Introduce a fake_quantize_training graph transform.
PiperOrigin-RevId: 162705909
This commit is contained in:
parent
78e7cffa71
commit
602632b5bc
@ -653,28 +653,38 @@ Status DoQuantizeTraining(int32 num_bits, const string& quant_op_type,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DoQuantizeTrainingOnSerializedGraphDef(const string& input_graph,
|
||||
int32 num_bits,
|
||||
const string& quant_op_type,
|
||||
string* result_graph) {
|
||||
// First create the graph from the GraphDef.
|
||||
Status DoQuantizeTrainingOnGraphDef(const GraphDef& input_graphdef,
|
||||
int32 num_bits, const string& quant_op_type,
|
||||
GraphDef* result_graphdef) {
|
||||
Graph graph(OpRegistry::Global());
|
||||
GraphConstructorOptions opts;
|
||||
GraphDef input_graphdef;
|
||||
if (!ParseProtoUnlimited(&input_graphdef, input_graph)) {
|
||||
return errors::InvalidArgument("Invalid input graph");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, input_graphdef, &graph));
|
||||
|
||||
// Call the rewriter on the graph.
|
||||
TF_RETURN_IF_ERROR(DoQuantizeTraining(num_bits, quant_op_type, &graph));
|
||||
|
||||
// Convert the result graph back to a GraphDef.
|
||||
GraphDef output_graphdef;
|
||||
graph.ToGraphDef(&output_graphdef);
|
||||
graph.ToGraphDef(result_graphdef);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (!output_graphdef.SerializeToString(result_graph)) {
|
||||
return errors::InvalidArgument("Invalid output graph");
|
||||
Status DoQuantizeTrainingOnSerializedGraphDef(const string& input_graph_string,
|
||||
int32 num_bits,
|
||||
const string& quant_op_type,
|
||||
string* result_graph_string) {
|
||||
// First create the graph from the GraphDef.
|
||||
GraphDef input_graphdef;
|
||||
if (!ParseProtoUnlimited(&input_graphdef, input_graph_string)) {
|
||||
return errors::InvalidArgument(
|
||||
"input_graph_string is not a serialized GraphDef protocol buffer");
|
||||
}
|
||||
GraphDef output_graphdef;
|
||||
TF_RETURN_IF_ERROR(DoQuantizeTrainingOnGraphDef(
|
||||
input_graphdef, num_bits, quant_op_type, &output_graphdef));
|
||||
|
||||
if (!output_graphdef.SerializeToString(result_graph_string)) {
|
||||
return errors::Internal(
|
||||
"quantize training transformation resulted in invalid GraphDef");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -38,12 +38,19 @@ namespace tensorflow {
|
||||
Status DoQuantizeTraining(int32 num_bits, const string& quant_op_type,
|
||||
Graph* g);
|
||||
|
||||
// Converts a input GraphDef and returns a rewritten GraphDef with the
|
||||
// quantized training.
|
||||
// Converts the input serialized GraphDef and returns a rewritten serialized
|
||||
// GraphDef for quantized training.
|
||||
Status DoQuantizeTrainingOnSerializedGraphDef(const string& input_graph,
|
||||
int32 num_bits,
|
||||
const string& quant_op_type,
|
||||
string* result_graph);
|
||||
|
||||
// Converts the input GraphDef and returns a rewritten GraphDef for quantized
|
||||
// training.
|
||||
Status DoQuantizeTrainingOnGraphDef(const GraphDef& input_graphdef,
|
||||
int32 num_bits, const string& quant_op_type,
|
||||
GraphDef* result_graphdef);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_GRAPH_QUANTIZE_TRAINING_H_
|
||||
|
@ -282,7 +282,7 @@ TEST_F(QuantizeTrainingTest, WithBackwardNodes_FakeQuant) {
|
||||
g, strings::StrCat(c->name(), "/FakeQuantWithMinMaxVars"), &found_node));
|
||||
}
|
||||
|
||||
TEST_F(QuantizeTrainingTest, QuantizeGraphDef) {
|
||||
TEST_F(QuantizeTrainingTest, QuantizeSerializedGraphDef) {
|
||||
// Construct a simple graph with 5 nodes.
|
||||
Reset();
|
||||
Graph* graph = g_.get();
|
||||
@ -310,8 +310,40 @@ TEST_F(QuantizeTrainingTest, QuantizeGraphDef) {
|
||||
GraphDef result_graphdef;
|
||||
EXPECT_TRUE(ParseProtoUnlimited(&result_graphdef, result_string));
|
||||
|
||||
// Ensure that quantizing the serialized graph_def results in a graph with the
|
||||
// same number of nodes as quantizing the graph.
|
||||
GraphConstructorOptions opts;
|
||||
Graph result_graph(OpRegistry::Global());
|
||||
TF_ASSERT_OK(ConvertGraphDefToGraph(opts, result_graphdef, &result_graph));
|
||||
TF_ASSERT_OK(DoQuantizeTraining(num_bits, "QuantizeAndDequantizeV2", graph));
|
||||
EXPECT_EQ(graph->num_nodes(), result_graph.num_nodes());
|
||||
}
|
||||
|
||||
TEST_F(QuantizeTrainingTest, QuantizeGraphDef) {
|
||||
// Construct a simple graph with 5 nodes.
|
||||
Reset();
|
||||
Graph* graph = g_.get();
|
||||
Node* const_a = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2});
|
||||
Node* const_b = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2});
|
||||
graph->AddControlEdge(graph->source_node(), const_a);
|
||||
graph->AddControlEdge(graph->source_node(), const_b);
|
||||
Node* relu = test::graph::Relu(graph, const_a);
|
||||
Node* identity = test::graph::Identity(graph, const_b);
|
||||
Node* matmul = test::graph::Matmul(graph, relu, identity, false, false);
|
||||
graph->AddControlEdge(matmul, graph->sink_node());
|
||||
|
||||
int num_bits = 8;
|
||||
|
||||
// Convert the graph to the graphdef string.
|
||||
GraphDef input_graphdef;
|
||||
graph->ToGraphDef(&input_graphdef);
|
||||
|
||||
GraphDef result_graphdef;
|
||||
TF_ASSERT_OK(DoQuantizeTrainingOnGraphDef(
|
||||
input_graphdef, num_bits, "QuantizeAndDequantizeV2", &result_graphdef));
|
||||
|
||||
// Ensure that quantizing the graph_def results in a graph with the same
|
||||
// number of nodes.
|
||||
// number of nodes as the graph_def.
|
||||
GraphConstructorOptions opts;
|
||||
Graph result_graph(OpRegistry::Global());
|
||||
TF_ASSERT_OK(ConvertGraphDefToGraph(opts, result_graphdef, &result_graph));
|
||||
|
@ -60,6 +60,7 @@ cc_library(
|
||||
srcs = [
|
||||
"add_default_attributes.cc",
|
||||
"backports.cc",
|
||||
"fake_quantize_training.cc",
|
||||
"fold_batch_norms.cc",
|
||||
"fold_constants_lib.cc",
|
||||
"fold_old_batch_norms.cc",
|
||||
@ -109,6 +110,7 @@ tf_cc_test(
|
||||
srcs = [
|
||||
"add_default_attributes_test.cc",
|
||||
"backports_test.cc",
|
||||
"fake_quantize_training_test.cc",
|
||||
"fold_batch_norms_test.cc",
|
||||
"fold_constants_test.cc",
|
||||
"fold_old_batch_norms_test.cc",
|
||||
|
50
tensorflow/tools/graph_transforms/fake_quantize_training.cc
Normal file
50
tensorflow/tools/graph_transforms/fake_quantize_training.cc
Normal file
@ -0,0 +1,50 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "tensorflow/core/graph/quantize_training.h"
|
||||
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace graph_transforms {
|
||||
|
||||
// Rewrites the GraphDef for quantized training.
|
||||
// Rewrites the forward pass to include the precision loss with quantization so
|
||||
// the model can learn to deal with such loss and achieve better accuracy when
|
||||
// it is quantized later for inference.
|
||||
// Quantization range information is collected in FakeQuantizeWithMinMaxVars
|
||||
// ops.
|
||||
//
|
||||
// TODO(suharshs): Provide instructions on converting the resulting graph for
|
||||
// inference.
|
||||
// TODO(suharshs): Implement this using the GTT rather than calling the old
|
||||
// prototype function.
|
||||
Status FakeQuantizeTraining(const GraphDef& input_graph_def,
|
||||
const TransformFuncContext& context,
|
||||
GraphDef* output_graph_def) {
|
||||
// TODO(suharshs): Make num_bits a parameter.
|
||||
const int32 num_bits = 8;
|
||||
// TODO(suharshs): Make quantization op a parameter?
|
||||
const string quant_op_type = "FakeQuantWithMinMaxVars";
|
||||
|
||||
return DoQuantizeTrainingOnGraphDef(input_graph_def, num_bits, quant_op_type,
|
||||
output_graph_def);
|
||||
}
|
||||
|
||||
REGISTER_GRAPH_TRANSFORM("fake_quantize_training", FakeQuantizeTraining);
|
||||
|
||||
} // namespace graph_transforms
|
||||
} // namespace tensorflow
|
@ -0,0 +1,63 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/ops/const_op.h"
|
||||
#include "tensorflow/cc/ops/math_ops.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace graph_transforms {
|
||||
|
||||
// Declare here, so we don't need a public header.
|
||||
Status FakeQuantizeTraining(const GraphDef& input_graph_def,
|
||||
const TransformFuncContext& context,
|
||||
GraphDef* output_graph_def);
|
||||
|
||||
class FakeQuantizeTrainingTest : public ::testing::Test {};
|
||||
|
||||
// For now, since the fake_quantize_training transform just calls the
|
||||
// quantize_training rewrite from tensorflow/core/graph/quantize_training.h,
|
||||
// we just test that the graph has been changed by the transform.
|
||||
// TODO(suharshs): Once we implement the fake_quantize_training transform
|
||||
// using the GTT, write proper tests of the transform here.
|
||||
TEST_F(FakeQuantizeTrainingTest, TransformOccurred) {
|
||||
auto root = tensorflow::Scope::NewRootScope();
|
||||
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||
|
||||
Tensor a_data(DT_FLOAT, TensorShape());
|
||||
test::FillIota<float>(&a_data, 1.0f);
|
||||
Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
|
||||
|
||||
Tensor b_data(DT_FLOAT, TensorShape());
|
||||
test::FillIota<float>(&b_data, 1.0f);
|
||||
Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data));
|
||||
|
||||
Output matmul = MatMul(root.WithOpName("matmul"), a_const, b_const);
|
||||
GraphDef graph_def;
|
||||
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
|
||||
|
||||
GraphDef result;
|
||||
TransformFuncContext context;
|
||||
TF_ASSERT_OK(FakeQuantizeTraining(graph_def, context, &result));
|
||||
|
||||
// Test that the transformation resulted in a graph with more nodes.
|
||||
EXPECT_GT(result.node_size(), graph_def.node_size());
|
||||
}
|
||||
|
||||
} // namespace graph_transforms
|
||||
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user