From ffc6b92b2839df5e7ffd438c42a81bb1912fe103 Mon Sep 17 00:00:00 2001 From: Andy Ly Date: Thu, 20 Jun 2019 17:46:58 -0700 Subject: [PATCH] [Grappler] Have GenericLayoutOptimizer be specific to GPUs for now. PiperOrigin-RevId: 254307839 --- .../optimizers/generic_layout_optimizer.cc | 53 +- .../optimizers/generic_layout_optimizer.h | 5 + .../generic_layout_optimizer_transposer.cc | 1066 ++++++----------- .../generic_layout_optimizer_transposer.h | 123 +- ...eneric_layout_optimizer_transposer_test.cc | 127 +- 5 files changed, 488 insertions(+), 886 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer.cc b/tensorflow/core/grappler/optimizers/generic_layout_optimizer.cc index 5d8906161bf..f27300d8705 100644 --- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer.cc @@ -31,6 +31,17 @@ namespace grappler { namespace { +inline int GetNumGPUs(const Cluster& cluster) { + auto devices = cluster.GetDevices(); + int num_gpus = 0; + for (const auto& device : devices) { + if (device.second.type() == "GPU") { + num_gpus++; + } + } + return num_gpus; +} + Status ExpandLayoutSensitiveOp(TransposeContext* context, TransposerFactory* transposer_factory) { const int num_nodes = context->num_nodes; @@ -180,9 +191,10 @@ Status EraseOutputShapeAttrs(TransposeContext* context) { utils::Mutation* mutation = graph_view->GetMutationBuilder(); const int num_nodes = graph_view->NumNodes(); for (int i = 0; i < num_nodes; ++i) { - mutation->RemoveNodeAttr(graph_view->GetNode(i), "_output_shapes"); + mutation->RemoveNodeAttr(graph_view->GetNode(i), kAttrOutputShape); + TF_RETURN_IF_ERROR(mutation->Apply()); } - return mutation->Apply(); + return Status::OK(); } } // namespace @@ -190,15 +202,23 @@ Status EraseOutputShapeAttrs(TransposeContext* context) { Status GenericLayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* output) { - // If optimizer returns early with error, output will be the input graph. - *output = item.graph; + if (cluster == nullptr) { + LOG(WARNING) + << "generic layout optimizer was called with cluster == nullptr"; + return errors::Aborted("cluster == nullptr."); + } + if (GetNumGPUs(*cluster) < 1) { + return errors::Aborted( + "No GPUs found: GenericLayoutOptimizer is currently only tuned for " + "GPU."); + } + TransposeContext context; - TF_RETURN_IF_ERROR( - TransposeContext::InitializeTransposeContext(item, cluster, &context)); + TF_RETURN_IF_ERROR(TransposeContext::InitializeTransposeContext( + item, cluster, src_format_, dst_format_, target_device_, &context)); TransposerFactory transposer_factory; TF_RETURN_IF_ERROR(ExpandLayoutSensitiveOp(&context, &transposer_factory)); TF_RETURN_IF_ERROR(ExpandLayoutAgnosticOp(&context, &transposer_factory)); - // TODO(lyandy): Merge non cancellable nodes. TF_RETURN_IF_ERROR(EraseCancellableNodes(&context)); TF_RETURN_IF_ERROR(EraseOutputShapeAttrs(&context)); @@ -213,25 +233,6 @@ void GenericLayoutOptimizer::Feedback(Cluster* cluster, // Takes no feedback. } -string GetAndValidateParameter(const string& parameter, - const AttrValueMap& parameter_map, - const std::set& valid_inputs, - std::vector* validation_errors, - std::vector* missing_parameters) { - if (parameter_map.find(parameter) != parameter_map.end()) { - string input = str_util::Uppercase(parameter_map.at(parameter).s()); - if (valid_inputs.find(input) != valid_inputs.end()) { - return input; - } - validation_errors->push_back(absl::StrCat( - "Invalid input ", input, " for parameter ", parameter, - ", must be one of [", str_util::Join(valid_inputs, ", "), "].")); - } else { - missing_parameters->push_back(parameter); - } - return ""; -} - Status GenericLayoutOptimizer::Init( const RewriterConfig_CustomGraphOptimizer* config) { return Status::OK(); diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer.h b/tensorflow/core/grappler/optimizers/generic_layout_optimizer.h index 3189412a094..38ac492da86 100644 --- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer.h +++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer.h @@ -36,6 +36,11 @@ class GenericLayoutOptimizer : public CustomGraphOptimizer { const GraphDef& optimize_output, double result) override; Status Init(const RewriterConfig_CustomGraphOptimizer* config) final; + + private: + string target_device_ = "GPU"; + string src_format_ = "NHWC"; + string dst_format_ = "NCHW"; }; } // namespace grappler diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc index 9fc439c6420..803db44a589 100644 --- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc +++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc @@ -42,7 +42,6 @@ namespace grappler { namespace { constexpr char kOptimizedSuffix[] = "LayoutOptimizer"; -constexpr char kAttrOutputShape[] = "_output_shapes"; constexpr char kAttrKSize[] = "ksize"; constexpr char kAttrStrides[] = "strides"; constexpr char kAttrDataFormat[] = "data_format"; @@ -144,23 +143,12 @@ bool IsHostMemory(const NodeDef& node, int output_port) { } // namespace -// NodeLayoutContext. - -void NodeLayoutContext::Initialize(absl::string_view src_format, - absl::string_view dst_format, - bool is_transposable) { - src = string(src_format); - dst = string(dst_format); - src_to_dst = GetPermutation(src_format, dst_format); - dst_to_src = GetPermutation(dst_format, src_format); - transposable = is_transposable; -} - // TransposeContext. -Status TransposeContext::InitializeTransposeContext(const GrapplerItem& item, - const Cluster* cluster, - TransposeContext* context) { +Status TransposeContext::InitializeTransposeContext( + const GrapplerItem& item, const Cluster* cluster, + absl::string_view src_format, absl::string_view dst_format, + absl::string_view target_device, TransposeContext* context) { DCHECK(context != nullptr); context->graph_properties = absl::make_unique(item); TF_RETURN_IF_ERROR(context->graph_properties->InferStatically(false)); @@ -179,6 +167,11 @@ Status TransposeContext::InitializeTransposeContext(const GrapplerItem& item, context->virtual_placer = absl::make_unique(cluster->GetDevices()); } + context->src_format = string(src_format); + context->dst_format = string(dst_format); + context->target_device = string(target_device); + context->src_to_dst = GetPermutation(src_format, dst_format); + context->dst_to_src = GetPermutation(dst_format, src_format); return Status::OK(); } @@ -191,12 +184,24 @@ string Transposer::GetDeviceName(const VirtualPlacer* virtual_placer, : node.device(); } -string Transposer::GetDstDataFormatForDevice( - const DeviceProperties& device) const { - if (device.type() == kGPU) { - return kNCHW; - } - return ""; +bool Transposer::ShouldProcess(const TransposeContext& context, + const utils::MutableNodeView& node) const { + const auto* node_def = node.node(); + const string& device_name = + GetDeviceName(context.virtual_placer.get(), *node_def); + string device; + string task; + bool is_on_target_device = + DeviceNameUtils::SplitDeviceName(device_name, &task, &device) && + absl::StrContains(absl::AsciiStrToLower(device), + absl::AsciiStrToLower(context.target_device)); + + // Only checks data format for layout sensitive op. + bool data_format_match = !IsLayoutSensitiveOp(*node_def) || + AttrDataFormatMatch(node, context.src_format); + return is_on_target_device && data_format_match && + !context.nodes_to_preserve.contains(node_def->name()) && + !(node.NumRegularFanouts() == 0 && node.NumControlledFanouts() == 0); } Status Transposer::CreateConstPermNode(TransposeContext* context, @@ -279,7 +284,6 @@ Status Transposer::CreateTransposeNode( } Status Transposer::UpdateFaninEdgesWithOp(TransposeContext* context, - const NodeLayoutContext& layout, absl::Span dst_ports, utils::MutableNodeView* dst_node, absl::string_view op) { @@ -288,11 +292,10 @@ Status Transposer::UpdateFaninEdgesWithOp(TransposeContext* context, auto* fanin_node_view = fanin_port.node_view(); TF_RETURN_IF_ERROR( - UpdateEdge(context, layout, - GetFaninNameFormat(dst_node->GetName(), dst_port, layout.src, - layout.dst), - op, - /*input_shape=*/nullptr, + UpdateEdge(context, + GetFaninNameFormat(dst_node->GetName(), dst_port, + context->src_format, context->dst_format), + op, /*input_shape=*/nullptr, /*is_src_format_to_dst_format=*/true, fanin_port.index(), dst_port, fanin_node_view, dst_node)); } @@ -300,7 +303,6 @@ Status Transposer::UpdateFaninEdgesWithOp(TransposeContext* context, } Status Transposer::UpdateFanoutEdgesWithOp(TransposeContext* context, - const NodeLayoutContext& layout, absl::Span src_ports, utils::MutableNodeView* src_node, absl::string_view op) { @@ -311,7 +313,7 @@ Status Transposer::UpdateFanoutEdgesWithOp(TransposeContext* context, shape_attr_copy = *output_shape_attr; for (int port : src_ports) { TF_RETURN_IF_ERROR(Permute( - layout.src_to_dst, + context->src_to_dst, shape_attr_copy.mutable_list()->mutable_shape(port)->mutable_dim())); } context->graph_view->GetMutationBuilder()->AddOrUpdateNodeAttr( @@ -329,24 +331,24 @@ Status Transposer::UpdateFanoutEdgesWithOp(TransposeContext* context, ComparatorByNodeNameAndIndex()); int num_downstream_transposers = 0; for (const auto& fanout : sorted_fanouts) { - TF_RETURN_IF_ERROR( - UpdateEdge(context, layout, - GetFanoutNameFormat(src_node->GetName(), src_port, - num_downstream_transposers++, - layout.src, layout.dst), - op, &shape_attr_copy, - /*is_src_format_to_dst_format=*/false, src_port, - fanout.index(), src_node, fanout.node_view())); + TF_RETURN_IF_ERROR(UpdateEdge( + context, + GetFanoutNameFormat(src_node->GetName(), src_port, + num_downstream_transposers++, context->src_format, + context->dst_format), + op, &shape_attr_copy, + /*is_src_format_to_dst_format=*/false, src_port, fanout.index(), + src_node, fanout.node_view())); } } return Status::OK(); } Status Transposer::CreateDataFormatNode( - TransposeContext* context, const NodeLayoutContext& layout, - absl::string_view node_name, absl::string_view op, absl::string_view device, - const DataType& data_type, bool is_fanin_on_host, - bool is_src_format_to_dst_format, utils::MutationNewNode* added_node) { + TransposeContext* context, absl::string_view node_name, + absl::string_view op, absl::string_view device, const DataType& data_type, + bool is_fanin_on_host, bool is_src_format_to_dst_format, + utils::MutationNewNode* added_node) { auto* graph_view = context->graph_view.get(); DCHECK(!graph_view->HasNode(node_name)); @@ -370,10 +372,12 @@ Status Transposer::CreateDataFormatNode( } AttrValue src_format; - src_format.set_s(is_src_format_to_dst_format ? layout.src : layout.dst); + src_format.set_s(is_src_format_to_dst_format ? context->src_format + : context->dst_format); node.mutable_attr()->insert({kAttrSrcFormat, src_format}); AttrValue dst_format; - dst_format.set_s(is_src_format_to_dst_format ? layout.dst : layout.src); + dst_format.set_s(is_src_format_to_dst_format ? context->dst_format + : context->src_format); node.mutable_attr()->insert({kAttrDstFormat, dst_format}); // Add place holder for 1st input field. @@ -386,11 +390,10 @@ Status Transposer::CreateDataFormatNode( } Status Transposer::UpdateEdge( - TransposeContext* context, const NodeLayoutContext& layout, - absl::string_view name_format, absl::string_view op, - const AttrValue* input_shape, bool is_src_format_to_dst_format, - const int src_port, const int dst_port, utils::MutableNodeView* src_node, - utils::MutableNodeView* dst_node) { + TransposeContext* context, absl::string_view name_format, + absl::string_view op, const AttrValue* input_shape, + bool is_src_format_to_dst_format, const int src_port, const int dst_port, + utils::MutableNodeView* src_node, utils::MutableNodeView* dst_node) { DCHECK(src_node != nullptr); DCHECK(dst_node != nullptr); auto* src_node_def = src_node->node(); @@ -427,7 +430,7 @@ Status Transposer::UpdateEdge( ? AsControlDependency(src_node_def->name()) : ""; const std::vector& permutation = - is_src_format_to_dst_format ? layout.src_to_dst : layout.dst_to_src; + is_src_format_to_dst_format ? context->src_to_dst : context->dst_to_src; TF_RETURN_IF_ERROR(CreateTransposeNode( context, name_format, data_type, device, input_shape_proto, permutation, control_node_name, &added_node, &added_node_name)); @@ -440,7 +443,7 @@ Status Transposer::UpdateEdge( parsed_name.type != "CPU" && IsHostMemory(*src_node_def, src_port); const string node_name = absl::Substitute(name_format, op); TF_RETURN_IF_ERROR(CreateDataFormatNode( - context, layout, node_name, op, device, data_type, is_fanin_on_host, + context, node_name, op, device, data_type, is_fanin_on_host, is_src_format_to_dst_format, &added_node)); added_node_name = node_name; } else { @@ -540,49 +543,18 @@ inline string GetLayoutSensitiveNodeDataFormat( return ""; } -const NodeLayoutContext& LayoutSensitiveOpTransposer::CheckNodeTransposable( - TransposeContext* context, const utils::MutableNodeView& node) const { - const auto* node_def = node.node(); - - string src_format = GetLayoutSensitiveNodeDataFormat(node); - if (src_format.empty()) { - src_format = kNHWC; - } - - NodeLayoutContext layout; - DeviceProperties device; - string dst_format; - if (context->virtual_placer != nullptr) { - device = context->virtual_placer->get_device(*node_def); - dst_format = GetDstDataFormatForDevice(device); - } - - if (dst_format.empty() || src_format.length() != dst_format.length() || - src_format == dst_format || !CanProcessNode(*context, node)) { - layout.Initialize(src_format, src_format, - /*is_transposable=*/false); - return context->sensitive_node_layouts.emplace(node.node_index(), layout) - .first->second; - } - - layout.Initialize(src_format, dst_format, /*is_transposable=*/true); - return context->sensitive_node_layouts.emplace(node.node_index(), layout) - .first->second; -} - Status LayoutSensitiveOpTransposer::UpdateNode(TransposeContext* context, - const NodeLayoutContext& layout, utils::MutableNodeView* node) { utils::Mutation* mutation = context->graph_view->GetMutationBuilder(); AttrValue data_format_attr; - data_format_attr.set_s(layout.dst); + data_format_attr.set_s(context->dst_format); mutation->AddOrUpdateNodeAttr(node, kAttrDataFormat, data_format_attr); // Update attrs strides and ksize. const auto* strides_attr = node->GetAttr(kAttrStrides); if (strides_attr != nullptr) { AttrValue strides_attr_copy(*strides_attr); - TF_RETURN_IF_ERROR(Permute(layout.src_to_dst, + TF_RETURN_IF_ERROR(Permute(context->src_to_dst, strides_attr_copy.mutable_list()->mutable_i())); mutation->AddOrUpdateNodeAttr(node, kAttrStrides, strides_attr_copy); } @@ -590,103 +562,66 @@ Status LayoutSensitiveOpTransposer::UpdateNode(TransposeContext* context, const auto* ksize_attr = node->GetAttr(kAttrKSize); if (ksize_attr != nullptr) { AttrValue ksize_attr_copy(*ksize_attr); - TF_RETURN_IF_ERROR(Permute(layout.src_to_dst, + TF_RETURN_IF_ERROR(Permute(context->src_to_dst, ksize_attr_copy.mutable_list()->mutable_i())); mutation->AddOrUpdateNodeAttr(node, kAttrKSize, ksize_attr_copy); } return Status::OK(); } -Status LayoutSensitiveOpTransposer::CommitMutation( - TransposeContext* context, const utils::MutableNodeView& node) { - Status status = context->graph_view->GetMutationBuilder()->Apply(); - if (!status.ok()) { - // If mutation failed, update node layout to reflect current state in graph. - auto& it = context->sensitive_node_layouts[node.node_index()]; - it.dst = it.src; - it.transposable = false; - } - return status; -} - Status DefaultLayoutSensitiveOpTransposer::TransposeNode( TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsDefaultLayoutSensitiveOp(*node->node())); - const NodeLayoutContext& layout = CheckNodeTransposable(context, *node); - if (!layout.transposable || !IsFanoutPortDimsN(*node, 0, 4)) { + if (!ShouldProcess(*context, *node) || !IsFanoutPortDimsN(*node, 0, 4)) { return Status::OK(); } - TF_RETURN_IF_ERROR(UpdateNode(context, layout, node)); - TF_RETURN_IF_ERROR( - UpdateFaninEdgesWithOp(context, layout, {0}, node, kOpTranspose)); - TF_RETURN_IF_ERROR( - UpdateFanoutEdgesWithOp(context, layout, {0}, node, kOpTranspose)); - return CommitMutation(context, *node); + TF_RETURN_IF_ERROR(UpdateNode(context, node)); + TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose)); + TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose)); + return context->graph_view->GetMutationBuilder()->Apply(); } Status BiasAddGradTransposer::TransposeNode(TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsBiasAddGrad(*node->node())); - const NodeLayoutContext& layout = CheckNodeTransposable(context, *node); - if (!layout.transposable || !IsFaninPortDimsN(*node, 0, 4)) { + if (!ShouldProcess(*context, *node) || !IsFaninPortDimsN(*node, 0, 4)) { return Status::OK(); } - TF_RETURN_IF_ERROR(UpdateNode(context, layout, node)); - TF_RETURN_IF_ERROR( - UpdateFaninEdgesWithOp(context, layout, {0}, node, kOpTranspose)); + TF_RETURN_IF_ERROR(UpdateNode(context, node)); + TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose)); // No need to update output shape, as it is always of shape 1-D with size the // feature dimension of `out_backprop`, regardless of whether NCHW or NHWC is // used. - return CommitMutation(context, *node); + return context->graph_view->GetMutationBuilder()->Apply(); } Status Conv2DBackpropFilterTransposer::TransposeNode( TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsConv2DBackpropFilter(*node->node())); - const NodeLayoutContext& layout = CheckNodeTransposable(context, *node); - if (!layout.transposable || !IsFanoutPortDimsN(*node, 0, 4)) { + if (!ShouldProcess(*context, *node) || !IsFanoutPortDimsN(*node, 0, 4)) { return Status::OK(); } - TF_RETURN_IF_ERROR(UpdateNode(context, layout, node)); + TF_RETURN_IF_ERROR(UpdateNode(context, node)); TF_RETURN_IF_ERROR( - UpdateFaninEdgesWithOp(context, layout, {0, 2}, node, kOpTranspose)); + UpdateFaninEdgesWithOp(context, {0, 2}, node, kOpTranspose)); // No need to update output shape, as it is always of shape // [filter_height, filter_width, in_channels, out_channels], regardless of // whether NCHW or NHWC is used. - return CommitMutation(context, *node); + return context->graph_view->GetMutationBuilder()->Apply(); } Status Conv2DBackpropInputTransposer::TransposeNode( TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsConv2DBackpropInput(*node->node())); - const NodeLayoutContext& layout = CheckNodeTransposable(context, *node); - if (!layout.transposable || !IsFanoutPortDimsN(*node, 0, 4)) { + if (!ShouldProcess(*context, *node) || !IsFanoutPortDimsN(*node, 0, 4)) { return Status::OK(); } - TF_RETURN_IF_ERROR(UpdateNode(context, layout, node)); + TF_RETURN_IF_ERROR(UpdateNode(context, node)); + TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {2}, node, kOpTranspose)); TF_RETURN_IF_ERROR( - UpdateFaninEdgesWithOp(context, layout, {2}, node, kOpTranspose)); - TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, layout, {0}, node, - kOpDataFormatVecPermute)); - TF_RETURN_IF_ERROR( - UpdateFanoutEdgesWithOp(context, layout, {0}, node, kOpTranspose)); - return CommitMutation(context, *node); -} - -Status FusedBatchNormGradTransposer::TransposeNode( - TransposeContext* context, utils::MutableNodeView* node) { - DCHECK(IsFusedBatchNormGrad(*node->node())); - const NodeLayoutContext& layout = CheckNodeTransposable(context, *node); - if (!layout.transposable || !IsFanoutPortDimsN(*node, 0, 4) || - !IsTraining(*node)) { - return Status::OK(); - } - TF_RETURN_IF_ERROR(UpdateNode(context, layout, node)); - TF_RETURN_IF_ERROR( - UpdateFaninEdgesWithOp(context, layout, {0, 1}, node, kOpTranspose)); - TF_RETURN_IF_ERROR( - UpdateFanoutEdgesWithOp(context, layout, {0}, node, kOpTranspose)); - return CommitMutation(context, *node); + UpdateFaninEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute)); + TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose)); + return context->graph_view->GetMutationBuilder()->Apply(); } bool FusedBatchNormGradTransposer::IsTraining( @@ -698,6 +633,20 @@ bool FusedBatchNormGradTransposer::IsTraining( return false; } +Status FusedBatchNormGradTransposer::TransposeNode( + TransposeContext* context, utils::MutableNodeView* node) { + DCHECK(IsFusedBatchNormGrad(*node->node())); + if (!ShouldProcess(*context, *node) || !IsFanoutPortDimsN(*node, 0, 4) || + !IsTraining(*node)) { + return Status::OK(); + } + TF_RETURN_IF_ERROR(UpdateNode(context, node)); + TF_RETURN_IF_ERROR( + UpdateFaninEdgesWithOp(context, {0, 1}, node, kOpTranspose)); + TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose)); + return context->graph_view->GetMutationBuilder()->Apply(); +} + Status MaxPoolV2Transposer::TransposeNode(TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsMaxPoolV2(*node->node())); @@ -706,316 +655,163 @@ Status MaxPoolV2Transposer::TransposeNode(TransposeContext* context, // constant. const auto& data_fanin = node->GetRegularFanin(0); auto* data_fanin_node = data_fanin.node_view(); - NodeLayoutContext layout = CheckNodeTransposable(context, *node); - if (!layout.transposable || + if (!ShouldProcess(*context, *node) || !IsFanoutPortDimsN(*data_fanin_node, data_fanin.index(), 4)) { return Status::OK(); } - TF_RETURN_IF_ERROR(UpdateNode(context, layout, node)); + TF_RETURN_IF_ERROR(UpdateNode(context, node)); + TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose)); TF_RETURN_IF_ERROR( - UpdateFaninEdgesWithOp(context, layout, {0}, node, kOpTranspose)); - TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, layout, {1, 2}, node, - kOpDataFormatVecPermute)); - TF_RETURN_IF_ERROR( - UpdateFanoutEdgesWithOp(context, layout, {0}, node, kOpTranspose)); - return CommitMutation(context, *node); + UpdateFaninEdgesWithOp(context, {1, 2}, node, kOpDataFormatVecPermute)); + TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose)); + return context->graph_view->GetMutationBuilder()->Apply(); } Status MaxPoolGradTransposer::TransposeNode(TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsMaxPoolGrad(*node->node())); - const NodeLayoutContext& layout = CheckNodeTransposable(context, *node); - if (!layout.transposable || !IsFanoutPortDimsN(*node, 0, 4)) { + if (!ShouldProcess(*context, *node) || !IsFanoutPortDimsN(*node, 0, 4)) { return Status::OK(); } - TF_RETURN_IF_ERROR(UpdateNode(context, layout, node)); + TF_RETURN_IF_ERROR(UpdateNode(context, node)); TF_RETURN_IF_ERROR( - UpdateFaninEdgesWithOp(context, layout, {0, 1, 2}, node, kOpTranspose)); - TF_RETURN_IF_ERROR( - UpdateFanoutEdgesWithOp(context, layout, {0}, node, kOpTranspose)); - return CommitMutation(context, *node); + UpdateFaninEdgesWithOp(context, {0, 1, 2}, node, kOpTranspose)); + TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose)); + return context->graph_view->GetMutationBuilder()->Apply(); } Status MaxPoolGradV2Transposer::TransposeNode(TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsMaxPoolGradV2(*node->node())); - const NodeLayoutContext& layout = CheckNodeTransposable(context, *node); - if (!layout.transposable || !IsFanoutPortDimsN(*node, 0, 4)) { + if (!ShouldProcess(*context, *node) || !IsFanoutPortDimsN(*node, 0, 4)) { return Status::OK(); } - TF_RETURN_IF_ERROR(UpdateNode(context, layout, node)); + TF_RETURN_IF_ERROR(UpdateNode(context, node)); TF_RETURN_IF_ERROR( - UpdateFaninEdgesWithOp(context, layout, {0, 1, 2}, node, kOpTranspose)); - TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, layout, {3, 4}, node, - kOpDataFormatVecPermute)); + UpdateFaninEdgesWithOp(context, {0, 1, 2}, node, kOpTranspose)); TF_RETURN_IF_ERROR( - UpdateFanoutEdgesWithOp(context, layout, {0}, node, kOpTranspose)); - return CommitMutation(context, *node); + UpdateFaninEdgesWithOp(context, {3, 4}, node, kOpDataFormatVecPermute)); + TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose)); + return context->graph_view->GetMutationBuilder()->Apply(); } // Layout agnostic transposer. -// Checks if fanin is a Transpose or DataFormatDimMap/DataFormatVecPermute added -// by the transposer. -inline bool IsTransposerAddedOp(const utils::MutableFanoutView& fanin, - int num_nodes) { - if (fanin.node_index() < num_nodes) { +bool IsValidConstPermTransposeNode(const utils::MutableNodeView& node, + absl::Span permutation) { + Tensor tensor; + if (!GetValueAttrIfConstPermTransposeNode(node, &tensor)) { return false; } - const auto* fanin_node = fanin.node_view(); - if (IsDataFormatOp(*fanin_node)) { - return true; + if (tensor.NumElements() != permutation.size()) { + return false; } - const auto* fanin_node_def = fanin_node->node(); - if (IsTranspose(*fanin_node_def)) { - const auto& regular_fanin_1 = fanin_node->GetRegularFanin(1); - return regular_fanin_1.node_index() >= num_nodes && - IsConstant(*regular_fanin_1.node_view()->node()); + + const auto& tensor_data = tensor.unaligned_flat(); + for (int i = 0; i < permutation.size(); i++) { + if (permutation[i] != tensor_data(i)) { + return false; + } + } + return true; +} + +bool IsValidDataFormatNode(const utils::MutableNodeView& node, + absl::string_view src_format, + absl::string_view dst_format) { + if (!IsDataFormatOp(node)) { + return false; + } + const auto* src_format_attr = node.GetAttr(kAttrSrcFormat); + if (src_format_attr == nullptr || src_format_attr->s() != src_format) { + return false; + } + const auto* dst_format_attr = node.GetAttr(kAttrDstFormat); + if (dst_format_attr == nullptr || dst_format_attr->s() != dst_format) { + return false; + } + return true; +} + +bool LayoutAgnosticOpTransposer::IsAfterDstToSrcTransform( + const TransposeContext& context, const utils::MutableNodeView& node) const { + std::deque queue; + absl::flat_hash_set visited_nodes; + auto data_node_pos = GetDataFaninPorts(node); + for (const int pos : data_node_pos) { + const auto& fanin = node.GetRegularFanin(pos); + auto* fanin_node = fanin.node_view(); + queue.push_back(fanin_node); + visited_nodes.insert(fanin_node); + } + // The code will exit this while loop in one iteration in most cases, as the + // graph is already topologically sorted. + while (!queue.empty()) { + utils::MutableNodeView* current_node = queue.front(); + queue.pop_front(); + if (IsValidConstPermTransposeNode(*current_node, context.dst_to_src) || + IsValidDataFormatNode(*current_node, context.dst_format, + context.src_format)) { + return true; + } + // We only continue searching if the path is connected through + // format-agnostic nodes. + if (IsLayoutAgnosticOp(*current_node->node())) { + auto current_node_pos = GetDataFaninPorts(*current_node); + for (const auto& pos : current_node_pos) { + const auto& fanin = current_node->GetRegularFanin(pos); + auto* fanin_node = fanin.node_view(); + if (visited_nodes.insert(fanin_node).second) { + queue.push_back(fanin_node); + } + } + } } return false; } -// Check if fanin is of a processed layout sensitive or agnostic node and return -// its data layout that was used to transform it. -NodeLayoutContext GetTransposerAddedOpDataFormat( - const TransposeContext& context, const utils::MutableFanoutView& fanin) { - NodeLayoutContext layout; - if (!IsTransposerAddedOp(fanin, context.num_nodes)) { - return layout; - } - const auto* fanin_node = fanin.node_view(); - const auto& regular_fanin_0 = fanin_node->GetRegularFanin(0); - const auto* node_def = regular_fanin_0.node_view()->node(); - if (IsLayoutSensitiveOp(*node_def)) { - auto it = context.sensitive_node_layouts.find(regular_fanin_0.node_index()); - if (it != context.sensitive_node_layouts.end()) { - layout = it->second; - } - } else if (IsLayoutAgnosticOp(*node_def)) { - auto it = context.agnostic_node_layouts.find(regular_fanin_0.node_index()); - if (it != context.agnostic_node_layouts.end()) { - layout = it->second.fanout_layouts[regular_fanin_0.index()]; +std::vector LayoutAgnosticOpTransposer::GetVariadic4DFaninPorts( + const TransposeContext& context, const utils::MutableNodeView& node) const { + std::vector ports; + const int num_regular_fanins = node.NumRegularFanins(); + ports.reserve(num_regular_fanins); + for (int i = 0; i < num_regular_fanins; ++i) { + const auto& regular_fanin = node.GetRegularFanin(i); + auto* regular_fanin_node = regular_fanin.node_view(); + int regular_fanin_port = regular_fanin.index(); + if (IsFanoutPortDimsN(*regular_fanin_node, regular_fanin_port, 4) && + ((IsAfterDstToSrcTransform(context, *regular_fanin_node) && + IsLayoutAgnosticOp(*regular_fanin_node->node())) || + IsValidConstPermTransposeNode(*regular_fanin_node, + context.dst_to_src))) { + ports.push_back(i); } } - return layout; -} - -// Resolve fanout data layouts from fanin data layouts for a given node. -void SetFanoutDataLayouts(const TransposeContext& context, - const utils::MutableNodeView& node, - absl::Span fanin_layouts, - absl::Span fanin_ports, - std::vector* fanout_layouts, - absl::Span fanout_ports, - bool* transposable) { - const auto* node_def = node.node(); - DeviceProperties device; - if (context.virtual_placer != nullptr) { - device = context.virtual_placer->get_device(*node_def); - } - if (IsMerge(*node_def)) { - // Check if both fanins of Merge are being transposed in the same manner. - // TODO(lyandy): Determine if the restriction of both fanins needing to be - // transposed is necessary. - const auto& fanin_layout_0 = fanin_layouts[0]; - if (fanin_layout_0.transposable && fanin_layout_0 == fanin_layouts[1]) { - (*fanout_layouts)[0] = fanin_layout_0; - *transposable = true; - } - } else if (IsIdentityN(*node_def) || IsShapeN(*node_def)) { - // Simply forward transposed fanins to fanouts as fanins of IdentityN/ShapeN - // are independent from one another. - for (const auto& port : fanin_ports) { - const auto& layout = fanin_layouts[port]; - if (layout.transposable) { - (*fanout_layouts)[port] = layout; - *transposable = true; - } - } - } else { - // Check if all fanins that are transposed have been transposed from and to - // the same data format layouts. - // TODO(lyandy): Implement a better selection of data format if only some - // fanins are transposed. - bool is_transposable = true; - bool found_one_transposable = false; - NodeLayoutContext layout; - for (const auto& port : fanin_ports) { - if (fanin_layouts[port].transposable) { - if (!found_one_transposable) { - found_one_transposable = true; - layout = fanin_layouts[port]; - } else if (layout != fanin_layouts[port]) { - is_transposable = false; - break; - } - } - } - if (is_transposable && found_one_transposable) { - for (const auto& port : fanout_ports) { - (*fanout_layouts)[port] = layout; - } - *transposable = true; - } - } -} - -void LayoutAgnosticOpTransposer::ProcessFanoutDataLayouts( - TransposeContext* context, const utils::MutableNodeView& node) { - const int node_index = node.node_index(); - auto it = context->agnostic_node_layouts.emplace(node_index, - LayoutAgnosticNodeFanouts()); - auto fanout_ports = GetDataFanoutPorts(node); - auto& node_layout = it.first->second; - if (!fanout_ports.empty()) { - node_layout.fanout_layouts.resize(fanout_ports[fanout_ports.size() - 1] + - 1); - } - auto fanin_ports = GetDataFaninPorts(node); - if (!fanin_ports.empty()) { - node_layout.fanin_layouts.resize(fanin_ports[fanin_ports.size() - 1] + 1); - } - // Copy over fanout data layouts for each fanin. - for (const auto& port : fanin_ports) { - const auto& fanin = node.GetRegularFanin(port); - if (IsLayoutAgnosticOp(*fanin.node_view()->node())) { - node_layout.fanin_layouts[port] = - context->agnostic_node_layouts[fanin.node_index()] - .fanout_layouts[fanin.index()]; - } else { - node_layout.fanin_layouts[port] = - GetTransposerAddedOpDataFormat(*context, fanin); - } - } - SetFanoutDataLayouts(*context, node, node_layout.fanin_layouts, fanin_ports, - &node_layout.fanout_layouts, fanout_ports, - &node_layout.transposable); -} - -// Performs a DFS traversal processing layout agnostic parent nodes before -// layout agnostic children nodes to propagate data layouts from processed -// layout sensitive nodes and layout agnostic nodes. Nodes are only visited once -// and results are cached. -void LayoutAgnosticOpTransposer::FetchNodeDataLayouts( - TransposeContext* context, const utils::MutableNodeView& node) { - const int num_nodes = context->graph_view->NumNodes(); - std::vector visited(num_nodes); - for (const auto& it : context->agnostic_node_layouts) { - visited[it.first] = true; - } - if (visited[node.node_index()]) { - return; - } - - // TODO(lyandy): Pull out DFS traversal into a GraphView util. - enum RecursionStackState : bool { ENTER, EXIT }; - - struct RecursionStackEntry { - RecursionStackEntry(int node_index, RecursionStackState recursion_state) - : node_index(node_index), recursion_state(recursion_state) {} - - const int node_index; - const RecursionStackState recursion_state; - }; - - std::vector recursion_stack; - recursion_stack.push_back({node.node_index(), ENTER}); - visited[node.node_index()] = true; - - while (!recursion_stack.empty()) { - auto curr_entry = recursion_stack.back(); - recursion_stack.pop_back(); - const auto* curr_node = context->graph_view->GetNode(curr_entry.node_index); - if (curr_entry.recursion_state == ENTER) { - recursion_stack.push_back({curr_entry.node_index, EXIT}); - auto fanin_ports = GetDataFaninPorts(*curr_node); - for (const auto& port : fanin_ports) { - const auto& fanin = curr_node->GetRegularFanin(port); - const auto* fanin_node = fanin.node_view(); - // Traverse up chains of layout agnostic nodes. - if (IsLayoutAgnosticOp(*fanin_node->node()) && - !visited[fanin.node_index()]) { - // Unprocessed layout agnostic node fanin. - recursion_stack.push_back({fanin.node_index(), ENTER}); - visited[fanin.node_index()] = true; - } else if (IsTransposerAddedOp(fanin, context->num_nodes)) { - const auto& fanin_fanin_0 = fanin_node->GetRegularFanin(0); - if (IsLayoutAgnosticOp(*fanin_fanin_0.node_view()->node()) && - !visited[fanin_fanin_0.node_index()]) { - // Processed layout agnostic node fanin. - recursion_stack.push_back({fanin_fanin_0.node_index(), ENTER}); - visited[fanin_fanin_0.node_index()] = true; - } - } - } - } else { - ProcessFanoutDataLayouts(context, *curr_node); - } - } -} - -const LayoutAgnosticNodeFanouts& -LayoutAgnosticOpTransposer::CheckNodeTransposable( - TransposeContext* context, const utils::MutableNodeView& node) { - FetchNodeDataLayouts(context, node); - if (!CanProcessNode(*context, node)) { - return context->unknown_agnostic_node_layout; - } - auto it = context->agnostic_node_layouts.find(node.node_index()); - if (it != context->agnostic_node_layouts.end()) { - return it->second; - } - return context->unknown_agnostic_node_layout; -} - -const NodeLayoutContext& LayoutAgnosticOpTransposer::GetDataFormatLayoutForNode( - TransposeContext* context, const utils::MutableNodeView& node, - int fanout_port, - std::function node_check) { - if (!node_check(node)) { - return context->unknown_node_layout; - } - const LayoutAgnosticNodeFanouts& layout = - CheckNodeTransposable(context, node); - if (!layout.transposable) { - return context->unknown_node_layout; - } - return layout.fanout_layouts[fanout_port]; + return ports; } Status DefaultLayoutAgnosticOpTransposer::TransposeNode( TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsDefaultLayoutAgnosticOp(*node->node())); - const NodeLayoutContext& layout = - GetDataFormatLayoutForNode(context, *node, /*fanout_port=*/0, - [this](const utils::MutableNodeView& node) { - return IsFanoutPortDimsN(node, 0, 4); - }); - if (!layout.transposable) { + if (!ShouldProcess(*context, *node) || !IsFanoutPortDimsN(*node, 0, 4) || + !IsAfterDstToSrcTransform(*context, *node)) { return Status::OK(); } - TF_RETURN_IF_ERROR( - UpdateFaninEdgesWithOp(context, layout, {0}, node, kOpTranspose)); - TF_RETURN_IF_ERROR( - UpdateFanoutEdgesWithOp(context, layout, {0}, node, kOpTranspose)); + TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose)); + TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose)); return context->graph_view->GetMutationBuilder()->Apply(); } Status AddNTransposer::TransposeNode(TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsAddN(*node->node())); - const NodeLayoutContext& layout = - GetDataFormatLayoutForNode(context, *node, /*fanout_port=*/0, - [this](const utils::MutableNodeView& node) { - return IsFanoutPortDimsN(node, 0, 4); - }); - if (!layout.transposable) { + if (!ShouldProcess(*context, *node) || !IsFanoutPortDimsN(*node, 0, 4) || + !IsAfterDstToSrcTransform(*context, *node)) { return Status::OK(); } - TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp( - context, layout, GetDataFaninPorts(*node), node, kOpTranspose)); - TF_RETURN_IF_ERROR( - UpdateFanoutEdgesWithOp(context, layout, {0}, node, kOpTranspose)); + TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, GetDataFaninPorts(*node), + node, kOpTranspose)); + TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose)); return context->graph_view->GetMutationBuilder()->Apply(); } @@ -1102,8 +898,7 @@ Status BinaryOpTransposer::AddNodeShapeConst(utils::Mutation* mutation, } Status BinaryOpTransposer::MaybeReshapeVectorFanin( - TransposeContext* context, const NodeLayoutContext& layout, - utils::MutableNodeView* node) { + TransposeContext* context, utils::MutableNodeView* node) { int vector_index = -1; if (IsNDOperateWithMD(*node, 4, 1)) { vector_index = 1; @@ -1114,7 +909,7 @@ Status BinaryOpTransposer::MaybeReshapeVectorFanin( const string& node_name = node->GetName(); const string& node_device = node->GetDevice(); string reshape_node_name = LayoutOptimizerNode(GetReshapeNodeNameFormat( - node_name, vector_index, layout.src, layout.dst)); + node_name, vector_index, context->src_format, context->dst_format)); string shape_const_node_name = LayoutOptimizerNode( GetShapeConstNodeNameFormat(node_name, vector_index)); const auto& fanin = node->GetRegularFanin(vector_index); @@ -1147,35 +942,26 @@ Status BinaryOpTransposer::MaybeReshapeVectorFanin( Status BinaryOpTransposer::TransposeNode(TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsBinaryOp(*node->node())); - const NodeLayoutContext& layout = - GetDataFormatLayoutForNode(context, *node, /*fanout_port=*/0, - [this](const utils::MutableNodeView& node) { - return IsFaninShapeSupported(node); - }); - if (!layout.transposable) { + if (!ShouldProcess(*context, *node) || !IsFaninShapeSupported(*node) || + !IsAfterDstToSrcTransform(*context, *node)) { return Status::OK(); } - TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp( - context, layout, Get4DDataFaninPorts(*node), node, kOpTranspose)); - TF_RETURN_IF_ERROR(MaybeReshapeVectorFanin(context, layout, node)); - TF_RETURN_IF_ERROR( - UpdateFanoutEdgesWithOp(context, layout, {0}, node, kOpTranspose)); + TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, Get4DDataFaninPorts(*node), + node, kOpTranspose)); + TF_RETURN_IF_ERROR(MaybeReshapeVectorFanin(context, node)); + TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose)); return context->graph_view->GetMutationBuilder()->Apply(); } Status ConcatOpTransposer::TransposeNode(TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsConcat(*node->node())); - const NodeLayoutContext& layout = - GetDataFormatLayoutForNode(context, *node, /*fanout_port=*/0, - [this](const utils::MutableNodeView& node) { - return IsFanoutPortDimsN(node, 0, 4); - }); - if (!layout.transposable) { + if (!ShouldProcess(*context, *node) || !IsFanoutPortDimsN(*node, 0, 4) || + !IsAfterDstToSrcTransform(*context, *node)) { return Status::OK(); } TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp( - context, layout, GetConcatDataFaninPorts(*node), node, kOpTranspose)); + context, GetConcatDataFaninPorts(*node), node, kOpTranspose)); int axis_node = 0; if (node->GetOp() == "ConcatV2") { const auto* n_attr = node->GetAttr(kAttrN); @@ -1183,72 +969,65 @@ Status ConcatOpTransposer::TransposeNode(TransposeContext* context, axis_node = n_attr->i(); } } - TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, layout, {axis_node}, node, - kOpDataFormatDimMap)); TF_RETURN_IF_ERROR( - UpdateFanoutEdgesWithOp(context, layout, {0}, node, kOpTranspose)); + UpdateFaninEdgesWithOp(context, {axis_node}, node, kOpDataFormatDimMap)); + TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose)); return context->graph_view->GetMutationBuilder()->Apply(); } Status FillOpTransposer::TransposeNode(TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsFill(*node->node())); - const NodeLayoutContext& layout = - GetDataFormatLayoutForNode(context, *node, /*fanout_port=*/0, - [this](const utils::MutableNodeView& node) { - return IsFanoutPortDimsN(node, 0, 4); - }); - if (!layout.transposable) { + if (!ShouldProcess(*context, *node) || !IsFanoutPortDimsN(*node, 0, 4) || + !IsAfterDstToSrcTransform(*context, *node)) { return Status::OK(); } - TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, layout, {0}, node, - kOpDataFormatVecPermute)); TF_RETURN_IF_ERROR( - UpdateFanoutEdgesWithOp(context, layout, {0}, node, kOpTranspose)); + UpdateFaninEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute)); + TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose)); return context->graph_view->GetMutationBuilder()->Apply(); } Status IdentityNTransposer::TransposeNode(TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsIdentityN(*node->node())); - const LayoutAgnosticNodeFanouts& layout = - CheckNodeTransposable(context, *node); - if (!layout.transposable) { + const auto ports = GetVariadic4DFaninPorts(*context, *node); + if (!ShouldProcess(*context, *node) || ports.empty()) { return Status::OK(); } - bool transposed = false; - const int num_layout_ports = layout.fanout_layouts.size(); - for (int i = 0; i < num_layout_ports; ++i) { - const auto& layout_port = layout.fanout_layouts[i]; - if (layout_port.transposable && IsFanoutPortDimsN(*node, i, 4)) { - TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, layout_port, {i}, node, - kOpTranspose)); - TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, layout_port, {i}, - node, kOpTranspose)); - transposed = true; + TF_RETURN_IF_ERROR( + UpdateFaninEdgesWithOp(context, ports, node, kOpTranspose)); + TF_RETURN_IF_ERROR( + UpdateFanoutEdgesWithOp(context, ports, node, kOpTranspose)); + return context->graph_view->GetMutationBuilder()->Apply(); +} + +bool MergeTransposer::IsEveryFaninAfterDstToSrcTransform( + const TransposeContext& context, const utils::MutableNodeView& node) const { + for (const auto& regular_fanin : node.GetRegularFanins()) { + auto* regular_fanin_node = regular_fanin.node_view(); + if (IsFanoutPortDimsN(*regular_fanin_node, regular_fanin.index(), 4) && + ((IsAfterDstToSrcTransform(context, *regular_fanin_node) && + IsLayoutAgnosticOp(*regular_fanin_node->node())) || + IsValidConstPermTransposeNode(*regular_fanin_node, + context.dst_to_src))) { + continue; } + return false; } - if (transposed) { - return context->graph_view->GetMutationBuilder()->Apply(); - } - return Status::OK(); + return true; } Status MergeTransposer::TransposeNode(TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsMerge(*node->node())); - const NodeLayoutContext& layout = - GetDataFormatLayoutForNode(context, *node, /*fanout_port=*/0, - [this](const utils::MutableNodeView& node) { - return IsFanoutPortDimsN(node, 0, 4); - }); - if (!layout.transposable) { + if (!ShouldProcess(*context, *node) || !IsFanoutPortDimsN(*node, 0, 4) || + !IsEveryFaninAfterDstToSrcTransform(*context, *node)) { return Status::OK(); } - TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp( - context, layout, GetDataFaninPorts(*node), node, kOpTranspose)); - TF_RETURN_IF_ERROR( - UpdateFanoutEdgesWithOp(context, layout, {0}, node, kOpTranspose)); + TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, GetDataFaninPorts(*node), + node, kOpTranspose)); + TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose)); return context->graph_view->GetMutationBuilder()->Apply(); } @@ -1256,42 +1035,14 @@ Status PadTransposer::TransposeNode(TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsMirrorPad(*node->node()) || IsMirrorPadGrad(*node->node()) || IsPad(*node->node())); - const NodeLayoutContext& layout = - GetDataFormatLayoutForNode(context, *node, /*fanout_port=*/0, - [this](const utils::MutableNodeView& node) { - return IsFanoutPortDimsN(node, 0, 4); - }); - if (!layout.transposable) { + if (!ShouldProcess(*context, *node) || !IsFanoutPortDimsN(*node, 0, 4) || + !IsAfterDstToSrcTransform(*context, *node)) { return Status::OK(); } + TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose)); TF_RETURN_IF_ERROR( - UpdateFaninEdgesWithOp(context, layout, {0}, node, kOpTranspose)); - TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, layout, {1}, node, - kOpDataFormatVecPermute)); - TF_RETURN_IF_ERROR( - UpdateFanoutEdgesWithOp(context, layout, {0}, node, kOpTranspose)); - return context->graph_view->GetMutationBuilder()->Apply(); -} - -Status ReduceTransposer::TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) { - DCHECK(IsReduceOp(*node->node())); - const NodeLayoutContext& layout = - GetDataFormatLayoutForNode(context, *node, /*fanout_port=*/0, - [this](const utils::MutableNodeView& node) { - return IsFaninPortDimsN(node, 0, 4); - }); - if (!layout.transposable || !IsReduceAxisSupported(layout, *node)) { - return Status::OK(); - } - TF_RETURN_IF_ERROR( - UpdateFaninEdgesWithOp(context, layout, {0}, node, kOpTranspose)); - TF_RETURN_IF_ERROR( - UpdateFaninEdgesWithOp(context, layout, {1}, node, kOpDataFormatDimMap)); - if (KeepDims(*node)) { - TF_RETURN_IF_ERROR( - UpdateFanoutEdgesWithOp(context, layout, {0}, node, kOpTranspose)); - } + UpdateFaninEdgesWithOp(context, {1}, node, kOpDataFormatVecPermute)); + TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose)); return context->graph_view->GetMutationBuilder()->Apply(); } @@ -1326,60 +1077,47 @@ bool ReduceTransposer::IsAlongAxis(const utils::MutableNodeView& axis_node, } bool ReduceTransposer::IsReduceAxisSupported( - const NodeLayoutContext& layout, const utils::MutableNodeView& node) { + const TransposeContext& context, const utils::MutableNodeView& node) { const auto& regular_fanin_1 = node.GetRegularFanin(1); auto* axis_node = regular_fanin_1.node_view(); // TODO(lyandy): Generalize this for other data format conversions. return KeepDims(node) || - (layout.src == kNHWC && layout.dst == kNCHW && + (context.src_format == kNHWC && context.dst_format == kNCHW && (IsAlongAxis(*axis_node, {0, 1, 2, 3}) || IsAlongAxis(*axis_node, {1, 2, 3}) || IsAlongAxis(*axis_node, {0, 1, 2}) || IsAlongAxis(*axis_node, {1, 2}) || IsAlongAxis(*axis_node, {3}))); } -Status ReverseV2Transposer::TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) { - DCHECK(IsReverseV2(*node->node())); - const NodeLayoutContext& layout = - GetDataFormatLayoutForNode(context, *node, /*fanout_port=*/0, - [this](const utils::MutableNodeView& node) { - return IsFanoutPortDimsN(node, 0, 4); - }); - if (!layout.transposable) { +Status ReduceTransposer::TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) { + DCHECK(IsReduceOp(*node->node())); + if (!ShouldProcess(*context, *node) || !IsFaninPortDimsN(*node, 0, 4) || + !IsReduceAxisSupported(*context, *node) || + !IsAfterDstToSrcTransform(*context, *node)) { return Status::OK(); } + TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose)); TF_RETURN_IF_ERROR( - UpdateFaninEdgesWithOp(context, layout, {0}, node, kOpTranspose)); - TF_RETURN_IF_ERROR( - UpdateFaninEdgesWithOp(context, layout, {1}, node, kOpDataFormatDimMap)); - TF_RETURN_IF_ERROR( - UpdateFanoutEdgesWithOp(context, layout, {0}, node, kOpTranspose)); + UpdateFaninEdgesWithOp(context, {1}, node, kOpDataFormatDimMap)); + if (KeepDims(*node)) { + TF_RETURN_IF_ERROR( + UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose)); + } return context->graph_view->GetMutationBuilder()->Apply(); } -Status SelectTransposer::TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) { - DCHECK(IsSelect(*node->node())); - const auto& regular_fanin_0 = node->GetRegularFanin(0); - auto* regular_fanin_0_node = regular_fanin_0.node_view(); - if (!IsFaninScalarVector4D(*regular_fanin_0_node, regular_fanin_0.index())) { +Status ReverseV2Transposer::TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) { + DCHECK(IsReverseV2(*node->node())); + if (!ShouldProcess(*context, *node) || !IsFanoutPortDimsN(*node, 0, 4) || + !IsAfterDstToSrcTransform(*context, *node)) { return Status::OK(); } - const NodeLayoutContext& layout = - GetDataFormatLayoutForNode(context, *node, /*fanout_port=*/0, - [this](const utils::MutableNodeView& node) { - return IsFanoutPortDimsN(node, 0, 4); - }); - if (!layout.transposable) { - return Status::OK(); - } - TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp( - context, layout, - GetFaninPorts(*regular_fanin_0_node, regular_fanin_0.index()), node, - kOpTranspose)); + TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose)); TF_RETURN_IF_ERROR( - UpdateFanoutEdgesWithOp(context, layout, {0}, node, kOpTranspose)); + UpdateFaninEdgesWithOp(context, {1}, node, kOpDataFormatDimMap)); + TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose)); return context->graph_view->GetMutationBuilder()->Apply(); } @@ -1399,127 +1137,93 @@ std::vector SelectTransposer::GetFaninPorts( return {1, 2}; } +Status SelectTransposer::TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) { + DCHECK(IsSelect(*node->node())); + const auto& regular_fanin_0 = node->GetRegularFanin(0); + auto* regular_fanin_0_node = regular_fanin_0.node_view(); + if (!ShouldProcess(*context, *node) || !IsFanoutPortDimsN(*node, 0, 4) || + !IsFaninScalarVector4D(*regular_fanin_0_node, regular_fanin_0.index()) || + !IsAfterDstToSrcTransform(*context, *node)) { + return Status::OK(); + } + TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp( + context, GetFaninPorts(*regular_fanin_0_node, regular_fanin_0.index()), + node, kOpTranspose)); + TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose)); + return context->graph_view->GetMutationBuilder()->Apply(); +} + Status ShapeTransposer::TransposeNode(TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsShape(*node->node())); - const NodeLayoutContext& layout = - GetDataFormatLayoutForNode(context, *node, /*fanout_port=*/0, - [this](const utils::MutableNodeView& node) { - return IsFaninPortDimsN(node, 0, 4); - }); - if (!layout.transposable) { + if (!ShouldProcess(*context, *node) || !IsFaninPortDimsN(*node, 0, 4) || + !IsAfterDstToSrcTransform(*context, *node)) { return Status::OK(); } + TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose)); TF_RETURN_IF_ERROR( - UpdateFaninEdgesWithOp(context, layout, {0}, node, kOpTranspose)); - TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, layout, {0}, node, - kOpDataFormatVecPermute)); + UpdateFanoutEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute)); return context->graph_view->GetMutationBuilder()->Apply(); } Status ShapeNTransposer::TransposeNode(TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsShapeN(*node->node())); - const LayoutAgnosticNodeFanouts& layout = - CheckNodeTransposable(context, *node); - if (!layout.transposable) { + const auto ports = GetVariadic4DFaninPorts(*context, *node); + if (!ShouldProcess(*context, *node) || ports.empty()) { return Status::OK(); } - bool transposed = false; - const int num_layout_ports = layout.fanout_layouts.size(); - for (int i = 0; i < num_layout_ports; ++i) { - const auto& layout_port = layout.fanout_layouts[i]; - if (layout_port.transposable && IsFaninPortDimsN(*node, i, 4)) { - TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, layout_port, {i}, node, - kOpTranspose)); - TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp( - context, layout_port, {i}, node, kOpDataFormatVecPermute)); - transposed = true; - } - } - if (transposed) { - return context->graph_view->GetMutationBuilder()->Apply(); - } - return Status::OK(); + TF_RETURN_IF_ERROR( + UpdateFaninEdgesWithOp(context, ports, node, kOpTranspose)); + TF_RETURN_IF_ERROR( + UpdateFanoutEdgesWithOp(context, ports, node, kOpDataFormatVecPermute)); + return context->graph_view->GetMutationBuilder()->Apply(); } Status SliceTransposer::TransposeNode(TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsSlice(*node->node())); - const NodeLayoutContext& layout = - GetDataFormatLayoutForNode(context, *node, /*fanout_port=*/0, - [this](const utils::MutableNodeView& node) { - return IsFanoutPortDimsN(node, 0, 4); - }); - if (!layout.transposable) { + if (!ShouldProcess(*context, *node) || !IsFanoutPortDimsN(*node, 0, 4) || + !IsAfterDstToSrcTransform(*context, *node)) { return Status::OK(); } + TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose)); TF_RETURN_IF_ERROR( - UpdateFaninEdgesWithOp(context, layout, {0}, node, kOpTranspose)); - TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, layout, {1, 2}, node, - kOpDataFormatVecPermute)); - TF_RETURN_IF_ERROR( - UpdateFanoutEdgesWithOp(context, layout, {0}, node, kOpTranspose)); + UpdateFaninEdgesWithOp(context, {1, 2}, node, kOpDataFormatVecPermute)); + TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose)); return context->graph_view->GetMutationBuilder()->Apply(); } Status SplitTransposer::TransposeNode(TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsSplit(*node->node())); - auto output_ports = GetDataFanoutPorts(*node); - const NodeLayoutContext& layout = GetDataFormatLayoutForNode( - context, *node, /*fanout_port=*/0, - [this, output_ports](const utils::MutableNodeView& node) { - return IsFanoutPortsDimsN(node, output_ports, 4); - }); - if (!layout.transposable) { + const auto ports = GetDataFanoutPorts(*node); + if (!ShouldProcess(*context, *node) || !IsFanoutPortsDimsN(*node, ports, 4) || + !IsAfterDstToSrcTransform(*context, *node)) { return Status::OK(); } + TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {1}, node, kOpTranspose)); TF_RETURN_IF_ERROR( - UpdateFaninEdgesWithOp(context, layout, {1}, node, kOpTranspose)); + UpdateFaninEdgesWithOp(context, {0}, node, kOpDataFormatDimMap)); TF_RETURN_IF_ERROR( - UpdateFaninEdgesWithOp(context, layout, {0}, node, kOpDataFormatDimMap)); - TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, layout, output_ports, - node, kOpTranspose)); + UpdateFanoutEdgesWithOp(context, ports, node, kOpTranspose)); return context->graph_view->GetMutationBuilder()->Apply(); } Status SplitVTransposer::TransposeNode(TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsSplitV(*node->node())); - auto output_ports = GetDataFanoutPorts(*node); - const NodeLayoutContext& layout = GetDataFormatLayoutForNode( - context, *node, /*fanout_port=*/0, - [this, output_ports](const utils::MutableNodeView& node) { - return IsFanoutPortsDimsN(node, output_ports, 4); - }); - if (!layout.transposable) { + const auto ports = GetDataFanoutPorts(*node); + if (!ShouldProcess(*context, *node) || !IsFanoutPortsDimsN(*node, ports, 4) || + !IsAfterDstToSrcTransform(*context, *node)) { return Status::OK(); } + TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose)); TF_RETURN_IF_ERROR( - UpdateFaninEdgesWithOp(context, layout, {0}, node, kOpTranspose)); + UpdateFaninEdgesWithOp(context, {2}, node, kOpDataFormatDimMap)); TF_RETURN_IF_ERROR( - UpdateFaninEdgesWithOp(context, layout, {2}, node, kOpDataFormatDimMap)); - TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, layout, output_ports, - node, kOpTranspose)); - return context->graph_view->GetMutationBuilder()->Apply(); -} - -Status SqueezeTransposer::TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) { - DCHECK(IsSqueeze(*node->node())); - const NodeLayoutContext& layout = GetDataFormatLayoutForNode( - context, *node, /*fanout_port=*/0, - [this](const utils::MutableNodeView& node) { - return IsDimsSupported(node) && IsInputConvertible(node); - }); - // TODO(lyandy): Generalize this for other data format conversions. - if (!layout.transposable || layout.src != kNHWC || layout.dst != kNCHW) { - return Status::OK(); - } - TF_RETURN_IF_ERROR( - UpdateFaninEdgesWithOp(context, layout, {0}, node, kOpTranspose)); - TF_RETURN_IF_ERROR(UpdateSqueezeDims(context, node)); + UpdateFanoutEdgesWithOp(context, ports, node, kOpTranspose)); return context->graph_view->GetMutationBuilder()->Apply(); } @@ -1592,25 +1296,20 @@ Status SqueezeTransposer::UpdateSqueezeDims(TransposeContext* context, return Status::OK(); } -Status StridedSliceTransposer::TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) { - DCHECK(IsStridedSlice(*node->node())); - const NodeLayoutContext& layout = GetDataFormatLayoutForNode( - context, *node, /*fanout_port=*/0, - [this](const utils::MutableNodeView& node) { - return IsFanoutPortDimsN(node, 0, 4) && HasOnlyBeginEndMask(node); - }); - if (!layout.transposable) { +Status SqueezeTransposer::TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) { + DCHECK(IsSqueeze(*node->node())); + // TODO(lyandy): Generalize this for other data format conversions. + if (context->src_format != kNHWC || context->dst_format != kNCHW) { return Status::OK(); } - TF_RETURN_IF_ERROR( - UpdateFaninEdgesWithOp(context, layout, {0}, node, kOpTranspose)); - TF_RETURN_IF_ERROR(PermuteMask(context, layout, node, "begin_mask")); - TF_RETURN_IF_ERROR(PermuteMask(context, layout, node, "end_mask")); - TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, layout, {1, 2, 3}, node, - kOpDataFormatVecPermute)); - TF_RETURN_IF_ERROR( - UpdateFanoutEdgesWithOp(context, layout, {0}, node, kOpTranspose)); + if (!ShouldProcess(*context, *node) || !IsDimsSupported(*node) || + !IsInputConvertible(*node) || + !IsAfterDstToSrcTransform(*context, *node)) { + return Status::OK(); + } + TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose)); + TF_RETURN_IF_ERROR(UpdateSqueezeDims(context, node)); return context->graph_view->GetMutationBuilder()->Apply(); } @@ -1631,7 +1330,6 @@ bool StridedSliceTransposer::HasOnlyBeginEndMask( } Status StridedSliceTransposer::PermuteMask(TransposeContext* context, - const NodeLayoutContext& layout, utils::MutableNodeView* node, absl::string_view mask) { // Computers the permutation of the masks based on the src and dst format. @@ -1647,8 +1345,8 @@ Status StridedSliceTransposer::PermuteMask(TransposeContext* context, return errors::InvalidArgument("invalid mask value: ", mask_i); } int result = 0; - for (int i = 0; i < layout.src_to_dst.size(); i++) { - const int final_pos = layout.src_to_dst[i]; + for (int i = 0; i < context->src_to_dst.size(); i++) { + const int final_pos = context->src_to_dst[i]; const int position_mask = 1 << final_pos; const int bit_i = (mask_i & position_mask) >> final_pos; result |= bit_i << i; @@ -1660,77 +1358,73 @@ Status StridedSliceTransposer::PermuteMask(TransposeContext* context, return Status::OK(); } +Status StridedSliceTransposer::TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) { + DCHECK(IsStridedSlice(*node->node())); + if (!ShouldProcess(*context, *node) || !IsFanoutPortDimsN(*node, 0, 4) || + !HasOnlyBeginEndMask(*node) || + !IsAfterDstToSrcTransform(*context, *node)) { + return Status::OK(); + } + TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose)); + TF_RETURN_IF_ERROR(PermuteMask(context, node, "begin_mask")); + TF_RETURN_IF_ERROR(PermuteMask(context, node, "end_mask")); + TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {1, 2, 3}, node, + kOpDataFormatVecPermute)); + TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose)); + return context->graph_view->GetMutationBuilder()->Apply(); +} + Status SwitchTransposer::TransposeNode(TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsSwitch(*node->node())); - const NodeLayoutContext& layout = - GetDataFormatLayoutForNode(context, *node, /*fanout_port=*/0, - [this](const utils::MutableNodeView& node) { - return IsFaninPortDimsN(node, 0, 4); - }); - if (!layout.transposable) { + if (!ShouldProcess(*context, *node) || !IsFaninPortDimsN(*node, 0, 4) || + !IsAfterDstToSrcTransform(*context, *node)) { return Status::OK(); } - TF_RETURN_IF_ERROR( - UpdateFaninEdgesWithOp(context, layout, {0}, node, kOpTranspose)); - TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp( - context, layout, GetDataFanoutPorts(*node), node, kOpTranspose)); + TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose)); + TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, GetDataFanoutPorts(*node), + node, kOpTranspose)); return context->graph_view->GetMutationBuilder()->Apply(); } Status TernaryOpTransposer::TransposeNode(TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsTernaryOp(*node->node())); - const NodeLayoutContext& layout = - GetDataFormatLayoutForNode(context, *node, /*fanout_port=*/0, - [this](const utils::MutableNodeView& node) { - return IsFanoutPortDimsN(node, 0, 4); - }); - if (!layout.transposable) { + if (!ShouldProcess(*context, *node) || !IsFanoutPortDimsN(*node, 0, 4) || + !IsAfterDstToSrcTransform(*context, *node)) { return Status::OK(); } TF_RETURN_IF_ERROR( - UpdateFaninEdgesWithOp(context, layout, {0, 1, 2}, node, kOpTranspose)); - TF_RETURN_IF_ERROR( - UpdateFanoutEdgesWithOp(context, layout, {0}, node, kOpTranspose)); + UpdateFaninEdgesWithOp(context, {0, 1, 2}, node, kOpTranspose)); + TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose)); return context->graph_view->GetMutationBuilder()->Apply(); } Status TileTransposer::TransposeNode(TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsTile(*node->node())); - const NodeLayoutContext& layout = - GetDataFormatLayoutForNode(context, *node, /*fanout_port=*/0, - [this](const utils::MutableNodeView& node) { - return IsFanoutPortDimsN(node, 0, 4); - }); - if (!layout.transposable) { + if (!ShouldProcess(*context, *node) || !IsFanoutPortDimsN(*node, 0, 4) || + !IsAfterDstToSrcTransform(*context, *node)) { return Status::OK(); } + TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose)); TF_RETURN_IF_ERROR( - UpdateFaninEdgesWithOp(context, layout, {0}, node, kOpTranspose)); - TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, layout, {1}, node, - kOpDataFormatVecPermute)); - TF_RETURN_IF_ERROR( - UpdateFanoutEdgesWithOp(context, layout, {0}, node, kOpTranspose)); + UpdateFaninEdgesWithOp(context, {1}, node, kOpDataFormatVecPermute)); + TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose)); return context->graph_view->GetMutationBuilder()->Apply(); } Status UnaryGradTransposer::TransposeNode(TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsUnaryGrad(*node->node())); - const NodeLayoutContext& layout = - GetDataFormatLayoutForNode(context, *node, /*fanout_port=*/0, - [this](const utils::MutableNodeView& node) { - return IsFanoutPortDimsN(node, 0, 4); - }); - if (!layout.transposable) { + if (!ShouldProcess(*context, *node) || !IsFanoutPortDimsN(*node, 0, 4) || + !IsAfterDstToSrcTransform(*context, *node)) { return Status::OK(); } TF_RETURN_IF_ERROR( - UpdateFaninEdgesWithOp(context, layout, {0, 1}, node, kOpTranspose)); - TF_RETURN_IF_ERROR( - UpdateFanoutEdgesWithOp(context, layout, {0}, node, kOpTranspose)); + UpdateFaninEdgesWithOp(context, {0, 1}, node, kOpTranspose)); + TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose)); return context->graph_view->GetMutationBuilder()->Apply(); } @@ -1908,47 +1602,11 @@ bool GetValueAttrIfConstPermTransposeNode(const utils::MutableNodeView& node, return true; } -bool IsValidConstPermTransposeNode(const utils::MutableNodeView& node, - absl::Span permutation) { - Tensor tensor; - if (!GetValueAttrIfConstPermTransposeNode(node, &tensor)) { - return false; - } - if (tensor.NumElements() != permutation.size()) { - return false; - } - - const auto& tensor_data = tensor.unaligned_flat(); - for (int i = 0; i < permutation.size(); i++) { - if (permutation[i] != tensor_data(i)) { - return false; - } - } - return true; -} - bool IsDataFormatOp(const utils::MutableNodeView& node) { const string& op = node.GetOp(); return op == kOpDataFormatDimMap || op == kOpDataFormatVecPermute; } -bool IsValidDataFormatNode(const utils::MutableNodeView& node, - absl::string_view src_format, - absl::string_view dst_format) { - if (!IsDataFormatOp(node)) { - return false; - } - const auto* src_format_attr = node.GetAttr(kAttrSrcFormat); - if (src_format_attr == nullptr || src_format_attr->s() != src_format) { - return false; - } - const auto* dst_format_attr = node.GetAttr(kAttrDstFormat); - if (dst_format_attr == nullptr || dst_format_attr->s() != dst_format) { - return false; - } - return true; -} - std::vector GetPermutation(absl::string_view src_format, absl::string_view dst_format) { // Generate permutation for transformation between src and dst format. diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.h b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.h index ff60fcaafa6..05100d7508b 100644 --- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.h +++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.h @@ -36,35 +36,7 @@ namespace grappler { constexpr char kAttrSrcFormat[] = "src_format"; constexpr char kAttrDstFormat[] = "dst_format"; - -// NodeLayoutContext holds the data formats to convert from and to for a layout -// sensitive node. -struct NodeLayoutContext { - void Initialize(absl::string_view src_format, absl::string_view dst_format, - bool is_transposable); - - bool operator==(const NodeLayoutContext& other) const { - return src == other.src && dst == other.dst; - } - - bool operator!=(const NodeLayoutContext& other) const { - return !(*this == other); - } - - string src; - string dst; - std::vector src_to_dst; - std::vector dst_to_src; - bool transposable = false; -}; - -// TransposeContext holds the data formats to convert from and to of each data -// fanin and fanout for a layout agnostic node. -struct LayoutAgnosticNodeFanouts { - std::vector fanout_layouts; - std::vector fanin_layouts; - bool transposable = false; -}; +constexpr char kAttrOutputShape[] = "_output_shapes"; // TransposeContext owns all data members. Must initialize GraphProperties, // FrameView, GraphDef and MutableGraphView with the same graph. NodeDef @@ -76,8 +48,10 @@ struct TransposeContext { // TransposeContext outside constructor. static Status InitializeTransposeContext(const GrapplerItem& item, const Cluster* cluster, + absl::string_view src_format, + absl::string_view dst_format, + absl::string_view target_device, TransposeContext* context); - FrameView frames; GraphDef graph; // Number of nodes in the original graph. As new nodes are appended to the end @@ -88,12 +62,12 @@ struct TransposeContext { std::unique_ptr graph_properties; std::unique_ptr graph_view; std::unique_ptr virtual_placer; - // Mapping of node index to data layout conversion of layout sensitive nodes. - absl::flat_hash_map sensitive_node_layouts; - // Mapping of node index to data layout conversion of layout agnostic nodes. - absl::flat_hash_map agnostic_node_layouts; - NodeLayoutContext unknown_node_layout; - LayoutAgnosticNodeFanouts unknown_agnostic_node_layout; + + string src_format; + string dst_format; + string target_device; + std::vector src_to_dst; + std::vector dst_to_src; }; class Transposer { @@ -105,6 +79,16 @@ class Transposer { virtual ~Transposer() {} + // Returns true iff the node should be processed by this transposer. + // NodeProcessors may perform additional oprand specific checks before + // processing if necessary. + // Following common conditions are checked: + // * node's device matches target device + // * node's source format matches config's source format + // * node has output + virtual bool ShouldProcess(const TransposeContext& context, + const utils::MutableNodeView& node) const; + // Transposes given node from src format to dst format. Also perform other // necessary operations to guarantee the graph produce the same result. // Eg. Add Transpose node sets before fanin ports and after fanout ports. @@ -133,7 +117,6 @@ class Transposer { // Update all edges between dst_node->fanin[dst_ports] and dst_node by // inserting an op node. Status UpdateFaninEdgesWithOp(TransposeContext* context, - const NodeLayoutContext& layout, absl::Span dst_ports, utils::MutableNodeView* dst_node, absl::string_view op); @@ -141,7 +124,6 @@ class Transposer { // Update all edges between src_node:src_ports and nodes take // src_node:src_ports as fanin. Also update attr _output_shape of src_node. Status UpdateFanoutEdgesWithOp(TransposeContext* context, - const NodeLayoutContext& layout, absl::Span src_ports, utils::MutableNodeView* src_node, absl::string_view op); @@ -149,7 +131,6 @@ class Transposer { // Creates a DataFromat node with given properties. // DataFromat op is either DataFormatVecPermute or DataFormatDimMap. Status CreateDataFormatNode(TransposeContext* context, - const NodeLayoutContext& layout, absl::string_view node_name, absl::string_view op, absl::string_view device, const DataType& data_type, bool is_fanin_on_host, @@ -167,14 +148,11 @@ class Transposer { const utils::MutableNodeView& node) const; string GetDeviceName(const VirtualPlacer* virtual_placer, const NodeDef& node) const; - virtual string GetDstDataFormatForDevice( - const DeviceProperties& device) const; // Update all edges between dst_node->fanin[dst_ports] and dst_node. // A node with op is created and insereted between all edges. // op is one of Transpose, DataFormatVecPermute or DataFormatDimMap. - Status UpdateEdge(TransposeContext* context, const NodeLayoutContext& layout, - absl::string_view name_format, absl::string_view op, - const AttrValue* input_shape, + Status UpdateEdge(TransposeContext* context, absl::string_view name_format, + absl::string_view op, const AttrValue* input_shape, bool is_src_format_to_dst_format, const int src_port, const int dst_port, utils::MutableNodeView* src_node, utils::MutableNodeView* dst_node); @@ -197,19 +175,7 @@ class LayoutSensitiveOpTransposer : public Transposer { // Updates attrs data_format, ksize, strides of the given node to dst_format. // _output_shape is updated during UpdateOutputEdges. - Status UpdateNode(TransposeContext* context, const NodeLayoutContext& layout, - utils::MutableNodeView* node); - - protected: - // Check if a node is transposable, and return what data format to transpose - // from and to. - const NodeLayoutContext& CheckNodeTransposable( - TransposeContext* context, const utils::MutableNodeView& node) const; - - // Apply mutation. If mutation fails, reset layout being transpose to - // associated with the node. - Status CommitMutation(TransposeContext* context, - const utils::MutableNodeView& node); + Status UpdateNode(TransposeContext* context, utils::MutableNodeView* node); }; // Layout sensitive op transposers. @@ -289,23 +255,12 @@ class LayoutAgnosticOpTransposer : public Transposer { explicit LayoutAgnosticOpTransposer() : Transposer() {} protected: - void ProcessFanoutDataLayouts(TransposeContext* context, - const utils::MutableNodeView& node); + bool IsAfterDstToSrcTransform(const TransposeContext& context, + const utils::MutableNodeView& node) const; - void FetchNodeDataLayouts(TransposeContext* context, - const utils::MutableNodeView& node); - - // Check if a node is transposable, and return what data format to transpose - // from and to. - const LayoutAgnosticNodeFanouts& CheckNodeTransposable( - TransposeContext* context, const utils::MutableNodeView& node); - - // Validate node and get data format layout to transpose from and to for a - // given fanout port of a node. - const NodeLayoutContext& GetDataFormatLayoutForNode( - TransposeContext* context, const utils::MutableNodeView& node, - int fanout_port, - std::function node_check); + std::vector GetVariadic4DFaninPorts( + const TransposeContext& context, + const utils::MutableNodeView& node) const; }; class DefaultLayoutAgnosticOpTransposer : public LayoutAgnosticOpTransposer { @@ -345,7 +300,6 @@ class BinaryOpTransposer : public LayoutAgnosticOpTransposer { absl::string_view shape_const_node_name, const DataType& data_type); Status MaybeReshapeVectorFanin(TransposeContext* context, - const NodeLayoutContext& layout, utils::MutableNodeView* node); }; @@ -381,7 +335,9 @@ class MergeTransposer : public LayoutAgnosticOpTransposer { utils::MutableNodeView* node) override; private: - std::vector GetFaninPorts(const utils::MutableNodeView& node); + bool IsEveryFaninAfterDstToSrcTransform( + const TransposeContext& context, + const utils::MutableNodeView& node) const; }; class PadTransposer : public LayoutAgnosticOpTransposer { @@ -403,7 +359,7 @@ class ReduceTransposer : public LayoutAgnosticOpTransposer { bool KeepDims(const utils::MutableNodeView& node); bool IsAlongAxis(const utils::MutableNodeView& axis_node, absl::Span axis); - bool IsReduceAxisSupported(const NodeLayoutContext& layout, + bool IsReduceAxisSupported(const TransposeContext& context, const utils::MutableNodeView& node); }; @@ -495,8 +451,8 @@ class StridedSliceTransposer : public LayoutAgnosticOpTransposer { private: bool IsMaskZero(const utils::MutableNodeView& node, absl::string_view mask); bool HasOnlyBeginEndMask(const utils::MutableNodeView& node); - Status PermuteMask(TransposeContext* context, const NodeLayoutContext& layout, - utils::MutableNodeView* node, absl::string_view mask); + Status PermuteMask(TransposeContext* context, utils::MutableNodeView* node, + absl::string_view mask); }; class SwitchTransposer : public LayoutAgnosticOpTransposer { @@ -580,19 +536,8 @@ std::vector GetDataFanoutPorts(const utils::MutableNodeView& node); bool GetValueAttrIfConstPermTransposeNode(const utils::MutableNodeView& node, Tensor* tensor); -// Returns true if the given node is a transpose op performing given const -// permutation. -bool IsValidConstPermTransposeNode(const utils::MutableNodeView& node, - absl::Span permutation); - bool IsDataFormatOp(const utils::MutableNodeView& node); -// Returns true if the given node is DataformatDimMap or DataformatVecPermute -// performing given permutation from scr_format to dst_format. -bool IsValidDataFormatNode(const utils::MutableNodeView& node, - absl::string_view src_format, - absl::string_view dst_format); - std::vector GetPermutation(absl::string_view src_format, absl::string_view dst_format); diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer_test.cc b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer_test.cc index 33f2cec8a7d..e328a3a0327 100644 --- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer_test.cc +++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer_test.cc @@ -49,6 +49,7 @@ constexpr int kOutHeight = 5; constexpr int kDepthOut = 16; constexpr char kSrcFormat[] = "NHWC"; constexpr char kDstFormat[] = "NCHW"; +constexpr char kGPU[] = "GPU"; constexpr char kAttrOutputShapes[] = "_output_shapes"; constexpr char kAttrDataFormat[] = "data_format"; constexpr char kOpTranspose[] = "Transpose"; @@ -316,7 +317,7 @@ TEST_F(TransposerTest, CreateConstPermNode) { TransposeContext context; TF_ASSERT_OK(CreateSimpleConv2DGraph(&item.graph)); TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( - item, virtual_cluster_.get(), &context)); + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); TransposerImpl transposer; constexpr char kNodeName[] = "const_perm_node"; @@ -358,8 +359,8 @@ TEST_F(TransposerTest, CreateTransposeNode) { GrapplerItem item; TransposeContext context; TF_ASSERT_OK(CreateSimpleConv2DGraph(&item.graph)); - TF_ASSERT_OK( - TransposeContext::InitializeTransposeContext(item, nullptr, &context)); + TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); TransposerImpl transposer; constexpr char kNodeNameFormat[] = @@ -398,14 +399,12 @@ TEST_F(TransposerTest, UpdateNode) { TransposeContext context; TF_ASSERT_OK(CreateSimpleConv2DGraph(&item.graph)); TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( - item, virtual_cluster_.get(), &context)); + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); DefaultLayoutSensitiveOpTransposer transposer; auto* conv2d = context.graph_view->GetNode("conv2d"); ASSERT_NE(conv2d, nullptr); - NodeLayoutContext layout; - layout.Initialize(kSrcFormat, kDstFormat, /*is_transposable=*/true); - TF_ASSERT_OK(transposer.UpdateNode(&context, layout, conv2d)); + TF_ASSERT_OK(transposer.UpdateNode(&context, conv2d)); TF_ASSERT_OK(context.graph_view->GetMutationBuilder()->Apply()); auto* updated_conv2d = context.graph_view->GetNode("conv2d"); @@ -430,7 +429,7 @@ TEST_F(TransposerTest, UpdateStrides) { TransposeContext context; TF_ASSERT_OK(CreateSimpleConv2DGraph(&item.graph)); TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( - item, virtual_cluster_.get(), &context)); + item, virtual_cluster_.get(), "ABCD", "ACBD", kGPU, &context)); AttrValue_ListValue expected_original_strides = MakeAttrValueListValueFromVector({1, 2, 4, 1}); @@ -449,9 +448,7 @@ TEST_F(TransposerTest, UpdateStrides) { TF_ASSERT_OK(context.graph_view->GetMutationBuilder()->Apply()); DefaultLayoutSensitiveOpTransposer transposer; - NodeLayoutContext layout; - layout.Initialize("ABCD", "ACBD", /*is_transposable=*/true); - TF_ASSERT_OK(transposer.UpdateNode(&context, layout, conv2d)); + TF_ASSERT_OK(transposer.UpdateNode(&context, conv2d)); TF_ASSERT_OK(context.graph_view->GetMutationBuilder()->Apply()); auto* updated_conv2d = context.graph_view->GetNode("conv2d"); @@ -468,19 +465,17 @@ TEST_F(TransposerTest, UpdateFaninEdgesTranspose) { GrapplerItem item; TransposeContext context; TF_ASSERT_OK(CreateSimpleFusedBatchNormGrad(&item.graph, true)); - TF_ASSERT_OK( - TransposeContext::InitializeTransposeContext(item, nullptr, &context)); + TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); FusedBatchNormGradTransposer transposer; - NodeLayoutContext layout; - layout.Initialize(kSrcFormat, kDstFormat, /*is_transposable=*/true); auto* fbng = context.graph_view->GetNode("fused_batch_norm_grad"); ASSERT_NE(fbng, nullptr); const auto& fbng_output_shapes_attr = fbng->GetAttr("_output_shapes"); ASSERT_NE(fbng_output_shapes_attr, nullptr); const TensorShapeProto& expected_shape = fbng_output_shapes_attr->shape(); - TF_ASSERT_OK(transposer.UpdateFaninEdgesWithOp(&context, layout, {0, 1}, fbng, - kOpTranspose)); + TF_ASSERT_OK( + transposer.UpdateFaninEdgesWithOp(&context, {0, 1}, fbng, kOpTranspose)); TF_ASSERT_OK(context.graph_view->GetMutationBuilder()->Apply()); // Verify output shape matches input shape. @@ -528,12 +523,10 @@ TEST_F(TransposerTest, UpdateFanoutEdgesTranspose) { GrapplerItem item; TransposeContext context; TF_ASSERT_OK(CreateSimpleConv2DGraph(&item.graph)); - TF_ASSERT_OK( - TransposeContext::InitializeTransposeContext(item, nullptr, &context)); + TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); TransposerImpl transposer; - NodeLayoutContext layout; - layout.Initialize(kSrcFormat, kDstFormat, /*is_transposable=*/true); TensorShapeProto expected_original_shape = MakeTensorShapeFromDimensions({32, 5, 3, 16}); TensorShapeProto expected_updated_shape = @@ -543,8 +536,8 @@ TEST_F(TransposerTest, UpdateFanoutEdgesTranspose) { ASSERT_NE(conv2d, nullptr); VerifyShapeAttributeMatch(conv2d, 0, expected_original_shape.DebugString()); - TF_ASSERT_OK(transposer.UpdateFanoutEdgesWithOp(&context, layout, {0}, conv2d, - kOpTranspose)); + TF_ASSERT_OK( + transposer.UpdateFanoutEdgesWithOp(&context, {0}, conv2d, kOpTranspose)); TF_ASSERT_OK(context.graph_view->GetMutationBuilder()->Apply()); auto* updated_conv2d = context.graph_view->GetNode("conv2d"); @@ -584,7 +577,7 @@ TEST_F(TransposerTest, DefaultLayoutSensitiveOpTransposerTestFusedBatchNorm) { TransposeContext context; TF_ASSERT_OK(CreateSimpleFusedBatchNorm(&item.graph)); TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( - item, virtual_cluster_.get(), &context)); + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); DefaultLayoutSensitiveOpTransposer transposer; auto* bn = context.graph_view->GetNode("bn"); @@ -639,7 +632,7 @@ TEST_F(TransposerTest, DefaultLayoutSensitiveOpTransposerTestConv2D) { TransposeContext context; TF_ASSERT_OK(CreateSimpleConv2DGraph(&item.graph)); TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( - item, virtual_cluster_.get(), &context)); + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); DefaultLayoutSensitiveOpTransposer transposer; auto* conv2d = context.graph_view->GetNode("conv2d"); @@ -681,7 +674,7 @@ TEST_F(TransposerTest, MaxPoolGradTransposerTest) { TransposeContext context; TF_ASSERT_OK(CreateSimpleMaxPoolGrad(&item.graph)); TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( - item, virtual_cluster_.get(), &context)); + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); MaxPoolGradTransposer transposer; auto* maxpool_grad = context.graph_view->GetNode("maxpool_grad"); @@ -733,7 +726,7 @@ TEST_F(TransposerTest, BiasAddGradTransposerTest) { TF_ASSERT_OK(CreateSimpleBiasAddGrad( &item.graph, {kBatchSize, kHeight, kWidth, kDepthIn})); TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( - item, virtual_cluster_.get(), &context)); + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); BiasAddGradTransposer transposer; auto* bag = context.graph_view->GetNode("bag"); @@ -768,7 +761,7 @@ TEST_F(TransposerTest, BiasAddGradTransposerIncorrectInputTest) { TF_ASSERT_OK( CreateSimpleBiasAddGrad(&item.graph, {kHeight, kWidth, kDepthIn})); TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( - item, virtual_cluster_.get(), &context)); + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); BiasAddGradTransposer transposer; auto* bag = context.graph_view->GetNode("bag"); @@ -802,7 +795,7 @@ TEST_F(TransposerTest, Conv2DBackpropFilterTransposerTest) { TransposeContext context; TF_ASSERT_OK(CreateSimpleConv2DBackpropFilter(&item.graph)); TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( - item, virtual_cluster_.get(), &context)); + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); Conv2DBackpropFilterTransposer transposer; auto* conv2d_bf = context.graph_view->GetNode("conv2d_backprop_filter"); @@ -854,7 +847,7 @@ TEST_F(TransposerTest, FusedBatchNormGradTransposerIsTrainingTest) { TransposeContext context; TF_ASSERT_OK(CreateSimpleFusedBatchNormGrad(&item.graph, true)); TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( - item, virtual_cluster_.get(), &context)); + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); FusedBatchNormGradTransposer transposer; auto* fbng = context.graph_view->GetNode("fused_batch_norm_grad"); @@ -922,7 +915,7 @@ TEST_F(TransposerTest, FusedBatchNormGradTransposerNotTrainingTest) { TransposeContext context; TF_ASSERT_OK(CreateSimpleFusedBatchNormGrad(&item.graph, false)); TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( - item, virtual_cluster_.get(), &context)); + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); FusedBatchNormGradTransposer transposer; auto* fbng = context.graph_view->GetNode("fused_batch_norm_grad"); @@ -992,7 +985,7 @@ TEST_F(TransposerTest, DefaultLayoutAgnosticOpTransposerIdentityTest) { TF_ASSERT_OK(scope.ToGraphDef(&item.graph)); TransposeContext context; TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( - item, virtual_cluster_.get(), &context)); + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); DefaultLayoutSensitiveOpTransposer conv2d_transposer; auto* c2d = context.graph_view->GetNode("conv2d"); @@ -1043,7 +1036,7 @@ TEST_F(TransposerTest, DefaultLayoutAgnosticOpTransposerIdentityBadInputTest) { TF_ASSERT_OK(scope.ToGraphDef(&item.graph)); TransposeContext context; TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( - item, virtual_cluster_.get(), &context)); + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); DefaultLayoutSensitiveOpTransposer conv2d_transposer; auto* c2d = context.graph_view->GetNode("conv2d"); @@ -1084,7 +1077,7 @@ TEST_F(TransposerTest, AddNTransposerTest) { TF_ASSERT_OK(CreateSimpleAddN(&item.graph)); TransposeContext context; TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( - item, virtual_cluster_.get(), &context)); + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); DefaultLayoutSensitiveOpTransposer conv2d_transposer; auto* conv2d = context.graph_view->GetNode("conv2d"); @@ -1147,7 +1140,7 @@ TEST_F(TransposerTest, AddNTransposerNotAfterTransformTest) { TF_ASSERT_OK(CreateSimpleAddN(&item.graph)); TransposeContext context; TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( - item, virtual_cluster_.get(), &context)); + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); AddNTransposer addn_transposer; auto* an = context.graph_view->GetNode("add_n"); @@ -1197,7 +1190,7 @@ TEST_F(TransposerTest, IdentityNTransposerTest) { TF_ASSERT_OK(CreateSimpleIdentityN(&item.graph)); TransposeContext context; TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( - item, virtual_cluster_.get(), &context)); + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); DefaultLayoutSensitiveOpTransposer conv2d_transposer; auto* conv2d_1 = context.graph_view->GetNode("conv2d_1"); @@ -1295,7 +1288,7 @@ TEST_F(TransposerTest, MergeTransposerTestMergeBothInputsConvertible) { TF_ASSERT_OK(scope.ToGraphDef(&item.graph)); TransposeContext context; TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( - item, virtual_cluster_.get(), &context)); + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); DefaultLayoutSensitiveOpTransposer conv2d_transposer; auto* c2d = context.graph_view->GetNode("conv2d"); @@ -1354,7 +1347,7 @@ TEST_F(TransposerTest, MergeTransposerTestMergeOneInputNotConvertible) { TF_ASSERT_OK(scope.ToGraphDef(&item.graph)); TransposeContext context; TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( - item, virtual_cluster_.get(), &context)); + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); DefaultLayoutSensitiveOpTransposer conv2d_transposer; auto* c2d = context.graph_view->GetNode("conv2d"); @@ -1407,7 +1400,7 @@ TEST_F(TransposerTest, PadTransposerTest) { TF_ASSERT_OK(scope.ToGraphDef(&item.graph)); TransposeContext context; TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( - item, virtual_cluster_.get(), &context)); + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); DefaultLayoutSensitiveOpTransposer conv2d_transposer; auto* c2d = context.graph_view->GetNode("conv2d"); @@ -1467,7 +1460,7 @@ TEST_F(TransposerTest, SwitchTransposerTest) { TF_ASSERT_OK(scope.ToGraphDef(&item.graph)); TransposeContext context; TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( - item, virtual_cluster_.get(), &context)); + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); DefaultLayoutSensitiveOpTransposer conv2d_transposer; auto* c2d = context.graph_view->GetNode("conv2d"); @@ -1533,7 +1526,7 @@ TEST_F(TransposerTest, TernaryOpTransposerTest) { TF_ASSERT_OK(scope.ToGraphDef(&item.graph)); TransposeContext context; TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( - item, virtual_cluster_.get(), &context)); + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); DefaultLayoutSensitiveOpTransposer conv2d_transposer; auto* c2d = context.graph_view->GetNode("conv2d"); @@ -1600,7 +1593,7 @@ TEST_F(TransposerTest, UnaryGradTransposerTestTanhGrad) { TF_ASSERT_OK(scope.ToGraphDef(&item.graph)); TransposeContext context; TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( - item, virtual_cluster_.get(), &context)); + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); DefaultLayoutSensitiveOpTransposer conv2d_transposer; auto* c2d = context.graph_view->GetNode("conv2d"); @@ -1663,7 +1656,7 @@ TEST_F(TransposerTest, UnaryGradTransposerTestRelu6Grad) { TF_ASSERT_OK(scope.ToGraphDef(&item.graph)); TransposeContext context; TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( - item, virtual_cluster_.get(), &context)); + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); DefaultLayoutSensitiveOpTransposer conv2d_transposer; auto* c2d = context.graph_view->GetNode("conv2d"); @@ -1731,7 +1724,7 @@ TEST_F(TransposerTest, SqueezeTransposerTest) { TF_ASSERT_OK(scope.ToGraphDef(&item.graph)); TransposeContext context; TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( - item, virtual_cluster_.get(), &context)); + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); DefaultLayoutSensitiveOpTransposer conv2d_transposer; auto* c2d = context.graph_view->GetNode("conv2d"); @@ -1785,7 +1778,7 @@ TEST_F(TransposerTest, SqueezeTransposerTestUnsupportedInputShape) { TF_ASSERT_OK(scope.ToGraphDef(&item.graph)); TransposeContext context; TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( - item, virtual_cluster_.get(), &context)); + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); DefaultLayoutSensitiveOpTransposer conv2d_transposer; auto* c2d = context.graph_view->GetNode("conv2d"); @@ -1824,7 +1817,7 @@ TEST_F(TransposerTest, SqueezeTransposerTestInvalidHWAxis) { TF_ASSERT_OK(scope.ToGraphDef(&item.graph)); TransposeContext context; TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( - item, virtual_cluster_.get(), &context)); + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); DefaultLayoutSensitiveOpTransposer conv2d_transposer; auto* c2d = context.graph_view->GetNode("conv2d"); @@ -1863,7 +1856,7 @@ TEST_F(TransposerTest, SqueezeTransposerTestInvalidNHWAxis) { TF_ASSERT_OK(scope.ToGraphDef(&item.graph)); TransposeContext context; TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( - item, virtual_cluster_.get(), &context)); + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); DefaultLayoutSensitiveOpTransposer conv2d_transposer; auto* c2d = context.graph_view->GetNode("conv2d"); @@ -1902,7 +1895,7 @@ TEST_F(TransposerTest, SqueezeTransposerTestSqueezeDimsUpdated) { TF_ASSERT_OK(scope.ToGraphDef(&item.graph)); TransposeContext context; TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( - item, virtual_cluster_.get(), &context)); + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); DefaultLayoutSensitiveOpTransposer conv2d_transposer; auto* c2d = context.graph_view->GetNode("conv2d"); @@ -1960,7 +1953,7 @@ TEST_F(TransposerTest, MaxPoolV2Transposer) { TF_ASSERT_OK(scope.ToGraphDef(&item.graph)); TransposeContext context; TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( - item, virtual_cluster_.get(), &context)); + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); MaxPoolV2Transposer maxpool_transposer; auto* maxpool = context.graph_view->GetNode("maxpoolv2"); @@ -2023,7 +2016,7 @@ TEST_F(TransposerTest, MaxPoolGradV2Transposer) { TF_ASSERT_OK(scope.ToGraphDef(&item.graph)); TransposeContext context; TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( - item, virtual_cluster_.get(), &context)); + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); MaxPoolGradV2Transposer maxpoolgrad_transposer; auto* maxpoolgrad = context.graph_view->GetNode("maxpoolgradv2"); @@ -2091,7 +2084,7 @@ TEST_F(TransposerTest, BinaryOpTransposerAdd) { TF_ASSERT_OK(scope.ToGraphDef(&item.graph)); TransposeContext context; TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( - item, virtual_cluster_.get(), &context)); + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); DefaultLayoutSensitiveOpTransposer conv2d_transposer; auto* c2d = context.graph_view->GetNode("conv2d"); @@ -2162,7 +2155,7 @@ TEST_F(TransposerTest, BinaryOpTransposerMul) { TF_ASSERT_OK(scope.ToGraphDef(&item.graph)); TransposeContext context; TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( - item, virtual_cluster_.get(), &context)); + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); DefaultLayoutSensitiveOpTransposer conv2d_transposer; auto* c2d = context.graph_view->GetNode("conv2d"); @@ -2235,7 +2228,7 @@ TEST_F(TransposerTest, BinaryOpTransposerPolygamma) { TF_ASSERT_OK(scope.ToGraphDef(&item.graph)); TransposeContext context; TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( - item, virtual_cluster_.get(), &context)); + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); DefaultLayoutSensitiveOpTransposer conv2d_transposer; auto* c2d = context.graph_view->GetNode("conv2d"); @@ -2326,7 +2319,7 @@ TEST_F(TransposerTest, ConcatOpTransposerConcat) { TransposeContext context; TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( - item, virtual_cluster_.get(), &context)); + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); DefaultLayoutSensitiveOpTransposer conv2d_transposer; auto* c2d = context.graph_view->GetNode("conv2d"); @@ -2402,7 +2395,7 @@ TEST_F(TransposerTest, ConcatOpTransposerConcatV2) { TransposeContext context; TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( - item, virtual_cluster_.get(), &context)); + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); DefaultLayoutSensitiveOpTransposer conv2d_transposer; auto* c2d = context.graph_view->GetNode("conv2d"); @@ -2474,7 +2467,7 @@ TEST_F(TransposerTest, ReverseV2Transposer) { TransposeContext context; TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( - item, virtual_cluster_.get(), &context)); + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); DefaultLayoutSensitiveOpTransposer conv2d_transposer; auto* c2d = context.graph_view->GetNode("conv2d"); @@ -2541,7 +2534,7 @@ TEST_F(TransposerTest, TileTransposer) { TransposeContext context; TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( - item, virtual_cluster_.get(), &context)); + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); DefaultLayoutSensitiveOpTransposer conv2d_transposer; auto* c2d = context.graph_view->GetNode("conv2d"); @@ -2606,7 +2599,7 @@ TEST_F(TransposerTest, ShapeTransposer) { TransposeContext context; TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( - item, virtual_cluster_.get(), &context)); + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); DefaultLayoutSensitiveOpTransposer conv2d_transposer; auto* c2d = context.graph_view->GetNode("conv2d"); @@ -2671,7 +2664,7 @@ TEST_F(TransposerTest, ShapeNTransposer) { TransposeContext context; TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( - item, virtual_cluster_.get(), &context)); + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); DefaultLayoutSensitiveOpTransposer conv2d_transposer; auto* c2d_1 = context.graph_view->GetNode("conv2d_1"); @@ -2765,7 +2758,7 @@ TEST_F(TransposerTest, FillOpTransposer) { TransposeContext context; TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( - item, virtual_cluster_.get(), &context)); + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); DefaultLayoutSensitiveOpTransposer conv2d_transposer; auto* c2d = context.graph_view->GetNode("conv2d"); @@ -2825,7 +2818,7 @@ TEST_F(TransposerTest, SliceTransposer) { TransposeContext context; TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( - item, virtual_cluster_.get(), &context)); + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); DefaultLayoutSensitiveOpTransposer conv2d_transposer; auto* c2d = context.graph_view->GetNode("conv2d"); @@ -2901,7 +2894,7 @@ TEST_F(TransposerTest, SplitTransposer) { TransposeContext context; TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( - item, virtual_cluster_.get(), &context)); + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); DefaultLayoutSensitiveOpTransposer conv2d_transposer; auto* c2d = context.graph_view->GetNode("conv2d"); @@ -2989,7 +2982,7 @@ TEST_F(TransposerTest, SplitVTransposer) { TransposeContext context; TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( - item, virtual_cluster_.get(), &context)); + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); DefaultLayoutSensitiveOpTransposer conv2d_transposer; auto* c2d = context.graph_view->GetNode("conv2d"); @@ -3080,7 +3073,7 @@ TEST_F(TransposerTest, StridedSliceTransposer) { TransposeContext context; TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( - item, virtual_cluster_.get(), &context)); + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); DefaultLayoutSensitiveOpTransposer conv2d_transposer; auto* c2d = context.graph_view->GetNode("conv2d"); @@ -3170,7 +3163,7 @@ TEST_F(TransposerTest, StridedSliceTransposerEllipsisMaskPresent) { TransposeContext context; TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( - item, virtual_cluster_.get(), &context)); + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); DefaultLayoutSensitiveOpTransposer conv2d_transposer; auto* c2d = context.graph_view->GetNode("conv2d"); @@ -3226,7 +3219,7 @@ TEST_F(TransposerTest, ReduceTransposerKeepDims) { TransposeContext context; TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( - item, virtual_cluster_.get(), &context)); + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); DefaultLayoutSensitiveOpTransposer conv2d_transposer; auto* c2d = context.graph_view->GetNode("conv2d"); @@ -3291,7 +3284,7 @@ TEST_F(TransposerTest, ReduceTransposerValidAxisNode) { TransposeContext context; TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( - item, virtual_cluster_.get(), &context)); + item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context)); DefaultLayoutSensitiveOpTransposer conv2d_transposer; auto* c2d = context.graph_view->GetNode("conv2d");