76 Commits

Author SHA1 Message Date
Sergei Lebedev
03b733bda7 Made has_side_effect= parameter of mlir.emit_python_callback keyword-only
This ensures that the call site always has parameter name and not just
a bare True/False argument.

PiperOrigin-RevId: 630166542
2024-05-02 13:44:54 -07:00
Sergei Lebedev
6e23c14f85 jax.debug.callback now passes arguments as jax.Arrays
Prior to this change the behavior in eager and under jax.jit was inconsistent

    >>> (lambda *args: jax.debug.callback(print, *args))([42])
    [42]
    >>> jax.jit(lambda *args: jax.debug.callback(print, *args))([42])
    [array(42, dtype=int32)]

It was also inconsistent with other callback APIs, which cast the arguments
to jax.Arrays.

Closes #20627.

PiperOrigin-RevId: 626461904
2024-04-19 13:57:18 -07:00
Sergei Lebedev
32922f61e9 jax.debug.callback now requires a Callable[..., None]
This makes the "return value is ignored" behavior explicit in the type.

PiperOrigin-RevId: 626430448
2024-04-19 11:55:08 -07:00
Sergei Lebedev
1be5451179 Import rich lazily
This ensures that the timing of `import jax` is not affected by `rich` being
installed.

See also #20778.
2024-04-16 22:25:33 +01:00
Yash Katariya
4c9241ecda Cache ClosedJaxpr creation to minimize cache_misses. ClosedJaxpr should always be created under a cache.
PiperOrigin-RevId: 593023314
2023-12-21 22:15:52 -08:00
Sergei Lebedev
f936613b06 Upgrade remaining sources to Python 3.9
This PR is a follow up to #18881.

The changes were generated by adding

    from __future__ import annotations

to the files which did not already have them and running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-13 10:29:45 +00:00
Yash Katariya
10f6a35f83 Add a registry for primitives that require device_assignment during lowering
PiperOrigin-RevId: 589272990
2023-12-08 16:31:41 -08:00
Yash Katariya
5fb8ceca73 Make lowering oblivious to real physical devices. Instead cache lowering on HloSharding only (which is based on logical device numbers)
Make an exception for callbacks and custom_partitioning because they need access to device_assignment during lowering.

PiperOrigin-RevId: 589244695
2023-12-08 14:36:09 -08:00
Matthew Johnson
7608cce86f improve a debug.callback type error message for idiots
(i am the idiot)
2023-12-06 14:41:52 -08:00
Matthew Johnson
997db225e2 small tweaks to jax.debug.print docstring 2023-11-22 15:03:41 -08:00
George Necula
edbe49fb2a Cleanup the handling of single- and multi-platform lowering in ModuleContext
Previously, we introduced support for multi-platform lowering, by
adding a new LoweringParameters object that can be used to specify
a cross-lowering platform or even multiple platforms. But we had
kept the ModuleContext.platform in place because some lowering rules
were still referencing it. Now we replace ModuleContext.platform with
ModuleContext.platforms, which removes the redundancy, simplifies
the code, and makes it clearer that the lowering rules should not
simply assume single-platform lowering.

PiperOrigin-RevId: 576575376
2023-10-25 10:40:41 -07:00
Yash Katariya
ef20526a76 Return PositionalSharding if input's rank is >= 3 or a NamedSharding if a mesh is available via the context from inspect_array_sharding. Never return GSPMDSharding from inspect_array_sharding.
PiperOrigin-RevId: 573048344
2023-10-12 16:55:12 -07:00
Peter Hawkins
15126504a7 [JAX] Keep CPU host callbacks alive via IFRT, rather than by attaching them to the Python object.
We need to keep callback objects alive as long as any running executables are alive. It is possible to discard the Python data structures for an executable before the runtime has finished running that executable, which can lead to a use after free. Instead, make the runtime keep host callbacks alive.

PiperOrigin-RevId: 571141106
2023-10-05 15:07:03 -07:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
Patrick Kidger
997d60ed1a
Fix docstring for jax.debug.{print,callback} 2023-07-17 15:58:58 +01:00
Peter Hawkins
816ba91263 Use lower-case PEP 585 names for types.
Issue https://github.com/google/jax/issues/16537

PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
Yash Katariya
c4b1b7db0d Make inspect_sharding_lowering_rule work only with HloSharding in it's callback. Also remove get_replicated_op_sharding since it is not needed anymore.
PiperOrigin-RevId: 538595966
2023-06-07 14:39:52 -07:00
Yash Katariya
ae9d1498e5 Bump minimum jaxlib version to 0.4.11. xla_extension_version is 158 and mlir_api_version is 49. It will subsume https://github.com/google/jax/pull/16161#issuecomment-1564977332
PiperOrigin-RevId: 537047525
2023-06-01 09:42:55 -07:00
Parker Schuh
5f4408ded7 Convert inspect_sharding to register the handler directly in c++ so that it can
work across the c-api boundary.

PiperOrigin-RevId: 527322386
2023-04-26 11:22:28 -07:00
Jake VanderPlas
055edf4a08 DOC: add docstrings for callback functions 2023-04-12 07:33:09 -07:00
Yash Katariya
393e5931d1 Move parse_flatten_op_sharding to sharding_impls.py to remove local import of pjit using that function from pxla.py
PiperOrigin-RevId: 523573375
2023-04-11 19:26:25 -07:00
Matthew Johnson
962e47f2a8 fix debug callback docstring (and return value) 2023-04-10 23:07:43 -07:00
Peter Hawkins
be1cf46a49 Split sharding_impls into its own Bazel target.
* Move dependencies of sharding_impls into sharding_impls to avoid creating cyclic dependencies.
* Fix a handful of new pytype errors.

PiperOrigin-RevId: 523146076
2023-04-10 10:15:58 -07:00
Peter Hawkins
c1f65fc8b2 Avoid imports from the public jax.* namespace in more places internally.
This change is in preparation for more cycle breaking in the Bazel dependency graph.

PiperOrigin-RevId: 521822756
2023-04-04 11:41:40 -07:00
Peter Hawkins
6cc1bf54a1 Move jax.interpreters.partial_eval to jax._src.interpreters.partial_eval.
Also fix up some other internal imports of jax.interpreters.* to use jax._src.interpreters.

PiperOrigin-RevId: 519813664
2023-03-27 13:30:47 -07:00
Frederic Bastien
42e9753431 Fix inspect_array_sharding with grad. 2023-03-21 07:58:27 -07:00
Peter Hawkins
dea7450e4e Remove references to jax.config.jax_array, which is always True at head.
PiperOrigin-RevId: 516970232
2023-03-15 17:09:11 -07:00
Peter Hawkins
1925aa1109 Split Sharding subclasses out of _src/sharding.py into _src/sharding_impls.py
By defining the Sharding base class in its own module, we can pull it out into a separate Bazel submodule, which will help pytype inference when defining Array.

PiperOrigin-RevId: 516223009
2023-03-13 08:50:18 -07:00
Peter Hawkins
623282715d Split Mesh and ResourceEnv into a new module jax._src.mesh.
This work is an effort to reduce cyclic dependencies in JAX internals.

Move the _global_to_local and _local_to_global methods out of Mesh and into pxla as free functions. This removes the need for jax._src.mesh to depend on things like avals.

PiperOrigin-RevId: 515667671
2023-03-10 10:08:21 -08:00
Matthew Johnson
1f67351f56 [shard_map] make debug_print work with shard_map, eager and jit 2023-03-08 20:38:03 -08:00
Yash Katariya
52a7701dda Replace usage of {in|out}_axis_resources with {in|out}_shardings
PiperOrigin-RevId: 513040164
2023-02-28 14:29:09 -08:00
Sharad Vikram
a6c4c87f3e Add JaxprInputEffect and refactor StateEffects to use it 2023-02-21 16:30:06 -08:00
Sharad Vikram
af2306c0a8 Refactor effects system to use effect types, not objects 2023-02-17 17:40:08 -08:00
Yash Katariya
0ffdeb3de2 Rename jax.sharding.OpShardingSharding to jax.sharding.GSPMDSharding. jax.sharding.OpShardingSharding will be removed in 3 months from Feb 17, 2023.
PiperOrigin-RevId: 510556189
2023-02-17 17:11:06 -08:00
Sharad Vikram
6f1714e57a Add some info in the docs about using jax.debug.print with f-strings 2023-02-15 15:16:37 -08:00
Peter Hawkins
00d45feee6 Deprecate uses of jax.experimental.pjit.NamedSharding and jax.experimental.pjit.PartitionSpec.
Use the aliases under jax.sharding instead.

PiperOrigin-RevId: 509837529
2023-02-15 08:14:26 -08:00
Roy Frostig
1c84e4a753 migrate internal dependencies from jax.interpreters.batching to jax._src.interpreters.batching
... in preparation for paring down `jax.interpreters.batching`'s exported symbols.

PiperOrigin-RevId: 508487887
2023-02-09 15:11:57 -08:00
Peter Hawkins
cc8d7fae32 Move jax.interpreters.mlir to jax._src.interpreters.mlir.
Replace jax.interpreters.mlir with a shim that re-exports names that are likely to be used externally.

PiperOrigin-RevId: 508187063
2023-02-08 14:39:01 -08:00
Roy Frostig
219723c738 migrate internal dependencies from jax.interpreters.ad to jax._src.interpreters.ad
... in preparation for paring down `jax.interpreters.ad`'s exported symbols.

Includes some import fixups along the way.

PiperOrigin-RevId: 507684262
2023-02-06 22:52:36 -08:00
Jake VanderPlas
4a6bbde409 Move jax.linear_util to jax._src.linear_util 2022-12-20 14:49:27 -08:00
Roy Frostig
d927a5dbf3 migrate internal dependencies from jax.core to jax._src.core
... in preparation for paring down `jax.core`'s exported symbols.

Also includes a few import fixups along the way, and a TODO comment to avoid an
import cycle in `_src/dtypes.py`.

PiperOrigin-RevId: 496024782
2022-12-16 21:00:14 -08:00
Eugene Burmako
b8ae8e3fa1 (NFC) Prepare for migration from producing MHLO to producing StableHLO
This CL renames occurrences of "mhlo" in: 1) names, 2) tests, 3) prose in order
to prepare for the upcoming migration.

Unchanged occurrences:
  1) Public API that contains "mhlo", e.g. XlaLowering.mhlo and the "mhlo"
     argument value in Lowering.as_text and Lowering.compiler_ir.
  2) Documentation (changelog, JEPs, IR examples, etc).
  3) One rare situation where prose says "StableHLO" and "MHLO" in one sentence,
     so both are necessary to disambiguate.

PiperOrigin-RevId: 495771153
2022-12-15 21:00:07 -08:00
Rahul Joshi
b2d73cb73b Fix sharding for tuple created in inspect_sharding_partition 2022-11-29 11:35:04 -08:00
Yash Katariya
c42bad85ef Make MeshPspecSharding an alias for NamedSharding (it was the other way around before this CL).
PiperOrigin-RevId: 488473538
2022-11-14 14:44:00 -08:00
jax authors
93bcb599c2 Merge pull request #13065 from sharadmv:vis-colors
PiperOrigin-RevId: 485467801
2022-11-01 18:11:06 -07:00
Sharad Vikram
3bbd5f3028 Add colors to sharding visualization 2022-11-01 16:55:06 -07:00
Rahul Joshi
42b2cf2107 Fix result type for InspectSharding custom call
- Use the result type that will be valid for the sharding that
  will be attached to this custom call.
2022-10-27 10:01:22 -07:00
Peter Hawkins
320d531521 Increase the minimum jaxlib version to 0.3.22.
The minimum xla_extension_version is now 98 and the minimum mlir_api_version is now 32.
2022-10-27 10:24:11 -04:00
Jake VanderPlas
7f89fd40a2 Cleanup: remove unused imports in private modules
Also improve our flake8 filter rules to avoid ignoring these.
2022-10-20 14:37:21 -07:00
Jake VanderPlas
5d15757741 [typing] annotate jax._src.util.safe_map 2022-10-20 10:15:04 -07:00