What is colocate_with used for in tensorflow?

2024/9/19 21:07:32

Here is the link of the official docs. https://www.tensorflow.org/versions/r1.3/api_docs/python/tf/colocate_with

Answer

It's a context manager to make sure that the operation or tensor you're about to create will be placed on the same device the reference operation is on. Consider this piece of code (tested):

import tensorflow as tfwith tf.device("/cpu:0"):a = tf.constant(0.0, name="a")with tf.device("/gpu:0"):b = tf.constant(0.0, name="b")with tf.colocate_with(a):c = tf.constant(0.0, name="c")d = tf.constant(0.0, name="d")for operation in tf.get_default_graph().get_operations():print(operation.name, operation.device)

Outputs:

(u'a', u'/device:CPU:0')
(u'b', u'/device:GPU:0')
(u'c', u'/device:CPU:0')
(u'd', u'/device:GPU:0')

So it places tensor c on the same device where a is, regardless of the active device context of GPU when c is created. This can be very important for multi-GPU training. Imagine if you're not careful and you have a graph with tensors dependent on each other placed on 8 devices randomly. A complete disaster efficiency-wise. tf.colocate_with() can make sure this doesn't happen.

It is not explained in the docs because it's meant to be used by internal libraries only, so no guarantees it will stay. (Very likely it will, however. If you want to know more, you can look it up in the source code as of May 2018; might move as changes to the code happen.)

You're not likely to need this unless you're working on some low-level stuff. Most people use only one GPU, and even if you use multiple, you're generally building your graph one GPU at a time, that is within one tf.device() context manager at a time.

One example where it's used is the tf.train.ExponentialMovingAverage class. Clearly it looks a good idea to make sure to colocate the decay and moving average variables with the value tensor they are tracking.

https://en.xdnf.cn/q/72459.html

Related Q&A

Getting tkinter to work with python 3.x on macos with asdf [duplicate]

This question already has answers here:Why does tkinter (or turtle) seem to be missing or broken? Shouldnt it be part of the standard library?(4 answers)Closed 8 months ago.So Im stumped. How do I ge…

Flask server sent events socket exception

I am thinking of using SSE to push new data to the client and using Flot(javascript charting library) display "live" updates. My server runs on python Flask framework and I have figured out h…

Pandas adding scalar value to numeric column?

Given a dataframe like thisImageId | Width | Height | lb0 | x0 | y0 | lb1 | x1 | y1 | lb2 | x2 | y2 0 abc | 200 | 500 | ijk | 115| 8 | zyx | 15 | 16 | www | 23 | 42 1 def | 300 | 800 …

Sklearn Pipeline all the input array dimensions for the concatenation axis must match exactly

import pandas as pd from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer from sklearn.pipeline import Pipeline from sklearn.svm import LinearSVC from sklearn.preprocessing impo…

Python Pandas read_excel returns empty Dataframe

Reading a simple xls returning empty dataframe, cant figure it out for the life of me:path = (c:/Users/Desktop/Stuff/Ready) files = os.listdir(path) print(files)files_xlsx = [f for f in files if f[-3:]…

Fill matrix with transposed version

I have a pairwise matrix:>>> ma b c d a 1.0 NaN NaN NaN b 0.5 1.0 NaN NaN c 0.6 0.0 1.0 NaN d 0.5 0.4 0.3 1.0I want to replace the NaN in the the top right with the same va…

Subclass of numpy ndarray doesnt work as expected

`Hello, everyone.I found there is a strange behavior when subclassing a ndarray.import numpy as npclass fooarray(np.ndarray):def __new__(cls, input_array, *args, **kwargs):obj = np.asarray(input_array)…

How to correctly load images asynchronously in PyQt5?

Im trying to figure out how to accomplish an async image load correctly, in PyQt Qlistview.My main widget consists of a Qlistview and a QLineEdit textbox. I have a database of actors which I query usin…

How to print results of Python ThreadPoolExecutor.map immediately?

I am running a function for several sets of iterables, returning a list of all results as soon as all processes are finished.def fct(variable1, variable2):# do an operation that does not necessarily ta…

Python dir equivalent in perl?

The dir command in Python 2.7.x lists all accessible symbols from a module. Is there an equivalent in Perl 5.x to list all accessible symbols from a package?