21394 Commits

Author SHA1 Message Date
jax authors
b486a95186 Merge pull request #21507 from renecotyfanboy:main
PiperOrigin-RevId: 641429523
2024-06-07 20:28:23 -07:00
jax authors
6c822c0124 Update XLA dependency to use revision
3195fdc851.

PiperOrigin-RevId: 641387498
2024-06-07 16:19:00 -07:00
jax authors
d32404020b Avoid "min() arg is an empty sequence" error after enabling "jax_explain_cache_misses".
PiperOrigin-RevId: 641381432
2024-06-07 15:52:35 -07:00
sdupourque
751d59ce67 increase default precision for hyp1f1 2024-06-08 00:38:51 +02:00
rajasekharporeddy
7989c70572 Add example code snippets to jax.scipy.linalg.expm and jax.scipy.linalg.polar docs 2024-06-08 03:30:12 +05:30
Yash Katariya
57826d8c65 Add a no input memories_test and enable memories test on vf 2x2
PiperOrigin-RevId: 641361865
2024-06-07 14:40:44 -07:00
jax authors
0d047a116a Merge pull request #21718 from jakevdp:pallas-config
PiperOrigin-RevId: 641349981
2024-06-07 13:58:49 -07:00
Yash Katariya
44a13c9d4b Merge code between make_jaxpr and jit(f).trace.
The semantics of `make_jaxpr` are preserved here i.e. `make_jaxpr` still closes over tracers but `jit(f).trace` doesn't.

Since we can keep the existing behavior and still merge the implementation is a good cleanup!

Fixes https://github.com/google/jax/issues/21116

PiperOrigin-RevId: 641347140
2024-06-07 13:48:31 -07:00
jax authors
25cc84b879 Merge pull request #21615 from selamw1:append_doc
PiperOrigin-RevId: 641344856
2024-06-07 13:39:57 -07:00
jax authors
dfc6076db2 Merge pull request #21744 from superbobry:typing
PiperOrigin-RevId: 641339815
2024-06-07 13:23:31 -07:00
Sergei Lebedev
136289e914 Added filelock to py_deps
This should unblock #21394, which uses filelock in the compilation cache.

PiperOrigin-RevId: 641338150
2024-06-07 13:16:33 -07:00
jax authors
7d913f763a Merge pull request #21298 from oliverdutton:pallas_interpreter_indexing_fix
PiperOrigin-RevId: 641325047
2024-06-07 12:29:31 -07:00
Sergei Lebedev
0786da8fd8 Removed unnecessary mypy exclusions from pyproject.toml
* 2/3 files type check just fine now
* the remaining one could be handled via a file-level directive
2024-06-07 20:07:42 +01:00
jax authors
f4c6437837 Merge pull request #21680 from ROCm:ci_spmm
PiperOrigin-RevId: 641316410
2024-06-07 11:57:12 -07:00
jax authors
af90464b53 Merge pull request #21733 from dfm:ffi-capsule-docstring
PiperOrigin-RevId: 641307843
2024-06-07 11:27:41 -07:00
jax authors
bd499a921e Merge pull request #21690 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 641292860
2024-06-07 10:38:07 -07:00
jax authors
98d7235aee Merge pull request #21501 from jakevdp:softmax-inf-doc
PiperOrigin-RevId: 641291919
2024-06-07 10:34:40 -07:00
jax authors
1459ac04a8 Merge pull request #21731 from tttc3:cross-product-typo
PiperOrigin-RevId: 641285460
2024-06-07 10:18:35 -07:00
jax authors
2899c9fada Merge pull request #21692 from rajasekharporeddy:testbranch2
PiperOrigin-RevId: 641285369
2024-06-07 10:15:22 -07:00
jax authors
30feb352b4 Merge pull request #21656 from yamlyeti:yamlyeti-patch-1
PiperOrigin-RevId: 641284969
2024-06-07 10:12:02 -07:00
Dan Foreman-Mackey
1fa66590d1 Edit pycapsule docstring to provide a little bit more context
The docstring for the recently added `pycapsule` function in
`jax.extend.ffi` didn't conform to our usual docstring format, so I
updated it and added a little bit more context.
2024-06-07 13:07:03 -04:00
Paweł Paruzel
5fcd50b7fa Refactor kernel function assigment
PiperOrigin-RevId: 641255192
2024-06-07 08:20:31 -07:00
jax authors
f51af87fc5 fp8 matmul in pallas
PiperOrigin-RevId: 641254832
2024-06-07 08:17:06 -07:00
Frederic Bastien
da8a7b2855 Add in the tutorial the idea to test 1 process per node and 1 process per GPU. 2024-06-07 10:00:04 -04:00
George Necula
3914cb415d [export] Remove old deprecated APIs for jax.experimental.export.
See CHANGELOG.md.
The deprecation period has passed.

Also replace deprecated .call_exported with .call in tests.

PiperOrigin-RevId: 641236222
2024-06-07 06:52:10 -07:00
tttc3
21f71c6b66 fix typo in jax.numpy.linalg.cross docstring 2024-06-07 13:43:51 +01:00
Sergei Lebedev
5d6413cecc Added debug_callback to the list of exclusions in jax2tf/tests/primitives_test.py
PiperOrigin-RevId: 641149152
2024-06-07 00:01:30 -07:00
jax authors
c01c98400d Add missing arguments for jnp.extract's python binding signature.
PiperOrigin-RevId: 641121305
2024-06-06 21:34:38 -07:00
rajasekharporeddy
6d94ae3274 Improve docs for jnp.angle and jnp.flip 2024-06-07 10:03:07 +05:30
rajasekharporeddy
6d85c3890d Improve documentation for jnp.fliplr and jnp.flipud 2024-06-07 09:58:02 +05:30
jax authors
625ea07a7e Merge pull request #21710 from jakevdp:fix-jax2tf
PiperOrigin-RevId: 641112498
2024-06-06 20:45:57 -07:00
Roy Frostig
ea6dfd1947 rename Specialized to Traced (and specialize to trace)
PiperOrigin-RevId: 641076488
2024-06-06 17:43:08 -07:00
jax authors
dd40d8852d Update XLA dependency to use revision
9449b0851c.

PiperOrigin-RevId: 641069331
2024-06-06 17:12:57 -07:00
Jake VanderPlas
a2c31f4d15 pallas/mosaic test: avoid leaking global config state 2024-06-06 16:00:02 -07:00
jax authors
a1b5860427 Merge pull request #21711 from jakevdp:setup-module
PiperOrigin-RevId: 641049524
2024-06-06 15:59:07 -07:00
Jake VanderPlas
a861c55a28 test cleanup: use ExitStack to reduce test boilerplate 2024-06-06 14:18:27 -07:00
jax authors
d457f9a116 Merge pull request #21716 from gnecula:exp_rename_sharding
PiperOrigin-RevId: 641017765
2024-06-06 14:17:10 -07:00
George Necula
01ee768f73 [export] Rename in_shardings and out_shardings fields.
We rename `in_shardings` to `in_shardings_hlo` to remove confusion
with JAX's use of `in_shardings`.
We also rename `xla_compatible_in_sharding` to `in_shardings_jax`
since we do not have a XLACompatibleSharding type anymore.
2024-06-06 22:00:16 +01:00
Yash Katariya
aee62e4874 Implement lower in terms of specialize
PiperOrigin-RevId: 641005643
2024-06-06 13:39:07 -07:00
jax authors
90c83bb1e2 Merge pull request #21484 from dfm:custom-call-lowering
PiperOrigin-RevId: 640996459
2024-06-06 13:10:28 -07:00
Mark Sandler
2c246df439 Reverts dfe61285093ff826e1ad23bb36b77a42c01040b4
PiperOrigin-RevId: 640987745
2024-06-06 12:41:17 -07:00
Yash Katariya
fbf2a62aa1 Remove jaxpr and name from Lowered because specialize already has those. This keeps the abstraction boundary clear. Adapt export to use specialize.
PiperOrigin-RevId: 640968129
2024-06-06 11:38:56 -07:00
Tomás Longeri
a65d3ae0da [Mosaic] Expand vector.shape_cast support for sublane (un)folding no-ops
- Support non-zero minor offsets without having to relayout (they're still a no-op).
- Remove restriction on tiling which now allows 1D packed types to work.

PiperOrigin-RevId: 640967375
2024-06-06 11:35:19 -07:00
Jake VanderPlas
48355cde83 jax2tf_test: ensure no modification of global config 2024-06-06 11:27:33 -07:00
jax authors
82516c5d4f Merge pull request #21694 from rajasekharporeddy:doc_typos
PiperOrigin-RevId: 640956334
2024-06-06 11:05:37 -07:00
jax authors
cc4bd42390 Merge pull request #21688 from froystig:slab-heap
PiperOrigin-RevId: 640953143
2024-06-06 10:56:09 -07:00
jax authors
15e41a620f Merge pull request #21702 from hawkinsp:cudnnplug
PiperOrigin-RevId: 640932820
2024-06-06 09:58:12 -07:00
Jevin Jiang
7a5975e174 [Pallas] Fix typo in test.
PiperOrigin-RevId: 640930803
2024-06-06 09:51:35 -07:00
Peter Hawkins
971ab0fba2 Make CuDNN SDPA API work with JAX with a CUDA plugin configuration. 2024-06-06 12:09:19 -04:00
Christos Perivolaropoulos
18e55d567f [test_utils] Fix the encoding of capture_stdout so it works on windows.
PiperOrigin-RevId: 640910749
2024-06-06 08:43:25 -07:00