diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index 4bc61240a7e..a72702ea8ca 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -214,7 +214,8 @@ inline nvinfer1::Dims TensorShapeToTrtDims(const TensorShapeType& shape, return trt_dims; } -Status TensorShapeArrayToTrtDims(const std::vector& shape, +template +Status TensorShapeArrayToTrtDims(const Container& shape, nvinfer1::Dims* out, bool ignore_first_dim = false) { PartialTensorShape tensor_shape; diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index ab5dadcd6b1..4599d0c168a 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -2427,7 +2427,7 @@ TEST_F(OpConverterTest, ConvertCombinedNMS) { ElementsAre(0, 0, 0.3, 0.4, 0, 0, 0.3, 0.4)); EXPECT_THAT(GetSpanForData(output_data[1]), ElementsAre(0.7, 0.4)); EXPECT_THAT(GetSpanForData(output_data[2]), ElementsAre(1, 0)); - EXPECT_THAT(GetSpanForData(output_data[3]), ElementsAre(2)); + EXPECT_THAT(GetSpanForData(output_data[3]), ElementsAre(2)); } }