DataFormatVecPermute accepts a vector of size 2.
This partially rolls back cl/307496027. The code before cl/307496027 assumes the actual length of input_sizes is always 4 and always permutes the vector. However, this is unsafe because the length of input_sizes can also be 2. cl/307496027 made the code safe. But this way LayoutOptimizer misses some optimizations, which apparently cause more memory usage. This CL makes DataFormatVecPermute accepts a vector of size 2 as well as a vector of size 4. When the size is 2, the two dimensions are interpreted as spatial dimensions. This way LayoutOptimizer doesn't need to check the static shape of input_sizes. Instead, it applies DataFormatVecPermute regardless of the vector size. See b/156645925 for details. PiperOrigin-RevId: 312571735 Change-Id: I257e2bef328882dbbcd0fe6bf07ef1f8989daf36
This commit is contained in:
parent
786ee6565f
commit
e622f15b21
|
@ -1297,8 +1297,8 @@ static LogicalResult Verify(DataFormatVecPermuteOp op) {
|
|||
|
||||
if (rank == 1) {
|
||||
int64_t dim0 = input_ty.getDimSize(0);
|
||||
if (dim0 != ShapedType::kDynamicSize && dim0 != 4)
|
||||
return op.emitOpError("requires 1D input of size 4");
|
||||
if (dim0 != ShapedType::kDynamicSize && dim0 != 4 && dim0 != 2)
|
||||
return op.emitOpError("requires 1D input of size 4 or size 2");
|
||||
}
|
||||
|
||||
if (rank == 2) {
|
||||
|
|
|
@ -81,11 +81,21 @@ class XlaPermuteOpTest(xla_test.XLATestCase):
|
|||
x = np.array([7, 4, 9, 3], dtype=dtype)
|
||||
self._runPermuteAndCompare(x, "NHWC", "NCHW", [7, 3, 4, 9])
|
||||
|
||||
def testNHWCToNCHW_Size2(self):
|
||||
for dtype in {np.int32, np.int64}:
|
||||
x = np.array([4, 9], dtype=dtype)
|
||||
self._runPermuteAndCompare(x, "NHWC", "NCHW", [4, 9])
|
||||
|
||||
def testNCHWToNHWC(self):
|
||||
for dtype in {np.int32, np.int64}:
|
||||
x = np.array([7, 4, 9, 3], dtype=dtype)
|
||||
self._runPermuteAndCompare(x, "NCHW", "NHWC", [7, 9, 3, 4])
|
||||
|
||||
def testNCHWToNHWC_Size2(self):
|
||||
for dtype in {np.int32, np.int64}:
|
||||
x = np.array([9, 3], dtype=dtype)
|
||||
self._runPermuteAndCompare(x, "NCHW", "NHWC", [9, 3])
|
||||
|
||||
def testNHWCToHWNC(self):
|
||||
for dtype in {np.int32, np.int64}:
|
||||
x = np.array([7, 4, 9, 3], dtype=dtype)
|
||||
|
|
|
@ -106,8 +106,9 @@ class DataFormatVecPermuteOp : public XlaOpKernel {
|
|||
errors::InvalidArgument(
|
||||
"Input must be a vector or matrix, but got shape ",
|
||||
input_tensor_shape.DebugString()));
|
||||
const int dim0 = input_tensor_shape.dim_size(0);
|
||||
OP_REQUIRES(
|
||||
ctx, input_tensor_shape.dim_size(0) == 4,
|
||||
ctx, dim0 == 2 || dim0 == 4,
|
||||
errors::InvalidArgument(
|
||||
"First dimension of input must be of size 4, but got shape ",
|
||||
input_tensor_shape.DebugString()));
|
||||
|
@ -118,10 +119,25 @@ class DataFormatVecPermuteOp : public XlaOpKernel {
|
|||
"Second dimension of 2D input must be of size 2, but got shape ",
|
||||
input_tensor_shape.DebugString()));
|
||||
}
|
||||
int32 dst_indices[4];
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
if (src_format_[i] == dst_format_[j]) {
|
||||
|
||||
string src_format_str = src_format_;
|
||||
string dst_format_str = dst_format_;
|
||||
if (dim0 == 2) {
|
||||
// If the input is a vector of size 2, treat the two elements as spatial
|
||||
// dimensions.
|
||||
auto keep_only_spatial_dimensions = [](string* format_str) -> void {
|
||||
auto new_end = std::remove_if(
|
||||
format_str->begin(), format_str->end(),
|
||||
[](const char dim) { return dim != 'H' && dim != 'W'; });
|
||||
format_str->erase(new_end, format_str->end());
|
||||
};
|
||||
keep_only_spatial_dimensions(&src_format_str);
|
||||
keep_only_spatial_dimensions(&dst_format_str);
|
||||
}
|
||||
std::vector<int32> dst_indices(dim0);
|
||||
for (int i = 0; i < dim0; ++i) {
|
||||
for (int j = 0; j < dim0; ++j) {
|
||||
if (src_format_str[i] == dst_format_str[j]) {
|
||||
dst_indices[j] = i;
|
||||
break;
|
||||
}
|
||||
|
|
|
@ -356,57 +356,35 @@ TEST_F(GenericLayoutOptimizerTest, Conv2DBackpropInputNonConstInputSizes) {
|
|||
#if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
|
||||
GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
|
||||
#endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
|
||||
Scope s = Scope::NewRootScope();
|
||||
auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "SAME", /*dilated=*/false,
|
||||
/*input_sizes_length=*/4);
|
||||
Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
|
||||
GrapplerItem item;
|
||||
TF_ASSERT_OK(s.ToGraphDef(&item.graph));
|
||||
for (const int input_sizes_length : {2, 4}) {
|
||||
Scope s = Scope::NewRootScope();
|
||||
auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "SAME", /*dilated=*/false,
|
||||
input_sizes_length);
|
||||
Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
|
||||
GrapplerItem item;
|
||||
TF_ASSERT_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
GenericLayoutOptimizer optimizer;
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
|
||||
GenericLayoutOptimizer optimizer;
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
|
||||
|
||||
Status status;
|
||||
utils::GraphView graph_view(&output, &status);
|
||||
TF_ASSERT_OK(status);
|
||||
auto* conv2d_backprop_node = graph_view.GetNode("Conv2DBackpropInput");
|
||||
ASSERT_NE(conv2d_backprop_node, nullptr);
|
||||
ASSERT_EQ(conv2d_backprop_node->NumRegularFanins(), 3);
|
||||
VerifyRegularFaninMatch(
|
||||
conv2d_backprop_node, 0,
|
||||
"Conv2DBackpropInput-0-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer",
|
||||
0);
|
||||
auto* input_sizes_node = graph_view.GetNode(
|
||||
"Conv2DBackpropInput-0-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer");
|
||||
ASSERT_NE(input_sizes_node, nullptr);
|
||||
EXPECT_EQ(input_sizes_node->GetOp(), "DataFormatVecPermute");
|
||||
ASSERT_EQ(input_sizes_node->NumRegularFanins(), 1);
|
||||
VerifyRegularFaninMatch(input_sizes_node, 0, "InputSizesIdentity", 0);
|
||||
}
|
||||
|
||||
TEST_F(GenericLayoutOptimizerTest, Conv2DBackpropInput2DInputSizes) {
|
||||
#if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
|
||||
GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
|
||||
#endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
|
||||
Scope s = Scope::NewRootScope();
|
||||
auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "SAME", /*dilated=*/false,
|
||||
/*input_sizes_length=*/2);
|
||||
Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
|
||||
GrapplerItem item;
|
||||
TF_ASSERT_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
GenericLayoutOptimizer optimizer;
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
|
||||
|
||||
Status status;
|
||||
utils::GraphView graph_view(&output, &status);
|
||||
TF_ASSERT_OK(status);
|
||||
auto* conv2d_backprop_node = graph_view.GetNode("Conv2DBackpropInput");
|
||||
ASSERT_NE(conv2d_backprop_node, nullptr);
|
||||
ASSERT_EQ(conv2d_backprop_node->NumRegularFanins(), 3);
|
||||
VerifyRegularFaninMatch(conv2d_backprop_node, 0, "InputSizesIdentity", 0);
|
||||
Status status;
|
||||
utils::GraphView graph_view(&output, &status);
|
||||
TF_ASSERT_OK(status);
|
||||
auto* conv2d_backprop_node = graph_view.GetNode("Conv2DBackpropInput");
|
||||
ASSERT_NE(conv2d_backprop_node, nullptr);
|
||||
ASSERT_EQ(conv2d_backprop_node->NumRegularFanins(), 3);
|
||||
VerifyRegularFaninMatch(
|
||||
conv2d_backprop_node, 0,
|
||||
"Conv2DBackpropInput-0-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer",
|
||||
0);
|
||||
auto* input_sizes_node = graph_view.GetNode(
|
||||
"Conv2DBackpropInput-0-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer");
|
||||
ASSERT_NE(input_sizes_node, nullptr);
|
||||
EXPECT_EQ(input_sizes_node->GetOp(), "DataFormatVecPermute");
|
||||
ASSERT_EQ(input_sizes_node->NumRegularFanins(), 1);
|
||||
VerifyRegularFaninMatch(input_sizes_node, 0, "InputSizesIdentity", 0);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(GenericLayoutOptimizerTest, Conv2DDataFormatVecPermuteCollapse) {
|
||||
|
|
|
@ -739,28 +739,13 @@ Status Conv2DBackpropInputTransposer::TransposeNode(
|
|||
VLOG(3) << fanin_node->GetName() << " is not a vector.";
|
||||
return Status::OK();
|
||||
}
|
||||
int vector_size = fanin_shape.dim(0).size();
|
||||
if (vector_size == -1) {
|
||||
VLOG(3) << "The number of elements in " << fanin_node->GetName()
|
||||
<< " is unknown.";
|
||||
return Status::OK();
|
||||
}
|
||||
if (vector_size != 2 && vector_size != 4) {
|
||||
return errors::InvalidArgument(
|
||||
fanin_node->GetName(), " must be a vector of size 2 or 4, but found ",
|
||||
vector_size);
|
||||
}
|
||||
|
||||
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
|
||||
<< "' with op '" << node->GetOp() << "' from data format '"
|
||||
<< context->src_format << "' to '" << context->dst_format << "'";
|
||||
TF_RETURN_IF_ERROR(UpdateNode(context, node));
|
||||
// Do not permute a input_sizes of size 2 because it represents HW regardless
|
||||
// of whether NCHW or NHWC.
|
||||
if (vector_size != 2) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
UpdateFaninEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
UpdateFaninEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute));
|
||||
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {2}, node, kOpTranspose));
|
||||
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
|
||||
return context->graph_view->GetMutationBuilder()->Apply();
|
||||
|
|
|
@ -90,16 +90,15 @@ class DataFormatVecPermuteOp : public OpKernel {
|
|||
"input must be a vector or 2D tensor, but got shape ",
|
||||
input.shape().DebugString()));
|
||||
if (input.dims() == 1) {
|
||||
OP_REQUIRES(
|
||||
context, input.NumElements() == 4,
|
||||
errors::InvalidArgument("1D input must be of size 4, but got shape ",
|
||||
input.shape().DebugString()));
|
||||
OP_REQUIRES(context, input.NumElements() == 2 || input.NumElements() == 4,
|
||||
errors::InvalidArgument(
|
||||
"1D input must be of size 2 or 4, but got shape ",
|
||||
input.shape().DebugString()));
|
||||
} else if (input.dims() == 2) {
|
||||
OP_REQUIRES(
|
||||
context, input.dim_size(0) == 4,
|
||||
errors::InvalidArgument(
|
||||
"First dimension of 2D input must be of size 4, but got shape ",
|
||||
input.shape().DebugString()));
|
||||
OP_REQUIRES(context, input.dim_size(0) == 2 || input.dim_size(0) == 4,
|
||||
errors::InvalidArgument("First dimension of 2D input must be "
|
||||
"of size 2 or 4, but got shape ",
|
||||
input.shape().DebugString()));
|
||||
OP_REQUIRES(
|
||||
context, input.dim_size(1) == 2,
|
||||
errors::InvalidArgument(
|
||||
|
@ -112,7 +111,21 @@ class DataFormatVecPermuteOp : public OpKernel {
|
|||
context->allocate_output(0, input.shape(), &output));
|
||||
// Support 1D and 2D cases.
|
||||
Eigen::DSizes<Eigen::DenseIndex, 8> dst_idx;
|
||||
ComputeDstIndex(input.dims(), &dst_idx);
|
||||
string src_format_str = src_format_;
|
||||
string dst_format_str = dst_format_;
|
||||
if (input.dim_size(0) == 2) {
|
||||
// If the input is a vector of size 2, treat the two elements as spatial
|
||||
// dimensions.
|
||||
auto keep_only_spatial_dimensions = [](string* format_str) -> void {
|
||||
auto new_end = std::remove_if(
|
||||
format_str->begin(), format_str->end(),
|
||||
[](const char dim) { return dim != 'H' && dim != 'W'; });
|
||||
format_str->erase(new_end, format_str->end());
|
||||
};
|
||||
keep_only_spatial_dimensions(&src_format_str);
|
||||
keep_only_spatial_dimensions(&dst_format_str);
|
||||
}
|
||||
ComputeDstIndex(src_format_str, dst_format_str, input.dims(), &dst_idx);
|
||||
|
||||
functor::DataFormatVecPermute<Device, T>()(context->eigen_device<Device>(),
|
||||
input.flat<T>(),
|
||||
|
@ -124,10 +137,12 @@ class DataFormatVecPermuteOp : public OpKernel {
|
|||
// Example: HWNC --> NHWC
|
||||
// 1D: dst = [1, 2, 0, 3],
|
||||
// 2D: dst = [2, 3, 4, 5, 0, 1, 6, 7]
|
||||
void ComputeDstIndex(int num_dim, Eigen::DSizes<Eigen::DenseIndex, 8>* dst) {
|
||||
for (int i = 0; i < src_format_.size(); ++i) {
|
||||
for (int j = 0; j < dst_format_.size(); ++j) {
|
||||
if (dst_format_[j] != src_format_[i]) continue;
|
||||
static void ComputeDstIndex(const string& src_format_str,
|
||||
const string& dst_format_str, int num_dim,
|
||||
Eigen::DSizes<Eigen::DenseIndex, 8>* dst) {
|
||||
for (int i = 0; i < src_format_str.size(); ++i) {
|
||||
for (int j = 0; j < dst_format_str.size(); ++j) {
|
||||
if (dst_format_str[j] != src_format_str[i]) continue;
|
||||
// Found the dst index. Set output based on the number of dims.
|
||||
for (int k = 0; k < num_dim; ++k) {
|
||||
(*dst)[i * num_dim + k] = j * num_dim + k;
|
||||
|
|
|
@ -1199,6 +1199,30 @@ class DataFormatVectorPermuteTest(test_lib.TestCase):
|
|||
y_val = self.evaluate(y)
|
||||
self.assertAllEqual(y_val, [7, 3, 4, 9])
|
||||
|
||||
def testNHWCToNCHW_Size2(self):
|
||||
x_val = [4, 9]
|
||||
x = constant_op.constant(x_val)
|
||||
y = nn_ops.data_format_vec_permute(x)
|
||||
with test_util.use_gpu():
|
||||
y_val = self.evaluate(y)
|
||||
self.assertAllEqual(y_val, [4, 9])
|
||||
|
||||
def testNHWCToWHCN(self):
|
||||
x_val = [7, 4, 9, 3]
|
||||
x = constant_op.constant(x_val)
|
||||
y = nn_ops.data_format_vec_permute(x, src_format="NHWC", dst_format="WHCN")
|
||||
with test_util.use_gpu():
|
||||
y_val = self.evaluate(y)
|
||||
self.assertAllEqual(y_val, [9, 4, 3, 7])
|
||||
|
||||
def testNHWCToWHCN_Size2(self):
|
||||
x_val = [4, 9]
|
||||
x = constant_op.constant(x_val)
|
||||
y = nn_ops.data_format_vec_permute(x, src_format="NHWC", dst_format="WHCN")
|
||||
with test_util.use_gpu():
|
||||
y_val = self.evaluate(y)
|
||||
self.assertAllEqual(y_val, [9, 4])
|
||||
|
||||
def testNCHWToNHWC(self):
|
||||
x_val = [7, 4, 9, 3]
|
||||
x = constant_op.constant(x_val)
|
||||
|
@ -1207,6 +1231,14 @@ class DataFormatVectorPermuteTest(test_lib.TestCase):
|
|||
y_val = self.evaluate(y)
|
||||
self.assertAllEqual(y_val, [7, 9, 3, 4])
|
||||
|
||||
def testNCHWToNHWC_Size2(self):
|
||||
x_val = [9, 3]
|
||||
x = constant_op.constant(x_val)
|
||||
y = nn_ops.data_format_vec_permute(x)
|
||||
with test_util.use_gpu():
|
||||
y_val = self.evaluate(y)
|
||||
self.assertAllEqual(y_val, [9, 3])
|
||||
|
||||
def testNHWCToHWNC(self):
|
||||
x_val = [7, 4, 9, 3]
|
||||
x = constant_op.constant(x_val)
|
||||
|
|
Loading…
Reference in New Issue