mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Refactor approx_top_k lowering to make it easier to understand
There have been two recent changes in this area: 1) migrating TPU lowering from XLA fallback to MLIR, 2) migrating lowering for other platforms from XLA fallback to MLIR. As I was trying to understand whether the versioning code makes sense, I had a hard time doing that. I think this refactoring makes this easier. PiperOrigin-RevId: 532102337
This commit is contained in:
parent
f7f1ddbb1e
commit
843106b73c
@ -443,12 +443,14 @@ approx_top_k_p.def_abstract_eval(_approx_top_k_abstract_eval)
|
||||
if xc.mlir_api_version > 48:
|
||||
mlir.register_lowering(approx_top_k_p,
|
||||
partial(_approx_top_k_lowering, fallback=True))
|
||||
else:
|
||||
mlir.register_lowering(approx_top_k_p, _approx_top_k_lowering,
|
||||
platform='tpu')
|
||||
elif xc.mlir_api_version == 48:
|
||||
xla.register_translation(approx_top_k_p, _approx_top_k_fallback_translation)
|
||||
if xc.mlir_api_version > 47:
|
||||
mlir.register_lowering(approx_top_k_p, _approx_top_k_lowering,
|
||||
platform='tpu')
|
||||
else:
|
||||
xla.register_translation(approx_top_k_p, _approx_top_k_fallback_translation)
|
||||
xla.register_translation(approx_top_k_p, _approx_top_k_tpu_translation,
|
||||
platform='tpu')
|
||||
batching.primitive_batchers[approx_top_k_p] = _approx_top_k_batch_rule
|
||||
|
Loading…
x
Reference in New Issue
Block a user