280 Commits

Author SHA1 Message Date
joaogui1
f5abdafa82 Merge branch 'master' of https://github.com/google/jax into better-documentation 2019-10-29 15:53:18 -03:00
joaogui1
a0cf482636 Adds new functionality to wraps 2019-10-29 15:50:50 -03:00
Peter Hawkins
0d667d2727
Add tests for float16 support in lax_test.py. (#1553)
* Add tests for float16 support in lax_test.py.

Make test tolerances per-type, rather than a single tolerance based on the x64 mode.
Don't test float16 on TPU because it doesn't support float16.
Rework a number of the gradient tests. For linear primitives, increase eps and use a per-type tol.

* Perform float16 sinh and cosh in float32 precision.
More tweaks to test tolerances to get tests to pass.

* Add float16 testing to lax_numpy_test.py as well.

* Fix tolerance computation for testReducer test.
Relax tolerance for polyval.

* Relax some test tolerances further.

* Further relax test tolerances.

* Another tolerance relaxation.

* Use decorator for the upcast to fp32 for computation pattern.

Relax test tolerance for float_power.
2019-10-22 19:53:59 -04:00
Peter Hawkins
2bf799b63a
Fix numpy version check that fails for development numpy versions. (#1540)
Numpy versions may contain strings if not a release build. Only look at the two major entries to avoid an exception.
2019-10-21 10:05:59 -04:00
Matthew Johnson
39e09b867d
Merge pull request #1524 from google/issue1521
broadcast arguments in jax.numpy.take_along_axis
2019-10-18 16:17:19 -07:00
Matthew Johnson
a0352f3969 fix up broadcasting in take_along_axis 2019-10-18 22:50:24 +00:00
Matthew Johnson
aa0692d307 improve broadcast_to, add error checks (fixes #1522) 2019-10-17 23:23:08 +00:00
Matthew Johnson
cc137ced4d broadcast arguments in take_along_axis, fixes #1521 2019-10-17 22:38:28 +00:00
Stephan Hoyer
d338449ed5
Use collections.abc.Sequence in favor of collections.Sequence (#1504)
* Use collections.abc.Sequence in favor of collections.Sequence

The later will be removed in Python 3.8, which is due out any day now!
(There is currently a warning that appears when importing lax_numpy.)

* restore collections import
2019-10-14 13:48:56 -07:00
Skye Wanderman-Milne
d99851af34 Revert "Revert "Add a pylintrc to make it easier to use linter (#1442)""
This reverts commit 54807b42addba538cb0c1f18d7a5c2d08a952821.
2019-10-08 14:39:36 -07:00
Skye Wanderman-Milne
54807b42ad Revert "Add a pylintrc to make it easier to use linter (#1442)"
This reverts commit a0bb2c0ea452975be76e0ba2c6055f5be4439aa3.

Temporarily reverting this to see if it's causing the github workflow failures.
2019-10-08 14:28:14 -07:00
joao guilherme
a0bb2c0ea4 Add a pylintrc to make it easier to use linter (#1442) 2019-10-04 18:19:31 -07:00
joaogui1
d21efd3cc7 Fixes the parameters descriptions 2019-10-03 11:04:09 -03:00
Skye Wanderman-Milne
226c9e9cd1 nanmean fix 2019-09-26 17:10:49 -07:00
Matthew Johnson
762b602f33
Merge pull request #1394 from j-towns/fix-scatter-caching
Ensure all ops get cache hits on second op-by-op mode call
2019-09-26 06:48:42 -07:00
Jamie Townsend
d2d0576892 Ensure cache hits for gcd, lcm 2019-09-25 16:19:26 +02:00
Helw150
03fb88e49e TODOs and wrong name 2019-09-23 10:04:31 -07:00
Helw150
3d21393d0c PR Response Changes 2019-09-23 08:53:49 -07:00
Helw150
c312729d62 Refactor and Test based on comments from old PR 2019-09-22 21:38:34 -07:00
Helw150
949e1ddf43 Weird Cherry Pick Remnant 2019-09-21 01:24:48 -07:00
Helw150
12de81456b Simplify nanmean with logical not 2019-09-21 01:24:11 -07:00
James Bradbury
fc3a3ac075
Merge pull request #1349 from xysun/issue1233
Support for non-scalar np.repeat arguments (fixes #1233)
2019-09-20 09:27:47 -07:00
xysun
6fd938d8a0 no more nested loops of dynamic_update_slice! 2019-09-20 10:25:16 +01:00
Nikita Kitaev
0e7ea7e379 Reduce memory usage for argmax (fixes #1330) 2019-09-16 14:30:28 -07:00
Peter Hawkins
45a02f39f0 Temporarily remove jit decorator on gather/scatter ops. 2019-09-16 13:57:07 -07:00
xysun
0f03afc5e3 python2.7 compatible 2019-09-14 22:32:45 +01:00
xysun
3c2003b24e jit works! 2019-09-14 21:57:46 +01:00
xysun
c8eb5f657d pretty much works 2019-09-14 21:24:28 +01:00
Peter Hawkins
b7b5328526
Merge pull request #1346 from hawkinsp/master
Add a DeviceArray._unstack() method that unpacks an array along its m…
2019-09-13 11:16:01 -07:00
Peter Hawkins
ae329dcdf4 Add a DeviceArray._unstack() method that unpacks an array along its major dimension.
Use it to implement pxla's shard_arg method for DeviceArrays; this is faster than slicing out each element one by one.
2019-09-13 13:47:09 -04:00
Peter Hawkins
723456bc82 Slice objects are unhashable, so unpack them into tuples for forming static arguments. 2019-09-13 13:39:39 -04:00
Peter Hawkins
5ffddc182e JIT-compile index and index-update expressions.
Improves the performance of indexing in op-by-op mode.
2019-09-13 10:37:41 -04:00
Matthew Johnson
c52027691b jax.numpy.stack and concatenate work on array args
fixes #1271
2019-09-02 07:55:25 -07:00
Matthew Johnson
cac042c34a move asinh/acosh/atanh to lax_numpy.py only 2019-08-31 22:39:51 -07:00
Matthew Johnson
1cd37bd977
reset jax_numpy_rank_promotion to "allow" default 2019-08-27 11:21:50 -07:00
fehiepsi
28df8a666b cast float64 to canonical dtype in np.cov 2019-08-26 01:32:18 -04:00
Matthew Johnson
3c2a73592c improve rank promotion warning, add doc page 2019-08-25 14:28:53 -07:00
Matthew Johnson
afe21bafa4 address reviewer comments 2019-08-24 12:34:44 -07:00
Matthew Johnson
e90457d737 add dtype warnings to array-creation routines
fixes #1230
2019-08-24 08:19:05 -07:00
Matthew Johnson
d700716e19 add option to disable rank-promotion broadcasting
fixes #1236
2019-08-23 18:13:18 -07:00
Matthew Johnson
a8e0c2559c
Merge pull request #1224 from google/no-more-tuples
No more tuples
2019-08-21 14:21:36 -07:00
Matthew Johnson
b702f8de3e De-tuplify the rest of the core
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-08-21 13:21:20 -07:00
Brian Patton
451ff2d694
Cope with old numpy lacking axis arg 2019-08-21 13:40:50 -04:00
David Majnemer
18beaab200 Use a faster, numerically more faithful, approach to logaddexp
We can use log1p and fewer instances of exp to compute the same result.
2019-08-16 22:36:00 -07:00
Brian Patton
85e5d63423
Fix an exception caused by cached() hashing
I *think* the issue was that one of the elements in shape was a `DeviceArray`.

  File "jax/random.py", line 717, in gamma
    return _gamma(key, a, shape, dtype)
  File "jax/api.py", line 151, in f_jitted
    device_assignment=device_assignment)
  File "jax/core.py", line 672, in call_bind
    ans = primitive.impl(f, *args, **params)
  File "jax/interpreters/xla.py", line 667, in _xla_call_impl
    *map(abstractify, args))
  File "jax/linear_util.py", line 213, in cached_fun
    ans, f_prev = cached_fun_body(f, args)
  File "jax/linear_util.py", line 210, in cached_fun_body
    return call(f, *args), f
  File "jax/interpreters/xla.py", line 679, in _xla_callable
    jaxpr, (pval, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
  File "jax/linear_util.py", line 161, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "jax/random.py", line 725, in _gamma
    a = np.broadcast_to(a, shape)
  File "jax/numpy/lax_numpy.py", line 821, in broadcast_to
    lax.broadcast_shapes(shape, _shape(arr))  # error checking
  File "jax/interpreters/xla.py", line 623, in __hash__
    raise TypeError("JAX DeviceArray, like numpy.ndarray, is not hashable.")
TypeError: JAX DeviceArray, like numpy.ndarray, is not hashable.
2019-08-16 16:49:45 -05:00
Brian Patton
48f64ec863
Avoid a TypeError when reps is or contains a ndarray
Something along the lines of `TypeError: multiply only accepts scalar or ndarray, but got a list.`
2019-08-16 12:31:53 -05:00
Peter Hawkins
9fcf526608
Merge pull request #1193 from hawkinsp/axis
Use _canonicalize_axis everywhere to canonicalize axes, rather than s…
2019-08-16 08:33:36 -04:00
Peter Hawkins
f4fde04760 Fix Python 2 compatibility problem. 2019-08-15 21:20:21 -04:00
Peter Hawkins
efe98e2b81 Use _canonicalize_axis everywhere to canonicalize axes, rather than sometimes mod or %.
_canonicalize_axis has behavior more faithful to NumPy, rejecting out of range axes.
2019-08-15 20:56:56 -04:00
Peter Hawkins
1ba13e1b82 Consistently return JAX arrays instead of Numpy-classic arrays from jax.numpy.
Avoids surprising behavior that sometimes arises when mixing the two.
2019-08-15 20:25:32 -04:00