removes the jax.mask and jax.shapecheck APIs.

PiperOrigin-RevId: 463026577
This commit is contained in:
George Necula 2022-07-25 01:23:11 -07:00 committed by jax authors
parent f5f650fc1c
commit 66dc95e2de
3 changed files with 2 additions and 65 deletions

View File

@ -29,7 +29,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* The `sym_pos` argument of {func}`jax.scipy.linalg.solve` is deprecated in favor of `assume_a='pos'`,
following a similar deprecation in {func}`scipy.linalg.solve`.
* Deprecations:
* {func}`jax.mask` {func}`jax.shapecheck` are being deprecated.
* {func}`jax.mask` {func}`jax.shapecheck` APIs have been removed.
See {jax-issue}`#11557`.
## jaxlib 0.3.15 (Unreleased)

View File

@ -95,7 +95,7 @@ from jax._src.api import (
linearize as linearize,
linear_transpose as linear_transpose,
make_jaxpr as make_jaxpr,
mask as mask,
named_call as named_call,
named_scope as named_scope,
pmap as pmap,
@ -103,7 +103,6 @@ from jax._src.api import (
process_index as process_index,
pxla, # TODO(phawkins): update users to avoid this.
remat as remat,
shapecheck as shapecheck,
ShapedArray as ShapedArray,
ShapeDtypeStruct as ShapeDtypeStruct,
value_and_grad as value_and_grad,

View File

@ -82,7 +82,6 @@ from jax.interpreters import xla
from jax.interpreters import pxla
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import masking
from jax._src.config import (
flags, config, bool_env,
@ -2247,67 +2246,6 @@ def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple,
return lower
def mask(fun: Callable, in_shapes, out_shape=None) -> Callable:
warn("`jax.mask` is deprecated and will be removed soon. ",
DeprecationWarning)
_check_callable(fun)
unique_ids = masking.UniqueIds()
in_specs, in_shapes_tree = tree_flatten(in_shapes)
in_specs = map(masking.parse_spec, in_specs)
in_specs = map(partial(masking.remap_ids, unique_ids), in_specs)
if out_shape is not None:
out_specs, out_spec_tree = tree_flatten(out_shape)
out_specs = map(masking.parse_spec, out_specs)
out_specs = map(partial(masking.remap_ids, unique_ids), out_specs)
def wrapped_fun(args, logical_env):
args_flat, in_tree = tree_flatten(args)
if in_tree != in_shapes_tree:
raise TypeError(f"Tree mismatch: Input {in_tree} and shape spec {in_shapes_tree}.")
logical_env = {unique_ids[name] : val for name, val in logical_env.items()}
in_shapes = map(masking.finalize_spec, in_specs, map(np.shape, args_flat))
padded_env = masking.bind_shapes(in_shapes, [x.shape for x in args_flat])
f = lu.wrap_init(fun)
flat_fun, out_tree_thunk = flatten_fun_nokwargs(f, in_tree)
outs, out_shapes = masking.mask_fun(
flat_fun, logical_env, padded_env, args_flat, in_shapes)
out_tree = out_tree_thunk()
if out_shape is None:
def logical_shape(poly_shape, padded_val):
shape = masking.eval_poly_shape(poly_shape, logical_env)
return ShapeDtypeStruct(shape, core.get_aval(padded_val).dtype)
out_logicals = map(logical_shape, out_shapes, outs)
return tree_unflatten(out_tree, outs), tree_unflatten(out_tree, out_logicals)
else:
masking.check_shapes(out_specs, out_spec_tree, list(out_shapes), out_tree)
def padded_spec(shape_spec):
return tuple(dim if dim is masking._monomorphic_dim else
masking.eval_poly(dim, padded_env) for dim in shape_spec)
masking.check_shapes(map(padded_spec, out_specs), out_spec_tree,
map(np.shape, outs), out_tree, "Padded output")
return tree_unflatten(out_tree, outs)
return wrapped_fun
@curry
def shapecheck(in_shapes, out_shape, fun: Callable):
warn("`jax.shapecheck` is deprecated and will be removed soon. ",
DeprecationWarning)
_check_callable(fun)
in_shapes, in_tree = tree_flatten(in_shapes)
in_shapes = map(masking.parse_spec, in_shapes)
out_specs, out_spec_tree = tree_flatten(out_shape)
out_specs = map(masking.parse_spec, out_specs)
flat_fun, out_tree_thunk = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
avals = map(partial(ShapedArray, dtype=np.float32), in_shapes)
out_shapes = [o.shape for o in pe.abstract_eval_fun(flat_fun.call_wrapped, *avals)]
masking.check_shapes(map(tuple, out_specs), out_spec_tree,
map(tuple, out_shapes), out_tree_thunk())
return fun
def jvp(
fun: Callable, primals, tangents, has_aux: bool = False
) -> Tuple[Any, ...]: