Merge pull request #39758 from tfeher:trt_squared_diff_dynamic_shape
PiperOrigin-RevId: 316897906 Change-Id: Iedea280cdf4b24e058c03b2e6d52997c87e3fd77
This commit is contained in:
commit
78ecbb0481
@ -1811,7 +1811,9 @@ class ParameterizedOpConverterTestBase
|
|||||||
const int batch_size = input_data_[0].tensor.shape().dim_size(0);
|
const int batch_size = input_data_[0].tensor.shape().dim_size(0);
|
||||||
Status stat =
|
Status stat =
|
||||||
OpConverterTest::BuildAndRun(input_data_, &output_data, batch_size);
|
OpConverterTest::BuildAndRun(input_data_, &output_data, batch_size);
|
||||||
ASSERT_EQ(expected_runtime_status, stat);
|
ASSERT_EQ(expected_runtime_status.ok(), stat.ok())
|
||||||
|
<< "expected status: " << expected_runtime_status
|
||||||
|
<< ", actual status: " << stat;
|
||||||
if (expected_runtime_status.ok() && stat.ok()) {
|
if (expected_runtime_status.ok() && stat.ok()) {
|
||||||
for (int i = 0; i < n_output; i++) {
|
for (int i = 0; i < n_output; i++) {
|
||||||
// Check the shape of the actual output tensors
|
// Check the shape of the actual output tensors
|
||||||
@ -6359,87 +6361,70 @@ NodeDef GetSquaredDifferenceNodeDef(DataType dtype) {
|
|||||||
return squared_diff.operation.node()->def();
|
return squared_diff.operation.node()->def();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <DataType dtype>
|
TEST_P(OpConverterTest2, ConvertSquaredDifference) {
|
||||||
void TestConvertSquaredDifference(OpConverterTest* test) {
|
|
||||||
typedef typename EnumToDataType<dtype>::Type CType;
|
|
||||||
|
|
||||||
struct TestParams {
|
|
||||||
std::vector<int> dims_x;
|
|
||||||
std::vector<int> dims_y;
|
|
||||||
std::vector<CType> value_x;
|
|
||||||
std::vector<CType> value_y;
|
|
||||||
std::vector<int> expected_output_dims;
|
|
||||||
std::vector<CType> expected_output;
|
|
||||||
};
|
|
||||||
|
|
||||||
const std::vector<CType> common_input = InitTestVector<CType>(6);
|
|
||||||
std::vector<TestParams> params = {
|
|
||||||
{
|
|
||||||
/*dims_x=*/{1, 2, 3},
|
|
||||||
/*dims_y=*/{1, 2, 3},
|
|
||||||
/*value_x=*/common_input,
|
|
||||||
/*value_y=*/CastTestVector<int, CType>({0, -1, 3, 0, 10, -7}),
|
|
||||||
/*expected_output_dims=*/{1, 2, 3},
|
|
||||||
/*expected_output=*/CastTestVector<int, CType>({0, 4, 1, 9, 36, 144}),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
/*dims_x=*/{1, 2, 3},
|
|
||||||
/*dims_y=*/{1, 1, 3},
|
|
||||||
/*value_x=*/common_input,
|
|
||||||
/*value_y=*/CastTestVector<int, CType>({0, 1, 2}),
|
|
||||||
/*expected_output_dims=*/{1, 2, 3},
|
|
||||||
/*expected_output=*/CastTestVector<int, CType>({0, 0, 0, 9, 9, 9}),
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
for (int i = 0; i < params.size(); ++i) {
|
|
||||||
test->Reset();
|
|
||||||
|
|
||||||
NodeDef node_def = GetSquaredDifferenceNodeDef(dtype);
|
|
||||||
test->AddTestTensor("x", params[i].dims_x, 1, TfDataTypeToTrt(dtype));
|
|
||||||
test->AddTestTensor("y", params[i].dims_y, 1, TfDataTypeToTrt(dtype));
|
|
||||||
test->RunValidationAndConversion(node_def);
|
|
||||||
|
|
||||||
TRT_TensorOrWeights output;
|
|
||||||
TF_EXPECT_OK(test->GetTensorOrWeights("my_squared_diff", &output));
|
|
||||||
EXPECT_TRUE(output.is_tensor());
|
|
||||||
ExpectTrtDimsEqualsArray(params[i].expected_output_dims,
|
|
||||||
output.tensor()->getDimensions());
|
|
||||||
|
|
||||||
DataVec input_data{{"x", test->AsTensor<CType>(params[i].value_x)},
|
|
||||||
{"y", test->AsTensor<CType>(params[i].value_y)}};
|
|
||||||
DataVec output_data{
|
|
||||||
{"my_squared_diff",
|
|
||||||
test->ConstructTensor<CType>(params[i].expected_output.size())}};
|
|
||||||
TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data));
|
|
||||||
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
|
|
||||||
ElementsAreArray(params[i].expected_output));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(OpConverterTest, ConvertSquaredDifference) {
|
|
||||||
{
|
{
|
||||||
// Input is a weight, should fail.
|
// Input is a weight, should fail.
|
||||||
Reset();
|
Reset();
|
||||||
NodeDef node_def = GetSquaredDifferenceNodeDef(DT_FLOAT);
|
NodeDef node_def = GetSquaredDifferenceNodeDef(tf_type);
|
||||||
AddTestWeights<float>("x", {1, 2, 3}, {1, 2, 3, 4, 5, 6});
|
AddTestWeights<float>("x", {1, 2, 3}, {1, 2, 3, 4, 5, 6});
|
||||||
AddTestTensor("y", {1, 2, 3});
|
AddTestTensor("y", {1, 1, 2, 3});
|
||||||
RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
|
RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
|
||||||
"The input \"x\" for SquaredDifference must be "
|
"The input \"x\" for SquaredDifference must be "
|
||||||
"a tensor, at my_squared_diff");
|
"a tensor, at my_squared_diff");
|
||||||
}
|
}
|
||||||
{
|
|
||||||
// Shapes are not broadcastable, should fail.
|
|
||||||
Reset();
|
|
||||||
NodeDef node_def = GetSquaredDifferenceNodeDef(DT_FLOAT);
|
|
||||||
AddTestTensor("x", {2, 3});
|
|
||||||
AddTestTensor("y", {7, 5});
|
|
||||||
RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
|
|
||||||
"Infeasible broadcast scheme");
|
|
||||||
}
|
|
||||||
|
|
||||||
TestConvertSquaredDifference<DT_FLOAT>(this);
|
struct TestParams {
|
||||||
TestConvertSquaredDifference<DT_HALF>(this);
|
std::vector<int> dims_x;
|
||||||
|
std::vector<int> dims_y;
|
||||||
|
std::vector<float> value_x;
|
||||||
|
std::vector<float> value_y;
|
||||||
|
std::vector<int> expected_output_dims;
|
||||||
|
std::vector<float> expected_output;
|
||||||
|
Status status;
|
||||||
|
Status runtime_status;
|
||||||
|
};
|
||||||
|
|
||||||
|
const std::vector<float> common_input = InitTestVector<float>(6);
|
||||||
|
std::vector<TestParams> params = {
|
||||||
|
{/*dims_x=*/{1, 2, 3},
|
||||||
|
/*dims_y=*/{1, 7, 5},
|
||||||
|
/*value_x=*/common_input,
|
||||||
|
/*value_y=*/std::vector<float>(7 * 5, 0),
|
||||||
|
/*expected_output_dims=*/{1, 1, 2, 3},
|
||||||
|
/*expected_output=*/common_input,
|
||||||
|
trt_mode == TrtTestMode::kDynamicShape
|
||||||
|
? Status::OK()
|
||||||
|
: errors::InvalidArgument("Infeasible broadcast scheme"),
|
||||||
|
errors::Internal(
|
||||||
|
"Binding index out of range. This can happen if profile is not set, "
|
||||||
|
"or the network is invalid for the current profile.")},
|
||||||
|
{
|
||||||
|
/*dims_x=*/{1, 1, 2, 3},
|
||||||
|
/*dims_y=*/{1, 1, 2, 3},
|
||||||
|
/*value_x=*/common_input,
|
||||||
|
/*value_y=*/{0, -1, 3, 0, 10, -7},
|
||||||
|
/*expected_output_dims=*/{1, 1, 2, 3},
|
||||||
|
/*expected_output=*/{0, 4, 1, 9, 36, 144},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
/*dims_x=*/{1, 1, 2, 3},
|
||||||
|
/*dims_y=*/{1, 1, 1, 3},
|
||||||
|
/*value_x=*/common_input,
|
||||||
|
/*value_y=*/{0, 1, 2},
|
||||||
|
/*expected_output_dims=*/{1, 1, 2, 3},
|
||||||
|
/*expected_output=*/{0, 0, 0, 9, 9, 9},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
for (auto p : params) {
|
||||||
|
Reset();
|
||||||
|
NodeDef node_def = GetSquaredDifferenceNodeDef(tf_type);
|
||||||
|
AddTestTensor("x", p.dims_x, p.value_x);
|
||||||
|
AddTestTensor("y", p.dims_y, p.value_y);
|
||||||
|
TestOpConverter("my_squared_diff", node_def, p.expected_output_dims,
|
||||||
|
p.status, p.runtime_status,
|
||||||
|
ElementsAreArray(p.expected_output));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
|
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
|
||||||
|
@ -46,6 +46,14 @@ Status GetTrtBindingShape(const nvinfer1::ICudaEngine* cuda_engine,
|
|||||||
// Get dims from context instead of engine in explicit batch mode because
|
// Get dims from context instead of engine in explicit batch mode because
|
||||||
// the engine might have dynamic shapes.
|
// the engine might have dynamic shapes.
|
||||||
dims = execution_context->getBindingDimensions(binding_index);
|
dims = execution_context->getBindingDimensions(binding_index);
|
||||||
|
if (dims.nbDims == -1) {
|
||||||
|
// Invalid dimensions. There can be multiple reasons for this. If we have
|
||||||
|
// incompatible input shapes (network invalid for the current profile)
|
||||||
|
// that can trigger this error.
|
||||||
|
return errors::Internal(
|
||||||
|
"Binding index out of range. This can happen if profile is not set, "
|
||||||
|
"or the network is invalid for the current profile.");
|
||||||
|
}
|
||||||
#else
|
#else
|
||||||
return errors::Internal(
|
return errors::Internal(
|
||||||
"Explicit batch mode is only supported with TensorRT 6 and above.");
|
"Explicit batch mode is only supported with TensorRT 6 and above.");
|
||||||
|
Loading…
x
Reference in New Issue
Block a user