Add back in checks for explicit ranks of inputs/outputs of nodes in generic layout optimizer transposers.
PiperOrigin-RevId: 337961791 Change-Id: I6788bc1cdfe163a1ea0c1928cafec715d51e60bf
This commit is contained in:
parent
c6694e7a96
commit
444342be43
@ -67,6 +67,8 @@ constexpr char kOpConst[] = "Const";
|
||||
constexpr char kReshape[] = "Reshape";
|
||||
constexpr char kReshapeConst[] = "ReshapeConst";
|
||||
constexpr int kRank = 4;
|
||||
constexpr int kUnknownRank = -1;
|
||||
constexpr int kInvalidRank = -2;
|
||||
|
||||
inline bool AttrDataFormatMatch(const utils::MutableNodeView& node,
|
||||
absl::string_view src_data_format,
|
||||
@ -554,15 +556,23 @@ Status Transposer::UpdateEdge(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool Transposer::IsFanoutPortRankN(const utils::MutableNodeView& node, int port,
|
||||
int n) const {
|
||||
int Transposer::GetFanoutPortRank(const utils::MutableNodeView& node,
|
||||
int port) const {
|
||||
const auto* output_shape_attr = node.GetAttr(kAttrOutputShape);
|
||||
if (output_shape_attr == nullptr ||
|
||||
output_shape_attr->list().shape_size() <= port) {
|
||||
return false;
|
||||
return kInvalidRank;
|
||||
}
|
||||
const auto& shape = output_shape_attr->list().shape(port);
|
||||
return !shape.unknown_rank() && shape.dim_size() == n;
|
||||
if (shape.unknown_rank()) {
|
||||
return kUnknownRank;
|
||||
}
|
||||
return shape.dim_size();
|
||||
}
|
||||
|
||||
bool Transposer::IsFanoutPortRankN(const utils::MutableNodeView& node, int port,
|
||||
int n) const {
|
||||
return GetFanoutPortRank(node, port) == n;
|
||||
}
|
||||
|
||||
bool Transposer::IsFanoutPortsRankN(const utils::MutableNodeView& node,
|
||||
@ -575,14 +585,18 @@ bool Transposer::IsFanoutPortsRankN(const utils::MutableNodeView& node,
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Transposer::IsFaninPortRankN(const utils::MutableNodeView& node, int port,
|
||||
int n) const {
|
||||
int Transposer::GetFaninPortRank(const utils::MutableNodeView& node,
|
||||
int port) const {
|
||||
if (port < node.NumRegularFanins() && port >= 0) {
|
||||
const auto& regular_fanin = node.GetRegularFanin(port);
|
||||
return IsFanoutPortRankN(*regular_fanin.node_view(), regular_fanin.index(),
|
||||
n);
|
||||
return GetFanoutPortRank(*regular_fanin.node_view(), regular_fanin.index());
|
||||
}
|
||||
return false;
|
||||
return kInvalidRank;
|
||||
}
|
||||
|
||||
bool Transposer::IsFaninPortRankN(const utils::MutableNodeView& node, int port,
|
||||
int n) const {
|
||||
return GetFaninPortRank(node, port) == n;
|
||||
}
|
||||
|
||||
bool Transposer::IsFaninPortDimsNIfConst(const utils::MutableNodeView& node,
|
||||
@ -719,11 +733,12 @@ Status LayoutSensitiveOpTransposer::UpdateNode(TransposeContext* context,
|
||||
Status DefaultLayoutSensitiveOpTransposer::TransposeNode(
|
||||
TransposeContext* context, utils::MutableNodeView* node) {
|
||||
DCHECK(IsDefaultLayoutSensitiveOp(*node->node()));
|
||||
const auto* output_shape_attr = node->GetAttr(kAttrOutputShape);
|
||||
const auto& shape = output_shape_attr->list().shape(0);
|
||||
const int rank = shape.dim_size();
|
||||
const int rank = GetFanoutPortRank(*node, 0);
|
||||
if (rank != 4 && rank != 5) {
|
||||
return Status::OK();
|
||||
}
|
||||
ScopedDataFormatUpgrader data_format_upgrader(context, rank);
|
||||
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, rank)) {
|
||||
if (!ShouldProcess(*context, *node)) {
|
||||
return Status::OK();
|
||||
}
|
||||
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
|
||||
@ -904,12 +919,12 @@ bool FusedBatchNormGradTransposer::IsTraining(
|
||||
Status FusedBatchNormGradTransposer::TransposeNode(
|
||||
TransposeContext* context, utils::MutableNodeView* node) {
|
||||
DCHECK(IsFusedBatchNormGrad(*node->node()));
|
||||
const auto* output_shape_attr = node->GetAttr(kAttrOutputShape);
|
||||
const auto& shape = output_shape_attr->list().shape(0);
|
||||
const int rank = shape.dim_size();
|
||||
const int rank = GetFanoutPortRank(*node, 0);
|
||||
if (rank != 4 && rank != 5) {
|
||||
return Status::OK();
|
||||
}
|
||||
ScopedDataFormatUpgrader data_format_upgrader(context, rank);
|
||||
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, rank) ||
|
||||
!IsTraining(*node)) {
|
||||
if (!ShouldProcess(*context, *node) || !IsTraining(*node)) {
|
||||
return Status::OK();
|
||||
}
|
||||
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
|
||||
@ -1089,9 +1104,7 @@ std::vector<int> LayoutAgnosticOpTransposer::GetVariadic4DFaninPorts(
|
||||
Status DefaultLayoutAgnosticOpTransposer::TransposeNode(
|
||||
TransposeContext* context, utils::MutableNodeView* node) {
|
||||
DCHECK(IsDefaultLayoutAgnosticOp(*node->node()));
|
||||
const auto* output_shape_attr = node->GetAttr(kAttrOutputShape);
|
||||
const auto& shape = output_shape_attr->list().shape(0);
|
||||
const int rank = shape.dim_size();
|
||||
const int rank = GetFanoutPortRank(*node, 0);
|
||||
if (rank != 4 && rank != 5) {
|
||||
return Status::OK();
|
||||
}
|
||||
@ -1249,9 +1262,10 @@ Status BinaryOpTransposer::MaybeReshapeVectorFanin(TransposeContext* context,
|
||||
Status BinaryOpTransposer::TransposeNode(TransposeContext* context,
|
||||
utils::MutableNodeView* node) {
|
||||
DCHECK(IsBinaryOp(*node->node()));
|
||||
const auto* output_shape_attr = node->GetAttr(kAttrOutputShape);
|
||||
const auto& shape = output_shape_attr->list().shape(0);
|
||||
const int rank = shape.dim_size();
|
||||
const int rank = GetFanoutPortRank(*node, 0);
|
||||
if (rank != 4 && rank != 5) {
|
||||
return Status::OK();
|
||||
}
|
||||
ScopedDataFormatUpgrader data_format_upgrader(context, rank);
|
||||
if (!ShouldProcess(*context, *node) || !IsFaninShapeSupported(*node, rank) ||
|
||||
!IsAfterDstToSrcTransform(*context, *node)) {
|
||||
@ -1432,13 +1446,12 @@ bool ReduceTransposer::IsReduceAxisSupported(
|
||||
Status ReduceTransposer::TransposeNode(TransposeContext* context,
|
||||
utils::MutableNodeView* node) {
|
||||
DCHECK(IsReduceOp(*node->node()));
|
||||
const auto& regular_fanin = node->GetRegularFanin(0);
|
||||
const auto* output_shape_attr =
|
||||
regular_fanin.node_view()->GetAttr(kAttrOutputShape);
|
||||
const auto& shape = output_shape_attr->list().shape(0);
|
||||
const int rank = shape.dim_size();
|
||||
const int rank = GetFaninPortRank(*node, 0);
|
||||
if (rank != 4 && rank != 5) {
|
||||
return Status::OK();
|
||||
}
|
||||
ScopedDataFormatUpgrader data_format_upgrader(context, rank);
|
||||
if (!ShouldProcess(*context, *node) || !IsFaninPortRankN(*node, 0, rank) ||
|
||||
if (!ShouldProcess(*context, *node) ||
|
||||
!IsReduceAxisSupported(*context, *node) ||
|
||||
!IsAfterDstToSrcTransform(*context, *node)) {
|
||||
return Status::OK();
|
||||
|
@ -149,10 +149,12 @@ class Transposer {
|
||||
utils::MutationNewNode* added_node);
|
||||
|
||||
protected:
|
||||
int GetFanoutPortRank(const utils::MutableNodeView& node, int port) const;
|
||||
bool IsFanoutPortRankN(const utils::MutableNodeView& node, int port,
|
||||
int n) const;
|
||||
bool IsFanoutPortsRankN(const utils::MutableNodeView& node,
|
||||
absl::Span<const int> ports, int n) const;
|
||||
int GetFaninPortRank(const utils::MutableNodeView& node, int port) const;
|
||||
bool IsFaninPortRankN(const utils::MutableNodeView& node, int port,
|
||||
int n) const;
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user