18921 Commits

Author SHA1 Message Date
Sergei Lebedev
699565ae6d Removed unused _dtype_to_xla_type_string
PiperOrigin-RevId: 592862584
2023-12-21 08:04:57 -08:00
Matthew Johnson
68635692b5 remove use of cast in ad_util
i hate cast
2023-12-20 21:00:42 -08:00
jax authors
f64b80f3fb Update XLA dependency to use revision
70ba83748b.

PiperOrigin-RevId: 592716358
2023-12-20 19:20:28 -08:00
Yash Katariya
9792e00887 Cleanup _find_arg_mismatch logic
PiperOrigin-RevId: 592697969
2023-12-20 17:24:26 -08:00
Yash Katariya
57d74d6d24 Always return a NamedSharding from eager shard_map
PiperOrigin-RevId: 592689808
2023-12-20 16:42:57 -08:00
jax authors
454ffcc36f Merge pull request #19008 from 8bitmp3:jax-docs-advanced-autodiff
PiperOrigin-RevId: 592683399
2023-12-20 16:13:48 -08:00
8bitmp3
6afb83a463 Upgrade JAX Advanced Autodiff 201 2023-12-20 23:58:49 +00:00
David Majnemer
e089c9b84f Internal only changes.
Reverts 4347950d9d018a254fd00bded54ae79df2e71556

PiperOrigin-RevId: 592679160
2023-12-20 15:55:53 -08:00
Jake VanderPlas
ad3d743ca9 jnp.vectorize: support excluding arguments by keyword 2023-12-20 15:38:19 -08:00
jax authors
965fefdbbf Merge pull request #19071 from mattjj:tangent-dtypes
PiperOrigin-RevId: 592664820
2023-12-20 15:02:02 -08:00
jax authors
7ecd22cfcc Exclude test_gpu_memory_allocation from pytest execution.
PiperOrigin-RevId: 592664477
2023-12-20 14:53:33 -08:00
Matthew Johnson
ec7d28c0b2 revise logic for tangent types of extended dtypes
* remove the dead code KeyTangentTy
* replace TyRules.make_tangent with TyRules.zero
* removed ad.instantiate_zeros_aval, which was redundant with ad.instantiate_zeros ever since (1) we removed units and (2) we made Zero carry an aval on it
* fix a bug in backward_pass where we instantiated a Zero at the primal type rather than the corresponding tangent type
* fix _f_bwd in test_keyarray_custom_vjp, which had the wrong type (need to return cotangents for all inputs, we were returning a (float_tangent, key_tangent) pair instead of a (float_tangent, (float_tangent, key_tangent)) nested tuple, see #19009 for a check which catches this and hence includes the same test change

We probably also need a TyRules.add for any extended dtypes that can occur as tangent dtypes, but we currently don't have any tests that exercise that (because all extended dtype tangent types are currently float0). I have some follow-up work to add such a case though!
2023-12-20 14:24:52 -08:00
jax authors
6fcaf79dbb Merge pull request #19070 from mattjj:make-hypothesis-optional
PiperOrigin-RevId: 592646659
2023-12-20 13:40:32 -08:00
Matthew Johnson
d7cceb1ee8 make hypothesis use in all_gather_test.py optional
I think a31129a aka cl/587963496 accidentally made hypothesis a test dependency in tests/all_gather_test.py, rather than following our existing convention as in tests/state_test.py of making it optional. I think it was an accident because there's no discussion of adding hypothesis as a test dependence on the review for that PR/CL.

This PR changes tests/all_gather_test.py to follow the convention for making hypothesis optional.
2023-12-20 13:29:55 -08:00
Dmitri Gribenko
35b8fdc3b2 Integrate LLVM at llvm/llvm-project@7022a24771
Updates LLVM usage to match
[7022a24771c8](https://github.com/llvm/llvm-project/commit/7022a24771c8)

PiperOrigin-RevId: 592546932
2023-12-20 06:52:01 -08:00
jax authors
25ce6d14d0 Merge pull request #19061 from mattjj:issue19059
PiperOrigin-RevId: 592540938
2023-12-20 06:20:03 -08:00
Matthew Johnson
325a0084b9 handle convert_element_type(complex -> real) in constant folding
fixes #19059
2023-12-19 21:21:29 -08:00
jax authors
9d493997a9 Update XLA dependency to use revision
4d7fd1cb34.

PiperOrigin-RevId: 592417487
2023-12-19 20:19:25 -08:00
Yash Katariya
90e47fbc6d Always flatten args and kwargs together i.e. tree_flatten((args, kwargs)) so that we have a uniform in_tree structure everywhere.
Leads to a code cleanup and more standardization in jit.

PiperOrigin-RevId: 592388438
2023-12-19 17:32:07 -08:00
Yash Katariya
72fbdb2eb5 Expose shard_alike via jax.experimental. The API is x, y = shard_like(x, y).
The guarantee provided by this API is that the sharding of `x` and `y` will be the same! What the sharding will be is decided by GSPMD.

The flow of sharding is bidirectional i.e. SPMD will choose what the sharding should be of `x` and `y` depending on it's propagation algorithm. It might end up being that the sharding chosen is not of `x` and `y` but something better. At the end of propagation `x` and `y` will be sharded alike.

The API can be made variadic in the future i.e. `*args = shard_alike(*args)` depending on use cases.

Fixes: https://github.com/google/jax/issues/15600
PiperOrigin-RevId: 592375936
2023-12-19 16:31:33 -08:00
Samuel Agyakwa
21a874b0bc [PJRT C API] Enable GPU Plugin tests internally
PiperOrigin-RevId: 592360226
2023-12-19 15:26:35 -08:00
Yash Katariya
9b6bf2cab0 Call shard_arg fallback in pjit's cpp fast path instead of dropping out completely.
PiperOrigin-RevId: 592344105
2023-12-19 14:26:01 -08:00
Jake VanderPlas
e98bb7c3ab jax.numpy: add trig aliases acos(h), asin(h), atan(h), atan2 2023-12-19 14:15:29 -08:00
Peter Hawkins
67d5c3bdea [JAX:GPU] Add a test that verifies that the XLA_PYTHON_CLIENT_PREALLOCATE environment variable is parsed correctly.
Fixes https://github.com/google/jax/issues/19035

PiperOrigin-RevId: 592322040
2023-12-19 13:06:08 -08:00
jax authors
2b54527d10 Merge pull request #19042 from jakevdp:array-api-linalg
PiperOrigin-RevId: 592302493
2023-12-19 11:54:59 -08:00
Jake VanderPlas
832ac874bd jnp.linalg: add matmul, tensordot, & svdvals 2023-12-19 11:36:09 -08:00
jax authors
cc242bda2a Merge pull request #19026 from adonath:eval_shape_with_static_arguments
PiperOrigin-RevId: 592294162
2023-12-19 11:28:34 -08:00
jax authors
0f6968f580 Merge pull request #18952 from jhrcek:jhrcek/fix-duplicate-words
PiperOrigin-RevId: 592289945
2023-12-19 11:20:15 -08:00
jax authors
1146a51569 Merge pull request #18967 from ROCmSoftwarePlatform:rocm-sparse-fix
PiperOrigin-RevId: 592289911
2023-12-19 11:11:59 -08:00
Jake VanderPlas
cab63114b4 Remove deprecated function jax.numpy.trapz
This was deprecated prior to the JAX 0.4.16 release, so we have now met the 3 month deprecation period specified by the [API Compatiblity Policy](https://jax.readthedocs.io/en/latest/api_compatibility.html).

PiperOrigin-RevId: 592266215
2023-12-19 09:57:39 -08:00
jax authors
c172be1379 Merge pull request #19049 from gnecula:poly_segment
PiperOrigin-RevId: 592190406
2023-12-19 04:22:15 -08:00
George Necula
30bc5a2a5f [shape_poly] Update the jax.ops.segment{max|...} to with with shape polymorphism
The fix is very small, just had to check how we check for cases when tracers
are passed as num_segments. We add tests.
2023-12-19 12:02:39 +02:00
George Necula
bb84e6c22e Improve support for JAX_DUMP_IR_TO.
Previously the environment variable JAX_DUMP_IR_TO controlled
whether and where to dump the MLIR module prior to compilation. Now we move the code for that support from
compiler.py to mlir.py, so that it can be used in other
parts of the code. We also add support for logging to Sponge.

Using this support we now log the module on errors from
refine_polymorphic_shapes.

PiperOrigin-RevId: 592099633
2023-12-18 21:25:45 -08:00
Jan Hrček
4da56dcdd7 Fix duplicate word occurrences 2023-12-19 06:15:30 +01:00
jax authors
e82807297b Update XLA dependency to use revision
536cf3ff52.

PiperOrigin-RevId: 592092290
2023-12-18 20:51:50 -08:00
Roy Frostig
66e86a36ee decrease sample size of dirichlet test
Tweak the tolerance threshold instead.

PiperOrigin-RevId: 592051942
2023-12-18 17:14:09 -08:00
Axel Donath
e8330b5fc5 Add eval_shape example for function with static arguments
Improve wording and formating of dynamic eval_shape example
2023-12-18 19:19:48 -05:00
jax authors
f6fbd474a6 Merge pull request #19038 from jakevdp:trapezoid-err
PiperOrigin-RevId: 592020622
2023-12-18 15:06:07 -08:00
Jake VanderPlas
d7d2b767f1 integrate.trapezoid: fix function name in error message 2023-12-18 14:38:27 -08:00
Parker Schuh
7ba8622719 For custom_partitioning, directly emit call when inside of a shard_map.
PiperOrigin-RevId: 592011427
2023-12-18 14:32:38 -08:00
jax authors
afdb7370b9 Merge pull request #19032 from jakevdp:upload-artifact
PiperOrigin-RevId: 591996988
2023-12-18 13:38:50 -08:00
Jake VanderPlas
9b46e2d6a3 Support float8 in reduce_min & reduce_max 2023-12-18 13:37:45 -08:00
Rahul Batra
d7b2590805 [ROCm]: Lower sparse(some) ops correctly for ROCm
-Lower coo_spmv, coo_spmm, csr_spmv and csr_spmm
	correctly for ROCm
2023-12-18 20:47:44 +00:00
George Necula
cc2a3eb564 Move export backwards compatibility tests out of jax2tf. Step 2.
These tests are independent of TensorFlow, yet by being in the jax2tf package they end up pulling in TensorFlow as a dependency.

This is part of a larger cl/562671314 that ran into OSS build problems.

This is step 2: moves the other test data Python files.

PiperOrigin-RevId: 591934999
2023-12-18 10:16:54 -08:00
Jake VanderPlas
e356d76913 Remove a number of deprecated APIs
All of these were deprecated prior to the JAX 0.4.16 release, on Sept 18 2023.
As of Monday Dec 18, we have met the 3 month deprecation period specified by the [API Compatiblity Policy](https://jax.readthedocs.io/en/latest/api_compatibility.html).

PiperOrigin-RevId: 591933493
2023-12-18 10:08:47 -08:00
George Necula
eed61f68aa Move export backwards compatibility tests out of jax2tf. Step 1.
These tests are independent of TensorFlow, yet by being in the jax2tf package they end up pulling in TensorFlow as a dependency.

This is part of a larger cl/562671314 that ran into OSS build problems.
I am attempting this smaller change first, and afterwards I will move more of the test data files, and then the actual test.

PiperOrigin-RevId: 591927484
2023-12-18 09:49:52 -08:00
Jake VanderPlas
873b33d7a5 CI: bump actions/upload-artifact from 3.1.3 to 4.0.0 2023-12-18 09:36:01 -08:00
jax authors
36cf5afa67 Merge pull request #19016 from gnecula:exp_fix_float0
PiperOrigin-RevId: 591869394
2023-12-18 05:42:17 -08:00
George Necula
7aba11f87f [export] Fix handling of float0 when exporting
There were two problems:
  * the float0 dtype was not part of the schema,
  * there was a bug invoking jax.vjp on a reloaded
    function, because of a mismatch between the type
    of symbolic zeros.

We changed the schema to add `f0`, but we add that
enum with a value larger than existing values, to
preserve backwards compatibility.
2023-12-18 14:51:27 +02:00
jax authors
259c285b10 [Jax] Enable jax_include_full_tracebacks_in_locations by default
PiperOrigin-RevId: 591783126
2023-12-17 21:56:13 -08:00