Use TensorRT DataType in TRT_ShapedWeights, so that we know the weights are always valid.

Also fix a bug in where it converts DT_INT8 (non-quantized type) to nvinfer1::kINT8  (quantized type).

PiperOrigin-RevId: 241269981
This commit is contained in:
Guangda Lai 2019-04-01 00:13:46 -07:00 committed by TensorFlower Gardener
parent 9fd90ad4c7
commit 6ff8dd7b5b
3 changed files with 156 additions and 111 deletions

View File

@ -98,15 +98,12 @@ namespace convert {
using absl::StrAppend;
using absl::StrCat;
inline Status ConvertDType(DataType tf_dtype, nvinfer1::DataType* trt_dtype) {
inline Status TfDataTypeToTrt(DataType tf_dtype,
nvinfer1::DataType* trt_dtype) {
switch (tf_dtype) {
case DataType::DT_FLOAT:
*trt_dtype = nvinfer1::DataType::kFLOAT;
break;
// TODO(aaroey): this should be DT_QINT8 which is not a well supported type.
case DataType::DT_INT8:
*trt_dtype = nvinfer1::DataType::kINT8;
break;
case DataType::DT_HALF:
*trt_dtype = nvinfer1::DataType::kHALF;
break;
@ -120,6 +117,25 @@ inline Status ConvertDType(DataType tf_dtype, nvinfer1::DataType* trt_dtype) {
return Status::OK();
}
inline Status TrtDataTypeToTf(nvinfer1::DataType trt_dtype,
DataType* tf_dtype) {
switch (trt_dtype) {
case nvinfer1::DataType::kFLOAT:
*tf_dtype = DataType::DT_FLOAT;
break;
case nvinfer1::DataType::kHALF:
*tf_dtype = DataType::DT_HALF;
break;
case nvinfer1::DataType::kINT32:
*tf_dtype = DataType::DT_INT32;
break;
default:
return errors::InvalidArgument("Unsupported data type ",
DebugString(trt_dtype));
}
return Status::OK();
}
class TFAttrs {
public:
explicit TFAttrs(const NodeDef& tf_node) {
@ -178,7 +194,7 @@ std::vector<float> TFAttrs::get<std::vector<float>>(const string& key) const {
template <>
nvinfer1::DataType TFAttrs::get<nvinfer1::DataType>(const string& key) const {
nvinfer1::DataType trt_dtype(nvinfer1::DataType::kFLOAT);
TF_CHECK_OK(ConvertDType(this->at(key)->type(), &trt_dtype));
TF_CHECK_OK(TfDataTypeToTrt(this->at(key)->type(), &trt_dtype));
return trt_dtype;
}
@ -268,7 +284,7 @@ Status ValidateTensorProperties(const string& producer_node_type,
nvinfer1::DataType* trt_dtype,
nvinfer1::Dims* trt_dims, int* batch_size) {
// Convert data type.
TF_RETURN_IF_ERROR(ConvertDType(dtype, trt_dtype));
TF_RETURN_IF_ERROR(TfDataTypeToTrt(dtype, trt_dtype));
// Convert shape.
if (shape.dims() < 0) {
@ -472,12 +488,12 @@ Status CreateBroadcastableScalarConstant(OpConverterParams* params, float value,
const nvinfer1::Dims& dims,
nvinfer1::ITensor** tensor,
const char* dtype_attr_name = "T") {
nvinfer1::DataType trt_dtype =
nvinfer1::DataType::kFLOAT; // Default to FP32.
TFAttrs attrs(params->node_def);
DataType dtype;
if (attrs.count(dtype_attr_name)) {
dtype = attrs.get<DataType>(dtype_attr_name);
} else {
dtype = DT_FLOAT; // Default to FP32.
DataType dtype = attrs.get<DataType>(dtype_attr_name);
TF_RETURN_IF_ERROR(TfDataTypeToTrt(dtype, &trt_dtype));
}
// In order to be broadcastable, the number of dims has to match.
@ -486,18 +502,18 @@ Status CreateBroadcastableScalarConstant(OpConverterParams* params, float value,
broadcastable_dims.d[i] = 1;
}
TRT_ShapedWeights weights =
params->weight_store->GetTempWeights(dtype, broadcastable_dims);
params->weight_store->GetTempWeights(trt_dtype, broadcastable_dims);
void* raw_ptr = weights.GetValues();
switch (dtype) {
case DataType::DT_FLOAT:
switch (trt_dtype) {
case nvinfer1::DataType::kFLOAT:
static_cast<float*>(raw_ptr)[0] = value;
break;
case DataType::DT_HALF:
case nvinfer1::DataType::kHALF:
static_cast<Eigen::half*>(raw_ptr)[0] = Eigen::half(value);
break;
default:
return errors::InvalidArgument("Unsupported data type ",
DataTypeString(dtype));
DebugString(trt_dtype));
}
*tensor = params->converter->CreateConstantLayer(weights, broadcastable_dims);
TFTRT_RETURN_ERROR_IF_NULLPTR(*tensor, params->node_def.name());
@ -676,12 +692,12 @@ Status VerifyShapesMatch(absl::Span<const TRT_TensorOrWeights> inputs,
return Status::OK();
}
TRT_ShapedWeights::TRT_ShapedWeights(DataType type) : type_(type) {
TRT_ShapedWeights::TRT_ShapedWeights(nvinfer1::DataType type) : type_(type) {
shape_.nbDims = 0;
}
TRT_ShapedWeights::TRT_ShapedWeights(DataType type, nvinfer1::Dims dims,
Tensor tensor)
TRT_ShapedWeights::TRT_ShapedWeights(nvinfer1::DataType type,
nvinfer1::Dims dims, Tensor tensor)
: shape_(dims), type_(type), tensor_(tensor) {}
TRT_ShapedWeights::TRT_ShapedWeights(const TRT_ShapedWeights& rhs)
@ -692,18 +708,29 @@ int64_t TRT_ShapedWeights::count() const {
}
nvinfer1::Weights TRT_ShapedWeights::GetTrtWeights() const {
nvinfer1::DataType trt_type(nvinfer1::DataType::kFLOAT);
TF_CHECK_OK(ConvertDType(type_, &trt_type));
return nvinfer1::Weights{trt_type, GetValues(), count()};
return nvinfer1::Weights{type_, GetValues(), count()};
}
size_t TRT_ShapedWeights::size_bytes() const {
return this->count() * DataTypeSize(this->type_);
size_t data_type_size = -1;
switch (type_) {
case nvinfer1::DataType::kFLOAT:
case nvinfer1::DataType::kINT32:
data_type_size = 4;
break;
case nvinfer1::DataType::kHALF:
data_type_size = 2;
break;
case nvinfer1::DataType::kINT8:
data_type_size = 1;
break;
}
return this->count() * data_type_size;
}
string TRT_ShapedWeights::DebugString() const {
return StrCat("TRT_ShapedWeights(shape=", convert::DebugString(shape_),
", type=", DataTypeString(type_),
", type=", convert::DebugString(type_),
", values=", reinterpret_cast<uintptr_t>(GetValues()), ")");
}
@ -867,13 +894,13 @@ void ReorderCKtoKC(const TRT_ShapedWeights& iweights,
oweights->shape_.d[1] = c;
const nvinfer1::DimsHW istrides = {1, k};
const nvinfer1::DimsHW ostrides = {c, 1};
switch (iweights.type_) {
case DataType::DT_FLOAT: {
switch (iweights.TrtDType()) {
case nvinfer1::DataType::kFLOAT: {
Reorder2({k, c}, static_cast<float const*>(iweights.GetValues()),
istrides, static_cast<float*>(oweights->GetValues()), ostrides);
break;
}
case DataType::DT_HALF: {
case nvinfer1::DataType::kHALF: {
Reorder2({k, c}, static_cast<Eigen::half const*>(iweights.GetValues()),
istrides, static_cast<Eigen::half*>(oweights->GetValues()),
ostrides);
@ -881,13 +908,13 @@ void ReorderCKtoKC(const TRT_ShapedWeights& iweights,
}
default:
LOG(FATAL) << "Unsupported type in reorder expected fp32 or fp16 but got "
<< DataTypeString(iweights.type_);
<< DebugString(iweights.TrtDType());
}
}
void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights,
TRT_ShapedWeights* oweights, const int num_groups) {
CHECK_EQ(iweights.type_, oweights->type_);
CHECK(iweights.TrtDType() == oweights->TrtDType());
CHECK_EQ(iweights.size_bytes(), oweights->size_bytes());
// K indexes over output channels, C over input channels, and R and S over the
// height and width of the convolution
@ -906,13 +933,13 @@ void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights,
oweights->shape_.d[3] = s;
const nvinfer1::DimsNCHW istrides = {1, k, s * k * c, c * k};
const nvinfer1::DimsNCHW ostrides = {c * r * s, r * s, s, 1};
switch (iweights.type_) {
case DataType::DT_FLOAT: {
switch (iweights.TrtDType()) {
case nvinfer1::DataType::kFLOAT: {
Reorder4({k, c, r, s}, static_cast<float const*>(iweights.GetValues()),
istrides, static_cast<float*>(oweights->GetValues()), ostrides);
break;
}
case DataType::DT_HALF: {
case nvinfer1::DataType::kHALF: {
Reorder4({k, c, r, s},
static_cast<Eigen::half const*>(iweights.GetValues()), istrides,
static_cast<Eigen::half*>(oweights->GetValues()), ostrides);
@ -921,18 +948,20 @@ void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights,
default:
LOG(FATAL) << "Unsupported type, expected fp32 or fp16 but got "
<< DataTypeString(iweights.type_);
<< DebugString(iweights.TrtDType());
}
}
TRT_ShapedWeights TrtWeightStore::GetTempWeights(DataType type,
TRT_ShapedWeights TrtWeightStore::GetTempWeights(nvinfer1::DataType trt_dtype,
const nvinfer1::Dims& dims) {
TensorShape shape;
DataType tf_dtype;
// TODO(laigd): make it return a status.
TF_CHECK_OK(TensorShapeUtils::MakeShape(dims.d, dims.nbDims, &shape));
TF_CHECK_OK(TrtDataTypeToTf(trt_dtype, &tf_dtype));
// TODO(jie): check weights size_bytes. 0 means type error
Tensor tensor(type, shape);
TRT_ShapedWeights weights(type, dims, tensor);
Tensor tensor(tf_dtype, shape);
TRT_ShapedWeights weights(trt_dtype, dims, tensor);
store_.emplace_back(std::move(tensor));
return weights;
}
@ -1282,22 +1311,22 @@ Status Converter::TransposeTensor(nvinfer1::ITensor* input_tensor,
Status Converter::GetWeightRange(const TRT_ShapedWeights& weights,
float* out_min, float* out_max) const {
switch (weights.type_) {
case DataType::DT_FLOAT: {
switch (weights.TrtDType()) {
case nvinfer1::DataType::kFLOAT: {
auto inp = static_cast<float const*>(weights.GetValues());
auto result = std::minmax_element(inp, inp + weights.count());
*out_min = *result.first;
*out_max = *result.second;
break;
}
case DataType::DT_HALF: {
case nvinfer1::DataType::kHALF: {
auto inp = static_cast<Eigen::half const*>(weights.GetValues());
auto result = std::minmax_element(inp, inp + weights.count());
*out_min = Eigen::half_impl::half_to_float(*result.first);
*out_max = Eigen::half_impl::half_to_float(*result.second);
break;
}
case DataType::DT_INT32: {
case nvinfer1::DataType::kINT32: {
auto inp = static_cast<int const*>(weights.GetValues());
auto result = std::minmax_element(inp, inp + weights.count());
*out_min = static_cast<float>(*result.first);
@ -1307,7 +1336,7 @@ Status Converter::GetWeightRange(const TRT_ShapedWeights& weights,
default:
return errors::Unimplemented(
"Data type not supported for GetWeightRange: ",
DataTypeString(weights.type_));
DebugString(weights.TrtDType()));
}
return Status::OK();
}
@ -1562,9 +1591,8 @@ Status AllowDataTypes(const OpConverterParams& params,
TRT_ShapedWeights ConvertFP32ToFP16(TrtWeightStore* store,
const TRT_ShapedWeights& weights_src) {
auto dtype_new = DataType::DT_HALF;
TRT_ShapedWeights weights =
store->GetTempWeights(dtype_new, weights_src.shape_);
store->GetTempWeights(nvinfer1::DataType::kHALF, weights_src.shape_);
const float* src = static_cast<const float*>(weights_src.GetValues());
Eigen::half* dst = static_cast<Eigen::half*>(weights.GetValues());
for (int64_t i = 0; i < weights_src.count(); i++) {
@ -1622,15 +1650,15 @@ std::function<Eigen::half(Eigen::half)> LambdaFactory::unary<Eigen::half>() {
Status UnaryCompute(const TRT_ShapedWeights& iweights,
TRT_ShapedWeights* oweights, LambdaFactory unary_op) {
CHECK_EQ(iweights.type_, oweights->type_);
switch (iweights.type_) {
case DataType::DT_FLOAT: {
CHECK(iweights.TrtDType() == oweights->TrtDType());
switch (iweights.TrtDType()) {
case nvinfer1::DataType::kFLOAT: {
auto inp = static_cast<float const*>(iweights.GetValues());
auto oup = static_cast<float*>(oweights->GetValues());
std::transform(inp, inp + iweights.count(), oup, unary_op.unary<float>());
break;
}
case DataType::DT_HALF: {
case nvinfer1::DataType::kHALF: {
auto inp = static_cast<Eigen::half const*>(iweights.GetValues());
auto oup = static_cast<Eigen::half*>(oweights->GetValues());
std::transform(inp, inp + iweights.count(), oup,
@ -1638,8 +1666,8 @@ Status UnaryCompute(const TRT_ShapedWeights& iweights,
break;
}
default:
return errors::Unimplemented("Data type not supported: " +
DataTypeString(iweights.type_));
return errors::Unimplemented("Data type not supported: ",
DebugString(iweights.TrtDType()));
}
return Status::OK();
}
@ -1660,10 +1688,6 @@ Status BinaryTensorOpWeight(OpConverterParams* params,
node_def.name());
}
// Check type consistency.
nvinfer1::DataType trt_dtype;
TF_RETURN_IF_ERROR(ConvertDType(weights.type_, &trt_dtype));
// Check scale mode.
auto dims_w = weights.shape_;
const auto dims_t = tensor->getDimensions();
@ -1753,9 +1777,9 @@ Status BinaryTensorOpWeight(OpConverterParams* params,
}
// Prepare weights
TRT_ShapedWeights shift_weights(weights.type_);
TRT_ShapedWeights scale_weights(weights.type_);
TRT_ShapedWeights power_weights(weights.type_);
TRT_ShapedWeights shift_weights(weights.TrtDType());
TRT_ShapedWeights scale_weights(weights.TrtDType());
TRT_ShapedWeights power_weights(weights.TrtDType());
if (node_def.op() == "Sub") {
if (swapped_inputs) {
@ -1922,7 +1946,7 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group,
TRT_ShapedWeights weights =
params->weight_store->GetTempWeights(weights_rsck);
ReorderRSCKToKCRS(weights_rsck, &weights, num_groups);
TRT_ShapedWeights biases(weights.type_);
TRT_ShapedWeights biases(weights.TrtDType());
const int output_axis = is_conv2d_backprop_input ? 1 : 0;
const int noutput = weights.shape_.d[output_axis] * num_groups;
nvinfer1::DimsHW kernel_size;
@ -3022,7 +3046,7 @@ Status ConvertBiasAdd(OpConverterParams* params) {
mode = nvinfer1::ScaleMode::kUNIFORM;
}
TRT_ShapedWeights empty_weights(weights.type_);
TRT_ShapedWeights empty_weights(weights.TrtDType());
nvinfer1::IScaleLayer* layer = params->converter->network()->addScale(
*tensor, mode, weights.GetTrtWeights(), empty_weights.GetTrtWeights(),
empty_weights.GetTrtWeights());
@ -3072,33 +3096,41 @@ void GetTensorDimsWithProtoShape(const Tensor& tensor, nvinfer1::Dims* dims) {
}
}
template <DataType dtype>
void CopyToTrtInt32Array(const Tensor& tensor, int32* dst) {
typedef typename EnumToDataType<dtype>::Type CType;
const CType* src = tensor.flat<CType>().data();
std::copy(src, src + tensor.NumElements(), dst);
}
Status TfTensorToTrtWeights(const Tensor& tensor, TrtWeightStore* weight_store,
TRT_ShapedWeights* weights) {
const DataType dtype = tensor.dtype();
// We always convert the integer constants to INT32, since TRT INT8 is for
// quantized inference.
// We always convert the integer constants to INT32.
//
// TODO(aaroey): FP16 will remain in half format and is not converted to
// FP32, but the converter currently uses all float weights as FP32. Fix
// this.
const DataType converted_dtype =
(dtype == DT_INT16 || dtype == DT_INT8 || dtype == DT_UINT8 ? DT_INT32
: dtype);
DataType converted_dtype = dtype;
if (dtype == DataType::DT_INT8 || dtype == DataType::DT_UINT8 ||
dtype == DataType::DT_INT16 || dtype == DataType::DT_UINT16) {
converted_dtype = DT_INT32;
}
// Verify that the dtype is supported by TensorRT. Otherwise, return an error.
nvinfer1::DataType trt_dtype;
TF_RETURN_IF_ERROR(ConvertDType(converted_dtype, &trt_dtype));
TF_RETURN_IF_ERROR(TfDataTypeToTrt(converted_dtype, &trt_dtype));
if (tensor.NumElements() == 0) {
// Return empty weights having converted dtype.
*weights = TRT_ShapedWeights(converted_dtype);
// Return empty weights.
*weights = TRT_ShapedWeights(trt_dtype);
return Status::OK();
}
nvinfer1::Dims weight_dims;
GetTensorDimsWithProtoShape(tensor, &weight_dims);
*weights = weight_store->GetTempWeights(converted_dtype, weight_dims);
*weights = weight_store->GetTempWeights(trt_dtype, weight_dims);
// Copy the tensor directly if the tensor does not require cast to the
// supported type.
@ -3110,17 +3142,21 @@ Status TfTensorToTrtWeights(const Tensor& tensor, TrtWeightStore* weight_store,
// Copy tensor elements after casting them to the converted DataType.
int32* dst = static_cast<int32*>(weights->GetValues());
if (dtype == DT_INT16) {
const int16* src = tensor.flat<int16>().data();
std::copy(src, src + tensor.NumElements(), dst);
} else if (dtype == DT_INT8) {
const int8* src = tensor.flat<int8>().data();
std::copy(src, src + tensor.NumElements(), dst);
} else {
// dtype can only be DT_UINT8 at this point.
TFTRT_CHECK_EQ_TYPE(dtype, DT_UINT8);
const uint8* src = tensor.flat<uint8>().data();
std::copy(src, src + tensor.NumElements(), dst);
switch (dtype) {
case DT_INT8:
CopyToTrtInt32Array<DT_INT8>(tensor, dst);
break;
case DT_UINT8:
CopyToTrtInt32Array<DT_UINT8>(tensor, dst);
break;
case DT_INT16:
CopyToTrtInt32Array<DT_INT16>(tensor, dst);
break;
case DT_UINT16:
CopyToTrtInt32Array<DT_UINT16>(tensor, dst);
break;
default:
return errors::Internal("Unexpected DataType: ", DataTypeString(dtype));
}
return Status::OK();
}
@ -3782,15 +3818,15 @@ Status ConvertFusedBatchNorm(OpConverterParams* params) {
nvinfer1::ITensor* tensor = inputs.at(0).tensor();
// Check parameter types
auto parameter_type = inputs.at(1).weights().type_;
if ((parameter_type != DataType::DT_FLOAT) &&
(parameter_type != DataType::DT_HALF)) {
auto parameter_type = inputs.at(1).weights().TrtDType();
if ((parameter_type != nvinfer1::DataType::kFLOAT) &&
(parameter_type != nvinfer1::DataType::kHALF)) {
return errors::Unimplemented(
"only float32 or float16 weight data type is supported, for node " +
node_def.name() + " got " + DataTypeString(parameter_type));
"Only float32 or float16 weight data type is supported, for node ",
node_def.name(), " got ", DebugString(parameter_type));
}
for (int i = 1; i < 5; i++) {
if (inputs.at(i).weights().type_ != parameter_type) {
if (inputs.at(i).weights().TrtDType() != parameter_type) {
return errors::Unimplemented(
"Inconsistent parameter type for batchnorm is not supported, at: " +
node_def.name());
@ -3841,16 +3877,16 @@ Status ConvertFusedBatchNorm(OpConverterParams* params) {
float batchnorm_data[4];
for (int j = 0; j < 4; j++) {
if (inputs.at(j + 1).weights().count() != 1) {
if (parameter_type == DT_FLOAT) {
if (parameter_type == nvinfer1::DataType::kFLOAT) {
batchnorm_data[j] = vals_array[j][i];
} else if (parameter_type == DT_HALF) {
} else if (parameter_type == nvinfer1::DataType::kHALF) {
batchnorm_data[j] =
Eigen::half_impl::half_to_float(cast_vals_array[j][i]);
}
} else {
if (parameter_type == DT_FLOAT) {
if (parameter_type == nvinfer1::DataType::kFLOAT) {
batchnorm_data[j] = vals_array[j][0];
} else if (parameter_type == DT_HALF) {
} else if (parameter_type == nvinfer1::DataType::kHALF) {
batchnorm_data[j] =
Eigen::half_impl::half_to_float(cast_vals_array[j][0]);
}
@ -3862,10 +3898,10 @@ Status ConvertFusedBatchNorm(OpConverterParams* params) {
float variance = batchnorm_data[3];
float combined_scale_val = scale / sqrtf(variance + epsilon);
float combined_offset_val = offset - mean * combined_scale_val;
if (parameter_type == DT_FLOAT) {
if (parameter_type == nvinfer1::DataType::kFLOAT) {
combined_scale_vals[i] = combined_scale_val;
combined_offset_vals[i] = combined_offset_val;
} else if (parameter_type == DT_HALF) {
} else if (parameter_type == nvinfer1::DataType::kHALF) {
cast_combined_scale_vals[i] = Eigen::half(combined_scale_val);
cast_combined_offset_vals[i] = Eigen::half(combined_offset_val);
}
@ -3962,14 +3998,14 @@ Status ConvertMatMulHelper(OpConverterParams* params,
}
nvinfer1::ITensor* tensor = tensor_input.tensor();
TRT_ShapedWeights weights(weights_raw.type_);
TRT_ShapedWeights weights(weights_raw.TrtDType());
if (transpose_weight) {
weights = weights_raw;
} else {
weights = params->weight_store->GetTempWeights(weights_raw);
ReorderCKtoKC(weights_raw, &weights);
}
TRT_ShapedWeights biases(weights.type_);
TRT_ShapedWeights biases(weights.TrtDType());
int noutput = weights.shape_.d[0];
@ -4472,7 +4508,7 @@ Status ConvertGraphDefToEngine(
TFAttrs attrs(node_def);
DataType tf_dtype = attrs.get<DataType>("T");
nvinfer1::DataType trt_dtype;
TF_RETURN_IF_ERROR(ConvertDType(tf_dtype, &trt_dtype));
TF_RETURN_IF_ERROR(TfDataTypeToTrt(tf_dtype, &trt_dtype));
if (output_tensors.size() <= slot_number) {
output_tensors.resize(slot_number + 1);
}

View File

@ -176,7 +176,8 @@ int64_t TrtTensorDimsNumElements(const nvinfer1::Dims& dims);
// Class to convert TF compile-time constants (e.g. Const nodes) to TRT weight.
class TRT_ShapedWeights {
public:
explicit TRT_ShapedWeights(DataType type = DT_FLOAT);
explicit TRT_ShapedWeights(
nvinfer1::DataType type = nvinfer1::DataType::kFLOAT);
// Copy from another weights.
//
@ -211,14 +212,18 @@ class TRT_ShapedWeights {
return std::vector<T>(span.data(), span.data() + span.size());
}
nvinfer1::DataType TrtDType() const { return type_; }
// TODO(aaroey): make these private.
nvinfer1::Dims shape_; // Note: shape.type[] is not used.
DataType type_;
private:
// This constructor is only used by TrtWeightStore, which creates the
// underlying buffer.
TRT_ShapedWeights(DataType type, nvinfer1::Dims dims, Tensor tensor);
TRT_ShapedWeights(nvinfer1::DataType type, nvinfer1::Dims dims,
Tensor tensor);
nvinfer1::DataType type_;
// All weights should be stored inside TrtWeightStore to make sure lifetime of
// all the underlying tensors are available until the engine is built. For
@ -239,12 +244,13 @@ class TRT_ShapedWeights {
class TrtWeightStore {
public:
// Get a TRT_ShapedWeights with 'type' and 'dims'.
TRT_ShapedWeights GetTempWeights(DataType type, const nvinfer1::Dims& dims);
TRT_ShapedWeights GetTempWeights(nvinfer1::DataType trt_type,
const nvinfer1::Dims& dims);
// Get a TRT_ShapedWeights with the same data type and dimensions as
// 'weights'.
TRT_ShapedWeights GetTempWeights(const TRT_ShapedWeights& weights) {
return GetTempWeights(weights.type_, weights.shape_);
return GetTempWeights(weights.TrtDType(), weights.shape_);
}
private:

View File

@ -186,8 +186,8 @@ void ExpectArrayNear(const std::vector<Eigen::half>& lhs,
bool TrtShapedWeightsEquals(const TRT_ShapedWeights& lhs,
const TRT_ShapedWeights& rhs) {
return TrtDimsEquals(lhs.shape_, rhs.shape_) && lhs.type_ == rhs.type_ &&
lhs.GetValues() == rhs.GetValues();
return TrtDimsEquals(lhs.shape_, rhs.shape_) &&
lhs.TrtDType() == rhs.TrtDType() && lhs.GetValues() == rhs.GetValues();
}
template <typename T>
@ -293,7 +293,7 @@ TEST(TRT_ShapedWeights_Test, Basic) {
}
// Test constructor with DataType argument.
{
TRT_ShapedWeights weights(DT_FLOAT);
TRT_ShapedWeights weights(nvinfer1::DataType::kFLOAT);
TRT_ShapedWeights copy(weights);
for (auto ptr : {&weights, &copy}) {
nvinfer1::Weights trt_weights = ptr->GetTrtWeights();
@ -310,7 +310,7 @@ TEST(TRT_ShapedWeights_Test, Basic) {
{
TrtWeightStore store;
TRT_ShapedWeights weights =
store.GetTempWeights(DT_FLOAT, GetTestDims({2, 5}));
store.GetTempWeights(nvinfer1::DataType::kFLOAT, GetTestDims({2, 5}));
TRT_ShapedWeights copy(weights);
for (auto ptr : {&weights, &copy}) {
nvinfer1::Weights trt_weights = ptr->GetTrtWeights();
@ -671,7 +671,7 @@ TEST_F(ConverterTest, RenameAndMarkOutputTensors) {
params->outputs->emplace_back(output_tensor);
output_tensors.push_back(output_tensor);
}
TRT_ShapedWeights output_weights(DT_FLOAT);
TRT_ShapedWeights output_weights(nvinfer1::DataType::kFLOAT);
params->outputs->emplace_back(output_weights);
return Status::OK();
};
@ -778,8 +778,8 @@ TEST_F(ConverterTest, PrepareTensorForShape_Tensor) {
}
TEST_F(ConverterTest, PrepareTensorForShape_Weights) {
TRT_ShapedWeights weights =
weight_store_->GetTempWeights(DT_FLOAT, GetTestDims({2, 3, 5}));
TRT_ShapedWeights weights = weight_store_->GetTempWeights(
nvinfer1::DataType::kFLOAT, GetTestDims({2, 3, 5}));
nvinfer1::ITensor* output_tensor = nullptr;
for (bool validation_only : {false, true}) {
TF_EXPECT_OK(converter_->PrepareTensorForShape(
@ -832,8 +832,8 @@ TEST_F(ConverterTest, AddAndGetTensorOrWeights) {
template <typename T>
void TestGetWeightRange(ConverterTest* test, TrtWeightStore* weight_store) {
TRT_ShapedWeights weights =
weight_store->GetTempWeights(DataTypeToEnum<T>::v(), GetTestDims({2, 3}));
TRT_ShapedWeights weights = weight_store->GetTempWeights(
TfDataTypeToTrt(DataTypeToEnum<T>::v()), GetTestDims({2, 3}));
const std::vector<T> values = {T(3), T(1), T(2), T(6), T(5), T(4)};
memcpy(weights.GetValues(), values.data(), weights.size_bytes());
@ -1002,14 +1002,14 @@ TEST_F(ConverterTest, GetTrtBroadcastShape) {
}
TEST_F(ConverterTest, CreateConstantLayer) {
for (auto dtype : {DT_FLOAT, DT_INT32}) {
for (auto dtype : {nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT32}) {
TRT_ShapedWeights weights =
weight_store_->GetTempWeights(dtype, GetTestDims({2, 3, 5}));
nvinfer1::ITensor* tensor =
converter_->CreateConstantLayer(weights, GetTestDims({3, 10}));
ASSERT_NE(nullptr, tensor);
EXPECT_EQ(TfDataTypeToTrt(dtype), tensor->getType())
<< "Expected " << DebugString(TfDataTypeToTrt(dtype)) << " vs. actual "
EXPECT_EQ(dtype, tensor->getType())
<< "Expected " << DebugString(dtype) << " vs. actual "
<< DebugString(tensor->getType());
ExpectTrtDimsEqualsArray({3, 10}, tensor->getDimensions());
}
@ -1246,7 +1246,7 @@ class OpConverterTest : public ::testing::Test {
template <typename T>
void AddTestWeights(const string& name, const std::vector<int>& dims,
const std::vector<T>& values) {
const DataType dtype = DataTypeToEnum<T>::v();
const nvinfer1::DataType dtype = TfDataTypeToTrt(DataTypeToEnum<T>::v());
const nvinfer1::Dims trt_dims = GetTestDims(dims);
const int64_t num_elements = TrtWeightDimsNumElements(trt_dims);
QCHECK_EQ(num_elements, values.size())
@ -1452,6 +1452,9 @@ TEST_F(OpConverterTest, ConvertConst) {
TestConvertConst<DT_FLOAT, float, float>(this);
TestConvertConst<DT_INT8, int8, int32>(this);
TestConvertConst<DT_UINT8, uint8, int32>(this);
TestConvertConst<DT_INT16, int16, int32>(this);
TestConvertConst<DT_UINT16, uint16, int32>(this);
TestConvertConst<DT_INT32, int32, int32>(this);
}