259 Commits

Author SHA1 Message Date
James Bradbury
fc3a3ac075
Merge pull request #1349 from xysun/issue1233
Support for non-scalar np.repeat arguments (fixes #1233)
2019-09-20 09:27:47 -07:00
xysun
6fd938d8a0 no more nested loops of dynamic_update_slice! 2019-09-20 10:25:16 +01:00
Nikita Kitaev
0e7ea7e379 Reduce memory usage for argmax (fixes #1330) 2019-09-16 14:30:28 -07:00
Peter Hawkins
45a02f39f0 Temporarily remove jit decorator on gather/scatter ops. 2019-09-16 13:57:07 -07:00
xysun
0f03afc5e3 python2.7 compatible 2019-09-14 22:32:45 +01:00
xysun
3c2003b24e jit works! 2019-09-14 21:57:46 +01:00
xysun
c8eb5f657d pretty much works 2019-09-14 21:24:28 +01:00
Peter Hawkins
b7b5328526
Merge pull request #1346 from hawkinsp/master
Add a DeviceArray._unstack() method that unpacks an array along its m…
2019-09-13 11:16:01 -07:00
Peter Hawkins
ae329dcdf4 Add a DeviceArray._unstack() method that unpacks an array along its major dimension.
Use it to implement pxla's shard_arg method for DeviceArrays; this is faster than slicing out each element one by one.
2019-09-13 13:47:09 -04:00
Peter Hawkins
723456bc82 Slice objects are unhashable, so unpack them into tuples for forming static arguments. 2019-09-13 13:39:39 -04:00
Peter Hawkins
5ffddc182e JIT-compile index and index-update expressions.
Improves the performance of indexing in op-by-op mode.
2019-09-13 10:37:41 -04:00
Matthew Johnson
c52027691b jax.numpy.stack and concatenate work on array args
fixes #1271
2019-09-02 07:55:25 -07:00
Matthew Johnson
cac042c34a move asinh/acosh/atanh to lax_numpy.py only 2019-08-31 22:39:51 -07:00
Matthew Johnson
1cd37bd977
reset jax_numpy_rank_promotion to "allow" default 2019-08-27 11:21:50 -07:00
fehiepsi
28df8a666b cast float64 to canonical dtype in np.cov 2019-08-26 01:32:18 -04:00
Matthew Johnson
3c2a73592c improve rank promotion warning, add doc page 2019-08-25 14:28:53 -07:00
Matthew Johnson
afe21bafa4 address reviewer comments 2019-08-24 12:34:44 -07:00
Matthew Johnson
e90457d737 add dtype warnings to array-creation routines
fixes #1230
2019-08-24 08:19:05 -07:00
Matthew Johnson
d700716e19 add option to disable rank-promotion broadcasting
fixes #1236
2019-08-23 18:13:18 -07:00
Matthew Johnson
a8e0c2559c
Merge pull request #1224 from google/no-more-tuples
No more tuples
2019-08-21 14:21:36 -07:00
Matthew Johnson
b702f8de3e De-tuplify the rest of the core
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-08-21 13:21:20 -07:00
Brian Patton
451ff2d694
Cope with old numpy lacking axis arg 2019-08-21 13:40:50 -04:00
David Majnemer
18beaab200 Use a faster, numerically more faithful, approach to logaddexp
We can use log1p and fewer instances of exp to compute the same result.
2019-08-16 22:36:00 -07:00
Brian Patton
85e5d63423
Fix an exception caused by cached() hashing
I *think* the issue was that one of the elements in shape was a `DeviceArray`.

  File "jax/random.py", line 717, in gamma
    return _gamma(key, a, shape, dtype)
  File "jax/api.py", line 151, in f_jitted
    device_assignment=device_assignment)
  File "jax/core.py", line 672, in call_bind
    ans = primitive.impl(f, *args, **params)
  File "jax/interpreters/xla.py", line 667, in _xla_call_impl
    *map(abstractify, args))
  File "jax/linear_util.py", line 213, in cached_fun
    ans, f_prev = cached_fun_body(f, args)
  File "jax/linear_util.py", line 210, in cached_fun_body
    return call(f, *args), f
  File "jax/interpreters/xla.py", line 679, in _xla_callable
    jaxpr, (pval, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
  File "jax/linear_util.py", line 161, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "jax/random.py", line 725, in _gamma
    a = np.broadcast_to(a, shape)
  File "jax/numpy/lax_numpy.py", line 821, in broadcast_to
    lax.broadcast_shapes(shape, _shape(arr))  # error checking
  File "jax/interpreters/xla.py", line 623, in __hash__
    raise TypeError("JAX DeviceArray, like numpy.ndarray, is not hashable.")
TypeError: JAX DeviceArray, like numpy.ndarray, is not hashable.
2019-08-16 16:49:45 -05:00
Brian Patton
48f64ec863
Avoid a TypeError when reps is or contains a ndarray
Something along the lines of `TypeError: multiply only accepts scalar or ndarray, but got a list.`
2019-08-16 12:31:53 -05:00
Peter Hawkins
9fcf526608
Merge pull request #1193 from hawkinsp/axis
Use _canonicalize_axis everywhere to canonicalize axes, rather than s…
2019-08-16 08:33:36 -04:00
Peter Hawkins
f4fde04760 Fix Python 2 compatibility problem. 2019-08-15 21:20:21 -04:00
Peter Hawkins
efe98e2b81 Use _canonicalize_axis everywhere to canonicalize axes, rather than sometimes mod or %.
_canonicalize_axis has behavior more faithful to NumPy, rejecting out of range axes.
2019-08-15 20:56:56 -04:00
Peter Hawkins
1ba13e1b82 Consistently return JAX arrays instead of Numpy-classic arrays from jax.numpy.
Avoids surprising behavior that sometimes arises when mixing the two.
2019-08-15 20:25:32 -04:00
Peter Hawkins
6d357fe884 Use select instead of rem to handle index wraparound. 2019-08-15 16:41:05 -04:00
Brian Patton
4b693777aa
Ensure reps is a tuple (allows list or other iterable) 2019-08-15 11:26:25 -05:00
David Majnemer
079ded4062 Use lax.rem less often in remainder 2019-08-14 12:00:04 -07:00
Peter Hawkins
a8ddf071bd Add test case for concurrent device_get and device_put calls.
Fix concurrency problems in memoize_... decorators.
Rename util.memoize to util.cache.
Remove util.memoize_unary and xla_bridge.memoize_thunk, replace with more general and thread-safe util.memoize that wraps fastcache.
2019-08-09 13:12:44 -04:00
James Bradbury
d0c9f45349 fix jax.numpy reduction init_val for bools 2019-08-03 21:27:06 -07:00
Matthew Johnson
3168006f4a fix np.var dtype bug 2019-08-02 11:26:17 -07:00
Matthew Johnson
1f3b4ae97e
Merge pull request #1091 from fehiepsi/tril
expose tril_indices, triu_indices similar to diag_indices
2019-08-01 20:58:22 -07:00
Matthew Johnson
fd98f957a9
Merge pull request #1088 from fehiepsi/median
Add numpy.median and support ddof for numpy.var
2019-08-01 20:57:28 -07:00
fehiepsi
7a5aecea31 expose tril_indices triu_indices 2019-08-01 17:35:36 -04:00
fehiepsi
45c5bd4fba support ddof for var 2019-08-01 16:20:08 -04:00
fehiepsi
98152d9d07 add numpy.median 2019-08-01 14:19:41 -04:00
Jamie Townsend
47f9eedb60
Correct jax.numpy.pad signature 2019-08-01 15:44:23 +01:00
Peter Hawkins
a350191331
Merge pull request #1074 from hawkinsp/pytree
Add C++ implementation of Pytree logic.
2019-07-30 20:50:09 -04:00
wyjw
4dcae5debf
Update lax_numpy.py 2019-07-29 22:56:30 -04:00
Peter Hawkins
510a9167c5 Add C++ implementation of pytree logic.
Move jaxlib version test into jax/lib/__init__.py. Make jax/lib mirror the structure of jaxlib; e.g., xla_client is now available as jax.lib.xla_client.
2019-07-29 15:06:05 -04:00
wyjw
b89e5a7ac0
shape_c variable taken out 2019-07-29 13:11:43 -04:00
wyjw
5487c784d6
added shape check 2019-07-29 13:05:48 -04:00
wyjw
a06883d91f
Made changes based on review. 2019-07-29 11:53:40 -04:00
wyjw
a87627b57b
Revert "Corrcoef" 2019-07-29 11:24:05 -04:00
twnly
3b6edbbe2f corrcoef 2019-07-29 11:12:02 -04:00
twnly
d9b7c5fa39 made changes to corrcoef 2019-07-29 11:06:08 -04:00