Merge pull request from caisq/r0.8-tensorforest-2

R0.8 tensorforest cherry-pick
This commit is contained in:
Martin Wicke 2016-04-19 13:51:30 -07:00
commit dc19800ee1
34 changed files with 1160 additions and 442 deletions

View File

@ -18,6 +18,7 @@
// only op that involves tree traversal, and is constructed so that it can
// be run in parallel on separate batches of data.
#include <unordered_map>
#include <vector>
#include "tensorflow/contrib/tensor_forest/core/ops/tree_utils.h"
@ -25,10 +26,12 @@
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/util/work_sharder.h"
namespace tensorflow {
using std::get;
using std::make_pair;
using std::make_tuple;
using std::pair;
using std::tuple;
@ -42,6 +45,71 @@ using tensorforest::DecideNode;
using tensorforest::Initialize;
using tensorforest::IsAllInitialized;
// A data structure to store the results of parallel tree traversal.
struct InputDataResult {
// A list of each node that was visited.
std::vector<int32> node_indices;
// The accumulator of the leaf that a data point ended up at, or -1 if none.
int32 leaf_accumulator;
// The left-branch taken candidate splits.
std::vector<int32> split_adds;
// If the candidate splits for the leaf that a data point arrived at
// were initialized or not, which determines if we add this to total
// pcw counts or not.
bool splits_initialized;
};
void Evaluate(const Tensor& input_data, const Tensor& input_labels,
const Tensor& tree_tensor, const Tensor& tree_thresholds,
const Tensor& node_to_accumulator,
const Tensor& candidate_split_features,
const Tensor& candidate_split_thresholds,
InputDataResult* results, int64 start, int64 end) {
const auto tree = tree_tensor.tensor<int32, 2>();
const auto thresholds = tree_thresholds.unaligned_flat<float>();
const auto node_map = node_to_accumulator.unaligned_flat<int32>();
const auto split_features = candidate_split_features.tensor<int32, 2>();
const auto split_thresholds = candidate_split_thresholds.tensor<float, 2>();
const int32 num_splits = candidate_split_features.shape().dim_size(1);
for (int i = start; i < end; ++i) {
const Tensor point = input_data.Slice(i, i + 1);
int node_index = 0;
results[i].splits_initialized = false;
while (true) {
results[i].node_indices.push_back(node_index);
int32 left_child = tree(node_index, CHILDREN_INDEX);
if (left_child == LEAF_NODE) {
const int32 accumulator = node_map(node_index);
results[i].leaf_accumulator = accumulator;
// If the leaf is not fertile or is not yet initialized, we don't
// count it in the candidate/total split per-class-weights because
// it won't have any candidate splits yet.
if (accumulator >= 0 &&
IsAllInitialized(candidate_split_features.Slice(
accumulator, accumulator + 1))) {
results[i].splits_initialized = true;
for (int split = 0; split < num_splits; split++) {
if (!DecideNode(point, split_features(accumulator, split),
split_thresholds(accumulator, split))) {
results[i].split_adds.push_back(split);
}
}
}
break;
} else if (left_child == FREE_NODE) {
LOG(ERROR) << "Reached a free node, not good.";
results[i].node_indices.push_back(FREE_NODE);
break;
}
node_index =
left_child + DecideNode(point, tree(node_index, FEATURE_INDEX),
thresholds(node_index));
}
}
}
REGISTER_OP("CountExtremelyRandomStats")
.Attr("num_classes: int32")
.Input("input_data: float")
@ -79,9 +147,9 @@ REGISTER_OP("CountExtremelyRandomStats")
gives the j-th feature of the i-th input.
input_labels: The training batch's labels; `input_labels[i]` is the class
of the i-th input.
tree:= A 2-d int32 tensor. `tree[0][i]` gives the index of the left child
of the i-th node, `tree[0][i] + 1` gives the index of the right child of
the i-th node, and `tree[1][i]` gives the index of the feature used to
tree:= A 2-d int32 tensor. `tree[i][0]` gives the index of the left child
of the i-th node, `tree[i][0] + 1` gives the index of the right child of
the i-th node, and `tree[i][1]` gives the index of the feature used to
split the i-th node.
tree_thresholds: `tree_thresholds[i]` is the value used to split the i-th
node.
@ -176,7 +244,31 @@ class CountExtremelyRandomStats : public OpKernel {
"candidate_split_features and candidate_split_thresholds should be "
"the same shape."));
const int32 num_splits = candidate_split_features.shape().dim_size(1);
// Evaluate input data in parallel.
const int64 num_data = input_data.shape().dim_size(0);
std::unique_ptr<InputDataResult[]> results(new InputDataResult[num_data]);
auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
int num_threads = worker_threads->num_threads;
if (num_threads <= 1) {
Evaluate(input_data, input_labels, tree_tensor, tree_thresholds,
node_to_accumulator, candidate_split_features,
candidate_split_thresholds, results.get(), 0, num_data);
} else {
auto work = [&input_data, &input_labels, &tree_tensor, &tree_thresholds,
&node_to_accumulator, &candidate_split_features,
&candidate_split_thresholds, &num_data,
&results](int64 start, int64 end) {
CHECK(start <= end);
CHECK(end <= num_data);
Evaluate(input_data, input_labels, tree_tensor, tree_thresholds,
node_to_accumulator, candidate_split_features,
candidate_split_thresholds, results.get(), start, end);
};
Shard(num_threads, worker_threads->workers, num_data, 100, work);
}
// Set output tensors.
const auto labels = input_labels.unaligned_flat<int32>();
// node pcw delta
Tensor* output_node_pcw_delta = nullptr;
@ -196,58 +288,28 @@ class CountExtremelyRandomStats : public OpKernel {
&output_leaves));
auto out_leaves = output_leaves->unaligned_flat<int32>();
const auto tree = tree_tensor.tensor<int32, 2>();
const auto thresholds = tree_thresholds.unaligned_flat<float>();
const auto labels = input_labels.unaligned_flat<int32>();
const auto node_map = node_to_accumulator.unaligned_flat<int32>();
const auto split_features = candidate_split_features.tensor<int32, 2>();
const auto split_thresholds = candidate_split_thresholds.tensor<float, 2>();
const int32 num_data = input_data.shape().dim_size(0);
// <accumulator, class> -> count delta
std::unordered_map<pair<int32, int32>, int32, PairIntHash> total_delta;
// <accumulator, split, class> -> count delta
std::unordered_map<tuple<int32, int32, int32>,
int32, TupleIntHash> split_delta;
for (int i = 0; i < num_data; i++) {
const Tensor point = input_data.Slice(i, i+1);
int node_index = 0;
while (true) {
const int32 label = labels(i);
++out_node(node_index, label);
int32 left_child = tree(node_index, CHILDREN_INDEX);
if (left_child == LEAF_NODE) {
out_leaves(i) = node_index;
const int32 accumulator = node_map(node_index);
// If the leaf is not fertile or is not yet initialized, we don't
// count it in the candidate/total split per-class-weights because
// it won't have any candidate splits yet.
if (accumulator >= 0 &&
IsAllInitialized(
candidate_split_features.Slice(accumulator,
accumulator + 1))) {
++total_delta[std::make_pair(accumulator, label)];
for (int split = 0; split < num_splits; split++) {
if (!DecideNode(point, split_features(accumulator, split),
split_thresholds(accumulator, split))) {
++split_delta[make_tuple(accumulator, split, label)];
}
}
}
break;
} else if (left_child == FREE_NODE) {
LOG(ERROR) << "Reached a free node, not good.";
out_leaves(i) = FREE_NODE;
break;
for (int32 i = 0; i < num_data; ++i) {
const int32 label = labels(i);
const int32 accumulator = results[i].leaf_accumulator;
for (const int32 node : results[i].node_indices) {
++out_node(node, label);
}
out_leaves(i) = results[i].node_indices.back();
if (accumulator >= 0 && results[i].splits_initialized) {
++total_delta[make_pair(accumulator, label)];
for (const int32 split : results[i].split_adds) {
++split_delta[make_tuple(accumulator, split, label)];
}
node_index = left_child +
DecideNode(point, tree(node_index, FEATURE_INDEX),
thresholds(node_index));
}
}
// candidate splits pcw indices
// candidate splits pcw indices
Tensor* output_candidate_pcw_indices = nullptr;
TensorShape candidate_pcw_shape;
candidate_pcw_shape.AddDim(split_delta.size());

View File

@ -94,7 +94,7 @@ class SampleInputs : public OpKernel {
"split_sampling_random_seed", &split_sampling_random_seed_));
// Set up the random number generator.
if (split_sampling_random_seed_ == 0) {
uint64 time_seed = static_cast<uint64>(std::time(NULL));
uint64 time_seed = static_cast<uint64>(std::clock());
single_rand_ = std::unique_ptr<random::PhiloxRandom>(
new random::PhiloxRandom(time_seed));
} else {

View File

@ -44,9 +44,9 @@ REGISTER_OP("TreePredictions")
input_data: The training batch's features as a 2-d tensor; `input_data[i][j]`
gives the j-th feature of the i-th input.
tree:= A 2-d int32 tensor. `tree[0][i]` gives the index of the left child
of the i-th node, `tree[0][i] + 1` gives the index of the right child of
the i-th node, and `tree[1][i]` gives the index of the feature used to
tree:= A 2-d int32 tensor. `tree[i][0]` gives the index of the left child
of the i-th node, `tree[i][0] + 1` gives the index of the right child of
the i-th node, and `tree[i][1]` gives the index of the feature used to
split the i-th node.
tree_thresholds: `tree_thresholds[i]` is the value used to split the i-th
node.

View File

@ -17,7 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow # pylint: disable=unused-import
import tensorflow as tf
from tensorflow.contrib.tensor_forest.python.ops import training_ops
@ -47,6 +47,29 @@ class CountExtremelyRandomStatsTest(test_util.TensorFlowTestCase):
self.tree_thresholds, self.node_map,
self.split_features, self.split_thresholds, num_classes=4))
self.assertAllEqual(
[[1., 1., 1., 1.], [1., 1., 0., 0.], [0., 0., 1., 1.]],
pcw_node.eval())
self.assertAllEqual([[0, 0, 0]], pcw_splits_indices.eval())
self.assertAllEqual([1.], pcw_splits_delta.eval())
self.assertAllEqual([[0, 1], [0, 0]], pcw_totals_indices.eval())
self.assertAllEqual([1., 1.], pcw_totals_delta.eval())
self.assertAllEqual([1, 1, 2, 2], leaves.eval())
def testThreaded(self):
with self.test_session(
config=tf.ConfigProto(intra_op_parallelism_threads=2)):
(pcw_node, pcw_splits_indices, pcw_splits_delta, pcw_totals_indices,
pcw_totals_delta,
leaves) = (self.ops.count_extremely_random_stats(self.input_data,
self.input_labels,
self.tree,
self.tree_thresholds,
self.node_map,
self.split_features,
self.split_thresholds,
num_classes=4))
self.assertAllEqual([[1., 1., 1., 1.], [1., 1., 0., 0.],
[0., 0., 1., 1.]],
pcw_node.eval())

View File

@ -49,12 +49,13 @@ def TreePredictions(op):
# there's not yet any guarantee that the shared object exists.
# In which case, "import tensorflow" will always crash, even for users that
# never use contrib.
def Load():
def Load(library_base_dir=''):
"""Load the inference ops library and return the loaded module."""
with _ops_lock:
global _inference_ops
if not _inference_ops:
data_files_path = tf.resource_loader.get_data_files_path()
data_files_path = os.path.join(library_base_dir,
tf.resource_loader.get_data_files_path())
tf.logging.info('data path: %s', data_files_path)
_inference_ops = tf.load_op_library(os.path.join(
data_files_path, INFERENCE_OPS_FILE))

View File

@ -25,6 +25,7 @@ import tensorflow as tf
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
TRAINING_OPS_FILE = '_training_ops.so'
_training_ops = None
@ -96,12 +97,13 @@ def _UpdateFertileSlotsShape(unused_op):
# there's not yet any guarantee that the shared object exists.
# In which case, "import tensorflow" will always crash, even for users that
# never use contrib.
def Load():
def Load(library_base_dir=''):
"""Load training ops library and return the loaded module."""
with _ops_lock:
global _training_ops
if not _training_ops:
data_files_path = tf.resource_loader.get_data_files_path()
data_files_path = os.path.join(library_base_dir,
tf.resource_loader.get_data_files_path())
tf.logging.info('data path: %s', data_files_path)
_training_ops = tf.load_op_library(os.path.join(
data_files_path, TRAINING_OPS_FILE))

View File

@ -25,19 +25,6 @@ from tensorflow.contrib.tensor_forest.python.ops import inference_ops
from tensorflow.contrib.tensor_forest.python.ops import training_ops
flags = tf.app.flags
FLAGS = flags.FLAGS
# Default parameter values. These are all only used if the corresponding
# parameter is not specified when constructing the ForestHParams.
flags.DEFINE_integer('num_trees', 100, 'Number of trees in forest')
flags.DEFINE_integer('max_nodes', 10000, 'Maxmimum number of tree nodes.')
flags.DEFINE_float(
'samples_to_decide', 25.0,
'Only decide on a split, or only fully use a leaf, after this many '
'training samples have been seen.')
# If tree[i][0] equals this value, then i is a leaf node.
LEAF_NODE = -1
@ -57,7 +44,20 @@ LEAF_NODE = -1
class ForestHParams(object):
"""A base class for holding hyperparameters and calculating good defaults."""
def __init__(self, **kwargs):
def __init__(self, num_trees=100, max_nodes=10000, bagging_fraction=1.0,
samples_to_decide=25, max_depth=0, num_splits_to_consider=0,
max_fertile_nodes=0, split_after_samples=0,
valid_leaf_threshold=0, **kwargs):
self.num_trees = num_trees
self.max_nodes = max_nodes
self.bagging_fraction = bagging_fraction
self.samples_to_decide = samples_to_decide
self.max_depth = max_depth
self.num_splits_to_consider = num_splits_to_consider
self.max_fertile_nodes = max_fertile_nodes
self.split_after_samples = split_after_samples
self.valid_leaf_threshold = valid_leaf_threshold
for name, value in kwargs.items():
setattr(self, name, value)
@ -69,19 +69,21 @@ class ForestHParams(object):
# Fail fast if num_classes isn't set.
_ = getattr(self, 'num_classes')
self.num_trees = getattr(self, 'num_trees', FLAGS.num_trees)
self.max_nodes = getattr(self, 'max_nodes', FLAGS.max_nodes)
self.training_library_base_dir = getattr(
self, 'training_library_base_dir', '')
self.inference_library_base_dir = getattr(
self, 'inference_library_base_dir', '')
# Allow each tree to be unbalanced by up to a factor of 2.
self.max_depth = getattr(self, 'max_depth',
int(2 * math.ceil(math.log(self.max_nodes, 2))))
self.max_depth = (self.max_depth or
int(2 * math.ceil(math.log(self.max_nodes, 2))))
# The Random Forest literature recommends sqrt(# features) for
# classification problems, and p/3 for regression problems.
# TODO(thomaswc): Consider capping this for large number of features.
if not getattr(self, 'num_splits_to_consider', None):
self.num_splits_to_consider = max(10, int(
math.ceil(math.sqrt(self.num_features))))
self.num_splits_to_consider = (
self.num_splits_to_consider or
max(10, int(math.ceil(math.sqrt(self.num_features)))))
# max_fertile_nodes doesn't effect performance, only training speed.
# We therefore set it primarily based upon space considerations.
@ -91,22 +93,19 @@ class ForestHParams(object):
num_fertile = int(math.ceil(self.max_nodes / self.num_splits_to_consider))
# But always use at least 1000 accumulate slots.
num_fertile = max(num_fertile, 1000)
self.max_fertile_nodes = getattr(self, 'max_fertile_nodes', num_fertile)
self.max_fertile_nodes = self.max_fertile_nodes or num_fertile
# But it also never needs to be larger than the number of leaves,
# which is max_nodes / 2.
self.max_fertile_nodes = min(self.max_nodes,
int(math.ceil(self.max_fertile_nodes / 2.0)))
self.max_fertile_nodes = min(self.max_fertile_nodes,
int(math.ceil(self.max_nodes / 2.0)))
# split_after_samples and valid_leaf_threshold should be about the same.
# Therefore, if either is set, use it to set the other. Otherwise, fall
# back on FLAGS.samples_to_decide.
samples_to_decide = (
getattr(self, 'split_after_samples',
getattr(self, 'valid_leaf_threshold', FLAGS.samples_to_decide)))
self.split_after_samples = getattr(self, 'split_after_samples',
samples_to_decide)
self.valid_leaf_threshold = getattr(self, 'valid_leaf_threshold',
samples_to_decide)
# back on samples_to_decide.
samples_to_decide = self.split_after_samples or self.samples_to_decide
self.split_after_samples = self.split_after_samples or samples_to_decide
self.valid_leaf_threshold = self.valid_leaf_threshold or samples_to_decide
# We have num_splits_to_consider slots to fill, and we want to spend
# approximately split_after_samples samples initializing them.
@ -184,23 +183,6 @@ class TreeStats(object):
self.num_leaves = num_leaves
def get_tree_stats(variables, unused_params, session):
num_nodes = variables.end_of_tree.eval(session=session) - 1
num_leaves = tf.where(
tf.equal(tf.squeeze(tf.slice(variables.tree, [0, 0], [-1, 1])),
LEAF_NODE)).eval(session=session).shape[0]
return TreeStats(num_nodes, num_leaves)
def get_forest_stats(variables, params, session):
tree_stats = []
for i in range(params.num_trees):
tree_stats.append(get_tree_stats(variables[i], params, session))
return ForestStats(tree_stats, params)
class ForestTrainingVariables(object):
"""A container for a forests training data, consisting of multiple trees.
@ -212,9 +194,11 @@ class ForestTrainingVariables(object):
... forest_variables.tree ...
"""
def __init__(self, params):
self.variables = [TreeTrainingVariables(params)
for _ in range(params.num_trees)]
def __init__(self, params, device_assigner):
self.variables = []
for i in range(params.num_trees):
with tf.device(device_assigner.get_device(i)):
self.variables.append(TreeTrainingVariables(params))
def __setitem__(self, t, val):
self.variables[t] = val
@ -223,15 +207,41 @@ class ForestTrainingVariables(object):
return self.variables[t]
class RandomForestDeviceAssigner(object):
"""A device assigner that uses the default device.
Write subclasses that implement get_device for control over how trees
get assigned to devices. This assumes that whole trees are assigned
to a device.
"""
def __init__(self):
self.cached = None
def get_device(self, unused_tree_num):
if not self.cached:
dummy = tf.constant(0)
self.cached = dummy.device
return self.cached
class RandomForestGraphs(object):
"""Builds TF graphs for random forest training and inference."""
def __init__(self, params):
def __init__(self, params, device_assigner=None, variables=None):
self.params = params
self.variables = ForestTrainingVariables(self.params)
self.trees = [RandomTreeGraphs(self.variables[i], self.params,
training_ops.Load(), inference_ops.Load())
for i in range(self.params.num_trees)]
self.device_assigner = device_assigner or RandomForestDeviceAssigner()
tf.logging.info('Constructing forest with params = ')
tf.logging.info(self.params.__dict__)
self.variables = variables or ForestTrainingVariables(
self.params, device_assigner=self.device_assigner)
self.trees = [
RandomTreeGraphs(
self.variables[i], self.params,
training_ops.Load(self.params.training_library_base_dir),
inference_ops.Load(self.params.inference_library_base_dir))
for i in range(self.params.num_trees)]
def training_graph(self, input_data, input_labels):
"""Constructs a TF graph for training a random forest.
@ -246,12 +256,26 @@ class RandomForestGraphs(object):
"""
tree_graphs = []
for i in range(self.params.num_trees):
tf.logging.info('Constructing tree %d', i)
seed = self.params.base_random_seed
if seed != 0:
seed += i
tree_graphs.append(self.trees[i].training_graph(
input_data, input_labels, seed))
with tf.device(self.device_assigner.get_device(i)):
seed = self.params.base_random_seed
if seed != 0:
seed += i
# If using bagging, randomly select some of the input.
tree_data = input_data
tree_labels = input_labels
if self.params.bagging_fraction < 1.0:
# TODO(thomaswc): This does sampling without replacment. Consider
# also allowing sampling with replacement as an option.
batch_size = tf.slice(tf.shape(input_data), [0], [1])
r = tf.random_uniform(batch_size, seed=seed)
mask = tf.less(r, tf.ones_like(r) * self.params.bagging_fraction)
gather_indices = tf.squeeze(tf.where(mask), squeeze_dims=[1])
# TODO(thomaswc): Calculate out-of-bag data and labels, and store
# them for use in calculating statistics later.
tree_data = tf.gather(input_data, gather_indices)
tree_labels = tf.gather(input_labels, gather_indices)
tree_graphs.append(
self.trees[i].training_graph(tree_data, tree_labels, seed))
return tf.group(*tree_graphs)
def inference_graph(self, input_data):
@ -265,9 +289,23 @@ class RandomForestGraphs(object):
"""
probabilities = []
for i in range(self.params.num_trees):
probabilities.append(self.trees[i].inference_graph(input_data))
all_predict = tf.pack(probabilities)
return tf.reduce_sum(all_predict, 0) / self.params.num_trees
with tf.device(self.device_assigner.get_device(i)):
probabilities.append(self.trees[i].inference_graph(input_data))
with tf.device(self.device_assigner.get_device(0)):
all_predict = tf.pack(probabilities)
return tf.reduce_sum(all_predict, 0) / self.params.num_trees
def average_size(self):
"""Constructs a TF graph for evaluating the average size of a forest.
Returns:
The average number of nodes over the trees.
"""
sizes = []
for i in range(self.params.num_trees):
with tf.device(self.device_assigner.get_device(i)):
sizes.append(self.trees[i].size())
return tf.reduce_mean(tf.pack(sizes))
def average_impurity(self):
"""Constructs a TF graph for evaluating the leaf impurity of a forest.
@ -277,9 +315,17 @@ class RandomForestGraphs(object):
"""
impurities = []
for i in range(self.params.num_trees):
impurities.append(self.trees[i].average_impurity(self.variables[i]))
with tf.device(self.device_assigner.get_device(i)):
impurities.append(self.trees[i].average_impurity())
return tf.reduce_mean(tf.pack(impurities))
def get_stats(self, session):
tree_stats = []
for i in range(self.params.num_trees):
with tf.device(self.device_assigner.get_device(i)):
tree_stats.append(self.trees[i].get_stats(session))
return ForestStats(tree_stats, self.params)
class RandomTreeGraphs(object):
"""Builds TF graphs for random tree training and inference."""
@ -394,6 +440,7 @@ class RandomTreeGraphs(object):
with tf.control_dependencies([node_update_op]):
def f1():
return self.variables.non_fertile_leaf_scores
def f2():
counts = tf.gather(self.variables.node_per_class_weights,
self.variables.non_fertile_leaves)
@ -535,3 +582,18 @@ class RandomTreeGraphs(object):
counts = tf.gather(self.variables.node_per_class_weights, leaves)
impurity = self._weighted_gini(counts)
return tf.reduce_sum(impurity) / tf.reduce_sum(counts + 1.0)
def size(self):
"""Constructs a TF graph for evaluating the current number of nodes.
Returns:
The current number of nodes in the tree.
"""
return self.variables.end_of_tree - 1
def get_stats(self, session):
num_nodes = self.variables.end_of_tree.eval(session=session) - 1
num_leaves = tf.where(
tf.equal(tf.squeeze(tf.slice(self.variables.tree, [0, 0], [-1, 1])),
LEAF_NODE)).eval(session=session).shape[0]
return TreeStats(num_nodes, num_leaves)

View File

@ -27,6 +27,37 @@ from tensorflow.python.platform import googletest
class TensorForestTest(test_util.TensorFlowTestCase):
def testForestHParams(self):
hparams = tensor_forest.ForestHParams(
num_classes=2, num_trees=100, max_nodes=1000,
num_features=60).fill()
self.assertEquals(2, hparams.num_classes)
# 2 * ceil(log_2(1000)) = 20
self.assertEquals(20, hparams.max_depth)
# sqrt(num_features) < 10, so num_splits_to_consider should be 10.
self.assertEquals(10, hparams.num_splits_to_consider)
# Don't have more fertile nodes than max # leaves, which is 500.
self.assertEquals(500, hparams.max_fertile_nodes)
# We didn't set either of these, so they should be equal
self.assertEquals(hparams.split_after_samples,
hparams.valid_leaf_threshold)
# split_after_samples is larger than 10
self.assertEquals(1, hparams.split_initializations_per_input)
self.assertEquals(0, hparams.base_random_seed)
def testForestHParamsBigTree(self):
hparams = tensor_forest.ForestHParams(
num_classes=2, num_trees=100, max_nodes=1000000,
split_after_samples=25,
num_features=1000).fill()
self.assertEquals(40, hparams.max_depth)
# sqrt(1000) = 31.63...
self.assertEquals(32, hparams.num_splits_to_consider)
# 1000000 / 32 = 31250
self.assertEquals(31250, hparams.max_fertile_nodes)
# floor(31.63 / 25) = 1
self.assertEquals(1, hparams.split_initializations_per_input)
def testTrainingConstruction(self):
input_data = [[-1., 0.], [-1., 2.], # node 1
[1., 0.], [1., -2.]] # node 2
@ -50,6 +81,14 @@ class TensorForestTest(test_util.TensorFlowTestCase):
graph = graph_builder.inference_graph(input_data)
self.assertTrue(isinstance(graph, tf.Tensor))
def testImpurityConstruction(self):
params = tensor_forest.ForestHParams(
num_classes=4, num_features=2, num_trees=10, max_nodes=1000).fill()
graph_builder = tensor_forest.RandomForestGraphs(params)
graph = graph_builder.average_impurity()
self.assertTrue(isinstance(graph, tf.Tensor))
if __name__ == '__main__':
googletest.main()

View File

@ -143,7 +143,6 @@ cc_library(
"lib/core/bits.h",
"lib/core/casts.h",
"lib/core/coding.h",
"lib/core/command_line_flags.h", # TODO(vrv): Delete.
"lib/core/errors.h",
"lib/core/notification.h",
"lib/core/status.h",

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/coding.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/gtl/stl_util.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
@ -713,4 +714,36 @@ void Tensor::FillDescription(TensorDescription* description) const {
}
}
gtl::InlinedVector<int64, 5> Tensor::ComputeFlatInnerDims(
int64 num_out_dims) const {
gtl::InlinedVector<int64, 5> out_dims(num_out_dims, 0);
const int64 num_elements = NumElements();
if (num_elements != 0) {
int64 prod_out_dims = 1;
for (int64 out_dim = num_out_dims - 1; out_dim > 0; --out_dim) {
const int64 in_dim = out_dim + (dims() - num_out_dims);
out_dims[out_dim] =
(in_dim >= dims() || in_dim < 0) ? 1 : dim_size(in_dim);
prod_out_dims *= out_dims[out_dim];
}
out_dims[0] = num_elements / prod_out_dims;
}
return out_dims;
}
gtl::InlinedVector<int64, 5> Tensor::ComputeFlatOuterDims(
int64 num_out_dims) const {
gtl::InlinedVector<int64, 5> out_dims(num_out_dims, 0);
const int64 num_elements = NumElements();
if (num_elements != 0) {
int64 prod_out_dims = 1;
for (int64 out_dim = 0; out_dim < num_out_dims - 1; ++out_dim) {
out_dims[out_dim] = out_dim >= dims() ? 1 : dim_size(out_dim);
prod_out_dims *= out_dims[out_dim];
}
out_dims[num_out_dims - 1] = num_elements / prod_out_dims;
}
return out_dims;
}
} // namespace tensorflow

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@ -243,40 +244,28 @@ class Tensor {
///
/// ```
template <typename T>
typename TTypes<T>::Flat flat();
typename TTypes<T>::Flat flat() {
return shaped<T, 1>({NumElements()});
}
template <typename T>
typename TTypes<T>::UnalignedFlat unaligned_flat() {
return unaligned_shaped<T, 1>({NumElements()});
}
/// Returns the data as an Eigen::Tensor with 2 dimensions, collapsing all
/// Tensor dimensions but the last one into the first dimension of the result.
template <typename T>
typename TTypes<T>::Matrix flat_inner_dims() {
int64 last_size = dims() > 0 ? dim_size(dims() - 1) : 1;
if (last_size == 0) {
DCHECK_EQ(NumElements(), 0);
// Return something empty, avoiding divide by 0
return shaped<T, 2>({0, 0});
} else {
return shaped<T, 2>({NumElements() / last_size, last_size});
}
}
/// Returns the data as an Eigen::Tensor with NDIMS dimensions, collapsing all
/// Tensor dimensions but the last NDIMS-1 into the first dimension of the
/// result. If NDIMS > dims() then leading dimensions of size 1 will be
/// added to make the output rank NDIMS.
template <typename T, size_t NDIMS = 2>
typename TTypes<T, NDIMS>::Tensor flat_inner_dims();
/// Returns the data as an Eigen::Tensor with 2 dimensions, collapsing all
/// Tensor dimensions but the first one into the last dimension of the result.
template <typename T>
typename TTypes<T>::Matrix flat_outer_dims() {
int64 first_size = dims() > 0 ? dim_size(0) : 1;
if (first_size == 0) {
DCHECK_EQ(NumElements(), 0);
// Return something empty, avoiding divide by 0
return shaped<T, 2>({0, 0});
} else {
return shaped<T, 2>({first_size, NumElements() / first_size});
}
}
/// Returns the data as an Eigen::Tensor with NDIMS dimensions, collapsing all
/// Tensor dimensions but the first NDIMS-1 into the last dimension of the
/// result. If NDIMS > dims() then trailing dimensions of size 1 will be
/// added to make the output rank NDIMS.
template <typename T, size_t NDIMS = 2>
typename TTypes<T, NDIMS>::Tensor flat_outer_dims();
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::Tensor shaped(gtl::ArraySlice<int64> new_sizes);
@ -308,31 +297,19 @@ class Tensor {
typename TTypes<T, NDIMS>::ConstTensor tensor() const;
template <typename T>
typename TTypes<T>::ConstFlat flat() const;
typename TTypes<T>::ConstFlat flat() const {
return shaped<T, 1>({NumElements()});
}
template <typename T>
typename TTypes<T>::UnalignedConstFlat unaligned_flat() const {
return unaligned_shaped<T, 1>({NumElements()});
}
template <typename T>
typename TTypes<T>::ConstMatrix flat_inner_dims() const {
int64 last_size = dims() > 0 ? dim_size(dims() - 1) : 1;
if (last_size == 0) {
DCHECK_EQ(NumElements(), 0);
// Return something empty, avoiding divide by 0
return shaped<T, 2>({0, 0});
} else {
return shaped<T, 2>({NumElements() / last_size, last_size});
}
}
template <typename T>
typename TTypes<T>::ConstMatrix flat_outer_dims() const;
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::ConstTensor shaped(
gtl::ArraySlice<int64> new_sizes) const;
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::UnalignedConstTensor unaligned_shaped(
gtl::ArraySlice<int64> new_sizes) const;
@ -340,6 +317,12 @@ class Tensor {
template <typename T>
typename TTypes<T>::ConstScalar scalar() const;
template <typename T, size_t NDIMS = 2>
typename TTypes<T, NDIMS>::ConstTensor flat_inner_dims() const;
template <typename T, size_t NDIMS = 2>
typename TTypes<T, NDIMS>::ConstTensor flat_outer_dims() const;
/// Render the first `max_entries` values in `*this` into a string.
string SummarizeValue(int64 max_entries) const;
@ -378,6 +361,8 @@ class Tensor {
void FillDimsAndValidateCompatibleShape(
gtl::ArraySlice<int64> new_sizes,
Eigen::array<Eigen::DenseIndex, NDIMS>* dims) const;
gtl::InlinedVector<int64, 5> ComputeFlatInnerDims(int64 num_out_dims) const;
gtl::InlinedVector<int64, 5> ComputeFlatOuterDims(int64 num_out_dims) const;
TensorShape shape_;
TensorBuffer* buf_;
@ -534,26 +519,24 @@ typename TTypes<T>::ConstScalar Tensor::scalar() const {
return typename TTypes<T>::ConstScalar(base<T>());
}
template <typename T>
typename TTypes<T>::Flat Tensor::flat() {
return shaped<T, 1>({NumElements()});
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::Tensor Tensor::flat_inner_dims() {
return shaped<T, NDIMS>(ComputeFlatInnerDims(NDIMS));
}
template <typename T>
typename TTypes<T>::ConstFlat Tensor::flat() const {
return shaped<T, 1>({NumElements()});
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::Tensor Tensor::flat_outer_dims() {
return shaped<T, NDIMS>(ComputeFlatOuterDims(NDIMS));
}
template <typename T>
typename TTypes<T>::ConstMatrix Tensor::flat_outer_dims() const {
int64 first_size = dims() > 0 ? dim_size(0) : 1;
if (first_size == 0) {
DCHECK_EQ(NumElements(), 0);
// Return something empty, avoiding divide by 0
return shaped<T, 2>({0, 0});
} else {
return shaped<T, 2>({first_size, NumElements() / first_size});
}
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::ConstTensor Tensor::flat_inner_dims() const {
return shaped<T, NDIMS>(ComputeFlatInnerDims(NDIMS));
}
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::ConstTensor Tensor::flat_outer_dims() const {
return shaped<T, NDIMS>(ComputeFlatOuterDims(NDIMS));
}
} // namespace tensorflow

View File

@ -224,6 +224,49 @@ TEST(Tensor_Float, Reshape) {
EXPECT_EQ(flat_inner_dims(0, 0), 0.01f);
EXPECT_EQ(flat_inner_dims(23, 4), 0.02f);
}
{
auto flat_outer_dims = t.flat_outer_dims<float>();
EXPECT_EQ(2, flat_outer_dims.dimension(0));
EXPECT_EQ(60, flat_outer_dims.dimension(1));
EXPECT_EQ(flat_outer_dims(0, 0), 0.01f);
EXPECT_EQ(flat_outer_dims(1, 59), 0.02f);
}
{
auto flat_inner_dims = t.flat_inner_dims<float, 3>();
EXPECT_EQ(6, flat_inner_dims.dimension(0));
EXPECT_EQ(4, flat_inner_dims.dimension(1));
EXPECT_EQ(5, flat_inner_dims.dimension(2));
EXPECT_EQ(flat_inner_dims(0, 0, 0), 0.01f);
EXPECT_EQ(flat_inner_dims(5, 3, 4), 0.02f);
}
{
auto flat_outer_dims = t.flat_outer_dims<float, 3>();
EXPECT_EQ(2, flat_outer_dims.dimension(0));
EXPECT_EQ(3, flat_outer_dims.dimension(1));
EXPECT_EQ(20, flat_outer_dims.dimension(2));
EXPECT_EQ(flat_outer_dims(0, 0, 0), 0.01f);
EXPECT_EQ(flat_outer_dims(1, 2, 19), 0.02f);
}
{
auto flat_inner_dims = t.flat_inner_dims<float, 5>();
EXPECT_EQ(1, flat_inner_dims.dimension(0));
EXPECT_EQ(2, flat_inner_dims.dimension(1));
EXPECT_EQ(3, flat_inner_dims.dimension(2));
EXPECT_EQ(4, flat_inner_dims.dimension(3));
EXPECT_EQ(5, flat_inner_dims.dimension(4));
EXPECT_EQ(flat_inner_dims(0, 0, 0, 0, 0), 0.01f);
EXPECT_EQ(flat_inner_dims(0, 1, 2, 3, 4), 0.02f);
}
{
auto flat_outer_dims = t.flat_outer_dims<float, 5>();
EXPECT_EQ(2, flat_outer_dims.dimension(0));
EXPECT_EQ(3, flat_outer_dims.dimension(1));
EXPECT_EQ(4, flat_outer_dims.dimension(2));
EXPECT_EQ(5, flat_outer_dims.dimension(3));
EXPECT_EQ(1, flat_outer_dims.dimension(4));
EXPECT_EQ(flat_outer_dims(0, 0, 0, 0, 0), 0.01f);
EXPECT_EQ(flat_outer_dims(1, 2, 3, 4, 0), 0.02f);
}
}
TEST(Tensor_Scalar, Basics) {

View File

@ -305,6 +305,23 @@ tf_kernel_libraries(
],
)
tf_cc_test(
name = "batch_norm_op_test",
size = "small",
deps = [
":batch_norm_op",
":ops_testutil",
":ops_util",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
tf_cc_test(
name = "concat_op_test",
size = "small",

View File

@ -0,0 +1,62 @@
/* Copyright 2015 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <vector>
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
class BatchNormOpTest : public OpsTestBase {};
TEST_F(BatchNormOpTest, Simple) {
TF_EXPECT_OK(
NodeDefBuilder("batch_norm_op", "BatchNormWithGlobalNormalization")
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_FLOAT))
.Attr("scale_after_normalization", false)
.Attr("variance_epsilon", 0.001)
.Finalize(node_def()));
TF_EXPECT_OK(InitOpWithGraphVersion(8));
AddInputFromArray<float>(TensorShape({1, 1, 6, 2}),
{1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6});
AddInputFromArray<float>(TensorShape({2}), {10, 20});
AddInputFromArray<float>(TensorShape({2}), {0.25, 0.5});
AddInputFromArray<float>(TensorShape({2}), {0.1, 0.6});
AddInputFromArray<float>(TensorShape({2}), {0.0, 0.0});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 1, 6, 2}));
test::FillValues<float>(
&expected, {-17.86, -22.00, -15.87, -20.59, -13.87, -19.18, -21.86,
-33.31, -23.85, -34.72, -25.85, -36.13});
test::ExpectTensorNear<float>(expected, *GetOutput(0), 0.01);
}
} // namespace tensorflow

View File

@ -94,10 +94,13 @@ class OpsTestBase : public ::testing::Test {
// and output types as output.
//
// Returns the status of initialization.
Status InitOp() {
Status InitOp() { return InitOpWithGraphVersion(TF_GRAPH_DEF_VERSION); }
// Only use this directly if you have a deprecated op that you need to test.
Status InitOpWithGraphVersion(int graph_def_version) {
Status status;
kernel_ = CreateOpKernel(device_type_, device_.get(), allocator(),
node_def_, TF_GRAPH_DEF_VERSION, &status);
node_def_, graph_def_version, &status);
if (kernel_ != nullptr) input_types_ = kernel_->input_types();
return status;
}

View File

@ -66,7 +66,7 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T MaybeConj(T v) {
#define MAYBE_CONJ(T) \
template <> \
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T MaybeConj<T>(T v) { \
return std::conj(v); \
return Eigen::numext::conj(v); \
}
#endif

View File

@ -1,121 +0,0 @@
/* Copyright 2015 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/lib/core/command_line_flags.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
namespace tensorflow {
namespace {
// Templated function to convert a string to target values.
// Return true if the conversion is successful. Otherwise, return false.
template <typename T>
bool StringToValue(const string& content, T* value);
template <>
bool StringToValue<int32>(const string& content, int32* value) {
return strings::safe_strto32(content, value);
}
template <>
bool StringToValue<string>(const string& content, string* value) {
*value = content;
return true;
}
// Parse a single argument by linearly searching through the command table.
// The input format is: --argument=value.
// Return OK if the argument is used. It store the extracted value into the
// matching flag.
// Return NOT_FOUND if the argument is not recognized.
// Return INVALID_ARGUMENT if the command is recognized, but fails to extract
// its value.
template <typename T>
Status ParseArgument(const string& argument) {
for (auto& command :
internal::CommandLineFlagRegistry<T>::Instance()->commands) {
string prefix = strings::StrCat("--", command.name, "=");
if (tensorflow::StringPiece(argument).starts_with(prefix)) {
string content = argument.substr(prefix.length());
if (StringToValue<T>(content, command.value)) {
return Status::OK();
}
return Status(error::INVALID_ARGUMENT,
strings::StrCat("Cannot parse integer in: ", argument));
}
}
return Status(error::NOT_FOUND,
strings::StrCat("Unknown command: ", argument));
}
// A specialization for booleans. The input format is:
// "--argument" or "--noargument".
// Parse a single argument by linearly searching through the command table.
// Return OK if the argument is used. The value is stored in the matching flag.
// Return NOT_FOUND if the argument is not recognized.
template <>
Status ParseArgument<bool>(const string& argument) {
for (auto& command :
internal::CommandLineFlagRegistry<bool>::Instance()->commands) {
if (argument == strings::StrCat("--", command.name)) {
*command.value = true;
return Status::OK();
} else if (argument == strings::StrCat("--no", command.name)) {
*command.value = false;
return Status::OK();
}
}
return Status(error::NOT_FOUND,
strings::StrCat("Unknown command: ", argument));
}
} // namespace
Status ParseCommandLineFlags(int* argc, char* argv[]) {
int unused_argc = 1;
for (int index = 1; index < *argc; ++index) {
Status s;
// Search bool commands.
s = ParseArgument<bool>(argv[index]);
if (s.ok()) {
continue;
}
if (s.code() != error::NOT_FOUND) {
return s;
}
// Search int32 commands.
s = ParseArgument<int32>(argv[index]);
if (s.ok()) {
continue;
}
// Search string commands.
s = ParseArgument<string>(argv[index]);
if (s.ok()) {
continue;
}
if (s.code() != error::NOT_FOUND) {
return s;
}
// Pointer swap the unused argument to the front.
std::swap(argv[unused_argc++], argv[index]);
}
*argc = unused_argc;
return Status::OK();
}
} // namespace tensorflow

View File

@ -1,80 +0,0 @@
/* Copyright 2015 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LIB_CORE_COMMAND_LINE_FLAGS_H_
#define TENSORFLOW_LIB_CORE_COMMAND_LINE_FLAGS_H_
#include <vector>
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace internal {
template <typename T>
struct CommandLineFlagRegistry {
static CommandLineFlagRegistry* Instance() {
static CommandLineFlagRegistry instance_;
return &instance_;
}
struct Command {
string name;
T* value;
string text;
};
std::vector<Command> commands;
private:
CommandLineFlagRegistry() {}
TF_DISALLOW_COPY_AND_ASSIGN(CommandLineFlagRegistry);
};
template <typename T>
struct CommandLineFlagRegister {
CommandLineFlagRegister(const string& name, T* val, const string& text) {
CommandLineFlagRegistry<T>::Instance()->commands.push_back(
{name, val, text});
}
};
#define TF_DEFINE_variable(type, name, default_value, text) \
type FLAGS_##name = default_value; \
namespace TF_flags_internal { \
tensorflow::internal::CommandLineFlagRegister<type> \
TF_flags_internal_var_##name(#name, &FLAGS_##name, text); \
} // namespace TF_flags_internal
} // namespace internal
#define TF_DEFINE_int32(name, default_value, text) \
TF_DEFINE_variable(tensorflow::int32, name, default_value, text);
#define TF_DEFINE_bool(name, default_value, text) \
TF_DEFINE_variable(bool, name, default_value, text);
#define TF_DEFINE_string(name, default_value, text) \
TF_DEFINE_variable(string, name, default_value, text);
// Parse argv[1]..argv[*argc-1] to options. Remove used arguments from the argv.
// Returned the number of unused arguments in *argc.
// Return error Status if the parsing encounters errors.
// TODO(opensource): switch to a command line argument parser that can be
// shared with other tests.
Status ParseCommandLineFlags(int* argc, char* argv[]);
} // namespace tensorflow
#endif // TENSORFLOW_LIB_CORE_COMMAND_LINE_FLAGS_H_

View File

@ -120,9 +120,12 @@ variable to its initial value.
##### Args:
* <b>`initial_value`</b>: A `Tensor`, or Python object convertible to a `Tensor`.
The initial value for the Variable. Must have a shape specified unless
`validate_shape` is set to False.
* <b>`initial_value`</b>: A `Tensor`, or Python object convertible to a `Tensor`,
which is the initial value for the Variable. The initial value must have
a shape specified unless `validate_shape` is set to False. Can also be a
callable with no argument that returns the initial value when called. In
that case, `dtype` must be specified. (Note that initializer functions
from init_ops.py must first be bound to a shape before being used here.)
* <b>`trainable`</b>: If `True`, the default, also adds the variable to the graph
collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
the default list of variables to use by the `Optimizer` classes.

View File

@ -1955,6 +1955,7 @@ on the parameters to the constructor and may include:
##### Raises:
* <b>`RuntimeError`</b>: If called with a non-chief Supervisor.
* <b>`ValueError`</b>: If not `logdir` was passed to the constructor as the
services need a log directory.
@ -2182,6 +2183,7 @@ on the parameters to the constructor and may include:
##### Raises:
* <b>`RuntimeError`</b>: If called with a non-chief Supervisor.
* <b>`ValueError`</b>: If not `logdir` was passed to the constructor as the
services need a log directory.
@ -2409,7 +2411,7 @@ Start threads for `QueueRunners`.
#### `tf.train.Supervisor.summary_op` {#Supervisor.summary_op}
Return the Summary Tensor used by the supervisor.
Return the Summary Tensor used by the chief supervisor.
##### Returns:
@ -2420,7 +2422,7 @@ Return the Summary Tensor used by the supervisor.
#### `tf.train.Supervisor.summary_writer` {#Supervisor.summary_writer}
Return the SummaryWriter used by the supervisor.
Return the SummaryWriter used by the chief supervisor.
##### Returns:

View File

@ -19,6 +19,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import numpy as np
import tensorflow as tf
@ -208,5 +209,59 @@ class RNNCellTest(tf.test.TestCase):
0.13248, 0.13248]])
class SlimRNNCellTest(tf.test.TestCase):
def testBasicRNNCell(self):
with self.test_session() as sess:
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
x = tf.zeros([1, 2])
m = tf.zeros([1, 2])
my_cell = functools.partial(basic_rnn_cell, num_units=2)
g, _ = tf.nn.rnn_cell.SlimRNNCell(my_cell)(x, m)
sess.run([tf.initialize_all_variables()])
res = sess.run([g], {x.name: np.array([[1., 1.]]),
m.name: np.array([[0.1, 0.1]])})
self.assertEqual(res[0].shape, (1, 2))
def testBasicRNNCellMatch(self):
batch_size = 32
input_size = 100
num_units = 10
with self.test_session() as sess:
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
inputs = tf.random_uniform((batch_size, input_size))
_, initial_state = basic_rnn_cell(inputs, None, num_units)
my_cell = functools.partial(basic_rnn_cell, num_units=num_units)
slim_cell = tf.nn.rnn_cell.SlimRNNCell(my_cell)
slim_outputs, slim_state = slim_cell(inputs, initial_state)
rnn_cell = tf.nn.rnn_cell.BasicRNNCell(num_units)
outputs, state = rnn_cell(inputs, initial_state)
self.assertEqual(slim_outputs.get_shape(), outputs.get_shape())
self.assertEqual(slim_state.get_shape(), state.get_shape())
sess.run([tf.initialize_all_variables()])
res = sess.run([slim_outputs, slim_state, outputs, state])
self.assertAllClose(res[0], res[2])
self.assertAllClose(res[1], res[3])
def basic_rnn_cell(inputs, state, num_units, scope=None):
if state is None:
if inputs is not None:
batch_size = inputs.get_shape()[0]
dtype = inputs.dtype
else:
batch_size = 0
dtype = tf.float32
init_output = tf.zeros(tf.pack([batch_size, num_units]), dtype=dtype)
init_state = tf.zeros(tf.pack([batch_size, num_units]), dtype=dtype)
init_output.set_shape([batch_size, num_units])
init_state.set_shape([batch_size, num_units])
return init_output, init_state
else:
with tf.variable_op_scope([inputs, state], scope, "BasicRNNCell"):
output = tf.tanh(tf.nn.rnn_cell.linear([inputs, state],
num_units, True))
return output, output
if __name__ == "__main__":
tf.test.main()

View File

@ -302,6 +302,53 @@ class VariablesTestCase(tf.test.TestCase):
self.assertEqual(var.op.device, init_op.device)
sess.run(init_op)
def testInitializerFunction(self):
value = [[-42], [133.7]]
shape = [2, 1]
with self.test_session():
initializer = lambda: tf.constant(value)
with self.assertRaises(ValueError):
# Checks that dtype must be specified.
tf.Variable(initializer)
v1 = tf.Variable(initializer, dtype=tf.float32)
self.assertEqual(shape, v1.get_shape())
self.assertAllClose(value, v1.initial_value.eval())
with self.assertRaises(tf.errors.FailedPreconditionError):
v1.eval()
v2 = tf.Variable(tf.neg(v1.initialized_value()), dtype=tf.float32)
self.assertEqual(v1.get_shape(), v2.get_shape())
self.assertAllClose(np.negative(value), v2.initial_value.eval())
# Once v2.initial_value.eval() has been called, v1 has effectively been
# initialized.
self.assertAllClose(value, v1.eval())
with self.assertRaises(tf.errors.FailedPreconditionError):
v2.eval()
tf.initialize_all_variables().run()
self.assertAllClose(np.negative(value), v2.eval())
def testInitializerFunctionDevicePlacement(self):
with self.test_session():
initializer = lambda: tf.constant(42.0)
with tf.device("/cpu:100"):
v1 = tf.Variable(initializer, dtype=tf.float32, name="v1")
expected_device = "/device:CPU:100"
expected_group_v1 = [b"loc:@v1"]
self.assertEqual(expected_device, v1.op.device)
self.assertEqual(expected_group_v1, v1.op.colocation_groups())
for i in v1.initializer.inputs:
self.assertEqual(expected_device, i.op.device)
self.assertEqual(expected_group_v1, i.op.colocation_groups())
v2 = tf.Variable(initializer, dtype=tf.float32, name="v2")
expected_group_v2 = [b"loc:@v2"]
self.assertEqual(expected_group_v2, v2.op.colocation_groups())
for i in v2.initializer.inputs:
self.assertEqual(expected_group_v2, i.op.colocation_groups())
class IsInitializedTest(tf.test.TestCase):

View File

@ -167,19 +167,22 @@ def create_partitioned_variables(
slice_offset[slice_dim] += var_shape[slice_dim]
if callable(initializer):
init_val = initializer(var_shape, dtype=dtype)
init_val = ops.convert_to_tensor(init_val, dtype=dtype)
init = initializer
init_shape = var_shape
elif isinstance(initializer, ops.Tensor):
init_val = array_ops.slice(initializer, var_offset, var_shape)
init = array_ops.slice(initializer, var_offset, var_shape)
# Use the dtype of the given tensor.
dtype = init_val.dtype.base_dtype
dtype = init.dtype.base_dtype
init_shape = None
else:
init_val = ops.convert_to_tensor(initializer, dtype=dtype)
init_val = array_ops.slice(init_val, var_offset, var_shape)
init = ops.convert_to_tensor(initializer, dtype=dtype)
init = array_ops.slice(init, var_offset, var_shape)
init_shape = None
var = variable_scope.get_variable(name="part_%d" % i,
shape=init_shape,
dtype=dtype,
initializer=init_val,
initializer=init,
trainable=trainable,
collections=collections)

View File

@ -661,6 +661,42 @@ class MultiRNNCell(RNNCell):
return cur_inp, array_ops.concat(1, new_states)
class SlimRNNCell(RNNCell):
"""A simple wrapper for slim.rnn_cells."""
def __init__(self, cell_fn):
"""Create a SlimRNNCell from a cell_fn.
Args:
cell_fn: a function which takes (inputs, state, scope) and produces the
outputs and the new_state. Additionally when called with inputs=None and
state=None it should return (initial_outputs, initial_state).
Raises:
TypeError: if cell_fn is not callable
ValueError: if cell_fn cannot produce a valid initial state.
"""
if not callable(cell_fn):
raise TypeError("cell_fn %s needs to be callable", cell_fn)
self._cell_fn = cell_fn
self._cell_name = cell_fn.func.__name__
_, init_state = self._cell_fn(None, None)
state_shape = init_state.get_shape()
self._state_size = state_shape.with_rank(2)[1].value
if self._state_size is None:
raise ValueError("Initial state created by %s has invalid shape %s",
self._cell_name, state_shape)
@property
def state_size(self):
return self._state_size
def __call__(self, inputs, state, scope=None):
scope = scope or self._cell_name
output, state = self._cell_fn(inputs, state, scope=scope)
return output, state
def linear(args, output_size, bias, bias_start=0.0, scope=None):
"""Linear map: sum_i(args[i] * W[i]), where W[i] is a variable.

View File

@ -144,14 +144,19 @@ class _VariableStore(object):
with ops.control_dependencies(None):
if initializing_from_value:
init_val = initializer
variable_dtype = None
else:
with ops.name_scope(name + "/Initializer/"):
init_val = initializer(shape.as_list(), dtype=dtype)
init_val = lambda: initializer(shape.as_list(), dtype=dtype)
variable_dtype = dtype.base_dtype
# Create the variable.
v = variables.Variable(init_val, name=name, trainable=trainable,
v = variables.Variable(initial_value=init_val,
name=name,
trainable=trainable,
collections=collections,
caching_device=caching_device)
caching_device=caching_device,
dtype=variable_dtype)
self._vars[name] = v
logging.info("Created variable %s with shape %s and init %s", v.name,
format(shape), initializer)

View File

@ -156,9 +156,12 @@ class Variable(object):
variable to its initial value.
Args:
initial_value: A `Tensor`, or Python object convertible to a `Tensor`.
The initial value for the Variable. Must have a shape specified unless
`validate_shape` is set to False.
initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
which is the initial value for the Variable. The initial value must have
a shape specified unless `validate_shape` is set to False. Can also be a
callable with no argument that returns the initial value when called. In
that case, `dtype` must be specified. (Note that initializer functions
from init_ops.py must first be bound to a shape before being used here.)
trainable: If `True`, the default, also adds the variable to the graph
collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
the default list of variables to use by the `Optimizer` classes.
@ -211,9 +214,12 @@ class Variable(object):
"""Creates a new variable from arguments.
Args:
initial_value: A `Tensor`, or Python object convertible to a `Tensor`.
The initial value for the Variable. Must have a shape specified unless
`validate_shape` is set to False.
initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
which is the initial value for the Variable. The initial value must have
a shape specified unless `validate_shape` is set to False. Can also be a
callable with no argument that returns the initial value when called. In
that case, `dtype` must be specified. (Note that initializer functions
from init_ops.py must first be bound to a shape before being used here.)
trainable: If `True`, the default, also adds the variable to the graph
collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
the default list of variables to use by the `Optimizer` classes.
@ -240,25 +246,62 @@ class Variable(object):
"""
if initial_value is None:
raise ValueError("initial_value must be specified.")
init_from_fn = callable(initial_value)
if init_from_fn and dtype is None:
raise ValueError(
"dtype must also be specified when initial_value is callable.")
if collections is None:
collections = [ops.GraphKeys.VARIABLES]
if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
with ops.control_dependencies(None):
with ops.op_scope([initial_value], name, "Variable") as name:
self._initial_value = ops.convert_to_tensor(initial_value,
name="initial_value",
dtype=dtype)
initial_value_shape = self._initial_value.get_shape()
if validate_shape and not initial_value_shape.is_fully_defined():
raise ValueError("initial_value must have a shape specified: %s"
% self._initial_value)
shape_to_set = initial_value_shape if validate_shape else []
with ops.op_scope(
[] if init_from_fn else [initial_value], name, "Variable") as name:
self._variable = state_ops.variable_op(
shape_to_set, self._initial_value.dtype.base_dtype,
set_shape=validate_shape, name=name)
# Get the initial value from a callable function. The real shape of the
# variable will be set later, since under the init_from_fn case, the
# shape won't be known until after the function is invoked.
if init_from_fn:
self._variable = state_ops.variable_op(
[],
dtype.base_dtype,
set_shape=False,
name=name)
with ops.colocate_with(self._variable.op):
with ops.name_scope("Initializer"):
# Colocate the tensors created by the initial_value() function
# with the variable itself.
self._initial_value = ops.convert_to_tensor(initial_value(),
name="initial_value",
dtype=dtype)
# Or get the initial value from a Tensor or Python object.
else:
self._initial_value = ops.convert_to_tensor(initial_value,
name="initial_value",
dtype=dtype)
# In this case, the variable op can't be created until after the
# initial_value has been converted to a Tensor with a known type.
self._variable = state_ops.variable_op(
[],
self._initial_value.dtype.base_dtype,
set_shape=False,
name=name)
# Manually overrides the variable's shape with the initial value's.
if validate_shape:
initial_value_shape = self._initial_value.get_shape()
if not initial_value_shape.is_fully_defined():
raise ValueError("initial_value must have a shape specified: %s"
% self._initial_value)
self._variable.set_shape(initial_value_shape)
# TODO(b/28152992): Remove the below hack modifying the node_def shape
# directly once set_shape() handles it.
self._variable.op.node_def.attr["shape"].shape.CopyFrom(
initial_value_shape.as_proto())
# Assigns initial value.
with ops.colocate_with(self._variable.op):
self._initializer_op = state_ops.assign(
self._variable, self._initial_value,

View File

@ -79,3 +79,7 @@ def get_path_to_datafile(path):
"""
data_files_path = os.path.dirname(inspect.getfile(sys._getframe(1)))
return os.path.join(data_files_path, path)
def readahead_file_path(path, unused_readahead=None):
"""Readahead files not implemented; simply returns given path."""
return path

View File

@ -22,6 +22,7 @@ from tensorflow.core.util import event_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.platform import app
from tensorflow.python.platform import logging
from tensorflow.python.platform import resource_loader
from tensorflow.python.util import compat
@ -31,6 +32,7 @@ class EventFileLoader(object):
def __init__(self, file_path):
if file_path is None:
raise ValueError('A file path is required')
file_path = resource_loader.readahead_file_path(file_path)
logging.debug('Opening a record reader pointing at %s', file_path)
self._reader = pywrap_tensorflow.PyRecordReader_New(
compat.as_bytes(file_path), 0)

View File

@ -326,19 +326,29 @@ class Supervisor(object):
self._init_global_step(global_step=global_step)
self._graph = graph
self._is_chief = is_chief
self._logdir = logdir
self._save_summaries_secs = save_summaries_secs
self._save_model_secs = save_model_secs
self._recovery_wait_secs = recovery_wait_secs
self._coord = coordinator.Coordinator()
if logdir:
self._started_threads = []
self._recovery_wait_secs = recovery_wait_secs
# Only chief supervisors write event files, so only chief supervisors
# should have event-writing properties. Set to None for non-chiefs.
if self._is_chief:
self._logdir = logdir
self._save_summaries_secs = save_summaries_secs
self._save_model_secs = save_model_secs
else:
self._logdir = None
self._save_summaries_secs = None
self._save_model_secs = None
if self._is_chief and self._logdir:
self._save_path = os.path.join(self._logdir, checkpoint_basename)
self._summary_writer = summary_io.SummaryWriter(self._logdir)
else:
self._save_path = None
self._summary_writer = None
self._init_session_manager(session_manager=session_manager)
self._started_threads = []
self._verify_setup()
# The graph is not allowed to change anymore.
graph.finalize()
@ -520,7 +530,7 @@ class Supervisor(object):
@property
def summary_writer(self):
"""Return the SummaryWriter used by the supervisor.
"""Return the SummaryWriter used by the chief supervisor.
Returns:
A SummaryWriter.
@ -529,7 +539,7 @@ class Supervisor(object):
@property
def summary_op(self):
"""Return the Summary Tensor used by the supervisor.
"""Return the Summary Tensor used by the chief supervisor.
Returns:
A string Tensor for the summary or `None`.
@ -583,8 +593,7 @@ class Supervisor(object):
def _write_graph(self):
"""Writes graph_def to `logdir` and adds it to summary if applicable."""
if not self._is_chief:
return
assert self._is_chief
if self._logdir:
training_util.write_graph(self._graph.as_graph_def(),
self._logdir, "graph.pbtxt")
@ -610,11 +619,13 @@ class Supervisor(object):
sv.coord.Join(<list of threads>)
Raises:
RuntimeError: If called with a non-chief Supervisor.
ValueError: If not `logdir` was passed to the constructor as the
services need a log directory.
"""
if not self._is_chief:
return
raise RuntimeError("Only chief supervisor can start standard services. "
"Because only cheif supervisors can write events.")
if not self._logdir:
logging.warning("Standard services need a 'logdir' "
"passed to the SessionManager")
@ -812,14 +823,18 @@ class Supervisor(object):
TypeError: if 'summary' is not a Summary proto or a string.
RuntimeError: if the Supervisor was created without a `logdir`.
"""
if not self._logdir:
raise RuntimeError("summary_computed() requires a logdir")
if not self._summary_writer:
raise RuntimeError("Writing a summary requires a summary writer.")
if global_step is None and self.global_step is not None:
global_step = training_util.global_step(sess, self.global_step)
if self._summary_writer:
self._summary_writer.add_summary(summary, global_step)
self._summary_writer.add_summary(summary, global_step)
def _default_global_step_tensor(self):
"""Returns the global_step from the default graph.
Returns:
The global step `Tensor` or `None`.
"""
try:
gs = ops.get_default_graph().get_tensor_by_name("global_step:0")
if gs.dtype.base_dtype in [dtypes.int32, dtypes.int64]:

View File

@ -73,12 +73,11 @@ class SupervisorTest(tf.test.TestCase):
sess.close()
sv.stop()
def testSummary(self):
def testChiefCanWriteEvents(self):
logdir = self._TestDir("basics")
with tf.Graph().as_default():
const = tf.constant([1.0, 2.0, 3.0])
summ = tf.scalar_summary(["c1", "c2", "c3"], const)
sv = tf.train.Supervisor(logdir=logdir, summary_op=None)
summ = tf.scalar_summary(["c1", "c2", "c3"], tf.constant([1.0, 2.0, 3.0]))
sv = tf.train.Supervisor(is_chief=True, logdir=logdir, summary_op=None)
sess = sv.prepare_or_wait_for_session("")
sv.summary_computed(sess, sess.run(summ))
sess.close()
@ -113,13 +112,31 @@ class SupervisorTest(tf.test.TestCase):
# We should be done.
self.assertRaises(StopIteration, lambda: next(rr))
def testNonChiefCannotWriteEvents(self):
def _summary_computed():
with tf.Graph().as_default():
sv = tf.train.Supervisor(is_chief=False)
sess = sv.prepare_or_wait_for_session("")
summ = tf.scalar_summary(["c1", "c2"], tf.constant([1.0, 2.0]))
sv.summary_computed(sess, sess.run(summ))
def _start_standard_services():
with tf.Graph().as_default():
sv = tf.train.Supervisor(is_chief=False)
sess = sv.prepare_or_wait_for_session("")
sv.start_standard_services(sess)
self.assertRaises(RuntimeError, _summary_computed)
self.assertRaises(RuntimeError, _start_standard_services)
def testNoLogdirButWantSummary(self):
with tf.Graph().as_default():
const = tf.constant([1.0, 2.0, 3.0])
summ = tf.scalar_summary(["c1", "c2", "c3"], const)
sv = tf.train.Supervisor(logdir="", summary_op=None)
sess = sv.prepare_or_wait_for_session("")
with self.assertRaisesRegexp(RuntimeError, "requires a logdir"):
with self.assertRaisesRegexp(RuntimeError, "requires a summary writer"):
sv.summary_computed(sess, sess.run(summ))
def testNoLogdirSucceeds(self):

View File

@ -0,0 +1,66 @@
# Description:
# Benchmark utility that can run on desktop and Android.
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "tf_copts")
exports_files(["LICENSE"])
cc_library(
name = "benchmark_model_lib",
srcs = [
"benchmark_model.cc",
],
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib",
],
"//conditions:default": [
"//tensorflow/core:core_cpu",
"//tensorflow/core:lib",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
],
}),
)
# This binary may be built for either desktop or Android.
# A typical Android build command will look like the following:
# bazel build -c opt tensorflow/core:android_tensorflow_lib \
# --crosstool_top=//external:android/crosstool \
# --cpu=armeabi-v7a \
# --host_crosstool_top=@bazel_tools//tools/cpp:toolchain
#
# NOTE: currently '-pthread' must be removed from the LINK_OPTS variable
# in google/protobuf/BUILD to sucessfully build for Android. This is temporary
# pending an update of the version of the protobuf library that Tensorflow
# uses.
cc_binary(
name = "benchmark_model",
copts = tf_copts(),
linkopts = select({
"//tensorflow:android": [
"-pie",
"-s",
"-landroid",
"-ljnigraphics",
"-llog",
"-lm",
"-z defs",
"-s",
"-Wl,--icf=all", # Identical Code Folding
"-Wl,--exclude-libs,ALL", # Exclude syms in all libs from auto export
],
"//conditions:default": [],
}),
linkstatic = 1,
visibility = ["//visibility:public"],
deps = [":benchmark_model_lib"],
)

View File

@ -0,0 +1,57 @@
# Tensorflow Model Benchmark Tool
## Description
A simple C++ binary to benchmark a compute graph and its individual operators,
both on desktop machines and on Android.
## To build/install/run
### On Android:
(1) build for your specific platform, e.g.:
```bash
$bazel build -c opt \
--crosstool_top=//external:android/crosstool \
--cpu=armeabi-v7a \
--host_crosstool_top=@bazel_tools//tools/cpp:toolchain \
tensorflow/tools/benchmark:benchmark_model
```
(2) Connect your phone. Push the binary to your phone with adb push
(make the directory if required):
```bash
$adb push bazel-bin/tensorflow/tools/benchmark/benchmark_model /data/local/tmp
```
(3) Push the compute graph that you need to test. For example:
adb push tensorflow_inception_graph.pb /data/local/tmp
(4) Run the benchmark. For example:
```bash
$adb shell "/data/local/tmp/benchmark_model \
--graph=/data/local/tmp/tensorflow_inception_graph.pb \
--input_layer="input:0" \
--input_layer_shape="1,224,224,3" \
--input_layer_type="float" \
--output_layer="output:0"
```
### On desktop:
(1) build the binary
```bash
$bazel build -c opt tensorflow/tools/benchmark:benchmark_model
```
(2) Run on your compute graph, similar to the Android case but without the need of adb shell.
For example:
```bash
$bazel-bin/tensorflow/tools/benchmark/benchmark_model \
--graph=tensorflow_inception_graph.pb \
--input_layer="input:0" \
--input_layer_shape="1,224,224,3" \
--input_layer_type="float" \
--output_layer="output:0"
```
The Inception graph used as an example here may be downloaded from
https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip

View File

@ -0,0 +1,225 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// A C++ binary to benchmark a compute graph and its individual operators,
// both on desktop machines and on Android.
//
// See README.md for usage instructions.
#include <cstdlib>
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/command_line_flags.h"
#include "tensorflow/core/util/stat_summarizer.h"
namespace tensorflow {
// Global variables that holds the Tensorflow classifier.
static std::unique_ptr<tensorflow::Session> session;
static StatSummarizer g_stats;
struct Flags {
string graph = "/data/local/tmp/tensorflow_inception_graph.pb";
string input_layer = "input:0";
string input_layer_shape = "1,224,224,3";
string input_layer_type = "float";
string output_layer = "output:0";
int num_runs = 50;
string run_delay = "-1.0";
int num_threads = -1;
};
static Flags* flags; // Filled in by main()
static bool InitializeBenchmark() {
g_stats.Reset();
LOG(INFO) << "Loading Tensorflow.";
tensorflow::SessionOptions options;
tensorflow::ConfigProto& config = options.config;
if (flags->num_threads > 0) {
config.set_intra_op_parallelism_threads(flags->num_threads);
}
LOG(INFO) << "Got config, " << config.device_count_size() << " devices";
session.reset(tensorflow::NewSession(options));
tensorflow::GraphDef tensorflow_graph;
Status s = ReadBinaryProto(Env::Default(), flags->graph, &tensorflow_graph);
if (!s.ok()) {
LOG(ERROR) << "Could not create Tensorflow Graph: " << s;
return false;
}
s = session->Create(tensorflow_graph);
if (!s.ok()) {
LOG(ERROR) << "Could not create Tensorflow Session: " << s;
return false;
}
// Clear the proto to save memory space.
tensorflow_graph.Clear();
return true;
}
static bool RunBenchmark() {
DataType input_data_type;
CHECK(DataTypeFromString(flags->input_layer_type, &input_data_type))
<< flags->input_layer_type << " was an invalid type";
std::vector<int32> sizes;
CHECK(str_util::SplitAndParseAsInts(flags->input_layer_shape, ',', &sizes))
<< "Incorrect size string specified: " << flags->input_layer_shape;
TensorShape input_shape;
for (int i = 0; i < sizes.size(); ++i) {
input_shape.AddDim(sizes[i]);
}
Tensor input_tensor(input_data_type, input_shape);
switch (input_data_type) {
case DT_INT32: {
auto int_tensor = input_tensor.flat<int32>();
int_tensor = int_tensor.constant(0.0);
break;
}
case DT_FLOAT: {
auto float_tensor = input_tensor.flat<float>();
float_tensor = float_tensor.constant(0.0);
break;
}
case DT_QUINT8: {
auto int_tensor = input_tensor.flat<quint8>();
int_tensor = int_tensor.constant(0.0);
break;
}
default:
LOG(FATAL) << "Unsupported input type: " << flags->input_layer_type;
}
std::vector<std::pair<string, tensorflow::Tensor> > input_tensors(
{{flags->input_layer, input_tensor}});
std::vector<tensorflow::Tensor> output_tensors;
std::vector<string> output_names({flags->output_layer});
tensorflow::Status s;
RunOptions run_options;
run_options.set_trace_level(RunOptions::FULL_TRACE);
RunMetadata run_metadata;
s = session->Run(run_options, input_tensors, output_names, {},
&output_tensors, &run_metadata);
assert(run_metadata.has_step_stats());
const StepStats& stats = run_metadata.step_stats();
g_stats.ProcessStepStats(stats);
if (!s.ok()) {
LOG(ERROR) << "Error during inference: " << s;
return false;
}
return true;
}
} // namespace tensorflow
int main(int argc, char** argv) {
tensorflow::flags = new tensorflow::Flags();
const bool parse_result = tensorflow::ParseFlags(
&argc, argv,
{
tensorflow::Flag("graph", &tensorflow::flags->graph),
tensorflow::Flag("input_layer", &tensorflow::flags->input_layer),
tensorflow::Flag("input_layer_shape",
&tensorflow::flags->input_layer_shape),
tensorflow::Flag("input_layer_type",
&tensorflow::flags->input_layer_type),
tensorflow::Flag("output_layer", &tensorflow::flags->output_layer),
tensorflow::Flag("num_runs", &tensorflow::flags->num_runs),
tensorflow::Flag("run_delay", &tensorflow::flags->run_delay),
tensorflow::Flag("num_threads", &tensorflow::flags->num_threads),
});
if (!parse_result) {
LOG(ERROR) << "Error parsing command-line flags.";
return -1;
}
::tensorflow::port::InitMain(argv[0], &argc, &argv);
if (argc > 1) {
LOG(ERROR) << "Unknown argument " << argv[1];
return -1;
}
LOG(INFO) << "Graph: [" << tensorflow::flags->graph << "]";
LOG(INFO) << "Input layer: [" << tensorflow::flags->input_layer << "]";
LOG(INFO) << "Input shape: [" << tensorflow::flags->input_layer_shape << "]";
LOG(INFO) << "Input type: [" << tensorflow::flags->input_layer_type << "]";
LOG(INFO) << "Output layer: [" << tensorflow::flags->output_layer << "]";
LOG(INFO) << "Num runs: [" << tensorflow::flags->num_runs << "]";
LOG(INFO) << "Inter-run delay (seconds): [" << tensorflow::flags->run_delay
<< "]";
LOG(INFO) << "Num threads: [" << tensorflow::flags->num_threads << "]";
if (!tensorflow::InitializeBenchmark()) {
return -1;
}
// Convert the run_delay string into a timespec.
const double sleep_seconds =
std::strtod(tensorflow::flags->run_delay.c_str(), nullptr);
timespec req;
req.tv_sec = static_cast<time_t>(sleep_seconds);
req.tv_nsec = (sleep_seconds - req.tv_sec) * 1000000000;
LOG(INFO) << "Running benchmark";
for (int i = 0; i < tensorflow::flags->num_runs; ++i) {
if (!tensorflow::RunBenchmark()) {
LOG(INFO) << "Failed on run " << i;
return -1;
}
// If requested, sleep between runs for an arbitrary amount of time.
// This can be helpful to determine the effect of mobile processor
// scaling and thermal throttling.
if (sleep_seconds > 0.0) {
nanosleep(&req, nullptr);
}
}
tensorflow::g_stats.PrintStepStats();
return 0;
}

View File

@ -139,6 +139,16 @@ else
# Assume: PYTHON_BIN_PATH is exported by the script above
fi
# Obtain the path to head/ghead binary (for log file printing)
HEAD_BIN="ghead"
if [[ -z $(which "${HEAD_BIN}") ]]; then
# This is not Mac (which uses coreutils/ghead), use head.
HEAD_BIN="head"
if [[ -z $(which "${HEAD_BIN}") ]]; then
die "Unable to obtain path to head or ghead"
fi
fi
if [[ -z "${PYTHON_BIN_PATH}" ]]; then
die "PYTHON_BIN_PATH was not provided. If this is not virtualenv, "\
"did you run configure?"
@ -371,7 +381,7 @@ while true; do
echo " Log @: ${TEST_LOGS[K]}"
echo "============== BEGINS failure log content =============="
head --lines=-1 "${TEST_LOGS[K]}"
"${HEAD_BIN}" --lines=-1 "${TEST_LOGS[K]}"
echo "============== ENDS failure log content =============="
echo ""
fi