5891 Commits

Author SHA1 Message Date
jax authors
92819f7b4b Merge pull request #8143 from jakevdp:union1d-fill-value
PiperOrigin-RevId: 402238900
2021-10-11 02:08:30 -07:00
George Necula
a75fb371f2 [jax2tf] Improved handling of getitem for shape polymorphism
* give an error for NumPy indexing with slices when the elements
  of the slices are not constant. This check existed, but was
  throing an error when the elements are dimension polynomials.
* give an error for NumPy indexing with slices when the dimension
  size is not constant.
* Improvements in the handling of enable_xla=False for shape
  polymorphism.
* Added test cases for the above.
2021-10-11 09:14:57 +02:00
George Karpenkov
f2aef25fba Use variadic reduce on GPU for argmax/argmin
PiperOrigin-RevId: 401923051
2021-10-08 21:58:56 -07:00
Jake VanderPlas
a4241a2aa3 jnp.union1d: add optional fill_value argument 2021-10-08 15:18:25 -07:00
jax authors
da1caf5d6d Merge pull request #7997 from google:aot
PiperOrigin-RevId: 401837898
2021-10-08 13:03:40 -07:00
Roy Frostig
0c75f52fa8 ahead-of-time lowering and compilation for jit 2021-10-08 10:54:45 -07:00
Roy Frostig
75468c7495 factor out jit input preparation 2021-10-08 10:54:45 -07:00
Jake VanderPlas
486aac949a jnp.array: handle raw device buffers 2021-10-08 10:41:43 -07:00
jax authors
dd5df5a562 Merge pull request #8121 from jakevdp:unique-fill-value
PiperOrigin-RevId: 401785306
2021-10-08 09:10:05 -07:00
George Necula
3938018228 Applied review suggestsions 2021-10-08 10:11:31 +02:00
jax authors
f9bead4b75 Merge pull request #8135 from google:default-rng
PiperOrigin-RevId: 401692501
2021-10-07 23:06:34 -07:00
jax authors
2028087a04 Merge pull request #8137 from mattjj:shaped-array-len
PiperOrigin-RevId: 401690179
2021-10-07 22:51:02 -07:00
Matthew Johnson
482e41d796 remove ShapedArray.__len__
It was confusing to overload, since we sometimes think of avals like
shapes paired with dtypes, and in that case len(aval) should perhaps be
like len(aval.shape). The only place where this behavior was relied on
was sparse/ops.py.
2021-10-07 22:04:16 -07:00
Roy Frostig
98d245ebb4 add a config setting to control the default PRNG implementation
Also add explicit seeding functions for each PRNG implementation.
2021-10-07 21:22:40 -07:00
Matthew Johnson
022cb8c0fc rbg_split and rbg_fold_in: use vmap for fewer HLOs 2021-10-07 21:19:06 -07:00
jax authors
b002bc178e Merge pull request #8123 from mattjj:fix-rng-bit-generator-again
PiperOrigin-RevId: 401673628
2021-10-07 20:39:24 -07:00
Matthew Johnson
634d252bb3 improvements to RBG PRNG
1. factor out rbg_prng_impl and unsafe_rbg_prng_impl. the former uses
   threefry2x32 for split and fold_in, while the latter uses untested
   heuristics based on calling rng_bit_generator itself as a kind of
   hash function
2. for unsafe_rbg_prng_impl's split and fold_in, generate longer
   sequences from rng_bit_generator (10x iterations) which may be useful on
   some backends
3. for unsafe_rbg_prng_impl, actually apply rng_bit_generator as our
   'hash function' in fold_in

Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Anselm Levskaya <levskaya@google.com>
2021-10-07 18:59:13 -07:00
jax authors
efa5edfd39 Merge pull request #8091 from eelregit:check-eq-float0
PiperOrigin-RevId: 401494757
2021-10-07 06:23:59 -07:00
jax authors
c877f8bcd2 Merge pull request #8094 from sracaniere:patch-2
PiperOrigin-RevId: 401485340
2021-10-07 05:18:04 -07:00
jax authors
b0b60d6293 Merge pull request #8116 from jakevdp:at-docs
PiperOrigin-RevId: 401447782
2021-10-07 01:16:53 -07:00
jax authors
8f0589f085 Merge pull request #8117 from LenaMartens:changelist/400933831
PiperOrigin-RevId: 401418826
2021-10-06 21:54:20 -07:00
Jake VanderPlas
0b93c46c71 jnp.unique: add fill_value for when size is not None 2021-10-06 16:28:36 -07:00
jax authors
3c117fd6ed Merge pull request #8090 from skye:compilation_cache_xla_flags
PiperOrigin-RevId: 401343120
2021-10-06 14:42:18 -07:00
Skye Wanderman-Milne
f939048e62 Include XLA_FLAGS in persistent compilation cache key.
This is to prevent false cache hits when the compiler behavior is
changed via flags. Flags known to not affect the compiled executable
(e.g. dumping HLO) are excluded from the key.

Note that any XLA flags with arguments should use = and not a space,
e.g. `--xla_flag=value`, not `--xla_flag value`. I believe this is
already a requirement of ABSL flags in general, but I'm not 100% sure.

Also note that this doesn't currently support XLA flags specified via
--flagfile. Please file a feature request if this is needed.
2021-10-06 14:11:40 -07:00
Jake VanderPlas
359f55eb65 [sparse] fix type checking issue 2021-10-06 12:08:22 -07:00
jax authors
1f5bfde66b Merge pull request #8102 from jakevdp:refactor-bcoo
PiperOrigin-RevId: 401294831
2021-10-06 11:24:17 -07:00
Jake VanderPlas
bba04e0985 Document extra arguments to jnp.ndarray.at[] 2021-10-06 11:22:00 -07:00
Jean-Baptiste Lespiau
803b83ee15 Enable C++ pmap.
On CPU:
```
name                                     old cpu/op  new cpu/op  delta
pmap_trivial_2_devices                    128µs ± 6%    14µs ± 3%  -89.06%  (p=0.008 n=5+5)
pmap_trivial_dispatch_8_devices           212µs ± 2%    35µs ± 1%  -83.54%  (p=0.008 n=5+5)
pmap_trivial_8_devices                    215µs ± 1%    40µs ± 4%  -81.31%  (p=0.008 n=5+5)
pmap_simple_2_devices                     123µs ± 5%    15µs ± 6%  -87.70%  (p=0.008 n=5+5)
pmap_simple_dispatch_8_devices            211µs ± 3%    35µs ± 2%  -83.24%  (p=0.008 n=5+5)
pmap_simple_8_devices                     217µs ± 5%    40µs ± 2%  -81.68%  (p=0.008 n=5+5)
pmap_simple_dispatch_8_devices_100_args  5.42ms ± 7%  0.52ms ± 2%  -90.44%  (p=0.008 n=5+5)
pmap_simple_8_devices_100_args           26.5ms ±21%  17.5ms ±37%  -34.18%  (p=0.008 n=5+5)
sda_index_1                              7.45µs ± 6%  7.53µs ± 6%     ~     (p=0.222 n=5+5)
sda_index_2                              14.1µs ± 1%  14.3µs ± 4%     ~     (p=0.690 n=5+5)
sda_index_8                              56.0µs ± 3%  56.9µs ± 4%     ~     (p=0.310 n=5+5)

name                                     old time/op             new time/op             delta
pmap_trivial_2_devices                    136µs ± 8%               19µs ± 3%  -86.08%          (p=0.008 n=5+5)
pmap_trivial_dispatch_8_devices           216µs ± 3%               39µs ± 2%  -81.94%          (p=0.008 n=5+5)
pmap_trivial_8_devices                    219µs ± 2%               49µs ±38%  -77.67%          (p=0.008 n=5+5)
pmap_simple_2_devices                     130µs ± 5%               20µs ± 5%  -84.38%          (p=0.008 n=5+5)
pmap_simple_dispatch_8_devices            216µs ± 3%               39µs ± 5%  -81.71%          (p=0.008 n=5+5)
pmap_simple_8_devices                     221µs ± 6%               43µs ± 1%  -80.41%          (p=0.016 n=5+4)
pmap_simple_dispatch_8_devices_100_args  5.52ms ± 7%             0.59ms ± 2%  -89.28%          (p=0.008 n=5+5)
pmap_simple_8_devices_100_args           26.6ms ±21%             17.6ms ±37%  -34.04%          (p=0.008 n=5+5)
sda_index_1                              7.48µs ± 8%             7.53µs ± 6%     ~             (p=0.310 n=5+5)
sda_index_2                              14.1µs ± 1%             14.3µs ± 4%     ~             (p=0.690 n=5+5)
sda_index_8                              56.0µs ± 3%             56.9µs ± 4%     ~             (p=0.310 n=5+5)
```

PiperOrigin-RevId: 401274089
2021-10-06 10:08:28 -07:00
Lena Martens
342948dcc4 Add batching rule for rng_bit_generator. 2021-10-06 17:50:07 +01:00
Jake VanderPlas
c2dd90e3a0 [sparse] Factor BCOO-related routines into a separate submodule 2021-10-06 08:06:18 -07:00
jax authors
e679302dfa Merge pull request #8109 from jakevdp:sparsify-spdot
PiperOrigin-RevId: 401218394
2021-10-06 05:25:08 -07:00
jax authors
9d6dcfb0b3 Merge pull request #8110 from google:all-gather-hlo
PiperOrigin-RevId: 401218393
2021-10-06 05:24:49 -07:00
jax authors
de6705be89 Merge pull request #8107 from jakevdp:copy
PiperOrigin-RevId: 401218373
2021-10-06 05:20:42 -07:00
jax authors
26b70db194 Merge pull request #8106 from google:fix-rng-bit-generator-bit-manipulation
PiperOrigin-RevId: 401143001
2021-10-05 21:03:08 -07:00
Matthew Johnson
8d641a1d1b fix rng_bit_generator translation rule
Fix two issues:
1. bit packing was incorrect
2. output key had different shape from input key

Co-authored-by: Lena Martens <lenamartens@google.com>
2021-10-05 20:19:57 -07:00
Jake VanderPlas
e22c232c31 jnp.array: replace host round-trip with on-device copy 2021-10-05 20:10:57 -07:00
Peter Hawkins
104a46594b Add DeprecationWarnings to jax.ops.index_... operators.
Remove uses of index_... in Common Gotchas notebook.
2021-10-05 20:47:22 -04:00
James Bradbury
86022adf2f
Use all_gather+reduce_scatter HLOs on TPU
The all-gather and reduce-scatter HLOs were wired through for GPU but not TPU, but they should also work there (and be more performant than the all-reduce based fallback).
2021-10-05 17:17:48 -07:00
Jake VanderPlas
3a440d665f [sparse] add sparsify support for sparse-sparse matmul 2021-10-05 16:45:48 -07:00
Jake VanderPlas
a0c2fe0dfd Remove duplicate import 2021-10-05 15:47:09 -07:00
Peter Hawkins
6a284ce5ad Fix incorrect EllipsisType reference for Python 3.10 2021-10-05 16:16:59 -04:00
Peter Hawkins
29447ed261 Fixes for Python 3.10.
With these changes, the JAX test suite passes on Python 3.10.
2021-10-05 15:25:28 -04:00
Peter Hawkins
42e0d4e5f5 Remove jax._src.util.partialmethod.
Use functools.partialmethod instead, which has existed since Python 3.4. The JAX partialmethod doesn't work correctly in Python 3.10.

Issue #8097
2021-10-05 12:12:41 -04:00
Yin Li
ece532556b
Simplify shape comparison with numpy assert 2021-10-05 11:30:19 -04:00
Adam Paszke
1520fa261f Change semantics of positional shapes depending on pjit nesting
One interesting angle of pjit is that it is a boundary between a multi-controller
world in which `.shape` attributes of all arrays (and avals) correspond to
slices of data that are internall to a given process, and a single-controller
world where `.shape` refers to the global array constructed by concatenating
per-device chunks. I haven't fully appreciated this previously which made
pjit nests (and xmaps in pjits) to incorrectly increase shapes with every
level of nesting, when only the outermost call that should make the change.

We now keep track of a flag that determines whether the positional shape of
avals we see is global or local in any given context. Note that sizes of named
axes have been and still are global only.

PiperOrigin-RevId: 400949756
2021-10-05 04:31:32 -07:00
sracaniere
8bb8c8e994
Update doc for eigh.
Mention `eigenvectors` before `eigenvalues` in doc to match the order of returned values.
2021-10-05 09:19:23 +01:00
Matthew Johnson
7ec797b794 lower rng_bit_generator using a BitcastConvertType 2021-10-04 21:47:12 -07:00
Yin Li
5d675220c0
Add float0 support to equality and closeness check 2021-10-04 21:32:57 -04:00
jax authors
50141e667f Merge pull request #8088 from skye:fix_formatting
PiperOrigin-RevId: 400862458
2021-10-04 18:21:29 -07:00
Peter Hawkins
256e7220ff [JAX] Fix pylint errors.
* trailing-whitespace
* dangerous-default-value. None of these appear to be bugs in practice, but the potential for accidentally mutating the default value is there, and the cost of avoiding the problem is small.
* invalid-envvar-default. Pass strings as getenv() defaults.
* unnecessary-semicolon. Use tuples instead for this one-liner.
* invalid-hash-returned. Raise an exception rather than asserting false.
* pointless-string-statement. Use comments instead.
* unreachable. Use @unittest.skip() decorator rather than raising as first line in test.
* logging-not-lazy. Make the logging lazy.
* bad-format-string-type. Use f-string instead.
* subprocess-run-check. Pass check=...

PiperOrigin-RevId: 400858477
2021-10-04 17:54:46 -07:00