mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00

In JAX the actual platform on which a computation is run is determined very late, e.g., based on where the data is located. When using AOT lowering or serialization, the computation may execute on a different machine, or even on a platform that is not available at lowering time. This means that it is not safe to write platform-dependent code using Python conditionals, e.g., based on the current default JAX platform. The proper way to do this is to introduce a primitive with platform-specific lowering rules. This change introduces such a primitive along with a user-facing API. See more details in the docstring of lax.platform_dependent.