Merge pull request #1902 from martinwicke/branch_119712558
Merge internal changes
This commit is contained in:
		
							parent
							
								
									d118d1d31c
								
							
						
					
					
						commit
						44a6b91ce8
					
				| @ -79,9 +79,9 @@ REGISTER_OP("CountExtremelyRandomStats") | |||||||
|      gives the j-th feature of the i-th input. |      gives the j-th feature of the i-th input. | ||||||
|    input_labels: The training batch's labels; `input_labels[i]` is the class |    input_labels: The training batch's labels; `input_labels[i]` is the class | ||||||
|      of the i-th input. |      of the i-th input. | ||||||
|    tree:= A 2-d int32 tensor.  `tree[0][i]` gives the index of the left child |    tree:= A 2-d int32 tensor.  `tree[i][0]` gives the index of the left child | ||||||
|      of the i-th node, `tree[0][i] + 1` gives the index of the right child of |      of the i-th node, `tree[i][0] + 1` gives the index of the right child of | ||||||
|      the i-th node, and `tree[1][i]` gives the index of the feature used to |      the i-th node, and `tree[i][1]` gives the index of the feature used to | ||||||
|      split the i-th node. |      split the i-th node. | ||||||
|    tree_thresholds: `tree_thresholds[i]` is the value used to split the i-th |    tree_thresholds: `tree_thresholds[i]` is the value used to split the i-th | ||||||
|      node. |      node. | ||||||
|  | |||||||
| @ -44,9 +44,9 @@ REGISTER_OP("TreePredictions") | |||||||
| 
 | 
 | ||||||
|   input_data: The training batch's features as a 2-d tensor; `input_data[i][j]` |   input_data: The training batch's features as a 2-d tensor; `input_data[i][j]` | ||||||
|    gives the j-th feature of the i-th input. |    gives the j-th feature of the i-th input. | ||||||
|   tree:= A 2-d int32 tensor.  `tree[0][i]` gives the index of the left child |   tree:= A 2-d int32 tensor.  `tree[i][0]` gives the index of the left child | ||||||
|    of the i-th node, `tree[0][i] + 1` gives the index of the right child of |    of the i-th node, `tree[i][0] + 1` gives the index of the right child of | ||||||
|    the i-th node, and `tree[1][i]` gives the index of the feature used to |    the i-th node, and `tree[i][1]` gives the index of the feature used to | ||||||
|    split the i-th node. |    split the i-th node. | ||||||
|   tree_thresholds: `tree_thresholds[i]` is the value used to split the i-th |   tree_thresholds: `tree_thresholds[i]` is the value used to split the i-th | ||||||
|    node. |    node. | ||||||
|  | |||||||
| @ -25,6 +25,12 @@ import tensorflow as tf | |||||||
| from tensorflow.python.framework import ops | from tensorflow.python.framework import ops | ||||||
| from tensorflow.python.framework import tensor_shape | from tensorflow.python.framework import tensor_shape | ||||||
| 
 | 
 | ||||||
|  | flags = tf.app.flags | ||||||
|  | FLAGS = flags.FLAGS | ||||||
|  | 
 | ||||||
|  | flags.DEFINE_string('inference_library_base_dir', '', | ||||||
|  |                     'Directory to look for inference library file.') | ||||||
|  | 
 | ||||||
| INFERENCE_OPS_FILE = '_inference_ops.so' | INFERENCE_OPS_FILE = '_inference_ops.so' | ||||||
| 
 | 
 | ||||||
| _inference_ops = None | _inference_ops = None | ||||||
| @ -54,7 +60,8 @@ def Load(): | |||||||
|   with _ops_lock: |   with _ops_lock: | ||||||
|     global _inference_ops |     global _inference_ops | ||||||
|     if not _inference_ops: |     if not _inference_ops: | ||||||
|       data_files_path = tf.resource_loader.get_data_files_path() |       data_files_path = os.path.join(FLAGS.inference_library_base_dir, | ||||||
|  |                                      tf.resource_loader.get_data_files_path()) | ||||||
|       tf.logging.info('data path: %s', data_files_path) |       tf.logging.info('data path: %s', data_files_path) | ||||||
|       _inference_ops = tf.load_op_library(os.path.join( |       _inference_ops = tf.load_op_library(os.path.join( | ||||||
|           data_files_path, INFERENCE_OPS_FILE)) |           data_files_path, INFERENCE_OPS_FILE)) | ||||||
|  | |||||||
| @ -25,6 +25,12 @@ import tensorflow as tf | |||||||
| from tensorflow.python.framework import ops | from tensorflow.python.framework import ops | ||||||
| from tensorflow.python.framework import tensor_shape | from tensorflow.python.framework import tensor_shape | ||||||
| 
 | 
 | ||||||
|  | flags = tf.app.flags | ||||||
|  | FLAGS = flags.FLAGS | ||||||
|  | 
 | ||||||
|  | flags.DEFINE_string('training_library_base_dir', '', | ||||||
|  |                     'Directory to look for inference library file.') | ||||||
|  | 
 | ||||||
| TRAINING_OPS_FILE = '_training_ops.so' | TRAINING_OPS_FILE = '_training_ops.so' | ||||||
| 
 | 
 | ||||||
| _training_ops = None | _training_ops = None | ||||||
| @ -101,7 +107,8 @@ def Load(): | |||||||
|   with _ops_lock: |   with _ops_lock: | ||||||
|     global _training_ops |     global _training_ops | ||||||
|     if not _training_ops: |     if not _training_ops: | ||||||
|       data_files_path = tf.resource_loader.get_data_files_path() |       data_files_path = os.path.join(FLAGS.training_library_base_dir, | ||||||
|  |                                      tf.resource_loader.get_data_files_path()) | ||||||
|       tf.logging.info('data path: %s', data_files_path) |       tf.logging.info('data path: %s', data_files_path) | ||||||
|       _training_ops = tf.load_op_library(os.path.join( |       _training_ops = tf.load_op_library(os.path.join( | ||||||
|           data_files_path, TRAINING_OPS_FILE)) |           data_files_path, TRAINING_OPS_FILE)) | ||||||
|  | |||||||
| @ -37,6 +37,13 @@ flags.DEFINE_float( | |||||||
|     'samples_to_decide', 25.0, |     'samples_to_decide', 25.0, | ||||||
|     'Only decide on a split, or only fully use a leaf, after this many ' |     'Only decide on a split, or only fully use a leaf, after this many ' | ||||||
|     'training samples have been seen.') |     'training samples have been seen.') | ||||||
|  | flags.DEFINE_float('bagging_fraction', 1.0, | ||||||
|  |                    'Use this fraction of the input, randomly chosen, to train ' | ||||||
|  |                    'each tree in the forest.') | ||||||
|  | flags.DEFINE_integer( | ||||||
|  |     'num_splits_to_consider', 0, | ||||||
|  |     'If non-zero, consider this many candidates for a splitting ' | ||||||
|  |     'rule at a fertile node.') | ||||||
| 
 | 
 | ||||||
| # If tree[i][0] equals this value, then i is a leaf node. | # If tree[i][0] equals this value, then i is a leaf node. | ||||||
| LEAF_NODE = -1 | LEAF_NODE = -1 | ||||||
| @ -69,6 +76,9 @@ class ForestHParams(object): | |||||||
|     # Fail fast if num_classes isn't set. |     # Fail fast if num_classes isn't set. | ||||||
|     _ = getattr(self, 'num_classes') |     _ = getattr(self, 'num_classes') | ||||||
| 
 | 
 | ||||||
|  |     self.bagging_fraction = getattr(self, 'bagging_fraction', | ||||||
|  |                                     FLAGS.bagging_fraction) | ||||||
|  | 
 | ||||||
|     self.num_trees = getattr(self, 'num_trees', FLAGS.num_trees) |     self.num_trees = getattr(self, 'num_trees', FLAGS.num_trees) | ||||||
|     self.max_nodes = getattr(self, 'max_nodes', FLAGS.max_nodes) |     self.max_nodes = getattr(self, 'max_nodes', FLAGS.max_nodes) | ||||||
| 
 | 
 | ||||||
| @ -79,7 +89,9 @@ class ForestHParams(object): | |||||||
|     # The Random Forest literature recommends sqrt(# features) for |     # The Random Forest literature recommends sqrt(# features) for | ||||||
|     # classification problems, and p/3 for regression problems. |     # classification problems, and p/3 for regression problems. | ||||||
|     # TODO(thomaswc): Consider capping this for large number of features. |     # TODO(thomaswc): Consider capping this for large number of features. | ||||||
|     if not getattr(self, 'num_splits_to_consider', None): |     self.num_splits_to_consider = getattr(self, 'num_splits_to_consider', | ||||||
|  |                                           FLAGS.num_splits_to_consider) | ||||||
|  |     if not self.num_splits_to_consider: | ||||||
|       self.num_splits_to_consider = max(10, int( |       self.num_splits_to_consider = max(10, int( | ||||||
|           math.ceil(math.sqrt(self.num_features)))) |           math.ceil(math.sqrt(self.num_features)))) | ||||||
| 
 | 
 | ||||||
| @ -94,8 +106,8 @@ class ForestHParams(object): | |||||||
|     self.max_fertile_nodes = getattr(self, 'max_fertile_nodes', num_fertile) |     self.max_fertile_nodes = getattr(self, 'max_fertile_nodes', num_fertile) | ||||||
|     # But it also never needs to be larger than the number of leaves, |     # But it also never needs to be larger than the number of leaves, | ||||||
|     # which is max_nodes / 2. |     # which is max_nodes / 2. | ||||||
|     self.max_fertile_nodes = min(self.max_nodes, |     self.max_fertile_nodes = min(self.max_fertile_nodes, | ||||||
|                                  int(math.ceil(self.max_fertile_nodes / 2.0))) |                                  int(math.ceil(self.max_nodes / 2.0))) | ||||||
| 
 | 
 | ||||||
|     # split_after_samples and valid_leaf_threshold should be about the same. |     # split_after_samples and valid_leaf_threshold should be about the same. | ||||||
|     # Therefore, if either is set, use it to set the other.  Otherwise, fall |     # Therefore, if either is set, use it to set the other.  Otherwise, fall | ||||||
| @ -184,23 +196,6 @@ class TreeStats(object): | |||||||
|     self.num_leaves = num_leaves |     self.num_leaves = num_leaves | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def get_tree_stats(variables, unused_params, session): |  | ||||||
|   num_nodes = variables.end_of_tree.eval(session=session) - 1 |  | ||||||
|   num_leaves = tf.where( |  | ||||||
|       tf.equal(tf.squeeze(tf.slice(variables.tree, [0, 0], [-1, 1])), |  | ||||||
|                LEAF_NODE)).eval(session=session).shape[0] |  | ||||||
|   return TreeStats(num_nodes, num_leaves) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def get_forest_stats(variables, params, session): |  | ||||||
| 
 |  | ||||||
|   tree_stats = [] |  | ||||||
|   for i in range(params.num_trees): |  | ||||||
|     tree_stats.append(get_tree_stats(variables[i], params, session)) |  | ||||||
| 
 |  | ||||||
|   return ForestStats(tree_stats, params) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class ForestTrainingVariables(object): | class ForestTrainingVariables(object): | ||||||
|   """A container for a forests training data, consisting of multiple trees. |   """A container for a forests training data, consisting of multiple trees. | ||||||
| 
 | 
 | ||||||
| @ -212,9 +207,11 @@ class ForestTrainingVariables(object): | |||||||
|     ... forest_variables.tree ... |     ... forest_variables.tree ... | ||||||
|   """ |   """ | ||||||
| 
 | 
 | ||||||
|   def __init__(self, params): |   def __init__(self, params, device_assigner): | ||||||
|     self.variables = [TreeTrainingVariables(params) |     self.variables = [] | ||||||
|                       for _ in range(params.num_trees)] |     for i in range(params.num_trees): | ||||||
|  |       with tf.device(device_assigner.get_device(i)): | ||||||
|  |         self.variables.append(TreeTrainingVariables(params)) | ||||||
| 
 | 
 | ||||||
|   def __setitem__(self, t, val): |   def __setitem__(self, t, val): | ||||||
|     self.variables[t] = val |     self.variables[t] = val | ||||||
| @ -223,12 +220,35 @@ class ForestTrainingVariables(object): | |||||||
|     return self.variables[t] |     return self.variables[t] | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | class RandomForestDeviceAssigner(object): | ||||||
|  |   """A device assigner that uses the default device. | ||||||
|  | 
 | ||||||
|  |   Write subclasses that implement get_device for control over how trees | ||||||
|  |   get assigned to devices.  This assumes that whole trees are assigned | ||||||
|  |   to a device. | ||||||
|  |   """ | ||||||
|  | 
 | ||||||
|  |   def __init__(self): | ||||||
|  |     self.cached = None | ||||||
|  | 
 | ||||||
|  |   def get_device(self, unused_tree_num): | ||||||
|  |     if not self.cached: | ||||||
|  |       dummy = tf.constant(0) | ||||||
|  |       self.cached = dummy.device | ||||||
|  | 
 | ||||||
|  |     return self.cached | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class RandomForestGraphs(object): | class RandomForestGraphs(object): | ||||||
|   """Builds TF graphs for random forest training and inference.""" |   """Builds TF graphs for random forest training and inference.""" | ||||||
| 
 | 
 | ||||||
|   def __init__(self, params): |   def __init__(self, params, device_assigner=None, variables=None): | ||||||
|     self.params = params |     self.params = params | ||||||
|     self.variables = ForestTrainingVariables(self.params) |     self.device_assigner = device_assigner or RandomForestDeviceAssigner() | ||||||
|  |     tf.logging.info('Constructing forest with params = ') | ||||||
|  |     tf.logging.info(self.params.__dict__) | ||||||
|  |     self.variables = variables or ForestTrainingVariables( | ||||||
|  |         self.params, device_assigner=self.device_assigner) | ||||||
|     self.trees = [RandomTreeGraphs(self.variables[i], self.params, |     self.trees = [RandomTreeGraphs(self.variables[i], self.params, | ||||||
|                                    training_ops.Load(), inference_ops.Load()) |                                    training_ops.Load(), inference_ops.Load()) | ||||||
|                   for i in range(self.params.num_trees)] |                   for i in range(self.params.num_trees)] | ||||||
| @ -246,12 +266,26 @@ class RandomForestGraphs(object): | |||||||
|     """ |     """ | ||||||
|     tree_graphs = [] |     tree_graphs = [] | ||||||
|     for i in range(self.params.num_trees): |     for i in range(self.params.num_trees): | ||||||
|       tf.logging.info('Constructing tree %d', i) |       with tf.device(self.device_assigner.get_device(i)): | ||||||
|         seed = self.params.base_random_seed |         seed = self.params.base_random_seed | ||||||
|         if seed != 0: |         if seed != 0: | ||||||
|           seed += i |           seed += i | ||||||
|       tree_graphs.append(self.trees[i].training_graph( |         # If using bagging, randomly select some of the input. | ||||||
|           input_data, input_labels, seed)) |         tree_data = input_data | ||||||
|  |         tree_labels = input_labels | ||||||
|  |         if self.params.bagging_fraction < 1.0: | ||||||
|  |           # TODO(thomaswc): This does sampling without replacment.  Consider | ||||||
|  |           # also allowing sampling with replacement as an option. | ||||||
|  |           batch_size = tf.slice(tf.shape(input_data), [0], [1]) | ||||||
|  |           r = tf.random_uniform(batch_size, seed=seed) | ||||||
|  |           mask = tf.less(r, tf.ones_like(r) * self.params.bagging_fraction) | ||||||
|  |           gather_indices = tf.squeeze(tf.where(mask), squeeze_dims=[1]) | ||||||
|  |           # TODO(thomaswc): Calculate out-of-bag data and labels, and store | ||||||
|  |           # them for use in calculating statistics later. | ||||||
|  |           tree_data = tf.gather(input_data, gather_indices) | ||||||
|  |           tree_labels = tf.gather(input_labels, gather_indices) | ||||||
|  |         tree_graphs.append( | ||||||
|  |             self.trees[i].training_graph(tree_data, tree_labels, seed)) | ||||||
|     return tf.group(*tree_graphs) |     return tf.group(*tree_graphs) | ||||||
| 
 | 
 | ||||||
|   def inference_graph(self, input_data): |   def inference_graph(self, input_data): | ||||||
| @ -265,10 +299,24 @@ class RandomForestGraphs(object): | |||||||
|     """ |     """ | ||||||
|     probabilities = [] |     probabilities = [] | ||||||
|     for i in range(self.params.num_trees): |     for i in range(self.params.num_trees): | ||||||
|  |       with tf.device(self.device_assigner.get_device(i)): | ||||||
|         probabilities.append(self.trees[i].inference_graph(input_data)) |         probabilities.append(self.trees[i].inference_graph(input_data)) | ||||||
|  |     with tf.device(self.device_assigner.get_device(0)): | ||||||
|       all_predict = tf.pack(probabilities) |       all_predict = tf.pack(probabilities) | ||||||
|       return tf.reduce_sum(all_predict, 0) / self.params.num_trees |       return tf.reduce_sum(all_predict, 0) / self.params.num_trees | ||||||
| 
 | 
 | ||||||
|  |   def average_size(self): | ||||||
|  |     """Constructs a TF graph for evaluating the average size of a forest. | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |       The average number of nodes over the trees. | ||||||
|  |     """ | ||||||
|  |     sizes = [] | ||||||
|  |     for i in range(self.params.num_trees): | ||||||
|  |       with tf.device(self.device_assigner.get_device(i)): | ||||||
|  |         sizes.append(self.trees[i].size()) | ||||||
|  |     return tf.reduce_mean(tf.pack(sizes)) | ||||||
|  | 
 | ||||||
|   def average_impurity(self): |   def average_impurity(self): | ||||||
|     """Constructs a TF graph for evaluating the leaf impurity of a forest. |     """Constructs a TF graph for evaluating the leaf impurity of a forest. | ||||||
| 
 | 
 | ||||||
| @ -277,9 +325,17 @@ class RandomForestGraphs(object): | |||||||
|     """ |     """ | ||||||
|     impurities = [] |     impurities = [] | ||||||
|     for i in range(self.params.num_trees): |     for i in range(self.params.num_trees): | ||||||
|       impurities.append(self.trees[i].average_impurity(self.variables[i])) |       with tf.device(self.device_assigner.get_device(i)): | ||||||
|  |         impurities.append(self.trees[i].average_impurity()) | ||||||
|     return tf.reduce_mean(tf.pack(impurities)) |     return tf.reduce_mean(tf.pack(impurities)) | ||||||
| 
 | 
 | ||||||
|  |   def get_stats(self, session): | ||||||
|  |     tree_stats = [] | ||||||
|  |     for i in range(self.params.num_trees): | ||||||
|  |       with tf.device(self.device_assigner.get_device(i)): | ||||||
|  |         tree_stats.append(self.trees[i].get_stats(session)) | ||||||
|  |     return ForestStats(tree_stats, self.params) | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| class RandomTreeGraphs(object): | class RandomTreeGraphs(object): | ||||||
|   """Builds TF graphs for random tree training and inference.""" |   """Builds TF graphs for random tree training and inference.""" | ||||||
| @ -394,6 +450,7 @@ class RandomTreeGraphs(object): | |||||||
|     with tf.control_dependencies([node_update_op]): |     with tf.control_dependencies([node_update_op]): | ||||||
|       def f1(): |       def f1(): | ||||||
|         return self.variables.non_fertile_leaf_scores |         return self.variables.non_fertile_leaf_scores | ||||||
|  | 
 | ||||||
|       def f2(): |       def f2(): | ||||||
|         counts = tf.gather(self.variables.node_per_class_weights, |         counts = tf.gather(self.variables.node_per_class_weights, | ||||||
|                            self.variables.non_fertile_leaves) |                            self.variables.non_fertile_leaves) | ||||||
| @ -535,3 +592,18 @@ class RandomTreeGraphs(object): | |||||||
|     counts = tf.gather(self.variables.node_per_class_weights, leaves) |     counts = tf.gather(self.variables.node_per_class_weights, leaves) | ||||||
|     impurity = self._weighted_gini(counts) |     impurity = self._weighted_gini(counts) | ||||||
|     return tf.reduce_sum(impurity) / tf.reduce_sum(counts + 1.0) |     return tf.reduce_sum(impurity) / tf.reduce_sum(counts + 1.0) | ||||||
|  | 
 | ||||||
|  |   def size(self): | ||||||
|  |     """Constructs a TF graph for evaluating the current number of nodes. | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |       The current number of nodes in the tree. | ||||||
|  |     """ | ||||||
|  |     return self.variables.end_of_tree - 1 | ||||||
|  | 
 | ||||||
|  |   def get_stats(self, session): | ||||||
|  |     num_nodes = self.variables.end_of_tree.eval(session=session) - 1 | ||||||
|  |     num_leaves = tf.where( | ||||||
|  |         tf.equal(tf.squeeze(tf.slice(self.variables.tree, [0, 0], [-1, 1])), | ||||||
|  |                  LEAF_NODE)).eval(session=session).shape[0] | ||||||
|  |     return TreeStats(num_nodes, num_leaves) | ||||||
|  | |||||||
| @ -27,6 +27,37 @@ from tensorflow.python.platform import googletest | |||||||
| 
 | 
 | ||||||
| class TensorForestTest(test_util.TensorFlowTestCase): | class TensorForestTest(test_util.TensorFlowTestCase): | ||||||
| 
 | 
 | ||||||
|  |   def testForestHParams(self): | ||||||
|  |     hparams = tensor_forest.ForestHParams( | ||||||
|  |         num_classes=2, num_trees=100, max_nodes=1000, | ||||||
|  |         num_features=60).fill() | ||||||
|  |     self.assertEquals(2, hparams.num_classes) | ||||||
|  |     # 2 * ceil(log_2(1000)) = 20 | ||||||
|  |     self.assertEquals(20, hparams.max_depth) | ||||||
|  |     # sqrt(num_features) < 10, so num_splits_to_consider should be 10. | ||||||
|  |     self.assertEquals(10, hparams.num_splits_to_consider) | ||||||
|  |     # Don't have more fertile nodes than max # leaves, which is 500. | ||||||
|  |     self.assertEquals(500, hparams.max_fertile_nodes) | ||||||
|  |     # We didn't set either of these, so they should be equal | ||||||
|  |     self.assertEquals(hparams.split_after_samples, | ||||||
|  |                       hparams.valid_leaf_threshold) | ||||||
|  |     # split_after_samples is larger than 10 | ||||||
|  |     self.assertEquals(1, hparams.split_initializations_per_input) | ||||||
|  |     self.assertEquals(0, hparams.base_random_seed) | ||||||
|  | 
 | ||||||
|  |   def testForestHParamsBigTree(self): | ||||||
|  |     hparams = tensor_forest.ForestHParams( | ||||||
|  |         num_classes=2, num_trees=100, max_nodes=1000000, | ||||||
|  |         split_after_samples=25, | ||||||
|  |         num_features=1000).fill() | ||||||
|  |     self.assertEquals(40, hparams.max_depth) | ||||||
|  |     # sqrt(1000) = 31.63... | ||||||
|  |     self.assertEquals(32, hparams.num_splits_to_consider) | ||||||
|  |     # 1000000 / 32 = 31250 | ||||||
|  |     self.assertEquals(31250, hparams.max_fertile_nodes) | ||||||
|  |     # floor(31.63 / 25) = 1 | ||||||
|  |     self.assertEquals(1, hparams.split_initializations_per_input) | ||||||
|  | 
 | ||||||
|   def testTrainingConstruction(self): |   def testTrainingConstruction(self): | ||||||
|     input_data = [[-1., 0.], [-1., 2.],  # node 1 |     input_data = [[-1., 0.], [-1., 2.],  # node 1 | ||||||
|                   [1., 0.], [1., -2.]]  # node 2 |                   [1., 0.], [1., -2.]]  # node 2 | ||||||
| @ -50,6 +81,14 @@ class TensorForestTest(test_util.TensorFlowTestCase): | |||||||
|     graph = graph_builder.inference_graph(input_data) |     graph = graph_builder.inference_graph(input_data) | ||||||
|     self.assertTrue(isinstance(graph, tf.Tensor)) |     self.assertTrue(isinstance(graph, tf.Tensor)) | ||||||
| 
 | 
 | ||||||
|  |   def testImpurityConstruction(self): | ||||||
|  |     params = tensor_forest.ForestHParams( | ||||||
|  |         num_classes=4, num_features=2, num_trees=10, max_nodes=1000).fill() | ||||||
|  | 
 | ||||||
|  |     graph_builder = tensor_forest.RandomForestGraphs(params) | ||||||
|  |     graph = graph_builder.average_impurity() | ||||||
|  |     self.assertTrue(isinstance(graph, tf.Tensor)) | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|   googletest.main() |   googletest.main() | ||||||
|  | |||||||
| @ -143,7 +143,6 @@ cc_library( | |||||||
|         "lib/core/bits.h", |         "lib/core/bits.h", | ||||||
|         "lib/core/casts.h", |         "lib/core/casts.h", | ||||||
|         "lib/core/coding.h", |         "lib/core/coding.h", | ||||||
|         "lib/core/command_line_flags.h",  # TODO(vrv): Delete. |  | ||||||
|         "lib/core/errors.h", |         "lib/core/errors.h", | ||||||
|         "lib/core/notification.h", |         "lib/core/notification.h", | ||||||
|         "lib/core/status.h", |         "lib/core/status.h", | ||||||
|  | |||||||
| @ -35,6 +35,7 @@ limitations under the License. | |||||||
| #include "tensorflow/core/framework/types.h" | #include "tensorflow/core/framework/types.h" | ||||||
| #include "tensorflow/core/lib/core/coding.h" | #include "tensorflow/core/lib/core/coding.h" | ||||||
| #include "tensorflow/core/lib/core/errors.h" | #include "tensorflow/core/lib/core/errors.h" | ||||||
|  | #include "tensorflow/core/lib/gtl/inlined_vector.h" | ||||||
| #include "tensorflow/core/lib/gtl/stl_util.h" | #include "tensorflow/core/lib/gtl/stl_util.h" | ||||||
| #include "tensorflow/core/lib/strings/str_util.h" | #include "tensorflow/core/lib/strings/str_util.h" | ||||||
| #include "tensorflow/core/lib/strings/strcat.h" | #include "tensorflow/core/lib/strings/strcat.h" | ||||||
| @ -713,4 +714,36 @@ void Tensor::FillDescription(TensorDescription* description) const { | |||||||
|   } |   } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | gtl::InlinedVector<int64, 5> Tensor::ComputeFlatInnerDims( | ||||||
|  |     int64 num_out_dims) const { | ||||||
|  |   gtl::InlinedVector<int64, 5> out_dims(num_out_dims, 0); | ||||||
|  |   const int64 num_elements = NumElements(); | ||||||
|  |   if (num_elements != 0) { | ||||||
|  |     int64 prod_out_dims = 1; | ||||||
|  |     for (int64 out_dim = num_out_dims - 1; out_dim > 0; --out_dim) { | ||||||
|  |       const int64 in_dim = out_dim + (dims() - num_out_dims); | ||||||
|  |       out_dims[out_dim] = | ||||||
|  |           (in_dim >= dims() || in_dim < 0) ? 1 : dim_size(in_dim); | ||||||
|  |       prod_out_dims *= out_dims[out_dim]; | ||||||
|  |     } | ||||||
|  |     out_dims[0] = num_elements / prod_out_dims; | ||||||
|  |   } | ||||||
|  |   return out_dims; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | gtl::InlinedVector<int64, 5> Tensor::ComputeFlatOuterDims( | ||||||
|  |     int64 num_out_dims) const { | ||||||
|  |   gtl::InlinedVector<int64, 5> out_dims(num_out_dims, 0); | ||||||
|  |   const int64 num_elements = NumElements(); | ||||||
|  |   if (num_elements != 0) { | ||||||
|  |     int64 prod_out_dims = 1; | ||||||
|  |     for (int64 out_dim = 0; out_dim < num_out_dims - 1; ++out_dim) { | ||||||
|  |       out_dims[out_dim] = out_dim >= dims() ? 1 : dim_size(out_dim); | ||||||
|  |       prod_out_dims *= out_dims[out_dim]; | ||||||
|  |     } | ||||||
|  |     out_dims[num_out_dims - 1] = num_elements / prod_out_dims; | ||||||
|  |   } | ||||||
|  |   return out_dims; | ||||||
|  | } | ||||||
|  | 
 | ||||||
| }  // namespace tensorflow
 | }  // namespace tensorflow
 | ||||||
|  | |||||||
| @ -28,6 +28,7 @@ limitations under the License. | |||||||
| #include "tensorflow/core/lib/core/refcount.h" | #include "tensorflow/core/lib/core/refcount.h" | ||||||
| #include "tensorflow/core/lib/core/status.h" | #include "tensorflow/core/lib/core/status.h" | ||||||
| #include "tensorflow/core/lib/core/stringpiece.h" | #include "tensorflow/core/lib/core/stringpiece.h" | ||||||
|  | #include "tensorflow/core/lib/gtl/inlined_vector.h" | ||||||
| #include "tensorflow/core/platform/logging.h" | #include "tensorflow/core/platform/logging.h" | ||||||
| #include "tensorflow/core/platform/macros.h" | #include "tensorflow/core/platform/macros.h" | ||||||
| #include "tensorflow/core/platform/types.h" | #include "tensorflow/core/platform/types.h" | ||||||
| @ -243,40 +244,28 @@ class Tensor { | |||||||
|   ///
 |   ///
 | ||||||
|   /// ```
 |   /// ```
 | ||||||
|   template <typename T> |   template <typename T> | ||||||
|   typename TTypes<T>::Flat flat(); |   typename TTypes<T>::Flat flat() { | ||||||
|  |     return shaped<T, 1>({NumElements()}); | ||||||
|  |   } | ||||||
| 
 | 
 | ||||||
|   template <typename T> |   template <typename T> | ||||||
|   typename TTypes<T>::UnalignedFlat unaligned_flat() { |   typename TTypes<T>::UnalignedFlat unaligned_flat() { | ||||||
|     return unaligned_shaped<T, 1>({NumElements()}); |     return unaligned_shaped<T, 1>({NumElements()}); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   /// Returns the data as an Eigen::Tensor with 2 dimensions, collapsing all
 |   /// Returns the data as an Eigen::Tensor with NDIMS dimensions, collapsing all
 | ||||||
|   /// Tensor dimensions but the last one into the first dimension of the result.
 |   /// Tensor dimensions but the last NDIMS-1 into the first dimension of the
 | ||||||
|   template <typename T> |   /// result. If NDIMS > dims() then leading dimensions of size 1 will be
 | ||||||
|   typename TTypes<T>::Matrix flat_inner_dims() { |   /// added to make the output rank NDIMS.
 | ||||||
|     int64 last_size = dims() > 0 ? dim_size(dims() - 1) : 1; |   template <typename T, size_t NDIMS = 2> | ||||||
|     if (last_size == 0) { |   typename TTypes<T, NDIMS>::Tensor flat_inner_dims(); | ||||||
|       DCHECK_EQ(NumElements(), 0); |  | ||||||
|       // Return something empty, avoiding divide by 0
 |  | ||||||
|       return shaped<T, 2>({0, 0}); |  | ||||||
|     } else { |  | ||||||
|       return shaped<T, 2>({NumElements() / last_size, last_size}); |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
| 
 | 
 | ||||||
|   /// Returns the data as an Eigen::Tensor with 2 dimensions, collapsing all
 |   /// Returns the data as an Eigen::Tensor with NDIMS dimensions, collapsing all
 | ||||||
|   /// Tensor dimensions but the first one into the last dimension of the result.
 |   /// Tensor dimensions but the first NDIMS-1 into the last dimension of the
 | ||||||
|   template <typename T> |   /// result. If NDIMS > dims() then trailing dimensions of size 1 will be
 | ||||||
|   typename TTypes<T>::Matrix flat_outer_dims() { |   /// added to make the output rank NDIMS.
 | ||||||
|     int64 first_size = dims() > 0 ? dim_size(0) : 1; |   template <typename T, size_t NDIMS = 2> | ||||||
|     if (first_size == 0) { |   typename TTypes<T, NDIMS>::Tensor flat_outer_dims(); | ||||||
|       DCHECK_EQ(NumElements(), 0); |  | ||||||
|       // Return something empty, avoiding divide by 0
 |  | ||||||
|       return shaped<T, 2>({0, 0}); |  | ||||||
|     } else { |  | ||||||
|       return shaped<T, 2>({first_size, NumElements() / first_size}); |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
| 
 | 
 | ||||||
|   template <typename T, size_t NDIMS> |   template <typename T, size_t NDIMS> | ||||||
|   typename TTypes<T, NDIMS>::Tensor shaped(gtl::ArraySlice<int64> new_sizes); |   typename TTypes<T, NDIMS>::Tensor shaped(gtl::ArraySlice<int64> new_sizes); | ||||||
| @ -308,31 +297,19 @@ class Tensor { | |||||||
|   typename TTypes<T, NDIMS>::ConstTensor tensor() const; |   typename TTypes<T, NDIMS>::ConstTensor tensor() const; | ||||||
| 
 | 
 | ||||||
|   template <typename T> |   template <typename T> | ||||||
|   typename TTypes<T>::ConstFlat flat() const; |   typename TTypes<T>::ConstFlat flat() const { | ||||||
|  |     return shaped<T, 1>({NumElements()}); | ||||||
|  |   } | ||||||
| 
 | 
 | ||||||
|   template <typename T> |   template <typename T> | ||||||
|   typename TTypes<T>::UnalignedConstFlat unaligned_flat() const { |   typename TTypes<T>::UnalignedConstFlat unaligned_flat() const { | ||||||
|     return unaligned_shaped<T, 1>({NumElements()}); |     return unaligned_shaped<T, 1>({NumElements()}); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   template <typename T> |  | ||||||
|   typename TTypes<T>::ConstMatrix flat_inner_dims() const { |  | ||||||
|     int64 last_size = dims() > 0 ? dim_size(dims() - 1) : 1; |  | ||||||
|     if (last_size == 0) { |  | ||||||
|       DCHECK_EQ(NumElements(), 0); |  | ||||||
|       // Return something empty, avoiding divide by 0
 |  | ||||||
|       return shaped<T, 2>({0, 0}); |  | ||||||
|     } else { |  | ||||||
|       return shaped<T, 2>({NumElements() / last_size, last_size}); |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   template <typename T> |  | ||||||
|   typename TTypes<T>::ConstMatrix flat_outer_dims() const; |  | ||||||
| 
 |  | ||||||
|   template <typename T, size_t NDIMS> |   template <typename T, size_t NDIMS> | ||||||
|   typename TTypes<T, NDIMS>::ConstTensor shaped( |   typename TTypes<T, NDIMS>::ConstTensor shaped( | ||||||
|       gtl::ArraySlice<int64> new_sizes) const; |       gtl::ArraySlice<int64> new_sizes) const; | ||||||
|  | 
 | ||||||
|   template <typename T, size_t NDIMS> |   template <typename T, size_t NDIMS> | ||||||
|   typename TTypes<T, NDIMS>::UnalignedConstTensor unaligned_shaped( |   typename TTypes<T, NDIMS>::UnalignedConstTensor unaligned_shaped( | ||||||
|       gtl::ArraySlice<int64> new_sizes) const; |       gtl::ArraySlice<int64> new_sizes) const; | ||||||
| @ -340,6 +317,12 @@ class Tensor { | |||||||
|   template <typename T> |   template <typename T> | ||||||
|   typename TTypes<T>::ConstScalar scalar() const; |   typename TTypes<T>::ConstScalar scalar() const; | ||||||
| 
 | 
 | ||||||
|  |   template <typename T, size_t NDIMS = 2> | ||||||
|  |   typename TTypes<T, NDIMS>::ConstTensor flat_inner_dims() const; | ||||||
|  | 
 | ||||||
|  |   template <typename T, size_t NDIMS = 2> | ||||||
|  |   typename TTypes<T, NDIMS>::ConstTensor flat_outer_dims() const; | ||||||
|  | 
 | ||||||
|   /// Render the first `max_entries` values in `*this` into a string.
 |   /// Render the first `max_entries` values in `*this` into a string.
 | ||||||
|   string SummarizeValue(int64 max_entries) const; |   string SummarizeValue(int64 max_entries) const; | ||||||
| 
 | 
 | ||||||
| @ -378,6 +361,8 @@ class Tensor { | |||||||
|   void FillDimsAndValidateCompatibleShape( |   void FillDimsAndValidateCompatibleShape( | ||||||
|       gtl::ArraySlice<int64> new_sizes, |       gtl::ArraySlice<int64> new_sizes, | ||||||
|       Eigen::array<Eigen::DenseIndex, NDIMS>* dims) const; |       Eigen::array<Eigen::DenseIndex, NDIMS>* dims) const; | ||||||
|  |   gtl::InlinedVector<int64, 5> ComputeFlatInnerDims(int64 num_out_dims) const; | ||||||
|  |   gtl::InlinedVector<int64, 5> ComputeFlatOuterDims(int64 num_out_dims) const; | ||||||
| 
 | 
 | ||||||
|   TensorShape shape_; |   TensorShape shape_; | ||||||
|   TensorBuffer* buf_; |   TensorBuffer* buf_; | ||||||
| @ -534,26 +519,24 @@ typename TTypes<T>::ConstScalar Tensor::scalar() const { | |||||||
|   return typename TTypes<T>::ConstScalar(base<T>()); |   return typename TTypes<T>::ConstScalar(base<T>()); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| template <typename T> | template <typename T, size_t NDIMS> | ||||||
| typename TTypes<T>::Flat Tensor::flat() { | typename TTypes<T, NDIMS>::Tensor Tensor::flat_inner_dims() { | ||||||
|   return shaped<T, 1>({NumElements()}); |   return shaped<T, NDIMS>(ComputeFlatInnerDims(NDIMS)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| template <typename T> | template <typename T, size_t NDIMS> | ||||||
| typename TTypes<T>::ConstFlat Tensor::flat() const { | typename TTypes<T, NDIMS>::Tensor Tensor::flat_outer_dims() { | ||||||
|   return shaped<T, 1>({NumElements()}); |   return shaped<T, NDIMS>(ComputeFlatOuterDims(NDIMS)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| template <typename T> | template <typename T, size_t NDIMS> | ||||||
| typename TTypes<T>::ConstMatrix Tensor::flat_outer_dims() const { | typename TTypes<T, NDIMS>::ConstTensor Tensor::flat_inner_dims() const { | ||||||
|   int64 first_size = dims() > 0 ? dim_size(0) : 1; |   return shaped<T, NDIMS>(ComputeFlatInnerDims(NDIMS)); | ||||||
|   if (first_size == 0) { | } | ||||||
|     DCHECK_EQ(NumElements(), 0); | 
 | ||||||
|     // Return something empty, avoiding divide by 0
 | template <typename T, size_t NDIMS> | ||||||
|     return shaped<T, 2>({0, 0}); | typename TTypes<T, NDIMS>::ConstTensor Tensor::flat_outer_dims() const { | ||||||
|   } else { |   return shaped<T, NDIMS>(ComputeFlatOuterDims(NDIMS)); | ||||||
|     return shaped<T, 2>({first_size, NumElements() / first_size}); |  | ||||||
|   } |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| }  // namespace tensorflow
 | }  // namespace tensorflow
 | ||||||
|  | |||||||
| @ -224,6 +224,49 @@ TEST(Tensor_Float, Reshape) { | |||||||
|     EXPECT_EQ(flat_inner_dims(0, 0), 0.01f); |     EXPECT_EQ(flat_inner_dims(0, 0), 0.01f); | ||||||
|     EXPECT_EQ(flat_inner_dims(23, 4), 0.02f); |     EXPECT_EQ(flat_inner_dims(23, 4), 0.02f); | ||||||
|   } |   } | ||||||
|  |   { | ||||||
|  |     auto flat_outer_dims = t.flat_outer_dims<float>(); | ||||||
|  |     EXPECT_EQ(2, flat_outer_dims.dimension(0)); | ||||||
|  |     EXPECT_EQ(60, flat_outer_dims.dimension(1)); | ||||||
|  |     EXPECT_EQ(flat_outer_dims(0, 0), 0.01f); | ||||||
|  |     EXPECT_EQ(flat_outer_dims(1, 59), 0.02f); | ||||||
|  |   } | ||||||
|  |   { | ||||||
|  |     auto flat_inner_dims = t.flat_inner_dims<float, 3>(); | ||||||
|  |     EXPECT_EQ(6, flat_inner_dims.dimension(0)); | ||||||
|  |     EXPECT_EQ(4, flat_inner_dims.dimension(1)); | ||||||
|  |     EXPECT_EQ(5, flat_inner_dims.dimension(2)); | ||||||
|  |     EXPECT_EQ(flat_inner_dims(0, 0, 0), 0.01f); | ||||||
|  |     EXPECT_EQ(flat_inner_dims(5, 3, 4), 0.02f); | ||||||
|  |   } | ||||||
|  |   { | ||||||
|  |     auto flat_outer_dims = t.flat_outer_dims<float, 3>(); | ||||||
|  |     EXPECT_EQ(2, flat_outer_dims.dimension(0)); | ||||||
|  |     EXPECT_EQ(3, flat_outer_dims.dimension(1)); | ||||||
|  |     EXPECT_EQ(20, flat_outer_dims.dimension(2)); | ||||||
|  |     EXPECT_EQ(flat_outer_dims(0, 0, 0), 0.01f); | ||||||
|  |     EXPECT_EQ(flat_outer_dims(1, 2, 19), 0.02f); | ||||||
|  |   } | ||||||
|  |   { | ||||||
|  |     auto flat_inner_dims = t.flat_inner_dims<float, 5>(); | ||||||
|  |     EXPECT_EQ(1, flat_inner_dims.dimension(0)); | ||||||
|  |     EXPECT_EQ(2, flat_inner_dims.dimension(1)); | ||||||
|  |     EXPECT_EQ(3, flat_inner_dims.dimension(2)); | ||||||
|  |     EXPECT_EQ(4, flat_inner_dims.dimension(3)); | ||||||
|  |     EXPECT_EQ(5, flat_inner_dims.dimension(4)); | ||||||
|  |     EXPECT_EQ(flat_inner_dims(0, 0, 0, 0, 0), 0.01f); | ||||||
|  |     EXPECT_EQ(flat_inner_dims(0, 1, 2, 3, 4), 0.02f); | ||||||
|  |   } | ||||||
|  |   { | ||||||
|  |     auto flat_outer_dims = t.flat_outer_dims<float, 5>(); | ||||||
|  |     EXPECT_EQ(2, flat_outer_dims.dimension(0)); | ||||||
|  |     EXPECT_EQ(3, flat_outer_dims.dimension(1)); | ||||||
|  |     EXPECT_EQ(4, flat_outer_dims.dimension(2)); | ||||||
|  |     EXPECT_EQ(5, flat_outer_dims.dimension(3)); | ||||||
|  |     EXPECT_EQ(1, flat_outer_dims.dimension(4)); | ||||||
|  |     EXPECT_EQ(flat_outer_dims(0, 0, 0, 0, 0), 0.01f); | ||||||
|  |     EXPECT_EQ(flat_outer_dims(1, 2, 3, 4, 0), 0.02f); | ||||||
|  |   } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| TEST(Tensor_Scalar, Basics) { | TEST(Tensor_Scalar, Basics) { | ||||||
|  | |||||||
| @ -305,6 +305,23 @@ tf_kernel_libraries( | |||||||
|     ], |     ], | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | tf_cc_test( | ||||||
|  |     name = "batch_norm_op_test", | ||||||
|  |     size = "small", | ||||||
|  |     deps = [ | ||||||
|  |         ":batch_norm_op", | ||||||
|  |         ":ops_testutil", | ||||||
|  |         ":ops_util", | ||||||
|  |         "//tensorflow/core:core_cpu", | ||||||
|  |         "//tensorflow/core:framework", | ||||||
|  |         "//tensorflow/core:lib", | ||||||
|  |         "//tensorflow/core:protos_all_cc", | ||||||
|  |         "//tensorflow/core:test", | ||||||
|  |         "//tensorflow/core:test_main", | ||||||
|  |         "//tensorflow/core:testlib", | ||||||
|  |     ], | ||||||
|  | ) | ||||||
|  | 
 | ||||||
| tf_cc_test( | tf_cc_test( | ||||||
|     name = "concat_op_test", |     name = "concat_op_test", | ||||||
|     size = "small", |     size = "small", | ||||||
|  | |||||||
							
								
								
									
										62
									
								
								tensorflow/core/kernels/batch_norm_op_test.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										62
									
								
								tensorflow/core/kernels/batch_norm_op_test.cc
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,62 @@ | |||||||
|  | /* Copyright 2015 Google Inc. All Rights Reserved.
 | ||||||
|  | 
 | ||||||
|  | Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | you may not use this file except in compliance with the License. | ||||||
|  | You may obtain a copy of the License at | ||||||
|  | 
 | ||||||
|  |     http://www.apache.org/licenses/LICENSE-2.0
 | ||||||
|  | 
 | ||||||
|  | Unless required by applicable law or agreed to in writing, software | ||||||
|  | distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | See the License for the specific language governing permissions and | ||||||
|  | limitations under the License. | ||||||
|  | ==============================================================================*/ | ||||||
|  | 
 | ||||||
|  | #include <vector> | ||||||
|  | #include "tensorflow/core/framework/allocator.h" | ||||||
|  | #include "tensorflow/core/framework/fake_input.h" | ||||||
|  | #include "tensorflow/core/framework/graph.pb.h" | ||||||
|  | #include "tensorflow/core/framework/node_def_builder.h" | ||||||
|  | #include "tensorflow/core/framework/op_kernel.h" | ||||||
|  | #include "tensorflow/core/framework/tensor.h" | ||||||
|  | #include "tensorflow/core/framework/tensor_testutil.h" | ||||||
|  | #include "tensorflow/core/framework/types.h" | ||||||
|  | #include "tensorflow/core/framework/types.pb.h" | ||||||
|  | #include "tensorflow/core/kernels/ops_testutil.h" | ||||||
|  | #include "tensorflow/core/kernels/ops_util.h" | ||||||
|  | #include "tensorflow/core/lib/core/status_test_util.h" | ||||||
|  | #include "tensorflow/core/platform/test.h" | ||||||
|  | 
 | ||||||
|  | namespace tensorflow { | ||||||
|  | 
 | ||||||
|  | class BatchNormOpTest : public OpsTestBase {}; | ||||||
|  | 
 | ||||||
|  | TEST_F(BatchNormOpTest, Simple) { | ||||||
|  |   TF_EXPECT_OK( | ||||||
|  |       NodeDefBuilder("batch_norm_op", "BatchNormWithGlobalNormalization") | ||||||
|  |           .Input(FakeInput(DT_FLOAT)) | ||||||
|  |           .Input(FakeInput(DT_FLOAT)) | ||||||
|  |           .Input(FakeInput(DT_FLOAT)) | ||||||
|  |           .Input(FakeInput(DT_FLOAT)) | ||||||
|  |           .Input(FakeInput(DT_FLOAT)) | ||||||
|  |           .Attr("scale_after_normalization", false) | ||||||
|  |           .Attr("variance_epsilon", 0.001) | ||||||
|  |           .Finalize(node_def())); | ||||||
|  |   TF_EXPECT_OK(InitOpWithGraphVersion(8)); | ||||||
|  |   AddInputFromArray<float>(TensorShape({1, 1, 6, 2}), | ||||||
|  |                            {1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6}); | ||||||
|  |   AddInputFromArray<float>(TensorShape({2}), {10, 20}); | ||||||
|  |   AddInputFromArray<float>(TensorShape({2}), {0.25, 0.5}); | ||||||
|  |   AddInputFromArray<float>(TensorShape({2}), {0.1, 0.6}); | ||||||
|  |   AddInputFromArray<float>(TensorShape({2}), {0.0, 0.0}); | ||||||
|  |   TF_ASSERT_OK(RunOpKernel()); | ||||||
|  | 
 | ||||||
|  |   Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 1, 6, 2})); | ||||||
|  |   test::FillValues<float>( | ||||||
|  |       &expected, {-17.86, -22.00, -15.87, -20.59, -13.87, -19.18, -21.86, | ||||||
|  |                   -33.31, -23.85, -34.72, -25.85, -36.13}); | ||||||
|  |   test::ExpectTensorNear<float>(expected, *GetOutput(0), 0.01); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | }  // namespace tensorflow
 | ||||||
| @ -94,10 +94,13 @@ class OpsTestBase : public ::testing::Test { | |||||||
|   // and output types as output.
 |   // and output types as output.
 | ||||||
|   //
 |   //
 | ||||||
|   // Returns the status of initialization.
 |   // Returns the status of initialization.
 | ||||||
|   Status InitOp() { |   Status InitOp() { return InitOpWithGraphVersion(TF_GRAPH_DEF_VERSION); } | ||||||
|  | 
 | ||||||
|  |   // Only use this directly if you have a deprecated op that you need to test.
 | ||||||
|  |   Status InitOpWithGraphVersion(int graph_def_version) { | ||||||
|     Status status; |     Status status; | ||||||
|     kernel_ = CreateOpKernel(device_type_, device_.get(), allocator(), |     kernel_ = CreateOpKernel(device_type_, device_.get(), allocator(), | ||||||
|                              node_def_, TF_GRAPH_DEF_VERSION, &status); |                              node_def_, graph_def_version, &status); | ||||||
|     if (kernel_ != nullptr) input_types_ = kernel_->input_types(); |     if (kernel_ != nullptr) input_types_ = kernel_->input_types(); | ||||||
|     return status; |     return status; | ||||||
|   } |   } | ||||||
|  | |||||||
| @ -66,7 +66,7 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T MaybeConj(T v) { | |||||||
| #define MAYBE_CONJ(T)                                         \ | #define MAYBE_CONJ(T)                                         \ | ||||||
|   template <>                                                 \ |   template <>                                                 \ | ||||||
|   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T MaybeConj<T>(T v) { \ |   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T MaybeConj<T>(T v) { \ | ||||||
|     return std::conj(v);                                      \ |     return Eigen::numext::conj(v);                            \ | ||||||
|   } |   } | ||||||
| #endif | #endif | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -1,121 +0,0 @@ | |||||||
| /* Copyright 2015 Google Inc. All Rights Reserved.
 |  | ||||||
| 
 |  | ||||||
| Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| you may not use this file except in compliance with the License. |  | ||||||
| You may obtain a copy of the License at |  | ||||||
| 
 |  | ||||||
|     http://www.apache.org/licenses/LICENSE-2.0
 |  | ||||||
| 
 |  | ||||||
| Unless required by applicable law or agreed to in writing, software |  | ||||||
| distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| See the License for the specific language governing permissions and |  | ||||||
| limitations under the License. |  | ||||||
| ==============================================================================*/ |  | ||||||
| 
 |  | ||||||
| #include "tensorflow/core/lib/core/command_line_flags.h" |  | ||||||
| 
 |  | ||||||
| #include "tensorflow/core/lib/strings/str_util.h" |  | ||||||
| #include "tensorflow/core/lib/strings/strcat.h" |  | ||||||
| #include "tensorflow/core/lib/strings/stringprintf.h" |  | ||||||
| 
 |  | ||||||
| namespace tensorflow { |  | ||||||
| namespace { |  | ||||||
| 
 |  | ||||||
| // Templated function to convert a string to target values.
 |  | ||||||
| // Return true if the conversion is successful. Otherwise, return false.
 |  | ||||||
| template <typename T> |  | ||||||
| bool StringToValue(const string& content, T* value); |  | ||||||
| 
 |  | ||||||
| template <> |  | ||||||
| bool StringToValue<int32>(const string& content, int32* value) { |  | ||||||
|   return strings::safe_strto32(content, value); |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| template <> |  | ||||||
| bool StringToValue<string>(const string& content, string* value) { |  | ||||||
|   *value = content; |  | ||||||
|   return true; |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // Parse a single argument by linearly searching through the command table.
 |  | ||||||
| // The input format is: --argument=value.
 |  | ||||||
| // Return OK if the argument is used. It store the extracted value into the
 |  | ||||||
| // matching flag.
 |  | ||||||
| // Return NOT_FOUND if the argument is not recognized.
 |  | ||||||
| // Return INVALID_ARGUMENT if the command is recognized, but fails to extract
 |  | ||||||
| // its value.
 |  | ||||||
| template <typename T> |  | ||||||
| Status ParseArgument(const string& argument) { |  | ||||||
|   for (auto& command : |  | ||||||
|        internal::CommandLineFlagRegistry<T>::Instance()->commands) { |  | ||||||
|     string prefix = strings::StrCat("--", command.name, "="); |  | ||||||
|     if (tensorflow::StringPiece(argument).starts_with(prefix)) { |  | ||||||
|       string content = argument.substr(prefix.length()); |  | ||||||
|       if (StringToValue<T>(content, command.value)) { |  | ||||||
|         return Status::OK(); |  | ||||||
|       } |  | ||||||
|       return Status(error::INVALID_ARGUMENT, |  | ||||||
|                     strings::StrCat("Cannot parse integer in: ", argument)); |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
|   return Status(error::NOT_FOUND, |  | ||||||
|                 strings::StrCat("Unknown command: ", argument)); |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // A specialization for booleans. The input format is:
 |  | ||||||
| //   "--argument" or "--noargument".
 |  | ||||||
| // Parse a single argument by linearly searching through the command table.
 |  | ||||||
| // Return OK if the argument is used. The value is stored in the matching flag.
 |  | ||||||
| // Return NOT_FOUND if the argument is not recognized.
 |  | ||||||
| template <> |  | ||||||
| Status ParseArgument<bool>(const string& argument) { |  | ||||||
|   for (auto& command : |  | ||||||
|        internal::CommandLineFlagRegistry<bool>::Instance()->commands) { |  | ||||||
|     if (argument == strings::StrCat("--", command.name)) { |  | ||||||
|       *command.value = true; |  | ||||||
|       return Status::OK(); |  | ||||||
|     } else if (argument == strings::StrCat("--no", command.name)) { |  | ||||||
|       *command.value = false; |  | ||||||
|       return Status::OK(); |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
|   return Status(error::NOT_FOUND, |  | ||||||
|                 strings::StrCat("Unknown command: ", argument)); |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| }  // namespace
 |  | ||||||
| 
 |  | ||||||
| Status ParseCommandLineFlags(int* argc, char* argv[]) { |  | ||||||
|   int unused_argc = 1; |  | ||||||
|   for (int index = 1; index < *argc; ++index) { |  | ||||||
|     Status s; |  | ||||||
|     // Search bool commands.
 |  | ||||||
|     s = ParseArgument<bool>(argv[index]); |  | ||||||
|     if (s.ok()) { |  | ||||||
|       continue; |  | ||||||
|     } |  | ||||||
|     if (s.code() != error::NOT_FOUND) { |  | ||||||
|       return s; |  | ||||||
|     } |  | ||||||
|     // Search int32 commands.
 |  | ||||||
|     s = ParseArgument<int32>(argv[index]); |  | ||||||
|     if (s.ok()) { |  | ||||||
|       continue; |  | ||||||
|     } |  | ||||||
|     // Search string commands.
 |  | ||||||
|     s = ParseArgument<string>(argv[index]); |  | ||||||
|     if (s.ok()) { |  | ||||||
|       continue; |  | ||||||
|     } |  | ||||||
|     if (s.code() != error::NOT_FOUND) { |  | ||||||
|       return s; |  | ||||||
|     } |  | ||||||
|     // Pointer swap the unused argument to the front.
 |  | ||||||
|     std::swap(argv[unused_argc++], argv[index]); |  | ||||||
|   } |  | ||||||
|   *argc = unused_argc; |  | ||||||
|   return Status::OK(); |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| }  // namespace tensorflow
 |  | ||||||
| @ -1,80 +0,0 @@ | |||||||
| /* Copyright 2015 Google Inc. All Rights Reserved.
 |  | ||||||
| 
 |  | ||||||
| Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| you may not use this file except in compliance with the License. |  | ||||||
| You may obtain a copy of the License at |  | ||||||
| 
 |  | ||||||
|     http://www.apache.org/licenses/LICENSE-2.0
 |  | ||||||
| 
 |  | ||||||
| Unless required by applicable law or agreed to in writing, software |  | ||||||
| distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| See the License for the specific language governing permissions and |  | ||||||
| limitations under the License. |  | ||||||
| ==============================================================================*/ |  | ||||||
| 
 |  | ||||||
| #ifndef TENSORFLOW_LIB_CORE_COMMAND_LINE_FLAGS_H_ |  | ||||||
| #define TENSORFLOW_LIB_CORE_COMMAND_LINE_FLAGS_H_ |  | ||||||
| 
 |  | ||||||
| #include <vector> |  | ||||||
| #include "tensorflow/core/lib/core/status.h" |  | ||||||
| #include "tensorflow/core/platform/macros.h" |  | ||||||
| #include "tensorflow/core/platform/types.h" |  | ||||||
| 
 |  | ||||||
| namespace tensorflow { |  | ||||||
| namespace internal { |  | ||||||
| 
 |  | ||||||
| template <typename T> |  | ||||||
| struct CommandLineFlagRegistry { |  | ||||||
|   static CommandLineFlagRegistry* Instance() { |  | ||||||
|     static CommandLineFlagRegistry instance_; |  | ||||||
|     return &instance_; |  | ||||||
|   } |  | ||||||
|   struct Command { |  | ||||||
|     string name; |  | ||||||
|     T* value; |  | ||||||
|     string text; |  | ||||||
|   }; |  | ||||||
|   std::vector<Command> commands; |  | ||||||
| 
 |  | ||||||
|  private: |  | ||||||
|   CommandLineFlagRegistry() {} |  | ||||||
|   TF_DISALLOW_COPY_AND_ASSIGN(CommandLineFlagRegistry); |  | ||||||
| }; |  | ||||||
| 
 |  | ||||||
| template <typename T> |  | ||||||
| struct CommandLineFlagRegister { |  | ||||||
|   CommandLineFlagRegister(const string& name, T* val, const string& text) { |  | ||||||
|     CommandLineFlagRegistry<T>::Instance()->commands.push_back( |  | ||||||
|         {name, val, text}); |  | ||||||
|   } |  | ||||||
| }; |  | ||||||
| 
 |  | ||||||
| #define TF_DEFINE_variable(type, name, default_value, text)     \ |  | ||||||
|   type FLAGS_##name = default_value;                            \ |  | ||||||
|   namespace TF_flags_internal {                                 \ |  | ||||||
|   tensorflow::internal::CommandLineFlagRegister<type>           \ |  | ||||||
|       TF_flags_internal_var_##name(#name, &FLAGS_##name, text); \ |  | ||||||
|   }  // namespace TF_flags_internal
 |  | ||||||
| 
 |  | ||||||
| }  // namespace internal
 |  | ||||||
| 
 |  | ||||||
| #define TF_DEFINE_int32(name, default_value, text) \ |  | ||||||
|   TF_DEFINE_variable(tensorflow::int32, name, default_value, text); |  | ||||||
| 
 |  | ||||||
| #define TF_DEFINE_bool(name, default_value, text) \ |  | ||||||
|   TF_DEFINE_variable(bool, name, default_value, text); |  | ||||||
| 
 |  | ||||||
| #define TF_DEFINE_string(name, default_value, text) \ |  | ||||||
|   TF_DEFINE_variable(string, name, default_value, text); |  | ||||||
| 
 |  | ||||||
| // Parse argv[1]..argv[*argc-1] to options. Remove used arguments from the argv.
 |  | ||||||
| // Returned the number of unused arguments in *argc.
 |  | ||||||
| // Return error Status if the parsing encounters errors.
 |  | ||||||
| // TODO(opensource): switch to a command line argument parser that can be
 |  | ||||||
| // shared with other tests.
 |  | ||||||
| Status ParseCommandLineFlags(int* argc, char* argv[]); |  | ||||||
| 
 |  | ||||||
| }  // namespace tensorflow
 |  | ||||||
| 
 |  | ||||||
| #endif  // TENSORFLOW_LIB_CORE_COMMAND_LINE_FLAGS_H_
 |  | ||||||
| @ -120,9 +120,12 @@ variable to its initial value. | |||||||
| ##### Args: | ##### Args: | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| *  <b>`initial_value`</b>: A `Tensor`, or Python object convertible to a `Tensor`. | *  <b>`initial_value`</b>: A `Tensor`, or Python object convertible to a `Tensor`, | ||||||
|     The initial value for the Variable. Must have a shape specified unless |     which is the initial value for the Variable. The initial value must have | ||||||
|     `validate_shape` is set to False. |     a shape specified unless `validate_shape` is set to False. Can also be a | ||||||
|  |     callable with no argument that returns the initial value when called. In | ||||||
|  |     that case, `dtype` must be specified. (Note that initializer functions | ||||||
|  |     from init_ops.py must first be bound to a shape before being used here.) | ||||||
| *  <b>`trainable`</b>: If `True`, the default, also adds the variable to the graph | *  <b>`trainable`</b>: If `True`, the default, also adds the variable to the graph | ||||||
|     collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as |     collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as | ||||||
|     the default list of variables to use by the `Optimizer` classes. |     the default list of variables to use by the `Optimizer` classes. | ||||||
|  | |||||||
| @ -1955,6 +1955,7 @@ on the parameters to the constructor and may include: | |||||||
| ##### Raises: | ##### Raises: | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | *  <b>`RuntimeError`</b>: If called with a non-chief Supervisor. | ||||||
| *  <b>`ValueError`</b>: If not `logdir` was passed to the constructor as the | *  <b>`ValueError`</b>: If not `logdir` was passed to the constructor as the | ||||||
|     services need a log directory. |     services need a log directory. | ||||||
| 
 | 
 | ||||||
| @ -2182,6 +2183,7 @@ on the parameters to the constructor and may include: | |||||||
| ##### Raises: | ##### Raises: | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | *  <b>`RuntimeError`</b>: If called with a non-chief Supervisor. | ||||||
| *  <b>`ValueError`</b>: If not `logdir` was passed to the constructor as the | *  <b>`ValueError`</b>: If not `logdir` was passed to the constructor as the | ||||||
|     services need a log directory. |     services need a log directory. | ||||||
| 
 | 
 | ||||||
| @ -2409,7 +2411,7 @@ Start threads for `QueueRunners`. | |||||||
| 
 | 
 | ||||||
| #### `tf.train.Supervisor.summary_op` {#Supervisor.summary_op} | #### `tf.train.Supervisor.summary_op` {#Supervisor.summary_op} | ||||||
| 
 | 
 | ||||||
| Return the Summary Tensor used by the supervisor. | Return the Summary Tensor used by the chief supervisor. | ||||||
| 
 | 
 | ||||||
| ##### Returns: | ##### Returns: | ||||||
| 
 | 
 | ||||||
| @ -2420,7 +2422,7 @@ Return the Summary Tensor used by the supervisor. | |||||||
| 
 | 
 | ||||||
| #### `tf.train.Supervisor.summary_writer` {#Supervisor.summary_writer} | #### `tf.train.Supervisor.summary_writer` {#Supervisor.summary_writer} | ||||||
| 
 | 
 | ||||||
| Return the SummaryWriter used by the supervisor. | Return the SummaryWriter used by the chief supervisor. | ||||||
| 
 | 
 | ||||||
| ##### Returns: | ##### Returns: | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -19,6 +19,7 @@ from __future__ import absolute_import | |||||||
| from __future__ import division | from __future__ import division | ||||||
| from __future__ import print_function | from __future__ import print_function | ||||||
| 
 | 
 | ||||||
|  | import functools | ||||||
| import numpy as np | import numpy as np | ||||||
| import tensorflow as tf | import tensorflow as tf | ||||||
| 
 | 
 | ||||||
| @ -208,5 +209,59 @@ class RNNCellTest(tf.test.TestCase): | |||||||
|                                    0.13248, 0.13248]]) |                                    0.13248, 0.13248]]) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | class SlimRNNCellTest(tf.test.TestCase): | ||||||
|  | 
 | ||||||
|  |   def testBasicRNNCell(self): | ||||||
|  |     with self.test_session() as sess: | ||||||
|  |       with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): | ||||||
|  |         x = tf.zeros([1, 2]) | ||||||
|  |         m = tf.zeros([1, 2]) | ||||||
|  |         my_cell = functools.partial(basic_rnn_cell, num_units=2) | ||||||
|  |         g, _ = tf.nn.rnn_cell.SlimRNNCell(my_cell)(x, m) | ||||||
|  |         sess.run([tf.initialize_all_variables()]) | ||||||
|  |         res = sess.run([g], {x.name: np.array([[1., 1.]]), | ||||||
|  |                              m.name: np.array([[0.1, 0.1]])}) | ||||||
|  |         self.assertEqual(res[0].shape, (1, 2)) | ||||||
|  | 
 | ||||||
|  |   def testBasicRNNCellMatch(self): | ||||||
|  |     batch_size = 32 | ||||||
|  |     input_size = 100 | ||||||
|  |     num_units = 10 | ||||||
|  |     with self.test_session() as sess: | ||||||
|  |       with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): | ||||||
|  |         inputs = tf.random_uniform((batch_size, input_size)) | ||||||
|  |         _, initial_state = basic_rnn_cell(inputs, None, num_units) | ||||||
|  |         my_cell = functools.partial(basic_rnn_cell, num_units=num_units) | ||||||
|  |         slim_cell = tf.nn.rnn_cell.SlimRNNCell(my_cell) | ||||||
|  |         slim_outputs, slim_state = slim_cell(inputs, initial_state) | ||||||
|  |         rnn_cell = tf.nn.rnn_cell.BasicRNNCell(num_units) | ||||||
|  |         outputs, state = rnn_cell(inputs, initial_state) | ||||||
|  |         self.assertEqual(slim_outputs.get_shape(), outputs.get_shape()) | ||||||
|  |         self.assertEqual(slim_state.get_shape(), state.get_shape()) | ||||||
|  |         sess.run([tf.initialize_all_variables()]) | ||||||
|  |         res = sess.run([slim_outputs, slim_state, outputs, state]) | ||||||
|  |         self.assertAllClose(res[0], res[2]) | ||||||
|  |         self.assertAllClose(res[1], res[3]) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def basic_rnn_cell(inputs, state, num_units, scope=None): | ||||||
|  |   if state is None: | ||||||
|  |     if inputs is not None: | ||||||
|  |       batch_size = inputs.get_shape()[0] | ||||||
|  |       dtype = inputs.dtype | ||||||
|  |     else: | ||||||
|  |       batch_size = 0 | ||||||
|  |       dtype = tf.float32 | ||||||
|  |     init_output = tf.zeros(tf.pack([batch_size, num_units]), dtype=dtype) | ||||||
|  |     init_state = tf.zeros(tf.pack([batch_size, num_units]), dtype=dtype) | ||||||
|  |     init_output.set_shape([batch_size, num_units]) | ||||||
|  |     init_state.set_shape([batch_size, num_units]) | ||||||
|  |     return init_output, init_state | ||||||
|  |   else: | ||||||
|  |     with tf.variable_op_scope([inputs, state], scope, "BasicRNNCell"): | ||||||
|  |       output = tf.tanh(tf.nn.rnn_cell.linear([inputs, state], | ||||||
|  |                                              num_units, True)) | ||||||
|  |     return output, output | ||||||
|  | 
 | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|   tf.test.main() |   tf.test.main() | ||||||
|  | |||||||
| @ -302,6 +302,53 @@ class VariablesTestCase(tf.test.TestCase): | |||||||
|       self.assertEqual(var.op.device, init_op.device) |       self.assertEqual(var.op.device, init_op.device) | ||||||
|       sess.run(init_op) |       sess.run(init_op) | ||||||
| 
 | 
 | ||||||
|  |   def testInitializerFunction(self): | ||||||
|  |     value = [[-42], [133.7]] | ||||||
|  |     shape = [2, 1] | ||||||
|  |     with self.test_session(): | ||||||
|  |       initializer = lambda: tf.constant(value) | ||||||
|  |       with self.assertRaises(ValueError): | ||||||
|  |         # Checks that dtype must be specified. | ||||||
|  |         tf.Variable(initializer) | ||||||
|  | 
 | ||||||
|  |       v1 = tf.Variable(initializer, dtype=tf.float32) | ||||||
|  |       self.assertEqual(shape, v1.get_shape()) | ||||||
|  |       self.assertAllClose(value, v1.initial_value.eval()) | ||||||
|  |       with self.assertRaises(tf.errors.FailedPreconditionError): | ||||||
|  |         v1.eval() | ||||||
|  | 
 | ||||||
|  |       v2 = tf.Variable(tf.neg(v1.initialized_value()), dtype=tf.float32) | ||||||
|  |       self.assertEqual(v1.get_shape(), v2.get_shape()) | ||||||
|  |       self.assertAllClose(np.negative(value), v2.initial_value.eval()) | ||||||
|  | 
 | ||||||
|  |       # Once v2.initial_value.eval() has been called, v1 has effectively been | ||||||
|  |       # initialized. | ||||||
|  |       self.assertAllClose(value, v1.eval()) | ||||||
|  | 
 | ||||||
|  |       with self.assertRaises(tf.errors.FailedPreconditionError): | ||||||
|  |         v2.eval() | ||||||
|  |       tf.initialize_all_variables().run() | ||||||
|  |       self.assertAllClose(np.negative(value), v2.eval()) | ||||||
|  | 
 | ||||||
|  |   def testInitializerFunctionDevicePlacement(self): | ||||||
|  |     with self.test_session(): | ||||||
|  |       initializer = lambda: tf.constant(42.0) | ||||||
|  |       with tf.device("/cpu:100"): | ||||||
|  |         v1 = tf.Variable(initializer, dtype=tf.float32, name="v1") | ||||||
|  |       expected_device = "/device:CPU:100" | ||||||
|  |       expected_group_v1 = [b"loc:@v1"] | ||||||
|  |       self.assertEqual(expected_device, v1.op.device) | ||||||
|  |       self.assertEqual(expected_group_v1, v1.op.colocation_groups()) | ||||||
|  |       for i in v1.initializer.inputs: | ||||||
|  |         self.assertEqual(expected_device, i.op.device) | ||||||
|  |         self.assertEqual(expected_group_v1, i.op.colocation_groups()) | ||||||
|  | 
 | ||||||
|  |       v2 = tf.Variable(initializer, dtype=tf.float32, name="v2") | ||||||
|  |       expected_group_v2 = [b"loc:@v2"] | ||||||
|  |       self.assertEqual(expected_group_v2, v2.op.colocation_groups()) | ||||||
|  |       for i in v2.initializer.inputs: | ||||||
|  |         self.assertEqual(expected_group_v2, i.op.colocation_groups()) | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| class IsInitializedTest(tf.test.TestCase): | class IsInitializedTest(tf.test.TestCase): | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -167,19 +167,22 @@ def create_partitioned_variables( | |||||||
|       slice_offset[slice_dim] += var_shape[slice_dim] |       slice_offset[slice_dim] += var_shape[slice_dim] | ||||||
| 
 | 
 | ||||||
|       if callable(initializer): |       if callable(initializer): | ||||||
|         init_val = initializer(var_shape, dtype=dtype) |         init = initializer | ||||||
|         init_val = ops.convert_to_tensor(init_val, dtype=dtype) |         init_shape = var_shape | ||||||
|       elif isinstance(initializer, ops.Tensor): |       elif isinstance(initializer, ops.Tensor): | ||||||
|         init_val = array_ops.slice(initializer, var_offset, var_shape) |         init = array_ops.slice(initializer, var_offset, var_shape) | ||||||
|         # Use the dtype of the given tensor. |         # Use the dtype of the given tensor. | ||||||
|         dtype = init_val.dtype.base_dtype |         dtype = init.dtype.base_dtype | ||||||
|  |         init_shape = None | ||||||
|       else: |       else: | ||||||
|         init_val = ops.convert_to_tensor(initializer, dtype=dtype) |         init = ops.convert_to_tensor(initializer, dtype=dtype) | ||||||
|         init_val = array_ops.slice(init_val, var_offset, var_shape) |         init = array_ops.slice(init, var_offset, var_shape) | ||||||
|  |         init_shape = None | ||||||
| 
 | 
 | ||||||
|       var = variable_scope.get_variable(name="part_%d" % i, |       var = variable_scope.get_variable(name="part_%d" % i, | ||||||
|  |                                         shape=init_shape, | ||||||
|                                         dtype=dtype, |                                         dtype=dtype, | ||||||
|                                         initializer=init_val, |                                         initializer=init, | ||||||
|                                         trainable=trainable, |                                         trainable=trainable, | ||||||
|                                         collections=collections) |                                         collections=collections) | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -661,6 +661,42 @@ class MultiRNNCell(RNNCell): | |||||||
|     return cur_inp, array_ops.concat(1, new_states) |     return cur_inp, array_ops.concat(1, new_states) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | class SlimRNNCell(RNNCell): | ||||||
|  |   """A simple wrapper for slim.rnn_cells.""" | ||||||
|  | 
 | ||||||
|  |   def __init__(self, cell_fn): | ||||||
|  |     """Create a SlimRNNCell from a cell_fn. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |       cell_fn: a function which takes (inputs, state, scope) and produces the | ||||||
|  |         outputs and the new_state. Additionally when called with inputs=None and | ||||||
|  |         state=None it should return (initial_outputs, initial_state). | ||||||
|  | 
 | ||||||
|  |     Raises: | ||||||
|  |       TypeError: if cell_fn is not callable | ||||||
|  |       ValueError: if cell_fn cannot produce a valid initial state. | ||||||
|  |     """ | ||||||
|  |     if not callable(cell_fn): | ||||||
|  |       raise TypeError("cell_fn %s needs to be callable", cell_fn) | ||||||
|  |     self._cell_fn = cell_fn | ||||||
|  |     self._cell_name = cell_fn.func.__name__ | ||||||
|  |     _, init_state = self._cell_fn(None, None) | ||||||
|  |     state_shape = init_state.get_shape() | ||||||
|  |     self._state_size = state_shape.with_rank(2)[1].value | ||||||
|  |     if self._state_size is None: | ||||||
|  |       raise ValueError("Initial state created by %s has invalid shape %s", | ||||||
|  |                        self._cell_name, state_shape) | ||||||
|  | 
 | ||||||
|  |   @property | ||||||
|  |   def state_size(self): | ||||||
|  |     return self._state_size | ||||||
|  | 
 | ||||||
|  |   def __call__(self, inputs, state, scope=None): | ||||||
|  |     scope = scope or self._cell_name | ||||||
|  |     output, state = self._cell_fn(inputs, state, scope=scope) | ||||||
|  |     return output, state | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| def linear(args, output_size, bias, bias_start=0.0, scope=None): | def linear(args, output_size, bias, bias_start=0.0, scope=None): | ||||||
|   """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable. |   """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable. | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -144,14 +144,19 @@ class _VariableStore(object): | |||||||
|     with ops.control_dependencies(None): |     with ops.control_dependencies(None): | ||||||
|       if initializing_from_value: |       if initializing_from_value: | ||||||
|         init_val = initializer |         init_val = initializer | ||||||
|  |         variable_dtype = None | ||||||
|       else: |       else: | ||||||
|         with ops.name_scope(name + "/Initializer/"): |         init_val = lambda: initializer(shape.as_list(), dtype=dtype) | ||||||
|           init_val = initializer(shape.as_list(), dtype=dtype) |         variable_dtype = dtype.base_dtype | ||||||
| 
 | 
 | ||||||
|     # Create the variable. |     # Create the variable. | ||||||
|     v = variables.Variable(init_val, name=name, trainable=trainable, |     v = variables.Variable(initial_value=init_val, | ||||||
|  |                            name=name, | ||||||
|  |                            trainable=trainable, | ||||||
|                            collections=collections, |                            collections=collections, | ||||||
|                            caching_device=caching_device) |                            caching_device=caching_device, | ||||||
|  |                            dtype=variable_dtype) | ||||||
|  | 
 | ||||||
|     self._vars[name] = v |     self._vars[name] = v | ||||||
|     logging.info("Created variable %s with shape %s and init %s", v.name, |     logging.info("Created variable %s with shape %s and init %s", v.name, | ||||||
|                  format(shape), initializer) |                  format(shape), initializer) | ||||||
|  | |||||||
| @ -156,9 +156,12 @@ class Variable(object): | |||||||
|     variable to its initial value. |     variable to its initial value. | ||||||
| 
 | 
 | ||||||
|     Args: |     Args: | ||||||
|       initial_value: A `Tensor`, or Python object convertible to a `Tensor`. |       initial_value: A `Tensor`, or Python object convertible to a `Tensor`, | ||||||
|         The initial value for the Variable. Must have a shape specified unless |         which is the initial value for the Variable. The initial value must have | ||||||
|         `validate_shape` is set to False. |         a shape specified unless `validate_shape` is set to False. Can also be a | ||||||
|  |         callable with no argument that returns the initial value when called. In | ||||||
|  |         that case, `dtype` must be specified. (Note that initializer functions | ||||||
|  |         from init_ops.py must first be bound to a shape before being used here.) | ||||||
|       trainable: If `True`, the default, also adds the variable to the graph |       trainable: If `True`, the default, also adds the variable to the graph | ||||||
|         collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as |         collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as | ||||||
|         the default list of variables to use by the `Optimizer` classes. |         the default list of variables to use by the `Optimizer` classes. | ||||||
| @ -211,9 +214,12 @@ class Variable(object): | |||||||
|     """Creates a new variable from arguments. |     """Creates a new variable from arguments. | ||||||
| 
 | 
 | ||||||
|     Args: |     Args: | ||||||
|       initial_value: A `Tensor`, or Python object convertible to a `Tensor`. |       initial_value: A `Tensor`, or Python object convertible to a `Tensor`, | ||||||
|         The initial value for the Variable. Must have a shape specified unless |         which is the initial value for the Variable. The initial value must have | ||||||
|         `validate_shape` is set to False. |         a shape specified unless `validate_shape` is set to False. Can also be a | ||||||
|  |         callable with no argument that returns the initial value when called. In | ||||||
|  |         that case, `dtype` must be specified. (Note that initializer functions | ||||||
|  |         from init_ops.py must first be bound to a shape before being used here.) | ||||||
|       trainable: If `True`, the default, also adds the variable to the graph |       trainable: If `True`, the default, also adds the variable to the graph | ||||||
|         collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as |         collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as | ||||||
|         the default list of variables to use by the `Optimizer` classes. |         the default list of variables to use by the `Optimizer` classes. | ||||||
| @ -240,25 +246,62 @@ class Variable(object): | |||||||
|     """ |     """ | ||||||
|     if initial_value is None: |     if initial_value is None: | ||||||
|       raise ValueError("initial_value must be specified.") |       raise ValueError("initial_value must be specified.") | ||||||
|  |     init_from_fn = callable(initial_value) | ||||||
|  |     if init_from_fn and dtype is None: | ||||||
|  |       raise ValueError( | ||||||
|  |           "dtype must also be specified when initial_value is callable.") | ||||||
|  | 
 | ||||||
|     if collections is None: |     if collections is None: | ||||||
|       collections = [ops.GraphKeys.VARIABLES] |       collections = [ops.GraphKeys.VARIABLES] | ||||||
|     if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: |     if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: | ||||||
|       collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] |       collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] | ||||||
|     with ops.control_dependencies(None): |     with ops.control_dependencies(None): | ||||||
|       with ops.op_scope([initial_value], name, "Variable") as name: |       with ops.op_scope( | ||||||
|  |           [] if init_from_fn else [initial_value], name, "Variable") as name: | ||||||
|  | 
 | ||||||
|  |         # Get the initial value from a callable function. The real shape of the | ||||||
|  |         # variable will be set later, since under the init_from_fn case, the | ||||||
|  |         # shape won't be known until after the function is invoked. | ||||||
|  |         if init_from_fn: | ||||||
|  |           self._variable = state_ops.variable_op( | ||||||
|  |               [], | ||||||
|  |               dtype.base_dtype, | ||||||
|  |               set_shape=False, | ||||||
|  |               name=name) | ||||||
|  |           with ops.colocate_with(self._variable.op): | ||||||
|  |             with ops.name_scope("Initializer"): | ||||||
|  |               # Colocate the tensors created by the initial_value() function | ||||||
|  |               # with the variable itself. | ||||||
|  |               self._initial_value = ops.convert_to_tensor(initial_value(), | ||||||
|  |                                                           name="initial_value", | ||||||
|  |                                                           dtype=dtype) | ||||||
|  | 
 | ||||||
|  |         # Or get the initial value from a Tensor or Python object. | ||||||
|  |         else: | ||||||
|           self._initial_value = ops.convert_to_tensor(initial_value, |           self._initial_value = ops.convert_to_tensor(initial_value, | ||||||
|                                                       name="initial_value", |                                                       name="initial_value", | ||||||
|                                                       dtype=dtype) |                                                       dtype=dtype) | ||||||
|  |           # In this case, the variable op can't be created until after the | ||||||
|  |           # initial_value has been converted to a Tensor with a known type. | ||||||
|  |           self._variable = state_ops.variable_op( | ||||||
|  |               [], | ||||||
|  |               self._initial_value.dtype.base_dtype, | ||||||
|  |               set_shape=False, | ||||||
|  |               name=name) | ||||||
|  | 
 | ||||||
|  |         # Manually overrides the variable's shape with the initial value's. | ||||||
|  |         if validate_shape: | ||||||
|           initial_value_shape = self._initial_value.get_shape() |           initial_value_shape = self._initial_value.get_shape() | ||||||
|         if validate_shape and not initial_value_shape.is_fully_defined(): |           if not initial_value_shape.is_fully_defined(): | ||||||
|             raise ValueError("initial_value must have a shape specified: %s" |             raise ValueError("initial_value must have a shape specified: %s" | ||||||
|                              % self._initial_value) |                              % self._initial_value) | ||||||
|         shape_to_set = initial_value_shape if validate_shape else [] |           self._variable.set_shape(initial_value_shape) | ||||||
| 
 |           # TODO(b/28152992): Remove the below hack modifying the node_def shape | ||||||
|         self._variable = state_ops.variable_op( |           # directly once set_shape() handles it. | ||||||
|             shape_to_set, self._initial_value.dtype.base_dtype, |           self._variable.op.node_def.attr["shape"].shape.CopyFrom( | ||||||
|             set_shape=validate_shape, name=name) |               initial_value_shape.as_proto()) | ||||||
| 
 | 
 | ||||||
|  |         # Assigns initial value. | ||||||
|         with ops.colocate_with(self._variable.op): |         with ops.colocate_with(self._variable.op): | ||||||
|           self._initializer_op = state_ops.assign( |           self._initializer_op = state_ops.assign( | ||||||
|               self._variable, self._initial_value, |               self._variable, self._initial_value, | ||||||
|  | |||||||
| @ -79,3 +79,7 @@ def get_path_to_datafile(path): | |||||||
|   """ |   """ | ||||||
|   data_files_path = os.path.dirname(inspect.getfile(sys._getframe(1))) |   data_files_path = os.path.dirname(inspect.getfile(sys._getframe(1))) | ||||||
|   return os.path.join(data_files_path, path) |   return os.path.join(data_files_path, path) | ||||||
|  | 
 | ||||||
|  | def readahead_file_path(path, unused_readahead=None): | ||||||
|  |   """Readahead files not implemented; simply returns given path.""" | ||||||
|  |   return path | ||||||
|  | |||||||
| @ -22,6 +22,7 @@ from tensorflow.core.util import event_pb2 | |||||||
| from tensorflow.python import pywrap_tensorflow | from tensorflow.python import pywrap_tensorflow | ||||||
| from tensorflow.python.platform import app | from tensorflow.python.platform import app | ||||||
| from tensorflow.python.platform import logging | from tensorflow.python.platform import logging | ||||||
|  | from tensorflow.python.platform import resource_loader | ||||||
| from tensorflow.python.util import compat | from tensorflow.python.util import compat | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -31,6 +32,7 @@ class EventFileLoader(object): | |||||||
|   def __init__(self, file_path): |   def __init__(self, file_path): | ||||||
|     if file_path is None: |     if file_path is None: | ||||||
|       raise ValueError('A file path is required') |       raise ValueError('A file path is required') | ||||||
|  |     file_path = resource_loader.readahead_file_path(file_path) | ||||||
|     logging.debug('Opening a record reader pointing at %s', file_path) |     logging.debug('Opening a record reader pointing at %s', file_path) | ||||||
|     self._reader = pywrap_tensorflow.PyRecordReader_New( |     self._reader = pywrap_tensorflow.PyRecordReader_New( | ||||||
|         compat.as_bytes(file_path), 0) |         compat.as_bytes(file_path), 0) | ||||||
|  | |||||||
| @ -326,19 +326,29 @@ class Supervisor(object): | |||||||
|       self._init_global_step(global_step=global_step) |       self._init_global_step(global_step=global_step) | ||||||
|     self._graph = graph |     self._graph = graph | ||||||
|     self._is_chief = is_chief |     self._is_chief = is_chief | ||||||
|  |     self._coord = coordinator.Coordinator() | ||||||
|  |     self._started_threads = [] | ||||||
|  |     self._recovery_wait_secs = recovery_wait_secs | ||||||
|  | 
 | ||||||
|  |     # Only chief supervisors write event files, so only chief supervisors | ||||||
|  |     # should have event-writing properties. Set to None for non-chiefs. | ||||||
|  |     if self._is_chief: | ||||||
|       self._logdir = logdir |       self._logdir = logdir | ||||||
|       self._save_summaries_secs = save_summaries_secs |       self._save_summaries_secs = save_summaries_secs | ||||||
|       self._save_model_secs = save_model_secs |       self._save_model_secs = save_model_secs | ||||||
|     self._recovery_wait_secs = recovery_wait_secs |     else: | ||||||
|     self._coord = coordinator.Coordinator() |       self._logdir = None | ||||||
|     if logdir: |       self._save_summaries_secs = None | ||||||
|  |       self._save_model_secs = None | ||||||
|  | 
 | ||||||
|  |     if self._is_chief and self._logdir: | ||||||
|       self._save_path = os.path.join(self._logdir, checkpoint_basename) |       self._save_path = os.path.join(self._logdir, checkpoint_basename) | ||||||
|       self._summary_writer = summary_io.SummaryWriter(self._logdir) |       self._summary_writer = summary_io.SummaryWriter(self._logdir) | ||||||
|     else: |     else: | ||||||
|       self._save_path = None |       self._save_path = None | ||||||
|       self._summary_writer = None |       self._summary_writer = None | ||||||
|  | 
 | ||||||
|     self._init_session_manager(session_manager=session_manager) |     self._init_session_manager(session_manager=session_manager) | ||||||
|     self._started_threads = [] |  | ||||||
|     self._verify_setup() |     self._verify_setup() | ||||||
|     # The graph is not allowed to change anymore. |     # The graph is not allowed to change anymore. | ||||||
|     graph.finalize() |     graph.finalize() | ||||||
| @ -520,7 +530,7 @@ class Supervisor(object): | |||||||
| 
 | 
 | ||||||
|   @property |   @property | ||||||
|   def summary_writer(self): |   def summary_writer(self): | ||||||
|     """Return the SummaryWriter used by the supervisor. |     """Return the SummaryWriter used by the chief supervisor. | ||||||
| 
 | 
 | ||||||
|     Returns: |     Returns: | ||||||
|       A SummaryWriter. |       A SummaryWriter. | ||||||
| @ -529,7 +539,7 @@ class Supervisor(object): | |||||||
| 
 | 
 | ||||||
|   @property |   @property | ||||||
|   def summary_op(self): |   def summary_op(self): | ||||||
|     """Return the Summary Tensor used by the supervisor. |     """Return the Summary Tensor used by the chief supervisor. | ||||||
| 
 | 
 | ||||||
|     Returns: |     Returns: | ||||||
|       A string Tensor for the summary or `None`. |       A string Tensor for the summary or `None`. | ||||||
| @ -583,8 +593,7 @@ class Supervisor(object): | |||||||
| 
 | 
 | ||||||
|   def _write_graph(self): |   def _write_graph(self): | ||||||
|     """Writes graph_def to `logdir` and adds it to summary if applicable.""" |     """Writes graph_def to `logdir` and adds it to summary if applicable.""" | ||||||
|     if not self._is_chief: |     assert self._is_chief | ||||||
|       return |  | ||||||
|     if self._logdir: |     if self._logdir: | ||||||
|       training_util.write_graph(self._graph.as_graph_def(), |       training_util.write_graph(self._graph.as_graph_def(), | ||||||
|                                 self._logdir, "graph.pbtxt") |                                 self._logdir, "graph.pbtxt") | ||||||
| @ -610,11 +619,13 @@ class Supervisor(object): | |||||||
|         sv.coord.Join(<list of threads>) |         sv.coord.Join(<list of threads>) | ||||||
| 
 | 
 | ||||||
|     Raises: |     Raises: | ||||||
|  |       RuntimeError: If called with a non-chief Supervisor. | ||||||
|       ValueError: If not `logdir` was passed to the constructor as the |       ValueError: If not `logdir` was passed to the constructor as the | ||||||
|         services need a log directory. |         services need a log directory. | ||||||
|     """ |     """ | ||||||
|     if not self._is_chief: |     if not self._is_chief: | ||||||
|       return |       raise RuntimeError("Only chief supervisor can start standard services. " | ||||||
|  |                          "Because only cheif supervisors can write events.") | ||||||
|     if not self._logdir: |     if not self._logdir: | ||||||
|       logging.warning("Standard services need a 'logdir' " |       logging.warning("Standard services need a 'logdir' " | ||||||
|                       "passed to the SessionManager") |                       "passed to the SessionManager") | ||||||
| @ -812,14 +823,18 @@ class Supervisor(object): | |||||||
|       TypeError: if 'summary' is not a Summary proto or a string. |       TypeError: if 'summary' is not a Summary proto or a string. | ||||||
|       RuntimeError: if the Supervisor was created without a `logdir`. |       RuntimeError: if the Supervisor was created without a `logdir`. | ||||||
|     """ |     """ | ||||||
|     if not self._logdir: |     if not self._summary_writer: | ||||||
|       raise RuntimeError("summary_computed() requires a logdir") |       raise RuntimeError("Writing a summary requires a summary writer.") | ||||||
|     if global_step is None and self.global_step is not None: |     if global_step is None and self.global_step is not None: | ||||||
|       global_step = training_util.global_step(sess, self.global_step) |       global_step = training_util.global_step(sess, self.global_step) | ||||||
|     if self._summary_writer: |  | ||||||
|     self._summary_writer.add_summary(summary, global_step) |     self._summary_writer.add_summary(summary, global_step) | ||||||
| 
 | 
 | ||||||
|   def _default_global_step_tensor(self): |   def _default_global_step_tensor(self): | ||||||
|  |     """Returns the global_step from the default graph. | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |       The global step `Tensor` or `None`. | ||||||
|  |     """ | ||||||
|     try: |     try: | ||||||
|       gs = ops.get_default_graph().get_tensor_by_name("global_step:0") |       gs = ops.get_default_graph().get_tensor_by_name("global_step:0") | ||||||
|       if gs.dtype.base_dtype in [dtypes.int32, dtypes.int64]: |       if gs.dtype.base_dtype in [dtypes.int32, dtypes.int64]: | ||||||
|  | |||||||
| @ -73,12 +73,11 @@ class SupervisorTest(tf.test.TestCase): | |||||||
|       sess.close() |       sess.close() | ||||||
|       sv.stop() |       sv.stop() | ||||||
| 
 | 
 | ||||||
|   def testSummary(self): |   def testChiefCanWriteEvents(self): | ||||||
|     logdir = self._TestDir("basics") |     logdir = self._TestDir("basics") | ||||||
|     with tf.Graph().as_default(): |     with tf.Graph().as_default(): | ||||||
|       const = tf.constant([1.0, 2.0, 3.0]) |       summ = tf.scalar_summary(["c1", "c2", "c3"], tf.constant([1.0, 2.0, 3.0])) | ||||||
|       summ = tf.scalar_summary(["c1", "c2", "c3"], const) |       sv = tf.train.Supervisor(is_chief=True, logdir=logdir, summary_op=None) | ||||||
|       sv = tf.train.Supervisor(logdir=logdir, summary_op=None) |  | ||||||
|       sess = sv.prepare_or_wait_for_session("") |       sess = sv.prepare_or_wait_for_session("") | ||||||
|       sv.summary_computed(sess, sess.run(summ)) |       sv.summary_computed(sess, sess.run(summ)) | ||||||
|       sess.close() |       sess.close() | ||||||
| @ -113,13 +112,31 @@ class SupervisorTest(tf.test.TestCase): | |||||||
|     # We should be done. |     # We should be done. | ||||||
|     self.assertRaises(StopIteration, lambda: next(rr)) |     self.assertRaises(StopIteration, lambda: next(rr)) | ||||||
| 
 | 
 | ||||||
|  |   def testNonChiefCannotWriteEvents(self): | ||||||
|  | 
 | ||||||
|  |     def _summary_computed(): | ||||||
|  |       with tf.Graph().as_default(): | ||||||
|  |         sv = tf.train.Supervisor(is_chief=False) | ||||||
|  |         sess = sv.prepare_or_wait_for_session("") | ||||||
|  |         summ = tf.scalar_summary(["c1", "c2"], tf.constant([1.0, 2.0])) | ||||||
|  |         sv.summary_computed(sess, sess.run(summ)) | ||||||
|  | 
 | ||||||
|  |     def _start_standard_services(): | ||||||
|  |       with tf.Graph().as_default(): | ||||||
|  |         sv = tf.train.Supervisor(is_chief=False) | ||||||
|  |         sess = sv.prepare_or_wait_for_session("") | ||||||
|  |         sv.start_standard_services(sess) | ||||||
|  | 
 | ||||||
|  |     self.assertRaises(RuntimeError, _summary_computed) | ||||||
|  |     self.assertRaises(RuntimeError, _start_standard_services) | ||||||
|  | 
 | ||||||
|   def testNoLogdirButWantSummary(self): |   def testNoLogdirButWantSummary(self): | ||||||
|     with tf.Graph().as_default(): |     with tf.Graph().as_default(): | ||||||
|       const = tf.constant([1.0, 2.0, 3.0]) |       const = tf.constant([1.0, 2.0, 3.0]) | ||||||
|       summ = tf.scalar_summary(["c1", "c2", "c3"], const) |       summ = tf.scalar_summary(["c1", "c2", "c3"], const) | ||||||
|       sv = tf.train.Supervisor(logdir="", summary_op=None) |       sv = tf.train.Supervisor(logdir="", summary_op=None) | ||||||
|       sess = sv.prepare_or_wait_for_session("") |       sess = sv.prepare_or_wait_for_session("") | ||||||
|       with self.assertRaisesRegexp(RuntimeError, "requires a logdir"): |       with self.assertRaisesRegexp(RuntimeError, "requires a summary writer"): | ||||||
|         sv.summary_computed(sess, sess.run(summ)) |         sv.summary_computed(sess, sess.run(summ)) | ||||||
| 
 | 
 | ||||||
|   def testNoLogdirSucceeds(self): |   def testNoLogdirSucceeds(self): | ||||||
|  | |||||||
							
								
								
									
										66
									
								
								tensorflow/tools/benchmark/BUILD
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										66
									
								
								tensorflow/tools/benchmark/BUILD
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,66 @@ | |||||||
|  | # Description: | ||||||
|  | #   Benchmark utility that can run on desktop and Android. | ||||||
|  | 
 | ||||||
|  | package(default_visibility = ["//visibility:public"]) | ||||||
|  | 
 | ||||||
|  | licenses(["notice"])  # Apache 2.0 | ||||||
|  | 
 | ||||||
|  | load("//tensorflow:tensorflow.bzl", "tf_copts") | ||||||
|  | 
 | ||||||
|  | exports_files(["LICENSE"]) | ||||||
|  | 
 | ||||||
|  | cc_library( | ||||||
|  |     name = "benchmark_model_lib", | ||||||
|  |     srcs = [ | ||||||
|  |         "benchmark_model.cc", | ||||||
|  |     ], | ||||||
|  |     copts = tf_copts(), | ||||||
|  |     visibility = ["//visibility:public"], | ||||||
|  |     deps = select({ | ||||||
|  |         "//tensorflow:android": [ | ||||||
|  |             "//tensorflow/core:android_tensorflow_lib", | ||||||
|  |         ], | ||||||
|  |         "//conditions:default": [ | ||||||
|  |             "//tensorflow/core:core_cpu", | ||||||
|  |             "//tensorflow/core:lib", | ||||||
|  |             "//tensorflow/core:framework", | ||||||
|  |             "//tensorflow/core:framework_internal", | ||||||
|  |             "//tensorflow/core:protos_all_cc", | ||||||
|  |             "//tensorflow/core:tensorflow", | ||||||
|  |         ], | ||||||
|  |     }), | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | # This binary may be built for either desktop or Android. | ||||||
|  | # A typical Android build command will look like the following: | ||||||
|  | # bazel build -c opt tensorflow/core:android_tensorflow_lib \ | ||||||
|  | # --crosstool_top=//external:android/crosstool \ | ||||||
|  | # --cpu=armeabi-v7a \ | ||||||
|  | # --host_crosstool_top=@bazel_tools//tools/cpp:toolchain | ||||||
|  | # | ||||||
|  | # NOTE: currently '-pthread' must be removed from the LINK_OPTS variable | ||||||
|  | # in google/protobuf/BUILD to sucessfully build for Android. This is temporary | ||||||
|  | # pending an update of the version of the protobuf library that Tensorflow | ||||||
|  | # uses. | ||||||
|  | cc_binary( | ||||||
|  |     name = "benchmark_model", | ||||||
|  |     copts = tf_copts(), | ||||||
|  |     linkopts = select({ | ||||||
|  |         "//tensorflow:android": [ | ||||||
|  |             "-pie", | ||||||
|  |             "-s", | ||||||
|  |             "-landroid", | ||||||
|  |             "-ljnigraphics", | ||||||
|  |             "-llog", | ||||||
|  |             "-lm", | ||||||
|  |             "-z defs", | ||||||
|  |             "-s", | ||||||
|  |             "-Wl,--icf=all",  # Identical Code Folding | ||||||
|  |             "-Wl,--exclude-libs,ALL",  # Exclude syms in all libs from auto export | ||||||
|  |         ], | ||||||
|  |         "//conditions:default": [], | ||||||
|  |     }), | ||||||
|  |     linkstatic = 1, | ||||||
|  |     visibility = ["//visibility:public"], | ||||||
|  |     deps = [":benchmark_model_lib"], | ||||||
|  | ) | ||||||
							
								
								
									
										57
									
								
								tensorflow/tools/benchmark/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										57
									
								
								tensorflow/tools/benchmark/README.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,57 @@ | |||||||
|  | # Tensorflow Model Benchmark Tool | ||||||
|  | 
 | ||||||
|  | ## Description | ||||||
|  | 
 | ||||||
|  | A simple C++ binary to benchmark a compute graph and its individual operators, | ||||||
|  | both on desktop machines and on Android. | ||||||
|  | 
 | ||||||
|  | ## To build/install/run | ||||||
|  | 
 | ||||||
|  | ### On Android: | ||||||
|  | 
 | ||||||
|  | (1) build for your specific platform, e.g.: | ||||||
|  | ```bash | ||||||
|  | $bazel build -c opt \ | ||||||
|  |   --crosstool_top=//external:android/crosstool \ | ||||||
|  |   --cpu=armeabi-v7a \ | ||||||
|  |   --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \ | ||||||
|  |   tensorflow/tools/benchmark:benchmark_model | ||||||
|  | ``` | ||||||
|  | 
 | ||||||
|  | (2) Connect your phone. Push the binary to your phone with adb push | ||||||
|  |      (make the directory if required): | ||||||
|  | ```bash | ||||||
|  | $adb push bazel-bin/tensorflow/tools/benchmark/benchmark_model /data/local/tmp | ||||||
|  | ``` | ||||||
|  | 
 | ||||||
|  | (3) Push the compute graph that you need to test. For example: | ||||||
|  |      adb push tensorflow_inception_graph.pb /data/local/tmp | ||||||
|  | 
 | ||||||
|  | (4) Run the benchmark. For example: | ||||||
|  | ```bash | ||||||
|  | $adb shell "/data/local/tmp/benchmark_model \ | ||||||
|  |   --graph=/data/local/tmp/tensorflow_inception_graph.pb \ | ||||||
|  |   --input_layer="input:0" \ | ||||||
|  |   --input_layer_shape="1,224,224,3" \ | ||||||
|  |   --input_layer_type="float" \ | ||||||
|  |   --output_layer="output:0" | ||||||
|  | ``` | ||||||
|  | ### On desktop: | ||||||
|  | (1) build the binary | ||||||
|  | ```bash | ||||||
|  | $bazel build -c opt tensorflow/tools/benchmark:benchmark_model | ||||||
|  | ``` | ||||||
|  | 
 | ||||||
|  | (2) Run on your compute graph, similar to the Android case but without the need of adb shell. | ||||||
|  | For example: | ||||||
|  | ```bash | ||||||
|  | $bazel-bin/tensorflow/tools/benchmark/benchmark_model \ | ||||||
|  |   --graph=tensorflow_inception_graph.pb \ | ||||||
|  |   --input_layer="input:0" \ | ||||||
|  |   --input_layer_shape="1,224,224,3" \ | ||||||
|  |   --input_layer_type="float" \ | ||||||
|  |   --output_layer="output:0" | ||||||
|  | ``` | ||||||
|  | 
 | ||||||
|  | The Inception graph used as an example here may be downloaded from | ||||||
|  | https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip | ||||||
							
								
								
									
										225
									
								
								tensorflow/tools/benchmark/benchmark_model.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										225
									
								
								tensorflow/tools/benchmark/benchmark_model.cc
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,225 @@ | |||||||
|  | /* Copyright 2016 Google Inc. All Rights Reserved.
 | ||||||
|  | 
 | ||||||
|  | Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | you may not use this file except in compliance with the License. | ||||||
|  | You may obtain a copy of the License at | ||||||
|  | 
 | ||||||
|  |     http://www.apache.org/licenses/LICENSE-2.0
 | ||||||
|  | 
 | ||||||
|  | Unless required by applicable law or agreed to in writing, software | ||||||
|  | distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | See the License for the specific language governing permissions and | ||||||
|  | limitations under the License. | ||||||
|  | ==============================================================================*/ | ||||||
|  | 
 | ||||||
|  | // A C++ binary to benchmark a compute graph and its individual operators,
 | ||||||
|  | // both on desktop machines and on Android.
 | ||||||
|  | //
 | ||||||
|  | // See README.md for usage instructions.
 | ||||||
|  | 
 | ||||||
|  | #include <cstdlib> | ||||||
|  | #include <memory> | ||||||
|  | #include <string> | ||||||
|  | #include <unordered_set> | ||||||
|  | #include <vector> | ||||||
|  | 
 | ||||||
|  | #include "tensorflow/core/framework/graph.pb.h" | ||||||
|  | #include "tensorflow/core/framework/tensor.h" | ||||||
|  | #include "tensorflow/core/graph/algorithm.h" | ||||||
|  | #include "tensorflow/core/graph/graph.h" | ||||||
|  | #include "tensorflow/core/graph/graph_constructor.h" | ||||||
|  | #include "tensorflow/core/lib/strings/str_util.h" | ||||||
|  | #include "tensorflow/core/lib/strings/strcat.h" | ||||||
|  | #include "tensorflow/core/platform/env.h" | ||||||
|  | #include "tensorflow/core/platform/init_main.h" | ||||||
|  | #include "tensorflow/core/platform/logging.h" | ||||||
|  | #include "tensorflow/core/platform/types.h" | ||||||
|  | #include "tensorflow/core/public/session.h" | ||||||
|  | #include "tensorflow/core/util/command_line_flags.h" | ||||||
|  | #include "tensorflow/core/util/stat_summarizer.h" | ||||||
|  | 
 | ||||||
|  | namespace tensorflow { | ||||||
|  | 
 | ||||||
|  | // Global variables that holds the Tensorflow classifier.
 | ||||||
|  | static std::unique_ptr<tensorflow::Session> session; | ||||||
|  | 
 | ||||||
|  | static StatSummarizer g_stats; | ||||||
|  | 
 | ||||||
|  | struct Flags { | ||||||
|  |   string graph = "/data/local/tmp/tensorflow_inception_graph.pb"; | ||||||
|  |   string input_layer = "input:0"; | ||||||
|  |   string input_layer_shape = "1,224,224,3"; | ||||||
|  |   string input_layer_type = "float"; | ||||||
|  |   string output_layer = "output:0"; | ||||||
|  |   int num_runs = 50; | ||||||
|  |   string run_delay = "-1.0"; | ||||||
|  |   int num_threads = -1; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | static Flags* flags;  // Filled in by main()
 | ||||||
|  | 
 | ||||||
|  | static bool InitializeBenchmark() { | ||||||
|  |   g_stats.Reset(); | ||||||
|  | 
 | ||||||
|  |   LOG(INFO) << "Loading Tensorflow."; | ||||||
|  | 
 | ||||||
|  |   tensorflow::SessionOptions options; | ||||||
|  |   tensorflow::ConfigProto& config = options.config; | ||||||
|  |   if (flags->num_threads > 0) { | ||||||
|  |     config.set_intra_op_parallelism_threads(flags->num_threads); | ||||||
|  |   } | ||||||
|  |   LOG(INFO) << "Got config, " << config.device_count_size() << " devices"; | ||||||
|  | 
 | ||||||
|  |   session.reset(tensorflow::NewSession(options)); | ||||||
|  |   tensorflow::GraphDef tensorflow_graph; | ||||||
|  |   Status s = ReadBinaryProto(Env::Default(), flags->graph, &tensorflow_graph); | ||||||
|  |   if (!s.ok()) { | ||||||
|  |     LOG(ERROR) << "Could not create Tensorflow Graph: " << s; | ||||||
|  |     return false; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   s = session->Create(tensorflow_graph); | ||||||
|  |   if (!s.ok()) { | ||||||
|  |     LOG(ERROR) << "Could not create Tensorflow Session: " << s; | ||||||
|  |     return false; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   // Clear the proto to save memory space.
 | ||||||
|  |   tensorflow_graph.Clear(); | ||||||
|  |   return true; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | static bool RunBenchmark() { | ||||||
|  |   DataType input_data_type; | ||||||
|  |   CHECK(DataTypeFromString(flags->input_layer_type, &input_data_type)) | ||||||
|  |       << flags->input_layer_type << " was an invalid type"; | ||||||
|  | 
 | ||||||
|  |   std::vector<int32> sizes; | ||||||
|  |   CHECK(str_util::SplitAndParseAsInts(flags->input_layer_shape, ',', &sizes)) | ||||||
|  |       << "Incorrect size string specified: " << flags->input_layer_shape; | ||||||
|  |   TensorShape input_shape; | ||||||
|  |   for (int i = 0; i < sizes.size(); ++i) { | ||||||
|  |     input_shape.AddDim(sizes[i]); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   Tensor input_tensor(input_data_type, input_shape); | ||||||
|  | 
 | ||||||
|  |   switch (input_data_type) { | ||||||
|  |     case DT_INT32: { | ||||||
|  |       auto int_tensor = input_tensor.flat<int32>(); | ||||||
|  |       int_tensor = int_tensor.constant(0.0); | ||||||
|  |       break; | ||||||
|  |     } | ||||||
|  |     case DT_FLOAT: { | ||||||
|  |       auto float_tensor = input_tensor.flat<float>(); | ||||||
|  |       float_tensor = float_tensor.constant(0.0); | ||||||
|  |       break; | ||||||
|  |     } | ||||||
|  |     case DT_QUINT8: { | ||||||
|  |       auto int_tensor = input_tensor.flat<quint8>(); | ||||||
|  |       int_tensor = int_tensor.constant(0.0); | ||||||
|  |       break; | ||||||
|  |     } | ||||||
|  |     default: | ||||||
|  |       LOG(FATAL) << "Unsupported input type: " << flags->input_layer_type; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   std::vector<std::pair<string, tensorflow::Tensor> > input_tensors( | ||||||
|  |       {{flags->input_layer, input_tensor}}); | ||||||
|  | 
 | ||||||
|  |   std::vector<tensorflow::Tensor> output_tensors; | ||||||
|  |   std::vector<string> output_names({flags->output_layer}); | ||||||
|  | 
 | ||||||
|  |   tensorflow::Status s; | ||||||
|  | 
 | ||||||
|  |   RunOptions run_options; | ||||||
|  |   run_options.set_trace_level(RunOptions::FULL_TRACE); | ||||||
|  |   RunMetadata run_metadata; | ||||||
|  | 
 | ||||||
|  |   s = session->Run(run_options, input_tensors, output_names, {}, | ||||||
|  |                    &output_tensors, &run_metadata); | ||||||
|  | 
 | ||||||
|  |   assert(run_metadata.has_step_stats()); | ||||||
|  | 
 | ||||||
|  |   const StepStats& stats = run_metadata.step_stats(); | ||||||
|  | 
 | ||||||
|  |   g_stats.ProcessStepStats(stats); | ||||||
|  | 
 | ||||||
|  |   if (!s.ok()) { | ||||||
|  |     LOG(ERROR) << "Error during inference: " << s; | ||||||
|  |     return false; | ||||||
|  |   } | ||||||
|  |   return true; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | }  // namespace tensorflow
 | ||||||
|  | 
 | ||||||
|  | int main(int argc, char** argv) { | ||||||
|  |   tensorflow::flags = new tensorflow::Flags(); | ||||||
|  | 
 | ||||||
|  |   const bool parse_result = tensorflow::ParseFlags( | ||||||
|  |       &argc, argv, | ||||||
|  |       { | ||||||
|  |           tensorflow::Flag("graph", &tensorflow::flags->graph), | ||||||
|  |           tensorflow::Flag("input_layer", &tensorflow::flags->input_layer), | ||||||
|  |           tensorflow::Flag("input_layer_shape", | ||||||
|  |                            &tensorflow::flags->input_layer_shape), | ||||||
|  |           tensorflow::Flag("input_layer_type", | ||||||
|  |                            &tensorflow::flags->input_layer_type), | ||||||
|  |           tensorflow::Flag("output_layer", &tensorflow::flags->output_layer), | ||||||
|  |           tensorflow::Flag("num_runs", &tensorflow::flags->num_runs), | ||||||
|  |           tensorflow::Flag("run_delay", &tensorflow::flags->run_delay), | ||||||
|  |           tensorflow::Flag("num_threads", &tensorflow::flags->num_threads), | ||||||
|  |       }); | ||||||
|  | 
 | ||||||
|  |   if (!parse_result) { | ||||||
|  |     LOG(ERROR) << "Error parsing command-line flags."; | ||||||
|  |     return -1; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   ::tensorflow::port::InitMain(argv[0], &argc, &argv); | ||||||
|  |   if (argc > 1) { | ||||||
|  |     LOG(ERROR) << "Unknown argument " << argv[1]; | ||||||
|  |     return -1; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   LOG(INFO) << "Graph: [" << tensorflow::flags->graph << "]"; | ||||||
|  |   LOG(INFO) << "Input layer: [" << tensorflow::flags->input_layer << "]"; | ||||||
|  |   LOG(INFO) << "Input shape: [" << tensorflow::flags->input_layer_shape << "]"; | ||||||
|  |   LOG(INFO) << "Input type: [" << tensorflow::flags->input_layer_type << "]"; | ||||||
|  |   LOG(INFO) << "Output layer: [" << tensorflow::flags->output_layer << "]"; | ||||||
|  |   LOG(INFO) << "Num runs: [" << tensorflow::flags->num_runs << "]"; | ||||||
|  |   LOG(INFO) << "Inter-run delay (seconds): [" << tensorflow::flags->run_delay | ||||||
|  |             << "]"; | ||||||
|  |   LOG(INFO) << "Num threads: [" << tensorflow::flags->num_threads << "]"; | ||||||
|  | 
 | ||||||
|  |   if (!tensorflow::InitializeBenchmark()) { | ||||||
|  |     return -1; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   // Convert the run_delay string into a timespec.
 | ||||||
|  |   const double sleep_seconds = | ||||||
|  |       std::strtod(tensorflow::flags->run_delay.c_str(), nullptr); | ||||||
|  |   timespec req; | ||||||
|  |   req.tv_sec = static_cast<time_t>(sleep_seconds); | ||||||
|  |   req.tv_nsec = (sleep_seconds - req.tv_sec) * 1000000000; | ||||||
|  | 
 | ||||||
|  |   LOG(INFO) << "Running benchmark"; | ||||||
|  |   for (int i = 0; i < tensorflow::flags->num_runs; ++i) { | ||||||
|  |     if (!tensorflow::RunBenchmark()) { | ||||||
|  |       LOG(INFO) << "Failed on run " << i; | ||||||
|  |       return -1; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // If requested, sleep between runs for an arbitrary amount of time.
 | ||||||
|  |     // This can be helpful to determine the effect of mobile processor
 | ||||||
|  |     // scaling and thermal throttling.
 | ||||||
|  |     if (sleep_seconds > 0.0) { | ||||||
|  |       nanosleep(&req, nullptr); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   tensorflow::g_stats.PrintStepStats(); | ||||||
|  |   return 0; | ||||||
|  | } | ||||||
| @ -139,6 +139,16 @@ else | |||||||
|   # Assume: PYTHON_BIN_PATH is exported by the script above |   # Assume: PYTHON_BIN_PATH is exported by the script above | ||||||
| fi | fi | ||||||
| 
 | 
 | ||||||
|  | # Obtain the path to head/ghead binary (for log file printing) | ||||||
|  | HEAD_BIN="ghead" | ||||||
|  | if [[ -z $(which "${HEAD_BIN}") ]]; then | ||||||
|  |   # This is not Mac (which uses coreutils/ghead), use head. | ||||||
|  |   HEAD_BIN="head" | ||||||
|  |   if [[ -z $(which "${HEAD_BIN}") ]]; then | ||||||
|  |      die "Unable to obtain path to head or ghead" | ||||||
|  |   fi | ||||||
|  | fi | ||||||
|  | 
 | ||||||
| if [[ -z "${PYTHON_BIN_PATH}" ]]; then | if [[ -z "${PYTHON_BIN_PATH}" ]]; then | ||||||
|   die "PYTHON_BIN_PATH was not provided. If this is not virtualenv, "\ |   die "PYTHON_BIN_PATH was not provided. If this is not virtualenv, "\ | ||||||
| "did you run configure?" | "did you run configure?" | ||||||
| @ -371,7 +381,7 @@ while true; do | |||||||
| 
 | 
 | ||||||
|       echo "  Log @: ${TEST_LOGS[K]}" |       echo "  Log @: ${TEST_LOGS[K]}" | ||||||
|       echo "============== BEGINS failure log content ==============" |       echo "============== BEGINS failure log content ==============" | ||||||
|       head --lines=-1 "${TEST_LOGS[K]}" |       "${HEAD_BIN}" --lines=-1 "${TEST_LOGS[K]}" | ||||||
|       echo "============== ENDS failure log content ==============" |       echo "============== ENDS failure log content ==============" | ||||||
|       echo "" |       echo "" | ||||||
|     fi |     fi | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user