Merge pull request #39758 from tfeher:trt_squared_diff_dynamic_shape
PiperOrigin-RevId: 316897906 Change-Id: Iedea280cdf4b24e058c03b2e6d52997c87e3fd77
This commit is contained in:
commit
78ecbb0481
@ -1811,7 +1811,9 @@ class ParameterizedOpConverterTestBase
|
||||
const int batch_size = input_data_[0].tensor.shape().dim_size(0);
|
||||
Status stat =
|
||||
OpConverterTest::BuildAndRun(input_data_, &output_data, batch_size);
|
||||
ASSERT_EQ(expected_runtime_status, stat);
|
||||
ASSERT_EQ(expected_runtime_status.ok(), stat.ok())
|
||||
<< "expected status: " << expected_runtime_status
|
||||
<< ", actual status: " << stat;
|
||||
if (expected_runtime_status.ok() && stat.ok()) {
|
||||
for (int i = 0; i < n_output; i++) {
|
||||
// Check the shape of the actual output tensors
|
||||
@ -6359,87 +6361,70 @@ NodeDef GetSquaredDifferenceNodeDef(DataType dtype) {
|
||||
return squared_diff.operation.node()->def();
|
||||
}
|
||||
|
||||
template <DataType dtype>
|
||||
void TestConvertSquaredDifference(OpConverterTest* test) {
|
||||
typedef typename EnumToDataType<dtype>::Type CType;
|
||||
|
||||
struct TestParams {
|
||||
std::vector<int> dims_x;
|
||||
std::vector<int> dims_y;
|
||||
std::vector<CType> value_x;
|
||||
std::vector<CType> value_y;
|
||||
std::vector<int> expected_output_dims;
|
||||
std::vector<CType> expected_output;
|
||||
};
|
||||
|
||||
const std::vector<CType> common_input = InitTestVector<CType>(6);
|
||||
std::vector<TestParams> params = {
|
||||
{
|
||||
/*dims_x=*/{1, 2, 3},
|
||||
/*dims_y=*/{1, 2, 3},
|
||||
/*value_x=*/common_input,
|
||||
/*value_y=*/CastTestVector<int, CType>({0, -1, 3, 0, 10, -7}),
|
||||
/*expected_output_dims=*/{1, 2, 3},
|
||||
/*expected_output=*/CastTestVector<int, CType>({0, 4, 1, 9, 36, 144}),
|
||||
},
|
||||
{
|
||||
/*dims_x=*/{1, 2, 3},
|
||||
/*dims_y=*/{1, 1, 3},
|
||||
/*value_x=*/common_input,
|
||||
/*value_y=*/CastTestVector<int, CType>({0, 1, 2}),
|
||||
/*expected_output_dims=*/{1, 2, 3},
|
||||
/*expected_output=*/CastTestVector<int, CType>({0, 0, 0, 9, 9, 9}),
|
||||
},
|
||||
};
|
||||
|
||||
for (int i = 0; i < params.size(); ++i) {
|
||||
test->Reset();
|
||||
|
||||
NodeDef node_def = GetSquaredDifferenceNodeDef(dtype);
|
||||
test->AddTestTensor("x", params[i].dims_x, 1, TfDataTypeToTrt(dtype));
|
||||
test->AddTestTensor("y", params[i].dims_y, 1, TfDataTypeToTrt(dtype));
|
||||
test->RunValidationAndConversion(node_def);
|
||||
|
||||
TRT_TensorOrWeights output;
|
||||
TF_EXPECT_OK(test->GetTensorOrWeights("my_squared_diff", &output));
|
||||
EXPECT_TRUE(output.is_tensor());
|
||||
ExpectTrtDimsEqualsArray(params[i].expected_output_dims,
|
||||
output.tensor()->getDimensions());
|
||||
|
||||
DataVec input_data{{"x", test->AsTensor<CType>(params[i].value_x)},
|
||||
{"y", test->AsTensor<CType>(params[i].value_y)}};
|
||||
DataVec output_data{
|
||||
{"my_squared_diff",
|
||||
test->ConstructTensor<CType>(params[i].expected_output.size())}};
|
||||
TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data));
|
||||
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
|
||||
ElementsAreArray(params[i].expected_output));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(OpConverterTest, ConvertSquaredDifference) {
|
||||
TEST_P(OpConverterTest2, ConvertSquaredDifference) {
|
||||
{
|
||||
// Input is a weight, should fail.
|
||||
Reset();
|
||||
NodeDef node_def = GetSquaredDifferenceNodeDef(DT_FLOAT);
|
||||
NodeDef node_def = GetSquaredDifferenceNodeDef(tf_type);
|
||||
AddTestWeights<float>("x", {1, 2, 3}, {1, 2, 3, 4, 5, 6});
|
||||
AddTestTensor("y", {1, 2, 3});
|
||||
AddTestTensor("y", {1, 1, 2, 3});
|
||||
RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
|
||||
"The input \"x\" for SquaredDifference must be "
|
||||
"a tensor, at my_squared_diff");
|
||||
}
|
||||
{
|
||||
// Shapes are not broadcastable, should fail.
|
||||
Reset();
|
||||
NodeDef node_def = GetSquaredDifferenceNodeDef(DT_FLOAT);
|
||||
AddTestTensor("x", {2, 3});
|
||||
AddTestTensor("y", {7, 5});
|
||||
RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
|
||||
"Infeasible broadcast scheme");
|
||||
}
|
||||
|
||||
TestConvertSquaredDifference<DT_FLOAT>(this);
|
||||
TestConvertSquaredDifference<DT_HALF>(this);
|
||||
struct TestParams {
|
||||
std::vector<int> dims_x;
|
||||
std::vector<int> dims_y;
|
||||
std::vector<float> value_x;
|
||||
std::vector<float> value_y;
|
||||
std::vector<int> expected_output_dims;
|
||||
std::vector<float> expected_output;
|
||||
Status status;
|
||||
Status runtime_status;
|
||||
};
|
||||
|
||||
const std::vector<float> common_input = InitTestVector<float>(6);
|
||||
std::vector<TestParams> params = {
|
||||
{/*dims_x=*/{1, 2, 3},
|
||||
/*dims_y=*/{1, 7, 5},
|
||||
/*value_x=*/common_input,
|
||||
/*value_y=*/std::vector<float>(7 * 5, 0),
|
||||
/*expected_output_dims=*/{1, 1, 2, 3},
|
||||
/*expected_output=*/common_input,
|
||||
trt_mode == TrtTestMode::kDynamicShape
|
||||
? Status::OK()
|
||||
: errors::InvalidArgument("Infeasible broadcast scheme"),
|
||||
errors::Internal(
|
||||
"Binding index out of range. This can happen if profile is not set, "
|
||||
"or the network is invalid for the current profile.")},
|
||||
{
|
||||
/*dims_x=*/{1, 1, 2, 3},
|
||||
/*dims_y=*/{1, 1, 2, 3},
|
||||
/*value_x=*/common_input,
|
||||
/*value_y=*/{0, -1, 3, 0, 10, -7},
|
||||
/*expected_output_dims=*/{1, 1, 2, 3},
|
||||
/*expected_output=*/{0, 4, 1, 9, 36, 144},
|
||||
},
|
||||
{
|
||||
/*dims_x=*/{1, 1, 2, 3},
|
||||
/*dims_y=*/{1, 1, 1, 3},
|
||||
/*value_x=*/common_input,
|
||||
/*value_y=*/{0, 1, 2},
|
||||
/*expected_output_dims=*/{1, 1, 2, 3},
|
||||
/*expected_output=*/{0, 0, 0, 9, 9, 9},
|
||||
},
|
||||
};
|
||||
|
||||
for (auto p : params) {
|
||||
Reset();
|
||||
NodeDef node_def = GetSquaredDifferenceNodeDef(tf_type);
|
||||
AddTestTensor("x", p.dims_x, p.value_x);
|
||||
AddTestTensor("y", p.dims_y, p.value_y);
|
||||
TestOpConverter("my_squared_diff", node_def, p.expected_output_dims,
|
||||
p.status, p.runtime_status,
|
||||
ElementsAreArray(p.expected_output));
|
||||
}
|
||||
}
|
||||
|
||||
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
|
||||
|
@ -46,6 +46,14 @@ Status GetTrtBindingShape(const nvinfer1::ICudaEngine* cuda_engine,
|
||||
// Get dims from context instead of engine in explicit batch mode because
|
||||
// the engine might have dynamic shapes.
|
||||
dims = execution_context->getBindingDimensions(binding_index);
|
||||
if (dims.nbDims == -1) {
|
||||
// Invalid dimensions. There can be multiple reasons for this. If we have
|
||||
// incompatible input shapes (network invalid for the current profile)
|
||||
// that can trigger this error.
|
||||
return errors::Internal(
|
||||
"Binding index out of range. This can happen if profile is not set, "
|
||||
"or the network is invalid for the current profile.");
|
||||
}
|
||||
#else
|
||||
return errors::Internal(
|
||||
"Explicit batch mode is only supported with TensorRT 6 and above.");
|
||||
|
Loading…
x
Reference in New Issue
Block a user