mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add abstracted axes to pjit to make jax2tf tests pass. abstracted_axes and dynamic_shapes is not supported by pjit yet.
PiperOrigin-RevId: 502138836
This commit is contained in:
parent
4601928277
commit
1209ab17e4
@ -295,11 +295,12 @@ if jax.config.jax_jit_pjit_api_merge:
|
||||
device: Optional[xc.Device] = None,
|
||||
backend: Optional[str] = None,
|
||||
inline: bool = False,
|
||||
abstracted_axes: Optional[Any] = None,
|
||||
) -> stages.Wrapped:
|
||||
(in_axis_resources, out_axis_resources, donate_argnums, static_argnums,
|
||||
static_argnames) = pjit.pre_infer_params(
|
||||
fun, in_axis_resources, out_axis_resources, donate_argnums,
|
||||
static_argnums, static_argnames, device, backend)
|
||||
static_argnums, static_argnames, device, backend, abstracted_axes)
|
||||
|
||||
def infer_params(*args, **kwargs):
|
||||
pjit_info_args = pjit.PjitInfo(
|
||||
@ -311,7 +312,7 @@ if jax.config.jax_jit_pjit_api_merge:
|
||||
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
|
||||
|
||||
return pjit.post_infer_params(fun, infer_params, static_argnums,
|
||||
static_argnames)
|
||||
static_argnames, abstracted_axes)
|
||||
|
||||
|
||||
def _jit(
|
||||
|
@ -174,8 +174,14 @@ def _cpp_pjit(fun: Callable, infer_params_fn, static_argnums, static_argnames):
|
||||
|
||||
|
||||
def pre_infer_params(fun, in_axis_resources, out_axis_resources,
|
||||
donate_argnums, static_argnums, static_argnames, device,
|
||||
backend):
|
||||
donate_argnums, static_argnums, static_argnames, device,
|
||||
backend, abstracted_axes):
|
||||
# TODO(yashkatariya, mattjj): Remove when pjit supports dynamic shapes.
|
||||
if config.jax_dynamic_shapes:
|
||||
raise ValueError("Dynamic shapes is not supported with pjit yet.")
|
||||
if abstracted_axes and not config.jax_dynamic_shapes:
|
||||
raise ValueError("abstracted_axes must be used with --jax_dynamic_shapes")
|
||||
|
||||
check_callable(fun)
|
||||
|
||||
if not config.jax_array and (_is_unspecified(in_axis_resources) or
|
||||
@ -224,8 +230,10 @@ def pre_infer_params(fun, in_axis_resources, out_axis_resources,
|
||||
static_argnames)
|
||||
|
||||
|
||||
def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames):
|
||||
if FLAGS.experimental_cpp_pjit and xla_extension_version >= 115:
|
||||
def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames,
|
||||
abstracted_axes):
|
||||
if (FLAGS.experimental_cpp_pjit and xla_extension_version >= 115 and
|
||||
abstracted_axes is None):
|
||||
wrapped = _cpp_pjit(fun, infer_params_fn, static_argnums, static_argnames)
|
||||
else:
|
||||
wrapped = _python_pjit(fun, infer_params_fn)
|
||||
@ -417,6 +425,7 @@ def pjit(
|
||||
device: Optional[xc.Device] = None,
|
||||
backend: Optional[str] = None,
|
||||
inline: bool = False,
|
||||
abstracted_axes: Optional[Any] = None,
|
||||
) -> stages.Wrapped:
|
||||
"""Makes ``fun`` compiled and automatically partitioned across multiple devices.
|
||||
|
||||
@ -556,7 +565,7 @@ def pjit(
|
||||
(in_axis_resources, out_axis_resources, donate_argnums, static_argnums,
|
||||
static_argnames) = pre_infer_params(
|
||||
fun, in_axis_resources, out_axis_resources, donate_argnums,
|
||||
static_argnums, static_argnames, device, backend)
|
||||
static_argnums, static_argnames, device, backend, abstracted_axes)
|
||||
|
||||
def infer_params(*args, **kwargs):
|
||||
# Putting this outside of wrapped would make resources lexically scoped
|
||||
@ -569,7 +578,8 @@ def pjit(
|
||||
inline=inline, resource_env=resource_env)
|
||||
return common_infer_params(pjit_info_args, *args, **kwargs)
|
||||
|
||||
return post_infer_params(fun, infer_params, static_argnums, static_argnames)
|
||||
return post_infer_params(fun, infer_params, static_argnums, static_argnames,
|
||||
abstracted_axes)
|
||||
|
||||
|
||||
class _ListWithW(list):
|
||||
|
Loading…
x
Reference in New Issue
Block a user