8 Commits

Author SHA1 Message Date
Peter Hawkins
70b8a6a806 Add a prototype IREE backend for JAX.
This is to support experimentation with the combination of JAX/IREE. Many things do not work yet.

PiperOrigin-RevId: 409980064
2021-11-15 07:57:04 -08:00
Peter Hawkins
3361c76dca Consolidate primitive and jit lowering paths.
Before this change, primitives have a special case dispatch path that attempts
to avoid building a jaxpr in the cache miss case. However, there's no good
reason for this: it makes the code more complicated, and we're not particularly
optimizing for fast cache misses anyway (we care mostly about cache hits).

Make the primitive lowering path trace a small function using the xla_callable
lowering path instead.
2021-10-13 12:36:53 -04:00
Yin Li
ece532556b
Simplify shape comparison with numpy assert 2021-10-05 11:30:19 -04:00
Yin Li
5d675220c0
Add float0 support to equality and closeness check 2021-10-04 21:32:57 -04:00
Peter Hawkins
256e7220ff [JAX] Fix pylint errors.
* trailing-whitespace
* dangerous-default-value. None of these appear to be bugs in practice, but the potential for accidentally mutating the default value is there, and the cost of avoiding the problem is small.
* invalid-envvar-default. Pass strings as getenv() defaults.
* unnecessary-semicolon. Use tuples instead for this one-liner.
* invalid-hash-returned. Raise an exception rather than asserting false.
* pointless-string-statement. Use comments instead.
* unreachable. Use @unittest.skip() decorator rather than raising as first line in test.
* logging-not-lazy. Make the logging lazy.
* bad-format-string-type. Use f-string instead.
* subprocess-run-check. Pass check=...

PiperOrigin-RevId: 400858477
2021-10-04 17:54:46 -07:00
Adam Paszke
22dce0f483 Add reverse-mode AD support for pjit
This is a somewhat big patch, because the transposition process turns out to be
quite difficult. The biggest issue appears when we do partial evaluation and we have
to add a whole bunch of intermediate values as outputs of the primal computation,
but we don't have any partition specs for them!

A simple workaround would be to mark all of them as replicated, but that would
likely tank performance which is why we didn't go with that option. Instead, we use
a newly added XLA option called `allow_spmd_sharding_propagation_to_output` to compile
a throwaway executable that lets us query output sharding that XLA considers convenient
for the computation.

However, there's one more difficulty: XLA's `OpSharding` is much less constrained
than our `PartitionSpec`s. In particular, while `PartitionSpec`s can only represent
"block permutations" of devices (with blocks deliniated by mesh axes), `OpSharding`
allows arbitrary assignment (permutation) of tensor chunks to devices. This means that
not every `OpSharding` has a corresponding `PartitionSpec`, but I did implement a
(somewhat involved) procedure that should recover one whenever it exists.

Unfortunately this makes our support for reverse-mode AD partial, because we might
be unable to handle `OpSharding` returned by XLA. But this will only happen if XLA
actually comes up with sharding specifications on its own. If it merely propagates
the sharding obtained from `PartitionSpec`s into the middle of the computation, then
we should be good. In any case, if we end up seeing failures in this path, we should
consider relaxing `PartitionSpec`s, but that would be a pretty large change, so I decided
to avoid it unless there's no other way.

PiperOrigin-RevId: 399680306
2021-09-29 07:19:55 -07:00
Nicholas Junge
8a68f7761b Change absl flag documentation to mention re.search
The absl help texts for test target discovery mention that targets
will be discovered by `re.match` use. However, in the subsequent
implementation, actually `re.search` is used. This commit changes the
help texts for the `test_targets` and `exclude_test_targets` flags
to correctly mention `re.search` as the discovery algorithm.
2021-09-28 18:42:44 +02:00
Peter Hawkins
db2e91eba2 Move jax.test_util to jax._src.test_util.
Add forwarding shims for names used by external clients of JAX in practice.

PiperOrigin-RevId: 398721725
2021-09-24 07:02:49 -07:00