[Grappler] Do not permute input_sizes of size 2.
PiperOrigin-RevId: 307496027 Change-Id: I4b3b5a726d334d224eaed65f99b7b3afc385be49
This commit is contained in:
parent
7c5934dfd9
commit
819330a213
tensorflow/core/grappler/optimizers
@ -77,17 +77,21 @@ Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size,
|
||||
|
||||
Output SimpleConv2DBackpropInput(tensorflow::Scope* s, int input_size,
|
||||
int filter_size, const string& padding,
|
||||
bool dilated) {
|
||||
bool dilated, const int input_sizes_length) {
|
||||
int batch_size = 128;
|
||||
int input_height = input_size;
|
||||
int input_width = input_size;
|
||||
int input_depth = 3;
|
||||
int filter_count = 2;
|
||||
int stride = 1;
|
||||
TensorShape input_sizes_shape({4});
|
||||
TensorShape input_sizes_shape({input_sizes_length});
|
||||
Tensor input_data(DT_INT32, input_sizes_shape);
|
||||
test::FillValues<int>(&input_data,
|
||||
{batch_size, input_height, input_width, input_depth});
|
||||
if (input_sizes_length == 4) {
|
||||
test::FillValues<int>(&input_data,
|
||||
{batch_size, input_height, input_width, input_depth});
|
||||
} else {
|
||||
test::FillValues<int>(&input_data, {input_height, input_width});
|
||||
}
|
||||
Output input_sizes =
|
||||
ops::Const(s->WithOpName("InputSizes"), Input::Initializer(input_data));
|
||||
|
||||
@ -353,7 +357,8 @@ TEST_F(GenericLayoutOptimizerTest, Conv2DBackpropInputNonConstInputSizes) {
|
||||
GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
|
||||
#endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
|
||||
Scope s = Scope::NewRootScope();
|
||||
auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "SAME", /*dilated=*/false);
|
||||
auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "SAME", /*dilated=*/false,
|
||||
/*input_sizes_length=*/4);
|
||||
Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
|
||||
GrapplerItem item;
|
||||
TF_ASSERT_OK(s.ToGraphDef(&item.graph));
|
||||
@ -380,6 +385,30 @@ TEST_F(GenericLayoutOptimizerTest, Conv2DBackpropInputNonConstInputSizes) {
|
||||
VerifyRegularFaninMatch(input_sizes_node, 0, "InputSizesIdentity", 0);
|
||||
}
|
||||
|
||||
TEST_F(GenericLayoutOptimizerTest, Conv2DBackpropInput2DInputSizes) {
|
||||
#if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
|
||||
GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
|
||||
#endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
|
||||
Scope s = Scope::NewRootScope();
|
||||
auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "SAME", /*dilated=*/false,
|
||||
/*input_sizes_length=*/2);
|
||||
Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
|
||||
GrapplerItem item;
|
||||
TF_ASSERT_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
GenericLayoutOptimizer optimizer;
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
|
||||
|
||||
Status status;
|
||||
utils::GraphView graph_view(&output, &status);
|
||||
TF_ASSERT_OK(status);
|
||||
auto* conv2d_backprop_node = graph_view.GetNode("Conv2DBackpropInput");
|
||||
ASSERT_NE(conv2d_backprop_node, nullptr);
|
||||
ASSERT_EQ(conv2d_backprop_node->NumRegularFanins(), 3);
|
||||
VerifyRegularFaninMatch(conv2d_backprop_node, 0, "InputSizesIdentity", 0);
|
||||
}
|
||||
|
||||
TEST_F(GenericLayoutOptimizerTest, Conv2DDataFormatVecPermuteCollapse) {
|
||||
#if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
|
||||
GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
|
||||
|
@ -725,12 +725,42 @@ Status Conv2DBackpropInputTransposer::TransposeNode(
|
||||
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4)) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const auto& fanin = node->GetRegularFanin(0);
|
||||
auto* fanin_node = fanin.node_view();
|
||||
const auto* output_shape_attr = fanin_node->GetAttr(kAttrOutputShape);
|
||||
if (output_shape_attr == nullptr) {
|
||||
VLOG(3) << "Cannot compute the shape of " << fanin_node->GetName()
|
||||
<< " because it is missing attribute " << kAttrOutputShape;
|
||||
return Status::OK();
|
||||
}
|
||||
TensorShapeProto fanin_shape = output_shape_attr->list().shape(fanin.index());
|
||||
if (fanin_shape.dim_size() != 1) {
|
||||
VLOG(3) << fanin_node->GetName() << " is not a vector.";
|
||||
return Status::OK();
|
||||
}
|
||||
int vector_size = fanin_shape.dim(0).size();
|
||||
if (vector_size == -1) {
|
||||
VLOG(3) << "The number of elements in " << fanin_node->GetName()
|
||||
<< " is unknown.";
|
||||
return Status::OK();
|
||||
}
|
||||
if (vector_size != 2 && vector_size != 4) {
|
||||
return errors::InvalidArgument(
|
||||
fanin_node->GetName(), " must be a vector of size 2 or 4, but found ",
|
||||
vector_size);
|
||||
}
|
||||
|
||||
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
|
||||
<< "' with op '" << node->GetOp() << "' from data format '"
|
||||
<< context->src_format << "' to '" << context->dst_format << "'";
|
||||
TF_RETURN_IF_ERROR(UpdateNode(context, node));
|
||||
TF_RETURN_IF_ERROR(
|
||||
UpdateFaninEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute));
|
||||
// Do not permute a input_sizes of size 2 because it represents HW regardless
|
||||
// of whether NCHW or NHWC.
|
||||
if (vector_size != 2) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
UpdateFaninEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {2}, node, kOpTranspose));
|
||||
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
|
||||
return context->graph_view->GetMutationBuilder()->Apply();
|
||||
|
Loading…
Reference in New Issue
Block a user