Merge pull request #38730 from DEKHTIARJonathan:tf_trt_op_blacklist

PiperOrigin-RevId: 307827337
Change-Id: If52426b608ad98d311790bbc765b1555855f396b
This commit is contained in:
TensorFlower Gardener 2020-04-22 09:03:19 -07:00
commit ef4e3be946
2 changed files with 28 additions and 0 deletions

View File

@ -487,6 +487,8 @@ cc_library(
copts = tf_copts(),
deps = [
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:lib_proto_parsing",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/strings",

View File

@ -27,8 +27,11 @@ limitations under the License.
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/env_var.h"
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
@ -422,6 +425,19 @@ Status SegmentGraph(const Graph* tf_graph,
// for TRT.
std::unordered_set<string> unsupported_ops;
int num_unsupported_ops = 0;
// Getting the operations blacklisted for conversion
string tftrt_op_blacklist_str;
TF_CHECK_OK(
ReadStringFromEnvVar("TF_TRT_OP_BLACKLIST", "", &tftrt_op_blacklist_str));
auto tftrt_op_blacklist = gtl::FlatSet<string>{}; // non-absl ok
for (const auto& x : str_util::Split(tftrt_op_blacklist_str, ",")) {
tftrt_op_blacklist.insert(x);
}
// Parsing each node of the graph
std::vector<UnionFind<SimpleNode*>> node_segments;
for (int i = 0; i < graph->num_node_ids(); ++i) {
SimpleNode* node = graph->FindNodeId(i);
@ -443,6 +459,16 @@ Status SegmentGraph(const Graph* tf_graph,
unsupported_ops.emplace(node->tf_node()->type_string());
num_unsupported_ops++;
node = nullptr;
} else if (tftrt_op_blacklist.count(node->tf_node()->type_string())) {
// WARNING verbosity since the user explicitly requests this behavior.
LOG(WARNING)
<< "Blacklisted as TF-TRT candidate, "
<< "(Op type: " << node->tf_node()->type_string() << "), "
<< "(Op name: " << node->name() << "), "
<< "(Reason: Blacklisted with the env var TF_TRT_OP_BLACKLIST)";
unsupported_ops.emplace(node->tf_node()->type_string());
num_unsupported_ops++;
node = nullptr;
} else {
VLOG(2) << "Accepted as a TF-TRT candidate, "
<< "(Op type: " << node->tf_node()->type_string() << "), "