This is the first step in a revision to how we handle the debug info pertaining
to staged functions' parameter names and result pytree paths. To limit
complexity, this first step adds machinery required to make our MLIR lowerings'
parameter and result names work, but it does *not* yet unify it with existing
arg-name machinery used at tracing time (in partial_eval.py, e.g.
partial_eval.DebugInfo etc). That unification will come in a follow up commits.
(I wrote the unified version first, then broke it down into this sequence of
commits.)
Another thing that will arrive in follow-up commits is pmap support (handling
static_broadcasted_argnames). This PR doesn't include support for pmap because
pmap's final style implementation requires slightly different machinery than
jit/pjit's initial style implementation. Indeed this PR removes the previous
support for pmap arg/result info, and skips the corresponding tests, because
the previous support didn't handle pmap's static_broadcasted_argnums (and I
think it could even lead to silently incorrect annotations when pmap was not at
the top-level, though I didn't work out an example case to be sure that was
possible).
This commit includes the changes from PR #15079, so that PR should be merged first.
Here's the _why_ of this change:
* The pre-existing solution (from PRs #14702, #14764, and #14813) did not
handle static_argnums or static_argnames correctly. Instead it would fail,
resulting in debug info being dropped from the jaxpr and ultimately the MLIR
computation (but no Exception raised). We need to handle
static_argnums/argnames because while the corresponding parameters remain on
the Python callable signature, they are excluded from the args/kwargs
pytrees; the previous solution didn't account for that divergence.
* The best way to handle static_argnums/argnames is to work out this debug info
when we still have the original args/kwargs in hand, i.e. much earlier than
the previous mechanism. We then just have to pass this debug info to the
right places. Indeed we often already had to work out some debug-related
information at these call sites (e.g. whether the function is being staged
out for jit, or scan, or whatever), so after this change we're working out
all the debug info at the same time.
* A side benefit is that now to get this debug info we no longer need to
unflatten user pytree defs with dummy objects (to reconstruct dummy
args/kwargs trees so that we can call inspect.signature(fun).bind), since we
just use the original args/kwargs instead. Since some user pytree node types
are not fully polymorphic in their element types (e.g. their __init__ methods
sometimes contained assertions about their elements' shapes, expecting them
to be arrays), that means the new mechanism is fundamentally more compatible
with custom pytree node types.
More concretely, effecting those high-level changes led to:
* replacing the previous `core.DebugInfo` with a class `core.JaxprDebugInfo`,
which in addition to the more precise name has fields like
`arg_names: Tuple[Optional[str], ...]` and
`result_paths: Tuple[Optional[str], ...]`, rather than
`in_tree: Optional[PyTreeDef]`, reflecting the fact that we work out the
actual debug info more eagerly than before and we don't need pytrees for
dummy-unflattening;
* introducing the new `partial_eval.TracingDebugInfo` class representing the
debug info about inputs which we have available at tracing time; in a
follow-up PR, we'll adapt partial_eval.py to use this new class and we'll
delete `partial_eval.DebugInfo` and its corresponding helper methods (not
done in this commit just to reduce complexity of each change);
* moving the old `core.DebugInfo`, which before #14702 lived in
partial_eval.py, back to partial_eval.py pending cleanup (deletion) of that
partial_eval.py debug info code;
* making specific jaxpr-processing functions produce an appropriately updated
`core.JaxprDebugInfo` object for their output (e.g. `pe.dce_jaxpr` prunes
elements from the `arg_names` field), maintaining now-checked invariants like
a Jaxpr's `debug_info` should have the same number of argument names as the
jaxpr has invars (the jaxpr-processing functions updated here are enough for
top-level jit jaxprs to have debug info attached, handling the original
intended use case of jit(f).lower, but not e.g. grad-of-jit cases, which can
be handled later by updating `ad.jvp_jaxpr` and the like to produce updated
debug info on their outputs);
* add some tests for static_argnums/static_argnames.
Phew! Can't wait to land those follow-ups too :P
* Define use_cpp_class and use_cpp_method decorators as no-ops for type checking.
* Remove the use of abc.ABC when defining the Sharding type. This triggers a pytype bug: the easiest fix seems to be to skip the use of the ABC.
* Write use_cpp_class decorator differently on ArrayImpl to work around pytype bug.
* Fix a few new type errors.
PiperOrigin-RevId: 516631428
@curry is opaque to pytype.
Fix a false positive type error that turns up because pytype doesn't really understand that a functools.partial is a kind of Callable.
PiperOrigin-RevId: 513697380
Work around pytype bug. It seems that the line
from functools import cached_property
causes pytype to give up on the entire module. Avoid the member import to fix the type inference.
PiperOrigin-RevId: 513544106
This commit changes the JAX codebase to use Python's builtin logging instead of ABSL logging. With the latter being used in JAX code as of now, the change to Python builtin logging is advised for the following reasons (among others):
- absl-py can be removed as an external dependency of JAX.
- Builtin logging brings the option of adding more log handlers, for example file handlers for log dumps or writers to different IO streams.
Logging in JAX is ported over to take place at the module level. While previously, some Python namespaces within JAX already used module-scoped logging via absl.vlog, the following idiom was adopted to provide the same functionality in Python builtin logging:
```py
import logging
logger = logging.getLogger(__name__)
logger.debug(...)
logger.info(...)
```
The builtin root logger is left untouched, which is beneficial for downstream users planning to customize the Python root logger. All JAX internal code promises to log to descendants of the top-level "jax" logger by virtue of log propagation.
The package `absl-py` was removed from JAX's install requirements, and added into its test requirements.
The only difference between the two was that
jax.config.jax_check_tracer_leaks disables the caching under util.cache
but not under util.memoize.
We could add that as an option on the same function if it turns out to
be important, but it seems unnecessary. Moreover there are only two
callers (in dtypes.py and in batching.py).
Co-authored-by: Skye Wanderman-Milne <skyewm@google.com>
These bugs were found by running the existing tests with MLIR translations enabled, so no new tests are needed:
* Fix bug where we failed to propagate the symbol table to inner computations. This could lead to duplicate function names.
* Remove support for tupling arguments. It turns out that the MHLO->HLO conversion, which was the intended user, does not accept tupled arguments in the input MHLO. Instead, arguments are tupled if requested by a flag to the converter.
* Add a generic fallback to translate via the XLA HLO to MHLO if there is no MHLO-specific translation rule.
* If we are padding in select_and_scatter_add, we also need to slice the output.
* create_token may take arguments (which should be ignored).
* Fixed a number of misunderstandings of the mhlo.infeed contract.
* Untuple results in the fallback path iff the primitive is marked as having multiple results, not depending on the actual arity.
* Change xla.primitive_subcomputation not to filter token arguments, which is appropriate for a subcomputation.
PiperOrigin-RevId: 410519678
This lowering is missing a number of features, but it is complete enough that many tests pass, and that I would like to start checking it in.
PiperOrigin-RevId: 409134016
* Don't wrap static arguments in hashable wrappers in pmap.
* Delete wrap_hashably().
* In argnums_partial, either enforce hashability or wrap values with an explicitly unhashable wrapper. The intent here is that either we should check for hashability early or we should make sure it's clear that it's not something we intended..
* Delete argnames_partial, which appears unused.
This can be enabled by setting the environment variable
`JAX_DISTRIBUTED_DEBUG=1` (or other true-like values), the flag
`--jax_distributed_debug=1`, or `jax.config.distributed_debug =
True`. It's off by default.
This enables WARNING-level logging of each distributed computation
that's run and related debugging information. This is designed to help
with multi-process debugging, e.g. to identify mismatched pmaps across
processes. All debugging information is enclosed between
`DISTRIBUTED_DEBUG_BEGIN` and `DISTRIBUTED_DEBUG_END` to faciliate
grepping for this info.
Example output:
```
DISTRIBUTED_DEBUG_BEGIN
Initialized backend: tpu
process_index: 0
device_count: 8
local_devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
DISTRIBUTED_DEBUG_END
DISTRIBUTED_DEBUG_BEGIN
Running pmapped function: <lambda>
python function: <function PmapTest.testArgAllReduce.<locals>.<lambda> at 0x7f77924d6c80>
devices: None
abstract args: [ShapedArray(float32[2,2])]
DISTRIBUTED_DEBUG_END
DISTRIBUTED_DEBUG_BEGIN
Running xmapped function: <lambda>
python function: <function XMapTest.testAxisSizes.<locals>.<lambda> at 0x7fb33d86e158>
mesh: Mesh(array([TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1)],
dtype=object), ('x',))
abstract args: []
DISTRIBUTED_DEBUG_END
DISTRIBUTED_DEBUG_BEGIN
Running pjit'd function: f
python function: <function PJitTest.testShardingConstraintPyTree.<locals>.f at 0x7fad672b8b70>
mesh: Mesh(array([[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)],
[TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1)]],
dtype=object), ('x', 'y'))
abstract args: [ShapedArray(int32[8,8]), ShapedArray(int32[8,8]), ShapedArray(int32[8,8])]
DISTRIBUTED_DEBUG_END
```