Merge pull request #12486 from hawkinsp:debugging

PiperOrigin-RevId: 476445041
This commit is contained in:
jax authors 2022-09-23 13:09:26 -07:00
commit 43bbce0cc6

View File

@ -382,7 +382,7 @@ def inspect_sharding_infer_sharding_from_operands(arg_shapes, arg_shardings,
del arg_shapes, shape, backend_string
return arg_shardings[0]
if jaxlib.xla_extension_version >= 94:
if jaxlib.xla_extension_version >= 95:
xc.register_custom_call_partitioner( # pytype: disable=module-attr
_INSPECT_SHARDING_CALL_NAME, inspect_sharding_prop_user_sharding,
inspect_sharding_partition, inspect_sharding_infer_sharding_from_operands,