Add back in checks for explicit ranks of inputs/outputs of nodes in generic layout optimizer transposers.
PiperOrigin-RevId: 337961791 Change-Id: I6788bc1cdfe163a1ea0c1928cafec715d51e60bf
This commit is contained in:
parent
c6694e7a96
commit
444342be43
@ -67,6 +67,8 @@ constexpr char kOpConst[] = "Const";
|
|||||||
constexpr char kReshape[] = "Reshape";
|
constexpr char kReshape[] = "Reshape";
|
||||||
constexpr char kReshapeConst[] = "ReshapeConst";
|
constexpr char kReshapeConst[] = "ReshapeConst";
|
||||||
constexpr int kRank = 4;
|
constexpr int kRank = 4;
|
||||||
|
constexpr int kUnknownRank = -1;
|
||||||
|
constexpr int kInvalidRank = -2;
|
||||||
|
|
||||||
inline bool AttrDataFormatMatch(const utils::MutableNodeView& node,
|
inline bool AttrDataFormatMatch(const utils::MutableNodeView& node,
|
||||||
absl::string_view src_data_format,
|
absl::string_view src_data_format,
|
||||||
@ -554,15 +556,23 @@ Status Transposer::UpdateEdge(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Transposer::IsFanoutPortRankN(const utils::MutableNodeView& node, int port,
|
int Transposer::GetFanoutPortRank(const utils::MutableNodeView& node,
|
||||||
int n) const {
|
int port) const {
|
||||||
const auto* output_shape_attr = node.GetAttr(kAttrOutputShape);
|
const auto* output_shape_attr = node.GetAttr(kAttrOutputShape);
|
||||||
if (output_shape_attr == nullptr ||
|
if (output_shape_attr == nullptr ||
|
||||||
output_shape_attr->list().shape_size() <= port) {
|
output_shape_attr->list().shape_size() <= port) {
|
||||||
return false;
|
return kInvalidRank;
|
||||||
}
|
}
|
||||||
const auto& shape = output_shape_attr->list().shape(port);
|
const auto& shape = output_shape_attr->list().shape(port);
|
||||||
return !shape.unknown_rank() && shape.dim_size() == n;
|
if (shape.unknown_rank()) {
|
||||||
|
return kUnknownRank;
|
||||||
|
}
|
||||||
|
return shape.dim_size();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Transposer::IsFanoutPortRankN(const utils::MutableNodeView& node, int port,
|
||||||
|
int n) const {
|
||||||
|
return GetFanoutPortRank(node, port) == n;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Transposer::IsFanoutPortsRankN(const utils::MutableNodeView& node,
|
bool Transposer::IsFanoutPortsRankN(const utils::MutableNodeView& node,
|
||||||
@ -575,14 +585,18 @@ bool Transposer::IsFanoutPortsRankN(const utils::MutableNodeView& node,
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Transposer::IsFaninPortRankN(const utils::MutableNodeView& node, int port,
|
int Transposer::GetFaninPortRank(const utils::MutableNodeView& node,
|
||||||
int n) const {
|
int port) const {
|
||||||
if (port < node.NumRegularFanins() && port >= 0) {
|
if (port < node.NumRegularFanins() && port >= 0) {
|
||||||
const auto& regular_fanin = node.GetRegularFanin(port);
|
const auto& regular_fanin = node.GetRegularFanin(port);
|
||||||
return IsFanoutPortRankN(*regular_fanin.node_view(), regular_fanin.index(),
|
return GetFanoutPortRank(*regular_fanin.node_view(), regular_fanin.index());
|
||||||
n);
|
|
||||||
}
|
}
|
||||||
return false;
|
return kInvalidRank;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Transposer::IsFaninPortRankN(const utils::MutableNodeView& node, int port,
|
||||||
|
int n) const {
|
||||||
|
return GetFaninPortRank(node, port) == n;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Transposer::IsFaninPortDimsNIfConst(const utils::MutableNodeView& node,
|
bool Transposer::IsFaninPortDimsNIfConst(const utils::MutableNodeView& node,
|
||||||
@ -719,11 +733,12 @@ Status LayoutSensitiveOpTransposer::UpdateNode(TransposeContext* context,
|
|||||||
Status DefaultLayoutSensitiveOpTransposer::TransposeNode(
|
Status DefaultLayoutSensitiveOpTransposer::TransposeNode(
|
||||||
TransposeContext* context, utils::MutableNodeView* node) {
|
TransposeContext* context, utils::MutableNodeView* node) {
|
||||||
DCHECK(IsDefaultLayoutSensitiveOp(*node->node()));
|
DCHECK(IsDefaultLayoutSensitiveOp(*node->node()));
|
||||||
const auto* output_shape_attr = node->GetAttr(kAttrOutputShape);
|
const int rank = GetFanoutPortRank(*node, 0);
|
||||||
const auto& shape = output_shape_attr->list().shape(0);
|
if (rank != 4 && rank != 5) {
|
||||||
const int rank = shape.dim_size();
|
return Status::OK();
|
||||||
|
}
|
||||||
ScopedDataFormatUpgrader data_format_upgrader(context, rank);
|
ScopedDataFormatUpgrader data_format_upgrader(context, rank);
|
||||||
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, rank)) {
|
if (!ShouldProcess(*context, *node)) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
|
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
|
||||||
@ -904,12 +919,12 @@ bool FusedBatchNormGradTransposer::IsTraining(
|
|||||||
Status FusedBatchNormGradTransposer::TransposeNode(
|
Status FusedBatchNormGradTransposer::TransposeNode(
|
||||||
TransposeContext* context, utils::MutableNodeView* node) {
|
TransposeContext* context, utils::MutableNodeView* node) {
|
||||||
DCHECK(IsFusedBatchNormGrad(*node->node()));
|
DCHECK(IsFusedBatchNormGrad(*node->node()));
|
||||||
const auto* output_shape_attr = node->GetAttr(kAttrOutputShape);
|
const int rank = GetFanoutPortRank(*node, 0);
|
||||||
const auto& shape = output_shape_attr->list().shape(0);
|
if (rank != 4 && rank != 5) {
|
||||||
const int rank = shape.dim_size();
|
return Status::OK();
|
||||||
|
}
|
||||||
ScopedDataFormatUpgrader data_format_upgrader(context, rank);
|
ScopedDataFormatUpgrader data_format_upgrader(context, rank);
|
||||||
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, rank) ||
|
if (!ShouldProcess(*context, *node) || !IsTraining(*node)) {
|
||||||
!IsTraining(*node)) {
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
|
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
|
||||||
@ -1089,9 +1104,7 @@ std::vector<int> LayoutAgnosticOpTransposer::GetVariadic4DFaninPorts(
|
|||||||
Status DefaultLayoutAgnosticOpTransposer::TransposeNode(
|
Status DefaultLayoutAgnosticOpTransposer::TransposeNode(
|
||||||
TransposeContext* context, utils::MutableNodeView* node) {
|
TransposeContext* context, utils::MutableNodeView* node) {
|
||||||
DCHECK(IsDefaultLayoutAgnosticOp(*node->node()));
|
DCHECK(IsDefaultLayoutAgnosticOp(*node->node()));
|
||||||
const auto* output_shape_attr = node->GetAttr(kAttrOutputShape);
|
const int rank = GetFanoutPortRank(*node, 0);
|
||||||
const auto& shape = output_shape_attr->list().shape(0);
|
|
||||||
const int rank = shape.dim_size();
|
|
||||||
if (rank != 4 && rank != 5) {
|
if (rank != 4 && rank != 5) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -1249,9 +1262,10 @@ Status BinaryOpTransposer::MaybeReshapeVectorFanin(TransposeContext* context,
|
|||||||
Status BinaryOpTransposer::TransposeNode(TransposeContext* context,
|
Status BinaryOpTransposer::TransposeNode(TransposeContext* context,
|
||||||
utils::MutableNodeView* node) {
|
utils::MutableNodeView* node) {
|
||||||
DCHECK(IsBinaryOp(*node->node()));
|
DCHECK(IsBinaryOp(*node->node()));
|
||||||
const auto* output_shape_attr = node->GetAttr(kAttrOutputShape);
|
const int rank = GetFanoutPortRank(*node, 0);
|
||||||
const auto& shape = output_shape_attr->list().shape(0);
|
if (rank != 4 && rank != 5) {
|
||||||
const int rank = shape.dim_size();
|
return Status::OK();
|
||||||
|
}
|
||||||
ScopedDataFormatUpgrader data_format_upgrader(context, rank);
|
ScopedDataFormatUpgrader data_format_upgrader(context, rank);
|
||||||
if (!ShouldProcess(*context, *node) || !IsFaninShapeSupported(*node, rank) ||
|
if (!ShouldProcess(*context, *node) || !IsFaninShapeSupported(*node, rank) ||
|
||||||
!IsAfterDstToSrcTransform(*context, *node)) {
|
!IsAfterDstToSrcTransform(*context, *node)) {
|
||||||
@ -1432,13 +1446,12 @@ bool ReduceTransposer::IsReduceAxisSupported(
|
|||||||
Status ReduceTransposer::TransposeNode(TransposeContext* context,
|
Status ReduceTransposer::TransposeNode(TransposeContext* context,
|
||||||
utils::MutableNodeView* node) {
|
utils::MutableNodeView* node) {
|
||||||
DCHECK(IsReduceOp(*node->node()));
|
DCHECK(IsReduceOp(*node->node()));
|
||||||
const auto& regular_fanin = node->GetRegularFanin(0);
|
const int rank = GetFaninPortRank(*node, 0);
|
||||||
const auto* output_shape_attr =
|
if (rank != 4 && rank != 5) {
|
||||||
regular_fanin.node_view()->GetAttr(kAttrOutputShape);
|
return Status::OK();
|
||||||
const auto& shape = output_shape_attr->list().shape(0);
|
}
|
||||||
const int rank = shape.dim_size();
|
|
||||||
ScopedDataFormatUpgrader data_format_upgrader(context, rank);
|
ScopedDataFormatUpgrader data_format_upgrader(context, rank);
|
||||||
if (!ShouldProcess(*context, *node) || !IsFaninPortRankN(*node, 0, rank) ||
|
if (!ShouldProcess(*context, *node) ||
|
||||||
!IsReduceAxisSupported(*context, *node) ||
|
!IsReduceAxisSupported(*context, *node) ||
|
||||||
!IsAfterDstToSrcTransform(*context, *node)) {
|
!IsAfterDstToSrcTransform(*context, *node)) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -149,10 +149,12 @@ class Transposer {
|
|||||||
utils::MutationNewNode* added_node);
|
utils::MutationNewNode* added_node);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
int GetFanoutPortRank(const utils::MutableNodeView& node, int port) const;
|
||||||
bool IsFanoutPortRankN(const utils::MutableNodeView& node, int port,
|
bool IsFanoutPortRankN(const utils::MutableNodeView& node, int port,
|
||||||
int n) const;
|
int n) const;
|
||||||
bool IsFanoutPortsRankN(const utils::MutableNodeView& node,
|
bool IsFanoutPortsRankN(const utils::MutableNodeView& node,
|
||||||
absl::Span<const int> ports, int n) const;
|
absl::Span<const int> ports, int n) const;
|
||||||
|
int GetFaninPortRank(const utils::MutableNodeView& node, int port) const;
|
||||||
bool IsFaninPortRankN(const utils::MutableNodeView& node, int port,
|
bool IsFaninPortRankN(const utils::MutableNodeView& node, int port,
|
||||||
int n) const;
|
int n) const;
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user