Use TensorRT DataType in TRT_ShapedWeights, so that we know the weights are always valid.
Also fix a bug in where it converts DT_INT8 (non-quantized type) to nvinfer1::kINT8 (quantized type). PiperOrigin-RevId: 241269981
This commit is contained in:
parent
9fd90ad4c7
commit
6ff8dd7b5b
@ -98,15 +98,12 @@ namespace convert {
|
||||
using absl::StrAppend;
|
||||
using absl::StrCat;
|
||||
|
||||
inline Status ConvertDType(DataType tf_dtype, nvinfer1::DataType* trt_dtype) {
|
||||
inline Status TfDataTypeToTrt(DataType tf_dtype,
|
||||
nvinfer1::DataType* trt_dtype) {
|
||||
switch (tf_dtype) {
|
||||
case DataType::DT_FLOAT:
|
||||
*trt_dtype = nvinfer1::DataType::kFLOAT;
|
||||
break;
|
||||
// TODO(aaroey): this should be DT_QINT8 which is not a well supported type.
|
||||
case DataType::DT_INT8:
|
||||
*trt_dtype = nvinfer1::DataType::kINT8;
|
||||
break;
|
||||
case DataType::DT_HALF:
|
||||
*trt_dtype = nvinfer1::DataType::kHALF;
|
||||
break;
|
||||
@ -120,6 +117,25 @@ inline Status ConvertDType(DataType tf_dtype, nvinfer1::DataType* trt_dtype) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
inline Status TrtDataTypeToTf(nvinfer1::DataType trt_dtype,
|
||||
DataType* tf_dtype) {
|
||||
switch (trt_dtype) {
|
||||
case nvinfer1::DataType::kFLOAT:
|
||||
*tf_dtype = DataType::DT_FLOAT;
|
||||
break;
|
||||
case nvinfer1::DataType::kHALF:
|
||||
*tf_dtype = DataType::DT_HALF;
|
||||
break;
|
||||
case nvinfer1::DataType::kINT32:
|
||||
*tf_dtype = DataType::DT_INT32;
|
||||
break;
|
||||
default:
|
||||
return errors::InvalidArgument("Unsupported data type ",
|
||||
DebugString(trt_dtype));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
class TFAttrs {
|
||||
public:
|
||||
explicit TFAttrs(const NodeDef& tf_node) {
|
||||
@ -178,7 +194,7 @@ std::vector<float> TFAttrs::get<std::vector<float>>(const string& key) const {
|
||||
template <>
|
||||
nvinfer1::DataType TFAttrs::get<nvinfer1::DataType>(const string& key) const {
|
||||
nvinfer1::DataType trt_dtype(nvinfer1::DataType::kFLOAT);
|
||||
TF_CHECK_OK(ConvertDType(this->at(key)->type(), &trt_dtype));
|
||||
TF_CHECK_OK(TfDataTypeToTrt(this->at(key)->type(), &trt_dtype));
|
||||
return trt_dtype;
|
||||
}
|
||||
|
||||
@ -268,7 +284,7 @@ Status ValidateTensorProperties(const string& producer_node_type,
|
||||
nvinfer1::DataType* trt_dtype,
|
||||
nvinfer1::Dims* trt_dims, int* batch_size) {
|
||||
// Convert data type.
|
||||
TF_RETURN_IF_ERROR(ConvertDType(dtype, trt_dtype));
|
||||
TF_RETURN_IF_ERROR(TfDataTypeToTrt(dtype, trt_dtype));
|
||||
|
||||
// Convert shape.
|
||||
if (shape.dims() < 0) {
|
||||
@ -472,12 +488,12 @@ Status CreateBroadcastableScalarConstant(OpConverterParams* params, float value,
|
||||
const nvinfer1::Dims& dims,
|
||||
nvinfer1::ITensor** tensor,
|
||||
const char* dtype_attr_name = "T") {
|
||||
nvinfer1::DataType trt_dtype =
|
||||
nvinfer1::DataType::kFLOAT; // Default to FP32.
|
||||
TFAttrs attrs(params->node_def);
|
||||
DataType dtype;
|
||||
if (attrs.count(dtype_attr_name)) {
|
||||
dtype = attrs.get<DataType>(dtype_attr_name);
|
||||
} else {
|
||||
dtype = DT_FLOAT; // Default to FP32.
|
||||
DataType dtype = attrs.get<DataType>(dtype_attr_name);
|
||||
TF_RETURN_IF_ERROR(TfDataTypeToTrt(dtype, &trt_dtype));
|
||||
}
|
||||
|
||||
// In order to be broadcastable, the number of dims has to match.
|
||||
@ -486,18 +502,18 @@ Status CreateBroadcastableScalarConstant(OpConverterParams* params, float value,
|
||||
broadcastable_dims.d[i] = 1;
|
||||
}
|
||||
TRT_ShapedWeights weights =
|
||||
params->weight_store->GetTempWeights(dtype, broadcastable_dims);
|
||||
params->weight_store->GetTempWeights(trt_dtype, broadcastable_dims);
|
||||
void* raw_ptr = weights.GetValues();
|
||||
switch (dtype) {
|
||||
case DataType::DT_FLOAT:
|
||||
switch (trt_dtype) {
|
||||
case nvinfer1::DataType::kFLOAT:
|
||||
static_cast<float*>(raw_ptr)[0] = value;
|
||||
break;
|
||||
case DataType::DT_HALF:
|
||||
case nvinfer1::DataType::kHALF:
|
||||
static_cast<Eigen::half*>(raw_ptr)[0] = Eigen::half(value);
|
||||
break;
|
||||
default:
|
||||
return errors::InvalidArgument("Unsupported data type ",
|
||||
DataTypeString(dtype));
|
||||
DebugString(trt_dtype));
|
||||
}
|
||||
*tensor = params->converter->CreateConstantLayer(weights, broadcastable_dims);
|
||||
TFTRT_RETURN_ERROR_IF_NULLPTR(*tensor, params->node_def.name());
|
||||
@ -676,12 +692,12 @@ Status VerifyShapesMatch(absl::Span<const TRT_TensorOrWeights> inputs,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TRT_ShapedWeights::TRT_ShapedWeights(DataType type) : type_(type) {
|
||||
TRT_ShapedWeights::TRT_ShapedWeights(nvinfer1::DataType type) : type_(type) {
|
||||
shape_.nbDims = 0;
|
||||
}
|
||||
|
||||
TRT_ShapedWeights::TRT_ShapedWeights(DataType type, nvinfer1::Dims dims,
|
||||
Tensor tensor)
|
||||
TRT_ShapedWeights::TRT_ShapedWeights(nvinfer1::DataType type,
|
||||
nvinfer1::Dims dims, Tensor tensor)
|
||||
: shape_(dims), type_(type), tensor_(tensor) {}
|
||||
|
||||
TRT_ShapedWeights::TRT_ShapedWeights(const TRT_ShapedWeights& rhs)
|
||||
@ -692,18 +708,29 @@ int64_t TRT_ShapedWeights::count() const {
|
||||
}
|
||||
|
||||
nvinfer1::Weights TRT_ShapedWeights::GetTrtWeights() const {
|
||||
nvinfer1::DataType trt_type(nvinfer1::DataType::kFLOAT);
|
||||
TF_CHECK_OK(ConvertDType(type_, &trt_type));
|
||||
return nvinfer1::Weights{trt_type, GetValues(), count()};
|
||||
return nvinfer1::Weights{type_, GetValues(), count()};
|
||||
}
|
||||
|
||||
size_t TRT_ShapedWeights::size_bytes() const {
|
||||
return this->count() * DataTypeSize(this->type_);
|
||||
size_t data_type_size = -1;
|
||||
switch (type_) {
|
||||
case nvinfer1::DataType::kFLOAT:
|
||||
case nvinfer1::DataType::kINT32:
|
||||
data_type_size = 4;
|
||||
break;
|
||||
case nvinfer1::DataType::kHALF:
|
||||
data_type_size = 2;
|
||||
break;
|
||||
case nvinfer1::DataType::kINT8:
|
||||
data_type_size = 1;
|
||||
break;
|
||||
}
|
||||
return this->count() * data_type_size;
|
||||
}
|
||||
|
||||
string TRT_ShapedWeights::DebugString() const {
|
||||
return StrCat("TRT_ShapedWeights(shape=", convert::DebugString(shape_),
|
||||
", type=", DataTypeString(type_),
|
||||
", type=", convert::DebugString(type_),
|
||||
", values=", reinterpret_cast<uintptr_t>(GetValues()), ")");
|
||||
}
|
||||
|
||||
@ -867,13 +894,13 @@ void ReorderCKtoKC(const TRT_ShapedWeights& iweights,
|
||||
oweights->shape_.d[1] = c;
|
||||
const nvinfer1::DimsHW istrides = {1, k};
|
||||
const nvinfer1::DimsHW ostrides = {c, 1};
|
||||
switch (iweights.type_) {
|
||||
case DataType::DT_FLOAT: {
|
||||
switch (iweights.TrtDType()) {
|
||||
case nvinfer1::DataType::kFLOAT: {
|
||||
Reorder2({k, c}, static_cast<float const*>(iweights.GetValues()),
|
||||
istrides, static_cast<float*>(oweights->GetValues()), ostrides);
|
||||
break;
|
||||
}
|
||||
case DataType::DT_HALF: {
|
||||
case nvinfer1::DataType::kHALF: {
|
||||
Reorder2({k, c}, static_cast<Eigen::half const*>(iweights.GetValues()),
|
||||
istrides, static_cast<Eigen::half*>(oweights->GetValues()),
|
||||
ostrides);
|
||||
@ -881,13 +908,13 @@ void ReorderCKtoKC(const TRT_ShapedWeights& iweights,
|
||||
}
|
||||
default:
|
||||
LOG(FATAL) << "Unsupported type in reorder expected fp32 or fp16 but got "
|
||||
<< DataTypeString(iweights.type_);
|
||||
<< DebugString(iweights.TrtDType());
|
||||
}
|
||||
}
|
||||
|
||||
void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights,
|
||||
TRT_ShapedWeights* oweights, const int num_groups) {
|
||||
CHECK_EQ(iweights.type_, oweights->type_);
|
||||
CHECK(iweights.TrtDType() == oweights->TrtDType());
|
||||
CHECK_EQ(iweights.size_bytes(), oweights->size_bytes());
|
||||
// K indexes over output channels, C over input channels, and R and S over the
|
||||
// height and width of the convolution
|
||||
@ -906,13 +933,13 @@ void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights,
|
||||
oweights->shape_.d[3] = s;
|
||||
const nvinfer1::DimsNCHW istrides = {1, k, s * k * c, c * k};
|
||||
const nvinfer1::DimsNCHW ostrides = {c * r * s, r * s, s, 1};
|
||||
switch (iweights.type_) {
|
||||
case DataType::DT_FLOAT: {
|
||||
switch (iweights.TrtDType()) {
|
||||
case nvinfer1::DataType::kFLOAT: {
|
||||
Reorder4({k, c, r, s}, static_cast<float const*>(iweights.GetValues()),
|
||||
istrides, static_cast<float*>(oweights->GetValues()), ostrides);
|
||||
break;
|
||||
}
|
||||
case DataType::DT_HALF: {
|
||||
case nvinfer1::DataType::kHALF: {
|
||||
Reorder4({k, c, r, s},
|
||||
static_cast<Eigen::half const*>(iweights.GetValues()), istrides,
|
||||
static_cast<Eigen::half*>(oweights->GetValues()), ostrides);
|
||||
@ -921,18 +948,20 @@ void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights,
|
||||
|
||||
default:
|
||||
LOG(FATAL) << "Unsupported type, expected fp32 or fp16 but got "
|
||||
<< DataTypeString(iweights.type_);
|
||||
<< DebugString(iweights.TrtDType());
|
||||
}
|
||||
}
|
||||
|
||||
TRT_ShapedWeights TrtWeightStore::GetTempWeights(DataType type,
|
||||
TRT_ShapedWeights TrtWeightStore::GetTempWeights(nvinfer1::DataType trt_dtype,
|
||||
const nvinfer1::Dims& dims) {
|
||||
TensorShape shape;
|
||||
DataType tf_dtype;
|
||||
// TODO(laigd): make it return a status.
|
||||
TF_CHECK_OK(TensorShapeUtils::MakeShape(dims.d, dims.nbDims, &shape));
|
||||
TF_CHECK_OK(TrtDataTypeToTf(trt_dtype, &tf_dtype));
|
||||
// TODO(jie): check weights size_bytes. 0 means type error
|
||||
Tensor tensor(type, shape);
|
||||
TRT_ShapedWeights weights(type, dims, tensor);
|
||||
Tensor tensor(tf_dtype, shape);
|
||||
TRT_ShapedWeights weights(trt_dtype, dims, tensor);
|
||||
store_.emplace_back(std::move(tensor));
|
||||
return weights;
|
||||
}
|
||||
@ -1282,22 +1311,22 @@ Status Converter::TransposeTensor(nvinfer1::ITensor* input_tensor,
|
||||
|
||||
Status Converter::GetWeightRange(const TRT_ShapedWeights& weights,
|
||||
float* out_min, float* out_max) const {
|
||||
switch (weights.type_) {
|
||||
case DataType::DT_FLOAT: {
|
||||
switch (weights.TrtDType()) {
|
||||
case nvinfer1::DataType::kFLOAT: {
|
||||
auto inp = static_cast<float const*>(weights.GetValues());
|
||||
auto result = std::minmax_element(inp, inp + weights.count());
|
||||
*out_min = *result.first;
|
||||
*out_max = *result.second;
|
||||
break;
|
||||
}
|
||||
case DataType::DT_HALF: {
|
||||
case nvinfer1::DataType::kHALF: {
|
||||
auto inp = static_cast<Eigen::half const*>(weights.GetValues());
|
||||
auto result = std::minmax_element(inp, inp + weights.count());
|
||||
*out_min = Eigen::half_impl::half_to_float(*result.first);
|
||||
*out_max = Eigen::half_impl::half_to_float(*result.second);
|
||||
break;
|
||||
}
|
||||
case DataType::DT_INT32: {
|
||||
case nvinfer1::DataType::kINT32: {
|
||||
auto inp = static_cast<int const*>(weights.GetValues());
|
||||
auto result = std::minmax_element(inp, inp + weights.count());
|
||||
*out_min = static_cast<float>(*result.first);
|
||||
@ -1307,7 +1336,7 @@ Status Converter::GetWeightRange(const TRT_ShapedWeights& weights,
|
||||
default:
|
||||
return errors::Unimplemented(
|
||||
"Data type not supported for GetWeightRange: ",
|
||||
DataTypeString(weights.type_));
|
||||
DebugString(weights.TrtDType()));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -1562,9 +1591,8 @@ Status AllowDataTypes(const OpConverterParams& params,
|
||||
|
||||
TRT_ShapedWeights ConvertFP32ToFP16(TrtWeightStore* store,
|
||||
const TRT_ShapedWeights& weights_src) {
|
||||
auto dtype_new = DataType::DT_HALF;
|
||||
TRT_ShapedWeights weights =
|
||||
store->GetTempWeights(dtype_new, weights_src.shape_);
|
||||
store->GetTempWeights(nvinfer1::DataType::kHALF, weights_src.shape_);
|
||||
const float* src = static_cast<const float*>(weights_src.GetValues());
|
||||
Eigen::half* dst = static_cast<Eigen::half*>(weights.GetValues());
|
||||
for (int64_t i = 0; i < weights_src.count(); i++) {
|
||||
@ -1622,15 +1650,15 @@ std::function<Eigen::half(Eigen::half)> LambdaFactory::unary<Eigen::half>() {
|
||||
|
||||
Status UnaryCompute(const TRT_ShapedWeights& iweights,
|
||||
TRT_ShapedWeights* oweights, LambdaFactory unary_op) {
|
||||
CHECK_EQ(iweights.type_, oweights->type_);
|
||||
switch (iweights.type_) {
|
||||
case DataType::DT_FLOAT: {
|
||||
CHECK(iweights.TrtDType() == oweights->TrtDType());
|
||||
switch (iweights.TrtDType()) {
|
||||
case nvinfer1::DataType::kFLOAT: {
|
||||
auto inp = static_cast<float const*>(iweights.GetValues());
|
||||
auto oup = static_cast<float*>(oweights->GetValues());
|
||||
std::transform(inp, inp + iweights.count(), oup, unary_op.unary<float>());
|
||||
break;
|
||||
}
|
||||
case DataType::DT_HALF: {
|
||||
case nvinfer1::DataType::kHALF: {
|
||||
auto inp = static_cast<Eigen::half const*>(iweights.GetValues());
|
||||
auto oup = static_cast<Eigen::half*>(oweights->GetValues());
|
||||
std::transform(inp, inp + iweights.count(), oup,
|
||||
@ -1638,8 +1666,8 @@ Status UnaryCompute(const TRT_ShapedWeights& iweights,
|
||||
break;
|
||||
}
|
||||
default:
|
||||
return errors::Unimplemented("Data type not supported: " +
|
||||
DataTypeString(iweights.type_));
|
||||
return errors::Unimplemented("Data type not supported: ",
|
||||
DebugString(iweights.TrtDType()));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -1660,10 +1688,6 @@ Status BinaryTensorOpWeight(OpConverterParams* params,
|
||||
node_def.name());
|
||||
}
|
||||
|
||||
// Check type consistency.
|
||||
nvinfer1::DataType trt_dtype;
|
||||
TF_RETURN_IF_ERROR(ConvertDType(weights.type_, &trt_dtype));
|
||||
|
||||
// Check scale mode.
|
||||
auto dims_w = weights.shape_;
|
||||
const auto dims_t = tensor->getDimensions();
|
||||
@ -1753,9 +1777,9 @@ Status BinaryTensorOpWeight(OpConverterParams* params,
|
||||
}
|
||||
|
||||
// Prepare weights
|
||||
TRT_ShapedWeights shift_weights(weights.type_);
|
||||
TRT_ShapedWeights scale_weights(weights.type_);
|
||||
TRT_ShapedWeights power_weights(weights.type_);
|
||||
TRT_ShapedWeights shift_weights(weights.TrtDType());
|
||||
TRT_ShapedWeights scale_weights(weights.TrtDType());
|
||||
TRT_ShapedWeights power_weights(weights.TrtDType());
|
||||
|
||||
if (node_def.op() == "Sub") {
|
||||
if (swapped_inputs) {
|
||||
@ -1922,7 +1946,7 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group,
|
||||
TRT_ShapedWeights weights =
|
||||
params->weight_store->GetTempWeights(weights_rsck);
|
||||
ReorderRSCKToKCRS(weights_rsck, &weights, num_groups);
|
||||
TRT_ShapedWeights biases(weights.type_);
|
||||
TRT_ShapedWeights biases(weights.TrtDType());
|
||||
const int output_axis = is_conv2d_backprop_input ? 1 : 0;
|
||||
const int noutput = weights.shape_.d[output_axis] * num_groups;
|
||||
nvinfer1::DimsHW kernel_size;
|
||||
@ -3022,7 +3046,7 @@ Status ConvertBiasAdd(OpConverterParams* params) {
|
||||
mode = nvinfer1::ScaleMode::kUNIFORM;
|
||||
}
|
||||
|
||||
TRT_ShapedWeights empty_weights(weights.type_);
|
||||
TRT_ShapedWeights empty_weights(weights.TrtDType());
|
||||
nvinfer1::IScaleLayer* layer = params->converter->network()->addScale(
|
||||
*tensor, mode, weights.GetTrtWeights(), empty_weights.GetTrtWeights(),
|
||||
empty_weights.GetTrtWeights());
|
||||
@ -3072,33 +3096,41 @@ void GetTensorDimsWithProtoShape(const Tensor& tensor, nvinfer1::Dims* dims) {
|
||||
}
|
||||
}
|
||||
|
||||
template <DataType dtype>
|
||||
void CopyToTrtInt32Array(const Tensor& tensor, int32* dst) {
|
||||
typedef typename EnumToDataType<dtype>::Type CType;
|
||||
const CType* src = tensor.flat<CType>().data();
|
||||
std::copy(src, src + tensor.NumElements(), dst);
|
||||
}
|
||||
|
||||
Status TfTensorToTrtWeights(const Tensor& tensor, TrtWeightStore* weight_store,
|
||||
TRT_ShapedWeights* weights) {
|
||||
const DataType dtype = tensor.dtype();
|
||||
|
||||
// We always convert the integer constants to INT32, since TRT INT8 is for
|
||||
// quantized inference.
|
||||
// We always convert the integer constants to INT32.
|
||||
//
|
||||
// TODO(aaroey): FP16 will remain in half format and is not converted to
|
||||
// FP32, but the converter currently uses all float weights as FP32. Fix
|
||||
// this.
|
||||
const DataType converted_dtype =
|
||||
(dtype == DT_INT16 || dtype == DT_INT8 || dtype == DT_UINT8 ? DT_INT32
|
||||
: dtype);
|
||||
DataType converted_dtype = dtype;
|
||||
if (dtype == DataType::DT_INT8 || dtype == DataType::DT_UINT8 ||
|
||||
dtype == DataType::DT_INT16 || dtype == DataType::DT_UINT16) {
|
||||
converted_dtype = DT_INT32;
|
||||
}
|
||||
|
||||
// Verify that the dtype is supported by TensorRT. Otherwise, return an error.
|
||||
nvinfer1::DataType trt_dtype;
|
||||
TF_RETURN_IF_ERROR(ConvertDType(converted_dtype, &trt_dtype));
|
||||
TF_RETURN_IF_ERROR(TfDataTypeToTrt(converted_dtype, &trt_dtype));
|
||||
|
||||
if (tensor.NumElements() == 0) {
|
||||
// Return empty weights having converted dtype.
|
||||
*weights = TRT_ShapedWeights(converted_dtype);
|
||||
// Return empty weights.
|
||||
*weights = TRT_ShapedWeights(trt_dtype);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
nvinfer1::Dims weight_dims;
|
||||
GetTensorDimsWithProtoShape(tensor, &weight_dims);
|
||||
*weights = weight_store->GetTempWeights(converted_dtype, weight_dims);
|
||||
*weights = weight_store->GetTempWeights(trt_dtype, weight_dims);
|
||||
|
||||
// Copy the tensor directly if the tensor does not require cast to the
|
||||
// supported type.
|
||||
@ -3110,17 +3142,21 @@ Status TfTensorToTrtWeights(const Tensor& tensor, TrtWeightStore* weight_store,
|
||||
|
||||
// Copy tensor elements after casting them to the converted DataType.
|
||||
int32* dst = static_cast<int32*>(weights->GetValues());
|
||||
if (dtype == DT_INT16) {
|
||||
const int16* src = tensor.flat<int16>().data();
|
||||
std::copy(src, src + tensor.NumElements(), dst);
|
||||
} else if (dtype == DT_INT8) {
|
||||
const int8* src = tensor.flat<int8>().data();
|
||||
std::copy(src, src + tensor.NumElements(), dst);
|
||||
} else {
|
||||
// dtype can only be DT_UINT8 at this point.
|
||||
TFTRT_CHECK_EQ_TYPE(dtype, DT_UINT8);
|
||||
const uint8* src = tensor.flat<uint8>().data();
|
||||
std::copy(src, src + tensor.NumElements(), dst);
|
||||
switch (dtype) {
|
||||
case DT_INT8:
|
||||
CopyToTrtInt32Array<DT_INT8>(tensor, dst);
|
||||
break;
|
||||
case DT_UINT8:
|
||||
CopyToTrtInt32Array<DT_UINT8>(tensor, dst);
|
||||
break;
|
||||
case DT_INT16:
|
||||
CopyToTrtInt32Array<DT_INT16>(tensor, dst);
|
||||
break;
|
||||
case DT_UINT16:
|
||||
CopyToTrtInt32Array<DT_UINT16>(tensor, dst);
|
||||
break;
|
||||
default:
|
||||
return errors::Internal("Unexpected DataType: ", DataTypeString(dtype));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -3782,15 +3818,15 @@ Status ConvertFusedBatchNorm(OpConverterParams* params) {
|
||||
nvinfer1::ITensor* tensor = inputs.at(0).tensor();
|
||||
|
||||
// Check parameter types
|
||||
auto parameter_type = inputs.at(1).weights().type_;
|
||||
if ((parameter_type != DataType::DT_FLOAT) &&
|
||||
(parameter_type != DataType::DT_HALF)) {
|
||||
auto parameter_type = inputs.at(1).weights().TrtDType();
|
||||
if ((parameter_type != nvinfer1::DataType::kFLOAT) &&
|
||||
(parameter_type != nvinfer1::DataType::kHALF)) {
|
||||
return errors::Unimplemented(
|
||||
"only float32 or float16 weight data type is supported, for node " +
|
||||
node_def.name() + " got " + DataTypeString(parameter_type));
|
||||
"Only float32 or float16 weight data type is supported, for node ",
|
||||
node_def.name(), " got ", DebugString(parameter_type));
|
||||
}
|
||||
for (int i = 1; i < 5; i++) {
|
||||
if (inputs.at(i).weights().type_ != parameter_type) {
|
||||
if (inputs.at(i).weights().TrtDType() != parameter_type) {
|
||||
return errors::Unimplemented(
|
||||
"Inconsistent parameter type for batchnorm is not supported, at: " +
|
||||
node_def.name());
|
||||
@ -3841,16 +3877,16 @@ Status ConvertFusedBatchNorm(OpConverterParams* params) {
|
||||
float batchnorm_data[4];
|
||||
for (int j = 0; j < 4; j++) {
|
||||
if (inputs.at(j + 1).weights().count() != 1) {
|
||||
if (parameter_type == DT_FLOAT) {
|
||||
if (parameter_type == nvinfer1::DataType::kFLOAT) {
|
||||
batchnorm_data[j] = vals_array[j][i];
|
||||
} else if (parameter_type == DT_HALF) {
|
||||
} else if (parameter_type == nvinfer1::DataType::kHALF) {
|
||||
batchnorm_data[j] =
|
||||
Eigen::half_impl::half_to_float(cast_vals_array[j][i]);
|
||||
}
|
||||
} else {
|
||||
if (parameter_type == DT_FLOAT) {
|
||||
if (parameter_type == nvinfer1::DataType::kFLOAT) {
|
||||
batchnorm_data[j] = vals_array[j][0];
|
||||
} else if (parameter_type == DT_HALF) {
|
||||
} else if (parameter_type == nvinfer1::DataType::kHALF) {
|
||||
batchnorm_data[j] =
|
||||
Eigen::half_impl::half_to_float(cast_vals_array[j][0]);
|
||||
}
|
||||
@ -3862,10 +3898,10 @@ Status ConvertFusedBatchNorm(OpConverterParams* params) {
|
||||
float variance = batchnorm_data[3];
|
||||
float combined_scale_val = scale / sqrtf(variance + epsilon);
|
||||
float combined_offset_val = offset - mean * combined_scale_val;
|
||||
if (parameter_type == DT_FLOAT) {
|
||||
if (parameter_type == nvinfer1::DataType::kFLOAT) {
|
||||
combined_scale_vals[i] = combined_scale_val;
|
||||
combined_offset_vals[i] = combined_offset_val;
|
||||
} else if (parameter_type == DT_HALF) {
|
||||
} else if (parameter_type == nvinfer1::DataType::kHALF) {
|
||||
cast_combined_scale_vals[i] = Eigen::half(combined_scale_val);
|
||||
cast_combined_offset_vals[i] = Eigen::half(combined_offset_val);
|
||||
}
|
||||
@ -3962,14 +3998,14 @@ Status ConvertMatMulHelper(OpConverterParams* params,
|
||||
}
|
||||
nvinfer1::ITensor* tensor = tensor_input.tensor();
|
||||
|
||||
TRT_ShapedWeights weights(weights_raw.type_);
|
||||
TRT_ShapedWeights weights(weights_raw.TrtDType());
|
||||
if (transpose_weight) {
|
||||
weights = weights_raw;
|
||||
} else {
|
||||
weights = params->weight_store->GetTempWeights(weights_raw);
|
||||
ReorderCKtoKC(weights_raw, &weights);
|
||||
}
|
||||
TRT_ShapedWeights biases(weights.type_);
|
||||
TRT_ShapedWeights biases(weights.TrtDType());
|
||||
|
||||
int noutput = weights.shape_.d[0];
|
||||
|
||||
@ -4472,7 +4508,7 @@ Status ConvertGraphDefToEngine(
|
||||
TFAttrs attrs(node_def);
|
||||
DataType tf_dtype = attrs.get<DataType>("T");
|
||||
nvinfer1::DataType trt_dtype;
|
||||
TF_RETURN_IF_ERROR(ConvertDType(tf_dtype, &trt_dtype));
|
||||
TF_RETURN_IF_ERROR(TfDataTypeToTrt(tf_dtype, &trt_dtype));
|
||||
if (output_tensors.size() <= slot_number) {
|
||||
output_tensors.resize(slot_number + 1);
|
||||
}
|
||||
|
@ -176,7 +176,8 @@ int64_t TrtTensorDimsNumElements(const nvinfer1::Dims& dims);
|
||||
// Class to convert TF compile-time constants (e.g. Const nodes) to TRT weight.
|
||||
class TRT_ShapedWeights {
|
||||
public:
|
||||
explicit TRT_ShapedWeights(DataType type = DT_FLOAT);
|
||||
explicit TRT_ShapedWeights(
|
||||
nvinfer1::DataType type = nvinfer1::DataType::kFLOAT);
|
||||
|
||||
// Copy from another weights.
|
||||
//
|
||||
@ -211,14 +212,18 @@ class TRT_ShapedWeights {
|
||||
return std::vector<T>(span.data(), span.data() + span.size());
|
||||
}
|
||||
|
||||
nvinfer1::DataType TrtDType() const { return type_; }
|
||||
|
||||
// TODO(aaroey): make these private.
|
||||
nvinfer1::Dims shape_; // Note: shape.type[] is not used.
|
||||
DataType type_;
|
||||
|
||||
private:
|
||||
// This constructor is only used by TrtWeightStore, which creates the
|
||||
// underlying buffer.
|
||||
TRT_ShapedWeights(DataType type, nvinfer1::Dims dims, Tensor tensor);
|
||||
TRT_ShapedWeights(nvinfer1::DataType type, nvinfer1::Dims dims,
|
||||
Tensor tensor);
|
||||
|
||||
nvinfer1::DataType type_;
|
||||
|
||||
// All weights should be stored inside TrtWeightStore to make sure lifetime of
|
||||
// all the underlying tensors are available until the engine is built. For
|
||||
@ -239,12 +244,13 @@ class TRT_ShapedWeights {
|
||||
class TrtWeightStore {
|
||||
public:
|
||||
// Get a TRT_ShapedWeights with 'type' and 'dims'.
|
||||
TRT_ShapedWeights GetTempWeights(DataType type, const nvinfer1::Dims& dims);
|
||||
TRT_ShapedWeights GetTempWeights(nvinfer1::DataType trt_type,
|
||||
const nvinfer1::Dims& dims);
|
||||
|
||||
// Get a TRT_ShapedWeights with the same data type and dimensions as
|
||||
// 'weights'.
|
||||
TRT_ShapedWeights GetTempWeights(const TRT_ShapedWeights& weights) {
|
||||
return GetTempWeights(weights.type_, weights.shape_);
|
||||
return GetTempWeights(weights.TrtDType(), weights.shape_);
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -186,8 +186,8 @@ void ExpectArrayNear(const std::vector<Eigen::half>& lhs,
|
||||
|
||||
bool TrtShapedWeightsEquals(const TRT_ShapedWeights& lhs,
|
||||
const TRT_ShapedWeights& rhs) {
|
||||
return TrtDimsEquals(lhs.shape_, rhs.shape_) && lhs.type_ == rhs.type_ &&
|
||||
lhs.GetValues() == rhs.GetValues();
|
||||
return TrtDimsEquals(lhs.shape_, rhs.shape_) &&
|
||||
lhs.TrtDType() == rhs.TrtDType() && lhs.GetValues() == rhs.GetValues();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@ -293,7 +293,7 @@ TEST(TRT_ShapedWeights_Test, Basic) {
|
||||
}
|
||||
// Test constructor with DataType argument.
|
||||
{
|
||||
TRT_ShapedWeights weights(DT_FLOAT);
|
||||
TRT_ShapedWeights weights(nvinfer1::DataType::kFLOAT);
|
||||
TRT_ShapedWeights copy(weights);
|
||||
for (auto ptr : {&weights, ©}) {
|
||||
nvinfer1::Weights trt_weights = ptr->GetTrtWeights();
|
||||
@ -310,7 +310,7 @@ TEST(TRT_ShapedWeights_Test, Basic) {
|
||||
{
|
||||
TrtWeightStore store;
|
||||
TRT_ShapedWeights weights =
|
||||
store.GetTempWeights(DT_FLOAT, GetTestDims({2, 5}));
|
||||
store.GetTempWeights(nvinfer1::DataType::kFLOAT, GetTestDims({2, 5}));
|
||||
TRT_ShapedWeights copy(weights);
|
||||
for (auto ptr : {&weights, ©}) {
|
||||
nvinfer1::Weights trt_weights = ptr->GetTrtWeights();
|
||||
@ -671,7 +671,7 @@ TEST_F(ConverterTest, RenameAndMarkOutputTensors) {
|
||||
params->outputs->emplace_back(output_tensor);
|
||||
output_tensors.push_back(output_tensor);
|
||||
}
|
||||
TRT_ShapedWeights output_weights(DT_FLOAT);
|
||||
TRT_ShapedWeights output_weights(nvinfer1::DataType::kFLOAT);
|
||||
params->outputs->emplace_back(output_weights);
|
||||
return Status::OK();
|
||||
};
|
||||
@ -778,8 +778,8 @@ TEST_F(ConverterTest, PrepareTensorForShape_Tensor) {
|
||||
}
|
||||
|
||||
TEST_F(ConverterTest, PrepareTensorForShape_Weights) {
|
||||
TRT_ShapedWeights weights =
|
||||
weight_store_->GetTempWeights(DT_FLOAT, GetTestDims({2, 3, 5}));
|
||||
TRT_ShapedWeights weights = weight_store_->GetTempWeights(
|
||||
nvinfer1::DataType::kFLOAT, GetTestDims({2, 3, 5}));
|
||||
nvinfer1::ITensor* output_tensor = nullptr;
|
||||
for (bool validation_only : {false, true}) {
|
||||
TF_EXPECT_OK(converter_->PrepareTensorForShape(
|
||||
@ -832,8 +832,8 @@ TEST_F(ConverterTest, AddAndGetTensorOrWeights) {
|
||||
|
||||
template <typename T>
|
||||
void TestGetWeightRange(ConverterTest* test, TrtWeightStore* weight_store) {
|
||||
TRT_ShapedWeights weights =
|
||||
weight_store->GetTempWeights(DataTypeToEnum<T>::v(), GetTestDims({2, 3}));
|
||||
TRT_ShapedWeights weights = weight_store->GetTempWeights(
|
||||
TfDataTypeToTrt(DataTypeToEnum<T>::v()), GetTestDims({2, 3}));
|
||||
const std::vector<T> values = {T(3), T(1), T(2), T(6), T(5), T(4)};
|
||||
memcpy(weights.GetValues(), values.data(), weights.size_bytes());
|
||||
|
||||
@ -1002,14 +1002,14 @@ TEST_F(ConverterTest, GetTrtBroadcastShape) {
|
||||
}
|
||||
|
||||
TEST_F(ConverterTest, CreateConstantLayer) {
|
||||
for (auto dtype : {DT_FLOAT, DT_INT32}) {
|
||||
for (auto dtype : {nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT32}) {
|
||||
TRT_ShapedWeights weights =
|
||||
weight_store_->GetTempWeights(dtype, GetTestDims({2, 3, 5}));
|
||||
nvinfer1::ITensor* tensor =
|
||||
converter_->CreateConstantLayer(weights, GetTestDims({3, 10}));
|
||||
ASSERT_NE(nullptr, tensor);
|
||||
EXPECT_EQ(TfDataTypeToTrt(dtype), tensor->getType())
|
||||
<< "Expected " << DebugString(TfDataTypeToTrt(dtype)) << " vs. actual "
|
||||
EXPECT_EQ(dtype, tensor->getType())
|
||||
<< "Expected " << DebugString(dtype) << " vs. actual "
|
||||
<< DebugString(tensor->getType());
|
||||
ExpectTrtDimsEqualsArray({3, 10}, tensor->getDimensions());
|
||||
}
|
||||
@ -1246,7 +1246,7 @@ class OpConverterTest : public ::testing::Test {
|
||||
template <typename T>
|
||||
void AddTestWeights(const string& name, const std::vector<int>& dims,
|
||||
const std::vector<T>& values) {
|
||||
const DataType dtype = DataTypeToEnum<T>::v();
|
||||
const nvinfer1::DataType dtype = TfDataTypeToTrt(DataTypeToEnum<T>::v());
|
||||
const nvinfer1::Dims trt_dims = GetTestDims(dims);
|
||||
const int64_t num_elements = TrtWeightDimsNumElements(trt_dims);
|
||||
QCHECK_EQ(num_elements, values.size())
|
||||
@ -1452,6 +1452,9 @@ TEST_F(OpConverterTest, ConvertConst) {
|
||||
|
||||
TestConvertConst<DT_FLOAT, float, float>(this);
|
||||
TestConvertConst<DT_INT8, int8, int32>(this);
|
||||
TestConvertConst<DT_UINT8, uint8, int32>(this);
|
||||
TestConvertConst<DT_INT16, int16, int32>(this);
|
||||
TestConvertConst<DT_UINT16, uint16, int32>(this);
|
||||
TestConvertConst<DT_INT32, int32, int32>(this);
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user