[tf.sparse.segment_*()] Add missing kernel registrations for DT_INT64 indices.

In addition, this change allows the segment_ids argument to be DT_INT32 (as at present) or DT_INT64. Allowing DT_INT64 segment_ids avoids the need to cast indices from a `tf.SparseTensor` before passing them to these ops, for example during `tf.nn.embedding_lookup_sparse()`.

PiperOrigin-RevId: 308136754
Change-Id: Ie9da7a5fb82507335943684ba9619e5ed6282647
This commit is contained in:
Derek Murray 2020-04-23 15:00:54 -07:00 committed by TensorFlower Gardener
parent 7653317576
commit 3bef4c6606
5 changed files with 188 additions and 121 deletions

View File

@ -425,7 +425,13 @@ class UnsortedSegmentReductionOp : public OpKernel {
// Same as SegmentReductionOp but takes as input a "sparse" tensor, represented
// by two dense tensors, one containing the data, and the other containing
// indices into the data.
template <typename Device, class T>
//
// The template parameters are:
// * Device: An Eigen device object, on which the kernel will execute.
// * T: The value type.
// * Index: The element type of the indices tensor (int32 or int64).
// * SegmentId: The element type of the segment_ids tensor (int32 or int64).
template <typename Device, class T, typename Index, typename SegmentId>
class SparseSegmentReductionOpBase : public OpKernel {
public:
explicit SparseSegmentReductionOpBase(OpKernelConstruction* context,
@ -468,11 +474,10 @@ class SparseSegmentReductionOpBase : public OpKernel {
auto input_flat = input.flat_outer_dims<T>();
const int64 num_col = input_flat.dimension(1);
const auto indices_vec = indices.vec<Index>();
typedef int32 OutputRow;
const auto segment_vec = segment_ids.vec<OutputRow>();
const auto segment_vec = segment_ids.vec<SegmentId>();
// Note that the current implementation assumes that segment_vec values are
// sorted.
const OutputRow last_segment_id_plus_one =
const SegmentId last_segment_id_plus_one =
num_indices > 0
? internal::SubtleMustCopy(segment_vec(num_indices - 1)) + 1
: 0;
@ -505,14 +510,14 @@ class SparseSegmentReductionOpBase : public OpKernel {
int64 start = 0, end = 1;
// Index from which the output is not initialized.
OutputRow uninitialized_index = 0;
OutputRow out_index = internal::SubtleMustCopy(segment_vec(start));
SegmentId uninitialized_index = 0;
SegmentId out_index = internal::SubtleMustCopy(segment_vec(start));
while (true) {
// We initialize next_index to 0 to avoid "warning: 'next_index' may be
// used uninitialized in this function" in the Mac build (since the
// compiler isn't smart enough to realize the code is safe).
OutputRow next_index = 0;
SegmentId next_index = 0;
if (end < num_indices) {
next_index = internal::SubtleMustCopy(segment_vec(end));
if (out_index == next_index) {
@ -567,8 +572,6 @@ class SparseSegmentReductionOpBase : public OpKernel {
}
private:
typedef int32 Index;
int64 Reduce(const typename TTypes<T>::ConstMatrix& input_flat,
const typename TTypes<Index>::ConstVec& indices_vec, int64 start,
int64 num,
@ -702,70 +705,78 @@ class SparseSegmentReductionOpBase : public OpKernel {
const T default_value_;
};
template <typename Device, class T>
template <typename Device, class T, typename Index, typename SegmentId>
class SparseSegmentReductionMeanOp
: public SparseSegmentReductionOpBase<Device, T> {
: public SparseSegmentReductionOpBase<Device, T, Index, SegmentId> {
public:
explicit SparseSegmentReductionMeanOp(OpKernelConstruction* context)
: SparseSegmentReductionOpBase<Device, T>(
: SparseSegmentReductionOpBase<Device, T, Index, SegmentId>(
context, true /*is_mean*/, false /*is_sqrtn*/,
false /* has_num_segments */, T(0) /* default_value */) {}
};
template <typename Device, class T>
template <typename Device, class T, typename Index, typename SegmentId>
class SparseSegmentReductionMeanWithNumSegmentsOp
: public SparseSegmentReductionOpBase<Device, T> {
: public SparseSegmentReductionOpBase<Device, T, Index, SegmentId> {
public:
explicit SparseSegmentReductionMeanWithNumSegmentsOp(
OpKernelConstruction* context)
: SparseSegmentReductionOpBase<Device, T>(
: SparseSegmentReductionOpBase<Device, T, Index, SegmentId>(
context, true /*is_mean*/, false /*is_sqrtn*/,
true /* has_num_segments */, T(0) /* default_value */) {}
};
template <typename Device, class T>
template <typename Device, class T, typename Index, typename SegmentId>
class SparseSegmentReductionSqrtNOp
: public SparseSegmentReductionOpBase<Device, T> {
: public SparseSegmentReductionOpBase<Device, T, Index, SegmentId> {
public:
explicit SparseSegmentReductionSqrtNOp(OpKernelConstruction* context)
: SparseSegmentReductionOpBase<Device, T>(
: SparseSegmentReductionOpBase<Device, T, Index, SegmentId>(
context, false /*is_mean*/, true /*is_sqrtn*/,
false /* has_num_segments */, T(0) /* default_value */) {}
};
template <typename Device, class T>
template <typename Device, class T, typename Index, typename SegmentId>
class SparseSegmentReductionSqrtNWithNumSegmentsOp
: public SparseSegmentReductionOpBase<Device, T> {
: public SparseSegmentReductionOpBase<Device, T, Index, SegmentId> {
public:
explicit SparseSegmentReductionSqrtNWithNumSegmentsOp(
OpKernelConstruction* context)
: SparseSegmentReductionOpBase<Device, T>(
: SparseSegmentReductionOpBase<Device, T, Index, SegmentId>(
context, false /*is_mean*/, true /*is_sqrtn*/,
true /* has_num_segments */, T(0) /* default_value */) {}
};
template <typename Device, class T>
template <typename Device, class T, typename Index, typename SegmentId>
class SparseSegmentReductionSumOp
: public SparseSegmentReductionOpBase<Device, T> {
: public SparseSegmentReductionOpBase<Device, T, Index, SegmentId> {
public:
explicit SparseSegmentReductionSumOp(OpKernelConstruction* context)
: SparseSegmentReductionOpBase<Device, T>(
: SparseSegmentReductionOpBase<Device, T, Index, SegmentId>(
context, false /*is_mean*/, false /*is_sqrtn*/,
false /* has_num_segments */, T(0) /* default_value */) {}
};
template <typename Device, class T>
template <typename Device, class T, typename Index, typename SegmentId>
class SparseSegmentReductionSumWithNumSegmentsOp
: public SparseSegmentReductionOpBase<Device, T> {
: public SparseSegmentReductionOpBase<Device, T, Index, SegmentId> {
public:
explicit SparseSegmentReductionSumWithNumSegmentsOp(
OpKernelConstruction* context)
: SparseSegmentReductionOpBase<Device, T>(
: SparseSegmentReductionOpBase<Device, T, Index, SegmentId>(
context, false /*is_mean*/, false /*is_sqrtn*/,
true /* has_num_segments */, T(0) /* default_value */) {}
};
template <class T>
// Implements the common logic for the gradients of SparseSegmentReduction
// kernels.
//
// The template parameters are:
// * Device: An Eigen device object, on which the kernel will execute.
// * T: The value type.
// * Index: The element type of the indices tensor (int32 or int64).
// * SegmentId: The element type of the segment_ids tensor (int32 or int64).
template <class T, typename Index, typename SegmentId>
class SparseSegmentGradOpBase : public OpKernel {
public:
explicit SparseSegmentGradOpBase(OpKernelConstruction* context, bool is_sqrtn)
@ -788,12 +799,9 @@ class SparseSegmentGradOpBase : public OpKernel {
OP_REQUIRES(context, N == segment_ids.NumElements(),
errors::InvalidArgument(
"segment_ids and indices should have same size."));
typedef int32 SegmentId;
const SegmentId M =
internal::SubtleMustCopy(output_dim0.scalar<SegmentId>()());
const SegmentId M = internal::SubtleMustCopy(output_dim0.scalar<int32>()());
auto input_flat = input.flat_outer_dims<T>();
typedef int32 Index;
const auto indices_vec = indices.vec<Index>();
const auto segment_vec = segment_ids.vec<SegmentId>();
@ -871,18 +879,22 @@ class SparseSegmentGradOpBase : public OpKernel {
const bool is_sqrtn_;
};
template <class T>
class SparseSegmentMeanGradOp : public SparseSegmentGradOpBase<T> {
template <class T, typename Index, typename SegmentId>
class SparseSegmentMeanGradOp
: public SparseSegmentGradOpBase<T, Index, SegmentId> {
public:
explicit SparseSegmentMeanGradOp(OpKernelConstruction* context)
: SparseSegmentGradOpBase<T>(context, false /*is_sqrtn*/) {}
: SparseSegmentGradOpBase<T, Index, SegmentId>(context,
false /*is_sqrtn*/) {}
};
template <class T>
class SparseSegmentSqrtNGradOp : public SparseSegmentGradOpBase<T> {
template <class T, typename Index, typename SegmentId>
class SparseSegmentSqrtNGradOp
: public SparseSegmentGradOpBase<T, Index, SegmentId> {
public:
explicit SparseSegmentSqrtNGradOp(OpKernelConstruction* context)
: SparseSegmentGradOpBase<T>(context, true /*is_sqrtn*/) {}
: SparseSegmentGradOpBase<T, Index, SegmentId>(context,
true /*is_sqrtn*/) {}
};
} // namespace tensorflow

View File

@ -18,71 +18,100 @@ limitations under the License.
namespace tensorflow {
#define REGISTER_CPU_SPARSE_KERNELS(type) \
REGISTER_KERNEL_BUILDER(Name("SparseSegmentSum") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx"), \
SparseSegmentReductionSumOp<CPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("SparseSegmentSumWithNumSegments") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx"), \
SparseSegmentReductionSumWithNumSegmentsOp<CPUDevice, type>);
TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_SPARSE_KERNELS);
#define REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_SEGMENT_ID_TYPE(type, index_type) \
REGISTER_CPU_SPARSE_KERNELS(type, index_type, int32) \
REGISTER_CPU_SPARSE_KERNELS(type, index_type, int64)
#define REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(type) \
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_SEGMENT_ID_TYPE(type, int32) \
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_SEGMENT_ID_TYPE(type, int64)
#define REGISTER_CPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \
REGISTER_KERNEL_BUILDER( \
Name("SparseSegmentSum") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<index_type>("Tidx") \
.TypeConstraint<segment_ids_type>("Tsegmentids"), \
SparseSegmentReductionSumOp<CPUDevice, type, index_type, \
segment_ids_type>); \
REGISTER_KERNEL_BUILDER( \
Name("SparseSegmentSumWithNumSegments") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<index_type>("Tidx") \
.TypeConstraint<segment_ids_type>("Tsegmentids"), \
SparseSegmentReductionSumWithNumSegmentsOp<CPUDevice, type, index_type, \
segment_ids_type>);
TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE);
#undef REGISTER_CPU_SPARSE_KERNELS
#define REGISTER_CPU_SPARSE_KERNELS(type) \
REGISTER_KERNEL_BUILDER(Name("SparseSegmentMean") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx"), \
SparseSegmentReductionMeanOp<CPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("SparseSegmentMeanWithNumSegments") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx"), \
SparseSegmentReductionMeanWithNumSegmentsOp<CPUDevice, type>);
REGISTER_CPU_SPARSE_KERNELS(float);
REGISTER_CPU_SPARSE_KERNELS(double);
#define REGISTER_CPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \
REGISTER_KERNEL_BUILDER( \
Name("SparseSegmentMean") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<index_type>("Tidx") \
.TypeConstraint<segment_ids_type>("Tsegmentids"), \
SparseSegmentReductionMeanOp<CPUDevice, type, index_type, \
segment_ids_type>); \
REGISTER_KERNEL_BUILDER( \
Name("SparseSegmentMeanWithNumSegments") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<index_type>("Tidx") \
.TypeConstraint<segment_ids_type>("Tsegmentids"), \
SparseSegmentReductionMeanWithNumSegmentsOp<CPUDevice, type, index_type, \
segment_ids_type>);
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(float);
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(double);
#undef REGISTER_CPU_SPARSE_KERNELS
#define REGISTER_CPU_SPARSE_KERNELS(type) \
REGISTER_KERNEL_BUILDER(Name("SparseSegmentSqrtN") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx"), \
SparseSegmentReductionSqrtNOp<CPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("SparseSegmentSqrtNWithNumSegments") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx"), \
SparseSegmentReductionSqrtNWithNumSegmentsOp<CPUDevice, type>);
REGISTER_CPU_SPARSE_KERNELS(float);
REGISTER_CPU_SPARSE_KERNELS(double);
#define REGISTER_CPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \
REGISTER_KERNEL_BUILDER( \
Name("SparseSegmentSqrtN") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<index_type>("Tidx") \
.TypeConstraint<segment_ids_type>("Tsegmentids"), \
SparseSegmentReductionSqrtNOp<CPUDevice, type, index_type, \
segment_ids_type>); \
REGISTER_KERNEL_BUILDER( \
Name("SparseSegmentSqrtNWithNumSegments") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<index_type>("Tidx") \
.TypeConstraint<segment_ids_type>("Tsegmentids"), \
SparseSegmentReductionSqrtNWithNumSegmentsOp< \
CPUDevice, type, index_type, segment_ids_type>);
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(float);
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(double);
#undef REGISTER_CPU_SPARSE_KERNELS
#define REGISTER_CPU_SPARSE_KERNELS(type) \
REGISTER_KERNEL_BUILDER(Name("SparseSegmentMeanGrad") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx"), \
SparseSegmentMeanGradOp<type>);
REGISTER_CPU_SPARSE_KERNELS(float);
REGISTER_CPU_SPARSE_KERNELS(double);
#define REGISTER_CPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \
REGISTER_KERNEL_BUILDER( \
Name("SparseSegmentMeanGrad") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<index_type>("Tidx") \
.TypeConstraint<segment_ids_type>("Tsegmentids"), \
SparseSegmentMeanGradOp<type, index_type, segment_ids_type>);
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(float);
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(double);
#undef REGISTER_CPU_SPARSE_KERNELS
#define REGISTER_CPU_SPARSE_KERNELS(type) \
REGISTER_KERNEL_BUILDER(Name("SparseSegmentSqrtNGrad") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx"), \
SparseSegmentSqrtNGradOp<type>);
REGISTER_CPU_SPARSE_KERNELS(float);
REGISTER_CPU_SPARSE_KERNELS(double);
#define REGISTER_CPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \
REGISTER_KERNEL_BUILDER( \
Name("SparseSegmentSqrtNGrad") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<index_type>("Tidx") \
.TypeConstraint<segment_ids_type>("Tsegmentids"), \
SparseSegmentSqrtNGradOp<type, index_type, segment_ids_type>);
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(float);
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(double);
#undef REGISTER_CPU_SPARSE_KERNELS
#undef REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE
#undef REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_SEGMENT_ID_TYPE
} // namespace tensorflow

View File

@ -1313,81 +1313,89 @@ REGISTER_OP("UnsortedSegmentProd")
REGISTER_OP("SparseSegmentSum")
.Input("data: T")
.Input("indices: Tidx")
.Input("segment_ids: int32")
.Input("segment_ids: Tsegmentids")
.Output("output: T")
.Attr("T: realnumbertype")
.Attr("Tidx: {int32, int64} = DT_INT32")
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
.SetShapeFn(SparseSegmentReductionShapeFn);
REGISTER_OP("SparseSegmentSumWithNumSegments")
.Input("data: T")
.Input("indices: Tidx")
.Input("segment_ids: int32")
.Input("segment_ids: Tsegmentids")
.Input("num_segments: Tnumsegments")
.Output("output: T")
.Attr("T: realnumbertype")
.Attr("Tidx: {int32, int64} = DT_INT32")
.Attr("Tnumsegments: {int32,int64} = DT_INT32")
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
.SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn);
REGISTER_OP("SparseSegmentMean")
.Input("data: T")
.Input("indices: Tidx")
.Input("segment_ids: int32")
.Input("segment_ids: Tsegmentids")
.Output("output: T")
.Attr("T: {float, double}")
.Attr("Tidx: {int32, int64} = DT_INT32")
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
.SetShapeFn(SparseSegmentReductionShapeFn);
REGISTER_OP("SparseSegmentMeanWithNumSegments")
.Input("data: T")
.Input("indices: Tidx")
.Input("segment_ids: int32")
.Input("segment_ids: Tsegmentids")
.Input("num_segments: Tnumsegments")
.Output("output: T")
.Attr("T: {float, double}")
.Attr("Tidx: {int32, int64} = DT_INT32")
.Attr("Tnumsegments: {int32,int64} = DT_INT32")
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
.SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn);
REGISTER_OP("SparseSegmentMeanGrad")
.Input("grad: T")
.Input("indices: Tidx")
.Input("segment_ids: int32")
.Input("segment_ids: Tsegmentids")
.Input("output_dim0: int32")
.Output("output: T")
.Attr("T: {float, double}")
.Attr("Tidx: {int32, int64} = DT_INT32")
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
.SetShapeFn(SparseSegmentReductionGradShapeFn);
REGISTER_OP("SparseSegmentSqrtN")
.Input("data: T")
.Input("indices: Tidx")
.Input("segment_ids: int32")
.Input("segment_ids: Tsegmentids")
.Output("output: T")
.Attr("T: {float, double}")
.Attr("Tidx: {int32, int64} = DT_INT32")
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
.SetShapeFn(SparseSegmentReductionShapeFn);
REGISTER_OP("SparseSegmentSqrtNWithNumSegments")
.Input("data: T")
.Input("indices: Tidx")
.Input("segment_ids: int32")
.Input("segment_ids: Tsegmentids")
.Input("num_segments: Tnumsegments")
.Output("output: T")
.Attr("T: {float, double}")
.Attr("Tidx: {int32, int64} = DT_INT32")
.Attr("Tnumsegments: {int32,int64} = DT_INT32")
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
.SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn);
REGISTER_OP("SparseSegmentSqrtNGrad")
.Input("grad: T")
.Input("indices: Tidx")
.Input("segment_ids: int32")
.Input("segment_ids: Tsegmentids")
.Input("output_dim0: int32")
.Output("output: T")
.Attr("T: {float, double}")
.Attr("Tidx: {int32, int64} = DT_INT32")
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
.SetShapeFn(SparseSegmentReductionGradShapeFn);
REGISTER_OP("All")

View File

@ -516,6 +516,9 @@ class SparseSegmentReductionOpTest(SparseSegmentReductionHelper):
dtypes_lib.int32
]
index_dtypes = [dtypes_lib.int32, dtypes_lib.int64]
segment_ids_dtypes = [dtypes_lib.int32, dtypes_lib.int64]
mean_dtypes = [dtypes_lib.float32, dtypes_lib.float64]
# Each item is np_op1, np_op2, tf_op
@ -531,22 +534,29 @@ class SparseSegmentReductionOpTest(SparseSegmentReductionHelper):
segment_indices.append(i)
num_indices = len(segment_indices)
for dtype in dtypes:
with self.cached_session(use_gpu=False):
tf_indices, np_indices, tf_x, np_x = self._sparse_input(
shape, num_indices, dtype=dtype)
for np_op1, np_op2, tf_op in ops_list:
if tf_op == math_ops.sparse_segment_mean and dtype not in mean_dtypes:
continue
np_ans = self._sparseSegmentReduce(np_x, np_indices, segment_indices,
np_op1, np_op2)
s = tf_op(data=tf_x, indices=tf_indices, segment_ids=segment_indices)
tf_ans = self.evaluate(s)
self.assertAllClose(np_ans, tf_ans)
# NOTE(mrry): The static shape inference that computes
# `tf_ans.shape` can only infer that sizes from dimension 1
# onwards, because the size of dimension 0 is data-dependent
# and may therefore vary dynamically.
self.assertAllEqual(np_ans.shape[1:], tf_ans.shape[1:])
for index_dtype in index_dtypes:
for segment_ids_dtype in segment_ids_dtypes:
with self.cached_session(use_gpu=False):
tf_indices, np_indices, tf_x, np_x = self._sparse_input(
shape, num_indices, dtype=dtype)
for np_op1, np_op2, tf_op in ops_list:
if (tf_op == math_ops.sparse_segment_mean
and dtype not in mean_dtypes):
continue
np_ans = self._sparseSegmentReduce(np_x, np_indices,
segment_indices, np_op1,
np_op2)
s = tf_op(
data=tf_x,
indices=math_ops.cast(tf_indices, index_dtype),
segment_ids=math_ops.cast(segment_indices, segment_ids_dtype))
tf_ans = self.evaluate(s)
self.assertAllClose(np_ans, tf_ans)
# NOTE(mrry): The static shape inference that computes
# `tf_ans.shape` can only infer that sizes from dimension 1
# onwards, because the size of dimension 0 is data-dependent
# and may therefore vary dynamically.
self.assertAllEqual(np_ans.shape[1:], tf_ans.shape[1:])
def testSegmentIdsHole(self):
tf_x, np_x = self._input([10, 4], dtype=dtypes_lib.float32)

View File

@ -19,6 +19,7 @@ from __future__ import print_function
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.compat import compat
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@ -481,8 +482,6 @@ def embedding_lookup_sparse(params,
with ops.name_scope(name, "embedding_lookup_sparse",
params + [sp_ids]) as name:
segment_ids = sp_ids.indices[:, 0]
if segment_ids.dtype != dtypes.int32:
segment_ids = math_ops.cast(segment_ids, dtypes.int32)
ids = sp_ids.values
ids, idx = array_ops.unique(ids)
@ -492,6 +491,9 @@ def embedding_lookup_sparse(params,
if embeddings.dtype in (dtypes.float16, dtypes.bfloat16):
embeddings = math_ops.cast(embeddings, dtypes.float32)
if not ignore_weights:
if segment_ids.dtype != dtypes.int32:
segment_ids = math_ops.cast(segment_ids, dtypes.int32)
weights = sp_weights.values
if weights.dtype != embeddings.dtype:
weights = math_ops.cast(weights, embeddings.dtype)
@ -531,6 +533,12 @@ def embedding_lookup_sparse(params,
else:
assert False, "Unrecognized combiner"
else:
if compat.forward_compatible(2020, 5, 14):
if segment_ids.dtype not in (dtypes.int32, dtypes.int64):
segment_ids = math_ops.cast(segment_ids, dtypes.int32)
else:
if segment_ids.dtype != dtypes.int32:
segment_ids = math_ops.cast(segment_ids, dtypes.int32)
assert idx is not None
if combiner == "sum":
embeddings = math_ops.sparse_segment_sum(