Adding TF_TRT_BLACKLIST_OP
This commit is contained in:
parent
7c5934dfd9
commit
47431b6178
@ -27,8 +27,11 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/graph/graph_constructor.h"
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
#include "tensorflow/core/util/env_var.h"
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
#if GOOGLE_TENSORRT
|
#if GOOGLE_TENSORRT
|
||||||
@ -401,6 +404,13 @@ void ContractEdge(SimpleEdge* edge, SimpleGraph* graph,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
static void UpdateList(gtl::FlatSet<string>* list, const string& values_to_add){
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
Status SegmentGraph(const Graph* tf_graph,
|
Status SegmentGraph(const Graph* tf_graph,
|
||||||
const std::function<Status(const Node*)>& candidate_fn,
|
const std::function<Status(const Node*)>& candidate_fn,
|
||||||
const std::function<bool(const Edge*)>& input_candidate_fn,
|
const std::function<bool(const Edge*)>& input_candidate_fn,
|
||||||
@ -422,6 +432,20 @@ Status SegmentGraph(const Graph* tf_graph,
|
|||||||
// for TRT.
|
// for TRT.
|
||||||
std::unordered_set<string> unsupported_ops;
|
std::unordered_set<string> unsupported_ops;
|
||||||
int num_unsupported_ops = 0;
|
int num_unsupported_ops = 0;
|
||||||
|
|
||||||
|
// Getting the nodes blacklisted for conversion
|
||||||
|
string tftrt_node_blacklist_str;
|
||||||
|
TF_CHECK_OK(ReadStringFromEnvVar(
|
||||||
|
"TF_TRT_BLACKLIST_OP", "", &tftrt_node_blacklist_str
|
||||||
|
));
|
||||||
|
|
||||||
|
auto tftrt_node_blacklist = gtl::FlatSet<string>{};
|
||||||
|
|
||||||
|
for (const auto& x : str_util::Split(tftrt_node_blacklist_str, ",")) {
|
||||||
|
tftrt_node_blacklist.insert(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parsing each node of the graph
|
||||||
std::vector<UnionFind<SimpleNode*>> node_segments;
|
std::vector<UnionFind<SimpleNode*>> node_segments;
|
||||||
for (int i = 0; i < graph->num_node_ids(); ++i) {
|
for (int i = 0; i < graph->num_node_ids(); ++i) {
|
||||||
SimpleNode* node = graph->FindNodeId(i);
|
SimpleNode* node = graph->FindNodeId(i);
|
||||||
@ -443,6 +467,15 @@ Status SegmentGraph(const Graph* tf_graph,
|
|||||||
unsupported_ops.emplace(node->tf_node()->type_string());
|
unsupported_ops.emplace(node->tf_node()->type_string());
|
||||||
num_unsupported_ops++;
|
num_unsupported_ops++;
|
||||||
node = nullptr;
|
node = nullptr;
|
||||||
|
} else if (tftrt_node_blacklist.count(node->tf_node()->type_string())) {
|
||||||
|
// WARNING verbosity since the user explicitly requests this behavior.
|
||||||
|
LOG(WARNING) << "Blacklisted as TF-TRT candidate, "
|
||||||
|
<< "(Op type: " << node->tf_node()->type_string() << "), "
|
||||||
|
<< "(Op name: " << node->name() << "), "
|
||||||
|
<< "(Reason: Blacklisted with the env var TF_TRT_BLACKLIST_OP)";
|
||||||
|
unsupported_ops.emplace(node->tf_node()->type_string());
|
||||||
|
num_unsupported_ops++;
|
||||||
|
node = nullptr;
|
||||||
} else {
|
} else {
|
||||||
VLOG(2) << "Accepted as a TF-TRT candidate, "
|
VLOG(2) << "Accepted as a TF-TRT candidate, "
|
||||||
<< "(Op type: " << node->tf_node()->type_string() << "), "
|
<< "(Op type: " << node->tf_node()->type_string() << "), "
|
||||||
|
Loading…
Reference in New Issue
Block a user