74 Commits

Author SHA1 Message Date
Peter Hawkins
3fa557289a Port tests away from setUpClass and setUpModule to setUp alone.
This change prepares for upcoming changes in which we run tests in parallel using threads, which we are doing partially to test free threading but also partially to speed up TPU tests via thread-parallelism.

If independent tests run in parallel in no particular order, there's no natural scope around which to call setUpClass or SetUpModule. But for JAX tests this never seems necessary: we can just do the same work in setUp() or do it globally.

PiperOrigin-RevId: 713296722
2025-01-08 08:14:50 -08:00
Matthew Johnson
0a73d74a4e simplify conversion logic involving extended dtypes
Previously, the idea was that we would use the `convert_element_type` primitive
to cast to/from extended dtypes. Extended dtype rules specified
`convert_from(dtype1, dtype2) -> bool` and `convert_to(dtype1, dtype2) -> bool`
functions. They were meant to do something like indicate whether a
convert_element_type was legal. But I'm not sure if they really made sense.
The implementation was certainly buggy for non-scalar representation types
(physical element types).

This PR simplifies and fixes things:
1. Instead of overloading the `convert_element_type_p` primitive with more cases
involving casts to/from extended dtypes, let's just have distinct `to_edtype_p`
and `from_edtype_p` primitives, which can be much simpler. We still reuse the
`jax.lax.convert_element_type` API function, so there's no API change to the
few existing users who know about this stuff.
2. Instead of extended dtype rules including `convert_from`/`convert_to`
functions with questionable semantics, let's only allow casts to/from the
representation type, which is already specified by the rules'
`physical_element_aval`. (Indeed that should be roughly _all_ we need, and this
PR is just one step towards realizing that goal.) We still have a boolean
`allow_conversion` on extended dtype rules just so we can handle the PRNGKey
case, where we don't want to allow any casts.
3. Fix the conversion logic to handle non-scalar representation types (physical
element types).
2024-09-25 00:10:01 +00:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Peter Hawkins
34ce9f21db Simplify implementation of _broadcast_to.
_broadcast_to needlessly squeezes away size 1 dimensions before passing its input to broadcast_in_dim. But broadcast_in_dim is perfectly happy to broadcast size 1 dimensions, so we don't need this squeeze.
2024-07-24 10:57:54 -04:00
jax authors
787b7c262f Merge pull request #21414 from mattjj:pjit-forwarding-rule
PiperOrigin-RevId: 637248964
2024-05-25 11:57:07 -07:00
Matthew Johnson
0a693faf48 add pjit forwarding rule
Co-authored-by: Roy Frostig <frostig@google.com>
2024-05-25 17:46:01 +00:00
Sergei Lebedev
2473ebf508 Removed mentions of iree from the test suite 2024-05-24 10:31:57 +01:00
jax authors
35fdf5da24 [Easy][djax] Better logging for name:dim mismatch
PiperOrigin-RevId: 636000226
2024-05-21 19:22:38 -07:00
Jake VanderPlas
f090074d86 Avoid 'from jax import config' imports
In some environments this appears to import the config module rather than
the config object.
2024-04-11 13:23:27 -07:00
Peter Hawkins
67df647988 Reland https://github.com/google/jax/pull/10573.
The original PR was reverted because of downstream breakage.

Originally we used the `Var.count` attribute to ensure `Var` instances were printed consistently regardless of context, even though only their object id was load-bearing. That is, `Var.count` was only used for pretty printing. (#1949 added a total_ordering on `Var` for reasons out of scope of JAX's core code. I'm going to figure out if that's still needed... Haiku tests all seem to pass without it.)

But #8019 revised our pretty-printing so as not to use `Var.count`. Instead it chose how to pretty-print Var instances based on their order of appearance in a jaxpr. That meant `Var.count` really wasn't useful anymore.

So this PR removes `Var.count`. Since we no longer have `Var.count`, we also don't need core.gensym to take an optional sequence of jaxprs, since that was just used to set the starting count index for new `Var`s.

In fact, `Var.__repr__` and `JaxprEqn.__repr__` were made confusing after #8019, since they could print variable names totally different from the names that would appear when the same `JaxprEqn` or `Var` objects were printed as part of a jaxpr. That is, before this PR we might have a jaxpr which printed like:

```
import jax

def f(x):
  for _ in range(3):
    x = jax.numpy.sin(x)
  return x

jaxpr = jax.make_jaxpr(f)(3.)
print(jaxpr)
# { lambda ; a:f32[]. let
#     b:f32[] = sin a
#     c:f32[] = sin b
#     d:f32[] = sin c
#   in (d,) }

_, eqn, _ = jaxpr.jaxpr.eqns
print(eqn)
# a:f32[] = sin b
```

Notice the variable names in the equation pretty-print don't correspond to any in the jaxpr pretty-print!

So this PR changes `JaxprEqn.__repr__` and `Var.__repr__` to show `Var` object ids, and in general just do less formatting (which seems consistent with the spirit of `__repr__`):
```
JaxprEqn(invars=[Var(id=140202705341552):float32[]], outvars=[Var(id=140202705339584):float32[]], primitive=sin, params={}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f837c73d770>, name_stack=NameStack(stack=())))
```

PiperOrigin-RevId: 607664497
2024-02-16 05:57:12 -08:00
Jake VanderPlas
a1ee8c1743 Improve shape validation when jax_dynamic_shapes=True 2023-12-12 13:58:46 -08:00
Peter Hawkins
baa77562e5 Use scoped disable_jit() in dynamic_api_test.
This test was leaving jit disabled, affecting other tests.

PiperOrigin-RevId: 587803847
2023-12-04 12:20:36 -08:00
Sergei Lebedev
cbcaac2756 MAINT Migrate remaining internal/test modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.
2023-10-12 17:32:15 +01:00
Peter Hawkins
1885c4933c Add a new internal test utility test_device_matches() and use it instead of equality tests on device_under_test().
This change prepares for allowing more flexible tag matching. For example, we may want to write "gpu" in a test and have it match both "cuda" and "rocm" devices, which we cannot do under the current API but can easily do under this design.

Replace uses of device_under_test() in a context that performs an equality test with a call to test_device_matches().
Replace uses of if_device_under_test() with test_device_matches() and delete if_device_under_test().

PiperOrigin-RevId: 568923117
2023-09-27 12:10:43 -07:00
Matthew Johnson
69ad4df9a5 fix pow_p jvp rule at x=0. y=0
fixes #14397

For autodiff purposes (and eventually for evaluation implementation purposes)
we need to distinguish between pow :: inexact -> int -> inexact (which is
differentiable at (0., 0)) and pow :: inexact -> inexact -> inexact (which
isn't); see https://github.com/google/jax/issues/14397#issuecomment-1426386290.

Instead of making a new primitive, we made the old one polymorphic and switch
its behavior on the element type of its second argument.

There were also some other cases with special handling for algorithmic reasons
(e.g. doing binary exponentiation), so these autodiff cases had to be merged
with those algorithmic cases.

Co-authored-by: Roy Frostig <frostig@google.com>
2023-07-28 17:14:47 -07:00
jax authors
4c0aafb425 Merge pull request #16719 from axch:ragged-test-cases
PiperOrigin-RevId: 548231820
2023-07-14 15:03:13 -07:00
Alexey Radul
fbb587232c Rename Piles to Jumbles, to avoid unfortunate Imperial entanglements. 2023-07-14 15:19:49 -04:00
jax authors
0e538e559d Merge pull request #16713 from gnecula:poly_clean4
PiperOrigin-RevId: 548092798
2023-07-14 04:52:17 -07:00
Alexey Radul
f348366041 Record the example transformer layer as a test case. 2023-07-13 15:59:31 -04:00
Alexey Radul
4345ccddcf Test a few examples with jitting enabled and disabled. 2023-07-13 13:19:38 -04:00
George Necula
71ac0bb446 [shape_poly] More cleanup for the internal APIs for shape polymorphism.
Previously we had a number of APIs in core.py that operated on dimensions
and shapes and delegated to instances of DimensionHandler. We remove most
of those APIs because by now they ended up doing very little, e.g.,
`core.sum_dim` was the same as `operator.add`, and `core.sum_shape` was
the same as `tuple(map(operator.add))`.

We also remove the whole `DimensionHandler` machinery because by now
the only other use of non-constant dimensions using this mechanism
are the symbolic dimensions used for shape polymorphism, and those
support now full operator overloading. (When we introduced `DimensionHandler`
we had the masking transformation around that needed it also.)
2023-07-13 16:37:53 +03:00
George Necula
58d6c4c1ec Roll back #16689
PiperOrigin-RevId: 547773322
2023-07-13 06:05:50 -07:00
George Necula
d21a667235 [shape_poly] More cleanup for the internal APIs for shape polymorphism.
Previously we had a number of APIs in core.py that operated on dimensions
and shapes and delegated to instances of DimensionHandler. We remove most
of those APIs because by now they ended up doing very little, e.g.,
`core.sum_dim` was the same as `operator.add`, and `core.sum_shape` was
the same as `tuple(map(operator.add))`.

We also remove the whole `DimensionHandler` machinery because by now
the only other use of non-constant dimensions using this mechanism
are the symbolic dimensions used for shape polymorphism, and those
support now full operator overloading. (When we introduced `DimensionHandler`
we had the masking transformation around that needed it also.)
2023-07-13 09:59:41 +03:00
Matthew Johnson
e04db23651 Indirectify ragged axes across jitting boundaries, input- and output-side.
Also propagate DShapedArray through at least the simple cases of
shardings that show up in test cases.

Co-authored-by: Alexey Radul <axch@google.com>
2023-07-11 15:21:55 -04:00
Alexey Radul
edf77f9a12 Revive indirection of RaggedAxis batch dims.
There are boundaries across which segment lengths should flow as
explicit, object-level arguments, instead of as "metadata".  To cross
such a boundary, we need to be able to extract the embedded segment
lengths from a RaggedAxis object, replacing them with references to
known object-level arguments; and we need to be able to reconstruct
the RaggedAxis object as metadata by resolving those references.
2023-07-11 15:21:55 -04:00
Alexey Radul
924394297b Test and implement slicing not dropping raggedness information. 2023-07-11 15:21:55 -04:00
Alexey Radul
defe71228c Clearer test names. 2023-07-07 09:23:33 -04:00
Alexey Radul
aa3c49f134 Test a different configuration of einsum.
This version stresses my transpose_ragged_axes method, which, it
seems, was interpreting the permutation the wrong way.  Fixed.
2023-07-07 09:23:33 -04:00
Alexey Radul
89dd69ea2d Test and implement ragged slicing.
This touches _gather_batching_rule because slicing is implemented as a
gather, but we only test the case exercised by the slice that occurs
in our test transformer model, namely the unstack operation
  q, k, v = qkv
(which turns into three slices on an non-batched and non-ragged axis).

Co-authored-by: Matthew Johnson <mattjj@google.com>
2023-07-07 09:23:33 -04:00
Alexey Radul
6f09fe840e Better error message when broadcasting ragged to static shape.
Co-authored-by: Matthew Johnson <mattjj@google.com>
2023-07-07 09:23:29 -04:00
jax authors
63415a9184 Merge pull request #16386 from axch:ragged-einsum
PiperOrigin-RevId: 542887557
2023-06-23 10:00:07 -07:00
Yash Katariya
fc0dcd15a2 Copybara import of the project:
--
57af5360a1ca1356dbf7760c76e241f7134ef6dd by Jake VanderPlas <jakevdp@google.com>:

[Roll forward] Update required Python version to 3.9

PiperOrigin-RevId: 542728213
2023-06-22 18:58:30 -07:00
Alexey Radul
63f912c220 Test and implement ragged einsum. 2023-06-13 17:04:43 -04:00
Alexey Radul
4250aa6777 Test and implement transposition of ragged arrays. 2023-06-13 17:04:40 -04:00
Alexey Radul
978899d7db Test and implement ragged outputs of dot_general. 2023-06-13 17:04:09 -04:00
Alexey Radul
effaf674ae Test and fix jnp.broadcast_to. 2023-06-08 16:17:43 -04:00
Alexey Radul
5aa6cc3542 Test and implement squeeze under ragged batching. 2023-06-08 16:17:43 -04:00
Alexey Radul
611c12ba25 Test and implement broadcast_in_dim for all permutations of ragged axes.
Add tests for
- Broadcasting an already-ragged array
- Broadcasting that creates an array that's ragged in two dimensions
2023-06-08 16:17:43 -04:00
Alexey Radul
241502157c Tidy up in response to first review. 2023-06-03 16:34:30 -04:00
Alexey Radul
e959dd2de3 Force the mask computation into 'int32' more consistently
so as to make tests pass with JAX_ENABLE_X64=1.
2023-05-19 13:49:21 -07:00
Matthew Johnson
1c6a892c7e Improve printing of bints and piles, and allow bints in convert_element_type. 2023-05-19 13:14:48 -07:00
Alexey Radul
2daeec83ce Redefine the pile representation from concatenated to stacked-and-padded.
The advantage (already being realized) is that the batching rules
become much simpler: we just batch along the stacked axis as always,
and when a reduction is about to occur, also mask out the padding
elements, replacing them with the identity element of the reduction.

This commit

- Changes the intended representation of data for piles and the
  corresponding BatchTracers.
- Re-defines ConcatAxis as RaggedAxis to represent the metadata.
- Updates `defreducer` to require the identity function (in case
  masking is needed), and supplies it everywhere.
- Flushes batching.segment_sum, as it is dead code now.
- Deletes unpack_concat_axes and reassemble_concat_axes, because they
  are irrelevant to the padded representation.
2023-05-19 13:13:15 -07:00
Yash Katariya
6506ee2a40 Copybara import of the project:
--
57af5360a1ca1356dbf7760c76e241f7134ef6dd by Jake VanderPlas <jakevdp@google.com>:

[Rollback] Update required Python version to 3.9

PiperOrigin-RevId: 528905991
2023-05-02 15:33:29 -07:00
Jake VanderPlas
57af5360a1 Update required Python version to 3.9 2023-05-01 10:00:57 -07:00
jax authors
13fe3810d2 Merge pull request #15694 from mattjj:djax-reshape
PiperOrigin-RevId: 526194423
2023-04-21 19:42:27 -07:00
Matthew Johnson
84ae14e7d3 [djax] handle simple reshapes and size-0 checks
One of the main changes here is that we don't do division in handling
x.reshape(..., -1) unless we have to.
2023-04-21 19:20:48 -07:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Matthew Johnson
7743fcd758 [dynamic-shapes] make dynamic shape staging-to-jaxpr work with pjit 2023-03-23 20:20:01 -07:00
Peter Hawkins
dea7450e4e Remove references to jax.config.jax_array, which is always True at head.
PiperOrigin-RevId: 516970232
2023-03-15 17:09:11 -07:00
Matthew Johnson
54b889ca7f [dynamic-shapes] don't require buf objects have dtype attribute
Fixes iree-org/iree-jax#57

An alternative fix would've been just to add the dtype attribute to IreeBuffer.
But it seems better not to make demands on the underlying runtime objects when
we don't need to.

I had to run the test with:

`JAX_PLATFORM_NAME=iree JAX_ARRAY=0 JAX_JIT_PJIT_API_MERGE=0 python tests/dynamic_api_test.py DynamicShapeTest.test_iree_buffer_doesnt_need_dtype_attribute`
2023-03-15 12:53:43 -07:00