* Added a noop config_tags_overrides parameter to jax_test()
* Updated BUILD files necessary to run Pallas tests via Bazel
* Changed PallasTest to skip "large" test cases
PiperOrigin-RevId: 608534008
The guarantee provided by this API is that the sharding of `x` and `y` will be the same! What the sharding will be is decided by GSPMD.
The flow of sharding is bidirectional i.e. SPMD will choose what the sharding should be of `x` and `y` depending on it's propagation algorithm. It might end up being that the sharding chosen is not of `x` and `y` but something better. At the end of propagation `x` and `y` will be sharded alike.
The API can be made variadic in the future i.e. `*args = shard_alike(*args)` depending on use cases.
Fixes: https://github.com/google/jax/issues/15600
PiperOrigin-RevId: 592375936
Shape polymorphism is now usable independently of jax2tf, and it deserves to have its tests independent of jax2tf. I started by branching jax2tf/tests/shape_poly_test.py into tests/shape_poly_test.py, followed by removing from the latter the tests and helper functions that do not make sense outside of jax2tf.
For now we leave the existing tests in jax2tf, because some of those tests exercise
other code paths. In the process of adding these tests we found two bugs (fixed separately in https://github.com/google/jax/pull/18516 and https://github.com/google/jax/pull/18515).
Since we now run these tests in GitHub and Kokoro, this has revealed a couple
of bugs in the tests, which we fix here both in the jax2tf/tests/shape_poly_test.py and the copy tests/shape_poly_test.py.
PiperOrigin-RevId: 583816243
This test is now independent of jax2tf. Move it out and rename it export_harnesses_multi_platform_test.py.
We disable the test in GitHub CI, because it is very large, pending
some changes to ensure it parallelizes well. The test is still
running in internal CI. This is matching the current behavior, since
jax2tf tests are only run internally.
PiperOrigin-RevId: 583603863
We expose 3 modes:
* `SpecifiedLayout`: User specifies the `minor_to_major` field of the layout. Tiling not exposed yet.
* `DefaultLayout`: PJRT chooses the layout. It defaults to the current behavior.
* `AUTO`: Compiler chooses the layout. This field is not a layout per se. It's a request to get the layout from the compiler. This field cannot be on an Array or other data types. It can only be on jit.
Public API coming soon.
Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 582692036
These methods are internal to JAX. Yet, prior to this commit they were
effectively part of the public API, since users could (and some did!) invoke
them on `jax.config`.
The multiplier for complex data types wasn't being applied correctly; the chunk_bytes calculation double-applied the multiplier.
Fixes https://github.com/google/jax/issues/18122
PiperOrigin-RevId: 573955671