4532 Commits

Author SHA1 Message Date
jax authors
0a84db59ed Merge pull request #6068 from jakevdp:fix-result-type
PiperOrigin-RevId: 363282198
2021-03-16 15:20:40 -07:00
Jake VanderPlas
d6408a4e6a Add extras_require to setup.py 2021-03-16 13:23:46 -07:00
Peter Hawkins
328930b917 Increase minimum jaxlib version to 0.1.62. 2021-03-16 15:11:36 -04:00
Tamas Berghammer
2ea526102d Add new lax.rng_bit_generator primitive
The new primitive provides access to the RngBitGenerator HLO
(https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator)
2021-03-16 16:30:09 +00:00
jax authors
2d148a331c Merge pull request #6078 from jacobaustin123:master
PiperOrigin-RevId: 363183679
2021-03-16 08:03:23 -07:00
jax authors
25704a048a Merge pull request #6081 from gnecula:jax2tf_examples
PiperOrigin-RevId: 363151280
2021-03-16 04:14:48 -07:00
George Necula
f67aeeadc8 Fix the output directory 2021-03-16 12:01:16 +01:00
jax authors
53d21b2059 Merge pull request #6077 from hawkinsp:jaxlibimport
PiperOrigin-RevId: 363147337
2021-03-16 03:42:02 -07:00
George Necula
840d516203 [jax2tf] Removed more traces of support for batch polymorphism
See issue #6080

* Also cleanup the examples
2021-03-16 11:38:57 +01:00
Jacob Austin
9d28b67022
Fixed two small typos in jax.lax. 2021-03-15 23:26:31 -04:00
jax authors
3da12dbc4f Merge pull request #6048 from skye:debug_nans
PiperOrigin-RevId: 363078153
2021-03-15 18:24:44 -07:00
Peter Hawkins
ee53eeb541 Add helpful message when import jaxlib fails. 2021-03-15 21:07:59 -04:00
Jake VanderPlas
b0c5fba82a BUG: fix jnp.result_type for non-canonical weak types 2021-03-15 14:38:14 -07:00
Peter Hawkins
63c06ef77e [JAX] Add a .weak_type attribute to C++ array objects.
Use .weak_type instead of parsing avals from C++. Inspecting Python objects unnecessarily is slow. In addition we were building a Python bool object that we didn't need to build (`py::cast<py::bool_>` instead of `py::cast<bool>`).

Benchmarks on my workstation:

```
name                                old time/op             new time/op             delta
jit_trivial_dispatch                44.9µs ± 1%             44.3µs ± 0%   -1.37%          (p=0.008 n=5+5)
jit_trivial                         46.2µs ± 0%             45.6µs ± 0%   -1.39%          (p=0.008 n=5+5)
jit_simple_dispatch                 17.7µs ± 2%             16.6µs ± 1%   -6.37%          (p=0.008 n=5+5)
jit_simple                          18.5µs ± 5%             17.3µs ± 1%   -6.54%          (p=0.008 n=5+5)
jit_simple_many_args_dispatch_10    26.6µs ± 1%             22.6µs ± 2%  -15.12%          (p=0.008 n=5+5)
jit_simple_many_args_10             27.9µs ± 3%             24.6µs ± 4%  -12.00%          (p=0.008 n=5+5)
jit_simple_many_args_dispatch_100    107µs ± 1%               75µs ± 1%  -29.85%          (p=0.008 n=5+5)
jit_simple_many_args_100             108µs ± 1%               76µs ± 0%  -29.66%          (p=0.008 n=5+5)
jit_simple_many_args_dispatch_1000  1.01ms ± 1%             0.69ms ± 2%  -31.72%          (p=0.008 n=5+5)
jit_simple_many_args_1000           1.03ms ± 1%             0.71ms ± 2%  -30.77%          (p=0.008 n=5+5)
jit_simple_many_args_dispatch_2000  2.09ms ± 1%             1.43ms ± 3%  -31.78%          (p=0.008 n=5+5)
jit_simple_many_args_2000           2.08ms ± 1%             1.44ms ± 4%  -30.77%          (p=0.008 n=5+5)
jit_dispatch_without_transfer       1.41ms ± 1%             1.43ms ± 6%     ~             (p=1.000 n=5+5)
jit_dispatch_with_transfer          1.40ms ± 1%             1.40ms ± 1%     ~             (p=1.000 n=5+5)
```

PiperOrigin-RevId: 363002879
2021-03-15 12:30:15 -07:00
Peter Hawkins
9a2a1ada85 [JAX] Enable C++ device arrays by default.
[XLA:Python] Relax constraints on .aval and ._device attributes on C++ buffer objects. The constraints cause more problems than they solve. Switch _device to be a C++ attribute rather than a Python attribute. This avoids some unnecessary Python attribute parsing in the JIT dispatch path.

Change PyBuffer objects to call themselves `DeviceArray` in Python so as not to surprise JAX users.

PiperOrigin-RevId: 362969997
2021-03-15 10:20:22 -07:00
jax authors
80966fe5bf Merge pull request #6018 from jakevdp:conv-elem-type
PiperOrigin-RevId: 362956294
2021-03-15 09:19:52 -07:00
jax authors
3fb6a11a27 Merge pull request #6057 from inailuig:fix-complex-normal
PiperOrigin-RevId: 362930693
2021-03-15 06:56:14 -07:00
jax authors
1ad99d3dbb Merge pull request #5999 from sharadmv:callback-scan
PiperOrigin-RevId: 362898757
2021-03-15 03:10:41 -07:00
Clemens Giuliani
d78fe6b8fb fix the dtype of complex jax.random.normal and add a regression test for it 2021-03-14 22:28:38 +01:00
Skye Wanderman-Milne
c56649aaac Make jax_debug_nans and jax_debug_infs work with pmap, xmap, and pjit.
Note that unlike in the jit case, this doesn't rerun the function in
op-by-op mode when it finds a nan, since we don't have op-by-op
parallel execution yet :)

This change doesn't appear to regress performance:

```
---------Benchmark summary for pmap_shard_outputs---------
  nouts    nshards       mean      %std    relative    mean/baseline
-------  ---------  ---------  --------  ----------  ---------------
     10          8   0.105598  5.06671      1               1.00693
    100          8   0.287756  0.870751     2.72502         0.973204
    500          8   1.20119   0.823624    11.3752          0.955185
   1000          8   2.56071   0           24.2497          0.983063
   5000          8  12.909     0          122.247           0.965925
    100          2   0.173727  5.15115      1.64518         0.98918
    100          4   0.207774  3.71411      1.9676          0.955849
    100          8   0.286103  1.60243      2.70937         0.971869
    100        100   2.34168   0           22.1755          0.904475
    100        500  15.9558    0          151.1             1.00483
```

Fixes #6044
2021-03-12 16:22:55 -08:00
Jake VanderPlas
04bf02a4b6 convert_element_type: don't canonicalize old_dtype 2021-03-12 15:26:06 -08:00
jax authors
8d3b4ac2f3 Merge pull request #6028 from jakevdp:transpose
PiperOrigin-RevId: 362590852
2021-03-12 13:38:13 -08:00
jax authors
77c1f313d9 Merge pull request #5966 from mtsokol:jax-numpy-where-keyword
PiperOrigin-RevId: 362565473
2021-03-12 11:35:44 -08:00
Jake VanderPlas
ed4c94497a jnp.array.transpose: support positional axis arguments 2021-03-12 11:16:50 -08:00
Mateusz Sokół
d743aa5803 Added 'where' keyword to 'jnp.{mean, var, std}' 2021-03-12 17:57:17 +01:00
Skye Wanderman-Milne
5cb5056ea7 Suppress gRPC log spam on Cloud TPU. 2021-03-11 22:52:54 +00:00
James Bradbury
72a3036b1a Hotfix for another assertion that's too strict about named shapes
PiperOrigin-RevId: 362164157
2021-03-10 16:09:37 -08:00
jax authors
cf9b77f1de Merge pull request #5998 from zhangqiaorjc:dev_put_count
PiperOrigin-RevId: 362143966
2021-03-10 14:36:55 -08:00
jax authors
61041cb1e3 Merge pull request #6010 from hawkinsp:issue4690
PiperOrigin-RevId: 362120286
2021-03-10 12:49:55 -08:00
Peter Hawkins
62a726d329 Add workaround for SelectAndScatter padding bug on CPU and GPU. 2021-03-10 15:25:32 -05:00
jax authors
c9c89c4820 Merge pull request #5997 from jakevdp:fix-piecewise
PiperOrigin-RevId: 362085337
2021-03-10 10:32:56 -08:00
Skye Wanderman-Milne
c32d1e5aae Automatically initialize Cloud TPU topology env vars if running on a Cloud TPU VM.
This removes the need to manually set these env vars when running on a Cloud TPU pod slice.
2021-03-10 09:15:31 -08:00
Peter Hawkins
00349390fe Fix crash returning a Token from a jit computation on GPU.
Calling .numpy_dtype() doesn't work on tokens. But we don't need a numpy dtype here, an XLA dtype works just as well.
2021-03-10 10:18:38 -05:00
Sharad Vikram
ddaef193fe Add scan and while rule for jax.experimental.callback transformation 2021-03-09 19:46:16 -08:00
Peter Hawkins
140c0acbbe Remove the JAX lazy sublanguage.
Back in the mists of time, before omnistaging landed in JAX, we used lazy
expressions to avoid materializing large constants inside `jit` computations.
Omnistaging, which means that computations that are in the dynamic scope of a
`jit` are staged into the `jit` computation, has subsumed most of the reasons
for laziness to exist, and this PR removes the laziness support for simplicity.

At the time of this PR, laziness is used only for broadcasts and transposes in
eager mode (i.e., outside a `jit`). This allows us to:
a) fuse together multiple broadcasts and transposes, and
b) if a lazy expression is lexically captured by a `jit` computation, we can
   avoid materializing it in its expanded form.

It is not clear that laziness has sufficient power to weight ratio to continue
to exist, and it is making other work on improving JAX dispatch times more
difficult. As a result, this PR removes laziness to unblock that work; if we
want laziness again we would want to reimplement it in C++ anyway.
2021-03-09 21:40:46 -05:00
Qiao Zhang
9577860169 Add jtu.count_device_put for tests to count device_put. 2021-03-09 14:45:01 -08:00
Matthew Johnson
2b9ffb1fb3 make axis_index bind respect dynamic traces 2021-03-09 13:51:12 -08:00
James Bradbury
a8b8246554 add some todos 2021-03-09 13:51:09 -08:00
Roy Frostig
e779ed8299 simplify standard named_shape_rule
Co-authored-by: Matthew Johnson <mattjj@google.com>
2021-03-09 13:48:26 -08:00
James Bradbury
c622422dad [avals with names] Propagate presence of name (mapped) vs absence (replicated) in abstract eval based on existing batching rules 2021-03-09 13:48:15 -08:00
James Bradbury
ba0f785a1f allow named axes on while_loop condition aval 2021-03-09 13:45:17 -08:00
Jake VanderPlas
dbdb189de1 jnp.piecewise: support scalar inputs 2021-03-09 13:25:38 -08:00
jax authors
6515b5f676 Merge pull request #5977 from apaszke:xmap-with-control-flow
PiperOrigin-RevId: 361854852
2021-03-09 11:22:18 -08:00
jax authors
f1ba3bcc9d Merge pull request #5990 from jakevdp:fix-power
PiperOrigin-RevId: 361852538
2021-03-09 11:12:39 -08:00
jax authors
b0d14fd28f Merge pull request #5951 from apaszke:revive-all-to-all
PiperOrigin-RevId: 361841703
2021-03-09 10:29:12 -08:00
Jake VanderPlas
0c86c1fd11 jnp.power: fix overflow case for x1=0 2021-03-09 09:36:41 -08:00
Adam Paszke
ec29275d7e Substitute axis names in nested jaxprs
Previously any collectives buried inside control flow would fail to
compile with xmap, because it would not traverse those with its name
substitution. This adds a "catch-all" default substitution rule which
recursively applies to all jaxpr found in the params (at the top level).
2021-03-08 18:11:07 +00:00
Jake VanderPlas
c9d1ded024 jax.random.poisson: fix return value for lam=0 2021-03-08 09:27:11 -08:00
Adam Paszke
2c7c86a4ba Reenable multi-axis all_to_all 2021-03-08 12:45:03 +00:00
Peter Hawkins
2469ad1bb3 Cleanups for laziness. No functional changes intended.
Use None as a trivial lazy expression in more places. Simplify some code.
2021-03-07 11:33:04 -05:00