The guarantee provided by this API is that the sharding of `x` and `y` will be the same! What the sharding will be is decided by GSPMD.
The flow of sharding is bidirectional i.e. SPMD will choose what the sharding should be of `x` and `y` depending on it's propagation algorithm. It might end up being that the sharding chosen is not of `x` and `y` but something better. At the end of propagation `x` and `y` will be sharded alike.
The API can be made variadic in the future i.e. `*args = shard_alike(*args)` depending on use cases.
Fixes: https://github.com/google/jax/issues/15600
PiperOrigin-RevId: 592375936
Shape polymorphism is now usable independently of jax2tf, and it deserves to have its tests independent of jax2tf. I started by branching jax2tf/tests/shape_poly_test.py into tests/shape_poly_test.py, followed by removing from the latter the tests and helper functions that do not make sense outside of jax2tf.
For now we leave the existing tests in jax2tf, because some of those tests exercise
other code paths. In the process of adding these tests we found two bugs (fixed separately in https://github.com/google/jax/pull/18516 and https://github.com/google/jax/pull/18515).
Since we now run these tests in GitHub and Kokoro, this has revealed a couple
of bugs in the tests, which we fix here both in the jax2tf/tests/shape_poly_test.py and the copy tests/shape_poly_test.py.
PiperOrigin-RevId: 583816243
This test is now independent of jax2tf. Move it out and rename it export_harnesses_multi_platform_test.py.
We disable the test in GitHub CI, because it is very large, pending
some changes to ensure it parallelizes well. The test is still
running in internal CI. This is matching the current behavior, since
jax2tf tests are only run internally.
PiperOrigin-RevId: 583603863
We expose 3 modes:
* `SpecifiedLayout`: User specifies the `minor_to_major` field of the layout. Tiling not exposed yet.
* `DefaultLayout`: PJRT chooses the layout. It defaults to the current behavior.
* `AUTO`: Compiler chooses the layout. This field is not a layout per se. It's a request to get the layout from the compiler. This field cannot be on an Array or other data types. It can only be on jit.
Public API coming soon.
Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 582692036
These methods are internal to JAX. Yet, prior to this commit they were
effectively part of the public API, since users could (and some did!) invoke
them on `jax.config`.
The multiplier for complex data types wasn't being applied correctly; the chunk_bytes calculation double-applied the multiplier.
Fixes https://github.com/google/jax/issues/18122
PiperOrigin-RevId: 573955671
These tests work on both GPU and the current (non-stream_executor) TPU runtime, so the conditions aren't needed any more.
Tag a couple of tests as "multiaccelerator" since they appear to benefit from multiple devices.
PiperOrigin-RevId: 565367453
Those modules have been developed initially for jax2tf
but they do not depend on TF anymore. They are used for JAX
native serialization. We move them under
jax.experimental.export (also renaming jax_export.py to export.py) so that we can use them without depending on TF.
We are leaving behind stub modules jax2tf.jax_export and jax2tf.shape_poly that just redirect some of the public APIs. To be cleaned later.
PiperOrigin-RevId: 562988740