74 Commits

Author SHA1 Message Date
George Necula
b79c7948ee Removed dependency on distutils.strtobool 2020-02-06 17:27:46 +01:00
George Necula
b18a4d8583 Disabled tests known to fail on Mac, and optionally slow tests.
Issue: #2166

Added JAX_SKIP_SLOW_TESTS environment variable to skip tests known
to be slow.
2020-02-05 18:02:56 +01:00
George Necula
4f5987ccd9 Simplify Jaxpr: remove freevars.
Freevars played a very small role, and they can be folded with
the invars. This simplifies the Jaxpr data structure.We remove
the `freevars` field from Jaxpr and from the bound_subjaxprs.

The only non-trivial change is for xla_pmap, where we need
to carry one extra parameter `mapped_invars` with a bitmap
to encode which invars are mapped and which are broadcast.
Previously, the freevars were broadcast.
2020-02-03 18:58:05 +01:00
Peter Hawkins
e60d5dd54c
Remove "from __future__" uses from JAX. (#2117)
The future (Python 3) has arrived; no need to request it explicitly.
2020-01-29 12:29:03 -05:00
Peter Hawkins
dcc882cf6b
Drop Python 2 support from JAX. (#1962)
Remove six dependency.
2020-01-08 13:17:55 -05:00
Skye Wanderman-Milne
891aecb941
Add test utilities for counting compilations. (#1895)
Also uses the new utilities to check that pmap doesn't compile constant computations.
2019-12-19 11:19:58 -08:00
Peter Hawkins
d8d3a7bc87
Allow scalar numpy arrays as shapes in np.{zeros,ones,full}. (#1881) 2019-12-17 17:20:51 -05:00
Stephan Hoyer
6ac1c569e8
Use HIGHEST precision for dot_general in linalg JVP rules (#1835) 2019-12-10 00:38:18 -08:00
Peter Hawkins
687b9050df
Prepare to switch default dtypes in JAX to be 32-bit types. (#1827)
This change prepares for switching the default types in JAX's NumPy to be 32-bit types. In particular, it makes the JAX tests pass in the event that jax.numpy.int_, jax.numpy.float_, and jax.numpy.complex_ are defined to be 32-bit types instead of 64-bit types, but does not yet change the defaults.
2019-12-09 21:18:39 -05:00
Peter Hawkins
fb79d56ace
Fixes to type handling. (#1824)
* Fixes to type handling.

* Specify exactly which types to test in lax_test.py, rather than relying on non-x64 mode to squash unsupported types.
* Fix some excessive promotions in jax.numpy.
* Fix some buggy RNGs that returned the wrong type for complex inputs.
2019-12-06 14:49:27 -05:00
Peter Hawkins
ff94b4442a
Remove np._promote_args_like, and replace its users with a newer _pro… (#1802)
* Remove np._promote_args_like, and replace its users with a newer _promote_args_inexact.

We no longer want to promote arguments exactly like NumPy; NumPy has a bad habit of promoting integer types to float64, whereas we want to promote to jax.numpy.float_, which may not be the same.

For example
```
import numpy as onp
onp.sin(3).dtype
```
returns `onp.dtype(float64)`.

However, it turns out that all of the users of `_promote_args_like` are using it for exactly one behavior: promoting integers or bools to inexact types like float. Implement that behavior explicitly rather than mimicing the behavior of NumPy.

* Relax test tolerances.
2019-12-03 10:05:51 -05:00
Peter Hawkins
f3c8af49e7
Fix bugs in handling of convolutions whose LHS has spatial size 0. (#1794)
* Fix bugs in handling of convolutions whose LHS has spatial size 0.

* Use onp.shape to compute shapes.
2019-12-02 14:43:43 -05:00
George Necula
2bb74b627e Ensure jaxpr.eqn.params are printed sorted, so we get deterministic output 2019-11-26 14:05:08 +01:00
George Necula
603258ebb8 Fixed a couple of tests 2019-11-26 13:56:58 +01:00
George Necula
5c15dda2c9 Changed api.make_jaxpr to return a TypedJaxpr
* A TypedJaxpr contains more useful information (consts, types)
* Also forced the instantiation of constants when producing the jaxpr.
  Before:
  >>>print(api.make_jaxpr(lambda x: 1.)(0.))
     lambda ; ; a.
     let
     in [*]}
  After this change:
  >>>print(api.make_jaxpr(lambda x: 1.)(0.))
     lambda ; ; a.
     let
     in [1.0]}
2019-11-26 09:17:03 +01:00
Peter Hawkins
ee36818a58
Add bfloat16 support to JAX. (#1720)
bfloat16 support is still immature, but this PR adds some initial support.

Fixes #76, at least enough that we can declare it fixed and open specific issues for specific bfloat16 problems.

The main awkwardness that this change deals with is that classic NumPy doesn't understand bfloat16 promotion rules, so we must:

implement our own type promotion operators that understand bfloat16 types
wrap a number of the reference implementations in tests to temporarily cast to float32 for computation.
2019-11-20 22:43:46 -05:00
Peter Hawkins
42dd736afd
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.

Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.

This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.

In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
Peter Hawkins
bbf8129aa6
Change test tolerance logic not to choose tolerance values based on f… (#1701)
* Change test tolerance logic not to choose tolerance values based on flags (in particular, --jax_enable_x64).

We would like to move away from having global flags to enable 64-bit mode. We therefore need other methods to select test tolerances. Instead, use a per-type default tolerance, and allow tests to pass per-type dictionaries of tolerances as atol and rtol values. Fix up a number of tolerances to make tests pass.

* Fix test tolerances.

* Fix dtype canonicalization for test tolerances.

* Relax core test_vjp tolerance.
2019-11-16 13:51:42 -05:00
Peter Hawkins
f4aa5150e8
Move internal type-related functions into a new (internal) jax.types … (#1695)
* Move internal type-related functions into a new (internal) jax.types module.

Avoid calling onp type functions in lieu of the wrappers in jax.types. Currently these do the same thing, but future changes will make the behavior of the jax type functions diverge from the classic NumPy versions in some cases.

Move xla_bridge.canonicalize_dtype into jax.types, since it fits there more naturally.

* Rename jax.types to jax.dtypes.

* s/types/dtypes/ in tests.
2019-11-15 10:02:51 -05:00
Peter Hawkins
cc0568ef49
Remove test_util.check_raises_regexp. (#1692)
It does nothing that the builtin self.assertRaisesRegexp doesn't already do.
2019-11-14 16:00:55 -05:00
Stephan Hoyer
a9a6cf8a2e
Faster test collection, second try (#1653)
* Faster test collection, second try

Follows @hawkinsp's suggestion from #1632 to rewrite everything in terms of
RNG factories, creating actual RNG functions *inside* each test method instead
of when they are collected.

* use np.testing.assert_allclose
2019-11-11 12:51:15 -08:00
Peter Hawkins
0d667d2727
Add tests for float16 support in lax_test.py. (#1553)
* Add tests for float16 support in lax_test.py.

Make test tolerances per-type, rather than a single tolerance based on the x64 mode.
Don't test float16 on TPU because it doesn't support float16.
Rework a number of the gradient tests. For linear primitives, increase eps and use a per-type tol.

* Perform float16 sinh and cosh in float32 precision.
More tweaks to test tolerances to get tests to pass.

* Add float16 testing to lax_numpy_test.py as well.

* Fix tolerance computation for testReducer test.
Relax tolerance for polyval.

* Relax some test tolerances further.

* Further relax test tolerances.

* Another tolerance relaxation.

* Use decorator for the upcast to fp32 for computation pattern.

Relax test tolerance for float_power.
2019-10-22 19:53:59 -04:00
Skye Wanderman-Milne
d6a9202c38 test_util._cast_to_shape fix 2019-09-27 11:18:36 -07:00
Matthew Johnson
762b602f33
Merge pull request #1394 from j-towns/fix-scatter-caching
Ensure all ops get cache hits on second op-by-op mode call
2019-09-26 06:48:42 -07:00
Peter Hawkins
6978591a2a
Merge pull request #1386 from Helw150/will/nanMeanAndNanTests
Nan Mean and Adds Nan Reducers to testing
2019-09-25 16:55:23 -04:00
Jamie Townsend
7c979e7c28 Always test for cache misses on second op-by-op call 2019-09-25 16:16:00 +02:00
Helw150
03fb88e49e TODOs and wrong name 2019-09-23 10:04:31 -07:00
Helw150
3d21393d0c PR Response Changes 2019-09-23 08:53:49 -07:00
Helw150
c312729d62 Refactor and Test based on comments from old PR 2019-09-22 21:38:34 -07:00
Roman Novak
58ed2e31d8
Make assertAllClose work with non-array types.
In our application this would make comparing trees with other entries like Nones and enums easier. I may be missing some other issues though, let me know if this change makes sense!
2019-09-22 15:32:12 -07:00
Peter Hawkins
0fae8fec80 Add a jtu.device_under_test() method. Use it instead of reparsing the device flag everywhere. 2019-08-04 17:17:49 -04:00
Peter Hawkins
08ffe4f3b0 Remove stale comment. 2019-07-03 12:05:07 -04:00
Matthew Johnson
1ecf75770b make check_grads check all orders <= given order 2019-05-21 17:22:33 -07:00
Matthew Johnson
8c00df1d61 finer-grained passing and failing tests 2019-05-21 15:14:28 -07:00
Matthew Johnson
42a1ad4307 change dtype promotion behavior for jit-invariance
Here are two desiderata for jax.numpy dtype promotion behavior:
1. follow what NumPy does
2. be invariant to `@jit`

The latter is much more important, so whenever the two are in tension we
prefer the latter. (Also we already can't do a perfect job following
what NumPy does, e.g. around its value-dependent dtype promotion logic.)

Issue #732 showed our code had a special behavior that essentially
handled a case of the former desideratum but also broke the latter. #732
also showed us (again) that our tests really should cover Python
scalars.

In summary, in this commit:
* revise jax.numpy dtype promotion behavior to be invariant to `@jit`
* add Python scalar types to lax_numpy tests
* simplify and update kron implementation to fix dtype issues
2019-05-19 18:49:16 -07:00
Matthew Johnson
07bf50967b make jtu.skip_on_devices read jax_platform_name
fixes #696
2019-05-10 12:27:15 -07:00
Peter Hawkins
6d77fb7d20 Fix type mismatch for nan_to_num for 64-bit types. Fixes #683.
Add tests for isinf/isnan/isposinf/isneginf/nan_to_num now that nan/inf are honored on the CPU backend.

Add complex number support to more of the RNG test utils. Add test RNG that emits both nans and infs.
2019-05-07 15:07:43 -04:00
Matthew Johnson
642d2dc802 revies optimizers api, fix misc bugs
* add more optimizers numerical tests
* update examples and readme with new optimziers api
* add device_values parameter to xla_call
* change optimizers.py to flatten trees and subtrees
* remove tree_map2, tree_multimap2, tree_mimomap, tree_prefixmap
* add optimizer tests: DeviceTuples and error msgs
* make the device_values arg to jit private
2019-05-03 12:44:52 -07:00
Matthew Johnson
055521fa8e add DeviceTuples for device-persistent tuples 2019-04-30 17:15:10 -07:00
Peter Hawkins
ca0d943999 Test case improvements:
* use numpy.random to select test cases, rather than random. This allows more control over random seeds. Pick a fixed random seed for each test case.
* sort types in linalg_test.py so the choice of test cases is deterministic.
* use known_flags=True when doing early parsing of flags from parse_flags_with_absl.
2019-04-12 10:48:11 -04:00
Matthew Johnson
a2778b245c reconcile _CheckAgainstNumpy arg order common use 2019-03-22 17:09:35 -07:00
Skye Wanderman-Milne
65cac39647 Fix convert_element_type bug when converting from a complex dtype. 2019-03-18 14:27:34 -07:00
Peter Hawkins
3a456a3e73 Make assertAllClose check shapes for exact equality.
Currently assertAllClose delegates to np.is_allclose, which has broadcasting semantics.

Fix some newly failing test cases.
2019-03-14 21:59:31 -04:00
Peter Hawkins
382df2e9a4 Actually use the user-provided epsilon value in tests. 2019-02-15 14:38:55 -05:00
Peter Hawkins
89727e4377 Fix complex64/complex128 confusion in test harness. 2019-02-01 14:01:06 -05:00
Matthew Johnson
5f5baaa4cd tweak tests for internal purposes 2019-01-28 12:55:24 -08:00
Peter Hawkins
a21c3c4562 Fix definition of rand_uniform(). 2019-01-14 16:48:07 -05:00
Peter Hawkins
e00dc5d39d Restrict the range of np.tan() test to [-1.5, 1.5) to avoid numerical problems. 2019-01-14 15:28:53 -05:00
Peter Hawkins
0fa5af9dbb Implement the mode='constant' case of np.pad. 2019-01-09 21:26:22 -05:00
Matthew Johnson
ca27f0a2b2 add jacobian / hessian pytree tests (fixes #173) 2019-01-07 08:54:14 -08:00