61 Commits

Author SHA1 Message Date
Alexey Radul
6f09fe840e Better error message when broadcasting ragged to static shape.
Co-authored-by: Matthew Johnson <mattjj@google.com>
2023-07-07 09:23:29 -04:00
George Necula
9261edaf94 [shape_poly] Cleanups for the shape polymorphism APIs.
Shape polymorphism relies on a number of functions defined
in core.py. Overtime we have accumulated some duplicate functionality
in those functions. Here we do some cleanups:

  * remove symbolic_equal_dim and symbolic_equal_shape in favor of the
    newer definitely_equal and definitely_equal_shape
  * remove is_special_dim_size, which checks that a value is a
    dimension expression (not a constant). Some uses are replaced
    with `not is_constant_dim` and others with `is_dim`.
  * introduce concrete_dim_or_error to check that a value is
    a dimension
2023-06-30 15:56:57 +03:00
Peter Hawkins
816ba91263 Use lower-case PEP 585 names for types.
Issue https://github.com/google/jax/issues/16537

PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
jax authors
f67acee129 Merge pull request #16430 from jakevdp:bool-error
PiperOrigin-RevId: 542951181
2023-06-23 14:00:12 -07:00
jax authors
63415a9184 Merge pull request #16386 from axch:ragged-einsum
PiperOrigin-RevId: 542887557
2023-06-23 10:00:07 -07:00
Ayaka
feb34ce074
Fix typo: ConcretizationError -> ConcretizationTypeError 2023-06-22 16:01:35 +08:00
Ayaka
5da5804824
Fix typo in documentation 2023-06-22 15:47:41 +08:00
Jake VanderPlas
f1e603e4b3 errors: create TracerBoolConversionError for more targeted debugging tips 2023-06-21 01:41:45 -07:00
Jake VanderPlas
452a3b928b Errors: avoid printing tracer repr for concretization errors 2023-06-20 00:33:51 -07:00
Lena Martens
fbf8823da3 Add live-analysis memory optimization to more jaxpr interpreters.
Follow-up on 8a85e76a5cff0897eccbafc48da836b6f6704e5d

PiperOrigin-RevId: 540857501
2023-06-16 06:08:51 -07:00
Alexey Radul
63f912c220 Test and implement ragged einsum. 2023-06-13 17:04:43 -04:00
Alexey Radul
d67e309482 Update todo comments based on offline discussion. 2023-06-13 10:44:52 -04:00
Alexey Radul
effaf674ae Test and fix jnp.broadcast_to. 2023-06-08 16:17:43 -04:00
Matthew Johnson
1c6a892c7e Improve printing of bints and piles, and allow bints in convert_element_type. 2023-05-19 13:14:48 -07:00
Alexey Radul
2daeec83ce Redefine the pile representation from concatenated to stacked-and-padded.
The advantage (already being realized) is that the batching rules
become much simpler: we just batch along the stacked axis as always,
and when a reduction is about to occur, also mask out the padding
elements, replacing them with the identity element of the reduction.

This commit

- Changes the intended representation of data for piles and the
  corresponding BatchTracers.
- Re-defines ConcatAxis as RaggedAxis to represent the metadata.
- Updates `defreducer` to require the identity function (in case
  masking is needed), and supplies it everywhere.
- Flushes batching.segment_sum, as it is dead code now.
- Deletes unpack_concat_axes and reassemble_concat_axes, because they
  are irrelevant to the padded representation.
2023-05-19 13:13:15 -07:00
Roy Frostig
180e26dafb remove physical_avals rule in favor of physical_element_aval 2023-05-17 20:07:58 -07:00
Peter Hawkins
eaf7eb2626 Break cycle between _src/core.py and _src/dtypes.py.
PiperOrigin-RevId: 532788430
2023-05-17 07:58:59 -07:00
George Necula
876c53abb7 [shape_poly] Refactor the unification of the argument abstract values with the actual arguments
This was called shape_poly.compute_dim_values. We rename it to
shape_poly.unify_avals_with_args and we add better error reporting to it.
Now it will identify the arg/kwarg where there is a shape discrepancy.

This is intended to be a pure refactoring, in preparation for adding
support for shape polymorphism to jax_export.call_exported.
2023-04-27 08:59:59 +02:00
Matthew Johnson
84ae14e7d3 [djax] handle simple reshapes and size-0 checks
One of the main changes here is that we don't do division in handling
x.reshape(..., -1) unless we have to.
2023-04-21 19:20:48 -07:00
Peter Hawkins
a3b262c379 Use the traceback of the call site when assigning a source location to an inlined function.
Improves but does not completely fix https://github.com/google/jax/issues/15663 . The non-inlined case still has similar problems.
2023-04-19 13:56:53 -04:00
Jake VanderPlas
72bb8ab753 jax.Array: dynamically define abstract methods 2023-04-18 13:08:32 -07:00
Jake VanderPlas
5521423d92 Change np.prod->math.prod
Why? This is generally used for static operations on shapes, but np.prod
has an unfortunate corner-case behavior that np.prod([]) returns a float.
math.prod is available as of Python 3.8, and is a better solution here.
2023-04-13 11:48:11 -07:00
Peter Hawkins
1c8512b1fa Micro-optimization: speed up JaxprEqn.replace().
PiperOrigin-RevId: 523415813
2023-04-11 09:00:12 -07:00
Matthew Johnson
9dabb6fa59 [shard-map] better errors for not-implemented-in-eager features 2023-04-08 21:12:40 -07:00
jax authors
c42aae9fd7 Merge pull request #15221 from froystig:custom-vjp-symbolic-zeros2
PiperOrigin-RevId: 522823918
2023-04-08 09:49:45 -07:00
Peter Hawkins
dee8279377 Add __slots__ to core.Var
PiperOrigin-RevId: 522659264
2023-04-07 12:33:37 -07:00
Roy Frostig
d51b8e6839 custom_vjp symbolic zeros support, take two
This change re-introduces symbolic zero support for `custom_vjp`.

This time:

* The forward rule API is slightly different, accepting two-field
  records at pytree leaves rather than pairs.

* In the default setting where symbolic_zeros is not set, there are no
  new requirements from pytree node definitions that are involved in
  the primal arguments. This avoids any change in behavior on the
  default path. In particular, custom pytree node definitions that
  aren't completely polymorphic in unflattening can remain as is.

* There is an additional test involving a custom pytree node.
2023-04-05 11:17:05 -07:00
George Necula
cd35e901aa [shape_poly] Cleanup handling of dimension variables.
We unify the way we compute with dimension variables (computing
their values from the shape of the actual arguments, and also
using those values to evaluate shapes that contain dimension variables).

We remove DimExprValueMlir, and all computations with dimension variables
and DimExpr are now done by JAX interpretation, followed by lowering to
TF or StableHLO.
2023-04-03 13:33:29 +02:00
Matthew Johnson
6a2b081506 fix bug from #15335 by checking main_trace tag 2023-03-30 22:35:03 -07:00
Peter Hawkins
c2d6fcc0e6 Split core.py and several files in an SCC with it into a separate Bazel build target.
PiperOrigin-RevId: 520192610
2023-03-28 18:31:13 -07:00
Peter Hawkins
f461c4ef0c Move jax._src.typing into a separate Bazel target.
PiperOrigin-RevId: 518899136
2023-03-23 10:30:08 -07:00
jax authors
e39578cd73 Merge pull request #15154 from mattjj:pjit-typecheck
PiperOrigin-RevId: 518717095
2023-03-22 17:31:59 -07:00
Matthew Johnson
268456ef54 enable pjit recursive typechecking
Give pjit_p a custom typecheck rule, which basically just calls the
core._check_call utility (which was made for xla_call_p and core.call_p).

This revealed the need for a slight generalization of the custom_typecheck rule
signature, for better "context-aware" printing of jaxpr type errors: the rules
should have a `ctx_factory` first argument. **The reason this PR touches so
many files is just that it makes the trivial tweaks to all existing typecheck
rules to accomodate that new signature.** I didn't adapt any other higher-order
primitives' rules to actually use the context, but presumably errors for HOPs
like scan would be improved by using it. Follow-up work!

It's key that core._check_call works with dynamic shapes; this PR is soon to be
followed by some djax+pjit PRs!
2023-03-22 16:59:22 -07:00
Peter Hawkins
64e1f5fe3d Revert: custom_vjp symbolic zeros support
PiperOrigin-RevId: 518597609
2023-03-22 09:56:09 -07:00
Roy Frostig
ac7491ced0 custom_vjp symbolic zeros support 2023-03-21 14:14:35 +00:00
Matthew Johnson
af63365b8e make mlir arg and result names work with static_argnums/argnames
This is the first step in a revision to how we handle the debug info pertaining
to staged functions' parameter names and result pytree paths. To limit
complexity, this first step adds machinery required to make our MLIR lowerings'
parameter and result names work, but it does *not* yet unify it with existing
arg-name machinery used at tracing time (in partial_eval.py, e.g.
partial_eval.DebugInfo etc). That unification will come in a follow up commits.
(I wrote the unified version first, then broke it down into this sequence of
commits.)

Another thing that will arrive in follow-up commits is pmap support (handling
static_broadcasted_argnames). This PR doesn't include support for pmap because
pmap's final style implementation requires slightly different machinery than
jit/pjit's initial style implementation. Indeed this PR removes the previous
support for pmap arg/result info, and skips the corresponding tests, because
the previous support didn't handle pmap's static_broadcasted_argnums (and I
think it could even lead to silently incorrect annotations when pmap was not at
the top-level, though I didn't work out an example case to be sure that was
possible).

This commit includes the changes from PR #15079, so that PR should be merged first.

Here's the _why_ of this change:
* The pre-existing solution (from PRs #14702, #14764, and #14813) did not
  handle static_argnums or static_argnames correctly. Instead it would fail,
  resulting in debug info being dropped from the jaxpr and ultimately the MLIR
  computation (but no Exception raised). We need to handle
  static_argnums/argnames because while the corresponding parameters remain on
  the Python callable signature, they are excluded from the args/kwargs
  pytrees; the previous solution didn't account for that divergence.
* The best way to handle static_argnums/argnames is to work out this debug info
  when we still have the original args/kwargs in hand, i.e. much earlier than
  the previous mechanism. We then just have to pass this debug info to the
  right places. Indeed we often already had to work out some debug-related
  information at these call sites (e.g. whether the function is being staged
  out for jit, or scan, or whatever), so after this change we're working out
  all the debug info at the same time.
* A side benefit is that now to get this debug info we no longer need to
  unflatten user pytree defs with dummy objects (to reconstruct dummy
  args/kwargs trees so that we can call inspect.signature(fun).bind), since we
  just use the original args/kwargs instead. Since some user pytree node types
  are not fully polymorphic in their element types (e.g. their __init__ methods
  sometimes contained assertions about their elements' shapes, expecting them
  to be arrays), that means the new mechanism is fundamentally more compatible
  with custom pytree node types.

More concretely, effecting those high-level changes led to:
* replacing the previous `core.DebugInfo` with a class `core.JaxprDebugInfo`,
  which in addition to the more precise name has fields like
  `arg_names: Tuple[Optional[str], ...]` and
  `result_paths: Tuple[Optional[str], ...]`, rather than
  `in_tree: Optional[PyTreeDef]`, reflecting the fact that we work out the
  actual debug info more eagerly than before and we don't need pytrees for
  dummy-unflattening;
* introducing the new `partial_eval.TracingDebugInfo` class representing the
  debug info about inputs which we have available at tracing time; in a
  follow-up PR, we'll adapt partial_eval.py to use this new class and we'll
  delete `partial_eval.DebugInfo` and its corresponding helper methods (not
  done in this commit just to reduce complexity of each change);
* moving the old `core.DebugInfo`, which before #14702 lived in
  partial_eval.py, back to partial_eval.py pending cleanup (deletion) of that
  partial_eval.py debug info code;
* making specific jaxpr-processing functions produce an appropriately updated
  `core.JaxprDebugInfo` object for their output (e.g. `pe.dce_jaxpr` prunes
  elements from the `arg_names` field), maintaining now-checked invariants like
  a Jaxpr's `debug_info` should have the same number of argument names as the
  jaxpr has invars (the jaxpr-processing functions updated here are enough for
  top-level jit jaxprs to have debug info attached, handling the original
  intended use case of jit(f).lower, but not e.g. grad-of-jit cases, which can
  be handled later by updating `ad.jvp_jaxpr` and the like to produce updated
  debug info on their outputs);
* add some tests for static_argnums/static_argnames.

Phew! Can't wait to land those follow-ups too :P
2023-03-20 11:50:30 -07:00
Peter Hawkins
c1fbd2caa8 [JAX] Check for AttributeError from getattr(), not KeyError.
PiperOrigin-RevId: 517462731
2023-03-17 11:26:47 -07:00
Jake VanderPlas
7610013ebf Improve error for tolist() and tobytes() on tracer objects 2023-03-17 09:42:49 -07:00
Adam Paszke
1301968248 Optimize canonicalize_shape
I was looking at some profiles and noticed canonicalize_shape showing up as a noticeable
overhead in certain cases. Which makes sense, given that we carefully check all possible
cases before trying to consider integers as plausible elements (which are the most popular
_by far_). And this function is pretty hot, because it gets called any time we create a new
`ShapedArray`.

I wrote a small benchmark that repeatedly calls canonicalize_shape on a 4-sized tuple of
integers.

Before:
7.62µs ± 8%

After:
1.42µs ± 2%

So a pretty easy 5x improvement overall. And in more real cases, when resharding an array
onto 8 TPUs, 50% of the time was spent on creating shapes for avals of device buffers.

PiperOrigin-RevId: 516795311
2023-03-15 05:10:09 -07:00
Peter Hawkins
8c7ba99f82 Make Tracer types on JaxprTrace more precise.
instantiate_const() must take and return a JaxprTracer.

Teach pytype that the Tracer returned by full_raise() must be an instance of the Tracer type associated with the Trace, using a Generic type.

PiperOrigin-RevId: 516554216
2023-03-14 09:56:21 -07:00
Matthew Johnson
b05975b964 add result info to mhlo, fixes #14780
incidentally fixes #14787
2023-03-06 21:21:26 -08:00
Matthew Johnson
c2aa5c5eed attach debug info to jaxpr, pass to mlir/mhlo
Co-authored-by: Peter Hawkins <phawkins@google.com>
2023-03-02 17:23:58 -08:00
Peter Hawkins
8fb1fd318d Replace jax._src.util.prod with math.prod.
math.prod() was added in Python 3.8, so we can assume it is always present.

PiperOrigin-RevId: 513011144
2023-02-28 12:41:00 -08:00
Matthew Johnson
5c4525cb10 custom_jvp symbolic zeros support
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com>
2023-02-24 07:33:49 -08:00
Sharad Vikram
a6c4c87f3e Add JaxprInputEffect and refactor StateEffects to use it 2023-02-21 16:30:06 -08:00
Sharad Vikram
af2306c0a8 Refactor effects system to use effect types, not objects 2023-02-17 17:40:08 -08:00
Zeynep Cankara
995ef40f68 [JAX] Improve error message when jit tracer passed to a shape.
Adds additional debugging message to the shape explaining why the value is a tracer.

Fixes #14279

PiperOrigin-RevId: 509545985
2023-02-14 09:13:01 -08:00
Jake VanderPlas
60256df668 [typing] define additional methods & properties on jax.Array
These are the methods that are only valid for actual materialized arrays (i.e. not Tracers)
In order to simplify the experience for users, we want to maintain only a single jax.Array
type, so we define all methods here and raise explicit errors on Tracer instances.
2023-02-10 09:42:32 -08:00
Matthew Johnson
a964dc3b9a simpler pretty-print for pjit, tweak custom pp rule signature 2023-02-09 12:45:51 -08:00
Matthew Johnson
644d3b650f minor tweaks to type annotations, specialize code on those types
I noticed some slightly-too-general type annotations in core.py. By tightening
them we could simplify the code too. (I think these were leftovers from
pre-omnistaging...)
2023-02-02 20:24:26 -08:00