Renormalize weight matrix using TensorFlow

2024/9/8 10:47:23

I'd like to add a max norm constraint to several of the weight matrices in my TensorFlow graph, ala Torch's renorm method.

If the L2 norm of any neuron's weight matrix exceeds max_norm, I'd like to scale its weights down so that their L2 norm is exactly max_norm.

What's the best way to express this using TensorFlow?


Here is a possible implementation:

import tensorflow as tfdef maxnorm_regularizer(threshold, axes=1, name="maxnorm", collection="maxnorm"):def maxnorm(weights):clipped = tf.clip_by_norm(weights, clip_norm=threshold, axes=axes)clip_weights = tf.assign(weights, clipped, name=name)tf.add_to_collection(collection, clip_weights)return None # there is no regularization loss termreturn maxnorm

Here's how you would use it:

from tensorflow.contrib.layers import fully_connected
from tensorflow.contrib.framework import arg_scopewith arg_scope([fully_connected],weights_regularizer=max_norm_regularizer(1.5)):hidden1 = fully_connected(X, 200, scope="hidden1")hidden2 = fully_connected(hidden1, 100, scope="hidden2")outputs = fully_connected(hidden2, 5, activation_fn=None, scope="outs")max_norm_ops = tf.get_collection("max_norm")[...]with tf.Session() as epoch in range(n_epochs):for X_batch, y_batch in load_next_batch(), feed_dict={X: X_batch, y: y_batch})

This creates a 3 layer neural network and trains it with max norm regularization at every layer (with a threshold of 1.5). I just tried it, seems to work. Hope this helps! Suggestions for improvements are welcome. :)


This code is based on tf.clip_by_norm():

>>> x = tf.constant([0., 0., 3., 4., 30., 40., 300., 400.], shape=(4, 2))
>>> print(x.eval())
[[   0.    0.][   3.    4.][  30.   40.][ 300.  400.]]
>>> clip_rows = tf.clip_by_norm(x, clip_norm=10, axes=1)
>>> print(clip_rows.eval())
[[ 0.          0.        ][ 3.          4.        ][ 6.          8.        ]  # clipped![ 6.00000048  8.        ]] # clipped!

You can also clip columns if you need to:

>>> clip_cols = tf.clip_by_norm(x, clip_norm=350, axes=0)
>>> print(clip_cols.eval())
[[   0.            0.        ][   3.            3.48245788][  30.           34.82457733][ 300.          348.24578857]]# clipped!

