17018 Commits

Author SHA1 Message Date
Jake VanderPlas
61f50bd3b6 jnp.ufunc: minor cleanups & test fixes 2023-08-14 15:19:46 -07:00
Peter Hawkins
d6e06f4476 Move the XLA commit out of the top-level JAX WORKSPACE file and into a separate .bzl file.
No functional changes intended.

PiperOrigin-RevId: 556906943
2023-08-14 14:09:45 -07:00
jax authors
abf918443e Merge pull request #17108 from froystig:key-array-repr
PiperOrigin-RevId: 556899317
2023-08-14 13:45:16 -07:00
jax authors
92102fbb86 Merge pull request #17101 from jakevdp:ufuncs-jit
PiperOrigin-RevId: 556896763
2023-08-14 13:36:33 -07:00
Jake VanderPlas
b3a02e1b62 jnp.ufunc: add __hash__ method and jit methods by default
This allows the JIT cache to work properly with ufunc methods, because bound
methods are created with a new ID each time.
2023-08-14 13:06:18 -07:00
Peter Hawkins
619377ebc1 Second attempt at fixing funm tolerance for LLVM change.
An LLVM change seems to have made this test fail. The impact seems small, so we can just relax the test tolerance.

PiperOrigin-RevId: 556886248
2023-08-14 13:01:19 -07:00
Roy Frostig
e58f5d283a de-emphasize internal array implementation type in key array repr 2023-08-14 12:48:19 -07:00
jax authors
0d1174a7ca Merge pull request #17103 from jakevdp:fix-numpy-warning
PiperOrigin-RevId: 556838682
2023-08-14 10:41:38 -07:00
Peter Hawkins
3a40cc3ca9 Relax test tolerance for funm test.
PiperOrigin-RevId: 556838400
2023-08-14 10:33:01 -07:00
Jake VanderPlas
227eec159a Ignore numpy deprecation warning 2023-08-14 10:08:57 -07:00
jax authors
3b97ff2b6d Merge pull request #17087 from gnecula:poly_minor_cleanup
PiperOrigin-RevId: 556790275
2023-08-14 07:58:30 -07:00
jax authors
23351bdbf0 Add fields kernel_name and kernel_regeneration_metadata to tpu_custom_call backend config
Add kernel regeneration utility functions.

PiperOrigin-RevId: 556717465
2023-08-14 02:35:20 -07:00
jax authors
a771ca2525 Merge pull request #17086 from gnecula:compat_shape
PiperOrigin-RevId: 556332378
2023-08-12 08:32:59 -07:00
George Necula
b90a7b7539 [shape_poly] Minor cleanup 2023-08-12 09:45:22 +02:00
George Necula
deefdbe3b4 [jax_export] Add backwards compatibility tests for shape_assertion. 2023-08-12 09:34:32 +02:00
jax authors
35d33f620c Instrument metrics to measure compilation cache savings in JAX -> PJRT.
Create metrics:
1) '/jax/compilation_cache/cache_retrieval_time_sec' to record the time duration for getting cache entries.
2) '/jax/compilation_cache/original_compile_time_saved_sec' to record the time saved on cache hits.

PiperOrigin-RevId: 556243588
2023-08-12 00:20:42 -07:00
George Necula
cf4e1d414b [jax2tf] Bump the default JAX serialization version to 7.
This enables shape assertion checking, the support for which
landed in XlaCallModule on July 12th, 2023.

See the CHANGELOG for details.

PiperOrigin-RevId: 556222908
2023-08-11 22:49:31 -07:00
jax authors
580b860284 Merge pull request #17054 from jakevdp:frompyfunc
PiperOrigin-RevId: 556161155
2023-08-11 18:11:29 -07:00
Hyeontaek Lim
3352ea3956 Fix dispatch performance regression with DeviceList
This change fixes the dispatch performance regression caused by switching to
`DeviceList` in pxla.py.

PiperOrigin-RevId: 556117825
2023-08-11 15:43:31 -07:00
Peter Hawkins
48f19f64ff Relax axis type annotations for reduction methods on jax.Array.
It appears that some users pass lists as axis arguments, and these are allowed by the types on the regular `jax.numpy` functions. Relax the type annotations on the methods to match the free functions.

PiperOrigin-RevId: 556085084
2023-08-11 14:11:37 -07:00
jax authors
0eb51bec3b Merge pull request #17080 from jakevdp:tpu-requirements
PiperOrigin-RevId: 556034802
2023-08-11 11:52:32 -07:00
Jake VanderPlas
f2e41b3c27 Add requests dependency to jax[tpu]
Fixes #17075
2023-08-11 11:46:58 -07:00
jax authors
cd24a15188 Reverts 7012a05497faf4d33c967bee3cebc83588234e63
PiperOrigin-RevId: 556001895
2023-08-11 10:30:13 -07:00
Peter Hawkins
0d955f08bd Skip tridiagonal solve batching test on older jaxlibs.
PiperOrigin-RevId: 555975222
2023-08-11 09:12:03 -07:00
Peter Hawkins
78cfdd1b35 Add some more type annotations to lax_numpy.py.
These type annotations are of course mostly ignored because the pytype: skip-file comment, but they help readers if nothing else.

PiperOrigin-RevId: 555955257
2023-08-11 08:07:24 -07:00
Peter Hawkins
bfaffe3183 Add version guards after GPU tridiagonal solve change.
PiperOrigin-RevId: 555931222
2023-08-11 06:41:05 -07:00
Chris Jones
deed8b71b1 [pallas] Minor cleanup to pallas_call_p JVP code.
PiperOrigin-RevId: 555862179
2023-08-11 02:30:35 -07:00
Sharad Vikram
7e1278c040 [Pallas] Simplify indexing logic in Mosaic lowering
PiperOrigin-RevId: 555781285
2023-08-10 21:18:50 -07:00
Parker Schuh
03575c4b33 Pad generated sharding specs with None up to ndims to simplify comparing dims
across different partitioned arguments.

PiperOrigin-RevId: 555712119
2023-08-10 17:02:31 -07:00
Srinivas Vasudevan
7dfc8ff49d Add batching rules to jax.lax.linalg.tridiagonal_solve.
PiperOrigin-RevId: 555700103
2023-08-10 16:25:59 -07:00
jax authors
60c3fdf683 Merge pull request #17069 from jakevdp:unpackbits-count
PiperOrigin-RevId: 555695640
2023-08-10 16:13:43 -07:00
Yash Katariya
5349ea6209 [Memories] Allow device_put outside jax.jit to work with different memory kinds.
Currently only jax.Arrays work. Other types will be added in subsequent CLs.

PiperOrigin-RevId: 555677540
2023-08-10 15:26:19 -07:00
Jake VanderPlas
ad8e719b82 Add jnp.ufunc and jnp.frompyfunc 2023-08-10 14:58:18 -07:00
Jake VanderPlas
4df58052aa jnp.unpackbits: fix handling of count & add tests 2023-08-10 14:34:11 -07:00
Chris Jones
f187352569 [pallas] Remove redundant lines from pallas_call_p JVP rule.
If a tangent is already non-zero, `instantiate_zero` is a no-op.

PiperOrigin-RevId: 555625200
2023-08-10 13:27:41 -07:00
Chris Jones
9b3703462c [pallas] Remove redundant check in pallas_call_p JVP rule.
An identical check already exists a few lines above.

PiperOrigin-RevId: 555624518
2023-08-10 13:18:49 -07:00
Peter Hawkins
1a539bbddc Fix type errors introduced by NINF deprecation change.
PiperOrigin-RevId: 555614877
2023-08-10 12:54:22 -07:00
Peter Hawkins
0e80d959c8 Mark jnp.{NINF,NZERO,PZERO} as deprecated.
This follows the upstream NumPy deprecation of these names (https://github.com/numpy/numpy/pull/24357).

PiperOrigin-RevId: 555548986
2023-08-10 10:25:21 -07:00
John QiangZhang
cf026ce745 Instead of tf.Graph protobuf, we switch to tf.Saved_model for back_compat_tf_test.
PiperOrigin-RevId: 555500398
2023-08-10 08:22:34 -07:00
John QiangZhang
1ddc340a1a Fix typo here. 'gpu' should be 'cuda' or 'rocm' here.
PiperOrigin-RevId: 555495314
2023-08-10 08:08:10 -07:00
George Necula
c3aa3a4c31 [jax2tf] Disable some graph serialization tests on GPU
We recently increased the test coverage of testing for dot_general with different dtype for lhs and rhs. Some of the new combinations of dtypes are not supported by XLA:GPU, and we disable those tests now.

PiperOrigin-RevId: 555465495
2023-08-10 06:30:40 -07:00
Malcolm Reynolds
7012a05497 Rollback, breaks internal project
Reverts 6b8bb7bd5990c5207c8b4f793f8ce0702060c8da

PiperOrigin-RevId: 555455350
2023-08-10 05:39:36 -07:00
jax authors
eb076c4c44 Explicitly set AutoFDO profile version in CompileOptions.
Set the AutoFDO profile version specified in --jax_xla_profile_version
if non-zero. Otherwise, expect that there is a function set in
get_latest_profile_version that will return a non-zero profile version
that should be used. If this function is not set or it returns 0,
set -1 instead to indicate that no attempt should be made to retrieve
an AutoFDO profile later on.

Testing: updated unit tests.
PiperOrigin-RevId: 555333728
2023-08-09 18:24:56 -07:00
jax authors
aac4cdad56 Set up plumbing for adding new compilation-cache-key generation algorithm.
The new cache-key generation algorithm will coexist with the original
version until the new one is fully deployed. While they coexist,
--jax_use_original_compilation_cache_key_generation will determine which
one is used. Once the new algorithm is deployed, the original algorithm
and this flag will be removed.

This change sets up the plumbing. Later changes will implement the new
algorithm.

Testing: test workload.
PiperOrigin-RevId: 555333628
2023-08-09 18:16:22 -07:00
Jevin Jiang
5d8f5c20fa [Mosaic] Remove multiple results check in apply layout.
PiperOrigin-RevId: 555320679
2023-08-09 17:17:25 -07:00
Parker Schuh
74bcd65bbd Make mesh available to custom_partitioning lowering rules.
PiperOrigin-RevId: 555319896
2023-08-09 17:08:57 -07:00
Hyeontaek Lim
97b96bbd4b [JAX] Introduce DeviceList backed by C++ xla::ifrt::DeviceList
This change adds `xla_client.DeviceList` that is implemented in C++
`jax::PyDeviceList`. `jax::PyDeviceList` implements the features of
`pxla._DeviceAssignment` as a functional drop-in replacement.
`jax::PyDeviceList` internally has `xla::ifrt::DeviceList`, which will be used
when using IFRT APIs without having to construct a new copy of a potentially
large device list.

`pxla._DeviceAssignment`'s interface is changed slightly to encourage avoiding
conversion to tuple.

Note that for the backward compatibility (and fast `xla_client.Device`
conversion), `jax::PyDeviceList` still uses a Python tuple whose element can be
any Python object matches `xla_client.Device` interface with duck typing. This
duck typing support will be removed when such use case is deprecated.
Eventually, we can try to avoid any type conversion to remove a shadow copy of
device list in JAX.

PiperOrigin-RevId: 555317152
2023-08-09 16:58:01 -07:00
jax authors
22a005c2a3 Support (u)int4 for jax.dtypes.scalar_type_of()
PiperOrigin-RevId: 555265053
2023-08-09 13:52:04 -07:00
jax authors
1e32fd598d Merge pull request #17042 from cottrell:me
PiperOrigin-RevId: 555226539
2023-08-09 11:42:57 -07:00
Peter Hawkins
9529ed0f4d Remove workarounds for MLIR constant construction.
* https://reviews.llvm.org/D155209 added support to the MLIR Python bindings for passing types like bfloat16 directly if an explicit IR type is provided.
* the crash for non-splat size 1 constants appears fixed at head, although I don't know which change fixed it.

PiperOrigin-RevId: 555225604
2023-08-09 11:34:28 -07:00