diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index 884ed7a5771..82c02c17e93 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -1712,7 +1712,7 @@ INSTANTIATE_TEST_CASE_P( // Builds and runs the converted network. Checks output tensor shape. Tests // output values using a matcher. -template +template void BuildAndRunConvertedNetwork(const string& name, OpConverterTest* test, const TestParamBase& p, const std::vector& input_vec, @@ -1731,12 +1731,14 @@ void BuildAndRunConvertedNetwork(const string& name, OpConverterTest* test, // runtime errors. return; } - typedef typename EnumToDataType::Type T; + typedef typename EnumToDataType::Type Tin; TensorShape shape; TF_EXPECT_OK(TensorShapeUtils::MakeShape(p.input_dims, &shape)); const DataVec input_data{ - {"input", test->AsTensor(CastTestVector(input_vec), shape)}}; - DataVec output_data{{name, test->ConstructTensor(6)}}; + {"input", + test->AsTensor(CastTestVector(input_vec), shape)}}; + typedef typename EnumToDataType::Type Tout; + DataVec output_data{{name, test->ConstructTensor(6)}}; test->BuildAndRun(input_data, &output_data); // Check the shape of the actual output tensor TF_EXPECT_OK(TensorShapeUtils::MakeShape(p.expected_output_dims, &shape)); @@ -1744,7 +1746,7 @@ void BuildAndRunConvertedNetwork(const string& name, OpConverterTest* test, << "Expected shape: " << shape.DebugString() << ", actual shape" << output_data[0].tensor.shape().DebugString(); // Cast the output to float and compare to expected output - auto out_span = GetSpanForData(output_data[0]); + auto out_span = GetSpanForData(output_data[0]); std::vector casted_output(out_span.begin(), out_span.end()); EXPECT_THAT(casted_output, matcher); } @@ -1754,16 +1756,35 @@ void InstantiateBuildAndRun(DataType tf_dtype, const string& name, const std::vector& input_vec, const Matcher>& matcher) { if (tf_dtype == DT_FLOAT) { - BuildAndRunConvertedNetwork(name, test, p, input_vec, matcher); + BuildAndRunConvertedNetwork(name, test, p, input_vec, + matcher); } else if (tf_dtype == DT_HALF) { - BuildAndRunConvertedNetwork(name, test, p, input_vec, matcher); + BuildAndRunConvertedNetwork(name, test, p, input_vec, + matcher); } else if (tf_dtype == DT_INT32) { - BuildAndRunConvertedNetwork(name, test, p, input_vec, matcher); + BuildAndRunConvertedNetwork(name, test, p, input_vec, + matcher); } else { FAIL() << "Test not supported for " << tf_dtype; } } +void InstantiateBuildAndRun(DataType input_tf_dtype, DataType output_tf_dtype, + const string& name, OpConverterTest* test, + const TestParamBase& p, + const std::vector& input_vec, + const Matcher>& matcher) { + if (input_tf_dtype == output_tf_dtype) { + InstantiateBuildAndRun(input_tf_dtype, name, test, p, input_vec, matcher); + } else if (input_tf_dtype == DT_HALF && output_tf_dtype) { + BuildAndRunConvertedNetwork(name, test, p, input_vec, + matcher); + } else { + FAIL() << "Test not supported for input " << input_tf_dtype << " output " + << output_tf_dtype; + } +} + template void CopyTensorElements(const Tensor& tensor, protobuf::RepeatedField* out) { out->Clear();