7741 Commits

Author SHA1 Message Date
Jake VanderPlas
bb543f2b5b jnp.unique: add support for axis argument 2021-04-21 16:00:14 -07:00
Jake VanderPlas
90d606fe25 Remove jax.experimental.doubledouble
PiperOrigin-RevId: 369740697
2021-04-21 14:52:14 -07:00
Jake VanderPlas
122fbcbd09 cusparse: use cstdint types
PiperOrigin-RevId: 369739828
2021-04-21 14:48:31 -07:00
jax authors
c25706ea81 Merge pull request #6531 from zhangqiaorjc:skiphb
PiperOrigin-RevId: 369705323
2021-04-21 12:06:36 -07:00
Qiao Zhang
4af95c7f7b Skip hostcall back tests that require concurrency. 2021-04-21 12:00:34 -07:00
Adam Paszke
a73fc71e0f Implement JVP for pjit
PiperOrigin-RevId: 369692330
2021-04-21 11:05:29 -07:00
jax authors
cae886b6ad Merge pull request #6529 from hawkinsp:docs
PiperOrigin-RevId: 369678418
2021-04-21 10:04:20 -07:00
Peter Hawkins
aafe8870ae Document that JAX follows the NEP-29 deprecation policy.
Remove the "experimental" disclaimer from the concurrency documentation.
2021-04-21 11:12:41 -04:00
jax authors
8248837cf2 Merge pull request #6528 from hawkinsp:xla
PiperOrigin-RevId: 369652169
2021-04-21 07:39:03 -07:00
Peter Hawkins
8654a9816f Update XLA. 2021-04-21 09:51:29 -04:00
Peter Hawkins
5261b776d2 Handle context manager configuration settings for matmul precision and numpy rank promotion correctly in JIT and linear_util caches.
PiperOrigin-RevId: 369643419
2021-04-21 06:36:35 -07:00
jax authors
0517675a12 Merge pull request #6526 from LenaMartens:changelist/369624561
PiperOrigin-RevId: 369633081
2021-04-21 05:06:45 -07:00
Lena Martens
deb2227f4a Make sure the out_axes in the HashableFunction closure are hashable.
By flattening them before putting them in the closure.
2021-04-21 12:32:19 +01:00
Adam Paszke
42d2e7620a Implement nesting of pjits
Without this change nesting works only when the inner `pjit`ed functions don't
close over any values.

PiperOrigin-RevId: 369626779
2021-04-21 04:09:58 -07:00
jax authors
a7f07601da Merge pull request #6522 from skye:process_index2
PiperOrigin-RevId: 369563230
2021-04-20 18:32:02 -07:00
Skye Wanderman-Milne
9128ba0c74 Replace host_id with process_index terminology, take 2.
We're switching to the new terminology to avoid confusion in cases
where multiple jax processes are running on a single host, and each
process has a unique process_index/host_id.

This keeps aliases for the old `host_id` APIs for now, but these will
eventually be removed.

This was originally commited in
b77ef5138b631378e6a8ceb8bafc94fe91239bae, but reverted in
14acd070c2afb11c81fc91f43790577cd48cbf67 due to Google-internal test
failures from renaming the local_devices argument name. This change is
identical except it also adds staging for the argument name change.
2021-04-20 18:13:34 -07:00
jax authors
6e099bf20b Merge pull request #6521 from skye:test_fixes
PiperOrigin-RevId: 369559987
2021-04-20 18:05:56 -07:00
Skye Wanderman-Milne
feb79e5698 Fix some Cloud TPU test failures.
The new select_and_gather_add logic was inspired by
3a35f7072a.
2021-04-21 00:37:02 +00:00
jax authors
cdff8a1e04 Merge pull request #6519 from zhangqiaorjc:disable_test
PiperOrigin-RevId: 369555579
2021-04-20 17:33:36 -07:00
jax authors
8310867b41 Merge pull request #6515 from jakevdp:unique-complex
PiperOrigin-RevId: 369541171
2021-04-20 16:13:11 -07:00
Qiao Zhang
2baed8da01 Disable HostCallbackIdTapTest.test_tap_multiple_barriers
for TPU too since it runs on CPU.
2021-04-20 16:02:48 -07:00
jax authors
8321041cc4 Merge pull request #6475 from lgeiger:speedup-gh-actions
PiperOrigin-RevId: 369524098
2021-04-20 14:44:40 -07:00
jax authors
2c8b1f365d Merge pull request #6516 from zhangqiaorjc:treetest_fix
PiperOrigin-RevId: 369523337
2021-04-20 14:41:07 -07:00
jax authors
259559632c Merge pull request #6502 from skye:distributed_debug
PiperOrigin-RevId: 369522054
2021-04-20 14:34:48 -07:00
Qiao Zhang
99c142a0f3 Use assertRegex in TreeTest since pytest expects a different module prefix. 2021-04-20 14:10:30 -07:00
jax authors
93f26e1ed4 Merge pull request #6514 from hawkinsp:numpy2
PiperOrigin-RevId: 369514704
2021-04-20 13:59:19 -07:00
Jake VanderPlas
13824363b5 Make jnp.unique() support complex inputs 2021-04-20 13:57:30 -07:00
Skye Wanderman-Milne
1614572eb9 Add optional distributed debugging logging.
This can be enabled by setting the environment variable
`JAX_DISTRIBUTED_DEBUG=1` (or other true-like values), the flag
`--jax_distributed_debug=1`, or `jax.config.distributed_debug =
True`. It's off by default.

This enables WARNING-level logging of each distributed computation
that's run and related debugging information. This is designed to help
with multi-process debugging, e.g. to identify mismatched pmaps across
processes. All debugging information is enclosed between
`DISTRIBUTED_DEBUG_BEGIN` and `DISTRIBUTED_DEBUG_END` to faciliate
grepping for this info.

Example output:

```
DISTRIBUTED_DEBUG_BEGIN
Initialized backend: tpu
  process_index: 0
  device_count: 8
  local_devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
DISTRIBUTED_DEBUG_END

DISTRIBUTED_DEBUG_BEGIN
Running pmapped function: <lambda>
  python function: <function PmapTest.testArgAllReduce.<locals>.<lambda> at 0x7f77924d6c80>
  devices: None
  abstract args: [ShapedArray(float32[2,2])]
DISTRIBUTED_DEBUG_END

DISTRIBUTED_DEBUG_BEGIN
Running xmapped function: <lambda>
  python function: <function XMapTest.testAxisSizes.<locals>.<lambda> at 0x7fb33d86e158>
  mesh: Mesh(array([TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
       TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1)],
      dtype=object), ('x',))
  abstract args: []
DISTRIBUTED_DEBUG_END

DISTRIBUTED_DEBUG_BEGIN
Running pjit'd function: f
  python function: <function PJitTest.testShardingConstraintPyTree.<locals>.f at 0x7fad672b8b70>
  mesh: Mesh(array([[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)],
       [TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1)]],
      dtype=object), ('x', 'y'))
  abstract args: [ShapedArray(int32[8,8]), ShapedArray(int32[8,8]), ShapedArray(int32[8,8])]
DISTRIBUTED_DEBUG_END
```
2021-04-20 13:34:45 -07:00
jax authors
ba89c19759 Merge pull request #6266 from JoostvDoorn:padding-cropping-fft
PiperOrigin-RevId: 369497327
2021-04-20 12:35:37 -07:00
Peter Hawkins
4409e03c3b Remove workaround for NumPy < 1.14.
Following NEP 29, we don't need to support NumPy older than 1.17 these days (although in practice we also support 1.16).

Remove accidental export jnp.numpy_version.
2021-04-20 15:27:44 -04:00
jax authors
93b857806c Merge pull request #6512 from jakevdp:sparse-nonzero
PiperOrigin-RevId: 369494862
2021-04-20 12:24:23 -07:00
Adam Paszke
828e210601 Add a type checking rule for xmap
Also fix the type checking code in core, which incorrectly propagated output
avals of custom type checking rules.

PiperOrigin-RevId: 369485371
2021-04-20 11:43:33 -07:00
Adam Paszke
09a519828c Extend pjit error checking to rank errors too
Otherwise one gets an inscrutable `IndexError`.

PiperOrigin-RevId: 369485185
2021-04-20 11:39:58 -07:00
Jake VanderPlas
8808ddf991 Use new jit-compatible jnp.nonzero() in sparse ops 2021-04-20 11:29:54 -07:00
jax authors
fd66e3dcba Merge pull request #6511 from hawkinsp:numpy2
PiperOrigin-RevId: 369478592
2021-04-20 11:08:44 -07:00
jax authors
8d75b594f7 Merge pull request #6501 from jakevdp:nonzero-jit
PiperOrigin-RevId: 369473074
2021-04-20 10:45:48 -07:00
Peter Hawkins
122b2b2c71 Disable a test that fails with Numpy 1.17.5, but not Numpy 1.18.0. 2021-04-20 13:41:48 -04:00
jax authors
bbc7be064c Merge pull request #6239 from j-towns:lt-allow-integers
PiperOrigin-RevId: 369467931
2021-04-20 10:23:10 -07:00
Jake VanderPlas
8d17cce80e Add JIT-compatible version of jnp.nonzero 2021-04-20 09:18:49 -07:00
Adam Paszke
c09037bd14 Move vtile to batching.py, make it possible to add new BatchTraces
No substantial behavior change right now, but the ability to use
subclasses of BatchTrace comes in handy when adding support for
nesting xmaps in the SPMD lowering.

PiperOrigin-RevId: 369445693
2021-04-20 08:33:01 -07:00
Adam Paszke
93c63d0341 Fix cache misses when re-creating equivalent mesh objects
The `Mesh` class was missing `__eq__` and `__hash__` and inherited the
(bad) Python defaults of comparison and hashing by identity.

PiperOrigin-RevId: 369407380
2021-04-20 03:48:26 -07:00
jax authors
14acd070c2 Internal change
PiperOrigin-RevId: 369345279
2021-04-19 18:23:07 -07:00
Peter Hawkins
da1b819f26 Move contents of jax.custom_derivatives to jax._src.custom_derivatives.
PiperOrigin-RevId: 369340983
2021-04-19 17:51:49 -07:00
jax authors
be725472e4 Merge pull request #6500 from skye:v3-64-init
PiperOrigin-RevId: 369329513
2021-04-19 16:42:26 -07:00
Skye Wanderman-Milne
6722c14589 Add v3-64 config to automatic Cloud TPU pod slice initialization. 2021-04-19 15:49:40 -07:00
jax authors
83d9aa50e2 Merge pull request #6215 from skye:process_index
PiperOrigin-RevId: 369319728
2021-04-19 15:48:33 -07:00
Skye Wanderman-Milne
b77ef5138b Replace host_id with process_index terminology.
We're switching to the new terminology to avoid confusion in cases
where multiple jax processes are running on a single host, and each
process has a unique process_index/host_id.

This keeps aliases for the old `host_id` APIs for now, but these will
eventually be removed.
2021-04-19 14:09:19 -07:00
Markus Kunesch
f030e70e82 xla: improvement to string representation of PyTreeDef
The string representation of PyTreeDef was different to how the underlying
containers are represented in python. This sometimes made it harder to read
error messages. This commit modifies the representation of tuples, lists,
dicts, and None so that it matches the pythonic representation.

The representation of custom nodes and NamedTuples is left unchanged since
their structure is not easily accessible in C++. However, to avoid confusion
they are now labelled "CustomNode" instead of "PyTreeDef". The latter is now
only used to wrap the whole representation. See below for examples.

Tests that relied on a specific string representation of PyTreeDef in error
messages are modified to be agnostic to the representation. Instead, this
commit adds a separate test of the string representation in tree_util_test.

Examples:

```
OLD: PyTreeDef(dict[['a', 'b']], [*,*])
NEW: PyTreeDef({'a': *, 'b': *})

OLD: PyTreeDef(tuple, [PyTreeDef(tuple, [*,*]),PyTreeDef(list, [*,PyTreeDef(tuple, [*,PyTreeDef(None, []),*])])])
NEW: PyTreeDef(((*, *), [*, (*, None, *)]))

OLD: PyTreeDef(list, [PyTreeDef(<class '__main__.AnObject'>[[4, 'foo']], [*,PyTreeDef(None, [])])])
NEW: PyTreeDef([CustomNode(<class '__main__.AnObject'>[[4, 'foo']], [*, None])])
```
PiperOrigin-RevId: 369298267
2021-04-19 14:06:36 -07:00
Jake VanderPlas
e152d6645d Fix cusparse headers
PiperOrigin-RevId: 369287633
2021-04-19 13:17:03 -07:00
Adam Paszke
d265fd5604 Ignore named shape when checking aval equality in AD
AD of code with named axes is still WIP, and pmap still doesn't take
proper care to handle them, so weaken the check for now.

PiperOrigin-RevId: 369265258
2021-04-19 11:35:34 -07:00