150 Commits

Author SHA1 Message Date
Bharat123rox
69d12111fc Implemented np.fix 2019-04-30 22:33:25 +05:30
Peter Hawkins
c47cca2058 Perform division in mean using the target dtype, rather than performing a true_divide and then casting back to the correct type. 2019-04-26 15:51:45 -07:00
Dheeraj Rajaram Reddy
8bcd0afac5 Use lax.sign instead of lax.div(x, lax.abs(x)) 2019-04-14 22:42:52 +05:30
Dheeraj Rajaram Reddy
fd2c0746b1 Implement np.cbrt 2019-04-14 22:34:23 +05:30
Navneet M Kumar
9cf6096b72 Implement np.diff (#588)
* Added np.diff

* Added np.diff test

* Added nonzerodim shapes
2019-04-13 21:01:46 -07:00
Navneet M Kumar
d7f623ca9d Added np.array_equal (#609)
Added test for np.isrealobj, added np.array_equal
2019-04-13 09:53:59 -07:00
Matthew Johnson
0cf14837c9 make a lax package, revert control flow names (#607)
c.f. #597
pair=skyewm
2019-04-12 16:28:40 -07:00
Skye Wanderman-Milne
105e46f379 Factor out control flow from lax.py into lax_control_flow.py.
Also moves control flow tests to lax_control_flow_test.py.
2019-04-12 13:57:18 -07:00
Navneet M Kumar
8a15d57bf3 Adding np.isrealobj #70 (#594)
* Added np.isrealobj

* Added test for np.isrealobj
2019-04-11 10:16:55 -07:00
ayir
d42f515c65 Fixed numpy.zeros shape generator error 2019-04-08 19:15:47 +02:00
Matthew Johnson
f4e141d30e add 'optimize' kwarg to jax.numpy.einsum 2019-04-06 15:26:33 -07:00
Matthew Johnson
45b3c2fa1f tweak comments 2019-04-06 13:07:27 -07:00
Matthew Johnson
6ec2eb72e5 make np.arange(N) create lazy const, arange tests 2019-04-06 12:52:47 -07:00
Matthew Johnson
1be9abd322 add jax.numpy.einsum_path (fixes #579) 2019-04-06 10:33:18 -07:00
Matthew Johnson
31e35b204a make np.reshape reflect on argument method
Reshapes should be cheap, but because `np.reshape` would always call
`lax.reshape` regardless of whether it was given a raw ndarray or one of
our DeviceArrays, it would sometimes copy ndarray data into a
DeviceArray. Our general policy is always to copy data to the device
(and lazily leave it there until the host needs it), but this policy
fell down here because of doing a reshape on data before a `pmap`'d
computation: the op-by-op `np.reshape` call put all the data on one
device, then the following `pmap` function had to copy everything back
to the host then re-distribute it to multiple devices. (The location of
what logical shards need to go on which device is computation-dependent,
so it's not something we can reliably do before actually getting to
execute the specific `pmap` function of interest.)

This commit makes a simple change in the `jax.numpy` layer to make
`np.reshape(x, shape)` try calling `x.reshape(shape)`, so that when `x`
is an ndarray it will stay an ndarray (without any transfer). This
change is not in the `lax` layer so that the `lax` policy can stay
simple (always copy to device). We might revise these decisions in the
future, and for now they're just under-the-hood optimizations, with the
ability for a user to directly call `onp` or `lax` if they want to be
careful about where data lives.

This commit also changed `jax.replicate` to replicate (with
`onp.broadcast_to`, which uses stride tricks instead of allocating more
memory) data to have a leading axis of size `device_count`. The previous
solution, based on `pmap`ing a function with a lexical closure, caused
re-compilation on every call.
2019-04-04 11:25:23 -07:00
Peter Hawkins
42c62d0dad Refactor handling of XLA backends.
Use a new xla_client.get_local_backend() method if available, which will be available in a future Jaxlib release.
Use 'cpu', 'gpu' to name platforms instead of 'Host', and 'CUDA'.

Move logic to initialize backends into get_backend() instead of get_xla_client().
Remove xla_bridge.get_xla_client(). Just import xla_client.xla_bridge instead.

Remove _platform_name. Instead, ask the backend for its platform name.
2019-03-29 11:09:56 -04:00
Matthew Johnson
c9aa60102f
Merge pull request #528 from kroq-gar78/cross
Implement cross product and test cases
2019-03-26 07:05:35 -07:00
Roy Frostig
323353ebf4 implement higher-rank cases of lax_numpy.dot in terms of lax.dot_general, avoiding a reshape 2019-03-25 17:31:03 -07:00
Aditya Vaidya
0a5a633d20 Fix dtypes in cross product 2019-03-25 19:29:49 -05:00
Aditya Vaidya
48934dc9d1 Implement cross product and test cases 2019-03-25 18:22:32 -05:00
Matthew Johnson
629c573fd3 handle numpy < 1.14 behavior of isclose 2019-03-22 17:07:10 -07:00
Peter Hawkins
8de992d706 Simplify Gather and Scatter by removing the index_vector_dim.
Use a canonical choice of:
* there is always a vector index dimension
* it is always the last dimension in indices.
2019-03-01 10:34:46 -05:00
Matthew Johnson
c8b9fe23d6 fix #453 in full, more exhaustive tests 2019-02-27 07:50:19 -08:00
Matthew Johnson
02124e31bf fix bug in transpose with order='F' (fixes #453) 2019-02-27 07:42:26 -08:00
Matthew Johnson
acd0150b4e
Merge pull request #426 from google/remove-unused-asarray-case
remove unused jax.numpy.array case (was typo)
2019-02-24 19:13:24 -08:00
Matthew Johnson
43a66701d5 add comment 2019-02-24 19:06:59 -08:00
Matthew Johnson
9815914d74 remove unused jax.numpy.array case (was typo) 2019-02-21 07:55:56 -08:00
Peter Hawkins
b033627f78 Implement np.float_power. 2019-02-21 10:07:33 -05:00
Peter Hawkins
7b0bcbe0a0 Only use binary exponentiation for integer/integer power() calls to avoid gradient problems. 2019-02-21 08:21:11 -05:00
Peter Hawkins
a63a402d7f Implement np.power for integer exponents. 2019-02-20 14:50:16 -05:00
Matthew Johnson
bf4ea4c099 guard against onp.lcm and onp.gcd not existing 2019-02-19 17:28:43 -08:00
Peter Hawkins
7fc4e0237b Implement np.gcd and np.lcm.
Taking the loop primitives out for a spin!
2019-02-19 15:57:22 -05:00
Peter Hawkins
f39292043c Implement np.roll (#70). 2019-02-18 15:52:32 -05:00
Masahiro H
c02fd5903d
Fix typo in comment 2019-02-16 23:29:35 +09:00
Matthew Johnson
2865cfac07 fix shape/dtype promotion order in some numpy funs 2019-02-13 08:59:21 -08:00
Matthew Johnson
bbf33709a6 switch builtin numeric types on six.PY3 2019-02-13 08:31:48 -08:00
Matthew Johnson
8df660e9ea use more _const and _constant_like helpers 2019-02-13 08:25:11 -08:00
Matthew Johnson
ea9c311349 remove 'long' because it's not in py3 2019-02-13 08:15:48 -08:00
Matthew Johnson
9425fa812a add 'long' and 'complex' to pyval promotion logic 2019-02-13 08:14:32 -08:00
Matthew Johnson
cad7db762b improve numpy dtype promo logic on Python scalars 2019-02-13 08:06:37 -08:00
Peter Hawkins
55acfb15e6 Implement np.linalg.norm. 2019-02-07 10:51:55 -05:00
Peter Hawkins
9338d3d704 Implement np.heaviside. 2019-02-06 09:05:53 -05:00
Peter Hawkins
7edd1337f2 Add axis argument to np.stack. Implement np.{dstack,atleast_3d}. 2019-02-06 08:46:11 -05:00
Peter Hawkins
c5433bd892 Implement np.{empty,empty_like,ptp,isreal,iscomplex,sinc,vander,positive}.
Fix bug in definition of `np.imag` for real numbers.
Fix wrong output (pi vs 0) for `np.angle` for negative real numbers. Fix semantics of angle for integers.

Issue #70
2019-02-05 10:29:31 -05:00
Peter Hawkins
84c30f7790
Merge pull request #319 from hawkinsp/numpy
Forward np.{bartlett,blackman,hamming,hanning,kaiser} to numpy.
2019-02-04 21:49:50 -05:00
Matthew Johnson
1a9c945386
Merge pull request #318 from levskaya/master
actually fix nondeterminism in einsum
2019-02-04 18:39:25 -08:00
Peter Hawkins
980ea88bfe Forward np.{bartlett,blackman,hamming,hanning,kaiser} to numpy. 2019-02-04 21:26:58 -05:00
Anselm Levskaya
893cf82898
actually fix the nondeterminism error in einsum batchdims
In the case where front batch_dims are already ordered correctly, fix the batch_names ordering to be correct.
2019-02-04 16:22:39 -08:00
Peter Hawkins
aed00fe335 Add np.{hsplit,vsplit,dsplit,deg2rad,rad2deg,degrees,radians,hypot,reciprocal,product}.
Forward np.{issubsctype,array_str,array_repr} to numpy.
2019-02-04 09:36:30 -05:00
Anselm Levskaya
723e3c46e4
make einsum deterministic, correct.
Fixes a nondeterministic batch-dimension reordering error that was caused by using a python set collection ordering to fix the final output permutations
2019-02-04 03:52:23 -08:00