Change IsFaninShapeSupported to support rank
This commit is contained in:
parent
94461ab056
commit
42aafd91e9
@ -1119,12 +1119,12 @@ bool BinaryOpTransposer::IsNDOperateWithMD(const utils::MutableNodeView& node,
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool BinaryOpTransposer::IsFaninShapeSupported(
|
bool BinaryOpTransposer::IsFaninShapeSupported(
|
||||||
const utils::MutableNodeView& node) {
|
const utils::MutableNodeView& node, int rank) {
|
||||||
return (IsNDOperateWithMD(node, 4, 0) || IsNDOperateWithMD(node, 4, 1) ||
|
return (IsNDOperateWithMD(node, rank, 0) ||
|
||||||
IsNDOperateWithMD(node, 4, 4) || IsNDOperateWithMD(node, 0, 4) ||
|
IsNDOperateWithMD(node, rank, 1) ||
|
||||||
IsNDOperateWithMD(node, 1, 4) || IsNDOperateWithMD(node, 5, 0) ||
|
IsNDOperateWithMD(node, rank, rank) ||
|
||||||
IsNDOperateWithMD(node, 5, 1) || IsNDOperateWithMD(node, 5, 5) ||
|
IsNDOperateWithMD(node, 0, rank) ||
|
||||||
IsNDOperateWithMD(node, 0, 5) || IsNDOperateWithMD(node, 1, 5));
|
IsNDOperateWithMD(node, 1, rank));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int> BinaryOpTransposer::GetNDDataFaninPorts(
|
std::vector<int> BinaryOpTransposer::GetNDDataFaninPorts(
|
||||||
@ -1258,7 +1258,7 @@ Status BinaryOpTransposer::TransposeNode(TransposeContext* context,
|
|||||||
context->AssignDeviceAndDataFormats(context->target_device, src_format_3d,
|
context->AssignDeviceAndDataFormats(context->target_device, src_format_3d,
|
||||||
dst_format_3d);
|
dst_format_3d);
|
||||||
}
|
}
|
||||||
if (!ShouldProcess(*context, *node) || !IsFaninShapeSupported(*node) ||
|
if (!ShouldProcess(*context, *node) || !IsFaninShapeSupported(*node, rank) ||
|
||||||
!IsAfterDstToSrcTransform(*context, *node)) {
|
!IsAfterDstToSrcTransform(*context, *node)) {
|
||||||
if (allow_5d) {
|
if (allow_5d) {
|
||||||
context->AssignDeviceAndDataFormats(context->target_device, src_format,
|
context->AssignDeviceAndDataFormats(context->target_device, src_format,
|
||||||
|
@ -347,7 +347,7 @@ class BinaryOpTransposer : public LayoutAgnosticOpTransposer {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
bool IsNDOperateWithMD(const utils::MutableNodeView& node, int n, int m);
|
bool IsNDOperateWithMD(const utils::MutableNodeView& node, int n, int m);
|
||||||
bool IsFaninShapeSupported(const utils::MutableNodeView& node);
|
bool IsFaninShapeSupported(const utils::MutableNodeView& node, int rank);
|
||||||
std::vector<int> GetNDDataFaninPorts(const utils::MutableNodeView& node,
|
std::vector<int> GetNDDataFaninPorts(const utils::MutableNodeView& node,
|
||||||
int rank);
|
int rank);
|
||||||
Status AddNodeShapeConst(utils::Mutation* mutation,
|
Status AddNodeShapeConst(utils::Mutation* mutation,
|
||||||
|
Loading…
Reference in New Issue
Block a user