mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Remove some dead code from gpu_prng.py
This commit is contained in:
parent
5ccfc8d716
commit
bee2bc443a
@ -13,12 +13,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from functools import partial
|
||||
import importlib
|
||||
import itertools
|
||||
import operator
|
||||
|
||||
import jaxlib.mlir.ir as ir
|
||||
|
||||
@ -61,8 +58,6 @@ if _hip_prng:
|
||||
xla_client.register_custom_call_target(_name, _value, platform="ROCM",
|
||||
api_version=api_version)
|
||||
|
||||
_prod = lambda xs: functools.reduce(operator.mul, xs, 1)
|
||||
|
||||
|
||||
def _threefry2x32_lowering(prng, platform: str, keys, data,
|
||||
length: int | ir.Value | None = None,
|
||||
|
Loading…
x
Reference in New Issue
Block a user