diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index a43b16e9e6a..e791ff9ff60 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/tf2tensorrt/convert/utils.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" @@ -795,6 +796,19 @@ nvinfer1::Dims TRT_TensorOrWeights::GetTrtDims() const { } } +Status TRT_TensorOrWeights::GetTfType(DataType* tf_type) const { + if (is_tensor()) { + nvinfer1::DataType trt_type = tensor()->getType(); + return TrtTypeToTfType(trt_type, tf_type); + } + + if (is_weights()) { + *tf_type = weights().GetTensor().dtype(); + return Status::OK(); + } + return errors::Internal("The object is probably not initialized"); +} + string TRT_TensorOrWeights::DebugString() const { string output = "TRT_TensorOrWeights(type="; if (is_tensor()) { @@ -1900,27 +1914,48 @@ Status CheckInputsWeights( return Status::OK(); } -Status AllowDataTypes(const OpConverterParams& params, - const std::set& allowed_dtypes, - const char* dtype_attr_name = "T") { - const auto& node_def = params.node_def; +Status GetNodeDefTfType(const NodeDef& node_def, DataType* tf_type, + const char* type_attr_name) { TFAttrs attrs(node_def); - if (!attrs.count(dtype_attr_name)) { - return errors::InvalidArgument("Attribute with name ", dtype_attr_name, + if (!attrs.count(type_attr_name)) { + return errors::InvalidArgument("Attribute with name ", type_attr_name, " not found."); } - const auto op_dtype = attrs.get(dtype_attr_name); - if (!allowed_dtypes.count(op_dtype)) { - // Build string list of allowed types. - std::ostringstream ss; - for (auto it = allowed_dtypes.begin(); it != allowed_dtypes.end(); ++it) { - if (it != allowed_dtypes.begin()) ss << ", "; - ss << DataTypeString(*it); - } - return errors::Unimplemented("Data type ", DataTypeString(op_dtype), + *tf_type = attrs.get(type_attr_name); + return Status::OK(); +} + +Status GetInputTfType(const OpConverterParams& params, DataType* tf_type, + int pos) { + const std::vector& inputs = params.inputs; + if (inputs.size() <= pos) { + return errors::Internal("Invalid input position"); + } + + return inputs[pos].GetTfType(tf_type); +} + +constexpr const char kOutputTypeAttrName[] = "T"; + +Status GetOutputTfType(const OpConverterParams& params, DataType* tf_type) { + return GetNodeDefTfType(params.node_def, tf_type, kOutputTypeAttrName); +} + +Status AllowDataTypes(const OpConverterParams& params, + const std::set& allowed_types, + const char* type_attr_name = kOutputTypeAttrName) { + const auto& node_def = params.node_def; + DataType tf_type; + TF_RETURN_IF_ERROR(GetNodeDefTfType(node_def, &tf_type, type_attr_name)); + if (!allowed_types.count(tf_type)) { + string allowed_types_string = absl::StrJoin( + allowed_types, ", ", [](string* out, const DataType& type) { + absl::StrAppendFormat(out, "%s", DataTypeString(type)); + }); + return errors::Unimplemented("Data type ", DataTypeString(tf_type), " is not supported for ", node_def.op(), - ", must be one of [", ss.str(), "], at ", - node_def.name()); + ", must be one of [", allowed_types_string, + "], at ", node_def.name()); } return Status::OK(); } @@ -4598,6 +4633,42 @@ Status ConvertUnpack(OpConverterParams* params) { return ConvertSplitHelper(params, inputs.at(0), tf_axis, num, true); } +// Supports cast fp16=>fp32 through IIdentityLayer. +Status ConvertCast(OpConverterParams* params) { + const NodeDef& node_def = params->node_def; + TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}})); + auto unsupport_cast_error = [&]() { + return errors::Unimplemented("Cast op: ", node_def.op(), + " not supported at: ", node_def.name()); + }; + + DataType input_type; + TF_RETURN_IF_ERROR(GetInputTfType(*params, &input_type, 0)); + if (input_type != DataType::DT_HALF) { + return unsupport_cast_error(); + } + + DataType output_type; + TF_RETURN_IF_ERROR(GetOutputTfType(*params, &output_type)); + if (output_type != DataType::DT_FLOAT) { + return unsupport_cast_error(); + } + + if (params->validation_only) return Status::OK(); + + nvinfer1::ITensor* input = params->inputs.at(0).tensor(); + nvinfer1::IIdentityLayer* layer = + params->converter->network()->addIdentity(*input); + layer->setPrecision(nvinfer1::DataType::kFLOAT); + + if (layer->getOutput(0)->getType() != nvinfer1::DataType::kFLOAT) { + return errors::Internal("IIdentityLayer doesn't work as expected"); + } + + params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0))); + return Status::OK(); +} + Status ConvertConcat(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; @@ -5675,6 +5746,7 @@ static void RegisterValidatableOpConverters( (*registration)["CombinedNonMaxSuppression"] = ConvertCombinedNMS; #endif (*registration)["AddN"] = ConvertAddN; + (*registration)["Cast"] = ConvertCast; (*registration)["ConcatV2"] = ConvertConcat; (*registration)["Const"] = ConvertConst; (*registration)["Conv2D"] = ConvertConv2D; diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h index 2092aecd657..2fe8eec9675 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h @@ -294,6 +294,8 @@ class TRT_TensorOrWeights { nvinfer1::Dims GetTrtDims() const; + Status GetTfType(DataType* tf_type) const; + int batch_size() const { return batch_size_; } string DebugString() const; diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index 964370af6be..1efc31f9e24 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -5147,6 +5147,14 @@ NodeDef CreateUnaryOp() { return T(s.WithOpName("my_unary"), input).operation.node()->def(); } +NodeDef CreateCastOp() { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_HALF); + return ops::Cast(s.WithOpName("my_unary"), input, DT_FLOAT) + .operation.node() + ->def(); +} + TEST_P(ParameterizedOpConverterTest, ConvertUnary) { const auto& spec = GetParam(); const TrtTestMode trt_mode = std::get<0>(spec); @@ -5174,6 +5182,7 @@ TEST_P(ParameterizedOpConverterTest, ConvertUnary) { ADD_OP("Asinh", ops::Asinh, std::asinh); ADD_OP("Atan", ops::Atan, std::atan); ADD_OP("Atanh", ops::Atanh, std::atanh); + op_map["Cast"] = std::make_pair(CreateCastOp, [](float x) { return x; }); ADD_OP("Ceil", ops::Ceil, std::ceil); ADD_OP("Cos", ops::Cos, std::cos); ADD_OP("Cosh", ops::Cosh, std::cosh); @@ -5212,7 +5221,13 @@ TEST_P(ParameterizedOpConverterTest, ConvertUnary) { } NodeDef node_def = op_map[op_name].first(); - AddTestTensor("input", p.input_dims, TfDataTypeToTrt(tf_dtype), trt_mode); + // TODO(bixia): we assume this test is only instantiated for DT_FLOAT for + // now. Need to find a better way to express input and output types. + DataType input_tf_dtype = op_name == "Cast" ? DT_HALF : tf_dtype; + DataType output_tf_dtype = tf_dtype; + + AddTestTensor("input", p.input_dims, TfDataTypeToTrt(input_tf_dtype), + trt_mode); RunValidationAndConversion(node_def, Status::OK(), "my_unary", p.expected_output_dims); @@ -5220,8 +5235,8 @@ TEST_P(ParameterizedOpConverterTest, ConvertUnary) { std::vector output; std::transform(input_values.begin(), input_values.end(), std::back_inserter(output), op_map[op_name].second); - InstantiateBuildAndRun(tf_dtype, "my_unary", this, p, input_values, - ArrayFloatNear(output, 0.0001, true)); + InstantiateBuildAndRun(input_tf_dtype, output_tf_dtype, "my_unary", this, p, + input_values, ArrayFloatNear(output, 0.0001, true)); } }