[Grappler] Add initial support for DT_QINT32, DT_QINT16, DT_QUINT16, DT_QINT8, and DT_QUINT8 to ConstantFolding.
PiperOrigin-RevId: 234072895
This commit is contained in:
parent
68b686af89
commit
43f47645a8
@ -1,7 +1,6 @@
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cuda_library")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_cuda_library")
|
||||
|
||||
cc_library(
|
||||
name = "op_types",
|
||||
@ -45,6 +44,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:tensor_testutil",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
@ -71,7 +71,6 @@ cc_library(
|
||||
deps = [
|
||||
":graph_view",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
|
@ -481,6 +481,7 @@ bool IsNumericType(const DataType dtype) {
|
||||
DT_QINT8,
|
||||
DT_QUINT8,
|
||||
DT_QINT16,
|
||||
DT_QUINT16,
|
||||
DT_QINT32,
|
||||
// Bool.
|
||||
DT_BOOL,
|
||||
|
@ -279,8 +279,8 @@ bool IsLogicalOr(const NodeDef& node) { return node.op() == "LogicalOr"; }
|
||||
|
||||
bool IsMatMul(const NodeDef& node) {
|
||||
const auto& op = node.op();
|
||||
return op == "MatMul" || op == "BatchMatMul" || op == "QuantizedMatMul" ||
|
||||
op == "SparseMatMul";
|
||||
return op == "MatMul" || op == "BatchMatMul" || op == "SparseMatMul" ||
|
||||
IsQuantizedMatMul(node);
|
||||
}
|
||||
|
||||
bool IsMax(const NodeDef& node) { return node.op() == "Max"; }
|
||||
@ -350,6 +350,10 @@ bool IsPrint(const NodeDef& node) {
|
||||
|
||||
bool IsProd(const NodeDef& node) { return node.op() == "Prod"; }
|
||||
|
||||
bool IsQuantizedMatMul(const NodeDef& node) {
|
||||
return node.op() == "QuantizedMatMul" || node.op() == "QuantizedMatMulV2";
|
||||
}
|
||||
|
||||
bool IsQueue(const NodeDef& node) {
|
||||
return str_util::EndsWith(node.op(), "QueueV2");
|
||||
}
|
||||
|
@ -106,6 +106,7 @@ bool IsPack(const NodeDef& node);
|
||||
bool IsPad(const NodeDef& node);
|
||||
bool IsPack(const NodeDef& node);
|
||||
bool IsPartitionedCall(const NodeDef& node);
|
||||
bool IsQuantizedMatMul(const NodeDef& node);
|
||||
bool IsNeg(const NodeDef& node);
|
||||
bool IsNoOp(const NodeDef& node);
|
||||
bool IsNotEqual(const NodeDef& node);
|
||||
|
@ -3,7 +3,6 @@ licenses(["notice"]) # Apache 2.0
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
|
||||
# Platform specific build config
|
||||
load(
|
||||
@ -274,13 +273,29 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "arithmetic_optimizer_test_utils",
|
||||
testonly = 1,
|
||||
hdrs = [
|
||||
"arithmetic_optimizer_test_utils.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":arithmetic_optimizer",
|
||||
":constant_folding",
|
||||
":model_pruner",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core/grappler/utils:grappler_test",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "arithmetic_optimizer_test",
|
||||
size = "small",
|
||||
srcs = ["arithmetic_optimizer_test.cc"],
|
||||
deps = [
|
||||
":arithmetic_optimizer",
|
||||
":constant_folding",
|
||||
":arithmetic_optimizer_test_utils",
|
||||
":model_pruner",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:cc_ops_internal",
|
||||
@ -295,7 +310,6 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
|
||||
"//tensorflow/core/grappler/utils:grappler_test",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -20,10 +20,9 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
|
||||
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
|
||||
#include "tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h"
|
||||
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/grappler/utils/grappler_test.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
@ -92,211 +91,6 @@ void VerifyGraphsMatch(const GraphDef& original_graph,
|
||||
}
|
||||
} // namespace
|
||||
|
||||
class ArithmeticOptimizerTest : public GrapplerTest {
|
||||
protected:
|
||||
// Optimize a graph using ArithmeticOptimizer and prune all the nodes that no
|
||||
// longer have any output consumers.
|
||||
void OptimizeAndPrune(ArithmeticOptimizer* optimizer, GrapplerItem* item,
|
||||
GraphDef* output) {
|
||||
TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
|
||||
item->graph.Swap(output);
|
||||
output->Clear();
|
||||
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, *item, output));
|
||||
}
|
||||
|
||||
// Run ArithmeticOptimizer twice to make sure the rewrite is idempotent.
|
||||
void OptimizeTwice(ArithmeticOptimizer* optimizer, GrapplerItem* item,
|
||||
GraphDef* output) {
|
||||
TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
|
||||
item->graph.Swap(output);
|
||||
output->Clear();
|
||||
TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
|
||||
}
|
||||
|
||||
// Run ArithmeticOptimizer twice to make sure the rewrite is idempotent.
|
||||
// Optionally run a constant folding pass before pruning.
|
||||
void OptimizeTwiceAndPrune(ArithmeticOptimizer* optimizer, GrapplerItem* item,
|
||||
GraphDef* output, bool const_folding = false) {
|
||||
TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
|
||||
|
||||
item->graph.Swap(output);
|
||||
output->Clear();
|
||||
TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
|
||||
|
||||
if (const_folding) {
|
||||
item->graph.Swap(output);
|
||||
output->Clear();
|
||||
TF_EXPECT_OK(ConstantFolding(/*cpu_device=*/nullptr)
|
||||
.Optimize(nullptr, *item, output));
|
||||
}
|
||||
|
||||
item->graph.Swap(output);
|
||||
output->Clear();
|
||||
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, *item, output));
|
||||
}
|
||||
|
||||
// TODO(ezhulenev): Make private. After migration to stages each test
|
||||
// should explicitly enable required optimization for tests isolation
|
||||
void DisableAllStages(ArithmeticOptimizer* optimizer) {
|
||||
ArithmeticOptimizer::ArithmeticOptimizerOptions options;
|
||||
options.dedup_computations = false;
|
||||
options.combine_add_to_addn = false;
|
||||
options.convert_sqrt_div_to_rsqrt_mul = false;
|
||||
options.convert_pow = false;
|
||||
options.convert_log1p = false;
|
||||
options.optimize_max_or_min_of_monotonic = false;
|
||||
options.fold_conjugate_into_transpose = false;
|
||||
options.fold_multiply_into_conv = false;
|
||||
options.fold_transpose_into_matmul = false;
|
||||
options.hoist_common_factor_out_of_aggregation = false;
|
||||
options.hoist_cwise_unary_chains = false;
|
||||
options.minimize_broadcasts = false;
|
||||
options.remove_identity_transpose = false;
|
||||
options.remove_involution = false;
|
||||
options.remove_idempotent = false;
|
||||
options.remove_redundant_bitcast = false;
|
||||
options.remove_redundant_cast = false;
|
||||
options.remove_redundant_reshape = false;
|
||||
options.remove_negation = false;
|
||||
options.remove_logical_not = false;
|
||||
options.reorder_cast_like_and_value_preserving = false;
|
||||
options.replace_mul_with_square = false;
|
||||
options.simplify_aggregation = false;
|
||||
options.unary_ops_composition = false;
|
||||
optimizer->options_ = options;
|
||||
}
|
||||
|
||||
void DisableAddToAddNCombining(ArithmeticOptimizer* optimizer) {
|
||||
optimizer->options_.combine_add_to_addn = false;
|
||||
}
|
||||
|
||||
void EnableOnlyAddToAddNCombining(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.combine_add_to_addn = true;
|
||||
}
|
||||
|
||||
void EnableOnlyFoldConjugateIntoTranspose(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.fold_conjugate_into_transpose = true;
|
||||
}
|
||||
|
||||
void EnableOnlyFoldMultipleIntoConv(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.fold_multiply_into_conv = true;
|
||||
}
|
||||
|
||||
void EnableOnlyFoldTransposeIntoMatMul(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.fold_transpose_into_matmul = true;
|
||||
}
|
||||
|
||||
void EnableOnlyHoistCommonFactor(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.hoist_common_factor_out_of_aggregation = true;
|
||||
}
|
||||
|
||||
void EnableOnlyMinimizeBroadcasts(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.minimize_broadcasts = true;
|
||||
}
|
||||
|
||||
void EnableOnlyRemoveIdentityTranspose(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.remove_identity_transpose = true;
|
||||
}
|
||||
|
||||
void EnableOnlyRemoveInvolution(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.remove_involution = true;
|
||||
}
|
||||
|
||||
void EnableOnlyRemoveRedundantBitcast(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.remove_redundant_bitcast = true;
|
||||
}
|
||||
|
||||
void EnableOnlyRemoveRedundantCast(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.remove_redundant_cast = true;
|
||||
}
|
||||
|
||||
void EnableOnlyRemoveRedundantReshape(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.remove_redundant_reshape = true;
|
||||
}
|
||||
|
||||
void EnableOnlyRemoveNegation(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.remove_negation = true;
|
||||
}
|
||||
|
||||
void EnableOnlyReorderCastAndTranspose(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.reorder_cast_like_and_value_preserving = true;
|
||||
}
|
||||
|
||||
void EnableOnlyReplaceMulWithSquare(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.replace_mul_with_square = true;
|
||||
}
|
||||
|
||||
void EnableOnlyHoistCWiseUnaryChains(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.hoist_cwise_unary_chains = true;
|
||||
}
|
||||
|
||||
void EnableOnlySqrtDivToRsqrtMul(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.convert_sqrt_div_to_rsqrt_mul = true;
|
||||
}
|
||||
|
||||
void EnableOnlyConvertPow(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.convert_pow = true;
|
||||
}
|
||||
|
||||
void EnableOnlyRemoveIdempotent(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.remove_idempotent = true;
|
||||
}
|
||||
|
||||
void EnableOnlyRemoveLogicalNot(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.remove_logical_not = true;
|
||||
}
|
||||
|
||||
void EnableOnlySimplifyAggregation(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.simplify_aggregation = true;
|
||||
}
|
||||
|
||||
void EnableOnlyLog1p(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.convert_log1p = true;
|
||||
}
|
||||
|
||||
void EnableOnlyOptimizeMaxOrMinOfMonotonic(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.optimize_max_or_min_of_monotonic = true;
|
||||
}
|
||||
|
||||
void EnableOnlyExpm1(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.convert_expm1 = true;
|
||||
}
|
||||
|
||||
void EnableOnlyUnaryOpsComposition(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.unary_ops_composition = true;
|
||||
}
|
||||
|
||||
void EnableOnlyRemoveStackStridedSliceSameAxis(
|
||||
ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.remove_stack_strided_slice_same_axis = true;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(ArithmeticOptimizerTest, NoOp) {
|
||||
// This trivial graph is so basic there's nothing to optimize.
|
||||
TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
|
||||
|
@ -0,0 +1,236 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_ARITHMETIC_OPTIMIZER_TEST_UTILS_H_
|
||||
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_ARITHMETIC_OPTIMIZER_TEST_UTILS_H_
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
|
||||
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
|
||||
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
|
||||
#include "tensorflow/core/grappler/utils/grappler_test.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
class ArithmeticOptimizerTest : public GrapplerTest {
|
||||
protected:
|
||||
// Optimize a graph using ArithmeticOptimizer and prune all the nodes that no
|
||||
// longer have any output consumers.
|
||||
void OptimizeAndPrune(ArithmeticOptimizer* optimizer, GrapplerItem* item,
|
||||
GraphDef* output) {
|
||||
TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
|
||||
item->graph.Swap(output);
|
||||
output->Clear();
|
||||
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, *item, output));
|
||||
}
|
||||
|
||||
// Run ArithmeticOptimizer twice to make sure the rewrite is idempotent.
|
||||
void OptimizeTwice(ArithmeticOptimizer* optimizer, GrapplerItem* item,
|
||||
GraphDef* output) {
|
||||
TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
|
||||
item->graph.Swap(output);
|
||||
output->Clear();
|
||||
TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
|
||||
}
|
||||
|
||||
// Run ArithmeticOptimizer twice to make sure the rewrite is idempotent.
|
||||
// Optionally run a constant folding pass before pruning.
|
||||
void OptimizeTwiceAndPrune(ArithmeticOptimizer* optimizer, GrapplerItem* item,
|
||||
GraphDef* output, bool const_folding = false) {
|
||||
TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
|
||||
|
||||
item->graph.Swap(output);
|
||||
output->Clear();
|
||||
TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
|
||||
|
||||
if (const_folding) {
|
||||
item->graph.Swap(output);
|
||||
output->Clear();
|
||||
TF_EXPECT_OK(ConstantFolding(/*cpu_device=*/nullptr)
|
||||
.Optimize(nullptr, *item, output));
|
||||
}
|
||||
|
||||
item->graph.Swap(output);
|
||||
output->Clear();
|
||||
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, *item, output));
|
||||
}
|
||||
|
||||
// TODO(ezhulenev): Make private. After migration to stages each test
|
||||
// should explicitly enable required optimization for tests isolation
|
||||
void DisableAllStages(ArithmeticOptimizer* optimizer) {
|
||||
ArithmeticOptimizer::ArithmeticOptimizerOptions options;
|
||||
options.dedup_computations = false;
|
||||
options.combine_add_to_addn = false;
|
||||
options.convert_sqrt_div_to_rsqrt_mul = false;
|
||||
options.convert_pow = false;
|
||||
options.convert_log1p = false;
|
||||
options.optimize_max_or_min_of_monotonic = false;
|
||||
options.fold_conjugate_into_transpose = false;
|
||||
options.fold_multiply_into_conv = false;
|
||||
options.fold_transpose_into_matmul = false;
|
||||
options.hoist_common_factor_out_of_aggregation = false;
|
||||
options.hoist_cwise_unary_chains = false;
|
||||
options.minimize_broadcasts = false;
|
||||
options.remove_identity_transpose = false;
|
||||
options.remove_involution = false;
|
||||
options.remove_idempotent = false;
|
||||
options.remove_redundant_bitcast = false;
|
||||
options.remove_redundant_cast = false;
|
||||
options.remove_redundant_reshape = false;
|
||||
options.remove_negation = false;
|
||||
options.remove_logical_not = false;
|
||||
options.reorder_cast_like_and_value_preserving = false;
|
||||
options.replace_mul_with_square = false;
|
||||
options.simplify_aggregation = false;
|
||||
options.unary_ops_composition = false;
|
||||
optimizer->options_ = options;
|
||||
}
|
||||
|
||||
void DisableAddToAddNCombining(ArithmeticOptimizer* optimizer) {
|
||||
optimizer->options_.combine_add_to_addn = false;
|
||||
}
|
||||
|
||||
void EnableOnlyAddToAddNCombining(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.combine_add_to_addn = true;
|
||||
}
|
||||
|
||||
void EnableOnlyFoldConjugateIntoTranspose(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.fold_conjugate_into_transpose = true;
|
||||
}
|
||||
|
||||
void EnableOnlyFoldMultipleIntoConv(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.fold_multiply_into_conv = true;
|
||||
}
|
||||
|
||||
void EnableOnlyFoldTransposeIntoMatMul(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.fold_transpose_into_matmul = true;
|
||||
}
|
||||
|
||||
void EnableOnlyHoistCommonFactor(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.hoist_common_factor_out_of_aggregation = true;
|
||||
}
|
||||
|
||||
void EnableOnlyMinimizeBroadcasts(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.minimize_broadcasts = true;
|
||||
}
|
||||
|
||||
void EnableOnlyRemoveIdentityTranspose(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.remove_identity_transpose = true;
|
||||
}
|
||||
|
||||
void EnableOnlyRemoveInvolution(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.remove_involution = true;
|
||||
}
|
||||
|
||||
void EnableOnlyRemoveRedundantBitcast(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.remove_redundant_bitcast = true;
|
||||
}
|
||||
|
||||
void EnableOnlyRemoveRedundantCast(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.remove_redundant_cast = true;
|
||||
}
|
||||
|
||||
void EnableOnlyRemoveRedundantReshape(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.remove_redundant_reshape = true;
|
||||
}
|
||||
|
||||
void EnableOnlyRemoveNegation(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.remove_negation = true;
|
||||
}
|
||||
|
||||
void EnableOnlyReorderCastAndTranspose(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.reorder_cast_like_and_value_preserving = true;
|
||||
}
|
||||
|
||||
void EnableOnlyReplaceMulWithSquare(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.replace_mul_with_square = true;
|
||||
}
|
||||
|
||||
void EnableOnlyHoistCWiseUnaryChains(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.hoist_cwise_unary_chains = true;
|
||||
}
|
||||
|
||||
void EnableOnlySqrtDivToRsqrtMul(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.convert_sqrt_div_to_rsqrt_mul = true;
|
||||
}
|
||||
|
||||
void EnableOnlyConvertPow(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.convert_pow = true;
|
||||
}
|
||||
|
||||
void EnableOnlyRemoveIdempotent(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.remove_idempotent = true;
|
||||
}
|
||||
|
||||
void EnableOnlyRemoveLogicalNot(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.remove_logical_not = true;
|
||||
}
|
||||
|
||||
void EnableOnlySimplifyAggregation(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.simplify_aggregation = true;
|
||||
}
|
||||
|
||||
void EnableOnlyLog1p(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.convert_log1p = true;
|
||||
}
|
||||
|
||||
void EnableOnlyOptimizeMaxOrMinOfMonotonic(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.optimize_max_or_min_of_monotonic = true;
|
||||
}
|
||||
|
||||
void EnableOnlyExpm1(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.convert_expm1 = true;
|
||||
}
|
||||
|
||||
void EnableOnlyUnaryOpsComposition(ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.unary_ops_composition = true;
|
||||
}
|
||||
|
||||
void EnableOnlyRemoveStackStridedSliceSameAxis(
|
||||
ArithmeticOptimizer* optimizer) {
|
||||
DisableAllStages(optimizer);
|
||||
optimizer->options_.remove_stack_strided_slice_same_axis = true;
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_ARITHMETIC_OPTIMIZER_TEST_UTILS_H_
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <cmath>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
@ -37,6 +38,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/grappler/optimizers/evaluation_utils.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
@ -185,6 +187,40 @@ bool IsDenormal(double x) {
|
||||
return !std::isnormal(x);
|
||||
}
|
||||
|
||||
float QuantizedTypeMinAsFloat(DataType data_type) {
|
||||
switch (data_type) {
|
||||
case DT_QINT8:
|
||||
return Eigen::NumTraits<qint8>::lowest();
|
||||
case DT_QUINT8:
|
||||
return Eigen::NumTraits<quint8>::lowest();
|
||||
case DT_QINT16:
|
||||
return Eigen::NumTraits<qint16>::lowest();
|
||||
case DT_QUINT16:
|
||||
return Eigen::NumTraits<quint16>::lowest();
|
||||
case DT_QINT32:
|
||||
return Eigen::NumTraits<qint32>::lowest();
|
||||
default:
|
||||
return 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
float QuantizedTypeMaxAsFloat(DataType data_type) {
|
||||
switch (data_type) {
|
||||
case DT_QINT8:
|
||||
return Eigen::NumTraits<qint8>::highest();
|
||||
case DT_QUINT8:
|
||||
return Eigen::NumTraits<quint8>::highest();
|
||||
case DT_QINT16:
|
||||
return Eigen::NumTraits<qint16>::highest();
|
||||
case DT_QUINT16:
|
||||
return Eigen::NumTraits<quint16>::highest();
|
||||
case DT_QINT32:
|
||||
return Eigen::NumTraits<qint32>::highest();
|
||||
default:
|
||||
return 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
ConstantFolding::ConstantFolding(RewriterConfig::Toggle opt_level,
|
||||
@ -945,6 +981,11 @@ Status CreateConstantTensorAttrValue(DataType type, double value,
|
||||
SET_TENSOR_VAL_CASE(DT_UINT16, int32, int);
|
||||
SET_TENSOR_VAL_CASE(DT_INT8, int32, int);
|
||||
SET_TENSOR_VAL_CASE(DT_UINT8, int32, int);
|
||||
SET_TENSOR_VAL_CASE(DT_QINT32, int32, int);
|
||||
SET_TENSOR_VAL_CASE(DT_QINT16, int32, int);
|
||||
SET_TENSOR_VAL_CASE(DT_QUINT16, int32, int);
|
||||
SET_TENSOR_VAL_CASE(DT_QINT8, int32, int);
|
||||
SET_TENSOR_VAL_CASE(DT_QUINT8, int32, int);
|
||||
SET_TENSOR_VAL_CASE(DT_BOOL, bool, bool);
|
||||
default:
|
||||
return errors::InvalidArgument("Unsupported type: ", type);
|
||||
@ -1085,6 +1126,8 @@ Status ConstantFolding::CreateNodeDef(const string& name,
|
||||
t->set_dtype(tensor->dtype());
|
||||
tensor->shape().AsProto(t->mutable_tensor_shape());
|
||||
} else {
|
||||
// DT_HALF, DT_BFLOAT16, DT_QINT32, DT_QINT16, DT_QUINT16, DT_QINT8,
|
||||
// DT_QUINT8
|
||||
tensor->AsProtoTensorContent(t);
|
||||
encoded_size = t->tensor_content().size();
|
||||
}
|
||||
@ -1533,6 +1576,11 @@ bool ConstantFolding::IsOnes(const NodeDef& node) const {
|
||||
IS_ONES_CASE(DT_INT16);
|
||||
IS_ONES_CASE(DT_INT32);
|
||||
IS_ONES_CASE(DT_INT64);
|
||||
IS_ONES_CASE(DT_QINT32);
|
||||
IS_ONES_CASE(DT_QINT16);
|
||||
IS_ONES_CASE(DT_QUINT16);
|
||||
IS_ONES_CASE(DT_QINT8);
|
||||
IS_ONES_CASE(DT_QUINT8);
|
||||
default:
|
||||
VLOG(1) << "Unsupported type " << DataTypeString(dtype);
|
||||
return false;
|
||||
@ -1567,6 +1615,11 @@ bool ConstantFolding::IsZeros(const NodeDef& node) const {
|
||||
IS_ZEROS_CASE(DT_INT16);
|
||||
IS_ZEROS_CASE(DT_INT32);
|
||||
IS_ZEROS_CASE(DT_INT64);
|
||||
IS_ZEROS_CASE(DT_QINT32);
|
||||
IS_ZEROS_CASE(DT_QINT16);
|
||||
IS_ZEROS_CASE(DT_QUINT16);
|
||||
IS_ZEROS_CASE(DT_QINT8);
|
||||
IS_ZEROS_CASE(DT_QUINT8);
|
||||
default:
|
||||
VLOG(1) << "Unsupported type " << DataTypeString(dtype);
|
||||
return false;
|
||||
@ -2576,6 +2629,7 @@ Status ConstantFolding::SimplifyArithmeticOperations(
|
||||
*success = false;
|
||||
const bool is_mul = IsMul(*node) || IsLogicalAnd(*node);
|
||||
const bool is_matmul = IsMatMul(*node);
|
||||
const bool is_quantized_matmul = IsQuantizedMatMul(*node);
|
||||
const bool is_add = IsAdd(*node) || IsBiasAdd(*node) || IsLogicalOr(*node);
|
||||
const bool is_sub = IsSub(*node);
|
||||
const bool is_any_div = IsAnyDiv(*node);
|
||||
@ -2670,6 +2724,10 @@ Status ConstantFolding::SimplifyArithmeticOperations(
|
||||
if (!replace_op_status.ok()) {
|
||||
return replace_op_status;
|
||||
} else if (replace_succeed) {
|
||||
if (is_quantized_matmul) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddQuantizedMatMulMinMaxOutConstNodes(node, optimized_graph));
|
||||
}
|
||||
*success = true;
|
||||
return Status::OK();
|
||||
}
|
||||
@ -3237,6 +3295,65 @@ bool ConstantFolding::MergeConcat(const GraphProperties& properties,
|
||||
return true;
|
||||
}
|
||||
|
||||
Status ConstantFolding::AddQuantizedMatMulMinMaxOutConstNodes(
|
||||
NodeDef* node, GraphDef* optimized_graph) {
|
||||
auto add_quantized_out = [this, node, optimized_graph](
|
||||
const string& out_const_name, int index) {
|
||||
NodeDef* out_node = optimized_graph->add_node();
|
||||
Tensor value(DT_FLOAT, TensorShape({}));
|
||||
const bool is_min = index == 1;
|
||||
const DataType type_attr = node->attr().at("dtype").type();
|
||||
|
||||
value.flat<float>()(0) = is_min ? QuantizedTypeMinAsFloat(type_attr)
|
||||
: QuantizedTypeMaxAsFloat(type_attr);
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateNodeDef(out_const_name, TensorValue(&value), out_node));
|
||||
node_map_->AddNode(out_const_name, out_node);
|
||||
out_node->set_device(node->device());
|
||||
|
||||
// Copy all inputs from node.
|
||||
out_node->mutable_input()->CopyFrom(node->input());
|
||||
for (const string& input : out_node->input()) {
|
||||
node_map_->AddOutput(NodeName(input), out_const_name);
|
||||
}
|
||||
|
||||
// Update output nodes consuming node:index to new const node.
|
||||
string old_input = absl::StrCat(node->name(), ":", index);
|
||||
int old_node_count = 0;
|
||||
auto outputs = node_map_->GetOutputs(node->name());
|
||||
for (const auto& output : outputs) {
|
||||
for (int i = 0; i < output->input_size(); ++i) {
|
||||
if (output->input(i) == old_input) {
|
||||
output->set_input(i, out_const_name);
|
||||
node_map_->AddOutput(out_const_name, output->name());
|
||||
} else if (NodeName(output->input(i)) == node->name()) {
|
||||
++old_node_count;
|
||||
}
|
||||
}
|
||||
if (old_node_count == 0) {
|
||||
node_map_->RemoveOutput(node->name(), output->name());
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
};
|
||||
const string min_out_const_name =
|
||||
OptimizedNodeName(*node, "-quantized_matmul_min_out");
|
||||
const string max_out_const_name =
|
||||
OptimizedNodeName(*node, "-quantized_matmul_max_out");
|
||||
if (node_map_->GetNode(min_out_const_name) == nullptr &&
|
||||
node_map_->GetNode(max_out_const_name) == nullptr) {
|
||||
TF_RETURN_IF_ERROR(add_quantized_out(min_out_const_name, 1));
|
||||
TF_RETURN_IF_ERROR(add_quantized_out(max_out_const_name, 2));
|
||||
} else {
|
||||
return errors::Internal(absl::Substitute(
|
||||
"Can't create Const for QuantizedMatMul min_out/max_out of "
|
||||
"node '$0' because of node name conflict",
|
||||
node->name()));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
|
||||
const GrapplerItem& item,
|
||||
GraphDef* optimized_graph) {
|
||||
|
@ -236,6 +236,9 @@ class ConstantFolding : public GraphOptimizer {
|
||||
bool MergeConcat(const GraphProperties& properties, bool use_shape_info,
|
||||
GraphDef* optimized_graph, NodeDef* node);
|
||||
|
||||
Status AddQuantizedMatMulMinMaxOutConstNodes(NodeDef* node,
|
||||
GraphDef* optimized_graph);
|
||||
|
||||
// Points to an externally provided device or to owned_device_;
|
||||
RewriterConfig::Toggle opt_level_;
|
||||
DeviceBase* cpu_device_;
|
||||
|
@ -40,7 +40,7 @@ namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace {
|
||||
template <typename T>
|
||||
bool SafeSetScalarTensorValue(double value, Tensor* tensor) {
|
||||
bool SafeSetDoubleScalarTensorValue(double value, Tensor* tensor) {
|
||||
using RealType = typename Eigen::NumTraits<T>::Real;
|
||||
if (value > static_cast<double>(Eigen::NumTraits<RealType>::highest()) ||
|
||||
value < static_cast<double>(Eigen::NumTraits<RealType>::lowest())) {
|
||||
@ -50,6 +50,17 @@ bool SafeSetScalarTensorValue(double value, Tensor* tensor) {
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool SafeSetIntScalarTensorValue(int value, Tensor* tensor) {
|
||||
using RealType = typename Eigen::NumTraits<T>::Real;
|
||||
if (value > static_cast<int>(Eigen::NumTraits<RealType>::highest()) ||
|
||||
value < static_cast<int>(Eigen::NumTraits<RealType>::lowest())) {
|
||||
return false;
|
||||
}
|
||||
tensor->flat<T>()(0) = static_cast<T>(value);
|
||||
return true;
|
||||
}
|
||||
|
||||
// Is 'node' an operator that consumes only the shape of its input, not the
|
||||
// data itself?
|
||||
// TODO(ezhulenev): move to op_types.h. Requires to break circular dependency.
|
||||
@ -410,35 +421,50 @@ void EraseNodesFromGraph(const std::set<string>& nodes_to_delete,
|
||||
EraseNodesFromGraphImpl(nodes_idx_to_delete, graph);
|
||||
}
|
||||
|
||||
#define HANDLE_CASE(DTYPE) \
|
||||
case DTYPE: \
|
||||
if (!SafeSetScalarTensorValue<EnumToDataType<DTYPE>::Type>( \
|
||||
static_cast<double>(value), tensor)) { \
|
||||
return errors::InvalidArgument("Cannot store value ", value, \
|
||||
" in tensor of type " #DTYPE); \
|
||||
} \
|
||||
#define HANDLE_DOUBLE_CASE(DTYPE) \
|
||||
case DTYPE: \
|
||||
if (!SafeSetDoubleScalarTensorValue<EnumToDataType<DTYPE>::Type>( \
|
||||
static_cast<double>(value), tensor)) { \
|
||||
return errors::InvalidArgument("Cannot store value ", value, \
|
||||
" in tensor of type " #DTYPE); \
|
||||
} \
|
||||
break
|
||||
|
||||
#define HANDLE_INT_CASE(DTYPE) \
|
||||
case DTYPE: \
|
||||
if (!SafeSetIntScalarTensorValue<EnumToDataType<DTYPE>::Type>(value, \
|
||||
tensor)) { \
|
||||
return errors::InvalidArgument("Cannot store value ", value, \
|
||||
" in tensor of type " #DTYPE); \
|
||||
} \
|
||||
break
|
||||
|
||||
Status SetTensorValue(DataType dtype, int value, Tensor* tensor) {
|
||||
// TODO(rmlarsen): Support more general shapes.
|
||||
// TODO(lyandy): Change `value` to be int64 once int64 -> qint32 is supported.
|
||||
if (tensor->NumElements() != 1) {
|
||||
return errors::InvalidArgument(
|
||||
"Expected scalar tensor, got num_elements = ", tensor->NumElements());
|
||||
}
|
||||
switch (dtype) {
|
||||
HANDLE_CASE(DT_HALF);
|
||||
HANDLE_CASE(DT_BFLOAT16);
|
||||
HANDLE_CASE(DT_BOOL);
|
||||
HANDLE_CASE(DT_FLOAT);
|
||||
HANDLE_CASE(DT_DOUBLE);
|
||||
HANDLE_CASE(DT_UINT8);
|
||||
HANDLE_CASE(DT_INT8);
|
||||
HANDLE_CASE(DT_UINT16);
|
||||
HANDLE_CASE(DT_INT16);
|
||||
HANDLE_CASE(DT_INT32);
|
||||
HANDLE_CASE(DT_INT64);
|
||||
HANDLE_CASE(DT_COMPLEX64);
|
||||
HANDLE_CASE(DT_COMPLEX128);
|
||||
HANDLE_DOUBLE_CASE(DT_HALF);
|
||||
HANDLE_DOUBLE_CASE(DT_BFLOAT16);
|
||||
HANDLE_DOUBLE_CASE(DT_BOOL);
|
||||
HANDLE_DOUBLE_CASE(DT_FLOAT);
|
||||
HANDLE_DOUBLE_CASE(DT_DOUBLE);
|
||||
HANDLE_DOUBLE_CASE(DT_UINT8);
|
||||
HANDLE_DOUBLE_CASE(DT_INT8);
|
||||
HANDLE_DOUBLE_CASE(DT_UINT16);
|
||||
HANDLE_DOUBLE_CASE(DT_INT16);
|
||||
HANDLE_DOUBLE_CASE(DT_INT32);
|
||||
HANDLE_DOUBLE_CASE(DT_INT64);
|
||||
HANDLE_DOUBLE_CASE(DT_COMPLEX64);
|
||||
HANDLE_DOUBLE_CASE(DT_COMPLEX128);
|
||||
HANDLE_INT_CASE(DT_QINT8);
|
||||
HANDLE_INT_CASE(DT_QUINT8);
|
||||
HANDLE_INT_CASE(DT_QINT16);
|
||||
HANDLE_INT_CASE(DT_QUINT16);
|
||||
HANDLE_INT_CASE(DT_QINT32);
|
||||
default:
|
||||
return errors::InvalidArgument("Unsupported type ",
|
||||
DataTypeString(dtype));
|
||||
|
@ -18,6 +18,8 @@ limitations under the License.
|
||||
#include <unistd.h>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
@ -124,56 +126,56 @@ class UtilsTest : public ::testing::Test {
|
||||
};
|
||||
|
||||
TEST_F(UtilsTest, NodeName) {
|
||||
EXPECT_EQ("abc", NodeName("abc"));
|
||||
EXPECT_EQ("abc", NodeName("^abc"));
|
||||
EXPECT_EQ("abc", NodeName("abc:0"));
|
||||
EXPECT_EQ("abc", NodeName("^abc:0"));
|
||||
EXPECT_EQ(NodeName("abc"), "abc");
|
||||
EXPECT_EQ(NodeName("^abc"), "abc");
|
||||
EXPECT_EQ(NodeName("abc:0"), "abc");
|
||||
EXPECT_EQ(NodeName("^abc:0"), "abc");
|
||||
|
||||
EXPECT_EQ("abc/def", NodeName("abc/def"));
|
||||
EXPECT_EQ("abc/def", NodeName("^abc/def"));
|
||||
EXPECT_EQ("abc/def", NodeName("abc/def:1"));
|
||||
EXPECT_EQ("abc/def", NodeName("^abc/def:1"));
|
||||
EXPECT_EQ(NodeName("abc/def"), "abc/def");
|
||||
EXPECT_EQ(NodeName("^abc/def"), "abc/def");
|
||||
EXPECT_EQ(NodeName("abc/def:1"), "abc/def");
|
||||
EXPECT_EQ(NodeName("^abc/def:1"), "abc/def");
|
||||
|
||||
EXPECT_EQ("abc/def0", NodeName("abc/def0"));
|
||||
EXPECT_EQ("abc/def0", NodeName("^abc/def0"));
|
||||
EXPECT_EQ("abc/def0", NodeName("abc/def0:0"));
|
||||
EXPECT_EQ("abc/def0", NodeName("^abc/def0:0"));
|
||||
EXPECT_EQ(NodeName("abc/def0"), "abc/def0");
|
||||
EXPECT_EQ(NodeName("^abc/def0"), "abc/def0");
|
||||
EXPECT_EQ(NodeName("abc/def0:0"), "abc/def0");
|
||||
EXPECT_EQ(NodeName("^abc/def0:0"), "abc/def0");
|
||||
|
||||
EXPECT_EQ("abc/def_0", NodeName("abc/def_0"));
|
||||
EXPECT_EQ("abc/def_0", NodeName("^abc/def_0"));
|
||||
EXPECT_EQ("abc/def_0", NodeName("abc/def_0:3"));
|
||||
EXPECT_EQ("abc/def_0", NodeName("^abc/def_0:3"));
|
||||
EXPECT_EQ(NodeName("abc/def_0"), "abc/def_0");
|
||||
EXPECT_EQ(NodeName("^abc/def_0"), "abc/def_0");
|
||||
EXPECT_EQ(NodeName("abc/def_0:3"), "abc/def_0");
|
||||
EXPECT_EQ(NodeName("^abc/def_0:3"), "abc/def_0");
|
||||
|
||||
EXPECT_EQ("abc/def_0", NodeName("^abc/def_0:3214"));
|
||||
EXPECT_EQ(NodeName("^abc/def_0:3214"), "abc/def_0");
|
||||
}
|
||||
|
||||
TEST_F(UtilsTest, NodePosition) {
|
||||
EXPECT_EQ(2, NodePosition("abc:2"));
|
||||
EXPECT_EQ(123, NodePosition("abc:123"));
|
||||
EXPECT_EQ(-1, NodePosition("^abc:123"));
|
||||
EXPECT_EQ(-1, NodePosition("^abc"));
|
||||
EXPECT_EQ(0, NodePosition(""));
|
||||
EXPECT_EQ(NodePosition("abc:2"), 2);
|
||||
EXPECT_EQ(NodePosition("abc:123"), 123);
|
||||
EXPECT_EQ(NodePosition("^abc:123"), -1);
|
||||
EXPECT_EQ(NodePosition("^abc"), -1);
|
||||
EXPECT_EQ(NodePosition(""), 0);
|
||||
}
|
||||
|
||||
TEST_F(UtilsTest, NodePositionIfSameNode) {
|
||||
EXPECT_EQ(-2, NodePositionIfSameNode(":123", ""));
|
||||
EXPECT_EQ(-2, NodePositionIfSameNode(":", ""));
|
||||
EXPECT_EQ(-2, NodePositionIfSameNode("", ""));
|
||||
EXPECT_EQ(123, NodePositionIfSameNode("abc:123", "abc"));
|
||||
EXPECT_EQ(-1, NodePositionIfSameNode("^abc", "abc"));
|
||||
EXPECT_EQ(-1, NodePositionIfSameNode("^abc:123", "abc"));
|
||||
EXPECT_EQ(-2, NodePositionIfSameNode("abc", "xyz"));
|
||||
EXPECT_EQ(-2, NodePositionIfSameNode("abc", "abc/xyz"));
|
||||
EXPECT_EQ(-2, NodePositionIfSameNode("abc/xyz", "abc"));
|
||||
EXPECT_EQ(-2, NodePositionIfSameNode("abc:123", "xyz"));
|
||||
EXPECT_EQ(-2, NodePositionIfSameNode("^abc", "xyz"));
|
||||
EXPECT_EQ(-2, NodePositionIfSameNode("^abc:123", "xyz"));
|
||||
EXPECT_EQ(NodePositionIfSameNode(":123", ""), -2);
|
||||
EXPECT_EQ(NodePositionIfSameNode(":", ""), -2);
|
||||
EXPECT_EQ(NodePositionIfSameNode("", ""), -2);
|
||||
EXPECT_EQ(NodePositionIfSameNode("abc:123", "abc"), 123);
|
||||
EXPECT_EQ(NodePositionIfSameNode("^abc", "abc"), -1);
|
||||
EXPECT_EQ(NodePositionIfSameNode("^abc:123", "abc"), -1);
|
||||
EXPECT_EQ(NodePositionIfSameNode("abc", "xyz"), -2);
|
||||
EXPECT_EQ(NodePositionIfSameNode("abc", "abc/xyz"), -2);
|
||||
EXPECT_EQ(NodePositionIfSameNode("abc/xyz", "abc"), -2);
|
||||
EXPECT_EQ(NodePositionIfSameNode("abc:123", "xyz"), -2);
|
||||
EXPECT_EQ(NodePositionIfSameNode("^abc", "xyz"), -2);
|
||||
EXPECT_EQ(NodePositionIfSameNode("^abc:123", "xyz"), -2);
|
||||
}
|
||||
|
||||
TEST_F(UtilsTest, AddNodeNamePrefix) {
|
||||
EXPECT_EQ("OPTIMIZED/abc", AddPrefixToNodeName("abc", "OPTIMIZED"));
|
||||
EXPECT_EQ("^OPTIMIZED/abc", AddPrefixToNodeName("^abc", "OPTIMIZED"));
|
||||
EXPECT_EQ("OPTIMIZED/", AddPrefixToNodeName("", "OPTIMIZED"));
|
||||
EXPECT_EQ(AddPrefixToNodeName("abc", "OPTIMIZED"), "OPTIMIZED/abc");
|
||||
EXPECT_EQ(AddPrefixToNodeName("^abc", "OPTIMIZED"), "^OPTIMIZED/abc");
|
||||
EXPECT_EQ(AddPrefixToNodeName("", "OPTIMIZED"), "OPTIMIZED/");
|
||||
}
|
||||
|
||||
TEST_F(UtilsTest, ExecuteWithTimeout) {
|
||||
@ -204,17 +206,17 @@ TEST_F(UtilsTest, ExecuteWithTimeout) {
|
||||
|
||||
TEST_F(UtilsTest, NumOutputs) {
|
||||
GraphDef graph;
|
||||
EXPECT_EQ(2, NumOutputs(CreateConcatOffsetNode(), &graph));
|
||||
EXPECT_EQ(5, NumOutputs(CreateFusedBatchNormNode(), &graph));
|
||||
EXPECT_EQ(1, NumOutputs(CreateDequeueNode(), &graph));
|
||||
EXPECT_EQ(NumOutputs(CreateConcatOffsetNode(), &graph), 2);
|
||||
EXPECT_EQ(NumOutputs(CreateFusedBatchNormNode(), &graph), 5);
|
||||
EXPECT_EQ(NumOutputs(CreateDequeueNode(), &graph), 1);
|
||||
}
|
||||
|
||||
TEST_F(UtilsTest, AsControlDependency) {
|
||||
NodeDef node;
|
||||
node.set_name("foo");
|
||||
EXPECT_EQ("^foo", AsControlDependency(node));
|
||||
EXPECT_EQ("^foo", AsControlDependency(node.name()));
|
||||
EXPECT_EQ("^foo", AsControlDependency("^foo"));
|
||||
EXPECT_EQ(AsControlDependency(node), "^foo");
|
||||
EXPECT_EQ(AsControlDependency(node.name()), "^foo");
|
||||
EXPECT_EQ(AsControlDependency("^foo"), "^foo");
|
||||
}
|
||||
|
||||
TEST_F(UtilsTest, GetTailOfChain) {
|
||||
@ -233,22 +235,23 @@ TEST_F(UtilsTest, GetTailOfChain) {
|
||||
GraphDef graph;
|
||||
TF_CHECK_OK(s.ToGraphDef(&graph));
|
||||
|
||||
ASSERT_EQ("c0", graph.node(0).name());
|
||||
ASSERT_EQ("c1", graph.node(1).name());
|
||||
ASSERT_EQ("neg0", graph.node(2).name());
|
||||
ASSERT_EQ("neg1", graph.node(3).name());
|
||||
ASSERT_EQ("neg2", graph.node(4).name());
|
||||
ASSERT_EQ("id1", graph.node(5).name());
|
||||
ASSERT_EQ("id2", graph.node(6).name());
|
||||
ASSERT_EQ("noop", graph.node(7).name());
|
||||
ASSERT_EQ(graph.node_size(), 8);
|
||||
ASSERT_EQ(graph.node(0).name(), "c0");
|
||||
ASSERT_EQ(graph.node(1).name(), "c1");
|
||||
ASSERT_EQ(graph.node(2).name(), "neg0");
|
||||
ASSERT_EQ(graph.node(3).name(), "neg1");
|
||||
ASSERT_EQ(graph.node(4).name(), "neg2");
|
||||
ASSERT_EQ(graph.node(5).name(), "id1");
|
||||
ASSERT_EQ(graph.node(6).name(), "id2");
|
||||
ASSERT_EQ(graph.node(7).name(), "noop");
|
||||
|
||||
NodeMap node_map(&graph);
|
||||
auto is_neg = [&](const NodeDef& node) { return node.op() == "Neg"; };
|
||||
// We walk backwards, starting as "id1", so tail should be "neg1".
|
||||
NodeDef* tail = GetTailOfChain(graph.node(5), node_map,
|
||||
/*follow_control_input=*/false, is_neg);
|
||||
EXPECT_NE(tail, nullptr);
|
||||
EXPECT_EQ("neg1", tail->name());
|
||||
ASSERT_NE(tail, nullptr);
|
||||
EXPECT_EQ(tail->name(), "neg1");
|
||||
|
||||
// We stop at branching nodes, so tail should be "neg2".
|
||||
auto is_neg_and_non_branching = [&](const NodeDef& node) {
|
||||
@ -257,22 +260,22 @@ TEST_F(UtilsTest, GetTailOfChain) {
|
||||
tail =
|
||||
GetTailOfChain(graph.node(5), node_map,
|
||||
/*follow_control_input=*/false, is_neg_and_non_branching);
|
||||
EXPECT_NE(tail, nullptr);
|
||||
EXPECT_EQ("neg2", tail->name());
|
||||
ASSERT_NE(tail, nullptr);
|
||||
EXPECT_EQ(tail->name(), "neg2");
|
||||
|
||||
// We walk backwards, starting from "noop", also following control inputs,
|
||||
// so tail should be "neg0".
|
||||
tail = GetTailOfChain(graph.node(7), node_map,
|
||||
/*follow_control_input=*/true, is_neg);
|
||||
EXPECT_NE(tail, nullptr);
|
||||
EXPECT_EQ("neg0", tail->name());
|
||||
ASSERT_NE(tail, nullptr);
|
||||
EXPECT_EQ(tail->name(), "neg0");
|
||||
|
||||
// We walk backwards, starting from "noop", not following control inputs,
|
||||
// so tail should be "noop" itself.
|
||||
tail = GetTailOfChain(graph.node(7), node_map,
|
||||
/*follow_control_input=*/false, is_neg);
|
||||
EXPECT_NE(tail, nullptr);
|
||||
EXPECT_EQ("noop", tail->name());
|
||||
ASSERT_NE(tail, nullptr);
|
||||
EXPECT_EQ(tail->name(), "noop");
|
||||
}
|
||||
|
||||
TEST_F(UtilsTest, DedupControlInputs) {
|
||||
@ -280,40 +283,40 @@ TEST_F(UtilsTest, DedupControlInputs) {
|
||||
foo.set_name("foo");
|
||||
foo.add_input("bar");
|
||||
DedupControlInputs(&foo);
|
||||
EXPECT_EQ(1, foo.input_size());
|
||||
EXPECT_EQ("bar", foo.input(0));
|
||||
ASSERT_EQ(foo.input_size(), 1);
|
||||
EXPECT_EQ(foo.input(0), "bar");
|
||||
|
||||
foo.set_input(0, "^bar");
|
||||
DedupControlInputs(&foo);
|
||||
EXPECT_EQ(1, foo.input_size());
|
||||
EXPECT_EQ("^bar", foo.input(0));
|
||||
ASSERT_EQ(foo.input_size(), 1);
|
||||
EXPECT_EQ(foo.input(0), "^bar");
|
||||
|
||||
foo.set_input(0, "bar");
|
||||
foo.add_input("bar");
|
||||
DedupControlInputs(&foo);
|
||||
EXPECT_EQ(2, foo.input_size());
|
||||
EXPECT_EQ("bar", foo.input(0));
|
||||
EXPECT_EQ("bar", foo.input(1));
|
||||
ASSERT_EQ(foo.input_size(), 2);
|
||||
EXPECT_EQ(foo.input(0), "bar");
|
||||
EXPECT_EQ(foo.input(1), "bar");
|
||||
|
||||
foo.set_input(1, "^bar");
|
||||
DedupControlInputs(&foo);
|
||||
EXPECT_EQ(1, foo.input_size());
|
||||
EXPECT_EQ("bar", foo.input(0));
|
||||
ASSERT_EQ(foo.input_size(), 1);
|
||||
EXPECT_EQ(foo.input(0), "bar");
|
||||
|
||||
foo.set_input(0, "^bar");
|
||||
foo.add_input("^bar");
|
||||
DedupControlInputs(&foo);
|
||||
EXPECT_EQ(1, foo.input_size());
|
||||
EXPECT_EQ("^bar", foo.input(0));
|
||||
ASSERT_EQ(foo.input_size(), 1);
|
||||
EXPECT_EQ(foo.input(0), "^bar");
|
||||
|
||||
foo.set_input(0, "bar");
|
||||
foo.add_input("gnu");
|
||||
foo.add_input("^bar");
|
||||
foo.add_input("^gnu");
|
||||
DedupControlInputs(&foo);
|
||||
EXPECT_EQ(2, foo.input_size());
|
||||
EXPECT_EQ("bar", foo.input(0));
|
||||
EXPECT_EQ("gnu", foo.input(1));
|
||||
ASSERT_EQ(foo.input_size(), 2);
|
||||
EXPECT_EQ(foo.input(0), "bar");
|
||||
EXPECT_EQ(foo.input(1), "gnu");
|
||||
}
|
||||
|
||||
TEST_F(UtilsTest, NumNonControlOutputs) {
|
||||
@ -347,14 +350,14 @@ TEST_F(UtilsTest, NumNonControlOutputs) {
|
||||
NodeMap node_map(&graph);
|
||||
|
||||
const NodeDef* add_node = node_map.GetNode("add");
|
||||
ASSERT_TRUE(add_node != nullptr);
|
||||
ASSERT_NE(add_node, nullptr);
|
||||
|
||||
// [a, b] are only non-control inputs
|
||||
EXPECT_EQ(2, NumNonControlInputs(*add_node));
|
||||
EXPECT_EQ(NumNonControlInputs(*add_node), 2);
|
||||
// [sqrt, shape] are non control outputs
|
||||
EXPECT_EQ(2, NumNonControlOutputs(*add_node, node_map));
|
||||
EXPECT_EQ(NumNonControlOutputs(*add_node, node_map), 2);
|
||||
// sqrt is the only data output
|
||||
EXPECT_EQ(1, NumNonControlDataOutputs(*add_node, node_map));
|
||||
EXPECT_EQ(NumNonControlDataOutputs(*add_node, node_map), 1);
|
||||
}
|
||||
|
||||
TEST(CheckAttrExists, All) {
|
||||
@ -465,10 +468,104 @@ TEST_F(UtilsTest, SetTensorValueBFloat16IntMin) {
|
||||
}
|
||||
|
||||
TEST_F(UtilsTest, TensorIdToString) {
|
||||
EXPECT_EQ("^foo", TensorIdToString({"foo", -1}));
|
||||
EXPECT_EQ("foo", TensorIdToString({"foo", 0}));
|
||||
EXPECT_EQ("foo:1", TensorIdToString({"foo", 1}));
|
||||
EXPECT_EQ("foo:2", TensorIdToString({"foo", 2}));
|
||||
EXPECT_EQ(TensorIdToString({"foo", -1}), "^foo");
|
||||
EXPECT_EQ(TensorIdToString({"foo", 0}), "foo");
|
||||
EXPECT_EQ(TensorIdToString({"foo", 1}), "foo:1");
|
||||
EXPECT_EQ(TensorIdToString({"foo", 2}), "foo:2");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void TestSetTensorValue(DataType type, int val, bool success,
|
||||
absl::string_view error_msg) {
|
||||
Tensor t(type, TensorShape({}));
|
||||
Status s = SetTensorValue(t.dtype(), val, &t);
|
||||
EXPECT_EQ(s.ok(), success);
|
||||
if (s.ok()) {
|
||||
test::ExpectTensorEqual<T>(Tensor(static_cast<T>(val)), t);
|
||||
} else {
|
||||
EXPECT_EQ(s.error_message(), error_msg);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(SetTensorValueTest, Quantized) {
|
||||
auto int_min_error = [](DataType type) {
|
||||
return absl::Substitute(
|
||||
"Cannot store value -2147483648 in tensor of type $0",
|
||||
DataType_Name(type));
|
||||
};
|
||||
auto int_max_error = [](DataType type) {
|
||||
return absl::Substitute(
|
||||
"Cannot store value 2147483647 in tensor of type $0",
|
||||
DataType_Name(type));
|
||||
};
|
||||
const int kMinInt = std::numeric_limits<int>::min();
|
||||
const int kMaxInt = std::numeric_limits<int>::max();
|
||||
|
||||
TestSetTensorValue<qint8>(DT_QINT8, -8, /*success=*/true, /*error_msg=*/"");
|
||||
TestSetTensorValue<qint8>(DT_QINT8, 0, /*success=*/true, /*error_msg=*/"");
|
||||
TestSetTensorValue<qint8>(DT_QINT8, 8, /*success=*/true, /*error_msg=*/"");
|
||||
TestSetTensorValue<qint8>(DT_QINT8, std::numeric_limits<qint8>::min(),
|
||||
/*success=*/true, /*error_msg=*/"");
|
||||
TestSetTensorValue<qint8>(DT_QINT8, std::numeric_limits<qint8>::max(),
|
||||
/*success=*/true, /*error_msg=*/"");
|
||||
TestSetTensorValue<qint8>(DT_QINT8, kMinInt, /*success=*/false,
|
||||
int_min_error(DT_QINT8));
|
||||
TestSetTensorValue<qint8>(DT_QINT8, kMaxInt, /*success=*/false,
|
||||
int_max_error(DT_QINT8));
|
||||
|
||||
TestSetTensorValue<quint8>(
|
||||
DT_QUINT8, -8, /*success=*/false,
|
||||
/*error_msg=*/"Cannot store value -8 in tensor of type DT_QUINT8");
|
||||
TestSetTensorValue<quint8>(DT_QUINT8, 0, /*success=*/true, /*error_msg=*/"");
|
||||
TestSetTensorValue<quint8>(DT_QUINT8, 8, /*success=*/true, /*error_msg=*/"");
|
||||
TestSetTensorValue<quint8>(DT_QUINT8, std::numeric_limits<quint8>::min(),
|
||||
/*success=*/true, /*error_msg=*/"");
|
||||
TestSetTensorValue<quint8>(DT_QUINT8, std::numeric_limits<quint8>::max(),
|
||||
/*success=*/true, /*error_msg=*/"");
|
||||
TestSetTensorValue<quint8>(DT_QUINT8, kMinInt, /*success=*/false,
|
||||
int_min_error(DT_QUINT8));
|
||||
TestSetTensorValue<quint8>(DT_QUINT8, kMaxInt, /*success=*/false,
|
||||
int_max_error(DT_QUINT8));
|
||||
|
||||
TestSetTensorValue<qint16>(DT_QINT16, -8, /*success=*/true, /*error_msg=*/"");
|
||||
TestSetTensorValue<qint16>(DT_QINT16, 0, /*success=*/true, /*error_msg=*/"");
|
||||
TestSetTensorValue<qint16>(DT_QINT16, 8, /*success=*/true, /*error_msg=*/"");
|
||||
TestSetTensorValue<qint16>(DT_QINT16, std::numeric_limits<qint16>::min(),
|
||||
/*success=*/true, /*error_msg=*/"");
|
||||
TestSetTensorValue<qint16>(DT_QINT16, std::numeric_limits<qint16>::max(),
|
||||
/*success=*/true, /*error_msg=*/"");
|
||||
TestSetTensorValue<qint16>(DT_QINT16, kMinInt, /*success=*/false,
|
||||
int_min_error(DT_QINT16));
|
||||
TestSetTensorValue<qint16>(DT_QINT16, kMaxInt, /*success=*/false,
|
||||
int_max_error(DT_QINT16));
|
||||
|
||||
TestSetTensorValue<quint16>(
|
||||
DT_QUINT16, -8, /*success=*/false,
|
||||
/*error_msg=*/"Cannot store value -8 in tensor of type DT_QUINT16");
|
||||
TestSetTensorValue<quint16>(DT_QUINT16, 0, /*success=*/true,
|
||||
/*error_msg=*/"");
|
||||
TestSetTensorValue<quint16>(DT_QUINT16, 8, /*success=*/true,
|
||||
/*error_msg=*/"");
|
||||
TestSetTensorValue<quint16>(DT_QUINT16, std::numeric_limits<quint16>::min(),
|
||||
/*success=*/true, /*error_msg=*/"");
|
||||
TestSetTensorValue<quint16>(DT_QUINT16, std::numeric_limits<quint16>::max(),
|
||||
/*success=*/true, /*error_msg=*/"");
|
||||
TestSetTensorValue<quint16>(DT_QUINT16, kMinInt, /*success=*/false,
|
||||
int_min_error(DT_QUINT16));
|
||||
TestSetTensorValue<quint16>(DT_QUINT16, kMaxInt, /*success=*/false,
|
||||
int_max_error(DT_QUINT16));
|
||||
|
||||
TestSetTensorValue<qint32>(DT_QINT32, -8, /*success=*/true, /*error_msg=*/"");
|
||||
TestSetTensorValue<qint32>(DT_QINT32, 0, /*success=*/true, /*error_msg=*/"");
|
||||
TestSetTensorValue<qint32>(DT_QINT32, 8, /*success=*/true, /*error_msg=*/"");
|
||||
TestSetTensorValue<qint32>(DT_QINT32, std::numeric_limits<qint32>::min(),
|
||||
/*success=*/true, /*error_msg=*/"");
|
||||
TestSetTensorValue<qint32>(DT_QINT32, std::numeric_limits<qint32>::max(),
|
||||
/*success=*/true, /*error_msg=*/"");
|
||||
TestSetTensorValue<qint32>(DT_QINT32, kMinInt, /*success=*/true,
|
||||
/*error_msg=*/"");
|
||||
TestSetTensorValue<qint32>(DT_QINT32, kMaxInt, /*success=*/true,
|
||||
/*error_msg=*/"");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
Loading…
Reference in New Issue
Block a user