Merge pull request #27208 from trevor-m:tmorris_tftrt_nms_test_bug_and_compile_error

PiperOrigin-RevId: 240638262
This commit is contained in:
TensorFlower Gardener 2019-03-27 14:29:22 -07:00
commit bb44f3384f
2 changed files with 4 additions and 4 deletions

View File

@ -214,8 +214,8 @@ inline nvinfer1::Dims TensorShapeToTrtDims(const TensorShapeType& shape,
return trt_dims; return trt_dims;
} }
Status TensorShapeArrayToTrtDims(const std::vector<int>& shape, template <typename Container>
nvinfer1::Dims* out, Status TensorShapeArrayToTrtDims(const Container& shape, nvinfer1::Dims* out,
bool ignore_first_dim = false) { bool ignore_first_dim = false) {
PartialTensorShape tensor_shape; PartialTensorShape tensor_shape;
TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape(shape, &tensor_shape)); TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape(shape, &tensor_shape));
@ -2354,7 +2354,7 @@ Status ConvertStridedSliceHelper(OpConverterParams* params,
if (params->validation_only) return Status::OK(); if (params->validation_only) return Status::OK();
nvinfer1::ISliceLayer* layer = params->converter->network()->addSlice( nvinfer1::ISliceLayer* layer = params->converter->network()->addSlice(
input.tensor(), begin_dims, size_dims, stride_dims); *input.tensor(), begin_dims, size_dims, stride_dims);
params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0))); params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0)));
return Status::OK(); return Status::OK();
#else #else

View File

@ -2427,7 +2427,7 @@ TEST_F(OpConverterTest, ConvertCombinedNMS) {
ElementsAre(0, 0, 0.3, 0.4, 0, 0, 0.3, 0.4)); ElementsAre(0, 0, 0.3, 0.4, 0, 0, 0.3, 0.4));
EXPECT_THAT(GetSpanForData<float>(output_data[1]), ElementsAre(0.7, 0.4)); EXPECT_THAT(GetSpanForData<float>(output_data[1]), ElementsAre(0.7, 0.4));
EXPECT_THAT(GetSpanForData<float>(output_data[2]), ElementsAre(1, 0)); EXPECT_THAT(GetSpanForData<float>(output_data[2]), ElementsAre(1, 0));
EXPECT_THAT(GetSpanForData<float>(output_data[3]), ElementsAre(2)); EXPECT_THAT(GetSpanForData<int32>(output_data[3]), ElementsAre(2));
} }
} }