Merge pull request #39990 from tfeher:trt_shape_op
PiperOrigin-RevId: 325245809 Change-Id: I7039b08edc3f5a639e8580cb63ff3c0ef1c95b36
This commit is contained in:
commit
3f24d131d7
@ -2410,6 +2410,40 @@ Status ConvertTranspose(OpConverterParams* params) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConvertShape(OpConverterParams* params) {
|
||||
const auto& inputs = params->inputs;
|
||||
TF_RETURN_IF_ERROR(
|
||||
CheckInputsWeights(*params, {{"input", TrtInputArg::kBoth}}));
|
||||
if (params->use_implicit_batch) {
|
||||
return errors::Unimplemented(
|
||||
"Shape is only supported for explicit batch mode.");
|
||||
}
|
||||
if (HasStaticShape(inputs.at(0).GetTrtDims())) {
|
||||
if (params->validation_only) return Status::OK();
|
||||
nvinfer1::Dims input_dims = inputs.at(0).GetTrtDims();
|
||||
nvinfer1::Dims output_dims{1, {input_dims.nbDims}};
|
||||
// Create a const node with the values of output_dims
|
||||
TRT_ShapedWeights weight = params->weight_store->GetTempWeights(
|
||||
nvinfer1::DataType::kINT32, output_dims);
|
||||
int32* values_ptr = static_cast<int32*>(weight.GetValues());
|
||||
std::copy(input_dims.d, input_dims.d + input_dims.nbDims, values_ptr);
|
||||
auto output = params->converter->CreateConstantLayer(weight, output_dims);
|
||||
params->outputs->push_back(TRT_TensorOrWeights(output));
|
||||
return Status::OK();
|
||||
}
|
||||
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
|
||||
if (params->validation_only) return Status::OK();
|
||||
nvinfer1::IShapeLayer* shape_layer =
|
||||
params->converter->network()->addShape(*inputs.at(0).tensor());
|
||||
TFTRT_RETURN_ERROR_IF_NULLPTR(shape_layer, params->node_def.name());
|
||||
params->outputs->push_back(TRT_TensorOrWeights(shape_layer->getOutput(0)));
|
||||
return Status::OK();
|
||||
#else
|
||||
return errors::Unavailable(
|
||||
"Shape op conversion requires TensorRT 6 or above");
|
||||
#endif
|
||||
}
|
||||
|
||||
Status ConvertReshape(OpConverterParams* params) {
|
||||
const auto& inputs = params->inputs;
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -5974,6 +6008,7 @@ static void RegisterValidatableOpConverters(
|
||||
(*registration)[pool_op_type] = ConvertPool3D;
|
||||
}
|
||||
#endif
|
||||
(*registration)["Shape"] = ConvertShape;
|
||||
(*registration)["Rsqrt"] = ConvertRsqrt;
|
||||
(*registration)["Slice"] = ConvertSlice;
|
||||
(*registration)["Softmax"] = ConvertSoftmax;
|
||||
|
@ -1785,7 +1785,8 @@ class ParameterizedOpConverterTestBase
|
||||
void BuildAndRun(const string& name,
|
||||
const std::vector<std::vector<int>>& expected_output_dims,
|
||||
const Status& expected_runtime_status,
|
||||
const std::vector<Matcher<std::vector<float>>>& matcher) {
|
||||
const std::vector<Matcher<std::vector<float>>>& matcher,
|
||||
const std::vector<DataType>& out_tf_types = {}) {
|
||||
TensorShape shape;
|
||||
const int n_output = expected_output_dims.size();
|
||||
ASSERT_EQ(n_output, matcher.size());
|
||||
@ -1794,12 +1795,14 @@ class ParameterizedOpConverterTestBase
|
||||
TF_EXPECT_OK(
|
||||
TensorShapeUtils::MakeShape(expected_output_dims[i], &shape));
|
||||
string out_name = (n_output == 1) ? name : StrCat(name, ":", i);
|
||||
InputOutputData data{out_name,
|
||||
ConstructTensor(shape.num_elements(), 0, tf_type)};
|
||||
DataType out_tf_type =
|
||||
out_tf_types.size() > i ? out_tf_types[i] : tf_type;
|
||||
InputOutputData data{
|
||||
out_name, ConstructTensor(shape.num_elements(), 0, out_tf_type)};
|
||||
output_data.push_back(data);
|
||||
}
|
||||
ASSERT_FALSE(input_data_.empty());
|
||||
const int batch_size = input_data_[0].tensor.shape().dim_size(0);
|
||||
const int batch_size =
|
||||
input_data_.empty() ? 1 : input_data_[0].tensor.shape().dim_size(0);
|
||||
Status stat =
|
||||
OpConverterTest::BuildAndRun(input_data_, &output_data, batch_size);
|
||||
ASSERT_EQ(expected_runtime_status.ok(), stat.ok())
|
||||
@ -1824,13 +1827,15 @@ class ParameterizedOpConverterTestBase
|
||||
const std::vector<int>& expected_output_dims,
|
||||
const Status& expected_conversion_status,
|
||||
const Status& expected_runtime_status,
|
||||
const Matcher<std::vector<float>>& matcher) {
|
||||
const Matcher<std::vector<float>>& matcher,
|
||||
const std::vector<DataType>& out_tf_types = {}) {
|
||||
RunValidationAndConversion(node_def, expected_conversion_status,
|
||||
name.c_str(), expected_output_dims);
|
||||
if (expected_conversion_status.ok()) {
|
||||
BuildAndRun(name, std::vector<std::vector<int>>({expected_output_dims}),
|
||||
expected_runtime_status,
|
||||
std::vector<Matcher<std::vector<float>>>({matcher}));
|
||||
std::vector<Matcher<std::vector<float>>>({matcher}),
|
||||
out_tf_types);
|
||||
}
|
||||
}
|
||||
|
||||
@ -2309,6 +2314,52 @@ TEST_F(OpConverterTest, ConvertReshape) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(OpConverterTest1, ConvertShape) {
|
||||
// Get the NodeDef for Shape op.
|
||||
Scope s = Scope::NewRootScope();
|
||||
auto input = ops::Placeholder(s.WithOpName("input"), tf_type);
|
||||
auto shape = ops::Shape(s.WithOpName("my_shape"), input);
|
||||
const NodeDef& node_def = shape.operation.node()->def();
|
||||
|
||||
Status conversion_status =
|
||||
(trt_mode == TrtTestMode::kImplicitBatch)
|
||||
? errors::Unimplemented(
|
||||
"Shape is only supported for explicit batch mode.")
|
||||
: Status::OK();
|
||||
std::vector<TestParamBase> test_params = {
|
||||
TestParamBase{{1, 2, 3}, {}, {3}, {}, conversion_status},
|
||||
// Add input as weight (we use non empty param ({1}) to trigger this).
|
||||
TestParamBase{{1, 2, 3}, {}, {3}, {1}, conversion_status},
|
||||
};
|
||||
|
||||
auto input_is_weight = [](const TestParamBase p) { return !p.param.empty(); };
|
||||
for (auto p : test_params) {
|
||||
SCOPED_TRACE(p);
|
||||
Reset();
|
||||
// The number of elements of the input tensor. We leave it 0 in case we do
|
||||
// not need to add an input tensor. This happens in explicit batch mode: the
|
||||
// shape is known at conversion time and therefore the shape is added to the
|
||||
// network as a constant layer. In this case the single node network that
|
||||
// we use for the unit test have no actual input tensor when it is converted
|
||||
// to a TensorRT network.
|
||||
int n_elements = 0;
|
||||
if (input_is_weight(p) || trt_mode != TrtTestMode::kExplicitBatch) {
|
||||
// Calculate the number of elements for adding input data.
|
||||
n_elements = std::accumulate(p.input_dims.begin(), p.input_dims.end(), 1,
|
||||
std::multiplies<int>());
|
||||
}
|
||||
std::vector<float> input_val(n_elements, 1);
|
||||
if (!input_is_weight(p)) {
|
||||
AddTestTensor("input", p.input_dims, input_val);
|
||||
} else {
|
||||
AddTestWeights("input", p.input_dims, input_val, tf_type);
|
||||
}
|
||||
TestOpConverter("my_shape", node_def, p.expected_output_dims, p.status,
|
||||
p.runtime_status, ElementsAreArray(p.input_dims),
|
||||
{DT_INT32});
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function for testing MatMul and BatchMatMul
|
||||
// get_matmul corresponds to the function used to generate the node. It should
|
||||
// accept (DataType, transpose_a, transpose_b) as parameters.
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
|
||||
|
||||
#if GOOGLE_CUDA && GOOGLE_TENSORRT
|
||||
@ -35,14 +36,16 @@ void TrtShapeOptimizationProfile::InitProfiles() {
|
||||
<< "for each input (min=opt=max).";
|
||||
}
|
||||
for (auto& shape_vec : input_shapes_) {
|
||||
std::vector<nvinfer1::Dims> dimvec;
|
||||
for (auto& shape : shape_vec) {
|
||||
dimvec.push_back(TensorShapeToTrtDims(shape, false));
|
||||
if (!shape_vec.empty()) {
|
||||
std::vector<nvinfer1::Dims> dimvec(shape_vec.size());
|
||||
absl::c_transform(shape_vec, dimvec.begin(), [](TensorShape shape) {
|
||||
return TensorShapeToTrtDims(shape, false);
|
||||
});
|
||||
// Set min=opt=max.
|
||||
OptimizationProfileConfig profConfig{dimvec, dimvec, dimvec};
|
||||
profiles_.push_back(std::move(profConfig));
|
||||
VLOG(1) << "Created profile " << profiles_.back().DebugString();
|
||||
}
|
||||
// We set min=opt=max.
|
||||
OptimizationProfileConfig profConfig{dimvec, dimvec, dimvec};
|
||||
profiles_.push_back(std::move(profConfig));
|
||||
VLOG(1) << "Created profile " << profiles_.back().DebugString();
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user