Parker Schuh
5cfc708843
Remove error-prone most_recent_entry() support from lu.cache.
...
PiperOrigin-RevId: 484382188
2022-10-27 16:41:44 -07:00
Tianjian Lu
66e75edd0b
[sparse] Update the default CUSPARSE SPMV and SPMM algorithms in jaxlib.
...
PiperOrigin-RevId: 484351696
2022-10-27 14:34:44 -07:00
Hyeontaek Lim
fc8f40ce0e
Internal visibility change
...
PiperOrigin-RevId: 484340424
2022-10-27 13:49:16 -07:00
jax authors
0c119cbb59
Merge pull request #13013 from ROCmSoftwarePlatform:rocm_update_amdgpu_archs
...
PiperOrigin-RevId: 484334622
2022-10-27 13:27:36 -07:00
Jake VanderPlas
f699acd931
Internal change
...
PiperOrigin-RevId: 484330913
2022-10-27 13:12:38 -07:00
Sharad Vikram
3e38675ac4
Update debugging_primitives_test to not use nontrivial floating point text comparisons
...
PiperOrigin-RevId: 484325096
2022-10-27 12:48:11 -07:00
Hyeontaek Lim
caac5f97d8
Internal changes for Array integration on alternative JAX backends.
...
PiperOrigin-RevId: 484319105
2022-10-27 12:20:24 -07:00
Sharad Vikram
4d94908e6b
Add more shards to for_loop_test_cpu
...
PiperOrigin-RevId: 484309176
2022-10-27 11:42:10 -07:00
Skye Wanderman-Milne
4b1fd63263
Re-enable skipped test
...
Fixes #12927
PiperOrigin-RevId: 484304818
2022-10-27 11:25:54 -07:00
Yash Katariya
9f80402845
Add a default PmapSharding
option which matches exactly pmap
's device placement.
...
PiperOrigin-RevId: 484289013
2022-10-27 10:28:25 -07:00
Rohit Santhanam
663441007c
[ROCm] Added gfx90a and gfx1030.
2022-10-27 16:57:06 +00:00
jax authors
978dcde8d6
MHLO Pretty Print - Enhance type printing for CopyOp, ClampOp, CstrReshapeOp, ComputeReshapeShapeOp, SelectOp.
...
Based on:
https://github.com/openxla/stablehlo/pull/37
PiperOrigin-RevId: 484271777
test-whl
2022-10-27 09:27:00 -07:00
jax authors
994e0ac1bb
Merge pull request #12979 from jakevdp:annotate-lax-numpy-3
...
PiperOrigin-RevId: 484271670
2022-10-27 09:20:28 -07:00
jax authors
94ac5fbd2f
Merge pull request #12997 from hawkinsp:minjaxlib
...
PiperOrigin-RevId: 484259964
2022-10-27 08:29:23 -07:00
Peter Hawkins
320d531521
Increase the minimum jaxlib version to 0.3.22.
...
The minimum xla_extension_version is now 98 and the minimum mlir_api_version is now 32.
2022-10-27 10:24:11 -04:00
George Necula
40c85bdab8
[jax2tf] Uses MHLO bytecode for XlaCallModule op.
...
Most of the changes here have to do with the fact that it is harder to
inspect the converted code, since the MHLO is not in text form. This means
that some tests need to be adjusted, and we are dropping an error message
when the converted code uses custom calls, since the detection was based
on inspecting the text of the MHLO.
PiperOrigin-RevId: 484220490
2022-10-27 04:49:08 -07:00
jax authors
540835f979
Merge pull request #12998 from jakevdp:annotate-reductions
...
PiperOrigin-RevId: 484126361
2022-10-26 18:41:31 -07:00
jax authors
9abacbdb56
Merge pull request #9079 from NeilGirdhar:annotate_tree
...
PiperOrigin-RevId: 484114597
2022-10-26 17:30:06 -07:00
jax authors
d2df0faf41
Merge pull request #12996 from mattjj:tweak-jnp-canonicalize-shape
...
PiperOrigin-RevId: 484100902
2022-10-26 16:24:37 -07:00
Peter Hawkins
bf21391248
[JAX] Change the default pmap() ordering to match the ordering of jax.devices() for single-process TPU jobs.
...
PiperOrigin-RevId: 484062717
2022-10-26 13:56:07 -07:00
Jake VanderPlas
709ffd7e77
[typing] annotate jax.numpy reduction operations
2022-10-26 13:33:15 -07:00
jax authors
a08ced86f3
Merge pull request #12991 from jakevdp:fix-faq
...
PiperOrigin-RevId: 484042682
2022-10-26 12:39:00 -07:00
Matthew Johnson
7e341817b4
[dynamic-shapes] tweak jnp.canonicalize_shape logic
...
The idea with jnp.canonicalize_shape is that it handles non-tuple shapes, i.e.
intended to be scalar-like arguments like Python builtin ints or numpy scalar
types or 0D arrays. To do that, it checks numpy.ndim(shape) == 0. But
numpy.ndim might attempt to convert its argument to a numpy.ndarray, which
breaks when the argument is a tuple with Tracers inside!
Instead, let's just check if the argument is one of the canonical sequence
types (list or tuple) and if so then not even call numpy.ndim.
2022-10-26 12:01:49 -07:00
Peter Hawkins
0814770601
Fix FP8 compilation failure in jaxlib stemming from the CUDA/ROCM merge.
...
PiperOrigin-RevId: 484026031
2022-10-26 11:40:14 -07:00
James Bradbury
bdde0f0cc2
[mesh_utils] Support single-core 2D meshes
...
PiperOrigin-RevId: 484026013
2022-10-26 11:32:50 -07:00
Jake VanderPlas
e9194b26b0
FAQ: fix JIT numerics discussion
2022-10-26 11:30:17 -07:00
Jake VanderPlas
9c0f876bcc
[typing] annotate jnp.pad
2022-10-26 11:09:52 -07:00
jax authors
db2c8c1bdb
Merge pull request #12994 from hawkinsp:docfix
...
PiperOrigin-RevId: 484015353
2022-10-26 10:55:14 -07:00
Neil Girdhar
b742b04380
Annotate tree_util
2022-10-26 13:38:38 -04:00
Peter Hawkins
71a384d25e
Clarify in JAX Basics that JAX array creation is also an operation that requires accelerator dispatch and converting to a regular Python type is a blocking operation.
2022-10-26 13:38:17 -04:00
Peter Hawkins
ce9e009c4c
[JAX:CPU] Enable buffer donation on CPU.
...
Fix a bug in PJRT where if a buffer was not owned (e.g., it aliased a NumPy buffer) it could still be donated and that would lead to a use after free.
PiperOrigin-RevId: 484001545
2022-10-26 10:13:01 -07:00
jax authors
b4fdc12492
Merge pull request #12990 from apaszke:enable-einsum
...
PiperOrigin-RevId: 484000984
2022-10-26 10:06:13 -07:00
Adam Paszke
6e43ce363e
Remove a TODO from the xmap tutorial
...
xeinsum is already powerful enough to support the example.
2022-10-26 15:44:06 +00:00
George Necula
20e78b0cdb
[jax2tf] Improve jax2tf native lowering to work with JAX_ARRAY
...
One problem is that when the out_axis_resources is unspecified then the
_shard_value function crashes. The fix here is to skip the XlaSharding op
in that case.
This does not fully fix the problem, but it reduces it to b/255511660.
PiperOrigin-RevId: 483907100
2022-10-26 02:15:40 -07:00
jax authors
f45e20fa0c
Merge pull request #12981 from mattjj:sick
...
PiperOrigin-RevId: 483814209
2022-10-25 17:05:34 -07:00
jax authors
0d1e230d97
Merge pull request #12977 from yejingxin:main
...
PiperOrigin-RevId: 483812465
2022-10-25 16:58:14 -07:00
jax authors
553583dcee
Merge pull request #12975 from jakevdp:annotate-lax-numpy-2
...
PiperOrigin-RevId: 483804114
2022-10-25 16:20:01 -07:00
Matthew Johnson
612bb17508
hopefully fix pjit bug
...
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2022-10-25 16:01:55 -07:00
jax authors
6fd9cb1920
Merge pull request #12973 from mattjj:devices-sharding
...
PiperOrigin-RevId: 483790957
2022-10-25 15:26:12 -07:00
Hyeontaek Lim
5a8f2dd885
Add additional pjit tests using a trivial computation.
...
PiperOrigin-RevId: 483781291
2022-10-25 14:46:53 -07:00
Yash Katariya
55b673e3a0
autodiff pjit partial eval
...
PiperOrigin-RevId: 483780041
2022-10-25 14:45:52 -07:00
Matthew Johnson
95eb4249bb
tweaks to DevicesSharding
...
1. rename DevicesSharding -> ReshapeableDevicesSharding
2. fix repr to print device order faithfully
3. respect shape of np.ndarray argument to __init__
2022-10-25 14:28:48 -07:00
Anlun Xu
f5c7a2d444
[jax][xla:runtime:cpu] Enable JAX compilation cache for CPU
...
JAX CPU compilation cache is enabled under XLA_FLAGS=--xla_cpu_use_xla_runtime=true
PiperOrigin-RevId: 483775788
2022-10-25 14:25:12 -07:00
Jake VanderPlas
a5bccc8bf9
[typing] annotate next chunk of lax_numpy.py
2022-10-25 14:03:43 -07:00
Jingxin Ye
63964237b2
Skip two unit tests about custom sharding on libtpu
...
DETAILS:
Due to xc.register_custom_call_partitioner is not supported on libtpu, the following two tests are skipped:
tests/pjit_test.py::PJitTest::test_custom_partitioner
tests/debugging_primitives_test.py::InspectShardingTest::test_inspect_sharding_is_called_in_pjit
2022-10-25 20:55:15 +00:00
jax authors
05f78d73a0
Merge pull request #12960 from jakevdp:annotate-lax-numpy
...
PiperOrigin-RevId: 483764493
2022-10-25 13:48:39 -07:00
Yash Katariya
b0a1deaa21
Use the device_assignment from mesh if available and find the residual_shardings by lowering to XLA.
...
PiperOrigin-RevId: 483764290
2022-10-25 13:41:54 -07:00
Jake VanderPlas
2f27d516d7
[typing] annotate next part of lax_numpy.py
2022-10-25 12:36:26 -07:00
Yash Katariya
cf6b5097d0
Remove pytest_benchmark for test-requirements.txt and move the benchmark file which was using that package to use google_benchmark.
...
PiperOrigin-RevId: 483736267
2022-10-25 11:59:32 -07:00
jax authors
548d7f4599
Merge pull request #12976 from jakevdp:cuda-release-comments
...
PiperOrigin-RevId: 483734792
2022-10-25 11:41:50 -07:00