There was an inconsistency between how the global cache was used at the top level and in pjit_call_impl so standardize it via a helper function.
In the test, check for re-compilation which is what that test was doing before cl/535630905
PiperOrigin-RevId: 536575987
This is required for APIs like `eval_jaxpr` and `jaxpr_as_fun` that don't call the top level pjit/jit function but rather go via pjit_p.bind directly which calls into _pjit_call_impl.
PiperOrigin-RevId: 535630905
Rather than enumerating a list of types that don't work in the buffer protocol, call the format descriptor function and fail if it fails.
Simplify the format descriptor function to avoid allocating a format string; these can be compile-time constants.
PiperOrigin-RevId: 535315975
This is an API proposed by the Python Array API Standard (https://data-apis.org/array-api/2022.12/). It's lightweight enough that there's hardly any downside to supporting it in JAX.
We supported the buffer protocol on the older DeviceArray class; port that support to jax.Array.
The previous attempt was reverted because it led to a C++ CHECK failure if the buffer was deleted while an external Python reference was held. Change the CPU PJRT client to keep the underlying buffer alive as long as there are external references, which is what the contract of Delete() says it will do.
Fixes https://github.com/google/jax/issues/14713
PiperOrigin-RevId: 535248553
Follow-up on #15677, basically undoing it. Some training runs experienced
mysterious failures after many steps. We may leave this disabled until we
diagnose the cause of the failures.
This way we don't pass a potentially-large (Python builtin) int value to an
int32 JAX computation parameter and get an error.
Fixes#15068
Co-authored by: Matthew Johnson <mattjj@google.com>
This change brings the dot_general primitive more in line with the HLO
primitive, as it is described in XLA's shape_inference.cc (but not in the
StableHLO spec). In particular we allow different input dtypes.
The main motivation is to support transposition in the presence of
preferred_element_type (which can set the output dtype to be different from the
inputs), e.g. to fix#10818.
However, because XLA platforms/backends can't seem to codegen all the cases
that are accepted by shape_inference.cc, in our lowering rules we generate
ConvertElementTypes on the inputs in a platform-dependent way.
The plugin is expected to calls jax._src.xla_bridge.register_plugin with its plugin_name, priority (default to be 400), path to .so file, and optional create options in their initialize() method.
Logics to register a plugin from ENV is not deleted to facilitate development with ENV.
PiperOrigin-RevId: 533280115
The semantics of eager wsc is the same as within a jit i.e. it will reshard to the given sharding only if the devices are the same and in the same order.
eager wsc won't work as expected with AD transpose because there is no `src` argument to reverse the shardings when transposing and was decided that it is fine for now. jax.device_put should be the API to use for that.
PiperOrigin-RevId: 532858670