mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add _raw_platform to work around extra platform normalization logic and enable
GPU aot compilation without a GPU present. Fixes https://github.com/jax-ml/jax/issues/23971 PiperOrigin-RevId: 702506848
This commit is contained in:
parent
ceeed909dc
commit
fcf0b6d3da
@ -2220,7 +2220,10 @@ def lower_sharding_computation(
|
||||
out_shardings = _concretize_abstract_shardings(
|
||||
out_shardings, global_out_avals, device_assignment)
|
||||
|
||||
platforms = lowering_platforms or (backend.platform,)
|
||||
# TODO(parkers): One _raw_platform has been unified with platform,
|
||||
# change this back to just read platform.
|
||||
platforms = lowering_platforms or (
|
||||
getattr(backend, "_raw_platform", backend.platform),)
|
||||
|
||||
committed = bool(
|
||||
devices_from_context or
|
||||
|
Loading…
x
Reference in New Issue
Block a user