412 Commits

Author SHA1 Message Date
Jake VanderPlas
5c9134c30a Re-enable testCumulativeLogSumExp test
PiperOrigin-RevId: 499895651
2023-01-05 08:52:21 -08:00
Jake VanderPlas
008f35a6b4 skip testCumulativeLogSumExp due to timeout with updated LLVM
PiperOrigin-RevId: 499585313
2023-01-04 14:51:32 -08:00
Yash Katariya
1fc9197c79 Simplify Array's shard_arg_handler by merging pmap and pjit/xmap paths
PiperOrigin-RevId: 497991966
2022-12-27 10:16:44 -08:00
George Necula
4d4eba4539 Fix scatter in CLIP mode with uint32 and uint64 indices
Clipping uses np.iinfo(indices.dtype).max and those values
are too large to be converted to Python or C constants.

This is a second attempt, after https://github.com/google/jax/pull/13746 was rolled back due to
failures when jax_array=False. Since that use case will go away
soon we just enable the fix for when jax_array=True.

PiperOrigin-RevId: 497171518
2022-12-22 08:34:47 -08:00
George Necula
2b716f292d Fix scatter in CLIP mode with uint32 and uint64 indices
Clipping uses np.iinfo(indices.dtype).max and those values
are too large to be converted to Python or C constants.

This is a second attempt, after https://github.com/google/jax/pull/13746 was rolled back due to
failures when jax_array=False. Since that use case will go away
soon we just enable the fix for when jax_array=True.

PiperOrigin-RevId: 497079129
2022-12-21 21:59:51 -08:00
George Necula
ce5320a2e4 Copybara import of the project:
--
a74c74c25572eec23c28e08dbe67781a23be19fb by George Necula <gcnecula@gmail.com>:

Fix scatter in CLIP mode with uint32 and uint64 indices

Clipping uses np.iinfo(indices.dtype).max and those values
are too large to be converted to Python or C constants.

PiperOrigin-RevId: 496883024
2022-12-21 03:46:27 -08:00
George Necula
a74c74c255 Fix scatter in CLIP mode with uint32 and uint64 indices
Clipping uses np.iinfo(indices.dtype).max and those values
are too large to be converted to Python or C constants.
2022-12-21 10:25:24 +02:00
Eugene Burmako
b8ae8e3fa1 (NFC) Prepare for migration from producing MHLO to producing StableHLO
This CL renames occurrences of "mhlo" in: 1) names, 2) tests, 3) prose in order
to prepare for the upcoming migration.

Unchanged occurrences:
  1) Public API that contains "mhlo", e.g. XlaLowering.mhlo and the "mhlo"
     argument value in Lowering.as_text and Lowering.compiler_ir.
  2) Documentation (changelog, JEPs, IR examples, etc).
  3) One rare situation where prose says "StableHLO" and "MHLO" in one sentence,
     so both are necessary to disambiguate.

PiperOrigin-RevId: 495771153
2022-12-15 21:00:07 -08:00
Eugene Burmako
ee1ad39dd1 Port type inference for 6 ops from StableHLO to MHLO
Ops:
  1) AfterAllOp: https://github.com/openxla/stablehlo/pull/708.
  2) CreateTokenOp: https://github.com/openxla/stablehlo/pull/711.
  3) DynamicUpdateSliceOp: https://github.com/openxla/stablehlo/pull/686 and https://github.com/openxla/stablehlo/pull/757.
  4) OptimizationBarrierOp: https://github.com/openxla/stablehlo/pull/575.
  5) OutfeedOp: https://github.com/openxla/stablehlo/pull/713.
  6) SendOp: https://github.com/openxla/stablehlo/pull/580.

This PR prepares for migration from producing MHLO to producing StableHLO by
aligning type inference between dialects, so that switching from one to another
doesn't need changes to calls to Python builders.

PiperOrigin-RevId: 495404149
2022-12-14 13:38:26 -08:00
George Necula
27f5bd057c Improves handling of opaque types for dynamic shapes
The immediate motivation for this is to support the lowering
to StableHLO for programs with polymorphic shapes. This requires
mixing of dynamic shapes with opaque types.

The general strategy is to push the actual selection of the MHLO ops
down into mlir module (e.g., mlir.slice_op, mlir.broadcast_in_dim)
so that we have one place where we pick whether we use the Dynamic
or static ops. These routines can also handle the opaque type.
This will result in a recursive
call to, e.g., mlir.slice_op, but the inner call will be using
the physical avals, which should not be opaque anymore.

While making this change I was confused by the fact that the
custom KeyTyRules in prng.py have lowerings that return multiple
MHLO ops. See https://github.com/google/jax/pull/11768#issuecomment-1342349102
and I changed the rules to return a single op.

.
2022-12-12 05:19:04 +01:00
Peter Hawkins
73de02d5ce Make JAX tests pass under NumPy 1.24.0rc2.
* allow rc2 in numpy versions when parsed by tests.
* don't cast np.empty(), which can lead to cast errors.
* NumPy 1.24 now warns on overflowing scalar int to array casts in more
places.
2022-12-08 19:46:10 +00:00
Jake VanderPlas
26d9837b36 Switch to new-style f-strings 2022-12-01 09:14:16 -08:00
Yash Katariya
4443b861a5 Remove local imports of array.py. The remaining local imports are in pxla.py but I will chip away at them when we delete SDA and move some more APIs out of experimental.
PiperOrigin-RevId: 492033543
2022-11-30 15:26:03 -08:00
Peter Hawkins
a13541441b Reenable a TPU test now that the compiler bug is fixed.
PiperOrigin-RevId: 487705048
2022-11-10 19:38:01 -08:00
Peter Hawkins
e42e52d4aa Rename test flag --num_generated_cases to --jax_num_generated_cases.
parse_flags_with_absl() only parses flags that start with --jax_. Other flags are only parsed when absl.app's main function runs. But that's too late for test cases: test cases need to have the number of generated cases chosen at module initialization time. Hence the --num_generated_cases flag wasn't doing anything. Oops. By renaming it it works once again.

It might make sense to stop using flags for the number of generated cases and only use environment variables. We defer that to a future change.

Fix many test cases that were shown to be broken with a larger number of test cases enabled.

PiperOrigin-RevId: 487406670
2022-11-09 18:58:05 -08:00
Srinivas Vasudevan
5adfb08986 Add lax.cumlogsumexp for cumulative logsumexp operations.
PiperOrigin-RevId: 485158935
2022-10-31 15:08:52 -07:00
Peter Hawkins
02dc25f022 [JAX] Redisable int8 convolution tests on GPU due to CI failures.
PiperOrigin-RevId: 482191832
2022-10-19 06:50:36 -07:00
Peter Hawkins
807269990e Enable more GPU and TPU tests that pass at head.
Increase precision of matmuls in LU decompositions, pseudo-inverse solves, and their gradients. It is unlikely users want to use low precision for these operations and high precision is probably the right default.

PiperOrigin-RevId: 482071629
2022-10-18 18:09:44 -07:00
Peter Hawkins
0d3277b5c3 Port more tests from jtu.cases_from_list to jtu.sample_product. 2022-10-11 21:06:08 +00:00
Peter Hawkins
c7e5d3dc95 Add an internal jtu.sample_product test decorator.
This decorator samples from a cartesian product of parameterized tests
without materializing the full product explicitly.

Update lax_test.py to use the new decorator.

On my desktop machine, this improves the timing for `pytest
--collect-only tests/lax_test.py` from 6.8s to 1.9s.
2022-10-04 00:39:22 +00:00
Yash Katariya
9e4114f0f1 Move array.py and sharding.py from experimental/ to _src/.
PiperOrigin-RevId: 477201711
2022-09-27 10:06:52 -07:00
Yash Katariya
cbf34cb609 Rename the concrete class Array to ArrayImpl
PiperOrigin-RevId: 477017236
2022-09-26 16:18:30 -07:00
Jake VanderPlas
a6b24b379c Add regression test for lax.rev simplification error
PiperOrigin-RevId: 476430486
2022-09-23 12:07:15 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Yash Katariya
28741b8e0d Some miscellaneous changes to make tests pass when jax.Array is enabled by default.
1. Add `device_buffer` and `device_buffers` fields to Array as a backwards compatible change for DA and SDA.
2. Support PartitionSpecs as input to in_axis_resources and out_axis_resources when jax_array is enabled as a backwards compatible change since all user code uses this currently. Create a MeshPspecSharding internally.
3. Some tests changes to make them pass

PiperOrigin-RevId: 474642889
2022-09-15 13:27:40 -07:00
Jake VanderPlas
34eae3de88 jax.lax: ensure GatherDimensionNumbers contains tuples for hashability 2022-09-12 12:10:17 -07:00
Peter Hawkins
6ddf3c4d97 Reapply: Use MLIR bytecode when passing IR to backends.
MLIR bytecode is more compact to represent and should be faster to generate and parse.

The previous attempt at this change broke for 0D convolutions. JAX was not ensuring that the padding attribute had the correct [N, 2] shape when N was 0.

PiperOrigin-RevId: 472991661
2022-09-08 08:11:16 -07:00
Peter Hawkins
57b5acf1b6 Roll forward: Upgrade logistic into a primitive.
Unlike the previous attempt, we don't try to use mhlo.logistic as the lowering of the new primitive yet. Instead, we lower to the old implementation of `expit`. This means that this change should be a no-op numerically and we can work on changing its implementation in a subsequent change.

PiperOrigin-RevId: 472705623
2022-09-07 06:06:56 -07:00
Yash Katariya
0584c6a1c4 Add support to handle arbitrary shardings to KeyArray. Resolve all the TODOs that were created before.
Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 471443690
2022-08-31 22:54:06 -07:00
Roy Frostig
8f045b12d6 internal rename: swap mentions of "custom eltypes" for "opaque dtypes"
Also, avoid direct set membership tests on `core.opaque_dtypes`. Update
callers to use `core.{is,has}_opaque_dtype` predicates instead.
2022-08-30 16:52:08 -07:00
jax authors
c26c7fddad Merge pull request #12167 from froystig:key-dtype3
PiperOrigin-RevId: 471114198
2022-08-30 16:07:57 -07:00
jax authors
9c16c83234 Rollback of upgrade logistic (sigmoid) function into a lax primitive.
PiperOrigin-RevId: 471105650
2022-08-30 15:30:43 -07:00
Roy Frostig
73bf0aa30c access rules through a hidden attribute of opaque dtype 2022-08-30 14:06:01 -07:00
Yash Katariya
6340952e2a Make jit == pjit. This means that the lowering and execution paths of jit and pjit are merged.
A fallback to `lower_xla_callable` is taken when pmap appears in the jaxpr during the jit lowering path.

Added support for `keep_unused`, `committed` and `core.Token` to pxla.py.

PiperOrigin-RevId: 470896270
2022-08-29 22:03:21 -07:00
Peter Hawkins
f68f1c0cd0 Upgrade logistic (sigmoid) function into a lax primitive.
This allows us to lower it to `mhlo.logistic`, which allows XLA to generate more efficient code.

PiperOrigin-RevId: 470300985
2022-08-26 11:58:28 -07:00
Jake VanderPlas
b46b86db95 Fix testUnaryWeakTypes 2022-08-24 09:16:47 -07:00
Jake VanderPlas
ab4ec5804a lax.squeeze: ensure DeviceArray is returned 2022-08-23 10:48:40 -07:00
Yash Katariya
d77848bcc9 Enable jax_array on CPU for the entire JAX test suite!
PiperOrigin-RevId: 468726200
2022-08-19 10:04:35 -07:00
Roy Frostig
c9d3094014 defer to custom eltype for gather lowering rule
This also adds an `mlir.delegate_lowering` helper that takes optional
replacements to the lowering context and restores/updates it on
return. This is used to define a gather lowering rule in tests here,
by calling out to the base gather lowering rule. I picked it off of a
working branch that uses it as well (for key-eltyped lowering rules).
2022-08-12 14:28:31 -07:00
Roy Frostig
7955799ae3 defer to custom eltype for slice lowering rule
We already handled dynamic slice, but plain slice is eltype-polymorphic too.
2022-08-09 19:13:34 -07:00
Matthew Johnson
81b6263ed0 Rolling forward #11768 after test failures caused roll-back (from use of np.empty).
PiperOrigin-RevId: 465712458
2022-08-05 22:19:33 -07:00
jax authors
6b0c0dc321 Internal change
PiperOrigin-RevId: 465705931
2022-08-05 21:08:43 -07:00
Matthew Johnson
348da51dc6 prototype unfettered element types in jaxpr arrays
From where comes the set of element types in jaxprs? Historically, from NumPy
and XLA element types. But why would jaxprs be constrained to those? After all,
jaxprs are just symbols, my friends. Those symbols need to be grounded when we
translate to another compiler's IR, or when we have input or output values with
a jaxpr evaluation. So if we're lowering we need ways to map jaxpr types to
lowered IR types, and also ways to map any operations allowed on these types to
lowered IR operations. And we may want Python objects representing values of
these types. But once we have those mappings we don't need to be limited by
NumPy/XLA element types.

Within jaxprs, we also need to handle transformations with these types.

In this change we started unfettering jaxpr element types from their vestigial
NumPy/XLA constraints. Concretely, that means:
  * allowing ShapedArray to have any object for its 'dtype' attribute
  * added core.custom_eltype set
  * extended existing handlers for ShapedArray to call the corresponding custom
    element type handlers
  * mlir lowerings of some fully-element-type-polymorphic primitives
  * tests

In this PR, we only actually use these new extension points in tests.

The applications to come that we have in mind are:
  * arrays of prngkeys (and even custom prngs, as well as reuse error checking)
  * arrays of bounded int type for dynamic shapes (and especially raggedness)
  * float0 arrays
We do *not* have in mind opening these mechanisms up to users. Think of these
as yet another JAX-internal extension point, like all our existing 'handler'
tables.

Jargon-wise, we may want to distinguish:
  * 'eltype' meaning jaxpr element types
  * 'dtype' meaning numpy dtypes (an existing convention)
  * 'etype' meaning hlo/mhlo element types (an existing convention)
But the code doesn't model this jargon at the moment, since we left a lot of
attributes and helper functions referring to 'dtype'.

We haven't yet handled all the element-type-polymorphic primitives. Here's the
list we've thought of so far:
  * [x] broadcast
  * [ ] reshape
  * [x] transpose
  * [ ] pad
  * [x] slice, dynamic_slice, dynamic_update_slice
  * [ ] concatenate
  * [ ] all_to_all, gather, scatter, all_gather, collective_permute
  * [x] make empty scalar (only appears in internal-about-to-lower-jaxpr dialect)
That last one is interesting: we introduced it so that the scan lowering rule,
which lowers first to a "lowered jaxpr dialect" involving only those eltypes
which correspond to etypes and involving only while_loop, ds/dus, etc, can be
made simpler. Otherwise we'd need scan, itself a fully-eltype-polymorphic
primitive, have a more complicated lowering rule.

We also haven't handled AD. Our main applications (at least the first two
listed above) don't involve AD types, so it seemed good to skip for now.

Co-authored-by: Roy Frostig <frostig@google.com>
2022-08-05 19:23:55 -07:00
Nicholas Junge
43b7d95c7e Enable bitwise reductions for integer dtypes in lax.reduce
This commit extends the existing bitwise reduction primitives to work
with all non-floating dtypes, including signed and unsigned integers.
Tests against numpy were added, in similar style to existing test
coverage, to assert the correct results for these new reductions.

Secondly, existing test coverage for `lax.reduce` was extended to check
for the correct primitive being produced in the resulting jaxpr for all
  `lax.reduce` operations.
2022-07-26 20:47:40 +02:00
jax authors
8ca7a9d71f Merge pull request #11114 from jakevdp:x64-lax-test
PiperOrigin-RevId: 455259737
2022-06-15 17:44:23 -07:00
reinerp
b51ee3752e Relax typechecking for preferred_element_type, to allow integer->floating dot products.
PiperOrigin-RevId: 455216435
2022-06-15 14:12:42 -07:00
Jake VanderPlas
b35053525f [x64] make lax_test compatible with strict dtype promotion 2022-06-15 12:09:08 -07:00
Nicholas Junge
311e6a92f9 Add bitwise XOR reducer to lax.reduce
This commit adds handling for the `lax.bitwise_xor` operation to `lax.reduce`. It also includes a new standard reduce primitive, modeled after the existing `and`/ `or` reducer primitives.
2022-06-15 16:56:51 +02:00
Peter Hawkins
823aae0ccb Verify that the indices passed to lax.dynamic_slice and lax.dynamic_update_slice are scalars.
We were previously missing a shape check for this, meaning that if non-scalars were passed we would receive a cryptic MHLO shape error rather than a helpful Python backtrace.

PiperOrigin-RevId: 452682084
2022-06-02 20:41:45 -07:00
Peter Hawkins
ef12b50ef7 Disallow complex dtypes for lax ops rem, lt, le, gt, and ge.
In none of these cases did the compiler do anything sensible with complex values anyway.

Note that complex comparisons are still allowed by the jax.numpy layer (and are defined to have lexicographic comparison semantics).
2022-06-01 17:54:13 -04:00