Prefer generator expressions over list comprehensions

This commit is contained in:
Lukas Geiger 2020-05-26 22:15:07 +01:00
parent e7cc47384f
commit bc99898e99
10 changed files with 11 additions and 12 deletions
tensorflow
lite/python
python
tools/dockerfiles
third_party/gpus

View File

@ -279,7 +279,7 @@ class QuantizationMode(object):
}) })
for node_def in self._graph_def.node: for node_def in self._graph_def.node:
if any([op in node_def.name for op in training_quant_ops]): if any(op in node_def.name for op in training_quant_ops):
return True return True
return False return False

View File

@ -726,7 +726,7 @@ class _DelayedRewriteGradientFunctions(object):
# pylint: enable=protected-access # pylint: enable=protected-access
capture_mapping = dict( capture_mapping = dict(
zip([ops.tensor_id(t) for t in self._func_graph.outputs], op.outputs)) zip((ops.tensor_id(t) for t in self._func_graph.outputs), op.outputs))
remapped_captures = [ remapped_captures = [
capture_mapping.get(ops.tensor_id(capture), capture) capture_mapping.get(ops.tensor_id(capture), capture)
for capture in backwards_function.captured_inputs for capture in backwards_function.captured_inputs

View File

@ -58,8 +58,7 @@ def _recursive_apply(tensors, apply_fn):
return tuple(tensors) return tuple(tensors)
return tensors_type(*tensors) # collections.namedtuple return tensors_type(*tensors) # collections.namedtuple
elif tensors_type is dict: elif tensors_type is dict:
return dict([(k, _recursive_apply(v, apply_fn)) for k, v in tensors.items() return dict((k, _recursive_apply(v, apply_fn)) for k, v in tensors.items())
])
else: else:
raise TypeError('_recursive_apply argument %r has invalid type %r' % raise TypeError('_recursive_apply argument %r has invalid type %r' %
(tensors, tensors_type)) (tensors, tensors_type))

View File

@ -486,7 +486,7 @@ def skip_if_error(test_obj, error_type, messages=None):
try: try:
yield yield
except error_type as e: except error_type as e:
if not messages or any([message in str(e) for message in messages]): if not messages or any(message in str(e) for message in messages):
test_obj.skipTest("Skipping error: {}".format(str(e))) test_obj.skipTest("Skipping error: {}".format(str(e)))
else: else:
raise raise

View File

@ -2341,7 +2341,7 @@ class CSVLogger(Callback):
if self.model.stop_training: if self.model.stop_training:
# We set NA so that csv parsers do not fail for this last epoch. # We set NA so that csv parsers do not fail for this last epoch.
logs = dict([(k, logs[k]) if k in logs else (k, 'NA') for k in self.keys]) logs = dict((k, logs[k]) if k in logs else (k, 'NA') for k in self.keys)
if not self.writer: if not self.writer:

View File

@ -140,9 +140,9 @@ class CategoryCrossing(Layer):
def call(self, inputs): def call(self, inputs):
depth_tuple = self._depth_tuple if self.depth else (len(inputs),) depth_tuple = self._depth_tuple if self.depth else (len(inputs),)
ragged_out = sparse_out = False ragged_out = sparse_out = False
if any([ragged_tensor.is_ragged(inp) for inp in inputs]): if any(ragged_tensor.is_ragged(inp) for inp in inputs):
ragged_out = True ragged_out = True
elif any([isinstance(inp, sparse_tensor.SparseTensor) for inp in inputs]): elif any(isinstance(inp, sparse_tensor.SparseTensor) for inp in inputs):
sparse_out = True sparse_out = True
outputs = [] outputs = []

View File

@ -168,7 +168,7 @@ class Hashing(Layer):
def _process_input_list(self, inputs): def _process_input_list(self, inputs):
# TODO(momernick): support ragged_cross_hashed with corrected fingerprint # TODO(momernick): support ragged_cross_hashed with corrected fingerprint
# and siphash. # and siphash.
if any([isinstance(inp, ragged_tensor.RaggedTensor) for inp in inputs]): if any(isinstance(inp, ragged_tensor.RaggedTensor) for inp in inputs):
raise ValueError('Hashing with ragged input is not supported yet.') raise ValueError('Hashing with ragged input is not supported yet.')
sparse_inputs = [ sparse_inputs = [
inp for inp in inputs if isinstance(inp, sparse_tensor.SparseTensor) inp for inp in inputs if isinstance(inp, sparse_tensor.SparseTensor)

View File

@ -876,7 +876,7 @@ def _legacy_weights(layer):
non_trainable_weights. non_trainable_weights.
""" """
weights = layer.trainable_weights + layer.non_trainable_weights weights = layer.trainable_weights + layer.non_trainable_weights
if any([not isinstance(w, variables_module.Variable) for w in weights]): if any(not isinstance(w, variables_module.Variable) for w in weights):
raise NotImplementedError( raise NotImplementedError(
'Save or restore weights that is not an instance of `tf.Variable` is ' 'Save or restore weights that is not an instance of `tf.Variable` is '
'not supported in h5, use `save_format=\'tf\'` instead. Got a model ' 'not supported in h5, use `save_format=\'tf\'` instead. Got a model '

View File

@ -558,7 +558,7 @@ def main(argv):
# Only build images for host architecture # Only build images for host architecture
proc_arch = platform.processor() proc_arch = platform.processor()
is_x86 = proc_arch.startswith('x86') is_x86 = proc_arch.startswith('x86')
if (is_x86 and any([arch in tag for arch in ['ppc64le']]) or if (is_x86 and any(arch in tag for arch in ['ppc64le']) or
not is_x86 and proc_arch not in tag): not is_x86 and proc_arch not in tag):
continue continue

View File

@ -62,7 +62,7 @@ def check_cuda_lib(path, check_soname=True):
output = subprocess.check_output([objdump, "-p", path]).decode("utf-8") output = subprocess.check_output([objdump, "-p", path]).decode("utf-8")
output = [line for line in output.splitlines() if "SONAME" in line] output = [line for line in output.splitlines() if "SONAME" in line]
sonames = [line.strip().split(" ")[-1] for line in output] sonames = [line.strip().split(" ")[-1] for line in output]
if not any([soname == os.path.basename(path) for soname in sonames]): if not any(soname == os.path.basename(path) for soname in sonames):
raise ConfigError("None of the libraries match their SONAME: " + path) raise ConfigError("None of the libraries match their SONAME: " + path)