DataFormatVecPermute accepts a vector of size 2.

This partially rolls back cl/307496027.

The code before cl/307496027 assumes the actual length of input_sizes is always
4 and always permutes the vector. However, this is unsafe because the length of
input_sizes can also be 2. cl/307496027 made the code safe. But this way
LayoutOptimizer misses some optimizations, which apparently cause more memory
usage.

This CL makes DataFormatVecPermute accepts a vector of size 2 as well as a
vector of size 4. When the size is 2, the two dimensions are interpreted as
spatial dimensions. This way LayoutOptimizer doesn't need to check the static
shape of input_sizes. Instead, it applies DataFormatVecPermute regardless of
the vector size.

See b/156645925 for details.

PiperOrigin-RevId: 312571735
Change-Id: I257e2bef328882dbbcd0fe6bf07ef1f8989daf36
This commit is contained in:
Jingyue Wu 2020-05-20 15:54:53 -07:00 committed by TensorFlower Gardener
parent 786ee6565f
commit e622f15b21
7 changed files with 123 additions and 87 deletions

View File

@ -1297,8 +1297,8 @@ static LogicalResult Verify(DataFormatVecPermuteOp op) {
if (rank == 1) {
int64_t dim0 = input_ty.getDimSize(0);
if (dim0 != ShapedType::kDynamicSize && dim0 != 4)
return op.emitOpError("requires 1D input of size 4");
if (dim0 != ShapedType::kDynamicSize && dim0 != 4 && dim0 != 2)
return op.emitOpError("requires 1D input of size 4 or size 2");
}
if (rank == 2) {

View File

@ -81,11 +81,21 @@ class XlaPermuteOpTest(xla_test.XLATestCase):
x = np.array([7, 4, 9, 3], dtype=dtype)
self._runPermuteAndCompare(x, "NHWC", "NCHW", [7, 3, 4, 9])
def testNHWCToNCHW_Size2(self):
for dtype in {np.int32, np.int64}:
x = np.array([4, 9], dtype=dtype)
self._runPermuteAndCompare(x, "NHWC", "NCHW", [4, 9])
def testNCHWToNHWC(self):
for dtype in {np.int32, np.int64}:
x = np.array([7, 4, 9, 3], dtype=dtype)
self._runPermuteAndCompare(x, "NCHW", "NHWC", [7, 9, 3, 4])
def testNCHWToNHWC_Size2(self):
for dtype in {np.int32, np.int64}:
x = np.array([9, 3], dtype=dtype)
self._runPermuteAndCompare(x, "NCHW", "NHWC", [9, 3])
def testNHWCToHWNC(self):
for dtype in {np.int32, np.int64}:
x = np.array([7, 4, 9, 3], dtype=dtype)

View File

@ -106,8 +106,9 @@ class DataFormatVecPermuteOp : public XlaOpKernel {
errors::InvalidArgument(
"Input must be a vector or matrix, but got shape ",
input_tensor_shape.DebugString()));
const int dim0 = input_tensor_shape.dim_size(0);
OP_REQUIRES(
ctx, input_tensor_shape.dim_size(0) == 4,
ctx, dim0 == 2 || dim0 == 4,
errors::InvalidArgument(
"First dimension of input must be of size 4, but got shape ",
input_tensor_shape.DebugString()));
@ -118,10 +119,25 @@ class DataFormatVecPermuteOp : public XlaOpKernel {
"Second dimension of 2D input must be of size 2, but got shape ",
input_tensor_shape.DebugString()));
}
int32 dst_indices[4];
for (int i = 0; i < 4; ++i) {
for (int j = 0; j < 4; ++j) {
if (src_format_[i] == dst_format_[j]) {
string src_format_str = src_format_;
string dst_format_str = dst_format_;
if (dim0 == 2) {
// If the input is a vector of size 2, treat the two elements as spatial
// dimensions.
auto keep_only_spatial_dimensions = [](string* format_str) -> void {
auto new_end = std::remove_if(
format_str->begin(), format_str->end(),
[](const char dim) { return dim != 'H' && dim != 'W'; });
format_str->erase(new_end, format_str->end());
};
keep_only_spatial_dimensions(&src_format_str);
keep_only_spatial_dimensions(&dst_format_str);
}
std::vector<int32> dst_indices(dim0);
for (int i = 0; i < dim0; ++i) {
for (int j = 0; j < dim0; ++j) {
if (src_format_str[i] == dst_format_str[j]) {
dst_indices[j] = i;
break;
}

View File

@ -356,57 +356,35 @@ TEST_F(GenericLayoutOptimizerTest, Conv2DBackpropInputNonConstInputSizes) {
#if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
#endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
Scope s = Scope::NewRootScope();
auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "SAME", /*dilated=*/false,
/*input_sizes_length=*/4);
Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
GrapplerItem item;
TF_ASSERT_OK(s.ToGraphDef(&item.graph));
for (const int input_sizes_length : {2, 4}) {
Scope s = Scope::NewRootScope();
auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "SAME", /*dilated=*/false,
input_sizes_length);
Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
GrapplerItem item;
TF_ASSERT_OK(s.ToGraphDef(&item.graph));
GenericLayoutOptimizer optimizer;
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
GenericLayoutOptimizer optimizer;
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
Status status;
utils::GraphView graph_view(&output, &status);
TF_ASSERT_OK(status);
auto* conv2d_backprop_node = graph_view.GetNode("Conv2DBackpropInput");
ASSERT_NE(conv2d_backprop_node, nullptr);
ASSERT_EQ(conv2d_backprop_node->NumRegularFanins(), 3);
VerifyRegularFaninMatch(
conv2d_backprop_node, 0,
"Conv2DBackpropInput-0-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer",
0);
auto* input_sizes_node = graph_view.GetNode(
"Conv2DBackpropInput-0-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer");
ASSERT_NE(input_sizes_node, nullptr);
EXPECT_EQ(input_sizes_node->GetOp(), "DataFormatVecPermute");
ASSERT_EQ(input_sizes_node->NumRegularFanins(), 1);
VerifyRegularFaninMatch(input_sizes_node, 0, "InputSizesIdentity", 0);
}
TEST_F(GenericLayoutOptimizerTest, Conv2DBackpropInput2DInputSizes) {
#if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
#endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
Scope s = Scope::NewRootScope();
auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "SAME", /*dilated=*/false,
/*input_sizes_length=*/2);
Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
GrapplerItem item;
TF_ASSERT_OK(s.ToGraphDef(&item.graph));
GenericLayoutOptimizer optimizer;
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
Status status;
utils::GraphView graph_view(&output, &status);
TF_ASSERT_OK(status);
auto* conv2d_backprop_node = graph_view.GetNode("Conv2DBackpropInput");
ASSERT_NE(conv2d_backprop_node, nullptr);
ASSERT_EQ(conv2d_backprop_node->NumRegularFanins(), 3);
VerifyRegularFaninMatch(conv2d_backprop_node, 0, "InputSizesIdentity", 0);
Status status;
utils::GraphView graph_view(&output, &status);
TF_ASSERT_OK(status);
auto* conv2d_backprop_node = graph_view.GetNode("Conv2DBackpropInput");
ASSERT_NE(conv2d_backprop_node, nullptr);
ASSERT_EQ(conv2d_backprop_node->NumRegularFanins(), 3);
VerifyRegularFaninMatch(
conv2d_backprop_node, 0,
"Conv2DBackpropInput-0-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer",
0);
auto* input_sizes_node = graph_view.GetNode(
"Conv2DBackpropInput-0-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer");
ASSERT_NE(input_sizes_node, nullptr);
EXPECT_EQ(input_sizes_node->GetOp(), "DataFormatVecPermute");
ASSERT_EQ(input_sizes_node->NumRegularFanins(), 1);
VerifyRegularFaninMatch(input_sizes_node, 0, "InputSizesIdentity", 0);
}
}
TEST_F(GenericLayoutOptimizerTest, Conv2DDataFormatVecPermuteCollapse) {

View File

@ -739,28 +739,13 @@ Status Conv2DBackpropInputTransposer::TransposeNode(
VLOG(3) << fanin_node->GetName() << " is not a vector.";
return Status::OK();
}
int vector_size = fanin_shape.dim(0).size();
if (vector_size == -1) {
VLOG(3) << "The number of elements in " << fanin_node->GetName()
<< " is unknown.";
return Status::OK();
}
if (vector_size != 2 && vector_size != 4) {
return errors::InvalidArgument(
fanin_node->GetName(), " must be a vector of size 2 or 4, but found ",
vector_size);
}
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
<< "' with op '" << node->GetOp() << "' from data format '"
<< context->src_format << "' to '" << context->dst_format << "'";
TF_RETURN_IF_ERROR(UpdateNode(context, node));
// Do not permute a input_sizes of size 2 because it represents HW regardless
// of whether NCHW or NHWC.
if (vector_size != 2) {
TF_RETURN_IF_ERROR(
UpdateFaninEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute));
}
TF_RETURN_IF_ERROR(
UpdateFaninEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute));
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {2}, node, kOpTranspose));
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
return context->graph_view->GetMutationBuilder()->Apply();

View File

@ -90,16 +90,15 @@ class DataFormatVecPermuteOp : public OpKernel {
"input must be a vector or 2D tensor, but got shape ",
input.shape().DebugString()));
if (input.dims() == 1) {
OP_REQUIRES(
context, input.NumElements() == 4,
errors::InvalidArgument("1D input must be of size 4, but got shape ",
input.shape().DebugString()));
OP_REQUIRES(context, input.NumElements() == 2 || input.NumElements() == 4,
errors::InvalidArgument(
"1D input must be of size 2 or 4, but got shape ",
input.shape().DebugString()));
} else if (input.dims() == 2) {
OP_REQUIRES(
context, input.dim_size(0) == 4,
errors::InvalidArgument(
"First dimension of 2D input must be of size 4, but got shape ",
input.shape().DebugString()));
OP_REQUIRES(context, input.dim_size(0) == 2 || input.dim_size(0) == 4,
errors::InvalidArgument("First dimension of 2D input must be "
"of size 2 or 4, but got shape ",
input.shape().DebugString()));
OP_REQUIRES(
context, input.dim_size(1) == 2,
errors::InvalidArgument(
@ -112,7 +111,21 @@ class DataFormatVecPermuteOp : public OpKernel {
context->allocate_output(0, input.shape(), &output));
// Support 1D and 2D cases.
Eigen::DSizes<Eigen::DenseIndex, 8> dst_idx;
ComputeDstIndex(input.dims(), &dst_idx);
string src_format_str = src_format_;
string dst_format_str = dst_format_;
if (input.dim_size(0) == 2) {
// If the input is a vector of size 2, treat the two elements as spatial
// dimensions.
auto keep_only_spatial_dimensions = [](string* format_str) -> void {
auto new_end = std::remove_if(
format_str->begin(), format_str->end(),
[](const char dim) { return dim != 'H' && dim != 'W'; });
format_str->erase(new_end, format_str->end());
};
keep_only_spatial_dimensions(&src_format_str);
keep_only_spatial_dimensions(&dst_format_str);
}
ComputeDstIndex(src_format_str, dst_format_str, input.dims(), &dst_idx);
functor::DataFormatVecPermute<Device, T>()(context->eigen_device<Device>(),
input.flat<T>(),
@ -124,10 +137,12 @@ class DataFormatVecPermuteOp : public OpKernel {
// Example: HWNC --> NHWC
// 1D: dst = [1, 2, 0, 3],
// 2D: dst = [2, 3, 4, 5, 0, 1, 6, 7]
void ComputeDstIndex(int num_dim, Eigen::DSizes<Eigen::DenseIndex, 8>* dst) {
for (int i = 0; i < src_format_.size(); ++i) {
for (int j = 0; j < dst_format_.size(); ++j) {
if (dst_format_[j] != src_format_[i]) continue;
static void ComputeDstIndex(const string& src_format_str,
const string& dst_format_str, int num_dim,
Eigen::DSizes<Eigen::DenseIndex, 8>* dst) {
for (int i = 0; i < src_format_str.size(); ++i) {
for (int j = 0; j < dst_format_str.size(); ++j) {
if (dst_format_str[j] != src_format_str[i]) continue;
// Found the dst index. Set output based on the number of dims.
for (int k = 0; k < num_dim; ++k) {
(*dst)[i * num_dim + k] = j * num_dim + k;

View File

@ -1199,6 +1199,30 @@ class DataFormatVectorPermuteTest(test_lib.TestCase):
y_val = self.evaluate(y)
self.assertAllEqual(y_val, [7, 3, 4, 9])
def testNHWCToNCHW_Size2(self):
x_val = [4, 9]
x = constant_op.constant(x_val)
y = nn_ops.data_format_vec_permute(x)
with test_util.use_gpu():
y_val = self.evaluate(y)
self.assertAllEqual(y_val, [4, 9])
def testNHWCToWHCN(self):
x_val = [7, 4, 9, 3]
x = constant_op.constant(x_val)
y = nn_ops.data_format_vec_permute(x, src_format="NHWC", dst_format="WHCN")
with test_util.use_gpu():
y_val = self.evaluate(y)
self.assertAllEqual(y_val, [9, 4, 3, 7])
def testNHWCToWHCN_Size2(self):
x_val = [4, 9]
x = constant_op.constant(x_val)
y = nn_ops.data_format_vec_permute(x, src_format="NHWC", dst_format="WHCN")
with test_util.use_gpu():
y_val = self.evaluate(y)
self.assertAllEqual(y_val, [9, 4])
def testNCHWToNHWC(self):
x_val = [7, 4, 9, 3]
x = constant_op.constant(x_val)
@ -1207,6 +1231,14 @@ class DataFormatVectorPermuteTest(test_lib.TestCase):
y_val = self.evaluate(y)
self.assertAllEqual(y_val, [7, 9, 3, 4])
def testNCHWToNHWC_Size2(self):
x_val = [9, 3]
x = constant_op.constant(x_val)
y = nn_ops.data_format_vec_permute(x)
with test_util.use_gpu():
y_val = self.evaluate(y)
self.assertAllEqual(y_val, [9, 3])
def testNHWCToHWNC(self):
x_val = [7, 4, 9, 3]
x = constant_op.constant(x_val)