Merge pull request from lgeiger:literals

PiperOrigin-RevId: 239507826
This commit is contained in:
TensorFlower Gardener 2019-03-20 17:29:57 -07:00
commit 9e4fc7fdae
45 changed files with 64 additions and 65 deletions
tensorflow
compiler/tests
contrib
autograph/examples/benchmarks
boosted_trees/estimator_batch
distribute/python
distributions/python/ops/bijectors
eager/python
examples/spinn
saver.py
kernel_methods/python
kinesis/python/kernel_tests
layers/python/layers
learn/python/learn/estimators
training/python/training
examples
python
tools/dockerfiles
third_party

View File

@ -66,7 +66,7 @@ def parse_disabled_manifest(manifest_content):
raise ValueError('Bad entry in manifest file.')
disabled_regex = '|'.join(disabled_tests)
method_types_filter = dict()
method_types_filter = {}
for method, types in disabled_method_types:
method_types_filter[method] = set([
dtypes.as_dtype(types_pb2.DataType.Value(name)).as_numpy_dtype

View File

@ -47,7 +47,7 @@ class ReportingBenchmark(tf.test.Benchmark):
avg_time = np.average(all_times)
extras = dict()
extras = {}
extras['all_times'] = all_times
if isinstance(name, tuple):

View File

@ -44,7 +44,7 @@ def _export_outputs_to_output_alternatives(export_outputs):
Returns:
converted output_alternatives.
"""
output = dict()
output = {}
if export_outputs is not None:
for key, value in export_outputs.items():
if isinstance(value, export_output.ClassificationOutput):

View File

@ -321,13 +321,13 @@ class MockOsEnv(collections.Mapping):
"""A class that allows per-thread TF_CONFIG."""
def __init__(self, *args):
self._dict = dict()
self._dict = {}
self._thread_local = threading.local()
super(MockOsEnv, self).__init__(*args)
def get(self, key, default=None):
if not hasattr(self._thread_local, 'dict'):
self._thread_local.dict = dict()
self._thread_local.dict = {}
if key == 'TF_CONFIG':
return dict.get(self._thread_local.dict, key, default)
else:
@ -335,7 +335,7 @@ class MockOsEnv(collections.Mapping):
def __getitem__(self, key):
if not hasattr(self._thread_local, 'dict'):
self._thread_local.dict = dict()
self._thread_local.dict = {}
if key == 'TF_CONFIG':
return dict.__getitem__(self._thread_local.dict, key)
else:
@ -343,7 +343,7 @@ class MockOsEnv(collections.Mapping):
def __setitem__(self, key, val):
if not hasattr(self._thread_local, 'dict'):
self._thread_local.dict = dict()
self._thread_local.dict = {}
if key == 'TF_CONFIG':
return dict.__setitem__(self._thread_local.dict, key, val)
else:
@ -351,7 +351,7 @@ class MockOsEnv(collections.Mapping):
def __iter__(self):
if not hasattr(self._thread_local, 'dict'):
self._thread_local.dict = dict()
self._thread_local.dict = {}
for x in self._thread_local.dict:
yield x
for x in self._dict:
@ -359,7 +359,7 @@ class MockOsEnv(collections.Mapping):
def __len__(self):
if not hasattr(self._thread_local, 'dict'):
self._thread_local.dict = dict()
self._thread_local.dict = {}
return self._thread_local.dict.__len__() + self._dict.__len__()

View File

@ -266,7 +266,7 @@ class BatchNormalization(bijector.Bijector):
else:
# At training-time, ildj is computed from the mean and log-variance across
# the current minibatch.
_, v = nn.moments(y, axes=reduction_axes, keep_dims=True)
_, v = nn.moments(y, axes=reduction_axes, keepdims=True)
log_variance = math_ops.log(v + self.batchnorm.epsilon)
# `gamma` and `log Var(y)` reductions over event_dims.

View File

@ -179,7 +179,7 @@ def load_word_vectors(data_root, vocab):
print("Loading word vectors...")
word2index = dict()
word2index = {}
embed = []
embed.append([0] * WORD_VECTOR_LEN) # <unk>

View File

@ -85,7 +85,7 @@ def restore_variables_on_create(save_path, map_func=None):
raise ValueError("map_func must be callable.")
map_func_wrapper = lambda self, x: map_func(x)
ckpt_var_cache = dict()
ckpt_var_cache = {}
reader = checkpoint_utils.load_checkpoint(save_path)
for k, _ in checkpoint_utils.list_variables(save_path):
ckpt_var_cache[k] = reader.get_tensor(k)

View File

@ -77,7 +77,7 @@ def _update_features_and_columns(features, feature_columns,
return features, feature_columns
# First construct new columns and features affected by kernel_mappers_dict.
mapped_features = dict()
mapped_features = {}
mapped_columns = set()
for feature_column in kernel_mappers_dict:
column_name = feature_column.name

View File

@ -110,7 +110,7 @@ class KinesisDatasetTest(test.TestCase):
init_batch_op = iterator.make_initializer(batch_dataset)
get_next = iterator.get_next()
data = list()
data = []
with self.cached_session() as sess:
# Basic test: read from shard 0 of stream 2.
sess.run(

View File

@ -488,7 +488,7 @@ def weighted_sum_from_feature_columns(columns_to_tensors,
default_name='weighted_sum_from_feature_columns',
values=columns_to_tensors.values()):
output_tensors = []
column_to_variable = dict()
column_to_variable = {}
transformer = _Transformer(columns_to_tensors)
# pylint: disable=protected-access
for column in sorted(set(feature_columns), key=lambda x: x.key):

View File

@ -50,7 +50,7 @@ class _BaseEstimator(object):
params : mapping of string to any
Parameter names mapped to their values.
"""
out = dict()
out = {}
param_names = [name for name in self.__dict__ if not name.startswith('_')]
for key in param_names:
value = getattr(self, key, None)

View File

@ -545,7 +545,7 @@ class HParams(object):
ValueError: If `values` cannot be parsed or a hyperparameter in `values`
doesn't exist.
"""
type_map = dict()
type_map = {}
for name, t in self._hparam_types.items():
param_type, _ = t
type_map[name] = param_type

View File

@ -366,7 +366,7 @@ class SequenceQueueingStateSaverTest(test.TestCase):
update_2 = next_batch.save_state("state2",
-1 + next_batch.state("state2"))
original_values = dict()
original_values = {}
def insert(which):
for i in range(20):

View File

@ -81,10 +81,10 @@ def word2vec_basic(log_dir):
"""Process raw inputs into a dataset."""
count = [['UNK', -1]]
count.extend(collections.Counter(words).most_common(n_words - 1))
dictionary = dict()
dictionary = {}
for word, _ in count:
dictionary[word] = len(dictionary)
data = list()
data = []
unk_count = 0
for word in words:
index = dictionary.get(word, 0)

View File

@ -251,10 +251,10 @@
"def build_dataset(words):\n",
" count = [['UNK', -1]]\n",
" count.extend(collections.Counter(words).most_common(vocabulary_size - 1))\n",
" dictionary = dict()\n",
" dictionary = {}\n",
" for word, _ in count:\n",
" dictionary[word] = len(dictionary)\n",
" data = list()\n",
" data = []\n",
" unk_count = 0\n",
" for word in words:\n",
" if word in dictionary:\n",

View File

@ -554,7 +554,7 @@
" return output_gate * tf.tanh(state), state\n",
"\n",
" # Input data.\n",
" train_data = list()\n",
" train_data = []\n",
" for _ in range(num_unrollings + 1):\n",
" train_data.append(\n",
" tf.placeholder(tf.float32, shape=[batch_size,vocabulary_size]))\n",
@ -562,7 +562,7 @@
" train_labels = train_data[1:] # labels are inputs shifted by one time step.\n",
"\n",
" # Unrolled LSTM loop.\n",
" outputs = list()\n",
" outputs = []\n",
" output = saved_output\n",
" state = saved_state\n",
" for i in train_inputs:\n",
@ -661,7 +661,7 @@
" mean_loss = 0\n",
" for step in range(num_steps):\n",
" batches = train_batches.next()\n",
" feed_dict = dict()\n",
" feed_dict = {}\n",
" for i in range(num_unrollings + 1):\n",
" feed_dict[train_data[i]] = batches[i]\n",
" _, l, predictions, lr = session.run(\n",

View File

@ -28,7 +28,7 @@ from tensorflow.python.ops import math_ops
def fake_tf():
"""Creates a fake module that looks like TensorFlow, for testing."""
mod = imp.new_module('tensorflow')
mod_contents = dict()
mod_contents = {}
mod_contents.update(gen_math_ops.__dict__)
mod_contents.update(math_ops.__dict__)
mod_contents.update(ops.__dict__)

View File

@ -95,11 +95,11 @@ class BucketBySequenceLengthTest(test_base.DatasetTestBase,
# Expected sum of all batches with an equal sequence length.
# <seq-length>: <expected-total-sum>
expected_sums = dict()
expected_sums = {}
# Expected batch sizes of batches depending on the sequence length.
# <seq-length>: [batch1_size, ..., batchN_size]
expected_batch_sizes = dict()
expected_batch_sizes = {}
for length, batch_size, bucket_elements in zip(lengths, batch_sizes,
n_bucket_elements):
@ -155,10 +155,10 @@ class BucketBySequenceLengthTest(test_base.DatasetTestBase,
generated_lengths = []
# <seq-length>: <total-sum>
generated_sums = dict()
generated_sums = {}
# <seq-length>: [<batch_size>, ...]
generated_batch_sizes = dict()
generated_batch_sizes = {}
for length, batch_size, bucket_elements in zip(lengths, batch_sizes,
n_bucket_elements):

View File

@ -47,7 +47,7 @@ class CLIConfig(object):
self._config[key] = value
self._save_to_file()
self._set_callbacks = dict()
self._set_callbacks = {}
def get(self, property_name):
if property_name not in self._config:

View File

@ -113,7 +113,7 @@ class ExpressionEvaluator(object):
dump: an instance of `DebugDumpDir`.
"""
self._dump = dump
self._cached_tensor_values = dict()
self._cached_tensor_values = {}
def evaluate(self, expression):
"""Parse an expression.

View File

@ -949,7 +949,7 @@ class DebugDumpDir(object):
Returns:
A dict mapping device names (`str`s) to reconstructed `tf.GraphDef`s.
"""
non_debug_graphs = dict()
non_debug_graphs = {}
for key in self._debug_graphs:
non_debug_graphs[key] = self._debug_graphs[key].non_debug_graph_def
return non_debug_graphs

View File

@ -249,7 +249,7 @@ class EventListenerTestServicer(grpc_debug_server.EventListenerBaseServicer):
def _initialize_toggle_watch_state(self, toggle_watches):
self._toggle_watches = toggle_watches
self._toggle_watch_state = dict()
self._toggle_watch_state = {}
if self._toggle_watches:
for watch_key in self._toggle_watches:
self._toggle_watch_state[watch_key] = False

View File

@ -59,7 +59,7 @@ def _format_origin_stack(origin_stack, call_traceback_proto):
call_traceback_proto: A `CallTraceback` proto whose fields are to be
populated.
"""
string_to_id = dict()
string_to_id = {}
string_to_id[None] = 0
for frame in origin_stack:
file_path, lineno, func_name, line_text = frame

View File

@ -243,7 +243,7 @@ class NodeStepper(object):
done = set() # Keep track of visited graph elements.
# A list of str: Names of the topologically-sorted graph elements.
node_inputs = dict() # New: Input map of nodes in the transitive closure.
node_inputs = {} # New: Input map of nodes in the transitive closure.
elem_stack = copy.copy(elem_list)

View File

@ -396,7 +396,7 @@ class BaseDebugWrapperSession(session.SessionInterface):
self._default_session_context_manager = None
# A cache for callables created from CallableOptions.
self._cached_callables_from_options = dict()
self._cached_callables_from_options = {}
@property
def graph(self):

View File

@ -264,10 +264,10 @@ class CollectiveKeys(object):
recorded with an id.
"""
self._group_key = group_key_start
self._group_key_table = dict()
self._group_key_table = {}
# For instance keys with ids
self._instance_key_id_to_key_table = dict()
self._instance_key_id_to_key_table = {}
self._instance_key_with_id_counter = instance_key_with_id_start
# For instance keys without ids

View File

@ -582,8 +582,8 @@ class Function(object):
concrete_functions.extend(
self._stateless_fn._function_cache.all_values())
# pylint: enable=protected-access
deduplicated_concrete_functions = list()
seen_signatures = list()
deduplicated_concrete_functions = []
seen_signatures = []
# We are using a list so that:
# - the returned collection is deterministic, and
# - we can use a custom equality operator (is_same_structure).

View File

@ -2251,7 +2251,7 @@ def _normalize_feature_columns(feature_columns):
'Given (type {}): {}.'.format(type(column), column))
if not feature_columns:
raise ValueError('feature_columns must not be empty.')
name_to_column = dict()
name_to_column = {}
for column in feature_columns:
if column.name in name_to_column:
raise ValueError('Duplicate feature column name found for columns: {} '

View File

@ -2691,7 +2691,7 @@ def _normalize_feature_columns(feature_columns):
'Given (type {}): {}.'.format(type(column), column))
if not feature_columns:
raise ValueError('feature_columns must not be empty.')
name_to_column = dict()
name_to_column = {}
for column in feature_columns:
if column.name in name_to_column:
raise ValueError('Duplicate feature column name found for columns: {} '

View File

@ -262,7 +262,7 @@ class _DefinedFunction(object):
self._definition = None
# Constructed only when C API is enabled, lazily
self._c_func = None
self._sub_functions = dict() # Constructed with _definition or _c_func
self._sub_functions = {} # Constructed with _definition or _c_func
# pylint: disable=protected-access
device_funcs = ops.get_default_graph()._device_functions_outer_to_inner
# pylint: enable=protected-access

View File

@ -2995,9 +2995,9 @@ class Graph(object):
# Similarly, if one or more Session.run calls are going on, all mutate ops
# have to wait until all Session.run calls have finished.
self._group_lock = lock_util.GroupLock(num_groups=2)
self._nodes_by_id = dict() # GUARDED_BY(self._lock)
self._nodes_by_id = {} # GUARDED_BY(self._lock)
self._next_id_counter = 0 # GUARDED_BY(self._lock)
self._nodes_by_name = dict() # GUARDED_BY(self._lock)
self._nodes_by_name = {} # GUARDED_BY(self._lock)
self._version = 0 # GUARDED_BY(self._lock)
# Maps a name used in the graph to the next id to use for that name.
self._names_in_use = {}

View File

@ -39,7 +39,7 @@ class Registry(object):
def __init__(self, name):
"""Creates a new registry."""
self._name = name
self._registry = dict()
self._registry = {}
def register(self, candidate, name=None):
"""Registers a Python object "candidate" for the given "name".

View File

@ -3006,7 +3006,7 @@ class GraphExecutionFunction(object):
# output from a fetch in `fetches`: { fetch: function(fetch_output) }
# A Callback can use this to register a function with access to the
# output values for a fetch it added.
self.fetch_callbacks = dict()
self.fetch_callbacks = {}
if session_kwargs:
raise ValueError('Some keys in session_kwargs are not supported at this '

View File

@ -1252,7 +1252,7 @@ class FIFOQueueTest(test.TestCase):
def testSelectQueue(self):
with self.cached_session():
num_queues = 10
qlist = list()
qlist = []
for _ in xrange(num_queues):
qlist.append(data_flow_ops.FIFOQueue(10, dtypes_lib.float32))
# Enqueue/Dequeue into a dynamically selected queue

View File

@ -1420,7 +1420,7 @@ class PaddingFIFOQueueTest(test.TestCase):
def testSelectQueue(self):
with self.cached_session():
num_queues = 10
qlist = list()
qlist = []
for _ in xrange(num_queues):
qlist.append(
data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),)))

View File

@ -1201,7 +1201,7 @@ class RandomShuffleQueueTest(test.TestCase):
def testSelectQueue(self):
with self.cached_session():
num_queues = 10
qlist = list()
qlist = []
for _ in xrange(num_queues):
qlist.append(
data_flow_ops.RandomShuffleQueue(10, 0, dtypes_lib.float32))

View File

@ -104,8 +104,8 @@ class _Mapping(collections.namedtuple(
def _merge_dicts(self, old=None, new=None):
"""Helper to merge two dictionaries."""
old = dict() if old is None else old
new = dict() if new is None else new
old = {} if old is None else old
new = {} if new is None else new
for k, v in six.iteritems(new):
val = old.get(k, None)
if val is not None and val != v:

View File

@ -201,7 +201,7 @@ class RunMetadataTest(test.TestCase):
graph = ops.get_default_graph()
forward_op = set()
backward_op = set()
back_to_forward = dict()
back_to_forward = {}
for op in graph.get_operations():
if op.name.find('gradients/') > 0 and op.name.find('_grad/') > 0:
backward_op.add(op.name)

View File

@ -93,7 +93,7 @@ def _get_logged_ops(graph, run_meta=None, add_trace=True,
op_missing_shape = 0
logged_ops = {}
string_to_id = dict()
string_to_id = {}
string_to_id['none'] = len(string_to_id)
# TODO(xpan): Work with Profiler more efficiently.
for op in graph.get_operations():
@ -169,7 +169,7 @@ def merge_default_with_oplog(graph, op_log=None, run_meta=None,
if not op_log:
tmp_op_log.log_entries.extend(logged_ops.values())
else:
all_ops = dict()
all_ops = {}
for entry in op_log.log_entries:
all_ops[entry.name] = entry
for op_name, entry in six.iteritems(logged_ops):

View File

@ -64,12 +64,12 @@ class SignatureDefUtilsTest(test.TestCase):
def testBuildSignatureDef(self):
x = array_ops.placeholder(dtypes.float32, 1, name="x")
x_tensor_info = utils.build_tensor_info(x)
inputs = dict()
inputs = {}
inputs["foo-input"] = x_tensor_info
y = array_ops.placeholder(dtypes.float32, name="y")
y_tensor_info = utils.build_tensor_info(y)
outputs = dict()
outputs = {}
outputs["foo-output"] = y_tensor_info
signature_def = signature_def_utils_impl.build_signature_def(

View File

@ -662,8 +662,7 @@ class TPUEmbedding(object):
Arguments for `enqueue_tpu_embedding_sparse_tensor_batch()`.
"""
sample_idcs, embedding_idcs, aggregation_weights, table_ids = (
list(), list(), list(), list())
sample_idcs, embedding_idcs, aggregation_weights, table_ids = [], [], [], []
for table_id, table in enumerate(self._table_to_features_dict):
features = self._table_to_features_dict[table]
for feature in features:

View File

@ -93,7 +93,7 @@ class AutoTrackable(base.Trackable):
def _list_functions_for_serialization(self):
"""Return a dict of `Function`s of a trackable."""
functions = dict()
functions = {}
for attribute_name in dir(self):
try:
attribute_value = getattr(self, attribute_name, None)

View File

@ -104,7 +104,7 @@ def _new_mark_used(self, *args, **kwargs):
pass
_WRAPPERS = dict()
_WRAPPERS = {}
def _get_wrapper(x, tf_should_use_helper):

View File

@ -339,7 +339,7 @@ def get_slice_sets_and_required_args(slice_sets, tag_spec):
def gather_tag_args(slices, cli_input_args, required_args):
"""Build a dictionary of all the CLI and slice-specified args for a tag."""
args = dict()
args = {}
for s in slices:
args = update_args_dict(args, s['args'])
@ -452,7 +452,7 @@ def gather_existing_partials(partial_path):
Dict[string, string] of partial short names (like "ubuntu/python" or
"bazel") to the full contents of that partial.
"""
partials = dict()
partials = {}
for path, _, files in os.walk(partial_path):
for name in files:
fullpath = os.path.join(path, name)

View File

@ -185,7 +185,7 @@ def _third_party_http_archive(ctx):
_apply_patch(ctx, ctx.attr.patch_file)
ctx.symlink(Label(ctx.attr.build_file), buildfile_path)
link_dict = dict()
link_dict = {}
if use_syslib:
link_dict.update(ctx.attr.system_link_files)