* shapecheck of jit, device_put, broadcast_in_dim, better error for unsupported ops, parse multi-digit integer literals
* WIP shapecheck np.pad
* Implement shapecheck of gather, pad
* Fix shapecheck of pad
* Implement polymorphic shape rule for (strided/dilated) convolution, refactor
* Cleanup
* Fix
* Remove all polymorphic shape rules, reuse shape rules instead.
* Register shape_rule for all standard_primitives
* Remove ShapeExpr, canonicalize_poly, renames
* Complete shapecheck(binop) implementation, remove special cases for polymorphic shapes
* Allow Poly of form d*poly + k to be divided by d
* Fix bug, inline poly_without_zeros.
One issue with nested pmaps on multihost platforms is inferring the global
pmap axis size without communication. This commit sidesteps the issue by adding
an `axis_size` argument to manually provide this information.
This change only enables a single cross-host pmap; all inner pmaps must be
single-host.
Addressing: #1753
This is useful for building higher level array libraries around JAX, because it
makes it possible to override operations like `jax_array + other`.
I think I covered all the array types that JAX should be able to handle:
- Python builtin numbers int, float and complex
- NumPy scalars
- NumPy arrays
- JAX array types and tracers
Did I miss anything? Maybe bfloat16 scalars?
* Added batching to cpu triangular_solver
* addressed comments about int overflows and returned triangular solve to use XLA over LAPACK
* add todo to benchmark LAPACK vs XLA
* WIP: real valued fft functions
Note: The transpose rule is not correct yet (hence the failing tests).
* Fix transpose rules for rfft and irfft
* Typo fix
* fix test failures in x64 mode
* Add 1d/2d real fft functions, plus docs
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes#1431
See https://github.com/google/jax/pull/1668 for more.
* Implement batched Cholesky decomposition on CPU and GPU using LAPACK and cuSolver.
Adds support for complex batched Cholesky decomposition on both platforms..
Fix concurrency bug in batched cuBlas kernels where a host to device memcpy could take place too early before the device buffer was ready.