This enables us to use them more simply in the current and upcoming Python code. The Python bindings for enum and enum attributes leave much to be desired.
PiperOrigin-RevId: 718795667
If we repeatedly form tuples by concatenation during printing, we make what should be a linear time operation quadratic.
Also simplify the API contract of extend() to only add a single element, and remove the unused method wrap_name.
PiperOrigin-RevId: 718570432
Debugging info is needed for error messages, and for
lowering. For the former, we need debug info inside
tracers. For the latter, inside Jaxprs. We add a
new set of tests that intentionally leak tracers while
tracing and then we check that the tracers have the
expected debug info. We also form Jaxprs and we
check that they have the expected debug info.
We uncovered a few missing debug infos, those are
marked with TODO.
This change adds support for including "consts" (i.e. closed-over arrays) in the body of the decomposition definition for a `lax.composite` op. The caveat here is that, since the signature of the decomposition must match the composite itself, the values of any consts must be known when lowering so that they can be inlined into the decomposition's HLO. Therefore, there is no support for closing over tracers. Since `lax.composite` doesn't support most transformations anyways, this typically isn't going to be a major limitation except with `jax.jit` (as demonstrated in the tests).
PiperOrigin-RevId: 718048021
Currently `keystr` just calls `str` on the key entries, leading to quite
verbose output. For example:
>>> params = {'foo': {'bar': {'baz': 1, 'bat': [2, 3]}}}
... for path, _ in jax.tree_util.tree_leaves_with_path(params):
... print(jax.tree_util.keystr(path))
['foo']['bar']['bat'][0]
['foo']['bar']['bat'][1]
['foo']['bar']['baz']
This change allows for a new "simple" format where the string representation
of key entries are further simplified. Additionally we allow a custom
separator since it is very common to use `/` (for example to separate module
and parameter names):
... for path, _ in jax.tree_util.tree_leaves_with_path(params):
... print(jax.tree_util.keystr(path, simple=True, separator='/'))
foo/bar/bat/0
foo/bar/bat/1
foo/bar/baz
```
PiperOrigin-RevId: 717971583
The paged_attention_kernel tests for TPU v6 was disabled in the past but I discovered that all the failing tests have `are_kv_quantized=True`. So we can still test the non-quantized part on TPU v6.
PiperOrigin-RevId: 717969073
Most places in the code assumed this already, but often
that usage is error reporting code, which is not yet well tested.
When we cannot get the `inspect.Signature` or when the
args and kwargs do not match the signature, we generate
the flattened argument names as: `args[0]`, `args[1]`,
`kwargs['foo']`, ... Previously, in these cases we
returned `arg_names` is None, and then the whole
debug_info ended up being `None`, throwing away even
available information.
We also add support for `api_util.fun_sourceinfo` even
for cases when the `fun.__code__` is not available. In
those cases we used to say that `fun_sourceinfo` is
`None`. Now, we use the string representation of `fun`
to get the name of built-in functions, or we use "<unknown>".
This enables the dialect lowering to depend on `wgmma.py` without creating a circular dependency. I need this in a follow up CL that implements the lowering of the WGMMA dialect op.
PiperOrigin-RevId: 717791901