mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #23741 from awshaichen:neuron
PiperOrigin-RevId: 678006682
This commit is contained in:
commit
2ac1d0b8d0
@ -72,7 +72,7 @@ def is_cache_used(backend: xla_client.Client) -> bool:
|
||||
# backend that supports serialization of executables.
|
||||
# TODO(skye): add warning when initializing cache on unsupported default
|
||||
# platform
|
||||
supported_platforms = ["tpu", "gpu", "cpu"]
|
||||
supported_platforms = ["tpu", "gpu", "cpu", "neuron"]
|
||||
|
||||
if not _is_cache_enabled():
|
||||
monitoring.record_event('/jax/compilation_cache/task_disabled_cache')
|
||||
|
@ -951,7 +951,7 @@ class LoweringResult(NamedTuple):
|
||||
shape_poly_state: ShapePolyLoweringState
|
||||
|
||||
|
||||
_platforms_with_donation = ["cpu", "cuda", "rocm", "tpu"]
|
||||
_platforms_with_donation = ["cpu", "cuda", "rocm", "tpu", "neuron"]
|
||||
|
||||
|
||||
def add_manual_axes(axis_ctx: sharding_impls.SPMDAxisContext, sharding, ndim):
|
||||
|
Loading…
x
Reference in New Issue
Block a user