We supported the buffer protocol on the older DeviceArray class; port that support to jax.Array.
The previous attempt was reverted because it led to a C++ CHECK failure if the buffer was deleted while an external Python reference was held. Change the CPU PJRT client to keep the underlying buffer alive as long as there are external references, which is what the contract of Delete() says it will do.
Fixes https://github.com/google/jax/issues/14713
PiperOrigin-RevId: 535248553
For dynamic shapes experimentation and normal debugging, `python_pjit` still exists so that problem doesn't exist which makes us free to remove these 2 flags.
I am leaving pmap's flag alone for now.
PiperOrigin-RevId: 522602754
I forgot that the default setting is actually in jaxlib:
fbe9a80fdb/xla/python/xla_client.py (L135)
To be able to make this change as a jax-only release, I manually set
the env var on Cloud TPU if it isn't already set.
This is another attempt to land a rolled-back change https://github.com/google/jax/pull/14734 (cl/514070997).
See b/272154366 for more details.
The use case for call_tf with shape polymorphism is when we have a JAX program
that calls into TF function, and we want to serialize the JAX program with
some shapes unknown. Previously this use case did not work, except in the special
case when the output shape of the called TF function returns statically known
shapes.
The idea is that we allow the user of call_tf to specify the output shape.
This can be done even in presence of shape polymorphism, by writing the
output shape as an expression in terms of the input shapes. This is what
other JAX primitives do, e.g., concat, so we are simply enabling call_tf
to get the same behavior.
This change should be enough for old-style jax2tf, but will require more
work for native serialization.
We also removed some old code that was trying to workaround some limitations
in shape inference in TF. I think that those workarounds are ugly, and I am
prepared to give error messages rather than keep that code. So far no
tests fail.
PiperOrigin-RevId: 515137407
- jax.Array.broadcast: use lax.broadcast instead
- jax.Array.broadcast_in_dim: use lax.broadcast_in_dim instead
- jax.Array.split: use jnp.split instead
These are removed because they are not part of the np.ndarray API.
This changes the internals of JAX without affecting any public API.
Before, `jit` was a final style primitive. This means that the creation
of jaxpr was delayed as much as possible and transformations were stacked
on top of each other. With the `jit`-`pjit` implementation merge, `jit`
becomes an initial style primitive which means that we trace to jaxpr
as early as possible. For more information see [this section in autodidax](https://jax.readthedocs.io/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing).
Moving to initial style should simplify JAX's internals and make
development of features like dynamic shapes, etc easier.
PiperOrigin-RevId: 508143501