Add C0326 bad-whitespace error to pylint sanity check.

PiperOrigin-RevId: 183689499
This commit is contained in:
Yifei Feng 2018-01-29 10:42:32 -08:00 committed by TensorFlower Gardener
parent 730071d0dc
commit fd63d4e30a
11 changed files with 353 additions and 309 deletions

View File

@ -31,7 +31,6 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest from tensorflow.python.platform import googletest
class AccumulateNV2Test(test_util.TensorFlowTestCase): class AccumulateNV2Test(test_util.TensorFlowTestCase):
"""Tests of the new, differentiable version of accumulate_n""" """Tests of the new, differentiable version of accumulate_n"""
@ -62,8 +61,9 @@ class AccumulateNV2Test(test_util.TensorFlowTestCase):
accum_n = av2.accumulate_n_v2(input_vars) accum_n = av2.accumulate_n_v2(input_vars)
sess.run(variables.global_variables_initializer()) sess.run(variables.global_variables_initializer())
accum_n_grad = gradients.gradients(accum_n, input_vars) accum_n_grad = gradients.gradients(accum_n, input_vars)
self.assertAllEqual(np.repeat(1.0, num_inputs), # d/dx (x + y + ...) = 1 self.assertAllEqual(
[g.eval() for g in accum_n_grad]) np.repeat(1.0, num_inputs), # d/dx (x + y + ...) = 1
[g.eval() for g in accum_n_grad])
# The tests below used to be in a separate class under cwise_ops_test.py, # The tests below used to be in a separate class under cwise_ops_test.py,
# which did not run in the default test target. # which did not run in the default test target.
@ -75,8 +75,8 @@ class AccumulateNV2Test(test_util.TensorFlowTestCase):
np.random.rand(16, 16, 16, 16).astype(np.float32) for _ in range(20) np.random.rand(16, 16, 16, 16).astype(np.float32) for _ in range(20)
] ]
random_tensors = [ random_tensors = [
ops.convert_to_tensor( ops.convert_to_tensor(x, dtype=dtypes_lib.float32)
x, dtype=dtypes_lib.float32) for x in random_arrays for x in random_arrays
] ]
tf_val = av2.accumulate_n_v2(random_tensors) tf_val = av2.accumulate_n_v2(random_tensors)
np_val = random_arrays[0] np_val = random_arrays[0]
@ -95,21 +95,21 @@ class AccumulateNV2Test(test_util.TensorFlowTestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
a = variables.Variable(0.2) a = variables.Variable(0.2)
b = variables.Variable(0.1) b = variables.Variable(0.1)
tf_val = av2.accumulate_n_v2([a,b], shape=[2,2]) # Should be shape=[] tf_val = av2.accumulate_n_v2([a, b], shape=[2, 2]) # Should be shape=[]
def testIncompatibleShapes(self): def testIncompatibleShapes(self):
with self.test_session(): with self.test_session():
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
a = variables.Variable(np.array([0.1,0.2])) a = variables.Variable(np.array([0.1, 0.2]))
b = variables.Variable(np.array([[0.3],[0.4]])) b = variables.Variable(np.array([[0.3], [0.4]]))
tf_val = av2.accumulate_n_v2([a,b]) tf_val = av2.accumulate_n_v2([a, b])
def testWrongType(self): def testWrongType(self):
with self.test_session(): with self.test_session():
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
a = variables.Variable(0.2, dtype=np.float32) a = variables.Variable(0.2, dtype=np.float32)
b = variables.Variable(0.1, dtype=np.float32) b = variables.Variable(0.1, dtype=np.float32)
tf_val = av2.accumulate_n_v2([a,b], tensor_dtype=np.int32) tf_val = av2.accumulate_n_v2([a, b], tensor_dtype=np.int32)
def testWrongTypeOneInput(self): def testWrongTypeOneInput(self):
# Scenario that used to trigger a bug, even when testWrongType() worked # Scenario that used to trigger a bug, even when testWrongType() worked

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Base utilities for loading datasets.""" """Base utilities for loading datasets."""
from __future__ import absolute_import from __future__ import absolute_import
@ -100,9 +99,7 @@ def load_iris(data_path=None):
module_path = path.dirname(__file__) module_path = path.dirname(__file__)
data_path = path.join(module_path, 'data', 'iris.csv') data_path = path.join(module_path, 'data', 'iris.csv')
return load_csv_with_header( return load_csv_with_header(
data_path, data_path, target_dtype=np.int, features_dtype=np.float)
target_dtype=np.int,
features_dtype=np.float)
def load_boston(data_path=None): def load_boston(data_path=None):
@ -118,16 +115,10 @@ def load_boston(data_path=None):
module_path = path.dirname(__file__) module_path = path.dirname(__file__)
data_path = path.join(module_path, 'data', 'boston_house_prices.csv') data_path = path.join(module_path, 'data', 'boston_house_prices.csv')
return load_csv_with_header( return load_csv_with_header(
data_path, data_path, target_dtype=np.float, features_dtype=np.float)
target_dtype=np.float,
features_dtype=np.float)
def retry(initial_delay, def retry(initial_delay, max_delay, factor=2.0, jitter=0.25, is_retriable=None):
max_delay,
factor=2.0,
jitter=0.25,
is_retriable=None):
"""Simple decorator for wrapping retriable functions. """Simple decorator for wrapping retriable functions.
Args: Args:
@ -152,7 +143,7 @@ def retry(initial_delay,
def delays(): def delays():
delay = initial_delay delay = initial_delay
while delay <= max_delay: while delay <= max_delay:
yield delay * random.uniform(1 - jitter, 1 + jitter) yield delay * random.uniform(1 - jitter, 1 + jitter)
delay *= factor delay *= factor
def wrap(fn): def wrap(fn):
@ -172,7 +163,9 @@ def retry(initial_delay,
else: else:
raise raise
return fn(*args, **kwargs) return fn(*args, **kwargs)
return wrapped_fn return wrapped_fn
return wrap return wrap

View File

@ -198,10 +198,13 @@ class TensorMapper(object):
def GenerateGraph(subgraph_idx, g, opcode_mapper): def GenerateGraph(subgraph_idx, g, opcode_mapper):
"""Produces the HTML required to have a d3 visualization of the dag.""" """Produces the HTML required to have a d3 visualization of the dag."""
def TensorName(idx): def TensorName(idx):
return "t%d"%idx return "t%d" % idx
def OpName(idx): def OpName(idx):
return "o%d"%idx return "o%d" % idx
edges = [] edges = []
nodes = [] nodes = []
first = {} first = {}
@ -210,27 +213,35 @@ def GenerateGraph(subgraph_idx, g, opcode_mapper):
for tensor_input_position, tensor_index in enumerate(op["inputs"]): for tensor_input_position, tensor_index in enumerate(op["inputs"]):
if tensor_index not in first: if tensor_index not in first:
first[tensor_index] = ( first[tensor_index] = (
op_index*pixel_mult, op_index * pixel_mult,
tensor_input_position*pixel_mult - pixel_mult/2) tensor_input_position * pixel_mult - pixel_mult / 2)
edges.append( edges.append({
{"source": TensorName(tensor_index), "target": OpName(op_index)}) "source": TensorName(tensor_index),
"target": OpName(op_index)
})
for tensor_index in op["outputs"]: for tensor_index in op["outputs"]:
edges.append( edges.append({
{"target": TensorName(tensor_index), "source": OpName(op_index)}) "target": TensorName(tensor_index),
nodes.append({"id": OpName(op_index), "source": OpName(op_index)
"name": opcode_mapper(op["opcode_index"]), })
"group": 2, nodes.append({
"x": pixel_mult, "id": OpName(op_index),
"y": op_index * pixel_mult}) "name": opcode_mapper(op["opcode_index"]),
"group": 2,
"x": pixel_mult,
"y": op_index * pixel_mult
})
for tensor_index, tensor in enumerate(g["tensors"]): for tensor_index, tensor in enumerate(g["tensors"]):
initial_y = (first[tensor_index] if tensor_index in first initial_y = (
else len(g["operators"])) first[tensor_index] if tensor_index in first else len(g["operators"]))
nodes.append({"id": TensorName(tensor_index), nodes.append({
"name": "%s (%d)" % (tensor["name"], tensor_index), "id": TensorName(tensor_index),
"group": 1, "name": "%s (%d)" % (tensor["name"], tensor_index),
"x": 2, "group": 1,
"y": initial_y}) "x": 2,
"y": initial_y
})
graph_str = json.dumps({"nodes": nodes, "edges": edges}) graph_str = json.dumps({"nodes": nodes, "edges": edges})
html = _D3_HTML_TEMPLATE % (graph_str, subgraph_idx) html = _D3_HTML_TEMPLATE % (graph_str, subgraph_idx)
@ -267,7 +278,7 @@ def GenerateTableHtml(items, keys_to_print, display_index=True):
for h, mapper in keys_to_print: for h, mapper in keys_to_print:
val = tensor[h] if h in tensor else None val = tensor[h] if h in tensor else None
val = val if mapper is None else mapper(val) val = val if mapper is None else mapper(val)
html += "<td>%s</td>\n"%val html += "<td>%s</td>\n" % val
html += "</tr>\n" html += "</tr>\n"
html += "</table>\n" html += "</table>\n"
@ -279,18 +290,19 @@ def CreateHtmlFile(tflite_input, html_output):
# Convert the model into a JSON flatbuffer using flatc (build if doesn't # Convert the model into a JSON flatbuffer using flatc (build if doesn't
# exist. # exist.
if not os.path.exists(tflite_input): if not os.path.exists(tflite_input):
raise RuntimeError("Invalid filename %r" % tflite_input) raise RuntimeError("Invalid filename %r" % tflite_input)
if tflite_input.endswith(".tflite") or tflite_input.endswith(".bin"): if tflite_input.endswith(".tflite") or tflite_input.endswith(".bin"):
# Run convert # Run convert
cmd = (_BINARY + " -t " cmd = (
"--strict-json --defaults-json -o /tmp {schema} -- {input}".format( _BINARY + " -t "
input=tflite_input, schema=_SCHEMA)) "--strict-json --defaults-json -o /tmp {schema} -- {input}".format(
input=tflite_input, schema=_SCHEMA))
print(cmd) print(cmd)
os.system(cmd) os.system(cmd)
real_output = ("/tmp/"+ os.path.splitext(os.path.split(tflite_input)[-1])[0] real_output = ("/tmp/" + os.path.splitext(
+ ".json") os.path.split(tflite_input)[-1])[0] + ".json")
data = json.load(open(real_output)) data = json.load(open(real_output))
elif tflite_input.endswith(".json"): elif tflite_input.endswith(".json"):
@ -302,12 +314,13 @@ def CreateHtmlFile(tflite_input, html_output):
html += "<h1>TensorFlow Lite Model</h2>" html += "<h1>TensorFlow Lite Model</h2>"
data["filename"] = tflite_input # Avoid special case data["filename"] = tflite_input # Avoid special case
toplevel_stuff = [("filename", None), ("version", None), toplevel_stuff = [("filename", None), ("version", None), ("description",
("description", None)] None)]
html += "<table>\n" html += "<table>\n"
for key, mapping in toplevel_stuff: for key, mapping in toplevel_stuff:
if not mapping: mapping = lambda x: x if not mapping:
mapping = lambda x: x
html += "<tr><th>%s</th><td>%s</td></tr>\n" % (key, mapping(data[key])) html += "<tr><th>%s</th><td>%s</td></tr>\n" % (key, mapping(data[key]))
html += "</table>\n" html += "</table>\n"
@ -320,22 +333,22 @@ def CreateHtmlFile(tflite_input, html_output):
html += "<div class='subgraph'>" html += "<div class='subgraph'>"
tensor_mapper = TensorMapper(g) tensor_mapper = TensorMapper(g)
opcode_mapper = OpCodeMapper(data) opcode_mapper = OpCodeMapper(data)
op_keys_to_display = [ op_keys_to_display = [("inputs", tensor_mapper), ("outputs", tensor_mapper),
("inputs", tensor_mapper), ("outputs", tensor_mapper), ("builtin_options", None), ("opcode_index",
("builtin_options", None), ("opcode_index", opcode_mapper)] opcode_mapper)]
tensor_keys_to_display = [ tensor_keys_to_display = [("name", None), ("type", None), ("shape", None),
("name", None), ("type", None), ("shape", None), ("buffer", None), ("buffer", None), ("quantization", None)]
("quantization", None)]
html += "<h2>Subgraph %d</h2>\n" % subgraph_idx html += "<h2>Subgraph %d</h2>\n" % subgraph_idx
# Inputs and outputs. # Inputs and outputs.
html += "<h3>Inputs/Outputs</h3>\n" html += "<h3>Inputs/Outputs</h3>\n"
html += GenerateTableHtml([{"inputs": g["inputs"], html += GenerateTableHtml(
"outputs": g["outputs"]}], [{
[("inputs", tensor_mapper), "inputs": g["inputs"],
("outputs", tensor_mapper)], "outputs": g["outputs"]
display_index=False) }], [("inputs", tensor_mapper), ("outputs", tensor_mapper)],
display_index=False)
# Print the tensors. # Print the tensors.
html += "<h3>Tensors</h3>\n" html += "<h3>Tensors</h3>\n"
@ -357,8 +370,7 @@ def CreateHtmlFile(tflite_input, html_output):
# Operator codes # Operator codes
html += "<h2>Operator Codes</h2>\n" html += "<h2>Operator Codes</h2>\n"
html += GenerateTableHtml(data["operator_codes"], html += GenerateTableHtml(data["operator_codes"], operator_keys_to_display)
operator_keys_to_display)
html += "</body></html>\n" html += "</body></html>\n"
@ -370,10 +382,10 @@ def main(argv):
tflite_input = argv[1] tflite_input = argv[1]
html_output = argv[2] html_output = argv[2]
except IndexError: except IndexError:
print ("Usage: %s <input tflite> <output html>" % (argv[0])) print("Usage: %s <input tflite> <output html>" % (argv[0]))
else: else:
CreateHtmlFile(tflite_input, html_output) CreateHtmlFile(tflite_input, html_output)
if __name__ == "__main__": if __name__ == "__main__":
main(sys.argv) main(sys.argv)

View File

@ -58,6 +58,7 @@ def read_cifar10(filename_queue):
class CIFAR10Record(object): class CIFAR10Record(object):
pass pass
result = CIFAR10Record() result = CIFAR10Record()
# Dimensions of the images in the CIFAR-10 dataset. # Dimensions of the images in the CIFAR-10 dataset.
@ -147,8 +148,9 @@ def distorted_inputs(data_dir, batch_size):
images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
labels: Labels. 1D tensor of [batch_size] size. labels: Labels. 1D tensor of [batch_size] size.
""" """
filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i) filenames = [
for i in xrange(1, 6)] os.path.join(data_dir, 'data_batch_%d.bin' % i) for i in xrange(1, 6)
]
for f in filenames: for f in filenames:
if not tf.gfile.Exists(f): if not tf.gfile.Exists(f):
raise ValueError('Failed to find file: ' + f) raise ValueError('Failed to find file: ' + f)
@ -174,10 +176,9 @@ def distorted_inputs(data_dir, batch_size):
# Because these operations are not commutative, consider randomizing # Because these operations are not commutative, consider randomizing
# the order their operation. # the order their operation.
distorted_image = tf.image.random_brightness(distorted_image, distorted_image = tf.image.random_brightness(distorted_image, max_delta=63)
max_delta=63) distorted_image = tf.image.random_contrast(
distorted_image = tf.image.random_contrast(distorted_image, distorted_image, lower=0.2, upper=1.8)
lower=0.2, upper=1.8)
# Subtract off the mean and divide by the variance of the pixels. # Subtract off the mean and divide by the variance of the pixels.
float_image = tf.image.per_image_standardization(distorted_image) float_image = tf.image.per_image_standardization(distorted_image)
@ -188,15 +189,18 @@ def distorted_inputs(data_dir, batch_size):
# Ensure that the random shuffling has good mixing properties. # Ensure that the random shuffling has good mixing properties.
min_fraction_of_examples_in_queue = 0.4 min_fraction_of_examples_in_queue = 0.4
min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * min_queue_examples = int(
min_fraction_of_examples_in_queue) NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * min_fraction_of_examples_in_queue)
print ('Filling queue with %d CIFAR images before starting to train. ' print('Filling queue with %d CIFAR images before starting to train. '
'This will take a few minutes.' % min_queue_examples) 'This will take a few minutes.' % min_queue_examples)
# Generate a batch of images and labels by building up a queue of examples. # Generate a batch of images and labels by building up a queue of examples.
return _generate_image_and_label_batch(float_image, read_input.label, return _generate_image_and_label_batch(
min_queue_examples, batch_size, float_image,
shuffle=True) read_input.label,
min_queue_examples,
batch_size,
shuffle=True)
def inputs(eval_data, data_dir, batch_size): def inputs(eval_data, data_dir, batch_size):
@ -212,8 +216,9 @@ def inputs(eval_data, data_dir, batch_size):
labels: Labels. 1D tensor of [batch_size] size. labels: Labels. 1D tensor of [batch_size] size.
""" """
if not eval_data: if not eval_data:
filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i) filenames = [
for i in xrange(1, 6)] os.path.join(data_dir, 'data_batch_%d.bin' % i) for i in xrange(1, 6)
]
num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN
else: else:
filenames = [os.path.join(data_dir, 'test_batch.bin')] filenames = [os.path.join(data_dir, 'test_batch.bin')]
@ -235,8 +240,8 @@ def inputs(eval_data, data_dir, batch_size):
# Image processing for evaluation. # Image processing for evaluation.
# Crop the central [height, width] of the image. # Crop the central [height, width] of the image.
resized_image = tf.image.resize_image_with_crop_or_pad(reshaped_image, resized_image = tf.image.resize_image_with_crop_or_pad(
width, height) reshaped_image, width, height)
# Subtract off the mean and divide by the variance of the pixels. # Subtract off the mean and divide by the variance of the pixels.
float_image = tf.image.per_image_standardization(resized_image) float_image = tf.image.per_image_standardization(resized_image)
@ -247,10 +252,13 @@ def inputs(eval_data, data_dir, batch_size):
# Ensure that the random shuffling has good mixing properties. # Ensure that the random shuffling has good mixing properties.
min_fraction_of_examples_in_queue = 0.4 min_fraction_of_examples_in_queue = 0.4
min_queue_examples = int(num_examples_per_epoch * min_queue_examples = int(
min_fraction_of_examples_in_queue) num_examples_per_epoch * min_fraction_of_examples_in_queue)
# Generate a batch of images and labels by building up a queue of examples. # Generate a batch of images and labels by building up a queue of examples.
return _generate_image_and_label_batch(float_image, read_input.label, return _generate_image_and_label_batch(
min_queue_examples, batch_size, float_image,
shuffle=False) read_input.label,
min_queue_examples,
batch_size,
shuffle=False)

View File

@ -42,7 +42,6 @@ from tensorflow.python.platform import test
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.contrib.rnn.python.ops import rnn_cell as contrib_rnn_cell from tensorflow.contrib.rnn.python.ops import rnn_cell as contrib_rnn_cell
# pylint: enable=protected-access # pylint: enable=protected-access
Linear = core_rnn_cell._Linear # pylint: disable=invalid-name Linear = core_rnn_cell._Linear # pylint: disable=invalid-name
@ -84,19 +83,22 @@ class RNNCellTest(test.TestCase):
], [v.name for v in cell.trainable_variables]) ], [v.name for v in cell.trainable_variables])
self.assertFalse(cell.non_trainable_variables) self.assertFalse(cell.non_trainable_variables)
sess.run([variables_lib.global_variables_initializer()]) sess.run([variables_lib.global_variables_initializer()])
res = sess.run( res = sess.run([g], {
[g], {x.name: np.array([[1., 1.]]), x.name: np.array([[1., 1.]]),
m.name: np.array([[0.1, 0.1]])}) m.name: np.array([[0.1, 0.1]])
})
self.assertEqual(res[0].shape, (1, 2)) self.assertEqual(res[0].shape, (1, 2))
def testBasicRNNCellNotTrainable(self): def testBasicRNNCellNotTrainable(self):
with self.test_session() as sess: with self.test_session() as sess:
def not_trainable_getter(getter, *args, **kwargs): def not_trainable_getter(getter, *args, **kwargs):
kwargs["trainable"] = False kwargs["trainable"] = False
return getter(*args, **kwargs) return getter(*args, **kwargs)
with variable_scope.variable_scope( with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5), "root",
initializer=init_ops.constant_initializer(0.5),
custom_getter=not_trainable_getter): custom_getter=not_trainable_getter):
x = array_ops.zeros([1, 2]) x = array_ops.zeros([1, 2])
m = array_ops.zeros([1, 2]) m = array_ops.zeros([1, 2])
@ -108,9 +110,10 @@ class RNNCellTest(test.TestCase):
"root/basic_rnn_cell/%s:0" % rnn_cell_impl._BIAS_VARIABLE_NAME "root/basic_rnn_cell/%s:0" % rnn_cell_impl._BIAS_VARIABLE_NAME
], [v.name for v in cell.non_trainable_variables]) ], [v.name for v in cell.non_trainable_variables])
sess.run([variables_lib.global_variables_initializer()]) sess.run([variables_lib.global_variables_initializer()])
res = sess.run( res = sess.run([g], {
[g], {x.name: np.array([[1., 1.]]), x.name: np.array([[1., 1.]]),
m.name: np.array([[0.1, 0.1]])}) m.name: np.array([[0.1, 0.1]])
})
self.assertEqual(res[0].shape, (1, 2)) self.assertEqual(res[0].shape, (1, 2))
def testGRUCell(self): def testGRUCell(self):
@ -121,9 +124,10 @@ class RNNCellTest(test.TestCase):
m = array_ops.zeros([1, 2]) m = array_ops.zeros([1, 2])
g, _ = rnn_cell_impl.GRUCell(2)(x, m) g, _ = rnn_cell_impl.GRUCell(2)(x, m)
sess.run([variables_lib.global_variables_initializer()]) sess.run([variables_lib.global_variables_initializer()])
res = sess.run( res = sess.run([g], {
[g], {x.name: np.array([[1., 1.]]), x.name: np.array([[1., 1.]]),
m.name: np.array([[0.1, 0.1]])}) m.name: np.array([[0.1, 0.1]])
})
# Smoke test # Smoke test
self.assertAllClose(res[0], [[0.175991, 0.175991]]) self.assertAllClose(res[0], [[0.175991, 0.175991]])
with variable_scope.variable_scope( with variable_scope.variable_scope(
@ -133,10 +137,10 @@ class RNNCellTest(test.TestCase):
m = array_ops.zeros([1, 2]) m = array_ops.zeros([1, 2])
g, _ = rnn_cell_impl.GRUCell(2)(x, m) g, _ = rnn_cell_impl.GRUCell(2)(x, m)
sess.run([variables_lib.global_variables_initializer()]) sess.run([variables_lib.global_variables_initializer()])
res = sess.run( res = sess.run([g], {
[g], x.name: np.array([[1., 1., 1.]]),
{x.name: np.array([[1., 1., 1.]]), m.name: np.array([[0.1, 0.1]])
m.name: np.array([[0.1, 0.1]])}) })
# Smoke test # Smoke test
self.assertAllClose(res[0], [[0.156736, 0.156736]]) self.assertAllClose(res[0], [[0.156736, 0.156736]])
@ -148,11 +152,12 @@ class RNNCellTest(test.TestCase):
m = array_ops.zeros([1, 2]) m = array_ops.zeros([1, 2])
g, _ = contrib_rnn_cell.SRUCell(2)(x, m) g, _ = contrib_rnn_cell.SRUCell(2)(x, m)
sess.run([variables_lib.global_variables_initializer()]) sess.run([variables_lib.global_variables_initializer()])
res = sess.run( res = sess.run([g], {
[g], {x.name: np.array([[1., 1.]]), x.name: np.array([[1., 1.]]),
m.name: np.array([[0.1, 0.1]])}) m.name: np.array([[0.1, 0.1]])
})
# Smoke test # Smoke test
self.assertAllClose(res[0], [[0.509682, 0.509682]]) self.assertAllClose(res[0], [[0.509682, 0.509682]])
def testBasicLSTMCell(self): def testBasicLSTMCell(self):
for dtype in [dtypes.float16, dtypes.float32]: for dtype in [dtypes.float16, dtypes.float32]:
@ -164,8 +169,7 @@ class RNNCellTest(test.TestCase):
m = array_ops.zeros([1, 8], dtype=dtype) m = array_ops.zeros([1, 8], dtype=dtype)
cell = rnn_cell_impl.MultiRNNCell( cell = rnn_cell_impl.MultiRNNCell(
[ [
rnn_cell_impl.BasicLSTMCell( rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)
2, state_is_tuple=False)
for _ in range(2) for _ in range(2)
], ],
state_is_tuple=False) state_is_tuple=False)
@ -183,22 +187,21 @@ class RNNCellTest(test.TestCase):
"root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0" % "root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0" %
rnn_cell_impl._BIAS_VARIABLE_NAME rnn_cell_impl._BIAS_VARIABLE_NAME
] ]
self.assertEqual( self.assertEqual(expected_variable_names,
expected_variable_names, [v.name for v in cell.trainable_variables])
[v.name for v in cell.trainable_variables])
self.assertFalse(cell.non_trainable_variables) self.assertFalse(cell.non_trainable_variables)
sess.run([variables_lib.global_variables_initializer()]) sess.run([variables_lib.global_variables_initializer()])
res = sess.run( res = sess.run([g, out_m], {
[g, out_m], x.name: np.array([[1., 1.]]),
{x.name: np.array([[1., 1.]]), m.name: 0.1 * np.ones([1, 8])
m.name: 0.1 * np.ones([1, 8])}) })
self.assertEqual(len(res), 2) self.assertEqual(len(res), 2)
variables = variables_lib.global_variables() variables = variables_lib.global_variables()
self.assertEqual(expected_variable_names, [v.name for v in variables]) self.assertEqual(expected_variable_names, [v.name for v in variables])
# The numbers in results were not calculated, this is just a # The numbers in results were not calculated, this is just a
# smoke test. # smoke test.
self.assertAllClose( self.assertAllClose(res[0], np.array(
res[0], np.array([[0.240, 0.240]], dtype=np_dtype), 1e-2) [[0.240, 0.240]], dtype=np_dtype), 1e-2)
expected_mem = np.array( expected_mem = np.array(
[[0.689, 0.689, 0.448, 0.448, 0.398, 0.398, 0.240, 0.240]], [[0.689, 0.689, 0.448, 0.448, 0.398, 0.398, 0.240, 0.240]],
dtype=np_dtype) dtype=np_dtype)
@ -208,13 +211,13 @@ class RNNCellTest(test.TestCase):
# Test BasicLSTMCell with input_size != num_units. # Test BasicLSTMCell with input_size != num_units.
x = array_ops.zeros([1, 3], dtype=dtype) x = array_ops.zeros([1, 3], dtype=dtype)
m = array_ops.zeros([1, 4], dtype=dtype) m = array_ops.zeros([1, 4], dtype=dtype)
g, out_m = rnn_cell_impl.BasicLSTMCell( g, out_m = rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)(x, m)
2, state_is_tuple=False)(x, m)
sess.run([variables_lib.global_variables_initializer()]) sess.run([variables_lib.global_variables_initializer()])
res = sess.run( res = sess.run(
[g, out_m], [g, out_m], {
{x.name: np.array([[1., 1., 1.]], dtype=np_dtype), x.name: np.array([[1., 1., 1.]], dtype=np_dtype),
m.name: 0.1 * np.ones([1, 4], dtype=np_dtype)}) m.name: 0.1 * np.ones([1, 4], dtype=np_dtype)
})
self.assertEqual(len(res), 2) self.assertEqual(len(res), 2)
def testBasicLSTMCellDimension0Error(self): def testBasicLSTMCellDimension0Error(self):
@ -232,9 +235,11 @@ class RNNCellTest(test.TestCase):
g, out_m = rnn_cell_impl.BasicLSTMCell( g, out_m = rnn_cell_impl.BasicLSTMCell(
num_units, state_is_tuple=False)(x, m) num_units, state_is_tuple=False)(x, m)
sess.run([variables_lib.global_variables_initializer()]) sess.run([variables_lib.global_variables_initializer()])
sess.run([g, out_m], sess.run(
{x.name: 1 * np.ones([batch_size, input_size]), [g, out_m], {
m.name: 0.1 * np.ones([batch_size - 1, state_size])}) x.name: 1 * np.ones([batch_size, input_size]),
m.name: 0.1 * np.ones([batch_size - 1, state_size])
})
def testBasicLSTMCellStateSizeError(self): def testBasicLSTMCellStateSizeError(self):
"""Tests that state_size must be num_units * 2.""" """Tests that state_size must be num_units * 2."""
@ -251,9 +256,11 @@ class RNNCellTest(test.TestCase):
g, out_m = rnn_cell_impl.BasicLSTMCell( g, out_m = rnn_cell_impl.BasicLSTMCell(
num_units, state_is_tuple=False)(x, m) num_units, state_is_tuple=False)(x, m)
sess.run([variables_lib.global_variables_initializer()]) sess.run([variables_lib.global_variables_initializer()])
sess.run([g, out_m], sess.run(
{x.name: 1 * np.ones([batch_size, input_size]), [g, out_m], {
m.name: 0.1 * np.ones([batch_size, state_size])}) x.name: 1 * np.ones([batch_size, input_size]),
m.name: 0.1 * np.ones([batch_size, state_size])
})
def testBasicLSTMCellStateTupleType(self): def testBasicLSTMCellStateTupleType(self):
with self.test_session(): with self.test_session():
@ -301,11 +308,12 @@ class RNNCellTest(test.TestCase):
state_is_tuple=True) state_is_tuple=True)
g, (out_m0, out_m1) = cell(x, (m0, m1)) g, (out_m0, out_m1) = cell(x, (m0, m1))
sess.run([variables_lib.global_variables_initializer()]) sess.run([variables_lib.global_variables_initializer()])
res = sess.run([g, out_m0, out_m1], { res = sess.run(
x.name: np.array([[1., 1.]]), [g, out_m0, out_m1], {
m0.name: 0.1 * np.ones([1, 4]), x.name: np.array([[1., 1.]]),
m1.name: 0.1 * np.ones([1, 4]) m0.name: 0.1 * np.ones([1, 4]),
}) m1.name: 0.1 * np.ones([1, 4])
})
self.assertEqual(len(res), 3) self.assertEqual(len(res), 3)
# The numbers in results were not calculated, this is just a smoke test. # The numbers in results were not calculated, this is just a smoke test.
# Note, however, these values should match the original # Note, however, these values should match the original
@ -336,10 +344,11 @@ class RNNCellTest(test.TestCase):
state_is_tuple=False) state_is_tuple=False)
output, state = cell(x, m) output, state = cell(x, m)
sess.run([variables_lib.global_variables_initializer()]) sess.run([variables_lib.global_variables_initializer()])
res = sess.run([output, state], { res = sess.run(
x.name: np.array([[1., 1.], [2., 2.], [3., 3.]]), [output, state], {
m.name: 0.1 * np.ones((batch_size, state_size)) x.name: np.array([[1., 1.], [2., 2.], [3., 3.]]),
}) m.name: 0.1 * np.ones((batch_size, state_size))
})
self.assertEqual(len(res), 2) self.assertEqual(len(res), 2)
# The numbers in results were not calculated, this is mostly just a # The numbers in results were not calculated, this is mostly just a
# smoke test. # smoke test.
@ -442,10 +451,10 @@ class RNNCellTest(test.TestCase):
rnn_cell_impl.GRUCell(3), num_proj=3) rnn_cell_impl.GRUCell(3), num_proj=3)
g, new_m = cell(x, m) g, new_m = cell(x, m)
sess.run([variables_lib.global_variables_initializer()]) sess.run([variables_lib.global_variables_initializer()])
res = sess.run( res = sess.run([g, new_m], {
[g, new_m], x.name: np.array([[1., 1.]]),
{x.name: np.array([[1., 1.]]), m.name: np.array([[0.1, 0.1, 0.1]])
m.name: np.array([[0.1, 0.1, 0.1]])}) })
self.assertEqual(res[1].shape, (1, 3)) self.assertEqual(res[1].shape, (1, 3))
# The numbers in results were not calculated, this is just a smoke test. # The numbers in results were not calculated, this is just a smoke test.
self.assertAllClose(res[0], [[0.154605, 0.154605, 0.154605]]) self.assertAllClose(res[0], [[0.154605, 0.154605, 0.154605]])
@ -479,9 +488,11 @@ class RNNCellTest(test.TestCase):
base_cell = rnn_cell_impl.GRUCell(3) base_cell = rnn_cell_impl.GRUCell(3)
g, m_new = base_cell(x, m) g, m_new = base_cell(x, m)
variable_scope.get_variable_scope().reuse_variables() variable_scope.get_variable_scope().reuse_variables()
def residual_with_slice_fn(inp, out): def residual_with_slice_fn(inp, out):
inp_sliced = array_ops.slice(inp, [0, 0], [-1, 3]) inp_sliced = array_ops.slice(inp, [0, 0], [-1, 3])
return inp_sliced + out return inp_sliced + out
g_res, m_new_res = rnn_cell_impl.ResidualWrapper( g_res, m_new_res = rnn_cell_impl.ResidualWrapper(
base_cell, residual_with_slice_fn)(x, m) base_cell, residual_with_slice_fn)(x, m)
sess.run([variables_lib.global_variables_initializer()]) sess.run([variables_lib.global_variables_initializer()])
@ -551,10 +562,10 @@ class RNNCellTest(test.TestCase):
self.assertEqual(embedding_cell.output_size, 2) self.assertEqual(embedding_cell.output_size, 2)
g, new_m = embedding_cell(x, m) g, new_m = embedding_cell(x, m)
sess.run([variables_lib.global_variables_initializer()]) sess.run([variables_lib.global_variables_initializer()])
res = sess.run( res = sess.run([g, new_m], {
[g, new_m], x.name: np.array([[1]]),
{x.name: np.array([[1]]), m.name: np.array([[0.1, 0.1]])
m.name: np.array([[0.1, 0.1]])}) })
self.assertEqual(res[1].shape, (1, 2)) self.assertEqual(res[1].shape, (1, 2))
# The numbers in results were not calculated, this is just a smoke test. # The numbers in results were not calculated, this is just a smoke test.
self.assertAllClose(res[0], [[0.17139, 0.17139]]) self.assertAllClose(res[0], [[0.17139, 0.17139]])
@ -584,8 +595,8 @@ class RNNCellTest(test.TestCase):
x = array_ops.zeros([1, 2]) x = array_ops.zeros([1, 2])
m = array_ops.zeros([1, 4]) m = array_ops.zeros([1, 4])
_, ml = rnn_cell_impl.MultiRNNCell( _, ml = rnn_cell_impl.MultiRNNCell(
[rnn_cell_impl.GRUCell(2) [rnn_cell_impl.GRUCell(2) for _ in range(2)],
for _ in range(2)], state_is_tuple=False)(x, m) state_is_tuple=False)(x, m)
sess.run([variables_lib.global_variables_initializer()]) sess.run([variables_lib.global_variables_initializer()])
res = sess.run(ml, { res = sess.run(ml, {
x.name: np.array([[1., 1.]]), x.name: np.array([[1., 1.]]),
@ -605,19 +616,20 @@ class RNNCellTest(test.TestCase):
# Test incorrectness of state # Test incorrectness of state
with self.assertRaisesRegexp(ValueError, "Expected state .* a tuple"): with self.assertRaisesRegexp(ValueError, "Expected state .* a tuple"):
rnn_cell_impl.MultiRNNCell( rnn_cell_impl.MultiRNNCell(
[rnn_cell_impl.GRUCell(2) [rnn_cell_impl.GRUCell(2) for _ in range(2)],
for _ in range(2)], state_is_tuple=True)(x, m_bad) state_is_tuple=True)(x, m_bad)
_, ml = rnn_cell_impl.MultiRNNCell( _, ml = rnn_cell_impl.MultiRNNCell(
[rnn_cell_impl.GRUCell(2) [rnn_cell_impl.GRUCell(2) for _ in range(2)],
for _ in range(2)], state_is_tuple=True)(x, m_good) state_is_tuple=True)(x, m_good)
sess.run([variables_lib.global_variables_initializer()]) sess.run([variables_lib.global_variables_initializer()])
res = sess.run(ml, { res = sess.run(
x.name: np.array([[1., 1.]]), ml, {
m_good[0].name: np.array([[0.1, 0.1]]), x.name: np.array([[1., 1.]]),
m_good[1].name: np.array([[0.1, 0.1]]) m_good[0].name: np.array([[0.1, 0.1]]),
}) m_good[1].name: np.array([[0.1, 0.1]])
})
# The numbers in results were not calculated, this is just a # The numbers in results were not calculated, this is just a
# smoke test. However, these numbers should match those of # smoke test. However, these numbers should match those of
@ -628,8 +640,11 @@ class RNNCellTest(test.TestCase):
class DropoutWrapperTest(test.TestCase): class DropoutWrapperTest(test.TestCase):
def _testDropoutWrapper(self, batch_size=None, time_steps=None, def _testDropoutWrapper(self,
parallel_iterations=None, **kwargs): batch_size=None,
time_steps=None,
parallel_iterations=None,
**kwargs):
with self.test_session() as sess: with self.test_session() as sess:
with variable_scope.variable_scope( with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)): "root", initializer=init_ops.constant_initializer(0.5)):
@ -640,14 +655,14 @@ class DropoutWrapperTest(test.TestCase):
x = constant_op.constant( x = constant_op.constant(
[[[2., 2., 2.]], [[1., 1., 1.]]], dtype=dtypes.float32) [[[2., 2., 2.]], [[1., 1., 1.]]], dtype=dtypes.float32)
m = rnn_cell_impl.LSTMStateTuple( m = rnn_cell_impl.LSTMStateTuple(
*[constant_op.constant([[0.1, 0.1, 0.1]], dtype=dtypes.float32) *[constant_op.constant([[0.1, 0.1, 0.1]], dtype=dtypes.float32
] * 2) )] * 2)
else: else:
x = constant_op.constant( x = constant_op.constant(
np.random.randn(time_steps, batch_size, 3).astype(np.float32)) np.random.randn(time_steps, batch_size, 3).astype(np.float32))
m = rnn_cell_impl.LSTMStateTuple(*[ m = rnn_cell_impl.LSTMStateTuple(*[
constant_op.constant( constant_op.
[[0.1, 0.1, 0.1]] * batch_size, dtype=dtypes.float32) constant([[0.1, 0.1, 0.1]] * batch_size, dtype=dtypes.float32)
] * 2) ] * 2)
outputs, final_state = rnn.dynamic_rnn( outputs, final_state = rnn.dynamic_rnn(
cell=rnn_cell_impl.DropoutWrapper( cell=rnn_cell_impl.DropoutWrapper(
@ -674,8 +689,8 @@ class DropoutWrapperTest(test.TestCase):
res = self._testDropoutWrapper( res = self._testDropoutWrapper(
input_keep_prob=keep, output_keep_prob=keep, state_keep_prob=keep) input_keep_prob=keep, output_keep_prob=keep, state_keep_prob=keep)
true_full_output = np.array( true_full_output = np.array(
[[[0.751109, 0.751109, 0.751109]], [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]],
[[0.895509, 0.895509, 0.895509]]], dtype=np.float32) dtype=np.float32)
true_full_final_c = np.array( true_full_final_c = np.array(
[[1.949385, 1.949385, 1.949385]], dtype=np.float32) [[1.949385, 1.949385, 1.949385]], dtype=np.float32)
self.assertAllClose(true_full_output, res[0]) self.assertAllClose(true_full_output, res[0])
@ -687,8 +702,8 @@ class DropoutWrapperTest(test.TestCase):
res = self._testDropoutWrapper( res = self._testDropoutWrapper(
input_keep_prob=keep, output_keep_prob=keep, state_keep_prob=keep) input_keep_prob=keep, output_keep_prob=keep, state_keep_prob=keep)
true_full_output = np.array( true_full_output = np.array(
[[[0.751109, 0.751109, 0.751109]], [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]],
[[0.895509, 0.895509, 0.895509]]], dtype=np.float32) dtype=np.float32)
true_full_final_c = np.array( true_full_final_c = np.array(
[[1.949385, 1.949385, 1.949385]], dtype=np.float32) [[1.949385, 1.949385, 1.949385]], dtype=np.float32)
self.assertAllClose(true_full_output, res[0]) self.assertAllClose(true_full_output, res[0])
@ -703,16 +718,20 @@ class DropoutWrapperTest(test.TestCase):
## consistent across both calls. Otherwise the seed may not end ## consistent across both calls. Otherwise the seed may not end
## up being munged consistently across both graphs. ## up being munged consistently across both graphs.
res_standard_1 = self._testDropoutWrapper( res_standard_1 = self._testDropoutWrapper(
input_keep_prob=keep_some, output_keep_prob=keep_some, input_keep_prob=keep_some,
state_keep_prob=keep_some, seed=10, output_keep_prob=keep_some,
state_keep_prob=keep_some,
seed=10,
parallel_iterations=1) parallel_iterations=1)
# Clear away the graph and the test session (which keeps variables around) # Clear away the graph and the test session (which keeps variables around)
ops.reset_default_graph() ops.reset_default_graph()
self._ClearCachedSession() self._ClearCachedSession()
random_seed.set_random_seed(2) random_seed.set_random_seed(2)
res_standard_2 = self._testDropoutWrapper( res_standard_2 = self._testDropoutWrapper(
input_keep_prob=keep_some, output_keep_prob=keep_some, input_keep_prob=keep_some,
state_keep_prob=keep_some, seed=10, output_keep_prob=keep_some,
state_keep_prob=keep_some,
seed=10,
parallel_iterations=1) parallel_iterations=1)
self.assertAllClose(res_standard_1[0], res_standard_2[0]) self.assertAllClose(res_standard_1[0], res_standard_2[0])
self.assertAllClose(res_standard_1[1].c, res_standard_2[1].c) self.assertAllClose(res_standard_1[1].c, res_standard_2[1].c)
@ -722,11 +741,12 @@ class DropoutWrapperTest(test.TestCase):
keep_all = variable_scope.get_variable("all", initializer=1.0) keep_all = variable_scope.get_variable("all", initializer=1.0)
keep_none = variable_scope.get_variable("none", initializer=1e-10) keep_none = variable_scope.get_variable("none", initializer=1e-10)
res = self._testDropoutWrapper( res = self._testDropoutWrapper(
input_keep_prob=keep_all, output_keep_prob=keep_none, input_keep_prob=keep_all,
output_keep_prob=keep_none,
state_keep_prob=keep_all) state_keep_prob=keep_all)
true_full_output = np.array( true_full_output = np.array(
[[[0.751109, 0.751109, 0.751109]], [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]],
[[0.895509, 0.895509, 0.895509]]], dtype=np.float32) dtype=np.float32)
true_full_final_c = np.array( true_full_final_c = np.array(
[[1.949385, 1.949385, 1.949385]], dtype=np.float32) [[1.949385, 1.949385, 1.949385]], dtype=np.float32)
self.assertAllClose(np.zeros(res[0].shape), res[0]) self.assertAllClose(np.zeros(res[0].shape), res[0])
@ -739,13 +759,13 @@ class DropoutWrapperTest(test.TestCase):
# Even though we dropout state, by default DropoutWrapper never # Even though we dropout state, by default DropoutWrapper never
# drops out the memory ("c") term of an LSTMStateTuple. # drops out the memory ("c") term of an LSTMStateTuple.
res = self._testDropoutWrapper( res = self._testDropoutWrapper(
input_keep_prob=keep_all, output_keep_prob=keep_all, input_keep_prob=keep_all,
output_keep_prob=keep_all,
state_keep_prob=keep_none) state_keep_prob=keep_none)
true_c_state = np.array( true_c_state = np.array([[1.713925, 1.713925, 1.713925]], dtype=np.float32)
[[1.713925, 1.713925, 1.713925]], dtype=np.float32)
true_full_output = np.array( true_full_output = np.array(
[[[0.751109, 0.751109, 0.751109]], [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]],
[[0.895509, 0.895509, 0.895509]]], dtype=np.float32) dtype=np.float32)
self.assertAllClose(true_full_output[0], res[0][0]) self.assertAllClose(true_full_output[0], res[0][0])
# Second output is modified by zero input state # Second output is modified by zero input state
self.assertGreater(np.linalg.norm(true_full_output[1] - res[0][1]), 1e-4) self.assertGreater(np.linalg.norm(true_full_output[1] - res[0][1]), 1e-4)
@ -758,13 +778,14 @@ class DropoutWrapperTest(test.TestCase):
keep_all = variable_scope.get_variable("all", initializer=1.0) keep_all = variable_scope.get_variable("all", initializer=1.0)
keep_none = variable_scope.get_variable("none", initializer=1e-10) keep_none = variable_scope.get_variable("none", initializer=1e-10)
true_full_output = np.array( true_full_output = np.array(
[[[0.751109, 0.751109, 0.751109]], [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]],
[[0.895509, 0.895509, 0.895509]]], dtype=np.float32) dtype=np.float32)
true_full_final_c = np.array( true_full_final_c = np.array(
[[1.949385, 1.949385, 1.949385]], dtype=np.float32) [[1.949385, 1.949385, 1.949385]], dtype=np.float32)
# All outputs are different because inputs are zeroed out # All outputs are different because inputs are zeroed out
res = self._testDropoutWrapper( res = self._testDropoutWrapper(
input_keep_prob=keep_none, output_keep_prob=keep_all, input_keep_prob=keep_none,
output_keep_prob=keep_all,
state_keep_prob=keep_all) state_keep_prob=keep_all)
self.assertGreater(np.linalg.norm(res[0] - true_full_output), 1e-4) self.assertGreater(np.linalg.norm(res[0] - true_full_output), 1e-4)
self.assertGreater(np.linalg.norm(res[1].h - true_full_output[1]), 1e-4) self.assertGreater(np.linalg.norm(res[1].h - true_full_output[1]), 1e-4)
@ -774,9 +795,13 @@ class DropoutWrapperTest(test.TestCase):
keep_some = 0.8 keep_some = 0.8
keep_all = variable_scope.get_variable("all", initializer=1.0) keep_all = variable_scope.get_variable("all", initializer=1.0)
res = self._testDropoutWrapper( res = self._testDropoutWrapper(
input_keep_prob=keep_all, output_keep_prob=keep_some, input_keep_prob=keep_all,
state_keep_prob=keep_all, variational_recurrent=True, output_keep_prob=keep_some,
input_size=3, batch_size=5, time_steps=7) state_keep_prob=keep_all,
variational_recurrent=True,
input_size=3,
batch_size=5,
time_steps=7)
# Ensure the same dropout pattern for all time steps # Ensure the same dropout pattern for all time steps
output_mask = np.abs(res[0]) > 1e-6 output_mask = np.abs(res[0]) > 1e-6
for m in output_mask[1:]: for m in output_mask[1:]:
@ -785,9 +810,13 @@ class DropoutWrapperTest(test.TestCase):
def testDropoutWrapperRecurrentStateInputAndOutput(self): def testDropoutWrapperRecurrentStateInputAndOutput(self):
keep_some = 0.9 keep_some = 0.9
res = self._testDropoutWrapper( res = self._testDropoutWrapper(
input_keep_prob=keep_some, output_keep_prob=keep_some, input_keep_prob=keep_some,
state_keep_prob=keep_some, variational_recurrent=True, output_keep_prob=keep_some,
input_size=3, batch_size=5, time_steps=7) state_keep_prob=keep_some,
variational_recurrent=True,
input_size=3,
batch_size=5,
time_steps=7)
# Smoke test for the state/input masks. # Smoke test for the state/input masks.
output_mask = np.abs(res[0]) > 1e-6 output_mask = np.abs(res[0]) > 1e-6
@ -811,17 +840,27 @@ class DropoutWrapperTest(test.TestCase):
random_seed.set_random_seed(2347) random_seed.set_random_seed(2347)
np.random.seed(23487) np.random.seed(23487)
res0 = self._testDropoutWrapper( res0 = self._testDropoutWrapper(
input_keep_prob=keep_some, output_keep_prob=keep_some, input_keep_prob=keep_some,
state_keep_prob=keep_some, variational_recurrent=True, output_keep_prob=keep_some,
input_size=3, batch_size=5, time_steps=7, seed=-234987) state_keep_prob=keep_some,
variational_recurrent=True,
input_size=3,
batch_size=5,
time_steps=7,
seed=-234987)
ops.reset_default_graph() ops.reset_default_graph()
self._ClearCachedSession() self._ClearCachedSession()
random_seed.set_random_seed(2347) random_seed.set_random_seed(2347)
np.random.seed(23487) np.random.seed(23487)
res1 = self._testDropoutWrapper( res1 = self._testDropoutWrapper(
input_keep_prob=keep_some, output_keep_prob=keep_some, input_keep_prob=keep_some,
state_keep_prob=keep_some, variational_recurrent=True, output_keep_prob=keep_some,
input_size=3, batch_size=5, time_steps=7, seed=-234987) state_keep_prob=keep_some,
variational_recurrent=True,
input_size=3,
batch_size=5,
time_steps=7,
seed=-234987)
output_mask = np.abs(res0[0]) > 1e-6 output_mask = np.abs(res0[0]) > 1e-6
for time_step in output_mask: for time_step in output_mask:
@ -858,9 +897,10 @@ class SlimRNNCellTest(test.TestCase):
g, _ = rnn_cell_impl._SlimRNNCell(my_cell)(x, m) g, _ = rnn_cell_impl._SlimRNNCell(my_cell)(x, m)
# pylint: enable=protected-access # pylint: enable=protected-access
sess.run([variables_lib.global_variables_initializer()]) sess.run([variables_lib.global_variables_initializer()])
res = sess.run( res = sess.run([g], {
[g], {x.name: np.array([[1., 1.]]), x.name: np.array([[1., 1.]]),
m.name: np.array([[0.1, 0.1]])}) m.name: np.array([[0.1, 0.1]])
})
self.assertEqual(res[0].shape, (1, 2)) self.assertEqual(res[0].shape, (1, 2))
def testBasicRNNCellMatch(self): def testBasicRNNCellMatch(self):

View File

@ -34,8 +34,7 @@ MAX_LABEL = 15
WORDS_FEATURE = 'words' # Name of the input words feature. WORDS_FEATURE = 'words' # Name of the input words feature.
def estimator_spec_for_softmax_classification( def estimator_spec_for_softmax_classification(logits, labels, mode):
logits, labels, mode):
"""Returns EstimatorSpec instance for softmax classification.""" """Returns EstimatorSpec instance for softmax classification."""
predicted_classes = tf.argmax(logits, 1) predicted_classes = tf.argmax(logits, 1)
if mode == tf.estimator.ModeKeys.PREDICT: if mode == tf.estimator.ModeKeys.PREDICT:
@ -53,8 +52,8 @@ def estimator_spec_for_softmax_classification(
return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op) return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
eval_metric_ops = { eval_metric_ops = {
'accuracy': tf.metrics.accuracy( 'accuracy':
labels=labels, predictions=predicted_classes) tf.metrics.accuracy(labels=labels, predictions=predicted_classes)
} }
return tf.estimator.EstimatorSpec( return tf.estimator.EstimatorSpec(
mode=mode, loss=loss, eval_metric_ops=eval_metric_ops) mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)
@ -67,8 +66,7 @@ def bag_of_words_model(features, labels, mode):
bow_embedding_column = tf.feature_column.embedding_column( bow_embedding_column = tf.feature_column.embedding_column(
bow_column, dimension=EMBEDDING_SIZE) bow_column, dimension=EMBEDDING_SIZE)
bow = tf.feature_column.input_layer( bow = tf.feature_column.input_layer(
features, features, feature_columns=[bow_embedding_column])
feature_columns=[bow_embedding_column])
logits = tf.layers.dense(bow, MAX_LABEL, activation=None) logits = tf.layers.dense(bow, MAX_LABEL, activation=None)
return estimator_spec_for_softmax_classification( return estimator_spec_for_softmax_classification(
@ -110,9 +108,9 @@ def main(unused_argv):
# Prepare training and testing data # Prepare training and testing data
dbpedia = tf.contrib.learn.datasets.load_dataset( dbpedia = tf.contrib.learn.datasets.load_dataset(
'dbpedia', test_with_fake_data=FLAGS.test_with_fake_data) 'dbpedia', test_with_fake_data=FLAGS.test_with_fake_data)
x_train = pandas.Series(dbpedia.train.data[:,1]) x_train = pandas.Series(dbpedia.train.data[:, 1])
y_train = pandas.Series(dbpedia.train.target) y_train = pandas.Series(dbpedia.train.target)
x_test = pandas.Series(dbpedia.test.data[:,1]) x_test = pandas.Series(dbpedia.test.data[:, 1])
y_test = pandas.Series(dbpedia.test.target) y_test = pandas.Series(dbpedia.test.target)
# Process vocabulary # Process vocabulary
@ -152,10 +150,7 @@ def main(unused_argv):
# Predict. # Predict.
test_input_fn = tf.estimator.inputs.numpy_input_fn( test_input_fn = tf.estimator.inputs.numpy_input_fn(
x={WORDS_FEATURE: x_test}, x={WORDS_FEATURE: x_test}, y=y_test, num_epochs=1, shuffle=False)
y=y_test,
num_epochs=1,
shuffle=False)
predictions = classifier.predict(input_fn=test_input_fn) predictions = classifier.predict(input_fn=test_input_fn)
y_predicted = np.array(list(p['class'] for p in predictions)) y_predicted = np.array(list(p['class'] for p in predictions))
y_predicted = y_predicted.reshape(np.array(y_test).shape) y_predicted = y_predicted.reshape(np.array(y_test).shape)

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Notebook front-end to TensorFlow. """Notebook front-end to TensorFlow.
When you run this binary, you'll see something like below, which indicates When you run this binary, you'll see something like below, which indicates
@ -43,10 +42,8 @@ from tensorflow.python.platform import app
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "cpp" os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "cpp"
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION"] = "2" os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION"] = "2"
FLAGS = None FLAGS = None
ORIG_ARGV = sys.argv ORIG_ARGV = sys.argv
# Main notebook process calls itself with argv[1]="kernel" to start kernel # Main notebook process calls itself with argv[1]="kernel" to start kernel
# subprocesses. # subprocesses.
@ -73,8 +70,8 @@ def main(unused_argv):
notebookapp.ip = "0.0.0.0" notebookapp.ip = "0.0.0.0"
notebookapp.password = passwd(FLAGS.password) notebookapp.password = passwd(FLAGS.password)
else: else:
print ("\nNo password specified; Notebook server will only be available" print("\nNo password specified; Notebook server will only be available"
" on the local machine.\n") " on the local machine.\n")
notebookapp.initialize(argv=["--notebook-dir", FLAGS.notebook_dir]) notebookapp.initialize(argv=["--notebook-dir", FLAGS.notebook_dir])
if notebookapp.ip == "0.0.0.0": if notebookapp.ip == "0.0.0.0":
@ -125,8 +122,8 @@ if __name__ == "__main__":
# kernel app. # kernel app.
if IS_KERNEL: if IS_KERNEL:
# Drop everything except --flagfile. # Drop everything except --flagfile.
sys.argv = ([sys.argv[0]] + sys.argv = (
[x for x in sys.argv[1:] if x.startswith("--flagfile")]) [sys.argv[0]] + [x for x in sys.argv[1:] if x.startswith("--flagfile")])
FLAGS, unparsed = parser.parse_known_args() FLAGS, unparsed = parser.parse_known_args()
app.run(main=main, argv=[sys.argv[0]] + unparsed) app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Classes and functions related to train_and_evaluate.""" """Classes and functions related to train_and_evaluate."""
from __future__ import absolute_import from __future__ import absolute_import
@ -37,7 +36,6 @@ from tensorflow.python.training import server_lib
from tensorflow.python.training import session_run_hook from tensorflow.python.training import session_run_hook
from tensorflow.python.util import compat from tensorflow.python.util import compat
_MAX_DELAY_SECS = 60 _MAX_DELAY_SECS = 60
_DELAY_SECS_PER_WORKER = 5 _DELAY_SECS_PER_WORKER = 5
_TF_CONFIG_ENV = 'TF_CONFIG' _TF_CONFIG_ENV = 'TF_CONFIG'
@ -50,8 +48,7 @@ _TRAINER_JOBS = (run_config_lib.TaskType.CHIEF, run_config_lib.TaskType.MASTER,
def _validate_input_fn(input_fn): def _validate_input_fn(input_fn):
"""Validates the `input_fn`.""" """Validates the `input_fn`."""
if not callable(input_fn): if not callable(input_fn):
raise TypeError( raise TypeError('`input_fn` must be callable, given: {}'.format(input_fn))
'`input_fn` must be callable, given: {}'.format(input_fn))
def _validate_hooks(hooks): def _validate_hooks(hooks):
@ -125,10 +122,7 @@ class TrainSpec(
duration. Optional hooks run at various stages of training. duration. Optional hooks run at various stages of training.
""" """
def __new__(cls, def __new__(cls, input_fn, max_steps=None, hooks=None):
input_fn,
max_steps=None,
hooks=None):
"""Creates a validated `TrainSpec` instance. """Creates a validated `TrainSpec` instance.
Args: Args:
@ -161,16 +155,13 @@ class TrainSpec(
hooks = _validate_hooks(hooks) hooks = _validate_hooks(hooks)
return super(TrainSpec, cls).__new__( return super(TrainSpec, cls).__new__(
cls, cls, input_fn=input_fn, max_steps=max_steps, hooks=hooks)
input_fn=input_fn,
max_steps=max_steps,
hooks=hooks)
class EvalSpec( class EvalSpec(
collections.namedtuple('EvalSpec', [ collections.namedtuple('EvalSpec', [
'input_fn', 'steps', 'name', 'hooks', 'exporters', 'input_fn', 'steps', 'name', 'hooks', 'exporters', 'start_delay_secs',
'start_delay_secs', 'throttle_secs' 'throttle_secs'
])): ])):
"""Configuration for the "eval" part for the `train_and_evaluate` call. """Configuration for the "eval" part for the `train_and_evaluate` call.
@ -417,8 +408,8 @@ def train_and_evaluate(estimator, train_spec, eval_spec):
Raises: Raises:
ValueError: if environment variable `TF_CONFIG` is incorrectly set. ValueError: if environment variable `TF_CONFIG` is incorrectly set.
""" """
executor = _TrainingExecutor(estimator=estimator, train_spec=train_spec, executor = _TrainingExecutor(
eval_spec=eval_spec) estimator=estimator, train_spec=train_spec, eval_spec=eval_spec)
config = estimator.config config = estimator.config
if (config.task_type == run_config_lib.TaskType.EVALUATOR and if (config.task_type == run_config_lib.TaskType.EVALUATOR and
@ -561,9 +552,8 @@ class _TrainingExecutor(object):
self._timer.update_last_triggered_step(global_step_value) self._timer.update_last_triggered_step(global_step_value)
self._evaluator.evaluate_and_export() self._evaluator.evaluate_and_export()
else: else:
logging.info( logging.info('Skip the current checkpoint eval due to throttle secs '
'Skip the current checkpoint eval due to throttle secs ' '({} secs).'.format(self._eval_throttle_secs))
'({} secs).'.format(self._eval_throttle_secs))
# Final export signal: For any eval result with global_step >= train # Final export signal: For any eval result with global_step >= train
# max_steps, the evaluator will send the final export signal. There is a # max_steps, the evaluator will send the final export signal. There is a
@ -576,8 +566,8 @@ class _TrainingExecutor(object):
# #
# But here, throttle_secs will skip the next intermediate checkpoint and, # But here, throttle_secs will skip the next intermediate checkpoint and,
# so, the double final export chance is very small. # so, the double final export chance is very small.
evaluator = _TrainingExecutor._Evaluator( evaluator = _TrainingExecutor._Evaluator(self._estimator, self._eval_spec,
self._estimator, self._eval_spec, self._train_spec.max_steps) self._train_spec.max_steps)
# When the underlying `Estimator` object saves a new checkpoint, we would # When the underlying `Estimator` object saves a new checkpoint, we would
# like this callback to be called so that evaluation and export can trigger. # like this callback to be called so that evaluation and export can trigger.
@ -617,8 +607,7 @@ class _TrainingExecutor(object):
raise ValueError('eval_spec.throttle_secs should be positive, given: {}.' raise ValueError('eval_spec.throttle_secs should be positive, given: {}.'
'It is used do determine how long each training ' 'It is used do determine how long each training '
'iteration should go when train and evaluate ' 'iteration should go when train and evaluate '
'locally.'.format( 'locally.'.format(self._eval_spec.throttle_secs))
self._eval_spec.throttle_secs))
stop_hook = _StopAtSecsHook(self._eval_spec.throttle_secs) stop_hook = _StopAtSecsHook(self._eval_spec.throttle_secs)
train_hooks = ( train_hooks = (
@ -663,8 +652,9 @@ class _TrainingExecutor(object):
if not config.master: if not config.master:
jobs = config.cluster_spec.jobs jobs = config.cluster_spec.jobs
if (len(jobs) == 1 and len(config.cluster_spec.job_tasks(jobs[0])) == 1 if (len(jobs) == 1 and
and config.task_type in _TRAINER_JOBS): len(config.cluster_spec.job_tasks(jobs[0])) == 1 and
config.task_type in _TRAINER_JOBS):
# For distributed training, config.master is empty if and only if it has # For distributed training, config.master is empty if and only if it has
# a single node in the cluster spec. In this case, we should not start # a single node in the cluster spec. In this case, we should not start
# the server. # the server.
@ -679,9 +669,9 @@ class _TrainingExecutor(object):
logging.info('Start Tensorflow server.') logging.info('Start Tensorflow server.')
if config.session_config is None: if config.session_config is None:
session_config=config_pb2.ConfigProto(log_device_placement=False) session_config = config_pb2.ConfigProto(log_device_placement=False)
else: else:
session_config=config_pb2.ConfigProto( session_config = config_pb2.ConfigProto(
log_device_placement=False, log_device_placement=False,
gpu_options=config.session_config.gpu_options) gpu_options=config.session_config.gpu_options)
@ -744,8 +734,7 @@ class _TrainingExecutor(object):
global_step >= self._train_spec.max_steps): global_step >= self._train_spec.max_steps):
logging.info( logging.info(
'Exiting evaluation, global_step=%s >= train max_steps=%s', 'Exiting evaluation, global_step=%s >= train max_steps=%s',
global_step, global_step, self._train_spec.max_steps)
self._train_spec.max_steps)
return return
latest_eval_result, should_early_stop = self._execute_evaluator_once( latest_eval_result, should_early_stop = self._execute_evaluator_once(
@ -781,10 +770,9 @@ class _TrainingExecutor(object):
# Throttle if necessary. # Throttle if necessary.
elapsed_time = time.time() - start elapsed_time = time.time() - start
difference = throttle_secs - elapsed_time difference = throttle_secs - elapsed_time
if difference > 0: if difference > 0:
logging.info('Waiting %f secs before starting next eval run.', logging.info('Waiting %f secs before starting next eval run.', difference)
difference)
time.sleep(difference) time.sleep(difference)
return (eval_result, should_early_stop) return (eval_result, should_early_stop)
@ -929,8 +917,8 @@ class _EvalResult(
if checkpoint_path: if checkpoint_path:
raise ValueError( raise ValueError(
'checkpoint must be `None` if status is not {}; got status {}, ' 'checkpoint must be `None` if status is not {}; got status {}, '
'checkpoint_path {}'.format( 'checkpoint_path {}'.format(_EvalStatus.EVALUATED, status,
_EvalStatus.EVALUATED, status, checkpoint_path)) checkpoint_path))
return super(_EvalResult, cls).__new__(cls, status, metrics, return super(_EvalResult, cls).__new__(cls, status, metrics,
checkpoint_path) checkpoint_path)

View File

@ -56,9 +56,11 @@ class TensordotTest(test_lib.TestCase):
axes_ph = array_ops.placeholder(dtypes.int32) axes_ph = array_ops.placeholder(dtypes.int32)
output = math_ops.tensordot(a_ph, b_ph, axes_ph) output = math_ops.tensordot(a_ph, b_ph, axes_ph)
_ = sess.run( _ = sess.run(
[output], feed_dict={a_ph: a, [output], feed_dict={
b_ph: b, a_ph: a,
axes_ph: (a_axes, b_axes)}) b_ph: b,
axes_ph: (a_axes, b_axes)
})
def test_invalid_axes(self): def test_invalid_axes(self):
a = [[1, 2], [3, 4]] a = [[1, 2], [3, 4]]
@ -81,26 +83,27 @@ class TensordotTest(test_lib.TestCase):
with self.test_session() as sess: with self.test_session() as sess:
with self.assertRaises(errors_impl.InvalidArgumentError): with self.assertRaises(errors_impl.InvalidArgumentError):
_ = sess.run( _ = sess.run(
[output], feed_dict={a_ph: a, [output], feed_dict={
b_ph: b, a_ph: a,
axes_ph: axes_value}) b_ph: b,
axes_ph: axes_value
})
# Test case for 11950 # Test case for 11950
def test_valid_axis(self): def test_valid_axis(self):
for axes_value in [1, 2], [[1], [2]]: for axes_value in [1, 2], [[1], [2]]:
with self.test_session() as sess: with self.test_session() as sess:
np_a = np.ones((3,3)) np_a = np.ones((3, 3))
np_b = np.array([2, 3, 1])[None, None] np_b = np.array([2, 3, 1])[None, None]
np_ans = np.tensordot(np_a, np_b, axes_value) np_ans = np.tensordot(np_a, np_b, axes_value)
tf_a = array_ops.ones((3,3), dtype=dtypes.float32) tf_a = array_ops.ones((3, 3), dtype=dtypes.float32)
tf_b = constant_op.constant([2, 3, 1], dtype=dtypes.float32)[None, None] tf_b = constant_op.constant([2, 3, 1], dtype=dtypes.float32)[None, None]
tf_ans = math_ops.tensordot(tf_a, tf_b, axes_value).eval() tf_ans = math_ops.tensordot(tf_a, tf_b, axes_value).eval()
self.assertAllEqual(tf_ans.shape, np_ans.shape) self.assertAllEqual(tf_ans.shape, np_ans.shape)
self.assertAllEqual(tf_ans, np_ans) self.assertAllEqual(tf_ans, np_ans)
def test_partial_shape_inference(self): def test_partial_shape_inference(self):
a = array_ops.placeholder(dtypes.float32) a = array_ops.placeholder(dtypes.float32)
b = array_ops.placeholder(dtypes.float32) b = array_ops.placeholder(dtypes.float32)
@ -169,9 +172,11 @@ def _get_tensordot_tests(dtype_, rank_a_, rank_b_, num_dims_, dynamic_shape_):
axes = array_ops.placeholder(dtypes.int32) axes = array_ops.placeholder(dtypes.int32)
c = math_ops.tensordot(a, b, axes) c = math_ops.tensordot(a, b, axes)
tf_ans = sess.run( tf_ans = sess.run(
c, feed_dict={a: a_np, c, feed_dict={
b: b_np, a: a_np,
axes: (a_dims_np, b_dims_np)}) b: b_np,
axes: (a_dims_np, b_dims_np)
})
else: else:
tf_ans = math_ops.tensordot(a_np, b_np, (a_dims_np, b_dims_np)).eval() tf_ans = math_ops.tensordot(a_np, b_np, (a_dims_np, b_dims_np)).eval()
self.assertAllClose(tf_ans, np_ans, rtol=tol, atol=tol) self.assertAllClose(tf_ans, np_ans, rtol=tol, atol=tol)

View File

@ -184,7 +184,8 @@ do_pylint() {
# W0312 mixed-indentation # W0312 mixed-indentation
# C0330 bad-continuation # C0330 bad-continuation
# C0301 line-too-long # C0301 line-too-long
grep -E '(\[E|\[W0311|\[W0312|\[C0330|\[C0301)' ${OUTPUT_FILE} > ${ERRORS_FILE} # C0326 bad-whitespace
grep -E '(\[E|\[W0311|\[W0312|\[C0330|\[C0301|\[C0326)' ${OUTPUT_FILE} > ${ERRORS_FILE}
N_ERRORS=0 N_ERRORS=0
while read -r LINE; do while read -r LINE; do

View File

@ -46,8 +46,9 @@ class APIChangeSpec(object):
""" """
class _FileEditTuple(collections.namedtuple( class _FileEditTuple(
"_FileEditTuple", ["comment", "line", "start", "old", "new"])): collections.namedtuple("_FileEditTuple",
["comment", "line", "start", "old", "new"])):
"""Each edit that is recorded by a _FileEditRecorder. """Each edit that is recorded by a _FileEditRecorder.
Fields: Fields:
@ -179,8 +180,7 @@ class _ASTCallVisitor(ast.NodeVisitor):
function_renames = self._api_change_spec.function_renames function_renames = self._api_change_spec.function_renames
try: try:
new_name = function_renames[full_name] new_name = function_renames[full_name]
self._file_edit.add("Renamed function %r to %r" % (full_name, self._file_edit.add("Renamed function %r to %r" % (full_name, new_name),
new_name),
node.lineno, node.col_offset, full_name, new_name) node.lineno, node.col_offset, full_name, new_name)
except KeyError: except KeyError:
pass pass
@ -227,7 +227,7 @@ class _ASTCallVisitor(ast.NodeVisitor):
# loop over lines # loop over lines
while 1: while 1:
# Reverse the text to and regular expression search for whitespace # Reverse the text to and regular expression search for whitespace
text = self._lines[line-1] text = self._lines[line - 1]
reversed_preceding_text = text[:col][::-1] reversed_preceding_text = text[:col][::-1]
# First find if a [ can be found with only whitespace between it and # First find if a [ can be found with only whitespace between it and
# col. # col.
@ -248,8 +248,8 @@ class _ASTCallVisitor(ast.NodeVisitor):
# node ranges to filter out spurious #'s that appear in string # node ranges to filter out spurious #'s that appear in string
# literals. # literals.
comment_start = prev_line.find("#") comment_start = prev_line.find("#")
if comment_start == -1: if comment_start == -1:
col = len(prev_line) -1 col = len(prev_line) - 1
elif find_string_chars.search(prev_line[comment_start:]) is None: elif find_string_chars.search(prev_line[comment_start:]) is None:
col = comment_start col = comment_start
else: else:
@ -260,7 +260,6 @@ class _ASTCallVisitor(ast.NodeVisitor):
# it is not possible to use that in an argument. # it is not possible to use that in an argument.
return node.lineno, node.col_offset return node.lineno, node.col_offset
def visit_Call(self, node): # pylint: disable=invalid-name def visit_Call(self, node): # pylint: disable=invalid-name
"""Handle visiting a call node in the AST. """Handle visiting a call node in the AST.
@ -268,7 +267,6 @@ class _ASTCallVisitor(ast.NodeVisitor):
node: Current Node node: Current Node
""" """
# Find a simple attribute name path e.g. "tf.foo.bar" # Find a simple attribute name path e.g. "tf.foo.bar"
full_name = self._get_attribute_full_path(node.func) full_name = self._get_attribute_full_path(node.func)
@ -293,18 +291,21 @@ class _ASTCallVisitor(ast.NodeVisitor):
lineno, col_offset = self._find_true_position(arg) lineno, col_offset = self._find_true_position(arg)
if lineno is None or col_offset is None: if lineno is None or col_offset is None:
self._file_edit.add( self._file_edit.add(
"Failed to add keyword %r to reordered function %r" "Failed to add keyword %r to reordered function %r" %
% (reordered[idx], full_name), arg.lineno, arg.col_offset, (reordered[idx], full_name),
"", "", arg.lineno,
arg.col_offset,
"",
"",
error="A necessary keyword argument failed to be inserted.") error="A necessary keyword argument failed to be inserted.")
else: else:
keyword_arg = reordered[idx] keyword_arg = reordered[idx]
if (full_name in function_keyword_renames and if (full_name in function_keyword_renames and
keyword_arg in function_keyword_renames[full_name]): keyword_arg in function_keyword_renames[full_name]):
keyword_arg = function_keyword_renames[full_name][keyword_arg] keyword_arg = function_keyword_renames[full_name][keyword_arg]
self._file_edit.add("Added keyword %r to reordered function %r" self._file_edit.add("Added keyword %r to reordered function %r" %
% (reordered[idx], full_name), lineno, (reordered[idx], full_name), lineno, col_offset,
col_offset, "", keyword_arg + "=") "", keyword_arg + "=")
# Examine each keyword argument and convert it to the final renamed form # Examine each keyword argument and convert it to the final renamed form
renamed_keywords = ({} if full_name not in function_keyword_renames else renamed_keywords = ({} if full_name not in function_keyword_renames else
@ -322,11 +323,11 @@ class _ASTCallVisitor(ast.NodeVisitor):
# value. # value.
key_start = argval_col_offset - len(argkey) - 1 key_start = argval_col_offset - len(argkey) - 1
key_end = key_start + len(argkey) + 1 key_end = key_start + len(argkey) + 1
if (self._lines[argval_lineno - 1][key_start:key_end] == if (self._lines[argval_lineno - 1][key_start:key_end] == argkey +
argkey + "="): "="):
self._file_edit.add("Renamed keyword argument from %r to %r" % self._file_edit.add("Renamed keyword argument from %r to %r" %
(argkey, renamed_keywords[argkey]), (argkey,
argval_lineno, renamed_keywords[argkey]), argval_lineno,
argval_col_offset - len(argkey) - 1, argval_col_offset - len(argkey) - 1,
argkey + "=", renamed_keywords[argkey] + "=") argkey + "=", renamed_keywords[argkey] + "=")
continue continue
@ -335,7 +336,8 @@ class _ASTCallVisitor(ast.NodeVisitor):
(argkey, renamed_keywords[argkey]), (argkey, renamed_keywords[argkey]),
argval.lineno, argval.lineno,
argval.col_offset - len(argkey) - 1, argval.col_offset - len(argkey) - 1,
"", "", "",
"",
error="Failed to find keyword lexographically. Fix manually.") error="Failed to find keyword lexographically. Fix manually.")
ast.NodeVisitor.generic_visit(self, node) ast.NodeVisitor.generic_visit(self, node)
@ -352,7 +354,7 @@ class _ASTCallVisitor(ast.NodeVisitor):
if full_name in self._api_change_spec.change_to_function: if full_name in self._api_change_spec.change_to_function:
if not hasattr(node, "is_function_for_call"): if not hasattr(node, "is_function_for_call"):
new_text = full_name + "()" new_text = full_name + "()"
self._file_edit.add("Changed %r to %r"%(full_name, new_text), self._file_edit.add("Changed %r to %r" % (full_name, new_text),
node.lineno, node.col_offset, full_name, new_text) node.lineno, node.col_offset, full_name, new_text)
ast.NodeVisitor.generic_visit(self, node) ast.NodeVisitor.generic_visit(self, node)
@ -380,8 +382,8 @@ class ASTCodeUpgrader(object):
# Write to a temporary file, just in case we are doing an implace modify. # Write to a temporary file, just in case we are doing an implace modify.
with open(in_filename, "r") as in_file, \ with open(in_filename, "r") as in_file, \
tempfile.NamedTemporaryFile("w", delete=False) as temp_file: tempfile.NamedTemporaryFile("w", delete=False) as temp_file:
ret = self.process_opened_file( ret = self.process_opened_file(in_filename, in_file, out_filename,
in_filename, in_file, out_filename, temp_file) temp_file)
shutil.move(temp_file.name, out_filename) shutil.move(temp_file.name, out_filename)
return ret return ret
@ -424,6 +426,7 @@ class ASTCodeUpgrader(object):
out_file.write(out_text) out_file.write(out_text)
text += "\n" text += "\n"
return 1, text, process_errors return 1, text, process_errors
# pylint: enable=broad-except # pylint: enable=broad-except
def process_tree(self, root_directory, output_root_directory, def process_tree(self, root_directory, output_root_directory,
@ -444,16 +447,16 @@ class ASTCodeUpgrader(object):
# make sure output directory doesn't exist # make sure output directory doesn't exist
if output_root_directory and os.path.exists(output_root_directory): if output_root_directory and os.path.exists(output_root_directory):
print("Output directory %r must not already exist." % ( print("Output directory %r must not already exist." %
output_root_directory)) (output_root_directory))
sys.exit(1) sys.exit(1)
# make sure output directory does not overlap with root_directory # make sure output directory does not overlap with root_directory
norm_root = os.path.split(os.path.normpath(root_directory)) norm_root = os.path.split(os.path.normpath(root_directory))
norm_output = os.path.split(os.path.normpath(output_root_directory)) norm_output = os.path.split(os.path.normpath(output_root_directory))
if norm_root == norm_output: if norm_root == norm_output:
print("Output directory %r same as input directory %r" % ( print("Output directory %r same as input directory %r" %
root_directory, output_root_directory)) (root_directory, output_root_directory))
sys.exit(1) sys.exit(1)
# Collect list of files to process (we do this to correctly handle if the # Collect list of files to process (we do this to correctly handle if the
@ -465,14 +468,16 @@ class ASTCodeUpgrader(object):
copy_files = [f for f in file_list if not f.endswith(".py")] copy_files = [f for f in file_list if not f.endswith(".py")]
for filename in py_files: for filename in py_files:
fullpath = os.path.join(dir_name, filename) fullpath = os.path.join(dir_name, filename)
fullpath_output = os.path.join( fullpath_output = os.path.join(output_root_directory,
output_root_directory, os.path.relpath(fullpath, root_directory)) os.path.relpath(fullpath,
root_directory))
files_to_process.append((fullpath, fullpath_output)) files_to_process.append((fullpath, fullpath_output))
if copy_other_files: if copy_other_files:
for filename in copy_files: for filename in copy_files:
fullpath = os.path.join(dir_name, filename) fullpath = os.path.join(dir_name, filename)
fullpath_output = os.path.join( fullpath_output = os.path.join(output_root_directory,
output_root_directory, os.path.relpath(fullpath, root_directory)) os.path.relpath(
fullpath, root_directory))
files_to_copy.append((fullpath, fullpath_output)) files_to_copy.append((fullpath, fullpath_output))
file_count = 0 file_count = 0
@ -641,18 +646,17 @@ class TFAPIChangeSpec(APIChangeSpec):
"tf.concat": ["concat_dim", "values", "name"], "tf.concat": ["concat_dim", "values", "name"],
"tf.svd": ["tensor", "compute_uv", "full_matrices", "name"], "tf.svd": ["tensor", "compute_uv", "full_matrices", "name"],
"tf.nn.softmax_cross_entropy_with_logits": [ "tf.nn.softmax_cross_entropy_with_logits": [
"logits", "labels", "dim", "name"], "logits", "labels", "dim", "name"
],
"tf.nn.sparse_softmax_cross_entropy_with_logits": [ "tf.nn.sparse_softmax_cross_entropy_with_logits": [
"logits", "labels", "name"], "logits", "labels", "name"
"tf.nn.sigmoid_cross_entropy_with_logits": [ ],
"logits", "labels", "name"], "tf.nn.sigmoid_cross_entropy_with_logits": ["logits", "labels", "name"],
"tf.op_scope": ["values", "name", "default_name"], "tf.op_scope": ["values", "name", "default_name"],
} }
# Specially handled functions. # Specially handled functions.
self.function_handle = { self.function_handle = {"tf.reverse": self._reverse_handler}
"tf.reverse": self._reverse_handler
}
@staticmethod @staticmethod
def _reverse_handler(file_edit_recorder, node): def _reverse_handler(file_edit_recorder, node):
@ -661,12 +665,13 @@ class TFAPIChangeSpec(APIChangeSpec):
comment = ("ERROR: tf.reverse has had its argument semantics changed\n" comment = ("ERROR: tf.reverse has had its argument semantics changed\n"
"significantly the converter cannot detect this reliably, so you" "significantly the converter cannot detect this reliably, so you"
"need to inspect this usage manually.\n") "need to inspect this usage manually.\n")
file_edit_recorder.add(comment, file_edit_recorder.add(
node.lineno, comment,
node.col_offset, node.lineno,
"tf.reverse", node.col_offset,
"tf.reverse", "tf.reverse",
error="tf.reverse requires manual check.") "tf.reverse",
error="tf.reverse requires manual check.")
if __name__ == "__main__": if __name__ == "__main__":