14859 Commits

Author SHA1 Message Date
Yash Katariya
58e46b48e6 Prepare for jax and jaxlib 0.4.4 release
PiperOrigin-RevId: 510152471
jax-v0.4.4 jaxlib-v0.4.4 jax-v0.4.4-rc
2023-02-16 08:37:15 -08:00
Peter Hawkins
c6a99b699e Remove jax.interpreters.xla.lower_fun.
This function has been a stub that does nothing useful for a long time, and the only user I can find is Equinox which already guards this with a hasattr(xla, 'lower_fun') guard.

PiperOrigin-RevId: 510142446
2023-02-16 07:51:15 -08:00
George Necula
a9e886f956 [jax2tf] Enable all native lowering jax2tf tests
Filed bugs for the few remaining tests, and disabled them.
Fixed the logging of the compiled HLO on test failure.

PiperOrigin-RevId: 510135651
2023-02-16 07:15:57 -08:00
George Necula
454e4de524 [shape_poly] Fix the lowering for symbolic dimension expressions for division
The symbolic dimension expression use the Python semantics for division and remainder, while StableHLO is slightly different.

PiperOrigin-RevId: 510056597
2023-02-15 23:51:23 -08:00
John QiangZhang
d0b42f2ce8 Fix the simple bug on call_tf.replace_non_float and add unittest for floating and complex data type.
PiperOrigin-RevId: 510055139
2023-02-15 23:40:54 -08:00
Roy Frostig
26045c49e7 remove core.{aval_method,aval_property}
PiperOrigin-RevId: 510043837
2023-02-15 22:22:09 -08:00
jax authors
d8514d0ec6 Merge pull request #14500 from jakevdp:bcsr-matmul-test
PiperOrigin-RevId: 510034750
2023-02-15 21:26:06 -08:00
Peter Hawkins
0af9fff5ca Replace uses of deprecated JAX sharding APIs with their new names in jax.sharding.
This change updates:
* {jax.experimental.maps.Mesh, jax.interpreters.pxla.Mesh} to jax.sharding.Mesh
* {jax.experimental.PartitionSpec, jax.experimental.pjit.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.pxla.PartitionSpec} to jax.sharding.PartitionSpec
* jax.experimental.maps.NamedSharding to jax.sharding.NamedSharding.

PiperOrigin-RevId: 510027595
2023-02-15 21:03:03 -08:00
Roy Frostig
1b2a318fd1 remove core.axis_substitution_rules
PiperOrigin-RevId: 509989925
2023-02-15 18:42:13 -08:00
Peter Hawkins
768960b4e4 Fix pytype errors.
PiperOrigin-RevId: 509984207
2023-02-15 18:12:42 -08:00
Peter Hawkins
37d4ad910a Remove uses of jax.xla_computation from metadata_test.py
Add HLO source path canonicalization regex to trace state key because otherwise MetadataTest.test_source_file_prefix_removal fails due to caching of lowerings with different canonicalization regexs.

PiperOrigin-RevId: 509975754
2023-02-15 17:26:21 -08:00
jax authors
3838d7612a Merge pull request #14504 from skye:host_callback_pjrt_error
PiperOrigin-RevId: 509972891
2023-02-15 17:11:01 -08:00
jax authors
fdc6d946ed Merge pull request #14479 from skye:cost_analysis
PiperOrigin-RevId: 509964786
2023-02-15 16:34:48 -08:00
jax authors
6f1527f81a Merge pull request #14489 from jakevdp:copy-array
PiperOrigin-RevId: 509960582
2023-02-15 16:16:44 -08:00
Skye Wanderman-Milne
c2819cfd91 MeshComputation.cost_analysis() isn't implemented with PJRT C API.
This was caught via PJitTest.testLowerCostAnalysis
(e74852f796/tests/pjit_test.py (L998)). We
don't need to change the test because NotImplementedError is already
caught in Lowered.cost_analysis:
e74852f796/jax/_src/stages.py (L659-L660)
2023-02-16 00:15:24 +00:00
Skye Wanderman-Milne
d9f628c972 Raise a user-friendly error message if in/outfeed-based host_callback stuff is used with PJRT C API.
Prior to this change, it would crash horribly instead.

I manually tested by running the following on a Cloud TPU v4-8:
```
JAX_USE_PJRT_C_API_ON_TPU=1 python3 -m pytest tests/host_callback_test.py --tb=no
```
And verifying that all errors were the new error message.

The new error message is:
`host_callback functionality isn't supported with the new Cloud TPU
runtime. See https://jax.readthedocs.io/en/latest/debugging/index.html
and
https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html
for alternatives. Please file a feature request at
https://github.com/google/jax/issues if none of the alternatives are
sufficent.`
2023-02-16 00:12:25 +00:00
Jake VanderPlas
29f91c5038 [sparse] add bcsr_matmul batching tests 2023-02-15 15:46:37 -08:00
Jake VanderPlas
936e4ae101 Add new argument to jax_test rule
PiperOrigin-RevId: 509952902
2023-02-15 15:45:47 -08:00
jax authors
7fa24703ec Merge pull request #14496 from jakevdp:bcsr-concatenate
PiperOrigin-RevId: 509949683
2023-02-15 15:32:19 -08:00
Jake VanderPlas
6608242f95 sparse_test: reduce num_generated_cases to avoid timeouts
PiperOrigin-RevId: 509941080
2023-02-15 15:00:28 -08:00
Peter Hawkins
cd0533cab0 Replace uses of jnp.ndarray with jax.Array inside JAX.
PiperOrigin-RevId: 509939691
2023-02-15 14:53:00 -08:00
Skye Wanderman-Milne
7aa7e158f8 Modify JaxArrayTest.test_defragment to work on any numbers of devices
Also skip it when the PJRT C API is enabled, since the C API only supports auto defrag.

PiperOrigin-RevId: 509933635
2023-02-15 14:36:03 -08:00
Roy Frostig
537372a637 remove core.bint
PiperOrigin-RevId: 509932914
2023-02-15 14:28:29 -08:00
Jake VanderPlas
f3e5024787 [sparse] implement bcsr_concatenate 2023-02-15 14:10:56 -08:00
jax authors
9b288e9ab9 Merge pull request #14420 from jakevdp:bcoo-broadcast-in-dim
PiperOrigin-RevId: 509926024
2023-02-15 14:05:13 -08:00
Roy Frostig
22168a0253 remove core.{bot,Bot}
PiperOrigin-RevId: 509884508
2023-02-15 11:13:11 -08:00
jax authors
7fb4e3b26a Merge pull request #14488 from jakevdp:doc-array-methods
PiperOrigin-RevId: 509881671
2023-02-15 11:03:09 -08:00
Jake VanderPlas
d688b6d6f3 [sparse] implement bcsr_broadcast_in_dim 2023-02-15 10:26:10 -08:00
jax authors
22c155798f Merge pull request #14491 from gnecula:call_tf_test
PiperOrigin-RevId: 509865775
2023-02-15 10:09:26 -08:00
George Necula
9adc6e50ce [jax2tf] Add a test to verify that native lowering can be nested inside non-native lowering 2023-02-15 18:47:11 +01:00
Jake VanderPlas
a6d68581b4 DOC: add better documentation for array methods 2023-02-15 09:21:56 -08:00
Peter Hawkins
b389eed8bf [JAX] Deprecate jax.experimental.maps.Mesh.
PiperOrigin-RevId: 509852142
2023-02-15 09:15:50 -08:00
Jake VanderPlas
58b800db84 jnp.copy: ensure inputs are array-like 2023-02-15 08:29:45 -08:00
Peter Hawkins
00d45feee6 Deprecate uses of jax.experimental.pjit.NamedSharding and jax.experimental.pjit.PartitionSpec.
Use the aliases under jax.sharding instead.

PiperOrigin-RevId: 509837529
2023-02-15 08:14:26 -08:00
Peter Hawkins
69b8a03400 Disable some slow tests under asan.
PiperOrigin-RevId: 509828659
2023-02-15 07:41:33 -08:00
jax authors
3bd6ca014c Merge pull request #14469 from gnecula:poly_percentile
PiperOrigin-RevId: 509828200
2023-02-15 07:33:26 -08:00
Roy Frostig
cb8dcce2fe migrate more internal dependencies from jax.core to jax._src.core
PiperOrigin-RevId: 509736368
2023-02-14 23:01:11 -08:00
Yash Katariya
b476661b4a Add clear_cache endpoint to python pjit and cpp pjit functions.
PiperOrigin-RevId: 509696516
2023-02-14 18:46:25 -08:00
jax authors
e74852f796 Merge pull request #14446 from jakevdp:initial-scalar
PiperOrigin-RevId: 509661731
2023-02-14 15:55:24 -08:00
Jake VanderPlas
dafb88a649 jax.numpy reductions: require initial to be a scalar
This follows the requirements of numpy's reduction API. Non-scalar initial values
can be implemented via .
2023-02-14 15:36:18 -08:00
jax authors
c2b7c5f132 Merge pull request #14474 from jakevdp:doc-array-methods
PiperOrigin-RevId: 509639140
2023-02-14 14:29:13 -08:00
Katherine Wu
59e9746552 Fix issue where HLO could not be generated for custom gradient.
It appears that the custom gradient function must be traced in the same context as the context in which it was defined. Fixed by shuffling around the default graphs.

PiperOrigin-RevId: 509618802
2023-02-14 13:22:30 -08:00
jax authors
a9ef98992c Merge pull request #14472 from nouiz:shmap_jep_fixes
PiperOrigin-RevId: 509617771
2023-02-14 13:14:33 -08:00
Jake VanderPlas
5958bf0d2f DOC: improve documentation for jax.Array methods 2023-02-14 13:04:27 -08:00
Jake VanderPlas
967f2118bf DOC: improve documentation for jax.Array methods 2023-02-14 13:03:10 -08:00
jax authors
5860cfdc71 Merge pull request #14453 from jakevdp:dtypes-doc
PiperOrigin-RevId: 509610755
2023-02-14 12:48:40 -08:00
Peter Hawkins
33bed1e520 Opt into higher matmul precision for A100 and TPU tests.
PiperOrigin-RevId: 509598465
2023-02-14 12:03:12 -08:00
jax authors
aa98c99d3a Merge pull request #14275 from xoiga123:fix-jax.numpy.hsplit
PiperOrigin-RevId: 509585801
2023-02-14 11:24:55 -08:00
Frederic Bastien
93c93133ea Use right fct name. 2023-02-14 11:21:16 -08:00
Frederic Bastien
d2bb1e089d Be consistent in the index used 2023-02-14 11:21:03 -08:00