keras/tensorflow model: gradient w.r.t. input return the same (wrong?) value for all input data

2024/7/5 11:53:19

Given a trained keras model I am trying to compute the gradient of the output with respect to the input.

This example tries to fit the function y=x^2 with a keras model composed by 4 layers of relu activations, and compute the gradient of the model output with respect to the input.

from keras.models import Sequential
from keras.layers import Dense
from keras import backend as k
from sklearn.model_selection import train_test_split
import numpy as np
import tensorflow as tf# random data
x = np.random.random((1000, 1))
y = x**2# split train/val
x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=0.15)# model
model = Sequential()
# 1d input
model.add(Dense(10, input_shape=(1, ), activation='relu'))
model.add(Dense(10, activation='relu'))
model.add(Dense(10, activation='relu'))
model.add(Dense(10, activation='relu'))
# 1d output
model.add(Dense(1))## compile and fit
model.compile(loss='mse', optimizer='rmsprop', metrics=['mae'])
model.fit(x_train, y_train, batch_size=256, epochs=100, validation_data=(x_val, y_val), shuffle=True)## compute derivative (gradient)
session = tf.Session()
session.run(tf.global_variables_initializer())
y_val_d_evaluated = session.run(tf.gradients(model.output, model.input), feed_dict={model.input: x_val})print(y_val_d_evaluated)

x_val is a vector of 150 random number between 0 and 1.

My expectations is that y_val_d_evaluated (the gradient) should be:

A. an array of 150 different numbers (because x_val contains 150 different numbers);

B. the values should be near to 2*x_val (the derivative of x^2).

Instead, every time I run this example, y_val_d_evaluated contains 150 equal values (e.g. [0.0150494], [-0.0150494], [0.0150494], [-0.0150494], ...), moreover the value is very different from 2x, and the value change every time I run the example.

Anyone has some suggestions to help me to understand why this code does not give the expected gradient results?

Answer

Ok I found the problem, the following lines:

session = tf.Session()
session.run(tf.global_variables_initializer())

create a new tf session that overwrites the model parameters, so after these instructions the model was a model with the random initial paramters. This explains why every run the value was different.

The solution to get the tensorflow session from a keras environment is to use:

session = k.get_session()

whith this simple change the results go as I expected.

https://en.xdnf.cn/q/120334.html

Related Q&A

Pandas: get json from data frame

I have data framemember_id,2015-05-01,2015-05-02,2015-05-03,2015-05-04,2015-05-05,2015-05-06,2015-05-07,2015-05-08,2015-05-09,2015-05-10,2015-05-11,2015-05-12,2015-05-13,2015-05-14,2015-05-15,2015-05-1…

Python - Statistical distribution

Im quite new to python world. Also, Im not a statistician. Im in the need to implementing mathematical models developed by mathematicians in a computer science programming language. Ive chosen python a…

How to add data in list below?

i have a list :List = [[[1,2],[2,4]],[[1,4],[4,8]],[[53,8],[8,2],[2,82]]]That i want add reverse data to listTo be:[[[1,2],[2,4],[2,1],[4,2]],[[1,4],[4,8],[4,1],[8,4]],[[53,8],[8,2],[2,82],[8,53],[2,8]…

Storing lists within lists in Python

I have a question about accessing elements in lists. This is the code: movies = ["The Holy Grail", 1975, "Terry Jones and Terry Gilliam", 91,["Graham Champman", ["Mic…

getting ZeroDivisionError: integer division or modulo by zero

I had written a simple pascal triangle code in python but I am getting a errordef factorial(n):c=1re=1for c in range(n):re = re * c;return(re)print "Enter how many rows of pascal triangle u want t…

How to scrape images from a website and display them on html file?

I am scraping images from https://www.open2study.com/courses I got all the image sources but dont know how to display the images (instead of links) on a table with 2 column ( one column for title and o…

Multiple files comparing using python [closed]

Its difficult to tell what is being asked here. This question is ambiguous, vague, incomplete, overly broad, or rhetorical and cannot be reasonably answered in its current form. For help clarifying thi…

parsing interactive broker fundamental data

Ive successfully pulled data from IB using the api. It comes in XML format and it looks like this...<TotalRevenues currency="USD"><TotalRevenue asofDate="2017-12-31" report…

How to format HTTP request to discord API?

While this code works, it sends "Bad Request". import socket, ssl token = NzMyMzQ1MTcwNjK2MTR5OEU3.XrzQug.BQzbrckR-THB9eRwZi3Dn08BWrM HOST = "discord.com" PORT = 443 t = POST / HTTP…

Python 3.30 TypeError: object of type int has no len() [closed]

This question is unlikely to help any future visitors; it is only relevant to a small geographic area, a specific moment in time, or an extraordinarily narrow situation that is not generally applicable…