Add missing softmax kernels

We were missing a bfloat16 kernel for LogSoftMax and a double kernel on
the GPU.

PiperOrigin-RevId: 314646942
Change-Id: Ifb235609c129f373d4ba30b698f8d906596627fe
This commit is contained in:
Gaurav Jain 2020-06-03 18:53:52 -07:00 committed by TensorFlower Gardener
parent a628c339c5
commit bc27e470f3
3 changed files with 48 additions and 27 deletions
tensorflow

View File

@ -82,19 +82,16 @@ class SoftmaxOp : public OpKernel {
REGISTER_KERNEL_BUILDER( \
Name("Softmax").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
SoftmaxOp<CPUDevice, T>);
TF_CALL_bfloat16(REGISTER_CPU);
TF_CALL_half(REGISTER_CPU);
TF_CALL_float(REGISTER_CPU);
TF_CALL_double(REGISTER_CPU);
TF_CALL_FLOAT_TYPES(REGISTER_CPU);
#undef REGISTER_CPU
#define REGISTER_CPU(T) \
REGISTER_KERNEL_BUILDER( \
Name("LogSoftmax").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
SoftmaxOp<CPUDevice, T>);
TF_CALL_half(REGISTER_CPU);
TF_CALL_float(REGISTER_CPU);
TF_CALL_double(REGISTER_CPU);
TF_CALL_FLOAT_TYPES(REGISTER_CPU);
#undef REGISTER_CPU
#ifdef TENSORFLOW_USE_SYCL
REGISTER_KERNEL_BUILDER(

View File

@ -281,21 +281,20 @@ class SoftmaxOpGPU : public OpKernel {
bool log_;
};
REGISTER_KERNEL_BUILDER(
Name("Softmax").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),
SoftmaxOpGPU<Eigen::half>);
REGISTER_KERNEL_BUILDER(
Name("Softmax").Device(DEVICE_GPU).TypeConstraint<float>("T"),
SoftmaxOpGPU<float>);
REGISTER_KERNEL_BUILDER(
Name("Softmax").Device(DEVICE_GPU).TypeConstraint<double>("T"),
SoftmaxOpGPU<double>);
REGISTER_KERNEL_BUILDER(
Name("LogSoftmax").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),
SoftmaxOpGPU<Eigen::half>);
REGISTER_KERNEL_BUILDER(
Name("LogSoftmax").Device(DEVICE_GPU).TypeConstraint<float>("T"),
SoftmaxOpGPU<float>);
#define REGISTER_GPU(T) \
REGISTER_KERNEL_BUILDER( \
Name("Softmax").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
SoftmaxOpGPU<T>);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
#undef REGISTER_GPU
#define REGISTER_GPU(T) \
REGISTER_KERNEL_BUILDER( \
Name("LogSoftmax").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
SoftmaxOpGPU<T>);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
#undef REGISTER_GPU
} // end namespace tensorflow

View File

@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
@ -55,13 +56,21 @@ class SoftmaxTest(test.TestCase):
res = res.astype(np.float16)
return res
def _testSoftmax(self, np_features, dim=-1, log=False, use_gpu=False):
def _testSoftmax(self,
np_features,
dim=-1,
log=False,
dtype=None,
use_gpu=False):
# A previous version of the code checked the op name rather than the op type
# to distinguish between log and non-log. Use an arbitrary name to catch
# this bug in future.
name = "arbitrary"
np_softmax = self._npSoftmax(np_features, dim=dim, log=log)
with self.cached_session(use_gpu=use_gpu):
if dtype is not None:
np_features = math_ops.cast(np_features, dtype=dtype)
if log:
tf_softmax = nn_ops.log_softmax(np_features, axis=dim, name=name)
else:
@ -69,15 +78,15 @@ class SoftmaxTest(test.TestCase):
out = self.evaluate(tf_softmax)
self.assertAllCloseAccordingToType(np_softmax, out)
self.assertShapeEqual(np_softmax, tf_softmax)
if not log:
if not log and dtype is None:
# Bonus check: the softmaxes should add to one in dimension dim.
sum_along_dim = np.sum(out, axis=dim)
self.assertAllCloseAccordingToType(
np.ones(sum_along_dim.shape), sum_along_dim)
def _testAll(self, features):
self._testSoftmax(features, use_gpu=True)
self._testSoftmax(features, log=True, use_gpu=True)
def _testAll(self, features, dtype=None):
self._testSoftmax(features, dtype=dtype, use_gpu=True)
self._testSoftmax(features, dtype=dtype, log=True, use_gpu=True)
self._testOverflow(use_gpu=True)
def testNpSoftmax(self):
@ -158,6 +167,22 @@ class SoftmaxTest(test.TestCase):
np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float64))
self._testOverflow()
@unittest.skipUnless(test.is_built_with_gpu_support(),
"Test only applicable when running on GPUs")
def testDoubleGPU(self):
if test.is_gpu_available(cuda_only=True):
rows = [2**x + np.random.randint(0, 16) for x in range(1, 4)]
cols = [2**x + np.random.randint(0, 16) for x in range(1, 4)]
for row, col in zip(rows, cols):
logging.info("Testing softmax float dtype in shape [%d, %d]", row, col)
data = np.random.rand(row, col)
self._testAll(data.astype(np.float64))
def testBfloat16(self):
self._testAll(
np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float32),
dtype=dtypes.bfloat16)
def test1DTensorAsInput(self):
self._testSoftmax(
np.array([3., 2., 3., 9.]).astype(np.float64), use_gpu=False)