Merge pull request #30789 from phillip-kravtsov:trt_remove_serialized_from_op
PiperOrigin-RevId: 260966558
This commit is contained in:
commit
68edd47a23
tensorflow
compiler/tf2tensorrt
BUILD
convert
convert_graph.ccconvert_graph.hconvert_nodes.ccconvert_nodes.hconvert_nodes_test.cctrt_optimization_pass.cctrt_optimization_pass.hutils.h
kernels
ops
python/compiler/tensorrt
@ -97,10 +97,17 @@ cc_library(
|
||||
":utils",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"//tensorflow/core:core_cpu_lib_no_ops",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:gpu_headers_lib",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
"//tensorflow/core:stream_executor",
|
||||
"//tensorflow/core:stream_executor_headers_lib",
|
||||
"//tensorflow/core/grappler/costs:graph_properties",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
] + if_tensorrt([":tensorrt_lib"]) + tf_custom_op_library_additional_deps(),
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -168,8 +175,12 @@ tf_cuda_cc_test(
|
||||
":trt_op_kernels",
|
||||
":trt_op_libs",
|
||||
":trt_resources",
|
||||
":trt_conversion",
|
||||
":utils",
|
||||
"@com_google_googletest//:gtest",
|
||||
"@com_google_absl//absl/strings",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:function_ops",
|
||||
"//tensorflow/cc:ops",
|
||||
"//tensorflow/cc:scope",
|
||||
"//tensorflow/core:framework",
|
||||
@ -248,6 +259,9 @@ tf_cuda_library(
|
||||
":utils",
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//tensorflow/core/grappler:op_types",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:gpu_runtime",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
] + if_tensorrt([":tensorrt_lib"]),
|
||||
)
|
||||
@ -318,11 +332,13 @@ tf_cuda_library(
|
||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
|
||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:op_types",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//tensorflow/core:gpu_runtime",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler:devices",
|
||||
|
@ -324,8 +324,6 @@ Status CreateTRTNode(const ConversionParams& params,
|
||||
nvinfer1::IGpuAllocator* alloc,
|
||||
std::vector<Node*>* engine_nodes) {
|
||||
const auto& info = infos.at(pos);
|
||||
std::vector<TensorShapeProto> output_shape_protos;
|
||||
std::vector<TensorShapeProto> input_shape_protos;
|
||||
std::vector<PartialTensorShape> input_shapes;
|
||||
std::vector<NodeDefBuilder::NodeOut> inputs;
|
||||
std::vector<Node*> input_nodes;
|
||||
@ -359,25 +357,16 @@ Status CreateTRTNode(const ConversionParams& params,
|
||||
} else {
|
||||
// Data edges
|
||||
if (!conn.is_input_edge) {
|
||||
// Set the shapes and data types of output edge.
|
||||
TensorShapeProto out_shape;
|
||||
// shape of the output node inside segment
|
||||
conn.inside_shape.AsProto(&out_shape);
|
||||
if (output_shape_protos.size() <= conn.port_number) {
|
||||
output_shape_protos.resize(conn.port_number + 1);
|
||||
// Set the data types of output edge.
|
||||
if (out_types.size() <= conn.port_number) {
|
||||
out_types.resize(conn.port_number + 1);
|
||||
}
|
||||
output_shape_protos.at(conn.port_number) = out_shape;
|
||||
out_types.at(conn.port_number) = conn.connection_type;
|
||||
} else {
|
||||
// Set the shapes and data types of input edge.
|
||||
TensorShapeProto in_shape;
|
||||
conn.outside_shape.AsProto(&in_shape);
|
||||
if (input_shape_protos.size() <= conn.port_number) {
|
||||
input_shape_protos.resize(conn.port_number + 1);
|
||||
if (input_shapes.size() <= conn.port_number) {
|
||||
input_shapes.resize(conn.port_number + 1);
|
||||
}
|
||||
input_shape_protos.at(conn.port_number) = in_shape;
|
||||
input_shapes.at(conn.port_number) = conn.outside_shape;
|
||||
// Shape must be fully defined (excluding batch dimension) for static
|
||||
// mode.
|
||||
@ -439,8 +428,6 @@ Status CreateTRTNode(const ConversionParams& params,
|
||||
TrtUniquePtrType<nvinfer1::IHostMemory> engine_data(engine->serialize());
|
||||
segment_string = string(static_cast<const char*>(engine_data->data()),
|
||||
engine_data->size());
|
||||
} else {
|
||||
segment_string = info.segment_graph_def.SerializeAsString();
|
||||
}
|
||||
|
||||
string prec_string;
|
||||
@ -460,15 +447,13 @@ Status CreateTRTNode(const ConversionParams& params,
|
||||
}
|
||||
|
||||
NodeDef trt_node;
|
||||
NameAttrList function;
|
||||
function.set_name(StrCat(info.engine_name, "_native_segment"));
|
||||
Status status =
|
||||
node_builder.Attr("input_shapes", input_shape_protos)
|
||||
.Attr("output_shapes", output_shape_protos)
|
||||
node_builder
|
||||
.Attr("static_engine",
|
||||
info.engine_type == EngineInfo::EngineType::TRTStatic)
|
||||
.Attr("segment_funcdef_name",
|
||||
params.use_function_backup
|
||||
? StrCat(info.engine_name, "_native_segment")
|
||||
: "")
|
||||
.Attr("segment_func", function)
|
||||
.Attr("serialized_segment", segment_string)
|
||||
.Attr("calibration_data", "")
|
||||
.Attr("max_cached_engines_count", info.maximum_cached_engines)
|
||||
@ -537,103 +522,27 @@ Status CreateTRTNode(const ConversionParams& params,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Function to construct a funcdef from the segment and add it to the graph.
|
||||
Status RegisterSegmentFunctionToFunctionLibrary(Graph* graph,
|
||||
const GraphDef& segment,
|
||||
const string& engine_name) {
|
||||
Graph sgraph(graph->flib_def());
|
||||
GraphConstructorOptions gcopts;
|
||||
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(gcopts, segment, &sgraph));
|
||||
std::map<string, Node*> io_nodes;
|
||||
int num_inputs = 0;
|
||||
for (auto n : sgraph.op_nodes()) {
|
||||
if (absl::StartsWith(n->name(), kInputPHName)) {
|
||||
num_inputs++;
|
||||
io_nodes.insert({n->name(), n});
|
||||
} else if (absl::StartsWith(n->name(), kOutputPHName)) {
|
||||
io_nodes.insert({n->name(), n});
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
auto name = StrCat(kInputPHName, i);
|
||||
auto node = io_nodes[name];
|
||||
NodeDef nd;
|
||||
NodeDefBuilder node_builder(StrCat(name, "_Arg"),
|
||||
FunctionLibraryDefinition::kArgOp);
|
||||
VLOG(1) << "Adding " << StrCat(name, "_Arg");
|
||||
TF_RETURN_IF_ERROR(node_builder.Attr("T", node->output_type(0))
|
||||
.Attr("index", i)
|
||||
.Finalize(&nd));
|
||||
Status s;
|
||||
auto node_arg = sgraph.AddNode(nd, &s);
|
||||
if (!s.ok()) {
|
||||
LOG(ERROR) << "Couldn't add _Arg node for " << name;
|
||||
}
|
||||
for (auto edge : node->out_edges()) {
|
||||
sgraph.AddEdge(node_arg, 0, edge->dst(), edge->dst_input());
|
||||
VLOG(1) << "Updating funcdef input " << node_arg->name() << ":" << 0
|
||||
<< " - > " << edge->dst()->name() << ":" << edge->dst_input();
|
||||
if (!s.ok()) {
|
||||
LOG(ERROR) << "Failed to update edge from " << node_arg->name()
|
||||
<< " to " << edge->dst()->name() << ":" << edge->dst_input();
|
||||
}
|
||||
}
|
||||
sgraph.RemoveNode(node);
|
||||
}
|
||||
|
||||
for (int i = 0; i < io_nodes.size() - num_inputs; ++i) {
|
||||
auto name = StrCat(kOutputPHName, i);
|
||||
auto node = io_nodes[name];
|
||||
NodeDef nd;
|
||||
NodeDefBuilder node_builder(StrCat(name, "_Ret"),
|
||||
FunctionLibraryDefinition::kRetOp);
|
||||
auto edge = *(node->in_edges().begin());
|
||||
NodeDefBuilder::NodeOut nout(edge->src()->name(), edge->src_output(),
|
||||
edge->src()->output_type(edge->src_output()));
|
||||
VLOG(1) << " input " << nout.node << ":" << nout.index
|
||||
<< " dtype=" << DataTypeString(nout.data_type);
|
||||
// nvcc complains that Input(<brace-enclosed initializer list>) is
|
||||
// ambiguous, so do not use Input({nout}).
|
||||
node_builder.Input(nout);
|
||||
TF_RETURN_IF_ERROR(node_builder.Attr("T", node->output_type(0))
|
||||
.Attr("index", i)
|
||||
.Finalize(&nd));
|
||||
if (VLOG_IS_ON(3)) {
|
||||
VLOG(3) << nd.DebugString();
|
||||
}
|
||||
Status s;
|
||||
auto node_ret = sgraph.AddNode(nd, &s);
|
||||
if (!s.ok()) {
|
||||
LOG(ERROR) << "Couldn't add _Ret node for " << name;
|
||||
}
|
||||
VLOG(1) << "Update edge from " << edge->src()->name() << ":"
|
||||
<< edge->src_output() << " - > " << node_ret->name() << ":" << 0;
|
||||
sgraph.AddEdge(edge->src(), edge->src_output(), node_ret, 0);
|
||||
s = sgraph.UpdateEdge(edge->src(), edge->src_output(), node_ret, 0);
|
||||
if (!s.ok()) {
|
||||
LOG(ERROR) << "Failed to update edge from " << edge->src()->name() << ":"
|
||||
<< edge->src_output() << " - > " << node_ret->name() << ":"
|
||||
<< 0;
|
||||
}
|
||||
sgraph.RemoveNode(node);
|
||||
}
|
||||
FunctionDefLibrary fdeflib;
|
||||
auto native_segment = fdeflib.add_function();
|
||||
Status RegisterGraphToFunctionLibrary(const GraphDef& segment_graph_def,
|
||||
Graph* graph, const string& engine_name) {
|
||||
Graph segment_graph(graph->flib_def());
|
||||
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(),
|
||||
segment_graph_def, &segment_graph));
|
||||
FunctionDefLibrary library;
|
||||
auto segment_func = library.add_function();
|
||||
TF_RETURN_IF_ERROR(GraphToFunctionDef(
|
||||
sgraph, StrCat(engine_name, "_native_segment"), native_segment));
|
||||
segment_graph, StrCat(engine_name, "_native_segment"), segment_func));
|
||||
// Set kIntsonDeviceAttr to true so that all TRTEngineOp outputs are always on
|
||||
// a GPU device as expected. Otherwise, some of the tensors of type DT_INT32
|
||||
// would be on host if the op generating the tensor has host memory tag set.
|
||||
(*native_segment
|
||||
->mutable_attr())[FunctionLibraryDefinition::kIntsOnDeviceAttr]
|
||||
(*segment_func->mutable_attr())[FunctionLibraryDefinition::kIntsOnDeviceAttr]
|
||||
.set_b(true);
|
||||
if (VLOG_IS_ON(7)) {
|
||||
VLOG(7) << engine_name << " Function_Def ";
|
||||
VLOG(7) << native_segment->DebugString();
|
||||
VLOG(7) << segment_func->DebugString();
|
||||
}
|
||||
VLOG(1) << "Adding funcdef to graphlib";
|
||||
TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(fdeflib));
|
||||
VLOG(1) << "Adding funcdef " << segment_func->signature().name()
|
||||
<< " to graphlib";
|
||||
TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(library));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -690,16 +599,10 @@ std::pair<int, Allocator*> GetDeviceAndAllocator(const ConversionParams& params,
|
||||
// Entry function from optimization pass.
|
||||
Status ConvertAfterShapes(const ConversionParams& params) {
|
||||
// Sanity checks.
|
||||
if (params.precision_mode == TrtPrecisionMode::INT8) {
|
||||
if (params.use_calibration && !params.use_function_backup) {
|
||||
return errors::InvalidArgument(
|
||||
"Calibration requires enabling fallback to TF function execution.");
|
||||
}
|
||||
} else {
|
||||
if (params.use_calibration) {
|
||||
return errors::InvalidArgument(
|
||||
"Calibration with FP32 or FP16 is not supported.");
|
||||
}
|
||||
if (params.precision_mode != TrtPrecisionMode::INT8 &&
|
||||
params.use_calibration) {
|
||||
return errors::InvalidArgument(
|
||||
"Calibration with FP32 or FP16 is not supported.");
|
||||
}
|
||||
|
||||
// Convert graphdef to graph.
|
||||
@ -760,14 +663,14 @@ Status ConvertAfterShapes(const ConversionParams& params) {
|
||||
: EngineInfo::EngineType::TRTStatic);
|
||||
curr_engine.use_calibration = params.use_calibration;
|
||||
curr_engine.maximum_cached_engines = params.max_cached_engines;
|
||||
if (params.use_function_backup) {
|
||||
status = RegisterSegmentFunctionToFunctionLibrary(
|
||||
&graph, curr_engine.segment_graph_def, curr_engine.engine_name);
|
||||
if (!status.ok()) {
|
||||
LOG(WARNING) << "Failed to register segment graphdef as a function "
|
||||
<< t << ": " << status;
|
||||
continue;
|
||||
}
|
||||
|
||||
status = RegisterGraphToFunctionLibrary(curr_engine.segment_graph_def,
|
||||
&graph, curr_engine.engine_name);
|
||||
|
||||
if (!status.ok()) {
|
||||
LOG(WARNING) << "Failed to register segment graphdef to the library " << t
|
||||
<< ": " << status;
|
||||
continue;
|
||||
}
|
||||
|
||||
engine_bytes_size.push_back(curr_engine.segment_graph_def.ByteSizeLong());
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/grappler/clusters/cluster.h"
|
||||
#include "tensorflow/core/grappler/costs/graph_properties.h"
|
||||
@ -46,8 +47,6 @@ struct ConversionParams {
|
||||
// maximum number of cached engines
|
||||
int max_cached_engines = 1;
|
||||
bool use_calibration = true;
|
||||
// Whether to use function fallback for TRTEngineOp
|
||||
bool use_function_backup = true;
|
||||
};
|
||||
|
||||
// Method to call from optimization pass
|
||||
@ -57,6 +56,11 @@ Status ConvertAfterShapes(const ConversionParams& params);
|
||||
std::pair<int, Allocator*> GetDeviceAndAllocator(const ConversionParams& params,
|
||||
const EngineInfo& engine);
|
||||
|
||||
// Helper method that registers `segment_graph` as a function to the function
|
||||
// library in `graph`.
|
||||
Status RegisterGraphToFunctionLibrary(const GraphDef& segment_graph_def,
|
||||
Graph* graph, const string& engine_name);
|
||||
|
||||
} // namespace convert
|
||||
} // namespace tensorrt
|
||||
} // namespace tensorflow
|
||||
|
@ -39,6 +39,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/grappler/op_types.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/strings/numbers.h"
|
||||
@ -75,18 +76,15 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
// TODO(aaroey): put these constants into some class.
|
||||
const char* const kInputPHName = "TensorRTInputPH_";
|
||||
const char* const kOutputPHName = "TensorRTOutputPH_";
|
||||
namespace convert {
|
||||
|
||||
bool IsEngineInput(absl::string_view name) {
|
||||
return absl::StartsWith(name, kInputPHName);
|
||||
return absl::StartsWith(name, IONamePrefixes::kInputPHName);
|
||||
}
|
||||
bool IsEngineOutput(absl::string_view name) {
|
||||
return absl::StartsWith(name, kOutputPHName);
|
||||
return absl::StartsWith(name, IONamePrefixes::kOutputPHName);
|
||||
}
|
||||
|
||||
namespace convert {
|
||||
using absl::StrAppend;
|
||||
using absl::StrCat;
|
||||
|
||||
@ -5200,19 +5198,33 @@ Status ConvertGraphDefToEngine(
|
||||
for (const auto& node_def : gdef.node()) {
|
||||
string node_name = node_def.name();
|
||||
VLOG(2) << "Converting op name=" << node_name << ", op=" << node_def.op();
|
||||
if (IsEngineInput(node_name) && (node_def.op() == "Placeholder")) {
|
||||
if (IsEngineInput(node_name)) {
|
||||
int32 slot_number = -1;
|
||||
if (!strings::safe_strto32( // non-absl ok
|
||||
node_name.c_str() + strlen(kInputPHName), &slot_number)) {
|
||||
return errors::InvalidArgument("Failed to parse slot number from ",
|
||||
node_name);
|
||||
string type_key;
|
||||
if (node_def.op() == "Placeholder") {
|
||||
if (!strings::safe_strto32( // non-absl ok
|
||||
node_name.c_str() + strlen(IONamePrefixes::kInputPHName),
|
||||
&slot_number)) {
|
||||
return errors::InvalidArgument("Failed to parse slot number from ",
|
||||
node_name);
|
||||
}
|
||||
type_key = "dtype";
|
||||
} else if (tensorflow::grappler::IsArg(node_def)) {
|
||||
// Maybe remove the dependence on grappler and re-implement IsArg,
|
||||
// which is pretty simple (but could change if new Arg nodes are added)
|
||||
slot_number = node_def.attr().at("index").i();
|
||||
type_key = "T";
|
||||
} else {
|
||||
return errors::InvalidArgument(
|
||||
"Node ", node_name,
|
||||
" with is neither Placeholder nor Arg, instead ", node_def.op());
|
||||
}
|
||||
nvinfer1::DataType trt_dtype;
|
||||
nvinfer1::Dims trt_dims;
|
||||
int batch_size = -1;
|
||||
auto shape = input_shapes.at(slot_number);
|
||||
auto status = ValidateTensorProperties(
|
||||
node_def.op(), node_def.attr().at("dtype").type(), shape,
|
||||
node_def.op(), node_def.attr().at(type_key).type(), shape,
|
||||
/*validation_only=*/false, &trt_dtype, &trt_dims, &batch_size);
|
||||
if (!status.ok()) {
|
||||
const string error_message =
|
||||
@ -5228,12 +5240,23 @@ Status ConvertGraphDefToEngine(
|
||||
// engines offline, by calling sess.run() and cache/serialize the engines.
|
||||
TF_RETURN_IF_ERROR(
|
||||
converter.AddInputTensor(node_name, trt_dtype, trt_dims, batch_size));
|
||||
} else if (IsEngineOutput(node_name) && (node_def.op() == "Identity")) {
|
||||
} else if (IsEngineOutput(node_name)) {
|
||||
int32 slot_number = -1;
|
||||
if (!strings::safe_strto32( // non-absl ok
|
||||
node_name.c_str() + strlen(kOutputPHName), &slot_number)) {
|
||||
return errors::InvalidArgument("Failed to parse slot number from ",
|
||||
node_name);
|
||||
if (node_def.op() == "Identity") {
|
||||
if (!strings::safe_strto32( // non-absl ok
|
||||
node_name.c_str() + strlen(IONamePrefixes::kOutputPHName),
|
||||
&slot_number)) {
|
||||
return errors::InvalidArgument("Failed to parse slot number from ",
|
||||
node_name);
|
||||
}
|
||||
} else if (tensorflow::grappler::IsRetval(node_def)) {
|
||||
slot_number = node_def.attr().at("index").i();
|
||||
} else {
|
||||
return errors::InvalidArgument(
|
||||
"Node with name ", node_name,
|
||||
" starting with IONamePrefixes::kOutputPHName is "
|
||||
"neither Identity nor Retval, instead ",
|
||||
node_def.op());
|
||||
}
|
||||
// Get output type that TensorFlow expects
|
||||
TFAttrs attrs(node_def);
|
||||
@ -5302,7 +5325,8 @@ Status ConvertSegmentToGraphDef(
|
||||
|
||||
// Add dummy input/output nodes to the segment graphdef.
|
||||
if (connection.is_input_edge) {
|
||||
const string node_name = StrCat(kInputPHName, connection.port_number);
|
||||
const string node_name =
|
||||
StrCat(IONamePrefixes::kInputPHName, connection.port_number);
|
||||
if (marker_nodes.count(node_name)) {
|
||||
VLOG(1) << "Reusing input " << node_name << " for the edge "
|
||||
<< connection.outside_node_name << ":"
|
||||
@ -5312,16 +5336,18 @@ Status ConvertSegmentToGraphDef(
|
||||
}
|
||||
marker_nodes.insert(node_name);
|
||||
auto seg_node = segment_def->add_node();
|
||||
NodeDefBuilder builder(node_name, "Placeholder");
|
||||
NodeDefBuilder builder(node_name, "_Arg");
|
||||
auto status = builder.Attr("shape", partial_shape)
|
||||
.Attr("dtype", dtype)
|
||||
.Attr("T", dtype)
|
||||
.Attr("index", connection.port_number)
|
||||
.Finalize(seg_node);
|
||||
VLOG(1) << "Constructing input " << node_name << " for the edge "
|
||||
<< connection.outside_node_name << ":" << connection.outside_port
|
||||
<< " -> " << connection.inside_node_name << ":"
|
||||
<< connection.inside_port;
|
||||
} else {
|
||||
const string node_name = StrCat(kOutputPHName, connection.port_number);
|
||||
const string node_name =
|
||||
StrCat(IONamePrefixes::kOutputPHName, connection.port_number);
|
||||
if (marker_nodes.count(node_name)) {
|
||||
VLOG(1) << "Reusing output " << node_name << " for the edge "
|
||||
<< connection.inside_node_name << ":" << connection.inside_port
|
||||
@ -5331,9 +5357,10 @@ Status ConvertSegmentToGraphDef(
|
||||
}
|
||||
marker_nodes.insert(node_name);
|
||||
auto seg_node = segment_def->add_node();
|
||||
NodeDefBuilder builder(node_name, "Identity");
|
||||
NodeDefBuilder builder(node_name, "_Retval");
|
||||
auto status =
|
||||
builder
|
||||
builder.Attr("T", dtype)
|
||||
.Attr("index", connection.port_number)
|
||||
.Input(connection.inside_node_name, connection.inside_port, dtype)
|
||||
.Finalize(seg_node);
|
||||
VLOG(1) << "Constructing output " << node_name << " for the edge "
|
||||
@ -5359,12 +5386,12 @@ Status ConvertSegmentToGraphDef(
|
||||
if (connection.is_control_edge() || !connection.is_input_edge) continue;
|
||||
auto snode =
|
||||
segment_def->mutable_node(old_to_new_id_map[connection.inside_id]);
|
||||
const string placeholder_name =
|
||||
StrCat(kInputPHName, connection.port_number);
|
||||
const string arg_name =
|
||||
StrCat(IONamePrefixes::kInputPHName, connection.port_number);
|
||||
VLOG(1) << "Updating " << snode->name() << ":" << connection.inside_port
|
||||
<< " from " << snode->input(connection.inside_port) << " to "
|
||||
<< placeholder_name;
|
||||
snode->set_input(connection.inside_port, placeholder_name);
|
||||
<< arg_name;
|
||||
snode->set_input(connection.inside_port, arg_name);
|
||||
}
|
||||
std::set<string> subgraph_node_names;
|
||||
for (const Node* node : subgraph_nodes) {
|
||||
|
@ -37,8 +37,6 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
extern const char* const kInputPHName;
|
||||
extern const char* const kOutputPHName;
|
||||
|
||||
namespace convert {
|
||||
|
||||
@ -119,8 +117,8 @@ struct EngineInfo {
|
||||
bool use_calibration;
|
||||
};
|
||||
|
||||
// Constructs a graphdef from the segment in the given graph. Adds placeholder
|
||||
// nodes for input edges (InputPH_*) and identity nodes for output edges
|
||||
// Constructs a graphdef from the segment in the given graph. Adds _Arg
|
||||
// nodes for input edges (InputPH_*) and _Retval nodes for output edges
|
||||
// (OutputPH_*). This function needs to be called before TensorRT nodes
|
||||
// inserted in order to correctly get sizes from the original graph.
|
||||
//
|
||||
|
@ -1158,7 +1158,7 @@ class ConvertGraphDefToEngineTest : public ::testing::Test {
|
||||
int batch_size = -1;
|
||||
for (const NodeDef& node : gdef.node()) {
|
||||
absl::string_view node_name(node.name());
|
||||
if (absl::ConsumePrefix(&node_name, kInputPHName)) {
|
||||
if (absl::ConsumePrefix(&node_name, IONamePrefixes::kInputPHName)) {
|
||||
int port = -1;
|
||||
EXPECT_TRUE(absl::SimpleAtoi(node_name, &port)) << node.name();
|
||||
if (input_shapes.size() < port + 1) input_shapes.resize(port + 1);
|
||||
@ -1188,11 +1188,13 @@ class ConvertGraphDefToEngineTest : public ::testing::Test {
|
||||
|
||||
TEST_F(ConvertGraphDefToEngineTest, IdentityGraph) {
|
||||
Scope s = Scope::NewRootScope();
|
||||
auto input = ops::Placeholder(s.WithOpName(StrCat(kInputPHName, 0)), DT_FLOAT,
|
||||
ops::Placeholder::Shape({1, 1}));
|
||||
auto input =
|
||||
ops::Placeholder(s.WithOpName(StrCat(IONamePrefixes::kInputPHName, 0)),
|
||||
DT_FLOAT, ops::Placeholder::Shape({1, 1}));
|
||||
auto output = ops::Identity(s.WithOpName("identity1"), input);
|
||||
output = ops::Identity(s.WithOpName("identity2"), output);
|
||||
output = ops::Identity(s.WithOpName(StrCat(kOutputPHName, 0)), output);
|
||||
output = ops::Identity(s.WithOpName(StrCat(IONamePrefixes::kOutputPHName, 0)),
|
||||
output);
|
||||
// If the converter marks the input tensor as output tensor, the conversion
|
||||
// below will fail with:
|
||||
// > TensorRTOutputPH_0 cannot be both input and output
|
||||
|
@ -67,9 +67,6 @@ Status TRTOptimizationPass::Init(
|
||||
if (params.count("use_calibration")) {
|
||||
use_calibration_ = params.at("use_calibration").b();
|
||||
}
|
||||
if (params.count("use_function_backup")) {
|
||||
use_function_backup_ = params.at("use_function_backup").b();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -258,7 +255,6 @@ Status TRTOptimizationPass::Optimize(grappler::Cluster* cluster,
|
||||
cp.is_dyn_op = is_dynamic_op_;
|
||||
cp.max_cached_engines = max_cached_batches_;
|
||||
cp.use_calibration = use_calibration_;
|
||||
cp.use_function_backup = use_function_backup_;
|
||||
auto status = ConvertAfterShapes(cp);
|
||||
VLOG(1) << "Returning from " << name_;
|
||||
return status;
|
||||
|
@ -40,8 +40,7 @@ class TRTOptimizationPass : public grappler::CustomGraphOptimizer {
|
||||
is_dynamic_op_(false),
|
||||
max_cached_batches_(1),
|
||||
max_workspace_size_bytes_(256LL << 20),
|
||||
use_calibration_(true),
|
||||
use_function_backup_(true) {
|
||||
use_calibration_(true) {
|
||||
VLOG(1) << "Constructing " << name_;
|
||||
}
|
||||
|
||||
@ -71,8 +70,6 @@ class TRTOptimizationPass : public grappler::CustomGraphOptimizer {
|
||||
int64_t max_workspace_size_bytes_;
|
||||
bool use_calibration_;
|
||||
|
||||
// Whether to allow TF function fallback path in TRTEngineOp.
|
||||
bool use_function_backup_;
|
||||
};
|
||||
|
||||
} // namespace convert
|
||||
|
@ -23,6 +23,12 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
|
||||
class IONamePrefixes {
|
||||
public:
|
||||
static constexpr const char* const kInputPHName = "TensorRTInputPH_";
|
||||
static constexpr const char* const kOutputPHName = "TensorRTOutputPH_";
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct TrtDestroyer {
|
||||
void operator()(T* t) {
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/ascii.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
|
||||
@ -24,10 +25,15 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/graph_optimizer.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/graph_to_functiondef.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
@ -53,6 +59,7 @@ using ::stream_executor::port::StatusOr;
|
||||
|
||||
// A helper class to call done() when destructed for asynchronous execution.
|
||||
// Helps simultaneous execution of native and TRT engines.
|
||||
|
||||
class AsyncHelper : public core::RefCounted {
|
||||
public:
|
||||
AsyncHelper(AsyncOpKernel::DoneCallback done) : done_(done) {}
|
||||
@ -89,7 +96,10 @@ class TRTEngineOp : public AsyncOpKernel {
|
||||
void ExecuteCalibration(OpKernelContext* ctx, AsyncHelper* helper);
|
||||
|
||||
// Construct a function handle for executing native funcdef graph
|
||||
Status ConstructFunctionHandle(OpKernelContext* ctx);
|
||||
// These are the exact same function.
|
||||
|
||||
Status ConstructFunctionHandle(FunctionLibraryRuntime* lib,
|
||||
const string& device_name);
|
||||
|
||||
// Execute replaced native segment as function Op.
|
||||
void ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper);
|
||||
@ -125,10 +135,8 @@ class TRTEngineOp : public AsyncOpKernel {
|
||||
// serialized protobuf segment or trt engine depending on static_engine_ flag.
|
||||
string serialized_segment_;
|
||||
|
||||
// Name of the function for TF native execution of the segment. If empty, it
|
||||
// means TF native execution is not allowed, and if TRT engine fails to run
|
||||
// an error will be returned.
|
||||
string funcdef_name_;
|
||||
// The function for TF native execution of the segment.
|
||||
NameAttrList func_;
|
||||
|
||||
// GraphDef representation of the segment.
|
||||
GraphDef segment_graph_;
|
||||
@ -148,7 +156,7 @@ class TRTEngineOp : public AsyncOpKernel {
|
||||
|
||||
int64 workspace_size_;
|
||||
mutex engine_mutex_;
|
||||
FunctionLibraryRuntime::Handle native_func_;
|
||||
FunctionLibraryRuntime::Handle func_handle_;
|
||||
|
||||
// The finalized calibrator for inference.
|
||||
std::unique_ptr<TRTInt8Calibrator> calibrator_;
|
||||
@ -177,23 +185,61 @@ void* GetTensorAddress(const Tensor* tensor_ptr) {
|
||||
}
|
||||
}
|
||||
|
||||
Status TRTEngineOp::ConstructFunctionHandle(OpKernelContext* ctx) {
|
||||
static Status FunctionDefToGraphDef(FunctionLibraryRuntime::Handle handle,
|
||||
FunctionLibraryRuntime* flib_runtime,
|
||||
GraphDef* graph_def) {
|
||||
const FunctionLibraryDefinition* flib_def =
|
||||
flib_runtime->GetFunctionLibraryDefinition();
|
||||
const FunctionBody* fbody;
|
||||
fbody = flib_runtime->GetFunctionBody(handle);
|
||||
if (!fbody) {
|
||||
return errors::Internal(
|
||||
"Function body is null when converting from FuncDef to GraphDef.");
|
||||
}
|
||||
std::unique_ptr<Graph> graph(new Graph(flib_def));
|
||||
CopyGraph(*fbody->graph, graph.get());
|
||||
|
||||
auto replace_name = [](const char* const prefix, string* name) {
|
||||
if (absl::StartsWith(*name, absl::AsciiStrToLower(prefix))) {
|
||||
name->replace(0, strlen(prefix), prefix);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
graph->ToGraphDef(graph_def);
|
||||
// GraphToFunctionDef() will convert all the node names to lowercase.
|
||||
for (auto& node : *graph_def->mutable_node()) {
|
||||
if (!replace_name(IONamePrefixes::kInputPHName, node.mutable_name())) {
|
||||
if (replace_name(IONamePrefixes::kOutputPHName, node.mutable_name())) {
|
||||
// Instantiation of the function will append _RetVal to the node name,
|
||||
// need to remove it for backward compatibility.
|
||||
const char* const suffix_to_remove = "_RetVal";
|
||||
if (absl::EndsWith(node.name(), suffix_to_remove)) {
|
||||
node.mutable_name()->erase(node.name().size() -
|
||||
strlen(suffix_to_remove));
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto& input : *node.mutable_input()) {
|
||||
if (!replace_name(IONamePrefixes::kInputPHName, &input)) {
|
||||
replace_name(IONamePrefixes::kOutputPHName, &input);
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TRTEngineOp::ConstructFunctionHandle(FunctionLibraryRuntime* lib,
|
||||
const string& device_name) {
|
||||
VLOG(1) << "Constructing function handle";
|
||||
auto lib = ctx->function_library();
|
||||
if (lib == nullptr) {
|
||||
return errors::Internal("Context function library is null");
|
||||
}
|
||||
auto fdef = lib->GetFunctionLibraryDefinition()->Find(funcdef_name_);
|
||||
if (fdef == nullptr) {
|
||||
return errors::Internal("Native FunctionDef ", funcdef_name_,
|
||||
" can't be found in function library");
|
||||
}
|
||||
FunctionLibraryRuntime::InstantiateOptions inst_ops;
|
||||
inst_ops.state_handle = "";
|
||||
inst_ops.target = ctx->device()->name();
|
||||
native_func_ = 0;
|
||||
return lib->Instantiate(funcdef_name_, AttrSlice(&fdef->attr()), inst_ops,
|
||||
&native_func_);
|
||||
inst_ops.target = device_name;
|
||||
return lib->Instantiate(func_.name(), AttrSlice(&func_.attr()), inst_ops,
|
||||
&func_handle_);
|
||||
}
|
||||
|
||||
TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
|
||||
@ -204,15 +250,7 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("workspace_size_bytes", &workspace_size_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("static_engine", &static_engine_));
|
||||
if (!static_engine_) {
|
||||
OP_REQUIRES(context, segment_graph_.ParseFromString(serialized_segment_),
|
||||
errors::InvalidArgument("Failed to parse segment graphdef!"));
|
||||
VLOG(1) << "Size of serialized GraphDef: "
|
||||
<< serialized_segment_.capacity();
|
||||
string tmp;
|
||||
// Swap with temporary empty string to deallocate the CPU memory.
|
||||
serialized_segment_.swap(tmp);
|
||||
}
|
||||
|
||||
VLOG(1) << "Constructing " << name();
|
||||
string precision_string;
|
||||
OP_REQUIRES_OK(context,
|
||||
@ -220,12 +258,22 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
|
||||
string calibration_data;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("calibration_data", &calibration_data));
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("segment_funcdef_name", &funcdef_name_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("segment_func", &func_));
|
||||
OP_REQUIRES(context, !func_.name().empty(),
|
||||
errors::InvalidArgument(
|
||||
"The TF function for the TRT segment could not be empty"));
|
||||
OP_REQUIRES_OK(context,
|
||||
TrtPrecisionModeFromName(precision_string, &precision_mode_));
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("use_calibration", &use_calibration_));
|
||||
func_handle_ = kInvalidHandle;
|
||||
if (!static_engine_) {
|
||||
FunctionLibraryRuntime* lib = context->function_library();
|
||||
OP_REQUIRES_OK(context,
|
||||
ConstructFunctionHandle(lib, context->device()->name()));
|
||||
OP_REQUIRES_OK(context,
|
||||
FunctionDefToGraphDef(func_handle_, lib, &segment_graph_));
|
||||
}
|
||||
calibration_mode_ =
|
||||
(use_calibration_ && precision_mode_ == TrtPrecisionMode::INT8 &&
|
||||
calibration_data.empty());
|
||||
@ -233,20 +281,19 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
|
||||
calibrator_.reset(new TRTInt8Calibrator(calibration_data));
|
||||
calibration_data.resize(0);
|
||||
}
|
||||
native_func_ = kInvalidHandle;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("max_cached_engines_count",
|
||||
&max_cached_engines_));
|
||||
}
|
||||
|
||||
void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx,
|
||||
AsyncHelper* helper) {
|
||||
OP_REQUIRES_ASYNC(ctx, !funcdef_name_.empty(),
|
||||
errors::Internal("Fallback path is disabled, for ", name()),
|
||||
*helper);
|
||||
std::vector<Tensor> inputs;
|
||||
std::vector<Tensor>* outputs = new std::vector<Tensor>();
|
||||
if (native_func_ == kInvalidHandle) {
|
||||
OP_REQUIRES_OK_ASYNC(ctx, ConstructFunctionHandle(ctx), *helper);
|
||||
if (func_handle_ == kInvalidHandle) {
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx,
|
||||
ConstructFunctionHandle(ctx->function_library(), ctx->device()->name()),
|
||||
*helper);
|
||||
}
|
||||
auto lib = ctx->function_library();
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
@ -259,7 +306,7 @@ void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx,
|
||||
}
|
||||
helper->Ref(); // Increment count for calculating native graph
|
||||
VLOG(1) << "Executing native segment: " << name();
|
||||
lib->Run(opts, native_func_, inputs, outputs,
|
||||
lib->Run(opts, func_handle_, inputs, outputs,
|
||||
[this, ctx, outputs, helper](const Status& s) {
|
||||
core::ScopedUnref sc(helper);
|
||||
OP_REQUIRES_OK_ASYNC(ctx, s, *helper);
|
||||
@ -298,7 +345,7 @@ void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx,
|
||||
const auto device_tensor =
|
||||
calib_ctx->device_tensors_.at(i).AccessTensor(ctx);
|
||||
CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes());
|
||||
input_data.emplace(StrCat(kInputPHName, i), data_address);
|
||||
input_data.emplace(StrCat(IONamePrefixes::kInputPHName, i), data_address);
|
||||
}
|
||||
VLOG(2) << "Filled map for sending";
|
||||
// copied from cuda_kernel_helper since it seems only valid in *.cu.cc files
|
||||
@ -435,9 +482,11 @@ bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx,
|
||||
// input.
|
||||
const int num_batch = ctx->input(0).shape().dim_size(0);
|
||||
const int num_binding = ctx->num_inputs() + ctx->num_outputs();
|
||||
|
||||
std::vector<void*> buffers(num_binding);
|
||||
|
||||
for (int i = 0; i < ctx->num_inputs(); i++) {
|
||||
const string input_name = StrCat(kInputPHName, i);
|
||||
const string input_name = StrCat(IONamePrefixes::kInputPHName, i);
|
||||
const int binding_index = cuda_engine->getBindingIndex(input_name.c_str());
|
||||
if (binding_index == -1) {
|
||||
const string msg =
|
||||
@ -479,7 +528,7 @@ bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx,
|
||||
|
||||
for (int i = 0; i < ctx->num_outputs(); i++) {
|
||||
// Create an output tensor
|
||||
const string output_name = StrCat(kOutputPHName, i);
|
||||
const string output_name = StrCat(IONamePrefixes::kOutputPHName, i);
|
||||
const int binding_index = cuda_engine->getBindingIndex(output_name.c_str());
|
||||
Tensor* output_tensor = nullptr;
|
||||
|
||||
@ -713,7 +762,7 @@ Status TRTEngineOp::AllocateCalibrationResources(
|
||||
"Unsupported data type encountered in input ", i);
|
||||
}
|
||||
cres->device_buffers_.emplace(
|
||||
StrCat(kInputPHName, i),
|
||||
StrCat(IONamePrefixes::kInputPHName, i),
|
||||
std::pair<void*, size_t>(device_address, device_tensor->TotalBytes()));
|
||||
}
|
||||
cres->calibrator_.reset(
|
||||
@ -727,56 +776,55 @@ Status TRTEngineOp::AllocateCalibrationResources(
|
||||
}
|
||||
|
||||
cache_res->Ref();
|
||||
cres->thr_.reset(
|
||||
new std::thread([this, cres, shapes, platform_gpu_id, cache_res]() {
|
||||
core::ScopedUnref sc(cache_res);
|
||||
cres->thr_.reset(new std::thread([this, cres, shapes, platform_gpu_id,
|
||||
cache_res]() {
|
||||
core::ScopedUnref sc(cache_res);
|
||||
|
||||
LOG(INFO) << "Starting calibration thread on device " << platform_gpu_id
|
||||
<< ", Calibration Resource @ " << cres;
|
||||
auto err = cudaSetDevice(platform_gpu_id);
|
||||
if (err != cudaSuccess) {
|
||||
// TODO(aaroey): should return error here.
|
||||
LOG(ERROR) << "Couldn't set cuda device to " << platform_gpu_id
|
||||
<< " in calibration thread";
|
||||
}
|
||||
std::vector<PartialTensorShape> partial_shapes(shapes.begin(),
|
||||
shapes.end());
|
||||
// ConvertGraphDefToEngine() will try to build the engine. This thread
|
||||
// will loop inside buildCudaEngine() consuming the calibration data
|
||||
// that is set by the TF op, and drive the builder until calibrator
|
||||
// returns false. Engine is discarded after calibration table is
|
||||
// generated
|
||||
//
|
||||
// TODO(aaroey): maybe setting the max batch size using the python
|
||||
// calibration wrapper class.
|
||||
auto s = convert::ConvertGraphDefToEngine(
|
||||
this->segment_graph_, TrtPrecisionMode::INT8,
|
||||
cres->calibrator_->getBatchSize(), this->workspace_size_,
|
||||
partial_shapes, &cache_res->GetLogger(),
|
||||
cache_res->allocator_.get(), cres->calibrator_.get(),
|
||||
&cres->engine_,
|
||||
/*use_calibration=*/true,
|
||||
/*convert_successfully=*/nullptr);
|
||||
if (!s.ok()) {
|
||||
LOG(ERROR) << "Calibration failed: " << s;
|
||||
cres->calibrator_->setDone(); // Ignore further pushes
|
||||
}
|
||||
LOG(INFO) << "Starting calibration thread on device " << platform_gpu_id
|
||||
<< ", Calibration Resource @ " << cres;
|
||||
auto err = cudaSetDevice(platform_gpu_id);
|
||||
if (err != cudaSuccess) {
|
||||
// TODO(aaroey): should return error here.
|
||||
LOG(ERROR) << "Couldn't set cuda device to " << platform_gpu_id
|
||||
<< " in calibration thread";
|
||||
}
|
||||
std::vector<PartialTensorShape> partial_shapes(shapes.begin(),
|
||||
shapes.end());
|
||||
// ConvertGraphDefToEngine() will try to build the engine. This thread
|
||||
// will loop inside buildCudaEngine() consuming the calibration data
|
||||
// that is set by the TF op, and drive the builder until calibrator
|
||||
// returns false. Engine is discarded after calibration table is
|
||||
// generated
|
||||
//
|
||||
// TODO(aaroey): maybe setting the max batch size using the python
|
||||
// calibration wrapper class.
|
||||
auto s = convert::ConvertGraphDefToEngine(
|
||||
this->segment_graph_, TrtPrecisionMode::INT8,
|
||||
cres->calibrator_->getBatchSize(), this->workspace_size_,
|
||||
partial_shapes, &cache_res->GetLogger(), cache_res->allocator_.get(),
|
||||
cres->calibrator_.get(), &cres->engine_,
|
||||
/*use_calibration=*/true,
|
||||
/*convert_successfully=*/nullptr);
|
||||
if (!s.ok()) {
|
||||
LOG(ERROR) << "Calibration failed: " << s;
|
||||
cres->calibrator_->setDone(); // Ignore further pushes
|
||||
}
|
||||
|
||||
// Transfer the ownership of the engine to the engine cache, so we can
|
||||
// dump it out during conversion for TF 2.0.
|
||||
if (cache_res) {
|
||||
mutex_lock lock(this->engine_mutex_);
|
||||
cres->SetCalibrationTable();
|
||||
this->calibrator_ = std::move(cres->calibrator_);
|
||||
TrtUniquePtrType<nvinfer1::IExecutionContext> exec_context(
|
||||
cres->engine_->createExecutionContext());
|
||||
cache_res->cache_.emplace(
|
||||
shapes, absl::make_unique<EngineContext>(
|
||||
std::move(cres->engine_), std::move(exec_context)));
|
||||
}
|
||||
// Transfer the ownership of the engine to the engine cache, so we can
|
||||
// dump it out during conversion for TF 2.0.
|
||||
if (cache_res) {
|
||||
mutex_lock lock(this->engine_mutex_);
|
||||
cres->SetCalibrationTable();
|
||||
this->calibrator_ = std::move(cres->calibrator_);
|
||||
TrtUniquePtrType<nvinfer1::IExecutionContext> exec_context(
|
||||
cres->engine_->createExecutionContext());
|
||||
cache_res->cache_.emplace(
|
||||
shapes, absl::make_unique<EngineContext>(std::move(cres->engine_),
|
||||
std::move(exec_context)));
|
||||
}
|
||||
|
||||
VLOG(1) << "Calibration loop terminated " << this->name();
|
||||
}));
|
||||
VLOG(1) << "Calibration loop terminated " << this->name();
|
||||
}));
|
||||
VLOG(1) << "initialized calibrator resource";
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -22,10 +22,17 @@ limitations under the License.
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/cc/ops/function_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
|
||||
#include "tensorflow/core/framework/fake_input.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/graph_to_functiondef.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||
@ -38,6 +45,7 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
using ::absl::StrCat;
|
||||
using ::testing::ElementsAre;
|
||||
|
||||
class TRTEngineOpTestBase : public OpsTestBase {
|
||||
@ -49,25 +57,32 @@ class TRTEngineOpTestBase : public OpsTestBase {
|
||||
|
||||
// Create simple TF graph.
|
||||
Scope s = Scope::NewRootScope();
|
||||
auto feed = ops::Placeholder(s.WithOpName("TensorRTInputPH_0"), dtype,
|
||||
ops::Placeholder::Shape({-1, -1}));
|
||||
auto feed = ops::_Arg(s.WithOpName("TensorRTInputPH_0"), dtype, 0);
|
||||
auto add = ops::Add(s.WithOpName("add"), feed, feed);
|
||||
ops::Identity(s.WithOpName("TensorRTOutputPH_0"), add);
|
||||
ops::_Retval(s.WithOpName("TensorRTOutputPH_0"), add, 0);
|
||||
|
||||
// Serialize the graph. TRTEngineOp will convert it using dynamic mode.
|
||||
GraphDef graph_def;
|
||||
TF_ASSERT_OK(s.ToGraphDef(&graph_def));
|
||||
Graph* graph = s.graph();
|
||||
const char* op_name = "myop";
|
||||
TF_ASSERT_OK(
|
||||
convert::RegisterGraphToFunctionLibrary(graph_def, graph, op_name));
|
||||
TF_ASSERT_OK(flib_def_->AddLibrary(graph->flib_def()));
|
||||
|
||||
PartialTensorShape shape({-1, -1});
|
||||
|
||||
// Create the op.
|
||||
OpsTestBase::SetDevice(DEVICE_GPU, std::move(device));
|
||||
TF_ASSERT_OK(NodeDefBuilder("myop", "TRTEngineOp")
|
||||
NameAttrList function;
|
||||
function.set_name(StrCat(op_name, "_native_segment"));
|
||||
TF_ASSERT_OK(NodeDefBuilder(op_name, "TRTEngineOp")
|
||||
.Input(FakeInput(1, dtype))
|
||||
.Attr("input_shapes", {shape})
|
||||
.Attr("output_shapes", {shape})
|
||||
.Attr("static_engine", false)
|
||||
.Attr("segment_funcdef_name", "") // no native fallback
|
||||
.Attr("serialized_segment", graph_def.SerializeAsString())
|
||||
.Attr("segment_func", function)
|
||||
.Attr("serialized_segment", "")
|
||||
.Attr("calibration_data", "")
|
||||
.Attr("max_cached_engines_count", max_cached_engines_count)
|
||||
.Attr("workspace_size_bytes", 1 << 20)
|
||||
@ -75,7 +90,7 @@ class TRTEngineOpTestBase : public OpsTestBase {
|
||||
.Attr("use_calibration", false)
|
||||
.Attr("OutT", {dtype})
|
||||
.Finalize(OpsTestBase::node_def()));
|
||||
TF_ASSERT_OK(OpsTestBase::InitOp());
|
||||
TF_ASSERT_OK(InitOpWithFunctionLibrary());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@ -89,9 +104,20 @@ class TRTEngineOpTestBase : public OpsTestBase {
|
||||
inputs_.clear();
|
||||
gtl::STLDeleteElements(&tensors_);
|
||||
}
|
||||
|
||||
private:
|
||||
Status InitOpWithFunctionLibrary() {
|
||||
OpKernel* kernel = nullptr;
|
||||
Status status = CreateOpKernel(device_type_, device_, allocator(),
|
||||
pflr_->GetFLR(device_->name()), node_def_,
|
||||
TF_GRAPH_DEF_VERSION, &kernel);
|
||||
kernel_ = std::unique_ptr<OpKernel>(kernel);
|
||||
if (kernel_ != nullptr) input_types_ = kernel_->input_types();
|
||||
return status;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(TRTEngineOpTestBase, dynamic_shapes) {
|
||||
TEST_F(TRTEngineOpTestBase, DynamicShapes) {
|
||||
TRTEngineOpTestBase::AddSimpleTrtOp(DT_FLOAT, /*max_cached_engines_count=*/4);
|
||||
|
||||
// Execute the op with batch size > 1.
|
||||
|
@ -33,7 +33,7 @@ namespace tensorflow {
|
||||
// key to cache the instantiated functions for different executor subgraphs.
|
||||
REGISTER_OP("TRTEngineOp")
|
||||
.Attr("serialized_segment: string")
|
||||
.Attr("segment_funcdef_name: string")
|
||||
.Attr("segment_func: func = {}")
|
||||
.Attr("InT: list({int8,float16,float32,int32})")
|
||||
.Attr("OutT: list({int8,float16,float32,int32})")
|
||||
.Attr("max_cached_engines_count: int = 1")
|
||||
@ -51,10 +51,11 @@ REGISTER_OP("TRTEngineOp")
|
||||
// inference function as a workaround.
|
||||
.SetShapeFn(shape_inference::UnknownShape)
|
||||
// Deprecated attributes.
|
||||
.Attr("segment_funcdef_name: string = ''")
|
||||
.Attr("cached_engine_batches: list(int) >= 0 = []")
|
||||
.Attr("fixed_input_size: bool = true")
|
||||
.Attr("input_shapes: list(shape)")
|
||||
.Attr("output_shapes: list(shape)")
|
||||
.Attr("input_shapes: list(shape) = []")
|
||||
.Attr("output_shapes: list(shape) = []")
|
||||
.Attr("static_engine: bool = true");
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -153,8 +153,7 @@ class QuantizationAwareTrainingMNISTTest(test_util.TensorFlowTestCase):
|
||||
# runtime to allocate GPU memory.
|
||||
max_workspace_size_bytes=1 << 28,
|
||||
minimum_segment_size=2,
|
||||
use_calibration=False,
|
||||
use_function_backup=False)
|
||||
use_calibration=False)
|
||||
graph_def = converter.convert()
|
||||
logging.info('Number of nodes after TF-TRT conversion: %d',
|
||||
len(graph_def.node))
|
||||
|
@ -23,6 +23,7 @@ import errno
|
||||
import gc
|
||||
import itertools
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
import warnings
|
||||
@ -234,10 +235,8 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
|
||||
is_dynamic_op=run_params.dynamic_engine,
|
||||
maximum_cached_engines=1,
|
||||
use_calibration=run_params.use_calibration,
|
||||
use_function_backup=False,
|
||||
max_batch_size=min(batch_list))
|
||||
return conversion_params._replace(
|
||||
use_function_backup=IsQuantizationWithCalibration(conversion_params))
|
||||
return conversion_params
|
||||
|
||||
def ShouldRunTest(self, run_params):
|
||||
"""Whether to run the test."""
|
||||
@ -388,8 +387,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
|
||||
minimum_segment_size=conversion_params.minimum_segment_size,
|
||||
is_dynamic_op=conversion_params.is_dynamic_op,
|
||||
maximum_cached_engines=conversion_params.maximum_cached_engines,
|
||||
use_calibration=conversion_params.use_calibration,
|
||||
use_function_backup=conversion_params.use_function_backup)
|
||||
use_calibration=conversion_params.use_calibration)
|
||||
|
||||
def _GetCalibratedInferGraph(self, run_params, saved_model_dir, inputs_data):
|
||||
"""Return trt converted graphdef in INT8 mode."""
|
||||
@ -558,21 +556,18 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
|
||||
if node.op == "TRTEngineOp":
|
||||
logging.info("Found TRTEngineOp: " + node.name)
|
||||
num_engines += 1
|
||||
segment_funcdef_name = node.attr["segment_funcdef_name"].s
|
||||
segment_funcdef_name = node.attr["segment_func"].func.name
|
||||
function_name = node.name + "_native_segment"
|
||||
if IsQuantizationWithCalibration(run_params):
|
||||
self.assertNotEmpty(segment_funcdef_name, node.name)
|
||||
self.assertIn(function_name, functions)
|
||||
else:
|
||||
self.assertEmpty(segment_funcdef_name, node.name)
|
||||
self.assertNotIn(function_name, functions)
|
||||
is_dynamic_engine = not node.attr["static_engine"].b
|
||||
self.assertNotEmpty(segment_funcdef_name, node.name)
|
||||
self.assertIn(function_name, functions)
|
||||
if not IsQuantizationWithCalibration and not is_dynamic_engine:
|
||||
self.assertTrue(len(node.attr["serialized_segment"].s), node.name)
|
||||
self.assertIn(node.name, expected_engines)
|
||||
self.assertTrue(len(node.attr["serialized_segment"].s), node.name)
|
||||
self.assertEqual(
|
||||
self._ToBytes(run_params.precision_mode),
|
||||
node.attr["precision_mode"].s, node.name)
|
||||
|
||||
is_dynamic_engine = not node.attr["static_engine"].b
|
||||
self.assertEqual(run_params.dynamic_engine, is_dynamic_engine,
|
||||
node.name)
|
||||
self.assertEqual(node.attr["use_calibration"].b,
|
||||
@ -602,10 +597,11 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
|
||||
node.name for node in gdef_to_verify.node if node.op == "TRTEngineOp"
|
||||
]
|
||||
for func in gdef_to_verify.library.function:
|
||||
for node in func.node_def:
|
||||
all_op_names.append(node.name)
|
||||
if node.op == "TRTEngineOp":
|
||||
trt_op_names.append(node.name)
|
||||
if not re.search(r"TRTEngineOp_\d+_native_segment", func.signature.name):
|
||||
for node in func.node_def:
|
||||
all_op_names.append(node.name)
|
||||
if node.op == "TRTEngineOp":
|
||||
trt_op_names.append(node.name)
|
||||
# Remove the function name prefix.
|
||||
def _Canonicalize(names):
|
||||
return set([self._ToString(name.split("/")[-1]) for name in names])
|
||||
|
@ -147,11 +147,6 @@ TrtConversionParams = collections.namedtuple(
|
||||
# trained with fake quantization.
|
||||
"use_calibration",
|
||||
|
||||
# If set to True, it will create a FunctionDef for each subgraph that is
|
||||
# converted to TRT op, and if TRT ops fail to execute at runtime, it'll
|
||||
# invoke that function as a fallback.
|
||||
"use_function_backup",
|
||||
|
||||
# Max size for the input batch.
|
||||
# This option is deprecated in TF 2.0.
|
||||
"max_batch_size",
|
||||
@ -165,7 +160,6 @@ DEFAULT_TRT_CONVERSION_PARAMS = TrtConversionParams(
|
||||
is_dynamic_op=False,
|
||||
maximum_cached_engines=1,
|
||||
use_calibration=True,
|
||||
use_function_backup=True,
|
||||
max_batch_size=1)
|
||||
|
||||
_TRT_ENGINE_CACHE_CONTAINER_NAME = "TF-TRT-Engine-Cache"
|
||||
@ -272,8 +266,6 @@ def get_tensorrt_rewriter_config(
|
||||
"maximum_cached_engines"].i = conversion_params.maximum_cached_engines
|
||||
optimizer.parameter_map[
|
||||
"use_calibration"].b = conversion_params.use_calibration
|
||||
optimizer.parameter_map[
|
||||
"use_function_backup"].b = conversion_params.use_function_backup
|
||||
|
||||
if is_v2:
|
||||
# Static mode (a.k.a pre-generating TRT engines and make them node
|
||||
@ -344,8 +336,7 @@ class TrtGraphConverter(object):
|
||||
minimum_segment_size=3,
|
||||
is_dynamic_op=False,
|
||||
maximum_cached_engines=1,
|
||||
use_calibration=True,
|
||||
use_function_backup=True):
|
||||
use_calibration=True):
|
||||
"""Initialize the converter.
|
||||
|
||||
Args:
|
||||
@ -384,9 +375,6 @@ class TrtGraphConverter(object):
|
||||
will occur. Please note that accuracy may be negatively affected if
|
||||
there is a mismatch between which tensors TRT quantizes and which
|
||||
tensors were trained with fake quantization.
|
||||
use_function_backup: if set to True, it will create a FunctionDef for each
|
||||
subgraph that is converted to TRT op, and if TRT ops fail to execute at
|
||||
runtime, it'll invoke that function as a fallback.
|
||||
|
||||
Raises:
|
||||
ValueError: if the combination of the parameters is invalid.
|
||||
@ -424,12 +412,6 @@ class TrtGraphConverter(object):
|
||||
"dynamic TRT ops only. Disregarding is_dynamic_op parameter.")
|
||||
is_dynamic_op = True
|
||||
|
||||
# TODO(laigd): consider provide a mechanism to remove the fallback path
|
||||
# after calibration is done.
|
||||
if self._need_calibration and not use_function_backup:
|
||||
raise ValueError(
|
||||
"Calibration requires enabling fallback to TF function execution.")
|
||||
|
||||
# TODO(laigd):
|
||||
# - Verify in int8 mode that maximum_cached_engines is set properly.
|
||||
# - If it fails to build the int8 engine it should return error.
|
||||
@ -446,7 +428,6 @@ class TrtGraphConverter(object):
|
||||
is_dynamic_op=is_dynamic_op,
|
||||
maximum_cached_engines=maximum_cached_engines,
|
||||
use_calibration=use_calibration,
|
||||
use_function_backup=use_function_backup,
|
||||
max_batch_size=max_batch_size)
|
||||
_check_conversion_params(self._conversion_params)
|
||||
|
||||
|
@ -205,8 +205,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
max_batch_size=1,
|
||||
minimum_segment_size=3,
|
||||
is_dynamic_op=False,
|
||||
maximum_cached_engines=1,
|
||||
use_function_backup=False):
|
||||
maximum_cached_engines=1):
|
||||
"""Helper method to convert a GraphDef or SavedModel using TF-TRT."""
|
||||
converter = trt_convert.TrtGraphConverter(
|
||||
input_saved_model_dir=input_saved_model_dir,
|
||||
@ -220,8 +219,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
else trt_convert.TrtPrecisionMode.FP32),
|
||||
minimum_segment_size=minimum_segment_size,
|
||||
is_dynamic_op=is_dynamic_op,
|
||||
maximum_cached_engines=maximum_cached_engines,
|
||||
use_function_backup=use_function_backup)
|
||||
maximum_cached_engines=maximum_cached_engines)
|
||||
output_graph_def = converter.convert()
|
||||
|
||||
if need_calibration:
|
||||
@ -254,8 +252,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
input_saved_model_dir=input_saved_model_dir,
|
||||
output_saved_model_dir=output_saved_model_dir,
|
||||
need_calibration=need_calibration,
|
||||
is_dynamic_op=is_dynamic_op,
|
||||
use_function_backup=need_calibration)
|
||||
is_dynamic_op=is_dynamic_op)
|
||||
graph_defs_to_verify = [output_graph_def]
|
||||
|
||||
if output_saved_model_dir:
|
||||
@ -316,8 +313,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
conversion_params=trt_convert.DEFAULT_TRT_CONVERSION_PARAMS._replace(
|
||||
precision_mode=trt_convert.TrtPrecisionMode.FP32,
|
||||
is_dynamic_op=True,
|
||||
maximum_cached_engines=2,
|
||||
use_function_backup=False))
|
||||
maximum_cached_engines=2))
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testTrtGraphConverter_BasicConversion_v2(self):
|
||||
@ -564,17 +560,9 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
def _TestRun(self,
|
||||
sess,
|
||||
batch_size,
|
||||
use_function_backup=False,
|
||||
expect_engine_is_run=True):
|
||||
try:
|
||||
result = sess.run(
|
||||
"output:0", feed_dict={"input:0": [[[1.0]]] * batch_size})
|
||||
self.assertAllEqual([[[4.0]]] * batch_size, result)
|
||||
except errors.OpError as e:
|
||||
# This should happen only when fallback path is disabled and TRT engine
|
||||
# fails to run.
|
||||
self.assertTrue(not use_function_backup and not expect_engine_is_run)
|
||||
self.assertIn("Fallback path is disabled, for TRTEngineOp_0", str(e))
|
||||
result = sess.run("output:0", feed_dict={"input:0": [[[1.0]]] * batch_size})
|
||||
self.assertAllEqual([[[4.0]]] * batch_size, result)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
def testTrtGraphConverter_MinimumSegmentSize(self):
|
||||
@ -604,8 +592,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
input_saved_model_dir=input_saved_model_dir,
|
||||
output_saved_model_dir=output_saved_model_dir,
|
||||
is_dynamic_op=True,
|
||||
maximum_cached_engines=2,
|
||||
use_function_backup=False) # Disallow fallback.
|
||||
maximum_cached_engines=2)
|
||||
|
||||
# Test the output GraphDef.
|
||||
with ops.Graph().as_default():
|
||||
@ -631,7 +618,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
# the max, it should evict an old engine and create a new one.
|
||||
self._TestRun(sess, 3)
|
||||
|
||||
def _TestStaticOp(self, use_function_backup):
|
||||
def _TestStaticOp(self):
|
||||
if not is_tensorrt_enabled():
|
||||
return
|
||||
|
||||
@ -641,8 +628,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
output_graph_def = self._ConvertGraph(
|
||||
input_saved_model_dir=input_saved_model_dir,
|
||||
output_saved_model_dir=output_saved_model_dir,
|
||||
maximum_cached_engines=2, # This is noop, added just for testing.
|
||||
use_function_backup=use_function_backup)
|
||||
maximum_cached_engines=2)
|
||||
|
||||
# Test the output GraphDef.
|
||||
with ops.Graph().as_default():
|
||||
@ -653,14 +639,12 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
self._TestRun(
|
||||
sess,
|
||||
1,
|
||||
use_function_backup=use_function_backup,
|
||||
expect_engine_is_run=True)
|
||||
# Run with batch size 2, which exceed the max_batch_size, it should try
|
||||
# to fall back to TF function.
|
||||
self._TestRun(
|
||||
sess,
|
||||
2,
|
||||
use_function_backup=use_function_backup,
|
||||
expect_engine_is_run=False)
|
||||
|
||||
# Test the output SavedModel
|
||||
@ -672,23 +656,17 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
self._TestRun(
|
||||
sess,
|
||||
1,
|
||||
use_function_backup=use_function_backup,
|
||||
expect_engine_is_run=True)
|
||||
# Run with batch size 2, which exceed the max_batch_size, it should try
|
||||
# to fall back to TF function.
|
||||
self._TestRun(
|
||||
sess,
|
||||
2,
|
||||
use_function_backup=use_function_backup,
|
||||
expect_engine_is_run=False)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
def testTrtGraphConverter_StaticOp_NoFallback(self):
|
||||
self._TestStaticOp(use_function_backup=False)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
def testTrtGraphConverter_StaticOp_WithFallback(self):
|
||||
self._TestStaticOp(use_function_backup=True)
|
||||
def testTrtGraphConverter_StaticOp(self):
|
||||
self._TestStaticOp()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
Reference in New Issue
Block a user