Merge pull request #27208 from trevor-m:tmorris_tftrt_nms_test_bug_and_compile_error
PiperOrigin-RevId: 240638262
This commit is contained in:
commit
bb44f3384f
@ -214,8 +214,8 @@ inline nvinfer1::Dims TensorShapeToTrtDims(const TensorShapeType& shape,
|
||||
return trt_dims;
|
||||
}
|
||||
|
||||
Status TensorShapeArrayToTrtDims(const std::vector<int>& shape,
|
||||
nvinfer1::Dims* out,
|
||||
template <typename Container>
|
||||
Status TensorShapeArrayToTrtDims(const Container& shape, nvinfer1::Dims* out,
|
||||
bool ignore_first_dim = false) {
|
||||
PartialTensorShape tensor_shape;
|
||||
TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape(shape, &tensor_shape));
|
||||
@ -2354,7 +2354,7 @@ Status ConvertStridedSliceHelper(OpConverterParams* params,
|
||||
if (params->validation_only) return Status::OK();
|
||||
|
||||
nvinfer1::ISliceLayer* layer = params->converter->network()->addSlice(
|
||||
input.tensor(), begin_dims, size_dims, stride_dims);
|
||||
*input.tensor(), begin_dims, size_dims, stride_dims);
|
||||
params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0)));
|
||||
return Status::OK();
|
||||
#else
|
||||
|
@ -2427,7 +2427,7 @@ TEST_F(OpConverterTest, ConvertCombinedNMS) {
|
||||
ElementsAre(0, 0, 0.3, 0.4, 0, 0, 0.3, 0.4));
|
||||
EXPECT_THAT(GetSpanForData<float>(output_data[1]), ElementsAre(0.7, 0.4));
|
||||
EXPECT_THAT(GetSpanForData<float>(output_data[2]), ElementsAre(1, 0));
|
||||
EXPECT_THAT(GetSpanForData<float>(output_data[3]), ElementsAre(2));
|
||||
EXPECT_THAT(GetSpanForData<int32>(output_data[3]), ElementsAre(2));
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user