Dougal Maclaurin
3dbf41f3e6
Generated function tests working with bazel
2018-12-06 17:31:52 -05:00
Dougal Maclaurin
c3374a9d5f
Added build rule for generated_fun_test (formerly quickish_check)
2018-12-06 17:04:00 -05:00
Dougal Maclaurin
29113dd606
Made tests runnable with bazel
2018-12-06 17:00:47 -05:00
Peter Hawkins
c1b9eb19ea
[JAX] Change semantics of dtype promotion to just call numpy.result_type.
...
* Enable tests for numpy scalars in lax_numpy_test.py.
* Fix invalid promotion in random.py.
* Split tests for bitwise ops into their own test case and test mixed signedness.
* Add complex64 to the set of types supported by abstractify.
2018-12-06 13:25:42 -05:00
Matthew Johnson
b4344a07bc
fix typo in WORKSPACE
2018-12-06 07:13:27 -08:00
Matthew Johnson
90afb3a155
update tensorflow release
2018-12-06 07:03:28 -08:00
Dougal Maclaurin
678bcee7ba
Merge branch 'quickish-check' of github.com:google/jax into quickish-check
2018-12-06 09:51:49 -05:00
Peter Hawkins
f2795cbdea
[JAX] Add a NUMPY_SCALAR_SHAPE constant shape to test_utils.py to allow tests for both numpy scalars as distinct from 0D ndarrays.
...
Fix mishandling of scalars when passed to np.reshape().
PiperOrigin-RevId: 224326107
2018-12-06 06:47:58 -08:00
Alex Wiltschko
47ade41368
Adding quickstart notebook, and corresponding gitignore rules
2018-12-06 08:53:44 -05:00
Dougal Maclaurin
727755178d
Added license header
2018-12-05 15:55:01 -05:00
Dougal
d7b7200884
Error message for unimplemented numpy functions
2018-12-05 15:55:01 -05:00
Dougal Maclaurin
2e4ff400fa
Fixed bug in vjp with constant-zero tangent outputs
2018-12-05 15:55:01 -05:00
Dougal Maclaurin
2f44eba01d
Made some subset of vjp/jvp inputs static in quickercheck. Exposing bugs.
2018-12-05 15:55:01 -05:00
Dougal Maclaurin
f1d7ea8972
Added reverse-mode checks
2018-12-05 15:55:01 -05:00
Dougal Maclaurin
6fcee12cec
Added forward derivative checks
2018-12-05 15:55:01 -05:00
Dougal Maclaurin
709cfe905d
Set default TF log level to "1" to avoid reporting things like CPU frequency at import time. Also import jax.numpy in __init__.py because it has side effects that set up the infix operator overloading.
2018-12-05 15:55:01 -05:00
Dougal Maclaurin
307d195577
Wrapped static args to jit to be hashed on id. This is conservative but simple and cheap.
2018-12-05 15:55:01 -05:00
Dougal Maclaurin
5bb8f87e22
First pass at a quickcheck-style property-based tester to test invariants of jit/grad/vmap etc on random functions
2018-12-05 15:55:01 -05:00
Matthew Johnson
2597300c7f
source sync
...
PiperOrigin-RevId: 224157599
2018-12-05 09:36:16 -08:00
Matthew Johnson
3b049e3853
[XLA] separate out an Execute from ExecutePerReplica
...
[JAX] reduce the creation of XLA Shape protos on every call (which is slow)
PiperOrigin-RevId: 223915944
2018-12-05 09:36:12 -08:00
Peter Hawkins
77db9bd556
[JAX] Update XLA in JAX workspace to include optimized computation launch implementation.
...
PiperOrigin-RevId: 223863753
2018-12-05 09:36:08 -08:00
Dougal Maclaurin
ac6cee2157
Added license header
2018-12-05 11:21:58 -05:00
Dougal
585f011d63
Error message for unimplemented numpy functions
2018-12-05 10:01:14 -05:00
Dougal Maclaurin
f5232aaeea
Fixed bug in vjp with constant-zero tangent outputs
2018-12-03 22:24:46 -05:00
Dougal Maclaurin
99023a24fa
Made some subset of vjp/jvp inputs static in quickercheck. Exposing bugs.
2018-12-03 09:52:28 -05:00
Roy Frostig
f5b051b431
Double gpu test shards
...
PiperOrigin-RevId: 223647959
2018-12-02 11:50:46 -08:00
Roy Frostig
20878c76f4
source sync
...
PiperOrigin-RevId: 223530503
2018-12-02 11:50:39 -08:00
Dougal Maclaurin
a5df01abfd
Added reverse-mode checks
2018-11-30 17:14:27 -05:00
Dougal Maclaurin
4715ce52ec
Added forward derivative checks
2018-11-30 16:41:12 -05:00
Dougal Maclaurin
5d5a6bc7ce
Set default TF log level to "1" to avoid reporting things like CPU frequency at import time. Also import jax.numpy in __init__.py because it has side effects that set up the infix operator overloading.
2018-11-30 16:17:54 -05:00
Dougal Maclaurin
0411beb357
Wrapped static args to jit to be hashed on id. This is conservative but simple and cheap.
2018-11-30 16:16:28 -05:00
Dougal Maclaurin
aab9faa41d
First pass at a quickcheck-style property-based tester to test invariants of jit/grad/vmap etc on random functions
2018-11-30 15:22:02 -05:00
Dougal Maclaurin
2df36f7510
Made a shim to handle configuration without having absl parse command-line flags.
...
PiperOrigin-RevId: 223391288
2018-11-29 13:44:54 -08:00
Matthew Johnson
1d2aaad6fe
fix handling of symbolic zeros for a few special primitives
...
PiperOrigin-RevId: 223264329
2018-11-29 12:06:17 -08:00
Matthew Johnson
aefd8f9eb5
add contributing.md
...
PiperOrigin-RevId: 223085949
2018-11-27 17:18:37 -08:00
Roy Frostig
14844224fc
source sync
...
PiperOrigin-RevId: 223081582
2018-11-27 16:51:28 -08:00
Matthew Johnson
0ea98501aa
source sync
...
PiperOrigin-RevId: 223080639
2018-11-27 16:51:26 -08:00
Peter Hawkins
93ac03ea08
[JAX] Update TensorFlow release to 0b6ed4887a
to pick up fixes to the XLA Mac OS X build.
...
PiperOrigin-RevId: 223048563
2018-11-27 16:51:24 -08:00
Dougal Maclaurin
ca2634ea5d
source sync
...
PiperOrigin-RevId: 222923229
2018-11-27 16:51:22 -08:00
Peter Hawkins
599ea38175
[JAX] Explicitly use /bin/bash in install_xla_in_source_tree.sh.
...
It turns out the script uses some bash-isms not supported by the /bin/sh (dash) shell on a Debian machine.
PiperOrigin-RevId: 222884583
2018-11-27 16:51:19 -08:00
Peter Hawkins
f3513a7bfb
[JAX] Rewrite OSS build script.
...
Significant changes:
* Mac OS X support.
* build script is in Python, not shell.
* build configuration is passed via flags, not environment variables.
* build script configures TF itself, and does not require explicitly checking out the TF git repository and running its configure script. Changes the TF dependency in the Bazel workspace to be an http_archive(), rather than a local checkout of TF.
* rather than trying to guess the path for Bazel-generated XLA artifacts, use a sh_binary() to perform installation of the built artifacts in to the JAX source tree. Bazel's runfiles mechanism is the supported route to find build artifacts.
* downloads Bazel in Python and checks its SHA256 before running it, rather than running an untrusted binary from the internet.
* intentionally does not delete the Bazel cache or Bazel after building.
Example of new build interaction:
Building without CUDA on Mac or Linux:
$ cd jax
$ python3 build.py (or python2 build.py if you want a Python 2 build)
_ _ __ __
| | / \ \ \/ /
_ | |/ _ \ \ /
| |_| / ___ \ / \
\___/_/ \_\/_/\_\
Starting local Bazel server and connecting to it...
Bazel binary path: /Users/xyz/bin/bazel
Python binary path: /Library/Frameworks/Python.framework/Versions/3.7/bin/python3
CUDA enabled: no
Building XLA and installing it in the JAX source tree...
...
Example of building with CUDA enabled on Linux:
$ python3 build.py --enable_cuda --cudnn_path=/usr/lib/x86_64-linux-gnu/
... as before, except ...
CUDA enabled: yes
CUDA toolkit path: /usr/local/cuda
CUDNN library path: /usr/lib/x86_64-linux-gnu/
...
PiperOrigin-RevId: 222868835
2018-11-27 16:51:17 -08:00
Sam Schoenholz
326773808b
source sync
...
PiperOrigin-RevId: 222867729
2018-11-27 16:51:15 -08:00
Matthew Johnson
c293f7c875
minor: add @jit to threefry hash function in random.py
...
PiperOrigin-RevId: 222841601
2018-11-27 16:51:13 -08:00
Matthew Johnson
d318c68827
source sync
...
PiperOrigin-RevId: 222822439
2018-11-27 16:51:11 -08:00
Peter Hawkins
6001fd219d
[JAX] Python 3 fix: xla_client.initialize_platform_name() accepts a string, not bytes, so we shouldn't pass it bytes. (The previous Python 3 change was a bit overzealous and made both changes; apparently this path isn't tested by the tests because they pass an explicit --jax_xla_backend flag..)
...
PiperOrigin-RevId: 222810811
2018-11-27 16:51:08 -08:00
Roy Frostig
642046e2c0
Bridge to the XRT backend.
...
PiperOrigin-RevId: 222763810
2018-11-27 16:51:06 -08:00
Matthew Johnson
dbf3c606f5
source sync
...
PiperOrigin-RevId: 222500675
2018-11-27 16:51:00 -08:00
cclauss
c3c64138a4
Undefined name: from six.moves import xrange
...
[flake8](http://flake8.pycqa.org ) testing of https://github.com/google/jax on Python 3.7.1
$ __flake8 . --count --select=E901,E999,F821,F822,F823 --show-source --statistics__
```
./examples/resnet50.py:124:12: F821 undefined name 'xrange'
for i in xrange(num_steps):
^
1 F821 undefined name 'xrange'
1
```
2018-11-22 17:50:50 +01:00
cclauss
33292ff962
Undefined name: from ..core import JaxTuple
...
[flake8](http://flake8.pycqa.org ) testing of https://github.com/google/jax on Python 3.7.1
$ __flake8 . --count --select=E901,E999,F821,F822,F823 --show-source --statistics__
```
./jax/interpreters/ad.py:189:20: F821 undefined name 'JaxTuple'
return xt, JaxTuple(map(zeros_like_jaxval, xt))
^
./jax/interpreters/ad.py:196:16: F821 undefined name 'JaxTuple'
return JaxTuple(map(zeros_like_jaxval, yt)), yt
^
2 F821 undefined name 'JaxTuple'
2
```
2018-11-22 17:38:07 +01:00
cclauss
2b995217b0
Explicit tuples are not valid function parameters in Python 3
...
[flake8](http://flake8.pycqa.org ) testing of https://github.com/google/jax on Python 3.7.1
$ __flake8 . --count --select=E901,E999,F821,F822,F823 --show-source --statistics__
```
./examples/mnist_vae.py:113:21: E999 SyntaxError: invalid syntax
def body_fun(i, (rng, opt_state, images)):
^
1 E999 SyntaxError: invalid syntax
1
```
2018-11-22 17:29:47 +01:00