[tf.sparse.segment_*()] Add missing kernel registrations for DT_INT64 indices.
In addition, this change allows the segment_ids argument to be DT_INT32 (as at present) or DT_INT64. Allowing DT_INT64 segment_ids avoids the need to cast indices from a `tf.SparseTensor` before passing them to these ops, for example during `tf.nn.embedding_lookup_sparse()`. PiperOrigin-RevId: 308136754 Change-Id: Ie9da7a5fb82507335943684ba9619e5ed6282647
This commit is contained in:
parent
7653317576
commit
3bef4c6606
tensorflow
core
python
@ -425,7 +425,13 @@ class UnsortedSegmentReductionOp : public OpKernel {
|
||||
// Same as SegmentReductionOp but takes as input a "sparse" tensor, represented
|
||||
// by two dense tensors, one containing the data, and the other containing
|
||||
// indices into the data.
|
||||
template <typename Device, class T>
|
||||
//
|
||||
// The template parameters are:
|
||||
// * Device: An Eigen device object, on which the kernel will execute.
|
||||
// * T: The value type.
|
||||
// * Index: The element type of the indices tensor (int32 or int64).
|
||||
// * SegmentId: The element type of the segment_ids tensor (int32 or int64).
|
||||
template <typename Device, class T, typename Index, typename SegmentId>
|
||||
class SparseSegmentReductionOpBase : public OpKernel {
|
||||
public:
|
||||
explicit SparseSegmentReductionOpBase(OpKernelConstruction* context,
|
||||
@ -468,11 +474,10 @@ class SparseSegmentReductionOpBase : public OpKernel {
|
||||
auto input_flat = input.flat_outer_dims<T>();
|
||||
const int64 num_col = input_flat.dimension(1);
|
||||
const auto indices_vec = indices.vec<Index>();
|
||||
typedef int32 OutputRow;
|
||||
const auto segment_vec = segment_ids.vec<OutputRow>();
|
||||
const auto segment_vec = segment_ids.vec<SegmentId>();
|
||||
// Note that the current implementation assumes that segment_vec values are
|
||||
// sorted.
|
||||
const OutputRow last_segment_id_plus_one =
|
||||
const SegmentId last_segment_id_plus_one =
|
||||
num_indices > 0
|
||||
? internal::SubtleMustCopy(segment_vec(num_indices - 1)) + 1
|
||||
: 0;
|
||||
@ -505,14 +510,14 @@ class SparseSegmentReductionOpBase : public OpKernel {
|
||||
|
||||
int64 start = 0, end = 1;
|
||||
// Index from which the output is not initialized.
|
||||
OutputRow uninitialized_index = 0;
|
||||
OutputRow out_index = internal::SubtleMustCopy(segment_vec(start));
|
||||
SegmentId uninitialized_index = 0;
|
||||
SegmentId out_index = internal::SubtleMustCopy(segment_vec(start));
|
||||
|
||||
while (true) {
|
||||
// We initialize next_index to 0 to avoid "warning: 'next_index' may be
|
||||
// used uninitialized in this function" in the Mac build (since the
|
||||
// compiler isn't smart enough to realize the code is safe).
|
||||
OutputRow next_index = 0;
|
||||
SegmentId next_index = 0;
|
||||
if (end < num_indices) {
|
||||
next_index = internal::SubtleMustCopy(segment_vec(end));
|
||||
if (out_index == next_index) {
|
||||
@ -567,8 +572,6 @@ class SparseSegmentReductionOpBase : public OpKernel {
|
||||
}
|
||||
|
||||
private:
|
||||
typedef int32 Index;
|
||||
|
||||
int64 Reduce(const typename TTypes<T>::ConstMatrix& input_flat,
|
||||
const typename TTypes<Index>::ConstVec& indices_vec, int64 start,
|
||||
int64 num,
|
||||
@ -702,70 +705,78 @@ class SparseSegmentReductionOpBase : public OpKernel {
|
||||
const T default_value_;
|
||||
};
|
||||
|
||||
template <typename Device, class T>
|
||||
template <typename Device, class T, typename Index, typename SegmentId>
|
||||
class SparseSegmentReductionMeanOp
|
||||
: public SparseSegmentReductionOpBase<Device, T> {
|
||||
: public SparseSegmentReductionOpBase<Device, T, Index, SegmentId> {
|
||||
public:
|
||||
explicit SparseSegmentReductionMeanOp(OpKernelConstruction* context)
|
||||
: SparseSegmentReductionOpBase<Device, T>(
|
||||
: SparseSegmentReductionOpBase<Device, T, Index, SegmentId>(
|
||||
context, true /*is_mean*/, false /*is_sqrtn*/,
|
||||
false /* has_num_segments */, T(0) /* default_value */) {}
|
||||
};
|
||||
|
||||
template <typename Device, class T>
|
||||
template <typename Device, class T, typename Index, typename SegmentId>
|
||||
class SparseSegmentReductionMeanWithNumSegmentsOp
|
||||
: public SparseSegmentReductionOpBase<Device, T> {
|
||||
: public SparseSegmentReductionOpBase<Device, T, Index, SegmentId> {
|
||||
public:
|
||||
explicit SparseSegmentReductionMeanWithNumSegmentsOp(
|
||||
OpKernelConstruction* context)
|
||||
: SparseSegmentReductionOpBase<Device, T>(
|
||||
: SparseSegmentReductionOpBase<Device, T, Index, SegmentId>(
|
||||
context, true /*is_mean*/, false /*is_sqrtn*/,
|
||||
true /* has_num_segments */, T(0) /* default_value */) {}
|
||||
};
|
||||
|
||||
template <typename Device, class T>
|
||||
template <typename Device, class T, typename Index, typename SegmentId>
|
||||
class SparseSegmentReductionSqrtNOp
|
||||
: public SparseSegmentReductionOpBase<Device, T> {
|
||||
: public SparseSegmentReductionOpBase<Device, T, Index, SegmentId> {
|
||||
public:
|
||||
explicit SparseSegmentReductionSqrtNOp(OpKernelConstruction* context)
|
||||
: SparseSegmentReductionOpBase<Device, T>(
|
||||
: SparseSegmentReductionOpBase<Device, T, Index, SegmentId>(
|
||||
context, false /*is_mean*/, true /*is_sqrtn*/,
|
||||
false /* has_num_segments */, T(0) /* default_value */) {}
|
||||
};
|
||||
|
||||
template <typename Device, class T>
|
||||
template <typename Device, class T, typename Index, typename SegmentId>
|
||||
class SparseSegmentReductionSqrtNWithNumSegmentsOp
|
||||
: public SparseSegmentReductionOpBase<Device, T> {
|
||||
: public SparseSegmentReductionOpBase<Device, T, Index, SegmentId> {
|
||||
public:
|
||||
explicit SparseSegmentReductionSqrtNWithNumSegmentsOp(
|
||||
OpKernelConstruction* context)
|
||||
: SparseSegmentReductionOpBase<Device, T>(
|
||||
: SparseSegmentReductionOpBase<Device, T, Index, SegmentId>(
|
||||
context, false /*is_mean*/, true /*is_sqrtn*/,
|
||||
true /* has_num_segments */, T(0) /* default_value */) {}
|
||||
};
|
||||
|
||||
template <typename Device, class T>
|
||||
template <typename Device, class T, typename Index, typename SegmentId>
|
||||
class SparseSegmentReductionSumOp
|
||||
: public SparseSegmentReductionOpBase<Device, T> {
|
||||
: public SparseSegmentReductionOpBase<Device, T, Index, SegmentId> {
|
||||
public:
|
||||
explicit SparseSegmentReductionSumOp(OpKernelConstruction* context)
|
||||
: SparseSegmentReductionOpBase<Device, T>(
|
||||
: SparseSegmentReductionOpBase<Device, T, Index, SegmentId>(
|
||||
context, false /*is_mean*/, false /*is_sqrtn*/,
|
||||
false /* has_num_segments */, T(0) /* default_value */) {}
|
||||
};
|
||||
|
||||
template <typename Device, class T>
|
||||
template <typename Device, class T, typename Index, typename SegmentId>
|
||||
class SparseSegmentReductionSumWithNumSegmentsOp
|
||||
: public SparseSegmentReductionOpBase<Device, T> {
|
||||
: public SparseSegmentReductionOpBase<Device, T, Index, SegmentId> {
|
||||
public:
|
||||
explicit SparseSegmentReductionSumWithNumSegmentsOp(
|
||||
OpKernelConstruction* context)
|
||||
: SparseSegmentReductionOpBase<Device, T>(
|
||||
: SparseSegmentReductionOpBase<Device, T, Index, SegmentId>(
|
||||
context, false /*is_mean*/, false /*is_sqrtn*/,
|
||||
true /* has_num_segments */, T(0) /* default_value */) {}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
// Implements the common logic for the gradients of SparseSegmentReduction
|
||||
// kernels.
|
||||
//
|
||||
// The template parameters are:
|
||||
// * Device: An Eigen device object, on which the kernel will execute.
|
||||
// * T: The value type.
|
||||
// * Index: The element type of the indices tensor (int32 or int64).
|
||||
// * SegmentId: The element type of the segment_ids tensor (int32 or int64).
|
||||
template <class T, typename Index, typename SegmentId>
|
||||
class SparseSegmentGradOpBase : public OpKernel {
|
||||
public:
|
||||
explicit SparseSegmentGradOpBase(OpKernelConstruction* context, bool is_sqrtn)
|
||||
@ -788,12 +799,9 @@ class SparseSegmentGradOpBase : public OpKernel {
|
||||
OP_REQUIRES(context, N == segment_ids.NumElements(),
|
||||
errors::InvalidArgument(
|
||||
"segment_ids and indices should have same size."));
|
||||
typedef int32 SegmentId;
|
||||
const SegmentId M =
|
||||
internal::SubtleMustCopy(output_dim0.scalar<SegmentId>()());
|
||||
const SegmentId M = internal::SubtleMustCopy(output_dim0.scalar<int32>()());
|
||||
|
||||
auto input_flat = input.flat_outer_dims<T>();
|
||||
typedef int32 Index;
|
||||
const auto indices_vec = indices.vec<Index>();
|
||||
const auto segment_vec = segment_ids.vec<SegmentId>();
|
||||
|
||||
@ -871,18 +879,22 @@ class SparseSegmentGradOpBase : public OpKernel {
|
||||
const bool is_sqrtn_;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
class SparseSegmentMeanGradOp : public SparseSegmentGradOpBase<T> {
|
||||
template <class T, typename Index, typename SegmentId>
|
||||
class SparseSegmentMeanGradOp
|
||||
: public SparseSegmentGradOpBase<T, Index, SegmentId> {
|
||||
public:
|
||||
explicit SparseSegmentMeanGradOp(OpKernelConstruction* context)
|
||||
: SparseSegmentGradOpBase<T>(context, false /*is_sqrtn*/) {}
|
||||
: SparseSegmentGradOpBase<T, Index, SegmentId>(context,
|
||||
false /*is_sqrtn*/) {}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
class SparseSegmentSqrtNGradOp : public SparseSegmentGradOpBase<T> {
|
||||
template <class T, typename Index, typename SegmentId>
|
||||
class SparseSegmentSqrtNGradOp
|
||||
: public SparseSegmentGradOpBase<T, Index, SegmentId> {
|
||||
public:
|
||||
explicit SparseSegmentSqrtNGradOp(OpKernelConstruction* context)
|
||||
: SparseSegmentGradOpBase<T>(context, true /*is_sqrtn*/) {}
|
||||
: SparseSegmentGradOpBase<T, Index, SegmentId>(context,
|
||||
true /*is_sqrtn*/) {}
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -18,71 +18,100 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
#define REGISTER_CPU_SPARSE_KERNELS(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("SparseSegmentSum") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<int32>("Tidx"), \
|
||||
SparseSegmentReductionSumOp<CPUDevice, type>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("SparseSegmentSumWithNumSegments") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<int32>("Tidx"), \
|
||||
SparseSegmentReductionSumWithNumSegmentsOp<CPUDevice, type>);
|
||||
TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_SPARSE_KERNELS);
|
||||
#define REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_SEGMENT_ID_TYPE(type, index_type) \
|
||||
REGISTER_CPU_SPARSE_KERNELS(type, index_type, int32) \
|
||||
REGISTER_CPU_SPARSE_KERNELS(type, index_type, int64)
|
||||
#define REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(type) \
|
||||
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_SEGMENT_ID_TYPE(type, int32) \
|
||||
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_SEGMENT_ID_TYPE(type, int64)
|
||||
|
||||
#define REGISTER_CPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("SparseSegmentSum") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<index_type>("Tidx") \
|
||||
.TypeConstraint<segment_ids_type>("Tsegmentids"), \
|
||||
SparseSegmentReductionSumOp<CPUDevice, type, index_type, \
|
||||
segment_ids_type>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("SparseSegmentSumWithNumSegments") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<index_type>("Tidx") \
|
||||
.TypeConstraint<segment_ids_type>("Tsegmentids"), \
|
||||
SparseSegmentReductionSumWithNumSegmentsOp<CPUDevice, type, index_type, \
|
||||
segment_ids_type>);
|
||||
TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE);
|
||||
#undef REGISTER_CPU_SPARSE_KERNELS
|
||||
|
||||
#define REGISTER_CPU_SPARSE_KERNELS(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("SparseSegmentMean") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<int32>("Tidx"), \
|
||||
SparseSegmentReductionMeanOp<CPUDevice, type>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("SparseSegmentMeanWithNumSegments") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<int32>("Tidx"), \
|
||||
SparseSegmentReductionMeanWithNumSegmentsOp<CPUDevice, type>);
|
||||
REGISTER_CPU_SPARSE_KERNELS(float);
|
||||
REGISTER_CPU_SPARSE_KERNELS(double);
|
||||
#define REGISTER_CPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("SparseSegmentMean") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<index_type>("Tidx") \
|
||||
.TypeConstraint<segment_ids_type>("Tsegmentids"), \
|
||||
SparseSegmentReductionMeanOp<CPUDevice, type, index_type, \
|
||||
segment_ids_type>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("SparseSegmentMeanWithNumSegments") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<index_type>("Tidx") \
|
||||
.TypeConstraint<segment_ids_type>("Tsegmentids"), \
|
||||
SparseSegmentReductionMeanWithNumSegmentsOp<CPUDevice, type, index_type, \
|
||||
segment_ids_type>);
|
||||
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(float);
|
||||
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(double);
|
||||
#undef REGISTER_CPU_SPARSE_KERNELS
|
||||
|
||||
#define REGISTER_CPU_SPARSE_KERNELS(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("SparseSegmentSqrtN") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<int32>("Tidx"), \
|
||||
SparseSegmentReductionSqrtNOp<CPUDevice, type>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("SparseSegmentSqrtNWithNumSegments") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<int32>("Tidx"), \
|
||||
SparseSegmentReductionSqrtNWithNumSegmentsOp<CPUDevice, type>);
|
||||
REGISTER_CPU_SPARSE_KERNELS(float);
|
||||
REGISTER_CPU_SPARSE_KERNELS(double);
|
||||
#define REGISTER_CPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("SparseSegmentSqrtN") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<index_type>("Tidx") \
|
||||
.TypeConstraint<segment_ids_type>("Tsegmentids"), \
|
||||
SparseSegmentReductionSqrtNOp<CPUDevice, type, index_type, \
|
||||
segment_ids_type>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("SparseSegmentSqrtNWithNumSegments") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<index_type>("Tidx") \
|
||||
.TypeConstraint<segment_ids_type>("Tsegmentids"), \
|
||||
SparseSegmentReductionSqrtNWithNumSegmentsOp< \
|
||||
CPUDevice, type, index_type, segment_ids_type>);
|
||||
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(float);
|
||||
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(double);
|
||||
#undef REGISTER_CPU_SPARSE_KERNELS
|
||||
|
||||
#define REGISTER_CPU_SPARSE_KERNELS(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("SparseSegmentMeanGrad") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<int32>("Tidx"), \
|
||||
SparseSegmentMeanGradOp<type>);
|
||||
REGISTER_CPU_SPARSE_KERNELS(float);
|
||||
REGISTER_CPU_SPARSE_KERNELS(double);
|
||||
#define REGISTER_CPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("SparseSegmentMeanGrad") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<index_type>("Tidx") \
|
||||
.TypeConstraint<segment_ids_type>("Tsegmentids"), \
|
||||
SparseSegmentMeanGradOp<type, index_type, segment_ids_type>);
|
||||
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(float);
|
||||
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(double);
|
||||
#undef REGISTER_CPU_SPARSE_KERNELS
|
||||
|
||||
#define REGISTER_CPU_SPARSE_KERNELS(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("SparseSegmentSqrtNGrad") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<int32>("Tidx"), \
|
||||
SparseSegmentSqrtNGradOp<type>);
|
||||
REGISTER_CPU_SPARSE_KERNELS(float);
|
||||
REGISTER_CPU_SPARSE_KERNELS(double);
|
||||
#define REGISTER_CPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("SparseSegmentSqrtNGrad") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<index_type>("Tidx") \
|
||||
.TypeConstraint<segment_ids_type>("Tsegmentids"), \
|
||||
SparseSegmentSqrtNGradOp<type, index_type, segment_ids_type>);
|
||||
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(float);
|
||||
REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE(double);
|
||||
#undef REGISTER_CPU_SPARSE_KERNELS
|
||||
|
||||
#undef REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE
|
||||
#undef REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_SEGMENT_ID_TYPE
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -1313,81 +1313,89 @@ REGISTER_OP("UnsortedSegmentProd")
|
||||
REGISTER_OP("SparseSegmentSum")
|
||||
.Input("data: T")
|
||||
.Input("indices: Tidx")
|
||||
.Input("segment_ids: int32")
|
||||
.Input("segment_ids: Tsegmentids")
|
||||
.Output("output: T")
|
||||
.Attr("T: realnumbertype")
|
||||
.Attr("Tidx: {int32, int64} = DT_INT32")
|
||||
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
|
||||
.SetShapeFn(SparseSegmentReductionShapeFn);
|
||||
|
||||
REGISTER_OP("SparseSegmentSumWithNumSegments")
|
||||
.Input("data: T")
|
||||
.Input("indices: Tidx")
|
||||
.Input("segment_ids: int32")
|
||||
.Input("segment_ids: Tsegmentids")
|
||||
.Input("num_segments: Tnumsegments")
|
||||
.Output("output: T")
|
||||
.Attr("T: realnumbertype")
|
||||
.Attr("Tidx: {int32, int64} = DT_INT32")
|
||||
.Attr("Tnumsegments: {int32,int64} = DT_INT32")
|
||||
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
|
||||
.SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn);
|
||||
|
||||
REGISTER_OP("SparseSegmentMean")
|
||||
.Input("data: T")
|
||||
.Input("indices: Tidx")
|
||||
.Input("segment_ids: int32")
|
||||
.Input("segment_ids: Tsegmentids")
|
||||
.Output("output: T")
|
||||
.Attr("T: {float, double}")
|
||||
.Attr("Tidx: {int32, int64} = DT_INT32")
|
||||
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
|
||||
.SetShapeFn(SparseSegmentReductionShapeFn);
|
||||
|
||||
REGISTER_OP("SparseSegmentMeanWithNumSegments")
|
||||
.Input("data: T")
|
||||
.Input("indices: Tidx")
|
||||
.Input("segment_ids: int32")
|
||||
.Input("segment_ids: Tsegmentids")
|
||||
.Input("num_segments: Tnumsegments")
|
||||
.Output("output: T")
|
||||
.Attr("T: {float, double}")
|
||||
.Attr("Tidx: {int32, int64} = DT_INT32")
|
||||
.Attr("Tnumsegments: {int32,int64} = DT_INT32")
|
||||
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
|
||||
.SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn);
|
||||
|
||||
REGISTER_OP("SparseSegmentMeanGrad")
|
||||
.Input("grad: T")
|
||||
.Input("indices: Tidx")
|
||||
.Input("segment_ids: int32")
|
||||
.Input("segment_ids: Tsegmentids")
|
||||
.Input("output_dim0: int32")
|
||||
.Output("output: T")
|
||||
.Attr("T: {float, double}")
|
||||
.Attr("Tidx: {int32, int64} = DT_INT32")
|
||||
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
|
||||
.SetShapeFn(SparseSegmentReductionGradShapeFn);
|
||||
|
||||
REGISTER_OP("SparseSegmentSqrtN")
|
||||
.Input("data: T")
|
||||
.Input("indices: Tidx")
|
||||
.Input("segment_ids: int32")
|
||||
.Input("segment_ids: Tsegmentids")
|
||||
.Output("output: T")
|
||||
.Attr("T: {float, double}")
|
||||
.Attr("Tidx: {int32, int64} = DT_INT32")
|
||||
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
|
||||
.SetShapeFn(SparseSegmentReductionShapeFn);
|
||||
|
||||
REGISTER_OP("SparseSegmentSqrtNWithNumSegments")
|
||||
.Input("data: T")
|
||||
.Input("indices: Tidx")
|
||||
.Input("segment_ids: int32")
|
||||
.Input("segment_ids: Tsegmentids")
|
||||
.Input("num_segments: Tnumsegments")
|
||||
.Output("output: T")
|
||||
.Attr("T: {float, double}")
|
||||
.Attr("Tidx: {int32, int64} = DT_INT32")
|
||||
.Attr("Tnumsegments: {int32,int64} = DT_INT32")
|
||||
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
|
||||
.SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn);
|
||||
|
||||
REGISTER_OP("SparseSegmentSqrtNGrad")
|
||||
.Input("grad: T")
|
||||
.Input("indices: Tidx")
|
||||
.Input("segment_ids: int32")
|
||||
.Input("segment_ids: Tsegmentids")
|
||||
.Input("output_dim0: int32")
|
||||
.Output("output: T")
|
||||
.Attr("T: {float, double}")
|
||||
.Attr("Tidx: {int32, int64} = DT_INT32")
|
||||
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
|
||||
.SetShapeFn(SparseSegmentReductionGradShapeFn);
|
||||
|
||||
REGISTER_OP("All")
|
||||
|
@ -516,6 +516,9 @@ class SparseSegmentReductionOpTest(SparseSegmentReductionHelper):
|
||||
dtypes_lib.int32
|
||||
]
|
||||
|
||||
index_dtypes = [dtypes_lib.int32, dtypes_lib.int64]
|
||||
segment_ids_dtypes = [dtypes_lib.int32, dtypes_lib.int64]
|
||||
|
||||
mean_dtypes = [dtypes_lib.float32, dtypes_lib.float64]
|
||||
|
||||
# Each item is np_op1, np_op2, tf_op
|
||||
@ -531,22 +534,29 @@ class SparseSegmentReductionOpTest(SparseSegmentReductionHelper):
|
||||
segment_indices.append(i)
|
||||
num_indices = len(segment_indices)
|
||||
for dtype in dtypes:
|
||||
with self.cached_session(use_gpu=False):
|
||||
tf_indices, np_indices, tf_x, np_x = self._sparse_input(
|
||||
shape, num_indices, dtype=dtype)
|
||||
for np_op1, np_op2, tf_op in ops_list:
|
||||
if tf_op == math_ops.sparse_segment_mean and dtype not in mean_dtypes:
|
||||
continue
|
||||
np_ans = self._sparseSegmentReduce(np_x, np_indices, segment_indices,
|
||||
np_op1, np_op2)
|
||||
s = tf_op(data=tf_x, indices=tf_indices, segment_ids=segment_indices)
|
||||
tf_ans = self.evaluate(s)
|
||||
self.assertAllClose(np_ans, tf_ans)
|
||||
# NOTE(mrry): The static shape inference that computes
|
||||
# `tf_ans.shape` can only infer that sizes from dimension 1
|
||||
# onwards, because the size of dimension 0 is data-dependent
|
||||
# and may therefore vary dynamically.
|
||||
self.assertAllEqual(np_ans.shape[1:], tf_ans.shape[1:])
|
||||
for index_dtype in index_dtypes:
|
||||
for segment_ids_dtype in segment_ids_dtypes:
|
||||
with self.cached_session(use_gpu=False):
|
||||
tf_indices, np_indices, tf_x, np_x = self._sparse_input(
|
||||
shape, num_indices, dtype=dtype)
|
||||
for np_op1, np_op2, tf_op in ops_list:
|
||||
if (tf_op == math_ops.sparse_segment_mean
|
||||
and dtype not in mean_dtypes):
|
||||
continue
|
||||
np_ans = self._sparseSegmentReduce(np_x, np_indices,
|
||||
segment_indices, np_op1,
|
||||
np_op2)
|
||||
s = tf_op(
|
||||
data=tf_x,
|
||||
indices=math_ops.cast(tf_indices, index_dtype),
|
||||
segment_ids=math_ops.cast(segment_indices, segment_ids_dtype))
|
||||
tf_ans = self.evaluate(s)
|
||||
self.assertAllClose(np_ans, tf_ans)
|
||||
# NOTE(mrry): The static shape inference that computes
|
||||
# `tf_ans.shape` can only infer that sizes from dimension 1
|
||||
# onwards, because the size of dimension 0 is data-dependent
|
||||
# and may therefore vary dynamically.
|
||||
self.assertAllEqual(np_ans.shape[1:], tf_ans.shape[1:])
|
||||
|
||||
def testSegmentIdsHole(self):
|
||||
tf_x, np_x = self._input([10, 4], dtype=dtypes_lib.float32)
|
||||
|
@ -19,6 +19,7 @@ from __future__ import print_function
|
||||
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
@ -481,8 +482,6 @@ def embedding_lookup_sparse(params,
|
||||
with ops.name_scope(name, "embedding_lookup_sparse",
|
||||
params + [sp_ids]) as name:
|
||||
segment_ids = sp_ids.indices[:, 0]
|
||||
if segment_ids.dtype != dtypes.int32:
|
||||
segment_ids = math_ops.cast(segment_ids, dtypes.int32)
|
||||
|
||||
ids = sp_ids.values
|
||||
ids, idx = array_ops.unique(ids)
|
||||
@ -492,6 +491,9 @@ def embedding_lookup_sparse(params,
|
||||
if embeddings.dtype in (dtypes.float16, dtypes.bfloat16):
|
||||
embeddings = math_ops.cast(embeddings, dtypes.float32)
|
||||
if not ignore_weights:
|
||||
if segment_ids.dtype != dtypes.int32:
|
||||
segment_ids = math_ops.cast(segment_ids, dtypes.int32)
|
||||
|
||||
weights = sp_weights.values
|
||||
if weights.dtype != embeddings.dtype:
|
||||
weights = math_ops.cast(weights, embeddings.dtype)
|
||||
@ -531,6 +533,12 @@ def embedding_lookup_sparse(params,
|
||||
else:
|
||||
assert False, "Unrecognized combiner"
|
||||
else:
|
||||
if compat.forward_compatible(2020, 5, 14):
|
||||
if segment_ids.dtype not in (dtypes.int32, dtypes.int64):
|
||||
segment_ids = math_ops.cast(segment_ids, dtypes.int32)
|
||||
else:
|
||||
if segment_ids.dtype != dtypes.int32:
|
||||
segment_ids = math_ops.cast(segment_ids, dtypes.int32)
|
||||
assert idx is not None
|
||||
if combiner == "sum":
|
||||
embeddings = math_ops.sparse_segment_sum(
|
||||
|
Loading…
Reference in New Issue
Block a user