112 Commits

Author SHA1 Message Date
Yash Katariya
13c34f9dc5 Move with_sharding_constraint out of experimental into jax.lax namespace.
PiperOrigin-RevId: 494635809
2022-12-11 22:55:21 -08:00
Srinivas Vasudevan
5adfb08986 Add lax.cumlogsumexp for cumulative logsumexp operations.
PiperOrigin-RevId: 485158935
2022-10-31 15:08:52 -07:00
Matthew Johnson
df5f7cb8d3 Rolling forward https://github.com/google/jax/pull/12707 after rollback, due to changes in relatively trivial jax.numpy shape validation code failed in some downstream user tests.
PiperOrigin-RevId: 480229237
2022-10-10 18:51:37 -07:00
jax authors
9cabd227d7 Copybara import of the project:
--
6d2aaac2454117d54997243714c1a009827707ca by Matthew Johnson <mattjj@google.com>:

implement bint arrays (opaque dtypes), add padding rules

Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
PiperOrigin-RevId: 479883102
2022-10-09 01:25:50 -07:00
Matthew Johnson
6d2aaac245 implement bint arrays (opaque dtypes), add padding rules
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-10-08 22:57:29 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Peter Hawkins
57b5acf1b6 Roll forward: Upgrade logistic into a primitive.
Unlike the previous attempt, we don't try to use mhlo.logistic as the lowering of the new primitive yet. Instead, we lower to the old implementation of `expit`. This means that this change should be a no-op numerically and we can work on changing its implementation in a subsequent change.

PiperOrigin-RevId: 472705623
2022-09-07 06:06:56 -07:00
jax authors
9c16c83234 Rollback of upgrade logistic (sigmoid) function into a lax primitive.
PiperOrigin-RevId: 471105650
2022-08-30 15:30:43 -07:00
Peter Hawkins
f68f1c0cd0 Upgrade logistic (sigmoid) function into a lax primitive.
This allows us to lower it to `mhlo.logistic`, which allows XLA to generate more efficient code.

PiperOrigin-RevId: 470300985
2022-08-26 11:58:28 -07:00
jax authors
3e3542b0d6 Upgrade logistic (sigmoid) function into a lax primitive.
This allows us to lower it to `mhlo.logistic`, which allows XLA to generate more efficient code.

PiperOrigin-RevId: 469841487
2022-08-24 15:39:37 -07:00
Peter Hawkins
6276194e1c Upgrade logistic (sigmoid) function into a lax primitive.
This allows us to lower it to `mhlo.logistic`, which allows XLA to generate more efficient code.

PiperOrigin-RevId: 469789339
2022-08-24 12:04:01 -07:00
Peter Hawkins
335b2cfb26 [JAX] Prepare not to export jax._src by default.
Currently
```
import jax
```
populates `jax._src` in the names exported from JAX. This change prepares for not exporting `jax._src` by default.

In particular, explicitly import modules from jax._src and refer to those imports rather than assuming jax._src contents will be around later. This is a common pattern in tests.

This change does not yet remove any exported names.

Issue https://github.com/google/jax/issues/11951

PiperOrigin-RevId: 469480816
2022-08-23 09:36:47 -07:00
Jake VanderPlas
37e7c1f8fd Add copy_p to jax.lax namespace 2022-08-19 13:09:15 -07:00
Nicholas Junge
311e6a92f9 Add bitwise XOR reducer to lax.reduce
This commit adds handling for the `lax.bitwise_xor` operation to `lax.reduce`. It also includes a new standard reduce primitive, modeled after the existing `and`/ `or` reducer primitives.
2022-06-15 16:56:51 +02:00
Jake VanderPlas
5782210174 CI: fix flake8 ignore declarations 2022-04-21 13:44:12 -07:00
Matthew Johnson
4354f355a8 prototyping dynamic shapes
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2022-04-11 22:10:47 -07:00
Roy Frostig
64572795b7 remove _select_and_{gather,scatter}_add from public jax.lax module 2022-03-10 10:43:42 -08:00
Roy Frostig
2f6de4a2df remove _reduce_window_{min,max,sum,prod} from public jax.lax module 2022-03-10 10:43:42 -08:00
Roy Frostig
8f93629e87 remove _convert_element_type from public jax.lax module 2022-03-09 18:46:38 -08:00
Roy Frostig
bea77710d6 remove _dilate_shape from public jax.lax module 2022-03-08 16:34:26 -08:00
Roy Frostig
460feed1dd remove _ones and _zeros from public jax.lax module 2022-03-08 16:34:26 -08:00
Roy Frostig
0cae3160f5 remove _delta from public jax.lax module 2022-03-08 16:34:26 -08:00
Roy Frostig
90f31c1df0 remove _tri from public jax.lax module 2022-03-08 16:34:26 -08:00
Roy Frostig
3c345ee785 remove _eye from public jax.lax module 2022-03-08 16:34:26 -08:00
Roy Frostig
7824325c23 remove _broadcasting_shape_rule from public jax.lax module 2022-03-08 16:34:26 -08:00
Roy Frostig
2324e5b5a2 remove _upcast_fp16_for_computation from public jax.lax module 2022-03-08 16:34:26 -08:00
Roy Frostig
e262c72b19 remove _check_user_dtype_supported from public jax.lax module 2022-03-08 16:34:26 -08:00
Roy Frostig
afc8729689 remove _reduce_or and _reduce_and from public jax.lax module 2022-03-08 16:34:26 -08:00
Roy Frostig
c979f64e29 remove _reduce_min from public jax.lax module 2022-03-08 16:34:26 -08:00
Roy Frostig
299d4db98d remove _reduce_max from public jax.lax module 2022-03-08 16:34:26 -08:00
Roy Frostig
6f519576f6 remove _reduce_sum from public jax.lax module 2022-03-08 16:34:26 -08:00
Roy Frostig
7890fb7596 remove _one and _zero from public jax.lax module 2022-03-08 12:56:11 -08:00
Roy Frostig
3f88518363 remove three internal functions from public jax.lax module
... namely `_float`, `_input_dtype`, and `_broadcasting_select`.
2022-03-08 12:49:36 -08:00
Roy Frostig
731998279a remove _eq_meet from public jax.lax module
PiperOrigin-RevId: 433251361
2022-03-08 10:39:47 -08:00
Roy Frostig
3e77a56fa2 remove _complex from public jax.lax module
PiperOrigin-RevId: 433198652
2022-03-08 06:49:04 -08:00
Roy Frostig
f7731bf959 remove _const from public jax.lax module
Modify all internal call sites to use `jax._src.lax.lax._const`.
2022-03-07 12:26:25 -08:00
jax authors
d9f82f7b9b [JAX] Move experimental.ann.approx_*_k into lax.
Updated docs, tests and the example code snippets.

PiperOrigin-RevId: 431781401
2022-03-01 14:46:33 -08:00
Jake VanderPlas
e13c847e04 Index update operators: add scatter_apply() 2022-02-18 09:44:40 -08:00
Peter Hawkins
8ca6622c0b Change lax.select_p to be an n-ary predicate, 'lax.select_n_p'. Change lax.select() to be a thin shim around the new n-ary version.
Note that one key difference between `lax.select_p` and `lax.select_n_p` is that the order of the cases is reversed for boolean predicates. This merited a new name to minimize confusion.

Use lax.select_n() in conditional batching. This means that we only produce one `select_n()` primitive for each conditional output, rather than a tree. While this has no effect on the number of HLO operators we generate, it can reduces the number of jaxpr equations significantly.

PiperOrigin-RevId: 427517899
2022-02-09 11:03:09 -08:00
Roman Novak
b9b759d4ff
Merge branch 'main' into conv_local 2022-01-07 09:51:46 -08:00
Peter Hawkins
4e21922055 Use imports relative to the jax package consistently, rather than .-relative imports.
This is more consistent, since currently we use a mix of both styles. It may also help pytype yield more accurate types.

PiperOrigin-RevId: 412057514
2021-11-24 07:48:29 -08:00
Peter Hawkins
83d8c6c238 Split slice/update_slice/gather/scatter out of jax._src.lax.lax into jax._src.lax.slicing.
To solve a circular dependency problem where some functions in jax._src.lax.lax depend on slicing, I moved a number of utility functions, e.g., standard_primitive, into a new module `jax._src.lax.utils`. Only utilities that need to be present at module import time were moved.

PiperOrigin-RevId: 411921794
2021-11-23 16:35:18 -08:00
Peter Hawkins
4204a25c91 Split convolution functions out of jax._src.lax.lax and into a separate module (jax._src.lax.convolution).
No public API changes.

PiperOrigin-RevId: 411871903
2021-11-23 12:35:50 -08:00
Peter Hawkins
45d7ade995 Split windowed reductions and their gradients into a separate file inside the lax implementation. 2021-11-18 18:21:38 -05:00
Peter Hawkins
fa89f34be6 [JAX] Port lax translation rules to updated XLA translation rule API.
PiperOrigin-RevId: 404114709
2021-10-18 18:07:29 -07:00
Jake VanderPlas
a0c2fe0dfd Remove duplicate import 2021-10-05 15:47:09 -07:00
Peter Hawkins
867068821e Drop out-of-bounds indexes in gather. 2021-09-23 10:35:03 -04:00
Peter Hawkins
b56c2ccadd Remove export of jax.lax.partial. 2021-09-14 16:17:50 -04:00
Peter Hawkins
8b2123968a Switch internal users of jax.util.partial to use functools.partial. 2021-09-13 21:09:58 -04:00
Jake VanderPlas
245581411e Add PEP484-compatible export for jax and its subpackages 2021-09-13 14:08:48 -07:00