Check node input and output types are float before quantizing

Change: 143799698
This commit is contained in:
Pete Warden 2017-01-06 12:07:05 -08:00 committed by TensorFlower Gardener
parent 088a5df5bb
commit 4b3d59a771
5 changed files with 154 additions and 0 deletions

View File

@ -696,6 +696,31 @@ Status QuantizeNodes(const GraphDef& input_graph_def,
const NodeDef& float_node = match.node;
const QuantizedOpInfo& op_info = op_map[float_node.op()];
DataTypeVector input_types;
DataTypeVector output_types;
TF_RETURN_IF_ERROR(
GetInOutTypes(float_node, &input_types, &output_types));
bool are_all_float = true;
for (int i = 0; i < float_node.input_size(); ++i) {
// Skip any known non-float inputs.
if (op_info.unquantized_inputs.count(i)) {
continue;
}
if (input_types[i] != DT_FLOAT) {
are_all_float = false;
}
}
for (const DataType& output_type : output_types) {
if (output_type != DT_FLOAT) {
are_all_float = false;
}
}
// This isn't a float op, so don't quantize it.
if (!are_all_float) {
CopyOriginalMatch(match, new_nodes);
return Status::OK();
}
string namespace_prefix = float_node.name() + "_eightbit";
// Quantize all of the inputs.

View File

@ -1232,6 +1232,52 @@ class QuantizeNodesTest : public ::testing::Test {
EXPECT_EQ("add_op", node_map["mul_op1"]->input(0));
EXPECT_EQ("c_op", node_map["mul_op1"]->input(1));
}
void TestExcludeNonFloat() {
auto root = tensorflow::Scope::NewRootScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
Tensor int_constant_tensor(DT_INT32, TensorShape({4, 5}));
test::FillIota<int32>(&int_constant_tensor, 1);
Output int_constant = Const(root.WithOpName("int_constant"),
Input::Initializer(int_constant_tensor));
Tensor float_constant_tensor(DT_FLOAT, TensorShape({4, 5}));
test::FillIota<float>(&float_constant_tensor, 2.0f);
Output float_constant = Const(root.WithOpName("float_constant"),
Input::Initializer(float_constant_tensor));
Output excluded_reshape_op =
Reshape(root.WithOpName("excluded_reshape_op"), int_constant, {10, 2});
Output included_reshape_op = Reshape(root.WithOpName("included_reshape_op"),
float_constant, {10, 2});
Output excluded_relu_op =
Relu(root.WithOpName("excluded_relu_op"), excluded_reshape_op);
Output excluded_float_caster = Cast(
root.WithOpName("excluded_float_caster"), excluded_relu_op, DT_FLOAT);
Output included_relu_op =
Relu(root.WithOpName("included_relu_op"), included_reshape_op);
GraphDef float_graph_def;
TF_ASSERT_OK(root.ToGraphDef(&float_graph_def));
GraphDef quantized_graph_def;
TestTransformedVersusFloatGraph(
QuantizeNodes, float_graph_def, {}, {},
{"excluded_float_caster", "included_relu_op"}, {}, 1.0,
&quantized_graph_def);
std::map<string, const NodeDef*> node_map;
MapNamesToNodes(quantized_graph_def, &node_map);
ASSERT_EQ(1, node_map.count("excluded_reshape_op"));
EXPECT_EQ("Reshape", node_map.at("excluded_reshape_op")->op());
ASSERT_EQ(1, node_map.count("included_reshape_op"));
EXPECT_EQ("Dequantize", node_map.at("included_reshape_op")->op());
}
};
TEST_F(QuantizeNodesTest, TestQuantizeMatMulTiny) { TestQuantizeMatMulTiny(); }
@ -1317,5 +1363,7 @@ TEST_F(QuantizeNodesTest, TestMergeDuplicateInOut) {
TestMergeDuplicatesInOut();
}
TEST_F(QuantizeNodesTest, TestExcludeNonFloat) { TestExcludeNonFloat(); }
} // namespace graph_transforms
} // namespace tensorflow

View File

@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/tools/graph_transforms/transform_utils.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/public/session.h"
@ -573,6 +575,14 @@ Status IsGraphValid(const GraphDef& graph_def) {
return Status::OK();
}
Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs,
DataTypeVector* outputs) {
const OpDef* op_def;
TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node_def.op(), &op_def));
TF_RETURN_IF_ERROR(InOutTypesForNode(node_def, *op_def, inputs, outputs));
return Status::OK();
}
int CountParameters(const TransformFuncContext& context, const string& name) {
if (context.params.count(name)) {
return context.params.at(name).size();

View File

@ -124,6 +124,10 @@ void FindInvalidInputs(const GraphDef& graph_def,
// graph.
Status IsGraphValid(const GraphDef& graph_def);
// Returns input and output types for a particular NodeDef.
Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs,
DataTypeVector* outputs);
// This is used to spot particular subgraphs in a larger model. To use it,
// create a pattern like:
// OpTypePattern pattern({"Conv2D", {{"ResizeBilinear", {{"MirrorPad"}}}}});

View File

@ -747,6 +747,71 @@ class TransformUtilsTest : public ::testing::Test {
EXPECT_TRUE(IsGraphValid(valid_graph_def).ok());
}
void TestGetInOutTypes() {
auto root = tensorflow::Scope::NewRootScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
const int width = 20;
Tensor float_data(DT_FLOAT, TensorShape({width}));
test::FillIota<float>(&float_data, 1.0f);
Output float_const =
Const(root.WithOpName("float_const"), Input::Initializer(float_data));
Tensor int_data(DT_INT32, TensorShape({width}));
test::FillIota<int32>(&int_data, 1);
Output int_const =
Const(root.WithOpName("int_const"), Input::Initializer(int_data));
Output float_relu = Relu(root.WithOpName("float_relu"), float_const);
Output int_relu = Relu(root.WithOpName("int_relu"), int_const);
GraphDef graph_def;
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
std::map<string, const NodeDef*> node_map;
MapNamesToNodes(graph_def, &node_map);
const NodeDef* float_const_def = node_map.at("float_const");
DataTypeVector float_const_inputs;
DataTypeVector float_const_outputs;
TF_EXPECT_OK(GetInOutTypes(*float_const_def, &float_const_inputs,
&float_const_outputs));
ASSERT_EQ(0, float_const_inputs.size());
ASSERT_EQ(1, float_const_outputs.size());
EXPECT_EQ(DT_FLOAT, float_const_outputs[0]);
const NodeDef* int_const_def = node_map.at("int_const");
DataTypeVector int_const_inputs;
DataTypeVector int_const_outputs;
TF_EXPECT_OK(
GetInOutTypes(*int_const_def, &int_const_inputs, &int_const_outputs));
ASSERT_EQ(0, int_const_inputs.size());
ASSERT_EQ(1, int_const_outputs.size());
EXPECT_EQ(DT_INT32, int_const_outputs[0]);
const NodeDef* float_relu_def = node_map.at("float_relu");
DataTypeVector float_relu_inputs;
DataTypeVector float_relu_outputs;
TF_EXPECT_OK(GetInOutTypes(*float_relu_def, &float_relu_inputs,
&float_relu_outputs));
ASSERT_EQ(1, float_relu_inputs.size());
EXPECT_EQ(DT_FLOAT, float_relu_inputs[0]);
ASSERT_EQ(1, float_relu_outputs.size());
EXPECT_EQ(DT_FLOAT, float_relu_outputs[0]);
const NodeDef* int_relu_def = node_map.at("int_relu");
DataTypeVector int_relu_inputs;
DataTypeVector int_relu_outputs;
TF_EXPECT_OK(
GetInOutTypes(*int_relu_def, &int_relu_inputs, &int_relu_outputs));
ASSERT_EQ(1, int_relu_inputs.size());
EXPECT_EQ(DT_INT32, int_relu_inputs[0]);
ASSERT_EQ(1, int_relu_outputs.size());
EXPECT_EQ(DT_INT32, int_relu_outputs[0]);
}
void TestCopyOriginalMatch() {
NodeDef a;
a.set_op("Relu");
@ -939,6 +1004,8 @@ TEST_F(TransformUtilsTest, TestFindInvalidInputs) { TestFindInvalidInputs(); }
TEST_F(TransformUtilsTest, TestIsGraphValid) { TestIsGraphValid(); }
TEST_F(TransformUtilsTest, TestGetInOutTypes) { TestGetInOutTypes(); }
TEST_F(TransformUtilsTest, TestCopyOriginalMatch) { TestCopyOriginalMatch(); }
TEST_F(TransformUtilsTest, TestHashNodeDef) { TestHashNodeDef(); }