[Grappler] Do not permute input_sizes of size 2.

PiperOrigin-RevId: 307496027
Change-Id: I4b3b5a726d334d224eaed65f99b7b3afc385be49
This commit is contained in:
Jingyue Wu 2020-04-20 16:09:06 -07:00 committed by TensorFlower Gardener
parent 7c5934dfd9
commit 819330a213
2 changed files with 66 additions and 7 deletions

View File

@ -77,17 +77,21 @@ Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size,
Output SimpleConv2DBackpropInput(tensorflow::Scope* s, int input_size,
int filter_size, const string& padding,
bool dilated) {
bool dilated, const int input_sizes_length) {
int batch_size = 128;
int input_height = input_size;
int input_width = input_size;
int input_depth = 3;
int filter_count = 2;
int stride = 1;
TensorShape input_sizes_shape({4});
TensorShape input_sizes_shape({input_sizes_length});
Tensor input_data(DT_INT32, input_sizes_shape);
test::FillValues<int>(&input_data,
{batch_size, input_height, input_width, input_depth});
if (input_sizes_length == 4) {
test::FillValues<int>(&input_data,
{batch_size, input_height, input_width, input_depth});
} else {
test::FillValues<int>(&input_data, {input_height, input_width});
}
Output input_sizes =
ops::Const(s->WithOpName("InputSizes"), Input::Initializer(input_data));
@ -353,7 +357,8 @@ TEST_F(GenericLayoutOptimizerTest, Conv2DBackpropInputNonConstInputSizes) {
GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
#endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
Scope s = Scope::NewRootScope();
auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "SAME", /*dilated=*/false);
auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "SAME", /*dilated=*/false,
/*input_sizes_length=*/4);
Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
GrapplerItem item;
TF_ASSERT_OK(s.ToGraphDef(&item.graph));
@ -380,6 +385,30 @@ TEST_F(GenericLayoutOptimizerTest, Conv2DBackpropInputNonConstInputSizes) {
VerifyRegularFaninMatch(input_sizes_node, 0, "InputSizesIdentity", 0);
}
TEST_F(GenericLayoutOptimizerTest, Conv2DBackpropInput2DInputSizes) {
#if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
#endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
Scope s = Scope::NewRootScope();
auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "SAME", /*dilated=*/false,
/*input_sizes_length=*/2);
Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
GrapplerItem item;
TF_ASSERT_OK(s.ToGraphDef(&item.graph));
GenericLayoutOptimizer optimizer;
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
Status status;
utils::GraphView graph_view(&output, &status);
TF_ASSERT_OK(status);
auto* conv2d_backprop_node = graph_view.GetNode("Conv2DBackpropInput");
ASSERT_NE(conv2d_backprop_node, nullptr);
ASSERT_EQ(conv2d_backprop_node->NumRegularFanins(), 3);
VerifyRegularFaninMatch(conv2d_backprop_node, 0, "InputSizesIdentity", 0);
}
TEST_F(GenericLayoutOptimizerTest, Conv2DDataFormatVecPermuteCollapse) {
#if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";

View File

@ -725,12 +725,42 @@ Status Conv2DBackpropInputTransposer::TransposeNode(
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4)) {
return Status::OK();
}
const auto& fanin = node->GetRegularFanin(0);
auto* fanin_node = fanin.node_view();
const auto* output_shape_attr = fanin_node->GetAttr(kAttrOutputShape);
if (output_shape_attr == nullptr) {
VLOG(3) << "Cannot compute the shape of " << fanin_node->GetName()
<< " because it is missing attribute " << kAttrOutputShape;
return Status::OK();
}
TensorShapeProto fanin_shape = output_shape_attr->list().shape(fanin.index());
if (fanin_shape.dim_size() != 1) {
VLOG(3) << fanin_node->GetName() << " is not a vector.";
return Status::OK();
}
int vector_size = fanin_shape.dim(0).size();
if (vector_size == -1) {
VLOG(3) << "The number of elements in " << fanin_node->GetName()
<< " is unknown.";
return Status::OK();
}
if (vector_size != 2 && vector_size != 4) {
return errors::InvalidArgument(
fanin_node->GetName(), " must be a vector of size 2 or 4, but found ",
vector_size);
}
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
<< "' with op '" << node->GetOp() << "' from data format '"
<< context->src_format << "' to '" << context->dst_format << "'";
TF_RETURN_IF_ERROR(UpdateNode(context, node));
TF_RETURN_IF_ERROR(
UpdateFaninEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute));
// Do not permute a input_sizes of size 2 because it represents HW regardless
// of whether NCHW or NHWC.
if (vector_size != 2) {
TF_RETURN_IF_ERROR(
UpdateFaninEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute));
}
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {2}, node, kOpTranspose));
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
return context->graph_view->GetMutationBuilder()->Apply();