479 Commits

Author SHA1 Message Date
Jake VanderPlas
91a33362de Deprecate jax.lax.tie_in 2024-01-18 13:13:47 -08:00
Yash Katariya
b8098b1782 Remove indices and devices from shard_arg_handlers and shard_args.
This only affects python dispatch path. This has no impact on the speed of cpp dispatch (which is why benchmarks are **not** regressing).

If your code ends up taking the python dispatch, then something is going wrong anyways.

PiperOrigin-RevId: 596081987
2024-01-05 14:17:14 -08:00
Matthew Johnson
325a0084b9 handle convert_element_type(complex -> real) in constant folding
fixes #19059
2023-12-19 21:21:29 -08:00
Jake VanderPlas
8b74b93501 Test: fix casting warning in betainc test 2023-12-04 14:20:14 -08:00
George Necula
86e99a9e2c Disable test failing in jaxlib GPU build
PiperOrigin-RevId: 585902789
2023-11-28 02:33:52 -08:00
Peter Hawkins
30a0136813 Increase minimum jaxlib version to 0.4.19.
0.4.19 has xla_extension version 207 and mlir_api_version 54.

PiperOrigin-RevId: 583412447
2023-11-17 09:38:31 -08:00
Sergei Lebedev
cbcaac2756 MAINT Migrate remaining internal/test modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.
2023-10-12 17:32:15 +01:00
Peter Hawkins
1885c4933c Add a new internal test utility test_device_matches() and use it instead of equality tests on device_under_test().
This change prepares for allowing more flexible tag matching. For example, we may want to write "gpu" in a test and have it match both "cuda" and "rocm" devices, which we cannot do under the current API but can easily do under this design.

Replace uses of device_under_test() in a context that performs an equality test with a call to test_device_matches().
Replace uses of if_device_under_test() with test_device_matches() and delete if_device_under_test().

PiperOrigin-RevId: 568923117
2023-09-27 12:10:43 -07:00
Jake VanderPlas
95a209f28b Tests: fix some failures for upstream numpy 2023-09-20 12:26:12 -07:00
Jake VanderPlas
56791eb9ec lax_test: adjust TPU tolerance for igamma & friends
PiperOrigin-RevId: 564859109
2023-09-12 15:59:41 -07:00
Qiao Zhang
d4adf0095f Add default jvp and transpose rule for jax.lax.reduce_precision.
PiperOrigin-RevId: 564536160
2023-09-11 16:35:44 -07:00
Peter Hawkins
4f805c2d8f [JAX] Change jax.test_util utilities to have identical tolerances on all platforms.
In cases where this causes TPU tests to fail, relax test tolerances in the test cases themselves.

TPUs are less precise only for specific operations, notably matrix multiplication (for which usually enabling higher-precision matrix multiplication is the right choice if precision is needed), and certain special functions (e.g., log/exp/pow).

The net effect of this change is mostly to tighten up many test tolerances on TPU.

PiperOrigin-RevId: 562953488
2023-09-05 18:48:55 -07:00
Yash Katariya
970f4c9d4d Remove trivial execution from jax since it leads to 100x slower dispatch time.
Trivial computations were added for a pre-omnistaging world. After omnistaging, JAX produces less trivial computations, so there is need for this to exist.

In the future, if we want to support forwarding of inputs to outputs, there would need to be a different way which the C++ dispatch path knows about.

```
jit_trivial_dispatch                                   246µs ± 3%                4µs ± 1%  -98.52%          (p=0.008 n=5+5)
jit_trivial                                            250µs ± 3%                5µs ± 1%  -98.19%          (p=0.008 n=5+5)
```

PiperOrigin-RevId: 560141018
2023-08-25 10:59:48 -07:00
Mateusz Sokół
d183a2c02f ENH: Update numpy exceptions imports 2023-08-07 19:08:41 +02:00
Peter Hawkins
2e042b6195 Enable test for indexing with u8 indices.
4e4eff35bf fixed the underlying XLA problem.

Fixes https://github.com/google/jax/issues/6122 https://github.com/google/jax/issues/16836

PiperOrigin-RevId: 552880163
2023-08-01 12:13:58 -07:00
jax authors
416814df2a Merge pull request #16826 from mattjj:issue16805
PiperOrigin-RevId: 551263673
2023-07-26 11:20:31 -07:00
Jake VanderPlas
0dbda849ef lax.dynamic_slice: avoid negative index correction for unsigned indices 2023-07-25 13:09:09 -07:00
Jake Vanderplas
b4132b4c50 Copybara import of the project:
--
b243ea79ae7c9e2c2aa85e264b8dca8fc4c61b7b by Jake VanderPlas <jakevdp@google.com>:

Rename opaque dtype to extended dtype.

This includes three deprecations:
 - jax.core.is_opaque_dtype(dt) is deprecated in favor of jnp.issubdtype(dt, jax.dtypes.extended)
 - jax.core.has_opaque_dtype(x) is deprecated in favor of jnp.issubdtype(x.dtype, jax.dtypes.extended)
 - the allow_opaque_dtype argument to jax.core.canonicalize_dtype is now allow_extended_dtype
Because jax.core is explicitly excluded from the API deprecation policy, these changes will not be
subject to a standard 3-month deprecation period.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16824 from jakevdp:extended-dtype b243ea79ae7c9e2c2aa85e264b8dca8fc4c61b7b
PiperOrigin-RevId: 550674205
2023-07-24 14:38:20 -07:00
Matthew Johnson
9ddef5cf84 make _dot_general_batch_rule handle python builtin numeric types 2023-07-24 14:01:07 -07:00
Jake VanderPlas
2ffa9bd8df Refactor opaque dtype implementation.
This makes it closer to numpy, with dtypes.OpaqueDtype analogous to np.dtype,
and dtypes.opaque analogous to np.numeric. This will let us replace the
dtypes.is_opaque_dtype function with jnp.issubdtype(dtype, dtypes.opaque).
2023-07-20 19:51:52 -07:00
Jake VanderPlas
1b3da85758 Fix scatter batching rule for scatter_apply
The issue is that the batching rule assumes that each scatter variant
always has the same update_jaxpr. This is not true of scatter_apply, which
lowers to scatter with a custom update_jaxpr. To address this, we change
the batching rule such that it re-uses the input jaxpr rather than always
re-generating it.
2023-07-10 16:42:45 -07:00
Jake VanderPlas
18bbc96279 Fix integer overflow in gather batching rule 2023-06-27 21:45:45 -07:00
Parker Schuh
819f731e8d jax.lax.collapse now takes Nones for stop_dimension.
PiperOrigin-RevId: 543598626
2023-06-26 18:30:34 -07:00
Peter Hawkins
816ba91263 Use lower-case PEP 585 names for types.
Issue https://github.com/google/jax/issues/16537

PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
Yash Katariya
fa099fd262 Simplify sharding types input to physical_hlo_sharding and lower_jaxpr_to_fun.
Make sure lower_jaxpr_to_fun always sees HloSharding in arg_shardings and results_shardings.

Also make sure physical_hlo_sharding only accepts HloSharding as the input.

Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 538342152
2023-06-06 18:03:12 -07:00
Yash Katariya
01fdd91a5f Use _to_xla_hlo_sharding everywhere in JAX. Remove _to_xla_op_sharding in favor of _to_xla_hlo_sharding since constructing a C++ class is faster than protos and will help with further changes coming to HloSharding.
PiperOrigin-RevId: 537969500
2023-06-05 13:41:31 -07:00
Matthew Johnson
61b106ec8f allow lax.dot_general to accept different input dtypes
This change brings the dot_general primitive more in line with the HLO
primitive, as it is described in XLA's shape_inference.cc (but not in the
StableHLO spec). In particular we allow different input dtypes.

The main motivation is to support transposition in the presence of
preferred_element_type (which can set the output dtype to be different from the
inputs), e.g. to fix #10818.

However, because XLA platforms/backends can't seem to codegen all the cases
that are accepted by shape_inference.cc, in our lowering rules we generate
ConvertElementTypes on the inputs in a platform-dependent way.
2023-05-22 10:33:42 -07:00
Roy Frostig
717d3c88fc inline and remove eq_mlir and ne_mlir rules 2023-05-17 20:07:59 -07:00
Roy Frostig
f18bff5371 inline and remove scatter_mlir rules 2023-05-17 20:07:59 -07:00
Roy Frostig
cc54b6e6ad inline and remove select_mlir rules 2023-05-17 20:07:59 -07:00
Roy Frostig
301d058b3d inline and remove gather_mlir rules 2023-05-17 20:07:59 -07:00
Roy Frostig
071c77e5bb inline and remove transpose_mlir rules 2023-05-17 20:07:59 -07:00
Roy Frostig
06132ac764 inline and remove broadcast_in_dim_mlir rules 2023-05-17 20:07:59 -07:00
Roy Frostig
0ac792f4ed inline and remove dynamic_update_slice_mlir rules 2023-05-17 20:07:59 -07:00
Roy Frostig
2dbdf1a6c1 inline and remove dynamic_slice_mlir rules 2023-05-17 20:07:59 -07:00
Roy Frostig
aed77c5031 inline and remove slice_mlir rules 2023-05-17 20:07:58 -07:00
Roy Frostig
129a4a5f35 inline and remove empty_mlir rules 2023-05-17 20:07:58 -07:00
Roy Frostig
180e26dafb remove physical_avals rule in favor of physical_element_aval 2023-05-17 20:07:58 -07:00
Peter Hawkins
eaf7eb2626 Break cycle between _src/core.py and _src/dtypes.py.
PiperOrigin-RevId: 532788430
2023-05-17 07:58:59 -07:00
jax authors
cf4c1edafa Merge pull request #15920 from froystig:issue15869
PiperOrigin-RevId: 530634021
2023-05-09 09:39:48 -07:00
Roy Frostig
051c5dda6e delegate select lowering to opaque dtype rule
... and implement it for PRNG key arrays
2023-05-08 19:02:42 -07:00
Peter Hawkins
00b75aff82 Add tests for negative inputs to top-k.
Make the top-k test inputs larger.

This test would have caught the top-k bug fixed by https://github.com/openxla/xla/pull/2809

PiperOrigin-RevId: 530398528
2023-05-08 13:51:24 -07:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Jake VanderPlas
5521423d92 Change np.prod->math.prod
Why? This is generally used for static operations on shapes, but np.prod
has an unfortunate corner-case behavior that np.prod([]) returns a float.
math.prod is available as of Python 3.8, and is a better solution here.
2023-04-13 11:48:11 -07:00
Yash Katariya
728a5ed96a [shard-map] fix eager shmap+prngs, revise phys aval/sharding logic
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2023-04-05 23:04:41 -07:00
Jake VanderPlas
c2fe350455 future-proof lax.convert_element_type
In the future, np.array(large_value, 'int32') will error
2023-04-04 15:57:32 -07:00
Skye Wanderman-Milne
00acf459c6 Bump minimum jaxlib version from 0.4.6 to 0.4.7.
Also removes a bunch of dead version guards (0.4.7 has
xla_extension_version 144 and mlir_api_version 47)
2023-03-28 13:43:01 -07:00
Parker Schuh
21541e60b1 Guard ArrayImpl checks by xla_extension_version.
PiperOrigin-RevId: 519191714
2023-03-24 11:15:36 -07:00
John QiangZhang
171b22dbbc Add padding option "SAME_LOWER" for ticket https://github.com/google/jax/pull/14990
PiperOrigin-RevId: 518984018
2023-03-23 15:50:16 -07:00
Yash Katariya
c2d5527f72 [Jax cleanup]
* Remove lower_xla_callable and all related functions
* Remove pxla.device_put
* Remove dispatch.device_put_handlers

PiperOrigin-RevId: 517249345
2023-03-16 15:47:28 -07:00