Change from deflinear to deflinear2

This commit is contained in:
Jake VanderPlas 2020-12-30 17:42:04 -08:00
parent afe612217b
commit 98aac23d92
7 changed files with 25 additions and 26 deletions

View File

@ -2522,7 +2522,7 @@ def _cumred_shape_rule(x, *, axis: int, reverse: bool):
"axis {} is out of bounds for array of shape {}".format(axis, x.shape))
return x.shape
def _cumsum_transpose_rule(t, *, axis: int, reverse: bool):
def _cumsum_transpose_rule(t, operand, *, axis: int, reverse: bool):
return [cumsum(t, axis=axis, reverse=not reverse)]
@ -2557,7 +2557,7 @@ def _cumred_dtype_rule(name, operand, *args, **kw):
cumsum_p = lax.standard_primitive(
_cumred_shape_rule, partial(_cumred_dtype_rule, "cumsum"),
'cumsum')
ad.deflinear(cumsum_p, _cumsum_transpose_rule)
ad.deflinear2(cumsum_p, _cumsum_transpose_rule)
xla.backend_specific_translations['tpu'][cumsum_p] = xla.lower_fun(
partial(_cumred_tpu_translation_rule, lax._reduce_window_sum),
multiple_results=False)

View File

@ -116,7 +116,7 @@ def _irfft_transpose(t, fft_lengths):
assert out.dtype == _complex_dtype(t.dtype), (out.dtype, t.dtype)
return out
def fft_transpose_rule(t, fft_type, fft_lengths):
def fft_transpose_rule(t, operand, fft_type, fft_lengths):
if fft_type == xla_client.FftType.RFFT:
result = _rfft_transpose(t, fft_lengths)
elif fft_type == xla_client.FftType.IRFFT:
@ -135,7 +135,7 @@ fft_p = Primitive('fft')
fft_p.def_impl(fft_impl)
fft_p.def_abstract_eval(fft_abstract_eval)
xla.translations[fft_p] = fft_translation_rule
ad.deflinear(fft_p, fft_transpose_rule)
ad.deflinear2(fft_p, fft_transpose_rule)
batching.primitive_batchers[fft_p] = fft_batching_rule
if pocketfft:
xla.backend_specific_translations['cpu'][fft_p] = pocketfft.pocketfft

View File

@ -2134,7 +2134,7 @@ _any = _int | _float | _complex | _bool
_bool_or_int = _int | _bool
neg_p = standard_unop(_num, 'neg')
ad.deflinear(neg_p, lambda t: [neg(t)])
ad.deflinear2(neg_p, lambda t, operand: [neg(t)])
def _sign_translation_rule(c, x):
shape = c.get_shape(x)
@ -2373,15 +2373,15 @@ ad.defjvp2(erf_inv_p, lambda g, ans, x: mul(_const(x, np.sqrt(np.pi) / 2.),
mul(g, exp(square(ans)))))
real_p = unop(_complex_basetype, _complex, 'real')
ad.deflinear(real_p, lambda t: [complex(t, np.zeros((), _dtype(t)))])
ad.deflinear2(real_p, lambda t, _: [complex(t, np.zeros((), _dtype(t)))])
imag_p = unop(_complex_basetype, _complex, 'imag')
ad.deflinear(imag_p, lambda t: [complex(np.zeros((), _dtype(t)), neg(t))])
ad.deflinear2(imag_p, lambda t, _: [complex(np.zeros((), _dtype(t)), neg(t))])
_complex_dtype = lambda dtype, *args: (np.zeros((), dtype) + np.zeros((), np.complex64)).dtype
complex_p = naryop(_complex_dtype, [_complex_elem_types, _complex_elem_types],
'complex')
ad.deflinear(complex_p, lambda t: [real(t), imag(neg(t))])
ad.deflinear2(complex_p, lambda t, *args: [real(t), imag(neg(t))])
conj_p = unop(_complex_dtype, _complex_elem_types | _complex, 'conj')
@ -3248,7 +3248,7 @@ def _broadcast_batch_rule(batched_args, batch_dims, *, sizes):
broadcast_p = standard_primitive(
_broadcast_shape_rule, _input_dtype, 'broadcast')
ad.deflinear(broadcast_p, lambda t, sizes: [_reduce_sum(t, range(len(sizes)))])
ad.deflinear2(broadcast_p, lambda t, _, sizes: [_reduce_sum(t, range(len(sizes)))])
batching.primitive_batchers[broadcast_p] = _broadcast_batch_rule
def _broadcast_in_dim_impl(operand, *, shape, broadcast_dimensions):
@ -3416,7 +3416,7 @@ def _concatenate_batch_rule(batched_args, batch_dims, *, dimension):
concatenate_p = standard_primitive(
_concatenate_shape_rule, _concatenate_dtype_rule, 'concatenate',
_concatenate_translation_rule)
ad.deflinear(concatenate_p, _concatenate_transpose_rule)
ad.deflinear2(concatenate_p, _concatenate_transpose_rule)
ad.primitive_transposes[concatenate_p] = _concatenate_transpose_rule
batching.primitive_batchers[concatenate_p] = _concatenate_batch_rule
@ -3486,8 +3486,7 @@ def _pad_masking_rule(padded_vals, logical_shapes, padding_config):
pad_p = standard_primitive(_pad_shape_rule, _pad_dtype_rule, 'pad',
translation_rule=_pad_translation_rule)
ad.deflinear(pad_p, _pad_transpose)
ad.primitive_transposes[pad_p] = _pad_transpose
ad.deflinear2(pad_p, _pad_transpose)
batching.primitive_batchers[pad_p] = _pad_batch_rule
masking.masking_rules[pad_p] = _pad_masking_rule
@ -3677,7 +3676,7 @@ def _rev_batch_rule(batched_args, batch_dims, *, dimensions):
return rev(operand, new_dimensions), bdim
rev_p = standard_primitive(_rev_shape_rule, _input_dtype, 'rev')
ad.deflinear(rev_p, lambda t, dimensions: [rev(t, dimensions)])
ad.deflinear2(rev_p, lambda t, _, dimensions: [rev(t, dimensions)])
batching.primitive_batchers[rev_p] = _rev_batch_rule
@ -3714,8 +3713,8 @@ def _transpose_masking_rule(padded_vals, logical_shapes, permutation):
transpose_p = standard_primitive(_transpose_shape_rule, _input_dtype,
'transpose')
transpose_p.def_impl(_transpose_impl)
ad.deflinear(transpose_p,
lambda t, permutation: [transpose(t, np.argsort(permutation))])
ad.deflinear2(transpose_p,
lambda t, _, permutation: [transpose(t, np.argsort(permutation))])
batching.primitive_batchers[transpose_p] = _transpose_batch_rule
masking.masking_rules[transpose_p] = _transpose_masking_rule

View File

@ -461,7 +461,7 @@ def _notuple_allreduce_translation_rule(prim, c, *args, axis_name, axis_env,
else all_reduce(x) for x in args]
return xops.Tuple(c, outs)
def _psum_transpose_rule(cts, axis_name, axis_index_groups):
def _psum_transpose_rule(cts, *args, axis_name, axis_index_groups):
nonzero_out_cts, treedef = tree_util.tree_flatten(cts)
nonzero_in_cts = psum_p.bind(*nonzero_out_cts, axis_name=axis_name,
axis_index_groups=axis_index_groups)
@ -473,7 +473,7 @@ psum_p.def_abstract_eval(lambda *args, **params: map(raise_to_shaped, args))
pxla.soft_pmap_rules[psum_p] = \
partial(_allreduce_soft_pmap_rule, psum_p, lax._reduce_sum)
xla.parallel_translations[psum_p] = partial(_allreduce_translation_rule, lax.add_p) # type: ignore
ad.deflinear(psum_p, _psum_transpose_rule)
ad.deflinear2(psum_p, _psum_transpose_rule)
pxla.multi_host_supported_collectives.add(psum_p)
batching.primitive_batchers[psum_p] = partial(_collective_batcher, psum_p)
batching.collective_rules[psum_p] = \
@ -534,7 +534,7 @@ def _ppermute_translation_rule(c, x, *, axis_name, axis_env, perm, platform):
full_perm.extend((grp[src], grp[dst]) for src, dst in perm)
return xops.CollectivePermute(x, full_perm)
def _ppermute_transpose_rule(t, perm, axis_name):
def _ppermute_transpose_rule(t, x, perm, axis_name):
srcs, dsts = unzip2(perm)
inverse_perm = list(zip(dsts, srcs))
return [ppermute(t, axis_name=axis_name, perm=inverse_perm)]
@ -551,7 +551,7 @@ def _ppermute_batcher(frame, vals_in, dims_in, axis_name, perm):
ppermute_p = core.Primitive('ppermute')
ppermute_p.def_abstract_eval(lambda x, **params: raise_to_shaped(x))
ad.deflinear(ppermute_p, _ppermute_transpose_rule)
ad.deflinear2(ppermute_p, _ppermute_transpose_rule)
xla.parallel_translations[ppermute_p] = _ppermute_translation_rule
pxla.multi_host_supported_collectives.add(ppermute_p)
batching.primitive_batchers[ppermute_p] = partial(_collective_batcher, ppermute_p)
@ -603,7 +603,7 @@ def _all_to_all_translation_rule(c, x, *, split_axis, concat_axis, axis_name,
x = xla.lower_fun(partial(lax.squeeze, dimensions=(split_axis,)), multiple_results=False)(c, x)
return x
def _all_to_all_transpose_rule(cts, axis_name, split_axis, concat_axis, axis_index_groups):
def _all_to_all_transpose_rule(cts, x, axis_name, split_axis, concat_axis, axis_index_groups):
return (all_to_all(
cts,
axis_name=axis_name,
@ -658,7 +658,7 @@ def _all_to_all_abstract_eval(x, axis_name, split_axis, concat_axis, axis_index_
all_to_all_p = core.Primitive('all_to_all')
all_to_all_p.def_abstract_eval(_all_to_all_abstract_eval)
xla.parallel_translations[all_to_all_p] = _all_to_all_translation_rule
ad.deflinear(all_to_all_p, _all_to_all_transpose_rule)
ad.deflinear2(all_to_all_p, _all_to_all_transpose_rule)
pxla.multi_host_supported_collectives.add(all_to_all_p)
batching.primitive_batchers[all_to_all_p] = _all_to_all_batcher
batching.collective_rules[all_to_all_p] = _all_to_all_batched_collective

View File

@ -510,8 +510,8 @@ def zero_jvp(primitive, primals, tangents, **params):
return r, Zero.from_value(r)
deflinear(zeros_like_p, lambda t: [Zero.from_value(t)])
deflinear(add_jaxvals_p, lambda t: (t, t))
deflinear2(zeros_like_p, lambda t, _: [Zero.from_value(t)])
deflinear2(add_jaxvals_p, lambda t, *args: (t, t))
def instantiate_zeros(tangent):
if type(tangent) is Zero:

View File

@ -398,8 +398,8 @@ def _sharding_constraint_translation_rule(c, x_node, partitions):
sharding_constraint_p = core.Primitive("sharding_constraint")
sharding_constraint_p.def_impl(_sharding_constraint_impl)
sharding_constraint_p.def_abstract_eval(lambda x, partitions: x)
ad.deflinear(sharding_constraint_p,
lambda ct, partitions: (with_sharding_constraint(ct, partitions),))
ad.deflinear2(sharding_constraint_p,
lambda ct, _, partitions: (with_sharding_constraint(ct, partitions),))
xla.translations[sharding_constraint_p] = _sharding_constraint_translation_rule
def with_sharding_constraint(x, partitions: Optional[PartitionSpec]):

View File

@ -1384,7 +1384,7 @@ device_put_p = core.Primitive('device_put')
device_put_p.def_impl(_device_put_impl)
device_put_p.def_abstract_eval(lambda x, device=None: x)
translations[device_put_p] = lambda c, x, device=None: x
ad.deflinear(device_put_p, lambda cotangent, **kwargs: [cotangent])
ad.deflinear2(device_put_p, lambda cotangent, _, **kwargs: [cotangent])
masking.defvectorized(device_put_p)