Rename Piles to Jumbles, to avoid unfortunate Imperial entanglements.

This commit is contained in:
Alexey Radul 2023-07-13 15:46:18 -04:00
parent f348366041
commit fbb587232c
3 changed files with 102 additions and 102 deletions

View File

@ -42,11 +42,11 @@ map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
# Piles
# Jumbles
# i:(Fin 3) => f32[[3, 1, 4].i]
@dataclasses.dataclass(frozen=True)
class PileTy:
class JumbleTy:
binder: core.Var
length: Union[int, Tracer, core.Var]
elt_ty: core.DShapedArray
@ -63,41 +63,41 @@ class IndexedAxisSize:
return f'{str(self.lengths)}.Var{id(self.idx)}'
replace = dataclasses.replace
# Pile(aval=a:3 => f32[[3 1 4].a],
# data=DeviceArray([0., 1., 2., 0., 0., 1., 2., 3.], dtype=float32))
# Jumble(aval=a:3 => f32[[3 1 4].a],
# data=DeviceArray([0., 1., 2., 0., 0., 1., 2., 3.], dtype=float32))
@dataclasses.dataclass(frozen=True)
class Pile:
aval: PileTy
class Jumble:
aval: JumbleTy
data: Array
# To vmap over a pile, one must specify the axis as PileAxis.
class PileAxis: pass
pile_axis = PileAxis()
# To vmap over a jumble, one must specify the axis as JumbleAxis.
class JumbleAxis: pass
jumble_axis = JumbleAxis()
# As a temporary measure before we have more general JITable / ADable interfaces
# (analogues to vmappable), to enable Piles to be used with other
# (analogues to vmappable), to enable Jumbles to be used with other
# transformations and higher-order primitives (primarily jit, though also grad
# with allow_int=True) we register them as pytrees.
# TODO(mattjj): add JITable / ADable interfaces, remove this pytree registration
def _pile_flatten(pile):
def _jumble_flatten(jumble):
lengths = []
new_shape = [lengths.append(d.lengths) or d.replace(lengths=len(lengths))
if type(d) is IndexedAxisSize else d
for d in pile.aval.elt_ty.shape]
elt_ty = pile.aval.elt_ty.update(shape=tuple(new_shape))
aval = pile.aval.replace(elt_ty=elt_ty)
return (lengths, pile.data), aval
def _pile_unflatten(aval, x):
for d in jumble.aval.elt_ty.shape]
elt_ty = jumble.aval.elt_ty.update(shape=tuple(new_shape))
aval = jumble.aval.replace(elt_ty=elt_ty)
return (lengths, jumble.data), aval
def _jumble_unflatten(aval, x):
lengths, data = x
new_shape = [d.replace(lengths=lengths[d.lengths - 1])
if type(d) is IndexedAxisSize else d
for d in aval.elt_ty.shape]
elt_ty = aval.elt_ty.update(shape=tuple(new_shape))
aval = aval.replace(elt_ty=elt_ty)
return Pile(aval, data)
register_pytree_node(Pile, _pile_flatten, _pile_unflatten)
return Jumble(aval, data)
register_pytree_node(Jumble, _jumble_flatten, _jumble_unflatten)
def _pile_result(axis_size, stacked_axis, ragged_axes, x):
def _jumble_result(axis_size, stacked_axis, ragged_axes, x):
binder = core.Var(0, '', core.ShapedArray((), np.dtype('int32')))
if stacked_axis != 0:
raise NotImplementedError # TODO Transpose x so the stacked axis is axis 0
@ -106,14 +106,14 @@ def _pile_result(axis_size, stacked_axis, ragged_axes, x):
for ragged_axis, segment_lens in ragged_axes:
shape[ragged_axis-1] = IndexedAxisSize(binder, segment_lens)
elt_ty = core.DShapedArray(tuple(shape), x.dtype, x.weak_type)
return Pile(PileTy(binder, axis_size, elt_ty), x)
return Jumble(JumbleTy(binder, axis_size, elt_ty), x)
@dataclasses.dataclass(frozen=True)
class RaggedAxis:
stacked_axis: int
# For each axis, we store its index and the corresponding segment lengths.
# For example, the pile i:(Fin 3) => f32[lens1.i, 7, lens2.i]
# For example, the jumble i:(Fin 3) => f32[lens1.i, 7, lens2.i]
# would be represented with ragged_axes = [(1, lens1), (3, lens2)]
ragged_axes: tuple[tuple[int, Array], ...]
@ -234,9 +234,9 @@ def to_elt(trace: Trace, get_idx: GetIdx, x: Vmappable, spec: MapSpec) -> Elt:
handler = to_elt_handlers.get(type(x))
if handler:
return handler(partial(to_elt, trace, get_idx), get_idx, x, spec)
elif type(x) is Pile:
if spec is not pile_axis:
raise TypeError("pile input without using pile_axis in_axes spec")
elif type(x) is Jumble:
if spec is not jumble_axis:
raise TypeError("jumble input without using jumble_axis in_axes spec")
ias: IndexedAxisSize # Not present in the AxisSize union in core.py
(d, ias), = ((i, sz) # type: ignore
for i, sz in enumerate(x.aval.elt_ty.shape)
@ -259,10 +259,10 @@ def from_elt(trace: 'BatchTrace', axis_size: AxisSize, x: Elt, spec: MapSpec
x_ = trace.full_raise(x)
val, bdim = x_.val, x_.batch_dim
if type(bdim) is RaggedAxis:
if spec is not pile_axis:
if spec is not jumble_axis:
# TODO(mattjj): improve this error message
raise TypeError("ragged output without using pile_axis out_axes spec")
return _pile_result(axis_size, bdim.stacked_axis, bdim.ragged_axes, val)
raise TypeError("ragged output without using jumble_axis out_axes spec")
return _jumble_result(axis_size, bdim.stacked_axis, bdim.ragged_axes, val)
else:
return matchaxis(trace.axis_name, axis_size, x_.batch_dim, spec, x_.val)
from_elt_handlers: dict[type, FromEltHandler] = {}
@ -284,7 +284,7 @@ def register_vmappable(data_type: type, spec_type: type, axis_size_type: type,
from_elt_handlers[data_type] = from_elt
if make_iota: make_iota_handlers[axis_size_type] = make_iota
vmappables: dict[type, tuple[type, type]] = {}
spec_types: set[type] = {PileAxis}
spec_types: set[type] = {JumbleAxis}
def unregister_vmappable(data_type: type) -> None:
spec_type, axis_size_type = vmappables.pop(data_type)
@ -295,7 +295,7 @@ def unregister_vmappable(data_type: type) -> None:
del make_iota_handlers[axis_size_type]
def is_vmappable(x: Any) -> bool:
return type(x) is Pile or type(x) in vmappables
return type(x) is Jumble or type(x) in vmappables
@lu.transformation_with_aux
def flatten_fun_for_vmap(in_tree, *args_flat):
@ -1089,12 +1089,12 @@ def broadcast(x, sz, axis):
return jax.lax.broadcast_in_dim(x, shape, broadcast_dims)
def matchaxis(axis_name, sz, src, dst, x, sum_match=False):
if dst == pile_axis:
if dst == jumble_axis:
x = bdim_at_front(x, src, sz)
elt_ty = x.aval.update(shape=x.shape[1:])
aval = PileTy(core.Var(0, '', core.ShapedArray((), np.dtype('int32'))),
x.shape[0], elt_ty)
return Pile(aval, x)
aval = JumbleTy(core.Var(0, '', core.ShapedArray((), np.dtype('int32'))),
x.shape[0], elt_ty)
return Jumble(aval, x)
try:
_ = core.get_aval(x)
except TypeError as e:

View File

@ -29,9 +29,9 @@ from jax._src.interpreters.batching import (
MakeIotaHandler as MakeIotaHandler,
MapSpec as MapSpec,
NotMapped as NotMapped,
Pile as Pile,
PileAxis as PileAxis,
PileTy as PileTy,
Jumble as Jumble,
JumbleAxis as JumbleAxis,
JumbleTy as JumbleTy,
ToEltHandler as ToEltHandler,
Vmappable as Vmappable,
Zero as Zero,
@ -60,7 +60,7 @@ from jax._src.interpreters.batching import (
matchaxis as matchaxis,
moveaxis as moveaxis,
not_mapped as not_mapped,
pile_axis as pile_axis,
jumble_axis as jumble_axis,
primitive_batchers as primitive_batchers,
reducer_batcher as reducer_batcher,
register_vmappable as register_vmappable,

View File

@ -1490,68 +1490,69 @@ class DynamicShapeExecutionTest(jtu.JaxTestCase):
@jtu.with_config(jax_dynamic_shapes=True, jax_numpy_rank_promotion="allow",
jax_traceback_filtering='off')
class PileTest(jtu.JaxTestCase):
class JumbleTest(jtu.JaxTestCase):
@parameterized.parameters((True,), (False,))
def test_internal_pile(self, disable_jit):
def test_internal_jumble(self, disable_jit):
config.update('jax_disable_jit', disable_jit)
ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5))
xs = jax.vmap(lambda n: jax.lax.iota('int32', n).sum())(ins)
self.assertAllClose(xs, jnp.array([3, 0, 6]), check_dtypes=False)
def test_pile_escapes(self):
def test_jumble_escapes(self):
ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5))
xs = jax.vmap(jax.jit(lambda n: jax.lax.iota('int32', n)),
out_axes=batching.pile_axis)(ins)
self.assertIsInstance(xs, batching.Pile)
out_axes=batching.jumble_axis)(ins)
self.assertIsInstance(xs, batching.Jumble)
data = jax.lax.broadcasted_iota('int32', (3, 5), 1)
self.assertAllClose(xs.data, data, check_dtypes=False)
def test_make_pile_from_dynamic_shape(self):
# We may not want to support returning piles from vmapped functions (instead
# preferring to have a separate API which allows piles). But for now it
# makes for a convenient way to construct piles for the other tests!
def test_make_jumble_from_dynamic_shape(self):
# We may not want to support returning jumbles from vmapped functions
# (instead preferring to have a separate API which allows jumbles). But for
# now it makes for a convenient way to construct jumbles for the other
# tests!
ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5))
p = jax.vmap(partial(jnp.arange, dtype='int32'),
out_axes=batching.pile_axis)(ins)
self.assertIsInstance(p, batching.Pile)
out_axes=batching.jumble_axis)(ins)
self.assertIsInstance(p, batching.Jumble)
self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[bint\{≤5\}\[3\] with value: \[3 1 4\]\.Var[0-9]+\]')
data = jax.lax.broadcasted_iota('int32', (3, 5), 1)
self.assertAllClose(p.data, data, check_dtypes=False)
def test_pile_map_eltwise(self):
def test_jumble_map_eltwise(self):
ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5))
p = jax.vmap(partial(jnp.arange, dtype='int32'),
out_axes=batching.pile_axis)(ins)
p = pile_map(jax.jit(lambda x: x ** 2))(p)
self.assertIsInstance(p, batching.Pile)
out_axes=batching.jumble_axis)(ins)
p = jumble_map(jax.jit(lambda x: x ** 2))(p)
self.assertIsInstance(p, batching.Jumble)
self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[bint\{≤5\}\[3\] with value: \[3 1 4\]\.Var[0-9]+\]')
data = jax.lax.broadcasted_iota('int32', (3, 5), 1) ** 2
self.assertAllClose(p.data, data, check_dtypes=False)
def test_pile_map_vector_dot(self):
def test_jumble_map_vector_dot(self):
ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5))
p = jax.vmap(partial(jnp.arange, dtype='int32'),
out_axes=batching.pile_axis)(ins)
y = pile_map(jnp.dot)(p, p)
self.assertIsInstance(y, batching.Pile)
out_axes=batching.jumble_axis)(ins)
y = jumble_map(jnp.dot)(p, p)
self.assertIsInstance(y, batching.Jumble)
self.assertAllClose(y.data, jnp.array([5, 0, 14], dtype='int32'))
@parameterized.parameters((True,), (False,))
def test_pile_map_matrix_dot_ragged_contract(self, disable_jit):
def test_jumble_map_matrix_dot_ragged_contract(self, disable_jit):
config.update('jax_disable_jit', disable_jit)
sizes = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5))
p1 = jax.vmap(lambda n: jnp.ones((7, n)), out_axes=batching.pile_axis
p1 = jax.vmap(lambda n: jnp.ones((7, n)), out_axes=batching.jumble_axis
)(sizes)
p2 = jax.vmap(lambda n: jnp.ones((n, 7)), out_axes=batching.pile_axis
p2 = jax.vmap(lambda n: jnp.ones((n, 7)), out_axes=batching.jumble_axis
)(sizes)
y = jax.vmap(jnp.dot, in_axes=batching.pile_axis, out_axes=0,
y = jax.vmap(jnp.dot, in_axes=batching.jumble_axis, out_axes=0,
axis_size=3)(p1, p2)
self.assertAllClose(y, np.tile(np.array([3, 1, 4])[:, None, None], (7, 7)),
check_dtypes=False)
@parameterized.parameters((True,), (False,))
def test_pile_map_matrix_dot_ragged_tensor(self, disable_jit):
def test_jumble_map_matrix_dot_ragged_tensor(self, disable_jit):
config.update('jax_disable_jit', disable_jit)
sizes = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5))
def func(size):
@ -1559,8 +1560,8 @@ class PileTest(jtu.JaxTestCase):
lhs_two_d = jax.lax.broadcast_in_dim(lhs_one_d, (size, 2), (0,))
rhs = jax.lax.broadcasted_iota('int32', (2, 4), 0) + 1
return jnp.dot(lhs_two_d, rhs)
p = jax.vmap(func, out_axes=batching.pile_axis)(sizes)
self.assertIsInstance(p, batching.Pile)
p = jax.vmap(func, out_axes=batching.jumble_axis)(sizes)
self.assertIsInstance(p, batching.Jumble)
self.assertEqual(p.data.shape, (3, 5, 4))
def test_broadcast_in_dim_while_ragged(self):
@ -1569,8 +1570,8 @@ class PileTest(jtu.JaxTestCase):
one_d = jnp.arange(size, dtype='int32')
two_d = jax.lax.broadcast_in_dim(one_d, (size, 7), (0,))
return two_d
p = jax.vmap(func, out_axes=batching.pile_axis)(ins)
self.assertIsInstance(p, batching.Pile)
p = jax.vmap(func, out_axes=batching.jumble_axis)(ins)
self.assertIsInstance(p, batching.Jumble)
data = jax.lax.broadcasted_iota('int32', (3, 5, 7), 1)
self.assertAllClose(p.data, data)
@ -1580,8 +1581,8 @@ class PileTest(jtu.JaxTestCase):
one_d = jnp.arange(12, dtype='int32')
two_d = jax.lax.broadcast_in_dim(one_d, (size, 12), (1,))
return two_d
p = jax.vmap(func, out_axes=batching.pile_axis)(ins)
self.assertIsInstance(p, batching.Pile)
p = jax.vmap(func, out_axes=batching.jumble_axis)(ins)
self.assertIsInstance(p, batching.Jumble)
data = jax.lax.broadcasted_iota('int32', (3, 5, 12), 2)
self.assertAllClose(p.data, data)
@ -1595,7 +1596,7 @@ class PileTest(jtu.JaxTestCase):
return two_d
msg = r"got operand of shape \(\[dynamic\],\), target broadcast shape \(4, 5\)"
with self.assertRaisesRegex(TypeError, msg):
jax.vmap(func, out_axes=batching.pile_axis)(ins)
jax.vmap(func, out_axes=batching.jumble_axis)(ins)
def test_broadcast_in_dim_to_doubly_ragged(self):
ins1 = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5))
@ -1604,8 +1605,8 @@ class PileTest(jtu.JaxTestCase):
one_d = jnp.arange(size1, dtype='int32')
two_d = jax.lax.broadcast_in_dim(one_d, (size1, size2), (0,))
return two_d
p = jax.vmap(func, out_axes=batching.pile_axis)(ins1, ins2)
self.assertIsInstance(p, batching.Pile)
p = jax.vmap(func, out_axes=batching.jumble_axis)(ins1, ins2)
self.assertIsInstance(p, batching.Jumble)
data = jax.lax.broadcasted_iota('int32', (3, 5, 6), 1)
self.assertAllClose(p.data, data)
@ -1616,8 +1617,8 @@ class PileTest(jtu.JaxTestCase):
two_d = jax.lax.broadcast_in_dim(one_d, (size, 1), (0,))
one_again = jax.lax.squeeze(two_d, dimensions=[1])
return one_again
p = jax.vmap(func, out_axes=batching.pile_axis)(ins)
self.assertIsInstance(p, batching.Pile)
p = jax.vmap(func, out_axes=batching.jumble_axis)(ins)
self.assertIsInstance(p, batching.Jumble)
data = jax.lax.broadcasted_iota('int32', (3, 5), 1)
self.assertAllClose(p.data, data)
@ -1627,8 +1628,8 @@ class PileTest(jtu.JaxTestCase):
one_d = jnp.arange(size, dtype='int32')
two_d = jnp.broadcast_to(one_d, (4, size))
return two_d
p = jax.vmap(func, out_axes=batching.pile_axis)(ins)
self.assertIsInstance(p, batching.Pile)
p = jax.vmap(func, out_axes=batching.jumble_axis)(ins)
self.assertIsInstance(p, batching.Jumble)
data = jax.lax.broadcasted_iota('int32', (3, 4, 5), 2)
self.assertAllClose(p.data, data)
@ -1638,8 +1639,8 @@ class PileTest(jtu.JaxTestCase):
one_d = jnp.arange(size, dtype='int32')
two_d = jnp.broadcast_to(one_d, (size, size))
return two_d
p = jax.vmap(func, out_axes=batching.pile_axis)(ins)
self.assertIsInstance(p, batching.Pile)
p = jax.vmap(func, out_axes=batching.jumble_axis)(ins)
self.assertIsInstance(p, batching.Jumble)
data = jax.lax.broadcasted_iota('int32', (3, 5, 5), 2)
self.assertAllClose(p.data, data)
@ -1649,8 +1650,8 @@ class PileTest(jtu.JaxTestCase):
one_d = jnp.arange(size, dtype='int32')
two_d = jnp.broadcast_to(one_d, (7, size))
return jnp.transpose(two_d, [1, 0])
p = jax.vmap(func, out_axes=batching.pile_axis)(ins)
self.assertIsInstance(p, batching.Pile)
p = jax.vmap(func, out_axes=batching.jumble_axis)(ins)
self.assertIsInstance(p, batching.Jumble)
data = jax.lax.broadcasted_iota('int32', (3, 5, 7), 1)
self.assertAllClose(p.data, data)
@ -1662,8 +1663,8 @@ class PileTest(jtu.JaxTestCase):
wqkv = jax.lax.broadcasted_iota('int32', (3, 2, 7, 11), 1)
qkv = jnp.einsum('te,ihqe->ithq', x, wqkv)
return qkv
p = jax.vmap(fprop_layer, out_axes=batching.pile_axis)(x_sizes)
self.assertIsInstance(p, batching.Pile)
p = jax.vmap(fprop_layer, out_axes=batching.jumble_axis)(x_sizes)
self.assertIsInstance(p, batching.Jumble)
self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[3,bint\{≤5\}\[3\] with value: \[3 1 4\]\.Var[0-9]+,2,7\]')
self.assertEqual(p.data.shape, (3, 3, 5, 2, 7))
@ -1677,8 +1678,8 @@ class PileTest(jtu.JaxTestCase):
v = jax.lax.broadcast_in_dim(one_d, (ragged_size, 2, 7), [0])
inner = jnp.einsum('tsh,shq->thq', alpha, v)
return inner
p = jax.vmap(fprop_layer, out_axes=batching.pile_axis)(ragged_sizes)
self.assertIsInstance(p, batching.Pile)
p = jax.vmap(fprop_layer, out_axes=batching.jumble_axis)(ragged_sizes)
self.assertIsInstance(p, batching.Jumble)
self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[bint\{≤5\}\[3\] with value: \[3 1 4\]\.Var[0-9]+,2,7\]')
self.assertEqual(p.data.shape, (3, 5, 2, 7))
@ -1689,14 +1690,14 @@ class PileTest(jtu.JaxTestCase):
two_d = jnp.broadcast_to(one_d, (2, size))
part_1, part_2 = two_d
return part_1
p = jax.vmap(func, out_axes=batching.pile_axis)(ins)
self.assertIsInstance(p, batching.Pile)
p = jax.vmap(func, out_axes=batching.jumble_axis)(ins)
self.assertIsInstance(p, batching.Jumble)
self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[bint\{≤5\}\[3\] with value: \[3 1 4\]\.Var[0-9]+\]')
data = jax.lax.broadcasted_iota('int32', (3, 5), 1)
self.assertAllClose(p.data, data)
@parameterized.parameters((True,), (False,))
def test_pile_map_end_to_end_fprop_layer(self, disable_jit):
def test_jumble_map_end_to_end_fprop_layer(self, disable_jit):
config.update('jax_disable_jit', disable_jit)
def fprop_layer(params, x):
@ -1731,13 +1732,12 @@ class PileTest(jtu.JaxTestCase):
jnp.zeros((420, 128)),
]
def pile_stack(xs: list[jax.Array]) -> batching.Pile:
def jumble_stack(xs: list[jax.Array]) -> batching.Jumble:
max_length = max(len(x) for x in xs)
lengths = jnp.array([len(x) for x in xs])
lengths = jax.lax.convert_element_type(lengths, core.bint(max_length))
xs_padded = jnp.stack([jnp.zeros((max_length, 128), dtype=x.dtype
).at[:x.shape[0]].set(x) for x in xs])
# jax.vmap(lambda l, xp: xp[:l, :], out_axes=pile_axis)(lengths, xs_padded)
# binder = i
binder = core.Var(0, '', core.ShapedArray((), np.dtype('int32')))
@ -1745,26 +1745,26 @@ class PileTest(jtu.JaxTestCase):
elt_ty = core.DShapedArray((batching.IndexedAxisSize(binder, lengths), 128),
xs_padded.dtype)
# aval = i:(Fin 3) => f32[[3, 1, 4].i, 128]
aval = batching.PileTy(binder, len(xs), elt_ty)
xs_pile = batching.Pile(aval, xs_padded)
return xs_pile
aval = batching.JumbleTy(binder, len(xs), elt_ty)
xs_jumble = batching.Jumble(aval, xs_padded)
return xs_jumble
xs_pile = pile_stack(xs)
xs_jumble = jumble_stack(xs)
fprop_batched = jax.vmap(fprop_layer,
in_axes=(None, batching.pile_axis),
out_axes=batching.pile_axis,
in_axes=(None, batching.jumble_axis),
out_axes=batching.jumble_axis,
axis_size=3)
result_jumble = fprop_batched(params, xs_jumble)
self.assertIsInstance(result_jumble, batching.Jumble)
regex = r'Var[0-9]+:3 => (f32|f64)\[bint\{≤512\}\[3\] with value: \[512 386 420\]\.Var[0-9]+,128\]'
self.assertRegex(str(result_jumble.aval), regex)
self.assertAllClose(result_jumble.data.shape, (3, 512, 128))
result_pile = fprop_batched(params, xs_pile)
self.assertIsInstance(result_pile, batching.Pile)
self.assertRegex(str(result_pile.aval), r'Var[0-9]+:3 => (f32|f64)\[bint\{≤512\}\[3\] with value: \[512 386 420\]\.Var[0-9]+,128\]')
self.assertAllClose(result_pile.data.shape, (3, 512, 128))
def pile_map(f):
def mapped(*piles):
return jax.vmap(f, in_axes=batching.pile_axis, out_axes=batching.pile_axis,
axis_size=piles[0].aval.length)(*piles)
def jumble_map(f):
def mapped(*jumbles):
return jax.vmap(f, in_axes=batching.jumble_axis, out_axes=batching.jumble_axis,
axis_size=jumbles[0].aval.length)(*jumbles)
return mapped
if __name__ == '__main__':