This change prepares for upcoming changes in which we run tests in parallel using threads, which we are doing partially to test free threading but also partially to speed up TPU tests via thread-parallelism.
If independent tests run in parallel in no particular order, there's no natural scope around which to call setUpClass or SetUpModule. But for JAX tests this never seems necessary: we can just do the same work in setUp() or do it globally.
PiperOrigin-RevId: 713296722
Previously, the idea was that we would use the `convert_element_type` primitive
to cast to/from extended dtypes. Extended dtype rules specified
`convert_from(dtype1, dtype2) -> bool` and `convert_to(dtype1, dtype2) -> bool`
functions. They were meant to do something like indicate whether a
convert_element_type was legal. But I'm not sure if they really made sense.
The implementation was certainly buggy for non-scalar representation types
(physical element types).
This PR simplifies and fixes things:
1. Instead of overloading the `convert_element_type_p` primitive with more cases
involving casts to/from extended dtypes, let's just have distinct `to_edtype_p`
and `from_edtype_p` primitives, which can be much simpler. We still reuse the
`jax.lax.convert_element_type` API function, so there's no API change to the
few existing users who know about this stuff.
2. Instead of extended dtype rules including `convert_from`/`convert_to`
functions with questionable semantics, let's only allow casts to/from the
representation type, which is already specified by the rules'
`physical_element_aval`. (Indeed that should be roughly _all_ we need, and this
PR is just one step towards realizing that goal.) We still have a boolean
`allow_conversion` on extended dtype rules just so we can handle the PRNGKey
case, where we don't want to allow any casts.
3. Fix the conversion logic to handle non-scalar representation types (physical
element types).
_broadcast_to needlessly squeezes away size 1 dimensions before passing its input to broadcast_in_dim. But broadcast_in_dim is perfectly happy to broadcast size 1 dimensions, so we don't need this squeeze.
The original PR was reverted because of downstream breakage.
Originally we used the `Var.count` attribute to ensure `Var` instances were printed consistently regardless of context, even though only their object id was load-bearing. That is, `Var.count` was only used for pretty printing. (#1949 added a total_ordering on `Var` for reasons out of scope of JAX's core code. I'm going to figure out if that's still needed... Haiku tests all seem to pass without it.)
But #8019 revised our pretty-printing so as not to use `Var.count`. Instead it chose how to pretty-print Var instances based on their order of appearance in a jaxpr. That meant `Var.count` really wasn't useful anymore.
So this PR removes `Var.count`. Since we no longer have `Var.count`, we also don't need core.gensym to take an optional sequence of jaxprs, since that was just used to set the starting count index for new `Var`s.
In fact, `Var.__repr__` and `JaxprEqn.__repr__` were made confusing after #8019, since they could print variable names totally different from the names that would appear when the same `JaxprEqn` or `Var` objects were printed as part of a jaxpr. That is, before this PR we might have a jaxpr which printed like:
```
import jax
def f(x):
for _ in range(3):
x = jax.numpy.sin(x)
return x
jaxpr = jax.make_jaxpr(f)(3.)
print(jaxpr)
# { lambda ; a:f32[]. let
# b:f32[] = sin a
# c:f32[] = sin b
# d:f32[] = sin c
# in (d,) }
_, eqn, _ = jaxpr.jaxpr.eqns
print(eqn)
# a:f32[] = sin b
```
Notice the variable names in the equation pretty-print don't correspond to any in the jaxpr pretty-print!
So this PR changes `JaxprEqn.__repr__` and `Var.__repr__` to show `Var` object ids, and in general just do less formatting (which seems consistent with the spirit of `__repr__`):
```
JaxprEqn(invars=[Var(id=140202705341552):float32[]], outvars=[Var(id=140202705339584):float32[]], primitive=sin, params={}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f837c73d770>, name_stack=NameStack(stack=())))
```
PiperOrigin-RevId: 607664497
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.
This is a follow up to #18008.
This change prepares for allowing more flexible tag matching. For example, we may want to write "gpu" in a test and have it match both "cuda" and "rocm" devices, which we cannot do under the current API but can easily do under this design.
Replace uses of device_under_test() in a context that performs an equality test with a call to test_device_matches().
Replace uses of if_device_under_test() with test_device_matches() and delete if_device_under_test().
PiperOrigin-RevId: 568923117
fixes#14397
For autodiff purposes (and eventually for evaluation implementation purposes)
we need to distinguish between pow :: inexact -> int -> inexact (which is
differentiable at (0., 0)) and pow :: inexact -> inexact -> inexact (which
isn't); see https://github.com/google/jax/issues/14397#issuecomment-1426386290.
Instead of making a new primitive, we made the old one polymorphic and switch
its behavior on the element type of its second argument.
There were also some other cases with special handling for algorithmic reasons
(e.g. doing binary exponentiation), so these autodiff cases had to be merged
with those algorithmic cases.
Co-authored-by: Roy Frostig <frostig@google.com>
Previously we had a number of APIs in core.py that operated on dimensions
and shapes and delegated to instances of DimensionHandler. We remove most
of those APIs because by now they ended up doing very little, e.g.,
`core.sum_dim` was the same as `operator.add`, and `core.sum_shape` was
the same as `tuple(map(operator.add))`.
We also remove the whole `DimensionHandler` machinery because by now
the only other use of non-constant dimensions using this mechanism
are the symbolic dimensions used for shape polymorphism, and those
support now full operator overloading. (When we introduced `DimensionHandler`
we had the masking transformation around that needed it also.)
Previously we had a number of APIs in core.py that operated on dimensions
and shapes and delegated to instances of DimensionHandler. We remove most
of those APIs because by now they ended up doing very little, e.g.,
`core.sum_dim` was the same as `operator.add`, and `core.sum_shape` was
the same as `tuple(map(operator.add))`.
We also remove the whole `DimensionHandler` machinery because by now
the only other use of non-constant dimensions using this mechanism
are the symbolic dimensions used for shape polymorphism, and those
support now full operator overloading. (When we introduced `DimensionHandler`
we had the masking transformation around that needed it also.)
There are boundaries across which segment lengths should flow as
explicit, object-level arguments, instead of as "metadata". To cross
such a boundary, we need to be able to extract the embedded segment
lengths from a RaggedAxis object, replacing them with references to
known object-level arguments; and we need to be able to reconstruct
the RaggedAxis object as metadata by resolving those references.
This touches _gather_batching_rule because slicing is implemented as a
gather, but we only test the case exercised by the slice that occurs
in our test transformer model, namely the unstack operation
q, k, v = qkv
(which turns into three slices on an non-batched and non-ragged axis).
Co-authored-by: Matthew Johnson <mattjj@google.com>
The advantage (already being realized) is that the batching rules
become much simpler: we just batch along the stacked axis as always,
and when a reduction is about to occur, also mask out the padding
elements, replacing them with the identity element of the reduction.
This commit
- Changes the intended representation of data for piles and the
corresponding BatchTracers.
- Re-defines ConcatAxis as RaggedAxis to represent the metadata.
- Updates `defreducer` to require the identity function (in case
masking is needed), and supplies it everywhere.
- Flushes batching.segment_sum, as it is dead code now.
- Deletes unpack_concat_axes and reassemble_concat_axes, because they
are irrelevant to the padded representation.
--
57af5360a1ca1356dbf7760c76e241f7134ef6dd by Jake VanderPlas <jakevdp@google.com>:
[Rollback] Update required Python version to 3.9
PiperOrigin-RevId: 528905991
Fixesiree-org/iree-jax#57
An alternative fix would've been just to add the dtype attribute to IreeBuffer.
But it seems better not to make demands on the underlying runtime objects when
we don't need to.
I had to run the test with:
`JAX_PLATFORM_NAME=iree JAX_ARRAY=0 JAX_JIT_PJIT_API_MERGE=0 python tests/dynamic_api_test.py DynamicShapeTest.test_iree_buffer_doesnt_need_dtype_attribute`