jax authors
b486a95186
Merge pull request #21507 from renecotyfanboy:main
...
PiperOrigin-RevId: 641429523
2024-06-07 20:28:23 -07:00
jax authors
6c822c0124
Update XLA dependency to use revision
...
3195fdc851
.
PiperOrigin-RevId: 641387498
2024-06-07 16:19:00 -07:00
jax authors
d32404020b
Avoid "min() arg is an empty sequence" error after enabling "jax_explain_cache_misses".
...
PiperOrigin-RevId: 641381432
2024-06-07 15:52:35 -07:00
sdupourque
751d59ce67
increase default precision for hyp1f1
2024-06-08 00:38:51 +02:00
rajasekharporeddy
7989c70572
Add example code snippets to jax.scipy.linalg.expm and jax.scipy.linalg.polar docs
2024-06-08 03:30:12 +05:30
Yash Katariya
57826d8c65
Add a no input memories_test and enable memories test on vf 2x2
...
PiperOrigin-RevId: 641361865
2024-06-07 14:40:44 -07:00
jax authors
0d047a116a
Merge pull request #21718 from jakevdp:pallas-config
...
PiperOrigin-RevId: 641349981
2024-06-07 13:58:49 -07:00
Yash Katariya
44a13c9d4b
Merge code between make_jaxpr
and jit(f).trace
.
...
The semantics of `make_jaxpr` are preserved here i.e. `make_jaxpr` still closes over tracers but `jit(f).trace` doesn't.
Since we can keep the existing behavior and still merge the implementation is a good cleanup!
Fixes https://github.com/google/jax/issues/21116
PiperOrigin-RevId: 641347140
2024-06-07 13:48:31 -07:00
jax authors
25cc84b879
Merge pull request #21615 from selamw1:append_doc
...
PiperOrigin-RevId: 641344856
2024-06-07 13:39:57 -07:00
jax authors
dfc6076db2
Merge pull request #21744 from superbobry:typing
...
PiperOrigin-RevId: 641339815
2024-06-07 13:23:31 -07:00
Sergei Lebedev
136289e914
Added filelock to py_deps
...
This should unblock #21394 , which uses filelock in the compilation cache.
PiperOrigin-RevId: 641338150
2024-06-07 13:16:33 -07:00
jax authors
7d913f763a
Merge pull request #21298 from oliverdutton:pallas_interpreter_indexing_fix
...
PiperOrigin-RevId: 641325047
2024-06-07 12:29:31 -07:00
Sergei Lebedev
0786da8fd8
Removed unnecessary mypy exclusions from pyproject.toml
...
* 2/3 files type check just fine now
* the remaining one could be handled via a file-level directive
2024-06-07 20:07:42 +01:00
jax authors
f4c6437837
Merge pull request #21680 from ROCm:ci_spmm
...
PiperOrigin-RevId: 641316410
2024-06-07 11:57:12 -07:00
jax authors
af90464b53
Merge pull request #21733 from dfm:ffi-capsule-docstring
...
PiperOrigin-RevId: 641307843
2024-06-07 11:27:41 -07:00
jax authors
bd499a921e
Merge pull request #21690 from rajasekharporeddy:testbranch1
...
PiperOrigin-RevId: 641292860
2024-06-07 10:38:07 -07:00
jax authors
98d7235aee
Merge pull request #21501 from jakevdp:softmax-inf-doc
...
PiperOrigin-RevId: 641291919
2024-06-07 10:34:40 -07:00
jax authors
1459ac04a8
Merge pull request #21731 from tttc3:cross-product-typo
...
PiperOrigin-RevId: 641285460
2024-06-07 10:18:35 -07:00
jax authors
2899c9fada
Merge pull request #21692 from rajasekharporeddy:testbranch2
...
PiperOrigin-RevId: 641285369
2024-06-07 10:15:22 -07:00
jax authors
30feb352b4
Merge pull request #21656 from yamlyeti:yamlyeti-patch-1
...
PiperOrigin-RevId: 641284969
2024-06-07 10:12:02 -07:00
Dan Foreman-Mackey
1fa66590d1
Edit pycapsule
docstring to provide a little bit more context
...
The docstring for the recently added `pycapsule` function in
`jax.extend.ffi` didn't conform to our usual docstring format, so I
updated it and added a little bit more context.
2024-06-07 13:07:03 -04:00
Paweł Paruzel
5fcd50b7fa
Refactor kernel function assigment
...
PiperOrigin-RevId: 641255192
2024-06-07 08:20:31 -07:00
jax authors
f51af87fc5
fp8 matmul in pallas
...
PiperOrigin-RevId: 641254832
2024-06-07 08:17:06 -07:00
Frederic Bastien
da8a7b2855
Add in the tutorial the idea to test 1 process per node and 1 process per GPU.
2024-06-07 10:00:04 -04:00
George Necula
3914cb415d
[export] Remove old deprecated APIs for jax.experimental.export.
...
See CHANGELOG.md.
The deprecation period has passed.
Also replace deprecated .call_exported with .call in tests.
PiperOrigin-RevId: 641236222
2024-06-07 06:52:10 -07:00
tttc3
21f71c6b66
fix typo in jax.numpy.linalg.cross
docstring
2024-06-07 13:43:51 +01:00
Sergei Lebedev
5d6413cecc
Added debug_callback to the list of exclusions in jax2tf/tests/primitives_test.py
...
PiperOrigin-RevId: 641149152
2024-06-07 00:01:30 -07:00
jax authors
c01c98400d
Add missing arguments for jnp.extract's python binding signature.
...
PiperOrigin-RevId: 641121305
2024-06-06 21:34:38 -07:00
rajasekharporeddy
6d94ae3274
Improve docs for jnp.angle and jnp.flip
2024-06-07 10:03:07 +05:30
rajasekharporeddy
6d85c3890d
Improve documentation for jnp.fliplr and jnp.flipud
2024-06-07 09:58:02 +05:30
jax authors
625ea07a7e
Merge pull request #21710 from jakevdp:fix-jax2tf
...
PiperOrigin-RevId: 641112498
2024-06-06 20:45:57 -07:00
Roy Frostig
ea6dfd1947
rename Specialized
to Traced
(and specialize
to trace
)
...
PiperOrigin-RevId: 641076488
2024-06-06 17:43:08 -07:00
jax authors
dd40d8852d
Update XLA dependency to use revision
...
9449b0851c
.
PiperOrigin-RevId: 641069331
2024-06-06 17:12:57 -07:00
Jake VanderPlas
a2c31f4d15
pallas/mosaic test: avoid leaking global config state
2024-06-06 16:00:02 -07:00
jax authors
a1b5860427
Merge pull request #21711 from jakevdp:setup-module
...
PiperOrigin-RevId: 641049524
2024-06-06 15:59:07 -07:00
Jake VanderPlas
a861c55a28
test cleanup: use ExitStack to reduce test boilerplate
2024-06-06 14:18:27 -07:00
jax authors
d457f9a116
Merge pull request #21716 from gnecula:exp_rename_sharding
...
PiperOrigin-RevId: 641017765
2024-06-06 14:17:10 -07:00
George Necula
01ee768f73
[export] Rename in_shardings and out_shardings fields.
...
We rename `in_shardings` to `in_shardings_hlo` to remove confusion
with JAX's use of `in_shardings`.
We also rename `xla_compatible_in_sharding` to `in_shardings_jax`
since we do not have a XLACompatibleSharding type anymore.
2024-06-06 22:00:16 +01:00
Yash Katariya
aee62e4874
Implement lower
in terms of specialize
...
PiperOrigin-RevId: 641005643
2024-06-06 13:39:07 -07:00
jax authors
90c83bb1e2
Merge pull request #21484 from dfm:custom-call-lowering
...
PiperOrigin-RevId: 640996459
2024-06-06 13:10:28 -07:00
Mark Sandler
2c246df439
Reverts dfe61285093ff826e1ad23bb36b77a42c01040b4
...
PiperOrigin-RevId: 640987745
2024-06-06 12:41:17 -07:00
Yash Katariya
fbf2a62aa1
Remove jaxpr
and name
from Lowered
because specialize
already has those. This keeps the abstraction boundary clear. Adapt export
to use specialize
.
...
PiperOrigin-RevId: 640968129
2024-06-06 11:38:56 -07:00
Tomás Longeri
a65d3ae0da
[Mosaic] Expand vector.shape_cast support for sublane (un)folding no-ops
...
- Support non-zero minor offsets without having to relayout (they're still a no-op).
- Remove restriction on tiling which now allows 1D packed types to work.
PiperOrigin-RevId: 640967375
2024-06-06 11:35:19 -07:00
Jake VanderPlas
48355cde83
jax2tf_test: ensure no modification of global config
2024-06-06 11:27:33 -07:00
jax authors
82516c5d4f
Merge pull request #21694 from rajasekharporeddy:doc_typos
...
PiperOrigin-RevId: 640956334
2024-06-06 11:05:37 -07:00
jax authors
cc4bd42390
Merge pull request #21688 from froystig:slab-heap
...
PiperOrigin-RevId: 640953143
2024-06-06 10:56:09 -07:00
jax authors
15e41a620f
Merge pull request #21702 from hawkinsp:cudnnplug
...
PiperOrigin-RevId: 640932820
2024-06-06 09:58:12 -07:00
Jevin Jiang
7a5975e174
[Pallas] Fix typo in test.
...
PiperOrigin-RevId: 640930803
2024-06-06 09:51:35 -07:00
Peter Hawkins
971ab0fba2
Make CuDNN SDPA API work with JAX with a CUDA plugin configuration.
2024-06-06 12:09:19 -04:00
Christos Perivolaropoulos
18e55d567f
[test_utils] Fix the encoding of capture_stdout so it works on windows.
...
PiperOrigin-RevId: 640910749
2024-06-06 08:43:25 -07:00