diff --git a/tensorflow/core/kernels/mkl_tmp_bf16_ops.cc b/tensorflow/core/kernels/mkl_tmp_bf16_ops.cc index 9b2d09fb827..ed5fec677e8 100644 --- a/tensorflow/core/kernels/mkl_tmp_bf16_ops.cc +++ b/tensorflow/core/kernels/mkl_tmp_bf16_ops.cc @@ -58,7 +58,9 @@ namespace tensorflow { REGISTER_KERNEL_BUILDER( \ Name("_FusedMatMul").Device(DEVICE_CPU).TypeConstraint("T"), NoOp); \ REGISTER_KERNEL_BUILDER( \ - Name("BatchMatMulV2").Device(DEVICE_CPU).TypeConstraint("T"), NoOp); + Name("BatchMatMulV2").Device(DEVICE_CPU).TypeConstraint("T"), NoOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Softmax").Device(DEVICE_CPU).TypeConstraint("T"), NoOp); TF_CALL_bfloat16(REGISTER_CPU); #undef REGISTER_CPU diff --git a/tensorflow/core/kernels/reduction_ops.h b/tensorflow/core/kernels/reduction_ops.h index 46d8051fff1..d1066d6556b 100644 --- a/tensorflow/core/kernels/reduction_ops.h +++ b/tensorflow/core/kernels/reduction_ops.h @@ -18,7 +18,6 @@ limitations under the License. // Functor definitions for Reduction ops, must be compilable by nvcc. -#include #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_types.h" @@ -58,6 +57,29 @@ struct ReduceEigenImpl { } }; +// Specialization for BF16 Reducer to fix accuracy. +// TODO: All BF16 reducers should have specializations to fix accuracy. +#define CASTING_SPECIALIZATION(Reducer, ScalarType, IntermediateType) \ + template \ + struct ReduceEigenImpl> { \ + void operator()(const Device& d, OUT_T out, IN_T in, \ + const ReductionAxes& reduction_axes, \ + const Reducer& reducer) { \ + static_assert(std::is_same::value, \ + ""); \ + Reducer intermediate_reducer; \ + auto in_as_intermediate = in.template cast(); \ + out.device(d) = \ + in_as_intermediate.reduce(reduction_axes, intermediate_reducer) \ + .template cast(); \ + } \ + }; + +CASTING_SPECIALIZATION(Eigen::internal::SumReducer, bfloat16, float); +#undef CASTING_SPECIALIZATION + template struct ReduceEigenImpl("T"), \ SoftmaxOp); +TF_CALL_bfloat16(REGISTER_CPU); TF_CALL_half(REGISTER_CPU); TF_CALL_float(REGISTER_CPU); TF_CALL_double(REGISTER_CPU); diff --git a/tensorflow/core/ops/nn_grad.cc b/tensorflow/core/ops/nn_grad.cc index 7beaf57c10b..30c0c96c6c8 100644 --- a/tensorflow/core/ops/nn_grad.cc +++ b/tensorflow/core/ops/nn_grad.cc @@ -31,7 +31,7 @@ Status SoftmaxGrad(const AttrSlice& attrs, FunctionDef* g) { // Ret val defs {"grad_x: T"}, // Attr defs - {{"T: {float, double}"}}, + {{"T: {float, double, bfloat16}"}}, // Nodes // Based on _SoftmaxGrad in nn_grad.py. { diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py index afa1dbdbaf7..ca34a0012f1 100644 --- a/tensorflow/python/ops/math_ops_test.py +++ b/tensorflow/python/ops/math_ops_test.py @@ -44,6 +44,16 @@ class ReduceTest(test_util.TensorFlowTestCase): y_tf = self.evaluate(math_ops.reduce_sum(x)) self.assertEqual(y_tf, 21) + def testReduceExtendType(self): + in_f32 = np.random.randn(1000, 1000).astype(np.float32) + in_bf16 = math_ops.cast(in_f32, dtypes.bfloat16) + + out_f32 = self.evaluate(math_ops.reduce_sum(in_f32)) + out_bf16 = self.evaluate(math_ops.reduce_sum(in_bf16)) + expected = math_ops.cast(out_f32, dtypes.bfloat16) + + self.assertAllClose(out_bf16, expected, 1e-3) + def testReduceExplicitAxes(self): x = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32) with test_util.device(use_gpu=True): diff --git a/tensorflow/python/ops/nn_grad_test.py b/tensorflow/python/ops/nn_grad_test.py index 9da56cb7200..490451c16c9 100644 --- a/tensorflow/python/ops/nn_grad_test.py +++ b/tensorflow/python/ops/nn_grad_test.py @@ -20,12 +20,14 @@ from __future__ import print_function import numpy as np +from tensorflow.python.eager import backprop from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import gradient_checker +from tensorflow.python.ops import gradient_checker_v2 from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import nn_grad # pylint: disable=unused-import from tensorflow.python.ops import nn_impl @@ -33,6 +35,29 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.platform import test +class SoftmaxOpTest(test.TestCase): + + # This test is for bfloat16, but the type has a problem with compute_gradient. + # TODO(penporn): Change the data type back to bfloat16 once b/157773623 is + # fixed. (compute_gradient internally converts bfloat16 to float32 for + # calculation anyway.) + def testSoftmaxGradGradExtendType(self): + with self.cached_session(): + + def f(x): + assert x.dtype == dtypes.float32 + with backprop.GradientTape() as tape: + tape.watch(x) + y = nn_ops.softmax(x) + return tape.gradient(y, x) + + x = constant_op.constant([[-2, -1, 1, 3], [5, 7, 8, 9]], + dtype=dtypes.float32) + error = gradient_checker_v2.max_error( + *gradient_checker_v2.compute_gradient(f, [x])) + self.assertLess(error, 1e-4) + + class Relu6OpTest(test.TestCase): @test_util.run_deprecated_v1 diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index 911eca9fbae..e672018bcf6 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -130,6 +130,18 @@ class SoftmaxTest(test_lib.TestCase, parameterized.TestCase): self.assertAllClose(x_neg_axis_tf, y_pos_axis_tf, eps) self.assertAllClose(y_pos_axis_tf, z_gt_axis_tf, eps) + def testSoftmaxExtendType(self): + x_shape = [5, 10] + x_np = np.random.randn(*x_shape).astype(np.float32) + + x_f32_tf = constant_op.constant(x_np) + x_bf16_tf = math_ops.cast(x_f32_tf, dtypes.bfloat16) + y_f32_tf = self.evaluate(nn_ops.softmax(x_f32_tf)) + y_bf16_tf = self.evaluate(nn_ops.softmax(x_bf16_tf)) + expected = math_ops.cast(y_f32_tf, dtypes.bfloat16) + tol = x_shape[1] * 1e-3 + self.assertAllClose(y_bf16_tf, expected, rtol=tol, atol=tol) + @parameterized.parameters(((5, 10),), ((2, 3, 4),)) @test_util.run_deprecated_v1 def testGradient(self, x_shape):