3846 Commits

Author SHA1 Message Date
Lena Martens
2ba8aec274 Checkify: Fix docstring formatting and polish enabled_errors sets. 2022-02-10 16:54:43 +00:00
Lena Martens
1340fbbc09 Strip named_shape and weak_type from aval when donating buffers.
PiperOrigin-RevId: 427744848
2022-02-10 07:39:50 -08:00
Peter Hawkins
74506c7dda Rollback of: Implement the JAX transfer guard API
Adds `--jax_transfer_guard` flag and `jax.transfer_guard()` context manager that allows logging or disallowing unintended transfers.

The API distinguishes between two types of transfers:
* explicit transfers: `jax.device_put*()` and `jax.device_get()` calls.
* implicit transfers: Other transfers (e.g., printing a `DeviceArray`).

The transfer guard can take an action based on its guard level:

* "allow": Silently allow all transfers (default; s...

PiperOrigin-RevId: 427576107
2022-02-09 14:44:45 -08:00
Hyeontaek Lim
b7e1fec250 Implement the JAX transfer guard API
Adds `--jax_transfer_guard` flag and `jax.transfer_guard()` context manager that allows logging or disallowing unintended transfers.

The API distinguishes between two types of transfers:
* explicit transfers: `jax.device_put*()` and `jax.device_get()` calls.
* implicit transfers: Other transfers (e.g., printing a `DeviceArray`).

The transfer guard can take an action based on its guard level:

* "allow": Silently allow all transfers (default; same as the previous behavior).
* "log": Log and allow implicit transfers. Silently allow explicit transfers.
* "disallow": Disallow implicit transfers. Silently allow explicit transfers.
* "log_explicit": Log and allow all transfers.
* "disallow_explicit": Disallow all transfers.

The API also allows fine-control the transfer guard level of individual transfer directions. Their flag and context manager names are suffixed with the transfer direction:

* "host_to_device": Converting a Python value into a `DeviceBuffer`.
* "device_to_device": Copying a `DeviceBuffer` to a different device.
* "device_to_host": Fetching the value of a `DeviceBuffer`.

Example:
```
x = jnp.array(1)
y = jnp.array(2)
z = jnp.array(3)

print(x)  # No error
with jax.transfer_guard("disallow"):
  print(x)  # No error; x is already fetched
  print(jax.device_get(y))  # No error
  print(z)  # Error!
```

PiperOrigin-RevId: 427562278
2022-02-09 13:50:25 -08:00
Matthew Johnson
d9270b242d [checkify] add custom_vjp support 2022-02-09 12:31:16 -08:00
jax authors
b82ef91f42 Merge pull request #9509 from mattjj:checkify-custom-jvp
PiperOrigin-RevId: 427541020
2022-02-09 12:25:22 -08:00
Matthew Johnson
4ce749e681 [checkify] handle custom_jvp 2022-02-09 12:12:58 -08:00
Peter Hawkins
8ca6622c0b Change lax.select_p to be an n-ary predicate, 'lax.select_n_p'. Change lax.select() to be a thin shim around the new n-ary version.
Note that one key difference between `lax.select_p` and `lax.select_n_p` is that the order of the cases is reversed for boolean predicates. This merited a new name to minimize confusion.

Use lax.select_n() in conditional batching. This means that we only produce one `select_n()` primitive for each conditional output, rather than a tree. While this has no effect on the number of HLO operators we generate, it can reduces the number of jaxpr equations significantly.

PiperOrigin-RevId: 427517899
2022-02-09 11:03:09 -08:00
jax authors
e219ea08f5 Merge pull request #9204 from LenaMartens:changelist/420726115
PiperOrigin-RevId: 427470746
2022-02-09 07:54:17 -08:00
Yash Katariya
0a060cf183 Catch the device buffer order with the expected global_mesh.local_devices order that we should get. This is to make sure that we don't get some cryptic message from XLA and catch it in GDA itself.
PiperOrigin-RevId: 427315736
2022-02-08 15:58:20 -08:00
jax authors
4e8043f2d1 Merge pull request #9461 from MichaelMarien:quantile-tuple-axis
PiperOrigin-RevId: 427313122
2022-02-08 15:49:54 -08:00
michaelmarien
3e9f8248f2 Expand implementation of lax_numpy._quantile to allow the input of a tuple as axis argument
* support and test edge case where axis argument is empty tuple ()
* replace swapaxis + reshape methodology by one call to lax.reshape for computational efficiency's sake
* add check on repeated axis and throw ValueError
* introduced and changed corresponding numpy code to swap and reshape axis to be quantiled
* introduced code to accomodate the reintroduction of those axes if keepdims=True
* added testcases
2022-02-08 21:03:02 +01:00
Lena Martens
0042edb5f4 Checkify: rename some symbols and add some docstrings. 2022-02-08 17:40:04 +00:00
jax authors
fbda1a650f Merge pull request #9489 from hawkinsp:sda
PiperOrigin-RevId: 427210907
2022-02-08 09:28:14 -08:00
Peter Hawkins
5679fedd2c Fix missing handler when lexically capturing a ShardedDeviceArray when MLIR enabled. 2022-02-08 09:51:57 -05:00
Nicholas Krämer
c07a0f1139 Add test and jet-primitive for dynamic_slice 2022-02-08 13:28:41 +01:00
Peter Hawkins
8be057de1f Introduce a new jax/jaxlib versioning scheme.
Adds a design note that describes the scheme and how the jax and jaxlib versions
are related.
2022-02-07 17:59:42 -05:00
Adam Paszke
42cd7ed1d0 Allow nesting MANUAL-style xmaps in pjits
PiperOrigin-RevId: 426955137
2022-02-07 10:35:06 -08:00
Jake VanderPlas
0b4b0fd07b tests: add unsigned ints to all_dtypes 2022-02-07 08:59:44 -08:00
Peter Hawkins
942581d136 Fix test failure in line search test due to scipy 1.8 release. 2022-02-07 14:49:54 +00:00
jax authors
89be6d29c5 Merge pull request #9395 from jakevdp:faster-numpy-test
PiperOrigin-RevId: 426850791
2022-02-07 01:29:21 -08:00
James Bradbury
5dd1c75969 Add batch_axis to variance scaling initializers
PiperOrigin-RevId: 426522731
2022-02-04 17:02:11 -08:00
Jake VanderPlas
bd18ded481 NumpyUfuncTests: speed up test generation 2022-02-04 15:22:16 -08:00
Adam Paszke
086a607d8c Add experimental support for SPMD lowering of xmap via MANUAL sharding annotations
Note that it's still limited and turns out to be a bit hard (partly due to
unclear XLA semantics at this point). Using constants that are not xmap inputs
is likely to cause SPMD partitioner errors and cross-replica collectives don't seem
to work either.

In any case, the next step will be to allow nesting those xmaps inside pjits.

PiperOrigin-RevId: 426447989
2022-02-04 11:17:39 -08:00
Skye Wanderman-Milne
3fd91d6038 Run compilation_cache_test.py using TFRT
PiperOrigin-RevId: 426433041
2022-02-04 10:17:59 -08:00
jax authors
d11ead5212 Merge pull request #9372 from mattjj:tree-prefix-utils
PiperOrigin-RevId: 426320883
2022-02-03 22:49:33 -08:00
jax authors
45d96c490e Merge pull request #4671 from romanngg:conv_local
PiperOrigin-RevId: 426282505
2022-02-03 18:03:33 -08:00
Matthew Johnson
e186aa3f1e add and test pytree utils for better errors 2022-02-03 17:04:38 -08:00
Yash Katariya
50d899964c Add a replica_id check so that the checkpoint inconsistency that was observed due to wrong replica_id calculation will not happen again. We will catch this at the GDA level in __init__ and raise an error before going to checkpointing.
In the pxla path, we will catch this in `_gda_array_result_handler` while compiling.

PiperOrigin-RevId: 426248066
2022-02-03 15:09:36 -08:00
jax authors
82c882ccef Merge pull request #9432 from jakevdp:bcoo-extract-fix
PiperOrigin-RevId: 426202221
2022-02-03 11:51:30 -08:00
Tianjian Lu
5a012d5e7b [JAX] Added jit-able singular value decomposition.
PiperOrigin-RevId: 426193395
2022-02-03 11:16:55 -08:00
Jake VanderPlas
69cecbf415 [sparse] fix bcoo_extract batching rule 2022-02-03 09:37:09 -08:00
jax authors
d04dce3fa2 Merge pull request #9417 from hawkinsp:fft2
PiperOrigin-RevId: 426163984
2022-02-03 09:24:13 -08:00
Peter Hawkins
042c9bd7a5 Ensure that tree_util.Partial's .func attribute is stable.
Fixes #9429.
2022-02-03 10:44:13 -05:00
Peter Hawkins
84bccb2420 Support string fft_type values in lax.fft. 2022-02-03 08:52:38 -05:00
jax authors
a0abe8e4ac [JAX] Move the ann recall computation to ann.py.
This function is very useful for our users to evaluate the ann results
against the standard ann datasets that provides the ground truth.

PiperOrigin-RevId: 425997236
2022-02-02 15:50:13 -08:00
Yash Katariya
4e47de66fc Add the cache back now that Mesh's __hash__ is also being hashed on self.devices.shape.
PiperOrigin-RevId: 425711067
2022-02-01 14:06:01 -08:00
jax authors
fe14530347 Merge pull request #9391 from jakevdp:fix-constant-handler
PiperOrigin-RevId: 425677978
2022-02-01 11:44:09 -08:00
jax authors
e3fe4a2c7c Merge pull request #9316 from mattjj:djax-now-5
PiperOrigin-RevId: 425627062
2022-02-01 08:13:09 -08:00
Matthew Johnson
d9dcd1394a djax: let make_jaxpr build dyn shape jaxprs 2022-02-01 00:10:21 -08:00
jax authors
d8f74f50dc Merge pull request #9386 from jakevdp:fix-pjit-test
PiperOrigin-RevId: 425477260
2022-01-31 15:37:16 -08:00
jax authors
c12ca7f64c [XLA:TPU] Add 'aggregate_to_topk' option to ann in jax
Also adds a pmap test for demonstrating multi-tpu ann.

PiperOrigin-RevId: 425451716
2022-01-31 13:46:07 -08:00
Tianjian Lu
97d55eb13c [JAX] Re-enables lowering bcoo dot general to cuSparse.
PiperOrigin-RevId: 425410511
2022-01-31 11:02:57 -08:00
Jake VanderPlas
37e73fce7f Add complex types to mlir constant handlers 2022-01-31 10:56:52 -08:00
Jake VanderPlas
163ec36ee0 pjit_test: set jax_numpy_rank_promotion=raise 2022-01-31 08:44:11 -08:00
Yash Katariya
f3ae2c0dfe Strip named_shape and weak_type from aval when donating buffers.
PiperOrigin-RevId: 424968695
2022-01-28 15:16:55 -08:00
Tom Hennigan
ace8c0a53a Strip named_shape and weak_type from aval when donating buffers.
PiperOrigin-RevId: 424888671
2022-01-28 09:35:54 -08:00
Jake VanderPlas
e376df29be disable implicit rank promotion in a number of remaining tests 2022-01-28 08:16:30 -08:00
jax authors
344171a508 Merge pull request #9359 from jakevdp:re-enable-array-test
PiperOrigin-RevId: 424871630
2022-01-28 08:05:56 -08:00
jax authors
e6f9ba0a14 Merge pull request #9275 from froystig:auto-vmap
PiperOrigin-RevId: 424765479
2022-01-27 19:38:31 -08:00