AttributeError: tuple object has no attribute dim, when feeding input to Pytorch LSTM network

2024/11/13 16:01:54

I am trying to run the following code:

import matplotlib.pylab as plt
import numpy as np
import torch
import torch.nn as nnclass LSTM(nn.Module):def __init__(self, input_shape, n_actions):super(LSTM, self).__init__()self.lstm = nn.LSTM(input_shape, 12)self.hidden2tag = nn.Linear(12, n_actions)def forward(self, x):out = self.lstm(x)out = self.hidden2tag(out)return outstate = [(1,2,3,4,5),(2,3,4,5,6),(3,4,5,6,7),(4,5,6,7,8),(5,6,7,8,9),(6,7,8,9,0)]device = torch.device("cuda")
net = LSTM(5, 3).to(device)state_v = torch.FloatTensor(state).to(device)q_vals_v = net(state_v.view(1, state_v.shape[0], state_v.shape[1]))
_, action = int(torch.max(q_vals_v, dim=1).item())

And that returns this error:

Traceback (most recent call last):File "/home/dikkerj/Documents/PycharmProjects/LSTMReactor/QuestionStackoverflow.py", line 26, in <module>q_vals_v = net(state_v.view(1, state_v.shape[0], state_v.shape[1]))File "/home/dikkerj/.local/lib/python3.5/site-packages/torch/nn/modules/module.py", line 477, in __call__result = self.forward(*input, **kwargs)File "/home/dikkerj/Documents/PycharmProjects/LSTMReactor/QuestionStackoverflow.py", line 15, in forwardout = self.hidden2tag(out)File "/home/dikkerj/.local/lib/python3.5/site-packages/torch/nn/modules/module.py", line 477, in __call__result = self.forward(*input, **kwargs)File "/home/dikkerj/.local/lib/python3.5/site-packages/torch/nn/modules/linear.py", line 55, in forwardreturn F.linear(input, self.weight, self.bias)File "/home/dikkerj/.local/lib/python3.5/site-packages/torch/nn/functional.py", line 1022, in linearif input.dim() == 2 and bias is not None:
AttributeError: 'tuple' object has no attribute 'dim'

Anyone knows how to fix this? (to get rid of the tensor being a tuple so that it can be fed into the LSTM network)

Answer

The pytorch LSTM returns a tuple.
So you get this error as your linear layer self.hidden2tag can not handle this tuple.

So change:

out = self.lstm(x)

to

out, states = self.lstm(x)

This will fix your error, by splitting up the tuple so that out is just your output tensor.

out then stores the hidden states, while states is another tuple that contains the last hidden and cell state.

You can also take a look here:
https://pytorch.org/docs/stable/nn.html#torch.nn.LSTM

You will get another error for the last line as max() returns a tuple as well. But this should be easy to fix and is yet different error :)

https://en.xdnf.cn/q/72041.html

Related Q&A

Python - Idiom to check if string is empty, print default

Im just wondering, is there a Python idiom to check if a string is empty, and then print a default if its is?(The context is Django, for the __unicode__(self) function for UserProfile - basically, I w…

Does WordNet have levels? (NLP)

For example...Chicken is an animal. Burrito is a food.WordNet allows you to do "is-a"...the hiearchy feature.However, how do I know when to stop travelling up the tree? I want a LEVEL. That …

Merge two DataFrames based on columns and values of a specific column with Pandas in Python 3.x

Hello i have a problem which i am not able to implement a solution on. I have following two DataFrames:>>> df1 A B date 1 1 01-2016 2 1 02-2017 1 2 03-2017 2 2 04-2020>>> d…

Use range as a key value in a dictionary, most efficient way?

I have been wondering if there is some kind of data-structure or clever way to use a dictionary (O(1) lookup) to return a value if there are given values for defined ranges that do not overlap. So far …

How to replace all instances of a sub-sequence in a list in Python?

I currently use this code:""" Replace all occurrences of subsequence a with b in list l """ def replace_subsequence(l,a,b):for i in range(len(l)):if(l[i:i+len(a)] == a):l…

How to initialise a 2D array in Python?

Ive been given the pseudo-code:for i= 1 to 3for j = 1 to 3board [i] [j] = 0next jnext iHow would I create this in python?(The idea is to create a 3 by 3 array with all of the elements set to 0 using a…

numpy: broadcast multiplication over one common axis of two 2d arrays

Im looking for a way to element-wise multiply two 2d arrays of shape (a, b) and (b, c), respectively. Over the b axis, which the two arrays have in common.For instance, an example of what Id like to br…

Convert integer to binary in python and compare the bits

How to convert a int n into binary and test each bit of the resulting binary number?I have just got the following after a lot of googling:def check_bit_positions(n, p1, p2):print int(str(n),2)However …

python, confused in decorate and closure

I have some test code:def num(num):def deco(func):def wrap(*args, **kwargs):inputed_num = numreturn func(*args, **kwargs)return wrapreturn deco@num(5) def test(a):return a + inputed_numprint test(1)whe…

Python Regex - checking for a capital letter with a lowercase after

I am trying to check for a capital letter that has a lowercase letter coming directly after it. The trick is that there is going to be a bunch of garbage capital letters and number coming directly befo…