Peter Hawkins
6fa31e59c4
Update version numbers after v0.4.29 release.
2024-06-10 14:37:53 -04:00
jax authors
3fe7377719
Merge pull request #21763 from gnecula:export_api
...
PiperOrigin-RevId: 641959833
2024-06-10 11:05:34 -07:00
George Necula
b33aca6b08
[export] Create the jax.export module APIs.
...
The functionality comes from the jax.experimental.export
module, which will be deprecated.
The following APIs are introduced:
```
from jax import export
def f(...): ...
ex: export.Exported = export.export(jax.jit(f))(*args, **kwargs)
blob: bytearray = ex.serialize()
rehydrated: export.Export = export.deserialize(blob)
def caller(...):
... rehydrated.call(*args, **kwargs)
```
Module documentation will follow shortly.
There are no changes for now in the jax.experimental.export
APIs.
Most of the changes in this PR are in tests due to some differences
in the new jax.export APIs compared to jax.experimental.export:
* Instead of `jax.experimental.export.call(exp)` we now write
`exp.call`
* The `jax.experimental.export.export` allowed the function
argument to be any Python callable and it would wrap it with
a `jax.jit`. This is not supported anymore by export, and instead
the user must use `jax.jit`.
2024-06-10 19:31:51 +02:00
Sergei Lebedev
5e7ad600e2
Removed the double re-exporting of Pallas GPU/TPU APIs
...
jax.experimental.pallas.{gpu,tpu} now import directly from the relevant
jax._src.pallas.{triton,mosaic} submodules.
PiperOrigin-RevId: 641875127
2024-06-10 05:59:09 -07:00
George Necula
2ade7e7526
[pallas] Move the hardware_generation query in the code path that needs it
...
This change allows us to lower and export Pallas calls even
on machines that do not have TPUs, in many cases.
PiperOrigin-RevId: 641841079
2024-06-10 03:13:36 -07:00
jax authors
af95803d00
Merge pull request #21759 from rajasekharporeddy:testbranch1
...
PiperOrigin-RevId: 641831969
2024-06-10 02:29:12 -07:00
rajasekharporeddy
775c6f8727
Fix Typos in docs and one error message
2024-06-10 11:38:01 +05:30
George Necula
14d87d3bf7
[export] Move the export implementation to jax._src.export.
...
This is part of the work to move the export APIs out
of jax.experimental. For now, the way to use this
implementation is still through `jax.experimental.export`.
Had to add a few "#type ignore" to the _export.py because
previously the file was exempt from internal pytype.
Will try to fix these in a later PR.
PiperOrigin-RevId: 641688200
2024-06-09 08:59:50 -07:00
jax authors
b486a95186
Merge pull request #21507 from renecotyfanboy:main
...
PiperOrigin-RevId: 641429523
2024-06-07 20:28:23 -07:00
jax authors
d32404020b
Avoid "min() arg is an empty sequence" error after enabling "jax_explain_cache_misses".
...
PiperOrigin-RevId: 641381432
2024-06-07 15:52:35 -07:00
sdupourque
751d59ce67
increase default precision for hyp1f1
2024-06-08 00:38:51 +02:00
Yash Katariya
44a13c9d4b
Merge code between make_jaxpr
and jit(f).trace
.
...
The semantics of `make_jaxpr` are preserved here i.e. `make_jaxpr` still closes over tracers but `jit(f).trace` doesn't.
Since we can keep the existing behavior and still merge the implementation is a good cleanup!
Fixes https://github.com/google/jax/issues/21116
PiperOrigin-RevId: 641347140
2024-06-07 13:48:31 -07:00
jax authors
25cc84b879
Merge pull request #21615 from selamw1:append_doc
...
PiperOrigin-RevId: 641344856
2024-06-07 13:39:57 -07:00
jax authors
dfc6076db2
Merge pull request #21744 from superbobry:typing
...
PiperOrigin-RevId: 641339815
2024-06-07 13:23:31 -07:00
jax authors
7d913f763a
Merge pull request #21298 from oliverdutton:pallas_interpreter_indexing_fix
...
PiperOrigin-RevId: 641325047
2024-06-07 12:29:31 -07:00
Sergei Lebedev
0786da8fd8
Removed unnecessary mypy exclusions from pyproject.toml
...
* 2/3 files type check just fine now
* the remaining one could be handled via a file-level directive
2024-06-07 20:07:42 +01:00
jax authors
f4c6437837
Merge pull request #21680 from ROCm:ci_spmm
...
PiperOrigin-RevId: 641316410
2024-06-07 11:57:12 -07:00
jax authors
af90464b53
Merge pull request #21733 from dfm:ffi-capsule-docstring
...
PiperOrigin-RevId: 641307843
2024-06-07 11:27:41 -07:00
jax authors
bd499a921e
Merge pull request #21690 from rajasekharporeddy:testbranch1
...
PiperOrigin-RevId: 641292860
2024-06-07 10:38:07 -07:00
jax authors
98d7235aee
Merge pull request #21501 from jakevdp:softmax-inf-doc
...
PiperOrigin-RevId: 641291919
2024-06-07 10:34:40 -07:00
jax authors
1459ac04a8
Merge pull request #21731 from tttc3:cross-product-typo
...
PiperOrigin-RevId: 641285460
2024-06-07 10:18:35 -07:00
jax authors
2899c9fada
Merge pull request #21692 from rajasekharporeddy:testbranch2
...
PiperOrigin-RevId: 641285369
2024-06-07 10:15:22 -07:00
Dan Foreman-Mackey
1fa66590d1
Edit pycapsule
docstring to provide a little bit more context
...
The docstring for the recently added `pycapsule` function in
`jax.extend.ffi` didn't conform to our usual docstring format, so I
updated it and added a little bit more context.
2024-06-07 13:07:03 -04:00
jax authors
f51af87fc5
fp8 matmul in pallas
...
PiperOrigin-RevId: 641254832
2024-06-07 08:17:06 -07:00
George Necula
3914cb415d
[export] Remove old deprecated APIs for jax.experimental.export.
...
See CHANGELOG.md.
The deprecation period has passed.
Also replace deprecated .call_exported with .call in tests.
PiperOrigin-RevId: 641236222
2024-06-07 06:52:10 -07:00
tttc3
21f71c6b66
fix typo in jax.numpy.linalg.cross
docstring
2024-06-07 13:43:51 +01:00
Sergei Lebedev
5d6413cecc
Added debug_callback to the list of exclusions in jax2tf/tests/primitives_test.py
...
PiperOrigin-RevId: 641149152
2024-06-07 00:01:30 -07:00
jax authors
c01c98400d
Add missing arguments for jnp.extract's python binding signature.
...
PiperOrigin-RevId: 641121305
2024-06-06 21:34:38 -07:00
rajasekharporeddy
6d94ae3274
Improve docs for jnp.angle and jnp.flip
2024-06-07 10:03:07 +05:30
rajasekharporeddy
6d85c3890d
Improve documentation for jnp.fliplr and jnp.flipud
2024-06-07 09:58:02 +05:30
jax authors
625ea07a7e
Merge pull request #21710 from jakevdp:fix-jax2tf
...
PiperOrigin-RevId: 641112498
2024-06-06 20:45:57 -07:00
Roy Frostig
ea6dfd1947
rename Specialized
to Traced
(and specialize
to trace
)
...
PiperOrigin-RevId: 641076488
2024-06-06 17:43:08 -07:00
jax authors
a1b5860427
Merge pull request #21711 from jakevdp:setup-module
...
PiperOrigin-RevId: 641049524
2024-06-06 15:59:07 -07:00
Jake VanderPlas
a861c55a28
test cleanup: use ExitStack to reduce test boilerplate
2024-06-06 14:18:27 -07:00
jax authors
d457f9a116
Merge pull request #21716 from gnecula:exp_rename_sharding
...
PiperOrigin-RevId: 641017765
2024-06-06 14:17:10 -07:00
George Necula
01ee768f73
[export] Rename in_shardings and out_shardings fields.
...
We rename `in_shardings` to `in_shardings_hlo` to remove confusion
with JAX's use of `in_shardings`.
We also rename `xla_compatible_in_sharding` to `in_shardings_jax`
since we do not have a XLACompatibleSharding type anymore.
2024-06-06 22:00:16 +01:00
Yash Katariya
aee62e4874
Implement lower
in terms of specialize
...
PiperOrigin-RevId: 641005643
2024-06-06 13:39:07 -07:00
jax authors
90c83bb1e2
Merge pull request #21484 from dfm:custom-call-lowering
...
PiperOrigin-RevId: 640996459
2024-06-06 13:10:28 -07:00
Mark Sandler
2c246df439
Reverts dfe61285093ff826e1ad23bb36b77a42c01040b4
...
PiperOrigin-RevId: 640987745
2024-06-06 12:41:17 -07:00
Yash Katariya
fbf2a62aa1
Remove jaxpr
and name
from Lowered
because specialize
already has those. This keeps the abstraction boundary clear. Adapt export
to use specialize
.
...
PiperOrigin-RevId: 640968129
2024-06-06 11:38:56 -07:00
Jake VanderPlas
48355cde83
jax2tf_test: ensure no modification of global config
2024-06-06 11:27:33 -07:00
jax authors
82516c5d4f
Merge pull request #21694 from rajasekharporeddy:doc_typos
...
PiperOrigin-RevId: 640956334
2024-06-06 11:05:37 -07:00
jax authors
cc4bd42390
Merge pull request #21688 from froystig:slab-heap
...
PiperOrigin-RevId: 640953143
2024-06-06 10:56:09 -07:00
jax authors
15e41a620f
Merge pull request #21702 from hawkinsp:cudnnplug
...
PiperOrigin-RevId: 640932820
2024-06-06 09:58:12 -07:00
Peter Hawkins
971ab0fba2
Make CuDNN SDPA API work with JAX with a CUDA plugin configuration.
2024-06-06 12:09:19 -04:00
Christos Perivolaropoulos
18e55d567f
[test_utils] Fix the encoding of capture_stdout so it works on windows.
...
PiperOrigin-RevId: 640910749
2024-06-06 08:43:25 -07:00
Dan Foreman-Mackey
ac560c0d90
Add helper function for building custom call lowering rules
...
This function provides sensible defaults for custom call lowering rules
with the goal of reducing the amount of boilerplate required for
implementing custom calls.
Co-authored-by: Sergei Lebedev <slebedev@google.com>
2024-06-06 11:34:08 -04:00
jax authors
fe9c1606fc
Merge pull request #21655 from gnecula:exp_lower
...
PiperOrigin-RevId: 640898362
2024-06-06 07:58:41 -07:00
Peter Hawkins
dfe6128509
Reverts da816d34eaad6a1c6536959ccb4bfee4466c037d
...
PiperOrigin-RevId: 640886105
2024-06-06 07:10:09 -07:00
Chris Jones
d700a0842b
[mosaic:gpu] Fix matmul example for mixed precision inputs.
...
The wgmma instruction group was never committed.
PiperOrigin-RevId: 640863541
2024-06-06 05:29:23 -07:00