58 Commits

Author SHA1 Message Date
Jake VanderPlas
47ae5bddd7 Mark jax.abstract_arrays as deprecated 2023-06-07 23:36:40 -07:00
Roy Frostig
d927a5dbf3 migrate internal dependencies from jax.core to jax._src.core
... in preparation for paring down `jax.core`'s exported symbols.

Also includes a few import fixups along the way, and a TODO comment to avoid an
import cycle in `_src/dtypes.py`.

PiperOrigin-RevId: 496024782
2022-12-16 21:00:14 -08:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Jake VanderPlas
5782210174 CI: fix flake8 ignore declarations 2022-04-21 13:44:12 -07:00
Peter Hawkins
46cc654537 Move jax.abstract_arrays to jax._src.abstract_arrays.
PiperOrigin-RevId: 377044255
2021-06-02 06:25:22 -07:00
George Necula
e37727cbce [jax2tf] Implementation of a parametric shape-polymorphism feature for jax2tf.
See the PR description.
2021-04-08 10:42:38 +03:00
Peter Hawkins
6a6f13e1b0 [JAX] Move contents of jax/dtypes.py to jax/_src/dtypes.py.
PiperOrigin-RevId: 367345623
2021-04-07 19:35:51 -07:00
Jake VanderPlas
9790232556 Python integer conversion: always return int64 or OverflowError 2021-03-29 09:26:19 -07:00
Matthew Johnson
8c3125c172 fix convert_element_type on large Py int inputs 2021-03-21 19:08:59 -07:00
George Necula
f105517ea2 Fixed mypy type errors for numpy 1.20
Revert also previous changes that pinned numpy to 1.19.

One of the changes in numpy 1.20 is to add more type annotations.
However, this sometimes make mypy give errors. A common example is
numpy.take, which with the new type annotation does not appear to
mypy as indexable.

Another change is that np.int and np.bool are deprecated. One
should use np.bool_ or np.int_, or the built-ins bool and int.
2021-02-05 10:40:47 +02:00
Peter Hawkins
d2a0bbd992 Add np.intc to the set of valid jaxtypes.
Fixes a number of test failures on Windows.
2020-11-20 13:55:12 -05:00
Peter Hawkins
81b6cd29ff [JAX] Move traceback_util.py into jax._src.
traceback_util is a JAX-private API.

PiperOrigin-RevId: 340659195
2020-11-04 09:02:59 -08:00
Roy Frostig
5d50e19364 add path exclusion opt-in to filtered stack traces and use it throughout the codebase 2020-10-26 12:31:19 -07:00
Lena Martens
ecad419cf3 Support grad with integer arguments.
- Add float0 and set-up at_least_vspace to return float0
values for int/bool primals
- Use Zero to wrap float0 tangents so they're correctly ignored in jvp
rules
- Add float0 handlers to XLA to support jit
- Fix convert_element_type and tie_in jvp rules
2020-09-28 19:07:04 +01:00
Srijan Saurav
40e20242db
Fix code quality issues (#4302)
Changes:
- Fix unnecessary generator
- Iterate dictionary directly instead of calling .keys()
- Remove global statement at the module level
- Use list() instead of a list comprehension
- Use with statement to open the file
- Merge isinstance calls
2020-09-17 09:21:18 -07:00
Adam Paszke
17472b97ab Optimize zeros_like_shaped_array
This function is used a lot more now, because `ad.instantiate_zeros` now
goes through that and not `zeros_like_array`.
2020-06-05 15:52:03 +00:00
joao guilherme
d2f84d635b
Change instances of onp to np and np to jnp (#3044) 2020-05-12 20:37:05 -04:00
Peter Hawkins
50dc44be6f
Fix IntEnum test when checking is enabled. (#2981) 2020-05-07 08:46:13 +03:00
George Necula
a2c06d6113
Added clearer error message for tracers in numpy.split (#2508)
* Added clearer error message for tracers in numpy.split

Now we print:

ConcretizationTypeError: Abstract tracer value where concrete value is expected (in
jax.numpy.split argument 1).
Use transformation parameters such as `static_argnums` for `jit` to avoid
tracing input values.
See `https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-where-concrete-value-is-expected-error`.
Encountered value: Traced<ShapedArray>

* Fixed tests, slight change to the error message

* Expanded the FAQ entry about abstract tracers for higher-order primitives

* Added clarification for tracers inside jit of grad

* Updated FAQ language in response to reviews
2020-04-22 09:25:06 +02:00
Chris Jones
1e7d13b5f9
Give Vars an aval. (#2299) 2020-03-09 10:14:23 +01:00
Pavel Sountsov
b2ef5bc095
Canonicalize the shape in the wrapper functions in random.py. (#2165)
* Canonicalize the shape in the wrapper functions in random.py.

This lets the user be more sloppy in using numpy arrays and statically
known DeviceArrays for shapes, and still hit the jit cache. When they
are not, the error is improved.

* Fix some errors.

* No need for the Poly workaround.

* Bypass canonicalization for None shapes in random.py.
2020-02-05 10:10:33 -08:00
Stephan Hoyer
be2704e425
Ensure ShapedArray.shape is always a tuple of builtins integers (#2039)
* Ensure ShapedArray.shape is always a tuple of builtins integers

Currently, it can sometimes include elements of type int64, e.g.,

    In [1]: import jax.numpy as jnp

    In [2]: x = jnp.arange(3) + 1

    In [3]: x.shape  # looks fine at first glance
    Out[3]: (3,)

    In [4]: type(x.shape[0])  # yikes!
    Out[4]: numpy.int64

This confirms my hypothesis that NumPy's scalar types are the root of all evil.

* Allow Poly in shapes

* Simple shape coercion in ShapedArray

* cleaner
2020-01-29 14:24:11 -08:00
Peter Hawkins
e60d5dd54c
Remove "from __future__" uses from JAX. (#2117)
The future (Python 3) has arrived; no need to request it explicitly.
2020-01-29 12:29:03 -05:00
Chase Roberts
82d6c6ce51 Added better error messages. (#2058)
#2057

Added better error messages for when a user accidentally uses a python cast instead of a the `jax.numpy` casting.
2020-01-27 15:44:33 -08:00
Peter Hawkins
dcc882cf6b
Drop Python 2 support from JAX. (#1962)
Remove six dependency.
2020-01-08 13:17:55 -05:00
Matthew Johnson
876c9c0ede fix x64 issue 2019-12-11 21:25:29 -08:00
Peter Hawkins
e87d9718c3
Support IntEnum values as arguments to JAX functions. (#1840)
* Support IntEnum values as arguments to JAX functions.

When abstractifying a Python value, search the method-resolution order (MRO) of the type rather than only looking at the value's own type. IntEnum instances are subclasses of int, so this allows us to correctly handle them as integers, much as NumPy itself does.
2019-12-11 12:27:11 -05:00
Matthew Johnson
9a8523603c Add experimental rematerialization decorator
We want to allow users to control how reverse-mode autodiff saves values
from the forward pass. In particular, we want it to be easy to signal
that a function shouldn't have any of its intermediate residuals stored
for the backward pass, and instead those values should be recomputed
from the function's saved inputs. (This feature is especially handy for
accelerators on which memory access is much more expensive than FLOPs
are.) In JAX terms, since we implement reverse-mode as a composition of
forward-mode, partial evaluation, and transposition, we want users to
control how partial evaluation behaves.

See https://github.com/google/jax/pull/1749 for more.

Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-11-27 19:52:24 -08:00
Peter Hawkins
ee36818a58
Add bfloat16 support to JAX. (#1720)
bfloat16 support is still immature, but this PR adds some initial support.

Fixes #76, at least enough that we can declare it fixed and open specific issues for specific bfloat16 problems.

The main awkwardness that this change deals with is that classic NumPy doesn't understand bfloat16 promotion rules, so we must:

implement our own type promotion operators that understand bfloat16 types
wrap a number of the reference implementations in tests to temporarily cast to float32 for computation.
2019-11-20 22:43:46 -05:00
Peter Hawkins
42dd736afd
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.

Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.

This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.

In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
Peter Hawkins
f4aa5150e8
Move internal type-related functions into a new (internal) jax.types … (#1695)
* Move internal type-related functions into a new (internal) jax.types module.

Avoid calling onp type functions in lieu of the wrappers in jax.types. Currently these do the same thing, but future changes will make the behavior of the jax type functions diverge from the classic NumPy versions in some cases.

Move xla_bridge.canonicalize_dtype into jax.types, since it fits there more naturally.

* Rename jax.types to jax.dtypes.

* s/types/dtypes/ in tests.
2019-11-15 10:02:51 -05:00
Peter Hawkins
b8a5473614 Add experimental support for XLA infeed/outfeed. 2019-10-09 15:05:54 -04:00
Matthew Johnson
b702f8de3e De-tuplify the rest of the core
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-08-21 13:21:20 -07:00
cclauss
f17f2b976f Mark instances of 'long' with # noqa 2019-08-05 00:22:41 +02:00
Matthew Johnson
fb1e2124ff enable staging out more multi-replica computations
There are two real changes here:

1. In api.py, improve the handling of the axis environment in
`xla_computation` so that `xla_computation(pmap(lambda x: x))(x)` works,
by checking for pmap's included in the jaxpr to be staged out (analogous
to how jit-of-pmap works).

2. In pxla.py, handle as a special case the pmapping of computations for
which the output does not depend on the input. The purpose here is to
enable `xla_computation(pmap(lambda x: x))(x)` when `x = np.arange(8)`
yet only one XLA device is available. Evaluating that expression leads
to the (partial) evaluation of a trivial pmap (unit / empty-tuple inputs and
outputs), which would cause an error when we attempt to compile an XLA
computation for more replicas than available hardware devices. We don't
know the computation is trivial until after we've run the function, i.e.
until we're in the xla_pmap impl, so this is the right place to do it.

The other changes are unrelated miscellania.
2019-07-09 15:12:02 -07:00
Matthew Johnson
5aef18f897 improve literal hashing logic
This fixes a bug where scalar ndarray literals with different dtypes
could hash to the same value. It also makes scalar DeviceArray literals
hashable after #884.
2019-06-19 10:32:55 -07:00
Matthew Johnson
9c931ddebe allow more types to be jaxpr literals, fixes #772 2019-05-28 22:38:06 -07:00
Matthew Johnson
29629931a1
Merge pull request #704 from google/differentiable-scan
Differentiable scan!
2019-05-13 10:26:09 -07:00
Peter Hawkins
367833bea2 Changes for compatibility with a upcoming Jaxlib update.
Shape.abstract_arrays will only accept dtypes, not scalar type objects.
Add long to the set of types known to abstract_arrays in Python 2.
Make api_test.py accepting of long values in shapes.
2019-05-08 20:32:24 -04:00
Matthew Johnson
15d783a836 Merge remote-tracking branch 'origin/master' into differentiable-scan 2019-05-08 13:42:44 -07:00
Matthew Johnson
a17f8e4ca8 add jaxpr eqn structured input, transpose progress
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-05-08 13:41:19 -07:00
Matthew Johnson
6736823021 victory! patial eval of scan (+ linearize!)
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-05-08 13:41:15 -07:00
Matthew Johnson
85755820bb add defvjp functions for custom VJPs
c.f. #116, which won't be closed until we add documentation
2019-04-23 17:47:28 -07:00
Matthew Johnson
aa5b036d6d misc python performance improvements 2019-04-15 07:45:10 -07:00
Peter Hawkins
356e6dcfe8 Add np.longlong to the set of JAX array types. 2019-04-01 15:49:12 -07:00
Matthew Johnson
65b6f19cf8 add a better error message on cond pval join error 2019-03-02 18:08:34 -08:00
Peter Hawkins
0129e94f79 Add {float16,uint16,uint8,int16,int8} types to abstract_arrays.
In principle this allows these types to be used. They are as yet untested, however.

Fixes #75.
2019-02-17 17:18:20 -05:00
Peter Hawkins
d43c65dcd8 Add preliminary support for np.complex128.
Only lightly tested.
2019-01-11 18:22:43 -05:00
Matthew Johnson
0f7c7c4eab generalize jacfwd and jacrev to handle pytrees 2019-01-06 12:49:41 -08:00
Matthew Johnson
1ae1ae17a2 add EyeConstant, new np.eye and np.array code 2018-12-18 09:16:59 -08:00