[TF:TRT] Enhance InstantiateBuildAndRun to support the case where the input
type and output type are not the same. This is to prepare for a change to enhance the TF-TRT bridge to support the Cast operations that can be represented via IIdentityLayer. PiperOrigin-RevId: 312077452 Change-Id: Iab6bfb54d6a346eef158785f61a1311559cee855
This commit is contained in:
parent
ea113ef6cd
commit
f40a063d84
@ -1712,7 +1712,7 @@ INSTANTIATE_TEST_CASE_P(
|
|||||||
|
|
||||||
// Builds and runs the converted network. Checks output tensor shape. Tests
|
// Builds and runs the converted network. Checks output tensor shape. Tests
|
||||||
// output values using a matcher.
|
// output values using a matcher.
|
||||||
template <DataType dtype>
|
template <DataType input_dtype, DataType output_dtype>
|
||||||
void BuildAndRunConvertedNetwork(const string& name, OpConverterTest* test,
|
void BuildAndRunConvertedNetwork(const string& name, OpConverterTest* test,
|
||||||
const TestParamBase& p,
|
const TestParamBase& p,
|
||||||
const std::vector<float>& input_vec,
|
const std::vector<float>& input_vec,
|
||||||
@ -1731,12 +1731,14 @@ void BuildAndRunConvertedNetwork(const string& name, OpConverterTest* test,
|
|||||||
// runtime errors.
|
// runtime errors.
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
typedef typename EnumToDataType<dtype>::Type T;
|
typedef typename EnumToDataType<input_dtype>::Type Tin;
|
||||||
TensorShape shape;
|
TensorShape shape;
|
||||||
TF_EXPECT_OK(TensorShapeUtils::MakeShape(p.input_dims, &shape));
|
TF_EXPECT_OK(TensorShapeUtils::MakeShape(p.input_dims, &shape));
|
||||||
const DataVec input_data{
|
const DataVec input_data{
|
||||||
{"input", test->AsTensor<T>(CastTestVector<float, T>(input_vec), shape)}};
|
{"input",
|
||||||
DataVec output_data{{name, test->ConstructTensor<T>(6)}};
|
test->AsTensor<Tin>(CastTestVector<float, Tin>(input_vec), shape)}};
|
||||||
|
typedef typename EnumToDataType<output_dtype>::Type Tout;
|
||||||
|
DataVec output_data{{name, test->ConstructTensor<Tout>(6)}};
|
||||||
test->BuildAndRun(input_data, &output_data);
|
test->BuildAndRun(input_data, &output_data);
|
||||||
// Check the shape of the actual output tensor
|
// Check the shape of the actual output tensor
|
||||||
TF_EXPECT_OK(TensorShapeUtils::MakeShape(p.expected_output_dims, &shape));
|
TF_EXPECT_OK(TensorShapeUtils::MakeShape(p.expected_output_dims, &shape));
|
||||||
@ -1744,7 +1746,7 @@ void BuildAndRunConvertedNetwork(const string& name, OpConverterTest* test,
|
|||||||
<< "Expected shape: " << shape.DebugString() << ", actual shape"
|
<< "Expected shape: " << shape.DebugString() << ", actual shape"
|
||||||
<< output_data[0].tensor.shape().DebugString();
|
<< output_data[0].tensor.shape().DebugString();
|
||||||
// Cast the output to float and compare to expected output
|
// Cast the output to float and compare to expected output
|
||||||
auto out_span = GetSpanForData<T>(output_data[0]);
|
auto out_span = GetSpanForData<Tout>(output_data[0]);
|
||||||
std::vector<float> casted_output(out_span.begin(), out_span.end());
|
std::vector<float> casted_output(out_span.begin(), out_span.end());
|
||||||
EXPECT_THAT(casted_output, matcher);
|
EXPECT_THAT(casted_output, matcher);
|
||||||
}
|
}
|
||||||
@ -1754,16 +1756,35 @@ void InstantiateBuildAndRun(DataType tf_dtype, const string& name,
|
|||||||
const std::vector<float>& input_vec,
|
const std::vector<float>& input_vec,
|
||||||
const Matcher<std::vector<float>>& matcher) {
|
const Matcher<std::vector<float>>& matcher) {
|
||||||
if (tf_dtype == DT_FLOAT) {
|
if (tf_dtype == DT_FLOAT) {
|
||||||
BuildAndRunConvertedNetwork<DT_FLOAT>(name, test, p, input_vec, matcher);
|
BuildAndRunConvertedNetwork<DT_FLOAT, DT_FLOAT>(name, test, p, input_vec,
|
||||||
|
matcher);
|
||||||
} else if (tf_dtype == DT_HALF) {
|
} else if (tf_dtype == DT_HALF) {
|
||||||
BuildAndRunConvertedNetwork<DT_HALF>(name, test, p, input_vec, matcher);
|
BuildAndRunConvertedNetwork<DT_HALF, DT_HALF>(name, test, p, input_vec,
|
||||||
|
matcher);
|
||||||
} else if (tf_dtype == DT_INT32) {
|
} else if (tf_dtype == DT_INT32) {
|
||||||
BuildAndRunConvertedNetwork<DT_INT32>(name, test, p, input_vec, matcher);
|
BuildAndRunConvertedNetwork<DT_INT32, DT_INT32>(name, test, p, input_vec,
|
||||||
|
matcher);
|
||||||
} else {
|
} else {
|
||||||
FAIL() << "Test not supported for " << tf_dtype;
|
FAIL() << "Test not supported for " << tf_dtype;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void InstantiateBuildAndRun(DataType input_tf_dtype, DataType output_tf_dtype,
|
||||||
|
const string& name, OpConverterTest* test,
|
||||||
|
const TestParamBase& p,
|
||||||
|
const std::vector<float>& input_vec,
|
||||||
|
const Matcher<std::vector<float>>& matcher) {
|
||||||
|
if (input_tf_dtype == output_tf_dtype) {
|
||||||
|
InstantiateBuildAndRun(input_tf_dtype, name, test, p, input_vec, matcher);
|
||||||
|
} else if (input_tf_dtype == DT_HALF && output_tf_dtype) {
|
||||||
|
BuildAndRunConvertedNetwork<DT_HALF, DT_FLOAT>(name, test, p, input_vec,
|
||||||
|
matcher);
|
||||||
|
} else {
|
||||||
|
FAIL() << "Test not supported for input " << input_tf_dtype << " output "
|
||||||
|
<< output_tf_dtype;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void CopyTensorElements(const Tensor& tensor, protobuf::RepeatedField<T>* out) {
|
void CopyTensorElements(const Tensor& tensor, protobuf::RepeatedField<T>* out) {
|
||||||
out->Clear();
|
out->Clear();
|
||||||
|
Loading…
Reference in New Issue
Block a user