[tf.data] Optimize SerializeManySparseOp implementation used in unbatching tf.SparseTensor.
This change makes the following optimizations: 1. Split the template specialization (between tstring and Variant) so that it applies at the entire op level, rather than a per-element level. This permits us to specialize for the (overwhelmingly more common) Variant case: * Use `Variant::emplace()` instead of the move assignment operator to avoid copying the inline data (viz. the TensorShape) in a Tensor. 2. Only set empty elements when the input is empty. Currently we call setConstant() on the entire output to set empty elements. With this change we only set those elements if there is no matching group in the input. This prevents wasted work (i) in the assignment and (ii) in destroying the unnecessarily assigned Tensors. 3. Introduce `sparse::Group::group_at()` to avoid the need for constructing a temporary vector on each group access, only to access the 0th element. 4. Optimize `sparse::GroupIterable::GroupMatches()` to return immediately when a mismatch is detected. PiperOrigin-RevId: 289209832 Change-Id: I22df11bf474eab117307931908cef9c601d98226
This commit is contained in:
parent
0f3c91c2bb
commit
880cad8598
tensorflow/core
@ -32,6 +32,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/reshape_util.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/lib/gtl/optional.h"
|
||||
#include "tensorflow/core/util/sparse/group_iterator.h"
|
||||
#include "tensorflow/core/util/sparse/sparse_tensor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -139,24 +140,150 @@ REGISTER_KERNEL_BUILDER(Name("SerializeSparse")
|
||||
.TypeConstraint<Variant>("out_type"),
|
||||
SerializeSparseOp<Variant>);
|
||||
|
||||
template <typename T, typename U>
|
||||
struct SerializeGroups {};
|
||||
|
||||
template <typename T>
|
||||
class SerializeManySparseOpBase : public OpKernel {
|
||||
public:
|
||||
explicit SerializeManySparseOpBase(OpKernelConstruction* context)
|
||||
: OpKernel(context) {}
|
||||
struct SerializeGroups<T, tstring> {
|
||||
Status operator()(sparse::GroupIterable* minibatch,
|
||||
const Tensor& output_shape, int64 N, int rank,
|
||||
Tensor* serialized_sparse) {
|
||||
auto serialized_sparse_t = serialized_sparse->matrix<tstring>();
|
||||
|
||||
void Compute(OpKernelContext* context) override {}
|
||||
int64 last_nonempty_group = -1;
|
||||
|
||||
protected:
|
||||
Status Initialize(const int64 n, Tensor* result);
|
||||
Status Serialize(const Tensor& input, T* result);
|
||||
auto serialize = [](const Tensor& input, tstring* result) {
|
||||
TensorProto proto;
|
||||
input.AsProtoTensorContent(&proto);
|
||||
*result = proto.SerializeAsString();
|
||||
};
|
||||
|
||||
tstring serialized_shape;
|
||||
serialize(output_shape, &serialized_shape);
|
||||
|
||||
auto serialize_empty_element = [&](int64 b) {
|
||||
serialize(Tensor(DT_INT64, {0, rank - 1}), &serialized_sparse_t(b, 0));
|
||||
serialize(Tensor(DataTypeToEnum<T>::value, {0}),
|
||||
&serialized_sparse_t(b, 1));
|
||||
serialized_sparse_t(b, 2) = serialized_shape;
|
||||
};
|
||||
|
||||
for (const auto& subset : *minibatch) {
|
||||
const int64 b = subset.group_at(0);
|
||||
if (b < 0 || b >= N) {
|
||||
return errors::InvalidArgument(
|
||||
"Received unexpected column 0 value in input SparseTensor: ", b,
|
||||
" < 0 or >= N (= ", N, ")");
|
||||
}
|
||||
|
||||
// GroupIterable generates only the non-empty groups of rows, so we must
|
||||
// generate empty outputs for any empty rows since the last non-empty
|
||||
// group that was generated.
|
||||
for (int64 empty_b = last_nonempty_group + 1; empty_b < b; ++empty_b) {
|
||||
serialize_empty_element(empty_b);
|
||||
}
|
||||
|
||||
last_nonempty_group = b;
|
||||
|
||||
const auto indices = subset.indices();
|
||||
const auto values = subset.values<T>();
|
||||
const int64 num_entries = values.size();
|
||||
|
||||
Tensor output_indices = Tensor(DT_INT64, {num_entries, rank - 1});
|
||||
Tensor output_values = Tensor(DataTypeToEnum<T>::value, {num_entries});
|
||||
|
||||
auto output_indices_t = output_indices.matrix<int64>();
|
||||
auto output_values_t = output_values.vec<T>();
|
||||
|
||||
for (int i = 0; i < num_entries; ++i) {
|
||||
for (int d = 1; d < rank; ++d) {
|
||||
output_indices_t(i, d - 1) = indices(i, d);
|
||||
}
|
||||
output_values_t(i) = values(i);
|
||||
}
|
||||
|
||||
serialize(output_indices, &serialized_sparse_t(b, 0));
|
||||
serialize(output_values, &serialized_sparse_t(b, 1));
|
||||
serialized_sparse_t(b, 2) = serialized_shape;
|
||||
}
|
||||
|
||||
for (int64 empty_b = last_nonempty_group + 1; empty_b < N; ++empty_b) {
|
||||
serialize_empty_element(empty_b);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct SerializeGroups<T, Variant> {
|
||||
Status operator()(sparse::GroupIterable* minibatch,
|
||||
const Tensor& output_shape, int64 N, int rank,
|
||||
Tensor* serialized_sparse) {
|
||||
auto serialized_sparse_t = serialized_sparse->template matrix<Variant>();
|
||||
|
||||
int64 last_nonempty_group = -1;
|
||||
|
||||
auto serialize_empty_element = [&](int64 b) {
|
||||
serialized_sparse_t(b, 0).emplace<Tensor>(DT_INT64,
|
||||
TensorShape({0, rank - 1}));
|
||||
serialized_sparse_t(b, 1).emplace<Tensor>(DataTypeToEnum<T>::value,
|
||||
TensorShape({0}));
|
||||
serialized_sparse_t(b, 2).emplace<Tensor>(output_shape);
|
||||
};
|
||||
|
||||
for (const auto& subset : *minibatch) {
|
||||
const int64 b = subset.group_at(0);
|
||||
if (b < 0 || b >= N) {
|
||||
return errors::InvalidArgument(
|
||||
"Received unexpected column 0 value in input SparseTensor: ", b,
|
||||
" < 0 or >= N (= ", N, ")");
|
||||
}
|
||||
|
||||
// GroupIterable generates only the non-empty groups of rows, so we must
|
||||
// generate empty outputs for any empty rows since the last non-empty
|
||||
// group that was generated.
|
||||
for (int64 empty_b = last_nonempty_group + 1; empty_b < b; ++empty_b) {
|
||||
serialize_empty_element(empty_b);
|
||||
}
|
||||
|
||||
last_nonempty_group = b;
|
||||
|
||||
const auto indices = subset.indices();
|
||||
const auto values = subset.values<T>();
|
||||
const int64 num_entries = values.size();
|
||||
|
||||
Tensor& output_indices = serialized_sparse_t(b, 0).emplace<Tensor>(
|
||||
DT_INT64, TensorShape({num_entries, rank - 1}));
|
||||
Tensor& output_values = serialized_sparse_t(b, 1).emplace<Tensor>(
|
||||
DataTypeToEnum<T>::value, TensorShape({num_entries}));
|
||||
|
||||
auto output_indices_t = output_indices.matrix<int64>();
|
||||
auto output_values_t = output_values.vec<T>();
|
||||
|
||||
for (int i = 0; i < num_entries; ++i) {
|
||||
for (int d = 1; d < rank; ++d) {
|
||||
output_indices_t(i, d - 1) = indices(i, d);
|
||||
}
|
||||
output_values_t(i) = values(i);
|
||||
}
|
||||
|
||||
serialized_sparse_t(b, 2).emplace<Tensor>(output_shape);
|
||||
}
|
||||
|
||||
for (int64 empty_b = last_nonempty_group + 1; empty_b < N; ++empty_b) {
|
||||
serialize_empty_element(empty_b);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U>
|
||||
class SerializeManySparseOp : public SerializeManySparseOpBase<U> {
|
||||
class SerializeManySparseOp : public OpKernel {
|
||||
public:
|
||||
explicit SerializeManySparseOp(OpKernelConstruction* context)
|
||||
: SerializeManySparseOpBase<U>(context) {}
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor* input_indices;
|
||||
@ -197,85 +324,25 @@ class SerializeManySparseOp : public SerializeManySparseOpBase<U> {
|
||||
|
||||
auto input_shape_t = input_shape->vec<int64>();
|
||||
const int64 N = input_shape_t(0);
|
||||
Tensor serialized_sparse;
|
||||
OP_REQUIRES_OK(context, this->Initialize(N, &serialized_sparse));
|
||||
auto serialized_sparse_t = serialized_sparse.matrix<U>();
|
||||
|
||||
Tensor* serialized_sparse;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(0, {N, 3}, &serialized_sparse));
|
||||
|
||||
OP_REQUIRES_OK(context, input_st.IndicesValid());
|
||||
|
||||
// Initialize output with empty values and the proper shapes.
|
||||
Tensor output_blank_indices(DT_INT64, {0, rank - 1});
|
||||
U serialized_indices;
|
||||
OP_REQUIRES_OK(context,
|
||||
this->Serialize(output_blank_indices, &serialized_indices));
|
||||
serialized_sparse_t.template chip<1>(0).setConstant(serialized_indices);
|
||||
|
||||
Tensor output_blank_values(DataTypeToEnum<T>::value, {0});
|
||||
U serialized_values;
|
||||
OP_REQUIRES_OK(context,
|
||||
this->Serialize(output_blank_values, &serialized_values));
|
||||
serialized_sparse_t.template chip<1>(1).setConstant(serialized_values);
|
||||
|
||||
Tensor output_shape(DT_INT64, {rank - 1});
|
||||
auto output_shape_t = output_shape.vec<int64>();
|
||||
for (int d = 1; d < rank; d++) output_shape_t(d - 1) = input_shape_t(d);
|
||||
U serialized_shape;
|
||||
OP_REQUIRES_OK(context, this->Serialize(output_shape, &serialized_shape));
|
||||
serialized_sparse_t.template chip<1>(2).setConstant(serialized_shape);
|
||||
|
||||
// Get groups by minibatch dimension
|
||||
sparse::GroupIterable minibatch = input_st.group({0});
|
||||
for (const auto& subset : minibatch) {
|
||||
const int64 b = subset.group()[0];
|
||||
OP_REQUIRES(
|
||||
context, b > -1 && b < N,
|
||||
errors::InvalidArgument(
|
||||
"Received unexpected column 0 value in input SparseTensor: ", b,
|
||||
" < 0 or >= N (= ", N, ")"));
|
||||
|
||||
const auto indices = subset.indices();
|
||||
const auto values = subset.values<T>();
|
||||
const int64 num_entries = values.size();
|
||||
|
||||
Tensor output_indices = Tensor(DT_INT64, {num_entries, rank - 1});
|
||||
Tensor output_values = Tensor(DataTypeToEnum<T>::value, {num_entries});
|
||||
|
||||
auto output_indices_t = output_indices.matrix<int64>();
|
||||
auto output_values_t = output_values.vec<T>();
|
||||
|
||||
for (int i = 0; i < num_entries; ++i) {
|
||||
for (int d = 1; d < rank; ++d) {
|
||||
output_indices_t(i, d - 1) = indices(i, d);
|
||||
}
|
||||
output_values_t(i) = values(i);
|
||||
}
|
||||
|
||||
OP_REQUIRES_OK(
|
||||
context, this->Serialize(output_indices, &serialized_sparse_t(b, 0)));
|
||||
OP_REQUIRES_OK(
|
||||
context, this->Serialize(output_values, &serialized_sparse_t(b, 1)));
|
||||
}
|
||||
|
||||
context->set_output(0, serialized_sparse);
|
||||
OP_REQUIRES_OK(context, SerializeGroups<T, U>()(&minibatch, output_shape, N,
|
||||
rank, serialized_sparse));
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
Status SerializeManySparseOpBase<tstring>::Initialize(const int64 n,
|
||||
Tensor* result) {
|
||||
*result = Tensor(DT_STRING, TensorShape({n, 3}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <>
|
||||
Status SerializeManySparseOpBase<tstring>::Serialize(const Tensor& input,
|
||||
tstring* result) {
|
||||
TensorProto proto;
|
||||
input.AsProtoTensorContent(&proto);
|
||||
*result = proto.SerializeAsString();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define REGISTER_KERNELS(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("SerializeManySparse") \
|
||||
.Device(DEVICE_CPU) \
|
||||
@ -286,19 +353,6 @@ Status SerializeManySparseOpBase<tstring>::Serialize(const Tensor& input,
|
||||
TF_CALL_ALL_TYPES(REGISTER_KERNELS);
|
||||
#undef REGISTER_KERNELS
|
||||
|
||||
template <>
|
||||
Status SerializeManySparseOpBase<Variant>::Initialize(const int64 n,
|
||||
Tensor* result) {
|
||||
*result = Tensor(DT_VARIANT, TensorShape({n, 3}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <>
|
||||
Status SerializeManySparseOpBase<Variant>::Serialize(const Tensor& input,
|
||||
Variant* result) {
|
||||
*result = input;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define REGISTER_KERNELS(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("SerializeManySparse") \
|
||||
|
@ -37,6 +37,7 @@ class Group {
|
||||
: iter_(iter), loc_(loc), next_loc_(next_loc) {}
|
||||
|
||||
std::vector<int64> group() const;
|
||||
int64 group_at(size_t index) const;
|
||||
TTypes<int64>::UnalignedConstMatrix indices() const;
|
||||
template <typename T>
|
||||
typename TTypes<T>::UnalignedVec values() const;
|
||||
@ -96,13 +97,12 @@ class GroupIterable {
|
||||
|
||||
template <typename TIX>
|
||||
inline bool GroupMatches(const TIX& ix, int64 loc_a, int64 loc_b) const {
|
||||
bool matches = true;
|
||||
for (int d : group_dims_) {
|
||||
if (ix(loc_a, d) != ix(loc_b, d)) {
|
||||
matches = false;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return matches;
|
||||
return true;
|
||||
}
|
||||
|
||||
class IteratorStep {
|
||||
@ -135,6 +135,11 @@ class GroupIterable {
|
||||
const gtl::InlinedVector<int64, 8> group_dims_;
|
||||
};
|
||||
|
||||
inline int64 Group::group_at(size_t index) const {
|
||||
const auto& ix_t = iter_->ix_matrix_;
|
||||
return ix_t(loc_, index);
|
||||
}
|
||||
|
||||
// Implementation of Group::values<T>()
|
||||
template <typename T>
|
||||
typename TTypes<T>::UnalignedVec Group::values() const {
|
||||
|
Loading…
Reference in New Issue
Block a user