The code for both CUDA and ROCM is almost identical, so with a small shim library to handle the differences we can share almost everything.
PiperOrigin-RevId: 483666051
fix some shape and type issues
import into namespace
imports into non-_src library
working logpdf test
cleanup
working tests for cdf and sf after fixing select
relax need for x to be in (a, b)
ensure behavior with invalid input matches scipy
remove enforcing valid parameters in tests
added truncnorm to docs
whoops alphabetical
fix linter error
fix circular import issue
The shape function of DotGeneralOp can't be integrated into MHLO yet: the shape function only predicts return shape but not able to predict element type. However, the current python binding infra will generate the constructor __init__() without the `return` as the first arg, which assumes the shape function can provide a fully inferred type (including an accurate element type). This leads to "inferred type does not match actual result type" errors in JAX. This needs a future solution.
This CL is the corresponding change with https://github.com/openxla/stablehlo/pull/269
Related Python __init__() interface changes (used by JAX):
batch_norm_grad: not used by JAX
batch_norm_inference: not used by JAX
batch_norm_training: not used by JAX
case: no change*
dot_general: open new b/253644255 to track the issue
if: no change*
map: no change*
reduce: no change*
reduce_window: no change*
sort: no change*
triangular_solve: updated in `linalg.py`
while: no change*
no change*: the signature of __init()__ for the op is not changed because of existence of regions https://github.com/llvm/llvm-project/blob/main/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp#L577
PiperOrigin-RevId: 482951512
This ensures all existing JAX buffer types have a `delete` method that can be used to free device buffer allocation eagerly.
User code sometimes have lingering python refs due to cyclic deps and other reasons, yet users may know for sure that certain arrays will no longer be used after a certain point. Calling `foo_array.delete()` for DeviceArray/ShardedDeviceArray/GlobalDeviceArray/Array allows users to force free the device side allocation to minimize device memory usage.
PiperOrigin-RevId: 482892157