Peter Hawkins
cb33fdf3f7
Include SASS/PTX for Hopper GPUs.
2023-06-05 09:42:12 -04:00
Jake VanderPlas
3bef6214bb
Deprecate jax.numpy functions alltrue, sometrue, product, cumproduct
2023-06-02 04:10:46 -07:00
Yash Katariya
4c48611fba
Finish jax and jaxlib 0.4.11 release
...
PiperOrigin-RevId: 536931532
2023-05-31 23:49:32 -07:00
Jake VanderPlas
7a87995ecd
Deprecate jax.interpreters.xla.Buffer, device_put, xla_call_p
2023-05-28 07:15:34 -07:00
Yash Katariya
fe3fed3627
Remove axis_resources from with_sharding_constraint since it has been 3 months since the deprecation as per the API deprecation policy.
...
PiperOrigin-RevId: 535687618
2023-05-26 12:35:16 -07:00
Peter Hawkins
e464dc8700
Reland: [XLA:Python] Add buffer protocol support to jax.Array
...
We supported the buffer protocol on the older DeviceArray class; port that support to jax.Array.
The previous attempt was reverted because it led to a C++ CHECK failure if the buffer was deleted while an external Python reference was held. Change the CPU PJRT client to keep the underlying buffer alive as long as there are external references, which is what the contract of Delete() says it will do.
Fixes https://github.com/google/jax/issues/14713
PiperOrigin-RevId: 535248553
2023-05-25 07:20:42 -07:00
Jake Vanderplas
399e4ee87f
Copybara import of the project:
...
--
8cf6a6acd151007935b0c3093df05ef036bb0244 by Jake VanderPlas <jakevdp@google.com>:
Remove several deprecated APIs
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16110 from jakevdp:deprecations 8cf6a6acd151007935b0c3093df05ef036bb0244
PiperOrigin-RevId: 534897394
2023-05-24 10:35:37 -07:00
Skye Wanderman-Milne
533a7c05f1
Update versions and changelog post 0.4.10 release
2023-05-11 18:16:02 -07:00
Skye Wanderman-Milne
82bbeef519
Update setup.py, WORKSPACE, and CHANGELOG for jax/jaxlib 0.4.10 release
2023-05-11 14:46:06 -07:00
jax authors
bbc96320ed
Merge pull request #15947 from skye:version
...
PiperOrigin-RevId: 530765476
2023-05-09 18:12:38 -07:00
Peter Hawkins
cc5e694658
Add improved TPU SVD accuracy to the changelog.
...
PiperOrigin-RevId: 530752990
2023-05-09 17:08:42 -07:00
Skye Wanderman-Milne
b02b043e7f
Update versions and changelog for 0.4.9 release
2023-05-09 17:06:59 -07:00
Yash Katariya
356cac014c
[Rollback] Remove py3.8 support from jax as per https://jax.readthedocs.io/en/latest/deprecation.html
...
PiperOrigin-RevId: 528907173
2023-05-02 15:40:27 -07:00
Yash Katariya
e51d12cdef
Remove py3.8 support from jax as per https://jax.readthedocs.io/en/latest/deprecation.html
...
PiperOrigin-RevId: 528488319
2023-05-01 09:15:44 -07:00
Peter Hawkins
84c516974a
Revert: Switch to using Clang as the default compiler.
...
It appears this is causing deadlocks in multi-gpu tests.
PiperOrigin-RevId: 527706573
2023-04-27 15:52:28 -07:00
Parker Schuh
782d90dc85
Switch to using Clang as the default compiler.
...
PiperOrigin-RevId: 526815933
2023-04-24 19:01:49 -07:00
Yash Katariya
30c6871618
Deprecate and raise an exception for instantiate_const_outputs
argument of jax.xla_computation since it has been unused for a very long time.
...
PiperOrigin-RevId: 524295738
2023-04-14 08:20:20 -07:00
Yash Katariya
3e93833ed8
Remove in_parts, out_parts from jax.xla_computation since they were only used for sharded_jit and sharded_jit is long gone
...
Also remove instantiate_const_outputs since that is unused
PiperOrigin-RevId: 524113088
2023-04-13 15:05:21 -07:00
Yash Katariya
738dd719bd
Remove experimental_cpp_pmap flag since it is always on
...
PiperOrigin-RevId: 522631405
2023-04-07 10:42:11 -07:00
Yash Katariya
694e43a44a
Remove experimental_cpp_jit
since that flag is unused and also remove experimental_cpp_pjit
.
...
For dynamic shapes experimentation and normal debugging, `python_pjit` still exists so that problem doesn't exist which makes us free to remove these 2 flags.
I am leaving pmap's flag alone for now.
PiperOrigin-RevId: 522602754
2023-04-07 08:29:20 -07:00
Yash Katariya
d27a80dbfa
Rename gda_serialization to array_serialization but keep gda_serialization around until it is included in a jax release so that OSS projects can be moved to array_serialization
...
PiperOrigin-RevId: 521055760
2023-03-31 18:07:51 -07:00
Jake VanderPlas
749dc1b95e
Remove deprecated function jnp.msort
2023-03-31 08:24:36 -07:00
Yash Katariya
69c9660aab
Raise deprecation warnings for {in|out}_axis_resources
for pjit and axis_resources
for with_sharding_constraint
...
PiperOrigin-RevId: 520748845
2023-03-30 14:51:01 -07:00
Skye Wanderman-Milne
30a51b21c3
Update version and changelog after jax 0.4.8 release
2023-03-29 14:27:09 -07:00
Yash Katariya
fbc05ee5ac
Remove global_arg_shapes from pmap since it was only used for sharded_jit and sharded_jit was removed from JAX a long time ago
...
PiperOrigin-RevId: 520356179
2023-03-29 09:23:22 -07:00
Skye Wanderman-Milne
473d1c3685
Turn on PJRT C API by default.
...
I forgot that the default setting is actually in jaxlib:
fbe9a80fdb/xla/python/xla_client.py (L135)
To be able to make this change as a jax-only release, I manually set
the env var on Cloud TPU if it isn't already set.
2023-03-28 15:28:13 -07:00
Skye Wanderman-Milne
00acf459c6
Bump minimum jaxlib version from 0.4.6 to 0.4.7.
...
Also removes a bunch of dead version guards (0.4.7 has
xla_extension_version 144 and mlir_api_version 47)
2023-03-28 13:43:01 -07:00
Yash Katariya
86c0b36bfd
Remove Cuda 11.4 support. JAX from 0.4.8 release will support cuda 11.8 and cuda 12
...
PiperOrigin-RevId: 520056811
2023-03-28 09:54:36 -07:00
Yash Katariya
670fba3a91
Finish jax and jaxlib 0.4.7 release
...
PiperOrigin-RevId: 519839723
2023-03-27 15:06:38 -07:00
Yash Katariya
e21aee18a8
Add deprecation warning for FROM_GDA usage since that argument is not required anymore.
...
PiperOrigin-RevId: 519781715
2023-03-27 11:33:11 -07:00
Peter Hawkins
b7375b316b
Increase minimum NumPy version to 1.21.
...
Also increase minimum SciPy version to 1.7, which was released just before NumPy 1.21.
2023-03-23 21:15:10 -04:00
Peter Hawkins
8bb90b5fbe
[XLA:Python] Change JAX and the XLA Python extension to get NumPy bfloat16/float8 types from ml_dtypes.
...
PiperOrigin-RevId: 518830467
2023-03-23 05:13:39 -07:00
Peter Hawkins
e0453add22
Mark jax.interpreters.pxla.ShardedDeviceArray as deprecated.
...
PiperOrigin-RevId: 518241326
2023-03-21 05:13:55 -07:00
George Necula
15acc49451
[jax2tf] Update CHANGELOG for native serialization.
...
PiperOrigin-RevId: 517994283
2023-03-20 09:43:32 -07:00
Yash Katariya
207cc10058
Error if jax_array or jax_jit_pjit_api_merge is set to False.
...
PiperOrigin-RevId: 517485597
2023-03-17 12:57:57 -07:00
Yash Katariya
c2d5527f72
[Jax cleanup]
...
* Remove lower_xla_callable and all related functions
* Remove pxla.device_put
* Remove dispatch.device_put_handlers
PiperOrigin-RevId: 517249345
2023-03-16 15:47:28 -07:00
Peter Hawkins
28e4038933
Mark jax.numpy.DeviceArray as deprecated. Use jax.Array instead.
...
PiperOrigin-RevId: 516835920
2023-03-15 08:50:00 -07:00
Jake VanderPlas
6dd0e0153a
jnp.ndarray.at: deprecate passing additional arguments by position
2023-03-13 10:04:39 -07:00
Skye Wanderman-Milne
6560bf8c36
Update versions and changelog for jax + jaxlib 0.4.6 release
2023-03-09 14:50:42 -08:00
George Necula
961e09e614
[shape_poly, call_tf] Some improvements for call_tf in a shape polymorphic program
...
This is another attempt to land a rolled-back change https://github.com/google/jax/pull/14734 (cl/514070997).
See b/272154366 for more details.
The use case for call_tf with shape polymorphism is when we have a JAX program
that calls into TF function, and we want to serialize the JAX program with
some shapes unknown. Previously this use case did not work, except in the special
case when the output shape of the called TF function returns statically known
shapes.
The idea is that we allow the user of call_tf to specify the output shape.
This can be done even in presence of shape polymorphism, by writing the
output shape as an expression in terms of the input shapes. This is what
other JAX primitives do, e.g., concat, so we are simply enabling call_tf
to get the same behavior.
This change should be enough for old-style jax2tf, but will require more
work for native serialization.
We also removed some old code that was trying to workaround some limitations
in shape inference in TF. I think that those workarounds are ugly, and I am
prepared to give error messages rather than keep that code. So far no
tests fail.
PiperOrigin-RevId: 515137407
2023-03-08 14:10:08 -08:00
Ivy Zheng
46838a4116
fix typo
2023-03-05 13:14:57 -08:00
Ivy Zheng
dfe940de8d
add key path change to changelog
2023-03-03 18:05:37 -08:00
Skye Wanderman-Milne
ed2c5717c5
Bump version and changelog after jax 0.4.5 release
2023-03-02 16:08:34 -08:00
Jake VanderPlas
a283aa0cc3
Deprecate three jax.Array methods:
...
- jax.Array.broadcast: use lax.broadcast instead
- jax.Array.broadcast_in_dim: use lax.broadcast_in_dim instead
- jax.Array.split: use jnp.split instead
These are removed because they are not part of the np.ndarray API.
2023-02-23 16:15:09 -08:00
Jake VanderPlas
841bdcef5f
DOC: add is_ready() to CHANGELOG
2023-02-23 11:56:48 -08:00
Yash Katariya
0ffdeb3de2
Rename jax.sharding.OpShardingSharding
to jax.sharding.GSPMDSharding
. jax.sharding.OpShardingSharding
will be removed in 3 months from Feb 17, 2023.
...
PiperOrigin-RevId: 510556189
2023-02-17 17:11:06 -08:00
Yash Katariya
941722f7db
Finish jax and jaxlib 0.4.4 release
...
PiperOrigin-RevId: 510234171
2023-02-16 13:54:56 -08:00
Peter Hawkins
b389eed8bf
[JAX] Deprecate jax.experimental.maps.Mesh.
...
PiperOrigin-RevId: 509852142
2023-02-15 09:15:50 -08:00
Peter Hawkins
00d45feee6
Deprecate uses of jax.experimental.pjit.NamedSharding and jax.experimental.pjit.PartitionSpec.
...
Use the aliases under jax.sharding instead.
PiperOrigin-RevId: 509837529
2023-02-15 08:14:26 -08:00
Jake VanderPlas
dafb88a649
jax.numpy reductions: require initial to be a scalar
...
This follows the requirements of numpy's reduction API. Non-scalar initial values
can be implemented via .
2023-02-14 15:36:18 -08:00