AttributeError: tuple object has no attribute dim, when feeding input to Pytorch LSTM network

2024/9/21 16:30:51

I am trying to run the following code:

import matplotlib.pylab as plt
import numpy as np
import torch
import torch.nn as nnclass LSTM(nn.Module):def __init__(self, input_shape, n_actions):super(LSTM, self).__init__()self.lstm = nn.LSTM(input_shape, 12)self.hidden2tag = nn.Linear(12, n_actions)def forward(self, x):out = self.lstm(x)out = self.hidden2tag(out)return outstate = [(1,2,3,4,5),(2,3,4,5,6),(3,4,5,6,7),(4,5,6,7,8),(5,6,7,8,9),(6,7,8,9,0)]device = torch.device("cuda")
net = LSTM(5, 3).to(device)state_v = torch.FloatTensor(state).to(device)q_vals_v = net(state_v.view(1, state_v.shape[0], state_v.shape[1]))
_, action = int(torch.max(q_vals_v, dim=1).item())

And that returns this error:

Traceback (most recent call last):File "/home/dikkerj/Documents/PycharmProjects/LSTMReactor/", line 26, in <module>q_vals_v = net(state_v.view(1, state_v.shape[0], state_v.shape[1]))File "/home/dikkerj/.local/lib/python3.5/site-packages/torch/nn/modules/", line 477, in __call__result = self.forward(*input, **kwargs)File "/home/dikkerj/Documents/PycharmProjects/LSTMReactor/", line 15, in forwardout = self.hidden2tag(out)File "/home/dikkerj/.local/lib/python3.5/site-packages/torch/nn/modules/", line 477, in __call__result = self.forward(*input, **kwargs)File "/home/dikkerj/.local/lib/python3.5/site-packages/torch/nn/modules/", line 55, in forwardreturn F.linear(input, self.weight, self.bias)File "/home/dikkerj/.local/lib/python3.5/site-packages/torch/nn/", line 1022, in linearif input.dim() == 2 and bias is not None:
AttributeError: 'tuple' object has no attribute 'dim'

Anyone knows how to fix this? (to get rid of the tensor being a tuple so that it can be fed into the LSTM network)


The pytorch LSTM returns a tuple.
So you get this error as your linear layer self.hidden2tag can not handle this tuple.

So change:

out = self.lstm(x)


out, states = self.lstm(x)

This will fix your error, by splitting up the tuple so that out is just your output tensor.

out then stores the hidden states, while states is another tuple that contains the last hidden and cell state.

You can also take a look here:

You will get another error for the last line as max() returns a tuple as well. But this should be easy to fix and is yet different error :)

