416 Commits

Author SHA1 Message Date
Peter Hawkins
7c871916f7 Deprecate jax.numpy.in1d.
Issue https://github.com/google/jax/issues/17244
2023-08-23 17:36:14 -06:00
Jake VanderPlas
19a57e1a01 Deprecate jax.numpy.row_stack 2023-08-22 13:12:49 -07:00
Peter Hawkins
4224a4d129 Deprecate jax.scipy.linalg.tril and jax.scipy.linalg.triu.
The corresponding functions are deprecated in scipy. Use the equivalent jax.numpy functions instead.
2023-08-18 16:14:42 -04:00
George Necula
ad15a38ec1 [host_callback] Remove old backwards compatibility flag jax_host_callback_ad_transforms.
This flag was added in https://github.com/google/jax/pull/8678 in December 2021
when we changed the behavior of host_callback to not have special handling for autodiff. Nobody is using that flag now.

This is part of a longer project to replace uses of host_callback with jax.pure_callback and jax.experimental.io_callback.

PiperOrigin-RevId: 557520668
2023-08-16 10:01:49 -07:00
George Necula
cf4e1d414b [jax2tf] Bump the default JAX serialization version to 7.
This enables shape assertion checking, the support for which
landed in XlaCallModule on July 12th, 2023.

See the CHANGELOG for details.

PiperOrigin-RevId: 556222908
2023-08-11 22:49:31 -07:00
Jake VanderPlas
ad8e719b82 Add jnp.ufunc and jnp.frompyfunc 2023-08-10 14:58:18 -07:00
Peter Hawkins
0e80d959c8 Mark jnp.{NINF,NZERO,PZERO} as deprecated.
This follows the upstream NumPy deprecation of these names (https://github.com/numpy/numpy/pull/24357).

PiperOrigin-RevId: 555548986
2023-08-10 10:25:21 -07:00
Skye Wanderman-Milne
3e50fea29e Remove option to use StreamExecutor Cloud TPU client in JAX
It's been over three months since the new PJRT C API client was
enabled by default
(https://jax.readthedocs.io/en/latest/changelog.html#jax-0-4-8-march-29-2023).

PiperOrigin-RevId: 554935166
2023-08-08 14:05:27 -07:00
Jake Vanderplas
d8f799391b COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/17027 from jakevdp:dtypes-annotations a116a9c498a7b085f9b3fec93b37da12289f6e31
PiperOrigin-RevId: 554905739
2023-08-08 20:38:44 +00:00
Peter Hawkins
afd56c15d9 Move jax.jaxpr_util to jax._src.jaxpr_util, and split it into a separate build target.
Change jaxpr_util_test to be a py_test(), since there's no point testing it on every hardware configuration.

PiperOrigin-RevId: 554861284
2023-08-08 10:09:09 -07:00
Peter Hawkins
c879f65aa6 [JAX] Remove the non-coordination service distributed service implementation from JAX.
The coordination service has been the default for a long time, and has significant additional functionality. Remove the older code path to simplify the code.

PiperOrigin-RevId: 554608165
2023-08-07 15:17:25 -07:00
George Necula
8d80e2587b [jax2tf] Turn on JAX native serialization by default.
See changes to the README.md for mechanisms to override the default.

PiperOrigin-RevId: 554390866
2023-08-07 01:03:55 -07:00
Patrick Kidger
d6dad3827d Documented the shortening of tracebacks 2023-08-04 12:20:19 -07:00
Skye Wanderman-Milne
011fc88c03 Update versions and changelog for jax 0.4.14 release 2023-07-27 16:22:53 -07:00
jax authors
1ceddfc98a Merge pull request #16710 from gnecula:poly_max0
PiperOrigin-RevId: 549515427
2023-07-19 21:40:17 -07:00
Jake VanderPlas
7205160095 Re-parameterize jax.random.gamma for better behavior at endpoints 2023-07-19 16:15:03 -07:00
jax authors
5ae3ac28cd Add deprecation of jax.stages.Compiled.compiler_ir to the change log
PiperOrigin-RevId: 549415191
2023-07-19 13:48:55 -07:00
George Necula
e643f98558 [shape_poly] Reimplement the shape constraint checking using shape assertions.
Most of the functionality is for the JAX native serialization case.
This relies on newly added functionality to xla_extension.refine_polymorphic_shapes
that handles custom calls @static_assertion.

As a beneficial side-effect now we get shape constraint checking for jax2tf
graph serialization when the resulting function is executed in graph mode.
2023-07-19 09:56:33 +03:00
Peter Hawkins
59509dc2b3 Remove the jax_array config option, which does nothing.
PiperOrigin-RevId: 548981491
2023-07-18 06:16:06 -07:00
Yash Katariya
f0ce0d8c6a Delete in_axis_resources and out_axis_resources from pjit since it's been more than 3 months since their deprecation. The replace is to use in_shardings and out_shardings. You can still pass PartitionSpecs to {in|out}_shardings to pjit.
PiperOrigin-RevId: 548673905
2023-07-17 06:35:49 -07:00
George Necula
603eeb1901 Copybara import of the project:
--
06bf5fe7b2ac97156df541bab989dc5beb1aff0c by George Necula <gcnecula@gmail.com>:

[jax2tf] Added a flag and environment variable to control the serialization version.

This allows us to control the serialization version to be compatible with
the deployed version of tf.XlaCallModule. In particular, we can run
most tests with the maximum available version, while keeping the
default lower.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16746 from gnecula:tf_version 06bf5fe7b2ac97156df541bab989dc5beb1aff0c
PiperOrigin-RevId: 548504243
2023-07-16 09:27:12 -07:00
Peter Hawkins
651f87733b Remove jax_jit_pjit_api_merge.
PiperOrigin-RevId: 548236671
2023-07-14 15:25:00 -07:00
Yash Katariya
89c78bf53f jax.jit now works correctly if both donate_argnums and donate_argnames are specified.
Update the docstring and changelog too to mention `donate_argnames`.

PiperOrigin-RevId: 548223395
2023-07-14 14:28:16 -07:00
Jake VanderPlas
21f6736005 Remove several deprecated APIs 2023-07-11 12:42:32 -07:00
Jake VanderPlas
b581ad1f33 Remove several deprecated jax.Array methods:
- arr.broadcast
- arr.broadcast_in_dim
- arr.split

These have been deprecated since JAX v0.4.5

PiperOrigin-RevId: 547228974
2023-07-11 10:34:27 -07:00
Jake VanderPlas
9962065deb Require ml_dtypes>=0.2 2023-07-07 12:07:44 -07:00
Jake VanderPlas
d0e75ca117 Require index update optional arguments to be passed by keyword.
Passing these keywords by position has been deprecated and has raised a warning since JAX v0.4.7 (Released 27 March 2023)

PiperOrigin-RevId: 544620172
2023-06-30 04:30:34 -07:00
Roy Frostig
48903a382e add corner-case cond resolution fix to changelog 2023-06-28 10:09:10 -07:00
Jake VanderPlas
3f47ad367d jax.interpreters.pxla: remove deprecated functions:
- jax.interpreters.pxla.device_put
- jax.interpreters.pxla.make_sharded_device_array
2023-06-27 21:49:55 -07:00
Yash Katariya
c632cace1e Raise an error if a user passes None to host_local_array_to_global_array or global_array_to_host_local_array
PiperOrigin-RevId: 543596009
2023-06-26 18:15:43 -07:00
Jake VanderPlas
ad35702934 Drop support for numpy 1.21
This is in accordance with NEP 29 and https://jax.readthedocs.io/en/latest/deprecation.html
2023-06-23 10:28:26 -07:00
Yash Katariya
19890086fa [Rollback] Remove py3.8 support from jax as per https://jax.readthedocs.io/en/latest/deprecation.html
PiperOrigin-RevId: 542724110
2023-06-22 18:31:30 -07:00
Skye Wanderman-Milne
20095ab9da Update version numbers and changelog post 0.4.13 release 2023-06-22 17:54:57 -07:00
Peter Hawkins
487b640acf Jax 0.4.13 release. 2023-06-22 14:59:36 -07:00
Skye Wanderman-Milne
10424c5972 Update JAX's XlaExecutable.cost_analysis and related plumbing so it works on Cloud TPU
* Exposes LoadedExecutable.cost_analysis via pybind
* Updates XlaExecutable.cost_analysis to try
  LoadedExecutable.cost_analysis, then fallback to the client method.

PiperOrigin-RevId: 542671990
2023-06-22 14:43:00 -07:00
Peter Hawkins
0ec03dbdce [XLA:Python] Fix __cuda_array_interface__.
Adds a test for __cuda_array_interface__ that does not depend on cupy.

Fixes https://github.com/google/jax/issues/16440

PiperOrigin-RevId: 541965361
2023-06-20 10:09:16 -07:00
Yash Katariya
6007698f4e Allow None to be passed to in_shardings and out_shardings. The default is still UNSPECIFIED to handle edge cases around the old semantics where None is treated as fully replicated.
The semantics are as follow:

* if the mesh context manager is not provided, None will be treated as UNSPECIFIED for both in_shardings and out_shardings

* If the mesh context manager is provided, None will be treated as fully replicated as per the old semantics.

This will make sure that we don't break existing code depending on None meaning replicated but also start making the transition to None meaning UNSPECIFIED for jit and pjit.

PiperOrigin-RevId: 540705660
2023-06-15 15:22:22 -07:00
George Necula
0961fb9eba [jax2tf] Add native_lowering_disabled_checks parameter to jax2tf.convert.
Previously, we had a boolean `native_serialization_strict_checks` parameter
that was disabling all safety checks. This mechanism had several
disadvantages:

  * the mechanism did not differentiate between different safety checks.
    E.g., in order to disable checking of the custom call targets, one
    had to disable checking for all custom call targets, and also the
    checking that the serialization and execution platforms are the same.
  * the mechanism operated only at serialization time. Now, the
    XlaCallModule supports a `disabled_checks` attribute to control
    which safety checks should be disabled.

Here we replace the `native_serialization_strict_checks` with
`native_serialization_disabled_checks`, whose values are sequences
of disabled check descriptors.
2023-06-13 08:04:58 +03:00
Peter Hawkins
0ded163027 Fix cuda12 pip install.
The wheel is now called cudnn89.

Fixes #16362
2023-06-12 15:38:16 -04:00
Skye Wanderman-Milne
4b80103077 Update versions and changelog post 0.4.12 release 2023-06-08 16:53:22 -07:00
Yash Katariya
14451492c9 Delete OpShardingSharding export since it has been 3 months since it was deprecated. Also remove deprecation warnings for MeshPspecSharding.
PiperOrigin-RevId: 538880293
2023-06-08 13:53:38 -07:00
Peter Hawkins
6374b73ce6 [XLA:Python] Fix crash when accessing locals of JAX-generated Python tracebacks.
Use a dummy PyCodeObject in our dummy PyFrameObjects. Using a real PyCodeObject with a fake PyFrameObject confuses CPython 3.11+ when it attempts to compute the locals of a frame, since the frame lacks certain details such as closure information.

This unfortunately means we will not get source column information under Python 3.11 any more, but that is probably better than crashing.

Fixes https://github.com/google/jax/issues/16027

PiperOrigin-RevId: 538873850
2023-06-08 13:22:30 -07:00
Peter Hawkins
a47aca8205 [XLA:Python] Fix incorrect code source information under Python 3.11.
Do not multiply the result of PyFrame_GetLasti() by sizeof(_Py_CODEUNIT), because the CPython implementation already does this inside PyFrame_GetLasti().

* In CPython versions 3.9 or earlier, the f_lasti value in a PyFrameObject was in bytes.
* In CPython 3.10, f_lasti was changed to be in code units, which required multiplying it by sizeof(_Py_CODEUNIT) before passing it to functions like PyCode_Addr2Line(). https://docs.python.org/3/whatsnew/3.10.html#changes-in-the-c-api
* In CPython 3.11, direct access to the representation of the PyFrameObject was removed from the headers, requiring the use of PyFrame_GetLasti() (https://docs.python.org/3/whatsnew/3.11.html#pyframeobject-3-11-hiding). This function multiplies by sizeof(_Py_CODEUNIT) internally (deaf509e8f/Objects/frameobject.c (L1353)) so there is no need for the caller to do this multiplication.

It is difficult to write a good test for this, since the only symptom is slightly inaccurate code line information. This issue was found under a debug mode build of CPython (https://docs.python.org/3/using/configure.html#python-debug-build), where PyCode_Addr2Line() has additional checks for out of range lasti values.

PiperOrigin-RevId: 538847288
2023-06-08 11:42:26 -07:00
jax authors
8d27f20637 Merge pull request #16246 from chrisflesher:scipy-rotation-v3
PiperOrigin-RevId: 538788621
2023-06-08 08:10:58 -07:00
Chris Flesher
5be17ed90c Added scipy.spatial.transform Rotation and Slerp classes 2023-06-08 07:51:32 -05:00
Jake VanderPlas
47ae5bddd7 Mark jax.abstract_arrays as deprecated 2023-06-07 23:36:40 -07:00
Skye Wanderman-Milne
1a3ac88c35 [PJRT:C] Implement PJRT_Device_MemoryStats
Inspired by https://github.com/google/jax/issues/1491#issuecomment-1567696310

PiperOrigin-RevId: 538596599
2023-06-07 14:47:49 -07:00
Peter Hawkins
cb33fdf3f7 Include SASS/PTX for Hopper GPUs. 2023-06-05 09:42:12 -04:00
Jake VanderPlas
3bef6214bb Deprecate jax.numpy functions alltrue, sometrue, product, cumproduct 2023-06-02 04:10:46 -07:00
Yash Katariya
4c48611fba Finish jax and jaxlib 0.4.11 release
PiperOrigin-RevId: 536931532
2023-05-31 23:49:32 -07:00