62 Commits

Author SHA1 Message Date
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Dougal Maclaurin
018189491b Clean up and fix primal type to tangent type mapping
This is part of the ["stackless"](#23299) change. I'm splitting it out into a separate PR because we need it for some work on sharding types.

Changes:
  1. Rename `at_least_vspace` to `to_tangent_type` since that's what we always meant by it. `at_least_vspace` was always a bad name (sorry!) but it makes even less sense when you can have a special tangent type for a primal types that's already a vector space itself.
  2. Replace `Zero.from_value` with `Zero.from_primal_value`, which does the required primal-type-to-tangent-type conversion.
  3. Add `to_tangent_type` calls in various other places they're missing.
  4. Remove non-support for float0 in custom deriviatives?
  5. [Optional, WIP] Reinstate some checks that had been skipped over, presumably because of these bugs. (We'll see how far I get with it. Might end up being a separate PR.)
PiperOrigin-RevId: 676115753
2024-09-18 13:43:54 -07:00
Jake VanderPlas
7b41583414 refactor jax.lax to not depend on jax.numpy 2024-09-01 07:49:49 -07:00
Peter Hawkins
3d5784a343 Don't wrap singleton ir.Types during HLO lowering.
This is similar to https://github.com/google/jax/pull/22211, but for MLIR types instead of MLIR values.
2024-07-08 12:24:45 -04:00
jax authors
dffd72e290 Merge pull request #22211 from hawkinsp:singletons
PiperOrigin-RevId: 649135349
2024-07-03 11:07:00 -07:00
Peter Hawkins
8ab0c07edc Don't wrap singleton ir.Values with tuples during HLO lowering.
In general a JAX value might correspond to multiple HLO values, which is why the HLO lowering represents each value as a tuple of zero or more ir.Values. However, the common case is that there is exactly one value, and almost all such lists are singletons.

To reduce the number of singleton list and tuple objects allocated during MLIR lowering, instead represent singleton values as unwrapped ir.Values, and only use a tuple if there is not exactly one ir.Value backing a JAX value.
2024-07-01 16:11:00 -04:00
Dan Foreman-Mackey
6becf716f3 Remove linear parameter from lax.cond_p.
As far as I can tell, it seems like the `linear` parameter in the
`lax.cond_p` primitive only exists for historical reasons. It could be
used for type checking in `_cond_transpose`, but that was removed
because of #14026. With this in mind, we could stop tracking this
parameter as implemented in this PR, unless we expect that we'd want to
re-introduce the type checking in the future.
2024-07-01 10:25:42 -04:00
Peter Hawkins
7f4ef63cd8 Run pyupgrade --py310-plus.
Also apply manual fixes to import sorting and unused imports.
2024-06-26 16:10:18 -04:00
jax authors
fe3c8e15a8 Merge pull request #21806 from cgarciae:cond-passthrough-outputs
PiperOrigin-RevId: 646970169
2024-06-26 09:13:07 -07:00
Cristian Garcia
dae7e41ade fix cond passthrough outputs 2024-06-26 16:17:45 +01:00
George Necula
6e3fc9a768 Fix the eager mode execution for lax.platform_dependent
When we use lax.platform_dependent in eager mode, and some
of the branches contain custom calls that are not recognized on
some platforms, we must eagerly pick the required branch.
In jit mode, the constant folding that the XLA compiler already
does will eliminate the unnecessary branches.
2024-06-21 17:07:48 +03:00
Matthew Johnson
7c125701c5 make cond forward inputs to outputs, reduces vmap lifting
Co-authored-by: Cristian Garcia <cgarciae@google.com>
2024-06-05 16:39:55 +00:00
Sergei Lebedev
c3bc88d5e4 Bumped mypy to 1.10.0 and ruff to 0.4.4 2024-05-16 23:16:32 +01:00
rajasekharporeddy
0d68a1a82d Fix doc typos 2024-04-05 14:21:33 +05:30
Matthew Johnson
3736b322b7 [xmap-removal] remove reduce_axes from grad / vjp / backward_pass
The reduce_axes machinery was planned to be used for xmap. It's not needed for
e.g. shard_map, see https://jax.readthedocs.io/en/latest/jep/17111-shmap-transpose.html.
2024-02-25 15:50:54 -08:00
Peter Hawkins
f1ea67117e Split name_stack out of mlir.ModuleContext.
A unique name_stack is built for every equation, which means that we're constantly rebuilding ModuleContext objects, even though the lifetime of almost everything else (naturally) is the Module scope. Split name_stack into an object that is threaded separately, including as part of mlir.LoweringRuleContext.

PiperOrigin-RevId: 608594374
2024-02-20 07:17:23 -08:00
Peter Hawkins
67df647988 Reland https://github.com/google/jax/pull/10573.
The original PR was reverted because of downstream breakage.

Originally we used the `Var.count` attribute to ensure `Var` instances were printed consistently regardless of context, even though only their object id was load-bearing. That is, `Var.count` was only used for pretty printing. (#1949 added a total_ordering on `Var` for reasons out of scope of JAX's core code. I'm going to figure out if that's still needed... Haiku tests all seem to pass without it.)

But #8019 revised our pretty-printing so as not to use `Var.count`. Instead it chose how to pretty-print Var instances based on their order of appearance in a jaxpr. That meant `Var.count` really wasn't useful anymore.

So this PR removes `Var.count`. Since we no longer have `Var.count`, we also don't need core.gensym to take an optional sequence of jaxprs, since that was just used to set the starting count index for new `Var`s.

In fact, `Var.__repr__` and `JaxprEqn.__repr__` were made confusing after #8019, since they could print variable names totally different from the names that would appear when the same `JaxprEqn` or `Var` objects were printed as part of a jaxpr. That is, before this PR we might have a jaxpr which printed like:

```
import jax

def f(x):
  for _ in range(3):
    x = jax.numpy.sin(x)
  return x

jaxpr = jax.make_jaxpr(f)(3.)
print(jaxpr)
# { lambda ; a:f32[]. let
#     b:f32[] = sin a
#     c:f32[] = sin b
#     d:f32[] = sin c
#   in (d,) }

_, eqn, _ = jaxpr.jaxpr.eqns
print(eqn)
# a:f32[] = sin b
```

Notice the variable names in the equation pretty-print don't correspond to any in the jaxpr pretty-print!

So this PR changes `JaxprEqn.__repr__` and `Var.__repr__` to show `Var` object ids, and in general just do less formatting (which seems consistent with the spirit of `__repr__`):
```
JaxprEqn(invars=[Var(id=140202705341552):float32[]], outvars=[Var(id=140202705339584):float32[]], primitive=sin, params={}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f837c73d770>, name_stack=NameStack(stack=())))
```

PiperOrigin-RevId: 607664497
2024-02-16 05:57:12 -08:00
Matthew Johnson
ec7d28c0b2 revise logic for tangent types of extended dtypes
* remove the dead code KeyTangentTy
* replace TyRules.make_tangent with TyRules.zero
* removed ad.instantiate_zeros_aval, which was redundant with ad.instantiate_zeros ever since (1) we removed units and (2) we made Zero carry an aval on it
* fix a bug in backward_pass where we instantiated a Zero at the primal type rather than the corresponding tangent type
* fix _f_bwd in test_keyarray_custom_vjp, which had the wrong type (need to return cotangents for all inputs, we were returning a (float_tangent, key_tangent) pair instead of a (float_tangent, (float_tangent, key_tangent)) nested tuple, see #19009 for a check which catches this and hence includes the same test change

We probably also need a TyRules.add for any extended dtypes that can occur as tangent dtypes, but we currently don't have any tests that exercise that (because all extended dtype tangent types are currently float0). I have some follow-up work to add such a case though!
2023-12-20 14:24:52 -08:00
George Necula
2d9da6c8fb Cleanup the code to picking lowering rules based on platform.
Previously, we had special-cased the code to pick the lowering
rule for a primitive based on the lowering platform, and separately
we had the code to handle multi-platform lowering. The latter,
called `mlir.lower_multi_platform` had its own special case for
when a single lowering rule applied.

We rename `mlir.lower_multi_platform` to `mlir.lower_per_platform`
to not imply that it is only for multi-platform. We simplify
its API (takes a dictionary instead of a list of tuples).
2023-11-19 18:39:59 +02:00
Peter Hawkins
8e8dc263bc Use MLIR generated convenience functions athing(...) instead of writing AThingOp(...).result.
In most cases these are more succinct.

This change does not update Pallas/Mosaic.

PiperOrigin-RevId: 583448254
2023-11-17 11:47:14 -08:00
George Necula
8feb413211 Add a lax.platform_dependent API for writing platform-dependent code.
In JAX the actual platform on which a computation is run is determined
very late, e.g., based on where the data is located. When using AOT
lowering or serialization, the computation may execute on a different
machine, or even on a platform that is not available at lowering time.
This means that it is not safe to write platform-dependent code using
Python conditionals, e.g., based on the current default JAX platform.
The proper way to do this is to introduce a primitive with
platform-specific lowering rules. This change introduces such a
primitive along with a user-facing API.

See more details in the docstring of lax.platform_dependent.
2023-11-02 14:31:38 +01:00
Sergei Lebedev
2f70ae700a Migrate another subset of internal modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.

PiperOrigin-RevId: 572587137
2023-10-11 08:46:06 -07:00
Emily Fertig
d7039b3640 Raise an error if a Ref is returned from lax.cond.
PiperOrigin-RevId: 567709582
2023-09-22 14:01:19 -07:00
George Necula
32ee27b5cb [callbacks] Add support for shardable ordered effects.
Ordered effects currently are not allowed in multi-device computations.
This is too restrictive sometimes, e.g., `io_callback(ordered=True)` uses
maximal sharding on one device and the callback would be issued only
once even in multi-device computations.

Here we add support for ordered shardable effects, which behave like
ordered effects except they are allowed in SPMD computations.
Currently, only `callback.IOOrderedEffect` is declared shardable.

In general, if the sharding of the side-effecting operation is not
maximal, then such effects would appear in a partial order, with
effects appearing ordered by program point and unordered among
the different devices at a given program point.

We also generalize the mechanism for tracking runtime tokens and
token buffers to work with multiple devices.

PiperOrigin-RevId: 566242557
2023-09-18 02:50:25 -07:00
Peter Hawkins
889489206b Remove the canonicalize_dtypes argument from mlir.ir_constant(s).
Instead, force the caller to explicitly canonicalize the argument if that's what they want.

The current behavior (canonicalize by default) is not the behavior we want to encourage: we want to canonicalize exactly where we need to and nowhere else.

PiperOrigin-RevId: 557806903
2023-08-17 06:44:12 -07:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
Roy Frostig
14f32653a1 resolve conditionals to default "shared operand form" more often
If both the second and third operand of a `lax.cond` call are callable, then
resolve it as a new-style (default) conditional, where both branches act on the
same operands.

This changes the behavior of five-argument `lax.cond` calls. It is a breaking
change for callers using the old-style `cond` calling convention (`pred`,
`true_arg`, `true_fn`, `false_arg`, `false_fn`) with a callable `true_arg`.

PiperOrigin-RevId: 543912445
2023-06-27 18:49:16 -07:00
Peter Hawkins
816ba91263 Use lower-case PEP 585 names for types.
Issue https://github.com/google/jax/issues/16537

PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
Sharad Vikram
907782d2e5 Deduplicate references closed over across branches of a lax.cond.
This fixes a correctness issue that could crop up when doing `run_state(cond)`.

PiperOrigin-RevId: 540795172
2023-06-15 23:58:14 -07:00
Joey Teng
005d4ca78e
add explanation: switch will be converted to select when transformed with vmap in doc 2023-06-09 19:12:21 +01:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Sharad Vikram
5101184ad4 Add initial implementation of a run_state primitive 2023-04-03 21:32:32 -07:00
Peter Hawkins
6cc1bf54a1 Move jax.interpreters.partial_eval to jax._src.interpreters.partial_eval.
Also fix up some other internal imports of jax.interpreters.* to use jax._src.interpreters.

PiperOrigin-RevId: 519813664
2023-03-27 13:30:47 -07:00
Matthew Johnson
268456ef54 enable pjit recursive typechecking
Give pjit_p a custom typecheck rule, which basically just calls the
core._check_call utility (which was made for xla_call_p and core.call_p).

This revealed the need for a slight generalization of the custom_typecheck rule
signature, for better "context-aware" printing of jaxpr type errors: the rules
should have a `ctx_factory` first argument. **The reason this PR touches so
many files is just that it makes the trivial tweaks to all existing typecheck
rules to accomodate that new signature.** I didn't adapt any other higher-order
primitives' rules to actually use the context, but presumably errors for HOPs
like scan would be improved by using it. Follow-up work!

It's key that core._check_call works with dynamic shapes; this PR is soon to be
followed by some djax+pjit PRs!
2023-03-22 16:59:22 -07:00
Peter Hawkins
148774587a Remove circular dependency between source_info_util and util.
Move util.new_name_stack into source_info_util. Replace uses of util.extend_name_stack with stack.extend().

PiperOrigin-RevId: 512685810
2023-02-27 11:41:46 -08:00
Sharad Vikram
4960e656af Refactor Ref abstract type to contain other AbstractValues 2023-02-23 17:02:40 -08:00
Sharad Vikram
a6c4c87f3e Add JaxprInputEffect and refactor StateEffects to use it 2023-02-21 16:30:06 -08:00
Sharad Vikram
af2306c0a8 Refactor effects system to use effect types, not objects 2023-02-17 17:40:08 -08:00
Peter Hawkins
54269c1145 Remove more exported names from jax.interpreters.xla.
None of these appear to have public users, and this module is not included in the deprecation policy.

Also:
* shorten a number of alias chains.
* move make_op_metadata() into its only caller in jax2tf
* delete the unused function dtype_to_primitive_type.
PiperOrigin-RevId: 510205315
2023-02-16 11:56:30 -08:00
Roy Frostig
cb8dcce2fe migrate more internal dependencies from jax.core to jax._src.core
PiperOrigin-RevId: 509736368
2023-02-14 23:01:11 -08:00
Yash Katariya
d0eedf7e57 Plumb spmd_axis_name through batch_jaxpr2 and batch_jaxpr
PiperOrigin-RevId: 509341618
2023-02-13 14:58:20 -08:00
Peter Hawkins
cc8d7fae32 Move jax.interpreters.mlir to jax._src.interpreters.mlir.
Replace jax.interpreters.mlir with a shim that re-exports names that are likely to be used externally.

PiperOrigin-RevId: 508187063
2023-02-08 14:39:01 -08:00
Jake VanderPlas
6376dc9616 Fix excessive recompiles in lax.cond 2023-01-18 10:17:01 -08:00
Sharad Vikram
3de5c2b716 Add IO callback 2023-01-17 13:55:05 -08:00
Matthew Johnson
e516d41180 cond transpose, use UndefinedPrimal not linear for transpose inputs 2023-01-16 10:39:19 -08:00
Jake VanderPlas
c9c6263251 DOC: clarify behavior of lax.cond & lax.select 2023-01-06 11:31:26 -08:00
Jake VanderPlas
4a6bbde409 Move jax.linear_util to jax._src.linear_util 2022-12-20 14:49:27 -08:00
Eugene Burmako
b8ae8e3fa1 (NFC) Prepare for migration from producing MHLO to producing StableHLO
This CL renames occurrences of "mhlo" in: 1) names, 2) tests, 3) prose in order
to prepare for the upcoming migration.

Unchanged occurrences:
  1) Public API that contains "mhlo", e.g. XlaLowering.mhlo and the "mhlo"
     argument value in Lowering.as_text and Lowering.compiler_ir.
  2) Documentation (changelog, JEPs, IR examples, etc).
  3) One rare situation where prose says "StableHLO" and "MHLO" in one sentence,
     so both are necessary to disambiguate.

PiperOrigin-RevId: 495771153
2022-12-15 21:00:07 -08:00
George Necula
8fb344a724 [jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.

For native serialization we will support two lowering implementations:

  * one is using the growing support in JAX for dynamic shapes,
  of which shape polymorphism is a special case.
  This implementation is enabled with the --jax_dynamic_shapes flag.
  At the moment, the JAX dynamic shapes support is still
  incomplete and over 300 jax2tf shape polymorphism tests fail.

  * a new one (added) here in which we form a Jaxpr using abstract
  values that express dimension sizes as dimension polynomials
  (as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
  This implementation is enabled when --jax_dynamic_shapes is off.
  With this implementation only 50 jax2tf tests fail (to be fixed
  separately).

The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.

The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.

Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.

The key code pattern used in the lowering rule is::

    if not core.is_constant_shape(shape):  # Handles both Var, and polynomials
       shape = mlir.eval_dynamic_shape(ctx, shape)
       return mhlo.DynamicXXX(..., shape)
    else:
       return mhlo.XXX(..., shape)

with `mlir.eval_dynamic_shape` handling both cases::

    def eval_dynamic_shape(ctx, shape):
       if config.jax_dynamic_shapes:
          # Using Var
          return ... subst using ctx.axis_size_env ...
       else:
          # Using polynomials
          return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values

In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.

I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-12-08 08:19:35 +02:00
Jake VanderPlas
26d9837b36 Switch to new-style f-strings 2022-12-01 09:14:16 -08:00