10460 Commits

Author SHA1 Message Date
Yash Katariya
1ad3551ec9 Release jax and jaxlib 0.3.0 as per the new release process.
PiperOrigin-RevId: 427809845
jaxlib-v0.3.0 jax-v0.3.0
2022-02-10 11:59:13 -08:00
jax authors
4df01eee08 Merge pull request #9523 from google:yashk2810-patch-6
PiperOrigin-RevId: 427804472
2022-02-10 11:39:07 -08:00
Yash Katariya
ef11d53ff9
Update TF commit for release 2022-02-10 11:33:20 -08:00
Yash Katariya
d82bcc2a0c Add jax[_ci] option to account for the new release process.
PiperOrigin-RevId: 427802081
2022-02-10 11:29:51 -08:00
jax authors
0b135fbfe8 Merge pull request #9518 from LenaMartens:changelist/427732530
PiperOrigin-RevId: 427766853
2022-02-10 09:18:46 -08:00
Lena Martens
2ba8aec274 Checkify: Fix docstring formatting and polish enabled_errors sets. 2022-02-10 16:54:43 +00:00
Lena Martens
1340fbbc09 Strip named_shape and weak_type from aval when donating buffers.
PiperOrigin-RevId: 427744848
2022-02-10 07:39:50 -08:00
jax authors
744636db71 Merge pull request #9514 from google:yashk2810-patch-5
PiperOrigin-RevId: 427605108
2022-02-09 16:50:00 -08:00
Yash Katariya
cabc98c047
update TF commit for release 2022-02-09 16:34:44 -08:00
Peter Hawkins
8af0d8d033 Add complex number DLPack support to JAX and TensorFlow.
Fixes https://github.com/google/jax/issues/9497

PiperOrigin-RevId: 427579098
2022-02-09 14:58:00 -08:00
Peter Hawkins
74506c7dda Rollback of: Implement the JAX transfer guard API
Adds `--jax_transfer_guard` flag and `jax.transfer_guard()` context manager that allows logging or disallowing unintended transfers.

The API distinguishes between two types of transfers:
* explicit transfers: `jax.device_put*()` and `jax.device_get()` calls.
* implicit transfers: Other transfers (e.g., printing a `DeviceArray`).

The transfer guard can take an action based on its guard level:

* "allow": Silently allow all transfers (default; s...

PiperOrigin-RevId: 427576107
2022-02-09 14:44:45 -08:00
jax authors
ca18fe1846 Merge pull request #9511 from hawkinsp:eqnprint
PiperOrigin-RevId: 427574021
2022-02-09 14:35:47 -08:00
Peter Hawkins
0c7687666a Replace core.pp_eqn_compact() with core.str_eqn_compact().
pp_eqn_compact() is used for one purpose only: creating metadata to put
on HLO. In that case, we don't need such carefully-formatted strings,
and speed is more important.

Gave a 6% speedup on a researcher's model.
2022-02-09 22:08:45 +00:00
Hyeontaek Lim
b7e1fec250 Implement the JAX transfer guard API
Adds `--jax_transfer_guard` flag and `jax.transfer_guard()` context manager that allows logging or disallowing unintended transfers.

The API distinguishes between two types of transfers:
* explicit transfers: `jax.device_put*()` and `jax.device_get()` calls.
* implicit transfers: Other transfers (e.g., printing a `DeviceArray`).

The transfer guard can take an action based on its guard level:

* "allow": Silently allow all transfers (default; same as the previous behavior).
* "log": Log and allow implicit transfers. Silently allow explicit transfers.
* "disallow": Disallow implicit transfers. Silently allow explicit transfers.
* "log_explicit": Log and allow all transfers.
* "disallow_explicit": Disallow all transfers.

The API also allows fine-control the transfer guard level of individual transfer directions. Their flag and context manager names are suffixed with the transfer direction:

* "host_to_device": Converting a Python value into a `DeviceBuffer`.
* "device_to_device": Copying a `DeviceBuffer` to a different device.
* "device_to_host": Fetching the value of a `DeviceBuffer`.

Example:
```
x = jnp.array(1)
y = jnp.array(2)
z = jnp.array(3)

print(x)  # No error
with jax.transfer_guard("disallow"):
  print(x)  # No error; x is already fetched
  print(jax.device_get(y))  # No error
  print(z)  # Error!
```

PiperOrigin-RevId: 427562278
2022-02-09 13:50:25 -08:00
jax authors
74fb420130 Merge pull request #9510 from mattjj:checkify-custom-vjp
PiperOrigin-RevId: 427546471
2022-02-09 12:48:00 -08:00
Matthew Johnson
d9270b242d [checkify] add custom_vjp support 2022-02-09 12:31:16 -08:00
jax authors
b82ef91f42 Merge pull request #9509 from mattjj:checkify-custom-jvp
PiperOrigin-RevId: 427541020
2022-02-09 12:25:22 -08:00
jax authors
61b884b0d4 Merge pull request #9494 from superbobry:lax_numpy-any
PiperOrigin-RevId: 427539082
2022-02-09 12:17:59 -08:00
Matthew Johnson
4ce749e681 [checkify] handle custom_jvp 2022-02-09 12:12:58 -08:00
Yash Katariya
02b8ce3373 Use thread local positional semantics and change thread_resources to subclass from threading.local.
Both these fixes makes sure that you can compile pjit in multiple threads.

PiperOrigin-RevId: 427517953
2022-02-09 11:19:56 -08:00
Peter Hawkins
8ca6622c0b Change lax.select_p to be an n-ary predicate, 'lax.select_n_p'. Change lax.select() to be a thin shim around the new n-ary version.
Note that one key difference between `lax.select_p` and `lax.select_n_p` is that the order of the cases is reversed for boolean predicates. This merited a new name to minimize confusion.

Use lax.select_n() in conditional batching. This means that we only produce one `select_n()` primitive for each conditional output, rather than a tree. While this has no effect on the number of HLO operators we generate, it can reduces the number of jaxpr equations significantly.

PiperOrigin-RevId: 427517899
2022-02-09 11:03:09 -08:00
jax authors
e219ea08f5 Merge pull request #9204 from LenaMartens:changelist/420726115
PiperOrigin-RevId: 427470746
2022-02-09 07:54:17 -08:00
Peter Hawkins
82d8261308 Speed up source location computation when lowering a jaxpr to HLO/MHLO.
Speed up source_info_util.user_frames by using a newly refactored Traceback.raw_frames() attribute. Since we are interested only in one frame, it's best to avoid doing wasted work on all the frames we are going to ignore.

Change traceback.raw_frames() to return the transpose of what it previously returned because it means we only need to build 3 Python objects, rather than n + 1 Python objects for n frames.

PiperOrigin-RevId: 427320674
2022-02-08 16:17:40 -08:00
Yash Katariya
0a060cf183 Catch the device buffer order with the expected global_mesh.local_devices order that we should get. This is to make sure that we don't get some cryptic message from XLA and catch it in GDA itself.
PiperOrigin-RevId: 427315736
2022-02-08 15:58:20 -08:00
jax authors
4e8043f2d1 Merge pull request #9461 from MichaelMarien:quantile-tuple-axis
PiperOrigin-RevId: 427313122
2022-02-08 15:49:54 -08:00
jax authors
e8ec9570dd Merge pull request #9471 from jakevdp:generic
PiperOrigin-RevId: 427310192
2022-02-08 15:39:58 -08:00
jax authors
bce332bf83 Merge pull request #9481 from jakevdp:jax-enable-checks
PiperOrigin-RevId: 427310155
2022-02-08 15:34:51 -08:00
Sergei Lebedev
0fe377ce42 Added an explicit Any return type to lax_numpy.ndarray methods
This change makes ndarray a bit easier for tooling to handle, since de-facto
all these methods are supposed to return *something*, but the type inferrable
from their default implementations is None.

As a hand-wavy aside, in a type stub

    def f(): ...

could be treated equivalently to

    def f() -> Any: ...

because there is no body to infer return type from, and Any is a reasonable
fallback type. In a .py file, however, f is no longer just a function *type*
(as opposed to function *implementation*), and thus it has an inferrable
return type.
2022-02-08 22:18:16 +00:00
lenamartens
4d0db5d975 Fix build and address suggestion 2022-02-08 20:57:51 +00:00
Lena Martens
b2cf12aa7e
Apply suggestions from code review
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-02-08 20:23:40 +00:00
michaelmarien
3e9f8248f2 Expand implementation of lax_numpy._quantile to allow the input of a tuple as axis argument
* support and test edge case where axis argument is empty tuple ()
* replace swapaxis + reshape methodology by one call to lax.reshape for computational efficiency's sake
* add check on repeated axis and throw ValueError
* introduced and changed corresponding numpy code to swap and reshape axis to be quantiled
* introduced code to accomodate the reintroduction of those axes if keepdims=True
* added testcases
2022-02-08 21:03:02 +01:00
Lena Martens
0042edb5f4 Checkify: rename some symbols and add some docstrings. 2022-02-08 17:40:04 +00:00
jax authors
fbda1a650f Merge pull request #9489 from hawkinsp:sda
PiperOrigin-RevId: 427210907
2022-02-08 09:28:14 -08:00
jax authors
c137aadad7 Merge pull request #9488 from pnkraemer:jet-primitive-dynamic-slice
PiperOrigin-RevId: 427210782
2022-02-08 09:23:27 -08:00
Peter Hawkins
5679fedd2c Fix missing handler when lexically capturing a ShardedDeviceArray when MLIR enabled. 2022-02-08 09:51:57 -05:00
Nicholas Krämer
c07a0f1139 Add test and jet-primitive for dynamic_slice 2022-02-08 13:28:41 +01:00
jax authors
44c6c055d3 Merge pull request #9442 from mattjj:zeros
PiperOrigin-RevId: 427078926
2022-02-07 19:17:57 -08:00
jax authors
3c1a1ca6f0 Merge pull request #9419 from hawkinsp:versioning
PiperOrigin-RevId: 427073382
2022-02-07 18:31:52 -08:00
Peter Hawkins
8be057de1f Introduce a new jax/jaxlib versioning scheme.
Adds a design note that describes the scheme and how the jax and jaxlib versions
are related.
2022-02-07 17:59:42 -05:00
Jake VanderPlas
760f309fb5 Add jax.numpy.generic 2022-02-07 14:56:39 -08:00
Jake VanderPlas
f2222bb1cf CI: error if docstring rewrite fails 2022-02-07 14:43:00 -08:00
Peter Hawkins
287c476eec Cache traceback to MLIR location conversion.
Finding the user frame in a traceback is something we do for every jaxpr equation, and it shows up in profiles. We expect a reasonable amount of locality, e.g., many lines of code with similar provenance appearing together, so this seems like a place for a small LRU cache.

PiperOrigin-RevId: 427020947
2022-02-07 14:40:46 -08:00
Peter Hawkins
f539c9b9bd Hoist construction of predicates out of cond batching rule.
Avoids building the "which path are we following" predicate once for each input.

PiperOrigin-RevId: 427012972
2022-02-07 14:13:03 -08:00
Peter Hawkins
e7032fe910 Remove unnecessary dtype canonicalization from jax.core.raise_to_shaped.
I noticed this in passing while working on https://github.com/google/jax/pull/9468. It seems strange to me that we would change the dtype when raising a ShapedArray to a ShapedArray, and indeed it seems not to be necessary.

PiperOrigin-RevId: 427011028
2022-02-07 14:06:13 -08:00
jax authors
5648f768ff Merge pull request #9477 from hawkinsp:doc2
PiperOrigin-RevId: 427002408
2022-02-07 13:34:25 -08:00
Peter Hawkins
465b593293 Update scipy intersphinx inventory for SciPy 1.8.0.
According to https://github.com/scipy/scipy/issues/14267 the SciPy docs seems to have moved.
2022-02-07 16:19:46 -05:00
Adam Paszke
296832e891 Use aval_out to construct a sharding spec in shard to full
The shard's dimensions might be too small and might trigger asserts, even though
the shape has no influence on sharding specs.

PiperOrigin-RevId: 426955706
2022-02-07 10:39:41 -08:00
Adam Paszke
42cd7ed1d0 Allow nesting MANUAL-style xmaps in pjits
PiperOrigin-RevId: 426955137
2022-02-07 10:35:06 -08:00
jax authors
1bc8ee1df7 Merge pull request #9441 from jakevdp:test-all-types
PiperOrigin-RevId: 426937874
2022-02-07 09:31:42 -08:00
Jake VanderPlas
0b4b0fd07b tests: add unsigned ints to all_dtypes 2022-02-07 08:59:44 -08:00