Michael Hudgins
d4d1518c3d
Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
...
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Yash Katariya
c191bbcdb1
Make debug.print
work with static args. Fixes: https://github.com/google/jax/issues/23600
...
PiperOrigin-RevId: 676005582
2024-09-18 08:41:29 -07:00
Sergei Lebedev
91df9d1a17
Fixed validation in jax.debug.format
...
This commit ensures that no formatting is done during validation, because the
arguments could be abstract values.
Closes #23475 .
2024-09-09 10:53:35 +01:00
Sergei Lebedev
a44265aa73
Added a trivial discharge rule for debug_callback_p
...
This allows using jax.debug.print with Refs in interpreted Pallas kernels.
2024-07-29 22:26:01 +01:00
Sergei Lebedev
8d33a6c9a6
Bumped jaxlib version mypy uses on the CI
...
I also enabled unnecessary cast checking, because turns out we have quite
a few of those.
2024-07-26 11:22:39 +01:00
Peter Hawkins
8ab0c07edc
Don't wrap singleton ir.Values with tuples during HLO lowering.
...
In general a JAX value might correspond to multiple HLO values, which is why the HLO lowering represents each value as a tuple of zero or more ir.Values. However, the common case is that there is exactly one value, and almost all such lists are singletons.
To reduce the number of singleton list and tuple objects allocated during MLIR lowering, instead represent singleton values as unwrapped ir.Values, and only use a tuple if there is not exactly one ir.Value backing a JAX value.
2024-07-01 16:11:00 -04:00
Peter Hawkins
7f4ef63cd8
Run pyupgrade --py310-plus
.
...
Also apply manual fixes to import sorting and unused imports.
2024-06-26 16:10:18 -04:00
jax authors
5e2710c2c2
Merge pull request #21261 from superbobry:mypy-ruff
...
PiperOrigin-RevId: 634654578
2024-05-17 00:10:27 -07:00
Sergei Lebedev
c3bc88d5e4
Bumped mypy to 1.10.0 and ruff to 0.4.4
2024-05-16 23:16:32 +01:00
Sergei Lebedev
01194bd2fb
Clarified the type of the inputs to callback APIs
...
The callback APIs were migrated to use jax.Arrays for both inputs and outputs
in JAX 0.4.27.
PiperOrigin-RevId: 634473890
2024-05-16 11:29:09 -07:00
Sergei Lebedev
03b733bda7
Made has_side_effect= parameter of mlir.emit_python_callback keyword-only
...
This ensures that the call site always has parameter name and not just
a bare True/False argument.
PiperOrigin-RevId: 630166542
2024-05-02 13:44:54 -07:00
Sergei Lebedev
6e23c14f85
jax.debug.callback now passes arguments as jax.Arrays
...
Prior to this change the behavior in eager and under jax.jit was inconsistent
>>> (lambda *args: jax.debug.callback(print, *args))([42])
[42]
>>> jax.jit(lambda *args: jax.debug.callback(print, *args))([42])
[array(42, dtype=int32)]
It was also inconsistent with other callback APIs, which cast the arguments
to jax.Arrays.
Closes #20627 .
PiperOrigin-RevId: 626461904
2024-04-19 13:57:18 -07:00
Sergei Lebedev
32922f61e9
jax.debug.callback now requires a Callable[..., None]
...
This makes the "return value is ignored" behavior explicit in the type.
PiperOrigin-RevId: 626430448
2024-04-19 11:55:08 -07:00
Sergei Lebedev
1be5451179
Import rich lazily
...
This ensures that the timing of `import jax` is not affected by `rich` being
installed.
See also #20778 .
2024-04-16 22:25:33 +01:00
Yash Katariya
4c9241ecda
Cache ClosedJaxpr creation to minimize cache_misses. ClosedJaxpr should always be created under a cache.
...
PiperOrigin-RevId: 593023314
2023-12-21 22:15:52 -08:00
Sergei Lebedev
f936613b06
Upgrade remaining sources to Python 3.9
...
This PR is a follow up to #18881 .
The changes were generated by adding
from __future__ import annotations
to the files which did not already have them and running
pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-13 10:29:45 +00:00
Yash Katariya
10f6a35f83
Add a registry for primitives that require device_assignment during lowering
...
PiperOrigin-RevId: 589272990
2023-12-08 16:31:41 -08:00
Yash Katariya
5fb8ceca73
Make lowering oblivious to real physical devices. Instead cache lowering on HloSharding only (which is based on logical device numbers)
...
Make an exception for callbacks and custom_partitioning because they need access to device_assignment during lowering.
PiperOrigin-RevId: 589244695
2023-12-08 14:36:09 -08:00
Matthew Johnson
7608cce86f
improve a debug.callback type error message for idiots
...
(i am the idiot)
2023-12-06 14:41:52 -08:00
Matthew Johnson
997db225e2
small tweaks to jax.debug.print docstring
2023-11-22 15:03:41 -08:00
George Necula
edbe49fb2a
Cleanup the handling of single- and multi-platform lowering in ModuleContext
...
Previously, we introduced support for multi-platform lowering, by
adding a new LoweringParameters object that can be used to specify
a cross-lowering platform or even multiple platforms. But we had
kept the ModuleContext.platform in place because some lowering rules
were still referencing it. Now we replace ModuleContext.platform with
ModuleContext.platforms, which removes the redundancy, simplifies
the code, and makes it clearer that the lowering rules should not
simply assume single-platform lowering.
PiperOrigin-RevId: 576575376
2023-10-25 10:40:41 -07:00
Yash Katariya
ef20526a76
Return PositionalSharding if input's rank is >= 3 or a NamedSharding if a mesh is available via the context from inspect_array_sharding. Never return GSPMDSharding from inspect_array_sharding.
...
PiperOrigin-RevId: 573048344
2023-10-12 16:55:12 -07:00
Peter Hawkins
15126504a7
[JAX] Keep CPU host callbacks alive via IFRT, rather than by attaching them to the Python object.
...
We need to keep callback objects alive as long as any running executables are alive. It is possible to discard the Python data structures for an executable before the runtime has finished running that executable, which can lead to a use after free. Instead, make the runtime keep host callbacks alive.
PiperOrigin-RevId: 571141106
2023-10-05 15:07:03 -07:00
Peter Hawkins
319ab98980
Apply pyupgrade --py39-plus.
...
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
Patrick Kidger
997d60ed1a
Fix docstring for jax.debug.{print,callback}
2023-07-17 15:58:58 +01:00
Peter Hawkins
816ba91263
Use lower-case PEP 585 names for types.
...
Issue https://github.com/google/jax/issues/16537
PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
Yash Katariya
c4b1b7db0d
Make inspect_sharding_lowering_rule
work only with HloSharding in it's callback. Also remove get_replicated_op_sharding
since it is not needed anymore.
...
PiperOrigin-RevId: 538595966
2023-06-07 14:39:52 -07:00
Yash Katariya
ae9d1498e5
Bump minimum jaxlib version to 0.4.11. xla_extension_version is 158 and mlir_api_version is 49. It will subsume https://github.com/google/jax/pull/16161#issuecomment-1564977332
...
PiperOrigin-RevId: 537047525
2023-06-01 09:42:55 -07:00
Parker Schuh
5f4408ded7
Convert inspect_sharding to register the handler directly in c++ so that it can
...
work across the c-api boundary.
PiperOrigin-RevId: 527322386
2023-04-26 11:22:28 -07:00
Jake VanderPlas
055edf4a08
DOC: add docstrings for callback functions
2023-04-12 07:33:09 -07:00
Yash Katariya
393e5931d1
Move parse_flatten_op_sharding to sharding_impls.py to remove local import of pjit using that function from pxla.py
...
PiperOrigin-RevId: 523573375
2023-04-11 19:26:25 -07:00
Matthew Johnson
962e47f2a8
fix debug callback docstring (and return value)
2023-04-10 23:07:43 -07:00
Peter Hawkins
be1cf46a49
Split sharding_impls into its own Bazel target.
...
* Move dependencies of sharding_impls into sharding_impls to avoid creating cyclic dependencies.
* Fix a handful of new pytype errors.
PiperOrigin-RevId: 523146076
2023-04-10 10:15:58 -07:00
Peter Hawkins
c1f65fc8b2
Avoid imports from the public jax.* namespace in more places internally.
...
This change is in preparation for more cycle breaking in the Bazel dependency graph.
PiperOrigin-RevId: 521822756
2023-04-04 11:41:40 -07:00
Peter Hawkins
6cc1bf54a1
Move jax.interpreters.partial_eval to jax._src.interpreters.partial_eval.
...
Also fix up some other internal imports of jax.interpreters.* to use jax._src.interpreters.
PiperOrigin-RevId: 519813664
2023-03-27 13:30:47 -07:00
Frederic Bastien
42e9753431
Fix inspect_array_sharding with grad.
2023-03-21 07:58:27 -07:00
Peter Hawkins
dea7450e4e
Remove references to jax.config.jax_array, which is always True at head.
...
PiperOrigin-RevId: 516970232
2023-03-15 17:09:11 -07:00
Peter Hawkins
1925aa1109
Split Sharding subclasses out of _src/sharding.py into _src/sharding_impls.py
...
By defining the Sharding base class in its own module, we can pull it out into a separate Bazel submodule, which will help pytype inference when defining Array.
PiperOrigin-RevId: 516223009
2023-03-13 08:50:18 -07:00
Peter Hawkins
623282715d
Split Mesh and ResourceEnv into a new module jax._src.mesh.
...
This work is an effort to reduce cyclic dependencies in JAX internals.
Move the _global_to_local and _local_to_global methods out of Mesh and into pxla as free functions. This removes the need for jax._src.mesh to depend on things like avals.
PiperOrigin-RevId: 515667671
2023-03-10 10:08:21 -08:00
Matthew Johnson
1f67351f56
[shard_map] make debug_print work with shard_map, eager and jit
2023-03-08 20:38:03 -08:00
Yash Katariya
52a7701dda
Replace usage of {in|out}_axis_resources with {in|out}_shardings
...
PiperOrigin-RevId: 513040164
2023-02-28 14:29:09 -08:00
Sharad Vikram
a6c4c87f3e
Add JaxprInputEffect
and refactor StateEffect
s to use it
2023-02-21 16:30:06 -08:00
Sharad Vikram
af2306c0a8
Refactor effects system to use effect types, not objects
2023-02-17 17:40:08 -08:00
Yash Katariya
0ffdeb3de2
Rename jax.sharding.OpShardingSharding
to jax.sharding.GSPMDSharding
. jax.sharding.OpShardingSharding
will be removed in 3 months from Feb 17, 2023.
...
PiperOrigin-RevId: 510556189
2023-02-17 17:11:06 -08:00
Sharad Vikram
6f1714e57a
Add some info in the docs about using jax.debug.print
with f-strings
2023-02-15 15:16:37 -08:00
Peter Hawkins
00d45feee6
Deprecate uses of jax.experimental.pjit.NamedSharding and jax.experimental.pjit.PartitionSpec.
...
Use the aliases under jax.sharding instead.
PiperOrigin-RevId: 509837529
2023-02-15 08:14:26 -08:00
Roy Frostig
1c84e4a753
migrate internal dependencies from jax.interpreters.batching
to jax._src.interpreters.batching
...
... in preparation for paring down `jax.interpreters.batching`'s exported symbols.
PiperOrigin-RevId: 508487887
2023-02-09 15:11:57 -08:00
Peter Hawkins
cc8d7fae32
Move jax.interpreters.mlir to jax._src.interpreters.mlir.
...
Replace jax.interpreters.mlir with a shim that re-exports names that are likely to be used externally.
PiperOrigin-RevId: 508187063
2023-02-08 14:39:01 -08:00
Roy Frostig
219723c738
migrate internal dependencies from jax.interpreters.ad
to jax._src.interpreters.ad
...
... in preparation for paring down `jax.interpreters.ad`'s exported symbols.
Includes some import fixups along the way.
PiperOrigin-RevId: 507684262
2023-02-06 22:52:36 -08:00
Jake VanderPlas
4a6bbde409
Move jax.linear_util to jax._src.linear_util
2022-12-20 14:49:27 -08:00