diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index 450831910f6..253351546fa 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -1811,7 +1811,9 @@ class ParameterizedOpConverterTestBase const int batch_size = input_data_[0].tensor.shape().dim_size(0); Status stat = OpConverterTest::BuildAndRun(input_data_, &output_data, batch_size); - ASSERT_EQ(expected_runtime_status, stat); + ASSERT_EQ(expected_runtime_status.ok(), stat.ok()) + << "expected status: " << expected_runtime_status + << ", actual status: " << stat; if (expected_runtime_status.ok() && stat.ok()) { for (int i = 0; i < n_output; i++) { // Check the shape of the actual output tensors @@ -6359,87 +6361,70 @@ NodeDef GetSquaredDifferenceNodeDef(DataType dtype) { return squared_diff.operation.node()->def(); } -template -void TestConvertSquaredDifference(OpConverterTest* test) { - typedef typename EnumToDataType::Type CType; - - struct TestParams { - std::vector dims_x; - std::vector dims_y; - std::vector value_x; - std::vector value_y; - std::vector expected_output_dims; - std::vector expected_output; - }; - - const std::vector common_input = InitTestVector(6); - std::vector params = { - { - /*dims_x=*/{1, 2, 3}, - /*dims_y=*/{1, 2, 3}, - /*value_x=*/common_input, - /*value_y=*/CastTestVector({0, -1, 3, 0, 10, -7}), - /*expected_output_dims=*/{1, 2, 3}, - /*expected_output=*/CastTestVector({0, 4, 1, 9, 36, 144}), - }, - { - /*dims_x=*/{1, 2, 3}, - /*dims_y=*/{1, 1, 3}, - /*value_x=*/common_input, - /*value_y=*/CastTestVector({0, 1, 2}), - /*expected_output_dims=*/{1, 2, 3}, - /*expected_output=*/CastTestVector({0, 0, 0, 9, 9, 9}), - }, - }; - - for (int i = 0; i < params.size(); ++i) { - test->Reset(); - - NodeDef node_def = GetSquaredDifferenceNodeDef(dtype); - test->AddTestTensor("x", params[i].dims_x, 1, TfDataTypeToTrt(dtype)); - test->AddTestTensor("y", params[i].dims_y, 1, TfDataTypeToTrt(dtype)); - test->RunValidationAndConversion(node_def); - - TRT_TensorOrWeights output; - TF_EXPECT_OK(test->GetTensorOrWeights("my_squared_diff", &output)); - EXPECT_TRUE(output.is_tensor()); - ExpectTrtDimsEqualsArray(params[i].expected_output_dims, - output.tensor()->getDimensions()); - - DataVec input_data{{"x", test->AsTensor(params[i].value_x)}, - {"y", test->AsTensor(params[i].value_y)}}; - DataVec output_data{ - {"my_squared_diff", - test->ConstructTensor(params[i].expected_output.size())}}; - TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data)); - EXPECT_THAT(GetSpanForData(output_data[0]), - ElementsAreArray(params[i].expected_output)); - } -} - -TEST_F(OpConverterTest, ConvertSquaredDifference) { +TEST_P(OpConverterTest2, ConvertSquaredDifference) { { // Input is a weight, should fail. Reset(); - NodeDef node_def = GetSquaredDifferenceNodeDef(DT_FLOAT); + NodeDef node_def = GetSquaredDifferenceNodeDef(tf_type); AddTestWeights("x", {1, 2, 3}, {1, 2, 3, 4, 5, 6}); - AddTestTensor("y", {1, 2, 3}); + AddTestTensor("y", {1, 1, 2, 3}); RunValidationAndConversion(node_def, error::UNIMPLEMENTED, "The input \"x\" for SquaredDifference must be " "a tensor, at my_squared_diff"); } - { - // Shapes are not broadcastable, should fail. - Reset(); - NodeDef node_def = GetSquaredDifferenceNodeDef(DT_FLOAT); - AddTestTensor("x", {2, 3}); - AddTestTensor("y", {7, 5}); - RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, - "Infeasible broadcast scheme"); - } - TestConvertSquaredDifference(this); - TestConvertSquaredDifference(this); + struct TestParams { + std::vector dims_x; + std::vector dims_y; + std::vector value_x; + std::vector value_y; + std::vector expected_output_dims; + std::vector expected_output; + Status status; + Status runtime_status; + }; + + const std::vector common_input = InitTestVector(6); + std::vector params = { + {/*dims_x=*/{1, 2, 3}, + /*dims_y=*/{1, 7, 5}, + /*value_x=*/common_input, + /*value_y=*/std::vector(7 * 5, 0), + /*expected_output_dims=*/{1, 1, 2, 3}, + /*expected_output=*/common_input, + trt_mode == TrtTestMode::kDynamicShape + ? Status::OK() + : errors::InvalidArgument("Infeasible broadcast scheme"), + errors::Internal( + "Binding index out of range. This can happen if profile is not set, " + "or the network is invalid for the current profile.")}, + { + /*dims_x=*/{1, 1, 2, 3}, + /*dims_y=*/{1, 1, 2, 3}, + /*value_x=*/common_input, + /*value_y=*/{0, -1, 3, 0, 10, -7}, + /*expected_output_dims=*/{1, 1, 2, 3}, + /*expected_output=*/{0, 4, 1, 9, 36, 144}, + }, + { + /*dims_x=*/{1, 1, 2, 3}, + /*dims_y=*/{1, 1, 1, 3}, + /*value_x=*/common_input, + /*value_y=*/{0, 1, 2}, + /*expected_output_dims=*/{1, 1, 2, 3}, + /*expected_output=*/{0, 0, 0, 9, 9, 9}, + }, + }; + + for (auto p : params) { + Reset(); + NodeDef node_def = GetSquaredDifferenceNodeDef(tf_type); + AddTestTensor("x", p.dims_x, p.value_x); + AddTestTensor("y", p.dims_y, p.value_y); + TestOpConverter("my_squared_diff", node_def, p.expected_output_dims, + p.status, p.runtime_status, + ElementsAreArray(p.expected_output)); + } } #if IS_TRT_VERSION_GE(6, 0, 0, 0) diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.cc index 213c1732e59..ed997b267b1 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.cc @@ -46,6 +46,14 @@ Status GetTrtBindingShape(const nvinfer1::ICudaEngine* cuda_engine, // Get dims from context instead of engine in explicit batch mode because // the engine might have dynamic shapes. dims = execution_context->getBindingDimensions(binding_index); + if (dims.nbDims == -1) { + // Invalid dimensions. There can be multiple reasons for this. If we have + // incompatible input shapes (network invalid for the current profile) + // that can trigger this error. + return errors::Internal( + "Binding index out of range. This can happen if profile is not set, " + "or the network is invalid for the current profile."); + } #else return errors::Internal( "Explicit batch mode is only supported with TensorRT 6 and above.");