From 47431b61782400188b16979eeea87f47760e3385 Mon Sep 17 00:00:00 2001 From: DEKHTIARJonathan Date: Mon, 20 Apr 2020 18:48:58 -0700 Subject: [PATCH 1/4] Adding TF_TRT_BLACKLIST_OP --- .../compiler/tf2tensorrt/segment/segment.cc | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/tensorflow/compiler/tf2tensorrt/segment/segment.cc b/tensorflow/compiler/tf2tensorrt/segment/segment.cc index 4d9dd42a53a..2237ac20863 100644 --- a/tensorflow/compiler/tf2tensorrt/segment/segment.cc +++ b/tensorflow/compiler/tf2tensorrt/segment/segment.cc @@ -27,8 +27,11 @@ limitations under the License. #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/env_var.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT @@ -401,6 +404,13 @@ void ContractEdge(SimpleEdge* edge, SimpleGraph* graph, } } +namespace { +static void UpdateList(gtl::FlatSet* list, const string& values_to_add){ + +} + +} // namespace + Status SegmentGraph(const Graph* tf_graph, const std::function& candidate_fn, const std::function& input_candidate_fn, @@ -422,6 +432,20 @@ Status SegmentGraph(const Graph* tf_graph, // for TRT. std::unordered_set unsupported_ops; int num_unsupported_ops = 0; + + // Getting the nodes blacklisted for conversion + string tftrt_node_blacklist_str; + TF_CHECK_OK(ReadStringFromEnvVar( + "TF_TRT_BLACKLIST_OP", "", &tftrt_node_blacklist_str + )); + + auto tftrt_node_blacklist = gtl::FlatSet{}; + + for (const auto& x : str_util::Split(tftrt_node_blacklist_str, ",")) { + tftrt_node_blacklist.insert(x); + } + + // Parsing each node of the graph std::vector> node_segments; for (int i = 0; i < graph->num_node_ids(); ++i) { SimpleNode* node = graph->FindNodeId(i); @@ -443,6 +467,15 @@ Status SegmentGraph(const Graph* tf_graph, unsupported_ops.emplace(node->tf_node()->type_string()); num_unsupported_ops++; node = nullptr; + } else if (tftrt_node_blacklist.count(node->tf_node()->type_string())) { + // WARNING verbosity since the user explicitly requests this behavior. + LOG(WARNING) << "Blacklisted as TF-TRT candidate, " + << "(Op type: " << node->tf_node()->type_string() << "), " + << "(Op name: " << node->name() << "), " + << "(Reason: Blacklisted with the env var TF_TRT_BLACKLIST_OP)"; + unsupported_ops.emplace(node->tf_node()->type_string()); + num_unsupported_ops++; + node = nullptr; } else { VLOG(2) << "Accepted as a TF-TRT candidate, " << "(Op type: " << node->tf_node()->type_string() << "), " From e97fdb5539becc7425e3e72c666c80454bd3a83a Mon Sep 17 00:00:00 2001 From: DEKHTIARJonathan Date: Mon, 20 Apr 2020 18:57:26 -0700 Subject: [PATCH 2/4] Changing TF_TRT_BLACKLIST_OP to TF_TRT_BLACKLIST_OPS --- tensorflow/compiler/tf2tensorrt/segment/segment.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/tf2tensorrt/segment/segment.cc b/tensorflow/compiler/tf2tensorrt/segment/segment.cc index 2237ac20863..c431c5b7412 100644 --- a/tensorflow/compiler/tf2tensorrt/segment/segment.cc +++ b/tensorflow/compiler/tf2tensorrt/segment/segment.cc @@ -436,7 +436,7 @@ Status SegmentGraph(const Graph* tf_graph, // Getting the nodes blacklisted for conversion string tftrt_node_blacklist_str; TF_CHECK_OK(ReadStringFromEnvVar( - "TF_TRT_BLACKLIST_OP", "", &tftrt_node_blacklist_str + "TF_TRT_BLACKLIST_OPS", "", &tftrt_node_blacklist_str )); auto tftrt_node_blacklist = gtl::FlatSet{}; @@ -472,7 +472,7 @@ Status SegmentGraph(const Graph* tf_graph, LOG(WARNING) << "Blacklisted as TF-TRT candidate, " << "(Op type: " << node->tf_node()->type_string() << "), " << "(Op name: " << node->name() << "), " - << "(Reason: Blacklisted with the env var TF_TRT_BLACKLIST_OP)"; + << "(Reason: Blacklisted with the env var TF_TRT_BLACKLIST_OPS)"; unsupported_ops.emplace(node->tf_node()->type_string()); num_unsupported_ops++; node = nullptr; From 9be0eadeb292c1899017f680221cc48453796662 Mon Sep 17 00:00:00 2001 From: DEKHTIARJonathan Date: Mon, 20 Apr 2020 19:18:41 -0700 Subject: [PATCH 3/4] Cleaning Leftovers --- tensorflow/compiler/tf2tensorrt/segment/segment.cc | 7 ------- 1 file changed, 7 deletions(-) diff --git a/tensorflow/compiler/tf2tensorrt/segment/segment.cc b/tensorflow/compiler/tf2tensorrt/segment/segment.cc index c431c5b7412..d65dc2cd0be 100644 --- a/tensorflow/compiler/tf2tensorrt/segment/segment.cc +++ b/tensorflow/compiler/tf2tensorrt/segment/segment.cc @@ -404,13 +404,6 @@ void ContractEdge(SimpleEdge* edge, SimpleGraph* graph, } } -namespace { -static void UpdateList(gtl::FlatSet* list, const string& values_to_add){ - -} - -} // namespace - Status SegmentGraph(const Graph* tf_graph, const std::function& candidate_fn, const std::function& input_candidate_fn, From f56c62cdcd32af676770d93925c674a37667e8e4 Mon Sep 17 00:00:00 2001 From: DEKHTIARJonathan Date: Tue, 21 Apr 2020 10:53:23 -0700 Subject: [PATCH 4/4] Env Var name change --- tensorflow/compiler/tf2tensorrt/segment/segment.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/tf2tensorrt/segment/segment.cc b/tensorflow/compiler/tf2tensorrt/segment/segment.cc index d65dc2cd0be..f46fde009c9 100644 --- a/tensorflow/compiler/tf2tensorrt/segment/segment.cc +++ b/tensorflow/compiler/tf2tensorrt/segment/segment.cc @@ -429,7 +429,7 @@ Status SegmentGraph(const Graph* tf_graph, // Getting the nodes blacklisted for conversion string tftrt_node_blacklist_str; TF_CHECK_OK(ReadStringFromEnvVar( - "TF_TRT_BLACKLIST_OPS", "", &tftrt_node_blacklist_str + "TF_TRT_OP_BLACKLIST", "", &tftrt_node_blacklist_str )); auto tftrt_node_blacklist = gtl::FlatSet{}; @@ -465,7 +465,7 @@ Status SegmentGraph(const Graph* tf_graph, LOG(WARNING) << "Blacklisted as TF-TRT candidate, " << "(Op type: " << node->tf_node()->type_string() << "), " << "(Op name: " << node->name() << "), " - << "(Reason: Blacklisted with the env var TF_TRT_BLACKLIST_OPS)"; + << "(Reason: Blacklisted with the env var TF_TRT_OP_BLACKLIST)"; unsupported_ops.emplace(node->tf_node()->type_string()); num_unsupported_ops++; node = nullptr;