From 43f47645a8dcae81f5fa626848c61b4464765531 Mon Sep 17 00:00:00 2001 From: Andy Ly Date: Thu, 14 Feb 2019 19:32:14 -0800 Subject: [PATCH] [Grappler] Add initial support for DT_QINT32, DT_QINT16, DT_QUINT16, DT_QINT8, and DT_QUINT8 to ConstantFolding. PiperOrigin-RevId: 234072895 --- tensorflow/core/grappler/BUILD | 5 +- .../core/grappler/costs/graph_properties.cc | 1 + tensorflow/core/grappler/op_types.cc | 8 +- tensorflow/core/grappler/op_types.h | 1 + tensorflow/core/grappler/optimizers/BUILD | 20 +- .../optimizers/arithmetic_optimizer_test.cc | 208 +------------- .../arithmetic_optimizer_test_utils.h | 236 ++++++++++++++++ .../grappler/optimizers/constant_folding.cc | 117 ++++++++ .../grappler/optimizers/constant_folding.h | 3 + tensorflow/core/grappler/utils.cc | 68 +++-- tensorflow/core/grappler/utils_test.cc | 259 ++++++++++++------ 11 files changed, 609 insertions(+), 317 deletions(-) create mode 100644 tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h diff --git a/tensorflow/core/grappler/BUILD b/tensorflow/core/grappler/BUILD index 9fe699360fe..77307708fab 100644 --- a/tensorflow/core/grappler/BUILD +++ b/tensorflow/core/grappler/BUILD @@ -1,7 +1,6 @@ licenses(["notice"]) # Apache 2.0 -load("//tensorflow:tensorflow.bzl", "tf_cc_test") -load("//tensorflow:tensorflow.bzl", "tf_cuda_library") +load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_cuda_library") cc_library( name = "op_types", @@ -45,6 +44,7 @@ tf_cc_test( "//tensorflow/core:tensor_testutil", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", ], ) @@ -71,7 +71,6 @@ cc_library( deps = [ ":graph_view", "//tensorflow/core:graph", - "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index 8ec558be7d7..6907988d08f 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -481,6 +481,7 @@ bool IsNumericType(const DataType dtype) { DT_QINT8, DT_QUINT8, DT_QINT16, + DT_QUINT16, DT_QINT32, // Bool. DT_BOOL, diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index 5d2fa4a45bb..59400dc479b 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -279,8 +279,8 @@ bool IsLogicalOr(const NodeDef& node) { return node.op() == "LogicalOr"; } bool IsMatMul(const NodeDef& node) { const auto& op = node.op(); - return op == "MatMul" || op == "BatchMatMul" || op == "QuantizedMatMul" || - op == "SparseMatMul"; + return op == "MatMul" || op == "BatchMatMul" || op == "SparseMatMul" || + IsQuantizedMatMul(node); } bool IsMax(const NodeDef& node) { return node.op() == "Max"; } @@ -350,6 +350,10 @@ bool IsPrint(const NodeDef& node) { bool IsProd(const NodeDef& node) { return node.op() == "Prod"; } +bool IsQuantizedMatMul(const NodeDef& node) { + return node.op() == "QuantizedMatMul" || node.op() == "QuantizedMatMulV2"; +} + bool IsQueue(const NodeDef& node) { return str_util::EndsWith(node.op(), "QueueV2"); } diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index bc1d8c15acc..bc1bb33772d 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -106,6 +106,7 @@ bool IsPack(const NodeDef& node); bool IsPad(const NodeDef& node); bool IsPack(const NodeDef& node); bool IsPartitionedCall(const NodeDef& node); +bool IsQuantizedMatMul(const NodeDef& node); bool IsNeg(const NodeDef& node); bool IsNoOp(const NodeDef& node); bool IsNotEqual(const NodeDef& node); diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 9bb63a5f4ed..af6fb137617 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -3,7 +3,6 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") load("//tensorflow:tensorflow.bzl", "tf_kernel_library") -load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") # Platform specific build config load( @@ -274,13 +273,29 @@ cc_library( ], ) +cc_library( + name = "arithmetic_optimizer_test_utils", + testonly = 1, + hdrs = [ + "arithmetic_optimizer_test_utils.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":arithmetic_optimizer", + ":constant_folding", + ":model_pruner", + "//tensorflow/core:test", + "//tensorflow/core/grappler/utils:grappler_test", + ], +) + tf_cuda_cc_test( name = "arithmetic_optimizer_test", size = "small", srcs = ["arithmetic_optimizer_test.cc"], deps = [ ":arithmetic_optimizer", - ":constant_folding", + ":arithmetic_optimizer_test_utils", ":model_pruner", "//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops_internal", @@ -295,7 +310,6 @@ tf_cuda_cc_test( "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder", - "//tensorflow/core/grappler/utils:grappler_test", ], ) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index 1220aefecf0..27783346229 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -20,10 +20,9 @@ limitations under the License. #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" -#include "tensorflow/core/grappler/optimizers/constant_folding.h" +#include "tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h" #include "tensorflow/core/grappler/optimizers/model_pruner.h" #include "tensorflow/core/grappler/utils.h" -#include "tensorflow/core/grappler/utils/grappler_test.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -92,211 +91,6 @@ void VerifyGraphsMatch(const GraphDef& original_graph, } } // namespace -class ArithmeticOptimizerTest : public GrapplerTest { - protected: - // Optimize a graph using ArithmeticOptimizer and prune all the nodes that no - // longer have any output consumers. - void OptimizeAndPrune(ArithmeticOptimizer* optimizer, GrapplerItem* item, - GraphDef* output) { - TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output)); - item->graph.Swap(output); - output->Clear(); - TF_EXPECT_OK(ModelPruner().Optimize(nullptr, *item, output)); - } - - // Run ArithmeticOptimizer twice to make sure the rewrite is idempotent. - void OptimizeTwice(ArithmeticOptimizer* optimizer, GrapplerItem* item, - GraphDef* output) { - TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output)); - item->graph.Swap(output); - output->Clear(); - TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output)); - } - - // Run ArithmeticOptimizer twice to make sure the rewrite is idempotent. - // Optionally run a constant folding pass before pruning. - void OptimizeTwiceAndPrune(ArithmeticOptimizer* optimizer, GrapplerItem* item, - GraphDef* output, bool const_folding = false) { - TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output)); - - item->graph.Swap(output); - output->Clear(); - TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output)); - - if (const_folding) { - item->graph.Swap(output); - output->Clear(); - TF_EXPECT_OK(ConstantFolding(/*cpu_device=*/nullptr) - .Optimize(nullptr, *item, output)); - } - - item->graph.Swap(output); - output->Clear(); - TF_EXPECT_OK(ModelPruner().Optimize(nullptr, *item, output)); - } - - // TODO(ezhulenev): Make private. After migration to stages each test - // should explicitly enable required optimization for tests isolation - void DisableAllStages(ArithmeticOptimizer* optimizer) { - ArithmeticOptimizer::ArithmeticOptimizerOptions options; - options.dedup_computations = false; - options.combine_add_to_addn = false; - options.convert_sqrt_div_to_rsqrt_mul = false; - options.convert_pow = false; - options.convert_log1p = false; - options.optimize_max_or_min_of_monotonic = false; - options.fold_conjugate_into_transpose = false; - options.fold_multiply_into_conv = false; - options.fold_transpose_into_matmul = false; - options.hoist_common_factor_out_of_aggregation = false; - options.hoist_cwise_unary_chains = false; - options.minimize_broadcasts = false; - options.remove_identity_transpose = false; - options.remove_involution = false; - options.remove_idempotent = false; - options.remove_redundant_bitcast = false; - options.remove_redundant_cast = false; - options.remove_redundant_reshape = false; - options.remove_negation = false; - options.remove_logical_not = false; - options.reorder_cast_like_and_value_preserving = false; - options.replace_mul_with_square = false; - options.simplify_aggregation = false; - options.unary_ops_composition = false; - optimizer->options_ = options; - } - - void DisableAddToAddNCombining(ArithmeticOptimizer* optimizer) { - optimizer->options_.combine_add_to_addn = false; - } - - void EnableOnlyAddToAddNCombining(ArithmeticOptimizer* optimizer) { - DisableAllStages(optimizer); - optimizer->options_.combine_add_to_addn = true; - } - - void EnableOnlyFoldConjugateIntoTranspose(ArithmeticOptimizer* optimizer) { - DisableAllStages(optimizer); - optimizer->options_.fold_conjugate_into_transpose = true; - } - - void EnableOnlyFoldMultipleIntoConv(ArithmeticOptimizer* optimizer) { - DisableAllStages(optimizer); - optimizer->options_.fold_multiply_into_conv = true; - } - - void EnableOnlyFoldTransposeIntoMatMul(ArithmeticOptimizer* optimizer) { - DisableAllStages(optimizer); - optimizer->options_.fold_transpose_into_matmul = true; - } - - void EnableOnlyHoistCommonFactor(ArithmeticOptimizer* optimizer) { - DisableAllStages(optimizer); - optimizer->options_.hoist_common_factor_out_of_aggregation = true; - } - - void EnableOnlyMinimizeBroadcasts(ArithmeticOptimizer* optimizer) { - DisableAllStages(optimizer); - optimizer->options_.minimize_broadcasts = true; - } - - void EnableOnlyRemoveIdentityTranspose(ArithmeticOptimizer* optimizer) { - DisableAllStages(optimizer); - optimizer->options_.remove_identity_transpose = true; - } - - void EnableOnlyRemoveInvolution(ArithmeticOptimizer* optimizer) { - DisableAllStages(optimizer); - optimizer->options_.remove_involution = true; - } - - void EnableOnlyRemoveRedundantBitcast(ArithmeticOptimizer* optimizer) { - DisableAllStages(optimizer); - optimizer->options_.remove_redundant_bitcast = true; - } - - void EnableOnlyRemoveRedundantCast(ArithmeticOptimizer* optimizer) { - DisableAllStages(optimizer); - optimizer->options_.remove_redundant_cast = true; - } - - void EnableOnlyRemoveRedundantReshape(ArithmeticOptimizer* optimizer) { - DisableAllStages(optimizer); - optimizer->options_.remove_redundant_reshape = true; - } - - void EnableOnlyRemoveNegation(ArithmeticOptimizer* optimizer) { - DisableAllStages(optimizer); - optimizer->options_.remove_negation = true; - } - - void EnableOnlyReorderCastAndTranspose(ArithmeticOptimizer* optimizer) { - DisableAllStages(optimizer); - optimizer->options_.reorder_cast_like_and_value_preserving = true; - } - - void EnableOnlyReplaceMulWithSquare(ArithmeticOptimizer* optimizer) { - DisableAllStages(optimizer); - optimizer->options_.replace_mul_with_square = true; - } - - void EnableOnlyHoistCWiseUnaryChains(ArithmeticOptimizer* optimizer) { - DisableAllStages(optimizer); - optimizer->options_.hoist_cwise_unary_chains = true; - } - - void EnableOnlySqrtDivToRsqrtMul(ArithmeticOptimizer* optimizer) { - DisableAllStages(optimizer); - optimizer->options_.convert_sqrt_div_to_rsqrt_mul = true; - } - - void EnableOnlyConvertPow(ArithmeticOptimizer* optimizer) { - DisableAllStages(optimizer); - optimizer->options_.convert_pow = true; - } - - void EnableOnlyRemoveIdempotent(ArithmeticOptimizer* optimizer) { - DisableAllStages(optimizer); - optimizer->options_.remove_idempotent = true; - } - - void EnableOnlyRemoveLogicalNot(ArithmeticOptimizer* optimizer) { - DisableAllStages(optimizer); - optimizer->options_.remove_logical_not = true; - } - - void EnableOnlySimplifyAggregation(ArithmeticOptimizer* optimizer) { - DisableAllStages(optimizer); - optimizer->options_.simplify_aggregation = true; - } - - void EnableOnlyLog1p(ArithmeticOptimizer* optimizer) { - DisableAllStages(optimizer); - optimizer->options_.convert_log1p = true; - } - - void EnableOnlyOptimizeMaxOrMinOfMonotonic(ArithmeticOptimizer* optimizer) { - DisableAllStages(optimizer); - optimizer->options_.optimize_max_or_min_of_monotonic = true; - } - - void EnableOnlyExpm1(ArithmeticOptimizer* optimizer) { - DisableAllStages(optimizer); - optimizer->options_.convert_expm1 = true; - } - - void EnableOnlyUnaryOpsComposition(ArithmeticOptimizer* optimizer) { - DisableAllStages(optimizer); - optimizer->options_.unary_ops_composition = true; - } - - void EnableOnlyRemoveStackStridedSliceSameAxis( - ArithmeticOptimizer* optimizer) { - DisableAllStages(optimizer); - optimizer->options_.remove_stack_strided_slice_same_axis = true; - } -}; - TEST_F(ArithmeticOptimizerTest, NoOp) { // This trivial graph is so basic there's nothing to optimize. TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"}); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h new file mode 100644 index 00000000000..94d0adc6092 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h @@ -0,0 +1,236 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_ARITHMETIC_OPTIMIZER_TEST_UTILS_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_ARITHMETIC_OPTIMIZER_TEST_UTILS_H_ + +#include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h" +#include "tensorflow/core/grappler/optimizers/constant_folding.h" +#include "tensorflow/core/grappler/optimizers/model_pruner.h" +#include "tensorflow/core/grappler/utils/grappler_test.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace tensorflow { +namespace grappler { + +class ArithmeticOptimizerTest : public GrapplerTest { + protected: + // Optimize a graph using ArithmeticOptimizer and prune all the nodes that no + // longer have any output consumers. + void OptimizeAndPrune(ArithmeticOptimizer* optimizer, GrapplerItem* item, + GraphDef* output) { + TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output)); + item->graph.Swap(output); + output->Clear(); + TF_EXPECT_OK(ModelPruner().Optimize(nullptr, *item, output)); + } + + // Run ArithmeticOptimizer twice to make sure the rewrite is idempotent. + void OptimizeTwice(ArithmeticOptimizer* optimizer, GrapplerItem* item, + GraphDef* output) { + TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output)); + item->graph.Swap(output); + output->Clear(); + TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output)); + } + + // Run ArithmeticOptimizer twice to make sure the rewrite is idempotent. + // Optionally run a constant folding pass before pruning. + void OptimizeTwiceAndPrune(ArithmeticOptimizer* optimizer, GrapplerItem* item, + GraphDef* output, bool const_folding = false) { + TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output)); + + item->graph.Swap(output); + output->Clear(); + TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output)); + + if (const_folding) { + item->graph.Swap(output); + output->Clear(); + TF_EXPECT_OK(ConstantFolding(/*cpu_device=*/nullptr) + .Optimize(nullptr, *item, output)); + } + + item->graph.Swap(output); + output->Clear(); + TF_EXPECT_OK(ModelPruner().Optimize(nullptr, *item, output)); + } + + // TODO(ezhulenev): Make private. After migration to stages each test + // should explicitly enable required optimization for tests isolation + void DisableAllStages(ArithmeticOptimizer* optimizer) { + ArithmeticOptimizer::ArithmeticOptimizerOptions options; + options.dedup_computations = false; + options.combine_add_to_addn = false; + options.convert_sqrt_div_to_rsqrt_mul = false; + options.convert_pow = false; + options.convert_log1p = false; + options.optimize_max_or_min_of_monotonic = false; + options.fold_conjugate_into_transpose = false; + options.fold_multiply_into_conv = false; + options.fold_transpose_into_matmul = false; + options.hoist_common_factor_out_of_aggregation = false; + options.hoist_cwise_unary_chains = false; + options.minimize_broadcasts = false; + options.remove_identity_transpose = false; + options.remove_involution = false; + options.remove_idempotent = false; + options.remove_redundant_bitcast = false; + options.remove_redundant_cast = false; + options.remove_redundant_reshape = false; + options.remove_negation = false; + options.remove_logical_not = false; + options.reorder_cast_like_and_value_preserving = false; + options.replace_mul_with_square = false; + options.simplify_aggregation = false; + options.unary_ops_composition = false; + optimizer->options_ = options; + } + + void DisableAddToAddNCombining(ArithmeticOptimizer* optimizer) { + optimizer->options_.combine_add_to_addn = false; + } + + void EnableOnlyAddToAddNCombining(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.combine_add_to_addn = true; + } + + void EnableOnlyFoldConjugateIntoTranspose(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.fold_conjugate_into_transpose = true; + } + + void EnableOnlyFoldMultipleIntoConv(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.fold_multiply_into_conv = true; + } + + void EnableOnlyFoldTransposeIntoMatMul(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.fold_transpose_into_matmul = true; + } + + void EnableOnlyHoistCommonFactor(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.hoist_common_factor_out_of_aggregation = true; + } + + void EnableOnlyMinimizeBroadcasts(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.minimize_broadcasts = true; + } + + void EnableOnlyRemoveIdentityTranspose(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.remove_identity_transpose = true; + } + + void EnableOnlyRemoveInvolution(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.remove_involution = true; + } + + void EnableOnlyRemoveRedundantBitcast(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.remove_redundant_bitcast = true; + } + + void EnableOnlyRemoveRedundantCast(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.remove_redundant_cast = true; + } + + void EnableOnlyRemoveRedundantReshape(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.remove_redundant_reshape = true; + } + + void EnableOnlyRemoveNegation(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.remove_negation = true; + } + + void EnableOnlyReorderCastAndTranspose(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.reorder_cast_like_and_value_preserving = true; + } + + void EnableOnlyReplaceMulWithSquare(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.replace_mul_with_square = true; + } + + void EnableOnlyHoistCWiseUnaryChains(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.hoist_cwise_unary_chains = true; + } + + void EnableOnlySqrtDivToRsqrtMul(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.convert_sqrt_div_to_rsqrt_mul = true; + } + + void EnableOnlyConvertPow(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.convert_pow = true; + } + + void EnableOnlyRemoveIdempotent(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.remove_idempotent = true; + } + + void EnableOnlyRemoveLogicalNot(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.remove_logical_not = true; + } + + void EnableOnlySimplifyAggregation(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.simplify_aggregation = true; + } + + void EnableOnlyLog1p(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.convert_log1p = true; + } + + void EnableOnlyOptimizeMaxOrMinOfMonotonic(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.optimize_max_or_min_of_monotonic = true; + } + + void EnableOnlyExpm1(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.convert_expm1 = true; + } + + void EnableOnlyUnaryOpsComposition(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.unary_ops_composition = true; + } + + void EnableOnlyRemoveStackStridedSliceSameAxis( + ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.remove_stack_strided_slice_same_axis = true; + } +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_ARITHMETIC_OPTIMIZER_TEST_UTILS_H_ diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index e626943ee63..cf495eecf53 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/function.pb.h" @@ -37,6 +38,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/evaluation_utils.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/grappler/utils/symbolic_shapes.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" @@ -185,6 +187,40 @@ bool IsDenormal(double x) { return !std::isnormal(x); } +float QuantizedTypeMinAsFloat(DataType data_type) { + switch (data_type) { + case DT_QINT8: + return Eigen::NumTraits::lowest(); + case DT_QUINT8: + return Eigen::NumTraits::lowest(); + case DT_QINT16: + return Eigen::NumTraits::lowest(); + case DT_QUINT16: + return Eigen::NumTraits::lowest(); + case DT_QINT32: + return Eigen::NumTraits::lowest(); + default: + return 0.0f; + } +} + +float QuantizedTypeMaxAsFloat(DataType data_type) { + switch (data_type) { + case DT_QINT8: + return Eigen::NumTraits::highest(); + case DT_QUINT8: + return Eigen::NumTraits::highest(); + case DT_QINT16: + return Eigen::NumTraits::highest(); + case DT_QUINT16: + return Eigen::NumTraits::highest(); + case DT_QINT32: + return Eigen::NumTraits::highest(); + default: + return 0.0f; + } +} + } // namespace ConstantFolding::ConstantFolding(RewriterConfig::Toggle opt_level, @@ -945,6 +981,11 @@ Status CreateConstantTensorAttrValue(DataType type, double value, SET_TENSOR_VAL_CASE(DT_UINT16, int32, int); SET_TENSOR_VAL_CASE(DT_INT8, int32, int); SET_TENSOR_VAL_CASE(DT_UINT8, int32, int); + SET_TENSOR_VAL_CASE(DT_QINT32, int32, int); + SET_TENSOR_VAL_CASE(DT_QINT16, int32, int); + SET_TENSOR_VAL_CASE(DT_QUINT16, int32, int); + SET_TENSOR_VAL_CASE(DT_QINT8, int32, int); + SET_TENSOR_VAL_CASE(DT_QUINT8, int32, int); SET_TENSOR_VAL_CASE(DT_BOOL, bool, bool); default: return errors::InvalidArgument("Unsupported type: ", type); @@ -1085,6 +1126,8 @@ Status ConstantFolding::CreateNodeDef(const string& name, t->set_dtype(tensor->dtype()); tensor->shape().AsProto(t->mutable_tensor_shape()); } else { + // DT_HALF, DT_BFLOAT16, DT_QINT32, DT_QINT16, DT_QUINT16, DT_QINT8, + // DT_QUINT8 tensor->AsProtoTensorContent(t); encoded_size = t->tensor_content().size(); } @@ -1533,6 +1576,11 @@ bool ConstantFolding::IsOnes(const NodeDef& node) const { IS_ONES_CASE(DT_INT16); IS_ONES_CASE(DT_INT32); IS_ONES_CASE(DT_INT64); + IS_ONES_CASE(DT_QINT32); + IS_ONES_CASE(DT_QINT16); + IS_ONES_CASE(DT_QUINT16); + IS_ONES_CASE(DT_QINT8); + IS_ONES_CASE(DT_QUINT8); default: VLOG(1) << "Unsupported type " << DataTypeString(dtype); return false; @@ -1567,6 +1615,11 @@ bool ConstantFolding::IsZeros(const NodeDef& node) const { IS_ZEROS_CASE(DT_INT16); IS_ZEROS_CASE(DT_INT32); IS_ZEROS_CASE(DT_INT64); + IS_ZEROS_CASE(DT_QINT32); + IS_ZEROS_CASE(DT_QINT16); + IS_ZEROS_CASE(DT_QUINT16); + IS_ZEROS_CASE(DT_QINT8); + IS_ZEROS_CASE(DT_QUINT8); default: VLOG(1) << "Unsupported type " << DataTypeString(dtype); return false; @@ -2576,6 +2629,7 @@ Status ConstantFolding::SimplifyArithmeticOperations( *success = false; const bool is_mul = IsMul(*node) || IsLogicalAnd(*node); const bool is_matmul = IsMatMul(*node); + const bool is_quantized_matmul = IsQuantizedMatMul(*node); const bool is_add = IsAdd(*node) || IsBiasAdd(*node) || IsLogicalOr(*node); const bool is_sub = IsSub(*node); const bool is_any_div = IsAnyDiv(*node); @@ -2670,6 +2724,10 @@ Status ConstantFolding::SimplifyArithmeticOperations( if (!replace_op_status.ok()) { return replace_op_status; } else if (replace_succeed) { + if (is_quantized_matmul) { + TF_RETURN_IF_ERROR( + AddQuantizedMatMulMinMaxOutConstNodes(node, optimized_graph)); + } *success = true; return Status::OK(); } @@ -3237,6 +3295,65 @@ bool ConstantFolding::MergeConcat(const GraphProperties& properties, return true; } +Status ConstantFolding::AddQuantizedMatMulMinMaxOutConstNodes( + NodeDef* node, GraphDef* optimized_graph) { + auto add_quantized_out = [this, node, optimized_graph]( + const string& out_const_name, int index) { + NodeDef* out_node = optimized_graph->add_node(); + Tensor value(DT_FLOAT, TensorShape({})); + const bool is_min = index == 1; + const DataType type_attr = node->attr().at("dtype").type(); + + value.flat()(0) = is_min ? QuantizedTypeMinAsFloat(type_attr) + : QuantizedTypeMaxAsFloat(type_attr); + TF_RETURN_IF_ERROR( + CreateNodeDef(out_const_name, TensorValue(&value), out_node)); + node_map_->AddNode(out_const_name, out_node); + out_node->set_device(node->device()); + + // Copy all inputs from node. + out_node->mutable_input()->CopyFrom(node->input()); + for (const string& input : out_node->input()) { + node_map_->AddOutput(NodeName(input), out_const_name); + } + + // Update output nodes consuming node:index to new const node. + string old_input = absl::StrCat(node->name(), ":", index); + int old_node_count = 0; + auto outputs = node_map_->GetOutputs(node->name()); + for (const auto& output : outputs) { + for (int i = 0; i < output->input_size(); ++i) { + if (output->input(i) == old_input) { + output->set_input(i, out_const_name); + node_map_->AddOutput(out_const_name, output->name()); + } else if (NodeName(output->input(i)) == node->name()) { + ++old_node_count; + } + } + if (old_node_count == 0) { + node_map_->RemoveOutput(node->name(), output->name()); + } + } + + return Status::OK(); + }; + const string min_out_const_name = + OptimizedNodeName(*node, "-quantized_matmul_min_out"); + const string max_out_const_name = + OptimizedNodeName(*node, "-quantized_matmul_max_out"); + if (node_map_->GetNode(min_out_const_name) == nullptr && + node_map_->GetNode(max_out_const_name) == nullptr) { + TF_RETURN_IF_ERROR(add_quantized_out(min_out_const_name, 1)); + TF_RETURN_IF_ERROR(add_quantized_out(max_out_const_name, 2)); + } else { + return errors::Internal(absl::Substitute( + "Can't create Const for QuantizedMatMul min_out/max_out of " + "node '$0' because of node name conflict", + node->name())); + } + return Status::OK(); +} + Status ConstantFolding::RunOptimizationPass(Cluster* cluster, const GrapplerItem& item, GraphDef* optimized_graph) { diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index 7cf01b4b62c..418176c8932 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -236,6 +236,9 @@ class ConstantFolding : public GraphOptimizer { bool MergeConcat(const GraphProperties& properties, bool use_shape_info, GraphDef* optimized_graph, NodeDef* node); + Status AddQuantizedMatMulMinMaxOutConstNodes(NodeDef* node, + GraphDef* optimized_graph); + // Points to an externally provided device or to owned_device_; RewriterConfig::Toggle opt_level_; DeviceBase* cpu_device_; diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc index 375c3e56c80..7d4dfb05207 100644 --- a/tensorflow/core/grappler/utils.cc +++ b/tensorflow/core/grappler/utils.cc @@ -40,7 +40,7 @@ namespace tensorflow { namespace grappler { namespace { template -bool SafeSetScalarTensorValue(double value, Tensor* tensor) { +bool SafeSetDoubleScalarTensorValue(double value, Tensor* tensor) { using RealType = typename Eigen::NumTraits::Real; if (value > static_cast(Eigen::NumTraits::highest()) || value < static_cast(Eigen::NumTraits::lowest())) { @@ -50,6 +50,17 @@ bool SafeSetScalarTensorValue(double value, Tensor* tensor) { return true; } +template +bool SafeSetIntScalarTensorValue(int value, Tensor* tensor) { + using RealType = typename Eigen::NumTraits::Real; + if (value > static_cast(Eigen::NumTraits::highest()) || + value < static_cast(Eigen::NumTraits::lowest())) { + return false; + } + tensor->flat()(0) = static_cast(value); + return true; +} + // Is 'node' an operator that consumes only the shape of its input, not the // data itself? // TODO(ezhulenev): move to op_types.h. Requires to break circular dependency. @@ -410,35 +421,50 @@ void EraseNodesFromGraph(const std::set& nodes_to_delete, EraseNodesFromGraphImpl(nodes_idx_to_delete, graph); } -#define HANDLE_CASE(DTYPE) \ - case DTYPE: \ - if (!SafeSetScalarTensorValue::Type>( \ - static_cast(value), tensor)) { \ - return errors::InvalidArgument("Cannot store value ", value, \ - " in tensor of type " #DTYPE); \ - } \ +#define HANDLE_DOUBLE_CASE(DTYPE) \ + case DTYPE: \ + if (!SafeSetDoubleScalarTensorValue::Type>( \ + static_cast(value), tensor)) { \ + return errors::InvalidArgument("Cannot store value ", value, \ + " in tensor of type " #DTYPE); \ + } \ + break + +#define HANDLE_INT_CASE(DTYPE) \ + case DTYPE: \ + if (!SafeSetIntScalarTensorValue::Type>(value, \ + tensor)) { \ + return errors::InvalidArgument("Cannot store value ", value, \ + " in tensor of type " #DTYPE); \ + } \ break Status SetTensorValue(DataType dtype, int value, Tensor* tensor) { // TODO(rmlarsen): Support more general shapes. + // TODO(lyandy): Change `value` to be int64 once int64 -> qint32 is supported. if (tensor->NumElements() != 1) { return errors::InvalidArgument( "Expected scalar tensor, got num_elements = ", tensor->NumElements()); } switch (dtype) { - HANDLE_CASE(DT_HALF); - HANDLE_CASE(DT_BFLOAT16); - HANDLE_CASE(DT_BOOL); - HANDLE_CASE(DT_FLOAT); - HANDLE_CASE(DT_DOUBLE); - HANDLE_CASE(DT_UINT8); - HANDLE_CASE(DT_INT8); - HANDLE_CASE(DT_UINT16); - HANDLE_CASE(DT_INT16); - HANDLE_CASE(DT_INT32); - HANDLE_CASE(DT_INT64); - HANDLE_CASE(DT_COMPLEX64); - HANDLE_CASE(DT_COMPLEX128); + HANDLE_DOUBLE_CASE(DT_HALF); + HANDLE_DOUBLE_CASE(DT_BFLOAT16); + HANDLE_DOUBLE_CASE(DT_BOOL); + HANDLE_DOUBLE_CASE(DT_FLOAT); + HANDLE_DOUBLE_CASE(DT_DOUBLE); + HANDLE_DOUBLE_CASE(DT_UINT8); + HANDLE_DOUBLE_CASE(DT_INT8); + HANDLE_DOUBLE_CASE(DT_UINT16); + HANDLE_DOUBLE_CASE(DT_INT16); + HANDLE_DOUBLE_CASE(DT_INT32); + HANDLE_DOUBLE_CASE(DT_INT64); + HANDLE_DOUBLE_CASE(DT_COMPLEX64); + HANDLE_DOUBLE_CASE(DT_COMPLEX128); + HANDLE_INT_CASE(DT_QINT8); + HANDLE_INT_CASE(DT_QUINT8); + HANDLE_INT_CASE(DT_QINT16); + HANDLE_INT_CASE(DT_QUINT16); + HANDLE_INT_CASE(DT_QINT32); default: return errors::InvalidArgument("Unsupported type ", DataTypeString(dtype)); diff --git a/tensorflow/core/grappler/utils_test.cc b/tensorflow/core/grappler/utils_test.cc index f5ae39867ac..e30b1c5b730 100644 --- a/tensorflow/core/grappler/utils_test.cc +++ b/tensorflow/core/grappler/utils_test.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include #include + +#include "absl/strings/substitute.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/tensor_testutil.h" @@ -124,56 +126,56 @@ class UtilsTest : public ::testing::Test { }; TEST_F(UtilsTest, NodeName) { - EXPECT_EQ("abc", NodeName("abc")); - EXPECT_EQ("abc", NodeName("^abc")); - EXPECT_EQ("abc", NodeName("abc:0")); - EXPECT_EQ("abc", NodeName("^abc:0")); + EXPECT_EQ(NodeName("abc"), "abc"); + EXPECT_EQ(NodeName("^abc"), "abc"); + EXPECT_EQ(NodeName("abc:0"), "abc"); + EXPECT_EQ(NodeName("^abc:0"), "abc"); - EXPECT_EQ("abc/def", NodeName("abc/def")); - EXPECT_EQ("abc/def", NodeName("^abc/def")); - EXPECT_EQ("abc/def", NodeName("abc/def:1")); - EXPECT_EQ("abc/def", NodeName("^abc/def:1")); + EXPECT_EQ(NodeName("abc/def"), "abc/def"); + EXPECT_EQ(NodeName("^abc/def"), "abc/def"); + EXPECT_EQ(NodeName("abc/def:1"), "abc/def"); + EXPECT_EQ(NodeName("^abc/def:1"), "abc/def"); - EXPECT_EQ("abc/def0", NodeName("abc/def0")); - EXPECT_EQ("abc/def0", NodeName("^abc/def0")); - EXPECT_EQ("abc/def0", NodeName("abc/def0:0")); - EXPECT_EQ("abc/def0", NodeName("^abc/def0:0")); + EXPECT_EQ(NodeName("abc/def0"), "abc/def0"); + EXPECT_EQ(NodeName("^abc/def0"), "abc/def0"); + EXPECT_EQ(NodeName("abc/def0:0"), "abc/def0"); + EXPECT_EQ(NodeName("^abc/def0:0"), "abc/def0"); - EXPECT_EQ("abc/def_0", NodeName("abc/def_0")); - EXPECT_EQ("abc/def_0", NodeName("^abc/def_0")); - EXPECT_EQ("abc/def_0", NodeName("abc/def_0:3")); - EXPECT_EQ("abc/def_0", NodeName("^abc/def_0:3")); + EXPECT_EQ(NodeName("abc/def_0"), "abc/def_0"); + EXPECT_EQ(NodeName("^abc/def_0"), "abc/def_0"); + EXPECT_EQ(NodeName("abc/def_0:3"), "abc/def_0"); + EXPECT_EQ(NodeName("^abc/def_0:3"), "abc/def_0"); - EXPECT_EQ("abc/def_0", NodeName("^abc/def_0:3214")); + EXPECT_EQ(NodeName("^abc/def_0:3214"), "abc/def_0"); } TEST_F(UtilsTest, NodePosition) { - EXPECT_EQ(2, NodePosition("abc:2")); - EXPECT_EQ(123, NodePosition("abc:123")); - EXPECT_EQ(-1, NodePosition("^abc:123")); - EXPECT_EQ(-1, NodePosition("^abc")); - EXPECT_EQ(0, NodePosition("")); + EXPECT_EQ(NodePosition("abc:2"), 2); + EXPECT_EQ(NodePosition("abc:123"), 123); + EXPECT_EQ(NodePosition("^abc:123"), -1); + EXPECT_EQ(NodePosition("^abc"), -1); + EXPECT_EQ(NodePosition(""), 0); } TEST_F(UtilsTest, NodePositionIfSameNode) { - EXPECT_EQ(-2, NodePositionIfSameNode(":123", "")); - EXPECT_EQ(-2, NodePositionIfSameNode(":", "")); - EXPECT_EQ(-2, NodePositionIfSameNode("", "")); - EXPECT_EQ(123, NodePositionIfSameNode("abc:123", "abc")); - EXPECT_EQ(-1, NodePositionIfSameNode("^abc", "abc")); - EXPECT_EQ(-1, NodePositionIfSameNode("^abc:123", "abc")); - EXPECT_EQ(-2, NodePositionIfSameNode("abc", "xyz")); - EXPECT_EQ(-2, NodePositionIfSameNode("abc", "abc/xyz")); - EXPECT_EQ(-2, NodePositionIfSameNode("abc/xyz", "abc")); - EXPECT_EQ(-2, NodePositionIfSameNode("abc:123", "xyz")); - EXPECT_EQ(-2, NodePositionIfSameNode("^abc", "xyz")); - EXPECT_EQ(-2, NodePositionIfSameNode("^abc:123", "xyz")); + EXPECT_EQ(NodePositionIfSameNode(":123", ""), -2); + EXPECT_EQ(NodePositionIfSameNode(":", ""), -2); + EXPECT_EQ(NodePositionIfSameNode("", ""), -2); + EXPECT_EQ(NodePositionIfSameNode("abc:123", "abc"), 123); + EXPECT_EQ(NodePositionIfSameNode("^abc", "abc"), -1); + EXPECT_EQ(NodePositionIfSameNode("^abc:123", "abc"), -1); + EXPECT_EQ(NodePositionIfSameNode("abc", "xyz"), -2); + EXPECT_EQ(NodePositionIfSameNode("abc", "abc/xyz"), -2); + EXPECT_EQ(NodePositionIfSameNode("abc/xyz", "abc"), -2); + EXPECT_EQ(NodePositionIfSameNode("abc:123", "xyz"), -2); + EXPECT_EQ(NodePositionIfSameNode("^abc", "xyz"), -2); + EXPECT_EQ(NodePositionIfSameNode("^abc:123", "xyz"), -2); } TEST_F(UtilsTest, AddNodeNamePrefix) { - EXPECT_EQ("OPTIMIZED/abc", AddPrefixToNodeName("abc", "OPTIMIZED")); - EXPECT_EQ("^OPTIMIZED/abc", AddPrefixToNodeName("^abc", "OPTIMIZED")); - EXPECT_EQ("OPTIMIZED/", AddPrefixToNodeName("", "OPTIMIZED")); + EXPECT_EQ(AddPrefixToNodeName("abc", "OPTIMIZED"), "OPTIMIZED/abc"); + EXPECT_EQ(AddPrefixToNodeName("^abc", "OPTIMIZED"), "^OPTIMIZED/abc"); + EXPECT_EQ(AddPrefixToNodeName("", "OPTIMIZED"), "OPTIMIZED/"); } TEST_F(UtilsTest, ExecuteWithTimeout) { @@ -204,17 +206,17 @@ TEST_F(UtilsTest, ExecuteWithTimeout) { TEST_F(UtilsTest, NumOutputs) { GraphDef graph; - EXPECT_EQ(2, NumOutputs(CreateConcatOffsetNode(), &graph)); - EXPECT_EQ(5, NumOutputs(CreateFusedBatchNormNode(), &graph)); - EXPECT_EQ(1, NumOutputs(CreateDequeueNode(), &graph)); + EXPECT_EQ(NumOutputs(CreateConcatOffsetNode(), &graph), 2); + EXPECT_EQ(NumOutputs(CreateFusedBatchNormNode(), &graph), 5); + EXPECT_EQ(NumOutputs(CreateDequeueNode(), &graph), 1); } TEST_F(UtilsTest, AsControlDependency) { NodeDef node; node.set_name("foo"); - EXPECT_EQ("^foo", AsControlDependency(node)); - EXPECT_EQ("^foo", AsControlDependency(node.name())); - EXPECT_EQ("^foo", AsControlDependency("^foo")); + EXPECT_EQ(AsControlDependency(node), "^foo"); + EXPECT_EQ(AsControlDependency(node.name()), "^foo"); + EXPECT_EQ(AsControlDependency("^foo"), "^foo"); } TEST_F(UtilsTest, GetTailOfChain) { @@ -233,22 +235,23 @@ TEST_F(UtilsTest, GetTailOfChain) { GraphDef graph; TF_CHECK_OK(s.ToGraphDef(&graph)); - ASSERT_EQ("c0", graph.node(0).name()); - ASSERT_EQ("c1", graph.node(1).name()); - ASSERT_EQ("neg0", graph.node(2).name()); - ASSERT_EQ("neg1", graph.node(3).name()); - ASSERT_EQ("neg2", graph.node(4).name()); - ASSERT_EQ("id1", graph.node(5).name()); - ASSERT_EQ("id2", graph.node(6).name()); - ASSERT_EQ("noop", graph.node(7).name()); + ASSERT_EQ(graph.node_size(), 8); + ASSERT_EQ(graph.node(0).name(), "c0"); + ASSERT_EQ(graph.node(1).name(), "c1"); + ASSERT_EQ(graph.node(2).name(), "neg0"); + ASSERT_EQ(graph.node(3).name(), "neg1"); + ASSERT_EQ(graph.node(4).name(), "neg2"); + ASSERT_EQ(graph.node(5).name(), "id1"); + ASSERT_EQ(graph.node(6).name(), "id2"); + ASSERT_EQ(graph.node(7).name(), "noop"); NodeMap node_map(&graph); auto is_neg = [&](const NodeDef& node) { return node.op() == "Neg"; }; // We walk backwards, starting as "id1", so tail should be "neg1". NodeDef* tail = GetTailOfChain(graph.node(5), node_map, /*follow_control_input=*/false, is_neg); - EXPECT_NE(tail, nullptr); - EXPECT_EQ("neg1", tail->name()); + ASSERT_NE(tail, nullptr); + EXPECT_EQ(tail->name(), "neg1"); // We stop at branching nodes, so tail should be "neg2". auto is_neg_and_non_branching = [&](const NodeDef& node) { @@ -257,22 +260,22 @@ TEST_F(UtilsTest, GetTailOfChain) { tail = GetTailOfChain(graph.node(5), node_map, /*follow_control_input=*/false, is_neg_and_non_branching); - EXPECT_NE(tail, nullptr); - EXPECT_EQ("neg2", tail->name()); + ASSERT_NE(tail, nullptr); + EXPECT_EQ(tail->name(), "neg2"); // We walk backwards, starting from "noop", also following control inputs, // so tail should be "neg0". tail = GetTailOfChain(graph.node(7), node_map, /*follow_control_input=*/true, is_neg); - EXPECT_NE(tail, nullptr); - EXPECT_EQ("neg0", tail->name()); + ASSERT_NE(tail, nullptr); + EXPECT_EQ(tail->name(), "neg0"); // We walk backwards, starting from "noop", not following control inputs, // so tail should be "noop" itself. tail = GetTailOfChain(graph.node(7), node_map, /*follow_control_input=*/false, is_neg); - EXPECT_NE(tail, nullptr); - EXPECT_EQ("noop", tail->name()); + ASSERT_NE(tail, nullptr); + EXPECT_EQ(tail->name(), "noop"); } TEST_F(UtilsTest, DedupControlInputs) { @@ -280,40 +283,40 @@ TEST_F(UtilsTest, DedupControlInputs) { foo.set_name("foo"); foo.add_input("bar"); DedupControlInputs(&foo); - EXPECT_EQ(1, foo.input_size()); - EXPECT_EQ("bar", foo.input(0)); + ASSERT_EQ(foo.input_size(), 1); + EXPECT_EQ(foo.input(0), "bar"); foo.set_input(0, "^bar"); DedupControlInputs(&foo); - EXPECT_EQ(1, foo.input_size()); - EXPECT_EQ("^bar", foo.input(0)); + ASSERT_EQ(foo.input_size(), 1); + EXPECT_EQ(foo.input(0), "^bar"); foo.set_input(0, "bar"); foo.add_input("bar"); DedupControlInputs(&foo); - EXPECT_EQ(2, foo.input_size()); - EXPECT_EQ("bar", foo.input(0)); - EXPECT_EQ("bar", foo.input(1)); + ASSERT_EQ(foo.input_size(), 2); + EXPECT_EQ(foo.input(0), "bar"); + EXPECT_EQ(foo.input(1), "bar"); foo.set_input(1, "^bar"); DedupControlInputs(&foo); - EXPECT_EQ(1, foo.input_size()); - EXPECT_EQ("bar", foo.input(0)); + ASSERT_EQ(foo.input_size(), 1); + EXPECT_EQ(foo.input(0), "bar"); foo.set_input(0, "^bar"); foo.add_input("^bar"); DedupControlInputs(&foo); - EXPECT_EQ(1, foo.input_size()); - EXPECT_EQ("^bar", foo.input(0)); + ASSERT_EQ(foo.input_size(), 1); + EXPECT_EQ(foo.input(0), "^bar"); foo.set_input(0, "bar"); foo.add_input("gnu"); foo.add_input("^bar"); foo.add_input("^gnu"); DedupControlInputs(&foo); - EXPECT_EQ(2, foo.input_size()); - EXPECT_EQ("bar", foo.input(0)); - EXPECT_EQ("gnu", foo.input(1)); + ASSERT_EQ(foo.input_size(), 2); + EXPECT_EQ(foo.input(0), "bar"); + EXPECT_EQ(foo.input(1), "gnu"); } TEST_F(UtilsTest, NumNonControlOutputs) { @@ -347,14 +350,14 @@ TEST_F(UtilsTest, NumNonControlOutputs) { NodeMap node_map(&graph); const NodeDef* add_node = node_map.GetNode("add"); - ASSERT_TRUE(add_node != nullptr); + ASSERT_NE(add_node, nullptr); // [a, b] are only non-control inputs - EXPECT_EQ(2, NumNonControlInputs(*add_node)); + EXPECT_EQ(NumNonControlInputs(*add_node), 2); // [sqrt, shape] are non control outputs - EXPECT_EQ(2, NumNonControlOutputs(*add_node, node_map)); + EXPECT_EQ(NumNonControlOutputs(*add_node, node_map), 2); // sqrt is the only data output - EXPECT_EQ(1, NumNonControlDataOutputs(*add_node, node_map)); + EXPECT_EQ(NumNonControlDataOutputs(*add_node, node_map), 1); } TEST(CheckAttrExists, All) { @@ -465,10 +468,104 @@ TEST_F(UtilsTest, SetTensorValueBFloat16IntMin) { } TEST_F(UtilsTest, TensorIdToString) { - EXPECT_EQ("^foo", TensorIdToString({"foo", -1})); - EXPECT_EQ("foo", TensorIdToString({"foo", 0})); - EXPECT_EQ("foo:1", TensorIdToString({"foo", 1})); - EXPECT_EQ("foo:2", TensorIdToString({"foo", 2})); + EXPECT_EQ(TensorIdToString({"foo", -1}), "^foo"); + EXPECT_EQ(TensorIdToString({"foo", 0}), "foo"); + EXPECT_EQ(TensorIdToString({"foo", 1}), "foo:1"); + EXPECT_EQ(TensorIdToString({"foo", 2}), "foo:2"); +} + +template +void TestSetTensorValue(DataType type, int val, bool success, + absl::string_view error_msg) { + Tensor t(type, TensorShape({})); + Status s = SetTensorValue(t.dtype(), val, &t); + EXPECT_EQ(s.ok(), success); + if (s.ok()) { + test::ExpectTensorEqual(Tensor(static_cast(val)), t); + } else { + EXPECT_EQ(s.error_message(), error_msg); + } +} + +TEST(SetTensorValueTest, Quantized) { + auto int_min_error = [](DataType type) { + return absl::Substitute( + "Cannot store value -2147483648 in tensor of type $0", + DataType_Name(type)); + }; + auto int_max_error = [](DataType type) { + return absl::Substitute( + "Cannot store value 2147483647 in tensor of type $0", + DataType_Name(type)); + }; + const int kMinInt = std::numeric_limits::min(); + const int kMaxInt = std::numeric_limits::max(); + + TestSetTensorValue(DT_QINT8, -8, /*success=*/true, /*error_msg=*/""); + TestSetTensorValue(DT_QINT8, 0, /*success=*/true, /*error_msg=*/""); + TestSetTensorValue(DT_QINT8, 8, /*success=*/true, /*error_msg=*/""); + TestSetTensorValue(DT_QINT8, std::numeric_limits::min(), + /*success=*/true, /*error_msg=*/""); + TestSetTensorValue(DT_QINT8, std::numeric_limits::max(), + /*success=*/true, /*error_msg=*/""); + TestSetTensorValue(DT_QINT8, kMinInt, /*success=*/false, + int_min_error(DT_QINT8)); + TestSetTensorValue(DT_QINT8, kMaxInt, /*success=*/false, + int_max_error(DT_QINT8)); + + TestSetTensorValue( + DT_QUINT8, -8, /*success=*/false, + /*error_msg=*/"Cannot store value -8 in tensor of type DT_QUINT8"); + TestSetTensorValue(DT_QUINT8, 0, /*success=*/true, /*error_msg=*/""); + TestSetTensorValue(DT_QUINT8, 8, /*success=*/true, /*error_msg=*/""); + TestSetTensorValue(DT_QUINT8, std::numeric_limits::min(), + /*success=*/true, /*error_msg=*/""); + TestSetTensorValue(DT_QUINT8, std::numeric_limits::max(), + /*success=*/true, /*error_msg=*/""); + TestSetTensorValue(DT_QUINT8, kMinInt, /*success=*/false, + int_min_error(DT_QUINT8)); + TestSetTensorValue(DT_QUINT8, kMaxInt, /*success=*/false, + int_max_error(DT_QUINT8)); + + TestSetTensorValue(DT_QINT16, -8, /*success=*/true, /*error_msg=*/""); + TestSetTensorValue(DT_QINT16, 0, /*success=*/true, /*error_msg=*/""); + TestSetTensorValue(DT_QINT16, 8, /*success=*/true, /*error_msg=*/""); + TestSetTensorValue(DT_QINT16, std::numeric_limits::min(), + /*success=*/true, /*error_msg=*/""); + TestSetTensorValue(DT_QINT16, std::numeric_limits::max(), + /*success=*/true, /*error_msg=*/""); + TestSetTensorValue(DT_QINT16, kMinInt, /*success=*/false, + int_min_error(DT_QINT16)); + TestSetTensorValue(DT_QINT16, kMaxInt, /*success=*/false, + int_max_error(DT_QINT16)); + + TestSetTensorValue( + DT_QUINT16, -8, /*success=*/false, + /*error_msg=*/"Cannot store value -8 in tensor of type DT_QUINT16"); + TestSetTensorValue(DT_QUINT16, 0, /*success=*/true, + /*error_msg=*/""); + TestSetTensorValue(DT_QUINT16, 8, /*success=*/true, + /*error_msg=*/""); + TestSetTensorValue(DT_QUINT16, std::numeric_limits::min(), + /*success=*/true, /*error_msg=*/""); + TestSetTensorValue(DT_QUINT16, std::numeric_limits::max(), + /*success=*/true, /*error_msg=*/""); + TestSetTensorValue(DT_QUINT16, kMinInt, /*success=*/false, + int_min_error(DT_QUINT16)); + TestSetTensorValue(DT_QUINT16, kMaxInt, /*success=*/false, + int_max_error(DT_QUINT16)); + + TestSetTensorValue(DT_QINT32, -8, /*success=*/true, /*error_msg=*/""); + TestSetTensorValue(DT_QINT32, 0, /*success=*/true, /*error_msg=*/""); + TestSetTensorValue(DT_QINT32, 8, /*success=*/true, /*error_msg=*/""); + TestSetTensorValue(DT_QINT32, std::numeric_limits::min(), + /*success=*/true, /*error_msg=*/""); + TestSetTensorValue(DT_QINT32, std::numeric_limits::max(), + /*success=*/true, /*error_msg=*/""); + TestSetTensorValue(DT_QINT32, kMinInt, /*success=*/true, + /*error_msg=*/""); + TestSetTensorValue(DT_QINT32, kMaxInt, /*success=*/true, + /*error_msg=*/""); } } // namespace