91 Commits

Author SHA1 Message Date
Yash Katariya
38f91bdaa5 Skip core tests which have nested pjits and DShapedArray.
PiperOrigin-RevId: 502013080
2023-01-13 22:39:31 -08:00
Jake VanderPlas
4a6bbde409 Move jax.linear_util to jax._src.linear_util 2022-12-20 14:49:27 -08:00
Peter Hawkins
6bda0d2863 Don't call dtypes.result_type() unnecessarily on the type of an array during abstractification.
Remove make_shaped_array since it has no more non-test users.

```
name        old cpu/op  new cpu/op  delta
device_put  69.4µs ± 6%  63.5µs ± 3%  -8.56%  (p=0.000 n=10+10)

name        old time/op             new time/op             delta
device_put  69.4µs ± 6%             63.5µs ± 3%  -8.56%        (p=0.000 n=10+10)
```

PiperOrigin-RevId: 491795793
2022-11-29 19:27:10 -08:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Peter Hawkins
335b2cfb26 [JAX] Prepare not to export jax._src by default.
Currently
```
import jax
```
populates `jax._src` in the names exported from JAX. This change prepares for not exporting `jax._src` by default.

In particular, explicitly import modules from jax._src and refer to those imports rather than assuming jax._src contents will be around later. This is a common pattern in tests.

This change does not yet remove any exported names.

Issue https://github.com/google/jax/issues/11951

PiperOrigin-RevId: 469480816
2022-08-23 09:36:47 -07:00
jax authors
5318df67fc Merge pull request #11151 from mattjj:input-fwd-residual-optimization
PiperOrigin-RevId: 455710192
2022-06-17 15:43:44 -07:00
Matthew Johnson
72a67906bf optimize grad-of-jit not to pass input-residuals as intermediate-residuals
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
Co-authored-by: Peter Hawkins <phawkins@google.com>
2022-06-17 15:24:28 -07:00
Matthew Johnson
f680269a4f [dynamic-shapes] initial support for dynamic shape typechecks
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2022-06-17 14:57:19 -07:00
Matthew Johnson
05dda56019 add core.closed_call_p 2022-05-14 14:06:30 -07:00
Matthew Johnson
0b841cf35b make core_test.py pass with core.call 2022-05-11 15:45:40 -07:00
Matthew Johnson
bb56f40947 Internal change
PiperOrigin-RevId: 447549479
2022-05-09 13:26:30 -07:00
Matthew Johnson
705c07ae6d remove count attribute and total_ordering from core.Var
Originally we used the 'Var.count' attribute to ensure Var instances were
printed consistently regardless of context, even though only their object id
was load-bearing. That is, Var.count was only used for pretty printing. (#1949
added a total_ordering on Var for reasons out of scope of JAX's core code.)

But #8019 revised our pretty-printing so as not to use Var.count. Instead it
chose how to pretty-print Var instances based on their order of appearance in a
jaxpr. That meant Var.count really wasn't useful anymore. So this PR removes
Var.count.

In fact, Var.__repr__ and JaxprEqn.__repr__ were made confusing after #8019,
since they could print variable names totally different from the names that
would appear when the same JaxprEqn or Var objects were printed as part of a
jaxpr. That is, before this PR< we might have a jaxpr which printed like:

```python
import jax

def f(x):
  for _ in range(3):
    x = jax.numpy.sin(x)
  return x

jaxpr = jax.make_jaxpr(f)(3.)
print(jaxpr)

_, eqn, _ = jaxpr.jaxpr.eqns
print(eqn)
```

Notice the variable names in the equation pretty-print don't correspond to any
in the jaxpr pretty-print!

So this PR changes JaxprEqn.__repr__ and Var.__repr__ to show Var object ids.
2022-05-09 09:31:23 -07:00
Matthew Johnson
9cd55a2bbd [remove-units] remove units 2022-05-04 10:58:56 -07:00
Matthew Johnson
477dfa6e46 [remove-units] don't use abstract_unit for dropvar avals 2022-04-28 22:51:41 -07:00
Matthew Johnson
4354f355a8 prototyping dynamic shapes
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2022-04-11 22:10:47 -07:00
Jake VanderPlas
df1ceaeeb1 Deprecate jax.tree_util.tree_multimap 2022-04-01 14:51:54 -07:00
jax authors
17fc5bd02e Merge pull request #9290 from fehiepsi:named
PiperOrigin-RevId: 438290209
2022-03-30 06:54:10 -07:00
Roy Frostig
d1f49c1917 add jaxpr check in partial_eval.trace_to_subjaxpr_dynamic 2022-03-23 19:56:34 -07:00
Roy Frostig
6f519576f6 remove _reduce_sum from public jax.lax module 2022-03-08 16:34:26 -08:00
Du Phan
3afe67367d NamedShape matches tuple behavior 2022-01-25 12:20:55 -05:00
Matthew Johnson
b55ba8afdb re-enable tests of #8955
f35014d had to revert part of #8955 because of a surprising downstream
breakage (relying on internal APIs). That breakage was isolated to how
_inline_literals handled invars.

The approach was a temporary one anyway: it relied on the fact that we
expect only to bind axis size variables at the top level and hence if we
didn't rename the input binders in _inline_literals we wouldn't need to
substitute new variables for any variables appearing in types. But a
more general approach would be to perform the necessary substitution
everywhere; after all, we might be inlining a literal into an axis size!

This commit takes the more general approach. It may fix the downstream
breakage automatically, just by virtue of being different; if not, I'll
figure out how to fix downstream.
2022-01-20 11:11:54 -08:00
jax authors
67723da38b Merge pull request #9143 from mattjj:fix-jaxpr-checking-error-messages
PiperOrigin-RevId: 420866862
2022-01-10 15:09:04 -08:00
Matthew Johnson
3548e023ec fix jaxpr type checking error messages
The pretty-printing changes a few months ago defined variable names
based on the state in JaxprPpContext instances. But that meant incorrect
variable names could be printed in jaxpr type checking error messages.

This commit correctly threads through the context so as to provide
error messages with coherent variable names.
2022-01-09 20:07:58 -08:00
Matthew Johnson
73b530aead Rolling forward again...
PiperOrigin-RevId: 420551242
2022-01-08 22:46:06 -08:00
jax authors
c335dfcc2a Re-applying #9136 after it was rolled back.
PiperOrigin-RevId: 420548165
2022-01-08 22:08:44 -08:00
Matthew Johnson
b0dabab99c Re-applying #9136 after it was rolled back.
PiperOrigin-RevId: 420545623
2022-01-08 21:39:41 -08:00
jax authors
0e201425e6 Internal change
PiperOrigin-RevId: 420436636
2022-01-07 23:27:47 -08:00
Matthew Johnson
c3b1d0dfd0 simpler jaxpr eqn params to bind params conversion
Final-style higher-order primitives, like call_p, xla_call_p (underlying
jit), xla_pmap_p (underlying pmap), and xmap_p (underlying xmap) have
slightly different bind signatures (while tracing) from their signatures
when they appear in jaxprs. In particular, their trace-time binds are
parameterized by a Python callable (or really a lu.WrappedFun)
representing the function to be applied, while in jaxpr eqns they are
parameterized by a jaxpr representing the same.

As a result, to round-trip from jaxpr to Python traceable, in
core.eval_jaxpr we have to convert from one parameter signature to the
other. (Basically we had to take the jaxpr and turn it into a Python
callable, via lu.wrap_init(partial(core.eval_jaxpr, call_jaxpr, ...)).)

However due to historical path dependence these conversion mechanisms
were all slightly distinct and kind of a mess. There was a case analysis
for call_jaxpr and map_jaxpr in core.eval_jaxpr_eqn (a helper function
created only because of this complexity), and there was a separate table
only used for the xmap rule.

In this PR we uniformized things! We basically only have a table (to
simplify core.eval_jaxpr), but instead of having it as a table we just
attached the rules to the different primitive classes (CallPrimitive,
MapPrimitive, and XmapPrimitive) to make things less error-prone (we
have a few different CallPrimitive instantiations, like call_p,
xla_call_p, named_call_p, and remat_call_p, and this way we don't have
to remember to populate the table separately for each).

This was actually a warmup simplification before we attempt to simplify
custom derivatives (to unify custom_jvp_call_p and
custom_jvp_call_jaxpr_p).

Co-authored-by: Roy Frostig <frostig@google.com>
2022-01-07 21:37:36 -08:00
Matthew Johnson
f35014d655 Temporarily revert a small part of https://github.com/google/jax/pull/8955 (in
partial_eval.py's _inline_literals) and skip new tests.

Some code seems to depend on whether we generate fresh invars (i.e. jaxpr input
binders) in that code. I'm not sure if it's a bug in the new JAX code or a bug in
the user code, but I'd like to un-break things while investigating.

PiperOrigin-RevId: 420296461
2022-01-07 08:17:09 -08:00
Matthew Johnson
4db899007b add staging logic for polymorphic shapes in jaxprs
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2022-01-05 14:11:12 -08:00
Peter Hawkins
06cd1fedee Move dtype canonicalization out of core.AbstractValue subclasses.
This is a strictly mechanical change that moves abstract value canonicalization out of the core.AbstractValue subclasses and into their callers. This makes it safe to manipulate non-canonical abstract values even inside an -x32 context.

The callers to which canonicalization was added were:
a) all callers of `ConcreteArray` inside the JAX Tree.
b) all callers of `ShapedArray` and `UnshapedArray` that were found to be passing non-canonical dtypes during a global presubmit. These were identified by adding an assertion that the dtype is in fact canonical and fixing all the resulting test failures.

PiperOrigin-RevId: 414704700
2021-12-07 06:13:07 -08:00
Peter Hawkins
48bbdbc890 Change jax.core.DropVar to be a non-singleton.
Previously jax.core.DropVar was a singleton value (jax.core.dropvar) whose type was always jax.core.AbstractUnit. However, this type is misleading: a DropVar is an equation output, and typically we would expect it to have an array type. In particular, the unit type confuses new-style translation rules that expect to use the output aval on an equation as part of the lowering logic.

Instead, change DropVar to be a non-singleton subclass of Var instead with a flexible choice of aval.

PiperOrigin-RevId: 404071001
2021-10-18 15:02:54 -07:00
Peter Hawkins
5fa4613e99 Adds a Wadler-Lindig pretty printer.
Changes jaxpr printing to use it.
2021-09-27 21:09:24 -04:00
Peter Hawkins
db2e91eba2 Move jax.test_util to jax._src.test_util.
Add forwarding shims for names used by external clients of JAX in practice.

PiperOrigin-RevId: 398721725
2021-09-24 07:02:49 -07:00
Peter Hawkins
8b2123968a Switch internal users of jax.util.partial to use functools.partial. 2021-09-13 21:09:58 -04:00
Markus Kunesch
6708cd3158 Add dtype to string representation of ConcreteArray.
The string representation of ConcreteArray did not include the data type of the
wrapped value. This makes it harder to spot the reason for errors arising from
inconsistent values (issue #5364). This commit adds the data type to the string
representation of ConcreteArray.
2021-08-13 15:01:26 +00:00
Peter Hawkins
46cc654537 Move jax.abstract_arrays to jax._src.abstract_arrays.
PiperOrigin-RevId: 377044255
2021-06-02 06:25:22 -07:00
Peter Hawkins
26e9ebcdae Move jax.api to jax._src.api.
PiperOrigin-RevId: 368233837
2021-04-13 09:43:24 -07:00
Matthew Johnson
0181d03902 add a memory leak test for jit jaxpr construction
Tweak implementation for `_inline_literals` not to include a class
defined in a function, since that seemed to cause leaking!
2021-03-17 13:09:02 -07:00
James Bradbury
10dcb26cb3 [avals with names] Add named_shape to ShapedArray and update typecompat
The second change in the avals-with-names stack:
- https://github.com/google/jax/pull/5524 Revise aval constructor call sites to use a new `aval.update` method
- **Add `named_shape` to `ShapedArray` and update typecompat**
- Propagate presence of name (mapped) vs absence (replicated) in abstract eval based on existing batching rules
- Make `mapped_aval`, `unmapped_aval`, and their xmap equivalents swap positional and named axes (rather than just creating and deleting positional ones)
- Enable `lax.full` to create values with named axes
- Ensure `grad` and `jacfwd`/`jacrev` consistently act elementwise over named axes (by e.g. using a seed with named axes in `grad`, and prohibiting collectives if TAP isn't too unhappy) and align `vmap(transpose)` with `transpose(vmap)` by moving the `psum` in `transpose(psum)` into `backward_pass`
- Add `axis_name` kwarg to grad to indicate operating collectively over one or more named axes

PiperOrigin-RevId: 355880632
2021-02-05 10:41:05 -08:00
Peter Hawkins
3ac809ede3 [JAX] Move jax.util to jax._src_util.
PiperOrigin-RevId: 351234602
2021-01-11 14:21:07 -08:00
Roy Frostig
ec6b10d4ea clear some caches when setting up jaxpr typecheck tests
In order to test that the typechecker identifies invalid jaxprs, some
tests modify jaxprs in place. This is typically not allowed, since
jaxprs are assumed immutable, and may be cached. As a workaround, this
change clears the relevant caches before every test. This ought to
prevent some order-dependent test failures.
2021-01-06 11:27:54 -08:00
Jake VanderPlas
6393349783 raise_to_shaped: preserve weak_type by default 2020-10-08 11:53:52 -07:00
Roy Frostig
e7979258ee equation context for undefined var reads in jaxpr typechecker 2020-10-05 12:29:43 -07:00
Jake VanderPlas
40016cc47c Allow jax objects to be represented by multiple buffers 2020-09-29 11:53:17 -07:00
Peter Hawkins
a0e14b0552
Revert "Allow JAX objects to be represented by multiple buffers" 2020-09-29 09:26:11 -04:00
Jake VanderPlas
d1f80228e0 Allow jax objects to be represented by multiple buffers 2020-09-25 11:09:08 -07:00
Matthew Johnson
6614f94890
rename and simplify TypedJaxpr -> ClosedJaxpr (#4328)
rename and simplify TypedJaxpr -> ClosedJaxpr

This change:
* simplifies code that constructs TypedJaxprs/ClosedJaxprs (because
  in_avals / out_avals no longer need to be constructed), making them
  easier to work with;
* correspondingly rules out a class of errors (mismatches between
  invars/outvars and in_avals/out_avals);
* provides a more descriptive class name (ClosedJaxprs are like jaxprs
  but they're closed in that they are packaged with their constant
  values).

This is part 1 of an attempt to remove TypedJaxprs completely, or at
least significantly reduce our use of them. However, I'm not getting rid
of them entirely in this first step because it'd require bigger changes
(basically allowing all constants to be represented as literals, rather
than only scalars) that would not only touch a lot more code (jaxpr
formation, jaxpr-to-jaxpr transformations, control flow, XLA lowering)
but also might affect XLA lowering right before a conference deadline
(ICLR). Plus I'm trying to make big changes in smaller steps :)

Co-authored-by: George Necula <gcnecula@gmail.com>
2020-09-18 10:07:13 -07:00
Matthew Johnson
4236eb2b59
omnistaging, under a flag and disabled by default (#3370)
This change, when enabled, stages out all primitive calls in the dynamic
scope of a jitted, pmapped, or control flow function, rather than only
staging out based on data dependence. One improvement is that jitted
functions can consume less memory, by avoiding instantiating large
constants at trace time, and cause less memory fragmentation as well. It
also simplifies several internals.

See https://github.com/google/jax/pull/3370 fo more information.
2020-07-30 12:59:36 -07:00
Jake Vanderplas
a9fad49e1b
Add ability to specify individual test targets via a regex (#3549)
* Add ability to specify individual test targets

* fix missing imports

* Use re.search and include test class name
2020-06-29 14:16:51 -07:00