mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Merge pull request #18338 from froystig:partitionable-threefry-ctx-mgr
PiperOrigin-RevId: 578325021
This commit is contained in:
commit
49fedb1c52
@ -67,6 +67,7 @@ from jax._src.config import (
|
||||
numpy_rank_promotion as numpy_rank_promotion,
|
||||
jax2tf_associative_scan_reductions as jax2tf_associative_scan_reductions,
|
||||
legacy_prng_key as legacy_prng_key,
|
||||
threefry_partitionable as threefry_partitionable,
|
||||
transfer_guard as transfer_guard,
|
||||
transfer_guard_host_to_device as transfer_guard_host_to_device,
|
||||
transfer_guard_device_to_device as transfer_guard_device_to_device,
|
||||
|
@ -683,6 +683,7 @@ class _ThreadLocalExtraJitContext(NamedTuple):
|
||||
numpy_dtype_promotion: Optional[str] = None
|
||||
default_matmul_precision: Optional[Any] = None
|
||||
dynamic_shapes: bool = False
|
||||
threefry_partitionable: bool = False
|
||||
softmax_custom_jvp: bool = False
|
||||
xla_profile_version: int = 0
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user