Replace list comprehension with generator expressions.
PiperOrigin-RevId: 285822581 Change-Id: I679256cc6f5890fa93ff3a2bfb9136b5d679d3ac
This commit is contained in:
parent
708e5729bb
commit
a78fa541d8
|
@ -586,7 +586,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
|
|||
trt_op_names.append(node.name)
|
||||
# Remove the function name prefix.
|
||||
def _Canonicalize(names):
|
||||
return set([self._ToString(name.split("/")[-1]) for name in names])
|
||||
return set(self._ToString(name.split("/")[-1]) for name in names)
|
||||
|
||||
all_op_names = _Canonicalize(all_op_names)
|
||||
trt_op_names = _Canonicalize(trt_op_names)
|
||||
|
|
|
@ -253,8 +253,8 @@ def _check_trt_version_compatibility():
|
|||
if loaded_version < linked_version:
|
||||
tf_logging.error(
|
||||
"Loaded TensorRT %s but linked TensorFlow against TensorRT %s. " %
|
||||
(".".join([str(x) for x in loaded_version]),
|
||||
".".join([str(x) for x in linked_version])) +
|
||||
(".".join(str(x) for x in loaded_version), ".".join(
|
||||
str(x) for x in linked_version)) +
|
||||
"TensorRT does not support forward compatibility. " +
|
||||
"It is also required to use the same major version of TensorRT " +
|
||||
"during compilation and runtime.")
|
||||
|
@ -262,16 +262,16 @@ def _check_trt_version_compatibility():
|
|||
if loaded_version[0] > linked_version[0]:
|
||||
tf_logging.error(
|
||||
"Loaded TensorRT %s but linked TensorFlow against TensorRT %s. " %
|
||||
(".".join([str(x) for x in loaded_version]),
|
||||
".".join([str(x) for x in linked_version])) +
|
||||
(".".join(str(x) for x in loaded_version), ".".join(
|
||||
str(x) for x in linked_version)) +
|
||||
"It is required to use the same major version " +
|
||||
"of TensorRT during compilation and runtime.")
|
||||
raise RuntimeError("Incompatible TensorRT major version")
|
||||
if loaded_version != linked_version:
|
||||
tf_logging.info(
|
||||
"Loaded TensorRT %s and linked TensorFlow against TensorRT %s. " %
|
||||
(".".join([str(x) for x in loaded_version]),
|
||||
".".join([str(x) for x in linked_version])) +
|
||||
(".".join(str(x) for x in loaded_version), ".".join(
|
||||
str(x) for x in linked_version)) +
|
||||
"This is supported because TensorRT " +
|
||||
" minor/patch upgrades are backward compatible")
|
||||
|
||||
|
|
|
@ -54,8 +54,8 @@ class CsvDatasetBenchmark(test.Benchmark):
|
|||
with open(fn, 'wb') as f:
|
||||
# Just write 100 rows and use `repeat`... Assumes the cost
|
||||
# of creating an iterator is not significant
|
||||
row = ','.join([str_val for _ in range(n)])
|
||||
f.write('\n'.join([row for _ in range(100)]))
|
||||
row = ','.join(str_val for _ in range(n))
|
||||
f.write('\n'.join(row for _ in range(100)))
|
||||
self._filenames.append(fn)
|
||||
|
||||
def _tear_down(self):
|
||||
|
|
|
@ -221,7 +221,7 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
" To enable rewrites, use resource variables instead by "
|
||||
"calling `tf.enable_resource_variables()` at the start of the "
|
||||
"program." % (", ".join(options._graph_rewrites())))
|
||||
self.assertTrue(any([expected in str(warning) for warning in w]))
|
||||
self.assertTrue(any(expected in str(warning) for warning in w))
|
||||
|
||||
# Check that outputs are the same in the optimized and unoptimized cases,
|
||||
# when the variable value is changing.
|
||||
|
|
|
@ -120,7 +120,7 @@ class OverrideThreadpoolTest(test_base.DatasetTestBase,
|
|||
graph = graph_pb2.GraphDef().FromString(
|
||||
self.evaluate(dataset._as_serialized_graph()))
|
||||
self.assertTrue(
|
||||
any([node.op != "MaxIntraOpParallelismDataset" for node in graph.node]))
|
||||
any(node.op != "MaxIntraOpParallelismDataset" for node in graph.node))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -52,7 +52,7 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
dataset = dataset_ops.Dataset.range(10)
|
||||
graph = graph_pb2.GraphDef().FromString(
|
||||
self.evaluate(dataset._as_serialized_graph()))
|
||||
self.assertTrue(any([node.op == "RangeDataset" for node in graph.node]))
|
||||
self.assertTrue(any(node.op == "RangeDataset" for node in graph.node))
|
||||
|
||||
def testAsSerializedGraphStateful(self):
|
||||
dataset = dataset_ops.Dataset.range(10).map(
|
||||
|
|
|
@ -103,13 +103,13 @@ class TextLineDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
|
||||
# Basic test: read from both files.
|
||||
expected_output = [self._lineText(0, i) for i in range(5)]
|
||||
expected_output.extend([self._lineText(1, i) for i in range(5)])
|
||||
expected_output.extend(self._lineText(1, i) for i in range(5))
|
||||
self.assertDatasetProduces(
|
||||
dataset_fn(test_filenames, 1), expected_output=expected_output)
|
||||
|
||||
# Test repeated iteration through both files.
|
||||
expected_output = [self._lineText(0, i) for i in range(5)]
|
||||
expected_output.extend([self._lineText(1, i) for i in range(5)])
|
||||
expected_output.extend(self._lineText(1, i) for i in range(5))
|
||||
self.assertDatasetProduces(
|
||||
dataset_fn(test_filenames, 10), expected_output=expected_output * 10)
|
||||
|
||||
|
@ -125,7 +125,7 @@ class TextLineDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
files = dataset_ops.Dataset.from_tensor_slices(test_filenames).repeat(10)
|
||||
expected_output = []
|
||||
for j in range(10):
|
||||
expected_output.extend([self._lineText(j, i) for i in range(10)])
|
||||
expected_output.extend(self._lineText(j, i) for i in range(10))
|
||||
dataset = readers.TextLineDataset(files, num_parallel_reads=4)
|
||||
self.assertDatasetProduces(
|
||||
dataset, expected_output=expected_output * 10, assert_items_equal=True)
|
||||
|
|
|
@ -311,10 +311,10 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
|
|||
# If the captured tensor is an eager tensor, we cannot trace its inputs.
|
||||
if isinstance(tensor, ops._EagerTensorBase): # pylint: disable=protected-access
|
||||
return False
|
||||
return any([is_tensor_or_parent_ref(x) for x in tensor.op.inputs])
|
||||
return any(is_tensor_or_parent_ref(x) for x in tensor.op.inputs)
|
||||
|
||||
for fn in self._functions():
|
||||
if any([is_tensor_or_parent_ref(t) for t in fn.function.captured_inputs]):
|
||||
if any(is_tensor_or_parent_ref(t) for t in fn.function.captured_inputs):
|
||||
return True
|
||||
|
||||
return any(
|
||||
|
|
|
@ -53,7 +53,7 @@ class TraverseTest(test.TestCase):
|
|||
variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds)
|
||||
self.assertSetEqual(
|
||||
set(["MapDataset", "RangeDataset"]),
|
||||
set([x.name for x in variant_tensor_ops]))
|
||||
set(x.name for x in variant_tensor_ops))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testConcat(self):
|
||||
|
@ -63,7 +63,7 @@ class TraverseTest(test.TestCase):
|
|||
variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds)
|
||||
self.assertSetEqual(
|
||||
set(["ConcatenateDataset", "RangeDataset", "RangeDataset_1"]),
|
||||
set([x.name for x in variant_tensor_ops]))
|
||||
set(x.name for x in variant_tensor_ops))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testZip(self):
|
||||
|
@ -73,7 +73,7 @@ class TraverseTest(test.TestCase):
|
|||
variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds)
|
||||
self.assertSetEqual(
|
||||
set(["ZipDataset", "RangeDataset", "RangeDataset_1"]),
|
||||
set([x.name for x in variant_tensor_ops]))
|
||||
set(x.name for x in variant_tensor_ops))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testMultipleVariantTensors(self):
|
||||
|
@ -82,7 +82,7 @@ class TraverseTest(test.TestCase):
|
|||
variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds)
|
||||
self.assertSetEqual(
|
||||
set(["RangeDataset", "ModelDataset", "PrefetchDataset"]),
|
||||
set([x.name for x in variant_tensor_ops]))
|
||||
set(x.name for x in variant_tensor_ops))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testFlatMap(self):
|
||||
|
@ -102,7 +102,7 @@ class TraverseTest(test.TestCase):
|
|||
set([
|
||||
"FlatMapDataset", "PrefetchDataset", "RepeatDataset",
|
||||
"RangeDataset", "RangeDataset_1"
|
||||
]), set([x.name for x in variant_tensor_ops]))
|
||||
]), set(x.name for x in variant_tensor_ops))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1200,12 +1200,12 @@ class DebugAnalyzer(object):
|
|||
return debugger_cli_common.rich_text_lines_from_rich_line_list(lines)
|
||||
|
||||
path_column_width = max(
|
||||
max([len(item[0]) for item in source_list]), len(path_head)) + 1
|
||||
max(len(item[0]) for item in source_list), len(path_head)) + 1
|
||||
num_nodes_column_width = max(
|
||||
max([len(str(item[2])) for item in source_list]),
|
||||
max(len(str(item[2])) for item in source_list),
|
||||
len(num_nodes_head)) + 1
|
||||
num_tensors_column_width = max(
|
||||
max([len(str(item[3])) for item in source_list]),
|
||||
max(len(str(item[3])) for item in source_list),
|
||||
len(num_tensors_head)) + 1
|
||||
|
||||
head = RL(path_head + " " * (path_column_width - len(path_head)), color)
|
||||
|
|
|
@ -42,7 +42,7 @@ def string_to_codes(cmd):
|
|||
|
||||
def codes_to_string(cmd_code):
|
||||
# Omit non-ASCII key codes.
|
||||
return "".join([chr(code) for code in cmd_code if code < 256])
|
||||
return "".join(chr(code) for code in cmd_code if code < 256)
|
||||
|
||||
|
||||
class MockCursesUI(curses_ui.CursesUI):
|
||||
|
|
|
@ -624,7 +624,7 @@ class ProfileAnalyzer(object):
|
|||
device_stats = self._run_metadata.step_stats.dev_stats[index]
|
||||
if device_name_regex and not device_name_regex.match(device_stats.device):
|
||||
continue
|
||||
profile_data.extend([datum for datum in data_generator(device_stats)])
|
||||
profile_data.extend(data_generator(device_stats))
|
||||
|
||||
source_annotation = source_utils.annotate_source_against_profile(
|
||||
profile_data,
|
||||
|
|
|
@ -88,7 +88,7 @@ class SessionDebugGrapplerInteractionTest(test_util.TensorFlowTestCase):
|
|||
self._dump_root, partition_graphs=run_metadata.partition_graphs,
|
||||
validate=True)
|
||||
|
||||
original_node_names = set([op.name for op in sess.graph.get_operations()])
|
||||
original_node_names = set(op.name for op in sess.graph.get_operations())
|
||||
dumped_node_names = set(dump_data.nodes())
|
||||
grappler_created_node_names = dumped_node_names - original_node_names
|
||||
grappler_removed_node_names = original_node_names - dumped_node_names
|
||||
|
|
|
@ -1231,7 +1231,7 @@ def choose_the_best(devices, session_config=None):
|
|||
Returns:
|
||||
A subclass of `CrossDeviceOps`.
|
||||
"""
|
||||
requested_devices = set([device_util.canonicalize(d) for d in devices])
|
||||
requested_devices = set(device_util.canonicalize(d) for d in devices)
|
||||
if ops.executing_eagerly_outside_functions():
|
||||
logical_gpus = context.context().list_logical_devices(device_type="GPU")
|
||||
physical_gpus = context.context().list_physical_devices(device_type="GPU")
|
||||
|
|
|
@ -722,7 +722,7 @@ def is_indexed_slices(value):
|
|||
if isinstance(value, ops.IndexedSlices):
|
||||
return True
|
||||
assert isinstance(value, value_lib.DistributedValues)
|
||||
return all([isinstance(v, ops.IndexedSlices) for v in value.values])
|
||||
return all(isinstance(v, ops.IndexedSlices) for v in value.values)
|
||||
|
||||
|
||||
def split_by_sparsity(values):
|
||||
|
|
|
@ -215,7 +215,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
|||
for replica in range(1, num_parameter_devices)
|
||||
]
|
||||
variables = list(variables) + extended_variables
|
||||
return set([v + ":0" for v in variables])
|
||||
return set(v + ":0" for v in variables)
|
||||
|
||||
self.assertEqual(
|
||||
get_expected_variables(len(distribution.extended.parameter_devices)),
|
||||
|
|
|
@ -2251,10 +2251,9 @@ def _convert_inputs_to_signature(inputs, input_signature, flat_input_signature):
|
|||
"""Convert inputs to pass into a function with an explicit signature."""
|
||||
|
||||
def format_error_message(inputs, input_signature):
|
||||
return (" inputs: (\n" + " " +
|
||||
",\n ".join([str(i) for i in inputs]) + ")\n" +
|
||||
" input_signature: (\n" + " " +
|
||||
",\n ".join([str(i) for i in input_signature]) + ")")
|
||||
return (" inputs: (\n" + " " + ",\n ".join(str(i) for i in inputs) +
|
||||
")\n" + " input_signature: (\n" + " " +
|
||||
",\n ".join(str(i) for i in input_signature) + ")")
|
||||
|
||||
try:
|
||||
# TODO(b/124370185): Use all elements as inputs to throw an error if there
|
||||
|
|
|
@ -246,7 +246,7 @@ def lift_to_graph(tensors,
|
|||
|
||||
# Check that the initializer does not depend on any placeholders.
|
||||
sources = object_identity.ObjectIdentitySet(sources or [])
|
||||
visited_ops = set([x.op for x in sources])
|
||||
visited_ops = set(x.op for x in sources)
|
||||
op_outputs = collections.defaultdict(set)
|
||||
|
||||
# First we extract the subgraph between init_tensors and sources.
|
||||
|
|
|
@ -94,13 +94,13 @@ def convert_structure_to_signature(structure, arg_names=None):
|
|||
# of the function argument.
|
||||
name = user_specified_name
|
||||
else:
|
||||
name = "/".join([str(p) for p in path])
|
||||
name = "/".join(str(p) for p in path)
|
||||
return tensor_spec.TensorSpec(arg.shape, arg.dtype, name)
|
||||
if isinstance(arg, composite_tensor.CompositeTensor):
|
||||
# TODO(b/133606651) Do we need to inject arg_name?
|
||||
return arg._type_spec # pylint: disable=protected-access
|
||||
if isinstance(arg, resource_variable_ops.BaseResourceVariable):
|
||||
name = "/".join([str(p) for p in path])
|
||||
name = "/".join(str(p) for p in path)
|
||||
return resource_variable_ops.VariableSpec(arg.shape, arg.dtype, name)
|
||||
if isinstance(arg, (
|
||||
int,
|
||||
|
|
|
@ -1296,7 +1296,7 @@ def get_extra_args():
|
|||
def _type_list_to_str(types):
|
||||
if any(_ not in _DTYPE_TO_STR for _ in types):
|
||||
raise ValueError("Unsupported dtypes: %s" % types)
|
||||
return "".join([_DTYPE_TO_STR[_] for _ in types])
|
||||
return "".join(_DTYPE_TO_STR[_] for _ in types)
|
||||
|
||||
|
||||
# NOTE: The list needs to be extended when more data types are added.
|
||||
|
|
|
@ -170,7 +170,7 @@ def graph_to_function_def(graph, operations, inputs, outputs, out_names=None):
|
|||
else:
|
||||
func.signature.output_arg.extend(
|
||||
[_tensor_to_argdef(o, name=n) for o, n in zip(outputs, out_names)])
|
||||
func_arg_placeholders = set([i.name for i in inputs])
|
||||
func_arg_placeholders = set(i.name for i in inputs)
|
||||
input_dict = _create_input_dict(graph, func_arg_placeholders,
|
||||
initial_value=initial_dict)
|
||||
|
||||
|
|
|
@ -297,9 +297,10 @@ def _find_extraneous_saver_nodes(graph_def, saver_def):
|
|||
# but it seems unnecessarily complex given the name scope solution.
|
||||
|
||||
# load the graph DAG in minimal form, without initializing a full Graph object
|
||||
nodes = {node_def.name:
|
||||
(set([_op_name(x) for x in node_def.input]), node_def.op)
|
||||
for node_def in graph_def.node}
|
||||
nodes = {
|
||||
node_def.name: (set(_op_name(x) for x in node_def.input), node_def.op)
|
||||
for node_def in graph_def.node
|
||||
}
|
||||
|
||||
retain_scope_save = None
|
||||
retain_scope_restore = None
|
||||
|
@ -313,12 +314,12 @@ def _find_extraneous_saver_nodes(graph_def, saver_def):
|
|||
retain_scope_restore = _get_scope(restore_op_name) + "/"
|
||||
retain_scope_save = _get_scope(save_op_name) + "/"
|
||||
|
||||
all_saver_node_names = set([name for name, (_, op) in nodes.items()
|
||||
if op in SAVE_AND_RESTORE_OPS])
|
||||
all_saver_node_names = set(
|
||||
name for name, (_, op) in nodes.items() if op in SAVE_AND_RESTORE_OPS)
|
||||
|
||||
all_saver_scopes = (set([_get_scope(x) for x in all_saver_node_names])
|
||||
- all_saver_node_names)
|
||||
all_saver_scopes = set([x + "/" for x in all_saver_scopes])
|
||||
all_saver_scopes = (
|
||||
set(_get_scope(x) for x in all_saver_node_names) - all_saver_node_names)
|
||||
all_saver_scopes = set(x + "/" for x in all_saver_scopes)
|
||||
|
||||
extraneous_scopes = all_saver_scopes - set([retain_scope_save,
|
||||
retain_scope_restore])
|
||||
|
@ -766,9 +767,10 @@ def import_scoped_meta_graph_with_return_elements(
|
|||
sorted([compat.as_str(v) for v in field.value]) !=
|
||||
sorted(input_map)):
|
||||
raise ValueError("Graph contains unbound inputs: %s. Must "
|
||||
"provide these inputs through input_map." %
|
||||
",".join([compat.as_str(v) for v in field.value
|
||||
if not input_map or v not in input_map]))
|
||||
"provide these inputs through input_map." % ",".join(
|
||||
compat.as_str(v)
|
||||
for v in field.value
|
||||
if not input_map or v not in input_map))
|
||||
break
|
||||
|
||||
# Sets graph to default graph if it's not passed in.
|
||||
|
|
|
@ -3283,7 +3283,7 @@ class Graph(object):
|
|||
|
||||
node_def = _NodeDef(op_type, name, attrs)
|
||||
|
||||
input_ops = set([t.op for t in inputs])
|
||||
input_ops = set(t.op for t in inputs)
|
||||
control_inputs = self._control_dependencies_for_inputs(input_ops)
|
||||
# _create_op_helper mutates the new Operation. `_mutation_lock` ensures a
|
||||
# Session.run call cannot occur between creating and mutating the op.
|
||||
|
@ -4442,7 +4442,7 @@ class Graph(object):
|
|||
# Don't add a control input if we already have a data dependency on i.
|
||||
# NOTE(mrry): We do not currently track transitive data dependencies,
|
||||
# so we may add redundant control inputs.
|
||||
ret.extend([c for c in controller.control_inputs if c not in input_ops])
|
||||
ret.extend(c for c in controller.control_inputs if c not in input_ops)
|
||||
return ret
|
||||
|
||||
def _record_op_seen_by_control_dependencies(self, op):
|
||||
|
|
|
@ -2409,7 +2409,7 @@ class TensorFlowTestCase(googletest.TestCase):
|
|||
path=None,
|
||||
msg=None):
|
||||
path = path or []
|
||||
path_str = (("[" + "][".join([str(p) for p in path]) + "]") if path else "")
|
||||
path_str = (("[" + "][".join(str(p) for p in path) + "]") if path else "")
|
||||
msg = msg if msg else ""
|
||||
|
||||
# Check if a and/or b are namedtuples.
|
||||
|
|
|
@ -90,7 +90,7 @@ class MemoryOptimizerSwapTest(test.TestCase):
|
|||
|
||||
self.assertEqual(len(graph.node), graph_size + 2)
|
||||
self.assertTrue(
|
||||
set([node.name for node in graph.node]) > set(
|
||||
set(node.name for node in graph.node) > set(
|
||||
['a', 'b', 'c', 'd', 'swap_in_d_0', 'swap_out_d_0']))
|
||||
for node in graph.node:
|
||||
if node.name == 'swap_in_d_0':
|
||||
|
|
|
@ -410,7 +410,7 @@ def block3(x,
|
|||
output_shape = x_shape + (groups,
|
||||
c) if backend.backend() == 'theano' else None
|
||||
x = layers.Lambda(
|
||||
lambda x: sum([x[:, :, :, :, i] for i in range(c)]),
|
||||
lambda x: sum(x[:, :, :, :, i] for i in range(c)),
|
||||
output_shape=output_shape,
|
||||
name=name + '_2_reduce')(
|
||||
x)
|
||||
|
|
|
@ -1035,7 +1035,7 @@ def placeholder(shape=None,
|
|||
dtype = floatx()
|
||||
if not shape:
|
||||
if ndim:
|
||||
shape = tuple([None for _ in range(ndim)])
|
||||
shape = (None,) * ndim
|
||||
with get_graph().as_default():
|
||||
if sparse:
|
||||
x = array_ops.sparse_placeholder(dtype, shape=shape, name=name)
|
||||
|
@ -1768,7 +1768,7 @@ def batch_dot(x, y, axes=None):
|
|||
else:
|
||||
axes = [x_ndim - 1, y_ndim - 2]
|
||||
|
||||
if py_any([isinstance(a, (list, tuple)) for a in axes]):
|
||||
if py_any(isinstance(a, (list, tuple)) for a in axes):
|
||||
raise ValueError('Multiple target dimensions are not supported. ' +
|
||||
'Expected: None, int, (int, int), ' +
|
||||
'Provided: ' + str(axes))
|
||||
|
@ -4580,7 +4580,7 @@ def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
|
|||
target = flatten(target)
|
||||
output = array_ops.reshape(output, [-1, output_shape[-1]])
|
||||
|
||||
if py_any([_is_symbolic_tensor(v) for v in [target, output]]):
|
||||
if py_any(_is_symbolic_tensor(v) for v in [target, output]):
|
||||
with get_graph().as_default():
|
||||
res = nn.sparse_softmax_cross_entropy_with_logits_v2(
|
||||
labels=target, logits=output)
|
||||
|
@ -5424,9 +5424,9 @@ def local_conv(inputs,
|
|||
if data_format == 'channels_first':
|
||||
slices.append(slice(None))
|
||||
|
||||
slices.extend([slice(position[d] * strides[d],
|
||||
position[d] * strides[d] + kernel_size[d])
|
||||
for d in spatial_dimensions])
|
||||
slices.extend(
|
||||
slice(position[d] * strides[d], position[d] * strides[d] +
|
||||
kernel_size[d]) for d in spatial_dimensions)
|
||||
|
||||
if data_format == 'channels_last':
|
||||
slices.append(slice(None))
|
||||
|
|
|
@ -113,7 +113,7 @@ def load_data(path='imdb.npz',
|
|||
str(maxlen) + ', no sequence was kept. '
|
||||
'Increase maxlen.')
|
||||
if not num_words:
|
||||
num_words = max([max(x) for x in xs])
|
||||
num_words = max(max(x) for x in xs)
|
||||
|
||||
# by convention, use 2 as OOV word
|
||||
# reserve 'index_from' (=3 by default) characters:
|
||||
|
|
|
@ -99,7 +99,7 @@ def load_data(path='reuters.npz',
|
|||
xs, labels = _remove_long_seq(maxlen, xs, labels)
|
||||
|
||||
if not num_words:
|
||||
num_words = max([max(x) for x in xs])
|
||||
num_words = max(max(x) for x in xs)
|
||||
|
||||
# by convention, use 2 as OOV word
|
||||
# reserve 'index_from' (=3 by default) characters:
|
||||
|
|
|
@ -266,7 +266,7 @@ class TensorLikeDataAdapter(DataAdapter):
|
|||
msg = "Data cardinality is ambiguous:\n"
|
||||
for label, data in zip(["x", "y", "sample_weight"], inputs):
|
||||
msg += " {} sizes: {}\n".format(
|
||||
label, ", ".join([str(i.shape[0]) for i in nest.flatten(data)]))
|
||||
label, ", ".join(str(i.shape[0]) for i in nest.flatten(data)))
|
||||
msg += "Please provide data which shares the same first dimension."
|
||||
raise ValueError(msg)
|
||||
num_samples = num_samples.pop()
|
||||
|
|
|
@ -1319,7 +1319,7 @@ class Model(network.Network, version_utils.VersionSelector):
|
|||
"""
|
||||
if not self._is_compiled:
|
||||
return
|
||||
if sample_weights and any([s is not None for s in sample_weights]):
|
||||
if sample_weights and any(s is not None for s in sample_weights):
|
||||
for endpoint in self._training_endpoints:
|
||||
endpoint.sample_weight_mode = (
|
||||
endpoint.sample_weight_mode or 'samplewise')
|
||||
|
@ -1330,8 +1330,8 @@ class Model(network.Network, version_utils.VersionSelector):
|
|||
def _recompile_weights_loss_and_weighted_metrics(self):
|
||||
if not self._is_compiled:
|
||||
return False
|
||||
recompile = any([e.sample_weights_mismatch()
|
||||
for e in self._training_endpoints])
|
||||
recompile = any(
|
||||
e.sample_weights_mismatch() for e in self._training_endpoints)
|
||||
|
||||
if recompile:
|
||||
self._compile_weights_loss_and_weighted_metrics()
|
||||
|
|
|
@ -1492,7 +1492,7 @@ class Model(training_lib.Model):
|
|||
"""
|
||||
if not self._is_compiled:
|
||||
return
|
||||
if sample_weights and any([s is not None for s in sample_weights]):
|
||||
if sample_weights and any(s is not None for s in sample_weights):
|
||||
for endpoint in self._training_endpoints:
|
||||
endpoint.sample_weight_mode = (
|
||||
endpoint.sample_weight_mode or 'samplewise')
|
||||
|
|
|
@ -592,8 +592,7 @@ class Flatten(Layer):
|
|||
if (self.data_format == 'channels_first'
|
||||
and K.ndim(inputs) is not None and K.ndim(inputs) > 1):
|
||||
permutation = [0]
|
||||
permutation.extend([i for i in
|
||||
range(2, K.ndim(inputs))])
|
||||
permutation.extend(range(2, K.ndim(inputs)))
|
||||
permutation.append(1)
|
||||
inputs = array_ops.transpose(inputs, perm=permutation)
|
||||
|
||||
|
@ -858,7 +857,7 @@ class Lambda(Layer):
|
|||
untracked_new_vars = [v for v in created_variables
|
||||
if v.experimental_ref() not in tracked_weights]
|
||||
if untracked_new_vars:
|
||||
variable_str = '\n'.join([' {}'.format(i) for i in untracked_new_vars])
|
||||
variable_str = '\n'.join(' {}'.format(i) for i in untracked_new_vars)
|
||||
error_str = textwrap.dedent(
|
||||
'''
|
||||
The following Variables were created within a Lambda layer ({name})
|
||||
|
@ -875,7 +874,7 @@ class Lambda(Layer):
|
|||
untracked_used_vars = [v for v in accessed_variables
|
||||
if v.experimental_ref() not in tracked_weights]
|
||||
if untracked_used_vars and not self._already_warned:
|
||||
variable_str = '\n'.join([' {}'.format(i) for i in untracked_used_vars])
|
||||
variable_str = '\n'.join(' {}'.format(i) for i in untracked_used_vars)
|
||||
self._warn(textwrap.dedent(
|
||||
'''
|
||||
The following Variables were used a Lambda layer's call ({name}), but
|
||||
|
|
|
@ -387,7 +387,7 @@ class Concatenate(_Merge):
|
|||
'except for the concat axis. Got inputs shapes: %s' %
|
||||
input_shape)
|
||||
# Make sure all the shapes have same ranks.
|
||||
ranks = set([len(shape) for shape in shape_set])
|
||||
ranks = set(len(shape) for shape in shape_set)
|
||||
if len(ranks) != 1:
|
||||
raise ValueError(err_msg)
|
||||
# Get the only rank for the set.
|
||||
|
@ -395,8 +395,8 @@ class Concatenate(_Merge):
|
|||
for axis in range(rank):
|
||||
# Skip the Nones in the shape since they are dynamic, also the axis for
|
||||
# concat has been removed above.
|
||||
unique_dims = set([shape[axis] for shape in shape_set
|
||||
if shape[axis] is not None])
|
||||
unique_dims = set(
|
||||
shape[axis] for shape in shape_set if shape[axis] is not None)
|
||||
if len(unique_dims) > 1:
|
||||
raise ValueError(err_msg)
|
||||
|
||||
|
|
|
@ -175,7 +175,7 @@ class RNNCellWrapperTest(test.TestCase, parameterized.TestCase):
|
|||
_ = rnn_cell(inputs, [state, state])
|
||||
weights = base_cell._cells[0].weights
|
||||
self.assertLen(weights, expected_len=2)
|
||||
self.assertTrue(all(["_wrapper" in v.name for v in weights]))
|
||||
self.assertTrue(all("_wrapper" in v.name for v in weights))
|
||||
|
||||
@parameterized.parameters(
|
||||
[rnn_cell_wrapper_v2.DropoutWrapper, rnn_cell_wrapper_v2.ResidualWrapper])
|
||||
|
|
|
@ -539,8 +539,8 @@ class AdamOptimizerTest(test.TestCase):
|
|||
opt = adam.Adam(1.)
|
||||
opt.minimize(lambda: v1 + v2, var_list=[v1, v2])
|
||||
# There should be iteration, and two unique slot variables for v1 and v2.
|
||||
self.assertEqual(
|
||||
5, len(set([v.experimental_ref() for v in opt.variables()])))
|
||||
self.assertEqual(5,
|
||||
len(set(v.experimental_ref() for v in opt.variables())))
|
||||
self.assertEqual(
|
||||
self.evaluate(opt.variables()[0]), self.evaluate(opt.iterations))
|
||||
|
||||
|
|
|
@ -89,7 +89,7 @@ class Optimizer(object):
|
|||
function not implemented).
|
||||
"""
|
||||
grads = K.gradients(loss, params)
|
||||
if any([g is None for g in grads]):
|
||||
if any(g is None for g in grads):
|
||||
raise ValueError('An operation has `None` for gradient. '
|
||||
'Please make sure that all of your ops have a '
|
||||
'gradient defined (i.e. are differentiable). '
|
||||
|
|
|
@ -804,8 +804,7 @@ def save_attributes_to_hdf5_group(group, name, data):
|
|||
if bad_attributes:
|
||||
raise RuntimeError('The following attributes cannot be saved to HDF5 '
|
||||
'file because they are larger than %d bytes: %s' %
|
||||
(HDF5_OBJECT_HEADER_LIMIT,
|
||||
', '.join([x for x in bad_attributes])))
|
||||
(HDF5_OBJECT_HEADER_LIMIT, ', '.join(bad_attributes)))
|
||||
|
||||
data_npy = np.asarray(data)
|
||||
|
||||
|
|
|
@ -681,7 +681,7 @@ def to_list(x):
|
|||
def object_list_uid(object_list):
|
||||
"""Creates a single string from object ids."""
|
||||
object_list = nest.flatten(object_list)
|
||||
return ', '.join([str(abs(id(x))) for x in object_list])
|
||||
return ', '.join(str(abs(id(x))) for x in object_list)
|
||||
|
||||
|
||||
def to_snake_case(name):
|
||||
|
|
|
@ -402,7 +402,7 @@ def assert_no_legacy_layers(layers):
|
|||
# isinstance check for tf.layers.Layer introduces a circular dependency.
|
||||
legacy_layers = [l for l in layers if getattr(l, '_is_legacy_layer', None)]
|
||||
if legacy_layers:
|
||||
layer_str = '\n'.join([' ' + str(l) for l in legacy_layers])
|
||||
layer_str = '\n'.join(' ' + str(l) for l in legacy_layers)
|
||||
raise TypeError(
|
||||
'The following are legacy tf.layers.Layers:\n{}\nTo use keras as a '
|
||||
'framework (for instance using the Network, Model, or Sequential '
|
||||
|
|
|
@ -273,7 +273,7 @@ class GpuMultiSessionMemoryTest(test_util.TensorFlowTestCase):
|
|||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
flat_results = set([x for x in itertools.chain(*results)])
|
||||
flat_results = set(itertools.chain(*results))
|
||||
self.assertEqual(1,
|
||||
len(flat_results),
|
||||
'Expected single value, got %r' % flat_results)
|
||||
|
|
|
@ -330,7 +330,7 @@ class DecodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase):
|
|||
# Test against all 3! permutations of fragments, and for each permutation
|
||||
# test parsing all possible combination of 2 fields.
|
||||
for indices in itertools.permutations(range(len(fragments))):
|
||||
proto = b''.join([fragments[i] for i in indices])
|
||||
proto = b''.join(fragments[i] for i in indices)
|
||||
for i in indices:
|
||||
if i == 1:
|
||||
expected_message_values = [
|
||||
|
|
|
@ -118,7 +118,7 @@ class VariableScopeTest(test.TestCase):
|
|||
vs.get_variable("v2", [2])
|
||||
expected_names = ["%s:0" % name for name in ["v1", "v2"]]
|
||||
self.assertEqual(
|
||||
set(expected_names), set([v.name for v in vs._vars.values()]))
|
||||
set(expected_names), set(v.name for v in vs._vars.values()))
|
||||
|
||||
# TODO(mihaimaruseac): Not converted to use wrap_function because of
|
||||
# TypeError: Expected tf.group() expected Tensor arguments not 'None' with
|
||||
|
|
|
@ -231,7 +231,7 @@ def constant_value(pred):
|
|||
def object_list_uid(object_list):
|
||||
"""Creates a single string from object ids."""
|
||||
object_list = nest.flatten(object_list)
|
||||
return ', '.join([str(abs(id(x))) for x in object_list])
|
||||
return ', '.join(str(abs(id(x))) for x in object_list)
|
||||
|
||||
|
||||
def static_shape(x):
|
||||
|
|
|
@ -351,7 +351,7 @@ def _binary_assert(sym, opname, op_func, static_func, x, y, data, summarize,
|
|||
raise errors.InvalidArgumentError(
|
||||
node_def=None,
|
||||
op=None,
|
||||
message=('\n'.join([_pretty_print(d, summarize) for d in data])))
|
||||
message=('\n'.join(_pretty_print(d, summarize) for d in data)))
|
||||
|
||||
else: # not context.executing_eagerly()
|
||||
if data is None:
|
||||
|
|
|
@ -82,7 +82,7 @@ def _as_shape_list(shapes,
|
|||
if any(not shape.is_fully_defined() for shape in shapes):
|
||||
raise ValueError("All shapes must be fully defined: %s" % shapes)
|
||||
if not unknown_rank_allowed:
|
||||
if any([shape.dims is None for shape in shapes]):
|
||||
if any(shape.dims is None for shape in shapes):
|
||||
raise ValueError("All shapes must have a defined rank: %s" % shapes)
|
||||
|
||||
return shapes
|
||||
|
|
|
@ -834,9 +834,9 @@ def _LogOpGradients(op, out_grads, in_grads):
|
|||
return True
|
||||
|
||||
logging.vlog(1, " in --> %s",
|
||||
", ".join([x.name for x in out_grads if _FilterGrad(x)]))
|
||||
", ".join(x.name for x in out_grads if _FilterGrad(x)))
|
||||
logging.vlog(1, " out --> %s",
|
||||
", ".join([x.name for x in in_grads if _FilterGrad(x)]))
|
||||
", ".join(x.name for x in in_grads if _FilterGrad(x)))
|
||||
|
||||
|
||||
def _MultiDeviceAddN(tensor_list, gradient_uid):
|
||||
|
|
|
@ -275,7 +275,7 @@ def _EinsumGrad(op, grad):
|
|||
set(output_subs + other_subs + "."))
|
||||
# Obtain the input subscripts with the reduced axis labels removed. E.g.
|
||||
# "ac" in the above example.
|
||||
left_subs = "".join([s for s in input_subs if s not in reduced_label_set])
|
||||
left_subs = "".join(s for s in input_subs if s not in reduced_label_set)
|
||||
|
||||
# Compute the gradient wrt the input, without accounting for the operation
|
||||
# "abc->ac". So, now we have the VJP of the operation "ac,cd->ad".
|
||||
|
|
|
@ -339,7 +339,7 @@ def _path_from(from_op, tensor, sources):
|
|||
if isinstance(from_op, ops.Tensor):
|
||||
from_op = from_op.op
|
||||
|
||||
visited_ops = set([x.op for x in sources])
|
||||
visited_ops = set(x.op for x in sources)
|
||||
ops_to_visit = [_as_operation(tensor)]
|
||||
some_op_output = {}
|
||||
while ops_to_visit:
|
||||
|
@ -354,7 +354,7 @@ def _path_from(from_op, tensor, sources):
|
|||
while path_op != final_op:
|
||||
path_op = some_op_output[path_op]
|
||||
path.append(path_op)
|
||||
return " <- ".join(["%s (%s)" % (x.name, x.type) for x in reversed(path)])
|
||||
return " <- ".join("%s (%s)" % (x.name, x.type) for x in reversed(path))
|
||||
else:
|
||||
for inp in graph_inputs(op):
|
||||
if inp not in visited_ops and inp not in sources:
|
||||
|
|
|
@ -74,7 +74,7 @@ def for_loop(loop_fn, loop_fn_dtypes, iters, parallel_iterations=None):
|
|||
len(fn_output)))
|
||||
outputs = []
|
||||
del is_none_list[:]
|
||||
is_none_list.extend([x is None for x in fn_output])
|
||||
is_none_list.extend(x is None for x in fn_output)
|
||||
for out, ta in zip(fn_output, ta_list):
|
||||
# TODO(agarwal): support returning Operation objects from loop_fn.
|
||||
if out is not None:
|
||||
|
|
|
@ -315,7 +315,7 @@ class BitwiseTest(PForTestCase):
|
|||
y1 = array_ops.gather(y, i)
|
||||
outputs = [op(x, y), op(x1, y), op(x, y1), op(x1, y1), op(x1, x1)]
|
||||
del output_dtypes[:]
|
||||
output_dtypes.extend([t.dtype for t in outputs])
|
||||
output_dtypes.extend(t.dtype for t in outputs)
|
||||
return outputs
|
||||
# pylint: enable=cell-var-from-loop
|
||||
self._test_loop_fn(loop_fn, 3)
|
||||
|
|
|
@ -65,7 +65,7 @@ class MathTest(PForTestCase, parameterized.TestCase):
|
|||
if grad is not None:
|
||||
outputs.append(grad)
|
||||
del output_dtypes[:]
|
||||
output_dtypes.extend([t.dtype for t in outputs])
|
||||
output_dtypes.extend(t.dtype for t in outputs)
|
||||
return outputs
|
||||
|
||||
# pylint: enable=cell-var-from-loop
|
||||
|
@ -215,7 +215,7 @@ class MathTest(PForTestCase, parameterized.TestCase):
|
|||
y1 = array_ops.gather(y, i)
|
||||
outputs = [op(x, y), op(x1, y), op(x, y1), op(x1, y1), op(x1, x1)]
|
||||
del output_dtypes[:]
|
||||
output_dtypes.extend([t.dtype for t in outputs])
|
||||
output_dtypes.extend(t.dtype for t in outputs)
|
||||
return outputs
|
||||
# pylint: enable=cell-var-from-loop
|
||||
|
||||
|
|
|
@ -121,7 +121,7 @@ class WhileOp(object):
|
|||
"""
|
||||
self._pfor_config = pfor_config
|
||||
self._pfor_ops = set(pfor_ops)
|
||||
self._pfor_op_ids = set([x._id for x in pfor_ops])
|
||||
self._pfor_op_ids = set(x._id for x in pfor_ops)
|
||||
assert isinstance(exit_node, ops.Tensor)
|
||||
self._while_context = exit_node.op._get_control_flow_context()
|
||||
assert isinstance(self._while_context, control_flow_ops.WhileContext)
|
||||
|
@ -1176,7 +1176,7 @@ class PFor(object):
|
|||
self._conversion_map = object_identity.ObjectIdentityDictionary()
|
||||
self._conversion_map[loop_var] = wrap(self.all_indices, True)
|
||||
self._pfor_ops = set(pfor_ops)
|
||||
self._pfor_op_ids = set([x._id for x in pfor_ops])
|
||||
self._pfor_op_ids = set(x._id for x in pfor_ops)
|
||||
self._pfor_config = pfor_config
|
||||
|
||||
def op_is_inside_loop(self, op):
|
||||
|
|
|
@ -126,8 +126,8 @@ class RaggedTensorDynamicShape(object):
|
|||
|
||||
# Convert dimension size tensors to a single dtype.
|
||||
if dim_size_dtype is None:
|
||||
dim_size_dtypes = set([p.dtype for p in partitioned_dim_sizes
|
||||
if p.shape.ndims == 1])
|
||||
dim_size_dtypes = set(
|
||||
p.dtype for p in partitioned_dim_sizes if p.shape.ndims == 1)
|
||||
if not dim_size_dtypes:
|
||||
dim_size_dtype = dtypes.int64
|
||||
elif len(dim_size_dtypes) == 1:
|
||||
|
|
|
@ -1237,7 +1237,7 @@ class MultiRNNCell(RNNCell):
|
|||
if not nest.is_sequence(cells):
|
||||
raise TypeError("cells must be a list or tuple, but saw: %s." % cells)
|
||||
|
||||
if len(set([id(cell) for cell in cells])) < len(cells):
|
||||
if len(set(id(cell) for cell in cells)) < len(cells):
|
||||
logging.log_first_n(
|
||||
logging.WARN, "At least two cells provided to MultiRNNCell "
|
||||
"are the same object and will share weights.", 1)
|
||||
|
|
|
@ -360,8 +360,8 @@ def _einsum_v1_parse_and_resolve_equation(equation, input_shapes):
|
|||
# tensors of different length and unlabeled output.
|
||||
ellipsis_axes = ''
|
||||
if '...' in equation:
|
||||
unused = ''.join([c for c in string.ascii_letters
|
||||
if c not in ''.join(input_axis_labels)])
|
||||
unused = ''.join(
|
||||
c for c in string.ascii_letters if c not in ''.join(input_axis_labels))
|
||||
for i, ax in enumerate(input_axis_labels):
|
||||
if '...' in ax:
|
||||
parts = ax.split('...')
|
||||
|
@ -381,7 +381,7 @@ def _einsum_v1_parse_and_resolve_equation(equation, input_shapes):
|
|||
if len(replace_axes) > len(ellipsis_axes):
|
||||
ellipsis_axes = replace_axes
|
||||
|
||||
if any(['.' in ax for ax in input_axis_labels]):
|
||||
if any('.' in ax for ax in input_axis_labels):
|
||||
raise ValueError('period "." found outside of ellipsis')
|
||||
|
||||
if output_axis_labels is not None:
|
||||
|
|
|
@ -804,7 +804,7 @@ class _EagerTensorArray(object):
|
|||
None, None,
|
||||
"Tried to write to index %d but array is not resizeable and size "
|
||||
"is: %d" % (index, size))
|
||||
self._tensor_array.extend([None for _ in range(index - size + 1)])
|
||||
self._tensor_array.extend(None for _ in range(index - size + 1))
|
||||
|
||||
if not isinstance(value, ops.EagerTensor):
|
||||
# TODO(b/129870929): Fix after all callers provide proper init dtype.
|
||||
|
|
|
@ -1324,9 +1324,9 @@ class Variable(six.with_metaclass(VariableMetaclass, trackable.Trackable)):
|
|||
@property
|
||||
def spec(self):
|
||||
"""Computes the spec string used for saving."""
|
||||
full_shape_str = " ".join(["%d" % d for d in self.full_shape]) + " "
|
||||
full_shape_str = " ".join("%d" % d for d in self.full_shape) + " "
|
||||
sl_spec = ":".join(
|
||||
["%d,%d" % (o, s) for o, s in zip(self.var_offset, self.var_shape)])
|
||||
"%d,%d" % (o, s) for o, s in zip(self.var_offset, self.var_shape))
|
||||
return full_shape_str + sl_spec
|
||||
|
||||
def to_proto(self, export_scope=None):
|
||||
|
|
|
@ -229,7 +229,7 @@ class PrintModelAnalysisTest(test.TestCase):
|
|||
with gfile.Open(outfile, 'r') as f:
|
||||
lines = f.read().split('\n')
|
||||
self.assertGreater(len(lines), 5)
|
||||
result = '\n'.join([l[:min(len(l), 80)] for l in lines])
|
||||
result = '\n'.join(l[:min(len(l), 80)] for l in lines)
|
||||
self.assertTrue(
|
||||
compat.as_text(lib.CheckAndRemoveDoc(result))
|
||||
.startswith('node name | # parameters | # float_ops'))
|
||||
|
|
|
@ -244,8 +244,7 @@ def recreate_function(saved_function, concrete_functions):
|
|||
|
||||
def _pretty_format_positional(positional):
|
||||
return "Positional arguments ({} total):\n * {}".format(
|
||||
len(positional),
|
||||
"\n * ".join([str(a) for a in positional]))
|
||||
len(positional), "\n * ".join(str(a) for a in positional))
|
||||
|
||||
for index, function_name in enumerate(saved_function.concrete_functions):
|
||||
concrete_function = concrete_functions[function_name]
|
||||
|
|
|
@ -96,7 +96,7 @@ def get_header_from_ops_and_kernels(ops_and_kernels,
|
|||
Returns:
|
||||
the string of the header that should be written as ops_to_register.h.
|
||||
"""
|
||||
ops = set([op for op, _ in ops_and_kernels])
|
||||
ops = set(op for op, _ in ops_and_kernels)
|
||||
result_list = []
|
||||
|
||||
def append(s):
|
||||
|
|
|
@ -111,7 +111,7 @@ def topological_sort(g):
|
|||
if op_in_degree[consumer] < 0:
|
||||
raise ValueError('consumer:%s degree mismatch'%consumer.name)
|
||||
|
||||
left_ops = set([op for (op, degree) in op_in_degree.items() if degree > 0])
|
||||
left_ops = set(op for (op, degree) in op_in_degree.items() if degree > 0)
|
||||
if left_ops:
|
||||
return (True, left_ops)
|
||||
else:
|
||||
|
|
|
@ -321,8 +321,8 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
|||
|
||||
def report_unsupported_operations(self):
|
||||
if self._unsupported_ops:
|
||||
op_str = "\n".join([" %s (%s)" % (op.type, op.name)
|
||||
for op in self._unsupported_ops[:_MAX_WARNING_LINES]])
|
||||
op_str = "\n".join(" %s (%s)" % (op.type, op.name)
|
||||
for op in self._unsupported_ops[:_MAX_WARNING_LINES])
|
||||
logging.warning("%d unsupported operations found: \n%s",
|
||||
len(self._unsupported_ops), op_str)
|
||||
if len(self._unsupported_ops) > _MAX_WARNING_LINES:
|
||||
|
@ -1200,7 +1200,7 @@ def split_compile_and_replicate(computation,
|
|||
|
||||
if host_compute_core:
|
||||
attr_value = attr_value_pb2.AttrValue()
|
||||
attr_value.list.s.extend([compat.as_bytes(x) for x in host_compute_core])
|
||||
attr_value.list.s.extend(compat.as_bytes(x) for x in host_compute_core)
|
||||
metadata._set_attr("host_compute_core", attr_value) # pylint: disable=protected-access
|
||||
|
||||
with ops.control_dependencies([metadata]):
|
||||
|
|
|
@ -199,7 +199,7 @@ def master_job(master, cluster_def):
|
|||
|
||||
if (not cluster_def or not cluster_def.job):
|
||||
return _DEFAULT_JOB_NAME
|
||||
job_names = set([job.name for job in cluster_def.job])
|
||||
job_names = set(job.name for job in cluster_def.job)
|
||||
if _DEFAULT_JOB_NAME in job_names:
|
||||
# b/37868888 tracks allowing ClusterSpec propagation to reuse job names.
|
||||
raise ValueError('Currently, tpu_worker is not an allowed job name.')
|
||||
|
|
|
@ -330,7 +330,7 @@ def _init_from_checkpoint(ckpt_dir_or_file, assignment_map):
|
|||
))
|
||||
var_name = var.name
|
||||
else:
|
||||
var_name = ",".join([v.name for v in var])
|
||||
var_name = ",".join(v.name for v in var)
|
||||
_set_variable_or_list_initializer(var, ckpt_file, tensor_name_in_ckpt)
|
||||
logging.debug("Initialize variable %s from checkpoint %s with %s",
|
||||
var_name, ckpt_dir_or_file, tensor_name_in_ckpt)
|
||||
|
|
|
@ -401,7 +401,7 @@ class SliceInputProducerTest(test_lib.TestCase):
|
|||
frequency[e] = 0
|
||||
for _ in range(num_epochs):
|
||||
output = [self.evaluate(slices) for _ in range(len(source_strings))]
|
||||
key = b",".join([s + compat.as_bytes(str(i)) for s, i in output])
|
||||
key = b",".join(s + compat.as_bytes(str(i)) for s, i in output)
|
||||
self.assertIn(key, expected)
|
||||
frequency[key] += 1
|
||||
|
||||
|
@ -1083,7 +1083,7 @@ class BatchJoinTest(test_lib.TestCase):
|
|||
self.assertEqual(len(which_a) + len(which_b), batch_size)
|
||||
if which_a and which_b:
|
||||
saw_both += 1
|
||||
all_a.extend([results[0][i] for i in which_a])
|
||||
all_a.extend(results[0][i] for i in which_a)
|
||||
seen_b += len(which_b)
|
||||
self.assertAllEqual([99] * len(which_b),
|
||||
[results[0][i] for i in which_b])
|
||||
|
@ -1185,7 +1185,7 @@ class BatchJoinTest(test_lib.TestCase):
|
|||
self.assertEqual(len(which_a) + len(which_b), batch_size)
|
||||
if which_a and which_b:
|
||||
saw_both += 1
|
||||
all_a.extend([results[0][i] for i in which_a])
|
||||
all_a.extend(results[0][i] for i in which_a)
|
||||
seen_b += len(which_b)
|
||||
self.assertAllEqual([99] * len(which_b),
|
||||
[results[0][i] for i in which_b])
|
||||
|
@ -1271,7 +1271,7 @@ class BatchJoinTest(test_lib.TestCase):
|
|||
self.assertEqual(len(which_a) + len(which_b), batch_size)
|
||||
if which_a and which_b:
|
||||
saw_both += 1
|
||||
all_a.extend([results[0][i] for i in which_a])
|
||||
all_a.extend(results[0][i] for i in which_a)
|
||||
seen_b += len(which_b)
|
||||
self.assertAllEqual([99] * len(which_b),
|
||||
[results[0][i] for i in which_b])
|
||||
|
@ -1291,7 +1291,7 @@ class BatchJoinTest(test_lib.TestCase):
|
|||
self.assertEqual(len(which_a) + len(which_b), 2 * extra_elements)
|
||||
if which_a and which_b:
|
||||
saw_both += 1
|
||||
all_a.extend([results[0][i] for i in which_a])
|
||||
all_a.extend(results[0][i] for i in which_a)
|
||||
seen_b += len(which_b)
|
||||
|
||||
# We'd like to see some minimum level of mixing of the results of both
|
||||
|
@ -1369,7 +1369,7 @@ class BatchJoinTest(test_lib.TestCase):
|
|||
self.assertEqual(len(which_a) + len(which_b), batch_size)
|
||||
if which_a and which_b:
|
||||
saw_both += 1
|
||||
all_a.extend([results[0][i] for i in which_a])
|
||||
all_a.extend(results[0][i] for i in which_a)
|
||||
seen_b += len(which_b)
|
||||
self.assertAllEqual([99] * len(which_b),
|
||||
[results[0][i] for i in which_b])
|
||||
|
@ -1389,7 +1389,7 @@ class BatchJoinTest(test_lib.TestCase):
|
|||
self.assertEqual(len(which_a) + len(which_b), 2 * extra_elements)
|
||||
if which_a and which_b:
|
||||
saw_both += 1
|
||||
all_a.extend([results[0][i] for i in which_a])
|
||||
all_a.extend(results[0][i] for i in which_a)
|
||||
seen_b += len(which_b)
|
||||
|
||||
# We'd like to see some minimum level of mixing of the results of both
|
||||
|
@ -2099,7 +2099,7 @@ class ShuffleBatchJoinTest(test_lib.TestCase):
|
|||
self.assertEqual(len(which_a) + len(which_b), batch_size)
|
||||
if which_a and which_b:
|
||||
saw_both += 1
|
||||
all_a.extend([results[0][i] for i in which_a])
|
||||
all_a.extend(results[0][i] for i in which_a)
|
||||
seen_b += len(which_b)
|
||||
self.assertAllEqual([99] * len(which_b),
|
||||
[results[0][i] for i in which_b])
|
||||
|
@ -2194,7 +2194,7 @@ class ShuffleBatchJoinTest(test_lib.TestCase):
|
|||
self.assertEqual(len(which_a) + len(which_b), batch_size)
|
||||
if which_a and which_b:
|
||||
saw_both += 1
|
||||
all_a.extend([results[0][i] for i in which_a])
|
||||
all_a.extend(results[0][i] for i in which_a)
|
||||
seen_b += len(which_b)
|
||||
self.assertAllEqual([99] * len(which_b),
|
||||
[results[0][i] for i in which_b])
|
||||
|
@ -2213,7 +2213,7 @@ class ShuffleBatchJoinTest(test_lib.TestCase):
|
|||
self.assertEqual(len(which_a) + len(which_b), 2 * extra_elements)
|
||||
if which_a and which_b:
|
||||
saw_both += 1
|
||||
all_a.extend([results[0][i] for i in which_a])
|
||||
all_a.extend(results[0][i] for i in which_a)
|
||||
seen_b += len(which_b)
|
||||
|
||||
# Some minimum level of mixing of the results of both threads.
|
||||
|
|
|
@ -554,7 +554,7 @@ class ExponentialMovingAverage(object):
|
|||
for v in moving_avg_variables:
|
||||
name_map[self.average_name(v)] = v
|
||||
# Make sure we restore variables without moving averages as well.
|
||||
moving_avg_variable_names = set([v.name for v in moving_avg_variables])
|
||||
moving_avg_variable_names = set(v.name for v in moving_avg_variables)
|
||||
for v in list(set(variables.global_variables())):
|
||||
if v.name not in moving_avg_variable_names and v.op.name not in name_map:
|
||||
name_map[v.op.name] = v
|
||||
|
|
|
@ -46,8 +46,8 @@ def make_all(module_name, doc_string_modules=None):
|
|||
"""
|
||||
if doc_string_modules is None:
|
||||
doc_string_modules = [_sys.modules[module_name]]
|
||||
cur_members = set([name for name, _
|
||||
in _tf_inspect.getmembers(_sys.modules[module_name])])
|
||||
cur_members = set(
|
||||
name for name, _ in _tf_inspect.getmembers(_sys.modules[module_name]))
|
||||
|
||||
results = set()
|
||||
for doc_module in doc_string_modules:
|
||||
|
|
|
@ -172,7 +172,7 @@ class ObjectIdentitySet(collections_abc.MutableSet):
|
|||
"""Like the built-in set, but compares objects with "is"."""
|
||||
|
||||
def __init__(self, *args):
|
||||
self._storage = set([self._wrap_key(obj) for obj in list(*args)])
|
||||
self._storage = set(self._wrap_key(obj) for obj in list(*args))
|
||||
|
||||
@staticmethod
|
||||
def _from_storage(storage):
|
||||
|
|
Loading…
Reference in New Issue