Add SerializeTensor operator (#11992)

* Add SerializeTensor operator
* Remove empty constructor for SerializeTensorOp
* Add test for SerializeTensorOp
* Explicitly test each dtype for SerializeTensorOp
* Fix BUILD file format
* update goldens
This commit is contained in:
Lin Min 2017-09-05 11:28:42 -05:00 committed by Martin Wicke
parent c8981445c5
commit 2306697c8a
8 changed files with 294 additions and 0 deletions

View File

@ -3343,6 +3343,20 @@ tf_kernel_library(
deps = PARSING_DEPS,
)
tf_cc_test(
name = "parse_tensor_test",
srcs = ["parse_tensor_test.cc"],
deps = [
":ops_testutil",
":ops_util",
":parse_tensor_op",
"//tensorflow/core:framework",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
tf_kernel_library(
name = "string_to_number_op",
prefix = "string_to_number_op",

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/framework/register_types.h"
namespace tensorflow {
@ -65,4 +66,32 @@ class ParseTensorOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("ParseTensor").Device(DEVICE_CPU), ParseTensorOp);
template <typename T>
class SerializeTensorOp : public OpKernel {
public:
using OpKernel::OpKernel;
void Compute(OpKernelContext* context) override {
const Tensor& tensor = context->input(0);
TensorProto proto;
if (tensor.dtype() == DT_STRING) {
tensor.AsProtoField(&proto);
} else {
tensor.AsProtoTensorContent(&proto);
}
Tensor* proto_string = nullptr;
OP_REQUIRES_OK(
context, context->allocate_output(0, TensorShape({}), &proto_string));
CHECK(proto.SerializeToString(&proto_string->scalar<string>()()));
}
};
#define REGISTER(T) \
REGISTER_KERNEL_BUILDER( \
Name("SerializeTensor").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
SerializeTensorOp<T>);
TF_CALL_ALL_TYPES(REGISTER)
#undef REGISTER
} // namespace tensorflow

View File

@ -0,0 +1,213 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <vector>
#include <string>
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/ops_testutil.h"
namespace tensorflow {
namespace {
class SerializeTensorOpTest : public OpsTestBase {
protected:
template <typename T>
void MakeOp(const TensorShape& input_shape,
std::function<T(int)> functor) {
TF_ASSERT_OK(
NodeDefBuilder("myop", "SerializeTensor")
.Input(FakeInput(DataTypeToEnum<T>::value))
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
AddInput<T>(input_shape, functor);
}
void ParseSerializedWithNodeDef(const NodeDef& parse_node_def,
Tensor* serialized,
Tensor* parse_output) {
std::unique_ptr<Device> device(
DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
gtl::InlinedVector<TensorValue, 4> inputs;
inputs.push_back({nullptr, serialized});
Status status;
std::unique_ptr<OpKernel> op(
CreateOpKernel(DEVICE_CPU, device.get(),
cpu_allocator(), parse_node_def,
TF_GRAPH_DEF_VERSION, &status));
TF_EXPECT_OK(status);
OpKernelContext::Params params;
params.device = device.get();
params.inputs = &inputs;
params.frame_iter = FrameAndIter(0, 0);
params.op_kernel = op.get();
std::vector<AllocatorAttributes> attrs;
test::SetOutputAttrs(&params, &attrs);
OpKernelContext ctx(&params);
op->Compute(&ctx);
TF_EXPECT_OK(status);
*parse_output = *ctx.mutable_output(0);
}
template <typename T>
void ParseSerializedOutput(Tensor* serialized, Tensor* parse_output) {
NodeDef parse;
TF_ASSERT_OK(NodeDefBuilder("parse", "ParseTensor")
.Input(FakeInput(DT_STRING))
.Attr("out_type", DataTypeToEnum<T>::value)
.Finalize(&parse));
ParseSerializedWithNodeDef(parse, serialized, parse_output);
}
};
TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_half) {
MakeOp<Eigen::half>(TensorShape({10}), [](int x) -> Eigen::half {
return static_cast<Eigen::half>(x / 10.);
});
TF_ASSERT_OK(RunOpKernel());
Tensor parse_output;
ParseSerializedOutput<Eigen::half>(GetOutput(0), &parse_output);
test::ExpectTensorEqual<Eigen::half>(parse_output, GetInput(0));
}
TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_float) {
MakeOp<float>(TensorShape({1, 10}), [](int x) -> float {
return static_cast<float>(x / 10.);
});
TF_ASSERT_OK(RunOpKernel());
Tensor parse_output;
ParseSerializedOutput<float>(GetOutput(0), &parse_output);
test::ExpectTensorEqual<float>(parse_output, GetInput(0));
}
TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_double) {
MakeOp<double>(TensorShape({5, 5}), [](int x) -> double {
return static_cast<double>(x / 10.);
});
TF_ASSERT_OK(RunOpKernel());
Tensor parse_output;
ParseSerializedOutput<double>(GetOutput(0), &parse_output);
test::ExpectTensorEqual<double>(parse_output, GetInput(0));
}
TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_int64) {
MakeOp<int64>(TensorShape({2, 3, 4}), [](int x) -> int64 {
return static_cast<int64>(x - 10);
});
TF_ASSERT_OK(RunOpKernel());
Tensor parse_output;
ParseSerializedOutput<int64>(GetOutput(0), &parse_output);
test::ExpectTensorEqual<int64>(parse_output, GetInput(0));
}
TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_int32) {
MakeOp<int32>(TensorShape({4, 2}), [](int x) -> int32 {
return static_cast<int32>(x + 7);
});
TF_ASSERT_OK(RunOpKernel());
Tensor parse_output;
ParseSerializedOutput<int32>(GetOutput(0), &parse_output);
test::ExpectTensorEqual<int32>(parse_output, GetInput(0));
}
TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_int16) {
MakeOp<int16>(TensorShape({8}), [](int x) -> int16 {
return static_cast<int16>(x + 18);
});
TF_ASSERT_OK(RunOpKernel());
Tensor parse_output;
ParseSerializedOutput<int16>(GetOutput(0), &parse_output);
test::ExpectTensorEqual<int16>(parse_output, GetInput(0));
}
TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_int8) {
MakeOp<int8>(TensorShape({2}), [](int x) -> int8 {
return static_cast<int8>(x + 8);
});
TF_ASSERT_OK(RunOpKernel());
Tensor parse_output;
ParseSerializedOutput<int8>(GetOutput(0), &parse_output);
test::ExpectTensorEqual<int8>(parse_output, GetInput(0));
}
TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_uint16) {
MakeOp<uint16>(TensorShape({1, 3}), [](int x) -> uint16 {
return static_cast<uint16>(x + 2);
});
TF_ASSERT_OK(RunOpKernel());
Tensor parse_output;
ParseSerializedOutput<uint16>(GetOutput(0), &parse_output);
test::ExpectTensorEqual<uint16>(parse_output, GetInput(0));
}
TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_uint8) {
MakeOp<uint8>(TensorShape({2, 1, 1}), [](int x) -> uint8 {
return static_cast<uint8>(x + 1);
});
TF_ASSERT_OK(RunOpKernel());
Tensor parse_output;
ParseSerializedOutput<uint8>(GetOutput(0), &parse_output);
test::ExpectTensorEqual<uint8>(parse_output, GetInput(0));
}
TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_complex64) {
MakeOp<complex64>(TensorShape({}), [](int x) -> complex64 {
return complex64{ static_cast<float>(x / 8.),
static_cast<float>(x / 2.) };
});
TF_ASSERT_OK(RunOpKernel());
Tensor parse_output;
ParseSerializedOutput<complex64>(GetOutput(0), &parse_output);
test::ExpectTensorEqual<complex64>(parse_output, GetInput(0));
}
TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_complex128) {
MakeOp<complex128>(TensorShape({3}), [](int x) -> complex128 {
return complex128{ x / 3., x / 2. };
});
TF_ASSERT_OK(RunOpKernel());
Tensor parse_output;
ParseSerializedOutput<complex128>(GetOutput(0), &parse_output);
test::ExpectTensorEqual<complex128>(parse_output, GetInput(0));
}
TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_bool) {
MakeOp<bool>(TensorShape({1}), [](int x) -> bool {
return static_cast<bool>(x % 2);
});
TF_ASSERT_OK(RunOpKernel());
Tensor parse_output;
ParseSerializedOutput<bool>(GetOutput(0), &parse_output);
test::ExpectTensorEqual<bool>(parse_output, GetInput(0));
}
TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_string) {
MakeOp<std::string>(TensorShape({10}), [](int x) -> std::string {
return std::to_string(x / 10.);
});
TF_ASSERT_OK(RunOpKernel());
Tensor parse_output;
ParseSerializedOutput<std::string>(GetOutput(0), &parse_output);
test::ExpectTensorEqual<std::string>(parse_output, GetInput(0));
}
} // namespace
} // namespace tensorflow

View File

@ -15804,6 +15804,25 @@ op {
}
summary: "Transforms a serialized tensorflow.TensorProto proto into a Tensor."
}
op {
name: "SerializeTensor"
input_arg {
name: "tensor"
description: "A Tensor of type `T`."
type: "T"
}
output_arg {
name: "serialized"
description: "A serialized TensorProto proto of the input tensor."
type_attr: DT_STRING
}
attr {
name: "T"
type: "type"
description: "The type of the input tensor."
}
summary: "Transforms a Tensor into a serialized TensorProto proto."
}
op {
name: "Placeholder"
output_arg {

View File

@ -292,6 +292,19 @@ out_type: The type of the serialized tensor. The provided type must match the
output: A Tensor of type `out_type`.
)doc");
REGISTER_OP("SerializeTensor")
.Input("tensor: T")
.Output("serialized: string")
.Attr("T: type")
.SetShapeFn(shape_inference::ScalarShape)
.Doc(R"doc(
Transforms a Tensor into a serialized TensorProto proto.
tensor: A Tensor of type `T`.
T: The type of the input tensor.
serialized: A serialized TensorProto proto of the input tensor.
)doc");
REGISTER_OP("DecodeJSONExample")
.Input("json_examples: string")
.Output("binary_examples: string")

View File

@ -37,6 +37,7 @@ See the @{$python/io_ops} guide.
@@parse_example
@@parse_single_example
@@parse_tensor
@@serialize_tensor
@@decode_json_example
@@QueueBase
@@FIFOQueue

View File

@ -40,6 +40,7 @@ from tensorflow.python.platform import tf_logging
ops.NotDifferentiable("DecodeRaw")
ops.NotDifferentiable("ParseTensor")
ops.NotDifferentiable("SerializeTensor")
ops.NotDifferentiable("StringToNumber")

View File

@ -1684,6 +1684,10 @@ tf_module {
name: "serialize_sparse"
argspec: "args=[\'sp_input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "serialize_tensor"
argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "set_random_seed"
argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=None"