Keras: Understanding the number of trainable LSTM parameters

2024/10/14 7:15:48

I have run a Keras LSTM demo containing the following code (after line 166):

m = 1
model=Sequential()
dim_in = m
dim_out = m
nb_units = 10model.add(LSTM(input_shape=(None, dim_in),return_sequences=True, units=nb_units))
model.add(TimeDistributed(Dense(activation='linear', units=dim_out)))
model.compile(loss = 'mse', optimizer = 'rmsprop')

When I prepend a call to model.summary(), I see the following output:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
lstm_4 (LSTM)                (None, None, 10)          480       
_________________________________________________________________
time_distributed_4 (TimeDist (None, None, 1)           11        
=================================================================
Total params: 491
Trainable params: 491
Non-trainable params: 0

I understand that the 11 params of the time distributed layer simply consist of nb_units weights plus one bias value.

Now for the LSTM layer: These answers say:

params = 4 * ((input_size + 1) * output_size + output_size^2)

In my case with input_size = 1 and output_size = 1 this yields only 12 parameters for each of the 10 units, totaling to 120 parameters. Compared to the reported 480, this is off by a factor of 4. Where is my error?

Answer

The params formula holds for the whole layer, not per Keras unit.

Quoting this answer:

[In Keras], the unit means the dimension of the inner cells in LSTM.

LSTM in Keras only define exactly one LSTM block, whose cells is of unit-length.

Directly setting output_size = 10 (like in this comment) correctly yields the 480 parameters.

https://en.xdnf.cn/q/117980.html

Related Q&A

Updating Labels in Tkinter with for loop

So Im trying to print items in a list dynamically on 10 tkinter Labels using a for loop. Currently I have the following code:labe11 = StringVar() list2_placer = 0 list1_placer = 1 mover = 227 for items…

Paginate results, offset and limit

If I am developing a web service for retrieving some album names of certain artist using an API, and I am asked:The service should give the possibility to paginate results. It should support ofset= and…

Improve code to find prime numbers

I wrote this python code about 3 days ago, and I am stuck here, I think it could be better, but I dont know how to improve it. Can you guys please help me?# Function def is_prime(n):if n == 2 or n == …

How to read the line that contains a string then extract this line without this string

I have a file .txt that contains a specific line, like thisfile.txt. . T - Python and Matplotlib Essentials for Scientists and Engineers . A - Wood, M.A. . . .I would like to extract lines that contain…

Python: How to access and iterate over a list of div class element using (BeautifulSoup)

Im parsing data about car production with BeautifulSoup (see also my first question):from bs4 import BeautifulSoup import stringhtml = """ <h4>Production Capacity (year)</h4>…

What should I worry about Python template engines and web frameworks? [closed]

As it currently stands, this question is not a good fit for our Q&A format. We expect answers to be supported by facts, references, or expertise, but this question will likely solicit debate, argum…

Value Search from Dictionary via User Input

I have written the following code for getting an output of the various districts located in the given city and their respective postal codes. I want my code to be able to receive input from the user (D…

Read and aggregate data from CSV file

I have a data file with the following format:name,cost1,cost1,cost1,cost2,cost3,cost3, X,2,4,6,5,6,8, Y,0,3,6,5,4,6, . . ....Now, what I would like to do is to convert this to a dictionary of dictionar…

nltk cant using ImportError: cannot import name compat

This is my codeimport nltk freq_dist = nltk.FreqDist(words) print freq_dist.keys()[:50] # 50 most frequent tokens print freq_dist.keys()[-50:] # 50 least frequent tokensAnd I am getting this error mess…

Fitting and Plotting Lognormal

Im having trouble doing something as relatively simple as:Draw N samples from a gaussian with some mean and variance Take logs to those N samples Fit a lognormal (using stats.lognorm.fit) Spit out a n…