13579 Commits

Author SHA1 Message Date
Parker Schuh
5cfc708843 Remove error-prone most_recent_entry() support from lu.cache.
PiperOrigin-RevId: 484382188
2022-10-27 16:41:44 -07:00
Tianjian Lu
66e75edd0b [sparse] Update the default CUSPARSE SPMV and SPMM algorithms in jaxlib.
PiperOrigin-RevId: 484351696
2022-10-27 14:34:44 -07:00
Hyeontaek Lim
fc8f40ce0e Internal visibility change
PiperOrigin-RevId: 484340424
2022-10-27 13:49:16 -07:00
jax authors
0c119cbb59 Merge pull request #13013 from ROCmSoftwarePlatform:rocm_update_amdgpu_archs
PiperOrigin-RevId: 484334622
2022-10-27 13:27:36 -07:00
Jake VanderPlas
f699acd931 Internal change
PiperOrigin-RevId: 484330913
2022-10-27 13:12:38 -07:00
Sharad Vikram
3e38675ac4 Update debugging_primitives_test to not use nontrivial floating point text comparisons
PiperOrigin-RevId: 484325096
2022-10-27 12:48:11 -07:00
Hyeontaek Lim
caac5f97d8 Internal changes for Array integration on alternative JAX backends.
PiperOrigin-RevId: 484319105
2022-10-27 12:20:24 -07:00
Sharad Vikram
4d94908e6b Add more shards to for_loop_test_cpu
PiperOrigin-RevId: 484309176
2022-10-27 11:42:10 -07:00
Skye Wanderman-Milne
4b1fd63263 Re-enable skipped test
Fixes #12927

PiperOrigin-RevId: 484304818
2022-10-27 11:25:54 -07:00
Yash Katariya
9f80402845 Add a default PmapSharding option which matches exactly pmap's device placement.
PiperOrigin-RevId: 484289013
2022-10-27 10:28:25 -07:00
Rohit Santhanam
663441007c [ROCm] Added gfx90a and gfx1030. 2022-10-27 16:57:06 +00:00
jax authors
978dcde8d6 MHLO Pretty Print - Enhance type printing for CopyOp, ClampOp, CstrReshapeOp, ComputeReshapeShapeOp, SelectOp.
Based on:
https://github.com/openxla/stablehlo/pull/37

PiperOrigin-RevId: 484271777
test-whl
2022-10-27 09:27:00 -07:00
jax authors
994e0ac1bb Merge pull request #12979 from jakevdp:annotate-lax-numpy-3
PiperOrigin-RevId: 484271670
2022-10-27 09:20:28 -07:00
jax authors
94ac5fbd2f Merge pull request #12997 from hawkinsp:minjaxlib
PiperOrigin-RevId: 484259964
2022-10-27 08:29:23 -07:00
Peter Hawkins
320d531521 Increase the minimum jaxlib version to 0.3.22.
The minimum xla_extension_version is now 98 and the minimum mlir_api_version is now 32.
2022-10-27 10:24:11 -04:00
George Necula
40c85bdab8 [jax2tf] Uses MHLO bytecode for XlaCallModule op.
Most of the changes here have to do with the fact that it is harder to
inspect the converted code, since the MHLO is not in text form. This means
that some tests need to be adjusted, and we are dropping an error message
when the converted code uses custom calls, since the detection was based
on inspecting the text of the MHLO.

PiperOrigin-RevId: 484220490
2022-10-27 04:49:08 -07:00
jax authors
540835f979 Merge pull request #12998 from jakevdp:annotate-reductions
PiperOrigin-RevId: 484126361
2022-10-26 18:41:31 -07:00
jax authors
9abacbdb56 Merge pull request #9079 from NeilGirdhar:annotate_tree
PiperOrigin-RevId: 484114597
2022-10-26 17:30:06 -07:00
jax authors
d2df0faf41 Merge pull request #12996 from mattjj:tweak-jnp-canonicalize-shape
PiperOrigin-RevId: 484100902
2022-10-26 16:24:37 -07:00
Peter Hawkins
bf21391248 [JAX] Change the default pmap() ordering to match the ordering of jax.devices() for single-process TPU jobs.
PiperOrigin-RevId: 484062717
2022-10-26 13:56:07 -07:00
Jake VanderPlas
709ffd7e77 [typing] annotate jax.numpy reduction operations 2022-10-26 13:33:15 -07:00
jax authors
a08ced86f3 Merge pull request #12991 from jakevdp:fix-faq
PiperOrigin-RevId: 484042682
2022-10-26 12:39:00 -07:00
Matthew Johnson
7e341817b4 [dynamic-shapes] tweak jnp.canonicalize_shape logic
The idea with jnp.canonicalize_shape is that it handles non-tuple shapes, i.e.
intended to be scalar-like arguments like Python builtin ints or numpy scalar
types or 0D arrays. To do that, it checks numpy.ndim(shape) == 0. But
numpy.ndim might attempt to convert its argument to a numpy.ndarray, which
breaks when the argument is a tuple with Tracers inside!

Instead, let's just check if the argument is one of the canonical sequence
types (list or tuple) and if so then not even call numpy.ndim.
2022-10-26 12:01:49 -07:00
Peter Hawkins
0814770601 Fix FP8 compilation failure in jaxlib stemming from the CUDA/ROCM merge.
PiperOrigin-RevId: 484026031
2022-10-26 11:40:14 -07:00
James Bradbury
bdde0f0cc2 [mesh_utils] Support single-core 2D meshes
PiperOrigin-RevId: 484026013
2022-10-26 11:32:50 -07:00
Jake VanderPlas
e9194b26b0 FAQ: fix JIT numerics discussion 2022-10-26 11:30:17 -07:00
Jake VanderPlas
9c0f876bcc [typing] annotate jnp.pad 2022-10-26 11:09:52 -07:00
jax authors
db2c8c1bdb Merge pull request #12994 from hawkinsp:docfix
PiperOrigin-RevId: 484015353
2022-10-26 10:55:14 -07:00
Neil Girdhar
b742b04380 Annotate tree_util 2022-10-26 13:38:38 -04:00
Peter Hawkins
71a384d25e Clarify in JAX Basics that JAX array creation is also an operation that requires accelerator dispatch and converting to a regular Python type is a blocking operation. 2022-10-26 13:38:17 -04:00
Peter Hawkins
ce9e009c4c [JAX:CPU] Enable buffer donation on CPU.
Fix a bug in PJRT where if a buffer was not owned (e.g., it aliased a NumPy buffer) it could still be donated and that would lead to a use after free.

PiperOrigin-RevId: 484001545
2022-10-26 10:13:01 -07:00
jax authors
b4fdc12492 Merge pull request #12990 from apaszke:enable-einsum
PiperOrigin-RevId: 484000984
2022-10-26 10:06:13 -07:00
Adam Paszke
6e43ce363e Remove a TODO from the xmap tutorial
xeinsum is already powerful enough to support the example.
2022-10-26 15:44:06 +00:00
George Necula
20e78b0cdb [jax2tf] Improve jax2tf native lowering to work with JAX_ARRAY
One problem is that when the out_axis_resources is unspecified then the
_shard_value function crashes. The fix here is to skip the XlaSharding op
in that case.

This does not fully fix the problem, but it reduces it to b/255511660.

PiperOrigin-RevId: 483907100
2022-10-26 02:15:40 -07:00
jax authors
f45e20fa0c Merge pull request #12981 from mattjj:sick
PiperOrigin-RevId: 483814209
2022-10-25 17:05:34 -07:00
jax authors
0d1e230d97 Merge pull request #12977 from yejingxin:main
PiperOrigin-RevId: 483812465
2022-10-25 16:58:14 -07:00
jax authors
553583dcee Merge pull request #12975 from jakevdp:annotate-lax-numpy-2
PiperOrigin-RevId: 483804114
2022-10-25 16:20:01 -07:00
Matthew Johnson
612bb17508 hopefully fix pjit bug
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2022-10-25 16:01:55 -07:00
jax authors
6fd9cb1920 Merge pull request #12973 from mattjj:devices-sharding
PiperOrigin-RevId: 483790957
2022-10-25 15:26:12 -07:00
Hyeontaek Lim
5a8f2dd885 Add additional pjit tests using a trivial computation.
PiperOrigin-RevId: 483781291
2022-10-25 14:46:53 -07:00
Yash Katariya
55b673e3a0 autodiff pjit partial eval
PiperOrigin-RevId: 483780041
2022-10-25 14:45:52 -07:00
Matthew Johnson
95eb4249bb tweaks to DevicesSharding
1. rename DevicesSharding -> ReshapeableDevicesSharding
2. fix repr to print device order faithfully
3. respect shape of np.ndarray argument to __init__
2022-10-25 14:28:48 -07:00
Anlun Xu
f5c7a2d444 [jax][xla:runtime:cpu] Enable JAX compilation cache for CPU
JAX CPU compilation cache is enabled under XLA_FLAGS=--xla_cpu_use_xla_runtime=true

PiperOrigin-RevId: 483775788
2022-10-25 14:25:12 -07:00
Jake VanderPlas
a5bccc8bf9 [typing] annotate next chunk of lax_numpy.py 2022-10-25 14:03:43 -07:00
Jingxin Ye
63964237b2 Skip two unit tests about custom sharding on libtpu
DETAILS:
Due to xc.register_custom_call_partitioner is not supported on libtpu, the following two tests are skipped:
tests/pjit_test.py::PJitTest::test_custom_partitioner
tests/debugging_primitives_test.py::InspectShardingTest::test_inspect_sharding_is_called_in_pjit
2022-10-25 20:55:15 +00:00
jax authors
05f78d73a0 Merge pull request #12960 from jakevdp:annotate-lax-numpy
PiperOrigin-RevId: 483764493
2022-10-25 13:48:39 -07:00
Yash Katariya
b0a1deaa21 Use the device_assignment from mesh if available and find the residual_shardings by lowering to XLA.
PiperOrigin-RevId: 483764290
2022-10-25 13:41:54 -07:00
Jake VanderPlas
2f27d516d7 [typing] annotate next part of lax_numpy.py 2022-10-25 12:36:26 -07:00
Yash Katariya
cf6b5097d0 Remove pytest_benchmark for test-requirements.txt and move the benchmark file which was using that package to use google_benchmark.
PiperOrigin-RevId: 483736267
2022-10-25 11:59:32 -07:00
jax authors
548d7f4599 Merge pull request #12976 from jakevdp:cuda-release-comments
PiperOrigin-RevId: 483734792
2022-10-25 11:41:50 -07:00