517 Commits

Author SHA1 Message Date
George Necula
cea77f5d17 Improve some deprecation error messages 2024-01-07 07:09:39 +02:00
Jake VanderPlas
adefbca642 jax.core: deprecate several private APIs 2023-12-15 13:37:09 -08:00
George Necula
0a02d83015 [shape_poly] Add simpler APIs max_dim and min_dim, improve >= 0
Add core.max_dim and core.min_dim as nicer wrappers around the
core.non_negative_dim. Also improve the completeness of the
heuristics for deciding >= 0, and add more tests.
2023-12-07 09:41:47 +01:00
Jake VanderPlas
2edb66de8a jax.core: point deprecation to jax.extend 2023-10-13 12:49:05 -07:00
Jake VanderPlas
e0944c938f jax.core: deprecate some inadvertent exports 2023-10-11 15:22:19 -07:00
Jake Vanderplas
d8f799391b COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/17027 from jakevdp:dtypes-annotations a116a9c498a7b085f9b3fec93b37da12289f6e31
PiperOrigin-RevId: 554905739
2023-08-08 20:38:44 +00:00
Jake VanderPlas
3b6b988473 fix deprecations in core.py 2023-07-25 09:47:04 -07:00
Jake Vanderplas
b4132b4c50 Copybara import of the project:
--
b243ea79ae7c9e2c2aa85e264b8dca8fc4c61b7b by Jake VanderPlas <jakevdp@google.com>:

Rename opaque dtype to extended dtype.

This includes three deprecations:
 - jax.core.is_opaque_dtype(dt) is deprecated in favor of jnp.issubdtype(dt, jax.dtypes.extended)
 - jax.core.has_opaque_dtype(x) is deprecated in favor of jnp.issubdtype(x.dtype, jax.dtypes.extended)
 - the allow_opaque_dtype argument to jax.core.canonicalize_dtype is now allow_extended_dtype
Because jax.core is explicitly excluded from the API deprecation policy, these changes will not be
subject to a standard 3-month deprecation period.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16824 from jakevdp:extended-dtype b243ea79ae7c9e2c2aa85e264b8dca8fc4c61b7b
PiperOrigin-RevId: 550674205
2023-07-24 14:38:20 -07:00
George Necula
aa85cd9a31 [shape_poly] Cleanup exported API symbols.
We remove the API functions related to shape polymorphism from the public API exported in jax.core. I could not remove a few API entry points because they
are referenced in Google. Will cleanup those uses next.

PiperOrigin-RevId: 545349000
2023-07-03 23:10:07 -07:00
George Necula
9261edaf94 [shape_poly] Cleanups for the shape polymorphism APIs.
Shape polymorphism relies on a number of functions defined
in core.py. Overtime we have accumulated some duplicate functionality
in those functions. Here we do some cleanups:

  * remove symbolic_equal_dim and symbolic_equal_shape in favor of the
    newer definitely_equal and definitely_equal_shape
  * remove is_special_dim_size, which checks that a value is a
    dimension expression (not a constant). Some uses are replaced
    with `not is_constant_dim` and others with `is_dim`.
  * introduce concrete_dim_or_error to check that a value is
    a dimension
2023-06-30 15:56:57 +03:00
Peter Hawkins
eaf7eb2626 Break cycle between _src/core.py and _src/dtypes.py.
PiperOrigin-RevId: 532788430
2023-05-17 07:58:59 -07:00
jax authors
59f33a4338 Expose JaxprDebugInfo so others can use it for pytyping.
PiperOrigin-RevId: 525749186
2023-04-20 08:09:26 -07:00
Peter Hawkins
1d4b7a3701 Hide accidental exports from jax.core.
PiperOrigin-RevId: 511350939
2023-02-21 17:48:40 -08:00
Sharad Vikram
af2306c0a8 Refactor effects system to use effect types, not objects 2023-02-17 17:40:08 -08:00
Roy Frostig
6b4de4f91c remove several more symbols from jax.core
* `DBIdx`
* `DConcreteArray`
* `DimensionHandler`
* `DuplicateAxisNameError`

PiperOrigin-RevId: 510503517
2023-02-17 13:07:00 -08:00
Roy Frostig
e276859d11 remove several symbols from jax.core
* `ClosedCallPrimitive`
* `CustomPpEqnRule`
* `DArray`
* `DArrayDimHandler`

PiperOrigin-RevId: 510343926
2023-02-16 22:55:16 -08:00
Matthew Johnson
ec1e513659 remove accidental re-export of __future__.annotations from jax/core.py
PiperOrigin-RevId: 510233347
2023-02-16 13:47:28 -08:00
Roy Frostig
591e2c8937 remove some exports from jax.core
Namely:
* `AvalMapHandlerPair`
* `AxisEnvFrame`
* `AxisName`
* `AxisPrimitive`
* `AxisSubst`
PiperOrigin-RevId: 510224417
2023-02-16 13:12:35 -08:00
Roy Frostig
6b545a2ddc remove several exported symbols from jax.core
All of these are prefixed by an underscore.

PiperOrigin-RevId: 510194304
2023-02-16 11:20:36 -08:00
Roy Frostig
26045c49e7 remove core.{aval_method,aval_property}
PiperOrigin-RevId: 510043837
2023-02-15 22:22:09 -08:00
Roy Frostig
1b2a318fd1 remove core.axis_substitution_rules
PiperOrigin-RevId: 509989925
2023-02-15 18:42:13 -08:00
Roy Frostig
537372a637 remove core.bint
PiperOrigin-RevId: 509932914
2023-02-15 14:28:29 -08:00
Roy Frostig
22168a0253 remove core.{bot,Bot}
PiperOrigin-RevId: 509884508
2023-02-15 11:13:11 -08:00
Peter Hawkins
a13a2c5cc2 [JAX] Remove obsolete unit type declarations in jax.core.
Remove obsolete unit test in host_callback.

PiperOrigin-RevId: 507473737
2023-02-06 07:33:14 -08:00
George Necula
15be538ebe [shape_poly] Fix the hashing and equality of symbolic dimensions 2023-02-04 08:30:44 +02:00
George Necula
1b04fcb4be [jax2tf] Improve handling of lax.pad and jnp.pad with polymorphic padding config
PiperOrigin-RevId: 498350702
2022-12-29 03:00:32 -08:00
Roy Frostig
523c6f7a53 [jax] move jax.core to jax._src.core
Re-export roughly all of the same symbols via `jax.core` for now.

Co-authored-by: Sharad Vikram <sharadmv@google.com>
PiperOrigin-RevId: 495766963
2022-12-15 20:35:20 -08:00
George Necula
8fb344a724 [jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.

For native serialization we will support two lowering implementations:

  * one is using the growing support in JAX for dynamic shapes,
  of which shape polymorphism is a special case.
  This implementation is enabled with the --jax_dynamic_shapes flag.
  At the moment, the JAX dynamic shapes support is still
  incomplete and over 300 jax2tf shape polymorphism tests fail.

  * a new one (added) here in which we form a Jaxpr using abstract
  values that express dimension sizes as dimension polynomials
  (as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
  This implementation is enabled when --jax_dynamic_shapes is off.
  With this implementation only 50 jax2tf tests fail (to be fixed
  separately).

The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.

The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.

Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.

The key code pattern used in the lowering rule is::

    if not core.is_constant_shape(shape):  # Handles both Var, and polynomials
       shape = mlir.eval_dynamic_shape(ctx, shape)
       return mhlo.DynamicXXX(..., shape)
    else:
       return mhlo.XXX(..., shape)

with `mlir.eval_dynamic_shape` handling both cases::

    def eval_dynamic_shape(ctx, shape):
       if config.jax_dynamic_shapes:
          # Using Var
          return ... subst using ctx.axis_size_env ...
       else:
          # Using polynomials
          return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values

In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.

I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-12-08 08:19:35 +02:00
Peter Hawkins
ac72346ad3 Ensure that the initial dynamic_trace_state is canonicalized.
The non-canonical state meant that we were falling back to a more expensive comparison for the first jit-compiled function in the program. I doubt there will be any impact on real benchmarks, but this perturbs the results of running a single microbenchmark in isolation.

PiperOrigin-RevId: 493489154
2022-12-06 20:39:53 -08:00
jax authors
1027d55b8c Optimize core.find_top_trace
This function is quite important, since it runs at every JAX primitive bind,
but it included a few redundant conditionals.

PiperOrigin-RevId: 492481837
2022-12-02 09:00:50 -08:00
Adam Paszke
bbf22db08b Optimize core.find_top_trace
This function is quite important, since it runs at every JAX primitive bind,
but it included a few redundant conditionals.

PiperOrigin-RevId: 492460102
2022-12-02 07:04:52 -08:00
Jake VanderPlas
e7f53479e2 Some cleanups related to dropping Python 3.7 2022-11-29 15:54:49 -08:00
Sharad Vikram
74b136e62c Delete jax_experimental_name_stack flag
PiperOrigin-RevId: 487601864
2022-11-10 11:59:50 -08:00
Jake VanderPlas
8fbf8da810 Declare Array.sharding & raise an error on tracers 2022-11-08 14:20:46 -08:00
Peter Hawkins
cd84eb10a6 Add a number of missing function cross-references in the docs. 2022-11-07 12:00:26 -05:00
Matthew Johnson
f2f2faa4fa add a basic prototype of piles, behind jax_dynamic_shapes
Co-authored-by: Adam Paszke <apaszke@google.com>
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2022-11-06 17:03:04 -08:00
jax authors
8dea82e089 Merge pull request #13022 from mattjj:leak-checker-improvements
PiperOrigin-RevId: 484640693
2022-10-28 16:05:43 -07:00
Matthew Johnson
6ebf44a681 make leak checker errors explain why objects are alive
Co-authored-by: Qiao Zhang <zhangqiaorjc@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2022-10-28 14:12:17 -07:00
Parker Schuh
5cfc708843 Remove error-prone most_recent_entry() support from lu.cache.
PiperOrigin-RevId: 484382188
2022-10-27 16:41:44 -07:00
Jake VanderPlas
1ed18fa500 add allow_opaque_dtype to dtypes.canonicalize_dtype utility 2022-10-17 13:47:42 -07:00
Matthew Johnson
df5f7cb8d3 Rolling forward https://github.com/google/jax/pull/12707 after rollback, due to changes in relatively trivial jax.numpy shape validation code failed in some downstream user tests.
PiperOrigin-RevId: 480229237
2022-10-10 18:51:37 -07:00
jax authors
9cabd227d7 Copybara import of the project:
--
6d2aaac2454117d54997243714c1a009827707ca by Matthew Johnson <mattjj@google.com>:

implement bint arrays (opaque dtypes), add padding rules

Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
PiperOrigin-RevId: 479883102
2022-10-09 01:25:50 -07:00
Matthew Johnson
6d2aaac245 implement bint arrays (opaque dtypes), add padding rules
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-10-08 22:57:29 -07:00
Matthew Johnson
a8826e672b [dynamic-shapes] Add basic slicing support
If e.g. `x : f32[10, n]` then we want to handle Python expressions like `x[0]`.
To do that, we can use a generalized version of `dynamic_slice` which allows
dynamic slice sizes (where the result shape depends on those slice sizes).

Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-09-28 15:55:51 -07:00
Jake VanderPlas
0cb233eec9 Add initial jax.Array base class for instance checks & annotation 2022-09-26 07:48:43 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Jake VanderPlas
74698048f3 Tracer: add missing __round__ and __reversed__ methods 2022-09-20 09:09:23 -07:00
Jake VanderPlas
cc72a20e9b use jax._src.typing in lax.py & a few other places 2022-09-12 09:08:13 -07:00
Matthew Johnson
58826507cc [dynamic-shapes] add basic vmap-of-indexing support
The main changes here are only indirectly related to gather: we just had to
update some other rules (e.g. for comparison, and squeeze) for a simple
dynamic-batch-shape gather to work.

I also skipped two tests and deleted some old dynamic shape slicing logic
because we want to handle that differently. We didn't have to do that removal
in this PR, but it's just convenient given I'm looking at indexing again.
2022-09-08 17:52:12 -07:00
Yash Katariya
7fbf8ec669 Fix Forward. The fix is on the user's end. Original PR: https://github.com/google/jax/pull/12217
Co-authored-by: Matthew Johnson <mattjj@google.com>
Co-authored-by: Yash Katariya <yashkatariya@google.com>
PiperOrigin-RevId: 472999907
2022-09-08 08:49:40 -07:00