[TF:TRT] Implement cast from fp16 to fp32 with IIdentityLayer.

This is the first CL to implement the request in b/150285802.

Add Cast op test to convert_nodes_test.

PiperOrigin-RevId: 312093049
Change-Id: I77215cf6da104f51acc93de1b03e9a179db54f0a
This commit is contained in:
Bixia Zheng 2020-05-18 09:23:09 -07:00 committed by TensorFlower Gardener
parent 46f7108d78
commit 32165792a3
3 changed files with 109 additions and 20 deletions

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
@ -795,6 +796,19 @@ nvinfer1::Dims TRT_TensorOrWeights::GetTrtDims() const {
}
}
Status TRT_TensorOrWeights::GetTfType(DataType* tf_type) const {
if (is_tensor()) {
nvinfer1::DataType trt_type = tensor()->getType();
return TrtTypeToTfType(trt_type, tf_type);
}
if (is_weights()) {
*tf_type = weights().GetTensor().dtype();
return Status::OK();
}
return errors::Internal("The object is probably not initialized");
}
string TRT_TensorOrWeights::DebugString() const {
string output = "TRT_TensorOrWeights(type=";
if (is_tensor()) {
@ -1900,27 +1914,48 @@ Status CheckInputsWeights(
return Status::OK();
}
Status AllowDataTypes(const OpConverterParams& params,
const std::set<DataType>& allowed_dtypes,
const char* dtype_attr_name = "T") {
const auto& node_def = params.node_def;
Status GetNodeDefTfType(const NodeDef& node_def, DataType* tf_type,
const char* type_attr_name) {
TFAttrs attrs(node_def);
if (!attrs.count(dtype_attr_name)) {
return errors::InvalidArgument("Attribute with name ", dtype_attr_name,
if (!attrs.count(type_attr_name)) {
return errors::InvalidArgument("Attribute with name ", type_attr_name,
" not found.");
}
const auto op_dtype = attrs.get<DataType>(dtype_attr_name);
if (!allowed_dtypes.count(op_dtype)) {
// Build string list of allowed types.
std::ostringstream ss;
for (auto it = allowed_dtypes.begin(); it != allowed_dtypes.end(); ++it) {
if (it != allowed_dtypes.begin()) ss << ", ";
ss << DataTypeString(*it);
*tf_type = attrs.get<DataType>(type_attr_name);
return Status::OK();
}
Status GetInputTfType(const OpConverterParams& params, DataType* tf_type,
int pos) {
const std::vector<TRT_TensorOrWeights>& inputs = params.inputs;
if (inputs.size() <= pos) {
return errors::Internal("Invalid input position");
}
return errors::Unimplemented("Data type ", DataTypeString(op_dtype),
return inputs[pos].GetTfType(tf_type);
}
constexpr const char kOutputTypeAttrName[] = "T";
Status GetOutputTfType(const OpConverterParams& params, DataType* tf_type) {
return GetNodeDefTfType(params.node_def, tf_type, kOutputTypeAttrName);
}
Status AllowDataTypes(const OpConverterParams& params,
const std::set<DataType>& allowed_types,
const char* type_attr_name = kOutputTypeAttrName) {
const auto& node_def = params.node_def;
DataType tf_type;
TF_RETURN_IF_ERROR(GetNodeDefTfType(node_def, &tf_type, type_attr_name));
if (!allowed_types.count(tf_type)) {
string allowed_types_string = absl::StrJoin(
allowed_types, ", ", [](string* out, const DataType& type) {
absl::StrAppendFormat(out, "%s", DataTypeString(type));
});
return errors::Unimplemented("Data type ", DataTypeString(tf_type),
" is not supported for ", node_def.op(),
", must be one of [", ss.str(), "], at ",
node_def.name());
", must be one of [", allowed_types_string,
"], at ", node_def.name());
}
return Status::OK();
}
@ -4598,6 +4633,42 @@ Status ConvertUnpack(OpConverterParams* params) {
return ConvertSplitHelper(params, inputs.at(0), tf_axis, num, true);
}
// Supports cast fp16=>fp32 through IIdentityLayer.
Status ConvertCast(OpConverterParams* params) {
const NodeDef& node_def = params->node_def;
TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}}));
auto unsupport_cast_error = [&]() {
return errors::Unimplemented("Cast op: ", node_def.op(),
" not supported at: ", node_def.name());
};
DataType input_type;
TF_RETURN_IF_ERROR(GetInputTfType(*params, &input_type, 0));
if (input_type != DataType::DT_HALF) {
return unsupport_cast_error();
}
DataType output_type;
TF_RETURN_IF_ERROR(GetOutputTfType(*params, &output_type));
if (output_type != DataType::DT_FLOAT) {
return unsupport_cast_error();
}
if (params->validation_only) return Status::OK();
nvinfer1::ITensor* input = params->inputs.at(0).tensor();
nvinfer1::IIdentityLayer* layer =
params->converter->network()->addIdentity(*input);
layer->setPrecision(nvinfer1::DataType::kFLOAT);
if (layer->getOutput(0)->getType() != nvinfer1::DataType::kFLOAT) {
return errors::Internal("IIdentityLayer doesn't work as expected");
}
params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0)));
return Status::OK();
}
Status ConvertConcat(OpConverterParams* params) {
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
@ -5675,6 +5746,7 @@ static void RegisterValidatableOpConverters(
(*registration)["CombinedNonMaxSuppression"] = ConvertCombinedNMS;
#endif
(*registration)["AddN"] = ConvertAddN;
(*registration)["Cast"] = ConvertCast;
(*registration)["ConcatV2"] = ConvertConcat;
(*registration)["Const"] = ConvertConst;
(*registration)["Conv2D"] = ConvertConv2D;

View File

@ -294,6 +294,8 @@ class TRT_TensorOrWeights {
nvinfer1::Dims GetTrtDims() const;
Status GetTfType(DataType* tf_type) const;
int batch_size() const { return batch_size_; }
string DebugString() const;

View File

@ -5147,6 +5147,14 @@ NodeDef CreateUnaryOp() {
return T(s.WithOpName("my_unary"), input).operation.node()->def();
}
NodeDef CreateCastOp() {
Scope s = Scope::NewRootScope();
auto input = ops::Placeholder(s.WithOpName("input"), DT_HALF);
return ops::Cast(s.WithOpName("my_unary"), input, DT_FLOAT)
.operation.node()
->def();
}
TEST_P(ParameterizedOpConverterTest, ConvertUnary) {
const auto& spec = GetParam();
const TrtTestMode trt_mode = std::get<0>(spec);
@ -5174,6 +5182,7 @@ TEST_P(ParameterizedOpConverterTest, ConvertUnary) {
ADD_OP("Asinh", ops::Asinh, std::asinh);
ADD_OP("Atan", ops::Atan, std::atan);
ADD_OP("Atanh", ops::Atanh, std::atanh);
op_map["Cast"] = std::make_pair(CreateCastOp, [](float x) { return x; });
ADD_OP("Ceil", ops::Ceil, std::ceil);
ADD_OP("Cos", ops::Cos, std::cos);
ADD_OP("Cosh", ops::Cosh, std::cosh);
@ -5212,7 +5221,13 @@ TEST_P(ParameterizedOpConverterTest, ConvertUnary) {
}
NodeDef node_def = op_map[op_name].first();
AddTestTensor("input", p.input_dims, TfDataTypeToTrt(tf_dtype), trt_mode);
// TODO(bixia): we assume this test is only instantiated for DT_FLOAT for
// now. Need to find a better way to express input and output types.
DataType input_tf_dtype = op_name == "Cast" ? DT_HALF : tf_dtype;
DataType output_tf_dtype = tf_dtype;
AddTestTensor("input", p.input_dims, TfDataTypeToTrt(input_tf_dtype),
trt_mode);
RunValidationAndConversion(node_def, Status::OK(), "my_unary",
p.expected_output_dims);
@ -5220,8 +5235,8 @@ TEST_P(ParameterizedOpConverterTest, ConvertUnary) {
std::vector<float> output;
std::transform(input_values.begin(), input_values.end(),
std::back_inserter(output), op_map[op_name].second);
InstantiateBuildAndRun(tf_dtype, "my_unary", this, p, input_values,
ArrayFloatNear(output, 0.0001, true));
InstantiateBuildAndRun(input_tf_dtype, output_tf_dtype, "my_unary", this, p,
input_values, ArrayFloatNear(output, 0.0001, true));
}
}