DataFormatVecPermute accepts a vector of size 2.
This partially rolls back cl/307496027. The code before cl/307496027 assumes the actual length of input_sizes is always 4 and always permutes the vector. However, this is unsafe because the length of input_sizes can also be 2. cl/307496027 made the code safe. But this way LayoutOptimizer misses some optimizations, which apparently cause more memory usage. This CL makes DataFormatVecPermute accepts a vector of size 2 as well as a vector of size 4. When the size is 2, the two dimensions are interpreted as spatial dimensions. This way LayoutOptimizer doesn't need to check the static shape of input_sizes. Instead, it applies DataFormatVecPermute regardless of the vector size. See b/156645925 for details. PiperOrigin-RevId: 312571735 Change-Id: I257e2bef328882dbbcd0fe6bf07ef1f8989daf36
This commit is contained in:
parent
786ee6565f
commit
e622f15b21
|
@ -1297,8 +1297,8 @@ static LogicalResult Verify(DataFormatVecPermuteOp op) {
|
||||||
|
|
||||||
if (rank == 1) {
|
if (rank == 1) {
|
||||||
int64_t dim0 = input_ty.getDimSize(0);
|
int64_t dim0 = input_ty.getDimSize(0);
|
||||||
if (dim0 != ShapedType::kDynamicSize && dim0 != 4)
|
if (dim0 != ShapedType::kDynamicSize && dim0 != 4 && dim0 != 2)
|
||||||
return op.emitOpError("requires 1D input of size 4");
|
return op.emitOpError("requires 1D input of size 4 or size 2");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (rank == 2) {
|
if (rank == 2) {
|
||||||
|
|
|
@ -81,11 +81,21 @@ class XlaPermuteOpTest(xla_test.XLATestCase):
|
||||||
x = np.array([7, 4, 9, 3], dtype=dtype)
|
x = np.array([7, 4, 9, 3], dtype=dtype)
|
||||||
self._runPermuteAndCompare(x, "NHWC", "NCHW", [7, 3, 4, 9])
|
self._runPermuteAndCompare(x, "NHWC", "NCHW", [7, 3, 4, 9])
|
||||||
|
|
||||||
|
def testNHWCToNCHW_Size2(self):
|
||||||
|
for dtype in {np.int32, np.int64}:
|
||||||
|
x = np.array([4, 9], dtype=dtype)
|
||||||
|
self._runPermuteAndCompare(x, "NHWC", "NCHW", [4, 9])
|
||||||
|
|
||||||
def testNCHWToNHWC(self):
|
def testNCHWToNHWC(self):
|
||||||
for dtype in {np.int32, np.int64}:
|
for dtype in {np.int32, np.int64}:
|
||||||
x = np.array([7, 4, 9, 3], dtype=dtype)
|
x = np.array([7, 4, 9, 3], dtype=dtype)
|
||||||
self._runPermuteAndCompare(x, "NCHW", "NHWC", [7, 9, 3, 4])
|
self._runPermuteAndCompare(x, "NCHW", "NHWC", [7, 9, 3, 4])
|
||||||
|
|
||||||
|
def testNCHWToNHWC_Size2(self):
|
||||||
|
for dtype in {np.int32, np.int64}:
|
||||||
|
x = np.array([9, 3], dtype=dtype)
|
||||||
|
self._runPermuteAndCompare(x, "NCHW", "NHWC", [9, 3])
|
||||||
|
|
||||||
def testNHWCToHWNC(self):
|
def testNHWCToHWNC(self):
|
||||||
for dtype in {np.int32, np.int64}:
|
for dtype in {np.int32, np.int64}:
|
||||||
x = np.array([7, 4, 9, 3], dtype=dtype)
|
x = np.array([7, 4, 9, 3], dtype=dtype)
|
||||||
|
|
|
@ -106,8 +106,9 @@ class DataFormatVecPermuteOp : public XlaOpKernel {
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
"Input must be a vector or matrix, but got shape ",
|
"Input must be a vector or matrix, but got shape ",
|
||||||
input_tensor_shape.DebugString()));
|
input_tensor_shape.DebugString()));
|
||||||
|
const int dim0 = input_tensor_shape.dim_size(0);
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
ctx, input_tensor_shape.dim_size(0) == 4,
|
ctx, dim0 == 2 || dim0 == 4,
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
"First dimension of input must be of size 4, but got shape ",
|
"First dimension of input must be of size 4, but got shape ",
|
||||||
input_tensor_shape.DebugString()));
|
input_tensor_shape.DebugString()));
|
||||||
|
@ -118,10 +119,25 @@ class DataFormatVecPermuteOp : public XlaOpKernel {
|
||||||
"Second dimension of 2D input must be of size 2, but got shape ",
|
"Second dimension of 2D input must be of size 2, but got shape ",
|
||||||
input_tensor_shape.DebugString()));
|
input_tensor_shape.DebugString()));
|
||||||
}
|
}
|
||||||
int32 dst_indices[4];
|
|
||||||
for (int i = 0; i < 4; ++i) {
|
string src_format_str = src_format_;
|
||||||
for (int j = 0; j < 4; ++j) {
|
string dst_format_str = dst_format_;
|
||||||
if (src_format_[i] == dst_format_[j]) {
|
if (dim0 == 2) {
|
||||||
|
// If the input is a vector of size 2, treat the two elements as spatial
|
||||||
|
// dimensions.
|
||||||
|
auto keep_only_spatial_dimensions = [](string* format_str) -> void {
|
||||||
|
auto new_end = std::remove_if(
|
||||||
|
format_str->begin(), format_str->end(),
|
||||||
|
[](const char dim) { return dim != 'H' && dim != 'W'; });
|
||||||
|
format_str->erase(new_end, format_str->end());
|
||||||
|
};
|
||||||
|
keep_only_spatial_dimensions(&src_format_str);
|
||||||
|
keep_only_spatial_dimensions(&dst_format_str);
|
||||||
|
}
|
||||||
|
std::vector<int32> dst_indices(dim0);
|
||||||
|
for (int i = 0; i < dim0; ++i) {
|
||||||
|
for (int j = 0; j < dim0; ++j) {
|
||||||
|
if (src_format_str[i] == dst_format_str[j]) {
|
||||||
dst_indices[j] = i;
|
dst_indices[j] = i;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
|
@ -356,57 +356,35 @@ TEST_F(GenericLayoutOptimizerTest, Conv2DBackpropInputNonConstInputSizes) {
|
||||||
#if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
|
#if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
|
||||||
GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
|
GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
|
||||||
#endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
|
#endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
|
||||||
Scope s = Scope::NewRootScope();
|
for (const int input_sizes_length : {2, 4}) {
|
||||||
auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "SAME", /*dilated=*/false,
|
Scope s = Scope::NewRootScope();
|
||||||
/*input_sizes_length=*/4);
|
auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "SAME", /*dilated=*/false,
|
||||||
Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
|
input_sizes_length);
|
||||||
GrapplerItem item;
|
Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
|
||||||
TF_ASSERT_OK(s.ToGraphDef(&item.graph));
|
GrapplerItem item;
|
||||||
|
TF_ASSERT_OK(s.ToGraphDef(&item.graph));
|
||||||
|
|
||||||
GenericLayoutOptimizer optimizer;
|
GenericLayoutOptimizer optimizer;
|
||||||
GraphDef output;
|
GraphDef output;
|
||||||
TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
|
TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
|
||||||
|
|
||||||
Status status;
|
Status status;
|
||||||
utils::GraphView graph_view(&output, &status);
|
utils::GraphView graph_view(&output, &status);
|
||||||
TF_ASSERT_OK(status);
|
TF_ASSERT_OK(status);
|
||||||
auto* conv2d_backprop_node = graph_view.GetNode("Conv2DBackpropInput");
|
auto* conv2d_backprop_node = graph_view.GetNode("Conv2DBackpropInput");
|
||||||
ASSERT_NE(conv2d_backprop_node, nullptr);
|
ASSERT_NE(conv2d_backprop_node, nullptr);
|
||||||
ASSERT_EQ(conv2d_backprop_node->NumRegularFanins(), 3);
|
ASSERT_EQ(conv2d_backprop_node->NumRegularFanins(), 3);
|
||||||
VerifyRegularFaninMatch(
|
VerifyRegularFaninMatch(
|
||||||
conv2d_backprop_node, 0,
|
conv2d_backprop_node, 0,
|
||||||
"Conv2DBackpropInput-0-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer",
|
"Conv2DBackpropInput-0-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer",
|
||||||
0);
|
0);
|
||||||
auto* input_sizes_node = graph_view.GetNode(
|
auto* input_sizes_node = graph_view.GetNode(
|
||||||
"Conv2DBackpropInput-0-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer");
|
"Conv2DBackpropInput-0-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer");
|
||||||
ASSERT_NE(input_sizes_node, nullptr);
|
ASSERT_NE(input_sizes_node, nullptr);
|
||||||
EXPECT_EQ(input_sizes_node->GetOp(), "DataFormatVecPermute");
|
EXPECT_EQ(input_sizes_node->GetOp(), "DataFormatVecPermute");
|
||||||
ASSERT_EQ(input_sizes_node->NumRegularFanins(), 1);
|
ASSERT_EQ(input_sizes_node->NumRegularFanins(), 1);
|
||||||
VerifyRegularFaninMatch(input_sizes_node, 0, "InputSizesIdentity", 0);
|
VerifyRegularFaninMatch(input_sizes_node, 0, "InputSizesIdentity", 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(GenericLayoutOptimizerTest, Conv2DBackpropInput2DInputSizes) {
|
|
||||||
#if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
|
|
||||||
GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
|
|
||||||
#endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
|
|
||||||
Scope s = Scope::NewRootScope();
|
|
||||||
auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "SAME", /*dilated=*/false,
|
|
||||||
/*input_sizes_length=*/2);
|
|
||||||
Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
|
|
||||||
GrapplerItem item;
|
|
||||||
TF_ASSERT_OK(s.ToGraphDef(&item.graph));
|
|
||||||
|
|
||||||
GenericLayoutOptimizer optimizer;
|
|
||||||
GraphDef output;
|
|
||||||
TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
|
|
||||||
|
|
||||||
Status status;
|
|
||||||
utils::GraphView graph_view(&output, &status);
|
|
||||||
TF_ASSERT_OK(status);
|
|
||||||
auto* conv2d_backprop_node = graph_view.GetNode("Conv2DBackpropInput");
|
|
||||||
ASSERT_NE(conv2d_backprop_node, nullptr);
|
|
||||||
ASSERT_EQ(conv2d_backprop_node->NumRegularFanins(), 3);
|
|
||||||
VerifyRegularFaninMatch(conv2d_backprop_node, 0, "InputSizesIdentity", 0);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(GenericLayoutOptimizerTest, Conv2DDataFormatVecPermuteCollapse) {
|
TEST_F(GenericLayoutOptimizerTest, Conv2DDataFormatVecPermuteCollapse) {
|
||||||
|
|
|
@ -739,28 +739,13 @@ Status Conv2DBackpropInputTransposer::TransposeNode(
|
||||||
VLOG(3) << fanin_node->GetName() << " is not a vector.";
|
VLOG(3) << fanin_node->GetName() << " is not a vector.";
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
int vector_size = fanin_shape.dim(0).size();
|
|
||||||
if (vector_size == -1) {
|
|
||||||
VLOG(3) << "The number of elements in " << fanin_node->GetName()
|
|
||||||
<< " is unknown.";
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
if (vector_size != 2 && vector_size != 4) {
|
|
||||||
return errors::InvalidArgument(
|
|
||||||
fanin_node->GetName(), " must be a vector of size 2 or 4, but found ",
|
|
||||||
vector_size);
|
|
||||||
}
|
|
||||||
|
|
||||||
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
|
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
|
||||||
<< "' with op '" << node->GetOp() << "' from data format '"
|
<< "' with op '" << node->GetOp() << "' from data format '"
|
||||||
<< context->src_format << "' to '" << context->dst_format << "'";
|
<< context->src_format << "' to '" << context->dst_format << "'";
|
||||||
TF_RETURN_IF_ERROR(UpdateNode(context, node));
|
TF_RETURN_IF_ERROR(UpdateNode(context, node));
|
||||||
// Do not permute a input_sizes of size 2 because it represents HW regardless
|
TF_RETURN_IF_ERROR(
|
||||||
// of whether NCHW or NHWC.
|
UpdateFaninEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute));
|
||||||
if (vector_size != 2) {
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
UpdateFaninEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute));
|
|
||||||
}
|
|
||||||
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {2}, node, kOpTranspose));
|
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {2}, node, kOpTranspose));
|
||||||
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
|
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
|
||||||
return context->graph_view->GetMutationBuilder()->Apply();
|
return context->graph_view->GetMutationBuilder()->Apply();
|
||||||
|
|
|
@ -90,16 +90,15 @@ class DataFormatVecPermuteOp : public OpKernel {
|
||||||
"input must be a vector or 2D tensor, but got shape ",
|
"input must be a vector or 2D tensor, but got shape ",
|
||||||
input.shape().DebugString()));
|
input.shape().DebugString()));
|
||||||
if (input.dims() == 1) {
|
if (input.dims() == 1) {
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(context, input.NumElements() == 2 || input.NumElements() == 4,
|
||||||
context, input.NumElements() == 4,
|
errors::InvalidArgument(
|
||||||
errors::InvalidArgument("1D input must be of size 4, but got shape ",
|
"1D input must be of size 2 or 4, but got shape ",
|
||||||
input.shape().DebugString()));
|
input.shape().DebugString()));
|
||||||
} else if (input.dims() == 2) {
|
} else if (input.dims() == 2) {
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(context, input.dim_size(0) == 2 || input.dim_size(0) == 4,
|
||||||
context, input.dim_size(0) == 4,
|
errors::InvalidArgument("First dimension of 2D input must be "
|
||||||
errors::InvalidArgument(
|
"of size 2 or 4, but got shape ",
|
||||||
"First dimension of 2D input must be of size 4, but got shape ",
|
input.shape().DebugString()));
|
||||||
input.shape().DebugString()));
|
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
context, input.dim_size(1) == 2,
|
context, input.dim_size(1) == 2,
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
|
@ -112,7 +111,21 @@ class DataFormatVecPermuteOp : public OpKernel {
|
||||||
context->allocate_output(0, input.shape(), &output));
|
context->allocate_output(0, input.shape(), &output));
|
||||||
// Support 1D and 2D cases.
|
// Support 1D and 2D cases.
|
||||||
Eigen::DSizes<Eigen::DenseIndex, 8> dst_idx;
|
Eigen::DSizes<Eigen::DenseIndex, 8> dst_idx;
|
||||||
ComputeDstIndex(input.dims(), &dst_idx);
|
string src_format_str = src_format_;
|
||||||
|
string dst_format_str = dst_format_;
|
||||||
|
if (input.dim_size(0) == 2) {
|
||||||
|
// If the input is a vector of size 2, treat the two elements as spatial
|
||||||
|
// dimensions.
|
||||||
|
auto keep_only_spatial_dimensions = [](string* format_str) -> void {
|
||||||
|
auto new_end = std::remove_if(
|
||||||
|
format_str->begin(), format_str->end(),
|
||||||
|
[](const char dim) { return dim != 'H' && dim != 'W'; });
|
||||||
|
format_str->erase(new_end, format_str->end());
|
||||||
|
};
|
||||||
|
keep_only_spatial_dimensions(&src_format_str);
|
||||||
|
keep_only_spatial_dimensions(&dst_format_str);
|
||||||
|
}
|
||||||
|
ComputeDstIndex(src_format_str, dst_format_str, input.dims(), &dst_idx);
|
||||||
|
|
||||||
functor::DataFormatVecPermute<Device, T>()(context->eigen_device<Device>(),
|
functor::DataFormatVecPermute<Device, T>()(context->eigen_device<Device>(),
|
||||||
input.flat<T>(),
|
input.flat<T>(),
|
||||||
|
@ -124,10 +137,12 @@ class DataFormatVecPermuteOp : public OpKernel {
|
||||||
// Example: HWNC --> NHWC
|
// Example: HWNC --> NHWC
|
||||||
// 1D: dst = [1, 2, 0, 3],
|
// 1D: dst = [1, 2, 0, 3],
|
||||||
// 2D: dst = [2, 3, 4, 5, 0, 1, 6, 7]
|
// 2D: dst = [2, 3, 4, 5, 0, 1, 6, 7]
|
||||||
void ComputeDstIndex(int num_dim, Eigen::DSizes<Eigen::DenseIndex, 8>* dst) {
|
static void ComputeDstIndex(const string& src_format_str,
|
||||||
for (int i = 0; i < src_format_.size(); ++i) {
|
const string& dst_format_str, int num_dim,
|
||||||
for (int j = 0; j < dst_format_.size(); ++j) {
|
Eigen::DSizes<Eigen::DenseIndex, 8>* dst) {
|
||||||
if (dst_format_[j] != src_format_[i]) continue;
|
for (int i = 0; i < src_format_str.size(); ++i) {
|
||||||
|
for (int j = 0; j < dst_format_str.size(); ++j) {
|
||||||
|
if (dst_format_str[j] != src_format_str[i]) continue;
|
||||||
// Found the dst index. Set output based on the number of dims.
|
// Found the dst index. Set output based on the number of dims.
|
||||||
for (int k = 0; k < num_dim; ++k) {
|
for (int k = 0; k < num_dim; ++k) {
|
||||||
(*dst)[i * num_dim + k] = j * num_dim + k;
|
(*dst)[i * num_dim + k] = j * num_dim + k;
|
||||||
|
|
|
@ -1199,6 +1199,30 @@ class DataFormatVectorPermuteTest(test_lib.TestCase):
|
||||||
y_val = self.evaluate(y)
|
y_val = self.evaluate(y)
|
||||||
self.assertAllEqual(y_val, [7, 3, 4, 9])
|
self.assertAllEqual(y_val, [7, 3, 4, 9])
|
||||||
|
|
||||||
|
def testNHWCToNCHW_Size2(self):
|
||||||
|
x_val = [4, 9]
|
||||||
|
x = constant_op.constant(x_val)
|
||||||
|
y = nn_ops.data_format_vec_permute(x)
|
||||||
|
with test_util.use_gpu():
|
||||||
|
y_val = self.evaluate(y)
|
||||||
|
self.assertAllEqual(y_val, [4, 9])
|
||||||
|
|
||||||
|
def testNHWCToWHCN(self):
|
||||||
|
x_val = [7, 4, 9, 3]
|
||||||
|
x = constant_op.constant(x_val)
|
||||||
|
y = nn_ops.data_format_vec_permute(x, src_format="NHWC", dst_format="WHCN")
|
||||||
|
with test_util.use_gpu():
|
||||||
|
y_val = self.evaluate(y)
|
||||||
|
self.assertAllEqual(y_val, [9, 4, 3, 7])
|
||||||
|
|
||||||
|
def testNHWCToWHCN_Size2(self):
|
||||||
|
x_val = [4, 9]
|
||||||
|
x = constant_op.constant(x_val)
|
||||||
|
y = nn_ops.data_format_vec_permute(x, src_format="NHWC", dst_format="WHCN")
|
||||||
|
with test_util.use_gpu():
|
||||||
|
y_val = self.evaluate(y)
|
||||||
|
self.assertAllEqual(y_val, [9, 4])
|
||||||
|
|
||||||
def testNCHWToNHWC(self):
|
def testNCHWToNHWC(self):
|
||||||
x_val = [7, 4, 9, 3]
|
x_val = [7, 4, 9, 3]
|
||||||
x = constant_op.constant(x_val)
|
x = constant_op.constant(x_val)
|
||||||
|
@ -1207,6 +1231,14 @@ class DataFormatVectorPermuteTest(test_lib.TestCase):
|
||||||
y_val = self.evaluate(y)
|
y_val = self.evaluate(y)
|
||||||
self.assertAllEqual(y_val, [7, 9, 3, 4])
|
self.assertAllEqual(y_val, [7, 9, 3, 4])
|
||||||
|
|
||||||
|
def testNCHWToNHWC_Size2(self):
|
||||||
|
x_val = [9, 3]
|
||||||
|
x = constant_op.constant(x_val)
|
||||||
|
y = nn_ops.data_format_vec_permute(x)
|
||||||
|
with test_util.use_gpu():
|
||||||
|
y_val = self.evaluate(y)
|
||||||
|
self.assertAllEqual(y_val, [9, 3])
|
||||||
|
|
||||||
def testNHWCToHWNC(self):
|
def testNHWCToHWNC(self):
|
||||||
x_val = [7, 4, 9, 3]
|
x_val = [7, 4, 9, 3]
|
||||||
x = constant_op.constant(x_val)
|
x = constant_op.constant(x_val)
|
||||||
|
|
Loading…
Reference in New Issue