[TF:TRT] Add support for per cluster maximum batch size.

Previously, with implicit batch mode, all the nodes inside the same graph uses the same max batch size. With this CL, users will be able to configure part of the nodes to use a different max batch size with an optional `_tftrt_op_max_batch_size` attribute on the node. Besides, all the static batch size will be treated in the same way as `_tftrt_op_max_batch_size` attribute annotation.

During segmentation, TF-TRT will avoid putting nodes with different annotated max batch size into the same cluster.
- For static engines, if any nodes inside the cluster are annotated with a customized max batch size, TF-TRT will use the customized max batch size to build a static engine. Otherwise, TF-TRT will use the default max batch size in convert parameters to build a static engine;
- For dynamic engines, TF-TRT will still use the batch size detected at runtime as the max batch size.

PiperOrigin-RevId: 338162727
Change-Id: I6aaf1157353676850dfb6c75b344149c55a8bc11
This commit is contained in:
A. Unique TensorFlower 2020-10-20 16:42:16 -07:00 committed by TensorFlower Gardener
parent e643795254
commit 2638bb9920
12 changed files with 468 additions and 109 deletions

View File

@ -60,8 +60,12 @@ limitations under the License.
namespace tensorflow {
namespace tensorrt {
namespace convert {
using absl::StrAppend;
using absl::StrCat;
using ::tensorflow::tensorrt::segment::ClusterProperty;
using ::tensorflow::tensorrt::segment::NodePtrCompare;
using ::tensorflow::tensorrt::segment::Segment;
namespace {
@ -125,15 +129,21 @@ bool ShallKeepControlEdgeFrom(const Node* input_node) {
// Function to get subsegment information structure.
Status GetEngineInfo(const Graph* g,
const grappler::GraphProperties& graph_properties,
const std::set<const Node*>& segment_nodes,
const Segment& segment,
const std::unordered_map<string, Node*>& node_map,
const std::vector<Node*>& reverse_topo_order,
EngineInfo* info) {
std::vector<const Node*> subgraph_nodes; // Topologically sorted nodes.
std::set<const Node*> added_const_nodes; // Used to prevent double insertion.
const ClusterProperty& segment_property = segment.property;
const std::set<const Node*, NodePtrCompare>& segment_nodes = segment.nodes;
// The device assignment accumulated from the compatible device assignments
// for the nodes in the segment.
DeviceNameUtils::ParsedName segment_device;
const DeviceNameUtils::ParsedName segment_device =
segment_property.DeviceName();
info->max_batch_size = segment_property.BatchSize().GetOptionalMaxBatchSize();
// Map from src_node_name+port to the unique port numbers of the TRT op, where
// the src_node_name is the name of the source node of the input/output
@ -146,18 +156,6 @@ Status GetEngineInfo(const Graph* g,
++it) {
const Node* node = *it;
if (segment_nodes.count(node) == 0) continue;
absl::optional<DeviceNameUtils::ParsedName> new_segment_device =
MergeIfCompatible(segment_device, GetDeviceName(node));
if (!new_segment_device.has_value()) {
// The segmenter should guarantee that nodes in the same segment have
// compatible device assignments.
return errors::Internal(
"segment nodes have incompatible device assignments: ",
DeviceNameUtils::ParsedNameToString(segment_device), " vs ",
GetDeviceName(node), " to node ", node->name());
}
segment_device = *new_segment_device;
subgraph_nodes.push_back(node);
const int node_id = node->id();
@ -332,7 +330,7 @@ void UpdateToEngineNode(const std::vector<EngineInfo>& infos,
// invocation of CreateTRTNode().
Status CreateTRTNode(const ConversionParams& params,
const std::vector<EngineInfo>& infos, int pos,
int max_batch_size, Graph* graph,
int default_max_batch_size, Graph* graph,
std::vector<Node*>* engine_nodes) {
const auto& info = infos.at(pos);
std::vector<tensorflow::TensorShapeProto> input_shape_protos;
@ -427,6 +425,11 @@ Status CreateTRTNode(const ConversionParams& params,
(info.precision_mode == TrtPrecisionMode::INT8 && info.use_calibration);
// Build the engine and get its serialized representation.
string segment_string;
int max_batch_size = info.max_batch_size.has_value()
? info.max_batch_size.value()
: default_max_batch_size;
if (info.engine_type == EngineInfo::EngineType::TRTStatic) {
std::pair<int, Allocator*> device_allocator =
GetDeviceAndAllocator(params, info);
@ -443,6 +446,7 @@ Status CreateTRTNode(const ConversionParams& params,
cudaSetDevice(cuda_device_id);
auto trt_logger = GetLoggerRegistry()->LookUp(params.trt_logger_name);
// Create static engines with precision_mode fp32/fp16.
TrtUniquePtrType<nvinfer1::ICudaEngine> engine;
TF_RETURN_IF_ERROR(ConvertGraphDefToEngine(
@ -486,6 +490,7 @@ Status CreateTRTNode(const ConversionParams& params,
.Attr("calibration_data", "")
.Attr("max_cached_engines_count", info.maximum_cached_engines)
.Attr("workspace_size_bytes", info.max_workspace_size_bytes)
.Attr("max_batch_size", max_batch_size)
.Attr("precision_mode", prec_string)
.Attr("use_calibration", info.use_calibration)
.Attr("_use_implicit_batch", params.use_implicit_batch)
@ -738,7 +743,7 @@ Status ConvertAfterShapes(const ConversionParams& params) {
segment_options.allow_dynamic_non_batch_dim =
AllowDynamicNonBatchDimension(params);
segment::SegmentNodesVector initial_segments;
segment::SegmentVector initial_segments;
TrtNodeValidator validator(static_graph_properties, params.precision_mode,
params.use_calibration, params.use_implicit_batch);
TF_RETURN_IF_ERROR(segment::SegmentGraph(
@ -759,7 +764,7 @@ Status ConvertAfterShapes(const ConversionParams& params) {
engine_segments.reserve(initial_segments.size());
std::vector<Node*> reverse_topo_order;
GetPostOrder(graph, &reverse_topo_order);
segment::SegmentNodesVector converted_segments;
segment::SegmentVector converted_segments;
converted_segments.reserve(initial_segments.size());
string engine_name_prefix =
StrCat("TRTEngineOp_", GetNextGraphSequenceNumber(), "_");
@ -779,6 +784,9 @@ Status ConvertAfterShapes(const ConversionParams& params) {
curr_engine.use_calibration = params.use_calibration;
curr_engine.maximum_cached_engines = params.max_cached_engines;
curr_engine.allow_build_at_runtime = params.allow_build_at_runtime;
if (!curr_engine.max_batch_size.has_value()) {
curr_engine.max_batch_size = params.max_batch_size;
}
status = RegisterGraphToFunctionLibrary(curr_engine.segment_graph_def,
&graph, curr_engine.engine_name);
@ -837,7 +845,7 @@ Status ConvertAfterShapes(const ConversionParams& params) {
params.max_batch_size, &graph, &engine_nodes);
string msg = StrCat("segment ", i, " consisting of ",
converted_segments.at(i).size(), " nodes by ",
converted_segments.at(i).nodes.size(), " nodes by ",
engine.engine_name);
if (status.ok()) {
LOG(INFO) << "Replaced " << msg << ".";
@ -849,7 +857,7 @@ Status ConvertAfterShapes(const ConversionParams& params) {
}
if (VLOG_IS_ON(1)) {
msg = "Segment consists of nodes: ";
for (const Node* node : converted_segments.at(i)) {
for (const Node* node : converted_segments.at(i).nodes) {
StrAppend(&msg, node->name(), ", ");
}
VLOG(1) << msg;
@ -858,7 +866,7 @@ Status ConvertAfterShapes(const ConversionParams& params) {
// If status is ok, we successfully added the node to the graph and can
// remove segment ops. Otherwise graph is not modified.
if (status.ok()) {
for (const Node* node : converted_segments.at(i)) {
for (const Node* node : converted_segments.at(i).nodes) {
graph.RemoveNode(const_cast<Node*>(node));
}
}

View File

@ -92,6 +92,8 @@ struct EngineInfo {
EngineInfo()
: engine_type(EngineType::TRTStatic),
max_workspace_size_bytes(0),
max_batch_size(absl::nullopt),
maximum_cached_engines(0),
precision_mode(TrtPrecisionMode::FP32),
use_calibration(true),
allow_build_at_runtime(true) {}
@ -108,6 +110,7 @@ struct EngineInfo {
enum class EngineType { TRTStatic = 0, TRTDynamic = 1 };
EngineType engine_type;
int64 max_workspace_size_bytes;
absl::optional<int> max_batch_size;
int maximum_cached_engines;
TrtPrecisionMode precision_mode;
bool use_calibration;

View File

@ -37,6 +37,7 @@ REGISTER_OP("TRTEngineOp")
.Attr("OutT: list({int8,float16,float32,int32})")
.Attr("input_shapes: list(shape) = []")
.Attr("max_cached_engines_count: int = 1")
.Attr("max_batch_size: int = 1")
.Attr("workspace_size_bytes: int")
.Attr("precision_mode: {'FP32', 'FP16', 'INT8'}")
.Attr("calibration_data: string = ''")

View File

@ -24,7 +24,6 @@ limitations under the License.
#include "absl/strings/str_format.h"
#include "tensorflow/compiler/tf2tensorrt/common/utils.h"
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
#include "tensorflow/compiler/tf2tensorrt/segment/union_find.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph.h"
@ -242,12 +241,6 @@ struct SimpleEdgePtrCompare {
}
};
struct NodePtrCompare {
bool operator()(const Node* lhs, const Node* rhs) const {
return lhs->name() < rhs->name();
}
};
// Copied from TF ReverseDFS, which only works for Graph.
void StableDFS(const SimpleGraph& g, bool reverse,
const std::vector<const SimpleNode*>& start,
@ -646,6 +639,12 @@ ClusterBatchSize GetClusterBatchSizeForNode(
return cluster_batch_size;
}
const NodeDef& node_def = node->def();
if (node_def.attr().count(kTftrtOpMaxBatchSizeAttr)) {
cluster_batch_size.SetMaxBatchSize(
node_def.attr().at(kTftrtOpMaxBatchSizeAttr).i());
}
// As shape inference cannot provide any useful information about the batch
// size, we keep it as missing.
if (!graph_properties ||
@ -660,9 +659,8 @@ ClusterBatchSize GetClusterBatchSizeForNode(
FindLeadingShape(GetInputsToDeterminateBatchSize(node, input_properties));
DCHECK(optional_leading_shape.has_value());
const TensorShapeProto* leading_shape = optional_leading_shape.value();
DCHECK(!leading_shape->unknown_rank() && leading_shape->dim_size() >= 2);
VLOG(3) << "has batch size " << leading_shape->dim(0).size();
VLOG(3) << "set batch size as " << leading_shape->dim(0).size();
return cluster_batch_size.SetBatchSize(leading_shape->dim(0).size());
}
@ -679,22 +677,6 @@ void AddSegmentForNode(const grappler::GraphProperties* graph_properties,
segments->emplace_back(node, std::move(property));
}
bool OpBatchSizeExceedMaximumBatchSize(
const grappler::GraphProperties* graph_properties, const Node* node,
bool use_implicit_batch, absl::optional<int> maximum_batch_size) {
ClusterBatchSize cluster_batch_size =
GetClusterBatchSizeForNode(graph_properties, node, use_implicit_batch);
// If the batch size is dynamic, then the negative dynamic batch size
// identifier shall never be larger than the positive max batch size.
if (cluster_batch_size.HasBatchSize() && maximum_batch_size.has_value() &&
cluster_batch_size.GetBatchSize() > maximum_batch_size.value()) {
VLOG(2) << "OP batch size " << cluster_batch_size.GetBatchSize()
<< " max_batch_size " << maximum_batch_size.value();
return true;
}
return false;
}
} // namespace
Status SegmentGraph(const Graph* tf_graph,
@ -702,8 +684,7 @@ Status SegmentGraph(const Graph* tf_graph,
const std::function<Status(const Node*)>& candidate_fn,
const std::function<bool(const Edge*)>& input_candidate_fn,
const std::function<bool(const Edge*)>& output_candidate_fn,
const SegmentOptions& options,
SegmentNodesVector* segments) {
const SegmentOptions& options, SegmentVector* segments) {
if (!options.use_implicit_batch && !options.allow_dynamic_non_batch_dim) {
return errors::Internal(
"Explicit batch mode should allow dynamic non-batch dimensions");
@ -791,14 +772,6 @@ Status SegmentGraph(const Graph* tf_graph,
<< "(Op type: " << node->tf_node()->type_string() << "), "
<< "(Op name: " << node->name() << ")";
exclude_node("Denylisted with the env var TF_TRT_OP_DENYLIST");
} else if (OpBatchSizeExceedMaximumBatchSize(
graph_properties, node->tf_node(),
options.use_implicit_batch, options.maximum_batch_size)) {
LOG_WARNING_WITH_PREFIX
<< "Implicit batch mode requires OP batch size not larger than "
<< "the converter maximum batch size: "
<< "(Op name: " << node->name() << ")";
exclude_node("OP batch size too large");
} else {
VLOG(2) << "Accepted as a TF-TRT candidate, "
<< "(Op type: " << node->tf_node()->type_string() << "), "
@ -947,18 +920,21 @@ Status SegmentGraph(const Graph* tf_graph,
// A map from the segment identifier (currently the name of the root node of
// the segment tree) to the segment nodes set.
std::map<string, std::set<const Node*, NodePtrCompare>> sg_map;
std::map<string, Segment> sg_map;
for (auto& u : node_segments) {
if ((u.Value() != nullptr) && (u.ParentValue() != nullptr)) {
sg_map[u.ParentValue()->name()].insert(u.Value()->tf_node());
sg_map[u.ParentValue()->name()].nodes.insert(u.Value()->tf_node());
}
if ((u.Value() != nullptr) && (u.ParentValue() == u.Value())) {
sg_map[u.Value()->name()].property = u.Property();
}
}
// --------------------------------- Step 2 ---------------------------------
// Remove ineligible input/output nodes.
for (auto& itr : sg_map) {
std::set<const Node*, NodePtrCompare>& segment_nodes = itr.second;
std::set<const Node*, NodePtrCompare>& segment_nodes = itr.second.nodes;
VLOG(1) << "Segment original size: " << segment_nodes.size();
while (true) {
std::deque<const Node*> in_nodes_que, out_nodes_que;
@ -1046,7 +1022,8 @@ Status SegmentGraph(const Graph* tf_graph,
for (const auto& itr : sg_map) {
const string& segment_root = itr.first;
// Return format does not require set comparator.
std::set<const Node*> segment_nodes(itr.second.begin(), itr.second.end());
std::set<const Node*, NodePtrCompare> segment_nodes(
itr.second.nodes.begin(), itr.second.nodes.end());
if (VLOG_IS_ON(1) && !segment_nodes.empty()) {
string s;
for (auto node : segment_nodes) {
@ -1070,8 +1047,7 @@ Status SegmentGraph(const Graph* tf_graph,
<< num_effective_nodes << " effective nodes, dropping";
continue;
}
segments->emplace_back(segment_nodes);
segments->emplace_back(itr.second.property, segment_nodes);
}
return Status::OK();

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <vector>
#include "absl/types/optional.h"
#include "tensorflow/compiler/tf2tensorrt/segment/union_find.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
@ -32,10 +33,10 @@ namespace tensorflow {
namespace tensorrt {
namespace segment {
// Vector of segments, each entry contains a set of node pointers.
using SegmentNodesVector = std::vector<std::set<const Node*>>;
constexpr char kTftrtOpMaxBatchSizeAttr[] = "_tftrt_op_max_batch_size";
struct SegmentOptions {
// This struct holds per graph segmenting parameters.
// Segment must contain at least this many nodes.
int minimum_segment_size = 2;
bool use_implicit_batch = true;
@ -45,9 +46,28 @@ struct SegmentOptions {
// When use_implicit_batch is false or when we are building dynamic engines,
// we allow dynamic non-batch dimensions.
bool allow_dynamic_non_batch_dim = false;
// The name of the device to put the segment on.
std::set<string> exclude_node_list;
};
struct NodePtrCompare {
bool operator()(const Node* lhs, const Node* rhs) const {
return lhs->name() < rhs->name();
}
};
struct Segment {
Segment() {}
Segment(const ClusterProperty& property,
const std::set<const Node*, NodePtrCompare>& nodes)
: property(property), nodes(nodes) {}
ClusterProperty property;
std::set<const Node*, NodePtrCompare> nodes;
};
// Vector of segments, each entry contains a set of node pointers.
using SegmentVector = std::vector<Segment>;
// Get the subgraphs of a graph that can be handled by TensorRT.
//
// @param tf_graph Graph of the network.
@ -63,8 +83,7 @@ Status SegmentGraph(const Graph* tf_graph,
const std::function<Status(const Node*)>& candidate_fn,
const std::function<bool(const Edge*)>& input_candidate_fn,
const std::function<bool(const Edge*)>& output_candidate_fn,
const SegmentOptions& options,
SegmentNodesVector* segments);
const SegmentOptions& options, SegmentVector* segments);
} // namespace segment
} // namespace tensorrt

View File

@ -65,7 +65,7 @@ class SegmentTest : public ::testing::Test {
const std::set<string>& input_candidates,
const std::set<string>& output_candidates,
const std::vector<std::set<string>>& expected_segments) {
SegmentNodesVector segments;
SegmentVector segments;
TF_EXPECT_OK(SegmentGraph(graph, graph_properties,
MakeCandidateFn(candidates),
MakeInputEdgeCandidateFn(input_candidates),
@ -82,12 +82,12 @@ class SegmentTest : public ::testing::Test {
expected_segments);
}
void ValidateSegment(const SegmentNodesVector& segments,
void ValidateSegment(const SegmentVector& segments,
const std::vector<std::set<string>>& expected_segments) {
EXPECT_EQ(expected_segments.size(), segments.size());
for (int i = 0; i < segments.size(); ++i) {
std::set<string> segment_node_names;
for (const Node* node : segments[i]) {
for (const Node* node : segments[i].nodes) {
segment_node_names.insert(node->name());
}
const auto& expected = expected_segments[i];
@ -490,9 +490,10 @@ TEST_F(SegmentTest, TwoChainsDiffBatchSizes) {
RunTest(&g, &static_graph_properties, all_nodes, all_nodes, all_nodes,
/*expected_segments=*/{{"output-0", "const-scalar"}});
// Converter will create engines based on the static batch size
EnableImplicitBatchModeForStaticEngine(1);
RunTest(&g, &static_graph_properties, all_nodes, all_nodes, all_nodes,
/*expected_segments=*/{});
/*expected_segments=*/{{"output-0", "const-scalar"}});
}
TEST_F(SegmentTest, SameRankImplicitBroadcastingStaticBatchSize) {

View File

@ -17,7 +17,6 @@ limitations under the License.
#include "absl/strings/str_format.h"
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
#include "tensorflow/core/lib/core/errors.h"
#if GOOGLE_CUDA && GOOGLE_TENSORRT
@ -54,10 +53,12 @@ inline absl::optional<T> MergeCompatible(const absl::optional<T>& a,
} // namespace
ClusterBatchSize::ClusterBatchSize() : batch_size_(absl::nullopt) {}
ClusterBatchSize::ClusterBatchSize()
: batch_size_(absl::nullopt), max_batch_size_(absl::nullopt) {}
bool ClusterBatchSize::operator==(const ClusterBatchSize& other) {
return batch_size_ == other.batch_size_;
return batch_size_ == other.batch_size_ &&
max_batch_size_ == other.max_batch_size_;
}
ClusterBatchSize& ClusterBatchSize::SetBatchSize(int batch_size) {
@ -68,6 +69,9 @@ ClusterBatchSize& ClusterBatchSize::SetBatchSize(int batch_size) {
ClusterBatchSize& ClusterBatchSize::SetBatchSize(
const absl::optional<int>& batch_size) {
batch_size_ = MergeCompatible<int>(batch_size_, batch_size);
if (batch_size_.has_value() && batch_size_.value() >= 0) {
SetMaxBatchSize(batch_size_);
}
return *this;
}
@ -78,23 +82,45 @@ int ClusterBatchSize::GetBatchSize() const {
return batch_size_.value();
}
ClusterBatchSize& ClusterBatchSize::SetMaxBatchSize(int max_batch_size) {
SetBatchSize(static_cast<absl::optional<int>>(max_batch_size));
return *this;
}
ClusterBatchSize& ClusterBatchSize::SetMaxBatchSize(
const absl::optional<int>& max_batch_size) {
max_batch_size_ = MergeCompatible<int>(max_batch_size_, max_batch_size);
return *this;
}
absl::optional<int> ClusterBatchSize::GetOptionalMaxBatchSize() const {
return max_batch_size_;
}
bool ClusterBatchSize::MergeIfCompatible(const ClusterBatchSize& other) {
if (!CheckIfCompatible(batch_size_, other.batch_size_)) {
if (!CheckIfCompatible(batch_size_, other.batch_size_) ||
!CheckIfCompatible(max_batch_size_, other.max_batch_size_)) {
return false;
}
SetBatchSize(other.batch_size_);
SetMaxBatchSize(other.max_batch_size_);
return true;
}
string ClusterBatchSize::ToString() const {
string s;
absl::StrAppendFormat(&s, "batch_size=(");
if (HasBatchSize()) {
absl::StrAppendFormat(&s, "%d", GetBatchSize());
} else {
absl::StrAppendFormat(&s, "?");
}
absl::StrAppend(&s, ")");
const auto append_optional_num = [&](const absl::optional<int>& num) {
if (num.has_value()) {
absl::StrAppendFormat(&s, "%d", num.value());
} else {
absl::StrAppendFormat(&s, "?");
}
};
absl::StrAppendFormat(&s, "batch_size=");
append_optional_num(batch_size_);
absl::StrAppendFormat(&s, ", max_batch_size=");
append_optional_num(max_batch_size_);
return s;
}

View File

@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_UNION_FIND_H_
#include "absl/types/optional.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/util/device_name_utils.h"
#if GOOGLE_CUDA && GOOGLE_TENSORRT
@ -45,14 +46,24 @@ namespace segment {
// cluster to either have the same dynamic batch size equivalent class or the
// same static batch size value.
//
// Besides, all the nodes with an annotated max batch size inside the same
// cluster shall have the same annotated max batch size. (It is allowed if
// part or all the nodes inside the cluster doesn't have annotated max batch
// size). Static batch sizes are treated as max batch size annotations. The
// converter max batch size is used for an OP with a dynamic batch size and no
// annotated max batch size.
//
// cluster: a = a1[1,3] + a1[1,3]
// ClusterBatchSize: batch_size_ = 1
// max_batch_size_ = 1
//
// cluster: b = b1[-1,3] + b2[-1, 3]
// ClusterBatchSize: batch_size_ = -1
// max_batch_size_ = null
//
// cluster: c = c1[-2,3] + c2[-2, 3]
// cluster: c = c1[-2,3] + c2[-2, 3](max_batch_size=100)
// ClusterBatchSize: batch_size_ = -2
// max_batch_size_ = 100
//
// When constructing cluster for explicit batch mode, all ClusterBatchSize is
// irrelevant.
@ -72,23 +83,35 @@ class ClusterBatchSize {
bool HasBatchSize() const;
int GetBatchSize() const;
// Sets the max batch size assuming that the object doesn't have a max batch
// size yet.
ClusterBatchSize& SetMaxBatchSize(int max_batch_size);
absl::optional<int> GetOptionalMaxBatchSize() const;
// Merge `other` into the current ClusterBatchSize if the two are not
// conflicting. Two ClusterBatchSizes are conflicting iff they both have a
// value and their values are different.
bool MergeIfCompatible(const ClusterBatchSize& other);
// Returns a string for the batch size.
// Returns a string for the batch size and the annotated max batch size.
// For the batch size:
// If the object has a static batch size, return a string representing a
// non-negative integer.
// If the object has a dynamic batch size, return a string representing a
// negative integer as an equivalent class.
// If the object doesn't have a batch size yet, return a "?" symbol string.
// If the object doesn't have a batch size yet, return "?".
// For the annotated max batch size:
// If the cluster has annotated max batch size in at least one of the nodes,
// return a string representing the annotated max batch size. Otherwise,
// return "?".
std::string ToString() const;
private:
ClusterBatchSize& SetBatchSize(const absl::optional<int>& batch_size);
ClusterBatchSize& SetMaxBatchSize(const absl::optional<int>& batch_size);
absl::optional<int> batch_size_;
absl::optional<int> max_batch_size_;
};
inline std::ostream& operator<<(std::ostream& os,

View File

@ -124,6 +124,7 @@ cuda_py_test(
cuda_py_tests(
name = "tf_trt_integration_test",
srcs = [
"test/annotate_max_batch_sizes_test.py",
"test/base_test.py",
"test/batch_matmul_test.py",
"test/biasadd_matmul_test.py",

View File

@ -0,0 +1,147 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Testing the impact of graph node _tftrt_op_max_batch_size annotation on TRTEngineOp attributes."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
class MaxBatchSizesTestBase(trt_test.TfTrtIntegrationTestBase):
@classmethod
def setUpClass(cls):
if cls is MaxBatchSizesTestBase:
raise unittest.SkipTest(
'MaxBatchSizesTestBase defines base class for other tests.')
super(MaxBatchSizesTestBase, cls).setUpClass()
@property
def tensor_shapes(self):
return [[1, 512, 1, 1], [64, 2, 2, 2], [32, 4, 2, 2], [16, 8, 2, 2]]
@property
def max_batch_sizes(self):
return [shape[0] for shape in self.tensor_shapes]
def GetParams(self):
"""Gets the build parameters for the test."""
return self.BuildParams(
self.GraphFn,
dtype=dtypes.float32,
input_shapes=[self.tensor_shapes[0]],
output_shapes=[self.tensor_shapes[-1]])
def ShouldRunTest(self, run_params):
# The maximum batch size for dynamic engines will be the actual batch size
# detected at runtime. Therefore, we don't run the test with dynamic
# engines.
return (not run_params.dynamic_engine, 'test static engine only.')
def GetConversionParams(self, run_params):
"""Returns a ConversionParams for test."""
conversion_params = super(MaxBatchSizesTestBase,
self).GetConversionParams(run_params)
conversion_params._replace(
max_batch_size=min(self.max_batch_sizes), maximum_cached_engines=1)
rewrite_config_with_trt = self.GetTrtRewriterConfig(
run_params=run_params,
conversion_params=conversion_params,
use_implicit_batch=True,
disable_non_trt_optimizers=True)
return conversion_params._replace(
rewriter_config_template=rewrite_config_with_trt)
def ExpectedEnginesToBuild(self, run_params):
"""Checks that the expected engine is built.
Args:
run_params: the run parameters.
Returns:
the expected engines to build.
There shall be engines generated for each maximum batch size.
"""
return [
'TRTEngineOp_{}'.format(seq_id)
for seq_id in range(len(self.max_batch_sizes))
]
def ExpectedMaxBatchSizes(self, run_params):
"""Checks that the expected maximum batch sizes for the generated engines.
Args:
run_params: the run parameters.
Returns:
the expected maximum batch sizes for the generated engines.
There shall be engines generated for each maximum batch size.
"""
return self.max_batch_sizes
class AnnotateMaxBatchSizesTest(MaxBatchSizesTestBase):
def GraphFn(self, inp):
"""Builds a tf.Graph for the test."""
tensor = inp * 2.0
tensor = array_ops.reshape(tensor, [-1] + self.tensor_shapes[1][1:])
with ops.get_default_graph()._attr_scope({
'_tftrt_op_max_batch_size':
attr_value_pb2.AttrValue(i=self.max_batch_sizes[1])
}):
tensor = tensor + 3.0
tensor = array_ops.reshape(tensor, [-1] + self.tensor_shapes[2][1:])
with ops.get_default_graph()._attr_scope({
'_tftrt_op_max_batch_size':
attr_value_pb2.AttrValue(i=self.max_batch_sizes[2])
}):
tensor = tensor * 4.0
tensor = array_ops.reshape(tensor, [-1] + self.tensor_shapes[3][1:])
with ops.get_default_graph()._attr_scope({
'_tftrt_op_max_batch_size':
attr_value_pb2.AttrValue(i=self.max_batch_sizes[3])
}):
tensor += tensor + 5.0
return array_ops.identity(tensor, name='output_0')
class StaticBatchSizeTest(MaxBatchSizesTestBase):
def GraphFn(self, inp):
"""Builds a tf.Graph for the test."""
tensor = inp * 2.0
tensor = array_ops.reshape(tensor, self.tensor_shapes[1])
tensor = tensor + 3.0
tensor = array_ops.reshape(tensor, self.tensor_shapes[2])
tensor = tensor * 4.0
tensor = array_ops.reshape(tensor, self.tensor_shapes[3])
tensor += tensor + 5.0
return array_ops.identity(tensor, name='output_0')
if __name__ == '__main__':
test.main()

View File

@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import namedtuple
import collections
import errno
import gc
import itertools
@ -57,7 +57,7 @@ from tensorflow.python.tools import saved_model_utils
from tensorflow.python.training.tracking import tracking
from tensorflow.python.util import nest
TfTrtIntegrationTestParams = namedtuple(
TfTrtIntegrationTestParams = collections.namedtuple(
"TfTrtIntegrationTestParams",
[
# A function that creates the TF graph for testing.
@ -74,7 +74,7 @@ TfTrtIntegrationTestParams = namedtuple(
"expected_output_dims"
])
RunParams = namedtuple(
RunParams = collections.namedtuple(
"RunParams",
[
# Whether to run the conversion online with RewriterConfig, or offline
@ -305,9 +305,13 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
run_params.precision_mode)), "test either calibration or non-INT8"
def ExpectedEnginesToBuild(self, run_params):
"""Return the expected engines to build, implemented by subclass."""
"""Returns the expected engines to build, implemented by subclass."""
raise NotImplementedError()
def ExpectedMaxBatchSizes(self, run_params):
"""Returns the expected maximum batch sizes of the build engines."""
return None
def ExpectedAbsoluteTolerance(self, run_params):
"""The absolute tolerance to compare floating point results."""
return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-02
@ -537,18 +541,39 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
logging.info("Writing graph to %s/%s", temp_dir, graph_name)
graph_io.write_graph(gdef, temp_dir, graph_name)
# Remove the graph sequence number prefix from the name only if the name has
# a prefix TRTEngineOp_n_. When expecting_prefix is true, assert such a
# prefix exists.
def _RemoveGraphSequenceNumberImpl(self, name, expecting_prefix):
match = re.search(r"TRTEngineOp_\d+_", name)
has_prefix = match and name.startswith(match.group(0))
assert (not expecting_prefix) or has_prefix
if has_prefix:
parts = name.split("_", maxsplit=2)
assert len(parts) == 3
return parts[0] + "_" + parts[2]
return name
# Removes the prefix(s) of function name(s).
# The input value can be a string or a sequence of string.
def _Canonicalize(self, value):
if isinstance(value, str):
return self._ToString(value.split("/")[-1])
elif isinstance(value, collections.abc.Iterable):
return set(self._Canonicalize(nm) for nm in value)
else:
raise TypeError(
"'_Canonicalize' can only be used on strings or sequence of strings!")
# Removes the graph sequence number prefix from the name(s) only if the
# name(s) has a prefix TRTEngineOp_n_. When expecting_prefix is true, asserts
# such a prefix exists.
# The input value can be a string or a sequence of string.
def _RemoveGraphSequenceNumberImpl(self, value, expecting_prefix):
if isinstance(value, str):
match = re.search(r"TRTEngineOp_\d+_", value)
has_prefix = match and value.startswith(match.group(0))
assert (not expecting_prefix) or has_prefix
if has_prefix:
parts = value.split("_", maxsplit=2)
assert len(parts) == 3
return parts[0] + "_" + parts[2]
return value
elif isinstance(value, collections.abc.Iterable):
return set(
self._RemoveGraphSequenceNumberImpl(nm, expecting_prefix)
for nm in value)
else:
raise TypeError(
"'_RemoveGraphSequenceNumberImpl' can only be used on strings "
"or sequence of strings!")
def _RemoveGraphSequenceNumber(self, name):
return self._RemoveGraphSequenceNumberImpl(name, True)
@ -644,6 +669,124 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
msg="\nexpected:\n%s\nvs actual:\n%s" %
(sorted(expected_input_map.items()), sorted(actual_input_map.items())))
def _VerifyMaxBatchSizeAnnotations(
self,
expected_engines,
original_gdef,
converted_gdef,
default_max_batch_size,
expected_max_batch_sizes=None,
):
"""Verifies the max batch size annotations in the original and converted GraphDef.
Args:
expected_engines: A sequence of engines names.
original_gdef: GraphDef. The graph def before TensorRT conversion.
converted_gdef: GraphDef. The graph def after TensorRT conversion.
default_max_batch_size: The default maximum batch size to use if no node
inside a segment is annoted with a customized max batch size.
expected_max_batch_sizes: Optional. A sequence of max batch sizes for all
the engines. `None` if does not check enforce max batch sizes.
"""
if isinstance(expected_max_batch_sizes, collections.abc.Collection):
self.assertEqual(len(expected_max_batch_sizes), len(expected_engines))
else:
self.assertIsNone(
expected_max_batch_sizes,
"'expected_max_batch_sizes' shall only be a sequence "
"of integers or `None`.")
def _ChainAllNodes(graph_def):
return itertools.chain(
graph_def.node,
itertools.chain(
*[func.node_def for func in graph_def.library.function]))
old_name_to_node_map = {
self._ToString(node.name): node
for node in _ChainAllNodes(original_gdef)
}
new_name_to_func_map = {
self._ToString(func.signature.name): func
for func in converted_gdef.library.function
}
def _DetectStaticBatchSize(node_def):
"""Returns the static batch size of an operation or None.
It is incorrect to use the output shapes to find the batch size of an
operation, as the segmenter actually uses the input shapes. However, it is
a simplication and works for most of the cases for the test purposes.
Args:
node_def: `tf.NodeDef`. The target node for analysis.
Returns:
If all the outputs of the node have the same static batch size, returns
the int value for the batch size. Otherwise returns None.
"""
shapes = node_def.attr["_output_shapes"].list.shape
batch_size = set(
list(s.dim)[0].size if len(s.dim) >= 2 else None for s in shapes)
if len(batch_size) == 1 and list(batch_size)[0] >= 1:
return list(batch_size)[0]
return None
name_to_engines_map = {}
actual_max_batch_sizes = []
for node in _ChainAllNodes(converted_gdef):
if node.op == "TRTEngineOp":
engine = node
engine_name = self._RemoveGraphSequenceNumber(
self._Canonicalize(self._ToString(engine.name)))
self.assertIn(engine_name, expected_engines)
name_to_engines_map[engine_name] = engine
# The input nodes shall not have the conflicting annotation (no
# annotation or the same annotation) with the maximum batch size
# annotation. If the engine has maximum batch size annotation as the
# non-default maximum batch size, then at least one input node shall
# have the same annotation to be the source.
self.assertIn("max_batch_size", node.attr)
engine_max_batch_size = node.attr["max_batch_size"].i
self.assertIsInstance(engine_max_batch_size, int)
actual_max_batch_sizes.append(engine_max_batch_size)
seg_func = node.attr["segment_func"].func
self.assertIsNotNone(seg_func)
self.assertIn(seg_func.name, new_name_to_func_map)
seg_func_def = new_name_to_func_map[seg_func.name]
logging.info("Segment function name: %s. Including %d nodes.",
seg_func.name, len(seg_func_def.node_def))
node_max_batch_size_all_none = True
# Use the native segment to search for replaced nodes
for alternative_node in seg_func_def.node_def:
node_name = self._Canonicalize(self._ToString(alternative_node.name))
if node_name not in old_name_to_node_map:
continue
original_node = old_name_to_node_map[node_name]
node_max_batch_size = None
if "_tftrt_op_max_batch_size" in original_node.attr:
node_max_batch_size = original_node.attr[
"_tftrt_op_max_batch_size"].i
elif (original_node.op != "Const" and
alternative_node.op != "Const" and
"_output_shapes" in original_node.attr):
node_max_batch_size = _DetectStaticBatchSize(original_node)
logging.info(
"'{%s}(%s)'s max batch size annotation is %s. "
"'{%s}'s max batch size is %s.", node_name, original_node.op,
str(node_max_batch_size), engine_name, str(engine_max_batch_size))
node_max_batch_size_all_none &= node_max_batch_size is None
self.assertTrue(engine_max_batch_size == node_max_batch_size or
node_max_batch_size is None)
logging.info("'{%s}'s max batch size is %d.", engine_name,
engine_max_batch_size)
self.assertTrue(engine_max_batch_size == default_max_batch_size or
not node_max_batch_size_all_none)
self.assertCountEqual(expected_engines, tuple(name_to_engines_map.keys()))
if expected_max_batch_sizes is not None:
self.assertCountEqual(expected_max_batch_sizes, actual_max_batch_sizes)
def _GetGraphDef(self, run_params, gdef_or_saved_model_dir):
if isinstance(gdef_or_saved_model_dir, str):
if run_params.is_v2:
@ -703,7 +846,14 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
self.assertEqual(num_engines, len(expected_engines))
if isinstance(expected_engines, dict):
self._VerifyConnections(expected_engines, original_gdef, gdef_to_verify)
# TODO(aaroey): consider verifying the corresponding TF function.
self._VerifyMaxBatchSizeAnnotations(
expected_engines=expected_engines,
original_gdef=original_gdef,
converted_gdef=gdef_to_verify,
expected_max_batch_sizes=self.ExpectedMaxBatchSizes(run_params),
default_max_batch_size=self.GetConversionParams(
run_params).max_batch_size,
)
def _VerifyGraphDefV2(self, run_params, original_gdef, gdef_to_verify,
graph_state):
@ -721,15 +871,10 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
all_op_names.append(node.name)
if node.op == "TRTEngineOp":
trt_op_names.append(node.name)
# Remove the function name prefix.
def _Canonicalize(names):
return set(self._ToString(name.split("/")[-1]) for name in names)
# Remove the graph sequence number prefix from all the names.
def _RemoveGraphSequenceNumber(names):
return set(self._RemoveGraphSequenceNumber(name) for name in names)
all_op_names = _Canonicalize(all_op_names)
trt_op_names = _RemoveGraphSequenceNumber(_Canonicalize(trt_op_names))
all_op_names = self._Canonicalize(all_op_names)
trt_op_names = self._RemoveGraphSequenceNumber(
self._Canonicalize(trt_op_names))
if isinstance(expected_engines, dict):
# For simplicity we don't verify the connections inside the engine in
@ -741,6 +886,14 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
expected_engines = set(expected_engines.keys())
self.assertEqual(set(expected_engines), trt_op_names)
self._VerifyMaxBatchSizeAnnotations(
expected_engines=expected_engines,
original_gdef=original_gdef,
converted_gdef=gdef_to_verify,
expected_max_batch_sizes=self.ExpectedMaxBatchSizes(run_params),
default_max_batch_size=self.GetConversionParams(
run_params).max_batch_size,
)
def _VerifyGraphDef(self, run_params, original_gdef_or_saved_model_dir,
gdef_or_saved_model_dir_to_verify, graph_state):

View File

@ -125,7 +125,8 @@ class ExplicitBatchTest(TrtModeTestBase):
def GetConversionParams(self, run_params):
"""Return a TrtConversionParams for test that enables explicit batch."""
return super(ExplicitBatchTest, self).GetConversionParams(run_params, False)
return super(ExplicitBatchTest, self).GetConversionParams(
run_params, implicit_batch=False)
def ExpectedEnginesToBuild(self, run_params):
"""Check that the expected engine is built.