As far as I can tell, it seems like the `linear` parameter in the
`lax.cond_p` primitive only exists for historical reasons. It could be
used for type checking in `_cond_transpose`, but that was removed
because of #14026. With this in mind, we could stop tracking this
parameter as implemented in this PR, unless we expect that we'd want to
re-introduce the type checking in the future.
Lambdas are represented by their ids in the metadata of lowered HLO (see example below) and they change every time. This makes the compilation cache less effective as it causes the computation's fingerprint to change every time.
```
get-tuple-element.41724 = bf16[8]{0} get-tuple-element(reduce.41723), index=0, metadata={op_name="pjit(_wrapped_fn)/jit(main)/.../reduce[computation=<function _compute_argminmax.<locals>.reducer_fn at 0x7fa6ecfb2200> dimensions=(1,)]" source_file="..." source_line=...}
```
PiperOrigin-RevId: 601910715
Add HLO source path canonicalization regex to trace state key because otherwise MetadataTest.test_source_file_prefix_removal fails due to caching of lowerings with different canonicalization regexs.
PiperOrigin-RevId: 509975754
Currently
```
import jax
```
populates `jax._src` in the names exported from JAX. This change prepares for not exporting `jax._src` by default.
In particular, explicitly import modules from jax._src and refer to those imports rather than assuming jax._src contents will be around later. This is a common pattern in tests.
This change does not yet remove any exported names.
Issue https://github.com/google/jax/issues/11951
PiperOrigin-RevId: 469480816
* trailing-whitespace
* dangerous-default-value. None of these appear to be bugs in practice, but the potential for accidentally mutating the default value is there, and the cost of avoiding the problem is small.
* invalid-envvar-default. Pass strings as getenv() defaults.
* unnecessary-semicolon. Use tuples instead for this one-liner.
* invalid-hash-returned. Raise an exception rather than asserting false.
* pointless-string-statement. Use comments instead.
* unreachable. Use @unittest.skip() decorator rather than raising as first line in test.
* logging-not-lazy. Make the logging lazy.
* bad-format-string-type. Use f-string instead.
* subprocess-run-check. Pass check=...
PiperOrigin-RevId: 400858477
This is meant to be used with @colemanliyah's persistent compilation
cache, since the serialized HLO computation (including the source_file
metadata) is used in the cache key. The config can be used to remove
bits of the source file path that vary between program invocations, to
avoid spurious cache misses.
This change, when enabled, stages out all primitive calls in the dynamic
scope of a jitted, pmapped, or control flow function, rather than only
staging out based on data dependence. One improvement is that jitted
functions can consume less memory, by avoiding instantiating large
constants at trace time, and cause less memory fragmentation as well. It
also simplifies several internals.
See https://github.com/google/jax/pull/3370 fo more information.
* Remove usage of xla_client.{Computation,ComputationBuilder}.
ComputationBuilder is a fairly pointless wrapper class that mimics an outdated version of the the C++ XLA API. It dates back from when we used to have SWIG bindings and needed to write a non-trivial Python shim to keep the interface pleasant to use. Now that we have pybind11-based bindings that are reasonably ergonomic by themselves, we don't need the wrapper class. Instead, we can simply call the pybind11-wrapped C++ API directly, removing the impedance mismatch between the C++ and Python APIs and allowing us to delete the Python ComputationBuilder class.
Similarly we can delete xla_client.Computation for the same reasons; it doesn't do anything useful on top of the C++ API.
Some higher-order primitives, like 'scan' and 'while', benefit from
distinguishing constants from other inputs to their closure-converted
function arguments; the reason is that for those primitives constants
act differently from the other inputs, which are loop carries or
scanned-over values, and are handled differently by transformations. For
example, they're used differently than loop carries in lattice
fixed-point computations. As another example, in scan the constants in
the forward computation are fanned out, so when transposing scan we
generate an accumulate-add.
However, these considerations don't hold true for cond: since there's no
looping going on (and hence no lattice fixed-points), constants are
treated just like the other operands. So we don't need to carry around
the distinction. That simplifies the cond rules a bit.
Co-authored-by: Roy Frostig <frostig@google.com>