From 6ff8dd7b5b07b039b3af025b447aa22ed055c507 Mon Sep 17 00:00:00 2001 From: Guangda Lai Date: Mon, 1 Apr 2019 00:13:46 -0700 Subject: [PATCH] Use TensorRT DataType in TRT_ShapedWeights, so that we know the weights are always valid. Also fix a bug in where it converts DT_INT8 (non-quantized type) to nvinfer1::kINT8 (quantized type). PiperOrigin-RevId: 241269981 --- .../tf2tensorrt/convert/convert_nodes.cc | 222 ++++++++++-------- .../tf2tensorrt/convert/convert_nodes.h | 16 +- .../tf2tensorrt/convert/convert_nodes_test.cc | 29 ++- 3 files changed, 156 insertions(+), 111 deletions(-) diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index 931b3d9d7f6..1b7e47701f5 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -98,15 +98,12 @@ namespace convert { using absl::StrAppend; using absl::StrCat; -inline Status ConvertDType(DataType tf_dtype, nvinfer1::DataType* trt_dtype) { +inline Status TfDataTypeToTrt(DataType tf_dtype, + nvinfer1::DataType* trt_dtype) { switch (tf_dtype) { case DataType::DT_FLOAT: *trt_dtype = nvinfer1::DataType::kFLOAT; break; - // TODO(aaroey): this should be DT_QINT8 which is not a well supported type. - case DataType::DT_INT8: - *trt_dtype = nvinfer1::DataType::kINT8; - break; case DataType::DT_HALF: *trt_dtype = nvinfer1::DataType::kHALF; break; @@ -120,6 +117,25 @@ inline Status ConvertDType(DataType tf_dtype, nvinfer1::DataType* trt_dtype) { return Status::OK(); } +inline Status TrtDataTypeToTf(nvinfer1::DataType trt_dtype, + DataType* tf_dtype) { + switch (trt_dtype) { + case nvinfer1::DataType::kFLOAT: + *tf_dtype = DataType::DT_FLOAT; + break; + case nvinfer1::DataType::kHALF: + *tf_dtype = DataType::DT_HALF; + break; + case nvinfer1::DataType::kINT32: + *tf_dtype = DataType::DT_INT32; + break; + default: + return errors::InvalidArgument("Unsupported data type ", + DebugString(trt_dtype)); + } + return Status::OK(); +} + class TFAttrs { public: explicit TFAttrs(const NodeDef& tf_node) { @@ -178,7 +194,7 @@ std::vector TFAttrs::get>(const string& key) const { template <> nvinfer1::DataType TFAttrs::get(const string& key) const { nvinfer1::DataType trt_dtype(nvinfer1::DataType::kFLOAT); - TF_CHECK_OK(ConvertDType(this->at(key)->type(), &trt_dtype)); + TF_CHECK_OK(TfDataTypeToTrt(this->at(key)->type(), &trt_dtype)); return trt_dtype; } @@ -268,7 +284,7 @@ Status ValidateTensorProperties(const string& producer_node_type, nvinfer1::DataType* trt_dtype, nvinfer1::Dims* trt_dims, int* batch_size) { // Convert data type. - TF_RETURN_IF_ERROR(ConvertDType(dtype, trt_dtype)); + TF_RETURN_IF_ERROR(TfDataTypeToTrt(dtype, trt_dtype)); // Convert shape. if (shape.dims() < 0) { @@ -472,12 +488,12 @@ Status CreateBroadcastableScalarConstant(OpConverterParams* params, float value, const nvinfer1::Dims& dims, nvinfer1::ITensor** tensor, const char* dtype_attr_name = "T") { + nvinfer1::DataType trt_dtype = + nvinfer1::DataType::kFLOAT; // Default to FP32. TFAttrs attrs(params->node_def); - DataType dtype; if (attrs.count(dtype_attr_name)) { - dtype = attrs.get(dtype_attr_name); - } else { - dtype = DT_FLOAT; // Default to FP32. + DataType dtype = attrs.get(dtype_attr_name); + TF_RETURN_IF_ERROR(TfDataTypeToTrt(dtype, &trt_dtype)); } // In order to be broadcastable, the number of dims has to match. @@ -486,18 +502,18 @@ Status CreateBroadcastableScalarConstant(OpConverterParams* params, float value, broadcastable_dims.d[i] = 1; } TRT_ShapedWeights weights = - params->weight_store->GetTempWeights(dtype, broadcastable_dims); + params->weight_store->GetTempWeights(trt_dtype, broadcastable_dims); void* raw_ptr = weights.GetValues(); - switch (dtype) { - case DataType::DT_FLOAT: + switch (trt_dtype) { + case nvinfer1::DataType::kFLOAT: static_cast(raw_ptr)[0] = value; break; - case DataType::DT_HALF: + case nvinfer1::DataType::kHALF: static_cast(raw_ptr)[0] = Eigen::half(value); break; default: return errors::InvalidArgument("Unsupported data type ", - DataTypeString(dtype)); + DebugString(trt_dtype)); } *tensor = params->converter->CreateConstantLayer(weights, broadcastable_dims); TFTRT_RETURN_ERROR_IF_NULLPTR(*tensor, params->node_def.name()); @@ -676,12 +692,12 @@ Status VerifyShapesMatch(absl::Span inputs, return Status::OK(); } -TRT_ShapedWeights::TRT_ShapedWeights(DataType type) : type_(type) { +TRT_ShapedWeights::TRT_ShapedWeights(nvinfer1::DataType type) : type_(type) { shape_.nbDims = 0; } -TRT_ShapedWeights::TRT_ShapedWeights(DataType type, nvinfer1::Dims dims, - Tensor tensor) +TRT_ShapedWeights::TRT_ShapedWeights(nvinfer1::DataType type, + nvinfer1::Dims dims, Tensor tensor) : shape_(dims), type_(type), tensor_(tensor) {} TRT_ShapedWeights::TRT_ShapedWeights(const TRT_ShapedWeights& rhs) @@ -692,18 +708,29 @@ int64_t TRT_ShapedWeights::count() const { } nvinfer1::Weights TRT_ShapedWeights::GetTrtWeights() const { - nvinfer1::DataType trt_type(nvinfer1::DataType::kFLOAT); - TF_CHECK_OK(ConvertDType(type_, &trt_type)); - return nvinfer1::Weights{trt_type, GetValues(), count()}; + return nvinfer1::Weights{type_, GetValues(), count()}; } size_t TRT_ShapedWeights::size_bytes() const { - return this->count() * DataTypeSize(this->type_); + size_t data_type_size = -1; + switch (type_) { + case nvinfer1::DataType::kFLOAT: + case nvinfer1::DataType::kINT32: + data_type_size = 4; + break; + case nvinfer1::DataType::kHALF: + data_type_size = 2; + break; + case nvinfer1::DataType::kINT8: + data_type_size = 1; + break; + } + return this->count() * data_type_size; } string TRT_ShapedWeights::DebugString() const { return StrCat("TRT_ShapedWeights(shape=", convert::DebugString(shape_), - ", type=", DataTypeString(type_), + ", type=", convert::DebugString(type_), ", values=", reinterpret_cast(GetValues()), ")"); } @@ -867,13 +894,13 @@ void ReorderCKtoKC(const TRT_ShapedWeights& iweights, oweights->shape_.d[1] = c; const nvinfer1::DimsHW istrides = {1, k}; const nvinfer1::DimsHW ostrides = {c, 1}; - switch (iweights.type_) { - case DataType::DT_FLOAT: { + switch (iweights.TrtDType()) { + case nvinfer1::DataType::kFLOAT: { Reorder2({k, c}, static_cast(iweights.GetValues()), istrides, static_cast(oweights->GetValues()), ostrides); break; } - case DataType::DT_HALF: { + case nvinfer1::DataType::kHALF: { Reorder2({k, c}, static_cast(iweights.GetValues()), istrides, static_cast(oweights->GetValues()), ostrides); @@ -881,13 +908,13 @@ void ReorderCKtoKC(const TRT_ShapedWeights& iweights, } default: LOG(FATAL) << "Unsupported type in reorder expected fp32 or fp16 but got " - << DataTypeString(iweights.type_); + << DebugString(iweights.TrtDType()); } } void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights, TRT_ShapedWeights* oweights, const int num_groups) { - CHECK_EQ(iweights.type_, oweights->type_); + CHECK(iweights.TrtDType() == oweights->TrtDType()); CHECK_EQ(iweights.size_bytes(), oweights->size_bytes()); // K indexes over output channels, C over input channels, and R and S over the // height and width of the convolution @@ -906,13 +933,13 @@ void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights, oweights->shape_.d[3] = s; const nvinfer1::DimsNCHW istrides = {1, k, s * k * c, c * k}; const nvinfer1::DimsNCHW ostrides = {c * r * s, r * s, s, 1}; - switch (iweights.type_) { - case DataType::DT_FLOAT: { + switch (iweights.TrtDType()) { + case nvinfer1::DataType::kFLOAT: { Reorder4({k, c, r, s}, static_cast(iweights.GetValues()), istrides, static_cast(oweights->GetValues()), ostrides); break; } - case DataType::DT_HALF: { + case nvinfer1::DataType::kHALF: { Reorder4({k, c, r, s}, static_cast(iweights.GetValues()), istrides, static_cast(oweights->GetValues()), ostrides); @@ -921,18 +948,20 @@ void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights, default: LOG(FATAL) << "Unsupported type, expected fp32 or fp16 but got " - << DataTypeString(iweights.type_); + << DebugString(iweights.TrtDType()); } } -TRT_ShapedWeights TrtWeightStore::GetTempWeights(DataType type, +TRT_ShapedWeights TrtWeightStore::GetTempWeights(nvinfer1::DataType trt_dtype, const nvinfer1::Dims& dims) { TensorShape shape; + DataType tf_dtype; // TODO(laigd): make it return a status. TF_CHECK_OK(TensorShapeUtils::MakeShape(dims.d, dims.nbDims, &shape)); + TF_CHECK_OK(TrtDataTypeToTf(trt_dtype, &tf_dtype)); // TODO(jie): check weights size_bytes. 0 means type error - Tensor tensor(type, shape); - TRT_ShapedWeights weights(type, dims, tensor); + Tensor tensor(tf_dtype, shape); + TRT_ShapedWeights weights(trt_dtype, dims, tensor); store_.emplace_back(std::move(tensor)); return weights; } @@ -1282,22 +1311,22 @@ Status Converter::TransposeTensor(nvinfer1::ITensor* input_tensor, Status Converter::GetWeightRange(const TRT_ShapedWeights& weights, float* out_min, float* out_max) const { - switch (weights.type_) { - case DataType::DT_FLOAT: { + switch (weights.TrtDType()) { + case nvinfer1::DataType::kFLOAT: { auto inp = static_cast(weights.GetValues()); auto result = std::minmax_element(inp, inp + weights.count()); *out_min = *result.first; *out_max = *result.second; break; } - case DataType::DT_HALF: { + case nvinfer1::DataType::kHALF: { auto inp = static_cast(weights.GetValues()); auto result = std::minmax_element(inp, inp + weights.count()); *out_min = Eigen::half_impl::half_to_float(*result.first); *out_max = Eigen::half_impl::half_to_float(*result.second); break; } - case DataType::DT_INT32: { + case nvinfer1::DataType::kINT32: { auto inp = static_cast(weights.GetValues()); auto result = std::minmax_element(inp, inp + weights.count()); *out_min = static_cast(*result.first); @@ -1307,7 +1336,7 @@ Status Converter::GetWeightRange(const TRT_ShapedWeights& weights, default: return errors::Unimplemented( "Data type not supported for GetWeightRange: ", - DataTypeString(weights.type_)); + DebugString(weights.TrtDType())); } return Status::OK(); } @@ -1562,9 +1591,8 @@ Status AllowDataTypes(const OpConverterParams& params, TRT_ShapedWeights ConvertFP32ToFP16(TrtWeightStore* store, const TRT_ShapedWeights& weights_src) { - auto dtype_new = DataType::DT_HALF; TRT_ShapedWeights weights = - store->GetTempWeights(dtype_new, weights_src.shape_); + store->GetTempWeights(nvinfer1::DataType::kHALF, weights_src.shape_); const float* src = static_cast(weights_src.GetValues()); Eigen::half* dst = static_cast(weights.GetValues()); for (int64_t i = 0; i < weights_src.count(); i++) { @@ -1622,15 +1650,15 @@ std::function LambdaFactory::unary() { Status UnaryCompute(const TRT_ShapedWeights& iweights, TRT_ShapedWeights* oweights, LambdaFactory unary_op) { - CHECK_EQ(iweights.type_, oweights->type_); - switch (iweights.type_) { - case DataType::DT_FLOAT: { + CHECK(iweights.TrtDType() == oweights->TrtDType()); + switch (iweights.TrtDType()) { + case nvinfer1::DataType::kFLOAT: { auto inp = static_cast(iweights.GetValues()); auto oup = static_cast(oweights->GetValues()); std::transform(inp, inp + iweights.count(), oup, unary_op.unary()); break; } - case DataType::DT_HALF: { + case nvinfer1::DataType::kHALF: { auto inp = static_cast(iweights.GetValues()); auto oup = static_cast(oweights->GetValues()); std::transform(inp, inp + iweights.count(), oup, @@ -1638,8 +1666,8 @@ Status UnaryCompute(const TRT_ShapedWeights& iweights, break; } default: - return errors::Unimplemented("Data type not supported: " + - DataTypeString(iweights.type_)); + return errors::Unimplemented("Data type not supported: ", + DebugString(iweights.TrtDType())); } return Status::OK(); } @@ -1660,10 +1688,6 @@ Status BinaryTensorOpWeight(OpConverterParams* params, node_def.name()); } - // Check type consistency. - nvinfer1::DataType trt_dtype; - TF_RETURN_IF_ERROR(ConvertDType(weights.type_, &trt_dtype)); - // Check scale mode. auto dims_w = weights.shape_; const auto dims_t = tensor->getDimensions(); @@ -1753,9 +1777,9 @@ Status BinaryTensorOpWeight(OpConverterParams* params, } // Prepare weights - TRT_ShapedWeights shift_weights(weights.type_); - TRT_ShapedWeights scale_weights(weights.type_); - TRT_ShapedWeights power_weights(weights.type_); + TRT_ShapedWeights shift_weights(weights.TrtDType()); + TRT_ShapedWeights scale_weights(weights.TrtDType()); + TRT_ShapedWeights power_weights(weights.TrtDType()); if (node_def.op() == "Sub") { if (swapped_inputs) { @@ -1922,7 +1946,7 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group, TRT_ShapedWeights weights = params->weight_store->GetTempWeights(weights_rsck); ReorderRSCKToKCRS(weights_rsck, &weights, num_groups); - TRT_ShapedWeights biases(weights.type_); + TRT_ShapedWeights biases(weights.TrtDType()); const int output_axis = is_conv2d_backprop_input ? 1 : 0; const int noutput = weights.shape_.d[output_axis] * num_groups; nvinfer1::DimsHW kernel_size; @@ -3022,7 +3046,7 @@ Status ConvertBiasAdd(OpConverterParams* params) { mode = nvinfer1::ScaleMode::kUNIFORM; } - TRT_ShapedWeights empty_weights(weights.type_); + TRT_ShapedWeights empty_weights(weights.TrtDType()); nvinfer1::IScaleLayer* layer = params->converter->network()->addScale( *tensor, mode, weights.GetTrtWeights(), empty_weights.GetTrtWeights(), empty_weights.GetTrtWeights()); @@ -3072,33 +3096,41 @@ void GetTensorDimsWithProtoShape(const Tensor& tensor, nvinfer1::Dims* dims) { } } +template +void CopyToTrtInt32Array(const Tensor& tensor, int32* dst) { + typedef typename EnumToDataType::Type CType; + const CType* src = tensor.flat().data(); + std::copy(src, src + tensor.NumElements(), dst); +} + Status TfTensorToTrtWeights(const Tensor& tensor, TrtWeightStore* weight_store, TRT_ShapedWeights* weights) { const DataType dtype = tensor.dtype(); - // We always convert the integer constants to INT32, since TRT INT8 is for - // quantized inference. + // We always convert the integer constants to INT32. // // TODO(aaroey): FP16 will remain in half format and is not converted to // FP32, but the converter currently uses all float weights as FP32. Fix // this. - const DataType converted_dtype = - (dtype == DT_INT16 || dtype == DT_INT8 || dtype == DT_UINT8 ? DT_INT32 - : dtype); + DataType converted_dtype = dtype; + if (dtype == DataType::DT_INT8 || dtype == DataType::DT_UINT8 || + dtype == DataType::DT_INT16 || dtype == DataType::DT_UINT16) { + converted_dtype = DT_INT32; + } // Verify that the dtype is supported by TensorRT. Otherwise, return an error. nvinfer1::DataType trt_dtype; - TF_RETURN_IF_ERROR(ConvertDType(converted_dtype, &trt_dtype)); + TF_RETURN_IF_ERROR(TfDataTypeToTrt(converted_dtype, &trt_dtype)); if (tensor.NumElements() == 0) { - // Return empty weights having converted dtype. - *weights = TRT_ShapedWeights(converted_dtype); + // Return empty weights. + *weights = TRT_ShapedWeights(trt_dtype); return Status::OK(); } nvinfer1::Dims weight_dims; GetTensorDimsWithProtoShape(tensor, &weight_dims); - *weights = weight_store->GetTempWeights(converted_dtype, weight_dims); + *weights = weight_store->GetTempWeights(trt_dtype, weight_dims); // Copy the tensor directly if the tensor does not require cast to the // supported type. @@ -3110,17 +3142,21 @@ Status TfTensorToTrtWeights(const Tensor& tensor, TrtWeightStore* weight_store, // Copy tensor elements after casting them to the converted DataType. int32* dst = static_cast(weights->GetValues()); - if (dtype == DT_INT16) { - const int16* src = tensor.flat().data(); - std::copy(src, src + tensor.NumElements(), dst); - } else if (dtype == DT_INT8) { - const int8* src = tensor.flat().data(); - std::copy(src, src + tensor.NumElements(), dst); - } else { - // dtype can only be DT_UINT8 at this point. - TFTRT_CHECK_EQ_TYPE(dtype, DT_UINT8); - const uint8* src = tensor.flat().data(); - std::copy(src, src + tensor.NumElements(), dst); + switch (dtype) { + case DT_INT8: + CopyToTrtInt32Array(tensor, dst); + break; + case DT_UINT8: + CopyToTrtInt32Array(tensor, dst); + break; + case DT_INT16: + CopyToTrtInt32Array(tensor, dst); + break; + case DT_UINT16: + CopyToTrtInt32Array(tensor, dst); + break; + default: + return errors::Internal("Unexpected DataType: ", DataTypeString(dtype)); } return Status::OK(); } @@ -3782,15 +3818,15 @@ Status ConvertFusedBatchNorm(OpConverterParams* params) { nvinfer1::ITensor* tensor = inputs.at(0).tensor(); // Check parameter types - auto parameter_type = inputs.at(1).weights().type_; - if ((parameter_type != DataType::DT_FLOAT) && - (parameter_type != DataType::DT_HALF)) { + auto parameter_type = inputs.at(1).weights().TrtDType(); + if ((parameter_type != nvinfer1::DataType::kFLOAT) && + (parameter_type != nvinfer1::DataType::kHALF)) { return errors::Unimplemented( - "only float32 or float16 weight data type is supported, for node " + - node_def.name() + " got " + DataTypeString(parameter_type)); + "Only float32 or float16 weight data type is supported, for node ", + node_def.name(), " got ", DebugString(parameter_type)); } for (int i = 1; i < 5; i++) { - if (inputs.at(i).weights().type_ != parameter_type) { + if (inputs.at(i).weights().TrtDType() != parameter_type) { return errors::Unimplemented( "Inconsistent parameter type for batchnorm is not supported, at: " + node_def.name()); @@ -3841,16 +3877,16 @@ Status ConvertFusedBatchNorm(OpConverterParams* params) { float batchnorm_data[4]; for (int j = 0; j < 4; j++) { if (inputs.at(j + 1).weights().count() != 1) { - if (parameter_type == DT_FLOAT) { + if (parameter_type == nvinfer1::DataType::kFLOAT) { batchnorm_data[j] = vals_array[j][i]; - } else if (parameter_type == DT_HALF) { + } else if (parameter_type == nvinfer1::DataType::kHALF) { batchnorm_data[j] = Eigen::half_impl::half_to_float(cast_vals_array[j][i]); } } else { - if (parameter_type == DT_FLOAT) { + if (parameter_type == nvinfer1::DataType::kFLOAT) { batchnorm_data[j] = vals_array[j][0]; - } else if (parameter_type == DT_HALF) { + } else if (parameter_type == nvinfer1::DataType::kHALF) { batchnorm_data[j] = Eigen::half_impl::half_to_float(cast_vals_array[j][0]); } @@ -3862,10 +3898,10 @@ Status ConvertFusedBatchNorm(OpConverterParams* params) { float variance = batchnorm_data[3]; float combined_scale_val = scale / sqrtf(variance + epsilon); float combined_offset_val = offset - mean * combined_scale_val; - if (parameter_type == DT_FLOAT) { + if (parameter_type == nvinfer1::DataType::kFLOAT) { combined_scale_vals[i] = combined_scale_val; combined_offset_vals[i] = combined_offset_val; - } else if (parameter_type == DT_HALF) { + } else if (parameter_type == nvinfer1::DataType::kHALF) { cast_combined_scale_vals[i] = Eigen::half(combined_scale_val); cast_combined_offset_vals[i] = Eigen::half(combined_offset_val); } @@ -3962,14 +3998,14 @@ Status ConvertMatMulHelper(OpConverterParams* params, } nvinfer1::ITensor* tensor = tensor_input.tensor(); - TRT_ShapedWeights weights(weights_raw.type_); + TRT_ShapedWeights weights(weights_raw.TrtDType()); if (transpose_weight) { weights = weights_raw; } else { weights = params->weight_store->GetTempWeights(weights_raw); ReorderCKtoKC(weights_raw, &weights); } - TRT_ShapedWeights biases(weights.type_); + TRT_ShapedWeights biases(weights.TrtDType()); int noutput = weights.shape_.d[0]; @@ -4472,7 +4508,7 @@ Status ConvertGraphDefToEngine( TFAttrs attrs(node_def); DataType tf_dtype = attrs.get("T"); nvinfer1::DataType trt_dtype; - TF_RETURN_IF_ERROR(ConvertDType(tf_dtype, &trt_dtype)); + TF_RETURN_IF_ERROR(TfDataTypeToTrt(tf_dtype, &trt_dtype)); if (output_tensors.size() <= slot_number) { output_tensors.resize(slot_number + 1); } diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h index 96293c119a4..dc32834203c 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h @@ -176,7 +176,8 @@ int64_t TrtTensorDimsNumElements(const nvinfer1::Dims& dims); // Class to convert TF compile-time constants (e.g. Const nodes) to TRT weight. class TRT_ShapedWeights { public: - explicit TRT_ShapedWeights(DataType type = DT_FLOAT); + explicit TRT_ShapedWeights( + nvinfer1::DataType type = nvinfer1::DataType::kFLOAT); // Copy from another weights. // @@ -211,14 +212,18 @@ class TRT_ShapedWeights { return std::vector(span.data(), span.data() + span.size()); } + nvinfer1::DataType TrtDType() const { return type_; } + // TODO(aaroey): make these private. nvinfer1::Dims shape_; // Note: shape.type[] is not used. - DataType type_; private: // This constructor is only used by TrtWeightStore, which creates the // underlying buffer. - TRT_ShapedWeights(DataType type, nvinfer1::Dims dims, Tensor tensor); + TRT_ShapedWeights(nvinfer1::DataType type, nvinfer1::Dims dims, + Tensor tensor); + + nvinfer1::DataType type_; // All weights should be stored inside TrtWeightStore to make sure lifetime of // all the underlying tensors are available until the engine is built. For @@ -239,12 +244,13 @@ class TRT_ShapedWeights { class TrtWeightStore { public: // Get a TRT_ShapedWeights with 'type' and 'dims'. - TRT_ShapedWeights GetTempWeights(DataType type, const nvinfer1::Dims& dims); + TRT_ShapedWeights GetTempWeights(nvinfer1::DataType trt_type, + const nvinfer1::Dims& dims); // Get a TRT_ShapedWeights with the same data type and dimensions as // 'weights'. TRT_ShapedWeights GetTempWeights(const TRT_ShapedWeights& weights) { - return GetTempWeights(weights.type_, weights.shape_); + return GetTempWeights(weights.TrtDType(), weights.shape_); } private: diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index 1d9910bff63..885995e0fe6 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -186,8 +186,8 @@ void ExpectArrayNear(const std::vector& lhs, bool TrtShapedWeightsEquals(const TRT_ShapedWeights& lhs, const TRT_ShapedWeights& rhs) { - return TrtDimsEquals(lhs.shape_, rhs.shape_) && lhs.type_ == rhs.type_ && - lhs.GetValues() == rhs.GetValues(); + return TrtDimsEquals(lhs.shape_, rhs.shape_) && + lhs.TrtDType() == rhs.TrtDType() && lhs.GetValues() == rhs.GetValues(); } template @@ -293,7 +293,7 @@ TEST(TRT_ShapedWeights_Test, Basic) { } // Test constructor with DataType argument. { - TRT_ShapedWeights weights(DT_FLOAT); + TRT_ShapedWeights weights(nvinfer1::DataType::kFLOAT); TRT_ShapedWeights copy(weights); for (auto ptr : {&weights, ©}) { nvinfer1::Weights trt_weights = ptr->GetTrtWeights(); @@ -310,7 +310,7 @@ TEST(TRT_ShapedWeights_Test, Basic) { { TrtWeightStore store; TRT_ShapedWeights weights = - store.GetTempWeights(DT_FLOAT, GetTestDims({2, 5})); + store.GetTempWeights(nvinfer1::DataType::kFLOAT, GetTestDims({2, 5})); TRT_ShapedWeights copy(weights); for (auto ptr : {&weights, ©}) { nvinfer1::Weights trt_weights = ptr->GetTrtWeights(); @@ -671,7 +671,7 @@ TEST_F(ConverterTest, RenameAndMarkOutputTensors) { params->outputs->emplace_back(output_tensor); output_tensors.push_back(output_tensor); } - TRT_ShapedWeights output_weights(DT_FLOAT); + TRT_ShapedWeights output_weights(nvinfer1::DataType::kFLOAT); params->outputs->emplace_back(output_weights); return Status::OK(); }; @@ -778,8 +778,8 @@ TEST_F(ConverterTest, PrepareTensorForShape_Tensor) { } TEST_F(ConverterTest, PrepareTensorForShape_Weights) { - TRT_ShapedWeights weights = - weight_store_->GetTempWeights(DT_FLOAT, GetTestDims({2, 3, 5})); + TRT_ShapedWeights weights = weight_store_->GetTempWeights( + nvinfer1::DataType::kFLOAT, GetTestDims({2, 3, 5})); nvinfer1::ITensor* output_tensor = nullptr; for (bool validation_only : {false, true}) { TF_EXPECT_OK(converter_->PrepareTensorForShape( @@ -832,8 +832,8 @@ TEST_F(ConverterTest, AddAndGetTensorOrWeights) { template void TestGetWeightRange(ConverterTest* test, TrtWeightStore* weight_store) { - TRT_ShapedWeights weights = - weight_store->GetTempWeights(DataTypeToEnum::v(), GetTestDims({2, 3})); + TRT_ShapedWeights weights = weight_store->GetTempWeights( + TfDataTypeToTrt(DataTypeToEnum::v()), GetTestDims({2, 3})); const std::vector values = {T(3), T(1), T(2), T(6), T(5), T(4)}; memcpy(weights.GetValues(), values.data(), weights.size_bytes()); @@ -1002,14 +1002,14 @@ TEST_F(ConverterTest, GetTrtBroadcastShape) { } TEST_F(ConverterTest, CreateConstantLayer) { - for (auto dtype : {DT_FLOAT, DT_INT32}) { + for (auto dtype : {nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT32}) { TRT_ShapedWeights weights = weight_store_->GetTempWeights(dtype, GetTestDims({2, 3, 5})); nvinfer1::ITensor* tensor = converter_->CreateConstantLayer(weights, GetTestDims({3, 10})); ASSERT_NE(nullptr, tensor); - EXPECT_EQ(TfDataTypeToTrt(dtype), tensor->getType()) - << "Expected " << DebugString(TfDataTypeToTrt(dtype)) << " vs. actual " + EXPECT_EQ(dtype, tensor->getType()) + << "Expected " << DebugString(dtype) << " vs. actual " << DebugString(tensor->getType()); ExpectTrtDimsEqualsArray({3, 10}, tensor->getDimensions()); } @@ -1246,7 +1246,7 @@ class OpConverterTest : public ::testing::Test { template void AddTestWeights(const string& name, const std::vector& dims, const std::vector& values) { - const DataType dtype = DataTypeToEnum::v(); + const nvinfer1::DataType dtype = TfDataTypeToTrt(DataTypeToEnum::v()); const nvinfer1::Dims trt_dims = GetTestDims(dims); const int64_t num_elements = TrtWeightDimsNumElements(trt_dims); QCHECK_EQ(num_elements, values.size()) @@ -1452,6 +1452,9 @@ TEST_F(OpConverterTest, ConvertConst) { TestConvertConst(this); TestConvertConst(this); + TestConvertConst(this); + TestConvertConst(this); + TestConvertConst(this); TestConvertConst(this); }