var = dict() --> var = {}
This commit is contained in:
parent
d3c1e397f6
commit
3ebc496a1b
@ -66,7 +66,7 @@ def parse_disabled_manifest(manifest_content):
|
|||||||
raise ValueError('Bad entry in manifest file.')
|
raise ValueError('Bad entry in manifest file.')
|
||||||
|
|
||||||
disabled_regex = '|'.join(disabled_tests)
|
disabled_regex = '|'.join(disabled_tests)
|
||||||
method_types_filter = dict()
|
method_types_filter = {}
|
||||||
for method, types in disabled_method_types:
|
for method, types in disabled_method_types:
|
||||||
method_types_filter[method] = set([
|
method_types_filter[method] = set([
|
||||||
dtypes.as_dtype(types_pb2.DataType.Value(name)).as_numpy_dtype
|
dtypes.as_dtype(types_pb2.DataType.Value(name)).as_numpy_dtype
|
||||||
|
@ -47,7 +47,7 @@ class ReportingBenchmark(tf.test.Benchmark):
|
|||||||
|
|
||||||
avg_time = np.average(all_times)
|
avg_time = np.average(all_times)
|
||||||
|
|
||||||
extras = dict()
|
extras = {}
|
||||||
extras['all_times'] = all_times
|
extras['all_times'] = all_times
|
||||||
|
|
||||||
if isinstance(name, tuple):
|
if isinstance(name, tuple):
|
||||||
|
@ -44,7 +44,7 @@ def _export_outputs_to_output_alternatives(export_outputs):
|
|||||||
Returns:
|
Returns:
|
||||||
converted output_alternatives.
|
converted output_alternatives.
|
||||||
"""
|
"""
|
||||||
output = dict()
|
output = {}
|
||||||
if export_outputs is not None:
|
if export_outputs is not None:
|
||||||
for key, value in export_outputs.items():
|
for key, value in export_outputs.items():
|
||||||
if isinstance(value, export_output.ClassificationOutput):
|
if isinstance(value, export_output.ClassificationOutput):
|
||||||
|
@ -316,13 +316,13 @@ class MockOsEnv(collections.Mapping):
|
|||||||
"""A class that allows per-thread TF_CONFIG."""
|
"""A class that allows per-thread TF_CONFIG."""
|
||||||
|
|
||||||
def __init__(self, *args):
|
def __init__(self, *args):
|
||||||
self._dict = dict()
|
self._dict = {}
|
||||||
self._thread_local = threading.local()
|
self._thread_local = threading.local()
|
||||||
super(MockOsEnv, self).__init__(*args)
|
super(MockOsEnv, self).__init__(*args)
|
||||||
|
|
||||||
def get(self, key, default=None):
|
def get(self, key, default=None):
|
||||||
if not hasattr(self._thread_local, 'dict'):
|
if not hasattr(self._thread_local, 'dict'):
|
||||||
self._thread_local.dict = dict()
|
self._thread_local.dict = {}
|
||||||
if key == 'TF_CONFIG':
|
if key == 'TF_CONFIG':
|
||||||
return dict.get(self._thread_local.dict, key, default)
|
return dict.get(self._thread_local.dict, key, default)
|
||||||
else:
|
else:
|
||||||
@ -330,7 +330,7 @@ class MockOsEnv(collections.Mapping):
|
|||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
if not hasattr(self._thread_local, 'dict'):
|
if not hasattr(self._thread_local, 'dict'):
|
||||||
self._thread_local.dict = dict()
|
self._thread_local.dict = {}
|
||||||
if key == 'TF_CONFIG':
|
if key == 'TF_CONFIG':
|
||||||
return dict.__getitem__(self._thread_local.dict, key)
|
return dict.__getitem__(self._thread_local.dict, key)
|
||||||
else:
|
else:
|
||||||
@ -338,7 +338,7 @@ class MockOsEnv(collections.Mapping):
|
|||||||
|
|
||||||
def __setitem__(self, key, val):
|
def __setitem__(self, key, val):
|
||||||
if not hasattr(self._thread_local, 'dict'):
|
if not hasattr(self._thread_local, 'dict'):
|
||||||
self._thread_local.dict = dict()
|
self._thread_local.dict = {}
|
||||||
if key == 'TF_CONFIG':
|
if key == 'TF_CONFIG':
|
||||||
return dict.__setitem__(self._thread_local.dict, key, val)
|
return dict.__setitem__(self._thread_local.dict, key, val)
|
||||||
else:
|
else:
|
||||||
@ -346,7 +346,7 @@ class MockOsEnv(collections.Mapping):
|
|||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
if not hasattr(self._thread_local, 'dict'):
|
if not hasattr(self._thread_local, 'dict'):
|
||||||
self._thread_local.dict = dict()
|
self._thread_local.dict = {}
|
||||||
for x in self._thread_local.dict:
|
for x in self._thread_local.dict:
|
||||||
yield x
|
yield x
|
||||||
for x in self._dict:
|
for x in self._dict:
|
||||||
@ -354,7 +354,7 @@ class MockOsEnv(collections.Mapping):
|
|||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
if not hasattr(self._thread_local, 'dict'):
|
if not hasattr(self._thread_local, 'dict'):
|
||||||
self._thread_local.dict = dict()
|
self._thread_local.dict = {}
|
||||||
return self._thread_local.dict.__len__() + self._dict.__len__()
|
return self._thread_local.dict.__len__() + self._dict.__len__()
|
||||||
|
|
||||||
|
|
||||||
|
@ -179,7 +179,7 @@ def load_word_vectors(data_root, vocab):
|
|||||||
|
|
||||||
print("Loading word vectors...")
|
print("Loading word vectors...")
|
||||||
|
|
||||||
word2index = dict()
|
word2index = {}
|
||||||
embed = []
|
embed = []
|
||||||
|
|
||||||
embed.append([0] * WORD_VECTOR_LEN) # <unk>
|
embed.append([0] * WORD_VECTOR_LEN) # <unk>
|
||||||
|
@ -85,7 +85,7 @@ def restore_variables_on_create(save_path, map_func=None):
|
|||||||
raise ValueError("map_func must be callable.")
|
raise ValueError("map_func must be callable.")
|
||||||
map_func_wrapper = lambda self, x: map_func(x)
|
map_func_wrapper = lambda self, x: map_func(x)
|
||||||
|
|
||||||
ckpt_var_cache = dict()
|
ckpt_var_cache = {}
|
||||||
reader = checkpoint_utils.load_checkpoint(save_path)
|
reader = checkpoint_utils.load_checkpoint(save_path)
|
||||||
for k, _ in checkpoint_utils.list_variables(save_path):
|
for k, _ in checkpoint_utils.list_variables(save_path):
|
||||||
ckpt_var_cache[k] = reader.get_tensor(k)
|
ckpt_var_cache[k] = reader.get_tensor(k)
|
||||||
|
@ -77,7 +77,7 @@ def _update_features_and_columns(features, feature_columns,
|
|||||||
return features, feature_columns
|
return features, feature_columns
|
||||||
|
|
||||||
# First construct new columns and features affected by kernel_mappers_dict.
|
# First construct new columns and features affected by kernel_mappers_dict.
|
||||||
mapped_features = dict()
|
mapped_features = {}
|
||||||
mapped_columns = set()
|
mapped_columns = set()
|
||||||
for feature_column in kernel_mappers_dict:
|
for feature_column in kernel_mappers_dict:
|
||||||
column_name = feature_column.name
|
column_name = feature_column.name
|
||||||
|
@ -488,7 +488,7 @@ def weighted_sum_from_feature_columns(columns_to_tensors,
|
|||||||
default_name='weighted_sum_from_feature_columns',
|
default_name='weighted_sum_from_feature_columns',
|
||||||
values=columns_to_tensors.values()):
|
values=columns_to_tensors.values()):
|
||||||
output_tensors = []
|
output_tensors = []
|
||||||
column_to_variable = dict()
|
column_to_variable = {}
|
||||||
transformer = _Transformer(columns_to_tensors)
|
transformer = _Transformer(columns_to_tensors)
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
for column in sorted(set(feature_columns), key=lambda x: x.key):
|
for column in sorted(set(feature_columns), key=lambda x: x.key):
|
||||||
|
@ -50,7 +50,7 @@ class _BaseEstimator(object):
|
|||||||
params : mapping of string to any
|
params : mapping of string to any
|
||||||
Parameter names mapped to their values.
|
Parameter names mapped to their values.
|
||||||
"""
|
"""
|
||||||
out = dict()
|
out = {}
|
||||||
param_names = [name for name in self.__dict__ if not name.startswith('_')]
|
param_names = [name for name in self.__dict__ if not name.startswith('_')]
|
||||||
for key in param_names:
|
for key in param_names:
|
||||||
value = getattr(self, key, None)
|
value = getattr(self, key, None)
|
||||||
|
@ -545,7 +545,7 @@ class HParams(object):
|
|||||||
ValueError: If `values` cannot be parsed or a hyperparameter in `values`
|
ValueError: If `values` cannot be parsed or a hyperparameter in `values`
|
||||||
doesn't exist.
|
doesn't exist.
|
||||||
"""
|
"""
|
||||||
type_map = dict()
|
type_map = {}
|
||||||
for name, t in self._hparam_types.items():
|
for name, t in self._hparam_types.items():
|
||||||
param_type, _ = t
|
param_type, _ = t
|
||||||
type_map[name] = param_type
|
type_map[name] = param_type
|
||||||
|
@ -366,7 +366,7 @@ class SequenceQueueingStateSaverTest(test.TestCase):
|
|||||||
update_2 = next_batch.save_state("state2",
|
update_2 = next_batch.save_state("state2",
|
||||||
-1 + next_batch.state("state2"))
|
-1 + next_batch.state("state2"))
|
||||||
|
|
||||||
original_values = dict()
|
original_values = {}
|
||||||
|
|
||||||
def insert(which):
|
def insert(which):
|
||||||
for i in range(20):
|
for i in range(20):
|
||||||
|
@ -81,7 +81,7 @@ def word2vec_basic(log_dir):
|
|||||||
"""Process raw inputs into a dataset."""
|
"""Process raw inputs into a dataset."""
|
||||||
count = [['UNK', -1]]
|
count = [['UNK', -1]]
|
||||||
count.extend(collections.Counter(words).most_common(n_words - 1))
|
count.extend(collections.Counter(words).most_common(n_words - 1))
|
||||||
dictionary = dict()
|
dictionary = {}
|
||||||
for word, _ in count:
|
for word, _ in count:
|
||||||
dictionary[word] = len(dictionary)
|
dictionary[word] = len(dictionary)
|
||||||
data = []
|
data = []
|
||||||
|
@ -251,7 +251,7 @@
|
|||||||
"def build_dataset(words):\n",
|
"def build_dataset(words):\n",
|
||||||
" count = [['UNK', -1]]\n",
|
" count = [['UNK', -1]]\n",
|
||||||
" count.extend(collections.Counter(words).most_common(vocabulary_size - 1))\n",
|
" count.extend(collections.Counter(words).most_common(vocabulary_size - 1))\n",
|
||||||
" dictionary = dict()\n",
|
" dictionary = {}\n",
|
||||||
" for word, _ in count:\n",
|
" for word, _ in count:\n",
|
||||||
" dictionary[word] = len(dictionary)\n",
|
" dictionary[word] = len(dictionary)\n",
|
||||||
" data = []\n",
|
" data = []\n",
|
||||||
|
@ -661,7 +661,7 @@
|
|||||||
" mean_loss = 0\n",
|
" mean_loss = 0\n",
|
||||||
" for step in range(num_steps):\n",
|
" for step in range(num_steps):\n",
|
||||||
" batches = train_batches.next()\n",
|
" batches = train_batches.next()\n",
|
||||||
" feed_dict = dict()\n",
|
" feed_dict = {}\n",
|
||||||
" for i in range(num_unrollings + 1):\n",
|
" for i in range(num_unrollings + 1):\n",
|
||||||
" feed_dict[train_data[i]] = batches[i]\n",
|
" feed_dict[train_data[i]] = batches[i]\n",
|
||||||
" _, l, predictions, lr = session.run(\n",
|
" _, l, predictions, lr = session.run(\n",
|
||||||
|
@ -28,7 +28,7 @@ from tensorflow.python.ops import math_ops
|
|||||||
def fake_tf():
|
def fake_tf():
|
||||||
"""Creates a fake module that looks like TensorFlow, for testing."""
|
"""Creates a fake module that looks like TensorFlow, for testing."""
|
||||||
mod = imp.new_module('tensorflow')
|
mod = imp.new_module('tensorflow')
|
||||||
mod_contents = dict()
|
mod_contents = {}
|
||||||
mod_contents.update(gen_math_ops.__dict__)
|
mod_contents.update(gen_math_ops.__dict__)
|
||||||
mod_contents.update(math_ops.__dict__)
|
mod_contents.update(math_ops.__dict__)
|
||||||
mod_contents.update(ops.__dict__)
|
mod_contents.update(ops.__dict__)
|
||||||
|
@ -95,11 +95,11 @@ class BucketBySequenceLengthTest(test_base.DatasetTestBase,
|
|||||||
|
|
||||||
# Expected sum of all batches with an equal sequence length.
|
# Expected sum of all batches with an equal sequence length.
|
||||||
# <seq-length>: <expected-total-sum>
|
# <seq-length>: <expected-total-sum>
|
||||||
expected_sums = dict()
|
expected_sums = {}
|
||||||
|
|
||||||
# Expected batch sizes of batches depending on the sequence length.
|
# Expected batch sizes of batches depending on the sequence length.
|
||||||
# <seq-length>: [batch1_size, ..., batchN_size]
|
# <seq-length>: [batch1_size, ..., batchN_size]
|
||||||
expected_batch_sizes = dict()
|
expected_batch_sizes = {}
|
||||||
|
|
||||||
for length, batch_size, bucket_elements in zip(lengths, batch_sizes,
|
for length, batch_size, bucket_elements in zip(lengths, batch_sizes,
|
||||||
n_bucket_elements):
|
n_bucket_elements):
|
||||||
@ -155,10 +155,10 @@ class BucketBySequenceLengthTest(test_base.DatasetTestBase,
|
|||||||
generated_lengths = []
|
generated_lengths = []
|
||||||
|
|
||||||
# <seq-length>: <total-sum>
|
# <seq-length>: <total-sum>
|
||||||
generated_sums = dict()
|
generated_sums = {}
|
||||||
|
|
||||||
# <seq-length>: [<batch_size>, ...]
|
# <seq-length>: [<batch_size>, ...]
|
||||||
generated_batch_sizes = dict()
|
generated_batch_sizes = {}
|
||||||
|
|
||||||
for length, batch_size, bucket_elements in zip(lengths, batch_sizes,
|
for length, batch_size, bucket_elements in zip(lengths, batch_sizes,
|
||||||
n_bucket_elements):
|
n_bucket_elements):
|
||||||
|
@ -47,7 +47,7 @@ class CLIConfig(object):
|
|||||||
self._config[key] = value
|
self._config[key] = value
|
||||||
self._save_to_file()
|
self._save_to_file()
|
||||||
|
|
||||||
self._set_callbacks = dict()
|
self._set_callbacks = {}
|
||||||
|
|
||||||
def get(self, property_name):
|
def get(self, property_name):
|
||||||
if property_name not in self._config:
|
if property_name not in self._config:
|
||||||
|
@ -113,7 +113,7 @@ class ExpressionEvaluator(object):
|
|||||||
dump: an instance of `DebugDumpDir`.
|
dump: an instance of `DebugDumpDir`.
|
||||||
"""
|
"""
|
||||||
self._dump = dump
|
self._dump = dump
|
||||||
self._cached_tensor_values = dict()
|
self._cached_tensor_values = {}
|
||||||
|
|
||||||
def evaluate(self, expression):
|
def evaluate(self, expression):
|
||||||
"""Parse an expression.
|
"""Parse an expression.
|
||||||
|
@ -949,7 +949,7 @@ class DebugDumpDir(object):
|
|||||||
Returns:
|
Returns:
|
||||||
A dict mapping device names (`str`s) to reconstructed `tf.GraphDef`s.
|
A dict mapping device names (`str`s) to reconstructed `tf.GraphDef`s.
|
||||||
"""
|
"""
|
||||||
non_debug_graphs = dict()
|
non_debug_graphs = {}
|
||||||
for key in self._debug_graphs:
|
for key in self._debug_graphs:
|
||||||
non_debug_graphs[key] = self._debug_graphs[key].non_debug_graph_def
|
non_debug_graphs[key] = self._debug_graphs[key].non_debug_graph_def
|
||||||
return non_debug_graphs
|
return non_debug_graphs
|
||||||
|
@ -249,7 +249,7 @@ class EventListenerTestServicer(grpc_debug_server.EventListenerBaseServicer):
|
|||||||
|
|
||||||
def _initialize_toggle_watch_state(self, toggle_watches):
|
def _initialize_toggle_watch_state(self, toggle_watches):
|
||||||
self._toggle_watches = toggle_watches
|
self._toggle_watches = toggle_watches
|
||||||
self._toggle_watch_state = dict()
|
self._toggle_watch_state = {}
|
||||||
if self._toggle_watches:
|
if self._toggle_watches:
|
||||||
for watch_key in self._toggle_watches:
|
for watch_key in self._toggle_watches:
|
||||||
self._toggle_watch_state[watch_key] = False
|
self._toggle_watch_state[watch_key] = False
|
||||||
|
@ -59,7 +59,7 @@ def _format_origin_stack(origin_stack, call_traceback_proto):
|
|||||||
call_traceback_proto: A `CallTraceback` proto whose fields are to be
|
call_traceback_proto: A `CallTraceback` proto whose fields are to be
|
||||||
populated.
|
populated.
|
||||||
"""
|
"""
|
||||||
string_to_id = dict()
|
string_to_id = {}
|
||||||
string_to_id[None] = 0
|
string_to_id[None] = 0
|
||||||
for frame in origin_stack:
|
for frame in origin_stack:
|
||||||
file_path, lineno, func_name, line_text = frame
|
file_path, lineno, func_name, line_text = frame
|
||||||
|
@ -243,7 +243,7 @@ class NodeStepper(object):
|
|||||||
done = set() # Keep track of visited graph elements.
|
done = set() # Keep track of visited graph elements.
|
||||||
|
|
||||||
# A list of str: Names of the topologically-sorted graph elements.
|
# A list of str: Names of the topologically-sorted graph elements.
|
||||||
node_inputs = dict() # New: Input map of nodes in the transitive closure.
|
node_inputs = {} # New: Input map of nodes in the transitive closure.
|
||||||
|
|
||||||
elem_stack = copy.copy(elem_list)
|
elem_stack = copy.copy(elem_list)
|
||||||
|
|
||||||
|
@ -396,7 +396,7 @@ class BaseDebugWrapperSession(session.SessionInterface):
|
|||||||
self._default_session_context_manager = None
|
self._default_session_context_manager = None
|
||||||
|
|
||||||
# A cache for callables created from CallableOptions.
|
# A cache for callables created from CallableOptions.
|
||||||
self._cached_callables_from_options = dict()
|
self._cached_callables_from_options = {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def graph(self):
|
def graph(self):
|
||||||
|
@ -264,10 +264,10 @@ class CollectiveKeys(object):
|
|||||||
recorded with an id.
|
recorded with an id.
|
||||||
"""
|
"""
|
||||||
self._group_key = group_key_start
|
self._group_key = group_key_start
|
||||||
self._group_key_table = dict()
|
self._group_key_table = {}
|
||||||
|
|
||||||
# For instance keys with ids
|
# For instance keys with ids
|
||||||
self._instance_key_id_to_key_table = dict()
|
self._instance_key_id_to_key_table = {}
|
||||||
self._instance_key_with_id_counter = instance_key_with_id_start
|
self._instance_key_with_id_counter = instance_key_with_id_start
|
||||||
|
|
||||||
# For instance keys without ids
|
# For instance keys without ids
|
||||||
|
@ -2251,7 +2251,7 @@ def _normalize_feature_columns(feature_columns):
|
|||||||
'Given (type {}): {}.'.format(type(column), column))
|
'Given (type {}): {}.'.format(type(column), column))
|
||||||
if not feature_columns:
|
if not feature_columns:
|
||||||
raise ValueError('feature_columns must not be empty.')
|
raise ValueError('feature_columns must not be empty.')
|
||||||
name_to_column = dict()
|
name_to_column = {}
|
||||||
for column in feature_columns:
|
for column in feature_columns:
|
||||||
if column.name in name_to_column:
|
if column.name in name_to_column:
|
||||||
raise ValueError('Duplicate feature column name found for columns: {} '
|
raise ValueError('Duplicate feature column name found for columns: {} '
|
||||||
|
@ -2691,7 +2691,7 @@ def _normalize_feature_columns(feature_columns):
|
|||||||
'Given (type {}): {}.'.format(type(column), column))
|
'Given (type {}): {}.'.format(type(column), column))
|
||||||
if not feature_columns:
|
if not feature_columns:
|
||||||
raise ValueError('feature_columns must not be empty.')
|
raise ValueError('feature_columns must not be empty.')
|
||||||
name_to_column = dict()
|
name_to_column = {}
|
||||||
for column in feature_columns:
|
for column in feature_columns:
|
||||||
if column.name in name_to_column:
|
if column.name in name_to_column:
|
||||||
raise ValueError('Duplicate feature column name found for columns: {} '
|
raise ValueError('Duplicate feature column name found for columns: {} '
|
||||||
|
@ -262,7 +262,7 @@ class _DefinedFunction(object):
|
|||||||
self._definition = None
|
self._definition = None
|
||||||
# Constructed only when C API is enabled, lazily
|
# Constructed only when C API is enabled, lazily
|
||||||
self._c_func = None
|
self._c_func = None
|
||||||
self._sub_functions = dict() # Constructed with _definition or _c_func
|
self._sub_functions = {} # Constructed with _definition or _c_func
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
device_funcs = ops.get_default_graph()._device_functions_outer_to_inner
|
device_funcs = ops.get_default_graph()._device_functions_outer_to_inner
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
|
@ -2995,9 +2995,9 @@ class Graph(object):
|
|||||||
# Similarly, if one or more Session.run calls are going on, all mutate ops
|
# Similarly, if one or more Session.run calls are going on, all mutate ops
|
||||||
# have to wait until all Session.run calls have finished.
|
# have to wait until all Session.run calls have finished.
|
||||||
self._group_lock = lock_util.GroupLock(num_groups=2)
|
self._group_lock = lock_util.GroupLock(num_groups=2)
|
||||||
self._nodes_by_id = dict() # GUARDED_BY(self._lock)
|
self._nodes_by_id = {} # GUARDED_BY(self._lock)
|
||||||
self._next_id_counter = 0 # GUARDED_BY(self._lock)
|
self._next_id_counter = 0 # GUARDED_BY(self._lock)
|
||||||
self._nodes_by_name = dict() # GUARDED_BY(self._lock)
|
self._nodes_by_name = {} # GUARDED_BY(self._lock)
|
||||||
self._version = 0 # GUARDED_BY(self._lock)
|
self._version = 0 # GUARDED_BY(self._lock)
|
||||||
# Maps a name used in the graph to the next id to use for that name.
|
# Maps a name used in the graph to the next id to use for that name.
|
||||||
self._names_in_use = {}
|
self._names_in_use = {}
|
||||||
|
@ -39,7 +39,7 @@ class Registry(object):
|
|||||||
def __init__(self, name):
|
def __init__(self, name):
|
||||||
"""Creates a new registry."""
|
"""Creates a new registry."""
|
||||||
self._name = name
|
self._name = name
|
||||||
self._registry = dict()
|
self._registry = {}
|
||||||
|
|
||||||
def register(self, candidate, name=None):
|
def register(self, candidate, name=None):
|
||||||
"""Registers a Python object "candidate" for the given "name".
|
"""Registers a Python object "candidate" for the given "name".
|
||||||
|
@ -2989,7 +2989,7 @@ class GraphExecutionFunction(object):
|
|||||||
# output from a fetch in `fetches`: { fetch: function(fetch_output) }
|
# output from a fetch in `fetches`: { fetch: function(fetch_output) }
|
||||||
# A Callback can use this to register a function with access to the
|
# A Callback can use this to register a function with access to the
|
||||||
# output values for a fetch it added.
|
# output values for a fetch it added.
|
||||||
self.fetch_callbacks = dict()
|
self.fetch_callbacks = {}
|
||||||
|
|
||||||
if session_kwargs:
|
if session_kwargs:
|
||||||
raise ValueError('Some keys in session_kwargs are not supported at this '
|
raise ValueError('Some keys in session_kwargs are not supported at this '
|
||||||
|
@ -104,8 +104,8 @@ class _Mapping(collections.namedtuple(
|
|||||||
|
|
||||||
def _merge_dicts(self, old=None, new=None):
|
def _merge_dicts(self, old=None, new=None):
|
||||||
"""Helper to merge two dictionaries."""
|
"""Helper to merge two dictionaries."""
|
||||||
old = dict() if old is None else old
|
old = {} if old is None else old
|
||||||
new = dict() if new is None else new
|
new = {} if new is None else new
|
||||||
for k, v in six.iteritems(new):
|
for k, v in six.iteritems(new):
|
||||||
val = old.get(k, None)
|
val = old.get(k, None)
|
||||||
if val is not None and val != v:
|
if val is not None and val != v:
|
||||||
|
@ -201,7 +201,7 @@ class RunMetadataTest(test.TestCase):
|
|||||||
graph = ops.get_default_graph()
|
graph = ops.get_default_graph()
|
||||||
forward_op = set()
|
forward_op = set()
|
||||||
backward_op = set()
|
backward_op = set()
|
||||||
back_to_forward = dict()
|
back_to_forward = {}
|
||||||
for op in graph.get_operations():
|
for op in graph.get_operations():
|
||||||
if op.name.find('gradients/') > 0 and op.name.find('_grad/') > 0:
|
if op.name.find('gradients/') > 0 and op.name.find('_grad/') > 0:
|
||||||
backward_op.add(op.name)
|
backward_op.add(op.name)
|
||||||
|
@ -93,7 +93,7 @@ def _get_logged_ops(graph, run_meta=None, add_trace=True,
|
|||||||
|
|
||||||
op_missing_shape = 0
|
op_missing_shape = 0
|
||||||
logged_ops = {}
|
logged_ops = {}
|
||||||
string_to_id = dict()
|
string_to_id = {}
|
||||||
string_to_id['none'] = len(string_to_id)
|
string_to_id['none'] = len(string_to_id)
|
||||||
# TODO(xpan): Work with Profiler more efficiently.
|
# TODO(xpan): Work with Profiler more efficiently.
|
||||||
for op in graph.get_operations():
|
for op in graph.get_operations():
|
||||||
@ -169,7 +169,7 @@ def merge_default_with_oplog(graph, op_log=None, run_meta=None,
|
|||||||
if not op_log:
|
if not op_log:
|
||||||
tmp_op_log.log_entries.extend(logged_ops.values())
|
tmp_op_log.log_entries.extend(logged_ops.values())
|
||||||
else:
|
else:
|
||||||
all_ops = dict()
|
all_ops = {}
|
||||||
for entry in op_log.log_entries:
|
for entry in op_log.log_entries:
|
||||||
all_ops[entry.name] = entry
|
all_ops[entry.name] = entry
|
||||||
for op_name, entry in six.iteritems(logged_ops):
|
for op_name, entry in six.iteritems(logged_ops):
|
||||||
|
@ -64,12 +64,12 @@ class SignatureDefUtilsTest(test.TestCase):
|
|||||||
def testBuildSignatureDef(self):
|
def testBuildSignatureDef(self):
|
||||||
x = array_ops.placeholder(dtypes.float32, 1, name="x")
|
x = array_ops.placeholder(dtypes.float32, 1, name="x")
|
||||||
x_tensor_info = utils.build_tensor_info(x)
|
x_tensor_info = utils.build_tensor_info(x)
|
||||||
inputs = dict()
|
inputs = {}
|
||||||
inputs["foo-input"] = x_tensor_info
|
inputs["foo-input"] = x_tensor_info
|
||||||
|
|
||||||
y = array_ops.placeholder(dtypes.float32, name="y")
|
y = array_ops.placeholder(dtypes.float32, name="y")
|
||||||
y_tensor_info = utils.build_tensor_info(y)
|
y_tensor_info = utils.build_tensor_info(y)
|
||||||
outputs = dict()
|
outputs = {}
|
||||||
outputs["foo-output"] = y_tensor_info
|
outputs["foo-output"] = y_tensor_info
|
||||||
|
|
||||||
signature_def = signature_def_utils_impl.build_signature_def(
|
signature_def = signature_def_utils_impl.build_signature_def(
|
||||||
|
@ -93,7 +93,7 @@ class AutoTrackable(base.Trackable):
|
|||||||
|
|
||||||
def _list_functions_for_serialization(self):
|
def _list_functions_for_serialization(self):
|
||||||
"""Return a dict of `Function`s of a trackable."""
|
"""Return a dict of `Function`s of a trackable."""
|
||||||
functions = dict()
|
functions = {}
|
||||||
for attribute_name in dir(self):
|
for attribute_name in dir(self):
|
||||||
try:
|
try:
|
||||||
attribute_value = getattr(self, attribute_name, None)
|
attribute_value = getattr(self, attribute_name, None)
|
||||||
|
@ -104,7 +104,7 @@ def _new_mark_used(self, *args, **kwargs):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
_WRAPPERS = dict()
|
_WRAPPERS = {}
|
||||||
|
|
||||||
|
|
||||||
def _get_wrapper(x, tf_should_use_helper):
|
def _get_wrapper(x, tf_should_use_helper):
|
||||||
|
@ -339,7 +339,7 @@ def get_slice_sets_and_required_args(slice_sets, tag_spec):
|
|||||||
|
|
||||||
def gather_tag_args(slices, cli_input_args, required_args):
|
def gather_tag_args(slices, cli_input_args, required_args):
|
||||||
"""Build a dictionary of all the CLI and slice-specified args for a tag."""
|
"""Build a dictionary of all the CLI and slice-specified args for a tag."""
|
||||||
args = dict()
|
args = {}
|
||||||
|
|
||||||
for s in slices:
|
for s in slices:
|
||||||
args = update_args_dict(args, s['args'])
|
args = update_args_dict(args, s['args'])
|
||||||
@ -452,7 +452,7 @@ def gather_existing_partials(partial_path):
|
|||||||
Dict[string, string] of partial short names (like "ubuntu/python" or
|
Dict[string, string] of partial short names (like "ubuntu/python" or
|
||||||
"bazel") to the full contents of that partial.
|
"bazel") to the full contents of that partial.
|
||||||
"""
|
"""
|
||||||
partials = dict()
|
partials = {}
|
||||||
for path, _, files in os.walk(partial_path):
|
for path, _, files in os.walk(partial_path):
|
||||||
for name in files:
|
for name in files:
|
||||||
fullpath = os.path.join(path, name)
|
fullpath = os.path.join(path, name)
|
||||||
|
2
third_party/repo.bzl
vendored
2
third_party/repo.bzl
vendored
@ -185,7 +185,7 @@ def _third_party_http_archive(ctx):
|
|||||||
_apply_patch(ctx, ctx.attr.patch_file)
|
_apply_patch(ctx, ctx.attr.patch_file)
|
||||||
ctx.symlink(Label(ctx.attr.build_file), buildfile_path)
|
ctx.symlink(Label(ctx.attr.build_file), buildfile_path)
|
||||||
|
|
||||||
link_dict = dict()
|
link_dict = {}
|
||||||
if use_syslib:
|
if use_syslib:
|
||||||
link_dict.update(ctx.attr.system_link_files)
|
link_dict.update(ctx.attr.system_link_files)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user