[TF:TRT] Enhance InstantiateBuildAndRun to support the case where the input

type and output type are not the same.

This is to prepare for a change to enhance the TF-TRT bridge to support the Cast
operations that can be represented via IIdentityLayer.

PiperOrigin-RevId: 312077452
Change-Id: Iab6bfb54d6a346eef158785f61a1311559cee855
This commit is contained in:
Bixia Zheng 2020-05-18 07:46:08 -07:00 committed by TensorFlower Gardener
parent ea113ef6cd
commit f40a063d84

View File

@ -1712,7 +1712,7 @@ INSTANTIATE_TEST_CASE_P(
// Builds and runs the converted network. Checks output tensor shape. Tests // Builds and runs the converted network. Checks output tensor shape. Tests
// output values using a matcher. // output values using a matcher.
template <DataType dtype> template <DataType input_dtype, DataType output_dtype>
void BuildAndRunConvertedNetwork(const string& name, OpConverterTest* test, void BuildAndRunConvertedNetwork(const string& name, OpConverterTest* test,
const TestParamBase& p, const TestParamBase& p,
const std::vector<float>& input_vec, const std::vector<float>& input_vec,
@ -1731,12 +1731,14 @@ void BuildAndRunConvertedNetwork(const string& name, OpConverterTest* test,
// runtime errors. // runtime errors.
return; return;
} }
typedef typename EnumToDataType<dtype>::Type T; typedef typename EnumToDataType<input_dtype>::Type Tin;
TensorShape shape; TensorShape shape;
TF_EXPECT_OK(TensorShapeUtils::MakeShape(p.input_dims, &shape)); TF_EXPECT_OK(TensorShapeUtils::MakeShape(p.input_dims, &shape));
const DataVec input_data{ const DataVec input_data{
{"input", test->AsTensor<T>(CastTestVector<float, T>(input_vec), shape)}}; {"input",
DataVec output_data{{name, test->ConstructTensor<T>(6)}}; test->AsTensor<Tin>(CastTestVector<float, Tin>(input_vec), shape)}};
typedef typename EnumToDataType<output_dtype>::Type Tout;
DataVec output_data{{name, test->ConstructTensor<Tout>(6)}};
test->BuildAndRun(input_data, &output_data); test->BuildAndRun(input_data, &output_data);
// Check the shape of the actual output tensor // Check the shape of the actual output tensor
TF_EXPECT_OK(TensorShapeUtils::MakeShape(p.expected_output_dims, &shape)); TF_EXPECT_OK(TensorShapeUtils::MakeShape(p.expected_output_dims, &shape));
@ -1744,7 +1746,7 @@ void BuildAndRunConvertedNetwork(const string& name, OpConverterTest* test,
<< "Expected shape: " << shape.DebugString() << ", actual shape" << "Expected shape: " << shape.DebugString() << ", actual shape"
<< output_data[0].tensor.shape().DebugString(); << output_data[0].tensor.shape().DebugString();
// Cast the output to float and compare to expected output // Cast the output to float and compare to expected output
auto out_span = GetSpanForData<T>(output_data[0]); auto out_span = GetSpanForData<Tout>(output_data[0]);
std::vector<float> casted_output(out_span.begin(), out_span.end()); std::vector<float> casted_output(out_span.begin(), out_span.end());
EXPECT_THAT(casted_output, matcher); EXPECT_THAT(casted_output, matcher);
} }
@ -1754,16 +1756,35 @@ void InstantiateBuildAndRun(DataType tf_dtype, const string& name,
const std::vector<float>& input_vec, const std::vector<float>& input_vec,
const Matcher<std::vector<float>>& matcher) { const Matcher<std::vector<float>>& matcher) {
if (tf_dtype == DT_FLOAT) { if (tf_dtype == DT_FLOAT) {
BuildAndRunConvertedNetwork<DT_FLOAT>(name, test, p, input_vec, matcher); BuildAndRunConvertedNetwork<DT_FLOAT, DT_FLOAT>(name, test, p, input_vec,
matcher);
} else if (tf_dtype == DT_HALF) { } else if (tf_dtype == DT_HALF) {
BuildAndRunConvertedNetwork<DT_HALF>(name, test, p, input_vec, matcher); BuildAndRunConvertedNetwork<DT_HALF, DT_HALF>(name, test, p, input_vec,
matcher);
} else if (tf_dtype == DT_INT32) { } else if (tf_dtype == DT_INT32) {
BuildAndRunConvertedNetwork<DT_INT32>(name, test, p, input_vec, matcher); BuildAndRunConvertedNetwork<DT_INT32, DT_INT32>(name, test, p, input_vec,
matcher);
} else { } else {
FAIL() << "Test not supported for " << tf_dtype; FAIL() << "Test not supported for " << tf_dtype;
} }
} }
void InstantiateBuildAndRun(DataType input_tf_dtype, DataType output_tf_dtype,
const string& name, OpConverterTest* test,
const TestParamBase& p,
const std::vector<float>& input_vec,
const Matcher<std::vector<float>>& matcher) {
if (input_tf_dtype == output_tf_dtype) {
InstantiateBuildAndRun(input_tf_dtype, name, test, p, input_vec, matcher);
} else if (input_tf_dtype == DT_HALF && output_tf_dtype) {
BuildAndRunConvertedNetwork<DT_HALF, DT_FLOAT>(name, test, p, input_vec,
matcher);
} else {
FAIL() << "Test not supported for input " << input_tf_dtype << " output "
<< output_tf_dtype;
}
}
template <typename T> template <typename T>
void CopyTensorElements(const Tensor& tensor, protobuf::RepeatedField<T>* out) { void CopyTensorElements(const Tensor& tensor, protobuf::RepeatedField<T>* out) {
out->Clear(); out->Clear();