PyTorch torch.matmul (Matrix Multiplication) Function

Jul. 17, 2024

PyTorch torch.matmul function1 is to conduct “Matrix product of two tensors.” Considering two tensors a and b, we can use a.matmul(b) or torch.matmul(a,b) to calculate their matrix product if their dimensions match. For example:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch

a = torch.arange(1,13).reshape(4,3) # 4-by-3
b = torch.arange(1,4) # 1-by-3
c = a.matmul(b)
d = torch.matmul(a,b)

print(a)
print(b)
print(c)
print(d)
print(a.shape)
print(b.shape)
print(c.shape)
print(d.shape)
1
2
3
4
5
6
7
8
9
10
11
tensor([[ 1,  2,  3],
        [ 4,  5,  6],
        [ 7,  8,  9],
        [10, 11, 12]])
tensor([1, 2, 3])
tensor([14, 32, 50, 68])
tensor([14, 32, 50, 68])
torch.Size([4, 3])
torch.Size([3])
torch.Size([4])
torch.Size([4])

However, if we exchange the position of a and b, an error will occur:

1
b.matmul(a)
1
2
3
----> 1 b.matmul(a)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x3 and 4x3)
1
torch.matmul(b,a)
1
2
3
----> 1 torch.matmul(b,a)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x3 and 4x3)


References