Add IsBiasAddV2 to check BiasAdd op

This commit is contained in:
Kaixi Hou 2020-11-22 22:02:51 -08:00
parent 09640bc683
commit 14bbfaaf7c
4 changed files with 8 additions and 5 deletions

View File

@ -113,6 +113,8 @@ bool IsBiasAdd(const NodeDef& node) {
return node.op() == "BiasAdd" || node.op() == "BiasAddV1";
}
bool IsBiasAddV2(const NodeDef& node) { return node.op() == "BiasAdd"; }
bool IsBiasAddGrad(const NodeDef& node) { return node.op() == "BiasAddGrad"; }
bool IsBitcast(const NodeDef& node) { return node.op() == "Bitcast"; }

View File

@ -45,6 +45,7 @@ bool IsAtan2(const NodeDef& node);
bool IsAvgPoolGrad(const NodeDef& node);
bool IsBetainc(const NodeDef& node);
bool IsBiasAdd(const NodeDef& node);
bool IsBiasAddV2(const NodeDef& node);
bool IsBiasAddGrad(const NodeDef& node);
bool IsBitcast(const NodeDef& node);
bool IsBroadcastTo(const NodeDef& node);

View File

@ -771,7 +771,7 @@ Status BiasAddTransposer::TransposeNode(
TransposeContext* context, utils::MutableNodeView* node) {
// This TransposeNode allows for BiasAdd but not BiasAddV1, since BiasAdd
// supports different data format.
DCHECK(node->GetOp() == "BiasAdd");
DCHECK(IsBiasAddV2(*node->node()));
const int rank = GetFanoutPortRank(*node, 0);
if (rank != 4 && rank != 5) {
return Status::OK();
@ -1969,7 +1969,7 @@ bool IsDefaultLayoutSensitiveOp(const NodeDef& node) {
bool IsLayoutSensitiveOp(const NodeDef& node) {
return IsDefaultLayoutSensitiveOp(node) || IsAvgPoolGrad(node) ||
IsBiasAdd(node) || IsBiasAddGrad(node) ||
IsBiasAddV2(node) || IsBiasAddGrad(node) ||
IsConv2DBackpropFilter(node) || IsConv2DBackpropInput(node) ||
IsDepthwiseConv2dNativeBackpropFilter(node) ||
IsDepthwiseConv2dNativeBackpropInput(node) ||

View File

@ -27,12 +27,12 @@ std::shared_ptr<Transposer> TransposerFactory::GetTransposer(
return GetOrCreateIfNotFound<DefaultLayoutSensitiveOpTransposer>(
"DefaultLayoutSensitiveOp");
}
if (IsBiasAdd(node)) {
return GetOrCreateIfNotFound<BiasAddTransposer>("BiasAdd");
}
if (IsAvgPoolGrad(node)) {
return GetOrCreateIfNotFound<AvgPoolGradTransposer>("AvgPoolGrad");
}
if (IsBiasAddV2(node)) {
return GetOrCreateIfNotFound<BiasAddTransposer>("BiasAdd");
}
if (IsBiasAddGrad(node)) {
return GetOrCreateIfNotFound<BiasAddGradTransposer>("BiasAddGrad");
}