393 Commits

Author SHA1 Message Date
Jake VanderPlas
21f6736005 Remove several deprecated APIs 2023-07-11 12:42:32 -07:00
Jake VanderPlas
b581ad1f33 Remove several deprecated jax.Array methods:
- arr.broadcast
- arr.broadcast_in_dim
- arr.split

These have been deprecated since JAX v0.4.5

PiperOrigin-RevId: 547228974
2023-07-11 10:34:27 -07:00
Jake VanderPlas
9962065deb Require ml_dtypes>=0.2 2023-07-07 12:07:44 -07:00
Jake VanderPlas
d0e75ca117 Require index update optional arguments to be passed by keyword.
Passing these keywords by position has been deprecated and has raised a warning since JAX v0.4.7 (Released 27 March 2023)

PiperOrigin-RevId: 544620172
2023-06-30 04:30:34 -07:00
Roy Frostig
48903a382e add corner-case cond resolution fix to changelog 2023-06-28 10:09:10 -07:00
Jake VanderPlas
3f47ad367d jax.interpreters.pxla: remove deprecated functions:
- jax.interpreters.pxla.device_put
- jax.interpreters.pxla.make_sharded_device_array
2023-06-27 21:49:55 -07:00
Yash Katariya
c632cace1e Raise an error if a user passes None to host_local_array_to_global_array or global_array_to_host_local_array
PiperOrigin-RevId: 543596009
2023-06-26 18:15:43 -07:00
Jake VanderPlas
ad35702934 Drop support for numpy 1.21
This is in accordance with NEP 29 and https://jax.readthedocs.io/en/latest/deprecation.html
2023-06-23 10:28:26 -07:00
Yash Katariya
19890086fa [Rollback] Remove py3.8 support from jax as per https://jax.readthedocs.io/en/latest/deprecation.html
PiperOrigin-RevId: 542724110
2023-06-22 18:31:30 -07:00
Skye Wanderman-Milne
20095ab9da Update version numbers and changelog post 0.4.13 release 2023-06-22 17:54:57 -07:00
Peter Hawkins
487b640acf Jax 0.4.13 release. 2023-06-22 14:59:36 -07:00
Skye Wanderman-Milne
10424c5972 Update JAX's XlaExecutable.cost_analysis and related plumbing so it works on Cloud TPU
* Exposes LoadedExecutable.cost_analysis via pybind
* Updates XlaExecutable.cost_analysis to try
  LoadedExecutable.cost_analysis, then fallback to the client method.

PiperOrigin-RevId: 542671990
2023-06-22 14:43:00 -07:00
Peter Hawkins
0ec03dbdce [XLA:Python] Fix __cuda_array_interface__.
Adds a test for __cuda_array_interface__ that does not depend on cupy.

Fixes https://github.com/google/jax/issues/16440

PiperOrigin-RevId: 541965361
2023-06-20 10:09:16 -07:00
Yash Katariya
6007698f4e Allow None to be passed to in_shardings and out_shardings. The default is still UNSPECIFIED to handle edge cases around the old semantics where None is treated as fully replicated.
The semantics are as follow:

* if the mesh context manager is not provided, None will be treated as UNSPECIFIED for both in_shardings and out_shardings

* If the mesh context manager is provided, None will be treated as fully replicated as per the old semantics.

This will make sure that we don't break existing code depending on None meaning replicated but also start making the transition to None meaning UNSPECIFIED for jit and pjit.

PiperOrigin-RevId: 540705660
2023-06-15 15:22:22 -07:00
George Necula
0961fb9eba [jax2tf] Add native_lowering_disabled_checks parameter to jax2tf.convert.
Previously, we had a boolean `native_serialization_strict_checks` parameter
that was disabling all safety checks. This mechanism had several
disadvantages:

  * the mechanism did not differentiate between different safety checks.
    E.g., in order to disable checking of the custom call targets, one
    had to disable checking for all custom call targets, and also the
    checking that the serialization and execution platforms are the same.
  * the mechanism operated only at serialization time. Now, the
    XlaCallModule supports a `disabled_checks` attribute to control
    which safety checks should be disabled.

Here we replace the `native_serialization_strict_checks` with
`native_serialization_disabled_checks`, whose values are sequences
of disabled check descriptors.
2023-06-13 08:04:58 +03:00
Peter Hawkins
0ded163027 Fix cuda12 pip install.
The wheel is now called cudnn89.

Fixes #16362
2023-06-12 15:38:16 -04:00
Skye Wanderman-Milne
4b80103077 Update versions and changelog post 0.4.12 release 2023-06-08 16:53:22 -07:00
Yash Katariya
14451492c9 Delete OpShardingSharding export since it has been 3 months since it was deprecated. Also remove deprecation warnings for MeshPspecSharding.
PiperOrigin-RevId: 538880293
2023-06-08 13:53:38 -07:00
Peter Hawkins
6374b73ce6 [XLA:Python] Fix crash when accessing locals of JAX-generated Python tracebacks.
Use a dummy PyCodeObject in our dummy PyFrameObjects. Using a real PyCodeObject with a fake PyFrameObject confuses CPython 3.11+ when it attempts to compute the locals of a frame, since the frame lacks certain details such as closure information.

This unfortunately means we will not get source column information under Python 3.11 any more, but that is probably better than crashing.

Fixes https://github.com/google/jax/issues/16027

PiperOrigin-RevId: 538873850
2023-06-08 13:22:30 -07:00
Peter Hawkins
a47aca8205 [XLA:Python] Fix incorrect code source information under Python 3.11.
Do not multiply the result of PyFrame_GetLasti() by sizeof(_Py_CODEUNIT), because the CPython implementation already does this inside PyFrame_GetLasti().

* In CPython versions 3.9 or earlier, the f_lasti value in a PyFrameObject was in bytes.
* In CPython 3.10, f_lasti was changed to be in code units, which required multiplying it by sizeof(_Py_CODEUNIT) before passing it to functions like PyCode_Addr2Line(). https://docs.python.org/3/whatsnew/3.10.html#changes-in-the-c-api
* In CPython 3.11, direct access to the representation of the PyFrameObject was removed from the headers, requiring the use of PyFrame_GetLasti() (https://docs.python.org/3/whatsnew/3.11.html#pyframeobject-3-11-hiding). This function multiplies by sizeof(_Py_CODEUNIT) internally (deaf509e8f/Objects/frameobject.c (L1353)) so there is no need for the caller to do this multiplication.

It is difficult to write a good test for this, since the only symptom is slightly inaccurate code line information. This issue was found under a debug mode build of CPython (https://docs.python.org/3/using/configure.html#python-debug-build), where PyCode_Addr2Line() has additional checks for out of range lasti values.

PiperOrigin-RevId: 538847288
2023-06-08 11:42:26 -07:00
jax authors
8d27f20637 Merge pull request #16246 from chrisflesher:scipy-rotation-v3
PiperOrigin-RevId: 538788621
2023-06-08 08:10:58 -07:00
Chris Flesher
5be17ed90c Added scipy.spatial.transform Rotation and Slerp classes 2023-06-08 07:51:32 -05:00
Jake VanderPlas
47ae5bddd7 Mark jax.abstract_arrays as deprecated 2023-06-07 23:36:40 -07:00
Skye Wanderman-Milne
1a3ac88c35 [PJRT:C] Implement PJRT_Device_MemoryStats
Inspired by https://github.com/google/jax/issues/1491#issuecomment-1567696310

PiperOrigin-RevId: 538596599
2023-06-07 14:47:49 -07:00
Peter Hawkins
cb33fdf3f7 Include SASS/PTX for Hopper GPUs. 2023-06-05 09:42:12 -04:00
Jake VanderPlas
3bef6214bb Deprecate jax.numpy functions alltrue, sometrue, product, cumproduct 2023-06-02 04:10:46 -07:00
Yash Katariya
4c48611fba Finish jax and jaxlib 0.4.11 release
PiperOrigin-RevId: 536931532
2023-05-31 23:49:32 -07:00
Jake VanderPlas
7a87995ecd Deprecate jax.interpreters.xla.Buffer, device_put, xla_call_p 2023-05-28 07:15:34 -07:00
Yash Katariya
fe3fed3627 Remove axis_resources from with_sharding_constraint since it has been 3 months since the deprecation as per the API deprecation policy.
PiperOrigin-RevId: 535687618
2023-05-26 12:35:16 -07:00
Peter Hawkins
e464dc8700 Reland: [XLA:Python] Add buffer protocol support to jax.Array
We supported the buffer protocol on the older DeviceArray class; port that support to jax.Array.

The previous attempt was reverted because it led to a C++ CHECK failure if the buffer was deleted while an external Python reference was held. Change the CPU PJRT client to keep the underlying buffer alive as long as there are external references, which is what the contract of Delete() says it will do.

Fixes https://github.com/google/jax/issues/14713

PiperOrigin-RevId: 535248553
2023-05-25 07:20:42 -07:00
Jake Vanderplas
399e4ee87f Copybara import of the project:
--
8cf6a6acd151007935b0c3093df05ef036bb0244 by Jake VanderPlas <jakevdp@google.com>:

Remove several deprecated APIs

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16110 from jakevdp:deprecations 8cf6a6acd151007935b0c3093df05ef036bb0244
PiperOrigin-RevId: 534897394
2023-05-24 10:35:37 -07:00
Skye Wanderman-Milne
533a7c05f1 Update versions and changelog post 0.4.10 release 2023-05-11 18:16:02 -07:00
Skye Wanderman-Milne
82bbeef519 Update setup.py, WORKSPACE, and CHANGELOG for jax/jaxlib 0.4.10 release 2023-05-11 14:46:06 -07:00
jax authors
bbc96320ed Merge pull request #15947 from skye:version
PiperOrigin-RevId: 530765476
2023-05-09 18:12:38 -07:00
Peter Hawkins
cc5e694658 Add improved TPU SVD accuracy to the changelog.
PiperOrigin-RevId: 530752990
2023-05-09 17:08:42 -07:00
Skye Wanderman-Milne
b02b043e7f Update versions and changelog for 0.4.9 release 2023-05-09 17:06:59 -07:00
Yash Katariya
356cac014c [Rollback] Remove py3.8 support from jax as per https://jax.readthedocs.io/en/latest/deprecation.html
PiperOrigin-RevId: 528907173
2023-05-02 15:40:27 -07:00
Yash Katariya
e51d12cdef Remove py3.8 support from jax as per https://jax.readthedocs.io/en/latest/deprecation.html
PiperOrigin-RevId: 528488319
2023-05-01 09:15:44 -07:00
Peter Hawkins
84c516974a Revert: Switch to using Clang as the default compiler.
It appears this is causing deadlocks in multi-gpu tests.

PiperOrigin-RevId: 527706573
2023-04-27 15:52:28 -07:00
Parker Schuh
782d90dc85 Switch to using Clang as the default compiler.
PiperOrigin-RevId: 526815933
2023-04-24 19:01:49 -07:00
Yash Katariya
30c6871618 Deprecate and raise an exception for instantiate_const_outputs argument of jax.xla_computation since it has been unused for a very long time.
PiperOrigin-RevId: 524295738
2023-04-14 08:20:20 -07:00
Yash Katariya
3e93833ed8 Remove in_parts, out_parts from jax.xla_computation since they were only used for sharded_jit and sharded_jit is long gone
Also remove instantiate_const_outputs since that is unused

PiperOrigin-RevId: 524113088
2023-04-13 15:05:21 -07:00
Yash Katariya
738dd719bd Remove experimental_cpp_pmap flag since it is always on
PiperOrigin-RevId: 522631405
2023-04-07 10:42:11 -07:00
Yash Katariya
694e43a44a Remove experimental_cpp_jit since that flag is unused and also remove experimental_cpp_pjit.
For dynamic shapes experimentation and normal debugging, `python_pjit` still exists so that problem doesn't exist which makes us free to remove these 2 flags.

I am leaving pmap's flag alone for now.

PiperOrigin-RevId: 522602754
2023-04-07 08:29:20 -07:00
Yash Katariya
d27a80dbfa Rename gda_serialization to array_serialization but keep gda_serialization around until it is included in a jax release so that OSS projects can be moved to array_serialization
PiperOrigin-RevId: 521055760
2023-03-31 18:07:51 -07:00
Jake VanderPlas
749dc1b95e Remove deprecated function jnp.msort 2023-03-31 08:24:36 -07:00
Yash Katariya
69c9660aab Raise deprecation warnings for {in|out}_axis_resources for pjit and axis_resources for with_sharding_constraint
PiperOrigin-RevId: 520748845
2023-03-30 14:51:01 -07:00
Skye Wanderman-Milne
30a51b21c3 Update version and changelog after jax 0.4.8 release 2023-03-29 14:27:09 -07:00
Yash Katariya
fbc05ee5ac Remove global_arg_shapes from pmap since it was only used for sharded_jit and sharded_jit was removed from JAX a long time ago
PiperOrigin-RevId: 520356179
2023-03-29 09:23:22 -07:00
Skye Wanderman-Milne
473d1c3685 Turn on PJRT C API by default.
I forgot that the default setting is actually in jaxlib:
fbe9a80fdb/xla/python/xla_client.py (L135)

To be able to make this change as a jax-only release, I manually set
the env var on Cloud TPU if it isn't already set.
2023-03-28 15:28:13 -07:00