diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc index eae86e7c18c..83fe6781108 100644 --- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc +++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc @@ -1119,12 +1119,12 @@ bool BinaryOpTransposer::IsNDOperateWithMD(const utils::MutableNodeView& node, } bool BinaryOpTransposer::IsFaninShapeSupported( - const utils::MutableNodeView& node) { - return (IsNDOperateWithMD(node, 4, 0) || IsNDOperateWithMD(node, 4, 1) || - IsNDOperateWithMD(node, 4, 4) || IsNDOperateWithMD(node, 0, 4) || - IsNDOperateWithMD(node, 1, 4) || IsNDOperateWithMD(node, 5, 0) || - IsNDOperateWithMD(node, 5, 1) || IsNDOperateWithMD(node, 5, 5) || - IsNDOperateWithMD(node, 0, 5) || IsNDOperateWithMD(node, 1, 5)); + const utils::MutableNodeView& node, int rank) { + return (IsNDOperateWithMD(node, rank, 0) || + IsNDOperateWithMD(node, rank, 1) || + IsNDOperateWithMD(node, rank, rank) || + IsNDOperateWithMD(node, 0, rank) || + IsNDOperateWithMD(node, 1, rank)); } std::vector BinaryOpTransposer::GetNDDataFaninPorts( @@ -1258,7 +1258,7 @@ Status BinaryOpTransposer::TransposeNode(TransposeContext* context, context->AssignDeviceAndDataFormats(context->target_device, src_format_3d, dst_format_3d); } - if (!ShouldProcess(*context, *node) || !IsFaninShapeSupported(*node) || + if (!ShouldProcess(*context, *node) || !IsFaninShapeSupported(*node, rank) || !IsAfterDstToSrcTransform(*context, *node)) { if (allow_5d) { context->AssignDeviceAndDataFormats(context->target_device, src_format, diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.h b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.h index 8db9ff0e70f..11a223ee097 100644 --- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.h +++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.h @@ -347,7 +347,7 @@ class BinaryOpTransposer : public LayoutAgnosticOpTransposer { private: bool IsNDOperateWithMD(const utils::MutableNodeView& node, int n, int m); - bool IsFaninShapeSupported(const utils::MutableNodeView& node); + bool IsFaninShapeSupported(const utils::MutableNodeView& node, int rank); std::vector GetNDDataFaninPorts(const utils::MutableNodeView& node, int rank); Status AddNodeShapeConst(utils::Mutation* mutation,