Add bfloat16 support for SparseSegmentMean*/SparseSegmentSqrtN*
PiperOrigin-RevId: 313461080 Change-Id: Ibf4ed3b6531231e797358940e584471f2c682848
This commit is contained in:
parent
1c1d4b619a
commit
67a155f782
@ -508,6 +508,12 @@ class SparseSegmentReductionOpBase : public OpKernel {
|
|||||||
errors::InvalidArgument("segment ids must be >= 0"));
|
errors::InvalidArgument("segment ids must be >= 0"));
|
||||||
auto output_flat = output->flat_outer_dims<T>();
|
auto output_flat = output->flat_outer_dims<T>();
|
||||||
|
|
||||||
|
Tensor temp;
|
||||||
|
if (input.dtype() == DT_BFLOAT16) {
|
||||||
|
temp = tensorflow::Tensor(DT_FLOAT, output_shape);
|
||||||
|
}
|
||||||
|
auto temp_flat = temp.flat_outer_dims<float>();
|
||||||
|
|
||||||
int64 start = 0, end = 1;
|
int64 start = 0, end = 1;
|
||||||
// Index from which the output is not initialized.
|
// Index from which the output is not initialized.
|
||||||
SegmentId uninitialized_index = 0;
|
SegmentId uninitialized_index = 0;
|
||||||
@ -546,8 +552,9 @@ class SparseSegmentReductionOpBase : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto out = output_flat.template chip<0>(out_index);
|
auto out = output_flat.template chip<0>(out_index);
|
||||||
const int bad_offset =
|
auto temp = temp_flat.template chip<0>(out_index);
|
||||||
Reduce(input_flat, indices_vec, start, end - start, out);
|
const int bad_offset = Reduce<T, Index>(input_flat, indices_vec, start,
|
||||||
|
end - start, out, temp);
|
||||||
OP_REQUIRES(context, bad_offset < 0,
|
OP_REQUIRES(context, bad_offset < 0,
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
"Bad: indices[", start + bad_offset,
|
"Bad: indices[", start + bad_offset,
|
||||||
@ -572,40 +579,89 @@ class SparseSegmentReductionOpBase : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int64 Reduce(const typename TTypes<T>::ConstMatrix& input_flat,
|
template <typename Tin>
|
||||||
const typename TTypes<Index>::ConstVec& indices_vec, int64 start,
|
using EnableIfBfloat16 =
|
||||||
int64 num,
|
typename std::enable_if<std::is_same<Tin, bfloat16>::value, int>::type;
|
||||||
Eigen::TensorChippingOp<0, typename TTypes<T>::Matrix> out) {
|
template <typename Tin>
|
||||||
|
using EnableIfNotBfloat16 =
|
||||||
|
typename std::enable_if<!std::is_same<Tin, bfloat16>::value, int>::type;
|
||||||
|
|
||||||
|
template <typename Tin, typename Tindex, EnableIfNotBfloat16<Tin> = 0>
|
||||||
|
EIGEN_ALWAYS_INLINE auto fetch_val(
|
||||||
|
const typename TTypes<Tin>::ConstMatrix& input_flat, Tindex index) {
|
||||||
|
return input_flat.template chip<0>(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Tin, typename Tindex, EnableIfBfloat16<Tin> = 0>
|
||||||
|
EIGEN_ALWAYS_INLINE auto fetch_val(
|
||||||
|
const typename TTypes<Tin>::ConstMatrix& input_flat, Tindex index) {
|
||||||
|
return input_flat.template chip<0>(index).template cast<float>();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Tout>
|
||||||
|
EIGEN_ALWAYS_INLINE Tout get_scaling_factor(int64 num) {
|
||||||
|
Tout m(1);
|
||||||
|
if (is_mean_ && (num < 10)) {
|
||||||
|
m = Tout(num);
|
||||||
|
}
|
||||||
|
if (is_sqrtn_ && (num < 10)) {
|
||||||
|
m = Tout(sqrt(num));
|
||||||
|
}
|
||||||
|
return Tout(1) / m;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Tin, typename Tindex, EnableIfNotBfloat16<Tin> = 0>
|
||||||
|
int64 Reduce(
|
||||||
|
const typename TTypes<Tin>::ConstMatrix& input_flat,
|
||||||
|
const typename TTypes<Tindex>::ConstVec& indices_vec, int64 start,
|
||||||
|
int64 num, Eigen::TensorChippingOp<0, typename TTypes<Tin>::Matrix> out,
|
||||||
|
Eigen::TensorChippingOp<0, typename TTypes<float>::Matrix> temp) {
|
||||||
|
return ReduceImpl<Tin, Tindex, Tin>(input_flat, indices_vec, start, num,
|
||||||
|
out, get_scaling_factor<Tin>(num));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Tin, typename Tindex, EnableIfBfloat16<Tin> = 0>
|
||||||
|
int64 Reduce(
|
||||||
|
const typename TTypes<Tin>::ConstMatrix& input_flat,
|
||||||
|
const typename TTypes<Tindex>::ConstVec& indices_vec, int64 start,
|
||||||
|
int64 num, Eigen::TensorChippingOp<0, typename TTypes<Tin>::Matrix> out,
|
||||||
|
Eigen::TensorChippingOp<0, typename TTypes<float>::Matrix> temp) {
|
||||||
|
int64 res =
|
||||||
|
ReduceImpl<Tin, Tindex, float>(input_flat, indices_vec, start, num,
|
||||||
|
temp, get_scaling_factor<float>(num));
|
||||||
|
out = temp.template cast<bfloat16>();
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Tin, typename Tindex, typename Tout>
|
||||||
|
int64 ReduceImpl(
|
||||||
|
const typename TTypes<Tin>::ConstMatrix& input_flat,
|
||||||
|
const typename TTypes<Tindex>::ConstVec& indices_vec, int64 start,
|
||||||
|
int64 num, Eigen::TensorChippingOp<0, typename TTypes<Tout>::Matrix> out,
|
||||||
|
const Tout scaling_factor) {
|
||||||
#define INDEX(n, i) \
|
#define INDEX(n, i) \
|
||||||
const auto index##n = indices_vec(start + (i)); \
|
const auto index##n = indices_vec(start + (i)); \
|
||||||
if (!FastBoundsCheck(index##n, input_flat.dimension(0))) return (i);
|
if (!FastBoundsCheck(index##n, input_flat.dimension(0))) return (i);
|
||||||
|
|
||||||
#define L(n) input_flat.template chip<0>(index##n)
|
#define L(n) fetch_val<Tin, Tindex>(input_flat, index##n)
|
||||||
|
|
||||||
if (num == 1) {
|
if (num == 1) {
|
||||||
INDEX(0, 0);
|
INDEX(0, 0);
|
||||||
out = L(0);
|
out = L(0);
|
||||||
} else {
|
} else {
|
||||||
int64 r = num % 8;
|
int64 r = num & 7;
|
||||||
T m(1);
|
|
||||||
if (is_mean_ && (num < 10)) {
|
|
||||||
m = T(num);
|
|
||||||
}
|
|
||||||
if (is_sqrtn_ && (num < 10)) {
|
|
||||||
m = T(sqrt(num));
|
|
||||||
}
|
|
||||||
switch (r) {
|
switch (r) {
|
||||||
case 2: {
|
case 2: {
|
||||||
INDEX(0, 0);
|
INDEX(0, 0);
|
||||||
INDEX(1, 1);
|
INDEX(1, 1);
|
||||||
out = (L(0) + L(1)) / m;
|
out = (L(0) + L(1)) * scaling_factor;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case 3: {
|
case 3: {
|
||||||
INDEX(0, 0);
|
INDEX(0, 0);
|
||||||
INDEX(1, 1);
|
INDEX(1, 1);
|
||||||
INDEX(2, 2);
|
INDEX(2, 2);
|
||||||
out = (L(0) + L(1) + L(2)) / m;
|
out = (L(0) + L(1) + L(2)) * scaling_factor;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case 4: {
|
case 4: {
|
||||||
@ -613,7 +669,7 @@ class SparseSegmentReductionOpBase : public OpKernel {
|
|||||||
INDEX(1, 1);
|
INDEX(1, 1);
|
||||||
INDEX(2, 2);
|
INDEX(2, 2);
|
||||||
INDEX(3, 3);
|
INDEX(3, 3);
|
||||||
out = (L(0) + L(1) + L(2) + L(3)) / m;
|
out = (L(0) + L(1) + L(2) + L(3)) * scaling_factor;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case 5: {
|
case 5: {
|
||||||
@ -622,7 +678,7 @@ class SparseSegmentReductionOpBase : public OpKernel {
|
|||||||
INDEX(2, 2);
|
INDEX(2, 2);
|
||||||
INDEX(3, 3);
|
INDEX(3, 3);
|
||||||
INDEX(4, 4);
|
INDEX(4, 4);
|
||||||
out = (L(0) + L(1) + L(2) + L(3) + L(4)) / m;
|
out = (L(0) + L(1) + L(2) + L(3) + L(4)) * scaling_factor;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case 6: {
|
case 6: {
|
||||||
@ -632,7 +688,7 @@ class SparseSegmentReductionOpBase : public OpKernel {
|
|||||||
INDEX(3, 3);
|
INDEX(3, 3);
|
||||||
INDEX(4, 4);
|
INDEX(4, 4);
|
||||||
INDEX(5, 5);
|
INDEX(5, 5);
|
||||||
out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5)) / m;
|
out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5)) * scaling_factor;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case 7: {
|
case 7: {
|
||||||
@ -643,7 +699,8 @@ class SparseSegmentReductionOpBase : public OpKernel {
|
|||||||
INDEX(4, 4);
|
INDEX(4, 4);
|
||||||
INDEX(5, 5);
|
INDEX(5, 5);
|
||||||
INDEX(6, 6);
|
INDEX(6, 6);
|
||||||
out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6)) / m;
|
out =
|
||||||
|
(L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6)) * scaling_factor;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case 0: {
|
case 0: {
|
||||||
@ -655,7 +712,8 @@ class SparseSegmentReductionOpBase : public OpKernel {
|
|||||||
INDEX(5, 5);
|
INDEX(5, 5);
|
||||||
INDEX(6, 6);
|
INDEX(6, 6);
|
||||||
INDEX(7, 7);
|
INDEX(7, 7);
|
||||||
out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7)) / m;
|
out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7)) *
|
||||||
|
scaling_factor;
|
||||||
r = 8;
|
r = 8;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -669,8 +727,8 @@ class SparseSegmentReductionOpBase : public OpKernel {
|
|||||||
INDEX(6, 6);
|
INDEX(6, 6);
|
||||||
INDEX(7, 7);
|
INDEX(7, 7);
|
||||||
INDEX(8, 8);
|
INDEX(8, 8);
|
||||||
out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7) + L(8)) /
|
out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7) + L(8)) *
|
||||||
m;
|
scaling_factor;
|
||||||
r = 9;
|
r = 9;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -687,10 +745,10 @@ class SparseSegmentReductionOpBase : public OpKernel {
|
|||||||
out += L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7);
|
out += L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7);
|
||||||
}
|
}
|
||||||
if (is_mean_ && num >= 10) {
|
if (is_mean_ && num >= 10) {
|
||||||
out = out / static_cast<T>(num);
|
out = out / static_cast<Tout>(num);
|
||||||
}
|
}
|
||||||
if (is_sqrtn_ && num >= 10) {
|
if (is_sqrtn_ && num >= 10) {
|
||||||
out = out / static_cast<T>(sqrt(num));
|
out = out / static_cast<Tout>(sqrt(num));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -64,6 +64,7 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE);
|
|||||||
segment_ids_type>);
|
segment_ids_type>);
|
||||||
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(float);
|
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(float);
|
||||||
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(double);
|
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(double);
|
||||||
|
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(bfloat16);
|
||||||
#undef REGISTER_CPU_SPARSE_KERNELS
|
#undef REGISTER_CPU_SPARSE_KERNELS
|
||||||
|
|
||||||
#define REGISTER_CPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \
|
#define REGISTER_CPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \
|
||||||
@ -85,6 +86,7 @@ REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(double);
|
|||||||
CPUDevice, type, index_type, segment_ids_type>);
|
CPUDevice, type, index_type, segment_ids_type>);
|
||||||
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(float);
|
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(float);
|
||||||
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(double);
|
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(double);
|
||||||
|
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(bfloat16);
|
||||||
#undef REGISTER_CPU_SPARSE_KERNELS
|
#undef REGISTER_CPU_SPARSE_KERNELS
|
||||||
|
|
||||||
#define REGISTER_CPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \
|
#define REGISTER_CPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \
|
||||||
|
@ -1337,7 +1337,7 @@ REGISTER_OP("SparseSegmentMean")
|
|||||||
.Input("indices: Tidx")
|
.Input("indices: Tidx")
|
||||||
.Input("segment_ids: Tsegmentids")
|
.Input("segment_ids: Tsegmentids")
|
||||||
.Output("output: T")
|
.Output("output: T")
|
||||||
.Attr("T: {float, double}")
|
.Attr("T: {bfloat16, float, double}")
|
||||||
.Attr("Tidx: {int32, int64} = DT_INT32")
|
.Attr("Tidx: {int32, int64} = DT_INT32")
|
||||||
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
|
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
|
||||||
.SetShapeFn(SparseSegmentReductionShapeFn);
|
.SetShapeFn(SparseSegmentReductionShapeFn);
|
||||||
@ -1348,7 +1348,7 @@ REGISTER_OP("SparseSegmentMeanWithNumSegments")
|
|||||||
.Input("segment_ids: Tsegmentids")
|
.Input("segment_ids: Tsegmentids")
|
||||||
.Input("num_segments: Tnumsegments")
|
.Input("num_segments: Tnumsegments")
|
||||||
.Output("output: T")
|
.Output("output: T")
|
||||||
.Attr("T: {float, double}")
|
.Attr("T: {bfloat16, float, double}")
|
||||||
.Attr("Tidx: {int32, int64} = DT_INT32")
|
.Attr("Tidx: {int32, int64} = DT_INT32")
|
||||||
.Attr("Tnumsegments: {int32,int64} = DT_INT32")
|
.Attr("Tnumsegments: {int32,int64} = DT_INT32")
|
||||||
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
|
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
|
||||||
@ -1370,7 +1370,7 @@ REGISTER_OP("SparseSegmentSqrtN")
|
|||||||
.Input("indices: Tidx")
|
.Input("indices: Tidx")
|
||||||
.Input("segment_ids: Tsegmentids")
|
.Input("segment_ids: Tsegmentids")
|
||||||
.Output("output: T")
|
.Output("output: T")
|
||||||
.Attr("T: {float, double}")
|
.Attr("T: {bfloat16, float, double}")
|
||||||
.Attr("Tidx: {int32, int64} = DT_INT32")
|
.Attr("Tidx: {int32, int64} = DT_INT32")
|
||||||
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
|
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
|
||||||
.SetShapeFn(SparseSegmentReductionShapeFn);
|
.SetShapeFn(SparseSegmentReductionShapeFn);
|
||||||
@ -1381,7 +1381,7 @@ REGISTER_OP("SparseSegmentSqrtNWithNumSegments")
|
|||||||
.Input("segment_ids: Tsegmentids")
|
.Input("segment_ids: Tsegmentids")
|
||||||
.Input("num_segments: Tnumsegments")
|
.Input("num_segments: Tnumsegments")
|
||||||
.Output("output: T")
|
.Output("output: T")
|
||||||
.Attr("T: {float, double}")
|
.Attr("T: {bfloat16, float, double}")
|
||||||
.Attr("Tidx: {int32, int64} = DT_INT32")
|
.Attr("Tidx: {int32, int64} = DT_INT32")
|
||||||
.Attr("Tnumsegments: {int32,int64} = DT_INT32")
|
.Attr("Tnumsegments: {int32,int64} = DT_INT32")
|
||||||
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
|
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
|
||||||
|
Loading…
Reference in New Issue
Block a user