Avoid importing Triton modules on Windows, since we don't build it.
Also avoid using an unescaped `\` in a regular expression.
PiperOrigin-RevId: 672507555
Of note, I moved the logic about which algorithm to use, and when to use the batched algorithm into the kernel in order to support shape polymorphism and export.
PiperOrigin-RevId: 671853879
Fix a type error that arises when we try to run the host callback tests with JAX_HOST_CALLBACK_LEGACY=False (in the process of deprecating jax.experimental.host_callback).
PiperOrigin-RevId: 671825020
This fixes several bugs in presence of equality constraints where
the left-hand side is just a dimension variable.
First, such constraints were not applied when parsing variables.
Now, with a constraint `a == b` when we parse "a" we obtain `b`.
Second, when we evaluate symbolic dimensions that contain
dimension variables that are constrained to be equal to something
else, we may fail to find the dimension variable in the environment
because the environment construction has applied the constraints.
We fix this by looking up the unknown dimension variable in
the equality constraints.
Fixes: #23437Fixes: #23456
Remote configurations of python repositories are removed because hermetic Python repository rules install and configure python modules in Bazel cache on the host machine. The cache is shared across host and remote machines.
PiperOrigin-RevId: 671512134
simpler bitwise_right_shift implementation
to match previous PR
updating bitwise_right_shift_doc as an alias
readded jnp.bitwise_left_shift, jnp.bitwise_right_shift
Update sharded-computation doc to use make_mesh()
Rename `jtu.create_global_mesh` to `jtu.create_mesh` and use `jax.make_mesh` inside `jtu.create_mesh` to get maximum test coverage of the new API.
PiperOrigin-RevId: 670744047
better true_divide and divide docs
doc wording update
[Mosaic TPU] Fix mosaic alignment check in concatenate rule.
PiperOrigin-RevId: 670837792
Fix pytype errors and args for jax.Array methods
Add docker builds for ubu22 and 24
Better docs for jax.numpy: log and log1p
random.key_impl: improve repr of output
Remove unused docstring addition: _PRECISION_DOC
update example optimizers library docstring
* JAXopt is being merged into Optax, so point only to Optax
* Update Optax's github repository URL
fixing merge duplication
updating tests to skip bitwise shift if numpy major version < 2
removed whitespace 659
keep non-bitwise tests for numpy < 2.0.0
more readable edit
This PR updates the FFI lowering rule to support a DeviceLoweringLayout
object as input when specifying the input and output layouts. For now,
this just converts the DLL object to its appropriate list of
minor-to-major integers because that's what the custom call op expects.