Skip to main content
Advice
0 votes
0 replies
59 views

I'm training multiple neural networks in parallel using Ray, where networks must synchronize at specific points during training (not just at completion) to share metadata and update hyperparameters ...
desert_ranger's user avatar
1 vote
0 answers
61 views

I have a vectorized (ensemble) Q-network implemented using Flax Linen that works as expected. Each critic in the ensemble has separate parameters, and the output is stacked along the first dimension (...
Lucas Alegre's user avatar
1 vote
1 answer
123 views

In jax, you can donate a function argument to save the execute memory and time, if this argument is not used any more. If you know that one of the inputs is not needed after the computation, and if ...
zhixin's user avatar
  • 194
1 vote
0 answers
130 views

I would like to implement continuous wavelet transform (CWT) using JAX. According to ChatGPT, it is in practice computed by performing a discrete convolution with a sampled wavelet function at ...
W. Zhu's user avatar
  • 845
1 vote
1 answer
50 views

I defined a dataset class with __len__ and __getitem__ which returns a tuple of values. I can use `grain.transforms.Batch` to compose batches, but how do I specify how each item is combined into a ...
Richie Bendall's user avatar
1 vote
1 answer
96 views

I’m implementing time series models using JAX in Python. These models are computationally expensive and need to be retrained over time using an expanding window approach. To improve performance, I ...
Ali Moin's user avatar
-1 votes
1 answer
124 views

I am currently working on GPU optimized simulations. For that, I have a function where the head looks like this: import jax from functools import partial @partial(partial, jax.jit, static_argnums=(2,...
alo bre's user avatar
  • 21
1 vote
1 answer
502 views

I am trying to install JAX with GPU support on a powerful, dedicated Linux server, but I am stuck in what feels like a Catch-22 where every official installation method fails in a different way, ...
PowerPoint Trenton's user avatar
3 votes
1 answer
107 views

I am running into a FAILED_PRECONDITION: DNN library initialization failed error when trying to parallelize a JAX function using either Python's multiprocessing library or joblib. The strange part is ...
PowerPoint Trenton's user avatar
3 votes
1 answer
117 views

I am trying to implement batch normalization from scratch. Here is my code. from functools import partial import jax @jax.tree_util.register_pytree_node_class class MyBN: def __init__(self, ...
Yahya's user avatar
  • 119
2 votes
1 answer
135 views

Consider the example code: from functools import partial from jax import jit import jax.numpy as jnp @partial(jit, static_argnums=(0,)) def my_function(n): idx = jnp.tile(jnp.arange(n, dtype=int)...
Ben's user avatar
  • 539
3 votes
1 answer
81 views

I use scan many times in my project, and accidentally found that scanning an array over a leading axis (normal_scan in the example below) is slower than scanning with an index (scan_with_index). Would ...
user1168149's user avatar
1 vote
1 answer
205 views

I have a simple app with a main controller process, and a child process that handles API calls. They communicate using Python queues. The app looks (something) like this: import multiprocessing as mp ...
BHK's user avatar
  • 39
1 vote
2 answers
150 views

Installed the mellon package but when I try to import it, I get: >>> import mellon Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/...
Megan Cole's user avatar
3 votes
1 answer
138 views

I am still exploring how to make batching rule correctly. Right now, my code of batching rule doesn't work as expected for multiple outputs. Here is my code. import jax import jax.numpy as jnp from ...
Yahya's user avatar
  • 119

15 30 50 per page
1
2 3 4 5
52