205 Commits

Author SHA1 Message Date
Jake Vanderplas
399e4ee87f Copybara import of the project:
--
8cf6a6acd151007935b0c3093df05ef036bb0244 by Jake VanderPlas <jakevdp@google.com>:

Remove several deprecated APIs

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16110 from jakevdp:deprecations 8cf6a6acd151007935b0c3093df05ef036bb0244
PiperOrigin-RevId: 534897394
2023-05-24 10:35:37 -07:00
Yash Katariya
23d3dfd834 Remove _PositionalSemantics class since it is not used anymore because jax.Array always has GLOBAL semantics
PiperOrigin-RevId: 517493710
2023-03-17 13:30:04 -07:00
Peter Hawkins
623282715d Split Mesh and ResourceEnv into a new module jax._src.mesh.
This work is an effort to reduce cyclic dependencies in JAX internals.

Move the _global_to_local and _local_to_global methods out of Mesh and into pxla as free functions. This removes the need for jax._src.mesh to depend on things like avals.

PiperOrigin-RevId: 515667671
2023-03-10 10:08:21 -08:00
Peter Hawkins
b389eed8bf [JAX] Deprecate jax.experimental.maps.Mesh.
PiperOrigin-RevId: 509852142
2023-02-15 09:15:50 -08:00
Peter Hawkins
4a523e3d74 Minimize exported names from jax.experimental.maps.
Move implementation of maps to jax._src.maps.

PiperOrigin-RevId: 509309092
2023-02-13 12:57:54 -08:00
Peter Hawkins
612a940160 Minimize the set of names exported from jax.experimental.pjit.
PiperOrigin-RevId: 508889911
2023-02-11 07:37:32 -08:00
Peter Hawkins
cc8d7fae32 Move jax.interpreters.mlir to jax._src.interpreters.mlir.
Replace jax.interpreters.mlir with a shim that re-exports names that are likely to be used externally.

PiperOrigin-RevId: 508187063
2023-02-08 14:39:01 -08:00
Peter Hawkins
98b75cf27b Prune accidental exports from jax.interpreters.pxla.
These imports do not appear to have users outside JAX itself.

PiperOrigin-RevId: 507835295
2023-02-07 11:16:42 -08:00
Rebecca Chen
82bd889120 Silence some pytype errors.
PiperOrigin-RevId: 505816444
2023-01-30 15:33:20 -08:00
Yash Katariya
c4d21f97ea Make xmap use dispatch.sharded_lowering as dispatch.lower_xla_callable is deprecated.
PiperOrigin-RevId: 502398366
2023-01-16 09:42:40 -08:00
Jiho Lee
41b9c5e8cd [docs] donate_argnums FAQ link to rst format 2023-01-10 18:11:08 +09:00
Jake VanderPlas
57fe3fd136 cleanup: remove several unused imports across the package 2022-12-28 12:49:17 -08:00
Yash Katariya
2f3d75aa03 Remove dependency of maps from pjit to avoid circular imports when importing pjit in api.py.
PiperOrigin-RevId: 497230514
2022-12-22 13:35:23 -08:00
Yash Katariya
57840dd916 Move functions into api_util.py and dispatch.py to remove circular import error when pjit is imported in api.py for merging the jit and pjit frontend API.
PiperOrigin-RevId: 497172760
2022-12-22 08:42:05 -08:00
Jake VanderPlas
4a6bbde409 Move jax.linear_util to jax._src.linear_util 2022-12-20 14:49:27 -08:00
George Necula
8fb344a724 [jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.

For native serialization we will support two lowering implementations:

  * one is using the growing support in JAX for dynamic shapes,
  of which shape polymorphism is a special case.
  This implementation is enabled with the --jax_dynamic_shapes flag.
  At the moment, the JAX dynamic shapes support is still
  incomplete and over 300 jax2tf shape polymorphism tests fail.

  * a new one (added) here in which we form a Jaxpr using abstract
  values that express dimension sizes as dimension polynomials
  (as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
  This implementation is enabled when --jax_dynamic_shapes is off.
  With this implementation only 50 jax2tf tests fail (to be fixed
  separately).

The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.

The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.

Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.

The key code pattern used in the lowering rule is::

    if not core.is_constant_shape(shape):  # Handles both Var, and polynomials
       shape = mlir.eval_dynamic_shape(ctx, shape)
       return mhlo.DynamicXXX(..., shape)
    else:
       return mhlo.XXX(..., shape)

with `mlir.eval_dynamic_shape` handling both cases::

    def eval_dynamic_shape(ctx, shape):
       if config.jax_dynamic_shapes:
          # Using Var
          return ... subst using ctx.axis_size_env ...
       else:
          # Using polynomials
          return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values

In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.

I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-12-08 08:19:35 +02:00
Yash Katariya
934bc4e1b3 Move PartitionSpec and Mesh out of experimental and into the sharding namespace. The new API endpoint is jax.sharding.PartitionSpec and jax.sharding.Mesh.
PiperOrigin-RevId: 492358238
2022-12-01 19:28:32 -08:00
jax authors
726b2bc2ee Add JAX monitoring library that instruments code via events.
PiperOrigin-RevId: 488731805
2022-11-15 12:41:41 -08:00
Yash Katariya
a419e1917a Use jax.Array by default for doctests
PiperOrigin-RevId: 488719467
2022-11-15 11:52:22 -08:00
Yash Katariya
c42bad85ef Make MeshPspecSharding an alias for NamedSharding (it was the other way around before this CL).
PiperOrigin-RevId: 488473538
2022-11-14 14:44:00 -08:00
Adam Paszke
94ba43bfba Don't assume that vmap doesn't introduce constants
Because it doesn't hold e.g. for the batcher of ppermute.

PiperOrigin-RevId: 485601414
2022-11-02 08:27:10 -07:00
Adam Paszke
0fce5be556 Improve undefined axis checks
Previously we checked for out axes being a superset of the defined axes,
but that's just not the right relation. In particular, out_axes of {'a'}
are not a superset of defined axes {'b'}, but axis 'a' is undefined. The
correct check is to verify emptiness of their difference.
2022-10-27 17:05:26 +00:00
Yash Katariya
9e4114f0f1 Move array.py and sharding.py from experimental/ to _src/.
PiperOrigin-RevId: 477201711
2022-09-27 10:06:52 -07:00
Yash Katariya
cbf34cb609 Rename the concrete class Array to ArrayImpl
PiperOrigin-RevId: 477017236
2022-09-26 16:18:30 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Yash Katariya
fc2902c6ac Make the gda and xmap sharding check work generally by checking the OpSharding protos.
PiperOrigin-RevId: 475560097
2022-09-20 08:24:47 -07:00
Yash Katariya
28741b8e0d Some miscellaneous changes to make tests pass when jax.Array is enabled by default.
1. Add `device_buffer` and `device_buffers` fields to Array as a backwards compatible change for DA and SDA.
2. Support PartitionSpecs as input to in_axis_resources and out_axis_resources when jax_array is enabled as a backwards compatible change since all user code uses this currently. Create a MeshPspecSharding internally.
3. Some tests changes to make them pass

PiperOrigin-RevId: 474642889
2022-09-15 13:27:40 -07:00
Yash Katariya
0a5d8e8ec6 Make nested xmap work with Arrays and GDA (in single process).
PiperOrigin-RevId: 474323667
2022-09-14 10:09:27 -07:00
Yash Katariya
da90234cae Delete soft_pmap as it has no users. Please use pjit or xmap if you do want soft_pmap.
`jax.soft_pmap` is undocumented. If it were documented, a deprecation period would have been provided.

PiperOrigin-RevId: 474145090
2022-09-13 15:52:10 -07:00
Yash Katariya
d77848bcc9 Enable jax_array on CPU for the entire JAX test suite!
PiperOrigin-RevId: 468726200
2022-08-19 10:04:35 -07:00
Yash Katariya
007d651ac8 Canonicalize all shardings to OpShardingSharding throughout pjit. Places where pspec is needed, parsed_flatten_op_sharding function is used to retrieve the pspec. The major places are global_to_local and local_to_global. Rest of the changes are just threading through OpShardingSharding.
I have added comments to places to explain things.

Dependence on MeshPspecSharding in Partial eval has been removed. It now depends on OpShardingSharding.

TODO: Fix the round trip through MeshPspecSharding in vmap batching handlers.
PiperOrigin-RevId: 465621165
2022-08-05 12:18:10 -07:00
Adam Paszke
bfb0814ce8 Remove experimental warning from xmap
It doesn't have bugs, or at least not noticeably more than the rest of our code :)
2022-08-05 18:27:11 +00:00
Matthew Johnson
e0c1e6c2ff add custom-policy partial eval and dce rules for pmap
Also add a failing test for xmap.
2022-07-28 21:13:25 -07:00
Sharad Vikram
4870710891 Enable debugging callbacks with pjit on TPU
PiperOrigin-RevId: 462527181
2022-07-21 20:22:14 -07:00
Yash Katariya
90687cc1ff Make lower_mesh_computation accept sharding instances. The new path is tested as everything in pjit goes through the new lower_sharding_computation except of AUTO and UNSPECIFIED (see below for these 2).
* Split `lower_mesh_computation` into `lower_mesh_computation` and `lower_sharding_computation`. This is because `lower_mesh_computation` handles 3 paths; `spmd lowering path`, `non-spmd lowering path` and `xmap spmd lowering path`. I didn't want to add a 4th path to it for general shardings.
  * `lower_sharding_computation` works in SPMD mode since its only used in pjit. Majority of the logic is the same. The only difference is that `mesh` does not exist in this function.

* `MeshComputation` is the point where `lower_mesh_computation` and `lower_sharding_computation` merge.

* `AUTO` and `UNSPECIFIED` cannot be used without mesh right now but I have a CL to fix this.

* Rest of the changes are to make all other functions play nicely with sharding instances.

PiperOrigin-RevId: 461260553
2022-07-15 16:16:23 -07:00
George Necula
3d9c8fbe6f [dynamic-shapes] Ensure that the axis_size_env is passed to sub lowering contexts 2022-07-12 12:44:23 +03:00
Yash Katariya
09ba51f323 Move _get_array_mapping from gda.py to pxla.py
PiperOrigin-RevId: 459891853
2022-07-08 21:38:06 -07:00
Peter Hawkins
0b4b0ba072 Update minimum jaxlib version to 0.3.14. 2022-07-08 00:36:02 +00:00
Yash Katariya
7da733f94b Change the internals of with_sharding_constraint to use the sharding instances.
PiperOrigin-RevId: 459600050
2022-07-07 14:22:10 -07:00
jax authors
2b8fbe9fe4 Merge pull request #11367 from apaszke:xmap-tracer-leak
PiperOrigin-RevId: 459456785
2022-07-07 02:01:51 -07:00
Adam Paszke
5777c1eac2 Add support for post_process of xmap in BatchTrace
PiperOrigin-RevId: 459108183
2022-07-05 12:07:26 -07:00
Adam Paszke
7439e1b1f8 Properly count sublevels when tracing xmap body
Otherwise it can lead to tracer leak errors. I'm not a 100% sure how
this works out, because the sublevel counting has changed since I read
it previously. This replicates the changes applied to
DynamicJaxprTrace.process_map since I last looked at it.
2022-07-05 11:43:26 +00:00
Matthew Johnson
f680269a4f [dynamic-shapes] initial support for dynamic shape typechecks
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2022-06-17 14:57:19 -07:00
Yash Katariya
1089c792d5 Add Array support to xmap. Just using the GDA path.
PiperOrigin-RevId: 454604138
2022-06-13 07:31:03 -07:00
Sharad Vikram
c3aa971948 Enable debug print in xmap 2022-06-07 14:05:46 -07:00
fehiepsi
3ca0d3f149 Rename mesh into Mesh in xmap tutorial and doc 2022-06-03 17:11:09 +07:00
David Hall
3e766c76fb
fix incorrect f-string format in xmap 2022-05-27 11:03:24 -07:00
Bart Chrzaszcz
ae908b8753 Remove local positional semantics assertion in make_xmap_callable.
PiperOrigin-RevId: 451347782
2022-05-27 02:57:18 -07:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Matthew Johnson
7e241b682d improve partial_eval_jaxpr_custom
* add caching via weakref_lru_cache
* add inst_in argument (needed for fixedpoints for loop primitives, in
  follow-up PR), update callers not to over-instantiate inputs (previously I
  had used a convention where call primitives would just stage out eqns with
  all inputs instantiated, for expediene)
* add ensure_out_unknowns and ensure_out_inst arguments, analogues of
  `instantiate` on e.g. partial_eval_jaxpr, jvp_jaxpr, etc (also neede for
 fixpoints of loop primitives)
* better dce in remat_partial_eval (e.g. prune unused residuals)
2022-05-11 13:20:23 -07:00