11 Commits

Author SHA1 Message Date
George Necula
7d452adfd3 Add support for dynamic shapes to GPU threefry2x32 custom call.
In presence of dynamic shapes the ThreeFry2x32Descriptor will contain the
value n=-1, and the actual desired output length will be passed as
an additional operand. If the shape is static then the length will be
passed as part of the descriptor.

PiperOrigin-RevId: 497945778
2022-12-27 04:48:26 -08:00
Eugene Burmako
b8ae8e3fa1 (NFC) Prepare for migration from producing MHLO to producing StableHLO
This CL renames occurrences of "mhlo" in: 1) names, 2) tests, 3) prose in order
to prepare for the upcoming migration.

Unchanged occurrences:
  1) Public API that contains "mhlo", e.g. XlaLowering.mhlo and the "mhlo"
     argument value in Lowering.as_text and Lowering.compiler_ir.
  2) Documentation (changelog, JEPs, IR examples, etc).
  3) One rare situation where prose says "StableHLO" and "MHLO" in one sentence,
     so both are necessary to disambiguate.

PiperOrigin-RevId: 495771153
2022-12-15 21:00:07 -08:00
George Necula
ac7740513d Raise error for unsupported shape polymorphism for custom call and fallback lowering 2022-12-14 12:31:18 +01:00
Peter Hawkins
a852710a09 Merge CUDA and ROCM kernel code in jaxlib.
The code for both CUDA and ROCM is almost identical, so with a small shim library to handle the differences we can share almost everything.

PiperOrigin-RevId: 483666051
2022-10-25 07:23:34 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
jax authors
840c96692e Internal change
PiperOrigin-RevId: 468509799
2022-08-18 11:39:07 -07:00
Peter Hawkins
1e241dcf16 Catch ModuleNotFoundError instead of ImportError.
We frequently use the pattern
try:
  import m
except ImportError:
  # do something else.

This suppresses errors when the module can be found but does not import
successfully for any reason. Instead, catch only ModuleNotFoundError so
missing modules are allowed but buggy modules still report errors.
2022-08-18 15:22:49 +00:00
Rohit Santhanam
a5c385327c Fix unit test failures on ROCm. 2022-05-09 13:36:48 +00:00
Peter Hawkins
883cf2b1e9 Refactor custom call building code in jaxlib to use a helper function.
Refactoring only, no functional changes intended.

This should fix a jaxlib build issue on Windows: we only have one constructor of layouts, and it explicitly requests an int64 type.

Fixes https://github.com/google/jax/issues/10474

PiperOrigin-RevId: 447076192
2022-05-06 14:51:24 -07:00
Peter Hawkins
08c3c2ec24 Split CUDA and HIP C++ code in jaxlib into separate directories.
PiperOrigin-RevId: 447062506
2022-05-06 13:48:00 -07:00
Peter Hawkins
4618f9ce03 Consolidate hip_prng and cuda_prng.
The Python code in jaxlib to build AMD HIP (ROCM) and NVIDIA CUDA kernels is almost identical. Share that Python code rather than duplicating it.

This change only updates the prng kernels; the idea would be to follow it with similar changes consolidating the other Python code in jaxlib between CUDA and HIP.

PiperOrigin-RevId: 446761784
2022-05-05 10:55:29 -07:00