mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #18048 from superbobry:all-any-list-comp
PiperOrigin-RevId: 572384827
This commit is contained in:
commit
899cc30419
@ -785,7 +785,7 @@ def trace_to_subjaxpr_nounits(
|
||||
main: core.MainTrace,
|
||||
instantiate: bool | Sequence[bool],
|
||||
in_pvals: Sequence[PartialVal]):
|
||||
assert all([isinstance(pv, PartialVal) for pv in in_pvals]), in_pvals
|
||||
assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals
|
||||
out_tracers, jaxpr, out_consts, env = yield from _trace_to_subjaxpr_nounits(
|
||||
main, instantiate, in_pvals)
|
||||
out_pvals = [t.pval for t in out_tracers]
|
||||
@ -820,7 +820,7 @@ def trace_to_subjaxpr_nounits_fwd(
|
||||
main: core.MainTrace,
|
||||
instantiate: bool | Sequence[bool],
|
||||
in_pvals: Sequence[PartialVal]):
|
||||
assert all([isinstance(pv, PartialVal) for pv in in_pvals]), in_pvals
|
||||
assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals
|
||||
out_tracers, jaxpr, out_consts, env = yield from _trace_to_subjaxpr_nounits(
|
||||
main, instantiate, in_pvals)
|
||||
out_pvals = [t.pval for t in out_tracers]
|
||||
@ -2608,7 +2608,7 @@ def call_padding_rule(prim, in_avals, out_avals, *args, call_jaxpr, **params):
|
||||
@lu.transformation
|
||||
def trace_to_subjaxpr(main: core.MainTrace, instantiate: bool | Sequence[bool],
|
||||
pvals: Sequence[PartialVal]):
|
||||
assert all([isinstance(pv, PartialVal) for pv in pvals]), pvals
|
||||
assert all(isinstance(pv, PartialVal) for pv in pvals), pvals
|
||||
trace = main.with_cur_sublevel()
|
||||
in_tracers = map(trace.new_arg, pvals)
|
||||
ans = yield in_tracers, {}
|
||||
|
@ -1451,7 +1451,7 @@ def _while_partial_eval(trace: pe.JaxprTrace, *tracers: pe.Tracer, cond_nconsts:
|
||||
cond_jaxpr_known, _, cond_uk, _ = pe.partial_eval_jaxpr_nounits( # type: ignore
|
||||
cond_jaxpr, cond_consts_uk + carry_uk, instantiate=False)
|
||||
|
||||
if cond_uk[0] or all([not uk for uk in unknowns]) or all(unknowns):
|
||||
if cond_uk[0] or all(not uk for uk in unknowns) or all(unknowns):
|
||||
# If conditional is unknown, or all inputs are known, or all are unknown,
|
||||
# just do the default processing.
|
||||
return trace.default_process_primitive(while_p, tracers, params)
|
||||
|
@ -4134,7 +4134,7 @@ def take_along_axis(
|
||||
lst[axis_int] = val
|
||||
return tuple(lst)
|
||||
|
||||
use_64bit_index = any([not core.is_constant_dim(d) or d >= (1 << 31) for d in a.shape])
|
||||
use_64bit_index = any(not core.is_constant_dim(d) or d >= (1 << 31) for d in a.shape)
|
||||
index_dtype = dtype(int64 if use_64bit_index else int32)
|
||||
indices = lax.convert_element_type(indices, index_dtype)
|
||||
|
||||
@ -4468,7 +4468,7 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
|
||||
collapsed_slice_dims: Sequence[int] = []
|
||||
start_index_map: Sequence[int] = []
|
||||
|
||||
use_64bit_index = any([not core.is_constant_dim(d) or d >= (1 << 31) for d in x_shape])
|
||||
use_64bit_index = any(not core.is_constant_dim(d) or d >= (1 << 31) for d in x_shape)
|
||||
index_dtype = int64 if use_64bit_index else int32
|
||||
|
||||
# Gather indices.
|
||||
|
@ -430,7 +430,7 @@ def _call_tf_abstract_eval(
|
||||
return output_avals, effects
|
||||
|
||||
def is_fully_known_shape(s):
|
||||
return s.rank is not None and all([d is not None for d in s])
|
||||
return s.rank is not None and all(d is not None for d in s)
|
||||
|
||||
if all(is_fully_known_shape(s)
|
||||
for s in concrete_function_flat_tf.output_shapes):
|
||||
|
@ -272,8 +272,8 @@ def _conv_general_dilated(
|
||||
in_channels = lhs_shape[-1]
|
||||
*rhs_spatial_shapes, _, rhs_out_channel = rhs_shape
|
||||
|
||||
is_transpose = any([d != 1 for d in lhs_dilation])
|
||||
is_atrous = any([d != 1 for d in rhs_dilation])
|
||||
is_transpose = any(d != 1 for d in lhs_dilation)
|
||||
is_atrous = any(d != 1 for d in rhs_dilation)
|
||||
is_depthwise = in_channels == feature_group_count and feature_group_count > 1
|
||||
_validate_conv_features(is_transpose, is_atrous, is_depthwise,
|
||||
feature_group_count, batch_group_count,
|
||||
@ -416,7 +416,7 @@ def _dot_general(lhs, rhs, *, dimension_numbers,
|
||||
squeeze_idxs.append(len(rhs.shape) - 1)
|
||||
result = tf.linalg.matmul(lhs, rhs)
|
||||
if len(squeeze_idxs) != 0:
|
||||
assert all([result.shape[i] == 1 for i in squeeze_idxs])
|
||||
assert all(result.shape[i] == 1 for i in squeeze_idxs)
|
||||
result = tf.squeeze(result, squeeze_idxs)
|
||||
return convert_result(result)
|
||||
|
||||
|
@ -62,7 +62,7 @@ class JaxPrimitiveTest(jtu.JaxTestCase):
|
||||
jax_unimpl = [l for l in harness.jax_unimplemented
|
||||
if l.filter(device=jtu.device_under_test(),
|
||||
dtype=harness.dtype)]
|
||||
if any([lim.skip_run for lim in jax_unimpl]):
|
||||
if any(lim.skip_run for lim in jax_unimpl):
|
||||
logging.info(
|
||||
"Skipping run with expected JAX limitations: %s in harness %s",
|
||||
[u.description for u in jax_unimpl], harness.fullname)
|
||||
|
@ -229,11 +229,11 @@ class Harness:
|
||||
include_jax_unimpl: bool = False,
|
||||
one_containing: Optional[str] = None) -> bool:
|
||||
if not include_jax_unimpl:
|
||||
if any([
|
||||
if any(
|
||||
device_under_test in l.devices
|
||||
for l in self.jax_unimplemented
|
||||
if l.filter(device=device_under_test, dtype=self.dtype)
|
||||
]):
|
||||
):
|
||||
return False
|
||||
|
||||
if one_containing is not None and one_containing not in self.fullname:
|
||||
@ -248,32 +248,32 @@ def dtypes_to_str(dtype_list: Sequence[DType], empty_means_all=False) -> str:
|
||||
|
||||
names = {np.dtype(dt).name for dt in dtype_list}
|
||||
signed = {"int8", "int16", "int32", "int64"}
|
||||
if all([t in names for t in signed]):
|
||||
if signed <= names:
|
||||
names = (names - signed) | {"signed"}
|
||||
integers = {"uint8", "uint16", "uint32", "uint64"}
|
||||
if all([t in names for t in integers]):
|
||||
if integers <= names:
|
||||
names = (names - integers) | {"unsigned"}
|
||||
integer = {"signed", "unsigned"}
|
||||
if all([t in names for t in integer]):
|
||||
if integer <= names:
|
||||
names = (names - integer) | {"integer"}
|
||||
|
||||
floating = {"bfloat16", "float16", "float32", "float64"}
|
||||
if all([t in names for t in floating]):
|
||||
if floating <= names:
|
||||
names = (names - floating) | {"floating"}
|
||||
|
||||
complex = {"complex64", "complex128"}
|
||||
if all([t in names for t in complex]):
|
||||
if complex <= names:
|
||||
names = (names - complex) | {"complex"}
|
||||
|
||||
inexact = {"floating", "complex"}
|
||||
if all([t in names for t in inexact]):
|
||||
if inexact <= names:
|
||||
names = (names - inexact) | {"inexact"}
|
||||
|
||||
all_types = {"integer", "inexact", "bool"}
|
||||
if all([t in names for t in all_types]):
|
||||
if all_types <= names:
|
||||
names = (names - all_types) | {"all"}
|
||||
|
||||
return ", ".join(sorted(list(names)))
|
||||
return ", ".join(sorted(names))
|
||||
|
||||
|
||||
##### All harnesses in this file.
|
||||
|
@ -42,8 +42,7 @@ except ImportError:
|
||||
tf = None
|
||||
|
||||
|
||||
dlpack_dtypes = sorted(list(jax.dlpack.SUPPORTED_DTYPES),
|
||||
key=lambda x: x.__name__)
|
||||
dlpack_dtypes = sorted(jax.dlpack.SUPPORTED_DTYPES, key=lambda x: x.__name__)
|
||||
|
||||
numpy_dtypes = sorted(
|
||||
[dt for dt in jax.dlpack.SUPPORTED_DTYPES if dt != jnp.bfloat16],
|
||||
|
@ -366,7 +366,7 @@ class CoreTest(jtu.JaxTestCase):
|
||||
b = newsym(core.ShapedArray((), np.dtype('int32')))
|
||||
c = newsym(core.ShapedArray((), np.dtype('int32')))
|
||||
for ordering in it.permutations([a, b, c]):
|
||||
assert sorted(list(ordering)) == [a, b, c]
|
||||
assert sorted(ordering) == [a, b, c]
|
||||
|
||||
def test_var_compared_by_identity(self):
|
||||
a1 = core.gensym()(core.ShapedArray((), np.dtype('int32')))
|
||||
|
@ -737,7 +737,7 @@ class SparseObjectTest(sptu.SparseTestCase):
|
||||
M = sparse.empty((2, 4), sparse_format=sparse_format)
|
||||
self.assertIsInstance(M, cls)
|
||||
buffers, tree = tree_util.tree_flatten(M)
|
||||
self.assertTrue(all([isinstance(buffer, jax.Array) for buffer in buffers]))
|
||||
self.assertTrue(all(isinstance(buffer, jax.Array) for buffer in buffers))
|
||||
M_out = tree_util.tree_unflatten(tree, buffers)
|
||||
self.assertEqual(M.dtype, M_out.dtype)
|
||||
self.assertEqual(M.shape, M_out.shape)
|
||||
|
Loading…
x
Reference in New Issue
Block a user