Roy Frostig
26045c49e7
remove core.{aval_method,aval_property}
...
PiperOrigin-RevId: 510043837
2023-02-15 22:22:09 -08:00
jax authors
d8514d0ec6
Merge pull request #14500 from jakevdp:bcsr-matmul-test
...
PiperOrigin-RevId: 510034750
2023-02-15 21:26:06 -08:00
Peter Hawkins
37d4ad910a
Remove uses of jax.xla_computation from metadata_test.py
...
Add HLO source path canonicalization regex to trace state key because otherwise MetadataTest.test_source_file_prefix_removal fails due to caching of lowerings with different canonicalization regexs.
PiperOrigin-RevId: 509975754
2023-02-15 17:26:21 -08:00
jax authors
3838d7612a
Merge pull request #14504 from skye:host_callback_pjrt_error
...
PiperOrigin-RevId: 509972891
2023-02-15 17:11:01 -08:00
Jake VanderPlas
29f91c5038
[sparse] add bcsr_matmul batching tests
2023-02-15 15:46:37 -08:00
jax authors
7fa24703ec
Merge pull request #14496 from jakevdp:bcsr-concatenate
...
PiperOrigin-RevId: 509949683
2023-02-15 15:32:19 -08:00
Jake VanderPlas
6608242f95
sparse_test: reduce num_generated_cases to avoid timeouts
...
PiperOrigin-RevId: 509941080
2023-02-15 15:00:28 -08:00
Peter Hawkins
cd0533cab0
Replace uses of jnp.ndarray with jax.Array inside JAX.
...
PiperOrigin-RevId: 509939691
2023-02-15 14:53:00 -08:00
Skye Wanderman-Milne
7aa7e158f8
Modify JaxArrayTest.test_defragment to work on any numbers of devices
...
Also skip it when the PJRT C API is enabled, since the C API only supports auto defrag.
PiperOrigin-RevId: 509933635
2023-02-15 14:36:03 -08:00
Roy Frostig
537372a637
remove core.bint
...
PiperOrigin-RevId: 509932914
2023-02-15 14:28:29 -08:00
Jake VanderPlas
f3e5024787
[sparse] implement bcsr_concatenate
2023-02-15 14:10:56 -08:00
Jake VanderPlas
d688b6d6f3
[sparse] implement bcsr_broadcast_in_dim
2023-02-15 10:26:10 -08:00
Peter Hawkins
b389eed8bf
[JAX] Deprecate jax.experimental.maps.Mesh.
...
PiperOrigin-RevId: 509852142
2023-02-15 09:15:50 -08:00
Peter Hawkins
00d45feee6
Deprecate uses of jax.experimental.pjit.NamedSharding and jax.experimental.pjit.PartitionSpec.
...
Use the aliases under jax.sharding instead.
PiperOrigin-RevId: 509837529
2023-02-15 08:14:26 -08:00
Peter Hawkins
69b8a03400
Disable some slow tests under asan.
...
PiperOrigin-RevId: 509828659
2023-02-15 07:41:33 -08:00
Roy Frostig
cb8dcce2fe
migrate more internal dependencies from jax.core
to jax._src.core
...
PiperOrigin-RevId: 509736368
2023-02-14 23:01:11 -08:00
Yash Katariya
b476661b4a
Add clear_cache endpoint to python pjit and cpp pjit functions.
...
PiperOrigin-RevId: 509696516
2023-02-14 18:46:25 -08:00
Jake VanderPlas
dafb88a649
jax.numpy reductions: require initial to be a scalar
...
This follows the requirements of numpy's reduction API. Non-scalar initial values
can be implemented via .
2023-02-14 15:36:18 -08:00
Peter Hawkins
33bed1e520
Opt into higher matmul precision for A100 and TPU tests.
...
PiperOrigin-RevId: 509598465
2023-02-14 12:03:12 -08:00
jax authors
aa98c99d3a
Merge pull request #14275 from xoiga123:fix-jax.numpy.hsplit
...
PiperOrigin-RevId: 509585801
2023-02-14 11:24:55 -08:00
Yash Katariya
1c651f2ea4
Catch the NaN's and raise a better error message when jax_debug_nans flag is True.
...
PiperOrigin-RevId: 509552717
2023-02-14 09:27:36 -08:00
Zeynep Cankara
995ef40f68
[JAX] Improve error message when jit tracer passed to a shape.
...
Adds additional debugging message to the shape explaining why the value is a tracer.
Fixes #14279
PiperOrigin-RevId: 509545985
2023-02-14 09:13:01 -08:00
Jake VanderPlas
15196bc1aa
[sparse] enable bcsr_dot_general cusparse lowering
...
PiperOrigin-RevId: 509537223
2023-02-14 08:32:04 -08:00
Sharad Vikram
442aa028c2
Fix xmap staging rule to handle positional semantics
...
PiperOrigin-RevId: 509356614
2023-02-13 16:05:17 -08:00
Jake VanderPlas
e1ff0c1d7a
Make colab_gpu.ipynb compatible with newer JAX versions
...
PiperOrigin-RevId: 509356393
2023-02-13 15:56:58 -08:00
Yash Katariya
d0eedf7e57
Plumb spmd_axis_name through batch_jaxpr2 and batch_jaxpr
...
PiperOrigin-RevId: 509341618
2023-02-13 14:58:20 -08:00
Yash Katariya
2fc64bee13
Change the axis_resources
argument of with_sharding_constraint
to shardings
to match pjit
and jit
.
...
PiperOrigin-RevId: 509275107
2023-02-13 10:53:57 -08:00
Jake VanderPlas
58323d5b40
jax.numpy reductions: better validation of initial value
2023-02-13 08:43:25 -08:00
Yash Katariya
6caaffc20c
Add in_shardings and out_shardings argument to pjit and jit to start deprecating in_axis_resources and out_axis_resources.
...
PiperOrigin-RevId: 508934327
2023-02-11 15:30:14 -08:00
jax authors
1bdcd5e138
Merge pull request #14415 from jakevdp:bcsr-matmul
...
PiperOrigin-RevId: 508785095
2023-02-10 16:55:05 -08:00
jax authors
26ddf3b571
Merge pull request #14419 from jakevdp:spsolve-cpu-lowering
...
PiperOrigin-RevId: 508777573
2023-02-10 16:16:05 -08:00
Jake VanderPlas
de8a77a3eb
[sparse] implement BCSR.__matmul__
2023-02-10 16:11:57 -08:00
jax authors
fc507f2ebe
Merge pull request #14418 from mattjj:vmap-spmd-axis-name-tuples
...
PiperOrigin-RevId: 508777043
2023-02-10 16:08:32 -08:00
Yash Katariya
0d07372995
Point to the exact primitive name nested under jit/pjit instead of mentioning all possible ones.
...
PiperOrigin-RevId: 508770290
2023-02-10 15:40:25 -08:00
Jake VanderPlas
552fc2c5a3
[sparse] add CPU lowering rule for sparse.linalg.spsolve
2023-02-10 15:35:42 -08:00
Matthew Johnson
9538bc3e73
generalize vmap spmd_axis_name to accept tuples of axis names
...
This brings the argument more in line with what can appear as positional
arguments to the PartitionSpec constructor.
2023-02-10 15:25:23 -08:00
Peter Hawkins
2f80e46f64
[XLA:Python] Fix overly pessimistic handling of singleton dimensions in dlpack code.
...
Requires an accompanying jaxlib change.
Fixes https://github.com/google/jax/issues/14399
PiperOrigin-RevId: 508757315
2023-02-10 14:44:22 -08:00
jax authors
dc6bf9b725
Merge pull request #14408 from lucashofer:scipy_spence
...
PiperOrigin-RevId: 508756972
2023-02-10 14:36:15 -08:00
Yash Katariya
1526c3e20c
Improve the error message which is raised from _get_and_check_device_assignment
.
...
Before:
```
ValueError: Devices of all `Array` inputs and outputs should be the same. Got array device ids [0] on platform CPU and another array's device ids [0, 1, 2, 3] on platform CPU
```
After:
```
ValueError: Received incompatible devices for jitted computation. Got argument inp of ArrayPjitTest.test_jit_with_sharding_constraint_committed_inp_error.<locals>.sharded_inp with bfloat16[8,2] and device ids [0] on platform CPU and with_sharding_constraint or nested pjit or shard_map with device ids [0, 1, 2, 3] on platform CPU at jax/tests/pjit_test.py:2509 (sharded_inp)
```
PiperOrigin-RevId: 508746961
2023-02-10 13:54:15 -08:00
Lucas Hofer
4636276214
added scipy special spence
...
added dtype to arrays in the _spence_poly function
2023-02-10 20:33:47 +00:00
Peter Hawkins
6ee67639e2
Split PyTorch interoperability tests into their own test.
...
PiperOrigin-RevId: 508722180
2023-02-10 12:17:11 -08:00
Jake VanderPlas
ac647b9459
[sparse] implement autodiff rules for bcsr_dot_general
2023-02-10 12:00:30 -08:00
Ngo Viet Hoai Bao
82e5767f77
update hsplit and testHVDSplit for 1D array
2023-02-10 14:27:37 +07:00
jax authors
12dc73dc6e
Merge pull request #14388 from jakevdp:bcsr-todense-ad
...
PiperOrigin-RevId: 508477843
2023-02-09 14:41:41 -08:00
Jieying Luo
668b82d529
[PJRT C API] Register a backend factory for every PJRT plugin set in PJRT_NAMES_AND_LIBRARY_PATHS.
...
Loading TPU PJRT plugin is moved to make_tpu_client.
This change is based on https://github.com/google/jax/pull/14011 .
PiperOrigin-RevId: 508477737
2023-02-09 14:33:46 -08:00
Jake VanderPlas
7651866b1d
[sparse] implement autodiff rules for bcsr primitives
2023-02-09 14:19:22 -08:00
Rahul Batra
7d0d9b706e
[ROCm]: Re-enable Dirichlet Tests on ROCm
2023-02-09 20:19:07 +00:00
Peter Hawkins
8268cd562d
Add infrastructure for managing deprecations.
...
Use it to deprecate jax.experimental.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.interpreters.pxla.Mesh.
PiperOrigin-RevId: 508349776
2023-02-09 05:48:40 -08:00
Matthew Johnson
6fb3ace5d0
[shard-map] add vmap spmd_axis_name support, fix vmap rule bug
2023-02-08 23:54:28 -08:00
jax authors
bd7c227e96
Merge pull request #14373 from mattjj:shmap-check-rep-false
...
PiperOrigin-RevId: 508219490
2023-02-08 16:49:29 -08:00