777 questions
Advice
0
votes
0
replies
59
views
Can Ray pause/resume tasks at synchronization points when GPUs are limited?
I'm training multiple neural networks in parallel using Ray, where networks must synchronize at specific points during training (not just at completion) to share metadata and update hyperparameters ...
1
vote
0
answers
61
views
How to vectorize (ensemble) nnx.Modules with separate parameters using nnx.vmap in JAX/Flax
I have a vectorized (ensemble) Q-network implemented using Flax Linen that works as expected. Each critic in the ensemble has separate parameters, and the output is stacked along the first dimension (...
1
vote
1
answer
123
views
why the order of return variables affect the jax jitted function's performance so much?
In jax, you can donate a function argument to save the execute memory and time, if this argument is not used any more.
If you know that one of the inputs is not needed after the computation, and if ...
1
vote
0
answers
130
views
Is my JAX implementation of continuous wavelet transform correct?
I would like to implement continuous wavelet transform (CWT) using JAX. According to ChatGPT, it is in practice computed by performing a discrete convolution with a sampled wavelet function at ...
1
vote
1
answer
50
views
Grain (JAX) - equivalent to pyTorch `collect_fn` for batches
I defined a dataset class with __len__ and __getitem__ which returns a tuple of values. I can use `grain.transforms.Batch` to compose batches, but how do I specify how each item is combined into a ...
1
vote
1
answer
96
views
How to JIT-compile a function in JAX when input dimensions grow over time?
I’m implementing time series models using JAX in Python. These models are computationally expensive and need to be retrained over time using an expanding window approach. To improve performance, I ...
-1
votes
1
answer
124
views
Passing 4 arguments to a jit function with 4 parameters raises TypeError: jit() takes 1 positional argument but 5 positional arguments
I am currently working on GPU optimized simulations. For that, I have a function where the head looks like this:
import jax
from functools import partial
@partial(partial, jax.jit, static_argnums=(2,...
1
vote
1
answer
502
views
How to correctly install JAX with CUDA on Linux when `jax[cuda12_pip]` consistently falls back to the CPU version?
I am trying to install JAX with GPU support on a powerful, dedicated Linux server, but I am stuck in what feels like a Catch-22 where every official installation method fails in a different way, ...
3
votes
1
answer
107
views
JAX crashes with `CUDNN_STATUS_INTERNAL_ERROR` when using `joblib` or `multiprocessing`, but works in a single process
I am running into a FAILED_PRECONDITION: DNN library initialization failed error when trying to parallelize a JAX function using either Python's multiprocessing library or joblib.
The strange part is ...
3
votes
1
answer
117
views
Gradient Error of Batch Norm That is Implemented from Scratch
I am trying to implement batch normalization from scratch. Here is my code.
from functools import partial
import jax
@jax.tree_util.register_pytree_node_class
class MyBN:
def __init__(self, ...
2
votes
1
answer
135
views
Does `jax` compilation save runtime memory by recognizing array elements that are duplicated by indexing
Consider the example code:
from functools import partial
from jax import jit
import jax.numpy as jnp
@partial(jit, static_argnums=(0,))
def my_function(n):
idx = jnp.tile(jnp.arange(n, dtype=int)...
3
votes
1
answer
81
views
JAX scan over a leading dimension (normal way) or with an index
I use scan many times in my project, and accidentally found that scanning an array over a leading axis (normal_scan in the example below) is slower than scanning with an index (scan_with_index).
Would ...
1
vote
1
answer
205
views
is jax really incompatible with python multiprocesses? [closed]
I have a simple app with a main controller process, and a child process that handles API calls. They communicate using Python queues.
The app looks (something) like this:
import multiprocessing as mp
...
1
vote
2
answers
150
views
'AttributeError: partially initialized module 'jax' has no attribute 'version' (most likely due to a circular import)' following import mellon
Installed the mellon package but when I try to import it, I get:
>>> import mellon
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/...
3
votes
1
answer
138
views
How to Make Batching Rule for Multiple Outputs
I am still exploring how to make batching rule correctly. Right now, my code of batching rule doesn't work as expected for multiple outputs. Here is my code.
import jax
import jax.numpy as jnp
from ...