Add bfloat16 support for SparseSegmentMean*/SparseSegmentSqrtN*

PiperOrigin-RevId: 312764313
Change-Id: I1e5de7e48f6e42a5c22012954b59ba1fea304441
This commit is contained in:
A. Unique TensorFlower 2020-05-21 16:37:52 -07:00 committed by TensorFlower Gardener
parent 19ed4a9ccf
commit 09243a984d
3 changed files with 154 additions and 123 deletions

View File

@ -508,6 +508,12 @@ class SparseSegmentReductionOpBase : public OpKernel {
errors::InvalidArgument("segment ids must be >= 0"));
auto output_flat = output->flat_outer_dims<T>();
Tensor temp;
if constexpr (std::is_same<T, bfloat16>::value) {
temp = tensorflow::Tensor(DT_FLOAT, output_shape);
}
auto temp_flat = temp.flat_outer_dims<float>();
int64 start = 0, end = 1;
// Index from which the output is not initialized.
SegmentId uninitialized_index = 0;
@ -546,8 +552,9 @@ class SparseSegmentReductionOpBase : public OpKernel {
}
auto out = output_flat.template chip<0>(out_index);
auto temp = temp_flat.template chip<0>(out_index);
const int bad_offset =
Reduce(input_flat, indices_vec, start, end - start, out);
Reduce(input_flat, indices_vec, start, end - start, out, temp);
OP_REQUIRES(context, bad_offset < 0,
errors::InvalidArgument(
"Bad: indices[", start + bad_offset,
@ -572,130 +579,152 @@ class SparseSegmentReductionOpBase : public OpKernel {
}
private:
int64 Reduce(const typename TTypes<T>::ConstMatrix& input_flat,
const typename TTypes<Index>::ConstVec& indices_vec, int64 start,
int64 num,
Eigen::TensorChippingOp<0, typename TTypes<T>::Matrix> out) {
// TODO(jaideepsi): re-write without macros, simplify Reduce b/157240265
int64 Reduce(
const typename TTypes<T>::ConstMatrix& input_flat,
const typename TTypes<Index>::ConstVec& indices_vec, int64 start,
int64 num, Eigen::TensorChippingOp<0, typename TTypes<T>::Matrix> out,
Eigen::TensorChippingOp<0, typename TTypes<float>::Matrix> temp) {
#define REDUCE \
if (num == 1) { \
INDEX(0, 0); \
OUT = L(0); \
} else { \
int64 r = num & 7; \
DT m(1); \
if (is_mean_ && (num < 10)) { \
m = DT(num); \
} \
if (is_sqrtn_ && (num < 10)) { \
m = DT(sqrt(num)); \
} \
switch (r) { \
case 2: { \
INDEX(0, 0); \
INDEX(1, 1); \
OUT = (L(0) + L(1)) / m; \
break; \
} \
case 3: { \
INDEX(0, 0); \
INDEX(1, 1); \
INDEX(2, 2); \
OUT = (L(0) + L(1) + L(2)) / m; \
break; \
} \
case 4: { \
INDEX(0, 0); \
INDEX(1, 1); \
INDEX(2, 2); \
INDEX(3, 3); \
OUT = (L(0) + L(1) + L(2) + L(3)) / m; \
break; \
} \
case 5: { \
INDEX(0, 0); \
INDEX(1, 1); \
INDEX(2, 2); \
INDEX(3, 3); \
INDEX(4, 4); \
OUT = (L(0) + L(1) + L(2) + L(3) + L(4)) / m; \
break; \
} \
case 6: { \
INDEX(0, 0); \
INDEX(1, 1); \
INDEX(2, 2); \
INDEX(3, 3); \
INDEX(4, 4); \
INDEX(5, 5); \
OUT = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5)) / m; \
break; \
} \
case 7: { \
INDEX(0, 0); \
INDEX(1, 1); \
INDEX(2, 2); \
INDEX(3, 3); \
INDEX(4, 4); \
INDEX(5, 5); \
INDEX(6, 6); \
OUT = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6)) / m; \
break; \
} \
case 0: { \
INDEX(0, 0); \
INDEX(1, 1); \
INDEX(2, 2); \
INDEX(3, 3); \
INDEX(4, 4); \
INDEX(5, 5); \
INDEX(6, 6); \
INDEX(7, 7); \
OUT = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7)) / m; \
r = 8; \
break; \
} \
case 1: { \
INDEX(0, 0); \
INDEX(1, 1); \
INDEX(2, 2); \
INDEX(3, 3); \
INDEX(4, 4); \
INDEX(5, 5); \
INDEX(6, 6); \
INDEX(7, 7); \
INDEX(8, 8); \
OUT = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7) + L(8)) / \
m; \
r = 9; \
break; \
} \
} \
for (; r < num; r += 8) { \
INDEX(0, r); \
INDEX(1, r + 1); \
INDEX(2, r + 2); \
INDEX(3, r + 3); \
INDEX(4, r + 4); \
INDEX(5, r + 5); \
INDEX(6, r + 6); \
INDEX(7, r + 7); \
OUT += L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7); \
} \
if (is_mean_ && num >= 10) { \
OUT = OUT / static_cast<DT>(num); \
} \
if (is_sqrtn_ && num >= 10) { \
OUT = OUT / static_cast<DT>(sqrt(num)); \
} \
}
#define INDEX(n, i) \
const auto index##n = indices_vec(start + (i)); \
if (!FastBoundsCheck(index##n, input_flat.dimension(0))) return (i);
#define L(n) input_flat.template chip<0>(index##n)
if constexpr (std::is_same<T, bfloat16>::value) {
#define L(n) input_flat.template chip<0>(index##n).template cast<float>()
#define OUT temp
#define DT float
if (num == 1) {
INDEX(0, 0);
out = L(0);
} else {
int64 r = num % 8;
T m(1);
if (is_mean_ && (num < 10)) {
m = T(num);
}
if (is_sqrtn_ && (num < 10)) {
m = T(sqrt(num));
}
switch (r) {
case 2: {
INDEX(0, 0);
INDEX(1, 1);
out = (L(0) + L(1)) / m;
break;
}
case 3: {
INDEX(0, 0);
INDEX(1, 1);
INDEX(2, 2);
out = (L(0) + L(1) + L(2)) / m;
break;
}
case 4: {
INDEX(0, 0);
INDEX(1, 1);
INDEX(2, 2);
INDEX(3, 3);
out = (L(0) + L(1) + L(2) + L(3)) / m;
break;
}
case 5: {
INDEX(0, 0);
INDEX(1, 1);
INDEX(2, 2);
INDEX(3, 3);
INDEX(4, 4);
out = (L(0) + L(1) + L(2) + L(3) + L(4)) / m;
break;
}
case 6: {
INDEX(0, 0);
INDEX(1, 1);
INDEX(2, 2);
INDEX(3, 3);
INDEX(4, 4);
INDEX(5, 5);
out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5)) / m;
break;
}
case 7: {
INDEX(0, 0);
INDEX(1, 1);
INDEX(2, 2);
INDEX(3, 3);
INDEX(4, 4);
INDEX(5, 5);
INDEX(6, 6);
out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6)) / m;
break;
}
case 0: {
INDEX(0, 0);
INDEX(1, 1);
INDEX(2, 2);
INDEX(3, 3);
INDEX(4, 4);
INDEX(5, 5);
INDEX(6, 6);
INDEX(7, 7);
out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7)) / m;
r = 8;
break;
}
case 1: {
INDEX(0, 0);
INDEX(1, 1);
INDEX(2, 2);
INDEX(3, 3);
INDEX(4, 4);
INDEX(5, 5);
INDEX(6, 6);
INDEX(7, 7);
INDEX(8, 8);
out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7) + L(8)) /
m;
r = 9;
break;
}
}
for (; r < num; r += 8) {
INDEX(0, r);
INDEX(1, r + 1);
INDEX(2, r + 2);
INDEX(3, r + 3);
INDEX(4, r + 4);
INDEX(5, r + 5);
INDEX(6, r + 6);
INDEX(7, r + 7);
out += L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7);
}
if (is_mean_ && num >= 10) {
out = out / static_cast<T>(num);
}
if (is_sqrtn_ && num >= 10) {
out = out / static_cast<T>(sqrt(num));
}
}
return -1;
REDUCE;
out = temp.template cast<bfloat16>();
#undef DT
#undef OUT
#undef L
} else {
#define L(n) input_flat.template chip<0>(index##n)
#define OUT out
#define DT T
REDUCE;
#undef DT
#undef OUT
#undef L
}
return -1;
#undef REDUCE
#undef INDEX
}

View File

@ -64,6 +64,7 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE);
segment_ids_type>);
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(float);
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(double);
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(bfloat16);
#undef REGISTER_CPU_SPARSE_KERNELS
#define REGISTER_CPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \
@ -85,6 +86,7 @@ REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(double);
CPUDevice, type, index_type, segment_ids_type>);
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(float);
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(double);
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(bfloat16);
#undef REGISTER_CPU_SPARSE_KERNELS
#define REGISTER_CPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \

View File

@ -1337,7 +1337,7 @@ REGISTER_OP("SparseSegmentMean")
.Input("indices: Tidx")
.Input("segment_ids: Tsegmentids")
.Output("output: T")
.Attr("T: {float, double}")
.Attr("T: {bfloat16, float, double}")
.Attr("Tidx: {int32, int64} = DT_INT32")
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
.SetShapeFn(SparseSegmentReductionShapeFn);
@ -1348,7 +1348,7 @@ REGISTER_OP("SparseSegmentMeanWithNumSegments")
.Input("segment_ids: Tsegmentids")
.Input("num_segments: Tnumsegments")
.Output("output: T")
.Attr("T: {float, double}")
.Attr("T: {bfloat16, float, double}")
.Attr("Tidx: {int32, int64} = DT_INT32")
.Attr("Tnumsegments: {int32,int64} = DT_INT32")
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
@ -1370,7 +1370,7 @@ REGISTER_OP("SparseSegmentSqrtN")
.Input("indices: Tidx")
.Input("segment_ids: Tsegmentids")
.Output("output: T")
.Attr("T: {float, double}")
.Attr("T: {bfloat16, float, double}")
.Attr("Tidx: {int32, int64} = DT_INT32")
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
.SetShapeFn(SparseSegmentReductionShapeFn);
@ -1381,7 +1381,7 @@ REGISTER_OP("SparseSegmentSqrtNWithNumSegments")
.Input("segment_ids: Tsegmentids")
.Input("num_segments: Tnumsegments")
.Output("output: T")
.Attr("T: {float, double}")
.Attr("T: {bfloat16, float, double}")
.Attr("Tidx: {int32, int64} = DT_INT32")
.Attr("Tnumsegments: {int32,int64} = DT_INT32")
.Attr("Tsegmentids: {int32, int64} = DT_INT32")