From 67a155f782aa86511bff264b89f3a8b8012c3c00 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 27 May 2020 14:10:58 -0700 Subject: [PATCH] Add bfloat16 support for SparseSegmentMean*/SparseSegmentSqrtN* PiperOrigin-RevId: 313461080 Change-Id: Ibf4ed3b6531231e797358940e584471f2c682848 --- .../core/kernels/segment_reduction_ops_impl.h | 110 +++++++++++++----- .../kernels/segment_reduction_ops_impl_5.cc | 2 + tensorflow/core/ops/math_ops.cc | 8 +- 3 files changed, 90 insertions(+), 30 deletions(-) diff --git a/tensorflow/core/kernels/segment_reduction_ops_impl.h b/tensorflow/core/kernels/segment_reduction_ops_impl.h index 8954dcd4681..6c3fad668ae 100644 --- a/tensorflow/core/kernels/segment_reduction_ops_impl.h +++ b/tensorflow/core/kernels/segment_reduction_ops_impl.h @@ -508,6 +508,12 @@ class SparseSegmentReductionOpBase : public OpKernel { errors::InvalidArgument("segment ids must be >= 0")); auto output_flat = output->flat_outer_dims(); + Tensor temp; + if (input.dtype() == DT_BFLOAT16) { + temp = tensorflow::Tensor(DT_FLOAT, output_shape); + } + auto temp_flat = temp.flat_outer_dims(); + int64 start = 0, end = 1; // Index from which the output is not initialized. SegmentId uninitialized_index = 0; @@ -546,8 +552,9 @@ class SparseSegmentReductionOpBase : public OpKernel { } auto out = output_flat.template chip<0>(out_index); - const int bad_offset = - Reduce(input_flat, indices_vec, start, end - start, out); + auto temp = temp_flat.template chip<0>(out_index); + const int bad_offset = Reduce(input_flat, indices_vec, start, + end - start, out, temp); OP_REQUIRES(context, bad_offset < 0, errors::InvalidArgument( "Bad: indices[", start + bad_offset, @@ -572,40 +579,89 @@ class SparseSegmentReductionOpBase : public OpKernel { } private: - int64 Reduce(const typename TTypes::ConstMatrix& input_flat, - const typename TTypes::ConstVec& indices_vec, int64 start, - int64 num, - Eigen::TensorChippingOp<0, typename TTypes::Matrix> out) { + template + using EnableIfBfloat16 = + typename std::enable_if::value, int>::type; + template + using EnableIfNotBfloat16 = + typename std::enable_if::value, int>::type; + + template = 0> + EIGEN_ALWAYS_INLINE auto fetch_val( + const typename TTypes::ConstMatrix& input_flat, Tindex index) { + return input_flat.template chip<0>(index); + } + + template = 0> + EIGEN_ALWAYS_INLINE auto fetch_val( + const typename TTypes::ConstMatrix& input_flat, Tindex index) { + return input_flat.template chip<0>(index).template cast(); + } + + template + EIGEN_ALWAYS_INLINE Tout get_scaling_factor(int64 num) { + Tout m(1); + if (is_mean_ && (num < 10)) { + m = Tout(num); + } + if (is_sqrtn_ && (num < 10)) { + m = Tout(sqrt(num)); + } + return Tout(1) / m; + } + + template = 0> + int64 Reduce( + const typename TTypes::ConstMatrix& input_flat, + const typename TTypes::ConstVec& indices_vec, int64 start, + int64 num, Eigen::TensorChippingOp<0, typename TTypes::Matrix> out, + Eigen::TensorChippingOp<0, typename TTypes::Matrix> temp) { + return ReduceImpl(input_flat, indices_vec, start, num, + out, get_scaling_factor(num)); + } + + template = 0> + int64 Reduce( + const typename TTypes::ConstMatrix& input_flat, + const typename TTypes::ConstVec& indices_vec, int64 start, + int64 num, Eigen::TensorChippingOp<0, typename TTypes::Matrix> out, + Eigen::TensorChippingOp<0, typename TTypes::Matrix> temp) { + int64 res = + ReduceImpl(input_flat, indices_vec, start, num, + temp, get_scaling_factor(num)); + out = temp.template cast(); + return res; + } + + template + int64 ReduceImpl( + const typename TTypes::ConstMatrix& input_flat, + const typename TTypes::ConstVec& indices_vec, int64 start, + int64 num, Eigen::TensorChippingOp<0, typename TTypes::Matrix> out, + const Tout scaling_factor) { #define INDEX(n, i) \ const auto index##n = indices_vec(start + (i)); \ if (!FastBoundsCheck(index##n, input_flat.dimension(0))) return (i); -#define L(n) input_flat.template chip<0>(index##n) +#define L(n) fetch_val(input_flat, index##n) if (num == 1) { INDEX(0, 0); out = L(0); } else { - int64 r = num % 8; - T m(1); - if (is_mean_ && (num < 10)) { - m = T(num); - } - if (is_sqrtn_ && (num < 10)) { - m = T(sqrt(num)); - } + int64 r = num & 7; switch (r) { case 2: { INDEX(0, 0); INDEX(1, 1); - out = (L(0) + L(1)) / m; + out = (L(0) + L(1)) * scaling_factor; break; } case 3: { INDEX(0, 0); INDEX(1, 1); INDEX(2, 2); - out = (L(0) + L(1) + L(2)) / m; + out = (L(0) + L(1) + L(2)) * scaling_factor; break; } case 4: { @@ -613,7 +669,7 @@ class SparseSegmentReductionOpBase : public OpKernel { INDEX(1, 1); INDEX(2, 2); INDEX(3, 3); - out = (L(0) + L(1) + L(2) + L(3)) / m; + out = (L(0) + L(1) + L(2) + L(3)) * scaling_factor; break; } case 5: { @@ -622,7 +678,7 @@ class SparseSegmentReductionOpBase : public OpKernel { INDEX(2, 2); INDEX(3, 3); INDEX(4, 4); - out = (L(0) + L(1) + L(2) + L(3) + L(4)) / m; + out = (L(0) + L(1) + L(2) + L(3) + L(4)) * scaling_factor; break; } case 6: { @@ -632,7 +688,7 @@ class SparseSegmentReductionOpBase : public OpKernel { INDEX(3, 3); INDEX(4, 4); INDEX(5, 5); - out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5)) / m; + out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5)) * scaling_factor; break; } case 7: { @@ -643,7 +699,8 @@ class SparseSegmentReductionOpBase : public OpKernel { INDEX(4, 4); INDEX(5, 5); INDEX(6, 6); - out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6)) / m; + out = + (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6)) * scaling_factor; break; } case 0: { @@ -655,7 +712,8 @@ class SparseSegmentReductionOpBase : public OpKernel { INDEX(5, 5); INDEX(6, 6); INDEX(7, 7); - out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7)) / m; + out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7)) * + scaling_factor; r = 8; break; } @@ -669,8 +727,8 @@ class SparseSegmentReductionOpBase : public OpKernel { INDEX(6, 6); INDEX(7, 7); INDEX(8, 8); - out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7) + L(8)) / - m; + out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7) + L(8)) * + scaling_factor; r = 9; break; } @@ -687,10 +745,10 @@ class SparseSegmentReductionOpBase : public OpKernel { out += L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7); } if (is_mean_ && num >= 10) { - out = out / static_cast(num); + out = out / static_cast(num); } if (is_sqrtn_ && num >= 10) { - out = out / static_cast(sqrt(num)); + out = out / static_cast(sqrt(num)); } } diff --git a/tensorflow/core/kernels/segment_reduction_ops_impl_5.cc b/tensorflow/core/kernels/segment_reduction_ops_impl_5.cc index fee0f818c5e..03a448e52b3 100644 --- a/tensorflow/core/kernels/segment_reduction_ops_impl_5.cc +++ b/tensorflow/core/kernels/segment_reduction_ops_impl_5.cc @@ -64,6 +64,7 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE); segment_ids_type>); REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(float); REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(double); +REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(bfloat16); #undef REGISTER_CPU_SPARSE_KERNELS #define REGISTER_CPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \ @@ -85,6 +86,7 @@ REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(double); CPUDevice, type, index_type, segment_ids_type>); REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(float); REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(double); +REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(bfloat16); #undef REGISTER_CPU_SPARSE_KERNELS #define REGISTER_CPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \ diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 972d6e27b75..dfc2463915c 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -1337,7 +1337,7 @@ REGISTER_OP("SparseSegmentMean") .Input("indices: Tidx") .Input("segment_ids: Tsegmentids") .Output("output: T") - .Attr("T: {float, double}") + .Attr("T: {bfloat16, float, double}") .Attr("Tidx: {int32, int64} = DT_INT32") .Attr("Tsegmentids: {int32, int64} = DT_INT32") .SetShapeFn(SparseSegmentReductionShapeFn); @@ -1348,7 +1348,7 @@ REGISTER_OP("SparseSegmentMeanWithNumSegments") .Input("segment_ids: Tsegmentids") .Input("num_segments: Tnumsegments") .Output("output: T") - .Attr("T: {float, double}") + .Attr("T: {bfloat16, float, double}") .Attr("Tidx: {int32, int64} = DT_INT32") .Attr("Tnumsegments: {int32,int64} = DT_INT32") .Attr("Tsegmentids: {int32, int64} = DT_INT32") @@ -1370,7 +1370,7 @@ REGISTER_OP("SparseSegmentSqrtN") .Input("indices: Tidx") .Input("segment_ids: Tsegmentids") .Output("output: T") - .Attr("T: {float, double}") + .Attr("T: {bfloat16, float, double}") .Attr("Tidx: {int32, int64} = DT_INT32") .Attr("Tsegmentids: {int32, int64} = DT_INT32") .SetShapeFn(SparseSegmentReductionShapeFn); @@ -1381,7 +1381,7 @@ REGISTER_OP("SparseSegmentSqrtNWithNumSegments") .Input("segment_ids: Tsegmentids") .Input("num_segments: Tnumsegments") .Output("output: T") - .Attr("T: {float, double}") + .Attr("T: {bfloat16, float, double}") .Attr("Tidx: {int32, int64} = DT_INT32") .Attr("Tnumsegments: {int32,int64} = DT_INT32") .Attr("Tsegmentids: {int32, int64} = DT_INT32")