[Grappler] Add initial support for DT_QINT32, DT_QINT16, DT_QUINT16, DT_QINT8, and DT_QUINT8 to ConstantFolding.

PiperOrigin-RevId: 234072895
This commit is contained in:
Andy Ly 2019-02-14 19:32:14 -08:00 committed by TensorFlower Gardener
parent 68b686af89
commit 43f47645a8
11 changed files with 609 additions and 317 deletions

View File

@ -1,7 +1,6 @@
licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("//tensorflow:tensorflow.bzl", "tf_cuda_library")
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_cuda_library")
cc_library(
name = "op_types",
@ -45,6 +44,7 @@ tf_cc_test(
"//tensorflow/core:tensor_testutil",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@com_google_absl//absl/strings",
],
)
@ -71,7 +71,6 @@ cc_library(
deps = [
":graph_view",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:inlined_vector",

View File

@ -481,6 +481,7 @@ bool IsNumericType(const DataType dtype) {
DT_QINT8,
DT_QUINT8,
DT_QINT16,
DT_QUINT16,
DT_QINT32,
// Bool.
DT_BOOL,

View File

@ -279,8 +279,8 @@ bool IsLogicalOr(const NodeDef& node) { return node.op() == "LogicalOr"; }
bool IsMatMul(const NodeDef& node) {
const auto& op = node.op();
return op == "MatMul" || op == "BatchMatMul" || op == "QuantizedMatMul" ||
op == "SparseMatMul";
return op == "MatMul" || op == "BatchMatMul" || op == "SparseMatMul" ||
IsQuantizedMatMul(node);
}
bool IsMax(const NodeDef& node) { return node.op() == "Max"; }
@ -350,6 +350,10 @@ bool IsPrint(const NodeDef& node) {
bool IsProd(const NodeDef& node) { return node.op() == "Prod"; }
bool IsQuantizedMatMul(const NodeDef& node) {
return node.op() == "QuantizedMatMul" || node.op() == "QuantizedMatMulV2";
}
bool IsQueue(const NodeDef& node) {
return str_util::EndsWith(node.op(), "QueueV2");
}

View File

@ -106,6 +106,7 @@ bool IsPack(const NodeDef& node);
bool IsPad(const NodeDef& node);
bool IsPack(const NodeDef& node);
bool IsPartitionedCall(const NodeDef& node);
bool IsQuantizedMatMul(const NodeDef& node);
bool IsNeg(const NodeDef& node);
bool IsNoOp(const NodeDef& node);
bool IsNotEqual(const NodeDef& node);

View File

@ -3,7 +3,6 @@ licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
# Platform specific build config
load(
@ -274,13 +273,29 @@ cc_library(
],
)
cc_library(
name = "arithmetic_optimizer_test_utils",
testonly = 1,
hdrs = [
"arithmetic_optimizer_test_utils.h",
],
visibility = ["//visibility:public"],
deps = [
":arithmetic_optimizer",
":constant_folding",
":model_pruner",
"//tensorflow/core:test",
"//tensorflow/core/grappler/utils:grappler_test",
],
)
tf_cuda_cc_test(
name = "arithmetic_optimizer_test",
size = "small",
srcs = ["arithmetic_optimizer_test.cc"],
deps = [
":arithmetic_optimizer",
":constant_folding",
":arithmetic_optimizer_test_utils",
":model_pruner",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:cc_ops_internal",
@ -295,7 +310,6 @@ tf_cuda_cc_test(
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
"//tensorflow/core/grappler/utils:grappler_test",
],
)

View File

@ -20,10 +20,9 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
#include "tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h"
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/grappler_test.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
@ -92,211 +91,6 @@ void VerifyGraphsMatch(const GraphDef& original_graph,
}
} // namespace
class ArithmeticOptimizerTest : public GrapplerTest {
protected:
// Optimize a graph using ArithmeticOptimizer and prune all the nodes that no
// longer have any output consumers.
void OptimizeAndPrune(ArithmeticOptimizer* optimizer, GrapplerItem* item,
GraphDef* output) {
TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
item->graph.Swap(output);
output->Clear();
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, *item, output));
}
// Run ArithmeticOptimizer twice to make sure the rewrite is idempotent.
void OptimizeTwice(ArithmeticOptimizer* optimizer, GrapplerItem* item,
GraphDef* output) {
TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
item->graph.Swap(output);
output->Clear();
TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
}
// Run ArithmeticOptimizer twice to make sure the rewrite is idempotent.
// Optionally run a constant folding pass before pruning.
void OptimizeTwiceAndPrune(ArithmeticOptimizer* optimizer, GrapplerItem* item,
GraphDef* output, bool const_folding = false) {
TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
item->graph.Swap(output);
output->Clear();
TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
if (const_folding) {
item->graph.Swap(output);
output->Clear();
TF_EXPECT_OK(ConstantFolding(/*cpu_device=*/nullptr)
.Optimize(nullptr, *item, output));
}
item->graph.Swap(output);
output->Clear();
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, *item, output));
}
// TODO(ezhulenev): Make private. After migration to stages each test
// should explicitly enable required optimization for tests isolation
void DisableAllStages(ArithmeticOptimizer* optimizer) {
ArithmeticOptimizer::ArithmeticOptimizerOptions options;
options.dedup_computations = false;
options.combine_add_to_addn = false;
options.convert_sqrt_div_to_rsqrt_mul = false;
options.convert_pow = false;
options.convert_log1p = false;
options.optimize_max_or_min_of_monotonic = false;
options.fold_conjugate_into_transpose = false;
options.fold_multiply_into_conv = false;
options.fold_transpose_into_matmul = false;
options.hoist_common_factor_out_of_aggregation = false;
options.hoist_cwise_unary_chains = false;
options.minimize_broadcasts = false;
options.remove_identity_transpose = false;
options.remove_involution = false;
options.remove_idempotent = false;
options.remove_redundant_bitcast = false;
options.remove_redundant_cast = false;
options.remove_redundant_reshape = false;
options.remove_negation = false;
options.remove_logical_not = false;
options.reorder_cast_like_and_value_preserving = false;
options.replace_mul_with_square = false;
options.simplify_aggregation = false;
options.unary_ops_composition = false;
optimizer->options_ = options;
}
void DisableAddToAddNCombining(ArithmeticOptimizer* optimizer) {
optimizer->options_.combine_add_to_addn = false;
}
void EnableOnlyAddToAddNCombining(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.combine_add_to_addn = true;
}
void EnableOnlyFoldConjugateIntoTranspose(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.fold_conjugate_into_transpose = true;
}
void EnableOnlyFoldMultipleIntoConv(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.fold_multiply_into_conv = true;
}
void EnableOnlyFoldTransposeIntoMatMul(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.fold_transpose_into_matmul = true;
}
void EnableOnlyHoistCommonFactor(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.hoist_common_factor_out_of_aggregation = true;
}
void EnableOnlyMinimizeBroadcasts(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.minimize_broadcasts = true;
}
void EnableOnlyRemoveIdentityTranspose(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.remove_identity_transpose = true;
}
void EnableOnlyRemoveInvolution(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.remove_involution = true;
}
void EnableOnlyRemoveRedundantBitcast(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.remove_redundant_bitcast = true;
}
void EnableOnlyRemoveRedundantCast(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.remove_redundant_cast = true;
}
void EnableOnlyRemoveRedundantReshape(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.remove_redundant_reshape = true;
}
void EnableOnlyRemoveNegation(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.remove_negation = true;
}
void EnableOnlyReorderCastAndTranspose(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.reorder_cast_like_and_value_preserving = true;
}
void EnableOnlyReplaceMulWithSquare(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.replace_mul_with_square = true;
}
void EnableOnlyHoistCWiseUnaryChains(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.hoist_cwise_unary_chains = true;
}
void EnableOnlySqrtDivToRsqrtMul(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.convert_sqrt_div_to_rsqrt_mul = true;
}
void EnableOnlyConvertPow(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.convert_pow = true;
}
void EnableOnlyRemoveIdempotent(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.remove_idempotent = true;
}
void EnableOnlyRemoveLogicalNot(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.remove_logical_not = true;
}
void EnableOnlySimplifyAggregation(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.simplify_aggregation = true;
}
void EnableOnlyLog1p(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.convert_log1p = true;
}
void EnableOnlyOptimizeMaxOrMinOfMonotonic(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.optimize_max_or_min_of_monotonic = true;
}
void EnableOnlyExpm1(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.convert_expm1 = true;
}
void EnableOnlyUnaryOpsComposition(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.unary_ops_composition = true;
}
void EnableOnlyRemoveStackStridedSliceSameAxis(
ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.remove_stack_strided_slice_same_axis = true;
}
};
TEST_F(ArithmeticOptimizerTest, NoOp) {
// This trivial graph is so basic there's nothing to optimize.
TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});

View File

@ -0,0 +1,236 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_ARITHMETIC_OPTIMIZER_TEST_UTILS_H_
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_ARITHMETIC_OPTIMIZER_TEST_UTILS_H_
#include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
#include "tensorflow/core/grappler/utils/grappler_test.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace tensorflow {
namespace grappler {
class ArithmeticOptimizerTest : public GrapplerTest {
protected:
// Optimize a graph using ArithmeticOptimizer and prune all the nodes that no
// longer have any output consumers.
void OptimizeAndPrune(ArithmeticOptimizer* optimizer, GrapplerItem* item,
GraphDef* output) {
TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
item->graph.Swap(output);
output->Clear();
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, *item, output));
}
// Run ArithmeticOptimizer twice to make sure the rewrite is idempotent.
void OptimizeTwice(ArithmeticOptimizer* optimizer, GrapplerItem* item,
GraphDef* output) {
TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
item->graph.Swap(output);
output->Clear();
TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
}
// Run ArithmeticOptimizer twice to make sure the rewrite is idempotent.
// Optionally run a constant folding pass before pruning.
void OptimizeTwiceAndPrune(ArithmeticOptimizer* optimizer, GrapplerItem* item,
GraphDef* output, bool const_folding = false) {
TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
item->graph.Swap(output);
output->Clear();
TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
if (const_folding) {
item->graph.Swap(output);
output->Clear();
TF_EXPECT_OK(ConstantFolding(/*cpu_device=*/nullptr)
.Optimize(nullptr, *item, output));
}
item->graph.Swap(output);
output->Clear();
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, *item, output));
}
// TODO(ezhulenev): Make private. After migration to stages each test
// should explicitly enable required optimization for tests isolation
void DisableAllStages(ArithmeticOptimizer* optimizer) {
ArithmeticOptimizer::ArithmeticOptimizerOptions options;
options.dedup_computations = false;
options.combine_add_to_addn = false;
options.convert_sqrt_div_to_rsqrt_mul = false;
options.convert_pow = false;
options.convert_log1p = false;
options.optimize_max_or_min_of_monotonic = false;
options.fold_conjugate_into_transpose = false;
options.fold_multiply_into_conv = false;
options.fold_transpose_into_matmul = false;
options.hoist_common_factor_out_of_aggregation = false;
options.hoist_cwise_unary_chains = false;
options.minimize_broadcasts = false;
options.remove_identity_transpose = false;
options.remove_involution = false;
options.remove_idempotent = false;
options.remove_redundant_bitcast = false;
options.remove_redundant_cast = false;
options.remove_redundant_reshape = false;
options.remove_negation = false;
options.remove_logical_not = false;
options.reorder_cast_like_and_value_preserving = false;
options.replace_mul_with_square = false;
options.simplify_aggregation = false;
options.unary_ops_composition = false;
optimizer->options_ = options;
}
void DisableAddToAddNCombining(ArithmeticOptimizer* optimizer) {
optimizer->options_.combine_add_to_addn = false;
}
void EnableOnlyAddToAddNCombining(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.combine_add_to_addn = true;
}
void EnableOnlyFoldConjugateIntoTranspose(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.fold_conjugate_into_transpose = true;
}
void EnableOnlyFoldMultipleIntoConv(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.fold_multiply_into_conv = true;
}
void EnableOnlyFoldTransposeIntoMatMul(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.fold_transpose_into_matmul = true;
}
void EnableOnlyHoistCommonFactor(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.hoist_common_factor_out_of_aggregation = true;
}
void EnableOnlyMinimizeBroadcasts(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.minimize_broadcasts = true;
}
void EnableOnlyRemoveIdentityTranspose(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.remove_identity_transpose = true;
}
void EnableOnlyRemoveInvolution(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.remove_involution = true;
}
void EnableOnlyRemoveRedundantBitcast(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.remove_redundant_bitcast = true;
}
void EnableOnlyRemoveRedundantCast(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.remove_redundant_cast = true;
}
void EnableOnlyRemoveRedundantReshape(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.remove_redundant_reshape = true;
}
void EnableOnlyRemoveNegation(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.remove_negation = true;
}
void EnableOnlyReorderCastAndTranspose(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.reorder_cast_like_and_value_preserving = true;
}
void EnableOnlyReplaceMulWithSquare(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.replace_mul_with_square = true;
}
void EnableOnlyHoistCWiseUnaryChains(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.hoist_cwise_unary_chains = true;
}
void EnableOnlySqrtDivToRsqrtMul(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.convert_sqrt_div_to_rsqrt_mul = true;
}
void EnableOnlyConvertPow(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.convert_pow = true;
}
void EnableOnlyRemoveIdempotent(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.remove_idempotent = true;
}
void EnableOnlyRemoveLogicalNot(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.remove_logical_not = true;
}
void EnableOnlySimplifyAggregation(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.simplify_aggregation = true;
}
void EnableOnlyLog1p(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.convert_log1p = true;
}
void EnableOnlyOptimizeMaxOrMinOfMonotonic(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.optimize_max_or_min_of_monotonic = true;
}
void EnableOnlyExpm1(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.convert_expm1 = true;
}
void EnableOnlyUnaryOpsComposition(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.unary_ops_composition = true;
}
void EnableOnlyRemoveStackStridedSliceSameAxis(
ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.remove_stack_strided_slice_same_axis = true;
}
};
} // end namespace grappler
} // end namespace tensorflow
#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_ARITHMETIC_OPTIMIZER_TEST_UTILS_H_

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <cmath>
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.pb.h"
@ -37,6 +38,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/evaluation_utils.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
@ -185,6 +187,40 @@ bool IsDenormal(double x) {
return !std::isnormal(x);
}
float QuantizedTypeMinAsFloat(DataType data_type) {
switch (data_type) {
case DT_QINT8:
return Eigen::NumTraits<qint8>::lowest();
case DT_QUINT8:
return Eigen::NumTraits<quint8>::lowest();
case DT_QINT16:
return Eigen::NumTraits<qint16>::lowest();
case DT_QUINT16:
return Eigen::NumTraits<quint16>::lowest();
case DT_QINT32:
return Eigen::NumTraits<qint32>::lowest();
default:
return 0.0f;
}
}
float QuantizedTypeMaxAsFloat(DataType data_type) {
switch (data_type) {
case DT_QINT8:
return Eigen::NumTraits<qint8>::highest();
case DT_QUINT8:
return Eigen::NumTraits<quint8>::highest();
case DT_QINT16:
return Eigen::NumTraits<qint16>::highest();
case DT_QUINT16:
return Eigen::NumTraits<quint16>::highest();
case DT_QINT32:
return Eigen::NumTraits<qint32>::highest();
default:
return 0.0f;
}
}
} // namespace
ConstantFolding::ConstantFolding(RewriterConfig::Toggle opt_level,
@ -945,6 +981,11 @@ Status CreateConstantTensorAttrValue(DataType type, double value,
SET_TENSOR_VAL_CASE(DT_UINT16, int32, int);
SET_TENSOR_VAL_CASE(DT_INT8, int32, int);
SET_TENSOR_VAL_CASE(DT_UINT8, int32, int);
SET_TENSOR_VAL_CASE(DT_QINT32, int32, int);
SET_TENSOR_VAL_CASE(DT_QINT16, int32, int);
SET_TENSOR_VAL_CASE(DT_QUINT16, int32, int);
SET_TENSOR_VAL_CASE(DT_QINT8, int32, int);
SET_TENSOR_VAL_CASE(DT_QUINT8, int32, int);
SET_TENSOR_VAL_CASE(DT_BOOL, bool, bool);
default:
return errors::InvalidArgument("Unsupported type: ", type);
@ -1085,6 +1126,8 @@ Status ConstantFolding::CreateNodeDef(const string& name,
t->set_dtype(tensor->dtype());
tensor->shape().AsProto(t->mutable_tensor_shape());
} else {
// DT_HALF, DT_BFLOAT16, DT_QINT32, DT_QINT16, DT_QUINT16, DT_QINT8,
// DT_QUINT8
tensor->AsProtoTensorContent(t);
encoded_size = t->tensor_content().size();
}
@ -1533,6 +1576,11 @@ bool ConstantFolding::IsOnes(const NodeDef& node) const {
IS_ONES_CASE(DT_INT16);
IS_ONES_CASE(DT_INT32);
IS_ONES_CASE(DT_INT64);
IS_ONES_CASE(DT_QINT32);
IS_ONES_CASE(DT_QINT16);
IS_ONES_CASE(DT_QUINT16);
IS_ONES_CASE(DT_QINT8);
IS_ONES_CASE(DT_QUINT8);
default:
VLOG(1) << "Unsupported type " << DataTypeString(dtype);
return false;
@ -1567,6 +1615,11 @@ bool ConstantFolding::IsZeros(const NodeDef& node) const {
IS_ZEROS_CASE(DT_INT16);
IS_ZEROS_CASE(DT_INT32);
IS_ZEROS_CASE(DT_INT64);
IS_ZEROS_CASE(DT_QINT32);
IS_ZEROS_CASE(DT_QINT16);
IS_ZEROS_CASE(DT_QUINT16);
IS_ZEROS_CASE(DT_QINT8);
IS_ZEROS_CASE(DT_QUINT8);
default:
VLOG(1) << "Unsupported type " << DataTypeString(dtype);
return false;
@ -2576,6 +2629,7 @@ Status ConstantFolding::SimplifyArithmeticOperations(
*success = false;
const bool is_mul = IsMul(*node) || IsLogicalAnd(*node);
const bool is_matmul = IsMatMul(*node);
const bool is_quantized_matmul = IsQuantizedMatMul(*node);
const bool is_add = IsAdd(*node) || IsBiasAdd(*node) || IsLogicalOr(*node);
const bool is_sub = IsSub(*node);
const bool is_any_div = IsAnyDiv(*node);
@ -2670,6 +2724,10 @@ Status ConstantFolding::SimplifyArithmeticOperations(
if (!replace_op_status.ok()) {
return replace_op_status;
} else if (replace_succeed) {
if (is_quantized_matmul) {
TF_RETURN_IF_ERROR(
AddQuantizedMatMulMinMaxOutConstNodes(node, optimized_graph));
}
*success = true;
return Status::OK();
}
@ -3237,6 +3295,65 @@ bool ConstantFolding::MergeConcat(const GraphProperties& properties,
return true;
}
Status ConstantFolding::AddQuantizedMatMulMinMaxOutConstNodes(
NodeDef* node, GraphDef* optimized_graph) {
auto add_quantized_out = [this, node, optimized_graph](
const string& out_const_name, int index) {
NodeDef* out_node = optimized_graph->add_node();
Tensor value(DT_FLOAT, TensorShape({}));
const bool is_min = index == 1;
const DataType type_attr = node->attr().at("dtype").type();
value.flat<float>()(0) = is_min ? QuantizedTypeMinAsFloat(type_attr)
: QuantizedTypeMaxAsFloat(type_attr);
TF_RETURN_IF_ERROR(
CreateNodeDef(out_const_name, TensorValue(&value), out_node));
node_map_->AddNode(out_const_name, out_node);
out_node->set_device(node->device());
// Copy all inputs from node.
out_node->mutable_input()->CopyFrom(node->input());
for (const string& input : out_node->input()) {
node_map_->AddOutput(NodeName(input), out_const_name);
}
// Update output nodes consuming node:index to new const node.
string old_input = absl::StrCat(node->name(), ":", index);
int old_node_count = 0;
auto outputs = node_map_->GetOutputs(node->name());
for (const auto& output : outputs) {
for (int i = 0; i < output->input_size(); ++i) {
if (output->input(i) == old_input) {
output->set_input(i, out_const_name);
node_map_->AddOutput(out_const_name, output->name());
} else if (NodeName(output->input(i)) == node->name()) {
++old_node_count;
}
}
if (old_node_count == 0) {
node_map_->RemoveOutput(node->name(), output->name());
}
}
return Status::OK();
};
const string min_out_const_name =
OptimizedNodeName(*node, "-quantized_matmul_min_out");
const string max_out_const_name =
OptimizedNodeName(*node, "-quantized_matmul_max_out");
if (node_map_->GetNode(min_out_const_name) == nullptr &&
node_map_->GetNode(max_out_const_name) == nullptr) {
TF_RETURN_IF_ERROR(add_quantized_out(min_out_const_name, 1));
TF_RETURN_IF_ERROR(add_quantized_out(max_out_const_name, 2));
} else {
return errors::Internal(absl::Substitute(
"Can't create Const for QuantizedMatMul min_out/max_out of "
"node '$0' because of node name conflict",
node->name()));
}
return Status::OK();
}
Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
const GrapplerItem& item,
GraphDef* optimized_graph) {

View File

@ -236,6 +236,9 @@ class ConstantFolding : public GraphOptimizer {
bool MergeConcat(const GraphProperties& properties, bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node);
Status AddQuantizedMatMulMinMaxOutConstNodes(NodeDef* node,
GraphDef* optimized_graph);
// Points to an externally provided device or to owned_device_;
RewriterConfig::Toggle opt_level_;
DeviceBase* cpu_device_;

View File

@ -40,7 +40,7 @@ namespace tensorflow {
namespace grappler {
namespace {
template <typename T>
bool SafeSetScalarTensorValue(double value, Tensor* tensor) {
bool SafeSetDoubleScalarTensorValue(double value, Tensor* tensor) {
using RealType = typename Eigen::NumTraits<T>::Real;
if (value > static_cast<double>(Eigen::NumTraits<RealType>::highest()) ||
value < static_cast<double>(Eigen::NumTraits<RealType>::lowest())) {
@ -50,6 +50,17 @@ bool SafeSetScalarTensorValue(double value, Tensor* tensor) {
return true;
}
template <typename T>
bool SafeSetIntScalarTensorValue(int value, Tensor* tensor) {
using RealType = typename Eigen::NumTraits<T>::Real;
if (value > static_cast<int>(Eigen::NumTraits<RealType>::highest()) ||
value < static_cast<int>(Eigen::NumTraits<RealType>::lowest())) {
return false;
}
tensor->flat<T>()(0) = static_cast<T>(value);
return true;
}
// Is 'node' an operator that consumes only the shape of its input, not the
// data itself?
// TODO(ezhulenev): move to op_types.h. Requires to break circular dependency.
@ -410,35 +421,50 @@ void EraseNodesFromGraph(const std::set<string>& nodes_to_delete,
EraseNodesFromGraphImpl(nodes_idx_to_delete, graph);
}
#define HANDLE_CASE(DTYPE) \
#define HANDLE_DOUBLE_CASE(DTYPE) \
case DTYPE: \
if (!SafeSetScalarTensorValue<EnumToDataType<DTYPE>::Type>( \
if (!SafeSetDoubleScalarTensorValue<EnumToDataType<DTYPE>::Type>( \
static_cast<double>(value), tensor)) { \
return errors::InvalidArgument("Cannot store value ", value, \
" in tensor of type " #DTYPE); \
} \
break
#define HANDLE_INT_CASE(DTYPE) \
case DTYPE: \
if (!SafeSetIntScalarTensorValue<EnumToDataType<DTYPE>::Type>(value, \
tensor)) { \
return errors::InvalidArgument("Cannot store value ", value, \
" in tensor of type " #DTYPE); \
} \
break
Status SetTensorValue(DataType dtype, int value, Tensor* tensor) {
// TODO(rmlarsen): Support more general shapes.
// TODO(lyandy): Change `value` to be int64 once int64 -> qint32 is supported.
if (tensor->NumElements() != 1) {
return errors::InvalidArgument(
"Expected scalar tensor, got num_elements = ", tensor->NumElements());
}
switch (dtype) {
HANDLE_CASE(DT_HALF);
HANDLE_CASE(DT_BFLOAT16);
HANDLE_CASE(DT_BOOL);
HANDLE_CASE(DT_FLOAT);
HANDLE_CASE(DT_DOUBLE);
HANDLE_CASE(DT_UINT8);
HANDLE_CASE(DT_INT8);
HANDLE_CASE(DT_UINT16);
HANDLE_CASE(DT_INT16);
HANDLE_CASE(DT_INT32);
HANDLE_CASE(DT_INT64);
HANDLE_CASE(DT_COMPLEX64);
HANDLE_CASE(DT_COMPLEX128);
HANDLE_DOUBLE_CASE(DT_HALF);
HANDLE_DOUBLE_CASE(DT_BFLOAT16);
HANDLE_DOUBLE_CASE(DT_BOOL);
HANDLE_DOUBLE_CASE(DT_FLOAT);
HANDLE_DOUBLE_CASE(DT_DOUBLE);
HANDLE_DOUBLE_CASE(DT_UINT8);
HANDLE_DOUBLE_CASE(DT_INT8);
HANDLE_DOUBLE_CASE(DT_UINT16);
HANDLE_DOUBLE_CASE(DT_INT16);
HANDLE_DOUBLE_CASE(DT_INT32);
HANDLE_DOUBLE_CASE(DT_INT64);
HANDLE_DOUBLE_CASE(DT_COMPLEX64);
HANDLE_DOUBLE_CASE(DT_COMPLEX128);
HANDLE_INT_CASE(DT_QINT8);
HANDLE_INT_CASE(DT_QUINT8);
HANDLE_INT_CASE(DT_QINT16);
HANDLE_INT_CASE(DT_QUINT16);
HANDLE_INT_CASE(DT_QINT32);
default:
return errors::InvalidArgument("Unsupported type ",
DataTypeString(dtype));

View File

@ -18,6 +18,8 @@ limitations under the License.
#include <unistd.h>
#include <limits>
#include <memory>
#include "absl/strings/substitute.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/tensor_testutil.h"
@ -124,56 +126,56 @@ class UtilsTest : public ::testing::Test {
};
TEST_F(UtilsTest, NodeName) {
EXPECT_EQ("abc", NodeName("abc"));
EXPECT_EQ("abc", NodeName("^abc"));
EXPECT_EQ("abc", NodeName("abc:0"));
EXPECT_EQ("abc", NodeName("^abc:0"));
EXPECT_EQ(NodeName("abc"), "abc");
EXPECT_EQ(NodeName("^abc"), "abc");
EXPECT_EQ(NodeName("abc:0"), "abc");
EXPECT_EQ(NodeName("^abc:0"), "abc");
EXPECT_EQ("abc/def", NodeName("abc/def"));
EXPECT_EQ("abc/def", NodeName("^abc/def"));
EXPECT_EQ("abc/def", NodeName("abc/def:1"));
EXPECT_EQ("abc/def", NodeName("^abc/def:1"));
EXPECT_EQ(NodeName("abc/def"), "abc/def");
EXPECT_EQ(NodeName("^abc/def"), "abc/def");
EXPECT_EQ(NodeName("abc/def:1"), "abc/def");
EXPECT_EQ(NodeName("^abc/def:1"), "abc/def");
EXPECT_EQ("abc/def0", NodeName("abc/def0"));
EXPECT_EQ("abc/def0", NodeName("^abc/def0"));
EXPECT_EQ("abc/def0", NodeName("abc/def0:0"));
EXPECT_EQ("abc/def0", NodeName("^abc/def0:0"));
EXPECT_EQ(NodeName("abc/def0"), "abc/def0");
EXPECT_EQ(NodeName("^abc/def0"), "abc/def0");
EXPECT_EQ(NodeName("abc/def0:0"), "abc/def0");
EXPECT_EQ(NodeName("^abc/def0:0"), "abc/def0");
EXPECT_EQ("abc/def_0", NodeName("abc/def_0"));
EXPECT_EQ("abc/def_0", NodeName("^abc/def_0"));
EXPECT_EQ("abc/def_0", NodeName("abc/def_0:3"));
EXPECT_EQ("abc/def_0", NodeName("^abc/def_0:3"));
EXPECT_EQ(NodeName("abc/def_0"), "abc/def_0");
EXPECT_EQ(NodeName("^abc/def_0"), "abc/def_0");
EXPECT_EQ(NodeName("abc/def_0:3"), "abc/def_0");
EXPECT_EQ(NodeName("^abc/def_0:3"), "abc/def_0");
EXPECT_EQ("abc/def_0", NodeName("^abc/def_0:3214"));
EXPECT_EQ(NodeName("^abc/def_0:3214"), "abc/def_0");
}
TEST_F(UtilsTest, NodePosition) {
EXPECT_EQ(2, NodePosition("abc:2"));
EXPECT_EQ(123, NodePosition("abc:123"));
EXPECT_EQ(-1, NodePosition("^abc:123"));
EXPECT_EQ(-1, NodePosition("^abc"));
EXPECT_EQ(0, NodePosition(""));
EXPECT_EQ(NodePosition("abc:2"), 2);
EXPECT_EQ(NodePosition("abc:123"), 123);
EXPECT_EQ(NodePosition("^abc:123"), -1);
EXPECT_EQ(NodePosition("^abc"), -1);
EXPECT_EQ(NodePosition(""), 0);
}
TEST_F(UtilsTest, NodePositionIfSameNode) {
EXPECT_EQ(-2, NodePositionIfSameNode(":123", ""));
EXPECT_EQ(-2, NodePositionIfSameNode(":", ""));
EXPECT_EQ(-2, NodePositionIfSameNode("", ""));
EXPECT_EQ(123, NodePositionIfSameNode("abc:123", "abc"));
EXPECT_EQ(-1, NodePositionIfSameNode("^abc", "abc"));
EXPECT_EQ(-1, NodePositionIfSameNode("^abc:123", "abc"));
EXPECT_EQ(-2, NodePositionIfSameNode("abc", "xyz"));
EXPECT_EQ(-2, NodePositionIfSameNode("abc", "abc/xyz"));
EXPECT_EQ(-2, NodePositionIfSameNode("abc/xyz", "abc"));
EXPECT_EQ(-2, NodePositionIfSameNode("abc:123", "xyz"));
EXPECT_EQ(-2, NodePositionIfSameNode("^abc", "xyz"));
EXPECT_EQ(-2, NodePositionIfSameNode("^abc:123", "xyz"));
EXPECT_EQ(NodePositionIfSameNode(":123", ""), -2);
EXPECT_EQ(NodePositionIfSameNode(":", ""), -2);
EXPECT_EQ(NodePositionIfSameNode("", ""), -2);
EXPECT_EQ(NodePositionIfSameNode("abc:123", "abc"), 123);
EXPECT_EQ(NodePositionIfSameNode("^abc", "abc"), -1);
EXPECT_EQ(NodePositionIfSameNode("^abc:123", "abc"), -1);
EXPECT_EQ(NodePositionIfSameNode("abc", "xyz"), -2);
EXPECT_EQ(NodePositionIfSameNode("abc", "abc/xyz"), -2);
EXPECT_EQ(NodePositionIfSameNode("abc/xyz", "abc"), -2);
EXPECT_EQ(NodePositionIfSameNode("abc:123", "xyz"), -2);
EXPECT_EQ(NodePositionIfSameNode("^abc", "xyz"), -2);
EXPECT_EQ(NodePositionIfSameNode("^abc:123", "xyz"), -2);
}
TEST_F(UtilsTest, AddNodeNamePrefix) {
EXPECT_EQ("OPTIMIZED/abc", AddPrefixToNodeName("abc", "OPTIMIZED"));
EXPECT_EQ("^OPTIMIZED/abc", AddPrefixToNodeName("^abc", "OPTIMIZED"));
EXPECT_EQ("OPTIMIZED/", AddPrefixToNodeName("", "OPTIMIZED"));
EXPECT_EQ(AddPrefixToNodeName("abc", "OPTIMIZED"), "OPTIMIZED/abc");
EXPECT_EQ(AddPrefixToNodeName("^abc", "OPTIMIZED"), "^OPTIMIZED/abc");
EXPECT_EQ(AddPrefixToNodeName("", "OPTIMIZED"), "OPTIMIZED/");
}
TEST_F(UtilsTest, ExecuteWithTimeout) {
@ -204,17 +206,17 @@ TEST_F(UtilsTest, ExecuteWithTimeout) {
TEST_F(UtilsTest, NumOutputs) {
GraphDef graph;
EXPECT_EQ(2, NumOutputs(CreateConcatOffsetNode(), &graph));
EXPECT_EQ(5, NumOutputs(CreateFusedBatchNormNode(), &graph));
EXPECT_EQ(1, NumOutputs(CreateDequeueNode(), &graph));
EXPECT_EQ(NumOutputs(CreateConcatOffsetNode(), &graph), 2);
EXPECT_EQ(NumOutputs(CreateFusedBatchNormNode(), &graph), 5);
EXPECT_EQ(NumOutputs(CreateDequeueNode(), &graph), 1);
}
TEST_F(UtilsTest, AsControlDependency) {
NodeDef node;
node.set_name("foo");
EXPECT_EQ("^foo", AsControlDependency(node));
EXPECT_EQ("^foo", AsControlDependency(node.name()));
EXPECT_EQ("^foo", AsControlDependency("^foo"));
EXPECT_EQ(AsControlDependency(node), "^foo");
EXPECT_EQ(AsControlDependency(node.name()), "^foo");
EXPECT_EQ(AsControlDependency("^foo"), "^foo");
}
TEST_F(UtilsTest, GetTailOfChain) {
@ -233,22 +235,23 @@ TEST_F(UtilsTest, GetTailOfChain) {
GraphDef graph;
TF_CHECK_OK(s.ToGraphDef(&graph));
ASSERT_EQ("c0", graph.node(0).name());
ASSERT_EQ("c1", graph.node(1).name());
ASSERT_EQ("neg0", graph.node(2).name());
ASSERT_EQ("neg1", graph.node(3).name());
ASSERT_EQ("neg2", graph.node(4).name());
ASSERT_EQ("id1", graph.node(5).name());
ASSERT_EQ("id2", graph.node(6).name());
ASSERT_EQ("noop", graph.node(7).name());
ASSERT_EQ(graph.node_size(), 8);
ASSERT_EQ(graph.node(0).name(), "c0");
ASSERT_EQ(graph.node(1).name(), "c1");
ASSERT_EQ(graph.node(2).name(), "neg0");
ASSERT_EQ(graph.node(3).name(), "neg1");
ASSERT_EQ(graph.node(4).name(), "neg2");
ASSERT_EQ(graph.node(5).name(), "id1");
ASSERT_EQ(graph.node(6).name(), "id2");
ASSERT_EQ(graph.node(7).name(), "noop");
NodeMap node_map(&graph);
auto is_neg = [&](const NodeDef& node) { return node.op() == "Neg"; };
// We walk backwards, starting as "id1", so tail should be "neg1".
NodeDef* tail = GetTailOfChain(graph.node(5), node_map,
/*follow_control_input=*/false, is_neg);
EXPECT_NE(tail, nullptr);
EXPECT_EQ("neg1", tail->name());
ASSERT_NE(tail, nullptr);
EXPECT_EQ(tail->name(), "neg1");
// We stop at branching nodes, so tail should be "neg2".
auto is_neg_and_non_branching = [&](const NodeDef& node) {
@ -257,22 +260,22 @@ TEST_F(UtilsTest, GetTailOfChain) {
tail =
GetTailOfChain(graph.node(5), node_map,
/*follow_control_input=*/false, is_neg_and_non_branching);
EXPECT_NE(tail, nullptr);
EXPECT_EQ("neg2", tail->name());
ASSERT_NE(tail, nullptr);
EXPECT_EQ(tail->name(), "neg2");
// We walk backwards, starting from "noop", also following control inputs,
// so tail should be "neg0".
tail = GetTailOfChain(graph.node(7), node_map,
/*follow_control_input=*/true, is_neg);
EXPECT_NE(tail, nullptr);
EXPECT_EQ("neg0", tail->name());
ASSERT_NE(tail, nullptr);
EXPECT_EQ(tail->name(), "neg0");
// We walk backwards, starting from "noop", not following control inputs,
// so tail should be "noop" itself.
tail = GetTailOfChain(graph.node(7), node_map,
/*follow_control_input=*/false, is_neg);
EXPECT_NE(tail, nullptr);
EXPECT_EQ("noop", tail->name());
ASSERT_NE(tail, nullptr);
EXPECT_EQ(tail->name(), "noop");
}
TEST_F(UtilsTest, DedupControlInputs) {
@ -280,40 +283,40 @@ TEST_F(UtilsTest, DedupControlInputs) {
foo.set_name("foo");
foo.add_input("bar");
DedupControlInputs(&foo);
EXPECT_EQ(1, foo.input_size());
EXPECT_EQ("bar", foo.input(0));
ASSERT_EQ(foo.input_size(), 1);
EXPECT_EQ(foo.input(0), "bar");
foo.set_input(0, "^bar");
DedupControlInputs(&foo);
EXPECT_EQ(1, foo.input_size());
EXPECT_EQ("^bar", foo.input(0));
ASSERT_EQ(foo.input_size(), 1);
EXPECT_EQ(foo.input(0), "^bar");
foo.set_input(0, "bar");
foo.add_input("bar");
DedupControlInputs(&foo);
EXPECT_EQ(2, foo.input_size());
EXPECT_EQ("bar", foo.input(0));
EXPECT_EQ("bar", foo.input(1));
ASSERT_EQ(foo.input_size(), 2);
EXPECT_EQ(foo.input(0), "bar");
EXPECT_EQ(foo.input(1), "bar");
foo.set_input(1, "^bar");
DedupControlInputs(&foo);
EXPECT_EQ(1, foo.input_size());
EXPECT_EQ("bar", foo.input(0));
ASSERT_EQ(foo.input_size(), 1);
EXPECT_EQ(foo.input(0), "bar");
foo.set_input(0, "^bar");
foo.add_input("^bar");
DedupControlInputs(&foo);
EXPECT_EQ(1, foo.input_size());
EXPECT_EQ("^bar", foo.input(0));
ASSERT_EQ(foo.input_size(), 1);
EXPECT_EQ(foo.input(0), "^bar");
foo.set_input(0, "bar");
foo.add_input("gnu");
foo.add_input("^bar");
foo.add_input("^gnu");
DedupControlInputs(&foo);
EXPECT_EQ(2, foo.input_size());
EXPECT_EQ("bar", foo.input(0));
EXPECT_EQ("gnu", foo.input(1));
ASSERT_EQ(foo.input_size(), 2);
EXPECT_EQ(foo.input(0), "bar");
EXPECT_EQ(foo.input(1), "gnu");
}
TEST_F(UtilsTest, NumNonControlOutputs) {
@ -347,14 +350,14 @@ TEST_F(UtilsTest, NumNonControlOutputs) {
NodeMap node_map(&graph);
const NodeDef* add_node = node_map.GetNode("add");
ASSERT_TRUE(add_node != nullptr);
ASSERT_NE(add_node, nullptr);
// [a, b] are only non-control inputs
EXPECT_EQ(2, NumNonControlInputs(*add_node));
EXPECT_EQ(NumNonControlInputs(*add_node), 2);
// [sqrt, shape] are non control outputs
EXPECT_EQ(2, NumNonControlOutputs(*add_node, node_map));
EXPECT_EQ(NumNonControlOutputs(*add_node, node_map), 2);
// sqrt is the only data output
EXPECT_EQ(1, NumNonControlDataOutputs(*add_node, node_map));
EXPECT_EQ(NumNonControlDataOutputs(*add_node, node_map), 1);
}
TEST(CheckAttrExists, All) {
@ -465,10 +468,104 @@ TEST_F(UtilsTest, SetTensorValueBFloat16IntMin) {
}
TEST_F(UtilsTest, TensorIdToString) {
EXPECT_EQ("^foo", TensorIdToString({"foo", -1}));
EXPECT_EQ("foo", TensorIdToString({"foo", 0}));
EXPECT_EQ("foo:1", TensorIdToString({"foo", 1}));
EXPECT_EQ("foo:2", TensorIdToString({"foo", 2}));
EXPECT_EQ(TensorIdToString({"foo", -1}), "^foo");
EXPECT_EQ(TensorIdToString({"foo", 0}), "foo");
EXPECT_EQ(TensorIdToString({"foo", 1}), "foo:1");
EXPECT_EQ(TensorIdToString({"foo", 2}), "foo:2");
}
template <typename T>
void TestSetTensorValue(DataType type, int val, bool success,
absl::string_view error_msg) {
Tensor t(type, TensorShape({}));
Status s = SetTensorValue(t.dtype(), val, &t);
EXPECT_EQ(s.ok(), success);
if (s.ok()) {
test::ExpectTensorEqual<T>(Tensor(static_cast<T>(val)), t);
} else {
EXPECT_EQ(s.error_message(), error_msg);
}
}
TEST(SetTensorValueTest, Quantized) {
auto int_min_error = [](DataType type) {
return absl::Substitute(
"Cannot store value -2147483648 in tensor of type $0",
DataType_Name(type));
};
auto int_max_error = [](DataType type) {
return absl::Substitute(
"Cannot store value 2147483647 in tensor of type $0",
DataType_Name(type));
};
const int kMinInt = std::numeric_limits<int>::min();
const int kMaxInt = std::numeric_limits<int>::max();
TestSetTensorValue<qint8>(DT_QINT8, -8, /*success=*/true, /*error_msg=*/"");
TestSetTensorValue<qint8>(DT_QINT8, 0, /*success=*/true, /*error_msg=*/"");
TestSetTensorValue<qint8>(DT_QINT8, 8, /*success=*/true, /*error_msg=*/"");
TestSetTensorValue<qint8>(DT_QINT8, std::numeric_limits<qint8>::min(),
/*success=*/true, /*error_msg=*/"");
TestSetTensorValue<qint8>(DT_QINT8, std::numeric_limits<qint8>::max(),
/*success=*/true, /*error_msg=*/"");
TestSetTensorValue<qint8>(DT_QINT8, kMinInt, /*success=*/false,
int_min_error(DT_QINT8));
TestSetTensorValue<qint8>(DT_QINT8, kMaxInt, /*success=*/false,
int_max_error(DT_QINT8));
TestSetTensorValue<quint8>(
DT_QUINT8, -8, /*success=*/false,
/*error_msg=*/"Cannot store value -8 in tensor of type DT_QUINT8");
TestSetTensorValue<quint8>(DT_QUINT8, 0, /*success=*/true, /*error_msg=*/"");
TestSetTensorValue<quint8>(DT_QUINT8, 8, /*success=*/true, /*error_msg=*/"");
TestSetTensorValue<quint8>(DT_QUINT8, std::numeric_limits<quint8>::min(),
/*success=*/true, /*error_msg=*/"");
TestSetTensorValue<quint8>(DT_QUINT8, std::numeric_limits<quint8>::max(),
/*success=*/true, /*error_msg=*/"");
TestSetTensorValue<quint8>(DT_QUINT8, kMinInt, /*success=*/false,
int_min_error(DT_QUINT8));
TestSetTensorValue<quint8>(DT_QUINT8, kMaxInt, /*success=*/false,
int_max_error(DT_QUINT8));
TestSetTensorValue<qint16>(DT_QINT16, -8, /*success=*/true, /*error_msg=*/"");
TestSetTensorValue<qint16>(DT_QINT16, 0, /*success=*/true, /*error_msg=*/"");
TestSetTensorValue<qint16>(DT_QINT16, 8, /*success=*/true, /*error_msg=*/"");
TestSetTensorValue<qint16>(DT_QINT16, std::numeric_limits<qint16>::min(),
/*success=*/true, /*error_msg=*/"");
TestSetTensorValue<qint16>(DT_QINT16, std::numeric_limits<qint16>::max(),
/*success=*/true, /*error_msg=*/"");
TestSetTensorValue<qint16>(DT_QINT16, kMinInt, /*success=*/false,
int_min_error(DT_QINT16));
TestSetTensorValue<qint16>(DT_QINT16, kMaxInt, /*success=*/false,
int_max_error(DT_QINT16));
TestSetTensorValue<quint16>(
DT_QUINT16, -8, /*success=*/false,
/*error_msg=*/"Cannot store value -8 in tensor of type DT_QUINT16");
TestSetTensorValue<quint16>(DT_QUINT16, 0, /*success=*/true,
/*error_msg=*/"");
TestSetTensorValue<quint16>(DT_QUINT16, 8, /*success=*/true,
/*error_msg=*/"");
TestSetTensorValue<quint16>(DT_QUINT16, std::numeric_limits<quint16>::min(),
/*success=*/true, /*error_msg=*/"");
TestSetTensorValue<quint16>(DT_QUINT16, std::numeric_limits<quint16>::max(),
/*success=*/true, /*error_msg=*/"");
TestSetTensorValue<quint16>(DT_QUINT16, kMinInt, /*success=*/false,
int_min_error(DT_QUINT16));
TestSetTensorValue<quint16>(DT_QUINT16, kMaxInt, /*success=*/false,
int_max_error(DT_QUINT16));
TestSetTensorValue<qint32>(DT_QINT32, -8, /*success=*/true, /*error_msg=*/"");
TestSetTensorValue<qint32>(DT_QINT32, 0, /*success=*/true, /*error_msg=*/"");
TestSetTensorValue<qint32>(DT_QINT32, 8, /*success=*/true, /*error_msg=*/"");
TestSetTensorValue<qint32>(DT_QINT32, std::numeric_limits<qint32>::min(),
/*success=*/true, /*error_msg=*/"");
TestSetTensorValue<qint32>(DT_QINT32, std::numeric_limits<qint32>::max(),
/*success=*/true, /*error_msg=*/"");
TestSetTensorValue<qint32>(DT_QINT32, kMinInt, /*success=*/true,
/*error_msg=*/"");
TestSetTensorValue<qint32>(DT_QINT32, kMaxInt, /*success=*/true,
/*error_msg=*/"");
}
} // namespace