Changing context search for BiasAddGrad rewrite from BFS to stricter check
This commit is contained in:
parent
4c0052dc4b
commit
059828d177
@ -247,16 +247,10 @@ namespace tensorflow {
|
|||||||
//
|
//
|
||||||
// P = Conv2DWithBiasBackpropBias(O, O_m)
|
// P = Conv2DWithBiasBackpropBias(O, O_m)
|
||||||
//
|
//
|
||||||
// 'Distance' between input of BiasAddGrad and _MklConv2D in terms of hops is
|
// Rewrite of BiasAddGrad into Conv2DWithBiasBackpropBias takes place depending
|
||||||
// the context matching depth. If _MklConv2DWithBias is not within the context
|
// on the matching 'context'. The term context is loosely related to which
|
||||||
// matching depth, then we do not rewrite BiasAddGrad.
|
// forward op is _associated_ to BiasAddGrad. If it is _MklConv2DWithBias then
|
||||||
|
// we consider it Conv2D context; if it is MatMul, then it is MatMul context.
|
||||||
// How many hops do we search for matching node in the backward dataflow graph?
|
|
||||||
// We use maxhop of 10 based on empirical observations. Also, these are
|
|
||||||
// maxhops in backward data-flow graph. Since input of forward nodes (Conv2D)
|
|
||||||
// directly goes to backward nodes, we do not expect the hop-distance
|
|
||||||
// would be more than few nodes.
|
|
||||||
static size_t kNodeMergeContextMaxDepth = 10;
|
|
||||||
|
|
||||||
class MklLayoutRewritePass : public GraphOptimizationPass {
|
class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||||
public:
|
public:
|
||||||
@ -280,6 +274,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
csinfo_.max_pool = "MaxPool";
|
csinfo_.max_pool = "MaxPool";
|
||||||
csinfo_.max_pool_grad = "MaxPoolGrad";
|
csinfo_.max_pool_grad = "MaxPoolGrad";
|
||||||
csinfo_.mkl_conv2d = "_MklConv2D";
|
csinfo_.mkl_conv2d = "_MklConv2D";
|
||||||
|
csinfo_.mkl_conv2d_grad_input = "_MklConv2DBackpropInput";
|
||||||
|
csinfo_.mkl_conv2d_grad_filter = "_MklConv2DBackpropFilter";
|
||||||
csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias";
|
csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias";
|
||||||
csinfo_.mkl_conv2d_with_bias_backprop_bias =
|
csinfo_.mkl_conv2d_with_bias_backprop_bias =
|
||||||
"_MklConv2DWithBiasBackpropBias";
|
"_MklConv2DWithBiasBackpropBias";
|
||||||
@ -360,16 +356,12 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
minfo_.push_back({csinfo_.mkl_conv2d, csinfo_.bias_add, 0,
|
minfo_.push_back({csinfo_.mkl_conv2d, csinfo_.bias_add, 0,
|
||||||
csinfo_.mkl_conv2d_with_bias});
|
csinfo_.mkl_conv2d_with_bias});
|
||||||
|
|
||||||
// We use maxhop of 10 based on empirical observations. Also, these are
|
|
||||||
// maxhops in backward data-flow graph. Since input of forward nodes
|
|
||||||
// (Conv2D) directly goes to backward nodes, we do not expect the
|
|
||||||
// hop-distance would be more than few nodes.
|
|
||||||
biasaddgrad_matmul_context_ = {csinfo_.bias_add_grad, csinfo_.matmul,
|
biasaddgrad_matmul_context_ = {csinfo_.bias_add_grad, csinfo_.matmul,
|
||||||
kNodeMergeContextMaxDepth};
|
IsBiasAddGradInMatMulContext};
|
||||||
|
|
||||||
biasaddgrad_conv2dwithbias_context_ = {csinfo_.bias_add_grad,
|
biasaddgrad_conv2dwithbias_context_ = {csinfo_.bias_add_grad,
|
||||||
csinfo_.mkl_conv2d_with_bias,
|
csinfo_.mkl_conv2d_with_bias,
|
||||||
kNodeMergeContextMaxDepth};
|
IsBiasAddGradInConv2DWithBiasContext};
|
||||||
|
|
||||||
cinfo_.push_back(&biasaddgrad_matmul_context_);
|
cinfo_.push_back(&biasaddgrad_matmul_context_);
|
||||||
cinfo_.push_back(&biasaddgrad_conv2dwithbias_context_);
|
cinfo_.push_back(&biasaddgrad_conv2dwithbias_context_);
|
||||||
@ -392,9 +384,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
string node; // Name of the node to be rewritten
|
string node; // Name of the node to be rewritten
|
||||||
string fwd; // Name of the node in the forward pass that this node
|
string fwd; // Name of the node in the forward pass that this node
|
||||||
// corresponds to
|
// corresponds to
|
||||||
size_t max_hop; // Maximum number of hops the fwd is located
|
std::function<bool(const Node*, const Node**, void* c)> context_match_fn;
|
||||||
// from this node. If the fwd is farther than max_hop
|
|
||||||
// then we do not rewrite the node.
|
|
||||||
} ContextInfo;
|
} ContextInfo;
|
||||||
|
|
||||||
/// Structure to specify the name of an original node, its new name after
|
/// Structure to specify the name of an original node, its new name after
|
||||||
@ -438,7 +428,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
|
|
||||||
/// Structure to store all constant strings
|
/// Structure to store all constant strings
|
||||||
/// NOTE: names are alphabetically sorted.
|
/// NOTE: names are alphabetically sorted.
|
||||||
struct {
|
typedef struct {
|
||||||
string avg_pool;
|
string avg_pool;
|
||||||
string avg_pool_grad;
|
string avg_pool_grad;
|
||||||
string bias_add;
|
string bias_add;
|
||||||
@ -457,13 +447,15 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
string max_pool;
|
string max_pool;
|
||||||
string max_pool_grad;
|
string max_pool_grad;
|
||||||
string mkl_conv2d;
|
string mkl_conv2d;
|
||||||
|
string mkl_conv2d_grad_input;
|
||||||
|
string mkl_conv2d_grad_filter;
|
||||||
string mkl_conv2d_with_bias;
|
string mkl_conv2d_with_bias;
|
||||||
string mkl_conv2d_with_bias_backprop_bias;
|
string mkl_conv2d_with_bias_backprop_bias;
|
||||||
string relu;
|
string relu;
|
||||||
string relu_grad;
|
string relu_grad;
|
||||||
string reshape;
|
string reshape;
|
||||||
string split;
|
string split;
|
||||||
} csinfo_;
|
} ConstStringsInfo;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/// Maintain info about nodes to rewrite
|
/// Maintain info about nodes to rewrite
|
||||||
@ -478,6 +470,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
/// Maintain info about nodes to rewrite
|
/// Maintain info about nodes to rewrite
|
||||||
static std::vector<ContextInfo*> cinfo_;
|
static std::vector<ContextInfo*> cinfo_;
|
||||||
|
|
||||||
|
/// Maintain structure of constant strings
|
||||||
|
static ConstStringsInfo csinfo_;
|
||||||
|
|
||||||
/// Context variables used in referencing rules
|
/// Context variables used in referencing rules
|
||||||
static ContextInfo biasaddgrad_matmul_context_;
|
static ContextInfo biasaddgrad_matmul_context_;
|
||||||
static ContextInfo biasaddgrad_conv2dwithbias_context_;
|
static ContextInfo biasaddgrad_conv2dwithbias_context_;
|
||||||
@ -629,6 +624,173 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Is BiasAddGrad node in 'n' is associated with Conv2DWithBias node
|
||||||
|
// specified in contextinfo 'ci'. Function updates fwd_node to point
|
||||||
|
// to Conv2DWithBias node if 'n' is associated with Conv2DWithBias.
|
||||||
|
//
|
||||||
|
// Association checks for one of the following graphs:
|
||||||
|
//
|
||||||
|
// Graph A:
|
||||||
|
//
|
||||||
|
// _ = Conv2DWithBias(F, I, _)
|
||||||
|
// ..
|
||||||
|
// _ = Conv2DBackpropFilter(F, _, G)
|
||||||
|
// _ = Conv2DBackpropInput(_, I, G)
|
||||||
|
// _ = BiasAddGrad(G)
|
||||||
|
//
|
||||||
|
// OR
|
||||||
|
//
|
||||||
|
// Graph B:
|
||||||
|
//
|
||||||
|
// _ = Conv2DWithBias(F, _, _)
|
||||||
|
// ..
|
||||||
|
// _ = Conv2DBackpropFilter(F, _, G)
|
||||||
|
// _ = BiasAddGrad(G)
|
||||||
|
//
|
||||||
|
// Here F, G, and I are graph nodes; _ represents graph nodes that we
|
||||||
|
// don't care here.
|
||||||
|
//
|
||||||
|
// @return - true (if BiasAddGrad is associated with Conv2DWithBias);
|
||||||
|
// false otherwise.
|
||||||
|
static bool IsBiasAddGradInConv2DWithBiasContext(const Node* n,
|
||||||
|
const Node** fwd_node,
|
||||||
|
void* ci) {
|
||||||
|
CHECK_NOTNULL(n);
|
||||||
|
CHECK_NOTNULL(fwd_node);
|
||||||
|
CHECK_NOTNULL(ci);
|
||||||
|
*fwd_node = nullptr;
|
||||||
|
|
||||||
|
CHECK_EQ(n->type_string(), csinfo_.bias_add_grad);
|
||||||
|
|
||||||
|
// Get the only 1 input of BiasAddGrad.
|
||||||
|
CHECK_EQ(n->num_inputs(), 1);
|
||||||
|
const Node* bias_add_grad_inp = nullptr;
|
||||||
|
TF_CHECK_OK(n->input_node(0, &bias_add_grad_inp));
|
||||||
|
CHECK_NOTNULL(bias_add_grad_inp);
|
||||||
|
|
||||||
|
// Check if this input also goes to BackpropFilter and BackpropInput
|
||||||
|
// as 3rd input.
|
||||||
|
bool found_backprop_input = false;
|
||||||
|
bool found_backprop_filter = false;
|
||||||
|
Node* backprop_filter_node = nullptr;
|
||||||
|
Node* backprop_input_node = nullptr;
|
||||||
|
|
||||||
|
for (const Edge* e : bias_add_grad_inp->out_edges()) {
|
||||||
|
Node* third_input = nullptr;
|
||||||
|
if (e->dst()->type_string() == csinfo_.conv2d_grad_input ||
|
||||||
|
e->dst()->type_string() == csinfo_.mkl_conv2d_grad_input) {
|
||||||
|
// Third input (index 2) of BackpropInput
|
||||||
|
TF_CHECK_OK(e->dst()->input_node(2, &third_input));
|
||||||
|
// Third input (index 2) of BackpropInput must be same as the input
|
||||||
|
// of BiasAddGrad.
|
||||||
|
if (third_input == bias_add_grad_inp) {
|
||||||
|
found_backprop_input = true;
|
||||||
|
backprop_input_node = e->dst();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (e->dst()->type_string() == csinfo_.conv2d_grad_filter ||
|
||||||
|
e->dst()->type_string() == csinfo_.mkl_conv2d_grad_filter) {
|
||||||
|
// Third input (index 2) of BackpropFilter
|
||||||
|
TF_CHECK_OK(e->dst()->input_node(2, &third_input));
|
||||||
|
// Third input (index 2) of BackpropFilter must be same as the input
|
||||||
|
// of BiasAddGrad.
|
||||||
|
if (third_input == bias_add_grad_inp) {
|
||||||
|
found_backprop_filter = true;
|
||||||
|
backprop_filter_node = e->dst();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we found both the nodes, then we can stop the search.
|
||||||
|
if (found_backprop_input && found_backprop_filter) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If BackpropFilter node is not found, then this is not
|
||||||
|
// Conv2DWithBias context. For 2nd graph in the example above, only
|
||||||
|
// BackpropFilter would be present.
|
||||||
|
if (!found_backprop_filter) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, we found the nodes.
|
||||||
|
CHECK_NOTNULL(backprop_filter_node);
|
||||||
|
if (found_backprop_input) {
|
||||||
|
CHECK_NOTNULL(backprop_input_node);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now that we confirmed that this is Conv2DWithBias context, we need to
|
||||||
|
// get access to the forward node (Conv2DWithBias). 2nd input of
|
||||||
|
// Conv2DWithBias is same as the 2nd input of Conv2DBackpropInput; 1st
|
||||||
|
// input of Conv2DWithBias is same as the 1st input of Conv2DBackpropFilter
|
||||||
|
// (This comes from definition of gradient computation for Conv2D).
|
||||||
|
if (found_backprop_input) {
|
||||||
|
// Graph A in the example.
|
||||||
|
Node* second_inp_of_input = nullptr;
|
||||||
|
Node* first_inp_of_filter = nullptr;
|
||||||
|
TF_CHECK_OK(backprop_input_node->input_node(1, &second_inp_of_input));
|
||||||
|
TF_CHECK_OK(backprop_filter_node->input_node(0, &first_inp_of_filter));
|
||||||
|
CHECK_NOTNULL(second_inp_of_input);
|
||||||
|
CHECK_NOTNULL(first_inp_of_filter);
|
||||||
|
|
||||||
|
// Now we need to find out Conv2DWithBias node from these input nodes.
|
||||||
|
// Conv2DWithBias node is the node that accepts both the nodes
|
||||||
|
// second_inp_of_input and first_inp_of_filter in 2nd and 1st input slots.
|
||||||
|
for (const Edge* fe : first_inp_of_filter->out_edges()) {
|
||||||
|
if (fe->dst()->type_string() == csinfo_.mkl_conv2d_with_bias &&
|
||||||
|
fe->dst_input() == 0) {
|
||||||
|
for (const Edge* ie : second_inp_of_input->out_edges()) {
|
||||||
|
if (ie->dst()->type_string() == csinfo_.mkl_conv2d_with_bias &&
|
||||||
|
ie->dst_input() == 1 && fe->dst() == ie->dst()) {
|
||||||
|
VLOG(1) << "MklLayoutRewritePass: found "
|
||||||
|
<< fe->dst()->DebugString()
|
||||||
|
<< " as the forward node for matching context, backward"
|
||||||
|
<< " node is: " << n->DebugString();
|
||||||
|
*fwd_node = fe->dst();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// We did not find BackpropInput, so we work with BackpropFilter only.
|
||||||
|
// Graph B in the example.
|
||||||
|
Node* first_inp_of_filter = nullptr;
|
||||||
|
TF_CHECK_OK(backprop_filter_node->input_node(0, &first_inp_of_filter));
|
||||||
|
CHECK_NOTNULL(first_inp_of_filter);
|
||||||
|
|
||||||
|
// Now we need to find out Conv2DWithBias node from first input of
|
||||||
|
// BackpropFIlter. Conv2DWithBias node is the node that accepts
|
||||||
|
// first_inp_of_filter in 1st input slot.
|
||||||
|
for (const Edge* fe : first_inp_of_filter->out_edges()) {
|
||||||
|
if (fe->dst()->type_string() == csinfo_.mkl_conv2d_with_bias &&
|
||||||
|
fe->dst_input() == 0) {
|
||||||
|
VLOG(1) << "MklLayoutRewritePass: found "
|
||||||
|
<< fe->dst()->DebugString()
|
||||||
|
<< " as the forward node for matching context, backward"
|
||||||
|
<< " node is: " << n->DebugString();
|
||||||
|
*fwd_node = fe->dst();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Is BiasAddGrad node in 'n' is associated with MatMul node
|
||||||
|
// specified in contextinfo 'ci'. Function does not update fwd_node.
|
||||||
|
//
|
||||||
|
// @return - true (if BiasAddGrad is associated with MatMul);
|
||||||
|
// false otherwise.
|
||||||
|
static bool IsBiasAddGradInMatMulContext(const Node* n,
|
||||||
|
const Node** fwd_node,
|
||||||
|
void* ci) {
|
||||||
|
return (!IsBiasAddGradInConv2DWithBiasContext(n, fwd_node, ci));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
// Rewrite rule that uses context-information for matching,
|
// Rewrite rule that uses context-information for matching,
|
||||||
// used in scenario 2.
|
// used in scenario 2.
|
||||||
//
|
//
|
||||||
@ -639,8 +801,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
static bool ContextMatchRewrite(const Node* n, const ContextInfo* c);
|
static bool ContextMatchRewrite(const Node* n, const ContextInfo* c);
|
||||||
|
|
||||||
// Helper function that searches the matching contextinfo for the node.
|
// Helper function that searches the matching contextinfo for the node.
|
||||||
// Implements depth-first search in the data dependence graph for the
|
|
||||||
// gradient op in the backward direction.
|
|
||||||
//
|
//
|
||||||
// @input n - Node (gradient op) whose contextinfo is to be searched,
|
// @input n - Node (gradient op) whose contextinfo is to be searched,
|
||||||
// fwd_node - pointer to node from the forward pass that this node
|
// fwd_node - pointer to node from the forward pass that this node
|
||||||
@ -788,6 +948,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
Node* orig_node);
|
Node* orig_node);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
MklLayoutRewritePass::ConstStringsInfo MklLayoutRewritePass::csinfo_;
|
||||||
MklLayoutRewritePass::ContextInfo
|
MklLayoutRewritePass::ContextInfo
|
||||||
MklLayoutRewritePass::biasaddgrad_conv2dwithbias_context_;
|
MklLayoutRewritePass::biasaddgrad_conv2dwithbias_context_;
|
||||||
MklLayoutRewritePass::ContextInfo
|
MklLayoutRewritePass::ContextInfo
|
||||||
@ -1667,12 +1828,12 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g,
|
|||||||
const ContextInfo* ci = nullptr;
|
const ContextInfo* ci = nullptr;
|
||||||
bool is_context_based_rewrite = false;
|
bool is_context_based_rewrite = false;
|
||||||
if ((ci = SearchMatchingContext(orig_node, &fwd_node)) != nullptr) {
|
if ((ci = SearchMatchingContext(orig_node, &fwd_node)) != nullptr) {
|
||||||
CHECK_NOTNULL(fwd_node);
|
|
||||||
is_context_based_rewrite = true;
|
is_context_based_rewrite = true;
|
||||||
|
|
||||||
// Sanity checks for context-based rewrite (if any)
|
// Sanity checks for context-based rewrite (if any)
|
||||||
if (orig_node->type_string() == csinfo_.bias_add_grad &&
|
if (orig_node->type_string() == csinfo_.bias_add_grad &&
|
||||||
ri->new_name == csinfo_.mkl_conv2d_with_bias_backprop_bias) {
|
ri->new_name == csinfo_.mkl_conv2d_with_bias_backprop_bias) {
|
||||||
|
CHECK_NOTNULL(fwd_node);
|
||||||
DataType orig_T, ctx_T;
|
DataType orig_T, ctx_T;
|
||||||
string orig_data_format, ctx_data_format;
|
string orig_data_format, ctx_data_format;
|
||||||
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &orig_T));
|
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &orig_T));
|
||||||
@ -1784,69 +1945,17 @@ MklLayoutRewritePass::SearchMatchingContext(const Node* n,
|
|||||||
CHECK_NOTNULL(fwd_node);
|
CHECK_NOTNULL(fwd_node);
|
||||||
*fwd_node = nullptr;
|
*fwd_node = nullptr;
|
||||||
|
|
||||||
// Search for matching contextinfo based on node name.
|
// Search for matching contextinfo based on node name and call
|
||||||
// There could be more than one matching contextinfos.
|
// callback function using matching contextinfo.
|
||||||
bool is_matching_cinfo_found = false;
|
// There could be more than one matching contextinfos but whichever
|
||||||
std::vector<const ContextInfo*> mci;
|
// matches first is returned.
|
||||||
for (auto ci = cinfo_.cbegin(); ci != cinfo_.cend(); ++ci) {
|
for (auto ci = cinfo_.cbegin(); ci != cinfo_.cend(); ++ci) {
|
||||||
if (n->type_string() == (*ci)->node) {
|
if (n->type_string() == (*ci)->node &&
|
||||||
mci.push_back(*ci);
|
(*ci)->context_match_fn(n, fwd_node, *ci)) {
|
||||||
is_matching_cinfo_found = true;
|
VLOG(1) << "Found context as matching: " << (*ci)->fwd;
|
||||||
|
return *ci;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// If no matching contextinfo is found, return immediately.
|
|
||||||
if (!is_matching_cinfo_found) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
VLOG(1) << "MklLayoutRewritePass: Searching graph for: " << n->type_string()
|
|
||||||
<< " in backwards.";
|
|
||||||
|
|
||||||
// Now we will check for forward op name for context info in data
|
|
||||||
// flow graph. Get the max hops we should search for the fwd node.
|
|
||||||
// We are now going to search (breadth-first) backwards in data
|
|
||||||
// dependence graph (for up to max hops) from n for the node
|
|
||||||
// specified in fwd.
|
|
||||||
// queue to maintain nodes to be visited and depth info for
|
|
||||||
// breadth-first search
|
|
||||||
std::queue<std::pair<const Node*, int>> nqueue;
|
|
||||||
const Node* curr_node = n;
|
|
||||||
size_t curr_depth = 0;
|
|
||||||
nqueue.push(std::make_pair(curr_node, curr_depth));
|
|
||||||
|
|
||||||
while (curr_depth < kNodeMergeContextMaxDepth && !nqueue.empty()) {
|
|
||||||
std::pair<const Node*, int> curr_pair = nqueue.front();
|
|
||||||
nqueue.pop();
|
|
||||||
|
|
||||||
std::set<const Node*> visited_nodes;
|
|
||||||
curr_node = curr_pair.first;
|
|
||||||
curr_depth = curr_pair.second;
|
|
||||||
CHECK_NOTNULL(curr_node);
|
|
||||||
|
|
||||||
VLOG(1) << "MklLayoutRewritePass: Visiting node: "
|
|
||||||
<< curr_node->type_string() << " at depth: " << curr_depth
|
|
||||||
<< " for node: " << n->type_string();
|
|
||||||
|
|
||||||
// If we find a match, we return immediately.
|
|
||||||
for (const ContextInfo* ci : mci) {
|
|
||||||
if (curr_node->type_string() == ci->fwd) {
|
|
||||||
*fwd_node = curr_node;
|
|
||||||
return ci;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Else we explore backward edges from current node.
|
|
||||||
// Add the source nodes of all incoming edges of the node to the queue.
|
|
||||||
for (const Edge* e : curr_node->in_edges()) {
|
|
||||||
// We do not visit already visited node.
|
|
||||||
if (visited_nodes.find(e->src()) == visited_nodes.end()) {
|
|
||||||
// Depth of these nodes is 1 more than the depth of current node.
|
|
||||||
nqueue.push(std::make_pair(e->src(), curr_depth + 1));
|
|
||||||
visited_nodes.insert(e->src());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} /* while */
|
|
||||||
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -345,7 +345,8 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_AttrMismatch) {
|
|||||||
// Test set 2: _MklConv2D..BiasAddGrad -> _MklConv2DWithBiasBackpropBias
|
// Test set 2: _MklConv2D..BiasAddGrad -> _MklConv2DWithBiasBackpropBias
|
||||||
// rewrite tests
|
// rewrite tests
|
||||||
|
|
||||||
// D=_MklConv2D(A,M,B,N,C,O); E=Sub(D,A); F=BiasAddGrad(E)
|
// BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter
|
||||||
|
// and BackpropInput
|
||||||
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Positive) {
|
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Positive) {
|
||||||
InitGraph(
|
InitGraph(
|
||||||
"node { name: 'A' op: 'Input'}"
|
"node { name: 'A' op: 'Input'}"
|
||||||
@ -364,16 +365,255 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Positive) {
|
|||||||
"node { name: 'E' op: 'Sub'"
|
"node { name: 'E' op: 'Sub'"
|
||||||
" attr {key: 'T' value { type: DT_FLOAT } }"
|
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||||
" input: ['D', 'A']}"
|
" input: ['D', 'A']}"
|
||||||
"node { name: 'F' op: 'BiasAddGrad'"
|
"node { name: 'F' op: 'Int32Input'}"
|
||||||
|
"node { name: 'G' op: '_MklConv2DBackpropFilter'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||||
|
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||||
|
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||||
|
" input: ['A', 'F', 'E', 'M', 'N', 'O'] }"
|
||||||
|
"node { name: 'H' op: 'Int32Input'}"
|
||||||
|
"node { name: 'I' op: '_MklConv2DBackpropInput'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||||
|
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||||
|
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||||
|
" input: ['H', 'B', 'E', 'M', 'N', 'O']}"
|
||||||
|
"node { name: 'J' op: 'BiasAddGrad'"
|
||||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
" input: ['E'] }");
|
" input: ['E'] }");
|
||||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
"A(Input);B(Input);C(Input);D(_MklConv2DWithBias);DMT/_0(Const);"
|
"A(Input);B(Input);C(Input);D(_MklConv2DWithBias);DMT/_0(Const);"
|
||||||
"E(Sub);F(_MklConv2DWithBiasBackpropBias);M(_MklInput);"
|
"E(Sub);F(Int32Input);G(_MklConv2DBackpropFilter);H(Int32Input);"
|
||||||
"N(_MklInput);O(_MklInput)|A->D;A->E:1;B->D:1;C->D:2;D->E;"
|
"I(_MklConv2DBackpropInput);J(_MklConv2DWithBiasBackpropBias);"
|
||||||
"DMT/_0->F:1;E->F;E:control->DMT/_0:control;M->D:3;N->D:4;"
|
"M(_MklInput);N(_MklInput);O(_MklInput)|A->D;A->E:1;A->G;B->D:1;"
|
||||||
"O->D:5");
|
"B->I:1;C->D:2;D->E;DMT/_0->J:1;E->G:2;E->I:2;E->J;"
|
||||||
|
"E:control->DMT/_0:control;F->G:1;H->I;M->D:3;M->G:3;M->I:3;"
|
||||||
|
"N->D:4;N->G:4;N->I:4;O->D:5;O->G:5;O->I:5");
|
||||||
|
}
|
||||||
|
|
||||||
|
// BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter
|
||||||
|
// and BackpropInput. But nodes do not match criteria for rewrite. So
|
||||||
|
// rewrite should not happen.
|
||||||
|
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Negative1) {
|
||||||
|
InitGraph(
|
||||||
|
"node { name: 'A' op: 'Input'}"
|
||||||
|
"node { name: 'B' op: 'Input'}"
|
||||||
|
"node { name: 'C' op: 'Input'}"
|
||||||
|
"node { name: 'M' op: '_MklInput'}"
|
||||||
|
"node { name: 'N' op: '_MklInput'}"
|
||||||
|
"node { name: 'O' op: '_MklInput'}"
|
||||||
|
"node { name: 'D' op: '_MklConv2DWithBias'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||||
|
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||||
|
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||||
|
" input: ['A', 'B', 'C', 'M', 'N', 'O']}"
|
||||||
|
"node { name: 'E' op: 'Sub'"
|
||||||
|
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['D', 'A']}"
|
||||||
|
"node { name: 'F' op: 'Int32Input'}"
|
||||||
|
"node { name: 'G' op: '_MklConv2DBackpropFilter'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||||
|
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||||
|
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||||
|
" input: ['E', 'F', 'A', 'M', 'N', 'O'] }"
|
||||||
|
"node { name: 'H' op: 'Int32Input'}"
|
||||||
|
"node { name: 'I' op: '_MklConv2DBackpropInput'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||||
|
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||||
|
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||||
|
" input: ['H', 'B', 'E', 'M', 'N', 'O']}"
|
||||||
|
"node { name: 'J' op: 'BiasAddGrad'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" input: ['E'] }");
|
||||||
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
|
"A(Input);B(Input);C(Input);D(_MklConv2DWithBias);"
|
||||||
|
"E(Sub);F(Int32Input);G(_MklConv2DBackpropFilter);H(Int32Input);"
|
||||||
|
"I(_MklConv2DBackpropInput);J(BiasAddGrad);"
|
||||||
|
"M(_MklInput);N(_MklInput);O(_MklInput)|A->D;A->E:1;A->G:2;B->D:1;"
|
||||||
|
"B->I:1;C->D:2;D->E;E->G;E->I:2;E->J;F->G:1;H->I;M->D:3;M->G:3;"
|
||||||
|
"M->I:3;N->D:4;N->G:4;N->I:4;O->D:5;O->G:5;O->I:5");
|
||||||
|
}
|
||||||
|
|
||||||
|
// BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter
|
||||||
|
// and BackpropInput. But nodes do not match criteria for rewrite. So
|
||||||
|
// rewrite should not happen.
|
||||||
|
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Negative2) {
|
||||||
|
InitGraph(
|
||||||
|
"node { name: 'A' op: 'Input'}"
|
||||||
|
"node { name: 'B' op: 'Input'}"
|
||||||
|
"node { name: 'C' op: 'Input'}"
|
||||||
|
"node { name: 'M' op: '_MklInput'}"
|
||||||
|
"node { name: 'N' op: '_MklInput'}"
|
||||||
|
"node { name: 'O' op: '_MklInput'}"
|
||||||
|
"node { name: 'D' op: '_MklConv2DWithBias'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||||
|
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||||
|
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||||
|
" input: ['B', 'A', 'C', 'M', 'N', 'O']}"
|
||||||
|
"node { name: 'E' op: 'Sub'"
|
||||||
|
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['D', 'A']}"
|
||||||
|
"node { name: 'F' op: 'Int32Input'}"
|
||||||
|
"node { name: 'G' op: '_MklConv2DBackpropFilter'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||||
|
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||||
|
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||||
|
" input: ['A', 'F', 'E', 'M', 'N', 'O'] }"
|
||||||
|
"node { name: 'H' op: 'Int32Input'}"
|
||||||
|
"node { name: 'I' op: '_MklConv2DBackpropInput'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||||
|
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||||
|
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||||
|
" input: ['H', 'B', 'E', 'M', 'N', 'O']}"
|
||||||
|
"node { name: 'J' op: 'BiasAddGrad'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" input: ['E'] }");
|
||||||
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
|
"A(Input);B(Input);C(Input);D(_MklConv2DWithBias);"
|
||||||
|
"E(Sub);F(Int32Input);G(_MklConv2DBackpropFilter);H(Int32Input);"
|
||||||
|
"I(_MklConv2DBackpropInput);J(BiasAddGrad);"
|
||||||
|
"M(_MklInput);N(_MklInput);O(_MklInput)|A->D:1;A->E:1;A->G;B->D;"
|
||||||
|
"B->I:1;C->D:2;D->E;E->G:2;E->I:2;E->J;F->G:1;H->I;M->D:3;M->G:3;"
|
||||||
|
"M->I:3;N->D:4;N->G:4;N->I:4;O->D:5;O->G:5;O->I:5");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter only
|
||||||
|
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_BpropFilter_Positive) {
|
||||||
|
InitGraph(
|
||||||
|
"node { name: 'A' op: 'Input'}"
|
||||||
|
"node { name: 'B' op: 'Input'}"
|
||||||
|
"node { name: 'C' op: 'Input'}"
|
||||||
|
"node { name: 'M' op: '_MklInput'}"
|
||||||
|
"node { name: 'N' op: '_MklInput'}"
|
||||||
|
"node { name: 'O' op: '_MklInput'}"
|
||||||
|
"node { name: 'D' op: '_MklConv2DWithBias'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||||
|
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||||
|
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||||
|
" input: ['A', 'B', 'C', 'M', 'N', 'O']}"
|
||||||
|
"node { name: 'E' op: 'Sub'"
|
||||||
|
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['D', 'A']}"
|
||||||
|
"node { name: 'F' op: 'Int32Input'}"
|
||||||
|
"node { name: 'G' op: '_MklConv2DBackpropFilter'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||||
|
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||||
|
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||||
|
" input: ['A', 'F', 'E', 'M', 'N', 'O'] }"
|
||||||
|
"node { name: 'H' op: 'BiasAddGrad'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" input: ['E'] }");
|
||||||
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
|
"A(Input);B(Input);C(Input);D(_MklConv2DWithBias);DMT/_0(Const);"
|
||||||
|
"E(Sub);F(Int32Input);G(_MklConv2DBackpropFilter);"
|
||||||
|
"H(_MklConv2DWithBiasBackpropBias);M(_MklInput);N(_MklInput);"
|
||||||
|
"O(_MklInput)|A->D;A->E:1;A->G;B->D:1;C->D:2;D->E;DMT/_0->H:1;"
|
||||||
|
"E->G:2;E->H;E:control->DMT/_0:control;F->G:1;M->D:3;M->G:3;"
|
||||||
|
"N->D:4;N->G:4;O->D:5;O->G:5");
|
||||||
|
}
|
||||||
|
|
||||||
|
// BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter only
|
||||||
|
// But BackpropFilter node inputs do not satisfy criteria for rewrite.
|
||||||
|
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_BpropFilter_Negative1) {
|
||||||
|
InitGraph(
|
||||||
|
"node { name: 'A' op: 'Input'}"
|
||||||
|
"node { name: 'B' op: 'Input'}"
|
||||||
|
"node { name: 'C' op: 'Input'}"
|
||||||
|
"node { name: 'M' op: '_MklInput'}"
|
||||||
|
"node { name: 'N' op: '_MklInput'}"
|
||||||
|
"node { name: 'O' op: '_MklInput'}"
|
||||||
|
"node { name: 'D' op: '_MklConv2DWithBias'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||||
|
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||||
|
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||||
|
" input: ['A', 'B', 'C', 'M', 'N', 'O']}"
|
||||||
|
"node { name: 'E' op: 'Sub'"
|
||||||
|
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['D', 'A']}"
|
||||||
|
"node { name: 'F' op: 'Int32Input'}"
|
||||||
|
"node { name: 'G' op: '_MklConv2DBackpropFilter'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||||
|
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||||
|
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||||
|
" input: ['E', 'F', 'A', 'M', 'N', 'O'] }"
|
||||||
|
"node { name: 'H' op: 'BiasAddGrad'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" input: ['E'] }");
|
||||||
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
|
"A(Input);B(Input);C(Input);D(_MklConv2DWithBias);"
|
||||||
|
"E(Sub);F(Int32Input);G(_MklConv2DBackpropFilter);H(BiasAddGrad);"
|
||||||
|
"M(_MklInput);N(_MklInput);O(_MklInput)|A->D;A->E:1;A->G:2;B->D:1;"
|
||||||
|
"C->D:2;D->E;E->G;E->H;F->G:1;M->D:3;M->G:3;N->D:4;N->G:4;O->D:5;"
|
||||||
|
"O->G:5");
|
||||||
|
}
|
||||||
|
|
||||||
|
// BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter only
|
||||||
|
// But BackpropFilter node inputs do not satisfy criteria for rewrite.
|
||||||
|
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_BpropFilter_Negative2) {
|
||||||
|
InitGraph(
|
||||||
|
"node { name: 'A' op: 'Input'}"
|
||||||
|
"node { name: 'B' op: 'Input'}"
|
||||||
|
"node { name: 'C' op: 'Input'}"
|
||||||
|
"node { name: 'M' op: '_MklInput'}"
|
||||||
|
"node { name: 'N' op: '_MklInput'}"
|
||||||
|
"node { name: 'O' op: '_MklInput'}"
|
||||||
|
"node { name: 'D' op: '_MklConv2DWithBias'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||||
|
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||||
|
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||||
|
" input: ['B', 'A', 'C', 'M', 'N', 'O']}"
|
||||||
|
"node { name: 'E' op: 'Sub'"
|
||||||
|
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" input: ['D', 'A']}"
|
||||||
|
"node { name: 'F' op: 'Int32Input'}"
|
||||||
|
"node { name: 'G' op: '_MklConv2DBackpropFilter'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||||
|
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||||
|
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||||
|
" input: ['A', 'F', 'E', 'M', 'N', 'O'] }"
|
||||||
|
"node { name: 'H' op: 'BiasAddGrad'"
|
||||||
|
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||||
|
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||||
|
" input: ['E'] }");
|
||||||
|
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||||
|
"A(Input);B(Input);C(Input);D(_MklConv2DWithBias);"
|
||||||
|
"E(Sub);F(Int32Input);G(_MklConv2DBackpropFilter);H(BiasAddGrad);"
|
||||||
|
"M(_MklInput);N(_MklInput);O(_MklInput)|A->D:1;A->E:1;A->G;B->D;"
|
||||||
|
"C->D:2;D->E;E->G:2;E->H;F->G:1;M->D:3;M->G:3;N->D:4;N->G:4;O->D:5;"
|
||||||
|
"O->G:5");
|
||||||
}
|
}
|
||||||
|
|
||||||
// No _MklConv2DWithBias in context, but _MklConv2D in context.
|
// No _MklConv2DWithBias in context, but _MklConv2D in context.
|
||||||
|
Loading…
Reference in New Issue
Block a user