21232 Commits

Author SHA1 Message Date
Sergei Lebedev
70f6ab3128 Updated the type annotations of *_spec= parameters of pl.pallas_call
The previous type did not work for nested pytrees and for some reason neither
pytype nor mypy flagged that.

I also re-enabled type checking for most pallas/*.py files.
2024-06-11 12:22:00 +01:00
Sergei Lebedev
f8473509cf Removed kernel_regeneration_util from Mosaic
It was only used for persisting kernel metadata, and that can be done via
jax.named_scope instead.

PiperOrigin-RevId: 642195336
2024-06-11 02:36:41 -07:00
jax authors
11370b758f Merge pull request #21782 from jakevdp:rel-entr
PiperOrigin-RevId: 642094313
2024-06-10 18:51:22 -07:00
jax authors
02b5d4769d Swap operands of dot if the LHS is fed by a parameter
PiperOrigin-RevId: 642090766
2024-06-10 18:33:05 -07:00
Justin Fu
9439f63645 [Pallas] Add pallas TPU random key impls and lowering rules for basic prng ops (seed/foldin/bits/unwrap/wrap).
PiperOrigin-RevId: 642085019
2024-06-10 18:08:19 -07:00
jax authors
3d4ee0dd7a Merge pull request #21791 from jakevdp:remove-deprecated
PiperOrigin-RevId: 642068297
2024-06-10 16:58:39 -07:00
Jake VanderPlas
266028f4a1 Remove unused variable 2024-06-10 16:30:44 -07:00
Yash Katariya
956226c929 Raise an error if device_put sees an invalid value.
PiperOrigin-RevId: 642053543
2024-06-10 16:07:44 -07:00
jax authors
71c19b779d Rewrite vector.contraction with bf16 accumulator and output into a
contraction with f32 accumulator and output, where the accumulator is
extended and the output truncated. For targets that do not support bf16
matmul, the lhs and rhs are extended to f32.

PiperOrigin-RevId: 642051952
2024-06-10 16:02:46 -07:00
jax authors
9d9dd36219 Adds test_compute_no_inputs_host_replicated in memories_test.py
PiperOrigin-RevId: 642033992
2024-06-10 15:02:34 -07:00
jax authors
bb24a92593 Update XLA dependency to use revision
af7fe24506.

PiperOrigin-RevId: 642026581
2024-06-10 14:38:29 -07:00
jax authors
af004302c1 Merge pull request #21516 from nouiz:paralell_computation
PiperOrigin-RevId: 642004618
2024-06-10 13:29:10 -07:00
jax authors
27de85439e Merge pull request #21781 from hawkinsp:release
PiperOrigin-RevId: 641994356
2024-06-10 12:56:31 -07:00
jax authors
489febee04 Enable input fusion for a specific kernel pattern.
cl/640530524 introduces batching support for some pallas calls that don't currently support it yet using dynamic slicing the input and dynamically updating the output. This CL ensures that XLA-guided input fusion into pallas kernel is working as expected for such pattern. We don't have support for fusion on the output side yet for pallas kernels.

PiperOrigin-RevId: 641989012
2024-06-10 12:37:49 -07:00
jax authors
f4dfa840e3 Merge pull request #21774 from jakevdp:tree-all-is-leaf
PiperOrigin-RevId: 641978173
2024-06-10 12:01:05 -07:00
Jevin Jiang
53daa0c742 [XLA:Mosaic] Fix infer layout for nested loop.
- We should recursively clear layouts and any assume_layout ops if we want to override layouts in a block.
- Refactor the logic of assume layouts for block arguments to a helper function.
- Add tests for nested fori loop and while loop.

PiperOrigin-RevId: 641973011
2024-06-10 11:49:01 -07:00
jax authors
f6ce973860 Merge pull request #21745 from pkgoogle:better_right_shift_doc
PiperOrigin-RevId: 641972495
2024-06-10 11:45:38 -07:00
Vadym Matsishevskyi
a073476fa0 chore: adopt new local wheel installation logic
PiperOrigin-RevId: 641972325
2024-06-10 11:41:52 -07:00
Peter Hawkins
6fa31e59c4 Update version numbers after v0.4.29 release. 2024-06-10 14:37:53 -04:00
Jake VanderPlas
afe088f876 Simplify definition of jax.scipy.special.kl_div 2024-06-10 11:36:35 -07:00
jax authors
3fe7377719 Merge pull request #21763 from gnecula:export_api
PiperOrigin-RevId: 641959833
2024-06-10 11:05:34 -07:00
George Necula
b33aca6b08 [export] Create the jax.export module APIs.
The functionality comes from the jax.experimental.export
module, which will be deprecated.

The following APIs are introduced:

```
  from jax import export
  def f(...): ...
  ex: export.Exported = export.export(jax.jit(f))(*args, **kwargs)

  blob: bytearray = ex.serialize()
  rehydrated: export.Export = export.deserialize(blob)

  def caller(...):
     ... rehydrated.call(*args, **kwargs)
```

Module documentation will follow shortly.
There are no changes for now in the jax.experimental.export
APIs.

Most of the changes in this PR are in tests due to some differences
in the new jax.export APIs compared to jax.experimental.export:

  * Instead of `jax.experimental.export.call(exp)` we now write
    `exp.call`
  * The `jax.experimental.export.export` allowed the function
    argument to be any Python callable and it would wrap it with
    a `jax.jit`. This is not supported anymore by export, and instead
    the user must use `jax.jit`.
2024-06-10 19:31:51 +02:00
Piseth Ky
07d90e5195 adding doc string to right_shift
updated punctuation and phrasing

fix example comment/code ordering

reformatting of description and adding print_binary helper

moving helper function to ufuncs.py

moved print_binary definition to doc string

fix in doc print_binary def and other edits
2024-06-10 09:59:42 -07:00
Jake VanderPlas
814b32a44b tree_all: add support for is_leaf 2024-06-10 09:46:15 -07:00
jax authors
833c7ba789 Allow generation of sharding strategies with mixed mesh shapes by default.
PiperOrigin-RevId: 641930205
2024-06-10 09:38:39 -07:00
Adam Paszke
0739d520b1 [Mosaic GPU] Don't always run with llvm::DebugFlag enabled
This slipped past during code review.

PiperOrigin-RevId: 641899993
2024-06-10 07:50:26 -07:00
Thomas Köppe
cd93b46df4 Add initialization annotations (for the benefit of MSAN) to variables that are initialized by external functions.
PiperOrigin-RevId: 641879836
2024-06-10 06:21:16 -07:00
jax authors
991797a8a9 Merge pull request #21765 from hawkinsp:release
PiperOrigin-RevId: 641876244
2024-06-10 06:03:58 -07:00
Sergei Lebedev
5e7ad600e2 Removed the double re-exporting of Pallas GPU/TPU APIs
jax.experimental.pallas.{gpu,tpu} now import directly from the relevant
jax._src.pallas.{triton,mosaic} submodules.

PiperOrigin-RevId: 641875127
2024-06-10 05:59:09 -07:00
Adam Paszke
3b4039c850 [Mosaic GPU] Load LLVM lowering interfaces for all dialects
Apparently we were missing interface registration code for LLVM lowering,
which the gpu-to-llvm pass gracefully ignores unless compiled with debug
assertions enabled. But, simply adding the assertions in fact makes the
pass _too powerful_ and makes it lower _all dialects to LLVM_, which is not
what we want. That's why I've replaced it with a minimal version that is
only repsponsible for handling the GPU dialect, making the lowering similar
to the one prior to extra registrations.

PiperOrigin-RevId: 641874183
2024-06-10 05:55:01 -07:00
Peter Hawkins
e071053d97 Prepare for 0.4.29 release. 2024-06-10 08:54:28 -04:00
George Necula
2ade7e7526 [pallas] Move the hardware_generation query in the code path that needs it
This change allows us to lower and export Pallas calls even
on machines that do not have TPUs, in many cases.

PiperOrigin-RevId: 641841079
2024-06-10 03:13:36 -07:00
jax authors
af95803d00 Merge pull request #21759 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 641831969
2024-06-10 02:29:12 -07:00
rajasekharporeddy
775c6f8727 Fix Typos in docs and one error message 2024-06-10 11:38:01 +05:30
jax authors
8fbe65b4b2 Update XLA dependency to use revision
32ba408c0e.

PiperOrigin-RevId: 641736314
2024-06-09 15:36:14 -07:00
Junwhan Ahn
6617a0d1ed Expand device_put benchmarks to run with different numbers of arrays and input types
For the upcoming batching changes for `device_put`, it is useful to benchmark `device_put` with varying numbers of arrays.

PiperOrigin-RevId: 641716268
2024-06-09 13:01:51 -07:00
Peter Hawkins
a8246ea67f Issue a warning where code relies on a bug where treedef.flatten_up_to(...) was overly permissive for None treedefs.
For example, tree_map(..., None, [2, 3]) previously did not raise an error, but None is a container and only leaves can be considered tree prefixes in this case.

In a future release of JAX, this behavior will become an error.

PiperOrigin-RevId: 641690427
2024-06-09 09:18:29 -07:00
George Necula
14d87d3bf7 [export] Move the export implementation to jax._src.export.
This is part of the work to move the export APIs out
of jax.experimental. For now, the way to use this
implementation is still through `jax.experimental.export`.

Had to add a few "#type ignore" to the _export.py because
previously the file was exempt from internal pytype.
Will try to fix these in a later PR.

PiperOrigin-RevId: 641688200
2024-06-09 08:59:50 -07:00
jax authors
aaa559a6b3 Update XLA dependency to use revision
7667c63918.

PiperOrigin-RevId: 641568627
2024-06-08 15:45:36 -07:00
jax authors
b486a95186 Merge pull request #21507 from renecotyfanboy:main
PiperOrigin-RevId: 641429523
2024-06-07 20:28:23 -07:00
jax authors
6c822c0124 Update XLA dependency to use revision
3195fdc851.

PiperOrigin-RevId: 641387498
2024-06-07 16:19:00 -07:00
jax authors
d32404020b Avoid "min() arg is an empty sequence" error after enabling "jax_explain_cache_misses".
PiperOrigin-RevId: 641381432
2024-06-07 15:52:35 -07:00
sdupourque
751d59ce67 increase default precision for hyp1f1 2024-06-08 00:38:51 +02:00
Yash Katariya
57826d8c65 Add a no input memories_test and enable memories test on vf 2x2
PiperOrigin-RevId: 641361865
2024-06-07 14:40:44 -07:00
jax authors
0d047a116a Merge pull request #21718 from jakevdp:pallas-config
PiperOrigin-RevId: 641349981
2024-06-07 13:58:49 -07:00
Yash Katariya
44a13c9d4b Merge code between make_jaxpr and jit(f).trace.
The semantics of `make_jaxpr` are preserved here i.e. `make_jaxpr` still closes over tracers but `jit(f).trace` doesn't.

Since we can keep the existing behavior and still merge the implementation is a good cleanup!

Fixes https://github.com/google/jax/issues/21116

PiperOrigin-RevId: 641347140
2024-06-07 13:48:31 -07:00
jax authors
25cc84b879 Merge pull request #21615 from selamw1:append_doc
PiperOrigin-RevId: 641344856
2024-06-07 13:39:57 -07:00
jax authors
dfc6076db2 Merge pull request #21744 from superbobry:typing
PiperOrigin-RevId: 641339815
2024-06-07 13:23:31 -07:00
Sergei Lebedev
136289e914 Added filelock to py_deps
This should unblock #21394, which uses filelock in the compilation cache.

PiperOrigin-RevId: 641338150
2024-06-07 13:16:33 -07:00
jax authors
7d913f763a Merge pull request #21298 from oliverdutton:pallas_interpreter_indexing_fix
PiperOrigin-RevId: 641325047
2024-06-07 12:29:31 -07:00