Improve performance of argmax and argmin GPU kernel in some cases.

Eigen has very poor performance when the output tensor has few elements. With this change, if there are at most 1024 elements, a different implementation is used. The new implementation uses the functor::ReduceImpl function that is used for most other TF reductions. Eigen performs better than functor::ReduceImpl when there are many output elements, which is why Eigen is still used when the number of output elements is greater than 1024.

A benchmark was added. The results from running on my machine with two Xeon E5-2690 v4 CPUs and a Titan V GPU are shown below. All times are in milliseconds. The benchmarks were run in the internal version of TensorFlow. Only float32 benchmarks are shown, as float16 and float64 results are similar. Also only benchmarks where the new implementation are shown, as when the old implementation is used, the performance is the same as before this change.

Benchmark                   New time (s)  Old time (s)  old of new %
1d_float32_dim0             0.00089       0.06431         1.4%
rectangle1_2d_float32_dim1  0.00285       0.06736         4.2%
rectangle2_2d_float32_dim0  0.00298       0.05501         5.2%
rectangle1_3d_float32_dim0  0.07876       0.12668        62.2%
rectangle2_3d_float32_dim1  0.07869       0.12757        61.7%
rectangle3_3d_float32_dim2  0.07847       0.78461        10.0%

PiperOrigin-RevId: 292206797
Change-Id: Ic586910e0935463190761dc3ec9e7122bba06bd6
This commit is contained in:
Reed Wanderman-Milne 2020-01-29 13:30:34 -08:00 committed by TensorFlower Gardener
parent 720b16121e
commit 837b673aa7
5 changed files with 239 additions and 4 deletions

View File

@ -3913,7 +3913,7 @@ tf_kernel_library(
tf_kernel_library(
name = "argmax_op",
prefix = "argmax_op",
deps = MATH_DEPS,
deps = MATH_DEPS + if_cuda_or_rocm([":reduction_ops"]),
)
tf_kernel_library(

View File

@ -13,8 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// See docs in ../ops/math_ops.cc.
#define EIGEN_USE_THREADS
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
@ -41,6 +39,39 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
template <typename Device>
struct CustomArgOp;
template <>
struct CustomArgOp<CPUDevice> {
template <typename T, typename Tout, typename ArgFunctor>
// Determines whether the custom kernel in argmax_op_gpu.cu.cc should be
// used, and if so, runs it by calling DoGpuArgOp. If it was run,
// returns true. Otherwise, it returns false and the caller must calculate the
// arg min or max itself.
static bool CustomArgFunc(OpKernelContext* context, const Tensor& input,
int axis, Tensor* output) {
return false;
}
};
template <>
struct CustomArgOp<GPUDevice> {
template <typename T, typename Tout, typename ArgFunctor>
static bool CustomArgFunc(OpKernelContext* context, const Tensor& input,
int axis, Tensor* output) {
if (output->NumElements() <= 1024 || output->dims() > 7) {
// The custom kernel is faster than Eigen when the number of output
// elements is relatively small. We also only handle the Eigen case for up
// to 7 dimensions.
DoGpuArgOp<T, Tout, ArgFunctor::is_argmax>(context, input, axis, output);
return true;
} else {
return false;
}
}
};
template <typename Device, typename T, typename Tout, typename ArgFunctor>
class ArgOp : public OpKernel {
public:
@ -81,6 +112,11 @@ class ArgOp : public OpKernel {
return;
}
if (CustomArgOp<Device>::template CustomArgFunc<T, Tout, ArgFunctor>(
context, input, axis, output)) {
return;
}
#define HANDLE_DIM(NDIM) \
case NDIM: \
ArgFunctor::Reduce##NDIM(context->eigen_device<Device>(), \

View File

@ -18,6 +18,8 @@ limitations under the License.
// Generator definition for ArgMaxOp, must be compilable by nvcc.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/platform/types.h"
@ -43,6 +45,7 @@ struct ArgMax {
DECLARE_COMPUTE_SPEC(7);
#undef DECLARE_COMPUTE_SPEC
enum { is_argmax = true };
};
template <typename Device, typename T, typename Tout>
@ -63,10 +66,15 @@ struct ArgMin {
DECLARE_COMPUTE_SPEC(7);
#undef DECLARE_COMPUTE_SPEC
enum { is_argmax = false };
};
} // namespace functor
template <typename T, typename Tout, bool is_argmax>
void DoGpuArgOp(OpKernelContext* context, const Tensor& input, int axis,
Tensor* output);
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_ARGMAX_OP_H_

View File

@ -20,11 +20,147 @@ limitations under the License.
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/argmax_op.h"
#include "tensorflow/core/kernels/reduction_gpu_kernels.cu.h"
#include "tensorflow/core/kernels/reduction_ops_common.h"
namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
typedef tensorflow::TTypes<float>::Tensor::Index Index;
// To compute the argmax/argmin, we perform a reduction on KeyValuePairs, which
// are (flattened index, value) pairs.
template <typename T>
using KeyValuePair = cub::KeyValuePair<Index, T>;
namespace {
template <typename T, bool is_argmax>
struct MaxOrMinFunc;
// The reduction operator: Returns the KeyValuePair with the highest or lowest
// value.
template <typename T>
struct MaxOrMinFunc<T, true> {
__host__ __device__ __forceinline__ KeyValuePair<T> operator()(
const KeyValuePair<T>& lhs, const KeyValuePair<T>& rhs) {
// If one value is NaN, we choose the other value. This behavior is not
// guaranteed by the op and may change in the future.
return (lhs.value > rhs.value || Eigen::numext::isnan(rhs.value)) ? lhs
: rhs;
}
};
template <typename T>
struct MaxOrMinFunc<T, false> {
__host__ __device__ __forceinline__ KeyValuePair<T> operator()(
const KeyValuePair<T>& lhs, const KeyValuePair<T>& rhs) {
return (lhs.value < rhs.value || Eigen::numext::isnan(rhs.value)) ? lhs
: rhs;
}
};
// The output converter: Converts from a KeyValuePair to an index into a a
// specific dimension. dim1 is the size of the dimension being reduced. dim2 is
// the size of the dimension(s) after dim1.
template <typename T, typename Tout>
struct OutputConverter {
OutputConverter(Index dim1, Index dim2) : dim1_(dim1), dim2_(dim2) {}
__host__ __device__ __forceinline__ Tout
operator()(const KeyValuePair<T>& key_value_pair) const {
return static_cast<Tout>((key_value_pair.key / dim2_) % dim1_);
}
Index dim1_;
Index dim2_;
};
} // namespace
namespace functor {
namespace reduction_op_helper {
// Template specialization of IdentityValue, to return the identity value for
// the reduction. This is needed for ReduceImpl, a function we call. We return
// (0, -inf) for argmax and (0, inf) for argmin.
template <typename T>
struct IdentityValue<KeyValuePair<T>, MaxOrMinFunc<T, true>> {
KeyValuePair<T> operator()() {
return {0, -std::numeric_limits<T>::infinity()};
}
};
template <typename T>
struct IdentityValue<KeyValuePair<T>, MaxOrMinFunc<T, false>> {
KeyValuePair<T> operator()() {
return {0, std::numeric_limits<T>::infinity()};
}
};
} // namespace reduction_op_helper
} // namespace functor
template <typename T, typename Tout, bool is_argmax>
void DoGpuArgOp(OpKernelContext* context, const Tensor& input, int axis,
Tensor* output) {
// We collapse adjacent axes of the input tensor in order to view it as a
// 3 dimensional tensor. The reduction axis is not collapsed, so the three new
// axes will be the input axes to the left of the reduction axis, the
// reduction axis, and the input axes to the right of the reduction axis.
Index dim0 = 1;
for (Index i = 0; i < axis; i++) {
dim0 *= input.dim_size(i);
}
Index dim1 = input.dim_size(axis);
Index dim2 = 1;
for (Index i = axis + 1; i < input.dims(); i++) {
dim2 *= input.dim_size(i);
}
DCHECK_EQ(dim0 * dim1 * dim2, input.NumElements());
auto inp = input.shaped<T, 3>({dim0, dim1, dim2});
auto out = output->shaped<Tout, 2>({dim0, dim2});
// We call ReduceImpl to perform the reduction. The input iterator returns
// KeyValuePairs. The reduction functor returns the KeyValuePair with the max
// or min value. The output iterator converts the KeyValuePair into an index
// into dim1.
using InputIterType = cub::ArgIndexInputIterator<const T*>;
using Functor = MaxOrMinFunc<T, is_argmax>;
using OutputIterType =
TransformOutputIterator<Tout, KeyValuePair<T>, OutputConverter<T, Tout>>;
InputIterType inp_wrapper(inp.data());
OutputIterType out_wrapper(out.data(), OutputConverter<T, Tout>(dim1, dim2));
typedef const Eigen::array<TTypes<float>::Tensor::Index, 1>& ReductionAxes;
Constants<GPUDevice> constants;
// TODO(reedwm): We can probably improve performance by writing specialized
// argmax kernels instead of relying on the generic ReduceImpl function
functor::ReduceImpl<KeyValuePair<T>, Functor, OutputIterType, InputIterType,
ReductionAxes>(context, out_wrapper, inp_wrapper, 3, dim0,
dim1, dim2, 2, constants.kOne, Functor());
}
#define DEFINE_GPU_ARG_OPS(T) \
template void DoGpuArgOp<T, int64, true>(OpKernelContext * context, \
const Tensor& input, int axis, \
Tensor* output); \
template void DoGpuArgOp<T, int64, false>(OpKernelContext * context, \
const Tensor& input, int axis, \
Tensor* output); \
template void DoGpuArgOp<T, int32, true>(OpKernelContext * context, \
const Tensor& input, int axis, \
Tensor* output); \
template void DoGpuArgOp<T, int32, false>(OpKernelContext * context, \
const Tensor& input, int axis, \
Tensor* output);
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_ARG_OPS);
#define DEFINE_GPU_SPEC(T) \
template struct functor::ArgMax<GPUDevice, T, int64>; \
template struct functor::ArgMin<GPUDevice, T, int64>; \

View File

@ -21,10 +21,15 @@ import functools
import numpy as np
from tensorflow.python.client import session
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import benchmark
from tensorflow.python.platform import test
@ -69,7 +74,7 @@ class ArgMaxTest(test.TestCase):
self._testBothArg(math_ops.argmin, x, 0, x.argmin())
def _testDim(self, dtype):
shape = (3, 2, 4, 5, 6, 3, 7)
shape = (3, 2, 4, 1, 5, 3, 2)
x = np.arange(functools.reduce(lambda x, y: x * y, shape), dtype=dtype)
np.random.shuffle(x)
x = x.reshape(shape)
@ -79,9 +84,17 @@ class ArgMaxTest(test.TestCase):
self._testBothArg(math_ops.argmax, x, axis, x.argmax(axis))
self._testBothArg(math_ops.argmin, x, axis, x.argmin(axis))
def _testLargeOutput(self, dtype):
# Test case where output size is greater than 1024, which uses a different
# codepath on the GPU.
x = np.asarray(100 * np.random.randn(11, 10, 5, 11), dtype=dtype)
self._testBothArg(math_ops.argmax, x, 2, x.argmax(2))
self._testBothArg(math_ops.argmin, x, 2, x.argmin(2))
def testFloat(self):
self._testBasic(np.float32)
self._testDim(np.float32)
self._testLargeOutput(np.float32)
def testFloatInt32Output(self):
x = np.asarray(100 * np.random.randn(200), dtype=np.float32)
@ -103,6 +116,12 @@ class ArgMaxTest(test.TestCase):
def testDouble(self):
self._testBasic(np.float64)
self._testDim(np.float64)
self._testLargeOutput(np.float64)
def testHalf(self):
self._testBasic(np.float16)
self._testDim(np.float16)
self._testLargeOutput(np.float16)
def testInt32(self):
self._testBasic(np.int32)
@ -134,5 +153,41 @@ class ArgMaxTest(test.TestCase):
self.assertEqual(ret.shape, (1, 0))
class ArgMaxBenchmark(test.Benchmark):
def _RunSingleBenchmark(self, shape, dtype, bench_name):
with session.Session(config=benchmark.benchmark_config()) as sess:
num_dims = len(shape)
var = variables.Variable(random_ops.random_uniform(shape, dtype=dtype))
variables.variables_initializer([var]).run()
for dim in range(num_dims):
num_ops_in_group = 15
op = control_flow_ops.group(*(math_ops.argmax(var, dimension=dim)
for _ in range(num_ops_in_group)))
op_name = "%s_%s_dim%d" % (bench_name, dtype.name, dim)
num_bytes = num_ops_in_group * np.prod(shape) * dtype.size
self.run_op_benchmark(sess, op, burn_iters=5, min_iters=20,
name=op_name, mbs=num_bytes / 1e6)
def _runBenchmarksWithDtype(self, dtype):
self._RunSingleBenchmark((2**17,), dtype, "1d")
self._RunSingleBenchmark((2**13, 2**13), dtype, "square_2d")
self._RunSingleBenchmark((2**5, 2**16), dtype, "rectangle1_2d")
self._RunSingleBenchmark((2**16, 2**5), dtype, "rectangle2_2d")
self._RunSingleBenchmark((2**8, 2**8, 2**8), dtype, "cube_3d")
self._RunSingleBenchmark((2**16, 2**5, 2**5), dtype, "rectangle1_3d")
self._RunSingleBenchmark((2**5, 2**16, 2**5), dtype, "rectangle2_3d")
self._RunSingleBenchmark((2**5, 2**5, 2**16), dtype, "rectangle3_3d")
def benchmarkFloat(self):
self._runBenchmarksWithDtype(dtypes.float32)
def benchmarkDouble(self):
self._runBenchmarksWithDtype(dtypes.float64)
def benchmarkHalf(self):
self._runBenchmarksWithDtype(dtypes.float16)
if __name__ == "__main__":
test.main()