Skye Wanderman-Milne
7ad51264fd
Update WORKSPACE for jaxlib 0.1.63 release, take 2
...
PiperOrigin-RevId: 363315929
jaxlib-v0.1.63
2021-03-16 18:26:52 -07:00
jax authors
0a84db59ed
Merge pull request #6068 from jakevdp:fix-result-type
...
PiperOrigin-RevId: 363282198
2021-03-16 15:20:40 -07:00
jax authors
6e1cd395e8
Merge pull request #6075 from jakevdp:fix-formatting
...
PiperOrigin-RevId: 363274347
2021-03-16 14:45:24 -07:00
Jake VanderPlas
4d8e9540a4
Sharp Bits: fix formatting of code blocks within list
2021-03-16 14:21:21 -07:00
jax authors
3b7de31c36
Merge pull request #6087 from jakevdp:pypi-extras
...
PiperOrigin-RevId: 363263098
2021-03-16 13:58:15 -07:00
Jake VanderPlas
d6408a4e6a
Add extras_require to setup.py
2021-03-16 13:23:46 -07:00
jax authors
d2d7ecf14b
Merge pull request #6082 from google:minjaxlib
...
PiperOrigin-RevId: 363255561
2021-03-16 13:21:52 -07:00
Peter Hawkins
328930b917
Increase minimum jaxlib version to 0.1.62.
2021-03-16 15:11:36 -04:00
jax authors
d326b077d9
Merge pull request #6086 from hawkinsp:numpy
...
PiperOrigin-RevId: 363239393
2021-03-16 12:10:03 -07:00
jax authors
2bf7dbceec
Merge pull request #5736 from tberghammer:changelist/357690519
...
PiperOrigin-RevId: 363229118
2021-03-16 11:28:46 -07:00
Peter Hawkins
48a7b153c1
Fix test failure with NumPy 1.20.
...
Fixes #6083
2021-03-16 13:46:13 -04:00
jax authors
265a663d88
Merge pull request #6084 from skye:workspace
...
PiperOrigin-RevId: 363205601
2021-03-16 09:51:55 -07:00
Skye Wanderman-Milne
b319d23431
Update WORKSPACE for jaxlib 0.1.63 release
2021-03-16 09:43:22 -07:00
Tamas Berghammer
2ea526102d
Add new lax.rng_bit_generator primitive
...
The new primitive provides access to the RngBitGenerator HLO
(https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator )
2021-03-16 16:30:09 +00:00
jax authors
2d148a331c
Merge pull request #6078 from jacobaustin123:master
...
PiperOrigin-RevId: 363183679
2021-03-16 08:03:23 -07:00
jax authors
25704a048a
Merge pull request #6081 from gnecula:jax2tf_examples
...
PiperOrigin-RevId: 363151280
2021-03-16 04:14:48 -07:00
George Necula
f67aeeadc8
Fix the output directory
2021-03-16 12:01:16 +01:00
jax authors
53d21b2059
Merge pull request #6077 from hawkinsp:jaxlibimport
...
PiperOrigin-RevId: 363147337
2021-03-16 03:42:02 -07:00
George Necula
840d516203
[jax2tf] Removed more traces of support for batch polymorphism
...
See issue #6080
* Also cleanup the examples
2021-03-16 11:38:57 +01:00
Jacob Austin
9d28b67022
Fixed two small typos in jax.lax.
2021-03-15 23:26:31 -04:00
jax authors
3da12dbc4f
Merge pull request #6048 from skye:debug_nans
...
PiperOrigin-RevId: 363078153
2021-03-15 18:24:44 -07:00
Peter Hawkins
ee53eeb541
Add helpful message when import jaxlib
fails.
2021-03-15 21:07:59 -04:00
jax authors
b3165bb65c
Merge pull request #6064 from ebuehrle:patch-1
...
PiperOrigin-RevId: 363067775
2021-03-15 17:22:27 -07:00
Jake VanderPlas
b0c5fba82a
BUG: fix jnp.result_type for non-canonical weak types
2021-03-15 14:38:14 -07:00
Peter Hawkins
63c06ef77e
[JAX] Add a .weak_type attribute to C++ array objects.
...
Use .weak_type instead of parsing avals from C++. Inspecting Python objects unnecessarily is slow. In addition we were building a Python bool object that we didn't need to build (`py::cast<py::bool_>` instead of `py::cast<bool>`).
Benchmarks on my workstation:
```
name old time/op new time/op delta
jit_trivial_dispatch 44.9µs ± 1% 44.3µs ± 0% -1.37% (p=0.008 n=5+5)
jit_trivial 46.2µs ± 0% 45.6µs ± 0% -1.39% (p=0.008 n=5+5)
jit_simple_dispatch 17.7µs ± 2% 16.6µs ± 1% -6.37% (p=0.008 n=5+5)
jit_simple 18.5µs ± 5% 17.3µs ± 1% -6.54% (p=0.008 n=5+5)
jit_simple_many_args_dispatch_10 26.6µs ± 1% 22.6µs ± 2% -15.12% (p=0.008 n=5+5)
jit_simple_many_args_10 27.9µs ± 3% 24.6µs ± 4% -12.00% (p=0.008 n=5+5)
jit_simple_many_args_dispatch_100 107µs ± 1% 75µs ± 1% -29.85% (p=0.008 n=5+5)
jit_simple_many_args_100 108µs ± 1% 76µs ± 0% -29.66% (p=0.008 n=5+5)
jit_simple_many_args_dispatch_1000 1.01ms ± 1% 0.69ms ± 2% -31.72% (p=0.008 n=5+5)
jit_simple_many_args_1000 1.03ms ± 1% 0.71ms ± 2% -30.77% (p=0.008 n=5+5)
jit_simple_many_args_dispatch_2000 2.09ms ± 1% 1.43ms ± 3% -31.78% (p=0.008 n=5+5)
jit_simple_many_args_2000 2.08ms ± 1% 1.44ms ± 4% -30.77% (p=0.008 n=5+5)
jit_dispatch_without_transfer 1.41ms ± 1% 1.43ms ± 6% ~ (p=1.000 n=5+5)
jit_dispatch_with_transfer 1.40ms ± 1% 1.40ms ± 1% ~ (p=1.000 n=5+5)
```
PiperOrigin-RevId: 363002879
2021-03-15 12:30:15 -07:00
Peter Hawkins
9a2a1ada85
[JAX] Enable C++ device arrays by default.
...
[XLA:Python] Relax constraints on .aval and ._device attributes on C++ buffer objects. The constraints cause more problems than they solve. Switch _device to be a C++ attribute rather than a Python attribute. This avoids some unnecessary Python attribute parsing in the JIT dispatch path.
Change PyBuffer objects to call themselves `DeviceArray` in Python so as not to surprise JAX users.
PiperOrigin-RevId: 362969997
2021-03-15 10:20:22 -07:00
ebuehrle
2e90e72e90
Fix code typo
2021-03-15 17:20:47 +01:00
jax authors
80966fe5bf
Merge pull request #6018 from jakevdp:conv-elem-type
...
PiperOrigin-RevId: 362956294
2021-03-15 09:19:52 -07:00
jax authors
3fb6a11a27
Merge pull request #6057 from inailuig:fix-complex-normal
...
PiperOrigin-RevId: 362930693
2021-03-15 06:56:14 -07:00
jax authors
1ad99d3dbb
Merge pull request #5999 from sharadmv:callback-scan
...
PiperOrigin-RevId: 362898757
2021-03-15 03:10:41 -07:00
Clemens Giuliani
d78fe6b8fb
fix the dtype of complex jax.random.normal and add a regression test for it
2021-03-14 22:28:38 +01:00
Skye Wanderman-Milne
c56649aaac
Make jax_debug_nans and jax_debug_infs work with pmap, xmap, and pjit.
...
Note that unlike in the jit case, this doesn't rerun the function in
op-by-op mode when it finds a nan, since we don't have op-by-op
parallel execution yet :)
This change doesn't appear to regress performance:
```
---------Benchmark summary for pmap_shard_outputs---------
nouts nshards mean %std relative mean/baseline
------- --------- --------- -------- ---------- ---------------
10 8 0.105598 5.06671 1 1.00693
100 8 0.287756 0.870751 2.72502 0.973204
500 8 1.20119 0.823624 11.3752 0.955185
1000 8 2.56071 0 24.2497 0.983063
5000 8 12.909 0 122.247 0.965925
100 2 0.173727 5.15115 1.64518 0.98918
100 4 0.207774 3.71411 1.9676 0.955849
100 8 0.286103 1.60243 2.70937 0.971869
100 100 2.34168 0 22.1755 0.904475
100 500 15.9558 0 151.1 1.00483
```
Fixes #6044
2021-03-12 16:22:55 -08:00
Jake VanderPlas
04bf02a4b6
convert_element_type: don't canonicalize old_dtype
2021-03-12 15:26:06 -08:00
jax authors
8d3b4ac2f3
Merge pull request #6028 from jakevdp:transpose
...
PiperOrigin-RevId: 362590852
2021-03-12 13:38:13 -08:00
jax authors
77c1f313d9
Merge pull request #5966 from mtsokol:jax-numpy-where-keyword
...
PiperOrigin-RevId: 362565473
2021-03-12 11:35:44 -08:00
jax authors
a42e6533ae
Merge pull request #6029 from jakevdp:timeout-minutes
...
PiperOrigin-RevId: 362561981
2021-03-12 11:19:49 -08:00
Jake VanderPlas
ed4c94497a
jnp.array.transpose: support positional axis arguments
2021-03-12 11:16:50 -08:00
Jake VanderPlas
60dcd0da61
Set reasonable timeouts for github actions jobs
2021-03-12 09:31:43 -08:00
jax authors
ea07d41947
Merge pull request #6041 from tomhennigan:changelist/362481121
...
PiperOrigin-RevId: 362537513
2021-03-12 09:29:21 -08:00
Mateusz Sokół
d743aa5803
Added 'where' keyword to 'jnp.{mean, var, std}'
2021-03-12 17:57:17 +01:00
Tom Hennigan
4f74b3391c
Update README.md
...
Co-authored-by: 8bitmp3 <19637339+8bitmp3@users.noreply.github.com>
2021-03-12 16:38:29 +00:00
jax authors
9785230e20
Merge pull request #6043 from hawkinsp:pocketfft
...
PiperOrigin-RevId: 362504596
2021-03-12 06:02:50 -08:00
Peter Hawkins
1ed321cbcd
Update PocketFFT version to fix crash due to undersized aligned allocations.
2021-03-12 08:46:01 -05:00
Tom Hennigan
0fd83a30be
Link to DeepMind JAX blog post.
2021-03-12 12:01:20 +00:00
jax authors
c7ebc3ed74
Merge pull request #6032 from skye:grpc_env_var
...
PiperOrigin-RevId: 362389283
2021-03-11 15:22:38 -08:00
jax authors
ee8ecb0a9b
Merge pull request #6015 from sethvargo:patch-1
...
PiperOrigin-RevId: 362387994
2021-03-11 15:15:53 -08:00
Skye Wanderman-Milne
5cb5056ea7
Suppress gRPC log spam on Cloud TPU.
2021-03-11 22:52:54 +00:00
jax authors
077793cd64
Merge pull request #6019 from skye:examples_test
...
PiperOrigin-RevId: 362363477
2021-03-11 13:23:53 -08:00
jax authors
65ee6041cb
Merge pull request #6020 from jamestwebber:patch-1
...
PiperOrigin-RevId: 362360958
2021-03-11 13:11:21 -08:00
James Webber
cf4f44548c
sync markdown
2021-03-11 14:24:41 -05:00