# 24 fc timedistributed num = 24 fc = nn.ModuleList([nn.Linear(8, 1) for i in range(num)]) # forward pass x = np.zeros(64, 24, 8) outs=[] for i in range(x.shape[1]): outs.append(fc[i](x[:, i, :].unsqueeze(1))) outs=torch.cat(outs, axis=1)