James Bradbury
fc3a3ac075
Merge pull request #1349 from xysun/issue1233
...
Support for non-scalar np.repeat arguments (fixes #1233 )
2019-09-20 09:27:47 -07:00
xysun
6fd938d8a0
no more nested loops of dynamic_update_slice!
2019-09-20 10:25:16 +01:00
Nikita Kitaev
0e7ea7e379
Reduce memory usage for argmax ( fixes #1330 )
2019-09-16 14:30:28 -07:00
Peter Hawkins
45a02f39f0
Temporarily remove jit
decorator on gather/scatter ops.
2019-09-16 13:57:07 -07:00
xysun
0f03afc5e3
python2.7 compatible
2019-09-14 22:32:45 +01:00
xysun
3c2003b24e
jit works!
2019-09-14 21:57:46 +01:00
xysun
c8eb5f657d
pretty much works
2019-09-14 21:24:28 +01:00
Peter Hawkins
b7b5328526
Merge pull request #1346 from hawkinsp/master
...
Add a DeviceArray._unstack() method that unpacks an array along its m…
2019-09-13 11:16:01 -07:00
Peter Hawkins
ae329dcdf4
Add a DeviceArray._unstack() method that unpacks an array along its major dimension.
...
Use it to implement pxla's shard_arg method for DeviceArrays; this is faster than slicing out each element one by one.
2019-09-13 13:47:09 -04:00
Peter Hawkins
723456bc82
Slice objects are unhashable, so unpack them into tuples for forming static arguments.
2019-09-13 13:39:39 -04:00
Peter Hawkins
5ffddc182e
JIT-compile index and index-update expressions.
...
Improves the performance of indexing in op-by-op mode.
2019-09-13 10:37:41 -04:00
Matthew Johnson
c52027691b
jax.numpy.stack and concatenate work on array args
...
fixes #1271
2019-09-02 07:55:25 -07:00
Matthew Johnson
cac042c34a
move asinh/acosh/atanh to lax_numpy.py only
2019-08-31 22:39:51 -07:00
Matthew Johnson
1cd37bd977
reset jax_numpy_rank_promotion to "allow" default
2019-08-27 11:21:50 -07:00
fehiepsi
28df8a666b
cast float64 to canonical dtype in np.cov
2019-08-26 01:32:18 -04:00
Matthew Johnson
3c2a73592c
improve rank promotion warning, add doc page
2019-08-25 14:28:53 -07:00
Matthew Johnson
afe21bafa4
address reviewer comments
2019-08-24 12:34:44 -07:00
Matthew Johnson
e90457d737
add dtype warnings to array-creation routines
...
fixes #1230
2019-08-24 08:19:05 -07:00
Matthew Johnson
d700716e19
add option to disable rank-promotion broadcasting
...
fixes #1236
2019-08-23 18:13:18 -07:00
Matthew Johnson
a8e0c2559c
Merge pull request #1224 from google/no-more-tuples
...
No more tuples
2019-08-21 14:21:36 -07:00
Matthew Johnson
b702f8de3e
De-tuplify the rest of the core
...
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-08-21 13:21:20 -07:00
Brian Patton
451ff2d694
Cope with old numpy lacking axis arg
2019-08-21 13:40:50 -04:00
David Majnemer
18beaab200
Use a faster, numerically more faithful, approach to logaddexp
...
We can use log1p and fewer instances of exp to compute the same result.
2019-08-16 22:36:00 -07:00
Brian Patton
85e5d63423
Fix an exception caused by cached()
hashing
...
I *think* the issue was that one of the elements in shape was a `DeviceArray`.
File "jax/random.py", line 717, in gamma
return _gamma(key, a, shape, dtype)
File "jax/api.py", line 151, in f_jitted
device_assignment=device_assignment)
File "jax/core.py", line 672, in call_bind
ans = primitive.impl(f, *args, **params)
File "jax/interpreters/xla.py", line 667, in _xla_call_impl
*map(abstractify, args))
File "jax/linear_util.py", line 213, in cached_fun
ans, f_prev = cached_fun_body(f, args)
File "jax/linear_util.py", line 210, in cached_fun_body
return call(f, *args), f
File "jax/interpreters/xla.py", line 679, in _xla_callable
jaxpr, (pval, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
File "jax/linear_util.py", line 161, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "jax/random.py", line 725, in _gamma
a = np.broadcast_to(a, shape)
File "jax/numpy/lax_numpy.py", line 821, in broadcast_to
lax.broadcast_shapes(shape, _shape(arr)) # error checking
File "jax/interpreters/xla.py", line 623, in __hash__
raise TypeError("JAX DeviceArray, like numpy.ndarray, is not hashable.")
TypeError: JAX DeviceArray, like numpy.ndarray, is not hashable.
2019-08-16 16:49:45 -05:00
Brian Patton
48f64ec863
Avoid a TypeError when reps is or contains a ndarray
...
Something along the lines of `TypeError: multiply only accepts scalar or ndarray, but got a list.`
2019-08-16 12:31:53 -05:00
Peter Hawkins
9fcf526608
Merge pull request #1193 from hawkinsp/axis
...
Use _canonicalize_axis everywhere to canonicalize axes, rather than s…
2019-08-16 08:33:36 -04:00
Peter Hawkins
f4fde04760
Fix Python 2 compatibility problem.
2019-08-15 21:20:21 -04:00
Peter Hawkins
efe98e2b81
Use _canonicalize_axis everywhere to canonicalize axes, rather than sometimes mod or %.
...
_canonicalize_axis has behavior more faithful to NumPy, rejecting out of range axes.
2019-08-15 20:56:56 -04:00
Peter Hawkins
1ba13e1b82
Consistently return JAX arrays instead of Numpy-classic arrays from jax.numpy.
...
Avoids surprising behavior that sometimes arises when mixing the two.
2019-08-15 20:25:32 -04:00
Peter Hawkins
6d357fe884
Use select
instead of rem
to handle index wraparound.
2019-08-15 16:41:05 -04:00
Brian Patton
4b693777aa
Ensure reps is a tuple (allows list or other iterable)
2019-08-15 11:26:25 -05:00
David Majnemer
079ded4062
Use lax.rem less often in remainder
2019-08-14 12:00:04 -07:00
Peter Hawkins
a8ddf071bd
Add test case for concurrent device_get and device_put calls.
...
Fix concurrency problems in memoize_... decorators.
Rename util.memoize to util.cache.
Remove util.memoize_unary and xla_bridge.memoize_thunk, replace with more general and thread-safe util.memoize that wraps fastcache.
2019-08-09 13:12:44 -04:00
James Bradbury
d0c9f45349
fix jax.numpy reduction init_val for bools
2019-08-03 21:27:06 -07:00
Matthew Johnson
3168006f4a
fix np.var dtype bug
2019-08-02 11:26:17 -07:00
Matthew Johnson
1f3b4ae97e
Merge pull request #1091 from fehiepsi/tril
...
expose tril_indices, triu_indices similar to diag_indices
2019-08-01 20:58:22 -07:00
Matthew Johnson
fd98f957a9
Merge pull request #1088 from fehiepsi/median
...
Add numpy.median and support ddof for numpy.var
2019-08-01 20:57:28 -07:00
fehiepsi
7a5aecea31
expose tril_indices triu_indices
2019-08-01 17:35:36 -04:00
fehiepsi
45c5bd4fba
support ddof for var
2019-08-01 16:20:08 -04:00
fehiepsi
98152d9d07
add numpy.median
2019-08-01 14:19:41 -04:00
Jamie Townsend
47f9eedb60
Correct jax.numpy.pad signature
2019-08-01 15:44:23 +01:00
Peter Hawkins
a350191331
Merge pull request #1074 from hawkinsp/pytree
...
Add C++ implementation of Pytree logic.
2019-07-30 20:50:09 -04:00
wyjw
4dcae5debf
Update lax_numpy.py
2019-07-29 22:56:30 -04:00
Peter Hawkins
510a9167c5
Add C++ implementation of pytree logic.
...
Move jaxlib version test into jax/lib/__init__.py. Make jax/lib mirror the structure of jaxlib; e.g., xla_client is now available as jax.lib.xla_client.
2019-07-29 15:06:05 -04:00
wyjw
b89e5a7ac0
shape_c variable taken out
2019-07-29 13:11:43 -04:00
wyjw
5487c784d6
added shape check
2019-07-29 13:05:48 -04:00
wyjw
a06883d91f
Made changes based on review.
2019-07-29 11:53:40 -04:00
wyjw
a87627b57b
Revert "Corrcoef"
2019-07-29 11:24:05 -04:00
twnly
3b6edbbe2f
corrcoef
2019-07-29 11:12:02 -04:00
twnly
d9b7c5fa39
made changes to corrcoef
2019-07-29 11:06:08 -04:00