Adding variant-based serialization and deserialization for sparse tensors.

PiperOrigin-RevId: 177971801
This commit is contained in:
Jiri Simsa 2017-12-05 10:13:41 -08:00 committed by TensorFlower Gardener
parent 77b60c1ac6
commit f88cd91955
10 changed files with 442 additions and 187 deletions

View File

@ -18,7 +18,14 @@ END
1-D. The `shape` of the minibatch `SparseTensor`.
END
}
summary: "Serialize an `N`-minibatch `SparseTensor` into an `[N, 3]` string `Tensor`."
attr {
name: "out_type"
description: <<END
The `dtype` to use for serialization; the supported types are `string`
(default) and `variant`.
END
}
summary: "Serialize an `N`-minibatch `SparseTensor` into an `[N, 3]` `Tensor` object."
description: <<END
The `SparseTensor` must have rank `R` greater than 1, and the first dimension
is treated as the minibatch dimension. Elements of the `SparseTensor`

View File

@ -18,5 +18,12 @@ END
1-D. The `shape` of the `SparseTensor`.
END
}
summary: "Serialize a `SparseTensor` into a string 3-vector (1-D `Tensor`) object."
attr {
name: "out_type"
description: <<END
The `dtype` to use for serialization; the supported types are `string`
(default) and `variant`.
END
}
summary: "Serialize a `SparseTensor` into a `[3]` `Tensor` object."
}

View File

@ -73,12 +73,14 @@ REGISTER(quint16)
REGISTER(qint16)
REGISTER(qint32)
REGISTER(bfloat16)
TF_CALL_variant(REGISTER)
#if defined(IS_MOBILE_PLATFORM) && !defined(SUPPORT_SELECTIVE_REGISTRATION) && \
!defined(__ANDROID_TYPES_FULL__)
// Primarily used for SavedModel support on mobile. Registering it here only if
// __ANDROID_TYPES_FULL__ is not defined, as that already register strings
REGISTER(string);
// Primarily used for SavedModel support on mobile. Registering it here only
// if __ANDROID_TYPES_FULL__ is not defined (which already registers string)
// to avoid duplicate registration.
REGISTER(string);
#endif // defined(IS_MOBILE_PLATFORM) &&
// !defined(SUPPORT_SELECTIVE_REGISTRATION) &&
// !defined(__ANDROID_TYPES_FULL__)

View File

@ -140,6 +140,7 @@ class PackOp : public OpKernel {
TF_CALL_ALL_TYPES(REGISTER_PACK);
TF_CALL_QUANTIZED_TYPES(REGISTER_PACK);
TF_CALL_bfloat16(REGISTER_PACK);
TF_CALL_variant(REGISTER_PACK);
#if defined(IS_MOBILE_PLATFORM) && !defined(SUPPORT_SELECTIVE_REGISTRATION)
// Primarily used for SavedModel support on mobile.

View File

@ -27,6 +27,8 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_util.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
#include "tensorflow/core/kernels/reshape_util.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/util/sparse/sparse_tensor.h"
@ -35,15 +37,20 @@ namespace tensorflow {
using sparse::SparseTensor;
template <typename T>
class SerializeSparseOp : public OpKernel {
public:
explicit SerializeSparseOp(OpKernelConstruction* context)
: OpKernel(context) {}
Status Initialize(Tensor* result);
Status Serialize(const Tensor& input, T* result);
void Compute(OpKernelContext* context) override {
const Tensor* input_indices;
const Tensor* input_values;
const Tensor* input_shape;
OP_REQUIRES_OK(context, context->input("sparse_indices", &input_indices));
OP_REQUIRES_OK(context, context->input("sparse_values", &input_values));
OP_REQUIRES_OK(context, context->input("sparse_shape", &input_shape));
@ -62,33 +69,74 @@ class SerializeSparseOp : public OpKernel {
"Input shape should be a vector but received shape ",
input_shape->shape().DebugString()));
TensorProto proto_indices;
TensorProto proto_values;
TensorProto proto_shape;
Tensor serialized_sparse;
OP_REQUIRES_OK(context, Initialize(&serialized_sparse));
input_indices->AsProtoTensorContent(&proto_indices);
input_values->AsProtoTensorContent(&proto_values);
input_shape->AsProtoTensorContent(&proto_shape);
Tensor serialized_sparse(DT_STRING, TensorShape({3}));
auto serialized_sparse_t = serialized_sparse.vec<string>();
serialized_sparse_t(0) = proto_indices.SerializeAsString();
serialized_sparse_t(1) = proto_values.SerializeAsString();
serialized_sparse_t(2) = proto_shape.SerializeAsString();
auto serialized_sparse_t = serialized_sparse.vec<T>();
OP_REQUIRES_OK(context, Serialize(*input_indices, &serialized_sparse_t(0)));
OP_REQUIRES_OK(context, Serialize(*input_values, &serialized_sparse_t(1)));
OP_REQUIRES_OK(context, Serialize(*input_shape, &serialized_sparse_t(2)));
context->set_output(0, serialized_sparse);
}
};
REGISTER_KERNEL_BUILDER(Name("SerializeSparse").Device(DEVICE_CPU),
SerializeSparseOp);
template <>
Status SerializeSparseOp<string>::Initialize(Tensor* result) {
*result = Tensor(DT_STRING, TensorShape({3}));
return Status::OK();
}
template <>
Status SerializeSparseOp<string>::Serialize(const Tensor& input,
string* result) {
TensorProto proto;
input.AsProtoTensorContent(&proto);
*result = proto.SerializeAsString();
return Status::OK();
}
REGISTER_KERNEL_BUILDER(Name("SerializeSparse")
.Device(DEVICE_CPU)
.TypeConstraint<string>("out_type"),
SerializeSparseOp<string>);
template <>
Status SerializeSparseOp<Variant>::Initialize(Tensor* result) {
*result = Tensor(DT_VARIANT, TensorShape({3}));
return Status::OK();
}
template <>
Status SerializeSparseOp<Variant>::Serialize(const Tensor& input,
Variant* result) {
*result = input;
return Status::OK();
}
REGISTER_KERNEL_BUILDER(Name("SerializeSparse")
.Device(DEVICE_CPU)
.TypeConstraint<Variant>("out_type"),
SerializeSparseOp<Variant>);
template <typename T>
class SerializeManySparseOp : public OpKernel {
class SerializeManySparseOpBase : public OpKernel {
public:
explicit SerializeManySparseOpBase(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {}
protected:
Status Initialize(const int64 n, Tensor* result);
Status Serialize(const Tensor& input, T* result);
};
template <typename T, typename U>
class SerializeManySparseOp : public SerializeManySparseOpBase<U> {
public:
explicit SerializeManySparseOp(OpKernelConstruction* context)
: OpKernel(context) {}
: SerializeManySparseOpBase<U>(context) {}
void Compute(OpKernelContext* context) override {
const Tensor* input_indices;
@ -127,37 +175,31 @@ class SerializeManySparseOp : public OpKernel {
auto input_shape_t = input_shape->vec<int64>();
const int64 N = input_shape_t(0);
Tensor serialized_sparse(DT_STRING, TensorShape({N, 3}));
auto serialized_sparse_t = serialized_sparse.matrix<string>();
Tensor serialized_sparse;
OP_REQUIRES_OK(context, this->Initialize(N, &serialized_sparse));
auto serialized_sparse_t = serialized_sparse.matrix<U>();
OP_REQUIRES_OK(context, input_st.IndicesValid());
// We can generate the output shape proto string now, for all
// minibatch entries.
// Initialize output with empty values and the proper shapes.
Tensor output_blank_indices(DT_INT64, {0, rank - 1});
U serialized_indices;
OP_REQUIRES_OK(context,
this->Serialize(output_blank_indices, &serialized_indices));
serialized_sparse_t.template chip<1>(0).setConstant(serialized_indices);
Tensor output_blank_values(DataTypeToEnum<T>::value, {0});
U serialized_values;
OP_REQUIRES_OK(context,
this->Serialize(output_blank_values, &serialized_values));
serialized_sparse_t.template chip<1>(1).setConstant(serialized_values);
Tensor output_shape(DT_INT64, {rank - 1});
auto output_shape_t = output_shape.vec<int64>();
for (int d = 1; d < rank; d++) output_shape_t(d - 1) = input_shape_t(d);
TensorProto proto_shape;
output_shape.AsProtoTensorContent(&proto_shape);
const string proto_shape_string = proto_shape.SerializeAsString();
Tensor output_blank_indices(DT_INT64, {0, rank - 1});
Tensor output_blank_values(DataTypeToEnum<T>::value, {0});
TensorProto proto_blank_indices;
TensorProto proto_blank_values;
output_blank_indices.AsProtoTensorContent(&proto_blank_indices);
output_blank_values.AsProtoTensorContent(&proto_blank_values);
const string proto_blank_indices_string =
proto_blank_indices.SerializeAsString();
const string proto_blank_values_string =
proto_blank_values.SerializeAsString();
// Initialize output with empty values and the proper shapes.
serialized_sparse_t.chip<1>(0).setConstant(proto_blank_indices_string);
serialized_sparse_t.chip<1>(1).setConstant(proto_blank_values_string);
serialized_sparse_t.chip<1>(2).setConstant(proto_shape_string);
U serialized_shape;
OP_REQUIRES_OK(context, this->Serialize(output_shape, &serialized_shape));
serialized_sparse_t.template chip<1>(2).setConstant(serialized_shape);
// Get groups by minibatch dimension
sparse::GroupIterable minibatch = input_st.group({0});
@ -186,33 +228,83 @@ class SerializeManySparseOp : public OpKernel {
output_values_t(i) = values(i);
}
TensorProto proto_indices;
TensorProto proto_values;
output_indices.AsProtoTensorContent(&proto_indices);
output_values.AsProtoTensorContent(&proto_values);
serialized_sparse_t(b, 0) = proto_indices.SerializeAsString();
serialized_sparse_t(b, 1) = proto_values.SerializeAsString();
OP_REQUIRES_OK(
context, this->Serialize(output_indices, &serialized_sparse_t(b, 0)));
OP_REQUIRES_OK(
context, this->Serialize(output_values, &serialized_sparse_t(b, 1)));
}
context->set_output(0, serialized_sparse);
}
};
#define REGISTER_KERNELS(type) \
REGISTER_KERNEL_BUILDER(Name("SerializeManySparse") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T"), \
SerializeManySparseOp<type>)
template <>
Status SerializeManySparseOpBase<string>::Initialize(const int64 n,
Tensor* result) {
*result = Tensor(DT_STRING, TensorShape({n, 3}));
return Status::OK();
}
template <>
Status SerializeManySparseOpBase<string>::Serialize(const Tensor& input,
string* result) {
TensorProto proto;
input.AsProtoTensorContent(&proto);
*result = proto.SerializeAsString();
return Status::OK();
}
#define REGISTER_KERNELS(type) \
REGISTER_KERNEL_BUILDER(Name("SerializeManySparse") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<string>("out_type"), \
SerializeManySparseOp<type, string>)
TF_CALL_ALL_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS
template <>
Status SerializeManySparseOpBase<Variant>::Initialize(const int64 n,
Tensor* result) {
*result = Tensor(DT_VARIANT, TensorShape({n, 3}));
return Status::OK();
}
template <>
Status SerializeManySparseOpBase<Variant>::Serialize(const Tensor& input,
Variant* result) {
*result = input;
return Status::OK();
}
#define REGISTER_KERNELS(type) \
REGISTER_KERNEL_BUILDER(Name("SerializeManySparse") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<Variant>("out_type"), \
SerializeManySparseOp<type, Variant>)
TF_CALL_ALL_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS
template <typename T>
class DeserializeSparseOp : public OpKernel {
class DeserializeSparseOpBase : public OpKernel {
public:
explicit DeserializeSparseOpBase(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {}
protected:
Status Deserialize(const T& serialized, Tensor* result);
};
template <typename T, typename U>
class DeserializeSparseOp : public DeserializeSparseOpBase<U> {
public:
explicit DeserializeSparseOp(OpKernelConstruction* context)
: OpKernel(context) {}
: DeserializeSparseOpBase<U>(context) {}
void Compute(OpKernelContext* context) override {
const Tensor& serialized_sparse = context->input(0);
@ -246,53 +338,30 @@ class DeserializeSparseOp : public OpKernel {
indices.reserve(num_sparse_tensors);
values.reserve(num_sparse_tensors);
const auto& serialized_sparse_t =
serialized_sparse.flat_inner_dims<string, 2>();
const auto& serialized_sparse_t = serialized_sparse.flat_inner_dims<U, 2>();
for (int i = 0; i < num_sparse_tensors; ++i) {
Tensor output_indices(DT_INT64);
Tensor output_values(DataTypeToEnum<T>::value);
Tensor output_shape(DT_INT64);
TensorProto proto_indices;
TensorProto proto_values;
TensorProto proto_shape;
OP_REQUIRES(
context,
ParseProtoUnlimited(&proto_indices, serialized_sparse_t(i, 0)),
errors::InvalidArgument("Could not parse serialized_sparse[", i,
", 0]"));
OP_REQUIRES(context,
ParseProtoUnlimited(&proto_values, serialized_sparse_t(i, 1)),
errors::InvalidArgument("Could not parse serialized_sparse[",
i, ", 1]"));
OP_REQUIRES(context,
ParseProtoUnlimited(&proto_shape, serialized_sparse_t(i, 2)),
errors::InvalidArgument("Could not parse serialized_sparse[",
i, ", 2]"));
OP_REQUIRES(context, output_indices.FromProto(proto_indices),
errors::InvalidArgument(
"Could not construct Tensor serialized_sparse[", i,
", 0] (indices)"));
Tensor output_indices;
OP_REQUIRES_OK(context, this->Deserialize(serialized_sparse_t(i, 0),
&output_indices));
OP_REQUIRES(context, TensorShapeUtils::IsMatrix(output_indices.shape()),
errors::InvalidArgument(
"Expected serialized_sparse[", i,
", 0] to represent an index matrix but received shape ",
output_indices.shape().DebugString()));
OP_REQUIRES(context, output_values.FromProto(proto_values),
errors::InvalidArgument(
"Could not construct Tensor serialized_sparse[", i,
", 1] (values)"));
Tensor output_values;
OP_REQUIRES_OK(context, this->Deserialize(serialized_sparse_t(i, 1),
&output_values));
OP_REQUIRES(context, TensorShapeUtils::IsVector(output_values.shape()),
errors::InvalidArgument(
"Expected serialized_sparse[", i,
", 1] to represent a values vector but received shape ",
output_values.shape().DebugString()));
OP_REQUIRES(context, output_shape.FromProto(proto_shape),
errors::InvalidArgument(
"Could not construct Tensor serialized_sparse[", i,
", 2] (shape)"));
Tensor output_shape;
OP_REQUIRES_OK(
context, this->Deserialize(serialized_sparse_t(i, 2), &output_shape));
OP_REQUIRES(
context, TensorShapeUtils::IsVector(output_shape.shape()),
errors::InvalidArgument("Expected serialized_sparse[", i,
@ -400,11 +469,27 @@ class DeserializeSparseOp : public OpKernel {
}
};
#define REGISTER_KERNELS(type) \
REGISTER_KERNEL_BUILDER(Name("DeserializeSparse") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("dtype"), \
DeserializeSparseOp<type>)
template <>
Status DeserializeSparseOpBase<string>::Deserialize(const string& serialized,
Tensor* result) {
TensorProto proto;
if (!ParseProtoUnlimited(&proto, serialized)) {
return errors::InvalidArgument("Could not parse serialized proto");
}
Tensor tensor;
if (!tensor.FromProto(proto)) {
return errors::InvalidArgument("Could not construct tensor from proto");
}
*result = tensor;
return Status::OK();
}
#define REGISTER_KERNELS(type) \
REGISTER_KERNEL_BUILDER(Name("DeserializeSparse") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("dtype") \
.TypeConstraint<string>("Tserialized"), \
DeserializeSparseOp<type, string>)
TF_CALL_ALL_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS
@ -413,7 +498,24 @@ TF_CALL_ALL_TYPES(REGISTER_KERNELS);
REGISTER_KERNEL_BUILDER(Name("DeserializeManySparse") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("dtype"), \
DeserializeSparseOp<type>)
DeserializeSparseOp<type, string>)
TF_CALL_ALL_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS
template <>
Status DeserializeSparseOpBase<Variant>::Deserialize(const Variant& serialized,
Tensor* result) {
*result = *serialized.get<Tensor>();
return Status::OK();
}
#define REGISTER_KERNELS(type) \
REGISTER_KERNEL_BUILDER(Name("DeserializeSparse") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("dtype") \
.TypeConstraint<Variant>("Tserialized"), \
DeserializeSparseOp<type, Variant>)
TF_CALL_ALL_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS

View File

@ -190,7 +190,8 @@ REGISTER_OP("SerializeSparse")
.Input("sparse_values: T")
.Input("sparse_shape: int64")
.Attr("T: type")
.Output("serialized_sparse: string")
.Output("serialized_sparse: out_type")
.Attr("out_type: {string, variant} = DT_STRING")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));
@ -200,11 +201,13 @@ REGISTER_OP("SerializeSparse")
return Status::OK();
})
.Doc(R"doc(
Serialize a `SparseTensor` into a string 3-vector (1-D `Tensor`) object.
Serialize a `SparseTensor` into a `[3]` `Tensor` object.
sparse_indices: 2-D. The `indices` of the `SparseTensor`.
sparse_values: 1-D. The `values` of the `SparseTensor`.
sparse_shape: 1-D. The `shape` of the `SparseTensor`.
out_type: The `dtype` to use for serialization; the supported types are `string`
(default) and `variant`.
)doc");
REGISTER_OP("SerializeManySparse")
@ -212,7 +215,8 @@ REGISTER_OP("SerializeManySparse")
.Input("sparse_values: T")
.Input("sparse_shape: int64")
.Attr("T: type")
.Output("serialized_sparse: string")
.Output("serialized_sparse: out_type")
.Attr("out_type: {string, variant} = DT_STRING")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));
@ -222,7 +226,7 @@ REGISTER_OP("SerializeManySparse")
return Status::OK();
})
.Doc(R"doc(
Serialize an `N`-minibatch `SparseTensor` into an `[N, 3]` string `Tensor`.
Serialize an `N`-minibatch `SparseTensor` into an `[N, 3]` `Tensor` object.
The `SparseTensor` must have rank `R` greater than 1, and the first dimension
is treated as the minibatch dimension. Elements of the `SparseTensor`
@ -235,14 +239,17 @@ The minibatch size `N` is extracted from `sparse_shape[0]`.
sparse_indices: 2-D. The `indices` of the minibatch `SparseTensor`.
sparse_values: 1-D. The `values` of the minibatch `SparseTensor`.
sparse_shape: 1-D. The `shape` of the minibatch `SparseTensor`.
out_type: The `dtype` to use for serialization; the supported types are `string`
(default) and `variant`.
)doc");
REGISTER_OP("DeserializeSparse")
.Input("serialized_sparse: string")
.Attr("dtype: type")
.Input("serialized_sparse: Tserialized")
.Output("sparse_indices: int64")
.Output("sparse_values: dtype")
.Output("sparse_shape: int64")
.Attr("dtype: type")
.Attr("Tserialized: {string, variant} = DT_STRING")
.SetShapeFn([](InferenceContext* c) {
// serialized sparse is [?, ..., ?, 3] vector.
DimensionHandle unused;
@ -305,10 +312,10 @@ dtype: The `dtype` of the serialized `SparseTensor` objects.
REGISTER_OP("DeserializeManySparse")
.Input("serialized_sparse: string")
.Attr("dtype: type")
.Output("sparse_indices: int64")
.Output("sparse_values: dtype")
.Output("sparse_shape: int64")
.Attr("dtype: type")
.SetShapeFn([](InferenceContext* c) {
// serialized sparse is [?,3] matrix.
ShapeHandle serialized_sparse;

View File

@ -64,12 +64,14 @@ class SerializeSparseTest(test.TestCase):
shape = np.array([3, 4, 5]).astype(np.int64)
return sparse_tensor_lib.SparseTensorValue(ind, val, shape)
def testSerializeDeserialize(self):
def _testSerializeDeserializeHelper(self,
serialize_fn,
deserialize_fn,
out_type=dtypes.string):
with self.test_session(use_gpu=False) as sess:
sp_input = self._SparseTensorValue_5x6(np.arange(6))
serialized = sparse_ops.serialize_sparse(sp_input)
sp_deserialized = sparse_ops.deserialize_sparse(
serialized, dtype=dtypes.int32)
serialized = serialize_fn(sp_input, out_type=out_type)
sp_deserialized = deserialize_fn(serialized, dtype=dtypes.int32)
indices, values, shape = sess.run(sp_deserialized)
@ -77,14 +79,25 @@ class SerializeSparseTest(test.TestCase):
self.assertAllEqual(values, sp_input[1])
self.assertAllEqual(shape, sp_input[2])
def testSerializeDeserializeBatch(self):
def testSerializeDeserialize(self):
self._testSerializeDeserializeHelper(sparse_ops.serialize_sparse,
sparse_ops.deserialize_sparse)
def testVariantSerializeDeserialize(self):
self._testSerializeDeserializeHelper(sparse_ops.serialize_sparse,
sparse_ops.deserialize_sparse,
dtypes.variant)
def _testSerializeDeserializeBatchHelper(self,
serialize_fn,
deserialize_fn,
out_type=dtypes.string):
with self.test_session(use_gpu=False) as sess:
sp_input = self._SparseTensorValue_5x6(np.arange(6))
serialized = sparse_ops.serialize_sparse(sp_input)
serialized = serialize_fn(sp_input, out_type=out_type)
serialized = array_ops.stack([serialized, serialized])
sp_deserialized = sparse_ops.deserialize_sparse(
serialized, dtype=dtypes.int32)
sp_deserialized = deserialize_fn(serialized, dtype=dtypes.int32)
combined_indices, combined_values, combined_shape = sess.run(
sp_deserialized)
@ -97,16 +110,29 @@ class SerializeSparseTest(test.TestCase):
self.assertAllEqual(combined_values[6:], sp_input[1])
self.assertAllEqual(combined_shape, [2, 5, 6])
def testSerializeDeserializeBatchInconsistentShape(self):
def testSerializeDeserializeBatch(self):
self._testSerializeDeserializeBatchHelper(sparse_ops.serialize_sparse,
sparse_ops.deserialize_sparse)
def testSerializeDeserializeManyBatch(self):
self._testSerializeDeserializeBatchHelper(
sparse_ops.serialize_sparse, sparse_ops.deserialize_many_sparse)
def testVariantSerializeDeserializeBatch(self):
self._testSerializeDeserializeBatchHelper(sparse_ops.serialize_sparse,
sparse_ops.deserialize_sparse,
dtypes.variant)
def _testSerializeDeserializeBatchInconsistentShapeHelper(
self, serialize_fn, deserialize_fn, out_type=dtypes.string):
with self.test_session(use_gpu=False) as sess:
sp_input0 = self._SparseTensorValue_5x6(np.arange(6))
sp_input1 = self._SparseTensorValue_3x4(np.arange(6))
serialized0 = sparse_ops.serialize_sparse(sp_input0)
serialized1 = sparse_ops.serialize_sparse(sp_input1)
serialized0 = serialize_fn(sp_input0, out_type=out_type)
serialized1 = serialize_fn(sp_input1, out_type=out_type)
serialized = array_ops.stack([serialized0, serialized1])
sp_deserialized = sparse_ops.deserialize_sparse(
serialized, dtype=dtypes.int32)
sp_deserialized = deserialize_fn(serialized, dtype=dtypes.int32)
combined_indices, combined_values, combined_shape = sess.run(
sp_deserialized)
@ -119,15 +145,26 @@ class SerializeSparseTest(test.TestCase):
self.assertAllEqual(combined_values[6:], sp_input1[1])
self.assertAllEqual(combined_shape, [2, 5, 6])
def testSerializeDeserializeNestedBatch(self):
def testSerializeDeserializeBatchInconsistentShape(self):
self._testSerializeDeserializeBatchInconsistentShapeHelper(
sparse_ops.serialize_sparse, sparse_ops.deserialize_sparse)
def testVariantSerializeDeserializeBatchInconsistentShape(self):
self._testSerializeDeserializeBatchInconsistentShapeHelper(
sparse_ops.serialize_sparse, sparse_ops.deserialize_sparse,
dtypes.variant)
def _testSerializeDeserializeNestedBatchHelper(self,
serialize_fn,
deserialize_fn,
out_type=dtypes.string):
with self.test_session(use_gpu=False) as sess:
sp_input = self._SparseTensorValue_5x6(np.arange(6))
serialized = sparse_ops.serialize_sparse(sp_input)
serialized = serialize_fn(sp_input, out_type=out_type)
serialized = array_ops.stack([serialized, serialized])
serialized = array_ops.stack([serialized, serialized])
sp_deserialized = sparse_ops.deserialize_sparse(
serialized, dtype=dtypes.int32)
sp_deserialized = deserialize_fn(serialized, dtype=dtypes.int32)
combined_indices, combined_values, combined_shape = sess.run(
sp_deserialized)
@ -151,40 +188,29 @@ class SerializeSparseTest(test.TestCase):
self.assertAllEqual(combined_shape, [2, 2, 5, 6])
def testSerializeDeserializeMany(self):
with self.test_session(use_gpu=False) as sess:
sp_input0 = self._SparseTensorValue_5x6(np.arange(6))
sp_input1 = self._SparseTensorValue_3x4(np.arange(6))
serialized0 = sparse_ops.serialize_sparse(sp_input0)
serialized1 = sparse_ops.serialize_sparse(sp_input1)
serialized_concat = array_ops.stack([serialized0, serialized1])
def testSerializeDeserializeNestedBatch(self):
self._testSerializeDeserializeNestedBatchHelper(
sparse_ops.serialize_sparse, sparse_ops.deserialize_sparse)
sp_deserialized = sparse_ops.deserialize_many_sparse(
serialized_concat, dtype=dtypes.int32)
def testVariantSerializeDeserializeNestedBatch(self):
self._testSerializeDeserializeNestedBatchHelper(
sparse_ops.serialize_sparse, sparse_ops.deserialize_sparse,
dtypes.variant)
combined_indices, combined_values, combined_shape = sess.run(
sp_deserialized)
self.assertAllEqual(combined_indices[:6, 0], [0] * 6) # minibatch 0
self.assertAllEqual(combined_indices[:6, 1:], sp_input0[0])
self.assertAllEqual(combined_indices[6:, 0], [1] * 6) # minibatch 1
self.assertAllEqual(combined_indices[6:, 1:], sp_input1[0])
self.assertAllEqual(combined_values[:6], sp_input0[1])
self.assertAllEqual(combined_values[6:], sp_input1[1])
self.assertAllEqual(combined_shape, [2, 5, 6])
def testFeedSerializeDeserializeMany(self):
def _testFeedSerializeDeserializeBatchHelper(self,
serialize_fn,
deserialize_fn,
out_type=dtypes.string):
with self.test_session(use_gpu=False) as sess:
sp_input0 = self._SparseTensorPlaceholder()
sp_input1 = self._SparseTensorPlaceholder()
input0_val = self._SparseTensorValue_5x6(np.arange(6))
input1_val = self._SparseTensorValue_3x4(np.arange(6))
serialized0 = sparse_ops.serialize_sparse(sp_input0)
serialized1 = sparse_ops.serialize_sparse(sp_input1)
serialized0 = serialize_fn(sp_input0, out_type=out_type)
serialized1 = serialize_fn(sp_input1, out_type=out_type)
serialized_concat = array_ops.stack([serialized0, serialized1])
sp_deserialized = sparse_ops.deserialize_many_sparse(
serialized_concat, dtype=dtypes.int32)
sp_deserialized = deserialize_fn(serialized_concat, dtype=dtypes.int32)
combined_indices, combined_values, combined_shape = sess.run(
sp_deserialized, {sp_input0: input0_val,
@ -198,40 +224,96 @@ class SerializeSparseTest(test.TestCase):
self.assertAllEqual(combined_values[6:], input1_val[1])
self.assertAllEqual(combined_shape, [2, 5, 6])
def testSerializeManyDeserializeManyRoundTrip(self):
def testFeedSerializeDeserializeBatch(self):
self._testFeedSerializeDeserializeBatchHelper(sparse_ops.serialize_sparse,
sparse_ops.deserialize_sparse)
def testFeedSerializeDeserializeManyBatch(self):
self._testFeedSerializeDeserializeBatchHelper(
sparse_ops.serialize_sparse, sparse_ops.deserialize_many_sparse)
def testFeedVariantSerializeDeserializeBatch(self):
self._testFeedSerializeDeserializeBatchHelper(sparse_ops.serialize_sparse,
sparse_ops.deserialize_sparse,
dtypes.variant)
def _testSerializeManyShapeHelper(self,
serialize_many_fn,
out_type=dtypes.string):
with self.test_session(use_gpu=False) as sess:
# N == 4 because shape_value == [4, 5]
indices_value = np.array([[0, 0], [0, 1], [2, 0]], dtype=np.int64)
values_value = np.array([b"a", b"b", b"c"])
shape_value = np.array([4, 5], dtype=np.int64)
sparse_tensor = self._SparseTensorPlaceholder(dtype=dtypes.string)
serialized = sparse_ops.serialize_many_sparse(sparse_tensor)
deserialized = sparse_ops.deserialize_many_sparse(
serialized, dtype=dtypes.string)
serialized_value, deserialized_value = sess.run(
[serialized, deserialized],
serialized = serialize_many_fn(sparse_tensor, out_type=out_type)
serialized_value = sess.run(
serialized,
feed_dict={
sparse_tensor.indices: indices_value,
sparse_tensor.values: values_value,
sparse_tensor.dense_shape: shape_value
})
self.assertEqual(serialized_value.shape, (4, 3))
def testSerializeManyShape(self):
self._testSerializeManyShapeHelper(sparse_ops.serialize_many_sparse)
def testVariantSerializeManyShape(self):
# NOTE: The following test is a no-op as it is currently not possible to
# convert the serialized variant value to a numpy value.
pass
def _testSerializeManyDeserializeBatchHelper(self,
serialize_many_fn,
deserialize_fn,
out_type=dtypes.string):
with self.test_session(use_gpu=False) as sess:
# N == 4 because shape_value == [4, 5]
indices_value = np.array([[0, 0], [0, 1], [2, 0]], dtype=np.int64)
values_value = np.array([b"a", b"b", b"c"])
shape_value = np.array([4, 5], dtype=np.int64)
sparse_tensor = self._SparseTensorPlaceholder(dtype=dtypes.string)
serialized = serialize_many_fn(sparse_tensor, out_type=out_type)
deserialized = deserialize_fn(serialized, dtype=dtypes.string)
deserialized_value = sess.run(
deserialized,
feed_dict={
sparse_tensor.indices: indices_value,
sparse_tensor.values: values_value,
sparse_tensor.dense_shape: shape_value
})
self.assertAllEqual(deserialized_value.indices, indices_value)
self.assertAllEqual(deserialized_value.values, values_value)
self.assertAllEqual(deserialized_value.dense_shape, shape_value)
def testDeserializeFailsWrongType(self):
def testSerializeManyDeserializeBatch(self):
self._testSerializeManyDeserializeBatchHelper(
sparse_ops.serialize_many_sparse, sparse_ops.deserialize_sparse)
def testSerializeManyDeserializeManyBatch(self):
self._testSerializeManyDeserializeBatchHelper(
sparse_ops.serialize_many_sparse, sparse_ops.deserialize_many_sparse)
def testVariantSerializeManyDeserializeBatch(self):
self._testSerializeManyDeserializeBatchHelper(
sparse_ops.serialize_many_sparse, sparse_ops.deserialize_sparse,
dtypes.variant)
def _testDeserializeFailsWrongTypeHelper(self,
serialize_fn,
deserialize_fn,
out_type=dtypes.string):
with self.test_session(use_gpu=False) as sess:
sp_input0 = self._SparseTensorPlaceholder()
sp_input1 = self._SparseTensorPlaceholder()
input0_val = self._SparseTensorValue_5x6(np.arange(6))
input1_val = self._SparseTensorValue_3x4(np.arange(6))
serialized0 = sparse_ops.serialize_sparse(sp_input0)
serialized1 = sparse_ops.serialize_sparse(sp_input1)
serialized0 = serialize_fn(sp_input0, out_type=out_type)
serialized1 = serialize_fn(sp_input1, out_type=out_type)
serialized_concat = array_ops.stack([serialized0, serialized1])
sp_deserialized = sparse_ops.deserialize_many_sparse(
serialized_concat, dtype=dtypes.int64)
sp_deserialized = deserialize_fn(serialized_concat, dtype=dtypes.int64)
with self.assertRaisesOpError(
r"Requested SparseTensor of type int64 but "
@ -240,18 +322,33 @@ class SerializeSparseTest(test.TestCase):
{sp_input0: input0_val,
sp_input1: input1_val})
def testDeserializeFailsInconsistentRank(self):
def testDeserializeFailsWrongType(self):
self._testDeserializeFailsWrongTypeHelper(sparse_ops.serialize_sparse,
sparse_ops.deserialize_sparse)
def testDeserializeManyFailsWrongType(self):
self._testDeserializeFailsWrongTypeHelper(
sparse_ops.serialize_sparse, sparse_ops.deserialize_many_sparse)
def testVariantDeserializeFailsWrongType(self):
self._testDeserializeFailsWrongTypeHelper(sparse_ops.serialize_sparse,
sparse_ops.deserialize_sparse,
dtypes.variant)
def _testDeserializeFailsInconsistentRankHelper(self,
serialize_fn,
deserialize_fn,
out_type=dtypes.string):
with self.test_session(use_gpu=False) as sess:
sp_input0 = self._SparseTensorPlaceholder()
sp_input1 = self._SparseTensorPlaceholder()
input0_val = self._SparseTensorValue_5x6(np.arange(6))
input1_val = self._SparseTensorValue_1x1x1()
serialized0 = sparse_ops.serialize_sparse(sp_input0)
serialized1 = sparse_ops.serialize_sparse(sp_input1)
serialized0 = serialize_fn(sp_input0, out_type=out_type)
serialized1 = serialize_fn(sp_input1, out_type=out_type)
serialized_concat = array_ops.stack([serialized0, serialized1])
sp_deserialized = sparse_ops.deserialize_many_sparse(
serialized_concat, dtype=dtypes.int32)
sp_deserialized = deserialize_fn(serialized_concat, dtype=dtypes.int32)
with self.assertRaisesOpError(
r"Inconsistent shape across SparseTensors: rank prior to "
@ -260,21 +357,43 @@ class SerializeSparseTest(test.TestCase):
{sp_input0: input0_val,
sp_input1: input1_val})
def testDeserializeFailsInvalidProto(self):
def testDeserializeFailsInconsistentRank(self):
self._testDeserializeFailsInconsistentRankHelper(
sparse_ops.serialize_sparse, sparse_ops.deserialize_sparse)
def testDeserializeManyFailsInconsistentRank(self):
self._testDeserializeFailsInconsistentRankHelper(
sparse_ops.serialize_sparse, sparse_ops.deserialize_many_sparse)
def testVariantDeserializeFailsInconsistentRank(self):
self._testDeserializeFailsInconsistentRankHelper(
sparse_ops.serialize_sparse, sparse_ops.deserialize_sparse,
dtypes.variant)
def _testDeserializeFailsInvalidProtoHelper(self,
serialize_fn,
deserialize_fn,
out_type=dtypes.string):
with self.test_session(use_gpu=False) as sess:
sp_input0 = self._SparseTensorPlaceholder()
input0_val = self._SparseTensorValue_5x6(np.arange(6))
serialized0 = sparse_ops.serialize_sparse(sp_input0)
serialized0 = serialize_fn(sp_input0, out_type=out_type)
serialized1 = ["a", "b", "c"]
serialized_concat = array_ops.stack([serialized0, serialized1])
sp_deserialized = sparse_ops.deserialize_many_sparse(
serialized_concat, dtype=dtypes.int32)
sp_deserialized = deserialize_fn(serialized_concat, dtype=dtypes.int32)
with self.assertRaisesOpError(
r"Could not parse serialized_sparse\[1, 0\]"):
with self.assertRaisesOpError(r"Could not parse serialized proto"):
sess.run(sp_deserialized, {sp_input0: input0_val})
def testDeserializeFailsInvalidProto(self):
self._testDeserializeFailsInvalidProtoHelper(sparse_ops.serialize_sparse,
sparse_ops.deserialize_sparse)
def testDeserializeManyFailsInvalidProto(self):
self._testDeserializeFailsInvalidProtoHelper(
sparse_ops.serialize_sparse, sparse_ops.deserialize_many_sparse)
if __name__ == "__main__":
test.main()

View File

@ -354,8 +354,8 @@ DestroyTemporaryVariable
AddSparseToTensorsMap
AddManySparseToTensorsMap
TakeManySparseFromTensorsMap
DeserializeSparse
DeserializeManySparse
DeserializeSparse
SerializeManySparse
SerializeSparse
SparseAdd

View File

@ -1385,16 +1385,17 @@ def sparse_fill_empty_rows(sp_input, default_value, name=None):
empty_row_indicator)
def serialize_sparse(sp_input, name=None):
"""Serialize a `SparseTensor` into a string 3-vector (1-D `Tensor`) object.
def serialize_sparse(sp_input, name=None, out_type=dtypes.string):
"""Serialize a `SparseTensor` into a 3-vector (1-D `Tensor`) object.
Args:
sp_input: The input `SparseTensor`.
name: A name prefix for the returned tensors (optional).
out_type: The `dtype` to use for serialization.
Returns:
A string 3-vector (1D `Tensor`), with each column representing the
serialized `SparseTensor`'s indices, values, and shape (respectively).
A 3-vector (1-D `Tensor`), with each column representing the serialized
`SparseTensor`'s indices, values, and shape (respectively).
Raises:
TypeError: If `sp_input` is not a `SparseTensor`.
@ -1402,11 +1403,15 @@ def serialize_sparse(sp_input, name=None):
sp_input = _convert_to_sparse_tensor(sp_input)
return gen_sparse_ops._serialize_sparse(
sp_input.indices, sp_input.values, sp_input.dense_shape, name=name)
sp_input.indices,
sp_input.values,
sp_input.dense_shape,
name=name,
out_type=out_type)
def serialize_many_sparse(sp_input, name=None):
"""Serialize an `N`-minibatch `SparseTensor` into an `[N, 3]` string `Tensor`.
def serialize_many_sparse(sp_input, name=None, out_type=dtypes.string):
"""Serialize `N`-minibatch `SparseTensor` into an `[N, 3]` `Tensor`.
The `SparseTensor` must have rank `R` greater than 1, and the first dimension
is treated as the minibatch dimension. Elements of the `SparseTensor`
@ -1419,11 +1424,12 @@ def serialize_many_sparse(sp_input, name=None):
Args:
sp_input: The input rank `R` `SparseTensor`.
name: A name prefix for the returned tensors (optional).
out_type: The `dtype` to use for serialization.
Returns:
A string matrix (2-D `Tensor`) with `N` rows and `3` columns.
Each column represents serialized `SparseTensor`'s indices, values, and
shape (respectively).
A matrix (2-D `Tensor`) with `N` rows and `3` columns. Each column
represents serialized `SparseTensor`'s indices, values, and shape
(respectively).
Raises:
TypeError: If `sp_input` is not a `SparseTensor`.
@ -1431,7 +1437,11 @@ def serialize_many_sparse(sp_input, name=None):
sp_input = _convert_to_sparse_tensor(sp_input)
return gen_sparse_ops._serialize_many_sparse(
sp_input.indices, sp_input.values, sp_input.dense_shape, name=name)
sp_input.indices,
sp_input.values,
sp_input.dense_shape,
name=name,
out_type=out_type)
def deserialize_sparse(serialized_sparse, dtype, rank=None, name=None):

View File

@ -1710,11 +1710,11 @@ tf_module {
}
member_method {
name: "serialize_many_sparse"
argspec: "args=[\'sp_input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'sp_input\', \'name\', \'out_type\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'string\'>\"], "
}
member_method {
name: "serialize_sparse"
argspec: "args=[\'sp_input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'sp_input\', \'name\', \'out_type\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'string\'>\"], "
}
member_method {
name: "serialize_tensor"