48 Commits

Author SHA1 Message Date
Parker Schuh
568a93bcd1 Convert _arrays to return PyArray instead of PyBuffer.
PiperOrigin-RevId: 508769390
2023-02-10 15:32:57 -08:00
Matthew Johnson
ff1e9b3973 shard_map (shmap) prototype and JEP
Co-authored-by: Sharad Vikram <sharadmv@google.com>
Co-authored-by: Sholto Douglas <sholto@google.com>
2023-02-02 23:01:30 -08:00
Jake VanderPlas
b679ef025f Remove unused CacheInfo namedtuple 2023-01-31 11:36:43 -08:00
Yash Katariya
c4d91d203c Remove local_imports of sharding.py. Adding pxla local imports but then cleaning those up will be super easy since those will be the only ones left and restricted to sharding.py file only.
Also remove `maybe_cached_property` from this CL since we are dropping 3.7 support

PiperOrigin-RevId: 491769101
2022-11-29 16:42:03 -08:00
Sharad Vikram
74b136e62c Delete jax_experimental_name_stack flag
PiperOrigin-RevId: 487601864
2022-11-10 11:59:50 -08:00
Peter Hawkins
320d531521 Increase the minimum jaxlib version to 0.3.22.
The minimum xla_extension_version is now 98 and the minimum mlir_api_version is now 32.
2022-10-27 10:24:11 -04:00
Jake VanderPlas
5d15757741 [typing] annotate jax._src.util.safe_map 2022-10-20 10:15:04 -07:00
Jake VanderPlas
524745f322 TMP: annotate util.safe_zip 2022-10-19 10:29:53 -07:00
Jake VanderPlas
d60ceeadd0 [typing] annotate util.unzip2 & util.unzip3 2022-10-18 09:47:49 -07:00
Nicholas Junge
efd61b73f6 Migrate JAX internals to builtin Python logging
This commit changes the JAX codebase to use Python's builtin logging instead of ABSL logging. With the latter being used in JAX code as of now, the change to Python builtin logging is advised for the following reasons (among others):

- absl-py can be removed as an external dependency of JAX.
- Builtin logging brings the option of adding more log handlers, for example file handlers for log dumps or writers to different IO streams.

Logging in JAX is ported over to take place at the module level. While previously, some Python namespaces within JAX already used module-scoped logging via absl.vlog, the following idiom was adopted to provide the same functionality in Python builtin logging:

```py
import logging
logger = logging.getLogger(__name__)

logger.debug(...)
logger.info(...)
```

 The builtin root logger is left untouched, which is beneficial for downstream users planning to customize the Python root logger. All JAX internal code promises to log to descendants of the top-level "jax" logger by virtue of log propagation.

The package `absl-py` was removed from JAX's install requirements, and added into its test requirements.
2022-10-13 21:32:44 +02:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Parker Schuh
9b3dfb66fa Use c++ weakref LRU cache implementation as a drop in replacement for jax's
weakref_lru_cache.

PiperOrigin-RevId: 468550018
2022-08-18 14:36:08 -07:00
Jake VanderPlas
98fac62897 remove dead code: jax._src.util.taggedtuple 2022-07-25 15:14:25 -07:00
Matthew Johnson
5b82ba787c [dynamic-shapes] start basic vmap compatibility 2022-07-09 10:03:40 -07:00
Matthew Johnson
004b59fbc9 [dynamic-shapes] basic linearize and grad working 2022-06-30 14:30:22 -07:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Jake VanderPlas
72470dee3a Comment on implementation of unzip2 & unzip3 2022-04-14 13:41:05 -07:00
Parker Schuh
df1c478ec5 Fix race condition for weakref destructor by catching rare exceptions. 2022-04-01 12:04:36 -07:00
Parker Schuh
c1bb767959
Update util.py 2022-03-23 12:26:09 -07:00
Parker Schuh
d0e0da02a1 Add weakreaf_lru_cache to prevent caches from pinning jaxprs.
To use this cache, the first argument must be some type that is
object identity hashed (like a jaxpr).
2022-03-21 10:56:44 -07:00
Sharad Vikram
1b79caa6bd Add separate mechanism for threading name stacks to the lowering 2022-02-23 09:59:09 -08:00
Matthew Johnson
7077ce2e68 [remove units] make JaxprTrace.process_call not introduce units 2022-02-12 13:48:12 -08:00
Matthew Johnson
e321964245 de-duplicate util.memoize and util.cache
The only difference between the two was that
jax.config.jax_check_tracer_leaks disables the caching under util.cache
but not under util.memoize.

We could add that as an option on the same function if it turns out to
be important, but it seems unnecessary. Moreover there are only two
callers (in dtypes.py and in batching.py).

Co-authored-by: Skye Wanderman-Milne <skyewm@google.com>
2022-01-10 14:28:28 -08:00
Peter Hawkins
52fe821719 Merge xla._partition_outputs and util.unflatten.
PiperOrigin-RevId: 412117736
2021-11-24 12:52:40 -08:00
Jake VanderPlas
f6e3f1b4ad Cleanup: remove duplicate canonicalize_axis utility 2021-11-23 16:54:02 -08:00
Peter Hawkins
0f56838435 Fix a number of bugs in MLIR translation rules.
These bugs were found by running the existing tests with MLIR translations enabled, so no new tests are needed:

* Fix bug where we failed to propagate the symbol table to inner computations. This could lead to duplicate function names.
* Remove support for tupling arguments. It turns out that the MHLO->HLO conversion, which was the intended user, does not accept tupled arguments in the input MHLO. Instead, arguments are tupled if requested by a flag to the converter.
* Add a generic fallback to translate via the XLA HLO to MHLO if there is no MHLO-specific translation rule.
* If we are padding in select_and_scatter_add, we also need to slice the output.
* create_token may take arguments (which should be ignored).
* Fixed a number of misunderstandings of the mhlo.infeed contract.
* Untuple results in the fallback path iff the primitive is marked as having multiple results, not depending on the actual arity.
* Change xla.primitive_subcomputation not to filter token arguments, which is appropriate for a subcomputation.

PiperOrigin-RevId: 410519678
2021-11-17 07:20:56 -08:00
Peter Hawkins
70b8a6a806 Add a prototype IREE backend for JAX.
This is to support experimentation with the combination of JAX/IREE. Many things do not work yet.

PiperOrigin-RevId: 409980064
2021-11-15 07:57:04 -08:00
Peter Hawkins
8f6e077d9a Adds an initial prototype of an alternate JAX compilation path that emits the MLIR MHLO/CHLO dialects instead of classic XLA HLO.
This lowering is missing a number of features, but it is complete enough that many tests pass, and that I would like to start checking it in.

PiperOrigin-RevId: 409134016
2021-11-11 06:37:12 -08:00
Peter Hawkins
42e0d4e5f5 Remove jax._src.util.partialmethod.
Use functools.partialmethod instead, which has existed since Python 3.4. The JAX partialmethod doesn't work correctly in Python 3.10.

Issue #8097
2021-10-05 12:12:41 -04:00
Peter Hawkins
a11d957e61 Disallow non-hashable static arguments in pmap().
* Don't wrap static arguments in hashable wrappers in pmap.
* Delete wrap_hashably().
* In argnums_partial, either enforce hashability or wrap values with an explicitly unhashable wrapper. The intent here is that either we should check for hashability early or we should make sure it's clear that it's not something we intended..
* Delete argnames_partial, which appears unused.
2021-09-30 15:50:07 -04:00
Peter Hawkins
8b2123968a Switch internal users of jax.util.partial to use functools.partial. 2021-09-13 21:09:58 -04:00
Matthew Johnson
2d28951ba4 address comments form @apaszke 2021-08-26 14:10:58 -07:00
Matthew Johnson
542641ca87 rejames/reblake implementation 2021-08-25 20:46:32 -07:00
Peter Hawkins
e709a2ea4d Delete tuple_replace.
It is unused.
2021-05-19 15:29:44 -04:00
Peter Hawkins
5261b776d2 Handle context manager configuration settings for matmul precision and numpy rank promotion correctly in JIT and linear_util caches.
PiperOrigin-RevId: 369643419
2021-04-21 06:36:35 -07:00
Skye Wanderman-Milne
1614572eb9 Add optional distributed debugging logging.
This can be enabled by setting the environment variable
`JAX_DISTRIBUTED_DEBUG=1` (or other true-like values), the flag
`--jax_distributed_debug=1`, or `jax.config.distributed_debug =
True`. It's off by default.

This enables WARNING-level logging of each distributed computation
that's run and related debugging information. This is designed to help
with multi-process debugging, e.g. to identify mismatched pmaps across
processes. All debugging information is enclosed between
`DISTRIBUTED_DEBUG_BEGIN` and `DISTRIBUTED_DEBUG_END` to faciliate
grepping for this info.

Example output:

```
DISTRIBUTED_DEBUG_BEGIN
Initialized backend: tpu
  process_index: 0
  device_count: 8
  local_devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
DISTRIBUTED_DEBUG_END

DISTRIBUTED_DEBUG_BEGIN
Running pmapped function: <lambda>
  python function: <function PmapTest.testArgAllReduce.<locals>.<lambda> at 0x7f77924d6c80>
  devices: None
  abstract args: [ShapedArray(float32[2,2])]
DISTRIBUTED_DEBUG_END

DISTRIBUTED_DEBUG_BEGIN
Running xmapped function: <lambda>
  python function: <function XMapTest.testAxisSizes.<locals>.<lambda> at 0x7fb33d86e158>
  mesh: Mesh(array([TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
       TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1)],
      dtype=object), ('x',))
  abstract args: []
DISTRIBUTED_DEBUG_END

DISTRIBUTED_DEBUG_BEGIN
Running pjit'd function: f
  python function: <function PJitTest.testShardingConstraintPyTree.<locals>.f at 0x7fad672b8b70>
  mesh: Mesh(array([[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)],
       [TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1)]],
      dtype=object), ('x', 'y'))
  abstract args: [ShapedArray(int32[8,8]), ShapedArray(int32[8,8]), ShapedArray(int32[8,8])]
DISTRIBUTED_DEBUG_END
```
2021-04-20 13:34:45 -07:00
Stephan Hoyer
acb0be9cb7 Add _python_jit_with_static_argnames. 2021-03-31 10:02:16 -07:00
Matthew Johnson
2b79264354 remove disable_omnistaging mechanism 2021-03-29 15:26:57 -07:00
Matthew Johnson
fd7b286ec9 unify configuration state handling 2021-03-23 18:56:01 -07:00
jax authors
01485ec4ee Merge pull request #5673 from apaszke:pgather
PiperOrigin-RevId: 356492981
2021-02-09 07:03:01 -08:00
Adam Paszke
b19dd87581 Add a pgather primitive, making it possible to index into mapped axes 2021-02-09 10:44:31 +00:00
jax authors
1bc82f139e Merge pull request #5644 from apaszke:xmap-more-reductions
PiperOrigin-RevId: 356378902
2021-02-08 16:25:10 -08:00
Adam Paszke
6965e8bbe3 Add support for named axes in jnp.mean and jnp.std 2021-02-08 20:43:23 +00:00
Jake VanderPlas
c1d1d94bf7 util.wraps: update same attributes as functools.wraps. 2021-02-08 11:31:53 -08:00
Jake VanderPlas
2fd682ef2a Make jax_enable_x64 a thread-local value. 2021-02-04 09:48:22 -08:00
Jake VanderPlas
b5454613f7 Add experimental context manager to enable/disable X64 mode 2021-01-25 13:23:15 -08:00
Matthew Johnson
203af4517b revive the leak checker, as a debug mode
Co-authored-by: James Bradbury <jekbradbury@google.com>
2021-01-22 18:31:00 -08:00
Peter Hawkins
3ac809ede3 [JAX] Move jax.util to jax._src_util.
PiperOrigin-RevId: 351234602
2021-01-11 14:21:07 -08:00