mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
removes the jax.mask and jax.shapecheck APIs.
PiperOrigin-RevId: 463026577
This commit is contained in:
parent
f5f650fc1c
commit
66dc95e2de
@ -29,7 +29,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
* The `sym_pos` argument of {func}`jax.scipy.linalg.solve` is deprecated in favor of `assume_a='pos'`,
|
||||
following a similar deprecation in {func}`scipy.linalg.solve`.
|
||||
* Deprecations:
|
||||
* {func}`jax.mask` {func}`jax.shapecheck` are being deprecated.
|
||||
* {func}`jax.mask` {func}`jax.shapecheck` APIs have been removed.
|
||||
See {jax-issue}`#11557`.
|
||||
|
||||
## jaxlib 0.3.15 (Unreleased)
|
||||
|
@ -95,7 +95,7 @@ from jax._src.api import (
|
||||
linearize as linearize,
|
||||
linear_transpose as linear_transpose,
|
||||
make_jaxpr as make_jaxpr,
|
||||
mask as mask,
|
||||
|
||||
named_call as named_call,
|
||||
named_scope as named_scope,
|
||||
pmap as pmap,
|
||||
@ -103,7 +103,6 @@ from jax._src.api import (
|
||||
process_index as process_index,
|
||||
pxla, # TODO(phawkins): update users to avoid this.
|
||||
remat as remat,
|
||||
shapecheck as shapecheck,
|
||||
ShapedArray as ShapedArray,
|
||||
ShapeDtypeStruct as ShapeDtypeStruct,
|
||||
value_and_grad as value_and_grad,
|
||||
|
@ -82,7 +82,6 @@ from jax.interpreters import xla
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import masking
|
||||
|
||||
from jax._src.config import (
|
||||
flags, config, bool_env,
|
||||
@ -2247,67 +2246,6 @@ def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple,
|
||||
|
||||
return lower
|
||||
|
||||
|
||||
def mask(fun: Callable, in_shapes, out_shape=None) -> Callable:
|
||||
warn("`jax.mask` is deprecated and will be removed soon. ",
|
||||
DeprecationWarning)
|
||||
_check_callable(fun)
|
||||
unique_ids = masking.UniqueIds()
|
||||
|
||||
in_specs, in_shapes_tree = tree_flatten(in_shapes)
|
||||
in_specs = map(masking.parse_spec, in_specs)
|
||||
in_specs = map(partial(masking.remap_ids, unique_ids), in_specs)
|
||||
|
||||
if out_shape is not None:
|
||||
out_specs, out_spec_tree = tree_flatten(out_shape)
|
||||
out_specs = map(masking.parse_spec, out_specs)
|
||||
out_specs = map(partial(masking.remap_ids, unique_ids), out_specs)
|
||||
|
||||
def wrapped_fun(args, logical_env):
|
||||
args_flat, in_tree = tree_flatten(args)
|
||||
if in_tree != in_shapes_tree:
|
||||
raise TypeError(f"Tree mismatch: Input {in_tree} and shape spec {in_shapes_tree}.")
|
||||
logical_env = {unique_ids[name] : val for name, val in logical_env.items()}
|
||||
in_shapes = map(masking.finalize_spec, in_specs, map(np.shape, args_flat))
|
||||
padded_env = masking.bind_shapes(in_shapes, [x.shape for x in args_flat])
|
||||
f = lu.wrap_init(fun)
|
||||
flat_fun, out_tree_thunk = flatten_fun_nokwargs(f, in_tree)
|
||||
outs, out_shapes = masking.mask_fun(
|
||||
flat_fun, logical_env, padded_env, args_flat, in_shapes)
|
||||
out_tree = out_tree_thunk()
|
||||
|
||||
if out_shape is None:
|
||||
def logical_shape(poly_shape, padded_val):
|
||||
shape = masking.eval_poly_shape(poly_shape, logical_env)
|
||||
return ShapeDtypeStruct(shape, core.get_aval(padded_val).dtype)
|
||||
out_logicals = map(logical_shape, out_shapes, outs)
|
||||
return tree_unflatten(out_tree, outs), tree_unflatten(out_tree, out_logicals)
|
||||
else:
|
||||
masking.check_shapes(out_specs, out_spec_tree, list(out_shapes), out_tree)
|
||||
def padded_spec(shape_spec):
|
||||
return tuple(dim if dim is masking._monomorphic_dim else
|
||||
masking.eval_poly(dim, padded_env) for dim in shape_spec)
|
||||
masking.check_shapes(map(padded_spec, out_specs), out_spec_tree,
|
||||
map(np.shape, outs), out_tree, "Padded output")
|
||||
return tree_unflatten(out_tree, outs)
|
||||
return wrapped_fun
|
||||
|
||||
@curry
|
||||
def shapecheck(in_shapes, out_shape, fun: Callable):
|
||||
warn("`jax.shapecheck` is deprecated and will be removed soon. ",
|
||||
DeprecationWarning)
|
||||
_check_callable(fun)
|
||||
in_shapes, in_tree = tree_flatten(in_shapes)
|
||||
in_shapes = map(masking.parse_spec, in_shapes)
|
||||
out_specs, out_spec_tree = tree_flatten(out_shape)
|
||||
out_specs = map(masking.parse_spec, out_specs)
|
||||
flat_fun, out_tree_thunk = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
|
||||
avals = map(partial(ShapedArray, dtype=np.float32), in_shapes)
|
||||
out_shapes = [o.shape for o in pe.abstract_eval_fun(flat_fun.call_wrapped, *avals)]
|
||||
masking.check_shapes(map(tuple, out_specs), out_spec_tree,
|
||||
map(tuple, out_shapes), out_tree_thunk())
|
||||
return fun
|
||||
|
||||
def jvp(
|
||||
fun: Callable, primals, tangents, has_aux: bool = False
|
||||
) -> Tuple[Any, ...]:
|
||||
|
Loading…
x
Reference in New Issue
Block a user