* Added checking that the output shape has higher or equal rank to input
* Added checking that the broadcast_dims are sorted (required by XLA)
* Relaxed check that operand dimension size can be 1
* Added lax.broadcast_in_dim docstring
Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu>
Co-authored-by: Matthew Johnson <mattjj@csail.mit.edu>
Co-authored-by: David Duvenaud <duvenaud@cs.toronto.edu>
Co-authored-by: Matthew Johnson <mattjj@csail.mit.edu>
Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu>
Co-authored-by: David Duvenaud <duvenaud@cs.toronto.edu>
Co-authored-by: Matthew Johnson <mattjj@csail.mit.edu>
Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu>
Co-authored-by: David Duvenaud <duvenaud@cs.toronto.edu>
Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu>
Co-authored-by: Jacob Kelly <jacob.jin.kelly@gmail.com>
Co-authored-by: David Duvenaud <duvenaud@cs.toronto.edu>
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
* Canonicalize the shape in the wrapper functions in random.py.
This lets the user be more sloppy in using numpy arrays and statically
known DeviceArrays for shapes, and still hit the jit cache. When they
are not, the error is improved.
* Fix some errors.
* No need for the Poly workaround.
* Bypass canonicalization for None shapes in random.py.
Freevars played a very small role, and they can be folded with
the invars. This simplifies the Jaxpr data structure.We remove
the `freevars` field from Jaxpr and from the bound_subjaxprs.
The only non-trivial change is for xla_pmap, where we need
to carry one extra parameter `mapped_invars` with a bitmap
to encode which invars are mapped and which are broadcast.
Previously, the freevars were broadcast.
* shapecheck of jit, device_put, broadcast_in_dim, better error for unsupported ops, parse multi-digit integer literals
* WIP shapecheck np.pad
* Implement shapecheck of gather, pad
* Fix shapecheck of pad
* Implement polymorphic shape rule for (strided/dilated) convolution, refactor
* Cleanup
* Fix
* Remove all polymorphic shape rules, reuse shape rules instead.
* Register shape_rule for all standard_primitives
* Remove ShapeExpr, canonicalize_poly, renames
* Complete shapecheck(binop) implementation, remove special cases for polymorphic shapes
* Allow Poly of form d*poly + k to be divided by d
* Fix bug, inline poly_without_zeros.
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes#1431
See https://github.com/google/jax/pull/1668 for more.