[TF:TRT] Implement cast from fp16 to fp32 with IIdentityLayer.
This is the first CL to implement the request in b/150285802. Add Cast op test to convert_nodes_test. PiperOrigin-RevId: 312093049 Change-Id: I77215cf6da104f51acc93de1b03e9a179db54f0a
This commit is contained in:
parent
46f7108d78
commit
32165792a3
@ -29,6 +29,7 @@ limitations under the License.
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
|
||||
@ -795,6 +796,19 @@ nvinfer1::Dims TRT_TensorOrWeights::GetTrtDims() const {
|
||||
}
|
||||
}
|
||||
|
||||
Status TRT_TensorOrWeights::GetTfType(DataType* tf_type) const {
|
||||
if (is_tensor()) {
|
||||
nvinfer1::DataType trt_type = tensor()->getType();
|
||||
return TrtTypeToTfType(trt_type, tf_type);
|
||||
}
|
||||
|
||||
if (is_weights()) {
|
||||
*tf_type = weights().GetTensor().dtype();
|
||||
return Status::OK();
|
||||
}
|
||||
return errors::Internal("The object is probably not initialized");
|
||||
}
|
||||
|
||||
string TRT_TensorOrWeights::DebugString() const {
|
||||
string output = "TRT_TensorOrWeights(type=";
|
||||
if (is_tensor()) {
|
||||
@ -1900,27 +1914,48 @@ Status CheckInputsWeights(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AllowDataTypes(const OpConverterParams& params,
|
||||
const std::set<DataType>& allowed_dtypes,
|
||||
const char* dtype_attr_name = "T") {
|
||||
const auto& node_def = params.node_def;
|
||||
Status GetNodeDefTfType(const NodeDef& node_def, DataType* tf_type,
|
||||
const char* type_attr_name) {
|
||||
TFAttrs attrs(node_def);
|
||||
if (!attrs.count(dtype_attr_name)) {
|
||||
return errors::InvalidArgument("Attribute with name ", dtype_attr_name,
|
||||
if (!attrs.count(type_attr_name)) {
|
||||
return errors::InvalidArgument("Attribute with name ", type_attr_name,
|
||||
" not found.");
|
||||
}
|
||||
const auto op_dtype = attrs.get<DataType>(dtype_attr_name);
|
||||
if (!allowed_dtypes.count(op_dtype)) {
|
||||
// Build string list of allowed types.
|
||||
std::ostringstream ss;
|
||||
for (auto it = allowed_dtypes.begin(); it != allowed_dtypes.end(); ++it) {
|
||||
if (it != allowed_dtypes.begin()) ss << ", ";
|
||||
ss << DataTypeString(*it);
|
||||
*tf_type = attrs.get<DataType>(type_attr_name);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetInputTfType(const OpConverterParams& params, DataType* tf_type,
|
||||
int pos) {
|
||||
const std::vector<TRT_TensorOrWeights>& inputs = params.inputs;
|
||||
if (inputs.size() <= pos) {
|
||||
return errors::Internal("Invalid input position");
|
||||
}
|
||||
return errors::Unimplemented("Data type ", DataTypeString(op_dtype),
|
||||
|
||||
return inputs[pos].GetTfType(tf_type);
|
||||
}
|
||||
|
||||
constexpr const char kOutputTypeAttrName[] = "T";
|
||||
|
||||
Status GetOutputTfType(const OpConverterParams& params, DataType* tf_type) {
|
||||
return GetNodeDefTfType(params.node_def, tf_type, kOutputTypeAttrName);
|
||||
}
|
||||
|
||||
Status AllowDataTypes(const OpConverterParams& params,
|
||||
const std::set<DataType>& allowed_types,
|
||||
const char* type_attr_name = kOutputTypeAttrName) {
|
||||
const auto& node_def = params.node_def;
|
||||
DataType tf_type;
|
||||
TF_RETURN_IF_ERROR(GetNodeDefTfType(node_def, &tf_type, type_attr_name));
|
||||
if (!allowed_types.count(tf_type)) {
|
||||
string allowed_types_string = absl::StrJoin(
|
||||
allowed_types, ", ", [](string* out, const DataType& type) {
|
||||
absl::StrAppendFormat(out, "%s", DataTypeString(type));
|
||||
});
|
||||
return errors::Unimplemented("Data type ", DataTypeString(tf_type),
|
||||
" is not supported for ", node_def.op(),
|
||||
", must be one of [", ss.str(), "], at ",
|
||||
node_def.name());
|
||||
", must be one of [", allowed_types_string,
|
||||
"], at ", node_def.name());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -4598,6 +4633,42 @@ Status ConvertUnpack(OpConverterParams* params) {
|
||||
return ConvertSplitHelper(params, inputs.at(0), tf_axis, num, true);
|
||||
}
|
||||
|
||||
// Supports cast fp16=>fp32 through IIdentityLayer.
|
||||
Status ConvertCast(OpConverterParams* params) {
|
||||
const NodeDef& node_def = params->node_def;
|
||||
TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}}));
|
||||
auto unsupport_cast_error = [&]() {
|
||||
return errors::Unimplemented("Cast op: ", node_def.op(),
|
||||
" not supported at: ", node_def.name());
|
||||
};
|
||||
|
||||
DataType input_type;
|
||||
TF_RETURN_IF_ERROR(GetInputTfType(*params, &input_type, 0));
|
||||
if (input_type != DataType::DT_HALF) {
|
||||
return unsupport_cast_error();
|
||||
}
|
||||
|
||||
DataType output_type;
|
||||
TF_RETURN_IF_ERROR(GetOutputTfType(*params, &output_type));
|
||||
if (output_type != DataType::DT_FLOAT) {
|
||||
return unsupport_cast_error();
|
||||
}
|
||||
|
||||
if (params->validation_only) return Status::OK();
|
||||
|
||||
nvinfer1::ITensor* input = params->inputs.at(0).tensor();
|
||||
nvinfer1::IIdentityLayer* layer =
|
||||
params->converter->network()->addIdentity(*input);
|
||||
layer->setPrecision(nvinfer1::DataType::kFLOAT);
|
||||
|
||||
if (layer->getOutput(0)->getType() != nvinfer1::DataType::kFLOAT) {
|
||||
return errors::Internal("IIdentityLayer doesn't work as expected");
|
||||
}
|
||||
|
||||
params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0)));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConvertConcat(OpConverterParams* params) {
|
||||
const auto& inputs = params->inputs;
|
||||
const auto& node_def = params->node_def;
|
||||
@ -5675,6 +5746,7 @@ static void RegisterValidatableOpConverters(
|
||||
(*registration)["CombinedNonMaxSuppression"] = ConvertCombinedNMS;
|
||||
#endif
|
||||
(*registration)["AddN"] = ConvertAddN;
|
||||
(*registration)["Cast"] = ConvertCast;
|
||||
(*registration)["ConcatV2"] = ConvertConcat;
|
||||
(*registration)["Const"] = ConvertConst;
|
||||
(*registration)["Conv2D"] = ConvertConv2D;
|
||||
|
@ -294,6 +294,8 @@ class TRT_TensorOrWeights {
|
||||
|
||||
nvinfer1::Dims GetTrtDims() const;
|
||||
|
||||
Status GetTfType(DataType* tf_type) const;
|
||||
|
||||
int batch_size() const { return batch_size_; }
|
||||
|
||||
string DebugString() const;
|
||||
|
@ -5147,6 +5147,14 @@ NodeDef CreateUnaryOp() {
|
||||
return T(s.WithOpName("my_unary"), input).operation.node()->def();
|
||||
}
|
||||
|
||||
NodeDef CreateCastOp() {
|
||||
Scope s = Scope::NewRootScope();
|
||||
auto input = ops::Placeholder(s.WithOpName("input"), DT_HALF);
|
||||
return ops::Cast(s.WithOpName("my_unary"), input, DT_FLOAT)
|
||||
.operation.node()
|
||||
->def();
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedOpConverterTest, ConvertUnary) {
|
||||
const auto& spec = GetParam();
|
||||
const TrtTestMode trt_mode = std::get<0>(spec);
|
||||
@ -5174,6 +5182,7 @@ TEST_P(ParameterizedOpConverterTest, ConvertUnary) {
|
||||
ADD_OP("Asinh", ops::Asinh, std::asinh);
|
||||
ADD_OP("Atan", ops::Atan, std::atan);
|
||||
ADD_OP("Atanh", ops::Atanh, std::atanh);
|
||||
op_map["Cast"] = std::make_pair(CreateCastOp, [](float x) { return x; });
|
||||
ADD_OP("Ceil", ops::Ceil, std::ceil);
|
||||
ADD_OP("Cos", ops::Cos, std::cos);
|
||||
ADD_OP("Cosh", ops::Cosh, std::cosh);
|
||||
@ -5212,7 +5221,13 @@ TEST_P(ParameterizedOpConverterTest, ConvertUnary) {
|
||||
}
|
||||
NodeDef node_def = op_map[op_name].first();
|
||||
|
||||
AddTestTensor("input", p.input_dims, TfDataTypeToTrt(tf_dtype), trt_mode);
|
||||
// TODO(bixia): we assume this test is only instantiated for DT_FLOAT for
|
||||
// now. Need to find a better way to express input and output types.
|
||||
DataType input_tf_dtype = op_name == "Cast" ? DT_HALF : tf_dtype;
|
||||
DataType output_tf_dtype = tf_dtype;
|
||||
|
||||
AddTestTensor("input", p.input_dims, TfDataTypeToTrt(input_tf_dtype),
|
||||
trt_mode);
|
||||
RunValidationAndConversion(node_def, Status::OK(), "my_unary",
|
||||
p.expected_output_dims);
|
||||
|
||||
@ -5220,8 +5235,8 @@ TEST_P(ParameterizedOpConverterTest, ConvertUnary) {
|
||||
std::vector<float> output;
|
||||
std::transform(input_values.begin(), input_values.end(),
|
||||
std::back_inserter(output), op_map[op_name].second);
|
||||
InstantiateBuildAndRun(tf_dtype, "my_unary", this, p, input_values,
|
||||
ArrayFloatNear(output, 0.0001, true));
|
||||
InstantiateBuildAndRun(input_tf_dtype, output_tf_dtype, "my_unary", this, p,
|
||||
input_values, ArrayFloatNear(output, 0.0001, true));
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user