[tf.data] Optimize SerializeManySparseOp implementation used in unbatching tf.SparseTensor.

This change makes the following optimizations:

1. Split the template specialization (between tstring and Variant) so that it
   applies at the entire op level, rather than a per-element level. This permits
   us to specialize for the (overwhelmingly more common) Variant case:

   * Use `Variant::emplace()` instead of the move assignment operator to avoid
     copying the inline data (viz. the TensorShape) in a Tensor.

2. Only set empty elements when the input is empty. Currently we call
   setConstant() on the entire output to set empty elements. With this change
   we only set those elements if there is no matching group in the input. This
   prevents wasted work (i) in the assignment and (ii) in destroying the
   unnecessarily assigned Tensors.

3. Introduce `sparse::Group::group_at()` to avoid the need for constructing a
   temporary vector on each group access, only to access the 0th element.

4. Optimize `sparse::GroupIterable::GroupMatches()` to return immediately when a
   mismatch is detected.

PiperOrigin-RevId: 289209832
Change-Id: I22df11bf474eab117307931908cef9c601d98226
This commit is contained in:
Derek Murray 2020-01-10 20:36:58 -08:00 committed by TensorFlower Gardener
parent 0f3c91c2bb
commit 880cad8598
2 changed files with 151 additions and 92 deletions
tensorflow/core

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/core/kernels/reshape_util.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/util/sparse/group_iterator.h"
#include "tensorflow/core/util/sparse/sparse_tensor.h"
namespace tensorflow {
@ -139,24 +140,150 @@ REGISTER_KERNEL_BUILDER(Name("SerializeSparse")
.TypeConstraint<Variant>("out_type"),
SerializeSparseOp<Variant>);
template <typename T, typename U>
struct SerializeGroups {};
template <typename T>
class SerializeManySparseOpBase : public OpKernel {
public:
explicit SerializeManySparseOpBase(OpKernelConstruction* context)
: OpKernel(context) {}
struct SerializeGroups<T, tstring> {
Status operator()(sparse::GroupIterable* minibatch,
const Tensor& output_shape, int64 N, int rank,
Tensor* serialized_sparse) {
auto serialized_sparse_t = serialized_sparse->matrix<tstring>();
void Compute(OpKernelContext* context) override {}
int64 last_nonempty_group = -1;
protected:
Status Initialize(const int64 n, Tensor* result);
Status Serialize(const Tensor& input, T* result);
auto serialize = [](const Tensor& input, tstring* result) {
TensorProto proto;
input.AsProtoTensorContent(&proto);
*result = proto.SerializeAsString();
};
tstring serialized_shape;
serialize(output_shape, &serialized_shape);
auto serialize_empty_element = [&](int64 b) {
serialize(Tensor(DT_INT64, {0, rank - 1}), &serialized_sparse_t(b, 0));
serialize(Tensor(DataTypeToEnum<T>::value, {0}),
&serialized_sparse_t(b, 1));
serialized_sparse_t(b, 2) = serialized_shape;
};
for (const auto& subset : *minibatch) {
const int64 b = subset.group_at(0);
if (b < 0 || b >= N) {
return errors::InvalidArgument(
"Received unexpected column 0 value in input SparseTensor: ", b,
" < 0 or >= N (= ", N, ")");
}
// GroupIterable generates only the non-empty groups of rows, so we must
// generate empty outputs for any empty rows since the last non-empty
// group that was generated.
for (int64 empty_b = last_nonempty_group + 1; empty_b < b; ++empty_b) {
serialize_empty_element(empty_b);
}
last_nonempty_group = b;
const auto indices = subset.indices();
const auto values = subset.values<T>();
const int64 num_entries = values.size();
Tensor output_indices = Tensor(DT_INT64, {num_entries, rank - 1});
Tensor output_values = Tensor(DataTypeToEnum<T>::value, {num_entries});
auto output_indices_t = output_indices.matrix<int64>();
auto output_values_t = output_values.vec<T>();
for (int i = 0; i < num_entries; ++i) {
for (int d = 1; d < rank; ++d) {
output_indices_t(i, d - 1) = indices(i, d);
}
output_values_t(i) = values(i);
}
serialize(output_indices, &serialized_sparse_t(b, 0));
serialize(output_values, &serialized_sparse_t(b, 1));
serialized_sparse_t(b, 2) = serialized_shape;
}
for (int64 empty_b = last_nonempty_group + 1; empty_b < N; ++empty_b) {
serialize_empty_element(empty_b);
}
return Status::OK();
}
};
template <typename T>
struct SerializeGroups<T, Variant> {
Status operator()(sparse::GroupIterable* minibatch,
const Tensor& output_shape, int64 N, int rank,
Tensor* serialized_sparse) {
auto serialized_sparse_t = serialized_sparse->template matrix<Variant>();
int64 last_nonempty_group = -1;
auto serialize_empty_element = [&](int64 b) {
serialized_sparse_t(b, 0).emplace<Tensor>(DT_INT64,
TensorShape({0, rank - 1}));
serialized_sparse_t(b, 1).emplace<Tensor>(DataTypeToEnum<T>::value,
TensorShape({0}));
serialized_sparse_t(b, 2).emplace<Tensor>(output_shape);
};
for (const auto& subset : *minibatch) {
const int64 b = subset.group_at(0);
if (b < 0 || b >= N) {
return errors::InvalidArgument(
"Received unexpected column 0 value in input SparseTensor: ", b,
" < 0 or >= N (= ", N, ")");
}
// GroupIterable generates only the non-empty groups of rows, so we must
// generate empty outputs for any empty rows since the last non-empty
// group that was generated.
for (int64 empty_b = last_nonempty_group + 1; empty_b < b; ++empty_b) {
serialize_empty_element(empty_b);
}
last_nonempty_group = b;
const auto indices = subset.indices();
const auto values = subset.values<T>();
const int64 num_entries = values.size();
Tensor& output_indices = serialized_sparse_t(b, 0).emplace<Tensor>(
DT_INT64, TensorShape({num_entries, rank - 1}));
Tensor& output_values = serialized_sparse_t(b, 1).emplace<Tensor>(
DataTypeToEnum<T>::value, TensorShape({num_entries}));
auto output_indices_t = output_indices.matrix<int64>();
auto output_values_t = output_values.vec<T>();
for (int i = 0; i < num_entries; ++i) {
for (int d = 1; d < rank; ++d) {
output_indices_t(i, d - 1) = indices(i, d);
}
output_values_t(i) = values(i);
}
serialized_sparse_t(b, 2).emplace<Tensor>(output_shape);
}
for (int64 empty_b = last_nonempty_group + 1; empty_b < N; ++empty_b) {
serialize_empty_element(empty_b);
}
return Status::OK();
}
};
template <typename T, typename U>
class SerializeManySparseOp : public SerializeManySparseOpBase<U> {
class SerializeManySparseOp : public OpKernel {
public:
explicit SerializeManySparseOp(OpKernelConstruction* context)
: SerializeManySparseOpBase<U>(context) {}
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
const Tensor* input_indices;
@ -197,85 +324,25 @@ class SerializeManySparseOp : public SerializeManySparseOpBase<U> {
auto input_shape_t = input_shape->vec<int64>();
const int64 N = input_shape_t(0);
Tensor serialized_sparse;
OP_REQUIRES_OK(context, this->Initialize(N, &serialized_sparse));
auto serialized_sparse_t = serialized_sparse.matrix<U>();
Tensor* serialized_sparse;
OP_REQUIRES_OK(context,
context->allocate_output(0, {N, 3}, &serialized_sparse));
OP_REQUIRES_OK(context, input_st.IndicesValid());
// Initialize output with empty values and the proper shapes.
Tensor output_blank_indices(DT_INT64, {0, rank - 1});
U serialized_indices;
OP_REQUIRES_OK(context,
this->Serialize(output_blank_indices, &serialized_indices));
serialized_sparse_t.template chip<1>(0).setConstant(serialized_indices);
Tensor output_blank_values(DataTypeToEnum<T>::value, {0});
U serialized_values;
OP_REQUIRES_OK(context,
this->Serialize(output_blank_values, &serialized_values));
serialized_sparse_t.template chip<1>(1).setConstant(serialized_values);
Tensor output_shape(DT_INT64, {rank - 1});
auto output_shape_t = output_shape.vec<int64>();
for (int d = 1; d < rank; d++) output_shape_t(d - 1) = input_shape_t(d);
U serialized_shape;
OP_REQUIRES_OK(context, this->Serialize(output_shape, &serialized_shape));
serialized_sparse_t.template chip<1>(2).setConstant(serialized_shape);
// Get groups by minibatch dimension
sparse::GroupIterable minibatch = input_st.group({0});
for (const auto& subset : minibatch) {
const int64 b = subset.group()[0];
OP_REQUIRES(
context, b > -1 && b < N,
errors::InvalidArgument(
"Received unexpected column 0 value in input SparseTensor: ", b,
" < 0 or >= N (= ", N, ")"));
const auto indices = subset.indices();
const auto values = subset.values<T>();
const int64 num_entries = values.size();
Tensor output_indices = Tensor(DT_INT64, {num_entries, rank - 1});
Tensor output_values = Tensor(DataTypeToEnum<T>::value, {num_entries});
auto output_indices_t = output_indices.matrix<int64>();
auto output_values_t = output_values.vec<T>();
for (int i = 0; i < num_entries; ++i) {
for (int d = 1; d < rank; ++d) {
output_indices_t(i, d - 1) = indices(i, d);
}
output_values_t(i) = values(i);
}
OP_REQUIRES_OK(
context, this->Serialize(output_indices, &serialized_sparse_t(b, 0)));
OP_REQUIRES_OK(
context, this->Serialize(output_values, &serialized_sparse_t(b, 1)));
}
context->set_output(0, serialized_sparse);
OP_REQUIRES_OK(context, SerializeGroups<T, U>()(&minibatch, output_shape, N,
rank, serialized_sparse));
}
};
template <>
Status SerializeManySparseOpBase<tstring>::Initialize(const int64 n,
Tensor* result) {
*result = Tensor(DT_STRING, TensorShape({n, 3}));
return Status::OK();
}
template <>
Status SerializeManySparseOpBase<tstring>::Serialize(const Tensor& input,
tstring* result) {
TensorProto proto;
input.AsProtoTensorContent(&proto);
*result = proto.SerializeAsString();
return Status::OK();
}
#define REGISTER_KERNELS(type) \
REGISTER_KERNEL_BUILDER(Name("SerializeManySparse") \
.Device(DEVICE_CPU) \
@ -286,19 +353,6 @@ Status SerializeManySparseOpBase<tstring>::Serialize(const Tensor& input,
TF_CALL_ALL_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS
template <>
Status SerializeManySparseOpBase<Variant>::Initialize(const int64 n,
Tensor* result) {
*result = Tensor(DT_VARIANT, TensorShape({n, 3}));
return Status::OK();
}
template <>
Status SerializeManySparseOpBase<Variant>::Serialize(const Tensor& input,
Variant* result) {
*result = input;
return Status::OK();
}
#define REGISTER_KERNELS(type) \
REGISTER_KERNEL_BUILDER(Name("SerializeManySparse") \

View File

@ -37,6 +37,7 @@ class Group {
: iter_(iter), loc_(loc), next_loc_(next_loc) {}
std::vector<int64> group() const;
int64 group_at(size_t index) const;
TTypes<int64>::UnalignedConstMatrix indices() const;
template <typename T>
typename TTypes<T>::UnalignedVec values() const;
@ -96,13 +97,12 @@ class GroupIterable {
template <typename TIX>
inline bool GroupMatches(const TIX& ix, int64 loc_a, int64 loc_b) const {
bool matches = true;
for (int d : group_dims_) {
if (ix(loc_a, d) != ix(loc_b, d)) {
matches = false;
return false;
}
}
return matches;
return true;
}
class IteratorStep {
@ -135,6 +135,11 @@ class GroupIterable {
const gtl::InlinedVector<int64, 8> group_dims_;
};
inline int64 Group::group_at(size_t index) const {
const auto& ix_t = iter_->ix_matrix_;
return ix_t(loc_, index);
}
// Implementation of Group::values<T>()
template <typename T>
typename TTypes<T>::UnalignedVec Group::values() const {