5522 Commits

Author SHA1 Message Date
Roy Frostig
26045c49e7 remove core.{aval_method,aval_property}
PiperOrigin-RevId: 510043837
2023-02-15 22:22:09 -08:00
jax authors
d8514d0ec6 Merge pull request #14500 from jakevdp:bcsr-matmul-test
PiperOrigin-RevId: 510034750
2023-02-15 21:26:06 -08:00
Peter Hawkins
37d4ad910a Remove uses of jax.xla_computation from metadata_test.py
Add HLO source path canonicalization regex to trace state key because otherwise MetadataTest.test_source_file_prefix_removal fails due to caching of lowerings with different canonicalization regexs.

PiperOrigin-RevId: 509975754
2023-02-15 17:26:21 -08:00
jax authors
3838d7612a Merge pull request #14504 from skye:host_callback_pjrt_error
PiperOrigin-RevId: 509972891
2023-02-15 17:11:01 -08:00
Jake VanderPlas
29f91c5038 [sparse] add bcsr_matmul batching tests 2023-02-15 15:46:37 -08:00
jax authors
7fa24703ec Merge pull request #14496 from jakevdp:bcsr-concatenate
PiperOrigin-RevId: 509949683
2023-02-15 15:32:19 -08:00
Jake VanderPlas
6608242f95 sparse_test: reduce num_generated_cases to avoid timeouts
PiperOrigin-RevId: 509941080
2023-02-15 15:00:28 -08:00
Peter Hawkins
cd0533cab0 Replace uses of jnp.ndarray with jax.Array inside JAX.
PiperOrigin-RevId: 509939691
2023-02-15 14:53:00 -08:00
Skye Wanderman-Milne
7aa7e158f8 Modify JaxArrayTest.test_defragment to work on any numbers of devices
Also skip it when the PJRT C API is enabled, since the C API only supports auto defrag.

PiperOrigin-RevId: 509933635
2023-02-15 14:36:03 -08:00
Roy Frostig
537372a637 remove core.bint
PiperOrigin-RevId: 509932914
2023-02-15 14:28:29 -08:00
Jake VanderPlas
f3e5024787 [sparse] implement bcsr_concatenate 2023-02-15 14:10:56 -08:00
Jake VanderPlas
d688b6d6f3 [sparse] implement bcsr_broadcast_in_dim 2023-02-15 10:26:10 -08:00
Peter Hawkins
b389eed8bf [JAX] Deprecate jax.experimental.maps.Mesh.
PiperOrigin-RevId: 509852142
2023-02-15 09:15:50 -08:00
Peter Hawkins
00d45feee6 Deprecate uses of jax.experimental.pjit.NamedSharding and jax.experimental.pjit.PartitionSpec.
Use the aliases under jax.sharding instead.

PiperOrigin-RevId: 509837529
2023-02-15 08:14:26 -08:00
Peter Hawkins
69b8a03400 Disable some slow tests under asan.
PiperOrigin-RevId: 509828659
2023-02-15 07:41:33 -08:00
Roy Frostig
cb8dcce2fe migrate more internal dependencies from jax.core to jax._src.core
PiperOrigin-RevId: 509736368
2023-02-14 23:01:11 -08:00
Yash Katariya
b476661b4a Add clear_cache endpoint to python pjit and cpp pjit functions.
PiperOrigin-RevId: 509696516
2023-02-14 18:46:25 -08:00
Jake VanderPlas
dafb88a649 jax.numpy reductions: require initial to be a scalar
This follows the requirements of numpy's reduction API. Non-scalar initial values
can be implemented via .
2023-02-14 15:36:18 -08:00
Peter Hawkins
33bed1e520 Opt into higher matmul precision for A100 and TPU tests.
PiperOrigin-RevId: 509598465
2023-02-14 12:03:12 -08:00
jax authors
aa98c99d3a Merge pull request #14275 from xoiga123:fix-jax.numpy.hsplit
PiperOrigin-RevId: 509585801
2023-02-14 11:24:55 -08:00
Yash Katariya
1c651f2ea4 Catch the NaN's and raise a better error message when jax_debug_nans flag is True.
PiperOrigin-RevId: 509552717
2023-02-14 09:27:36 -08:00
Zeynep Cankara
995ef40f68 [JAX] Improve error message when jit tracer passed to a shape.
Adds additional debugging message to the shape explaining why the value is a tracer.

Fixes #14279

PiperOrigin-RevId: 509545985
2023-02-14 09:13:01 -08:00
Jake VanderPlas
15196bc1aa [sparse] enable bcsr_dot_general cusparse lowering
PiperOrigin-RevId: 509537223
2023-02-14 08:32:04 -08:00
Sharad Vikram
442aa028c2 Fix xmap staging rule to handle positional semantics
PiperOrigin-RevId: 509356614
2023-02-13 16:05:17 -08:00
Jake VanderPlas
e1ff0c1d7a Make colab_gpu.ipynb compatible with newer JAX versions
PiperOrigin-RevId: 509356393
2023-02-13 15:56:58 -08:00
Yash Katariya
d0eedf7e57 Plumb spmd_axis_name through batch_jaxpr2 and batch_jaxpr
PiperOrigin-RevId: 509341618
2023-02-13 14:58:20 -08:00
Yash Katariya
2fc64bee13 Change the axis_resources argument of with_sharding_constraint to shardings to match pjit and jit.
PiperOrigin-RevId: 509275107
2023-02-13 10:53:57 -08:00
Jake VanderPlas
58323d5b40 jax.numpy reductions: better validation of initial value 2023-02-13 08:43:25 -08:00
Yash Katariya
6caaffc20c Add in_shardings and out_shardings argument to pjit and jit to start deprecating in_axis_resources and out_axis_resources.
PiperOrigin-RevId: 508934327
2023-02-11 15:30:14 -08:00
jax authors
1bdcd5e138 Merge pull request #14415 from jakevdp:bcsr-matmul
PiperOrigin-RevId: 508785095
2023-02-10 16:55:05 -08:00
jax authors
26ddf3b571 Merge pull request #14419 from jakevdp:spsolve-cpu-lowering
PiperOrigin-RevId: 508777573
2023-02-10 16:16:05 -08:00
Jake VanderPlas
de8a77a3eb [sparse] implement BCSR.__matmul__ 2023-02-10 16:11:57 -08:00
jax authors
fc507f2ebe Merge pull request #14418 from mattjj:vmap-spmd-axis-name-tuples
PiperOrigin-RevId: 508777043
2023-02-10 16:08:32 -08:00
Yash Katariya
0d07372995 Point to the exact primitive name nested under jit/pjit instead of mentioning all possible ones.
PiperOrigin-RevId: 508770290
2023-02-10 15:40:25 -08:00
Jake VanderPlas
552fc2c5a3 [sparse] add CPU lowering rule for sparse.linalg.spsolve 2023-02-10 15:35:42 -08:00
Matthew Johnson
9538bc3e73 generalize vmap spmd_axis_name to accept tuples of axis names
This brings the argument more in line with what can appear as positional
arguments to the PartitionSpec constructor.
2023-02-10 15:25:23 -08:00
Peter Hawkins
2f80e46f64 [XLA:Python] Fix overly pessimistic handling of singleton dimensions in dlpack code.
Requires an accompanying jaxlib change.

Fixes https://github.com/google/jax/issues/14399

PiperOrigin-RevId: 508757315
2023-02-10 14:44:22 -08:00
jax authors
dc6bf9b725 Merge pull request #14408 from lucashofer:scipy_spence
PiperOrigin-RevId: 508756972
2023-02-10 14:36:15 -08:00
Yash Katariya
1526c3e20c Improve the error message which is raised from _get_and_check_device_assignment.
Before:

```
ValueError: Devices of all `Array` inputs and outputs should be the same. Got array device ids [0] on platform CPU and another array's device ids [0, 1, 2, 3] on platform CPU
```

After:

```
ValueError: Received incompatible devices for jitted computation. Got argument inp of ArrayPjitTest.test_jit_with_sharding_constraint_committed_inp_error.<locals>.sharded_inp with bfloat16[8,2] and device ids [0] on platform CPU and with_sharding_constraint or nested pjit or shard_map with device ids [0, 1, 2, 3] on platform CPU at jax/tests/pjit_test.py:2509 (sharded_inp)
```
PiperOrigin-RevId: 508746961
2023-02-10 13:54:15 -08:00
Lucas Hofer
4636276214 added scipy special spence
added dtype to arrays in the _spence_poly function
2023-02-10 20:33:47 +00:00
Peter Hawkins
6ee67639e2 Split PyTorch interoperability tests into their own test.
PiperOrigin-RevId: 508722180
2023-02-10 12:17:11 -08:00
Jake VanderPlas
ac647b9459 [sparse] implement autodiff rules for bcsr_dot_general 2023-02-10 12:00:30 -08:00
Ngo Viet Hoai Bao
82e5767f77 update hsplit and testHVDSplit for 1D array 2023-02-10 14:27:37 +07:00
jax authors
12dc73dc6e Merge pull request #14388 from jakevdp:bcsr-todense-ad
PiperOrigin-RevId: 508477843
2023-02-09 14:41:41 -08:00
Jieying Luo
668b82d529 [PJRT C API] Register a backend factory for every PJRT plugin set in PJRT_NAMES_AND_LIBRARY_PATHS.
Loading TPU PJRT plugin is moved to make_tpu_client.

This change is based on https://github.com/google/jax/pull/14011.

PiperOrigin-RevId: 508477737
2023-02-09 14:33:46 -08:00
Jake VanderPlas
7651866b1d [sparse] implement autodiff rules for bcsr primitives 2023-02-09 14:19:22 -08:00
Rahul Batra
7d0d9b706e [ROCm]: Re-enable Dirichlet Tests on ROCm 2023-02-09 20:19:07 +00:00
Peter Hawkins
8268cd562d Add infrastructure for managing deprecations.
Use it to deprecate jax.experimental.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.interpreters.pxla.Mesh.

PiperOrigin-RevId: 508349776
2023-02-09 05:48:40 -08:00
Matthew Johnson
6fb3ace5d0 [shard-map] add vmap spmd_axis_name support, fix vmap rule bug 2023-02-08 23:54:28 -08:00
jax authors
bd7c227e96 Merge pull request #14373 from mattjj:shmap-check-rep-false
PiperOrigin-RevId: 508219490
2023-02-08 16:49:29 -08:00