Remove some dead code from gpu_prng.py

This commit is contained in:
Peter Hawkins 2024-10-29 09:29:56 -04:00
parent 5ccfc8d716
commit bee2bc443a

View File

@ -13,12 +13,9 @@
# limitations under the License.
from __future__ import annotations
import functools
from functools import partial
import importlib
import itertools
import operator
import jaxlib.mlir.ir as ir
@ -61,8 +58,6 @@ if _hip_prng:
xla_client.register_custom_call_target(_name, _value, platform="ROCM",
api_version=api_version)
_prod = lambda xs: functools.reduce(operator.mul, xs, 1)
def _threefry2x32_lowering(prng, platform: str, keys, data,
length: int | ir.Value | None = None,