From fcf0b6d3daca6001cfa190f433553d5e85a86796 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Tue, 3 Dec 2024 16:24:34 -0800 Subject: [PATCH] Add _raw_platform to work around extra platform normalization logic and enable GPU aot compilation without a GPU present. Fixes https://github.com/jax-ml/jax/issues/23971 PiperOrigin-RevId: 702506848 --- jax/_src/interpreters/pxla.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 11df2d38f..7e13285c4 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2220,7 +2220,10 @@ def lower_sharding_computation( out_shardings = _concretize_abstract_shardings( out_shardings, global_out_avals, device_assignment) - platforms = lowering_platforms or (backend.platform,) + # TODO(parkers): One _raw_platform has been unified with platform, + # change this back to just read platform. + platforms = lowering_platforms or ( + getattr(backend, "_raw_platform", backend.platform),) committed = bool( devices_from_context or