I'd like to add a max norm constraint to several of the weight matrices in my TensorFlow graph, ala Torch's renorm
method.
If the L2 norm of any neuron's weight matrix exceeds max_norm
, I'd like to scale its weights down so that their L2 norm is exactly max_norm
.
What's the best way to express this using TensorFlow?
Here is a possible implementation:
import tensorflow as tfdef maxnorm_regularizer(threshold, axes=1, name="maxnorm", collection="maxnorm"):def maxnorm(weights):clipped = tf.clip_by_norm(weights, clip_norm=threshold, axes=axes)clip_weights = tf.assign(weights, clipped, name=name)tf.add_to_collection(collection, clip_weights)return None # there is no regularization loss termreturn maxnorm
Here's how you would use it:
from tensorflow.contrib.layers import fully_connected
from tensorflow.contrib.framework import arg_scopewith arg_scope([fully_connected],weights_regularizer=max_norm_regularizer(1.5)):hidden1 = fully_connected(X, 200, scope="hidden1")hidden2 = fully_connected(hidden1, 100, scope="hidden2")outputs = fully_connected(hidden2, 5, activation_fn=None, scope="outs")max_norm_ops = tf.get_collection("max_norm")[...]with tf.Session() as sess:sess.run(init)for epoch in range(n_epochs):for X_batch, y_batch in load_next_batch():sess.run(training_op, feed_dict={X: X_batch, y: y_batch})sess.run(max_norm_ops)
This creates a 3 layer neural network and trains it with max norm regularization at every layer (with a threshold of 1.5). I just tried it, seems to work. Hope this helps! Suggestions for improvements are welcome. :)
Notes
This code is based on tf.clip_by_norm()
:
>>> x = tf.constant([0., 0., 3., 4., 30., 40., 300., 400.], shape=(4, 2))
>>> print(x.eval())
[[ 0. 0.][ 3. 4.][ 30. 40.][ 300. 400.]]
>>> clip_rows = tf.clip_by_norm(x, clip_norm=10, axes=1)
>>> print(clip_rows.eval())
[[ 0. 0. ][ 3. 4. ][ 6. 8. ] # clipped![ 6.00000048 8. ]] # clipped!
You can also clip columns if you need to:
>>> clip_cols = tf.clip_by_norm(x, clip_norm=350, axes=0)
>>> print(clip_cols.eval())
[[ 0. 0. ][ 3. 3.48245788][ 30. 34.82457733][ 300. 348.24578857]]# clipped!