Merge of b75a6222b82bb556f63f7a5a04cab45212ed30c6

PiperOrigin-RevId: 294781398
Change-Id: I3cc915dd058a9b1414a7885794e4b95522ea910c
This commit is contained in:
TensorFlower Gardener 2020-02-12 16:16:18 -08:00
commit d74394747a
7 changed files with 349 additions and 68 deletions

View File

@ -328,6 +328,7 @@ Status CreateTRTNode(const ConversionParams& params,
nvinfer1::IGpuAllocator* alloc,
std::vector<Node*>* engine_nodes) {
const auto& info = infos.at(pos);
std::vector<tensorflow::TensorShapeProto> input_shape_protos;
std::vector<PartialTensorShape> input_shapes;
std::vector<NodeDefBuilder::NodeOut> inputs;
std::vector<Node*> input_nodes;
@ -369,8 +370,10 @@ Status CreateTRTNode(const ConversionParams& params,
} else {
// Set the shapes and data types of input edge.
if (input_shapes.size() <= conn.port_number) {
input_shape_protos.resize(conn.port_number + 1);
input_shapes.resize(conn.port_number + 1);
}
conn.outside_shape.AsProto(&input_shape_protos.at(conn.port_number));
input_shapes.at(conn.port_number) = conn.outside_shape;
// Shape must be fully defined (excluding batch dimension) for static
// mode.
@ -454,7 +457,7 @@ Status CreateTRTNode(const ConversionParams& params,
NameAttrList function;
function.set_name(StrCat(info.engine_name, "_native_segment"));
Status status =
node_builder
node_builder.Attr("input_shapes", input_shape_protos)
.Attr("static_engine",
info.engine_type == EngineInfo::EngineType::TRTStatic)
.Attr("segment_func", function)

View File

@ -292,13 +292,19 @@ Status ValidateTensorProperties(const string& producer_node_type,
}
if (validation_only) return Status::OK();
// Following are validations at runtime.
for (int d = first_trt_dim; d < shape.dims(); ++d) {
if (shape.dim_size(d) < 0) {
return errors::InvalidArgument(
"Input tensor with shape ", shape.DebugString(),
" has an unknown non-batch dimension at dim ", d);
// Following checks are only used during TRT engine creation time. In implicit
// batch mode we check that all inputs for the network has static shape (as
// required by the TensorRT). The only exception is the batch size, which
// could be unknown. In contrast, using explicit batch mode this test is not
// necessary, since any dimension could be unknown in explicit batch mode.
if (use_implicit_batch) {
for (int d = first_trt_dim; d < shape.dims(); ++d) {
if (shape.dim_size(d) < 0) {
return errors::InvalidArgument(
"Input tensor with shape ", shape.DebugString(),
" has an unknown non-batch dimension at dim ", d);
}
}
}
return Status::OK();
@ -2405,10 +2411,32 @@ Status Converter::SqueezeTensor(nvinfer1::ITensor* input,
}
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
// For dynamic input shapes, we need to use TRT ops to build the new shape.
// If the remaining dimensions of a squeeze operation have dynamic sizes, we
// need to use TRT ops to build the result shape for the squeeze operation.
// This is because IShuffleLayer::setReshapeDimensions treats -1 as a special
// value.
if (absl::c_any_of(input_dims, [](int i) { return i == -1; })) {
return errors::Unimplemented(
"Squeeze is not implemented for dynamic input shapes");
nvinfer1::ITensor* shape = network()->addShape(*input)->getOutput(0);
std::vector<nvinfer1::ITensor const*> concat_inputs;
for (int i = 0; i < input_dims.size(); i++) {
// If input dim wasn't set to 0 earlier, we include it in new shape.
if (input_dims[i] != 0) {
concat_inputs.push_back(
network()
->addSlice(*shape, {1, {i}}, {1, {1}}, {1, {1}})
->getOutput(0));
}
}
nvinfer1::IConcatenationLayer* concat_layer = network()->addConcatenation(
const_cast<nvinfer1::ITensor* const*>(concat_inputs.data()),
concat_inputs.size());
concat_layer->setAxis(0);
nvinfer1::ITensor* new_shape = concat_layer->getOutput(0);
// Reshape input using new shape
nvinfer1::IShuffleLayer* shuffle = network()->addShuffle(*input);
shuffle->setInput(1, *new_shape);
*output = shuffle->getOutput(0);
return Status::OK();
}
#endif
// Remove all dims which are equal to 0.

View File

@ -170,6 +170,13 @@ class TRTEngineOp : public AsyncOpKernel {
// If true, create calibration graph for INT8 mode. Otherwise, we are using
// user-provided quantization ranges.
bool use_calibration_;
// Array of all input shapes, collected from the input_shapes attribute when
// constructing the TRTEngineOp. The input_shapes attribute is set during
// graph conversion time. This data is used to retrieve which input dimensions
// could be unknown. During inference time this information is not available
// otherwise (all shapes are known (concrete) shapes when we run inference).
std::vector<PartialTensorShape> input_partial_shapes_;
};
#define TYPECASE(dt, X, Y) \
@ -272,6 +279,8 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
TrtPrecisionModeFromName(precision_string, &precision_mode_));
OP_REQUIRES_OK(context,
context->GetAttr("use_calibration", &use_calibration_));
OP_REQUIRES_OK(context,
context->GetAttr("input_shapes", &input_partial_shapes_));
func_handle_ = kInvalidHandle;
if (!static_engine_) {
FunctionLibraryRuntime* lib = context->function_library();
@ -306,7 +315,25 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
use_implicit_batch_ = true;
}
#endif
if (!use_implicit_batch_) {
if (use_implicit_batch_) {
if (input_partial_shapes_.empty()) {
VLOG(1) << "Attribute input_shapes is not set. This happens probably "
<< "because you are using a model that is already converted "
<< "to TensorRT with a previous version of TF-TRT (i.e. includes "
<< "TRTEngineOp in graph). This is not an error. If you convert "
<< "the original model again to TensorRT, the attributes "
<< "input_shapes will be set automatically.";
}
} else {
OP_REQUIRES(
context, !input_partial_shapes_.empty(),
errors::InvalidArgument(
"Explicit batch mode requires attribute input_shapes to be set."
"If you are using a model that was converted to TensorRT by a "
"previous version of TF-TRT, (i.e. includes TRTEngineOp in graph "
"without the input_shapes attribute), then you need to convert the "
"original model again to TensorRT in order to set the attribute "
"input_shapes."));
OP_REQUIRES(context, !calibration_mode_,
errors::InvalidArgument(
"Explicit batch mode does not support calibration"));
@ -393,28 +420,68 @@ void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx,
ExecuteNativeSegment(ctx, helper);
}
Status TRTEngineOp::VerifyInputShapes(const std::vector<TensorShape>& shapes) {
if (shapes.empty()) {
Status TRTEngineOp::VerifyInputShapes(
const std::vector<TensorShape>& input_concrete_shapes) {
if (input_concrete_shapes.empty()) {
return errors::InvalidArgument("Input shapes are empty, for ", name());
}
if (shapes[0].dims() < 1) {
return errors::InvalidArgument("Input shapes contain scalar, for ", name(),
": ",
TensorShapeUtils::ShapeListString(shapes));
if (input_partial_shapes_.empty()) {
if (!use_implicit_batch_) {
return errors::InvalidArgument(
"Explicit batch mode requires input_partial_shapes_ ",
"to contain the dynamic input shapes to TRTEngineOp");
}
// If the graph was converted with an earlier version of TF-TRT, it can
// happen that the input_partial_shapes_ vector is not set (see
// input_shapes attribute handling in the TRTEngineOp constructor).
// In implicit batch mode it is allowed to have empty input_partial_shapes_,
// since it is only required in explicit batch mode (see the input_shapes
// attribute of ConvertGraphDefToEngine in TRTEngineOp::GetEngine.
} else {
// Additional consistency checks if input_partial_shapes_ is present.
const string error_msg = StrCat(
"Input shapes do not match input partial shapes stored in graph, for ",
name(), ": ", DebugString(input_concrete_shapes),
" != ", DebugString(input_partial_shapes_));
if (input_concrete_shapes.size() != input_partial_shapes_.size()) {
return errors::InvalidArgument(error_msg);
}
for (int i = 0; i < input_concrete_shapes.size(); i++) {
if (input_concrete_shapes[i].dims() != input_partial_shapes_[i].dims()) {
return errors::InvalidArgument(error_msg);
}
}
for (int i = 0; i < input_concrete_shapes.size(); i++) {
for (int d = 0; d < input_concrete_shapes[i].dims(); d++) {
if (input_partial_shapes_[i].dim_size(d) != -1) {
if (input_concrete_shapes[i].dim_size(d) !=
input_partial_shapes_[i].dim_size(d)) {
return errors::InvalidArgument(error_msg);
}
}
}
}
}
if (use_implicit_batch_) {
const int batch_size = shapes[0].dim_size(0);
if (batch_size < 1) {
return errors::InvalidArgument("Incorrect batch dimension, for ", name(),
": ",
TensorShapeUtils::ShapeListString(shapes));
if (input_concrete_shapes[0].dims() < 1) {
return errors::InvalidArgument(
"Input shapes contain scalar, for ", name(), ": ",
TensorShapeUtils::ShapeListString(input_concrete_shapes));
}
for (const TensorShape& shape : shapes) {
const int batch_size = input_concrete_shapes[0].dim_size(0);
if (batch_size < 1) {
return errors::InvalidArgument(
"Incorrect batch dimension, for ", name(), ": ",
TensorShapeUtils::ShapeListString(input_concrete_shapes));
}
for (const TensorShape& shape : input_concrete_shapes) {
if (batch_size != shape.dim_size(0)) {
return errors::InvalidArgument(
"Input shapes are inconsistent on the batch dimension, for ",
name(), ": ", TensorShapeUtils::ShapeListString(shapes));
name(), ": ",
TensorShapeUtils::ShapeListString(input_concrete_shapes));
}
}
}
@ -443,11 +510,17 @@ bool AreShapesCompatible(const std::vector<TensorShape>& actual_shapes,
return true;
}
// This routine finds the engines with input shapes compatible with the
// actual_input_shapes, and returns the input shapes of one of such engine that
// has the smallest batch size.
Status TRTEngineOp::GetEngineInputShapes(
const CacheType& cache, const std::vector<TensorShape>& actual_input_shapes,
std::vector<TensorShape>* engine_input_shapes) {
// VerifyInputShapes() already ensured that all input shapes have same
// batch size, and are not scalars.
// batch size, and are not scalars, if we are in implicit batch mode.
//
// In explicit batch mode we plan to have single engine in the cache, and we
// return its shape if it is compatible.
*engine_input_shapes = actual_input_shapes;
int64 min_matched_batch_size = kint64max;
for (const auto& pair : cache) {
@ -513,19 +586,22 @@ void TRTEngineOp::ComputeAsync(OpKernelContext* ctx,
}
// Get shapes of inputs to engine.
std::vector<TensorShape> input_shapes;
input_shapes.reserve(ctx->num_inputs());
std::vector<TensorShape> input_concrete_shapes;
input_concrete_shapes.reserve(ctx->num_inputs());
for (int i = 0; i < ctx->num_inputs(); ++i) {
input_shapes.push_back(ctx->input(i).shape());
input_concrete_shapes.push_back(ctx->input(i).shape());
}
OP_REQUIRES_OK_ASYNC(ctx, VerifyInputShapes(input_shapes), *helper);
StatusOr<EngineContext*> status = GetEngine(input_shapes, ctx, cache_res);
OP_REQUIRES_OK_ASYNC(ctx, VerifyInputShapes(input_concrete_shapes), *helper);
StatusOr<EngineContext*> status =
GetEngine(input_concrete_shapes, ctx, cache_res);
OP_REQUIRES_OK_ASYNC(ctx, status.status(), *helper);
EngineContext* engine_context = status.ValueOrDie();
if (!engine_context->cuda_engine) {
VLOG(1) << "Engine retrieval for input shapes: "
<< TensorShapeUtils::ShapeListString(input_shapes)
<< TensorShapeUtils::ShapeListString(input_concrete_shapes)
<< " failed. Running native segment for " << name();
ExecuteNativeSegment(ctx, helper);
return;
@ -807,14 +883,17 @@ Status TRTEngineOp::GetEngineCacheResource(OpKernelContext* ctx,
}
StatusOr<EngineContext*> TRTEngineOp::GetEngine(
const std::vector<TensorShape>& input_shapes, OpKernelContext* ctx,
const std::vector<TensorShape>& input_concrete_shapes, OpKernelContext* ctx,
TRTEngineCacheResource* cache_res) {
static EngineContext empty_context;
mutex_lock lock(engine_mutex_);
// Using first input to get batch size is reliable - VerifyInputShapes() has
// verified that.
const int batch_size = input_shapes[0].dim_size(0);
// Using first input to get batch size is reliable - VerifyInputShapes()
// guarantees that the first input is not a scalar. As such we can always use
// the first input to get the batch size for implicit batch mode. For explicit
// batch mode, this value is not used.
const int batch_size = input_concrete_shapes[0].dim_size(0);
// TODO(Tamas): remove the need for batch_size in explicit_batch mode
auto& cache = cache_res->cache_;
auto allocator = cache_res->allocator_.get();
if (allocator == nullptr) {
@ -828,7 +907,7 @@ StatusOr<EngineContext*> TRTEngineOp::GetEngine(
// TODO(laigd): need a better shape compatibility check for the case where
// implicit batch is disabled.
if (!use_implicit_batch_ ||
AreShapesCompatible(input_shapes, cache.begin()->first)) {
AreShapesCompatible(input_concrete_shapes, cache.begin()->first)) {
return cache.begin()->second.get();
}
return &empty_context;
@ -846,11 +925,9 @@ StatusOr<EngineContext*> TRTEngineOp::GetEngine(
const auto max_batch_size = raw_static_engine->getMaxBatchSize();
// Static engine will have max_batch_size for batch size so that all inputs
// will map to this single engine.
std::vector<TensorShape> engine_input_shapes(input_shapes);
if (use_implicit_batch_) {
for (int i = 0; i < engine_input_shapes.size(); i++) {
engine_input_shapes[i].set_dim(0, max_batch_size);
}
std::vector<TensorShape> engine_input_shapes(input_concrete_shapes);
for (int i = 0; i < engine_input_shapes.size(); i++) {
engine_input_shapes[i].set_dim(0, max_batch_size);
}
// TODO(laigd): here we assume engine_input_shapes matches the actual input
// shapes of the engine, we should verify that.
@ -874,7 +951,7 @@ StatusOr<EngineContext*> TRTEngineOp::GetEngine(
// Handle the dynamic engine case. See if there is a compatible engine cached.
std::vector<TensorShape> engine_input_shapes;
TF_RETURN_IF_ERROR(
GetEngineInputShapes(cache, input_shapes, &engine_input_shapes));
GetEngineInputShapes(cache, input_concrete_shapes, &engine_input_shapes));
// If matched, use that engine. Otherwise, we will look in cache for that
// exact shape and possibly create a new engine if it is not in cache.
@ -883,17 +960,21 @@ StatusOr<EngineContext*> TRTEngineOp::GetEngine(
bool convert_successfully = false;
LOG(INFO) << "Building a new TensorRT engine for " << name()
<< " with input shapes: "
<< TensorShapeUtils::ShapeListString(engine_input_shapes);
<< TensorShapeUtils::ShapeListString(input_concrete_shapes);
// Convert to partial shapes
std::vector<PartialTensorShape> partial_shapes(engine_input_shapes.begin(),
engine_input_shapes.end());
// Use concrete shapes for implicit batch mode and partial shapes for
// explicit batch mode.
const std::vector<PartialTensorShape>& conversion_input_shapes =
use_implicit_batch_
? std::vector<PartialTensorShape>(input_concrete_shapes.begin(),
input_concrete_shapes.end())
: input_partial_shapes_;
// Up to this point, calibrator_ can never be empty, since otherwise it
// means calibration_mode_ is true and this path won't get executed.
auto status = convert::ConvertGraphDefToEngine(
segment_graph_def_, precision_mode_, batch_size, workspace_size_,
partial_shapes, &logger, allocator, calibrator_.get(), &engine,
conversion_input_shapes, &logger, allocator, calibrator_.get(), &engine,
use_calibration_, use_implicit_batch_, &convert_successfully);
if (!status.ok()) {
LOG(WARNING) << "Engine creation for " << name() << " failed. "
@ -901,12 +982,12 @@ StatusOr<EngineContext*> TRTEngineOp::GetEngine(
<< "Reason: " << status;
// Store an empty engine in the cache for these input shapes so we don't
// try to build the same failing engine again.
cache.emplace(engine_input_shapes, absl::make_unique<EngineContext>());
cache.emplace(input_concrete_shapes, absl::make_unique<EngineContext>());
return &empty_context;
}
TrtUniquePtrType<nvinfer1::IExecutionContext> exec_context(
engine->createExecutionContext());
cache.emplace(engine_input_shapes,
cache.emplace(input_concrete_shapes,
absl::make_unique<EngineContext>(std::move(engine),
std::move(exec_context)));
VLOG(1) << "Added new engine to cache of " << name()

View File

@ -60,7 +60,9 @@ using ::testing::ElementsAre;
class TRTEngineOpTestBase : public OpsTestBase {
public:
void AddSimpleTrtOp(DataType dtype, int max_cached_engines_count = 1) {
void AddSimpleTrtOp(DataType dtype, int max_cached_engines_count = 1,
PartialTensorShape shape = PartialTensorShape({-1, -1}),
bool use_implicit_batch = true) {
// Create the GPU device.
std::unique_ptr<Device> device(
DeviceFactory::NewDevice("GPU", {}, "/job:worker/replica:0/task:0"));
@ -80,9 +82,12 @@ class TRTEngineOpTestBase : public OpsTestBase {
convert::RegisterGraphToFunctionLibrary(graph_def, graph, op_name));
TF_ASSERT_OK(flib_def_->AddLibrary(graph->flib_def()));
PartialTensorShape shape({-1, -1});
// Create the op.
// In implicit batch mode, the input shapes that we specify here are not
// used for engine creation, we use the concrete shapes during inference
// time for creating the engine.
// In explicit batch mode, the input shapes attribute is used to define
// the network for the TensorRT engine.
OpsTestBase::SetDevice(DEVICE_GPU, std::move(device));
NameAttrList function;
function.set_name(StrCat(op_name, "_native_segment"));
@ -98,7 +103,7 @@ class TRTEngineOpTestBase : public OpsTestBase {
.Attr("workspace_size_bytes", 1 << 20)
.Attr("precision_mode", "FP32")
.Attr("use_calibration", false)
.Attr("_use_implicit_batch", true)
.Attr("_use_implicit_batch", use_implicit_batch)
.Attr("OutT", {dtype})
.Finalize(OpsTestBase::node_def()));
TF_ASSERT_OK(InitOpWithFunctionLibrary());
@ -131,7 +136,8 @@ class TRTEngineOpTestBase : public OpsTestBase {
}
};
TEST_F(TRTEngineOpTestBase, DynamicShapes) {
TEST_F(TRTEngineOpTestBase, DynamicEngines) {
// Test dynamic engine creation during inference time
TRTEngineOpTestBase::AddSimpleTrtOp(DT_FLOAT, /*max_cached_engines_count=*/4);
// Execute the op with batch size > 1.
@ -180,6 +186,72 @@ TEST_F(TRTEngineOpTestBase, DynamicShapes) {
EXPECT_EQ(1, cache->count({TensorShape({10, 10})}));
}
TEST_F(TRTEngineOpTestBase, ExplicitBatch) {
// Test inference in explicit batch mode with static input shapes. Static
// shapes in this context means that the TensorRT knows all the input shapes
// during engine creation time.
TRTEngineOpTestBase::AddSimpleTrtOp(DT_FLOAT, /*max_cached_engines_count=*/1,
/*shape=*/PartialTensorShape({1, 2}),
/*use_implicit_batch=*/false);
TensorShape input_shape({1, 2});
TRTEngineOpTestBase::AddSimpleInput<float>(input_shape);
TF_ASSERT_OK(OpsTestBase::RunOpKernel());
// Get the engine cache.
TRTEngineCacheResource* cache_resource = nullptr;
TF_ASSERT_OK(
device_->resource_manager()->Lookup("TF-TRT", "myop", &cache_resource));
core::ScopedUnref sc(cache_resource);
// The cache should contain only one EngineContext, with a valid cuda_engine.
auto cache = &cache_resource->cache_;
EXPECT_EQ(1, cache->size());
ASSERT_EQ(1, cache->count({input_shape}));
EngineContext* ectx = cache->at({input_shape}).get();
EXPECT_NE(ectx->cuda_engine, nullptr);
}
TEST_F(TRTEngineOpTestBase, DynamicShapes) {
// Test inference in explicit batch mode with dynamic input shapes. Dynamic
// shapes in this context means that some input shapes for TensorRT are
// unknown during engine creation time. When we create the network, the
// unknow shapes are repsesented as -1. Before we run inference, these shapes
// have to be specified by calling setBindingDimensions.
TRTEngineOpTestBase::AddSimpleTrtOp(DT_FLOAT, /*max_cached_engines_count=*/1,
/*shape=*/PartialTensorShape({-1, -1}),
/*use_implicit_batch=*/false);
TensorShape input_shape({1, 2});
TRTEngineOpTestBase::AddSimpleInput<float>(input_shape);
// We expect that TensorRT engine creation fails: we would need to configure
// the engine with optimization profiles to use dynamic input shapes, but that
// feature is not yet implemented.
//
// Since TRT engine creation has failed, we fall back to native segment.
// Calling the native segment fails for the same reason that is investigated
// in https://github.com/tensorflow/tensorflow/pull/34919. This is irrelevant
// for the current test, here we want to just check wether TRT engine creation
// has failed.
OpsTestBase::RunOpKernel();
// Get the engine cache.
TRTEngineCacheResource* cache_resource = nullptr;
TF_ASSERT_OK(
device_->resource_manager()->Lookup("TF-TRT", "myop", &cache_resource));
core::ScopedUnref sc(cache_resource);
// The cache should contain only one EngineContext.
auto cache = &cache_resource->cache_;
EXPECT_EQ(1, cache->size());
ASSERT_EQ(1, cache->count({input_shape}));
EngineContext* ectx = cache->at({input_shape}).get();
// Since engine creation failed, we expect to find nullptr. Finding a nullptr
// indicates that unknown shapes were used to define the TensorRT network.
EXPECT_EQ(ectx->cuda_engine, nullptr);
}
template <typename T>
class TRTEngineOpTest : public TRTEngineOpTestBase {};

View File

@ -36,6 +36,7 @@ REGISTER_OP("TRTEngineOp")
.Attr("segment_func: func = {}")
.Attr("InT: list({int8,float16,float32,int32})")
.Attr("OutT: list({int8,float16,float32,int32})")
.Attr("input_shapes: list(shape) = []")
.Attr("max_cached_engines_count: int = 1")
.Attr("workspace_size_bytes: int")
.Attr("precision_mode: {'FP32', 'FP16', 'INT8'}")
@ -54,7 +55,6 @@ REGISTER_OP("TRTEngineOp")
.Attr("segment_funcdef_name: string = ''")
.Attr("cached_engine_batches: list(int) >= 0 = []")
.Attr("fixed_input_size: bool = true")
.Attr("input_shapes: list(shape) = []")
.Attr("output_shapes: list(shape) = []")
.Attr("static_engine: bool = true");
} // namespace tensorflow

View File

@ -157,28 +157,88 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
super(TfTrtIntegrationTestBase, self).setUp()
warnings.simplefilter("always")
def BuildParams(self, graph_fn, dtype, input_shapes, output_shapes):
"""Build test parameters when not considering dynamic shapes."""
def _GetTensorSpec(self, shape, mask, dtype, name):
# Set dimension i to None if mask[i] == False
assert len(shape) == len(mask)
new_shape = [s if m else None for s, m in zip(shape, mask)]
return tensor_spec.TensorSpec(new_shape, dtype, name)
def _Validate(shapes):
def BuildParams(self, graph_fn, dtype, input_shapes, output_shapes):
"""Build test parameters.
The input_shapes and output_shapes arguments are known (static) shapes that
can be used to generate test data. To define the model, we also specify
corresponding input/output TensoSpecs. These are defined using the shape
arguments. For each input tensor we define:
input_spec = [None] + input_shape[1:]
and similarly for output shapes. This means that we leave the first (batch)
dimension unknown, the rest is just copied from the shapes arg.
Args:
graph_fn: The function to build the graph.
dtype: The element type.
input_shapes: The input shapes.
output_shapes: The output shapes.
Returns:
The test parameters.
"""
input_mask = [[False] + [True] * (len(shape) - 1) for shape in input_shapes]
output_mask = [
[False] + [True] * (len(shape) - 1) for shape in output_shapes
]
return self.BuildParamsWithMask(graph_fn, dtype, input_shapes,
output_shapes, input_mask, output_mask)
def BuildParamsWithMask(self, graph_fn, dtype, input_shapes, output_shapes,
input_mask, output_mask):
"""Build test parameters with static or dynamic input shapes.
To define dynamic shapes give a boolean mask that describes which
dimensions to treat as known. The values in input_mask are interpreted the
following way:
- True: known dim (use the corresponding value from input_shapes)
- False: unknown dim (replace the corresponding value from input_shapes
with None)
For example, to define the first two dimension with unknown size use
input_shapes=[[1,2,1,8]], input_mask=[[False, False, True, True]].
Args:
graph_fn: The function to build the graph.
dtype: The element type.
input_shapes: The input shapes.
output_shapes: The output shapes.
input_mask: The input shape masks.
output_mask: the output shape masks.
Returns:
The test parameters.
"""
def _ValidateShapes(shapes):
# Make sure all the shapes are fully specified.
for shape in shapes:
assert all(shape)
_Validate(input_shapes)
_Validate(output_shapes)
_ValidateShapes(input_shapes)
_ValidateShapes(output_shapes)
assert len(input_mask) == len(input_shapes)
assert len(output_mask) == len(output_shapes)
return TfTrtIntegrationTestParams(
graph_fn=graph_fn,
# Unset the batch dim of the specs to make sure TRT can tolerate changes
# on that.
input_specs=[
tensor_spec.TensorSpec([None] + shape[1:], dtype, "input_%d" % i)
for i, shape in enumerate(input_shapes)
self._GetTensorSpec(shape, mask, dtype, "input_%d" % i)
for i, (shape, mask) in enumerate(zip(input_shapes, input_mask))
],
output_specs=[
tensor_spec.TensorSpec([None] + shape[1:], dtype, "output_%d" % i)
for i, shape in enumerate(output_shapes)
self._GetTensorSpec(shape, mask, dtype, "output_%d" % i)
for i, (shape, mask) in enumerate(zip(output_shapes, output_mask))
],
input_dims=[input_shapes],
expected_output_dims=[output_shapes])

View File

@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from unittest import skip # pylint: disable=g-importing-member
from unittest import SkipTest # pylint: disable=g-importing-member
from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test
from tensorflow.python.framework import dtypes
@ -27,7 +27,6 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
@skip("TrtModeTestBase defines a common base class for other tests")
class TrtModeTestBase(trt_test.TfTrtIntegrationTestBase):
"""Test squeeze on batch dim and some unary operations in TF-TRT."""
@ -66,6 +65,12 @@ class TrtModeTestBase(trt_test.TfTrtIntegrationTestBase):
use_implicit_batch=implicit_batch)
return conversion_params._replace(rewriter_config_template=rewriter_config)
@classmethod
def setUpClass(cls):
if cls is TrtModeTestBase:
raise SkipTest("TrtModeTestBase defines base class for other test.")
super(TrtModeTestBase, cls).setUpClass()
class ImplicitBatchTest(TrtModeTestBase):
@ -92,6 +97,14 @@ class ImplicitBatchTest(TrtModeTestBase):
class ExplicitBatchTest(TrtModeTestBase):
def GetParams(self):
"""We specify input/output masks with static (known) shapes."""
return self.BuildParamsWithMask(
self.GraphFn,
dtypes.float32, [[1, 12, 5]], [[12, 5]],
input_mask=[[True, True, True]],
output_mask=[[True, True]])
def GetConversionParams(self, run_params):
"""Return a TrtConversionParams for test that enables explicit batch."""
return super(ExplicitBatchTest, self).GetConversionParams(run_params, False)
@ -110,5 +123,29 @@ class ExplicitBatchTest(TrtModeTestBase):
return ["TRTEngineOp_0"]
class DynamicShapesTest(TrtModeTestBase):
"""Test with dynamic input shapes.
DynamicShapesTest is different from ExplicitBatchTest in that it uses input
and output masks to change the input and output shapes to unknown shapes.
"""
def GetParams(self):
"""We specify input/output mask with dynamic (unknown) shapes."""
return self.BuildParamsWithMask(
self.GraphFn,
dtypes.float32, [[1, 12, 5]], [[12, 5]],
input_mask=[[False, False, False]],
output_mask=[[False, False]])
def GetConversionParams(self, run_params):
"""Return a TrtConversionParams for test that enables explicit batch."""
return super(DynamicShapesTest, self).GetConversionParams(run_params, False)
def ExpectedEnginesToBuild(self, run_params):
"""Return the expected engines to build."""
return ["TRTEngineOp_0"]
if __name__ == "__main__":
test.main()