diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index b51fc841d1c..efc5d7c553a 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -2589,7 +2589,9 @@ tf_kernel_library( tf_kernel_library( name = "segment_reduction_ops", prefix = "segment_reduction_ops", - deps = MATH_DEPS, + deps = MATH_DEPS + if_cuda([ + ":cuda_solvers", + ]), ) tf_kernel_library( diff --git a/tensorflow/core/kernels/segment_reduction_ops.cc b/tensorflow/core/kernels/segment_reduction_ops.cc index 9cdbe89457c..8f7eff113cd 100644 --- a/tensorflow/core/kernels/segment_reduction_ops.cc +++ b/tensorflow/core/kernels/segment_reduction_ops.cc @@ -16,6 +16,9 @@ limitations under the License. // See docs in ../ops/math_ops.cc. #define EIGEN_USE_THREADS +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#endif // GOOGLE_CUDA #include "tensorflow/core/kernels/segment_reduction_ops.h" #include @@ -32,6 +35,15 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/util.h" + +#if GOOGLE_CUDA +#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" +#include "tensorflow/core/kernels/cuda_solvers.h" +#include "tensorflow/core/platform/cuda.h" + +using ::perftools::gputools::cuda::ScopedActivateExecutorContext; +#endif // GOOGLE_CUDA + namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; @@ -183,6 +195,105 @@ class SegmentReductionOp : public OpKernel { } }; +#ifdef GOOGLE_CUDA +// SegmentSumGPUOp is a segment sum operator implemented for GPU only. +// TODO: This implementation of SegmentSumGPUOp is sometimes slower than +// its unsorted counterpart (mostly when problem size is small). +// This is due to the following two main reasons and a cost-effective way +// to resolve these problems is desirable. +// 1. Sorted segment sum requires a memory transfer from device to host in +// order to know the size of the output dimension whereas unsorted segment +// sum receives the size of the output dimension as an input parameter. +// 2. Sorted segment sum is essentially a tiled version of unsorted segment +// sum and therefore such optimization comes at an inherent cost. However +// such cost may not be justified when the problem size is small. When to +// use the tiled version or the untiled version depends on many factors +// including data alignments, ratio of calculation to memory traffic and +// obviously, the problem sizes. +template +class SegmentSumGPUOp : public AsyncOpKernel { + public: + explicit SegmentSumGPUOp(OpKernelConstruction* context) + : AsyncOpKernel(context) {} + + void ComputeAsync(OpKernelContext* context, DoneCallback done) override { + const Tensor& input = context->input(0); + const Tensor& segment_ids = context->input(1); + + OP_REQUIRES_ASYNC( + context, TensorShapeUtils::IsVector(segment_ids.shape()), + errors::InvalidArgument("segment_ids should be a vector."), done); + + const int64 num_indices = segment_ids.NumElements(); + OP_REQUIRES_ASYNC( + context, num_indices == input.dim_size(0), + errors::InvalidArgument( + "segment_ids should be the same size as dimension 0 of" + " input."), + done); + + if (num_indices == 0) { + TensorShape output_shape = input.shape(); + output_shape.set_dim(0, 0); + + Tensor* output = nullptr; + OP_REQUIRES_OK_ASYNC( + context, context->allocate_output(0, output_shape, &output), done); + done(); + return; + } + + perftools::gputools::DeviceMemoryBase output_rows_device( + (void*)(segment_ids.template flat().data() + (num_indices - 1))); + ScratchSpace output_rows_host(context, 1, /* on_host */ true); + + auto stream = context->op_device_context()->stream(); + OP_REQUIRES_ASYNC( + context, stream + ->ThenMemcpy(output_rows_host.mutable_data(), + output_rows_device, sizeof(Index)) + .ok(), + errors::Internal( + "SegmentSumGPUOp: failed to copy output_rows from device"), + done); + + functor::SegmentSumFunctor functor_; + auto create_and_check_output = [context, output_rows_host, &input, + &segment_ids, &functor_, done]() { + // Ensure that within the callback, the proper GPU settings are + // configured. + auto stream = context->op_device_context()->stream(); + ScopedActivateExecutorContext scoped_activation{stream->parent()}; + + Index output_rows = *output_rows_host.data(); + output_rows++; + OP_REQUIRES_ASYNC(context, output_rows > 0, + errors::InvalidArgument("segment ids must be >= 0"), + done); + + TensorShape output_shape = input.shape(); + output_shape.set_dim(0, output_rows); + + Tensor* output = nullptr; + OP_REQUIRES_OK_ASYNC( + context, context->allocate_output(0, output_shape, &output), done); + + auto output_flat = output->flat_outer_dims(); + auto data_ptr = input.template flat().data(); + auto segment_flat = segment_ids.flat(); + functor_(context, context->eigen_device(), output_rows, + segment_ids.shape(), segment_flat, input.NumElements(), data_ptr, + output_flat); + + done(); + }; + + context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute( + stream, create_and_check_output); + } +}; +#endif // GOOGLE_CUDA + #define REGISTER_CPU_KERNEL_SEGMENT(name, functor, type, index_type, \ default_value) \ REGISTER_KERNEL_BUILDER( \ @@ -227,6 +338,23 @@ REGISTER_COMPLEX_CPU_KERNELS_ALL(complex128); #undef REGISTER_REAL_CPU_KERNELS_ALL #undef REGISTER_COMPLEX_CPU_KERNELS_ALL +#if GOOGLE_CUDA +#define REGISTER_GPU_SORTED_KERNELS(type, index_type) \ + REGISTER_KERNEL_BUILDER(Name("SegmentSum") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tindices"), \ + SegmentSumGPUOp) + +#define REGISTER_GPU_SORTED_KERNELS_ALL(type) \ + REGISTER_GPU_SORTED_KERNELS(type, int32); \ + REGISTER_GPU_SORTED_KERNELS(type, int64); + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_SORTED_KERNELS_ALL); +#undef REGISTER_GPU_SORTED_KERNELS +#undef REGISTER_GPU_SORTED_KERNELS_ALL +#endif // GOOGLE_CUDA + namespace functor { // UnsortedSegmentSumFunctor implementation for CPUDevice. diff --git a/tensorflow/core/kernels/segment_reduction_ops.h b/tensorflow/core/kernels/segment_reduction_ops.h index ee09c213b7c..412c1d601d3 100644 --- a/tensorflow/core/kernels/segment_reduction_ops.h +++ b/tensorflow/core/kernels/segment_reduction_ops.h @@ -26,6 +26,28 @@ namespace tensorflow { class OpKernelContext; namespace functor { + +#ifdef GOOGLE_CUDA +typedef Eigen::GpuDevice GPUDevice; +// Functor for SegmentSumGPUOp. +// 'output_rows': the number of output segments (unique segment ids in +// 'segment_ids'). +// 'segment_ids_shape': shape of 'segment_ids' tensor. +// 'segment_ids': unsorted map from input to output segment ids at which to +// perform segment sum operation. +// 'data_size': size of input data tensor. +// 'data': input data tensor. +// 'output': output reshaped to {output_rows, output.size/output_rows} +template +struct SegmentSumFunctor { + void operator()(OpKernelContext* ctx, const GPUDevice& d, + const Index output_rows, const TensorShape& segment_ids_shape, + typename TTypes::ConstFlat segment_ids, + const Index data_size, const T* data, + typename TTypes::Tensor output); +}; +#endif + // BaseFunctor for definition of UnsorteSegmentReductionOp // for usage without templates. template diff --git a/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc b/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc index b132b1e8f8b..26fcafee34a 100644 --- a/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc @@ -54,6 +54,77 @@ __device__ __forceinline__ void AccumulateInto( CudaAtomicAdd(dest_scalar + 1, value.imag()); } +// SortedSegmentSumFunctor kernel reduces input data just as +// UnsortedSegmentSumCustomKernel does except that input data +// is partitioned along the outer reduction dimension. This is +// because consecutive rows (elements in a row share the same +// outer dimension index) in the flattened 2D input data likely +// belong to the same segment in sorted segment sum operation. +// Therefore such partitioning strategy has two advantages over +// the UnsortedSegmentSumFunctor kernel: +// 1. Each thread reduces across multiple rows before writing +// answers to the global memory, we can therefore +// write reduction results to global memory less often. +// 2. We may know that the current thread is the only contributor +// to an output element because of the increasing nature of segment +// ids. In such cases, we do not need to use atomic operations +// to write results to global memory. +// In the flattened view of input data (with only outer and inner +// dimension), every thread processes a strip of input data of +// size OuterDimTileSize x 1. This strip runs across multiple +// rows of input data and all reduction elements share one inner +// dimension index. +template +__global__ void SortedSegmentSumCustomKernel(const Index input_outer_dim_size, + const Index inner_dim_size, + const Index output_outer_dim_size, + const Index* segment_ids, + const T* input, T* output, + const Index total_stripe_count) { + CUDA_1D_KERNEL_LOOP(stripe_index, total_stripe_count) { + const Index segment_offset = stripe_index % inner_dim_size; + const Index input_outer_dim_index_base = + stripe_index / inner_dim_size * Index(OuterDimTileSize); + + T sum = T(0); + Index first_segment_id = segment_ids[input_outer_dim_index_base]; + Index last_output_segment_id = output_outer_dim_size; + + const Index actual_stripe_height = + min(Index(OuterDimTileSize), + input_outer_dim_size - input_outer_dim_index_base); + for (Index j = 0; j < actual_stripe_height; j++) { + Index current_output_segment_id = + segment_ids[input_outer_dim_index_base + j]; + // Decide whether to write result to global memory. + // Result is only written to global memory if we move + // to another segment. Otherwise we can keep accumulating + // locally. + if (current_output_segment_id > last_output_segment_id) { + const Index output_index = + last_output_segment_id * inner_dim_size + segment_offset; + // decide whether to write result to global memory using atomic + // operations + if (last_output_segment_id == first_segment_id) { + AccumulateInto(output + output_index, sum); + } else { + *(output + output_index) = sum; + } + sum = T(0); + } + sum += ldg(input + (input_outer_dim_index_base + j) * inner_dim_size + + segment_offset); + last_output_segment_id = current_output_segment_id; + } + // For the last result in a strip, always write using atomic operations + // due to possible race conditions with threads computing + // the following strip. + const Index output_index = + last_output_segment_id * inner_dim_size + segment_offset; + AccumulateInto(output + output_index, sum); + } +} + // UnsortedSegmentSumFunctor kernel processes 'input_total_size' elements. // Each element is mapped from input to output by a combination of its // 'segment_ids' mapping and 'inner_dim_size'. @@ -80,6 +151,47 @@ __global__ void UnsortedSegmentSumCustomKernel( namespace functor { +template +void SegmentSumFunctor::operator()( + OpKernelContext* ctx, const GPUDevice& d, const Index output_rows, + const TensorShape& segment_ids_shape, + typename TTypes::ConstFlat segment_ids, const Index data_size, + const T* data, typename TTypes::Tensor output) { + if (output.size() == 0) { + return; + } + // Set 'output' to zeros. + CudaLaunchConfig config = GetCudaLaunchConfig(output.size(), d); + SetZero<<>>( + output.size(), output.data()); + if (data_size == 0 || segment_ids_shape.num_elements() == 0) { + return; + } + + // Launch kernel to compute sorted segment sum. + // Notes: + // *) 'input_total_size' is the total number of elements to process. + // *) 'segment_ids.shape' is a prefix of data's shape. + // *) 'input_outer_dim_size' is the total number of segments to process. + const Index input_total_size = data_size; + const Index input_outer_dim_size = segment_ids.dimension(0); + const Index input_inner_dim_size = input_total_size / input_outer_dim_size; + + const int OuterDimTileSize = 8; + + const Index input_outer_dim_num_stripe = + Eigen::divup(input_outer_dim_size, Index(OuterDimTileSize)); + + const Index total_stripe_count = + input_inner_dim_size * input_outer_dim_num_stripe; + + config = GetCudaLaunchConfig(total_stripe_count, d); + SortedSegmentSumCustomKernel<<< + config.block_count, config.thread_per_block, 0, d.stream()>>>( + input_outer_dim_size, input_inner_dim_size, output_rows, + segment_ids.data(), data, output.data(), total_stripe_count); +}; + // UnsortedSegmentSumFunctor implementation for GPUDevice. template struct UnsortedSegmentSumFunctor: UnsortedSegmentBaseFunctor { @@ -117,6 +229,15 @@ struct UnsortedSegmentSumFunctor: UnsortedSegmentBaseFuncto } }; +#define DEFINE_SORTED_GPU_SPECS_INDEX(T, Index) \ + template struct SegmentSumFunctor + +#define DEFINE_SORTED_GPU_SPECS(T) \ + DEFINE_SORTED_GPU_SPECS_INDEX(T, int32); \ + DEFINE_SORTED_GPU_SPECS_INDEX(T, int64); + +TF_CALL_GPU_NUMBER_TYPES(DEFINE_SORTED_GPU_SPECS); + #define DEFINE_GPU_SPECS_INDEX(T, Index) \ template struct UnsortedSegmentSumFunctor diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index e432998c21d..9538be95e3e 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -683,13 +683,15 @@ cuda_py_test( tf_py_test( name = "segment_reduction_ops_test", - size = "small", + size = "medium", srcs = ["segment_reduction_ops_test.py"], additional_deps = [ "//third_party/py/numpy", + "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", + "//tensorflow/python:variables", "//tensorflow/python:nn_grad", ], ) diff --git a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py index 33269c91234..5e426fc61a7 100644 --- a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py +++ b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py @@ -18,12 +18,17 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import itertools + import numpy as np +from tensorflow.python.client import session +from tensorflow.python.framework import ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes as dtypes_lib from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables import tensorflow.python.ops.nn_grad # pylint: disable=unused-import from tensorflow.python.platform import test @@ -107,19 +112,19 @@ class SegmentReductionOpTest(SegmentReductionHelper): curr_ops_list = complex_ops_list else: curr_ops_list = ops_list - - with self.test_session(use_gpu=False): - tf_x, np_x = self._input(shape, dtype=dtype) - for np_op1, np_op2, tf_op in curr_ops_list: - np_ans = self._segmentReduce(indices, np_x, np_op1, np_op2) - s = tf_op(data=tf_x, segment_ids=indices) - tf_ans = s.eval() - self.assertAllClose(np_ans, tf_ans) - # NOTE(mrry): The static shape inference that computes - # `tf_ans.shape` can only infer that sizes from dimension 1 - # onwards, because the size of dimension 0 is data-dependent - # and may therefore vary dynamically. - self.assertAllEqual(np_ans.shape[1:], tf_ans.shape[1:]) + for use_gpu in [True, False]: + with self.test_session(use_gpu=use_gpu): + tf_x, np_x = self._input(shape, dtype=dtype) + for np_op1, np_op2, tf_op in curr_ops_list: + np_ans = self._segmentReduce(indices, np_x, np_op1, np_op2) + s = tf_op(data=tf_x, segment_ids=indices) + tf_ans = s.eval() + self.assertAllClose(np_ans, tf_ans) + # NOTE(mrry): The static shape inference that computes + # `tf_ans.shape` can only infer that sizes from dimension 1 + # onwards, because the size of dimension 0 is data-dependent + # and may therefore vary dynamically. + self.assertAllEqual(np_ans.shape[1:], tf_ans.shape[1:]) def testSegmentIdsShape(self): shape = [4, 4] @@ -130,41 +135,45 @@ class SegmentReductionOpTest(SegmentReductionHelper): def testSegmentIdsSize(self): shape = [4, 4] - with self.test_session(): - tf_x, _ = self._input(shape) - indices = [0, 1] - s = math_ops.segment_sum(data=tf_x, segment_ids=indices) - with self.assertRaisesOpError("segment_ids should be the same size"): - s.eval() + for use_gpu in [True, False]: + with self.test_session(use_gpu=use_gpu): + tf_x, _ = self._input(shape) + indices = [0, 1] + s = math_ops.segment_sum(data=tf_x, segment_ids=indices) + with self.assertRaisesOpError("segment_ids should be the same size"): + s.eval() def testSegmentIdsValid(self): # This is a baseline for the following SegmentIdsInvalid* tests. shape = [4, 4] - with self.test_session(): - tf_x, _ = self._input(shape) - indices = [0, 0, 0, 1] - result = math_ops.segment_sum(data=tf_x, segment_ids=indices).eval() - self.assertAllEqual([[15, 18, 21, 24], [13, 14, 15, 16]], result) + for use_gpu in [True, False]: + with self.test_session(use_gpu=use_gpu): + tf_x, _ = self._input(shape, dtype=dtypes_lib.float32) + indices = [0, 0, 0, 1] + result = math_ops.segment_sum(data=tf_x, segment_ids=indices).eval() + self.assertAllEqual([[15, 18, 21, 24], [13, 14, 15, 16]], result) def testSegmentIdsGreaterThanZero(self): shape = [4, 4] - with self.test_session(): - tf_x, np_x = self._input(shape) - indices = [1, 1, 2, 2] - np_ans = self._segmentReduce(indices, np_x, np.add) - s = math_ops.segment_sum(data=tf_x, segment_ids=indices) - tf_ans = s.eval() - self.assertAllClose(np_ans, tf_ans) + for use_gpu in [True, False]: + with self.test_session(use_gpu=use_gpu): + tf_x, np_x = self._input(shape, dtype=dtypes_lib.float32) + indices = [1, 1, 2, 2] + np_ans = self._segmentReduce(indices, np_x, np.add) + s = math_ops.segment_sum(data=tf_x, segment_ids=indices) + tf_ans = s.eval() + self.assertAllClose(np_ans, tf_ans) def testSegmentIdsHole(self): shape = [4, 4] - with self.test_session(): - tf_x, np_x = self._input(shape) - indices = [0, 0, 3, 3] - np_ans = self._segmentReduce(indices, np_x, np.add) - s = math_ops.segment_sum(data=tf_x, segment_ids=indices) - tf_ans = s.eval() - self.assertAllClose(np_ans, tf_ans) + for use_gpu in [True, False]: + with self.test_session(use_gpu=use_gpu): + tf_x, np_x = self._input(shape, dtype=dtypes_lib.float32) + indices = [0, 0, 3, 3] + np_ans = self._segmentReduce(indices, np_x, np.add) + s = math_ops.segment_sum(data=tf_x, segment_ids=indices) + tf_ans = s.eval() + self.assertAllClose(np_ans, tf_ans) def testSegmentIdsInvalid1(self): shape = [4, 4] @@ -199,21 +208,23 @@ class SegmentReductionOpTest(SegmentReductionHelper): def testSegmentIdsInvalid4(self): shape = [4, 4] - with self.test_session(): - tf_x, _ = self._input(shape) - indices = [0, 0, 0, -1] - s = math_ops.segment_sum(data=tf_x, segment_ids=indices) - with self.assertRaisesOpError("segment ids must be >= 0"): - s.eval() + for use_gpu in [True, False]: + with self.test_session(use_gpu=use_gpu): + tf_x, _ = self._input(shape, dtype=dtypes_lib.float32) + indices = [0, 0, 0, -1] + s = math_ops.segment_sum(data=tf_x, segment_ids=indices) + with self.assertRaisesOpError("segment ids must be >= 0"): + s.eval() def testSegmentIdsInvalid5(self): shape = [4, 4] - with self.test_session(): - tf_x, _ = self._input(shape) - indices = [0, 0, 0, -2] - s = math_ops.segment_sum(data=tf_x, segment_ids=indices) - with self.assertRaisesOpError("segment ids must be >= 0"): - s.eval() + for use_gpu in [True, False]: + with self.test_session(use_gpu=use_gpu): + tf_x, _ = self._input(shape, dtype=dtypes_lib.float32) + indices = [0, 0, 0, -2] + s = math_ops.segment_sum(data=tf_x, segment_ids=indices) + with self.assertRaisesOpError("segment ids must be >= 0"): + s.eval() def testGradient(self): shape = [4, 4] @@ -341,7 +352,7 @@ class UnsortedSegmentSumTest(SegmentReductionHelper): with self.test_session(use_gpu=True): tf_x, np_x = self._input(shape, dtype=dtypes_lib.float64) s = math_ops.unsorted_segment_max(data=tf_x, segment_ids=indices, - num_segments=num_segments) + num_segments=num_segments) jacob_t, jacob_n = gradient_checker.compute_gradient( tf_x, shape, @@ -635,6 +646,64 @@ class SparseSegmentReductionOpTest(SparseSegmentReductionHelper): with self.assertRaisesOpError(r"Segment id 0 out of range \[0, 0\)"): s.eval() +class SegmentReductionOpBenchmark(test.Benchmark): + outer_dim_options = [2**x for x in range(9, 14, 2)] + ratio_options = [2**x for x in range(1, 6, 2)] + inner_dim_options = [2**x for x in range(9, 14, 2)] + #randomly generated sizes with less alignments + inner_dim_options += [1120, 1215, 1856, 1302, 1329, 1531, 1313, 1672, 1851, 1584] + dtype_options = [np.float32, np.float64] + options = (outer_dim_options, + ratio_options, inner_dim_options, dtype_options) + op_functors = [lambda vc, vs, seg_ids: + ("sorted", math_ops.segment_sum(vc, vs)), + lambda vc, vs, seg_ids: + ("unsorted", math_ops.unsorted_segment_sum(vc, vs, seg_ids[-1]+1))] + repeat = 10 + + def _npTypeToStr(self, t): + if t == np.float32: + return "fp32" + if t == np.float64: + return "fp64" + + def _runGraph(self, op_functor, outer_dim, ratio, inner_dim, dtype): + output_outer_dim = int(outer_dim/ratio) + const = np.random.randint(5, size=(outer_dim, inner_dim)) + seg_ids = np.sort(np.random.randint( + output_outer_dim, size=outer_dim)) + vs = variables.Variable(seg_ids.astype(np.int32)) + with ops.device("/gpu:0"): + vc = variables.Variable(const.astype(dtype)) + name, op = op_functor(vc, vs, seg_ids) + with session.Session() as sess: + variables.global_variables_initializer().run() + r = self.run_op_benchmark(sess, op, min_iters=self.repeat, + name="_".join(map(str, + [name, + outer_dim, + ratio, + inner_dim, + self._npTypeToStr(dtype)]))) + return name, r["wall_time"] + + def benchmarkSegmentSumGPU(self): + if not test.is_gpu_available(cuda_only=True): + return + for outer_dim, ratio, inner_dim, dtype in itertools.product(*self.options): + output_outer_dim = int(outer_dim/ratio) + op_functor = self.op_functors[0] + with ops.Graph().as_default(): + self._runGraph(op_functor, outer_dim, ratio, inner_dim, dtype) + + def benchmarkUnsortedSegmentSumGPU(self): + if not test.is_gpu_available(cuda_only=True): + return + for outer_dim, ratio, inner_dim, dtype in itertools.product(*self.options): + output_outer_dim = int(outer_dim/ratio) + op_functor = self.op_functors[1] + with ops.Graph().as_default(): + self._runGraph(op_functor, outer_dim, ratio, inner_dim, dtype) if __name__ == "__main__": test.main()