28 Commits

Author SHA1 Message Date
Peter Hawkins
f4aa5150e8
Move internal type-related functions into a new (internal) jax.types … (#1695)
* Move internal type-related functions into a new (internal) jax.types module.

Avoid calling onp type functions in lieu of the wrappers in jax.types. Currently these do the same thing, but future changes will make the behavior of the jax type functions diverge from the classic NumPy versions in some cases.

Move xla_bridge.canonicalize_dtype into jax.types, since it fits there more naturally.

* Rename jax.types to jax.dtypes.

* s/types/dtypes/ in tests.
2019-11-15 10:02:51 -05:00
Peter Hawkins
b8a5473614 Add experimental support for XLA infeed/outfeed. 2019-10-09 15:05:54 -04:00
Matthew Johnson
b702f8de3e De-tuplify the rest of the core
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-08-21 13:21:20 -07:00
cclauss
f17f2b976f Mark instances of 'long' with # noqa 2019-08-05 00:22:41 +02:00
Matthew Johnson
fb1e2124ff enable staging out more multi-replica computations
There are two real changes here:

1. In api.py, improve the handling of the axis environment in
`xla_computation` so that `xla_computation(pmap(lambda x: x))(x)` works,
by checking for pmap's included in the jaxpr to be staged out (analogous
to how jit-of-pmap works).

2. In pxla.py, handle as a special case the pmapping of computations for
which the output does not depend on the input. The purpose here is to
enable `xla_computation(pmap(lambda x: x))(x)` when `x = np.arange(8)`
yet only one XLA device is available. Evaluating that expression leads
to the (partial) evaluation of a trivial pmap (unit / empty-tuple inputs and
outputs), which would cause an error when we attempt to compile an XLA
computation for more replicas than available hardware devices. We don't
know the computation is trivial until after we've run the function, i.e.
until we're in the xla_pmap impl, so this is the right place to do it.

The other changes are unrelated miscellania.
2019-07-09 15:12:02 -07:00
Matthew Johnson
5aef18f897 improve literal hashing logic
This fixes a bug where scalar ndarray literals with different dtypes
could hash to the same value. It also makes scalar DeviceArray literals
hashable after #884.
2019-06-19 10:32:55 -07:00
Matthew Johnson
9c931ddebe allow more types to be jaxpr literals, fixes #772 2019-05-28 22:38:06 -07:00
Matthew Johnson
29629931a1
Merge pull request #704 from google/differentiable-scan
Differentiable scan!
2019-05-13 10:26:09 -07:00
Peter Hawkins
367833bea2 Changes for compatibility with a upcoming Jaxlib update.
Shape.abstract_arrays will only accept dtypes, not scalar type objects.
Add long to the set of types known to abstract_arrays in Python 2.
Make api_test.py accepting of long values in shapes.
2019-05-08 20:32:24 -04:00
Matthew Johnson
15d783a836 Merge remote-tracking branch 'origin/master' into differentiable-scan 2019-05-08 13:42:44 -07:00
Matthew Johnson
a17f8e4ca8 add jaxpr eqn structured input, transpose progress
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-05-08 13:41:19 -07:00
Matthew Johnson
6736823021 victory! patial eval of scan (+ linearize!)
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-05-08 13:41:15 -07:00
Matthew Johnson
85755820bb add defvjp functions for custom VJPs
c.f. #116, which won't be closed until we add documentation
2019-04-23 17:47:28 -07:00
Matthew Johnson
aa5b036d6d misc python performance improvements 2019-04-15 07:45:10 -07:00
Peter Hawkins
356e6dcfe8 Add np.longlong to the set of JAX array types. 2019-04-01 15:49:12 -07:00
Matthew Johnson
65b6f19cf8 add a better error message on cond pval join error 2019-03-02 18:08:34 -08:00
Peter Hawkins
0129e94f79 Add {float16,uint16,uint8,int16,int8} types to abstract_arrays.
In principle this allows these types to be used. They are as yet untested, however.

Fixes #75.
2019-02-17 17:18:20 -05:00
Peter Hawkins
d43c65dcd8 Add preliminary support for np.complex128.
Only lightly tested.
2019-01-11 18:22:43 -05:00
Matthew Johnson
0f7c7c4eab generalize jacfwd and jacrev to handle pytrees 2019-01-06 12:49:41 -08:00
Matthew Johnson
1ae1ae17a2 add EyeConstant, new np.eye and np.array code 2018-12-18 09:16:59 -08:00
Matthew Johnson
bfe653c6b0 Tracer.__len__ should reflect on abstract value
This old implementation, which was meant to be revised but which we
forgot about, caused a surprising slowdown: if x were a traced array of
size 50000, evaluating len(x) would create 50000 traced temporary
objects, which led to a lot of overhead! That came up in our
implementation of jax.random.shuffle, which happened to call len()
instead of x.shape[axis] (even though it should have been using x.size
anyway, according to tjablin@'s code that it's based on).
2018-12-15 20:07:10 -08:00
Peter Hawkins
0d4eb6c1e1 Make JAX flake8-clean.
Fixes #1.
2018-12-13 15:29:39 -05:00
Matthew Johnson
9cd6027979 add python built-in complex type to array types
fixes #74
2018-12-11 12:14:57 -08:00
Peter Hawkins
c1b9eb19ea [JAX] Change semantics of dtype promotion to just call numpy.result_type.
* Enable tests for numpy scalars in lax_numpy_test.py.
* Fix invalid promotion in random.py.
* Split tests for bitwise ops into their own test case and test mixed signedness.
* Add complex64 to the set of types supported by abstractify.
2018-12-06 13:25:42 -05:00
Dougal Maclaurin
ca2634ea5d source sync
PiperOrigin-RevId: 222923229
2018-11-27 16:51:22 -08:00
Peter Hawkins
5e60639bc5 source sync
PiperOrigin-RevId: 222452709
2018-11-21 20:22:54 -08:00
Peter Hawkins
e180f08113 source sync
PiperOrigin-RevId: 222451919
2018-11-21 20:22:51 -08:00
Matthew Johnson
a30e858e59 populating source tree 2018-11-17 18:03:33 -08:00