Add a bfloat16 sum reducer that uses float32 accumulators. Fix existing tests.
The majority of the changes are from PR #38630 ([Intel MKL] Enable BF16 Softmax/SoftmaxGrad) which was reverted because of test failures. PiperOrigin-RevId: 314152011 Change-Id: Ib50e1ae90016c05a6fc62b8d21ce7b3f34d28833
This commit is contained in:
parent
3f91d43368
commit
cbac31d59d
@ -58,7 +58,9 @@ namespace tensorflow {
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("_FusedMatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"), NoOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("BatchMatMulV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), NoOp);
|
||||
Name("BatchMatMulV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), NoOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Softmax").Device(DEVICE_CPU).TypeConstraint<T>("T"), NoOp);
|
||||
|
||||
TF_CALL_bfloat16(REGISTER_CPU);
|
||||
#undef REGISTER_CPU
|
||||
|
@ -18,7 +18,6 @@ limitations under the License.
|
||||
|
||||
// Functor definitions for Reduction ops, must be compilable by nvcc.
|
||||
|
||||
#include <iostream>
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
@ -58,6 +57,29 @@ struct ReduceEigenImpl {
|
||||
}
|
||||
};
|
||||
|
||||
// Specialization for BF16 Reducer to fix accuracy.
|
||||
// TODO: All BF16 reducers should have specializations to fix accuracy.
|
||||
#define CASTING_SPECIALIZATION(Reducer, ScalarType, IntermediateType) \
|
||||
template <typename Device, typename OUT_T, typename IN_T, \
|
||||
typename ReductionAxes> \
|
||||
struct ReduceEigenImpl<Device, OUT_T, IN_T, ReductionAxes, \
|
||||
Reducer<ScalarType>> { \
|
||||
void operator()(const Device& d, OUT_T out, IN_T in, \
|
||||
const ReductionAxes& reduction_axes, \
|
||||
const Reducer<ScalarType>& reducer) { \
|
||||
static_assert(std::is_same<ScalarType, typename OUT_T::Scalar>::value, \
|
||||
""); \
|
||||
Reducer<IntermediateType> intermediate_reducer; \
|
||||
auto in_as_intermediate = in.template cast<IntermediateType>(); \
|
||||
out.device(d) = \
|
||||
in_as_intermediate.reduce(reduction_axes, intermediate_reducer) \
|
||||
.template cast<ScalarType>(); \
|
||||
} \
|
||||
};
|
||||
|
||||
CASTING_SPECIALIZATION(Eigen::internal::SumReducer, bfloat16, float);
|
||||
#undef CASTING_SPECIALIZATION
|
||||
|
||||
template <typename Device, typename OUT_T, typename IN_T,
|
||||
typename ReductionAxes, typename Scalar>
|
||||
struct ReduceEigenImpl<Device, OUT_T, IN_T, ReductionAxes,
|
||||
|
@ -82,6 +82,7 @@ class SoftmaxOp : public OpKernel {
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Softmax").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||
SoftmaxOp<CPUDevice, T>);
|
||||
TF_CALL_bfloat16(REGISTER_CPU);
|
||||
TF_CALL_half(REGISTER_CPU);
|
||||
TF_CALL_float(REGISTER_CPU);
|
||||
TF_CALL_double(REGISTER_CPU);
|
||||
|
@ -31,7 +31,7 @@ Status SoftmaxGrad(const AttrSlice& attrs, FunctionDef* g) {
|
||||
// Ret val defs
|
||||
{"grad_x: T"},
|
||||
// Attr defs
|
||||
{{"T: {float, double}"}},
|
||||
{{"T: {float, double, bfloat16}"}},
|
||||
// Nodes
|
||||
// Based on _SoftmaxGrad in nn_grad.py.
|
||||
{
|
||||
|
@ -44,6 +44,16 @@ class ReduceTest(test_util.TensorFlowTestCase):
|
||||
y_tf = self.evaluate(math_ops.reduce_sum(x))
|
||||
self.assertEqual(y_tf, 21)
|
||||
|
||||
def testReduceExtendType(self):
|
||||
in_f32 = np.random.randn(1000, 1000).astype(np.float32)
|
||||
in_bf16 = math_ops.cast(in_f32, dtypes.bfloat16)
|
||||
|
||||
out_f32 = self.evaluate(math_ops.reduce_sum(in_f32))
|
||||
out_bf16 = self.evaluate(math_ops.reduce_sum(in_bf16))
|
||||
expected = math_ops.cast(out_f32, dtypes.bfloat16)
|
||||
|
||||
self.assertAllClose(out_bf16, expected, 1e-3)
|
||||
|
||||
def testReduceExplicitAxes(self):
|
||||
x = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
|
||||
with test_util.device(use_gpu=True):
|
||||
|
@ -20,12 +20,14 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_nn_ops
|
||||
from tensorflow.python.ops import gradient_checker
|
||||
from tensorflow.python.ops import gradient_checker_v2
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import nn_grad # pylint: disable=unused-import
|
||||
from tensorflow.python.ops import nn_impl
|
||||
@ -33,6 +35,29 @@ from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class SoftmaxOpTest(test.TestCase):
|
||||
|
||||
# This test is for bfloat16, but the type has a problem with compute_gradient.
|
||||
# TODO(penporn): Change the data type back to bfloat16 once b/157773623 is
|
||||
# fixed. (compute_gradient internally converts bfloat16 to float32 for
|
||||
# calculation anyway.)
|
||||
def testSoftmaxGradGradExtendType(self):
|
||||
with self.cached_session():
|
||||
|
||||
def f(x):
|
||||
assert x.dtype == dtypes.float32
|
||||
with backprop.GradientTape() as tape:
|
||||
tape.watch(x)
|
||||
y = nn_ops.softmax(x)
|
||||
return tape.gradient(y, x)
|
||||
|
||||
x = constant_op.constant([[-2, -1, 1, 3], [5, 7, 8, 9]],
|
||||
dtype=dtypes.float32)
|
||||
error = gradient_checker_v2.max_error(
|
||||
*gradient_checker_v2.compute_gradient(f, [x]))
|
||||
self.assertLess(error, 1e-4)
|
||||
|
||||
|
||||
class Relu6OpTest(test.TestCase):
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
|
@ -130,6 +130,18 @@ class SoftmaxTest(test_lib.TestCase, parameterized.TestCase):
|
||||
self.assertAllClose(x_neg_axis_tf, y_pos_axis_tf, eps)
|
||||
self.assertAllClose(y_pos_axis_tf, z_gt_axis_tf, eps)
|
||||
|
||||
def testSoftmaxExtendType(self):
|
||||
x_shape = [5, 10]
|
||||
x_np = np.random.randn(*x_shape).astype(np.float32)
|
||||
|
||||
x_f32_tf = constant_op.constant(x_np)
|
||||
x_bf16_tf = math_ops.cast(x_f32_tf, dtypes.bfloat16)
|
||||
y_f32_tf = self.evaluate(nn_ops.softmax(x_f32_tf))
|
||||
y_bf16_tf = self.evaluate(nn_ops.softmax(x_bf16_tf))
|
||||
expected = math_ops.cast(y_f32_tf, dtypes.bfloat16)
|
||||
tol = x_shape[1] * 1e-3
|
||||
self.assertAllClose(y_bf16_tf, expected, rtol=tol, atol=tol)
|
||||
|
||||
@parameterized.parameters(((5, 10),), ((2, 3, 4),))
|
||||
@test_util.run_deprecated_v1
|
||||
def testGradient(self, x_shape):
|
||||
|
Loading…
x
Reference in New Issue
Block a user