364 Commits

Author SHA1 Message Date
Peter Hawkins
e464dc8700 Reland: [XLA:Python] Add buffer protocol support to jax.Array
We supported the buffer protocol on the older DeviceArray class; port that support to jax.Array.

The previous attempt was reverted because it led to a C++ CHECK failure if the buffer was deleted while an external Python reference was held. Change the CPU PJRT client to keep the underlying buffer alive as long as there are external references, which is what the contract of Delete() says it will do.

Fixes https://github.com/google/jax/issues/14713

PiperOrigin-RevId: 535248553
2023-05-25 07:20:42 -07:00
Jake Vanderplas
399e4ee87f Copybara import of the project:
--
8cf6a6acd151007935b0c3093df05ef036bb0244 by Jake VanderPlas <jakevdp@google.com>:

Remove several deprecated APIs

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16110 from jakevdp:deprecations 8cf6a6acd151007935b0c3093df05ef036bb0244
PiperOrigin-RevId: 534897394
2023-05-24 10:35:37 -07:00
Skye Wanderman-Milne
533a7c05f1 Update versions and changelog post 0.4.10 release 2023-05-11 18:16:02 -07:00
Skye Wanderman-Milne
82bbeef519 Update setup.py, WORKSPACE, and CHANGELOG for jax/jaxlib 0.4.10 release 2023-05-11 14:46:06 -07:00
jax authors
bbc96320ed Merge pull request #15947 from skye:version
PiperOrigin-RevId: 530765476
2023-05-09 18:12:38 -07:00
Peter Hawkins
cc5e694658 Add improved TPU SVD accuracy to the changelog.
PiperOrigin-RevId: 530752990
2023-05-09 17:08:42 -07:00
Skye Wanderman-Milne
b02b043e7f Update versions and changelog for 0.4.9 release 2023-05-09 17:06:59 -07:00
Yash Katariya
356cac014c [Rollback] Remove py3.8 support from jax as per https://jax.readthedocs.io/en/latest/deprecation.html
PiperOrigin-RevId: 528907173
2023-05-02 15:40:27 -07:00
Yash Katariya
e51d12cdef Remove py3.8 support from jax as per https://jax.readthedocs.io/en/latest/deprecation.html
PiperOrigin-RevId: 528488319
2023-05-01 09:15:44 -07:00
Peter Hawkins
84c516974a Revert: Switch to using Clang as the default compiler.
It appears this is causing deadlocks in multi-gpu tests.

PiperOrigin-RevId: 527706573
2023-04-27 15:52:28 -07:00
Parker Schuh
782d90dc85 Switch to using Clang as the default compiler.
PiperOrigin-RevId: 526815933
2023-04-24 19:01:49 -07:00
Yash Katariya
30c6871618 Deprecate and raise an exception for instantiate_const_outputs argument of jax.xla_computation since it has been unused for a very long time.
PiperOrigin-RevId: 524295738
2023-04-14 08:20:20 -07:00
Yash Katariya
3e93833ed8 Remove in_parts, out_parts from jax.xla_computation since they were only used for sharded_jit and sharded_jit is long gone
Also remove instantiate_const_outputs since that is unused

PiperOrigin-RevId: 524113088
2023-04-13 15:05:21 -07:00
Yash Katariya
738dd719bd Remove experimental_cpp_pmap flag since it is always on
PiperOrigin-RevId: 522631405
2023-04-07 10:42:11 -07:00
Yash Katariya
694e43a44a Remove experimental_cpp_jit since that flag is unused and also remove experimental_cpp_pjit.
For dynamic shapes experimentation and normal debugging, `python_pjit` still exists so that problem doesn't exist which makes us free to remove these 2 flags.

I am leaving pmap's flag alone for now.

PiperOrigin-RevId: 522602754
2023-04-07 08:29:20 -07:00
Yash Katariya
d27a80dbfa Rename gda_serialization to array_serialization but keep gda_serialization around until it is included in a jax release so that OSS projects can be moved to array_serialization
PiperOrigin-RevId: 521055760
2023-03-31 18:07:51 -07:00
Jake VanderPlas
749dc1b95e Remove deprecated function jnp.msort 2023-03-31 08:24:36 -07:00
Yash Katariya
69c9660aab Raise deprecation warnings for {in|out}_axis_resources for pjit and axis_resources for with_sharding_constraint
PiperOrigin-RevId: 520748845
2023-03-30 14:51:01 -07:00
Skye Wanderman-Milne
30a51b21c3 Update version and changelog after jax 0.4.8 release 2023-03-29 14:27:09 -07:00
Yash Katariya
fbc05ee5ac Remove global_arg_shapes from pmap since it was only used for sharded_jit and sharded_jit was removed from JAX a long time ago
PiperOrigin-RevId: 520356179
2023-03-29 09:23:22 -07:00
Skye Wanderman-Milne
473d1c3685 Turn on PJRT C API by default.
I forgot that the default setting is actually in jaxlib:
fbe9a80fdb/xla/python/xla_client.py (L135)

To be able to make this change as a jax-only release, I manually set
the env var on Cloud TPU if it isn't already set.
2023-03-28 15:28:13 -07:00
Skye Wanderman-Milne
00acf459c6 Bump minimum jaxlib version from 0.4.6 to 0.4.7.
Also removes a bunch of dead version guards (0.4.7 has
xla_extension_version 144 and mlir_api_version 47)
2023-03-28 13:43:01 -07:00
Yash Katariya
86c0b36bfd Remove Cuda 11.4 support. JAX from 0.4.8 release will support cuda 11.8 and cuda 12
PiperOrigin-RevId: 520056811
2023-03-28 09:54:36 -07:00
Yash Katariya
670fba3a91 Finish jax and jaxlib 0.4.7 release
PiperOrigin-RevId: 519839723
2023-03-27 15:06:38 -07:00
Yash Katariya
e21aee18a8 Add deprecation warning for FROM_GDA usage since that argument is not required anymore.
PiperOrigin-RevId: 519781715
2023-03-27 11:33:11 -07:00
Peter Hawkins
b7375b316b Increase minimum NumPy version to 1.21.
Also increase minimum SciPy version to 1.7, which was released just before NumPy 1.21.
2023-03-23 21:15:10 -04:00
Peter Hawkins
8bb90b5fbe [XLA:Python] Change JAX and the XLA Python extension to get NumPy bfloat16/float8 types from ml_dtypes.
PiperOrigin-RevId: 518830467
2023-03-23 05:13:39 -07:00
Peter Hawkins
e0453add22 Mark jax.interpreters.pxla.ShardedDeviceArray as deprecated.
PiperOrigin-RevId: 518241326
2023-03-21 05:13:55 -07:00
George Necula
15acc49451 [jax2tf] Update CHANGELOG for native serialization.
PiperOrigin-RevId: 517994283
2023-03-20 09:43:32 -07:00
Yash Katariya
207cc10058 Error if jax_array or jax_jit_pjit_api_merge is set to False.
PiperOrigin-RevId: 517485597
2023-03-17 12:57:57 -07:00
Yash Katariya
c2d5527f72 [Jax cleanup]
* Remove lower_xla_callable and all related functions
* Remove pxla.device_put
* Remove dispatch.device_put_handlers

PiperOrigin-RevId: 517249345
2023-03-16 15:47:28 -07:00
Peter Hawkins
28e4038933 Mark jax.numpy.DeviceArray as deprecated. Use jax.Array instead.
PiperOrigin-RevId: 516835920
2023-03-15 08:50:00 -07:00
Jake VanderPlas
6dd0e0153a jnp.ndarray.at: deprecate passing additional arguments by position 2023-03-13 10:04:39 -07:00
Skye Wanderman-Milne
6560bf8c36 Update versions and changelog for jax + jaxlib 0.4.6 release 2023-03-09 14:50:42 -08:00
George Necula
961e09e614 [shape_poly, call_tf] Some improvements for call_tf in a shape polymorphic program
This is another attempt to land a rolled-back change https://github.com/google/jax/pull/14734 (cl/514070997).
See b/272154366 for more details.

The use case for call_tf with shape polymorphism is when we have a JAX program
that calls into TF function, and we want to serialize the JAX program with
some shapes unknown. Previously this use case did not work, except in the special
case when the output shape of the called TF function returns statically known
shapes.

The idea is that we allow the user of call_tf to specify the output shape.
This can be done even in presence of shape polymorphism, by writing the
output shape as an expression in terms of the input shapes. This is what
other JAX primitives do, e.g., concat, so we are simply enabling call_tf
to get the same behavior.

This change should be enough for old-style jax2tf, but will require more
work for native serialization.

We also removed some old code that was trying to workaround some limitations
in shape inference in TF. I think that those workarounds are ugly, and I am
prepared to give error messages rather than keep that code. So far no
tests fail.

PiperOrigin-RevId: 515137407
2023-03-08 14:10:08 -08:00
Ivy Zheng
46838a4116
fix typo 2023-03-05 13:14:57 -08:00
Ivy Zheng
dfe940de8d
add key path change to changelog 2023-03-03 18:05:37 -08:00
Skye Wanderman-Milne
ed2c5717c5 Bump version and changelog after jax 0.4.5 release 2023-03-02 16:08:34 -08:00
Jake VanderPlas
a283aa0cc3 Deprecate three jax.Array methods:
- jax.Array.broadcast: use lax.broadcast instead
- jax.Array.broadcast_in_dim: use lax.broadcast_in_dim instead
- jax.Array.split: use jnp.split instead
These are removed because they are not part of the np.ndarray API.
2023-02-23 16:15:09 -08:00
Jake VanderPlas
841bdcef5f DOC: add is_ready() to CHANGELOG 2023-02-23 11:56:48 -08:00
Yash Katariya
0ffdeb3de2 Rename jax.sharding.OpShardingSharding to jax.sharding.GSPMDSharding. jax.sharding.OpShardingSharding will be removed in 3 months from Feb 17, 2023.
PiperOrigin-RevId: 510556189
2023-02-17 17:11:06 -08:00
Yash Katariya
941722f7db Finish jax and jaxlib 0.4.4 release
PiperOrigin-RevId: 510234171
2023-02-16 13:54:56 -08:00
Peter Hawkins
b389eed8bf [JAX] Deprecate jax.experimental.maps.Mesh.
PiperOrigin-RevId: 509852142
2023-02-15 09:15:50 -08:00
Peter Hawkins
00d45feee6 Deprecate uses of jax.experimental.pjit.NamedSharding and jax.experimental.pjit.PartitionSpec.
Use the aliases under jax.sharding instead.

PiperOrigin-RevId: 509837529
2023-02-15 08:14:26 -08:00
Jake VanderPlas
dafb88a649 jax.numpy reductions: require initial to be a scalar
This follows the requirements of numpy's reduction API. Non-scalar initial values
can be implemented via .
2023-02-14 15:36:18 -08:00
Jake VanderPlas
7975192f92 Expose jax.typing & update docs 2023-02-13 15:53:08 -08:00
Yash Katariya
2fc64bee13 Change the axis_resources argument of with_sharding_constraint to shardings to match pjit and jit.
PiperOrigin-RevId: 509275107
2023-02-13 10:53:57 -08:00
Peter Hawkins
ec56d71d01 Drop support for NVIDIA Kepler series GPUs in jaxlib builds. 2023-02-10 14:15:15 -05:00
Yash Katariya
6ec9082cf5 Default jax_jit_pjit_api_merge to True. This means that the implementation of jit and pjit have been merged but they still remain separate APIs due to the semantic difference of how they behave under the Mesh context manager.
This changes the internals of JAX without affecting any public API.

Before, `jit` was a final style primitive. This means that the creation
of jaxpr was delayed as much as possible and transformations were stacked
on top of each other. With the `jit`-`pjit` implementation merge, `jit`
becomes an initial style primitive which means that we trace to jaxpr
as early as possible. For more information see [this section in autodidax](https://jax.readthedocs.io/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing).

Moving to initial style should simplify JAX's internals and make
development of features like dynamic shapes, etc easier.

PiperOrigin-RevId: 508143501
2023-02-08 11:55:48 -08:00
Skye Wanderman-Milne
21f12183bf Post 0.4.3 release updates 2023-02-08 10:08:59 -08:00