Add C0326 bad-whitespace error to pylint sanity check.
PiperOrigin-RevId: 183689499
This commit is contained in:
parent
730071d0dc
commit
fd63d4e30a
tensorflow
contrib
framework/python/ops
learn/python/learn/datasets
lite/tools
model_pruning/examples/cifar10
rnn/python/kernel_tests
examples/learn
python
tools
@ -31,7 +31,6 @@ from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
|
||||
class AccumulateNV2Test(test_util.TensorFlowTestCase):
|
||||
"""Tests of the new, differentiable version of accumulate_n"""
|
||||
|
||||
@ -62,8 +61,9 @@ class AccumulateNV2Test(test_util.TensorFlowTestCase):
|
||||
accum_n = av2.accumulate_n_v2(input_vars)
|
||||
sess.run(variables.global_variables_initializer())
|
||||
accum_n_grad = gradients.gradients(accum_n, input_vars)
|
||||
self.assertAllEqual(np.repeat(1.0, num_inputs), # d/dx (x + y + ...) = 1
|
||||
[g.eval() for g in accum_n_grad])
|
||||
self.assertAllEqual(
|
||||
np.repeat(1.0, num_inputs), # d/dx (x + y + ...) = 1
|
||||
[g.eval() for g in accum_n_grad])
|
||||
|
||||
# The tests below used to be in a separate class under cwise_ops_test.py,
|
||||
# which did not run in the default test target.
|
||||
@ -75,8 +75,8 @@ class AccumulateNV2Test(test_util.TensorFlowTestCase):
|
||||
np.random.rand(16, 16, 16, 16).astype(np.float32) for _ in range(20)
|
||||
]
|
||||
random_tensors = [
|
||||
ops.convert_to_tensor(
|
||||
x, dtype=dtypes_lib.float32) for x in random_arrays
|
||||
ops.convert_to_tensor(x, dtype=dtypes_lib.float32)
|
||||
for x in random_arrays
|
||||
]
|
||||
tf_val = av2.accumulate_n_v2(random_tensors)
|
||||
np_val = random_arrays[0]
|
||||
@ -95,21 +95,21 @@ class AccumulateNV2Test(test_util.TensorFlowTestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
a = variables.Variable(0.2)
|
||||
b = variables.Variable(0.1)
|
||||
tf_val = av2.accumulate_n_v2([a,b], shape=[2,2]) # Should be shape=[]
|
||||
tf_val = av2.accumulate_n_v2([a, b], shape=[2, 2]) # Should be shape=[]
|
||||
|
||||
def testIncompatibleShapes(self):
|
||||
with self.test_session():
|
||||
with self.assertRaises(ValueError):
|
||||
a = variables.Variable(np.array([0.1,0.2]))
|
||||
b = variables.Variable(np.array([[0.3],[0.4]]))
|
||||
tf_val = av2.accumulate_n_v2([a,b])
|
||||
a = variables.Variable(np.array([0.1, 0.2]))
|
||||
b = variables.Variable(np.array([[0.3], [0.4]]))
|
||||
tf_val = av2.accumulate_n_v2([a, b])
|
||||
|
||||
def testWrongType(self):
|
||||
with self.test_session():
|
||||
with self.assertRaises(TypeError):
|
||||
a = variables.Variable(0.2, dtype=np.float32)
|
||||
b = variables.Variable(0.1, dtype=np.float32)
|
||||
tf_val = av2.accumulate_n_v2([a,b], tensor_dtype=np.int32)
|
||||
tf_val = av2.accumulate_n_v2([a, b], tensor_dtype=np.int32)
|
||||
|
||||
def testWrongTypeOneInput(self):
|
||||
# Scenario that used to trigger a bug, even when testWrongType() worked
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Base utilities for loading datasets."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
@ -100,9 +99,7 @@ def load_iris(data_path=None):
|
||||
module_path = path.dirname(__file__)
|
||||
data_path = path.join(module_path, 'data', 'iris.csv')
|
||||
return load_csv_with_header(
|
||||
data_path,
|
||||
target_dtype=np.int,
|
||||
features_dtype=np.float)
|
||||
data_path, target_dtype=np.int, features_dtype=np.float)
|
||||
|
||||
|
||||
def load_boston(data_path=None):
|
||||
@ -118,16 +115,10 @@ def load_boston(data_path=None):
|
||||
module_path = path.dirname(__file__)
|
||||
data_path = path.join(module_path, 'data', 'boston_house_prices.csv')
|
||||
return load_csv_with_header(
|
||||
data_path,
|
||||
target_dtype=np.float,
|
||||
features_dtype=np.float)
|
||||
data_path, target_dtype=np.float, features_dtype=np.float)
|
||||
|
||||
|
||||
def retry(initial_delay,
|
||||
max_delay,
|
||||
factor=2.0,
|
||||
jitter=0.25,
|
||||
is_retriable=None):
|
||||
def retry(initial_delay, max_delay, factor=2.0, jitter=0.25, is_retriable=None):
|
||||
"""Simple decorator for wrapping retriable functions.
|
||||
|
||||
Args:
|
||||
@ -152,7 +143,7 @@ def retry(initial_delay,
|
||||
def delays():
|
||||
delay = initial_delay
|
||||
while delay <= max_delay:
|
||||
yield delay * random.uniform(1 - jitter, 1 + jitter)
|
||||
yield delay * random.uniform(1 - jitter, 1 + jitter)
|
||||
delay *= factor
|
||||
|
||||
def wrap(fn):
|
||||
@ -172,7 +163,9 @@ def retry(initial_delay,
|
||||
else:
|
||||
raise
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapped_fn
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
|
@ -198,10 +198,13 @@ class TensorMapper(object):
|
||||
|
||||
def GenerateGraph(subgraph_idx, g, opcode_mapper):
|
||||
"""Produces the HTML required to have a d3 visualization of the dag."""
|
||||
|
||||
def TensorName(idx):
|
||||
return "t%d"%idx
|
||||
return "t%d" % idx
|
||||
|
||||
def OpName(idx):
|
||||
return "o%d"%idx
|
||||
return "o%d" % idx
|
||||
|
||||
edges = []
|
||||
nodes = []
|
||||
first = {}
|
||||
@ -210,27 +213,35 @@ def GenerateGraph(subgraph_idx, g, opcode_mapper):
|
||||
for tensor_input_position, tensor_index in enumerate(op["inputs"]):
|
||||
if tensor_index not in first:
|
||||
first[tensor_index] = (
|
||||
op_index*pixel_mult,
|
||||
tensor_input_position*pixel_mult - pixel_mult/2)
|
||||
edges.append(
|
||||
{"source": TensorName(tensor_index), "target": OpName(op_index)})
|
||||
op_index * pixel_mult,
|
||||
tensor_input_position * pixel_mult - pixel_mult / 2)
|
||||
edges.append({
|
||||
"source": TensorName(tensor_index),
|
||||
"target": OpName(op_index)
|
||||
})
|
||||
for tensor_index in op["outputs"]:
|
||||
edges.append(
|
||||
{"target": TensorName(tensor_index), "source": OpName(op_index)})
|
||||
nodes.append({"id": OpName(op_index),
|
||||
"name": opcode_mapper(op["opcode_index"]),
|
||||
"group": 2,
|
||||
"x": pixel_mult,
|
||||
"y": op_index * pixel_mult})
|
||||
edges.append({
|
||||
"target": TensorName(tensor_index),
|
||||
"source": OpName(op_index)
|
||||
})
|
||||
nodes.append({
|
||||
"id": OpName(op_index),
|
||||
"name": opcode_mapper(op["opcode_index"]),
|
||||
"group": 2,
|
||||
"x": pixel_mult,
|
||||
"y": op_index * pixel_mult
|
||||
})
|
||||
for tensor_index, tensor in enumerate(g["tensors"]):
|
||||
initial_y = (first[tensor_index] if tensor_index in first
|
||||
else len(g["operators"]))
|
||||
initial_y = (
|
||||
first[tensor_index] if tensor_index in first else len(g["operators"]))
|
||||
|
||||
nodes.append({"id": TensorName(tensor_index),
|
||||
"name": "%s (%d)" % (tensor["name"], tensor_index),
|
||||
"group": 1,
|
||||
"x": 2,
|
||||
"y": initial_y})
|
||||
nodes.append({
|
||||
"id": TensorName(tensor_index),
|
||||
"name": "%s (%d)" % (tensor["name"], tensor_index),
|
||||
"group": 1,
|
||||
"x": 2,
|
||||
"y": initial_y
|
||||
})
|
||||
graph_str = json.dumps({"nodes": nodes, "edges": edges})
|
||||
|
||||
html = _D3_HTML_TEMPLATE % (graph_str, subgraph_idx)
|
||||
@ -267,7 +278,7 @@ def GenerateTableHtml(items, keys_to_print, display_index=True):
|
||||
for h, mapper in keys_to_print:
|
||||
val = tensor[h] if h in tensor else None
|
||||
val = val if mapper is None else mapper(val)
|
||||
html += "<td>%s</td>\n"%val
|
||||
html += "<td>%s</td>\n" % val
|
||||
|
||||
html += "</tr>\n"
|
||||
html += "</table>\n"
|
||||
@ -279,18 +290,19 @@ def CreateHtmlFile(tflite_input, html_output):
|
||||
|
||||
# Convert the model into a JSON flatbuffer using flatc (build if doesn't
|
||||
# exist.
|
||||
if not os.path.exists(tflite_input):
|
||||
if not os.path.exists(tflite_input):
|
||||
raise RuntimeError("Invalid filename %r" % tflite_input)
|
||||
if tflite_input.endswith(".tflite") or tflite_input.endswith(".bin"):
|
||||
|
||||
# Run convert
|
||||
cmd = (_BINARY + " -t "
|
||||
"--strict-json --defaults-json -o /tmp {schema} -- {input}".format(
|
||||
input=tflite_input, schema=_SCHEMA))
|
||||
cmd = (
|
||||
_BINARY + " -t "
|
||||
"--strict-json --defaults-json -o /tmp {schema} -- {input}".format(
|
||||
input=tflite_input, schema=_SCHEMA))
|
||||
print(cmd)
|
||||
os.system(cmd)
|
||||
real_output = ("/tmp/"+ os.path.splitext(os.path.split(tflite_input)[-1])[0]
|
||||
+ ".json")
|
||||
real_output = ("/tmp/" + os.path.splitext(
|
||||
os.path.split(tflite_input)[-1])[0] + ".json")
|
||||
|
||||
data = json.load(open(real_output))
|
||||
elif tflite_input.endswith(".json"):
|
||||
@ -302,12 +314,13 @@ def CreateHtmlFile(tflite_input, html_output):
|
||||
html += "<h1>TensorFlow Lite Model</h2>"
|
||||
|
||||
data["filename"] = tflite_input # Avoid special case
|
||||
toplevel_stuff = [("filename", None), ("version", None),
|
||||
("description", None)]
|
||||
toplevel_stuff = [("filename", None), ("version", None), ("description",
|
||||
None)]
|
||||
|
||||
html += "<table>\n"
|
||||
for key, mapping in toplevel_stuff:
|
||||
if not mapping: mapping = lambda x: x
|
||||
if not mapping:
|
||||
mapping = lambda x: x
|
||||
html += "<tr><th>%s</th><td>%s</td></tr>\n" % (key, mapping(data[key]))
|
||||
html += "</table>\n"
|
||||
|
||||
@ -320,22 +333,22 @@ def CreateHtmlFile(tflite_input, html_output):
|
||||
html += "<div class='subgraph'>"
|
||||
tensor_mapper = TensorMapper(g)
|
||||
opcode_mapper = OpCodeMapper(data)
|
||||
op_keys_to_display = [
|
||||
("inputs", tensor_mapper), ("outputs", tensor_mapper),
|
||||
("builtin_options", None), ("opcode_index", opcode_mapper)]
|
||||
tensor_keys_to_display = [
|
||||
("name", None), ("type", None), ("shape", None), ("buffer", None),
|
||||
("quantization", None)]
|
||||
op_keys_to_display = [("inputs", tensor_mapper), ("outputs", tensor_mapper),
|
||||
("builtin_options", None), ("opcode_index",
|
||||
opcode_mapper)]
|
||||
tensor_keys_to_display = [("name", None), ("type", None), ("shape", None),
|
||||
("buffer", None), ("quantization", None)]
|
||||
|
||||
html += "<h2>Subgraph %d</h2>\n" % subgraph_idx
|
||||
|
||||
# Inputs and outputs.
|
||||
html += "<h3>Inputs/Outputs</h3>\n"
|
||||
html += GenerateTableHtml([{"inputs": g["inputs"],
|
||||
"outputs": g["outputs"]}],
|
||||
[("inputs", tensor_mapper),
|
||||
("outputs", tensor_mapper)],
|
||||
display_index=False)
|
||||
html += GenerateTableHtml(
|
||||
[{
|
||||
"inputs": g["inputs"],
|
||||
"outputs": g["outputs"]
|
||||
}], [("inputs", tensor_mapper), ("outputs", tensor_mapper)],
|
||||
display_index=False)
|
||||
|
||||
# Print the tensors.
|
||||
html += "<h3>Tensors</h3>\n"
|
||||
@ -357,8 +370,7 @@ def CreateHtmlFile(tflite_input, html_output):
|
||||
|
||||
# Operator codes
|
||||
html += "<h2>Operator Codes</h2>\n"
|
||||
html += GenerateTableHtml(data["operator_codes"],
|
||||
operator_keys_to_display)
|
||||
html += GenerateTableHtml(data["operator_codes"], operator_keys_to_display)
|
||||
|
||||
html += "</body></html>\n"
|
||||
|
||||
@ -370,10 +382,10 @@ def main(argv):
|
||||
tflite_input = argv[1]
|
||||
html_output = argv[2]
|
||||
except IndexError:
|
||||
print ("Usage: %s <input tflite> <output html>" % (argv[0]))
|
||||
print("Usage: %s <input tflite> <output html>" % (argv[0]))
|
||||
else:
|
||||
CreateHtmlFile(tflite_input, html_output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(sys.argv)
|
||||
|
||||
|
@ -58,6 +58,7 @@ def read_cifar10(filename_queue):
|
||||
|
||||
class CIFAR10Record(object):
|
||||
pass
|
||||
|
||||
result = CIFAR10Record()
|
||||
|
||||
# Dimensions of the images in the CIFAR-10 dataset.
|
||||
@ -147,8 +148,9 @@ def distorted_inputs(data_dir, batch_size):
|
||||
images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
|
||||
labels: Labels. 1D tensor of [batch_size] size.
|
||||
"""
|
||||
filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
|
||||
for i in xrange(1, 6)]
|
||||
filenames = [
|
||||
os.path.join(data_dir, 'data_batch_%d.bin' % i) for i in xrange(1, 6)
|
||||
]
|
||||
for f in filenames:
|
||||
if not tf.gfile.Exists(f):
|
||||
raise ValueError('Failed to find file: ' + f)
|
||||
@ -174,10 +176,9 @@ def distorted_inputs(data_dir, batch_size):
|
||||
|
||||
# Because these operations are not commutative, consider randomizing
|
||||
# the order their operation.
|
||||
distorted_image = tf.image.random_brightness(distorted_image,
|
||||
max_delta=63)
|
||||
distorted_image = tf.image.random_contrast(distorted_image,
|
||||
lower=0.2, upper=1.8)
|
||||
distorted_image = tf.image.random_brightness(distorted_image, max_delta=63)
|
||||
distorted_image = tf.image.random_contrast(
|
||||
distorted_image, lower=0.2, upper=1.8)
|
||||
|
||||
# Subtract off the mean and divide by the variance of the pixels.
|
||||
float_image = tf.image.per_image_standardization(distorted_image)
|
||||
@ -188,15 +189,18 @@ def distorted_inputs(data_dir, batch_size):
|
||||
|
||||
# Ensure that the random shuffling has good mixing properties.
|
||||
min_fraction_of_examples_in_queue = 0.4
|
||||
min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN *
|
||||
min_fraction_of_examples_in_queue)
|
||||
print ('Filling queue with %d CIFAR images before starting to train. '
|
||||
'This will take a few minutes.' % min_queue_examples)
|
||||
min_queue_examples = int(
|
||||
NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * min_fraction_of_examples_in_queue)
|
||||
print('Filling queue with %d CIFAR images before starting to train. '
|
||||
'This will take a few minutes.' % min_queue_examples)
|
||||
|
||||
# Generate a batch of images and labels by building up a queue of examples.
|
||||
return _generate_image_and_label_batch(float_image, read_input.label,
|
||||
min_queue_examples, batch_size,
|
||||
shuffle=True)
|
||||
return _generate_image_and_label_batch(
|
||||
float_image,
|
||||
read_input.label,
|
||||
min_queue_examples,
|
||||
batch_size,
|
||||
shuffle=True)
|
||||
|
||||
|
||||
def inputs(eval_data, data_dir, batch_size):
|
||||
@ -212,8 +216,9 @@ def inputs(eval_data, data_dir, batch_size):
|
||||
labels: Labels. 1D tensor of [batch_size] size.
|
||||
"""
|
||||
if not eval_data:
|
||||
filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
|
||||
for i in xrange(1, 6)]
|
||||
filenames = [
|
||||
os.path.join(data_dir, 'data_batch_%d.bin' % i) for i in xrange(1, 6)
|
||||
]
|
||||
num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN
|
||||
else:
|
||||
filenames = [os.path.join(data_dir, 'test_batch.bin')]
|
||||
@ -235,8 +240,8 @@ def inputs(eval_data, data_dir, batch_size):
|
||||
|
||||
# Image processing for evaluation.
|
||||
# Crop the central [height, width] of the image.
|
||||
resized_image = tf.image.resize_image_with_crop_or_pad(reshaped_image,
|
||||
width, height)
|
||||
resized_image = tf.image.resize_image_with_crop_or_pad(
|
||||
reshaped_image, width, height)
|
||||
|
||||
# Subtract off the mean and divide by the variance of the pixels.
|
||||
float_image = tf.image.per_image_standardization(resized_image)
|
||||
@ -247,10 +252,13 @@ def inputs(eval_data, data_dir, batch_size):
|
||||
|
||||
# Ensure that the random shuffling has good mixing properties.
|
||||
min_fraction_of_examples_in_queue = 0.4
|
||||
min_queue_examples = int(num_examples_per_epoch *
|
||||
min_fraction_of_examples_in_queue)
|
||||
min_queue_examples = int(
|
||||
num_examples_per_epoch * min_fraction_of_examples_in_queue)
|
||||
|
||||
# Generate a batch of images and labels by building up a queue of examples.
|
||||
return _generate_image_and_label_batch(float_image, read_input.label,
|
||||
min_queue_examples, batch_size,
|
||||
shuffle=False)
|
||||
return _generate_image_and_label_batch(
|
||||
float_image,
|
||||
read_input.label,
|
||||
min_queue_examples,
|
||||
batch_size,
|
||||
shuffle=False)
|
||||
|
@ -42,7 +42,6 @@ from tensorflow.python.platform import test
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.contrib.rnn.python.ops import rnn_cell as contrib_rnn_cell
|
||||
|
||||
|
||||
# pylint: enable=protected-access
|
||||
Linear = core_rnn_cell._Linear # pylint: disable=invalid-name
|
||||
|
||||
@ -84,19 +83,22 @@ class RNNCellTest(test.TestCase):
|
||||
], [v.name for v in cell.trainable_variables])
|
||||
self.assertFalse(cell.non_trainable_variables)
|
||||
sess.run([variables_lib.global_variables_initializer()])
|
||||
res = sess.run(
|
||||
[g], {x.name: np.array([[1., 1.]]),
|
||||
m.name: np.array([[0.1, 0.1]])})
|
||||
res = sess.run([g], {
|
||||
x.name: np.array([[1., 1.]]),
|
||||
m.name: np.array([[0.1, 0.1]])
|
||||
})
|
||||
self.assertEqual(res[0].shape, (1, 2))
|
||||
|
||||
def testBasicRNNCellNotTrainable(self):
|
||||
with self.test_session() as sess:
|
||||
|
||||
def not_trainable_getter(getter, *args, **kwargs):
|
||||
kwargs["trainable"] = False
|
||||
return getter(*args, **kwargs)
|
||||
|
||||
with variable_scope.variable_scope(
|
||||
"root", initializer=init_ops.constant_initializer(0.5),
|
||||
"root",
|
||||
initializer=init_ops.constant_initializer(0.5),
|
||||
custom_getter=not_trainable_getter):
|
||||
x = array_ops.zeros([1, 2])
|
||||
m = array_ops.zeros([1, 2])
|
||||
@ -108,9 +110,10 @@ class RNNCellTest(test.TestCase):
|
||||
"root/basic_rnn_cell/%s:0" % rnn_cell_impl._BIAS_VARIABLE_NAME
|
||||
], [v.name for v in cell.non_trainable_variables])
|
||||
sess.run([variables_lib.global_variables_initializer()])
|
||||
res = sess.run(
|
||||
[g], {x.name: np.array([[1., 1.]]),
|
||||
m.name: np.array([[0.1, 0.1]])})
|
||||
res = sess.run([g], {
|
||||
x.name: np.array([[1., 1.]]),
|
||||
m.name: np.array([[0.1, 0.1]])
|
||||
})
|
||||
self.assertEqual(res[0].shape, (1, 2))
|
||||
|
||||
def testGRUCell(self):
|
||||
@ -121,9 +124,10 @@ class RNNCellTest(test.TestCase):
|
||||
m = array_ops.zeros([1, 2])
|
||||
g, _ = rnn_cell_impl.GRUCell(2)(x, m)
|
||||
sess.run([variables_lib.global_variables_initializer()])
|
||||
res = sess.run(
|
||||
[g], {x.name: np.array([[1., 1.]]),
|
||||
m.name: np.array([[0.1, 0.1]])})
|
||||
res = sess.run([g], {
|
||||
x.name: np.array([[1., 1.]]),
|
||||
m.name: np.array([[0.1, 0.1]])
|
||||
})
|
||||
# Smoke test
|
||||
self.assertAllClose(res[0], [[0.175991, 0.175991]])
|
||||
with variable_scope.variable_scope(
|
||||
@ -133,10 +137,10 @@ class RNNCellTest(test.TestCase):
|
||||
m = array_ops.zeros([1, 2])
|
||||
g, _ = rnn_cell_impl.GRUCell(2)(x, m)
|
||||
sess.run([variables_lib.global_variables_initializer()])
|
||||
res = sess.run(
|
||||
[g],
|
||||
{x.name: np.array([[1., 1., 1.]]),
|
||||
m.name: np.array([[0.1, 0.1]])})
|
||||
res = sess.run([g], {
|
||||
x.name: np.array([[1., 1., 1.]]),
|
||||
m.name: np.array([[0.1, 0.1]])
|
||||
})
|
||||
# Smoke test
|
||||
self.assertAllClose(res[0], [[0.156736, 0.156736]])
|
||||
|
||||
@ -148,11 +152,12 @@ class RNNCellTest(test.TestCase):
|
||||
m = array_ops.zeros([1, 2])
|
||||
g, _ = contrib_rnn_cell.SRUCell(2)(x, m)
|
||||
sess.run([variables_lib.global_variables_initializer()])
|
||||
res = sess.run(
|
||||
[g], {x.name: np.array([[1., 1.]]),
|
||||
m.name: np.array([[0.1, 0.1]])})
|
||||
res = sess.run([g], {
|
||||
x.name: np.array([[1., 1.]]),
|
||||
m.name: np.array([[0.1, 0.1]])
|
||||
})
|
||||
# Smoke test
|
||||
self.assertAllClose(res[0], [[0.509682, 0.509682]])
|
||||
self.assertAllClose(res[0], [[0.509682, 0.509682]])
|
||||
|
||||
def testBasicLSTMCell(self):
|
||||
for dtype in [dtypes.float16, dtypes.float32]:
|
||||
@ -164,8 +169,7 @@ class RNNCellTest(test.TestCase):
|
||||
m = array_ops.zeros([1, 8], dtype=dtype)
|
||||
cell = rnn_cell_impl.MultiRNNCell(
|
||||
[
|
||||
rnn_cell_impl.BasicLSTMCell(
|
||||
2, state_is_tuple=False)
|
||||
rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)
|
||||
for _ in range(2)
|
||||
],
|
||||
state_is_tuple=False)
|
||||
@ -183,22 +187,21 @@ class RNNCellTest(test.TestCase):
|
||||
"root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0" %
|
||||
rnn_cell_impl._BIAS_VARIABLE_NAME
|
||||
]
|
||||
self.assertEqual(
|
||||
expected_variable_names,
|
||||
[v.name for v in cell.trainable_variables])
|
||||
self.assertEqual(expected_variable_names,
|
||||
[v.name for v in cell.trainable_variables])
|
||||
self.assertFalse(cell.non_trainable_variables)
|
||||
sess.run([variables_lib.global_variables_initializer()])
|
||||
res = sess.run(
|
||||
[g, out_m],
|
||||
{x.name: np.array([[1., 1.]]),
|
||||
m.name: 0.1 * np.ones([1, 8])})
|
||||
res = sess.run([g, out_m], {
|
||||
x.name: np.array([[1., 1.]]),
|
||||
m.name: 0.1 * np.ones([1, 8])
|
||||
})
|
||||
self.assertEqual(len(res), 2)
|
||||
variables = variables_lib.global_variables()
|
||||
self.assertEqual(expected_variable_names, [v.name for v in variables])
|
||||
# The numbers in results were not calculated, this is just a
|
||||
# smoke test.
|
||||
self.assertAllClose(
|
||||
res[0], np.array([[0.240, 0.240]], dtype=np_dtype), 1e-2)
|
||||
self.assertAllClose(res[0], np.array(
|
||||
[[0.240, 0.240]], dtype=np_dtype), 1e-2)
|
||||
expected_mem = np.array(
|
||||
[[0.689, 0.689, 0.448, 0.448, 0.398, 0.398, 0.240, 0.240]],
|
||||
dtype=np_dtype)
|
||||
@ -208,13 +211,13 @@ class RNNCellTest(test.TestCase):
|
||||
# Test BasicLSTMCell with input_size != num_units.
|
||||
x = array_ops.zeros([1, 3], dtype=dtype)
|
||||
m = array_ops.zeros([1, 4], dtype=dtype)
|
||||
g, out_m = rnn_cell_impl.BasicLSTMCell(
|
||||
2, state_is_tuple=False)(x, m)
|
||||
g, out_m = rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)(x, m)
|
||||
sess.run([variables_lib.global_variables_initializer()])
|
||||
res = sess.run(
|
||||
[g, out_m],
|
||||
{x.name: np.array([[1., 1., 1.]], dtype=np_dtype),
|
||||
m.name: 0.1 * np.ones([1, 4], dtype=np_dtype)})
|
||||
[g, out_m], {
|
||||
x.name: np.array([[1., 1., 1.]], dtype=np_dtype),
|
||||
m.name: 0.1 * np.ones([1, 4], dtype=np_dtype)
|
||||
})
|
||||
self.assertEqual(len(res), 2)
|
||||
|
||||
def testBasicLSTMCellDimension0Error(self):
|
||||
@ -232,9 +235,11 @@ class RNNCellTest(test.TestCase):
|
||||
g, out_m = rnn_cell_impl.BasicLSTMCell(
|
||||
num_units, state_is_tuple=False)(x, m)
|
||||
sess.run([variables_lib.global_variables_initializer()])
|
||||
sess.run([g, out_m],
|
||||
{x.name: 1 * np.ones([batch_size, input_size]),
|
||||
m.name: 0.1 * np.ones([batch_size - 1, state_size])})
|
||||
sess.run(
|
||||
[g, out_m], {
|
||||
x.name: 1 * np.ones([batch_size, input_size]),
|
||||
m.name: 0.1 * np.ones([batch_size - 1, state_size])
|
||||
})
|
||||
|
||||
def testBasicLSTMCellStateSizeError(self):
|
||||
"""Tests that state_size must be num_units * 2."""
|
||||
@ -251,9 +256,11 @@ class RNNCellTest(test.TestCase):
|
||||
g, out_m = rnn_cell_impl.BasicLSTMCell(
|
||||
num_units, state_is_tuple=False)(x, m)
|
||||
sess.run([variables_lib.global_variables_initializer()])
|
||||
sess.run([g, out_m],
|
||||
{x.name: 1 * np.ones([batch_size, input_size]),
|
||||
m.name: 0.1 * np.ones([batch_size, state_size])})
|
||||
sess.run(
|
||||
[g, out_m], {
|
||||
x.name: 1 * np.ones([batch_size, input_size]),
|
||||
m.name: 0.1 * np.ones([batch_size, state_size])
|
||||
})
|
||||
|
||||
def testBasicLSTMCellStateTupleType(self):
|
||||
with self.test_session():
|
||||
@ -301,11 +308,12 @@ class RNNCellTest(test.TestCase):
|
||||
state_is_tuple=True)
|
||||
g, (out_m0, out_m1) = cell(x, (m0, m1))
|
||||
sess.run([variables_lib.global_variables_initializer()])
|
||||
res = sess.run([g, out_m0, out_m1], {
|
||||
x.name: np.array([[1., 1.]]),
|
||||
m0.name: 0.1 * np.ones([1, 4]),
|
||||
m1.name: 0.1 * np.ones([1, 4])
|
||||
})
|
||||
res = sess.run(
|
||||
[g, out_m0, out_m1], {
|
||||
x.name: np.array([[1., 1.]]),
|
||||
m0.name: 0.1 * np.ones([1, 4]),
|
||||
m1.name: 0.1 * np.ones([1, 4])
|
||||
})
|
||||
self.assertEqual(len(res), 3)
|
||||
# The numbers in results were not calculated, this is just a smoke test.
|
||||
# Note, however, these values should match the original
|
||||
@ -336,10 +344,11 @@ class RNNCellTest(test.TestCase):
|
||||
state_is_tuple=False)
|
||||
output, state = cell(x, m)
|
||||
sess.run([variables_lib.global_variables_initializer()])
|
||||
res = sess.run([output, state], {
|
||||
x.name: np.array([[1., 1.], [2., 2.], [3., 3.]]),
|
||||
m.name: 0.1 * np.ones((batch_size, state_size))
|
||||
})
|
||||
res = sess.run(
|
||||
[output, state], {
|
||||
x.name: np.array([[1., 1.], [2., 2.], [3., 3.]]),
|
||||
m.name: 0.1 * np.ones((batch_size, state_size))
|
||||
})
|
||||
self.assertEqual(len(res), 2)
|
||||
# The numbers in results were not calculated, this is mostly just a
|
||||
# smoke test.
|
||||
@ -442,10 +451,10 @@ class RNNCellTest(test.TestCase):
|
||||
rnn_cell_impl.GRUCell(3), num_proj=3)
|
||||
g, new_m = cell(x, m)
|
||||
sess.run([variables_lib.global_variables_initializer()])
|
||||
res = sess.run(
|
||||
[g, new_m],
|
||||
{x.name: np.array([[1., 1.]]),
|
||||
m.name: np.array([[0.1, 0.1, 0.1]])})
|
||||
res = sess.run([g, new_m], {
|
||||
x.name: np.array([[1., 1.]]),
|
||||
m.name: np.array([[0.1, 0.1, 0.1]])
|
||||
})
|
||||
self.assertEqual(res[1].shape, (1, 3))
|
||||
# The numbers in results were not calculated, this is just a smoke test.
|
||||
self.assertAllClose(res[0], [[0.154605, 0.154605, 0.154605]])
|
||||
@ -479,9 +488,11 @@ class RNNCellTest(test.TestCase):
|
||||
base_cell = rnn_cell_impl.GRUCell(3)
|
||||
g, m_new = base_cell(x, m)
|
||||
variable_scope.get_variable_scope().reuse_variables()
|
||||
|
||||
def residual_with_slice_fn(inp, out):
|
||||
inp_sliced = array_ops.slice(inp, [0, 0], [-1, 3])
|
||||
return inp_sliced + out
|
||||
|
||||
g_res, m_new_res = rnn_cell_impl.ResidualWrapper(
|
||||
base_cell, residual_with_slice_fn)(x, m)
|
||||
sess.run([variables_lib.global_variables_initializer()])
|
||||
@ -551,10 +562,10 @@ class RNNCellTest(test.TestCase):
|
||||
self.assertEqual(embedding_cell.output_size, 2)
|
||||
g, new_m = embedding_cell(x, m)
|
||||
sess.run([variables_lib.global_variables_initializer()])
|
||||
res = sess.run(
|
||||
[g, new_m],
|
||||
{x.name: np.array([[1]]),
|
||||
m.name: np.array([[0.1, 0.1]])})
|
||||
res = sess.run([g, new_m], {
|
||||
x.name: np.array([[1]]),
|
||||
m.name: np.array([[0.1, 0.1]])
|
||||
})
|
||||
self.assertEqual(res[1].shape, (1, 2))
|
||||
# The numbers in results were not calculated, this is just a smoke test.
|
||||
self.assertAllClose(res[0], [[0.17139, 0.17139]])
|
||||
@ -584,8 +595,8 @@ class RNNCellTest(test.TestCase):
|
||||
x = array_ops.zeros([1, 2])
|
||||
m = array_ops.zeros([1, 4])
|
||||
_, ml = rnn_cell_impl.MultiRNNCell(
|
||||
[rnn_cell_impl.GRUCell(2)
|
||||
for _ in range(2)], state_is_tuple=False)(x, m)
|
||||
[rnn_cell_impl.GRUCell(2) for _ in range(2)],
|
||||
state_is_tuple=False)(x, m)
|
||||
sess.run([variables_lib.global_variables_initializer()])
|
||||
res = sess.run(ml, {
|
||||
x.name: np.array([[1., 1.]]),
|
||||
@ -605,19 +616,20 @@ class RNNCellTest(test.TestCase):
|
||||
# Test incorrectness of state
|
||||
with self.assertRaisesRegexp(ValueError, "Expected state .* a tuple"):
|
||||
rnn_cell_impl.MultiRNNCell(
|
||||
[rnn_cell_impl.GRUCell(2)
|
||||
for _ in range(2)], state_is_tuple=True)(x, m_bad)
|
||||
[rnn_cell_impl.GRUCell(2) for _ in range(2)],
|
||||
state_is_tuple=True)(x, m_bad)
|
||||
|
||||
_, ml = rnn_cell_impl.MultiRNNCell(
|
||||
[rnn_cell_impl.GRUCell(2)
|
||||
for _ in range(2)], state_is_tuple=True)(x, m_good)
|
||||
[rnn_cell_impl.GRUCell(2) for _ in range(2)],
|
||||
state_is_tuple=True)(x, m_good)
|
||||
|
||||
sess.run([variables_lib.global_variables_initializer()])
|
||||
res = sess.run(ml, {
|
||||
x.name: np.array([[1., 1.]]),
|
||||
m_good[0].name: np.array([[0.1, 0.1]]),
|
||||
m_good[1].name: np.array([[0.1, 0.1]])
|
||||
})
|
||||
res = sess.run(
|
||||
ml, {
|
||||
x.name: np.array([[1., 1.]]),
|
||||
m_good[0].name: np.array([[0.1, 0.1]]),
|
||||
m_good[1].name: np.array([[0.1, 0.1]])
|
||||
})
|
||||
|
||||
# The numbers in results were not calculated, this is just a
|
||||
# smoke test. However, these numbers should match those of
|
||||
@ -628,8 +640,11 @@ class RNNCellTest(test.TestCase):
|
||||
|
||||
class DropoutWrapperTest(test.TestCase):
|
||||
|
||||
def _testDropoutWrapper(self, batch_size=None, time_steps=None,
|
||||
parallel_iterations=None, **kwargs):
|
||||
def _testDropoutWrapper(self,
|
||||
batch_size=None,
|
||||
time_steps=None,
|
||||
parallel_iterations=None,
|
||||
**kwargs):
|
||||
with self.test_session() as sess:
|
||||
with variable_scope.variable_scope(
|
||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||
@ -640,14 +655,14 @@ class DropoutWrapperTest(test.TestCase):
|
||||
x = constant_op.constant(
|
||||
[[[2., 2., 2.]], [[1., 1., 1.]]], dtype=dtypes.float32)
|
||||
m = rnn_cell_impl.LSTMStateTuple(
|
||||
*[constant_op.constant([[0.1, 0.1, 0.1]], dtype=dtypes.float32)
|
||||
] * 2)
|
||||
*[constant_op.constant([[0.1, 0.1, 0.1]], dtype=dtypes.float32
|
||||
)] * 2)
|
||||
else:
|
||||
x = constant_op.constant(
|
||||
np.random.randn(time_steps, batch_size, 3).astype(np.float32))
|
||||
m = rnn_cell_impl.LSTMStateTuple(*[
|
||||
constant_op.constant(
|
||||
[[0.1, 0.1, 0.1]] * batch_size, dtype=dtypes.float32)
|
||||
constant_op.
|
||||
constant([[0.1, 0.1, 0.1]] * batch_size, dtype=dtypes.float32)
|
||||
] * 2)
|
||||
outputs, final_state = rnn.dynamic_rnn(
|
||||
cell=rnn_cell_impl.DropoutWrapper(
|
||||
@ -674,8 +689,8 @@ class DropoutWrapperTest(test.TestCase):
|
||||
res = self._testDropoutWrapper(
|
||||
input_keep_prob=keep, output_keep_prob=keep, state_keep_prob=keep)
|
||||
true_full_output = np.array(
|
||||
[[[0.751109, 0.751109, 0.751109]],
|
||||
[[0.895509, 0.895509, 0.895509]]], dtype=np.float32)
|
||||
[[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]],
|
||||
dtype=np.float32)
|
||||
true_full_final_c = np.array(
|
||||
[[1.949385, 1.949385, 1.949385]], dtype=np.float32)
|
||||
self.assertAllClose(true_full_output, res[0])
|
||||
@ -687,8 +702,8 @@ class DropoutWrapperTest(test.TestCase):
|
||||
res = self._testDropoutWrapper(
|
||||
input_keep_prob=keep, output_keep_prob=keep, state_keep_prob=keep)
|
||||
true_full_output = np.array(
|
||||
[[[0.751109, 0.751109, 0.751109]],
|
||||
[[0.895509, 0.895509, 0.895509]]], dtype=np.float32)
|
||||
[[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]],
|
||||
dtype=np.float32)
|
||||
true_full_final_c = np.array(
|
||||
[[1.949385, 1.949385, 1.949385]], dtype=np.float32)
|
||||
self.assertAllClose(true_full_output, res[0])
|
||||
@ -703,16 +718,20 @@ class DropoutWrapperTest(test.TestCase):
|
||||
## consistent across both calls. Otherwise the seed may not end
|
||||
## up being munged consistently across both graphs.
|
||||
res_standard_1 = self._testDropoutWrapper(
|
||||
input_keep_prob=keep_some, output_keep_prob=keep_some,
|
||||
state_keep_prob=keep_some, seed=10,
|
||||
input_keep_prob=keep_some,
|
||||
output_keep_prob=keep_some,
|
||||
state_keep_prob=keep_some,
|
||||
seed=10,
|
||||
parallel_iterations=1)
|
||||
# Clear away the graph and the test session (which keeps variables around)
|
||||
ops.reset_default_graph()
|
||||
self._ClearCachedSession()
|
||||
random_seed.set_random_seed(2)
|
||||
res_standard_2 = self._testDropoutWrapper(
|
||||
input_keep_prob=keep_some, output_keep_prob=keep_some,
|
||||
state_keep_prob=keep_some, seed=10,
|
||||
input_keep_prob=keep_some,
|
||||
output_keep_prob=keep_some,
|
||||
state_keep_prob=keep_some,
|
||||
seed=10,
|
||||
parallel_iterations=1)
|
||||
self.assertAllClose(res_standard_1[0], res_standard_2[0])
|
||||
self.assertAllClose(res_standard_1[1].c, res_standard_2[1].c)
|
||||
@ -722,11 +741,12 @@ class DropoutWrapperTest(test.TestCase):
|
||||
keep_all = variable_scope.get_variable("all", initializer=1.0)
|
||||
keep_none = variable_scope.get_variable("none", initializer=1e-10)
|
||||
res = self._testDropoutWrapper(
|
||||
input_keep_prob=keep_all, output_keep_prob=keep_none,
|
||||
input_keep_prob=keep_all,
|
||||
output_keep_prob=keep_none,
|
||||
state_keep_prob=keep_all)
|
||||
true_full_output = np.array(
|
||||
[[[0.751109, 0.751109, 0.751109]],
|
||||
[[0.895509, 0.895509, 0.895509]]], dtype=np.float32)
|
||||
[[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]],
|
||||
dtype=np.float32)
|
||||
true_full_final_c = np.array(
|
||||
[[1.949385, 1.949385, 1.949385]], dtype=np.float32)
|
||||
self.assertAllClose(np.zeros(res[0].shape), res[0])
|
||||
@ -739,13 +759,13 @@ class DropoutWrapperTest(test.TestCase):
|
||||
# Even though we dropout state, by default DropoutWrapper never
|
||||
# drops out the memory ("c") term of an LSTMStateTuple.
|
||||
res = self._testDropoutWrapper(
|
||||
input_keep_prob=keep_all, output_keep_prob=keep_all,
|
||||
input_keep_prob=keep_all,
|
||||
output_keep_prob=keep_all,
|
||||
state_keep_prob=keep_none)
|
||||
true_c_state = np.array(
|
||||
[[1.713925, 1.713925, 1.713925]], dtype=np.float32)
|
||||
true_c_state = np.array([[1.713925, 1.713925, 1.713925]], dtype=np.float32)
|
||||
true_full_output = np.array(
|
||||
[[[0.751109, 0.751109, 0.751109]],
|
||||
[[0.895509, 0.895509, 0.895509]]], dtype=np.float32)
|
||||
[[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]],
|
||||
dtype=np.float32)
|
||||
self.assertAllClose(true_full_output[0], res[0][0])
|
||||
# Second output is modified by zero input state
|
||||
self.assertGreater(np.linalg.norm(true_full_output[1] - res[0][1]), 1e-4)
|
||||
@ -758,13 +778,14 @@ class DropoutWrapperTest(test.TestCase):
|
||||
keep_all = variable_scope.get_variable("all", initializer=1.0)
|
||||
keep_none = variable_scope.get_variable("none", initializer=1e-10)
|
||||
true_full_output = np.array(
|
||||
[[[0.751109, 0.751109, 0.751109]],
|
||||
[[0.895509, 0.895509, 0.895509]]], dtype=np.float32)
|
||||
[[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]],
|
||||
dtype=np.float32)
|
||||
true_full_final_c = np.array(
|
||||
[[1.949385, 1.949385, 1.949385]], dtype=np.float32)
|
||||
# All outputs are different because inputs are zeroed out
|
||||
res = self._testDropoutWrapper(
|
||||
input_keep_prob=keep_none, output_keep_prob=keep_all,
|
||||
input_keep_prob=keep_none,
|
||||
output_keep_prob=keep_all,
|
||||
state_keep_prob=keep_all)
|
||||
self.assertGreater(np.linalg.norm(res[0] - true_full_output), 1e-4)
|
||||
self.assertGreater(np.linalg.norm(res[1].h - true_full_output[1]), 1e-4)
|
||||
@ -774,9 +795,13 @@ class DropoutWrapperTest(test.TestCase):
|
||||
keep_some = 0.8
|
||||
keep_all = variable_scope.get_variable("all", initializer=1.0)
|
||||
res = self._testDropoutWrapper(
|
||||
input_keep_prob=keep_all, output_keep_prob=keep_some,
|
||||
state_keep_prob=keep_all, variational_recurrent=True,
|
||||
input_size=3, batch_size=5, time_steps=7)
|
||||
input_keep_prob=keep_all,
|
||||
output_keep_prob=keep_some,
|
||||
state_keep_prob=keep_all,
|
||||
variational_recurrent=True,
|
||||
input_size=3,
|
||||
batch_size=5,
|
||||
time_steps=7)
|
||||
# Ensure the same dropout pattern for all time steps
|
||||
output_mask = np.abs(res[0]) > 1e-6
|
||||
for m in output_mask[1:]:
|
||||
@ -785,9 +810,13 @@ class DropoutWrapperTest(test.TestCase):
|
||||
def testDropoutWrapperRecurrentStateInputAndOutput(self):
|
||||
keep_some = 0.9
|
||||
res = self._testDropoutWrapper(
|
||||
input_keep_prob=keep_some, output_keep_prob=keep_some,
|
||||
state_keep_prob=keep_some, variational_recurrent=True,
|
||||
input_size=3, batch_size=5, time_steps=7)
|
||||
input_keep_prob=keep_some,
|
||||
output_keep_prob=keep_some,
|
||||
state_keep_prob=keep_some,
|
||||
variational_recurrent=True,
|
||||
input_size=3,
|
||||
batch_size=5,
|
||||
time_steps=7)
|
||||
|
||||
# Smoke test for the state/input masks.
|
||||
output_mask = np.abs(res[0]) > 1e-6
|
||||
@ -811,17 +840,27 @@ class DropoutWrapperTest(test.TestCase):
|
||||
random_seed.set_random_seed(2347)
|
||||
np.random.seed(23487)
|
||||
res0 = self._testDropoutWrapper(
|
||||
input_keep_prob=keep_some, output_keep_prob=keep_some,
|
||||
state_keep_prob=keep_some, variational_recurrent=True,
|
||||
input_size=3, batch_size=5, time_steps=7, seed=-234987)
|
||||
input_keep_prob=keep_some,
|
||||
output_keep_prob=keep_some,
|
||||
state_keep_prob=keep_some,
|
||||
variational_recurrent=True,
|
||||
input_size=3,
|
||||
batch_size=5,
|
||||
time_steps=7,
|
||||
seed=-234987)
|
||||
ops.reset_default_graph()
|
||||
self._ClearCachedSession()
|
||||
random_seed.set_random_seed(2347)
|
||||
np.random.seed(23487)
|
||||
res1 = self._testDropoutWrapper(
|
||||
input_keep_prob=keep_some, output_keep_prob=keep_some,
|
||||
state_keep_prob=keep_some, variational_recurrent=True,
|
||||
input_size=3, batch_size=5, time_steps=7, seed=-234987)
|
||||
input_keep_prob=keep_some,
|
||||
output_keep_prob=keep_some,
|
||||
state_keep_prob=keep_some,
|
||||
variational_recurrent=True,
|
||||
input_size=3,
|
||||
batch_size=5,
|
||||
time_steps=7,
|
||||
seed=-234987)
|
||||
|
||||
output_mask = np.abs(res0[0]) > 1e-6
|
||||
for time_step in output_mask:
|
||||
@ -858,9 +897,10 @@ class SlimRNNCellTest(test.TestCase):
|
||||
g, _ = rnn_cell_impl._SlimRNNCell(my_cell)(x, m)
|
||||
# pylint: enable=protected-access
|
||||
sess.run([variables_lib.global_variables_initializer()])
|
||||
res = sess.run(
|
||||
[g], {x.name: np.array([[1., 1.]]),
|
||||
m.name: np.array([[0.1, 0.1]])})
|
||||
res = sess.run([g], {
|
||||
x.name: np.array([[1., 1.]]),
|
||||
m.name: np.array([[0.1, 0.1]])
|
||||
})
|
||||
self.assertEqual(res[0].shape, (1, 2))
|
||||
|
||||
def testBasicRNNCellMatch(self):
|
||||
|
@ -34,8 +34,7 @@ MAX_LABEL = 15
|
||||
WORDS_FEATURE = 'words' # Name of the input words feature.
|
||||
|
||||
|
||||
def estimator_spec_for_softmax_classification(
|
||||
logits, labels, mode):
|
||||
def estimator_spec_for_softmax_classification(logits, labels, mode):
|
||||
"""Returns EstimatorSpec instance for softmax classification."""
|
||||
predicted_classes = tf.argmax(logits, 1)
|
||||
if mode == tf.estimator.ModeKeys.PREDICT:
|
||||
@ -53,8 +52,8 @@ def estimator_spec_for_softmax_classification(
|
||||
return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
|
||||
|
||||
eval_metric_ops = {
|
||||
'accuracy': tf.metrics.accuracy(
|
||||
labels=labels, predictions=predicted_classes)
|
||||
'accuracy':
|
||||
tf.metrics.accuracy(labels=labels, predictions=predicted_classes)
|
||||
}
|
||||
return tf.estimator.EstimatorSpec(
|
||||
mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)
|
||||
@ -67,8 +66,7 @@ def bag_of_words_model(features, labels, mode):
|
||||
bow_embedding_column = tf.feature_column.embedding_column(
|
||||
bow_column, dimension=EMBEDDING_SIZE)
|
||||
bow = tf.feature_column.input_layer(
|
||||
features,
|
||||
feature_columns=[bow_embedding_column])
|
||||
features, feature_columns=[bow_embedding_column])
|
||||
logits = tf.layers.dense(bow, MAX_LABEL, activation=None)
|
||||
|
||||
return estimator_spec_for_softmax_classification(
|
||||
@ -110,9 +108,9 @@ def main(unused_argv):
|
||||
# Prepare training and testing data
|
||||
dbpedia = tf.contrib.learn.datasets.load_dataset(
|
||||
'dbpedia', test_with_fake_data=FLAGS.test_with_fake_data)
|
||||
x_train = pandas.Series(dbpedia.train.data[:,1])
|
||||
x_train = pandas.Series(dbpedia.train.data[:, 1])
|
||||
y_train = pandas.Series(dbpedia.train.target)
|
||||
x_test = pandas.Series(dbpedia.test.data[:,1])
|
||||
x_test = pandas.Series(dbpedia.test.data[:, 1])
|
||||
y_test = pandas.Series(dbpedia.test.target)
|
||||
|
||||
# Process vocabulary
|
||||
@ -152,10 +150,7 @@ def main(unused_argv):
|
||||
|
||||
# Predict.
|
||||
test_input_fn = tf.estimator.inputs.numpy_input_fn(
|
||||
x={WORDS_FEATURE: x_test},
|
||||
y=y_test,
|
||||
num_epochs=1,
|
||||
shuffle=False)
|
||||
x={WORDS_FEATURE: x_test}, y=y_test, num_epochs=1, shuffle=False)
|
||||
predictions = classifier.predict(input_fn=test_input_fn)
|
||||
y_predicted = np.array(list(p['class'] for p in predictions))
|
||||
y_predicted = y_predicted.reshape(np.array(y_test).shape)
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Notebook front-end to TensorFlow.
|
||||
|
||||
When you run this binary, you'll see something like below, which indicates
|
||||
@ -43,10 +42,8 @@ from tensorflow.python.platform import app
|
||||
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "cpp"
|
||||
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION"] = "2"
|
||||
|
||||
|
||||
FLAGS = None
|
||||
|
||||
|
||||
ORIG_ARGV = sys.argv
|
||||
# Main notebook process calls itself with argv[1]="kernel" to start kernel
|
||||
# subprocesses.
|
||||
@ -73,8 +70,8 @@ def main(unused_argv):
|
||||
notebookapp.ip = "0.0.0.0"
|
||||
notebookapp.password = passwd(FLAGS.password)
|
||||
else:
|
||||
print ("\nNo password specified; Notebook server will only be available"
|
||||
" on the local machine.\n")
|
||||
print("\nNo password specified; Notebook server will only be available"
|
||||
" on the local machine.\n")
|
||||
notebookapp.initialize(argv=["--notebook-dir", FLAGS.notebook_dir])
|
||||
|
||||
if notebookapp.ip == "0.0.0.0":
|
||||
@ -125,8 +122,8 @@ if __name__ == "__main__":
|
||||
# kernel app.
|
||||
if IS_KERNEL:
|
||||
# Drop everything except --flagfile.
|
||||
sys.argv = ([sys.argv[0]] +
|
||||
[x for x in sys.argv[1:] if x.startswith("--flagfile")])
|
||||
sys.argv = (
|
||||
[sys.argv[0]] + [x for x in sys.argv[1:] if x.startswith("--flagfile")])
|
||||
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Classes and functions related to train_and_evaluate."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
@ -37,7 +36,6 @@ from tensorflow.python.training import server_lib
|
||||
from tensorflow.python.training import session_run_hook
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
|
||||
_MAX_DELAY_SECS = 60
|
||||
_DELAY_SECS_PER_WORKER = 5
|
||||
_TF_CONFIG_ENV = 'TF_CONFIG'
|
||||
@ -50,8 +48,7 @@ _TRAINER_JOBS = (run_config_lib.TaskType.CHIEF, run_config_lib.TaskType.MASTER,
|
||||
def _validate_input_fn(input_fn):
|
||||
"""Validates the `input_fn`."""
|
||||
if not callable(input_fn):
|
||||
raise TypeError(
|
||||
'`input_fn` must be callable, given: {}'.format(input_fn))
|
||||
raise TypeError('`input_fn` must be callable, given: {}'.format(input_fn))
|
||||
|
||||
|
||||
def _validate_hooks(hooks):
|
||||
@ -125,10 +122,7 @@ class TrainSpec(
|
||||
duration. Optional hooks run at various stages of training.
|
||||
"""
|
||||
|
||||
def __new__(cls,
|
||||
input_fn,
|
||||
max_steps=None,
|
||||
hooks=None):
|
||||
def __new__(cls, input_fn, max_steps=None, hooks=None):
|
||||
"""Creates a validated `TrainSpec` instance.
|
||||
|
||||
Args:
|
||||
@ -161,16 +155,13 @@ class TrainSpec(
|
||||
hooks = _validate_hooks(hooks)
|
||||
|
||||
return super(TrainSpec, cls).__new__(
|
||||
cls,
|
||||
input_fn=input_fn,
|
||||
max_steps=max_steps,
|
||||
hooks=hooks)
|
||||
cls, input_fn=input_fn, max_steps=max_steps, hooks=hooks)
|
||||
|
||||
|
||||
class EvalSpec(
|
||||
collections.namedtuple('EvalSpec', [
|
||||
'input_fn', 'steps', 'name', 'hooks', 'exporters',
|
||||
'start_delay_secs', 'throttle_secs'
|
||||
'input_fn', 'steps', 'name', 'hooks', 'exporters', 'start_delay_secs',
|
||||
'throttle_secs'
|
||||
])):
|
||||
"""Configuration for the "eval" part for the `train_and_evaluate` call.
|
||||
|
||||
@ -417,8 +408,8 @@ def train_and_evaluate(estimator, train_spec, eval_spec):
|
||||
Raises:
|
||||
ValueError: if environment variable `TF_CONFIG` is incorrectly set.
|
||||
"""
|
||||
executor = _TrainingExecutor(estimator=estimator, train_spec=train_spec,
|
||||
eval_spec=eval_spec)
|
||||
executor = _TrainingExecutor(
|
||||
estimator=estimator, train_spec=train_spec, eval_spec=eval_spec)
|
||||
|
||||
config = estimator.config
|
||||
if (config.task_type == run_config_lib.TaskType.EVALUATOR and
|
||||
@ -561,9 +552,8 @@ class _TrainingExecutor(object):
|
||||
self._timer.update_last_triggered_step(global_step_value)
|
||||
self._evaluator.evaluate_and_export()
|
||||
else:
|
||||
logging.info(
|
||||
'Skip the current checkpoint eval due to throttle secs '
|
||||
'({} secs).'.format(self._eval_throttle_secs))
|
||||
logging.info('Skip the current checkpoint eval due to throttle secs '
|
||||
'({} secs).'.format(self._eval_throttle_secs))
|
||||
|
||||
# Final export signal: For any eval result with global_step >= train
|
||||
# max_steps, the evaluator will send the final export signal. There is a
|
||||
@ -576,8 +566,8 @@ class _TrainingExecutor(object):
|
||||
#
|
||||
# But here, throttle_secs will skip the next intermediate checkpoint and,
|
||||
# so, the double final export chance is very small.
|
||||
evaluator = _TrainingExecutor._Evaluator(
|
||||
self._estimator, self._eval_spec, self._train_spec.max_steps)
|
||||
evaluator = _TrainingExecutor._Evaluator(self._estimator, self._eval_spec,
|
||||
self._train_spec.max_steps)
|
||||
|
||||
# When the underlying `Estimator` object saves a new checkpoint, we would
|
||||
# like this callback to be called so that evaluation and export can trigger.
|
||||
@ -617,8 +607,7 @@ class _TrainingExecutor(object):
|
||||
raise ValueError('eval_spec.throttle_secs should be positive, given: {}.'
|
||||
'It is used do determine how long each training '
|
||||
'iteration should go when train and evaluate '
|
||||
'locally.'.format(
|
||||
self._eval_spec.throttle_secs))
|
||||
'locally.'.format(self._eval_spec.throttle_secs))
|
||||
|
||||
stop_hook = _StopAtSecsHook(self._eval_spec.throttle_secs)
|
||||
train_hooks = (
|
||||
@ -663,8 +652,9 @@ class _TrainingExecutor(object):
|
||||
|
||||
if not config.master:
|
||||
jobs = config.cluster_spec.jobs
|
||||
if (len(jobs) == 1 and len(config.cluster_spec.job_tasks(jobs[0])) == 1
|
||||
and config.task_type in _TRAINER_JOBS):
|
||||
if (len(jobs) == 1 and
|
||||
len(config.cluster_spec.job_tasks(jobs[0])) == 1 and
|
||||
config.task_type in _TRAINER_JOBS):
|
||||
# For distributed training, config.master is empty if and only if it has
|
||||
# a single node in the cluster spec. In this case, we should not start
|
||||
# the server.
|
||||
@ -679,9 +669,9 @@ class _TrainingExecutor(object):
|
||||
logging.info('Start Tensorflow server.')
|
||||
|
||||
if config.session_config is None:
|
||||
session_config=config_pb2.ConfigProto(log_device_placement=False)
|
||||
session_config = config_pb2.ConfigProto(log_device_placement=False)
|
||||
else:
|
||||
session_config=config_pb2.ConfigProto(
|
||||
session_config = config_pb2.ConfigProto(
|
||||
log_device_placement=False,
|
||||
gpu_options=config.session_config.gpu_options)
|
||||
|
||||
@ -744,8 +734,7 @@ class _TrainingExecutor(object):
|
||||
global_step >= self._train_spec.max_steps):
|
||||
logging.info(
|
||||
'Exiting evaluation, global_step=%s >= train max_steps=%s',
|
||||
global_step,
|
||||
self._train_spec.max_steps)
|
||||
global_step, self._train_spec.max_steps)
|
||||
return
|
||||
|
||||
latest_eval_result, should_early_stop = self._execute_evaluator_once(
|
||||
@ -781,10 +770,9 @@ class _TrainingExecutor(object):
|
||||
|
||||
# Throttle if necessary.
|
||||
elapsed_time = time.time() - start
|
||||
difference = throttle_secs - elapsed_time
|
||||
difference = throttle_secs - elapsed_time
|
||||
if difference > 0:
|
||||
logging.info('Waiting %f secs before starting next eval run.',
|
||||
difference)
|
||||
logging.info('Waiting %f secs before starting next eval run.', difference)
|
||||
time.sleep(difference)
|
||||
|
||||
return (eval_result, should_early_stop)
|
||||
@ -929,8 +917,8 @@ class _EvalResult(
|
||||
if checkpoint_path:
|
||||
raise ValueError(
|
||||
'checkpoint must be `None` if status is not {}; got status {}, '
|
||||
'checkpoint_path {}'.format(
|
||||
_EvalStatus.EVALUATED, status, checkpoint_path))
|
||||
'checkpoint_path {}'.format(_EvalStatus.EVALUATED, status,
|
||||
checkpoint_path))
|
||||
return super(_EvalResult, cls).__new__(cls, status, metrics,
|
||||
checkpoint_path)
|
||||
|
||||
|
@ -56,9 +56,11 @@ class TensordotTest(test_lib.TestCase):
|
||||
axes_ph = array_ops.placeholder(dtypes.int32)
|
||||
output = math_ops.tensordot(a_ph, b_ph, axes_ph)
|
||||
_ = sess.run(
|
||||
[output], feed_dict={a_ph: a,
|
||||
b_ph: b,
|
||||
axes_ph: (a_axes, b_axes)})
|
||||
[output], feed_dict={
|
||||
a_ph: a,
|
||||
b_ph: b,
|
||||
axes_ph: (a_axes, b_axes)
|
||||
})
|
||||
|
||||
def test_invalid_axes(self):
|
||||
a = [[1, 2], [3, 4]]
|
||||
@ -81,26 +83,27 @@ class TensordotTest(test_lib.TestCase):
|
||||
with self.test_session() as sess:
|
||||
with self.assertRaises(errors_impl.InvalidArgumentError):
|
||||
_ = sess.run(
|
||||
[output], feed_dict={a_ph: a,
|
||||
b_ph: b,
|
||||
axes_ph: axes_value})
|
||||
[output], feed_dict={
|
||||
a_ph: a,
|
||||
b_ph: b,
|
||||
axes_ph: axes_value
|
||||
})
|
||||
|
||||
# Test case for 11950
|
||||
def test_valid_axis(self):
|
||||
for axes_value in [1, 2], [[1], [2]]:
|
||||
with self.test_session() as sess:
|
||||
np_a = np.ones((3,3))
|
||||
np_a = np.ones((3, 3))
|
||||
np_b = np.array([2, 3, 1])[None, None]
|
||||
np_ans = np.tensordot(np_a, np_b, axes_value)
|
||||
|
||||
tf_a = array_ops.ones((3,3), dtype=dtypes.float32)
|
||||
tf_a = array_ops.ones((3, 3), dtype=dtypes.float32)
|
||||
tf_b = constant_op.constant([2, 3, 1], dtype=dtypes.float32)[None, None]
|
||||
tf_ans = math_ops.tensordot(tf_a, tf_b, axes_value).eval()
|
||||
|
||||
self.assertAllEqual(tf_ans.shape, np_ans.shape)
|
||||
self.assertAllEqual(tf_ans, np_ans)
|
||||
|
||||
|
||||
def test_partial_shape_inference(self):
|
||||
a = array_ops.placeholder(dtypes.float32)
|
||||
b = array_ops.placeholder(dtypes.float32)
|
||||
@ -169,9 +172,11 @@ def _get_tensordot_tests(dtype_, rank_a_, rank_b_, num_dims_, dynamic_shape_):
|
||||
axes = array_ops.placeholder(dtypes.int32)
|
||||
c = math_ops.tensordot(a, b, axes)
|
||||
tf_ans = sess.run(
|
||||
c, feed_dict={a: a_np,
|
||||
b: b_np,
|
||||
axes: (a_dims_np, b_dims_np)})
|
||||
c, feed_dict={
|
||||
a: a_np,
|
||||
b: b_np,
|
||||
axes: (a_dims_np, b_dims_np)
|
||||
})
|
||||
else:
|
||||
tf_ans = math_ops.tensordot(a_np, b_np, (a_dims_np, b_dims_np)).eval()
|
||||
self.assertAllClose(tf_ans, np_ans, rtol=tol, atol=tol)
|
||||
|
@ -184,7 +184,8 @@ do_pylint() {
|
||||
# W0312 mixed-indentation
|
||||
# C0330 bad-continuation
|
||||
# C0301 line-too-long
|
||||
grep -E '(\[E|\[W0311|\[W0312|\[C0330|\[C0301)' ${OUTPUT_FILE} > ${ERRORS_FILE}
|
||||
# C0326 bad-whitespace
|
||||
grep -E '(\[E|\[W0311|\[W0312|\[C0330|\[C0301|\[C0326)' ${OUTPUT_FILE} > ${ERRORS_FILE}
|
||||
|
||||
N_ERRORS=0
|
||||
while read -r LINE; do
|
||||
|
@ -46,8 +46,9 @@ class APIChangeSpec(object):
|
||||
"""
|
||||
|
||||
|
||||
class _FileEditTuple(collections.namedtuple(
|
||||
"_FileEditTuple", ["comment", "line", "start", "old", "new"])):
|
||||
class _FileEditTuple(
|
||||
collections.namedtuple("_FileEditTuple",
|
||||
["comment", "line", "start", "old", "new"])):
|
||||
"""Each edit that is recorded by a _FileEditRecorder.
|
||||
|
||||
Fields:
|
||||
@ -179,8 +180,7 @@ class _ASTCallVisitor(ast.NodeVisitor):
|
||||
function_renames = self._api_change_spec.function_renames
|
||||
try:
|
||||
new_name = function_renames[full_name]
|
||||
self._file_edit.add("Renamed function %r to %r" % (full_name,
|
||||
new_name),
|
||||
self._file_edit.add("Renamed function %r to %r" % (full_name, new_name),
|
||||
node.lineno, node.col_offset, full_name, new_name)
|
||||
except KeyError:
|
||||
pass
|
||||
@ -227,7 +227,7 @@ class _ASTCallVisitor(ast.NodeVisitor):
|
||||
# loop over lines
|
||||
while 1:
|
||||
# Reverse the text to and regular expression search for whitespace
|
||||
text = self._lines[line-1]
|
||||
text = self._lines[line - 1]
|
||||
reversed_preceding_text = text[:col][::-1]
|
||||
# First find if a [ can be found with only whitespace between it and
|
||||
# col.
|
||||
@ -248,8 +248,8 @@ class _ASTCallVisitor(ast.NodeVisitor):
|
||||
# node ranges to filter out spurious #'s that appear in string
|
||||
# literals.
|
||||
comment_start = prev_line.find("#")
|
||||
if comment_start == -1:
|
||||
col = len(prev_line) -1
|
||||
if comment_start == -1:
|
||||
col = len(prev_line) - 1
|
||||
elif find_string_chars.search(prev_line[comment_start:]) is None:
|
||||
col = comment_start
|
||||
else:
|
||||
@ -260,7 +260,6 @@ class _ASTCallVisitor(ast.NodeVisitor):
|
||||
# it is not possible to use that in an argument.
|
||||
return node.lineno, node.col_offset
|
||||
|
||||
|
||||
def visit_Call(self, node): # pylint: disable=invalid-name
|
||||
"""Handle visiting a call node in the AST.
|
||||
|
||||
@ -268,7 +267,6 @@ class _ASTCallVisitor(ast.NodeVisitor):
|
||||
node: Current Node
|
||||
"""
|
||||
|
||||
|
||||
# Find a simple attribute name path e.g. "tf.foo.bar"
|
||||
full_name = self._get_attribute_full_path(node.func)
|
||||
|
||||
@ -293,18 +291,21 @@ class _ASTCallVisitor(ast.NodeVisitor):
|
||||
lineno, col_offset = self._find_true_position(arg)
|
||||
if lineno is None or col_offset is None:
|
||||
self._file_edit.add(
|
||||
"Failed to add keyword %r to reordered function %r"
|
||||
% (reordered[idx], full_name), arg.lineno, arg.col_offset,
|
||||
"", "",
|
||||
"Failed to add keyword %r to reordered function %r" %
|
||||
(reordered[idx], full_name),
|
||||
arg.lineno,
|
||||
arg.col_offset,
|
||||
"",
|
||||
"",
|
||||
error="A necessary keyword argument failed to be inserted.")
|
||||
else:
|
||||
keyword_arg = reordered[idx]
|
||||
if (full_name in function_keyword_renames and
|
||||
keyword_arg in function_keyword_renames[full_name]):
|
||||
keyword_arg = function_keyword_renames[full_name][keyword_arg]
|
||||
self._file_edit.add("Added keyword %r to reordered function %r"
|
||||
% (reordered[idx], full_name), lineno,
|
||||
col_offset, "", keyword_arg + "=")
|
||||
self._file_edit.add("Added keyword %r to reordered function %r" %
|
||||
(reordered[idx], full_name), lineno, col_offset,
|
||||
"", keyword_arg + "=")
|
||||
|
||||
# Examine each keyword argument and convert it to the final renamed form
|
||||
renamed_keywords = ({} if full_name not in function_keyword_renames else
|
||||
@ -322,11 +323,11 @@ class _ASTCallVisitor(ast.NodeVisitor):
|
||||
# value.
|
||||
key_start = argval_col_offset - len(argkey) - 1
|
||||
key_end = key_start + len(argkey) + 1
|
||||
if (self._lines[argval_lineno - 1][key_start:key_end] ==
|
||||
argkey + "="):
|
||||
if (self._lines[argval_lineno - 1][key_start:key_end] == argkey +
|
||||
"="):
|
||||
self._file_edit.add("Renamed keyword argument from %r to %r" %
|
||||
(argkey, renamed_keywords[argkey]),
|
||||
argval_lineno,
|
||||
(argkey,
|
||||
renamed_keywords[argkey]), argval_lineno,
|
||||
argval_col_offset - len(argkey) - 1,
|
||||
argkey + "=", renamed_keywords[argkey] + "=")
|
||||
continue
|
||||
@ -335,7 +336,8 @@ class _ASTCallVisitor(ast.NodeVisitor):
|
||||
(argkey, renamed_keywords[argkey]),
|
||||
argval.lineno,
|
||||
argval.col_offset - len(argkey) - 1,
|
||||
"", "",
|
||||
"",
|
||||
"",
|
||||
error="Failed to find keyword lexographically. Fix manually.")
|
||||
|
||||
ast.NodeVisitor.generic_visit(self, node)
|
||||
@ -352,7 +354,7 @@ class _ASTCallVisitor(ast.NodeVisitor):
|
||||
if full_name in self._api_change_spec.change_to_function:
|
||||
if not hasattr(node, "is_function_for_call"):
|
||||
new_text = full_name + "()"
|
||||
self._file_edit.add("Changed %r to %r"%(full_name, new_text),
|
||||
self._file_edit.add("Changed %r to %r" % (full_name, new_text),
|
||||
node.lineno, node.col_offset, full_name, new_text)
|
||||
|
||||
ast.NodeVisitor.generic_visit(self, node)
|
||||
@ -380,8 +382,8 @@ class ASTCodeUpgrader(object):
|
||||
# Write to a temporary file, just in case we are doing an implace modify.
|
||||
with open(in_filename, "r") as in_file, \
|
||||
tempfile.NamedTemporaryFile("w", delete=False) as temp_file:
|
||||
ret = self.process_opened_file(
|
||||
in_filename, in_file, out_filename, temp_file)
|
||||
ret = self.process_opened_file(in_filename, in_file, out_filename,
|
||||
temp_file)
|
||||
|
||||
shutil.move(temp_file.name, out_filename)
|
||||
return ret
|
||||
@ -424,6 +426,7 @@ class ASTCodeUpgrader(object):
|
||||
out_file.write(out_text)
|
||||
text += "\n"
|
||||
return 1, text, process_errors
|
||||
|
||||
# pylint: enable=broad-except
|
||||
|
||||
def process_tree(self, root_directory, output_root_directory,
|
||||
@ -444,16 +447,16 @@ class ASTCodeUpgrader(object):
|
||||
|
||||
# make sure output directory doesn't exist
|
||||
if output_root_directory and os.path.exists(output_root_directory):
|
||||
print("Output directory %r must not already exist." % (
|
||||
output_root_directory))
|
||||
print("Output directory %r must not already exist." %
|
||||
(output_root_directory))
|
||||
sys.exit(1)
|
||||
|
||||
# make sure output directory does not overlap with root_directory
|
||||
norm_root = os.path.split(os.path.normpath(root_directory))
|
||||
norm_output = os.path.split(os.path.normpath(output_root_directory))
|
||||
if norm_root == norm_output:
|
||||
print("Output directory %r same as input directory %r" % (
|
||||
root_directory, output_root_directory))
|
||||
print("Output directory %r same as input directory %r" %
|
||||
(root_directory, output_root_directory))
|
||||
sys.exit(1)
|
||||
|
||||
# Collect list of files to process (we do this to correctly handle if the
|
||||
@ -465,14 +468,16 @@ class ASTCodeUpgrader(object):
|
||||
copy_files = [f for f in file_list if not f.endswith(".py")]
|
||||
for filename in py_files:
|
||||
fullpath = os.path.join(dir_name, filename)
|
||||
fullpath_output = os.path.join(
|
||||
output_root_directory, os.path.relpath(fullpath, root_directory))
|
||||
fullpath_output = os.path.join(output_root_directory,
|
||||
os.path.relpath(fullpath,
|
||||
root_directory))
|
||||
files_to_process.append((fullpath, fullpath_output))
|
||||
if copy_other_files:
|
||||
for filename in copy_files:
|
||||
fullpath = os.path.join(dir_name, filename)
|
||||
fullpath_output = os.path.join(
|
||||
output_root_directory, os.path.relpath(fullpath, root_directory))
|
||||
fullpath_output = os.path.join(output_root_directory,
|
||||
os.path.relpath(
|
||||
fullpath, root_directory))
|
||||
files_to_copy.append((fullpath, fullpath_output))
|
||||
|
||||
file_count = 0
|
||||
@ -641,18 +646,17 @@ class TFAPIChangeSpec(APIChangeSpec):
|
||||
"tf.concat": ["concat_dim", "values", "name"],
|
||||
"tf.svd": ["tensor", "compute_uv", "full_matrices", "name"],
|
||||
"tf.nn.softmax_cross_entropy_with_logits": [
|
||||
"logits", "labels", "dim", "name"],
|
||||
"logits", "labels", "dim", "name"
|
||||
],
|
||||
"tf.nn.sparse_softmax_cross_entropy_with_logits": [
|
||||
"logits", "labels", "name"],
|
||||
"tf.nn.sigmoid_cross_entropy_with_logits": [
|
||||
"logits", "labels", "name"],
|
||||
"logits", "labels", "name"
|
||||
],
|
||||
"tf.nn.sigmoid_cross_entropy_with_logits": ["logits", "labels", "name"],
|
||||
"tf.op_scope": ["values", "name", "default_name"],
|
||||
}
|
||||
|
||||
# Specially handled functions.
|
||||
self.function_handle = {
|
||||
"tf.reverse": self._reverse_handler
|
||||
}
|
||||
self.function_handle = {"tf.reverse": self._reverse_handler}
|
||||
|
||||
@staticmethod
|
||||
def _reverse_handler(file_edit_recorder, node):
|
||||
@ -661,12 +665,13 @@ class TFAPIChangeSpec(APIChangeSpec):
|
||||
comment = ("ERROR: tf.reverse has had its argument semantics changed\n"
|
||||
"significantly the converter cannot detect this reliably, so you"
|
||||
"need to inspect this usage manually.\n")
|
||||
file_edit_recorder.add(comment,
|
||||
node.lineno,
|
||||
node.col_offset,
|
||||
"tf.reverse",
|
||||
"tf.reverse",
|
||||
error="tf.reverse requires manual check.")
|
||||
file_edit_recorder.add(
|
||||
comment,
|
||||
node.lineno,
|
||||
node.col_offset,
|
||||
"tf.reverse",
|
||||
"tf.reverse",
|
||||
error="tf.reverse requires manual check.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
Reference in New Issue
Block a user