JAX Apply function only on slice of array under jit

2024/10/1 12:27:35

I am using JAX, and I want to perform an operation like

@jax.jit
def fun(x, index):x[:index] = other_fun(x[:index])return x

This cannot be performed under jit. Is there a way of doing this with jax.ops or jax.lax? I thought of using jax.ops.index_update(x, idx, y) but I cannot find a way of computing y without incurring in the same problem again.

Answer

The previous answer by @rvinas using dynamic_slice works well if your index is static, but you can also accomplish this with a dynamic index using jnp.where. For example:

import jax
import jax.numpy as jnpdef other_fun(x):return x + 1@jax.jit
def fun(x, index):mask = jnp.arange(x.shape[0]) < indexreturn jnp.where(mask, other_fun(x), x)x = jnp.arange(5)
print(fun(x, 3))
# [1 2 3 3 4]
https://en.xdnf.cn/q/70969.html

Related Q&A

Using my own corpus for category classification in Python NLTK

Im a NTLK/Python beginner and managed to load my own corpus using CategorizedPlaintextCorpusReader but how do I actually train and use the data for classification of text?>>> from nltk.corpus…

Python ImportError for strptime in spyder for windows 7

I cant for the life of me figure out what is causing this very odd error.I am running a script in python 2.7 in the spyder IDE for windows 7. It uses datetime.datetime.strptime at one point. I can run …

How to show diff of two string sequences in colors?

Im trying to find a Python way to diff strings. I know about difflib but I havent been able to find an inline mode that does something similar to what this JS library does (insertions in green, deletio…

Regex for timestamp

Im terrible at regex apparently, it makes no sense to me...Id like an expression for matching a time, like 01:23:45 within a string. I tried this (r(([0-9]*2)[:])*2([0-9]*2)but its not working. I need …

os.read(0,) vs sys.stdin.buffer.read() in python

I encountered the picotui library, and was curious to know a bit how it works. I saw here (line 147) that it uses: os.read(0,32)According to Google 0 represents stdin, but also that the accepted answer…

python - Pandas: groupby ffill for multiple columns

I have the following DataFrame with some missing values. I want to use ffill() to fill missing values in both var1 and var2 grouped by date and building. I can do that for one variable at a time, but w…

Gtk-Message: Failed to load module canberra-gtk-module

My pygtk program writes this warning to stderr:Gtk-Message: Failed to load module "canberra-gtk-module"libcanberra seems to be a library for sound.My program does not use any sound. Is there …

Why does installation of some Python packages require Visual Studio?

Say, you are installing a Python package for pyEnchant or crfsuite, etc. It fails to install and in the error trace it says some .bat (or .dll) file is missing.A few forums suggest you install Visual S…

Does Django ORM have an equivalent to SQLAlchemys Hybrid Attribute?

In SQLAlchemy, a hybrid attribute is either a property or method applied to an ORM-mapped class,class Interval(Base):__tablename__ = intervalid = Column(Integer, primary_key=True)start = Column(Integer…

Building a Python shared object binding with cmake, which depends upon external libraries

We have a c file called dbookpy.c, which will provide a Python binding some C functions.Next we decided to build a proper .so with cmake, but it seems we are doing something wrong with regards to linki…