We're switching to the new terminology to avoid confusion in cases
where multiple jax processes are running on a single host, and each
process has a unique process_index/host_id.
This keeps aliases for the old `host_id` APIs for now, but these will
eventually be removed.
This was originally commited in
b77ef5138b631378e6a8ceb8bafc94fe91239bae, but reverted in
14acd070c2afb11c81fc91f43790577cd48cbf67 due to Google-internal test
failures from renaming the local_devices argument name. This change is
identical except it also adds staging for the argument name change.
This can be enabled by setting the environment variable
`JAX_DISTRIBUTED_DEBUG=1` (or other true-like values), the flag
`--jax_distributed_debug=1`, or `jax.config.distributed_debug =
True`. It's off by default.
This enables WARNING-level logging of each distributed computation
that's run and related debugging information. This is designed to help
with multi-process debugging, e.g. to identify mismatched pmaps across
processes. All debugging information is enclosed between
`DISTRIBUTED_DEBUG_BEGIN` and `DISTRIBUTED_DEBUG_END` to faciliate
grepping for this info.
Example output:
```
DISTRIBUTED_DEBUG_BEGIN
Initialized backend: tpu
process_index: 0
device_count: 8
local_devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
DISTRIBUTED_DEBUG_END
DISTRIBUTED_DEBUG_BEGIN
Running pmapped function: <lambda>
python function: <function PmapTest.testArgAllReduce.<locals>.<lambda> at 0x7f77924d6c80>
devices: None
abstract args: [ShapedArray(float32[2,2])]
DISTRIBUTED_DEBUG_END
DISTRIBUTED_DEBUG_BEGIN
Running xmapped function: <lambda>
python function: <function XMapTest.testAxisSizes.<locals>.<lambda> at 0x7fb33d86e158>
mesh: Mesh(array([TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1)],
dtype=object), ('x',))
abstract args: []
DISTRIBUTED_DEBUG_END
DISTRIBUTED_DEBUG_BEGIN
Running pjit'd function: f
python function: <function PJitTest.testShardingConstraintPyTree.<locals>.f at 0x7fad672b8b70>
mesh: Mesh(array([[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)],
[TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1)]],
dtype=object), ('x', 'y'))
abstract args: [ShapedArray(int32[8,8]), ShapedArray(int32[8,8]), ShapedArray(int32[8,8])]
DISTRIBUTED_DEBUG_END
```
Following NEP 29, we don't need to support NumPy older than 1.17 these days (although in practice we also support 1.16).
Remove accidental export jnp.numpy_version.
No substantial behavior change right now, but the ability to use
subclasses of BatchTrace comes in handy when adding support for
nesting xmaps in the SPMD lowering.
PiperOrigin-RevId: 369445693
The `Mesh` class was missing `__eq__` and `__hash__` and inherited the
(bad) Python defaults of comparison and hashing by identity.
PiperOrigin-RevId: 369407380
We're switching to the new terminology to avoid confusion in cases
where multiple jax processes are running on a single host, and each
process has a unique process_index/host_id.
This keeps aliases for the old `host_id` APIs for now, but these will
eventually be removed.
The string representation of PyTreeDef was different to how the underlying
containers are represented in python. This sometimes made it harder to read
error messages. This commit modifies the representation of tuples, lists,
dicts, and None so that it matches the pythonic representation.
The representation of custom nodes and NamedTuples is left unchanged since
their structure is not easily accessible in C++. However, to avoid confusion
they are now labelled "CustomNode" instead of "PyTreeDef". The latter is now
only used to wrap the whole representation. See below for examples.
Tests that relied on a specific string representation of PyTreeDef in error
messages are modified to be agnostic to the representation. Instead, this
commit adds a separate test of the string representation in tree_util_test.
Examples:
```
OLD: PyTreeDef(dict[['a', 'b']], [*,*])
NEW: PyTreeDef({'a': *, 'b': *})
OLD: PyTreeDef(tuple, [PyTreeDef(tuple, [*,*]),PyTreeDef(list, [*,PyTreeDef(tuple, [*,PyTreeDef(None, []),*])])])
NEW: PyTreeDef(((*, *), [*, (*, None, *)]))
OLD: PyTreeDef(list, [PyTreeDef(<class '__main__.AnObject'>[[4, 'foo']], [*,PyTreeDef(None, [])])])
NEW: PyTreeDef([CustomNode(<class '__main__.AnObject'>[[4, 'foo']], [*, None])])
```
PiperOrigin-RevId: 369298267
AD of code with named axes is still WIP, and pmap still doesn't take
proper care to handle them, so weaken the check for now.
PiperOrigin-RevId: 369265258