Introduce a fake_quantize_training graph transform.

PiperOrigin-RevId: 162705909
This commit is contained in:
Suharsh Sivakumar 2017-07-20 22:13:58 -07:00 committed by TensorFlower Gardener
parent 78e7cffa71
commit 602632b5bc
6 changed files with 181 additions and 17 deletions

View File

@ -653,28 +653,38 @@ Status DoQuantizeTraining(int32 num_bits, const string& quant_op_type,
return Status::OK();
}
Status DoQuantizeTrainingOnSerializedGraphDef(const string& input_graph,
int32 num_bits,
const string& quant_op_type,
string* result_graph) {
// First create the graph from the GraphDef.
Status DoQuantizeTrainingOnGraphDef(const GraphDef& input_graphdef,
int32 num_bits, const string& quant_op_type,
GraphDef* result_graphdef) {
Graph graph(OpRegistry::Global());
GraphConstructorOptions opts;
GraphDef input_graphdef;
if (!ParseProtoUnlimited(&input_graphdef, input_graph)) {
return errors::InvalidArgument("Invalid input graph");
}
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, input_graphdef, &graph));
// Call the rewriter on the graph.
TF_RETURN_IF_ERROR(DoQuantizeTraining(num_bits, quant_op_type, &graph));
// Convert the result graph back to a GraphDef.
GraphDef output_graphdef;
graph.ToGraphDef(&output_graphdef);
graph.ToGraphDef(result_graphdef);
return Status::OK();
}
if (!output_graphdef.SerializeToString(result_graph)) {
return errors::InvalidArgument("Invalid output graph");
Status DoQuantizeTrainingOnSerializedGraphDef(const string& input_graph_string,
int32 num_bits,
const string& quant_op_type,
string* result_graph_string) {
// First create the graph from the GraphDef.
GraphDef input_graphdef;
if (!ParseProtoUnlimited(&input_graphdef, input_graph_string)) {
return errors::InvalidArgument(
"input_graph_string is not a serialized GraphDef protocol buffer");
}
GraphDef output_graphdef;
TF_RETURN_IF_ERROR(DoQuantizeTrainingOnGraphDef(
input_graphdef, num_bits, quant_op_type, &output_graphdef));
if (!output_graphdef.SerializeToString(result_graph_string)) {
return errors::Internal(
"quantize training transformation resulted in invalid GraphDef");
}
return Status::OK();
}

View File

@ -38,12 +38,19 @@ namespace tensorflow {
Status DoQuantizeTraining(int32 num_bits, const string& quant_op_type,
Graph* g);
// Converts a input GraphDef and returns a rewritten GraphDef with the
// quantized training.
// Converts the input serialized GraphDef and returns a rewritten serialized
// GraphDef for quantized training.
Status DoQuantizeTrainingOnSerializedGraphDef(const string& input_graph,
int32 num_bits,
const string& quant_op_type,
string* result_graph);
// Converts the input GraphDef and returns a rewritten GraphDef for quantized
// training.
Status DoQuantizeTrainingOnGraphDef(const GraphDef& input_graphdef,
int32 num_bits, const string& quant_op_type,
GraphDef* result_graphdef);
} // namespace tensorflow
#endif // TENSORFLOW_GRAPH_QUANTIZE_TRAINING_H_

View File

@ -282,7 +282,7 @@ TEST_F(QuantizeTrainingTest, WithBackwardNodes_FakeQuant) {
g, strings::StrCat(c->name(), "/FakeQuantWithMinMaxVars"), &found_node));
}
TEST_F(QuantizeTrainingTest, QuantizeGraphDef) {
TEST_F(QuantizeTrainingTest, QuantizeSerializedGraphDef) {
// Construct a simple graph with 5 nodes.
Reset();
Graph* graph = g_.get();
@ -310,8 +310,40 @@ TEST_F(QuantizeTrainingTest, QuantizeGraphDef) {
GraphDef result_graphdef;
EXPECT_TRUE(ParseProtoUnlimited(&result_graphdef, result_string));
// Ensure that quantizing the serialized graph_def results in a graph with the
// same number of nodes as quantizing the graph.
GraphConstructorOptions opts;
Graph result_graph(OpRegistry::Global());
TF_ASSERT_OK(ConvertGraphDefToGraph(opts, result_graphdef, &result_graph));
TF_ASSERT_OK(DoQuantizeTraining(num_bits, "QuantizeAndDequantizeV2", graph));
EXPECT_EQ(graph->num_nodes(), result_graph.num_nodes());
}
TEST_F(QuantizeTrainingTest, QuantizeGraphDef) {
// Construct a simple graph with 5 nodes.
Reset();
Graph* graph = g_.get();
Node* const_a = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2});
Node* const_b = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2});
graph->AddControlEdge(graph->source_node(), const_a);
graph->AddControlEdge(graph->source_node(), const_b);
Node* relu = test::graph::Relu(graph, const_a);
Node* identity = test::graph::Identity(graph, const_b);
Node* matmul = test::graph::Matmul(graph, relu, identity, false, false);
graph->AddControlEdge(matmul, graph->sink_node());
int num_bits = 8;
// Convert the graph to the graphdef string.
GraphDef input_graphdef;
graph->ToGraphDef(&input_graphdef);
GraphDef result_graphdef;
TF_ASSERT_OK(DoQuantizeTrainingOnGraphDef(
input_graphdef, num_bits, "QuantizeAndDequantizeV2", &result_graphdef));
// Ensure that quantizing the graph_def results in a graph with the same
// number of nodes.
// number of nodes as the graph_def.
GraphConstructorOptions opts;
Graph result_graph(OpRegistry::Global());
TF_ASSERT_OK(ConvertGraphDefToGraph(opts, result_graphdef, &result_graph));

View File

@ -60,6 +60,7 @@ cc_library(
srcs = [
"add_default_attributes.cc",
"backports.cc",
"fake_quantize_training.cc",
"fold_batch_norms.cc",
"fold_constants_lib.cc",
"fold_old_batch_norms.cc",
@ -109,6 +110,7 @@ tf_cc_test(
srcs = [
"add_default_attributes_test.cc",
"backports_test.cc",
"fake_quantize_training_test.cc",
"fold_batch_norms_test.cc",
"fold_constants_test.cc",
"fold_old_batch_norms_test.cc",

View File

@ -0,0 +1,50 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS
#include "tensorflow/core/graph/quantize_training.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
namespace graph_transforms {
// Rewrites the GraphDef for quantized training.
// Rewrites the forward pass to include the precision loss with quantization so
// the model can learn to deal with such loss and achieve better accuracy when
// it is quantized later for inference.
// Quantization range information is collected in FakeQuantizeWithMinMaxVars
// ops.
//
// TODO(suharshs): Provide instructions on converting the resulting graph for
// inference.
// TODO(suharshs): Implement this using the GTT rather than calling the old
// prototype function.
Status FakeQuantizeTraining(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def) {
// TODO(suharshs): Make num_bits a parameter.
const int32 num_bits = 8;
// TODO(suharshs): Make quantization op a parameter?
const string quant_op_type = "FakeQuantWithMinMaxVars";
return DoQuantizeTrainingOnGraphDef(input_graph_def, num_bits, quant_op_type,
output_graph_def);
}
REGISTER_GRAPH_TRANSFORM("fake_quantize_training", FakeQuantizeTraining);
} // namespace graph_transforms
} // namespace tensorflow

View File

@ -0,0 +1,63 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/math_ops.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
namespace graph_transforms {
// Declare here, so we don't need a public header.
Status FakeQuantizeTraining(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def);
class FakeQuantizeTrainingTest : public ::testing::Test {};
// For now, since the fake_quantize_training transform just calls the
// quantize_training rewrite from tensorflow/core/graph/quantize_training.h,
// we just test that the graph has been changed by the transform.
// TODO(suharshs): Once we implement the fake_quantize_training transform
// using the GTT, write proper tests of the transform here.
TEST_F(FakeQuantizeTrainingTest, TransformOccurred) {
auto root = tensorflow::Scope::NewRootScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
Tensor a_data(DT_FLOAT, TensorShape());
test::FillIota<float>(&a_data, 1.0f);
Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
Tensor b_data(DT_FLOAT, TensorShape());
test::FillIota<float>(&b_data, 1.0f);
Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data));
Output matmul = MatMul(root.WithOpName("matmul"), a_const, b_const);
GraphDef graph_def;
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
GraphDef result;
TransformFuncContext context;
TF_ASSERT_OK(FakeQuantizeTraining(graph_def, context, &result));
// Test that the transformation resulted in a graph with more nodes.
EXPECT_GT(result.node_size(), graph_def.node_size());
}
} // namespace graph_transforms
} // namespace tensorflow