16407 Commits

Author SHA1 Message Date
Matthew Johnson
18920337f3 add eye in lax_numpy 2018-12-07 07:34:52 -08:00
Matthew Johnson
50624bd978 rename in_bdims, out_bdims --> in_axes, out_axes 2018-12-07 06:53:29 -08:00
Dougal Maclaurin
ae7df43e9b Fixed bug due to input_shape kwarg not being modified in batching rule for reducers. Fixes b/120595235 2018-12-06 22:45:49 -05:00
Dougal Maclaurin
1627827ac6 Fixed a couple of bugs 2018-12-06 21:47:47 -05:00
Matthew Johnson
bbc92ce6eb
Split out jax and jaxlib packages (#11)
factor out 'jaxlib' as separate package
2018-12-06 21:35:03 -05:00
Dougal Maclaurin
8a3241e36b Merge branch 'master' into quickish-check 2018-12-06 19:02:28 -05:00
Dougal Maclaurin
494c2dbae2 Merge branch 'master' into quickish-check 2018-12-06 18:58:26 -05:00
Dougal Maclaurin
8b88027df0 Number of test cases settable with command-line flag 2018-12-06 18:30:59 -05:00
Peter Hawkins
e03af5df77
Merge pull request #3 from cclauss/patch-2
Undefined name: from ..core import JaxTuple
2018-12-06 18:20:23 -05:00
Dougal Maclaurin
ebc6cd1e03 Step through sizes in test case generation 2018-12-06 18:02:43 -05:00
Dougal Maclaurin
3dbf41f3e6 Generated function tests working with bazel 2018-12-06 17:31:52 -05:00
Peter Hawkins
c1b9eb19ea [JAX] Change semantics of dtype promotion to just call numpy.result_type.
* Enable tests for numpy scalars in lax_numpy_test.py.
* Fix invalid promotion in random.py.
* Split tests for bitwise ops into their own test case and test mixed signedness.
* Add complex64 to the set of types supported by abstractify.
2018-12-06 13:25:42 -05:00
Peter Hawkins
f2795cbdea [JAX] Add a NUMPY_SCALAR_SHAPE constant shape to test_utils.py to allow tests for both numpy scalars as distinct from 0D ndarrays.
Fix mishandling of scalars when passed to np.reshape().

PiperOrigin-RevId: 224326107
2018-12-06 06:47:58 -08:00
Dougal
d7b7200884 Error message for unimplemented numpy functions 2018-12-05 15:55:01 -05:00
Dougal Maclaurin
2e4ff400fa Fixed bug in vjp with constant-zero tangent outputs 2018-12-05 15:55:01 -05:00
Dougal Maclaurin
709cfe905d Set default TF log level to "1" to avoid reporting things like CPU frequency at import time. Also import jax.numpy in __init__.py because it has side effects that set up the infix operator overloading. 2018-12-05 15:55:01 -05:00
Dougal Maclaurin
307d195577 Wrapped static args to jit to be hashed on id. This is conservative but simple and cheap. 2018-12-05 15:55:01 -05:00
Matthew Johnson
2597300c7f source sync
PiperOrigin-RevId: 224157599
2018-12-05 09:36:16 -08:00
Matthew Johnson
3b049e3853 [XLA] separate out an Execute from ExecutePerReplica
[JAX] reduce the creation of XLA Shape protos on every call (which is slow)

PiperOrigin-RevId: 223915944
2018-12-05 09:36:12 -08:00
Dougal Maclaurin
2df36f7510 Made a shim to handle configuration without having absl parse command-line flags.
PiperOrigin-RevId: 223391288
2018-11-29 13:44:54 -08:00
Matthew Johnson
1d2aaad6fe fix handling of symbolic zeros for a few special primitives
PiperOrigin-RevId: 223264329
2018-11-29 12:06:17 -08:00
Dougal Maclaurin
ca2634ea5d source sync
PiperOrigin-RevId: 222923229
2018-11-27 16:51:22 -08:00
Peter Hawkins
f3513a7bfb [JAX] Rewrite OSS build script.
Significant changes:
* Mac OS X support.
* build script is in Python, not shell.
* build configuration is passed via flags, not environment variables.
* build script configures TF itself, and does not require explicitly checking out the TF git repository and running its configure script. Changes the TF dependency in the Bazel workspace to be an http_archive(), rather than a local checkout of TF.
* rather than trying to guess the path for Bazel-generated XLA artifacts, use a sh_binary() to perform installation of the built artifacts in to the JAX source tree. Bazel's runfiles mechanism is the supported route to find build artifacts.
* downloads Bazel in Python and checks its SHA256 before running it, rather than running an untrusted binary from the internet.
* intentionally does not delete the Bazel cache or Bazel after building.

Example of new build interaction:

Building without CUDA on Mac or Linux:
$ cd jax
$ python3 build.py   (or python2 build.py if you want a Python 2 build)

     _   _    __  __
    | | / \   \ \/ /
 _  | |/ _ \   \  /
| |_| / ___ \  /  \
 \___/_/   \_\/_/\_\

Starting local Bazel server and connecting to it...
Bazel binary path: /Users/xyz/bin/bazel
Python binary path: /Library/Frameworks/Python.framework/Versions/3.7/bin/python3
CUDA enabled: no

Building XLA and installing it in the JAX source tree...
...

Example of building with CUDA enabled on Linux:
$ python3 build.py --enable_cuda --cudnn_path=/usr/lib/x86_64-linux-gnu/
... as before, except ...
CUDA enabled: yes
CUDA toolkit path: /usr/local/cuda
CUDNN library path: /usr/lib/x86_64-linux-gnu/
...

PiperOrigin-RevId: 222868835
2018-11-27 16:51:17 -08:00
Sam Schoenholz
326773808b source sync
PiperOrigin-RevId: 222867729
2018-11-27 16:51:15 -08:00
Matthew Johnson
c293f7c875 minor: add @jit to threefry hash function in random.py
PiperOrigin-RevId: 222841601
2018-11-27 16:51:13 -08:00
Peter Hawkins
6001fd219d [JAX] Python 3 fix: xla_client.initialize_platform_name() accepts a string, not bytes, so we shouldn't pass it bytes. (The previous Python 3 change was a bit overzealous and made both changes; apparently this path isn't tested by the tests because they pass an explicit --jax_xla_backend flag..)
PiperOrigin-RevId: 222810811
2018-11-27 16:51:08 -08:00
Roy Frostig
642046e2c0 Bridge to the XRT backend.
PiperOrigin-RevId: 222763810
2018-11-27 16:51:06 -08:00
Matthew Johnson
dbf3c606f5 source sync
PiperOrigin-RevId: 222500675
2018-11-27 16:51:00 -08:00
cclauss
33292ff962
Undefined name: from ..core import JaxTuple
[flake8](http://flake8.pycqa.org) testing of https://github.com/google/jax on Python 3.7.1

$ __flake8 . --count --select=E901,E999,F821,F822,F823 --show-source --statistics__
```
./jax/interpreters/ad.py:189:20: F821 undefined name 'JaxTuple'
        return xt, JaxTuple(map(zeros_like_jaxval, xt))
                   ^
./jax/interpreters/ad.py:196:16: F821 undefined name 'JaxTuple'
        return JaxTuple(map(zeros_like_jaxval, yt)), yt
               ^
2    F821 undefined name 'JaxTuple'
2
```
2018-11-22 17:38:07 +01:00
Matthew Johnson
9101d66b4e source sync
PiperOrigin-RevId: 222487460
2018-11-21 20:32:33 -08:00
Matthew Johnson
8317cc3618 source sync
PiperOrigin-RevId: 222484671
2018-11-21 20:32:33 -08:00
Matthew Johnson
3b3490f406 source sync
PiperOrigin-RevId: 222483357
2018-11-21 20:32:33 -08:00
Matthew Johnson
ab53373665 source sync
PiperOrigin-RevId: 222470141
2018-11-21 20:32:33 -08:00
Matthew Johnson
2ae9a2bc35 source sync
PiperOrigin-RevId: 222461242
2018-11-21 20:32:16 -08:00
Peter Hawkins
5e60639bc5 source sync
PiperOrigin-RevId: 222452709
2018-11-21 20:22:54 -08:00
Peter Hawkins
e180f08113 source sync
PiperOrigin-RevId: 222451919
2018-11-21 20:22:51 -08:00
Peter Hawkins
fe4edf2839 source sync
PiperOrigin-RevId: 222449830
2018-11-21 20:22:49 -08:00
Matthew Johnson
e5b76f4fde source sync
PiperOrigin-RevId: 222340967
2018-11-21 20:22:43 -08:00
Peter Hawkins
065bb0baa2 source sync
PiperOrigin-RevId: 222291726
2018-11-21 20:22:41 -08:00
Matthew Johnson
7f546b8c02 source sync
PiperOrigin-RevId: 222175432
2018-11-21 20:22:35 -08:00
Matthew Johnson
25fb9b421d source sync
PiperOrigin-RevId: 222170151
2018-11-21 20:22:33 -08:00
Roy Frostig
a3619ca89d source sync
PiperOrigin-RevId: 222153576
2018-11-21 20:22:30 -08:00
Matthew Johnson
377322d3d4 remove stray file 2018-11-19 21:58:15 -08:00
Matthew Johnson
b2b1e8d70c source sync 2018-11-19 21:24:30 -08:00
Matthew Johnson
50038c07c8 fix build file issues 2018-11-19 20:18:31 -08:00
Roy Frostig
fe11b19e46 source sync 2018-11-19 15:44:16 -08:00
Roy Frostig
3731ca2299 source sync 2018-11-19 15:08:46 -08:00
Roy Frostig
99f98f8e8c source sync 2018-11-19 13:50:57 -08:00
Roy Frostig
24f7f35e16 source sync 2018-11-19 13:33:37 -08:00
Roy Frostig
da2d53ad33 source sync 2018-11-19 13:29:47 -08:00