Use TensorRT DataType in TRT_ShapedWeights, so that we know the weights are always valid.

Also fix a bug in where it converts DT_INT8 (non-quantized type) to nvinfer1::kINT8  (quantized type).

PiperOrigin-RevId: 241269981
This commit is contained in:
Guangda Lai 2019-04-01 00:13:46 -07:00 committed by TensorFlower Gardener
parent 9fd90ad4c7
commit 6ff8dd7b5b
3 changed files with 156 additions and 111 deletions

View File

@ -98,15 +98,12 @@ namespace convert {
using absl::StrAppend; using absl::StrAppend;
using absl::StrCat; using absl::StrCat;
inline Status ConvertDType(DataType tf_dtype, nvinfer1::DataType* trt_dtype) { inline Status TfDataTypeToTrt(DataType tf_dtype,
nvinfer1::DataType* trt_dtype) {
switch (tf_dtype) { switch (tf_dtype) {
case DataType::DT_FLOAT: case DataType::DT_FLOAT:
*trt_dtype = nvinfer1::DataType::kFLOAT; *trt_dtype = nvinfer1::DataType::kFLOAT;
break; break;
// TODO(aaroey): this should be DT_QINT8 which is not a well supported type.
case DataType::DT_INT8:
*trt_dtype = nvinfer1::DataType::kINT8;
break;
case DataType::DT_HALF: case DataType::DT_HALF:
*trt_dtype = nvinfer1::DataType::kHALF; *trt_dtype = nvinfer1::DataType::kHALF;
break; break;
@ -120,6 +117,25 @@ inline Status ConvertDType(DataType tf_dtype, nvinfer1::DataType* trt_dtype) {
return Status::OK(); return Status::OK();
} }
inline Status TrtDataTypeToTf(nvinfer1::DataType trt_dtype,
DataType* tf_dtype) {
switch (trt_dtype) {
case nvinfer1::DataType::kFLOAT:
*tf_dtype = DataType::DT_FLOAT;
break;
case nvinfer1::DataType::kHALF:
*tf_dtype = DataType::DT_HALF;
break;
case nvinfer1::DataType::kINT32:
*tf_dtype = DataType::DT_INT32;
break;
default:
return errors::InvalidArgument("Unsupported data type ",
DebugString(trt_dtype));
}
return Status::OK();
}
class TFAttrs { class TFAttrs {
public: public:
explicit TFAttrs(const NodeDef& tf_node) { explicit TFAttrs(const NodeDef& tf_node) {
@ -178,7 +194,7 @@ std::vector<float> TFAttrs::get<std::vector<float>>(const string& key) const {
template <> template <>
nvinfer1::DataType TFAttrs::get<nvinfer1::DataType>(const string& key) const { nvinfer1::DataType TFAttrs::get<nvinfer1::DataType>(const string& key) const {
nvinfer1::DataType trt_dtype(nvinfer1::DataType::kFLOAT); nvinfer1::DataType trt_dtype(nvinfer1::DataType::kFLOAT);
TF_CHECK_OK(ConvertDType(this->at(key)->type(), &trt_dtype)); TF_CHECK_OK(TfDataTypeToTrt(this->at(key)->type(), &trt_dtype));
return trt_dtype; return trt_dtype;
} }
@ -268,7 +284,7 @@ Status ValidateTensorProperties(const string& producer_node_type,
nvinfer1::DataType* trt_dtype, nvinfer1::DataType* trt_dtype,
nvinfer1::Dims* trt_dims, int* batch_size) { nvinfer1::Dims* trt_dims, int* batch_size) {
// Convert data type. // Convert data type.
TF_RETURN_IF_ERROR(ConvertDType(dtype, trt_dtype)); TF_RETURN_IF_ERROR(TfDataTypeToTrt(dtype, trt_dtype));
// Convert shape. // Convert shape.
if (shape.dims() < 0) { if (shape.dims() < 0) {
@ -472,12 +488,12 @@ Status CreateBroadcastableScalarConstant(OpConverterParams* params, float value,
const nvinfer1::Dims& dims, const nvinfer1::Dims& dims,
nvinfer1::ITensor** tensor, nvinfer1::ITensor** tensor,
const char* dtype_attr_name = "T") { const char* dtype_attr_name = "T") {
nvinfer1::DataType trt_dtype =
nvinfer1::DataType::kFLOAT; // Default to FP32.
TFAttrs attrs(params->node_def); TFAttrs attrs(params->node_def);
DataType dtype;
if (attrs.count(dtype_attr_name)) { if (attrs.count(dtype_attr_name)) {
dtype = attrs.get<DataType>(dtype_attr_name); DataType dtype = attrs.get<DataType>(dtype_attr_name);
} else { TF_RETURN_IF_ERROR(TfDataTypeToTrt(dtype, &trt_dtype));
dtype = DT_FLOAT; // Default to FP32.
} }
// In order to be broadcastable, the number of dims has to match. // In order to be broadcastable, the number of dims has to match.
@ -486,18 +502,18 @@ Status CreateBroadcastableScalarConstant(OpConverterParams* params, float value,
broadcastable_dims.d[i] = 1; broadcastable_dims.d[i] = 1;
} }
TRT_ShapedWeights weights = TRT_ShapedWeights weights =
params->weight_store->GetTempWeights(dtype, broadcastable_dims); params->weight_store->GetTempWeights(trt_dtype, broadcastable_dims);
void* raw_ptr = weights.GetValues(); void* raw_ptr = weights.GetValues();
switch (dtype) { switch (trt_dtype) {
case DataType::DT_FLOAT: case nvinfer1::DataType::kFLOAT:
static_cast<float*>(raw_ptr)[0] = value; static_cast<float*>(raw_ptr)[0] = value;
break; break;
case DataType::DT_HALF: case nvinfer1::DataType::kHALF:
static_cast<Eigen::half*>(raw_ptr)[0] = Eigen::half(value); static_cast<Eigen::half*>(raw_ptr)[0] = Eigen::half(value);
break; break;
default: default:
return errors::InvalidArgument("Unsupported data type ", return errors::InvalidArgument("Unsupported data type ",
DataTypeString(dtype)); DebugString(trt_dtype));
} }
*tensor = params->converter->CreateConstantLayer(weights, broadcastable_dims); *tensor = params->converter->CreateConstantLayer(weights, broadcastable_dims);
TFTRT_RETURN_ERROR_IF_NULLPTR(*tensor, params->node_def.name()); TFTRT_RETURN_ERROR_IF_NULLPTR(*tensor, params->node_def.name());
@ -676,12 +692,12 @@ Status VerifyShapesMatch(absl::Span<const TRT_TensorOrWeights> inputs,
return Status::OK(); return Status::OK();
} }
TRT_ShapedWeights::TRT_ShapedWeights(DataType type) : type_(type) { TRT_ShapedWeights::TRT_ShapedWeights(nvinfer1::DataType type) : type_(type) {
shape_.nbDims = 0; shape_.nbDims = 0;
} }
TRT_ShapedWeights::TRT_ShapedWeights(DataType type, nvinfer1::Dims dims, TRT_ShapedWeights::TRT_ShapedWeights(nvinfer1::DataType type,
Tensor tensor) nvinfer1::Dims dims, Tensor tensor)
: shape_(dims), type_(type), tensor_(tensor) {} : shape_(dims), type_(type), tensor_(tensor) {}
TRT_ShapedWeights::TRT_ShapedWeights(const TRT_ShapedWeights& rhs) TRT_ShapedWeights::TRT_ShapedWeights(const TRT_ShapedWeights& rhs)
@ -692,18 +708,29 @@ int64_t TRT_ShapedWeights::count() const {
} }
nvinfer1::Weights TRT_ShapedWeights::GetTrtWeights() const { nvinfer1::Weights TRT_ShapedWeights::GetTrtWeights() const {
nvinfer1::DataType trt_type(nvinfer1::DataType::kFLOAT); return nvinfer1::Weights{type_, GetValues(), count()};
TF_CHECK_OK(ConvertDType(type_, &trt_type));
return nvinfer1::Weights{trt_type, GetValues(), count()};
} }
size_t TRT_ShapedWeights::size_bytes() const { size_t TRT_ShapedWeights::size_bytes() const {
return this->count() * DataTypeSize(this->type_); size_t data_type_size = -1;
switch (type_) {
case nvinfer1::DataType::kFLOAT:
case nvinfer1::DataType::kINT32:
data_type_size = 4;
break;
case nvinfer1::DataType::kHALF:
data_type_size = 2;
break;
case nvinfer1::DataType::kINT8:
data_type_size = 1;
break;
}
return this->count() * data_type_size;
} }
string TRT_ShapedWeights::DebugString() const { string TRT_ShapedWeights::DebugString() const {
return StrCat("TRT_ShapedWeights(shape=", convert::DebugString(shape_), return StrCat("TRT_ShapedWeights(shape=", convert::DebugString(shape_),
", type=", DataTypeString(type_), ", type=", convert::DebugString(type_),
", values=", reinterpret_cast<uintptr_t>(GetValues()), ")"); ", values=", reinterpret_cast<uintptr_t>(GetValues()), ")");
} }
@ -867,13 +894,13 @@ void ReorderCKtoKC(const TRT_ShapedWeights& iweights,
oweights->shape_.d[1] = c; oweights->shape_.d[1] = c;
const nvinfer1::DimsHW istrides = {1, k}; const nvinfer1::DimsHW istrides = {1, k};
const nvinfer1::DimsHW ostrides = {c, 1}; const nvinfer1::DimsHW ostrides = {c, 1};
switch (iweights.type_) { switch (iweights.TrtDType()) {
case DataType::DT_FLOAT: { case nvinfer1::DataType::kFLOAT: {
Reorder2({k, c}, static_cast<float const*>(iweights.GetValues()), Reorder2({k, c}, static_cast<float const*>(iweights.GetValues()),
istrides, static_cast<float*>(oweights->GetValues()), ostrides); istrides, static_cast<float*>(oweights->GetValues()), ostrides);
break; break;
} }
case DataType::DT_HALF: { case nvinfer1::DataType::kHALF: {
Reorder2({k, c}, static_cast<Eigen::half const*>(iweights.GetValues()), Reorder2({k, c}, static_cast<Eigen::half const*>(iweights.GetValues()),
istrides, static_cast<Eigen::half*>(oweights->GetValues()), istrides, static_cast<Eigen::half*>(oweights->GetValues()),
ostrides); ostrides);
@ -881,13 +908,13 @@ void ReorderCKtoKC(const TRT_ShapedWeights& iweights,
} }
default: default:
LOG(FATAL) << "Unsupported type in reorder expected fp32 or fp16 but got " LOG(FATAL) << "Unsupported type in reorder expected fp32 or fp16 but got "
<< DataTypeString(iweights.type_); << DebugString(iweights.TrtDType());
} }
} }
void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights, void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights,
TRT_ShapedWeights* oweights, const int num_groups) { TRT_ShapedWeights* oweights, const int num_groups) {
CHECK_EQ(iweights.type_, oweights->type_); CHECK(iweights.TrtDType() == oweights->TrtDType());
CHECK_EQ(iweights.size_bytes(), oweights->size_bytes()); CHECK_EQ(iweights.size_bytes(), oweights->size_bytes());
// K indexes over output channels, C over input channels, and R and S over the // K indexes over output channels, C over input channels, and R and S over the
// height and width of the convolution // height and width of the convolution
@ -906,13 +933,13 @@ void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights,
oweights->shape_.d[3] = s; oweights->shape_.d[3] = s;
const nvinfer1::DimsNCHW istrides = {1, k, s * k * c, c * k}; const nvinfer1::DimsNCHW istrides = {1, k, s * k * c, c * k};
const nvinfer1::DimsNCHW ostrides = {c * r * s, r * s, s, 1}; const nvinfer1::DimsNCHW ostrides = {c * r * s, r * s, s, 1};
switch (iweights.type_) { switch (iweights.TrtDType()) {
case DataType::DT_FLOAT: { case nvinfer1::DataType::kFLOAT: {
Reorder4({k, c, r, s}, static_cast<float const*>(iweights.GetValues()), Reorder4({k, c, r, s}, static_cast<float const*>(iweights.GetValues()),
istrides, static_cast<float*>(oweights->GetValues()), ostrides); istrides, static_cast<float*>(oweights->GetValues()), ostrides);
break; break;
} }
case DataType::DT_HALF: { case nvinfer1::DataType::kHALF: {
Reorder4({k, c, r, s}, Reorder4({k, c, r, s},
static_cast<Eigen::half const*>(iweights.GetValues()), istrides, static_cast<Eigen::half const*>(iweights.GetValues()), istrides,
static_cast<Eigen::half*>(oweights->GetValues()), ostrides); static_cast<Eigen::half*>(oweights->GetValues()), ostrides);
@ -921,18 +948,20 @@ void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights,
default: default:
LOG(FATAL) << "Unsupported type, expected fp32 or fp16 but got " LOG(FATAL) << "Unsupported type, expected fp32 or fp16 but got "
<< DataTypeString(iweights.type_); << DebugString(iweights.TrtDType());
} }
} }
TRT_ShapedWeights TrtWeightStore::GetTempWeights(DataType type, TRT_ShapedWeights TrtWeightStore::GetTempWeights(nvinfer1::DataType trt_dtype,
const nvinfer1::Dims& dims) { const nvinfer1::Dims& dims) {
TensorShape shape; TensorShape shape;
DataType tf_dtype;
// TODO(laigd): make it return a status. // TODO(laigd): make it return a status.
TF_CHECK_OK(TensorShapeUtils::MakeShape(dims.d, dims.nbDims, &shape)); TF_CHECK_OK(TensorShapeUtils::MakeShape(dims.d, dims.nbDims, &shape));
TF_CHECK_OK(TrtDataTypeToTf(trt_dtype, &tf_dtype));
// TODO(jie): check weights size_bytes. 0 means type error // TODO(jie): check weights size_bytes. 0 means type error
Tensor tensor(type, shape); Tensor tensor(tf_dtype, shape);
TRT_ShapedWeights weights(type, dims, tensor); TRT_ShapedWeights weights(trt_dtype, dims, tensor);
store_.emplace_back(std::move(tensor)); store_.emplace_back(std::move(tensor));
return weights; return weights;
} }
@ -1282,22 +1311,22 @@ Status Converter::TransposeTensor(nvinfer1::ITensor* input_tensor,
Status Converter::GetWeightRange(const TRT_ShapedWeights& weights, Status Converter::GetWeightRange(const TRT_ShapedWeights& weights,
float* out_min, float* out_max) const { float* out_min, float* out_max) const {
switch (weights.type_) { switch (weights.TrtDType()) {
case DataType::DT_FLOAT: { case nvinfer1::DataType::kFLOAT: {
auto inp = static_cast<float const*>(weights.GetValues()); auto inp = static_cast<float const*>(weights.GetValues());
auto result = std::minmax_element(inp, inp + weights.count()); auto result = std::minmax_element(inp, inp + weights.count());
*out_min = *result.first; *out_min = *result.first;
*out_max = *result.second; *out_max = *result.second;
break; break;
} }
case DataType::DT_HALF: { case nvinfer1::DataType::kHALF: {
auto inp = static_cast<Eigen::half const*>(weights.GetValues()); auto inp = static_cast<Eigen::half const*>(weights.GetValues());
auto result = std::minmax_element(inp, inp + weights.count()); auto result = std::minmax_element(inp, inp + weights.count());
*out_min = Eigen::half_impl::half_to_float(*result.first); *out_min = Eigen::half_impl::half_to_float(*result.first);
*out_max = Eigen::half_impl::half_to_float(*result.second); *out_max = Eigen::half_impl::half_to_float(*result.second);
break; break;
} }
case DataType::DT_INT32: { case nvinfer1::DataType::kINT32: {
auto inp = static_cast<int const*>(weights.GetValues()); auto inp = static_cast<int const*>(weights.GetValues());
auto result = std::minmax_element(inp, inp + weights.count()); auto result = std::minmax_element(inp, inp + weights.count());
*out_min = static_cast<float>(*result.first); *out_min = static_cast<float>(*result.first);
@ -1307,7 +1336,7 @@ Status Converter::GetWeightRange(const TRT_ShapedWeights& weights,
default: default:
return errors::Unimplemented( return errors::Unimplemented(
"Data type not supported for GetWeightRange: ", "Data type not supported for GetWeightRange: ",
DataTypeString(weights.type_)); DebugString(weights.TrtDType()));
} }
return Status::OK(); return Status::OK();
} }
@ -1562,9 +1591,8 @@ Status AllowDataTypes(const OpConverterParams& params,
TRT_ShapedWeights ConvertFP32ToFP16(TrtWeightStore* store, TRT_ShapedWeights ConvertFP32ToFP16(TrtWeightStore* store,
const TRT_ShapedWeights& weights_src) { const TRT_ShapedWeights& weights_src) {
auto dtype_new = DataType::DT_HALF;
TRT_ShapedWeights weights = TRT_ShapedWeights weights =
store->GetTempWeights(dtype_new, weights_src.shape_); store->GetTempWeights(nvinfer1::DataType::kHALF, weights_src.shape_);
const float* src = static_cast<const float*>(weights_src.GetValues()); const float* src = static_cast<const float*>(weights_src.GetValues());
Eigen::half* dst = static_cast<Eigen::half*>(weights.GetValues()); Eigen::half* dst = static_cast<Eigen::half*>(weights.GetValues());
for (int64_t i = 0; i < weights_src.count(); i++) { for (int64_t i = 0; i < weights_src.count(); i++) {
@ -1622,15 +1650,15 @@ std::function<Eigen::half(Eigen::half)> LambdaFactory::unary<Eigen::half>() {
Status UnaryCompute(const TRT_ShapedWeights& iweights, Status UnaryCompute(const TRT_ShapedWeights& iweights,
TRT_ShapedWeights* oweights, LambdaFactory unary_op) { TRT_ShapedWeights* oweights, LambdaFactory unary_op) {
CHECK_EQ(iweights.type_, oweights->type_); CHECK(iweights.TrtDType() == oweights->TrtDType());
switch (iweights.type_) { switch (iweights.TrtDType()) {
case DataType::DT_FLOAT: { case nvinfer1::DataType::kFLOAT: {
auto inp = static_cast<float const*>(iweights.GetValues()); auto inp = static_cast<float const*>(iweights.GetValues());
auto oup = static_cast<float*>(oweights->GetValues()); auto oup = static_cast<float*>(oweights->GetValues());
std::transform(inp, inp + iweights.count(), oup, unary_op.unary<float>()); std::transform(inp, inp + iweights.count(), oup, unary_op.unary<float>());
break; break;
} }
case DataType::DT_HALF: { case nvinfer1::DataType::kHALF: {
auto inp = static_cast<Eigen::half const*>(iweights.GetValues()); auto inp = static_cast<Eigen::half const*>(iweights.GetValues());
auto oup = static_cast<Eigen::half*>(oweights->GetValues()); auto oup = static_cast<Eigen::half*>(oweights->GetValues());
std::transform(inp, inp + iweights.count(), oup, std::transform(inp, inp + iweights.count(), oup,
@ -1638,8 +1666,8 @@ Status UnaryCompute(const TRT_ShapedWeights& iweights,
break; break;
} }
default: default:
return errors::Unimplemented("Data type not supported: " + return errors::Unimplemented("Data type not supported: ",
DataTypeString(iweights.type_)); DebugString(iweights.TrtDType()));
} }
return Status::OK(); return Status::OK();
} }
@ -1660,10 +1688,6 @@ Status BinaryTensorOpWeight(OpConverterParams* params,
node_def.name()); node_def.name());
} }
// Check type consistency.
nvinfer1::DataType trt_dtype;
TF_RETURN_IF_ERROR(ConvertDType(weights.type_, &trt_dtype));
// Check scale mode. // Check scale mode.
auto dims_w = weights.shape_; auto dims_w = weights.shape_;
const auto dims_t = tensor->getDimensions(); const auto dims_t = tensor->getDimensions();
@ -1753,9 +1777,9 @@ Status BinaryTensorOpWeight(OpConverterParams* params,
} }
// Prepare weights // Prepare weights
TRT_ShapedWeights shift_weights(weights.type_); TRT_ShapedWeights shift_weights(weights.TrtDType());
TRT_ShapedWeights scale_weights(weights.type_); TRT_ShapedWeights scale_weights(weights.TrtDType());
TRT_ShapedWeights power_weights(weights.type_); TRT_ShapedWeights power_weights(weights.TrtDType());
if (node_def.op() == "Sub") { if (node_def.op() == "Sub") {
if (swapped_inputs) { if (swapped_inputs) {
@ -1922,7 +1946,7 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group,
TRT_ShapedWeights weights = TRT_ShapedWeights weights =
params->weight_store->GetTempWeights(weights_rsck); params->weight_store->GetTempWeights(weights_rsck);
ReorderRSCKToKCRS(weights_rsck, &weights, num_groups); ReorderRSCKToKCRS(weights_rsck, &weights, num_groups);
TRT_ShapedWeights biases(weights.type_); TRT_ShapedWeights biases(weights.TrtDType());
const int output_axis = is_conv2d_backprop_input ? 1 : 0; const int output_axis = is_conv2d_backprop_input ? 1 : 0;
const int noutput = weights.shape_.d[output_axis] * num_groups; const int noutput = weights.shape_.d[output_axis] * num_groups;
nvinfer1::DimsHW kernel_size; nvinfer1::DimsHW kernel_size;
@ -3022,7 +3046,7 @@ Status ConvertBiasAdd(OpConverterParams* params) {
mode = nvinfer1::ScaleMode::kUNIFORM; mode = nvinfer1::ScaleMode::kUNIFORM;
} }
TRT_ShapedWeights empty_weights(weights.type_); TRT_ShapedWeights empty_weights(weights.TrtDType());
nvinfer1::IScaleLayer* layer = params->converter->network()->addScale( nvinfer1::IScaleLayer* layer = params->converter->network()->addScale(
*tensor, mode, weights.GetTrtWeights(), empty_weights.GetTrtWeights(), *tensor, mode, weights.GetTrtWeights(), empty_weights.GetTrtWeights(),
empty_weights.GetTrtWeights()); empty_weights.GetTrtWeights());
@ -3072,33 +3096,41 @@ void GetTensorDimsWithProtoShape(const Tensor& tensor, nvinfer1::Dims* dims) {
} }
} }
template <DataType dtype>
void CopyToTrtInt32Array(const Tensor& tensor, int32* dst) {
typedef typename EnumToDataType<dtype>::Type CType;
const CType* src = tensor.flat<CType>().data();
std::copy(src, src + tensor.NumElements(), dst);
}
Status TfTensorToTrtWeights(const Tensor& tensor, TrtWeightStore* weight_store, Status TfTensorToTrtWeights(const Tensor& tensor, TrtWeightStore* weight_store,
TRT_ShapedWeights* weights) { TRT_ShapedWeights* weights) {
const DataType dtype = tensor.dtype(); const DataType dtype = tensor.dtype();
// We always convert the integer constants to INT32, since TRT INT8 is for // We always convert the integer constants to INT32.
// quantized inference.
// //
// TODO(aaroey): FP16 will remain in half format and is not converted to // TODO(aaroey): FP16 will remain in half format and is not converted to
// FP32, but the converter currently uses all float weights as FP32. Fix // FP32, but the converter currently uses all float weights as FP32. Fix
// this. // this.
const DataType converted_dtype = DataType converted_dtype = dtype;
(dtype == DT_INT16 || dtype == DT_INT8 || dtype == DT_UINT8 ? DT_INT32 if (dtype == DataType::DT_INT8 || dtype == DataType::DT_UINT8 ||
: dtype); dtype == DataType::DT_INT16 || dtype == DataType::DT_UINT16) {
converted_dtype = DT_INT32;
}
// Verify that the dtype is supported by TensorRT. Otherwise, return an error. // Verify that the dtype is supported by TensorRT. Otherwise, return an error.
nvinfer1::DataType trt_dtype; nvinfer1::DataType trt_dtype;
TF_RETURN_IF_ERROR(ConvertDType(converted_dtype, &trt_dtype)); TF_RETURN_IF_ERROR(TfDataTypeToTrt(converted_dtype, &trt_dtype));
if (tensor.NumElements() == 0) { if (tensor.NumElements() == 0) {
// Return empty weights having converted dtype. // Return empty weights.
*weights = TRT_ShapedWeights(converted_dtype); *weights = TRT_ShapedWeights(trt_dtype);
return Status::OK(); return Status::OK();
} }
nvinfer1::Dims weight_dims; nvinfer1::Dims weight_dims;
GetTensorDimsWithProtoShape(tensor, &weight_dims); GetTensorDimsWithProtoShape(tensor, &weight_dims);
*weights = weight_store->GetTempWeights(converted_dtype, weight_dims); *weights = weight_store->GetTempWeights(trt_dtype, weight_dims);
// Copy the tensor directly if the tensor does not require cast to the // Copy the tensor directly if the tensor does not require cast to the
// supported type. // supported type.
@ -3110,17 +3142,21 @@ Status TfTensorToTrtWeights(const Tensor& tensor, TrtWeightStore* weight_store,
// Copy tensor elements after casting them to the converted DataType. // Copy tensor elements after casting them to the converted DataType.
int32* dst = static_cast<int32*>(weights->GetValues()); int32* dst = static_cast<int32*>(weights->GetValues());
if (dtype == DT_INT16) { switch (dtype) {
const int16* src = tensor.flat<int16>().data(); case DT_INT8:
std::copy(src, src + tensor.NumElements(), dst); CopyToTrtInt32Array<DT_INT8>(tensor, dst);
} else if (dtype == DT_INT8) { break;
const int8* src = tensor.flat<int8>().data(); case DT_UINT8:
std::copy(src, src + tensor.NumElements(), dst); CopyToTrtInt32Array<DT_UINT8>(tensor, dst);
} else { break;
// dtype can only be DT_UINT8 at this point. case DT_INT16:
TFTRT_CHECK_EQ_TYPE(dtype, DT_UINT8); CopyToTrtInt32Array<DT_INT16>(tensor, dst);
const uint8* src = tensor.flat<uint8>().data(); break;
std::copy(src, src + tensor.NumElements(), dst); case DT_UINT16:
CopyToTrtInt32Array<DT_UINT16>(tensor, dst);
break;
default:
return errors::Internal("Unexpected DataType: ", DataTypeString(dtype));
} }
return Status::OK(); return Status::OK();
} }
@ -3782,15 +3818,15 @@ Status ConvertFusedBatchNorm(OpConverterParams* params) {
nvinfer1::ITensor* tensor = inputs.at(0).tensor(); nvinfer1::ITensor* tensor = inputs.at(0).tensor();
// Check parameter types // Check parameter types
auto parameter_type = inputs.at(1).weights().type_; auto parameter_type = inputs.at(1).weights().TrtDType();
if ((parameter_type != DataType::DT_FLOAT) && if ((parameter_type != nvinfer1::DataType::kFLOAT) &&
(parameter_type != DataType::DT_HALF)) { (parameter_type != nvinfer1::DataType::kHALF)) {
return errors::Unimplemented( return errors::Unimplemented(
"only float32 or float16 weight data type is supported, for node " + "Only float32 or float16 weight data type is supported, for node ",
node_def.name() + " got " + DataTypeString(parameter_type)); node_def.name(), " got ", DebugString(parameter_type));
} }
for (int i = 1; i < 5; i++) { for (int i = 1; i < 5; i++) {
if (inputs.at(i).weights().type_ != parameter_type) { if (inputs.at(i).weights().TrtDType() != parameter_type) {
return errors::Unimplemented( return errors::Unimplemented(
"Inconsistent parameter type for batchnorm is not supported, at: " + "Inconsistent parameter type for batchnorm is not supported, at: " +
node_def.name()); node_def.name());
@ -3841,16 +3877,16 @@ Status ConvertFusedBatchNorm(OpConverterParams* params) {
float batchnorm_data[4]; float batchnorm_data[4];
for (int j = 0; j < 4; j++) { for (int j = 0; j < 4; j++) {
if (inputs.at(j + 1).weights().count() != 1) { if (inputs.at(j + 1).weights().count() != 1) {
if (parameter_type == DT_FLOAT) { if (parameter_type == nvinfer1::DataType::kFLOAT) {
batchnorm_data[j] = vals_array[j][i]; batchnorm_data[j] = vals_array[j][i];
} else if (parameter_type == DT_HALF) { } else if (parameter_type == nvinfer1::DataType::kHALF) {
batchnorm_data[j] = batchnorm_data[j] =
Eigen::half_impl::half_to_float(cast_vals_array[j][i]); Eigen::half_impl::half_to_float(cast_vals_array[j][i]);
} }
} else { } else {
if (parameter_type == DT_FLOAT) { if (parameter_type == nvinfer1::DataType::kFLOAT) {
batchnorm_data[j] = vals_array[j][0]; batchnorm_data[j] = vals_array[j][0];
} else if (parameter_type == DT_HALF) { } else if (parameter_type == nvinfer1::DataType::kHALF) {
batchnorm_data[j] = batchnorm_data[j] =
Eigen::half_impl::half_to_float(cast_vals_array[j][0]); Eigen::half_impl::half_to_float(cast_vals_array[j][0]);
} }
@ -3862,10 +3898,10 @@ Status ConvertFusedBatchNorm(OpConverterParams* params) {
float variance = batchnorm_data[3]; float variance = batchnorm_data[3];
float combined_scale_val = scale / sqrtf(variance + epsilon); float combined_scale_val = scale / sqrtf(variance + epsilon);
float combined_offset_val = offset - mean * combined_scale_val; float combined_offset_val = offset - mean * combined_scale_val;
if (parameter_type == DT_FLOAT) { if (parameter_type == nvinfer1::DataType::kFLOAT) {
combined_scale_vals[i] = combined_scale_val; combined_scale_vals[i] = combined_scale_val;
combined_offset_vals[i] = combined_offset_val; combined_offset_vals[i] = combined_offset_val;
} else if (parameter_type == DT_HALF) { } else if (parameter_type == nvinfer1::DataType::kHALF) {
cast_combined_scale_vals[i] = Eigen::half(combined_scale_val); cast_combined_scale_vals[i] = Eigen::half(combined_scale_val);
cast_combined_offset_vals[i] = Eigen::half(combined_offset_val); cast_combined_offset_vals[i] = Eigen::half(combined_offset_val);
} }
@ -3962,14 +3998,14 @@ Status ConvertMatMulHelper(OpConverterParams* params,
} }
nvinfer1::ITensor* tensor = tensor_input.tensor(); nvinfer1::ITensor* tensor = tensor_input.tensor();
TRT_ShapedWeights weights(weights_raw.type_); TRT_ShapedWeights weights(weights_raw.TrtDType());
if (transpose_weight) { if (transpose_weight) {
weights = weights_raw; weights = weights_raw;
} else { } else {
weights = params->weight_store->GetTempWeights(weights_raw); weights = params->weight_store->GetTempWeights(weights_raw);
ReorderCKtoKC(weights_raw, &weights); ReorderCKtoKC(weights_raw, &weights);
} }
TRT_ShapedWeights biases(weights.type_); TRT_ShapedWeights biases(weights.TrtDType());
int noutput = weights.shape_.d[0]; int noutput = weights.shape_.d[0];
@ -4472,7 +4508,7 @@ Status ConvertGraphDefToEngine(
TFAttrs attrs(node_def); TFAttrs attrs(node_def);
DataType tf_dtype = attrs.get<DataType>("T"); DataType tf_dtype = attrs.get<DataType>("T");
nvinfer1::DataType trt_dtype; nvinfer1::DataType trt_dtype;
TF_RETURN_IF_ERROR(ConvertDType(tf_dtype, &trt_dtype)); TF_RETURN_IF_ERROR(TfDataTypeToTrt(tf_dtype, &trt_dtype));
if (output_tensors.size() <= slot_number) { if (output_tensors.size() <= slot_number) {
output_tensors.resize(slot_number + 1); output_tensors.resize(slot_number + 1);
} }

View File

@ -176,7 +176,8 @@ int64_t TrtTensorDimsNumElements(const nvinfer1::Dims& dims);
// Class to convert TF compile-time constants (e.g. Const nodes) to TRT weight. // Class to convert TF compile-time constants (e.g. Const nodes) to TRT weight.
class TRT_ShapedWeights { class TRT_ShapedWeights {
public: public:
explicit TRT_ShapedWeights(DataType type = DT_FLOAT); explicit TRT_ShapedWeights(
nvinfer1::DataType type = nvinfer1::DataType::kFLOAT);
// Copy from another weights. // Copy from another weights.
// //
@ -211,14 +212,18 @@ class TRT_ShapedWeights {
return std::vector<T>(span.data(), span.data() + span.size()); return std::vector<T>(span.data(), span.data() + span.size());
} }
nvinfer1::DataType TrtDType() const { return type_; }
// TODO(aaroey): make these private. // TODO(aaroey): make these private.
nvinfer1::Dims shape_; // Note: shape.type[] is not used. nvinfer1::Dims shape_; // Note: shape.type[] is not used.
DataType type_;
private: private:
// This constructor is only used by TrtWeightStore, which creates the // This constructor is only used by TrtWeightStore, which creates the
// underlying buffer. // underlying buffer.
TRT_ShapedWeights(DataType type, nvinfer1::Dims dims, Tensor tensor); TRT_ShapedWeights(nvinfer1::DataType type, nvinfer1::Dims dims,
Tensor tensor);
nvinfer1::DataType type_;
// All weights should be stored inside TrtWeightStore to make sure lifetime of // All weights should be stored inside TrtWeightStore to make sure lifetime of
// all the underlying tensors are available until the engine is built. For // all the underlying tensors are available until the engine is built. For
@ -239,12 +244,13 @@ class TRT_ShapedWeights {
class TrtWeightStore { class TrtWeightStore {
public: public:
// Get a TRT_ShapedWeights with 'type' and 'dims'. // Get a TRT_ShapedWeights with 'type' and 'dims'.
TRT_ShapedWeights GetTempWeights(DataType type, const nvinfer1::Dims& dims); TRT_ShapedWeights GetTempWeights(nvinfer1::DataType trt_type,
const nvinfer1::Dims& dims);
// Get a TRT_ShapedWeights with the same data type and dimensions as // Get a TRT_ShapedWeights with the same data type and dimensions as
// 'weights'. // 'weights'.
TRT_ShapedWeights GetTempWeights(const TRT_ShapedWeights& weights) { TRT_ShapedWeights GetTempWeights(const TRT_ShapedWeights& weights) {
return GetTempWeights(weights.type_, weights.shape_); return GetTempWeights(weights.TrtDType(), weights.shape_);
} }
private: private:

View File

@ -186,8 +186,8 @@ void ExpectArrayNear(const std::vector<Eigen::half>& lhs,
bool TrtShapedWeightsEquals(const TRT_ShapedWeights& lhs, bool TrtShapedWeightsEquals(const TRT_ShapedWeights& lhs,
const TRT_ShapedWeights& rhs) { const TRT_ShapedWeights& rhs) {
return TrtDimsEquals(lhs.shape_, rhs.shape_) && lhs.type_ == rhs.type_ && return TrtDimsEquals(lhs.shape_, rhs.shape_) &&
lhs.GetValues() == rhs.GetValues(); lhs.TrtDType() == rhs.TrtDType() && lhs.GetValues() == rhs.GetValues();
} }
template <typename T> template <typename T>
@ -293,7 +293,7 @@ TEST(TRT_ShapedWeights_Test, Basic) {
} }
// Test constructor with DataType argument. // Test constructor with DataType argument.
{ {
TRT_ShapedWeights weights(DT_FLOAT); TRT_ShapedWeights weights(nvinfer1::DataType::kFLOAT);
TRT_ShapedWeights copy(weights); TRT_ShapedWeights copy(weights);
for (auto ptr : {&weights, &copy}) { for (auto ptr : {&weights, &copy}) {
nvinfer1::Weights trt_weights = ptr->GetTrtWeights(); nvinfer1::Weights trt_weights = ptr->GetTrtWeights();
@ -310,7 +310,7 @@ TEST(TRT_ShapedWeights_Test, Basic) {
{ {
TrtWeightStore store; TrtWeightStore store;
TRT_ShapedWeights weights = TRT_ShapedWeights weights =
store.GetTempWeights(DT_FLOAT, GetTestDims({2, 5})); store.GetTempWeights(nvinfer1::DataType::kFLOAT, GetTestDims({2, 5}));
TRT_ShapedWeights copy(weights); TRT_ShapedWeights copy(weights);
for (auto ptr : {&weights, &copy}) { for (auto ptr : {&weights, &copy}) {
nvinfer1::Weights trt_weights = ptr->GetTrtWeights(); nvinfer1::Weights trt_weights = ptr->GetTrtWeights();
@ -671,7 +671,7 @@ TEST_F(ConverterTest, RenameAndMarkOutputTensors) {
params->outputs->emplace_back(output_tensor); params->outputs->emplace_back(output_tensor);
output_tensors.push_back(output_tensor); output_tensors.push_back(output_tensor);
} }
TRT_ShapedWeights output_weights(DT_FLOAT); TRT_ShapedWeights output_weights(nvinfer1::DataType::kFLOAT);
params->outputs->emplace_back(output_weights); params->outputs->emplace_back(output_weights);
return Status::OK(); return Status::OK();
}; };
@ -778,8 +778,8 @@ TEST_F(ConverterTest, PrepareTensorForShape_Tensor) {
} }
TEST_F(ConverterTest, PrepareTensorForShape_Weights) { TEST_F(ConverterTest, PrepareTensorForShape_Weights) {
TRT_ShapedWeights weights = TRT_ShapedWeights weights = weight_store_->GetTempWeights(
weight_store_->GetTempWeights(DT_FLOAT, GetTestDims({2, 3, 5})); nvinfer1::DataType::kFLOAT, GetTestDims({2, 3, 5}));
nvinfer1::ITensor* output_tensor = nullptr; nvinfer1::ITensor* output_tensor = nullptr;
for (bool validation_only : {false, true}) { for (bool validation_only : {false, true}) {
TF_EXPECT_OK(converter_->PrepareTensorForShape( TF_EXPECT_OK(converter_->PrepareTensorForShape(
@ -832,8 +832,8 @@ TEST_F(ConverterTest, AddAndGetTensorOrWeights) {
template <typename T> template <typename T>
void TestGetWeightRange(ConverterTest* test, TrtWeightStore* weight_store) { void TestGetWeightRange(ConverterTest* test, TrtWeightStore* weight_store) {
TRT_ShapedWeights weights = TRT_ShapedWeights weights = weight_store->GetTempWeights(
weight_store->GetTempWeights(DataTypeToEnum<T>::v(), GetTestDims({2, 3})); TfDataTypeToTrt(DataTypeToEnum<T>::v()), GetTestDims({2, 3}));
const std::vector<T> values = {T(3), T(1), T(2), T(6), T(5), T(4)}; const std::vector<T> values = {T(3), T(1), T(2), T(6), T(5), T(4)};
memcpy(weights.GetValues(), values.data(), weights.size_bytes()); memcpy(weights.GetValues(), values.data(), weights.size_bytes());
@ -1002,14 +1002,14 @@ TEST_F(ConverterTest, GetTrtBroadcastShape) {
} }
TEST_F(ConverterTest, CreateConstantLayer) { TEST_F(ConverterTest, CreateConstantLayer) {
for (auto dtype : {DT_FLOAT, DT_INT32}) { for (auto dtype : {nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT32}) {
TRT_ShapedWeights weights = TRT_ShapedWeights weights =
weight_store_->GetTempWeights(dtype, GetTestDims({2, 3, 5})); weight_store_->GetTempWeights(dtype, GetTestDims({2, 3, 5}));
nvinfer1::ITensor* tensor = nvinfer1::ITensor* tensor =
converter_->CreateConstantLayer(weights, GetTestDims({3, 10})); converter_->CreateConstantLayer(weights, GetTestDims({3, 10}));
ASSERT_NE(nullptr, tensor); ASSERT_NE(nullptr, tensor);
EXPECT_EQ(TfDataTypeToTrt(dtype), tensor->getType()) EXPECT_EQ(dtype, tensor->getType())
<< "Expected " << DebugString(TfDataTypeToTrt(dtype)) << " vs. actual " << "Expected " << DebugString(dtype) << " vs. actual "
<< DebugString(tensor->getType()); << DebugString(tensor->getType());
ExpectTrtDimsEqualsArray({3, 10}, tensor->getDimensions()); ExpectTrtDimsEqualsArray({3, 10}, tensor->getDimensions());
} }
@ -1246,7 +1246,7 @@ class OpConverterTest : public ::testing::Test {
template <typename T> template <typename T>
void AddTestWeights(const string& name, const std::vector<int>& dims, void AddTestWeights(const string& name, const std::vector<int>& dims,
const std::vector<T>& values) { const std::vector<T>& values) {
const DataType dtype = DataTypeToEnum<T>::v(); const nvinfer1::DataType dtype = TfDataTypeToTrt(DataTypeToEnum<T>::v());
const nvinfer1::Dims trt_dims = GetTestDims(dims); const nvinfer1::Dims trt_dims = GetTestDims(dims);
const int64_t num_elements = TrtWeightDimsNumElements(trt_dims); const int64_t num_elements = TrtWeightDimsNumElements(trt_dims);
QCHECK_EQ(num_elements, values.size()) QCHECK_EQ(num_elements, values.size())
@ -1452,6 +1452,9 @@ TEST_F(OpConverterTest, ConvertConst) {
TestConvertConst<DT_FLOAT, float, float>(this); TestConvertConst<DT_FLOAT, float, float>(this);
TestConvertConst<DT_INT8, int8, int32>(this); TestConvertConst<DT_INT8, int8, int32>(this);
TestConvertConst<DT_UINT8, uint8, int32>(this);
TestConvertConst<DT_INT16, int16, int32>(this);
TestConvertConst<DT_UINT16, uint16, int32>(this);
TestConvertConst<DT_INT32, int32, int32>(this); TestConvertConst<DT_INT32, int32, int32>(this);
} }