Merge pull request #23741 from awshaichen:neuron

PiperOrigin-RevId: 678006682
This commit is contained in:
jax authors 2024-09-23 17:41:11 -07:00
commit 2ac1d0b8d0
2 changed files with 2 additions and 2 deletions

View File

@ -72,7 +72,7 @@ def is_cache_used(backend: xla_client.Client) -> bool:
# backend that supports serialization of executables.
# TODO(skye): add warning when initializing cache on unsupported default
# platform
supported_platforms = ["tpu", "gpu", "cpu"]
supported_platforms = ["tpu", "gpu", "cpu", "neuron"]
if not _is_cache_enabled():
monitoring.record_event('/jax/compilation_cache/task_disabled_cache')

View File

@ -951,7 +951,7 @@ class LoweringResult(NamedTuple):
shape_poly_state: ShapePolyLoweringState
_platforms_with_donation = ["cpu", "cuda", "rocm", "tpu"]
_platforms_with_donation = ["cpu", "cuda", "rocm", "tpu", "neuron"]
def add_manual_axes(axis_ctx: sharding_impls.SPMDAxisContext, sharding, ndim):