Merge pull request #18048 from superbobry:all-any-list-comp

PiperOrigin-RevId: 572384827
This commit is contained in:
jax authors 2023-10-10 15:32:55 -07:00
commit 899cc30419
10 changed files with 24 additions and 25 deletions

View File

@ -785,7 +785,7 @@ def trace_to_subjaxpr_nounits(
main: core.MainTrace,
instantiate: bool | Sequence[bool],
in_pvals: Sequence[PartialVal]):
assert all([isinstance(pv, PartialVal) for pv in in_pvals]), in_pvals
assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals
out_tracers, jaxpr, out_consts, env = yield from _trace_to_subjaxpr_nounits(
main, instantiate, in_pvals)
out_pvals = [t.pval for t in out_tracers]
@ -820,7 +820,7 @@ def trace_to_subjaxpr_nounits_fwd(
main: core.MainTrace,
instantiate: bool | Sequence[bool],
in_pvals: Sequence[PartialVal]):
assert all([isinstance(pv, PartialVal) for pv in in_pvals]), in_pvals
assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals
out_tracers, jaxpr, out_consts, env = yield from _trace_to_subjaxpr_nounits(
main, instantiate, in_pvals)
out_pvals = [t.pval for t in out_tracers]
@ -2608,7 +2608,7 @@ def call_padding_rule(prim, in_avals, out_avals, *args, call_jaxpr, **params):
@lu.transformation
def trace_to_subjaxpr(main: core.MainTrace, instantiate: bool | Sequence[bool],
pvals: Sequence[PartialVal]):
assert all([isinstance(pv, PartialVal) for pv in pvals]), pvals
assert all(isinstance(pv, PartialVal) for pv in pvals), pvals
trace = main.with_cur_sublevel()
in_tracers = map(trace.new_arg, pvals)
ans = yield in_tracers, {}

View File

@ -1451,7 +1451,7 @@ def _while_partial_eval(trace: pe.JaxprTrace, *tracers: pe.Tracer, cond_nconsts:
cond_jaxpr_known, _, cond_uk, _ = pe.partial_eval_jaxpr_nounits( # type: ignore
cond_jaxpr, cond_consts_uk + carry_uk, instantiate=False)
if cond_uk[0] or all([not uk for uk in unknowns]) or all(unknowns):
if cond_uk[0] or all(not uk for uk in unknowns) or all(unknowns):
# If conditional is unknown, or all inputs are known, or all are unknown,
# just do the default processing.
return trace.default_process_primitive(while_p, tracers, params)

View File

@ -4134,7 +4134,7 @@ def take_along_axis(
lst[axis_int] = val
return tuple(lst)
use_64bit_index = any([not core.is_constant_dim(d) or d >= (1 << 31) for d in a.shape])
use_64bit_index = any(not core.is_constant_dim(d) or d >= (1 << 31) for d in a.shape)
index_dtype = dtype(int64 if use_64bit_index else int32)
indices = lax.convert_element_type(indices, index_dtype)
@ -4468,7 +4468,7 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
collapsed_slice_dims: Sequence[int] = []
start_index_map: Sequence[int] = []
use_64bit_index = any([not core.is_constant_dim(d) or d >= (1 << 31) for d in x_shape])
use_64bit_index = any(not core.is_constant_dim(d) or d >= (1 << 31) for d in x_shape)
index_dtype = int64 if use_64bit_index else int32
# Gather indices.

View File

@ -430,7 +430,7 @@ def _call_tf_abstract_eval(
return output_avals, effects
def is_fully_known_shape(s):
return s.rank is not None and all([d is not None for d in s])
return s.rank is not None and all(d is not None for d in s)
if all(is_fully_known_shape(s)
for s in concrete_function_flat_tf.output_shapes):

View File

@ -272,8 +272,8 @@ def _conv_general_dilated(
in_channels = lhs_shape[-1]
*rhs_spatial_shapes, _, rhs_out_channel = rhs_shape
is_transpose = any([d != 1 for d in lhs_dilation])
is_atrous = any([d != 1 for d in rhs_dilation])
is_transpose = any(d != 1 for d in lhs_dilation)
is_atrous = any(d != 1 for d in rhs_dilation)
is_depthwise = in_channels == feature_group_count and feature_group_count > 1
_validate_conv_features(is_transpose, is_atrous, is_depthwise,
feature_group_count, batch_group_count,
@ -416,7 +416,7 @@ def _dot_general(lhs, rhs, *, dimension_numbers,
squeeze_idxs.append(len(rhs.shape) - 1)
result = tf.linalg.matmul(lhs, rhs)
if len(squeeze_idxs) != 0:
assert all([result.shape[i] == 1 for i in squeeze_idxs])
assert all(result.shape[i] == 1 for i in squeeze_idxs)
result = tf.squeeze(result, squeeze_idxs)
return convert_result(result)

View File

@ -62,7 +62,7 @@ class JaxPrimitiveTest(jtu.JaxTestCase):
jax_unimpl = [l for l in harness.jax_unimplemented
if l.filter(device=jtu.device_under_test(),
dtype=harness.dtype)]
if any([lim.skip_run for lim in jax_unimpl]):
if any(lim.skip_run for lim in jax_unimpl):
logging.info(
"Skipping run with expected JAX limitations: %s in harness %s",
[u.description for u in jax_unimpl], harness.fullname)

View File

@ -229,11 +229,11 @@ class Harness:
include_jax_unimpl: bool = False,
one_containing: Optional[str] = None) -> bool:
if not include_jax_unimpl:
if any([
if any(
device_under_test in l.devices
for l in self.jax_unimplemented
if l.filter(device=device_under_test, dtype=self.dtype)
]):
):
return False
if one_containing is not None and one_containing not in self.fullname:
@ -248,32 +248,32 @@ def dtypes_to_str(dtype_list: Sequence[DType], empty_means_all=False) -> str:
names = {np.dtype(dt).name for dt in dtype_list}
signed = {"int8", "int16", "int32", "int64"}
if all([t in names for t in signed]):
if signed <= names:
names = (names - signed) | {"signed"}
integers = {"uint8", "uint16", "uint32", "uint64"}
if all([t in names for t in integers]):
if integers <= names:
names = (names - integers) | {"unsigned"}
integer = {"signed", "unsigned"}
if all([t in names for t in integer]):
if integer <= names:
names = (names - integer) | {"integer"}
floating = {"bfloat16", "float16", "float32", "float64"}
if all([t in names for t in floating]):
if floating <= names:
names = (names - floating) | {"floating"}
complex = {"complex64", "complex128"}
if all([t in names for t in complex]):
if complex <= names:
names = (names - complex) | {"complex"}
inexact = {"floating", "complex"}
if all([t in names for t in inexact]):
if inexact <= names:
names = (names - inexact) | {"inexact"}
all_types = {"integer", "inexact", "bool"}
if all([t in names for t in all_types]):
if all_types <= names:
names = (names - all_types) | {"all"}
return ", ".join(sorted(list(names)))
return ", ".join(sorted(names))
##### All harnesses in this file.

View File

@ -42,8 +42,7 @@ except ImportError:
tf = None
dlpack_dtypes = sorted(list(jax.dlpack.SUPPORTED_DTYPES),
key=lambda x: x.__name__)
dlpack_dtypes = sorted(jax.dlpack.SUPPORTED_DTYPES, key=lambda x: x.__name__)
numpy_dtypes = sorted(
[dt for dt in jax.dlpack.SUPPORTED_DTYPES if dt != jnp.bfloat16],

View File

@ -366,7 +366,7 @@ class CoreTest(jtu.JaxTestCase):
b = newsym(core.ShapedArray((), np.dtype('int32')))
c = newsym(core.ShapedArray((), np.dtype('int32')))
for ordering in it.permutations([a, b, c]):
assert sorted(list(ordering)) == [a, b, c]
assert sorted(ordering) == [a, b, c]
def test_var_compared_by_identity(self):
a1 = core.gensym()(core.ShapedArray((), np.dtype('int32')))

View File

@ -737,7 +737,7 @@ class SparseObjectTest(sptu.SparseTestCase):
M = sparse.empty((2, 4), sparse_format=sparse_format)
self.assertIsInstance(M, cls)
buffers, tree = tree_util.tree_flatten(M)
self.assertTrue(all([isinstance(buffer, jax.Array) for buffer in buffers]))
self.assertTrue(all(isinstance(buffer, jax.Array) for buffer in buffers))
M_out = tree_util.tree_unflatten(tree, buffers)
self.assertEqual(M.dtype, M_out.dtype)
self.assertEqual(M.shape, M_out.shape)