mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #22735 from dfm:custom-vjp-remat-opt
PiperOrigin-RevId: 658043956
This commit is contained in:
commit
c7eb023746
@ -16,7 +16,7 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
import dataclasses
|
||||
from functools import update_wrapper, reduce, partial
|
||||
from functools import update_wrapper, reduce, partial, wraps
|
||||
import inspect
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
@ -497,6 +497,7 @@ class custom_vjp(Generic[ReturnValue]):
|
||||
self.fwd: Callable[..., tuple[ReturnValue, Any]] | None = None
|
||||
self.bwd: Callable[..., tuple[Any, ...]] | None = None
|
||||
self.symbolic_zeros = False
|
||||
self.optimize_remat = False
|
||||
|
||||
__getattr__ = custom_api_util.forward_attr
|
||||
|
||||
@ -504,6 +505,7 @@ class custom_vjp(Generic[ReturnValue]):
|
||||
fwd: Callable[..., tuple[ReturnValue, Any]],
|
||||
bwd: Callable[..., tuple[Any, ...]],
|
||||
symbolic_zeros: bool = False,
|
||||
optimize_remat: bool = False,
|
||||
) -> None:
|
||||
"""Define a custom VJP rule for the function represented by this instance.
|
||||
|
||||
@ -560,6 +562,10 @@ class custom_vjp(Generic[ReturnValue]):
|
||||
objects that are given as input leaves to the ``fwd`` rule.
|
||||
|
||||
Default ``False``.
|
||||
optimize_remat: boolean, an experimental flag to enable an automatic
|
||||
optimization when this function is used under :func:`jax.remat`. This
|
||||
will be most useful when the ``fwd`` rule is an opaque call such as a
|
||||
Pallas kernel or a custom call. Default ``False``.
|
||||
|
||||
Returns:
|
||||
None.
|
||||
@ -582,6 +588,10 @@ class custom_vjp(Generic[ReturnValue]):
|
||||
self.fwd = fwd
|
||||
self.bwd = bwd
|
||||
self.symbolic_zeros = symbolic_zeros
|
||||
self.optimize_remat = optimize_remat
|
||||
if self.symbolic_zeros and self.optimize_remat:
|
||||
raise NotImplementedError(
|
||||
"remat optimization for custom_vjp does not support symbolic zeros")
|
||||
|
||||
@traceback_util.api_boundary
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable=invalid-annotation
|
||||
@ -591,6 +601,12 @@ class custom_vjp(Generic[ReturnValue]):
|
||||
raise AttributeError(msg)
|
||||
fwd_name = getattr(self.fwd, '__name__', str(self.fwd))
|
||||
args = _resolve_kwargs(self.fun, args, kwargs)
|
||||
if self.optimize_remat:
|
||||
fwd = optimize_remat_of_custom_vjp_fwd(
|
||||
self.fun, self.fwd, nondiff_argnums=self.nondiff_argnums,
|
||||
symbolic_zeros=self.symbolic_zeros)
|
||||
else:
|
||||
fwd = self.fwd
|
||||
if config.enable_custom_vjp_by_custom_transpose.value:
|
||||
if self.nondiff_argnums:
|
||||
raise NotImplementedError(
|
||||
@ -604,16 +620,16 @@ class custom_vjp(Generic[ReturnValue]):
|
||||
f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), dyn_argnums,
|
||||
args, require_static_args_hashable=False)
|
||||
static_args = [args[i] for i in self.nondiff_argnums]
|
||||
fwd, _ = argnums_partial(lu.wrap_init(self.fwd), dyn_argnums, args,
|
||||
fwd_, _ = argnums_partial(lu.wrap_init(fwd), dyn_argnums, args,
|
||||
require_static_args_hashable=False)
|
||||
bwd = _add_args(lu.wrap_init(self.bwd), static_args)
|
||||
else:
|
||||
f_, dyn_args = lu.wrap_init(self.fun), args
|
||||
fwd, bwd = lu.wrap_init(self.fwd), lu.wrap_init(self.bwd)
|
||||
fwd_, bwd = lu.wrap_init(fwd), lu.wrap_init(self.bwd)
|
||||
args_flat, in_tree = tree_flatten(dyn_args)
|
||||
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
|
||||
flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree)
|
||||
flat_fwd, out_trees = _flatten_fwd(fwd, self.symbolic_zeros, primal_name,
|
||||
flat_fwd, out_trees = _flatten_fwd(fwd_, self.symbolic_zeros, primal_name,
|
||||
fwd_name, in_tree, out_type)
|
||||
flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees).call_wrapped
|
||||
out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd,
|
||||
@ -1407,3 +1423,226 @@ def custom_vjp_by_custom_transpose(fun, fwd, bwd):
|
||||
|
||||
# TODO(mattjj): remove these stubs, which exist to avoid breaking internal users
|
||||
custom_jvp_call_jaxpr_p = core.Primitive("custom_jvp_call_jaxpr")
|
||||
|
||||
|
||||
# The following is a helper for optimizing the behavior of custom_vjp when used
|
||||
# under remat. This is really only useful when the `fwd` function to custom_vjp
|
||||
# executes a black box kernel. Otherwise, DCE will perform this optimization
|
||||
# automatically.
|
||||
#
|
||||
# TODO(dfm): Eventually this should probably be the default behavior for
|
||||
# custom_vjp, if we can make it so that it is a no-op for most cases. Right now,
|
||||
# it is written in "initial-style" so it doesn't support eager mode. This was
|
||||
# a reasonable compromise when written because it made the implementation
|
||||
# simpler, but it would be worth revisiting this.
|
||||
def optimize_remat_of_custom_vjp_fwd(
|
||||
fun: Callable[..., ReturnValue],
|
||||
fwd: Callable[..., tuple[ReturnValue, Any]],
|
||||
nondiff_argnums: tuple[int, ...] = (),
|
||||
symbolic_zeros: bool = False,
|
||||
) -> Callable[..., tuple[ReturnValue, Any]]:
|
||||
if symbolic_zeros:
|
||||
# TODO(dfm): This probably shouldn't be too hard to support.
|
||||
raise NotImplementedError(
|
||||
"remat optimization for custom_vjp does not support symbolic zeros")
|
||||
|
||||
@wraps(fwd)
|
||||
def wrapped_fwd(*args, **kwargs) -> tuple[ReturnValue, Any]:
|
||||
# TODO(dfm): This initial logic is duplicated from custom_vjp.__call__
|
||||
# above and it would be good to consolidate it.
|
||||
primal_name = getattr(fun, "__name__", str(fun))
|
||||
fwd_name = getattr(fwd, "__name__", str(fwd))
|
||||
args = _resolve_kwargs(fwd, args, kwargs)
|
||||
if nondiff_argnums:
|
||||
for i in nondiff_argnums: _check_for_tracers(args[i])
|
||||
nondiff_argnums_ = set(nondiff_argnums)
|
||||
dyn_argnums = [i for i in range(len(args)) if i not in nondiff_argnums_]
|
||||
f_, dyn_args = argnums_partial(lu.wrap_init(fun), dyn_argnums,
|
||||
args, require_static_args_hashable=False)
|
||||
fwd_, _ = argnums_partial(lu.wrap_init(fwd), dyn_argnums, args,
|
||||
require_static_args_hashable=False)
|
||||
else:
|
||||
f_, dyn_args = lu.wrap_init(fun), args
|
||||
fwd_ = lu.wrap_init(fwd)
|
||||
args_flat, in_tree = tree_flatten(dyn_args)
|
||||
flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree)
|
||||
flat_fwd, out_trees = _flatten_fwd(fwd_, False, primal_name, fwd_name,
|
||||
in_tree, out_type)
|
||||
|
||||
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
|
||||
fwd_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fwd, in_avals)
|
||||
fwd_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(fwd_jaxpr))
|
||||
prim_tree, res_tree = out_trees()
|
||||
num_res = res_tree.num_leaves
|
||||
|
||||
if fwd_jaxpr.effects:
|
||||
raise NotImplementedError(
|
||||
"remat optimization for custom_vjp does not support forward "
|
||||
f"functions with side effects, but {fwd_name} has the following "
|
||||
f"effects: {fwd_jaxpr.effects}")
|
||||
|
||||
@pe._memoize
|
||||
def fun_jaxpr_thunk():
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
|
||||
return jaxpr, consts
|
||||
|
||||
out_flat = remat_opt_p.bind(*consts, *args_flat,
|
||||
num_consts=len(consts),
|
||||
num_res=num_res,
|
||||
fwd_jaxpr=fwd_jaxpr,
|
||||
fun_jaxpr_thunk=fun_jaxpr_thunk)
|
||||
res, out_flat = split_list(out_flat, [num_res])
|
||||
out_tree = treedef_tuple((prim_tree, res_tree))
|
||||
return tree_unflatten(out_tree, (*out_flat, *res))
|
||||
|
||||
return wrapped_fwd
|
||||
|
||||
def _remat_opt_impl(
|
||||
*args,
|
||||
num_consts: int,
|
||||
num_res: int,
|
||||
fwd_jaxpr: core.ClosedJaxpr,
|
||||
fun_jaxpr_thunk: Callable[[], core.ClosedJaxpr],
|
||||
):
|
||||
del num_consts, num_res, fun_jaxpr_thunk # unused
|
||||
return core.jaxpr_as_fun(fwd_jaxpr)(*args)
|
||||
|
||||
def _remat_opt_abstract_eval(*args, fwd_jaxpr: core.ClosedJaxpr, **_):
|
||||
del args
|
||||
return fwd_jaxpr.out_avals, fwd_jaxpr.effects
|
||||
|
||||
def _remat_opt_vmap(
|
||||
spmd_axis_name, axis_size, axis_name, main_type, args, in_dims,
|
||||
*,
|
||||
num_consts: int,
|
||||
num_res: int,
|
||||
fwd_jaxpr: core.ClosedJaxpr,
|
||||
fun_jaxpr_thunk: Callable[[], core.ClosedJaxpr],
|
||||
):
|
||||
args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0
|
||||
else x for x, d in zip(args, in_dims)]
|
||||
|
||||
in_batched = [d is not not_mapped for d in in_dims]
|
||||
batched_fwd_jaxpr, out_batched = batching.batch_jaxpr(
|
||||
fwd_jaxpr, axis_size, in_batched, False,
|
||||
axis_name, spmd_axis_name, main_type)
|
||||
out_dims = [0 if b else not_mapped for b in out_batched]
|
||||
|
||||
_, prim_batched = split_list(in_batched, [num_consts])
|
||||
|
||||
@pe._memoize
|
||||
def batched_fun_jaxpr_thunk():
|
||||
fun_jaxpr = core.ClosedJaxpr(*fun_jaxpr_thunk())
|
||||
batched_fun_jaxpr, out_batched = batching.batch_jaxpr(
|
||||
fun_jaxpr, axis_size, prim_batched, False, axis_name, spmd_axis_name,
|
||||
main_type)
|
||||
return batched_fun_jaxpr.jaxpr, batched_fun_jaxpr.consts
|
||||
|
||||
batched_outs = remat_opt_p.bind(*args, num_consts=num_consts,
|
||||
num_res=num_res,
|
||||
fwd_jaxpr=batched_fwd_jaxpr,
|
||||
fun_jaxpr_thunk=batched_fun_jaxpr_thunk)
|
||||
|
||||
return batched_outs, out_dims
|
||||
|
||||
def _remat_opt_jvp(
|
||||
primals,
|
||||
tangents,
|
||||
*,
|
||||
num_consts: int,
|
||||
num_res: int,
|
||||
fwd_jaxpr: core.ClosedJaxpr,
|
||||
fun_jaxpr_thunk: Callable[[], core.ClosedJaxpr],
|
||||
):
|
||||
consts, primals = split_list(primals, [num_consts])
|
||||
consts_dot, tangents = split_list(tangents, [num_consts])
|
||||
# Tangents must be instantated in case we end up DCEing later.
|
||||
tangents = map(ad.instantiate_zeros, tangents)
|
||||
consts_nz = [not isinstance(t, Zero) for t in consts_dot]
|
||||
consts_dot = [c for nz, c in zip(consts_nz, consts_dot) if nz]
|
||||
in_nz = consts_nz + [True] * len(tangents)
|
||||
fwd_jaxpr_jvp_, out_nz = ad.jvp_jaxpr(fwd_jaxpr, in_nz, True)
|
||||
num_out = len(out_nz) - num_res
|
||||
fwd_jaxpr_jvp_ = ad.rearrange_binders(
|
||||
fwd_jaxpr_jvp_, [num_consts, len(primals)],
|
||||
[len(consts_dot), len(tangents)], [num_res, num_out], [num_res, num_out])
|
||||
fwd_jaxpr_jvp = pe.close_jaxpr(pe.convert_constvars_jaxpr(fwd_jaxpr_jvp_.jaxpr))
|
||||
|
||||
@pe._memoize
|
||||
def fun_jvp_jaxpr_thunk():
|
||||
fun_jaxpr = core.ClosedJaxpr(*fun_jaxpr_thunk())
|
||||
in_nz = [True] * len(primals)
|
||||
fun_jvp_jaxpr, _ = ad.jvp_jaxpr(fun_jaxpr, in_nz, True)
|
||||
return fun_jvp_jaxpr.jaxpr, fun_jvp_jaxpr.consts
|
||||
|
||||
new_num_consts = len(fwd_jaxpr_jvp_.consts) + num_consts + len(consts_dot)
|
||||
outs = remat_opt_p.bind(*fwd_jaxpr_jvp_.consts, *consts, *consts_dot,
|
||||
*primals, *tangents, num_consts=new_num_consts,
|
||||
num_res=2 * num_res, fwd_jaxpr=fwd_jaxpr_jvp,
|
||||
fun_jaxpr_thunk=fun_jvp_jaxpr_thunk)
|
||||
res, res_dot, outs, outs_dot = split_list(outs, [num_res, num_res, num_out])
|
||||
return (*res, *outs), (*res_dot, *outs_dot)
|
||||
|
||||
def _remat_opt_transpose(
|
||||
cts, *args,
|
||||
num_consts: int,
|
||||
num_res: int,
|
||||
fwd_jaxpr: core.ClosedJaxpr,
|
||||
fun_jaxpr_thunk: Callable[[], core.ClosedJaxpr],
|
||||
):
|
||||
# TODO(dfm): It shouldn't be too hard to implement this as needed in the
|
||||
# future.
|
||||
raise NotImplementedError(
|
||||
"remat optimization for custom_vjp does not support higher-order AD")
|
||||
|
||||
def _remat_opt_dce(used_outs: list[bool], eqn: core.JaxprEqn):
|
||||
used_res, used_prims = split_list(used_outs, [eqn.params["num_res"]])
|
||||
outvars = [v for used, v in zip(used_outs, eqn.outvars) if used]
|
||||
if any(used_res):
|
||||
# If any of the residuals are used, we still need to run fwd at this point,
|
||||
# but we may end up DCEing again in the future, so we must instantiate all
|
||||
# the input primals.
|
||||
instantiate = [False] * eqn.params["num_consts"]
|
||||
instantiate += [True] * (len(eqn.invars) - eqn.params["num_consts"])
|
||||
new_jaxpr, used_ins = pe.dce_jaxpr(eqn.params["fwd_jaxpr"].jaxpr, used_outs,
|
||||
instantiate=instantiate)
|
||||
closed_jaxpr = pe.close_jaxpr(new_jaxpr)
|
||||
invars = [v for used, v in zip(used_ins, eqn.invars) if used]
|
||||
new_params = dict(eqn.params)
|
||||
new_num_consts = sum(split_list(used_ins, [eqn.params["num_consts"]])[0])
|
||||
new_params["num_consts"] = new_num_consts
|
||||
new_params["fwd_jaxpr"] = closed_jaxpr
|
||||
new_params["num_res"] = sum(used_res)
|
||||
new_eqn = pe.new_jaxpr_eqn(
|
||||
invars, outvars, remat_opt_p, new_params, closed_jaxpr.effects,
|
||||
eqn.source_info, eqn.ctx)
|
||||
return used_ins, new_eqn
|
||||
else:
|
||||
# If none of the residuals are used, we run the primal computation instead.
|
||||
# At this point we drop this custom DCE behavior, but since the primal might
|
||||
# have different consts than fwd, we build a new JaxprEqn with a closed_call
|
||||
# primitive.
|
||||
fun_jaxpr, consts = eqn.params["fun_jaxpr_thunk"]()
|
||||
new_jaxpr, used_consts, used_ins = pe.dce_jaxpr_consts(fun_jaxpr, used_prims)
|
||||
consts = [c for used, c in zip(used_consts, consts) if used]
|
||||
closed_jaxpr = core.ClosedJaxpr(new_jaxpr, consts)
|
||||
_, invars = split_list(eqn.invars, [eqn.params["num_consts"]])
|
||||
invars = [v for used, v in zip(used_ins, invars) if used]
|
||||
new_eqn = pe.new_jaxpr_eqn(
|
||||
invars, outvars, core.closed_call_p, dict(call_jaxpr=closed_jaxpr),
|
||||
closed_jaxpr.effects, eqn.source_info, eqn.ctx)
|
||||
used_ins = [False] * eqn.params["num_consts"] + used_ins
|
||||
return used_ins, new_eqn
|
||||
|
||||
remat_opt_p = core.Primitive("remat_opt")
|
||||
remat_opt_p.multiple_results = True
|
||||
remat_opt_p.def_impl(_remat_opt_impl)
|
||||
remat_opt_p.def_effectful_abstract_eval(_remat_opt_abstract_eval)
|
||||
xla.register_initial_style_primitive(remat_opt_p)
|
||||
mlir.register_lowering(remat_opt_p, mlir.lower_fun(
|
||||
_remat_opt_impl, multiple_results=True))
|
||||
batching.spmd_axis_primitive_batchers[remat_opt_p] = _remat_opt_vmap
|
||||
batching.axis_primitive_batchers[remat_opt_p] = partial(_remat_opt_vmap, None)
|
||||
ad.primitive_jvps[remat_opt_p] = _remat_opt_jvp
|
||||
ad.primitive_transposes[remat_opt_p] = _remat_opt_transpose
|
||||
pe.dce_rules[remat_opt_p] = _remat_opt_dce
|
||||
|
@ -30,6 +30,7 @@ from jax._src.custom_derivatives import (
|
||||
custom_vjp_primal_tree_values as custom_vjp_primal_tree_values,
|
||||
CustomVJPPrimal as CustomVJPPrimal,
|
||||
linear_call as linear_call,
|
||||
remat_opt_p as remat_opt_p,
|
||||
)
|
||||
|
||||
from jax._src.ad_util import (
|
||||
|
@ -3448,6 +3448,17 @@ def _custom_lin(*args: TfVal, **_) -> Sequence[TfVal]:
|
||||
tf_impl[ad.custom_lin_p] = _custom_lin
|
||||
|
||||
|
||||
def _remat_opt(*args: TfVal, num_consts: int, num_res: int,
|
||||
fwd_jaxpr: core.ClosedJaxpr,
|
||||
fun_jaxpr_thunk: Callable) -> Sequence[TfVal]:
|
||||
del num_consts, num_res, fun_jaxpr_thunk
|
||||
return _interpret_jaxpr(fwd_jaxpr, *args, extra_name_stack="remat_opt",
|
||||
fresh_constant_cache=False)
|
||||
|
||||
|
||||
tf_impl[custom_derivatives.remat_opt_p] = _remat_opt
|
||||
|
||||
|
||||
PartitionsOrReplicated = Union[tuple[int, ...], None]
|
||||
|
||||
def split_to_logical_devices(tensor: TfVal,
|
||||
|
@ -9653,6 +9653,83 @@ class CustomVJPTest(jtu.JaxTestCase):
|
||||
|
||||
jax.grad(f)(1., 2.) # don't crash
|
||||
|
||||
def test_optimize_remat(self):
|
||||
def fun(x):
|
||||
# This array is included to make sure that we handle consts appropriately
|
||||
return np.array([1.0])*x
|
||||
|
||||
def fwd(x):
|
||||
return np.array([2.0])*x*x/np.array([1.0]), (x,)
|
||||
|
||||
fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(fun, fwd)
|
||||
x = jnp.linspace(0, 5.0, 10)
|
||||
self.assertAllClose(jax.jit(fwd)(x)[0], 2*x*x) # Shouldn't hit custom DCE
|
||||
self.assertAllClose(jax.jit(lambda x: fwd(x)[0])(x), x) # Should be DCEed
|
||||
|
||||
def test_optimize_remat_vmap(self):
|
||||
def fun(x):
|
||||
return (np.array([1.0])*x)[0]
|
||||
def fwd(x):
|
||||
return (np.array([2.0])*x*x/np.array([1.0]))[0], (x,)
|
||||
fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(fun, fwd)
|
||||
x = jnp.linspace(0, 5.0, 10)
|
||||
self.assertAllClose(jax.jit(jax.vmap(fwd))(x)[0], 2*x*x)
|
||||
self.assertAllClose(jax.jit(lambda x: jax.vmap(fwd)(x)[0])(x), x)
|
||||
|
||||
def test_optimize_remat_cond(self):
|
||||
def fun(x):
|
||||
return x
|
||||
def fwd(x):
|
||||
return x*x, (x,)
|
||||
fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(fun, fwd)
|
||||
|
||||
def g(x):
|
||||
return jax.lax.cond(True, fwd, lambda x: (2.0 * x, (x,)), x)
|
||||
x = jnp.linspace(0, 5.0, 10)
|
||||
self.assertAllClose(jax.jit(g)(x)[0], x*x)
|
||||
self.assertAllClose(jax.jit(lambda x: g(x)[0])(x), x)
|
||||
|
||||
def test_optimize_remat_jvp(self):
|
||||
def fun(x):
|
||||
return x**2
|
||||
def fwd_(x):
|
||||
return x*x, (x,)
|
||||
fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(fun, fwd_)
|
||||
calc = jax.jvp(fwd, (3.2,), (1.0,))
|
||||
expected = jax.jvp(fwd_, (3.2,), (1.0,))
|
||||
self.assertAllClose(calc, expected)
|
||||
|
||||
@jax.jit
|
||||
def g(x, t):
|
||||
(y, r), (y_dot, r_dot) = jax.jvp(fwd, (x,), (t,))
|
||||
return y, y_dot
|
||||
calc = g(3.2, 1.0)
|
||||
expected = jax.jvp(fun, (3.2,), (1.0,))
|
||||
self.assertAllClose(calc, expected)
|
||||
|
||||
def test_optimize_remat_gh21303(self):
|
||||
@jax.custom_vjp
|
||||
def f(x):
|
||||
return jnp.tan(x)
|
||||
|
||||
def f_fwd(x):
|
||||
return jnp.sin(x), (x,)
|
||||
|
||||
def f_bwd(res, g):
|
||||
x, = res
|
||||
cos_x = jnp.cos(x)
|
||||
return (cos_x * g,)
|
||||
|
||||
f.defvjp(f_fwd, f_bwd, optimize_remat=True)
|
||||
|
||||
def temp(x):
|
||||
out = jax.remat(f)(x)
|
||||
out = out ** 2
|
||||
return out
|
||||
|
||||
v, g = jax.value_and_grad(temp)(3.2)
|
||||
self.assertAllClose(v, jnp.tan(3.2)**2)
|
||||
|
||||
|
||||
def transpose_unary(f, x_example):
|
||||
def transposed(y):
|
||||
|
Loading…
x
Reference in New Issue
Block a user