56 Commits

Author SHA1 Message Date
Matthew Johnson
1cf7d4ab5d Copybara import of the project:
--
4fcdadbfb3f4c484fd4432203cf13b88782b9311 by Matthew Johnson <mattjj@google.com>:

add jax.ensure_compile_time_eval to public api

aka jax.core.eval_context

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/7987 from google:issue7535 4fcdadbfb3f4c484fd4432203cf13b88782b9311
PiperOrigin-RevId: 420928687
2022-01-10 20:58:26 -08:00
Matthew Johnson
c8a34fe5cc add jax.block_until_ready function
fixes #8536
2021-12-14 11:02:14 -08:00
Peter Hawkins
4e21922055 Use imports relative to the jax package consistently, rather than .-relative imports.
This is more consistent, since currently we use a mix of both styles. It may also help pytype yield more accurate types.

PiperOrigin-RevId: 412057514
2021-11-24 07:48:29 -08:00
Roy Frostig
cdc9df2b66 add Lowered and Compiled to the jax public package 2021-11-10 18:57:39 -08:00
Qiao Zhang
0be30fbf96 Add jax.distributed.initialize for multi-host GPU. 2021-10-26 14:37:54 -07:00
Roy Frostig
98d245ebb4 add a config setting to control the default PRNG implementation
Also add explicit seeding functions for each PRNG implementation.
2021-10-07 21:22:40 -07:00
Jake VanderPlas
37719e4ad5 Import numpy explicitly in the jax namespace 2021-09-23 09:02:14 -07:00
Peter Hawkins
2c2f4033cc Move contents of jax.lib to jax._src.lib.
Add shim libraries for functions exported from jax.lib that other code seems to use in practice.

PiperOrigin-RevId: 398471863
2021-09-23 06:33:55 -07:00
Peter Hawkins
f35ab3693d Remove jax.partial from the JAX API.
Use functools.partial instead.
2021-09-20 09:19:53 -04:00
Peter Hawkins
6a1b626564 Remove jax.api.
Functions exported as jax.api were aliases for names in jax.*. Use the jax.* names instead.
2021-09-16 16:29:06 -04:00
Jake VanderPlas
33e2bed1b4 Fix package exports 2021-09-14 13:55:55 -07:00
Peter Hawkins
8b2123968a Switch internal users of jax.util.partial to use functools.partial. 2021-09-13 21:09:58 -04:00
Peter Hawkins
a84426cb8f Switch internal users of jax.ops.index_... to use x.at[x].set() APIs. 2021-09-13 19:48:29 -04:00
Jake VanderPlas
245581411e Add PEP484-compatible export for jax and its subpackages 2021-09-13 14:08:48 -07:00
Peter Hawkins
e869e5e0f8 Move contents of jax.api_util to jax._src.api_util and add a forwarding shim.
One of many changes to codify the set of exported symbols in the jax.* namespace.

PiperOrigin-RevId: 395484706
2021-09-08 09:00:56 -07:00
Matthew Johnson
2d28951ba4 address comments form @apaszke 2021-08-26 14:10:58 -07:00
Jake VanderPlas
1a397a5572 import in the JAX namespace 2021-08-24 08:44:05 -07:00
Roy Frostig
60e0e9f929 implement backwards-compatible behavior and enable custom PRNGs only conditionally
Introduce a config flag for upgrading to a world of custom PRNGs. The
flag defaults off, so that we can introduce custom PRNGs into the
codebase and allow downstream libraries time to upgrade.

Backwards compatible behavior is meant in an external sense. This does
not mean that our code is internally the same any longer.
2021-08-19 20:43:11 -07:00
Peter Hawkins
46cc654537 Move jax.abstract_arrays to jax._src.abstract_arrays.
PiperOrigin-RevId: 377044255
2021-06-02 06:25:22 -07:00
Skye Wanderman-Milne
9128ba0c74 Replace host_id with process_index terminology, take 2.
We're switching to the new terminology to avoid confusion in cases
where multiple jax processes are running on a single host, and each
process has a unique process_index/host_id.

This keeps aliases for the old `host_id` APIs for now, but these will
eventually be removed.

This was originally commited in
b77ef5138b631378e6a8ceb8bafc94fe91239bae, but reverted in
14acd070c2afb11c81fc91f43790577cd48cbf67 due to Google-internal test
failures from renaming the local_devices argument name. This change is
identical except it also adds staging for the argument name change.
2021-04-20 18:13:34 -07:00
jax authors
14acd070c2 Internal change
PiperOrigin-RevId: 369345279
2021-04-19 18:23:07 -07:00
Skye Wanderman-Milne
b77ef5138b Replace host_id with process_index terminology.
We're switching to the new terminology to avoid confusion in cases
where multiple jax processes are running on a single host, and each
process has a unique process_index/host_id.

This keeps aliases for the old `host_id` APIs for now, but these will
eventually be removed.
2021-04-19 14:09:19 -07:00
Peter Hawkins
14d991dd90 Move jax.config to jax._src.config.
PiperOrigin-RevId: 369230109
2021-04-19 08:53:12 -07:00
Peter Hawkins
26e9ebcdae Move jax.api to jax._src.api.
PiperOrigin-RevId: 368233837
2021-04-13 09:43:24 -07:00
Peter Hawkins
6a6f13e1b0 [JAX] Move contents of jax/dtypes.py to jax/_src/dtypes.py.
PiperOrigin-RevId: 367345623
2021-04-07 19:35:51 -07:00
Tom Hennigan
441ded0676 Remove jax.argnums_partial. 2021-04-07 11:34:38 +00:00
Matthew Johnson
632876d773 Copybara import of the project:
--
35fcf2e2fd5b4c56cbb591f4c8bf01222a23dfe5 by Matthew Johnson <mattjj@google.com>:

remove deprecated custom_transforms code

PiperOrigin-RevId: 366108489
2021-03-31 13:50:56 -07:00
Matthew Johnson
89768a3d28 add jax_default_matmul_precision flag & context mngr 2021-03-24 14:03:58 -07:00
Matthew Johnson
fd7b286ec9 unify configuration state handling 2021-03-23 18:56:01 -07:00
Skye Wanderman-Milne
c32d1e5aae Automatically initialize Cloud TPU topology env vars if running on a Cloud TPU VM.
This removes the need to manually set these env vars when running on a Cloud TPU pod slice.
2021-03-10 09:15:31 -08:00
Skye Wanderman-Milne
902038a718 Revert breaking change:
Automatically initialize Cloud TPU topology env vars if running on a Cloud TPU VM.

This removes the need to manually set these env vars when running on a Cloud TPU pod slice.

PiperOrigin-RevId: 361681134
2021-03-08 16:10:54 -08:00
Skye Wanderman-Milne
5a2859e1b6 Automatically initialize Cloud TPU topology env vars if running on a Cloud TPU VM.
This removes the need to manually set these env vars when running on a Cloud TPU pod slice.
2021-03-08 12:04:32 -08:00
Jake VanderPlas
12c84e7a50 Add jax.errors submodule & error troubleshooting docs 2021-03-03 12:39:12 -08:00
Tom Hennigan
7adb1e381d Add jax.default_backend() which returns the default platform name.
This can be useful when you need backend specific behaviour, e.g.:

    if jax.default_backend() == 'gpu':
      dataset = double_buffer(dataset)

Or if you want to assert a given backend is the default:

    assert jax.default_backend() == 'tpu'

I am a bit conflicted by the naming, "backend" is consistent with other APIs in
JAX (e.g. jit, local_devices etc) which accept a "backend" string which is used
to lookup an XLA backend by platform name.
2021-02-04 14:50:15 +00:00
Matthew Johnson
014f9a86b4 implement soft_pmap in terms of xmap 2021-01-28 07:59:57 -08:00
Roy Frostig
93c61e4d77 import closure_convert at top module level 2021-01-26 09:44:48 -08:00
Peter Hawkins
3ac809ede3 [JAX] Move jax.util to jax._src_util.
PiperOrigin-RevId: 351234602
2021-01-11 14:21:07 -08:00
Matthew Johnson
dc610e4516 add jax.device_put_replicated
Also move tests for device_put_sharded into pmap_test.py, since that
file tests with multiple devices even in our OSS CI.

Add both device_put_replicated and device_put_sharded to
jax/__init__.py.
2020-12-04 12:54:07 -08:00
Qiumin Xu
31600aac62 Add named_call public API.
Move named_call_p to core.py from lax.py.
Also move the translation rule to jax/interpreters/xla.py where the core_call translation rule is.
2020-11-12 17:32:01 -08:00
Lena Martens
ecad419cf3 Support grad with integer arguments.
- Add float0 and set-up at_least_vspace to return float0
values for int/bool primals
- Use Zero to wrap float0 tangents so they're correctly ignored in jvp
rules
- Add float0 handlers to XLA to support jit
- Fix convert_element_type and tie_in jvp rules
2020-09-28 19:07:04 +01:00
Stephan Hoyer
877053d8ab
Add jax.linear_transpose (#3398)
* Add jax.linear_transpose

Co-authored-by: Matthew Johnson <mattjj@google.com>

* add failing test for complex numbers

* Add picky dtype check for linear_transpose

* Lint fix

* Allow truncating dtypes to match inputs in linear_transpose

* Fix typo in shape check error

* improve docstring

* Don't support integer inputs; better docstring

* fixup

* Fix doctest

Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-09-16 20:29:19 -07:00
Tom Hennigan
aa75209db3
Import profiler in jax/__init__.py (#3719) 2020-07-11 20:44:16 -07:00
Skye Wanderman-Milne
44eae61059
Turn off INFO logging (again). (#3707)
Something must have started logging earlier than before, causing INFO-level logging to be initialized before we disabled it. This change disables INFO logging sooner.
2020-07-10 11:11:48 -04:00
Peter Hawkins
b943b31b22
Add jax.image.resize. (#3703)
* Add jax.image.resize.

This is a port of `tf.image.resize()` and the `ScaleAndTranslate` operator.

While I don't expect this implementation to be particularly fast, it is a useful generic implementation to which we can add optimized special cases as the need arises.
2020-07-10 09:57:59 -04:00
Adam Paszke
4d40b208ed
Initial version of invertible AD implementation (#3232)
This is a prototype implementation of the memory-efficient VJP method
for invertible function. The general idea is that thanks to
invertibility, we don't have to memoize any intermediate primal values,
but can simply reconstruct them in lock-step with gradient computation.
The API is such that the only thing a user has to do, is decorate a
function with `@invertible`, which will make AD apply the more efficient
transpose than usual.

The current version is expressive enough to support e.g. the Reversible
ResNet, but there are still some caveats:
- The definition of "invertible" function is a one that produces a jaxpr
  that can be inverted correctly if only we iterate over its equations
  in reverse. This is a bit strict, because users generally don't have
  too much control over that, and there are functions that produce
  jaxprs which will be treated as invertible when one topological
  ordering of equations is used, while they will be considered
  non-invertible for other valid orderings.
- It doesn't follow the usual jvp + transpose path, and it turns out
  that zero argument pruning in JVPTrace makes it pretty much impossible
  to implement correctly.
- `custom_ivjp` is an initial-style primitive.
- Invertible reverse-mode implementation (`rev_backward_pass`) assumes
  that all the VJPs of primal primitives are jittable (not sure if
  that's a problem, but worth pointing out).
- Not having a dedicated linearization pass makes the JVP of
  `custom_ivjp` inefficient if it is being staged out.
2020-06-15 12:35:06 +02:00
Jake Vanderplas
2a10dbbf37
deflake remainder of jax (#3343) 2020-06-06 10:51:34 -07:00
Sergei Lebedev
73b76e9976
Exported lax from jax/__init__.py (#3135)
This allows to use lax functions without separately importing jax.lax.
2020-05-19 15:40:03 -04:00
Peter Hawkins
f21ade3fa2
Remove jax.np from the jax namespace (use jax.numpy instead). (#3010) 2020-05-08 10:04:19 -04:00
Peter Hawkins
0ea22b7e19
Use a whitelist to restrict visibility in top-level jax namespace. (#2982)
* Use a whitelist to restrict visibility in top-level jax namespace.

The goal of this change is to capture the way the world is (i.e., not break users), and separately we will work on fixing users to avoid accidentally-exported APIs.
2020-05-07 17:24:19 -04:00
Tom Hennigan
6f6209838a
Import nn in jax/__init__.py. 2019-11-01 09:28:48 +00:00