Add missing softmax kernels
We were missing a bfloat16 kernel for LogSoftMax and a double kernel on the GPU. PiperOrigin-RevId: 314646942 Change-Id: Ifb235609c129f373d4ba30b698f8d906596627fe
This commit is contained in:
parent
a628c339c5
commit
bc27e470f3
@ -82,19 +82,16 @@ class SoftmaxOp : public OpKernel {
|
|||||||
REGISTER_KERNEL_BUILDER( \
|
REGISTER_KERNEL_BUILDER( \
|
||||||
Name("Softmax").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
Name("Softmax").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||||
SoftmaxOp<CPUDevice, T>);
|
SoftmaxOp<CPUDevice, T>);
|
||||||
TF_CALL_bfloat16(REGISTER_CPU);
|
TF_CALL_FLOAT_TYPES(REGISTER_CPU);
|
||||||
TF_CALL_half(REGISTER_CPU);
|
|
||||||
TF_CALL_float(REGISTER_CPU);
|
|
||||||
TF_CALL_double(REGISTER_CPU);
|
|
||||||
|
|
||||||
#undef REGISTER_CPU
|
#undef REGISTER_CPU
|
||||||
#define REGISTER_CPU(T) \
|
#define REGISTER_CPU(T) \
|
||||||
REGISTER_KERNEL_BUILDER( \
|
REGISTER_KERNEL_BUILDER( \
|
||||||
Name("LogSoftmax").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
Name("LogSoftmax").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||||
SoftmaxOp<CPUDevice, T>);
|
SoftmaxOp<CPUDevice, T>);
|
||||||
TF_CALL_half(REGISTER_CPU);
|
TF_CALL_FLOAT_TYPES(REGISTER_CPU);
|
||||||
TF_CALL_float(REGISTER_CPU);
|
|
||||||
TF_CALL_double(REGISTER_CPU);
|
#undef REGISTER_CPU
|
||||||
|
|
||||||
#ifdef TENSORFLOW_USE_SYCL
|
#ifdef TENSORFLOW_USE_SYCL
|
||||||
REGISTER_KERNEL_BUILDER(
|
REGISTER_KERNEL_BUILDER(
|
||||||
|
@ -281,21 +281,20 @@ class SoftmaxOpGPU : public OpKernel {
|
|||||||
bool log_;
|
bool log_;
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(
|
#define REGISTER_GPU(T) \
|
||||||
Name("Softmax").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),
|
REGISTER_KERNEL_BUILDER( \
|
||||||
SoftmaxOpGPU<Eigen::half>);
|
Name("Softmax").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
|
||||||
REGISTER_KERNEL_BUILDER(
|
SoftmaxOpGPU<T>);
|
||||||
Name("Softmax").Device(DEVICE_GPU).TypeConstraint<float>("T"),
|
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
|
||||||
SoftmaxOpGPU<float>);
|
|
||||||
REGISTER_KERNEL_BUILDER(
|
#undef REGISTER_GPU
|
||||||
Name("Softmax").Device(DEVICE_GPU).TypeConstraint<double>("T"),
|
#define REGISTER_GPU(T) \
|
||||||
SoftmaxOpGPU<double>);
|
REGISTER_KERNEL_BUILDER( \
|
||||||
REGISTER_KERNEL_BUILDER(
|
Name("LogSoftmax").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
|
||||||
Name("LogSoftmax").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),
|
SoftmaxOpGPU<T>);
|
||||||
SoftmaxOpGPU<Eigen::half>);
|
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
|
||||||
REGISTER_KERNEL_BUILDER(
|
|
||||||
Name("LogSoftmax").Device(DEVICE_GPU).TypeConstraint<float>("T"),
|
#undef REGISTER_GPU
|
||||||
SoftmaxOpGPU<float>);
|
|
||||||
|
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
|
||||||
|
@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes
|
|||||||
from tensorflow.python.framework import errors_impl
|
from tensorflow.python.framework import errors_impl
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import nn_ops
|
from tensorflow.python.ops import nn_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
@ -55,13 +56,21 @@ class SoftmaxTest(test.TestCase):
|
|||||||
res = res.astype(np.float16)
|
res = res.astype(np.float16)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def _testSoftmax(self, np_features, dim=-1, log=False, use_gpu=False):
|
def _testSoftmax(self,
|
||||||
|
np_features,
|
||||||
|
dim=-1,
|
||||||
|
log=False,
|
||||||
|
dtype=None,
|
||||||
|
use_gpu=False):
|
||||||
# A previous version of the code checked the op name rather than the op type
|
# A previous version of the code checked the op name rather than the op type
|
||||||
# to distinguish between log and non-log. Use an arbitrary name to catch
|
# to distinguish between log and non-log. Use an arbitrary name to catch
|
||||||
# this bug in future.
|
# this bug in future.
|
||||||
name = "arbitrary"
|
name = "arbitrary"
|
||||||
np_softmax = self._npSoftmax(np_features, dim=dim, log=log)
|
np_softmax = self._npSoftmax(np_features, dim=dim, log=log)
|
||||||
with self.cached_session(use_gpu=use_gpu):
|
with self.cached_session(use_gpu=use_gpu):
|
||||||
|
if dtype is not None:
|
||||||
|
np_features = math_ops.cast(np_features, dtype=dtype)
|
||||||
|
|
||||||
if log:
|
if log:
|
||||||
tf_softmax = nn_ops.log_softmax(np_features, axis=dim, name=name)
|
tf_softmax = nn_ops.log_softmax(np_features, axis=dim, name=name)
|
||||||
else:
|
else:
|
||||||
@ -69,15 +78,15 @@ class SoftmaxTest(test.TestCase):
|
|||||||
out = self.evaluate(tf_softmax)
|
out = self.evaluate(tf_softmax)
|
||||||
self.assertAllCloseAccordingToType(np_softmax, out)
|
self.assertAllCloseAccordingToType(np_softmax, out)
|
||||||
self.assertShapeEqual(np_softmax, tf_softmax)
|
self.assertShapeEqual(np_softmax, tf_softmax)
|
||||||
if not log:
|
if not log and dtype is None:
|
||||||
# Bonus check: the softmaxes should add to one in dimension dim.
|
# Bonus check: the softmaxes should add to one in dimension dim.
|
||||||
sum_along_dim = np.sum(out, axis=dim)
|
sum_along_dim = np.sum(out, axis=dim)
|
||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
np.ones(sum_along_dim.shape), sum_along_dim)
|
np.ones(sum_along_dim.shape), sum_along_dim)
|
||||||
|
|
||||||
def _testAll(self, features):
|
def _testAll(self, features, dtype=None):
|
||||||
self._testSoftmax(features, use_gpu=True)
|
self._testSoftmax(features, dtype=dtype, use_gpu=True)
|
||||||
self._testSoftmax(features, log=True, use_gpu=True)
|
self._testSoftmax(features, dtype=dtype, log=True, use_gpu=True)
|
||||||
self._testOverflow(use_gpu=True)
|
self._testOverflow(use_gpu=True)
|
||||||
|
|
||||||
def testNpSoftmax(self):
|
def testNpSoftmax(self):
|
||||||
@ -158,6 +167,22 @@ class SoftmaxTest(test.TestCase):
|
|||||||
np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float64))
|
np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float64))
|
||||||
self._testOverflow()
|
self._testOverflow()
|
||||||
|
|
||||||
|
@unittest.skipUnless(test.is_built_with_gpu_support(),
|
||||||
|
"Test only applicable when running on GPUs")
|
||||||
|
def testDoubleGPU(self):
|
||||||
|
if test.is_gpu_available(cuda_only=True):
|
||||||
|
rows = [2**x + np.random.randint(0, 16) for x in range(1, 4)]
|
||||||
|
cols = [2**x + np.random.randint(0, 16) for x in range(1, 4)]
|
||||||
|
for row, col in zip(rows, cols):
|
||||||
|
logging.info("Testing softmax float dtype in shape [%d, %d]", row, col)
|
||||||
|
data = np.random.rand(row, col)
|
||||||
|
self._testAll(data.astype(np.float64))
|
||||||
|
|
||||||
|
def testBfloat16(self):
|
||||||
|
self._testAll(
|
||||||
|
np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float32),
|
||||||
|
dtype=dtypes.bfloat16)
|
||||||
|
|
||||||
def test1DTensorAsInput(self):
|
def test1DTensorAsInput(self):
|
||||||
self._testSoftmax(
|
self._testSoftmax(
|
||||||
np.array([3., 2., 3., 9.]).astype(np.float64), use_gpu=False)
|
np.array([3., 2., 3., 9.]).astype(np.float64), use_gpu=False)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user