diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD index af1877a2394..337d198cd10 100644 --- a/tensorflow/compiler/tf2tensorrt/BUILD +++ b/tensorflow/compiler/tf2tensorrt/BUILD @@ -487,6 +487,8 @@ cc_library( copts = tf_copts(), deps = [ "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core:lib_proto_parsing", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/strings", diff --git a/tensorflow/compiler/tf2tensorrt/segment/segment.cc b/tensorflow/compiler/tf2tensorrt/segment/segment.cc index 4d9dd42a53a..9b151375c8d 100644 --- a/tensorflow/compiler/tf2tensorrt/segment/segment.cc +++ b/tensorflow/compiler/tf2tensorrt/segment/segment.cc @@ -27,8 +27,11 @@ limitations under the License. #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/env_var.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT @@ -422,6 +425,19 @@ Status SegmentGraph(const Graph* tf_graph, // for TRT. std::unordered_set unsupported_ops; int num_unsupported_ops = 0; + + // Getting the operations blacklisted for conversion + string tftrt_op_blacklist_str; + TF_CHECK_OK( + ReadStringFromEnvVar("TF_TRT_OP_BLACKLIST", "", &tftrt_op_blacklist_str)); + + auto tftrt_op_blacklist = gtl::FlatSet{}; // non-absl ok + + for (const auto& x : str_util::Split(tftrt_op_blacklist_str, ",")) { + tftrt_op_blacklist.insert(x); + } + + // Parsing each node of the graph std::vector> node_segments; for (int i = 0; i < graph->num_node_ids(); ++i) { SimpleNode* node = graph->FindNodeId(i); @@ -443,6 +459,16 @@ Status SegmentGraph(const Graph* tf_graph, unsupported_ops.emplace(node->tf_node()->type_string()); num_unsupported_ops++; node = nullptr; + } else if (tftrt_op_blacklist.count(node->tf_node()->type_string())) { + // WARNING verbosity since the user explicitly requests this behavior. + LOG(WARNING) + << "Blacklisted as TF-TRT candidate, " + << "(Op type: " << node->tf_node()->type_string() << "), " + << "(Op name: " << node->name() << "), " + << "(Reason: Blacklisted with the env var TF_TRT_OP_BLACKLIST)"; + unsupported_ops.emplace(node->tf_node()->type_string()); + num_unsupported_ops++; + node = nullptr; } else { VLOG(2) << "Accepted as a TF-TRT candidate, " << "(Op type: " << node->tf_node()->type_string() << "), "