104 Commits

Author SHA1 Message Date
Jake VanderPlas
87aec2433b internal: refactor array methods into separate private submodule 2023-03-23 10:57:53 -07:00
Etienne Pot
4cb32ba46f Fix isinstance(k, PRNGKeyArray) on PRNGKeyArray subclasses
PiperOrigin-RevId: 518803946
2023-03-23 02:32:06 -07:00
Yash Katariya
58fed7001a Remove pxla.OutputType enum class now that the only output can be jax.Array
PiperOrigin-RevId: 517985356
2023-03-20 09:09:58 -07:00
Yash Katariya
c2d5527f72 [Jax cleanup]
* Remove lower_xla_callable and all related functions
* Remove pxla.device_put
* Remove dispatch.device_put_handlers

PiperOrigin-RevId: 517249345
2023-03-16 15:47:28 -07:00
Peter Hawkins
dea7450e4e Remove references to jax.config.jax_array, which is always True at head.
PiperOrigin-RevId: 516970232
2023-03-15 17:09:11 -07:00
Peter Hawkins
86263be17f Make PRNG seed types more liberal in what they accept.
PRNG seeds can be Arrays, not just concrete ints.

PiperOrigin-RevId: 516532859
2023-03-14 08:32:58 -07:00
Peter Hawkins
1925aa1109 Split Sharding subclasses out of _src/sharding.py into _src/sharding_impls.py
By defining the Sharding base class in its own module, we can pull it out into a separate Bazel submodule, which will help pytype inference when defining Array.

PiperOrigin-RevId: 516223009
2023-03-13 08:50:18 -07:00
Jake VanderPlas
c8c269f5f5 internal: avoid unused imports in lax_numpy 2023-03-08 10:29:04 -08:00
Parker Schuh
17079d9072 Add sharding to the signature of shard_args and update
the jax.Array handler unpack to single device arrays after
resharding.

PiperOrigin-RevId: 513624513
2023-03-02 13:29:03 -08:00
Peter Hawkins
8fb1fd318d Replace jax._src.util.prod with math.prod.
math.prod() was added in Python 3.8, so we can assume it is always present.

PiperOrigin-RevId: 513011144
2023-02-28 12:41:00 -08:00
Yash Katariya
0ffdeb3de2 Rename jax.sharding.OpShardingSharding to jax.sharding.GSPMDSharding. jax.sharding.OpShardingSharding will be removed in 3 months from Feb 17, 2023.
PiperOrigin-RevId: 510556189
2023-02-17 17:11:06 -08:00
Peter Hawkins
cd0533cab0 Replace uses of jnp.ndarray with jax.Array inside JAX.
PiperOrigin-RevId: 509939691
2023-02-15 14:53:00 -08:00
Roy Frostig
1c84e4a753 migrate internal dependencies from jax.interpreters.batching to jax._src.interpreters.batching
... in preparation for paring down `jax.interpreters.batching`'s exported symbols.

PiperOrigin-RevId: 508487887
2023-02-09 15:11:57 -08:00
Peter Hawkins
8268cd562d Add infrastructure for managing deprecations.
Use it to deprecate jax.experimental.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.interpreters.pxla.Mesh.

PiperOrigin-RevId: 508349776
2023-02-09 05:48:40 -08:00
Peter Hawkins
cc8d7fae32 Move jax.interpreters.mlir to jax._src.interpreters.mlir.
Replace jax.interpreters.mlir with a shim that re-exports names that are likely to be used externally.

PiperOrigin-RevId: 508187063
2023-02-08 14:39:01 -08:00
Roy Frostig
219723c738 migrate internal dependencies from jax.interpreters.ad to jax._src.interpreters.ad
... in preparation for paring down `jax.interpreters.ad`'s exported symbols.

Includes some import fixups along the way.

PiperOrigin-RevId: 507684262
2023-02-06 22:52:36 -08:00
Yash Katariya
8a69444ff9 Bump minimum jaxlib_version to 0.4.2 i.e xla_extension_version == 119 and mlir_api_version == 43
PiperOrigin-RevId: 507520956
2023-02-06 10:37:33 -08:00
Jake VanderPlas
0b5443c6e8 Clean up: remove unused helper functions 2023-02-01 09:55:58 -08:00
jax authors
69a2931830 Merge pull request #14189 from froystig:opaque-dtypes-to-mlir-avals
PiperOrigin-RevId: 505219181
2023-01-27 15:07:49 -08:00
Roy Frostig
b1b4915c1c remove opaque dtype aval translation to MLIR types
We already have a mapping from opaquely-dtyped avals to basic
"physical" avals, and we can map the latter to MLIR types.
2023-01-27 14:27:30 -08:00
George Necula
d25bcac93d [shape_poly] Add better support for division, and working with strides
Previously, division was only supported in certain situation, and this
led to errors, e.g., when using strides. Now we generalize the polynomials
to also include "floordiv(E, E)" and "mod(E, E)" as atoms, in addition
to dimension variables. A symbolic dimension is now a sum of products
of atoms. (We also changed the documentation to use symbolic dimension
instead of dimension polynomials).
2023-01-25 07:37:54 -08:00
Yash Katariya
1fc9197c79 Simplify Array's shard_arg_handler by merging pmap and pjit/xmap paths
PiperOrigin-RevId: 497991966
2022-12-27 10:16:44 -08:00
George Necula
7d452adfd3 Add support for dynamic shapes to GPU threefry2x32 custom call.
In presence of dynamic shapes the ThreeFry2x32Descriptor will contain the
value n=-1, and the actual desired output length will be passed as
an additional operand. If the shape is static then the length will be
passed as part of the descriptor.

PiperOrigin-RevId: 497945778
2022-12-27 04:48:26 -08:00
Roy Frostig
d927a5dbf3 migrate internal dependencies from jax.core to jax._src.core
... in preparation for paring down `jax.core`'s exported symbols.

Also includes a few import fixups along the way, and a TODO comment to avoid an
import cycle in `_src/dtypes.py`.

PiperOrigin-RevId: 496024782
2022-12-16 21:00:14 -08:00
Eugene Burmako
b8ae8e3fa1 (NFC) Prepare for migration from producing MHLO to producing StableHLO
This CL renames occurrences of "mhlo" in: 1) names, 2) tests, 3) prose in order
to prepare for the upcoming migration.

Unchanged occurrences:
  1) Public API that contains "mhlo", e.g. XlaLowering.mhlo and the "mhlo"
     argument value in Lowering.as_text and Lowering.compiler_ir.
  2) Documentation (changelog, JEPs, IR examples, etc).
  3) One rare situation where prose says "StableHLO" and "MHLO" in one sentence,
     so both are necessary to disambiguate.

PiperOrigin-RevId: 495771153
2022-12-15 21:00:07 -08:00
George Necula
27f5bd057c Improves handling of opaque types for dynamic shapes
The immediate motivation for this is to support the lowering
to StableHLO for programs with polymorphic shapes. This requires
mixing of dynamic shapes with opaque types.

The general strategy is to push the actual selection of the MHLO ops
down into mlir module (e.g., mlir.slice_op, mlir.broadcast_in_dim)
so that we have one place where we pick whether we use the Dynamic
or static ops. These routines can also handle the opaque type.
This will result in a recursive
call to, e.g., mlir.slice_op, but the inner call will be using
the physical avals, which should not be opaque anymore.

While making this change I was confused by the fact that the
custom KeyTyRules in prng.py have lowerings that return multiple
MHLO ops. See https://github.com/google/jax/pull/11768#issuecomment-1342349102
and I changed the rules to return a single op.

.
2022-12-12 05:19:04 +01:00
George Necula
8fb344a724 [jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.

For native serialization we will support two lowering implementations:

  * one is using the growing support in JAX for dynamic shapes,
  of which shape polymorphism is a special case.
  This implementation is enabled with the --jax_dynamic_shapes flag.
  At the moment, the JAX dynamic shapes support is still
  incomplete and over 300 jax2tf shape polymorphism tests fail.

  * a new one (added) here in which we form a Jaxpr using abstract
  values that express dimension sizes as dimension polynomials
  (as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
  This implementation is enabled when --jax_dynamic_shapes is off.
  With this implementation only 50 jax2tf tests fail (to be fixed
  separately).

The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.

The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.

Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.

The key code pattern used in the lowering rule is::

    if not core.is_constant_shape(shape):  # Handles both Var, and polynomials
       shape = mlir.eval_dynamic_shape(ctx, shape)
       return mhlo.DynamicXXX(..., shape)
    else:
       return mhlo.XXX(..., shape)

with `mlir.eval_dynamic_shape` handling both cases::

    def eval_dynamic_shape(ctx, shape):
       if config.jax_dynamic_shapes:
          # Using Var
          return ... subst using ctx.axis_size_env ...
       else:
          # Using polynomials
          return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values

In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.

I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-12-08 08:19:35 +02:00
Roy Frostig
431c51a3eb rename iota_32x2_shape to iota_2x32_shape
... for consistency with other internal Threefry primitive names.
2022-12-05 11:09:56 -08:00
Roy Frostig
75af6b58d9 add a jax2tf translation rule for the shaped-iota primitive
This allows for jax2tf conversion of the partitionable Threefry RNG.
2022-12-05 09:19:25 -08:00
Roy Frostig
a3483dbe32 docstring for shaped iota primitive 2022-12-05 09:15:27 -08:00
Jake VanderPlas
26d9837b36 Switch to new-style f-strings 2022-12-01 09:14:16 -08:00
Roy Frostig
dab2909a31 make threefry split and fold_in symmetric
Namely, make it so that `split(key, n)[i]` equals `fold_in(key, i)`
for any key and for `0 <= i < n`.

This change affects the observed random bits for a fixed key (indirectly
through splits and folds), so here we guard it behind
`jax.config.jax_threefry_partitionable`. It's not described very well
by the flag name, but it makes for a simple way to bundle together
several random-bit-altering changes as part of the same upgrade cycle.
2022-11-21 15:24:48 -08:00
Yash Katariya
c42bad85ef Make MeshPspecSharding an alias for NamedSharding (it was the other way around before this CL).
PiperOrigin-RevId: 488473538
2022-11-14 14:44:00 -08:00
Patrick Kidger
d2afa84a6e PRNGKeyArray is now a virtual subclass of ndarray 2022-11-11 08:04:38 -08:00
Matthew Johnson
213d2c8592 integrate new (partitionable, count-space-exhaustive) counts generation 2022-10-29 00:05:49 -07:00
Roy Frostig
63bfb87edf wip bits-changing partitionable rng based on iota raveling
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-10-28 14:17:34 -07:00
jax authors
89b240ba02 Merge pull request #13012 from mattjj:rng-part-overgenerate
PiperOrigin-RevId: 484567918
2022-10-28 10:41:35 -07:00
Roy Frostig
c8b9280fb3 partitionable threefry PRNG random bits implementation
the cost is 2x overgeneration of bits

Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-10-28 10:07:14 -07:00
Peter Hawkins
320d531521 Increase the minimum jaxlib version to 0.3.22.
The minimum xla_extension_version is now 98 and the minimum mlir_api_version is now 32.
2022-10-27 10:24:11 -04:00
Yash Katariya
9e4114f0f1 Move array.py and sharding.py from experimental/ to _src/.
PiperOrigin-RevId: 477201711
2022-09-27 10:06:52 -07:00
Yash Katariya
389a2e570d Add a backwards compat path for op_sharding.clone() because it doesn't exist with the latest jaxlib on pypi
PiperOrigin-RevId: 477034758
2022-09-26 17:50:19 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
jax authors
edfbbd7203 Merge pull request #12297 from mattjj:computation-follows-data-prng
PiperOrigin-RevId: 473092328
2022-09-08 14:57:31 -07:00
Matthew Johnson
47b2dfe92f add _device attribute to PRNGKeyArray so that computation follows key placement
unrelated: remove some redundant hasattr + try / except AttributeError
2022-09-08 14:30:18 -07:00
Yash Katariya
7fbf8ec669 Fix Forward. The fix is on the user's end. Original PR: https://github.com/google/jax/pull/12217
Co-authored-by: Matthew Johnson <mattjj@google.com>
Co-authored-by: Yash Katariya <yashkatariya@google.com>
PiperOrigin-RevId: 472999907
2022-09-08 08:49:40 -07:00
jax authors
14f1a345a1 roll back breakage
PiperOrigin-RevId: 472949225
2022-09-08 03:59:54 -07:00
Yash Katariya
b7e4e44cbf DCE jaxpr and trivial_jaxpr support for lower_sharding_computation
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 471274989
2022-09-06 14:09:10 -07:00
Yash Katariya
0584c6a1c4 Add support to handle arbitrary shardings to KeyArray. Resolve all the TODOs that were created before.
Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 471443690
2022-08-31 22:54:06 -07:00
jax authors
bf7525e121 Merge pull request #12170 from froystig:just-dtype
PiperOrigin-RevId: 471409020
2022-08-31 18:36:47 -07:00
Roy Frostig
023764376c support key array pickling
Involves:
* a weaker notion of equality on key element types
* avoiding jitted functions as PRNG impl fields
2022-08-31 12:03:53 -07:00