[TF:TRT] Implement cast from fp16 to fp32 with IIdentityLayer.
This is the first CL to implement the request in b/150285802. Add Cast op test to convert_nodes_test. PiperOrigin-RevId: 312093049 Change-Id: I77215cf6da104f51acc93de1b03e9a179db54f0a
This commit is contained in:
parent
46f7108d78
commit
32165792a3
@ -29,6 +29,7 @@ limitations under the License.
|
|||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "absl/strings/match.h"
|
#include "absl/strings/match.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "absl/strings/str_format.h"
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
|
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
|
||||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
|
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
|
||||||
@ -795,6 +796,19 @@ nvinfer1::Dims TRT_TensorOrWeights::GetTrtDims() const {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status TRT_TensorOrWeights::GetTfType(DataType* tf_type) const {
|
||||||
|
if (is_tensor()) {
|
||||||
|
nvinfer1::DataType trt_type = tensor()->getType();
|
||||||
|
return TrtTypeToTfType(trt_type, tf_type);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (is_weights()) {
|
||||||
|
*tf_type = weights().GetTensor().dtype();
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
return errors::Internal("The object is probably not initialized");
|
||||||
|
}
|
||||||
|
|
||||||
string TRT_TensorOrWeights::DebugString() const {
|
string TRT_TensorOrWeights::DebugString() const {
|
||||||
string output = "TRT_TensorOrWeights(type=";
|
string output = "TRT_TensorOrWeights(type=";
|
||||||
if (is_tensor()) {
|
if (is_tensor()) {
|
||||||
@ -1900,27 +1914,48 @@ Status CheckInputsWeights(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status AllowDataTypes(const OpConverterParams& params,
|
Status GetNodeDefTfType(const NodeDef& node_def, DataType* tf_type,
|
||||||
const std::set<DataType>& allowed_dtypes,
|
const char* type_attr_name) {
|
||||||
const char* dtype_attr_name = "T") {
|
|
||||||
const auto& node_def = params.node_def;
|
|
||||||
TFAttrs attrs(node_def);
|
TFAttrs attrs(node_def);
|
||||||
if (!attrs.count(dtype_attr_name)) {
|
if (!attrs.count(type_attr_name)) {
|
||||||
return errors::InvalidArgument("Attribute with name ", dtype_attr_name,
|
return errors::InvalidArgument("Attribute with name ", type_attr_name,
|
||||||
" not found.");
|
" not found.");
|
||||||
}
|
}
|
||||||
const auto op_dtype = attrs.get<DataType>(dtype_attr_name);
|
*tf_type = attrs.get<DataType>(type_attr_name);
|
||||||
if (!allowed_dtypes.count(op_dtype)) {
|
return Status::OK();
|
||||||
// Build string list of allowed types.
|
|
||||||
std::ostringstream ss;
|
|
||||||
for (auto it = allowed_dtypes.begin(); it != allowed_dtypes.end(); ++it) {
|
|
||||||
if (it != allowed_dtypes.begin()) ss << ", ";
|
|
||||||
ss << DataTypeString(*it);
|
|
||||||
}
|
}
|
||||||
return errors::Unimplemented("Data type ", DataTypeString(op_dtype),
|
|
||||||
|
Status GetInputTfType(const OpConverterParams& params, DataType* tf_type,
|
||||||
|
int pos) {
|
||||||
|
const std::vector<TRT_TensorOrWeights>& inputs = params.inputs;
|
||||||
|
if (inputs.size() <= pos) {
|
||||||
|
return errors::Internal("Invalid input position");
|
||||||
|
}
|
||||||
|
|
||||||
|
return inputs[pos].GetTfType(tf_type);
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr const char kOutputTypeAttrName[] = "T";
|
||||||
|
|
||||||
|
Status GetOutputTfType(const OpConverterParams& params, DataType* tf_type) {
|
||||||
|
return GetNodeDefTfType(params.node_def, tf_type, kOutputTypeAttrName);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status AllowDataTypes(const OpConverterParams& params,
|
||||||
|
const std::set<DataType>& allowed_types,
|
||||||
|
const char* type_attr_name = kOutputTypeAttrName) {
|
||||||
|
const auto& node_def = params.node_def;
|
||||||
|
DataType tf_type;
|
||||||
|
TF_RETURN_IF_ERROR(GetNodeDefTfType(node_def, &tf_type, type_attr_name));
|
||||||
|
if (!allowed_types.count(tf_type)) {
|
||||||
|
string allowed_types_string = absl::StrJoin(
|
||||||
|
allowed_types, ", ", [](string* out, const DataType& type) {
|
||||||
|
absl::StrAppendFormat(out, "%s", DataTypeString(type));
|
||||||
|
});
|
||||||
|
return errors::Unimplemented("Data type ", DataTypeString(tf_type),
|
||||||
" is not supported for ", node_def.op(),
|
" is not supported for ", node_def.op(),
|
||||||
", must be one of [", ss.str(), "], at ",
|
", must be one of [", allowed_types_string,
|
||||||
node_def.name());
|
"], at ", node_def.name());
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -4598,6 +4633,42 @@ Status ConvertUnpack(OpConverterParams* params) {
|
|||||||
return ConvertSplitHelper(params, inputs.at(0), tf_axis, num, true);
|
return ConvertSplitHelper(params, inputs.at(0), tf_axis, num, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Supports cast fp16=>fp32 through IIdentityLayer.
|
||||||
|
Status ConvertCast(OpConverterParams* params) {
|
||||||
|
const NodeDef& node_def = params->node_def;
|
||||||
|
TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}}));
|
||||||
|
auto unsupport_cast_error = [&]() {
|
||||||
|
return errors::Unimplemented("Cast op: ", node_def.op(),
|
||||||
|
" not supported at: ", node_def.name());
|
||||||
|
};
|
||||||
|
|
||||||
|
DataType input_type;
|
||||||
|
TF_RETURN_IF_ERROR(GetInputTfType(*params, &input_type, 0));
|
||||||
|
if (input_type != DataType::DT_HALF) {
|
||||||
|
return unsupport_cast_error();
|
||||||
|
}
|
||||||
|
|
||||||
|
DataType output_type;
|
||||||
|
TF_RETURN_IF_ERROR(GetOutputTfType(*params, &output_type));
|
||||||
|
if (output_type != DataType::DT_FLOAT) {
|
||||||
|
return unsupport_cast_error();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (params->validation_only) return Status::OK();
|
||||||
|
|
||||||
|
nvinfer1::ITensor* input = params->inputs.at(0).tensor();
|
||||||
|
nvinfer1::IIdentityLayer* layer =
|
||||||
|
params->converter->network()->addIdentity(*input);
|
||||||
|
layer->setPrecision(nvinfer1::DataType::kFLOAT);
|
||||||
|
|
||||||
|
if (layer->getOutput(0)->getType() != nvinfer1::DataType::kFLOAT) {
|
||||||
|
return errors::Internal("IIdentityLayer doesn't work as expected");
|
||||||
|
}
|
||||||
|
|
||||||
|
params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0)));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status ConvertConcat(OpConverterParams* params) {
|
Status ConvertConcat(OpConverterParams* params) {
|
||||||
const auto& inputs = params->inputs;
|
const auto& inputs = params->inputs;
|
||||||
const auto& node_def = params->node_def;
|
const auto& node_def = params->node_def;
|
||||||
@ -5675,6 +5746,7 @@ static void RegisterValidatableOpConverters(
|
|||||||
(*registration)["CombinedNonMaxSuppression"] = ConvertCombinedNMS;
|
(*registration)["CombinedNonMaxSuppression"] = ConvertCombinedNMS;
|
||||||
#endif
|
#endif
|
||||||
(*registration)["AddN"] = ConvertAddN;
|
(*registration)["AddN"] = ConvertAddN;
|
||||||
|
(*registration)["Cast"] = ConvertCast;
|
||||||
(*registration)["ConcatV2"] = ConvertConcat;
|
(*registration)["ConcatV2"] = ConvertConcat;
|
||||||
(*registration)["Const"] = ConvertConst;
|
(*registration)["Const"] = ConvertConst;
|
||||||
(*registration)["Conv2D"] = ConvertConv2D;
|
(*registration)["Conv2D"] = ConvertConv2D;
|
||||||
|
@ -294,6 +294,8 @@ class TRT_TensorOrWeights {
|
|||||||
|
|
||||||
nvinfer1::Dims GetTrtDims() const;
|
nvinfer1::Dims GetTrtDims() const;
|
||||||
|
|
||||||
|
Status GetTfType(DataType* tf_type) const;
|
||||||
|
|
||||||
int batch_size() const { return batch_size_; }
|
int batch_size() const { return batch_size_; }
|
||||||
|
|
||||||
string DebugString() const;
|
string DebugString() const;
|
||||||
|
@ -5147,6 +5147,14 @@ NodeDef CreateUnaryOp() {
|
|||||||
return T(s.WithOpName("my_unary"), input).operation.node()->def();
|
return T(s.WithOpName("my_unary"), input).operation.node()->def();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
NodeDef CreateCastOp() {
|
||||||
|
Scope s = Scope::NewRootScope();
|
||||||
|
auto input = ops::Placeholder(s.WithOpName("input"), DT_HALF);
|
||||||
|
return ops::Cast(s.WithOpName("my_unary"), input, DT_FLOAT)
|
||||||
|
.operation.node()
|
||||||
|
->def();
|
||||||
|
}
|
||||||
|
|
||||||
TEST_P(ParameterizedOpConverterTest, ConvertUnary) {
|
TEST_P(ParameterizedOpConverterTest, ConvertUnary) {
|
||||||
const auto& spec = GetParam();
|
const auto& spec = GetParam();
|
||||||
const TrtTestMode trt_mode = std::get<0>(spec);
|
const TrtTestMode trt_mode = std::get<0>(spec);
|
||||||
@ -5174,6 +5182,7 @@ TEST_P(ParameterizedOpConverterTest, ConvertUnary) {
|
|||||||
ADD_OP("Asinh", ops::Asinh, std::asinh);
|
ADD_OP("Asinh", ops::Asinh, std::asinh);
|
||||||
ADD_OP("Atan", ops::Atan, std::atan);
|
ADD_OP("Atan", ops::Atan, std::atan);
|
||||||
ADD_OP("Atanh", ops::Atanh, std::atanh);
|
ADD_OP("Atanh", ops::Atanh, std::atanh);
|
||||||
|
op_map["Cast"] = std::make_pair(CreateCastOp, [](float x) { return x; });
|
||||||
ADD_OP("Ceil", ops::Ceil, std::ceil);
|
ADD_OP("Ceil", ops::Ceil, std::ceil);
|
||||||
ADD_OP("Cos", ops::Cos, std::cos);
|
ADD_OP("Cos", ops::Cos, std::cos);
|
||||||
ADD_OP("Cosh", ops::Cosh, std::cosh);
|
ADD_OP("Cosh", ops::Cosh, std::cosh);
|
||||||
@ -5212,7 +5221,13 @@ TEST_P(ParameterizedOpConverterTest, ConvertUnary) {
|
|||||||
}
|
}
|
||||||
NodeDef node_def = op_map[op_name].first();
|
NodeDef node_def = op_map[op_name].first();
|
||||||
|
|
||||||
AddTestTensor("input", p.input_dims, TfDataTypeToTrt(tf_dtype), trt_mode);
|
// TODO(bixia): we assume this test is only instantiated for DT_FLOAT for
|
||||||
|
// now. Need to find a better way to express input and output types.
|
||||||
|
DataType input_tf_dtype = op_name == "Cast" ? DT_HALF : tf_dtype;
|
||||||
|
DataType output_tf_dtype = tf_dtype;
|
||||||
|
|
||||||
|
AddTestTensor("input", p.input_dims, TfDataTypeToTrt(input_tf_dtype),
|
||||||
|
trt_mode);
|
||||||
RunValidationAndConversion(node_def, Status::OK(), "my_unary",
|
RunValidationAndConversion(node_def, Status::OK(), "my_unary",
|
||||||
p.expected_output_dims);
|
p.expected_output_dims);
|
||||||
|
|
||||||
@ -5220,8 +5235,8 @@ TEST_P(ParameterizedOpConverterTest, ConvertUnary) {
|
|||||||
std::vector<float> output;
|
std::vector<float> output;
|
||||||
std::transform(input_values.begin(), input_values.end(),
|
std::transform(input_values.begin(), input_values.end(),
|
||||||
std::back_inserter(output), op_map[op_name].second);
|
std::back_inserter(output), op_map[op_name].second);
|
||||||
InstantiateBuildAndRun(tf_dtype, "my_unary", this, p, input_values,
|
InstantiateBuildAndRun(input_tf_dtype, output_tf_dtype, "my_unary", this, p,
|
||||||
ArrayFloatNear(output, 0.0001, true));
|
input_values, ArrayFloatNear(output, 0.0001, true));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user