1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 13:26:06 +00:00

102 Commits

Author SHA1 Message Date
Yash Katariya
a4ca0dbc6c Make the signature of AbstractMesh to be AbstractMesh(axis_size: tuple[int, ...], axis_name: tuple[str, ...], *, axis_types) instead of AbstractMesh(shape_tuple: tuple[tuple[str, int], ...], *, axis_types) so that we are consistent across all Mesh APIs: Mesh, AbstractMesh and make_mesh
PiperOrigin-RevId: 736371111
2025-03-12 21:32:31 -07:00
shuw
c099e8081d support e2m1fn 2025-03-05 17:44:34 +00:00
Bart Chrzaszcz
ac493655bf #sdy support JAX export tests when Shardy is enabled.
This CL only supports lowering a module with the exact same mesh, and loading it with either the exact same mesh or different meshes.

Note that we will be introducing some restrictions under Shardy for JAX export:

- You can only lower/save the module with meshes all of the same shape, but different axis names (this PR is right now only allowing the same axis names, but this will be relaxed in a follow-up)
- When loading the module, just like with GSPMD, you can use a different mesh with a different mesh shape and axis names. However, like with the restriction in the previous point, all shardings must use the same axis shapes, but can use different axis names (again this will be relaxed in a follow-up)

We may remove the restriction of having to use the exact same mesh shapes during export saving time and exact same mesh shaped during export loading time in the future. But for now we will keep this restriction while no one is using Shardy with JAX export.

PiperOrigin-RevId: 732878916
2025-03-03 04:57:06 -08:00
Matthew Johnson
1ae02bc069 skip tests with extra requirements 2025-02-05 01:48:28 +00:00
George Necula
6dd1234707 [export] Fix mis-used of NamedSharding in export tests 2025-01-24 09:18:02 +02:00
jax authors
8442d64a02 Merge pull request from wenscarl:fp8_e8m0fnu
PiperOrigin-RevId: 718996844
2025-01-23 13:41:35 -08:00
Peter Hawkins
3fa557289a Port tests away from setUpClass and setUpModule to setUp alone.
This change prepares for upcoming changes in which we run tests in parallel using threads, which we are doing partially to test free threading but also partially to speed up TPU tests via thread-parallelism.

If independent tests run in parallel in no particular order, there's no natural scope around which to call setUpClass or SetUpModule. But for JAX tests this never seems necessary: we can just do the same work in setUp() or do it globally.

PiperOrigin-RevId: 713296722
2025-01-08 08:14:50 -08:00
Peter Hawkins
51b9fe3010 [JAX] Add a new jax_num_cpu_devices flag that allows the user to specify the number of CPU directly.
This subsumes (and ultimately will deprecate) overriding the number of CPU devices via XLA_FLAGS.

In addition, replace the test utility jtu.set_host_platform_device_count with jtu.request_cpu_devices(...), which sets or increases the flag's value. This both removes the need for an overly complicated context stack, and prepares for removing remaining uses of setUpModule as part of work parallelizing the test suite with threads.

PiperOrigin-RevId: 713272197
2025-01-08 06:37:44 -08:00
George Necula
afcb62ea20 [export] Expand exporting to work with AbstractMesh.
This is a follow up from  that enabled lowering with
AbstractMesh.

This required adding `num_devices` to `lowering.compiler_args`
because in presence of an AbstractMesh the device_assignment
is not accurate.
2024-12-16 10:30:46 +02:00
George Necula
292a00b35a [export] Cleanup in the export module.
With jax.experimental.export gone we can now do some cleanup in the export module.

In particular we remove the `export.args_spec` API, and the `lowering_platforms` arg for `export.export`. These were deprecated in June 2024.

PiperOrigin-RevId: 692398132
2024-11-01 22:56:44 -07:00
Jake VanderPlas
e61a20b45a Remove deprecated jax.experimental.export module.
These tools are now available at jax.export.
2024-10-30 05:27:29 -07:00
George Necula
2feea414ac [export] Add support for serialization for some custom PyTree nodes
See the added documentation for `jax._src.export.register_pytree_node_serialization`
and `jax._src.export.register_namedtuple_serialization`.

Serialization of PyTree nodes is needed to serialize the `in_tree` and
`out_tree` fields of `Exported` functions (not to serialize actual instances
of the custom types).

When writing this I have looked at how TensorFlow handles namedtuple. It does
so transparently, without requiring the user to register a serialization
handler for the namedtuple type. But this has the disadvantage that on
deserializaton a fresh distinct namedtuple type is created for
each input and output type of the serialized function. This means that
calling the deserialized function will return outputs of different types
than then function that was serialized. This can be confusing.

The Python pickle mode does a bit better: it attempts to look up the
namedtuple type as a module attribute in the deserializing code,
importing automatically the module whose name was saved during serialization.
This is too much magic for my taste, as it can result in strange import errors.

Hence I added an explicit step for the user to say how they want
the namedtuple to be serialized and deserialized.

Since I wanted to also add support for `collections.OrderedDict`, which
users are asking for, I added more general support for PyTree custom nodes.
Note that this registration mechanism works in conjunction with the
PyTree custom node registration mechanism. The burden is on the
user to decide how to serialize and deserialize the custom auxdata that
the PyTree custom registration mechanism uses. Not all custom types
will be serializable, but many commonly used ones, e.g., dataclasses,
can now be inputs and outputs of the serialized functions.
2024-10-21 11:38:13 +02:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Dougal Maclaurin
018189491b Clean up and fix primal type to tangent type mapping
This is part of the ["stackless"]() change. I'm splitting it out into a separate PR because we need it for some work on sharding types.

Changes:
  1. Rename `at_least_vspace` to `to_tangent_type` since that's what we always meant by it. `at_least_vspace` was always a bad name (sorry!) but it makes even less sense when you can have a special tangent type for a primal types that's already a vector space itself.
  2. Replace `Zero.from_value` with `Zero.from_primal_value`, which does the required primal-type-to-tangent-type conversion.
  3. Add `to_tangent_type` calls in various other places they're missing.
  4. Remove non-support for float0 in custom deriviatives?
  5. [Optional, WIP] Reinstate some checks that had been skipped over, presumably because of these bugs. (We'll see how far I get with it. Might end up being a separate PR.)
PiperOrigin-RevId: 676115753
2024-09-18 13:43:54 -07:00
Tom Ward
33bd2925f0 [export] Fix poly shape check for vjp function with integer valued, polymorphic output.
PiperOrigin-RevId: 650990009
2024-07-10 06:12:19 -07:00
Sergei Lebedev
56745818a6 Added basic support for int2/uint2 dtypes to JAX


PiperOrigin-RevId: 649366888
2024-07-04 04:13:24 -07:00
jax authors
dffd72e290 Merge pull request from hawkinsp:singletons
PiperOrigin-RevId: 649135349
2024-07-03 11:07:00 -07:00
George Necula
cfa3c91c32 [export] Disable serialization in export_test if flatbuffers is not installed
This allows one to run most of export_test even if flatbuffers
is not installed. Only the serialization and deserialization are
skipped.
2024-07-02 15:46:38 +02:00
Peter Hawkins
8ab0c07edc Don't wrap singleton ir.Values with tuples during HLO lowering.
In general a JAX value might correspond to multiple HLO values, which is why the HLO lowering represents each value as a tuple of zero or more ir.Values. However, the common case is that there is exactly one value, and almost all such lists are singletons.

To reduce the number of singleton list and tuple objects allocated during MLIR lowering, instead represent singleton values as unwrapped ir.Values, and only use a tuple if there is not exactly one ir.Value backing a JAX value.
2024-07-01 16:11:00 -04:00
George Necula
47f1b3de2c [export] Add documentation for debugging and for ensuring compatibility.
The rendered documentation is at https://jax--21976.org.readthedocs.build/en/21976/export/export.html#developer-documentation (for the export developer documentation, including compatibility) and https://jax--21976.org.readthedocs.build/en/21976/export/shape_poly.html#debugging (for the shape polymorphism debugging documentation)

While testing the compatibility mechanism I discovered that it can be circumvented by caches.
To fix this, I added export_ignore_forward_compatibility to mlir.LoweringParameters.
2024-06-28 08:36:55 +03:00
Peter Hawkins
7f4ef63cd8 Run pyupgrade --py310-plus.
Also apply manual fixes to import sorting and unused imports.
2024-06-26 16:10:18 -04:00
jax authors
93e8167e19 Merge pull request from ROCm:ci_export_test_fix
PiperOrigin-RevId: 646923809
2024-06-26 06:35:50 -07:00
Ruturaj4
651cc1cb6c [ROCM] Fix export test for rocm 2024-06-24 10:55:43 -05:00
George Necula
d737abda48 [export] Fix multi-platform lowering for unknown platform, with donated_argnums
I had to ensure that the check for platforms supporting donation
only kicks in when we actually have donation.
2024-06-23 07:26:12 +03:00
George Necula
b58ff2ba20 [shape_poly] Add documentation for shape polymorphism
This involved writing some new content and also moving and adapting
the documentation that existed as part of the jax2tf
README file:

https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion
2024-06-15 18:20:54 +03:00
Jake VanderPlas
a92fa547a0 Re-land https://github.com/google/jax/pull/21847
Reverts 0bcc81ceb33e3065110e3dd56ca215dbb62f0a7b

PiperOrigin-RevId: 643202512
2024-06-13 19:53:53 -07:00
jax authors
0bcc81ceb3 Reverts 5aedafc214cf930f5b196b1eb130fd7ec866bc5e
PiperOrigin-RevId: 643131144
2024-06-13 14:58:54 -07:00
George Necula
7af03a8fd1 [export] Deprecate jax.experimental.export
And announce jax.export.

While turning on the DeprecationWarning I discovered a couple
of tests that needed adjustment.
2024-06-13 21:46:18 +03:00
George Necula
7c3a4db3e4 [export] Rename some API entry points
We take the opportunity of a new jax.export package to rename some
of the API entry points:

  * `Exported.uses_shape_polymorphism` is renamed to `Exported.uses_global_constants`
    because this is more accurate. The dimension variables are global
    constants, but so is the platform index. And we need to run
    global constant propagation and shape refinement for all of these.
  * We rename "serialization version" with "calling convention version".
    Hence we now have `Exported.calling_convention_version`,
    and the configuration flag is renamed from `--jax-serialization-version`
    to `--jax-export-calling-convention-version`. Also,
    `jax.export.minimum_supported_serialization_version` is now
    `jax.export.minimum_supported_calling_convention_version`.
   * We rename `lowering_platforms` to `platforms` both as a field
    of `Exported` and as the kwarg to `export.export`.
   * We rename `jax.export.default_lowering_platform` to `jax.export.default_export_version`.
2024-06-13 06:44:13 +02:00
George Necula
105cc9a103 [export] Add documentation for jax.export 2024-06-12 19:44:47 +02:00
George Necula
97db0e758d [pallas] Add support for cross-platform lowering
When implementing this I have discovered that the
multi-platform lowering support does not handle the case when
the lowering rule for a platform invoke tracing (via `mlir.lower_fun`)
and that tracing encounters a primitive that has lowering rules
only for a particular platform. To support this, I have added
the `LoweringRuleContext.platforms` to override
`ModuleContext.platforms` with a potentially narrower set
of lowering platforms. Added a test for this scenario.
2024-06-12 08:48:58 +02:00
George Necula
e3faf854b0 [export] Cleaned up types of [in|out]_shardings
Previously we declared Exported.in_shardings to be
a sequence of `core.AbstractValue`, but in reality we only
support `core.ShapedArray`. We change the type declaration and
this allowed us to clean up some `# type: ignore"
2024-06-11 13:46:44 +02:00
George Necula
b33aca6b08 [export] Create the jax.export module APIs.
The functionality comes from the jax.experimental.export
module, which will be deprecated.

The following APIs are introduced:

```
  from jax import export
  def f(...): ...
  ex: export.Exported = export.export(jax.jit(f))(*args, **kwargs)

  blob: bytearray = ex.serialize()
  rehydrated: export.Export = export.deserialize(blob)

  def caller(...):
     ... rehydrated.call(*args, **kwargs)
```

Module documentation will follow shortly.
There are no changes for now in the jax.experimental.export
APIs.

Most of the changes in this PR are in tests due to some differences
in the new jax.export APIs compared to jax.experimental.export:

  * Instead of `jax.experimental.export.call(exp)` we now write
    `exp.call`
  * The `jax.experimental.export.export` allowed the function
    argument to be any Python callable and it would wrap it with
    a `jax.jit`. This is not supported anymore by export, and instead
    the user must use `jax.jit`.
2024-06-10 19:31:51 +02:00
George Necula
3914cb415d [export] Remove old deprecated APIs for jax.experimental.export.
See CHANGELOG.md.
The deprecation period has passed.

Also replace deprecated .call_exported with .call in tests.

PiperOrigin-RevId: 641236222
2024-06-07 06:52:10 -07:00
Roy Frostig
ea6dfd1947 rename Specialized to Traced (and specialize to trace)
PiperOrigin-RevId: 641076488
2024-06-06 17:43:08 -07:00
jax authors
a1b5860427 Merge pull request from jakevdp:setup-module
PiperOrigin-RevId: 641049524
2024-06-06 15:59:07 -07:00
Jake VanderPlas
a861c55a28 test cleanup: use ExitStack to reduce test boilerplate 2024-06-06 14:18:27 -07:00
George Necula
01ee768f73 [export] Rename in_shardings and out_shardings fields.
We rename `in_shardings` to `in_shardings_hlo` to remove confusion
with JAX's use of `in_shardings`.
We also rename `xla_compatible_in_sharding` to `in_shardings_jax`
since we do not have a XLACompatibleSharding type anymore.
2024-06-06 22:00:16 +01:00
Yash Katariya
fbf2a62aa1 Remove jaxpr and name from Lowered because specialize already has those. This keeps the abstraction boundary clear. Adapt export to use specialize.
PiperOrigin-RevId: 640968129
2024-06-06 11:38:56 -07:00
George Necula
079eea5669 [export] Add a LoweringParameters.for_export boolean context for exporting
This boolean context field is set only when we are lowering for
exporting. It can be used, e.g., to adapt the lowering rules
for the export case.
2024-06-06 06:28:58 +01:00
Jake VanderPlas
9a080f4b83 Test: use context manager to set jax_serialization_version 2024-06-04 16:08:16 -07:00
George Necula
2f3e02a36a [export] Add helped methods to create XLACompatibleShardings for in_shardings and out_shardings 2024-06-04 09:07:16 +01:00
George Necula
be1e40dc2e Copybara import of the project:
--
f79d1060cccf7c9a1c02d0bcab06c6ee0ef795a8 by George Necula <gcnecula@gmail.com>:

[export] Fix

A user reported an error when trying to export a function
that has a "lower" attribute (to impersonate a jitted function)
but does not have a "__name__" attribute.
The solution is to use the default name "<unnamed function>".

While I was at it I have added a `util.fun_name` to get
the name of a Callable, and I use it in several places.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/21572 from gnecula:exp_fix_name f79d1060cccf7c9a1c02d0bcab06c6ee0ef795a8
PiperOrigin-RevId: 639236990
2024-05-31 20:40:42 -07:00
George Necula
acb56a2909 [export] Fix calling under pmap of exported computation with polymorphic shapes
If we call a computation with shape polymorphism under pmap
we must refine the shapes before we compile.
We follow the same pattern for `UnloadedPmapExecutable` as
for `UnloadedMeshExecutable`: we store the `shape_poly_state`
from the `LoweringResult` into the `compile_args` and we
call `refine_polymorphic_shapes`.

Without this fix we may end up trying to compile HLO with
dynamic shapes.
2024-05-28 16:18:38 +03:00
George Necula
7dbab168fd [export] Simplify construction of shardings for the VJP
This is possible now due to improvements in the handling
of shardings in pjit. At the same time, re-enable the
checking of shardings for the arguments and results of
the VJP function.
2024-05-28 04:23:56 +03:00
Tom Ward
95c05521b4 [export] Enable model replication sharding using jax.pmap.
PiperOrigin-RevId: 636063465
2024-05-22 00:37:16 -07:00
George Necula
9948529735 [export] Relax the check that exported modules are used with same number of devices as when exported
Now we allow a module exported for 1 device and not using any sharding annotations
to be called from a computation that uses multiple devices. Such exported modules
can be parallelized trivially point-wise.
2024-05-21 11:40:16 -07:00
George Necula
6deeee27db [export] Fix device assignment error for grad of exported.
Currently, the export code uses a manufactured device assignment
for exporting the VJP function. We should use instead the same
device assigment that was used when exporting the primal function.

This PR fixes that for the case when the export is done through
the direct use of `jax.experimental.export`, and leaves as future
work the case when the use is from `jax2tf`. We add a disabled
tests for the latter case.

Bug: 
2024-05-20 16:11:01 -07:00
Ashish Shenoy
1d6ffdedc5 Reverts 85e91c2be4310d9728f7bfeefef921ee4a075135
PiperOrigin-RevId: 633622856
2024-05-14 10:07:44 -07:00
George Necula
98aead70eb [export] Relax the check that exported modules are used with same number of devices as when exported
Now we allow a module exported for 1 device and not using any sharding annotations
to be called from a computation that uses multiple devices. Such exported modules
can be parallelized trivially point-wise.
2024-05-13 20:09:43 +03:00