[TF:TRT] Enhance InstantiateBuildAndRun to support the case where the input
type and output type are not the same. This is to prepare for a change to enhance the TF-TRT bridge to support the Cast operations that can be represented via IIdentityLayer. PiperOrigin-RevId: 312077452 Change-Id: Iab6bfb54d6a346eef158785f61a1311559cee855
This commit is contained in:
parent
ea113ef6cd
commit
f40a063d84
@ -1712,7 +1712,7 @@ INSTANTIATE_TEST_CASE_P(
|
||||
|
||||
// Builds and runs the converted network. Checks output tensor shape. Tests
|
||||
// output values using a matcher.
|
||||
template <DataType dtype>
|
||||
template <DataType input_dtype, DataType output_dtype>
|
||||
void BuildAndRunConvertedNetwork(const string& name, OpConverterTest* test,
|
||||
const TestParamBase& p,
|
||||
const std::vector<float>& input_vec,
|
||||
@ -1731,12 +1731,14 @@ void BuildAndRunConvertedNetwork(const string& name, OpConverterTest* test,
|
||||
// runtime errors.
|
||||
return;
|
||||
}
|
||||
typedef typename EnumToDataType<dtype>::Type T;
|
||||
typedef typename EnumToDataType<input_dtype>::Type Tin;
|
||||
TensorShape shape;
|
||||
TF_EXPECT_OK(TensorShapeUtils::MakeShape(p.input_dims, &shape));
|
||||
const DataVec input_data{
|
||||
{"input", test->AsTensor<T>(CastTestVector<float, T>(input_vec), shape)}};
|
||||
DataVec output_data{{name, test->ConstructTensor<T>(6)}};
|
||||
{"input",
|
||||
test->AsTensor<Tin>(CastTestVector<float, Tin>(input_vec), shape)}};
|
||||
typedef typename EnumToDataType<output_dtype>::Type Tout;
|
||||
DataVec output_data{{name, test->ConstructTensor<Tout>(6)}};
|
||||
test->BuildAndRun(input_data, &output_data);
|
||||
// Check the shape of the actual output tensor
|
||||
TF_EXPECT_OK(TensorShapeUtils::MakeShape(p.expected_output_dims, &shape));
|
||||
@ -1744,7 +1746,7 @@ void BuildAndRunConvertedNetwork(const string& name, OpConverterTest* test,
|
||||
<< "Expected shape: " << shape.DebugString() << ", actual shape"
|
||||
<< output_data[0].tensor.shape().DebugString();
|
||||
// Cast the output to float and compare to expected output
|
||||
auto out_span = GetSpanForData<T>(output_data[0]);
|
||||
auto out_span = GetSpanForData<Tout>(output_data[0]);
|
||||
std::vector<float> casted_output(out_span.begin(), out_span.end());
|
||||
EXPECT_THAT(casted_output, matcher);
|
||||
}
|
||||
@ -1754,16 +1756,35 @@ void InstantiateBuildAndRun(DataType tf_dtype, const string& name,
|
||||
const std::vector<float>& input_vec,
|
||||
const Matcher<std::vector<float>>& matcher) {
|
||||
if (tf_dtype == DT_FLOAT) {
|
||||
BuildAndRunConvertedNetwork<DT_FLOAT>(name, test, p, input_vec, matcher);
|
||||
BuildAndRunConvertedNetwork<DT_FLOAT, DT_FLOAT>(name, test, p, input_vec,
|
||||
matcher);
|
||||
} else if (tf_dtype == DT_HALF) {
|
||||
BuildAndRunConvertedNetwork<DT_HALF>(name, test, p, input_vec, matcher);
|
||||
BuildAndRunConvertedNetwork<DT_HALF, DT_HALF>(name, test, p, input_vec,
|
||||
matcher);
|
||||
} else if (tf_dtype == DT_INT32) {
|
||||
BuildAndRunConvertedNetwork<DT_INT32>(name, test, p, input_vec, matcher);
|
||||
BuildAndRunConvertedNetwork<DT_INT32, DT_INT32>(name, test, p, input_vec,
|
||||
matcher);
|
||||
} else {
|
||||
FAIL() << "Test not supported for " << tf_dtype;
|
||||
}
|
||||
}
|
||||
|
||||
void InstantiateBuildAndRun(DataType input_tf_dtype, DataType output_tf_dtype,
|
||||
const string& name, OpConverterTest* test,
|
||||
const TestParamBase& p,
|
||||
const std::vector<float>& input_vec,
|
||||
const Matcher<std::vector<float>>& matcher) {
|
||||
if (input_tf_dtype == output_tf_dtype) {
|
||||
InstantiateBuildAndRun(input_tf_dtype, name, test, p, input_vec, matcher);
|
||||
} else if (input_tf_dtype == DT_HALF && output_tf_dtype) {
|
||||
BuildAndRunConvertedNetwork<DT_HALF, DT_FLOAT>(name, test, p, input_vec,
|
||||
matcher);
|
||||
} else {
|
||||
FAIL() << "Test not supported for input " << input_tf_dtype << " output "
|
||||
<< output_tf_dtype;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CopyTensorElements(const Tensor& tensor, protobuf::RepeatedField<T>* out) {
|
||||
out->Clear();
|
||||
|
Loading…
Reference in New Issue
Block a user