From 4b3d59a771252506cc34e66ebf2cd93be2564229 Mon Sep 17 00:00:00 2001 From: Pete Warden Date: Fri, 6 Jan 2017 12:07:05 -0800 Subject: [PATCH] Check node input and output types are float before quantizing Change: 143799698 --- .../tools/graph_transforms/quantize_nodes.cc | 25 +++++++ .../graph_transforms/quantize_nodes_test.cc | 48 +++++++++++++ .../tools/graph_transforms/transform_utils.cc | 10 +++ .../tools/graph_transforms/transform_utils.h | 4 ++ .../graph_transforms/transform_utils_test.cc | 67 +++++++++++++++++++ 5 files changed, 154 insertions(+) diff --git a/tensorflow/tools/graph_transforms/quantize_nodes.cc b/tensorflow/tools/graph_transforms/quantize_nodes.cc index 8b0393049ac..22ed2b669e7 100644 --- a/tensorflow/tools/graph_transforms/quantize_nodes.cc +++ b/tensorflow/tools/graph_transforms/quantize_nodes.cc @@ -696,6 +696,31 @@ Status QuantizeNodes(const GraphDef& input_graph_def, const NodeDef& float_node = match.node; const QuantizedOpInfo& op_info = op_map[float_node.op()]; + DataTypeVector input_types; + DataTypeVector output_types; + TF_RETURN_IF_ERROR( + GetInOutTypes(float_node, &input_types, &output_types)); + bool are_all_float = true; + for (int i = 0; i < float_node.input_size(); ++i) { + // Skip any known non-float inputs. + if (op_info.unquantized_inputs.count(i)) { + continue; + } + if (input_types[i] != DT_FLOAT) { + are_all_float = false; + } + } + for (const DataType& output_type : output_types) { + if (output_type != DT_FLOAT) { + are_all_float = false; + } + } + // This isn't a float op, so don't quantize it. + if (!are_all_float) { + CopyOriginalMatch(match, new_nodes); + return Status::OK(); + } + string namespace_prefix = float_node.name() + "_eightbit"; // Quantize all of the inputs. diff --git a/tensorflow/tools/graph_transforms/quantize_nodes_test.cc b/tensorflow/tools/graph_transforms/quantize_nodes_test.cc index a82bf781fc6..222aa0b99dc 100644 --- a/tensorflow/tools/graph_transforms/quantize_nodes_test.cc +++ b/tensorflow/tools/graph_transforms/quantize_nodes_test.cc @@ -1232,6 +1232,52 @@ class QuantizeNodesTest : public ::testing::Test { EXPECT_EQ("add_op", node_map["mul_op1"]->input(0)); EXPECT_EQ("c_op", node_map["mul_op1"]->input(1)); } + + void TestExcludeNonFloat() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + Tensor int_constant_tensor(DT_INT32, TensorShape({4, 5})); + test::FillIota(&int_constant_tensor, 1); + Output int_constant = Const(root.WithOpName("int_constant"), + Input::Initializer(int_constant_tensor)); + + Tensor float_constant_tensor(DT_FLOAT, TensorShape({4, 5})); + test::FillIota(&float_constant_tensor, 2.0f); + Output float_constant = Const(root.WithOpName("float_constant"), + Input::Initializer(float_constant_tensor)); + + Output excluded_reshape_op = + Reshape(root.WithOpName("excluded_reshape_op"), int_constant, {10, 2}); + + Output included_reshape_op = Reshape(root.WithOpName("included_reshape_op"), + float_constant, {10, 2}); + + Output excluded_relu_op = + Relu(root.WithOpName("excluded_relu_op"), excluded_reshape_op); + + Output excluded_float_caster = Cast( + root.WithOpName("excluded_float_caster"), excluded_relu_op, DT_FLOAT); + + Output included_relu_op = + Relu(root.WithOpName("included_relu_op"), included_reshape_op); + + GraphDef float_graph_def; + TF_ASSERT_OK(root.ToGraphDef(&float_graph_def)); + + GraphDef quantized_graph_def; + TestTransformedVersusFloatGraph( + QuantizeNodes, float_graph_def, {}, {}, + {"excluded_float_caster", "included_relu_op"}, {}, 1.0, + &quantized_graph_def); + + std::map node_map; + MapNamesToNodes(quantized_graph_def, &node_map); + ASSERT_EQ(1, node_map.count("excluded_reshape_op")); + EXPECT_EQ("Reshape", node_map.at("excluded_reshape_op")->op()); + ASSERT_EQ(1, node_map.count("included_reshape_op")); + EXPECT_EQ("Dequantize", node_map.at("included_reshape_op")->op()); + } }; TEST_F(QuantizeNodesTest, TestQuantizeMatMulTiny) { TestQuantizeMatMulTiny(); } @@ -1317,5 +1363,7 @@ TEST_F(QuantizeNodesTest, TestMergeDuplicateInOut) { TestMergeDuplicatesInOut(); } +TEST_F(QuantizeNodesTest, TestExcludeNonFloat) { TestExcludeNonFloat(); } + } // namespace graph_transforms } // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/transform_utils.cc b/tensorflow/tools/graph_transforms/transform_utils.cc index 0a0b0f01a5f..72bd7f03836 100644 --- a/tensorflow/tools/graph_transforms/transform_utils.cc +++ b/tensorflow/tools/graph_transforms/transform_utils.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/tools/graph_transforms/transform_utils.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/public/session.h" @@ -573,6 +575,14 @@ Status IsGraphValid(const GraphDef& graph_def) { return Status::OK(); } +Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs, + DataTypeVector* outputs) { + const OpDef* op_def; + TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node_def.op(), &op_def)); + TF_RETURN_IF_ERROR(InOutTypesForNode(node_def, *op_def, inputs, outputs)); + return Status::OK(); +} + int CountParameters(const TransformFuncContext& context, const string& name) { if (context.params.count(name)) { return context.params.at(name).size(); diff --git a/tensorflow/tools/graph_transforms/transform_utils.h b/tensorflow/tools/graph_transforms/transform_utils.h index 7672011d6cf..f87d8326ef5 100644 --- a/tensorflow/tools/graph_transforms/transform_utils.h +++ b/tensorflow/tools/graph_transforms/transform_utils.h @@ -124,6 +124,10 @@ void FindInvalidInputs(const GraphDef& graph_def, // graph. Status IsGraphValid(const GraphDef& graph_def); +// Returns input and output types for a particular NodeDef. +Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs, + DataTypeVector* outputs); + // This is used to spot particular subgraphs in a larger model. To use it, // create a pattern like: // OpTypePattern pattern({"Conv2D", {{"ResizeBilinear", {{"MirrorPad"}}}}}); diff --git a/tensorflow/tools/graph_transforms/transform_utils_test.cc b/tensorflow/tools/graph_transforms/transform_utils_test.cc index 3e9f661f672..85074b9bc33 100644 --- a/tensorflow/tools/graph_transforms/transform_utils_test.cc +++ b/tensorflow/tools/graph_transforms/transform_utils_test.cc @@ -747,6 +747,71 @@ class TransformUtilsTest : public ::testing::Test { EXPECT_TRUE(IsGraphValid(valid_graph_def).ok()); } + void TestGetInOutTypes() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + const int width = 20; + + Tensor float_data(DT_FLOAT, TensorShape({width})); + test::FillIota(&float_data, 1.0f); + Output float_const = + Const(root.WithOpName("float_const"), Input::Initializer(float_data)); + + Tensor int_data(DT_INT32, TensorShape({width})); + test::FillIota(&int_data, 1); + Output int_const = + Const(root.WithOpName("int_const"), Input::Initializer(int_data)); + + Output float_relu = Relu(root.WithOpName("float_relu"), float_const); + + Output int_relu = Relu(root.WithOpName("int_relu"), int_const); + + GraphDef graph_def; + TF_ASSERT_OK(root.ToGraphDef(&graph_def)); + + std::map node_map; + MapNamesToNodes(graph_def, &node_map); + + const NodeDef* float_const_def = node_map.at("float_const"); + DataTypeVector float_const_inputs; + DataTypeVector float_const_outputs; + TF_EXPECT_OK(GetInOutTypes(*float_const_def, &float_const_inputs, + &float_const_outputs)); + ASSERT_EQ(0, float_const_inputs.size()); + ASSERT_EQ(1, float_const_outputs.size()); + EXPECT_EQ(DT_FLOAT, float_const_outputs[0]); + + const NodeDef* int_const_def = node_map.at("int_const"); + DataTypeVector int_const_inputs; + DataTypeVector int_const_outputs; + TF_EXPECT_OK( + GetInOutTypes(*int_const_def, &int_const_inputs, &int_const_outputs)); + ASSERT_EQ(0, int_const_inputs.size()); + ASSERT_EQ(1, int_const_outputs.size()); + EXPECT_EQ(DT_INT32, int_const_outputs[0]); + + const NodeDef* float_relu_def = node_map.at("float_relu"); + DataTypeVector float_relu_inputs; + DataTypeVector float_relu_outputs; + TF_EXPECT_OK(GetInOutTypes(*float_relu_def, &float_relu_inputs, + &float_relu_outputs)); + ASSERT_EQ(1, float_relu_inputs.size()); + EXPECT_EQ(DT_FLOAT, float_relu_inputs[0]); + ASSERT_EQ(1, float_relu_outputs.size()); + EXPECT_EQ(DT_FLOAT, float_relu_outputs[0]); + + const NodeDef* int_relu_def = node_map.at("int_relu"); + DataTypeVector int_relu_inputs; + DataTypeVector int_relu_outputs; + TF_EXPECT_OK( + GetInOutTypes(*int_relu_def, &int_relu_inputs, &int_relu_outputs)); + ASSERT_EQ(1, int_relu_inputs.size()); + EXPECT_EQ(DT_INT32, int_relu_inputs[0]); + ASSERT_EQ(1, int_relu_outputs.size()); + EXPECT_EQ(DT_INT32, int_relu_outputs[0]); + } + void TestCopyOriginalMatch() { NodeDef a; a.set_op("Relu"); @@ -939,6 +1004,8 @@ TEST_F(TransformUtilsTest, TestFindInvalidInputs) { TestFindInvalidInputs(); } TEST_F(TransformUtilsTest, TestIsGraphValid) { TestIsGraphValid(); } +TEST_F(TransformUtilsTest, TestGetInOutTypes) { TestGetInOutTypes(); } + TEST_F(TransformUtilsTest, TestCopyOriginalMatch) { TestCopyOriginalMatch(); } TEST_F(TransformUtilsTest, TestHashNodeDef) { TestHashNodeDef(); }