Add back in checks for explicit ranks of inputs/outputs of nodes in generic layout optimizer transposers.

PiperOrigin-RevId: 337961791
Change-Id: I6788bc1cdfe163a1ea0c1928cafec715d51e60bf
This commit is contained in:
Andy Ly 2020-10-19 16:45:59 -07:00 committed by TensorFlower Gardener
parent c6694e7a96
commit 444342be43
2 changed files with 45 additions and 30 deletions

View File

@ -67,6 +67,8 @@ constexpr char kOpConst[] = "Const";
constexpr char kReshape[] = "Reshape";
constexpr char kReshapeConst[] = "ReshapeConst";
constexpr int kRank = 4;
constexpr int kUnknownRank = -1;
constexpr int kInvalidRank = -2;
inline bool AttrDataFormatMatch(const utils::MutableNodeView& node,
absl::string_view src_data_format,
@ -554,15 +556,23 @@ Status Transposer::UpdateEdge(
return Status::OK();
}
bool Transposer::IsFanoutPortRankN(const utils::MutableNodeView& node, int port,
int n) const {
int Transposer::GetFanoutPortRank(const utils::MutableNodeView& node,
int port) const {
const auto* output_shape_attr = node.GetAttr(kAttrOutputShape);
if (output_shape_attr == nullptr ||
output_shape_attr->list().shape_size() <= port) {
return false;
return kInvalidRank;
}
const auto& shape = output_shape_attr->list().shape(port);
return !shape.unknown_rank() && shape.dim_size() == n;
if (shape.unknown_rank()) {
return kUnknownRank;
}
return shape.dim_size();
}
bool Transposer::IsFanoutPortRankN(const utils::MutableNodeView& node, int port,
int n) const {
return GetFanoutPortRank(node, port) == n;
}
bool Transposer::IsFanoutPortsRankN(const utils::MutableNodeView& node,
@ -575,14 +585,18 @@ bool Transposer::IsFanoutPortsRankN(const utils::MutableNodeView& node,
return true;
}
bool Transposer::IsFaninPortRankN(const utils::MutableNodeView& node, int port,
int n) const {
int Transposer::GetFaninPortRank(const utils::MutableNodeView& node,
int port) const {
if (port < node.NumRegularFanins() && port >= 0) {
const auto& regular_fanin = node.GetRegularFanin(port);
return IsFanoutPortRankN(*regular_fanin.node_view(), regular_fanin.index(),
n);
return GetFanoutPortRank(*regular_fanin.node_view(), regular_fanin.index());
}
return false;
return kInvalidRank;
}
bool Transposer::IsFaninPortRankN(const utils::MutableNodeView& node, int port,
int n) const {
return GetFaninPortRank(node, port) == n;
}
bool Transposer::IsFaninPortDimsNIfConst(const utils::MutableNodeView& node,
@ -719,11 +733,12 @@ Status LayoutSensitiveOpTransposer::UpdateNode(TransposeContext* context,
Status DefaultLayoutSensitiveOpTransposer::TransposeNode(
TransposeContext* context, utils::MutableNodeView* node) {
DCHECK(IsDefaultLayoutSensitiveOp(*node->node()));
const auto* output_shape_attr = node->GetAttr(kAttrOutputShape);
const auto& shape = output_shape_attr->list().shape(0);
const int rank = shape.dim_size();
const int rank = GetFanoutPortRank(*node, 0);
if (rank != 4 && rank != 5) {
return Status::OK();
}
ScopedDataFormatUpgrader data_format_upgrader(context, rank);
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, rank)) {
if (!ShouldProcess(*context, *node)) {
return Status::OK();
}
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
@ -904,12 +919,12 @@ bool FusedBatchNormGradTransposer::IsTraining(
Status FusedBatchNormGradTransposer::TransposeNode(
TransposeContext* context, utils::MutableNodeView* node) {
DCHECK(IsFusedBatchNormGrad(*node->node()));
const auto* output_shape_attr = node->GetAttr(kAttrOutputShape);
const auto& shape = output_shape_attr->list().shape(0);
const int rank = shape.dim_size();
const int rank = GetFanoutPortRank(*node, 0);
if (rank != 4 && rank != 5) {
return Status::OK();
}
ScopedDataFormatUpgrader data_format_upgrader(context, rank);
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, rank) ||
!IsTraining(*node)) {
if (!ShouldProcess(*context, *node) || !IsTraining(*node)) {
return Status::OK();
}
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
@ -1089,9 +1104,7 @@ std::vector<int> LayoutAgnosticOpTransposer::GetVariadic4DFaninPorts(
Status DefaultLayoutAgnosticOpTransposer::TransposeNode(
TransposeContext* context, utils::MutableNodeView* node) {
DCHECK(IsDefaultLayoutAgnosticOp(*node->node()));
const auto* output_shape_attr = node->GetAttr(kAttrOutputShape);
const auto& shape = output_shape_attr->list().shape(0);
const int rank = shape.dim_size();
const int rank = GetFanoutPortRank(*node, 0);
if (rank != 4 && rank != 5) {
return Status::OK();
}
@ -1249,9 +1262,10 @@ Status BinaryOpTransposer::MaybeReshapeVectorFanin(TransposeContext* context,
Status BinaryOpTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
DCHECK(IsBinaryOp(*node->node()));
const auto* output_shape_attr = node->GetAttr(kAttrOutputShape);
const auto& shape = output_shape_attr->list().shape(0);
const int rank = shape.dim_size();
const int rank = GetFanoutPortRank(*node, 0);
if (rank != 4 && rank != 5) {
return Status::OK();
}
ScopedDataFormatUpgrader data_format_upgrader(context, rank);
if (!ShouldProcess(*context, *node) || !IsFaninShapeSupported(*node, rank) ||
!IsAfterDstToSrcTransform(*context, *node)) {
@ -1432,13 +1446,12 @@ bool ReduceTransposer::IsReduceAxisSupported(
Status ReduceTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
DCHECK(IsReduceOp(*node->node()));
const auto& regular_fanin = node->GetRegularFanin(0);
const auto* output_shape_attr =
regular_fanin.node_view()->GetAttr(kAttrOutputShape);
const auto& shape = output_shape_attr->list().shape(0);
const int rank = shape.dim_size();
const int rank = GetFaninPortRank(*node, 0);
if (rank != 4 && rank != 5) {
return Status::OK();
}
ScopedDataFormatUpgrader data_format_upgrader(context, rank);
if (!ShouldProcess(*context, *node) || !IsFaninPortRankN(*node, 0, rank) ||
if (!ShouldProcess(*context, *node) ||
!IsReduceAxisSupported(*context, *node) ||
!IsAfterDstToSrcTransform(*context, *node)) {
return Status::OK();

View File

@ -149,10 +149,12 @@ class Transposer {
utils::MutationNewNode* added_node);
protected:
int GetFanoutPortRank(const utils::MutableNodeView& node, int port) const;
bool IsFanoutPortRankN(const utils::MutableNodeView& node, int port,
int n) const;
bool IsFanoutPortsRankN(const utils::MutableNodeView& node,
absl::Span<const int> ports, int n) const;
int GetFaninPortRank(const utils::MutableNodeView& node, int port) const;
bool IsFaninPortRankN(const utils::MutableNodeView& node, int port,
int n) const;