7282 Commits

Author SHA1 Message Date
Skye Wanderman-Milne
7ad51264fd Update WORKSPACE for jaxlib 0.1.63 release, take 2
PiperOrigin-RevId: 363315929
jaxlib-v0.1.63
2021-03-16 18:26:52 -07:00
jax authors
0a84db59ed Merge pull request #6068 from jakevdp:fix-result-type
PiperOrigin-RevId: 363282198
2021-03-16 15:20:40 -07:00
jax authors
6e1cd395e8 Merge pull request #6075 from jakevdp:fix-formatting
PiperOrigin-RevId: 363274347
2021-03-16 14:45:24 -07:00
Jake VanderPlas
4d8e9540a4 Sharp Bits: fix formatting of code blocks within list 2021-03-16 14:21:21 -07:00
jax authors
3b7de31c36 Merge pull request #6087 from jakevdp:pypi-extras
PiperOrigin-RevId: 363263098
2021-03-16 13:58:15 -07:00
Jake VanderPlas
d6408a4e6a Add extras_require to setup.py 2021-03-16 13:23:46 -07:00
jax authors
d2d7ecf14b Merge pull request #6082 from google:minjaxlib
PiperOrigin-RevId: 363255561
2021-03-16 13:21:52 -07:00
Peter Hawkins
328930b917 Increase minimum jaxlib version to 0.1.62. 2021-03-16 15:11:36 -04:00
jax authors
d326b077d9 Merge pull request #6086 from hawkinsp:numpy
PiperOrigin-RevId: 363239393
2021-03-16 12:10:03 -07:00
jax authors
2bf7dbceec Merge pull request #5736 from tberghammer:changelist/357690519
PiperOrigin-RevId: 363229118
2021-03-16 11:28:46 -07:00
Peter Hawkins
48a7b153c1 Fix test failure with NumPy 1.20.
Fixes #6083
2021-03-16 13:46:13 -04:00
jax authors
265a663d88 Merge pull request #6084 from skye:workspace
PiperOrigin-RevId: 363205601
2021-03-16 09:51:55 -07:00
Skye Wanderman-Milne
b319d23431 Update WORKSPACE for jaxlib 0.1.63 release 2021-03-16 09:43:22 -07:00
Tamas Berghammer
2ea526102d Add new lax.rng_bit_generator primitive
The new primitive provides access to the RngBitGenerator HLO
(https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator)
2021-03-16 16:30:09 +00:00
jax authors
2d148a331c Merge pull request #6078 from jacobaustin123:master
PiperOrigin-RevId: 363183679
2021-03-16 08:03:23 -07:00
jax authors
25704a048a Merge pull request #6081 from gnecula:jax2tf_examples
PiperOrigin-RevId: 363151280
2021-03-16 04:14:48 -07:00
George Necula
f67aeeadc8 Fix the output directory 2021-03-16 12:01:16 +01:00
jax authors
53d21b2059 Merge pull request #6077 from hawkinsp:jaxlibimport
PiperOrigin-RevId: 363147337
2021-03-16 03:42:02 -07:00
George Necula
840d516203 [jax2tf] Removed more traces of support for batch polymorphism
See issue #6080

* Also cleanup the examples
2021-03-16 11:38:57 +01:00
Jacob Austin
9d28b67022
Fixed two small typos in jax.lax. 2021-03-15 23:26:31 -04:00
jax authors
3da12dbc4f Merge pull request #6048 from skye:debug_nans
PiperOrigin-RevId: 363078153
2021-03-15 18:24:44 -07:00
Peter Hawkins
ee53eeb541 Add helpful message when import jaxlib fails. 2021-03-15 21:07:59 -04:00
jax authors
b3165bb65c Merge pull request #6064 from ebuehrle:patch-1
PiperOrigin-RevId: 363067775
2021-03-15 17:22:27 -07:00
Jake VanderPlas
b0c5fba82a BUG: fix jnp.result_type for non-canonical weak types 2021-03-15 14:38:14 -07:00
Peter Hawkins
63c06ef77e [JAX] Add a .weak_type attribute to C++ array objects.
Use .weak_type instead of parsing avals from C++. Inspecting Python objects unnecessarily is slow. In addition we were building a Python bool object that we didn't need to build (`py::cast<py::bool_>` instead of `py::cast<bool>`).

Benchmarks on my workstation:

```
name                                old time/op             new time/op             delta
jit_trivial_dispatch                44.9µs ± 1%             44.3µs ± 0%   -1.37%          (p=0.008 n=5+5)
jit_trivial                         46.2µs ± 0%             45.6µs ± 0%   -1.39%          (p=0.008 n=5+5)
jit_simple_dispatch                 17.7µs ± 2%             16.6µs ± 1%   -6.37%          (p=0.008 n=5+5)
jit_simple                          18.5µs ± 5%             17.3µs ± 1%   -6.54%          (p=0.008 n=5+5)
jit_simple_many_args_dispatch_10    26.6µs ± 1%             22.6µs ± 2%  -15.12%          (p=0.008 n=5+5)
jit_simple_many_args_10             27.9µs ± 3%             24.6µs ± 4%  -12.00%          (p=0.008 n=5+5)
jit_simple_many_args_dispatch_100    107µs ± 1%               75µs ± 1%  -29.85%          (p=0.008 n=5+5)
jit_simple_many_args_100             108µs ± 1%               76µs ± 0%  -29.66%          (p=0.008 n=5+5)
jit_simple_many_args_dispatch_1000  1.01ms ± 1%             0.69ms ± 2%  -31.72%          (p=0.008 n=5+5)
jit_simple_many_args_1000           1.03ms ± 1%             0.71ms ± 2%  -30.77%          (p=0.008 n=5+5)
jit_simple_many_args_dispatch_2000  2.09ms ± 1%             1.43ms ± 3%  -31.78%          (p=0.008 n=5+5)
jit_simple_many_args_2000           2.08ms ± 1%             1.44ms ± 4%  -30.77%          (p=0.008 n=5+5)
jit_dispatch_without_transfer       1.41ms ± 1%             1.43ms ± 6%     ~             (p=1.000 n=5+5)
jit_dispatch_with_transfer          1.40ms ± 1%             1.40ms ± 1%     ~             (p=1.000 n=5+5)
```

PiperOrigin-RevId: 363002879
2021-03-15 12:30:15 -07:00
Peter Hawkins
9a2a1ada85 [JAX] Enable C++ device arrays by default.
[XLA:Python] Relax constraints on .aval and ._device attributes on C++ buffer objects. The constraints cause more problems than they solve. Switch _device to be a C++ attribute rather than a Python attribute. This avoids some unnecessary Python attribute parsing in the JIT dispatch path.

Change PyBuffer objects to call themselves `DeviceArray` in Python so as not to surprise JAX users.

PiperOrigin-RevId: 362969997
2021-03-15 10:20:22 -07:00
ebuehrle
2e90e72e90 Fix code typo 2021-03-15 17:20:47 +01:00
jax authors
80966fe5bf Merge pull request #6018 from jakevdp:conv-elem-type
PiperOrigin-RevId: 362956294
2021-03-15 09:19:52 -07:00
jax authors
3fb6a11a27 Merge pull request #6057 from inailuig:fix-complex-normal
PiperOrigin-RevId: 362930693
2021-03-15 06:56:14 -07:00
jax authors
1ad99d3dbb Merge pull request #5999 from sharadmv:callback-scan
PiperOrigin-RevId: 362898757
2021-03-15 03:10:41 -07:00
Clemens Giuliani
d78fe6b8fb fix the dtype of complex jax.random.normal and add a regression test for it 2021-03-14 22:28:38 +01:00
Skye Wanderman-Milne
c56649aaac Make jax_debug_nans and jax_debug_infs work with pmap, xmap, and pjit.
Note that unlike in the jit case, this doesn't rerun the function in
op-by-op mode when it finds a nan, since we don't have op-by-op
parallel execution yet :)

This change doesn't appear to regress performance:

```
---------Benchmark summary for pmap_shard_outputs---------
  nouts    nshards       mean      %std    relative    mean/baseline
-------  ---------  ---------  --------  ----------  ---------------
     10          8   0.105598  5.06671      1               1.00693
    100          8   0.287756  0.870751     2.72502         0.973204
    500          8   1.20119   0.823624    11.3752          0.955185
   1000          8   2.56071   0           24.2497          0.983063
   5000          8  12.909     0          122.247           0.965925
    100          2   0.173727  5.15115      1.64518         0.98918
    100          4   0.207774  3.71411      1.9676          0.955849
    100          8   0.286103  1.60243      2.70937         0.971869
    100        100   2.34168   0           22.1755          0.904475
    100        500  15.9558    0          151.1             1.00483
```

Fixes #6044
2021-03-12 16:22:55 -08:00
Jake VanderPlas
04bf02a4b6 convert_element_type: don't canonicalize old_dtype 2021-03-12 15:26:06 -08:00
jax authors
8d3b4ac2f3 Merge pull request #6028 from jakevdp:transpose
PiperOrigin-RevId: 362590852
2021-03-12 13:38:13 -08:00
jax authors
77c1f313d9 Merge pull request #5966 from mtsokol:jax-numpy-where-keyword
PiperOrigin-RevId: 362565473
2021-03-12 11:35:44 -08:00
jax authors
a42e6533ae Merge pull request #6029 from jakevdp:timeout-minutes
PiperOrigin-RevId: 362561981
2021-03-12 11:19:49 -08:00
Jake VanderPlas
ed4c94497a jnp.array.transpose: support positional axis arguments 2021-03-12 11:16:50 -08:00
Jake VanderPlas
60dcd0da61 Set reasonable timeouts for github actions jobs 2021-03-12 09:31:43 -08:00
jax authors
ea07d41947 Merge pull request #6041 from tomhennigan:changelist/362481121
PiperOrigin-RevId: 362537513
2021-03-12 09:29:21 -08:00
Mateusz Sokół
d743aa5803 Added 'where' keyword to 'jnp.{mean, var, std}' 2021-03-12 17:57:17 +01:00
Tom Hennigan
4f74b3391c
Update README.md
Co-authored-by: 8bitmp3 <19637339+8bitmp3@users.noreply.github.com>
2021-03-12 16:38:29 +00:00
jax authors
9785230e20 Merge pull request #6043 from hawkinsp:pocketfft
PiperOrigin-RevId: 362504596
2021-03-12 06:02:50 -08:00
Peter Hawkins
1ed321cbcd Update PocketFFT version to fix crash due to undersized aligned allocations. 2021-03-12 08:46:01 -05:00
Tom Hennigan
0fd83a30be Link to DeepMind JAX blog post. 2021-03-12 12:01:20 +00:00
jax authors
c7ebc3ed74 Merge pull request #6032 from skye:grpc_env_var
PiperOrigin-RevId: 362389283
2021-03-11 15:22:38 -08:00
jax authors
ee8ecb0a9b Merge pull request #6015 from sethvargo:patch-1
PiperOrigin-RevId: 362387994
2021-03-11 15:15:53 -08:00
Skye Wanderman-Milne
5cb5056ea7 Suppress gRPC log spam on Cloud TPU. 2021-03-11 22:52:54 +00:00
jax authors
077793cd64 Merge pull request #6019 from skye:examples_test
PiperOrigin-RevId: 362363477
2021-03-11 13:23:53 -08:00
jax authors
65ee6041cb Merge pull request #6020 from jamestwebber:patch-1
PiperOrigin-RevId: 362360958
2021-03-11 13:11:21 -08:00
James Webber
cf4f44548c
sync markdown 2021-03-11 14:24:41 -05:00