516 Commits

Author SHA1 Message Date
Skye Wanderman-Milne
9128ba0c74 Replace host_id with process_index terminology, take 2.
We're switching to the new terminology to avoid confusion in cases
where multiple jax processes are running on a single host, and each
process has a unique process_index/host_id.

This keeps aliases for the old `host_id` APIs for now, but these will
eventually be removed.

This was originally commited in
b77ef5138b631378e6a8ceb8bafc94fe91239bae, but reverted in
14acd070c2afb11c81fc91f43790577cd48cbf67 due to Google-internal test
failures from renaming the local_devices argument name. This change is
identical except it also adds staging for the argument name change.
2021-04-20 18:13:34 -07:00
jax authors
14acd070c2 Internal change
PiperOrigin-RevId: 369345279
2021-04-19 18:23:07 -07:00
Skye Wanderman-Milne
b77ef5138b Replace host_id with process_index terminology.
We're switching to the new terminology to avoid confusion in cases
where multiple jax processes are running on a single host, and each
process has a unique process_index/host_id.

This keeps aliases for the old `host_id` APIs for now, but these will
eventually be removed.
2021-04-19 14:09:19 -07:00
Peter Hawkins
14d991dd90 Move jax.config to jax._src.config.
PiperOrigin-RevId: 369230109
2021-04-19 08:53:12 -07:00
Peter Hawkins
26e9ebcdae Move jax.api to jax._src.api.
PiperOrigin-RevId: 368233837
2021-04-13 09:43:24 -07:00
Matthew Johnson
60828e9b19 raise error if vmap/pmap in_axes are booleans
fixes #6372
2021-04-09 14:47:36 -07:00
Peter Hawkins
a54a5e59ee Remove backward compatibility code paths for jaxlib < 0.1.65.
Fix up a few version comments.
2021-04-09 15:39:38 -04:00
Peter Hawkins
fb2824bdbb [JAX] Add static_argnames support to jax.jit.
Requires a new jaxlib build.

Add support for static_argnames in C++ JIT implementation.

PiperOrigin-RevId: 367627359
2021-04-09 07:11:04 -07:00
Peter Hawkins
9fad2441a2 Mark arguments to jax.jit() other than the function as keyword-only.
This change is to prevent breakage when options are added or removed.
2021-04-08 10:32:35 -04:00
jax authors
59685f98a8 Merge pull request #6366 from shoyer:inspect-fallback
PiperOrigin-RevId: 367347533
2021-04-07 19:58:51 -07:00
Peter Hawkins
6a6f13e1b0 [JAX] Move contents of jax/dtypes.py to jax/_src/dtypes.py.
PiperOrigin-RevId: 367345623
2021-04-07 19:35:51 -07:00
Stephan Hoyer
ebc86f83c4 add fallback for inspect.signature inside jit 2021-04-07 13:48:20 -07:00
jax authors
d2c53e0560 Merge pull request #6273 from shoyer:static-kwargs2
PiperOrigin-RevId: 366459406
2021-04-02 09:37:46 -07:00
Matthew Johnson
632876d773 Copybara import of the project:
--
35fcf2e2fd5b4c56cbb591f4c8bf01222a23dfe5 by Matthew Johnson <mattjj@google.com>:

remove deprecated custom_transforms code

PiperOrigin-RevId: 366108489
2021-03-31 13:50:56 -07:00
Stephan Hoyer
acb0be9cb7 Add _python_jit_with_static_argnames. 2021-03-31 10:02:16 -07:00
Peter Hawkins
0cd74c93b8 Fix declaration order problem when using JAX_DEBUG_NANS environment variable. 2021-03-29 20:10:40 -04:00
Matthew Johnson
2b79264354 remove disable_omnistaging mechanism 2021-03-29 15:26:57 -07:00
Jamie Townsend
0a3ba6f2ce Instantiate zero outputs of linear_transpose 2021-03-26 11:03:07 +00:00
Matthew Johnson
848fed8b87 Work around CPython bug https://bugs.python.org/issue33261 exposed by changes to C++ JIT dispatch path.
PiperOrigin-RevId: 365189779
2021-03-25 22:15:45 -07:00
Peter Hawkins
a136f61ba7 [JAX] Remove function wrappers in C++ JIT dispatch path.
Notable changes:
* Make CompiledFunction implement __get__() so it can be used as a bound method.
* Allow dynamic attributes in CompiledFunction.

Includes changes from https://github.com/google/jax/pull/6220 and https://github.com/google/jax/pull/6183 as diffbases.

PiperOrigin-RevId: 365170596
2021-03-25 19:01:02 -07:00
Peter Hawkins
f6f0128c16 Move api_boundary annotation onto C++ jit cache_miss function.
There's a small overhead to introducing an additional Python call frame, which we can avoid if we annotate only the cache miss case. This change does not seem to affect the filtered stack traces; it appears we filter all JAX-internal frames whenever any api_boundary is present.
2021-03-25 17:06:07 -04:00
Peter Hawkins
cac1b891ce [JAX] Refactor NaN/Inf checking in jitted functions.
Avoid performing NaN/Inf checking in the common path for calling a jit-ted function. Instead, add a global/thread-local `posthook` function that, if, set, the C++ jit code calls with the inputs (function, args, kwargs, outputs). Use the posthook feature to implement NaN checking.

Add a `_cache_miss` attribute to the C++ JIT function objects to allow the NaN checking code to extract and call the cache miss function.

PiperOrigin-RevId: 365108787
2021-03-25 13:13:02 -07:00
Peter Hawkins
b1e1d0acd6 Switch disable_jit() to use the common boolean state mechanism. 2021-03-24 19:46:52 -04:00
Matthew Johnson
fd7b286ec9 unify configuration state handling 2021-03-23 18:56:01 -07:00
Peter Hawkins
368f3f056e Rollforward of:
[JAX] Add an opaque `extra_jit_context` field to the JAX C++ jit code.

This allows the JAX Python code to include extra context from, for example, the interpreter state as part of the C++ jit cache key.

PiperOrigin-RevId: 364611475
2021-03-23 12:00:43 -07:00
Peter Hawkins
7890d6cc2a Rollback of:
[JAX] Add an opaque `extra_jit_context` field to the JAX C++ jit code.

This allows the JAX Python code to include extra context from, for example, the interpreter state as part of the C++ jit cache key.

PiperOrigin-RevId: 364599983
2021-03-23 11:12:02 -07:00
Peter Hawkins
f2a6d46426 [JAX] Add an opaque extra_jit_context field to the JAX C++ jit code.
This allows the JAX Python code to include extra context from, for example, the interpreter state as part of the C++ jit cache key.

PiperOrigin-RevId: 364563982
2021-03-23 08:35:05 -07:00
Peter Hawkins
e1646bf2b8 [JAX] Include jax_enable_x64 in the C++ JIT call signature.
This allows us to avoid building a tuple as part of JIT dispatch.

PiperOrigin-RevId: 364409162
2021-03-22 14:29:49 -07:00
Jake VanderPlas
36ea462e47 grad(): improve error for traced argnums 2021-03-19 10:14:24 -07:00
Peter Hawkins
23756a040b [JAX] Refactor handling of JIT interpreter state in jax_jit API.
Create separate holder objects for global and thread-local state, and move enable_x64 and disable_jit context into the holder objects.

Expose the global and per-thread state objects to Python via pybind11.

Refactoring only; no functional changes intended.

PiperOrigin-RevId: 363510449
2021-03-17 14:39:34 -07:00
Peter Hawkins
328930b917 Increase minimum jaxlib version to 0.1.62. 2021-03-16 15:11:36 -04:00
Peter Hawkins
140c0acbbe Remove the JAX lazy sublanguage.
Back in the mists of time, before omnistaging landed in JAX, we used lazy
expressions to avoid materializing large constants inside `jit` computations.
Omnistaging, which means that computations that are in the dynamic scope of a
`jit` are staged into the `jit` computation, has subsumed most of the reasons
for laziness to exist, and this PR removes the laziness support for simplicity.

At the time of this PR, laziness is used only for broadcasts and transposes in
eager mode (i.e., outside a `jit`). This allows us to:
a) fuse together multiple broadcasts and transposes, and
b) if a lazy expression is lexically captured by a `jit` computation, we can
   avoid materializing it in its expanded form.

It is not clear that laziness has sufficient power to weight ratio to continue
to exist, and it is making other work on improving JAX dispatch times more
difficult. As a result, this PR removes laziness to unblock that work; if we
want laziness again we would want to reimplement it in C++ anyway.
2021-03-09 21:40:46 -05:00
James Bradbury
c622422dad [avals with names] Propagate presence of name (mapped) vs absence (replicated) in abstract eval based on existing batching rules 2021-03-09 13:48:15 -08:00
Peter Hawkins
2469ad1bb3 Cleanups for laziness. No functional changes intended.
Use None as a trivial lazy expression in more places. Simplify some code.
2021-03-07 11:33:04 -05:00
Jean-Baptiste Lespiau
ed16ad8ca2 Also change the C++ value for the flag, when the environment variable is set.
PiperOrigin-RevId: 360910313
2021-03-04 07:24:32 -08:00
Peter Hawkins
afd2aa2ea0 Remove device constants from lazy language.
Updated version of #4536.

This is removing the device constant part of #1668. We can do this because after #3370 and #4038 omnistaging removes the need for lazy device constants in a jitted context. (They could still in principle be useful in an op-by-op context, but the power:weight isn't worthwhile anymore.)

After this change, the only parts of the lazy sublanguage that remain are those to do with broadcasts and transposes. We may or may not kill those in a follow-up (it hinges on whether any benefit to op-by-op execution is worth the extra complexity).

This change regresses non-omnistaging users. As one particular example, test_eval_shape_big_random_array no longer passes with omnistaging disabled.
2021-03-03 21:17:31 -05:00
Jean-Baptiste Lespiau
654a5b332c Remove a callback to Python to get the value of some flag.
This is done to simplify the code, and not at all for performance, because it's only executed during the compilation phase.

One possible design question: should we let the user access the value of the flag if it has not been set? Right now, the Python code allows it I think (meaning the behavior may not match the flag value, which has not been parsed yet).
We could raise an error, by setting the flag value to absl::nullopt, and check it's not null. But it would be a breaking change, so I am a little reluctant doing so.

PiperOrigin-RevId: 360407549
2021-03-02 05:39:52 -08:00
Jean-Baptiste Lespiau
5d11d101c6 Move the x64 context manager threadlocal state from jax/python to xla/c++
Fixes #5532.

PiperOrigin-RevId: 360252057
2021-03-01 12:33:11 -08:00
Jean-Baptiste Lespiau
625bb8040e Fix an identation.
PiperOrigin-RevId: 360233433
2021-03-01 11:16:50 -08:00
Roy Frostig
912cc87a3d introduce linear_call for custom transposition.
This adds a primitive with a corresponding traceable function in
`custom_derivatives` that takes a callee and its transpose, both
functions. When the primitive is encountered during transposition, the
given transpose function is invoked instead of transpose-transforming
the callee. The invocation of the custom transposition is itself done
via a `linear_call`, with the original callee set as the transpose.
This maintains, in particular, that transposing twice is an identity.
2021-02-26 10:46:54 -08:00
Jake VanderPlas
067be89a0c DOC: minor documentation & formatting fixes 2021-02-23 10:31:44 -08:00
Matthew Johnson
9b18135b6e Rollback of #5702 due to internal breakage.
PiperOrigin-RevId: 357943850
2021-02-17 07:32:09 -08:00
James Bradbury
fb160b8afd [avals with names] Propagate presence of name (mapped) vs absence (replicated) in abstract eval based on existing batching rules 2021-02-16 15:46:14 -08:00
Roy Frostig
afecab9ad7 accept any arguments with shape/dtype attributes after make_jaxpr 2021-02-10 17:07:10 -08:00
Matthew Johnson
7394048782 make jax.eval_shape duck typing more robust 2021-02-09 11:25:15 -08:00
James Bradbury
10dcb26cb3 [avals with names] Add named_shape to ShapedArray and update typecompat
The second change in the avals-with-names stack:
- https://github.com/google/jax/pull/5524 Revise aval constructor call sites to use a new `aval.update` method
- **Add `named_shape` to `ShapedArray` and update typecompat**
- Propagate presence of name (mapped) vs absence (replicated) in abstract eval based on existing batching rules
- Make `mapped_aval`, `unmapped_aval`, and their xmap equivalents swap positional and named axes (rather than just creating and deleting positional ones)
- Enable `lax.full` to create values with named axes
- Ensure `grad` and `jacfwd`/`jacrev` consistently act elementwise over named axes (by e.g. using a seed with named axes in `grad`, and prohibiting collectives if TAP isn't too unhappy) and align `vmap(transpose)` with `transpose(vmap)` by moving the `psum` in `transpose(psum)` into `backward_pass`
- Add `axis_name` kwarg to grad to indicate operating collectively over one or more named axes

PiperOrigin-RevId: 355880632
2021-02-05 10:41:05 -08:00
jax authors
3575bc7639 Merge pull request #5628 from tomhennigan:changelist/355594738
PiperOrigin-RevId: 355834441
2021-02-05 05:55:59 -08:00
Jake VanderPlas
5e7be4a61f Cleanup: remove obsolete jaxlib version checks 2021-02-04 15:13:39 -08:00
Jake VanderPlas
2fd682ef2a Make jax_enable_x64 a thread-local value. 2021-02-04 09:48:22 -08:00
Tom Hennigan
7adb1e381d Add jax.default_backend() which returns the default platform name.
This can be useful when you need backend specific behaviour, e.g.:

    if jax.default_backend() == 'gpu':
      dataset = double_buffer(dataset)

Or if you want to assert a given backend is the default:

    assert jax.default_backend() == 'tpu'

I am a bit conflicted by the naming, "backend" is consistent with other APIs in
JAX (e.g. jit, local_devices etc) which accept a "backend" string which is used
to lookup an XLA backend by platform name.
2021-02-04 14:50:15 +00:00