Adding variant-based serialization and deserialization for sparse tensors.
PiperOrigin-RevId: 177971801
This commit is contained in:
parent
77b60c1ac6
commit
f88cd91955
@ -18,7 +18,14 @@ END
|
||||
1-D. The `shape` of the minibatch `SparseTensor`.
|
||||
END
|
||||
}
|
||||
summary: "Serialize an `N`-minibatch `SparseTensor` into an `[N, 3]` string `Tensor`."
|
||||
attr {
|
||||
name: "out_type"
|
||||
description: <<END
|
||||
The `dtype` to use for serialization; the supported types are `string`
|
||||
(default) and `variant`.
|
||||
END
|
||||
}
|
||||
summary: "Serialize an `N`-minibatch `SparseTensor` into an `[N, 3]` `Tensor` object."
|
||||
description: <<END
|
||||
The `SparseTensor` must have rank `R` greater than 1, and the first dimension
|
||||
is treated as the minibatch dimension. Elements of the `SparseTensor`
|
||||
|
@ -18,5 +18,12 @@ END
|
||||
1-D. The `shape` of the `SparseTensor`.
|
||||
END
|
||||
}
|
||||
summary: "Serialize a `SparseTensor` into a string 3-vector (1-D `Tensor`) object."
|
||||
attr {
|
||||
name: "out_type"
|
||||
description: <<END
|
||||
The `dtype` to use for serialization; the supported types are `string`
|
||||
(default) and `variant`.
|
||||
END
|
||||
}
|
||||
summary: "Serialize a `SparseTensor` into a `[3]` `Tensor` object."
|
||||
}
|
||||
|
@ -73,12 +73,14 @@ REGISTER(quint16)
|
||||
REGISTER(qint16)
|
||||
REGISTER(qint32)
|
||||
REGISTER(bfloat16)
|
||||
TF_CALL_variant(REGISTER)
|
||||
|
||||
#if defined(IS_MOBILE_PLATFORM) && !defined(SUPPORT_SELECTIVE_REGISTRATION) && \
|
||||
!defined(__ANDROID_TYPES_FULL__)
|
||||
// Primarily used for SavedModel support on mobile. Registering it here only if
|
||||
// __ANDROID_TYPES_FULL__ is not defined, as that already register strings
|
||||
REGISTER(string);
|
||||
// Primarily used for SavedModel support on mobile. Registering it here only
|
||||
// if __ANDROID_TYPES_FULL__ is not defined (which already registers string)
|
||||
// to avoid duplicate registration.
|
||||
REGISTER(string);
|
||||
#endif // defined(IS_MOBILE_PLATFORM) &&
|
||||
// !defined(SUPPORT_SELECTIVE_REGISTRATION) &&
|
||||
// !defined(__ANDROID_TYPES_FULL__)
|
||||
|
@ -140,6 +140,7 @@ class PackOp : public OpKernel {
|
||||
TF_CALL_ALL_TYPES(REGISTER_PACK);
|
||||
TF_CALL_QUANTIZED_TYPES(REGISTER_PACK);
|
||||
TF_CALL_bfloat16(REGISTER_PACK);
|
||||
TF_CALL_variant(REGISTER_PACK);
|
||||
|
||||
#if defined(IS_MOBILE_PLATFORM) && !defined(SUPPORT_SELECTIVE_REGISTRATION)
|
||||
// Primarily used for SavedModel support on mobile.
|
||||
|
@ -27,6 +27,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_util.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/variant.h"
|
||||
#include "tensorflow/core/framework/variant_encode_decode.h"
|
||||
#include "tensorflow/core/kernels/reshape_util.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/util/sparse/sparse_tensor.h"
|
||||
@ -35,15 +37,20 @@ namespace tensorflow {
|
||||
|
||||
using sparse::SparseTensor;
|
||||
|
||||
template <typename T>
|
||||
class SerializeSparseOp : public OpKernel {
|
||||
public:
|
||||
explicit SerializeSparseOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {}
|
||||
|
||||
Status Initialize(Tensor* result);
|
||||
Status Serialize(const Tensor& input, T* result);
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor* input_indices;
|
||||
const Tensor* input_values;
|
||||
const Tensor* input_shape;
|
||||
|
||||
OP_REQUIRES_OK(context, context->input("sparse_indices", &input_indices));
|
||||
OP_REQUIRES_OK(context, context->input("sparse_values", &input_values));
|
||||
OP_REQUIRES_OK(context, context->input("sparse_shape", &input_shape));
|
||||
@ -62,33 +69,74 @@ class SerializeSparseOp : public OpKernel {
|
||||
"Input shape should be a vector but received shape ",
|
||||
input_shape->shape().DebugString()));
|
||||
|
||||
TensorProto proto_indices;
|
||||
TensorProto proto_values;
|
||||
TensorProto proto_shape;
|
||||
Tensor serialized_sparse;
|
||||
OP_REQUIRES_OK(context, Initialize(&serialized_sparse));
|
||||
|
||||
input_indices->AsProtoTensorContent(&proto_indices);
|
||||
input_values->AsProtoTensorContent(&proto_values);
|
||||
input_shape->AsProtoTensorContent(&proto_shape);
|
||||
|
||||
Tensor serialized_sparse(DT_STRING, TensorShape({3}));
|
||||
auto serialized_sparse_t = serialized_sparse.vec<string>();
|
||||
|
||||
serialized_sparse_t(0) = proto_indices.SerializeAsString();
|
||||
serialized_sparse_t(1) = proto_values.SerializeAsString();
|
||||
serialized_sparse_t(2) = proto_shape.SerializeAsString();
|
||||
auto serialized_sparse_t = serialized_sparse.vec<T>();
|
||||
OP_REQUIRES_OK(context, Serialize(*input_indices, &serialized_sparse_t(0)));
|
||||
OP_REQUIRES_OK(context, Serialize(*input_values, &serialized_sparse_t(1)));
|
||||
OP_REQUIRES_OK(context, Serialize(*input_shape, &serialized_sparse_t(2)));
|
||||
|
||||
context->set_output(0, serialized_sparse);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("SerializeSparse").Device(DEVICE_CPU),
|
||||
SerializeSparseOp);
|
||||
template <>
|
||||
Status SerializeSparseOp<string>::Initialize(Tensor* result) {
|
||||
*result = Tensor(DT_STRING, TensorShape({3}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <>
|
||||
Status SerializeSparseOp<string>::Serialize(const Tensor& input,
|
||||
string* result) {
|
||||
TensorProto proto;
|
||||
input.AsProtoTensorContent(&proto);
|
||||
*result = proto.SerializeAsString();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("SerializeSparse")
|
||||
.Device(DEVICE_CPU)
|
||||
.TypeConstraint<string>("out_type"),
|
||||
SerializeSparseOp<string>);
|
||||
|
||||
template <>
|
||||
Status SerializeSparseOp<Variant>::Initialize(Tensor* result) {
|
||||
*result = Tensor(DT_VARIANT, TensorShape({3}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <>
|
||||
Status SerializeSparseOp<Variant>::Serialize(const Tensor& input,
|
||||
Variant* result) {
|
||||
*result = input;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("SerializeSparse")
|
||||
.Device(DEVICE_CPU)
|
||||
.TypeConstraint<Variant>("out_type"),
|
||||
SerializeSparseOp<Variant>);
|
||||
|
||||
template <typename T>
|
||||
class SerializeManySparseOp : public OpKernel {
|
||||
class SerializeManySparseOpBase : public OpKernel {
|
||||
public:
|
||||
explicit SerializeManySparseOpBase(OpKernelConstruction* context)
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {}
|
||||
|
||||
protected:
|
||||
Status Initialize(const int64 n, Tensor* result);
|
||||
Status Serialize(const Tensor& input, T* result);
|
||||
};
|
||||
|
||||
template <typename T, typename U>
|
||||
class SerializeManySparseOp : public SerializeManySparseOpBase<U> {
|
||||
public:
|
||||
explicit SerializeManySparseOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {}
|
||||
: SerializeManySparseOpBase<U>(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor* input_indices;
|
||||
@ -127,37 +175,31 @@ class SerializeManySparseOp : public OpKernel {
|
||||
|
||||
auto input_shape_t = input_shape->vec<int64>();
|
||||
const int64 N = input_shape_t(0);
|
||||
|
||||
Tensor serialized_sparse(DT_STRING, TensorShape({N, 3}));
|
||||
auto serialized_sparse_t = serialized_sparse.matrix<string>();
|
||||
Tensor serialized_sparse;
|
||||
OP_REQUIRES_OK(context, this->Initialize(N, &serialized_sparse));
|
||||
auto serialized_sparse_t = serialized_sparse.matrix<U>();
|
||||
|
||||
OP_REQUIRES_OK(context, input_st.IndicesValid());
|
||||
|
||||
// We can generate the output shape proto string now, for all
|
||||
// minibatch entries.
|
||||
// Initialize output with empty values and the proper shapes.
|
||||
Tensor output_blank_indices(DT_INT64, {0, rank - 1});
|
||||
U serialized_indices;
|
||||
OP_REQUIRES_OK(context,
|
||||
this->Serialize(output_blank_indices, &serialized_indices));
|
||||
serialized_sparse_t.template chip<1>(0).setConstant(serialized_indices);
|
||||
|
||||
Tensor output_blank_values(DataTypeToEnum<T>::value, {0});
|
||||
U serialized_values;
|
||||
OP_REQUIRES_OK(context,
|
||||
this->Serialize(output_blank_values, &serialized_values));
|
||||
serialized_sparse_t.template chip<1>(1).setConstant(serialized_values);
|
||||
|
||||
Tensor output_shape(DT_INT64, {rank - 1});
|
||||
auto output_shape_t = output_shape.vec<int64>();
|
||||
for (int d = 1; d < rank; d++) output_shape_t(d - 1) = input_shape_t(d);
|
||||
TensorProto proto_shape;
|
||||
output_shape.AsProtoTensorContent(&proto_shape);
|
||||
const string proto_shape_string = proto_shape.SerializeAsString();
|
||||
|
||||
Tensor output_blank_indices(DT_INT64, {0, rank - 1});
|
||||
Tensor output_blank_values(DataTypeToEnum<T>::value, {0});
|
||||
TensorProto proto_blank_indices;
|
||||
TensorProto proto_blank_values;
|
||||
output_blank_indices.AsProtoTensorContent(&proto_blank_indices);
|
||||
output_blank_values.AsProtoTensorContent(&proto_blank_values);
|
||||
|
||||
const string proto_blank_indices_string =
|
||||
proto_blank_indices.SerializeAsString();
|
||||
const string proto_blank_values_string =
|
||||
proto_blank_values.SerializeAsString();
|
||||
|
||||
// Initialize output with empty values and the proper shapes.
|
||||
serialized_sparse_t.chip<1>(0).setConstant(proto_blank_indices_string);
|
||||
serialized_sparse_t.chip<1>(1).setConstant(proto_blank_values_string);
|
||||
serialized_sparse_t.chip<1>(2).setConstant(proto_shape_string);
|
||||
U serialized_shape;
|
||||
OP_REQUIRES_OK(context, this->Serialize(output_shape, &serialized_shape));
|
||||
serialized_sparse_t.template chip<1>(2).setConstant(serialized_shape);
|
||||
|
||||
// Get groups by minibatch dimension
|
||||
sparse::GroupIterable minibatch = input_st.group({0});
|
||||
@ -186,33 +228,83 @@ class SerializeManySparseOp : public OpKernel {
|
||||
output_values_t(i) = values(i);
|
||||
}
|
||||
|
||||
TensorProto proto_indices;
|
||||
TensorProto proto_values;
|
||||
output_indices.AsProtoTensorContent(&proto_indices);
|
||||
output_values.AsProtoTensorContent(&proto_values);
|
||||
|
||||
serialized_sparse_t(b, 0) = proto_indices.SerializeAsString();
|
||||
serialized_sparse_t(b, 1) = proto_values.SerializeAsString();
|
||||
OP_REQUIRES_OK(
|
||||
context, this->Serialize(output_indices, &serialized_sparse_t(b, 0)));
|
||||
OP_REQUIRES_OK(
|
||||
context, this->Serialize(output_values, &serialized_sparse_t(b, 1)));
|
||||
}
|
||||
|
||||
context->set_output(0, serialized_sparse);
|
||||
}
|
||||
};
|
||||
|
||||
#define REGISTER_KERNELS(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("SerializeManySparse") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T"), \
|
||||
SerializeManySparseOp<type>)
|
||||
template <>
|
||||
Status SerializeManySparseOpBase<string>::Initialize(const int64 n,
|
||||
Tensor* result) {
|
||||
*result = Tensor(DT_STRING, TensorShape({n, 3}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <>
|
||||
Status SerializeManySparseOpBase<string>::Serialize(const Tensor& input,
|
||||
string* result) {
|
||||
TensorProto proto;
|
||||
input.AsProtoTensorContent(&proto);
|
||||
*result = proto.SerializeAsString();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define REGISTER_KERNELS(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("SerializeManySparse") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<string>("out_type"), \
|
||||
SerializeManySparseOp<type, string>)
|
||||
|
||||
TF_CALL_ALL_TYPES(REGISTER_KERNELS);
|
||||
#undef REGISTER_KERNELS
|
||||
|
||||
template <>
|
||||
Status SerializeManySparseOpBase<Variant>::Initialize(const int64 n,
|
||||
Tensor* result) {
|
||||
*result = Tensor(DT_VARIANT, TensorShape({n, 3}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <>
|
||||
Status SerializeManySparseOpBase<Variant>::Serialize(const Tensor& input,
|
||||
Variant* result) {
|
||||
*result = input;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define REGISTER_KERNELS(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("SerializeManySparse") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<Variant>("out_type"), \
|
||||
SerializeManySparseOp<type, Variant>)
|
||||
|
||||
TF_CALL_ALL_TYPES(REGISTER_KERNELS);
|
||||
#undef REGISTER_KERNELS
|
||||
|
||||
template <typename T>
|
||||
class DeserializeSparseOp : public OpKernel {
|
||||
class DeserializeSparseOpBase : public OpKernel {
|
||||
public:
|
||||
explicit DeserializeSparseOpBase(OpKernelConstruction* context)
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {}
|
||||
|
||||
protected:
|
||||
Status Deserialize(const T& serialized, Tensor* result);
|
||||
};
|
||||
|
||||
template <typename T, typename U>
|
||||
class DeserializeSparseOp : public DeserializeSparseOpBase<U> {
|
||||
public:
|
||||
explicit DeserializeSparseOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {}
|
||||
: DeserializeSparseOpBase<U>(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor& serialized_sparse = context->input(0);
|
||||
@ -246,53 +338,30 @@ class DeserializeSparseOp : public OpKernel {
|
||||
indices.reserve(num_sparse_tensors);
|
||||
values.reserve(num_sparse_tensors);
|
||||
|
||||
const auto& serialized_sparse_t =
|
||||
serialized_sparse.flat_inner_dims<string, 2>();
|
||||
const auto& serialized_sparse_t = serialized_sparse.flat_inner_dims<U, 2>();
|
||||
|
||||
for (int i = 0; i < num_sparse_tensors; ++i) {
|
||||
Tensor output_indices(DT_INT64);
|
||||
Tensor output_values(DataTypeToEnum<T>::value);
|
||||
Tensor output_shape(DT_INT64);
|
||||
TensorProto proto_indices;
|
||||
TensorProto proto_values;
|
||||
TensorProto proto_shape;
|
||||
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
ParseProtoUnlimited(&proto_indices, serialized_sparse_t(i, 0)),
|
||||
errors::InvalidArgument("Could not parse serialized_sparse[", i,
|
||||
", 0]"));
|
||||
OP_REQUIRES(context,
|
||||
ParseProtoUnlimited(&proto_values, serialized_sparse_t(i, 1)),
|
||||
errors::InvalidArgument("Could not parse serialized_sparse[",
|
||||
i, ", 1]"));
|
||||
OP_REQUIRES(context,
|
||||
ParseProtoUnlimited(&proto_shape, serialized_sparse_t(i, 2)),
|
||||
errors::InvalidArgument("Could not parse serialized_sparse[",
|
||||
i, ", 2]"));
|
||||
|
||||
OP_REQUIRES(context, output_indices.FromProto(proto_indices),
|
||||
errors::InvalidArgument(
|
||||
"Could not construct Tensor serialized_sparse[", i,
|
||||
", 0] (indices)"));
|
||||
Tensor output_indices;
|
||||
OP_REQUIRES_OK(context, this->Deserialize(serialized_sparse_t(i, 0),
|
||||
&output_indices));
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsMatrix(output_indices.shape()),
|
||||
errors::InvalidArgument(
|
||||
"Expected serialized_sparse[", i,
|
||||
", 0] to represent an index matrix but received shape ",
|
||||
output_indices.shape().DebugString()));
|
||||
OP_REQUIRES(context, output_values.FromProto(proto_values),
|
||||
errors::InvalidArgument(
|
||||
"Could not construct Tensor serialized_sparse[", i,
|
||||
", 1] (values)"));
|
||||
|
||||
Tensor output_values;
|
||||
OP_REQUIRES_OK(context, this->Deserialize(serialized_sparse_t(i, 1),
|
||||
&output_values));
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsVector(output_values.shape()),
|
||||
errors::InvalidArgument(
|
||||
"Expected serialized_sparse[", i,
|
||||
", 1] to represent a values vector but received shape ",
|
||||
output_values.shape().DebugString()));
|
||||
OP_REQUIRES(context, output_shape.FromProto(proto_shape),
|
||||
errors::InvalidArgument(
|
||||
"Could not construct Tensor serialized_sparse[", i,
|
||||
", 2] (shape)"));
|
||||
|
||||
Tensor output_shape;
|
||||
OP_REQUIRES_OK(
|
||||
context, this->Deserialize(serialized_sparse_t(i, 2), &output_shape));
|
||||
OP_REQUIRES(
|
||||
context, TensorShapeUtils::IsVector(output_shape.shape()),
|
||||
errors::InvalidArgument("Expected serialized_sparse[", i,
|
||||
@ -400,11 +469,27 @@ class DeserializeSparseOp : public OpKernel {
|
||||
}
|
||||
};
|
||||
|
||||
#define REGISTER_KERNELS(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("DeserializeSparse") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("dtype"), \
|
||||
DeserializeSparseOp<type>)
|
||||
template <>
|
||||
Status DeserializeSparseOpBase<string>::Deserialize(const string& serialized,
|
||||
Tensor* result) {
|
||||
TensorProto proto;
|
||||
if (!ParseProtoUnlimited(&proto, serialized)) {
|
||||
return errors::InvalidArgument("Could not parse serialized proto");
|
||||
}
|
||||
Tensor tensor;
|
||||
if (!tensor.FromProto(proto)) {
|
||||
return errors::InvalidArgument("Could not construct tensor from proto");
|
||||
}
|
||||
*result = tensor;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define REGISTER_KERNELS(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("DeserializeSparse") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("dtype") \
|
||||
.TypeConstraint<string>("Tserialized"), \
|
||||
DeserializeSparseOp<type, string>)
|
||||
|
||||
TF_CALL_ALL_TYPES(REGISTER_KERNELS);
|
||||
#undef REGISTER_KERNELS
|
||||
@ -413,7 +498,24 @@ TF_CALL_ALL_TYPES(REGISTER_KERNELS);
|
||||
REGISTER_KERNEL_BUILDER(Name("DeserializeManySparse") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("dtype"), \
|
||||
DeserializeSparseOp<type>)
|
||||
DeserializeSparseOp<type, string>)
|
||||
|
||||
TF_CALL_ALL_TYPES(REGISTER_KERNELS);
|
||||
#undef REGISTER_KERNELS
|
||||
|
||||
template <>
|
||||
Status DeserializeSparseOpBase<Variant>::Deserialize(const Variant& serialized,
|
||||
Tensor* result) {
|
||||
*result = *serialized.get<Tensor>();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define REGISTER_KERNELS(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("DeserializeSparse") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("dtype") \
|
||||
.TypeConstraint<Variant>("Tserialized"), \
|
||||
DeserializeSparseOp<type, Variant>)
|
||||
|
||||
TF_CALL_ALL_TYPES(REGISTER_KERNELS);
|
||||
#undef REGISTER_KERNELS
|
||||
|
@ -190,7 +190,8 @@ REGISTER_OP("SerializeSparse")
|
||||
.Input("sparse_values: T")
|
||||
.Input("sparse_shape: int64")
|
||||
.Attr("T: type")
|
||||
.Output("serialized_sparse: string")
|
||||
.Output("serialized_sparse: out_type")
|
||||
.Attr("out_type: {string, variant} = DT_STRING")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle unused;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));
|
||||
@ -200,11 +201,13 @@ REGISTER_OP("SerializeSparse")
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
Serialize a `SparseTensor` into a string 3-vector (1-D `Tensor`) object.
|
||||
Serialize a `SparseTensor` into a `[3]` `Tensor` object.
|
||||
|
||||
sparse_indices: 2-D. The `indices` of the `SparseTensor`.
|
||||
sparse_values: 1-D. The `values` of the `SparseTensor`.
|
||||
sparse_shape: 1-D. The `shape` of the `SparseTensor`.
|
||||
out_type: The `dtype` to use for serialization; the supported types are `string`
|
||||
(default) and `variant`.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("SerializeManySparse")
|
||||
@ -212,7 +215,8 @@ REGISTER_OP("SerializeManySparse")
|
||||
.Input("sparse_values: T")
|
||||
.Input("sparse_shape: int64")
|
||||
.Attr("T: type")
|
||||
.Output("serialized_sparse: string")
|
||||
.Output("serialized_sparse: out_type")
|
||||
.Attr("out_type: {string, variant} = DT_STRING")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle unused;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));
|
||||
@ -222,7 +226,7 @@ REGISTER_OP("SerializeManySparse")
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
Serialize an `N`-minibatch `SparseTensor` into an `[N, 3]` string `Tensor`.
|
||||
Serialize an `N`-minibatch `SparseTensor` into an `[N, 3]` `Tensor` object.
|
||||
|
||||
The `SparseTensor` must have rank `R` greater than 1, and the first dimension
|
||||
is treated as the minibatch dimension. Elements of the `SparseTensor`
|
||||
@ -235,14 +239,17 @@ The minibatch size `N` is extracted from `sparse_shape[0]`.
|
||||
sparse_indices: 2-D. The `indices` of the minibatch `SparseTensor`.
|
||||
sparse_values: 1-D. The `values` of the minibatch `SparseTensor`.
|
||||
sparse_shape: 1-D. The `shape` of the minibatch `SparseTensor`.
|
||||
out_type: The `dtype` to use for serialization; the supported types are `string`
|
||||
(default) and `variant`.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("DeserializeSparse")
|
||||
.Input("serialized_sparse: string")
|
||||
.Attr("dtype: type")
|
||||
.Input("serialized_sparse: Tserialized")
|
||||
.Output("sparse_indices: int64")
|
||||
.Output("sparse_values: dtype")
|
||||
.Output("sparse_shape: int64")
|
||||
.Attr("dtype: type")
|
||||
.Attr("Tserialized: {string, variant} = DT_STRING")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
// serialized sparse is [?, ..., ?, 3] vector.
|
||||
DimensionHandle unused;
|
||||
@ -305,10 +312,10 @@ dtype: The `dtype` of the serialized `SparseTensor` objects.
|
||||
|
||||
REGISTER_OP("DeserializeManySparse")
|
||||
.Input("serialized_sparse: string")
|
||||
.Attr("dtype: type")
|
||||
.Output("sparse_indices: int64")
|
||||
.Output("sparse_values: dtype")
|
||||
.Output("sparse_shape: int64")
|
||||
.Attr("dtype: type")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
// serialized sparse is [?,3] matrix.
|
||||
ShapeHandle serialized_sparse;
|
||||
|
@ -64,12 +64,14 @@ class SerializeSparseTest(test.TestCase):
|
||||
shape = np.array([3, 4, 5]).astype(np.int64)
|
||||
return sparse_tensor_lib.SparseTensorValue(ind, val, shape)
|
||||
|
||||
def testSerializeDeserialize(self):
|
||||
def _testSerializeDeserializeHelper(self,
|
||||
serialize_fn,
|
||||
deserialize_fn,
|
||||
out_type=dtypes.string):
|
||||
with self.test_session(use_gpu=False) as sess:
|
||||
sp_input = self._SparseTensorValue_5x6(np.arange(6))
|
||||
serialized = sparse_ops.serialize_sparse(sp_input)
|
||||
sp_deserialized = sparse_ops.deserialize_sparse(
|
||||
serialized, dtype=dtypes.int32)
|
||||
serialized = serialize_fn(sp_input, out_type=out_type)
|
||||
sp_deserialized = deserialize_fn(serialized, dtype=dtypes.int32)
|
||||
|
||||
indices, values, shape = sess.run(sp_deserialized)
|
||||
|
||||
@ -77,14 +79,25 @@ class SerializeSparseTest(test.TestCase):
|
||||
self.assertAllEqual(values, sp_input[1])
|
||||
self.assertAllEqual(shape, sp_input[2])
|
||||
|
||||
def testSerializeDeserializeBatch(self):
|
||||
def testSerializeDeserialize(self):
|
||||
self._testSerializeDeserializeHelper(sparse_ops.serialize_sparse,
|
||||
sparse_ops.deserialize_sparse)
|
||||
|
||||
def testVariantSerializeDeserialize(self):
|
||||
self._testSerializeDeserializeHelper(sparse_ops.serialize_sparse,
|
||||
sparse_ops.deserialize_sparse,
|
||||
dtypes.variant)
|
||||
|
||||
def _testSerializeDeserializeBatchHelper(self,
|
||||
serialize_fn,
|
||||
deserialize_fn,
|
||||
out_type=dtypes.string):
|
||||
with self.test_session(use_gpu=False) as sess:
|
||||
sp_input = self._SparseTensorValue_5x6(np.arange(6))
|
||||
serialized = sparse_ops.serialize_sparse(sp_input)
|
||||
serialized = serialize_fn(sp_input, out_type=out_type)
|
||||
serialized = array_ops.stack([serialized, serialized])
|
||||
|
||||
sp_deserialized = sparse_ops.deserialize_sparse(
|
||||
serialized, dtype=dtypes.int32)
|
||||
sp_deserialized = deserialize_fn(serialized, dtype=dtypes.int32)
|
||||
|
||||
combined_indices, combined_values, combined_shape = sess.run(
|
||||
sp_deserialized)
|
||||
@ -97,16 +110,29 @@ class SerializeSparseTest(test.TestCase):
|
||||
self.assertAllEqual(combined_values[6:], sp_input[1])
|
||||
self.assertAllEqual(combined_shape, [2, 5, 6])
|
||||
|
||||
def testSerializeDeserializeBatchInconsistentShape(self):
|
||||
def testSerializeDeserializeBatch(self):
|
||||
self._testSerializeDeserializeBatchHelper(sparse_ops.serialize_sparse,
|
||||
sparse_ops.deserialize_sparse)
|
||||
|
||||
def testSerializeDeserializeManyBatch(self):
|
||||
self._testSerializeDeserializeBatchHelper(
|
||||
sparse_ops.serialize_sparse, sparse_ops.deserialize_many_sparse)
|
||||
|
||||
def testVariantSerializeDeserializeBatch(self):
|
||||
self._testSerializeDeserializeBatchHelper(sparse_ops.serialize_sparse,
|
||||
sparse_ops.deserialize_sparse,
|
||||
dtypes.variant)
|
||||
|
||||
def _testSerializeDeserializeBatchInconsistentShapeHelper(
|
||||
self, serialize_fn, deserialize_fn, out_type=dtypes.string):
|
||||
with self.test_session(use_gpu=False) as sess:
|
||||
sp_input0 = self._SparseTensorValue_5x6(np.arange(6))
|
||||
sp_input1 = self._SparseTensorValue_3x4(np.arange(6))
|
||||
serialized0 = sparse_ops.serialize_sparse(sp_input0)
|
||||
serialized1 = sparse_ops.serialize_sparse(sp_input1)
|
||||
serialized0 = serialize_fn(sp_input0, out_type=out_type)
|
||||
serialized1 = serialize_fn(sp_input1, out_type=out_type)
|
||||
serialized = array_ops.stack([serialized0, serialized1])
|
||||
|
||||
sp_deserialized = sparse_ops.deserialize_sparse(
|
||||
serialized, dtype=dtypes.int32)
|
||||
sp_deserialized = deserialize_fn(serialized, dtype=dtypes.int32)
|
||||
|
||||
combined_indices, combined_values, combined_shape = sess.run(
|
||||
sp_deserialized)
|
||||
@ -119,15 +145,26 @@ class SerializeSparseTest(test.TestCase):
|
||||
self.assertAllEqual(combined_values[6:], sp_input1[1])
|
||||
self.assertAllEqual(combined_shape, [2, 5, 6])
|
||||
|
||||
def testSerializeDeserializeNestedBatch(self):
|
||||
def testSerializeDeserializeBatchInconsistentShape(self):
|
||||
self._testSerializeDeserializeBatchInconsistentShapeHelper(
|
||||
sparse_ops.serialize_sparse, sparse_ops.deserialize_sparse)
|
||||
|
||||
def testVariantSerializeDeserializeBatchInconsistentShape(self):
|
||||
self._testSerializeDeserializeBatchInconsistentShapeHelper(
|
||||
sparse_ops.serialize_sparse, sparse_ops.deserialize_sparse,
|
||||
dtypes.variant)
|
||||
|
||||
def _testSerializeDeserializeNestedBatchHelper(self,
|
||||
serialize_fn,
|
||||
deserialize_fn,
|
||||
out_type=dtypes.string):
|
||||
with self.test_session(use_gpu=False) as sess:
|
||||
sp_input = self._SparseTensorValue_5x6(np.arange(6))
|
||||
serialized = sparse_ops.serialize_sparse(sp_input)
|
||||
serialized = serialize_fn(sp_input, out_type=out_type)
|
||||
serialized = array_ops.stack([serialized, serialized])
|
||||
serialized = array_ops.stack([serialized, serialized])
|
||||
|
||||
sp_deserialized = sparse_ops.deserialize_sparse(
|
||||
serialized, dtype=dtypes.int32)
|
||||
sp_deserialized = deserialize_fn(serialized, dtype=dtypes.int32)
|
||||
|
||||
combined_indices, combined_values, combined_shape = sess.run(
|
||||
sp_deserialized)
|
||||
@ -151,40 +188,29 @@ class SerializeSparseTest(test.TestCase):
|
||||
|
||||
self.assertAllEqual(combined_shape, [2, 2, 5, 6])
|
||||
|
||||
def testSerializeDeserializeMany(self):
|
||||
with self.test_session(use_gpu=False) as sess:
|
||||
sp_input0 = self._SparseTensorValue_5x6(np.arange(6))
|
||||
sp_input1 = self._SparseTensorValue_3x4(np.arange(6))
|
||||
serialized0 = sparse_ops.serialize_sparse(sp_input0)
|
||||
serialized1 = sparse_ops.serialize_sparse(sp_input1)
|
||||
serialized_concat = array_ops.stack([serialized0, serialized1])
|
||||
def testSerializeDeserializeNestedBatch(self):
|
||||
self._testSerializeDeserializeNestedBatchHelper(
|
||||
sparse_ops.serialize_sparse, sparse_ops.deserialize_sparse)
|
||||
|
||||
sp_deserialized = sparse_ops.deserialize_many_sparse(
|
||||
serialized_concat, dtype=dtypes.int32)
|
||||
def testVariantSerializeDeserializeNestedBatch(self):
|
||||
self._testSerializeDeserializeNestedBatchHelper(
|
||||
sparse_ops.serialize_sparse, sparse_ops.deserialize_sparse,
|
||||
dtypes.variant)
|
||||
|
||||
combined_indices, combined_values, combined_shape = sess.run(
|
||||
sp_deserialized)
|
||||
|
||||
self.assertAllEqual(combined_indices[:6, 0], [0] * 6) # minibatch 0
|
||||
self.assertAllEqual(combined_indices[:6, 1:], sp_input0[0])
|
||||
self.assertAllEqual(combined_indices[6:, 0], [1] * 6) # minibatch 1
|
||||
self.assertAllEqual(combined_indices[6:, 1:], sp_input1[0])
|
||||
self.assertAllEqual(combined_values[:6], sp_input0[1])
|
||||
self.assertAllEqual(combined_values[6:], sp_input1[1])
|
||||
self.assertAllEqual(combined_shape, [2, 5, 6])
|
||||
|
||||
def testFeedSerializeDeserializeMany(self):
|
||||
def _testFeedSerializeDeserializeBatchHelper(self,
|
||||
serialize_fn,
|
||||
deserialize_fn,
|
||||
out_type=dtypes.string):
|
||||
with self.test_session(use_gpu=False) as sess:
|
||||
sp_input0 = self._SparseTensorPlaceholder()
|
||||
sp_input1 = self._SparseTensorPlaceholder()
|
||||
input0_val = self._SparseTensorValue_5x6(np.arange(6))
|
||||
input1_val = self._SparseTensorValue_3x4(np.arange(6))
|
||||
serialized0 = sparse_ops.serialize_sparse(sp_input0)
|
||||
serialized1 = sparse_ops.serialize_sparse(sp_input1)
|
||||
serialized0 = serialize_fn(sp_input0, out_type=out_type)
|
||||
serialized1 = serialize_fn(sp_input1, out_type=out_type)
|
||||
serialized_concat = array_ops.stack([serialized0, serialized1])
|
||||
|
||||
sp_deserialized = sparse_ops.deserialize_many_sparse(
|
||||
serialized_concat, dtype=dtypes.int32)
|
||||
sp_deserialized = deserialize_fn(serialized_concat, dtype=dtypes.int32)
|
||||
|
||||
combined_indices, combined_values, combined_shape = sess.run(
|
||||
sp_deserialized, {sp_input0: input0_val,
|
||||
@ -198,40 +224,96 @@ class SerializeSparseTest(test.TestCase):
|
||||
self.assertAllEqual(combined_values[6:], input1_val[1])
|
||||
self.assertAllEqual(combined_shape, [2, 5, 6])
|
||||
|
||||
def testSerializeManyDeserializeManyRoundTrip(self):
|
||||
def testFeedSerializeDeserializeBatch(self):
|
||||
self._testFeedSerializeDeserializeBatchHelper(sparse_ops.serialize_sparse,
|
||||
sparse_ops.deserialize_sparse)
|
||||
|
||||
def testFeedSerializeDeserializeManyBatch(self):
|
||||
self._testFeedSerializeDeserializeBatchHelper(
|
||||
sparse_ops.serialize_sparse, sparse_ops.deserialize_many_sparse)
|
||||
|
||||
def testFeedVariantSerializeDeserializeBatch(self):
|
||||
self._testFeedSerializeDeserializeBatchHelper(sparse_ops.serialize_sparse,
|
||||
sparse_ops.deserialize_sparse,
|
||||
dtypes.variant)
|
||||
|
||||
def _testSerializeManyShapeHelper(self,
|
||||
serialize_many_fn,
|
||||
out_type=dtypes.string):
|
||||
with self.test_session(use_gpu=False) as sess:
|
||||
# N == 4 because shape_value == [4, 5]
|
||||
indices_value = np.array([[0, 0], [0, 1], [2, 0]], dtype=np.int64)
|
||||
values_value = np.array([b"a", b"b", b"c"])
|
||||
shape_value = np.array([4, 5], dtype=np.int64)
|
||||
sparse_tensor = self._SparseTensorPlaceholder(dtype=dtypes.string)
|
||||
serialized = sparse_ops.serialize_many_sparse(sparse_tensor)
|
||||
deserialized = sparse_ops.deserialize_many_sparse(
|
||||
serialized, dtype=dtypes.string)
|
||||
serialized_value, deserialized_value = sess.run(
|
||||
[serialized, deserialized],
|
||||
serialized = serialize_many_fn(sparse_tensor, out_type=out_type)
|
||||
serialized_value = sess.run(
|
||||
serialized,
|
||||
feed_dict={
|
||||
sparse_tensor.indices: indices_value,
|
||||
sparse_tensor.values: values_value,
|
||||
sparse_tensor.dense_shape: shape_value
|
||||
})
|
||||
self.assertEqual(serialized_value.shape, (4, 3))
|
||||
|
||||
def testSerializeManyShape(self):
|
||||
self._testSerializeManyShapeHelper(sparse_ops.serialize_many_sparse)
|
||||
|
||||
def testVariantSerializeManyShape(self):
|
||||
# NOTE: The following test is a no-op as it is currently not possible to
|
||||
# convert the serialized variant value to a numpy value.
|
||||
pass
|
||||
|
||||
def _testSerializeManyDeserializeBatchHelper(self,
|
||||
serialize_many_fn,
|
||||
deserialize_fn,
|
||||
out_type=dtypes.string):
|
||||
with self.test_session(use_gpu=False) as sess:
|
||||
# N == 4 because shape_value == [4, 5]
|
||||
indices_value = np.array([[0, 0], [0, 1], [2, 0]], dtype=np.int64)
|
||||
values_value = np.array([b"a", b"b", b"c"])
|
||||
shape_value = np.array([4, 5], dtype=np.int64)
|
||||
sparse_tensor = self._SparseTensorPlaceholder(dtype=dtypes.string)
|
||||
serialized = serialize_many_fn(sparse_tensor, out_type=out_type)
|
||||
deserialized = deserialize_fn(serialized, dtype=dtypes.string)
|
||||
deserialized_value = sess.run(
|
||||
deserialized,
|
||||
feed_dict={
|
||||
sparse_tensor.indices: indices_value,
|
||||
sparse_tensor.values: values_value,
|
||||
sparse_tensor.dense_shape: shape_value
|
||||
})
|
||||
self.assertAllEqual(deserialized_value.indices, indices_value)
|
||||
self.assertAllEqual(deserialized_value.values, values_value)
|
||||
self.assertAllEqual(deserialized_value.dense_shape, shape_value)
|
||||
|
||||
def testDeserializeFailsWrongType(self):
|
||||
def testSerializeManyDeserializeBatch(self):
|
||||
self._testSerializeManyDeserializeBatchHelper(
|
||||
sparse_ops.serialize_many_sparse, sparse_ops.deserialize_sparse)
|
||||
|
||||
def testSerializeManyDeserializeManyBatch(self):
|
||||
self._testSerializeManyDeserializeBatchHelper(
|
||||
sparse_ops.serialize_many_sparse, sparse_ops.deserialize_many_sparse)
|
||||
|
||||
def testVariantSerializeManyDeserializeBatch(self):
|
||||
self._testSerializeManyDeserializeBatchHelper(
|
||||
sparse_ops.serialize_many_sparse, sparse_ops.deserialize_sparse,
|
||||
dtypes.variant)
|
||||
|
||||
def _testDeserializeFailsWrongTypeHelper(self,
|
||||
serialize_fn,
|
||||
deserialize_fn,
|
||||
out_type=dtypes.string):
|
||||
with self.test_session(use_gpu=False) as sess:
|
||||
sp_input0 = self._SparseTensorPlaceholder()
|
||||
sp_input1 = self._SparseTensorPlaceholder()
|
||||
input0_val = self._SparseTensorValue_5x6(np.arange(6))
|
||||
input1_val = self._SparseTensorValue_3x4(np.arange(6))
|
||||
serialized0 = sparse_ops.serialize_sparse(sp_input0)
|
||||
serialized1 = sparse_ops.serialize_sparse(sp_input1)
|
||||
serialized0 = serialize_fn(sp_input0, out_type=out_type)
|
||||
serialized1 = serialize_fn(sp_input1, out_type=out_type)
|
||||
serialized_concat = array_ops.stack([serialized0, serialized1])
|
||||
|
||||
sp_deserialized = sparse_ops.deserialize_many_sparse(
|
||||
serialized_concat, dtype=dtypes.int64)
|
||||
sp_deserialized = deserialize_fn(serialized_concat, dtype=dtypes.int64)
|
||||
|
||||
with self.assertRaisesOpError(
|
||||
r"Requested SparseTensor of type int64 but "
|
||||
@ -240,18 +322,33 @@ class SerializeSparseTest(test.TestCase):
|
||||
{sp_input0: input0_val,
|
||||
sp_input1: input1_val})
|
||||
|
||||
def testDeserializeFailsInconsistentRank(self):
|
||||
def testDeserializeFailsWrongType(self):
|
||||
self._testDeserializeFailsWrongTypeHelper(sparse_ops.serialize_sparse,
|
||||
sparse_ops.deserialize_sparse)
|
||||
|
||||
def testDeserializeManyFailsWrongType(self):
|
||||
self._testDeserializeFailsWrongTypeHelper(
|
||||
sparse_ops.serialize_sparse, sparse_ops.deserialize_many_sparse)
|
||||
|
||||
def testVariantDeserializeFailsWrongType(self):
|
||||
self._testDeserializeFailsWrongTypeHelper(sparse_ops.serialize_sparse,
|
||||
sparse_ops.deserialize_sparse,
|
||||
dtypes.variant)
|
||||
|
||||
def _testDeserializeFailsInconsistentRankHelper(self,
|
||||
serialize_fn,
|
||||
deserialize_fn,
|
||||
out_type=dtypes.string):
|
||||
with self.test_session(use_gpu=False) as sess:
|
||||
sp_input0 = self._SparseTensorPlaceholder()
|
||||
sp_input1 = self._SparseTensorPlaceholder()
|
||||
input0_val = self._SparseTensorValue_5x6(np.arange(6))
|
||||
input1_val = self._SparseTensorValue_1x1x1()
|
||||
serialized0 = sparse_ops.serialize_sparse(sp_input0)
|
||||
serialized1 = sparse_ops.serialize_sparse(sp_input1)
|
||||
serialized0 = serialize_fn(sp_input0, out_type=out_type)
|
||||
serialized1 = serialize_fn(sp_input1, out_type=out_type)
|
||||
serialized_concat = array_ops.stack([serialized0, serialized1])
|
||||
|
||||
sp_deserialized = sparse_ops.deserialize_many_sparse(
|
||||
serialized_concat, dtype=dtypes.int32)
|
||||
sp_deserialized = deserialize_fn(serialized_concat, dtype=dtypes.int32)
|
||||
|
||||
with self.assertRaisesOpError(
|
||||
r"Inconsistent shape across SparseTensors: rank prior to "
|
||||
@ -260,21 +357,43 @@ class SerializeSparseTest(test.TestCase):
|
||||
{sp_input0: input0_val,
|
||||
sp_input1: input1_val})
|
||||
|
||||
def testDeserializeFailsInvalidProto(self):
|
||||
def testDeserializeFailsInconsistentRank(self):
|
||||
self._testDeserializeFailsInconsistentRankHelper(
|
||||
sparse_ops.serialize_sparse, sparse_ops.deserialize_sparse)
|
||||
|
||||
def testDeserializeManyFailsInconsistentRank(self):
|
||||
self._testDeserializeFailsInconsistentRankHelper(
|
||||
sparse_ops.serialize_sparse, sparse_ops.deserialize_many_sparse)
|
||||
|
||||
def testVariantDeserializeFailsInconsistentRank(self):
|
||||
self._testDeserializeFailsInconsistentRankHelper(
|
||||
sparse_ops.serialize_sparse, sparse_ops.deserialize_sparse,
|
||||
dtypes.variant)
|
||||
|
||||
def _testDeserializeFailsInvalidProtoHelper(self,
|
||||
serialize_fn,
|
||||
deserialize_fn,
|
||||
out_type=dtypes.string):
|
||||
with self.test_session(use_gpu=False) as sess:
|
||||
sp_input0 = self._SparseTensorPlaceholder()
|
||||
input0_val = self._SparseTensorValue_5x6(np.arange(6))
|
||||
serialized0 = sparse_ops.serialize_sparse(sp_input0)
|
||||
serialized0 = serialize_fn(sp_input0, out_type=out_type)
|
||||
serialized1 = ["a", "b", "c"]
|
||||
serialized_concat = array_ops.stack([serialized0, serialized1])
|
||||
|
||||
sp_deserialized = sparse_ops.deserialize_many_sparse(
|
||||
serialized_concat, dtype=dtypes.int32)
|
||||
sp_deserialized = deserialize_fn(serialized_concat, dtype=dtypes.int32)
|
||||
|
||||
with self.assertRaisesOpError(
|
||||
r"Could not parse serialized_sparse\[1, 0\]"):
|
||||
with self.assertRaisesOpError(r"Could not parse serialized proto"):
|
||||
sess.run(sp_deserialized, {sp_input0: input0_val})
|
||||
|
||||
def testDeserializeFailsInvalidProto(self):
|
||||
self._testDeserializeFailsInvalidProtoHelper(sparse_ops.serialize_sparse,
|
||||
sparse_ops.deserialize_sparse)
|
||||
|
||||
def testDeserializeManyFailsInvalidProto(self):
|
||||
self._testDeserializeFailsInvalidProtoHelper(
|
||||
sparse_ops.serialize_sparse, sparse_ops.deserialize_many_sparse)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -354,8 +354,8 @@ DestroyTemporaryVariable
|
||||
AddSparseToTensorsMap
|
||||
AddManySparseToTensorsMap
|
||||
TakeManySparseFromTensorsMap
|
||||
DeserializeSparse
|
||||
DeserializeManySparse
|
||||
DeserializeSparse
|
||||
SerializeManySparse
|
||||
SerializeSparse
|
||||
SparseAdd
|
||||
|
@ -1385,16 +1385,17 @@ def sparse_fill_empty_rows(sp_input, default_value, name=None):
|
||||
empty_row_indicator)
|
||||
|
||||
|
||||
def serialize_sparse(sp_input, name=None):
|
||||
"""Serialize a `SparseTensor` into a string 3-vector (1-D `Tensor`) object.
|
||||
def serialize_sparse(sp_input, name=None, out_type=dtypes.string):
|
||||
"""Serialize a `SparseTensor` into a 3-vector (1-D `Tensor`) object.
|
||||
|
||||
Args:
|
||||
sp_input: The input `SparseTensor`.
|
||||
name: A name prefix for the returned tensors (optional).
|
||||
out_type: The `dtype` to use for serialization.
|
||||
|
||||
Returns:
|
||||
A string 3-vector (1D `Tensor`), with each column representing the
|
||||
serialized `SparseTensor`'s indices, values, and shape (respectively).
|
||||
A 3-vector (1-D `Tensor`), with each column representing the serialized
|
||||
`SparseTensor`'s indices, values, and shape (respectively).
|
||||
|
||||
Raises:
|
||||
TypeError: If `sp_input` is not a `SparseTensor`.
|
||||
@ -1402,11 +1403,15 @@ def serialize_sparse(sp_input, name=None):
|
||||
sp_input = _convert_to_sparse_tensor(sp_input)
|
||||
|
||||
return gen_sparse_ops._serialize_sparse(
|
||||
sp_input.indices, sp_input.values, sp_input.dense_shape, name=name)
|
||||
sp_input.indices,
|
||||
sp_input.values,
|
||||
sp_input.dense_shape,
|
||||
name=name,
|
||||
out_type=out_type)
|
||||
|
||||
|
||||
def serialize_many_sparse(sp_input, name=None):
|
||||
"""Serialize an `N`-minibatch `SparseTensor` into an `[N, 3]` string `Tensor`.
|
||||
def serialize_many_sparse(sp_input, name=None, out_type=dtypes.string):
|
||||
"""Serialize `N`-minibatch `SparseTensor` into an `[N, 3]` `Tensor`.
|
||||
|
||||
The `SparseTensor` must have rank `R` greater than 1, and the first dimension
|
||||
is treated as the minibatch dimension. Elements of the `SparseTensor`
|
||||
@ -1419,11 +1424,12 @@ def serialize_many_sparse(sp_input, name=None):
|
||||
Args:
|
||||
sp_input: The input rank `R` `SparseTensor`.
|
||||
name: A name prefix for the returned tensors (optional).
|
||||
out_type: The `dtype` to use for serialization.
|
||||
|
||||
Returns:
|
||||
A string matrix (2-D `Tensor`) with `N` rows and `3` columns.
|
||||
Each column represents serialized `SparseTensor`'s indices, values, and
|
||||
shape (respectively).
|
||||
A matrix (2-D `Tensor`) with `N` rows and `3` columns. Each column
|
||||
represents serialized `SparseTensor`'s indices, values, and shape
|
||||
(respectively).
|
||||
|
||||
Raises:
|
||||
TypeError: If `sp_input` is not a `SparseTensor`.
|
||||
@ -1431,7 +1437,11 @@ def serialize_many_sparse(sp_input, name=None):
|
||||
sp_input = _convert_to_sparse_tensor(sp_input)
|
||||
|
||||
return gen_sparse_ops._serialize_many_sparse(
|
||||
sp_input.indices, sp_input.values, sp_input.dense_shape, name=name)
|
||||
sp_input.indices,
|
||||
sp_input.values,
|
||||
sp_input.dense_shape,
|
||||
name=name,
|
||||
out_type=out_type)
|
||||
|
||||
|
||||
def deserialize_sparse(serialized_sparse, dtype, rank=None, name=None):
|
||||
|
@ -1710,11 +1710,11 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "serialize_many_sparse"
|
||||
argspec: "args=[\'sp_input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'sp_input\', \'name\', \'out_type\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'string\'>\"], "
|
||||
}
|
||||
member_method {
|
||||
name: "serialize_sparse"
|
||||
argspec: "args=[\'sp_input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'sp_input\', \'name\', \'out_type\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'string\'>\"], "
|
||||
}
|
||||
member_method {
|
||||
name: "serialize_tensor"
|
||||
|
Loading…
x
Reference in New Issue
Block a user