505 Commits

Author SHA1 Message Date
Peter Hawkins
1d4b7a3701 Hide accidental exports from jax.core.
PiperOrigin-RevId: 511350939
2023-02-21 17:48:40 -08:00
Sharad Vikram
af2306c0a8 Refactor effects system to use effect types, not objects 2023-02-17 17:40:08 -08:00
Roy Frostig
6b4de4f91c remove several more symbols from jax.core
* `DBIdx`
* `DConcreteArray`
* `DimensionHandler`
* `DuplicateAxisNameError`

PiperOrigin-RevId: 510503517
2023-02-17 13:07:00 -08:00
Roy Frostig
e276859d11 remove several symbols from jax.core
* `ClosedCallPrimitive`
* `CustomPpEqnRule`
* `DArray`
* `DArrayDimHandler`

PiperOrigin-RevId: 510343926
2023-02-16 22:55:16 -08:00
Matthew Johnson
ec1e513659 remove accidental re-export of __future__.annotations from jax/core.py
PiperOrigin-RevId: 510233347
2023-02-16 13:47:28 -08:00
Roy Frostig
591e2c8937 remove some exports from jax.core
Namely:
* `AvalMapHandlerPair`
* `AxisEnvFrame`
* `AxisName`
* `AxisPrimitive`
* `AxisSubst`
PiperOrigin-RevId: 510224417
2023-02-16 13:12:35 -08:00
Roy Frostig
6b545a2ddc remove several exported symbols from jax.core
All of these are prefixed by an underscore.

PiperOrigin-RevId: 510194304
2023-02-16 11:20:36 -08:00
Roy Frostig
26045c49e7 remove core.{aval_method,aval_property}
PiperOrigin-RevId: 510043837
2023-02-15 22:22:09 -08:00
Roy Frostig
1b2a318fd1 remove core.axis_substitution_rules
PiperOrigin-RevId: 509989925
2023-02-15 18:42:13 -08:00
Roy Frostig
537372a637 remove core.bint
PiperOrigin-RevId: 509932914
2023-02-15 14:28:29 -08:00
Roy Frostig
22168a0253 remove core.{bot,Bot}
PiperOrigin-RevId: 509884508
2023-02-15 11:13:11 -08:00
Peter Hawkins
a13a2c5cc2 [JAX] Remove obsolete unit type declarations in jax.core.
Remove obsolete unit test in host_callback.

PiperOrigin-RevId: 507473737
2023-02-06 07:33:14 -08:00
George Necula
15be538ebe [shape_poly] Fix the hashing and equality of symbolic dimensions 2023-02-04 08:30:44 +02:00
George Necula
1b04fcb4be [jax2tf] Improve handling of lax.pad and jnp.pad with polymorphic padding config
PiperOrigin-RevId: 498350702
2022-12-29 03:00:32 -08:00
Roy Frostig
523c6f7a53 [jax] move jax.core to jax._src.core
Re-export roughly all of the same symbols via `jax.core` for now.

Co-authored-by: Sharad Vikram <sharadmv@google.com>
PiperOrigin-RevId: 495766963
2022-12-15 20:35:20 -08:00
George Necula
8fb344a724 [jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.

For native serialization we will support two lowering implementations:

  * one is using the growing support in JAX for dynamic shapes,
  of which shape polymorphism is a special case.
  This implementation is enabled with the --jax_dynamic_shapes flag.
  At the moment, the JAX dynamic shapes support is still
  incomplete and over 300 jax2tf shape polymorphism tests fail.

  * a new one (added) here in which we form a Jaxpr using abstract
  values that express dimension sizes as dimension polynomials
  (as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
  This implementation is enabled when --jax_dynamic_shapes is off.
  With this implementation only 50 jax2tf tests fail (to be fixed
  separately).

The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.

The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.

Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.

The key code pattern used in the lowering rule is::

    if not core.is_constant_shape(shape):  # Handles both Var, and polynomials
       shape = mlir.eval_dynamic_shape(ctx, shape)
       return mhlo.DynamicXXX(..., shape)
    else:
       return mhlo.XXX(..., shape)

with `mlir.eval_dynamic_shape` handling both cases::

    def eval_dynamic_shape(ctx, shape):
       if config.jax_dynamic_shapes:
          # Using Var
          return ... subst using ctx.axis_size_env ...
       else:
          # Using polynomials
          return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values

In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.

I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-12-08 08:19:35 +02:00
Peter Hawkins
ac72346ad3 Ensure that the initial dynamic_trace_state is canonicalized.
The non-canonical state meant that we were falling back to a more expensive comparison for the first jit-compiled function in the program. I doubt there will be any impact on real benchmarks, but this perturbs the results of running a single microbenchmark in isolation.

PiperOrigin-RevId: 493489154
2022-12-06 20:39:53 -08:00
jax authors
1027d55b8c Optimize core.find_top_trace
This function is quite important, since it runs at every JAX primitive bind,
but it included a few redundant conditionals.

PiperOrigin-RevId: 492481837
2022-12-02 09:00:50 -08:00
Adam Paszke
bbf22db08b Optimize core.find_top_trace
This function is quite important, since it runs at every JAX primitive bind,
but it included a few redundant conditionals.

PiperOrigin-RevId: 492460102
2022-12-02 07:04:52 -08:00
Jake VanderPlas
e7f53479e2 Some cleanups related to dropping Python 3.7 2022-11-29 15:54:49 -08:00
Sharad Vikram
74b136e62c Delete jax_experimental_name_stack flag
PiperOrigin-RevId: 487601864
2022-11-10 11:59:50 -08:00
Jake VanderPlas
8fbf8da810 Declare Array.sharding & raise an error on tracers 2022-11-08 14:20:46 -08:00
Peter Hawkins
cd84eb10a6 Add a number of missing function cross-references in the docs. 2022-11-07 12:00:26 -05:00
Matthew Johnson
f2f2faa4fa add a basic prototype of piles, behind jax_dynamic_shapes
Co-authored-by: Adam Paszke <apaszke@google.com>
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2022-11-06 17:03:04 -08:00
jax authors
8dea82e089 Merge pull request #13022 from mattjj:leak-checker-improvements
PiperOrigin-RevId: 484640693
2022-10-28 16:05:43 -07:00
Matthew Johnson
6ebf44a681 make leak checker errors explain why objects are alive
Co-authored-by: Qiao Zhang <zhangqiaorjc@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2022-10-28 14:12:17 -07:00
Parker Schuh
5cfc708843 Remove error-prone most_recent_entry() support from lu.cache.
PiperOrigin-RevId: 484382188
2022-10-27 16:41:44 -07:00
Jake VanderPlas
1ed18fa500 add allow_opaque_dtype to dtypes.canonicalize_dtype utility 2022-10-17 13:47:42 -07:00
Matthew Johnson
df5f7cb8d3 Rolling forward https://github.com/google/jax/pull/12707 after rollback, due to changes in relatively trivial jax.numpy shape validation code failed in some downstream user tests.
PiperOrigin-RevId: 480229237
2022-10-10 18:51:37 -07:00
jax authors
9cabd227d7 Copybara import of the project:
--
6d2aaac2454117d54997243714c1a009827707ca by Matthew Johnson <mattjj@google.com>:

implement bint arrays (opaque dtypes), add padding rules

Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
PiperOrigin-RevId: 479883102
2022-10-09 01:25:50 -07:00
Matthew Johnson
6d2aaac245 implement bint arrays (opaque dtypes), add padding rules
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-10-08 22:57:29 -07:00
Matthew Johnson
a8826e672b [dynamic-shapes] Add basic slicing support
If e.g. `x : f32[10, n]` then we want to handle Python expressions like `x[0]`.
To do that, we can use a generalized version of `dynamic_slice` which allows
dynamic slice sizes (where the result shape depends on those slice sizes).

Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-09-28 15:55:51 -07:00
Jake VanderPlas
0cb233eec9 Add initial jax.Array base class for instance checks & annotation 2022-09-26 07:48:43 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Jake VanderPlas
74698048f3 Tracer: add missing __round__ and __reversed__ methods 2022-09-20 09:09:23 -07:00
Jake VanderPlas
cc72a20e9b use jax._src.typing in lax.py & a few other places 2022-09-12 09:08:13 -07:00
Matthew Johnson
58826507cc [dynamic-shapes] add basic vmap-of-indexing support
The main changes here are only indirectly related to gather: we just had to
update some other rules (e.g. for comparison, and squeeze) for a simple
dynamic-batch-shape gather to work.

I also skipped two tests and deleted some old dynamic shape slicing logic
because we want to handle that differently. We didn't have to do that removal
in this PR, but it's just convenient given I'm looking at indexing again.
2022-09-08 17:52:12 -07:00
Yash Katariya
7fbf8ec669 Fix Forward. The fix is on the user's end. Original PR: https://github.com/google/jax/pull/12217
Co-authored-by: Matthew Johnson <mattjj@google.com>
Co-authored-by: Yash Katariya <yashkatariya@google.com>
PiperOrigin-RevId: 472999907
2022-09-08 08:49:40 -07:00
jax authors
14f1a345a1 roll back breakage
PiperOrigin-RevId: 472949225
2022-09-08 03:59:54 -07:00
jax authors
b09a6175fb Merge pull request #12217 from mattjj:dce-and-execute-trivial
PiperOrigin-RevId: 472856448
2022-09-07 17:41:16 -07:00
Matthew Johnson
3c811b1520 fix bugs, infeed/outfeed must be considered effectful
Co-authored-by: Yash Katariya <yashkatariya@google.com>
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-09-06 15:13:01 -07:00
Yash Katariya
b7e4e44cbf DCE jaxpr and trivial_jaxpr support for lower_sharding_computation
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 471274989
2022-09-06 14:09:10 -07:00
Roy Frostig
8f045b12d6 internal rename: swap mentions of "custom eltypes" for "opaque dtypes"
Also, avoid direct set membership tests on `core.opaque_dtypes`. Update
callers to use `core.{is,has}_opaque_dtype` predicates instead.
2022-08-30 16:52:08 -07:00
Roy Frostig
73bf0aa30c access rules through a hidden attribute of opaque dtype 2022-08-30 14:06:01 -07:00
Matthew Johnson
bbb8048d2e Add batching rules for state primitives and for_loop
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-08-29 11:40:09 -07:00
Roy Frostig
6071a8f875 roll-forward #11952, take 2
Now with:
* resetting the `random.PRNGKeyArray` type during Python typechecks
* zeroing JVP rules for random primitives
* temporarily skipping vmap-of-pmap test with keys under `config.jax_array`

PiperOrigin-RevId: 469276609
2022-08-22 13:57:31 -07:00
jax authors
3a2f25ff31 roll-forward #11952
... with a small adjustment, resetting the `random.PRNGKeyArray` type
during Python typechecking.

PiperOrigin-RevId: 468840334
2022-08-19 21:02:18 -07:00
Roy Frostig
9789e83b26 roll-forward #11952
... with a small adjustment, resetting the `random.PRNGKeyArray` type
during Python typechecking.

PiperOrigin-RevId: 468835674
2022-08-19 20:12:32 -07:00
jax authors
a6c6416872 Internal change
PiperOrigin-RevId: 468712508
2022-08-19 08:56:49 -07:00
Roy Frostig
34b63dfc77 teach jax2tf about custom eltypes, key arrays, and random key primitives
Specifically:

* Introduce a `physical_avals` view as a custom eltype method. This is
  analogous to the existing `aval_to_ir_types`, but where the output
  is an aval with a non-custom eltype (and hence a direct
  correspondence to TF and to lowerings).

* Change jax2tf to continue tracing with logical avals, but to
  maintain TF tensors of corresponding physical shape/dtype, and to
  translate to TF operations based on physical avals where relevant.

* Fix up various TF impl rules to follow physical avals. To this end,
  add a "physical" mode to jax2tf's `_convert_jax_impl` helper, which
  carries out the conversion using physical rather than logical avals.

* Write TF impl rules for `random_{seed,split,fold_in,bits}`
  primitives. To this end, factor out the part of these primitives'
  impl rules that operates on the base array and convert that, pass it
  through `_convert_jax_impl` in physical mode.

* Teach the jax2tf test harness how to unwrap key-array-typed outputs
  into physical `uint32` arrays that it can use in comparison tests.
2022-08-18 21:46:55 -07:00