8224 Commits

Author SHA1 Message Date
George Necula
c70de6deed [better_errors] Merge the JaxprDebugInfo and TracingDebugInfo into core.DebugInfo
Previously, we had two almost identical classes: `TracingDebugInfo` and
`JaxprDebugInfo`. The only difference was that `TracingDebugInfo` had
a thunk to return the result paths, while `JaxprDebugInfo` had the
result paths resolved to a tuple. The separation of these types
provided some clarity, but also led to code duplication and
required conversions as the debugging info goes from `WrappedFun`
to a `Jaxpr` and then to `WrappedFun` again.
2025-02-02 06:23:03 +02:00
jax authors
de48ce2a4c Merge pull request #26174 from skye:cpu_configs
PiperOrigin-RevId: 722199229
2025-02-01 16:20:16 -08:00
Christos Perivolaropoulos
b23f8f414b [pallas/pallas_mgpu] Discharging run_scoped should not be discharging the intermediates
When we do run_scoped[jaxpr, R1,R2], it can't be assumed that references
corresponding to R1 and R2 can be safely discharged. Sometimes they can (eg
Accumulator) but sometimes they can't (eg SMEM scratch). It should be up to the
lowering rule to do such discharging.

This further means that during lowering there is no guarantee that the
references will not be used/returned by nested scoped blocks so we also remove
that check.

PiperOrigin-RevId: 722137352
2025-02-01 09:37:03 -08:00
Christos Perivolaropoulos
8649132d86 [pallas] Support DMA start partial discharge and run_scoped() does its own partial discharge.
This CL lays the ground for a future CL that makes run_scoped discharge to not request the discharge of the temporary buffers it creates. This causes issues becausa

a) dma_start can't discharge some but not all its references
b) run_scoped() lowering depends on run_scoped discharge to remove the run_scoped operation (or it goes in an infinite loop).

PiperOrigin-RevId: 722126566
2025-02-01 08:23:23 -08:00
Jevin Jiang
ed952c8e65 [Pallas TPU] Support jnp.take_along_axis for 32-bit vreg-sized vector.
PiperOrigin-RevId: 722015152
2025-01-31 21:27:08 -08:00
jax authors
872e6c0ec4 Merge pull request #25766 from carlosgmartin:nn_initializers_variance_scaling_mode_fan_geo_avg
PiperOrigin-RevId: 721928532
2025-01-31 15:41:50 -08:00
jax authors
a9f4dd7182 Merge pull request #26249 from jakevdp:fix-sterling
PiperOrigin-RevId: 721922732
2025-01-31 15:26:37 -08:00
carlosgmartin
96d3447e89 Add mode='fan_geo_avg' to nn.initializers.variance_scaling. 2025-01-31 17:52:22 -05:00
Emily Fertig
3b2410f77c Reverts bb951136e9b91a584bb422119ada76cc69c86024
PiperOrigin-RevId: 721908669
2025-01-31 14:42:22 -08:00
Jake VanderPlas
216bd9a6cc Fix dtype issue in stirling approximation 2025-01-31 14:13:02 -08:00
jax authors
70aed64a15 Merge pull request #26245 from jakevdp:fix-line
PiperOrigin-RevId: 721879951
2025-01-31 13:16:51 -08:00
jax authors
cb79ff4d85 Merge pull request #26194 from jax-ml:fix-dist-init-runtime-error
PiperOrigin-RevId: 721878712
2025-01-31 13:12:46 -08:00
Jake VanderPlas
0df7f182d6 delete unnecessary line 2025-01-31 12:44:14 -08:00
Gunhyun Park
20555f63da Lower np.ndarray to DenseElementsAttr instead of ArrayAttr.
PiperOrigin-RevId: 721833949
2025-01-31 11:06:06 -08:00
Jake VanderPlas
522200ff45 jax.lax: improve documentation for several functions 2025-01-31 09:30:56 -08:00
Peter Hawkins
60fad99c9c Fix CI failures in xla_metadata_test.
PiperOrigin-RevId: 721572019
2025-01-30 17:28:54 -08:00
Yash Katariya
1f33cad321 remove checks since they are redundant and we can change out_aval because of various reasons
PiperOrigin-RevId: 721535417
2025-01-30 15:14:34 -08:00
Yash Katariya
9107ee4a22 Do automatic casting from auto -> manual when the context mesh is manual and avals are in auto mode. This happens when values are being closed over in a shard_map. The casting is happening at lax level but we can move this to a different place later on.
PiperOrigin-RevId: 721495804
2025-01-30 13:14:04 -08:00
Gunhyun Park
a8df383ccf Fix lax.ragged_all_to_all degenerate case
In a singleton group case, unlike regular all_to_all, the ragged op becomes a generic equivalent of DynamicUpdateSlice, except update size is not statically known. This operation can't be expressed with standard HLO instructions -- the backend will handle this case separately.

Added small improvement to error messages.

PiperOrigin-RevId: 721473063
2025-01-30 12:05:02 -08:00
Yash Katariya
f4e2c6c34c Try to match out_spec with in_spec if both shardings are full auto and they are equivalent to each other. This is because of backwards compatibility reasons where tests expect the in and out shardings to match.
PiperOrigin-RevId: 721470917
2025-01-30 11:59:57 -08:00
Emily Fertig
bb951136e9 Return arrays from ArrayImpl._check_and_rearrange.
This is in preparation for a larger change, so that input buffers can be checked before Array creation in XLA and the user gets more helpful JAX error messages instead of XLA errors.

PiperOrigin-RevId: 721412760
2025-01-30 09:10:50 -08:00
jax authors
1003ba93c3 Merge pull request #26150 from jreiffers:main
PiperOrigin-RevId: 721400896
2025-01-30 08:32:20 -08:00
George Necula
32c98b9a76 [better_errors] Refactor more uses of pe.tracing_debug_info (part 3)
We replace uses of `pe.tracing_debug_info` with with `api_util.tracing_debug_info`,
which uses the actual args and kwargs, instead of `in_tree` to manufacture fake
args and kwargs. This ends up being more accurate, especially for `arg_names`;
see changes in debug_info_tests.py.
This means that we have to construct the debug info further upstream, before
flattening args. This will later help populate debug info in `WrappedFun` and
`Jaxpr`.

This is part 3 of a series (following #26097, #26099) for jit, pmap, checkify,
and the custom_partitioning (the last few uses).

In order to land this, I had to remove a safety check that the number of
`arg_names` and `result_paths` in a Jaxpr's debug info match the number
of Jaxpr invars and outvars, respectively. Additionally, I added two
accessors `safe_arg_names` and `safe_result_paths` to ensure that
the arg names and result paths match the expected length. These accessors
return no-op results when the lengths are not as expected.
From my testint, this happens only in Jaxprs that
are not used for lowering, hence there is no actual user-visible
change here. Simply, more internal Jaxprs are getting debug_info
and in some cases the `arg_names` and `result_paths` are not correct.
Still, this change is worth it because the `func_src_info` is the most
useful part of the debug info (used for leaked tracers), and that is
accurate. We will fix the `arg_names` and `result_paths` in a future change.

One can see in the changes in debug_info_test.py the improvements in the
user-visible debug info, including for `pjit` and `pmap` cases when
it was wrong.
2025-01-30 07:40:05 +02:00
Yash Katariya
d223dfc3f7 Allow multiple meshes for avals but in that case, just use empty_abstract_mesh instead of enabling computation follows data only for **Auto mode**.
PiperOrigin-RevId: 721224349
2025-01-29 20:47:34 -08:00
Zac Cranko
186478d213 improve runtime error message 2025-01-29 21:03:26 +00:00
jax authors
152099ee0e Merge pull request #26188 from dfm:revert-callback-docs
PiperOrigin-RevId: 721077634
2025-01-29 12:53:57 -08:00
Sergei Lebedev
d4ced960ab Pulled DLDeviceType to XLA backend mapping into a global
I also updated `to_dlpack` and `from_dlpack` to handle `KeyError` instead of `TypeError`, because I think `TypeError` was never actually raised.

PiperOrigin-RevId: 721052736
2025-01-29 11:38:50 -08:00
Yash Katariya
dcb28f1218 [sharding_in_types] Add vmap + explicit sharding support. The main changes are:
* Track `explicit_mesh_axis` on `AxisData`.
* Modify `unmapped_aval` to the the above explicit mesh axis and insert it into the right place in the sharding so out_shardings are correct.
* Make `matchaxis` also handle shardings correctly
* All mapped dimensions should be sharded the same way
* spmd_axis_name and explicit sharded arrays cannot be used together
* `out_shardings` parameter on `dot_general`, `broadcast_in_dim`, `reshape`, `reshard` and `mesh_cast` is handled correctly in presence of vmap.

This should eventually help us get rid of `spmd_axis_name` from `vmap`.

PiperOrigin-RevId: 721007659
2025-01-29 09:34:27 -08:00
Bixia Zheng
20843643ab [jax:custom_partitioning] Make propagate_user_sharding default to None.
Revise documentation for sharding_rule and add a link to jax-shardy-guide.

PiperOrigin-RevId: 721001922
2025-01-29 09:14:35 -08:00
Jake VanderPlas
955e7c4793 Internal: avoid adding _DimExpr to dtypes._weak_types
This causes problems because internal code assumes it will not be modified. We replace this with an internal registration mechanism.

PiperOrigin-RevId: 721000907
2025-01-29 09:11:02 -08:00
Dan Foreman-Mackey
e2eff1f8d5 Revert https://github.com/jax-ml/jax/pull/25982 since callbacks can now use JAX functions. 2025-01-29 11:12:32 -05:00
jax authors
a459e7e4cd Merge pull request #26151 from gnecula:debug_info_collect_lowered_jaxprs
PiperOrigin-RevId: 720911587
2025-01-29 04:00:03 -08:00
jax authors
bf22b53cf4 Merge pull request #26154 from jakevdp:pure-callback-doc
PiperOrigin-RevId: 720763192
2025-01-28 17:28:02 -08:00
Skye Wanderman-Milne
2aa810fe60 Make JAX_CPU_COLLECTIVES_IMPLEMENTATION and JAX_NUM_CPU_DEVICES env vars
Before, these values could only be specified via jax.config or
flags. This PR makes them proper configs, so they also work as env
vars.
2025-01-28 17:17:56 -08:00
Gunhyun Park
809e1133c8 Add support for axis_name and axis_index_groups to lax.ragged_all_to_all
PiperOrigin-RevId: 720738861
2025-01-28 16:02:03 -08:00
George Necula
f8673cde94 [better_errors] Expand debug info testing with eager mode, and MLIR module checking.
Made several improvements to the debug info tests:

 * added support for eager mode, which sometimes uses
   different code paths for the debug info, e.g., for
   `jvp(pmap)`. To check the debugging info in these cases we add
   instrumentation to collect the lowered Jaxprs and MLIR modules right
   after lowering, and we check the debugging information there.
 * added support for checking for the presence of regular expressions
   and strings in the lowered module, to check that the location
   information and arg_names and result_paths is present. This
   is now enabled only for a subset of the tests.
 * simplified the pretty-printing of the arg_names and result_paths
   in the debug info, to remove a layer of parentheses and string,
   so that instead of `arg_names=("x", "y")` we now pretty-print
   just `arg_names=x,y"
 * added support for checking the provenance information in
   leaked tracers
2025-01-28 23:51:06 +02:00
Dan Foreman-Mackey
09392d8160 Simplify dtype inference in lax.linalg.eig abstract eval rule.
I came across this when working on an unrelated issue, but the explicit use of `finfo` was causing some `UserWarning`s, and it was really unnecessary.

PiperOrigin-RevId: 720691470
2025-01-28 13:35:53 -08:00
Yash Katariya
8f248fe626 [sharding_in_types] Upstream changes from defaulting sharding_in_types config to True experiment. There aren't a lot of failures in TGP but we can atleast upstream these changes until we work on the failures.
PiperOrigin-RevId: 720639755
2025-01-28 11:04:42 -08:00
Jake VanderPlas
ba2858f834 DOC: add discussion of exceptions in pure_callback 2025-01-28 09:53:47 -08:00
Johannes Reifferscheid
55d891c5bf Don't apply shardy config during def_partition.
PR #25834 intended to dynamically choose the the partitioner API, but
it still applies the configuration value too early (it should only be
applied in __call__, not in def_partition and __call__).
2025-01-28 15:37:47 +01:00
Yash Katariya
7ed7e0b5b1 [sharding_in_types] Add clamp_p sharding rule.
PiperOrigin-RevId: 720428881
2025-01-27 21:58:08 -08:00
Yash Katariya
ae705fef9c [sharding_in_types] Add support for svd_p
PiperOrigin-RevId: 720409750
2025-01-27 20:31:54 -08:00
Justin Fu
54ac172b4c [Pallas] Refactor Pallas HLO interpret mode to a standalone file.
Also replaces the interpreter context (used only for handling extended dtypes) with a physicalize Jaxpr pass.

PiperOrigin-RevId: 720371033
2025-01-27 17:52:27 -08:00
jax authors
763ffb3f73 Merge pull request #26128 from jakevdp:norm-doc
PiperOrigin-RevId: 720243405
2025-01-27 11:24:57 -08:00
jax authors
579a8fc500 Merge pull request #26123 from dfm:custom-dce-changelog
PiperOrigin-RevId: 720241697
2025-01-27 11:20:13 -08:00
Jake VanderPlas
a6a0226a53 jnp.linalg.norm: better documentation & error text for axis 2025-01-27 10:39:19 -08:00
Dan Foreman-Mackey
782138fb6f Add custom_dce to changelogs and API docs. 2025-01-27 13:03:34 -05:00
Peter Hawkins
893197fc9a [JAX] Add a note that setting environment variables in tests is not thread-safe.
PiperOrigin-RevId: 720202505
2025-01-27 09:40:12 -08:00
George Necula
878272ee3c [better_errors] Refactor more uses of pe.tracing_debug_info (part 2)
We replace uses of `pe.tracing_debug_info` with with `api_util.tracing_debug_info`,
which uses the actual args and kwargs, instead of `in_tree` to manufacture fake
args and kwargs. This ends up being more accurate, especially for `arg_names`;
see changes in debug_info_tests.py.
This means that we have to construct the debug info further upstream, before
flattening args. This will later help populate debug info in `WrappedFun` and
`Jaxpr`.

This is part 2 of a series (following #26097) for Pallas.
2025-01-27 16:10:56 +02:00
jax authors
2a6accd63f Merge pull request #26097 from gnecula:debug_info_no_pe_debug_info
PiperOrigin-RevId: 720106054
2025-01-27 03:46:09 -08:00