841 Commits

Author SHA1 Message Date
Roy Frostig
cb8dcce2fe migrate more internal dependencies from jax.core to jax._src.core
PiperOrigin-RevId: 509736368
2023-02-14 23:01:11 -08:00
Yash Katariya
b476661b4a Add clear_cache endpoint to python pjit and cpp pjit functions.
PiperOrigin-RevId: 509696516
2023-02-14 18:46:25 -08:00
Peter Hawkins
33bed1e520 Opt into higher matmul precision for A100 and TPU tests.
PiperOrigin-RevId: 509598465
2023-02-14 12:03:12 -08:00
Yash Katariya
1526c3e20c Improve the error message which is raised from _get_and_check_device_assignment.
Before:

```
ValueError: Devices of all `Array` inputs and outputs should be the same. Got array device ids [0] on platform CPU and another array's device ids [0, 1, 2, 3] on platform CPU
```

After:

```
ValueError: Received incompatible devices for jitted computation. Got argument inp of ArrayPjitTest.test_jit_with_sharding_constraint_committed_inp_error.<locals>.sharded_inp with bfloat16[8,2] and device ids [0] on platform CPU and with_sharding_constraint or nested pjit or shard_map with device ids [0, 1, 2, 3] on platform CPU at jax/tests/pjit_test.py:2509 (sharded_inp)
```
PiperOrigin-RevId: 508746961
2023-02-10 13:54:15 -08:00
Peter Hawkins
8268cd562d Add infrastructure for managing deprecations.
Use it to deprecate jax.experimental.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.interpreters.pxla.Mesh.

PiperOrigin-RevId: 508349776
2023-02-09 05:48:40 -08:00
Yash Katariya
7350f00acd Remove jax_experimental_subjaxpr_lowering_cache since it was only for jit and was False by default. Now that jit/pjit are merged, this cache is not needed since pjit does the caching and we get it for free.
PiperOrigin-RevId: 508191408
2023-02-08 14:55:56 -08:00
Peter Hawkins
cc8d7fae32 Move jax.interpreters.mlir to jax._src.interpreters.mlir.
Replace jax.interpreters.mlir with a shim that re-exports names that are likely to be used externally.

PiperOrigin-RevId: 508187063
2023-02-08 14:39:01 -08:00
Peter Hawkins
3d9ae6b467 Add a .cost_analysis() on lowered but uncompiled computations.
Allows users to call XLA's HLO cost analysis without using internal APIs. In practice plenty of users appear to be doing this using non-public APIs, so we may as well offer a supported API for it.

PiperOrigin-RevId: 507560058
2023-02-06 12:57:57 -08:00
Peter Hawkins
428189f8fb Replace uses of deprecated JAX sharding APIs with their new names in jax.sharding.
This change updates:
* {jax.experimental.maps.Mesh, jax.interpreters.pxla.Mesh} to jax.sharding.Mesh
* {jax.experimental.PartitionSpec, jax.experimental.pjit.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.pxla.PartitionSpec} to jax.sharding.PartitionSpec
* jax.experimental.maps.NamedSharding to jax.sharding.NamedSharding.

PiperOrigin-RevId: 506994892
2023-02-03 14:28:45 -08:00
Matthew Johnson
cd615b6be8 skip custom_jvp/vjp tests which dont work with initial-style staging
These tests, involving nondiff_argnums and/or closing over tracers, happen to
work with final-style JIT but not our initial-style primitives. We shouldn't
support this behavior anyway; there are good alternatives.
2023-02-01 20:34:47 -08:00
Yash Katariya
8a4de1f86a Remove the usage of _arrays from tests
PiperOrigin-RevId: 505871063
2023-01-30 20:02:37 -08:00
Jake VanderPlas
43e57db77a Begin deprecation of public jax.ShapedArray 2023-01-30 11:27:58 -08:00
Jake VanderPlas
3564cd8f1c Fix typo in autodidax test 2023-01-27 12:28:53 -08:00
Jake VanderPlas
c89b537f3a Add smoketest for autodidax 2023-01-27 08:18:01 -08:00
Yash Katariya
1641c8f141 Don't run test_mismatched_nested_backends test with pjit and jit because jax_jit_pjit_api_merge will do that for us.
PiperOrigin-RevId: 504168144
2023-01-23 21:56:30 -08:00
Yash Katariya
fb9b5ec1e4 Add dce_rules for pjit primitive so that remat can DCE through the pjit primitive and remove unused residuals
PiperOrigin-RevId: 504123801
2023-01-23 17:32:20 -08:00
Yash Katariya
cb9a9952fe Check if the sharding input to ShapeDtypeStruct is an instance of Sharding
PiperOrigin-RevId: 502652848
2023-01-17 12:08:51 -08:00
Yash Katariya
acd8dadc74 Make some api_test pass with jit/pjit merge
PiperOrigin-RevId: 501392938
2023-01-11 15:21:16 -08:00
Jake VanderPlas
f317943f56 Warn rather than fail when reloading JAX
Fixes https://github.com/google/jax/issues/13857

PiperOrigin-RevId: 500727768
2023-01-09 09:11:50 -08:00
Jake VanderPlas
4b7e72c218 validate shape & dtype in ShapeDtypeStruct 2023-01-03 09:00:59 -08:00
Jake VanderPlas
53676932e8 Error on numpy masked array inputs. 2022-12-27 15:42:49 -08:00
Yash Katariya
57840dd916 Move functions into api_util.py and dispatch.py to remove circular import error when pjit is imported in api.py for merging the jit and pjit frontend API.
PiperOrigin-RevId: 497172760
2022-12-22 08:42:05 -08:00
Matthew Johnson
c2d9b5cee6 tweak dot_general pretty-printing rule to suppress default params 2022-12-21 10:29:01 -08:00
jax authors
87aa3aaf00 Merge pull request #13735 from jakevdp:private-linear-util
PiperOrigin-RevId: 496896110
2022-12-21 05:15:38 -08:00
jax authors
388ab7fff6 Merge pull request #13734 from mattjj:print-saved-residuals-tweaks
PiperOrigin-RevId: 496778515
2022-12-20 16:24:27 -08:00
Jake VanderPlas
4a6bbde409 Move jax.linear_util to jax._src.linear_util 2022-12-20 14:49:27 -08:00
Matthew Johnson
580fdb6e15 tweak print_saved_residuals 2022-12-20 12:00:46 -08:00
Peter Hawkins
2c6c30d458 Bump the minimum jaxlib version to 0.4.1.
Jaxlib 0.4.1 has XLA client version 109 and MLIR API version 39.
2022-12-19 17:49:24 +00:00
Yash Katariya
7d4ef891af Add device and backend API to pjit but resolve them away in infer_params. This is to merge jit and pjit frontend API.
The semantics of mentioning `device` or `backend` on `pjit` is the same as doing a `device_put` i.e. no matter which device the arg is on, reshard it to the device mentioned.

PiperOrigin-RevId: 495437165
2022-12-14 15:41:59 -08:00
Peter Hawkins
73de02d5ce Make JAX tests pass under NumPy 1.24.0rc2.
* allow rc2 in numpy versions when parsed by tests.
* don't cast np.empty(), which can lead to cast errors.
* NumPy 1.24 now warns on overflowing scalar int to array casts in more
places.
2022-12-08 19:46:10 +00:00
Jake VanderPlas
26d9837b36 Switch to new-style f-strings 2022-12-01 09:14:16 -08:00
jax authors
66ad07b5b7 Merge pull request #13442 from jakevdp:x64-api-test
PiperOrigin-RevId: 491933522
2022-11-30 09:04:46 -08:00
Jake VanderPlas
e7f53479e2 Some cleanups related to dropping Python 3.7 2022-11-29 15:54:49 -08:00
Jake VanderPlas
e916a49d6c [x64] update api_test for type safety 2022-11-29 14:32:15 -08:00
Johannes Reifferscheid
575c2f3783 Skip unsupported tests on XLA:CPU MLIR.
PiperOrigin-RevId: 490754048
2022-11-24 09:56:59 -08:00
Roy Frostig
ef9b2fe4a1 custom vmap: support closure and staged constants
The `custom_vmap` primitive stages out its wrapped function at call
time. It might extract closed-over or otherwise constant values
("consts") in doing so. To handle these, we can reduce back to the
empty closure setting: convert the consts to formal arguments, both in
the target function and in the custom vmap rule, and ignore them in
the latter.

We only need to play this trick once, on initial entry. After that, we
can resume in assuming an empty closure.
2022-11-22 17:44:08 -08:00
Yash Katariya
eca12411e7 Disable some tests with jax.Array that are failing in OSS due to using minimum_jaxlib_version. I will bump the version again this week.
PiperOrigin-RevId: 488708528
2022-11-15 11:10:29 -08:00
Yash Katariya
c42bad85ef Make MeshPspecSharding an alias for NamedSharding (it was the other way around before this CL).
PiperOrigin-RevId: 488473538
2022-11-14 14:44:00 -08:00
Sharad Vikram
74b136e62c Delete jax_experimental_name_stack flag
PiperOrigin-RevId: 487601864
2022-11-10 11:59:50 -08:00
Eugene Burmako
55996328f2 Introduce XlaLowering::stablehlo() and use it in associated APIs
See tests/api_test.py for usage examples.

At the moment, stablehlo() works by using the hlo-legalize-to-stablehlo pass, which takes MHLO natively produced by JAX and converts it into StableHLO. This is an intermediate step towards switching JAX to natively produce StableHLO.

This CL adds both mhlo_to_stablehlo and stablehlo_to_mhlo to jaxlib, even though only the former is used at the moment. This is done in anticipation of switching JAX to natively produce StableHLO, where stablehlo_to_mhlo will be needed to provide backward compatibility for XlaLowering::mhlo(). We're adding stablehlo_to_mhlo now, so that in the future we don't have to update jaxlib again which will make deployment easier.

PiperOrigin-RevId: 487144342
2022-11-08 22:50:06 -08:00
Yash Katariya
e161d20dc3 Improve the error message when the avals a function was AOT compiled with doesn't match the input avals when its called.
PiperOrigin-RevId: 486294881
2022-11-04 21:25:46 -07:00
Matthew Johnson
4033007979 improve error when f_vjp gets more than one argument
fixes #13099
2022-11-03 15:20:10 -07:00
Adam Paszke
b0621e300c Fix the vmap rule for remat_p
It treated constants like args, but failed to convert_constvars_jaxpr to
adjust the calling convention.

PiperOrigin-RevId: 485847686
2022-11-03 05:34:09 -07:00
Yash Katariya
cc5af7ed98 Rename ReshapeableDevicesSharding to PositionalSharding and add an alias NamedSharding for MeshPspecSharding.
`MeshPspecSharding` name will be replaced with `NamedSharding` in 3 months.

PiperOrigin-RevId: 485753078
2022-11-02 19:13:13 -07:00
Hyeontaek Lim
bb0702842b Make device_put accept a prefix tree with Sharding leaves as the second argument
PiperOrigin-RevId: 485419880
2022-11-01 14:32:55 -07:00
jax authors
8dea82e089 Merge pull request #13022 from mattjj:leak-checker-improvements
PiperOrigin-RevId: 484640693
2022-10-28 16:05:43 -07:00
Matthew Johnson
6ebf44a681 make leak checker errors explain why objects are alive
Co-authored-by: Qiao Zhang <zhangqiaorjc@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2022-10-28 14:12:17 -07:00
Parker Schuh
5cfc708843 Remove error-prone most_recent_entry() support from lu.cache.
PiperOrigin-RevId: 484382188
2022-10-27 16:41:44 -07:00
Peter Hawkins
ce9e009c4c [JAX:CPU] Enable buffer donation on CPU.
Fix a bug in PJRT where if a buffer was not owned (e.g., it aliased a NumPy buffer) it could still be donated and that would lead to a use after free.

PiperOrigin-RevId: 484001545
2022-10-26 10:13:01 -07:00
Matthew Johnson
60b236cff0 improve (and shorten!) pmap error messages about inconsistent axis sizes 2022-10-20 18:31:40 -07:00