PyTorch Tensor’s requires_grad Attribute, detach and detach_ Method

Jul. 18, 2024

requires_grad attribute

Let $y(x)=x^2+3x+2$, what is the result of $\dfrac{\mathrm{d}y}{\mathrm{d}x}\Big\vert_{x=3}$? We can calculate it with the help of PyTorch automatic differentiation:

1
2
3
4
5
6
7
8
9
10
11
import torch

x = torch.tensor(3.0,requires_grad=True)
y = x**2+3*x+2 # y = x^2+3x+2

print("y:", y)

# dy/dx
y.backward()

print("x.grad:", x.grad)
1
2
y: tensor(20., grad_fn=<AddBackward0>)
x.grad: tensor(9.)

After y.backward, the value of $\mathrm{d}y/\mathrm{d}x$ is stored in grad attribute of x variable.1

It should be noted that, in above process, requires_grad attribute must be set to True when creating variable x by torch.tensor function. As requires_grad entry in PyTorch documentation:

torch.Tensor.requires_grad2

Is True if gradients need to be computed for this Tensor, False otherwise.

Actually, a tensor’s attribute requires_grad is False by default when the tensor is created by torch.tensor. Still take above case, if we don’t specify requires_grad as True, an error will occur:

1
2
3
4
5
6
7
8
9
10
11
import torch

x = torch.tensor(3.0) # with default `requires_grad` value
y = x**2+3*x+2 # y = x^2+3x+2

print("y:", y)

# dy/dx
y.backward()

print("x.grad:", x.grad)
1
2
3
4
5
6
7
8
9
10
Cell In[5], line 9
      6 print("y:", y)
      8 # dy/dx
----> 9 y.backward()
     11 print("x.grad:", x.grad)

...
...

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

So, it is necessary to set requires_grad=True at this case to guarantee that gradient computation, i.e. y.backward()3, can be successfully made.


detach and detach_ method

In fact, we can check the difference of requires_grad attribute of x variable in Script 1 and Script 2 by directly printing variable:

1
2
3
4
5
x = torch.tensor(3.0)
print(x)

x = torch.tensor(3.0,requires_grad=True)
print(x)
1
2
tensor(3.)
tensor(3., requires_grad=True)

On the other hand, if a variable’s requires_grad is True, then other variables computed based on this variable will have a grad_fn attribute, and grad_fn attribute relates to specific arithmetic operation. For example:

1
2
3
4
5
6
7
8
9
10
11
12
13
# Case 1
x = torch.tensor(3.0,requires_grad=True)
y = x**2+3*x+2
z = (y+1)/2-0.5
print(y)
print(z)

# Case 2
x = torch.tensor(3.0)
y = x**2+3*x+2
z = (y+1)/2-0.5
print(y)
print(z)
1
2
3
4
tensor(20., grad_fn=<AddBackward0>)
tensor(10., grad_fn=<SubBackward0>)
tensor(20.)
tensor(10.)

Actually, for computation process at Case 1, we should realize that a PyTorch computational graph is constructed4, and automatic differentiation relies on the computational graph. And, it can be imagined that constructing computational graph and making it get involved in subsequent computations will increase time and space complexity. This is why I think requires_grad is False by default.

On the other hand, many built-in PyTorch functions, or operations, would construct computational graph automatically. If we don’t plan using computational graph, we can detach a tensor from the graph through method detach or detach_:

torch.Tensor.detach5

Returns a new Tensor, detached from the current graph.

The result will never require gradient.

This method also affects forward mode AD gradients and the result will never have forward mode AD gradients.


torch.Tensor.detach_6

Detaches the Tensor from the graph that created it, making it a leaf. Views cannot be detached in-place.

This method also affects forward mode AD gradients and the result will never have forward mode AD gradients.

As can be seen, detach and detach_ have similar functions, and the difference between them is that detach will new a same tensor and detach this copy from the computational graph, but the original tensor is still in the graph; while detach_ will detach the original tensor per se from the graph. For example:

(1) detach method

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch

x = torch.tensor(3.0,requires_grad=True)
print('x: ', x)

y = x**2+3*x+2
print('y (before detach): ', y)

z = y.detach()
print('y (after detach): ', y)
print('z: ', z)

y.backward()
print('x.grad:', x.grad)

z.backward()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
x:  tensor(3., requires_grad=True)
y (before detach):  tensor(20., grad_fn=<AddBackward0>)
y (after detach):  tensor(20., grad_fn=<AddBackward0>)
z:  tensor(20.)
x.grad: tensor(9.)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[34], line 16
     13 y.backward()
     14 print('x.grad:', x.grad)
---> 16 z.backward()

...
...

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

detach method doesn’t detach y from the graph, and hence y.backward() is still available, but z.backward() is not available, where z is the new tensor when executing y.detach().

(2) detach_ method

1
2
3
4
5
6
7
8
9
10
11
12
13
import torch

x = torch.tensor(3.0,requires_grad=True)
print('x: ', x)

y = x**2+3*x+2
print('y (before detach_): ', y)

z = y.detach_()
print('y (after detach_): ', y)
print('z: ', z)

y.backward()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
x:  tensor(3., requires_grad=True)
y (before detach_):  tensor(20., grad_fn=<AddBackward0>)
y (after detach_):  tensor(20.)
z:  tensor(20.)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[37], line 13
     10 print('y (after detach_): ', y)
     11 print('z: ', z)
---> 13 y.backward()

...
...

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

detach_ method detaches y from the graph, hence y.backward() not available.


References