127 Commits

Author SHA1 Message Date
Peter Hawkins
3fc1fdb148 Add a JVP rule for the general case of lax.reduce. 2021-03-30 17:31:47 -04:00
Matthew Johnson
2b79264354 remove disable_omnistaging mechanism 2021-03-29 15:26:57 -07:00
jax authors
2022141b13 Merge pull request #6208 from majnemer:int-conv
PiperOrigin-RevId: 365544250
2021-03-29 04:20:05 -07:00
Matthew Johnson
8547c71bfd simplify public lax.convert_element_type api
Specifically:
1. don't expose weak_type in the public api, as it's jax-internal
2. don't make new_dtype optional, which could make bugs easier

This change keeps the public API simpler, and also makes
convert_element_type match the ConvertElementType HLO. As an internal
API we can call lax._convert_element_type just like before.
2021-03-28 10:32:02 -07:00
David Majnemer
7defa05009 Allow integer/boolean convolutions 2021-03-24 23:20:30 -07:00
Matthew Johnson
89768a3d28 add jax_default_matmul_precision flag & context mngr 2021-03-24 14:03:58 -07:00
Matthew Johnson
214d273d8c undo changes to host_callback (not needed anymore) 2021-03-21 19:43:12 -07:00
Matthew Johnson
fe4d12c10f move logic to traceable 2021-03-21 19:38:12 -07:00
Matthew Johnson
8c3125c172 fix convert_element_type on large Py int inputs 2021-03-21 19:08:59 -07:00
Matthew Johnson
af59542d00 Re-applying the changes in #6014, after they had to be rolled-back.
PiperOrigin-RevId: 364200195
2021-03-21 13:40:20 -07:00
Matthew Johnson
57d5c6af5f add clz primitive 2021-03-19 22:54:36 -07:00
Roy Frostig
7427991819 skip scalars when broadcasting for batch dimension agreement 2021-03-19 21:47:16 -07:00
jax authors
4f8814a760 Copybara import of the project:
--
bf15ba5310d5f9009571928f70548bcbc7e856c3 by Matthew Johnson <mattjj@google.com>:

don't device transfer in convert_element_type

Co-authored-by: Qiao Zhang <zhangqiaorjc@google.com>
PiperOrigin-RevId: 363995032
2021-03-19 16:35:37 -07:00
Matthew Johnson
bf15ba5310 don't device transfer in convert_element_type
Co-authored-by: Qiao Zhang <zhangqiaorjc@google.com>
2021-03-19 13:42:33 -07:00
Jake VanderPlas
5f51d4fb1d Make lax._const() work for non-canonical dtypes 2021-03-17 13:07:53 -07:00
Peter Hawkins
328930b917 Increase minimum jaxlib version to 0.1.62. 2021-03-16 15:11:36 -04:00
Tamas Berghammer
2ea526102d Add new lax.rng_bit_generator primitive
The new primitive provides access to the RngBitGenerator HLO
(https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator)
2021-03-16 16:30:09 +00:00
Jacob Austin
9d28b67022
Fixed two small typos in jax.lax. 2021-03-15 23:26:31 -04:00
Jake VanderPlas
04bf02a4b6 convert_element_type: don't canonicalize old_dtype 2021-03-12 15:26:06 -08:00
Peter Hawkins
62a726d329 Add workaround for SelectAndScatter padding bug on CPU and GPU. 2021-03-10 15:25:32 -05:00
Peter Hawkins
140c0acbbe Remove the JAX lazy sublanguage.
Back in the mists of time, before omnistaging landed in JAX, we used lazy
expressions to avoid materializing large constants inside `jit` computations.
Omnistaging, which means that computations that are in the dynamic scope of a
`jit` are staged into the `jit` computation, has subsumed most of the reasons
for laziness to exist, and this PR removes the laziness support for simplicity.

At the time of this PR, laziness is used only for broadcasts and transposes in
eager mode (i.e., outside a `jit`). This allows us to:
a) fuse together multiple broadcasts and transposes, and
b) if a lazy expression is lexically captured by a `jit` computation, we can
   avoid materializing it in its expanded form.

It is not clear that laziness has sufficient power to weight ratio to continue
to exist, and it is making other work on improving JAX dispatch times more
difficult. As a result, this PR removes laziness to unblock that work; if we
want laziness again we would want to reimplement it in C++ anyway.
2021-03-09 21:40:46 -05:00
Roy Frostig
e779ed8299 simplify standard named_shape_rule
Co-authored-by: Matthew Johnson <mattjj@google.com>
2021-03-09 13:48:26 -08:00
James Bradbury
c622422dad [avals with names] Propagate presence of name (mapped) vs absence (replicated) in abstract eval based on existing batching rules 2021-03-09 13:48:15 -08:00
Peter Hawkins
2469ad1bb3 Cleanups for laziness. No functional changes intended.
Use None as a trivial lazy expression in more places. Simplify some code.
2021-03-07 11:33:04 -05:00
Peter Hawkins
afd2aa2ea0 Remove device constants from lazy language.
Updated version of #4536.

This is removing the device constant part of #1668. We can do this because after #3370 and #4038 omnistaging removes the need for lazy device constants in a jitted context. (They could still in principle be useful in an op-by-op context, but the power:weight isn't worthwhile anymore.)

After this change, the only parts of the lazy sublanguage that remain are those to do with broadcasts and transposes. We may or may not kill those in a follow-up (it hinges on whether any benefit to op-by-op execution is worth the extra complexity).

This change regresses non-omnistaging users. As one particular example, test_eval_shape_big_random_array no longer passes with omnistaging disabled.
2021-03-03 21:17:31 -05:00
Peter Hawkins
0dd1b5516d Implement lax.pad batching rule for batched padding values. 2021-02-17 13:30:20 -05:00
Matthew Johnson
9b18135b6e Rollback of #5702 due to internal breakage.
PiperOrigin-RevId: 357943850
2021-02-17 07:32:09 -08:00
jax authors
31a187c1c3 Merge pull request #5702 from google:awn
PiperOrigin-RevId: 357854529
2021-02-16 19:08:03 -08:00
Jake VanderPlas
e56c1d9d0c DOC: suppress some warnings 2021-02-16 17:18:38 -08:00
Roy Frostig
30dd558b27 simplify standard named_shape_rule
Co-authored-by: Matthew Johnson <mattjj@google.com>
2021-02-16 15:46:14 -08:00
James Bradbury
fb160b8afd [avals with names] Propagate presence of name (mapped) vs absence (replicated) in abstract eval based on existing batching rules 2021-02-16 15:46:14 -08:00
Matthew Johnson
268493bae8 specialize standard_primitive back to single-out 2021-02-12 10:30:46 -08:00
Jake VanderPlas
41b7a0f770 Re-land #4850 weak types change 2021-02-09 09:07:52 -08:00
Peter Hawkins
ac512718cf Add note that lax.reduce requires a monoid. 2021-02-08 09:23:35 -05:00
Matthew Johnson
ca4f7f7964 add check for __jax_array__ method before error
Before raising an error on an unrecognized type, first check if the
object defines a __jax_array__ method. If it does, call it!

This provides a way for custom types to be auto-converted to
JAX-compatible types.

Implementing this method is not sufficient for a type to be duck-typed
enough for use with jax.numpy. But it may be necessary. That is, someone
trying to add a duck-typed array to be used with JAX identified a need
for __jax_array__ or similar. The user would still need to add lots of
other properties and methods, like dtype and shape attributes.

revives #4725 after it was rolled back. fixes #5356.
2021-02-05 20:30:14 -08:00
George Necula
f105517ea2 Fixed mypy type errors for numpy 1.20
Revert also previous changes that pinned numpy to 1.19.

One of the changes in numpy 1.20 is to add more type annotations.
However, this sometimes make mypy give errors. A common example is
numpy.take, which with the new type annotation does not appear to
mypy as indexable.

Another change is that np.int and np.bool are deprecated. One
should use np.bool_ or np.int_, or the built-ins bool and int.
2021-02-05 10:40:47 +02:00
jax authors
8719609bd7 Merge pull request #5547 from skye:fix
PiperOrigin-RevId: 354434034
2021-01-28 17:38:34 -08:00
Skye Wanderman-Milne
997e6efa9c Improve error message when a reduction function returns an invalid return type.
Fixes #5536

Co-authored-by: Matthew Johnson <mattjj@google.com>
2021-01-28 15:36:15 -08:00
Jake VanderPlas
af6da229da DOC: fix some minor formatting issues 2021-01-28 15:20:02 -08:00
James Bradbury
f1918f0b19 [avals with names] Revise aval constructor call sites to use a new aval.update method
PiperOrigin-RevId: 354182876
2021-01-27 15:14:02 -08:00
jax authors
62e89cbdf8 Merge pull request #5213 from malmaud:changelist/346683050
PiperOrigin-RevId: 352800932
2021-01-20 08:45:06 -08:00
Jake VanderPlas
0a89fc83cb integer_pow: fix jvp rule for y=0 2021-01-19 15:42:40 -08:00
Jake VanderPlas
7acf521d49 integer_pow: fix translation rule for y=0 2021-01-19 15:42:18 -08:00
Jake VanderPlas
9076008c62 lax.integer_pow(): always bind the primitive 2021-01-19 11:36:39 -08:00
Jonathan Malmaud
c0c4843b93 Add support for 'preferred_element_type' keyword arg in dot and dot_general.
XLA recently added support for this parameter to xops.DotGeneral. It's an optional parameter that controls the accumulation type used by the dot operation.

This is useful for eg quantized ANNs, where you might want to do matrix multiples with int8 tensors and get back an int32 tensor instead of an int8 tensor that suffers from severe overflow. Note it's not sufficient in this case to cast the inputs to 'dot' to int32 beforehand and rely on the default output dtype inference, since backend devices might have an accelerated path for int8*int8->int32 matmuls and we want that explicitly represented in the XLA.

Note because XLA still doesn't support integer dots on the CPU backend, that use case can't tested with a CPU-only test at the moment.
2021-01-19 18:56:46 +00:00
George Necula
67b5af97f7 Copybara import of the project:
--
9be685946252edc67c2c28261b100b9aee68614a by George Necula <gcnecula@gmail.com>:

Change the translation rule for lax.nextafter_p to ensure
broadcasting during translation.

Previously, this was the only binary arithmetic primitive that
did not have broadcasting during translation. Trying to use it
with non-equal shapes resulted in the error:

```
 RuntimeError: Internal: RET_CHECK failure
(external/org_tensorflow/tensorflow/compiler/xla/client/xla_builder.cc:748)
non_scalar_shape.value().dimensions() == shape->dimensions() Unimplemented
implicit broadcast.:
       This is a bug in JAX's shape-checking rules; please report it!
```

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/5448 from gnecula:nextafter 9be685946252edc67c2c28261b100b9aee68614a
PiperOrigin-RevId: 352367039
2021-01-18 01:59:55 -08:00
Samuel Marks
6e458e2237
[*.py] Rename "Arguments:" to "Args:" 2021-01-15 11:49:19 +11:00
Jake VanderPlas
14fc5ce0de Validate axes values in reduce_op_shape_rule 2021-01-13 14:16:54 -08:00
Peter Hawkins
3ac809ede3 [JAX] Move jax.util to jax._src_util.
PiperOrigin-RevId: 351234602
2021-01-11 14:21:07 -08:00
Jake VanderPlas
98aac23d92 Change from deflinear to deflinear2 2021-01-05 09:03:33 -08:00