Check node input and output types are float before quantizing
Change: 143799698
This commit is contained in:
parent
088a5df5bb
commit
4b3d59a771
@ -696,6 +696,31 @@ Status QuantizeNodes(const GraphDef& input_graph_def,
|
||||
const NodeDef& float_node = match.node;
|
||||
const QuantizedOpInfo& op_info = op_map[float_node.op()];
|
||||
|
||||
DataTypeVector input_types;
|
||||
DataTypeVector output_types;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetInOutTypes(float_node, &input_types, &output_types));
|
||||
bool are_all_float = true;
|
||||
for (int i = 0; i < float_node.input_size(); ++i) {
|
||||
// Skip any known non-float inputs.
|
||||
if (op_info.unquantized_inputs.count(i)) {
|
||||
continue;
|
||||
}
|
||||
if (input_types[i] != DT_FLOAT) {
|
||||
are_all_float = false;
|
||||
}
|
||||
}
|
||||
for (const DataType& output_type : output_types) {
|
||||
if (output_type != DT_FLOAT) {
|
||||
are_all_float = false;
|
||||
}
|
||||
}
|
||||
// This isn't a float op, so don't quantize it.
|
||||
if (!are_all_float) {
|
||||
CopyOriginalMatch(match, new_nodes);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
string namespace_prefix = float_node.name() + "_eightbit";
|
||||
|
||||
// Quantize all of the inputs.
|
||||
|
@ -1232,6 +1232,52 @@ class QuantizeNodesTest : public ::testing::Test {
|
||||
EXPECT_EQ("add_op", node_map["mul_op1"]->input(0));
|
||||
EXPECT_EQ("c_op", node_map["mul_op1"]->input(1));
|
||||
}
|
||||
|
||||
void TestExcludeNonFloat() {
|
||||
auto root = tensorflow::Scope::NewRootScope();
|
||||
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||
|
||||
Tensor int_constant_tensor(DT_INT32, TensorShape({4, 5}));
|
||||
test::FillIota<int32>(&int_constant_tensor, 1);
|
||||
Output int_constant = Const(root.WithOpName("int_constant"),
|
||||
Input::Initializer(int_constant_tensor));
|
||||
|
||||
Tensor float_constant_tensor(DT_FLOAT, TensorShape({4, 5}));
|
||||
test::FillIota<float>(&float_constant_tensor, 2.0f);
|
||||
Output float_constant = Const(root.WithOpName("float_constant"),
|
||||
Input::Initializer(float_constant_tensor));
|
||||
|
||||
Output excluded_reshape_op =
|
||||
Reshape(root.WithOpName("excluded_reshape_op"), int_constant, {10, 2});
|
||||
|
||||
Output included_reshape_op = Reshape(root.WithOpName("included_reshape_op"),
|
||||
float_constant, {10, 2});
|
||||
|
||||
Output excluded_relu_op =
|
||||
Relu(root.WithOpName("excluded_relu_op"), excluded_reshape_op);
|
||||
|
||||
Output excluded_float_caster = Cast(
|
||||
root.WithOpName("excluded_float_caster"), excluded_relu_op, DT_FLOAT);
|
||||
|
||||
Output included_relu_op =
|
||||
Relu(root.WithOpName("included_relu_op"), included_reshape_op);
|
||||
|
||||
GraphDef float_graph_def;
|
||||
TF_ASSERT_OK(root.ToGraphDef(&float_graph_def));
|
||||
|
||||
GraphDef quantized_graph_def;
|
||||
TestTransformedVersusFloatGraph(
|
||||
QuantizeNodes, float_graph_def, {}, {},
|
||||
{"excluded_float_caster", "included_relu_op"}, {}, 1.0,
|
||||
&quantized_graph_def);
|
||||
|
||||
std::map<string, const NodeDef*> node_map;
|
||||
MapNamesToNodes(quantized_graph_def, &node_map);
|
||||
ASSERT_EQ(1, node_map.count("excluded_reshape_op"));
|
||||
EXPECT_EQ("Reshape", node_map.at("excluded_reshape_op")->op());
|
||||
ASSERT_EQ(1, node_map.count("included_reshape_op"));
|
||||
EXPECT_EQ("Dequantize", node_map.at("included_reshape_op")->op());
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(QuantizeNodesTest, TestQuantizeMatMulTiny) { TestQuantizeMatMulTiny(); }
|
||||
@ -1317,5 +1363,7 @@ TEST_F(QuantizeNodesTest, TestMergeDuplicateInOut) {
|
||||
TestMergeDuplicatesInOut();
|
||||
}
|
||||
|
||||
TEST_F(QuantizeNodesTest, TestExcludeNonFloat) { TestExcludeNonFloat(); }
|
||||
|
||||
} // namespace graph_transforms
|
||||
} // namespace tensorflow
|
||||
|
@ -15,6 +15,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
@ -573,6 +575,14 @@ Status IsGraphValid(const GraphDef& graph_def) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs,
|
||||
DataTypeVector* outputs) {
|
||||
const OpDef* op_def;
|
||||
TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node_def.op(), &op_def));
|
||||
TF_RETURN_IF_ERROR(InOutTypesForNode(node_def, *op_def, inputs, outputs));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int CountParameters(const TransformFuncContext& context, const string& name) {
|
||||
if (context.params.count(name)) {
|
||||
return context.params.at(name).size();
|
||||
|
@ -124,6 +124,10 @@ void FindInvalidInputs(const GraphDef& graph_def,
|
||||
// graph.
|
||||
Status IsGraphValid(const GraphDef& graph_def);
|
||||
|
||||
// Returns input and output types for a particular NodeDef.
|
||||
Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs,
|
||||
DataTypeVector* outputs);
|
||||
|
||||
// This is used to spot particular subgraphs in a larger model. To use it,
|
||||
// create a pattern like:
|
||||
// OpTypePattern pattern({"Conv2D", {{"ResizeBilinear", {{"MirrorPad"}}}}});
|
||||
|
@ -747,6 +747,71 @@ class TransformUtilsTest : public ::testing::Test {
|
||||
EXPECT_TRUE(IsGraphValid(valid_graph_def).ok());
|
||||
}
|
||||
|
||||
void TestGetInOutTypes() {
|
||||
auto root = tensorflow::Scope::NewRootScope();
|
||||
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||
|
||||
const int width = 20;
|
||||
|
||||
Tensor float_data(DT_FLOAT, TensorShape({width}));
|
||||
test::FillIota<float>(&float_data, 1.0f);
|
||||
Output float_const =
|
||||
Const(root.WithOpName("float_const"), Input::Initializer(float_data));
|
||||
|
||||
Tensor int_data(DT_INT32, TensorShape({width}));
|
||||
test::FillIota<int32>(&int_data, 1);
|
||||
Output int_const =
|
||||
Const(root.WithOpName("int_const"), Input::Initializer(int_data));
|
||||
|
||||
Output float_relu = Relu(root.WithOpName("float_relu"), float_const);
|
||||
|
||||
Output int_relu = Relu(root.WithOpName("int_relu"), int_const);
|
||||
|
||||
GraphDef graph_def;
|
||||
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
|
||||
|
||||
std::map<string, const NodeDef*> node_map;
|
||||
MapNamesToNodes(graph_def, &node_map);
|
||||
|
||||
const NodeDef* float_const_def = node_map.at("float_const");
|
||||
DataTypeVector float_const_inputs;
|
||||
DataTypeVector float_const_outputs;
|
||||
TF_EXPECT_OK(GetInOutTypes(*float_const_def, &float_const_inputs,
|
||||
&float_const_outputs));
|
||||
ASSERT_EQ(0, float_const_inputs.size());
|
||||
ASSERT_EQ(1, float_const_outputs.size());
|
||||
EXPECT_EQ(DT_FLOAT, float_const_outputs[0]);
|
||||
|
||||
const NodeDef* int_const_def = node_map.at("int_const");
|
||||
DataTypeVector int_const_inputs;
|
||||
DataTypeVector int_const_outputs;
|
||||
TF_EXPECT_OK(
|
||||
GetInOutTypes(*int_const_def, &int_const_inputs, &int_const_outputs));
|
||||
ASSERT_EQ(0, int_const_inputs.size());
|
||||
ASSERT_EQ(1, int_const_outputs.size());
|
||||
EXPECT_EQ(DT_INT32, int_const_outputs[0]);
|
||||
|
||||
const NodeDef* float_relu_def = node_map.at("float_relu");
|
||||
DataTypeVector float_relu_inputs;
|
||||
DataTypeVector float_relu_outputs;
|
||||
TF_EXPECT_OK(GetInOutTypes(*float_relu_def, &float_relu_inputs,
|
||||
&float_relu_outputs));
|
||||
ASSERT_EQ(1, float_relu_inputs.size());
|
||||
EXPECT_EQ(DT_FLOAT, float_relu_inputs[0]);
|
||||
ASSERT_EQ(1, float_relu_outputs.size());
|
||||
EXPECT_EQ(DT_FLOAT, float_relu_outputs[0]);
|
||||
|
||||
const NodeDef* int_relu_def = node_map.at("int_relu");
|
||||
DataTypeVector int_relu_inputs;
|
||||
DataTypeVector int_relu_outputs;
|
||||
TF_EXPECT_OK(
|
||||
GetInOutTypes(*int_relu_def, &int_relu_inputs, &int_relu_outputs));
|
||||
ASSERT_EQ(1, int_relu_inputs.size());
|
||||
EXPECT_EQ(DT_INT32, int_relu_inputs[0]);
|
||||
ASSERT_EQ(1, int_relu_outputs.size());
|
||||
EXPECT_EQ(DT_INT32, int_relu_outputs[0]);
|
||||
}
|
||||
|
||||
void TestCopyOriginalMatch() {
|
||||
NodeDef a;
|
||||
a.set_op("Relu");
|
||||
@ -939,6 +1004,8 @@ TEST_F(TransformUtilsTest, TestFindInvalidInputs) { TestFindInvalidInputs(); }
|
||||
|
||||
TEST_F(TransformUtilsTest, TestIsGraphValid) { TestIsGraphValid(); }
|
||||
|
||||
TEST_F(TransformUtilsTest, TestGetInOutTypes) { TestGetInOutTypes(); }
|
||||
|
||||
TEST_F(TransformUtilsTest, TestCopyOriginalMatch) { TestCopyOriginalMatch(); }
|
||||
|
||||
TEST_F(TransformUtilsTest, TestHashNodeDef) { TestHashNodeDef(); }
|
||||
|
Loading…
Reference in New Issue
Block a user