170 Commits

Author SHA1 Message Date
Jake VanderPlas
05faf0f40d Remove deprecated functionality from jax.test_util
PiperOrigin-RevId: 480360504
2022-10-11 08:16:34 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
jax authors
840c96692e Internal change
PiperOrigin-RevId: 468509799
2022-08-18 11:39:07 -07:00
Peter Hawkins
1e241dcf16 Catch ModuleNotFoundError instead of ImportError.
We frequently use the pattern
try:
  import m
except ImportError:
  # do something else.

This suppresses errors when the module can be found but does not import
successfully for any reason. Instead, catch only ModuleNotFoundError so
missing modules are allowed but buggy modules still report errors.
2022-08-18 15:22:49 +00:00
Jake VanderPlas
887abbc3b9 jax.test_util: remove deprecated test classes.
JaxTestCase and JaxTestLoader were deprecated in jax v0.3.1, released Feb 2022.
2022-06-27 11:04:50 -07:00
Jake VanderPlas
5782210174 CI: fix flake8 ignore declarations 2022-04-21 13:44:12 -07:00
Jake VanderPlas
d9508304e4 Deprecate remaining functionality in jax.test_util 2022-04-21 12:12:40 -07:00
Jake VanderPlas
1246b6fc73 Separate jax.test_util implementations into public and private sources.
Eventually the private functionality will no longer be exported via the jax.test_util submodule.

PiperOrigin-RevId: 439415485
2022-04-04 14:43:39 -07:00
Jake VanderPlas
da3aaa1960 Add deprecation warning to JaxTestCase and JaxTestLoader 2022-02-17 14:58:58 -08:00
Jake VanderPlas
1de3999ea6 test_util: export with_config 2022-01-26 14:32:11 -08:00
Peter Hawkins
c491203bdd Readd jax.test_util.check_jvp and check_vjp to the public JAX API. 2021-10-14 11:55:11 -04:00
Peter Hawkins
db2e91eba2 Move jax.test_util to jax._src.test_util.
Add forwarding shims for names used by external clients of JAX in practice.

PiperOrigin-RevId: 398721725
2021-09-24 07:02:49 -07:00
Peter Hawkins
2c2f4033cc Move contents of jax.lib to jax._src.lib.
Add shim libraries for functions exported from jax.lib that other code seems to use in practice.

PiperOrigin-RevId: 398471863
2021-09-23 06:33:55 -07:00
Peter Hawkins
be28827164 Add atol and rtol arguments to jtu._CheckAgainstNumpy().
Prefer atol and rtol if they are provided.
2021-09-22 17:46:08 -04:00
Peter Hawkins
8b2123968a Switch internal users of jax.util.partial to use functools.partial. 2021-09-13 21:09:58 -04:00
Matthew Johnson
9f919c69e4 fix custom_vjp issue? 2021-08-27 21:22:25 -07:00
Jake VanderPlas
6114e6a0d3 test_util: add decorator to set config values in test cases 2021-08-05 14:06:37 -07:00
Jake VanderPlas
63a788b4de Cleanup: switch to new version of super() 2021-08-05 13:11:07 -07:00
jax authors
060345c16a Merge pull request #7476 from jakevdp:remove-method
PiperOrigin-RevId: 388771537
2021-08-04 13:31:58 -07:00
Jake VanderPlas
2408343221 Cleanup: remove unused test utility 2021-08-03 13:45:37 -07:00
Reza Rahimi
b44d35664c change skip_on_devices to handle device tags 2021-07-30 19:17:21 +00:00
Adam Paszke
e987f6f9fc Make maps.EXPERIMENTAL_SPMD_LOWERING into a jax.config flag
This is much more convenient and lets us register callbacks that trigger on
changes. I want to add more toggles (e.g. for the SPMD lowering that restricts
sharding of every intermediate), so I want to work out a reasonable approach to
do that first.

Second attempt, this time without hardening against the flags being
registered too late due to delayed imports.
2021-07-15 14:18:52 +00:00
jax authors
25e44821dd Make maps.EXPERIMENTAL_SPMD_LOWERING into a jax.config flag
This is much more convenient and lets us register callbacks that trigger on
changes. I want to add more toggles (e.g. for the SPMD lowering that restricts
sharding of every intermediate), so I want to work out a reasonable approach to
do that first.

PiperOrigin-RevId: 384902895
2021-07-15 05:07:09 -07:00
Adam Paszke
8bc6e7f1d5 Make maps.EXPERIMENTAL_SPMD_LOWERING into a jax.config flag
This is much more convenient and lets us register callbacks that trigger on
changes. I want to add more toggles (e.g. for the SPMD lowering that restricts
sharding of every intermediate), so I want to work out a reasonable approach to
do that first.

PiperOrigin-RevId: 384892199
2021-07-15 03:37:30 -07:00
Amol Mandhane
4bae17dd70 Add a helper decorator to disable implicit rank promotion in unit tests. 2021-07-05 13:40:41 +01:00
Peter Hawkins
75c9bf01f3 Fix most test failures under NumPy 1.21. 2021-06-22 16:31:44 -04:00
George Necula
973171bb6d [jax2tf] Add support for pjit. 2021-06-01 14:32:59 +03:00
Peter Hawkins
1350d21881 Add regression test for #5728.
This issue appears to have been fixed by jaxlib 0.1.66.
2021-05-12 13:45:16 -04:00
George Necula
ba5e11f86f [jax2tf] Improve the conversion of integer_pow for better numerical accuracy.
Previously we simply converted integer_pow to tf.math.pow. JAX instead uses
a series of multiplications. We now use the same lowering strategy as JAX, so
that we have the same numerical result.

Also improved the error messages for assertion failures.

PiperOrigin-RevId: 373351147
2021-05-12 05:45:39 -07:00
George Necula
235eb8c2b4 Copybara import of the project:
--
1ecf4f02891cad70cc8f094b49cf2458105ca366 by George Necula <gcnecula@gmail.com>:

[jax2tf] Change the conversion of dot_general to use XLA op.

Instead of converting the dot_general to a sea of TF ops, when
we enable_xla we just use the XLA op. This has the advantage
that it also supports the preferred_element_type.

Fixed bug with passing the precision parameter to TF.
Also improved tests to print the HLO in case of numerical errors.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/6717 from gnecula:tf_dot 1ecf4f02891cad70cc8f094b49cf2458105ca366
PiperOrigin-RevId: 373326655
2021-05-12 02:30:15 -07:00
Peter Hawkins
d005e38f78 Promote the x.at[idx].set(y) operators as the preferred way to do indexed updates.
Mark the index_update() etc. operators as deprecated in the documentation.

Add new .divide and .power operators. Fixes #2694.
Add .multiply as an alias for .mul. To be more numpy-like we should probably prefer the longer names.
2021-05-10 20:32:00 -04:00
Peter Hawkins
14d991dd90 Move jax.config to jax._src.config.
PiperOrigin-RevId: 369230109
2021-04-19 08:53:12 -07:00
Jake VanderPlas
7773d50486 Fix nanquantile for negative NaNs & adjust test harness to cover this 2021-04-15 09:42:24 -07:00
jax authors
6cc4bb0476 Merge pull request #6420 from apaszke:faster-tests
PiperOrigin-RevId: 368519866
2021-04-14 15:24:06 -07:00
Adam Paszke
e0357283a6 Speed up test generation 7x
There are a few test cases that generate millions of configurations,
only to have a handful of them selected by `cases_form_list`. I've
found all tests that spend over 100ms in case generation and
converted them to a new "test sampler" approach. The result: test
generation time drops from 15s to around 2s. Doesn't sound like much,
but I expect that we all run tests many times daily, so it seems like a
useful thing to have.

The rough idea is that the sampling generators get parameterized by a
sampler function that should be applied to the range of every `for` loop.
This allows us to sample runs of the generator through different
configurations by restricting each loop to a smaller subset. Right now
we always narrow it down to a single randomly selected instance. But,
we still retain the possibility of adding exhaustive testing in the
future, which can be achieved by passing in an identity sampling
function that wouldn't modify any loop ranges.
2021-04-14 15:58:05 +00:00
Peter Hawkins
26e9ebcdae Move jax.api to jax._src.api.
PiperOrigin-RevId: 368233837
2021-04-13 09:43:24 -07:00
Peter Hawkins
6a6f13e1b0 [JAX] Move contents of jax/dtypes.py to jax/_src/dtypes.py.
PiperOrigin-RevId: 367345623
2021-04-07 19:35:51 -07:00
Matthew Johnson
fd7b286ec9 unify configuration state handling 2021-03-23 18:56:01 -07:00
jax authors
cf9b77f1de Merge pull request #5998 from zhangqiaorjc:dev_put_count
PiperOrigin-RevId: 362143966
2021-03-10 14:36:55 -08:00
Qiao Zhang
9577860169 Add jtu.count_device_put for tests to count device_put. 2021-03-09 14:45:01 -08:00
Jake VanderPlas
dbdb189de1 jnp.piecewise: support scalar inputs 2021-03-09 13:25:38 -08:00
jax authors
144a633e42 Merge pull request #5405 from shoyer:check-grads-err-msg
PiperOrigin-RevId: 356818579
2021-02-10 13:51:23 -08:00
Stephan Hoyer
db6405c746 Better error messages for test_util.check_grads()
Rather than merely reporting a failure in check_grads(), we now
report the *specific* check that failed, e.g., "JVP tangent" or
"VJP of JVP cotangent projection". Gradient tests often fail for
spurious reasons (e.g., due to insufficient precision), so this should
be helpful for debugging.

I tested this manually by relaxing the tolerance for a test in
`linalg_test.py`.
2021-02-10 12:19:33 -08:00
Adam Paszke
0b7febea39 Add argument donation for xmap
Also, pass the body to XLA JIT when no parallel resources are used.
There is no reason to not do that given that we already require users to
pay the price of making their code jittable.
2021-02-08 12:46:45 +00:00
Jake VanderPlas
5140426c7c Support non-scalar fill values in jnp.full() & jnp.full_like() 2021-02-05 10:07:41 -08:00
Jake VanderPlas
2fd682ef2a Make jax_enable_x64 a thread-local value. 2021-02-04 09:48:22 -08:00
Peter Hawkins
3ac809ede3 [JAX] Move jax.util to jax._src_util.
PiperOrigin-RevId: 351234602
2021-01-11 14:21:07 -08:00
George Necula
20be478a6e [host_callback] Add support for pmap and for passing the device to tap
* Adds support for jit of pmap and pmap of pmap.
* Also adds a `tap_with_device` optional argument to `id_print` and
  `id_tap`, to have the tap function invoked with a device keyword argument.
* Added multiple tests involving pmap

Issue: #5134
Fixes: #5169
2020-12-15 10:46:23 +02:00
Adam Paszke
ca8028950e Fix pmap compilation cache regressions from #4904.
AD didn't use `HashableFunction` enough, tripping up the compilation
cache. I've also used the occasion to make function hashing a little
safer by including the Python bytecode of the wrapped function as part
of the key.
2020-12-02 14:40:45 +00:00
jax authors
9a8ee95c08 Merge pull request #4419 from rsepassi:master
PiperOrigin-RevId: 334846970
2020-10-01 10:39:51 -07:00