Prefer generator expressions over list comprehensions
This commit is contained in:
parent
e7cc47384f
commit
bc99898e99
|
@ -279,7 +279,7 @@ class QuantizationMode(object):
|
|||
})
|
||||
|
||||
for node_def in self._graph_def.node:
|
||||
if any([op in node_def.name for op in training_quant_ops]):
|
||||
if any(op in node_def.name for op in training_quant_ops):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
|
|
@ -726,7 +726,7 @@ class _DelayedRewriteGradientFunctions(object):
|
|||
# pylint: enable=protected-access
|
||||
|
||||
capture_mapping = dict(
|
||||
zip([ops.tensor_id(t) for t in self._func_graph.outputs], op.outputs))
|
||||
zip((ops.tensor_id(t) for t in self._func_graph.outputs), op.outputs))
|
||||
remapped_captures = [
|
||||
capture_mapping.get(ops.tensor_id(capture), capture)
|
||||
for capture in backwards_function.captured_inputs
|
||||
|
|
|
@ -58,8 +58,7 @@ def _recursive_apply(tensors, apply_fn):
|
|||
return tuple(tensors)
|
||||
return tensors_type(*tensors) # collections.namedtuple
|
||||
elif tensors_type is dict:
|
||||
return dict([(k, _recursive_apply(v, apply_fn)) for k, v in tensors.items()
|
||||
])
|
||||
return dict((k, _recursive_apply(v, apply_fn)) for k, v in tensors.items())
|
||||
else:
|
||||
raise TypeError('_recursive_apply argument %r has invalid type %r' %
|
||||
(tensors, tensors_type))
|
||||
|
|
|
@ -486,7 +486,7 @@ def skip_if_error(test_obj, error_type, messages=None):
|
|||
try:
|
||||
yield
|
||||
except error_type as e:
|
||||
if not messages or any([message in str(e) for message in messages]):
|
||||
if not messages or any(message in str(e) for message in messages):
|
||||
test_obj.skipTest("Skipping error: {}".format(str(e)))
|
||||
else:
|
||||
raise
|
||||
|
|
|
@ -2341,7 +2341,7 @@ class CSVLogger(Callback):
|
|||
|
||||
if self.model.stop_training:
|
||||
# We set NA so that csv parsers do not fail for this last epoch.
|
||||
logs = dict([(k, logs[k]) if k in logs else (k, 'NA') for k in self.keys])
|
||||
logs = dict((k, logs[k]) if k in logs else (k, 'NA') for k in self.keys)
|
||||
|
||||
if not self.writer:
|
||||
|
||||
|
|
|
@ -140,9 +140,9 @@ class CategoryCrossing(Layer):
|
|||
def call(self, inputs):
|
||||
depth_tuple = self._depth_tuple if self.depth else (len(inputs),)
|
||||
ragged_out = sparse_out = False
|
||||
if any([ragged_tensor.is_ragged(inp) for inp in inputs]):
|
||||
if any(ragged_tensor.is_ragged(inp) for inp in inputs):
|
||||
ragged_out = True
|
||||
elif any([isinstance(inp, sparse_tensor.SparseTensor) for inp in inputs]):
|
||||
elif any(isinstance(inp, sparse_tensor.SparseTensor) for inp in inputs):
|
||||
sparse_out = True
|
||||
|
||||
outputs = []
|
||||
|
|
|
@ -168,7 +168,7 @@ class Hashing(Layer):
|
|||
def _process_input_list(self, inputs):
|
||||
# TODO(momernick): support ragged_cross_hashed with corrected fingerprint
|
||||
# and siphash.
|
||||
if any([isinstance(inp, ragged_tensor.RaggedTensor) for inp in inputs]):
|
||||
if any(isinstance(inp, ragged_tensor.RaggedTensor) for inp in inputs):
|
||||
raise ValueError('Hashing with ragged input is not supported yet.')
|
||||
sparse_inputs = [
|
||||
inp for inp in inputs if isinstance(inp, sparse_tensor.SparseTensor)
|
||||
|
|
|
@ -876,7 +876,7 @@ def _legacy_weights(layer):
|
|||
non_trainable_weights.
|
||||
"""
|
||||
weights = layer.trainable_weights + layer.non_trainable_weights
|
||||
if any([not isinstance(w, variables_module.Variable) for w in weights]):
|
||||
if any(not isinstance(w, variables_module.Variable) for w in weights):
|
||||
raise NotImplementedError(
|
||||
'Save or restore weights that is not an instance of `tf.Variable` is '
|
||||
'not supported in h5, use `save_format=\'tf\'` instead. Got a model '
|
||||
|
|
|
@ -558,7 +558,7 @@ def main(argv):
|
|||
# Only build images for host architecture
|
||||
proc_arch = platform.processor()
|
||||
is_x86 = proc_arch.startswith('x86')
|
||||
if (is_x86 and any([arch in tag for arch in ['ppc64le']]) or
|
||||
if (is_x86 and any(arch in tag for arch in ['ppc64le']) or
|
||||
not is_x86 and proc_arch not in tag):
|
||||
continue
|
||||
|
||||
|
|
|
@ -62,7 +62,7 @@ def check_cuda_lib(path, check_soname=True):
|
|||
output = subprocess.check_output([objdump, "-p", path]).decode("utf-8")
|
||||
output = [line for line in output.splitlines() if "SONAME" in line]
|
||||
sonames = [line.strip().split(" ")[-1] for line in output]
|
||||
if not any([soname == os.path.basename(path) for soname in sonames]):
|
||||
if not any(soname == os.path.basename(path) for soname in sonames):
|
||||
raise ConfigError("None of the libraries match their SONAME: " + path)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue