Fix transpose bug for large dimension.
Add random tests of large shapes for better coverage. Update transpose benchmark with cases that swap one small dimension with one large dimension. PiperOrigin-RevId: 171302097
This commit is contained in:
parent
3251bc0792
commit
2daa40f9d0
@ -272,6 +272,88 @@ __global__ void SwapDimension1And2InTensor3UsingTiles(const T* input,
|
||||
}
|
||||
}
|
||||
|
||||
// Use shared memory tiles to swap dimension-1 and dimension-2 of a 3D tensor
|
||||
// when only one of the dimension sizes is smaller than 16,
|
||||
// where dimensions are zero-based: output[i][j][k] = input[i][k][j].
|
||||
//
|
||||
// small_dim = the_smaller_dimension_size
|
||||
// large_dim = the_larger_dimension_size
|
||||
// tile_num_per_block = blockDim.x
|
||||
// kTileLength = small_dim
|
||||
//
|
||||
// Each thread block operates on a single rectangle tile, where its width is
|
||||
// kTileLength (we currently set it to 64) and its height is small_dim,
|
||||
// We set the thread block's X dimension to be tile_num_per_block, and its Y
|
||||
// and Z to be one.
|
||||
template <typename T, int ShmemSize, bool SmallDim2>
|
||||
__global__ void SwapDimension1And2InTensor3SmallDim(const T* input,
|
||||
int batch_per_block,
|
||||
Dimension<3> input_dims,
|
||||
T* output) {
|
||||
// TODO(yangzihao) avoid share memory bank conflict.
|
||||
__shared__ T shared_memory_tile[ShmemSize];
|
||||
|
||||
eigen_assert(blockDim.y == 1);
|
||||
eigen_assert(blockDim.z == 1);
|
||||
eigen_assert(gridDim.z == 1);
|
||||
|
||||
int block_offset = blockIdx.x * blockDim.x;
|
||||
|
||||
int x = threadIdx.x;
|
||||
int tile_height = blockDim.x;
|
||||
|
||||
// Get tile height, width, and thread/block origin indices.
|
||||
int small_dim = SmallDim2 ? input_dims[2] : input_dims[1];
|
||||
int large_dim = SmallDim2 ? input_dims[1] : input_dims[2];
|
||||
|
||||
int global_offset = small_dim * large_dim * (blockIdx.y * batch_per_block) +
|
||||
(SmallDim2 ? block_offset * small_dim : block_offset);
|
||||
if (global_offset >= (input_dims[0] * input_dims[1] * input_dims[2])) return;
|
||||
|
||||
for (int batch = 0; batch < batch_per_block; ++batch) {
|
||||
int block_origin_idx =
|
||||
small_dim * large_dim * (blockIdx.y * batch_per_block + batch);
|
||||
int thread_origin_idx =
|
||||
block_origin_idx +
|
||||
(SmallDim2 ? block_offset * small_dim : block_offset) + x;
|
||||
|
||||
if (block_offset + blockDim.x > large_dim) {
|
||||
tile_height = large_dim - block_offset;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Load a continuous memory region to shared memory tile.
|
||||
if (x < tile_height) {
|
||||
for (int y = 0; y < small_dim; y++) {
|
||||
int shmem_index =
|
||||
SmallDim2 ? (x + y * tile_height) : (x * small_dim + y);
|
||||
shared_memory_tile[shmem_index] =
|
||||
ldg(input + thread_origin_idx +
|
||||
y * (SmallDim2 ? tile_height : large_dim));
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Get block origin index for output array.
|
||||
int output_block_offset = block_origin_idx;
|
||||
int output_block_idx = SmallDim2 ? block_offset : block_offset * small_dim;
|
||||
int output_block_origin_idx = output_block_offset + output_block_idx;
|
||||
|
||||
// Store the tranposed memory region in shared memory to device.
|
||||
if (x < tile_height) {
|
||||
for (int y = 0; y < small_dim; y++) {
|
||||
int output_idx = output_block_origin_idx + x +
|
||||
y * (SmallDim2 ? large_dim : tile_height);
|
||||
int shmem_index =
|
||||
SmallDim2 ? (x * small_dim + y) : (x + y * tile_height);
|
||||
output[output_idx] = shared_memory_tile[shmem_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// A Cuda custom kernel that convert input to output, given proper padding on
|
||||
// the left and the top. The padded value is zero.
|
||||
template <typename T, int NDIMS>
|
||||
@ -420,25 +502,62 @@ template <typename T>
|
||||
void RunSwapDimension1And2InTensor3(const GPUDevice& d, const T* input,
|
||||
const Dimension<3>& input_dims, T* output) {
|
||||
// If both dimensions are not trivial, use tiles for the actual swapping.
|
||||
// If one dimension is trivial, use SmallDim kernel for swapping.
|
||||
// Otherwise, the trivial swapping relying on the ldg cache is more efficient.
|
||||
static const int kMinDimensionToUseTiles = 16;
|
||||
bool use_tiles = (input_dims[1] >= kMinDimensionToUseTiles &&
|
||||
input_dims[2] >= kMinDimensionToUseTiles);
|
||||
bool use_small_dim = ((input_dims[1] >= kMinDimensionToUseTiles &&
|
||||
input_dims[2] < kMinDimensionToUseTiles)) ||
|
||||
((input_dims[1] < kMinDimensionToUseTiles &&
|
||||
input_dims[2] >= kMinDimensionToUseTiles));
|
||||
static const int NumSubTiles = 8;
|
||||
|
||||
if (use_tiles) {
|
||||
// We get best performance when TileSize is the number of threads in a warp
|
||||
// (32 on our GPUs) and NumSubTiles is 8, so our block size is 8 * 32 = 256
|
||||
// threads.
|
||||
static const int TileSize = 32;
|
||||
static const int NumSubTiles = 8;
|
||||
Dimension<3> input_dims_in_tiles = {
|
||||
input_dims[0], (input_dims[1] + TileSize - 1) / TileSize,
|
||||
(input_dims[2] + TileSize - 1) / TileSize,
|
||||
};
|
||||
int total_tiles_count = input_dims_in_tiles[0] * input_dims_in_tiles[1] *
|
||||
input_dims_in_tiles[2];
|
||||
// We get best performance when TileSize is the number of threads in a warp
|
||||
// (32 on our GPUs) and NumSubTiles is 8, so our block size is 8 * 32 = 256
|
||||
// threads.
|
||||
SwapDimension1And2InTensor3UsingTiles<T, TileSize, NumSubTiles><<<
|
||||
total_tiles_count, dim3(TileSize, NumSubTiles), 0, d.stream()>>>(
|
||||
input, input_dims, output);
|
||||
} else if (use_small_dim) {
|
||||
// When only one of the dimensions is smaller than kMinDimensionToUseTiles,
|
||||
// we use one block to process a rectangle region with the size of
|
||||
// kTileLength * small_dim. We found that when set kTileLength to 64 on
|
||||
// TitanX Maxwell GPU, it achieves the best performance.
|
||||
// large_dim
|
||||
// +---------------...--------+
|
||||
// | | | |
|
||||
// small_dim | | ... | |
|
||||
// | | | |
|
||||
// +--------------...---------+
|
||||
// \----- ------/ \- -/
|
||||
// V V
|
||||
// kTileLength(tile_height) tile_height
|
||||
static const int kTileLength = 64;
|
||||
static const int kGridDimY = 65535;
|
||||
int large_dim = std::max(input_dims[2], input_dims[1]);
|
||||
int tile_num_per_block = (large_dim + kTileLength - 1) / kTileLength;
|
||||
int grid_dim_y = std::min(input_dims[0], kGridDimY);
|
||||
int batch_per_block = (input_dims[0] + grid_dim_y - 1) / grid_dim_y;
|
||||
if (input_dims[2] < input_dims[1]) {
|
||||
SwapDimension1And2InTensor3SmallDim<
|
||||
T, kTileLength * kMinDimensionToUseTiles, true>
|
||||
<<<dim3(tile_num_per_block, grid_dim_y), kTileLength, 0,
|
||||
d.stream()>>>(input, batch_per_block, input_dims, output);
|
||||
} else {
|
||||
SwapDimension1And2InTensor3SmallDim<
|
||||
T, kTileLength * kMinDimensionToUseTiles, false>
|
||||
<<<dim3(tile_num_per_block, grid_dim_y), kTileLength, 0,
|
||||
d.stream()>>>(input, batch_per_block, input_dims, output);
|
||||
}
|
||||
} else {
|
||||
int total_element_count = input_dims[0] * input_dims[1] * input_dims[2];
|
||||
CudaLaunchConfig config = GetCudaLaunchConfig(total_element_count, d);
|
||||
|
@ -4060,6 +4060,26 @@ cuda_py_test(
|
||||
main = "ops/concat_benchmark.py",
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "conv2d_benchmark",
|
||||
size = "large",
|
||||
srcs = ["ops/conv2d_benchmark.py"],
|
||||
additional_deps = [
|
||||
":client",
|
||||
":client_testlib",
|
||||
":control_flow_ops",
|
||||
":framework_for_generated_wrappers",
|
||||
":nn_ops",
|
||||
":platform",
|
||||
":platform_benchmark",
|
||||
":random_ops",
|
||||
":variables",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
],
|
||||
main = "ops/conv2d_benchmark.py",
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "split_benchmark",
|
||||
srcs = ["ops/split_benchmark.py"],
|
||||
|
@ -229,6 +229,80 @@ class TransposeTest(test.TestCase):
|
||||
self.assertAllEqual(np_ans, tf_ans)
|
||||
self.assertShapeEqual(np_ans, y)
|
||||
|
||||
def testLargeSizeGPU(self):
|
||||
# If no GPU available, skip the test
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
|
||||
large_shapes = [[1000000, 31, 3], [3, 1000000, 31], [3, 31, 1000000],
|
||||
[10000, 310, 3], [3, 10000, 310], [3, 310, 10000],
|
||||
[2, 1000, 1000], [1000, 2, 1000], [1000, 1000, 2]]
|
||||
perms = [[0, 2, 1]] * 9
|
||||
|
||||
for input_shape, perm in zip(large_shapes, perms):
|
||||
total_size = np.prod(input_shape)
|
||||
inp = np.arange(1, total_size + 1, dtype=np.float32).reshape(input_shape)
|
||||
np_ans = self._np_transpose(inp, perm)
|
||||
with self.test_session(use_gpu=True):
|
||||
inx = ops.convert_to_tensor(inp)
|
||||
y = array_ops.transpose(inx, perm)
|
||||
tf_ans = y.eval()
|
||||
self.assertAllEqual(np_ans, tf_ans)
|
||||
self.assertShapeEqual(np_ans, y)
|
||||
|
||||
def testRandomizedSmallDimLargeSizeGPU(self):
|
||||
# If no GPU available, skip the test
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
|
||||
# Draw 10 random shapes with large dimension sizes.
|
||||
# 40% prob to generate dim[0] size within [1, 2047]
|
||||
# 40% prob to generate dim[0] size within [2048, 4095]
|
||||
# 20% prob to generate dim[0] size within [4096, 100000]
|
||||
# 50% prob to use dim[1] as the small dim (<16)
|
||||
num_samples = 10
|
||||
total_size = 500000
|
||||
small_size_limit = 2048
|
||||
large_size_limit = 95905
|
||||
small_size_percentage = 0.4
|
||||
medium_size_percentage = 0.4
|
||||
large_size_percentage = 0.2
|
||||
perms = [[0, 2, 1]] * num_samples
|
||||
dim_zero_sizes = []
|
||||
dim_zero_sizes += list(
|
||||
np.random.randint(
|
||||
small_size_limit, size=int(small_size_percentage * num_samples)) +
|
||||
1)
|
||||
dim_zero_sizes += list(
|
||||
np.random.randint(
|
||||
small_size_limit, size=int(medium_size_percentage * num_samples)) +
|
||||
small_size_limit)
|
||||
dim_zero_sizes += list(
|
||||
np.random.randint(
|
||||
large_size_limit, size=int(large_size_percentage * num_samples)) +
|
||||
small_size_limit * 2)
|
||||
input_shapes = []
|
||||
small_dim_limit = 16
|
||||
for dim_zero_size in dim_zero_sizes:
|
||||
small_dim_size = np.random.randint(small_dim_limit - 1) + 1
|
||||
large_dim_size = int(
|
||||
total_size / dim_zero_size / small_dim_size) + small_dim_limit
|
||||
input_shapes += ([[dim_zero_size, small_dim_size, large_dim_size]]
|
||||
if np.random.randint(2) else
|
||||
[[dim_zero_size, large_dim_size, small_dim_size]])
|
||||
|
||||
for input_shape, perm in zip(input_shapes, perms):
|
||||
# generate input data with random ints from 0 to 9.
|
||||
inp = np.random.randint(10, size=input_shape)
|
||||
np_ans = self._np_transpose(inp, perm)
|
||||
with self.test_session(use_gpu=True):
|
||||
inx = ops.convert_to_tensor(inp)
|
||||
y = array_ops.transpose(inx, perm)
|
||||
tf_ans = y.eval()
|
||||
self.assertAllEqual(np_ans, tf_ans)
|
||||
self.assertShapeEqual(np_ans, y)
|
||||
self._ClearCachedSession()
|
||||
|
||||
def testNop(self):
|
||||
self._compareCpu(np.arange(0, 6).reshape([3, 2]).astype(np.float32), [0, 1])
|
||||
|
||||
|
141
tensorflow/python/ops/conv2d_benchmark.py
Normal file
141
tensorflow/python/ops/conv2d_benchmark.py
Normal file
@ -0,0 +1,141 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Benchmark for Conv2D op."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import itertools
|
||||
import time
|
||||
|
||||
from tensorflow.python.client import session as session_lib
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def build_graph(device, input_shape, filter_shape, strides, padding, num_iters):
|
||||
"""builds a graph containing a sequence of conv2d operations.
|
||||
|
||||
Args:
|
||||
device: String, the device to run on.
|
||||
input_shape: Shape of the input tensor.
|
||||
filter_shape: Shape of the filter tensor.
|
||||
strides: A list of ints. 1-D of length 4. The stride of sliding
|
||||
window for each dimension of input.
|
||||
padding: A string from: "SAME", "VALID". The type of padding
|
||||
algorithm to use.
|
||||
num_iters: number of iterations to run conv2d.
|
||||
|
||||
Returns:
|
||||
An array of tensors to run()
|
||||
"""
|
||||
with ops.device("/%s:0" % device):
|
||||
inp = variables.Variable(random_ops.truncated_normal(input_shape))
|
||||
filt = variables.Variable(random_ops.truncated_normal(filter_shape))
|
||||
|
||||
outputs = []
|
||||
conv2d_op = nn_ops.conv2d(inp, filt, strides, padding, data_format="NHWC")
|
||||
outputs.append(conv2d_op)
|
||||
for _ in range(1, num_iters):
|
||||
with ops.control_dependencies([conv2d_op]):
|
||||
conv2d_op = nn_ops.conv2d(
|
||||
inp, filt, strides, padding, data_format="NHWC")
|
||||
outputs.append(conv2d_op)
|
||||
return control_flow_ops.group(*outputs)
|
||||
|
||||
|
||||
class Conv2DBenchmark(test.Benchmark):
|
||||
"""Benchmark conv2d!"""
|
||||
|
||||
def _run_graph(self, device, input_shape, filter_shape, strides, padding,
|
||||
num_iters):
|
||||
"""runs the graph and print its execution time.
|
||||
|
||||
Args:
|
||||
device: String, the device to run on.
|
||||
input_shape: Shape of the input tensor.
|
||||
filter_shape: Shape of the filter tensor.
|
||||
strides: A list of ints. 1-D of length 4. The stride of sliding
|
||||
window for each dimension of input.
|
||||
padding: A string from: "SAME", "VALID". The type of padding
|
||||
algorithm to use. num_iters: Number of iterations to run the
|
||||
benchmark.
|
||||
num_iters: number of iterations to run conv2d.
|
||||
|
||||
Returns:
|
||||
The duration of the run in seconds.
|
||||
"""
|
||||
graph = ops.Graph()
|
||||
with graph.as_default():
|
||||
outputs = build_graph(device, input_shape, filter_shape, strides, padding,
|
||||
num_iters)
|
||||
with session_lib.Session(graph=graph) as session:
|
||||
variables.global_variables_initializer().run()
|
||||
# warmup runs
|
||||
session.run(outputs)
|
||||
|
||||
start_time = time.time()
|
||||
session.run(outputs)
|
||||
duration = (time.time() - start_time) / num_iters
|
||||
|
||||
print("%s inputshape:%s filtershape:%s strides:%s padding:%s "
|
||||
"%d iters: %.8f sec" %
|
||||
(device, str(input_shape).replace(" ", ""),
|
||||
str(filter_shape).replace(" ", ""),
|
||||
str(strides).replace(" ", ""), padding, num_iters, duration))
|
||||
|
||||
name_template = (
|
||||
"conv2d_{device}_input_shape_{inputshape}_filter_shape_{filtershape}_"
|
||||
"strides_{strides}_padding_{padding}")
|
||||
|
||||
self.report_benchmark(
|
||||
name=name_template.format(
|
||||
device=device,
|
||||
inputshape=str(input_shape).replace(" ", ""),
|
||||
filtershape=str(filter_shape).replace(" ", ""),
|
||||
strides=str(strides).replace(" ", ""),
|
||||
padding=padding).replace(" ", ""),
|
||||
iters=num_iters,
|
||||
wall_time=duration / num_iters)
|
||||
|
||||
return duration
|
||||
|
||||
def benchmark_conv2d(self):
|
||||
print("conv2d benchmark:")
|
||||
|
||||
h = 500
|
||||
w = 500
|
||||
fh = 3
|
||||
fw = 3
|
||||
input_shapes = []
|
||||
filter_shapes = []
|
||||
for b, c in itertools.product([4, 16, 32], [i for i in range(3, 16)]):
|
||||
input_shapes += [[b, h, w, c]]
|
||||
filter_shapes += [[fh, fw, c, b]]
|
||||
strides = [[1, 2, 2, 1]]
|
||||
paddings = ["VALID", "SAME"]
|
||||
for ishape, fshape in zip(input_shapes, filter_shapes):
|
||||
for stride in strides:
|
||||
for padding in paddings:
|
||||
self._run_graph("gpu", ishape, fshape, stride, padding, 80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -32,7 +32,7 @@ from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def build_graph(device, input_shape, perm, datatype, num_iters):
|
||||
"""Build a graph containing a sequence of conv2d operations.
|
||||
"""builds a graph containing a sequence of conv2d operations.
|
||||
|
||||
Args:
|
||||
device: String, the device to run on.
|
||||
@ -50,10 +50,12 @@ def build_graph(device, input_shape, perm, datatype, num_iters):
|
||||
t = constant_op.constant(inp, shape=input_shape)
|
||||
|
||||
outputs = []
|
||||
outputs.append(array_ops.transpose(t, perm))
|
||||
for i in range(1, num_iters):
|
||||
with ops.control_dependencies([outputs[i - 1]]):
|
||||
outputs.append(array_ops.transpose(t, perm))
|
||||
transpose_op = array_ops.transpose(t, perm)
|
||||
outputs.append(transpose_op)
|
||||
for _ in range(1, num_iters):
|
||||
with ops.control_dependencies([transpose_op]):
|
||||
transpose_op = array_ops.transpose(t, perm)
|
||||
outputs.append(transpose_op)
|
||||
return control_flow_ops.group(*outputs)
|
||||
|
||||
|
||||
@ -61,7 +63,7 @@ class TransposeBenchmark(test.Benchmark):
|
||||
"""Benchmark transpose!"""
|
||||
|
||||
def _run_graph(self, device, input_shape, perm, num_iters, datatype):
|
||||
"""Run the graph and print its execution time.
|
||||
"""runs the graph and print its execution time.
|
||||
|
||||
Args:
|
||||
device: String, the device to run on.
|
||||
@ -82,9 +84,11 @@ class TransposeBenchmark(test.Benchmark):
|
||||
session.run(outputs)
|
||||
start_time = time.time()
|
||||
session.run(outputs)
|
||||
|
||||
duration = (time.time() - start_time) / num_iters
|
||||
throughput = np.prod(
|
||||
np.array(input_shape)) * datatype().itemsize * 2 / duration / 1e9
|
||||
|
||||
print("%s %s inputshape:%s perm:%s %d %.6fsec, %.4fGB/s." %
|
||||
(device, str(datatype), str(input_shape).replace(" ", ""),
|
||||
str(perm).replace(" ", ""), num_iters, duration, throughput))
|
||||
@ -108,12 +112,12 @@ class TransposeBenchmark(test.Benchmark):
|
||||
|
||||
datatypes = [np.complex128, np.float64, np.float32, np.float16, np.int8]
|
||||
|
||||
small_shapes = [[2, 20, 20, 20, 16], [2, 16, 20, 20, 20]] * 2 + [[
|
||||
2, 100, 100, 16
|
||||
], [2, 16, 100, 100]] * 2 + [[2, 5000, 16], [2, 16, 5000]] * 2
|
||||
small_perms = [[0, 4, 1, 2, 3], [0, 2, 3, 4, 1]] + [[4, 1, 2, 3, 0]] * 2 + [
|
||||
[0, 3, 1, 2], [0, 2, 3, 1]
|
||||
] + [[3, 1, 2, 0]] * 2 + [[0, 2, 1]] * 2 + [[2, 1, 0]] * 2
|
||||
small_shapes = [[2, 20, 20, 20, 16], [2, 16, 20, 20, 20]] * 2
|
||||
small_shapes += [[2, 100, 100, 16], [2, 16, 100, 100]] * 2
|
||||
small_shapes += [[2, 5000, 16], [2, 16, 5000]] * 2
|
||||
small_perms = [[0, 4, 1, 2, 3], [0, 2, 3, 4, 1]] + [[4, 1, 2, 3, 0]] * 2
|
||||
small_perms += [[0, 3, 1, 2], [0, 2, 3, 1]] + [[3, 1, 2, 0]] * 2
|
||||
small_perms += [[0, 2, 1]] * 2 + [[2, 1, 0]] * 2
|
||||
|
||||
large_shapes = [[2, 40, 40, 40, 32], [2, 40, 40, 40, 64]] * 2 + [[
|
||||
2, 300, 300, 32
|
||||
@ -132,5 +136,23 @@ class TransposeBenchmark(test.Benchmark):
|
||||
for ishape, perm in zip(large_shapes, large_perms):
|
||||
self._run_graph("gpu", ishape, perm, num_iters, datatype)
|
||||
|
||||
small_dim_large_shapes = [[2, 10000, 3], [2, 3, 10000], [2, 10000, 8],
|
||||
[2, 8, 10000]]
|
||||
small_dim_small_shapes = [[2, 5000, 3], [2, 3, 5000], [2, 5000, 8],
|
||||
[2, 8, 5000]]
|
||||
small_dim_perms = [[0, 2, 1]] * 4
|
||||
|
||||
num_iters = 320
|
||||
small_dim_large_shape_datatypes = [np.float64, np.float32, np.int8]
|
||||
for datatype in small_dim_large_shape_datatypes:
|
||||
for ishape, perm in zip(small_dim_large_shapes, small_dim_perms):
|
||||
self._run_graph("gpu", ishape, perm, num_iters, datatype)
|
||||
|
||||
small_dim_small_shape_datatypes = [np.complex128, np.float16]
|
||||
for datatype in small_dim_small_shape_datatypes:
|
||||
for ishape, perm in zip(small_dim_small_shapes, small_dim_perms):
|
||||
self._run_graph("gpu", ishape, perm, num_iters, datatype)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user