Add C0326 bad-whitespace error to pylint sanity check.
PiperOrigin-RevId: 183689499
This commit is contained in:
parent
730071d0dc
commit
fd63d4e30a
@ -31,7 +31,6 @@ from tensorflow.python.ops import variables
|
|||||||
from tensorflow.python.platform import googletest
|
from tensorflow.python.platform import googletest
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class AccumulateNV2Test(test_util.TensorFlowTestCase):
|
class AccumulateNV2Test(test_util.TensorFlowTestCase):
|
||||||
"""Tests of the new, differentiable version of accumulate_n"""
|
"""Tests of the new, differentiable version of accumulate_n"""
|
||||||
|
|
||||||
@ -62,8 +61,9 @@ class AccumulateNV2Test(test_util.TensorFlowTestCase):
|
|||||||
accum_n = av2.accumulate_n_v2(input_vars)
|
accum_n = av2.accumulate_n_v2(input_vars)
|
||||||
sess.run(variables.global_variables_initializer())
|
sess.run(variables.global_variables_initializer())
|
||||||
accum_n_grad = gradients.gradients(accum_n, input_vars)
|
accum_n_grad = gradients.gradients(accum_n, input_vars)
|
||||||
self.assertAllEqual(np.repeat(1.0, num_inputs), # d/dx (x + y + ...) = 1
|
self.assertAllEqual(
|
||||||
[g.eval() for g in accum_n_grad])
|
np.repeat(1.0, num_inputs), # d/dx (x + y + ...) = 1
|
||||||
|
[g.eval() for g in accum_n_grad])
|
||||||
|
|
||||||
# The tests below used to be in a separate class under cwise_ops_test.py,
|
# The tests below used to be in a separate class under cwise_ops_test.py,
|
||||||
# which did not run in the default test target.
|
# which did not run in the default test target.
|
||||||
@ -75,8 +75,8 @@ class AccumulateNV2Test(test_util.TensorFlowTestCase):
|
|||||||
np.random.rand(16, 16, 16, 16).astype(np.float32) for _ in range(20)
|
np.random.rand(16, 16, 16, 16).astype(np.float32) for _ in range(20)
|
||||||
]
|
]
|
||||||
random_tensors = [
|
random_tensors = [
|
||||||
ops.convert_to_tensor(
|
ops.convert_to_tensor(x, dtype=dtypes_lib.float32)
|
||||||
x, dtype=dtypes_lib.float32) for x in random_arrays
|
for x in random_arrays
|
||||||
]
|
]
|
||||||
tf_val = av2.accumulate_n_v2(random_tensors)
|
tf_val = av2.accumulate_n_v2(random_tensors)
|
||||||
np_val = random_arrays[0]
|
np_val = random_arrays[0]
|
||||||
@ -95,21 +95,21 @@ class AccumulateNV2Test(test_util.TensorFlowTestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
a = variables.Variable(0.2)
|
a = variables.Variable(0.2)
|
||||||
b = variables.Variable(0.1)
|
b = variables.Variable(0.1)
|
||||||
tf_val = av2.accumulate_n_v2([a,b], shape=[2,2]) # Should be shape=[]
|
tf_val = av2.accumulate_n_v2([a, b], shape=[2, 2]) # Should be shape=[]
|
||||||
|
|
||||||
def testIncompatibleShapes(self):
|
def testIncompatibleShapes(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
a = variables.Variable(np.array([0.1,0.2]))
|
a = variables.Variable(np.array([0.1, 0.2]))
|
||||||
b = variables.Variable(np.array([[0.3],[0.4]]))
|
b = variables.Variable(np.array([[0.3], [0.4]]))
|
||||||
tf_val = av2.accumulate_n_v2([a,b])
|
tf_val = av2.accumulate_n_v2([a, b])
|
||||||
|
|
||||||
def testWrongType(self):
|
def testWrongType(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
with self.assertRaises(TypeError):
|
with self.assertRaises(TypeError):
|
||||||
a = variables.Variable(0.2, dtype=np.float32)
|
a = variables.Variable(0.2, dtype=np.float32)
|
||||||
b = variables.Variable(0.1, dtype=np.float32)
|
b = variables.Variable(0.1, dtype=np.float32)
|
||||||
tf_val = av2.accumulate_n_v2([a,b], tensor_dtype=np.int32)
|
tf_val = av2.accumulate_n_v2([a, b], tensor_dtype=np.int32)
|
||||||
|
|
||||||
def testWrongTypeOneInput(self):
|
def testWrongTypeOneInput(self):
|
||||||
# Scenario that used to trigger a bug, even when testWrongType() worked
|
# Scenario that used to trigger a bug, even when testWrongType() worked
|
||||||
|
|||||||
@ -12,7 +12,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
"""Base utilities for loading datasets."""
|
"""Base utilities for loading datasets."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
@ -100,9 +99,7 @@ def load_iris(data_path=None):
|
|||||||
module_path = path.dirname(__file__)
|
module_path = path.dirname(__file__)
|
||||||
data_path = path.join(module_path, 'data', 'iris.csv')
|
data_path = path.join(module_path, 'data', 'iris.csv')
|
||||||
return load_csv_with_header(
|
return load_csv_with_header(
|
||||||
data_path,
|
data_path, target_dtype=np.int, features_dtype=np.float)
|
||||||
target_dtype=np.int,
|
|
||||||
features_dtype=np.float)
|
|
||||||
|
|
||||||
|
|
||||||
def load_boston(data_path=None):
|
def load_boston(data_path=None):
|
||||||
@ -118,16 +115,10 @@ def load_boston(data_path=None):
|
|||||||
module_path = path.dirname(__file__)
|
module_path = path.dirname(__file__)
|
||||||
data_path = path.join(module_path, 'data', 'boston_house_prices.csv')
|
data_path = path.join(module_path, 'data', 'boston_house_prices.csv')
|
||||||
return load_csv_with_header(
|
return load_csv_with_header(
|
||||||
data_path,
|
data_path, target_dtype=np.float, features_dtype=np.float)
|
||||||
target_dtype=np.float,
|
|
||||||
features_dtype=np.float)
|
|
||||||
|
|
||||||
|
|
||||||
def retry(initial_delay,
|
def retry(initial_delay, max_delay, factor=2.0, jitter=0.25, is_retriable=None):
|
||||||
max_delay,
|
|
||||||
factor=2.0,
|
|
||||||
jitter=0.25,
|
|
||||||
is_retriable=None):
|
|
||||||
"""Simple decorator for wrapping retriable functions.
|
"""Simple decorator for wrapping retriable functions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -152,7 +143,7 @@ def retry(initial_delay,
|
|||||||
def delays():
|
def delays():
|
||||||
delay = initial_delay
|
delay = initial_delay
|
||||||
while delay <= max_delay:
|
while delay <= max_delay:
|
||||||
yield delay * random.uniform(1 - jitter, 1 + jitter)
|
yield delay * random.uniform(1 - jitter, 1 + jitter)
|
||||||
delay *= factor
|
delay *= factor
|
||||||
|
|
||||||
def wrap(fn):
|
def wrap(fn):
|
||||||
@ -172,7 +163,9 @@ def retry(initial_delay,
|
|||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
return fn(*args, **kwargs)
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
return wrapped_fn
|
return wrapped_fn
|
||||||
|
|
||||||
return wrap
|
return wrap
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -198,10 +198,13 @@ class TensorMapper(object):
|
|||||||
|
|
||||||
def GenerateGraph(subgraph_idx, g, opcode_mapper):
|
def GenerateGraph(subgraph_idx, g, opcode_mapper):
|
||||||
"""Produces the HTML required to have a d3 visualization of the dag."""
|
"""Produces the HTML required to have a d3 visualization of the dag."""
|
||||||
|
|
||||||
def TensorName(idx):
|
def TensorName(idx):
|
||||||
return "t%d"%idx
|
return "t%d" % idx
|
||||||
|
|
||||||
def OpName(idx):
|
def OpName(idx):
|
||||||
return "o%d"%idx
|
return "o%d" % idx
|
||||||
|
|
||||||
edges = []
|
edges = []
|
||||||
nodes = []
|
nodes = []
|
||||||
first = {}
|
first = {}
|
||||||
@ -210,27 +213,35 @@ def GenerateGraph(subgraph_idx, g, opcode_mapper):
|
|||||||
for tensor_input_position, tensor_index in enumerate(op["inputs"]):
|
for tensor_input_position, tensor_index in enumerate(op["inputs"]):
|
||||||
if tensor_index not in first:
|
if tensor_index not in first:
|
||||||
first[tensor_index] = (
|
first[tensor_index] = (
|
||||||
op_index*pixel_mult,
|
op_index * pixel_mult,
|
||||||
tensor_input_position*pixel_mult - pixel_mult/2)
|
tensor_input_position * pixel_mult - pixel_mult / 2)
|
||||||
edges.append(
|
edges.append({
|
||||||
{"source": TensorName(tensor_index), "target": OpName(op_index)})
|
"source": TensorName(tensor_index),
|
||||||
|
"target": OpName(op_index)
|
||||||
|
})
|
||||||
for tensor_index in op["outputs"]:
|
for tensor_index in op["outputs"]:
|
||||||
edges.append(
|
edges.append({
|
||||||
{"target": TensorName(tensor_index), "source": OpName(op_index)})
|
"target": TensorName(tensor_index),
|
||||||
nodes.append({"id": OpName(op_index),
|
"source": OpName(op_index)
|
||||||
"name": opcode_mapper(op["opcode_index"]),
|
})
|
||||||
"group": 2,
|
nodes.append({
|
||||||
"x": pixel_mult,
|
"id": OpName(op_index),
|
||||||
"y": op_index * pixel_mult})
|
"name": opcode_mapper(op["opcode_index"]),
|
||||||
|
"group": 2,
|
||||||
|
"x": pixel_mult,
|
||||||
|
"y": op_index * pixel_mult
|
||||||
|
})
|
||||||
for tensor_index, tensor in enumerate(g["tensors"]):
|
for tensor_index, tensor in enumerate(g["tensors"]):
|
||||||
initial_y = (first[tensor_index] if tensor_index in first
|
initial_y = (
|
||||||
else len(g["operators"]))
|
first[tensor_index] if tensor_index in first else len(g["operators"]))
|
||||||
|
|
||||||
nodes.append({"id": TensorName(tensor_index),
|
nodes.append({
|
||||||
"name": "%s (%d)" % (tensor["name"], tensor_index),
|
"id": TensorName(tensor_index),
|
||||||
"group": 1,
|
"name": "%s (%d)" % (tensor["name"], tensor_index),
|
||||||
"x": 2,
|
"group": 1,
|
||||||
"y": initial_y})
|
"x": 2,
|
||||||
|
"y": initial_y
|
||||||
|
})
|
||||||
graph_str = json.dumps({"nodes": nodes, "edges": edges})
|
graph_str = json.dumps({"nodes": nodes, "edges": edges})
|
||||||
|
|
||||||
html = _D3_HTML_TEMPLATE % (graph_str, subgraph_idx)
|
html = _D3_HTML_TEMPLATE % (graph_str, subgraph_idx)
|
||||||
@ -267,7 +278,7 @@ def GenerateTableHtml(items, keys_to_print, display_index=True):
|
|||||||
for h, mapper in keys_to_print:
|
for h, mapper in keys_to_print:
|
||||||
val = tensor[h] if h in tensor else None
|
val = tensor[h] if h in tensor else None
|
||||||
val = val if mapper is None else mapper(val)
|
val = val if mapper is None else mapper(val)
|
||||||
html += "<td>%s</td>\n"%val
|
html += "<td>%s</td>\n" % val
|
||||||
|
|
||||||
html += "</tr>\n"
|
html += "</tr>\n"
|
||||||
html += "</table>\n"
|
html += "</table>\n"
|
||||||
@ -279,18 +290,19 @@ def CreateHtmlFile(tflite_input, html_output):
|
|||||||
|
|
||||||
# Convert the model into a JSON flatbuffer using flatc (build if doesn't
|
# Convert the model into a JSON flatbuffer using flatc (build if doesn't
|
||||||
# exist.
|
# exist.
|
||||||
if not os.path.exists(tflite_input):
|
if not os.path.exists(tflite_input):
|
||||||
raise RuntimeError("Invalid filename %r" % tflite_input)
|
raise RuntimeError("Invalid filename %r" % tflite_input)
|
||||||
if tflite_input.endswith(".tflite") or tflite_input.endswith(".bin"):
|
if tflite_input.endswith(".tflite") or tflite_input.endswith(".bin"):
|
||||||
|
|
||||||
# Run convert
|
# Run convert
|
||||||
cmd = (_BINARY + " -t "
|
cmd = (
|
||||||
"--strict-json --defaults-json -o /tmp {schema} -- {input}".format(
|
_BINARY + " -t "
|
||||||
input=tflite_input, schema=_SCHEMA))
|
"--strict-json --defaults-json -o /tmp {schema} -- {input}".format(
|
||||||
|
input=tflite_input, schema=_SCHEMA))
|
||||||
print(cmd)
|
print(cmd)
|
||||||
os.system(cmd)
|
os.system(cmd)
|
||||||
real_output = ("/tmp/"+ os.path.splitext(os.path.split(tflite_input)[-1])[0]
|
real_output = ("/tmp/" + os.path.splitext(
|
||||||
+ ".json")
|
os.path.split(tflite_input)[-1])[0] + ".json")
|
||||||
|
|
||||||
data = json.load(open(real_output))
|
data = json.load(open(real_output))
|
||||||
elif tflite_input.endswith(".json"):
|
elif tflite_input.endswith(".json"):
|
||||||
@ -302,12 +314,13 @@ def CreateHtmlFile(tflite_input, html_output):
|
|||||||
html += "<h1>TensorFlow Lite Model</h2>"
|
html += "<h1>TensorFlow Lite Model</h2>"
|
||||||
|
|
||||||
data["filename"] = tflite_input # Avoid special case
|
data["filename"] = tflite_input # Avoid special case
|
||||||
toplevel_stuff = [("filename", None), ("version", None),
|
toplevel_stuff = [("filename", None), ("version", None), ("description",
|
||||||
("description", None)]
|
None)]
|
||||||
|
|
||||||
html += "<table>\n"
|
html += "<table>\n"
|
||||||
for key, mapping in toplevel_stuff:
|
for key, mapping in toplevel_stuff:
|
||||||
if not mapping: mapping = lambda x: x
|
if not mapping:
|
||||||
|
mapping = lambda x: x
|
||||||
html += "<tr><th>%s</th><td>%s</td></tr>\n" % (key, mapping(data[key]))
|
html += "<tr><th>%s</th><td>%s</td></tr>\n" % (key, mapping(data[key]))
|
||||||
html += "</table>\n"
|
html += "</table>\n"
|
||||||
|
|
||||||
@ -320,22 +333,22 @@ def CreateHtmlFile(tflite_input, html_output):
|
|||||||
html += "<div class='subgraph'>"
|
html += "<div class='subgraph'>"
|
||||||
tensor_mapper = TensorMapper(g)
|
tensor_mapper = TensorMapper(g)
|
||||||
opcode_mapper = OpCodeMapper(data)
|
opcode_mapper = OpCodeMapper(data)
|
||||||
op_keys_to_display = [
|
op_keys_to_display = [("inputs", tensor_mapper), ("outputs", tensor_mapper),
|
||||||
("inputs", tensor_mapper), ("outputs", tensor_mapper),
|
("builtin_options", None), ("opcode_index",
|
||||||
("builtin_options", None), ("opcode_index", opcode_mapper)]
|
opcode_mapper)]
|
||||||
tensor_keys_to_display = [
|
tensor_keys_to_display = [("name", None), ("type", None), ("shape", None),
|
||||||
("name", None), ("type", None), ("shape", None), ("buffer", None),
|
("buffer", None), ("quantization", None)]
|
||||||
("quantization", None)]
|
|
||||||
|
|
||||||
html += "<h2>Subgraph %d</h2>\n" % subgraph_idx
|
html += "<h2>Subgraph %d</h2>\n" % subgraph_idx
|
||||||
|
|
||||||
# Inputs and outputs.
|
# Inputs and outputs.
|
||||||
html += "<h3>Inputs/Outputs</h3>\n"
|
html += "<h3>Inputs/Outputs</h3>\n"
|
||||||
html += GenerateTableHtml([{"inputs": g["inputs"],
|
html += GenerateTableHtml(
|
||||||
"outputs": g["outputs"]}],
|
[{
|
||||||
[("inputs", tensor_mapper),
|
"inputs": g["inputs"],
|
||||||
("outputs", tensor_mapper)],
|
"outputs": g["outputs"]
|
||||||
display_index=False)
|
}], [("inputs", tensor_mapper), ("outputs", tensor_mapper)],
|
||||||
|
display_index=False)
|
||||||
|
|
||||||
# Print the tensors.
|
# Print the tensors.
|
||||||
html += "<h3>Tensors</h3>\n"
|
html += "<h3>Tensors</h3>\n"
|
||||||
@ -357,8 +370,7 @@ def CreateHtmlFile(tflite_input, html_output):
|
|||||||
|
|
||||||
# Operator codes
|
# Operator codes
|
||||||
html += "<h2>Operator Codes</h2>\n"
|
html += "<h2>Operator Codes</h2>\n"
|
||||||
html += GenerateTableHtml(data["operator_codes"],
|
html += GenerateTableHtml(data["operator_codes"], operator_keys_to_display)
|
||||||
operator_keys_to_display)
|
|
||||||
|
|
||||||
html += "</body></html>\n"
|
html += "</body></html>\n"
|
||||||
|
|
||||||
@ -370,10 +382,10 @@ def main(argv):
|
|||||||
tflite_input = argv[1]
|
tflite_input = argv[1]
|
||||||
html_output = argv[2]
|
html_output = argv[2]
|
||||||
except IndexError:
|
except IndexError:
|
||||||
print ("Usage: %s <input tflite> <output html>" % (argv[0]))
|
print("Usage: %s <input tflite> <output html>" % (argv[0]))
|
||||||
else:
|
else:
|
||||||
CreateHtmlFile(tflite_input, html_output)
|
CreateHtmlFile(tflite_input, html_output)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main(sys.argv)
|
main(sys.argv)
|
||||||
|
|
||||||
|
|||||||
@ -58,6 +58,7 @@ def read_cifar10(filename_queue):
|
|||||||
|
|
||||||
class CIFAR10Record(object):
|
class CIFAR10Record(object):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
result = CIFAR10Record()
|
result = CIFAR10Record()
|
||||||
|
|
||||||
# Dimensions of the images in the CIFAR-10 dataset.
|
# Dimensions of the images in the CIFAR-10 dataset.
|
||||||
@ -147,8 +148,9 @@ def distorted_inputs(data_dir, batch_size):
|
|||||||
images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
|
images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
|
||||||
labels: Labels. 1D tensor of [batch_size] size.
|
labels: Labels. 1D tensor of [batch_size] size.
|
||||||
"""
|
"""
|
||||||
filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
|
filenames = [
|
||||||
for i in xrange(1, 6)]
|
os.path.join(data_dir, 'data_batch_%d.bin' % i) for i in xrange(1, 6)
|
||||||
|
]
|
||||||
for f in filenames:
|
for f in filenames:
|
||||||
if not tf.gfile.Exists(f):
|
if not tf.gfile.Exists(f):
|
||||||
raise ValueError('Failed to find file: ' + f)
|
raise ValueError('Failed to find file: ' + f)
|
||||||
@ -174,10 +176,9 @@ def distorted_inputs(data_dir, batch_size):
|
|||||||
|
|
||||||
# Because these operations are not commutative, consider randomizing
|
# Because these operations are not commutative, consider randomizing
|
||||||
# the order their operation.
|
# the order their operation.
|
||||||
distorted_image = tf.image.random_brightness(distorted_image,
|
distorted_image = tf.image.random_brightness(distorted_image, max_delta=63)
|
||||||
max_delta=63)
|
distorted_image = tf.image.random_contrast(
|
||||||
distorted_image = tf.image.random_contrast(distorted_image,
|
distorted_image, lower=0.2, upper=1.8)
|
||||||
lower=0.2, upper=1.8)
|
|
||||||
|
|
||||||
# Subtract off the mean and divide by the variance of the pixels.
|
# Subtract off the mean and divide by the variance of the pixels.
|
||||||
float_image = tf.image.per_image_standardization(distorted_image)
|
float_image = tf.image.per_image_standardization(distorted_image)
|
||||||
@ -188,15 +189,18 @@ def distorted_inputs(data_dir, batch_size):
|
|||||||
|
|
||||||
# Ensure that the random shuffling has good mixing properties.
|
# Ensure that the random shuffling has good mixing properties.
|
||||||
min_fraction_of_examples_in_queue = 0.4
|
min_fraction_of_examples_in_queue = 0.4
|
||||||
min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN *
|
min_queue_examples = int(
|
||||||
min_fraction_of_examples_in_queue)
|
NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * min_fraction_of_examples_in_queue)
|
||||||
print ('Filling queue with %d CIFAR images before starting to train. '
|
print('Filling queue with %d CIFAR images before starting to train. '
|
||||||
'This will take a few minutes.' % min_queue_examples)
|
'This will take a few minutes.' % min_queue_examples)
|
||||||
|
|
||||||
# Generate a batch of images and labels by building up a queue of examples.
|
# Generate a batch of images and labels by building up a queue of examples.
|
||||||
return _generate_image_and_label_batch(float_image, read_input.label,
|
return _generate_image_and_label_batch(
|
||||||
min_queue_examples, batch_size,
|
float_image,
|
||||||
shuffle=True)
|
read_input.label,
|
||||||
|
min_queue_examples,
|
||||||
|
batch_size,
|
||||||
|
shuffle=True)
|
||||||
|
|
||||||
|
|
||||||
def inputs(eval_data, data_dir, batch_size):
|
def inputs(eval_data, data_dir, batch_size):
|
||||||
@ -212,8 +216,9 @@ def inputs(eval_data, data_dir, batch_size):
|
|||||||
labels: Labels. 1D tensor of [batch_size] size.
|
labels: Labels. 1D tensor of [batch_size] size.
|
||||||
"""
|
"""
|
||||||
if not eval_data:
|
if not eval_data:
|
||||||
filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
|
filenames = [
|
||||||
for i in xrange(1, 6)]
|
os.path.join(data_dir, 'data_batch_%d.bin' % i) for i in xrange(1, 6)
|
||||||
|
]
|
||||||
num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN
|
num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN
|
||||||
else:
|
else:
|
||||||
filenames = [os.path.join(data_dir, 'test_batch.bin')]
|
filenames = [os.path.join(data_dir, 'test_batch.bin')]
|
||||||
@ -235,8 +240,8 @@ def inputs(eval_data, data_dir, batch_size):
|
|||||||
|
|
||||||
# Image processing for evaluation.
|
# Image processing for evaluation.
|
||||||
# Crop the central [height, width] of the image.
|
# Crop the central [height, width] of the image.
|
||||||
resized_image = tf.image.resize_image_with_crop_or_pad(reshaped_image,
|
resized_image = tf.image.resize_image_with_crop_or_pad(
|
||||||
width, height)
|
reshaped_image, width, height)
|
||||||
|
|
||||||
# Subtract off the mean and divide by the variance of the pixels.
|
# Subtract off the mean and divide by the variance of the pixels.
|
||||||
float_image = tf.image.per_image_standardization(resized_image)
|
float_image = tf.image.per_image_standardization(resized_image)
|
||||||
@ -247,10 +252,13 @@ def inputs(eval_data, data_dir, batch_size):
|
|||||||
|
|
||||||
# Ensure that the random shuffling has good mixing properties.
|
# Ensure that the random shuffling has good mixing properties.
|
||||||
min_fraction_of_examples_in_queue = 0.4
|
min_fraction_of_examples_in_queue = 0.4
|
||||||
min_queue_examples = int(num_examples_per_epoch *
|
min_queue_examples = int(
|
||||||
min_fraction_of_examples_in_queue)
|
num_examples_per_epoch * min_fraction_of_examples_in_queue)
|
||||||
|
|
||||||
# Generate a batch of images and labels by building up a queue of examples.
|
# Generate a batch of images and labels by building up a queue of examples.
|
||||||
return _generate_image_and_label_batch(float_image, read_input.label,
|
return _generate_image_and_label_batch(
|
||||||
min_queue_examples, batch_size,
|
float_image,
|
||||||
shuffle=False)
|
read_input.label,
|
||||||
|
min_queue_examples,
|
||||||
|
batch_size,
|
||||||
|
shuffle=False)
|
||||||
|
|||||||
@ -42,7 +42,6 @@ from tensorflow.python.platform import test
|
|||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.contrib.rnn.python.ops import rnn_cell as contrib_rnn_cell
|
from tensorflow.contrib.rnn.python.ops import rnn_cell as contrib_rnn_cell
|
||||||
|
|
||||||
|
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
Linear = core_rnn_cell._Linear # pylint: disable=invalid-name
|
Linear = core_rnn_cell._Linear # pylint: disable=invalid-name
|
||||||
|
|
||||||
@ -84,19 +83,22 @@ class RNNCellTest(test.TestCase):
|
|||||||
], [v.name for v in cell.trainable_variables])
|
], [v.name for v in cell.trainable_variables])
|
||||||
self.assertFalse(cell.non_trainable_variables)
|
self.assertFalse(cell.non_trainable_variables)
|
||||||
sess.run([variables_lib.global_variables_initializer()])
|
sess.run([variables_lib.global_variables_initializer()])
|
||||||
res = sess.run(
|
res = sess.run([g], {
|
||||||
[g], {x.name: np.array([[1., 1.]]),
|
x.name: np.array([[1., 1.]]),
|
||||||
m.name: np.array([[0.1, 0.1]])})
|
m.name: np.array([[0.1, 0.1]])
|
||||||
|
})
|
||||||
self.assertEqual(res[0].shape, (1, 2))
|
self.assertEqual(res[0].shape, (1, 2))
|
||||||
|
|
||||||
def testBasicRNNCellNotTrainable(self):
|
def testBasicRNNCellNotTrainable(self):
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
|
|
||||||
def not_trainable_getter(getter, *args, **kwargs):
|
def not_trainable_getter(getter, *args, **kwargs):
|
||||||
kwargs["trainable"] = False
|
kwargs["trainable"] = False
|
||||||
return getter(*args, **kwargs)
|
return getter(*args, **kwargs)
|
||||||
|
|
||||||
with variable_scope.variable_scope(
|
with variable_scope.variable_scope(
|
||||||
"root", initializer=init_ops.constant_initializer(0.5),
|
"root",
|
||||||
|
initializer=init_ops.constant_initializer(0.5),
|
||||||
custom_getter=not_trainable_getter):
|
custom_getter=not_trainable_getter):
|
||||||
x = array_ops.zeros([1, 2])
|
x = array_ops.zeros([1, 2])
|
||||||
m = array_ops.zeros([1, 2])
|
m = array_ops.zeros([1, 2])
|
||||||
@ -108,9 +110,10 @@ class RNNCellTest(test.TestCase):
|
|||||||
"root/basic_rnn_cell/%s:0" % rnn_cell_impl._BIAS_VARIABLE_NAME
|
"root/basic_rnn_cell/%s:0" % rnn_cell_impl._BIAS_VARIABLE_NAME
|
||||||
], [v.name for v in cell.non_trainable_variables])
|
], [v.name for v in cell.non_trainable_variables])
|
||||||
sess.run([variables_lib.global_variables_initializer()])
|
sess.run([variables_lib.global_variables_initializer()])
|
||||||
res = sess.run(
|
res = sess.run([g], {
|
||||||
[g], {x.name: np.array([[1., 1.]]),
|
x.name: np.array([[1., 1.]]),
|
||||||
m.name: np.array([[0.1, 0.1]])})
|
m.name: np.array([[0.1, 0.1]])
|
||||||
|
})
|
||||||
self.assertEqual(res[0].shape, (1, 2))
|
self.assertEqual(res[0].shape, (1, 2))
|
||||||
|
|
||||||
def testGRUCell(self):
|
def testGRUCell(self):
|
||||||
@ -121,9 +124,10 @@ class RNNCellTest(test.TestCase):
|
|||||||
m = array_ops.zeros([1, 2])
|
m = array_ops.zeros([1, 2])
|
||||||
g, _ = rnn_cell_impl.GRUCell(2)(x, m)
|
g, _ = rnn_cell_impl.GRUCell(2)(x, m)
|
||||||
sess.run([variables_lib.global_variables_initializer()])
|
sess.run([variables_lib.global_variables_initializer()])
|
||||||
res = sess.run(
|
res = sess.run([g], {
|
||||||
[g], {x.name: np.array([[1., 1.]]),
|
x.name: np.array([[1., 1.]]),
|
||||||
m.name: np.array([[0.1, 0.1]])})
|
m.name: np.array([[0.1, 0.1]])
|
||||||
|
})
|
||||||
# Smoke test
|
# Smoke test
|
||||||
self.assertAllClose(res[0], [[0.175991, 0.175991]])
|
self.assertAllClose(res[0], [[0.175991, 0.175991]])
|
||||||
with variable_scope.variable_scope(
|
with variable_scope.variable_scope(
|
||||||
@ -133,10 +137,10 @@ class RNNCellTest(test.TestCase):
|
|||||||
m = array_ops.zeros([1, 2])
|
m = array_ops.zeros([1, 2])
|
||||||
g, _ = rnn_cell_impl.GRUCell(2)(x, m)
|
g, _ = rnn_cell_impl.GRUCell(2)(x, m)
|
||||||
sess.run([variables_lib.global_variables_initializer()])
|
sess.run([variables_lib.global_variables_initializer()])
|
||||||
res = sess.run(
|
res = sess.run([g], {
|
||||||
[g],
|
x.name: np.array([[1., 1., 1.]]),
|
||||||
{x.name: np.array([[1., 1., 1.]]),
|
m.name: np.array([[0.1, 0.1]])
|
||||||
m.name: np.array([[0.1, 0.1]])})
|
})
|
||||||
# Smoke test
|
# Smoke test
|
||||||
self.assertAllClose(res[0], [[0.156736, 0.156736]])
|
self.assertAllClose(res[0], [[0.156736, 0.156736]])
|
||||||
|
|
||||||
@ -148,11 +152,12 @@ class RNNCellTest(test.TestCase):
|
|||||||
m = array_ops.zeros([1, 2])
|
m = array_ops.zeros([1, 2])
|
||||||
g, _ = contrib_rnn_cell.SRUCell(2)(x, m)
|
g, _ = contrib_rnn_cell.SRUCell(2)(x, m)
|
||||||
sess.run([variables_lib.global_variables_initializer()])
|
sess.run([variables_lib.global_variables_initializer()])
|
||||||
res = sess.run(
|
res = sess.run([g], {
|
||||||
[g], {x.name: np.array([[1., 1.]]),
|
x.name: np.array([[1., 1.]]),
|
||||||
m.name: np.array([[0.1, 0.1]])})
|
m.name: np.array([[0.1, 0.1]])
|
||||||
|
})
|
||||||
# Smoke test
|
# Smoke test
|
||||||
self.assertAllClose(res[0], [[0.509682, 0.509682]])
|
self.assertAllClose(res[0], [[0.509682, 0.509682]])
|
||||||
|
|
||||||
def testBasicLSTMCell(self):
|
def testBasicLSTMCell(self):
|
||||||
for dtype in [dtypes.float16, dtypes.float32]:
|
for dtype in [dtypes.float16, dtypes.float32]:
|
||||||
@ -164,8 +169,7 @@ class RNNCellTest(test.TestCase):
|
|||||||
m = array_ops.zeros([1, 8], dtype=dtype)
|
m = array_ops.zeros([1, 8], dtype=dtype)
|
||||||
cell = rnn_cell_impl.MultiRNNCell(
|
cell = rnn_cell_impl.MultiRNNCell(
|
||||||
[
|
[
|
||||||
rnn_cell_impl.BasicLSTMCell(
|
rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)
|
||||||
2, state_is_tuple=False)
|
|
||||||
for _ in range(2)
|
for _ in range(2)
|
||||||
],
|
],
|
||||||
state_is_tuple=False)
|
state_is_tuple=False)
|
||||||
@ -183,22 +187,21 @@ class RNNCellTest(test.TestCase):
|
|||||||
"root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0" %
|
"root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0" %
|
||||||
rnn_cell_impl._BIAS_VARIABLE_NAME
|
rnn_cell_impl._BIAS_VARIABLE_NAME
|
||||||
]
|
]
|
||||||
self.assertEqual(
|
self.assertEqual(expected_variable_names,
|
||||||
expected_variable_names,
|
[v.name for v in cell.trainable_variables])
|
||||||
[v.name for v in cell.trainable_variables])
|
|
||||||
self.assertFalse(cell.non_trainable_variables)
|
self.assertFalse(cell.non_trainable_variables)
|
||||||
sess.run([variables_lib.global_variables_initializer()])
|
sess.run([variables_lib.global_variables_initializer()])
|
||||||
res = sess.run(
|
res = sess.run([g, out_m], {
|
||||||
[g, out_m],
|
x.name: np.array([[1., 1.]]),
|
||||||
{x.name: np.array([[1., 1.]]),
|
m.name: 0.1 * np.ones([1, 8])
|
||||||
m.name: 0.1 * np.ones([1, 8])})
|
})
|
||||||
self.assertEqual(len(res), 2)
|
self.assertEqual(len(res), 2)
|
||||||
variables = variables_lib.global_variables()
|
variables = variables_lib.global_variables()
|
||||||
self.assertEqual(expected_variable_names, [v.name for v in variables])
|
self.assertEqual(expected_variable_names, [v.name for v in variables])
|
||||||
# The numbers in results were not calculated, this is just a
|
# The numbers in results were not calculated, this is just a
|
||||||
# smoke test.
|
# smoke test.
|
||||||
self.assertAllClose(
|
self.assertAllClose(res[0], np.array(
|
||||||
res[0], np.array([[0.240, 0.240]], dtype=np_dtype), 1e-2)
|
[[0.240, 0.240]], dtype=np_dtype), 1e-2)
|
||||||
expected_mem = np.array(
|
expected_mem = np.array(
|
||||||
[[0.689, 0.689, 0.448, 0.448, 0.398, 0.398, 0.240, 0.240]],
|
[[0.689, 0.689, 0.448, 0.448, 0.398, 0.398, 0.240, 0.240]],
|
||||||
dtype=np_dtype)
|
dtype=np_dtype)
|
||||||
@ -208,13 +211,13 @@ class RNNCellTest(test.TestCase):
|
|||||||
# Test BasicLSTMCell with input_size != num_units.
|
# Test BasicLSTMCell with input_size != num_units.
|
||||||
x = array_ops.zeros([1, 3], dtype=dtype)
|
x = array_ops.zeros([1, 3], dtype=dtype)
|
||||||
m = array_ops.zeros([1, 4], dtype=dtype)
|
m = array_ops.zeros([1, 4], dtype=dtype)
|
||||||
g, out_m = rnn_cell_impl.BasicLSTMCell(
|
g, out_m = rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)(x, m)
|
||||||
2, state_is_tuple=False)(x, m)
|
|
||||||
sess.run([variables_lib.global_variables_initializer()])
|
sess.run([variables_lib.global_variables_initializer()])
|
||||||
res = sess.run(
|
res = sess.run(
|
||||||
[g, out_m],
|
[g, out_m], {
|
||||||
{x.name: np.array([[1., 1., 1.]], dtype=np_dtype),
|
x.name: np.array([[1., 1., 1.]], dtype=np_dtype),
|
||||||
m.name: 0.1 * np.ones([1, 4], dtype=np_dtype)})
|
m.name: 0.1 * np.ones([1, 4], dtype=np_dtype)
|
||||||
|
})
|
||||||
self.assertEqual(len(res), 2)
|
self.assertEqual(len(res), 2)
|
||||||
|
|
||||||
def testBasicLSTMCellDimension0Error(self):
|
def testBasicLSTMCellDimension0Error(self):
|
||||||
@ -232,9 +235,11 @@ class RNNCellTest(test.TestCase):
|
|||||||
g, out_m = rnn_cell_impl.BasicLSTMCell(
|
g, out_m = rnn_cell_impl.BasicLSTMCell(
|
||||||
num_units, state_is_tuple=False)(x, m)
|
num_units, state_is_tuple=False)(x, m)
|
||||||
sess.run([variables_lib.global_variables_initializer()])
|
sess.run([variables_lib.global_variables_initializer()])
|
||||||
sess.run([g, out_m],
|
sess.run(
|
||||||
{x.name: 1 * np.ones([batch_size, input_size]),
|
[g, out_m], {
|
||||||
m.name: 0.1 * np.ones([batch_size - 1, state_size])})
|
x.name: 1 * np.ones([batch_size, input_size]),
|
||||||
|
m.name: 0.1 * np.ones([batch_size - 1, state_size])
|
||||||
|
})
|
||||||
|
|
||||||
def testBasicLSTMCellStateSizeError(self):
|
def testBasicLSTMCellStateSizeError(self):
|
||||||
"""Tests that state_size must be num_units * 2."""
|
"""Tests that state_size must be num_units * 2."""
|
||||||
@ -251,9 +256,11 @@ class RNNCellTest(test.TestCase):
|
|||||||
g, out_m = rnn_cell_impl.BasicLSTMCell(
|
g, out_m = rnn_cell_impl.BasicLSTMCell(
|
||||||
num_units, state_is_tuple=False)(x, m)
|
num_units, state_is_tuple=False)(x, m)
|
||||||
sess.run([variables_lib.global_variables_initializer()])
|
sess.run([variables_lib.global_variables_initializer()])
|
||||||
sess.run([g, out_m],
|
sess.run(
|
||||||
{x.name: 1 * np.ones([batch_size, input_size]),
|
[g, out_m], {
|
||||||
m.name: 0.1 * np.ones([batch_size, state_size])})
|
x.name: 1 * np.ones([batch_size, input_size]),
|
||||||
|
m.name: 0.1 * np.ones([batch_size, state_size])
|
||||||
|
})
|
||||||
|
|
||||||
def testBasicLSTMCellStateTupleType(self):
|
def testBasicLSTMCellStateTupleType(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
@ -301,11 +308,12 @@ class RNNCellTest(test.TestCase):
|
|||||||
state_is_tuple=True)
|
state_is_tuple=True)
|
||||||
g, (out_m0, out_m1) = cell(x, (m0, m1))
|
g, (out_m0, out_m1) = cell(x, (m0, m1))
|
||||||
sess.run([variables_lib.global_variables_initializer()])
|
sess.run([variables_lib.global_variables_initializer()])
|
||||||
res = sess.run([g, out_m0, out_m1], {
|
res = sess.run(
|
||||||
x.name: np.array([[1., 1.]]),
|
[g, out_m0, out_m1], {
|
||||||
m0.name: 0.1 * np.ones([1, 4]),
|
x.name: np.array([[1., 1.]]),
|
||||||
m1.name: 0.1 * np.ones([1, 4])
|
m0.name: 0.1 * np.ones([1, 4]),
|
||||||
})
|
m1.name: 0.1 * np.ones([1, 4])
|
||||||
|
})
|
||||||
self.assertEqual(len(res), 3)
|
self.assertEqual(len(res), 3)
|
||||||
# The numbers in results were not calculated, this is just a smoke test.
|
# The numbers in results were not calculated, this is just a smoke test.
|
||||||
# Note, however, these values should match the original
|
# Note, however, these values should match the original
|
||||||
@ -336,10 +344,11 @@ class RNNCellTest(test.TestCase):
|
|||||||
state_is_tuple=False)
|
state_is_tuple=False)
|
||||||
output, state = cell(x, m)
|
output, state = cell(x, m)
|
||||||
sess.run([variables_lib.global_variables_initializer()])
|
sess.run([variables_lib.global_variables_initializer()])
|
||||||
res = sess.run([output, state], {
|
res = sess.run(
|
||||||
x.name: np.array([[1., 1.], [2., 2.], [3., 3.]]),
|
[output, state], {
|
||||||
m.name: 0.1 * np.ones((batch_size, state_size))
|
x.name: np.array([[1., 1.], [2., 2.], [3., 3.]]),
|
||||||
})
|
m.name: 0.1 * np.ones((batch_size, state_size))
|
||||||
|
})
|
||||||
self.assertEqual(len(res), 2)
|
self.assertEqual(len(res), 2)
|
||||||
# The numbers in results were not calculated, this is mostly just a
|
# The numbers in results were not calculated, this is mostly just a
|
||||||
# smoke test.
|
# smoke test.
|
||||||
@ -442,10 +451,10 @@ class RNNCellTest(test.TestCase):
|
|||||||
rnn_cell_impl.GRUCell(3), num_proj=3)
|
rnn_cell_impl.GRUCell(3), num_proj=3)
|
||||||
g, new_m = cell(x, m)
|
g, new_m = cell(x, m)
|
||||||
sess.run([variables_lib.global_variables_initializer()])
|
sess.run([variables_lib.global_variables_initializer()])
|
||||||
res = sess.run(
|
res = sess.run([g, new_m], {
|
||||||
[g, new_m],
|
x.name: np.array([[1., 1.]]),
|
||||||
{x.name: np.array([[1., 1.]]),
|
m.name: np.array([[0.1, 0.1, 0.1]])
|
||||||
m.name: np.array([[0.1, 0.1, 0.1]])})
|
})
|
||||||
self.assertEqual(res[1].shape, (1, 3))
|
self.assertEqual(res[1].shape, (1, 3))
|
||||||
# The numbers in results were not calculated, this is just a smoke test.
|
# The numbers in results were not calculated, this is just a smoke test.
|
||||||
self.assertAllClose(res[0], [[0.154605, 0.154605, 0.154605]])
|
self.assertAllClose(res[0], [[0.154605, 0.154605, 0.154605]])
|
||||||
@ -479,9 +488,11 @@ class RNNCellTest(test.TestCase):
|
|||||||
base_cell = rnn_cell_impl.GRUCell(3)
|
base_cell = rnn_cell_impl.GRUCell(3)
|
||||||
g, m_new = base_cell(x, m)
|
g, m_new = base_cell(x, m)
|
||||||
variable_scope.get_variable_scope().reuse_variables()
|
variable_scope.get_variable_scope().reuse_variables()
|
||||||
|
|
||||||
def residual_with_slice_fn(inp, out):
|
def residual_with_slice_fn(inp, out):
|
||||||
inp_sliced = array_ops.slice(inp, [0, 0], [-1, 3])
|
inp_sliced = array_ops.slice(inp, [0, 0], [-1, 3])
|
||||||
return inp_sliced + out
|
return inp_sliced + out
|
||||||
|
|
||||||
g_res, m_new_res = rnn_cell_impl.ResidualWrapper(
|
g_res, m_new_res = rnn_cell_impl.ResidualWrapper(
|
||||||
base_cell, residual_with_slice_fn)(x, m)
|
base_cell, residual_with_slice_fn)(x, m)
|
||||||
sess.run([variables_lib.global_variables_initializer()])
|
sess.run([variables_lib.global_variables_initializer()])
|
||||||
@ -551,10 +562,10 @@ class RNNCellTest(test.TestCase):
|
|||||||
self.assertEqual(embedding_cell.output_size, 2)
|
self.assertEqual(embedding_cell.output_size, 2)
|
||||||
g, new_m = embedding_cell(x, m)
|
g, new_m = embedding_cell(x, m)
|
||||||
sess.run([variables_lib.global_variables_initializer()])
|
sess.run([variables_lib.global_variables_initializer()])
|
||||||
res = sess.run(
|
res = sess.run([g, new_m], {
|
||||||
[g, new_m],
|
x.name: np.array([[1]]),
|
||||||
{x.name: np.array([[1]]),
|
m.name: np.array([[0.1, 0.1]])
|
||||||
m.name: np.array([[0.1, 0.1]])})
|
})
|
||||||
self.assertEqual(res[1].shape, (1, 2))
|
self.assertEqual(res[1].shape, (1, 2))
|
||||||
# The numbers in results were not calculated, this is just a smoke test.
|
# The numbers in results were not calculated, this is just a smoke test.
|
||||||
self.assertAllClose(res[0], [[0.17139, 0.17139]])
|
self.assertAllClose(res[0], [[0.17139, 0.17139]])
|
||||||
@ -584,8 +595,8 @@ class RNNCellTest(test.TestCase):
|
|||||||
x = array_ops.zeros([1, 2])
|
x = array_ops.zeros([1, 2])
|
||||||
m = array_ops.zeros([1, 4])
|
m = array_ops.zeros([1, 4])
|
||||||
_, ml = rnn_cell_impl.MultiRNNCell(
|
_, ml = rnn_cell_impl.MultiRNNCell(
|
||||||
[rnn_cell_impl.GRUCell(2)
|
[rnn_cell_impl.GRUCell(2) for _ in range(2)],
|
||||||
for _ in range(2)], state_is_tuple=False)(x, m)
|
state_is_tuple=False)(x, m)
|
||||||
sess.run([variables_lib.global_variables_initializer()])
|
sess.run([variables_lib.global_variables_initializer()])
|
||||||
res = sess.run(ml, {
|
res = sess.run(ml, {
|
||||||
x.name: np.array([[1., 1.]]),
|
x.name: np.array([[1., 1.]]),
|
||||||
@ -605,19 +616,20 @@ class RNNCellTest(test.TestCase):
|
|||||||
# Test incorrectness of state
|
# Test incorrectness of state
|
||||||
with self.assertRaisesRegexp(ValueError, "Expected state .* a tuple"):
|
with self.assertRaisesRegexp(ValueError, "Expected state .* a tuple"):
|
||||||
rnn_cell_impl.MultiRNNCell(
|
rnn_cell_impl.MultiRNNCell(
|
||||||
[rnn_cell_impl.GRUCell(2)
|
[rnn_cell_impl.GRUCell(2) for _ in range(2)],
|
||||||
for _ in range(2)], state_is_tuple=True)(x, m_bad)
|
state_is_tuple=True)(x, m_bad)
|
||||||
|
|
||||||
_, ml = rnn_cell_impl.MultiRNNCell(
|
_, ml = rnn_cell_impl.MultiRNNCell(
|
||||||
[rnn_cell_impl.GRUCell(2)
|
[rnn_cell_impl.GRUCell(2) for _ in range(2)],
|
||||||
for _ in range(2)], state_is_tuple=True)(x, m_good)
|
state_is_tuple=True)(x, m_good)
|
||||||
|
|
||||||
sess.run([variables_lib.global_variables_initializer()])
|
sess.run([variables_lib.global_variables_initializer()])
|
||||||
res = sess.run(ml, {
|
res = sess.run(
|
||||||
x.name: np.array([[1., 1.]]),
|
ml, {
|
||||||
m_good[0].name: np.array([[0.1, 0.1]]),
|
x.name: np.array([[1., 1.]]),
|
||||||
m_good[1].name: np.array([[0.1, 0.1]])
|
m_good[0].name: np.array([[0.1, 0.1]]),
|
||||||
})
|
m_good[1].name: np.array([[0.1, 0.1]])
|
||||||
|
})
|
||||||
|
|
||||||
# The numbers in results were not calculated, this is just a
|
# The numbers in results were not calculated, this is just a
|
||||||
# smoke test. However, these numbers should match those of
|
# smoke test. However, these numbers should match those of
|
||||||
@ -628,8 +640,11 @@ class RNNCellTest(test.TestCase):
|
|||||||
|
|
||||||
class DropoutWrapperTest(test.TestCase):
|
class DropoutWrapperTest(test.TestCase):
|
||||||
|
|
||||||
def _testDropoutWrapper(self, batch_size=None, time_steps=None,
|
def _testDropoutWrapper(self,
|
||||||
parallel_iterations=None, **kwargs):
|
batch_size=None,
|
||||||
|
time_steps=None,
|
||||||
|
parallel_iterations=None,
|
||||||
|
**kwargs):
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
with variable_scope.variable_scope(
|
with variable_scope.variable_scope(
|
||||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||||
@ -640,14 +655,14 @@ class DropoutWrapperTest(test.TestCase):
|
|||||||
x = constant_op.constant(
|
x = constant_op.constant(
|
||||||
[[[2., 2., 2.]], [[1., 1., 1.]]], dtype=dtypes.float32)
|
[[[2., 2., 2.]], [[1., 1., 1.]]], dtype=dtypes.float32)
|
||||||
m = rnn_cell_impl.LSTMStateTuple(
|
m = rnn_cell_impl.LSTMStateTuple(
|
||||||
*[constant_op.constant([[0.1, 0.1, 0.1]], dtype=dtypes.float32)
|
*[constant_op.constant([[0.1, 0.1, 0.1]], dtype=dtypes.float32
|
||||||
] * 2)
|
)] * 2)
|
||||||
else:
|
else:
|
||||||
x = constant_op.constant(
|
x = constant_op.constant(
|
||||||
np.random.randn(time_steps, batch_size, 3).astype(np.float32))
|
np.random.randn(time_steps, batch_size, 3).astype(np.float32))
|
||||||
m = rnn_cell_impl.LSTMStateTuple(*[
|
m = rnn_cell_impl.LSTMStateTuple(*[
|
||||||
constant_op.constant(
|
constant_op.
|
||||||
[[0.1, 0.1, 0.1]] * batch_size, dtype=dtypes.float32)
|
constant([[0.1, 0.1, 0.1]] * batch_size, dtype=dtypes.float32)
|
||||||
] * 2)
|
] * 2)
|
||||||
outputs, final_state = rnn.dynamic_rnn(
|
outputs, final_state = rnn.dynamic_rnn(
|
||||||
cell=rnn_cell_impl.DropoutWrapper(
|
cell=rnn_cell_impl.DropoutWrapper(
|
||||||
@ -674,8 +689,8 @@ class DropoutWrapperTest(test.TestCase):
|
|||||||
res = self._testDropoutWrapper(
|
res = self._testDropoutWrapper(
|
||||||
input_keep_prob=keep, output_keep_prob=keep, state_keep_prob=keep)
|
input_keep_prob=keep, output_keep_prob=keep, state_keep_prob=keep)
|
||||||
true_full_output = np.array(
|
true_full_output = np.array(
|
||||||
[[[0.751109, 0.751109, 0.751109]],
|
[[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]],
|
||||||
[[0.895509, 0.895509, 0.895509]]], dtype=np.float32)
|
dtype=np.float32)
|
||||||
true_full_final_c = np.array(
|
true_full_final_c = np.array(
|
||||||
[[1.949385, 1.949385, 1.949385]], dtype=np.float32)
|
[[1.949385, 1.949385, 1.949385]], dtype=np.float32)
|
||||||
self.assertAllClose(true_full_output, res[0])
|
self.assertAllClose(true_full_output, res[0])
|
||||||
@ -687,8 +702,8 @@ class DropoutWrapperTest(test.TestCase):
|
|||||||
res = self._testDropoutWrapper(
|
res = self._testDropoutWrapper(
|
||||||
input_keep_prob=keep, output_keep_prob=keep, state_keep_prob=keep)
|
input_keep_prob=keep, output_keep_prob=keep, state_keep_prob=keep)
|
||||||
true_full_output = np.array(
|
true_full_output = np.array(
|
||||||
[[[0.751109, 0.751109, 0.751109]],
|
[[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]],
|
||||||
[[0.895509, 0.895509, 0.895509]]], dtype=np.float32)
|
dtype=np.float32)
|
||||||
true_full_final_c = np.array(
|
true_full_final_c = np.array(
|
||||||
[[1.949385, 1.949385, 1.949385]], dtype=np.float32)
|
[[1.949385, 1.949385, 1.949385]], dtype=np.float32)
|
||||||
self.assertAllClose(true_full_output, res[0])
|
self.assertAllClose(true_full_output, res[0])
|
||||||
@ -703,16 +718,20 @@ class DropoutWrapperTest(test.TestCase):
|
|||||||
## consistent across both calls. Otherwise the seed may not end
|
## consistent across both calls. Otherwise the seed may not end
|
||||||
## up being munged consistently across both graphs.
|
## up being munged consistently across both graphs.
|
||||||
res_standard_1 = self._testDropoutWrapper(
|
res_standard_1 = self._testDropoutWrapper(
|
||||||
input_keep_prob=keep_some, output_keep_prob=keep_some,
|
input_keep_prob=keep_some,
|
||||||
state_keep_prob=keep_some, seed=10,
|
output_keep_prob=keep_some,
|
||||||
|
state_keep_prob=keep_some,
|
||||||
|
seed=10,
|
||||||
parallel_iterations=1)
|
parallel_iterations=1)
|
||||||
# Clear away the graph and the test session (which keeps variables around)
|
# Clear away the graph and the test session (which keeps variables around)
|
||||||
ops.reset_default_graph()
|
ops.reset_default_graph()
|
||||||
self._ClearCachedSession()
|
self._ClearCachedSession()
|
||||||
random_seed.set_random_seed(2)
|
random_seed.set_random_seed(2)
|
||||||
res_standard_2 = self._testDropoutWrapper(
|
res_standard_2 = self._testDropoutWrapper(
|
||||||
input_keep_prob=keep_some, output_keep_prob=keep_some,
|
input_keep_prob=keep_some,
|
||||||
state_keep_prob=keep_some, seed=10,
|
output_keep_prob=keep_some,
|
||||||
|
state_keep_prob=keep_some,
|
||||||
|
seed=10,
|
||||||
parallel_iterations=1)
|
parallel_iterations=1)
|
||||||
self.assertAllClose(res_standard_1[0], res_standard_2[0])
|
self.assertAllClose(res_standard_1[0], res_standard_2[0])
|
||||||
self.assertAllClose(res_standard_1[1].c, res_standard_2[1].c)
|
self.assertAllClose(res_standard_1[1].c, res_standard_2[1].c)
|
||||||
@ -722,11 +741,12 @@ class DropoutWrapperTest(test.TestCase):
|
|||||||
keep_all = variable_scope.get_variable("all", initializer=1.0)
|
keep_all = variable_scope.get_variable("all", initializer=1.0)
|
||||||
keep_none = variable_scope.get_variable("none", initializer=1e-10)
|
keep_none = variable_scope.get_variable("none", initializer=1e-10)
|
||||||
res = self._testDropoutWrapper(
|
res = self._testDropoutWrapper(
|
||||||
input_keep_prob=keep_all, output_keep_prob=keep_none,
|
input_keep_prob=keep_all,
|
||||||
|
output_keep_prob=keep_none,
|
||||||
state_keep_prob=keep_all)
|
state_keep_prob=keep_all)
|
||||||
true_full_output = np.array(
|
true_full_output = np.array(
|
||||||
[[[0.751109, 0.751109, 0.751109]],
|
[[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]],
|
||||||
[[0.895509, 0.895509, 0.895509]]], dtype=np.float32)
|
dtype=np.float32)
|
||||||
true_full_final_c = np.array(
|
true_full_final_c = np.array(
|
||||||
[[1.949385, 1.949385, 1.949385]], dtype=np.float32)
|
[[1.949385, 1.949385, 1.949385]], dtype=np.float32)
|
||||||
self.assertAllClose(np.zeros(res[0].shape), res[0])
|
self.assertAllClose(np.zeros(res[0].shape), res[0])
|
||||||
@ -739,13 +759,13 @@ class DropoutWrapperTest(test.TestCase):
|
|||||||
# Even though we dropout state, by default DropoutWrapper never
|
# Even though we dropout state, by default DropoutWrapper never
|
||||||
# drops out the memory ("c") term of an LSTMStateTuple.
|
# drops out the memory ("c") term of an LSTMStateTuple.
|
||||||
res = self._testDropoutWrapper(
|
res = self._testDropoutWrapper(
|
||||||
input_keep_prob=keep_all, output_keep_prob=keep_all,
|
input_keep_prob=keep_all,
|
||||||
|
output_keep_prob=keep_all,
|
||||||
state_keep_prob=keep_none)
|
state_keep_prob=keep_none)
|
||||||
true_c_state = np.array(
|
true_c_state = np.array([[1.713925, 1.713925, 1.713925]], dtype=np.float32)
|
||||||
[[1.713925, 1.713925, 1.713925]], dtype=np.float32)
|
|
||||||
true_full_output = np.array(
|
true_full_output = np.array(
|
||||||
[[[0.751109, 0.751109, 0.751109]],
|
[[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]],
|
||||||
[[0.895509, 0.895509, 0.895509]]], dtype=np.float32)
|
dtype=np.float32)
|
||||||
self.assertAllClose(true_full_output[0], res[0][0])
|
self.assertAllClose(true_full_output[0], res[0][0])
|
||||||
# Second output is modified by zero input state
|
# Second output is modified by zero input state
|
||||||
self.assertGreater(np.linalg.norm(true_full_output[1] - res[0][1]), 1e-4)
|
self.assertGreater(np.linalg.norm(true_full_output[1] - res[0][1]), 1e-4)
|
||||||
@ -758,13 +778,14 @@ class DropoutWrapperTest(test.TestCase):
|
|||||||
keep_all = variable_scope.get_variable("all", initializer=1.0)
|
keep_all = variable_scope.get_variable("all", initializer=1.0)
|
||||||
keep_none = variable_scope.get_variable("none", initializer=1e-10)
|
keep_none = variable_scope.get_variable("none", initializer=1e-10)
|
||||||
true_full_output = np.array(
|
true_full_output = np.array(
|
||||||
[[[0.751109, 0.751109, 0.751109]],
|
[[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]],
|
||||||
[[0.895509, 0.895509, 0.895509]]], dtype=np.float32)
|
dtype=np.float32)
|
||||||
true_full_final_c = np.array(
|
true_full_final_c = np.array(
|
||||||
[[1.949385, 1.949385, 1.949385]], dtype=np.float32)
|
[[1.949385, 1.949385, 1.949385]], dtype=np.float32)
|
||||||
# All outputs are different because inputs are zeroed out
|
# All outputs are different because inputs are zeroed out
|
||||||
res = self._testDropoutWrapper(
|
res = self._testDropoutWrapper(
|
||||||
input_keep_prob=keep_none, output_keep_prob=keep_all,
|
input_keep_prob=keep_none,
|
||||||
|
output_keep_prob=keep_all,
|
||||||
state_keep_prob=keep_all)
|
state_keep_prob=keep_all)
|
||||||
self.assertGreater(np.linalg.norm(res[0] - true_full_output), 1e-4)
|
self.assertGreater(np.linalg.norm(res[0] - true_full_output), 1e-4)
|
||||||
self.assertGreater(np.linalg.norm(res[1].h - true_full_output[1]), 1e-4)
|
self.assertGreater(np.linalg.norm(res[1].h - true_full_output[1]), 1e-4)
|
||||||
@ -774,9 +795,13 @@ class DropoutWrapperTest(test.TestCase):
|
|||||||
keep_some = 0.8
|
keep_some = 0.8
|
||||||
keep_all = variable_scope.get_variable("all", initializer=1.0)
|
keep_all = variable_scope.get_variable("all", initializer=1.0)
|
||||||
res = self._testDropoutWrapper(
|
res = self._testDropoutWrapper(
|
||||||
input_keep_prob=keep_all, output_keep_prob=keep_some,
|
input_keep_prob=keep_all,
|
||||||
state_keep_prob=keep_all, variational_recurrent=True,
|
output_keep_prob=keep_some,
|
||||||
input_size=3, batch_size=5, time_steps=7)
|
state_keep_prob=keep_all,
|
||||||
|
variational_recurrent=True,
|
||||||
|
input_size=3,
|
||||||
|
batch_size=5,
|
||||||
|
time_steps=7)
|
||||||
# Ensure the same dropout pattern for all time steps
|
# Ensure the same dropout pattern for all time steps
|
||||||
output_mask = np.abs(res[0]) > 1e-6
|
output_mask = np.abs(res[0]) > 1e-6
|
||||||
for m in output_mask[1:]:
|
for m in output_mask[1:]:
|
||||||
@ -785,9 +810,13 @@ class DropoutWrapperTest(test.TestCase):
|
|||||||
def testDropoutWrapperRecurrentStateInputAndOutput(self):
|
def testDropoutWrapperRecurrentStateInputAndOutput(self):
|
||||||
keep_some = 0.9
|
keep_some = 0.9
|
||||||
res = self._testDropoutWrapper(
|
res = self._testDropoutWrapper(
|
||||||
input_keep_prob=keep_some, output_keep_prob=keep_some,
|
input_keep_prob=keep_some,
|
||||||
state_keep_prob=keep_some, variational_recurrent=True,
|
output_keep_prob=keep_some,
|
||||||
input_size=3, batch_size=5, time_steps=7)
|
state_keep_prob=keep_some,
|
||||||
|
variational_recurrent=True,
|
||||||
|
input_size=3,
|
||||||
|
batch_size=5,
|
||||||
|
time_steps=7)
|
||||||
|
|
||||||
# Smoke test for the state/input masks.
|
# Smoke test for the state/input masks.
|
||||||
output_mask = np.abs(res[0]) > 1e-6
|
output_mask = np.abs(res[0]) > 1e-6
|
||||||
@ -811,17 +840,27 @@ class DropoutWrapperTest(test.TestCase):
|
|||||||
random_seed.set_random_seed(2347)
|
random_seed.set_random_seed(2347)
|
||||||
np.random.seed(23487)
|
np.random.seed(23487)
|
||||||
res0 = self._testDropoutWrapper(
|
res0 = self._testDropoutWrapper(
|
||||||
input_keep_prob=keep_some, output_keep_prob=keep_some,
|
input_keep_prob=keep_some,
|
||||||
state_keep_prob=keep_some, variational_recurrent=True,
|
output_keep_prob=keep_some,
|
||||||
input_size=3, batch_size=5, time_steps=7, seed=-234987)
|
state_keep_prob=keep_some,
|
||||||
|
variational_recurrent=True,
|
||||||
|
input_size=3,
|
||||||
|
batch_size=5,
|
||||||
|
time_steps=7,
|
||||||
|
seed=-234987)
|
||||||
ops.reset_default_graph()
|
ops.reset_default_graph()
|
||||||
self._ClearCachedSession()
|
self._ClearCachedSession()
|
||||||
random_seed.set_random_seed(2347)
|
random_seed.set_random_seed(2347)
|
||||||
np.random.seed(23487)
|
np.random.seed(23487)
|
||||||
res1 = self._testDropoutWrapper(
|
res1 = self._testDropoutWrapper(
|
||||||
input_keep_prob=keep_some, output_keep_prob=keep_some,
|
input_keep_prob=keep_some,
|
||||||
state_keep_prob=keep_some, variational_recurrent=True,
|
output_keep_prob=keep_some,
|
||||||
input_size=3, batch_size=5, time_steps=7, seed=-234987)
|
state_keep_prob=keep_some,
|
||||||
|
variational_recurrent=True,
|
||||||
|
input_size=3,
|
||||||
|
batch_size=5,
|
||||||
|
time_steps=7,
|
||||||
|
seed=-234987)
|
||||||
|
|
||||||
output_mask = np.abs(res0[0]) > 1e-6
|
output_mask = np.abs(res0[0]) > 1e-6
|
||||||
for time_step in output_mask:
|
for time_step in output_mask:
|
||||||
@ -858,9 +897,10 @@ class SlimRNNCellTest(test.TestCase):
|
|||||||
g, _ = rnn_cell_impl._SlimRNNCell(my_cell)(x, m)
|
g, _ = rnn_cell_impl._SlimRNNCell(my_cell)(x, m)
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
sess.run([variables_lib.global_variables_initializer()])
|
sess.run([variables_lib.global_variables_initializer()])
|
||||||
res = sess.run(
|
res = sess.run([g], {
|
||||||
[g], {x.name: np.array([[1., 1.]]),
|
x.name: np.array([[1., 1.]]),
|
||||||
m.name: np.array([[0.1, 0.1]])})
|
m.name: np.array([[0.1, 0.1]])
|
||||||
|
})
|
||||||
self.assertEqual(res[0].shape, (1, 2))
|
self.assertEqual(res[0].shape, (1, 2))
|
||||||
|
|
||||||
def testBasicRNNCellMatch(self):
|
def testBasicRNNCellMatch(self):
|
||||||
|
|||||||
@ -34,8 +34,7 @@ MAX_LABEL = 15
|
|||||||
WORDS_FEATURE = 'words' # Name of the input words feature.
|
WORDS_FEATURE = 'words' # Name of the input words feature.
|
||||||
|
|
||||||
|
|
||||||
def estimator_spec_for_softmax_classification(
|
def estimator_spec_for_softmax_classification(logits, labels, mode):
|
||||||
logits, labels, mode):
|
|
||||||
"""Returns EstimatorSpec instance for softmax classification."""
|
"""Returns EstimatorSpec instance for softmax classification."""
|
||||||
predicted_classes = tf.argmax(logits, 1)
|
predicted_classes = tf.argmax(logits, 1)
|
||||||
if mode == tf.estimator.ModeKeys.PREDICT:
|
if mode == tf.estimator.ModeKeys.PREDICT:
|
||||||
@ -53,8 +52,8 @@ def estimator_spec_for_softmax_classification(
|
|||||||
return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
|
return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
|
||||||
|
|
||||||
eval_metric_ops = {
|
eval_metric_ops = {
|
||||||
'accuracy': tf.metrics.accuracy(
|
'accuracy':
|
||||||
labels=labels, predictions=predicted_classes)
|
tf.metrics.accuracy(labels=labels, predictions=predicted_classes)
|
||||||
}
|
}
|
||||||
return tf.estimator.EstimatorSpec(
|
return tf.estimator.EstimatorSpec(
|
||||||
mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)
|
mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)
|
||||||
@ -67,8 +66,7 @@ def bag_of_words_model(features, labels, mode):
|
|||||||
bow_embedding_column = tf.feature_column.embedding_column(
|
bow_embedding_column = tf.feature_column.embedding_column(
|
||||||
bow_column, dimension=EMBEDDING_SIZE)
|
bow_column, dimension=EMBEDDING_SIZE)
|
||||||
bow = tf.feature_column.input_layer(
|
bow = tf.feature_column.input_layer(
|
||||||
features,
|
features, feature_columns=[bow_embedding_column])
|
||||||
feature_columns=[bow_embedding_column])
|
|
||||||
logits = tf.layers.dense(bow, MAX_LABEL, activation=None)
|
logits = tf.layers.dense(bow, MAX_LABEL, activation=None)
|
||||||
|
|
||||||
return estimator_spec_for_softmax_classification(
|
return estimator_spec_for_softmax_classification(
|
||||||
@ -110,9 +108,9 @@ def main(unused_argv):
|
|||||||
# Prepare training and testing data
|
# Prepare training and testing data
|
||||||
dbpedia = tf.contrib.learn.datasets.load_dataset(
|
dbpedia = tf.contrib.learn.datasets.load_dataset(
|
||||||
'dbpedia', test_with_fake_data=FLAGS.test_with_fake_data)
|
'dbpedia', test_with_fake_data=FLAGS.test_with_fake_data)
|
||||||
x_train = pandas.Series(dbpedia.train.data[:,1])
|
x_train = pandas.Series(dbpedia.train.data[:, 1])
|
||||||
y_train = pandas.Series(dbpedia.train.target)
|
y_train = pandas.Series(dbpedia.train.target)
|
||||||
x_test = pandas.Series(dbpedia.test.data[:,1])
|
x_test = pandas.Series(dbpedia.test.data[:, 1])
|
||||||
y_test = pandas.Series(dbpedia.test.target)
|
y_test = pandas.Series(dbpedia.test.target)
|
||||||
|
|
||||||
# Process vocabulary
|
# Process vocabulary
|
||||||
@ -152,10 +150,7 @@ def main(unused_argv):
|
|||||||
|
|
||||||
# Predict.
|
# Predict.
|
||||||
test_input_fn = tf.estimator.inputs.numpy_input_fn(
|
test_input_fn = tf.estimator.inputs.numpy_input_fn(
|
||||||
x={WORDS_FEATURE: x_test},
|
x={WORDS_FEATURE: x_test}, y=y_test, num_epochs=1, shuffle=False)
|
||||||
y=y_test,
|
|
||||||
num_epochs=1,
|
|
||||||
shuffle=False)
|
|
||||||
predictions = classifier.predict(input_fn=test_input_fn)
|
predictions = classifier.predict(input_fn=test_input_fn)
|
||||||
y_predicted = np.array(list(p['class'] for p in predictions))
|
y_predicted = np.array(list(p['class'] for p in predictions))
|
||||||
y_predicted = y_predicted.reshape(np.array(y_test).shape)
|
y_predicted = y_predicted.reshape(np.array(y_test).shape)
|
||||||
|
|||||||
@ -12,7 +12,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
"""Notebook front-end to TensorFlow.
|
"""Notebook front-end to TensorFlow.
|
||||||
|
|
||||||
When you run this binary, you'll see something like below, which indicates
|
When you run this binary, you'll see something like below, which indicates
|
||||||
@ -43,10 +42,8 @@ from tensorflow.python.platform import app
|
|||||||
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "cpp"
|
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "cpp"
|
||||||
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION"] = "2"
|
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION"] = "2"
|
||||||
|
|
||||||
|
|
||||||
FLAGS = None
|
FLAGS = None
|
||||||
|
|
||||||
|
|
||||||
ORIG_ARGV = sys.argv
|
ORIG_ARGV = sys.argv
|
||||||
# Main notebook process calls itself with argv[1]="kernel" to start kernel
|
# Main notebook process calls itself with argv[1]="kernel" to start kernel
|
||||||
# subprocesses.
|
# subprocesses.
|
||||||
@ -73,8 +70,8 @@ def main(unused_argv):
|
|||||||
notebookapp.ip = "0.0.0.0"
|
notebookapp.ip = "0.0.0.0"
|
||||||
notebookapp.password = passwd(FLAGS.password)
|
notebookapp.password = passwd(FLAGS.password)
|
||||||
else:
|
else:
|
||||||
print ("\nNo password specified; Notebook server will only be available"
|
print("\nNo password specified; Notebook server will only be available"
|
||||||
" on the local machine.\n")
|
" on the local machine.\n")
|
||||||
notebookapp.initialize(argv=["--notebook-dir", FLAGS.notebook_dir])
|
notebookapp.initialize(argv=["--notebook-dir", FLAGS.notebook_dir])
|
||||||
|
|
||||||
if notebookapp.ip == "0.0.0.0":
|
if notebookapp.ip == "0.0.0.0":
|
||||||
@ -125,8 +122,8 @@ if __name__ == "__main__":
|
|||||||
# kernel app.
|
# kernel app.
|
||||||
if IS_KERNEL:
|
if IS_KERNEL:
|
||||||
# Drop everything except --flagfile.
|
# Drop everything except --flagfile.
|
||||||
sys.argv = ([sys.argv[0]] +
|
sys.argv = (
|
||||||
[x for x in sys.argv[1:] if x.startswith("--flagfile")])
|
[sys.argv[0]] + [x for x in sys.argv[1:] if x.startswith("--flagfile")])
|
||||||
|
|
||||||
FLAGS, unparsed = parser.parse_known_args()
|
FLAGS, unparsed = parser.parse_known_args()
|
||||||
app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||||
|
|||||||
@ -12,7 +12,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
"""Classes and functions related to train_and_evaluate."""
|
"""Classes and functions related to train_and_evaluate."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
@ -37,7 +36,6 @@ from tensorflow.python.training import server_lib
|
|||||||
from tensorflow.python.training import session_run_hook
|
from tensorflow.python.training import session_run_hook
|
||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
|
|
||||||
|
|
||||||
_MAX_DELAY_SECS = 60
|
_MAX_DELAY_SECS = 60
|
||||||
_DELAY_SECS_PER_WORKER = 5
|
_DELAY_SECS_PER_WORKER = 5
|
||||||
_TF_CONFIG_ENV = 'TF_CONFIG'
|
_TF_CONFIG_ENV = 'TF_CONFIG'
|
||||||
@ -50,8 +48,7 @@ _TRAINER_JOBS = (run_config_lib.TaskType.CHIEF, run_config_lib.TaskType.MASTER,
|
|||||||
def _validate_input_fn(input_fn):
|
def _validate_input_fn(input_fn):
|
||||||
"""Validates the `input_fn`."""
|
"""Validates the `input_fn`."""
|
||||||
if not callable(input_fn):
|
if not callable(input_fn):
|
||||||
raise TypeError(
|
raise TypeError('`input_fn` must be callable, given: {}'.format(input_fn))
|
||||||
'`input_fn` must be callable, given: {}'.format(input_fn))
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_hooks(hooks):
|
def _validate_hooks(hooks):
|
||||||
@ -125,10 +122,7 @@ class TrainSpec(
|
|||||||
duration. Optional hooks run at various stages of training.
|
duration. Optional hooks run at various stages of training.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __new__(cls,
|
def __new__(cls, input_fn, max_steps=None, hooks=None):
|
||||||
input_fn,
|
|
||||||
max_steps=None,
|
|
||||||
hooks=None):
|
|
||||||
"""Creates a validated `TrainSpec` instance.
|
"""Creates a validated `TrainSpec` instance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -161,16 +155,13 @@ class TrainSpec(
|
|||||||
hooks = _validate_hooks(hooks)
|
hooks = _validate_hooks(hooks)
|
||||||
|
|
||||||
return super(TrainSpec, cls).__new__(
|
return super(TrainSpec, cls).__new__(
|
||||||
cls,
|
cls, input_fn=input_fn, max_steps=max_steps, hooks=hooks)
|
||||||
input_fn=input_fn,
|
|
||||||
max_steps=max_steps,
|
|
||||||
hooks=hooks)
|
|
||||||
|
|
||||||
|
|
||||||
class EvalSpec(
|
class EvalSpec(
|
||||||
collections.namedtuple('EvalSpec', [
|
collections.namedtuple('EvalSpec', [
|
||||||
'input_fn', 'steps', 'name', 'hooks', 'exporters',
|
'input_fn', 'steps', 'name', 'hooks', 'exporters', 'start_delay_secs',
|
||||||
'start_delay_secs', 'throttle_secs'
|
'throttle_secs'
|
||||||
])):
|
])):
|
||||||
"""Configuration for the "eval" part for the `train_and_evaluate` call.
|
"""Configuration for the "eval" part for the `train_and_evaluate` call.
|
||||||
|
|
||||||
@ -417,8 +408,8 @@ def train_and_evaluate(estimator, train_spec, eval_spec):
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: if environment variable `TF_CONFIG` is incorrectly set.
|
ValueError: if environment variable `TF_CONFIG` is incorrectly set.
|
||||||
"""
|
"""
|
||||||
executor = _TrainingExecutor(estimator=estimator, train_spec=train_spec,
|
executor = _TrainingExecutor(
|
||||||
eval_spec=eval_spec)
|
estimator=estimator, train_spec=train_spec, eval_spec=eval_spec)
|
||||||
|
|
||||||
config = estimator.config
|
config = estimator.config
|
||||||
if (config.task_type == run_config_lib.TaskType.EVALUATOR and
|
if (config.task_type == run_config_lib.TaskType.EVALUATOR and
|
||||||
@ -561,9 +552,8 @@ class _TrainingExecutor(object):
|
|||||||
self._timer.update_last_triggered_step(global_step_value)
|
self._timer.update_last_triggered_step(global_step_value)
|
||||||
self._evaluator.evaluate_and_export()
|
self._evaluator.evaluate_and_export()
|
||||||
else:
|
else:
|
||||||
logging.info(
|
logging.info('Skip the current checkpoint eval due to throttle secs '
|
||||||
'Skip the current checkpoint eval due to throttle secs '
|
'({} secs).'.format(self._eval_throttle_secs))
|
||||||
'({} secs).'.format(self._eval_throttle_secs))
|
|
||||||
|
|
||||||
# Final export signal: For any eval result with global_step >= train
|
# Final export signal: For any eval result with global_step >= train
|
||||||
# max_steps, the evaluator will send the final export signal. There is a
|
# max_steps, the evaluator will send the final export signal. There is a
|
||||||
@ -576,8 +566,8 @@ class _TrainingExecutor(object):
|
|||||||
#
|
#
|
||||||
# But here, throttle_secs will skip the next intermediate checkpoint and,
|
# But here, throttle_secs will skip the next intermediate checkpoint and,
|
||||||
# so, the double final export chance is very small.
|
# so, the double final export chance is very small.
|
||||||
evaluator = _TrainingExecutor._Evaluator(
|
evaluator = _TrainingExecutor._Evaluator(self._estimator, self._eval_spec,
|
||||||
self._estimator, self._eval_spec, self._train_spec.max_steps)
|
self._train_spec.max_steps)
|
||||||
|
|
||||||
# When the underlying `Estimator` object saves a new checkpoint, we would
|
# When the underlying `Estimator` object saves a new checkpoint, we would
|
||||||
# like this callback to be called so that evaluation and export can trigger.
|
# like this callback to be called so that evaluation and export can trigger.
|
||||||
@ -617,8 +607,7 @@ class _TrainingExecutor(object):
|
|||||||
raise ValueError('eval_spec.throttle_secs should be positive, given: {}.'
|
raise ValueError('eval_spec.throttle_secs should be positive, given: {}.'
|
||||||
'It is used do determine how long each training '
|
'It is used do determine how long each training '
|
||||||
'iteration should go when train and evaluate '
|
'iteration should go when train and evaluate '
|
||||||
'locally.'.format(
|
'locally.'.format(self._eval_spec.throttle_secs))
|
||||||
self._eval_spec.throttle_secs))
|
|
||||||
|
|
||||||
stop_hook = _StopAtSecsHook(self._eval_spec.throttle_secs)
|
stop_hook = _StopAtSecsHook(self._eval_spec.throttle_secs)
|
||||||
train_hooks = (
|
train_hooks = (
|
||||||
@ -663,8 +652,9 @@ class _TrainingExecutor(object):
|
|||||||
|
|
||||||
if not config.master:
|
if not config.master:
|
||||||
jobs = config.cluster_spec.jobs
|
jobs = config.cluster_spec.jobs
|
||||||
if (len(jobs) == 1 and len(config.cluster_spec.job_tasks(jobs[0])) == 1
|
if (len(jobs) == 1 and
|
||||||
and config.task_type in _TRAINER_JOBS):
|
len(config.cluster_spec.job_tasks(jobs[0])) == 1 and
|
||||||
|
config.task_type in _TRAINER_JOBS):
|
||||||
# For distributed training, config.master is empty if and only if it has
|
# For distributed training, config.master is empty if and only if it has
|
||||||
# a single node in the cluster spec. In this case, we should not start
|
# a single node in the cluster spec. In this case, we should not start
|
||||||
# the server.
|
# the server.
|
||||||
@ -679,9 +669,9 @@ class _TrainingExecutor(object):
|
|||||||
logging.info('Start Tensorflow server.')
|
logging.info('Start Tensorflow server.')
|
||||||
|
|
||||||
if config.session_config is None:
|
if config.session_config is None:
|
||||||
session_config=config_pb2.ConfigProto(log_device_placement=False)
|
session_config = config_pb2.ConfigProto(log_device_placement=False)
|
||||||
else:
|
else:
|
||||||
session_config=config_pb2.ConfigProto(
|
session_config = config_pb2.ConfigProto(
|
||||||
log_device_placement=False,
|
log_device_placement=False,
|
||||||
gpu_options=config.session_config.gpu_options)
|
gpu_options=config.session_config.gpu_options)
|
||||||
|
|
||||||
@ -744,8 +734,7 @@ class _TrainingExecutor(object):
|
|||||||
global_step >= self._train_spec.max_steps):
|
global_step >= self._train_spec.max_steps):
|
||||||
logging.info(
|
logging.info(
|
||||||
'Exiting evaluation, global_step=%s >= train max_steps=%s',
|
'Exiting evaluation, global_step=%s >= train max_steps=%s',
|
||||||
global_step,
|
global_step, self._train_spec.max_steps)
|
||||||
self._train_spec.max_steps)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
latest_eval_result, should_early_stop = self._execute_evaluator_once(
|
latest_eval_result, should_early_stop = self._execute_evaluator_once(
|
||||||
@ -781,10 +770,9 @@ class _TrainingExecutor(object):
|
|||||||
|
|
||||||
# Throttle if necessary.
|
# Throttle if necessary.
|
||||||
elapsed_time = time.time() - start
|
elapsed_time = time.time() - start
|
||||||
difference = throttle_secs - elapsed_time
|
difference = throttle_secs - elapsed_time
|
||||||
if difference > 0:
|
if difference > 0:
|
||||||
logging.info('Waiting %f secs before starting next eval run.',
|
logging.info('Waiting %f secs before starting next eval run.', difference)
|
||||||
difference)
|
|
||||||
time.sleep(difference)
|
time.sleep(difference)
|
||||||
|
|
||||||
return (eval_result, should_early_stop)
|
return (eval_result, should_early_stop)
|
||||||
@ -929,8 +917,8 @@ class _EvalResult(
|
|||||||
if checkpoint_path:
|
if checkpoint_path:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'checkpoint must be `None` if status is not {}; got status {}, '
|
'checkpoint must be `None` if status is not {}; got status {}, '
|
||||||
'checkpoint_path {}'.format(
|
'checkpoint_path {}'.format(_EvalStatus.EVALUATED, status,
|
||||||
_EvalStatus.EVALUATED, status, checkpoint_path))
|
checkpoint_path))
|
||||||
return super(_EvalResult, cls).__new__(cls, status, metrics,
|
return super(_EvalResult, cls).__new__(cls, status, metrics,
|
||||||
checkpoint_path)
|
checkpoint_path)
|
||||||
|
|
||||||
|
|||||||
@ -56,9 +56,11 @@ class TensordotTest(test_lib.TestCase):
|
|||||||
axes_ph = array_ops.placeholder(dtypes.int32)
|
axes_ph = array_ops.placeholder(dtypes.int32)
|
||||||
output = math_ops.tensordot(a_ph, b_ph, axes_ph)
|
output = math_ops.tensordot(a_ph, b_ph, axes_ph)
|
||||||
_ = sess.run(
|
_ = sess.run(
|
||||||
[output], feed_dict={a_ph: a,
|
[output], feed_dict={
|
||||||
b_ph: b,
|
a_ph: a,
|
||||||
axes_ph: (a_axes, b_axes)})
|
b_ph: b,
|
||||||
|
axes_ph: (a_axes, b_axes)
|
||||||
|
})
|
||||||
|
|
||||||
def test_invalid_axes(self):
|
def test_invalid_axes(self):
|
||||||
a = [[1, 2], [3, 4]]
|
a = [[1, 2], [3, 4]]
|
||||||
@ -81,26 +83,27 @@ class TensordotTest(test_lib.TestCase):
|
|||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
with self.assertRaises(errors_impl.InvalidArgumentError):
|
with self.assertRaises(errors_impl.InvalidArgumentError):
|
||||||
_ = sess.run(
|
_ = sess.run(
|
||||||
[output], feed_dict={a_ph: a,
|
[output], feed_dict={
|
||||||
b_ph: b,
|
a_ph: a,
|
||||||
axes_ph: axes_value})
|
b_ph: b,
|
||||||
|
axes_ph: axes_value
|
||||||
|
})
|
||||||
|
|
||||||
# Test case for 11950
|
# Test case for 11950
|
||||||
def test_valid_axis(self):
|
def test_valid_axis(self):
|
||||||
for axes_value in [1, 2], [[1], [2]]:
|
for axes_value in [1, 2], [[1], [2]]:
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
np_a = np.ones((3,3))
|
np_a = np.ones((3, 3))
|
||||||
np_b = np.array([2, 3, 1])[None, None]
|
np_b = np.array([2, 3, 1])[None, None]
|
||||||
np_ans = np.tensordot(np_a, np_b, axes_value)
|
np_ans = np.tensordot(np_a, np_b, axes_value)
|
||||||
|
|
||||||
tf_a = array_ops.ones((3,3), dtype=dtypes.float32)
|
tf_a = array_ops.ones((3, 3), dtype=dtypes.float32)
|
||||||
tf_b = constant_op.constant([2, 3, 1], dtype=dtypes.float32)[None, None]
|
tf_b = constant_op.constant([2, 3, 1], dtype=dtypes.float32)[None, None]
|
||||||
tf_ans = math_ops.tensordot(tf_a, tf_b, axes_value).eval()
|
tf_ans = math_ops.tensordot(tf_a, tf_b, axes_value).eval()
|
||||||
|
|
||||||
self.assertAllEqual(tf_ans.shape, np_ans.shape)
|
self.assertAllEqual(tf_ans.shape, np_ans.shape)
|
||||||
self.assertAllEqual(tf_ans, np_ans)
|
self.assertAllEqual(tf_ans, np_ans)
|
||||||
|
|
||||||
|
|
||||||
def test_partial_shape_inference(self):
|
def test_partial_shape_inference(self):
|
||||||
a = array_ops.placeholder(dtypes.float32)
|
a = array_ops.placeholder(dtypes.float32)
|
||||||
b = array_ops.placeholder(dtypes.float32)
|
b = array_ops.placeholder(dtypes.float32)
|
||||||
@ -169,9 +172,11 @@ def _get_tensordot_tests(dtype_, rank_a_, rank_b_, num_dims_, dynamic_shape_):
|
|||||||
axes = array_ops.placeholder(dtypes.int32)
|
axes = array_ops.placeholder(dtypes.int32)
|
||||||
c = math_ops.tensordot(a, b, axes)
|
c = math_ops.tensordot(a, b, axes)
|
||||||
tf_ans = sess.run(
|
tf_ans = sess.run(
|
||||||
c, feed_dict={a: a_np,
|
c, feed_dict={
|
||||||
b: b_np,
|
a: a_np,
|
||||||
axes: (a_dims_np, b_dims_np)})
|
b: b_np,
|
||||||
|
axes: (a_dims_np, b_dims_np)
|
||||||
|
})
|
||||||
else:
|
else:
|
||||||
tf_ans = math_ops.tensordot(a_np, b_np, (a_dims_np, b_dims_np)).eval()
|
tf_ans = math_ops.tensordot(a_np, b_np, (a_dims_np, b_dims_np)).eval()
|
||||||
self.assertAllClose(tf_ans, np_ans, rtol=tol, atol=tol)
|
self.assertAllClose(tf_ans, np_ans, rtol=tol, atol=tol)
|
||||||
|
|||||||
@ -184,7 +184,8 @@ do_pylint() {
|
|||||||
# W0312 mixed-indentation
|
# W0312 mixed-indentation
|
||||||
# C0330 bad-continuation
|
# C0330 bad-continuation
|
||||||
# C0301 line-too-long
|
# C0301 line-too-long
|
||||||
grep -E '(\[E|\[W0311|\[W0312|\[C0330|\[C0301)' ${OUTPUT_FILE} > ${ERRORS_FILE}
|
# C0326 bad-whitespace
|
||||||
|
grep -E '(\[E|\[W0311|\[W0312|\[C0330|\[C0301|\[C0326)' ${OUTPUT_FILE} > ${ERRORS_FILE}
|
||||||
|
|
||||||
N_ERRORS=0
|
N_ERRORS=0
|
||||||
while read -r LINE; do
|
while read -r LINE; do
|
||||||
|
|||||||
@ -46,8 +46,9 @@ class APIChangeSpec(object):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class _FileEditTuple(collections.namedtuple(
|
class _FileEditTuple(
|
||||||
"_FileEditTuple", ["comment", "line", "start", "old", "new"])):
|
collections.namedtuple("_FileEditTuple",
|
||||||
|
["comment", "line", "start", "old", "new"])):
|
||||||
"""Each edit that is recorded by a _FileEditRecorder.
|
"""Each edit that is recorded by a _FileEditRecorder.
|
||||||
|
|
||||||
Fields:
|
Fields:
|
||||||
@ -179,8 +180,7 @@ class _ASTCallVisitor(ast.NodeVisitor):
|
|||||||
function_renames = self._api_change_spec.function_renames
|
function_renames = self._api_change_spec.function_renames
|
||||||
try:
|
try:
|
||||||
new_name = function_renames[full_name]
|
new_name = function_renames[full_name]
|
||||||
self._file_edit.add("Renamed function %r to %r" % (full_name,
|
self._file_edit.add("Renamed function %r to %r" % (full_name, new_name),
|
||||||
new_name),
|
|
||||||
node.lineno, node.col_offset, full_name, new_name)
|
node.lineno, node.col_offset, full_name, new_name)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
@ -227,7 +227,7 @@ class _ASTCallVisitor(ast.NodeVisitor):
|
|||||||
# loop over lines
|
# loop over lines
|
||||||
while 1:
|
while 1:
|
||||||
# Reverse the text to and regular expression search for whitespace
|
# Reverse the text to and regular expression search for whitespace
|
||||||
text = self._lines[line-1]
|
text = self._lines[line - 1]
|
||||||
reversed_preceding_text = text[:col][::-1]
|
reversed_preceding_text = text[:col][::-1]
|
||||||
# First find if a [ can be found with only whitespace between it and
|
# First find if a [ can be found with only whitespace between it and
|
||||||
# col.
|
# col.
|
||||||
@ -248,8 +248,8 @@ class _ASTCallVisitor(ast.NodeVisitor):
|
|||||||
# node ranges to filter out spurious #'s that appear in string
|
# node ranges to filter out spurious #'s that appear in string
|
||||||
# literals.
|
# literals.
|
||||||
comment_start = prev_line.find("#")
|
comment_start = prev_line.find("#")
|
||||||
if comment_start == -1:
|
if comment_start == -1:
|
||||||
col = len(prev_line) -1
|
col = len(prev_line) - 1
|
||||||
elif find_string_chars.search(prev_line[comment_start:]) is None:
|
elif find_string_chars.search(prev_line[comment_start:]) is None:
|
||||||
col = comment_start
|
col = comment_start
|
||||||
else:
|
else:
|
||||||
@ -260,7 +260,6 @@ class _ASTCallVisitor(ast.NodeVisitor):
|
|||||||
# it is not possible to use that in an argument.
|
# it is not possible to use that in an argument.
|
||||||
return node.lineno, node.col_offset
|
return node.lineno, node.col_offset
|
||||||
|
|
||||||
|
|
||||||
def visit_Call(self, node): # pylint: disable=invalid-name
|
def visit_Call(self, node): # pylint: disable=invalid-name
|
||||||
"""Handle visiting a call node in the AST.
|
"""Handle visiting a call node in the AST.
|
||||||
|
|
||||||
@ -268,7 +267,6 @@ class _ASTCallVisitor(ast.NodeVisitor):
|
|||||||
node: Current Node
|
node: Current Node
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
# Find a simple attribute name path e.g. "tf.foo.bar"
|
# Find a simple attribute name path e.g. "tf.foo.bar"
|
||||||
full_name = self._get_attribute_full_path(node.func)
|
full_name = self._get_attribute_full_path(node.func)
|
||||||
|
|
||||||
@ -293,18 +291,21 @@ class _ASTCallVisitor(ast.NodeVisitor):
|
|||||||
lineno, col_offset = self._find_true_position(arg)
|
lineno, col_offset = self._find_true_position(arg)
|
||||||
if lineno is None or col_offset is None:
|
if lineno is None or col_offset is None:
|
||||||
self._file_edit.add(
|
self._file_edit.add(
|
||||||
"Failed to add keyword %r to reordered function %r"
|
"Failed to add keyword %r to reordered function %r" %
|
||||||
% (reordered[idx], full_name), arg.lineno, arg.col_offset,
|
(reordered[idx], full_name),
|
||||||
"", "",
|
arg.lineno,
|
||||||
|
arg.col_offset,
|
||||||
|
"",
|
||||||
|
"",
|
||||||
error="A necessary keyword argument failed to be inserted.")
|
error="A necessary keyword argument failed to be inserted.")
|
||||||
else:
|
else:
|
||||||
keyword_arg = reordered[idx]
|
keyword_arg = reordered[idx]
|
||||||
if (full_name in function_keyword_renames and
|
if (full_name in function_keyword_renames and
|
||||||
keyword_arg in function_keyword_renames[full_name]):
|
keyword_arg in function_keyword_renames[full_name]):
|
||||||
keyword_arg = function_keyword_renames[full_name][keyword_arg]
|
keyword_arg = function_keyword_renames[full_name][keyword_arg]
|
||||||
self._file_edit.add("Added keyword %r to reordered function %r"
|
self._file_edit.add("Added keyword %r to reordered function %r" %
|
||||||
% (reordered[idx], full_name), lineno,
|
(reordered[idx], full_name), lineno, col_offset,
|
||||||
col_offset, "", keyword_arg + "=")
|
"", keyword_arg + "=")
|
||||||
|
|
||||||
# Examine each keyword argument and convert it to the final renamed form
|
# Examine each keyword argument and convert it to the final renamed form
|
||||||
renamed_keywords = ({} if full_name not in function_keyword_renames else
|
renamed_keywords = ({} if full_name not in function_keyword_renames else
|
||||||
@ -322,11 +323,11 @@ class _ASTCallVisitor(ast.NodeVisitor):
|
|||||||
# value.
|
# value.
|
||||||
key_start = argval_col_offset - len(argkey) - 1
|
key_start = argval_col_offset - len(argkey) - 1
|
||||||
key_end = key_start + len(argkey) + 1
|
key_end = key_start + len(argkey) + 1
|
||||||
if (self._lines[argval_lineno - 1][key_start:key_end] ==
|
if (self._lines[argval_lineno - 1][key_start:key_end] == argkey +
|
||||||
argkey + "="):
|
"="):
|
||||||
self._file_edit.add("Renamed keyword argument from %r to %r" %
|
self._file_edit.add("Renamed keyword argument from %r to %r" %
|
||||||
(argkey, renamed_keywords[argkey]),
|
(argkey,
|
||||||
argval_lineno,
|
renamed_keywords[argkey]), argval_lineno,
|
||||||
argval_col_offset - len(argkey) - 1,
|
argval_col_offset - len(argkey) - 1,
|
||||||
argkey + "=", renamed_keywords[argkey] + "=")
|
argkey + "=", renamed_keywords[argkey] + "=")
|
||||||
continue
|
continue
|
||||||
@ -335,7 +336,8 @@ class _ASTCallVisitor(ast.NodeVisitor):
|
|||||||
(argkey, renamed_keywords[argkey]),
|
(argkey, renamed_keywords[argkey]),
|
||||||
argval.lineno,
|
argval.lineno,
|
||||||
argval.col_offset - len(argkey) - 1,
|
argval.col_offset - len(argkey) - 1,
|
||||||
"", "",
|
"",
|
||||||
|
"",
|
||||||
error="Failed to find keyword lexographically. Fix manually.")
|
error="Failed to find keyword lexographically. Fix manually.")
|
||||||
|
|
||||||
ast.NodeVisitor.generic_visit(self, node)
|
ast.NodeVisitor.generic_visit(self, node)
|
||||||
@ -352,7 +354,7 @@ class _ASTCallVisitor(ast.NodeVisitor):
|
|||||||
if full_name in self._api_change_spec.change_to_function:
|
if full_name in self._api_change_spec.change_to_function:
|
||||||
if not hasattr(node, "is_function_for_call"):
|
if not hasattr(node, "is_function_for_call"):
|
||||||
new_text = full_name + "()"
|
new_text = full_name + "()"
|
||||||
self._file_edit.add("Changed %r to %r"%(full_name, new_text),
|
self._file_edit.add("Changed %r to %r" % (full_name, new_text),
|
||||||
node.lineno, node.col_offset, full_name, new_text)
|
node.lineno, node.col_offset, full_name, new_text)
|
||||||
|
|
||||||
ast.NodeVisitor.generic_visit(self, node)
|
ast.NodeVisitor.generic_visit(self, node)
|
||||||
@ -380,8 +382,8 @@ class ASTCodeUpgrader(object):
|
|||||||
# Write to a temporary file, just in case we are doing an implace modify.
|
# Write to a temporary file, just in case we are doing an implace modify.
|
||||||
with open(in_filename, "r") as in_file, \
|
with open(in_filename, "r") as in_file, \
|
||||||
tempfile.NamedTemporaryFile("w", delete=False) as temp_file:
|
tempfile.NamedTemporaryFile("w", delete=False) as temp_file:
|
||||||
ret = self.process_opened_file(
|
ret = self.process_opened_file(in_filename, in_file, out_filename,
|
||||||
in_filename, in_file, out_filename, temp_file)
|
temp_file)
|
||||||
|
|
||||||
shutil.move(temp_file.name, out_filename)
|
shutil.move(temp_file.name, out_filename)
|
||||||
return ret
|
return ret
|
||||||
@ -424,6 +426,7 @@ class ASTCodeUpgrader(object):
|
|||||||
out_file.write(out_text)
|
out_file.write(out_text)
|
||||||
text += "\n"
|
text += "\n"
|
||||||
return 1, text, process_errors
|
return 1, text, process_errors
|
||||||
|
|
||||||
# pylint: enable=broad-except
|
# pylint: enable=broad-except
|
||||||
|
|
||||||
def process_tree(self, root_directory, output_root_directory,
|
def process_tree(self, root_directory, output_root_directory,
|
||||||
@ -444,16 +447,16 @@ class ASTCodeUpgrader(object):
|
|||||||
|
|
||||||
# make sure output directory doesn't exist
|
# make sure output directory doesn't exist
|
||||||
if output_root_directory and os.path.exists(output_root_directory):
|
if output_root_directory and os.path.exists(output_root_directory):
|
||||||
print("Output directory %r must not already exist." % (
|
print("Output directory %r must not already exist." %
|
||||||
output_root_directory))
|
(output_root_directory))
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# make sure output directory does not overlap with root_directory
|
# make sure output directory does not overlap with root_directory
|
||||||
norm_root = os.path.split(os.path.normpath(root_directory))
|
norm_root = os.path.split(os.path.normpath(root_directory))
|
||||||
norm_output = os.path.split(os.path.normpath(output_root_directory))
|
norm_output = os.path.split(os.path.normpath(output_root_directory))
|
||||||
if norm_root == norm_output:
|
if norm_root == norm_output:
|
||||||
print("Output directory %r same as input directory %r" % (
|
print("Output directory %r same as input directory %r" %
|
||||||
root_directory, output_root_directory))
|
(root_directory, output_root_directory))
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# Collect list of files to process (we do this to correctly handle if the
|
# Collect list of files to process (we do this to correctly handle if the
|
||||||
@ -465,14 +468,16 @@ class ASTCodeUpgrader(object):
|
|||||||
copy_files = [f for f in file_list if not f.endswith(".py")]
|
copy_files = [f for f in file_list if not f.endswith(".py")]
|
||||||
for filename in py_files:
|
for filename in py_files:
|
||||||
fullpath = os.path.join(dir_name, filename)
|
fullpath = os.path.join(dir_name, filename)
|
||||||
fullpath_output = os.path.join(
|
fullpath_output = os.path.join(output_root_directory,
|
||||||
output_root_directory, os.path.relpath(fullpath, root_directory))
|
os.path.relpath(fullpath,
|
||||||
|
root_directory))
|
||||||
files_to_process.append((fullpath, fullpath_output))
|
files_to_process.append((fullpath, fullpath_output))
|
||||||
if copy_other_files:
|
if copy_other_files:
|
||||||
for filename in copy_files:
|
for filename in copy_files:
|
||||||
fullpath = os.path.join(dir_name, filename)
|
fullpath = os.path.join(dir_name, filename)
|
||||||
fullpath_output = os.path.join(
|
fullpath_output = os.path.join(output_root_directory,
|
||||||
output_root_directory, os.path.relpath(fullpath, root_directory))
|
os.path.relpath(
|
||||||
|
fullpath, root_directory))
|
||||||
files_to_copy.append((fullpath, fullpath_output))
|
files_to_copy.append((fullpath, fullpath_output))
|
||||||
|
|
||||||
file_count = 0
|
file_count = 0
|
||||||
@ -641,18 +646,17 @@ class TFAPIChangeSpec(APIChangeSpec):
|
|||||||
"tf.concat": ["concat_dim", "values", "name"],
|
"tf.concat": ["concat_dim", "values", "name"],
|
||||||
"tf.svd": ["tensor", "compute_uv", "full_matrices", "name"],
|
"tf.svd": ["tensor", "compute_uv", "full_matrices", "name"],
|
||||||
"tf.nn.softmax_cross_entropy_with_logits": [
|
"tf.nn.softmax_cross_entropy_with_logits": [
|
||||||
"logits", "labels", "dim", "name"],
|
"logits", "labels", "dim", "name"
|
||||||
|
],
|
||||||
"tf.nn.sparse_softmax_cross_entropy_with_logits": [
|
"tf.nn.sparse_softmax_cross_entropy_with_logits": [
|
||||||
"logits", "labels", "name"],
|
"logits", "labels", "name"
|
||||||
"tf.nn.sigmoid_cross_entropy_with_logits": [
|
],
|
||||||
"logits", "labels", "name"],
|
"tf.nn.sigmoid_cross_entropy_with_logits": ["logits", "labels", "name"],
|
||||||
"tf.op_scope": ["values", "name", "default_name"],
|
"tf.op_scope": ["values", "name", "default_name"],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Specially handled functions.
|
# Specially handled functions.
|
||||||
self.function_handle = {
|
self.function_handle = {"tf.reverse": self._reverse_handler}
|
||||||
"tf.reverse": self._reverse_handler
|
|
||||||
}
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _reverse_handler(file_edit_recorder, node):
|
def _reverse_handler(file_edit_recorder, node):
|
||||||
@ -661,12 +665,13 @@ class TFAPIChangeSpec(APIChangeSpec):
|
|||||||
comment = ("ERROR: tf.reverse has had its argument semantics changed\n"
|
comment = ("ERROR: tf.reverse has had its argument semantics changed\n"
|
||||||
"significantly the converter cannot detect this reliably, so you"
|
"significantly the converter cannot detect this reliably, so you"
|
||||||
"need to inspect this usage manually.\n")
|
"need to inspect this usage manually.\n")
|
||||||
file_edit_recorder.add(comment,
|
file_edit_recorder.add(
|
||||||
node.lineno,
|
comment,
|
||||||
node.col_offset,
|
node.lineno,
|
||||||
"tf.reverse",
|
node.col_offset,
|
||||||
"tf.reverse",
|
"tf.reverse",
|
||||||
error="tf.reverse requires manual check.")
|
"tf.reverse",
|
||||||
|
error="tf.reverse requires manual check.")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user