From a68d64495ae90eada08d028d7577343429f2555a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 31 Mar 2017 00:02:39 -0800 Subject: [PATCH 01/81] Add more test cases for DivideTwoScalarsS32. Change: 151798655 --- .../xla/tests/scalar_computations_test.cc | 51 +++++++++++++++++-- 1 file changed, 47 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc index 134eb91a1fe..d5cb98f304f 100644 --- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -245,13 +246,55 @@ XLA_TEST_F(ScalarComputationsTest, RemTwoScalarsF32) { ComputeAndCompareR0(&builder, 2.5f, {}, error_spec_); } -XLA_TEST_F(ScalarComputationsTest, DivideTwoScalarsS32) { - ComputationBuilder builder(client_, TestName()); - builder.Div(builder.ConstantR0(-5), builder.ConstantR0(2)); +struct DivS32Params { + int32 dividend; + int32 divisor; + int32 result; +}; - ComputeAndCompareR0(&builder, -2, {}); +void PrintTo(const DivS32Params& p, std::ostream* os) { + *os << "{" << p.dividend << ", " << p.divisor << ", " << p.result << "}"; } +class DivS32Test : public ClientLibraryTestBase, + public ::testing::WithParamInterface {}; + +XLA_TEST_P(DivS32Test, DivideTwoScalarsS32) { + DivS32Params p = GetParam(); + ComputationBuilder builder(client_, TestName()); + builder.Div(builder.ConstantR0(p.dividend), + builder.ConstantR0(p.divisor)); + + ComputeAndCompareR0(&builder, p.result, {}); +} + +INSTANTIATE_TEST_CASE_P(DivS32Test_Instantiation, DivS32Test, + ::testing::Values( + // Positive divisors. + DivS32Params{5, 2, 2}, // + DivS32Params{-5, 2, -2}, // + DivS32Params{17, 3, 5}, // + DivS32Params{-17, 3, -5}, // + // Negative divisors. + DivS32Params{5, -2, -2}, // + DivS32Params{-5, -2, 2}, // + DivS32Params{17, -3, -5}, // + DivS32Params{-17, -3, 5}, // + // Large positive divisors. + DivS32Params{INT32_MIN, INT32_MAX, -1}, // + DivS32Params{INT32_MIN + 1, INT32_MAX, -1}, // + DivS32Params{INT32_MIN + 2, INT32_MAX, 0}, // + DivS32Params{INT32_MIN, 0x40000000, -2}, // + DivS32Params{INT32_MIN + 1, 0x40000000, -1}, // + // Large negative divisors. + DivS32Params{INT32_MIN, INT32_MIN, 1}, // + DivS32Params{INT32_MIN, INT32_MIN + 1, 1}, // + DivS32Params{INT32_MIN + 1, INT32_MIN, 0}, // + DivS32Params{INT32_MAX, INT32_MIN, 0}, // + DivS32Params{INT32_MAX, INT32_MIN + 1, -1}, // + DivS32Params{INT32_MIN, -0x40000000, 2}, // + DivS32Params{INT32_MIN + 1, -0x40000000, 1})); + TEST_F(ScalarComputationsTest, RemainderTwoScalarsNegativeResultS32) { ComputationBuilder builder(client_, TestName()); builder.Rem(builder.ConstantR0(-5), builder.ConstantR0(2)); From e4a8dc831dbf2894c79659d50aea73999c1ff173 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 31 Mar 2017 02:30:28 -0800 Subject: [PATCH 02/81] Allow variable reuse in function. Change: 151807504 --- tensorflow/python/framework/function.py | 7 +++ tensorflow/python/framework/function_test.py | 52 ++++++++++++++++++++ tensorflow/python/ops/variable_scope.py | 5 +- 3 files changed, 63 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py index 3c663c3a9b2..8b156db6dc4 100644 --- a/tensorflow/python/framework/function.py +++ b/tensorflow/python/framework/function.py @@ -299,6 +299,7 @@ class _FuncGraph(ops.Graph): shape=None, dtype=None, initializer=None, + reuse=None, trainable=True, collections=None, # pylint: disable=redefined-outer-name use_resource=None, @@ -319,6 +320,7 @@ class _FuncGraph(ops.Graph): shape=shape, dtype=dtype, initializer=initializer, + reuse=reuse, trainable=trainable, collections=collections, use_resource=use_resource) @@ -886,6 +888,11 @@ class Defun(object): default graph. Because the addition of the function into the graph is deferred, the decorator can be used anywhere in the program. + Any variables created inside of the function are hoisted into the outer graph. + Note that the variables are created in the variable scope that was active + during the first call to the function. Subsequent function calls will refer to + the same set of variables. + Definitions of functions are frozen in a graph as soon as the graph is used to create a session. Therefore, nodes using the function must be created in the graph before the corresponding session is created. diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py index 96bf7bde29f..51a9215ac4e 100644 --- a/tensorflow/python/framework/function_test.py +++ b/tensorflow/python/framework/function_test.py @@ -722,6 +722,58 @@ class FunctionTest(test.TestCase): y = Bar(array_ops.zeros([1, 2, 3])) self.assertAllEqual(y.get_shape().as_list(), [1, 1, 2, 3]) + def testVariableReuse(self): + def LinearWithReuse(input_tensor, reuse=None): + size = input_tensor.shape.dims[1] + with variable_scope.variable_scope("linear", reuse=reuse): + w = variable_scope.get_variable("w", shape=[size, size], + dtype=input_tensor.dtype) + return math_ops.matmul(input_tensor, w) + + @function.Defun(dtypes.float32) + def Foo(inputs): + inputs = array_ops.reshape(inputs, [32, 100]) + hidden = LinearWithReuse(inputs) + return LinearWithReuse(hidden, reuse=True) + + input_op = array_ops.placeholder(shape=[32, 100], dtype=dtypes.float32) + output_op = Foo(input_op) + + global_vars = variables.global_variables() + self.assertEqual(len(global_vars), 1) + self.assertEqual(global_vars[0].name, "linear/w:0") + + with session.Session() as sess: + sess.run(variables.global_variables_initializer()) + output_val = sess.run(output_op, + feed_dict={input_op: np.random.rand(32, 100)}) + self.assertEqual(output_val.shape, (32, 100)) + + def testFunctionCallInDifferentVariableScopes(self): + @function.Defun(dtypes.float32) + def Foo(inputs): + var = variable_scope.get_variable("var", shape=[10], dtype=dtypes.float32, + initializer=init_ops.ones_initializer()) + return inputs + var + + input_op = array_ops.placeholder(shape=[10], dtype=dtypes.float32) + with variable_scope.variable_scope("vs1"): + out1_op = Foo(input_op) + + with variable_scope.variable_scope("vs2"): + out2_op = Foo(input_op) + + global_vars = variables.global_variables() + self.assertEqual(len(global_vars), 1) + self.assertEqual(global_vars[0].name, "vs1/var:0") + + with session.Session() as sess: + sess.run(variables.global_variables_initializer()) + out1, out2 = sess.run([out1_op, out2_op], + feed_dict={input_op: np.linspace(1, 10, 10)}) + self.assertAllEqual(out1, np.linspace(2, 11, 10)) + self.assertAllEqual(out2, np.linspace(2, 11, 10)) + class FunctionsFromProtos(test.TestCase): diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py index 2f97abdc791..19c5d3c3ea0 100644 --- a/tensorflow/python/ops/variable_scope.py +++ b/tensorflow/python/ops/variable_scope.py @@ -904,6 +904,7 @@ class VariableScope(object): dtype=None, initializer=None, regularizer=None, + reuse=None, trainable=True, collections=None, caching_device=None, @@ -920,6 +921,8 @@ class VariableScope(object): partitioner = self._partitioner if custom_getter is None: custom_getter = self._custom_getter + if reuse is None: + reuse = self._reuse full_name = self.name + "/" + name if self.name else name # Variable names only depend on variable_scope (full_name here), @@ -942,7 +945,7 @@ class VariableScope(object): return var_store.get_variable( full_name, shape=shape, dtype=dtype, initializer=initializer, - regularizer=regularizer, reuse=self.reuse, trainable=trainable, + regularizer=regularizer, reuse=reuse, trainable=trainable, collections=collections, caching_device=caching_device, partitioner=partitioner, validate_shape=validate_shape, use_resource=use_resource, custom_getter=custom_getter) From 78ead4e427eb97b814b32748575f4ad97d690595 Mon Sep 17 00:00:00 2001 From: Malcolm Reynolds Date: Fri, 31 Mar 2017 03:07:38 -0800 Subject: [PATCH 03/81] Only change ._variables_created in template after inner function succeeds. Change: 151809410 --- .../python/kernel_tests/template_test.py | 24 ++++++++++++++++++- tensorflow/python/ops/template.py | 15 +++++++----- 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/tensorflow/python/kernel_tests/template_test.py b/tensorflow/python/kernel_tests/template_test.py index be2d6a566ab..54e8098e4e6 100644 --- a/tensorflow/python/kernel_tests/template_test.py +++ b/tensorflow/python/kernel_tests/template_test.py @@ -306,7 +306,7 @@ class TemplateTest(test.TestCase): self.assertEqual(custom_getter_count[0], 2) # Test that custom getter is called when the variable scope is created - # during construction + # during construction custom_getter_count[0] = 0 tmpl2 = template.make_template( "s2", @@ -319,6 +319,28 @@ class TemplateTest(test.TestCase): tmpl2() self.assertEqual(custom_getter_count[0], 2) + def test_fails_gracefully(self): + for create_scope_now in [True, False]: + def module_function_with_one_arg(inputs): + w = variable_scope.get_variable( + "w", shape=[1], initializer=init_ops.zeros_initializer()) + return inputs * w + + templatized_function = template.make_template( + "f1", module_function_with_one_arg, + create_scope_now_=create_scope_now) + data = array_ops.zeros(1) + try: + # Try to connect with a kwarg which is unsupported. + templatized_function(data, is_training=True) + except TypeError: + pass + + # The failed __call__ hasn't modified the inner state. + self.assertFalse(templatized_function._variables_created) + templatized_function(data) + self.assertTrue(templatized_function._variables_created) + def test_name_scopes_for_variable_scopes(self): # Test that name scopes are not unnecessarily uniquified (but are # still uniquified when necessary). diff --git a/tensorflow/python/ops/template.py b/tensorflow/python/ops/template.py index 80dd74521be..48be9e2cdae 100644 --- a/tensorflow/python/ops/template.py +++ b/tensorflow/python/ops/template.py @@ -261,20 +261,23 @@ class Template(object): return self._call_func(args, kwargs, check_for_new_variables=True) else: # This is the first visit to __call__, but the scope has already been - # created in the constructor. Set _variables_created so that subsequent - # calls take the if branch above. - self._variables_created = True + # created in the constructor. Set _variables_created after the inner + # function is successfully called so that subsequent calls take the if + # branch above. with variable_scope.variable_scope(self._variable_scope): - return self._call_func(args, kwargs, check_for_new_variables=False) + result = self._call_func(args, kwargs, check_for_new_variables=False) + self._variables_created = True + return result else: # The scope was not created at construction time, so create it here. # Subsequent calls should reuse variables. - self._variables_created = True with variable_scope.variable_scope( self._unique_name, self._name, custom_getter=self._custom_getter) as vs: self._variable_scope = vs - return self._call_func(args, kwargs, check_for_new_variables=False) + result = self._call_func(args, kwargs, check_for_new_variables=False) + self._variables_created = True + return result @property def variable_scope(self): From 52267c1aceb1b8935f79022040d13b02fb986a8f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 31 Mar 2017 03:28:09 -0800 Subject: [PATCH 04/81] Add more test cases for scalar s32 remainder. Change: 151810398 --- .../xla/tests/scalar_computations_test.cc | 88 ++++++++----------- 1 file changed, 39 insertions(+), 49 deletions(-) diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc index d5cb98f304f..0005c0c9e23 100644 --- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -249,11 +249,13 @@ XLA_TEST_F(ScalarComputationsTest, RemTwoScalarsF32) { struct DivS32Params { int32 dividend; int32 divisor; - int32 result; + int32 quotient; + int32 remainder; }; void PrintTo(const DivS32Params& p, std::ostream* os) { - *os << "{" << p.dividend << ", " << p.divisor << ", " << p.result << "}"; + *os << "{" << p.dividend << ", " << p.divisor << ", " << p.quotient << ", " + << p.remainder << "}"; } class DivS32Test : public ClientLibraryTestBase, @@ -265,60 +267,48 @@ XLA_TEST_P(DivS32Test, DivideTwoScalarsS32) { builder.Div(builder.ConstantR0(p.dividend), builder.ConstantR0(p.divisor)); - ComputeAndCompareR0(&builder, p.result, {}); + ComputeAndCompareR0(&builder, p.quotient, {}); } -INSTANTIATE_TEST_CASE_P(DivS32Test_Instantiation, DivS32Test, - ::testing::Values( - // Positive divisors. - DivS32Params{5, 2, 2}, // - DivS32Params{-5, 2, -2}, // - DivS32Params{17, 3, 5}, // - DivS32Params{-17, 3, -5}, // - // Negative divisors. - DivS32Params{5, -2, -2}, // - DivS32Params{-5, -2, 2}, // - DivS32Params{17, -3, -5}, // - DivS32Params{-17, -3, 5}, // - // Large positive divisors. - DivS32Params{INT32_MIN, INT32_MAX, -1}, // - DivS32Params{INT32_MIN + 1, INT32_MAX, -1}, // - DivS32Params{INT32_MIN + 2, INT32_MAX, 0}, // - DivS32Params{INT32_MIN, 0x40000000, -2}, // - DivS32Params{INT32_MIN + 1, 0x40000000, -1}, // - // Large negative divisors. - DivS32Params{INT32_MIN, INT32_MIN, 1}, // - DivS32Params{INT32_MIN, INT32_MIN + 1, 1}, // - DivS32Params{INT32_MIN + 1, INT32_MIN, 0}, // - DivS32Params{INT32_MAX, INT32_MIN, 0}, // - DivS32Params{INT32_MAX, INT32_MIN + 1, -1}, // - DivS32Params{INT32_MIN, -0x40000000, 2}, // - DivS32Params{INT32_MIN + 1, -0x40000000, 1})); - -TEST_F(ScalarComputationsTest, RemainderTwoScalarsNegativeResultS32) { +XLA_TEST_P(DivS32Test, RemainderTwoScalarsS32) { + DivS32Params p = GetParam(); ComputationBuilder builder(client_, TestName()); - builder.Rem(builder.ConstantR0(-5), builder.ConstantR0(2)); + builder.Rem(builder.ConstantR0(p.dividend), + builder.ConstantR0(p.divisor)); - ComputeAndCompareR0(&builder, -1, {}); + ComputeAndCompareR0(&builder, p.remainder, {}); } -TEST_F(ScalarComputationsTest, RemainderTwoScalarsIntMinS32) { - ComputationBuilder builder(client_, TestName()); - builder.Rem(builder.ConstantR0(INT_MIN), - builder.ConstantR0(7919)); +INSTANTIATE_TEST_CASE_P( + DivS32Test_Instantiation, DivS32Test, + ::testing::Values( + // Positive divisors. + DivS32Params{5, 2, 2, 1}, // + DivS32Params{-5, 2, -2, -1}, // + DivS32Params{17, 3, 5, 2}, // + DivS32Params{-17, 3, -5, -2}, // + // Negative divisors. + DivS32Params{5, -2, -2, 1}, // + DivS32Params{-5, -2, 2, -1}, // + DivS32Params{17, -3, -5, 2}, // + DivS32Params{-17, -3, 5, -2}, // + // Large positive divisors. + DivS32Params{INT32_MIN, 7919, -271181, -1309}, // + DivS32Params{INT32_MIN, INT32_MAX, -1, -1}, // + DivS32Params{INT32_MIN + 1, INT32_MAX, -1, 0}, // + DivS32Params{INT32_MIN + 2, INT32_MAX, 0, INT32_MIN + 2}, // + DivS32Params{INT32_MIN, 0x40000000, -2, 0}, // + DivS32Params{INT32_MIN + 1, 0x40000000, -1, -0x3fffffff}, // + // Large negative divisors. + DivS32Params{INT32_MIN, INT32_MIN, 1, 0}, // + DivS32Params{INT32_MIN, INT32_MIN + 1, 1, -1}, // + DivS32Params{INT32_MIN + 1, INT32_MIN, 0, INT32_MIN + 1}, // + DivS32Params{INT32_MAX, INT32_MIN, 0, INT32_MAX}, // + DivS32Params{INT32_MAX, INT32_MIN + 1, -1, 0}, // + DivS32Params{INT32_MIN, -0x40000000, 2, 0}, // + DivS32Params{INT32_MIN + 1, -0x40000000, 1, -0x3fffffff})); - ComputeAndCompareR0(&builder, -1309, {}); -} - -TEST_F(ScalarComputationsTest, RemainderTwoScalarsIntMinVsIntMaxS32) { - ComputationBuilder builder(client_, TestName()); - builder.Rem(builder.ConstantR0(INT_MIN), - builder.ConstantR0(INT_MAX)); - - ComputeAndCompareR0(&builder, -1, {}); -} - -TEST_F(ScalarComputationsTest, RemainderTwoScalarsPositiveResultS32) { +TEST_F(ScalarComputationsTest, RemainderTwoScalarsNonConstDividendS32) { ComputationBuilder builder(client_, TestName()); auto x = builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "x"); builder.Rem(x, builder.ConstantR0(80000)); From 35e9035b8c76306cc85ed4871660ffb78d484a3a Mon Sep 17 00:00:00 2001 From: Ian Langmore Date: Fri, 31 Mar 2017 06:51:27 -0800 Subject: [PATCH 05/81] TESTFIX: LinearOperatorFullMatrix placeholder test was not using a placeholder. Change: 151822825 --- .../python/kernel_tests/linear_operator_full_matrix_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_full_matrix_test.py b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_full_matrix_test.py index 93cbb48e1b2..d4a9e97ce7a 100644 --- a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_full_matrix_test.py +++ b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_full_matrix_test.py @@ -45,7 +45,7 @@ class SquareLinearOperatorFullMatrixTest( # values are random and we want the same value used for both mat and # feed_dict. matrix = matrix.eval() - operator = linalg.LinearOperatorFullMatrix(matrix) + operator = linalg.LinearOperatorFullMatrix(matrix_ph) feed_dict = {matrix_ph: matrix} else: operator = linalg.LinearOperatorFullMatrix(matrix) @@ -105,7 +105,7 @@ class SquareLinearOperatorFullMatrixSymmetricPositiveDefiniteTest( # feed_dict. matrix = matrix.eval() operator = linalg.LinearOperatorFullMatrix( - matrix, is_self_adjoint=True, is_positive_definite=True) + matrix_ph, is_self_adjoint=True, is_positive_definite=True) feed_dict = {matrix_ph: matrix} else: operator = linalg.LinearOperatorFullMatrix( @@ -144,7 +144,7 @@ class NonSquareLinearOperatorFullMatrixTest( # values are random and we want the same value used for both mat and # feed_dict. matrix = matrix.eval() - operator = linalg.LinearOperatorFullMatrix(matrix) + operator = linalg.LinearOperatorFullMatrix(matrix_ph) feed_dict = {matrix_ph: matrix} else: operator = linalg.LinearOperatorFullMatrix(matrix) From 7ce8fb4ef2e1bf25db1c1667e71d78cfc42e802d Mon Sep 17 00:00:00 2001 From: Eugene Brevdo Date: Fri, 31 Mar 2017 08:04:13 -0800 Subject: [PATCH 06/81] [tf learn estimators] Bugfix to rnn_common following rollback of RNNCell instance argument support. Change: 151829043 --- .../contrib/learn/python/learn/estimators/rnn_common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py b/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py index f20dc788349..6bb2b8b2aad 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py +++ b/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py @@ -66,8 +66,8 @@ def _get_single_cell(cell_type, num_units): ValueError: `cell_type` is an invalid `RNNCell` name. TypeError: `cell_type` is not a string or a subclass of `RNNCell`. """ - cell_type = _CELL_TYPES.get(cell_type) - if cell_type is None and not issubclass(cell_type, contrib_rnn.RNNCell): + cell_type = _CELL_TYPES.get(cell_type, cell_type) + if not cell_type or not issubclass(cell_type, contrib_rnn.RNNCell): raise ValueError('The supported cell types are {}; got {}'.format( list(_CELL_TYPES.keys()), cell_type)) return cell_type(num_units=num_units) From 9731b5eb6e8522cf1664c125983b4b1e97b316c4 Mon Sep 17 00:00:00 2001 From: Yao Zhang Date: Fri, 31 Mar 2017 08:20:26 -0800 Subject: [PATCH 07/81] Add a SliceProcessor specialized for const inputs, used if a constant folding pass is applied before LayoutOptimizer. Change: 151830580 --- .../grappler/optimizers/layout_optimizer.cc | 44 ++++++++++++++----- 1 file changed, 33 insertions(+), 11 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc index b0988b8a891..791944a1b36 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc @@ -615,11 +615,9 @@ class ReluGradProcessor : public AgnosticNodeProcessor { } }; -// This is the older, less optimized gather-based SliceProcessor. We keep it as -// a test case for constant propagation optimization. -class SliceProcessorGatherBased : public AgnosticNodeProcessor { +class SliceProcessor : public AgnosticNodeProcessor { public: - SliceProcessorGatherBased(GraphDef* graph, NodeDef* node, NodeMap* node_map) + SliceProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map) : AgnosticNodeProcessor(graph, node, node_map) {} protected: @@ -663,9 +661,30 @@ class SliceProcessorGatherBased : public AgnosticNodeProcessor { } }; -class SliceProcessor : public AgnosticNodeProcessor { +// Specialized SliceProcessor, used if the second and third input are const +// nodes, which could be the case if a constant folding pass is applied +// before this optimization. +class SliceProcessorConst : public AgnosticNodeProcessor { public: - SliceProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map) + SliceProcessorConst(GraphDef* graph, NodeDef* node, NodeMap* node_map) + : AgnosticNodeProcessor(graph, node, node_map) {} + + protected: + Status CustomizedProcessing() override { + // Skip the first input, which is the data to be sliced. + for (int i = 1; i < node_->input_size(); i++) { + auto shape_node = node_map_->GetNode(node_->input(i)); + TF_RETURN_IF_ERROR(UpdateAttrValue(shape_node)); + } + return Status::OK(); + } +}; + +// Specialized SliceProcessor, used if the second input is ConcatOffset. An +// example use case is in the gradient computation of Concat for InceptionV3. +class SliceProcessorConcatOffset : public AgnosticNodeProcessor { + public: + SliceProcessorConcatOffset(GraphDef* graph, NodeDef* node, NodeMap* node_map) : AgnosticNodeProcessor(graph, node, node_map) {} protected: @@ -938,14 +957,17 @@ class DataLayoutOptimizer { node_processor.reset( new ReluGradProcessor(graph_, node, &node_map_)); } else if (node->op().compare("Slice") == 0) { - auto maybe_concatoffset_node = - node_map_.GetNode(NodeName(node->input(1))); - if (maybe_concatoffset_node->op() == "ConcatOffset") { + auto input1 = node_map_.GetNode(NodeName(node->input(1))); + auto input2 = node_map_.GetNode(NodeName(node->input(2))); + if (input1->op() == "ConcatOffset") { node_processor.reset( - new SliceProcessor(graph_, node, &node_map_)); + new SliceProcessorConcatOffset(graph_, node, &node_map_)); + } else if (input1->op() == "Const" && input2->op() == "Const") { + node_processor.reset( + new SliceProcessorConst(graph_, node, &node_map_)); } else { node_processor.reset( - new SliceProcessorGatherBased(graph_, node, &node_map_)); + new SliceProcessor(graph_, node, &node_map_)); } } else if (node->op().compare("Squeeze") == 0) { From 50be7aa7d72ded57c11c705e9de80da2bdc2220b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 31 Mar 2017 08:35:08 -0800 Subject: [PATCH 08/81] Migrate trees, models, testutils, and resources libraries to boosted_trees. Change: 151832033 --- tensorflow/BUILD | 1 + tensorflow/contrib/boosted_trees/lib/BUILD | 87 ++++ .../lib/models/multiple_additive_trees.cc | 140 +++++++ .../lib/models/multiple_additive_trees.h | 50 +++ .../models/multiple_additive_trees_test.cc | 381 ++++++++++++++++++ .../lib/testutil/batch_features_testutil.cc | 88 ++++ .../lib/testutil/batch_features_testutil.h | 45 +++ .../lib/testutil/random_tree_gen.cc | 211 ++++++++++ .../lib/testutil/random_tree_gen.h | 75 ++++ .../lib/testutil/random_tree_gen_main.cc | 64 +++ .../boosted_trees/lib/trees/decision_tree.cc | 170 ++++++++ .../boosted_trees/lib/trees/decision_tree.h | 49 +++ .../lib/trees/decision_tree_test.cc | 326 +++++++++++++++ tensorflow/contrib/boosted_trees/proto/BUILD | 9 + .../boosted_trees/proto/quantiles.proto | 32 ++ .../contrib/boosted_trees/resources/BUILD | 53 +++ .../decision_tree_ensemble_resource.h | 76 ++++ .../resources/quantile_stream_resource.h | 104 +++++ .../resources/stamped_resource.h | 42 ++ 19 files changed, 2003 insertions(+) create mode 100644 tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc create mode 100644 tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h create mode 100644 tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees_test.cc create mode 100644 tensorflow/contrib/boosted_trees/lib/testutil/batch_features_testutil.cc create mode 100644 tensorflow/contrib/boosted_trees/lib/testutil/batch_features_testutil.h create mode 100644 tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.cc create mode 100644 tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.h create mode 100644 tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen_main.cc create mode 100644 tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc create mode 100644 tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h create mode 100644 tensorflow/contrib/boosted_trees/lib/trees/decision_tree_test.cc create mode 100644 tensorflow/contrib/boosted_trees/proto/quantiles.proto create mode 100644 tensorflow/contrib/boosted_trees/resources/BUILD create mode 100644 tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h create mode 100644 tensorflow/contrib/boosted_trees/resources/quantile_stream_resource.h create mode 100644 tensorflow/contrib/boosted_trees/resources/stamped_resource.h diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 38a8c726675..f07ef0cabf1 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -194,6 +194,7 @@ filegroup( "//tensorflow/contrib/boosted_trees:all_files", "//tensorflow/contrib/boosted_trees/lib:all_files", "//tensorflow/contrib/boosted_trees/proto:all_files", + "//tensorflow/contrib/boosted_trees/resources:all_files", "//tensorflow/contrib/cloud:all_files", "//tensorflow/contrib/cloud/kernels:all_files", "//tensorflow/contrib/compiler:all_files", diff --git a/tensorflow/contrib/boosted_trees/lib/BUILD b/tensorflow/contrib/boosted_trees/lib/BUILD index 714bd324c2a..011c02d720f 100644 --- a/tensorflow/contrib/boosted_trees/lib/BUILD +++ b/tensorflow/contrib/boosted_trees/lib/BUILD @@ -160,3 +160,90 @@ cc_test( "//tensorflow/core:test_main", ], ) + +cc_library( + name = "models", + srcs = ["models/multiple_additive_trees.cc"], + hdrs = ["models/multiple_additive_trees.h"], + deps = [ + ":trees", + ":utils", + "//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc", + "//tensorflow/core:framework_headers_lib", + ], +) + +cc_test( + name = "multiple_additive_trees_test", + size = "small", + srcs = ["models/multiple_additive_trees_test.cc"], + deps = [ + ":batch_features_testutil", + ":models", + ":random_tree_gen", + "//tensorflow/contrib/boosted_trees/resources:decision_tree_ensemble_resource", + "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:lib", + "//tensorflow/core:tensor_testutil", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "trees", + srcs = ["trees/decision_tree.cc"], + hdrs = ["trees/decision_tree.h"], + deps = [ + ":utils", + "//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc", + "//tensorflow/core:framework_headers_lib", + ], +) + +cc_test( + name = "trees_test", + size = "small", + srcs = ["trees/decision_tree_test.cc"], + deps = [ + ":trees", + ":utils", + "//tensorflow/core:tensor_testutil", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "batch_features_testutil", + testonly = 1, + srcs = ["testutil/batch_features_testutil.cc"], + hdrs = ["testutil/batch_features_testutil.h"], + deps = [ + ":utils", + "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:testlib", + ], +) + +cc_library( + name = "random_tree_gen", + srcs = ["testutil/random_tree_gen.cc"], + hdrs = ["testutil/random_tree_gen.h"], + deps = [ + "//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc", + "//tensorflow/core:lib", + ], +) + +cc_binary( + name = "random_tree_gen_main", + srcs = ["testutil/random_tree_gen_main.cc"], + deps = [ + ":random_tree_gen", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], +) diff --git a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc new file mode 100644 index 00000000000..16bffd9becc --- /dev/null +++ b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc @@ -0,0 +1,140 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#include "tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h" +#include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h" +#include "tensorflow/contrib/boosted_trees/lib/utils/batch_features.h" +#include "tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h" + +namespace tensorflow { +namespace boosted_trees { +namespace models { + +namespace { +void CalculateTreesToKeep( + const boosted_trees::trees::DecisionTreeEnsembleConfig& config, + const std::vector& trees_to_drop, const int32 num_trees, + const bool only_finalized, std::vector* trees_to_keep) { + trees_to_keep->reserve(num_trees - trees_to_drop.size()); + + int32 index = 0; + // This assumes that trees_to_drop is a sorted list of tree ids. + for (int32 tree = 0; tree < num_trees; ++tree) { + if ((!trees_to_drop.empty() && index < trees_to_drop.size() && + trees_to_drop[index] == tree) || + (only_finalized && config.tree_metadata_size() > 0 && + !config.tree_metadata(tree).is_finalized())) { + ++index; + continue; + } + trees_to_keep->push_back(tree); + } +} + +void UpdatePredictions( + const int32 index_1, const int32 index_2, const float value, + tensorflow::TTypes::Matrix* output_predictions, + tensorflow::TTypes::Matrix* additional_output_predictions) { + (*output_predictions)(index_1, index_2) += value; + + if (additional_output_predictions != nullptr) { + (*additional_output_predictions)(index_1, index_2) += value; + } +} + +void UpdatePredictionsBasedOnTree( + const boosted_trees::trees::DecisionTreeEnsembleConfig& config, + const int32 tree_idx, const boosted_trees::utils::Example& example, + tensorflow::TTypes::Matrix* output_predictions, + tensorflow::TTypes::Matrix* additional_output_predictions) { + const boosted_trees::trees::DecisionTreeConfig& tree = config.trees(tree_idx); + const float tree_weight = config.tree_weights(tree_idx); + const int leaf_idx = trees::DecisionTree::Traverse(tree, 0, example); + QCHECK(leaf_idx >= 0) << "Invalid tree: " << tree.DebugString(); + const auto& leaf_node = tree.nodes(leaf_idx); + QCHECK(leaf_node.has_leaf()) + << "Invalid leaf node: " << leaf_node.DebugString(); + if (leaf_node.leaf().has_sparse_vector()) { + const auto& leaf = leaf_node.leaf().sparse_vector(); + QCHECK_EQ(leaf.index_size(), leaf.value_size()); + for (size_t class_idx = 0; class_idx < leaf.index_size(); ++class_idx) { + const float value = tree_weight * leaf.value(class_idx); + + UpdatePredictions(example.example_idx, leaf.index(class_idx), value, + output_predictions, additional_output_predictions); + } + } else { + QCHECK(leaf_node.leaf().has_vector()) << "Unknown leaf type"; + const auto& leaf = leaf_node.leaf().vector(); + for (size_t i = 0; i < leaf.value_size(); ++i) { + const float value = tree_weight * leaf.value(i); + UpdatePredictions(example.example_idx, i, value, output_predictions, + additional_output_predictions); + } + } +} + +} // namespace + +void MultipleAdditiveTrees::Predict( + const boosted_trees::trees::DecisionTreeEnsembleConfig& config, + const bool only_finalized_trees, const std::vector& trees_to_drop, + const boosted_trees::utils::BatchFeatures& features, + tensorflow::thread::ThreadPool* worker_threads, + tensorflow::TTypes::Matrix output_predictions, + tensorflow::TTypes::Matrix no_dropout_predictions) { + // Zero out predictions as the model is additive. + output_predictions.setZero(); + no_dropout_predictions.setZero(); + + // Get batch size. + const int64 batch_size = features.batch_size(); + if (batch_size <= 0) { + return; + } + + // Prepare the list of trees to keep. + std::vector trees_to_keep; + CalculateTreesToKeep(config, trees_to_drop, config.trees_size(), + only_finalized_trees, &trees_to_keep); + + // Lambda for doing a block of work. + auto update_predictions = [&config, &features, &trees_to_keep, &trees_to_drop, + &output_predictions, + &no_dropout_predictions](int64 start, int64 end) { + auto examples_iterable = features.examples_iterable(start, end); + for (const auto& example : examples_iterable) { + for (const int32 tree_idx : trees_to_keep) { + UpdatePredictionsBasedOnTree(config, tree_idx, example, + &output_predictions, + &no_dropout_predictions); + } + + // Now do predictions for dropped trees + for (const int32 tree_idx : trees_to_drop) { + UpdatePredictionsBasedOnTree(config, tree_idx, example, + &no_dropout_predictions, nullptr); + } + } + }; + + // TODO(salehay): parallelize this for low latency in serving path where + // batch size tends to be small but ensemble size tends to be large. + boosted_trees::utils::ParallelFor(batch_size, worker_threads->NumThreads(), + worker_threads, update_predictions); +} + +} // namespace models +} // namespace boosted_trees +} // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h new file mode 100644 index 00000000000..fedade20261 --- /dev/null +++ b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h @@ -0,0 +1,50 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_MODELS_MULTIPLE_ADDITIVE_TREES_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_MODELS_MULTIPLE_ADDITIVE_TREES_H_ + +#include + +#include "tensorflow/contrib/boosted_trees/lib/utils/batch_features.h" +#include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h" // NOLINT +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace boosted_trees { +namespace models { + +// Multiple additive trees prediction model. +// This class does not hold state and is thread safe. +class MultipleAdditiveTrees { + public: + // Predict runs tree ensemble on the given batch and updates + // output predictions accordingly. The method also returns predictions that + // we would get if no dropout was applied. + static void Predict( + const boosted_trees::trees::DecisionTreeEnsembleConfig& config, + const bool only_finalized_trees, const std::vector& trees_to_drop, + const boosted_trees::utils::BatchFeatures& features, + thread::ThreadPool* const thread_pool, + TTypes::Matrix output_predictions, + TTypes::Matrix no_dropout_predictions); +}; + +} // namespace models +} // namespace boosted_trees +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_MODELS_MULTIPLE_ADDITIVE_TREES_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees_test.cc b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees_test.cc new file mode 100644 index 00000000000..5f0924b48f2 --- /dev/null +++ b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees_test.cc @@ -0,0 +1,381 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#include "tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h" + +#include "tensorflow/contrib/boosted_trees/lib/testutil/batch_features_testutil.h" +#include "tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.h" +#include "tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/random/philox_random.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace tensorflow { +using boosted_trees::trees::DecisionTreeEnsembleConfig; +using test::AsTensor; + +namespace boosted_trees { +namespace models { +namespace { + +const int32 kNumThreadsMultiThreaded = 6; +const int32 kNumThreadsSingleThreaded = 1; + +class MultipleAdditiveTreesTest : public ::testing::Test { + protected: + MultipleAdditiveTreesTest() : batch_features_(2) { + // Create a batch of two examples having one dense feature each. + // The shape of the dense matrix is therefore 2x1 as in one row per example + // and one column per feature per example. + auto dense_matrix = test::AsTensor({7.0f, -2.0f}, {2, 1}); + TF_EXPECT_OK( + batch_features_.Initialize({dense_matrix}, {}, {}, {}, {}, {}, {})); + } + + boosted_trees::utils::BatchFeatures batch_features_; +}; + +TEST_F(MultipleAdditiveTreesTest, Empty) { + // Create empty tree ensemble. + DecisionTreeEnsembleConfig tree_ensemble_config; + auto output_tensor = AsTensor({9.0f, 23.0f}, {2, 1}); + auto output_matrix = output_tensor.matrix(); + auto no_dropout_output_matrix = output_tensor.matrix(); + + // Predict for both instances. + tensorflow::thread::ThreadPool threads(tensorflow::Env::Default(), "test", + kNumThreadsSingleThreaded); + MultipleAdditiveTrees::Predict(tree_ensemble_config, + false, // include non-finalized trees + {}, batch_features_, &threads, output_matrix, + no_dropout_output_matrix); + EXPECT_EQ(0, output_matrix(0, 0)); + EXPECT_EQ(0, output_matrix(1, 0)); + + // There was no dropout + for (int i = 0; i < 2; ++i) { + EXPECT_EQ(output_matrix(i, 0), no_dropout_output_matrix(i, 0)); + } +} + +TEST_F(MultipleAdditiveTreesTest, SingleClass) { + // Add one bias and one stump to ensemble for a single class. + DecisionTreeEnsembleConfig tree_ensemble_config; + auto* tree1 = tree_ensemble_config.add_trees(); + auto* bias_leaf = tree1->add_nodes()->mutable_leaf()->mutable_sparse_vector(); + bias_leaf->add_index(0); + bias_leaf->add_value(-0.4f); + auto* tree2 = tree_ensemble_config.add_trees(); + auto* dense_split = tree2->add_nodes()->mutable_dense_float_binary_split(); + dense_split->set_feature_column(0); + dense_split->set_threshold(5.0f); + dense_split->set_left_id(1); + dense_split->set_right_id(2); + auto* leaf1 = tree2->add_nodes()->mutable_leaf()->mutable_sparse_vector(); + leaf1->add_index(0); + leaf1->add_value(0.9f); + auto* leaf2 = tree2->add_nodes()->mutable_leaf()->mutable_sparse_vector(); + leaf2->add_index(0); + leaf2->add_value(0.2f); + + tree_ensemble_config.add_tree_weights(1.0); + tree_ensemble_config.add_tree_weights(1.0); + + auto output_tensor = AsTensor({0.0f, 0.0f}, {2, 1}); + auto output_matrix = output_tensor.matrix(); + + auto no_dropout_output_tensor = AsTensor({0.0f, 0.0f}, {2, 1}); + auto no_dropout_output_matrix = no_dropout_output_tensor.matrix(); + + tensorflow::thread::ThreadPool threads(tensorflow::Env::Default(), "test", + kNumThreadsSingleThreaded); + + // Normal case. + { + MultipleAdditiveTrees::Predict(tree_ensemble_config, + false, // include non-finalized trees + {}, batch_features_, &threads, output_matrix, + no_dropout_output_matrix); + EXPECT_FLOAT_EQ(-0.2f, output_matrix(0, 0)); // -0.4 (bias) + 0.2 (leaf 2). + EXPECT_FLOAT_EQ(0.5f, output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1). + + // No dropout predictions are the same. + for (int i = 0; i < 2; ++i) { + EXPECT_EQ(output_matrix(i, 0), no_dropout_output_matrix(i, 0)); + } + } + // Weighted case + { + DecisionTreeEnsembleConfig weighted = tree_ensemble_config; + weighted.set_tree_weights(0, 6.0); + weighted.set_tree_weights(1, 3.2); + MultipleAdditiveTrees::Predict(weighted, + false, // include non-finalized trees + {}, batch_features_, &threads, output_matrix, + no_dropout_output_matrix); + // -0.4 (bias) + 0.2 (leaf 2). + EXPECT_FLOAT_EQ(-0.4f * 6 + 0.2 * 3.2, output_matrix(0, 0)); + // -0.4 (bias) + 0.9 (leaf 1). + EXPECT_FLOAT_EQ(-0.4f * 6 + 0.9 * 3.2, output_matrix(1, 0)); + + // No dropout predictions are the same. + for (int i = 0; i < 2; ++i) { + EXPECT_EQ(output_matrix(i, 0), no_dropout_output_matrix(i, 0)); + } + } + // Drop first tree. + { + MultipleAdditiveTrees::Predict(tree_ensemble_config, + false, // include non-finalized trees + {0}, batch_features_, &threads, + output_matrix, no_dropout_output_matrix); + EXPECT_FLOAT_EQ(0.2f, output_matrix(0, 0)); // 0.2 (leaf 2). + EXPECT_FLOAT_EQ(0.9f, output_matrix(1, 0)); // 0.9 (leaf 1). + + // No dropout predictions + EXPECT_FLOAT_EQ( + -0.2f, no_dropout_output_matrix(0, 0)); // -0.4 (bias) + 0.2 (leaf 2). + EXPECT_FLOAT_EQ( + 0.5f, no_dropout_output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1). + } + // Drop second tree. + { + MultipleAdditiveTrees::Predict(tree_ensemble_config, + false, // include non-finalized trees + {1}, batch_features_, &threads, + output_matrix, no_dropout_output_matrix); + EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 0)); // -0.4 (bias). + EXPECT_FLOAT_EQ(-0.4f, output_matrix(1, 0)); // -0.4 (bias). + + // No dropout predictions + EXPECT_FLOAT_EQ( + -0.2f, no_dropout_output_matrix(0, 0)); // -0.4 (bias) + 0.2 (leaf 2). + EXPECT_FLOAT_EQ( + 0.5f, no_dropout_output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1). + } + // Drop all trees. + { + MultipleAdditiveTrees::Predict(tree_ensemble_config, + false, // include non-finalized trees + {0, 1}, batch_features_, &threads, + output_matrix, no_dropout_output_matrix); + EXPECT_FLOAT_EQ(0.0, output_matrix(0, 0)); + EXPECT_FLOAT_EQ(0.0, output_matrix(1, 0)); + + // No dropout predictions + EXPECT_FLOAT_EQ( + -0.2f, no_dropout_output_matrix(0, 0)); // -0.4 (bias) + 0.2 (leaf 2). + EXPECT_FLOAT_EQ( + 0.5f, no_dropout_output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1). + } +} + +TEST_F(MultipleAdditiveTreesTest, MultiClass) { + // Add one bias and one stump to ensemble for two classes. + DecisionTreeEnsembleConfig tree_ensemble_config; + auto* tree1 = tree_ensemble_config.add_trees(); + auto* bias_leaf = tree1->add_nodes()->mutable_leaf()->mutable_sparse_vector(); + bias_leaf->add_index(0); + bias_leaf->add_value(-0.4f); + bias_leaf->add_index(1); + bias_leaf->add_value(-0.7f); + auto* tree2 = tree_ensemble_config.add_trees(); + auto* dense_split = tree2->add_nodes()->mutable_dense_float_binary_split(); + dense_split->set_feature_column(0); + dense_split->set_threshold(5.0f); + dense_split->set_left_id(1); + dense_split->set_right_id(2); + auto* leaf1 = tree2->add_nodes()->mutable_leaf()->mutable_sparse_vector(); + leaf1->add_index(0); + leaf1->add_value(0.9f); + auto* leaf2 = tree2->add_nodes()->mutable_leaf()->mutable_sparse_vector(); + leaf2->add_index(1); + leaf2->add_value(0.2f); + + tree_ensemble_config.add_tree_weights(1.0); + tree_ensemble_config.add_tree_weights(1.0); + + // Predict for both instances. + tensorflow::thread::ThreadPool threads(tensorflow::Env::Default(), "test", + kNumThreadsSingleThreaded); + auto output_tensor = AsTensor({0.0f, 0.0f, 0.0f, 0.0f}, {2, 2}); + auto output_matrix = output_tensor.matrix(); + + auto no_dropout_output_tensor = + AsTensor({0.0f, 0.0f, 0.0f, 0.0f}, {2, 2}); + auto no_dropout_output_matrix = no_dropout_output_tensor.matrix(); + + // Normal case. + { + MultipleAdditiveTrees::Predict(tree_ensemble_config, + false, // include non-finalized trees + {}, batch_features_, &threads, output_matrix, + no_dropout_output_matrix); + EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 0)); // -0.4 (bias) + EXPECT_FLOAT_EQ(-0.5f, output_matrix(0, 1)); // -0.7 (bias) + 0.2 (leaf 2) + EXPECT_FLOAT_EQ(0.5f, output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1) + EXPECT_FLOAT_EQ(-0.7f, output_matrix(1, 1)); // -0.7 (bias) + + // No dropout predictions are the same. + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 2; ++j) { + EXPECT_EQ(output_matrix(i, j), no_dropout_output_matrix(i, j)); + } + } + } + // Weighted case. + { + DecisionTreeEnsembleConfig weighted = tree_ensemble_config; + weighted.set_tree_weights(0, 6.0); + weighted.set_tree_weights(1, 3.2); + MultipleAdditiveTrees::Predict(weighted, + false, // include non-finalized trees + {}, batch_features_, &threads, output_matrix, + no_dropout_output_matrix); + // bias + EXPECT_FLOAT_EQ(-0.4f * 6, output_matrix(0, 0)); + // bias + leaf 2 + EXPECT_FLOAT_EQ(-0.7f * 6 + 0.2f * 3.2, output_matrix(0, 1)); + // bias + leaf 2 + EXPECT_FLOAT_EQ(-0.4f * 6 + 0.9f * 3.2f, output_matrix(1, 0)); + // bias + EXPECT_FLOAT_EQ(-0.7f * 6, output_matrix(1, 1)); + } + // Dropout first tree. + { + MultipleAdditiveTrees::Predict(tree_ensemble_config, + false, // include non-finalized trees + {0}, batch_features_, &threads, + output_matrix, no_dropout_output_matrix); + EXPECT_FLOAT_EQ(0.0, output_matrix(0, 0)); + EXPECT_FLOAT_EQ(0.2f, output_matrix(0, 1)); // 0.2 (leaf 2) + EXPECT_FLOAT_EQ(0.9f, output_matrix(1, 0)); // 0.9 (leaf 2) + EXPECT_FLOAT_EQ(0.0f, output_matrix(1, 1)); + + // No dropout predictions + EXPECT_FLOAT_EQ(-0.4f, no_dropout_output_matrix(0, 0)); // -0.4 (bias) + EXPECT_FLOAT_EQ( + -0.5f, no_dropout_output_matrix(0, 1)); // -0.7 (bias) + 0.2 (leaf 2) + EXPECT_FLOAT_EQ( + 0.5f, no_dropout_output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 2) + EXPECT_FLOAT_EQ(-0.7f, no_dropout_output_matrix(1, 1)); // -0.7 (bias) + } + // Dropout second tree. + { + MultipleAdditiveTrees::Predict(tree_ensemble_config, + false, // include non-finalized trees + {1}, batch_features_, &threads, + output_matrix, no_dropout_output_matrix); + EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 0)); // -0.4 (bias) + EXPECT_FLOAT_EQ(-0.7f, output_matrix(0, 1)); // -0.7 (bias) + EXPECT_FLOAT_EQ(-0.4f, output_matrix(1, 0)); // -0.4 (bias) + EXPECT_FLOAT_EQ(-0.7f, output_matrix(1, 1)); // -0.7 (bias) + + // No dropout predictions + EXPECT_FLOAT_EQ(-0.4f, no_dropout_output_matrix(0, 0)); // -0.4 (bias) + EXPECT_FLOAT_EQ( + -0.5f, no_dropout_output_matrix(0, 1)); // -0.7 (bias) + 0.2 (leaf 2) + EXPECT_FLOAT_EQ( + 0.5f, no_dropout_output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 2) + EXPECT_FLOAT_EQ(-0.7f, no_dropout_output_matrix(1, 1)); // -0.7 (bias) + } + // Drop both trees. + { + MultipleAdditiveTrees::Predict(tree_ensemble_config, + false, // include non-finalized trees + {0, 1}, batch_features_, &threads, + output_matrix, no_dropout_output_matrix); + EXPECT_FLOAT_EQ(0.0f, output_matrix(0, 0)); + EXPECT_FLOAT_EQ(0.0f, output_matrix(0, 1)); + EXPECT_FLOAT_EQ(0.0f, output_matrix(1, 0)); + EXPECT_FLOAT_EQ(0.0f, output_matrix(1, 1)); + + // No dropout predictions + EXPECT_FLOAT_EQ(-0.4f, no_dropout_output_matrix(0, 0)); // -0.4 (bias) + EXPECT_FLOAT_EQ( + -0.5f, no_dropout_output_matrix(0, 1)); // -0.7 (bias) + 0.2 (leaf 2) + EXPECT_FLOAT_EQ( + 0.5f, no_dropout_output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 2) + EXPECT_FLOAT_EQ(-0.7f, no_dropout_output_matrix(1, 1)); // -0.7 (bias) + } +} + +TEST_F(MultipleAdditiveTreesTest, DenseLeaves) { + DecisionTreeEnsembleConfig tree_ensemble_config; + auto* tree1 = tree_ensemble_config.add_trees(); + auto* bias_leaf = tree1->add_nodes()->mutable_leaf()->mutable_vector(); + bias_leaf->add_value(-0.4f); + bias_leaf->add_value(-0.7f); + bias_leaf->add_value(3.0f); + auto* tree2 = tree_ensemble_config.add_trees(); + auto* dense_split = tree2->add_nodes()->mutable_dense_float_binary_split(); + dense_split->set_feature_column(0); + dense_split->set_threshold(5.0f); + dense_split->set_left_id(1); + dense_split->set_right_id(2); + auto* leaf1 = tree2->add_nodes()->mutable_leaf()->mutable_vector(); + leaf1->add_value(0.9f); + leaf1->add_value(0.8f); + leaf1->add_value(0.7f); + auto* leaf2 = tree2->add_nodes()->mutable_leaf()->mutable_vector(); + leaf2->add_value(0.2f); + leaf2->add_value(0.3f); + leaf2->add_value(0.4f); + + tree_ensemble_config.add_tree_weights(1.0); + tree_ensemble_config.add_tree_weights(1.0); + + // Predict for both instances. + tensorflow::thread::ThreadPool threads(tensorflow::Env::Default(), "test", + kNumThreadsSingleThreaded); + auto output_tensor = + AsTensor({0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}, {2, 3}); + auto output_matrix = output_tensor.matrix(); + + auto no_dropout_output_tensor = + AsTensor({0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}, {2, 3}); + auto no_dropout_output_matrix = no_dropout_output_tensor.matrix(); + + // Normal case. + { + MultipleAdditiveTrees::Predict(tree_ensemble_config, + false, // include non-finalized trees + {}, batch_features_, &threads, output_matrix, + no_dropout_output_matrix); + EXPECT_FLOAT_EQ(-0.2f, output_matrix(0, 0)); // -0.4 (tree1) + 0.2 (leaf 2) + EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 1)); // -0.7 (tree1) + 0.3 (leaf 2) + EXPECT_FLOAT_EQ(3.4f, output_matrix(0, 2)); // 3.0 -(tree1) + 0.4 (leaf 2) + EXPECT_FLOAT_EQ(0.5f, output_matrix(1, 0)); // -0.4 (tree1) + 0.9 (leaf 1) + EXPECT_FLOAT_EQ(0.1f, output_matrix(1, 1)); // -0.7 (tree1) + 0.8 (leaf 1) + EXPECT_FLOAT_EQ(3.7f, output_matrix(1, 2)); // 3.0 (tree1) + 0.7 (leaf 1) + + // No dropout predictions are the same. + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + EXPECT_EQ(output_matrix(i, j), no_dropout_output_matrix(i, j)); + } + } + } +} + +} // namespace +} // namespace models +} // namespace boosted_trees +} // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/lib/testutil/batch_features_testutil.cc b/tensorflow/contrib/boosted_trees/lib/testutil/batch_features_testutil.cc new file mode 100644 index 00000000000..39c2fbe9c99 --- /dev/null +++ b/tensorflow/contrib/boosted_trees/lib/testutil/batch_features_testutil.cc @@ -0,0 +1,88 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#include "tensorflow/contrib/boosted_trees/lib/testutil/batch_features_testutil.h" + +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace tensorflow { +namespace boosted_trees { +namespace testutil { + +using tensorflow::Tensor; + +void RandomlyInitializeBatchFeatures( + tensorflow::random::SimplePhilox* rng, uint32 num_dense_float_features, + uint32 num_sparse_float_features, double sparsity_lo, double sparsity_hi, + boosted_trees::utils::BatchFeatures* batch_features) { + const int64 batch_size = static_cast(batch_features->batch_size()); + + // Populate dense features. + std::vector dense_float_features_list; + for (int i = 0; i < num_dense_float_features; ++i) { + std::vector values; + for (int64 j = 0; j < batch_size; ++j) { + values.push_back(rng->RandFloat()); + } + auto dense_tensor = Tensor(tensorflow::DT_FLOAT, {batch_size, 1}); + tensorflow::test::FillValues(&dense_tensor, values); + dense_float_features_list.push_back(dense_tensor); + } + + // Populate sparse features. + std::vector sparse_float_feature_indices_list; + std::vector sparse_float_feature_values_list; + std::vector sparse_float_feature_shapes_list; + for (int i = 0; i < num_sparse_float_features; ++i) { + std::set indices; + const double sparsity = + sparsity_lo + rng->RandDouble() * (sparsity_hi - sparsity_lo); + const double density = 1 - sparsity; + for (int64 k = 0; k < static_cast(density * batch_size) + 1; ++k) { + indices.insert(rng->Uniform64(batch_size)); + } + const int64 sparse_values_size = indices.size(); + std::vector indices_vector; + for (auto idx : indices) { + indices_vector.push_back(idx); + indices_vector.push_back(0); + } + auto indices_tensor = Tensor(tensorflow::DT_INT64, {sparse_values_size, 2}); + tensorflow::test::FillValues(&indices_tensor, indices_vector); + sparse_float_feature_indices_list.push_back(indices_tensor); + + std::vector values; + for (int64 j = 0; j < sparse_values_size; ++j) { + values.push_back(rng->RandFloat()); + } + auto values_tensor = Tensor(tensorflow::DT_FLOAT, {sparse_values_size}); + tensorflow::test::FillValues(&values_tensor, values); + sparse_float_feature_values_list.push_back(values_tensor); + + auto shape_tensor = Tensor(tensorflow::DT_INT64, {2}); + tensorflow::test::FillValues(&shape_tensor, {batch_size, 1}); + sparse_float_feature_shapes_list.push_back(shape_tensor); + } + + // TODO(salehay): Add categorical feature generation support. + TF_EXPECT_OK(batch_features->Initialize( + dense_float_features_list, sparse_float_feature_indices_list, + sparse_float_feature_values_list, sparse_float_feature_shapes_list, {}, + {}, {})); +} + +} // namespace testutil +} // namespace boosted_trees +} // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/lib/testutil/batch_features_testutil.h b/tensorflow/contrib/boosted_trees/lib/testutil/batch_features_testutil.h new file mode 100644 index 00000000000..d95878ec87b --- /dev/null +++ b/tensorflow/contrib/boosted_trees/lib/testutil/batch_features_testutil.h @@ -0,0 +1,45 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_BATCH_FEATURES_TESTUTIL_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_BATCH_FEATURES_TESTUTIL_H_ + +#include "tensorflow/contrib/boosted_trees/lib/utils/batch_features.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/random/simple_philox.h" + +namespace tensorflow { +namespace boosted_trees { +namespace testutil { + +// This method calls Initialize on the given 'batch_features', which will be +// populated with randomly generated feature values when the call returns. +// 'tensors' returns a vector of all tensors used in the initialization, +// because they must outlive 'batch_features'. +// +// All float features will be either missing or uniformly randomly chosen +// from [0, 1). For sparse (float) features, a sparsity is uniformly randomly +// chosen from ['sparsity_lo', 'sparsity_hi') per feature, and each instance +// will have a probability of sparsity of missing that feature, in other words, +// sparsity = 1 - density. +void RandomlyInitializeBatchFeatures( + tensorflow::random::SimplePhilox* rng, uint32 num_dense_float_features, + uint32 num_sparse_float_features, double sparsity_lo, double sparsity_hi, + boosted_trees::utils::BatchFeatures* batch_features); + +} // namespace testutil +} // namespace boosted_trees +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_BATCH_FEATURES_TESTUTIL_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.cc b/tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.cc new file mode 100644 index 00000000000..24259ff6035 --- /dev/null +++ b/tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.cc @@ -0,0 +1,211 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#include "tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.h" + +#include "tensorflow/core/lib/random/philox_random.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace boosted_trees { +namespace testutil { + +using tensorflow::boosted_trees::trees::DecisionTreeConfig; +using tensorflow::boosted_trees::trees::TreeNode; +using boosted_trees::trees::DenseFloatBinarySplit; + +namespace { + +// Append the given nodes to tree with transfer of pointer ownership. +// nodes will not be usable upon return. +void AppendNodes(DecisionTreeConfig* tree, + proto2::RepeatedPtrField* nodes) { + std::reverse(nodes->pointer_begin(), nodes->pointer_end()); + while (!nodes->empty()) { + tree->mutable_nodes()->AddAllocated(nodes->ReleaseLast()); + } +} + +DenseFloatBinarySplit* GetSplit(TreeNode* node) { + switch (node->node_case()) { + case TreeNode::kSparseFloatBinarySplitDefaultLeft: + return node->mutable_sparse_float_binary_split_default_left() + ->mutable_split(); + case TreeNode::kSparseFloatBinarySplitDefaultRight: + return node->mutable_sparse_float_binary_split_default_right() + ->mutable_split(); + case TreeNode::kDenseFloatBinarySplit: + return node->mutable_dense_float_binary_split(); + default: + LOG(FATAL) << "Unknown node type encountered."; + } + return nullptr; +} + +} // namespace + +RandomTreeGen::RandomTreeGen(tensorflow::random::SimplePhilox* rng, + int dense_feature_size, int sparse_feature_size) + : rng_(rng), + dense_feature_size_(dense_feature_size), + sparse_feature_size_(sparse_feature_size) {} + +namespace { +void AddWeightAndMetadata( + boosted_trees::trees::DecisionTreeEnsembleConfig* ret) { + // Assign the weight of the tree to 1 and say that this weight was updated + // only once. + ret->add_tree_weights(1.0); + auto* meta = ret->add_tree_metadata(); + meta->set_num_tree_weight_updates(1); +} + +} // namespace + +boosted_trees::trees::DecisionTreeEnsembleConfig +RandomTreeGen::GenerateEnsemble(int depth, int tree_count) { + boosted_trees::trees::DecisionTreeEnsembleConfig ret; + *(ret.add_trees()) = Generate(depth); + AddWeightAndMetadata(&ret); + for (int i = 1; i < tree_count; ++i) { + *(ret.add_trees()) = Generate(ret.trees(0)); + AddWeightAndMetadata(&ret); + } + return ret; +} + +DecisionTreeConfig RandomTreeGen::Generate(const DecisionTreeConfig& tree) { + DecisionTreeConfig ret = tree; + for (auto& node : *ret.mutable_nodes()) { + if (node.node_case() == TreeNode::kLeaf) { + node.mutable_leaf()->mutable_sparse_vector()->set_value( + 0, rng_->RandFloat()); + continue; + } + // Original node is a split. Re-generate it's type but retain the split node + // indices. + DenseFloatBinarySplit* split = GetSplit(&node); + const int left_id = split->left_id(); + const int right_id = split->right_id(); + GenerateSplit(&node, left_id, right_id); + } + return ret; +} + +DecisionTreeConfig RandomTreeGen::Generate(int depth) { + DecisionTreeConfig ret; + // Add root, + TreeNode* node = ret.add_nodes(); + GenerateSplit(node, 1, 2); + if (depth == 1) { + // Add left and right leaves. + TreeNode* left = ret.add_nodes(); + left->mutable_leaf()->mutable_sparse_vector()->add_index(0); + left->mutable_leaf()->mutable_sparse_vector()->add_value(rng_->RandFloat()); + TreeNode* right = ret.add_nodes(); + right->mutable_leaf()->mutable_sparse_vector()->add_index(0); + right->mutable_leaf()->mutable_sparse_vector()->add_value( + rng_->RandFloat()); + return ret; + } else { + DecisionTreeConfig left_branch = Generate(depth - 1); + DecisionTreeConfig right_branch = Generate(depth - 1); + Combine(&ret, &left_branch, &right_branch); + return ret; + } +} + +void RandomTreeGen::Combine(DecisionTreeConfig* root, + DecisionTreeConfig* left_branch, + DecisionTreeConfig* right_branch) { + const int left_branch_size = left_branch->nodes_size(); + CHECK_EQ(1, root->nodes_size()); + // left_branch starts its index at 1. right_branch starts its index at + // (left_branch_size + 1). + auto* root_node = root->mutable_nodes(0); + DenseFloatBinarySplit* root_split = GetSplit(root_node); + root_split->set_left_id(1); + root_split->set_right_id(left_branch_size + 1); + // Shift left/right branch's indices internally so that everything is + // consistent. + ShiftNodeIndex(left_branch, 1); + ShiftNodeIndex(right_branch, left_branch_size + 1); + + // Complexity O(branch node size). No proto copying though. + AppendNodes(root, left_branch->mutable_nodes()); + AppendNodes(root, right_branch->mutable_nodes()); +} + +void RandomTreeGen::ShiftNodeIndex(DecisionTreeConfig* tree, int shift) { + for (TreeNode& node : *(tree->mutable_nodes())) { + DenseFloatBinarySplit* split = nullptr; + switch (node.node_case()) { + case TreeNode::kLeaf: + break; + case TreeNode::kSparseFloatBinarySplitDefaultLeft: + split = node.mutable_sparse_float_binary_split_default_left() + ->mutable_split(); + break; + case TreeNode::kSparseFloatBinarySplitDefaultRight: + split = node.mutable_sparse_float_binary_split_default_right() + ->mutable_split(); + break; + case TreeNode::kDenseFloatBinarySplit: + split = node.mutable_dense_float_binary_split(); + break; + default: + LOG(FATAL) << "Unknown node type encountered."; + } + if (split) { + split->set_left_id(shift + split->left_id()); + split->set_right_id(shift + split->right_id()); + } + } +} + +void RandomTreeGen::GenerateSplit(TreeNode* node, int left_id, int right_id) { + const double denseSplitProb = + sparse_feature_size_ == 0 + ? 1.0 + : static_cast(dense_feature_size_) / + (dense_feature_size_ + sparse_feature_size_); + // Generate the tree such that it has equal probability of going left and + // right when the feature is missing. + static constexpr float kLeftProb = 0.5; + + DenseFloatBinarySplit* split; + int feature_size; + if (rng_->RandFloat() < denseSplitProb) { + feature_size = dense_feature_size_; + split = node->mutable_dense_float_binary_split(); + } else { + feature_size = sparse_feature_size_; + if (rng_->RandFloat() < kLeftProb) { + split = node->mutable_sparse_float_binary_split_default_left() + ->mutable_split(); + } else { + split = node->mutable_sparse_float_binary_split_default_right() + ->mutable_split(); + } + } + split->set_threshold(rng_->RandFloat()); + split->set_feature_column(rng_->Uniform(feature_size)); + split->set_left_id(left_id); + split->set_right_id(right_id); +} + +} // namespace testutil +} // namespace boosted_trees +} // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.h b/tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.h new file mode 100644 index 00000000000..dc584bbd3cf --- /dev/null +++ b/tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.h @@ -0,0 +1,75 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_RANDOM_TREE_GEN_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_RANDOM_TREE_GEN_H_ + +#include + +#include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h" // NOLINT +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { +namespace boosted_trees { +namespace testutil { + +// Randomly generate a balanced tree, for performance benchmarking purposes, +// that assume all features are sparse float features, for now. +class RandomTreeGen { + public: + RandomTreeGen(tensorflow::random::SimplePhilox* rng, int dense_feature_size, + int sparse_feature_size); + + // Required: depth must be >= 1. + // If one wants to generate multiple trees with the same depth, see also the + // overload below. + boosted_trees::trees::DecisionTreeConfig Generate(int depth); + + // Randomly generate a new tree with the same depth (and tree structure) + // as the given tree. This is faster. + boosted_trees::trees::DecisionTreeConfig Generate( + const boosted_trees::trees::DecisionTreeConfig& tree); + + // Requried: depth >= 1; tree_count >= 1. + boosted_trees::trees::DecisionTreeEnsembleConfig GenerateEnsemble( + int dept, int tree_count); + + private: + tensorflow::random::SimplePhilox* rng_; + const int dense_feature_size_; + const int sparse_feature_size_; + + // Put together a deeper tree by combining two trees. + void Combine(boosted_trees::trees::DecisionTreeConfig* root, + boosted_trees::trees::DecisionTreeConfig* left_branch, + boosted_trees::trees::DecisionTreeConfig* right_branch); + + // For each node in the provided tree, shift its referenced left/right index + // by shift. + void ShiftNodeIndex(boosted_trees::trees::DecisionTreeConfig* tree, + int shift); + + // Generate a sparse split in the node. + void GenerateSplit(boosted_trees::trees::TreeNode* node, int left_id, + int right_id); + + TF_DISALLOW_COPY_AND_ASSIGN(RandomTreeGen); +}; + +} // namespace testutil +} // namespace boosted_trees +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_RANDOM_TREE_GEN_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen_main.cc b/tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen_main.cc new file mode 100644 index 00000000000..7d905d40828 --- /dev/null +++ b/tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen_main.cc @@ -0,0 +1,64 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// Randomly generate a tree ensemble and write to file. + +#include "tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/random/philox_random.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/command_line_flags.h" + +using tensorflow::Flag; +using tensorflow::Flags; + +int main(int argc, char* argv[]) { + int32 dense_feature_size = 100; + int32 sparse_feature_size = 100; + int32 depth = 8; + int32 tree_count = 10; + string filename = "/tmp/trees.pb"; + std::vector flag_list = { + Flag("dense_feature_size", &dense_feature_size, "dense feature size"), + Flag("sparse_feature_size", &sparse_feature_size, "sparse_feature_size"), + Flag("depth", &depth, "tree depth"), + Flag("tree_count", &tree_count, "tree count"), + Flag("filename", &filename, "Output filename."), + }; + string usage = Flags::Usage(argv[0], flag_list); + const bool parse_result = Flags::Parse(&argc, argv, flag_list); + // We need to call this to set up global state for TensorFlow. + tensorflow::port::InitMain(usage.c_str(), &argc, &argv); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return -1; + } + + tensorflow::random::PhiloxRandom philox(1); + tensorflow::random::SimplePhilox rng(&philox); + tensorflow::boosted_trees::testutil::RandomTreeGen tree_gen( + &rng, dense_feature_size, sparse_feature_size); + const auto& trees = tree_gen.GenerateEnsemble(depth, tree_count); + tensorflow::Status status = + tensorflow::WriteBinaryProto(tensorflow::Env::Default(), filename, trees); + if (!status.ok()) { + LOG(WARNING) << "Failed to write: " << filename << " : " << status; + } else { + LOG(INFO) << "Tree ensemble written to: " << filename; + } + return 0; +} diff --git a/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc b/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc new file mode 100644 index 00000000000..318d8a5296e --- /dev/null +++ b/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc @@ -0,0 +1,170 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { +namespace boosted_trees { +namespace trees { + +constexpr int kInvalidLeaf = -1; +int DecisionTree::Traverse(const DecisionTreeConfig& config, + const int32 sub_root_id, + const utils::Example& example) { + if (TF_PREDICT_FALSE(config.nodes_size() <= sub_root_id)) { + return kInvalidLeaf; + } + + // Traverse tree starting at the provided sub-root. + int32 node_id = sub_root_id; + while (true) { + const auto& current_node = config.nodes(node_id); + switch (current_node.node_case()) { + case TreeNode::kLeaf: { + return node_id; + } + case TreeNode::kDenseFloatBinarySplit: { + const auto& split = current_node.dense_float_binary_split(); + node_id = example.dense_float_features[split.feature_column()] <= + split.threshold() + ? split.left_id() + : split.right_id(); + break; + } + case TreeNode::kSparseFloatBinarySplitDefaultLeft: { + const auto& split = + current_node.sparse_float_binary_split_default_left().split(); + auto sparse_feature = + example.sparse_float_features[split.feature_column()]; + node_id = !sparse_feature.has_value() || + sparse_feature.get_value() <= split.threshold() + ? split.left_id() + : split.right_id(); + break; + } + case TreeNode::kSparseFloatBinarySplitDefaultRight: { + const auto& split = + current_node.sparse_float_binary_split_default_right().split(); + auto sparse_feature = + example.sparse_float_features[split.feature_column()]; + node_id = sparse_feature.has_value() && + sparse_feature.get_value() <= split.threshold() + ? split.left_id() + : split.right_id(); + break; + } + case TreeNode::kCategoricalIdBinarySplit: { + const auto& split = current_node.categorical_id_binary_split(); + node_id = example.sparse_int_features[split.feature_column()].count( + split.feature_id()) > 0 + ? split.left_id() + : split.right_id(); + break; + } + case TreeNode::NODE_NOT_SET: { + QCHECK(false) << "Invalid node in tree: " << current_node.DebugString(); + break; + } + } + DCHECK_NE(node_id, 0) << "Malformed tree, cycles found to root:" + << current_node.DebugString(); + } +} + +void DecisionTree::LinkChildren(const std::vector& children, + TreeNode* parent_node) { + // Decide how to link children depending on the parent node's type. + auto children_it = children.begin(); + switch (parent_node->node_case()) { + case TreeNode::kLeaf: { + // Essentially no-op. + QCHECK(children.empty()) << "A leaf node cannot have children."; + break; + } + case TreeNode::kDenseFloatBinarySplit: { + QCHECK(children.size() == 2) + << "A binary split node must have exactly two children."; + auto* split = parent_node->mutable_dense_float_binary_split(); + split->set_left_id(*children_it); + split->set_right_id(*++children_it); + break; + } + case TreeNode::kSparseFloatBinarySplitDefaultLeft: { + QCHECK(children.size() == 2) + << "A binary split node must have exactly two children."; + auto* split = + parent_node->mutable_sparse_float_binary_split_default_left() + ->mutable_split(); + split->set_left_id(*children_it); + split->set_right_id(*++children_it); + break; + } + case TreeNode::kSparseFloatBinarySplitDefaultRight: { + QCHECK(children.size() == 2) + << "A binary split node must have exactly two children."; + auto* split = + parent_node->mutable_sparse_float_binary_split_default_right() + ->mutable_split(); + split->set_left_id(*children_it); + split->set_right_id(*++children_it); + break; + } + case TreeNode::kCategoricalIdBinarySplit: { + QCHECK(children.size() == 2) + << "A binary split node must have exactly two children."; + auto* split = parent_node->mutable_categorical_id_binary_split(); + split->set_left_id(*children_it); + split->set_right_id(*++children_it); + break; + } + case TreeNode::NODE_NOT_SET: { + QCHECK(false) << "A non-set node cannot have children."; + break; + } + } +} + +std::vector DecisionTree::GetChildren(const TreeNode& node) { + // A node's children depend on its type. + switch (node.node_case()) { + case TreeNode::kLeaf: { + return {}; + } + case TreeNode::kDenseFloatBinarySplit: { + const auto& split = node.dense_float_binary_split(); + return {split.left_id(), split.right_id()}; + } + case TreeNode::kSparseFloatBinarySplitDefaultLeft: { + const auto& split = node.sparse_float_binary_split_default_left().split(); + return {split.left_id(), split.right_id()}; + } + case TreeNode::kSparseFloatBinarySplitDefaultRight: { + const auto& split = + node.sparse_float_binary_split_default_right().split(); + return {split.left_id(), split.right_id()}; + } + case TreeNode::kCategoricalIdBinarySplit: { + const auto& split = node.categorical_id_binary_split(); + return {split.left_id(), split.right_id()}; + } + case TreeNode::NODE_NOT_SET: { + return {}; + } + } +} + +} // namespace trees +} // namespace boosted_trees +} // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h b/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h new file mode 100644 index 00000000000..604ff02744b --- /dev/null +++ b/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h @@ -0,0 +1,49 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TREES_DECISION_TREE_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TREES_DECISION_TREE_H_ + +#include "tensorflow/contrib/boosted_trees/lib/utils/example.h" +#include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h" // NOLINT + +namespace tensorflow { +namespace boosted_trees { +namespace trees { + +// Decision tree class to encapsulate tree traversal and mutation logic. +// This class does not hold state and is thread safe. +class DecisionTree { + public: + // Traverse given an instance, a sub-root and its set of features + // and return the leaf index or -1 if the tree is empty or + // the sub-root is invalid. + static int Traverse(const DecisionTreeConfig& config, int32 sub_root_id, + const utils::Example& example); + + // Links the specified children to the parent, the children must + // already be added to the decision tree config so this method + // just ensures nodes are re-linked. + static void LinkChildren(const std::vector& children, + TreeNode* parent_node); + + // Retrieves node children indices if any. + static std::vector GetChildren(const TreeNode& node); +}; + +} // namespace trees +} // namespace boosted_trees +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TREES_DECISION_TREE_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/trees/decision_tree_test.cc b/tensorflow/contrib/boosted_trees/lib/trees/decision_tree_test.cc new file mode 100644 index 00000000000..0f082d7fd54 --- /dev/null +++ b/tensorflow/contrib/boosted_trees/lib/trees/decision_tree_test.cc @@ -0,0 +1,326 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h" +#include "tensorflow/contrib/boosted_trees/lib/utils/batch_features.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace boosted_trees { +namespace trees { +namespace { + +class DecisionTreeTest : public ::testing::Test { + protected: + DecisionTreeTest() : batch_features_(2) { + // Create a batch of two examples having one dense float, two sparse float + // and one sparse int features. + // The first example is missing the second sparse feature column and the + // second example is missing the first sparse feature column. + // This looks like the following: + // Instance | DenseF1 | SparseF1 | SparseF2 | SparseI1 | + // 0 | 7 | -3 | | 3 | + // 1 | -2 | | 4 | | + auto dense_float_matrix = test::AsTensor({7.0f, -2.0f}, {2, 1}); + auto sparse_float_indices1 = test::AsTensor({0, 0}, {1, 2}); + auto sparse_float_values1 = test::AsTensor({-3.0f}); + auto sparse_float_shape1 = test::AsTensor({2, 1}); + auto sparse_float_indices2 = test::AsTensor({1, 0}, {1, 2}); + auto sparse_float_values2 = test::AsTensor({4.0f}); + auto sparse_float_shape2 = test::AsTensor({2, 1}); + auto sparse_int_indices1 = test::AsTensor({0, 0}, {1, 2}); + auto sparse_int_values1 = test::AsTensor({3}); + auto sparse_int_shape1 = test::AsTensor({2, 1}); + TF_EXPECT_OK(batch_features_.Initialize( + {dense_float_matrix}, {sparse_float_indices1, sparse_float_indices2}, + {sparse_float_values1, sparse_float_values2}, + {sparse_float_shape1, sparse_float_shape2}, {sparse_int_indices1}, + {sparse_int_values1}, {sparse_int_shape1})); + } + + template + void TestLinkChildrenBinary(TreeNode* node, SplitType* split) { + // Verify children were linked. + DecisionTree::LinkChildren({3, 8}, node); + EXPECT_EQ(3, split->left_id()); + EXPECT_EQ(8, split->right_id()); + + // Invalid cases. + EXPECT_DEATH(DecisionTree::LinkChildren({}, node), + "A binary split node must have exactly two children."); + EXPECT_DEATH(DecisionTree::LinkChildren({3}, node), + "A binary split node must have exactly two children."); + EXPECT_DEATH(DecisionTree::LinkChildren({1, 2, 3}, node), + "A binary split node must have exactly two children."); + } + + void TestGetChildren(const TreeNode& node, + const std::vector& expected_children) { + // Verify children were linked. + auto children = DecisionTree::GetChildren(node); + EXPECT_EQ(children.size(), expected_children.size()); + for (size_t idx = 0; idx < children.size(); ++idx) { + EXPECT_EQ(children[idx], expected_children[idx]); + } + } + + utils::BatchFeatures batch_features_; +}; + +TEST_F(DecisionTreeTest, TraverseEmpty) { + DecisionTreeConfig tree_config; + auto example = (*batch_features_.examples_iterable(0, 1).begin()); + EXPECT_EQ(-1, DecisionTree::Traverse(tree_config, 0, example)); +} + +TEST_F(DecisionTreeTest, TraverseBias) { + DecisionTreeConfig tree_config; + tree_config.add_nodes()->mutable_leaf(); + auto example = (*batch_features_.examples_iterable(0, 1).begin()); + EXPECT_EQ(0, DecisionTree::Traverse(tree_config, 0, example)); +} + +TEST_F(DecisionTreeTest, TraverseInvalidSubRoot) { + DecisionTreeConfig tree_config; + tree_config.add_nodes()->mutable_leaf(); + auto example = (*batch_features_.examples_iterable(0, 1).begin()); + EXPECT_EQ(-1, DecisionTree::Traverse(tree_config, 10, example)); +} + +TEST_F(DecisionTreeTest, TraverseDenseBinarySplit) { + DecisionTreeConfig tree_config; + auto* split_node = + tree_config.add_nodes()->mutable_dense_float_binary_split(); + split_node->set_feature_column(0); + split_node->set_threshold(0.0f); + split_node->set_left_id(1); + split_node->set_right_id(2); + tree_config.add_nodes()->mutable_leaf(); + tree_config.add_nodes()->mutable_leaf(); + auto example_iterable = batch_features_.examples_iterable(0, 2); + + // Expect right child to be picked as !(7 <= 0); + auto example_it = example_iterable.begin(); + EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *example_it)); + + // Expect left child to be picked as (-2 <= 0); + EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *++example_it)); +} + +TEST_F(DecisionTreeTest, TraverseSparseBinarySplit) { + // Test first sparse feature which is missing for the second example. + DecisionTreeConfig tree_config1; + auto* split_node1 = tree_config1.add_nodes() + ->mutable_sparse_float_binary_split_default_left() + ->mutable_split(); + split_node1->set_feature_column(0); + split_node1->set_threshold(-20.0f); + split_node1->set_left_id(1); + split_node1->set_right_id(2); + tree_config1.add_nodes()->mutable_leaf(); + tree_config1.add_nodes()->mutable_leaf(); + auto example_iterable = batch_features_.examples_iterable(0, 2); + + // Expect right child to be picked as !(-3 <= -20). + auto example_it = example_iterable.begin(); + EXPECT_EQ(2, DecisionTree::Traverse(tree_config1, 0, *example_it)); + + // Expect left child to be picked as default direction. + EXPECT_EQ(1, DecisionTree::Traverse(tree_config1, 0, *++example_it)); + + // Test second sparse feature which is missing for the first example. + DecisionTreeConfig tree_config2; + auto* split_node2 = tree_config2.add_nodes() + ->mutable_sparse_float_binary_split_default_right() + ->mutable_split(); + split_node2->set_feature_column(1); + split_node2->set_threshold(4.0f); + split_node2->set_left_id(1); + split_node2->set_right_id(2); + tree_config2.add_nodes()->mutable_leaf(); + tree_config2.add_nodes()->mutable_leaf(); + + // Expect right child to be picked as default direction. + example_it = example_iterable.begin(); + EXPECT_EQ(2, DecisionTree::Traverse(tree_config2, 0, *example_it)); + + // Expect left child to be picked as (4 <= 4). + EXPECT_EQ(1, DecisionTree::Traverse(tree_config2, 0, *++example_it)); +} + +TEST_F(DecisionTreeTest, TraverseCategoricalIdBinarySplit) { + DecisionTreeConfig tree_config; + auto* split_node = + tree_config.add_nodes()->mutable_categorical_id_binary_split(); + split_node->set_feature_column(0); + split_node->set_feature_id(3); + split_node->set_left_id(1); + split_node->set_right_id(2); + tree_config.add_nodes()->mutable_leaf(); + tree_config.add_nodes()->mutable_leaf(); + auto example_iterable = batch_features_.examples_iterable(0, 2); + + // Expect left child to be picked as 3 == 3; + auto example_it = example_iterable.begin(); + EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *example_it)); + + // Expect right child to be picked as the feature is missing; + EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *++example_it)); +} + +TEST_F(DecisionTreeTest, TraverseHybridSplits) { + DecisionTreeConfig tree_config; + auto* split_node1 = + tree_config.add_nodes()->mutable_dense_float_binary_split(); + split_node1->set_feature_column(0); + split_node1->set_threshold(9.0f); + split_node1->set_left_id(1); // sparse split. + split_node1->set_right_id(2); // leaf + auto* split_node2 = tree_config.add_nodes() + ->mutable_sparse_float_binary_split_default_left() + ->mutable_split(); + tree_config.add_nodes()->mutable_leaf(); + split_node2->set_feature_column(0); + split_node2->set_threshold(-20.0f); + split_node2->set_left_id(3); + split_node2->set_right_id(4); + auto* split_node3 = + tree_config.add_nodes()->mutable_categorical_id_binary_split(); + split_node3->set_feature_column(0); + split_node3->set_feature_id(2); + split_node3->set_left_id(5); + split_node3->set_right_id(6); + tree_config.add_nodes()->mutable_leaf(); + tree_config.add_nodes()->mutable_leaf(); + tree_config.add_nodes()->mutable_leaf(); + auto example_iterable = batch_features_.examples_iterable(0, 2); + + // Expect will go left through the first dense split as (7.0f <= 9.0f), + // then will go right through the sparse split as !(-3 <= -20). + auto example_it = example_iterable.begin(); + EXPECT_EQ(4, DecisionTree::Traverse(tree_config, 0, *example_it)); + + // Expect will go left through the first dense split as (-2.0f <= 9.0f), + // then will go left the default direction as the sparse feature is missing, + // then will go right as 2 != 3 on the categorical split. + EXPECT_EQ(6, DecisionTree::Traverse(tree_config, 0, *++example_it)); +} + +TEST_F(DecisionTreeTest, LinkChildrenLeaf) { + // Create leaf node. + TreeNode node; + node.mutable_leaf(); + + // No-op. + DecisionTree::LinkChildren({}, &node); + + // Invalid case. + EXPECT_DEATH(DecisionTree::LinkChildren({1}, &node), + "A leaf node cannot have children."); +} + +TEST_F(DecisionTreeTest, LinkChildrenDenseFloatBinarySplit) { + TreeNode node; + auto* split = node.mutable_dense_float_binary_split(); + split->set_left_id(-1); + split->set_right_id(-1); + TestLinkChildrenBinary(&node, split); +} + +TEST_F(DecisionTreeTest, LinkChildrenSparseFloatBinarySplitDefaultLeft) { + TreeNode node; + auto* split = + node.mutable_sparse_float_binary_split_default_left()->mutable_split(); + split->set_left_id(-1); + split->set_right_id(-1); + TestLinkChildrenBinary(&node, split); +} + +TEST_F(DecisionTreeTest, LinkChildrenSparseFloatBinarySplitDefaultRight) { + TreeNode node; + auto* split = + node.mutable_sparse_float_binary_split_default_right()->mutable_split(); + split->set_left_id(-1); + split->set_right_id(-1); + TestLinkChildrenBinary(&node, split); +} + +TEST_F(DecisionTreeTest, LinkChildrenCategoricalSingleIdBinarySplit) { + TreeNode node; + auto* split = node.mutable_categorical_id_binary_split(); + split->set_left_id(-1); + split->set_right_id(-1); + TestLinkChildrenBinary(&node, split); +} + +TEST_F(DecisionTreeTest, LinkChildrenNodeNotSet) { + // Create unset node. + TreeNode node; + + // Invalid case. + EXPECT_DEATH(DecisionTree::LinkChildren({1}, &node), + "A non-set node cannot have children."); +} + +TEST_F(DecisionTreeTest, GetChildrenLeaf) { + TreeNode node; + node.mutable_leaf(); + TestGetChildren(node, {}); +} + +TEST_F(DecisionTreeTest, GetChildrenDenseFloatBinarySplit) { + TreeNode node; + auto* split = node.mutable_dense_float_binary_split(); + split->set_left_id(23); + split->set_right_id(24); + TestGetChildren(node, {23, 24}); +} + +TEST_F(DecisionTreeTest, GetChildrenSparseFloatBinarySplitDefaultLeft) { + TreeNode node; + auto* split = + node.mutable_sparse_float_binary_split_default_left()->mutable_split(); + split->set_left_id(12); + split->set_right_id(13); + TestGetChildren(node, {12, 13}); +} + +TEST_F(DecisionTreeTest, GetChildrenSparseFloatBinarySplitDefaultRight) { + TreeNode node; + auto* split = + node.mutable_sparse_float_binary_split_default_right()->mutable_split(); + split->set_left_id(1); + split->set_right_id(2); + TestGetChildren(node, {1, 2}); +} + +TEST_F(DecisionTreeTest, GetChildrenCategoricalSingleIdBinarySplit) { + TreeNode node; + auto* split = node.mutable_categorical_id_binary_split(); + split->set_left_id(7); + split->set_right_id(8); + TestGetChildren(node, {7, 8}); +} + +TEST_F(DecisionTreeTest, GetChildrenNodeNotSet) { + TreeNode node; + TestGetChildren(node, {}); +} + +} // namespace +} // namespace trees +} // namespace boosted_trees +} // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/proto/BUILD b/tensorflow/contrib/boosted_trees/proto/BUILD index 3b6b0339d2e..c99d8849bd5 100644 --- a/tensorflow/contrib/boosted_trees/proto/BUILD +++ b/tensorflow/contrib/boosted_trees/proto/BUILD @@ -24,6 +24,15 @@ tf_proto_library( visibility = ["//visibility:public"], ) +tf_proto_library( + name = "quantiles_proto", + srcs = [ + "quantiles.proto", + ], + cc_api_version = 2, + visibility = ["//visibility:public"], +) + tf_proto_library( name = "tree_config_proto", srcs = ["tree_config.proto"], diff --git a/tensorflow/contrib/boosted_trees/proto/quantiles.proto b/tensorflow/contrib/boosted_trees/proto/quantiles.proto new file mode 100644 index 00000000000..7f872d2aa71 --- /dev/null +++ b/tensorflow/contrib/boosted_trees/proto/quantiles.proto @@ -0,0 +1,32 @@ +syntax = "proto3"; + +option cc_enable_arenas = true; + +package boosted_trees; + +message QuantileConfig { + // Maximum eps error when computing quantile summaries. + double eps = 1; + // Number of quantiles to generate. + int64 num_quantiles = 2; +} + +message QuantileEntry { + // Value for the entry. + float value = 1; + // Weight for the entry. + float weight = 2; + // We need the minimum and maximum rank possible for this entry. + // Rank is 0.0 for the absolute minimum and sum of the weights for the maximum + // value in the input. + float min_rank = 3; + float max_rank = 4; +} + +message QuantileSummaryState { + repeated QuantileEntry entries = 1; +} + +message QuantileStreamState { + repeated QuantileSummaryState summaries = 1; +} diff --git a/tensorflow/contrib/boosted_trees/resources/BUILD b/tensorflow/contrib/boosted_trees/resources/BUILD new file mode 100644 index 00000000000..5dfdf8f4896 --- /dev/null +++ b/tensorflow/contrib/boosted_trees/resources/BUILD @@ -0,0 +1,53 @@ +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +package( + default_visibility = [ + "//tensorflow/contrib/boosted_trees:__subpackages__", + "//tensorflow/contrib/boosted_trees:friends", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +cc_library( + name = "stamped_resource", + hdrs = ["stamped_resource.h"], + deps = [ + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + ], +) + +cc_library( + name = "quantile_stream_resource", + hdrs = ["quantile_stream_resource.h"], + deps = [ + ":stamped_resource", + "//tensorflow/contrib/boosted_trees/lib:weighted_quantiles", + "//tensorflow/contrib/boosted_trees/proto:quantiles_proto_cc", + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + ], +) + +cc_library( + name = "decision_tree_ensemble_resource", + hdrs = ["decision_tree_ensemble_resource.h"], + deps = [ + ":stamped_resource", + "//tensorflow/contrib/boosted_trees/lib:trees", + "//tensorflow/core:framework_headers_lib", + ], + alwayslink = 1, +) diff --git a/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h b/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h new file mode 100644 index 00000000000..90e641f4bd6 --- /dev/null +++ b/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h @@ -0,0 +1,76 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_DECISION_TREE_ENSEMBLE_RESOURCE_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_DECISION_TREE_ENSEMBLE_RESOURCE_H_ + +#include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h" +#include "tensorflow/contrib/boosted_trees/resources/stamped_resource.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { +namespace boosted_trees { +namespace models { + +// Keep a tree ensemble in memory for efficient evaluation and mutation. +class DecisionTreeEnsembleResource : public StampedResource { + public: + // Constructor. + explicit DecisionTreeEnsembleResource() + : decision_tree_ensemble_( + proto2::Arena::CreateMessage< + boosted_trees::trees::DecisionTreeEnsembleConfig>(&arena_)) {} + + string DebugString() override { + return strings::StrCat("GTFlowDecisionTreeEnsemble[size=", + decision_tree_ensemble_->trees_size(), "]"); + } + + const boosted_trees::trees::DecisionTreeEnsembleConfig& + decision_tree_ensemble() const { + return *decision_tree_ensemble_; + } + + boosted_trees::trees::DecisionTreeEnsembleConfig* + mutable_decision_tree_ensemble() { + return decision_tree_ensemble_; + } + + // Resets the resource and frees the protos in arena. + // Caller needs to hold the mutex lock while calling this. + void Reset() { + // Reset stamp. + set_stamp(-1); + + // Clear tree ensemle. + arena_.Reset(); + CHECK_EQ(0, arena_.SpaceAllocated()); + decision_tree_ensemble_ = proto2::Arena::CreateMessage< + boosted_trees::trees::DecisionTreeEnsembleConfig>(&arena_); + } + + mutex* get_mutex() { return &mu_; } + + private: + proto2::Arena arena_; + mutex mu_; + boosted_trees::trees::DecisionTreeEnsembleConfig* decision_tree_ensemble_; +}; + +} // namespace models +} // namespace boosted_trees +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_DECISION_TREE_ENSEMBLE_RESOURCE_H_ diff --git a/tensorflow/contrib/boosted_trees/resources/quantile_stream_resource.h b/tensorflow/contrib/boosted_trees/resources/quantile_stream_resource.h new file mode 100644 index 00000000000..fb29f79e578 --- /dev/null +++ b/tensorflow/contrib/boosted_trees/resources/quantile_stream_resource.h @@ -0,0 +1,104 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_QUANTILE_STREAM_RESOURCE_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_QUANTILE_STREAM_RESOURCE_H_ + +#include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h" +#include "tensorflow/contrib/boosted_trees/proto/quantiles.pb.h" // NOLINT +#include "tensorflow/contrib/boosted_trees/resources/stamped_resource.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { +namespace boosted_trees { + +using QuantileStream = + boosted_trees::quantiles::WeightedQuantilesStream; + +// Resource for accumulating summaries for multiple columns. +class QuantileStreamResource : public StampedResource { + public: + QuantileStreamResource(const float epsilon, const int32 num_quantiles, + const int64 max_elements, int64 stamp_token) + : stream_(epsilon, max_elements), + are_buckets_ready_(false), + epsilon_(epsilon), + num_quantiles_(num_quantiles), + max_elements_(max_elements) { + set_stamp(stamp_token); + } + + string DebugString() override { return "QuantileStreamResource"; } + + tensorflow::mutex* mutex() { return &mu_; } + + QuantileStream* stream(int64 stamp) { + CHECK(is_stamp_valid(stamp)); + return &stream_; + } + + const std::vector& boundaries(int64 stamp) { + CHECK(is_stamp_valid(stamp)); + return boundaries_; + } + + void set_boundaries(int64 stamp, const std::vector& boundaries) { + CHECK(is_stamp_valid(stamp)); + are_buckets_ready_ = true; + boundaries_ = boundaries; + } + + float epsilon() const { return epsilon_; } + int32 num_quantiles() const { return num_quantiles_; } + + void Reset(int64 stamp) { + set_stamp(stamp); + stream_ = QuantileStream(epsilon_, max_elements_); + } + + bool are_buckets_ready() const { return are_buckets_ready_; } + void set_buckets_ready(bool are_buckets_ready) { + are_buckets_ready_ = are_buckets_ready; + } + + private: + ~QuantileStreamResource() override {} + + // Mutex for the whole resource. + tensorflow::mutex mu_; + + // Quantile stream. + QuantileStream stream_; + + // Stores the boundaries from the previous iteration. Empty during the first + // iteration. + std::vector boundaries_; + + // Whether boundaries are created. Initially boundaries are empty until + // set_boundaries are called. + bool are_buckets_ready_; + + const float epsilon_; + const int32 num_quantiles_; + // An upper-bound for the number of elements. + int64 max_elements_; + TF_DISALLOW_COPY_AND_ASSIGN(QuantileStreamResource); +}; + +} // namespace boosted_trees +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_QUANTILE_STREAM_RESOURCE_H_ diff --git a/tensorflow/contrib/boosted_trees/resources/stamped_resource.h b/tensorflow/contrib/boosted_trees/resources/stamped_resource.h new file mode 100644 index 00000000000..aabeeb98516 --- /dev/null +++ b/tensorflow/contrib/boosted_trees/resources/stamped_resource.h @@ -0,0 +1,42 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_STAMPED_RESOURCE_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_STAMPED_RESOURCE_H_ + +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { +namespace boosted_trees { + +// A StampedResource is a resource that has a stamp token associated with it. +// Before reading from or applying updates to the resource, the stamp should +// be checked to verify that the update is not stale. +class StampedResource : public ResourceBase { + public: + StampedResource() : stamp_(-1) {} + + bool is_stamp_valid(int64 stamp) const { return stamp_ == stamp; } + + int64 stamp() const { return stamp_; } + void set_stamp(int64 stamp) { stamp_ = stamp; } + + private: + int64 stamp_; +}; + +} // namespace boosted_trees +} // namespace tensorflow +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_STAMPED_RESOURCE_H_ From bbd2047cf3a715a1431889ad8f558576a5382876 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 31 Mar 2017 09:11:32 -0800 Subject: [PATCH 09/81] [XLA:HLO] Minor fix for Clamp shape inference, and add some tests. Previously Clamp(f32[5], f32[], f32[9]) returned success, but it now returns a failure. Noticed while debugging a different problem. Change: 151835981 --- .../compiler/xla/service/shape_inference.cc | 56 +++++++---- .../compiler/xla/service/shape_inference.h | 4 + .../xla/service/shape_inference_test.cc | 93 +++++++++++++++++++ 3 files changed, 133 insertions(+), 20 deletions(-) diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index c05cf8c37d8..9472086e2b4 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -633,26 +633,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( TF_DCHECK_OK(ShapeUtil::ValidateShape(ehs)); switch (operation) { case TRIOP_CLAMP: - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(lhs, "lhs of ternary operation")); - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(rhs, "rhs of ternary operation")); - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(ehs, "ehs of ternary operation")); - if (((ShapeUtil::Compatible(lhs, rhs) || ShapeUtil::Rank(lhs) == 0) && - (ShapeUtil::Compatible(rhs, ehs) || ShapeUtil::Rank(ehs) == 0))) { - return rhs; - } - if (ShapeUtil::Rank(rhs) == 0) { - if (ShapeUtil::Compatible(lhs, ehs)) { - return lhs; - } - return ShapeUtil::Rank(ehs) == 0 ? lhs : ehs; - } - return Unimplemented("not yet implemented: %s, %s %s", - lhs.ShortDebugString().c_str(), - ehs.ShortDebugString().c_str(), - rhs.ShortDebugString().c_str()); + return InferClampShape(lhs, rhs, ehs); case TRIOP_SELECT: return InferSelectShape(lhs, rhs, ehs); case TRIOP_UPDATE: @@ -1332,6 +1313,41 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return ShapeUtil::PermuteDimensions(InversePermutation(dimensions), operand); } +// TODO(b/36794510): Make broadcast semantics more consistent, by supporting +// "degenerate" cases, as with binary elementwise ops. +/* static */ StatusOr ShapeInference::InferClampShape( + const Shape& min, const Shape& operand, const Shape& max) { + TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(min, "clamp min")); + TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "clamp operand")); + TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(max, "clamp max")); + if (!ShapeUtil::SameElementType(min, operand) || + !ShapeUtil::SameElementType(max, operand)) { + return InvalidArgument("clamp op with different operand types: %s, %s, %s", + ShapeUtil::HumanString(min).c_str(), + ShapeUtil::HumanString(operand).c_str(), + ShapeUtil::HumanString(max).c_str()); + } + if (((ShapeUtil::Compatible(min, operand) || ShapeUtil::IsScalar(min)) && + (ShapeUtil::Compatible(max, operand) || ShapeUtil::IsScalar(max)))) { + return operand; + } + if (ShapeUtil::IsScalar(operand)) { + if (ShapeUtil::Compatible(min, max)) { + return min; + } else if (ShapeUtil::IsScalar(min)) { + return max; + } else if (ShapeUtil::IsScalar(max)) { + return min; + } + } + return Unimplemented( + "not yet implemented: %s, %s %s", min.ShortDebugString().c_str(), + max.ShortDebugString().c_str(), operand.ShortDebugString().c_str()); +} + +// TODO(b/36794510): Make broadcast semantics more consistent, by supporting +// "degenerate" cases, as with binary elementwise ops, as well as scalar +// broadcast from all operands, not just the predicate. /* static */ StatusOr ShapeInference::InferSelectShape( const Shape& pred, const Shape& on_true, const Shape& on_false) { if (!ShapeUtil::Compatible(on_true, on_false)) { diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index ced2f4d0017..c2223423e92 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -190,6 +190,10 @@ class ShapeInference { BinaryOperation operation, const Shape& lhs, const Shape& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions); + // Helper for inferring the shape of Clamp ops. + static StatusOr InferClampShape(const Shape& min, const Shape& operand, + const Shape& max); + // Helper for inferring the shape of Select ops. static StatusOr InferSelectShape(const Shape& pred, const Shape& on_true, diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 5a1ae6b0024..6f968ded568 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -157,6 +157,99 @@ TEST_F(ShapeInferenceTest, SelectBadShapes) { testing::ContainsRegex("pred operand must have PRED element type")); } +TEST_F(ShapeInferenceTest, ClampAllMatrix) { + auto inferred_status = ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, matrix_64_48_, matrix_64_48_, + matrix_64_48_); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); +} + +TEST_F(ShapeInferenceTest, ClampAllScalar) { + auto inferred_status = ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, f32_, f32_, f32_); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie())); +} + +TEST_F(ShapeInferenceTest, ClampMinScalar) { + auto inferred_status = ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, f32_, matrix_64_48_, matrix_64_48_); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); +} + +TEST_F(ShapeInferenceTest, ClampMaxScalar) { + auto inferred_status = ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, matrix_64_48_, matrix_64_48_, f32_); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); +} + +TEST_F(ShapeInferenceTest, ClampOperandScalar) { + auto inferred_status = ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, matrix_64_48_, f32_, matrix_64_48_); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); +} + +TEST_F(ShapeInferenceTest, ClampMinMatrix) { + auto inferred_status = ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, matrix_64_48_, f32_, f32_); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); +} + +TEST_F(ShapeInferenceTest, ClampMaxMatrix) { + auto inferred_status = ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, f32_, f32_, matrix_64_48_); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); +} + +TEST_F(ShapeInferenceTest, ClampOperandMatrix) { + auto inferred_status = ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, f32_, matrix_64_48_, f32_); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); +} + +TEST_F(ShapeInferenceTest, ClampBadShapes) { + // Type mismatch + ASSERT_FALSE(ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, s32_, f32_, f32_) + .ok()); + ASSERT_FALSE(ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, f32_, s32_, f32_) + .ok()); + ASSERT_FALSE(ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, f32_, f32_, s32_) + .ok()); + // Dimension mismatch + ASSERT_FALSE( + ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP, + vector_64_, vector_32_, vector_32_) + .ok()); + ASSERT_FALSE( + ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP, + vector_32_, vector_64_, vector_32_) + .ok()); + ASSERT_FALSE( + ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP, + vector_32_, vector_32_, vector_64_) + .ok()); + // Dimension mismatch, where one operand is a scalar + ASSERT_FALSE(ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, vector_64_, vector_32_, f32_) + .ok()); + ASSERT_FALSE(ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, vector_64_, f32_, vector_32_) + .ok()); + ASSERT_FALSE(ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, f32_, vector_64_, vector_32_) + .ok()); +} + TEST_F(ShapeInferenceTest, VariadicOpTuplify) { StatusOr result = ShapeInference::InferVariadicOpShape( VariadicOperation::VAROP_TUPLE, {&s32_, &f32_}); From 47cd4cd100800482b57d1b7755dfdcbc04969ffe Mon Sep 17 00:00:00 2001 From: Geoffrey Irving Date: Fri, 31 Mar 2017 09:16:25 -0800 Subject: [PATCH 10/81] Use a simpler threshold in MultiplyWithoutOverflow We only need to worry about overflowing uint64, not int64. Change: 151836590 --- tensorflow/core/util/overflow.h | 5 ++--- tensorflow/core/util/overflow_test.cc | 8 ++++++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/tensorflow/core/util/overflow.h b/tensorflow/core/util/overflow.h index 13b1f305aa7..04be68a111e 100644 --- a/tensorflow/core/util/overflow.h +++ b/tensorflow/core/util/overflow.h @@ -31,9 +31,8 @@ inline int64 MultiplyWithoutOverflow(const int64 x, const int64 y) { const uint64 uy = y; const uint64 uxy = ux * uy; - // Check for overflow, using a cheap check if both inputs are small - static const uint64 kSqrtInt64Max = 3037000500; // ceil(sqrt(2**63 - 1)) - if (TF_PREDICT_FALSE(ux >= kSqrtInt64Max || uy >= kSqrtInt64Max)) { + // Check if we overflow uint64, using a cheap check if both inputs are small + if (TF_PREDICT_FALSE((ux | uy) >> 32 != 0)) { // Ensure nonnegativity. Note that negative numbers will appear "large" // to the unsigned comparisons above. CHECK(x >= 0 && y >= 0); diff --git a/tensorflow/core/util/overflow_test.cc b/tensorflow/core/util/overflow_test.cc index 627f77164e9..f93ba885e6d 100644 --- a/tensorflow/core/util/overflow_test.cc +++ b/tensorflow/core/util/overflow_test.cc @@ -30,8 +30,12 @@ TEST(OverflowTest, Nonnegative) { interesting.push_back(bit + 1); interesting.push_back(bit - 1); } - auto mid = static_cast(std::pow(2, 63.0 / 2)); - for (int i = -5; i < 5; i++) interesting.push_back(mid + i); + for (const int64 mid : {static_cast(1) << 32, + static_cast(std::pow(2, 63.0 / 2))}) { + for (int i = -5; i < 5; i++) { + interesting.push_back(mid + i); + } + } // Check all pairs for (auto x : interesting) { From 93e822ea4f16ec33110ef4d2bb24d2f9aa2e9eaa Mon Sep 17 00:00:00 2001 From: Zakaria Haque Date: Fri, 31 Mar 2017 09:29:14 -0800 Subject: [PATCH 11/81] Fixes a bug where heads/pre-canned estimators were not exporting proper classes tensor. Servo expects classes to be a string tensor of the same shape as scores and containing the labels for corresponding scores. While creating output alternatives, if classes tensor does not match these conditions, we create a new tensor with these properties. Change: 151838225 --- .../learn/python/learn/estimators/head.py | 95 ++++++++++++++----- .../python/learn/estimators/head_test.py | 60 ++++++++++-- 2 files changed, 125 insertions(+), 30 deletions(-) diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index 028a13ca20a..65f5b49b0e4 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -42,6 +42,7 @@ from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import sparse_ops +from tensorflow.python.ops import string_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.summary import summary @@ -816,7 +817,8 @@ class _BinaryLogisticHead(_SingleHead): loss_fn=self._loss_fn, logits_to_predictions_fn=self._logits_to_predictions, metrics_fn=self._metrics, - create_output_alternatives_fn=self._create_output_alternatives, + create_output_alternatives_fn=_classification_output_alternatives( + self.head_name, self._problem_type), labels=labels, train_op_fn=train_op_fn, logits=logits, @@ -1009,7 +1011,8 @@ class _MultiClassHead(_SingleHead): loss_fn=self._wrapped_loss_fn, logits_to_predictions_fn=self._logits_to_predictions, metrics_fn=self._metrics, - create_output_alternatives_fn=self._create_output_alternatives, + create_output_alternatives_fn=_classification_output_alternatives( + self.head_name, self._problem_type, self._label_keys), labels=labels, train_op_fn=train_op_fn, logits=logits, @@ -1113,25 +1116,6 @@ class _MultiClassHead(_SingleHead): return metrics - def _create_output_alternatives(self, predictions): - """See superclass.""" - probabilities = predictions[prediction_key.PredictionKey.PROBABILITIES] - batch_size = array_ops.shape(probabilities)[0] - if self._label_keys: - classes = array_ops.tile( - input=array_ops.expand_dims(input=self._label_keys, axis=0), - multiples=[batch_size, 1]) - else: - classes = array_ops.tile( - input=array_ops.expand_dims( - input=math_ops.range(self.logits_dimension), axis=0), - multiples=[batch_size, 1]) - predictions_for_serving = { - prediction_key.PredictionKey.CLASSES: classes, - prediction_key.PredictionKey.PROBABILITIES: probabilities, - } - return {self._head_name: (self._problem_type, predictions_for_serving)} - def _to_labels_tensor(labels, label_name): """Returns label as a tensor. @@ -1226,6 +1210,7 @@ class _BinarySvmHead(_SingleHead): loss_fn=self._loss_fn, logits_to_predictions_fn=self._logits_to_predictions, metrics_fn=self._metrics, + # TODO(zakaria): Handle labels for export. create_output_alternatives_fn=self._create_output_alternatives, labels=labels, train_op_fn=train_op_fn, @@ -1325,7 +1310,8 @@ class _MultiLabelHead(_SingleHead): loss_fn=self._loss_fn, logits_to_predictions_fn=self._logits_to_predictions, metrics_fn=self._metrics, - create_output_alternatives_fn=self._create_output_alternatives, + create_output_alternatives_fn=_classification_output_alternatives( + self.head_name, self._problem_type), labels=labels, train_op_fn=train_op_fn, logits=logits, @@ -1901,6 +1887,71 @@ def _streaming_recall_at_threshold(predictions, labels, weights, threshold): return array_ops.squeeze(precision_tensor), array_ops.squeeze(update_op) +def _classification_output_alternatives(head_name, problem_type, + label_keys=None): + """Creates a func to generate output alternatives for classification. + + Servo expects classes to be a string tensor, and have the same dimensions + as the probabilities tensor. It should contain the labels of the corresponding + entries in probabilities. This function creates a new classes tensor that + satisfies these conditions and can be exported. + + Args: + head_name: Name of the head. + problem_type: `ProblemType` + label_keys: Optional label keys + + Returns: + A function to generate output alternatives. + """ + def _create_output_alternatives(predictions): + """Creates output alternative for the Head. + + Args: + predictions: a dict of {tensor_name: Tensor}, where 'tensor_name' is a + symbolic name for an output Tensor possibly but not necessarily taken + from `PredictionKey`, and 'Tensor' is the corresponding output Tensor + itself. + + Returns: + `dict` of {submodel_name: (problem_type, {tensor_name: Tensor})}, where + 'submodel_name' is a submodel identifier that should be consistent across + the pipeline (here likely taken from the head_name), + 'problem_type' is a `ProblemType`, + 'tensor_name' is a symbolic name for an output Tensor possibly but not + necessarily taken from `PredictionKey`, and + 'Tensor' is the corresponding output Tensor itself. + + Raises: + ValueError: if predictions does not have PredictionKey.PROBABILITIES key. + """ + probabilities = predictions.get(prediction_key.PredictionKey.PROBABILITIES) + if probabilities is None: + raise ValueError("%s missing in predictions" % + prediction_key.PredictionKey.PROBABILITIES) + + with ops.name_scope(None, "_classification_output_alternatives", + (probabilities,)): + batch_size = array_ops.shape(probabilities)[0] + if label_keys: + classes = array_ops.tile( + input=array_ops.expand_dims(input=label_keys, axis=0), + multiples=[batch_size, 1], + name="classes_tensor") + else: + n = array_ops.shape(probabilities)[1] + classes = array_ops.tile( + input=array_ops.expand_dims(input=math_ops.range(n), axis=0), + multiples=[batch_size, 1]) + classes = string_ops.as_string(classes, name="classes_tensor") + + exported_predictions = { + prediction_key.PredictionKey.PROBABILITIES: probabilities, + prediction_key.PredictionKey.CLASSES: classes} + return {head_name: (problem_type, exported_predictions)} + + return _create_output_alternatives + # Aliases # TODO(zakaria): Remove these aliases, See b/34751732 _regression_head = regression_head diff --git a/tensorflow/contrib/learn/python/learn/estimators/head_test.py b/tensorflow/contrib/learn/python/learn/estimators/head_test.py index ecc1d9ff9e1..a5f3b4b2703 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head_test.py @@ -417,7 +417,7 @@ class MultiLabelHeadTest(test.TestCase): {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn, logits_input=((0., 0.),), logits=self._logits) - def testMultiLabelEvalMode(self): + def testMultiLabelEval(self): n_classes = 3 head = head_lib.multi_label_head( n_classes=n_classes, metric_class_ids=range(n_classes)) @@ -433,7 +433,7 @@ class MultiLabelHeadTest(test.TestCase): _assert_metrics(self, expected_loss, self._expected_eval_metrics(expected_loss), model_fn_ops) - def testMultiClassEvalModeWithLargeLogits(self): + def testMultiClassEvalWithLargeLogits(self): n_classes = 3 head = head_lib.multi_label_head( n_classes=n_classes, metric_class_ids=range(n_classes)) @@ -472,6 +472,36 @@ class MultiLabelHeadTest(test.TestCase): _assert_metrics(self, expected_loss, expected_eval_metrics, model_fn_ops) + def testMultiLabelInfer(self): + n_classes = 3 + head = head_lib.multi_label_head(n_classes=n_classes, head_name="head_name") + with ops.Graph().as_default(), session.Session(): + model_fn_ops = head.create_model_fn_ops( + {}, model_fn.ModeKeys.INFER, self._labels, head_lib.no_op_train_fn, + logits=((1., 0., 0.), (0., 0., 1))) + self.assertIsNone(model_fn_ops.train_op) + _assert_no_variables(self) + with session.Session(): + self.assertListEqual( + [1, 0, 0], model_fn_ops.predictions["classes"].eval().tolist()[0]) + self.assertItemsEqual( + ["head_name"], six.iterkeys(model_fn_ops.output_alternatives)) + self.assertEqual( + constants.ProblemType.CLASSIFICATION, + model_fn_ops.output_alternatives["head_name"][0]) + + predictions_for_serving = ( + model_fn_ops.output_alternatives["head_name"][1]) + self.assertIn("classes", six.iterkeys(predictions_for_serving)) + self.assertAllEqual( + [[b"0", b"1", b"2"], [b"0", b"1", b"2"]], + predictions_for_serving["classes"].eval()) + self.assertIn("probabilities", six.iterkeys(predictions_for_serving)) + self.assertAllClose( + [[0.731059, 0.5, 0.5], + [0.5, 0.5, 0.731059,]], + predictions_for_serving["probabilities"].eval()) + def testMultiLabelWithLabelName(self): n_classes = 3 label_name = "my_label" @@ -691,7 +721,7 @@ class BinaryClassificationHeadTest(test.TestCase): {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn, logits_input=((0., 0.), (0., 0.)), logits=self._logits) - def testBinaryClassificationEvalMode(self): + def testBinaryClassificationEval(self): n_classes = 2 head = head_lib.multi_class_head(n_classes=n_classes) with ops.Graph().as_default(), session.Session(): @@ -708,18 +738,32 @@ class BinaryClassificationHeadTest(test.TestCase): _assert_metrics(self, expected_loss, self._expected_eval_metrics(expected_loss), model_fn_ops) - def testBinaryClassificationInferMode(self): + def testBinaryClassificationInfer(self): n_classes = 2 - head = head_lib.multi_class_head(n_classes=n_classes) + head = head_lib.multi_class_head(n_classes=n_classes, head_name="head_name") with ops.Graph().as_default(), session.Session(): # logloss: z:label, x:logit # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) model_fn_ops = head.create_model_fn_ops( {}, model_fn.ModeKeys.INFER, self._labels, head_lib.no_op_train_fn, logits=self._logits) - self._assert_output_alternatives(model_fn_ops) self.assertIsNone(model_fn_ops.train_op) _assert_no_variables(self) + with session.Session(): + self.assertListEqual( + [1, 1], list(model_fn_ops.predictions["classes"].eval())) + self.assertItemsEqual( + ["head_name"], six.iterkeys(model_fn_ops.output_alternatives)) + self.assertEqual( + constants.ProblemType.LOGISTIC_REGRESSION, + model_fn_ops.output_alternatives["head_name"][0]) + predictions_for_serving = ( + model_fn_ops.output_alternatives["head_name"][1]) + self.assertIn("classes", six.iterkeys(predictions_for_serving)) + predicted_classes = predictions_for_serving["classes"].eval().tolist() + self.assertListEqual( + [b"0", b"1"], predicted_classes[0]) + self.assertIn("probabilities", six.iterkeys(predictions_for_serving)) def testBinaryClassificationInferMode_withWightColumn(self): n_classes = 2 @@ -1006,7 +1050,7 @@ class MultiClassHeadTest(test.TestCase): "multi_class_head/centered_bias/bias_1", "multi_class_head/centered_bias/bias_2"]) - def testMultiClassEvalMode(self): + def testMultiClassEval(self): n_classes = 3 head = head_lib.multi_class_head( n_classes=n_classes, metric_class_ids=range(n_classes)) @@ -1131,7 +1175,7 @@ class MultiClassHeadTest(test.TestCase): model_fn_ops.output_alternatives["head_name"][1]) self.assertIn("classes", six.iterkeys(predictions_for_serving)) self.assertAllEqual( - [[0, 1, 2], [0, 1, 2]], + [[b"0", b"1", b"2"], [b"0", b"1", b"2"]], predictions_for_serving["classes"].eval()) self.assertIn("probabilities", six.iterkeys(predictions_for_serving)) self.assertAllClose( From 0fe523ac05f1f49145f4b243953a7aad331ea4dc Mon Sep 17 00:00:00 2001 From: Shanqing Cai Date: Fri, 31 Mar 2017 09:31:36 -0800 Subject: [PATCH 12/81] tfdbg doc: fix code blocks under numbered bullets Change: 151838508 --- tensorflow/docs_src/programmers_guide/debugger.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tensorflow/docs_src/programmers_guide/debugger.md b/tensorflow/docs_src/programmers_guide/debugger.md index 7ecddc548fe..6f442e6e0c4 100644 --- a/tensorflow/docs_src/programmers_guide/debugger.md +++ b/tensorflow/docs_src/programmers_guide/debugger.md @@ -418,8 +418,9 @@ that the graph contains, this kind of disk space issue can happen. There are three possible workarounds or solutions: 1. The constructors of `LocalCLIDebugWrapperSession` and `LocalCLIDebugHook` - provide a keyword argument, `dump_root`, with which you can specify the path + provide a keyword argument, `dump_root`, with which you can specify the path to which **tfdbg** dumps the debug data. For example: + ``` python # For LocalCLIDebugWrapperSession sess = tf_debug.LocalCLIDebugWrapperSession(dump_root="/with/lots/of/space") @@ -432,6 +433,7 @@ There are three possible workarounds or solutions: 2. Reduce the batch size used during the runs. 3. Use the filtering options of **tfdbg**'s `run` command to watch only specific nodes in the graph. For example: + ``` tfdbg> run --node_name_filter .*hidden.* tfdbg> run --op_type_filter Variable.* From 997ffb515f33911bde527465ab886a93a3cf9e67 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 31 Mar 2017 09:42:41 -0800 Subject: [PATCH 13/81] Add default saver option to CheckpointSaverHook and improve docstrings. Change: 151839983 --- .../training/basic_session_run_hooks.py | 112 +++++++++--------- .../training/basic_session_run_hooks_test.py | 92 ++++++++++++++ .../python/training/monitored_session.py | 7 +- tensorflow/python/training/saver.py | 27 +++++ 4 files changed, 175 insertions(+), 63 deletions(-) diff --git a/tensorflow/python/training/basic_session_run_hooks.py b/tensorflow/python/training/basic_session_run_hooks.py index f13b87dfed6..6fd20ce8013 100644 --- a/tensorflow/python/training/basic_session_run_hooks.py +++ b/tensorflow/python/training/basic_session_run_hooks.py @@ -40,6 +40,7 @@ from tensorflow.core.util.event_pb2 import SessionLog from tensorflow.python.framework import meta_graph from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import session_run_hook from tensorflow.python.training import training_util from tensorflow.python.training.session_run_hook import SessionRunArgs @@ -124,7 +125,7 @@ class LoggingTensorHook(session_run_hook.SessionRunHook): def __init__(self, tensors, every_n_iter=None, every_n_secs=None, formatter=None): - """Initializes a LoggingHook monitor. + """Initializes a `LoggingTensorHook`. Args: tensors: `dict` that maps string-valued tags to tensors/tensor names, @@ -189,10 +190,10 @@ class LoggingTensorHook(session_run_hook.SessionRunHook): class StopAtStepHook(session_run_hook.SessionRunHook): - """Monitor to request stop at a specified step.""" + """Hook that requests stop at a specified step.""" def __init__(self, num_steps=None, last_step=None): - """Create a StopAtStep Hook. + """Initializes a `StopAtStepHook`. This hook requests stop after either a number of steps have been executed or a last step has been reached. Only one of the two options can be @@ -234,51 +235,48 @@ class StopAtStepHook(session_run_hook.SessionRunHook): class CheckpointSaverListener(object): - """An interface for event hooks that depend on a checkpoint. + """Interface for listeners that take action before or after checkpoint save. - CheckpointSaverListeners are similar to SessionRunHooks, and can be useful to - track training, report progress, and more. The distinction is that - CheckpointSaverListeners run only in steps when CheckpointSaverHook is - triggered, and provide callbacks to run before or after the checkpoint is - generated. This is in contrast to SessionRunHooks, which may run in steps - when no checkpoint is written, and which have no guaranteed execution order - in any case. CheckpointSaverListeners use the observer pattern and notify at - the following points: - - when a session starts being used + `CheckpointSaverListener` triggers only in steps when `CheckpointSaverHook` is + triggered, and provides callbacks at the following points: + - before using the session - before each call to `Saver.save()` - after each call to `Saver.save()` - - when the session closed + - at the end of session - Custom CheckpointSaverListeners look like this: - class ExampleCheckpointSaverListerner(CheckpointSaverListener): - def begin(self): - # You can add ops to the graph here. - print('Starting the session.') - self.your_tensor = ... + To use a listener, implement a class and pass the listener to a + `CheckpointSaverHook`, as in this example: - def before_save(self, session, global_step_value): - print('About to write a checkpoint') + ```python + class ExampleCheckpointSaverListerner(CheckpointSaverListener): + def begin(self): + # You can add ops to the graph here. + print('Starting the session.') + self.your_tensor = ... - def after_save(self, session, global_step_value): - print('Done writing checkpoint.') + def before_save(self, session, global_step_value): + print('About to write a checkpoint') - def end(self, session, global_step_value): - print('Done with the session.') + def after_save(self, session, global_step_value): + print('Done writing checkpoint.') - A CheckpointSaverListener may simply take some action after every checkpoint. - It is also possible for the listener to use its own schedule to act less - frequently, based on wall clock time or on global_step_value. In this case, - implementors must be careful about what happens at end(). When end is called, - The CheckpointSaverHook will have already triggered after_save() in the same - global_step, but the listener may or may not have actually acted on it. - The listener may want to be sure to act at end() if there is a fresh - checkpoint available, but should not act twice if after_save() already handled - it. In this case, end() should have logic to detect the situation and do the - right thing, similar to what CheckpointSaverHook.end() does using - self._timer.last_triggered_step(). + def end(self, session, global_step_value): + print('Done with the session.') - To use such listeners, in your `model_fn` return a `CheckpointSaverHook` as - part of `training_chief_hooks`. + ... + listener = ExampleCheckpointSaverListerner() + saver_hook = tf.train.CheckpointSaverHook( + checkpoint_dir, listeners=[listener]) + with tf.train.MonitoredTrainingSession(chief_only_hooks=[saver_hook]): + ... + ``` + + A `CheckpointSaverListener` may simply take some action after every + checkpoint save. It is also possible for the listener to use its own schedule + to act less frequently, e.g. based on global_step_value. In this case, + implementors should implement the `end()` method to handle actions related to + the last checkpoint save. But the listener should not act twice if + `after_save()` already handled this last checkpoint save. """ def begin(self): @@ -305,7 +303,7 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook): checkpoint_basename="model.ckpt", scaffold=None, listeners=None): - """Initialize CheckpointSaverHook monitor. + """Initializes a `CheckpointSaverHook`. Args: checkpoint_dir: `str`, base directory for the checkpoint files. @@ -315,18 +313,18 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook): checkpoint_basename: `str`, base name for the checkpoint files. scaffold: `Scaffold`, use to get saver object. listeners: List of `CheckpointSaverListener` subclass instances. - Used for callbacks that run immediately after the corresponding - CheckpointSaverHook callbacks, only in steps where the - CheckpointSaverHook was triggered. + Used for callbacks that run immediately before or after this hook saves + the checkpoint. Raises: ValueError: One of `save_steps` or `save_secs` should be set. ValueError: Exactly one of saver or scaffold should be set. """ logging.info("Create CheckpointSaverHook.") - if ((saver is None and scaffold is None) or - (saver is not None and scaffold is not None)): - raise ValueError("Exactly one of saver or scaffold must be provided.") + if saver is not None and scaffold is not None: + raise ValueError("You cannot provide both saver and scaffold.") + if saver is None and scaffold is None: + saver = saver_lib._get_saver_or_default() # pylint: disable=protected-access self._saver = saver self._checkpoint_dir = checkpoint_dir self._save_path = os.path.join(checkpoint_dir, checkpoint_basename) @@ -401,7 +399,7 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook): class StepCounterHook(session_run_hook.SessionRunHook): - """Steps per second monitor.""" + """Hook that counts steps per second.""" def __init__(self, every_n_steps=100, @@ -453,14 +451,13 @@ class NanLossDuringTrainingError(RuntimeError): class NanTensorHook(session_run_hook.SessionRunHook): - """NaN Loss monitor. + """Monitors the loss tensor and stops training if loss is NaN. - Monitors loss and stops training if loss is NaN. Can either fail with exception or just stop training. """ def __init__(self, loss_tensor, fail_on_nan_loss=True): - """Initializes NanLoss monitor. + """Initializes a `NanTensorHook`. Args: loss_tensor: `Tensor`, the loss tensor. @@ -494,7 +491,7 @@ class SummarySaverHook(session_run_hook.SessionRunHook): summary_writer=None, scaffold=None, summary_op=None): - """Initializes a `SummarySaver` monitor. + """Initializes a `SummarySaverHook`. Args: save_steps: `int`, save summaries every N steps. Exactly one of @@ -590,7 +587,7 @@ class SummarySaverHook(session_run_hook.SessionRunHook): class GlobalStepWaiterHook(session_run_hook.SessionRunHook): - """Delay execution until global step reaches to wait_until_step. + """Delays execution until global step reaches `wait_until_step`. This hook delays execution until global step reaches to `wait_until_step`. It is used to gradually start workers in distributed settings. One example usage @@ -599,7 +596,7 @@ class GlobalStepWaiterHook(session_run_hook.SessionRunHook): """ def __init__(self, wait_until_step): - """Create a _GlobalStepWaiterHook. + """Initializes a `GlobalStepWaiterHook`. Args: wait_until_step: an `int` shows until which global step should we wait. @@ -637,10 +634,10 @@ class GlobalStepWaiterHook(session_run_hook.SessionRunHook): class FinalOpsHook(session_run_hook.SessionRunHook): - """A run hook which evaluates `Tensors` at the end of a session.""" + """A hook which evaluates `Tensors` at the end of a session.""" def __init__(self, final_ops, final_ops_feed_dict=None): - """Constructs the FinalOpHook with ops to run at the end of the session. + """Initializes `FinalOpHook` with ops to run at the end of the session. Args: final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of @@ -666,10 +663,11 @@ class FeedFnHook(session_run_hook.SessionRunHook): """Runs `feed_fn` and sets the `feed_dict` accordingly.""" def __init__(self, feed_fn): - """Constructs the FeedFnHook with given `feed_fn`. + """Initializes a `FeedFnHook`. Args: - feed_fn: function, no arguments and returns `dict` to feed. + feed_fn: function that takes no arguments and returns `dict` of `Tensor` + to feed. """ self.feed_fn = feed_fn diff --git a/tensorflow/python/training/basic_session_run_hooks_test.py b/tensorflow/python/training/basic_session_run_hooks_test.py index c2636d46f59..ecb61d447bf 100644 --- a/tensorflow/python/training/basic_session_run_hooks_test.py +++ b/tensorflow/python/training/basic_session_run_hooks_test.py @@ -346,6 +346,98 @@ class CheckpointSaverHookTest(test.TestCase): 'end': 1 }, listener.get_counts()) + def test_listener_with_monitored_session(self): + with ops.Graph().as_default(): + scaffold = monitored_session.Scaffold() + global_step = variables.get_or_create_global_step() + train_op = state_ops.assign_add(global_step, 1) + listener = MockCheckpointSaverListener() + hook = basic_session_run_hooks.CheckpointSaverHook( + self.model_dir, + save_steps=1, + scaffold=scaffold, + listeners=[listener]) + with monitored_session.SingularMonitoredSession( + hooks=[hook], + scaffold=scaffold, + checkpoint_dir=self.model_dir) as sess: + sess.run(train_op) + sess.run(train_op) + global_step_val = sess.run(global_step) + listener_counts = listener.get_counts() + self.assertEqual(2, global_step_val) + self.assertEqual({ + 'begin': 1, + 'before_save': 2, + 'after_save': 2, + 'end': 1 + }, listener_counts) + + def test_listener_with_default_saver(self): + with ops.Graph().as_default(): + global_step = variables.get_or_create_global_step() + train_op = state_ops.assign_add(global_step, 1) + listener = MockCheckpointSaverListener() + hook = basic_session_run_hooks.CheckpointSaverHook( + self.model_dir, + save_steps=1, + listeners=[listener]) + with monitored_session.SingularMonitoredSession( + hooks=[hook], + checkpoint_dir=self.model_dir) as sess: + sess.run(train_op) + sess.run(train_op) + global_step_val = sess.run(global_step) + listener_counts = listener.get_counts() + self.assertEqual(2, global_step_val) + self.assertEqual({ + 'begin': 1, + 'before_save': 2, + 'after_save': 2, + 'end': 1 + }, listener_counts) + + with ops.Graph().as_default(): + global_step = variables.get_or_create_global_step() + with monitored_session.SingularMonitoredSession( + checkpoint_dir=self.model_dir) as sess2: + global_step_saved_val = sess2.run(global_step) + self.assertEqual(2, global_step_saved_val) + + def test_two_listeners_with_default_saver(self): + with ops.Graph().as_default(): + global_step = variables.get_or_create_global_step() + train_op = state_ops.assign_add(global_step, 1) + listener1 = MockCheckpointSaverListener() + listener2 = MockCheckpointSaverListener() + hook = basic_session_run_hooks.CheckpointSaverHook( + self.model_dir, + save_steps=1, + listeners=[listener1, listener2]) + with monitored_session.SingularMonitoredSession( + hooks=[hook], + checkpoint_dir=self.model_dir) as sess: + sess.run(train_op) + sess.run(train_op) + global_step_val = sess.run(global_step) + listener1_counts = listener1.get_counts() + listener2_counts = listener2.get_counts() + self.assertEqual(2, global_step_val) + self.assertEqual({ + 'begin': 1, + 'before_save': 2, + 'after_save': 2, + 'end': 1 + }, listener1_counts) + self.assertEqual(listener1_counts, listener2_counts) + + with ops.Graph().as_default(): + global_step = variables.get_or_create_global_step() + with monitored_session.SingularMonitoredSession( + checkpoint_dir=self.model_dir) as sess2: + global_step_saved_val = sess2.run(global_step) + self.assertEqual(2, global_step_saved_val) + @test.mock.patch('time.time') def test_save_secs_saves_periodically(self, mock_time): # Let's have a realistic start time diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py index ae76a1ab580..cf8692eda13 100644 --- a/tensorflow/python/training/monitored_session.py +++ b/tensorflow/python/training/monitored_session.py @@ -22,7 +22,6 @@ from __future__ import print_function import abc from tensorflow.core.protobuf import config_pb2 -from tensorflow.core.protobuf import saver_pb2 from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -180,11 +179,7 @@ class Scaffold(object): summary.merge_all) # pylint: disable=g-long-lambda if self._saver is None: - self._saver = Scaffold.get_or_default( - 'saver', - ops.GraphKeys.SAVERS, - lambda: training_saver.Saver(sharded=True, allow_empty=True, - write_version=saver_pb2.SaverDef.V2)) + self._saver = training_saver._get_saver_or_default() # pylint: disable=protected-access # pylint: enable=g-long-lambda self._saver.build() diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index ae5fc54d854..43b61742467 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -712,6 +712,33 @@ class BaseSaverBuilder(object): version=self._write_version) +def _get_saver_or_default(): + """Returns the saver from SAVERS collection, or creates a default one. + + This method is used by other members of the training module, such as + `Scaffold`, or `CheckpointSaverHook`. + + Returns: + `Saver`. + + Raises: + RuntimeError: If the SAVERS collection already has more than one items. + """ + collection_key = ops.GraphKeys.SAVERS + savers = ops.get_collection(collection_key) + if savers: + if len(savers) > 1: + raise RuntimeError( + "More than one item in collection {}. " + "Please indicate which one to use by passing it to the constructor.". + format(collection_key)) + return savers[0] + saver = Saver(sharded=True, allow_empty=True) + if saver is not None: + ops.add_to_collection(collection_key, saver) + return saver + + def _GetCheckpointFilename(save_dir, latest_filename): """Returns a filename for storing the CheckpointState. From ac933c956a01187b626a4095ea00822364cfb6c2 Mon Sep 17 00:00:00 2001 From: Shanqing Cai Date: Fri, 31 Mar 2017 10:05:39 -0800 Subject: [PATCH 14/81] Fix issue in installing latest nightly tensorflow pip wheel in ubuntu:16.04-based gcs_test Dockerfile Change: 151843345 --- tensorflow/tools/gcs_test/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/tools/gcs_test/Dockerfile b/tensorflow/tools/gcs_test/Dockerfile index 1df00692725..581bded65da 100644 --- a/tensorflow/tools/gcs_test/Dockerfile +++ b/tensorflow/tools/gcs_test/Dockerfile @@ -3,7 +3,7 @@ FROM ubuntu:16.04 MAINTAINER Shanqing Cai RUN apt-get update -RUN apt-get install -y --no-install-recommends \ +RUN apt-get install -y \ curl \ libcurl4-openssl-dev \ python \ From bf02d3cfefdf92dc2929a6e257bc31800bf02a60 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 31 Mar 2017 10:51:11 -0800 Subject: [PATCH 15/81] Add an optional activation function to the OutputProjectionWrapper and InputProjectionWrapper. Change: 151849737 --- .../contrib/rnn/python/ops/core_rnn_cell_impl.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py b/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py index f44302638eb..d828a337f31 100644 --- a/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py +++ b/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py @@ -466,12 +466,13 @@ class OutputProjectionWrapper(RNNCell): if needed or directly feed into a softmax. """ - def __init__(self, cell, output_size, reuse=None): + def __init__(self, cell, output_size, activation=None, reuse=None): """Create a cell with output projection. Args: cell: an RNNCell, a projection to output_size is added to it. output_size: integer, the size of the output after projection. + activation: (optional) an optional activation function. reuse: (optional) Python boolean describing whether to reuse variables in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised. @@ -487,6 +488,7 @@ class OutputProjectionWrapper(RNNCell): self._cell = cell self._output_size = output_size self._reuse = reuse + self._activation = activation @property def state_size(self): @@ -507,6 +509,8 @@ class OutputProjectionWrapper(RNNCell): with _checked_scope(self, scope or "output_projection_wrapper", reuse=self._reuse): projected = _linear(output, self._output_size, True) + if self._activation: + projected = self._activation(projected) return projected, res_state @@ -518,12 +522,13 @@ class InputProjectionWrapper(RNNCell): do the projection on this batch-concatenated sequence, then split it. """ - def __init__(self, cell, num_proj, input_size=None): + def __init__(self, cell, num_proj, activation=None, input_size=None): """Create a cell with input projection. Args: cell: an RNNCell, a projection of inputs is added before it. num_proj: Python integer. The dimension to project to. + activation: (optional) an optional activation function. input_size: Deprecated and unused. Raises: @@ -535,6 +540,7 @@ class InputProjectionWrapper(RNNCell): raise TypeError("The parameter cell is not RNNCell.") self._cell = cell self._num_proj = num_proj + self._activation = activation @property def state_size(self): @@ -553,6 +559,8 @@ class InputProjectionWrapper(RNNCell): # Default scope: "InputProjectionWrapper" with vs.variable_scope(scope or "input_projection_wrapper"): projected = _linear(inputs, self._num_proj, True) + if self._activation: + projected = self._activation(projected) return self._cell(projected, state) From 89ed6b49af088fae6c2185f145ca1e2748642a49 Mon Sep 17 00:00:00 2001 From: Zakaria Haque Date: Fri, 31 Mar 2017 10:54:28 -0800 Subject: [PATCH 16/81] Adds area under precision recall curve for binary and multiclass heads. Change: 151850197 --- .../contrib/learn/python/learn/estimators/head.py | 13 +++++++++++-- .../learn/python/learn/estimators/head_test.py | 7 ++++++- .../learn/python/learn/estimators/metric_key.py | 2 ++ 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index 65f5b49b0e4..14f7666b3df 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -887,6 +887,8 @@ class _BinaryLogisticHead(_SingleHead): _indicator_labels_streaming_mean(labels, weights)) metrics[_summary_key(self.head_name, mkey.AUC)] = ( _streaming_auc(logistic, labels, weights)) + metrics[_summary_key(self.head_name, mkey.AUC_PR)] = ( + _streaming_auc(logistic, labels, weights, curve="PR")) for threshold in self._thresholds: metrics[_summary_key( @@ -1360,6 +1362,8 @@ class _MultiLabelHead(_SingleHead): metrics_lib.streaming_accuracy(classes, labels, weights)) metrics[_summary_key(self.head_name, mkey.AUC)] = _streaming_auc( probabilities, labels, weights) + metrics[_summary_key(self.head_name, mkey.AUC_PR)] = _streaming_auc( + probabilities, labels, weights, curve="PR") for class_id in self._metric_class_ids: # TODO(ptucker): Add per-class accuracy, precision, recall. @@ -1377,6 +1381,9 @@ class _MultiLabelHead(_SingleHead): _predictions_streaming_mean(logits, weights, class_id)) metrics[_summary_key(self.head_name, mkey.CLASS_AUC % class_id)] = ( _streaming_auc(probabilities, labels, weights, class_id)) + metrics[_summary_key(self.head_name, mkey.CLASS_AUC_PR % class_id)] = ( + _streaming_auc(probabilities, labels, weights, class_id, + curve="PR")) return metrics @@ -1843,7 +1850,8 @@ def _class_labels_streaming_mean(labels, weights, class_id): weights=weights) -def _streaming_auc(predictions, labels, weights=None, class_id=None): +def _streaming_auc(predictions, labels, weights=None, class_id=None, + curve="ROC"): predictions = ops.convert_to_tensor(predictions) labels = ops.convert_to_tensor(labels) if class_id is not None: @@ -1852,7 +1860,8 @@ def _streaming_auc(predictions, labels, weights=None, class_id=None): return metrics_lib.streaming_auc( predictions, math_ops.cast(labels, dtypes.bool), - weights=_float_weights_or_none(weights)) + weights=_float_weights_or_none(weights), + curve=curve) def _assert_class_id(class_id, num_classes=None): diff --git a/tensorflow/contrib/learn/python/learn/estimators/head_test.py b/tensorflow/contrib/learn/python/learn/estimators/head_test.py index a5f3b4b2703..9b8cba15263 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head_test.py @@ -297,11 +297,15 @@ class MultiLabelHeadTest(test.TestCase): def _expected_eval_metrics(self, expected_loss): return { "accuracy": 1. / 3, - "auc": 1. / 4, "loss": expected_loss, + "auc": 1. / 4, "auc/class0": 1., "auc/class1": 1., "auc/class2": 0., + "auc_precision_recall": 0.166667, + "auc_precision_recall/class0": 0, + "auc_precision_recall/class1": 0., + "auc_precision_recall/class2": 1., "labels/actual_label_mean/class0": self._labels[0][0], "labels/actual_label_mean/class1": self._labels[0][1], "labels/actual_label_mean/class2": self._labels[0][2], @@ -651,6 +655,7 @@ class BinaryClassificationHeadTest(test.TestCase): "accuracy/baseline_label_mean": label_mean, "accuracy/threshold_0.500000_mean": 1. / 2, "auc": 1. / 2, + "auc_precision_recall": 0.749999, "labels/actual_label_mean": label_mean, "labels/prediction_mean": .731059, # softmax "loss": expected_loss, diff --git a/tensorflow/contrib/learn/python/learn/estimators/metric_key.py b/tensorflow/contrib/learn/python/learn/estimators/metric_key.py index 10ac888eca7..99388f116b3 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/metric_key.py +++ b/tensorflow/contrib/learn/python/learn/estimators/metric_key.py @@ -22,7 +22,9 @@ class MetricKey(object): """Metric key strings.""" LOSS = "loss" AUC = "auc" + AUC_PR = "auc_precision_recall" CLASS_AUC = "auc/class%d" + CLASS_AUC_PR = "auc_precision_recall/class%d" PREDICTION_MEAN = "labels/prediction_mean" CLASS_PREDICTION_MEAN = "labels/prediction_mean/class%d" CLASS_LOGITS_MEAN = "labels/logits_mean/class%d" From a05668a832980531e621586473204312c87b5b6e Mon Sep 17 00:00:00 2001 From: Vinu Rajashekhar Date: Fri, 31 Mar 2017 11:25:32 -0800 Subject: [PATCH 17/81] Makes GraphRunner a class to explicitly control it's lifetime. Change: 151853846 --- .../core/common_runtime/constant_folding.cc | 13 ++++++++-- .../core/common_runtime/graph_runner.cc | 26 ++++++++++--------- tensorflow/core/common_runtime/graph_runner.h | 19 +++++++++++--- .../core/common_runtime/graph_runner_test.cc | 14 +++++----- .../core/common_runtime/shape_refiner.cc | 16 ++++++++---- .../core/common_runtime/shape_refiner.h | 7 +++++ 6 files changed, 66 insertions(+), 29 deletions(-) diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc index 5db49aa498c..8c4085425a1 100644 --- a/tensorflow/core/common_runtime/constant_folding.cc +++ b/tensorflow/core/common_runtime/constant_folding.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/graph/subgraph.h" #include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/public/session_options.h" @@ -304,10 +305,18 @@ Status DoConstantFoldingWithStatus(const ConstantFoldingOptions& opts, tensors_to_replace.push_back({n.second, n.first.second}); } + auto graph_runner = std::unique_ptr(new GraphRunner(env)); // Evaluate the constant foldable nodes. std::vector outputs; - Status s = GraphRunner::Run(constant_graph.get(), function_library, env, - {} /* inputs*/, tensors_to_fetch_names, &outputs); + auto delete_tensors = gtl::MakeCleanup([&graph_runner, &outputs] { + // Output tensors need to be cleared before the GraphRunner is deleted. + outputs.clear(); + graph_runner.reset(nullptr); + }); + + Status s = + graph_runner->Run(constant_graph.get(), function_library, {} /* inputs*/, + tensors_to_fetch_names, &outputs); if (!s.ok()) { VLOG(1) << "Could not fetch constants: " << s; *was_mutated = false; diff --git a/tensorflow/core/common_runtime/graph_runner.cc b/tensorflow/core/common_runtime/graph_runner.cc index c93ff1cdde8..d4dc8f0057e 100644 --- a/tensorflow/core/common_runtime/graph_runner.cc +++ b/tensorflow/core/common_runtime/graph_runner.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/core/common_runtime/graph_runner.h" -#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/executor.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/memory_types.h" @@ -95,22 +94,24 @@ class SimpleRendezvous : public Rendezvous { } // namespace -// static +GraphRunner::GraphRunner(Env* env) : cpu_device_(GetCPUDevice(env)) {} + +GraphRunner::~GraphRunner() {} + Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library, - Env* env, const NamedTensorList& inputs, + const NamedTensorList& inputs, const std::vector& output_names, std::vector* outputs) { + if (cpu_device_ == nullptr) { + return errors::NotFound("Cannot find a device for GraphRunner."); + } + // TODO(vrv): Instead of copying the entire graph, consider modifying // the existing graph, and then removing those removed edges. // prior to returning. std::unique_ptr graph_to_run(new Graph(graph->op_registry())); CopyGraph(*graph, graph_to_run.get()); - std::unique_ptr device = GetCPUDevice(env); - if (!device) { - return errors::NotFound("Cannot find a device for GraphRunner."); - } - SimpleRendezvous* rendez = new SimpleRendezvous; core::ScopedUnref rendez_unref(rendez); @@ -130,7 +131,7 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library, // Call RewriteGraphForExecution TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution( graph_to_run.get(), input_names, output_names, {} /* target nodes */, - device->attributes())); + cpu_device_->attributes())); // Create the local executor and the Rendezvous for fetching back the // constants. @@ -143,10 +144,11 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library, Graph* g = graph_to_run.release(); LocalExecutorParams params; - params.device = device.get(); + // The ownership of the output tensors are bound to this device's lifetime. + params.device = cpu_device_.get(); params.function_library = function_library; - params.create_kernel = [&device, g](const NodeDef& ndef, OpKernel** kernel) { - return CreateNonCachedKernel(device.get(), nullptr, ndef, + params.create_kernel = [this, g](const NodeDef& ndef, OpKernel** kernel) { + return CreateNonCachedKernel(cpu_device_.get(), nullptr, ndef, g->versions().producer(), kernel); }; params.delete_kernel = [](OpKernel* kernel) { delete kernel; }; diff --git a/tensorflow/core/common_runtime/graph_runner.h b/tensorflow/core/common_runtime/graph_runner.h index e078c7ffc8c..24e8b04c463 100644 --- a/tensorflow/core/common_runtime/graph_runner.h +++ b/tensorflow/core/common_runtime/graph_runner.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/graph/graph.h" @@ -44,16 +45,26 @@ namespace tensorflow { // to be particularly lightweight, fast, or efficient. class GraphRunner { public: + // REQUIRES: `env` is not nullptr. + GraphRunner(Env* env); + ~GraphRunner(); + // Function semantics for `inputs`, `output_names` and `outputs` // matches those from Session::Run(). // + // NOTE: The output tensors share lifetime with the GraphRunner, and could + // be destroyed once the GraphRunner is destroyed. + // // REQUIRES: `graph`, `env`, and `outputs` are not nullptr. // `function_library` may be nullptr. typedef std::vector> NamedTensorList; - static Status Run(Graph* graph, FunctionLibraryRuntime* function_library, - Env* env, const NamedTensorList& inputs, - const std::vector& output_names, - std::vector* outputs); + Status Run(Graph* graph, FunctionLibraryRuntime* function_library, + const NamedTensorList& inputs, + const std::vector& output_names, + std::vector* outputs); + + private: + std::unique_ptr cpu_device_; }; } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/graph_runner_test.cc b/tensorflow/core/common_runtime/graph_runner_test.cc index 5918ba9a22d..ccb44af0ec2 100644 --- a/tensorflow/core/common_runtime/graph_runner_test.cc +++ b/tensorflow/core/common_runtime/graph_runner_test.cc @@ -46,9 +46,9 @@ using test::internal::ExpectEqual; TEST(GraphRunnerTest, SingleConst) { Scope root = Scope::NewRootScope(); auto c = ops::Const(root, 42.0f); + GraphRunner graph_runner(Env::Default()); std::vector outputs; - Status s = GraphRunner::Run(root.graph(), nullptr, Env::Default(), {}, - {c.name()}, &outputs); + Status s = graph_runner.Run(root.graph(), nullptr, {}, {c.name()}, &outputs); TF_ASSERT_OK(s); ExpectEqual(42.0f, outputs[0].scalar()()); } @@ -57,9 +57,10 @@ TEST(GraphRunnerTest, MultiFetchConst) { Scope root = Scope::NewRootScope(); auto c = ops::Const(root, 42.0f); auto pi = ops::Const(root, 3.14f); + GraphRunner graph_runner(Env::Default()); std::vector outputs; - Status s = GraphRunner::Run(root.graph(), nullptr, Env::Default(), {}, - {c.name(), pi.name()}, &outputs); + Status s = graph_runner.Run(root.graph(), nullptr, {}, {c.name(), pi.name()}, + &outputs); TF_ASSERT_OK(s); ExpectEqual(42.0f, outputs[0].scalar()()); ExpectEqual(3.14f, outputs[1].scalar()()); @@ -78,9 +79,10 @@ TEST(GraphRunnerTest, FeedAndFetch) { std::vector> inputs = {{"p1:0", p1_data}, {"p2:0", p2_data}}; + GraphRunner graph_runner(Env::Default()); std::vector outputs; - Status s = GraphRunner::Run(root.graph(), nullptr, Env::Default(), inputs, - {"add:0"}, &outputs); + Status s = + graph_runner.Run(root.graph(), nullptr, inputs, {"add:0"}, &outputs); TF_ASSERT_OK(s); ExpectEqual(3.0f, outputs[0].scalar()()); } diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc index 2f65abde0af..f58faefa9fb 100644 --- a/tensorflow/core/common_runtime/shape_refiner.cc +++ b/tensorflow/core/common_runtime/shape_refiner.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include -#include "tensorflow/core/common_runtime/graph_runner.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/stl_util.h" @@ -33,7 +32,15 @@ using shape_inference::ShapeHandle; ShapeRefiner::ShapeRefiner(int graph_def_version, const OpRegistryInterface* ops) - : graph_def_version_(graph_def_version), ops_registry_(ops) {} + : graph_def_version_(graph_def_version), + ops_registry_(ops), + graph_runner_(Env::Default()) {} + +ShapeRefiner::~ShapeRefiner() { + // The lifetime of the tensors are bound to the GraphRunner, so the tensors + // should be deleted before it. + const_tensor_map_.clear(); +} Status ShapeRefiner::AddNode(const Node* node) { // For each 'input' of this node, fetch the corresponding shape @@ -223,9 +230,8 @@ Status ShapeRefiner::EvaluateConstantTensorForEdge(const Node* node, std::vector outputs; // NOTE; we should pass in a function library runtime if we want // to support constant-expression evaluation on functions. - Status s = GraphRunner::Run(&subgraph, nullptr /* function_library */, - Env::Default(), const_inputs, - {output_tensor_name}, &outputs); + Status s = graph_runner_.Run(&subgraph, nullptr /* function_library */, + const_inputs, {output_tensor_name}, &outputs); // If all kernels in the constant graph are not registered // in the process, GraphRunner::Run may fail, in which case diff --git a/tensorflow/core/common_runtime/shape_refiner.h b/tensorflow/core/common_runtime/shape_refiner.h index b8d69fc05b8..bbde0924c7f 100644 --- a/tensorflow/core/common_runtime/shape_refiner.h +++ b/tensorflow/core/common_runtime/shape_refiner.h @@ -17,6 +17,7 @@ limitations under the License. #include +#include "tensorflow/core/common_runtime/graph_runner.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/status.h" @@ -32,6 +33,7 @@ namespace tensorflow { class ShapeRefiner { public: ShapeRefiner(int graph_def_version, const OpRegistryInterface* ops); + ~ShapeRefiner(); // Performs validation of 'node' and runs 'node's shape function, // storing its shape outputs. @@ -101,6 +103,10 @@ class ShapeRefiner { const int graph_def_version_; const OpRegistryInterface* const ops_registry_; + // The lifetime of the tensors are bound to the runner, so it should be the + // deleted after the tensors. + GraphRunner graph_runner_; + // Stores a map from a node to its InferenceContext. // // Owns values. @@ -118,6 +124,7 @@ class ShapeRefiner { // Only tensors less than 1KiB are currently stored in the cache. static constexpr int64 kMaxTensorSize = 1024; std::unordered_map const_tensor_map_; + TF_DISALLOW_COPY_AND_ASSIGN(ShapeRefiner); }; From 08a3e36c97a644377c07d39d6c707d1abfb2c394 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 31 Mar 2017 11:26:42 -0800 Subject: [PATCH 18/81] Add nccl to tf.contrib. Compile nccl on windows, now that bazel 0.4.5 is used. Change: 151853954 --- tensorflow/contrib/BUILD | 7 ++---- tensorflow/contrib/__init__.py | 1 + .../contrib/cmake/tf_core_kernels.cmake | 3 +++ tensorflow/contrib/cmake/tf_core_ops.cmake | 1 + tensorflow/contrib/cmake/tf_python.cmake | 7 ++++++ tensorflow/contrib/nccl/__init__.py | 22 ++++++++++++++----- .../ci_build/windows/libtensorflow_cpu.sh | 11 +--------- 7 files changed, 32 insertions(+), 20 deletions(-) diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 595d8997388..a726471d0fb 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -7,8 +7,6 @@ exports_files(["LICENSE"]) package(default_visibility = ["//tensorflow:__subpackages__"]) -load("//tensorflow:tensorflow.bzl", "if_not_windows") - py_library( name = "contrib_py", srcs = glob(["**/*.py"]), @@ -46,6 +44,7 @@ py_library( "//tensorflow/contrib/losses:losses_py", "//tensorflow/contrib/memory_stats:memory_stats_py", "//tensorflow/contrib/metrics:metrics_py", + "//tensorflow/contrib/nccl:nccl_py", "//tensorflow/contrib/ndlstm", "//tensorflow/contrib/nn:nn_py", "//tensorflow/contrib/opt:opt_py", @@ -65,9 +64,7 @@ py_library( "//tensorflow/contrib/tfprof", "//tensorflow/contrib/training:training_py", "//tensorflow/contrib/util:util_py", - ] + if_not_windows([ - "//tensorflow/contrib/nccl:nccl_py", - ]), + ], ) cc_library( diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index d4ddd1cf6a6..9b703cf090a 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -45,6 +45,7 @@ from tensorflow.contrib import lookup from tensorflow.contrib import losses from tensorflow.contrib import memory_stats from tensorflow.contrib import metrics +from tensorflow.contrib import nccl from tensorflow.contrib import nn from tensorflow.contrib import opt from tensorflow.contrib import quantization diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake index 33384eed480..0663ba16379 100644 --- a/tensorflow/contrib/cmake/tf_core_kernels.cmake +++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake @@ -37,6 +37,9 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) "${tensorflow_source_dir}/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc" "${tensorflow_source_dir}/tensorflow/contrib/layers/ops/bucketization_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/layers/ops/sparse_feature_cross_op.cc" + "${tensorflow_source_dir}/tensorflow/contrib/nccl/kernels/nccl_manager.cc" + "${tensorflow_source_dir}/tensorflow/contrib/nccl/kernels/nccl_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/nccl/ops/nccl_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/rnn/kernels/blas_gemm.cc" "${tensorflow_source_dir}/tensorflow/contrib/rnn/kernels/gru_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/rnn/kernels/lstm_ops.cc" diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake index 4e300056295..126ef6c00c2 100644 --- a/tensorflow/contrib/cmake/tf_core_ops.cmake +++ b/tensorflow/contrib/cmake/tf_core_ops.cmake @@ -58,6 +58,7 @@ GENERATE_CONTRIB_OP_LIBRARY(image "${tensorflow_source_dir}/tensorflow/contrib/i GENERATE_CONTRIB_OP_LIBRARY(layers_bucketization "${tensorflow_source_dir}/tensorflow/contrib/layers/ops/bucketization_op.cc") GENERATE_CONTRIB_OP_LIBRARY(layers_sparse_feature_cross "${tensorflow_source_dir}/tensorflow/contrib/layers/ops/sparse_feature_cross_op.cc") GENERATE_CONTRIB_OP_LIBRARY(memory_stats "${tensorflow_source_dir}/tensorflow/contrib/memory_stats/ops/memory_stats_ops.cc") +GENERATE_CONTRIB_OP_LIBRARY(nccl "${tensorflow_source_dir}/tensorflow/contrib/nccl/ops/nccl_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(rnn_gru "${tensorflow_source_dir}/tensorflow/contrib/rnn/ops/gru_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(rnn_lstm "${tensorflow_source_dir}/tensorflow/contrib/rnn/ops/lstm_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(tensor_forest "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/ops/tensor_forest_ops.cc") diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 02038da7f85..37bdcec0867 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -405,6 +405,11 @@ add_python_module("tensorflow/contrib/ndlstm/python") add_python_module("tensorflow/contrib/nn") add_python_module("tensorflow/contrib/nn/python") add_python_module("tensorflow/contrib/nn/python/ops") +add_python_module("tensorflow/contrib/nccl") +add_python_module("tensorflow/contrib/nccl/kernels") +add_python_module("tensorflow/contrib/nccl/ops") +add_python_module("tensorflow/contrib/nccl/python") +add_python_module("tensorflow/contrib/nccl/python/ops") add_python_module("tensorflow/contrib/opt") add_python_module("tensorflow/contrib/opt/python") add_python_module("tensorflow/contrib/opt/python/training") @@ -599,6 +604,8 @@ GENERATE_PYTHON_OP_LIB("contrib_layers_sparse_feature_cross_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/layers/ops/gen_sparse_feature_cross_op.py) GENERATE_PYTHON_OP_LIB("contrib_memory_stats_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/memory_stats/ops/gen_memory_stats_ops.py) +GENERATE_PYTHON_OP_LIB("contrib_nccl_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/nccl/ops/gen_nccl_ops.py) GENERATE_PYTHON_OP_LIB("contrib_rnn_gru_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/rnn/ops/gen_gru_ops.py) GENERATE_PYTHON_OP_LIB("contrib_rnn_lstm_ops" diff --git a/tensorflow/contrib/nccl/__init__.py b/tensorflow/contrib/nccl/__init__.py index 0275ed60798..d851c522c03 100644 --- a/tensorflow/contrib/nccl/__init__.py +++ b/tensorflow/contrib/nccl/__init__.py @@ -12,13 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Ops for nccl AllReduce.""" +"""Functions for using NVIDIA nccl collective ops. + +@@all_max +@@all_min +@@all_prod +@@all_sum +@@broadcast + +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.nccl.python.ops.nccl_ops import * -# pylint: enable=wildcard-import +from tensorflow.contrib.nccl.python.ops.nccl_ops import all_max +from tensorflow.contrib.nccl.python.ops.nccl_ops import all_min +from tensorflow.contrib.nccl.python.ops.nccl_ops import all_prod +from tensorflow.contrib.nccl.python.ops.nccl_ops import all_sum +from tensorflow.contrib.nccl.python.ops.nccl_ops import broadcast + +from tensorflow.python.util.all_util import remove_undocumented +remove_undocumented(__name__) diff --git a/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh b/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh index b428bebc6f6..a08a3c28741 100755 --- a/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh +++ b/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh @@ -31,15 +31,6 @@ if [ ! -e "WORKSPACE" ]; then exit 1 fi -#### BEGIN HACKS TO BE RESOLVED WITH NEWER BAZEL VERSIONS #### -# Disable nccl. -# This can be removed once we switch to a bazel release that includes -# https://github.com/bazelbuild/bazel/commit/8e0991cb19eadfcb651cd6987255d5f7c0a58e0a -# (the fix for https://github.com/bazelbuild/bazel/issues/2494). -# Most likley bazel 0.4.5 will contain that. -sed -i -e "s/\"@nccl_archive/#\"@nccl_archive/" ./tensorflow/contrib/nccl/BUILD -sed -i -e "s/\"@nccl_archive/#\"@nccl_archive/" ./tensorflow/tools/pip_package/BUILD - # Enable JNI support for Windows in Bazel. # This can be removed once # https://github.com/bazelbuild/bazel/pull/2599 @@ -66,7 +57,7 @@ bazel build -c opt ${BUILD_OPTS} \ tensorflow/tools/lib_package:jnilicenses_generate # Revert the hacks above -git checkout ./tensorflow/contrib/nccl/BUILD ./tensorflow/tools/pip_package/BUILD +git checkout ./tensorflow/tools/pip_package/BUILD git checkout ./tensorflow/java/src/main/native/BUILD rm -f ./tensorflow/java/src/main/native/windows_jni_md.h From 7b735b0a333a029e23feb0fbfe1197a8f670f68e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 31 Mar 2017 11:40:47 -0800 Subject: [PATCH 19/81] Adding bazel clean prior to running tests in run_pip.sh Change: 151855524 --- tensorflow/tools/ci_build/builds/run_pip_tests.sh | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tensorflow/tools/ci_build/builds/run_pip_tests.sh b/tensorflow/tools/ci_build/builds/run_pip_tests.sh index 10bed0b786b..553e9652a2f 100755 --- a/tensorflow/tools/ci_build/builds/run_pip_tests.sh +++ b/tensorflow/tools/ci_build/builds/run_pip_tests.sh @@ -86,6 +86,9 @@ BAZEL_TEST_TARGETS="//${PIP_TEST_PREFIX}/tensorflow/contrib/... \ //${PIP_TEST_PREFIX}/tensorflow/python/... \ //${PIP_TEST_PREFIX}/tensorflow/tensorboard/..." +# Clean the bazel cache +bazel clean + # Run configure again, we might be using a different python path, due to # virtualenv. export TF_NEED_GCP=0 From d922630aaf1f603af552cb468fd7fef9cec44d64 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 31 Mar 2017 12:05:11 -0800 Subject: [PATCH 20/81] Fix inequalities in Estimator.train(). Change: 151858259 --- tensorflow/python/estimator/estimator.py | 6 +++--- tensorflow/python/estimator/estimator_test.py | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index 85f8b24c9f9..9891ae4eaf9 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -198,10 +198,10 @@ class Estimator(object): if (steps is not None) and (max_steps is not None): raise ValueError('Can not provide both steps and max_steps.') if steps is not None and steps <= 0: - raise ValueError('Must specify steps >= 0, given: {}'.format(steps)) + raise ValueError('Must specify steps > 0, given: {}'.format(steps)) if max_steps is not None and max_steps <= 0: raise ValueError( - 'Must specify max_steps >= 0, given: {}'.format(max_steps)) + 'Must specify max_steps > 0, given: {}'.format(max_steps)) if max_steps is not None: start_step = _load_global_step_from_checkpoint_dir(self._model_dir) @@ -256,7 +256,7 @@ class Estimator(object): hooks = _check_hooks_type(hooks) if steps is not None: if steps <= 0: - raise ValueError('Must specify steps >= 0, given: {}'.format(steps)) + raise ValueError('Must specify steps > 0, given: {}'.format(steps)) hooks.append(evaluation._StopAfterNEvalsHook( # pylint: disable=protected-access num_evals=steps)) diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index 2c9636f797c..398ff20b6b7 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -239,25 +239,25 @@ class EstimatorTrainTest(test.TestCase): def test_steps0_raises_error(self): est = estimator.Estimator( model_fn=_model_fn_with_eval_metric_ops) - with self.assertRaisesRegexp(ValueError, 'Must specify steps >= 0'): + with self.assertRaisesRegexp(ValueError, 'Must specify steps > 0'): est.train(dummy_input_fn, steps=0) def test_steps_negative_raises_error(self): est = estimator.Estimator( model_fn=_model_fn_with_eval_metric_ops) - with self.assertRaisesRegexp(ValueError, 'Must specify steps >= 0'): + with self.assertRaisesRegexp(ValueError, 'Must specify steps > 0'): est.train(dummy_input_fn, steps=-1) def test_max_steps0_raises_error(self): est = estimator.Estimator( model_fn=_model_fn_with_eval_metric_ops) - with self.assertRaisesRegexp(ValueError, 'Must specify max_steps >= 0'): + with self.assertRaisesRegexp(ValueError, 'Must specify max_steps > 0'): est.train(dummy_input_fn, max_steps=0) def test_max_steps_negative_raises_error(self): est = estimator.Estimator( model_fn=_model_fn_with_eval_metric_ops) - with self.assertRaisesRegexp(ValueError, 'Must specify max_steps >= 0'): + with self.assertRaisesRegexp(ValueError, 'Must specify max_steps > 0'): est.train(dummy_input_fn, max_steps=-1) def test_scaffold_is_used(self): @@ -475,14 +475,14 @@ class EstimatorEvaluateTest(test.TestCase): est = estimator.Estimator( model_fn=_model_fn_with_eval_metric_ops) est.train(dummy_input_fn, steps=5) - with self.assertRaisesRegexp(ValueError, 'Must specify steps >= 0'): + with self.assertRaisesRegexp(ValueError, 'Must specify steps > 0'): est.evaluate(dummy_input_fn, steps=0) def test_steps_negative_raises_error(self): est = estimator.Estimator( model_fn=_model_fn_with_eval_metric_ops) est.train(dummy_input_fn, steps=5) - with self.assertRaisesRegexp(ValueError, 'Must specify steps >= 0'): + with self.assertRaisesRegexp(ValueError, 'Must specify steps > 0'): est.evaluate(dummy_input_fn, steps=-1) def test_global_step_metric_raises_error(self): From 8eae27dc1d27f6d03cec1245650bb284167dd0f6 Mon Sep 17 00:00:00 2001 From: Martin Wicke Date: Fri, 31 Mar 2017 12:26:16 -0800 Subject: [PATCH 21/81] Remove old doc generator. Change: 151860614 --- tensorflow/python/BUILD | 33 - tensorflow/python/framework/docs.py | 647 ----- .../python/framework/gen_docs_combined.py | 332 --- tensorflow/tools/docs/BUILD | 39 - tensorflow/tools/docs/gen_cc_md.py | 314 --- tensorflow/tools/docs/gen_docs.sh | 50 - tensorflow/tools/docs/gen_docs_test.sh | 41 - tensorflow/tools/docs/tf-doxy_for_md-config | 2280 ----------------- 8 files changed, 3736 deletions(-) delete mode 100644 tensorflow/python/framework/docs.py delete mode 100644 tensorflow/python/framework/gen_docs_combined.py delete mode 100644 tensorflow/tools/docs/gen_cc_md.py delete mode 100755 tensorflow/tools/docs/gen_docs.sh delete mode 100755 tensorflow/tools/docs/gen_docs_test.sh delete mode 100644 tensorflow/tools/docs/tf-doxy_for_md-config diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 9db763ce78f..4f7d2590456 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -3307,39 +3307,6 @@ py_test( ], ) -py_library( - name = "docs", - srcs = ["framework/docs.py"], - srcs_version = "PY2AND3", -) - -py_library( - name = "gen_docs_combined_lib", - srcs = ["framework/gen_docs_combined.py"], - srcs_version = "PY2AND3", - deps = [ - ":docs", - "//tensorflow:tensorflow_py", - "//tensorflow/contrib/ffmpeg:ffmpeg_ops_py", - "//tensorflow/python/debug:debug_py", - ], -) - -py_binary( - name = "gen_docs_combined", - srcs = ["framework/gen_docs_combined.py"], - main = "framework/gen_docs_combined.py", - srcs_version = "PY2AND3", - deps = [ - ":client", - ":docs", - ":framework", - ":framework_for_generated_wrappers", - "//tensorflow:tensorflow_py", - "//tensorflow/python/debug:debug_py", - ], -) - # ----------------------------------------------------------------------------- # Quantization diff --git a/tensorflow/python/framework/docs.py b/tensorflow/python/framework/docs.py deleted file mode 100644 index 4ae0046117b..00000000000 --- a/tensorflow/python/framework/docs.py +++ /dev/null @@ -1,647 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""Updates generated docs from Python doc comments. - -Updates the documentation files. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import functools -import inspect -import os -import re - - -_arg_re = re.compile(" *([*]{0,2}[a-zA-Z][a-zA-Z0-9_]*):") -_section_re = re.compile("([A-Z][a-zA-Z ]*):$") -_always_drop_symbol_re = re.compile("_[_a-zA-Z0-9]") -_anchor_re = re.compile(r"^[\w.]+$") -_member_mark = "@@" -_indiv_dir = "functions_and_classes" -_num_subdirs = 10 -_subdir_prefix = "shard" - - -class Document(object): - """Base class for an automatically generated document.""" - - def write_markdown_to_file(self, f): - """Writes a Markdown-formatted version of this document to file `f`. - - Args: - f: The output file. - """ - raise NotImplementedError("Document.WriteToFile") - - -class Index(Document): - """An automatically generated index for a collection of documents.""" - - def __init__(self, module_to_name, members, filename_to_library_map, - path_prefix): - """Creates a new Index. - - Args: - module_to_name: Dictionary mapping modules to short names. - members: Dictionary mapping member name to (fullname, member). - filename_to_library_map: A list of (filename, Library) pairs. The order - corresponds to the order in which the libraries appear in the index. - path_prefix: Prefix to add to links in the index. - """ - self._module_to_name = module_to_name - self._members = members - self._filename_to_library_map = filename_to_library_map - self._path_prefix = path_prefix - - def write_markdown_to_file(self, f): - """Writes this index to file `f`. - - The output is formatted as an unordered list. Each list element - contains the title of the library, followed by a list of symbols - in that library hyperlinked to the corresponding anchor in that - library. - - Args: - f: The output file. - """ - print("", file=f) - print("", file=f) - print("# TensorFlow Python reference documentation", file=f) - print("", file=f) - fullname_f = lambda name: self._members[name][0] - anchor_f = lambda name: get_anchor(self._module_to_name, fullname_f(name)) - - for filename, library in self._filename_to_library_map: - sorted_names = sorted(library.mentioned, key=lambda x: (str.lower(x), x)) - member_names = [n for n in sorted_names if n in self._members] - # TODO(wicke): This is a hack that should be removed as soon as the - # website code allows it. - full_filename = self._path_prefix + filename - links = ["[`%s`](%s#%s)" % (name, full_filename, anchor_f(name)) - for name in member_names] - if links: - print("* **[%s](%s)**:" % (library.title, full_filename), file=f) - for link in links: - print(" * %s" % link, file=f) - print("", file=f) - - -def collect_members(module_to_name, exclude=()): - """Collect all symbols from a list of modules. - - Args: - module_to_name: Dictionary mapping modules to short names. - exclude: Set of fully qualified names to exclude. - - Returns: - Dictionary mapping name to (fullname, member) pairs. - - Raises: - RuntimeError: if we can not resolve a name collision. - """ - members = {} - for module, module_name in module_to_name.items(): - all_names = getattr(module, "__all__", None) - for name, member in inspect.getmembers(module): - if ((inspect.isfunction(member) - or inspect.isclass(member) - or isinstance(member, functools.partial)) - and not _always_drop_symbol_re.match(name) and - (all_names is None or name in all_names)): - fullname = "%s.%s" % (module_name, name) - if fullname in exclude: - continue - if name in members: - other_fullname, other_member = members[name] - if member is not other_member: - raise RuntimeError("Short name collision between %s and %s" % - (fullname, other_fullname)) - if len(fullname) == len(other_fullname): - raise RuntimeError("Can't decide whether to use %s or %s for %s: " - "both full names have length %d" % - (fullname, other_fullname, name, len(fullname))) - if len(fullname) > len(other_fullname): - continue # Use the shorter full name - members[name] = fullname, member - return members - - -def get_anchor(module_to_name, fullname): - """Turn a full member name into an anchor. - - Args: - module_to_name: Dictionary mapping modules to short names. - fullname: Fully qualified name of symbol. - - Returns: - HTML anchor string. The longest module name prefix of fullname is - removed to make the anchor. - - Raises: - ValueError: If fullname uses characters invalid in an anchor. - """ - if not _anchor_re.match(fullname): - raise ValueError("'%s' is not a valid anchor" % fullname) - anchor = fullname - for module_name in module_to_name.values(): - if fullname.startswith(module_name + "."): - rest = fullname[len(module_name)+1:] - # Use this prefix iff it is longer than any found before - if len(anchor) > len(rest): - anchor = rest - return anchor - - -def _stable_hash(s): - """A simple string hash that won't change from run to run.""" - ret = 0 - for c in s: - ret = ret * 97 + ord(c) - return ret - - -class Library(Document): - """An automatically generated document for a set of functions and classes.""" - - def __init__(self, - title, - module, - module_to_name, - members, - documented, - exclude_symbols=(), - prefix=None): - """Creates a new Library. - - Args: - title: A human-readable title for the library. - module: Module to pull high level docstring from (for table of contents, - list of Ops to document, etc.). - module_to_name: Dictionary mapping modules to short names. - members: Dictionary mapping member name to (fullname, member). - documented: Set of documented names to update. - exclude_symbols: A list of specific symbols to exclude. - prefix: A string to include at the beginning of the page. - """ - self._title = title - self._module = module - self._module_to_name = module_to_name - self._members = dict(members) # Copy since we mutate it below - self._exclude_symbols = frozenset(exclude_symbols) - documented.update(exclude_symbols) - self._documented = documented - self._mentioned = set() - self._prefix = prefix or "" - - @property - def title(self): - """The human-readable title for this library.""" - return self._title - - @property - def mentioned(self): - """Set of names mentioned in this library.""" - return self._mentioned - - @property - def exclude_symbols(self): - """Set of excluded symbols.""" - return self._exclude_symbols - - def _should_include_member(self, name): - """Returns True if this member should be included in the document.""" - # __x__ should be documented always - name_is_operator = name.startswith("__") and name.endswith("__") - name_is_private = name.startswith("_") and not name_is_operator - name_is_excluded = name in self._exclude_symbols - return not (name_is_private or name_is_excluded) - - def get_imported_modules(self, module): - """Returns the list of modules imported from `module`.""" - for name, member in inspect.getmembers(module): - if inspect.ismodule(member): - yield name, member - - def get_class_members(self, cls_name, cls): - """Returns the list of class members to document in `cls`. - - This function filters the class member to ONLY return those - defined by the class. It drops the inherited ones. - - Args: - cls_name: Qualified name of `cls`. - cls: An inspect object of type 'class'. - - Yields: - name, member tuples. - """ - for name, member in inspect.getmembers(cls): - # Only show methods and properties presently. In Python 3, - # methods register as isfunction. - is_method = (inspect.ismethod(member) or inspect.isfunction(member) - or isinstance(member, functools.partial)) - if not (is_method or isinstance(member, property)): - continue - if self._should_include_member(name): - yield name, ("%s.%s" % (cls_name, name), member) - - def shard_dir(self, name): - """Returns the path of the doc subdirectory for member `name`. - - When generating individual files for each function and class, we shard - the files across several directories to avoid hitting the limit for - files per directory. This function determines the subdirectory for - a member based on a stable hash of its name. - - Args: - name: string. The name of a function or class. - - Returns: - The path to a subdirectory of the api docs directory. - """ - index = _stable_hash(name) % _num_subdirs - return os.path.join(self.functions_and_classes_dir, - _subdir_prefix + str(index)) - - def set_functions_and_classes_dir(self, dirname): - """Sets the name of the directory for function and class markdown files. - - Args: - dirname: string. The name of the directory in which to store function - and class markdown files. - """ - self.functions_and_classes_dir = dirname - - def _generate_signature_for_function(self, func): - """Given a function, returns a string representing its args.""" - args_list = [] - if isinstance(func, functools.partial): - argspec = inspect.getargspec(func.func) - # Remove the args from the original function that have been used up. - first_default_arg = ( - len(argspec.args or []) - len(argspec.defaults or [])) - partial_args = len(func.args) - if argspec.args: - argspec_args = list(argspec.args[partial_args:]) - else: - argspec_args = [] - if argspec.defaults: - argspec_defaults = list(argspec.defaults[ - max(0, partial_args-first_default_arg):]) - else: - argspec_defaults = [] - first_default_arg = max(0, first_default_arg - partial_args) - for kwarg in func.keywords: - if kwarg in argspec_args: - i = argspec_args.index(kwarg) - argspec_args.pop(i) - if i >= first_default_arg: - argspec_defaults.pop(i-first_default_arg) - else: - first_default_arg -= 1 - argspec_varargs = None - argspec_keywords = None - - else: - argspec = inspect.getargspec(func) - argspec_args = argspec.args - argspec_defaults = argspec.defaults - argspec_varargs = argspec.varargs - argspec_keywords = argspec.keywords - - first_arg_with_default = ( - len(argspec_args or []) - len(argspec_defaults or [])) - for arg in argspec_args[:first_arg_with_default]: - if arg == "self": - # Python documentation typically skips `self` when printing method - # signatures. - continue - args_list.append(arg) - - # TODO(mrry): This is a workaround for documenting signature of - # functions that have the @contextlib.contextmanager decorator. - # TODO(aselle): This workaround is brittle on TestCase.__call__ - # so we need to wrap this in a try/catch - # We should do something better. - if argspec_varargs == "args" and argspec_keywords == "kwds": - try: - original_func = func.__closure__[0].cell_contents - return self._generate_signature_for_function(original_func) - except TypeError: - pass - - if argspec_defaults: - for arg, default in zip( - argspec_args[first_arg_with_default:], argspec_defaults): - if callable(default): - if hasattr(default, "__name__"): - args_list.append("%s=%s" % (arg, default.__name__)) - else: - # A callable may be a class instance. - # TODO(fchollet): handle case with non-default constructor - # arguments (currently not present in the TF codebase). - args_list.append("%s=%s()" % (arg, default.__class__.__name__)) - else: - args_list.append("%s=%r" % (arg, default)) - if argspec_varargs: - args_list.append("*" + argspec_varargs) - if argspec_keywords: - args_list.append("**" + argspec_keywords) - return "(" + ", ".join(args_list) + ")" - - def _remove_docstring_indent(self, docstring): - """Remove indenting. - - We follow Python's convention and remove the minimum indent of the lines - after the first, see: - https://www.python.org/dev/peps/pep-0257/#handling-docstring-indentation - preserving relative indentation. - - Args: - docstring: A docstring. - - Returns: - A list of strings, one per line, with the minimum indent stripped. - """ - docstring = docstring or "" - lines = docstring.strip().split("\n") - - min_indent = len(docstring) - for l in lines[1:]: - l = l.rstrip() - if l: - i = 0 - while i < len(l) and l[i] == " ": - i += 1 - if i < min_indent: min_indent = i - for i in range(1, len(lines)): - l = lines[i].rstrip() - if len(l) >= min_indent: - l = l[min_indent:] - lines[i] = l - return lines - - def _print_formatted_docstring(self, docstring, f): - """Formats the given `docstring` as Markdown and prints it to `f`.""" - lines = self._remove_docstring_indent(docstring) - - # Output the lines, identifying "Args" and other section blocks. - i = 0 - - def _at_start_of_section(): - """Returns the header if lines[i] is at start of a docstring section.""" - l = lines[i] - match = _section_re.match(l) - if match and i + 1 < len( - lines) and lines[i + 1].startswith(" "): - return match.group(1) - else: - return None - - while i < len(lines): - l = lines[i] - - section_header = _at_start_of_section() - if section_header: - if i == 0 or lines[i-1]: - print("", file=f) - # Use at least H4 to keep these out of the TOC. - print("##### " + section_header + ":", file=f) - print("", file=f) - i += 1 - outputting_list = False - while i < len(lines): - l = lines[i] - # A new section header terminates the section. - if _at_start_of_section(): - break - match = _arg_re.match(l) - if match: - if not outputting_list: - # We need to start a list. In Markdown, a blank line needs to - # precede a list. - print("", file=f) - outputting_list = True - suffix = l[len(match.group()):].lstrip() - print("* `" + match.group(1) + "`: " + suffix, file=f) - else: - # For lines that don't start with _arg_re, continue the list if it - # has enough indentation. - outputting_list &= l.startswith(" ") - print(l, file=f) - i += 1 - else: - print(l, file=f) - i += 1 - - def _print_function(self, f, prefix, fullname, func): - """Prints the given function to `f`.""" - heading = prefix + " `" + fullname - if not isinstance(func, property): - heading += self._generate_signature_for_function(func) - heading += "` {#%s}" % get_anchor(self._module_to_name, fullname) - print(heading, file=f) - print("", file=f) - self._print_formatted_docstring(inspect.getdoc(func), f) - print("", file=f) - - def _write_member_markdown_to_file(self, f, prefix, name, member): - """Print `member` to `f`.""" - if (inspect.isfunction(member) or inspect.ismethod(member) - or (isinstance(member, functools.partial) - and inspect.isfunction(member.func)) - or isinstance(member, property)): - print("- - -", file=f) - print("", file=f) - self._print_function(f, prefix, name, member) - print("", file=f) - - # Write an individual file for each function. - if inspect.isfunction(member): - indivf = open( - os.path.join(self.shard_dir(name), name + ".md"), "w+") - self._print_function(indivf, prefix, name, member) - elif (inspect.isclass(member) - or (isinstance(member, functools.partial) - and inspect.isclass(member.func))): - print("- - -", file=f) - print("", file=f) - print("%s `class %s` {#%s}" % (prefix, name, - get_anchor(self._module_to_name, name)), - file=f) - print("", file=f) - self._write_class_markdown_to_file(f, name, member) - print("", file=f) - - # Write an individual file for each class. - indivf = open( - os.path.join(self.shard_dir(name), name + ".md"), "w+") - self._write_class_markdown_to_file(indivf, name, member) - else: - raise RuntimeError("Member %s has unknown type %s" % (name, type(member))) - - def _write_docstring_markdown_to_file(self, f, prefix, docstring, members, - imports): - for l in self._remove_docstring_indent(docstring): - if l.startswith(_member_mark): - name = l[len(_member_mark):].strip(" \t") - if name in members: - self._documented.add(name) - self._mentioned.add(name) - self._write_member_markdown_to_file(f, prefix, *members[name]) - del members[name] - elif name in imports: - self._write_module_markdown_to_file(f, imports[name]) - else: - raise ValueError("%s: unknown member `%s`, markdown=`%s`." % ( - self._title, name, l)) - else: - print(l, file=f) - - def _write_class_markdown_to_file(self, f, name, cls): - """Write the class doc to `f`. - - Args: - f: File to write to. - name: name to use. - cls: class object. - """ - # Build the list of class methods to document. - methods = dict(self.get_class_members(name, cls)) - # Used later to check if any methods were called out in the class - # docstring. - num_methods = len(methods) - try: - self._write_docstring_markdown_to_file(f, "####", inspect.getdoc(cls), - methods, {}) - except ValueError as e: - raise ValueError(str(e) + " in class `%s`" % cls.__name__) - - # If some methods were not described, describe them now if they are - # defined by the class itself (not inherited). If NO methods were - # described, describe all methods. - # - # TODO(touts): when all methods have been categorized make it an error - # if some methods are not categorized. - any_method_called_out = (len(methods) != num_methods) - if any_method_called_out: - other_methods = {n: m for n, m in methods.items() if n in cls.__dict__} - if other_methods: - print("\n#### Other Methods", file=f) - else: - other_methods = methods - for name in sorted(other_methods): - self._write_member_markdown_to_file(f, "####", *other_methods[name]) - - def _write_module_markdown_to_file(self, f, module): - imports = dict(self.get_imported_modules(module)) - self._write_docstring_markdown_to_file(f, "###", inspect.getdoc(module), - self._members, imports) - - def write_markdown_to_file(self, f): - """Prints this library to file `f`. - - Args: - f: File to write to. - - Returns: - Dictionary of documented members. - """ - print("", file=f) - print("", file=f) - # TODO(touts): Do not insert these. Let the doc writer put them in - # the module docstring explicitly. - print("#", self._title, file=f) - if self._prefix: - print(self._prefix, file=f) - print("[TOC]", file=f) - print("", file=f) - if self._module is not None: - self._write_module_markdown_to_file(f, self._module) - - def write_other_members(self, f, catch_all=False): - """Writes the leftover members to `f`. - - Args: - f: File to write to. - catch_all: If true, document all missing symbols from any module. - Otherwise, document missing symbols from just this module. - """ - if catch_all: - names = self._members.items() - else: - names = inspect.getmembers(self._module) - all_names = getattr(self._module, "__all__", None) - if all_names is not None: - names = [(n, m) for n, m in names if n in all_names] - leftovers = [] - for name, _ in names: - if name in self._members and name not in self._documented: - leftovers.append(name) - if leftovers: - print("%s: undocumented members: %d" % (self._title, len(leftovers))) - print("\n## Other Functions and Classes", file=f) - for name in sorted(leftovers): - print(" %s" % name) - self._documented.add(name) - self._mentioned.add(name) - self._write_member_markdown_to_file(f, "###", *self._members[name]) - - def assert_no_leftovers(self): - """Generate an error if there are leftover members.""" - leftovers = [] - for name in self._members: - if name in self._members and name not in self._documented: - leftovers.append(name) - if leftovers: - raise RuntimeError("%s: undocumented members: %s" % - (self._title, ", ".join(leftovers))) - - -def write_libraries(output_dir, libraries): - """Write a list of libraries to disk. - - Args: - output_dir: Output directory. - libraries: List of (filename, library) pairs. - """ - files = [open(os.path.join(output_dir, k), "w") for k, _ in libraries] - - # Set the directory in which to save individual class and function md files, - # creating it if it doesn't exist. Create subdirectories to avoid hitting - # the limit for number of files in a directory. - indiv_dir = os.path.join(output_dir, _indiv_dir) - if not os.path.exists(indiv_dir): - os.makedirs(indiv_dir) - - for i in range(0, _num_subdirs): - subdir = os.path.join(indiv_dir, _subdir_prefix + str(i)) - if not os.path.exists(subdir): - os.makedirs(subdir) - - # Document mentioned symbols for all libraries - for f, (_, v) in zip(files, libraries): - v.set_functions_and_classes_dir(indiv_dir) - v.write_markdown_to_file(f) - # Document symbols that no library mentioned. We do this after writing - # out all libraries so that earlier libraries know what later libraries - # documented. - for f, (_, v) in zip(files, libraries): - v.write_other_members(f) - f.close() diff --git a/tensorflow/python/framework/gen_docs_combined.py b/tensorflow/python/framework/gen_docs_combined.py deleted file mode 100644 index 65379dda209..00000000000 --- a/tensorflow/python/framework/gen_docs_combined.py +++ /dev/null @@ -1,332 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""Updates generated docs from Python doc comments.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import argparse -import collections -import os.path -import sys - -import tensorflow as tf - -from tensorflow.contrib import ffmpeg -from tensorflow.python import debug as tf_debug -from tensorflow.python.client import client_lib -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import docs -from tensorflow.python.framework import framework_lib - -FLAGS = None - - -PREFIX_TEXT = """ -Note: Functions taking `Tensor` arguments can also take anything accepted by -@{tf.convert_to_tensor}. -""" - - -def module_names(): - return [ - "tf", - "tf.errors", - "tf.image", - "tf.nn", - "tf.train", - "tf.python_io", - "tf.saved_model", - "tf.summary", - "tf.test", - "tf.contrib.bayesflow.entropy", - "tf.contrib.bayesflow.monte_carlo", - "tf.contrib.bayesflow.stochastic_graph", - "tf.contrib.bayesflow.stochastic_tensor", - "tf.contrib.bayesflow.variational_inference", - "tf.contrib.copy_graph", - "tf.contrib.crf", - "tf.contrib.distributions", - "tf.contrib.distributions.bijector", - "tf.contrib.ffmpeg", - "tf.contrib.framework", - "tf.contrib.graph_editor", - "tf.contrib.integrate", - "tf.contrib.layers", - "tf.contrib.learn", - "tf.contrib.learn.monitors", - "tf.contrib.legacy_seq2seq", - "tf.contrib.linalg", - "tf.contrib.losses", - "tf.contrib.metrics", - "tf.contrib.opt", - "tf.contrib.rnn", - "tf.contrib.solvers", - "tf.contrib.training", - "tf.contrib.util", - "tf_debug", - ] - - -def find_module(base_module, name): - if name == "tf": - return base_module - # Special case for ffmpeg is needed since it's not linked in by default due - # to size concerns. - elif name == "tf.contrib.ffmpeg": - return ffmpeg - elif name == "tf_debug": - return tf_debug - elif name.startswith("tf."): - subname = name[3:] - subnames = subname.split(".") - parent_module = base_module - for s in subnames: - if not hasattr(parent_module, s): - raise ValueError( - "Module not found: {}. Submodule {} not found in parent module {}." - " Possible candidates are {}".format( - name, s, parent_module.__name__, dir(parent_module))) - parent_module = getattr(parent_module, s) - return parent_module - else: - raise ValueError( - "Invalid module name: {}. Module names must start with 'tf.'".format( - name)) - - -def get_module_to_name(names): - return collections.OrderedDict([(find_module(tf, x), x) for x in names]) - - -def all_libraries(module_to_name, members, documented): - """Make a list of the individual files that we want to create. - - Args: - module_to_name: Dictionary mapping modules to short names. - members: Dictionary mapping member name to (fullname, member). - documented: Set of documented names to update. - - Returns: - List of (filename, docs.Library) pairs. - """ - def library(name, title, module=None, **args): - if module is None: - module = sys.modules["tensorflow.python.ops." + name] - return (name + ".md", docs.Library(title=title, - module_to_name=module_to_name, - members=members, - documented=documented, - module=module, - **args)) - return collections.OrderedDict([ - # Splits of module 'tf'. - library("framework", "Building Graphs", framework_lib), - library("check_ops", "Asserts and boolean checks."), - library("constant_op", "Constants, Sequences, and Random Values", - constant_op, prefix=PREFIX_TEXT), - library("state_ops", - "Variables", - exclude_symbols=["create_partitioned_variables"], - prefix=PREFIX_TEXT), - library("array_ops", - "Tensor Transformations", - exclude_symbols=["list_diff"], - prefix=PREFIX_TEXT), - library("math_ops", - "Math", - exclude_symbols=["sparse_matmul", "arg_min", "arg_max", - "lin_space", "sparse_segment_mean_grad"], - prefix=PREFIX_TEXT), - library("string_ops", "Strings", - prefix=PREFIX_TEXT), - library("histogram_ops", "Histograms"), - library("control_flow_ops", "Control Flow", prefix=PREFIX_TEXT), - library("functional_ops", "Higher Order Functions", prefix=PREFIX_TEXT), - library("tensor_array_ops", "TensorArray Operations", prefix=PREFIX_TEXT), - library("session_ops", "Tensor Handle Operations", prefix=PREFIX_TEXT), - library("image", "Images", tf.image, exclude_symbols=["ResizeMethod"], - prefix=PREFIX_TEXT), - library("sparse_ops", - "Sparse Tensors", - exclude_symbols=["serialize_sparse", "serialize_many_sparse", - "deserialize_many_sparse"], - prefix=PREFIX_TEXT), - library("io_ops", - "Inputs and Readers", - exclude_symbols=["LookupTableBase", "HashTable", - "initialize_all_tables", - "tables_initializer", - "parse_single_sequence_example", - "string_to_hash_bucket"], - prefix=PREFIX_TEXT), - library("python_io", "Data IO (Python functions)", tf.python_io), - library("nn", - "Neural Network", - tf.nn, - exclude_symbols=["conv2d_backprop_input", - "conv2d_backprop_filter", "avg_pool_grad", - "max_pool_grad", "max_pool_grad_with_argmax", - "batch_norm_with_global_normalization_grad", - "lrn_grad", "relu6_grad", "softplus_grad", - "softsign_grad", "xw_plus_b", "relu_layer", - "lrn", "batch_norm_with_global_normalization", - "batch_norm_with_global_normalization_grad", - "all_candidate_sampler", "seq2seq"], - prefix=PREFIX_TEXT), - library("client", "Running Graphs", client_lib), - library("train", - "Training", - tf.train, - exclude_symbols=["Feature", "Features", "BytesList", "FloatList", - "Int64List", "Example", "InferenceExample", - "FeatureList", "FeatureLists", "RankingExample", - "SequenceExample"]), - library("script_ops", - "Wraps python functions", - prefix=PREFIX_TEXT), - library("summary", "Summary Operations", tf.summary), - library("test", "Testing", tf.test), - library("contrib.bayesflow.entropy", - "BayesFlow Entropy (contrib)", - tf.contrib.bayesflow.entropy), - library("contrib.bayesflow.monte_carlo", - "BayesFlow Monte Carlo (contrib)", - tf.contrib.bayesflow.monte_carlo), - library("contrib.bayesflow.stochastic_graph", - "BayesFlow Stochastic Graph (contrib)", - tf.contrib.bayesflow.stochastic_graph), - library("contrib.bayesflow.stochastic_tensor", - "BayesFlow Stochastic Tensors (contrib)", - tf.contrib.bayesflow.stochastic_tensor), - library("contrib.bayesflow.variational_inference", - "BayesFlow Variational Inference (contrib)", - tf.contrib.bayesflow.variational_inference), - library("contrib.crf", "CRF (contrib)", tf.contrib.crf), - library("contrib.distributions", "Statistical Distributions (contrib)", - tf.contrib.distributions), - library("contrib.distributions.bijector", - "Random variable transformations (contrib)", - tf.contrib.distributions.bijector), - library("contrib.ffmpeg", "FFmpeg (contrib)", ffmpeg), - library("contrib.framework", "Framework (contrib)", tf.contrib.framework), - library("contrib.graph_editor", "Graph Editor (contrib)", - tf.contrib.graph_editor), - library("contrib.integrate", "Integrate (contrib)", tf.contrib.integrate), - library("contrib.layers", "Layers (contrib)", tf.contrib.layers), - library("contrib.learn", "Learn (contrib)", tf.contrib.learn), - library("contrib.learn.monitors", "Monitors (contrib)", - tf.contrib.learn.monitors), - library("contrib.legacy_seq2seq", "Sequence to Sequence (contrib)", - tf.contrib.legacy_seq2seq), - library("contrib.linalg", "Linear Algebra (contrib)", - tf.contrib.linalg), - library("contrib.losses", "Losses (contrib)", tf.contrib.losses), - library("contrib.opt", "Optimization (contrib)", tf.contrib.opt), - library("contrib.rnn", "RNN and Cells (contrib)", tf.contrib.rnn), - library("contrib.metrics", "Metrics (contrib)", tf.contrib.metrics), - library("contrib.training", "Training (contrib)", tf.contrib.training), - library("contrib.util", "Utilities (contrib)", tf.contrib.util), - library("contrib.copy_graph", "Copying Graph Elements (contrib)", - tf.contrib.copy_graph), - library("tf_debug", "TensorFlow Debugger", tf_debug), - ]) - -_hidden_symbols = ["Event", "LogMessage", "Summary", "SessionLog", "xrange", - "HistogramProto", "ConfigProto", "NodeDef", "GraphDef", - "GPUOptions", "GraphOptions", "RunOptions", "RunMetadata", - "SessionInterface", "BaseSession", "NameAttrList", - "AttrValue", "OptimizerOptions", - "CollectionDef", "MetaGraphDef", "QueueRunnerDef", - "SaverDef", "VariableDef", "TestCase", "GrpcServer", - "ClusterDef", "JobDef", "ServerDef", "TensorInfo"] - -# TODO(skleinfeld, deannarubin) Address shortname -# conflict between tf.contrib.learn.NanLossDuringTrainingError and -# tf.contrib.learn.monitors.NanLossDuringTrainingError, arising due -# to imports in learn/python/learn/__init__.py -# TODO(wicke): Remove contrib.layers.relu* after shortnames are -# disabled. These conflict with tf.nn.relu* -EXCLUDE = frozenset(["tf.contrib.learn.monitors.NanLossDuringTrainingError", - "tf.contrib.layers.dropout", - "tf.contrib.layers.bias_add", - "tf.contrib.layers.conv2d", - "tf.contrib.layers.conv2d_transpose", - "tf.contrib.layers.separable_conv2d", - "tf.contrib.layers.softmax", - "tf.contrib.layers.relu", "tf.contrib.layers.relu6", - "tf.contrib.framework.assert_global_step", - "tf.contrib.framework.get_global_step", - "tf.contrib.learn.NanLossDuringTrainingError", - "tf.contrib.layers.stack", - "tf.contrib.layers.ProblemType", - "tf.confusion_matrix"]) - - -def main(unused_argv): - if not FLAGS.out_dir: - tf.logging.error("out_dir not specified") - return -1 - - # Document libraries - documented = set() - module_to_name = get_module_to_name(module_names()) - members = docs.collect_members(module_to_name, exclude=EXCLUDE) - libraries = all_libraries(module_to_name, members, documented).items() - - # Define catch_all library before calling write_libraries to avoid complaining - # about generically hidden symbols. - catch_all = docs.Library(title="Catch All", module=None, - exclude_symbols=_hidden_symbols, - module_to_name=module_to_name, members=members, - documented=documented) - - # Write docs to files - docs.write_libraries(FLAGS.out_dir, libraries) - - # Make it easy to search for hidden symbols - if FLAGS.print_hidden_regex: - hidden = set(_hidden_symbols) - for _, lib in libraries: - hidden.update(lib.exclude_symbols) - print(r"hidden symbols regex = r'\b(%s)\b'" % "|".join(sorted(hidden))) - - # Verify that all symbols are mentioned in some library doc. - catch_all.assert_no_leftovers() - - # Generate index - with open(os.path.join(FLAGS.out_dir, "index.md"), "w") as f: - docs.Index(module_to_name, members, libraries, - "../../api_docs/python/").write_markdown_to_file(f) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.register("type", "bool", lambda v: v.lower() == "true") - parser.add_argument( - "--out_dir", - type=str, - default=None, - help="Directory to which docs should be written.") - parser.add_argument( - "--print_hidden_regex", - type="bool", - nargs="?", - const=True, - default=False, - help="Dump a regular expression matching any hidden symbol") - FLAGS, unparsed = parser.parse_known_args() - tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/tools/docs/BUILD b/tensorflow/tools/docs/BUILD index ed2626efabb..b425e93aa10 100644 --- a/tensorflow/tools/docs/BUILD +++ b/tensorflow/tools/docs/BUILD @@ -11,13 +11,6 @@ package( load("//tensorflow:tensorflow.bzl", "py_test") -py_binary( - name = "gen_cc_md", - srcs = ["gen_cc_md.py"], - srcs_version = "PY2AND3", - deps = ["//tensorflow:tensorflow_py"], -) - py_library( name = "doc_generator_visitor", srcs = [ @@ -134,38 +127,6 @@ py_test( ], ) -filegroup( - name = "doxy_config", - srcs = ["tf-doxy_for_md-config"], -) - -sh_binary( - name = "gen_docs", - srcs = ["gen_docs.sh"], - data = [ - ":doxy_config", - ":gen_cc_md", - "//tensorflow/python:gen_docs_combined", - ], -) - -sh_test( - name = "gen_docs_test", - size = "small", - srcs = [ - "gen_docs_test.sh", - ], - data = [ - ":gen_docs", - "//tensorflow/core:all_files", - "//tensorflow/python:all_files", - ], - tags = [ - "manual", - "notap", - ], -) - filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/tools/docs/gen_cc_md.py b/tensorflow/tools/docs/gen_cc_md.py deleted file mode 100644 index 931df3230b4..00000000000 --- a/tensorflow/tools/docs/gen_cc_md.py +++ /dev/null @@ -1,314 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Convert Doxygen .xml files to MarkDown (.md files).""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import argparse -import os -import re - -from BeautifulSoup import BeautifulStoneSoup -import tensorflow as tf - -ANCHOR_RE = re.compile(r'\W+') - -PAGE_TEMPLATE = '''# `{0} {1}` - -{2} - -###Member Details - -{3}''' - -INDEX_TEMPLATE = '''# TensorFlow C++ Session API reference documentation - -TensorFlow's public C++ API includes only the API for executing graphs, as of -version 0.5. To control the execution of a graph from C++: - -1. Build the computation graph using the [Python API](../python/). -1. Use [`tf.train.write_graph()`](../python/train.md#write_graph) to -write the graph to a file. -1. Load the graph using the C++ Session API. For example: - - ```c++ - // Reads a model graph definition from disk, and creates a session object you - // can use to run it. - Status LoadGraph(string graph_file_name, Session** session) { - GraphDef graph_def; - TF_RETURN_IF_ERROR( - ReadBinaryProto(Env::Default(), graph_file_name, &graph_def)); - TF_RETURN_IF_ERROR(NewSession(SessionOptions(), session)); - TF_RETURN_IF_ERROR((*session)->Create(graph_def)); - return Status::OK(); - } -``` - -1. Run the graph with a call to `session->Run()` - -## Env - -@@Env -@@RandomAccessFile -@@WritableFile -@@EnvWrapper - -## Session - -@@Session -@@SessionOptions - -## Status - -@@Status -@@Status::State - -## Tensor - -@@Tensor -@@TensorShape -@@TensorShapeDim -@@TensorShapeUtils -@@PartialTensorShape -@@PartialTensorShapeUtils - -## Thread - -@@Thread -@@ThreadOptions -''' - -FLAGS = None - - -def member_definition(member_elt): - def_text = '' - - def_elt = member_elt.find('definition') - if def_elt: - def_text = def_elt.text - - return def_text - - -def member_sig(member_elt): - def_text = member_definition(member_elt) - - argstring_text = '' - argstring = member_elt.find('argsstring') - if argstring: - argstring_text = argstring.text - - sig = def_text + argstring_text - return sig - - -def anchorize(name): - return ANCHOR_RE.sub('_', name) - - -def element_text(member_elt, elt_name): - """Extract all `para` text from (`elt_name` in) `member_elt`.""" - text = [] - if elt_name: - elt = member_elt.find(elt_name) - else: - elt = member_elt - - if elt: - paras = elt.findAll('para') - for p in paras: - text.append(p.getText(separator=u' ').strip()) - return '\n\n'.join(text) - - -def full_member_entry(member_elt): - """Generate the description of `member_elt` for "Member Details".""" - anchor = '{#' + anchorize(member_definition(member_elt)) + '}' - full_entry = '#### `%s` %s' % (member_sig(member_elt), anchor) - - complete_descr = element_text(member_elt, 'briefdescription') + '\n\n' - complete_descr += element_text(member_elt, 'detaileddescription') - - if complete_descr: - full_entry += '\n\n' + complete_descr - - return full_entry - - -def brief_member_entry(member_elt): - """Generate the description of `member_elt` for the "Member Summary".""" - brief_item = '' - brief_descr = element_text(member_elt, 'briefdescription') - if brief_descr: - brief_item = '\n * ' + brief_descr - sig = member_sig(member_elt) - memdef = member_definition(member_elt) - linkified_sig = '[`{0}`](#{1})'.format(sig, anchorize(memdef)) - - return '* ' + linkified_sig + brief_item - - -def all_briefs(members): - briefs = [brief_member_entry(member_elt) for member_elt in members] - return '\n'.join(briefs) - - -def all_fulls(members): - fulls = [full_member_entry(member_elt) for member_elt in members] - return '\n\n'.join(fulls) - - -def page_overview(class_elt): - """Returns the contents of the .md file for `class_elt`.""" - overview_brief = '' - overview_details = '' - - briefs = class_elt.findAll('briefdescription', recursive=False) - if briefs: - overview_brief = element_text(briefs[0], None) - - details = class_elt.findAll('detaileddescription', recursive=False) - if details: - overview_details = element_text(details[0], None) - - return overview_brief + '\n\n' + overview_details - - -def page_with_name(pages, name): - def match(n): - for i in xrange(len(pages)): - if pages[i].get_name() == n: - return i - return None - return match(name) or match('tensorflow::' + name) - - -def get_all_indexed_pages(): - all_pages = set() - lines = INDEX_TEMPLATE.split('\n') - for i in range(len(lines)): - if lines[i].startswith('@@'): - name = lines[i][2:] - all_pages.add(name) - return all_pages - - -def index_page(pages): - """Create the index page linking to `pages` using INDEX_TEMPLATE.""" - pages = pages[:] - lines = INDEX_TEMPLATE.split('\n') - all_md_files = [] - for i in range(len(lines)): - if lines[i].startswith('@@'): - name = lines[i][2:] - page_index = page_with_name(pages, name) - if page_index is None: - raise ValueError('Missing page with name: ' + name) - lines[i] = '* [{0}]({1})'.format( - pages[page_index].get_name(), pages[page_index].get_md_filename()) - all_md_files.append(pages[page_index].get_md_filename()) - pages.pop(page_index) - - return '\n'.join(lines) - - -def page_in_name_list(page, names): - for name in names: - if page.get_name() == name or page.get_name() == 'tensorflow::' + name: - return True - return False - - -class Page(object): - """Holds the MarkDown converted contents of a .xml page.""" - - def __init__(self, xml_path, deftype): - self.type = deftype - xml_file = open(xml_path) - xml = xml_file.read() - xml = xml.replace('', '`').replace('', '`') - # TODO(josh11b): Should not use HTML entities inside ```...```. - soup = BeautifulStoneSoup( - xml, convertEntities=BeautifulStoneSoup.HTML_ENTITIES) - self.name = soup.find('compoundname').text - print('Making page with name ' + self.name + ' (from ' + xml_path + ')') - members = soup('memberdef', prot='public') - fulls = all_fulls(members) - self.overview = page_overview(soup.find('compounddef')) - self.page_text = PAGE_TEMPLATE.format( - self.type, self.name, self.overview, fulls) - - def get_text(self): - return self.page_text - - def get_name(self): - return self.name - - def get_short_name(self): - parse = self.get_name().split('::') - return parse[len(parse)-1] - - def get_type(self): - return self.type - - def get_md_filename(self): - capitalized_type = self.get_type()[0].upper() + self.get_type()[1:] - return capitalized_type + anchorize(self.get_short_name()) + '.md' - - -def main(unused_argv): - print('Converting in ' + FLAGS.src_dir) - pages = [] - all_pages = get_all_indexed_pages() - xml_files = os.listdir(FLAGS.src_dir) - for fname in xml_files: - if len(fname) < 6: continue - newpage = None - if fname[0:5] == 'class': - newpage = Page(os.path.join(FLAGS.src_dir, fname), 'class') - elif fname[0:6] == 'struct': - newpage = Page(os.path.join(FLAGS.src_dir, fname), 'struct') - if newpage is not None and page_in_name_list(newpage, all_pages): - pages.append(newpage) - md_filename = newpage.get_md_filename() - print('Writing ' + md_filename) - md_file = open(os.path.join(FLAGS.out_dir, md_filename), 'w') - print(newpage.get_text(), file=md_file) - - index_text = index_page(pages) - index_md_file = open(os.path.join(FLAGS.out_dir, 'index.md'), 'w') - print(index_text, file=index_md_file) - return 0 - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument( - '--src_dir', - type=str, - default=None, - help='Directory containing the doxygen output.' - ) - parser.add_argument( - '--out_dir', - type=str, - default=None, - help='Directory to which docs should be written.' - ) - FLAGS = parser.parse_args() - - tf.app.run() diff --git a/tensorflow/tools/docs/gen_docs.sh b/tensorflow/tools/docs/gen_docs.sh deleted file mode 100755 index 4f529270ab4..00000000000 --- a/tensorflow/tools/docs/gen_docs.sh +++ /dev/null @@ -1,50 +0,0 @@ -#!/usr/bin/env bash -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -# This script needs to be run from the tensorflow/tools/docs directory -# Pass -a to also rebuild C++ docs. This requires doxygen. - -set -e - -DOC_DIR="g3doc/api_docs" -DOXYGEN_BIN=${DOXYGEN:-doxygen} -DOXYGEN_CONFIG="tools/docs/tf-doxy_for_md-config" -# The TMP_DIR is set inside DOXYGEN_CONFIG and cannot be changed independently -TMP_DIR=/tmp/tensorflow-docs/xml - -if [ ! -f gen_docs.sh ]; then - echo "This script must be run from inside the tensorflow/tools/docs directory." - exit 1 -fi - -# go to the tensorflow/ directory -pushd ../.. -BASE=$(pwd) - -# Make Python docs -bazel run -- //tensorflow/python:gen_docs_combined \ - --out_dir=$BASE/$DOC_DIR/python - -# Check if we should build c++ docs (if -a is given) -if [ x$1 == x-a ]; then - mkdir -p $TMP_DIR - $DOXYGEN_BIN "$BASE/$DOXYGEN_CONFIG" - bazel run -- //tensorflow/tools/docs:gen_cc_md \ - --out_dir=$BASE/$DOC_DIR/cc \ - --src_dir=$TMP_DIR -fi - -popd diff --git a/tensorflow/tools/docs/gen_docs_test.sh b/tensorflow/tools/docs/gen_docs_test.sh deleted file mode 100755 index c8c1955aa06..00000000000 --- a/tensorflow/tools/docs/gen_docs_test.sh +++ /dev/null @@ -1,41 +0,0 @@ -#!/usr/bin/env bash -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -set -eux - -if [ -d $TEST_SRCDIR/org_tensorflow ]; then - TFDIR=$TEST_SRCDIR/org_tensorflow/tensorflow -else - # Support 0.2.1- runfiles. - TFDIR=$TEST_SRCDIR/tensorflow -fi -DOXYGEN=doxygen -DOXYGEN_CONFIG="tf-doxy_for_md-config" -TMP_DIR=/tmp/tensorflow-docs -mkdir -p $TMP_DIR/python -mkdir -p $TMP_DIR/xml -mkdir -p $TMP_DIR/cc - -pushd $TFDIR -python/gen_docs_combined --out_dir=$TMP_DIR/python - -# TODO(wicke): this does not work well inside the build/test jail -#$DOXYGEN "tools/docs/$DOXYGEN_CONFIG" -#tools/docs/gen_cc_md \ -# --out_dir=$TMP_DIR/cc \ -# --src_dir=$TMP_DIR/xml -popd -echo "PASS" diff --git a/tensorflow/tools/docs/tf-doxy_for_md-config b/tensorflow/tools/docs/tf-doxy_for_md-config deleted file mode 100644 index b7fd6e95076..00000000000 --- a/tensorflow/tools/docs/tf-doxy_for_md-config +++ /dev/null @@ -1,2280 +0,0 @@ -# Doxyfile 1.8.5 - -# This file describes the settings to be used by the documentation system -# doxygen (www.doxygen.org) for a project. -# -# All text after a double hash (##) is considered a comment and is placed in -# front of the TAG it is preceding. -# -# All text after a single hash (#) is considered a comment and will be ignored. -# The format is: -# TAG = value [value, ...] -# For lists, items can also be appended using: -# TAG += value [value, ...] -# Values that contain spaces should be placed between quotes (\" \"). - -#--------------------------------------------------------------------------- -# Project related configuration options -#--------------------------------------------------------------------------- - -# This tag specifies the encoding used for all characters in the config file -# that follow. The default is UTF-8 which is also the encoding used for all text -# before the first occurrence of this tag. Doxygen uses libiconv (or the iconv -# built into libc) for the transcoding. See http://www.gnu.org/software/libiconv -# for the list of possible encodings. -# The default value is: UTF-8. - -DOXYFILE_ENCODING = UTF-8 - -# The PROJECT_NAME tag is a single word (or a sequence of words surrounded by -# double-quotes, unless you are using Doxywizard) that should identify the -# project for which the documentation is generated. This name is used in the -# title of most generated pages and in a few other places. -# The default value is: My Project. - -PROJECT_NAME = "TensorFlow" - -# The PROJECT_NUMBER tag can be used to enter a project or revision number. This -# could be handy for archiving the generated documentation or if some version -# control system is used. - -PROJECT_NUMBER = 0.0.0 - -# Using the PROJECT_BRIEF tag one can provide an optional one line description -# for a project that appears at the top of each page and should give viewer a -# quick idea about the purpose of the project. Keep the description short. - -PROJECT_BRIEF = - -# With the PROJECT_LOGO tag one can specify an logo or icon that is included in -# the documentation. The maximum height of the logo should not exceed 55 pixels -# and the maximum width should not exceed 200 pixels. Doxygen will copy the logo -# to the output directory. - -PROJECT_LOGO = - -# The OUTPUT_DIRECTORY tag is used to specify the (relative or absolute) path -# into which the generated documentation will be written. If a relative path is -# entered, it will be relative to the location where doxygen was started. If -# left blank the current directory will be used. - -OUTPUT_DIRECTORY = /tmp/tensorflow-docs/ - -# If the CREATE_SUBDIRS tag is set to YES, then doxygen will create 4096 sub- -# directories (in 2 levels) under the output directory of each output format and -# will distribute the generated files over these directories. Enabling this -# option can be useful when feeding doxygen a huge amount of source files, where -# putting all generated files in the same directory would otherwise causes -# performance problems for the file system. -# The default value is: NO. - -CREATE_SUBDIRS = NO - -# The OUTPUT_LANGUAGE tag is used to specify the language in which all -# documentation generated by doxygen is written. Doxygen will use this -# information to generate all constant output in the proper language. -# Possible values are: Afrikaans, Arabic, Brazilian, Catalan, Chinese, Chinese- -# Traditional, Croatian, Czech, Danish, Dutch, English, Esperanto, Farsi, -# Finnish, French, German, Greek, Hungarian, Italian, Japanese, Japanese-en, -# Korean, Korean-en, Latvian, Norwegian, Macedonian, Persian, Polish, -# Portuguese, Romanian, Russian, Serbian, Slovak, Slovene, Spanish, Swedish, -# Turkish, Ukrainian and Vietnamese. -# The default value is: English. - -OUTPUT_LANGUAGE = English - -# If the BRIEF_MEMBER_DESC tag is set to YES doxygen will include brief member -# descriptions after the members that are listed in the file and class -# documentation (similar to Javadoc). Set to NO to disable this. -# The default value is: YES. - -BRIEF_MEMBER_DESC = YES - -# If the REPEAT_BRIEF tag is set to YES doxygen will prepend the brief -# description of a member or function before the detailed description -# -# Note: If both HIDE_UNDOC_MEMBERS and BRIEF_MEMBER_DESC are set to NO, the -# brief descriptions will be completely suppressed. -# The default value is: YES. - -REPEAT_BRIEF = YES - -# This tag implements a quasi-intelligent brief description abbreviator that is -# used to form the text in various listings. Each string in this list, if found -# as the leading text of the brief description, will be stripped from the text -# and the result, after processing the whole list, is used as the annotated -# text. Otherwise, the brief description is used as-is. If left blank, the -# following values are used ($name is automatically replaced with the name of -# the entity):The $name class, The $name widget, The $name file, is, provides, -# specifies, contains, represents, a, an and the. - -ABBREVIATE_BRIEF = - -# If the ALWAYS_DETAILED_SEC and REPEAT_BRIEF tags are both set to YES then -# doxygen will generate a detailed section even if there is only a brief -# description. -# The default value is: NO. - -ALWAYS_DETAILED_SEC = NO - -# If the INLINE_INHERITED_MEMB tag is set to YES, doxygen will show all -# inherited members of a class in the documentation of that class as if those -# members were ordinary class members. Constructors, destructors and assignment -# operators of the base classes will not be shown. -# The default value is: NO. - -INLINE_INHERITED_MEMB = NO - -# If the FULL_PATH_NAMES tag is set to YES doxygen will prepend the full path -# before files name in the file list and in the header files. If set to NO the -# shortest path that makes the file name unique will be used -# The default value is: YES. - -FULL_PATH_NAMES = YES - -# The STRIP_FROM_PATH tag can be used to strip a user-defined part of the path. -# Stripping is only done if one of the specified strings matches the left-hand -# part of the path. The tag can be used to show relative paths in the file list. -# If left blank the directory from which doxygen is run is used as the path to -# strip. -# -# Note that you can specify absolute paths here, but also relative paths, which -# will be relative from the directory where doxygen is started. -# This tag requires that the tag FULL_PATH_NAMES is set to YES. - -STRIP_FROM_PATH = - -# The STRIP_FROM_INC_PATH tag can be used to strip a user-defined part of the -# path mentioned in the documentation of a class, which tells the reader which -# header file to include in order to use a class. If left blank only the name of -# the header file containing the class definition is used. Otherwise one should -# specify the list of include paths that are normally passed to the compiler -# using the -I flag. - -STRIP_FROM_INC_PATH = - -# If the SHORT_NAMES tag is set to YES, doxygen will generate much shorter (but -# less readable) file names. This can be useful is your file systems doesn't -# support long names like on DOS, Mac, or CD-ROM. -# The default value is: NO. - -SHORT_NAMES = NO - -# If the JAVADOC_AUTOBRIEF tag is set to YES then doxygen will interpret the -# first line (until the first dot) of a Javadoc-style comment as the brief -# description. If set to NO, the Javadoc-style will behave just like regular Qt- -# style comments (thus requiring an explicit @brief command for a brief -# description.) -# The default value is: NO. - -JAVADOC_AUTOBRIEF = NO - -# If the QT_AUTOBRIEF tag is set to YES then doxygen will interpret the first -# line (until the first dot) of a Qt-style comment as the brief description. If -# set to NO, the Qt-style will behave just like regular Qt-style comments (thus -# requiring an explicit \brief command for a brief description.) -# The default value is: NO. - -QT_AUTOBRIEF = NO - -# The MULTILINE_CPP_IS_BRIEF tag can be set to YES to make doxygen treat a -# multi-line C++ special comment block (i.e. a block of //! or /// comments) as -# a brief description. This used to be the default behavior. The new default is -# to treat a multi-line C++ comment block as a detailed description. Set this -# tag to YES if you prefer the old behavior instead. -# -# Note that setting this tag to YES also means that rational rose comments are -# not recognized any more. -# The default value is: NO. - -MULTILINE_CPP_IS_BRIEF = NO - -# If the INHERIT_DOCS tag is set to YES then an undocumented member inherits the -# documentation from any documented member that it re-implements. -# The default value is: YES. - -INHERIT_DOCS = YES - -# If the SEPARATE_MEMBER_PAGES tag is set to YES, then doxygen will produce a -# new page for each member. If set to NO, the documentation of a member will be -# part of the file/class/namespace that contains it. -# The default value is: NO. - -SEPARATE_MEMBER_PAGES = NO - -# The TAB_SIZE tag can be used to set the number of spaces in a tab. Doxygen -# uses this value to replace tabs by spaces in code fragments. -# Minimum value: 1, maximum value: 16, default value: 4. - -TAB_SIZE = 4 - -# This tag can be used to specify a number of aliases that act as commands in -# the documentation. An alias has the form: -# name=value -# For example adding -# "sideeffect=@par Side Effects:\n" -# will allow you to put the command \sideeffect (or @sideeffect) in the -# documentation, which will result in a user-defined paragraph with heading -# "Side Effects:". You can put \n's in the value part of an alias to insert -# newlines. - -ALIASES = - -# This tag can be used to specify a number of word-keyword mappings (TCL only). -# A mapping has the form "name=value". For example adding "class=itcl::class" -# will allow you to use the command class in the itcl::class meaning. - -TCL_SUBST = - -# Set the OPTIMIZE_OUTPUT_FOR_C tag to YES if your project consists of C sources -# only. Doxygen will then generate output that is more tailored for C. For -# instance, some of the names that are used will be different. The list of all -# members will be omitted, etc. -# The default value is: NO. - -OPTIMIZE_OUTPUT_FOR_C = NO - -# Set the OPTIMIZE_OUTPUT_JAVA tag to YES if your project consists of Java or -# Python sources only. Doxygen will then generate output that is more tailored -# for that language. For instance, namespaces will be presented as packages, -# qualified scopes will look different, etc. -# The default value is: NO. - -OPTIMIZE_OUTPUT_JAVA = NO - -# Set the OPTIMIZE_FOR_FORTRAN tag to YES if your project consists of Fortran -# sources. Doxygen will then generate output that is tailored for Fortran. -# The default value is: NO. - -OPTIMIZE_FOR_FORTRAN = NO - -# Set the OPTIMIZE_OUTPUT_VHDL tag to YES if your project consists of VHDL -# sources. Doxygen will then generate output that is tailored for VHDL. -# The default value is: NO. - -OPTIMIZE_OUTPUT_VHDL = NO - -# Doxygen selects the parser to use depending on the extension of the files it -# parses. With this tag you can assign which parser to use for a given -# extension. Doxygen has a built-in mapping, but you can override or extend it -# using this tag. The format is ext=language, where ext is a file extension, and -# language is one of the parsers supported by doxygen: IDL, Java, Javascript, -# C#, C, C++, D, PHP, Objective-C, Python, Fortran, VHDL. For instance to make -# doxygen treat .inc files as Fortran files (default is PHP), and .f files as C -# (default is Fortran), use: inc=Fortran f=C. -# -# Note For files without extension you can use no_extension as a placeholder. -# -# Note that for custom extensions you also need to set FILE_PATTERNS otherwise -# the files are not read by doxygen. - -EXTENSION_MAPPING = - -# If the MARKDOWN_SUPPORT tag is enabled then doxygen pre-processes all comments -# according to the Markdown format, which allows for more readable -# documentation. See http://daringfireball.net/projects/markdown/ for details. -# The output of markdown processing is further processed by doxygen, so you can -# mix doxygen, HTML, and XML commands with Markdown formatting. Disable only in -# case of backward compatibilities issues. -# The default value is: YES. - -MARKDOWN_SUPPORT = YES - -# When enabled doxygen tries to link words that correspond to documented -# classes, or namespaces to their corresponding documentation. Such a link can -# be prevented in individual cases by putting a % sign in front of the word -# or globally by setting AUTOLINK_SUPPORT to NO. -# The default value is: YES. - -AUTOLINK_SUPPORT = YES - -# If you use STL classes (i.e. std::string, std::vector, etc.) but do not want -# to include (a tag file for) the STL sources as input, then you should set this -# tag to YES in order to let doxygen match functions declarations and -# definitions whose arguments contain STL classes (e.g. func(std::string); -# versus func(std::string) {}). This also make the inheritance and collaboration -# diagrams that involve STL classes more complete and accurate. -# The default value is: NO. - -BUILTIN_STL_SUPPORT = NO - -# If you use Microsoft's C++/CLI language, you should set this option to YES to -# enable parsing support. -# The default value is: NO. - -CPP_CLI_SUPPORT = NO - -# Set the SIP_SUPPORT tag to YES if your project consists of sip (see: -# http://www.riverbankcomputing.co.uk/software/sip/intro) sources only. Doxygen -# will parse them like normal C++ but will assume all classes use public instead -# of private inheritance when no explicit protection keyword is present. -# The default value is: NO. - -SIP_SUPPORT = NO - -# For Microsoft's IDL there are propget and propput attributes to indicate -# getter and setter methods for a property. Setting this option to YES will make -# doxygen to replace the get and set methods by a property in the documentation. -# This will only work if the methods are indeed getting or setting a simple -# type. If this is not the case, or you want to show the methods anyway, you -# should set this option to NO. -# The default value is: YES. - -IDL_PROPERTY_SUPPORT = YES - -# If member grouping is used in the documentation and the DISTRIBUTE_GROUP_DOC -# tag is set to YES, then doxygen will reuse the documentation of the first -# member in the group (if any) for the other members of the group. By default -# all members of a group must be documented explicitly. -# The default value is: NO. - -DISTRIBUTE_GROUP_DOC = NO - -# Set the SUBGROUPING tag to YES to allow class member groups of the same type -# (for instance a group of public functions) to be put as a subgroup of that -# type (e.g. under the Public Functions section). Set it to NO to prevent -# subgrouping. Alternatively, this can be done per class using the -# \nosubgrouping command. -# The default value is: YES. - -SUBGROUPING = YES - -# When the INLINE_GROUPED_CLASSES tag is set to YES, classes, structs and unions -# are shown inside the group in which they are included (e.g. using \ingroup) -# instead of on a separate page (for HTML and Man pages) or section (for LaTeX -# and RTF). -# -# Note that this feature does not work in combination with -# SEPARATE_MEMBER_PAGES. -# The default value is: NO. - -INLINE_GROUPED_CLASSES = NO - -# When the INLINE_SIMPLE_STRUCTS tag is set to YES, structs, classes, and unions -# with only public data fields or simple typedef fields will be shown inline in -# the documentation of the scope in which they are defined (i.e. file, -# namespace, or group documentation), provided this scope is documented. If set -# to NO, structs, classes, and unions are shown on a separate page (for HTML and -# Man pages) or section (for LaTeX and RTF). -# The default value is: NO. - -INLINE_SIMPLE_STRUCTS = NO - -# When TYPEDEF_HIDES_STRUCT tag is enabled, a typedef of a struct, union, or -# enum is documented as struct, union, or enum with the name of the typedef. So -# typedef struct TypeS {} TypeT, will appear in the documentation as a struct -# with name TypeT. When disabled the typedef will appear as a member of a file, -# namespace, or class. And the struct will be named TypeS. This can typically be -# useful for C code in case the coding convention dictates that all compound -# types are typedef'ed and only the typedef is referenced, never the tag name. -# The default value is: NO. - -TYPEDEF_HIDES_STRUCT = NO - -# The size of the symbol lookup cache can be set using LOOKUP_CACHE_SIZE. This -# cache is used to resolve symbols given their name and scope. Since this can be -# an expensive process and often the same symbol appears multiple times in the -# code, doxygen keeps a cache of pre-resolved symbols. If the cache is too small -# doxygen will become slower. If the cache is too large, memory is wasted. The -# cache size is given by this formula: 2^(16+LOOKUP_CACHE_SIZE). The valid range -# is 0..9, the default is 0, corresponding to a cache size of 2^16=65536 -# symbols. At the end of a run doxygen will report the cache usage and suggest -# the optimal cache size from a speed point of view. -# Minimum value: 0, maximum value: 9, default value: 0. - -LOOKUP_CACHE_SIZE = 0 - -#--------------------------------------------------------------------------- -# Build related configuration options -#--------------------------------------------------------------------------- - -# If the EXTRACT_ALL tag is set to YES doxygen will assume all entities in -# documentation are documented, even if no documentation was available. Private -# class members and static file members will be hidden unless the -# EXTRACT_PRIVATE respectively EXTRACT_STATIC tags are set to YES. -# Note: This will also disable the warnings about undocumented members that are -# normally produced when WARNINGS is set to YES. -# The default value is: NO. - -EXTRACT_ALL = NO - -# If the EXTRACT_PRIVATE tag is set to YES all private members of a class will -# be included in the documentation. -# The default value is: NO. - -EXTRACT_PRIVATE = NO - -# If the EXTRACT_PACKAGE tag is set to YES all members with package or internal -# scope will be included in the documentation. -# The default value is: NO. - -EXTRACT_PACKAGE = NO - -# If the EXTRACT_STATIC tag is set to YES all static members of a file will be -# included in the documentation. -# The default value is: NO. - -EXTRACT_STATIC = YES - -# If the EXTRACT_LOCAL_CLASSES tag is set to YES classes (and structs) defined -# locally in source files will be included in the documentation. If set to NO -# only classes defined in header files are included. Does not have any effect -# for Java sources. -# The default value is: YES. - -EXTRACT_LOCAL_CLASSES = YES - -# This flag is only useful for Objective-C code. When set to YES local methods, -# which are defined in the implementation section but not in the interface are -# included in the documentation. If set to NO only methods in the interface are -# included. -# The default value is: NO. - -EXTRACT_LOCAL_METHODS = NO - -# If this flag is set to YES, the members of anonymous namespaces will be -# extracted and appear in the documentation as a namespace called -# 'anonymous_namespace{file}', where file will be replaced with the base name of -# the file that contains the anonymous namespace. By default anonymous namespace -# are hidden. -# The default value is: NO. - -EXTRACT_ANON_NSPACES = NO - -# If the HIDE_UNDOC_MEMBERS tag is set to YES, doxygen will hide all -# undocumented members inside documented classes or files. If set to NO these -# members will be included in the various overviews, but no documentation -# section is generated. This option has no effect if EXTRACT_ALL is enabled. -# The default value is: NO. - -HIDE_UNDOC_MEMBERS = NO - -# If the HIDE_UNDOC_CLASSES tag is set to YES, doxygen will hide all -# undocumented classes that are normally visible in the class hierarchy. If set -# to NO these classes will be included in the various overviews. This option has -# no effect if EXTRACT_ALL is enabled. -# The default value is: NO. - -HIDE_UNDOC_CLASSES = NO - -# If the HIDE_FRIEND_COMPOUNDS tag is set to YES, doxygen will hide all friend -# (class|struct|union) declarations. If set to NO these declarations will be -# included in the documentation. -# The default value is: NO. - -HIDE_FRIEND_COMPOUNDS = NO - -# If the HIDE_IN_BODY_DOCS tag is set to YES, doxygen will hide any -# documentation blocks found inside the body of a function. If set to NO these -# blocks will be appended to the function's detailed documentation block. -# The default value is: NO. - -HIDE_IN_BODY_DOCS = NO - -# The INTERNAL_DOCS tag determines if documentation that is typed after a -# \internal command is included. If the tag is set to NO then the documentation -# will be excluded. Set it to YES to include the internal documentation. -# The default value is: NO. - -INTERNAL_DOCS = NO - -# If the CASE_SENSE_NAMES tag is set to NO then doxygen will only generate file -# names in lower-case letters. If set to YES upper-case letters are also -# allowed. This is useful if you have classes or files whose names only differ -# in case and if your file system supports case sensitive file names. Windows -# and Mac users are advised to set this option to NO. -# The default value is: system dependent. - -CASE_SENSE_NAMES = YES - -# If the HIDE_SCOPE_NAMES tag is set to NO then doxygen will show members with -# their full class and namespace scopes in the documentation. If set to YES the -# scope will be hidden. -# The default value is: NO. - -HIDE_SCOPE_NAMES = NO - -# If the SHOW_INCLUDE_FILES tag is set to YES then doxygen will put a list of -# the files that are included by a file in the documentation of that file. -# The default value is: YES. - -SHOW_INCLUDE_FILES = YES - -# If the FORCE_LOCAL_INCLUDES tag is set to YES then doxygen will list include -# files with double quotes in the documentation rather than with sharp brackets. -# The default value is: NO. - -FORCE_LOCAL_INCLUDES = NO - -# If the INLINE_INFO tag is set to YES then a tag [inline] is inserted in the -# documentation for inline members. -# The default value is: YES. - -INLINE_INFO = YES - -# If the SORT_MEMBER_DOCS tag is set to YES then doxygen will sort the -# (detailed) documentation of file and class members alphabetically by member -# name. If set to NO the members will appear in declaration order. -# The default value is: YES. - -SORT_MEMBER_DOCS = YES - -# If the SORT_BRIEF_DOCS tag is set to YES then doxygen will sort the brief -# descriptions of file, namespace and class members alphabetically by member -# name. If set to NO the members will appear in declaration order. -# The default value is: NO. - -SORT_BRIEF_DOCS = NO - -# If the SORT_MEMBERS_CTORS_1ST tag is set to YES then doxygen will sort the -# (brief and detailed) documentation of class members so that constructors and -# destructors are listed first. If set to NO the constructors will appear in the -# respective orders defined by SORT_BRIEF_DOCS and SORT_MEMBER_DOCS. -# Note: If SORT_BRIEF_DOCS is set to NO this option is ignored for sorting brief -# member documentation. -# Note: If SORT_MEMBER_DOCS is set to NO this option is ignored for sorting -# detailed member documentation. -# The default value is: NO. - -SORT_MEMBERS_CTORS_1ST = NO - -# If the SORT_GROUP_NAMES tag is set to YES then doxygen will sort the hierarchy -# of group names into alphabetical order. If set to NO the group names will -# appear in their defined order. -# The default value is: NO. - -SORT_GROUP_NAMES = NO - -# If the SORT_BY_SCOPE_NAME tag is set to YES, the class list will be sorted by -# fully-qualified names, including namespaces. If set to NO, the class list will -# be sorted only by class name, not including the namespace part. -# Note: This option is not very useful if HIDE_SCOPE_NAMES is set to YES. -# Note: This option applies only to the class list, not to the alphabetical -# list. -# The default value is: NO. - -SORT_BY_SCOPE_NAME = NO - -# If the STRICT_PROTO_MATCHING option is enabled and doxygen fails to do proper -# type resolution of all parameters of a function it will reject a match between -# the prototype and the implementation of a member function even if there is -# only one candidate or it is obvious which candidate to choose by doing a -# simple string match. By disabling STRICT_PROTO_MATCHING doxygen will still -# accept a match between prototype and implementation in such cases. -# The default value is: NO. - -STRICT_PROTO_MATCHING = NO - -# The GENERATE_TODOLIST tag can be used to enable ( YES) or disable ( NO) the -# todo list. This list is created by putting \todo commands in the -# documentation. -# The default value is: YES. - -GENERATE_TODOLIST = YES - -# The GENERATE_TESTLIST tag can be used to enable ( YES) or disable ( NO) the -# test list. This list is created by putting \test commands in the -# documentation. -# The default value is: YES. - -GENERATE_TESTLIST = YES - -# The GENERATE_BUGLIST tag can be used to enable ( YES) or disable ( NO) the bug -# list. This list is created by putting \bug commands in the documentation. -# The default value is: YES. - -GENERATE_BUGLIST = YES - -# The GENERATE_DEPRECATEDLIST tag can be used to enable ( YES) or disable ( NO) -# the deprecated list. This list is created by putting \deprecated commands in -# the documentation. -# The default value is: YES. - -GENERATE_DEPRECATEDLIST= YES - -# The ENABLED_SECTIONS tag can be used to enable conditional documentation -# sections, marked by \if ... \endif and \cond -# ... \endcond blocks. - -ENABLED_SECTIONS = - -# The MAX_INITIALIZER_LINES tag determines the maximum number of lines that the -# initial value of a variable or macro / define can have for it to appear in the -# documentation. If the initializer consists of more lines than specified here -# it will be hidden. Use a value of 0 to hide initializers completely. The -# appearance of the value of individual variables and macros / defines can be -# controlled using \showinitializer or \hideinitializer command in the -# documentation regardless of this setting. -# Minimum value: 0, maximum value: 10000, default value: 30. - -MAX_INITIALIZER_LINES = 30 - -# Set the SHOW_USED_FILES tag to NO to disable the list of files generated at -# the bottom of the documentation of classes and structs. If set to YES the list -# will mention the files that were used to generate the documentation. -# The default value is: YES. - -SHOW_USED_FILES = YES - -# Set the SHOW_FILES tag to NO to disable the generation of the Files page. This -# will remove the Files entry from the Quick Index and from the Folder Tree View -# (if specified). -# The default value is: YES. - -SHOW_FILES = YES - -# Set the SHOW_NAMESPACES tag to NO to disable the generation of the Namespaces -# page. This will remove the Namespaces entry from the Quick Index and from the -# Folder Tree View (if specified). -# The default value is: YES. - -SHOW_NAMESPACES = YES - -# The FILE_VERSION_FILTER tag can be used to specify a program or script that -# doxygen should invoke to get the current version for each file (typically from -# the version control system). Doxygen will invoke the program by executing (via -# popen()) the command input-file, where command is the value of the -# FILE_VERSION_FILTER tag, and input-file is the name of an input file provided -# by doxygen. Whatever the program writes to standard output is used as the file -# version. For an example see the documentation. - -FILE_VERSION_FILTER = - -# The LAYOUT_FILE tag can be used to specify a layout file which will be parsed -# by doxygen. The layout file controls the global structure of the generated -# output files in an output format independent way. To create the layout file -# that represents doxygen's defaults, run doxygen with the -l option. You can -# optionally specify a file name after the option, if omitted DoxygenLayout.xml -# will be used as the name of the layout file. -# -# Note that if you run doxygen from a directory containing a file called -# DoxygenLayout.xml, doxygen will parse it automatically even if the LAYOUT_FILE -# tag is left empty. - -LAYOUT_FILE = - -# The CITE_BIB_FILES tag can be used to specify one or more bib files containing -# the reference definitions. This must be a list of .bib files. The .bib -# extension is automatically appended if omitted. This requires the bibtex tool -# to be installed. See also http://en.wikipedia.org/wiki/BibTeX for more info. -# For LaTeX the style of the bibliography can be controlled using -# LATEX_BIB_STYLE. To use this feature you need bibtex and perl available in the -# search path. Do not use file names with spaces, bibtex cannot handle them. See -# also \cite for info how to create references. - -CITE_BIB_FILES = - -#--------------------------------------------------------------------------- -# Configuration options related to warning and progress messages -#--------------------------------------------------------------------------- - -# The QUIET tag can be used to turn on/off the messages that are generated to -# standard output by doxygen. If QUIET is set to YES this implies that the -# messages are off. -# The default value is: NO. - -QUIET = NO - -# The WARNINGS tag can be used to turn on/off the warning messages that are -# generated to standard error ( stderr) by doxygen. If WARNINGS is set to YES -# this implies that the warnings are on. -# -# Tip: Turn warnings on while writing the documentation. -# The default value is: YES. - -WARNINGS = YES - -# If the WARN_IF_UNDOCUMENTED tag is set to YES, then doxygen will generate -# warnings for undocumented members. If EXTRACT_ALL is set to YES then this flag -# will automatically be disabled. -# The default value is: YES. - -WARN_IF_UNDOCUMENTED = YES - -# If the WARN_IF_DOC_ERROR tag is set to YES, doxygen will generate warnings for -# potential errors in the documentation, such as not documenting some parameters -# in a documented function, or documenting parameters that don't exist or using -# markup commands wrongly. -# The default value is: YES. - -WARN_IF_DOC_ERROR = YES - -# This WARN_NO_PARAMDOC option can be enabled to get warnings for functions that -# are documented, but have no documentation for their parameters or return -# value. If set to NO doxygen will only warn about wrong or incomplete parameter -# documentation, but not about the absence of documentation. -# The default value is: NO. - -WARN_NO_PARAMDOC = NO - -# The WARN_FORMAT tag determines the format of the warning messages that doxygen -# can produce. The string should contain the $file, $line, and $text tags, which -# will be replaced by the file and line number from which the warning originated -# and the warning text. Optionally the format may contain $version, which will -# be replaced by the version of the file (if it could be obtained via -# FILE_VERSION_FILTER) -# The default value is: $file:$line: $text. - -WARN_FORMAT = "$file:$line: $text" - -# The WARN_LOGFILE tag can be used to specify a file to which warning and error -# messages should be written. If left blank the output is written to standard -# error (stderr). - -WARN_LOGFILE = - -#--------------------------------------------------------------------------- -# Configuration options related to the input files -#--------------------------------------------------------------------------- - -# The INPUT tag is used to specify the files and/or directories that contain -# documented source files. You may enter file names like myfile.cpp or -# directories like /usr/src/myproject. Separate the files or directories with -# spaces. -# Note: If this tag is empty the current directory is searched. - -INPUT = core/framework core/lib/core core/platform core/public - -# This tag can be used to specify the character encoding of the source files -# that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses -# libiconv (or the iconv built into libc) for the transcoding. See the libiconv -# documentation (see: http://www.gnu.org/software/libiconv) for the list of -# possible encodings. -# The default value is: UTF-8. - -INPUT_ENCODING = UTF-8 - -# If the value of the INPUT tag contains directories, you can use the -# FILE_PATTERNS tag to specify one or more wildcard patterns (like *.cpp and -# *.h) to filter out the source-files in the directories. If left blank the -# following patterns are tested:*.c, *.cc, *.cxx, *.cpp, *.c++, *.java, *.ii, -# *.ixx, *.ipp, *.i++, *.inl, *.idl, *.ddl, *.odl, *.h, *.hh, *.hxx, *.hpp, -# *.h++, *.cs, *.d, *.php, *.php4, *.php5, *.phtml, *.inc, *.m, *.markdown, -# *.md, *.mm, *.dox, *.py, *.f90, *.f, *.for, *.tcl, *.vhd, *.vhdl, *.ucf, -# *.qsf, *.as and *.js. - -FILE_PATTERNS = - -# The RECURSIVE tag can be used to specify whether or not subdirectories should -# be searched for input files as well. -# The default value is: NO. - -RECURSIVE = NO - -# The EXCLUDE tag can be used to specify files and/or directories that should be -# excluded from the INPUT source files. This way you can easily exclude a -# subdirectory from a directory tree whose root is specified with the INPUT tag. -# -# Note that relative paths are relative to the directory from which doxygen is -# run. - -EXCLUDE = - -# The EXCLUDE_SYMLINKS tag can be used to select whether or not files or -# directories that are symbolic links (a Unix file system feature) are excluded -# from the input. -# The default value is: NO. - -EXCLUDE_SYMLINKS = NO - -# If the value of the INPUT tag contains directories, you can use the -# EXCLUDE_PATTERNS tag to specify one or more wildcard patterns to exclude -# certain files from those directories. -# -# Note that the wildcards are matched against the file with absolute path, so to -# exclude all test directories for example use the pattern */test/* - -EXCLUDE_PATTERNS = - -# The EXCLUDE_SYMBOLS tag can be used to specify one or more symbol names -# (namespaces, classes, functions, etc.) that should be excluded from the -# output. The symbol name can be a fully qualified name, a word, or if the -# wildcard * is used, a substring. Examples: ANamespace, AClass, -# AClass::ANamespace, ANamespace::*Test -# -# Note that the wildcards are matched against the file with absolute path, so to -# exclude all test directories use the pattern */test/* - -EXCLUDE_SYMBOLS = - -# The EXAMPLE_PATH tag can be used to specify one or more files or directories -# that contain example code fragments that are included (see the \include -# command). - -EXAMPLE_PATH = - -# If the value of the EXAMPLE_PATH tag contains directories, you can use the -# EXAMPLE_PATTERNS tag to specify one or more wildcard pattern (like *.cpp and -# *.h) to filter out the source-files in the directories. If left blank all -# files are included. - -EXAMPLE_PATTERNS = - -# If the EXAMPLE_RECURSIVE tag is set to YES then subdirectories will be -# searched for input files to be used with the \include or \dontinclude commands -# irrespective of the value of the RECURSIVE tag. -# The default value is: NO. - -EXAMPLE_RECURSIVE = NO - -# The IMAGE_PATH tag can be used to specify one or more files or directories -# that contain images that are to be included in the documentation (see the -# \image command). - -IMAGE_PATH = - -# The INPUT_FILTER tag can be used to specify a program that doxygen should -# invoke to filter for each input file. Doxygen will invoke the filter program -# by executing (via popen()) the command: -# -# -# -# where is the value of the INPUT_FILTER tag, and is the -# name of an input file. Doxygen will then use the output that the filter -# program writes to standard output. If FILTER_PATTERNS is specified, this tag -# will be ignored. -# -# Note that the filter must not add or remove lines; it is applied before the -# code is scanned, but not when the output code is generated. If lines are added -# or removed, the anchors will not be placed correctly. - -INPUT_FILTER = - -# The FILTER_PATTERNS tag can be used to specify filters on a per file pattern -# basis. Doxygen will compare the file name with each pattern and apply the -# filter if there is a match. The filters are a list of the form: pattern=filter -# (like *.cpp=my_cpp_filter). See INPUT_FILTER for further information on how -# filters are used. If the FILTER_PATTERNS tag is empty or if none of the -# patterns match the file name, INPUT_FILTER is applied. - -FILTER_PATTERNS = - -# If the FILTER_SOURCE_FILES tag is set to YES, the input filter (if set using -# INPUT_FILTER ) will also be used to filter the input files that are used for -# producing the source files to browse (i.e. when SOURCE_BROWSER is set to YES). -# The default value is: NO. - -FILTER_SOURCE_FILES = NO - -# The FILTER_SOURCE_PATTERNS tag can be used to specify source filters per file -# pattern. A pattern will override the setting for FILTER_PATTERN (if any) and -# it is also possible to disable source filtering for a specific pattern using -# *.ext= (so without naming a filter). -# This tag requires that the tag FILTER_SOURCE_FILES is set to YES. - -FILTER_SOURCE_PATTERNS = - -# If the USE_MDFILE_AS_MAINPAGE tag refers to the name of a markdown file that -# is part of the input, its contents will be placed on the main page -# (index.html). This can be useful if you have a project on for instance GitHub -# and want to reuse the introduction page also for the doxygen output. - -USE_MDFILE_AS_MAINPAGE = - -#--------------------------------------------------------------------------- -# Configuration options related to source browsing -#--------------------------------------------------------------------------- - -# If the SOURCE_BROWSER tag is set to YES then a list of source files will be -# generated. Documented entities will be cross-referenced with these sources. -# -# Note: To get rid of all source code in the generated output, make sure that -# also VERBATIM_HEADERS is set to NO. -# The default value is: NO. - -SOURCE_BROWSER = NO - -# Setting the INLINE_SOURCES tag to YES will include the body of functions, -# classes and enums directly into the documentation. -# The default value is: NO. - -INLINE_SOURCES = NO - -# Setting the STRIP_CODE_COMMENTS tag to YES will instruct doxygen to hide any -# special comment blocks from generated source code fragments. Normal C, C++ and -# Fortran comments will always remain visible. -# The default value is: YES. - -STRIP_CODE_COMMENTS = YES - -# If the REFERENCED_BY_RELATION tag is set to YES then for each documented -# function all documented functions referencing it will be listed. -# The default value is: NO. - -REFERENCED_BY_RELATION = NO - -# If the REFERENCES_RELATION tag is set to YES then for each documented function -# all documented entities called/used by that function will be listed. -# The default value is: NO. - -REFERENCES_RELATION = NO - -# If the REFERENCES_LINK_SOURCE tag is set to YES and SOURCE_BROWSER tag is set -# to YES, then the hyperlinks from functions in REFERENCES_RELATION and -# REFERENCED_BY_RELATION lists will link to the source code. Otherwise they will -# link to the documentation. -# The default value is: YES. - -REFERENCES_LINK_SOURCE = YES - -# If SOURCE_TOOLTIPS is enabled (the default) then hovering a hyperlink in the -# source code will show a tooltip with additional information such as prototype, -# brief description and links to the definition and documentation. Since this -# will make the HTML file larger and loading of large files a bit slower, you -# can opt to disable this feature. -# The default value is: YES. -# This tag requires that the tag SOURCE_BROWSER is set to YES. - -SOURCE_TOOLTIPS = YES - -# If the USE_HTAGS tag is set to YES then the references to source code will -# point to the HTML generated by the htags(1) tool instead of doxygen built-in -# source browser. The htags tool is part of GNU's global source tagging system -# (see http://www.gnu.org/software/global/global.html). You will need version -# 4.8.6 or higher. -# -# To use it do the following: -# - Install the latest version of global -# - Enable SOURCE_BROWSER and USE_HTAGS in the config file -# - Make sure the INPUT points to the root of the source tree -# - Run doxygen as normal -# -# Doxygen will invoke htags (and that will in turn invoke gtags), so these -# tools must be available from the command line (i.e. in the search path). -# -# The result: instead of the source browser generated by doxygen, the links to -# source code will now point to the output of htags. -# The default value is: NO. -# This tag requires that the tag SOURCE_BROWSER is set to YES. - -USE_HTAGS = NO - -# If the VERBATIM_HEADERS tag is set the YES then doxygen will generate a -# verbatim copy of the header file for each class for which an include is -# specified. Set to NO to disable this. -# See also: Section \class. -# The default value is: YES. - -VERBATIM_HEADERS = YES - -#--------------------------------------------------------------------------- -# Configuration options related to the alphabetical class index -#--------------------------------------------------------------------------- - -# If the ALPHABETICAL_INDEX tag is set to YES, an alphabetical index of all -# compounds will be generated. Enable this if the project contains a lot of -# classes, structs, unions or interfaces. -# The default value is: YES. - -ALPHABETICAL_INDEX = YES - -# The COLS_IN_ALPHA_INDEX tag can be used to specify the number of columns in -# which the alphabetical index list will be split. -# Minimum value: 1, maximum value: 20, default value: 5. -# This tag requires that the tag ALPHABETICAL_INDEX is set to YES. - -COLS_IN_ALPHA_INDEX = 5 - -# In case all classes in a project start with a common prefix, all classes will -# be put under the same header in the alphabetical index. The IGNORE_PREFIX tag -# can be used to specify a prefix (or a list of prefixes) that should be ignored -# while generating the index headers. -# This tag requires that the tag ALPHABETICAL_INDEX is set to YES. - -IGNORE_PREFIX = - -#--------------------------------------------------------------------------- -# Configuration options related to the HTML output -#--------------------------------------------------------------------------- - -# If the GENERATE_HTML tag is set to YES doxygen will generate HTML output -# The default value is: YES. - -GENERATE_HTML = NO - -# The HTML_OUTPUT tag is used to specify where the HTML docs will be put. If a -# relative path is entered the value of OUTPUT_DIRECTORY will be put in front of -# it. -# The default directory is: html. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_OUTPUT = html - -# The HTML_FILE_EXTENSION tag can be used to specify the file extension for each -# generated HTML page (for example: .htm, .php, .asp). -# The default value is: .html. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_FILE_EXTENSION = .html - -# The HTML_HEADER tag can be used to specify a user-defined HTML header file for -# each generated HTML page. If the tag is left blank doxygen will generate a -# standard header. -# -# To get valid HTML the header file that includes any scripts and style sheets -# that doxygen needs, which is dependent on the configuration options used (e.g. -# the setting GENERATE_TREEVIEW). It is highly recommended to start with a -# default header using -# doxygen -w html new_header.html new_footer.html new_stylesheet.css -# YourConfigFile -# and then modify the file new_header.html. See also section "Doxygen usage" -# for information on how to generate the default header that doxygen normally -# uses. -# Note: The header is subject to change so you typically have to regenerate the -# default header when upgrading to a newer version of doxygen. For a description -# of the possible markers and block names see the documentation. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_HEADER = - -# The HTML_FOOTER tag can be used to specify a user-defined HTML footer for each -# generated HTML page. If the tag is left blank doxygen will generate a standard -# footer. See HTML_HEADER for more information on how to generate a default -# footer and what special commands can be used inside the footer. See also -# section "Doxygen usage" for information on how to generate the default footer -# that doxygen normally uses. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_FOOTER = - -# The HTML_STYLESHEET tag can be used to specify a user-defined cascading style -# sheet that is used by each HTML page. It can be used to fine-tune the look of -# the HTML output. If left blank doxygen will generate a default style sheet. -# See also section "Doxygen usage" for information on how to generate the style -# sheet that doxygen normally uses. -# Note: It is recommended to use HTML_EXTRA_STYLESHEET instead of this tag, as -# it is more robust and this tag (HTML_STYLESHEET) will in the future become -# obsolete. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_STYLESHEET = - -# The HTML_EXTRA_STYLESHEET tag can be used to specify an additional user- -# defined cascading style sheet that is included after the standard style sheets -# created by doxygen. Using this option one can overrule certain style aspects. -# This is preferred over using HTML_STYLESHEET since it does not replace the -# standard style sheet and is therefor more robust against future updates. -# Doxygen will copy the style sheet file to the output directory. For an example -# see the documentation. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_EXTRA_STYLESHEET = - -# The HTML_EXTRA_FILES tag can be used to specify one or more extra images or -# other source files which should be copied to the HTML output directory. Note -# that these files will be copied to the base HTML output directory. Use the -# $relpath^ marker in the HTML_HEADER and/or HTML_FOOTER files to load these -# files. In the HTML_STYLESHEET file, use the file name only. Also note that the -# files will be copied as-is; there are no commands or markers available. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_EXTRA_FILES = - -# The HTML_COLORSTYLE_HUE tag controls the color of the HTML output. Doxygen -# will adjust the colors in the stylesheet and background images according to -# this color. Hue is specified as an angle on a colorwheel, see -# http://en.wikipedia.org/wiki/Hue for more information. For instance the value -# 0 represents red, 60 is yellow, 120 is green, 180 is cyan, 240 is blue, 300 -# purple, and 360 is red again. -# Minimum value: 0, maximum value: 359, default value: 220. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_COLORSTYLE_HUE = 220 - -# The HTML_COLORSTYLE_SAT tag controls the purity (or saturation) of the colors -# in the HTML output. For a value of 0 the output will use grayscales only. A -# value of 255 will produce the most vivid colors. -# Minimum value: 0, maximum value: 255, default value: 100. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_COLORSTYLE_SAT = 100 - -# The HTML_COLORSTYLE_GAMMA tag controls the gamma correction applied to the -# luminance component of the colors in the HTML output. Values below 100 -# gradually make the output lighter, whereas values above 100 make the output -# darker. The value divided by 100 is the actual gamma applied, so 80 represents -# a gamma of 0.8, The value 220 represents a gamma of 2.2, and 100 does not -# change the gamma. -# Minimum value: 40, maximum value: 240, default value: 80. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_COLORSTYLE_GAMMA = 80 - -# If the HTML_TIMESTAMP tag is set to YES then the footer of each generated HTML -# page will contain the date and time when the page was generated. Setting this -# to NO can help when comparing the output of multiple runs. -# The default value is: YES. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_TIMESTAMP = NO - -# If the HTML_DYNAMIC_SECTIONS tag is set to YES then the generated HTML -# documentation will contain sections that can be hidden and shown after the -# page has loaded. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_DYNAMIC_SECTIONS = NO - -# With HTML_INDEX_NUM_ENTRIES one can control the preferred number of entries -# shown in the various tree structured indices initially; the user can expand -# and collapse entries dynamically later on. Doxygen will expand the tree to -# such a level that at most the specified number of entries are visible (unless -# a fully collapsed tree already exceeds this amount). So setting the number of -# entries 1 will produce a full collapsed tree by default. 0 is a special value -# representing an infinite number of entries and will result in a full expanded -# tree by default. -# Minimum value: 0, maximum value: 9999, default value: 100. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_INDEX_NUM_ENTRIES = 100 - -# If the GENERATE_DOCSET tag is set to YES, additional index files will be -# generated that can be used as input for Apple's Xcode 3 integrated development -# environment (see: http://developer.apple.com/tools/xcode/), introduced with -# OSX 10.5 (Leopard). To create a documentation set, doxygen will generate a -# Makefile in the HTML output directory. Running make will produce the docset in -# that directory and running make install will install the docset in -# ~/Library/Developer/Shared/Documentation/DocSets so that Xcode will find it at -# startup. See http://developer.apple.com/tools/creatingdocsetswithdoxygen.html -# for more information. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -GENERATE_DOCSET = NO - -# This tag determines the name of the docset feed. A documentation feed provides -# an umbrella under which multiple documentation sets from a single provider -# (such as a company or product suite) can be grouped. -# The default value is: Doxygen generated docs. -# This tag requires that the tag GENERATE_DOCSET is set to YES. - -DOCSET_FEEDNAME = "Doxygen generated docs" - -# This tag specifies a string that should uniquely identify the documentation -# set bundle. This should be a reverse domain-name style string, e.g. -# com.mycompany.MyDocSet. Doxygen will append .docset to the name. -# The default value is: org.doxygen.Project. -# This tag requires that the tag GENERATE_DOCSET is set to YES. - -DOCSET_BUNDLE_ID = org.doxygen.Project - -# The DOCSET_PUBLISHER_ID tag specifies a string that should uniquely identify -# the documentation publisher. This should be a reverse domain-name style -# string, e.g. com.mycompany.MyDocSet.documentation. -# The default value is: org.doxygen.Publisher. -# This tag requires that the tag GENERATE_DOCSET is set to YES. - -DOCSET_PUBLISHER_ID = org.doxygen.Publisher - -# The DOCSET_PUBLISHER_NAME tag identifies the documentation publisher. -# The default value is: Publisher. -# This tag requires that the tag GENERATE_DOCSET is set to YES. - -DOCSET_PUBLISHER_NAME = Publisher - -# If the GENERATE_HTMLHELP tag is set to YES then doxygen generates three -# additional HTML index files: index.hhp, index.hhc, and index.hhk. The -# index.hhp is a project file that can be read by Microsoft's HTML Help Workshop -# (see: http://www.microsoft.com/en-us/download/details.aspx?id=21138) on -# Windows. -# -# The HTML Help Workshop contains a compiler that can convert all HTML output -# generated by doxygen into a single compiled HTML file (.chm). Compiled HTML -# files are now used as the Windows 98 help format, and will replace the old -# Windows help format (.hlp) on all Windows platforms in the future. Compressed -# HTML files also contain an index, a table of contents, and you can search for -# words in the documentation. The HTML workshop also contains a viewer for -# compressed HTML files. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -GENERATE_HTMLHELP = NO - -# The CHM_FILE tag can be used to specify the file name of the resulting .chm -# file. You can add a path in front of the file if the result should not be -# written to the html output directory. -# This tag requires that the tag GENERATE_HTMLHELP is set to YES. - -CHM_FILE = - -# The HHC_LOCATION tag can be used to specify the location (absolute path -# including file name) of the HTML help compiler ( hhc.exe). If non-empty -# doxygen will try to run the HTML help compiler on the generated index.hhp. -# The file has to be specified with full path. -# This tag requires that the tag GENERATE_HTMLHELP is set to YES. - -HHC_LOCATION = - -# The GENERATE_CHI flag controls if a separate .chi index file is generated ( -# YES) or that it should be included in the master .chm file ( NO). -# The default value is: NO. -# This tag requires that the tag GENERATE_HTMLHELP is set to YES. - -GENERATE_CHI = NO - -# The CHM_INDEX_ENCODING is used to encode HtmlHelp index ( hhk), content ( hhc) -# and project file content. -# This tag requires that the tag GENERATE_HTMLHELP is set to YES. - -CHM_INDEX_ENCODING = - -# The BINARY_TOC flag controls whether a binary table of contents is generated ( -# YES) or a normal table of contents ( NO) in the .chm file. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTMLHELP is set to YES. - -BINARY_TOC = NO - -# The TOC_EXPAND flag can be set to YES to add extra items for group members to -# the table of contents of the HTML help documentation and to the tree view. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTMLHELP is set to YES. - -TOC_EXPAND = NO - -# If the GENERATE_QHP tag is set to YES and both QHP_NAMESPACE and -# QHP_VIRTUAL_FOLDER are set, an additional index file will be generated that -# can be used as input for Qt's qhelpgenerator to generate a Qt Compressed Help -# (.qch) of the generated HTML documentation. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -GENERATE_QHP = NO - -# If the QHG_LOCATION tag is specified, the QCH_FILE tag can be used to specify -# the file name of the resulting .qch file. The path specified is relative to -# the HTML output folder. -# This tag requires that the tag GENERATE_QHP is set to YES. - -QCH_FILE = - -# The QHP_NAMESPACE tag specifies the namespace to use when generating Qt Help -# Project output. For more information please see Qt Help Project / Namespace -# (see: http://qt-project.org/doc/qt-4.8/qthelpproject.html#namespace). -# The default value is: org.doxygen.Project. -# This tag requires that the tag GENERATE_QHP is set to YES. - -QHP_NAMESPACE = org.doxygen.Project - -# The QHP_VIRTUAL_FOLDER tag specifies the namespace to use when generating Qt -# Help Project output. For more information please see Qt Help Project / Virtual -# Folders (see: http://qt-project.org/doc/qt-4.8/qthelpproject.html#virtual- -# folders). -# The default value is: doc. -# This tag requires that the tag GENERATE_QHP is set to YES. - -QHP_VIRTUAL_FOLDER = doc - -# If the QHP_CUST_FILTER_NAME tag is set, it specifies the name of a custom -# filter to add. For more information please see Qt Help Project / Custom -# Filters (see: http://qt-project.org/doc/qt-4.8/qthelpproject.html#custom- -# filters). -# This tag requires that the tag GENERATE_QHP is set to YES. - -QHP_CUST_FILTER_NAME = - -# The QHP_CUST_FILTER_ATTRS tag specifies the list of the attributes of the -# custom filter to add. For more information please see Qt Help Project / Custom -# Filters (see: http://qt-project.org/doc/qt-4.8/qthelpproject.html#custom- -# filters). -# This tag requires that the tag GENERATE_QHP is set to YES. - -QHP_CUST_FILTER_ATTRS = - -# The QHP_SECT_FILTER_ATTRS tag specifies the list of the attributes this -# project's filter section matches. Qt Help Project / Filter Attributes (see: -# http://qt-project.org/doc/qt-4.8/qthelpproject.html#filter-attributes). -# This tag requires that the tag GENERATE_QHP is set to YES. - -QHP_SECT_FILTER_ATTRS = - -# The QHG_LOCATION tag can be used to specify the location of Qt's -# qhelpgenerator. If non-empty doxygen will try to run qhelpgenerator on the -# generated .qhp file. -# This tag requires that the tag GENERATE_QHP is set to YES. - -QHG_LOCATION = - -# If the GENERATE_ECLIPSEHELP tag is set to YES, additional index files will be -# generated, together with the HTML files, they form an Eclipse help plugin. To -# install this plugin and make it available under the help contents menu in -# Eclipse, the contents of the directory containing the HTML and XML files needs -# to be copied into the plugins directory of eclipse. The name of the directory -# within the plugins directory should be the same as the ECLIPSE_DOC_ID value. -# After copying Eclipse needs to be restarted before the help appears. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -GENERATE_ECLIPSEHELP = NO - -# A unique identifier for the Eclipse help plugin. When installing the plugin -# the directory name containing the HTML and XML files should also have this -# name. Each documentation set should have its own identifier. -# The default value is: org.doxygen.Project. -# This tag requires that the tag GENERATE_ECLIPSEHELP is set to YES. - -ECLIPSE_DOC_ID = org.doxygen.Project - -# If you want full control over the layout of the generated HTML pages it might -# be necessary to disable the index and replace it with your own. The -# DISABLE_INDEX tag can be used to turn on/off the condensed index (tabs) at top -# of each HTML page. A value of NO enables the index and the value YES disables -# it. Since the tabs in the index contain the same information as the navigation -# tree, you can set this option to YES if you also set GENERATE_TREEVIEW to YES. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -DISABLE_INDEX = NO - -# The GENERATE_TREEVIEW tag is used to specify whether a tree-like index -# structure should be generated to display hierarchical information. If the tag -# value is set to YES, a side panel will be generated containing a tree-like -# index structure (just like the one that is generated for HTML Help). For this -# to work a browser that supports JavaScript, DHTML, CSS and frames is required -# (i.e. any modern browser). Windows users are probably better off using the -# HTML help feature. Via custom stylesheets (see HTML_EXTRA_STYLESHEET) one can -# further fine-tune the look of the index. As an example, the default style -# sheet generated by doxygen has an example that shows how to put an image at -# the root of the tree instead of the PROJECT_NAME. Since the tree basically has -# the same information as the tab index, you could consider setting -# DISABLE_INDEX to YES when enabling this option. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -GENERATE_TREEVIEW = NO - -# The ENUM_VALUES_PER_LINE tag can be used to set the number of enum values that -# doxygen will group on one line in the generated HTML documentation. -# -# Note that a value of 0 will completely suppress the enum values from appearing -# in the overview section. -# Minimum value: 0, maximum value: 20, default value: 4. -# This tag requires that the tag GENERATE_HTML is set to YES. - -ENUM_VALUES_PER_LINE = 4 - -# If the treeview is enabled (see GENERATE_TREEVIEW) then this tag can be used -# to set the initial width (in pixels) of the frame in which the tree is shown. -# Minimum value: 0, maximum value: 1500, default value: 250. -# This tag requires that the tag GENERATE_HTML is set to YES. - -TREEVIEW_WIDTH = 250 - -# When the EXT_LINKS_IN_WINDOW option is set to YES doxygen will open links to -# external symbols imported via tag files in a separate window. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -EXT_LINKS_IN_WINDOW = NO - -# Use this tag to change the font size of LaTeX formulas included as images in -# the HTML documentation. When you change the font size after a successful -# doxygen run you need to manually remove any form_*.png images from the HTML -# output directory to force them to be regenerated. -# Minimum value: 8, maximum value: 50, default value: 10. -# This tag requires that the tag GENERATE_HTML is set to YES. - -FORMULA_FONTSIZE = 10 - -# Use the FORMULA_TRANPARENT tag to determine whether or not the images -# generated for formulas are transparent PNGs. Transparent PNGs are not -# supported properly for IE 6.0, but are supported on all modern browsers. -# -# Note that when changing this option you need to delete any form_*.png files in -# the HTML output directory before the changes have effect. -# The default value is: YES. -# This tag requires that the tag GENERATE_HTML is set to YES. - -FORMULA_TRANSPARENT = YES - -# Enable the USE_MATHJAX option to render LaTeX formulas using MathJax (see -# http://www.mathjax.org) which uses client side Javascript for the rendering -# instead of using prerendered bitmaps. Use this if you do not have LaTeX -# installed or if you want to formulas look prettier in the HTML output. When -# enabled you may also need to install MathJax separately and configure the path -# to it using the MATHJAX_RELPATH option. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -USE_MATHJAX = NO - -# When MathJax is enabled you can set the default output format to be used for -# the MathJax output. See the MathJax site (see: -# http://docs.mathjax.org/en/latest/output.html) for more details. -# Possible values are: HTML-CSS (which is slower, but has the best -# compatibility), NativeMML (i.e. MathML) and SVG. -# The default value is: HTML-CSS. -# This tag requires that the tag USE_MATHJAX is set to YES. - -MATHJAX_FORMAT = HTML-CSS - -# When MathJax is enabled you need to specify the location relative to the HTML -# output directory using the MATHJAX_RELPATH option. The destination directory -# should contain the MathJax.js script. For instance, if the mathjax directory -# is located at the same level as the HTML output directory, then -# MATHJAX_RELPATH should be ../mathjax. The default value points to the MathJax -# Content Delivery Network so you can quickly see the result without installing -# MathJax. However, it is strongly recommended to install a local copy of -# MathJax from http://www.mathjax.org before deployment. -# The default value is: http://cdn.mathjax.org/mathjax/latest. -# This tag requires that the tag USE_MATHJAX is set to YES. - -MATHJAX_RELPATH = http://cdn.mathjax.org/mathjax/latest - -# The MATHJAX_EXTENSIONS tag can be used to specify one or more MathJax -# extension names that should be enabled during MathJax rendering. For example -# MATHJAX_EXTENSIONS = TeX/AMSmath TeX/AMSsymbols -# This tag requires that the tag USE_MATHJAX is set to YES. - -MATHJAX_EXTENSIONS = - -# The MATHJAX_CODEFILE tag can be used to specify a file with javascript pieces -# of code that will be used on startup of the MathJax code. See the MathJax site -# (see: http://docs.mathjax.org/en/latest/output.html) for more details. For an -# example see the documentation. -# This tag requires that the tag USE_MATHJAX is set to YES. - -MATHJAX_CODEFILE = - -# When the SEARCHENGINE tag is enabled doxygen will generate a search box for -# the HTML output. The underlying search engine uses javascript and DHTML and -# should work on any modern browser. Note that when using HTML help -# (GENERATE_HTMLHELP), Qt help (GENERATE_QHP), or docsets (GENERATE_DOCSET) -# there is already a search function so this one should typically be disabled. -# For large projects the javascript based search engine can be slow, then -# enabling SERVER_BASED_SEARCH may provide a better solution. It is possible to -# search using the keyboard; to jump to the search box use + S -# (what the is depends on the OS and browser, but it is typically -# , /