Change IsFaninShapeSupported to support rank
This commit is contained in:
parent
94461ab056
commit
42aafd91e9
@ -1119,12 +1119,12 @@ bool BinaryOpTransposer::IsNDOperateWithMD(const utils::MutableNodeView& node,
|
||||
}
|
||||
|
||||
bool BinaryOpTransposer::IsFaninShapeSupported(
|
||||
const utils::MutableNodeView& node) {
|
||||
return (IsNDOperateWithMD(node, 4, 0) || IsNDOperateWithMD(node, 4, 1) ||
|
||||
IsNDOperateWithMD(node, 4, 4) || IsNDOperateWithMD(node, 0, 4) ||
|
||||
IsNDOperateWithMD(node, 1, 4) || IsNDOperateWithMD(node, 5, 0) ||
|
||||
IsNDOperateWithMD(node, 5, 1) || IsNDOperateWithMD(node, 5, 5) ||
|
||||
IsNDOperateWithMD(node, 0, 5) || IsNDOperateWithMD(node, 1, 5));
|
||||
const utils::MutableNodeView& node, int rank) {
|
||||
return (IsNDOperateWithMD(node, rank, 0) ||
|
||||
IsNDOperateWithMD(node, rank, 1) ||
|
||||
IsNDOperateWithMD(node, rank, rank) ||
|
||||
IsNDOperateWithMD(node, 0, rank) ||
|
||||
IsNDOperateWithMD(node, 1, rank));
|
||||
}
|
||||
|
||||
std::vector<int> BinaryOpTransposer::GetNDDataFaninPorts(
|
||||
@ -1258,7 +1258,7 @@ Status BinaryOpTransposer::TransposeNode(TransposeContext* context,
|
||||
context->AssignDeviceAndDataFormats(context->target_device, src_format_3d,
|
||||
dst_format_3d);
|
||||
}
|
||||
if (!ShouldProcess(*context, *node) || !IsFaninShapeSupported(*node) ||
|
||||
if (!ShouldProcess(*context, *node) || !IsFaninShapeSupported(*node, rank) ||
|
||||
!IsAfterDstToSrcTransform(*context, *node)) {
|
||||
if (allow_5d) {
|
||||
context->AssignDeviceAndDataFormats(context->target_device, src_format,
|
||||
|
@ -347,7 +347,7 @@ class BinaryOpTransposer : public LayoutAgnosticOpTransposer {
|
||||
|
||||
private:
|
||||
bool IsNDOperateWithMD(const utils::MutableNodeView& node, int n, int m);
|
||||
bool IsFaninShapeSupported(const utils::MutableNodeView& node);
|
||||
bool IsFaninShapeSupported(const utils::MutableNodeView& node, int rank);
|
||||
std::vector<int> GetNDDataFaninPorts(const utils::MutableNodeView& node,
|
||||
int rank);
|
||||
Status AddNodeShapeConst(utils::Mutation* mutation,
|
||||
|
Loading…
Reference in New Issue
Block a user