14488 Commits

Author SHA1 Message Date
Skye Wanderman-Milne
953910ab45 Disable timing out sparse_test.py on msan
PiperOrigin-RevId: 503475670
2023-01-20 10:41:20 -08:00
Skye Wanderman-Milne
068423bb96 Increase sharding on checkify_test.py to avoid asan timeouts
PiperOrigin-RevId: 503472266
2023-01-20 10:26:37 -08:00
jax authors
6a69b5a16c Merge pull request #14090 from jakevdp:ci-set-output
PiperOrigin-RevId: 503469220
2023-01-20 10:12:21 -08:00
Jake VanderPlas
25c9621295 CI: update deprecated uses of set-output 2023-01-20 09:51:05 -08:00
jax authors
cd5b26a0b9 Fix typo "invalud" -> "invalid" in error message.
PiperOrigin-RevId: 503452691
2023-01-20 08:48:24 -08:00
Qiao Zhang
650e1ef4c3 Expose fp8 types from jnp namespace.
PiperOrigin-RevId: 503353939
2023-01-19 22:23:33 -08:00
Yash Katariya
6dd4ebc8da Respect jax_disable_jit in pjit
PiperOrigin-RevId: 503297194
2023-01-19 16:36:00 -08:00
jax authors
622522e4a8 Merge pull request #14083 from rmlarsen:patch-3
PiperOrigin-RevId: 503287135
2023-01-19 15:53:30 -08:00
Rasmus Munk Larsen
c798fcaefc
Remove more uses of tan() in reduction tests.
This is to avoid subtly brittle tests. Tan() is an ill-conditioned function to evaluate near it's singularities.
2023-01-19 15:20:02 -08:00
jax authors
b0a7075f66 Merge pull request #14029 from 8bitmp3:move-debug
PiperOrigin-RevId: 503277693
2023-01-19 15:13:33 -08:00
8bitmp3
d6cc2bdb22 Make JAX Debugging and Profiling guides more visible, move Profiling to User Guides from Notes 2023-01-19 23:05:24 +00:00
jax authors
fae7306d88 Merge pull request #14042 from jakevdp:sharp-bits-dynamic-shapes
PiperOrigin-RevId: 503258721
2023-01-19 14:00:44 -08:00
Jake VanderPlas
9e355a6606 Sharp Bits: add section on Dynamic shapes 2023-01-19 11:37:03 -08:00
jax authors
0088d16064 Merge pull request #14079 from jakevdp:doc-reorg
PiperOrigin-RevId: 503207831
2023-01-19 10:49:50 -08:00
jax authors
7085699832 Merge pull request #14038 from jakevdp:sharp-bits-exceptions
PiperOrigin-RevId: 503196077
2023-01-19 10:08:54 -08:00
Jake VanderPlas
e5f0103895 Reorganization among notes, user guides, advanced guides 2023-01-19 09:52:43 -08:00
jax authors
3df9250169 Merge pull request #14077 from jakevdp:sparse-doc
PiperOrigin-RevId: 503189736
2023-01-19 09:44:18 -08:00
jax authors
97d23dc9d7 Merge pull request #13995 from chiamp:compute_fans
PiperOrigin-RevId: 503180128
2023-01-19 09:12:54 -08:00
Jake VanderPlas
736b6f6a52 [sparse] update primitive coverage in module doc 2023-01-19 09:06:38 -08:00
jax authors
6a1164e9ac Merge pull request #14062 from jakevdp:package-doc-titles
PiperOrigin-RevId: 503180070
2023-01-19 09:04:39 -08:00
Marcus Chiam
45c2f31887 Added shape error checking for compute_fans
Update tests/nn_test.py

Co-authored-by: Jake Vanderplas <jakevdp@google.com>
2023-01-18 20:59:11 -08:00
jax authors
f138656ebb Merge pull request #14049 from jakevdp:doc-fix-grid
PiperOrigin-RevId: 503052849
2023-01-18 20:39:18 -08:00
jax authors
1feced0183 Merge pull request #14037 from jakevdp:fix-faq
PiperOrigin-RevId: 503052186
2023-01-18 20:31:43 -08:00
Yash Katariya
5714616dd6 Set no_kwargs to False because pjit supports kwargs
PiperOrigin-RevId: 503019556
2023-01-18 17:14:24 -08:00
jax authors
5f4e95b6c7 Add gcs support to serialization for internal tensor spec.
PiperOrigin-RevId: 503013068
2023-01-18 16:43:03 -08:00
jax authors
30a0df2b37 Merge pull request #14066 from rmlarsen:patch-1
PiperOrigin-RevId: 502980887
2023-01-18 14:33:44 -08:00
jax authors
5b1d67f0bb Merge pull request #14008 from skye:tpu_init
PiperOrigin-RevId: 502980676
2023-01-18 14:33:28 -08:00
Peter Hawkins
1929d34cb6 Fix test failures under NumPy 1.24.
NumPy 1.24 release notes: https://numpy.org/devdocs/release/1.24.0-notes.html

The fixes vary, but there are three particularly common changes:
* NumPy 1.24 removes a number of deprecated NumPy type aliases references (np.bool, np.int, np.float, np.complex, np.object, np.str, np.unicode, np.long). This change replaces them with their recommended replacements (bool, int, float, complex, object, str, str, int).
* Under NumPy 1.24 no longer automatically infers dtype=object when ragged sequences are passed to np.array(). See https://numpy.org/neps/nep-0034-infer-dtype-is-object.html . In most cases the fix is to pass dtype=object explicitly, but in some cases where the raggedness seems accidental other fixes were used.
* NumPy 1.24 is pickier about the dtype= option passed to comparison ufuncs.

PiperOrigin-RevId: 502979152
2023-01-18 14:25:33 -08:00
Qiao Zhang
d58266eac7 Store sorted flattened dict keys in PyTree as a c++ vector instead of py::list to avoid creating new python object on every single dict flatten. For deeply nested dict, this avoids excessive gc pressure and avoids the slowdown whenever gc needs to sweep too many live python objects.
PiperOrigin-RevId: 502967020
2023-01-18 13:40:43 -08:00
Rasmus Munk Larsen
2ee33a0728
Update lax_control_flow_test.py
Fix brittle scan test. Adding tan(randn) is numerically brittle because evaluating tan() near its singularities is ill-conditioned.
2023-01-18 13:11:56 -08:00
Yash Katariya
4add3b8cee Make pjit an AxisPrimitive so that it can run the batching rules even if the argument is not batched but there is a axis_index/named shapes inside the pjitted function.
PiperOrigin-RevId: 502955369
2023-01-18 12:56:07 -08:00
George Necula
30cf057bf3 [host_callback] Add device_index to hcb.call and add tests
The device_index feature works only with outfeed, add an
error message.

PiperOrigin-RevId: 502951721
2023-01-18 12:41:11 -08:00
jax authors
96b67bcb9f Merge pull request #14053 from gnecula:tf_weak_dim_as_value
PiperOrigin-RevId: 502946430
2023-01-18 12:21:29 -08:00
Skye Wanderman-Milne
6d0e22eaf9 Don't run FP8 dtype test on TPU.
This change makes dtypes_test.py pass even when not using Bazel (e.g. with
pytest). It also improves TPU coverage when using Bazel.

PiperOrigin-RevId: 502930531
2023-01-18 11:22:17 -08:00
Jake VanderPlas
703dd462d3 DOC: make grid placement responsive for smaller screens 2023-01-18 11:21:27 -08:00
Jake VanderPlas
81e627d5bd DOC: make API doc titles more uniform 2023-01-18 10:59:42 -08:00
jax authors
e3c2602ed5 Merge pull request #14059 from jakevdp:fix-cond-compilation
PiperOrigin-RevId: 502916647
2023-01-18 10:35:57 -08:00
Jake VanderPlas
6376dc9616 Fix excessive recompiles in lax.cond 2023-01-18 10:17:01 -08:00
George Necula
58035a7b53 [shape_poly] Fix handling of weak_type for conversions of symbolic dimensions to Array
In presence of static shapes `jnp.array(x.shape[0])` has weak_type. We must
preserve that behavior even with symbolic dimensions.
2023-01-18 12:56:48 +02:00
Yash Katariya
a37121e195 Don't depend on flatten_axis_resources which will error because flatten_axes passes a dummy object() which doesn't work with checks in user pytrees.
Only do this if the original {in|out}_shardings are _UNSPECIFIED.

PiperOrigin-RevId: 502792305
2023-01-18 00:13:04 -08:00
jax authors
74df2f9927 Merge pull request #14031 from gnecula:tf_jit_pjit
PiperOrigin-RevId: 502790279
2023-01-18 00:01:34 -08:00
George Necula
97b58bfae7 [jax2tf] Adjust jax2tf for the pjit==jit API migration.
jax2tf treats jit and pjit differently: jit was inlined while
pjit resulted in a recursive call to _interpret_jaxpr. This
resulted in differences of handling of constant sharing.

This PR actually makes the constant sharing more aggressive.
This should be Ok, because we are only sharing non-scalars
which JAX has already lifted to the top-level of the Jaxpr.
2023-01-18 09:04:21 +02:00
jax authors
b0e30eb067 Merge pull request #13911 from 8bitmp3:main
PiperOrigin-RevId: 502743593
2023-01-17 18:50:27 -08:00
Yash Katariya
05e1ddd4ea Make error_test a jax_test so that we can test other configs and fix it with jit/pjit merge.
PiperOrigin-RevId: 502743523
2023-01-17 18:43:05 -08:00
Parker Schuh
b58dd3cbe1 Add support for __signature__ to PjitFunction.
PiperOrigin-RevId: 502731453
2023-01-17 17:28:14 -08:00
jax authors
6d80e5909b Merge pull request #13423 from nvlcambier:lcambier/on_demand_ci_jax_gpu
PiperOrigin-RevId: 502727655
2023-01-17 17:11:18 -08:00
jax authors
e6e8350389 Merge pull request #14047 from jakevdp:doc-installation
PiperOrigin-RevId: 502727042
2023-01-17 17:03:43 -08:00
Yash Katariya
53fceab17c pjit allows nesting of pjits where the outer backend is None while the inner backend is something other than device_under_test(). This is because the inner backend will take priority.
PiperOrigin-RevId: 502721834
2023-01-17 16:39:45 -08:00
Leopold Cambier
056702c1cb Multinodes CICD on GPUs using on-demand cluster and e2e tests using T5X 2023-01-17 16:29:30 -08:00
Jake VanderPlas
5ebd1e79ff DOC: add GPU & TPU installation tabs on index page 2023-01-17 16:19:57 -08:00