Adding TF_TRT_BLACKLIST_OP
This commit is contained in:
parent
7c5934dfd9
commit
47431b6178
@ -27,8 +27,11 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/env_var.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_TENSORRT
|
||||
@ -401,6 +404,13 @@ void ContractEdge(SimpleEdge* edge, SimpleGraph* graph,
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
static void UpdateList(gtl::FlatSet<string>* list, const string& values_to_add){
|
||||
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status SegmentGraph(const Graph* tf_graph,
|
||||
const std::function<Status(const Node*)>& candidate_fn,
|
||||
const std::function<bool(const Edge*)>& input_candidate_fn,
|
||||
@ -422,6 +432,20 @@ Status SegmentGraph(const Graph* tf_graph,
|
||||
// for TRT.
|
||||
std::unordered_set<string> unsupported_ops;
|
||||
int num_unsupported_ops = 0;
|
||||
|
||||
// Getting the nodes blacklisted for conversion
|
||||
string tftrt_node_blacklist_str;
|
||||
TF_CHECK_OK(ReadStringFromEnvVar(
|
||||
"TF_TRT_BLACKLIST_OP", "", &tftrt_node_blacklist_str
|
||||
));
|
||||
|
||||
auto tftrt_node_blacklist = gtl::FlatSet<string>{};
|
||||
|
||||
for (const auto& x : str_util::Split(tftrt_node_blacklist_str, ",")) {
|
||||
tftrt_node_blacklist.insert(x);
|
||||
}
|
||||
|
||||
// Parsing each node of the graph
|
||||
std::vector<UnionFind<SimpleNode*>> node_segments;
|
||||
for (int i = 0; i < graph->num_node_ids(); ++i) {
|
||||
SimpleNode* node = graph->FindNodeId(i);
|
||||
@ -443,6 +467,15 @@ Status SegmentGraph(const Graph* tf_graph,
|
||||
unsupported_ops.emplace(node->tf_node()->type_string());
|
||||
num_unsupported_ops++;
|
||||
node = nullptr;
|
||||
} else if (tftrt_node_blacklist.count(node->tf_node()->type_string())) {
|
||||
// WARNING verbosity since the user explicitly requests this behavior.
|
||||
LOG(WARNING) << "Blacklisted as TF-TRT candidate, "
|
||||
<< "(Op type: " << node->tf_node()->type_string() << "), "
|
||||
<< "(Op name: " << node->name() << "), "
|
||||
<< "(Reason: Blacklisted with the env var TF_TRT_BLACKLIST_OP)";
|
||||
unsupported_ops.emplace(node->tf_node()->type_string());
|
||||
num_unsupported_ops++;
|
||||
node = nullptr;
|
||||
} else {
|
||||
VLOG(2) << "Accepted as a TF-TRT candidate, "
|
||||
<< "(Op type: " << node->tf_node()->type_string() << "), "
|
||||
|
Loading…
Reference in New Issue
Block a user