Merge pull request from phillip-kravtsov:trt_remove_serialized_from_op

PiperOrigin-RevId: 260966558
This commit is contained in:
TensorFlower Gardener 2019-07-31 11:47:01 -07:00
commit 68edd47a23
16 changed files with 322 additions and 344 deletions

View File

@ -97,10 +97,17 @@ cc_library(
":utils",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@local_config_cuda//cuda:cuda_headers",
"//tensorflow/core:core_cpu_lib_no_ops",
"//tensorflow/core:framework",
"//tensorflow/core:gpu_headers_lib",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:lib_proto_parsing",
"//tensorflow/core:stream_executor",
"//tensorflow/core:stream_executor_headers_lib",
"//tensorflow/core/grappler/costs:graph_properties",
"//tensorflow/stream_executor/lib",
] + if_tensorrt([":tensorrt_lib"]) + tf_custom_op_library_additional_deps(),
alwayslink = 1,
)
@ -168,8 +175,12 @@ tf_cuda_cc_test(
":trt_op_kernels",
":trt_op_libs",
":trt_resources",
":trt_conversion",
":utils",
"@com_google_googletest//:gtest",
"@com_google_absl//absl/strings",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:function_ops",
"//tensorflow/cc:ops",
"//tensorflow/cc:scope",
"//tensorflow/core:framework",
@ -248,6 +259,9 @@ tf_cuda_library(
":utils",
"//tensorflow/core:framework_headers_lib",
"//tensorflow/core:framework_lite",
"//tensorflow/core/grappler:op_types",
"//tensorflow/core:graph",
"//tensorflow/core:gpu_runtime",
"//tensorflow/core:lib_proto_parsing",
] + if_tensorrt([":tensorrt_lib"]),
)
@ -318,11 +332,13 @@ tf_cuda_library(
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core:framework",
"//tensorflow/core:framework_lite",
"//tensorflow/core:gpu_runtime",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:devices",

View File

@ -324,8 +324,6 @@ Status CreateTRTNode(const ConversionParams& params,
nvinfer1::IGpuAllocator* alloc,
std::vector<Node*>* engine_nodes) {
const auto& info = infos.at(pos);
std::vector<TensorShapeProto> output_shape_protos;
std::vector<TensorShapeProto> input_shape_protos;
std::vector<PartialTensorShape> input_shapes;
std::vector<NodeDefBuilder::NodeOut> inputs;
std::vector<Node*> input_nodes;
@ -359,25 +357,16 @@ Status CreateTRTNode(const ConversionParams& params,
} else {
// Data edges
if (!conn.is_input_edge) {
// Set the shapes and data types of output edge.
TensorShapeProto out_shape;
// shape of the output node inside segment
conn.inside_shape.AsProto(&out_shape);
if (output_shape_protos.size() <= conn.port_number) {
output_shape_protos.resize(conn.port_number + 1);
// Set the data types of output edge.
if (out_types.size() <= conn.port_number) {
out_types.resize(conn.port_number + 1);
}
output_shape_protos.at(conn.port_number) = out_shape;
out_types.at(conn.port_number) = conn.connection_type;
} else {
// Set the shapes and data types of input edge.
TensorShapeProto in_shape;
conn.outside_shape.AsProto(&in_shape);
if (input_shape_protos.size() <= conn.port_number) {
input_shape_protos.resize(conn.port_number + 1);
if (input_shapes.size() <= conn.port_number) {
input_shapes.resize(conn.port_number + 1);
}
input_shape_protos.at(conn.port_number) = in_shape;
input_shapes.at(conn.port_number) = conn.outside_shape;
// Shape must be fully defined (excluding batch dimension) for static
// mode.
@ -439,8 +428,6 @@ Status CreateTRTNode(const ConversionParams& params,
TrtUniquePtrType<nvinfer1::IHostMemory> engine_data(engine->serialize());
segment_string = string(static_cast<const char*>(engine_data->data()),
engine_data->size());
} else {
segment_string = info.segment_graph_def.SerializeAsString();
}
string prec_string;
@ -460,15 +447,13 @@ Status CreateTRTNode(const ConversionParams& params,
}
NodeDef trt_node;
NameAttrList function;
function.set_name(StrCat(info.engine_name, "_native_segment"));
Status status =
node_builder.Attr("input_shapes", input_shape_protos)
.Attr("output_shapes", output_shape_protos)
node_builder
.Attr("static_engine",
info.engine_type == EngineInfo::EngineType::TRTStatic)
.Attr("segment_funcdef_name",
params.use_function_backup
? StrCat(info.engine_name, "_native_segment")
: "")
.Attr("segment_func", function)
.Attr("serialized_segment", segment_string)
.Attr("calibration_data", "")
.Attr("max_cached_engines_count", info.maximum_cached_engines)
@ -537,103 +522,27 @@ Status CreateTRTNode(const ConversionParams& params,
return Status::OK();
}
// Function to construct a funcdef from the segment and add it to the graph.
Status RegisterSegmentFunctionToFunctionLibrary(Graph* graph,
const GraphDef& segment,
const string& engine_name) {
Graph sgraph(graph->flib_def());
GraphConstructorOptions gcopts;
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(gcopts, segment, &sgraph));
std::map<string, Node*> io_nodes;
int num_inputs = 0;
for (auto n : sgraph.op_nodes()) {
if (absl::StartsWith(n->name(), kInputPHName)) {
num_inputs++;
io_nodes.insert({n->name(), n});
} else if (absl::StartsWith(n->name(), kOutputPHName)) {
io_nodes.insert({n->name(), n});
}
}
for (int i = 0; i < num_inputs; ++i) {
auto name = StrCat(kInputPHName, i);
auto node = io_nodes[name];
NodeDef nd;
NodeDefBuilder node_builder(StrCat(name, "_Arg"),
FunctionLibraryDefinition::kArgOp);
VLOG(1) << "Adding " << StrCat(name, "_Arg");
TF_RETURN_IF_ERROR(node_builder.Attr("T", node->output_type(0))
.Attr("index", i)
.Finalize(&nd));
Status s;
auto node_arg = sgraph.AddNode(nd, &s);
if (!s.ok()) {
LOG(ERROR) << "Couldn't add _Arg node for " << name;
}
for (auto edge : node->out_edges()) {
sgraph.AddEdge(node_arg, 0, edge->dst(), edge->dst_input());
VLOG(1) << "Updating funcdef input " << node_arg->name() << ":" << 0
<< " - > " << edge->dst()->name() << ":" << edge->dst_input();
if (!s.ok()) {
LOG(ERROR) << "Failed to update edge from " << node_arg->name()
<< " to " << edge->dst()->name() << ":" << edge->dst_input();
}
}
sgraph.RemoveNode(node);
}
for (int i = 0; i < io_nodes.size() - num_inputs; ++i) {
auto name = StrCat(kOutputPHName, i);
auto node = io_nodes[name];
NodeDef nd;
NodeDefBuilder node_builder(StrCat(name, "_Ret"),
FunctionLibraryDefinition::kRetOp);
auto edge = *(node->in_edges().begin());
NodeDefBuilder::NodeOut nout(edge->src()->name(), edge->src_output(),
edge->src()->output_type(edge->src_output()));
VLOG(1) << " input " << nout.node << ":" << nout.index
<< " dtype=" << DataTypeString(nout.data_type);
// nvcc complains that Input(<brace-enclosed initializer list>) is
// ambiguous, so do not use Input({nout}).
node_builder.Input(nout);
TF_RETURN_IF_ERROR(node_builder.Attr("T", node->output_type(0))
.Attr("index", i)
.Finalize(&nd));
if (VLOG_IS_ON(3)) {
VLOG(3) << nd.DebugString();
}
Status s;
auto node_ret = sgraph.AddNode(nd, &s);
if (!s.ok()) {
LOG(ERROR) << "Couldn't add _Ret node for " << name;
}
VLOG(1) << "Update edge from " << edge->src()->name() << ":"
<< edge->src_output() << " - > " << node_ret->name() << ":" << 0;
sgraph.AddEdge(edge->src(), edge->src_output(), node_ret, 0);
s = sgraph.UpdateEdge(edge->src(), edge->src_output(), node_ret, 0);
if (!s.ok()) {
LOG(ERROR) << "Failed to update edge from " << edge->src()->name() << ":"
<< edge->src_output() << " - > " << node_ret->name() << ":"
<< 0;
}
sgraph.RemoveNode(node);
}
FunctionDefLibrary fdeflib;
auto native_segment = fdeflib.add_function();
Status RegisterGraphToFunctionLibrary(const GraphDef& segment_graph_def,
Graph* graph, const string& engine_name) {
Graph segment_graph(graph->flib_def());
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(),
segment_graph_def, &segment_graph));
FunctionDefLibrary library;
auto segment_func = library.add_function();
TF_RETURN_IF_ERROR(GraphToFunctionDef(
sgraph, StrCat(engine_name, "_native_segment"), native_segment));
segment_graph, StrCat(engine_name, "_native_segment"), segment_func));
// Set kIntsonDeviceAttr to true so that all TRTEngineOp outputs are always on
// a GPU device as expected. Otherwise, some of the tensors of type DT_INT32
// would be on host if the op generating the tensor has host memory tag set.
(*native_segment
->mutable_attr())[FunctionLibraryDefinition::kIntsOnDeviceAttr]
(*segment_func->mutable_attr())[FunctionLibraryDefinition::kIntsOnDeviceAttr]
.set_b(true);
if (VLOG_IS_ON(7)) {
VLOG(7) << engine_name << " Function_Def ";
VLOG(7) << native_segment->DebugString();
VLOG(7) << segment_func->DebugString();
}
VLOG(1) << "Adding funcdef to graphlib";
TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(fdeflib));
VLOG(1) << "Adding funcdef " << segment_func->signature().name()
<< " to graphlib";
TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(library));
return Status::OK();
}
@ -690,16 +599,10 @@ std::pair<int, Allocator*> GetDeviceAndAllocator(const ConversionParams& params,
// Entry function from optimization pass.
Status ConvertAfterShapes(const ConversionParams& params) {
// Sanity checks.
if (params.precision_mode == TrtPrecisionMode::INT8) {
if (params.use_calibration && !params.use_function_backup) {
return errors::InvalidArgument(
"Calibration requires enabling fallback to TF function execution.");
}
} else {
if (params.use_calibration) {
return errors::InvalidArgument(
"Calibration with FP32 or FP16 is not supported.");
}
if (params.precision_mode != TrtPrecisionMode::INT8 &&
params.use_calibration) {
return errors::InvalidArgument(
"Calibration with FP32 or FP16 is not supported.");
}
// Convert graphdef to graph.
@ -760,14 +663,14 @@ Status ConvertAfterShapes(const ConversionParams& params) {
: EngineInfo::EngineType::TRTStatic);
curr_engine.use_calibration = params.use_calibration;
curr_engine.maximum_cached_engines = params.max_cached_engines;
if (params.use_function_backup) {
status = RegisterSegmentFunctionToFunctionLibrary(
&graph, curr_engine.segment_graph_def, curr_engine.engine_name);
if (!status.ok()) {
LOG(WARNING) << "Failed to register segment graphdef as a function "
<< t << ": " << status;
continue;
}
status = RegisterGraphToFunctionLibrary(curr_engine.segment_graph_def,
&graph, curr_engine.engine_name);
if (!status.ok()) {
LOG(WARNING) << "Failed to register segment graphdef to the library " << t
<< ": " << status;
continue;
}
engine_bytes_size.push_back(curr_engine.segment_graph_def.ByteSizeLong());

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/grappler/clusters/cluster.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
@ -46,8 +47,6 @@ struct ConversionParams {
// maximum number of cached engines
int max_cached_engines = 1;
bool use_calibration = true;
// Whether to use function fallback for TRTEngineOp
bool use_function_backup = true;
};
// Method to call from optimization pass
@ -57,6 +56,11 @@ Status ConvertAfterShapes(const ConversionParams& params);
std::pair<int, Allocator*> GetDeviceAndAllocator(const ConversionParams& params,
const EngineInfo& engine);
// Helper method that registers `segment_graph` as a function to the function
// library in `graph`.
Status RegisterGraphToFunctionLibrary(const GraphDef& segment_graph_def,
Graph* graph, const string& engine_name);
} // namespace convert
} // namespace tensorrt
} // namespace tensorflow

View File

@ -39,6 +39,7 @@ limitations under the License.
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/numbers.h"
@ -75,18 +76,15 @@ limitations under the License.
namespace tensorflow {
namespace tensorrt {
// TODO(aaroey): put these constants into some class.
const char* const kInputPHName = "TensorRTInputPH_";
const char* const kOutputPHName = "TensorRTOutputPH_";
namespace convert {
bool IsEngineInput(absl::string_view name) {
return absl::StartsWith(name, kInputPHName);
return absl::StartsWith(name, IONamePrefixes::kInputPHName);
}
bool IsEngineOutput(absl::string_view name) {
return absl::StartsWith(name, kOutputPHName);
return absl::StartsWith(name, IONamePrefixes::kOutputPHName);
}
namespace convert {
using absl::StrAppend;
using absl::StrCat;
@ -5200,19 +5198,33 @@ Status ConvertGraphDefToEngine(
for (const auto& node_def : gdef.node()) {
string node_name = node_def.name();
VLOG(2) << "Converting op name=" << node_name << ", op=" << node_def.op();
if (IsEngineInput(node_name) && (node_def.op() == "Placeholder")) {
if (IsEngineInput(node_name)) {
int32 slot_number = -1;
if (!strings::safe_strto32( // non-absl ok
node_name.c_str() + strlen(kInputPHName), &slot_number)) {
return errors::InvalidArgument("Failed to parse slot number from ",
node_name);
string type_key;
if (node_def.op() == "Placeholder") {
if (!strings::safe_strto32( // non-absl ok
node_name.c_str() + strlen(IONamePrefixes::kInputPHName),
&slot_number)) {
return errors::InvalidArgument("Failed to parse slot number from ",
node_name);
}
type_key = "dtype";
} else if (tensorflow::grappler::IsArg(node_def)) {
// Maybe remove the dependence on grappler and re-implement IsArg,
// which is pretty simple (but could change if new Arg nodes are added)
slot_number = node_def.attr().at("index").i();
type_key = "T";
} else {
return errors::InvalidArgument(
"Node ", node_name,
" with is neither Placeholder nor Arg, instead ", node_def.op());
}
nvinfer1::DataType trt_dtype;
nvinfer1::Dims trt_dims;
int batch_size = -1;
auto shape = input_shapes.at(slot_number);
auto status = ValidateTensorProperties(
node_def.op(), node_def.attr().at("dtype").type(), shape,
node_def.op(), node_def.attr().at(type_key).type(), shape,
/*validation_only=*/false, &trt_dtype, &trt_dims, &batch_size);
if (!status.ok()) {
const string error_message =
@ -5228,12 +5240,23 @@ Status ConvertGraphDefToEngine(
// engines offline, by calling sess.run() and cache/serialize the engines.
TF_RETURN_IF_ERROR(
converter.AddInputTensor(node_name, trt_dtype, trt_dims, batch_size));
} else if (IsEngineOutput(node_name) && (node_def.op() == "Identity")) {
} else if (IsEngineOutput(node_name)) {
int32 slot_number = -1;
if (!strings::safe_strto32( // non-absl ok
node_name.c_str() + strlen(kOutputPHName), &slot_number)) {
return errors::InvalidArgument("Failed to parse slot number from ",
node_name);
if (node_def.op() == "Identity") {
if (!strings::safe_strto32( // non-absl ok
node_name.c_str() + strlen(IONamePrefixes::kOutputPHName),
&slot_number)) {
return errors::InvalidArgument("Failed to parse slot number from ",
node_name);
}
} else if (tensorflow::grappler::IsRetval(node_def)) {
slot_number = node_def.attr().at("index").i();
} else {
return errors::InvalidArgument(
"Node with name ", node_name,
" starting with IONamePrefixes::kOutputPHName is "
"neither Identity nor Retval, instead ",
node_def.op());
}
// Get output type that TensorFlow expects
TFAttrs attrs(node_def);
@ -5302,7 +5325,8 @@ Status ConvertSegmentToGraphDef(
// Add dummy input/output nodes to the segment graphdef.
if (connection.is_input_edge) {
const string node_name = StrCat(kInputPHName, connection.port_number);
const string node_name =
StrCat(IONamePrefixes::kInputPHName, connection.port_number);
if (marker_nodes.count(node_name)) {
VLOG(1) << "Reusing input " << node_name << " for the edge "
<< connection.outside_node_name << ":"
@ -5312,16 +5336,18 @@ Status ConvertSegmentToGraphDef(
}
marker_nodes.insert(node_name);
auto seg_node = segment_def->add_node();
NodeDefBuilder builder(node_name, "Placeholder");
NodeDefBuilder builder(node_name, "_Arg");
auto status = builder.Attr("shape", partial_shape)
.Attr("dtype", dtype)
.Attr("T", dtype)
.Attr("index", connection.port_number)
.Finalize(seg_node);
VLOG(1) << "Constructing input " << node_name << " for the edge "
<< connection.outside_node_name << ":" << connection.outside_port
<< " -> " << connection.inside_node_name << ":"
<< connection.inside_port;
} else {
const string node_name = StrCat(kOutputPHName, connection.port_number);
const string node_name =
StrCat(IONamePrefixes::kOutputPHName, connection.port_number);
if (marker_nodes.count(node_name)) {
VLOG(1) << "Reusing output " << node_name << " for the edge "
<< connection.inside_node_name << ":" << connection.inside_port
@ -5331,9 +5357,10 @@ Status ConvertSegmentToGraphDef(
}
marker_nodes.insert(node_name);
auto seg_node = segment_def->add_node();
NodeDefBuilder builder(node_name, "Identity");
NodeDefBuilder builder(node_name, "_Retval");
auto status =
builder
builder.Attr("T", dtype)
.Attr("index", connection.port_number)
.Input(connection.inside_node_name, connection.inside_port, dtype)
.Finalize(seg_node);
VLOG(1) << "Constructing output " << node_name << " for the edge "
@ -5359,12 +5386,12 @@ Status ConvertSegmentToGraphDef(
if (connection.is_control_edge() || !connection.is_input_edge) continue;
auto snode =
segment_def->mutable_node(old_to_new_id_map[connection.inside_id]);
const string placeholder_name =
StrCat(kInputPHName, connection.port_number);
const string arg_name =
StrCat(IONamePrefixes::kInputPHName, connection.port_number);
VLOG(1) << "Updating " << snode->name() << ":" << connection.inside_port
<< " from " << snode->input(connection.inside_port) << " to "
<< placeholder_name;
snode->set_input(connection.inside_port, placeholder_name);
<< arg_name;
snode->set_input(connection.inside_port, arg_name);
}
std::set<string> subgraph_node_names;
for (const Node* node : subgraph_nodes) {

View File

@ -37,8 +37,6 @@ limitations under the License.
namespace tensorflow {
namespace tensorrt {
extern const char* const kInputPHName;
extern const char* const kOutputPHName;
namespace convert {
@ -119,8 +117,8 @@ struct EngineInfo {
bool use_calibration;
};
// Constructs a graphdef from the segment in the given graph. Adds placeholder
// nodes for input edges (InputPH_*) and identity nodes for output edges
// Constructs a graphdef from the segment in the given graph. Adds _Arg
// nodes for input edges (InputPH_*) and _Retval nodes for output edges
// (OutputPH_*). This function needs to be called before TensorRT nodes
// inserted in order to correctly get sizes from the original graph.
//

View File

@ -1158,7 +1158,7 @@ class ConvertGraphDefToEngineTest : public ::testing::Test {
int batch_size = -1;
for (const NodeDef& node : gdef.node()) {
absl::string_view node_name(node.name());
if (absl::ConsumePrefix(&node_name, kInputPHName)) {
if (absl::ConsumePrefix(&node_name, IONamePrefixes::kInputPHName)) {
int port = -1;
EXPECT_TRUE(absl::SimpleAtoi(node_name, &port)) << node.name();
if (input_shapes.size() < port + 1) input_shapes.resize(port + 1);
@ -1188,11 +1188,13 @@ class ConvertGraphDefToEngineTest : public ::testing::Test {
TEST_F(ConvertGraphDefToEngineTest, IdentityGraph) {
Scope s = Scope::NewRootScope();
auto input = ops::Placeholder(s.WithOpName(StrCat(kInputPHName, 0)), DT_FLOAT,
ops::Placeholder::Shape({1, 1}));
auto input =
ops::Placeholder(s.WithOpName(StrCat(IONamePrefixes::kInputPHName, 0)),
DT_FLOAT, ops::Placeholder::Shape({1, 1}));
auto output = ops::Identity(s.WithOpName("identity1"), input);
output = ops::Identity(s.WithOpName("identity2"), output);
output = ops::Identity(s.WithOpName(StrCat(kOutputPHName, 0)), output);
output = ops::Identity(s.WithOpName(StrCat(IONamePrefixes::kOutputPHName, 0)),
output);
// If the converter marks the input tensor as output tensor, the conversion
// below will fail with:
// > TensorRTOutputPH_0 cannot be both input and output

View File

@ -67,9 +67,6 @@ Status TRTOptimizationPass::Init(
if (params.count("use_calibration")) {
use_calibration_ = params.at("use_calibration").b();
}
if (params.count("use_function_backup")) {
use_function_backup_ = params.at("use_function_backup").b();
}
return Status::OK();
}
@ -258,7 +255,6 @@ Status TRTOptimizationPass::Optimize(grappler::Cluster* cluster,
cp.is_dyn_op = is_dynamic_op_;
cp.max_cached_engines = max_cached_batches_;
cp.use_calibration = use_calibration_;
cp.use_function_backup = use_function_backup_;
auto status = ConvertAfterShapes(cp);
VLOG(1) << "Returning from " << name_;
return status;

View File

@ -40,8 +40,7 @@ class TRTOptimizationPass : public grappler::CustomGraphOptimizer {
is_dynamic_op_(false),
max_cached_batches_(1),
max_workspace_size_bytes_(256LL << 20),
use_calibration_(true),
use_function_backup_(true) {
use_calibration_(true) {
VLOG(1) << "Constructing " << name_;
}
@ -71,8 +70,6 @@ class TRTOptimizationPass : public grappler::CustomGraphOptimizer {
int64_t max_workspace_size_bytes_;
bool use_calibration_;
// Whether to allow TF function fallback path in TRTEngineOp.
bool use_function_backup_;
};
} // namespace convert

View File

@ -23,6 +23,12 @@ limitations under the License.
namespace tensorflow {
namespace tensorrt {
class IONamePrefixes {
public:
static constexpr const char* const kInputPHName = "TensorRTInputPH_";
static constexpr const char* const kOutputPHName = "TensorRTOutputPH_";
};
template <typename T>
struct TrtDestroyer {
void operator()(T* t) {

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <vector>
#include "absl/memory/memory.h"
#include "absl/strings/ascii.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
@ -24,10 +25,15 @@ limitations under the License.
#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/graph_optimizer.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
@ -53,6 +59,7 @@ using ::stream_executor::port::StatusOr;
// A helper class to call done() when destructed for asynchronous execution.
// Helps simultaneous execution of native and TRT engines.
class AsyncHelper : public core::RefCounted {
public:
AsyncHelper(AsyncOpKernel::DoneCallback done) : done_(done) {}
@ -89,7 +96,10 @@ class TRTEngineOp : public AsyncOpKernel {
void ExecuteCalibration(OpKernelContext* ctx, AsyncHelper* helper);
// Construct a function handle for executing native funcdef graph
Status ConstructFunctionHandle(OpKernelContext* ctx);
// These are the exact same function.
Status ConstructFunctionHandle(FunctionLibraryRuntime* lib,
const string& device_name);
// Execute replaced native segment as function Op.
void ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper);
@ -125,10 +135,8 @@ class TRTEngineOp : public AsyncOpKernel {
// serialized protobuf segment or trt engine depending on static_engine_ flag.
string serialized_segment_;
// Name of the function for TF native execution of the segment. If empty, it
// means TF native execution is not allowed, and if TRT engine fails to run
// an error will be returned.
string funcdef_name_;
// The function for TF native execution of the segment.
NameAttrList func_;
// GraphDef representation of the segment.
GraphDef segment_graph_;
@ -148,7 +156,7 @@ class TRTEngineOp : public AsyncOpKernel {
int64 workspace_size_;
mutex engine_mutex_;
FunctionLibraryRuntime::Handle native_func_;
FunctionLibraryRuntime::Handle func_handle_;
// The finalized calibrator for inference.
std::unique_ptr<TRTInt8Calibrator> calibrator_;
@ -177,23 +185,61 @@ void* GetTensorAddress(const Tensor* tensor_ptr) {
}
}
Status TRTEngineOp::ConstructFunctionHandle(OpKernelContext* ctx) {
static Status FunctionDefToGraphDef(FunctionLibraryRuntime::Handle handle,
FunctionLibraryRuntime* flib_runtime,
GraphDef* graph_def) {
const FunctionLibraryDefinition* flib_def =
flib_runtime->GetFunctionLibraryDefinition();
const FunctionBody* fbody;
fbody = flib_runtime->GetFunctionBody(handle);
if (!fbody) {
return errors::Internal(
"Function body is null when converting from FuncDef to GraphDef.");
}
std::unique_ptr<Graph> graph(new Graph(flib_def));
CopyGraph(*fbody->graph, graph.get());
auto replace_name = [](const char* const prefix, string* name) {
if (absl::StartsWith(*name, absl::AsciiStrToLower(prefix))) {
name->replace(0, strlen(prefix), prefix);
return true;
}
return false;
};
graph->ToGraphDef(graph_def);
// GraphToFunctionDef() will convert all the node names to lowercase.
for (auto& node : *graph_def->mutable_node()) {
if (!replace_name(IONamePrefixes::kInputPHName, node.mutable_name())) {
if (replace_name(IONamePrefixes::kOutputPHName, node.mutable_name())) {
// Instantiation of the function will append _RetVal to the node name,
// need to remove it for backward compatibility.
const char* const suffix_to_remove = "_RetVal";
if (absl::EndsWith(node.name(), suffix_to_remove)) {
node.mutable_name()->erase(node.name().size() -
strlen(suffix_to_remove));
}
}
}
for (auto& input : *node.mutable_input()) {
if (!replace_name(IONamePrefixes::kInputPHName, &input)) {
replace_name(IONamePrefixes::kOutputPHName, &input);
}
}
}
return Status::OK();
}
Status TRTEngineOp::ConstructFunctionHandle(FunctionLibraryRuntime* lib,
const string& device_name) {
VLOG(1) << "Constructing function handle";
auto lib = ctx->function_library();
if (lib == nullptr) {
return errors::Internal("Context function library is null");
}
auto fdef = lib->GetFunctionLibraryDefinition()->Find(funcdef_name_);
if (fdef == nullptr) {
return errors::Internal("Native FunctionDef ", funcdef_name_,
" can't be found in function library");
}
FunctionLibraryRuntime::InstantiateOptions inst_ops;
inst_ops.state_handle = "";
inst_ops.target = ctx->device()->name();
native_func_ = 0;
return lib->Instantiate(funcdef_name_, AttrSlice(&fdef->attr()), inst_ops,
&native_func_);
inst_ops.target = device_name;
return lib->Instantiate(func_.name(), AttrSlice(&func_.attr()), inst_ops,
&func_handle_);
}
TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
@ -204,15 +250,7 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
OP_REQUIRES_OK(context,
context->GetAttr("workspace_size_bytes", &workspace_size_));
OP_REQUIRES_OK(context, context->GetAttr("static_engine", &static_engine_));
if (!static_engine_) {
OP_REQUIRES(context, segment_graph_.ParseFromString(serialized_segment_),
errors::InvalidArgument("Failed to parse segment graphdef!"));
VLOG(1) << "Size of serialized GraphDef: "
<< serialized_segment_.capacity();
string tmp;
// Swap with temporary empty string to deallocate the CPU memory.
serialized_segment_.swap(tmp);
}
VLOG(1) << "Constructing " << name();
string precision_string;
OP_REQUIRES_OK(context,
@ -220,12 +258,22 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
string calibration_data;
OP_REQUIRES_OK(context,
context->GetAttr("calibration_data", &calibration_data));
OP_REQUIRES_OK(context,
context->GetAttr("segment_funcdef_name", &funcdef_name_));
OP_REQUIRES_OK(context, context->GetAttr("segment_func", &func_));
OP_REQUIRES(context, !func_.name().empty(),
errors::InvalidArgument(
"The TF function for the TRT segment could not be empty"));
OP_REQUIRES_OK(context,
TrtPrecisionModeFromName(precision_string, &precision_mode_));
OP_REQUIRES_OK(context,
context->GetAttr("use_calibration", &use_calibration_));
func_handle_ = kInvalidHandle;
if (!static_engine_) {
FunctionLibraryRuntime* lib = context->function_library();
OP_REQUIRES_OK(context,
ConstructFunctionHandle(lib, context->device()->name()));
OP_REQUIRES_OK(context,
FunctionDefToGraphDef(func_handle_, lib, &segment_graph_));
}
calibration_mode_ =
(use_calibration_ && precision_mode_ == TrtPrecisionMode::INT8 &&
calibration_data.empty());
@ -233,20 +281,19 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
calibrator_.reset(new TRTInt8Calibrator(calibration_data));
calibration_data.resize(0);
}
native_func_ = kInvalidHandle;
OP_REQUIRES_OK(context, context->GetAttr("max_cached_engines_count",
&max_cached_engines_));
}
void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx,
AsyncHelper* helper) {
OP_REQUIRES_ASYNC(ctx, !funcdef_name_.empty(),
errors::Internal("Fallback path is disabled, for ", name()),
*helper);
std::vector<Tensor> inputs;
std::vector<Tensor>* outputs = new std::vector<Tensor>();
if (native_func_ == kInvalidHandle) {
OP_REQUIRES_OK_ASYNC(ctx, ConstructFunctionHandle(ctx), *helper);
if (func_handle_ == kInvalidHandle) {
OP_REQUIRES_OK_ASYNC(
ctx,
ConstructFunctionHandle(ctx->function_library(), ctx->device()->name()),
*helper);
}
auto lib = ctx->function_library();
FunctionLibraryRuntime::Options opts;
@ -259,7 +306,7 @@ void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx,
}
helper->Ref(); // Increment count for calculating native graph
VLOG(1) << "Executing native segment: " << name();
lib->Run(opts, native_func_, inputs, outputs,
lib->Run(opts, func_handle_, inputs, outputs,
[this, ctx, outputs, helper](const Status& s) {
core::ScopedUnref sc(helper);
OP_REQUIRES_OK_ASYNC(ctx, s, *helper);
@ -298,7 +345,7 @@ void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx,
const auto device_tensor =
calib_ctx->device_tensors_.at(i).AccessTensor(ctx);
CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes());
input_data.emplace(StrCat(kInputPHName, i), data_address);
input_data.emplace(StrCat(IONamePrefixes::kInputPHName, i), data_address);
}
VLOG(2) << "Filled map for sending";
// copied from cuda_kernel_helper since it seems only valid in *.cu.cc files
@ -435,9 +482,11 @@ bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx,
// input.
const int num_batch = ctx->input(0).shape().dim_size(0);
const int num_binding = ctx->num_inputs() + ctx->num_outputs();
std::vector<void*> buffers(num_binding);
for (int i = 0; i < ctx->num_inputs(); i++) {
const string input_name = StrCat(kInputPHName, i);
const string input_name = StrCat(IONamePrefixes::kInputPHName, i);
const int binding_index = cuda_engine->getBindingIndex(input_name.c_str());
if (binding_index == -1) {
const string msg =
@ -479,7 +528,7 @@ bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx,
for (int i = 0; i < ctx->num_outputs(); i++) {
// Create an output tensor
const string output_name = StrCat(kOutputPHName, i);
const string output_name = StrCat(IONamePrefixes::kOutputPHName, i);
const int binding_index = cuda_engine->getBindingIndex(output_name.c_str());
Tensor* output_tensor = nullptr;
@ -713,7 +762,7 @@ Status TRTEngineOp::AllocateCalibrationResources(
"Unsupported data type encountered in input ", i);
}
cres->device_buffers_.emplace(
StrCat(kInputPHName, i),
StrCat(IONamePrefixes::kInputPHName, i),
std::pair<void*, size_t>(device_address, device_tensor->TotalBytes()));
}
cres->calibrator_.reset(
@ -727,56 +776,55 @@ Status TRTEngineOp::AllocateCalibrationResources(
}
cache_res->Ref();
cres->thr_.reset(
new std::thread([this, cres, shapes, platform_gpu_id, cache_res]() {
core::ScopedUnref sc(cache_res);
cres->thr_.reset(new std::thread([this, cres, shapes, platform_gpu_id,
cache_res]() {
core::ScopedUnref sc(cache_res);
LOG(INFO) << "Starting calibration thread on device " << platform_gpu_id
<< ", Calibration Resource @ " << cres;
auto err = cudaSetDevice(platform_gpu_id);
if (err != cudaSuccess) {
// TODO(aaroey): should return error here.
LOG(ERROR) << "Couldn't set cuda device to " << platform_gpu_id
<< " in calibration thread";
}
std::vector<PartialTensorShape> partial_shapes(shapes.begin(),
shapes.end());
// ConvertGraphDefToEngine() will try to build the engine. This thread
// will loop inside buildCudaEngine() consuming the calibration data
// that is set by the TF op, and drive the builder until calibrator
// returns false. Engine is discarded after calibration table is
// generated
//
// TODO(aaroey): maybe setting the max batch size using the python
// calibration wrapper class.
auto s = convert::ConvertGraphDefToEngine(
this->segment_graph_, TrtPrecisionMode::INT8,
cres->calibrator_->getBatchSize(), this->workspace_size_,
partial_shapes, &cache_res->GetLogger(),
cache_res->allocator_.get(), cres->calibrator_.get(),
&cres->engine_,
/*use_calibration=*/true,
/*convert_successfully=*/nullptr);
if (!s.ok()) {
LOG(ERROR) << "Calibration failed: " << s;
cres->calibrator_->setDone(); // Ignore further pushes
}
LOG(INFO) << "Starting calibration thread on device " << platform_gpu_id
<< ", Calibration Resource @ " << cres;
auto err = cudaSetDevice(platform_gpu_id);
if (err != cudaSuccess) {
// TODO(aaroey): should return error here.
LOG(ERROR) << "Couldn't set cuda device to " << platform_gpu_id
<< " in calibration thread";
}
std::vector<PartialTensorShape> partial_shapes(shapes.begin(),
shapes.end());
// ConvertGraphDefToEngine() will try to build the engine. This thread
// will loop inside buildCudaEngine() consuming the calibration data
// that is set by the TF op, and drive the builder until calibrator
// returns false. Engine is discarded after calibration table is
// generated
//
// TODO(aaroey): maybe setting the max batch size using the python
// calibration wrapper class.
auto s = convert::ConvertGraphDefToEngine(
this->segment_graph_, TrtPrecisionMode::INT8,
cres->calibrator_->getBatchSize(), this->workspace_size_,
partial_shapes, &cache_res->GetLogger(), cache_res->allocator_.get(),
cres->calibrator_.get(), &cres->engine_,
/*use_calibration=*/true,
/*convert_successfully=*/nullptr);
if (!s.ok()) {
LOG(ERROR) << "Calibration failed: " << s;
cres->calibrator_->setDone(); // Ignore further pushes
}
// Transfer the ownership of the engine to the engine cache, so we can
// dump it out during conversion for TF 2.0.
if (cache_res) {
mutex_lock lock(this->engine_mutex_);
cres->SetCalibrationTable();
this->calibrator_ = std::move(cres->calibrator_);
TrtUniquePtrType<nvinfer1::IExecutionContext> exec_context(
cres->engine_->createExecutionContext());
cache_res->cache_.emplace(
shapes, absl::make_unique<EngineContext>(
std::move(cres->engine_), std::move(exec_context)));
}
// Transfer the ownership of the engine to the engine cache, so we can
// dump it out during conversion for TF 2.0.
if (cache_res) {
mutex_lock lock(this->engine_mutex_);
cres->SetCalibrationTable();
this->calibrator_ = std::move(cres->calibrator_);
TrtUniquePtrType<nvinfer1::IExecutionContext> exec_context(
cres->engine_->createExecutionContext());
cache_res->cache_.emplace(
shapes, absl::make_unique<EngineContext>(std::move(cres->engine_),
std::move(exec_context)));
}
VLOG(1) << "Calibration loop terminated " << this->name();
}));
VLOG(1) << "Calibration loop terminated " << this->name();
}));
VLOG(1) << "initialized calibrator resource";
return Status::OK();
}

View File

@ -22,10 +22,17 @@ limitations under the License.
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/strings/str_cat.h"
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h"
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/ops_testutil.h"
@ -38,6 +45,7 @@ limitations under the License.
namespace tensorflow {
namespace tensorrt {
using ::absl::StrCat;
using ::testing::ElementsAre;
class TRTEngineOpTestBase : public OpsTestBase {
@ -49,25 +57,32 @@ class TRTEngineOpTestBase : public OpsTestBase {
// Create simple TF graph.
Scope s = Scope::NewRootScope();
auto feed = ops::Placeholder(s.WithOpName("TensorRTInputPH_0"), dtype,
ops::Placeholder::Shape({-1, -1}));
auto feed = ops::_Arg(s.WithOpName("TensorRTInputPH_0"), dtype, 0);
auto add = ops::Add(s.WithOpName("add"), feed, feed);
ops::Identity(s.WithOpName("TensorRTOutputPH_0"), add);
ops::_Retval(s.WithOpName("TensorRTOutputPH_0"), add, 0);
// Serialize the graph. TRTEngineOp will convert it using dynamic mode.
GraphDef graph_def;
TF_ASSERT_OK(s.ToGraphDef(&graph_def));
Graph* graph = s.graph();
const char* op_name = "myop";
TF_ASSERT_OK(
convert::RegisterGraphToFunctionLibrary(graph_def, graph, op_name));
TF_ASSERT_OK(flib_def_->AddLibrary(graph->flib_def()));
PartialTensorShape shape({-1, -1});
// Create the op.
OpsTestBase::SetDevice(DEVICE_GPU, std::move(device));
TF_ASSERT_OK(NodeDefBuilder("myop", "TRTEngineOp")
NameAttrList function;
function.set_name(StrCat(op_name, "_native_segment"));
TF_ASSERT_OK(NodeDefBuilder(op_name, "TRTEngineOp")
.Input(FakeInput(1, dtype))
.Attr("input_shapes", {shape})
.Attr("output_shapes", {shape})
.Attr("static_engine", false)
.Attr("segment_funcdef_name", "") // no native fallback
.Attr("serialized_segment", graph_def.SerializeAsString())
.Attr("segment_func", function)
.Attr("serialized_segment", "")
.Attr("calibration_data", "")
.Attr("max_cached_engines_count", max_cached_engines_count)
.Attr("workspace_size_bytes", 1 << 20)
@ -75,7 +90,7 @@ class TRTEngineOpTestBase : public OpsTestBase {
.Attr("use_calibration", false)
.Attr("OutT", {dtype})
.Finalize(OpsTestBase::node_def()));
TF_ASSERT_OK(OpsTestBase::InitOp());
TF_ASSERT_OK(InitOpWithFunctionLibrary());
}
template <typename T>
@ -89,9 +104,20 @@ class TRTEngineOpTestBase : public OpsTestBase {
inputs_.clear();
gtl::STLDeleteElements(&tensors_);
}
private:
Status InitOpWithFunctionLibrary() {
OpKernel* kernel = nullptr;
Status status = CreateOpKernel(device_type_, device_, allocator(),
pflr_->GetFLR(device_->name()), node_def_,
TF_GRAPH_DEF_VERSION, &kernel);
kernel_ = std::unique_ptr<OpKernel>(kernel);
if (kernel_ != nullptr) input_types_ = kernel_->input_types();
return status;
}
};
TEST_F(TRTEngineOpTestBase, dynamic_shapes) {
TEST_F(TRTEngineOpTestBase, DynamicShapes) {
TRTEngineOpTestBase::AddSimpleTrtOp(DT_FLOAT, /*max_cached_engines_count=*/4);
// Execute the op with batch size > 1.

View File

@ -33,7 +33,7 @@ namespace tensorflow {
// key to cache the instantiated functions for different executor subgraphs.
REGISTER_OP("TRTEngineOp")
.Attr("serialized_segment: string")
.Attr("segment_funcdef_name: string")
.Attr("segment_func: func = {}")
.Attr("InT: list({int8,float16,float32,int32})")
.Attr("OutT: list({int8,float16,float32,int32})")
.Attr("max_cached_engines_count: int = 1")
@ -51,10 +51,11 @@ REGISTER_OP("TRTEngineOp")
// inference function as a workaround.
.SetShapeFn(shape_inference::UnknownShape)
// Deprecated attributes.
.Attr("segment_funcdef_name: string = ''")
.Attr("cached_engine_batches: list(int) >= 0 = []")
.Attr("fixed_input_size: bool = true")
.Attr("input_shapes: list(shape)")
.Attr("output_shapes: list(shape)")
.Attr("input_shapes: list(shape) = []")
.Attr("output_shapes: list(shape) = []")
.Attr("static_engine: bool = true");
} // namespace tensorflow

View File

@ -153,8 +153,7 @@ class QuantizationAwareTrainingMNISTTest(test_util.TensorFlowTestCase):
# runtime to allocate GPU memory.
max_workspace_size_bytes=1 << 28,
minimum_segment_size=2,
use_calibration=False,
use_function_backup=False)
use_calibration=False)
graph_def = converter.convert()
logging.info('Number of nodes after TF-TRT conversion: %d',
len(graph_def.node))

View File

@ -23,6 +23,7 @@ import errno
import gc
import itertools
import os
import re
import shutil
import tempfile
import warnings
@ -234,10 +235,8 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
is_dynamic_op=run_params.dynamic_engine,
maximum_cached_engines=1,
use_calibration=run_params.use_calibration,
use_function_backup=False,
max_batch_size=min(batch_list))
return conversion_params._replace(
use_function_backup=IsQuantizationWithCalibration(conversion_params))
return conversion_params
def ShouldRunTest(self, run_params):
"""Whether to run the test."""
@ -388,8 +387,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
minimum_segment_size=conversion_params.minimum_segment_size,
is_dynamic_op=conversion_params.is_dynamic_op,
maximum_cached_engines=conversion_params.maximum_cached_engines,
use_calibration=conversion_params.use_calibration,
use_function_backup=conversion_params.use_function_backup)
use_calibration=conversion_params.use_calibration)
def _GetCalibratedInferGraph(self, run_params, saved_model_dir, inputs_data):
"""Return trt converted graphdef in INT8 mode."""
@ -558,21 +556,18 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
if node.op == "TRTEngineOp":
logging.info("Found TRTEngineOp: " + node.name)
num_engines += 1
segment_funcdef_name = node.attr["segment_funcdef_name"].s
segment_funcdef_name = node.attr["segment_func"].func.name
function_name = node.name + "_native_segment"
if IsQuantizationWithCalibration(run_params):
self.assertNotEmpty(segment_funcdef_name, node.name)
self.assertIn(function_name, functions)
else:
self.assertEmpty(segment_funcdef_name, node.name)
self.assertNotIn(function_name, functions)
is_dynamic_engine = not node.attr["static_engine"].b
self.assertNotEmpty(segment_funcdef_name, node.name)
self.assertIn(function_name, functions)
if not IsQuantizationWithCalibration and not is_dynamic_engine:
self.assertTrue(len(node.attr["serialized_segment"].s), node.name)
self.assertIn(node.name, expected_engines)
self.assertTrue(len(node.attr["serialized_segment"].s), node.name)
self.assertEqual(
self._ToBytes(run_params.precision_mode),
node.attr["precision_mode"].s, node.name)
is_dynamic_engine = not node.attr["static_engine"].b
self.assertEqual(run_params.dynamic_engine, is_dynamic_engine,
node.name)
self.assertEqual(node.attr["use_calibration"].b,
@ -602,10 +597,11 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
node.name for node in gdef_to_verify.node if node.op == "TRTEngineOp"
]
for func in gdef_to_verify.library.function:
for node in func.node_def:
all_op_names.append(node.name)
if node.op == "TRTEngineOp":
trt_op_names.append(node.name)
if not re.search(r"TRTEngineOp_\d+_native_segment", func.signature.name):
for node in func.node_def:
all_op_names.append(node.name)
if node.op == "TRTEngineOp":
trt_op_names.append(node.name)
# Remove the function name prefix.
def _Canonicalize(names):
return set([self._ToString(name.split("/")[-1]) for name in names])

View File

@ -147,11 +147,6 @@ TrtConversionParams = collections.namedtuple(
# trained with fake quantization.
"use_calibration",
# If set to True, it will create a FunctionDef for each subgraph that is
# converted to TRT op, and if TRT ops fail to execute at runtime, it'll
# invoke that function as a fallback.
"use_function_backup",
# Max size for the input batch.
# This option is deprecated in TF 2.0.
"max_batch_size",
@ -165,7 +160,6 @@ DEFAULT_TRT_CONVERSION_PARAMS = TrtConversionParams(
is_dynamic_op=False,
maximum_cached_engines=1,
use_calibration=True,
use_function_backup=True,
max_batch_size=1)
_TRT_ENGINE_CACHE_CONTAINER_NAME = "TF-TRT-Engine-Cache"
@ -272,8 +266,6 @@ def get_tensorrt_rewriter_config(
"maximum_cached_engines"].i = conversion_params.maximum_cached_engines
optimizer.parameter_map[
"use_calibration"].b = conversion_params.use_calibration
optimizer.parameter_map[
"use_function_backup"].b = conversion_params.use_function_backup
if is_v2:
# Static mode (a.k.a pre-generating TRT engines and make them node
@ -344,8 +336,7 @@ class TrtGraphConverter(object):
minimum_segment_size=3,
is_dynamic_op=False,
maximum_cached_engines=1,
use_calibration=True,
use_function_backup=True):
use_calibration=True):
"""Initialize the converter.
Args:
@ -384,9 +375,6 @@ class TrtGraphConverter(object):
will occur. Please note that accuracy may be negatively affected if
there is a mismatch between which tensors TRT quantizes and which
tensors were trained with fake quantization.
use_function_backup: if set to True, it will create a FunctionDef for each
subgraph that is converted to TRT op, and if TRT ops fail to execute at
runtime, it'll invoke that function as a fallback.
Raises:
ValueError: if the combination of the parameters is invalid.
@ -424,12 +412,6 @@ class TrtGraphConverter(object):
"dynamic TRT ops only. Disregarding is_dynamic_op parameter.")
is_dynamic_op = True
# TODO(laigd): consider provide a mechanism to remove the fallback path
# after calibration is done.
if self._need_calibration and not use_function_backup:
raise ValueError(
"Calibration requires enabling fallback to TF function execution.")
# TODO(laigd):
# - Verify in int8 mode that maximum_cached_engines is set properly.
# - If it fails to build the int8 engine it should return error.
@ -446,7 +428,6 @@ class TrtGraphConverter(object):
is_dynamic_op=is_dynamic_op,
maximum_cached_engines=maximum_cached_engines,
use_calibration=use_calibration,
use_function_backup=use_function_backup,
max_batch_size=max_batch_size)
_check_conversion_params(self._conversion_params)

View File

@ -205,8 +205,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
max_batch_size=1,
minimum_segment_size=3,
is_dynamic_op=False,
maximum_cached_engines=1,
use_function_backup=False):
maximum_cached_engines=1):
"""Helper method to convert a GraphDef or SavedModel using TF-TRT."""
converter = trt_convert.TrtGraphConverter(
input_saved_model_dir=input_saved_model_dir,
@ -220,8 +219,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
else trt_convert.TrtPrecisionMode.FP32),
minimum_segment_size=minimum_segment_size,
is_dynamic_op=is_dynamic_op,
maximum_cached_engines=maximum_cached_engines,
use_function_backup=use_function_backup)
maximum_cached_engines=maximum_cached_engines)
output_graph_def = converter.convert()
if need_calibration:
@ -254,8 +252,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
input_saved_model_dir=input_saved_model_dir,
output_saved_model_dir=output_saved_model_dir,
need_calibration=need_calibration,
is_dynamic_op=is_dynamic_op,
use_function_backup=need_calibration)
is_dynamic_op=is_dynamic_op)
graph_defs_to_verify = [output_graph_def]
if output_saved_model_dir:
@ -316,8 +313,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
conversion_params=trt_convert.DEFAULT_TRT_CONVERSION_PARAMS._replace(
precision_mode=trt_convert.TrtPrecisionMode.FP32,
is_dynamic_op=True,
maximum_cached_engines=2,
use_function_backup=False))
maximum_cached_engines=2))
@test_util.run_v2_only
def testTrtGraphConverter_BasicConversion_v2(self):
@ -564,17 +560,9 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
def _TestRun(self,
sess,
batch_size,
use_function_backup=False,
expect_engine_is_run=True):
try:
result = sess.run(
"output:0", feed_dict={"input:0": [[[1.0]]] * batch_size})
self.assertAllEqual([[[4.0]]] * batch_size, result)
except errors.OpError as e:
# This should happen only when fallback path is disabled and TRT engine
# fails to run.
self.assertTrue(not use_function_backup and not expect_engine_is_run)
self.assertIn("Fallback path is disabled, for TRTEngineOp_0", str(e))
result = sess.run("output:0", feed_dict={"input:0": [[[1.0]]] * batch_size})
self.assertAllEqual([[[4.0]]] * batch_size, result)
@test_util.deprecated_graph_mode_only
def testTrtGraphConverter_MinimumSegmentSize(self):
@ -604,8 +592,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
input_saved_model_dir=input_saved_model_dir,
output_saved_model_dir=output_saved_model_dir,
is_dynamic_op=True,
maximum_cached_engines=2,
use_function_backup=False) # Disallow fallback.
maximum_cached_engines=2)
# Test the output GraphDef.
with ops.Graph().as_default():
@ -631,7 +618,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
# the max, it should evict an old engine and create a new one.
self._TestRun(sess, 3)
def _TestStaticOp(self, use_function_backup):
def _TestStaticOp(self):
if not is_tensorrt_enabled():
return
@ -641,8 +628,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
output_graph_def = self._ConvertGraph(
input_saved_model_dir=input_saved_model_dir,
output_saved_model_dir=output_saved_model_dir,
maximum_cached_engines=2, # This is noop, added just for testing.
use_function_backup=use_function_backup)
maximum_cached_engines=2)
# Test the output GraphDef.
with ops.Graph().as_default():
@ -653,14 +639,12 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
self._TestRun(
sess,
1,
use_function_backup=use_function_backup,
expect_engine_is_run=True)
# Run with batch size 2, which exceed the max_batch_size, it should try
# to fall back to TF function.
self._TestRun(
sess,
2,
use_function_backup=use_function_backup,
expect_engine_is_run=False)
# Test the output SavedModel
@ -672,23 +656,17 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
self._TestRun(
sess,
1,
use_function_backup=use_function_backup,
expect_engine_is_run=True)
# Run with batch size 2, which exceed the max_batch_size, it should try
# to fall back to TF function.
self._TestRun(
sess,
2,
use_function_backup=use_function_backup,
expect_engine_is_run=False)
@test_util.deprecated_graph_mode_only
def testTrtGraphConverter_StaticOp_NoFallback(self):
self._TestStaticOp(use_function_backup=False)
@test_util.deprecated_graph_mode_only
def testTrtGraphConverter_StaticOp_WithFallback(self):
self._TestStaticOp(use_function_backup=True)
def testTrtGraphConverter_StaticOp(self):
self._TestStaticOp()
if __name__ == "__main__":