How to conditionally assign values to tensor [masking for loss function]?

2024/10/13 18:27:11

I want to create a L2 loss function that ignores values (=> pixels) where the label has the value 0. The tensor batch[1] contains the labels while output is a tensor for the net output, both have a shape of (None,300,300,1).

labels_mask = tf.identity(batch[1])
labels_mask[labels_mask > 0] = 1
loss = tf.reduce_sum(tf.square((output-batch[1])*labels_mask))/tf.reduce_sum(labels_mask)

My current code yields to TypeError: 'Tensor' object does not support item assignment (on the second line). What's the tensorflow-way to do this? I also tried to normalize the loss with tf.reduce_sum(labels_mask), which I hope works like this.

Answer

Here is an example how to apply boolean indexing and conditionally assign values to Variable:

a = tf.Variable(initial_value=[0, 0, 4, 6, 1, 2, 4, 0])
mask = tf.greater_equal(a, 2)  # [False False  True  True False  True  True False]
indexes = tf.where(mask)  # [[2] [3] [5] [6]], shape=(4, 1)
b = tf.scatter_update(a, mask, tf.constant(1500))

output:

[   0,    0, 1500, 1500,    1, 1500, 1500,    0]
https://en.xdnf.cn/q/69505.html

Related Q&A

Assign Colors to Lines

I am trying to plot a variable number of lines in matplotlib where the X, Y data and colors are stored in numpy arrays, as shown below. Is there a way to pass an array of colors into the plot function,…

How to display multiple annotations in Seaborn Heatmap cells

I want seaborn heatmap to display multiple values in each cell of the heatmap. Here is a manual example of what I want to see, just to be clear:data = np.array([[0.000000,0.000000],[-0.231049,0.000000]…

ImportError: No module named lxml on Mac

I am having a problem running a Python script and it is showing this message:ImportError: No module named lxmlI suppose I have to install somewhat called lxml but I am really newbie to Python and I don…

Pandas Rolling window Spearman correlation

I want to calculate the Spearman and/or Pearson Correlation between two columns of a DataFrame, using a rolling window.I have tried df[corr] = df[col1].rolling(P).corr(df[col2]) (P is the window size)b…

Python string splitlines() removes certain Unicode control characters

I noticed that Pythons standard string method splitlines() actually removes some crucial Unicode control characters as well. Example>>> s1 = uasdf \n fdsa \x1d asdf >>> s1.splitlines(…

Get only HTML head Element with a Script or Tool

I am trying to get large amount of status information, which are encoded in websites, mainly inside the "< head >< /head >" element. I know I can use wget or curl or python to get…

Is it possible to restore corrupted “interned” bytes-objects

It is well known, that small bytes-objects are automatically "interned" by CPython (similar to the intern-function for strings). Correction: As explained by @abarnert it is more like the inte…

Wildcard namespaces in lxml

How to query using xpath ignoring the xml namespace? I am using python lxml library. I tried the solution from this question but doesnt seem to work.In [151]: e.find("./*[local-name()=Buckets]&qu…

WordNet - What does n and the number represent?

My question is related to WordNet Interface.>>> wn.synsets(cat)[Synset(cat.n.01), Synset(guy.n.01), Synset(cat.n.03),Synset(kat.n.01), Synset(cat-o-nine-tails.n.01), Synset(caterpillar.n.02), …

How to change the values of a column based on two conditions in Python

I have a dataset where I have the time in a game and the time of an event. EVENT GAME0:34 0:43NaN 0:232:34 3:43NaN 4:50I want to replace the NaN in the EVENT column where GAME…