10550 Commits

Author SHA1 Message Date
Hyeontaek Lim
beaa00c460 Implement the JAX transfer guard API
Adds `--jax_transfer_guard` flag and `jax.transfer_guard()` context manager that allows logging or disallowing unintended transfers.

The API distinguishes between two types of transfers:
* explicit transfers: `jax.device_put*()` and `jax.device_get()` calls.
* implicit transfers: Other transfers (e.g., printing a `DeviceArray`).

The transfer guard can take an action based on its guard level:

* "allow": Silently allow all transfers (default; same as the previous behavior).
* "log": Log and allow implicit transfers. Silently allow explicit transfers.
* "disallow": Disallow implicit transfers. Silently allow explicit transfers.
* "log_explicit": Log and allow all transfers.
* "disallow_explicit": Disallow all transfers.

The API also allows fine-control the transfer guard level of individual transfer directions. Their flag and context manager names are suffixed with the transfer direction:

* "host_to_device": Converting a Python value into a `DeviceBuffer`.
* "device_to_device": Copying a `DeviceBuffer` to a different device.
* "device_to_host": Fetching the value of a `DeviceBuffer`.

Example:
```
x = jnp.array(1)
y = jnp.array(2)
z = jnp.array(3)

print(x)  # No error
with jax.transfer_guard("disallow"):
  print(x)  # No error; x is already fetched
  print(jax.device_get(y))  # No error
  print(z)  # Error!
```

PiperOrigin-RevId: 428590081
2022-02-14 13:11:49 -08:00
michaelmarien
20e5090b61 Add a warning to random.choice to notify users of the ill-defined behaviour when requesting more samples than non-zero probabilities and replace=False 2022-02-14 21:41:30 +01:00
jax authors
f229a703e7 Merge pull request #9562 from jakevdp:disable-rank-promotion
PiperOrigin-RevId: 428579739
2022-02-14 12:27:22 -08:00
Parker Schuh
7ce911b8d1 Add translation rule for optimization barrier.
Also adds a translation rule for remat that uses the new optimization barrier
op. If you find errors, consider disabling the remat lowering using
`jax_remat_opt_barrier` config flag.
2022-02-14 12:21:16 -08:00
jax authors
a457320b4d Merge pull request #9563 from hawkinsp:selectn
PiperOrigin-RevId: 428570762
2022-02-14 11:53:03 -08:00
Peter Hawkins
29c8a04527 Fix incorrect binary search comparison in lax.select_n lowering.
Fixes issue in https://github.com/google/jax/discussions/9556#discussioncomment-2175113
2022-02-14 14:29:38 -05:00
jax authors
fb4934c23b Merge pull request #9559 from hawkinsp:token
PiperOrigin-RevId: 428533661
2022-02-14 09:44:13 -08:00
jax authors
8b117c500f Merge pull request #9557 from LenaMartens:changelist/428497052
PiperOrigin-RevId: 428533464
2022-02-14 09:38:53 -08:00
Jake VanderPlas
97512e9e44 JaxTestCase: set jax_numpy_rank_promotion='raise' by default 2022-02-14 09:22:05 -08:00
Peter Hawkins
5a259925a0 Add constant handler for tokens.
Fixes https://github.com/google/jax/issues/9438
2022-02-14 12:09:29 -05:00
jax authors
2c01312d09 Merge pull request #9558 from jblespiau:changelist/427986940
PiperOrigin-RevId: 428526701
2022-02-14 09:09:20 -08:00
Jean-Baptiste Lespiau
799ecfa920 Remove e Type annotation for jit an pmap as there are additional attributes on the returned callable.
Using the experimental jax.jit(lambda x: x+1).lower(...) is raising an error with pytype.
2022-02-14 16:39:51 +00:00
Lena Martens
a4cacf5729 Checkify: handle named_call in process_call.
named_call does not specify donated_invars, this change handles this missing
param case.

For future reference: we might want to add a call_param_updater registry to define
how call params need to get updated wrt checkify, like eg. partial_eval/ad does.
2022-02-14 15:24:55 +00:00
jax authors
fb821e94ad Merge pull request #9491 from LenaMartens:changelist/427247461
PiperOrigin-RevId: 428492579
2022-02-14 06:38:30 -08:00
Jake VanderPlas
4f6004a3c9 JaxTestCase now sets jax_numpy_rank_promotion='raise' by default
PiperOrigin-RevId: 428489444
2022-02-14 06:20:42 -08:00
jax authors
7af443aac7 Merge pull request #9539 from jakevdp:fix-jax-array
PiperOrigin-RevId: 428487387
2022-02-14 06:09:16 -08:00
RuffaloVM
fc4f47cdae shape_poly_test.py : remove duplicate word 2022-02-13 23:47:52 +09:00
jax authors
51c7d3bbb5 Merge pull request #9500 from mattjj:remove-units-3
PiperOrigin-RevId: 428282592
2022-02-12 21:37:32 -08:00
Matthew Johnson
d59af33cb6 [remove units] make JaxprTrace.process_call not introduce units 2022-02-12 21:06:51 -08:00
jax authors
0566ea4ccd Merge pull request #9456 from mattjj:jaxpr-pprint-color-flag-and-default
PiperOrigin-RevId: 428247626
2022-02-12 14:49:02 -08:00
Matthew Johnson
004bb684ea add flag for jaxpr colorful syntax highlighting
set it on by default
2022-02-12 14:15:28 -08:00
jax authors
06c401226f Merge pull request #9498 from mattjj:remove-units-2
PiperOrigin-RevId: 428244153
2022-02-12 14:09:55 -08:00
Matthew Johnson
7077ce2e68 [remove units] make JaxprTrace.process_call not introduce units 2022-02-12 13:48:12 -08:00
Jake VanderPlas
c069bfeefd Respect __jax_array__ in jnp.ndarray operations 2022-02-11 12:44:55 -08:00
jax authors
2ae10ea7b8 Merge pull request #9469 from hawkinsp:doc
PiperOrigin-RevId: 428053772
2022-02-11 11:44:14 -08:00
Peter Hawkins
b9b73ee69a Recommend optax in jax.experimental_libraries.optimizers documentation. 2022-02-11 14:30:06 -05:00
jax authors
fd3540712d Merge pull request #9526 from jakevdp:fix-util
PiperOrigin-RevId: 428047726
2022-02-11 11:20:11 -08:00
jax authors
f02c6fcd72 Merge pull request #9475 from jakevdp:doc-extra-params
PiperOrigin-RevId: 427943650
2022-02-11 01:03:14 -08:00
Yash Katariya
2162868ed9 Update values after release
PiperOrigin-RevId: 427910510
2022-02-10 20:32:53 -08:00
jax authors
5691010d2f Copybara import of the project:
--
d42fffd849a4bac0c0c11a3346c93f07f8c64c44 by Jake VanderPlas <jakevdp@google.com>:

JaxTestCase: set numpy_rank_promotion='raise' by default
PiperOrigin-RevId: 427896974
2022-02-10 19:08:29 -08:00
jax authors
8b4a7ce910 Merge pull request #9330 from jakevdp:rank-promotion-final
PiperOrigin-RevId: 427878821
2022-02-10 17:08:39 -08:00
Jake VanderPlas
22ff25bb8e DOC: add ability to document extra_params within _wraps 2022-02-10 16:54:57 -08:00
Jake VanderPlas
6324577a63 JaxTestCase: set numpy_rank_promotion='raise' by default 2022-02-10 16:54:31 -08:00
Yash Katariya
8df1932100 Add ndim and size to GDA
PiperOrigin-RevId: 427874829
2022-02-10 16:46:26 -08:00
Jake VanderPlas
3e5048956d tests: use lax.broadcast_shapes in place of custom logic 2022-02-10 13:19:49 -08:00
Peter Hawkins
2512aed4bd Include the module name in MHLO IR dumps, rather than always naming dumps "builtin.module".
PiperOrigin-RevId: 427817318
2022-02-10 12:28:48 -08:00
Yash Katariya
1ad3551ec9 Release jax and jaxlib 0.3.0 as per the new release process.
PiperOrigin-RevId: 427809845
jaxlib-v0.3.0 jax-v0.3.0
2022-02-10 11:59:13 -08:00
jax authors
4df01eee08 Merge pull request #9523 from google:yashk2810-patch-6
PiperOrigin-RevId: 427804472
2022-02-10 11:39:07 -08:00
Yash Katariya
ef11d53ff9
Update TF commit for release 2022-02-10 11:33:20 -08:00
Yash Katariya
d82bcc2a0c Add jax[_ci] option to account for the new release process.
PiperOrigin-RevId: 427802081
2022-02-10 11:29:51 -08:00
jax authors
0b135fbfe8 Merge pull request #9518 from LenaMartens:changelist/427732530
PiperOrigin-RevId: 427766853
2022-02-10 09:18:46 -08:00
Lena Martens
2ba8aec274 Checkify: Fix docstring formatting and polish enabled_errors sets. 2022-02-10 16:54:43 +00:00
Lena Martens
1340fbbc09 Strip named_shape and weak_type from aval when donating buffers.
PiperOrigin-RevId: 427744848
2022-02-10 07:39:50 -08:00
jax authors
744636db71 Merge pull request #9514 from google:yashk2810-patch-5
PiperOrigin-RevId: 427605108
2022-02-09 16:50:00 -08:00
Yash Katariya
cabc98c047
update TF commit for release 2022-02-09 16:34:44 -08:00
Peter Hawkins
8af0d8d033 Add complex number DLPack support to JAX and TensorFlow.
Fixes https://github.com/google/jax/issues/9497

PiperOrigin-RevId: 427579098
2022-02-09 14:58:00 -08:00
Peter Hawkins
74506c7dda Rollback of: Implement the JAX transfer guard API
Adds `--jax_transfer_guard` flag and `jax.transfer_guard()` context manager that allows logging or disallowing unintended transfers.

The API distinguishes between two types of transfers:
* explicit transfers: `jax.device_put*()` and `jax.device_get()` calls.
* implicit transfers: Other transfers (e.g., printing a `DeviceArray`).

The transfer guard can take an action based on its guard level:

* "allow": Silently allow all transfers (default; s...

PiperOrigin-RevId: 427576107
2022-02-09 14:44:45 -08:00
jax authors
ca18fe1846 Merge pull request #9511 from hawkinsp:eqnprint
PiperOrigin-RevId: 427574021
2022-02-09 14:35:47 -08:00
Peter Hawkins
0c7687666a Replace core.pp_eqn_compact() with core.str_eqn_compact().
pp_eqn_compact() is used for one purpose only: creating metadata to put
on HLO. In that case, we don't need such carefully-formatted strings,
and speed is more important.

Gave a 6% speedup on a researcher's model.
2022-02-09 22:08:45 +00:00
Hyeontaek Lim
b7e1fec250 Implement the JAX transfer guard API
Adds `--jax_transfer_guard` flag and `jax.transfer_guard()` context manager that allows logging or disallowing unintended transfers.

The API distinguishes between two types of transfers:
* explicit transfers: `jax.device_put*()` and `jax.device_get()` calls.
* implicit transfers: Other transfers (e.g., printing a `DeviceArray`).

The transfer guard can take an action based on its guard level:

* "allow": Silently allow all transfers (default; same as the previous behavior).
* "log": Log and allow implicit transfers. Silently allow explicit transfers.
* "disallow": Disallow implicit transfers. Silently allow explicit transfers.
* "log_explicit": Log and allow all transfers.
* "disallow_explicit": Disallow all transfers.

The API also allows fine-control the transfer guard level of individual transfer directions. Their flag and context manager names are suffixed with the transfer direction:

* "host_to_device": Converting a Python value into a `DeviceBuffer`.
* "device_to_device": Copying a `DeviceBuffer` to a different device.
* "device_to_host": Fetching the value of a `DeviceBuffer`.

Example:
```
x = jnp.array(1)
y = jnp.array(2)
z = jnp.array(3)

print(x)  # No error
with jax.transfer_guard("disallow"):
  print(x)  # No error; x is already fetched
  print(jax.device_get(y))  # No error
  print(z)  # Error!
```

PiperOrigin-RevId: 427562278
2022-02-09 13:50:25 -08:00