also:
* fix jit invariance bug around weak types
* elide trivial broadcasts
This started as an attempt to simplify some jaxpr pretty-prints, by (1)
eliding some convert_element_type applications that I thought were
unnecessary and (2) eliding some trivial broadcasts.
But it turned out that we were actually pruning more
convert_element_types than we should! In particular, see
test_weak_type_jit_invariance; that test fails on the main branch even
if we add the fixes in DynamicJaxprTrace.new_const, because [this
logic](b53a174042/jax/interpreters/partial_eval.py (L1225))
was not paying attention to weak types and hence clobbered them.
In addition to fixing those bugs that turned up (the changes in
DynamicJaxprTrace, and in what is now _convert_elt_type_fwd_rule), this
PR generalizes the jaxpr simplification machinery so as not to be a
couple special cases on convert_element_type_p. Insetad, we have tables
of rules! How we love them.
These rule signatures should let us add simplifications like forwarding
variables through calls and other higher-order primitives. That's all
future work though.
* replace uses of `jax.ops.index[...]` with `jax.numpy.index_exp[...]`, which is a standard NumPy function that does the same thing.
* remove some redundant uses of `jax.ops.index[...]`, where the expression is passed directly to an indexed accessor function like `.at[...]`.
* update some remaining users of `jax.ops.index_update(x, jax.ops.index[idx], y)` to use the `x.at[idx].set(y)` APIs.
PiperOrigin-RevId: 406162068
The `jax.experimental.stax` and `jax.experimental.optimizers` modules are standalone examples libraries. By contrast, the remaining modules in `jax.experimental` are experimental features of the JAX core system. This change moves the two example libraries, and the README that describes them, to `jax.example_libraries` to reflect this distinction.
PiperOrigin-RevId: 404405186
This is a more descriptive name and a better location (next to other facilities for building XLA IR).
Quite a few users of the former xla_bridge.constant() didn't need anything other than uncanonicalized array constants. Change these users to use xla_client.ops.Constant instead; no need for the fancy utility in these cases.
PiperOrigin-RevId: 404270649
Previously jax.core.DropVar was a singleton value (jax.core.dropvar) whose type was always jax.core.AbstractUnit. However, this type is misleading: a DropVar is an equation output, and typically we would expect it to have an array type. In particular, the unit type confuses new-style translation rules that expect to use the output aval on an equation as part of the lowering logic.
Instead, change DropVar to be a non-singleton subclass of Var instead with a flexible choice of aval.
PiperOrigin-RevId: 404071001
An upcoming change will move and rename these functions, and it's not clear they should have been public in the first place.
PiperOrigin-RevId: 404051961
This is a vestigal wrapper around xla_client.XlaBuilder whose purpose is long gone.
Also rename uses of XlaComputationBuilder to XlaBuilder. XlaComputationBuilder was an older name that is gone in most places.
* trailing-whitespace
* dangerous-default-value. None of these appear to be bugs in practice, but the potential for accidentally mutating the default value is there, and the cost of avoiding the problem is small.
* invalid-envvar-default. Pass strings as getenv() defaults.
* unnecessary-semicolon. Use tuples instead for this one-liner.
* invalid-hash-returned. Raise an exception rather than asserting false.
* pointless-string-statement. Use comments instead.
* unreachable. Use @unittest.skip() decorator rather than raising as first line in test.
* logging-not-lazy. Make the logging lazy.
* bad-format-string-type. Use f-string instead.
* subprocess-run-check. Pass check=...
PiperOrigin-RevId: 400858477
- Add docstring to abstract property
- Add explicit HTML documentation of this property
- Mark index update functions as deprecated, linking to this documentation
This means that jax and its dependencies (e.g. jaxlib) must be
manually installed before running the tests. This is useful for
testing an existing jax install, e.g. a later version of jaxlib, GPU
jaxlib, etc.