diff --git a/tensorflow/contrib/tensor_forest/core/ops/count_extremely_random_stats_op.cc b/tensorflow/contrib/tensor_forest/core/ops/count_extremely_random_stats_op.cc index ab5ac9c8999..2f1602d0ba5 100644 --- a/tensorflow/contrib/tensor_forest/core/ops/count_extremely_random_stats_op.cc +++ b/tensorflow/contrib/tensor_forest/core/ops/count_extremely_random_stats_op.cc @@ -18,6 +18,7 @@ // only op that involves tree traversal, and is constructed so that it can // be run in parallel on separate batches of data. #include +#include #include "tensorflow/contrib/tensor_forest/core/ops/tree_utils.h" @@ -25,10 +26,12 @@ #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/util/work_sharder.h" namespace tensorflow { using std::get; +using std::make_pair; using std::make_tuple; using std::pair; using std::tuple; @@ -42,6 +45,71 @@ using tensorforest::DecideNode; using tensorforest::Initialize; using tensorforest::IsAllInitialized; +// A data structure to store the results of parallel tree traversal. +struct InputDataResult { + // A list of each node that was visited. + std::vector node_indices; + // The accumulator of the leaf that a data point ended up at, or -1 if none. + int32 leaf_accumulator; + // The left-branch taken candidate splits. + std::vector split_adds; + // If the candidate splits for the leaf that a data point arrived at + // were initialized or not, which determines if we add this to total + // pcw counts or not. + bool splits_initialized; +}; + +void Evaluate(const Tensor& input_data, const Tensor& input_labels, + const Tensor& tree_tensor, const Tensor& tree_thresholds, + const Tensor& node_to_accumulator, + const Tensor& candidate_split_features, + const Tensor& candidate_split_thresholds, + InputDataResult* results, int64 start, int64 end) { + const auto tree = tree_tensor.tensor(); + const auto thresholds = tree_thresholds.unaligned_flat(); + const auto node_map = node_to_accumulator.unaligned_flat(); + const auto split_features = candidate_split_features.tensor(); + const auto split_thresholds = candidate_split_thresholds.tensor(); + + const int32 num_splits = candidate_split_features.shape().dim_size(1); + + for (int i = start; i < end; ++i) { + const Tensor point = input_data.Slice(i, i + 1); + int node_index = 0; + results[i].splits_initialized = false; + while (true) { + results[i].node_indices.push_back(node_index); + int32 left_child = tree(node_index, CHILDREN_INDEX); + if (left_child == LEAF_NODE) { + const int32 accumulator = node_map(node_index); + results[i].leaf_accumulator = accumulator; + // If the leaf is not fertile or is not yet initialized, we don't + // count it in the candidate/total split per-class-weights because + // it won't have any candidate splits yet. + if (accumulator >= 0 && + IsAllInitialized(candidate_split_features.Slice( + accumulator, accumulator + 1))) { + results[i].splits_initialized = true; + for (int split = 0; split < num_splits; split++) { + if (!DecideNode(point, split_features(accumulator, split), + split_thresholds(accumulator, split))) { + results[i].split_adds.push_back(split); + } + } + } + break; + } else if (left_child == FREE_NODE) { + LOG(ERROR) << "Reached a free node, not good."; + results[i].node_indices.push_back(FREE_NODE); + break; + } + node_index = + left_child + DecideNode(point, tree(node_index, FEATURE_INDEX), + thresholds(node_index)); + } + } +} + REGISTER_OP("CountExtremelyRandomStats") .Attr("num_classes: int32") .Input("input_data: float") @@ -79,9 +147,9 @@ REGISTER_OP("CountExtremelyRandomStats") gives the j-th feature of the i-th input. input_labels: The training batch's labels; `input_labels[i]` is the class of the i-th input. - tree:= A 2-d int32 tensor. `tree[0][i]` gives the index of the left child - of the i-th node, `tree[0][i] + 1` gives the index of the right child of - the i-th node, and `tree[1][i]` gives the index of the feature used to + tree:= A 2-d int32 tensor. `tree[i][0]` gives the index of the left child + of the i-th node, `tree[i][0] + 1` gives the index of the right child of + the i-th node, and `tree[i][1]` gives the index of the feature used to split the i-th node. tree_thresholds: `tree_thresholds[i]` is the value used to split the i-th node. @@ -176,7 +244,31 @@ class CountExtremelyRandomStats : public OpKernel { "candidate_split_features and candidate_split_thresholds should be " "the same shape.")); - const int32 num_splits = candidate_split_features.shape().dim_size(1); + // Evaluate input data in parallel. + const int64 num_data = input_data.shape().dim_size(0); + std::unique_ptr results(new InputDataResult[num_data]); + auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); + int num_threads = worker_threads->num_threads; + if (num_threads <= 1) { + Evaluate(input_data, input_labels, tree_tensor, tree_thresholds, + node_to_accumulator, candidate_split_features, + candidate_split_thresholds, results.get(), 0, num_data); + } else { + auto work = [&input_data, &input_labels, &tree_tensor, &tree_thresholds, + &node_to_accumulator, &candidate_split_features, + &candidate_split_thresholds, &num_data, + &results](int64 start, int64 end) { + CHECK(start <= end); + CHECK(end <= num_data); + Evaluate(input_data, input_labels, tree_tensor, tree_thresholds, + node_to_accumulator, candidate_split_features, + candidate_split_thresholds, results.get(), start, end); + }; + Shard(num_threads, worker_threads->workers, num_data, 100, work); + } + + // Set output tensors. + const auto labels = input_labels.unaligned_flat(); // node pcw delta Tensor* output_node_pcw_delta = nullptr; @@ -196,58 +288,28 @@ class CountExtremelyRandomStats : public OpKernel { &output_leaves)); auto out_leaves = output_leaves->unaligned_flat(); - const auto tree = tree_tensor.tensor(); - const auto thresholds = tree_thresholds.unaligned_flat(); - const auto labels = input_labels.unaligned_flat(); - const auto node_map = node_to_accumulator.unaligned_flat(); - const auto split_features = candidate_split_features.tensor(); - const auto split_thresholds = candidate_split_thresholds.tensor(); - - const int32 num_data = input_data.shape().dim_size(0); - // -> count delta std::unordered_map, int32, PairIntHash> total_delta; // -> count delta std::unordered_map, int32, TupleIntHash> split_delta; - for (int i = 0; i < num_data; i++) { - const Tensor point = input_data.Slice(i, i+1); - int node_index = 0; - while (true) { - const int32 label = labels(i); - ++out_node(node_index, label); - int32 left_child = tree(node_index, CHILDREN_INDEX); - if (left_child == LEAF_NODE) { - out_leaves(i) = node_index; - const int32 accumulator = node_map(node_index); - // If the leaf is not fertile or is not yet initialized, we don't - // count it in the candidate/total split per-class-weights because - // it won't have any candidate splits yet. - if (accumulator >= 0 && - IsAllInitialized( - candidate_split_features.Slice(accumulator, - accumulator + 1))) { - ++total_delta[std::make_pair(accumulator, label)]; - for (int split = 0; split < num_splits; split++) { - if (!DecideNode(point, split_features(accumulator, split), - split_thresholds(accumulator, split))) { - ++split_delta[make_tuple(accumulator, split, label)]; - } - } - } - break; - } else if (left_child == FREE_NODE) { - LOG(ERROR) << "Reached a free node, not good."; - out_leaves(i) = FREE_NODE; - break; + + for (int32 i = 0; i < num_data; ++i) { + const int32 label = labels(i); + const int32 accumulator = results[i].leaf_accumulator; + for (const int32 node : results[i].node_indices) { + ++out_node(node, label); + } + out_leaves(i) = results[i].node_indices.back(); + if (accumulator >= 0 && results[i].splits_initialized) { + ++total_delta[make_pair(accumulator, label)]; + for (const int32 split : results[i].split_adds) { + ++split_delta[make_tuple(accumulator, split, label)]; } - node_index = left_child + - DecideNode(point, tree(node_index, FEATURE_INDEX), - thresholds(node_index)); } } - // candidate splits pcw indices + // candidate splits pcw indices Tensor* output_candidate_pcw_indices = nullptr; TensorShape candidate_pcw_shape; candidate_pcw_shape.AddDim(split_delta.size()); diff --git a/tensorflow/contrib/tensor_forest/core/ops/sample_inputs_op.cc b/tensorflow/contrib/tensor_forest/core/ops/sample_inputs_op.cc index 452a68353c8..90ed0420e3a 100644 --- a/tensorflow/contrib/tensor_forest/core/ops/sample_inputs_op.cc +++ b/tensorflow/contrib/tensor_forest/core/ops/sample_inputs_op.cc @@ -94,7 +94,7 @@ class SampleInputs : public OpKernel { "split_sampling_random_seed", &split_sampling_random_seed_)); // Set up the random number generator. if (split_sampling_random_seed_ == 0) { - uint64 time_seed = static_cast(std::time(NULL)); + uint64 time_seed = static_cast(std::clock()); single_rand_ = std::unique_ptr( new random::PhiloxRandom(time_seed)); } else { diff --git a/tensorflow/contrib/tensor_forest/core/ops/tree_predictions_op.cc b/tensorflow/contrib/tensor_forest/core/ops/tree_predictions_op.cc index 3e84534795f..37640c31b63 100644 --- a/tensorflow/contrib/tensor_forest/core/ops/tree_predictions_op.cc +++ b/tensorflow/contrib/tensor_forest/core/ops/tree_predictions_op.cc @@ -44,9 +44,9 @@ REGISTER_OP("TreePredictions") input_data: The training batch's features as a 2-d tensor; `input_data[i][j]` gives the j-th feature of the i-th input. - tree:= A 2-d int32 tensor. `tree[0][i]` gives the index of the left child - of the i-th node, `tree[0][i] + 1` gives the index of the right child of - the i-th node, and `tree[1][i]` gives the index of the feature used to + tree:= A 2-d int32 tensor. `tree[i][0]` gives the index of the left child + of the i-th node, `tree[i][0] + 1` gives the index of the right child of + the i-th node, and `tree[i][1]` gives the index of the feature used to split the i-th node. tree_thresholds: `tree_thresholds[i]` is the value used to split the i-th node. diff --git a/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py b/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py index e93bb17522a..5f2a2e49ef6 100644 --- a/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py +++ b/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py @@ -17,7 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import tensorflow # pylint: disable=unused-import +import tensorflow as tf from tensorflow.contrib.tensor_forest.python.ops import training_ops @@ -47,6 +47,29 @@ class CountExtremelyRandomStatsTest(test_util.TensorFlowTestCase): self.tree_thresholds, self.node_map, self.split_features, self.split_thresholds, num_classes=4)) + self.assertAllEqual( + [[1., 1., 1., 1.], [1., 1., 0., 0.], [0., 0., 1., 1.]], + pcw_node.eval()) + self.assertAllEqual([[0, 0, 0]], pcw_splits_indices.eval()) + self.assertAllEqual([1.], pcw_splits_delta.eval()) + self.assertAllEqual([[0, 1], [0, 0]], pcw_totals_indices.eval()) + self.assertAllEqual([1., 1.], pcw_totals_delta.eval()) + self.assertAllEqual([1, 1, 2, 2], leaves.eval()) + + def testThreaded(self): + with self.test_session( + config=tf.ConfigProto(intra_op_parallelism_threads=2)): + (pcw_node, pcw_splits_indices, pcw_splits_delta, pcw_totals_indices, + pcw_totals_delta, + leaves) = (self.ops.count_extremely_random_stats(self.input_data, + self.input_labels, + self.tree, + self.tree_thresholds, + self.node_map, + self.split_features, + self.split_thresholds, + num_classes=4)) + self.assertAllEqual([[1., 1., 1., 1.], [1., 1., 0., 0.], [0., 0., 1., 1.]], pcw_node.eval()) diff --git a/tensorflow/contrib/tensor_forest/python/ops/inference_ops.py b/tensorflow/contrib/tensor_forest/python/ops/inference_ops.py index 7cad6a8d38f..62add1bf6ce 100644 --- a/tensorflow/contrib/tensor_forest/python/ops/inference_ops.py +++ b/tensorflow/contrib/tensor_forest/python/ops/inference_ops.py @@ -49,12 +49,13 @@ def TreePredictions(op): # there's not yet any guarantee that the shared object exists. # In which case, "import tensorflow" will always crash, even for users that # never use contrib. -def Load(): +def Load(library_base_dir=''): """Load the inference ops library and return the loaded module.""" with _ops_lock: global _inference_ops if not _inference_ops: - data_files_path = tf.resource_loader.get_data_files_path() + data_files_path = os.path.join(library_base_dir, + tf.resource_loader.get_data_files_path()) tf.logging.info('data path: %s', data_files_path) _inference_ops = tf.load_op_library(os.path.join( data_files_path, INFERENCE_OPS_FILE)) diff --git a/tensorflow/contrib/tensor_forest/python/ops/training_ops.py b/tensorflow/contrib/tensor_forest/python/ops/training_ops.py index 8ca2491d608..84bc2cfea6f 100644 --- a/tensorflow/contrib/tensor_forest/python/ops/training_ops.py +++ b/tensorflow/contrib/tensor_forest/python/ops/training_ops.py @@ -25,6 +25,7 @@ import tensorflow as tf from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape + TRAINING_OPS_FILE = '_training_ops.so' _training_ops = None @@ -96,12 +97,13 @@ def _UpdateFertileSlotsShape(unused_op): # there's not yet any guarantee that the shared object exists. # In which case, "import tensorflow" will always crash, even for users that # never use contrib. -def Load(): +def Load(library_base_dir=''): """Load training ops library and return the loaded module.""" with _ops_lock: global _training_ops if not _training_ops: - data_files_path = tf.resource_loader.get_data_files_path() + data_files_path = os.path.join(library_base_dir, + tf.resource_loader.get_data_files_path()) tf.logging.info('data path: %s', data_files_path) _training_ops = tf.load_op_library(os.path.join( data_files_path, TRAINING_OPS_FILE)) diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest.py b/tensorflow/contrib/tensor_forest/python/tensor_forest.py index 6257d6481d6..3d254f2d505 100644 --- a/tensorflow/contrib/tensor_forest/python/tensor_forest.py +++ b/tensorflow/contrib/tensor_forest/python/tensor_forest.py @@ -25,19 +25,6 @@ from tensorflow.contrib.tensor_forest.python.ops import inference_ops from tensorflow.contrib.tensor_forest.python.ops import training_ops -flags = tf.app.flags -FLAGS = flags.FLAGS - - -# Default parameter values. These are all only used if the corresponding -# parameter is not specified when constructing the ForestHParams. -flags.DEFINE_integer('num_trees', 100, 'Number of trees in forest') -flags.DEFINE_integer('max_nodes', 10000, 'Maxmimum number of tree nodes.') -flags.DEFINE_float( - 'samples_to_decide', 25.0, - 'Only decide on a split, or only fully use a leaf, after this many ' - 'training samples have been seen.') - # If tree[i][0] equals this value, then i is a leaf node. LEAF_NODE = -1 @@ -57,7 +44,20 @@ LEAF_NODE = -1 class ForestHParams(object): """A base class for holding hyperparameters and calculating good defaults.""" - def __init__(self, **kwargs): + def __init__(self, num_trees=100, max_nodes=10000, bagging_fraction=1.0, + samples_to_decide=25, max_depth=0, num_splits_to_consider=0, + max_fertile_nodes=0, split_after_samples=0, + valid_leaf_threshold=0, **kwargs): + self.num_trees = num_trees + self.max_nodes = max_nodes + self.bagging_fraction = bagging_fraction + self.samples_to_decide = samples_to_decide + self.max_depth = max_depth + self.num_splits_to_consider = num_splits_to_consider + self.max_fertile_nodes = max_fertile_nodes + self.split_after_samples = split_after_samples + self.valid_leaf_threshold = valid_leaf_threshold + for name, value in kwargs.items(): setattr(self, name, value) @@ -69,19 +69,21 @@ class ForestHParams(object): # Fail fast if num_classes isn't set. _ = getattr(self, 'num_classes') - self.num_trees = getattr(self, 'num_trees', FLAGS.num_trees) - self.max_nodes = getattr(self, 'max_nodes', FLAGS.max_nodes) + self.training_library_base_dir = getattr( + self, 'training_library_base_dir', '') + self.inference_library_base_dir = getattr( + self, 'inference_library_base_dir', '') # Allow each tree to be unbalanced by up to a factor of 2. - self.max_depth = getattr(self, 'max_depth', - int(2 * math.ceil(math.log(self.max_nodes, 2)))) + self.max_depth = (self.max_depth or + int(2 * math.ceil(math.log(self.max_nodes, 2)))) # The Random Forest literature recommends sqrt(# features) for # classification problems, and p/3 for regression problems. # TODO(thomaswc): Consider capping this for large number of features. - if not getattr(self, 'num_splits_to_consider', None): - self.num_splits_to_consider = max(10, int( - math.ceil(math.sqrt(self.num_features)))) + self.num_splits_to_consider = ( + self.num_splits_to_consider or + max(10, int(math.ceil(math.sqrt(self.num_features))))) # max_fertile_nodes doesn't effect performance, only training speed. # We therefore set it primarily based upon space considerations. @@ -91,22 +93,19 @@ class ForestHParams(object): num_fertile = int(math.ceil(self.max_nodes / self.num_splits_to_consider)) # But always use at least 1000 accumulate slots. num_fertile = max(num_fertile, 1000) - self.max_fertile_nodes = getattr(self, 'max_fertile_nodes', num_fertile) + self.max_fertile_nodes = self.max_fertile_nodes or num_fertile # But it also never needs to be larger than the number of leaves, # which is max_nodes / 2. - self.max_fertile_nodes = min(self.max_nodes, - int(math.ceil(self.max_fertile_nodes / 2.0))) + self.max_fertile_nodes = min(self.max_fertile_nodes, + int(math.ceil(self.max_nodes / 2.0))) # split_after_samples and valid_leaf_threshold should be about the same. # Therefore, if either is set, use it to set the other. Otherwise, fall - # back on FLAGS.samples_to_decide. - samples_to_decide = ( - getattr(self, 'split_after_samples', - getattr(self, 'valid_leaf_threshold', FLAGS.samples_to_decide))) - self.split_after_samples = getattr(self, 'split_after_samples', - samples_to_decide) - self.valid_leaf_threshold = getattr(self, 'valid_leaf_threshold', - samples_to_decide) + # back on samples_to_decide. + samples_to_decide = self.split_after_samples or self.samples_to_decide + + self.split_after_samples = self.split_after_samples or samples_to_decide + self.valid_leaf_threshold = self.valid_leaf_threshold or samples_to_decide # We have num_splits_to_consider slots to fill, and we want to spend # approximately split_after_samples samples initializing them. @@ -184,23 +183,6 @@ class TreeStats(object): self.num_leaves = num_leaves -def get_tree_stats(variables, unused_params, session): - num_nodes = variables.end_of_tree.eval(session=session) - 1 - num_leaves = tf.where( - tf.equal(tf.squeeze(tf.slice(variables.tree, [0, 0], [-1, 1])), - LEAF_NODE)).eval(session=session).shape[0] - return TreeStats(num_nodes, num_leaves) - - -def get_forest_stats(variables, params, session): - - tree_stats = [] - for i in range(params.num_trees): - tree_stats.append(get_tree_stats(variables[i], params, session)) - - return ForestStats(tree_stats, params) - - class ForestTrainingVariables(object): """A container for a forests training data, consisting of multiple trees. @@ -212,9 +194,11 @@ class ForestTrainingVariables(object): ... forest_variables.tree ... """ - def __init__(self, params): - self.variables = [TreeTrainingVariables(params) - for _ in range(params.num_trees)] + def __init__(self, params, device_assigner): + self.variables = [] + for i in range(params.num_trees): + with tf.device(device_assigner.get_device(i)): + self.variables.append(TreeTrainingVariables(params)) def __setitem__(self, t, val): self.variables[t] = val @@ -223,15 +207,41 @@ class ForestTrainingVariables(object): return self.variables[t] +class RandomForestDeviceAssigner(object): + """A device assigner that uses the default device. + + Write subclasses that implement get_device for control over how trees + get assigned to devices. This assumes that whole trees are assigned + to a device. + """ + + def __init__(self): + self.cached = None + + def get_device(self, unused_tree_num): + if not self.cached: + dummy = tf.constant(0) + self.cached = dummy.device + + return self.cached + + class RandomForestGraphs(object): """Builds TF graphs for random forest training and inference.""" - def __init__(self, params): + def __init__(self, params, device_assigner=None, variables=None): self.params = params - self.variables = ForestTrainingVariables(self.params) - self.trees = [RandomTreeGraphs(self.variables[i], self.params, - training_ops.Load(), inference_ops.Load()) - for i in range(self.params.num_trees)] + self.device_assigner = device_assigner or RandomForestDeviceAssigner() + tf.logging.info('Constructing forest with params = ') + tf.logging.info(self.params.__dict__) + self.variables = variables or ForestTrainingVariables( + self.params, device_assigner=self.device_assigner) + self.trees = [ + RandomTreeGraphs( + self.variables[i], self.params, + training_ops.Load(self.params.training_library_base_dir), + inference_ops.Load(self.params.inference_library_base_dir)) + for i in range(self.params.num_trees)] def training_graph(self, input_data, input_labels): """Constructs a TF graph for training a random forest. @@ -246,12 +256,26 @@ class RandomForestGraphs(object): """ tree_graphs = [] for i in range(self.params.num_trees): - tf.logging.info('Constructing tree %d', i) - seed = self.params.base_random_seed - if seed != 0: - seed += i - tree_graphs.append(self.trees[i].training_graph( - input_data, input_labels, seed)) + with tf.device(self.device_assigner.get_device(i)): + seed = self.params.base_random_seed + if seed != 0: + seed += i + # If using bagging, randomly select some of the input. + tree_data = input_data + tree_labels = input_labels + if self.params.bagging_fraction < 1.0: + # TODO(thomaswc): This does sampling without replacment. Consider + # also allowing sampling with replacement as an option. + batch_size = tf.slice(tf.shape(input_data), [0], [1]) + r = tf.random_uniform(batch_size, seed=seed) + mask = tf.less(r, tf.ones_like(r) * self.params.bagging_fraction) + gather_indices = tf.squeeze(tf.where(mask), squeeze_dims=[1]) + # TODO(thomaswc): Calculate out-of-bag data and labels, and store + # them for use in calculating statistics later. + tree_data = tf.gather(input_data, gather_indices) + tree_labels = tf.gather(input_labels, gather_indices) + tree_graphs.append( + self.trees[i].training_graph(tree_data, tree_labels, seed)) return tf.group(*tree_graphs) def inference_graph(self, input_data): @@ -265,9 +289,23 @@ class RandomForestGraphs(object): """ probabilities = [] for i in range(self.params.num_trees): - probabilities.append(self.trees[i].inference_graph(input_data)) - all_predict = tf.pack(probabilities) - return tf.reduce_sum(all_predict, 0) / self.params.num_trees + with tf.device(self.device_assigner.get_device(i)): + probabilities.append(self.trees[i].inference_graph(input_data)) + with tf.device(self.device_assigner.get_device(0)): + all_predict = tf.pack(probabilities) + return tf.reduce_sum(all_predict, 0) / self.params.num_trees + + def average_size(self): + """Constructs a TF graph for evaluating the average size of a forest. + + Returns: + The average number of nodes over the trees. + """ + sizes = [] + for i in range(self.params.num_trees): + with tf.device(self.device_assigner.get_device(i)): + sizes.append(self.trees[i].size()) + return tf.reduce_mean(tf.pack(sizes)) def average_impurity(self): """Constructs a TF graph for evaluating the leaf impurity of a forest. @@ -277,9 +315,17 @@ class RandomForestGraphs(object): """ impurities = [] for i in range(self.params.num_trees): - impurities.append(self.trees[i].average_impurity(self.variables[i])) + with tf.device(self.device_assigner.get_device(i)): + impurities.append(self.trees[i].average_impurity()) return tf.reduce_mean(tf.pack(impurities)) + def get_stats(self, session): + tree_stats = [] + for i in range(self.params.num_trees): + with tf.device(self.device_assigner.get_device(i)): + tree_stats.append(self.trees[i].get_stats(session)) + return ForestStats(tree_stats, self.params) + class RandomTreeGraphs(object): """Builds TF graphs for random tree training and inference.""" @@ -394,6 +440,7 @@ class RandomTreeGraphs(object): with tf.control_dependencies([node_update_op]): def f1(): return self.variables.non_fertile_leaf_scores + def f2(): counts = tf.gather(self.variables.node_per_class_weights, self.variables.non_fertile_leaves) @@ -535,3 +582,18 @@ class RandomTreeGraphs(object): counts = tf.gather(self.variables.node_per_class_weights, leaves) impurity = self._weighted_gini(counts) return tf.reduce_sum(impurity) / tf.reduce_sum(counts + 1.0) + + def size(self): + """Constructs a TF graph for evaluating the current number of nodes. + + Returns: + The current number of nodes in the tree. + """ + return self.variables.end_of_tree - 1 + + def get_stats(self, session): + num_nodes = self.variables.end_of_tree.eval(session=session) - 1 + num_leaves = tf.where( + tf.equal(tf.squeeze(tf.slice(self.variables.tree, [0, 0], [-1, 1])), + LEAF_NODE)).eval(session=session).shape[0] + return TreeStats(num_nodes, num_leaves) diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py index e4846cb0479..a2cf187bdcb 100644 --- a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py +++ b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py @@ -27,6 +27,37 @@ from tensorflow.python.platform import googletest class TensorForestTest(test_util.TensorFlowTestCase): + def testForestHParams(self): + hparams = tensor_forest.ForestHParams( + num_classes=2, num_trees=100, max_nodes=1000, + num_features=60).fill() + self.assertEquals(2, hparams.num_classes) + # 2 * ceil(log_2(1000)) = 20 + self.assertEquals(20, hparams.max_depth) + # sqrt(num_features) < 10, so num_splits_to_consider should be 10. + self.assertEquals(10, hparams.num_splits_to_consider) + # Don't have more fertile nodes than max # leaves, which is 500. + self.assertEquals(500, hparams.max_fertile_nodes) + # We didn't set either of these, so they should be equal + self.assertEquals(hparams.split_after_samples, + hparams.valid_leaf_threshold) + # split_after_samples is larger than 10 + self.assertEquals(1, hparams.split_initializations_per_input) + self.assertEquals(0, hparams.base_random_seed) + + def testForestHParamsBigTree(self): + hparams = tensor_forest.ForestHParams( + num_classes=2, num_trees=100, max_nodes=1000000, + split_after_samples=25, + num_features=1000).fill() + self.assertEquals(40, hparams.max_depth) + # sqrt(1000) = 31.63... + self.assertEquals(32, hparams.num_splits_to_consider) + # 1000000 / 32 = 31250 + self.assertEquals(31250, hparams.max_fertile_nodes) + # floor(31.63 / 25) = 1 + self.assertEquals(1, hparams.split_initializations_per_input) + def testTrainingConstruction(self): input_data = [[-1., 0.], [-1., 2.], # node 1 [1., 0.], [1., -2.]] # node 2 @@ -50,6 +81,14 @@ class TensorForestTest(test_util.TensorFlowTestCase): graph = graph_builder.inference_graph(input_data) self.assertTrue(isinstance(graph, tf.Tensor)) + def testImpurityConstruction(self): + params = tensor_forest.ForestHParams( + num_classes=4, num_features=2, num_trees=10, max_nodes=1000).fill() + + graph_builder = tensor_forest.RandomForestGraphs(params) + graph = graph_builder.average_impurity() + self.assertTrue(isinstance(graph, tf.Tensor)) + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 41b5e2d723d..ee0a5caaecb 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -143,7 +143,6 @@ cc_library( "lib/core/bits.h", "lib/core/casts.h", "lib/core/coding.h", - "lib/core/command_line_flags.h", # TODO(vrv): Delete. "lib/core/errors.h", "lib/core/notification.h", "lib/core/status.h", diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc index f1df7e9debd..c928eccec34 100644 --- a/tensorflow/core/framework/tensor.cc +++ b/tensorflow/core/framework/tensor.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/coding.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/gtl/stl_util.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -713,4 +714,36 @@ void Tensor::FillDescription(TensorDescription* description) const { } } +gtl::InlinedVector Tensor::ComputeFlatInnerDims( + int64 num_out_dims) const { + gtl::InlinedVector out_dims(num_out_dims, 0); + const int64 num_elements = NumElements(); + if (num_elements != 0) { + int64 prod_out_dims = 1; + for (int64 out_dim = num_out_dims - 1; out_dim > 0; --out_dim) { + const int64 in_dim = out_dim + (dims() - num_out_dims); + out_dims[out_dim] = + (in_dim >= dims() || in_dim < 0) ? 1 : dim_size(in_dim); + prod_out_dims *= out_dims[out_dim]; + } + out_dims[0] = num_elements / prod_out_dims; + } + return out_dims; +} + +gtl::InlinedVector Tensor::ComputeFlatOuterDims( + int64 num_out_dims) const { + gtl::InlinedVector out_dims(num_out_dims, 0); + const int64 num_elements = NumElements(); + if (num_elements != 0) { + int64 prod_out_dims = 1; + for (int64 out_dim = 0; out_dim < num_out_dims - 1; ++out_dim) { + out_dims[out_dim] = out_dim >= dims() ? 1 : dim_size(out_dim); + prod_out_dims *= out_dims[out_dim]; + } + out_dims[num_out_dims - 1] = num_elements / prod_out_dims; + } + return out_dims; +} + } // namespace tensorflow diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h index 708c98f409c..5abc9c9f526 100644 --- a/tensorflow/core/framework/tensor.h +++ b/tensorflow/core/framework/tensor.h @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -243,40 +244,28 @@ class Tensor { /// /// ``` template - typename TTypes::Flat flat(); + typename TTypes::Flat flat() { + return shaped({NumElements()}); + } template typename TTypes::UnalignedFlat unaligned_flat() { return unaligned_shaped({NumElements()}); } - /// Returns the data as an Eigen::Tensor with 2 dimensions, collapsing all - /// Tensor dimensions but the last one into the first dimension of the result. - template - typename TTypes::Matrix flat_inner_dims() { - int64 last_size = dims() > 0 ? dim_size(dims() - 1) : 1; - if (last_size == 0) { - DCHECK_EQ(NumElements(), 0); - // Return something empty, avoiding divide by 0 - return shaped({0, 0}); - } else { - return shaped({NumElements() / last_size, last_size}); - } - } + /// Returns the data as an Eigen::Tensor with NDIMS dimensions, collapsing all + /// Tensor dimensions but the last NDIMS-1 into the first dimension of the + /// result. If NDIMS > dims() then leading dimensions of size 1 will be + /// added to make the output rank NDIMS. + template + typename TTypes::Tensor flat_inner_dims(); - /// Returns the data as an Eigen::Tensor with 2 dimensions, collapsing all - /// Tensor dimensions but the first one into the last dimension of the result. - template - typename TTypes::Matrix flat_outer_dims() { - int64 first_size = dims() > 0 ? dim_size(0) : 1; - if (first_size == 0) { - DCHECK_EQ(NumElements(), 0); - // Return something empty, avoiding divide by 0 - return shaped({0, 0}); - } else { - return shaped({first_size, NumElements() / first_size}); - } - } + /// Returns the data as an Eigen::Tensor with NDIMS dimensions, collapsing all + /// Tensor dimensions but the first NDIMS-1 into the last dimension of the + /// result. If NDIMS > dims() then trailing dimensions of size 1 will be + /// added to make the output rank NDIMS. + template + typename TTypes::Tensor flat_outer_dims(); template typename TTypes::Tensor shaped(gtl::ArraySlice new_sizes); @@ -308,31 +297,19 @@ class Tensor { typename TTypes::ConstTensor tensor() const; template - typename TTypes::ConstFlat flat() const; + typename TTypes::ConstFlat flat() const { + return shaped({NumElements()}); + } template typename TTypes::UnalignedConstFlat unaligned_flat() const { return unaligned_shaped({NumElements()}); } - template - typename TTypes::ConstMatrix flat_inner_dims() const { - int64 last_size = dims() > 0 ? dim_size(dims() - 1) : 1; - if (last_size == 0) { - DCHECK_EQ(NumElements(), 0); - // Return something empty, avoiding divide by 0 - return shaped({0, 0}); - } else { - return shaped({NumElements() / last_size, last_size}); - } - } - - template - typename TTypes::ConstMatrix flat_outer_dims() const; - template typename TTypes::ConstTensor shaped( gtl::ArraySlice new_sizes) const; + template typename TTypes::UnalignedConstTensor unaligned_shaped( gtl::ArraySlice new_sizes) const; @@ -340,6 +317,12 @@ class Tensor { template typename TTypes::ConstScalar scalar() const; + template + typename TTypes::ConstTensor flat_inner_dims() const; + + template + typename TTypes::ConstTensor flat_outer_dims() const; + /// Render the first `max_entries` values in `*this` into a string. string SummarizeValue(int64 max_entries) const; @@ -378,6 +361,8 @@ class Tensor { void FillDimsAndValidateCompatibleShape( gtl::ArraySlice new_sizes, Eigen::array* dims) const; + gtl::InlinedVector ComputeFlatInnerDims(int64 num_out_dims) const; + gtl::InlinedVector ComputeFlatOuterDims(int64 num_out_dims) const; TensorShape shape_; TensorBuffer* buf_; @@ -534,26 +519,24 @@ typename TTypes::ConstScalar Tensor::scalar() const { return typename TTypes::ConstScalar(base()); } -template -typename TTypes::Flat Tensor::flat() { - return shaped({NumElements()}); +template +typename TTypes::Tensor Tensor::flat_inner_dims() { + return shaped(ComputeFlatInnerDims(NDIMS)); } -template -typename TTypes::ConstFlat Tensor::flat() const { - return shaped({NumElements()}); +template +typename TTypes::Tensor Tensor::flat_outer_dims() { + return shaped(ComputeFlatOuterDims(NDIMS)); } -template -typename TTypes::ConstMatrix Tensor::flat_outer_dims() const { - int64 first_size = dims() > 0 ? dim_size(0) : 1; - if (first_size == 0) { - DCHECK_EQ(NumElements(), 0); - // Return something empty, avoiding divide by 0 - return shaped({0, 0}); - } else { - return shaped({first_size, NumElements() / first_size}); - } +template +typename TTypes::ConstTensor Tensor::flat_inner_dims() const { + return shaped(ComputeFlatInnerDims(NDIMS)); +} + +template +typename TTypes::ConstTensor Tensor::flat_outer_dims() const { + return shaped(ComputeFlatOuterDims(NDIMS)); } } // namespace tensorflow diff --git a/tensorflow/core/framework/tensor_test.cc b/tensorflow/core/framework/tensor_test.cc index 13896f9177d..ecc04671038 100644 --- a/tensorflow/core/framework/tensor_test.cc +++ b/tensorflow/core/framework/tensor_test.cc @@ -224,6 +224,49 @@ TEST(Tensor_Float, Reshape) { EXPECT_EQ(flat_inner_dims(0, 0), 0.01f); EXPECT_EQ(flat_inner_dims(23, 4), 0.02f); } + { + auto flat_outer_dims = t.flat_outer_dims(); + EXPECT_EQ(2, flat_outer_dims.dimension(0)); + EXPECT_EQ(60, flat_outer_dims.dimension(1)); + EXPECT_EQ(flat_outer_dims(0, 0), 0.01f); + EXPECT_EQ(flat_outer_dims(1, 59), 0.02f); + } + { + auto flat_inner_dims = t.flat_inner_dims(); + EXPECT_EQ(6, flat_inner_dims.dimension(0)); + EXPECT_EQ(4, flat_inner_dims.dimension(1)); + EXPECT_EQ(5, flat_inner_dims.dimension(2)); + EXPECT_EQ(flat_inner_dims(0, 0, 0), 0.01f); + EXPECT_EQ(flat_inner_dims(5, 3, 4), 0.02f); + } + { + auto flat_outer_dims = t.flat_outer_dims(); + EXPECT_EQ(2, flat_outer_dims.dimension(0)); + EXPECT_EQ(3, flat_outer_dims.dimension(1)); + EXPECT_EQ(20, flat_outer_dims.dimension(2)); + EXPECT_EQ(flat_outer_dims(0, 0, 0), 0.01f); + EXPECT_EQ(flat_outer_dims(1, 2, 19), 0.02f); + } + { + auto flat_inner_dims = t.flat_inner_dims(); + EXPECT_EQ(1, flat_inner_dims.dimension(0)); + EXPECT_EQ(2, flat_inner_dims.dimension(1)); + EXPECT_EQ(3, flat_inner_dims.dimension(2)); + EXPECT_EQ(4, flat_inner_dims.dimension(3)); + EXPECT_EQ(5, flat_inner_dims.dimension(4)); + EXPECT_EQ(flat_inner_dims(0, 0, 0, 0, 0), 0.01f); + EXPECT_EQ(flat_inner_dims(0, 1, 2, 3, 4), 0.02f); + } + { + auto flat_outer_dims = t.flat_outer_dims(); + EXPECT_EQ(2, flat_outer_dims.dimension(0)); + EXPECT_EQ(3, flat_outer_dims.dimension(1)); + EXPECT_EQ(4, flat_outer_dims.dimension(2)); + EXPECT_EQ(5, flat_outer_dims.dimension(3)); + EXPECT_EQ(1, flat_outer_dims.dimension(4)); + EXPECT_EQ(flat_outer_dims(0, 0, 0, 0, 0), 0.01f); + EXPECT_EQ(flat_outer_dims(1, 2, 3, 4, 0), 0.02f); + } } TEST(Tensor_Scalar, Basics) { diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index d75db3f381d..751c5afb7b5 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -305,6 +305,23 @@ tf_kernel_libraries( ], ) +tf_cc_test( + name = "batch_norm_op_test", + size = "small", + deps = [ + ":batch_norm_op", + ":ops_testutil", + ":ops_util", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + tf_cc_test( name = "concat_op_test", size = "small", diff --git a/tensorflow/core/kernels/batch_norm_op_test.cc b/tensorflow/core/kernels/batch_norm_op_test.cc new file mode 100644 index 00000000000..e70bcc5b4c6 --- /dev/null +++ b/tensorflow/core/kernels/batch_norm_op_test.cc @@ -0,0 +1,62 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +class BatchNormOpTest : public OpsTestBase {}; + +TEST_F(BatchNormOpTest, Simple) { + TF_EXPECT_OK( + NodeDefBuilder("batch_norm_op", "BatchNormWithGlobalNormalization") + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Attr("scale_after_normalization", false) + .Attr("variance_epsilon", 0.001) + .Finalize(node_def())); + TF_EXPECT_OK(InitOpWithGraphVersion(8)); + AddInputFromArray(TensorShape({1, 1, 6, 2}), + {1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6}); + AddInputFromArray(TensorShape({2}), {10, 20}); + AddInputFromArray(TensorShape({2}), {0.25, 0.5}); + AddInputFromArray(TensorShape({2}), {0.1, 0.6}); + AddInputFromArray(TensorShape({2}), {0.0, 0.0}); + TF_ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 1, 6, 2})); + test::FillValues( + &expected, {-17.86, -22.00, -15.87, -20.59, -13.87, -19.18, -21.86, + -33.31, -23.85, -34.72, -25.85, -36.13}); + test::ExpectTensorNear(expected, *GetOutput(0), 0.01); +} + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/ops_testutil.h b/tensorflow/core/kernels/ops_testutil.h index d2a9bf293b2..521168cb17d 100644 --- a/tensorflow/core/kernels/ops_testutil.h +++ b/tensorflow/core/kernels/ops_testutil.h @@ -94,10 +94,13 @@ class OpsTestBase : public ::testing::Test { // and output types as output. // // Returns the status of initialization. - Status InitOp() { + Status InitOp() { return InitOpWithGraphVersion(TF_GRAPH_DEF_VERSION); } + + // Only use this directly if you have a deprecated op that you need to test. + Status InitOpWithGraphVersion(int graph_def_version) { Status status; kernel_ = CreateOpKernel(device_type_, device_.get(), allocator(), - node_def_, TF_GRAPH_DEF_VERSION, &status); + node_def_, graph_def_version, &status); if (kernel_ != nullptr) input_types_ = kernel_->input_types(); return status; } diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h index 52120d6c772..ac6c0339da5 100644 --- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h +++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h @@ -66,7 +66,7 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T MaybeConj(T v) { #define MAYBE_CONJ(T) \ template <> \ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T MaybeConj(T v) { \ - return std::conj(v); \ + return Eigen::numext::conj(v); \ } #endif diff --git a/tensorflow/core/lib/core/command_line_flags.cc b/tensorflow/core/lib/core/command_line_flags.cc deleted file mode 100644 index 757b63b6694..00000000000 --- a/tensorflow/core/lib/core/command_line_flags.cc +++ /dev/null @@ -1,121 +0,0 @@ -/* Copyright 2015 Google Inc. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/lib/core/command_line_flags.h" - -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" - -namespace tensorflow { -namespace { - -// Templated function to convert a string to target values. -// Return true if the conversion is successful. Otherwise, return false. -template -bool StringToValue(const string& content, T* value); - -template <> -bool StringToValue(const string& content, int32* value) { - return strings::safe_strto32(content, value); -} - -template <> -bool StringToValue(const string& content, string* value) { - *value = content; - return true; -} - -// Parse a single argument by linearly searching through the command table. -// The input format is: --argument=value. -// Return OK if the argument is used. It store the extracted value into the -// matching flag. -// Return NOT_FOUND if the argument is not recognized. -// Return INVALID_ARGUMENT if the command is recognized, but fails to extract -// its value. -template -Status ParseArgument(const string& argument) { - for (auto& command : - internal::CommandLineFlagRegistry::Instance()->commands) { - string prefix = strings::StrCat("--", command.name, "="); - if (tensorflow::StringPiece(argument).starts_with(prefix)) { - string content = argument.substr(prefix.length()); - if (StringToValue(content, command.value)) { - return Status::OK(); - } - return Status(error::INVALID_ARGUMENT, - strings::StrCat("Cannot parse integer in: ", argument)); - } - } - return Status(error::NOT_FOUND, - strings::StrCat("Unknown command: ", argument)); -} - -// A specialization for booleans. The input format is: -// "--argument" or "--noargument". -// Parse a single argument by linearly searching through the command table. -// Return OK if the argument is used. The value is stored in the matching flag. -// Return NOT_FOUND if the argument is not recognized. -template <> -Status ParseArgument(const string& argument) { - for (auto& command : - internal::CommandLineFlagRegistry::Instance()->commands) { - if (argument == strings::StrCat("--", command.name)) { - *command.value = true; - return Status::OK(); - } else if (argument == strings::StrCat("--no", command.name)) { - *command.value = false; - return Status::OK(); - } - } - return Status(error::NOT_FOUND, - strings::StrCat("Unknown command: ", argument)); -} - -} // namespace - -Status ParseCommandLineFlags(int* argc, char* argv[]) { - int unused_argc = 1; - for (int index = 1; index < *argc; ++index) { - Status s; - // Search bool commands. - s = ParseArgument(argv[index]); - if (s.ok()) { - continue; - } - if (s.code() != error::NOT_FOUND) { - return s; - } - // Search int32 commands. - s = ParseArgument(argv[index]); - if (s.ok()) { - continue; - } - // Search string commands. - s = ParseArgument(argv[index]); - if (s.ok()) { - continue; - } - if (s.code() != error::NOT_FOUND) { - return s; - } - // Pointer swap the unused argument to the front. - std::swap(argv[unused_argc++], argv[index]); - } - *argc = unused_argc; - return Status::OK(); -} - -} // namespace tensorflow diff --git a/tensorflow/core/lib/core/command_line_flags.h b/tensorflow/core/lib/core/command_line_flags.h deleted file mode 100644 index d6f6f795145..00000000000 --- a/tensorflow/core/lib/core/command_line_flags.h +++ /dev/null @@ -1,80 +0,0 @@ -/* Copyright 2015 Google Inc. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LIB_CORE_COMMAND_LINE_FLAGS_H_ -#define TENSORFLOW_LIB_CORE_COMMAND_LINE_FLAGS_H_ - -#include -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { -namespace internal { - -template -struct CommandLineFlagRegistry { - static CommandLineFlagRegistry* Instance() { - static CommandLineFlagRegistry instance_; - return &instance_; - } - struct Command { - string name; - T* value; - string text; - }; - std::vector commands; - - private: - CommandLineFlagRegistry() {} - TF_DISALLOW_COPY_AND_ASSIGN(CommandLineFlagRegistry); -}; - -template -struct CommandLineFlagRegister { - CommandLineFlagRegister(const string& name, T* val, const string& text) { - CommandLineFlagRegistry::Instance()->commands.push_back( - {name, val, text}); - } -}; - -#define TF_DEFINE_variable(type, name, default_value, text) \ - type FLAGS_##name = default_value; \ - namespace TF_flags_internal { \ - tensorflow::internal::CommandLineFlagRegister \ - TF_flags_internal_var_##name(#name, &FLAGS_##name, text); \ - } // namespace TF_flags_internal - -} // namespace internal - -#define TF_DEFINE_int32(name, default_value, text) \ - TF_DEFINE_variable(tensorflow::int32, name, default_value, text); - -#define TF_DEFINE_bool(name, default_value, text) \ - TF_DEFINE_variable(bool, name, default_value, text); - -#define TF_DEFINE_string(name, default_value, text) \ - TF_DEFINE_variable(string, name, default_value, text); - -// Parse argv[1]..argv[*argc-1] to options. Remove used arguments from the argv. -// Returned the number of unused arguments in *argc. -// Return error Status if the parsing encounters errors. -// TODO(opensource): switch to a command line argument parser that can be -// shared with other tests. -Status ParseCommandLineFlags(int* argc, char* argv[]); - -} // namespace tensorflow - -#endif // TENSORFLOW_LIB_CORE_COMMAND_LINE_FLAGS_H_ diff --git a/tensorflow/g3doc/api_docs/python/state_ops.md b/tensorflow/g3doc/api_docs/python/state_ops.md index 172c4785007..da2d66d9dd8 100644 --- a/tensorflow/g3doc/api_docs/python/state_ops.md +++ b/tensorflow/g3doc/api_docs/python/state_ops.md @@ -120,9 +120,12 @@ variable to its initial value. ##### Args: -* `initial_value`: A `Tensor`, or Python object convertible to a `Tensor`. - The initial value for the Variable. Must have a shape specified unless - `validate_shape` is set to False. +* `initial_value`: A `Tensor`, or Python object convertible to a `Tensor`, + which is the initial value for the Variable. The initial value must have + a shape specified unless `validate_shape` is set to False. Can also be a + callable with no argument that returns the initial value when called. In + that case, `dtype` must be specified. (Note that initializer functions + from init_ops.py must first be bound to a shape before being used here.) * `trainable`: If `True`, the default, also adds the variable to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as the default list of variables to use by the `Optimizer` classes. diff --git a/tensorflow/g3doc/api_docs/python/train.md b/tensorflow/g3doc/api_docs/python/train.md index 24b9f9f8142..55463482034 100644 --- a/tensorflow/g3doc/api_docs/python/train.md +++ b/tensorflow/g3doc/api_docs/python/train.md @@ -1955,6 +1955,7 @@ on the parameters to the constructor and may include: ##### Raises: +* `RuntimeError`: If called with a non-chief Supervisor. * `ValueError`: If not `logdir` was passed to the constructor as the services need a log directory. @@ -2182,6 +2183,7 @@ on the parameters to the constructor and may include: ##### Raises: +* `RuntimeError`: If called with a non-chief Supervisor. * `ValueError`: If not `logdir` was passed to the constructor as the services need a log directory. @@ -2409,7 +2411,7 @@ Start threads for `QueueRunners`. #### `tf.train.Supervisor.summary_op` {#Supervisor.summary_op} -Return the Summary Tensor used by the supervisor. +Return the Summary Tensor used by the chief supervisor. ##### Returns: @@ -2420,7 +2422,7 @@ Return the Summary Tensor used by the supervisor. #### `tf.train.Supervisor.summary_writer` {#Supervisor.summary_writer} -Return the SummaryWriter used by the supervisor. +Return the SummaryWriter used by the chief supervisor. ##### Returns: diff --git a/tensorflow/python/kernel_tests/rnn_cell_test.py b/tensorflow/python/kernel_tests/rnn_cell_test.py index 5261af4aabf..47d4af075ee 100644 --- a/tensorflow/python/kernel_tests/rnn_cell_test.py +++ b/tensorflow/python/kernel_tests/rnn_cell_test.py @@ -19,6 +19,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools import numpy as np import tensorflow as tf @@ -208,5 +209,59 @@ class RNNCellTest(tf.test.TestCase): 0.13248, 0.13248]]) +class SlimRNNCellTest(tf.test.TestCase): + + def testBasicRNNCell(self): + with self.test_session() as sess: + with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): + x = tf.zeros([1, 2]) + m = tf.zeros([1, 2]) + my_cell = functools.partial(basic_rnn_cell, num_units=2) + g, _ = tf.nn.rnn_cell.SlimRNNCell(my_cell)(x, m) + sess.run([tf.initialize_all_variables()]) + res = sess.run([g], {x.name: np.array([[1., 1.]]), + m.name: np.array([[0.1, 0.1]])}) + self.assertEqual(res[0].shape, (1, 2)) + + def testBasicRNNCellMatch(self): + batch_size = 32 + input_size = 100 + num_units = 10 + with self.test_session() as sess: + with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): + inputs = tf.random_uniform((batch_size, input_size)) + _, initial_state = basic_rnn_cell(inputs, None, num_units) + my_cell = functools.partial(basic_rnn_cell, num_units=num_units) + slim_cell = tf.nn.rnn_cell.SlimRNNCell(my_cell) + slim_outputs, slim_state = slim_cell(inputs, initial_state) + rnn_cell = tf.nn.rnn_cell.BasicRNNCell(num_units) + outputs, state = rnn_cell(inputs, initial_state) + self.assertEqual(slim_outputs.get_shape(), outputs.get_shape()) + self.assertEqual(slim_state.get_shape(), state.get_shape()) + sess.run([tf.initialize_all_variables()]) + res = sess.run([slim_outputs, slim_state, outputs, state]) + self.assertAllClose(res[0], res[2]) + self.assertAllClose(res[1], res[3]) + + +def basic_rnn_cell(inputs, state, num_units, scope=None): + if state is None: + if inputs is not None: + batch_size = inputs.get_shape()[0] + dtype = inputs.dtype + else: + batch_size = 0 + dtype = tf.float32 + init_output = tf.zeros(tf.pack([batch_size, num_units]), dtype=dtype) + init_state = tf.zeros(tf.pack([batch_size, num_units]), dtype=dtype) + init_output.set_shape([batch_size, num_units]) + init_state.set_shape([batch_size, num_units]) + return init_output, init_state + else: + with tf.variable_op_scope([inputs, state], scope, "BasicRNNCell"): + output = tf.tanh(tf.nn.rnn_cell.linear([inputs, state], + num_units, True)) + return output, output + if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py index 593cd5f25a0..f2325779c10 100644 --- a/tensorflow/python/kernel_tests/variables_test.py +++ b/tensorflow/python/kernel_tests/variables_test.py @@ -302,6 +302,53 @@ class VariablesTestCase(tf.test.TestCase): self.assertEqual(var.op.device, init_op.device) sess.run(init_op) + def testInitializerFunction(self): + value = [[-42], [133.7]] + shape = [2, 1] + with self.test_session(): + initializer = lambda: tf.constant(value) + with self.assertRaises(ValueError): + # Checks that dtype must be specified. + tf.Variable(initializer) + + v1 = tf.Variable(initializer, dtype=tf.float32) + self.assertEqual(shape, v1.get_shape()) + self.assertAllClose(value, v1.initial_value.eval()) + with self.assertRaises(tf.errors.FailedPreconditionError): + v1.eval() + + v2 = tf.Variable(tf.neg(v1.initialized_value()), dtype=tf.float32) + self.assertEqual(v1.get_shape(), v2.get_shape()) + self.assertAllClose(np.negative(value), v2.initial_value.eval()) + + # Once v2.initial_value.eval() has been called, v1 has effectively been + # initialized. + self.assertAllClose(value, v1.eval()) + + with self.assertRaises(tf.errors.FailedPreconditionError): + v2.eval() + tf.initialize_all_variables().run() + self.assertAllClose(np.negative(value), v2.eval()) + + def testInitializerFunctionDevicePlacement(self): + with self.test_session(): + initializer = lambda: tf.constant(42.0) + with tf.device("/cpu:100"): + v1 = tf.Variable(initializer, dtype=tf.float32, name="v1") + expected_device = "/device:CPU:100" + expected_group_v1 = [b"loc:@v1"] + self.assertEqual(expected_device, v1.op.device) + self.assertEqual(expected_group_v1, v1.op.colocation_groups()) + for i in v1.initializer.inputs: + self.assertEqual(expected_device, i.op.device) + self.assertEqual(expected_group_v1, i.op.colocation_groups()) + + v2 = tf.Variable(initializer, dtype=tf.float32, name="v2") + expected_group_v2 = [b"loc:@v2"] + self.assertEqual(expected_group_v2, v2.op.colocation_groups()) + for i in v2.initializer.inputs: + self.assertEqual(expected_group_v2, i.op.colocation_groups()) + class IsInitializedTest(tf.test.TestCase): diff --git a/tensorflow/python/ops/partitioned_variables.py b/tensorflow/python/ops/partitioned_variables.py index 9d4d19668af..c16ba0f814a 100644 --- a/tensorflow/python/ops/partitioned_variables.py +++ b/tensorflow/python/ops/partitioned_variables.py @@ -167,19 +167,22 @@ def create_partitioned_variables( slice_offset[slice_dim] += var_shape[slice_dim] if callable(initializer): - init_val = initializer(var_shape, dtype=dtype) - init_val = ops.convert_to_tensor(init_val, dtype=dtype) + init = initializer + init_shape = var_shape elif isinstance(initializer, ops.Tensor): - init_val = array_ops.slice(initializer, var_offset, var_shape) + init = array_ops.slice(initializer, var_offset, var_shape) # Use the dtype of the given tensor. - dtype = init_val.dtype.base_dtype + dtype = init.dtype.base_dtype + init_shape = None else: - init_val = ops.convert_to_tensor(initializer, dtype=dtype) - init_val = array_ops.slice(init_val, var_offset, var_shape) + init = ops.convert_to_tensor(initializer, dtype=dtype) + init = array_ops.slice(init, var_offset, var_shape) + init_shape = None var = variable_scope.get_variable(name="part_%d" % i, + shape=init_shape, dtype=dtype, - initializer=init_val, + initializer=init, trainable=trainable, collections=collections) diff --git a/tensorflow/python/ops/rnn_cell.py b/tensorflow/python/ops/rnn_cell.py index 6650d3b53b8..9aa2314e5e6 100644 --- a/tensorflow/python/ops/rnn_cell.py +++ b/tensorflow/python/ops/rnn_cell.py @@ -661,6 +661,42 @@ class MultiRNNCell(RNNCell): return cur_inp, array_ops.concat(1, new_states) +class SlimRNNCell(RNNCell): + """A simple wrapper for slim.rnn_cells.""" + + def __init__(self, cell_fn): + """Create a SlimRNNCell from a cell_fn. + + Args: + cell_fn: a function which takes (inputs, state, scope) and produces the + outputs and the new_state. Additionally when called with inputs=None and + state=None it should return (initial_outputs, initial_state). + + Raises: + TypeError: if cell_fn is not callable + ValueError: if cell_fn cannot produce a valid initial state. + """ + if not callable(cell_fn): + raise TypeError("cell_fn %s needs to be callable", cell_fn) + self._cell_fn = cell_fn + self._cell_name = cell_fn.func.__name__ + _, init_state = self._cell_fn(None, None) + state_shape = init_state.get_shape() + self._state_size = state_shape.with_rank(2)[1].value + if self._state_size is None: + raise ValueError("Initial state created by %s has invalid shape %s", + self._cell_name, state_shape) + + @property + def state_size(self): + return self._state_size + + def __call__(self, inputs, state, scope=None): + scope = scope or self._cell_name + output, state = self._cell_fn(inputs, state, scope=scope) + return output, state + + def linear(args, output_size, bias, bias_start=0.0, scope=None): """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable. diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py index 64ad23674bd..f60816e0225 100644 --- a/tensorflow/python/ops/variable_scope.py +++ b/tensorflow/python/ops/variable_scope.py @@ -144,14 +144,19 @@ class _VariableStore(object): with ops.control_dependencies(None): if initializing_from_value: init_val = initializer + variable_dtype = None else: - with ops.name_scope(name + "/Initializer/"): - init_val = initializer(shape.as_list(), dtype=dtype) + init_val = lambda: initializer(shape.as_list(), dtype=dtype) + variable_dtype = dtype.base_dtype # Create the variable. - v = variables.Variable(init_val, name=name, trainable=trainable, + v = variables.Variable(initial_value=init_val, + name=name, + trainable=trainable, collections=collections, - caching_device=caching_device) + caching_device=caching_device, + dtype=variable_dtype) + self._vars[name] = v logging.info("Created variable %s with shape %s and init %s", v.name, format(shape), initializer) diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index a21724194b0..3f2571bc065 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -156,9 +156,12 @@ class Variable(object): variable to its initial value. Args: - initial_value: A `Tensor`, or Python object convertible to a `Tensor`. - The initial value for the Variable. Must have a shape specified unless - `validate_shape` is set to False. + initial_value: A `Tensor`, or Python object convertible to a `Tensor`, + which is the initial value for the Variable. The initial value must have + a shape specified unless `validate_shape` is set to False. Can also be a + callable with no argument that returns the initial value when called. In + that case, `dtype` must be specified. (Note that initializer functions + from init_ops.py must first be bound to a shape before being used here.) trainable: If `True`, the default, also adds the variable to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as the default list of variables to use by the `Optimizer` classes. @@ -211,9 +214,12 @@ class Variable(object): """Creates a new variable from arguments. Args: - initial_value: A `Tensor`, or Python object convertible to a `Tensor`. - The initial value for the Variable. Must have a shape specified unless - `validate_shape` is set to False. + initial_value: A `Tensor`, or Python object convertible to a `Tensor`, + which is the initial value for the Variable. The initial value must have + a shape specified unless `validate_shape` is set to False. Can also be a + callable with no argument that returns the initial value when called. In + that case, `dtype` must be specified. (Note that initializer functions + from init_ops.py must first be bound to a shape before being used here.) trainable: If `True`, the default, also adds the variable to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as the default list of variables to use by the `Optimizer` classes. @@ -240,25 +246,62 @@ class Variable(object): """ if initial_value is None: raise ValueError("initial_value must be specified.") + init_from_fn = callable(initial_value) + if init_from_fn and dtype is None: + raise ValueError( + "dtype must also be specified when initial_value is callable.") + if collections is None: collections = [ops.GraphKeys.VARIABLES] if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] with ops.control_dependencies(None): - with ops.op_scope([initial_value], name, "Variable") as name: - self._initial_value = ops.convert_to_tensor(initial_value, - name="initial_value", - dtype=dtype) - initial_value_shape = self._initial_value.get_shape() - if validate_shape and not initial_value_shape.is_fully_defined(): - raise ValueError("initial_value must have a shape specified: %s" - % self._initial_value) - shape_to_set = initial_value_shape if validate_shape else [] + with ops.op_scope( + [] if init_from_fn else [initial_value], name, "Variable") as name: - self._variable = state_ops.variable_op( - shape_to_set, self._initial_value.dtype.base_dtype, - set_shape=validate_shape, name=name) + # Get the initial value from a callable function. The real shape of the + # variable will be set later, since under the init_from_fn case, the + # shape won't be known until after the function is invoked. + if init_from_fn: + self._variable = state_ops.variable_op( + [], + dtype.base_dtype, + set_shape=False, + name=name) + with ops.colocate_with(self._variable.op): + with ops.name_scope("Initializer"): + # Colocate the tensors created by the initial_value() function + # with the variable itself. + self._initial_value = ops.convert_to_tensor(initial_value(), + name="initial_value", + dtype=dtype) + # Or get the initial value from a Tensor or Python object. + else: + self._initial_value = ops.convert_to_tensor(initial_value, + name="initial_value", + dtype=dtype) + # In this case, the variable op can't be created until after the + # initial_value has been converted to a Tensor with a known type. + self._variable = state_ops.variable_op( + [], + self._initial_value.dtype.base_dtype, + set_shape=False, + name=name) + + # Manually overrides the variable's shape with the initial value's. + if validate_shape: + initial_value_shape = self._initial_value.get_shape() + if not initial_value_shape.is_fully_defined(): + raise ValueError("initial_value must have a shape specified: %s" + % self._initial_value) + self._variable.set_shape(initial_value_shape) + # TODO(b/28152992): Remove the below hack modifying the node_def shape + # directly once set_shape() handles it. + self._variable.op.node_def.attr["shape"].shape.CopyFrom( + initial_value_shape.as_proto()) + + # Assigns initial value. with ops.colocate_with(self._variable.op): self._initializer_op = state_ops.assign( self._variable, self._initial_value, diff --git a/tensorflow/python/platform/resource_loader.py b/tensorflow/python/platform/resource_loader.py index 2e8de06b919..0188dc6b39b 100644 --- a/tensorflow/python/platform/resource_loader.py +++ b/tensorflow/python/platform/resource_loader.py @@ -79,3 +79,7 @@ def get_path_to_datafile(path): """ data_files_path = os.path.dirname(inspect.getfile(sys._getframe(1))) return os.path.join(data_files_path, path) + +def readahead_file_path(path, unused_readahead=None): + """Readahead files not implemented; simply returns given path.""" + return path diff --git a/tensorflow/python/summary/impl/event_file_loader.py b/tensorflow/python/summary/impl/event_file_loader.py index 7c7f6ca70dd..150d31edb7c 100644 --- a/tensorflow/python/summary/impl/event_file_loader.py +++ b/tensorflow/python/summary/impl/event_file_loader.py @@ -22,6 +22,7 @@ from tensorflow.core.util import event_pb2 from tensorflow.python import pywrap_tensorflow from tensorflow.python.platform import app from tensorflow.python.platform import logging +from tensorflow.python.platform import resource_loader from tensorflow.python.util import compat @@ -31,6 +32,7 @@ class EventFileLoader(object): def __init__(self, file_path): if file_path is None: raise ValueError('A file path is required') + file_path = resource_loader.readahead_file_path(file_path) logging.debug('Opening a record reader pointing at %s', file_path) self._reader = pywrap_tensorflow.PyRecordReader_New( compat.as_bytes(file_path), 0) diff --git a/tensorflow/python/training/supervisor.py b/tensorflow/python/training/supervisor.py index a434d675407..94ec9b92f0f 100644 --- a/tensorflow/python/training/supervisor.py +++ b/tensorflow/python/training/supervisor.py @@ -326,19 +326,29 @@ class Supervisor(object): self._init_global_step(global_step=global_step) self._graph = graph self._is_chief = is_chief - self._logdir = logdir - self._save_summaries_secs = save_summaries_secs - self._save_model_secs = save_model_secs - self._recovery_wait_secs = recovery_wait_secs self._coord = coordinator.Coordinator() - if logdir: + self._started_threads = [] + self._recovery_wait_secs = recovery_wait_secs + + # Only chief supervisors write event files, so only chief supervisors + # should have event-writing properties. Set to None for non-chiefs. + if self._is_chief: + self._logdir = logdir + self._save_summaries_secs = save_summaries_secs + self._save_model_secs = save_model_secs + else: + self._logdir = None + self._save_summaries_secs = None + self._save_model_secs = None + + if self._is_chief and self._logdir: self._save_path = os.path.join(self._logdir, checkpoint_basename) self._summary_writer = summary_io.SummaryWriter(self._logdir) else: self._save_path = None self._summary_writer = None + self._init_session_manager(session_manager=session_manager) - self._started_threads = [] self._verify_setup() # The graph is not allowed to change anymore. graph.finalize() @@ -520,7 +530,7 @@ class Supervisor(object): @property def summary_writer(self): - """Return the SummaryWriter used by the supervisor. + """Return the SummaryWriter used by the chief supervisor. Returns: A SummaryWriter. @@ -529,7 +539,7 @@ class Supervisor(object): @property def summary_op(self): - """Return the Summary Tensor used by the supervisor. + """Return the Summary Tensor used by the chief supervisor. Returns: A string Tensor for the summary or `None`. @@ -583,8 +593,7 @@ class Supervisor(object): def _write_graph(self): """Writes graph_def to `logdir` and adds it to summary if applicable.""" - if not self._is_chief: - return + assert self._is_chief if self._logdir: training_util.write_graph(self._graph.as_graph_def(), self._logdir, "graph.pbtxt") @@ -610,11 +619,13 @@ class Supervisor(object): sv.coord.Join() Raises: + RuntimeError: If called with a non-chief Supervisor. ValueError: If not `logdir` was passed to the constructor as the services need a log directory. """ if not self._is_chief: - return + raise RuntimeError("Only chief supervisor can start standard services. " + "Because only cheif supervisors can write events.") if not self._logdir: logging.warning("Standard services need a 'logdir' " "passed to the SessionManager") @@ -812,14 +823,18 @@ class Supervisor(object): TypeError: if 'summary' is not a Summary proto or a string. RuntimeError: if the Supervisor was created without a `logdir`. """ - if not self._logdir: - raise RuntimeError("summary_computed() requires a logdir") + if not self._summary_writer: + raise RuntimeError("Writing a summary requires a summary writer.") if global_step is None and self.global_step is not None: global_step = training_util.global_step(sess, self.global_step) - if self._summary_writer: - self._summary_writer.add_summary(summary, global_step) + self._summary_writer.add_summary(summary, global_step) def _default_global_step_tensor(self): + """Returns the global_step from the default graph. + + Returns: + The global step `Tensor` or `None`. + """ try: gs = ops.get_default_graph().get_tensor_by_name("global_step:0") if gs.dtype.base_dtype in [dtypes.int32, dtypes.int64]: diff --git a/tensorflow/python/training/supervisor_test.py b/tensorflow/python/training/supervisor_test.py index 01d4a74e18c..4bd73e4fc3e 100644 --- a/tensorflow/python/training/supervisor_test.py +++ b/tensorflow/python/training/supervisor_test.py @@ -73,12 +73,11 @@ class SupervisorTest(tf.test.TestCase): sess.close() sv.stop() - def testSummary(self): + def testChiefCanWriteEvents(self): logdir = self._TestDir("basics") with tf.Graph().as_default(): - const = tf.constant([1.0, 2.0, 3.0]) - summ = tf.scalar_summary(["c1", "c2", "c3"], const) - sv = tf.train.Supervisor(logdir=logdir, summary_op=None) + summ = tf.scalar_summary(["c1", "c2", "c3"], tf.constant([1.0, 2.0, 3.0])) + sv = tf.train.Supervisor(is_chief=True, logdir=logdir, summary_op=None) sess = sv.prepare_or_wait_for_session("") sv.summary_computed(sess, sess.run(summ)) sess.close() @@ -113,13 +112,31 @@ class SupervisorTest(tf.test.TestCase): # We should be done. self.assertRaises(StopIteration, lambda: next(rr)) + def testNonChiefCannotWriteEvents(self): + + def _summary_computed(): + with tf.Graph().as_default(): + sv = tf.train.Supervisor(is_chief=False) + sess = sv.prepare_or_wait_for_session("") + summ = tf.scalar_summary(["c1", "c2"], tf.constant([1.0, 2.0])) + sv.summary_computed(sess, sess.run(summ)) + + def _start_standard_services(): + with tf.Graph().as_default(): + sv = tf.train.Supervisor(is_chief=False) + sess = sv.prepare_or_wait_for_session("") + sv.start_standard_services(sess) + + self.assertRaises(RuntimeError, _summary_computed) + self.assertRaises(RuntimeError, _start_standard_services) + def testNoLogdirButWantSummary(self): with tf.Graph().as_default(): const = tf.constant([1.0, 2.0, 3.0]) summ = tf.scalar_summary(["c1", "c2", "c3"], const) sv = tf.train.Supervisor(logdir="", summary_op=None) sess = sv.prepare_or_wait_for_session("") - with self.assertRaisesRegexp(RuntimeError, "requires a logdir"): + with self.assertRaisesRegexp(RuntimeError, "requires a summary writer"): sv.summary_computed(sess, sess.run(summ)) def testNoLogdirSucceeds(self): diff --git a/tensorflow/tools/benchmark/BUILD b/tensorflow/tools/benchmark/BUILD new file mode 100644 index 00000000000..364197bebdb --- /dev/null +++ b/tensorflow/tools/benchmark/BUILD @@ -0,0 +1,66 @@ +# Description: +# Benchmark utility that can run on desktop and Android. + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "tf_copts") + +exports_files(["LICENSE"]) + +cc_library( + name = "benchmark_model_lib", + srcs = [ + "benchmark_model.cc", + ], + copts = tf_copts(), + visibility = ["//visibility:public"], + deps = select({ + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + ], + "//conditions:default": [ + "//tensorflow/core:core_cpu", + "//tensorflow/core:lib", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", + ], + }), +) + +# This binary may be built for either desktop or Android. +# A typical Android build command will look like the following: +# bazel build -c opt tensorflow/core:android_tensorflow_lib \ +# --crosstool_top=//external:android/crosstool \ +# --cpu=armeabi-v7a \ +# --host_crosstool_top=@bazel_tools//tools/cpp:toolchain +# +# NOTE: currently '-pthread' must be removed from the LINK_OPTS variable +# in google/protobuf/BUILD to sucessfully build for Android. This is temporary +# pending an update of the version of the protobuf library that Tensorflow +# uses. +cc_binary( + name = "benchmark_model", + copts = tf_copts(), + linkopts = select({ + "//tensorflow:android": [ + "-pie", + "-s", + "-landroid", + "-ljnigraphics", + "-llog", + "-lm", + "-z defs", + "-s", + "-Wl,--icf=all", # Identical Code Folding + "-Wl,--exclude-libs,ALL", # Exclude syms in all libs from auto export + ], + "//conditions:default": [], + }), + linkstatic = 1, + visibility = ["//visibility:public"], + deps = [":benchmark_model_lib"], +) diff --git a/tensorflow/tools/benchmark/README.md b/tensorflow/tools/benchmark/README.md new file mode 100644 index 00000000000..bcfed4ff142 --- /dev/null +++ b/tensorflow/tools/benchmark/README.md @@ -0,0 +1,57 @@ +# Tensorflow Model Benchmark Tool + +## Description + +A simple C++ binary to benchmark a compute graph and its individual operators, +both on desktop machines and on Android. + +## To build/install/run + +### On Android: + +(1) build for your specific platform, e.g.: +```bash +$bazel build -c opt \ + --crosstool_top=//external:android/crosstool \ + --cpu=armeabi-v7a \ + --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \ + tensorflow/tools/benchmark:benchmark_model +``` + +(2) Connect your phone. Push the binary to your phone with adb push + (make the directory if required): +```bash +$adb push bazel-bin/tensorflow/tools/benchmark/benchmark_model /data/local/tmp +``` + +(3) Push the compute graph that you need to test. For example: + adb push tensorflow_inception_graph.pb /data/local/tmp + +(4) Run the benchmark. For example: +```bash +$adb shell "/data/local/tmp/benchmark_model \ + --graph=/data/local/tmp/tensorflow_inception_graph.pb \ + --input_layer="input:0" \ + --input_layer_shape="1,224,224,3" \ + --input_layer_type="float" \ + --output_layer="output:0" +``` +### On desktop: +(1) build the binary +```bash +$bazel build -c opt tensorflow/tools/benchmark:benchmark_model +``` + +(2) Run on your compute graph, similar to the Android case but without the need of adb shell. +For example: +```bash +$bazel-bin/tensorflow/tools/benchmark/benchmark_model \ + --graph=tensorflow_inception_graph.pb \ + --input_layer="input:0" \ + --input_layer_shape="1,224,224,3" \ + --input_layer_type="float" \ + --output_layer="output:0" +``` + +The Inception graph used as an example here may be downloaded from +https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip \ No newline at end of file diff --git a/tensorflow/tools/benchmark/benchmark_model.cc b/tensorflow/tools/benchmark/benchmark_model.cc new file mode 100644 index 00000000000..556f702fed4 --- /dev/null +++ b/tensorflow/tools/benchmark/benchmark_model.cc @@ -0,0 +1,225 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// A C++ binary to benchmark a compute graph and its individual operators, +// both on desktop machines and on Android. +// +// See README.md for usage instructions. + +#include +#include +#include +#include +#include + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/util/command_line_flags.h" +#include "tensorflow/core/util/stat_summarizer.h" + +namespace tensorflow { + +// Global variables that holds the Tensorflow classifier. +static std::unique_ptr session; + +static StatSummarizer g_stats; + +struct Flags { + string graph = "/data/local/tmp/tensorflow_inception_graph.pb"; + string input_layer = "input:0"; + string input_layer_shape = "1,224,224,3"; + string input_layer_type = "float"; + string output_layer = "output:0"; + int num_runs = 50; + string run_delay = "-1.0"; + int num_threads = -1; +}; + +static Flags* flags; // Filled in by main() + +static bool InitializeBenchmark() { + g_stats.Reset(); + + LOG(INFO) << "Loading Tensorflow."; + + tensorflow::SessionOptions options; + tensorflow::ConfigProto& config = options.config; + if (flags->num_threads > 0) { + config.set_intra_op_parallelism_threads(flags->num_threads); + } + LOG(INFO) << "Got config, " << config.device_count_size() << " devices"; + + session.reset(tensorflow::NewSession(options)); + tensorflow::GraphDef tensorflow_graph; + Status s = ReadBinaryProto(Env::Default(), flags->graph, &tensorflow_graph); + if (!s.ok()) { + LOG(ERROR) << "Could not create Tensorflow Graph: " << s; + return false; + } + + s = session->Create(tensorflow_graph); + if (!s.ok()) { + LOG(ERROR) << "Could not create Tensorflow Session: " << s; + return false; + } + + // Clear the proto to save memory space. + tensorflow_graph.Clear(); + return true; +} + +static bool RunBenchmark() { + DataType input_data_type; + CHECK(DataTypeFromString(flags->input_layer_type, &input_data_type)) + << flags->input_layer_type << " was an invalid type"; + + std::vector sizes; + CHECK(str_util::SplitAndParseAsInts(flags->input_layer_shape, ',', &sizes)) + << "Incorrect size string specified: " << flags->input_layer_shape; + TensorShape input_shape; + for (int i = 0; i < sizes.size(); ++i) { + input_shape.AddDim(sizes[i]); + } + + Tensor input_tensor(input_data_type, input_shape); + + switch (input_data_type) { + case DT_INT32: { + auto int_tensor = input_tensor.flat(); + int_tensor = int_tensor.constant(0.0); + break; + } + case DT_FLOAT: { + auto float_tensor = input_tensor.flat(); + float_tensor = float_tensor.constant(0.0); + break; + } + case DT_QUINT8: { + auto int_tensor = input_tensor.flat(); + int_tensor = int_tensor.constant(0.0); + break; + } + default: + LOG(FATAL) << "Unsupported input type: " << flags->input_layer_type; + } + + std::vector > input_tensors( + {{flags->input_layer, input_tensor}}); + + std::vector output_tensors; + std::vector output_names({flags->output_layer}); + + tensorflow::Status s; + + RunOptions run_options; + run_options.set_trace_level(RunOptions::FULL_TRACE); + RunMetadata run_metadata; + + s = session->Run(run_options, input_tensors, output_names, {}, + &output_tensors, &run_metadata); + + assert(run_metadata.has_step_stats()); + + const StepStats& stats = run_metadata.step_stats(); + + g_stats.ProcessStepStats(stats); + + if (!s.ok()) { + LOG(ERROR) << "Error during inference: " << s; + return false; + } + return true; +} + +} // namespace tensorflow + +int main(int argc, char** argv) { + tensorflow::flags = new tensorflow::Flags(); + + const bool parse_result = tensorflow::ParseFlags( + &argc, argv, + { + tensorflow::Flag("graph", &tensorflow::flags->graph), + tensorflow::Flag("input_layer", &tensorflow::flags->input_layer), + tensorflow::Flag("input_layer_shape", + &tensorflow::flags->input_layer_shape), + tensorflow::Flag("input_layer_type", + &tensorflow::flags->input_layer_type), + tensorflow::Flag("output_layer", &tensorflow::flags->output_layer), + tensorflow::Flag("num_runs", &tensorflow::flags->num_runs), + tensorflow::Flag("run_delay", &tensorflow::flags->run_delay), + tensorflow::Flag("num_threads", &tensorflow::flags->num_threads), + }); + + if (!parse_result) { + LOG(ERROR) << "Error parsing command-line flags."; + return -1; + } + + ::tensorflow::port::InitMain(argv[0], &argc, &argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1]; + return -1; + } + + LOG(INFO) << "Graph: [" << tensorflow::flags->graph << "]"; + LOG(INFO) << "Input layer: [" << tensorflow::flags->input_layer << "]"; + LOG(INFO) << "Input shape: [" << tensorflow::flags->input_layer_shape << "]"; + LOG(INFO) << "Input type: [" << tensorflow::flags->input_layer_type << "]"; + LOG(INFO) << "Output layer: [" << tensorflow::flags->output_layer << "]"; + LOG(INFO) << "Num runs: [" << tensorflow::flags->num_runs << "]"; + LOG(INFO) << "Inter-run delay (seconds): [" << tensorflow::flags->run_delay + << "]"; + LOG(INFO) << "Num threads: [" << tensorflow::flags->num_threads << "]"; + + if (!tensorflow::InitializeBenchmark()) { + return -1; + } + + // Convert the run_delay string into a timespec. + const double sleep_seconds = + std::strtod(tensorflow::flags->run_delay.c_str(), nullptr); + timespec req; + req.tv_sec = static_cast(sleep_seconds); + req.tv_nsec = (sleep_seconds - req.tv_sec) * 1000000000; + + LOG(INFO) << "Running benchmark"; + for (int i = 0; i < tensorflow::flags->num_runs; ++i) { + if (!tensorflow::RunBenchmark()) { + LOG(INFO) << "Failed on run " << i; + return -1; + } + + // If requested, sleep between runs for an arbitrary amount of time. + // This can be helpful to determine the effect of mobile processor + // scaling and thermal throttling. + if (sleep_seconds > 0.0) { + nanosleep(&req, nullptr); + } + } + + tensorflow::g_stats.PrintStepStats(); + return 0; +} diff --git a/tensorflow/tools/ci_build/builds/test_installation.sh b/tensorflow/tools/ci_build/builds/test_installation.sh index 0d9a192d479..6629c8fbc5c 100755 --- a/tensorflow/tools/ci_build/builds/test_installation.sh +++ b/tensorflow/tools/ci_build/builds/test_installation.sh @@ -139,6 +139,16 @@ else # Assume: PYTHON_BIN_PATH is exported by the script above fi +# Obtain the path to head/ghead binary (for log file printing) +HEAD_BIN="ghead" +if [[ -z $(which "${HEAD_BIN}") ]]; then + # This is not Mac (which uses coreutils/ghead), use head. + HEAD_BIN="head" + if [[ -z $(which "${HEAD_BIN}") ]]; then + die "Unable to obtain path to head or ghead" + fi +fi + if [[ -z "${PYTHON_BIN_PATH}" ]]; then die "PYTHON_BIN_PATH was not provided. If this is not virtualenv, "\ "did you run configure?" @@ -371,7 +381,7 @@ while true; do echo " Log @: ${TEST_LOGS[K]}" echo "============== BEGINS failure log content ==============" - head --lines=-1 "${TEST_LOGS[K]}" + "${HEAD_BIN}" --lines=-1 "${TEST_LOGS[K]}" echo "============== ENDS failure log content ==============" echo "" fi