jax authors
0a84db59ed
Merge pull request #6068 from jakevdp:fix-result-type
...
PiperOrigin-RevId: 363282198
2021-03-16 15:20:40 -07:00
Jake VanderPlas
d6408a4e6a
Add extras_require to setup.py
2021-03-16 13:23:46 -07:00
Peter Hawkins
328930b917
Increase minimum jaxlib version to 0.1.62.
2021-03-16 15:11:36 -04:00
Tamas Berghammer
2ea526102d
Add new lax.rng_bit_generator primitive
...
The new primitive provides access to the RngBitGenerator HLO
(https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator )
2021-03-16 16:30:09 +00:00
jax authors
2d148a331c
Merge pull request #6078 from jacobaustin123:master
...
PiperOrigin-RevId: 363183679
2021-03-16 08:03:23 -07:00
jax authors
25704a048a
Merge pull request #6081 from gnecula:jax2tf_examples
...
PiperOrigin-RevId: 363151280
2021-03-16 04:14:48 -07:00
George Necula
f67aeeadc8
Fix the output directory
2021-03-16 12:01:16 +01:00
jax authors
53d21b2059
Merge pull request #6077 from hawkinsp:jaxlibimport
...
PiperOrigin-RevId: 363147337
2021-03-16 03:42:02 -07:00
George Necula
840d516203
[jax2tf] Removed more traces of support for batch polymorphism
...
See issue #6080
* Also cleanup the examples
2021-03-16 11:38:57 +01:00
Jacob Austin
9d28b67022
Fixed two small typos in jax.lax.
2021-03-15 23:26:31 -04:00
jax authors
3da12dbc4f
Merge pull request #6048 from skye:debug_nans
...
PiperOrigin-RevId: 363078153
2021-03-15 18:24:44 -07:00
Peter Hawkins
ee53eeb541
Add helpful message when import jaxlib
fails.
2021-03-15 21:07:59 -04:00
Jake VanderPlas
b0c5fba82a
BUG: fix jnp.result_type for non-canonical weak types
2021-03-15 14:38:14 -07:00
Peter Hawkins
63c06ef77e
[JAX] Add a .weak_type attribute to C++ array objects.
...
Use .weak_type instead of parsing avals from C++. Inspecting Python objects unnecessarily is slow. In addition we were building a Python bool object that we didn't need to build (`py::cast<py::bool_>` instead of `py::cast<bool>`).
Benchmarks on my workstation:
```
name old time/op new time/op delta
jit_trivial_dispatch 44.9µs ± 1% 44.3µs ± 0% -1.37% (p=0.008 n=5+5)
jit_trivial 46.2µs ± 0% 45.6µs ± 0% -1.39% (p=0.008 n=5+5)
jit_simple_dispatch 17.7µs ± 2% 16.6µs ± 1% -6.37% (p=0.008 n=5+5)
jit_simple 18.5µs ± 5% 17.3µs ± 1% -6.54% (p=0.008 n=5+5)
jit_simple_many_args_dispatch_10 26.6µs ± 1% 22.6µs ± 2% -15.12% (p=0.008 n=5+5)
jit_simple_many_args_10 27.9µs ± 3% 24.6µs ± 4% -12.00% (p=0.008 n=5+5)
jit_simple_many_args_dispatch_100 107µs ± 1% 75µs ± 1% -29.85% (p=0.008 n=5+5)
jit_simple_many_args_100 108µs ± 1% 76µs ± 0% -29.66% (p=0.008 n=5+5)
jit_simple_many_args_dispatch_1000 1.01ms ± 1% 0.69ms ± 2% -31.72% (p=0.008 n=5+5)
jit_simple_many_args_1000 1.03ms ± 1% 0.71ms ± 2% -30.77% (p=0.008 n=5+5)
jit_simple_many_args_dispatch_2000 2.09ms ± 1% 1.43ms ± 3% -31.78% (p=0.008 n=5+5)
jit_simple_many_args_2000 2.08ms ± 1% 1.44ms ± 4% -30.77% (p=0.008 n=5+5)
jit_dispatch_without_transfer 1.41ms ± 1% 1.43ms ± 6% ~ (p=1.000 n=5+5)
jit_dispatch_with_transfer 1.40ms ± 1% 1.40ms ± 1% ~ (p=1.000 n=5+5)
```
PiperOrigin-RevId: 363002879
2021-03-15 12:30:15 -07:00
Peter Hawkins
9a2a1ada85
[JAX] Enable C++ device arrays by default.
...
[XLA:Python] Relax constraints on .aval and ._device attributes on C++ buffer objects. The constraints cause more problems than they solve. Switch _device to be a C++ attribute rather than a Python attribute. This avoids some unnecessary Python attribute parsing in the JIT dispatch path.
Change PyBuffer objects to call themselves `DeviceArray` in Python so as not to surprise JAX users.
PiperOrigin-RevId: 362969997
2021-03-15 10:20:22 -07:00
jax authors
80966fe5bf
Merge pull request #6018 from jakevdp:conv-elem-type
...
PiperOrigin-RevId: 362956294
2021-03-15 09:19:52 -07:00
jax authors
3fb6a11a27
Merge pull request #6057 from inailuig:fix-complex-normal
...
PiperOrigin-RevId: 362930693
2021-03-15 06:56:14 -07:00
jax authors
1ad99d3dbb
Merge pull request #5999 from sharadmv:callback-scan
...
PiperOrigin-RevId: 362898757
2021-03-15 03:10:41 -07:00
Clemens Giuliani
d78fe6b8fb
fix the dtype of complex jax.random.normal and add a regression test for it
2021-03-14 22:28:38 +01:00
Skye Wanderman-Milne
c56649aaac
Make jax_debug_nans and jax_debug_infs work with pmap, xmap, and pjit.
...
Note that unlike in the jit case, this doesn't rerun the function in
op-by-op mode when it finds a nan, since we don't have op-by-op
parallel execution yet :)
This change doesn't appear to regress performance:
```
---------Benchmark summary for pmap_shard_outputs---------
nouts nshards mean %std relative mean/baseline
------- --------- --------- -------- ---------- ---------------
10 8 0.105598 5.06671 1 1.00693
100 8 0.287756 0.870751 2.72502 0.973204
500 8 1.20119 0.823624 11.3752 0.955185
1000 8 2.56071 0 24.2497 0.983063
5000 8 12.909 0 122.247 0.965925
100 2 0.173727 5.15115 1.64518 0.98918
100 4 0.207774 3.71411 1.9676 0.955849
100 8 0.286103 1.60243 2.70937 0.971869
100 100 2.34168 0 22.1755 0.904475
100 500 15.9558 0 151.1 1.00483
```
Fixes #6044
2021-03-12 16:22:55 -08:00
Jake VanderPlas
04bf02a4b6
convert_element_type: don't canonicalize old_dtype
2021-03-12 15:26:06 -08:00
jax authors
8d3b4ac2f3
Merge pull request #6028 from jakevdp:transpose
...
PiperOrigin-RevId: 362590852
2021-03-12 13:38:13 -08:00
jax authors
77c1f313d9
Merge pull request #5966 from mtsokol:jax-numpy-where-keyword
...
PiperOrigin-RevId: 362565473
2021-03-12 11:35:44 -08:00
Jake VanderPlas
ed4c94497a
jnp.array.transpose: support positional axis arguments
2021-03-12 11:16:50 -08:00
Mateusz Sokół
d743aa5803
Added 'where' keyword to 'jnp.{mean, var, std}'
2021-03-12 17:57:17 +01:00
Skye Wanderman-Milne
5cb5056ea7
Suppress gRPC log spam on Cloud TPU.
2021-03-11 22:52:54 +00:00
James Bradbury
72a3036b1a
Hotfix for another assertion that's too strict about named shapes
...
PiperOrigin-RevId: 362164157
2021-03-10 16:09:37 -08:00
jax authors
cf9b77f1de
Merge pull request #5998 from zhangqiaorjc:dev_put_count
...
PiperOrigin-RevId: 362143966
2021-03-10 14:36:55 -08:00
jax authors
61041cb1e3
Merge pull request #6010 from hawkinsp:issue4690
...
PiperOrigin-RevId: 362120286
2021-03-10 12:49:55 -08:00
Peter Hawkins
62a726d329
Add workaround for SelectAndScatter padding bug on CPU and GPU.
2021-03-10 15:25:32 -05:00
jax authors
c9c89c4820
Merge pull request #5997 from jakevdp:fix-piecewise
...
PiperOrigin-RevId: 362085337
2021-03-10 10:32:56 -08:00
Skye Wanderman-Milne
c32d1e5aae
Automatically initialize Cloud TPU topology env vars if running on a Cloud TPU VM.
...
This removes the need to manually set these env vars when running on a Cloud TPU pod slice.
2021-03-10 09:15:31 -08:00
Peter Hawkins
00349390fe
Fix crash returning a Token from a jit computation on GPU.
...
Calling .numpy_dtype() doesn't work on tokens. But we don't need a numpy dtype here, an XLA dtype works just as well.
2021-03-10 10:18:38 -05:00
Sharad Vikram
ddaef193fe
Add scan and while rule for jax.experimental.callback transformation
2021-03-09 19:46:16 -08:00
Peter Hawkins
140c0acbbe
Remove the JAX lazy sublanguage.
...
Back in the mists of time, before omnistaging landed in JAX, we used lazy
expressions to avoid materializing large constants inside `jit` computations.
Omnistaging, which means that computations that are in the dynamic scope of a
`jit` are staged into the `jit` computation, has subsumed most of the reasons
for laziness to exist, and this PR removes the laziness support for simplicity.
At the time of this PR, laziness is used only for broadcasts and transposes in
eager mode (i.e., outside a `jit`). This allows us to:
a) fuse together multiple broadcasts and transposes, and
b) if a lazy expression is lexically captured by a `jit` computation, we can
avoid materializing it in its expanded form.
It is not clear that laziness has sufficient power to weight ratio to continue
to exist, and it is making other work on improving JAX dispatch times more
difficult. As a result, this PR removes laziness to unblock that work; if we
want laziness again we would want to reimplement it in C++ anyway.
2021-03-09 21:40:46 -05:00
Qiao Zhang
9577860169
Add jtu.count_device_put for tests to count device_put.
2021-03-09 14:45:01 -08:00
Matthew Johnson
2b9ffb1fb3
make axis_index bind respect dynamic traces
2021-03-09 13:51:12 -08:00
James Bradbury
a8b8246554
add some todos
2021-03-09 13:51:09 -08:00
Roy Frostig
e779ed8299
simplify standard named_shape_rule
...
Co-authored-by: Matthew Johnson <mattjj@google.com>
2021-03-09 13:48:26 -08:00
James Bradbury
c622422dad
[avals with names] Propagate presence of name (mapped) vs absence (replicated) in abstract eval based on existing batching rules
2021-03-09 13:48:15 -08:00
James Bradbury
ba0f785a1f
allow named axes on while_loop condition aval
2021-03-09 13:45:17 -08:00
Jake VanderPlas
dbdb189de1
jnp.piecewise: support scalar inputs
2021-03-09 13:25:38 -08:00
jax authors
6515b5f676
Merge pull request #5977 from apaszke:xmap-with-control-flow
...
PiperOrigin-RevId: 361854852
2021-03-09 11:22:18 -08:00
jax authors
f1ba3bcc9d
Merge pull request #5990 from jakevdp:fix-power
...
PiperOrigin-RevId: 361852538
2021-03-09 11:12:39 -08:00
jax authors
b0d14fd28f
Merge pull request #5951 from apaszke:revive-all-to-all
...
PiperOrigin-RevId: 361841703
2021-03-09 10:29:12 -08:00
Jake VanderPlas
0c86c1fd11
jnp.power: fix overflow case for x1=0
2021-03-09 09:36:41 -08:00
Adam Paszke
ec29275d7e
Substitute axis names in nested jaxprs
...
Previously any collectives buried inside control flow would fail to
compile with xmap, because it would not traverse those with its name
substitution. This adds a "catch-all" default substitution rule which
recursively applies to all jaxpr found in the params (at the top level).
2021-03-08 18:11:07 +00:00
Jake VanderPlas
c9d1ded024
jax.random.poisson: fix return value for lam=0
2021-03-08 09:27:11 -08:00
Adam Paszke
2c7c86a4ba
Reenable multi-axis all_to_all
2021-03-08 12:45:03 +00:00
Peter Hawkins
2469ad1bb3
Cleanups for laziness. No functional changes intended.
...
Use None as a trivial lazy expression in more places. Simplify some code.
2021-03-07 11:33:04 -05:00