Fix transpose bug for large dimension.

Add random tests of large shapes for better coverage.
Update transpose benchmark with cases that swap one small dimension with one large dimension.

PiperOrigin-RevId: 171302097
This commit is contained in:
Yangzihao Wang 2017-10-06 09:35:06 -07:00 committed by TensorFlower Gardener
parent 3251bc0792
commit 2daa40f9d0
5 changed files with 393 additions and 17 deletions

View File

@ -272,6 +272,88 @@ __global__ void SwapDimension1And2InTensor3UsingTiles(const T* input,
}
}
// Use shared memory tiles to swap dimension-1 and dimension-2 of a 3D tensor
// when only one of the dimension sizes is smaller than 16,
// where dimensions are zero-based: output[i][j][k] = input[i][k][j].
//
// small_dim = the_smaller_dimension_size
// large_dim = the_larger_dimension_size
// tile_num_per_block = blockDim.x
// kTileLength = small_dim
//
// Each thread block operates on a single rectangle tile, where its width is
// kTileLength (we currently set it to 64) and its height is small_dim,
// We set the thread block's X dimension to be tile_num_per_block, and its Y
// and Z to be one.
template <typename T, int ShmemSize, bool SmallDim2>
__global__ void SwapDimension1And2InTensor3SmallDim(const T* input,
int batch_per_block,
Dimension<3> input_dims,
T* output) {
// TODO(yangzihao) avoid share memory bank conflict.
__shared__ T shared_memory_tile[ShmemSize];
eigen_assert(blockDim.y == 1);
eigen_assert(blockDim.z == 1);
eigen_assert(gridDim.z == 1);
int block_offset = blockIdx.x * blockDim.x;
int x = threadIdx.x;
int tile_height = blockDim.x;
// Get tile height, width, and thread/block origin indices.
int small_dim = SmallDim2 ? input_dims[2] : input_dims[1];
int large_dim = SmallDim2 ? input_dims[1] : input_dims[2];
int global_offset = small_dim * large_dim * (blockIdx.y * batch_per_block) +
(SmallDim2 ? block_offset * small_dim : block_offset);
if (global_offset >= (input_dims[0] * input_dims[1] * input_dims[2])) return;
for (int batch = 0; batch < batch_per_block; ++batch) {
int block_origin_idx =
small_dim * large_dim * (blockIdx.y * batch_per_block + batch);
int thread_origin_idx =
block_origin_idx +
(SmallDim2 ? block_offset * small_dim : block_offset) + x;
if (block_offset + blockDim.x > large_dim) {
tile_height = large_dim - block_offset;
}
__syncthreads();
// Load a continuous memory region to shared memory tile.
if (x < tile_height) {
for (int y = 0; y < small_dim; y++) {
int shmem_index =
SmallDim2 ? (x + y * tile_height) : (x * small_dim + y);
shared_memory_tile[shmem_index] =
ldg(input + thread_origin_idx +
y * (SmallDim2 ? tile_height : large_dim));
}
}
__syncthreads();
// Get block origin index for output array.
int output_block_offset = block_origin_idx;
int output_block_idx = SmallDim2 ? block_offset : block_offset * small_dim;
int output_block_origin_idx = output_block_offset + output_block_idx;
// Store the tranposed memory region in shared memory to device.
if (x < tile_height) {
for (int y = 0; y < small_dim; y++) {
int output_idx = output_block_origin_idx + x +
y * (SmallDim2 ? large_dim : tile_height);
int shmem_index =
SmallDim2 ? (x * small_dim + y) : (x + y * tile_height);
output[output_idx] = shared_memory_tile[shmem_index];
}
}
}
}
// A Cuda custom kernel that convert input to output, given proper padding on
// the left and the top. The padded value is zero.
template <typename T, int NDIMS>
@ -420,25 +502,62 @@ template <typename T>
void RunSwapDimension1And2InTensor3(const GPUDevice& d, const T* input,
const Dimension<3>& input_dims, T* output) {
// If both dimensions are not trivial, use tiles for the actual swapping.
// If one dimension is trivial, use SmallDim kernel for swapping.
// Otherwise, the trivial swapping relying on the ldg cache is more efficient.
static const int kMinDimensionToUseTiles = 16;
bool use_tiles = (input_dims[1] >= kMinDimensionToUseTiles &&
input_dims[2] >= kMinDimensionToUseTiles);
bool use_small_dim = ((input_dims[1] >= kMinDimensionToUseTiles &&
input_dims[2] < kMinDimensionToUseTiles)) ||
((input_dims[1] < kMinDimensionToUseTiles &&
input_dims[2] >= kMinDimensionToUseTiles));
static const int NumSubTiles = 8;
if (use_tiles) {
// We get best performance when TileSize is the number of threads in a warp
// (32 on our GPUs) and NumSubTiles is 8, so our block size is 8 * 32 = 256
// threads.
static const int TileSize = 32;
static const int NumSubTiles = 8;
Dimension<3> input_dims_in_tiles = {
input_dims[0], (input_dims[1] + TileSize - 1) / TileSize,
(input_dims[2] + TileSize - 1) / TileSize,
};
int total_tiles_count = input_dims_in_tiles[0] * input_dims_in_tiles[1] *
input_dims_in_tiles[2];
// We get best performance when TileSize is the number of threads in a warp
// (32 on our GPUs) and NumSubTiles is 8, so our block size is 8 * 32 = 256
// threads.
SwapDimension1And2InTensor3UsingTiles<T, TileSize, NumSubTiles><<<
total_tiles_count, dim3(TileSize, NumSubTiles), 0, d.stream()>>>(
input, input_dims, output);
} else if (use_small_dim) {
// When only one of the dimensions is smaller than kMinDimensionToUseTiles,
// we use one block to process a rectangle region with the size of
// kTileLength * small_dim. We found that when set kTileLength to 64 on
// TitanX Maxwell GPU, it achieves the best performance.
// large_dim
// +---------------...--------+
// | | | |
// small_dim | | ... | |
// | | | |
// +--------------...---------+
// \----- ------/ \- -/
// V V
// kTileLength(tile_height) tile_height
static const int kTileLength = 64;
static const int kGridDimY = 65535;
int large_dim = std::max(input_dims[2], input_dims[1]);
int tile_num_per_block = (large_dim + kTileLength - 1) / kTileLength;
int grid_dim_y = std::min(input_dims[0], kGridDimY);
int batch_per_block = (input_dims[0] + grid_dim_y - 1) / grid_dim_y;
if (input_dims[2] < input_dims[1]) {
SwapDimension1And2InTensor3SmallDim<
T, kTileLength * kMinDimensionToUseTiles, true>
<<<dim3(tile_num_per_block, grid_dim_y), kTileLength, 0,
d.stream()>>>(input, batch_per_block, input_dims, output);
} else {
SwapDimension1And2InTensor3SmallDim<
T, kTileLength * kMinDimensionToUseTiles, false>
<<<dim3(tile_num_per_block, grid_dim_y), kTileLength, 0,
d.stream()>>>(input, batch_per_block, input_dims, output);
}
} else {
int total_element_count = input_dims[0] * input_dims[1] * input_dims[2];
CudaLaunchConfig config = GetCudaLaunchConfig(total_element_count, d);

View File

@ -4060,6 +4060,26 @@ cuda_py_test(
main = "ops/concat_benchmark.py",
)
cuda_py_test(
name = "conv2d_benchmark",
size = "large",
srcs = ["ops/conv2d_benchmark.py"],
additional_deps = [
":client",
":client_testlib",
":control_flow_ops",
":framework_for_generated_wrappers",
":nn_ops",
":platform",
":platform_benchmark",
":random_ops",
":variables",
"//third_party/py/numpy",
"//tensorflow/core:protos_all_py",
],
main = "ops/conv2d_benchmark.py",
)
cuda_py_test(
name = "split_benchmark",
srcs = ["ops/split_benchmark.py"],

View File

@ -229,6 +229,80 @@ class TransposeTest(test.TestCase):
self.assertAllEqual(np_ans, tf_ans)
self.assertShapeEqual(np_ans, y)
def testLargeSizeGPU(self):
# If no GPU available, skip the test
if not test.is_gpu_available(cuda_only=True):
return
large_shapes = [[1000000, 31, 3], [3, 1000000, 31], [3, 31, 1000000],
[10000, 310, 3], [3, 10000, 310], [3, 310, 10000],
[2, 1000, 1000], [1000, 2, 1000], [1000, 1000, 2]]
perms = [[0, 2, 1]] * 9
for input_shape, perm in zip(large_shapes, perms):
total_size = np.prod(input_shape)
inp = np.arange(1, total_size + 1, dtype=np.float32).reshape(input_shape)
np_ans = self._np_transpose(inp, perm)
with self.test_session(use_gpu=True):
inx = ops.convert_to_tensor(inp)
y = array_ops.transpose(inx, perm)
tf_ans = y.eval()
self.assertAllEqual(np_ans, tf_ans)
self.assertShapeEqual(np_ans, y)
def testRandomizedSmallDimLargeSizeGPU(self):
# If no GPU available, skip the test
if not test.is_gpu_available(cuda_only=True):
return
# Draw 10 random shapes with large dimension sizes.
# 40% prob to generate dim[0] size within [1, 2047]
# 40% prob to generate dim[0] size within [2048, 4095]
# 20% prob to generate dim[0] size within [4096, 100000]
# 50% prob to use dim[1] as the small dim (<16)
num_samples = 10
total_size = 500000
small_size_limit = 2048
large_size_limit = 95905
small_size_percentage = 0.4
medium_size_percentage = 0.4
large_size_percentage = 0.2
perms = [[0, 2, 1]] * num_samples
dim_zero_sizes = []
dim_zero_sizes += list(
np.random.randint(
small_size_limit, size=int(small_size_percentage * num_samples)) +
1)
dim_zero_sizes += list(
np.random.randint(
small_size_limit, size=int(medium_size_percentage * num_samples)) +
small_size_limit)
dim_zero_sizes += list(
np.random.randint(
large_size_limit, size=int(large_size_percentage * num_samples)) +
small_size_limit * 2)
input_shapes = []
small_dim_limit = 16
for dim_zero_size in dim_zero_sizes:
small_dim_size = np.random.randint(small_dim_limit - 1) + 1
large_dim_size = int(
total_size / dim_zero_size / small_dim_size) + small_dim_limit
input_shapes += ([[dim_zero_size, small_dim_size, large_dim_size]]
if np.random.randint(2) else
[[dim_zero_size, large_dim_size, small_dim_size]])
for input_shape, perm in zip(input_shapes, perms):
# generate input data with random ints from 0 to 9.
inp = np.random.randint(10, size=input_shape)
np_ans = self._np_transpose(inp, perm)
with self.test_session(use_gpu=True):
inx = ops.convert_to_tensor(inp)
y = array_ops.transpose(inx, perm)
tf_ans = y.eval()
self.assertAllEqual(np_ans, tf_ans)
self.assertShapeEqual(np_ans, y)
self._ClearCachedSession()
def testNop(self):
self._compareCpu(np.arange(0, 6).reshape([3, 2]).astype(np.float32), [0, 1])

View File

@ -0,0 +1,141 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Benchmark for Conv2D op."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
import time
from tensorflow.python.client import session as session_lib
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
def build_graph(device, input_shape, filter_shape, strides, padding, num_iters):
"""builds a graph containing a sequence of conv2d operations.
Args:
device: String, the device to run on.
input_shape: Shape of the input tensor.
filter_shape: Shape of the filter tensor.
strides: A list of ints. 1-D of length 4. The stride of sliding
window for each dimension of input.
padding: A string from: "SAME", "VALID". The type of padding
algorithm to use.
num_iters: number of iterations to run conv2d.
Returns:
An array of tensors to run()
"""
with ops.device("/%s:0" % device):
inp = variables.Variable(random_ops.truncated_normal(input_shape))
filt = variables.Variable(random_ops.truncated_normal(filter_shape))
outputs = []
conv2d_op = nn_ops.conv2d(inp, filt, strides, padding, data_format="NHWC")
outputs.append(conv2d_op)
for _ in range(1, num_iters):
with ops.control_dependencies([conv2d_op]):
conv2d_op = nn_ops.conv2d(
inp, filt, strides, padding, data_format="NHWC")
outputs.append(conv2d_op)
return control_flow_ops.group(*outputs)
class Conv2DBenchmark(test.Benchmark):
"""Benchmark conv2d!"""
def _run_graph(self, device, input_shape, filter_shape, strides, padding,
num_iters):
"""runs the graph and print its execution time.
Args:
device: String, the device to run on.
input_shape: Shape of the input tensor.
filter_shape: Shape of the filter tensor.
strides: A list of ints. 1-D of length 4. The stride of sliding
window for each dimension of input.
padding: A string from: "SAME", "VALID". The type of padding
algorithm to use. num_iters: Number of iterations to run the
benchmark.
num_iters: number of iterations to run conv2d.
Returns:
The duration of the run in seconds.
"""
graph = ops.Graph()
with graph.as_default():
outputs = build_graph(device, input_shape, filter_shape, strides, padding,
num_iters)
with session_lib.Session(graph=graph) as session:
variables.global_variables_initializer().run()
# warmup runs
session.run(outputs)
start_time = time.time()
session.run(outputs)
duration = (time.time() - start_time) / num_iters
print("%s inputshape:%s filtershape:%s strides:%s padding:%s "
"%d iters: %.8f sec" %
(device, str(input_shape).replace(" ", ""),
str(filter_shape).replace(" ", ""),
str(strides).replace(" ", ""), padding, num_iters, duration))
name_template = (
"conv2d_{device}_input_shape_{inputshape}_filter_shape_{filtershape}_"
"strides_{strides}_padding_{padding}")
self.report_benchmark(
name=name_template.format(
device=device,
inputshape=str(input_shape).replace(" ", ""),
filtershape=str(filter_shape).replace(" ", ""),
strides=str(strides).replace(" ", ""),
padding=padding).replace(" ", ""),
iters=num_iters,
wall_time=duration / num_iters)
return duration
def benchmark_conv2d(self):
print("conv2d benchmark:")
h = 500
w = 500
fh = 3
fw = 3
input_shapes = []
filter_shapes = []
for b, c in itertools.product([4, 16, 32], [i for i in range(3, 16)]):
input_shapes += [[b, h, w, c]]
filter_shapes += [[fh, fw, c, b]]
strides = [[1, 2, 2, 1]]
paddings = ["VALID", "SAME"]
for ishape, fshape in zip(input_shapes, filter_shapes):
for stride in strides:
for padding in paddings:
self._run_graph("gpu", ishape, fshape, stride, padding, 80)
if __name__ == "__main__":
test.main()

View File

@ -1,4 +1,4 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -32,7 +32,7 @@ from tensorflow.python.platform import test
def build_graph(device, input_shape, perm, datatype, num_iters):
"""Build a graph containing a sequence of conv2d operations.
"""builds a graph containing a sequence of conv2d operations.
Args:
device: String, the device to run on.
@ -50,10 +50,12 @@ def build_graph(device, input_shape, perm, datatype, num_iters):
t = constant_op.constant(inp, shape=input_shape)
outputs = []
outputs.append(array_ops.transpose(t, perm))
for i in range(1, num_iters):
with ops.control_dependencies([outputs[i - 1]]):
outputs.append(array_ops.transpose(t, perm))
transpose_op = array_ops.transpose(t, perm)
outputs.append(transpose_op)
for _ in range(1, num_iters):
with ops.control_dependencies([transpose_op]):
transpose_op = array_ops.transpose(t, perm)
outputs.append(transpose_op)
return control_flow_ops.group(*outputs)
@ -61,7 +63,7 @@ class TransposeBenchmark(test.Benchmark):
"""Benchmark transpose!"""
def _run_graph(self, device, input_shape, perm, num_iters, datatype):
"""Run the graph and print its execution time.
"""runs the graph and print its execution time.
Args:
device: String, the device to run on.
@ -82,9 +84,11 @@ class TransposeBenchmark(test.Benchmark):
session.run(outputs)
start_time = time.time()
session.run(outputs)
duration = (time.time() - start_time) / num_iters
throughput = np.prod(
np.array(input_shape)) * datatype().itemsize * 2 / duration / 1e9
print("%s %s inputshape:%s perm:%s %d %.6fsec, %.4fGB/s." %
(device, str(datatype), str(input_shape).replace(" ", ""),
str(perm).replace(" ", ""), num_iters, duration, throughput))
@ -108,12 +112,12 @@ class TransposeBenchmark(test.Benchmark):
datatypes = [np.complex128, np.float64, np.float32, np.float16, np.int8]
small_shapes = [[2, 20, 20, 20, 16], [2, 16, 20, 20, 20]] * 2 + [[
2, 100, 100, 16
], [2, 16, 100, 100]] * 2 + [[2, 5000, 16], [2, 16, 5000]] * 2
small_perms = [[0, 4, 1, 2, 3], [0, 2, 3, 4, 1]] + [[4, 1, 2, 3, 0]] * 2 + [
[0, 3, 1, 2], [0, 2, 3, 1]
] + [[3, 1, 2, 0]] * 2 + [[0, 2, 1]] * 2 + [[2, 1, 0]] * 2
small_shapes = [[2, 20, 20, 20, 16], [2, 16, 20, 20, 20]] * 2
small_shapes += [[2, 100, 100, 16], [2, 16, 100, 100]] * 2
small_shapes += [[2, 5000, 16], [2, 16, 5000]] * 2
small_perms = [[0, 4, 1, 2, 3], [0, 2, 3, 4, 1]] + [[4, 1, 2, 3, 0]] * 2
small_perms += [[0, 3, 1, 2], [0, 2, 3, 1]] + [[3, 1, 2, 0]] * 2
small_perms += [[0, 2, 1]] * 2 + [[2, 1, 0]] * 2
large_shapes = [[2, 40, 40, 40, 32], [2, 40, 40, 40, 64]] * 2 + [[
2, 300, 300, 32
@ -132,5 +136,23 @@ class TransposeBenchmark(test.Benchmark):
for ishape, perm in zip(large_shapes, large_perms):
self._run_graph("gpu", ishape, perm, num_iters, datatype)
small_dim_large_shapes = [[2, 10000, 3], [2, 3, 10000], [2, 10000, 8],
[2, 8, 10000]]
small_dim_small_shapes = [[2, 5000, 3], [2, 3, 5000], [2, 5000, 8],
[2, 8, 5000]]
small_dim_perms = [[0, 2, 1]] * 4
num_iters = 320
small_dim_large_shape_datatypes = [np.float64, np.float32, np.int8]
for datatype in small_dim_large_shape_datatypes:
for ishape, perm in zip(small_dim_large_shapes, small_dim_perms):
self._run_graph("gpu", ishape, perm, num_iters, datatype)
small_dim_small_shape_datatypes = [np.complex128, np.float16]
for datatype in small_dim_small_shape_datatypes:
for ishape, perm in zip(small_dim_small_shapes, small_dim_perms):
self._run_graph("gpu", ishape, perm, num_iters, datatype)
if __name__ == "__main__":
test.main()