398 Commits

Author SHA1 Message Date
Peter Hawkins
59509dc2b3 Remove the jax_array config option, which does nothing.
PiperOrigin-RevId: 548981491
2023-07-18 06:16:06 -07:00
Yash Katariya
f0ce0d8c6a Delete in_axis_resources and out_axis_resources from pjit since it's been more than 3 months since their deprecation. The replace is to use in_shardings and out_shardings. You can still pass PartitionSpecs to {in|out}_shardings to pjit.
PiperOrigin-RevId: 548673905
2023-07-17 06:35:49 -07:00
George Necula
603eeb1901 Copybara import of the project:
--
06bf5fe7b2ac97156df541bab989dc5beb1aff0c by George Necula <gcnecula@gmail.com>:

[jax2tf] Added a flag and environment variable to control the serialization version.

This allows us to control the serialization version to be compatible with
the deployed version of tf.XlaCallModule. In particular, we can run
most tests with the maximum available version, while keeping the
default lower.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16746 from gnecula:tf_version 06bf5fe7b2ac97156df541bab989dc5beb1aff0c
PiperOrigin-RevId: 548504243
2023-07-16 09:27:12 -07:00
Peter Hawkins
651f87733b Remove jax_jit_pjit_api_merge.
PiperOrigin-RevId: 548236671
2023-07-14 15:25:00 -07:00
Yash Katariya
89c78bf53f jax.jit now works correctly if both donate_argnums and donate_argnames are specified.
Update the docstring and changelog too to mention `donate_argnames`.

PiperOrigin-RevId: 548223395
2023-07-14 14:28:16 -07:00
Jake VanderPlas
21f6736005 Remove several deprecated APIs 2023-07-11 12:42:32 -07:00
Jake VanderPlas
b581ad1f33 Remove several deprecated jax.Array methods:
- arr.broadcast
- arr.broadcast_in_dim
- arr.split

These have been deprecated since JAX v0.4.5

PiperOrigin-RevId: 547228974
2023-07-11 10:34:27 -07:00
Jake VanderPlas
9962065deb Require ml_dtypes>=0.2 2023-07-07 12:07:44 -07:00
Jake VanderPlas
d0e75ca117 Require index update optional arguments to be passed by keyword.
Passing these keywords by position has been deprecated and has raised a warning since JAX v0.4.7 (Released 27 March 2023)

PiperOrigin-RevId: 544620172
2023-06-30 04:30:34 -07:00
Roy Frostig
48903a382e add corner-case cond resolution fix to changelog 2023-06-28 10:09:10 -07:00
Jake VanderPlas
3f47ad367d jax.interpreters.pxla: remove deprecated functions:
- jax.interpreters.pxla.device_put
- jax.interpreters.pxla.make_sharded_device_array
2023-06-27 21:49:55 -07:00
Yash Katariya
c632cace1e Raise an error if a user passes None to host_local_array_to_global_array or global_array_to_host_local_array
PiperOrigin-RevId: 543596009
2023-06-26 18:15:43 -07:00
Jake VanderPlas
ad35702934 Drop support for numpy 1.21
This is in accordance with NEP 29 and https://jax.readthedocs.io/en/latest/deprecation.html
2023-06-23 10:28:26 -07:00
Yash Katariya
19890086fa [Rollback] Remove py3.8 support from jax as per https://jax.readthedocs.io/en/latest/deprecation.html
PiperOrigin-RevId: 542724110
2023-06-22 18:31:30 -07:00
Skye Wanderman-Milne
20095ab9da Update version numbers and changelog post 0.4.13 release 2023-06-22 17:54:57 -07:00
Peter Hawkins
487b640acf Jax 0.4.13 release. 2023-06-22 14:59:36 -07:00
Skye Wanderman-Milne
10424c5972 Update JAX's XlaExecutable.cost_analysis and related plumbing so it works on Cloud TPU
* Exposes LoadedExecutable.cost_analysis via pybind
* Updates XlaExecutable.cost_analysis to try
  LoadedExecutable.cost_analysis, then fallback to the client method.

PiperOrigin-RevId: 542671990
2023-06-22 14:43:00 -07:00
Peter Hawkins
0ec03dbdce [XLA:Python] Fix __cuda_array_interface__.
Adds a test for __cuda_array_interface__ that does not depend on cupy.

Fixes https://github.com/google/jax/issues/16440

PiperOrigin-RevId: 541965361
2023-06-20 10:09:16 -07:00
Yash Katariya
6007698f4e Allow None to be passed to in_shardings and out_shardings. The default is still UNSPECIFIED to handle edge cases around the old semantics where None is treated as fully replicated.
The semantics are as follow:

* if the mesh context manager is not provided, None will be treated as UNSPECIFIED for both in_shardings and out_shardings

* If the mesh context manager is provided, None will be treated as fully replicated as per the old semantics.

This will make sure that we don't break existing code depending on None meaning replicated but also start making the transition to None meaning UNSPECIFIED for jit and pjit.

PiperOrigin-RevId: 540705660
2023-06-15 15:22:22 -07:00
George Necula
0961fb9eba [jax2tf] Add native_lowering_disabled_checks parameter to jax2tf.convert.
Previously, we had a boolean `native_serialization_strict_checks` parameter
that was disabling all safety checks. This mechanism had several
disadvantages:

  * the mechanism did not differentiate between different safety checks.
    E.g., in order to disable checking of the custom call targets, one
    had to disable checking for all custom call targets, and also the
    checking that the serialization and execution platforms are the same.
  * the mechanism operated only at serialization time. Now, the
    XlaCallModule supports a `disabled_checks` attribute to control
    which safety checks should be disabled.

Here we replace the `native_serialization_strict_checks` with
`native_serialization_disabled_checks`, whose values are sequences
of disabled check descriptors.
2023-06-13 08:04:58 +03:00
Peter Hawkins
0ded163027 Fix cuda12 pip install.
The wheel is now called cudnn89.

Fixes #16362
2023-06-12 15:38:16 -04:00
Skye Wanderman-Milne
4b80103077 Update versions and changelog post 0.4.12 release 2023-06-08 16:53:22 -07:00
Yash Katariya
14451492c9 Delete OpShardingSharding export since it has been 3 months since it was deprecated. Also remove deprecation warnings for MeshPspecSharding.
PiperOrigin-RevId: 538880293
2023-06-08 13:53:38 -07:00
Peter Hawkins
6374b73ce6 [XLA:Python] Fix crash when accessing locals of JAX-generated Python tracebacks.
Use a dummy PyCodeObject in our dummy PyFrameObjects. Using a real PyCodeObject with a fake PyFrameObject confuses CPython 3.11+ when it attempts to compute the locals of a frame, since the frame lacks certain details such as closure information.

This unfortunately means we will not get source column information under Python 3.11 any more, but that is probably better than crashing.

Fixes https://github.com/google/jax/issues/16027

PiperOrigin-RevId: 538873850
2023-06-08 13:22:30 -07:00
Peter Hawkins
a47aca8205 [XLA:Python] Fix incorrect code source information under Python 3.11.
Do not multiply the result of PyFrame_GetLasti() by sizeof(_Py_CODEUNIT), because the CPython implementation already does this inside PyFrame_GetLasti().

* In CPython versions 3.9 or earlier, the f_lasti value in a PyFrameObject was in bytes.
* In CPython 3.10, f_lasti was changed to be in code units, which required multiplying it by sizeof(_Py_CODEUNIT) before passing it to functions like PyCode_Addr2Line(). https://docs.python.org/3/whatsnew/3.10.html#changes-in-the-c-api
* In CPython 3.11, direct access to the representation of the PyFrameObject was removed from the headers, requiring the use of PyFrame_GetLasti() (https://docs.python.org/3/whatsnew/3.11.html#pyframeobject-3-11-hiding). This function multiplies by sizeof(_Py_CODEUNIT) internally (deaf509e8f/Objects/frameobject.c (L1353)) so there is no need for the caller to do this multiplication.

It is difficult to write a good test for this, since the only symptom is slightly inaccurate code line information. This issue was found under a debug mode build of CPython (https://docs.python.org/3/using/configure.html#python-debug-build), where PyCode_Addr2Line() has additional checks for out of range lasti values.

PiperOrigin-RevId: 538847288
2023-06-08 11:42:26 -07:00
jax authors
8d27f20637 Merge pull request #16246 from chrisflesher:scipy-rotation-v3
PiperOrigin-RevId: 538788621
2023-06-08 08:10:58 -07:00
Chris Flesher
5be17ed90c Added scipy.spatial.transform Rotation and Slerp classes 2023-06-08 07:51:32 -05:00
Jake VanderPlas
47ae5bddd7 Mark jax.abstract_arrays as deprecated 2023-06-07 23:36:40 -07:00
Skye Wanderman-Milne
1a3ac88c35 [PJRT:C] Implement PJRT_Device_MemoryStats
Inspired by https://github.com/google/jax/issues/1491#issuecomment-1567696310

PiperOrigin-RevId: 538596599
2023-06-07 14:47:49 -07:00
Peter Hawkins
cb33fdf3f7 Include SASS/PTX for Hopper GPUs. 2023-06-05 09:42:12 -04:00
Jake VanderPlas
3bef6214bb Deprecate jax.numpy functions alltrue, sometrue, product, cumproduct 2023-06-02 04:10:46 -07:00
Yash Katariya
4c48611fba Finish jax and jaxlib 0.4.11 release
PiperOrigin-RevId: 536931532
2023-05-31 23:49:32 -07:00
Jake VanderPlas
7a87995ecd Deprecate jax.interpreters.xla.Buffer, device_put, xla_call_p 2023-05-28 07:15:34 -07:00
Yash Katariya
fe3fed3627 Remove axis_resources from with_sharding_constraint since it has been 3 months since the deprecation as per the API deprecation policy.
PiperOrigin-RevId: 535687618
2023-05-26 12:35:16 -07:00
Peter Hawkins
e464dc8700 Reland: [XLA:Python] Add buffer protocol support to jax.Array
We supported the buffer protocol on the older DeviceArray class; port that support to jax.Array.

The previous attempt was reverted because it led to a C++ CHECK failure if the buffer was deleted while an external Python reference was held. Change the CPU PJRT client to keep the underlying buffer alive as long as there are external references, which is what the contract of Delete() says it will do.

Fixes https://github.com/google/jax/issues/14713

PiperOrigin-RevId: 535248553
2023-05-25 07:20:42 -07:00
Jake Vanderplas
399e4ee87f Copybara import of the project:
--
8cf6a6acd151007935b0c3093df05ef036bb0244 by Jake VanderPlas <jakevdp@google.com>:

Remove several deprecated APIs

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16110 from jakevdp:deprecations 8cf6a6acd151007935b0c3093df05ef036bb0244
PiperOrigin-RevId: 534897394
2023-05-24 10:35:37 -07:00
Skye Wanderman-Milne
533a7c05f1 Update versions and changelog post 0.4.10 release 2023-05-11 18:16:02 -07:00
Skye Wanderman-Milne
82bbeef519 Update setup.py, WORKSPACE, and CHANGELOG for jax/jaxlib 0.4.10 release 2023-05-11 14:46:06 -07:00
jax authors
bbc96320ed Merge pull request #15947 from skye:version
PiperOrigin-RevId: 530765476
2023-05-09 18:12:38 -07:00
Peter Hawkins
cc5e694658 Add improved TPU SVD accuracy to the changelog.
PiperOrigin-RevId: 530752990
2023-05-09 17:08:42 -07:00
Skye Wanderman-Milne
b02b043e7f Update versions and changelog for 0.4.9 release 2023-05-09 17:06:59 -07:00
Yash Katariya
356cac014c [Rollback] Remove py3.8 support from jax as per https://jax.readthedocs.io/en/latest/deprecation.html
PiperOrigin-RevId: 528907173
2023-05-02 15:40:27 -07:00
Yash Katariya
e51d12cdef Remove py3.8 support from jax as per https://jax.readthedocs.io/en/latest/deprecation.html
PiperOrigin-RevId: 528488319
2023-05-01 09:15:44 -07:00
Peter Hawkins
84c516974a Revert: Switch to using Clang as the default compiler.
It appears this is causing deadlocks in multi-gpu tests.

PiperOrigin-RevId: 527706573
2023-04-27 15:52:28 -07:00
Parker Schuh
782d90dc85 Switch to using Clang as the default compiler.
PiperOrigin-RevId: 526815933
2023-04-24 19:01:49 -07:00
Yash Katariya
30c6871618 Deprecate and raise an exception for instantiate_const_outputs argument of jax.xla_computation since it has been unused for a very long time.
PiperOrigin-RevId: 524295738
2023-04-14 08:20:20 -07:00
Yash Katariya
3e93833ed8 Remove in_parts, out_parts from jax.xla_computation since they were only used for sharded_jit and sharded_jit is long gone
Also remove instantiate_const_outputs since that is unused

PiperOrigin-RevId: 524113088
2023-04-13 15:05:21 -07:00
Yash Katariya
738dd719bd Remove experimental_cpp_pmap flag since it is always on
PiperOrigin-RevId: 522631405
2023-04-07 10:42:11 -07:00
Yash Katariya
694e43a44a Remove experimental_cpp_jit since that flag is unused and also remove experimental_cpp_pjit.
For dynamic shapes experimentation and normal debugging, `python_pjit` still exists so that problem doesn't exist which makes us free to remove these 2 flags.

I am leaving pmap's flag alone for now.

PiperOrigin-RevId: 522602754
2023-04-07 08:29:20 -07:00
Yash Katariya
d27a80dbfa Rename gda_serialization to array_serialization but keep gda_serialization around until it is included in a jax release so that OSS projects can be moved to array_serialization
PiperOrigin-RevId: 521055760
2023-03-31 18:07:51 -07:00