Change IsFaninShapeSupported to support rank

This commit is contained in:
Kaixi Hou 2020-09-29 11:54:22 -07:00
parent 94461ab056
commit 42aafd91e9
2 changed files with 8 additions and 8 deletions

View File

@ -1119,12 +1119,12 @@ bool BinaryOpTransposer::IsNDOperateWithMD(const utils::MutableNodeView& node,
} }
bool BinaryOpTransposer::IsFaninShapeSupported( bool BinaryOpTransposer::IsFaninShapeSupported(
const utils::MutableNodeView& node) { const utils::MutableNodeView& node, int rank) {
return (IsNDOperateWithMD(node, 4, 0) || IsNDOperateWithMD(node, 4, 1) || return (IsNDOperateWithMD(node, rank, 0) ||
IsNDOperateWithMD(node, 4, 4) || IsNDOperateWithMD(node, 0, 4) || IsNDOperateWithMD(node, rank, 1) ||
IsNDOperateWithMD(node, 1, 4) || IsNDOperateWithMD(node, 5, 0) || IsNDOperateWithMD(node, rank, rank) ||
IsNDOperateWithMD(node, 5, 1) || IsNDOperateWithMD(node, 5, 5) || IsNDOperateWithMD(node, 0, rank) ||
IsNDOperateWithMD(node, 0, 5) || IsNDOperateWithMD(node, 1, 5)); IsNDOperateWithMD(node, 1, rank));
} }
std::vector<int> BinaryOpTransposer::GetNDDataFaninPorts( std::vector<int> BinaryOpTransposer::GetNDDataFaninPorts(
@ -1258,7 +1258,7 @@ Status BinaryOpTransposer::TransposeNode(TransposeContext* context,
context->AssignDeviceAndDataFormats(context->target_device, src_format_3d, context->AssignDeviceAndDataFormats(context->target_device, src_format_3d,
dst_format_3d); dst_format_3d);
} }
if (!ShouldProcess(*context, *node) || !IsFaninShapeSupported(*node) || if (!ShouldProcess(*context, *node) || !IsFaninShapeSupported(*node, rank) ||
!IsAfterDstToSrcTransform(*context, *node)) { !IsAfterDstToSrcTransform(*context, *node)) {
if (allow_5d) { if (allow_5d) {
context->AssignDeviceAndDataFormats(context->target_device, src_format, context->AssignDeviceAndDataFormats(context->target_device, src_format,

View File

@ -347,7 +347,7 @@ class BinaryOpTransposer : public LayoutAgnosticOpTransposer {
private: private:
bool IsNDOperateWithMD(const utils::MutableNodeView& node, int n, int m); bool IsNDOperateWithMD(const utils::MutableNodeView& node, int n, int m);
bool IsFaninShapeSupported(const utils::MutableNodeView& node); bool IsFaninShapeSupported(const utils::MutableNodeView& node, int rank);
std::vector<int> GetNDDataFaninPorts(const utils::MutableNodeView& node, std::vector<int> GetNDDataFaninPorts(const utils::MutableNodeView& node,
int rank); int rank);
Status AddNodeShapeConst(utils::Mutation* mutation, Status AddNodeShapeConst(utils::Mutation* mutation,