Add _raw_platform to work around extra platform normalization logic and enable

GPU aot compilation without a GPU present.

Fixes https://github.com/jax-ml/jax/issues/23971

PiperOrigin-RevId: 702506848
This commit is contained in:
Parker Schuh 2024-12-03 16:24:34 -08:00 committed by jax authors
parent ceeed909dc
commit fcf0b6d3da

View File

@ -2220,7 +2220,10 @@ def lower_sharding_computation(
out_shardings = _concretize_abstract_shardings(
out_shardings, global_out_avals, device_assignment)
platforms = lowering_platforms or (backend.platform,)
# TODO(parkers): One _raw_platform has been unified with platform,
# change this back to just read platform.
platforms = lowering_platforms or (
getattr(backend, "_raw_platform", backend.platform),)
committed = bool(
devices_from_context or