29 Commits

Author SHA1 Message Date
Jake VanderPlas
b49c75c0d7 [x64] make jax.experimental.loops consistent with default dtype 2021-12-08 12:08:49 -08:00
Peter Hawkins
db2e91eba2 Move jax.test_util to jax._src.test_util.
Add forwarding shims for names used by external clients of JAX in practice.

PiperOrigin-RevId: 398721725
2021-09-24 07:02:49 -07:00
Peter Hawkins
a84426cb8f Switch internal users of jax.ops.index_... to use x.at[x].set() APIs. 2021-09-13 19:48:29 -04:00
Peter Hawkins
9f083d11da Use jax.* APIs rather than api.* names in tests.
Tests should use our own public APIs where they exist.
2021-09-13 16:01:32 -04:00
Peter Hawkins
26e9ebcdae Move jax.api to jax._src.api.
PiperOrigin-RevId: 368233837
2021-04-13 09:43:24 -07:00
George Necula
555a215cfb [loops] Extend loops with support for pytrees
Also improve error checking and error messages.
2021-01-14 21:17:14 +02:00
Jake VanderPlas
f74235cdae X32 tests: fail on dtype warnings 2020-12-08 13:03:30 -08:00
Roy Frostig
dbca9e682c unrevert #3674 (revert #3791) 2020-08-17 18:13:58 -07:00
Matthew Johnson
4236eb2b59
omnistaging, under a flag and disabled by default (#3370)
This change, when enabled, stages out all primitive calls in the dynamic
scope of a jitted, pmapped, or control flow function, rather than only
staging out based on data dependence. One improvement is that jitted
functions can consume less memory, by avoiding instantiating large
constants at trace time, and cause less memory fragmentation as well. It
also simplifies several internals.

See https://github.com/google/jax/pull/3370 fo more information.
2020-07-30 12:59:36 -07:00
Roy Frostig
fa2a0275c8 revert #3674 2020-07-17 15:44:51 -07:00
Roy Frostig
6416ca0e9d append filtered stack traces to error messages raised under transformations 2020-07-16 17:12:09 -07:00
Jake VanderPlas
afce718eb1 Add ability to specify individual test targets 2020-06-29 11:08:57 -07:00
Peter Hawkins
fffdb2daa8
Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_… (#3280)
* Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_util APIs.
Default to check_dtypes=True.

Remove explicit usages of check_dtypes=True from tests. This mostly just removes visual noise from tests. Testing for exact type equality is the sensible default, although there are cases where opting out makes sense.

No functional changes intended.

* Fix a number of lax reference implementations to preserve types.
2020-06-01 17:19:23 -04:00
Jake Vanderplas
bc30597780
Cleanup: remove unused imports in tests (#3276) 2020-06-01 11:49:35 -07:00
Peter Hawkins
b1bc841ae5
Replace np -> jnp, onp -> np in more places. (#2973)
* Replace np -> jnp, onp -> np in more places.

Context: #2370

* Fix typo in random_test.py
2020-05-05 16:40:41 -04:00
George Necula
a2c06d6113
Added clearer error message for tracers in numpy.split (#2508)
* Added clearer error message for tracers in numpy.split

Now we print:

ConcretizationTypeError: Abstract tracer value where concrete value is expected (in
jax.numpy.split argument 1).
Use transformation parameters such as `static_argnums` for `jit` to avoid
tracing input values.
See `https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-where-concrete-value-is-expected-error`.
Encountered value: Traced<ShapedArray>

* Fixed tests, slight change to the error message

* Expanded the FAQ entry about abstract tracers for higher-order primitives

* Added clarification for tracers inside jit of grad

* Updated FAQ language in response to reviews
2020-04-22 09:25:06 +02:00
Matthew Johnson
f2de1bf345 add trace state check tearDown to JaxTestCase 2020-04-02 22:01:43 -07:00
Matthew Johnson
b78b7a0309 add global trace state checks to more tests 2020-04-02 18:03:58 -07:00
Matthew Johnson
ab0a005452 check sublevel is reset in loops_test.py 2020-04-02 17:18:47 -07:00
George Necula
d2a827a08a Ensure the global trace_state is restored on errors in loops
This is an attempted fix for https://github.com/google/jax/issues/2507
2020-04-01 10:23:14 +03:00
Matthew Johnson
1f03d48c83 try resetting global tracer state in loops_test.py
attempting to address #2507
2020-03-30 20:10:39 -07:00
Peter Hawkins
e60d5dd54c
Remove "from __future__" uses from JAX. (#2117)
The future (Python 3) has arrived; no need to request it explicitly.
2020-01-29 12:29:03 -05:00
Peter Hawkins
dcc882cf6b
Drop Python 2 support from JAX. (#1962)
Remove six dependency.
2020-01-08 13:17:55 -05:00
Peter Hawkins
687b9050df
Prepare to switch default dtypes in JAX to be 32-bit types. (#1827)
This change prepares for switching the default types in JAX's NumPy to be 32-bit types. In particular, it makes the JAX tests pass in the event that jax.numpy.int_, jax.numpy.float_, and jax.numpy.complex_ are defined to be 32-bit types instead of 64-bit types, but does not yet change the defaults.
2019-12-09 21:18:39 -05:00
Peter Hawkins
45a1ba0bbc
Make more tests pass on TPU. (#1752) 2019-11-23 12:28:26 -05:00
George Necula
8ec6ea4742 Implemented suggestions from code review.
* added example of while_range to the module docstring.
* wrap the very long lines
2019-11-18 11:39:58 +01:00
George Necula
d549d44e43 Improved documentation
Also fix for the Python 2 iterators.
2019-11-16 18:36:08 +01:00
George Necula
64e186c337 Fix tests for Python 2 and for X64 2019-11-16 18:05:45 +01:00
George Necula
d24c374d59 An implementation of an experimental syntactic sugar for 'for' loops.
See description in jax/experimental/loops.py.
2019-11-16 17:23:40 +01:00