Anish Tondwalkar
ac44d2c2e3
Migrate besseli0e off xla_fallback
...
PiperOrigin-RevId: 519241252
2023-03-24 14:39:40 -07:00
Yash Katariya
257ac6a993
If each host has the full value of the Array, allow fetching it to host. Fixes #15162
...
Benchmarks:
```
name old cpu/op new cpu/op delta
np_asarray_8_devices 3.71ms ± 6% 3.32ms ± 7% -10.48% (p=0.008 n=5+5)
name old time/op new time/op delta
np_asarray_8_devices 3.86ms ± 6% 3.49ms ± 7% -9.72% (p=0.008 n=5+5)
```
PiperOrigin-RevId: 519222320
2023-03-24 13:21:57 -07:00
Peter Hawkins
6ed66ada0f
Delete remote TPU support.
...
TPU VMs are the only supported way to use TPUs as of JAX 0.4.0.
PiperOrigin-RevId: 519211267
2023-03-24 12:33:33 -07:00
John QiangZhang
fad4e6f95a
[1/n] store embedded tf.graph to stablehlo.custom_call
...
PiperOrigin-RevId: 519194911
2023-03-24 11:27:24 -07:00
Parker Schuh
21541e60b1
Guard ArrayImpl checks by xla_extension_version.
...
PiperOrigin-RevId: 519191714
2023-03-24 11:15:36 -07:00
Yash Katariya
bc231ee0bf
After the SPMD bug fix, always take the _rewriting_take route for getitem instead of bouncing to host.
...
PiperOrigin-RevId: 519170785
2023-03-24 10:00:41 -07:00
Anish Tondwalkar
8c75e27f67
Migrate random_gamma_grad off xla_fallback
...
PiperOrigin-RevId: 519154537
2023-03-24 08:49:40 -07:00
Anish Tondwalkar
8d1d522618
Migrate igamma_grad_a_p off xla_fallback
...
PiperOrigin-RevId: 519148548
2023-03-24 08:21:22 -07:00
Anish Tondwalkar
4a9b09485e
Migrate igammac_p off xla_fallback path
...
It is now decomposed into stablehlo ops.
PiperOrigin-RevId: 519122775
2023-03-24 05:58:38 -07:00
jax authors
32e712864c
Merge pull request #15192 from mattjj:issue15190
...
PiperOrigin-RevId: 519037959
2023-03-23 20:48:33 -07:00
jax authors
1982a113d6
Merge pull request #15187 from mattjj:djax-revival
...
PiperOrigin-RevId: 519036576
2023-03-23 20:38:01 -07:00
Matthew Johnson
7743fcd758
[dynamic-shapes] make dynamic shape staging-to-jaxpr work with pjit
2023-03-23 20:20:01 -07:00
Matthew Johnson
793387e496
fix jax.Array.round()
...
fixes #15190
2023-03-23 20:16:23 -07:00
jax authors
e9bc7ee866
Merge pull request #15184 from jakevdp:move-median
...
PiperOrigin-RevId: 519003606
2023-03-23 17:14:50 -07:00
Jake VanderPlas
6f8885a0c2
lax_numpy: move quantile-based functions to reductions.py
2023-03-23 16:39:20 -07:00
Anish Tondwalkar
f981243af5
Migrate igamma_p off xla_fallback
...
We decompose it into a series or a call to igammac.
PiperOrigin-RevId: 518993077
2023-03-23 16:26:59 -07:00
George Necula
d777cf229e
[jax2tf] A simple failing test on TPU with native serialization
...
PiperOrigin-RevId: 518987577
2023-03-23 16:04:53 -07:00
John QiangZhang
171b22dbbc
Add padding option "SAME_LOWER" for ticket https://github.com/google/jax/pull/14990
...
PiperOrigin-RevId: 518984018
2023-03-23 15:50:16 -07:00
George Necula
7df52459d0
[jax2tf] Create a jax_export library with JAX-only pieces for native serialization
...
This is a pure refactor, no functionality should change.
PiperOrigin-RevId: 518982222
2023-03-23 15:42:45 -07:00
jax authors
6407fa63ae
Merge pull request #15178 from mattjj:improve-scan-errors
...
PiperOrigin-RevId: 518977054
2023-03-23 15:28:16 -07:00
Anish Tondwalkar
adbdaa47a3
Refactor special functions into their own module.
...
We're going to want to decompose these using series and
continued fraction representations, and for that we'll need
control flow
PiperOrigin-RevId: 518977008
2023-03-23 15:21:15 -07:00
jax authors
1f7c305cf6
Merge pull request #15172 from jakevdp:jax-array-refactor
...
PiperOrigin-RevId: 518971494
2023-03-23 15:00:15 -07:00
Matthew Johnson
ba2ff519ca
improve scan error messages
2023-03-23 14:53:05 -07:00
George Necula
1136d0f6c7
[jax2tf] Minor addition to the documentation
...
PiperOrigin-RevId: 518969936
2023-03-23 14:52:01 -07:00
Cristian Garcia
dfc24f2981
add trailing-whitespace pre-commit hook
2023-03-23 21:17:54 +00:00
Jake VanderPlas
1286446b52
Fix mypy issue in jax/experimental/jet.py
2023-03-23 13:56:11 -07:00
Yash Katariya
a9e48af260
Deprecated xla_call_p since it has been replaced with pjit.pjit_p
...
PiperOrigin-RevId: 518921538
2023-03-23 11:44:42 -07:00
Jake VanderPlas
87aec2433b
internal: refactor array methods into separate private submodule
2023-03-23 10:57:53 -07:00
Peter Hawkins
f461c4ef0c
Move jax._src.typing into a separate Bazel target.
...
PiperOrigin-RevId: 518899136
2023-03-23 10:30:08 -07:00
jax authors
383cf41848
Merge pull request #14937 from b0nce:fix-stats
...
PiperOrigin-RevId: 518888600
2023-03-23 10:01:54 -07:00
Peter Hawkins
befce6d2c8
[XLA:Python] Allow passing ExecutableBuildOptions to outfeed receiver.
...
Outfeed receiver compiles computations (during shutdown), and if the correct options aren't provided, then it may not be able to do things like find ptxas for CUDA builds. Plumb the executable build options through from Python.
PiperOrigin-RevId: 518852909
2023-03-23 07:31:06 -07:00
jax authors
54e8101f00
Merge pull request #15123 from jakevdp:fix-mean-large-dims
...
PiperOrigin-RevId: 518852476
2023-03-23 07:23:20 -07:00
Peter Hawkins
8bb90b5fbe
[XLA:Python] Change JAX and the XLA Python extension to get NumPy bfloat16/float8 types from ml_dtypes.
...
PiperOrigin-RevId: 518830467
2023-03-23 05:13:39 -07:00
George Necula
bd1f53ed6d
[jax2tf] Fix tests broken by upgrade of XlaCallModule
...
PiperOrigin-RevId: 518811580
2023-03-23 03:16:41 -07:00
Etienne Pot
4cb32ba46f
Fix isinstance(k, PRNGKeyArray) on PRNGKeyArray subclasses
...
PiperOrigin-RevId: 518803946
2023-03-23 02:32:06 -07:00
jax authors
6d1c849a53
fix typo: "one of more" -> "one or more"
...
PiperOrigin-RevId: 518762341
2023-03-22 22:06:50 -07:00
jax authors
e39578cd73
Merge pull request #15154 from mattjj:pjit-typecheck
...
PiperOrigin-RevId: 518717095
2023-03-22 17:31:59 -07:00
Parker Schuh
484eb26d2a
Redefine compile_and_serialize
as serialize(lowered.compile())
.
...
This has the downside of keeping around the UnloadedMeshComputation,
but it makes the serialize() API easier to understand.
PiperOrigin-RevId: 518715469
2023-03-22 17:23:19 -07:00
Matthew Johnson
268456ef54
enable pjit recursive typechecking
...
Give pjit_p a custom typecheck rule, which basically just calls the
core._check_call utility (which was made for xla_call_p and core.call_p).
This revealed the need for a slight generalization of the custom_typecheck rule
signature, for better "context-aware" printing of jaxpr type errors: the rules
should have a `ctx_factory` first argument. **The reason this PR touches so
many files is just that it makes the trivial tweaks to all existing typecheck
rules to accomodate that new signature.** I didn't adapt any other higher-order
primitives' rules to actually use the context, but presumably errors for HOPs
like scan would be improved by using it. Follow-up work!
It's key that core._check_call works with dynamic shapes; this PR is soon to be
followed by some djax+pjit PRs!
2023-03-22 16:59:22 -07:00
jax authors
5f724cf9a2
Merge pull request #15151 from mattjj:qiao-remat-print-res-info
...
PiperOrigin-RevId: 518709114
2023-03-22 16:53:22 -07:00
Matthew Johnson
6b4262d9f6
add experimental jax_log_checkpoint_residuals option
...
The main idea here is to improve tooling for knowing what residuals are being
saved and why. There's a lot more that can be done here (e.g. naming the
arguments, explaining what JVP rule produced these residuals, explaining what
consumed them, etc) but this is a start.
Co-authored-by: Qiao Zhang <zhangqiaorjc@google.com>
2023-03-22 16:26:56 -07:00
Jake VanderPlas
91040120ec
DOC: add formulae for distributions in jax.random
2023-03-22 12:36:10 -07:00
jax authors
dd2ecf4bb5
Merge pull request #15085 from mattjj:arg-info-in-mlir-5
...
PiperOrigin-RevId: 518642948
2023-03-22 12:31:35 -07:00
Kevin Gleason
022b47fd91
Improve handling of dynamic shapes in jax native serialization
...
PiperOrigin-RevId: 518634912
2023-03-22 12:06:54 -07:00
Jake VanderPlas
398532b32c
jax.random: remove scale from wald function
2023-03-22 10:44:50 -07:00
Peter Hawkins
64e1f5fe3d
Revert: custom_vjp
symbolic zeros support
...
PiperOrigin-RevId: 518597609
2023-03-22 09:56:09 -07:00
jax authors
00e6c73b68
Merge pull request #15114 from JiaYaobo:add_wald_random
...
PiperOrigin-RevId: 518592428
2023-03-22 09:37:44 -07:00
Peter Hawkins
4b0eee1632
Fix mypy failures in jax2tf.
...
PiperOrigin-RevId: 518572905
2023-03-22 08:18:40 -07:00
George Necula
bab83c3a10
[jax2tf] Fix grad of pjit in native lowering.
...
Since jax2tf.convert is called recursively for the purpose of
serializing the vjp function, we must ensure that if the primal
function is a pjit with shardings then the vjp function must also
be converted as a pjit.
Without this fix the serialization with gradients of a pjit function
will fail the an error that there are shardings but not pjit at
the top-level.
2023-03-22 10:29:30 +01:00
jiayaobo
f7a14d65d2
add wald random generator
...
add wald to random.py
2023-03-22 11:06:59 +08:00