15738 Commits

Author SHA1 Message Date
Jake VanderPlas
23c1d62910 internal: move more NumPy APIs to ensure_arraylike 2025-01-23 08:48:13 -08:00
Yash Katariya
6b95ad0a53 Don't calculate tracing debug_info twice
PiperOrigin-RevId: 718860522
2025-01-23 07:37:14 -08:00
Yash Katariya
33aa088a5c Remove axis_name from pmap_unmapped_aval_handlers
PiperOrigin-RevId: 718859837
2025-01-23 07:34:38 -08:00
Dimitar (Mitko) Asenov
6f609926a6 [Mosaic GPU] Remove an unnecessary restriction in the vector.store lowering
This was made obsolete by:
f89accc56a

PiperOrigin-RevId: 718808561
2025-01-23 04:24:14 -08:00
Dimitar (Mitko) Asenov
3a411d883a [Mosaic GPU] Implement basic WGMMAFragLayout inference and propagation
PiperOrigin-RevId: 718781860
2025-01-23 02:48:04 -08:00
Parker Schuh
f3e27b6c28 Support axis_index using a nested shard_map instead of iota with full to shard.
PiperOrigin-RevId: 718661661
2025-01-22 19:14:37 -08:00
Yash Katariya
704b2e5fba [sharding_in_types] Make vmap work with shard_map + pallas
PiperOrigin-RevId: 718578207
2025-01-22 16:48:32 -08:00
Peter Hawkins
cd51e9dd14 Speed up name stack printing.
If we repeatedly form tuples by concatenation during printing, we make what should be a linear time operation quadratic.

Also simplify the API contract of extend() to only add a single element, and remove the unused method wrap_name.

PiperOrigin-RevId: 718570432
2025-01-22 16:24:16 -08:00
Sharad Vikram
64e9b07ee3 Update debugging Pallas g3doc to remove text about scalar printing restriction
PiperOrigin-RevId: 718560406
2025-01-22 15:55:05 -08:00
Yash Katariya
23d360bded Remove axis_name from unmapped_aval
PiperOrigin-RevId: 718558713
2025-01-22 15:49:04 -08:00
jax authors
f6243ff8e1 Merge pull request #25889 from Stella-S-Yan:cache_reset
PiperOrigin-RevId: 718537398
2025-01-22 14:52:05 -08:00
Peter Hawkins
f4adcc650f Set __slots__ on core.Trace subclasses.
This is easy to do and makes field accesses on Trace classes slightly faster.
2025-01-22 16:17:54 -05:00
jax authors
fc9356085e Merge pull request #26024 from jakevdp:setops-args
PiperOrigin-RevId: 718467165
2025-01-22 11:37:51 -08:00
Justin Fu
10bb38bb79 [Mosaic GPU] Add manual consumed barrier handling to WS pipeline.
PiperOrigin-RevId: 718451678
2025-01-22 10:59:58 -08:00
George Necula
849ccc978b [better_errors] Expand the tests for debug_info
Debugging info is needed for error messages, and for
lowering. For the former, we need debug info inside
tracers. For the latter, inside Jaxprs. We add a
new set of tests that intentionally leak tracers while
tracing and then we check that the tracers have the
expected debug info. We also form Jaxprs and we
check that they have the expected debug info.
We uncovered a few missing debug infos, those are
marked with TODO.
2025-01-22 16:49:16 +01:00
jax authors
e304e9ea16 Merge pull request #25992 from gnecula:debug_info_arg_names
PiperOrigin-RevId: 718216003
2025-01-21 22:17:08 -08:00
Stella S Yan
f87c94db75 Fix cache init when JAX Array is created early (#25768) 2025-01-22 03:44:29 +00:00
Yash Katariya
3aa55992fe Remove device_context from trace_context because we don't need it there. We can get compilation cache misses (and tracing/lowering cache hit) naturally without putting concrete devices into trace_context.
PiperOrigin-RevId: 718113413
2025-01-21 16:21:36 -08:00
Jake VanderPlas
a69f9dcc19 jax.numpy setops: use ensure_arraylike & avoid asarray 2025-01-21 16:05:49 -08:00
Yash Katariya
051861bbf1 Error out if contracting dimensions are sharded and ask the user to provide the output sharding
PiperOrigin-RevId: 718087084
2025-01-21 15:06:54 -08:00
Dan Foreman-Mackey
aa8c0010e2 Add support for constants in the decomposition of lax.composites.
This change adds support for including "consts" (i.e. closed-over arrays) in the body of the decomposition definition for a `lax.composite` op. The caveat here is that, since the signature of the decomposition must match the composite itself, the values of any consts must be known when lowering so that they can be inlined into the decomposition's HLO. Therefore, there is no support for closing over tracers. Since `lax.composite` doesn't support most transformations anyways, this typically isn't going to be a major limitation except with `jax.jit` (as demonstrated in the tests).

PiperOrigin-RevId: 718048021
2025-01-21 13:28:42 -08:00
Peter Hawkins
afb750cf76 Remove unnecessary use of xla_client.OpMetadata class.
We create this object and immediately turn it into a different object. We can cut out a step here!

PiperOrigin-RevId: 718023353
2025-01-21 12:24:51 -08:00
Tom Hennigan
7f43316e27 Add an option to simplify keystr output and use a custom separator.
Currently `keystr` just calls `str` on the key entries, leading to quite
verbose output. For example:

    >>> params = {'foo': {'bar': {'baz': 1, 'bat': [2, 3]}}}
    ... for path, _ in jax.tree_util.tree_leaves_with_path(params):
    ...   print(jax.tree_util.keystr(path))
    ['foo']['bar']['bat'][0]
    ['foo']['bar']['bat'][1]
    ['foo']['bar']['baz']

This change allows for a new "simple" format where the string representation
of key entries are further simplified. Additionally we allow a custom
separator since it is very common to use `/` (for example to separate module
and parameter names):

    ... for path, _ in jax.tree_util.tree_leaves_with_path(params):
    ...   print(jax.tree_util.keystr(path, simple=True, separator='/'))
    foo/bar/bat/0
    foo/bar/bat/1
    foo/bar/baz
```

PiperOrigin-RevId: 717971583
2025-01-21 10:18:42 -08:00
jax authors
70a5175d0a Merge pull request #25647 from Rifur13:bwd_pass
PiperOrigin-RevId: 717954065
2025-01-21 09:37:41 -08:00
Sergei Lebedev
4c363766f8 [mosaic_gpu] Removed debug prints in jax._src.lib.mosaic_gpu
PiperOrigin-RevId: 717952850
2025-01-21 09:36:12 -08:00
jax authors
3b5b98163b Merge pull request #25982 from roth-jakob:warning_callbacks
PiperOrigin-RevId: 717951477
2025-01-21 09:29:52 -08:00
Yash Katariya
d30476a0ae Make make_mesh take visible_axes, hidden_axes and collective_axes as parameters instead of axis_types to make it a more cleaner API.
The mesh axis names provided to those parameters should be disjoint i.e. no overlap.

PiperOrigin-RevId: 717941577
2025-01-21 09:02:39 -08:00
Christos Perivolaropoulos
c4643c6156 [mosaic_gpu] a function for bitwidth as well as bytewidht.
This is to enable s4 and friends.

PiperOrigin-RevId: 717896640
2025-01-21 06:52:04 -08:00
George Necula
3f73f7b0eb [better_errors] Ensure debug_info.arg_names is never None.
Most places in the code assumed this already, but often
that usage is error reporting code, which is not yet well tested.

When we cannot get the `inspect.Signature` or when the
args and kwargs do not match the signature, we generate
the flattened argument names as: `args[0]`, `args[1]`,
`kwargs['foo']`, ... Previously, in these cases we
returned `arg_names` is None, and then the whole
debug_info ended up being `None`, throwing away even
available information.

We also add support for `api_util.fun_sourceinfo` even
for cases when the `fun.__code__` is not available. In
those cases we used to say that `fun_sourceinfo` is
`None`. Now, we use the string representation of `fun`
to get the name of built-in functions, or we use "<unknown>".
2025-01-21 13:38:10 +01:00
Dimitar (Mitko) Asenov
f89accc56a [Mosaic GPU] Add support for converting all fragmented layouts to ir and back.
This will be used in the layout inference and lowering of the dialect WGMMA op

PiperOrigin-RevId: 717836648
2025-01-21 03:27:03 -08:00
Dimitar (Mitko) Asenov
0a89760c24 [Mosaic GPU] Do not use mgpu in wgmma.py
This enables the dialect lowering to depend on `wgmma.py` without creating a circular dependency. I need this in a follow up CL that implements the lowering of the WGMMA dialect op.

PiperOrigin-RevId: 717791901
2025-01-21 01:06:42 -08:00
Yash Katariya
bba5ada525 Make sure reshard and mesh_cast behave properly under eager mode
PiperOrigin-RevId: 717714860
2025-01-20 20:33:09 -08:00
Yash Katariya
a943ebf449 [sharding_in_types] Move the calculation of new_mesh inside the decorator so that functions can be decorated with hidden_axes and visible_axes at the top level
PiperOrigin-RevId: 717701048
2025-01-20 19:41:22 -08:00
Gleb Pobudzey
b8b9f2bc33 Fix the backwards pass and support more block sizes. 2025-01-21 02:48:37 +00:00
Yash Katariya
d50d1e2c40 Don't allow users to query tracer.sharding even under sharding in types mode.
Instead, users should do `tracer.aval.sharding` so that code behaves the same under jit and eager mode.

PiperOrigin-RevId: 717638986
2025-01-20 15:12:47 -08:00
jax authors
7f19b345fb Merge pull request #25984 from Rifur13:use_exp2
PiperOrigin-RevId: 717627777
2025-01-20 14:20:39 -08:00
Yash Katariya
799eb98cac Add reshard API in experimental. Currently for sharding_in_types we have 2 APIs: mesh_cast and reshard. Both work in sharding_in_types mode and affect the sharding of the aval. Following are the semantics of both:
* `mesh_cast`: AxisTypes between src and dst mesh **must** differ. There should be **no "visible" data movement**. The shape of the aval doesn't change.

* `reshard`: Mesh should be the **same** between src and dst (same axis_names, axis_sizes and axis_types). **Data movement is allowed**. The shape of the aval doesn't change.

We might make `reshard` == `device_put`, hence the API is in experimental. This decision can be taken at a later point in time. The reason not to just give `device_put` this power is because `device_put` does a lot of stuff right now (and is going to get even more powers in the near future like cross-host transfers) and it's semantics would be very confusing if we keep piling sharding-in-types stuff on it.

PiperOrigin-RevId: 717588253
2025-01-20 11:39:25 -08:00
George Necula
4fd0bb05b1 [better_errors] Finally remove api_util.debug_info.
Following https://github.com/jax-ml/jax/pull/25916 there were a few TODOs
left in the code to remove api_util.debug_info and replace the
one remaining use with api_util.tracing_debug_info.

PiperOrigin-RevId: 717583667
2025-01-20 11:19:53 -08:00
George Necula
dcf72b01f4 [better_errors] Improvements in propagation of debugging info
Added some documentation for `TracingDebugInfo` (docstring, comments
about `arg_names`, since it was not obvious to me that this would
flatten the non-static arguments).

Laying the ground for the unification of the old `api_util.debug_info`
and `partial_eval.tracing_debug_info`: we rename the former to
`api_util.tracing_debug_info`, we push inside the calls to
`fun_sourceinfo` and `fun_signature` (which were done by the callers
until now), and we rewrite the latter in terms
of the former. We leave for a future PR the actual replacing of the
latter with the former throughout.

In the process of above, cleaned up the one case when `partial_eval.tracing_debug_info`
received None for the `in_tree` and `out_tracer_thunk`. The function contained
catch-all exception clauses to handle those, but doing so it masked other places
where we fail to collect debug info due to programming mistakes. E.g., in
one place we passed a `WrappedFun` instead of a `Callable`, resulting in missing debugging info.

Added more type declarations.

Added a `state_test` with a failure to track debugging information, manifested
with a leaked tracer without function provenance. Fixing this in a subsequent PR.
2025-01-20 15:09:51 +01:00
Jakob Roth
4f8699c8a1 Update docs of callbacks
Callback functions should not call into JAX. This information was
missing in the docs of the callbacks. This commit adds this information
to the docs.

See: #25861, #24255
2025-01-19 11:33:20 +01:00
Gleb Pobudzey
1f59506384 Speed up attention kernel by using exp2 2025-01-19 06:18:05 +00:00
jax authors
aed9c6f149 Merge pull request #25969 from jakevdp:fix-util
PiperOrigin-RevId: 717104490
2025-01-18 18:02:43 -08:00
jax authors
cc38d8c10e Merge pull request #25976 from arvoelke:fix-memory-space-error
PiperOrigin-RevId: 717101811
2025-01-18 17:52:47 -08:00
Yash Katariya
5a068da699 Remove memories flag now that JAX 0.5.0 has been released since it always defaults to True.
PiperOrigin-RevId: 716908015
2025-01-17 22:13:04 -08:00
Yash Katariya
36daf36913 Add a sharding rule for reduce_precision_p and properly thread eqn.ctx in loops.py where we create pe.new_jaxpr_eqn's
PiperOrigin-RevId: 716849111
2025-01-17 17:31:24 -08:00
Aaron Russell Voelker
4173842736
add f-string to mosaic memory space error msg 2025-01-17 20:16:36 -05:00
Yash Katariya
c7f8d17f5a Expose hidden_axes via jax namespace as public API. Also mention it as a workaround for primitives we don't support yet.
PiperOrigin-RevId: 716839003
2025-01-17 16:48:58 -08:00
Jake VanderPlas
45a352041c internal: check integer overflow in lax.asarray 2025-01-17 14:38:13 -08:00
Yash Katariya
12b59f8e53 Rename hidden_mode -> hidden_axes and hidden_mode_ctx -> use_hidden_axes. Same for visible mode and visible_mode_ctx.
Also make the `axes` parameter optional of hidden_axes and visible_axes functions. If axes is optional, you drop into full hidden/visible mode.

PiperOrigin-RevId: 716771872
2025-01-17 13:01:07 -08:00
Peter Hawkins
efab6945ca Remove code that supported jaxlib < 0.5.
The new xla_extension_version is 303 and the new mlir_api_version is 57.
2025-01-17 14:22:27 -05:00