From 762093140b412f3d1559f012eb0ec1fc63c69da3 Mon Sep 17 00:00:00 2001 From: Guangda Lai Date: Thu, 14 Mar 2019 11:40:11 -0700 Subject: [PATCH] Support fp16 conversion for conversion of Square, Relu and LeakyRelu. This also fix a bug in GatherV2 converter where it uses the wrong data type attribute name. Will enable it for more converters later. PiperOrigin-RevId: 238485945 --- tensorflow/compiler/tf2tensorrt/BUILD | 1 + .../tf2tensorrt/convert/convert_nodes.cc | 234 ++++++++++-------- .../tf2tensorrt/convert/convert_nodes_test.cc | 80 ++++-- 3 files changed, 187 insertions(+), 128 deletions(-) diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD index a7fb772ec25..38958d4ce69 100644 --- a/tensorflow/compiler/tf2tensorrt/BUILD +++ b/tensorflow/compiler/tf2tensorrt/BUILD @@ -326,6 +326,7 @@ tf_cuda_cc_test( "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", "//tensorflow/cc:ops", "//tensorflow/cc:scope", "//tensorflow/core/grappler/costs:graph_properties", diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index 45c58d2259e..8aeecaff925 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -116,6 +116,88 @@ inline Status ConvertDType(DataType tf_dtype, nvinfer1::DataType* trt_dtype) { return Status::OK(); } +class TFAttrs { + public: + explicit TFAttrs(const NodeDef& tf_node) { + for (const auto& attr : tf_node.attr()) { + attrs_.insert({attr.first, &attr.second}); + } + } + + bool count(const string& key) const { return attrs_.count(key); } + + AttrValue const* at(const string& key) const { + if (!attrs_.count(key)) { + LOG(FATAL) << "Attribute not found: " << key; + } + return attrs_.at(key); + } + + template + T get(const string& key) const; + + template + T get(const string& key, const T& default_value) const { + return attrs_.count(key) ? this->get(key) : default_value; + } + + std::vector GetAllAttrKeys() const { + std::vector attr_list; + for (const auto& attr_item : attrs_) { + attr_list.emplace_back(attr_item.first); + } + return attr_list; + } + + private: + typedef std::map AttrMap; + AttrMap attrs_; +}; + +template <> +string TFAttrs::get(const string& key) const { + return this->at(key)->s(); +} + +template <> +std::vector TFAttrs::get>(const string& key) const { + auto attr = this->at(key)->list().i(); + return std::vector(attr.begin(), attr.end()); +} + +template <> +std::vector TFAttrs::get>(const string& key) const { + auto attr = this->at(key)->list().f(); + return std::vector(attr.begin(), attr.end()); +} + +template <> +nvinfer1::DataType TFAttrs::get(const string& key) const { + nvinfer1::DataType trt_dtype(nvinfer1::DataType::kFLOAT); + TF_CHECK_OK(ConvertDType(this->at(key)->type(), &trt_dtype)); + return trt_dtype; +} + +template <> +DataType TFAttrs::get(const string& key) const { + return this->at(key)->type(); +} + +template <> +float TFAttrs::get(const string& key) const { + return this->at(key)->f(); +} + +template <> +bool TFAttrs::get(const string& key) const { + return this->at(key)->b(); +} + +template <> +int64 TFAttrs::get(const string& key) const { + return this->at(key)->i(); +} + template inline nvinfer1::Dims TensorShapeToTrtDims(const TensorShapeType& shape, bool ignore_first_dim) { @@ -379,17 +461,35 @@ nvinfer1::ITensor* Converter::CreateConstantLayer( Status CreateBroadcastableScalarConstant(OpConverterParams* params, float value, const nvinfer1::Dims& dims, - const nvinfer1::ITensor** tensor) { + const nvinfer1::ITensor** tensor, + const char* dtype_attr_name = "T") { + TFAttrs attrs(params->node_def); + DataType dtype; + if (attrs.count(dtype_attr_name)) { + dtype = attrs.get(dtype_attr_name); + } else { + dtype = DT_FLOAT; // Default to FP32. + } + // In order to be broadcastable, the number of dims has to match. nvinfer1::Dims broadcastable_dims(dims); for (int i = 0; i < broadcastable_dims.nbDims; i++) { broadcastable_dims.d[i] = 1; } - TRT_ShapedWeights weights = params->weight_store->GetTempWeights( - DataType::DT_FLOAT, broadcastable_dims); - auto weights_ptr = - static_cast(const_cast(weights.GetValues())); - weights_ptr[0] = value; + TRT_ShapedWeights weights = + params->weight_store->GetTempWeights(dtype, broadcastable_dims); + void* raw_ptr = const_cast(weights.GetValues()); + switch (dtype) { + case DataType::DT_FLOAT: + static_cast(raw_ptr)[0] = value; + break; + case DataType::DT_HALF: + static_cast(raw_ptr)[0] = Eigen::half(value); + break; + default: + return errors::InvalidArgument("Unsupported data type ", + DataTypeString(dtype)); + } *tensor = params->converter->CreateConstantLayer(weights, broadcastable_dims); TFTRT_RETURN_ERROR_IF_NULLPTR(*tensor, params->node_def.name()); params->converter->ProvideQuantizationRange( @@ -662,88 +762,6 @@ string TRT_TensorOrWeights::DebugString() const { return output; } -class TFAttrs { - public: - explicit TFAttrs(const NodeDef& tf_node) { - for (const auto& attr : tf_node.attr()) { - attrs_.insert({attr.first, &attr.second}); - } - } - - bool count(const string& key) const { return attrs_.count(key); } - - AttrValue const* at(const string& key) const { - if (!attrs_.count(key)) { - LOG(FATAL) << "Attribute not found: " << key; - } - return attrs_.at(key); - } - - template - T get(const string& key) const; - - template - T get(const string& key, const T& default_value) const { - return attrs_.count(key) ? this->get(key) : default_value; - } - - std::vector GetAllAttrKeys() const { - std::vector attr_list; - for (const auto& attr_item : attrs_) { - attr_list.emplace_back(attr_item.first); - } - return attr_list; - } - - private: - typedef std::map AttrMap; - AttrMap attrs_; -}; - -template <> -string TFAttrs::get(const string& key) const { - return this->at(key)->s(); -} - -template <> -std::vector TFAttrs::get>(const string& key) const { - auto attr = this->at(key)->list().i(); - return std::vector(attr.begin(), attr.end()); -} - -template <> -std::vector TFAttrs::get>(const string& key) const { - auto attr = this->at(key)->list().f(); - return std::vector(attr.begin(), attr.end()); -} - -template <> -nvinfer1::DataType TFAttrs::get(const string& key) const { - nvinfer1::DataType trt_dtype(nvinfer1::DataType::kFLOAT); - TF_CHECK_OK(ConvertDType(this->at(key)->type(), &trt_dtype)); - return trt_dtype; -} - -template <> -DataType TFAttrs::get(const string& key) const { - return this->at(key)->type(); -} - -template <> -float TFAttrs::get(const string& key) const { - return this->at(key)->f(); -} - -template <> -bool TFAttrs::get(const string& key) const { - return this->at(key)->b(); -} - -template <> -int64 TFAttrs::get(const string& key) const { - return this->at(key)->i(); -} - // TODO(jie): reorder4 & reorder2 should be merged? // TODO(aaroey): fix the order of parameters. template @@ -1435,26 +1453,27 @@ Status CheckInputsWeights( } Status AllowDataTypes(const OpConverterParams& params, - const std::set& allowed_dtypes) { + const std::set& allowed_dtypes, + const char* dtype_attr_name = "T") { const auto& node_def = params.node_def; - TFAttrs attrs(params.node_def); - if (attrs.count("T")) { - const auto op_dtype = attrs.get("T"); - if (!allowed_dtypes.count(op_dtype)) { - // Build string list of allowed types. - std::ostringstream ss; - for (auto it = allowed_dtypes.begin(); it != allowed_dtypes.end(); ++it) { - if (it != allowed_dtypes.begin()) ss << ", "; - ss << DataTypeString(*it); - } - return errors::Unimplemented("Data type ", DataTypeString(op_dtype), - " is not supported for ", node_def.op(), - ", must be one of [", ss.str(), "], at ", - node_def.name()); - } + TFAttrs attrs(node_def); + if (!attrs.count(dtype_attr_name)) { + return errors::InvalidArgument("Attribute with name ", dtype_attr_name, + " not found."); + } + const auto op_dtype = attrs.get(dtype_attr_name); + if (!allowed_dtypes.count(op_dtype)) { + // Build string list of allowed types. + std::ostringstream ss; + for (auto it = allowed_dtypes.begin(); it != allowed_dtypes.end(); ++it) { + if (it != allowed_dtypes.begin()) ss << ", "; + ss << DataTypeString(*it); + } + return errors::Unimplemented("Data type ", DataTypeString(op_dtype), + " is not supported for ", node_def.op(), + ", must be one of [", ss.str(), "], at ", + node_def.name()); } - // If there is no T attribute, we can't determine the type of the op. We will - // allow it to convert for now. return Status::OK(); } @@ -3696,7 +3715,8 @@ Status ConvertGather(OpConverterParams* params) { TF_RETURN_IF_ERROR(CheckInputsWeights( *params, {{"params", false}, {"indices", false}, {"axis", true}})); TF_RETURN_IF_ERROR(AllowDataTypes( - *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32})); + *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}, + /*dtype_attr_name=*/"Tparams")); absl::Span axis = inputs.at(2).weights().GetSpan(); if (axis.size() != 1) { return errors::InvalidArgument("Axis for GatherV2 must be a scalar, at ", diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index bd656b0e836..853b313367c 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/ops/nn_ops_internal.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" @@ -109,13 +110,17 @@ DataType TrtDataTypeToTf(nvinfer1::DataType trt_dtype) { } NodeDef MakeNodeDef(const string& name, const string& op, - const std::vector& inputs) { + const std::vector& inputs, + const std::map attrs = {}) { NodeDef node_def; node_def.set_name(name); node_def.set_op(op); for (const string& input : inputs) { node_def.add_input(input); } + for (const auto& attr : attrs) { + (*node_def.mutable_attr())[attr.first] = attr.second; + } return node_def; } @@ -1094,8 +1099,22 @@ class OpConverterTest : public ::testing::Test { validator_inputs_.clear(); } - // TODO(laigd): test fp16 and int8 support. - void BuildAndRun(const DataVec& input_data, DataVec* output_data) { + void CheckDataTypeMatches(const DataVec& datas) { + for (const auto& data : datas) { + const int input_index = engine_->getBindingIndex(data.name); + ASSERT_NE(-1, input_index); + const nvinfer1::DataType trt_dtype = + engine_->getBindingDataType(input_index); + const DataType tf_dtype = TrtDataTypeToTf(trt_dtype); + ASSERT_EQ(data.tensor.dtype(), tf_dtype) + << DataTypeString(data.tensor.dtype()) << " vs. " + << DataTypeString(tf_dtype); + } + } + + // TODO(laigd): test fp16 and int8 support for more converters. + void BuildAndRun(const DataVec& input_data, DataVec* output_data, + TrtPrecisionMode precision_mode = TrtPrecisionMode::FP32) { // Mark the output tensor as TRT engine output. std::vector output_info; for (const auto& data : *output_data) { @@ -1105,9 +1124,20 @@ class OpConverterTest : public ::testing::Test { TF_EXPECT_OK(converter_->RenameAndMarkOutputTensors(output_info)); // Build the TRT engine. + if (precision_mode == TrtPrecisionMode::FP16) { + builder_->setFp16Mode(true); + } else if (precision_mode == TrtPrecisionMode::INT8) { + // Setting FP16 mode as well allows TRT to also consider FP16 kernels and + // use them in situations where they are faster than INT8 or where INT8 is + // not supported for a given layer. + builder_->setFp16Mode(true); + builder_->setInt8Mode(true); + } ASSERT_EQ(nullptr, engine_.get()); engine_.reset(builder_->buildCudaEngine(*converter_->network())); CHECK_NOTNULL(engine_.get()); + CheckDataTypeMatches(input_data); + CheckDataTypeMatches(*output_data); // Execute the TRT engine. const int num_bindings = input_data.size() + output_data->size(); @@ -1761,7 +1791,9 @@ void TestBinaryTensorOpWeightNoBroadcast(OpConverterTest* test) { const DataVec input_data{ {"input", test::AsTensor(swap_inputs ? operand2 : operand1)}}; DataVec output_data{{"my_binary", ConstructTensor(2)}}; - test->BuildAndRun(input_data, &output_data); + test->BuildAndRun( + input_data, &output_data, + dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32); if (node_def.op() == "Add") { EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(CType(5), CType(10.5))); @@ -1942,7 +1974,9 @@ void TestBinaryTensorOpTensor(OpConverterTest* test) { DataVec output_data{{"my_binary", ConstructTensor(4)}}; // After broadcasting first input becomes {3, 6, 3, 6} and second input // becomes {2, 3, 2, 3}. - test->BuildAndRun(input_data, &output_data); + test->BuildAndRun( + input_data, &output_data, + dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32); if (node_def.op() == "Add") { EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(CType(5), CType(8), CType(6), CType(9))); @@ -1974,10 +2008,13 @@ void TestBinaryTensorOpTensor(OpConverterTest* test) { } TEST_F(OpConverterTest, ConvertBinary) { + AttrValue dtype; + dtype.set_type(DT_FLOAT); // Input size doesn't match, should fail. for (size_t num_inputs = 0; num_inputs < 2; ++num_inputs) { Reset(); - NodeDef node_def = MakeNodeDef("my_add", "Add", {num_inputs, "input"}); + NodeDef node_def = + MakeNodeDef("my_add", "Add", {num_inputs, "input"}, {{"T", dtype}}); AddTestTensor("input", {1}, /*batch_size=*/1, nvinfer1::DataType::kFLOAT); RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, StrCat("Add got ", std::to_string(num_inputs), @@ -1987,7 +2024,8 @@ TEST_F(OpConverterTest, ConvertBinary) { { // Both inputs are weights. Reset(); - NodeDef node_def = MakeNodeDef("my_add", "Add", {"weights1", "weights2"}); + NodeDef node_def = + MakeNodeDef("my_add", "Add", {"weights1", "weights2"}, {{"T", dtype}}); AddTestWeights("weights1", {1}, {1}); AddTestWeights("weights2", {1}, {1}); RunValidationAndConversion( @@ -2002,15 +2040,12 @@ TEST_F(OpConverterTest, ConvertBinary) { TestBinaryTensorOpWeightNoBroadcast(this); TestBinaryTensorOpWeightNoBroadcast(this); TestBinaryTensorOpWeightNoBroadcast(this); -#if 0 - // TODO(b/119560144): it doesn't support FP16 constants and the following test - // will fail. + TestBinaryTensorOpWeightNoBroadcast(this); TestBinaryTensorOpWeightNoBroadcast(this); TestBinaryTensorOpWeightNoBroadcast(this); TestBinaryTensorOpWeightNoBroadcast(this); TestBinaryTensorOpWeightNoBroadcast(this); -#endif // Test BinaryTensorOpWeight() with channel-wise broadcasting. TestBinaryTensorOpWeightWithChannelWiseBroadcast(this); @@ -2192,7 +2227,8 @@ void TestConvertSquare(OpConverterTest* test) { auto square = ops::Square(s.WithOpName("my_square"), input); NodeDef node_def = square.operation.node()->def(); - test->AddTestTensor("input", {1, 20}); + test->AddTestTensor("input", {1, 20}, /*batch_size=*/1, + TfDataTypeToTrt(dtype)); test->RunValidationAndConversion(node_def); TRT_TensorOrWeights output; TF_EXPECT_OK(test->GetTensorOrWeights("my_square", &output)); @@ -2202,14 +2238,18 @@ void TestConvertSquare(OpConverterTest* test) { const int num_inputs = 20; std::vector inputs(num_inputs); std::vector expected_outputs(num_inputs); - for (int i = 0; i < 20; i++) { + for (int i = 0; i < num_inputs; ++i) { const CType value = CType(i - 9); inputs[i] = value; expected_outputs[i] = value * value; } const DataVec input_data{{"input", test::AsTensor(inputs)}}; + // Engine outputs are converted to FP16 automatically if we set FP16 mode in + // the builder. DataVec output_data{{"my_square", ConstructTensor(num_inputs)}}; - test->BuildAndRun(input_data, &output_data); + test->BuildAndRun( + input_data, &output_data, + dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32); ExpectArrayNear(expected_outputs, GetSpanForData(output_data[0])); } @@ -2237,9 +2277,7 @@ TEST_F(OpConverterTest, ConvertSquare) { // OK. Note that kINT32 is not supported by IElementWiseLayer, so we don't // test DT_INT32 type here. TestConvertSquare(this); - // TODO(tmorris): Looks like there may be a bug with this layer for FP16 - // inputs. Disabling for now. - // TestConvertSquare(this); + TestConvertSquare(this); } TEST_F(OpConverterTest, ConvertActivation) { @@ -2269,10 +2307,10 @@ TEST_F(OpConverterTest, ConvertActivation) { Scope s = Scope::NewRootScope(); auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); if (op_name == "LeakyRelu") { - // LeakyRelu does not have a C++ API - NodeDef node_def = MakeNodeDef("my_act", "LeakyRelu", {"input"}); - (*node_def.mutable_attr())["alpha"].set_f(kAlpha); - return node_def; + auto act = + ops::internal::LeakyRelu(s.WithOpName("my_act"), input, + ops::internal::LeakyRelu::Alpha(kAlpha)); + return act.operation.node()->def(); } else if (op_name == "Relu") { auto act = ops::Relu(s.WithOpName("my_act"), input); return act.operation.node()->def();