mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #12486 from hawkinsp:debugging
PiperOrigin-RevId: 476445041
This commit is contained in:
commit
43bbce0cc6
@ -382,7 +382,7 @@ def inspect_sharding_infer_sharding_from_operands(arg_shapes, arg_shardings,
|
||||
del arg_shapes, shape, backend_string
|
||||
return arg_shardings[0]
|
||||
|
||||
if jaxlib.xla_extension_version >= 94:
|
||||
if jaxlib.xla_extension_version >= 95:
|
||||
xc.register_custom_call_partitioner( # pytype: disable=module-attr
|
||||
_INSPECT_SHARDING_CALL_NAME, inspect_sharding_prop_user_sharding,
|
||||
inspect_sharding_partition, inspect_sharding_infer_sharding_from_operands,
|
||||
|
Loading…
x
Reference in New Issue
Block a user