Yash Katariya
58e46b48e6
Prepare for jax and jaxlib 0.4.4 release
...
PiperOrigin-RevId: 510152471
jax-v0.4.4
jaxlib-v0.4.4
jax-v0.4.4-rc
2023-02-16 08:37:15 -08:00
Peter Hawkins
c6a99b699e
Remove jax.interpreters.xla.lower_fun.
...
This function has been a stub that does nothing useful for a long time, and the only user I can find is Equinox which already guards this with a hasattr(xla, 'lower_fun') guard.
PiperOrigin-RevId: 510142446
2023-02-16 07:51:15 -08:00
George Necula
a9e886f956
[jax2tf] Enable all native lowering jax2tf tests
...
Filed bugs for the few remaining tests, and disabled them.
Fixed the logging of the compiled HLO on test failure.
PiperOrigin-RevId: 510135651
2023-02-16 07:15:57 -08:00
George Necula
454e4de524
[shape_poly] Fix the lowering for symbolic dimension expressions for division
...
The symbolic dimension expression use the Python semantics for division and remainder, while StableHLO is slightly different.
PiperOrigin-RevId: 510056597
2023-02-15 23:51:23 -08:00
John QiangZhang
d0b42f2ce8
Fix the simple bug on call_tf.replace_non_float and add unittest for floating and complex data type.
...
PiperOrigin-RevId: 510055139
2023-02-15 23:40:54 -08:00
Roy Frostig
26045c49e7
remove core.{aval_method,aval_property}
...
PiperOrigin-RevId: 510043837
2023-02-15 22:22:09 -08:00
jax authors
d8514d0ec6
Merge pull request #14500 from jakevdp:bcsr-matmul-test
...
PiperOrigin-RevId: 510034750
2023-02-15 21:26:06 -08:00
Peter Hawkins
0af9fff5ca
Replace uses of deprecated JAX sharding APIs with their new names in jax.sharding.
...
This change updates:
* {jax.experimental.maps.Mesh, jax.interpreters.pxla.Mesh} to jax.sharding.Mesh
* {jax.experimental.PartitionSpec, jax.experimental.pjit.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.pxla.PartitionSpec} to jax.sharding.PartitionSpec
* jax.experimental.maps.NamedSharding to jax.sharding.NamedSharding.
PiperOrigin-RevId: 510027595
2023-02-15 21:03:03 -08:00
Roy Frostig
1b2a318fd1
remove core.axis_substitution_rules
...
PiperOrigin-RevId: 509989925
2023-02-15 18:42:13 -08:00
Peter Hawkins
768960b4e4
Fix pytype errors.
...
PiperOrigin-RevId: 509984207
2023-02-15 18:12:42 -08:00
Peter Hawkins
37d4ad910a
Remove uses of jax.xla_computation from metadata_test.py
...
Add HLO source path canonicalization regex to trace state key because otherwise MetadataTest.test_source_file_prefix_removal fails due to caching of lowerings with different canonicalization regexs.
PiperOrigin-RevId: 509975754
2023-02-15 17:26:21 -08:00
jax authors
3838d7612a
Merge pull request #14504 from skye:host_callback_pjrt_error
...
PiperOrigin-RevId: 509972891
2023-02-15 17:11:01 -08:00
jax authors
fdc6d946ed
Merge pull request #14479 from skye:cost_analysis
...
PiperOrigin-RevId: 509964786
2023-02-15 16:34:48 -08:00
jax authors
6f1527f81a
Merge pull request #14489 from jakevdp:copy-array
...
PiperOrigin-RevId: 509960582
2023-02-15 16:16:44 -08:00
Skye Wanderman-Milne
c2819cfd91
MeshComputation.cost_analysis() isn't implemented with PJRT C API.
...
This was caught via PJitTest.testLowerCostAnalysis
(e74852f796/tests/pjit_test.py (L998)
). We
don't need to change the test because NotImplementedError is already
caught in Lowered.cost_analysis:
e74852f796/jax/_src/stages.py (L659-L660)
2023-02-16 00:15:24 +00:00
Skye Wanderman-Milne
d9f628c972
Raise a user-friendly error message if in/outfeed-based host_callback stuff is used with PJRT C API.
...
Prior to this change, it would crash horribly instead.
I manually tested by running the following on a Cloud TPU v4-8:
```
JAX_USE_PJRT_C_API_ON_TPU=1 python3 -m pytest tests/host_callback_test.py --tb=no
```
And verifying that all errors were the new error message.
The new error message is:
`host_callback functionality isn't supported with the new Cloud TPU
runtime. See https://jax.readthedocs.io/en/latest/debugging/index.html
and
https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html
for alternatives. Please file a feature request at
https://github.com/google/jax/issues if none of the alternatives are
sufficent.`
2023-02-16 00:12:25 +00:00
Jake VanderPlas
29f91c5038
[sparse] add bcsr_matmul batching tests
2023-02-15 15:46:37 -08:00
Jake VanderPlas
936e4ae101
Add new argument to jax_test rule
...
PiperOrigin-RevId: 509952902
2023-02-15 15:45:47 -08:00
jax authors
7fa24703ec
Merge pull request #14496 from jakevdp:bcsr-concatenate
...
PiperOrigin-RevId: 509949683
2023-02-15 15:32:19 -08:00
Jake VanderPlas
6608242f95
sparse_test: reduce num_generated_cases to avoid timeouts
...
PiperOrigin-RevId: 509941080
2023-02-15 15:00:28 -08:00
Peter Hawkins
cd0533cab0
Replace uses of jnp.ndarray with jax.Array inside JAX.
...
PiperOrigin-RevId: 509939691
2023-02-15 14:53:00 -08:00
Skye Wanderman-Milne
7aa7e158f8
Modify JaxArrayTest.test_defragment to work on any numbers of devices
...
Also skip it when the PJRT C API is enabled, since the C API only supports auto defrag.
PiperOrigin-RevId: 509933635
2023-02-15 14:36:03 -08:00
Roy Frostig
537372a637
remove core.bint
...
PiperOrigin-RevId: 509932914
2023-02-15 14:28:29 -08:00
Jake VanderPlas
f3e5024787
[sparse] implement bcsr_concatenate
2023-02-15 14:10:56 -08:00
jax authors
9b288e9ab9
Merge pull request #14420 from jakevdp:bcoo-broadcast-in-dim
...
PiperOrigin-RevId: 509926024
2023-02-15 14:05:13 -08:00
Roy Frostig
22168a0253
remove core.{bot,Bot}
...
PiperOrigin-RevId: 509884508
2023-02-15 11:13:11 -08:00
jax authors
7fb4e3b26a
Merge pull request #14488 from jakevdp:doc-array-methods
...
PiperOrigin-RevId: 509881671
2023-02-15 11:03:09 -08:00
Jake VanderPlas
d688b6d6f3
[sparse] implement bcsr_broadcast_in_dim
2023-02-15 10:26:10 -08:00
jax authors
22c155798f
Merge pull request #14491 from gnecula:call_tf_test
...
PiperOrigin-RevId: 509865775
2023-02-15 10:09:26 -08:00
George Necula
9adc6e50ce
[jax2tf] Add a test to verify that native lowering can be nested inside non-native lowering
2023-02-15 18:47:11 +01:00
Jake VanderPlas
a6d68581b4
DOC: add better documentation for array methods
2023-02-15 09:21:56 -08:00
Peter Hawkins
b389eed8bf
[JAX] Deprecate jax.experimental.maps.Mesh.
...
PiperOrigin-RevId: 509852142
2023-02-15 09:15:50 -08:00
Jake VanderPlas
58b800db84
jnp.copy: ensure inputs are array-like
2023-02-15 08:29:45 -08:00
Peter Hawkins
00d45feee6
Deprecate uses of jax.experimental.pjit.NamedSharding and jax.experimental.pjit.PartitionSpec.
...
Use the aliases under jax.sharding instead.
PiperOrigin-RevId: 509837529
2023-02-15 08:14:26 -08:00
Peter Hawkins
69b8a03400
Disable some slow tests under asan.
...
PiperOrigin-RevId: 509828659
2023-02-15 07:41:33 -08:00
jax authors
3bd6ca014c
Merge pull request #14469 from gnecula:poly_percentile
...
PiperOrigin-RevId: 509828200
2023-02-15 07:33:26 -08:00
Roy Frostig
cb8dcce2fe
migrate more internal dependencies from jax.core
to jax._src.core
...
PiperOrigin-RevId: 509736368
2023-02-14 23:01:11 -08:00
Yash Katariya
b476661b4a
Add clear_cache endpoint to python pjit and cpp pjit functions.
...
PiperOrigin-RevId: 509696516
2023-02-14 18:46:25 -08:00
jax authors
e74852f796
Merge pull request #14446 from jakevdp:initial-scalar
...
PiperOrigin-RevId: 509661731
2023-02-14 15:55:24 -08:00
Jake VanderPlas
dafb88a649
jax.numpy reductions: require initial to be a scalar
...
This follows the requirements of numpy's reduction API. Non-scalar initial values
can be implemented via .
2023-02-14 15:36:18 -08:00
jax authors
c2b7c5f132
Merge pull request #14474 from jakevdp:doc-array-methods
...
PiperOrigin-RevId: 509639140
2023-02-14 14:29:13 -08:00
Katherine Wu
59e9746552
Fix issue where HLO could not be generated for custom gradient.
...
It appears that the custom gradient function must be traced in the same context as the context in which it was defined. Fixed by shuffling around the default graphs.
PiperOrigin-RevId: 509618802
2023-02-14 13:22:30 -08:00
jax authors
a9ef98992c
Merge pull request #14472 from nouiz:shmap_jep_fixes
...
PiperOrigin-RevId: 509617771
2023-02-14 13:14:33 -08:00
Jake VanderPlas
5958bf0d2f
DOC: improve documentation for jax.Array methods
2023-02-14 13:04:27 -08:00
Jake VanderPlas
967f2118bf
DOC: improve documentation for jax.Array methods
2023-02-14 13:03:10 -08:00
jax authors
5860cfdc71
Merge pull request #14453 from jakevdp:dtypes-doc
...
PiperOrigin-RevId: 509610755
2023-02-14 12:48:40 -08:00
Peter Hawkins
33bed1e520
Opt into higher matmul precision for A100 and TPU tests.
...
PiperOrigin-RevId: 509598465
2023-02-14 12:03:12 -08:00
jax authors
aa98c99d3a
Merge pull request #14275 from xoiga123:fix-jax.numpy.hsplit
...
PiperOrigin-RevId: 509585801
2023-02-14 11:24:55 -08:00
Frederic Bastien
93c93133ea
Use right fct name.
2023-02-14 11:21:16 -08:00
Frederic Bastien
d2bb1e089d
Be consistent in the index used
2023-02-14 11:21:03 -08:00