770 Commits

Author SHA1 Message Date
Peter Hawkins
407306293f Update lax documentation to reflect new code organization. 2019-04-15 12:16:14 -04:00
Matthew Johnson
aa5b036d6d misc python performance improvements 2019-04-15 07:45:10 -07:00
Dheeraj Rajaram Reddy
8bcd0afac5 Use lax.sign instead of lax.div(x, lax.abs(x)) 2019-04-14 22:42:52 +05:30
Dheeraj Rajaram Reddy
fd2c0746b1 Implement np.cbrt 2019-04-14 22:34:23 +05:30
Navneet M Kumar
9cf6096b72 Implement np.diff (#588)
* Added np.diff

* Added np.diff test

* Added nonzerodim shapes
2019-04-13 21:01:46 -07:00
Matthew Johnson
d7096a42c5
make jacrev work w/ complex inputs, update errors (#610)
* make jacrev work w/ complex inputs, update errors

* fix up complex handling in jacfwd and jacrev
2019-04-13 13:22:45 -07:00
Navneet M Kumar
d7f623ca9d Added np.array_equal (#609)
Added test for np.isrealobj, added np.array_equal
2019-04-13 09:53:59 -07:00
Matthew Johnson
f49ab50ec1 expose lax._safe_mul again (c.f. #608) 2019-04-13 08:18:57 -07:00
Matthew Johnson
0cf14837c9 make a lax package, revert control flow names (#607)
c.f. #597
pair=skyewm
2019-04-12 16:28:40 -07:00
Peter Hawkins
4f0280fe36
Merge pull request #605 from hawkinsp/master
Copy pmap arguments to device in parallel.
2019-04-12 18:59:12 -04:00
Peter Hawkins
23c3ad17fc Fix indentation. 2019-04-12 17:26:21 -04:00
Peter Hawkins
ac8de3360d Add docstring and TODOs. 2019-04-12 17:11:20 -04:00
Matthew Johnson
a3a8e48e76
Merge pull request #604 from google/issue603
add error checks so that #603 isn't silent fail
2019-04-12 14:06:01 -07:00
Skye Wanderman-Milne
105e46f379 Factor out control flow from lax.py into lax_control_flow.py.
Also moves control flow tests to lax_control_flow_test.py.
2019-04-12 13:57:18 -07:00
Matthew Johnson
849ea87b33 tree-map the real dtype check in api.py 2019-04-12 13:29:07 -07:00
Peter Hawkins
38d764c438 Add a device_put_many() method that copies multiple tensors to accelerator devices in parallel. Use it to copy pmap arguments. 2019-04-12 16:18:48 -04:00
Matthew Johnson
18671fa027 add error checks so that #603 isn't silent fail 2019-04-12 12:01:19 -07:00
Matthew Johnson
8120f48b16
Merge pull request #551 from fehiepsi/gamma
Add Gamma sampler
2019-04-12 08:58:59 -07:00
Peter Hawkins
ca0d943999 Test case improvements:
* use numpy.random to select test cases, rather than random. This allows more control over random seeds. Pick a fixed random seed for each test case.
* sort types in linalg_test.py so the choice of test cases is deterministic.
* use known_flags=True when doing early parsing of flags from parse_flags_with_absl.
2019-04-12 10:48:11 -04:00
Matthew Johnson
35e340670c
Merge branch 'master' into gamma 2019-04-12 07:15:41 -07:00
Matthew Johnson
7dab5e025f
remove the current lax.scan implementation (#599)
fixes #598
2019-04-11 21:26:06 -07:00
Navneet M Kumar
8a15d57bf3 Adding np.isrealobj #70 (#594)
* Added np.isrealobj

* Added test for np.isrealobj
2019-04-11 10:16:55 -07:00
Matthew Johnson
de2a5f725d add warning, fix typo in kwargs test and bug 2019-04-11 06:58:09 -07:00
Matthew Johnson
2582598294 remove assert that fails on python3 map objects 2019-04-10 22:17:54 -07:00
Matthew Johnson
9c2e1c35b1 prevent jit from treating keyword args as static
fixes #523
2019-04-10 22:09:14 -07:00
Skye Wanderman-Milne
25278ac749 Factor out parallel functionality from lax.py into lax_parallel.py. 2019-04-10 13:02:04 -07:00
Matthew Johnson
7d7dd257b7
Merge pull request #592 from levskaya/tconv2
transposed convolution implementation
2019-04-10 06:44:26 -07:00
Anselm Levskaya
cef4c94c13 finish transposed convolution implementation and tests 2019-04-09 22:59:03 -07:00
Anselm Levskaya
797d411eeb initial tranpose conv implementation 2019-04-09 15:06:46 -07:00
Roy Frostig
4c73b18b53 Merge branch 'master' into parallelize 2019-04-09 10:47:18 -07:00
ayir
d42f515c65 Fixed numpy.zeros shape generator error 2019-04-08 19:15:47 +02:00
Matthew Johnson
03c64af2d5 bump version for pypi 2019-04-07 17:50:43 -07:00
Matthew Johnson
f4e141d30e add 'optimize' kwarg to jax.numpy.einsum 2019-04-06 15:26:33 -07:00
Matthew Johnson
1fdaedbccf bump version for pypi 2019-04-06 14:18:10 -07:00
Matthew Johnson
45b3c2fa1f tweak comments 2019-04-06 13:07:27 -07:00
Matthew Johnson
6ec2eb72e5 make np.arange(N) create lazy const, arange tests 2019-04-06 12:52:47 -07:00
Matthew Johnson
1be9abd322 add jax.numpy.einsum_path (fixes #579) 2019-04-06 10:33:18 -07:00
Matthew Johnson
f27844d93e bump version for pypi 2019-04-05 07:53:03 -07:00
Matthew Johnson
4e39876941
Merge pull request #566 from j-towns/jax-random-stax
Use jax.random for stax initialization
2019-04-05 07:45:09 -07:00
Matthew Johnson
054d210a32 fix typo in xla_computation 2019-04-04 17:40:48 -07:00
Matthew Johnson
31e35b204a make np.reshape reflect on argument method
Reshapes should be cheap, but because `np.reshape` would always call
`lax.reshape` regardless of whether it was given a raw ndarray or one of
our DeviceArrays, it would sometimes copy ndarray data into a
DeviceArray. Our general policy is always to copy data to the device
(and lazily leave it there until the host needs it), but this policy
fell down here because of doing a reshape on data before a `pmap`'d
computation: the op-by-op `np.reshape` call put all the data on one
device, then the following `pmap` function had to copy everything back
to the host then re-distribute it to multiple devices. (The location of
what logical shards need to go on which device is computation-dependent,
so it's not something we can reliably do before actually getting to
execute the specific `pmap` function of interest.)

This commit makes a simple change in the `jax.numpy` layer to make
`np.reshape(x, shape)` try calling `x.reshape(shape)`, so that when `x`
is an ndarray it will stay an ndarray (without any transfer). This
change is not in the `lax` layer so that the `lax` policy can stay
simple (always copy to device). We might revise these decisions in the
future, and for now they're just under-the-hood optimizations, with the
ability for a user to directly call `onp` or `lax` if they want to be
careful about where data lives.

This commit also changed `jax.replicate` to replicate (with
`onp.broadcast_to`, which uses stride tricks instead of allocating more
memory) data to have a leading axis of size `device_count`. The previous
solution, based on `pmap`ing a function with a lexical closure, caused
re-compilation on every call.
2019-04-04 11:25:23 -07:00
Matthew Johnson
4bdb6825ab
Merge pull request #574 from levskaya/configfix
correctly update jax config.values after absl flag parsing
2019-04-04 08:11:42 -07:00
Anselm Levskaya
116e329e10 correctly update jax config.values after absl flag parsing 2019-04-04 02:09:35 -07:00
Jamie Townsend
49f3f991d4 Use safe mul for exp jvp 2019-04-04 09:12:27 +01:00
Roy Frostig
357dd49635 Merge branch 'master' into parallelize 2019-04-03 15:32:06 -07:00
Roy Frostig
794af8bd55 parallelization rule for lax.select 2019-04-03 15:13:04 -07:00
Roy Frostig
2cec9f97d5 test psplit under a simulated pmap 2019-04-03 14:51:24 -07:00
Roy Frostig
866c21d8ae add a new parallel primitive: psplit_like 2019-04-03 14:50:01 -07:00
Peter Hawkins
a37441aa13
Merge pull request #570 from hawkinsp/master
Improve batching rule for conv_general_dilated
2019-04-03 13:15:47 -07:00
Peter Hawkins
1ae2df5444 Add support for batching conv_general_dilated on the left where lhs_dims[0] != 0 or out_dims[0] != 0. 2019-04-03 12:41:14 -07:00