21201 Commits

Author SHA1 Message Date
Adam Paszke
3b4039c850 [Mosaic GPU] Load LLVM lowering interfaces for all dialects
Apparently we were missing interface registration code for LLVM lowering,
which the gpu-to-llvm pass gracefully ignores unless compiled with debug
assertions enabled. But, simply adding the assertions in fact makes the
pass _too powerful_ and makes it lower _all dialects to LLVM_, which is not
what we want. That's why I've replaced it with a minimal version that is
only repsponsible for handling the GPU dialect, making the lowering similar
to the one prior to extra registrations.

PiperOrigin-RevId: 641874183
2024-06-10 05:55:01 -07:00
George Necula
2ade7e7526 [pallas] Move the hardware_generation query in the code path that needs it
This change allows us to lower and export Pallas calls even
on machines that do not have TPUs, in many cases.

PiperOrigin-RevId: 641841079
2024-06-10 03:13:36 -07:00
jax authors
af95803d00 Merge pull request #21759 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 641831969
2024-06-10 02:29:12 -07:00
rajasekharporeddy
775c6f8727 Fix Typos in docs and one error message 2024-06-10 11:38:01 +05:30
jax authors
8fbe65b4b2 Update XLA dependency to use revision
32ba408c0e.

PiperOrigin-RevId: 641736314
2024-06-09 15:36:14 -07:00
Junwhan Ahn
6617a0d1ed Expand device_put benchmarks to run with different numbers of arrays and input types
For the upcoming batching changes for `device_put`, it is useful to benchmark `device_put` with varying numbers of arrays.

PiperOrigin-RevId: 641716268
2024-06-09 13:01:51 -07:00
Peter Hawkins
a8246ea67f Issue a warning where code relies on a bug where treedef.flatten_up_to(...) was overly permissive for None treedefs.
For example, tree_map(..., None, [2, 3]) previously did not raise an error, but None is a container and only leaves can be considered tree prefixes in this case.

In a future release of JAX, this behavior will become an error.

PiperOrigin-RevId: 641690427
2024-06-09 09:18:29 -07:00
George Necula
14d87d3bf7 [export] Move the export implementation to jax._src.export.
This is part of the work to move the export APIs out
of jax.experimental. For now, the way to use this
implementation is still through `jax.experimental.export`.

Had to add a few "#type ignore" to the _export.py because
previously the file was exempt from internal pytype.
Will try to fix these in a later PR.

PiperOrigin-RevId: 641688200
2024-06-09 08:59:50 -07:00
jax authors
aaa559a6b3 Update XLA dependency to use revision
7667c63918.

PiperOrigin-RevId: 641568627
2024-06-08 15:45:36 -07:00
jax authors
b486a95186 Merge pull request #21507 from renecotyfanboy:main
PiperOrigin-RevId: 641429523
2024-06-07 20:28:23 -07:00
jax authors
6c822c0124 Update XLA dependency to use revision
3195fdc851.

PiperOrigin-RevId: 641387498
2024-06-07 16:19:00 -07:00
jax authors
d32404020b Avoid "min() arg is an empty sequence" error after enabling "jax_explain_cache_misses".
PiperOrigin-RevId: 641381432
2024-06-07 15:52:35 -07:00
sdupourque
751d59ce67 increase default precision for hyp1f1 2024-06-08 00:38:51 +02:00
Yash Katariya
57826d8c65 Add a no input memories_test and enable memories test on vf 2x2
PiperOrigin-RevId: 641361865
2024-06-07 14:40:44 -07:00
jax authors
0d047a116a Merge pull request #21718 from jakevdp:pallas-config
PiperOrigin-RevId: 641349981
2024-06-07 13:58:49 -07:00
Yash Katariya
44a13c9d4b Merge code between make_jaxpr and jit(f).trace.
The semantics of `make_jaxpr` are preserved here i.e. `make_jaxpr` still closes over tracers but `jit(f).trace` doesn't.

Since we can keep the existing behavior and still merge the implementation is a good cleanup!

Fixes https://github.com/google/jax/issues/21116

PiperOrigin-RevId: 641347140
2024-06-07 13:48:31 -07:00
jax authors
25cc84b879 Merge pull request #21615 from selamw1:append_doc
PiperOrigin-RevId: 641344856
2024-06-07 13:39:57 -07:00
jax authors
dfc6076db2 Merge pull request #21744 from superbobry:typing
PiperOrigin-RevId: 641339815
2024-06-07 13:23:31 -07:00
Sergei Lebedev
136289e914 Added filelock to py_deps
This should unblock #21394, which uses filelock in the compilation cache.

PiperOrigin-RevId: 641338150
2024-06-07 13:16:33 -07:00
jax authors
7d913f763a Merge pull request #21298 from oliverdutton:pallas_interpreter_indexing_fix
PiperOrigin-RevId: 641325047
2024-06-07 12:29:31 -07:00
Sergei Lebedev
0786da8fd8 Removed unnecessary mypy exclusions from pyproject.toml
* 2/3 files type check just fine now
* the remaining one could be handled via a file-level directive
2024-06-07 20:07:42 +01:00
jax authors
f4c6437837 Merge pull request #21680 from ROCm:ci_spmm
PiperOrigin-RevId: 641316410
2024-06-07 11:57:12 -07:00
jax authors
af90464b53 Merge pull request #21733 from dfm:ffi-capsule-docstring
PiperOrigin-RevId: 641307843
2024-06-07 11:27:41 -07:00
jax authors
bd499a921e Merge pull request #21690 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 641292860
2024-06-07 10:38:07 -07:00
jax authors
98d7235aee Merge pull request #21501 from jakevdp:softmax-inf-doc
PiperOrigin-RevId: 641291919
2024-06-07 10:34:40 -07:00
jax authors
1459ac04a8 Merge pull request #21731 from tttc3:cross-product-typo
PiperOrigin-RevId: 641285460
2024-06-07 10:18:35 -07:00
jax authors
2899c9fada Merge pull request #21692 from rajasekharporeddy:testbranch2
PiperOrigin-RevId: 641285369
2024-06-07 10:15:22 -07:00
jax authors
30feb352b4 Merge pull request #21656 from yamlyeti:yamlyeti-patch-1
PiperOrigin-RevId: 641284969
2024-06-07 10:12:02 -07:00
Dan Foreman-Mackey
1fa66590d1 Edit pycapsule docstring to provide a little bit more context
The docstring for the recently added `pycapsule` function in
`jax.extend.ffi` didn't conform to our usual docstring format, so I
updated it and added a little bit more context.
2024-06-07 13:07:03 -04:00
Paweł Paruzel
5fcd50b7fa Refactor kernel function assigment
PiperOrigin-RevId: 641255192
2024-06-07 08:20:31 -07:00
jax authors
f51af87fc5 fp8 matmul in pallas
PiperOrigin-RevId: 641254832
2024-06-07 08:17:06 -07:00
George Necula
3914cb415d [export] Remove old deprecated APIs for jax.experimental.export.
See CHANGELOG.md.
The deprecation period has passed.

Also replace deprecated .call_exported with .call in tests.

PiperOrigin-RevId: 641236222
2024-06-07 06:52:10 -07:00
tttc3
21f71c6b66 fix typo in jax.numpy.linalg.cross docstring 2024-06-07 13:43:51 +01:00
Sergei Lebedev
5d6413cecc Added debug_callback to the list of exclusions in jax2tf/tests/primitives_test.py
PiperOrigin-RevId: 641149152
2024-06-07 00:01:30 -07:00
jax authors
c01c98400d Add missing arguments for jnp.extract's python binding signature.
PiperOrigin-RevId: 641121305
2024-06-06 21:34:38 -07:00
rajasekharporeddy
6d94ae3274 Improve docs for jnp.angle and jnp.flip 2024-06-07 10:03:07 +05:30
rajasekharporeddy
6d85c3890d Improve documentation for jnp.fliplr and jnp.flipud 2024-06-07 09:58:02 +05:30
jax authors
625ea07a7e Merge pull request #21710 from jakevdp:fix-jax2tf
PiperOrigin-RevId: 641112498
2024-06-06 20:45:57 -07:00
Roy Frostig
ea6dfd1947 rename Specialized to Traced (and specialize to trace)
PiperOrigin-RevId: 641076488
2024-06-06 17:43:08 -07:00
jax authors
dd40d8852d Update XLA dependency to use revision
9449b0851c.

PiperOrigin-RevId: 641069331
2024-06-06 17:12:57 -07:00
Jake VanderPlas
a2c31f4d15 pallas/mosaic test: avoid leaking global config state 2024-06-06 16:00:02 -07:00
jax authors
a1b5860427 Merge pull request #21711 from jakevdp:setup-module
PiperOrigin-RevId: 641049524
2024-06-06 15:59:07 -07:00
Jake VanderPlas
a861c55a28 test cleanup: use ExitStack to reduce test boilerplate 2024-06-06 14:18:27 -07:00
jax authors
d457f9a116 Merge pull request #21716 from gnecula:exp_rename_sharding
PiperOrigin-RevId: 641017765
2024-06-06 14:17:10 -07:00
George Necula
01ee768f73 [export] Rename in_shardings and out_shardings fields.
We rename `in_shardings` to `in_shardings_hlo` to remove confusion
with JAX's use of `in_shardings`.
We also rename `xla_compatible_in_sharding` to `in_shardings_jax`
since we do not have a XLACompatibleSharding type anymore.
2024-06-06 22:00:16 +01:00
Yash Katariya
aee62e4874 Implement lower in terms of specialize
PiperOrigin-RevId: 641005643
2024-06-06 13:39:07 -07:00
jax authors
90c83bb1e2 Merge pull request #21484 from dfm:custom-call-lowering
PiperOrigin-RevId: 640996459
2024-06-06 13:10:28 -07:00
Mark Sandler
2c246df439 Reverts dfe61285093ff826e1ad23bb36b77a42c01040b4
PiperOrigin-RevId: 640987745
2024-06-06 12:41:17 -07:00
Yash Katariya
fbf2a62aa1 Remove jaxpr and name from Lowered because specialize already has those. This keeps the abstraction boundary clear. Adapt export to use specialize.
PiperOrigin-RevId: 640968129
2024-06-06 11:38:56 -07:00
Tomás Longeri
a65d3ae0da [Mosaic] Expand vector.shape_cast support for sublane (un)folding no-ops
- Support non-zero minor offsets without having to relayout (they're still a no-op).
- Remove restriction on tiling which now allows 1D packed types to work.

PiperOrigin-RevId: 640967375
2024-06-06 11:35:19 -07:00