PyTorch torch.dot
Function
Jul. 19, 2024
As PyTorch documentation puts it, torch.dot
function is to compute dot product of two 1D tensors1:
torch.dot
:
Computes the dot product of two 1D tensors.
NOTE: Unlike NumPy’s dot, torch.dot intentionally only supports computing the dot product of two 1D tensors with the same number of elements.
Parameters
- input (Tensor) – first tensor in the dot product, must be 1D.
- other (Tensor) – second tensor in the dot product, must be 1D.
Keyword Arguments
out ([Tensor, optional) – the output tensor.
We can see that torch.dot
function is only available for 1D tensor:
1
2
3
4
5
import torch
a, b = torch.tensor([1, 2]), torch.tensor([3, 4])
c = torch.dot(a, b)
print(a, b, c)
1
tensor([1, 2]) tensor([3, 4]) tensor(11)
but not for column vector (2D tensor):
1
2
3
a, b = torch.tensor([[1], [2]]), torch.tensor([[3], [4]])
c = torch.dot(a, b)
print(a, b, c)
1
2
3
4
5
6
7
8
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[2], line 2
1 a, b = torch.tensor([[1], [2]]), torch.tensor([[3], [4]])
----> 2 c = torch.dot(a, b)
3 print(a, b, c)
RuntimeError: 1D tensors expected, but got 2D and 2D tensors
or row vector (2D tensor):
1
2
3
a, b = torch.tensor([[1, 2]]), torch.tensor([[3, 4]])
c = torch.dot(a, b)
print(a, b, c)
1
2
3
4
5
6
7
8
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[3], line 2
1 a, b = torch.tensor([[1, 2]]), torch.tensor([[3, 4]])
----> 2 c = torch.dot(a, b)
3 print(a, b, c)
RuntimeError: 1D tensors expected, but got 2D and 2D tensors
The differences between 1D and 2D tensors should be noted2.
To put it bluntly, kind of weird 😂
References