[Grappler] Conditionally fold quantization emulation ops.

The comment on
RewriterConfig.experimental_disable_folding_quantization_emulation explains the
rationale. Even more details can be found at b/174138564.

PiperOrigin-RevId: 352949451
Change-Id: I7f7bdccda83a4b8d1c0834099292d79ca6b6e6bb
This commit is contained in:
Jingyue Wu 2021-01-20 22:53:55 -08:00 committed by TensorFlower Gardener
parent 310b42a801
commit d9613ebb70
7 changed files with 73 additions and 8 deletions

View File

@ -251,6 +251,12 @@ bool IsElu(const NodeDef& node) { return node.op() == "Elu"; }
bool IsEluGrad(const NodeDef& node) { return node.op() == "EluGrad"; }
bool IsQuantizationEmulation(const NodeDef& node) {
const auto& op = node.op();
return absl::StartsWith(op, "QuantizeAndDequantize") ||
absl::StartsWith(op, "FakeQuantWithMinMax");
}
bool IsEnter(const NodeDef& node) {
const auto& op = node.op();
return op == "Enter" || op == "RefEnter";

View File

@ -75,6 +75,7 @@ bool IsDivNoNan(const NodeDef& node);
bool IsElementWiseMonotonic(const NodeDef& node, bool* is_non_decreasing);
bool IsElu(const NodeDef& node);
bool IsEluGrad(const NodeDef& node);
bool IsQuantizationEmulation(const NodeDef& node);
bool IsEnter(const NodeDef& node);
bool IsEqual(const NodeDef& node);
bool IsExit(const NodeDef& node);

View File

@ -188,18 +188,22 @@ float QuantizedTypeMaxAsFloat(DataType data_type) {
ConstantFolding::ConstantFolding(RewriterConfig::Toggle opt_level,
DeviceBase* cpu_device,
bool disable_compressed_tensor_optimization)
bool disable_compressed_tensor_optimization,
bool fold_quantization_emulation)
: opt_level_(opt_level),
cpu_device_(cpu_device),
disable_compressed_tensor_optimization_(
disable_compressed_tensor_optimization) {
disable_compressed_tensor_optimization),
fold_quantization_emulation_(fold_quantization_emulation) {
resource_mgr_.reset(new ResourceMgr());
}
ConstantFolding::ConstantFolding(DeviceBase* cpu_device,
bool disable_compressed_tensor_optimization)
bool disable_compressed_tensor_optimization,
bool fold_quantization_ops)
: ConstantFolding(RewriterConfig::ON, cpu_device,
disable_compressed_tensor_optimization) {}
disable_compressed_tensor_optimization,
fold_quantization_ops) {}
// static
string ConstantFolding::AddControlDependency(const string& input_name,
@ -1061,6 +1065,10 @@ bool ConstantFolding::MaybeFoldable(const NodeDef& node,
return false;
}
if (!fold_quantization_emulation_ && IsQuantizationEmulation(node)) {
return false;
}
const string& op = node.op();
if (op.find("Save") != string::npos || op.find("Restore") != string::npos ||
op.find("Reader") != string::npos) {

View File

@ -46,9 +46,11 @@ class ConstantFolding : public GraphOptimizer {
NodeMap* node_map);
explicit ConstantFolding(DeviceBase* cpu_device,
bool disable_compressed_tensor_optimization = false);
bool disable_compressed_tensor_optimization = false,
bool fold_quantization_emulation = true);
ConstantFolding(RewriterConfig::Toggle opt_level, DeviceBase* cpu_device,
bool disable_compressed_tensor_optimization = false);
bool disable_compressed_tensor_optimization = false,
bool fold_quantization_emulation = true);
~ConstantFolding() override {}
@ -340,6 +342,7 @@ class ConstantFolding : public GraphOptimizer {
bool graph_modified_;
bool graph_contains_assign_or_inplace_op_;
bool disable_compressed_tensor_optimization_;
bool fold_quantization_emulation_;
};
} // end namespace grappler

View File

@ -4368,6 +4368,43 @@ TEST_F(ConstantFoldingTest, SimplifySelect_BroadcastTo) {
}
}
TEST_F(ConstantFoldingTest, QuantizationEmulation) {
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
Output x = ops::Const(scope.WithOpName("x"), {0.0f, 1.0f, 2.0f, 3.0f}, {4});
Output min_range = ops::Const(scope.WithOpName("min_range"), 0.0f, {});
Output max_range = ops::Const(scope.WithOpName("max_range"), 3.0f, {});
Output y = ops::QuantizeAndDequantizeV2(scope.WithOpName("y"), x, min_range,
max_range);
Output id = ops::Identity(scope.WithOpName("id"), y);
GrapplerItem item;
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
item.fetch = {"id"};
std::vector<Tensor> expected_tensors = EvaluateNodes(item.graph, item.fetch);
for (const bool fold_quantization_emulation : {false, true}) {
ConstantFolding optimizer(/*cpu_device=*/nullptr,
/*disable_compressed_tensor_optimization=*/false,
fold_quantization_emulation);
GraphDef output;
Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
int num_quantization_emulation_ops = 0;
for (const NodeDef& node : output.node()) {
if (node.op() == "QuantizeAndDequantizeV2") {
num_quantization_emulation_ops++;
}
}
EXPECT_EQ(fold_quantization_emulation ? 0 : 1,
num_quantization_emulation_ops);
std::vector<Tensor> actual_tensors = EvaluateNodes(output, item.fetch);
for (int i = 0; i < item.fetch.size(); ++i) {
test::ExpectTensorEqual<float>(expected_tensors[i], actual_tensors[i]);
}
}
}
} // namespace
} // namespace grappler
} // namespace tensorflow

View File

@ -189,7 +189,8 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer(
MK_OPT("constfold",
new ConstantFolding(
cpu_device_,
cfg_.experimental_disable_compressed_tensor_optimization()));
cfg_.experimental_disable_compressed_tensor_optimization(),
!cfg_.experimental_disable_folding_quantization_emulation()));
MK_OPT("shape", new ShapeOptimizer());
MK_OPT("remap", new Remapper(cfg_.remapping()));
MK_OPT("layout", new GenericLayoutOptimizer(
@ -253,7 +254,8 @@ Status MetaOptimizer::InitializeOptimizers(
if (cfg_.constant_folding() != RewriterConfig::OFF) {
optimizers->push_back(MakeUnique<ConstantFolding>(
cfg_.constant_folding(), cpu_device_,
cfg_.experimental_disable_compressed_tensor_optimization()));
cfg_.experimental_disable_compressed_tensor_optimization(),
!cfg_.experimental_disable_folding_quantization_emulation()));
}
if (cfg_.shape_optimization() != RewriterConfig::OFF) {
optimizers->push_back(MakeUnique<ShapeOptimizer>());

View File

@ -121,6 +121,14 @@ message RewriterConfig {
// is experimental and may be removed in the future.
bool experimental_disable_compressed_tensor_optimization = 26;
// Disable folding quantization emulation ops such as FakeQuantWithMinMax* and
// QuantizeAndDequantize*. Some compilers (e.g. the TF-to-tflite converter)
// have to extract quantization configs (e.g. min/max range, number of bits,
// and per-channel) from the quantization emulation ops. Note that this flag
// is experimental and may be removed in the future. See b/174138564 for more
// details.
bool experimental_disable_folding_quantization_emulation = 27;
enum MemOptType {
// The default setting (SCHEDULING and SWAPPING HEURISTICS only)
DEFAULT_MEM_OPT = 0;