Merge pull request #38730 from DEKHTIARJonathan:tf_trt_op_blacklist
PiperOrigin-RevId: 307827337 Change-Id: If52426b608ad98d311790bbc765b1555855f396b
This commit is contained in:
commit
ef4e3be946
@ -487,6 +487,8 @@ cc_library(
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
@ -27,8 +27,11 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/env_var.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_TENSORRT
|
||||
@ -422,6 +425,19 @@ Status SegmentGraph(const Graph* tf_graph,
|
||||
// for TRT.
|
||||
std::unordered_set<string> unsupported_ops;
|
||||
int num_unsupported_ops = 0;
|
||||
|
||||
// Getting the operations blacklisted for conversion
|
||||
string tftrt_op_blacklist_str;
|
||||
TF_CHECK_OK(
|
||||
ReadStringFromEnvVar("TF_TRT_OP_BLACKLIST", "", &tftrt_op_blacklist_str));
|
||||
|
||||
auto tftrt_op_blacklist = gtl::FlatSet<string>{}; // non-absl ok
|
||||
|
||||
for (const auto& x : str_util::Split(tftrt_op_blacklist_str, ",")) {
|
||||
tftrt_op_blacklist.insert(x);
|
||||
}
|
||||
|
||||
// Parsing each node of the graph
|
||||
std::vector<UnionFind<SimpleNode*>> node_segments;
|
||||
for (int i = 0; i < graph->num_node_ids(); ++i) {
|
||||
SimpleNode* node = graph->FindNodeId(i);
|
||||
@ -443,6 +459,16 @@ Status SegmentGraph(const Graph* tf_graph,
|
||||
unsupported_ops.emplace(node->tf_node()->type_string());
|
||||
num_unsupported_ops++;
|
||||
node = nullptr;
|
||||
} else if (tftrt_op_blacklist.count(node->tf_node()->type_string())) {
|
||||
// WARNING verbosity since the user explicitly requests this behavior.
|
||||
LOG(WARNING)
|
||||
<< "Blacklisted as TF-TRT candidate, "
|
||||
<< "(Op type: " << node->tf_node()->type_string() << "), "
|
||||
<< "(Op name: " << node->name() << "), "
|
||||
<< "(Reason: Blacklisted with the env var TF_TRT_OP_BLACKLIST)";
|
||||
unsupported_ops.emplace(node->tf_node()->type_string());
|
||||
num_unsupported_ops++;
|
||||
node = nullptr;
|
||||
} else {
|
||||
VLOG(2) << "Accepted as a TF-TRT candidate, "
|
||||
<< "(Op type: " << node->tf_node()->type_string() << "), "
|
||||
|
Loading…
Reference in New Issue
Block a user