Adding TF_TRT_BLACKLIST_OP

This commit is contained in:
DEKHTIARJonathan 2020-04-20 18:48:58 -07:00
parent 7c5934dfd9
commit 47431b6178

View File

@ -27,8 +27,11 @@ limitations under the License.
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/env_var.h"
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
@ -401,6 +404,13 @@ void ContractEdge(SimpleEdge* edge, SimpleGraph* graph,
}
}
namespace {
static void UpdateList(gtl::FlatSet<string>* list, const string& values_to_add){
}
} // namespace
Status SegmentGraph(const Graph* tf_graph,
const std::function<Status(const Node*)>& candidate_fn,
const std::function<bool(const Edge*)>& input_candidate_fn,
@ -422,6 +432,20 @@ Status SegmentGraph(const Graph* tf_graph,
// for TRT.
std::unordered_set<string> unsupported_ops;
int num_unsupported_ops = 0;
// Getting the nodes blacklisted for conversion
string tftrt_node_blacklist_str;
TF_CHECK_OK(ReadStringFromEnvVar(
"TF_TRT_BLACKLIST_OP", "", &tftrt_node_blacklist_str
));
auto tftrt_node_blacklist = gtl::FlatSet<string>{};
for (const auto& x : str_util::Split(tftrt_node_blacklist_str, ",")) {
tftrt_node_blacklist.insert(x);
}
// Parsing each node of the graph
std::vector<UnionFind<SimpleNode*>> node_segments;
for (int i = 0; i < graph->num_node_ids(); ++i) {
SimpleNode* node = graph->FindNodeId(i);
@ -443,6 +467,15 @@ Status SegmentGraph(const Graph* tf_graph,
unsupported_ops.emplace(node->tf_node()->type_string());
num_unsupported_ops++;
node = nullptr;
} else if (tftrt_node_blacklist.count(node->tf_node()->type_string())) {
// WARNING verbosity since the user explicitly requests this behavior.
LOG(WARNING) << "Blacklisted as TF-TRT candidate, "
<< "(Op type: " << node->tf_node()->type_string() << "), "
<< "(Op name: " << node->name() << "), "
<< "(Reason: Blacklisted with the env var TF_TRT_BLACKLIST_OP)";
unsupported_ops.emplace(node->tf_node()->type_string());
num_unsupported_ops++;
node = nullptr;
} else {
VLOG(2) << "Accepted as a TF-TRT candidate, "
<< "(Op type: " << node->tf_node()->type_string() << "), "