a = torch.rand(2, 3) b = torch.rand(2,2, 3) print('a:',a) print('b:',b) c = a.expand_as(b) print('c:',c)