Add back in checks for explicit ranks of inputs/outputs of nodes in generic layout optimizer transposers.

PiperOrigin-RevId: 337961791
Change-Id: I6788bc1cdfe163a1ea0c1928cafec715d51e60bf
This commit is contained in:
Andy Ly 2020-10-19 16:45:59 -07:00 committed by TensorFlower Gardener
parent c6694e7a96
commit 444342be43
2 changed files with 45 additions and 30 deletions

View File

@ -67,6 +67,8 @@ constexpr char kOpConst[] = "Const";
constexpr char kReshape[] = "Reshape"; constexpr char kReshape[] = "Reshape";
constexpr char kReshapeConst[] = "ReshapeConst"; constexpr char kReshapeConst[] = "ReshapeConst";
constexpr int kRank = 4; constexpr int kRank = 4;
constexpr int kUnknownRank = -1;
constexpr int kInvalidRank = -2;
inline bool AttrDataFormatMatch(const utils::MutableNodeView& node, inline bool AttrDataFormatMatch(const utils::MutableNodeView& node,
absl::string_view src_data_format, absl::string_view src_data_format,
@ -554,15 +556,23 @@ Status Transposer::UpdateEdge(
return Status::OK(); return Status::OK();
} }
bool Transposer::IsFanoutPortRankN(const utils::MutableNodeView& node, int port, int Transposer::GetFanoutPortRank(const utils::MutableNodeView& node,
int n) const { int port) const {
const auto* output_shape_attr = node.GetAttr(kAttrOutputShape); const auto* output_shape_attr = node.GetAttr(kAttrOutputShape);
if (output_shape_attr == nullptr || if (output_shape_attr == nullptr ||
output_shape_attr->list().shape_size() <= port) { output_shape_attr->list().shape_size() <= port) {
return false; return kInvalidRank;
} }
const auto& shape = output_shape_attr->list().shape(port); const auto& shape = output_shape_attr->list().shape(port);
return !shape.unknown_rank() && shape.dim_size() == n; if (shape.unknown_rank()) {
return kUnknownRank;
}
return shape.dim_size();
}
bool Transposer::IsFanoutPortRankN(const utils::MutableNodeView& node, int port,
int n) const {
return GetFanoutPortRank(node, port) == n;
} }
bool Transposer::IsFanoutPortsRankN(const utils::MutableNodeView& node, bool Transposer::IsFanoutPortsRankN(const utils::MutableNodeView& node,
@ -575,14 +585,18 @@ bool Transposer::IsFanoutPortsRankN(const utils::MutableNodeView& node,
return true; return true;
} }
bool Transposer::IsFaninPortRankN(const utils::MutableNodeView& node, int port, int Transposer::GetFaninPortRank(const utils::MutableNodeView& node,
int n) const { int port) const {
if (port < node.NumRegularFanins() && port >= 0) { if (port < node.NumRegularFanins() && port >= 0) {
const auto& regular_fanin = node.GetRegularFanin(port); const auto& regular_fanin = node.GetRegularFanin(port);
return IsFanoutPortRankN(*regular_fanin.node_view(), regular_fanin.index(), return GetFanoutPortRank(*regular_fanin.node_view(), regular_fanin.index());
n);
} }
return false; return kInvalidRank;
}
bool Transposer::IsFaninPortRankN(const utils::MutableNodeView& node, int port,
int n) const {
return GetFaninPortRank(node, port) == n;
} }
bool Transposer::IsFaninPortDimsNIfConst(const utils::MutableNodeView& node, bool Transposer::IsFaninPortDimsNIfConst(const utils::MutableNodeView& node,
@ -719,11 +733,12 @@ Status LayoutSensitiveOpTransposer::UpdateNode(TransposeContext* context,
Status DefaultLayoutSensitiveOpTransposer::TransposeNode( Status DefaultLayoutSensitiveOpTransposer::TransposeNode(
TransposeContext* context, utils::MutableNodeView* node) { TransposeContext* context, utils::MutableNodeView* node) {
DCHECK(IsDefaultLayoutSensitiveOp(*node->node())); DCHECK(IsDefaultLayoutSensitiveOp(*node->node()));
const auto* output_shape_attr = node->GetAttr(kAttrOutputShape); const int rank = GetFanoutPortRank(*node, 0);
const auto& shape = output_shape_attr->list().shape(0); if (rank != 4 && rank != 5) {
const int rank = shape.dim_size(); return Status::OK();
}
ScopedDataFormatUpgrader data_format_upgrader(context, rank); ScopedDataFormatUpgrader data_format_upgrader(context, rank);
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, rank)) { if (!ShouldProcess(*context, *node)) {
return Status::OK(); return Status::OK();
} }
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName() VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
@ -904,12 +919,12 @@ bool FusedBatchNormGradTransposer::IsTraining(
Status FusedBatchNormGradTransposer::TransposeNode( Status FusedBatchNormGradTransposer::TransposeNode(
TransposeContext* context, utils::MutableNodeView* node) { TransposeContext* context, utils::MutableNodeView* node) {
DCHECK(IsFusedBatchNormGrad(*node->node())); DCHECK(IsFusedBatchNormGrad(*node->node()));
const auto* output_shape_attr = node->GetAttr(kAttrOutputShape); const int rank = GetFanoutPortRank(*node, 0);
const auto& shape = output_shape_attr->list().shape(0); if (rank != 4 && rank != 5) {
const int rank = shape.dim_size(); return Status::OK();
}
ScopedDataFormatUpgrader data_format_upgrader(context, rank); ScopedDataFormatUpgrader data_format_upgrader(context, rank);
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, rank) || if (!ShouldProcess(*context, *node) || !IsTraining(*node)) {
!IsTraining(*node)) {
return Status::OK(); return Status::OK();
} }
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName() VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
@ -1089,9 +1104,7 @@ std::vector<int> LayoutAgnosticOpTransposer::GetVariadic4DFaninPorts(
Status DefaultLayoutAgnosticOpTransposer::TransposeNode( Status DefaultLayoutAgnosticOpTransposer::TransposeNode(
TransposeContext* context, utils::MutableNodeView* node) { TransposeContext* context, utils::MutableNodeView* node) {
DCHECK(IsDefaultLayoutAgnosticOp(*node->node())); DCHECK(IsDefaultLayoutAgnosticOp(*node->node()));
const auto* output_shape_attr = node->GetAttr(kAttrOutputShape); const int rank = GetFanoutPortRank(*node, 0);
const auto& shape = output_shape_attr->list().shape(0);
const int rank = shape.dim_size();
if (rank != 4 && rank != 5) { if (rank != 4 && rank != 5) {
return Status::OK(); return Status::OK();
} }
@ -1249,9 +1262,10 @@ Status BinaryOpTransposer::MaybeReshapeVectorFanin(TransposeContext* context,
Status BinaryOpTransposer::TransposeNode(TransposeContext* context, Status BinaryOpTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) { utils::MutableNodeView* node) {
DCHECK(IsBinaryOp(*node->node())); DCHECK(IsBinaryOp(*node->node()));
const auto* output_shape_attr = node->GetAttr(kAttrOutputShape); const int rank = GetFanoutPortRank(*node, 0);
const auto& shape = output_shape_attr->list().shape(0); if (rank != 4 && rank != 5) {
const int rank = shape.dim_size(); return Status::OK();
}
ScopedDataFormatUpgrader data_format_upgrader(context, rank); ScopedDataFormatUpgrader data_format_upgrader(context, rank);
if (!ShouldProcess(*context, *node) || !IsFaninShapeSupported(*node, rank) || if (!ShouldProcess(*context, *node) || !IsFaninShapeSupported(*node, rank) ||
!IsAfterDstToSrcTransform(*context, *node)) { !IsAfterDstToSrcTransform(*context, *node)) {
@ -1432,13 +1446,12 @@ bool ReduceTransposer::IsReduceAxisSupported(
Status ReduceTransposer::TransposeNode(TransposeContext* context, Status ReduceTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) { utils::MutableNodeView* node) {
DCHECK(IsReduceOp(*node->node())); DCHECK(IsReduceOp(*node->node()));
const auto& regular_fanin = node->GetRegularFanin(0); const int rank = GetFaninPortRank(*node, 0);
const auto* output_shape_attr = if (rank != 4 && rank != 5) {
regular_fanin.node_view()->GetAttr(kAttrOutputShape); return Status::OK();
const auto& shape = output_shape_attr->list().shape(0); }
const int rank = shape.dim_size();
ScopedDataFormatUpgrader data_format_upgrader(context, rank); ScopedDataFormatUpgrader data_format_upgrader(context, rank);
if (!ShouldProcess(*context, *node) || !IsFaninPortRankN(*node, 0, rank) || if (!ShouldProcess(*context, *node) ||
!IsReduceAxisSupported(*context, *node) || !IsReduceAxisSupported(*context, *node) ||
!IsAfterDstToSrcTransform(*context, *node)) { !IsAfterDstToSrcTransform(*context, *node)) {
return Status::OK(); return Status::OK();

View File

@ -149,10 +149,12 @@ class Transposer {
utils::MutationNewNode* added_node); utils::MutationNewNode* added_node);
protected: protected:
int GetFanoutPortRank(const utils::MutableNodeView& node, int port) const;
bool IsFanoutPortRankN(const utils::MutableNodeView& node, int port, bool IsFanoutPortRankN(const utils::MutableNodeView& node, int port,
int n) const; int n) const;
bool IsFanoutPortsRankN(const utils::MutableNodeView& node, bool IsFanoutPortsRankN(const utils::MutableNodeView& node,
absl::Span<const int> ports, int n) const; absl::Span<const int> ports, int n) const;
int GetFaninPortRank(const utils::MutableNodeView& node, int port) const;
bool IsFaninPortRankN(const utils::MutableNodeView& node, int port, bool IsFaninPortRankN(const utils::MutableNodeView& node, int port,
int n) const; int n) const;