1218 Commits

Author SHA1 Message Date
Sergei Lebedev
65d3058944 Migrate a subset of internal modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

PiperOrigin-RevId: 571932143
2023-10-09 07:29:53 -07:00
jax authors
d631fb10fd Merge pull request #17996 from mattjj:pow-int-jvp-bug-fix
PiperOrigin-RevId: 571482372
2023-10-06 18:26:01 -07:00
Matthew Johnson
6fe573147e fix pow jvp rule with int exponent (broken since #16419)
fixes #17995
2023-10-06 17:53:31 -07:00
jax authors
fae53d9577 Merge pull request #17959 from jakevdp:lax-abs
PiperOrigin-RevId: 571397881
2023-10-06 12:02:02 -07:00
jax authors
2052673f9c Merge pull request #17978 from hawkinsp:fft
PiperOrigin-RevId: 571379922
2023-10-06 10:52:48 -07:00
Peter Hawkins
4e1b8fcdd2 Check dtypes in fft_p's abstract eval rule.
In particular, this catches a bad error when a bfloat16 is passed to rfft.
2023-10-06 08:04:01 -04:00
George Necula
b580c5d5e7 [export] Fix dot_general multi-platform lowering
The previous lowering rule for dot_general was using
ctx.module_context.platform to customize the lowering per
platform. Now we set different lowering rules for
different platforms, thus enabling the multi-platform
lowering to generate the proper code.

This fixes the dot_general multi_platform_export_tests.

PiperOrigin-RevId: 571304531
2023-10-06 04:57:50 -07:00
Jake VanderPlas
60029e7d09 lax.abs: better error for unsigned inputs 2023-10-05 10:53:08 -07:00
Matthew Johnson
29af93b4cb [run_state] add scan nested test, tweak rule name to mention 'state' 2023-10-04 11:48:42 -07:00
jax authors
c3e73c67aa Merge pull request #17760 from superbobry:array-any
PiperOrigin-RevId: 570400629
2023-10-03 08:50:07 -07:00
Sergei Lebedev
5ab05e42c9 MAINT Clean up leftover Array = Any aliases in jax/_src/**.py
I had to revert to using `Any` for `RaggedAxis.ragged_axes` because pytype
found more latent type errors, which require the understanding of ragedness
and dynamic shapes internals to fix properly.
2023-10-01 12:19:21 +01:00
Sergei Lebedev
a8b8267f48 MAINT Reorder the overloads for lax.sort
`Array` is structurally a `Sequence[Array]`, so the first overload always
matches under pytype, which defines `collections.abc.Sequence` as a
`Protocol`.

See
b8f91a37e5/pytype/stubs/builtins/typing.pytd (L149).
2023-09-28 12:51:36 +01:00
Jake VanderPlas
4a5bd9e046 Fix typos across the package 2023-09-22 14:54:31 -07:00
Emily Fertig
d7039b3640 Raise an error if a Ref is returned from lax.cond.
PiperOrigin-RevId: 567709582
2023-09-22 14:01:19 -07:00
Yash Katariya
03877a9218 If a pmap out is replicated i.e. with out_axes=None make jnp.copy's impl go via apply_primitive which will put it on a single device.
If we don't do that, then it hits an error mentioned in https://github.com/google/jax/issues/17690.

Fixes https://github.com/google/jax/issues/17690

PiperOrigin-RevId: 567628026
2023-09-22 08:24:57 -07:00
jax authors
256612bb80 Merge pull request #17720 from superbobry:tuple-list-comp
PiperOrigin-RevId: 567433086
2023-09-21 15:16:12 -07:00
Sergei Lebedev
df7f6a06c0 MAINT Use a generator expression in tuple([... for ... in ...])
In a few cases I also replaced tuple([*xs, *ys]) with (*xs, ys), because
tuple literals support unpacking as well.
2023-09-21 22:25:38 +01:00
Jake VanderPlas
4edb74ba7b Fix some numpy 2.0 incompatibilities 2023-09-21 10:24:52 -07:00
George Necula
32ee27b5cb [callbacks] Add support for shardable ordered effects.
Ordered effects currently are not allowed in multi-device computations.
This is too restrictive sometimes, e.g., `io_callback(ordered=True)` uses
maximal sharding on one device and the callback would be issued only
once even in multi-device computations.

Here we add support for ordered shardable effects, which behave like
ordered effects except they are allowed in SPMD computations.
Currently, only `callback.IOOrderedEffect` is declared shardable.

In general, if the sharding of the side-effecting operation is not
maximal, then such effects would appear in a partial order, with
effects appearing ordered by program point and unordered among
the different devices at a given program point.

We also generalize the mechanism for tracking runtime tokens and
token buffers to work with multiple devices.

PiperOrigin-RevId: 566242557
2023-09-18 02:50:25 -07:00
Qiao Zhang
d4adf0095f Add default jvp and transpose rule for jax.lax.reduce_precision.
PiperOrigin-RevId: 564536160
2023-09-11 16:35:44 -07:00
Parker Schuh
bda9292523 Propagate ad.Zeros to the scan body function for jax.lax.scan for the output 'ys'.
Example of what this fixes:

```
def grad_fn(x):
  def scan_body(x, params):
    return x, x.sum()

  pred, state = jax.lax.scan(scan_body, x, None, length=2)
  return pred.sum(), state
x = np.zeros((5, 10), dtype=np.float32)
loss_grad_fn = jax.value_and_grad(grad_fn, has_aux=True)
print(jax.make_jaxpr(loss_grad_fn)(x))
```
PiperOrigin-RevId: 563544684
2023-09-07 14:36:53 -07:00
Adam Paszke
bb8d5a0121 Rewrite simple slicing to the static slicing primitive whenever possible
This makes it a lot easier to handle within Pallas and Mosaic.

PiperOrigin-RevId: 563128943
2023-09-06 09:43:00 -07:00
Matthew Johnson
70b58bbd30 rolling forward shard_map transpose fixes
The new efficient-transpose path, enabled by setting check_rep=True in the shard_map call, had kept working. But the change inadvertently broke the check_rep=False path. And because most tests set check_rep=True, we didn't notice it in the tests!

The issue was that with check_rep=False, we need the shard_map transpose rule to insert psums corresponding to in_specs with fan-out, and correspondingly insert division for out_specs with fan-in-consensus. (With the new check_rep=True path that this change adds, those extra operations aren't necessary as the body itself transposes correctly.) But the PR accidentally removed those!

The fix was simple: just track whether we've applied the efficient-transpose-body-rewrite (i.e. whether we're in the new body-is-transposable path or old need-extra-operations path) by adding a boolean parameter `rewrite` to the shard_map primitive, and if the rewrite hasn't been applied then include the explicit psum/div operations in the transpose rule.

Reverts 8a04dfd830ff89f46e1fe3e866ee4fb2da9c90aa

PiperOrigin-RevId: 561805840
2023-08-31 17:31:21 -07:00
Matthew Johnson
8a04dfd830 rolling back shard_map transposition change to fix a bug
Reverts 437d7be73534403f39fbee9d6391be1c532933a1

PiperOrigin-RevId: 561730581
2023-08-31 12:39:56 -07:00
Matthew Johnson
fdd252f6ca [shard-map] add rewrite for efficient transposition 2023-08-30 15:08:11 -07:00
Peter Hawkins
93900245aa Remove jax.interpreters.xla.register_collective_primitive.
We aren't consuming this data any more. It existed only to compare against the set of multiprocess-allowed collectives, but we removed that list also. So this registry is completely pointless.

PiperOrigin-RevId: 561150259
2023-08-29 15:10:05 -07:00
Peter Hawkins
d0a6813ea2 Make mlir.custom_call() more general and expose it as jax.interpreters.mlir.custom_call().
This change is in preparation for deprecating the XlaBuilder APIs for building non-MLIR HLO. In general JAX would be best served by adding a more user-friendly "custom kernel" API that doesn't require the user to build IR directly, but for the moment the best we can do is migrate users to use MLIR/StableHLO utilities instead of classic HLO utilities.

Since most users of custom kernels probably want to build a custom-call we can get most of the benefit by providing an ergonomic helper function for building the IR for custom calls that can be called by external primitive lowering rules.

This function has two benefits over just building the stablehlo directly:
a) it is a JAX API, and we can be more confident the API won't change because of upstream MLIR changes
b) the Python API to build stablehlo.custom_call generated by the bindings isn't that easy to use (e.g. it doesn't have sensible defaults).

Next step will be to deprecate XlaBuilder and encourage users to switch to lowering rules using this helper.

PiperOrigin-RevId: 561042402
2023-08-29 08:50:07 -07:00
Jake VanderPlas
6cec5d4416 lax.pow: fix shape mismatch failure in jvp rule 2023-08-25 10:05:55 -07:00
jax authors
af42359433 Merge pull request #16419 from mattjj:pow-jvp
PiperOrigin-RevId: 559266945
2023-08-22 17:15:04 -07:00
Matthew Johnson
1f8fb2c8bd change lowering rule to satisfy jax2tf 2023-08-22 16:48:11 -07:00
jax authors
209b6b02f4 Merge pull request #17144 from jakevdp:zeta
PiperOrigin-RevId: 558193896
2023-08-18 11:04:43 -07:00
Peter Hawkins
889489206b Remove the canonicalize_dtypes argument from mlir.ir_constant(s).
Instead, force the caller to explicitly canonicalize the argument if that's what they want.

The current behavior (canonicalize by default) is not the behavior we want to encourage: we want to canonicalize exactly where we need to and nowhere else.

PiperOrigin-RevId: 557806903
2023-08-17 06:44:12 -07:00
Jake VanderPlas
6cd467fd57 Create lax.zeta with native HLO lowering 2023-08-16 13:43:41 -07:00
Jake VanderPlas
0ad6196ff0 Create lax.polygamma with native HLO lowering 2023-08-16 11:57:05 -07:00
Peter Hawkins
47651c6a59 Remove uses of XLA translation rules.
Remove translation_rule argument to standard_primitive.

PiperOrigin-RevId: 557220350
2023-08-15 12:53:36 -07:00
Peter Hawkins
78cfdd1b35 Add some more type annotations to lax_numpy.py.
These type annotations are of course mostly ignored because the pytype: skip-file comment, but they help readers if nothing else.

PiperOrigin-RevId: 555955257
2023-08-11 08:07:24 -07:00
Peter Hawkins
bfaffe3183 Add version guards after GPU tridiagonal solve change.
PiperOrigin-RevId: 555931222
2023-08-11 06:41:05 -07:00
Srinivas Vasudevan
7dfc8ff49d Add batching rules to jax.lax.linalg.tridiagonal_solve.
PiperOrigin-RevId: 555700103
2023-08-10 16:25:59 -07:00
jax authors
be543f020d Merge pull request #17041 from mtsokol:update-ninf-usage
PiperOrigin-RevId: 555185385
2023-08-09 09:26:12 -07:00
Mateusz Sokół
1fedf04ed5 API: Remove NINF and PINF usages 2023-08-09 14:16:33 +02:00
Peter Hawkins
c9cf6b4423 Remove allowlist for multihost collectives.
This allowlist used to prevent users from using collectives that didn't work correctly in multihost pmap(). But currently every collective in JAX (except for pgather(), which isn't public), is on the list. So the allowlist serves no purpose any more.

PiperOrigin-RevId: 555124144
2023-08-09 04:43:51 -07:00
Peter Hawkins
ca17b6c08f Move functions out of xla.py closer to their users.
Refactoring only, no changes intended. The goal is to shrink xla.py down to only its HLO-compatibility role, and remove things that aren't related to HLO compatibility.

Remove an unused top_k translation rule as well.

PiperOrigin-RevId: 554946059
2023-08-08 14:40:42 -07:00
Mateusz Sokół
d183a2c02f ENH: Update numpy exceptions imports 2023-08-07 19:08:41 +02:00
Patrick Kidger
808b0b26be Crash fix from ndim 2023-08-04 13:32:12 -07:00
Jérome Eertmans
6e55c20fbb chore(docs): improve jax.lax.scan
Make the docstring a bit more explicit about what is t

Co-authored-by: Jake Vanderplas <jakevdp@google.com>
2023-08-03 22:18:07 +02:00
Yash Katariya
4ddf6a9a54 Bump minimum_jaxlib_version to 0.4.14. xla_extension_version is 174 and mlir_api_version is 54
PiperOrigin-RevId: 552816893
2023-08-01 08:53:28 -07:00
George Necula
2eaf545a47 [shape_poly] Fix handling of dot_general with different lhs_dtype and rhs_dtype
Add primitives tests for the case of dot_general with different lhs_dtype and
rhs_dtype. Then fix the lowering to work with dynamic shapes.
2023-07-31 10:54:05 +03:00
Matthew Johnson
69ad4df9a5 fix pow_p jvp rule at x=0. y=0
fixes #14397

For autodiff purposes (and eventually for evaluation implementation purposes)
we need to distinguish between pow :: inexact -> int -> inexact (which is
differentiable at (0., 0)) and pow :: inexact -> inexact -> inexact (which
isn't); see https://github.com/google/jax/issues/14397#issuecomment-1426386290.

Instead of making a new primitive, we made the old one polymorphic and switch
its behavior on the element type of its second argument.

There were also some other cases with special handling for algorithmic reasons
(e.g. doing binary exponentiation), so these autodiff cases had to be merged
with those algorithmic cases.

Co-authored-by: Roy Frostig <frostig@google.com>
2023-07-28 17:14:47 -07:00
jax authors
b716f433a5 Merge pull request #16883 from mattjj:exp2-primitive
PiperOrigin-RevId: 551946122
2023-07-28 14:09:05 -07:00
Matthew Johnson
560ede0ff1 add an exp2 primitive and lax.exp2
part of fixing https://github.com/jax-ml/jax-triton/issues/204
2023-07-28 12:33:49 -07:00