diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_test.cc b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_test.cc index 4ceb0264909..c85d85e69ff 100644 --- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_test.cc @@ -77,17 +77,21 @@ Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size, Output SimpleConv2DBackpropInput(tensorflow::Scope* s, int input_size, int filter_size, const string& padding, - bool dilated) { + bool dilated, const int input_sizes_length) { int batch_size = 128; int input_height = input_size; int input_width = input_size; int input_depth = 3; int filter_count = 2; int stride = 1; - TensorShape input_sizes_shape({4}); + TensorShape input_sizes_shape({input_sizes_length}); Tensor input_data(DT_INT32, input_sizes_shape); - test::FillValues(&input_data, - {batch_size, input_height, input_width, input_depth}); + if (input_sizes_length == 4) { + test::FillValues(&input_data, + {batch_size, input_height, input_width, input_depth}); + } else { + test::FillValues(&input_data, {input_height, input_width}); + } Output input_sizes = ops::Const(s->WithOpName("InputSizes"), Input::Initializer(input_data)); @@ -353,7 +357,8 @@ TEST_F(GenericLayoutOptimizerTest, Conv2DBackpropInputNonConstInputSizes) { GTEST_SKIP() << "Neither CUDA nor ROCm is enabled"; #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM) Scope s = Scope::NewRootScope(); - auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "SAME", /*dilated=*/false); + auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "SAME", /*dilated=*/false, + /*input_sizes_length=*/4); Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv}); GrapplerItem item; TF_ASSERT_OK(s.ToGraphDef(&item.graph)); @@ -380,6 +385,30 @@ TEST_F(GenericLayoutOptimizerTest, Conv2DBackpropInputNonConstInputSizes) { VerifyRegularFaninMatch(input_sizes_node, 0, "InputSizesIdentity", 0); } +TEST_F(GenericLayoutOptimizerTest, Conv2DBackpropInput2DInputSizes) { +#if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM) + GTEST_SKIP() << "Neither CUDA nor ROCm is enabled"; +#endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM) + Scope s = Scope::NewRootScope(); + auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "SAME", /*dilated=*/false, + /*input_sizes_length=*/2); + Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv}); + GrapplerItem item; + TF_ASSERT_OK(s.ToGraphDef(&item.graph)); + + GenericLayoutOptimizer optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); + + Status status; + utils::GraphView graph_view(&output, &status); + TF_ASSERT_OK(status); + auto* conv2d_backprop_node = graph_view.GetNode("Conv2DBackpropInput"); + ASSERT_NE(conv2d_backprop_node, nullptr); + ASSERT_EQ(conv2d_backprop_node->NumRegularFanins(), 3); + VerifyRegularFaninMatch(conv2d_backprop_node, 0, "InputSizesIdentity", 0); +} + TEST_F(GenericLayoutOptimizerTest, Conv2DDataFormatVecPermuteCollapse) { #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM) GTEST_SKIP() << "Neither CUDA nor ROCm is enabled"; diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc index ca28fafe80a..e9691a13b30 100644 --- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc +++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc @@ -725,12 +725,42 @@ Status Conv2DBackpropInputTransposer::TransposeNode( if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4)) { return Status::OK(); } + + const auto& fanin = node->GetRegularFanin(0); + auto* fanin_node = fanin.node_view(); + const auto* output_shape_attr = fanin_node->GetAttr(kAttrOutputShape); + if (output_shape_attr == nullptr) { + VLOG(3) << "Cannot compute the shape of " << fanin_node->GetName() + << " because it is missing attribute " << kAttrOutputShape; + return Status::OK(); + } + TensorShapeProto fanin_shape = output_shape_attr->list().shape(fanin.index()); + if (fanin_shape.dim_size() != 1) { + VLOG(3) << fanin_node->GetName() << " is not a vector."; + return Status::OK(); + } + int vector_size = fanin_shape.dim(0).size(); + if (vector_size == -1) { + VLOG(3) << "The number of elements in " << fanin_node->GetName() + << " is unknown."; + return Status::OK(); + } + if (vector_size != 2 && vector_size != 4) { + return errors::InvalidArgument( + fanin_node->GetName(), " must be a vector of size 2 or 4, but found ", + vector_size); + } + VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName() << "' with op '" << node->GetOp() << "' from data format '" << context->src_format << "' to '" << context->dst_format << "'"; TF_RETURN_IF_ERROR(UpdateNode(context, node)); - TF_RETURN_IF_ERROR( - UpdateFaninEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute)); + // Do not permute a input_sizes of size 2 because it represents HW regardless + // of whether NCHW or NHWC. + if (vector_size != 2) { + TF_RETURN_IF_ERROR( + UpdateFaninEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute)); + } TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {2}, node, kOpTranspose)); TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose)); return context->graph_view->GetMutationBuilder()->Apply();