36 Commits

Author SHA1 Message Date
Jake VanderPlas
98fac62897 remove dead code: jax._src.util.taggedtuple 2022-07-25 15:14:25 -07:00
Matthew Johnson
5b82ba787c [dynamic-shapes] start basic vmap compatibility 2022-07-09 10:03:40 -07:00
Matthew Johnson
004b59fbc9 [dynamic-shapes] basic linearize and grad working 2022-06-30 14:30:22 -07:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Jake VanderPlas
72470dee3a Comment on implementation of unzip2 & unzip3 2022-04-14 13:41:05 -07:00
Parker Schuh
df1c478ec5 Fix race condition for weakref destructor by catching rare exceptions. 2022-04-01 12:04:36 -07:00
Parker Schuh
c1bb767959
Update util.py 2022-03-23 12:26:09 -07:00
Parker Schuh
d0e0da02a1 Add weakreaf_lru_cache to prevent caches from pinning jaxprs.
To use this cache, the first argument must be some type that is
object identity hashed (like a jaxpr).
2022-03-21 10:56:44 -07:00
Sharad Vikram
1b79caa6bd Add separate mechanism for threading name stacks to the lowering 2022-02-23 09:59:09 -08:00
Matthew Johnson
7077ce2e68 [remove units] make JaxprTrace.process_call not introduce units 2022-02-12 13:48:12 -08:00
Matthew Johnson
e321964245 de-duplicate util.memoize and util.cache
The only difference between the two was that
jax.config.jax_check_tracer_leaks disables the caching under util.cache
but not under util.memoize.

We could add that as an option on the same function if it turns out to
be important, but it seems unnecessary. Moreover there are only two
callers (in dtypes.py and in batching.py).

Co-authored-by: Skye Wanderman-Milne <skyewm@google.com>
2022-01-10 14:28:28 -08:00
Peter Hawkins
52fe821719 Merge xla._partition_outputs and util.unflatten.
PiperOrigin-RevId: 412117736
2021-11-24 12:52:40 -08:00
Jake VanderPlas
f6e3f1b4ad Cleanup: remove duplicate canonicalize_axis utility 2021-11-23 16:54:02 -08:00
Peter Hawkins
0f56838435 Fix a number of bugs in MLIR translation rules.
These bugs were found by running the existing tests with MLIR translations enabled, so no new tests are needed:

* Fix bug where we failed to propagate the symbol table to inner computations. This could lead to duplicate function names.
* Remove support for tupling arguments. It turns out that the MHLO->HLO conversion, which was the intended user, does not accept tupled arguments in the input MHLO. Instead, arguments are tupled if requested by a flag to the converter.
* Add a generic fallback to translate via the XLA HLO to MHLO if there is no MHLO-specific translation rule.
* If we are padding in select_and_scatter_add, we also need to slice the output.
* create_token may take arguments (which should be ignored).
* Fixed a number of misunderstandings of the mhlo.infeed contract.
* Untuple results in the fallback path iff the primitive is marked as having multiple results, not depending on the actual arity.
* Change xla.primitive_subcomputation not to filter token arguments, which is appropriate for a subcomputation.

PiperOrigin-RevId: 410519678
2021-11-17 07:20:56 -08:00
Peter Hawkins
70b8a6a806 Add a prototype IREE backend for JAX.
This is to support experimentation with the combination of JAX/IREE. Many things do not work yet.

PiperOrigin-RevId: 409980064
2021-11-15 07:57:04 -08:00
Peter Hawkins
8f6e077d9a Adds an initial prototype of an alternate JAX compilation path that emits the MLIR MHLO/CHLO dialects instead of classic XLA HLO.
This lowering is missing a number of features, but it is complete enough that many tests pass, and that I would like to start checking it in.

PiperOrigin-RevId: 409134016
2021-11-11 06:37:12 -08:00
Peter Hawkins
42e0d4e5f5 Remove jax._src.util.partialmethod.
Use functools.partialmethod instead, which has existed since Python 3.4. The JAX partialmethod doesn't work correctly in Python 3.10.

Issue #8097
2021-10-05 12:12:41 -04:00
Peter Hawkins
a11d957e61 Disallow non-hashable static arguments in pmap().
* Don't wrap static arguments in hashable wrappers in pmap.
* Delete wrap_hashably().
* In argnums_partial, either enforce hashability or wrap values with an explicitly unhashable wrapper. The intent here is that either we should check for hashability early or we should make sure it's clear that it's not something we intended..
* Delete argnames_partial, which appears unused.
2021-09-30 15:50:07 -04:00
Peter Hawkins
8b2123968a Switch internal users of jax.util.partial to use functools.partial. 2021-09-13 21:09:58 -04:00
Matthew Johnson
2d28951ba4 address comments form @apaszke 2021-08-26 14:10:58 -07:00
Matthew Johnson
542641ca87 rejames/reblake implementation 2021-08-25 20:46:32 -07:00
Peter Hawkins
e709a2ea4d Delete tuple_replace.
It is unused.
2021-05-19 15:29:44 -04:00
Peter Hawkins
5261b776d2 Handle context manager configuration settings for matmul precision and numpy rank promotion correctly in JIT and linear_util caches.
PiperOrigin-RevId: 369643419
2021-04-21 06:36:35 -07:00
Skye Wanderman-Milne
1614572eb9 Add optional distributed debugging logging.
This can be enabled by setting the environment variable
`JAX_DISTRIBUTED_DEBUG=1` (or other true-like values), the flag
`--jax_distributed_debug=1`, or `jax.config.distributed_debug =
True`. It's off by default.

This enables WARNING-level logging of each distributed computation
that's run and related debugging information. This is designed to help
with multi-process debugging, e.g. to identify mismatched pmaps across
processes. All debugging information is enclosed between
`DISTRIBUTED_DEBUG_BEGIN` and `DISTRIBUTED_DEBUG_END` to faciliate
grepping for this info.

Example output:

```
DISTRIBUTED_DEBUG_BEGIN
Initialized backend: tpu
  process_index: 0
  device_count: 8
  local_devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
DISTRIBUTED_DEBUG_END

DISTRIBUTED_DEBUG_BEGIN
Running pmapped function: <lambda>
  python function: <function PmapTest.testArgAllReduce.<locals>.<lambda> at 0x7f77924d6c80>
  devices: None
  abstract args: [ShapedArray(float32[2,2])]
DISTRIBUTED_DEBUG_END

DISTRIBUTED_DEBUG_BEGIN
Running xmapped function: <lambda>
  python function: <function XMapTest.testAxisSizes.<locals>.<lambda> at 0x7fb33d86e158>
  mesh: Mesh(array([TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
       TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1)],
      dtype=object), ('x',))
  abstract args: []
DISTRIBUTED_DEBUG_END

DISTRIBUTED_DEBUG_BEGIN
Running pjit'd function: f
  python function: <function PJitTest.testShardingConstraintPyTree.<locals>.f at 0x7fad672b8b70>
  mesh: Mesh(array([[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)],
       [TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1)]],
      dtype=object), ('x', 'y'))
  abstract args: [ShapedArray(int32[8,8]), ShapedArray(int32[8,8]), ShapedArray(int32[8,8])]
DISTRIBUTED_DEBUG_END
```
2021-04-20 13:34:45 -07:00
Stephan Hoyer
acb0be9cb7 Add _python_jit_with_static_argnames. 2021-03-31 10:02:16 -07:00
Matthew Johnson
2b79264354 remove disable_omnistaging mechanism 2021-03-29 15:26:57 -07:00
Matthew Johnson
fd7b286ec9 unify configuration state handling 2021-03-23 18:56:01 -07:00
jax authors
01485ec4ee Merge pull request #5673 from apaszke:pgather
PiperOrigin-RevId: 356492981
2021-02-09 07:03:01 -08:00
Adam Paszke
b19dd87581 Add a pgather primitive, making it possible to index into mapped axes 2021-02-09 10:44:31 +00:00
jax authors
1bc82f139e Merge pull request #5644 from apaszke:xmap-more-reductions
PiperOrigin-RevId: 356378902
2021-02-08 16:25:10 -08:00
Adam Paszke
6965e8bbe3 Add support for named axes in jnp.mean and jnp.std 2021-02-08 20:43:23 +00:00
Jake VanderPlas
c1d1d94bf7 util.wraps: update same attributes as functools.wraps. 2021-02-08 11:31:53 -08:00
Jake VanderPlas
2fd682ef2a Make jax_enable_x64 a thread-local value. 2021-02-04 09:48:22 -08:00
Jake VanderPlas
b5454613f7 Add experimental context manager to enable/disable X64 mode 2021-01-25 13:23:15 -08:00
Matthew Johnson
203af4517b revive the leak checker, as a debug mode
Co-authored-by: James Bradbury <jekbradbury@google.com>
2021-01-22 18:31:00 -08:00
Peter Hawkins
3ac809ede3 [JAX] Move jax.util to jax._src_util.
PiperOrigin-RevId: 351234602
2021-01-11 14:21:07 -08:00