Add IsBiasAddV2 to check BiasAdd op
This commit is contained in:
parent
09640bc683
commit
14bbfaaf7c
tensorflow/core/grappler
@ -113,6 +113,8 @@ bool IsBiasAdd(const NodeDef& node) {
|
||||
return node.op() == "BiasAdd" || node.op() == "BiasAddV1";
|
||||
}
|
||||
|
||||
bool IsBiasAddV2(const NodeDef& node) { return node.op() == "BiasAdd"; }
|
||||
|
||||
bool IsBiasAddGrad(const NodeDef& node) { return node.op() == "BiasAddGrad"; }
|
||||
|
||||
bool IsBitcast(const NodeDef& node) { return node.op() == "Bitcast"; }
|
||||
|
@ -45,6 +45,7 @@ bool IsAtan2(const NodeDef& node);
|
||||
bool IsAvgPoolGrad(const NodeDef& node);
|
||||
bool IsBetainc(const NodeDef& node);
|
||||
bool IsBiasAdd(const NodeDef& node);
|
||||
bool IsBiasAddV2(const NodeDef& node);
|
||||
bool IsBiasAddGrad(const NodeDef& node);
|
||||
bool IsBitcast(const NodeDef& node);
|
||||
bool IsBroadcastTo(const NodeDef& node);
|
||||
|
@ -771,7 +771,7 @@ Status BiasAddTransposer::TransposeNode(
|
||||
TransposeContext* context, utils::MutableNodeView* node) {
|
||||
// This TransposeNode allows for BiasAdd but not BiasAddV1, since BiasAdd
|
||||
// supports different data format.
|
||||
DCHECK(node->GetOp() == "BiasAdd");
|
||||
DCHECK(IsBiasAddV2(*node->node()));
|
||||
const int rank = GetFanoutPortRank(*node, 0);
|
||||
if (rank != 4 && rank != 5) {
|
||||
return Status::OK();
|
||||
@ -1969,7 +1969,7 @@ bool IsDefaultLayoutSensitiveOp(const NodeDef& node) {
|
||||
|
||||
bool IsLayoutSensitiveOp(const NodeDef& node) {
|
||||
return IsDefaultLayoutSensitiveOp(node) || IsAvgPoolGrad(node) ||
|
||||
IsBiasAdd(node) || IsBiasAddGrad(node) ||
|
||||
IsBiasAddV2(node) || IsBiasAddGrad(node) ||
|
||||
IsConv2DBackpropFilter(node) || IsConv2DBackpropInput(node) ||
|
||||
IsDepthwiseConv2dNativeBackpropFilter(node) ||
|
||||
IsDepthwiseConv2dNativeBackpropInput(node) ||
|
||||
|
@ -27,12 +27,12 @@ std::shared_ptr<Transposer> TransposerFactory::GetTransposer(
|
||||
return GetOrCreateIfNotFound<DefaultLayoutSensitiveOpTransposer>(
|
||||
"DefaultLayoutSensitiveOp");
|
||||
}
|
||||
if (IsBiasAdd(node)) {
|
||||
return GetOrCreateIfNotFound<BiasAddTransposer>("BiasAdd");
|
||||
}
|
||||
if (IsAvgPoolGrad(node)) {
|
||||
return GetOrCreateIfNotFound<AvgPoolGradTransposer>("AvgPoolGrad");
|
||||
}
|
||||
if (IsBiasAddV2(node)) {
|
||||
return GetOrCreateIfNotFound<BiasAddTransposer>("BiasAdd");
|
||||
}
|
||||
if (IsBiasAddGrad(node)) {
|
||||
return GetOrCreateIfNotFound<BiasAddGradTransposer>("BiasAddGrad");
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user