Merge pull request #21504 from gnecula:poly_approx

PiperOrigin-RevId: 638550165
This commit is contained in:
jax authors 2024-05-30 00:22:24 -07:00
commit f72b0f0ca6
2 changed files with 40 additions and 5 deletions

View File

@ -229,11 +229,19 @@ def _approx_top_k_abstract_eval(operand, *, k, reduction_dimension,
if not dtypes.issubdtype(operand.dtype, np.floating):
raise ValueError('operand must be a floating type')
reduction_input_size = dims[reduction_dimension]
dims[reduction_dimension] = xc.ops.ApproxTopKReductionOutputSize(
reduction_input_size, len(dims), k, recall_target, aggregate_to_topk,
reduction_input_size_override)[0]
return (operand.update(
shape=dims, dtype=operand.dtype, weak_type=operand.weak_type),
if aggregate_to_topk:
dims[reduction_dimension] = k
elif core.is_constant_shape((reduction_input_size, k)):
dims[reduction_dimension] = xc.ops.ApproxTopKReductionOutputSize(
reduction_input_size, len(dims), k, recall_target, aggregate_to_topk,
reduction_input_size_override)[0]
else:
raise NotImplementedError(
"approx_top_k with aggregate_to_topk=False not yet implemented when "
f"either the `k` ({k}) or the "
f" reduction dimension size ({reduction_input_size}) are symbolic")
return (operand.update(shape=dims, dtype=operand.dtype,
weak_type=operand.weak_type),
operand.update(shape=dims, dtype=np.dtype(np.int32)))

View File

@ -1997,6 +1997,32 @@ _POLY_SHAPE_TEST_HARNESSES = [
+ jnp.sin(x))),
arg_descriptors=[RandArg((3, 4), _f32)],
polymorphic_shapes=["b, ..."]),
[ # approx_max_k
# x: f32[b, {n}, 32] with n being either 8 or the symbol "n"
# we reduce on dim=1, with size n
# k is either the constant 4 or the symbol "k"
PolyHarness("approx_max_k", f"n_{n}_k_{k}_agg={agg}",
lambda x, x_k, agg: lax.approx_max_k(
x, k=x_k.shape[0], reduction_dimension=1,
aggregate_to_topk=agg),
arg_descriptors=[RandArg((3, 8, 32), _f32),
RandArg((4,), _f32),
StaticArg(agg)],
polymorphic_shapes=[f"b, {n}, 32", f"{k},"],
# k must be at most the reduction dimension size
symbolic_constraints=[f"{k} <= {n}"],
expect_error=(
(NotImplementedError, "aggregate_to_topk=False") if (
not agg and (isinstance(k, str) or
isinstance(n, str))) else
# TODO(b/339398482) fix case with k symbolic
(TypeError, "get") if (agg and isinstance(k, str)) else
None
))
for n in [8, "n"]
for k in [4, "k"]
for agg in [True, False]
],
[ # arange
PolyHarness("arange", name,
f_jax,
@ -3311,5 +3337,6 @@ class ShapePolyHarnessesTest(jtu.JaxTestCase):
for fname, _ in config_flags.items():
jax.config.update(fname, prev_jax_config_flags[fname])
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())