Add bfloat16 support for SparseSegmentMean*/SparseSegmentSqrtN*

PiperOrigin-RevId: 313461080
Change-Id: Ibf4ed3b6531231e797358940e584471f2c682848
This commit is contained in:
A. Unique TensorFlower 2020-05-27 14:10:58 -07:00 committed by TensorFlower Gardener
parent 1c1d4b619a
commit 67a155f782
3 changed files with 90 additions and 30 deletions

View File

@ -508,6 +508,12 @@ class SparseSegmentReductionOpBase : public OpKernel {
errors::InvalidArgument("segment ids must be >= 0"));
auto output_flat = output->flat_outer_dims<T>();
Tensor temp;
if (input.dtype() == DT_BFLOAT16) {
temp = tensorflow::Tensor(DT_FLOAT, output_shape);
}
auto temp_flat = temp.flat_outer_dims<float>();
int64 start = 0, end = 1;
// Index from which the output is not initialized.
SegmentId uninitialized_index = 0;
@ -546,8 +552,9 @@ class SparseSegmentReductionOpBase : public OpKernel {
}
auto out = output_flat.template chip<0>(out_index);
const int bad_offset =
Reduce(input_flat, indices_vec, start, end - start, out);
auto temp = temp_flat.template chip<0>(out_index);
const int bad_offset = Reduce<T, Index>(input_flat, indices_vec, start,
end - start, out, temp);
OP_REQUIRES(context, bad_offset < 0,
errors::InvalidArgument(
"Bad: indices[", start + bad_offset,
@ -572,40 +579,89 @@ class SparseSegmentReductionOpBase : public OpKernel {
}
private:
int64 Reduce(const typename TTypes<T>::ConstMatrix& input_flat,
const typename TTypes<Index>::ConstVec& indices_vec, int64 start,
int64 num,
Eigen::TensorChippingOp<0, typename TTypes<T>::Matrix> out) {
template <typename Tin>
using EnableIfBfloat16 =
typename std::enable_if<std::is_same<Tin, bfloat16>::value, int>::type;
template <typename Tin>
using EnableIfNotBfloat16 =
typename std::enable_if<!std::is_same<Tin, bfloat16>::value, int>::type;
template <typename Tin, typename Tindex, EnableIfNotBfloat16<Tin> = 0>
EIGEN_ALWAYS_INLINE auto fetch_val(
const typename TTypes<Tin>::ConstMatrix& input_flat, Tindex index) {
return input_flat.template chip<0>(index);
}
template <typename Tin, typename Tindex, EnableIfBfloat16<Tin> = 0>
EIGEN_ALWAYS_INLINE auto fetch_val(
const typename TTypes<Tin>::ConstMatrix& input_flat, Tindex index) {
return input_flat.template chip<0>(index).template cast<float>();
}
template <typename Tout>
EIGEN_ALWAYS_INLINE Tout get_scaling_factor(int64 num) {
Tout m(1);
if (is_mean_ && (num < 10)) {
m = Tout(num);
}
if (is_sqrtn_ && (num < 10)) {
m = Tout(sqrt(num));
}
return Tout(1) / m;
}
template <typename Tin, typename Tindex, EnableIfNotBfloat16<Tin> = 0>
int64 Reduce(
const typename TTypes<Tin>::ConstMatrix& input_flat,
const typename TTypes<Tindex>::ConstVec& indices_vec, int64 start,
int64 num, Eigen::TensorChippingOp<0, typename TTypes<Tin>::Matrix> out,
Eigen::TensorChippingOp<0, typename TTypes<float>::Matrix> temp) {
return ReduceImpl<Tin, Tindex, Tin>(input_flat, indices_vec, start, num,
out, get_scaling_factor<Tin>(num));
}
template <typename Tin, typename Tindex, EnableIfBfloat16<Tin> = 0>
int64 Reduce(
const typename TTypes<Tin>::ConstMatrix& input_flat,
const typename TTypes<Tindex>::ConstVec& indices_vec, int64 start,
int64 num, Eigen::TensorChippingOp<0, typename TTypes<Tin>::Matrix> out,
Eigen::TensorChippingOp<0, typename TTypes<float>::Matrix> temp) {
int64 res =
ReduceImpl<Tin, Tindex, float>(input_flat, indices_vec, start, num,
temp, get_scaling_factor<float>(num));
out = temp.template cast<bfloat16>();
return res;
}
template <typename Tin, typename Tindex, typename Tout>
int64 ReduceImpl(
const typename TTypes<Tin>::ConstMatrix& input_flat,
const typename TTypes<Tindex>::ConstVec& indices_vec, int64 start,
int64 num, Eigen::TensorChippingOp<0, typename TTypes<Tout>::Matrix> out,
const Tout scaling_factor) {
#define INDEX(n, i) \
const auto index##n = indices_vec(start + (i)); \
if (!FastBoundsCheck(index##n, input_flat.dimension(0))) return (i);
#define L(n) input_flat.template chip<0>(index##n)
#define L(n) fetch_val<Tin, Tindex>(input_flat, index##n)
if (num == 1) {
INDEX(0, 0);
out = L(0);
} else {
int64 r = num % 8;
T m(1);
if (is_mean_ && (num < 10)) {
m = T(num);
}
if (is_sqrtn_ && (num < 10)) {
m = T(sqrt(num));
}
int64 r = num & 7;
switch (r) {
case 2: {
INDEX(0, 0);
INDEX(1, 1);
out = (L(0) + L(1)) / m;
out = (L(0) + L(1)) * scaling_factor;
break;
}
case 3: {
INDEX(0, 0);
INDEX(1, 1);
INDEX(2, 2);
out = (L(0) + L(1) + L(2)) / m;
out = (L(0) + L(1) + L(2)) * scaling_factor;
break;
}
case 4: {
@ -613,7 +669,7 @@ class SparseSegmentReductionOpBase : public OpKernel {
INDEX(1, 1);
INDEX(2, 2);
INDEX(3, 3);
out = (L(0) + L(1) + L(2) + L(3)) / m;
out = (L(0) + L(1) + L(2) + L(3)) * scaling_factor;
break;
}
case 5: {
@ -622,7 +678,7 @@ class SparseSegmentReductionOpBase : public OpKernel {
INDEX(2, 2);
INDEX(3, 3);
INDEX(4, 4);
out = (L(0) + L(1) + L(2) + L(3) + L(4)) / m;
out = (L(0) + L(1) + L(2) + L(3) + L(4)) * scaling_factor;
break;
}
case 6: {
@ -632,7 +688,7 @@ class SparseSegmentReductionOpBase : public OpKernel {
INDEX(3, 3);
INDEX(4, 4);
INDEX(5, 5);
out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5)) / m;
out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5)) * scaling_factor;
break;
}
case 7: {
@ -643,7 +699,8 @@ class SparseSegmentReductionOpBase : public OpKernel {
INDEX(4, 4);
INDEX(5, 5);
INDEX(6, 6);
out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6)) / m;
out =
(L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6)) * scaling_factor;
break;
}
case 0: {
@ -655,7 +712,8 @@ class SparseSegmentReductionOpBase : public OpKernel {
INDEX(5, 5);
INDEX(6, 6);
INDEX(7, 7);
out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7)) / m;
out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7)) *
scaling_factor;
r = 8;
break;
}
@ -669,8 +727,8 @@ class SparseSegmentReductionOpBase : public OpKernel {
INDEX(6, 6);
INDEX(7, 7);
INDEX(8, 8);
out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7) + L(8)) /
m;
out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7) + L(8)) *
scaling_factor;
r = 9;
break;
}
@ -687,10 +745,10 @@ class SparseSegmentReductionOpBase : public OpKernel {
out += L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7);
}
if (is_mean_ && num >= 10) {
out = out / static_cast<T>(num);
out = out / static_cast<Tout>(num);
}
if (is_sqrtn_ && num >= 10) {
out = out / static_cast<T>(sqrt(num));
out = out / static_cast<Tout>(sqrt(num));
}
}

View File

@ -64,6 +64,7 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE);
segment_ids_type>);
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(float);
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(double);
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(bfloat16);
#undef REGISTER_CPU_SPARSE_KERNELS
#define REGISTER_CPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \
@ -85,6 +86,7 @@ REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(double);
CPUDevice, type, index_type, segment_ids_type>);
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(float);
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(double);
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(bfloat16);
#undef REGISTER_CPU_SPARSE_KERNELS
#define REGISTER_CPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \

View File

@ -1337,7 +1337,7 @@ REGISTER_OP("SparseSegmentMean")
.Input("indices: Tidx")
.Input("segment_ids: Tsegmentids")
.Output("output: T")
.Attr("T: {float, double}")
.Attr("T: {bfloat16, float, double}")
.Attr("Tidx: {int32, int64} = DT_INT32")
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
.SetShapeFn(SparseSegmentReductionShapeFn);
@ -1348,7 +1348,7 @@ REGISTER_OP("SparseSegmentMeanWithNumSegments")
.Input("segment_ids: Tsegmentids")
.Input("num_segments: Tnumsegments")
.Output("output: T")
.Attr("T: {float, double}")
.Attr("T: {bfloat16, float, double}")
.Attr("Tidx: {int32, int64} = DT_INT32")
.Attr("Tnumsegments: {int32,int64} = DT_INT32")
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
@ -1370,7 +1370,7 @@ REGISTER_OP("SparseSegmentSqrtN")
.Input("indices: Tidx")
.Input("segment_ids: Tsegmentids")
.Output("output: T")
.Attr("T: {float, double}")
.Attr("T: {bfloat16, float, double}")
.Attr("Tidx: {int32, int64} = DT_INT32")
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
.SetShapeFn(SparseSegmentReductionShapeFn);
@ -1381,7 +1381,7 @@ REGISTER_OP("SparseSegmentSqrtNWithNumSegments")
.Input("segment_ids: Tsegmentids")
.Input("num_segments: Tnumsegments")
.Output("output: T")
.Attr("T: {float, double}")
.Attr("T: {bfloat16, float, double}")
.Attr("Tidx: {int32, int64} = DT_INT32")
.Attr("Tnumsegments: {int32,int64} = DT_INT32")
.Attr("Tsegmentids: {int32, int64} = DT_INT32")