Merge pull request #30789 from phillip-kravtsov:trt_remove_serialized_from_op
PiperOrigin-RevId: 260966558
This commit is contained in:
commit
68edd47a23
@ -97,10 +97,17 @@ cc_library(
|
|||||||
":utils",
|
":utils",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
|
"@local_config_cuda//cuda:cuda_headers",
|
||||||
|
"//tensorflow/core:core_cpu_lib_no_ops",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:gpu_headers_lib",
|
"//tensorflow/core:gpu_headers_lib",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core:lib_proto_parsing",
|
"//tensorflow/core:lib_proto_parsing",
|
||||||
|
"//tensorflow/core:stream_executor",
|
||||||
"//tensorflow/core:stream_executor_headers_lib",
|
"//tensorflow/core:stream_executor_headers_lib",
|
||||||
"//tensorflow/core/grappler/costs:graph_properties",
|
"//tensorflow/core/grappler/costs:graph_properties",
|
||||||
|
"//tensorflow/stream_executor/lib",
|
||||||
] + if_tensorrt([":tensorrt_lib"]) + tf_custom_op_library_additional_deps(),
|
] + if_tensorrt([":tensorrt_lib"]) + tf_custom_op_library_additional_deps(),
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
@ -168,8 +175,12 @@ tf_cuda_cc_test(
|
|||||||
":trt_op_kernels",
|
":trt_op_kernels",
|
||||||
":trt_op_libs",
|
":trt_op_libs",
|
||||||
":trt_resources",
|
":trt_resources",
|
||||||
|
":trt_conversion",
|
||||||
|
":utils",
|
||||||
"@com_google_googletest//:gtest",
|
"@com_google_googletest//:gtest",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
"//tensorflow/cc:cc_ops",
|
"//tensorflow/cc:cc_ops",
|
||||||
|
"//tensorflow/cc:function_ops",
|
||||||
"//tensorflow/cc:ops",
|
"//tensorflow/cc:ops",
|
||||||
"//tensorflow/cc:scope",
|
"//tensorflow/cc:scope",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
@ -248,6 +259,9 @@ tf_cuda_library(
|
|||||||
":utils",
|
":utils",
|
||||||
"//tensorflow/core:framework_headers_lib",
|
"//tensorflow/core:framework_headers_lib",
|
||||||
"//tensorflow/core:framework_lite",
|
"//tensorflow/core:framework_lite",
|
||||||
|
"//tensorflow/core/grappler:op_types",
|
||||||
|
"//tensorflow/core:graph",
|
||||||
|
"//tensorflow/core:gpu_runtime",
|
||||||
"//tensorflow/core:lib_proto_parsing",
|
"//tensorflow/core:lib_proto_parsing",
|
||||||
] + if_tensorrt([":tensorrt_lib"]),
|
] + if_tensorrt([":tensorrt_lib"]),
|
||||||
)
|
)
|
||||||
@ -318,11 +332,13 @@ tf_cuda_library(
|
|||||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
|
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
|
||||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
|
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
|
||||||
"//tensorflow/core/grappler:grappler_item",
|
"//tensorflow/core/grappler:grappler_item",
|
||||||
|
"//tensorflow/core/grappler:op_types",
|
||||||
"//tensorflow/core/grappler:utils",
|
"//tensorflow/core/grappler:utils",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:framework_lite",
|
"//tensorflow/core:framework_lite",
|
||||||
"//tensorflow/core:gpu_runtime",
|
"//tensorflow/core:gpu_runtime",
|
||||||
"//tensorflow/core:graph",
|
"//tensorflow/core:graph",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/grappler:devices",
|
"//tensorflow/core/grappler:devices",
|
||||||
|
@ -324,8 +324,6 @@ Status CreateTRTNode(const ConversionParams& params,
|
|||||||
nvinfer1::IGpuAllocator* alloc,
|
nvinfer1::IGpuAllocator* alloc,
|
||||||
std::vector<Node*>* engine_nodes) {
|
std::vector<Node*>* engine_nodes) {
|
||||||
const auto& info = infos.at(pos);
|
const auto& info = infos.at(pos);
|
||||||
std::vector<TensorShapeProto> output_shape_protos;
|
|
||||||
std::vector<TensorShapeProto> input_shape_protos;
|
|
||||||
std::vector<PartialTensorShape> input_shapes;
|
std::vector<PartialTensorShape> input_shapes;
|
||||||
std::vector<NodeDefBuilder::NodeOut> inputs;
|
std::vector<NodeDefBuilder::NodeOut> inputs;
|
||||||
std::vector<Node*> input_nodes;
|
std::vector<Node*> input_nodes;
|
||||||
@ -359,25 +357,16 @@ Status CreateTRTNode(const ConversionParams& params,
|
|||||||
} else {
|
} else {
|
||||||
// Data edges
|
// Data edges
|
||||||
if (!conn.is_input_edge) {
|
if (!conn.is_input_edge) {
|
||||||
// Set the shapes and data types of output edge.
|
// Set the data types of output edge.
|
||||||
TensorShapeProto out_shape;
|
if (out_types.size() <= conn.port_number) {
|
||||||
// shape of the output node inside segment
|
|
||||||
conn.inside_shape.AsProto(&out_shape);
|
|
||||||
if (output_shape_protos.size() <= conn.port_number) {
|
|
||||||
output_shape_protos.resize(conn.port_number + 1);
|
|
||||||
out_types.resize(conn.port_number + 1);
|
out_types.resize(conn.port_number + 1);
|
||||||
}
|
}
|
||||||
output_shape_protos.at(conn.port_number) = out_shape;
|
|
||||||
out_types.at(conn.port_number) = conn.connection_type;
|
out_types.at(conn.port_number) = conn.connection_type;
|
||||||
} else {
|
} else {
|
||||||
// Set the shapes and data types of input edge.
|
// Set the shapes and data types of input edge.
|
||||||
TensorShapeProto in_shape;
|
if (input_shapes.size() <= conn.port_number) {
|
||||||
conn.outside_shape.AsProto(&in_shape);
|
|
||||||
if (input_shape_protos.size() <= conn.port_number) {
|
|
||||||
input_shape_protos.resize(conn.port_number + 1);
|
|
||||||
input_shapes.resize(conn.port_number + 1);
|
input_shapes.resize(conn.port_number + 1);
|
||||||
}
|
}
|
||||||
input_shape_protos.at(conn.port_number) = in_shape;
|
|
||||||
input_shapes.at(conn.port_number) = conn.outside_shape;
|
input_shapes.at(conn.port_number) = conn.outside_shape;
|
||||||
// Shape must be fully defined (excluding batch dimension) for static
|
// Shape must be fully defined (excluding batch dimension) for static
|
||||||
// mode.
|
// mode.
|
||||||
@ -439,8 +428,6 @@ Status CreateTRTNode(const ConversionParams& params,
|
|||||||
TrtUniquePtrType<nvinfer1::IHostMemory> engine_data(engine->serialize());
|
TrtUniquePtrType<nvinfer1::IHostMemory> engine_data(engine->serialize());
|
||||||
segment_string = string(static_cast<const char*>(engine_data->data()),
|
segment_string = string(static_cast<const char*>(engine_data->data()),
|
||||||
engine_data->size());
|
engine_data->size());
|
||||||
} else {
|
|
||||||
segment_string = info.segment_graph_def.SerializeAsString();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
string prec_string;
|
string prec_string;
|
||||||
@ -460,15 +447,13 @@ Status CreateTRTNode(const ConversionParams& params,
|
|||||||
}
|
}
|
||||||
|
|
||||||
NodeDef trt_node;
|
NodeDef trt_node;
|
||||||
|
NameAttrList function;
|
||||||
|
function.set_name(StrCat(info.engine_name, "_native_segment"));
|
||||||
Status status =
|
Status status =
|
||||||
node_builder.Attr("input_shapes", input_shape_protos)
|
node_builder
|
||||||
.Attr("output_shapes", output_shape_protos)
|
|
||||||
.Attr("static_engine",
|
.Attr("static_engine",
|
||||||
info.engine_type == EngineInfo::EngineType::TRTStatic)
|
info.engine_type == EngineInfo::EngineType::TRTStatic)
|
||||||
.Attr("segment_funcdef_name",
|
.Attr("segment_func", function)
|
||||||
params.use_function_backup
|
|
||||||
? StrCat(info.engine_name, "_native_segment")
|
|
||||||
: "")
|
|
||||||
.Attr("serialized_segment", segment_string)
|
.Attr("serialized_segment", segment_string)
|
||||||
.Attr("calibration_data", "")
|
.Attr("calibration_data", "")
|
||||||
.Attr("max_cached_engines_count", info.maximum_cached_engines)
|
.Attr("max_cached_engines_count", info.maximum_cached_engines)
|
||||||
@ -537,103 +522,27 @@ Status CreateTRTNode(const ConversionParams& params,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Function to construct a funcdef from the segment and add it to the graph.
|
Status RegisterGraphToFunctionLibrary(const GraphDef& segment_graph_def,
|
||||||
Status RegisterSegmentFunctionToFunctionLibrary(Graph* graph,
|
Graph* graph, const string& engine_name) {
|
||||||
const GraphDef& segment,
|
Graph segment_graph(graph->flib_def());
|
||||||
const string& engine_name) {
|
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(),
|
||||||
Graph sgraph(graph->flib_def());
|
segment_graph_def, &segment_graph));
|
||||||
GraphConstructorOptions gcopts;
|
FunctionDefLibrary library;
|
||||||
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(gcopts, segment, &sgraph));
|
auto segment_func = library.add_function();
|
||||||
std::map<string, Node*> io_nodes;
|
|
||||||
int num_inputs = 0;
|
|
||||||
for (auto n : sgraph.op_nodes()) {
|
|
||||||
if (absl::StartsWith(n->name(), kInputPHName)) {
|
|
||||||
num_inputs++;
|
|
||||||
io_nodes.insert({n->name(), n});
|
|
||||||
} else if (absl::StartsWith(n->name(), kOutputPHName)) {
|
|
||||||
io_nodes.insert({n->name(), n});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i < num_inputs; ++i) {
|
|
||||||
auto name = StrCat(kInputPHName, i);
|
|
||||||
auto node = io_nodes[name];
|
|
||||||
NodeDef nd;
|
|
||||||
NodeDefBuilder node_builder(StrCat(name, "_Arg"),
|
|
||||||
FunctionLibraryDefinition::kArgOp);
|
|
||||||
VLOG(1) << "Adding " << StrCat(name, "_Arg");
|
|
||||||
TF_RETURN_IF_ERROR(node_builder.Attr("T", node->output_type(0))
|
|
||||||
.Attr("index", i)
|
|
||||||
.Finalize(&nd));
|
|
||||||
Status s;
|
|
||||||
auto node_arg = sgraph.AddNode(nd, &s);
|
|
||||||
if (!s.ok()) {
|
|
||||||
LOG(ERROR) << "Couldn't add _Arg node for " << name;
|
|
||||||
}
|
|
||||||
for (auto edge : node->out_edges()) {
|
|
||||||
sgraph.AddEdge(node_arg, 0, edge->dst(), edge->dst_input());
|
|
||||||
VLOG(1) << "Updating funcdef input " << node_arg->name() << ":" << 0
|
|
||||||
<< " - > " << edge->dst()->name() << ":" << edge->dst_input();
|
|
||||||
if (!s.ok()) {
|
|
||||||
LOG(ERROR) << "Failed to update edge from " << node_arg->name()
|
|
||||||
<< " to " << edge->dst()->name() << ":" << edge->dst_input();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
sgraph.RemoveNode(node);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i < io_nodes.size() - num_inputs; ++i) {
|
|
||||||
auto name = StrCat(kOutputPHName, i);
|
|
||||||
auto node = io_nodes[name];
|
|
||||||
NodeDef nd;
|
|
||||||
NodeDefBuilder node_builder(StrCat(name, "_Ret"),
|
|
||||||
FunctionLibraryDefinition::kRetOp);
|
|
||||||
auto edge = *(node->in_edges().begin());
|
|
||||||
NodeDefBuilder::NodeOut nout(edge->src()->name(), edge->src_output(),
|
|
||||||
edge->src()->output_type(edge->src_output()));
|
|
||||||
VLOG(1) << " input " << nout.node << ":" << nout.index
|
|
||||||
<< " dtype=" << DataTypeString(nout.data_type);
|
|
||||||
// nvcc complains that Input(<brace-enclosed initializer list>) is
|
|
||||||
// ambiguous, so do not use Input({nout}).
|
|
||||||
node_builder.Input(nout);
|
|
||||||
TF_RETURN_IF_ERROR(node_builder.Attr("T", node->output_type(0))
|
|
||||||
.Attr("index", i)
|
|
||||||
.Finalize(&nd));
|
|
||||||
if (VLOG_IS_ON(3)) {
|
|
||||||
VLOG(3) << nd.DebugString();
|
|
||||||
}
|
|
||||||
Status s;
|
|
||||||
auto node_ret = sgraph.AddNode(nd, &s);
|
|
||||||
if (!s.ok()) {
|
|
||||||
LOG(ERROR) << "Couldn't add _Ret node for " << name;
|
|
||||||
}
|
|
||||||
VLOG(1) << "Update edge from " << edge->src()->name() << ":"
|
|
||||||
<< edge->src_output() << " - > " << node_ret->name() << ":" << 0;
|
|
||||||
sgraph.AddEdge(edge->src(), edge->src_output(), node_ret, 0);
|
|
||||||
s = sgraph.UpdateEdge(edge->src(), edge->src_output(), node_ret, 0);
|
|
||||||
if (!s.ok()) {
|
|
||||||
LOG(ERROR) << "Failed to update edge from " << edge->src()->name() << ":"
|
|
||||||
<< edge->src_output() << " - > " << node_ret->name() << ":"
|
|
||||||
<< 0;
|
|
||||||
}
|
|
||||||
sgraph.RemoveNode(node);
|
|
||||||
}
|
|
||||||
FunctionDefLibrary fdeflib;
|
|
||||||
auto native_segment = fdeflib.add_function();
|
|
||||||
TF_RETURN_IF_ERROR(GraphToFunctionDef(
|
TF_RETURN_IF_ERROR(GraphToFunctionDef(
|
||||||
sgraph, StrCat(engine_name, "_native_segment"), native_segment));
|
segment_graph, StrCat(engine_name, "_native_segment"), segment_func));
|
||||||
// Set kIntsonDeviceAttr to true so that all TRTEngineOp outputs are always on
|
// Set kIntsonDeviceAttr to true so that all TRTEngineOp outputs are always on
|
||||||
// a GPU device as expected. Otherwise, some of the tensors of type DT_INT32
|
// a GPU device as expected. Otherwise, some of the tensors of type DT_INT32
|
||||||
// would be on host if the op generating the tensor has host memory tag set.
|
// would be on host if the op generating the tensor has host memory tag set.
|
||||||
(*native_segment
|
(*segment_func->mutable_attr())[FunctionLibraryDefinition::kIntsOnDeviceAttr]
|
||||||
->mutable_attr())[FunctionLibraryDefinition::kIntsOnDeviceAttr]
|
|
||||||
.set_b(true);
|
.set_b(true);
|
||||||
if (VLOG_IS_ON(7)) {
|
if (VLOG_IS_ON(7)) {
|
||||||
VLOG(7) << engine_name << " Function_Def ";
|
VLOG(7) << engine_name << " Function_Def ";
|
||||||
VLOG(7) << native_segment->DebugString();
|
VLOG(7) << segment_func->DebugString();
|
||||||
}
|
}
|
||||||
VLOG(1) << "Adding funcdef to graphlib";
|
VLOG(1) << "Adding funcdef " << segment_func->signature().name()
|
||||||
TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(fdeflib));
|
<< " to graphlib";
|
||||||
|
TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(library));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -690,16 +599,10 @@ std::pair<int, Allocator*> GetDeviceAndAllocator(const ConversionParams& params,
|
|||||||
// Entry function from optimization pass.
|
// Entry function from optimization pass.
|
||||||
Status ConvertAfterShapes(const ConversionParams& params) {
|
Status ConvertAfterShapes(const ConversionParams& params) {
|
||||||
// Sanity checks.
|
// Sanity checks.
|
||||||
if (params.precision_mode == TrtPrecisionMode::INT8) {
|
if (params.precision_mode != TrtPrecisionMode::INT8 &&
|
||||||
if (params.use_calibration && !params.use_function_backup) {
|
params.use_calibration) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"Calibration requires enabling fallback to TF function execution.");
|
"Calibration with FP32 or FP16 is not supported.");
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (params.use_calibration) {
|
|
||||||
return errors::InvalidArgument(
|
|
||||||
"Calibration with FP32 or FP16 is not supported.");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert graphdef to graph.
|
// Convert graphdef to graph.
|
||||||
@ -760,14 +663,14 @@ Status ConvertAfterShapes(const ConversionParams& params) {
|
|||||||
: EngineInfo::EngineType::TRTStatic);
|
: EngineInfo::EngineType::TRTStatic);
|
||||||
curr_engine.use_calibration = params.use_calibration;
|
curr_engine.use_calibration = params.use_calibration;
|
||||||
curr_engine.maximum_cached_engines = params.max_cached_engines;
|
curr_engine.maximum_cached_engines = params.max_cached_engines;
|
||||||
if (params.use_function_backup) {
|
|
||||||
status = RegisterSegmentFunctionToFunctionLibrary(
|
status = RegisterGraphToFunctionLibrary(curr_engine.segment_graph_def,
|
||||||
&graph, curr_engine.segment_graph_def, curr_engine.engine_name);
|
&graph, curr_engine.engine_name);
|
||||||
if (!status.ok()) {
|
|
||||||
LOG(WARNING) << "Failed to register segment graphdef as a function "
|
if (!status.ok()) {
|
||||||
<< t << ": " << status;
|
LOG(WARNING) << "Failed to register segment graphdef to the library " << t
|
||||||
continue;
|
<< ": " << status;
|
||||||
}
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
engine_bytes_size.push_back(curr_engine.segment_graph_def.ByteSizeLong());
|
engine_bytes_size.push_back(curr_engine.segment_graph_def.ByteSizeLong());
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
|
#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
|
||||||
|
#include "tensorflow/core/framework/function.pb.h"
|
||||||
#include "tensorflow/core/framework/graph.pb.h"
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
#include "tensorflow/core/grappler/clusters/cluster.h"
|
#include "tensorflow/core/grappler/clusters/cluster.h"
|
||||||
#include "tensorflow/core/grappler/costs/graph_properties.h"
|
#include "tensorflow/core/grappler/costs/graph_properties.h"
|
||||||
@ -46,8 +47,6 @@ struct ConversionParams {
|
|||||||
// maximum number of cached engines
|
// maximum number of cached engines
|
||||||
int max_cached_engines = 1;
|
int max_cached_engines = 1;
|
||||||
bool use_calibration = true;
|
bool use_calibration = true;
|
||||||
// Whether to use function fallback for TRTEngineOp
|
|
||||||
bool use_function_backup = true;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Method to call from optimization pass
|
// Method to call from optimization pass
|
||||||
@ -57,6 +56,11 @@ Status ConvertAfterShapes(const ConversionParams& params);
|
|||||||
std::pair<int, Allocator*> GetDeviceAndAllocator(const ConversionParams& params,
|
std::pair<int, Allocator*> GetDeviceAndAllocator(const ConversionParams& params,
|
||||||
const EngineInfo& engine);
|
const EngineInfo& engine);
|
||||||
|
|
||||||
|
// Helper method that registers `segment_graph` as a function to the function
|
||||||
|
// library in `graph`.
|
||||||
|
Status RegisterGraphToFunctionLibrary(const GraphDef& segment_graph_def,
|
||||||
|
Graph* graph, const string& engine_name);
|
||||||
|
|
||||||
} // namespace convert
|
} // namespace convert
|
||||||
} // namespace tensorrt
|
} // namespace tensorrt
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -39,6 +39,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/graph/algorithm.h"
|
#include "tensorflow/core/graph/algorithm.h"
|
||||||
#include "tensorflow/core/graph/graph.h"
|
#include "tensorflow/core/graph/graph.h"
|
||||||
#include "tensorflow/core/graph/graph_constructor.h"
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
|
#include "tensorflow/core/grappler/op_types.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/lib/strings/numbers.h"
|
#include "tensorflow/core/lib/strings/numbers.h"
|
||||||
@ -75,18 +76,15 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace tensorrt {
|
namespace tensorrt {
|
||||||
// TODO(aaroey): put these constants into some class.
|
namespace convert {
|
||||||
const char* const kInputPHName = "TensorRTInputPH_";
|
|
||||||
const char* const kOutputPHName = "TensorRTOutputPH_";
|
|
||||||
|
|
||||||
bool IsEngineInput(absl::string_view name) {
|
bool IsEngineInput(absl::string_view name) {
|
||||||
return absl::StartsWith(name, kInputPHName);
|
return absl::StartsWith(name, IONamePrefixes::kInputPHName);
|
||||||
}
|
}
|
||||||
bool IsEngineOutput(absl::string_view name) {
|
bool IsEngineOutput(absl::string_view name) {
|
||||||
return absl::StartsWith(name, kOutputPHName);
|
return absl::StartsWith(name, IONamePrefixes::kOutputPHName);
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace convert {
|
|
||||||
using absl::StrAppend;
|
using absl::StrAppend;
|
||||||
using absl::StrCat;
|
using absl::StrCat;
|
||||||
|
|
||||||
@ -5200,19 +5198,33 @@ Status ConvertGraphDefToEngine(
|
|||||||
for (const auto& node_def : gdef.node()) {
|
for (const auto& node_def : gdef.node()) {
|
||||||
string node_name = node_def.name();
|
string node_name = node_def.name();
|
||||||
VLOG(2) << "Converting op name=" << node_name << ", op=" << node_def.op();
|
VLOG(2) << "Converting op name=" << node_name << ", op=" << node_def.op();
|
||||||
if (IsEngineInput(node_name) && (node_def.op() == "Placeholder")) {
|
if (IsEngineInput(node_name)) {
|
||||||
int32 slot_number = -1;
|
int32 slot_number = -1;
|
||||||
if (!strings::safe_strto32( // non-absl ok
|
string type_key;
|
||||||
node_name.c_str() + strlen(kInputPHName), &slot_number)) {
|
if (node_def.op() == "Placeholder") {
|
||||||
return errors::InvalidArgument("Failed to parse slot number from ",
|
if (!strings::safe_strto32( // non-absl ok
|
||||||
node_name);
|
node_name.c_str() + strlen(IONamePrefixes::kInputPHName),
|
||||||
|
&slot_number)) {
|
||||||
|
return errors::InvalidArgument("Failed to parse slot number from ",
|
||||||
|
node_name);
|
||||||
|
}
|
||||||
|
type_key = "dtype";
|
||||||
|
} else if (tensorflow::grappler::IsArg(node_def)) {
|
||||||
|
// Maybe remove the dependence on grappler and re-implement IsArg,
|
||||||
|
// which is pretty simple (but could change if new Arg nodes are added)
|
||||||
|
slot_number = node_def.attr().at("index").i();
|
||||||
|
type_key = "T";
|
||||||
|
} else {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Node ", node_name,
|
||||||
|
" with is neither Placeholder nor Arg, instead ", node_def.op());
|
||||||
}
|
}
|
||||||
nvinfer1::DataType trt_dtype;
|
nvinfer1::DataType trt_dtype;
|
||||||
nvinfer1::Dims trt_dims;
|
nvinfer1::Dims trt_dims;
|
||||||
int batch_size = -1;
|
int batch_size = -1;
|
||||||
auto shape = input_shapes.at(slot_number);
|
auto shape = input_shapes.at(slot_number);
|
||||||
auto status = ValidateTensorProperties(
|
auto status = ValidateTensorProperties(
|
||||||
node_def.op(), node_def.attr().at("dtype").type(), shape,
|
node_def.op(), node_def.attr().at(type_key).type(), shape,
|
||||||
/*validation_only=*/false, &trt_dtype, &trt_dims, &batch_size);
|
/*validation_only=*/false, &trt_dtype, &trt_dims, &batch_size);
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
const string error_message =
|
const string error_message =
|
||||||
@ -5228,12 +5240,23 @@ Status ConvertGraphDefToEngine(
|
|||||||
// engines offline, by calling sess.run() and cache/serialize the engines.
|
// engines offline, by calling sess.run() and cache/serialize the engines.
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
converter.AddInputTensor(node_name, trt_dtype, trt_dims, batch_size));
|
converter.AddInputTensor(node_name, trt_dtype, trt_dims, batch_size));
|
||||||
} else if (IsEngineOutput(node_name) && (node_def.op() == "Identity")) {
|
} else if (IsEngineOutput(node_name)) {
|
||||||
int32 slot_number = -1;
|
int32 slot_number = -1;
|
||||||
if (!strings::safe_strto32( // non-absl ok
|
if (node_def.op() == "Identity") {
|
||||||
node_name.c_str() + strlen(kOutputPHName), &slot_number)) {
|
if (!strings::safe_strto32( // non-absl ok
|
||||||
return errors::InvalidArgument("Failed to parse slot number from ",
|
node_name.c_str() + strlen(IONamePrefixes::kOutputPHName),
|
||||||
node_name);
|
&slot_number)) {
|
||||||
|
return errors::InvalidArgument("Failed to parse slot number from ",
|
||||||
|
node_name);
|
||||||
|
}
|
||||||
|
} else if (tensorflow::grappler::IsRetval(node_def)) {
|
||||||
|
slot_number = node_def.attr().at("index").i();
|
||||||
|
} else {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Node with name ", node_name,
|
||||||
|
" starting with IONamePrefixes::kOutputPHName is "
|
||||||
|
"neither Identity nor Retval, instead ",
|
||||||
|
node_def.op());
|
||||||
}
|
}
|
||||||
// Get output type that TensorFlow expects
|
// Get output type that TensorFlow expects
|
||||||
TFAttrs attrs(node_def);
|
TFAttrs attrs(node_def);
|
||||||
@ -5302,7 +5325,8 @@ Status ConvertSegmentToGraphDef(
|
|||||||
|
|
||||||
// Add dummy input/output nodes to the segment graphdef.
|
// Add dummy input/output nodes to the segment graphdef.
|
||||||
if (connection.is_input_edge) {
|
if (connection.is_input_edge) {
|
||||||
const string node_name = StrCat(kInputPHName, connection.port_number);
|
const string node_name =
|
||||||
|
StrCat(IONamePrefixes::kInputPHName, connection.port_number);
|
||||||
if (marker_nodes.count(node_name)) {
|
if (marker_nodes.count(node_name)) {
|
||||||
VLOG(1) << "Reusing input " << node_name << " for the edge "
|
VLOG(1) << "Reusing input " << node_name << " for the edge "
|
||||||
<< connection.outside_node_name << ":"
|
<< connection.outside_node_name << ":"
|
||||||
@ -5312,16 +5336,18 @@ Status ConvertSegmentToGraphDef(
|
|||||||
}
|
}
|
||||||
marker_nodes.insert(node_name);
|
marker_nodes.insert(node_name);
|
||||||
auto seg_node = segment_def->add_node();
|
auto seg_node = segment_def->add_node();
|
||||||
NodeDefBuilder builder(node_name, "Placeholder");
|
NodeDefBuilder builder(node_name, "_Arg");
|
||||||
auto status = builder.Attr("shape", partial_shape)
|
auto status = builder.Attr("shape", partial_shape)
|
||||||
.Attr("dtype", dtype)
|
.Attr("T", dtype)
|
||||||
|
.Attr("index", connection.port_number)
|
||||||
.Finalize(seg_node);
|
.Finalize(seg_node);
|
||||||
VLOG(1) << "Constructing input " << node_name << " for the edge "
|
VLOG(1) << "Constructing input " << node_name << " for the edge "
|
||||||
<< connection.outside_node_name << ":" << connection.outside_port
|
<< connection.outside_node_name << ":" << connection.outside_port
|
||||||
<< " -> " << connection.inside_node_name << ":"
|
<< " -> " << connection.inside_node_name << ":"
|
||||||
<< connection.inside_port;
|
<< connection.inside_port;
|
||||||
} else {
|
} else {
|
||||||
const string node_name = StrCat(kOutputPHName, connection.port_number);
|
const string node_name =
|
||||||
|
StrCat(IONamePrefixes::kOutputPHName, connection.port_number);
|
||||||
if (marker_nodes.count(node_name)) {
|
if (marker_nodes.count(node_name)) {
|
||||||
VLOG(1) << "Reusing output " << node_name << " for the edge "
|
VLOG(1) << "Reusing output " << node_name << " for the edge "
|
||||||
<< connection.inside_node_name << ":" << connection.inside_port
|
<< connection.inside_node_name << ":" << connection.inside_port
|
||||||
@ -5331,9 +5357,10 @@ Status ConvertSegmentToGraphDef(
|
|||||||
}
|
}
|
||||||
marker_nodes.insert(node_name);
|
marker_nodes.insert(node_name);
|
||||||
auto seg_node = segment_def->add_node();
|
auto seg_node = segment_def->add_node();
|
||||||
NodeDefBuilder builder(node_name, "Identity");
|
NodeDefBuilder builder(node_name, "_Retval");
|
||||||
auto status =
|
auto status =
|
||||||
builder
|
builder.Attr("T", dtype)
|
||||||
|
.Attr("index", connection.port_number)
|
||||||
.Input(connection.inside_node_name, connection.inside_port, dtype)
|
.Input(connection.inside_node_name, connection.inside_port, dtype)
|
||||||
.Finalize(seg_node);
|
.Finalize(seg_node);
|
||||||
VLOG(1) << "Constructing output " << node_name << " for the edge "
|
VLOG(1) << "Constructing output " << node_name << " for the edge "
|
||||||
@ -5359,12 +5386,12 @@ Status ConvertSegmentToGraphDef(
|
|||||||
if (connection.is_control_edge() || !connection.is_input_edge) continue;
|
if (connection.is_control_edge() || !connection.is_input_edge) continue;
|
||||||
auto snode =
|
auto snode =
|
||||||
segment_def->mutable_node(old_to_new_id_map[connection.inside_id]);
|
segment_def->mutable_node(old_to_new_id_map[connection.inside_id]);
|
||||||
const string placeholder_name =
|
const string arg_name =
|
||||||
StrCat(kInputPHName, connection.port_number);
|
StrCat(IONamePrefixes::kInputPHName, connection.port_number);
|
||||||
VLOG(1) << "Updating " << snode->name() << ":" << connection.inside_port
|
VLOG(1) << "Updating " << snode->name() << ":" << connection.inside_port
|
||||||
<< " from " << snode->input(connection.inside_port) << " to "
|
<< " from " << snode->input(connection.inside_port) << " to "
|
||||||
<< placeholder_name;
|
<< arg_name;
|
||||||
snode->set_input(connection.inside_port, placeholder_name);
|
snode->set_input(connection.inside_port, arg_name);
|
||||||
}
|
}
|
||||||
std::set<string> subgraph_node_names;
|
std::set<string> subgraph_node_names;
|
||||||
for (const Node* node : subgraph_nodes) {
|
for (const Node* node : subgraph_nodes) {
|
||||||
|
@ -37,8 +37,6 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace tensorrt {
|
namespace tensorrt {
|
||||||
extern const char* const kInputPHName;
|
|
||||||
extern const char* const kOutputPHName;
|
|
||||||
|
|
||||||
namespace convert {
|
namespace convert {
|
||||||
|
|
||||||
@ -119,8 +117,8 @@ struct EngineInfo {
|
|||||||
bool use_calibration;
|
bool use_calibration;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Constructs a graphdef from the segment in the given graph. Adds placeholder
|
// Constructs a graphdef from the segment in the given graph. Adds _Arg
|
||||||
// nodes for input edges (InputPH_*) and identity nodes for output edges
|
// nodes for input edges (InputPH_*) and _Retval nodes for output edges
|
||||||
// (OutputPH_*). This function needs to be called before TensorRT nodes
|
// (OutputPH_*). This function needs to be called before TensorRT nodes
|
||||||
// inserted in order to correctly get sizes from the original graph.
|
// inserted in order to correctly get sizes from the original graph.
|
||||||
//
|
//
|
||||||
|
@ -1158,7 +1158,7 @@ class ConvertGraphDefToEngineTest : public ::testing::Test {
|
|||||||
int batch_size = -1;
|
int batch_size = -1;
|
||||||
for (const NodeDef& node : gdef.node()) {
|
for (const NodeDef& node : gdef.node()) {
|
||||||
absl::string_view node_name(node.name());
|
absl::string_view node_name(node.name());
|
||||||
if (absl::ConsumePrefix(&node_name, kInputPHName)) {
|
if (absl::ConsumePrefix(&node_name, IONamePrefixes::kInputPHName)) {
|
||||||
int port = -1;
|
int port = -1;
|
||||||
EXPECT_TRUE(absl::SimpleAtoi(node_name, &port)) << node.name();
|
EXPECT_TRUE(absl::SimpleAtoi(node_name, &port)) << node.name();
|
||||||
if (input_shapes.size() < port + 1) input_shapes.resize(port + 1);
|
if (input_shapes.size() < port + 1) input_shapes.resize(port + 1);
|
||||||
@ -1188,11 +1188,13 @@ class ConvertGraphDefToEngineTest : public ::testing::Test {
|
|||||||
|
|
||||||
TEST_F(ConvertGraphDefToEngineTest, IdentityGraph) {
|
TEST_F(ConvertGraphDefToEngineTest, IdentityGraph) {
|
||||||
Scope s = Scope::NewRootScope();
|
Scope s = Scope::NewRootScope();
|
||||||
auto input = ops::Placeholder(s.WithOpName(StrCat(kInputPHName, 0)), DT_FLOAT,
|
auto input =
|
||||||
ops::Placeholder::Shape({1, 1}));
|
ops::Placeholder(s.WithOpName(StrCat(IONamePrefixes::kInputPHName, 0)),
|
||||||
|
DT_FLOAT, ops::Placeholder::Shape({1, 1}));
|
||||||
auto output = ops::Identity(s.WithOpName("identity1"), input);
|
auto output = ops::Identity(s.WithOpName("identity1"), input);
|
||||||
output = ops::Identity(s.WithOpName("identity2"), output);
|
output = ops::Identity(s.WithOpName("identity2"), output);
|
||||||
output = ops::Identity(s.WithOpName(StrCat(kOutputPHName, 0)), output);
|
output = ops::Identity(s.WithOpName(StrCat(IONamePrefixes::kOutputPHName, 0)),
|
||||||
|
output);
|
||||||
// If the converter marks the input tensor as output tensor, the conversion
|
// If the converter marks the input tensor as output tensor, the conversion
|
||||||
// below will fail with:
|
// below will fail with:
|
||||||
// > TensorRTOutputPH_0 cannot be both input and output
|
// > TensorRTOutputPH_0 cannot be both input and output
|
||||||
|
@ -67,9 +67,6 @@ Status TRTOptimizationPass::Init(
|
|||||||
if (params.count("use_calibration")) {
|
if (params.count("use_calibration")) {
|
||||||
use_calibration_ = params.at("use_calibration").b();
|
use_calibration_ = params.at("use_calibration").b();
|
||||||
}
|
}
|
||||||
if (params.count("use_function_backup")) {
|
|
||||||
use_function_backup_ = params.at("use_function_backup").b();
|
|
||||||
}
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -258,7 +255,6 @@ Status TRTOptimizationPass::Optimize(grappler::Cluster* cluster,
|
|||||||
cp.is_dyn_op = is_dynamic_op_;
|
cp.is_dyn_op = is_dynamic_op_;
|
||||||
cp.max_cached_engines = max_cached_batches_;
|
cp.max_cached_engines = max_cached_batches_;
|
||||||
cp.use_calibration = use_calibration_;
|
cp.use_calibration = use_calibration_;
|
||||||
cp.use_function_backup = use_function_backup_;
|
|
||||||
auto status = ConvertAfterShapes(cp);
|
auto status = ConvertAfterShapes(cp);
|
||||||
VLOG(1) << "Returning from " << name_;
|
VLOG(1) << "Returning from " << name_;
|
||||||
return status;
|
return status;
|
||||||
|
@ -40,8 +40,7 @@ class TRTOptimizationPass : public grappler::CustomGraphOptimizer {
|
|||||||
is_dynamic_op_(false),
|
is_dynamic_op_(false),
|
||||||
max_cached_batches_(1),
|
max_cached_batches_(1),
|
||||||
max_workspace_size_bytes_(256LL << 20),
|
max_workspace_size_bytes_(256LL << 20),
|
||||||
use_calibration_(true),
|
use_calibration_(true) {
|
||||||
use_function_backup_(true) {
|
|
||||||
VLOG(1) << "Constructing " << name_;
|
VLOG(1) << "Constructing " << name_;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -71,8 +70,6 @@ class TRTOptimizationPass : public grappler::CustomGraphOptimizer {
|
|||||||
int64_t max_workspace_size_bytes_;
|
int64_t max_workspace_size_bytes_;
|
||||||
bool use_calibration_;
|
bool use_calibration_;
|
||||||
|
|
||||||
// Whether to allow TF function fallback path in TRTEngineOp.
|
|
||||||
bool use_function_backup_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace convert
|
} // namespace convert
|
||||||
|
@ -23,6 +23,12 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace tensorrt {
|
namespace tensorrt {
|
||||||
|
|
||||||
|
class IONamePrefixes {
|
||||||
|
public:
|
||||||
|
static constexpr const char* const kInputPHName = "TensorRTInputPH_";
|
||||||
|
static constexpr const char* const kOutputPHName = "TensorRTOutputPH_";
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct TrtDestroyer {
|
struct TrtDestroyer {
|
||||||
void operator()(T* t) {
|
void operator()(T* t) {
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
|
#include "absl/strings/ascii.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
|
#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
|
||||||
@ -24,10 +25,15 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
|
#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
|
||||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
|
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
|
||||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
|
#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
|
||||||
|
#include "tensorflow/core/common_runtime/function.h"
|
||||||
|
#include "tensorflow/core/common_runtime/graph_optimizer.h"
|
||||||
#include "tensorflow/core/framework/function.h"
|
#include "tensorflow/core/framework/function.h"
|
||||||
#include "tensorflow/core/framework/graph_to_functiondef.h"
|
#include "tensorflow/core/framework/graph_to_functiondef.h"
|
||||||
|
#include "tensorflow/core/framework/node_def_builder.h"
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/graph/algorithm.h"
|
||||||
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
#include "tensorflow/core/lib/core/refcount.h"
|
#include "tensorflow/core/lib/core/refcount.h"
|
||||||
#include "tensorflow/core/lib/strings/str_util.h"
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
@ -53,6 +59,7 @@ using ::stream_executor::port::StatusOr;
|
|||||||
|
|
||||||
// A helper class to call done() when destructed for asynchronous execution.
|
// A helper class to call done() when destructed for asynchronous execution.
|
||||||
// Helps simultaneous execution of native and TRT engines.
|
// Helps simultaneous execution of native and TRT engines.
|
||||||
|
|
||||||
class AsyncHelper : public core::RefCounted {
|
class AsyncHelper : public core::RefCounted {
|
||||||
public:
|
public:
|
||||||
AsyncHelper(AsyncOpKernel::DoneCallback done) : done_(done) {}
|
AsyncHelper(AsyncOpKernel::DoneCallback done) : done_(done) {}
|
||||||
@ -89,7 +96,10 @@ class TRTEngineOp : public AsyncOpKernel {
|
|||||||
void ExecuteCalibration(OpKernelContext* ctx, AsyncHelper* helper);
|
void ExecuteCalibration(OpKernelContext* ctx, AsyncHelper* helper);
|
||||||
|
|
||||||
// Construct a function handle for executing native funcdef graph
|
// Construct a function handle for executing native funcdef graph
|
||||||
Status ConstructFunctionHandle(OpKernelContext* ctx);
|
// These are the exact same function.
|
||||||
|
|
||||||
|
Status ConstructFunctionHandle(FunctionLibraryRuntime* lib,
|
||||||
|
const string& device_name);
|
||||||
|
|
||||||
// Execute replaced native segment as function Op.
|
// Execute replaced native segment as function Op.
|
||||||
void ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper);
|
void ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper);
|
||||||
@ -125,10 +135,8 @@ class TRTEngineOp : public AsyncOpKernel {
|
|||||||
// serialized protobuf segment or trt engine depending on static_engine_ flag.
|
// serialized protobuf segment or trt engine depending on static_engine_ flag.
|
||||||
string serialized_segment_;
|
string serialized_segment_;
|
||||||
|
|
||||||
// Name of the function for TF native execution of the segment. If empty, it
|
// The function for TF native execution of the segment.
|
||||||
// means TF native execution is not allowed, and if TRT engine fails to run
|
NameAttrList func_;
|
||||||
// an error will be returned.
|
|
||||||
string funcdef_name_;
|
|
||||||
|
|
||||||
// GraphDef representation of the segment.
|
// GraphDef representation of the segment.
|
||||||
GraphDef segment_graph_;
|
GraphDef segment_graph_;
|
||||||
@ -148,7 +156,7 @@ class TRTEngineOp : public AsyncOpKernel {
|
|||||||
|
|
||||||
int64 workspace_size_;
|
int64 workspace_size_;
|
||||||
mutex engine_mutex_;
|
mutex engine_mutex_;
|
||||||
FunctionLibraryRuntime::Handle native_func_;
|
FunctionLibraryRuntime::Handle func_handle_;
|
||||||
|
|
||||||
// The finalized calibrator for inference.
|
// The finalized calibrator for inference.
|
||||||
std::unique_ptr<TRTInt8Calibrator> calibrator_;
|
std::unique_ptr<TRTInt8Calibrator> calibrator_;
|
||||||
@ -177,23 +185,61 @@ void* GetTensorAddress(const Tensor* tensor_ptr) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status TRTEngineOp::ConstructFunctionHandle(OpKernelContext* ctx) {
|
static Status FunctionDefToGraphDef(FunctionLibraryRuntime::Handle handle,
|
||||||
|
FunctionLibraryRuntime* flib_runtime,
|
||||||
|
GraphDef* graph_def) {
|
||||||
|
const FunctionLibraryDefinition* flib_def =
|
||||||
|
flib_runtime->GetFunctionLibraryDefinition();
|
||||||
|
const FunctionBody* fbody;
|
||||||
|
fbody = flib_runtime->GetFunctionBody(handle);
|
||||||
|
if (!fbody) {
|
||||||
|
return errors::Internal(
|
||||||
|
"Function body is null when converting from FuncDef to GraphDef.");
|
||||||
|
}
|
||||||
|
std::unique_ptr<Graph> graph(new Graph(flib_def));
|
||||||
|
CopyGraph(*fbody->graph, graph.get());
|
||||||
|
|
||||||
|
auto replace_name = [](const char* const prefix, string* name) {
|
||||||
|
if (absl::StartsWith(*name, absl::AsciiStrToLower(prefix))) {
|
||||||
|
name->replace(0, strlen(prefix), prefix);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
graph->ToGraphDef(graph_def);
|
||||||
|
// GraphToFunctionDef() will convert all the node names to lowercase.
|
||||||
|
for (auto& node : *graph_def->mutable_node()) {
|
||||||
|
if (!replace_name(IONamePrefixes::kInputPHName, node.mutable_name())) {
|
||||||
|
if (replace_name(IONamePrefixes::kOutputPHName, node.mutable_name())) {
|
||||||
|
// Instantiation of the function will append _RetVal to the node name,
|
||||||
|
// need to remove it for backward compatibility.
|
||||||
|
const char* const suffix_to_remove = "_RetVal";
|
||||||
|
if (absl::EndsWith(node.name(), suffix_to_remove)) {
|
||||||
|
node.mutable_name()->erase(node.name().size() -
|
||||||
|
strlen(suffix_to_remove));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (auto& input : *node.mutable_input()) {
|
||||||
|
if (!replace_name(IONamePrefixes::kInputPHName, &input)) {
|
||||||
|
replace_name(IONamePrefixes::kOutputPHName, &input);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status TRTEngineOp::ConstructFunctionHandle(FunctionLibraryRuntime* lib,
|
||||||
|
const string& device_name) {
|
||||||
VLOG(1) << "Constructing function handle";
|
VLOG(1) << "Constructing function handle";
|
||||||
auto lib = ctx->function_library();
|
|
||||||
if (lib == nullptr) {
|
if (lib == nullptr) {
|
||||||
return errors::Internal("Context function library is null");
|
return errors::Internal("Context function library is null");
|
||||||
}
|
}
|
||||||
auto fdef = lib->GetFunctionLibraryDefinition()->Find(funcdef_name_);
|
|
||||||
if (fdef == nullptr) {
|
|
||||||
return errors::Internal("Native FunctionDef ", funcdef_name_,
|
|
||||||
" can't be found in function library");
|
|
||||||
}
|
|
||||||
FunctionLibraryRuntime::InstantiateOptions inst_ops;
|
FunctionLibraryRuntime::InstantiateOptions inst_ops;
|
||||||
inst_ops.state_handle = "";
|
inst_ops.state_handle = "";
|
||||||
inst_ops.target = ctx->device()->name();
|
inst_ops.target = device_name;
|
||||||
native_func_ = 0;
|
return lib->Instantiate(func_.name(), AttrSlice(&func_.attr()), inst_ops,
|
||||||
return lib->Instantiate(funcdef_name_, AttrSlice(&fdef->attr()), inst_ops,
|
&func_handle_);
|
||||||
&native_func_);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
|
TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
|
||||||
@ -204,15 +250,7 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
|
|||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
context->GetAttr("workspace_size_bytes", &workspace_size_));
|
context->GetAttr("workspace_size_bytes", &workspace_size_));
|
||||||
OP_REQUIRES_OK(context, context->GetAttr("static_engine", &static_engine_));
|
OP_REQUIRES_OK(context, context->GetAttr("static_engine", &static_engine_));
|
||||||
if (!static_engine_) {
|
|
||||||
OP_REQUIRES(context, segment_graph_.ParseFromString(serialized_segment_),
|
|
||||||
errors::InvalidArgument("Failed to parse segment graphdef!"));
|
|
||||||
VLOG(1) << "Size of serialized GraphDef: "
|
|
||||||
<< serialized_segment_.capacity();
|
|
||||||
string tmp;
|
|
||||||
// Swap with temporary empty string to deallocate the CPU memory.
|
|
||||||
serialized_segment_.swap(tmp);
|
|
||||||
}
|
|
||||||
VLOG(1) << "Constructing " << name();
|
VLOG(1) << "Constructing " << name();
|
||||||
string precision_string;
|
string precision_string;
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
@ -220,12 +258,22 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
|
|||||||
string calibration_data;
|
string calibration_data;
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
context->GetAttr("calibration_data", &calibration_data));
|
context->GetAttr("calibration_data", &calibration_data));
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context, context->GetAttr("segment_func", &func_));
|
||||||
context->GetAttr("segment_funcdef_name", &funcdef_name_));
|
OP_REQUIRES(context, !func_.name().empty(),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"The TF function for the TRT segment could not be empty"));
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
TrtPrecisionModeFromName(precision_string, &precision_mode_));
|
TrtPrecisionModeFromName(precision_string, &precision_mode_));
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
context->GetAttr("use_calibration", &use_calibration_));
|
context->GetAttr("use_calibration", &use_calibration_));
|
||||||
|
func_handle_ = kInvalidHandle;
|
||||||
|
if (!static_engine_) {
|
||||||
|
FunctionLibraryRuntime* lib = context->function_library();
|
||||||
|
OP_REQUIRES_OK(context,
|
||||||
|
ConstructFunctionHandle(lib, context->device()->name()));
|
||||||
|
OP_REQUIRES_OK(context,
|
||||||
|
FunctionDefToGraphDef(func_handle_, lib, &segment_graph_));
|
||||||
|
}
|
||||||
calibration_mode_ =
|
calibration_mode_ =
|
||||||
(use_calibration_ && precision_mode_ == TrtPrecisionMode::INT8 &&
|
(use_calibration_ && precision_mode_ == TrtPrecisionMode::INT8 &&
|
||||||
calibration_data.empty());
|
calibration_data.empty());
|
||||||
@ -233,20 +281,19 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
|
|||||||
calibrator_.reset(new TRTInt8Calibrator(calibration_data));
|
calibrator_.reset(new TRTInt8Calibrator(calibration_data));
|
||||||
calibration_data.resize(0);
|
calibration_data.resize(0);
|
||||||
}
|
}
|
||||||
native_func_ = kInvalidHandle;
|
|
||||||
OP_REQUIRES_OK(context, context->GetAttr("max_cached_engines_count",
|
OP_REQUIRES_OK(context, context->GetAttr("max_cached_engines_count",
|
||||||
&max_cached_engines_));
|
&max_cached_engines_));
|
||||||
}
|
}
|
||||||
|
|
||||||
void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx,
|
void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx,
|
||||||
AsyncHelper* helper) {
|
AsyncHelper* helper) {
|
||||||
OP_REQUIRES_ASYNC(ctx, !funcdef_name_.empty(),
|
|
||||||
errors::Internal("Fallback path is disabled, for ", name()),
|
|
||||||
*helper);
|
|
||||||
std::vector<Tensor> inputs;
|
std::vector<Tensor> inputs;
|
||||||
std::vector<Tensor>* outputs = new std::vector<Tensor>();
|
std::vector<Tensor>* outputs = new std::vector<Tensor>();
|
||||||
if (native_func_ == kInvalidHandle) {
|
if (func_handle_ == kInvalidHandle) {
|
||||||
OP_REQUIRES_OK_ASYNC(ctx, ConstructFunctionHandle(ctx), *helper);
|
OP_REQUIRES_OK_ASYNC(
|
||||||
|
ctx,
|
||||||
|
ConstructFunctionHandle(ctx->function_library(), ctx->device()->name()),
|
||||||
|
*helper);
|
||||||
}
|
}
|
||||||
auto lib = ctx->function_library();
|
auto lib = ctx->function_library();
|
||||||
FunctionLibraryRuntime::Options opts;
|
FunctionLibraryRuntime::Options opts;
|
||||||
@ -259,7 +306,7 @@ void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx,
|
|||||||
}
|
}
|
||||||
helper->Ref(); // Increment count for calculating native graph
|
helper->Ref(); // Increment count for calculating native graph
|
||||||
VLOG(1) << "Executing native segment: " << name();
|
VLOG(1) << "Executing native segment: " << name();
|
||||||
lib->Run(opts, native_func_, inputs, outputs,
|
lib->Run(opts, func_handle_, inputs, outputs,
|
||||||
[this, ctx, outputs, helper](const Status& s) {
|
[this, ctx, outputs, helper](const Status& s) {
|
||||||
core::ScopedUnref sc(helper);
|
core::ScopedUnref sc(helper);
|
||||||
OP_REQUIRES_OK_ASYNC(ctx, s, *helper);
|
OP_REQUIRES_OK_ASYNC(ctx, s, *helper);
|
||||||
@ -298,7 +345,7 @@ void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx,
|
|||||||
const auto device_tensor =
|
const auto device_tensor =
|
||||||
calib_ctx->device_tensors_.at(i).AccessTensor(ctx);
|
calib_ctx->device_tensors_.at(i).AccessTensor(ctx);
|
||||||
CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes());
|
CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes());
|
||||||
input_data.emplace(StrCat(kInputPHName, i), data_address);
|
input_data.emplace(StrCat(IONamePrefixes::kInputPHName, i), data_address);
|
||||||
}
|
}
|
||||||
VLOG(2) << "Filled map for sending";
|
VLOG(2) << "Filled map for sending";
|
||||||
// copied from cuda_kernel_helper since it seems only valid in *.cu.cc files
|
// copied from cuda_kernel_helper since it seems only valid in *.cu.cc files
|
||||||
@ -435,9 +482,11 @@ bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx,
|
|||||||
// input.
|
// input.
|
||||||
const int num_batch = ctx->input(0).shape().dim_size(0);
|
const int num_batch = ctx->input(0).shape().dim_size(0);
|
||||||
const int num_binding = ctx->num_inputs() + ctx->num_outputs();
|
const int num_binding = ctx->num_inputs() + ctx->num_outputs();
|
||||||
|
|
||||||
std::vector<void*> buffers(num_binding);
|
std::vector<void*> buffers(num_binding);
|
||||||
|
|
||||||
for (int i = 0; i < ctx->num_inputs(); i++) {
|
for (int i = 0; i < ctx->num_inputs(); i++) {
|
||||||
const string input_name = StrCat(kInputPHName, i);
|
const string input_name = StrCat(IONamePrefixes::kInputPHName, i);
|
||||||
const int binding_index = cuda_engine->getBindingIndex(input_name.c_str());
|
const int binding_index = cuda_engine->getBindingIndex(input_name.c_str());
|
||||||
if (binding_index == -1) {
|
if (binding_index == -1) {
|
||||||
const string msg =
|
const string msg =
|
||||||
@ -479,7 +528,7 @@ bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx,
|
|||||||
|
|
||||||
for (int i = 0; i < ctx->num_outputs(); i++) {
|
for (int i = 0; i < ctx->num_outputs(); i++) {
|
||||||
// Create an output tensor
|
// Create an output tensor
|
||||||
const string output_name = StrCat(kOutputPHName, i);
|
const string output_name = StrCat(IONamePrefixes::kOutputPHName, i);
|
||||||
const int binding_index = cuda_engine->getBindingIndex(output_name.c_str());
|
const int binding_index = cuda_engine->getBindingIndex(output_name.c_str());
|
||||||
Tensor* output_tensor = nullptr;
|
Tensor* output_tensor = nullptr;
|
||||||
|
|
||||||
@ -713,7 +762,7 @@ Status TRTEngineOp::AllocateCalibrationResources(
|
|||||||
"Unsupported data type encountered in input ", i);
|
"Unsupported data type encountered in input ", i);
|
||||||
}
|
}
|
||||||
cres->device_buffers_.emplace(
|
cres->device_buffers_.emplace(
|
||||||
StrCat(kInputPHName, i),
|
StrCat(IONamePrefixes::kInputPHName, i),
|
||||||
std::pair<void*, size_t>(device_address, device_tensor->TotalBytes()));
|
std::pair<void*, size_t>(device_address, device_tensor->TotalBytes()));
|
||||||
}
|
}
|
||||||
cres->calibrator_.reset(
|
cres->calibrator_.reset(
|
||||||
@ -727,56 +776,55 @@ Status TRTEngineOp::AllocateCalibrationResources(
|
|||||||
}
|
}
|
||||||
|
|
||||||
cache_res->Ref();
|
cache_res->Ref();
|
||||||
cres->thr_.reset(
|
cres->thr_.reset(new std::thread([this, cres, shapes, platform_gpu_id,
|
||||||
new std::thread([this, cres, shapes, platform_gpu_id, cache_res]() {
|
cache_res]() {
|
||||||
core::ScopedUnref sc(cache_res);
|
core::ScopedUnref sc(cache_res);
|
||||||
|
|
||||||
LOG(INFO) << "Starting calibration thread on device " << platform_gpu_id
|
LOG(INFO) << "Starting calibration thread on device " << platform_gpu_id
|
||||||
<< ", Calibration Resource @ " << cres;
|
<< ", Calibration Resource @ " << cres;
|
||||||
auto err = cudaSetDevice(platform_gpu_id);
|
auto err = cudaSetDevice(platform_gpu_id);
|
||||||
if (err != cudaSuccess) {
|
if (err != cudaSuccess) {
|
||||||
// TODO(aaroey): should return error here.
|
// TODO(aaroey): should return error here.
|
||||||
LOG(ERROR) << "Couldn't set cuda device to " << platform_gpu_id
|
LOG(ERROR) << "Couldn't set cuda device to " << platform_gpu_id
|
||||||
<< " in calibration thread";
|
<< " in calibration thread";
|
||||||
}
|
}
|
||||||
std::vector<PartialTensorShape> partial_shapes(shapes.begin(),
|
std::vector<PartialTensorShape> partial_shapes(shapes.begin(),
|
||||||
shapes.end());
|
shapes.end());
|
||||||
// ConvertGraphDefToEngine() will try to build the engine. This thread
|
// ConvertGraphDefToEngine() will try to build the engine. This thread
|
||||||
// will loop inside buildCudaEngine() consuming the calibration data
|
// will loop inside buildCudaEngine() consuming the calibration data
|
||||||
// that is set by the TF op, and drive the builder until calibrator
|
// that is set by the TF op, and drive the builder until calibrator
|
||||||
// returns false. Engine is discarded after calibration table is
|
// returns false. Engine is discarded after calibration table is
|
||||||
// generated
|
// generated
|
||||||
//
|
//
|
||||||
// TODO(aaroey): maybe setting the max batch size using the python
|
// TODO(aaroey): maybe setting the max batch size using the python
|
||||||
// calibration wrapper class.
|
// calibration wrapper class.
|
||||||
auto s = convert::ConvertGraphDefToEngine(
|
auto s = convert::ConvertGraphDefToEngine(
|
||||||
this->segment_graph_, TrtPrecisionMode::INT8,
|
this->segment_graph_, TrtPrecisionMode::INT8,
|
||||||
cres->calibrator_->getBatchSize(), this->workspace_size_,
|
cres->calibrator_->getBatchSize(), this->workspace_size_,
|
||||||
partial_shapes, &cache_res->GetLogger(),
|
partial_shapes, &cache_res->GetLogger(), cache_res->allocator_.get(),
|
||||||
cache_res->allocator_.get(), cres->calibrator_.get(),
|
cres->calibrator_.get(), &cres->engine_,
|
||||||
&cres->engine_,
|
/*use_calibration=*/true,
|
||||||
/*use_calibration=*/true,
|
/*convert_successfully=*/nullptr);
|
||||||
/*convert_successfully=*/nullptr);
|
if (!s.ok()) {
|
||||||
if (!s.ok()) {
|
LOG(ERROR) << "Calibration failed: " << s;
|
||||||
LOG(ERROR) << "Calibration failed: " << s;
|
cres->calibrator_->setDone(); // Ignore further pushes
|
||||||
cres->calibrator_->setDone(); // Ignore further pushes
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Transfer the ownership of the engine to the engine cache, so we can
|
// Transfer the ownership of the engine to the engine cache, so we can
|
||||||
// dump it out during conversion for TF 2.0.
|
// dump it out during conversion for TF 2.0.
|
||||||
if (cache_res) {
|
if (cache_res) {
|
||||||
mutex_lock lock(this->engine_mutex_);
|
mutex_lock lock(this->engine_mutex_);
|
||||||
cres->SetCalibrationTable();
|
cres->SetCalibrationTable();
|
||||||
this->calibrator_ = std::move(cres->calibrator_);
|
this->calibrator_ = std::move(cres->calibrator_);
|
||||||
TrtUniquePtrType<nvinfer1::IExecutionContext> exec_context(
|
TrtUniquePtrType<nvinfer1::IExecutionContext> exec_context(
|
||||||
cres->engine_->createExecutionContext());
|
cres->engine_->createExecutionContext());
|
||||||
cache_res->cache_.emplace(
|
cache_res->cache_.emplace(
|
||||||
shapes, absl::make_unique<EngineContext>(
|
shapes, absl::make_unique<EngineContext>(std::move(cres->engine_),
|
||||||
std::move(cres->engine_), std::move(exec_context)));
|
std::move(exec_context)));
|
||||||
}
|
}
|
||||||
|
|
||||||
VLOG(1) << "Calibration loop terminated " << this->name();
|
VLOG(1) << "Calibration loop terminated " << this->name();
|
||||||
}));
|
}));
|
||||||
VLOG(1) << "initialized calibrator resource";
|
VLOG(1) << "initialized calibrator resource";
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -22,10 +22,17 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <gmock/gmock.h>
|
#include <gmock/gmock.h>
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "tensorflow/cc/ops/function_ops.h"
|
||||||
#include "tensorflow/cc/ops/standard_ops.h"
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
#include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h"
|
||||||
|
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
|
||||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
|
#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
|
||||||
#include "tensorflow/core/framework/fake_input.h"
|
#include "tensorflow/core/framework/fake_input.h"
|
||||||
|
#include "tensorflow/core/framework/function.h"
|
||||||
|
#include "tensorflow/core/framework/graph_to_functiondef.h"
|
||||||
#include "tensorflow/core/framework/node_def_builder.h"
|
#include "tensorflow/core/framework/node_def_builder.h"
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||||
@ -38,6 +45,7 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace tensorrt {
|
namespace tensorrt {
|
||||||
|
using ::absl::StrCat;
|
||||||
using ::testing::ElementsAre;
|
using ::testing::ElementsAre;
|
||||||
|
|
||||||
class TRTEngineOpTestBase : public OpsTestBase {
|
class TRTEngineOpTestBase : public OpsTestBase {
|
||||||
@ -49,25 +57,32 @@ class TRTEngineOpTestBase : public OpsTestBase {
|
|||||||
|
|
||||||
// Create simple TF graph.
|
// Create simple TF graph.
|
||||||
Scope s = Scope::NewRootScope();
|
Scope s = Scope::NewRootScope();
|
||||||
auto feed = ops::Placeholder(s.WithOpName("TensorRTInputPH_0"), dtype,
|
auto feed = ops::_Arg(s.WithOpName("TensorRTInputPH_0"), dtype, 0);
|
||||||
ops::Placeholder::Shape({-1, -1}));
|
|
||||||
auto add = ops::Add(s.WithOpName("add"), feed, feed);
|
auto add = ops::Add(s.WithOpName("add"), feed, feed);
|
||||||
ops::Identity(s.WithOpName("TensorRTOutputPH_0"), add);
|
ops::_Retval(s.WithOpName("TensorRTOutputPH_0"), add, 0);
|
||||||
|
|
||||||
// Serialize the graph. TRTEngineOp will convert it using dynamic mode.
|
// Serialize the graph. TRTEngineOp will convert it using dynamic mode.
|
||||||
GraphDef graph_def;
|
GraphDef graph_def;
|
||||||
TF_ASSERT_OK(s.ToGraphDef(&graph_def));
|
TF_ASSERT_OK(s.ToGraphDef(&graph_def));
|
||||||
|
Graph* graph = s.graph();
|
||||||
|
const char* op_name = "myop";
|
||||||
|
TF_ASSERT_OK(
|
||||||
|
convert::RegisterGraphToFunctionLibrary(graph_def, graph, op_name));
|
||||||
|
TF_ASSERT_OK(flib_def_->AddLibrary(graph->flib_def()));
|
||||||
|
|
||||||
PartialTensorShape shape({-1, -1});
|
PartialTensorShape shape({-1, -1});
|
||||||
|
|
||||||
// Create the op.
|
// Create the op.
|
||||||
OpsTestBase::SetDevice(DEVICE_GPU, std::move(device));
|
OpsTestBase::SetDevice(DEVICE_GPU, std::move(device));
|
||||||
TF_ASSERT_OK(NodeDefBuilder("myop", "TRTEngineOp")
|
NameAttrList function;
|
||||||
|
function.set_name(StrCat(op_name, "_native_segment"));
|
||||||
|
TF_ASSERT_OK(NodeDefBuilder(op_name, "TRTEngineOp")
|
||||||
.Input(FakeInput(1, dtype))
|
.Input(FakeInput(1, dtype))
|
||||||
.Attr("input_shapes", {shape})
|
.Attr("input_shapes", {shape})
|
||||||
.Attr("output_shapes", {shape})
|
.Attr("output_shapes", {shape})
|
||||||
.Attr("static_engine", false)
|
.Attr("static_engine", false)
|
||||||
.Attr("segment_funcdef_name", "") // no native fallback
|
.Attr("segment_func", function)
|
||||||
.Attr("serialized_segment", graph_def.SerializeAsString())
|
.Attr("serialized_segment", "")
|
||||||
.Attr("calibration_data", "")
|
.Attr("calibration_data", "")
|
||||||
.Attr("max_cached_engines_count", max_cached_engines_count)
|
.Attr("max_cached_engines_count", max_cached_engines_count)
|
||||||
.Attr("workspace_size_bytes", 1 << 20)
|
.Attr("workspace_size_bytes", 1 << 20)
|
||||||
@ -75,7 +90,7 @@ class TRTEngineOpTestBase : public OpsTestBase {
|
|||||||
.Attr("use_calibration", false)
|
.Attr("use_calibration", false)
|
||||||
.Attr("OutT", {dtype})
|
.Attr("OutT", {dtype})
|
||||||
.Finalize(OpsTestBase::node_def()));
|
.Finalize(OpsTestBase::node_def()));
|
||||||
TF_ASSERT_OK(OpsTestBase::InitOp());
|
TF_ASSERT_OK(InitOpWithFunctionLibrary());
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -89,9 +104,20 @@ class TRTEngineOpTestBase : public OpsTestBase {
|
|||||||
inputs_.clear();
|
inputs_.clear();
|
||||||
gtl::STLDeleteElements(&tensors_);
|
gtl::STLDeleteElements(&tensors_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
Status InitOpWithFunctionLibrary() {
|
||||||
|
OpKernel* kernel = nullptr;
|
||||||
|
Status status = CreateOpKernel(device_type_, device_, allocator(),
|
||||||
|
pflr_->GetFLR(device_->name()), node_def_,
|
||||||
|
TF_GRAPH_DEF_VERSION, &kernel);
|
||||||
|
kernel_ = std::unique_ptr<OpKernel>(kernel);
|
||||||
|
if (kernel_ != nullptr) input_types_ = kernel_->input_types();
|
||||||
|
return status;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(TRTEngineOpTestBase, dynamic_shapes) {
|
TEST_F(TRTEngineOpTestBase, DynamicShapes) {
|
||||||
TRTEngineOpTestBase::AddSimpleTrtOp(DT_FLOAT, /*max_cached_engines_count=*/4);
|
TRTEngineOpTestBase::AddSimpleTrtOp(DT_FLOAT, /*max_cached_engines_count=*/4);
|
||||||
|
|
||||||
// Execute the op with batch size > 1.
|
// Execute the op with batch size > 1.
|
||||||
|
@ -33,7 +33,7 @@ namespace tensorflow {
|
|||||||
// key to cache the instantiated functions for different executor subgraphs.
|
// key to cache the instantiated functions for different executor subgraphs.
|
||||||
REGISTER_OP("TRTEngineOp")
|
REGISTER_OP("TRTEngineOp")
|
||||||
.Attr("serialized_segment: string")
|
.Attr("serialized_segment: string")
|
||||||
.Attr("segment_funcdef_name: string")
|
.Attr("segment_func: func = {}")
|
||||||
.Attr("InT: list({int8,float16,float32,int32})")
|
.Attr("InT: list({int8,float16,float32,int32})")
|
||||||
.Attr("OutT: list({int8,float16,float32,int32})")
|
.Attr("OutT: list({int8,float16,float32,int32})")
|
||||||
.Attr("max_cached_engines_count: int = 1")
|
.Attr("max_cached_engines_count: int = 1")
|
||||||
@ -51,10 +51,11 @@ REGISTER_OP("TRTEngineOp")
|
|||||||
// inference function as a workaround.
|
// inference function as a workaround.
|
||||||
.SetShapeFn(shape_inference::UnknownShape)
|
.SetShapeFn(shape_inference::UnknownShape)
|
||||||
// Deprecated attributes.
|
// Deprecated attributes.
|
||||||
|
.Attr("segment_funcdef_name: string = ''")
|
||||||
.Attr("cached_engine_batches: list(int) >= 0 = []")
|
.Attr("cached_engine_batches: list(int) >= 0 = []")
|
||||||
.Attr("fixed_input_size: bool = true")
|
.Attr("fixed_input_size: bool = true")
|
||||||
.Attr("input_shapes: list(shape)")
|
.Attr("input_shapes: list(shape) = []")
|
||||||
.Attr("output_shapes: list(shape)")
|
.Attr("output_shapes: list(shape) = []")
|
||||||
.Attr("static_engine: bool = true");
|
.Attr("static_engine: bool = true");
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -153,8 +153,7 @@ class QuantizationAwareTrainingMNISTTest(test_util.TensorFlowTestCase):
|
|||||||
# runtime to allocate GPU memory.
|
# runtime to allocate GPU memory.
|
||||||
max_workspace_size_bytes=1 << 28,
|
max_workspace_size_bytes=1 << 28,
|
||||||
minimum_segment_size=2,
|
minimum_segment_size=2,
|
||||||
use_calibration=False,
|
use_calibration=False)
|
||||||
use_function_backup=False)
|
|
||||||
graph_def = converter.convert()
|
graph_def = converter.convert()
|
||||||
logging.info('Number of nodes after TF-TRT conversion: %d',
|
logging.info('Number of nodes after TF-TRT conversion: %d',
|
||||||
len(graph_def.node))
|
len(graph_def.node))
|
||||||
|
@ -23,6 +23,7 @@ import errno
|
|||||||
import gc
|
import gc
|
||||||
import itertools
|
import itertools
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
import warnings
|
import warnings
|
||||||
@ -234,10 +235,8 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
|
|||||||
is_dynamic_op=run_params.dynamic_engine,
|
is_dynamic_op=run_params.dynamic_engine,
|
||||||
maximum_cached_engines=1,
|
maximum_cached_engines=1,
|
||||||
use_calibration=run_params.use_calibration,
|
use_calibration=run_params.use_calibration,
|
||||||
use_function_backup=False,
|
|
||||||
max_batch_size=min(batch_list))
|
max_batch_size=min(batch_list))
|
||||||
return conversion_params._replace(
|
return conversion_params
|
||||||
use_function_backup=IsQuantizationWithCalibration(conversion_params))
|
|
||||||
|
|
||||||
def ShouldRunTest(self, run_params):
|
def ShouldRunTest(self, run_params):
|
||||||
"""Whether to run the test."""
|
"""Whether to run the test."""
|
||||||
@ -388,8 +387,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
|
|||||||
minimum_segment_size=conversion_params.minimum_segment_size,
|
minimum_segment_size=conversion_params.minimum_segment_size,
|
||||||
is_dynamic_op=conversion_params.is_dynamic_op,
|
is_dynamic_op=conversion_params.is_dynamic_op,
|
||||||
maximum_cached_engines=conversion_params.maximum_cached_engines,
|
maximum_cached_engines=conversion_params.maximum_cached_engines,
|
||||||
use_calibration=conversion_params.use_calibration,
|
use_calibration=conversion_params.use_calibration)
|
||||||
use_function_backup=conversion_params.use_function_backup)
|
|
||||||
|
|
||||||
def _GetCalibratedInferGraph(self, run_params, saved_model_dir, inputs_data):
|
def _GetCalibratedInferGraph(self, run_params, saved_model_dir, inputs_data):
|
||||||
"""Return trt converted graphdef in INT8 mode."""
|
"""Return trt converted graphdef in INT8 mode."""
|
||||||
@ -558,21 +556,18 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
|
|||||||
if node.op == "TRTEngineOp":
|
if node.op == "TRTEngineOp":
|
||||||
logging.info("Found TRTEngineOp: " + node.name)
|
logging.info("Found TRTEngineOp: " + node.name)
|
||||||
num_engines += 1
|
num_engines += 1
|
||||||
segment_funcdef_name = node.attr["segment_funcdef_name"].s
|
segment_funcdef_name = node.attr["segment_func"].func.name
|
||||||
function_name = node.name + "_native_segment"
|
function_name = node.name + "_native_segment"
|
||||||
if IsQuantizationWithCalibration(run_params):
|
is_dynamic_engine = not node.attr["static_engine"].b
|
||||||
self.assertNotEmpty(segment_funcdef_name, node.name)
|
self.assertNotEmpty(segment_funcdef_name, node.name)
|
||||||
self.assertIn(function_name, functions)
|
self.assertIn(function_name, functions)
|
||||||
else:
|
if not IsQuantizationWithCalibration and not is_dynamic_engine:
|
||||||
self.assertEmpty(segment_funcdef_name, node.name)
|
self.assertTrue(len(node.attr["serialized_segment"].s), node.name)
|
||||||
self.assertNotIn(function_name, functions)
|
|
||||||
self.assertIn(node.name, expected_engines)
|
self.assertIn(node.name, expected_engines)
|
||||||
self.assertTrue(len(node.attr["serialized_segment"].s), node.name)
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
self._ToBytes(run_params.precision_mode),
|
self._ToBytes(run_params.precision_mode),
|
||||||
node.attr["precision_mode"].s, node.name)
|
node.attr["precision_mode"].s, node.name)
|
||||||
|
|
||||||
is_dynamic_engine = not node.attr["static_engine"].b
|
|
||||||
self.assertEqual(run_params.dynamic_engine, is_dynamic_engine,
|
self.assertEqual(run_params.dynamic_engine, is_dynamic_engine,
|
||||||
node.name)
|
node.name)
|
||||||
self.assertEqual(node.attr["use_calibration"].b,
|
self.assertEqual(node.attr["use_calibration"].b,
|
||||||
@ -602,10 +597,11 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
|
|||||||
node.name for node in gdef_to_verify.node if node.op == "TRTEngineOp"
|
node.name for node in gdef_to_verify.node if node.op == "TRTEngineOp"
|
||||||
]
|
]
|
||||||
for func in gdef_to_verify.library.function:
|
for func in gdef_to_verify.library.function:
|
||||||
for node in func.node_def:
|
if not re.search(r"TRTEngineOp_\d+_native_segment", func.signature.name):
|
||||||
all_op_names.append(node.name)
|
for node in func.node_def:
|
||||||
if node.op == "TRTEngineOp":
|
all_op_names.append(node.name)
|
||||||
trt_op_names.append(node.name)
|
if node.op == "TRTEngineOp":
|
||||||
|
trt_op_names.append(node.name)
|
||||||
# Remove the function name prefix.
|
# Remove the function name prefix.
|
||||||
def _Canonicalize(names):
|
def _Canonicalize(names):
|
||||||
return set([self._ToString(name.split("/")[-1]) for name in names])
|
return set([self._ToString(name.split("/")[-1]) for name in names])
|
||||||
|
@ -147,11 +147,6 @@ TrtConversionParams = collections.namedtuple(
|
|||||||
# trained with fake quantization.
|
# trained with fake quantization.
|
||||||
"use_calibration",
|
"use_calibration",
|
||||||
|
|
||||||
# If set to True, it will create a FunctionDef for each subgraph that is
|
|
||||||
# converted to TRT op, and if TRT ops fail to execute at runtime, it'll
|
|
||||||
# invoke that function as a fallback.
|
|
||||||
"use_function_backup",
|
|
||||||
|
|
||||||
# Max size for the input batch.
|
# Max size for the input batch.
|
||||||
# This option is deprecated in TF 2.0.
|
# This option is deprecated in TF 2.0.
|
||||||
"max_batch_size",
|
"max_batch_size",
|
||||||
@ -165,7 +160,6 @@ DEFAULT_TRT_CONVERSION_PARAMS = TrtConversionParams(
|
|||||||
is_dynamic_op=False,
|
is_dynamic_op=False,
|
||||||
maximum_cached_engines=1,
|
maximum_cached_engines=1,
|
||||||
use_calibration=True,
|
use_calibration=True,
|
||||||
use_function_backup=True,
|
|
||||||
max_batch_size=1)
|
max_batch_size=1)
|
||||||
|
|
||||||
_TRT_ENGINE_CACHE_CONTAINER_NAME = "TF-TRT-Engine-Cache"
|
_TRT_ENGINE_CACHE_CONTAINER_NAME = "TF-TRT-Engine-Cache"
|
||||||
@ -272,8 +266,6 @@ def get_tensorrt_rewriter_config(
|
|||||||
"maximum_cached_engines"].i = conversion_params.maximum_cached_engines
|
"maximum_cached_engines"].i = conversion_params.maximum_cached_engines
|
||||||
optimizer.parameter_map[
|
optimizer.parameter_map[
|
||||||
"use_calibration"].b = conversion_params.use_calibration
|
"use_calibration"].b = conversion_params.use_calibration
|
||||||
optimizer.parameter_map[
|
|
||||||
"use_function_backup"].b = conversion_params.use_function_backup
|
|
||||||
|
|
||||||
if is_v2:
|
if is_v2:
|
||||||
# Static mode (a.k.a pre-generating TRT engines and make them node
|
# Static mode (a.k.a pre-generating TRT engines and make them node
|
||||||
@ -344,8 +336,7 @@ class TrtGraphConverter(object):
|
|||||||
minimum_segment_size=3,
|
minimum_segment_size=3,
|
||||||
is_dynamic_op=False,
|
is_dynamic_op=False,
|
||||||
maximum_cached_engines=1,
|
maximum_cached_engines=1,
|
||||||
use_calibration=True,
|
use_calibration=True):
|
||||||
use_function_backup=True):
|
|
||||||
"""Initialize the converter.
|
"""Initialize the converter.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -384,9 +375,6 @@ class TrtGraphConverter(object):
|
|||||||
will occur. Please note that accuracy may be negatively affected if
|
will occur. Please note that accuracy may be negatively affected if
|
||||||
there is a mismatch between which tensors TRT quantizes and which
|
there is a mismatch between which tensors TRT quantizes and which
|
||||||
tensors were trained with fake quantization.
|
tensors were trained with fake quantization.
|
||||||
use_function_backup: if set to True, it will create a FunctionDef for each
|
|
||||||
subgraph that is converted to TRT op, and if TRT ops fail to execute at
|
|
||||||
runtime, it'll invoke that function as a fallback.
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if the combination of the parameters is invalid.
|
ValueError: if the combination of the parameters is invalid.
|
||||||
@ -424,12 +412,6 @@ class TrtGraphConverter(object):
|
|||||||
"dynamic TRT ops only. Disregarding is_dynamic_op parameter.")
|
"dynamic TRT ops only. Disregarding is_dynamic_op parameter.")
|
||||||
is_dynamic_op = True
|
is_dynamic_op = True
|
||||||
|
|
||||||
# TODO(laigd): consider provide a mechanism to remove the fallback path
|
|
||||||
# after calibration is done.
|
|
||||||
if self._need_calibration and not use_function_backup:
|
|
||||||
raise ValueError(
|
|
||||||
"Calibration requires enabling fallback to TF function execution.")
|
|
||||||
|
|
||||||
# TODO(laigd):
|
# TODO(laigd):
|
||||||
# - Verify in int8 mode that maximum_cached_engines is set properly.
|
# - Verify in int8 mode that maximum_cached_engines is set properly.
|
||||||
# - If it fails to build the int8 engine it should return error.
|
# - If it fails to build the int8 engine it should return error.
|
||||||
@ -446,7 +428,6 @@ class TrtGraphConverter(object):
|
|||||||
is_dynamic_op=is_dynamic_op,
|
is_dynamic_op=is_dynamic_op,
|
||||||
maximum_cached_engines=maximum_cached_engines,
|
maximum_cached_engines=maximum_cached_engines,
|
||||||
use_calibration=use_calibration,
|
use_calibration=use_calibration,
|
||||||
use_function_backup=use_function_backup,
|
|
||||||
max_batch_size=max_batch_size)
|
max_batch_size=max_batch_size)
|
||||||
_check_conversion_params(self._conversion_params)
|
_check_conversion_params(self._conversion_params)
|
||||||
|
|
||||||
|
@ -205,8 +205,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
max_batch_size=1,
|
max_batch_size=1,
|
||||||
minimum_segment_size=3,
|
minimum_segment_size=3,
|
||||||
is_dynamic_op=False,
|
is_dynamic_op=False,
|
||||||
maximum_cached_engines=1,
|
maximum_cached_engines=1):
|
||||||
use_function_backup=False):
|
|
||||||
"""Helper method to convert a GraphDef or SavedModel using TF-TRT."""
|
"""Helper method to convert a GraphDef or SavedModel using TF-TRT."""
|
||||||
converter = trt_convert.TrtGraphConverter(
|
converter = trt_convert.TrtGraphConverter(
|
||||||
input_saved_model_dir=input_saved_model_dir,
|
input_saved_model_dir=input_saved_model_dir,
|
||||||
@ -220,8 +219,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
else trt_convert.TrtPrecisionMode.FP32),
|
else trt_convert.TrtPrecisionMode.FP32),
|
||||||
minimum_segment_size=minimum_segment_size,
|
minimum_segment_size=minimum_segment_size,
|
||||||
is_dynamic_op=is_dynamic_op,
|
is_dynamic_op=is_dynamic_op,
|
||||||
maximum_cached_engines=maximum_cached_engines,
|
maximum_cached_engines=maximum_cached_engines)
|
||||||
use_function_backup=use_function_backup)
|
|
||||||
output_graph_def = converter.convert()
|
output_graph_def = converter.convert()
|
||||||
|
|
||||||
if need_calibration:
|
if need_calibration:
|
||||||
@ -254,8 +252,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
input_saved_model_dir=input_saved_model_dir,
|
input_saved_model_dir=input_saved_model_dir,
|
||||||
output_saved_model_dir=output_saved_model_dir,
|
output_saved_model_dir=output_saved_model_dir,
|
||||||
need_calibration=need_calibration,
|
need_calibration=need_calibration,
|
||||||
is_dynamic_op=is_dynamic_op,
|
is_dynamic_op=is_dynamic_op)
|
||||||
use_function_backup=need_calibration)
|
|
||||||
graph_defs_to_verify = [output_graph_def]
|
graph_defs_to_verify = [output_graph_def]
|
||||||
|
|
||||||
if output_saved_model_dir:
|
if output_saved_model_dir:
|
||||||
@ -316,8 +313,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
conversion_params=trt_convert.DEFAULT_TRT_CONVERSION_PARAMS._replace(
|
conversion_params=trt_convert.DEFAULT_TRT_CONVERSION_PARAMS._replace(
|
||||||
precision_mode=trt_convert.TrtPrecisionMode.FP32,
|
precision_mode=trt_convert.TrtPrecisionMode.FP32,
|
||||||
is_dynamic_op=True,
|
is_dynamic_op=True,
|
||||||
maximum_cached_engines=2,
|
maximum_cached_engines=2))
|
||||||
use_function_backup=False))
|
|
||||||
|
|
||||||
@test_util.run_v2_only
|
@test_util.run_v2_only
|
||||||
def testTrtGraphConverter_BasicConversion_v2(self):
|
def testTrtGraphConverter_BasicConversion_v2(self):
|
||||||
@ -564,17 +560,9 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
def _TestRun(self,
|
def _TestRun(self,
|
||||||
sess,
|
sess,
|
||||||
batch_size,
|
batch_size,
|
||||||
use_function_backup=False,
|
|
||||||
expect_engine_is_run=True):
|
expect_engine_is_run=True):
|
||||||
try:
|
result = sess.run("output:0", feed_dict={"input:0": [[[1.0]]] * batch_size})
|
||||||
result = sess.run(
|
self.assertAllEqual([[[4.0]]] * batch_size, result)
|
||||||
"output:0", feed_dict={"input:0": [[[1.0]]] * batch_size})
|
|
||||||
self.assertAllEqual([[[4.0]]] * batch_size, result)
|
|
||||||
except errors.OpError as e:
|
|
||||||
# This should happen only when fallback path is disabled and TRT engine
|
|
||||||
# fails to run.
|
|
||||||
self.assertTrue(not use_function_backup and not expect_engine_is_run)
|
|
||||||
self.assertIn("Fallback path is disabled, for TRTEngineOp_0", str(e))
|
|
||||||
|
|
||||||
@test_util.deprecated_graph_mode_only
|
@test_util.deprecated_graph_mode_only
|
||||||
def testTrtGraphConverter_MinimumSegmentSize(self):
|
def testTrtGraphConverter_MinimumSegmentSize(self):
|
||||||
@ -604,8 +592,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
input_saved_model_dir=input_saved_model_dir,
|
input_saved_model_dir=input_saved_model_dir,
|
||||||
output_saved_model_dir=output_saved_model_dir,
|
output_saved_model_dir=output_saved_model_dir,
|
||||||
is_dynamic_op=True,
|
is_dynamic_op=True,
|
||||||
maximum_cached_engines=2,
|
maximum_cached_engines=2)
|
||||||
use_function_backup=False) # Disallow fallback.
|
|
||||||
|
|
||||||
# Test the output GraphDef.
|
# Test the output GraphDef.
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
@ -631,7 +618,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
# the max, it should evict an old engine and create a new one.
|
# the max, it should evict an old engine and create a new one.
|
||||||
self._TestRun(sess, 3)
|
self._TestRun(sess, 3)
|
||||||
|
|
||||||
def _TestStaticOp(self, use_function_backup):
|
def _TestStaticOp(self):
|
||||||
if not is_tensorrt_enabled():
|
if not is_tensorrt_enabled():
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -641,8 +628,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
output_graph_def = self._ConvertGraph(
|
output_graph_def = self._ConvertGraph(
|
||||||
input_saved_model_dir=input_saved_model_dir,
|
input_saved_model_dir=input_saved_model_dir,
|
||||||
output_saved_model_dir=output_saved_model_dir,
|
output_saved_model_dir=output_saved_model_dir,
|
||||||
maximum_cached_engines=2, # This is noop, added just for testing.
|
maximum_cached_engines=2)
|
||||||
use_function_backup=use_function_backup)
|
|
||||||
|
|
||||||
# Test the output GraphDef.
|
# Test the output GraphDef.
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
@ -653,14 +639,12 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
self._TestRun(
|
self._TestRun(
|
||||||
sess,
|
sess,
|
||||||
1,
|
1,
|
||||||
use_function_backup=use_function_backup,
|
|
||||||
expect_engine_is_run=True)
|
expect_engine_is_run=True)
|
||||||
# Run with batch size 2, which exceed the max_batch_size, it should try
|
# Run with batch size 2, which exceed the max_batch_size, it should try
|
||||||
# to fall back to TF function.
|
# to fall back to TF function.
|
||||||
self._TestRun(
|
self._TestRun(
|
||||||
sess,
|
sess,
|
||||||
2,
|
2,
|
||||||
use_function_backup=use_function_backup,
|
|
||||||
expect_engine_is_run=False)
|
expect_engine_is_run=False)
|
||||||
|
|
||||||
# Test the output SavedModel
|
# Test the output SavedModel
|
||||||
@ -672,23 +656,17 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
self._TestRun(
|
self._TestRun(
|
||||||
sess,
|
sess,
|
||||||
1,
|
1,
|
||||||
use_function_backup=use_function_backup,
|
|
||||||
expect_engine_is_run=True)
|
expect_engine_is_run=True)
|
||||||
# Run with batch size 2, which exceed the max_batch_size, it should try
|
# Run with batch size 2, which exceed the max_batch_size, it should try
|
||||||
# to fall back to TF function.
|
# to fall back to TF function.
|
||||||
self._TestRun(
|
self._TestRun(
|
||||||
sess,
|
sess,
|
||||||
2,
|
2,
|
||||||
use_function_backup=use_function_backup,
|
|
||||||
expect_engine_is_run=False)
|
expect_engine_is_run=False)
|
||||||
|
|
||||||
@test_util.deprecated_graph_mode_only
|
@test_util.deprecated_graph_mode_only
|
||||||
def testTrtGraphConverter_StaticOp_NoFallback(self):
|
def testTrtGraphConverter_StaticOp(self):
|
||||||
self._TestStaticOp(use_function_backup=False)
|
self._TestStaticOp()
|
||||||
|
|
||||||
@test_util.deprecated_graph_mode_only
|
|
||||||
def testTrtGraphConverter_StaticOp_WithFallback(self):
|
|
||||||
self._TestStaticOp(use_function_backup=True)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
Reference in New Issue
Block a user