Renormalize weight matrix using TensorFlow

2024/9/8 10:47:23

I'd like to add a max norm constraint to several of the weight matrices in my TensorFlow graph, ala Torch's renorm method.

If the L2 norm of any neuron's weight matrix exceeds max_norm, I'd like to scale its weights down so that their L2 norm is exactly max_norm.

What's the best way to express this using TensorFlow?

Answer

Here is a possible implementation:

import tensorflow as tfdef maxnorm_regularizer(threshold, axes=1, name="maxnorm", collection="maxnorm"):def maxnorm(weights):clipped = tf.clip_by_norm(weights, clip_norm=threshold, axes=axes)clip_weights = tf.assign(weights, clipped, name=name)tf.add_to_collection(collection, clip_weights)return None # there is no regularization loss termreturn maxnorm

Here's how you would use it:

from tensorflow.contrib.layers import fully_connected
from tensorflow.contrib.framework import arg_scopewith arg_scope([fully_connected],weights_regularizer=max_norm_regularizer(1.5)):hidden1 = fully_connected(X, 200, scope="hidden1")hidden2 = fully_connected(hidden1, 100, scope="hidden2")outputs = fully_connected(hidden2, 5, activation_fn=None, scope="outs")max_norm_ops = tf.get_collection("max_norm")[...]with tf.Session() as sess:sess.run(init)for epoch in range(n_epochs):for X_batch, y_batch in load_next_batch():sess.run(training_op, feed_dict={X: X_batch, y: y_batch})sess.run(max_norm_ops)

This creates a 3 layer neural network and trains it with max norm regularization at every layer (with a threshold of 1.5). I just tried it, seems to work. Hope this helps! Suggestions for improvements are welcome. :)

Notes

This code is based on tf.clip_by_norm():

>>> x = tf.constant([0., 0., 3., 4., 30., 40., 300., 400.], shape=(4, 2))
>>> print(x.eval())
[[   0.    0.][   3.    4.][  30.   40.][ 300.  400.]]
>>> clip_rows = tf.clip_by_norm(x, clip_norm=10, axes=1)
>>> print(clip_rows.eval())
[[ 0.          0.        ][ 3.          4.        ][ 6.          8.        ]  # clipped![ 6.00000048  8.        ]] # clipped!

You can also clip columns if you need to:

>>> clip_cols = tf.clip_by_norm(x, clip_norm=350, axes=0)
>>> print(clip_cols.eval())
[[   0.            0.        ][   3.            3.48245788][  30.           34.82457733][ 300.          348.24578857]]# clipped!
https://en.xdnf.cn/q/72673.html

Related Q&A

Numpy: find the euclidean distance between two 3-D arrays

Given, two 3-D arrays of dimensions (2,2,2):A = [[[ 0, 0],[92, 92]],[[ 0, 92],[ 0, 92]]]B = [[[ 0, 0],[92, 0]],[[ 0, 92],[92, 92]]]How do you find the Euclidean distance for each vector in A and B e…

Is it possible to break from lambda when the expected result is found

I am Python newbie, and just become very interested in Lambda expression. The problem I have is to find one and only one target element from a list of elements with lambda filter. In theory, when the t…

Intersection of multiple pandas dataframes

I have a number of dataframes (100) in a list as:frameList = [df1,df2,..,df100]Each dataframe has the two columns DateTime, Temperature.I want to intersect all the dataframes on the common DateTime col…

docker with pycharm 5

I try to build a docker-based development box for our django app. Its running smoothly.None of my teammembers will care about that until there is a nice IDE integration, therefore I play the new and sh…

How to make a simple Python REST server and client?

Im attempting to make the simplest possible REST API server and client, with both the server and client being written in Python and running on the same computer.From this tutorial:https://blog.miguelgr…

Histogram fitting with python

Ive been surfing but havent found the correct method to do the following.I have a histogram done with matplotlib:hist, bins, patches = plt.hist(distance, bins=100, normed=True)From the plot, I can see …

Subtract each row of matrix A from every row of matrix B without loops

Given two arrays, A (shape: M X C) and B (shape: N X C), is there a way to subtract each row of A from each row of B without using loops? The final output would be of shape (M N X C).Example A = np.ar…

Programmatically setting access control limits in mosquitto

I am working on an application that will use mqtt. I will be using the python library. I have been leaning towards using mosquitto but can find no way of programmatically setting access control limits …

Optimizing cartesian product between two Pandas Dataframe

I have two dataframes with the same columns:Dataframe 1:attr_1 attr_77 ... attr_8 userID John 1.2501 2.4196 ... 1.7610 Charles 0.0000 1.0618 ... 1.4813 Genarit…

Tensorflow: open a PIL.Image?

I have a script that obscures part of an image and runs it through a prediction net to see which parts of the image most strongly influence the tag prediction. To do this, I open a local image with PIL…