Merge pull request #1902 from martinwicke/branch_119712558
Merge internal changes
This commit is contained in:
parent
d118d1d31c
commit
44a6b91ce8
@ -79,9 +79,9 @@ REGISTER_OP("CountExtremelyRandomStats")
|
||||
gives the j-th feature of the i-th input.
|
||||
input_labels: The training batch's labels; `input_labels[i]` is the class
|
||||
of the i-th input.
|
||||
tree:= A 2-d int32 tensor. `tree[0][i]` gives the index of the left child
|
||||
of the i-th node, `tree[0][i] + 1` gives the index of the right child of
|
||||
the i-th node, and `tree[1][i]` gives the index of the feature used to
|
||||
tree:= A 2-d int32 tensor. `tree[i][0]` gives the index of the left child
|
||||
of the i-th node, `tree[i][0] + 1` gives the index of the right child of
|
||||
the i-th node, and `tree[i][1]` gives the index of the feature used to
|
||||
split the i-th node.
|
||||
tree_thresholds: `tree_thresholds[i]` is the value used to split the i-th
|
||||
node.
|
||||
|
@ -44,9 +44,9 @@ REGISTER_OP("TreePredictions")
|
||||
|
||||
input_data: The training batch's features as a 2-d tensor; `input_data[i][j]`
|
||||
gives the j-th feature of the i-th input.
|
||||
tree:= A 2-d int32 tensor. `tree[0][i]` gives the index of the left child
|
||||
of the i-th node, `tree[0][i] + 1` gives the index of the right child of
|
||||
the i-th node, and `tree[1][i]` gives the index of the feature used to
|
||||
tree:= A 2-d int32 tensor. `tree[i][0]` gives the index of the left child
|
||||
of the i-th node, `tree[i][0] + 1` gives the index of the right child of
|
||||
the i-th node, and `tree[i][1]` gives the index of the feature used to
|
||||
split the i-th node.
|
||||
tree_thresholds: `tree_thresholds[i]` is the value used to split the i-th
|
||||
node.
|
||||
|
@ -25,6 +25,12 @@ import tensorflow as tf
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
|
||||
flags = tf.app.flags
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_string('inference_library_base_dir', '',
|
||||
'Directory to look for inference library file.')
|
||||
|
||||
INFERENCE_OPS_FILE = '_inference_ops.so'
|
||||
|
||||
_inference_ops = None
|
||||
@ -54,7 +60,8 @@ def Load():
|
||||
with _ops_lock:
|
||||
global _inference_ops
|
||||
if not _inference_ops:
|
||||
data_files_path = tf.resource_loader.get_data_files_path()
|
||||
data_files_path = os.path.join(FLAGS.inference_library_base_dir,
|
||||
tf.resource_loader.get_data_files_path())
|
||||
tf.logging.info('data path: %s', data_files_path)
|
||||
_inference_ops = tf.load_op_library(os.path.join(
|
||||
data_files_path, INFERENCE_OPS_FILE))
|
||||
|
@ -25,6 +25,12 @@ import tensorflow as tf
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
|
||||
flags = tf.app.flags
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_string('training_library_base_dir', '',
|
||||
'Directory to look for inference library file.')
|
||||
|
||||
TRAINING_OPS_FILE = '_training_ops.so'
|
||||
|
||||
_training_ops = None
|
||||
@ -101,7 +107,8 @@ def Load():
|
||||
with _ops_lock:
|
||||
global _training_ops
|
||||
if not _training_ops:
|
||||
data_files_path = tf.resource_loader.get_data_files_path()
|
||||
data_files_path = os.path.join(FLAGS.training_library_base_dir,
|
||||
tf.resource_loader.get_data_files_path())
|
||||
tf.logging.info('data path: %s', data_files_path)
|
||||
_training_ops = tf.load_op_library(os.path.join(
|
||||
data_files_path, TRAINING_OPS_FILE))
|
||||
|
@ -37,6 +37,13 @@ flags.DEFINE_float(
|
||||
'samples_to_decide', 25.0,
|
||||
'Only decide on a split, or only fully use a leaf, after this many '
|
||||
'training samples have been seen.')
|
||||
flags.DEFINE_float('bagging_fraction', 1.0,
|
||||
'Use this fraction of the input, randomly chosen, to train '
|
||||
'each tree in the forest.')
|
||||
flags.DEFINE_integer(
|
||||
'num_splits_to_consider', 0,
|
||||
'If non-zero, consider this many candidates for a splitting '
|
||||
'rule at a fertile node.')
|
||||
|
||||
# If tree[i][0] equals this value, then i is a leaf node.
|
||||
LEAF_NODE = -1
|
||||
@ -69,6 +76,9 @@ class ForestHParams(object):
|
||||
# Fail fast if num_classes isn't set.
|
||||
_ = getattr(self, 'num_classes')
|
||||
|
||||
self.bagging_fraction = getattr(self, 'bagging_fraction',
|
||||
FLAGS.bagging_fraction)
|
||||
|
||||
self.num_trees = getattr(self, 'num_trees', FLAGS.num_trees)
|
||||
self.max_nodes = getattr(self, 'max_nodes', FLAGS.max_nodes)
|
||||
|
||||
@ -79,7 +89,9 @@ class ForestHParams(object):
|
||||
# The Random Forest literature recommends sqrt(# features) for
|
||||
# classification problems, and p/3 for regression problems.
|
||||
# TODO(thomaswc): Consider capping this for large number of features.
|
||||
if not getattr(self, 'num_splits_to_consider', None):
|
||||
self.num_splits_to_consider = getattr(self, 'num_splits_to_consider',
|
||||
FLAGS.num_splits_to_consider)
|
||||
if not self.num_splits_to_consider:
|
||||
self.num_splits_to_consider = max(10, int(
|
||||
math.ceil(math.sqrt(self.num_features))))
|
||||
|
||||
@ -94,8 +106,8 @@ class ForestHParams(object):
|
||||
self.max_fertile_nodes = getattr(self, 'max_fertile_nodes', num_fertile)
|
||||
# But it also never needs to be larger than the number of leaves,
|
||||
# which is max_nodes / 2.
|
||||
self.max_fertile_nodes = min(self.max_nodes,
|
||||
int(math.ceil(self.max_fertile_nodes / 2.0)))
|
||||
self.max_fertile_nodes = min(self.max_fertile_nodes,
|
||||
int(math.ceil(self.max_nodes / 2.0)))
|
||||
|
||||
# split_after_samples and valid_leaf_threshold should be about the same.
|
||||
# Therefore, if either is set, use it to set the other. Otherwise, fall
|
||||
@ -184,23 +196,6 @@ class TreeStats(object):
|
||||
self.num_leaves = num_leaves
|
||||
|
||||
|
||||
def get_tree_stats(variables, unused_params, session):
|
||||
num_nodes = variables.end_of_tree.eval(session=session) - 1
|
||||
num_leaves = tf.where(
|
||||
tf.equal(tf.squeeze(tf.slice(variables.tree, [0, 0], [-1, 1])),
|
||||
LEAF_NODE)).eval(session=session).shape[0]
|
||||
return TreeStats(num_nodes, num_leaves)
|
||||
|
||||
|
||||
def get_forest_stats(variables, params, session):
|
||||
|
||||
tree_stats = []
|
||||
for i in range(params.num_trees):
|
||||
tree_stats.append(get_tree_stats(variables[i], params, session))
|
||||
|
||||
return ForestStats(tree_stats, params)
|
||||
|
||||
|
||||
class ForestTrainingVariables(object):
|
||||
"""A container for a forests training data, consisting of multiple trees.
|
||||
|
||||
@ -212,9 +207,11 @@ class ForestTrainingVariables(object):
|
||||
... forest_variables.tree ...
|
||||
"""
|
||||
|
||||
def __init__(self, params):
|
||||
self.variables = [TreeTrainingVariables(params)
|
||||
for _ in range(params.num_trees)]
|
||||
def __init__(self, params, device_assigner):
|
||||
self.variables = []
|
||||
for i in range(params.num_trees):
|
||||
with tf.device(device_assigner.get_device(i)):
|
||||
self.variables.append(TreeTrainingVariables(params))
|
||||
|
||||
def __setitem__(self, t, val):
|
||||
self.variables[t] = val
|
||||
@ -223,12 +220,35 @@ class ForestTrainingVariables(object):
|
||||
return self.variables[t]
|
||||
|
||||
|
||||
class RandomForestDeviceAssigner(object):
|
||||
"""A device assigner that uses the default device.
|
||||
|
||||
Write subclasses that implement get_device for control over how trees
|
||||
get assigned to devices. This assumes that whole trees are assigned
|
||||
to a device.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.cached = None
|
||||
|
||||
def get_device(self, unused_tree_num):
|
||||
if not self.cached:
|
||||
dummy = tf.constant(0)
|
||||
self.cached = dummy.device
|
||||
|
||||
return self.cached
|
||||
|
||||
|
||||
class RandomForestGraphs(object):
|
||||
"""Builds TF graphs for random forest training and inference."""
|
||||
|
||||
def __init__(self, params):
|
||||
def __init__(self, params, device_assigner=None, variables=None):
|
||||
self.params = params
|
||||
self.variables = ForestTrainingVariables(self.params)
|
||||
self.device_assigner = device_assigner or RandomForestDeviceAssigner()
|
||||
tf.logging.info('Constructing forest with params = ')
|
||||
tf.logging.info(self.params.__dict__)
|
||||
self.variables = variables or ForestTrainingVariables(
|
||||
self.params, device_assigner=self.device_assigner)
|
||||
self.trees = [RandomTreeGraphs(self.variables[i], self.params,
|
||||
training_ops.Load(), inference_ops.Load())
|
||||
for i in range(self.params.num_trees)]
|
||||
@ -246,12 +266,26 @@ class RandomForestGraphs(object):
|
||||
"""
|
||||
tree_graphs = []
|
||||
for i in range(self.params.num_trees):
|
||||
tf.logging.info('Constructing tree %d', i)
|
||||
seed = self.params.base_random_seed
|
||||
if seed != 0:
|
||||
seed += i
|
||||
tree_graphs.append(self.trees[i].training_graph(
|
||||
input_data, input_labels, seed))
|
||||
with tf.device(self.device_assigner.get_device(i)):
|
||||
seed = self.params.base_random_seed
|
||||
if seed != 0:
|
||||
seed += i
|
||||
# If using bagging, randomly select some of the input.
|
||||
tree_data = input_data
|
||||
tree_labels = input_labels
|
||||
if self.params.bagging_fraction < 1.0:
|
||||
# TODO(thomaswc): This does sampling without replacment. Consider
|
||||
# also allowing sampling with replacement as an option.
|
||||
batch_size = tf.slice(tf.shape(input_data), [0], [1])
|
||||
r = tf.random_uniform(batch_size, seed=seed)
|
||||
mask = tf.less(r, tf.ones_like(r) * self.params.bagging_fraction)
|
||||
gather_indices = tf.squeeze(tf.where(mask), squeeze_dims=[1])
|
||||
# TODO(thomaswc): Calculate out-of-bag data and labels, and store
|
||||
# them for use in calculating statistics later.
|
||||
tree_data = tf.gather(input_data, gather_indices)
|
||||
tree_labels = tf.gather(input_labels, gather_indices)
|
||||
tree_graphs.append(
|
||||
self.trees[i].training_graph(tree_data, tree_labels, seed))
|
||||
return tf.group(*tree_graphs)
|
||||
|
||||
def inference_graph(self, input_data):
|
||||
@ -265,9 +299,23 @@ class RandomForestGraphs(object):
|
||||
"""
|
||||
probabilities = []
|
||||
for i in range(self.params.num_trees):
|
||||
probabilities.append(self.trees[i].inference_graph(input_data))
|
||||
all_predict = tf.pack(probabilities)
|
||||
return tf.reduce_sum(all_predict, 0) / self.params.num_trees
|
||||
with tf.device(self.device_assigner.get_device(i)):
|
||||
probabilities.append(self.trees[i].inference_graph(input_data))
|
||||
with tf.device(self.device_assigner.get_device(0)):
|
||||
all_predict = tf.pack(probabilities)
|
||||
return tf.reduce_sum(all_predict, 0) / self.params.num_trees
|
||||
|
||||
def average_size(self):
|
||||
"""Constructs a TF graph for evaluating the average size of a forest.
|
||||
|
||||
Returns:
|
||||
The average number of nodes over the trees.
|
||||
"""
|
||||
sizes = []
|
||||
for i in range(self.params.num_trees):
|
||||
with tf.device(self.device_assigner.get_device(i)):
|
||||
sizes.append(self.trees[i].size())
|
||||
return tf.reduce_mean(tf.pack(sizes))
|
||||
|
||||
def average_impurity(self):
|
||||
"""Constructs a TF graph for evaluating the leaf impurity of a forest.
|
||||
@ -277,9 +325,17 @@ class RandomForestGraphs(object):
|
||||
"""
|
||||
impurities = []
|
||||
for i in range(self.params.num_trees):
|
||||
impurities.append(self.trees[i].average_impurity(self.variables[i]))
|
||||
with tf.device(self.device_assigner.get_device(i)):
|
||||
impurities.append(self.trees[i].average_impurity())
|
||||
return tf.reduce_mean(tf.pack(impurities))
|
||||
|
||||
def get_stats(self, session):
|
||||
tree_stats = []
|
||||
for i in range(self.params.num_trees):
|
||||
with tf.device(self.device_assigner.get_device(i)):
|
||||
tree_stats.append(self.trees[i].get_stats(session))
|
||||
return ForestStats(tree_stats, self.params)
|
||||
|
||||
|
||||
class RandomTreeGraphs(object):
|
||||
"""Builds TF graphs for random tree training and inference."""
|
||||
@ -394,6 +450,7 @@ class RandomTreeGraphs(object):
|
||||
with tf.control_dependencies([node_update_op]):
|
||||
def f1():
|
||||
return self.variables.non_fertile_leaf_scores
|
||||
|
||||
def f2():
|
||||
counts = tf.gather(self.variables.node_per_class_weights,
|
||||
self.variables.non_fertile_leaves)
|
||||
@ -535,3 +592,18 @@ class RandomTreeGraphs(object):
|
||||
counts = tf.gather(self.variables.node_per_class_weights, leaves)
|
||||
impurity = self._weighted_gini(counts)
|
||||
return tf.reduce_sum(impurity) / tf.reduce_sum(counts + 1.0)
|
||||
|
||||
def size(self):
|
||||
"""Constructs a TF graph for evaluating the current number of nodes.
|
||||
|
||||
Returns:
|
||||
The current number of nodes in the tree.
|
||||
"""
|
||||
return self.variables.end_of_tree - 1
|
||||
|
||||
def get_stats(self, session):
|
||||
num_nodes = self.variables.end_of_tree.eval(session=session) - 1
|
||||
num_leaves = tf.where(
|
||||
tf.equal(tf.squeeze(tf.slice(self.variables.tree, [0, 0], [-1, 1])),
|
||||
LEAF_NODE)).eval(session=session).shape[0]
|
||||
return TreeStats(num_nodes, num_leaves)
|
||||
|
@ -27,6 +27,37 @@ from tensorflow.python.platform import googletest
|
||||
|
||||
class TensorForestTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testForestHParams(self):
|
||||
hparams = tensor_forest.ForestHParams(
|
||||
num_classes=2, num_trees=100, max_nodes=1000,
|
||||
num_features=60).fill()
|
||||
self.assertEquals(2, hparams.num_classes)
|
||||
# 2 * ceil(log_2(1000)) = 20
|
||||
self.assertEquals(20, hparams.max_depth)
|
||||
# sqrt(num_features) < 10, so num_splits_to_consider should be 10.
|
||||
self.assertEquals(10, hparams.num_splits_to_consider)
|
||||
# Don't have more fertile nodes than max # leaves, which is 500.
|
||||
self.assertEquals(500, hparams.max_fertile_nodes)
|
||||
# We didn't set either of these, so they should be equal
|
||||
self.assertEquals(hparams.split_after_samples,
|
||||
hparams.valid_leaf_threshold)
|
||||
# split_after_samples is larger than 10
|
||||
self.assertEquals(1, hparams.split_initializations_per_input)
|
||||
self.assertEquals(0, hparams.base_random_seed)
|
||||
|
||||
def testForestHParamsBigTree(self):
|
||||
hparams = tensor_forest.ForestHParams(
|
||||
num_classes=2, num_trees=100, max_nodes=1000000,
|
||||
split_after_samples=25,
|
||||
num_features=1000).fill()
|
||||
self.assertEquals(40, hparams.max_depth)
|
||||
# sqrt(1000) = 31.63...
|
||||
self.assertEquals(32, hparams.num_splits_to_consider)
|
||||
# 1000000 / 32 = 31250
|
||||
self.assertEquals(31250, hparams.max_fertile_nodes)
|
||||
# floor(31.63 / 25) = 1
|
||||
self.assertEquals(1, hparams.split_initializations_per_input)
|
||||
|
||||
def testTrainingConstruction(self):
|
||||
input_data = [[-1., 0.], [-1., 2.], # node 1
|
||||
[1., 0.], [1., -2.]] # node 2
|
||||
@ -50,6 +81,14 @@ class TensorForestTest(test_util.TensorFlowTestCase):
|
||||
graph = graph_builder.inference_graph(input_data)
|
||||
self.assertTrue(isinstance(graph, tf.Tensor))
|
||||
|
||||
def testImpurityConstruction(self):
|
||||
params = tensor_forest.ForestHParams(
|
||||
num_classes=4, num_features=2, num_trees=10, max_nodes=1000).fill()
|
||||
|
||||
graph_builder = tensor_forest.RandomForestGraphs(params)
|
||||
graph = graph_builder.average_impurity()
|
||||
self.assertTrue(isinstance(graph, tf.Tensor))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
googletest.main()
|
||||
|
@ -143,7 +143,6 @@ cc_library(
|
||||
"lib/core/bits.h",
|
||||
"lib/core/casts.h",
|
||||
"lib/core/coding.h",
|
||||
"lib/core/command_line_flags.h", # TODO(vrv): Delete.
|
||||
"lib/core/errors.h",
|
||||
"lib/core/notification.h",
|
||||
"lib/core/status.h",
|
||||
|
@ -35,6 +35,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/coding.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/lib/gtl/stl_util.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
@ -713,4 +714,36 @@ void Tensor::FillDescription(TensorDescription* description) const {
|
||||
}
|
||||
}
|
||||
|
||||
gtl::InlinedVector<int64, 5> Tensor::ComputeFlatInnerDims(
|
||||
int64 num_out_dims) const {
|
||||
gtl::InlinedVector<int64, 5> out_dims(num_out_dims, 0);
|
||||
const int64 num_elements = NumElements();
|
||||
if (num_elements != 0) {
|
||||
int64 prod_out_dims = 1;
|
||||
for (int64 out_dim = num_out_dims - 1; out_dim > 0; --out_dim) {
|
||||
const int64 in_dim = out_dim + (dims() - num_out_dims);
|
||||
out_dims[out_dim] =
|
||||
(in_dim >= dims() || in_dim < 0) ? 1 : dim_size(in_dim);
|
||||
prod_out_dims *= out_dims[out_dim];
|
||||
}
|
||||
out_dims[0] = num_elements / prod_out_dims;
|
||||
}
|
||||
return out_dims;
|
||||
}
|
||||
|
||||
gtl::InlinedVector<int64, 5> Tensor::ComputeFlatOuterDims(
|
||||
int64 num_out_dims) const {
|
||||
gtl::InlinedVector<int64, 5> out_dims(num_out_dims, 0);
|
||||
const int64 num_elements = NumElements();
|
||||
if (num_elements != 0) {
|
||||
int64 prod_out_dims = 1;
|
||||
for (int64 out_dim = 0; out_dim < num_out_dims - 1; ++out_dim) {
|
||||
out_dims[out_dim] = out_dim >= dims() ? 1 : dim_size(out_dim);
|
||||
prod_out_dims *= out_dims[out_dim];
|
||||
}
|
||||
out_dims[num_out_dims - 1] = num_elements / prod_out_dims;
|
||||
}
|
||||
return out_dims;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
@ -243,40 +244,28 @@ class Tensor {
|
||||
///
|
||||
/// ```
|
||||
template <typename T>
|
||||
typename TTypes<T>::Flat flat();
|
||||
typename TTypes<T>::Flat flat() {
|
||||
return shaped<T, 1>({NumElements()});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
typename TTypes<T>::UnalignedFlat unaligned_flat() {
|
||||
return unaligned_shaped<T, 1>({NumElements()});
|
||||
}
|
||||
|
||||
/// Returns the data as an Eigen::Tensor with 2 dimensions, collapsing all
|
||||
/// Tensor dimensions but the last one into the first dimension of the result.
|
||||
template <typename T>
|
||||
typename TTypes<T>::Matrix flat_inner_dims() {
|
||||
int64 last_size = dims() > 0 ? dim_size(dims() - 1) : 1;
|
||||
if (last_size == 0) {
|
||||
DCHECK_EQ(NumElements(), 0);
|
||||
// Return something empty, avoiding divide by 0
|
||||
return shaped<T, 2>({0, 0});
|
||||
} else {
|
||||
return shaped<T, 2>({NumElements() / last_size, last_size});
|
||||
}
|
||||
}
|
||||
/// Returns the data as an Eigen::Tensor with NDIMS dimensions, collapsing all
|
||||
/// Tensor dimensions but the last NDIMS-1 into the first dimension of the
|
||||
/// result. If NDIMS > dims() then leading dimensions of size 1 will be
|
||||
/// added to make the output rank NDIMS.
|
||||
template <typename T, size_t NDIMS = 2>
|
||||
typename TTypes<T, NDIMS>::Tensor flat_inner_dims();
|
||||
|
||||
/// Returns the data as an Eigen::Tensor with 2 dimensions, collapsing all
|
||||
/// Tensor dimensions but the first one into the last dimension of the result.
|
||||
template <typename T>
|
||||
typename TTypes<T>::Matrix flat_outer_dims() {
|
||||
int64 first_size = dims() > 0 ? dim_size(0) : 1;
|
||||
if (first_size == 0) {
|
||||
DCHECK_EQ(NumElements(), 0);
|
||||
// Return something empty, avoiding divide by 0
|
||||
return shaped<T, 2>({0, 0});
|
||||
} else {
|
||||
return shaped<T, 2>({first_size, NumElements() / first_size});
|
||||
}
|
||||
}
|
||||
/// Returns the data as an Eigen::Tensor with NDIMS dimensions, collapsing all
|
||||
/// Tensor dimensions but the first NDIMS-1 into the last dimension of the
|
||||
/// result. If NDIMS > dims() then trailing dimensions of size 1 will be
|
||||
/// added to make the output rank NDIMS.
|
||||
template <typename T, size_t NDIMS = 2>
|
||||
typename TTypes<T, NDIMS>::Tensor flat_outer_dims();
|
||||
|
||||
template <typename T, size_t NDIMS>
|
||||
typename TTypes<T, NDIMS>::Tensor shaped(gtl::ArraySlice<int64> new_sizes);
|
||||
@ -308,31 +297,19 @@ class Tensor {
|
||||
typename TTypes<T, NDIMS>::ConstTensor tensor() const;
|
||||
|
||||
template <typename T>
|
||||
typename TTypes<T>::ConstFlat flat() const;
|
||||
typename TTypes<T>::ConstFlat flat() const {
|
||||
return shaped<T, 1>({NumElements()});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
typename TTypes<T>::UnalignedConstFlat unaligned_flat() const {
|
||||
return unaligned_shaped<T, 1>({NumElements()});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
typename TTypes<T>::ConstMatrix flat_inner_dims() const {
|
||||
int64 last_size = dims() > 0 ? dim_size(dims() - 1) : 1;
|
||||
if (last_size == 0) {
|
||||
DCHECK_EQ(NumElements(), 0);
|
||||
// Return something empty, avoiding divide by 0
|
||||
return shaped<T, 2>({0, 0});
|
||||
} else {
|
||||
return shaped<T, 2>({NumElements() / last_size, last_size});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
typename TTypes<T>::ConstMatrix flat_outer_dims() const;
|
||||
|
||||
template <typename T, size_t NDIMS>
|
||||
typename TTypes<T, NDIMS>::ConstTensor shaped(
|
||||
gtl::ArraySlice<int64> new_sizes) const;
|
||||
|
||||
template <typename T, size_t NDIMS>
|
||||
typename TTypes<T, NDIMS>::UnalignedConstTensor unaligned_shaped(
|
||||
gtl::ArraySlice<int64> new_sizes) const;
|
||||
@ -340,6 +317,12 @@ class Tensor {
|
||||
template <typename T>
|
||||
typename TTypes<T>::ConstScalar scalar() const;
|
||||
|
||||
template <typename T, size_t NDIMS = 2>
|
||||
typename TTypes<T, NDIMS>::ConstTensor flat_inner_dims() const;
|
||||
|
||||
template <typename T, size_t NDIMS = 2>
|
||||
typename TTypes<T, NDIMS>::ConstTensor flat_outer_dims() const;
|
||||
|
||||
/// Render the first `max_entries` values in `*this` into a string.
|
||||
string SummarizeValue(int64 max_entries) const;
|
||||
|
||||
@ -378,6 +361,8 @@ class Tensor {
|
||||
void FillDimsAndValidateCompatibleShape(
|
||||
gtl::ArraySlice<int64> new_sizes,
|
||||
Eigen::array<Eigen::DenseIndex, NDIMS>* dims) const;
|
||||
gtl::InlinedVector<int64, 5> ComputeFlatInnerDims(int64 num_out_dims) const;
|
||||
gtl::InlinedVector<int64, 5> ComputeFlatOuterDims(int64 num_out_dims) const;
|
||||
|
||||
TensorShape shape_;
|
||||
TensorBuffer* buf_;
|
||||
@ -534,26 +519,24 @@ typename TTypes<T>::ConstScalar Tensor::scalar() const {
|
||||
return typename TTypes<T>::ConstScalar(base<T>());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
typename TTypes<T>::Flat Tensor::flat() {
|
||||
return shaped<T, 1>({NumElements()});
|
||||
template <typename T, size_t NDIMS>
|
||||
typename TTypes<T, NDIMS>::Tensor Tensor::flat_inner_dims() {
|
||||
return shaped<T, NDIMS>(ComputeFlatInnerDims(NDIMS));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
typename TTypes<T>::ConstFlat Tensor::flat() const {
|
||||
return shaped<T, 1>({NumElements()});
|
||||
template <typename T, size_t NDIMS>
|
||||
typename TTypes<T, NDIMS>::Tensor Tensor::flat_outer_dims() {
|
||||
return shaped<T, NDIMS>(ComputeFlatOuterDims(NDIMS));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
typename TTypes<T>::ConstMatrix Tensor::flat_outer_dims() const {
|
||||
int64 first_size = dims() > 0 ? dim_size(0) : 1;
|
||||
if (first_size == 0) {
|
||||
DCHECK_EQ(NumElements(), 0);
|
||||
// Return something empty, avoiding divide by 0
|
||||
return shaped<T, 2>({0, 0});
|
||||
} else {
|
||||
return shaped<T, 2>({first_size, NumElements() / first_size});
|
||||
}
|
||||
template <typename T, size_t NDIMS>
|
||||
typename TTypes<T, NDIMS>::ConstTensor Tensor::flat_inner_dims() const {
|
||||
return shaped<T, NDIMS>(ComputeFlatInnerDims(NDIMS));
|
||||
}
|
||||
|
||||
template <typename T, size_t NDIMS>
|
||||
typename TTypes<T, NDIMS>::ConstTensor Tensor::flat_outer_dims() const {
|
||||
return shaped<T, NDIMS>(ComputeFlatOuterDims(NDIMS));
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -224,6 +224,49 @@ TEST(Tensor_Float, Reshape) {
|
||||
EXPECT_EQ(flat_inner_dims(0, 0), 0.01f);
|
||||
EXPECT_EQ(flat_inner_dims(23, 4), 0.02f);
|
||||
}
|
||||
{
|
||||
auto flat_outer_dims = t.flat_outer_dims<float>();
|
||||
EXPECT_EQ(2, flat_outer_dims.dimension(0));
|
||||
EXPECT_EQ(60, flat_outer_dims.dimension(1));
|
||||
EXPECT_EQ(flat_outer_dims(0, 0), 0.01f);
|
||||
EXPECT_EQ(flat_outer_dims(1, 59), 0.02f);
|
||||
}
|
||||
{
|
||||
auto flat_inner_dims = t.flat_inner_dims<float, 3>();
|
||||
EXPECT_EQ(6, flat_inner_dims.dimension(0));
|
||||
EXPECT_EQ(4, flat_inner_dims.dimension(1));
|
||||
EXPECT_EQ(5, flat_inner_dims.dimension(2));
|
||||
EXPECT_EQ(flat_inner_dims(0, 0, 0), 0.01f);
|
||||
EXPECT_EQ(flat_inner_dims(5, 3, 4), 0.02f);
|
||||
}
|
||||
{
|
||||
auto flat_outer_dims = t.flat_outer_dims<float, 3>();
|
||||
EXPECT_EQ(2, flat_outer_dims.dimension(0));
|
||||
EXPECT_EQ(3, flat_outer_dims.dimension(1));
|
||||
EXPECT_EQ(20, flat_outer_dims.dimension(2));
|
||||
EXPECT_EQ(flat_outer_dims(0, 0, 0), 0.01f);
|
||||
EXPECT_EQ(flat_outer_dims(1, 2, 19), 0.02f);
|
||||
}
|
||||
{
|
||||
auto flat_inner_dims = t.flat_inner_dims<float, 5>();
|
||||
EXPECT_EQ(1, flat_inner_dims.dimension(0));
|
||||
EXPECT_EQ(2, flat_inner_dims.dimension(1));
|
||||
EXPECT_EQ(3, flat_inner_dims.dimension(2));
|
||||
EXPECT_EQ(4, flat_inner_dims.dimension(3));
|
||||
EXPECT_EQ(5, flat_inner_dims.dimension(4));
|
||||
EXPECT_EQ(flat_inner_dims(0, 0, 0, 0, 0), 0.01f);
|
||||
EXPECT_EQ(flat_inner_dims(0, 1, 2, 3, 4), 0.02f);
|
||||
}
|
||||
{
|
||||
auto flat_outer_dims = t.flat_outer_dims<float, 5>();
|
||||
EXPECT_EQ(2, flat_outer_dims.dimension(0));
|
||||
EXPECT_EQ(3, flat_outer_dims.dimension(1));
|
||||
EXPECT_EQ(4, flat_outer_dims.dimension(2));
|
||||
EXPECT_EQ(5, flat_outer_dims.dimension(3));
|
||||
EXPECT_EQ(1, flat_outer_dims.dimension(4));
|
||||
EXPECT_EQ(flat_outer_dims(0, 0, 0, 0, 0), 0.01f);
|
||||
EXPECT_EQ(flat_outer_dims(1, 2, 3, 4, 0), 0.02f);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Tensor_Scalar, Basics) {
|
||||
|
@ -305,6 +305,23 @@ tf_kernel_libraries(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "batch_norm_op_test",
|
||||
size = "small",
|
||||
deps = [
|
||||
":batch_norm_op",
|
||||
":ops_testutil",
|
||||
":ops_util",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "concat_op_test",
|
||||
size = "small",
|
||||
|
62
tensorflow/core/kernels/batch_norm_op_test.cc
Normal file
62
tensorflow/core/kernels/batch_norm_op_test.cc
Normal file
@ -0,0 +1,62 @@
|
||||
/* Copyright 2015 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <vector>
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/fake_input.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class BatchNormOpTest : public OpsTestBase {};
|
||||
|
||||
TEST_F(BatchNormOpTest, Simple) {
|
||||
TF_EXPECT_OK(
|
||||
NodeDefBuilder("batch_norm_op", "BatchNormWithGlobalNormalization")
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Attr("scale_after_normalization", false)
|
||||
.Attr("variance_epsilon", 0.001)
|
||||
.Finalize(node_def()));
|
||||
TF_EXPECT_OK(InitOpWithGraphVersion(8));
|
||||
AddInputFromArray<float>(TensorShape({1, 1, 6, 2}),
|
||||
{1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6});
|
||||
AddInputFromArray<float>(TensorShape({2}), {10, 20});
|
||||
AddInputFromArray<float>(TensorShape({2}), {0.25, 0.5});
|
||||
AddInputFromArray<float>(TensorShape({2}), {0.1, 0.6});
|
||||
AddInputFromArray<float>(TensorShape({2}), {0.0, 0.0});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 1, 6, 2}));
|
||||
test::FillValues<float>(
|
||||
&expected, {-17.86, -22.00, -15.87, -20.59, -13.87, -19.18, -21.86,
|
||||
-33.31, -23.85, -34.72, -25.85, -36.13});
|
||||
test::ExpectTensorNear<float>(expected, *GetOutput(0), 0.01);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -94,10 +94,13 @@ class OpsTestBase : public ::testing::Test {
|
||||
// and output types as output.
|
||||
//
|
||||
// Returns the status of initialization.
|
||||
Status InitOp() {
|
||||
Status InitOp() { return InitOpWithGraphVersion(TF_GRAPH_DEF_VERSION); }
|
||||
|
||||
// Only use this directly if you have a deprecated op that you need to test.
|
||||
Status InitOpWithGraphVersion(int graph_def_version) {
|
||||
Status status;
|
||||
kernel_ = CreateOpKernel(device_type_, device_.get(), allocator(),
|
||||
node_def_, TF_GRAPH_DEF_VERSION, &status);
|
||||
node_def_, graph_def_version, &status);
|
||||
if (kernel_ != nullptr) input_types_ = kernel_->input_types();
|
||||
return status;
|
||||
}
|
||||
|
@ -66,7 +66,7 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T MaybeConj(T v) {
|
||||
#define MAYBE_CONJ(T) \
|
||||
template <> \
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T MaybeConj<T>(T v) { \
|
||||
return std::conj(v); \
|
||||
return Eigen::numext::conj(v); \
|
||||
}
|
||||
#endif
|
||||
|
||||
|
@ -1,121 +0,0 @@
|
||||
/* Copyright 2015 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/lib/core/command_line_flags.h"
|
||||
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// Templated function to convert a string to target values.
|
||||
// Return true if the conversion is successful. Otherwise, return false.
|
||||
template <typename T>
|
||||
bool StringToValue(const string& content, T* value);
|
||||
|
||||
template <>
|
||||
bool StringToValue<int32>(const string& content, int32* value) {
|
||||
return strings::safe_strto32(content, value);
|
||||
}
|
||||
|
||||
template <>
|
||||
bool StringToValue<string>(const string& content, string* value) {
|
||||
*value = content;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Parse a single argument by linearly searching through the command table.
|
||||
// The input format is: --argument=value.
|
||||
// Return OK if the argument is used. It store the extracted value into the
|
||||
// matching flag.
|
||||
// Return NOT_FOUND if the argument is not recognized.
|
||||
// Return INVALID_ARGUMENT if the command is recognized, but fails to extract
|
||||
// its value.
|
||||
template <typename T>
|
||||
Status ParseArgument(const string& argument) {
|
||||
for (auto& command :
|
||||
internal::CommandLineFlagRegistry<T>::Instance()->commands) {
|
||||
string prefix = strings::StrCat("--", command.name, "=");
|
||||
if (tensorflow::StringPiece(argument).starts_with(prefix)) {
|
||||
string content = argument.substr(prefix.length());
|
||||
if (StringToValue<T>(content, command.value)) {
|
||||
return Status::OK();
|
||||
}
|
||||
return Status(error::INVALID_ARGUMENT,
|
||||
strings::StrCat("Cannot parse integer in: ", argument));
|
||||
}
|
||||
}
|
||||
return Status(error::NOT_FOUND,
|
||||
strings::StrCat("Unknown command: ", argument));
|
||||
}
|
||||
|
||||
// A specialization for booleans. The input format is:
|
||||
// "--argument" or "--noargument".
|
||||
// Parse a single argument by linearly searching through the command table.
|
||||
// Return OK if the argument is used. The value is stored in the matching flag.
|
||||
// Return NOT_FOUND if the argument is not recognized.
|
||||
template <>
|
||||
Status ParseArgument<bool>(const string& argument) {
|
||||
for (auto& command :
|
||||
internal::CommandLineFlagRegistry<bool>::Instance()->commands) {
|
||||
if (argument == strings::StrCat("--", command.name)) {
|
||||
*command.value = true;
|
||||
return Status::OK();
|
||||
} else if (argument == strings::StrCat("--no", command.name)) {
|
||||
*command.value = false;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
return Status(error::NOT_FOUND,
|
||||
strings::StrCat("Unknown command: ", argument));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status ParseCommandLineFlags(int* argc, char* argv[]) {
|
||||
int unused_argc = 1;
|
||||
for (int index = 1; index < *argc; ++index) {
|
||||
Status s;
|
||||
// Search bool commands.
|
||||
s = ParseArgument<bool>(argv[index]);
|
||||
if (s.ok()) {
|
||||
continue;
|
||||
}
|
||||
if (s.code() != error::NOT_FOUND) {
|
||||
return s;
|
||||
}
|
||||
// Search int32 commands.
|
||||
s = ParseArgument<int32>(argv[index]);
|
||||
if (s.ok()) {
|
||||
continue;
|
||||
}
|
||||
// Search string commands.
|
||||
s = ParseArgument<string>(argv[index]);
|
||||
if (s.ok()) {
|
||||
continue;
|
||||
}
|
||||
if (s.code() != error::NOT_FOUND) {
|
||||
return s;
|
||||
}
|
||||
// Pointer swap the unused argument to the front.
|
||||
std::swap(argv[unused_argc++], argv[index]);
|
||||
}
|
||||
*argc = unused_argc;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -1,80 +0,0 @@
|
||||
/* Copyright 2015 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LIB_CORE_COMMAND_LINE_FLAGS_H_
|
||||
#define TENSORFLOW_LIB_CORE_COMMAND_LINE_FLAGS_H_
|
||||
|
||||
#include <vector>
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
template <typename T>
|
||||
struct CommandLineFlagRegistry {
|
||||
static CommandLineFlagRegistry* Instance() {
|
||||
static CommandLineFlagRegistry instance_;
|
||||
return &instance_;
|
||||
}
|
||||
struct Command {
|
||||
string name;
|
||||
T* value;
|
||||
string text;
|
||||
};
|
||||
std::vector<Command> commands;
|
||||
|
||||
private:
|
||||
CommandLineFlagRegistry() {}
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(CommandLineFlagRegistry);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct CommandLineFlagRegister {
|
||||
CommandLineFlagRegister(const string& name, T* val, const string& text) {
|
||||
CommandLineFlagRegistry<T>::Instance()->commands.push_back(
|
||||
{name, val, text});
|
||||
}
|
||||
};
|
||||
|
||||
#define TF_DEFINE_variable(type, name, default_value, text) \
|
||||
type FLAGS_##name = default_value; \
|
||||
namespace TF_flags_internal { \
|
||||
tensorflow::internal::CommandLineFlagRegister<type> \
|
||||
TF_flags_internal_var_##name(#name, &FLAGS_##name, text); \
|
||||
} // namespace TF_flags_internal
|
||||
|
||||
} // namespace internal
|
||||
|
||||
#define TF_DEFINE_int32(name, default_value, text) \
|
||||
TF_DEFINE_variable(tensorflow::int32, name, default_value, text);
|
||||
|
||||
#define TF_DEFINE_bool(name, default_value, text) \
|
||||
TF_DEFINE_variable(bool, name, default_value, text);
|
||||
|
||||
#define TF_DEFINE_string(name, default_value, text) \
|
||||
TF_DEFINE_variable(string, name, default_value, text);
|
||||
|
||||
// Parse argv[1]..argv[*argc-1] to options. Remove used arguments from the argv.
|
||||
// Returned the number of unused arguments in *argc.
|
||||
// Return error Status if the parsing encounters errors.
|
||||
// TODO(opensource): switch to a command line argument parser that can be
|
||||
// shared with other tests.
|
||||
Status ParseCommandLineFlags(int* argc, char* argv[]);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_LIB_CORE_COMMAND_LINE_FLAGS_H_
|
@ -120,9 +120,12 @@ variable to its initial value.
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`initial_value`</b>: A `Tensor`, or Python object convertible to a `Tensor`.
|
||||
The initial value for the Variable. Must have a shape specified unless
|
||||
`validate_shape` is set to False.
|
||||
* <b>`initial_value`</b>: A `Tensor`, or Python object convertible to a `Tensor`,
|
||||
which is the initial value for the Variable. The initial value must have
|
||||
a shape specified unless `validate_shape` is set to False. Can also be a
|
||||
callable with no argument that returns the initial value when called. In
|
||||
that case, `dtype` must be specified. (Note that initializer functions
|
||||
from init_ops.py must first be bound to a shape before being used here.)
|
||||
* <b>`trainable`</b>: If `True`, the default, also adds the variable to the graph
|
||||
collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
|
||||
the default list of variables to use by the `Optimizer` classes.
|
||||
|
@ -1955,6 +1955,7 @@ on the parameters to the constructor and may include:
|
||||
##### Raises:
|
||||
|
||||
|
||||
* <b>`RuntimeError`</b>: If called with a non-chief Supervisor.
|
||||
* <b>`ValueError`</b>: If not `logdir` was passed to the constructor as the
|
||||
services need a log directory.
|
||||
|
||||
@ -2182,6 +2183,7 @@ on the parameters to the constructor and may include:
|
||||
##### Raises:
|
||||
|
||||
|
||||
* <b>`RuntimeError`</b>: If called with a non-chief Supervisor.
|
||||
* <b>`ValueError`</b>: If not `logdir` was passed to the constructor as the
|
||||
services need a log directory.
|
||||
|
||||
@ -2409,7 +2411,7 @@ Start threads for `QueueRunners`.
|
||||
|
||||
#### `tf.train.Supervisor.summary_op` {#Supervisor.summary_op}
|
||||
|
||||
Return the Summary Tensor used by the supervisor.
|
||||
Return the Summary Tensor used by the chief supervisor.
|
||||
|
||||
##### Returns:
|
||||
|
||||
@ -2420,7 +2422,7 @@ Return the Summary Tensor used by the supervisor.
|
||||
|
||||
#### `tf.train.Supervisor.summary_writer` {#Supervisor.summary_writer}
|
||||
|
||||
Return the SummaryWriter used by the supervisor.
|
||||
Return the SummaryWriter used by the chief supervisor.
|
||||
|
||||
##### Returns:
|
||||
|
||||
|
@ -19,6 +19,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
@ -208,5 +209,59 @@ class RNNCellTest(tf.test.TestCase):
|
||||
0.13248, 0.13248]])
|
||||
|
||||
|
||||
class SlimRNNCellTest(tf.test.TestCase):
|
||||
|
||||
def testBasicRNNCell(self):
|
||||
with self.test_session() as sess:
|
||||
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
|
||||
x = tf.zeros([1, 2])
|
||||
m = tf.zeros([1, 2])
|
||||
my_cell = functools.partial(basic_rnn_cell, num_units=2)
|
||||
g, _ = tf.nn.rnn_cell.SlimRNNCell(my_cell)(x, m)
|
||||
sess.run([tf.initialize_all_variables()])
|
||||
res = sess.run([g], {x.name: np.array([[1., 1.]]),
|
||||
m.name: np.array([[0.1, 0.1]])})
|
||||
self.assertEqual(res[0].shape, (1, 2))
|
||||
|
||||
def testBasicRNNCellMatch(self):
|
||||
batch_size = 32
|
||||
input_size = 100
|
||||
num_units = 10
|
||||
with self.test_session() as sess:
|
||||
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
|
||||
inputs = tf.random_uniform((batch_size, input_size))
|
||||
_, initial_state = basic_rnn_cell(inputs, None, num_units)
|
||||
my_cell = functools.partial(basic_rnn_cell, num_units=num_units)
|
||||
slim_cell = tf.nn.rnn_cell.SlimRNNCell(my_cell)
|
||||
slim_outputs, slim_state = slim_cell(inputs, initial_state)
|
||||
rnn_cell = tf.nn.rnn_cell.BasicRNNCell(num_units)
|
||||
outputs, state = rnn_cell(inputs, initial_state)
|
||||
self.assertEqual(slim_outputs.get_shape(), outputs.get_shape())
|
||||
self.assertEqual(slim_state.get_shape(), state.get_shape())
|
||||
sess.run([tf.initialize_all_variables()])
|
||||
res = sess.run([slim_outputs, slim_state, outputs, state])
|
||||
self.assertAllClose(res[0], res[2])
|
||||
self.assertAllClose(res[1], res[3])
|
||||
|
||||
|
||||
def basic_rnn_cell(inputs, state, num_units, scope=None):
|
||||
if state is None:
|
||||
if inputs is not None:
|
||||
batch_size = inputs.get_shape()[0]
|
||||
dtype = inputs.dtype
|
||||
else:
|
||||
batch_size = 0
|
||||
dtype = tf.float32
|
||||
init_output = tf.zeros(tf.pack([batch_size, num_units]), dtype=dtype)
|
||||
init_state = tf.zeros(tf.pack([batch_size, num_units]), dtype=dtype)
|
||||
init_output.set_shape([batch_size, num_units])
|
||||
init_state.set_shape([batch_size, num_units])
|
||||
return init_output, init_state
|
||||
else:
|
||||
with tf.variable_op_scope([inputs, state], scope, "BasicRNNCell"):
|
||||
output = tf.tanh(tf.nn.rnn_cell.linear([inputs, state],
|
||||
num_units, True))
|
||||
return output, output
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
||||
|
@ -302,6 +302,53 @@ class VariablesTestCase(tf.test.TestCase):
|
||||
self.assertEqual(var.op.device, init_op.device)
|
||||
sess.run(init_op)
|
||||
|
||||
def testInitializerFunction(self):
|
||||
value = [[-42], [133.7]]
|
||||
shape = [2, 1]
|
||||
with self.test_session():
|
||||
initializer = lambda: tf.constant(value)
|
||||
with self.assertRaises(ValueError):
|
||||
# Checks that dtype must be specified.
|
||||
tf.Variable(initializer)
|
||||
|
||||
v1 = tf.Variable(initializer, dtype=tf.float32)
|
||||
self.assertEqual(shape, v1.get_shape())
|
||||
self.assertAllClose(value, v1.initial_value.eval())
|
||||
with self.assertRaises(tf.errors.FailedPreconditionError):
|
||||
v1.eval()
|
||||
|
||||
v2 = tf.Variable(tf.neg(v1.initialized_value()), dtype=tf.float32)
|
||||
self.assertEqual(v1.get_shape(), v2.get_shape())
|
||||
self.assertAllClose(np.negative(value), v2.initial_value.eval())
|
||||
|
||||
# Once v2.initial_value.eval() has been called, v1 has effectively been
|
||||
# initialized.
|
||||
self.assertAllClose(value, v1.eval())
|
||||
|
||||
with self.assertRaises(tf.errors.FailedPreconditionError):
|
||||
v2.eval()
|
||||
tf.initialize_all_variables().run()
|
||||
self.assertAllClose(np.negative(value), v2.eval())
|
||||
|
||||
def testInitializerFunctionDevicePlacement(self):
|
||||
with self.test_session():
|
||||
initializer = lambda: tf.constant(42.0)
|
||||
with tf.device("/cpu:100"):
|
||||
v1 = tf.Variable(initializer, dtype=tf.float32, name="v1")
|
||||
expected_device = "/device:CPU:100"
|
||||
expected_group_v1 = [b"loc:@v1"]
|
||||
self.assertEqual(expected_device, v1.op.device)
|
||||
self.assertEqual(expected_group_v1, v1.op.colocation_groups())
|
||||
for i in v1.initializer.inputs:
|
||||
self.assertEqual(expected_device, i.op.device)
|
||||
self.assertEqual(expected_group_v1, i.op.colocation_groups())
|
||||
|
||||
v2 = tf.Variable(initializer, dtype=tf.float32, name="v2")
|
||||
expected_group_v2 = [b"loc:@v2"]
|
||||
self.assertEqual(expected_group_v2, v2.op.colocation_groups())
|
||||
for i in v2.initializer.inputs:
|
||||
self.assertEqual(expected_group_v2, i.op.colocation_groups())
|
||||
|
||||
|
||||
class IsInitializedTest(tf.test.TestCase):
|
||||
|
||||
|
@ -167,19 +167,22 @@ def create_partitioned_variables(
|
||||
slice_offset[slice_dim] += var_shape[slice_dim]
|
||||
|
||||
if callable(initializer):
|
||||
init_val = initializer(var_shape, dtype=dtype)
|
||||
init_val = ops.convert_to_tensor(init_val, dtype=dtype)
|
||||
init = initializer
|
||||
init_shape = var_shape
|
||||
elif isinstance(initializer, ops.Tensor):
|
||||
init_val = array_ops.slice(initializer, var_offset, var_shape)
|
||||
init = array_ops.slice(initializer, var_offset, var_shape)
|
||||
# Use the dtype of the given tensor.
|
||||
dtype = init_val.dtype.base_dtype
|
||||
dtype = init.dtype.base_dtype
|
||||
init_shape = None
|
||||
else:
|
||||
init_val = ops.convert_to_tensor(initializer, dtype=dtype)
|
||||
init_val = array_ops.slice(init_val, var_offset, var_shape)
|
||||
init = ops.convert_to_tensor(initializer, dtype=dtype)
|
||||
init = array_ops.slice(init, var_offset, var_shape)
|
||||
init_shape = None
|
||||
|
||||
var = variable_scope.get_variable(name="part_%d" % i,
|
||||
shape=init_shape,
|
||||
dtype=dtype,
|
||||
initializer=init_val,
|
||||
initializer=init,
|
||||
trainable=trainable,
|
||||
collections=collections)
|
||||
|
||||
|
@ -661,6 +661,42 @@ class MultiRNNCell(RNNCell):
|
||||
return cur_inp, array_ops.concat(1, new_states)
|
||||
|
||||
|
||||
class SlimRNNCell(RNNCell):
|
||||
"""A simple wrapper for slim.rnn_cells."""
|
||||
|
||||
def __init__(self, cell_fn):
|
||||
"""Create a SlimRNNCell from a cell_fn.
|
||||
|
||||
Args:
|
||||
cell_fn: a function which takes (inputs, state, scope) and produces the
|
||||
outputs and the new_state. Additionally when called with inputs=None and
|
||||
state=None it should return (initial_outputs, initial_state).
|
||||
|
||||
Raises:
|
||||
TypeError: if cell_fn is not callable
|
||||
ValueError: if cell_fn cannot produce a valid initial state.
|
||||
"""
|
||||
if not callable(cell_fn):
|
||||
raise TypeError("cell_fn %s needs to be callable", cell_fn)
|
||||
self._cell_fn = cell_fn
|
||||
self._cell_name = cell_fn.func.__name__
|
||||
_, init_state = self._cell_fn(None, None)
|
||||
state_shape = init_state.get_shape()
|
||||
self._state_size = state_shape.with_rank(2)[1].value
|
||||
if self._state_size is None:
|
||||
raise ValueError("Initial state created by %s has invalid shape %s",
|
||||
self._cell_name, state_shape)
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
return self._state_size
|
||||
|
||||
def __call__(self, inputs, state, scope=None):
|
||||
scope = scope or self._cell_name
|
||||
output, state = self._cell_fn(inputs, state, scope=scope)
|
||||
return output, state
|
||||
|
||||
|
||||
def linear(args, output_size, bias, bias_start=0.0, scope=None):
|
||||
"""Linear map: sum_i(args[i] * W[i]), where W[i] is a variable.
|
||||
|
||||
|
@ -144,14 +144,19 @@ class _VariableStore(object):
|
||||
with ops.control_dependencies(None):
|
||||
if initializing_from_value:
|
||||
init_val = initializer
|
||||
variable_dtype = None
|
||||
else:
|
||||
with ops.name_scope(name + "/Initializer/"):
|
||||
init_val = initializer(shape.as_list(), dtype=dtype)
|
||||
init_val = lambda: initializer(shape.as_list(), dtype=dtype)
|
||||
variable_dtype = dtype.base_dtype
|
||||
|
||||
# Create the variable.
|
||||
v = variables.Variable(init_val, name=name, trainable=trainable,
|
||||
v = variables.Variable(initial_value=init_val,
|
||||
name=name,
|
||||
trainable=trainable,
|
||||
collections=collections,
|
||||
caching_device=caching_device)
|
||||
caching_device=caching_device,
|
||||
dtype=variable_dtype)
|
||||
|
||||
self._vars[name] = v
|
||||
logging.info("Created variable %s with shape %s and init %s", v.name,
|
||||
format(shape), initializer)
|
||||
|
@ -156,9 +156,12 @@ class Variable(object):
|
||||
variable to its initial value.
|
||||
|
||||
Args:
|
||||
initial_value: A `Tensor`, or Python object convertible to a `Tensor`.
|
||||
The initial value for the Variable. Must have a shape specified unless
|
||||
`validate_shape` is set to False.
|
||||
initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
|
||||
which is the initial value for the Variable. The initial value must have
|
||||
a shape specified unless `validate_shape` is set to False. Can also be a
|
||||
callable with no argument that returns the initial value when called. In
|
||||
that case, `dtype` must be specified. (Note that initializer functions
|
||||
from init_ops.py must first be bound to a shape before being used here.)
|
||||
trainable: If `True`, the default, also adds the variable to the graph
|
||||
collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
|
||||
the default list of variables to use by the `Optimizer` classes.
|
||||
@ -211,9 +214,12 @@ class Variable(object):
|
||||
"""Creates a new variable from arguments.
|
||||
|
||||
Args:
|
||||
initial_value: A `Tensor`, or Python object convertible to a `Tensor`.
|
||||
The initial value for the Variable. Must have a shape specified unless
|
||||
`validate_shape` is set to False.
|
||||
initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
|
||||
which is the initial value for the Variable. The initial value must have
|
||||
a shape specified unless `validate_shape` is set to False. Can also be a
|
||||
callable with no argument that returns the initial value when called. In
|
||||
that case, `dtype` must be specified. (Note that initializer functions
|
||||
from init_ops.py must first be bound to a shape before being used here.)
|
||||
trainable: If `True`, the default, also adds the variable to the graph
|
||||
collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
|
||||
the default list of variables to use by the `Optimizer` classes.
|
||||
@ -240,25 +246,62 @@ class Variable(object):
|
||||
"""
|
||||
if initial_value is None:
|
||||
raise ValueError("initial_value must be specified.")
|
||||
init_from_fn = callable(initial_value)
|
||||
if init_from_fn and dtype is None:
|
||||
raise ValueError(
|
||||
"dtype must also be specified when initial_value is callable.")
|
||||
|
||||
if collections is None:
|
||||
collections = [ops.GraphKeys.VARIABLES]
|
||||
if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
|
||||
collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
|
||||
with ops.control_dependencies(None):
|
||||
with ops.op_scope([initial_value], name, "Variable") as name:
|
||||
self._initial_value = ops.convert_to_tensor(initial_value,
|
||||
name="initial_value",
|
||||
dtype=dtype)
|
||||
initial_value_shape = self._initial_value.get_shape()
|
||||
if validate_shape and not initial_value_shape.is_fully_defined():
|
||||
raise ValueError("initial_value must have a shape specified: %s"
|
||||
% self._initial_value)
|
||||
shape_to_set = initial_value_shape if validate_shape else []
|
||||
with ops.op_scope(
|
||||
[] if init_from_fn else [initial_value], name, "Variable") as name:
|
||||
|
||||
self._variable = state_ops.variable_op(
|
||||
shape_to_set, self._initial_value.dtype.base_dtype,
|
||||
set_shape=validate_shape, name=name)
|
||||
# Get the initial value from a callable function. The real shape of the
|
||||
# variable will be set later, since under the init_from_fn case, the
|
||||
# shape won't be known until after the function is invoked.
|
||||
if init_from_fn:
|
||||
self._variable = state_ops.variable_op(
|
||||
[],
|
||||
dtype.base_dtype,
|
||||
set_shape=False,
|
||||
name=name)
|
||||
with ops.colocate_with(self._variable.op):
|
||||
with ops.name_scope("Initializer"):
|
||||
# Colocate the tensors created by the initial_value() function
|
||||
# with the variable itself.
|
||||
self._initial_value = ops.convert_to_tensor(initial_value(),
|
||||
name="initial_value",
|
||||
dtype=dtype)
|
||||
|
||||
# Or get the initial value from a Tensor or Python object.
|
||||
else:
|
||||
self._initial_value = ops.convert_to_tensor(initial_value,
|
||||
name="initial_value",
|
||||
dtype=dtype)
|
||||
# In this case, the variable op can't be created until after the
|
||||
# initial_value has been converted to a Tensor with a known type.
|
||||
self._variable = state_ops.variable_op(
|
||||
[],
|
||||
self._initial_value.dtype.base_dtype,
|
||||
set_shape=False,
|
||||
name=name)
|
||||
|
||||
# Manually overrides the variable's shape with the initial value's.
|
||||
if validate_shape:
|
||||
initial_value_shape = self._initial_value.get_shape()
|
||||
if not initial_value_shape.is_fully_defined():
|
||||
raise ValueError("initial_value must have a shape specified: %s"
|
||||
% self._initial_value)
|
||||
self._variable.set_shape(initial_value_shape)
|
||||
# TODO(b/28152992): Remove the below hack modifying the node_def shape
|
||||
# directly once set_shape() handles it.
|
||||
self._variable.op.node_def.attr["shape"].shape.CopyFrom(
|
||||
initial_value_shape.as_proto())
|
||||
|
||||
# Assigns initial value.
|
||||
with ops.colocate_with(self._variable.op):
|
||||
self._initializer_op = state_ops.assign(
|
||||
self._variable, self._initial_value,
|
||||
|
@ -79,3 +79,7 @@ def get_path_to_datafile(path):
|
||||
"""
|
||||
data_files_path = os.path.dirname(inspect.getfile(sys._getframe(1)))
|
||||
return os.path.join(data_files_path, path)
|
||||
|
||||
def readahead_file_path(path, unused_readahead=None):
|
||||
"""Readahead files not implemented; simply returns given path."""
|
||||
return path
|
||||
|
@ -22,6 +22,7 @@ from tensorflow.core.util import event_pb2
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.platform import app
|
||||
from tensorflow.python.platform import logging
|
||||
from tensorflow.python.platform import resource_loader
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
|
||||
@ -31,6 +32,7 @@ class EventFileLoader(object):
|
||||
def __init__(self, file_path):
|
||||
if file_path is None:
|
||||
raise ValueError('A file path is required')
|
||||
file_path = resource_loader.readahead_file_path(file_path)
|
||||
logging.debug('Opening a record reader pointing at %s', file_path)
|
||||
self._reader = pywrap_tensorflow.PyRecordReader_New(
|
||||
compat.as_bytes(file_path), 0)
|
||||
|
@ -326,19 +326,29 @@ class Supervisor(object):
|
||||
self._init_global_step(global_step=global_step)
|
||||
self._graph = graph
|
||||
self._is_chief = is_chief
|
||||
self._logdir = logdir
|
||||
self._save_summaries_secs = save_summaries_secs
|
||||
self._save_model_secs = save_model_secs
|
||||
self._recovery_wait_secs = recovery_wait_secs
|
||||
self._coord = coordinator.Coordinator()
|
||||
if logdir:
|
||||
self._started_threads = []
|
||||
self._recovery_wait_secs = recovery_wait_secs
|
||||
|
||||
# Only chief supervisors write event files, so only chief supervisors
|
||||
# should have event-writing properties. Set to None for non-chiefs.
|
||||
if self._is_chief:
|
||||
self._logdir = logdir
|
||||
self._save_summaries_secs = save_summaries_secs
|
||||
self._save_model_secs = save_model_secs
|
||||
else:
|
||||
self._logdir = None
|
||||
self._save_summaries_secs = None
|
||||
self._save_model_secs = None
|
||||
|
||||
if self._is_chief and self._logdir:
|
||||
self._save_path = os.path.join(self._logdir, checkpoint_basename)
|
||||
self._summary_writer = summary_io.SummaryWriter(self._logdir)
|
||||
else:
|
||||
self._save_path = None
|
||||
self._summary_writer = None
|
||||
|
||||
self._init_session_manager(session_manager=session_manager)
|
||||
self._started_threads = []
|
||||
self._verify_setup()
|
||||
# The graph is not allowed to change anymore.
|
||||
graph.finalize()
|
||||
@ -520,7 +530,7 @@ class Supervisor(object):
|
||||
|
||||
@property
|
||||
def summary_writer(self):
|
||||
"""Return the SummaryWriter used by the supervisor.
|
||||
"""Return the SummaryWriter used by the chief supervisor.
|
||||
|
||||
Returns:
|
||||
A SummaryWriter.
|
||||
@ -529,7 +539,7 @@ class Supervisor(object):
|
||||
|
||||
@property
|
||||
def summary_op(self):
|
||||
"""Return the Summary Tensor used by the supervisor.
|
||||
"""Return the Summary Tensor used by the chief supervisor.
|
||||
|
||||
Returns:
|
||||
A string Tensor for the summary or `None`.
|
||||
@ -583,8 +593,7 @@ class Supervisor(object):
|
||||
|
||||
def _write_graph(self):
|
||||
"""Writes graph_def to `logdir` and adds it to summary if applicable."""
|
||||
if not self._is_chief:
|
||||
return
|
||||
assert self._is_chief
|
||||
if self._logdir:
|
||||
training_util.write_graph(self._graph.as_graph_def(),
|
||||
self._logdir, "graph.pbtxt")
|
||||
@ -610,11 +619,13 @@ class Supervisor(object):
|
||||
sv.coord.Join(<list of threads>)
|
||||
|
||||
Raises:
|
||||
RuntimeError: If called with a non-chief Supervisor.
|
||||
ValueError: If not `logdir` was passed to the constructor as the
|
||||
services need a log directory.
|
||||
"""
|
||||
if not self._is_chief:
|
||||
return
|
||||
raise RuntimeError("Only chief supervisor can start standard services. "
|
||||
"Because only cheif supervisors can write events.")
|
||||
if not self._logdir:
|
||||
logging.warning("Standard services need a 'logdir' "
|
||||
"passed to the SessionManager")
|
||||
@ -812,14 +823,18 @@ class Supervisor(object):
|
||||
TypeError: if 'summary' is not a Summary proto or a string.
|
||||
RuntimeError: if the Supervisor was created without a `logdir`.
|
||||
"""
|
||||
if not self._logdir:
|
||||
raise RuntimeError("summary_computed() requires a logdir")
|
||||
if not self._summary_writer:
|
||||
raise RuntimeError("Writing a summary requires a summary writer.")
|
||||
if global_step is None and self.global_step is not None:
|
||||
global_step = training_util.global_step(sess, self.global_step)
|
||||
if self._summary_writer:
|
||||
self._summary_writer.add_summary(summary, global_step)
|
||||
self._summary_writer.add_summary(summary, global_step)
|
||||
|
||||
def _default_global_step_tensor(self):
|
||||
"""Returns the global_step from the default graph.
|
||||
|
||||
Returns:
|
||||
The global step `Tensor` or `None`.
|
||||
"""
|
||||
try:
|
||||
gs = ops.get_default_graph().get_tensor_by_name("global_step:0")
|
||||
if gs.dtype.base_dtype in [dtypes.int32, dtypes.int64]:
|
||||
|
@ -73,12 +73,11 @@ class SupervisorTest(tf.test.TestCase):
|
||||
sess.close()
|
||||
sv.stop()
|
||||
|
||||
def testSummary(self):
|
||||
def testChiefCanWriteEvents(self):
|
||||
logdir = self._TestDir("basics")
|
||||
with tf.Graph().as_default():
|
||||
const = tf.constant([1.0, 2.0, 3.0])
|
||||
summ = tf.scalar_summary(["c1", "c2", "c3"], const)
|
||||
sv = tf.train.Supervisor(logdir=logdir, summary_op=None)
|
||||
summ = tf.scalar_summary(["c1", "c2", "c3"], tf.constant([1.0, 2.0, 3.0]))
|
||||
sv = tf.train.Supervisor(is_chief=True, logdir=logdir, summary_op=None)
|
||||
sess = sv.prepare_or_wait_for_session("")
|
||||
sv.summary_computed(sess, sess.run(summ))
|
||||
sess.close()
|
||||
@ -113,13 +112,31 @@ class SupervisorTest(tf.test.TestCase):
|
||||
# We should be done.
|
||||
self.assertRaises(StopIteration, lambda: next(rr))
|
||||
|
||||
def testNonChiefCannotWriteEvents(self):
|
||||
|
||||
def _summary_computed():
|
||||
with tf.Graph().as_default():
|
||||
sv = tf.train.Supervisor(is_chief=False)
|
||||
sess = sv.prepare_or_wait_for_session("")
|
||||
summ = tf.scalar_summary(["c1", "c2"], tf.constant([1.0, 2.0]))
|
||||
sv.summary_computed(sess, sess.run(summ))
|
||||
|
||||
def _start_standard_services():
|
||||
with tf.Graph().as_default():
|
||||
sv = tf.train.Supervisor(is_chief=False)
|
||||
sess = sv.prepare_or_wait_for_session("")
|
||||
sv.start_standard_services(sess)
|
||||
|
||||
self.assertRaises(RuntimeError, _summary_computed)
|
||||
self.assertRaises(RuntimeError, _start_standard_services)
|
||||
|
||||
def testNoLogdirButWantSummary(self):
|
||||
with tf.Graph().as_default():
|
||||
const = tf.constant([1.0, 2.0, 3.0])
|
||||
summ = tf.scalar_summary(["c1", "c2", "c3"], const)
|
||||
sv = tf.train.Supervisor(logdir="", summary_op=None)
|
||||
sess = sv.prepare_or_wait_for_session("")
|
||||
with self.assertRaisesRegexp(RuntimeError, "requires a logdir"):
|
||||
with self.assertRaisesRegexp(RuntimeError, "requires a summary writer"):
|
||||
sv.summary_computed(sess, sess.run(summ))
|
||||
|
||||
def testNoLogdirSucceeds(self):
|
||||
|
66
tensorflow/tools/benchmark/BUILD
Normal file
66
tensorflow/tools/benchmark/BUILD
Normal file
@ -0,0 +1,66 @@
|
||||
# Description:
|
||||
# Benchmark utility that can run on desktop and Android.
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_copts")
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
cc_library(
|
||||
name = "benchmark_model_lib",
|
||||
srcs = [
|
||||
"benchmark_model.cc",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:tensorflow",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
||||
# This binary may be built for either desktop or Android.
|
||||
# A typical Android build command will look like the following:
|
||||
# bazel build -c opt tensorflow/core:android_tensorflow_lib \
|
||||
# --crosstool_top=//external:android/crosstool \
|
||||
# --cpu=armeabi-v7a \
|
||||
# --host_crosstool_top=@bazel_tools//tools/cpp:toolchain
|
||||
#
|
||||
# NOTE: currently '-pthread' must be removed from the LINK_OPTS variable
|
||||
# in google/protobuf/BUILD to sucessfully build for Android. This is temporary
|
||||
# pending an update of the version of the protobuf library that Tensorflow
|
||||
# uses.
|
||||
cc_binary(
|
||||
name = "benchmark_model",
|
||||
copts = tf_copts(),
|
||||
linkopts = select({
|
||||
"//tensorflow:android": [
|
||||
"-pie",
|
||||
"-s",
|
||||
"-landroid",
|
||||
"-ljnigraphics",
|
||||
"-llog",
|
||||
"-lm",
|
||||
"-z defs",
|
||||
"-s",
|
||||
"-Wl,--icf=all", # Identical Code Folding
|
||||
"-Wl,--exclude-libs,ALL", # Exclude syms in all libs from auto export
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
linkstatic = 1,
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":benchmark_model_lib"],
|
||||
)
|
57
tensorflow/tools/benchmark/README.md
Normal file
57
tensorflow/tools/benchmark/README.md
Normal file
@ -0,0 +1,57 @@
|
||||
# Tensorflow Model Benchmark Tool
|
||||
|
||||
## Description
|
||||
|
||||
A simple C++ binary to benchmark a compute graph and its individual operators,
|
||||
both on desktop machines and on Android.
|
||||
|
||||
## To build/install/run
|
||||
|
||||
### On Android:
|
||||
|
||||
(1) build for your specific platform, e.g.:
|
||||
```bash
|
||||
$bazel build -c opt \
|
||||
--crosstool_top=//external:android/crosstool \
|
||||
--cpu=armeabi-v7a \
|
||||
--host_crosstool_top=@bazel_tools//tools/cpp:toolchain \
|
||||
tensorflow/tools/benchmark:benchmark_model
|
||||
```
|
||||
|
||||
(2) Connect your phone. Push the binary to your phone with adb push
|
||||
(make the directory if required):
|
||||
```bash
|
||||
$adb push bazel-bin/tensorflow/tools/benchmark/benchmark_model /data/local/tmp
|
||||
```
|
||||
|
||||
(3) Push the compute graph that you need to test. For example:
|
||||
adb push tensorflow_inception_graph.pb /data/local/tmp
|
||||
|
||||
(4) Run the benchmark. For example:
|
||||
```bash
|
||||
$adb shell "/data/local/tmp/benchmark_model \
|
||||
--graph=/data/local/tmp/tensorflow_inception_graph.pb \
|
||||
--input_layer="input:0" \
|
||||
--input_layer_shape="1,224,224,3" \
|
||||
--input_layer_type="float" \
|
||||
--output_layer="output:0"
|
||||
```
|
||||
### On desktop:
|
||||
(1) build the binary
|
||||
```bash
|
||||
$bazel build -c opt tensorflow/tools/benchmark:benchmark_model
|
||||
```
|
||||
|
||||
(2) Run on your compute graph, similar to the Android case but without the need of adb shell.
|
||||
For example:
|
||||
```bash
|
||||
$bazel-bin/tensorflow/tools/benchmark/benchmark_model \
|
||||
--graph=tensorflow_inception_graph.pb \
|
||||
--input_layer="input:0" \
|
||||
--input_layer_shape="1,224,224,3" \
|
||||
--input_layer_type="float" \
|
||||
--output_layer="output:0"
|
||||
```
|
||||
|
||||
The Inception graph used as an example here may be downloaded from
|
||||
https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip
|
225
tensorflow/tools/benchmark/benchmark_model.cc
Normal file
225
tensorflow/tools/benchmark/benchmark_model.cc
Normal file
@ -0,0 +1,225 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// A C++ binary to benchmark a compute graph and its individual operators,
|
||||
// both on desktop machines and on Android.
|
||||
//
|
||||
// See README.md for usage instructions.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
#include "tensorflow/core/util/stat_summarizer.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Global variables that holds the Tensorflow classifier.
|
||||
static std::unique_ptr<tensorflow::Session> session;
|
||||
|
||||
static StatSummarizer g_stats;
|
||||
|
||||
struct Flags {
|
||||
string graph = "/data/local/tmp/tensorflow_inception_graph.pb";
|
||||
string input_layer = "input:0";
|
||||
string input_layer_shape = "1,224,224,3";
|
||||
string input_layer_type = "float";
|
||||
string output_layer = "output:0";
|
||||
int num_runs = 50;
|
||||
string run_delay = "-1.0";
|
||||
int num_threads = -1;
|
||||
};
|
||||
|
||||
static Flags* flags; // Filled in by main()
|
||||
|
||||
static bool InitializeBenchmark() {
|
||||
g_stats.Reset();
|
||||
|
||||
LOG(INFO) << "Loading Tensorflow.";
|
||||
|
||||
tensorflow::SessionOptions options;
|
||||
tensorflow::ConfigProto& config = options.config;
|
||||
if (flags->num_threads > 0) {
|
||||
config.set_intra_op_parallelism_threads(flags->num_threads);
|
||||
}
|
||||
LOG(INFO) << "Got config, " << config.device_count_size() << " devices";
|
||||
|
||||
session.reset(tensorflow::NewSession(options));
|
||||
tensorflow::GraphDef tensorflow_graph;
|
||||
Status s = ReadBinaryProto(Env::Default(), flags->graph, &tensorflow_graph);
|
||||
if (!s.ok()) {
|
||||
LOG(ERROR) << "Could not create Tensorflow Graph: " << s;
|
||||
return false;
|
||||
}
|
||||
|
||||
s = session->Create(tensorflow_graph);
|
||||
if (!s.ok()) {
|
||||
LOG(ERROR) << "Could not create Tensorflow Session: " << s;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Clear the proto to save memory space.
|
||||
tensorflow_graph.Clear();
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool RunBenchmark() {
|
||||
DataType input_data_type;
|
||||
CHECK(DataTypeFromString(flags->input_layer_type, &input_data_type))
|
||||
<< flags->input_layer_type << " was an invalid type";
|
||||
|
||||
std::vector<int32> sizes;
|
||||
CHECK(str_util::SplitAndParseAsInts(flags->input_layer_shape, ',', &sizes))
|
||||
<< "Incorrect size string specified: " << flags->input_layer_shape;
|
||||
TensorShape input_shape;
|
||||
for (int i = 0; i < sizes.size(); ++i) {
|
||||
input_shape.AddDim(sizes[i]);
|
||||
}
|
||||
|
||||
Tensor input_tensor(input_data_type, input_shape);
|
||||
|
||||
switch (input_data_type) {
|
||||
case DT_INT32: {
|
||||
auto int_tensor = input_tensor.flat<int32>();
|
||||
int_tensor = int_tensor.constant(0.0);
|
||||
break;
|
||||
}
|
||||
case DT_FLOAT: {
|
||||
auto float_tensor = input_tensor.flat<float>();
|
||||
float_tensor = float_tensor.constant(0.0);
|
||||
break;
|
||||
}
|
||||
case DT_QUINT8: {
|
||||
auto int_tensor = input_tensor.flat<quint8>();
|
||||
int_tensor = int_tensor.constant(0.0);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
LOG(FATAL) << "Unsupported input type: " << flags->input_layer_type;
|
||||
}
|
||||
|
||||
std::vector<std::pair<string, tensorflow::Tensor> > input_tensors(
|
||||
{{flags->input_layer, input_tensor}});
|
||||
|
||||
std::vector<tensorflow::Tensor> output_tensors;
|
||||
std::vector<string> output_names({flags->output_layer});
|
||||
|
||||
tensorflow::Status s;
|
||||
|
||||
RunOptions run_options;
|
||||
run_options.set_trace_level(RunOptions::FULL_TRACE);
|
||||
RunMetadata run_metadata;
|
||||
|
||||
s = session->Run(run_options, input_tensors, output_names, {},
|
||||
&output_tensors, &run_metadata);
|
||||
|
||||
assert(run_metadata.has_step_stats());
|
||||
|
||||
const StepStats& stats = run_metadata.step_stats();
|
||||
|
||||
g_stats.ProcessStepStats(stats);
|
||||
|
||||
if (!s.ok()) {
|
||||
LOG(ERROR) << "Error during inference: " << s;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
tensorflow::flags = new tensorflow::Flags();
|
||||
|
||||
const bool parse_result = tensorflow::ParseFlags(
|
||||
&argc, argv,
|
||||
{
|
||||
tensorflow::Flag("graph", &tensorflow::flags->graph),
|
||||
tensorflow::Flag("input_layer", &tensorflow::flags->input_layer),
|
||||
tensorflow::Flag("input_layer_shape",
|
||||
&tensorflow::flags->input_layer_shape),
|
||||
tensorflow::Flag("input_layer_type",
|
||||
&tensorflow::flags->input_layer_type),
|
||||
tensorflow::Flag("output_layer", &tensorflow::flags->output_layer),
|
||||
tensorflow::Flag("num_runs", &tensorflow::flags->num_runs),
|
||||
tensorflow::Flag("run_delay", &tensorflow::flags->run_delay),
|
||||
tensorflow::Flag("num_threads", &tensorflow::flags->num_threads),
|
||||
});
|
||||
|
||||
if (!parse_result) {
|
||||
LOG(ERROR) << "Error parsing command-line flags.";
|
||||
return -1;
|
||||
}
|
||||
|
||||
::tensorflow::port::InitMain(argv[0], &argc, &argv);
|
||||
if (argc > 1) {
|
||||
LOG(ERROR) << "Unknown argument " << argv[1];
|
||||
return -1;
|
||||
}
|
||||
|
||||
LOG(INFO) << "Graph: [" << tensorflow::flags->graph << "]";
|
||||
LOG(INFO) << "Input layer: [" << tensorflow::flags->input_layer << "]";
|
||||
LOG(INFO) << "Input shape: [" << tensorflow::flags->input_layer_shape << "]";
|
||||
LOG(INFO) << "Input type: [" << tensorflow::flags->input_layer_type << "]";
|
||||
LOG(INFO) << "Output layer: [" << tensorflow::flags->output_layer << "]";
|
||||
LOG(INFO) << "Num runs: [" << tensorflow::flags->num_runs << "]";
|
||||
LOG(INFO) << "Inter-run delay (seconds): [" << tensorflow::flags->run_delay
|
||||
<< "]";
|
||||
LOG(INFO) << "Num threads: [" << tensorflow::flags->num_threads << "]";
|
||||
|
||||
if (!tensorflow::InitializeBenchmark()) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Convert the run_delay string into a timespec.
|
||||
const double sleep_seconds =
|
||||
std::strtod(tensorflow::flags->run_delay.c_str(), nullptr);
|
||||
timespec req;
|
||||
req.tv_sec = static_cast<time_t>(sleep_seconds);
|
||||
req.tv_nsec = (sleep_seconds - req.tv_sec) * 1000000000;
|
||||
|
||||
LOG(INFO) << "Running benchmark";
|
||||
for (int i = 0; i < tensorflow::flags->num_runs; ++i) {
|
||||
if (!tensorflow::RunBenchmark()) {
|
||||
LOG(INFO) << "Failed on run " << i;
|
||||
return -1;
|
||||
}
|
||||
|
||||
// If requested, sleep between runs for an arbitrary amount of time.
|
||||
// This can be helpful to determine the effect of mobile processor
|
||||
// scaling and thermal throttling.
|
||||
if (sleep_seconds > 0.0) {
|
||||
nanosleep(&req, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
tensorflow::g_stats.PrintStepStats();
|
||||
return 0;
|
||||
}
|
@ -139,6 +139,16 @@ else
|
||||
# Assume: PYTHON_BIN_PATH is exported by the script above
|
||||
fi
|
||||
|
||||
# Obtain the path to head/ghead binary (for log file printing)
|
||||
HEAD_BIN="ghead"
|
||||
if [[ -z $(which "${HEAD_BIN}") ]]; then
|
||||
# This is not Mac (which uses coreutils/ghead), use head.
|
||||
HEAD_BIN="head"
|
||||
if [[ -z $(which "${HEAD_BIN}") ]]; then
|
||||
die "Unable to obtain path to head or ghead"
|
||||
fi
|
||||
fi
|
||||
|
||||
if [[ -z "${PYTHON_BIN_PATH}" ]]; then
|
||||
die "PYTHON_BIN_PATH was not provided. If this is not virtualenv, "\
|
||||
"did you run configure?"
|
||||
@ -371,7 +381,7 @@ while true; do
|
||||
|
||||
echo " Log @: ${TEST_LOGS[K]}"
|
||||
echo "============== BEGINS failure log content =============="
|
||||
head --lines=-1 "${TEST_LOGS[K]}"
|
||||
"${HEAD_BIN}" --lines=-1 "${TEST_LOGS[K]}"
|
||||
echo "============== ENDS failure log content =============="
|
||||
echo ""
|
||||
fi
|
||||
|
Loading…
x
Reference in New Issue
Block a user