7722 Commits

Author SHA1 Message Date
Skye Wanderman-Milne
9128ba0c74 Replace host_id with process_index terminology, take 2.
We're switching to the new terminology to avoid confusion in cases
where multiple jax processes are running on a single host, and each
process has a unique process_index/host_id.

This keeps aliases for the old `host_id` APIs for now, but these will
eventually be removed.

This was originally commited in
b77ef5138b631378e6a8ceb8bafc94fe91239bae, but reverted in
14acd070c2afb11c81fc91f43790577cd48cbf67 due to Google-internal test
failures from renaming the local_devices argument name. This change is
identical except it also adds staging for the argument name change.
2021-04-20 18:13:34 -07:00
jax authors
8310867b41 Merge pull request #6515 from jakevdp:unique-complex
PiperOrigin-RevId: 369541171
2021-04-20 16:13:11 -07:00
jax authors
8321041cc4 Merge pull request #6475 from lgeiger:speedup-gh-actions
PiperOrigin-RevId: 369524098
2021-04-20 14:44:40 -07:00
jax authors
2c8b1f365d Merge pull request #6516 from zhangqiaorjc:treetest_fix
PiperOrigin-RevId: 369523337
2021-04-20 14:41:07 -07:00
jax authors
259559632c Merge pull request #6502 from skye:distributed_debug
PiperOrigin-RevId: 369522054
2021-04-20 14:34:48 -07:00
Qiao Zhang
99c142a0f3 Use assertRegex in TreeTest since pytest expects a different module prefix. 2021-04-20 14:10:30 -07:00
jax authors
93f26e1ed4 Merge pull request #6514 from hawkinsp:numpy2
PiperOrigin-RevId: 369514704
2021-04-20 13:59:19 -07:00
Jake VanderPlas
13824363b5 Make jnp.unique() support complex inputs 2021-04-20 13:57:30 -07:00
Skye Wanderman-Milne
1614572eb9 Add optional distributed debugging logging.
This can be enabled by setting the environment variable
`JAX_DISTRIBUTED_DEBUG=1` (or other true-like values), the flag
`--jax_distributed_debug=1`, or `jax.config.distributed_debug =
True`. It's off by default.

This enables WARNING-level logging of each distributed computation
that's run and related debugging information. This is designed to help
with multi-process debugging, e.g. to identify mismatched pmaps across
processes. All debugging information is enclosed between
`DISTRIBUTED_DEBUG_BEGIN` and `DISTRIBUTED_DEBUG_END` to faciliate
grepping for this info.

Example output:

```
DISTRIBUTED_DEBUG_BEGIN
Initialized backend: tpu
  process_index: 0
  device_count: 8
  local_devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
DISTRIBUTED_DEBUG_END

DISTRIBUTED_DEBUG_BEGIN
Running pmapped function: <lambda>
  python function: <function PmapTest.testArgAllReduce.<locals>.<lambda> at 0x7f77924d6c80>
  devices: None
  abstract args: [ShapedArray(float32[2,2])]
DISTRIBUTED_DEBUG_END

DISTRIBUTED_DEBUG_BEGIN
Running xmapped function: <lambda>
  python function: <function XMapTest.testAxisSizes.<locals>.<lambda> at 0x7fb33d86e158>
  mesh: Mesh(array([TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
       TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1)],
      dtype=object), ('x',))
  abstract args: []
DISTRIBUTED_DEBUG_END

DISTRIBUTED_DEBUG_BEGIN
Running pjit'd function: f
  python function: <function PJitTest.testShardingConstraintPyTree.<locals>.f at 0x7fad672b8b70>
  mesh: Mesh(array([[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)],
       [TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1)]],
      dtype=object), ('x', 'y'))
  abstract args: [ShapedArray(int32[8,8]), ShapedArray(int32[8,8]), ShapedArray(int32[8,8])]
DISTRIBUTED_DEBUG_END
```
2021-04-20 13:34:45 -07:00
jax authors
ba89c19759 Merge pull request #6266 from JoostvDoorn:padding-cropping-fft
PiperOrigin-RevId: 369497327
2021-04-20 12:35:37 -07:00
Peter Hawkins
4409e03c3b Remove workaround for NumPy < 1.14.
Following NEP 29, we don't need to support NumPy older than 1.17 these days (although in practice we also support 1.16).

Remove accidental export jnp.numpy_version.
2021-04-20 15:27:44 -04:00
jax authors
93b857806c Merge pull request #6512 from jakevdp:sparse-nonzero
PiperOrigin-RevId: 369494862
2021-04-20 12:24:23 -07:00
Adam Paszke
828e210601 Add a type checking rule for xmap
Also fix the type checking code in core, which incorrectly propagated output
avals of custom type checking rules.

PiperOrigin-RevId: 369485371
2021-04-20 11:43:33 -07:00
Adam Paszke
09a519828c Extend pjit error checking to rank errors too
Otherwise one gets an inscrutable `IndexError`.

PiperOrigin-RevId: 369485185
2021-04-20 11:39:58 -07:00
Jake VanderPlas
8808ddf991 Use new jit-compatible jnp.nonzero() in sparse ops 2021-04-20 11:29:54 -07:00
jax authors
fd66e3dcba Merge pull request #6511 from hawkinsp:numpy2
PiperOrigin-RevId: 369478592
2021-04-20 11:08:44 -07:00
jax authors
8d75b594f7 Merge pull request #6501 from jakevdp:nonzero-jit
PiperOrigin-RevId: 369473074
2021-04-20 10:45:48 -07:00
Peter Hawkins
122b2b2c71 Disable a test that fails with Numpy 1.17.5, but not Numpy 1.18.0. 2021-04-20 13:41:48 -04:00
jax authors
bbc7be064c Merge pull request #6239 from j-towns:lt-allow-integers
PiperOrigin-RevId: 369467931
2021-04-20 10:23:10 -07:00
Jake VanderPlas
8d17cce80e Add JIT-compatible version of jnp.nonzero 2021-04-20 09:18:49 -07:00
Adam Paszke
c09037bd14 Move vtile to batching.py, make it possible to add new BatchTraces
No substantial behavior change right now, but the ability to use
subclasses of BatchTrace comes in handy when adding support for
nesting xmaps in the SPMD lowering.

PiperOrigin-RevId: 369445693
2021-04-20 08:33:01 -07:00
Adam Paszke
93c63d0341 Fix cache misses when re-creating equivalent mesh objects
The `Mesh` class was missing `__eq__` and `__hash__` and inherited the
(bad) Python defaults of comparison and hashing by identity.

PiperOrigin-RevId: 369407380
2021-04-20 03:48:26 -07:00
jax authors
14acd070c2 Internal change
PiperOrigin-RevId: 369345279
2021-04-19 18:23:07 -07:00
Peter Hawkins
da1b819f26 Move contents of jax.custom_derivatives to jax._src.custom_derivatives.
PiperOrigin-RevId: 369340983
2021-04-19 17:51:49 -07:00
jax authors
be725472e4 Merge pull request #6500 from skye:v3-64-init
PiperOrigin-RevId: 369329513
2021-04-19 16:42:26 -07:00
Skye Wanderman-Milne
6722c14589 Add v3-64 config to automatic Cloud TPU pod slice initialization. 2021-04-19 15:49:40 -07:00
jax authors
83d9aa50e2 Merge pull request #6215 from skye:process_index
PiperOrigin-RevId: 369319728
2021-04-19 15:48:33 -07:00
Skye Wanderman-Milne
b77ef5138b Replace host_id with process_index terminology.
We're switching to the new terminology to avoid confusion in cases
where multiple jax processes are running on a single host, and each
process has a unique process_index/host_id.

This keeps aliases for the old `host_id` APIs for now, but these will
eventually be removed.
2021-04-19 14:09:19 -07:00
Markus Kunesch
f030e70e82 xla: improvement to string representation of PyTreeDef
The string representation of PyTreeDef was different to how the underlying
containers are represented in python. This sometimes made it harder to read
error messages. This commit modifies the representation of tuples, lists,
dicts, and None so that it matches the pythonic representation.

The representation of custom nodes and NamedTuples is left unchanged since
their structure is not easily accessible in C++. However, to avoid confusion
they are now labelled "CustomNode" instead of "PyTreeDef". The latter is now
only used to wrap the whole representation. See below for examples.

Tests that relied on a specific string representation of PyTreeDef in error
messages are modified to be agnostic to the representation. Instead, this
commit adds a separate test of the string representation in tree_util_test.

Examples:

```
OLD: PyTreeDef(dict[['a', 'b']], [*,*])
NEW: PyTreeDef({'a': *, 'b': *})

OLD: PyTreeDef(tuple, [PyTreeDef(tuple, [*,*]),PyTreeDef(list, [*,PyTreeDef(tuple, [*,PyTreeDef(None, []),*])])])
NEW: PyTreeDef(((*, *), [*, (*, None, *)]))

OLD: PyTreeDef(list, [PyTreeDef(<class '__main__.AnObject'>[[4, 'foo']], [*,PyTreeDef(None, [])])])
NEW: PyTreeDef([CustomNode(<class '__main__.AnObject'>[[4, 'foo']], [*, None])])
```
PiperOrigin-RevId: 369298267
2021-04-19 14:06:36 -07:00
Jake VanderPlas
e152d6645d Fix cusparse headers
PiperOrigin-RevId: 369287633
2021-04-19 13:17:03 -07:00
Adam Paszke
d265fd5604 Ignore named shape when checking aval equality in AD
AD of code with named axes is still WIP, and pmap still doesn't take
proper care to handle them, so weaken the check for now.

PiperOrigin-RevId: 369265258
2021-04-19 11:35:34 -07:00
Adam Paszke
f285f0cc47 Fix pjit resource checks to verify the shapes against the local mesh.
PiperOrigin-RevId: 369263691
2021-04-19 11:28:47 -07:00
jax authors
290d749483 Merge pull request #6460 from LenaMartens:changelist/368626329
PiperOrigin-RevId: 369254004
2021-04-19 10:45:32 -07:00
Lena Martens
fa5e19b630 Fix Zero handling in select_jvp. 2021-04-19 17:03:07 +01:00
Peter Hawkins
14d991dd90 Move jax.config to jax._src.config.
PiperOrigin-RevId: 369230109
2021-04-19 08:53:12 -07:00
jax authors
0f96406130 Merge pull request #6461 from apaszke:xmap-awn
PiperOrigin-RevId: 369208554
2021-04-19 06:36:34 -07:00
Joost van Doorn
7091ae5af6 Add support for padding and cropping to fft 2021-04-17 08:38:24 +02:00
jax authors
6f0f717476 Merge pull request #6470 from google:issue6452
PiperOrigin-RevId: 368911086
2021-04-16 13:39:55 -07:00
Matthew Johnson
9d6263a743 support implicit broadcasting in transpose rules 2021-04-16 12:51:11 -07:00
Jake VanderPlas
919b11e81a Remove unnecessary dependency
PiperOrigin-RevId: 368882451
2021-04-16 11:10:33 -07:00
jax authors
c1ed3bd05f Merge pull request #6480 from LenaMartens:patch-3
PiperOrigin-RevId: 368856047
2021-04-16 08:56:17 -07:00
Lena Martens
fcf87cd7f2
Fix typo in NamedShape 2021-04-16 14:20:25 +01:00
Adam Paszke
c9b0b3122e Enable avals-with-names in xmap
Starting from this change, we start introducing xmapped names when
tracing the xmap jaxpr and eliminating them from avals when the values
are returned. This lets us enable two long-awaited checks:
1. Returning values that are mapped along more axes than `out_axes`
   declare now results in a readable error, instead of an internal
   vmap assertion.
2. We catch the resource-overlap error triggered by making two axes
   mapped to the same resources coincide in a single value.
2021-04-16 10:01:33 +00:00
jax authors
d9605c1627 Merge pull request #6477 from zhangqiaorjc:hb1
PiperOrigin-RevId: 368806831
2021-04-16 02:18:48 -07:00
Qiao Zhang
6247c53e9b Disable concurrent test_tap_multiple on CPU.
New TFRT CPU backend will have concurrent program execution which
can cause outfeed from different programs to interleave.
2021-04-15 21:02:38 -07:00
Lukas Geiger
78ea235ac7 CI: Install wheel pkg to improve pip-cache 2021-04-16 02:55:41 +02:00
George Necula
799ac413bd Re-enable test disabled due to LLVM integrate bug
PiperOrigin-RevId: 368744520
2021-04-15 16:51:15 -07:00
jax authors
0f1ddf6dc3 Merge pull request #6468 from jakevdp:fromdense
PiperOrigin-RevId: 368736649
2021-04-15 16:06:15 -07:00
jax authors
ee7158fcd6 Merge pull request #6472 from skye:lax_scipy_sparse_test
PiperOrigin-RevId: 368735746
2021-04-15 16:01:22 -07:00
Skye Wanderman-Milne
346df9c557 Disable lax_scipy_sparse_test.py cases that are hanging on GPU.
See #6471.
2021-04-15 15:54:15 -07:00