Refactor approx_top_k lowering to make it easier to understand

There have been two recent changes in this area: 1) migrating TPU lowering from
XLA fallback to MLIR, 2) migrating lowering for other platforms from XLA
fallback to MLIR.

As I was trying to understand whether the versioning code makes sense, I had
a hard time doing that. I think this refactoring makes this easier.

PiperOrigin-RevId: 532102337
This commit is contained in:
Eugene Burmako 2023-05-15 07:20:55 -07:00 committed by jax authors
parent f7f1ddbb1e
commit 843106b73c

View File

@ -443,12 +443,14 @@ approx_top_k_p.def_abstract_eval(_approx_top_k_abstract_eval)
if xc.mlir_api_version > 48:
mlir.register_lowering(approx_top_k_p,
partial(_approx_top_k_lowering, fallback=True))
else:
mlir.register_lowering(approx_top_k_p, _approx_top_k_lowering,
platform='tpu')
elif xc.mlir_api_version == 48:
xla.register_translation(approx_top_k_p, _approx_top_k_fallback_translation)
if xc.mlir_api_version > 47:
mlir.register_lowering(approx_top_k_p, _approx_top_k_lowering,
platform='tpu')
else:
xla.register_translation(approx_top_k_p, _approx_top_k_fallback_translation)
xla.register_translation(approx_top_k_p, _approx_top_k_tpu_translation,
platform='tpu')
batching.primitive_batchers[approx_top_k_p] = _approx_top_k_batch_rule