Adding variant-based serialization and deserialization for sparse tensors.

PiperOrigin-RevId: 177971801
This commit is contained in:
Jiri Simsa 2017-12-05 10:13:41 -08:00 committed by TensorFlower Gardener
parent 77b60c1ac6
commit f88cd91955
10 changed files with 442 additions and 187 deletions

View File

@ -18,7 +18,14 @@ END
1-D. The `shape` of the minibatch `SparseTensor`. 1-D. The `shape` of the minibatch `SparseTensor`.
END END
} }
summary: "Serialize an `N`-minibatch `SparseTensor` into an `[N, 3]` string `Tensor`." attr {
name: "out_type"
description: <<END
The `dtype` to use for serialization; the supported types are `string`
(default) and `variant`.
END
}
summary: "Serialize an `N`-minibatch `SparseTensor` into an `[N, 3]` `Tensor` object."
description: <<END description: <<END
The `SparseTensor` must have rank `R` greater than 1, and the first dimension The `SparseTensor` must have rank `R` greater than 1, and the first dimension
is treated as the minibatch dimension. Elements of the `SparseTensor` is treated as the minibatch dimension. Elements of the `SparseTensor`

View File

@ -18,5 +18,12 @@ END
1-D. The `shape` of the `SparseTensor`. 1-D. The `shape` of the `SparseTensor`.
END END
} }
summary: "Serialize a `SparseTensor` into a string 3-vector (1-D `Tensor`) object." attr {
name: "out_type"
description: <<END
The `dtype` to use for serialization; the supported types are `string`
(default) and `variant`.
END
}
summary: "Serialize a `SparseTensor` into a `[3]` `Tensor` object."
} }

View File

@ -73,12 +73,14 @@ REGISTER(quint16)
REGISTER(qint16) REGISTER(qint16)
REGISTER(qint32) REGISTER(qint32)
REGISTER(bfloat16) REGISTER(bfloat16)
TF_CALL_variant(REGISTER)
#if defined(IS_MOBILE_PLATFORM) && !defined(SUPPORT_SELECTIVE_REGISTRATION) && \ #if defined(IS_MOBILE_PLATFORM) && !defined(SUPPORT_SELECTIVE_REGISTRATION) && \
!defined(__ANDROID_TYPES_FULL__) !defined(__ANDROID_TYPES_FULL__)
// Primarily used for SavedModel support on mobile. Registering it here only if // Primarily used for SavedModel support on mobile. Registering it here only
// __ANDROID_TYPES_FULL__ is not defined, as that already register strings // if __ANDROID_TYPES_FULL__ is not defined (which already registers string)
REGISTER(string); // to avoid duplicate registration.
REGISTER(string);
#endif // defined(IS_MOBILE_PLATFORM) && #endif // defined(IS_MOBILE_PLATFORM) &&
// !defined(SUPPORT_SELECTIVE_REGISTRATION) && // !defined(SUPPORT_SELECTIVE_REGISTRATION) &&
// !defined(__ANDROID_TYPES_FULL__) // !defined(__ANDROID_TYPES_FULL__)

View File

@ -140,6 +140,7 @@ class PackOp : public OpKernel {
TF_CALL_ALL_TYPES(REGISTER_PACK); TF_CALL_ALL_TYPES(REGISTER_PACK);
TF_CALL_QUANTIZED_TYPES(REGISTER_PACK); TF_CALL_QUANTIZED_TYPES(REGISTER_PACK);
TF_CALL_bfloat16(REGISTER_PACK); TF_CALL_bfloat16(REGISTER_PACK);
TF_CALL_variant(REGISTER_PACK);
#if defined(IS_MOBILE_PLATFORM) && !defined(SUPPORT_SELECTIVE_REGISTRATION) #if defined(IS_MOBILE_PLATFORM) && !defined(SUPPORT_SELECTIVE_REGISTRATION)
// Primarily used for SavedModel support on mobile. // Primarily used for SavedModel support on mobile.

View File

@ -27,6 +27,8 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_util.h" #include "tensorflow/core/framework/tensor_util.h"
#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
#include "tensorflow/core/kernels/reshape_util.h" #include "tensorflow/core/kernels/reshape_util.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/util/sparse/sparse_tensor.h" #include "tensorflow/core/util/sparse/sparse_tensor.h"
@ -35,15 +37,20 @@ namespace tensorflow {
using sparse::SparseTensor; using sparse::SparseTensor;
template <typename T>
class SerializeSparseOp : public OpKernel { class SerializeSparseOp : public OpKernel {
public: public:
explicit SerializeSparseOp(OpKernelConstruction* context) explicit SerializeSparseOp(OpKernelConstruction* context)
: OpKernel(context) {} : OpKernel(context) {}
Status Initialize(Tensor* result);
Status Serialize(const Tensor& input, T* result);
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
const Tensor* input_indices; const Tensor* input_indices;
const Tensor* input_values; const Tensor* input_values;
const Tensor* input_shape; const Tensor* input_shape;
OP_REQUIRES_OK(context, context->input("sparse_indices", &input_indices)); OP_REQUIRES_OK(context, context->input("sparse_indices", &input_indices));
OP_REQUIRES_OK(context, context->input("sparse_values", &input_values)); OP_REQUIRES_OK(context, context->input("sparse_values", &input_values));
OP_REQUIRES_OK(context, context->input("sparse_shape", &input_shape)); OP_REQUIRES_OK(context, context->input("sparse_shape", &input_shape));
@ -62,33 +69,74 @@ class SerializeSparseOp : public OpKernel {
"Input shape should be a vector but received shape ", "Input shape should be a vector but received shape ",
input_shape->shape().DebugString())); input_shape->shape().DebugString()));
TensorProto proto_indices; Tensor serialized_sparse;
TensorProto proto_values; OP_REQUIRES_OK(context, Initialize(&serialized_sparse));
TensorProto proto_shape;
input_indices->AsProtoTensorContent(&proto_indices); auto serialized_sparse_t = serialized_sparse.vec<T>();
input_values->AsProtoTensorContent(&proto_values); OP_REQUIRES_OK(context, Serialize(*input_indices, &serialized_sparse_t(0)));
input_shape->AsProtoTensorContent(&proto_shape); OP_REQUIRES_OK(context, Serialize(*input_values, &serialized_sparse_t(1)));
OP_REQUIRES_OK(context, Serialize(*input_shape, &serialized_sparse_t(2)));
Tensor serialized_sparse(DT_STRING, TensorShape({3}));
auto serialized_sparse_t = serialized_sparse.vec<string>();
serialized_sparse_t(0) = proto_indices.SerializeAsString();
serialized_sparse_t(1) = proto_values.SerializeAsString();
serialized_sparse_t(2) = proto_shape.SerializeAsString();
context->set_output(0, serialized_sparse); context->set_output(0, serialized_sparse);
} }
}; };
REGISTER_KERNEL_BUILDER(Name("SerializeSparse").Device(DEVICE_CPU), template <>
SerializeSparseOp); Status SerializeSparseOp<string>::Initialize(Tensor* result) {
*result = Tensor(DT_STRING, TensorShape({3}));
return Status::OK();
}
template <>
Status SerializeSparseOp<string>::Serialize(const Tensor& input,
string* result) {
TensorProto proto;
input.AsProtoTensorContent(&proto);
*result = proto.SerializeAsString();
return Status::OK();
}
REGISTER_KERNEL_BUILDER(Name("SerializeSparse")
.Device(DEVICE_CPU)
.TypeConstraint<string>("out_type"),
SerializeSparseOp<string>);
template <>
Status SerializeSparseOp<Variant>::Initialize(Tensor* result) {
*result = Tensor(DT_VARIANT, TensorShape({3}));
return Status::OK();
}
template <>
Status SerializeSparseOp<Variant>::Serialize(const Tensor& input,
Variant* result) {
*result = input;
return Status::OK();
}
REGISTER_KERNEL_BUILDER(Name("SerializeSparse")
.Device(DEVICE_CPU)
.TypeConstraint<Variant>("out_type"),
SerializeSparseOp<Variant>);
template <typename T> template <typename T>
class SerializeManySparseOp : public OpKernel { class SerializeManySparseOpBase : public OpKernel {
public:
explicit SerializeManySparseOpBase(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {}
protected:
Status Initialize(const int64 n, Tensor* result);
Status Serialize(const Tensor& input, T* result);
};
template <typename T, typename U>
class SerializeManySparseOp : public SerializeManySparseOpBase<U> {
public: public:
explicit SerializeManySparseOp(OpKernelConstruction* context) explicit SerializeManySparseOp(OpKernelConstruction* context)
: OpKernel(context) {} : SerializeManySparseOpBase<U>(context) {}
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
const Tensor* input_indices; const Tensor* input_indices;
@ -127,37 +175,31 @@ class SerializeManySparseOp : public OpKernel {
auto input_shape_t = input_shape->vec<int64>(); auto input_shape_t = input_shape->vec<int64>();
const int64 N = input_shape_t(0); const int64 N = input_shape_t(0);
Tensor serialized_sparse;
Tensor serialized_sparse(DT_STRING, TensorShape({N, 3})); OP_REQUIRES_OK(context, this->Initialize(N, &serialized_sparse));
auto serialized_sparse_t = serialized_sparse.matrix<string>(); auto serialized_sparse_t = serialized_sparse.matrix<U>();
OP_REQUIRES_OK(context, input_st.IndicesValid()); OP_REQUIRES_OK(context, input_st.IndicesValid());
// We can generate the output shape proto string now, for all // Initialize output with empty values and the proper shapes.
// minibatch entries. Tensor output_blank_indices(DT_INT64, {0, rank - 1});
U serialized_indices;
OP_REQUIRES_OK(context,
this->Serialize(output_blank_indices, &serialized_indices));
serialized_sparse_t.template chip<1>(0).setConstant(serialized_indices);
Tensor output_blank_values(DataTypeToEnum<T>::value, {0});
U serialized_values;
OP_REQUIRES_OK(context,
this->Serialize(output_blank_values, &serialized_values));
serialized_sparse_t.template chip<1>(1).setConstant(serialized_values);
Tensor output_shape(DT_INT64, {rank - 1}); Tensor output_shape(DT_INT64, {rank - 1});
auto output_shape_t = output_shape.vec<int64>(); auto output_shape_t = output_shape.vec<int64>();
for (int d = 1; d < rank; d++) output_shape_t(d - 1) = input_shape_t(d); for (int d = 1; d < rank; d++) output_shape_t(d - 1) = input_shape_t(d);
TensorProto proto_shape; U serialized_shape;
output_shape.AsProtoTensorContent(&proto_shape); OP_REQUIRES_OK(context, this->Serialize(output_shape, &serialized_shape));
const string proto_shape_string = proto_shape.SerializeAsString(); serialized_sparse_t.template chip<1>(2).setConstant(serialized_shape);
Tensor output_blank_indices(DT_INT64, {0, rank - 1});
Tensor output_blank_values(DataTypeToEnum<T>::value, {0});
TensorProto proto_blank_indices;
TensorProto proto_blank_values;
output_blank_indices.AsProtoTensorContent(&proto_blank_indices);
output_blank_values.AsProtoTensorContent(&proto_blank_values);
const string proto_blank_indices_string =
proto_blank_indices.SerializeAsString();
const string proto_blank_values_string =
proto_blank_values.SerializeAsString();
// Initialize output with empty values and the proper shapes.
serialized_sparse_t.chip<1>(0).setConstant(proto_blank_indices_string);
serialized_sparse_t.chip<1>(1).setConstant(proto_blank_values_string);
serialized_sparse_t.chip<1>(2).setConstant(proto_shape_string);
// Get groups by minibatch dimension // Get groups by minibatch dimension
sparse::GroupIterable minibatch = input_st.group({0}); sparse::GroupIterable minibatch = input_st.group({0});
@ -186,33 +228,83 @@ class SerializeManySparseOp : public OpKernel {
output_values_t(i) = values(i); output_values_t(i) = values(i);
} }
TensorProto proto_indices; OP_REQUIRES_OK(
TensorProto proto_values; context, this->Serialize(output_indices, &serialized_sparse_t(b, 0)));
output_indices.AsProtoTensorContent(&proto_indices); OP_REQUIRES_OK(
output_values.AsProtoTensorContent(&proto_values); context, this->Serialize(output_values, &serialized_sparse_t(b, 1)));
serialized_sparse_t(b, 0) = proto_indices.SerializeAsString();
serialized_sparse_t(b, 1) = proto_values.SerializeAsString();
} }
context->set_output(0, serialized_sparse); context->set_output(0, serialized_sparse);
} }
}; };
#define REGISTER_KERNELS(type) \ template <>
REGISTER_KERNEL_BUILDER(Name("SerializeManySparse") \ Status SerializeManySparseOpBase<string>::Initialize(const int64 n,
.Device(DEVICE_CPU) \ Tensor* result) {
.TypeConstraint<type>("T"), \ *result = Tensor(DT_STRING, TensorShape({n, 3}));
SerializeManySparseOp<type>) return Status::OK();
}
template <>
Status SerializeManySparseOpBase<string>::Serialize(const Tensor& input,
string* result) {
TensorProto proto;
input.AsProtoTensorContent(&proto);
*result = proto.SerializeAsString();
return Status::OK();
}
#define REGISTER_KERNELS(type) \
REGISTER_KERNEL_BUILDER(Name("SerializeManySparse") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<string>("out_type"), \
SerializeManySparseOp<type, string>)
TF_CALL_ALL_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS
template <>
Status SerializeManySparseOpBase<Variant>::Initialize(const int64 n,
Tensor* result) {
*result = Tensor(DT_VARIANT, TensorShape({n, 3}));
return Status::OK();
}
template <>
Status SerializeManySparseOpBase<Variant>::Serialize(const Tensor& input,
Variant* result) {
*result = input;
return Status::OK();
}
#define REGISTER_KERNELS(type) \
REGISTER_KERNEL_BUILDER(Name("SerializeManySparse") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<Variant>("out_type"), \
SerializeManySparseOp<type, Variant>)
TF_CALL_ALL_TYPES(REGISTER_KERNELS); TF_CALL_ALL_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS #undef REGISTER_KERNELS
template <typename T> template <typename T>
class DeserializeSparseOp : public OpKernel { class DeserializeSparseOpBase : public OpKernel {
public:
explicit DeserializeSparseOpBase(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {}
protected:
Status Deserialize(const T& serialized, Tensor* result);
};
template <typename T, typename U>
class DeserializeSparseOp : public DeserializeSparseOpBase<U> {
public: public:
explicit DeserializeSparseOp(OpKernelConstruction* context) explicit DeserializeSparseOp(OpKernelConstruction* context)
: OpKernel(context) {} : DeserializeSparseOpBase<U>(context) {}
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
const Tensor& serialized_sparse = context->input(0); const Tensor& serialized_sparse = context->input(0);
@ -246,53 +338,30 @@ class DeserializeSparseOp : public OpKernel {
indices.reserve(num_sparse_tensors); indices.reserve(num_sparse_tensors);
values.reserve(num_sparse_tensors); values.reserve(num_sparse_tensors);
const auto& serialized_sparse_t = const auto& serialized_sparse_t = serialized_sparse.flat_inner_dims<U, 2>();
serialized_sparse.flat_inner_dims<string, 2>();
for (int i = 0; i < num_sparse_tensors; ++i) { for (int i = 0; i < num_sparse_tensors; ++i) {
Tensor output_indices(DT_INT64); Tensor output_indices;
Tensor output_values(DataTypeToEnum<T>::value); OP_REQUIRES_OK(context, this->Deserialize(serialized_sparse_t(i, 0),
Tensor output_shape(DT_INT64); &output_indices));
TensorProto proto_indices;
TensorProto proto_values;
TensorProto proto_shape;
OP_REQUIRES(
context,
ParseProtoUnlimited(&proto_indices, serialized_sparse_t(i, 0)),
errors::InvalidArgument("Could not parse serialized_sparse[", i,
", 0]"));
OP_REQUIRES(context,
ParseProtoUnlimited(&proto_values, serialized_sparse_t(i, 1)),
errors::InvalidArgument("Could not parse serialized_sparse[",
i, ", 1]"));
OP_REQUIRES(context,
ParseProtoUnlimited(&proto_shape, serialized_sparse_t(i, 2)),
errors::InvalidArgument("Could not parse serialized_sparse[",
i, ", 2]"));
OP_REQUIRES(context, output_indices.FromProto(proto_indices),
errors::InvalidArgument(
"Could not construct Tensor serialized_sparse[", i,
", 0] (indices)"));
OP_REQUIRES(context, TensorShapeUtils::IsMatrix(output_indices.shape()), OP_REQUIRES(context, TensorShapeUtils::IsMatrix(output_indices.shape()),
errors::InvalidArgument( errors::InvalidArgument(
"Expected serialized_sparse[", i, "Expected serialized_sparse[", i,
", 0] to represent an index matrix but received shape ", ", 0] to represent an index matrix but received shape ",
output_indices.shape().DebugString())); output_indices.shape().DebugString()));
OP_REQUIRES(context, output_values.FromProto(proto_values),
errors::InvalidArgument( Tensor output_values;
"Could not construct Tensor serialized_sparse[", i, OP_REQUIRES_OK(context, this->Deserialize(serialized_sparse_t(i, 1),
", 1] (values)")); &output_values));
OP_REQUIRES(context, TensorShapeUtils::IsVector(output_values.shape()), OP_REQUIRES(context, TensorShapeUtils::IsVector(output_values.shape()),
errors::InvalidArgument( errors::InvalidArgument(
"Expected serialized_sparse[", i, "Expected serialized_sparse[", i,
", 1] to represent a values vector but received shape ", ", 1] to represent a values vector but received shape ",
output_values.shape().DebugString())); output_values.shape().DebugString()));
OP_REQUIRES(context, output_shape.FromProto(proto_shape),
errors::InvalidArgument( Tensor output_shape;
"Could not construct Tensor serialized_sparse[", i, OP_REQUIRES_OK(
", 2] (shape)")); context, this->Deserialize(serialized_sparse_t(i, 2), &output_shape));
OP_REQUIRES( OP_REQUIRES(
context, TensorShapeUtils::IsVector(output_shape.shape()), context, TensorShapeUtils::IsVector(output_shape.shape()),
errors::InvalidArgument("Expected serialized_sparse[", i, errors::InvalidArgument("Expected serialized_sparse[", i,
@ -400,11 +469,27 @@ class DeserializeSparseOp : public OpKernel {
} }
}; };
#define REGISTER_KERNELS(type) \ template <>
REGISTER_KERNEL_BUILDER(Name("DeserializeSparse") \ Status DeserializeSparseOpBase<string>::Deserialize(const string& serialized,
.Device(DEVICE_CPU) \ Tensor* result) {
.TypeConstraint<type>("dtype"), \ TensorProto proto;
DeserializeSparseOp<type>) if (!ParseProtoUnlimited(&proto, serialized)) {
return errors::InvalidArgument("Could not parse serialized proto");
}
Tensor tensor;
if (!tensor.FromProto(proto)) {
return errors::InvalidArgument("Could not construct tensor from proto");
}
*result = tensor;
return Status::OK();
}
#define REGISTER_KERNELS(type) \
REGISTER_KERNEL_BUILDER(Name("DeserializeSparse") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("dtype") \
.TypeConstraint<string>("Tserialized"), \
DeserializeSparseOp<type, string>)
TF_CALL_ALL_TYPES(REGISTER_KERNELS); TF_CALL_ALL_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS #undef REGISTER_KERNELS
@ -413,7 +498,24 @@ TF_CALL_ALL_TYPES(REGISTER_KERNELS);
REGISTER_KERNEL_BUILDER(Name("DeserializeManySparse") \ REGISTER_KERNEL_BUILDER(Name("DeserializeManySparse") \
.Device(DEVICE_CPU) \ .Device(DEVICE_CPU) \
.TypeConstraint<type>("dtype"), \ .TypeConstraint<type>("dtype"), \
DeserializeSparseOp<type>) DeserializeSparseOp<type, string>)
TF_CALL_ALL_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS
template <>
Status DeserializeSparseOpBase<Variant>::Deserialize(const Variant& serialized,
Tensor* result) {
*result = *serialized.get<Tensor>();
return Status::OK();
}
#define REGISTER_KERNELS(type) \
REGISTER_KERNEL_BUILDER(Name("DeserializeSparse") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("dtype") \
.TypeConstraint<Variant>("Tserialized"), \
DeserializeSparseOp<type, Variant>)
TF_CALL_ALL_TYPES(REGISTER_KERNELS); TF_CALL_ALL_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS #undef REGISTER_KERNELS

View File

@ -190,7 +190,8 @@ REGISTER_OP("SerializeSparse")
.Input("sparse_values: T") .Input("sparse_values: T")
.Input("sparse_shape: int64") .Input("sparse_shape: int64")
.Attr("T: type") .Attr("T: type")
.Output("serialized_sparse: string") .Output("serialized_sparse: out_type")
.Attr("out_type: {string, variant} = DT_STRING")
.SetShapeFn([](InferenceContext* c) { .SetShapeFn([](InferenceContext* c) {
ShapeHandle unused; ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));
@ -200,11 +201,13 @@ REGISTER_OP("SerializeSparse")
return Status::OK(); return Status::OK();
}) })
.Doc(R"doc( .Doc(R"doc(
Serialize a `SparseTensor` into a string 3-vector (1-D `Tensor`) object. Serialize a `SparseTensor` into a `[3]` `Tensor` object.
sparse_indices: 2-D. The `indices` of the `SparseTensor`. sparse_indices: 2-D. The `indices` of the `SparseTensor`.
sparse_values: 1-D. The `values` of the `SparseTensor`. sparse_values: 1-D. The `values` of the `SparseTensor`.
sparse_shape: 1-D. The `shape` of the `SparseTensor`. sparse_shape: 1-D. The `shape` of the `SparseTensor`.
out_type: The `dtype` to use for serialization; the supported types are `string`
(default) and `variant`.
)doc"); )doc");
REGISTER_OP("SerializeManySparse") REGISTER_OP("SerializeManySparse")
@ -212,7 +215,8 @@ REGISTER_OP("SerializeManySparse")
.Input("sparse_values: T") .Input("sparse_values: T")
.Input("sparse_shape: int64") .Input("sparse_shape: int64")
.Attr("T: type") .Attr("T: type")
.Output("serialized_sparse: string") .Output("serialized_sparse: out_type")
.Attr("out_type: {string, variant} = DT_STRING")
.SetShapeFn([](InferenceContext* c) { .SetShapeFn([](InferenceContext* c) {
ShapeHandle unused; ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));
@ -222,7 +226,7 @@ REGISTER_OP("SerializeManySparse")
return Status::OK(); return Status::OK();
}) })
.Doc(R"doc( .Doc(R"doc(
Serialize an `N`-minibatch `SparseTensor` into an `[N, 3]` string `Tensor`. Serialize an `N`-minibatch `SparseTensor` into an `[N, 3]` `Tensor` object.
The `SparseTensor` must have rank `R` greater than 1, and the first dimension The `SparseTensor` must have rank `R` greater than 1, and the first dimension
is treated as the minibatch dimension. Elements of the `SparseTensor` is treated as the minibatch dimension. Elements of the `SparseTensor`
@ -235,14 +239,17 @@ The minibatch size `N` is extracted from `sparse_shape[0]`.
sparse_indices: 2-D. The `indices` of the minibatch `SparseTensor`. sparse_indices: 2-D. The `indices` of the minibatch `SparseTensor`.
sparse_values: 1-D. The `values` of the minibatch `SparseTensor`. sparse_values: 1-D. The `values` of the minibatch `SparseTensor`.
sparse_shape: 1-D. The `shape` of the minibatch `SparseTensor`. sparse_shape: 1-D. The `shape` of the minibatch `SparseTensor`.
out_type: The `dtype` to use for serialization; the supported types are `string`
(default) and `variant`.
)doc"); )doc");
REGISTER_OP("DeserializeSparse") REGISTER_OP("DeserializeSparse")
.Input("serialized_sparse: string") .Input("serialized_sparse: Tserialized")
.Attr("dtype: type")
.Output("sparse_indices: int64") .Output("sparse_indices: int64")
.Output("sparse_values: dtype") .Output("sparse_values: dtype")
.Output("sparse_shape: int64") .Output("sparse_shape: int64")
.Attr("dtype: type")
.Attr("Tserialized: {string, variant} = DT_STRING")
.SetShapeFn([](InferenceContext* c) { .SetShapeFn([](InferenceContext* c) {
// serialized sparse is [?, ..., ?, 3] vector. // serialized sparse is [?, ..., ?, 3] vector.
DimensionHandle unused; DimensionHandle unused;
@ -305,10 +312,10 @@ dtype: The `dtype` of the serialized `SparseTensor` objects.
REGISTER_OP("DeserializeManySparse") REGISTER_OP("DeserializeManySparse")
.Input("serialized_sparse: string") .Input("serialized_sparse: string")
.Attr("dtype: type")
.Output("sparse_indices: int64") .Output("sparse_indices: int64")
.Output("sparse_values: dtype") .Output("sparse_values: dtype")
.Output("sparse_shape: int64") .Output("sparse_shape: int64")
.Attr("dtype: type")
.SetShapeFn([](InferenceContext* c) { .SetShapeFn([](InferenceContext* c) {
// serialized sparse is [?,3] matrix. // serialized sparse is [?,3] matrix.
ShapeHandle serialized_sparse; ShapeHandle serialized_sparse;

View File

@ -64,12 +64,14 @@ class SerializeSparseTest(test.TestCase):
shape = np.array([3, 4, 5]).astype(np.int64) shape = np.array([3, 4, 5]).astype(np.int64)
return sparse_tensor_lib.SparseTensorValue(ind, val, shape) return sparse_tensor_lib.SparseTensorValue(ind, val, shape)
def testSerializeDeserialize(self): def _testSerializeDeserializeHelper(self,
serialize_fn,
deserialize_fn,
out_type=dtypes.string):
with self.test_session(use_gpu=False) as sess: with self.test_session(use_gpu=False) as sess:
sp_input = self._SparseTensorValue_5x6(np.arange(6)) sp_input = self._SparseTensorValue_5x6(np.arange(6))
serialized = sparse_ops.serialize_sparse(sp_input) serialized = serialize_fn(sp_input, out_type=out_type)
sp_deserialized = sparse_ops.deserialize_sparse( sp_deserialized = deserialize_fn(serialized, dtype=dtypes.int32)
serialized, dtype=dtypes.int32)
indices, values, shape = sess.run(sp_deserialized) indices, values, shape = sess.run(sp_deserialized)
@ -77,14 +79,25 @@ class SerializeSparseTest(test.TestCase):
self.assertAllEqual(values, sp_input[1]) self.assertAllEqual(values, sp_input[1])
self.assertAllEqual(shape, sp_input[2]) self.assertAllEqual(shape, sp_input[2])
def testSerializeDeserializeBatch(self): def testSerializeDeserialize(self):
self._testSerializeDeserializeHelper(sparse_ops.serialize_sparse,
sparse_ops.deserialize_sparse)
def testVariantSerializeDeserialize(self):
self._testSerializeDeserializeHelper(sparse_ops.serialize_sparse,
sparse_ops.deserialize_sparse,
dtypes.variant)
def _testSerializeDeserializeBatchHelper(self,
serialize_fn,
deserialize_fn,
out_type=dtypes.string):
with self.test_session(use_gpu=False) as sess: with self.test_session(use_gpu=False) as sess:
sp_input = self._SparseTensorValue_5x6(np.arange(6)) sp_input = self._SparseTensorValue_5x6(np.arange(6))
serialized = sparse_ops.serialize_sparse(sp_input) serialized = serialize_fn(sp_input, out_type=out_type)
serialized = array_ops.stack([serialized, serialized]) serialized = array_ops.stack([serialized, serialized])
sp_deserialized = sparse_ops.deserialize_sparse( sp_deserialized = deserialize_fn(serialized, dtype=dtypes.int32)
serialized, dtype=dtypes.int32)
combined_indices, combined_values, combined_shape = sess.run( combined_indices, combined_values, combined_shape = sess.run(
sp_deserialized) sp_deserialized)
@ -97,16 +110,29 @@ class SerializeSparseTest(test.TestCase):
self.assertAllEqual(combined_values[6:], sp_input[1]) self.assertAllEqual(combined_values[6:], sp_input[1])
self.assertAllEqual(combined_shape, [2, 5, 6]) self.assertAllEqual(combined_shape, [2, 5, 6])
def testSerializeDeserializeBatchInconsistentShape(self): def testSerializeDeserializeBatch(self):
self._testSerializeDeserializeBatchHelper(sparse_ops.serialize_sparse,
sparse_ops.deserialize_sparse)
def testSerializeDeserializeManyBatch(self):
self._testSerializeDeserializeBatchHelper(
sparse_ops.serialize_sparse, sparse_ops.deserialize_many_sparse)
def testVariantSerializeDeserializeBatch(self):
self._testSerializeDeserializeBatchHelper(sparse_ops.serialize_sparse,
sparse_ops.deserialize_sparse,
dtypes.variant)
def _testSerializeDeserializeBatchInconsistentShapeHelper(
self, serialize_fn, deserialize_fn, out_type=dtypes.string):
with self.test_session(use_gpu=False) as sess: with self.test_session(use_gpu=False) as sess:
sp_input0 = self._SparseTensorValue_5x6(np.arange(6)) sp_input0 = self._SparseTensorValue_5x6(np.arange(6))
sp_input1 = self._SparseTensorValue_3x4(np.arange(6)) sp_input1 = self._SparseTensorValue_3x4(np.arange(6))
serialized0 = sparse_ops.serialize_sparse(sp_input0) serialized0 = serialize_fn(sp_input0, out_type=out_type)
serialized1 = sparse_ops.serialize_sparse(sp_input1) serialized1 = serialize_fn(sp_input1, out_type=out_type)
serialized = array_ops.stack([serialized0, serialized1]) serialized = array_ops.stack([serialized0, serialized1])
sp_deserialized = sparse_ops.deserialize_sparse( sp_deserialized = deserialize_fn(serialized, dtype=dtypes.int32)
serialized, dtype=dtypes.int32)
combined_indices, combined_values, combined_shape = sess.run( combined_indices, combined_values, combined_shape = sess.run(
sp_deserialized) sp_deserialized)
@ -119,15 +145,26 @@ class SerializeSparseTest(test.TestCase):
self.assertAllEqual(combined_values[6:], sp_input1[1]) self.assertAllEqual(combined_values[6:], sp_input1[1])
self.assertAllEqual(combined_shape, [2, 5, 6]) self.assertAllEqual(combined_shape, [2, 5, 6])
def testSerializeDeserializeNestedBatch(self): def testSerializeDeserializeBatchInconsistentShape(self):
self._testSerializeDeserializeBatchInconsistentShapeHelper(
sparse_ops.serialize_sparse, sparse_ops.deserialize_sparse)
def testVariantSerializeDeserializeBatchInconsistentShape(self):
self._testSerializeDeserializeBatchInconsistentShapeHelper(
sparse_ops.serialize_sparse, sparse_ops.deserialize_sparse,
dtypes.variant)
def _testSerializeDeserializeNestedBatchHelper(self,
serialize_fn,
deserialize_fn,
out_type=dtypes.string):
with self.test_session(use_gpu=False) as sess: with self.test_session(use_gpu=False) as sess:
sp_input = self._SparseTensorValue_5x6(np.arange(6)) sp_input = self._SparseTensorValue_5x6(np.arange(6))
serialized = sparse_ops.serialize_sparse(sp_input) serialized = serialize_fn(sp_input, out_type=out_type)
serialized = array_ops.stack([serialized, serialized]) serialized = array_ops.stack([serialized, serialized])
serialized = array_ops.stack([serialized, serialized]) serialized = array_ops.stack([serialized, serialized])
sp_deserialized = sparse_ops.deserialize_sparse( sp_deserialized = deserialize_fn(serialized, dtype=dtypes.int32)
serialized, dtype=dtypes.int32)
combined_indices, combined_values, combined_shape = sess.run( combined_indices, combined_values, combined_shape = sess.run(
sp_deserialized) sp_deserialized)
@ -151,40 +188,29 @@ class SerializeSparseTest(test.TestCase):
self.assertAllEqual(combined_shape, [2, 2, 5, 6]) self.assertAllEqual(combined_shape, [2, 2, 5, 6])
def testSerializeDeserializeMany(self): def testSerializeDeserializeNestedBatch(self):
with self.test_session(use_gpu=False) as sess: self._testSerializeDeserializeNestedBatchHelper(
sp_input0 = self._SparseTensorValue_5x6(np.arange(6)) sparse_ops.serialize_sparse, sparse_ops.deserialize_sparse)
sp_input1 = self._SparseTensorValue_3x4(np.arange(6))
serialized0 = sparse_ops.serialize_sparse(sp_input0)
serialized1 = sparse_ops.serialize_sparse(sp_input1)
serialized_concat = array_ops.stack([serialized0, serialized1])
sp_deserialized = sparse_ops.deserialize_many_sparse( def testVariantSerializeDeserializeNestedBatch(self):
serialized_concat, dtype=dtypes.int32) self._testSerializeDeserializeNestedBatchHelper(
sparse_ops.serialize_sparse, sparse_ops.deserialize_sparse,
dtypes.variant)
combined_indices, combined_values, combined_shape = sess.run( def _testFeedSerializeDeserializeBatchHelper(self,
sp_deserialized) serialize_fn,
deserialize_fn,
self.assertAllEqual(combined_indices[:6, 0], [0] * 6) # minibatch 0 out_type=dtypes.string):
self.assertAllEqual(combined_indices[:6, 1:], sp_input0[0])
self.assertAllEqual(combined_indices[6:, 0], [1] * 6) # minibatch 1
self.assertAllEqual(combined_indices[6:, 1:], sp_input1[0])
self.assertAllEqual(combined_values[:6], sp_input0[1])
self.assertAllEqual(combined_values[6:], sp_input1[1])
self.assertAllEqual(combined_shape, [2, 5, 6])
def testFeedSerializeDeserializeMany(self):
with self.test_session(use_gpu=False) as sess: with self.test_session(use_gpu=False) as sess:
sp_input0 = self._SparseTensorPlaceholder() sp_input0 = self._SparseTensorPlaceholder()
sp_input1 = self._SparseTensorPlaceholder() sp_input1 = self._SparseTensorPlaceholder()
input0_val = self._SparseTensorValue_5x6(np.arange(6)) input0_val = self._SparseTensorValue_5x6(np.arange(6))
input1_val = self._SparseTensorValue_3x4(np.arange(6)) input1_val = self._SparseTensorValue_3x4(np.arange(6))
serialized0 = sparse_ops.serialize_sparse(sp_input0) serialized0 = serialize_fn(sp_input0, out_type=out_type)
serialized1 = sparse_ops.serialize_sparse(sp_input1) serialized1 = serialize_fn(sp_input1, out_type=out_type)
serialized_concat = array_ops.stack([serialized0, serialized1]) serialized_concat = array_ops.stack([serialized0, serialized1])
sp_deserialized = sparse_ops.deserialize_many_sparse( sp_deserialized = deserialize_fn(serialized_concat, dtype=dtypes.int32)
serialized_concat, dtype=dtypes.int32)
combined_indices, combined_values, combined_shape = sess.run( combined_indices, combined_values, combined_shape = sess.run(
sp_deserialized, {sp_input0: input0_val, sp_deserialized, {sp_input0: input0_val,
@ -198,40 +224,96 @@ class SerializeSparseTest(test.TestCase):
self.assertAllEqual(combined_values[6:], input1_val[1]) self.assertAllEqual(combined_values[6:], input1_val[1])
self.assertAllEqual(combined_shape, [2, 5, 6]) self.assertAllEqual(combined_shape, [2, 5, 6])
def testSerializeManyDeserializeManyRoundTrip(self): def testFeedSerializeDeserializeBatch(self):
self._testFeedSerializeDeserializeBatchHelper(sparse_ops.serialize_sparse,
sparse_ops.deserialize_sparse)
def testFeedSerializeDeserializeManyBatch(self):
self._testFeedSerializeDeserializeBatchHelper(
sparse_ops.serialize_sparse, sparse_ops.deserialize_many_sparse)
def testFeedVariantSerializeDeserializeBatch(self):
self._testFeedSerializeDeserializeBatchHelper(sparse_ops.serialize_sparse,
sparse_ops.deserialize_sparse,
dtypes.variant)
def _testSerializeManyShapeHelper(self,
serialize_many_fn,
out_type=dtypes.string):
with self.test_session(use_gpu=False) as sess: with self.test_session(use_gpu=False) as sess:
# N == 4 because shape_value == [4, 5] # N == 4 because shape_value == [4, 5]
indices_value = np.array([[0, 0], [0, 1], [2, 0]], dtype=np.int64) indices_value = np.array([[0, 0], [0, 1], [2, 0]], dtype=np.int64)
values_value = np.array([b"a", b"b", b"c"]) values_value = np.array([b"a", b"b", b"c"])
shape_value = np.array([4, 5], dtype=np.int64) shape_value = np.array([4, 5], dtype=np.int64)
sparse_tensor = self._SparseTensorPlaceholder(dtype=dtypes.string) sparse_tensor = self._SparseTensorPlaceholder(dtype=dtypes.string)
serialized = sparse_ops.serialize_many_sparse(sparse_tensor) serialized = serialize_many_fn(sparse_tensor, out_type=out_type)
deserialized = sparse_ops.deserialize_many_sparse( serialized_value = sess.run(
serialized, dtype=dtypes.string) serialized,
serialized_value, deserialized_value = sess.run(
[serialized, deserialized],
feed_dict={ feed_dict={
sparse_tensor.indices: indices_value, sparse_tensor.indices: indices_value,
sparse_tensor.values: values_value, sparse_tensor.values: values_value,
sparse_tensor.dense_shape: shape_value sparse_tensor.dense_shape: shape_value
}) })
self.assertEqual(serialized_value.shape, (4, 3)) self.assertEqual(serialized_value.shape, (4, 3))
def testSerializeManyShape(self):
self._testSerializeManyShapeHelper(sparse_ops.serialize_many_sparse)
def testVariantSerializeManyShape(self):
# NOTE: The following test is a no-op as it is currently not possible to
# convert the serialized variant value to a numpy value.
pass
def _testSerializeManyDeserializeBatchHelper(self,
serialize_many_fn,
deserialize_fn,
out_type=dtypes.string):
with self.test_session(use_gpu=False) as sess:
# N == 4 because shape_value == [4, 5]
indices_value = np.array([[0, 0], [0, 1], [2, 0]], dtype=np.int64)
values_value = np.array([b"a", b"b", b"c"])
shape_value = np.array([4, 5], dtype=np.int64)
sparse_tensor = self._SparseTensorPlaceholder(dtype=dtypes.string)
serialized = serialize_many_fn(sparse_tensor, out_type=out_type)
deserialized = deserialize_fn(serialized, dtype=dtypes.string)
deserialized_value = sess.run(
deserialized,
feed_dict={
sparse_tensor.indices: indices_value,
sparse_tensor.values: values_value,
sparse_tensor.dense_shape: shape_value
})
self.assertAllEqual(deserialized_value.indices, indices_value) self.assertAllEqual(deserialized_value.indices, indices_value)
self.assertAllEqual(deserialized_value.values, values_value) self.assertAllEqual(deserialized_value.values, values_value)
self.assertAllEqual(deserialized_value.dense_shape, shape_value) self.assertAllEqual(deserialized_value.dense_shape, shape_value)
def testDeserializeFailsWrongType(self): def testSerializeManyDeserializeBatch(self):
self._testSerializeManyDeserializeBatchHelper(
sparse_ops.serialize_many_sparse, sparse_ops.deserialize_sparse)
def testSerializeManyDeserializeManyBatch(self):
self._testSerializeManyDeserializeBatchHelper(
sparse_ops.serialize_many_sparse, sparse_ops.deserialize_many_sparse)
def testVariantSerializeManyDeserializeBatch(self):
self._testSerializeManyDeserializeBatchHelper(
sparse_ops.serialize_many_sparse, sparse_ops.deserialize_sparse,
dtypes.variant)
def _testDeserializeFailsWrongTypeHelper(self,
serialize_fn,
deserialize_fn,
out_type=dtypes.string):
with self.test_session(use_gpu=False) as sess: with self.test_session(use_gpu=False) as sess:
sp_input0 = self._SparseTensorPlaceholder() sp_input0 = self._SparseTensorPlaceholder()
sp_input1 = self._SparseTensorPlaceholder() sp_input1 = self._SparseTensorPlaceholder()
input0_val = self._SparseTensorValue_5x6(np.arange(6)) input0_val = self._SparseTensorValue_5x6(np.arange(6))
input1_val = self._SparseTensorValue_3x4(np.arange(6)) input1_val = self._SparseTensorValue_3x4(np.arange(6))
serialized0 = sparse_ops.serialize_sparse(sp_input0) serialized0 = serialize_fn(sp_input0, out_type=out_type)
serialized1 = sparse_ops.serialize_sparse(sp_input1) serialized1 = serialize_fn(sp_input1, out_type=out_type)
serialized_concat = array_ops.stack([serialized0, serialized1]) serialized_concat = array_ops.stack([serialized0, serialized1])
sp_deserialized = sparse_ops.deserialize_many_sparse( sp_deserialized = deserialize_fn(serialized_concat, dtype=dtypes.int64)
serialized_concat, dtype=dtypes.int64)
with self.assertRaisesOpError( with self.assertRaisesOpError(
r"Requested SparseTensor of type int64 but " r"Requested SparseTensor of type int64 but "
@ -240,18 +322,33 @@ class SerializeSparseTest(test.TestCase):
{sp_input0: input0_val, {sp_input0: input0_val,
sp_input1: input1_val}) sp_input1: input1_val})
def testDeserializeFailsInconsistentRank(self): def testDeserializeFailsWrongType(self):
self._testDeserializeFailsWrongTypeHelper(sparse_ops.serialize_sparse,
sparse_ops.deserialize_sparse)
def testDeserializeManyFailsWrongType(self):
self._testDeserializeFailsWrongTypeHelper(
sparse_ops.serialize_sparse, sparse_ops.deserialize_many_sparse)
def testVariantDeserializeFailsWrongType(self):
self._testDeserializeFailsWrongTypeHelper(sparse_ops.serialize_sparse,
sparse_ops.deserialize_sparse,
dtypes.variant)
def _testDeserializeFailsInconsistentRankHelper(self,
serialize_fn,
deserialize_fn,
out_type=dtypes.string):
with self.test_session(use_gpu=False) as sess: with self.test_session(use_gpu=False) as sess:
sp_input0 = self._SparseTensorPlaceholder() sp_input0 = self._SparseTensorPlaceholder()
sp_input1 = self._SparseTensorPlaceholder() sp_input1 = self._SparseTensorPlaceholder()
input0_val = self._SparseTensorValue_5x6(np.arange(6)) input0_val = self._SparseTensorValue_5x6(np.arange(6))
input1_val = self._SparseTensorValue_1x1x1() input1_val = self._SparseTensorValue_1x1x1()
serialized0 = sparse_ops.serialize_sparse(sp_input0) serialized0 = serialize_fn(sp_input0, out_type=out_type)
serialized1 = sparse_ops.serialize_sparse(sp_input1) serialized1 = serialize_fn(sp_input1, out_type=out_type)
serialized_concat = array_ops.stack([serialized0, serialized1]) serialized_concat = array_ops.stack([serialized0, serialized1])
sp_deserialized = sparse_ops.deserialize_many_sparse( sp_deserialized = deserialize_fn(serialized_concat, dtype=dtypes.int32)
serialized_concat, dtype=dtypes.int32)
with self.assertRaisesOpError( with self.assertRaisesOpError(
r"Inconsistent shape across SparseTensors: rank prior to " r"Inconsistent shape across SparseTensors: rank prior to "
@ -260,21 +357,43 @@ class SerializeSparseTest(test.TestCase):
{sp_input0: input0_val, {sp_input0: input0_val,
sp_input1: input1_val}) sp_input1: input1_val})
def testDeserializeFailsInvalidProto(self): def testDeserializeFailsInconsistentRank(self):
self._testDeserializeFailsInconsistentRankHelper(
sparse_ops.serialize_sparse, sparse_ops.deserialize_sparse)
def testDeserializeManyFailsInconsistentRank(self):
self._testDeserializeFailsInconsistentRankHelper(
sparse_ops.serialize_sparse, sparse_ops.deserialize_many_sparse)
def testVariantDeserializeFailsInconsistentRank(self):
self._testDeserializeFailsInconsistentRankHelper(
sparse_ops.serialize_sparse, sparse_ops.deserialize_sparse,
dtypes.variant)
def _testDeserializeFailsInvalidProtoHelper(self,
serialize_fn,
deserialize_fn,
out_type=dtypes.string):
with self.test_session(use_gpu=False) as sess: with self.test_session(use_gpu=False) as sess:
sp_input0 = self._SparseTensorPlaceholder() sp_input0 = self._SparseTensorPlaceholder()
input0_val = self._SparseTensorValue_5x6(np.arange(6)) input0_val = self._SparseTensorValue_5x6(np.arange(6))
serialized0 = sparse_ops.serialize_sparse(sp_input0) serialized0 = serialize_fn(sp_input0, out_type=out_type)
serialized1 = ["a", "b", "c"] serialized1 = ["a", "b", "c"]
serialized_concat = array_ops.stack([serialized0, serialized1]) serialized_concat = array_ops.stack([serialized0, serialized1])
sp_deserialized = sparse_ops.deserialize_many_sparse( sp_deserialized = deserialize_fn(serialized_concat, dtype=dtypes.int32)
serialized_concat, dtype=dtypes.int32)
with self.assertRaisesOpError( with self.assertRaisesOpError(r"Could not parse serialized proto"):
r"Could not parse serialized_sparse\[1, 0\]"):
sess.run(sp_deserialized, {sp_input0: input0_val}) sess.run(sp_deserialized, {sp_input0: input0_val})
def testDeserializeFailsInvalidProto(self):
self._testDeserializeFailsInvalidProtoHelper(sparse_ops.serialize_sparse,
sparse_ops.deserialize_sparse)
def testDeserializeManyFailsInvalidProto(self):
self._testDeserializeFailsInvalidProtoHelper(
sparse_ops.serialize_sparse, sparse_ops.deserialize_many_sparse)
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()

View File

@ -354,8 +354,8 @@ DestroyTemporaryVariable
AddSparseToTensorsMap AddSparseToTensorsMap
AddManySparseToTensorsMap AddManySparseToTensorsMap
TakeManySparseFromTensorsMap TakeManySparseFromTensorsMap
DeserializeSparse
DeserializeManySparse DeserializeManySparse
DeserializeSparse
SerializeManySparse SerializeManySparse
SerializeSparse SerializeSparse
SparseAdd SparseAdd

View File

@ -1385,16 +1385,17 @@ def sparse_fill_empty_rows(sp_input, default_value, name=None):
empty_row_indicator) empty_row_indicator)
def serialize_sparse(sp_input, name=None): def serialize_sparse(sp_input, name=None, out_type=dtypes.string):
"""Serialize a `SparseTensor` into a string 3-vector (1-D `Tensor`) object. """Serialize a `SparseTensor` into a 3-vector (1-D `Tensor`) object.
Args: Args:
sp_input: The input `SparseTensor`. sp_input: The input `SparseTensor`.
name: A name prefix for the returned tensors (optional). name: A name prefix for the returned tensors (optional).
out_type: The `dtype` to use for serialization.
Returns: Returns:
A string 3-vector (1D `Tensor`), with each column representing the A 3-vector (1-D `Tensor`), with each column representing the serialized
serialized `SparseTensor`'s indices, values, and shape (respectively). `SparseTensor`'s indices, values, and shape (respectively).
Raises: Raises:
TypeError: If `sp_input` is not a `SparseTensor`. TypeError: If `sp_input` is not a `SparseTensor`.
@ -1402,11 +1403,15 @@ def serialize_sparse(sp_input, name=None):
sp_input = _convert_to_sparse_tensor(sp_input) sp_input = _convert_to_sparse_tensor(sp_input)
return gen_sparse_ops._serialize_sparse( return gen_sparse_ops._serialize_sparse(
sp_input.indices, sp_input.values, sp_input.dense_shape, name=name) sp_input.indices,
sp_input.values,
sp_input.dense_shape,
name=name,
out_type=out_type)
def serialize_many_sparse(sp_input, name=None): def serialize_many_sparse(sp_input, name=None, out_type=dtypes.string):
"""Serialize an `N`-minibatch `SparseTensor` into an `[N, 3]` string `Tensor`. """Serialize `N`-minibatch `SparseTensor` into an `[N, 3]` `Tensor`.
The `SparseTensor` must have rank `R` greater than 1, and the first dimension The `SparseTensor` must have rank `R` greater than 1, and the first dimension
is treated as the minibatch dimension. Elements of the `SparseTensor` is treated as the minibatch dimension. Elements of the `SparseTensor`
@ -1419,11 +1424,12 @@ def serialize_many_sparse(sp_input, name=None):
Args: Args:
sp_input: The input rank `R` `SparseTensor`. sp_input: The input rank `R` `SparseTensor`.
name: A name prefix for the returned tensors (optional). name: A name prefix for the returned tensors (optional).
out_type: The `dtype` to use for serialization.
Returns: Returns:
A string matrix (2-D `Tensor`) with `N` rows and `3` columns. A matrix (2-D `Tensor`) with `N` rows and `3` columns. Each column
Each column represents serialized `SparseTensor`'s indices, values, and represents serialized `SparseTensor`'s indices, values, and shape
shape (respectively). (respectively).
Raises: Raises:
TypeError: If `sp_input` is not a `SparseTensor`. TypeError: If `sp_input` is not a `SparseTensor`.
@ -1431,7 +1437,11 @@ def serialize_many_sparse(sp_input, name=None):
sp_input = _convert_to_sparse_tensor(sp_input) sp_input = _convert_to_sparse_tensor(sp_input)
return gen_sparse_ops._serialize_many_sparse( return gen_sparse_ops._serialize_many_sparse(
sp_input.indices, sp_input.values, sp_input.dense_shape, name=name) sp_input.indices,
sp_input.values,
sp_input.dense_shape,
name=name,
out_type=out_type)
def deserialize_sparse(serialized_sparse, dtype, rank=None, name=None): def deserialize_sparse(serialized_sparse, dtype, rank=None, name=None):

View File

@ -1710,11 +1710,11 @@ tf_module {
} }
member_method { member_method {
name: "serialize_many_sparse" name: "serialize_many_sparse"
argspec: "args=[\'sp_input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'sp_input\', \'name\', \'out_type\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'string\'>\"], "
} }
member_method { member_method {
name: "serialize_sparse" name: "serialize_sparse"
argspec: "args=[\'sp_input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'sp_input\', \'name\', \'out_type\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'string\'>\"], "
} }
member_method { member_method {
name: "serialize_tensor" name: "serialize_tensor"