Add a bfloat16 sum reducer that uses float32 accumulators. Fix existing tests.

The majority of the changes are from PR #38630 ([Intel MKL] Enable BF16 Softmax/SoftmaxGrad) which was reverted because of test failures.

PiperOrigin-RevId: 314152011
Change-Id: Ib50e1ae90016c05a6fc62b8d21ce7b3f34d28833
This commit is contained in:
Penporn Koanantakool 2020-06-01 10:14:16 -07:00 committed by TensorFlower Gardener
parent 3f91d43368
commit cbac31d59d
7 changed files with 75 additions and 3 deletions

View File

@ -58,7 +58,9 @@ namespace tensorflow {
REGISTER_KERNEL_BUILDER( \
Name("_FusedMatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"), NoOp); \
REGISTER_KERNEL_BUILDER( \
Name("BatchMatMulV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), NoOp);
Name("BatchMatMulV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), NoOp); \
REGISTER_KERNEL_BUILDER( \
Name("Softmax").Device(DEVICE_CPU).TypeConstraint<T>("T"), NoOp);
TF_CALL_bfloat16(REGISTER_CPU);
#undef REGISTER_CPU

View File

@ -18,7 +18,6 @@ limitations under the License.
// Functor definitions for Reduction ops, must be compilable by nvcc.
#include <iostream>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
@ -58,6 +57,29 @@ struct ReduceEigenImpl {
}
};
// Specialization for BF16 Reducer to fix accuracy.
// TODO: All BF16 reducers should have specializations to fix accuracy.
#define CASTING_SPECIALIZATION(Reducer, ScalarType, IntermediateType) \
template <typename Device, typename OUT_T, typename IN_T, \
typename ReductionAxes> \
struct ReduceEigenImpl<Device, OUT_T, IN_T, ReductionAxes, \
Reducer<ScalarType>> { \
void operator()(const Device& d, OUT_T out, IN_T in, \
const ReductionAxes& reduction_axes, \
const Reducer<ScalarType>& reducer) { \
static_assert(std::is_same<ScalarType, typename OUT_T::Scalar>::value, \
""); \
Reducer<IntermediateType> intermediate_reducer; \
auto in_as_intermediate = in.template cast<IntermediateType>(); \
out.device(d) = \
in_as_intermediate.reduce(reduction_axes, intermediate_reducer) \
.template cast<ScalarType>(); \
} \
};
CASTING_SPECIALIZATION(Eigen::internal::SumReducer, bfloat16, float);
#undef CASTING_SPECIALIZATION
template <typename Device, typename OUT_T, typename IN_T,
typename ReductionAxes, typename Scalar>
struct ReduceEigenImpl<Device, OUT_T, IN_T, ReductionAxes,

View File

@ -82,6 +82,7 @@ class SoftmaxOp : public OpKernel {
REGISTER_KERNEL_BUILDER( \
Name("Softmax").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
SoftmaxOp<CPUDevice, T>);
TF_CALL_bfloat16(REGISTER_CPU);
TF_CALL_half(REGISTER_CPU);
TF_CALL_float(REGISTER_CPU);
TF_CALL_double(REGISTER_CPU);

View File

@ -31,7 +31,7 @@ Status SoftmaxGrad(const AttrSlice& attrs, FunctionDef* g) {
// Ret val defs
{"grad_x: T"},
// Attr defs
{{"T: {float, double}"}},
{{"T: {float, double, bfloat16}"}},
// Nodes
// Based on _SoftmaxGrad in nn_grad.py.
{

View File

@ -44,6 +44,16 @@ class ReduceTest(test_util.TensorFlowTestCase):
y_tf = self.evaluate(math_ops.reduce_sum(x))
self.assertEqual(y_tf, 21)
def testReduceExtendType(self):
in_f32 = np.random.randn(1000, 1000).astype(np.float32)
in_bf16 = math_ops.cast(in_f32, dtypes.bfloat16)
out_f32 = self.evaluate(math_ops.reduce_sum(in_f32))
out_bf16 = self.evaluate(math_ops.reduce_sum(in_bf16))
expected = math_ops.cast(out_f32, dtypes.bfloat16)
self.assertAllClose(out_bf16, expected, 1e-3)
def testReduceExplicitAxes(self):
x = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
with test_util.device(use_gpu=True):

View File

@ -20,12 +20,14 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.eager import backprop
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import gradient_checker_v2
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import nn_grad # pylint: disable=unused-import
from tensorflow.python.ops import nn_impl
@ -33,6 +35,29 @@ from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import test
class SoftmaxOpTest(test.TestCase):
# This test is for bfloat16, but the type has a problem with compute_gradient.
# TODO(penporn): Change the data type back to bfloat16 once b/157773623 is
# fixed. (compute_gradient internally converts bfloat16 to float32 for
# calculation anyway.)
def testSoftmaxGradGradExtendType(self):
with self.cached_session():
def f(x):
assert x.dtype == dtypes.float32
with backprop.GradientTape() as tape:
tape.watch(x)
y = nn_ops.softmax(x)
return tape.gradient(y, x)
x = constant_op.constant([[-2, -1, 1, 3], [5, 7, 8, 9]],
dtype=dtypes.float32)
error = gradient_checker_v2.max_error(
*gradient_checker_v2.compute_gradient(f, [x]))
self.assertLess(error, 1e-4)
class Relu6OpTest(test.TestCase):
@test_util.run_deprecated_v1

View File

@ -130,6 +130,18 @@ class SoftmaxTest(test_lib.TestCase, parameterized.TestCase):
self.assertAllClose(x_neg_axis_tf, y_pos_axis_tf, eps)
self.assertAllClose(y_pos_axis_tf, z_gt_axis_tf, eps)
def testSoftmaxExtendType(self):
x_shape = [5, 10]
x_np = np.random.randn(*x_shape).astype(np.float32)
x_f32_tf = constant_op.constant(x_np)
x_bf16_tf = math_ops.cast(x_f32_tf, dtypes.bfloat16)
y_f32_tf = self.evaluate(nn_ops.softmax(x_f32_tf))
y_bf16_tf = self.evaluate(nn_ops.softmax(x_bf16_tf))
expected = math_ops.cast(y_f32_tf, dtypes.bfloat16)
tol = x_shape[1] * 1e-3
self.assertAllClose(y_bf16_tf, expected, rtol=tol, atol=tol)
@parameterized.parameters(((5, 10),), ((2, 3, 4),))
@test_util.run_deprecated_v1
def testGradient(self, x_shape):