319 Commits

Author SHA1 Message Date
Peter Hawkins
3ddcec27f2 Update minimum jaxlib version to 0.1.69. 2021-07-15 17:00:13 -04:00
George Necula
0beef34d25 [jax2tf] Fix conversion for argmin/argmax; add conversion for reduce
The previous conversion for argmin/argmax simply used tf.argmin and tf.argmax.
Those ops behave differently than JAX when the inputs contain NaN and Inf. Added
a few test cases in primitive_harness to expose the failures.

In order to implement an accurate conversion of argmin/argmax, we need to use the
XLA Reduce op.

Also tightened the shape checks for lax.argmin and lax.argmax, to ensure they are
not used with an empty reduced dimension. E.g., if the axis=-1, previously we got
an internal error:
```
RuntimeError: Invalid argument: Reducing out-of-bounds dimension -1 in shape f32[2,0,3].:
This is a bug in JAX's shape-checking rules; please report it!
```
PiperOrigin-RevId: 384182794
2021-07-12 01:11:42 -07:00
George Necula
1f946ad51e Fix grad of conv 0D.
This bug was introduced in #6345, and was not caught by existing tests.
Add a reproducing test.
2021-07-11 10:46:42 +03:00
Jake VanderPlas
27fc797a67 Add dynamic slice U8 index test 2021-06-23 13:29:15 -07:00
Peter Hawkins
4406c77086 [XLA] Improve the accuracy of complex log1p close to 0.
https://github.com/google/jax/issues/7004

PiperOrigin-RevId: 380188368
2021-06-18 08:12:23 -07:00
George Necula
a4e74a269c Disable testing dot_general with preferred_element_type on GPU.
Due to a XLA bug we get non-deterministic NaN on GPU.
This fixes flakiness in lax_test.py

PiperOrigin-RevId: 380179946
2021-06-18 07:12:54 -07:00
Peter Hawkins
6cc440b79d Fix handling of NaNs in GPU argmax translation rule. 2021-05-18 11:35:54 -04:00
Peter Hawkins
1350d21881 Add regression test for #5728.
This issue appears to have been fixed by jaxlib 0.1.66.
2021-05-12 13:45:16 -04:00
Peter Hawkins
4f9eb64c04 [XLA] Fix incomplete gamma functions where x is infinity.
Issue https://github.com/google/jax/issues/6535

PiperOrigin-RevId: 373122943
2021-05-11 04:31:41 -07:00
Lukas Geiger
f7f42694d9 Add support for preferred_element_type arg in convolutions 2021-04-22 10:29:31 +02:00
jax authors
6cc4bb0476 Merge pull request #6420 from apaszke:faster-tests
PiperOrigin-RevId: 368519866
2021-04-14 15:24:06 -07:00
Adam Paszke
e0357283a6 Speed up test generation 7x
There are a few test cases that generate millions of configurations,
only to have a handful of them selected by `cases_form_list`. I've
found all tests that spend over 100ms in case generation and
converted them to a new "test sampler" approach. The result: test
generation time drops from 15s to around 2s. Doesn't sound like much,
but I expect that we all run tests many times daily, so it seems like a
useful thing to have.

The rough idea is that the sampling generators get parameterized by a
sampler function that should be applied to the range of every `for` loop.
This allows us to sample runs of the generator through different
configurations by restricting each loop to a smaller subset. Right now
we always narrow it down to a single randomly selected instance. But,
we still retain the possibility of adding exhaustive testing in the
future, which can be achieved by passing in an identity sampling
function that wouldn't modify any loop ranges.
2021-04-14 15:58:05 +00:00
Peter Hawkins
2a3da097ed Change scatter shape test to use eval_shape rather than instantiating concrete arrays.
Reduces space usage during testing.
2021-04-13 14:10:11 -04:00
Peter Hawkins
26e9ebcdae Move jax.api to jax._src.api.
PiperOrigin-RevId: 368233837
2021-04-13 09:43:24 -07:00
Peter Hawkins
a54a5e59ee Remove backward compatibility code paths for jaxlib < 0.1.65.
Fix up a few version comments.
2021-04-09 15:39:38 -04:00
Peter Hawkins
1800e43e70 [JAX] Make lax_reference private by moving it to jax._src.lax_reference.
PiperOrigin-RevId: 367441557
2021-04-08 09:06:14 -07:00
George Necula
cbe5f54cca Added support for lax.pad, and more error checking 2021-04-08 10:42:38 +03:00
Peter Hawkins
6a6f13e1b0 [JAX] Move contents of jax/dtypes.py to jax/_src/dtypes.py.
PiperOrigin-RevId: 367345623
2021-04-07 19:35:51 -07:00
Jake VanderPlas
33fde77bb1 Add lax.reduce_precision() 2021-04-05 09:54:14 -07:00
Matthew Johnson
2b79264354 remove disable_omnistaging mechanism 2021-03-29 15:26:57 -07:00
jax authors
2022141b13 Merge pull request #6208 from majnemer:int-conv
PiperOrigin-RevId: 365544250
2021-03-29 04:20:05 -07:00
Matthew Johnson
8547c71bfd simplify public lax.convert_element_type api
Specifically:
1. don't expose weak_type in the public api, as it's jax-internal
2. don't make new_dtype optional, which could make bugs easier

This change keeps the public API simpler, and also makes
convert_element_type match the ConvertElementType HLO. As an internal
API we can call lax._convert_element_type just like before.
2021-03-28 10:32:02 -07:00
David Majnemer
7defa05009 Allow integer/boolean convolutions 2021-03-24 23:20:30 -07:00
Matthew Johnson
fd7b286ec9 unify configuration state handling 2021-03-23 18:56:01 -07:00
Matthew Johnson
57d5c6af5f add clz primitive 2021-03-19 22:54:36 -07:00
Jake VanderPlas
6041e1b73c lax_test: increase test coverage 2021-03-18 07:41:03 -07:00
Jake VanderPlas
5f51d4fb1d Make lax._const() work for non-canonical dtypes 2021-03-17 13:07:53 -07:00
Peter Hawkins
328930b917 Increase minimum jaxlib version to 0.1.62. 2021-03-16 15:11:36 -04:00
Tamas Berghammer
2ea526102d Add new lax.rng_bit_generator primitive
The new primitive provides access to the RngBitGenerator HLO
(https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator)
2021-03-16 16:30:09 +00:00
Jake VanderPlas
04bf02a4b6 convert_element_type: don't canonicalize old_dtype 2021-03-12 15:26:06 -08:00
James Bradbury
c622422dad [avals with names] Propagate presence of name (mapped) vs absence (replicated) in abstract eval based on existing batching rules 2021-03-09 13:48:15 -08:00
Matthew Johnson
9b18135b6e Rollback of #5702 due to internal breakage.
PiperOrigin-RevId: 357943850
2021-02-17 07:32:09 -08:00
James Bradbury
fb160b8afd [avals with names] Propagate presence of name (mapped) vs absence (replicated) in abstract eval based on existing batching rules 2021-02-16 15:46:14 -08:00
Matthew Johnson
268493bae8 specialize standard_primitive back to single-out 2021-02-12 10:30:46 -08:00
Jake VanderPlas
41b7a0f770 Re-land #4850 weak types change 2021-02-09 09:07:52 -08:00
Jake VanderPlas
2fd682ef2a Make jax_enable_x64 a thread-local value. 2021-02-04 09:48:22 -08:00
Skye Wanderman-Milne
997e6efa9c Improve error message when a reduction function returns an invalid return type.
Fixes #5536

Co-authored-by: Matthew Johnson <mattjj@google.com>
2021-01-28 15:36:15 -08:00
Jonathan Malmaud
c0c4843b93 Add support for 'preferred_element_type' keyword arg in dot and dot_general.
XLA recently added support for this parameter to xops.DotGeneral. It's an optional parameter that controls the accumulation type used by the dot operation.

This is useful for eg quantized ANNs, where you might want to do matrix multiples with int8 tensors and get back an int32 tensor instead of an int8 tensor that suffers from severe overflow. Note it's not sufficient in this case to cast the inputs to 'dot' to int32 beforehand and rely on the default output dtype inference, since backend devices might have an accelerated path for int8*int8->int32 matmuls and we want that explicitly represented in the XLA.

Note because XLA still doesn't support integer dots on the CPU backend, that use case can't tested with a CPU-only test at the moment.
2021-01-19 18:56:46 +00:00
Peter Hawkins
3ac809ede3 [JAX] Move jax.util to jax._src_util.
PiperOrigin-RevId: 351234602
2021-01-11 14:21:07 -08:00
Jake VanderPlas
ca7d94646f Fix broadcast_in_dim transpose issue. 2020-12-29 12:58:07 -08:00
David Majnemer
242c7d5d8d Enable more tests on TPU
PiperOrigin-RevId: 349323353
2020-12-28 13:21:00 -08:00
Matthew Johnson
26c91aa185 remove unused import 2020-12-23 11:12:53 -08:00
Matthew Johnson
3dee321fb8 rollback of #4850 2020-12-23 11:01:58 -08:00
Jake VanderPlas
98f88152cd Fix bug in primitive_computation
fixes #4672
2020-12-17 12:51:35 -08:00
Jake VanderPlas
c820dbf44c Propagate weak_types in remaining lax primitives 2020-12-16 09:53:30 -08:00
Jake VanderPlas
7b097340bf Fix lax.convert_element_type() with dtype=None 2020-12-10 14:14:36 -08:00
Jake VanderPlas
c63097bc90 Add weak_type argument to convert_element_type_p 2020-12-10 11:10:21 -08:00
Jake VanderPlas
8a00b4e0ee lax.convert_element_type: always return DeviceArray 2020-12-07 09:10:34 -08:00
jax authors
9d70ef26d7 Merge pull request #5109 from minoring:fix-conv-stride-shape-rule
PiperOrigin-RevId: 346047432
2020-12-07 02:25:29 -08:00
George Necula
f51db5cd75
Update lax_test.py 2020-12-06 15:39:30 +02:00