Add bfloat16 support for SparseSegmentMean*/SparseSegmentSqrtN*
PiperOrigin-RevId: 312764313 Change-Id: I1e5de7e48f6e42a5c22012954b59ba1fea304441
This commit is contained in:
parent
19ed4a9ccf
commit
09243a984d
|
@ -508,6 +508,12 @@ class SparseSegmentReductionOpBase : public OpKernel {
|
|||
errors::InvalidArgument("segment ids must be >= 0"));
|
||||
auto output_flat = output->flat_outer_dims<T>();
|
||||
|
||||
Tensor temp;
|
||||
if constexpr (std::is_same<T, bfloat16>::value) {
|
||||
temp = tensorflow::Tensor(DT_FLOAT, output_shape);
|
||||
}
|
||||
auto temp_flat = temp.flat_outer_dims<float>();
|
||||
|
||||
int64 start = 0, end = 1;
|
||||
// Index from which the output is not initialized.
|
||||
SegmentId uninitialized_index = 0;
|
||||
|
@ -546,8 +552,9 @@ class SparseSegmentReductionOpBase : public OpKernel {
|
|||
}
|
||||
|
||||
auto out = output_flat.template chip<0>(out_index);
|
||||
auto temp = temp_flat.template chip<0>(out_index);
|
||||
const int bad_offset =
|
||||
Reduce(input_flat, indices_vec, start, end - start, out);
|
||||
Reduce(input_flat, indices_vec, start, end - start, out, temp);
|
||||
OP_REQUIRES(context, bad_offset < 0,
|
||||
errors::InvalidArgument(
|
||||
"Bad: indices[", start + bad_offset,
|
||||
|
@ -572,130 +579,152 @@ class SparseSegmentReductionOpBase : public OpKernel {
|
|||
}
|
||||
|
||||
private:
|
||||
int64 Reduce(const typename TTypes<T>::ConstMatrix& input_flat,
|
||||
const typename TTypes<Index>::ConstVec& indices_vec, int64 start,
|
||||
int64 num,
|
||||
Eigen::TensorChippingOp<0, typename TTypes<T>::Matrix> out) {
|
||||
// TODO(jaideepsi): re-write without macros, simplify Reduce b/157240265
|
||||
int64 Reduce(
|
||||
const typename TTypes<T>::ConstMatrix& input_flat,
|
||||
const typename TTypes<Index>::ConstVec& indices_vec, int64 start,
|
||||
int64 num, Eigen::TensorChippingOp<0, typename TTypes<T>::Matrix> out,
|
||||
Eigen::TensorChippingOp<0, typename TTypes<float>::Matrix> temp) {
|
||||
#define REDUCE \
|
||||
if (num == 1) { \
|
||||
INDEX(0, 0); \
|
||||
OUT = L(0); \
|
||||
} else { \
|
||||
int64 r = num & 7; \
|
||||
DT m(1); \
|
||||
if (is_mean_ && (num < 10)) { \
|
||||
m = DT(num); \
|
||||
} \
|
||||
if (is_sqrtn_ && (num < 10)) { \
|
||||
m = DT(sqrt(num)); \
|
||||
} \
|
||||
switch (r) { \
|
||||
case 2: { \
|
||||
INDEX(0, 0); \
|
||||
INDEX(1, 1); \
|
||||
OUT = (L(0) + L(1)) / m; \
|
||||
break; \
|
||||
} \
|
||||
case 3: { \
|
||||
INDEX(0, 0); \
|
||||
INDEX(1, 1); \
|
||||
INDEX(2, 2); \
|
||||
OUT = (L(0) + L(1) + L(2)) / m; \
|
||||
break; \
|
||||
} \
|
||||
case 4: { \
|
||||
INDEX(0, 0); \
|
||||
INDEX(1, 1); \
|
||||
INDEX(2, 2); \
|
||||
INDEX(3, 3); \
|
||||
OUT = (L(0) + L(1) + L(2) + L(3)) / m; \
|
||||
break; \
|
||||
} \
|
||||
case 5: { \
|
||||
INDEX(0, 0); \
|
||||
INDEX(1, 1); \
|
||||
INDEX(2, 2); \
|
||||
INDEX(3, 3); \
|
||||
INDEX(4, 4); \
|
||||
OUT = (L(0) + L(1) + L(2) + L(3) + L(4)) / m; \
|
||||
break; \
|
||||
} \
|
||||
case 6: { \
|
||||
INDEX(0, 0); \
|
||||
INDEX(1, 1); \
|
||||
INDEX(2, 2); \
|
||||
INDEX(3, 3); \
|
||||
INDEX(4, 4); \
|
||||
INDEX(5, 5); \
|
||||
OUT = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5)) / m; \
|
||||
break; \
|
||||
} \
|
||||
case 7: { \
|
||||
INDEX(0, 0); \
|
||||
INDEX(1, 1); \
|
||||
INDEX(2, 2); \
|
||||
INDEX(3, 3); \
|
||||
INDEX(4, 4); \
|
||||
INDEX(5, 5); \
|
||||
INDEX(6, 6); \
|
||||
OUT = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6)) / m; \
|
||||
break; \
|
||||
} \
|
||||
case 0: { \
|
||||
INDEX(0, 0); \
|
||||
INDEX(1, 1); \
|
||||
INDEX(2, 2); \
|
||||
INDEX(3, 3); \
|
||||
INDEX(4, 4); \
|
||||
INDEX(5, 5); \
|
||||
INDEX(6, 6); \
|
||||
INDEX(7, 7); \
|
||||
OUT = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7)) / m; \
|
||||
r = 8; \
|
||||
break; \
|
||||
} \
|
||||
case 1: { \
|
||||
INDEX(0, 0); \
|
||||
INDEX(1, 1); \
|
||||
INDEX(2, 2); \
|
||||
INDEX(3, 3); \
|
||||
INDEX(4, 4); \
|
||||
INDEX(5, 5); \
|
||||
INDEX(6, 6); \
|
||||
INDEX(7, 7); \
|
||||
INDEX(8, 8); \
|
||||
OUT = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7) + L(8)) / \
|
||||
m; \
|
||||
r = 9; \
|
||||
break; \
|
||||
} \
|
||||
} \
|
||||
for (; r < num; r += 8) { \
|
||||
INDEX(0, r); \
|
||||
INDEX(1, r + 1); \
|
||||
INDEX(2, r + 2); \
|
||||
INDEX(3, r + 3); \
|
||||
INDEX(4, r + 4); \
|
||||
INDEX(5, r + 5); \
|
||||
INDEX(6, r + 6); \
|
||||
INDEX(7, r + 7); \
|
||||
OUT += L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7); \
|
||||
} \
|
||||
if (is_mean_ && num >= 10) { \
|
||||
OUT = OUT / static_cast<DT>(num); \
|
||||
} \
|
||||
if (is_sqrtn_ && num >= 10) { \
|
||||
OUT = OUT / static_cast<DT>(sqrt(num)); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define INDEX(n, i) \
|
||||
const auto index##n = indices_vec(start + (i)); \
|
||||
if (!FastBoundsCheck(index##n, input_flat.dimension(0))) return (i);
|
||||
|
||||
#define L(n) input_flat.template chip<0>(index##n)
|
||||
if constexpr (std::is_same<T, bfloat16>::value) {
|
||||
#define L(n) input_flat.template chip<0>(index##n).template cast<float>()
|
||||
#define OUT temp
|
||||
#define DT float
|
||||
|
||||
if (num == 1) {
|
||||
INDEX(0, 0);
|
||||
out = L(0);
|
||||
} else {
|
||||
int64 r = num % 8;
|
||||
T m(1);
|
||||
if (is_mean_ && (num < 10)) {
|
||||
m = T(num);
|
||||
}
|
||||
if (is_sqrtn_ && (num < 10)) {
|
||||
m = T(sqrt(num));
|
||||
}
|
||||
switch (r) {
|
||||
case 2: {
|
||||
INDEX(0, 0);
|
||||
INDEX(1, 1);
|
||||
out = (L(0) + L(1)) / m;
|
||||
break;
|
||||
}
|
||||
case 3: {
|
||||
INDEX(0, 0);
|
||||
INDEX(1, 1);
|
||||
INDEX(2, 2);
|
||||
out = (L(0) + L(1) + L(2)) / m;
|
||||
break;
|
||||
}
|
||||
case 4: {
|
||||
INDEX(0, 0);
|
||||
INDEX(1, 1);
|
||||
INDEX(2, 2);
|
||||
INDEX(3, 3);
|
||||
out = (L(0) + L(1) + L(2) + L(3)) / m;
|
||||
break;
|
||||
}
|
||||
case 5: {
|
||||
INDEX(0, 0);
|
||||
INDEX(1, 1);
|
||||
INDEX(2, 2);
|
||||
INDEX(3, 3);
|
||||
INDEX(4, 4);
|
||||
out = (L(0) + L(1) + L(2) + L(3) + L(4)) / m;
|
||||
break;
|
||||
}
|
||||
case 6: {
|
||||
INDEX(0, 0);
|
||||
INDEX(1, 1);
|
||||
INDEX(2, 2);
|
||||
INDEX(3, 3);
|
||||
INDEX(4, 4);
|
||||
INDEX(5, 5);
|
||||
out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5)) / m;
|
||||
break;
|
||||
}
|
||||
case 7: {
|
||||
INDEX(0, 0);
|
||||
INDEX(1, 1);
|
||||
INDEX(2, 2);
|
||||
INDEX(3, 3);
|
||||
INDEX(4, 4);
|
||||
INDEX(5, 5);
|
||||
INDEX(6, 6);
|
||||
out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6)) / m;
|
||||
break;
|
||||
}
|
||||
case 0: {
|
||||
INDEX(0, 0);
|
||||
INDEX(1, 1);
|
||||
INDEX(2, 2);
|
||||
INDEX(3, 3);
|
||||
INDEX(4, 4);
|
||||
INDEX(5, 5);
|
||||
INDEX(6, 6);
|
||||
INDEX(7, 7);
|
||||
out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7)) / m;
|
||||
r = 8;
|
||||
break;
|
||||
}
|
||||
case 1: {
|
||||
INDEX(0, 0);
|
||||
INDEX(1, 1);
|
||||
INDEX(2, 2);
|
||||
INDEX(3, 3);
|
||||
INDEX(4, 4);
|
||||
INDEX(5, 5);
|
||||
INDEX(6, 6);
|
||||
INDEX(7, 7);
|
||||
INDEX(8, 8);
|
||||
out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7) + L(8)) /
|
||||
m;
|
||||
r = 9;
|
||||
break;
|
||||
}
|
||||
}
|
||||
for (; r < num; r += 8) {
|
||||
INDEX(0, r);
|
||||
INDEX(1, r + 1);
|
||||
INDEX(2, r + 2);
|
||||
INDEX(3, r + 3);
|
||||
INDEX(4, r + 4);
|
||||
INDEX(5, r + 5);
|
||||
INDEX(6, r + 6);
|
||||
INDEX(7, r + 7);
|
||||
out += L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7);
|
||||
}
|
||||
if (is_mean_ && num >= 10) {
|
||||
out = out / static_cast<T>(num);
|
||||
}
|
||||
if (is_sqrtn_ && num >= 10) {
|
||||
out = out / static_cast<T>(sqrt(num));
|
||||
}
|
||||
}
|
||||
|
||||
return -1;
|
||||
REDUCE;
|
||||
out = temp.template cast<bfloat16>();
|
||||
#undef DT
|
||||
#undef OUT
|
||||
#undef L
|
||||
} else {
|
||||
#define L(n) input_flat.template chip<0>(index##n)
|
||||
#define OUT out
|
||||
#define DT T
|
||||
|
||||
REDUCE;
|
||||
|
||||
#undef DT
|
||||
#undef OUT
|
||||
#undef L
|
||||
}
|
||||
return -1;
|
||||
#undef REDUCE
|
||||
#undef INDEX
|
||||
}
|
||||
|
||||
|
|
|
@ -64,6 +64,7 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE);
|
|||
segment_ids_type>);
|
||||
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(float);
|
||||
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(double);
|
||||
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(bfloat16);
|
||||
#undef REGISTER_CPU_SPARSE_KERNELS
|
||||
|
||||
#define REGISTER_CPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \
|
||||
|
@ -85,6 +86,7 @@ REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(double);
|
|||
CPUDevice, type, index_type, segment_ids_type>);
|
||||
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(float);
|
||||
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(double);
|
||||
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(bfloat16);
|
||||
#undef REGISTER_CPU_SPARSE_KERNELS
|
||||
|
||||
#define REGISTER_CPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \
|
||||
|
|
|
@ -1337,7 +1337,7 @@ REGISTER_OP("SparseSegmentMean")
|
|||
.Input("indices: Tidx")
|
||||
.Input("segment_ids: Tsegmentids")
|
||||
.Output("output: T")
|
||||
.Attr("T: {float, double}")
|
||||
.Attr("T: {bfloat16, float, double}")
|
||||
.Attr("Tidx: {int32, int64} = DT_INT32")
|
||||
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
|
||||
.SetShapeFn(SparseSegmentReductionShapeFn);
|
||||
|
@ -1348,7 +1348,7 @@ REGISTER_OP("SparseSegmentMeanWithNumSegments")
|
|||
.Input("segment_ids: Tsegmentids")
|
||||
.Input("num_segments: Tnumsegments")
|
||||
.Output("output: T")
|
||||
.Attr("T: {float, double}")
|
||||
.Attr("T: {bfloat16, float, double}")
|
||||
.Attr("Tidx: {int32, int64} = DT_INT32")
|
||||
.Attr("Tnumsegments: {int32,int64} = DT_INT32")
|
||||
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
|
||||
|
@ -1370,7 +1370,7 @@ REGISTER_OP("SparseSegmentSqrtN")
|
|||
.Input("indices: Tidx")
|
||||
.Input("segment_ids: Tsegmentids")
|
||||
.Output("output: T")
|
||||
.Attr("T: {float, double}")
|
||||
.Attr("T: {bfloat16, float, double}")
|
||||
.Attr("Tidx: {int32, int64} = DT_INT32")
|
||||
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
|
||||
.SetShapeFn(SparseSegmentReductionShapeFn);
|
||||
|
@ -1381,7 +1381,7 @@ REGISTER_OP("SparseSegmentSqrtNWithNumSegments")
|
|||
.Input("segment_ids: Tsegmentids")
|
||||
.Input("num_segments: Tnumsegments")
|
||||
.Output("output: T")
|
||||
.Attr("T: {float, double}")
|
||||
.Attr("T: {bfloat16, float, double}")
|
||||
.Attr("Tidx: {int32, int64} = DT_INT32")
|
||||
.Attr("Tnumsegments: {int32,int64} = DT_INT32")
|
||||
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
|
||||
|
|
Loading…
Reference in New Issue