37 Commits

Author SHA1 Message Date
Skye Wanderman-Milne
9128ba0c74 Replace host_id with process_index terminology, take 2.
We're switching to the new terminology to avoid confusion in cases
where multiple jax processes are running on a single host, and each
process has a unique process_index/host_id.

This keeps aliases for the old `host_id` APIs for now, but these will
eventually be removed.

This was originally commited in
b77ef5138b631378e6a8ceb8bafc94fe91239bae, but reverted in
14acd070c2afb11c81fc91f43790577cd48cbf67 due to Google-internal test
failures from renaming the local_devices argument name. This change is
identical except it also adds staging for the argument name change.
2021-04-20 18:13:34 -07:00
jax authors
14acd070c2 Internal change
PiperOrigin-RevId: 369345279
2021-04-19 18:23:07 -07:00
Skye Wanderman-Milne
b77ef5138b Replace host_id with process_index terminology.
We're switching to the new terminology to avoid confusion in cases
where multiple jax processes are running on a single host, and each
process has a unique process_index/host_id.

This keeps aliases for the old `host_id` APIs for now, but these will
eventually be removed.
2021-04-19 14:09:19 -07:00
Peter Hawkins
14d991dd90 Move jax.config to jax._src.config.
PiperOrigin-RevId: 369230109
2021-04-19 08:53:12 -07:00
Peter Hawkins
26e9ebcdae Move jax.api to jax._src.api.
PiperOrigin-RevId: 368233837
2021-04-13 09:43:24 -07:00
Peter Hawkins
6a6f13e1b0 [JAX] Move contents of jax/dtypes.py to jax/_src/dtypes.py.
PiperOrigin-RevId: 367345623
2021-04-07 19:35:51 -07:00
Tom Hennigan
441ded0676 Remove jax.argnums_partial. 2021-04-07 11:34:38 +00:00
Matthew Johnson
632876d773 Copybara import of the project:
--
35fcf2e2fd5b4c56cbb591f4c8bf01222a23dfe5 by Matthew Johnson <mattjj@google.com>:

remove deprecated custom_transforms code

PiperOrigin-RevId: 366108489
2021-03-31 13:50:56 -07:00
Matthew Johnson
89768a3d28 add jax_default_matmul_precision flag & context mngr 2021-03-24 14:03:58 -07:00
Matthew Johnson
fd7b286ec9 unify configuration state handling 2021-03-23 18:56:01 -07:00
Skye Wanderman-Milne
c32d1e5aae Automatically initialize Cloud TPU topology env vars if running on a Cloud TPU VM.
This removes the need to manually set these env vars when running on a Cloud TPU pod slice.
2021-03-10 09:15:31 -08:00
Skye Wanderman-Milne
902038a718 Revert breaking change:
Automatically initialize Cloud TPU topology env vars if running on a Cloud TPU VM.

This removes the need to manually set these env vars when running on a Cloud TPU pod slice.

PiperOrigin-RevId: 361681134
2021-03-08 16:10:54 -08:00
Skye Wanderman-Milne
5a2859e1b6 Automatically initialize Cloud TPU topology env vars if running on a Cloud TPU VM.
This removes the need to manually set these env vars when running on a Cloud TPU pod slice.
2021-03-08 12:04:32 -08:00
Jake VanderPlas
12c84e7a50 Add jax.errors submodule & error troubleshooting docs 2021-03-03 12:39:12 -08:00
Tom Hennigan
7adb1e381d Add jax.default_backend() which returns the default platform name.
This can be useful when you need backend specific behaviour, e.g.:

    if jax.default_backend() == 'gpu':
      dataset = double_buffer(dataset)

Or if you want to assert a given backend is the default:

    assert jax.default_backend() == 'tpu'

I am a bit conflicted by the naming, "backend" is consistent with other APIs in
JAX (e.g. jit, local_devices etc) which accept a "backend" string which is used
to lookup an XLA backend by platform name.
2021-02-04 14:50:15 +00:00
Matthew Johnson
014f9a86b4 implement soft_pmap in terms of xmap 2021-01-28 07:59:57 -08:00
Roy Frostig
93c61e4d77 import closure_convert at top module level 2021-01-26 09:44:48 -08:00
Peter Hawkins
3ac809ede3 [JAX] Move jax.util to jax._src_util.
PiperOrigin-RevId: 351234602
2021-01-11 14:21:07 -08:00
Matthew Johnson
dc610e4516 add jax.device_put_replicated
Also move tests for device_put_sharded into pmap_test.py, since that
file tests with multiple devices even in our OSS CI.

Add both device_put_replicated and device_put_sharded to
jax/__init__.py.
2020-12-04 12:54:07 -08:00
Qiumin Xu
31600aac62 Add named_call public API.
Move named_call_p to core.py from lax.py.
Also move the translation rule to jax/interpreters/xla.py where the core_call translation rule is.
2020-11-12 17:32:01 -08:00
Lena Martens
ecad419cf3 Support grad with integer arguments.
- Add float0 and set-up at_least_vspace to return float0
values for int/bool primals
- Use Zero to wrap float0 tangents so they're correctly ignored in jvp
rules
- Add float0 handlers to XLA to support jit
- Fix convert_element_type and tie_in jvp rules
2020-09-28 19:07:04 +01:00
Stephan Hoyer
877053d8ab
Add jax.linear_transpose (#3398)
* Add jax.linear_transpose

Co-authored-by: Matthew Johnson <mattjj@google.com>

* add failing test for complex numbers

* Add picky dtype check for linear_transpose

* Lint fix

* Allow truncating dtypes to match inputs in linear_transpose

* Fix typo in shape check error

* improve docstring

* Don't support integer inputs; better docstring

* fixup

* Fix doctest

Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-09-16 20:29:19 -07:00
Tom Hennigan
aa75209db3
Import profiler in jax/__init__.py (#3719) 2020-07-11 20:44:16 -07:00
Skye Wanderman-Milne
44eae61059
Turn off INFO logging (again). (#3707)
Something must have started logging earlier than before, causing INFO-level logging to be initialized before we disabled it. This change disables INFO logging sooner.
2020-07-10 11:11:48 -04:00
Peter Hawkins
b943b31b22
Add jax.image.resize. (#3703)
* Add jax.image.resize.

This is a port of `tf.image.resize()` and the `ScaleAndTranslate` operator.

While I don't expect this implementation to be particularly fast, it is a useful generic implementation to which we can add optimized special cases as the need arises.
2020-07-10 09:57:59 -04:00
Adam Paszke
4d40b208ed
Initial version of invertible AD implementation (#3232)
This is a prototype implementation of the memory-efficient VJP method
for invertible function. The general idea is that thanks to
invertibility, we don't have to memoize any intermediate primal values,
but can simply reconstruct them in lock-step with gradient computation.
The API is such that the only thing a user has to do, is decorate a
function with `@invertible`, which will make AD apply the more efficient
transpose than usual.

The current version is expressive enough to support e.g. the Reversible
ResNet, but there are still some caveats:
- The definition of "invertible" function is a one that produces a jaxpr
  that can be inverted correctly if only we iterate over its equations
  in reverse. This is a bit strict, because users generally don't have
  too much control over that, and there are functions that produce
  jaxprs which will be treated as invertible when one topological
  ordering of equations is used, while they will be considered
  non-invertible for other valid orderings.
- It doesn't follow the usual jvp + transpose path, and it turns out
  that zero argument pruning in JVPTrace makes it pretty much impossible
  to implement correctly.
- `custom_ivjp` is an initial-style primitive.
- Invertible reverse-mode implementation (`rev_backward_pass`) assumes
  that all the VJPs of primal primitives are jittable (not sure if
  that's a problem, but worth pointing out).
- Not having a dedicated linearization pass makes the JVP of
  `custom_ivjp` inefficient if it is being staged out.
2020-06-15 12:35:06 +02:00
Jake Vanderplas
2a10dbbf37
deflake remainder of jax (#3343) 2020-06-06 10:51:34 -07:00
Sergei Lebedev
73b76e9976
Exported lax from jax/__init__.py (#3135)
This allows to use lax functions without separately importing jax.lax.
2020-05-19 15:40:03 -04:00
Peter Hawkins
f21ade3fa2
Remove jax.np from the jax namespace (use jax.numpy instead). (#3010) 2020-05-08 10:04:19 -04:00
Peter Hawkins
0ea22b7e19
Use a whitelist to restrict visibility in top-level jax namespace. (#2982)
* Use a whitelist to restrict visibility in top-level jax namespace.

The goal of this change is to capture the way the world is (i.e., not break users), and separately we will work on fixing users to avoid accidentally-exported APIs.
2020-05-07 17:24:19 -04:00
Tom Hennigan
6f6209838a
Import nn in jax/__init__.py. 2019-11-01 09:28:48 +00:00
Matthew Johnson
861e939324 import random in jax/__init__.py 2019-07-15 23:04:50 +01:00
Peter Hawkins
41ac7e2d85 Use a regular import to add jax.__version__ rather than exec() trickery.
(The exec() trickery is needed for setup.py, but not for jax/__init__.py.)
2019-02-19 11:38:28 -05:00
Matthew Johnson
9a9c304644 add version attribute
following idea 3 here:
https://packaging.python.org/guides/single-sourcing-package-version/
2019-02-13 20:04:38 -08:00
Dougal Maclaurin
709cfe905d Set default TF log level to "1" to avoid reporting things like CPU frequency at import time. Also import jax.numpy in __init__.py because it has side effects that set up the infix operator overloading. 2018-12-05 15:55:01 -05:00
Peter Hawkins
e180f08113 source sync
PiperOrigin-RevId: 222451919
2018-11-21 20:22:51 -08:00
Matthew Johnson
a30e858e59 populating source tree 2018-11-17 18:03:33 -08:00