diff --git a/tensorflow/compiler/tf2tensorrt/segment/segment.cc b/tensorflow/compiler/tf2tensorrt/segment/segment.cc index 4d9dd42a53a..2237ac20863 100644 --- a/tensorflow/compiler/tf2tensorrt/segment/segment.cc +++ b/tensorflow/compiler/tf2tensorrt/segment/segment.cc @@ -27,8 +27,11 @@ limitations under the License. #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/env_var.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT @@ -401,6 +404,13 @@ void ContractEdge(SimpleEdge* edge, SimpleGraph* graph, } } +namespace { +static void UpdateList(gtl::FlatSet* list, const string& values_to_add){ + +} + +} // namespace + Status SegmentGraph(const Graph* tf_graph, const std::function& candidate_fn, const std::function& input_candidate_fn, @@ -422,6 +432,20 @@ Status SegmentGraph(const Graph* tf_graph, // for TRT. std::unordered_set unsupported_ops; int num_unsupported_ops = 0; + + // Getting the nodes blacklisted for conversion + string tftrt_node_blacklist_str; + TF_CHECK_OK(ReadStringFromEnvVar( + "TF_TRT_BLACKLIST_OP", "", &tftrt_node_blacklist_str + )); + + auto tftrt_node_blacklist = gtl::FlatSet{}; + + for (const auto& x : str_util::Split(tftrt_node_blacklist_str, ",")) { + tftrt_node_blacklist.insert(x); + } + + // Parsing each node of the graph std::vector> node_segments; for (int i = 0; i < graph->num_node_ids(); ++i) { SimpleNode* node = graph->FindNodeId(i); @@ -443,6 +467,15 @@ Status SegmentGraph(const Graph* tf_graph, unsupported_ops.emplace(node->tf_node()->type_string()); num_unsupported_ops++; node = nullptr; + } else if (tftrt_node_blacklist.count(node->tf_node()->type_string())) { + // WARNING verbosity since the user explicitly requests this behavior. + LOG(WARNING) << "Blacklisted as TF-TRT candidate, " + << "(Op type: " << node->tf_node()->type_string() << "), " + << "(Op name: " << node->name() << "), " + << "(Reason: Blacklisted with the env var TF_TRT_BLACKLIST_OP)"; + unsupported_ops.emplace(node->tf_node()->type_string()); + num_unsupported_ops++; + node = nullptr; } else { VLOG(2) << "Accepted as a TF-TRT candidate, " << "(Op type: " << node->tf_node()->type_string() << "), "