Check node input and output types are float before quantizing
Change: 143799698
This commit is contained in:
		
							parent
							
								
									088a5df5bb
								
							
						
					
					
						commit
						4b3d59a771
					
				| @ -696,6 +696,31 @@ Status QuantizeNodes(const GraphDef& input_graph_def, | |||||||
|         const NodeDef& float_node = match.node; |         const NodeDef& float_node = match.node; | ||||||
|         const QuantizedOpInfo& op_info = op_map[float_node.op()]; |         const QuantizedOpInfo& op_info = op_map[float_node.op()]; | ||||||
| 
 | 
 | ||||||
|  |         DataTypeVector input_types; | ||||||
|  |         DataTypeVector output_types; | ||||||
|  |         TF_RETURN_IF_ERROR( | ||||||
|  |             GetInOutTypes(float_node, &input_types, &output_types)); | ||||||
|  |         bool are_all_float = true; | ||||||
|  |         for (int i = 0; i < float_node.input_size(); ++i) { | ||||||
|  |           // Skip any known non-float inputs.
 | ||||||
|  |           if (op_info.unquantized_inputs.count(i)) { | ||||||
|  |             continue; | ||||||
|  |           } | ||||||
|  |           if (input_types[i] != DT_FLOAT) { | ||||||
|  |             are_all_float = false; | ||||||
|  |           } | ||||||
|  |         } | ||||||
|  |         for (const DataType& output_type : output_types) { | ||||||
|  |           if (output_type != DT_FLOAT) { | ||||||
|  |             are_all_float = false; | ||||||
|  |           } | ||||||
|  |         } | ||||||
|  |         // This isn't a float op, so don't quantize it.
 | ||||||
|  |         if (!are_all_float) { | ||||||
|  |           CopyOriginalMatch(match, new_nodes); | ||||||
|  |           return Status::OK(); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|         string namespace_prefix = float_node.name() + "_eightbit"; |         string namespace_prefix = float_node.name() + "_eightbit"; | ||||||
| 
 | 
 | ||||||
|         // Quantize all of the inputs.
 |         // Quantize all of the inputs.
 | ||||||
|  | |||||||
| @ -1232,6 +1232,52 @@ class QuantizeNodesTest : public ::testing::Test { | |||||||
|     EXPECT_EQ("add_op", node_map["mul_op1"]->input(0)); |     EXPECT_EQ("add_op", node_map["mul_op1"]->input(0)); | ||||||
|     EXPECT_EQ("c_op", node_map["mul_op1"]->input(1)); |     EXPECT_EQ("c_op", node_map["mul_op1"]->input(1)); | ||||||
|   } |   } | ||||||
|  | 
 | ||||||
|  |   void TestExcludeNonFloat() { | ||||||
|  |     auto root = tensorflow::Scope::NewRootScope(); | ||||||
|  |     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
 | ||||||
|  | 
 | ||||||
|  |     Tensor int_constant_tensor(DT_INT32, TensorShape({4, 5})); | ||||||
|  |     test::FillIota<int32>(&int_constant_tensor, 1); | ||||||
|  |     Output int_constant = Const(root.WithOpName("int_constant"), | ||||||
|  |                                 Input::Initializer(int_constant_tensor)); | ||||||
|  | 
 | ||||||
|  |     Tensor float_constant_tensor(DT_FLOAT, TensorShape({4, 5})); | ||||||
|  |     test::FillIota<float>(&float_constant_tensor, 2.0f); | ||||||
|  |     Output float_constant = Const(root.WithOpName("float_constant"), | ||||||
|  |                                   Input::Initializer(float_constant_tensor)); | ||||||
|  | 
 | ||||||
|  |     Output excluded_reshape_op = | ||||||
|  |         Reshape(root.WithOpName("excluded_reshape_op"), int_constant, {10, 2}); | ||||||
|  | 
 | ||||||
|  |     Output included_reshape_op = Reshape(root.WithOpName("included_reshape_op"), | ||||||
|  |                                          float_constant, {10, 2}); | ||||||
|  | 
 | ||||||
|  |     Output excluded_relu_op = | ||||||
|  |         Relu(root.WithOpName("excluded_relu_op"), excluded_reshape_op); | ||||||
|  | 
 | ||||||
|  |     Output excluded_float_caster = Cast( | ||||||
|  |         root.WithOpName("excluded_float_caster"), excluded_relu_op, DT_FLOAT); | ||||||
|  | 
 | ||||||
|  |     Output included_relu_op = | ||||||
|  |         Relu(root.WithOpName("included_relu_op"), included_reshape_op); | ||||||
|  | 
 | ||||||
|  |     GraphDef float_graph_def; | ||||||
|  |     TF_ASSERT_OK(root.ToGraphDef(&float_graph_def)); | ||||||
|  | 
 | ||||||
|  |     GraphDef quantized_graph_def; | ||||||
|  |     TestTransformedVersusFloatGraph( | ||||||
|  |         QuantizeNodes, float_graph_def, {}, {}, | ||||||
|  |         {"excluded_float_caster", "included_relu_op"}, {}, 1.0, | ||||||
|  |         &quantized_graph_def); | ||||||
|  | 
 | ||||||
|  |     std::map<string, const NodeDef*> node_map; | ||||||
|  |     MapNamesToNodes(quantized_graph_def, &node_map); | ||||||
|  |     ASSERT_EQ(1, node_map.count("excluded_reshape_op")); | ||||||
|  |     EXPECT_EQ("Reshape", node_map.at("excluded_reshape_op")->op()); | ||||||
|  |     ASSERT_EQ(1, node_map.count("included_reshape_op")); | ||||||
|  |     EXPECT_EQ("Dequantize", node_map.at("included_reshape_op")->op()); | ||||||
|  |   } | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| TEST_F(QuantizeNodesTest, TestQuantizeMatMulTiny) { TestQuantizeMatMulTiny(); } | TEST_F(QuantizeNodesTest, TestQuantizeMatMulTiny) { TestQuantizeMatMulTiny(); } | ||||||
| @ -1317,5 +1363,7 @@ TEST_F(QuantizeNodesTest, TestMergeDuplicateInOut) { | |||||||
|   TestMergeDuplicatesInOut(); |   TestMergeDuplicatesInOut(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | TEST_F(QuantizeNodesTest, TestExcludeNonFloat) { TestExcludeNonFloat(); } | ||||||
|  | 
 | ||||||
| }  // namespace graph_transforms
 | }  // namespace graph_transforms
 | ||||||
| }  // namespace tensorflow
 | }  // namespace tensorflow
 | ||||||
|  | |||||||
| @ -15,6 +15,8 @@ limitations under the License. | |||||||
| 
 | 
 | ||||||
| #include "tensorflow/tools/graph_transforms/transform_utils.h" | #include "tensorflow/tools/graph_transforms/transform_utils.h" | ||||||
| 
 | 
 | ||||||
|  | #include "tensorflow/core/framework/node_def_util.h" | ||||||
|  | #include "tensorflow/core/framework/op.h" | ||||||
| #include "tensorflow/core/lib/hash/hash.h" | #include "tensorflow/core/lib/hash/hash.h" | ||||||
| #include "tensorflow/core/lib/strings/str_util.h" | #include "tensorflow/core/lib/strings/str_util.h" | ||||||
| #include "tensorflow/core/public/session.h" | #include "tensorflow/core/public/session.h" | ||||||
| @ -573,6 +575,14 @@ Status IsGraphValid(const GraphDef& graph_def) { | |||||||
|   return Status::OK(); |   return Status::OK(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs, | ||||||
|  |                      DataTypeVector* outputs) { | ||||||
|  |   const OpDef* op_def; | ||||||
|  |   TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node_def.op(), &op_def)); | ||||||
|  |   TF_RETURN_IF_ERROR(InOutTypesForNode(node_def, *op_def, inputs, outputs)); | ||||||
|  |   return Status::OK(); | ||||||
|  | } | ||||||
|  | 
 | ||||||
| int CountParameters(const TransformFuncContext& context, const string& name) { | int CountParameters(const TransformFuncContext& context, const string& name) { | ||||||
|   if (context.params.count(name)) { |   if (context.params.count(name)) { | ||||||
|     return context.params.at(name).size(); |     return context.params.at(name).size(); | ||||||
|  | |||||||
| @ -124,6 +124,10 @@ void FindInvalidInputs(const GraphDef& graph_def, | |||||||
| // graph.
 | // graph.
 | ||||||
| Status IsGraphValid(const GraphDef& graph_def); | Status IsGraphValid(const GraphDef& graph_def); | ||||||
| 
 | 
 | ||||||
|  | // Returns input and output types for a particular NodeDef.
 | ||||||
|  | Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs, | ||||||
|  |                      DataTypeVector* outputs); | ||||||
|  | 
 | ||||||
| // This is used to spot particular subgraphs in a larger model. To use it,
 | // This is used to spot particular subgraphs in a larger model. To use it,
 | ||||||
| // create a pattern like:
 | // create a pattern like:
 | ||||||
| // OpTypePattern pattern({"Conv2D", {{"ResizeBilinear", {{"MirrorPad"}}}}});
 | // OpTypePattern pattern({"Conv2D", {{"ResizeBilinear", {{"MirrorPad"}}}}});
 | ||||||
|  | |||||||
| @ -747,6 +747,71 @@ class TransformUtilsTest : public ::testing::Test { | |||||||
|     EXPECT_TRUE(IsGraphValid(valid_graph_def).ok()); |     EXPECT_TRUE(IsGraphValid(valid_graph_def).ok()); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |   void TestGetInOutTypes() { | ||||||
|  |     auto root = tensorflow::Scope::NewRootScope(); | ||||||
|  |     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
 | ||||||
|  | 
 | ||||||
|  |     const int width = 20; | ||||||
|  | 
 | ||||||
|  |     Tensor float_data(DT_FLOAT, TensorShape({width})); | ||||||
|  |     test::FillIota<float>(&float_data, 1.0f); | ||||||
|  |     Output float_const = | ||||||
|  |         Const(root.WithOpName("float_const"), Input::Initializer(float_data)); | ||||||
|  | 
 | ||||||
|  |     Tensor int_data(DT_INT32, TensorShape({width})); | ||||||
|  |     test::FillIota<int32>(&int_data, 1); | ||||||
|  |     Output int_const = | ||||||
|  |         Const(root.WithOpName("int_const"), Input::Initializer(int_data)); | ||||||
|  | 
 | ||||||
|  |     Output float_relu = Relu(root.WithOpName("float_relu"), float_const); | ||||||
|  | 
 | ||||||
|  |     Output int_relu = Relu(root.WithOpName("int_relu"), int_const); | ||||||
|  | 
 | ||||||
|  |     GraphDef graph_def; | ||||||
|  |     TF_ASSERT_OK(root.ToGraphDef(&graph_def)); | ||||||
|  | 
 | ||||||
|  |     std::map<string, const NodeDef*> node_map; | ||||||
|  |     MapNamesToNodes(graph_def, &node_map); | ||||||
|  | 
 | ||||||
|  |     const NodeDef* float_const_def = node_map.at("float_const"); | ||||||
|  |     DataTypeVector float_const_inputs; | ||||||
|  |     DataTypeVector float_const_outputs; | ||||||
|  |     TF_EXPECT_OK(GetInOutTypes(*float_const_def, &float_const_inputs, | ||||||
|  |                                &float_const_outputs)); | ||||||
|  |     ASSERT_EQ(0, float_const_inputs.size()); | ||||||
|  |     ASSERT_EQ(1, float_const_outputs.size()); | ||||||
|  |     EXPECT_EQ(DT_FLOAT, float_const_outputs[0]); | ||||||
|  | 
 | ||||||
|  |     const NodeDef* int_const_def = node_map.at("int_const"); | ||||||
|  |     DataTypeVector int_const_inputs; | ||||||
|  |     DataTypeVector int_const_outputs; | ||||||
|  |     TF_EXPECT_OK( | ||||||
|  |         GetInOutTypes(*int_const_def, &int_const_inputs, &int_const_outputs)); | ||||||
|  |     ASSERT_EQ(0, int_const_inputs.size()); | ||||||
|  |     ASSERT_EQ(1, int_const_outputs.size()); | ||||||
|  |     EXPECT_EQ(DT_INT32, int_const_outputs[0]); | ||||||
|  | 
 | ||||||
|  |     const NodeDef* float_relu_def = node_map.at("float_relu"); | ||||||
|  |     DataTypeVector float_relu_inputs; | ||||||
|  |     DataTypeVector float_relu_outputs; | ||||||
|  |     TF_EXPECT_OK(GetInOutTypes(*float_relu_def, &float_relu_inputs, | ||||||
|  |                                &float_relu_outputs)); | ||||||
|  |     ASSERT_EQ(1, float_relu_inputs.size()); | ||||||
|  |     EXPECT_EQ(DT_FLOAT, float_relu_inputs[0]); | ||||||
|  |     ASSERT_EQ(1, float_relu_outputs.size()); | ||||||
|  |     EXPECT_EQ(DT_FLOAT, float_relu_outputs[0]); | ||||||
|  | 
 | ||||||
|  |     const NodeDef* int_relu_def = node_map.at("int_relu"); | ||||||
|  |     DataTypeVector int_relu_inputs; | ||||||
|  |     DataTypeVector int_relu_outputs; | ||||||
|  |     TF_EXPECT_OK( | ||||||
|  |         GetInOutTypes(*int_relu_def, &int_relu_inputs, &int_relu_outputs)); | ||||||
|  |     ASSERT_EQ(1, int_relu_inputs.size()); | ||||||
|  |     EXPECT_EQ(DT_INT32, int_relu_inputs[0]); | ||||||
|  |     ASSERT_EQ(1, int_relu_outputs.size()); | ||||||
|  |     EXPECT_EQ(DT_INT32, int_relu_outputs[0]); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|   void TestCopyOriginalMatch() { |   void TestCopyOriginalMatch() { | ||||||
|     NodeDef a; |     NodeDef a; | ||||||
|     a.set_op("Relu"); |     a.set_op("Relu"); | ||||||
| @ -939,6 +1004,8 @@ TEST_F(TransformUtilsTest, TestFindInvalidInputs) { TestFindInvalidInputs(); } | |||||||
| 
 | 
 | ||||||
| TEST_F(TransformUtilsTest, TestIsGraphValid) { TestIsGraphValid(); } | TEST_F(TransformUtilsTest, TestIsGraphValid) { TestIsGraphValid(); } | ||||||
| 
 | 
 | ||||||
|  | TEST_F(TransformUtilsTest, TestGetInOutTypes) { TestGetInOutTypes(); } | ||||||
|  | 
 | ||||||
| TEST_F(TransformUtilsTest, TestCopyOriginalMatch) { TestCopyOriginalMatch(); } | TEST_F(TransformUtilsTest, TestCopyOriginalMatch) { TestCopyOriginalMatch(); } | ||||||
| 
 | 
 | ||||||
| TEST_F(TransformUtilsTest, TestHashNodeDef) { TestHashNodeDef(); } | TEST_F(TransformUtilsTest, TestHashNodeDef) { TestHashNodeDef(); } | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user