3235 Commits

Author SHA1 Message Date
James Bradbury
a8c5b49fda
Merge pull request #1722 from google/jb/sinc-double-where
Use double-where trick to avoid NaNs in grad(sinc)
2019-11-21 09:03:19 -08:00
Peter Hawkins
a8a19e196c
Implement batching rule for lax._select_and_gather_add (#1736) 2019-11-21 11:52:58 -05:00
Peter Hawkins
2b0cde3648
Fix test failure for jax.numpy.signbit(bfloat16) on TPU. (#1735) 2019-11-21 10:48:53 -05:00
Peter Hawkins
c60f3fd65d
Minor documentation fixes. (#1734) 2019-11-21 09:51:26 -05:00
Peter Hawkins
ee36818a58
Add bfloat16 support to JAX. (#1720)
bfloat16 support is still immature, but this PR adds some initial support.

Fixes #76, at least enough that we can declare it fixed and open specific issues for specific bfloat16 problems.

The main awkwardness that this change deals with is that classic NumPy doesn't understand bfloat16 promotion rules, so we must:

implement our own type promotion operators that understand bfloat16 types
wrap a number of the reference implementations in tests to temporarily cast to float32 for computation.
2019-11-20 22:43:46 -05:00
Skye Wanderman-Milne
b7d11ab90d Bump jaxlib version to 0.1.35 jaxlib-v0.1.35 2019-11-20 15:58:34 -08:00
Stephan Hoyer
ee29705712
Add jax.scipy.ndimage to online docs (#1724) 2019-11-20 12:35:10 -08:00
Skye Wanderman-Milne
ad09822067 Bump README to jaxlib 0.1.34 2019-11-20 10:59:54 -08:00
Tzu-Wei Sung
db46a22b23 Implementation of np.signbit (#1627)
Implement `np.signbit`.
2019-11-20 13:32:43 -05:00
Matthew Johnson
2353345446 only use one gensym in tracers_to_jaxpr
fixes a bug in #1721 revealed by additional internal testing
2019-11-20 09:12:15 -08:00
Peter Buchlovsky
9d1204689f Fix typo 2019-11-20 08:53:01 -08:00
Peter Buchlovsky
410ebfeb1c Fix typo 2019-11-20 08:52:46 -08:00
Matthew Johnson
68b7dc85c3 fix multi-host pmap, disambiguate nrep
When #1667 inlined a function into its caller, it mixed up two distinct
values referred to as `nrep` in the two functions: num_global_replicas
vs num_local_replicas. The result caused errors on multi-host setups.

Co-authored-by: Jonathan Heek <jheek@google.com>
2019-11-20 08:01:15 -08:00
Stephan Hoyer
65f0556ead
Add support for scipy.ndimage.map_coordinates with order=0 and order=1 (#1711)
* Add support for scipy.ndimage.map_coordinates with order=1

Higher dimensional interpolation will be a bit trickier, but this should
already be useful.

* move around docstring

* dtype fixes, more tests

* fixup float32 tests

* Handle order=0

* Tests for errors from map_coordinates
2019-11-19 17:14:09 -08:00
James Bradbury
1817cab012 Use double-where trick to avoid NaNs in grad(sinc) 2019-11-19 16:47:32 -08:00
Skye Wanderman-Milne
b1888881da Bump jaxlib version to 0.1.33 and update WORKSPACE.
Includes XLA fixes for CPU psum.
jaxlib-v0.1.34
2019-11-19 15:30:10 -08:00
Matthew Johnson
1fcebbaa0e fix reference cycle in jaxpr tracing using weakrefs
As one step in tracing user code to a jaxpr using the machinery in
partial_eval.py, we construct a bipartite graph made of JaxprTracer
nodes, corresponding to values in the user code, and recipe nodes
,particularly those corresponding to jaxpr equations, representing
primitive operations. (This representation was put in place in #1224,
since when primitives only had single outputs we could identify each
primitive operation with the JaxprTracer value it produced.) This graph
had reference cycles because each equation recipe points to both its
input and output tracers (as a jaxpr eqn has both input and output vars)
and a tracer must be able to point to the equation recipe that produced
it (for us to toposort the graph from in_tracers to out_tracers in
tracers_to_jaxpr).

Those cycles caused memory leaks. This commit removes the strong
reference cycle using weakrefs. In particular, equation recipes only
hold weak references to their output tracers.

Before this change, we used the core.JaxprEqn struct both to represent
equations in jaxprs (where invars and outvars are instances of the
core.Var class) and to represent equation recipes (where invars and
outvars are instances of the partial_eval.JaxprTracer class). That was a
bit lazy. This commit distinguishes the two as separate JaxprEqn and
JaxprEqnRecipe structs.

Bug find and test code from @trevorcai. Thanks!
2019-11-19 15:23:08 -08:00
Matthew Johnson
5edda9e66c
remove very old "what we're working on" 2019-11-19 14:20:49 -08:00
James Bradbury
68c053319f
Merge pull request #1707 from google/jb/readme-jaxlib
Bump jaxlib version in README
2019-11-19 12:17:12 -08:00
Matthew Johnson
a3474fec38 bump version for pypi 2019-11-19 07:02:29 -08:00
Peter Hawkins
5c3b99d0b4
Implement the __pos__ operator on JAX arrays. (#1718) 2019-11-18 22:00:32 -05:00
Peter Hawkins
f95e3e969f
Check for None in indexer dtype check. (#1717) 2019-11-18 22:00:23 -05:00
Peter Hawkins
6cf2e4b8bf
Add type check that indexers are integers or boolean values. (#1716)
Improves error if, say, a float type is passed as an indexer.
2019-11-18 21:04:27 -05:00
Peter Hawkins
d323431d5e
Relax test tolerance for core_test jvp tests. (#1714) 2019-11-18 15:36:29 -05:00
Peter Hawkins
9679a87901
Avoid out-of-bounds dereference for arity-0 nodes. (#1713) 2019-11-18 15:35:07 -05:00
Peter Hawkins
42dd736afd
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.

Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.

This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.

In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
George Necula
397a244e7f
Merge pull request #1706 from gnecula/loops
An implementation of an experimental syntactic sugar for 'for' and `while` loops and conditionals.
2019-11-18 12:17:59 +01:00
George Necula
8ec6ea4742 Implemented suggestions from code review.
* added example of while_range to the module docstring.
* wrap the very long lines
2019-11-18 11:39:58 +01:00
Anselm Levskaya
f882359511
fix lax.scan notes in gotchas notebook
Note that lax.scan is now jittable and differentiable in the Gotchas notebook.
2019-11-17 00:19:24 -08:00
James Bradbury
5fa68774ea
Bump jaxlib version in README 2019-11-16 17:51:16 -08:00
Matthew Johnson
063419ab5f tweak test name (cf. #1704) 2019-11-16 14:40:25 -08:00
Chase Roberts
3978007be8 Explict typing 2019-11-16 14:39:13 -08:00
Chase Roberts
979a8d30b7 Cast perm to tuple 2019-11-16 14:39:13 -08:00
Peter Hawkins
bbf8129aa6
Change test tolerance logic not to choose tolerance values based on f… (#1701)
* Change test tolerance logic not to choose tolerance values based on flags (in particular, --jax_enable_x64).

We would like to move away from having global flags to enable 64-bit mode. We therefore need other methods to select test tolerances. Instead, use a per-type default tolerance, and allow tests to pass per-type dictionaries of tolerances as atol and rtol values. Fix up a number of tolerances to make tests pass.

* Fix test tolerances.

* Fix dtype canonicalization for test tolerances.

* Relax core test_vjp tolerance.
2019-11-16 13:51:42 -05:00
George Necula
d549d44e43 Improved documentation
Also fix for the Python 2 iterators.
2019-11-16 18:36:08 +01:00
George Necula
64e186c337 Fix tests for Python 2 and for X64 2019-11-16 18:05:45 +01:00
George Necula
d24c374d59 An implementation of an experimental syntactic sugar for 'for' loops.
See description in jax/experimental/loops.py.
2019-11-16 17:23:40 +01:00
Peter Hawkins
9b853a4255
Update XLA. (#1702)
Add support for building a CPU-only jaxlib with a CUDA-enabled toolchain.
2019-11-16 11:01:36 -05:00
Peter Hawkins
4fc765241f
Drop protobuf dependency from jax package. It appears unused. (#1700) 2019-11-15 14:55:26 -05:00
Peter Hawkins
f4aa5150e8
Move internal type-related functions into a new (internal) jax.types … (#1695)
* Move internal type-related functions into a new (internal) jax.types module.

Avoid calling onp type functions in lieu of the wrappers in jax.types. Currently these do the same thing, but future changes will make the behavior of the jax type functions diverge from the classic NumPy versions in some cases.

Move xla_bridge.canonicalize_dtype into jax.types, since it fits there more naturally.

* Rename jax.types to jax.dtypes.

* s/types/dtypes/ in tests.
2019-11-15 10:02:51 -05:00
Matthew Johnson
3ac3271381 improve docs on shape/dtype loop stability
also tweak how some error messages are printed, and corresponding tests

fixes #1686
2019-11-14 22:28:47 -08:00
Matthew Johnson
dbf41348a0
Merge pull request #1697 from google/grad-argnums-error-message
improve grad error message without enough args
2019-11-14 21:48:18 -08:00
Matthew Johnson
728cb7fba8 improve grad error message without enough args
fixes #1696
2019-11-14 21:18:23 -08:00
Matthew Johnson
be28700b8b skip some tests on tpu 2019-11-14 16:51:39 -08:00
Matthew Johnson
73f7edba08
Merge pull request #1694 from google/issue1688
fix shard_args logic, closes #1688
2019-11-14 16:30:38 -08:00
Matthew Johnson
c19e65b7ab fix shard_args logic, closes #1688
Co-authored-by: Skye Wanderman-Milne <skyewm@google.com>
2019-11-14 16:16:21 -08:00
Skye Wanderman-Milne
e95098de54 Update jax version to 0.1.51. 2019-11-14 15:18:41 -08:00
android
3f0c1cd9dd Add TPU Driver as JAX backend for high-performance access to Google Cloud TPU hardware. (#1675) 2019-11-14 14:00:08 -08:00
Peter Hawkins
9ffdd6bdcf
Add a type check that verifies the lower and upper arguments to lax.fori_loop have equal types. (#1693) 2019-11-14 16:18:00 -05:00
Peter Hawkins
cc0568ef49
Remove test_util.check_raises_regexp. (#1692)
It does nothing that the builtin self.assertRaisesRegexp doesn't already do.
2019-11-14 16:00:55 -05:00