Add abstracted axes to pjit to make jax2tf tests pass. abstracted_axes and dynamic_shapes is not supported by pjit yet.

PiperOrigin-RevId: 502138836
This commit is contained in:
Yash Katariya 2023-01-14 20:16:57 -08:00 committed by jax authors
parent 4601928277
commit 1209ab17e4
2 changed files with 19 additions and 8 deletions

View File

@ -295,11 +295,12 @@ if jax.config.jax_jit_pjit_api_merge:
device: Optional[xc.Device] = None,
backend: Optional[str] = None,
inline: bool = False,
abstracted_axes: Optional[Any] = None,
) -> stages.Wrapped:
(in_axis_resources, out_axis_resources, donate_argnums, static_argnums,
static_argnames) = pjit.pre_infer_params(
fun, in_axis_resources, out_axis_resources, donate_argnums,
static_argnums, static_argnames, device, backend)
static_argnums, static_argnames, device, backend, abstracted_axes)
def infer_params(*args, **kwargs):
pjit_info_args = pjit.PjitInfo(
@ -311,7 +312,7 @@ if jax.config.jax_jit_pjit_api_merge:
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
return pjit.post_infer_params(fun, infer_params, static_argnums,
static_argnames)
static_argnames, abstracted_axes)
def _jit(

View File

@ -174,8 +174,14 @@ def _cpp_pjit(fun: Callable, infer_params_fn, static_argnums, static_argnames):
def pre_infer_params(fun, in_axis_resources, out_axis_resources,
donate_argnums, static_argnums, static_argnames, device,
backend):
donate_argnums, static_argnums, static_argnames, device,
backend, abstracted_axes):
# TODO(yashkatariya, mattjj): Remove when pjit supports dynamic shapes.
if config.jax_dynamic_shapes:
raise ValueError("Dynamic shapes is not supported with pjit yet.")
if abstracted_axes and not config.jax_dynamic_shapes:
raise ValueError("abstracted_axes must be used with --jax_dynamic_shapes")
check_callable(fun)
if not config.jax_array and (_is_unspecified(in_axis_resources) or
@ -224,8 +230,10 @@ def pre_infer_params(fun, in_axis_resources, out_axis_resources,
static_argnames)
def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames):
if FLAGS.experimental_cpp_pjit and xla_extension_version >= 115:
def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames,
abstracted_axes):
if (FLAGS.experimental_cpp_pjit and xla_extension_version >= 115 and
abstracted_axes is None):
wrapped = _cpp_pjit(fun, infer_params_fn, static_argnums, static_argnames)
else:
wrapped = _python_pjit(fun, infer_params_fn)
@ -417,6 +425,7 @@ def pjit(
device: Optional[xc.Device] = None,
backend: Optional[str] = None,
inline: bool = False,
abstracted_axes: Optional[Any] = None,
) -> stages.Wrapped:
"""Makes ``fun`` compiled and automatically partitioned across multiple devices.
@ -556,7 +565,7 @@ def pjit(
(in_axis_resources, out_axis_resources, donate_argnums, static_argnums,
static_argnames) = pre_infer_params(
fun, in_axis_resources, out_axis_resources, donate_argnums,
static_argnums, static_argnames, device, backend)
static_argnums, static_argnames, device, backend, abstracted_axes)
def infer_params(*args, **kwargs):
# Putting this outside of wrapped would make resources lexically scoped
@ -569,7 +578,8 @@ def pjit(
inline=inline, resource_env=resource_env)
return common_infer_params(pjit_info_args, *args, **kwargs)
return post_infer_params(fun, infer_params, static_argnums, static_argnames)
return post_infer_params(fun, infer_params, static_argnums, static_argnames,
abstracted_axes)
class _ListWithW(list):