This should allow us to try out xmap not only in a simulation (i.e.
faking the devices using vmap, which we still support), but also on real
hardware.
Limitations:
- No compilation caching yet
- Nested xmaps not supported yet
- Transforms (AD, vmap, etc.) of xmaps not supported yet
Benefits:
- An xmap over multiple mesh axes already implements a more efficient
lowering than the one used for nested pmaps.
The `resources` context-manager is now called `fake_resources`, while
real meshes can be defined in a specific context using the
`mesh(devices, axis_names)` manager. `devices` is supposed to be an
`ndarray` of JAX device objects (e.g. obtained from `jax.devices()`),
while `axis_names` should be a tuple of length matching the rank of
`devices` and specifying mesh axis names.
For concrete examples see the changes in `gmap_tests.py`.
In principle the current version of the code should also work in a
multi-host setting, but I haven't tested it just yet.
... and in map primitives in general (which is why the patch touches
most traces).
This also fixes a bug in the transpose rule for map primitives, which
would fail to adjust the aval associated with zeros returned from the
map body.
... and in map primitives in general (which is why the patch touches
most traces).
This also fixes a bug in the transpose rule for map primitives, which
would fail to adjust the aval associated with zeros returned from the
map body.
All initial style primitives currently use `batch_jaxpr` in their
batching rules, but that function hasn't been updated to support
axis_name when I added support for vmap collectives.
Prior to this it was possible, e.g., for code that contains a Literal,
such as to result in FLOPS during checking.
The assertion is broken by many tests unless we raise_to_shape for Literals.
I have timed the checks on my laptop and I do not see a reduction in the
total test time.
The main change is that we use `core.new_base_main` to use an
omnistaging-based tracer. This has the benefit that we can
convert to TF even functions with no arguments (previously
they would be constant-folded by JAX prior to the conversion).
We also add an explicit error if the jax2tf.convert transformation
is nested under other JAX transformations.
- Add float0 and set-up at_least_vspace to return float0
values for int/bool primals
- Use Zero to wrap float0 tangents so they're correctly ignored in jvp
rules
- Add float0 handlers to XLA to support jit
- Fix convert_element_type and tie_in jvp rules
The primitive was moved to `lax_parallel.py` some time ago, so the one
in `core` should no longer be used. This is probably a result of a
botched rebase.
Previously, given this function:
```python
@jax.jit
def f(x,y):
if x > y:
return x
else:
return y
```
we'd get an error message like this (after #4038, improved to help with
omnistaging debugging):
```
...
While tracing the function f at tim.py:3, this value became a tracer due to JAX operations on these lines:
operation c:bool[] = gt a:int32[] b:int32[]
from line tim.py:5 (f)
...
```
But this message is buggy! In this case, the value is a tracer because
it has a data dependence on arguments to a jitted function.
After this change, we instead produce this error message:
```
...
While tracing the function f at tim.py:3, this concrete value was not available in Python because it depends on the value of the arguments to f at tim.py:3 at positions [0, 1], and the computation of these values is being staged out.
...
```
I'm eager to iterate with further improvements, but for now I want to
fix this buggy message.
rename and simplify TypedJaxpr -> ClosedJaxpr
This change:
* simplifies code that constructs TypedJaxprs/ClosedJaxprs (because
in_avals / out_avals no longer need to be constructed), making them
easier to work with;
* correspondingly rules out a class of errors (mismatches between
invars/outvars and in_avals/out_avals);
* provides a more descriptive class name (ClosedJaxprs are like jaxprs
but they're closed in that they are packaged with their constant
values).
This is part 1 of an attempt to remove TypedJaxprs completely, or at
least significantly reduce our use of them. However, I'm not getting rid
of them entirely in this first step because it'd require bigger changes
(basically allowing all constants to be represented as literals, rather
than only scalars) that would not only touch a lot more code (jaxpr
formation, jaxpr-to-jaxpr transformations, control flow, XLA lowering)
but also might affect XLA lowering right before a conference deadline
(ICLR). Plus I'm trying to make big changes in smaller steps :)
Co-authored-by: George Necula <gcnecula@gmail.com>
Changes:
- Fix unnecessary generator
- Iterate dictionary directly instead of calling .keys()
- Remove global statement at the module level
- Use list() instead of a list comprehension
- Use with statement to open the file
- Merge isinstance calls
* improve an escaped tracer error message
Before this commit, encountering an escaped tracer in a specific way
would lead to a bad internal error. This change
1. raises an UnexpectedTracerError instead, and
2. includes in the error message the user source line which created the
tracer.
* deflake
* replace _live propety with _assert_live method
Thanks @jekbradbury !
This is normally unnecessary, because the XLA translation usually
doesn't bind any of the primitives in the jaxpr, but this is not true in
case of scan! Its translation rule reevaluates the jaxpr as a function,
and if it contains collectives such as `axis_index` it can fail due to
axis being missing.
Some of the vmap and gmap collective tests have been failing on master
and I can't seem to be able to reproduce them locally. Hopefully, if
this happens again, this extra bit of information will be useful in
debugging the problem.