62 Commits

Author SHA1 Message Date
Matthew Johnson
af63365b8e make mlir arg and result names work with static_argnums/argnames
This is the first step in a revision to how we handle the debug info pertaining
to staged functions' parameter names and result pytree paths. To limit
complexity, this first step adds machinery required to make our MLIR lowerings'
parameter and result names work, but it does *not* yet unify it with existing
arg-name machinery used at tracing time (in partial_eval.py, e.g.
partial_eval.DebugInfo etc). That unification will come in a follow up commits.
(I wrote the unified version first, then broke it down into this sequence of
commits.)

Another thing that will arrive in follow-up commits is pmap support (handling
static_broadcasted_argnames). This PR doesn't include support for pmap because
pmap's final style implementation requires slightly different machinery than
jit/pjit's initial style implementation. Indeed this PR removes the previous
support for pmap arg/result info, and skips the corresponding tests, because
the previous support didn't handle pmap's static_broadcasted_argnums (and I
think it could even lead to silently incorrect annotations when pmap was not at
the top-level, though I didn't work out an example case to be sure that was
possible).

This commit includes the changes from PR #15079, so that PR should be merged first.

Here's the _why_ of this change:
* The pre-existing solution (from PRs #14702, #14764, and #14813) did not
  handle static_argnums or static_argnames correctly. Instead it would fail,
  resulting in debug info being dropped from the jaxpr and ultimately the MLIR
  computation (but no Exception raised). We need to handle
  static_argnums/argnames because while the corresponding parameters remain on
  the Python callable signature, they are excluded from the args/kwargs
  pytrees; the previous solution didn't account for that divergence.
* The best way to handle static_argnums/argnames is to work out this debug info
  when we still have the original args/kwargs in hand, i.e. much earlier than
  the previous mechanism. We then just have to pass this debug info to the
  right places. Indeed we often already had to work out some debug-related
  information at these call sites (e.g. whether the function is being staged
  out for jit, or scan, or whatever), so after this change we're working out
  all the debug info at the same time.
* A side benefit is that now to get this debug info we no longer need to
  unflatten user pytree defs with dummy objects (to reconstruct dummy
  args/kwargs trees so that we can call inspect.signature(fun).bind), since we
  just use the original args/kwargs instead. Since some user pytree node types
  are not fully polymorphic in their element types (e.g. their __init__ methods
  sometimes contained assertions about their elements' shapes, expecting them
  to be arrays), that means the new mechanism is fundamentally more compatible
  with custom pytree node types.

More concretely, effecting those high-level changes led to:
* replacing the previous `core.DebugInfo` with a class `core.JaxprDebugInfo`,
  which in addition to the more precise name has fields like
  `arg_names: Tuple[Optional[str], ...]` and
  `result_paths: Tuple[Optional[str], ...]`, rather than
  `in_tree: Optional[PyTreeDef]`, reflecting the fact that we work out the
  actual debug info more eagerly than before and we don't need pytrees for
  dummy-unflattening;
* introducing the new `partial_eval.TracingDebugInfo` class representing the
  debug info about inputs which we have available at tracing time; in a
  follow-up PR, we'll adapt partial_eval.py to use this new class and we'll
  delete `partial_eval.DebugInfo` and its corresponding helper methods (not
  done in this commit just to reduce complexity of each change);
* moving the old `core.DebugInfo`, which before #14702 lived in
  partial_eval.py, back to partial_eval.py pending cleanup (deletion) of that
  partial_eval.py debug info code;
* making specific jaxpr-processing functions produce an appropriately updated
  `core.JaxprDebugInfo` object for their output (e.g. `pe.dce_jaxpr` prunes
  elements from the `arg_names` field), maintaining now-checked invariants like
  a Jaxpr's `debug_info` should have the same number of argument names as the
  jaxpr has invars (the jaxpr-processing functions updated here are enough for
  top-level jit jaxprs to have debug info attached, handling the original
  intended use case of jit(f).lower, but not e.g. grad-of-jit cases, which can
  be handled later by updating `ad.jvp_jaxpr` and the like to produce updated
  debug info on their outputs);
* add some tests for static_argnums/static_argnames.

Phew! Can't wait to land those follow-ups too :P
2023-03-20 11:50:30 -07:00
Peter Hawkins
a0121d9b9b Improve pytype inference for Sharding type.
* Define use_cpp_class and use_cpp_method decorators as no-ops for type checking.
* Remove the use of abc.ABC when defining the Sharding type. This triggers a pytype bug: the easiest fix seems to be to skip the use of the ABC.
* Write use_cpp_class decorator differently on ArrayImpl to work around pytype bug.
* Fix a few new type errors.

PiperOrigin-RevId: 516631428
2023-03-14 14:20:17 -07:00
Parker Schuh
d21c78a54b [Rollforward] Move PyBuffer methods used by PyArray to c++.
```
  def delete(self): ...
  def unsafe_buffer_pointer(self) -> Any: ...
  def clone(self) -> ArrayImpl: ...
  def _copy_single_device_array_to_host_async(self): ...
  def _single_device_array_to_np_array(self) -> np.ndarray: ...
  def on_device_size_in_bytes(self) -> int: ...
```

PiperOrigin-RevId: 516372847
2023-03-13 17:59:17 -07:00
Matthew Johnson
a6d3ae1446 use Partial to make ravel_pytree unflatteners jit-friendly
Co-authored-by: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com>
2023-03-13 11:06:56 -07:00
Yash Katariya
96da1c4b71 [Rollback] Move PyBuffer methods used by PyArray to c++.
```
  def delete(self): ...
  def unsafe_buffer_pointer(self) -> Any: ...
  def clone(self) -> ArrayImpl: ...
  def _copy_single_device_array_to_host_async(self): ...
  def _single_device_array_to_np_array(self) -> np.ndarray: ...
  def on_device_size_in_bytes(self) -> int: ...
```

PiperOrigin-RevId: 515914842
2023-03-11 13:28:03 -08:00
Parker Schuh
e317a1ef06 Move PyBuffer methods used by PyArray to c++.
```
  def delete(self): ...
  def unsafe_buffer_pointer(self) -> Any: ...
  def clone(self) -> ArrayImpl: ...
  def _copy_single_device_array_to_host_async(self): ...
  def _single_device_array_to_np_array(self) -> np.ndarray: ...
  def on_device_size_in_bytes(self) -> int: ...
```

PiperOrigin-RevId: 515769831
2023-03-10 16:41:58 -08:00
Peter Hawkins
0e05a7987f Split some submodules out of //jax under Bazel.
Add separate BUILD targets
* :version - for version.py
* _src/lib - wrapping the jaxlib shims.
* :util - for util.py
* :config - for config.py

PiperOrigin-RevId: 515307923
2023-03-09 05:27:34 -08:00
Peter Hawkins
bd2500579a Change definition of util.wraps so pytype can understand it.
@curry is opaque to pytype.

Fix a false positive type error that turns up because pytype doesn't really understand that a functools.partial is a kind of Callable.

PiperOrigin-RevId: 513697380
2023-03-02 18:41:52 -08:00
Matthew Johnson
c2aa5c5eed attach debug info to jaxpr, pass to mlir/mhlo
Co-authored-by: Peter Hawkins <phawkins@google.com>
2023-03-02 17:23:58 -08:00
Peter Hawkins
a002643a4a Fix stale reference to util.prod.
Work around pytype bug. It seems that the line
from functools import cached_property
causes pytype to give up on the entire module. Avoid the member import to fix the type inference.

PiperOrigin-RevId: 513544106
2023-03-02 08:24:30 -08:00
Peter Hawkins
8fb1fd318d Replace jax._src.util.prod with math.prod.
math.prod() was added in Python 3.8, so we can assume it is always present.

PiperOrigin-RevId: 513011144
2023-02-28 12:41:00 -08:00
Peter Hawkins
148774587a Remove circular dependency between source_info_util and util.
Move util.new_name_stack into source_info_util. Replace uses of util.extend_name_stack with stack.extend().

PiperOrigin-RevId: 512685810
2023-02-27 11:41:46 -08:00
Parker Schuh
f888e4814c [Rollforward] Convert _arrays to return PyArray instead of PyBuffer.
This change also converts all callsites that construct buffers to
return PyArrays.

PiperOrigin-RevId: 510486273
2023-02-17 11:52:43 -08:00
Yash Katariya
9316188b3a [Rollback] Convert _arrays to return PyArray instead of PyBuffer.
PiperOrigin-RevId: 508827908
2023-02-10 21:36:56 -08:00
Parker Schuh
568a93bcd1 Convert _arrays to return PyArray instead of PyBuffer.
PiperOrigin-RevId: 508769390
2023-02-10 15:32:57 -08:00
Matthew Johnson
ff1e9b3973 shard_map (shmap) prototype and JEP
Co-authored-by: Sharad Vikram <sharadmv@google.com>
Co-authored-by: Sholto Douglas <sholto@google.com>
2023-02-02 23:01:30 -08:00
Jake VanderPlas
b679ef025f Remove unused CacheInfo namedtuple 2023-01-31 11:36:43 -08:00
Yash Katariya
c4d91d203c Remove local_imports of sharding.py. Adding pxla local imports but then cleaning those up will be super easy since those will be the only ones left and restricted to sharding.py file only.
Also remove `maybe_cached_property` from this CL since we are dropping 3.7 support

PiperOrigin-RevId: 491769101
2022-11-29 16:42:03 -08:00
Sharad Vikram
74b136e62c Delete jax_experimental_name_stack flag
PiperOrigin-RevId: 487601864
2022-11-10 11:59:50 -08:00
Peter Hawkins
320d531521 Increase the minimum jaxlib version to 0.3.22.
The minimum xla_extension_version is now 98 and the minimum mlir_api_version is now 32.
2022-10-27 10:24:11 -04:00
Jake VanderPlas
5d15757741 [typing] annotate jax._src.util.safe_map 2022-10-20 10:15:04 -07:00
Jake VanderPlas
524745f322 TMP: annotate util.safe_zip 2022-10-19 10:29:53 -07:00
Jake VanderPlas
d60ceeadd0 [typing] annotate util.unzip2 & util.unzip3 2022-10-18 09:47:49 -07:00
Nicholas Junge
efd61b73f6 Migrate JAX internals to builtin Python logging
This commit changes the JAX codebase to use Python's builtin logging instead of ABSL logging. With the latter being used in JAX code as of now, the change to Python builtin logging is advised for the following reasons (among others):

- absl-py can be removed as an external dependency of JAX.
- Builtin logging brings the option of adding more log handlers, for example file handlers for log dumps or writers to different IO streams.

Logging in JAX is ported over to take place at the module level. While previously, some Python namespaces within JAX already used module-scoped logging via absl.vlog, the following idiom was adopted to provide the same functionality in Python builtin logging:

```py
import logging
logger = logging.getLogger(__name__)

logger.debug(...)
logger.info(...)
```

 The builtin root logger is left untouched, which is beneficial for downstream users planning to customize the Python root logger. All JAX internal code promises to log to descendants of the top-level "jax" logger by virtue of log propagation.

The package `absl-py` was removed from JAX's install requirements, and added into its test requirements.
2022-10-13 21:32:44 +02:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Parker Schuh
9b3dfb66fa Use c++ weakref LRU cache implementation as a drop in replacement for jax's
weakref_lru_cache.

PiperOrigin-RevId: 468550018
2022-08-18 14:36:08 -07:00
Jake VanderPlas
98fac62897 remove dead code: jax._src.util.taggedtuple 2022-07-25 15:14:25 -07:00
Matthew Johnson
5b82ba787c [dynamic-shapes] start basic vmap compatibility 2022-07-09 10:03:40 -07:00
Matthew Johnson
004b59fbc9 [dynamic-shapes] basic linearize and grad working 2022-06-30 14:30:22 -07:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Jake VanderPlas
72470dee3a Comment on implementation of unzip2 & unzip3 2022-04-14 13:41:05 -07:00
Parker Schuh
df1c478ec5 Fix race condition for weakref destructor by catching rare exceptions. 2022-04-01 12:04:36 -07:00
Parker Schuh
c1bb767959
Update util.py 2022-03-23 12:26:09 -07:00
Parker Schuh
d0e0da02a1 Add weakreaf_lru_cache to prevent caches from pinning jaxprs.
To use this cache, the first argument must be some type that is
object identity hashed (like a jaxpr).
2022-03-21 10:56:44 -07:00
Sharad Vikram
1b79caa6bd Add separate mechanism for threading name stacks to the lowering 2022-02-23 09:59:09 -08:00
Matthew Johnson
7077ce2e68 [remove units] make JaxprTrace.process_call not introduce units 2022-02-12 13:48:12 -08:00
Matthew Johnson
e321964245 de-duplicate util.memoize and util.cache
The only difference between the two was that
jax.config.jax_check_tracer_leaks disables the caching under util.cache
but not under util.memoize.

We could add that as an option on the same function if it turns out to
be important, but it seems unnecessary. Moreover there are only two
callers (in dtypes.py and in batching.py).

Co-authored-by: Skye Wanderman-Milne <skyewm@google.com>
2022-01-10 14:28:28 -08:00
Peter Hawkins
52fe821719 Merge xla._partition_outputs and util.unflatten.
PiperOrigin-RevId: 412117736
2021-11-24 12:52:40 -08:00
Jake VanderPlas
f6e3f1b4ad Cleanup: remove duplicate canonicalize_axis utility 2021-11-23 16:54:02 -08:00
Peter Hawkins
0f56838435 Fix a number of bugs in MLIR translation rules.
These bugs were found by running the existing tests with MLIR translations enabled, so no new tests are needed:

* Fix bug where we failed to propagate the symbol table to inner computations. This could lead to duplicate function names.
* Remove support for tupling arguments. It turns out that the MHLO->HLO conversion, which was the intended user, does not accept tupled arguments in the input MHLO. Instead, arguments are tupled if requested by a flag to the converter.
* Add a generic fallback to translate via the XLA HLO to MHLO if there is no MHLO-specific translation rule.
* If we are padding in select_and_scatter_add, we also need to slice the output.
* create_token may take arguments (which should be ignored).
* Fixed a number of misunderstandings of the mhlo.infeed contract.
* Untuple results in the fallback path iff the primitive is marked as having multiple results, not depending on the actual arity.
* Change xla.primitive_subcomputation not to filter token arguments, which is appropriate for a subcomputation.

PiperOrigin-RevId: 410519678
2021-11-17 07:20:56 -08:00
Peter Hawkins
70b8a6a806 Add a prototype IREE backend for JAX.
This is to support experimentation with the combination of JAX/IREE. Many things do not work yet.

PiperOrigin-RevId: 409980064
2021-11-15 07:57:04 -08:00
Peter Hawkins
8f6e077d9a Adds an initial prototype of an alternate JAX compilation path that emits the MLIR MHLO/CHLO dialects instead of classic XLA HLO.
This lowering is missing a number of features, but it is complete enough that many tests pass, and that I would like to start checking it in.

PiperOrigin-RevId: 409134016
2021-11-11 06:37:12 -08:00
Peter Hawkins
42e0d4e5f5 Remove jax._src.util.partialmethod.
Use functools.partialmethod instead, which has existed since Python 3.4. The JAX partialmethod doesn't work correctly in Python 3.10.

Issue #8097
2021-10-05 12:12:41 -04:00
Peter Hawkins
a11d957e61 Disallow non-hashable static arguments in pmap().
* Don't wrap static arguments in hashable wrappers in pmap.
* Delete wrap_hashably().
* In argnums_partial, either enforce hashability or wrap values with an explicitly unhashable wrapper. The intent here is that either we should check for hashability early or we should make sure it's clear that it's not something we intended..
* Delete argnames_partial, which appears unused.
2021-09-30 15:50:07 -04:00
Peter Hawkins
8b2123968a Switch internal users of jax.util.partial to use functools.partial. 2021-09-13 21:09:58 -04:00
Matthew Johnson
2d28951ba4 address comments form @apaszke 2021-08-26 14:10:58 -07:00
Matthew Johnson
542641ca87 rejames/reblake implementation 2021-08-25 20:46:32 -07:00
Peter Hawkins
e709a2ea4d Delete tuple_replace.
It is unused.
2021-05-19 15:29:44 -04:00
Peter Hawkins
5261b776d2 Handle context manager configuration settings for matmul precision and numpy rank promotion correctly in JIT and linear_util caches.
PiperOrigin-RevId: 369643419
2021-04-21 06:36:35 -07:00
Skye Wanderman-Milne
1614572eb9 Add optional distributed debugging logging.
This can be enabled by setting the environment variable
`JAX_DISTRIBUTED_DEBUG=1` (or other true-like values), the flag
`--jax_distributed_debug=1`, or `jax.config.distributed_debug =
True`. It's off by default.

This enables WARNING-level logging of each distributed computation
that's run and related debugging information. This is designed to help
with multi-process debugging, e.g. to identify mismatched pmaps across
processes. All debugging information is enclosed between
`DISTRIBUTED_DEBUG_BEGIN` and `DISTRIBUTED_DEBUG_END` to faciliate
grepping for this info.

Example output:

```
DISTRIBUTED_DEBUG_BEGIN
Initialized backend: tpu
  process_index: 0
  device_count: 8
  local_devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
DISTRIBUTED_DEBUG_END

DISTRIBUTED_DEBUG_BEGIN
Running pmapped function: <lambda>
  python function: <function PmapTest.testArgAllReduce.<locals>.<lambda> at 0x7f77924d6c80>
  devices: None
  abstract args: [ShapedArray(float32[2,2])]
DISTRIBUTED_DEBUG_END

DISTRIBUTED_DEBUG_BEGIN
Running xmapped function: <lambda>
  python function: <function XMapTest.testAxisSizes.<locals>.<lambda> at 0x7fb33d86e158>
  mesh: Mesh(array([TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
       TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1)],
      dtype=object), ('x',))
  abstract args: []
DISTRIBUTED_DEBUG_END

DISTRIBUTED_DEBUG_BEGIN
Running pjit'd function: f
  python function: <function PJitTest.testShardingConstraintPyTree.<locals>.f at 0x7fad672b8b70>
  mesh: Mesh(array([[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)],
       [TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1)]],
      dtype=object), ('x', 'y'))
  abstract args: [ShapedArray(int32[8,8]), ShapedArray(int32[8,8]), ShapedArray(int32[8,8])]
DISTRIBUTED_DEBUG_END
```
2021-04-20 13:34:45 -07:00