Merge pull request #27208 from trevor-m:tmorris_tftrt_nms_test_bug_and_compile_error
PiperOrigin-RevId: 240638262
This commit is contained in:
commit
bb44f3384f
@ -214,8 +214,8 @@ inline nvinfer1::Dims TensorShapeToTrtDims(const TensorShapeType& shape,
|
|||||||
return trt_dims;
|
return trt_dims;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status TensorShapeArrayToTrtDims(const std::vector<int>& shape,
|
template <typename Container>
|
||||||
nvinfer1::Dims* out,
|
Status TensorShapeArrayToTrtDims(const Container& shape, nvinfer1::Dims* out,
|
||||||
bool ignore_first_dim = false) {
|
bool ignore_first_dim = false) {
|
||||||
PartialTensorShape tensor_shape;
|
PartialTensorShape tensor_shape;
|
||||||
TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape(shape, &tensor_shape));
|
TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape(shape, &tensor_shape));
|
||||||
@ -2354,7 +2354,7 @@ Status ConvertStridedSliceHelper(OpConverterParams* params,
|
|||||||
if (params->validation_only) return Status::OK();
|
if (params->validation_only) return Status::OK();
|
||||||
|
|
||||||
nvinfer1::ISliceLayer* layer = params->converter->network()->addSlice(
|
nvinfer1::ISliceLayer* layer = params->converter->network()->addSlice(
|
||||||
input.tensor(), begin_dims, size_dims, stride_dims);
|
*input.tensor(), begin_dims, size_dims, stride_dims);
|
||||||
params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0)));
|
params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0)));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
#else
|
#else
|
||||||
|
@ -2427,7 +2427,7 @@ TEST_F(OpConverterTest, ConvertCombinedNMS) {
|
|||||||
ElementsAre(0, 0, 0.3, 0.4, 0, 0, 0.3, 0.4));
|
ElementsAre(0, 0, 0.3, 0.4, 0, 0, 0.3, 0.4));
|
||||||
EXPECT_THAT(GetSpanForData<float>(output_data[1]), ElementsAre(0.7, 0.4));
|
EXPECT_THAT(GetSpanForData<float>(output_data[1]), ElementsAre(0.7, 0.4));
|
||||||
EXPECT_THAT(GetSpanForData<float>(output_data[2]), ElementsAre(1, 0));
|
EXPECT_THAT(GetSpanForData<float>(output_data[2]), ElementsAre(1, 0));
|
||||||
EXPECT_THAT(GetSpanForData<float>(output_data[3]), ElementsAre(2));
|
EXPECT_THAT(GetSpanForData<int32>(output_data[3]), ElementsAre(2));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user