diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 174ccde8b7a..b51fc841d1c 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -3343,6 +3343,20 @@ tf_kernel_library( deps = PARSING_DEPS, ) +tf_cc_test( + name = "parse_tensor_test", + srcs = ["parse_tensor_test.cc"], + deps = [ + ":ops_testutil", + ":ops_util", + ":parse_tensor_op", + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + tf_kernel_library( name = "string_to_number_op", prefix = "string_to_number_op", diff --git a/tensorflow/core/kernels/parse_tensor_op.cc b/tensorflow/core/kernels/parse_tensor_op.cc index 79199ff5c3f..dd645262d2e 100644 --- a/tensorflow/core/kernels/parse_tensor_op.cc +++ b/tensorflow/core/kernels/parse_tensor_op.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/framework/register_types.h" namespace tensorflow { @@ -65,4 +66,32 @@ class ParseTensorOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("ParseTensor").Device(DEVICE_CPU), ParseTensorOp); + +template +class SerializeTensorOp : public OpKernel { + public: + using OpKernel::OpKernel; + + void Compute(OpKernelContext* context) override { + const Tensor& tensor = context->input(0); + TensorProto proto; + if (tensor.dtype() == DT_STRING) { + tensor.AsProtoField(&proto); + } else { + tensor.AsProtoTensorContent(&proto); + } + Tensor* proto_string = nullptr; + OP_REQUIRES_OK( + context, context->allocate_output(0, TensorShape({}), &proto_string)); + CHECK(proto.SerializeToString(&proto_string->scalar()())); + } +}; + +#define REGISTER(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("SerializeTensor").Device(DEVICE_CPU).TypeConstraint("T"), \ + SerializeTensorOp); +TF_CALL_ALL_TYPES(REGISTER) +#undef REGISTER + } // namespace tensorflow diff --git a/tensorflow/core/kernels/parse_tensor_test.cc b/tensorflow/core/kernels/parse_tensor_test.cc new file mode 100644 index 00000000000..f6f60fee71c --- /dev/null +++ b/tensorflow/core/kernels/parse_tensor_test.cc @@ -0,0 +1,213 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/ops_testutil.h" + +namespace tensorflow { +namespace { + +class SerializeTensorOpTest : public OpsTestBase { + protected: + template + void MakeOp(const TensorShape& input_shape, + std::function functor) { + TF_ASSERT_OK( + NodeDefBuilder("myop", "SerializeTensor") + .Input(FakeInput(DataTypeToEnum::value)) + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + AddInput(input_shape, functor); + } + void ParseSerializedWithNodeDef(const NodeDef& parse_node_def, + Tensor* serialized, + Tensor* parse_output) { + std::unique_ptr device( + DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")); + gtl::InlinedVector inputs; + inputs.push_back({nullptr, serialized}); + Status status; + std::unique_ptr op( + CreateOpKernel(DEVICE_CPU, device.get(), + cpu_allocator(), parse_node_def, + TF_GRAPH_DEF_VERSION, &status)); + TF_EXPECT_OK(status); + OpKernelContext::Params params; + params.device = device.get(); + params.inputs = &inputs; + params.frame_iter = FrameAndIter(0, 0); + params.op_kernel = op.get(); + std::vector attrs; + test::SetOutputAttrs(¶ms, &attrs); + OpKernelContext ctx(¶ms); + op->Compute(&ctx); + TF_EXPECT_OK(status); + *parse_output = *ctx.mutable_output(0); + } + template + void ParseSerializedOutput(Tensor* serialized, Tensor* parse_output) { + NodeDef parse; + TF_ASSERT_OK(NodeDefBuilder("parse", "ParseTensor") + .Input(FakeInput(DT_STRING)) + .Attr("out_type", DataTypeToEnum::value) + .Finalize(&parse)); + ParseSerializedWithNodeDef(parse, serialized, parse_output); + } +}; + +TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_half) { + MakeOp(TensorShape({10}), [](int x) -> Eigen::half { + return static_cast(x / 10.); + }); + TF_ASSERT_OK(RunOpKernel()); + Tensor parse_output; + ParseSerializedOutput(GetOutput(0), &parse_output); + test::ExpectTensorEqual(parse_output, GetInput(0)); +} + +TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_float) { + MakeOp(TensorShape({1, 10}), [](int x) -> float { + return static_cast(x / 10.); + }); + TF_ASSERT_OK(RunOpKernel()); + Tensor parse_output; + ParseSerializedOutput(GetOutput(0), &parse_output); + test::ExpectTensorEqual(parse_output, GetInput(0)); +} + +TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_double) { + MakeOp(TensorShape({5, 5}), [](int x) -> double { + return static_cast(x / 10.); + }); + TF_ASSERT_OK(RunOpKernel()); + Tensor parse_output; + ParseSerializedOutput(GetOutput(0), &parse_output); + test::ExpectTensorEqual(parse_output, GetInput(0)); +} + +TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_int64) { + MakeOp(TensorShape({2, 3, 4}), [](int x) -> int64 { + return static_cast(x - 10); + }); + TF_ASSERT_OK(RunOpKernel()); + Tensor parse_output; + ParseSerializedOutput(GetOutput(0), &parse_output); + test::ExpectTensorEqual(parse_output, GetInput(0)); +} + +TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_int32) { + MakeOp(TensorShape({4, 2}), [](int x) -> int32 { + return static_cast(x + 7); + }); + TF_ASSERT_OK(RunOpKernel()); + Tensor parse_output; + ParseSerializedOutput(GetOutput(0), &parse_output); + test::ExpectTensorEqual(parse_output, GetInput(0)); +} + +TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_int16) { + MakeOp(TensorShape({8}), [](int x) -> int16 { + return static_cast(x + 18); + }); + TF_ASSERT_OK(RunOpKernel()); + Tensor parse_output; + ParseSerializedOutput(GetOutput(0), &parse_output); + test::ExpectTensorEqual(parse_output, GetInput(0)); +} + +TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_int8) { + MakeOp(TensorShape({2}), [](int x) -> int8 { + return static_cast(x + 8); + }); + TF_ASSERT_OK(RunOpKernel()); + Tensor parse_output; + ParseSerializedOutput(GetOutput(0), &parse_output); + test::ExpectTensorEqual(parse_output, GetInput(0)); +} + +TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_uint16) { + MakeOp(TensorShape({1, 3}), [](int x) -> uint16 { + return static_cast(x + 2); + }); + TF_ASSERT_OK(RunOpKernel()); + Tensor parse_output; + ParseSerializedOutput(GetOutput(0), &parse_output); + test::ExpectTensorEqual(parse_output, GetInput(0)); +} + +TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_uint8) { + MakeOp(TensorShape({2, 1, 1}), [](int x) -> uint8 { + return static_cast(x + 1); + }); + TF_ASSERT_OK(RunOpKernel()); + Tensor parse_output; + ParseSerializedOutput(GetOutput(0), &parse_output); + test::ExpectTensorEqual(parse_output, GetInput(0)); +} + +TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_complex64) { + MakeOp(TensorShape({}), [](int x) -> complex64 { + return complex64{ static_cast(x / 8.), + static_cast(x / 2.) }; + }); + TF_ASSERT_OK(RunOpKernel()); + Tensor parse_output; + ParseSerializedOutput(GetOutput(0), &parse_output); + test::ExpectTensorEqual(parse_output, GetInput(0)); +} + +TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_complex128) { + MakeOp(TensorShape({3}), [](int x) -> complex128 { + return complex128{ x / 3., x / 2. }; + }); + TF_ASSERT_OK(RunOpKernel()); + Tensor parse_output; + ParseSerializedOutput(GetOutput(0), &parse_output); + test::ExpectTensorEqual(parse_output, GetInput(0)); +} + +TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_bool) { + MakeOp(TensorShape({1}), [](int x) -> bool { + return static_cast(x % 2); + }); + TF_ASSERT_OK(RunOpKernel()); + Tensor parse_output; + ParseSerializedOutput(GetOutput(0), &parse_output); + test::ExpectTensorEqual(parse_output, GetInput(0)); +} + +TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_string) { + MakeOp(TensorShape({10}), [](int x) -> std::string { + return std::to_string(x / 10.); + }); + TF_ASSERT_OK(RunOpKernel()); + Tensor parse_output; + ParseSerializedOutput(GetOutput(0), &parse_output); + test::ExpectTensorEqual(parse_output, GetInput(0)); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 3a28ce3767d..35c31c6cb81 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -15804,6 +15804,25 @@ op { } summary: "Transforms a serialized tensorflow.TensorProto proto into a Tensor." } +op { + name: "SerializeTensor" + input_arg { + name: "tensor" + description: "A Tensor of type `T`." + type: "T" + } + output_arg { + name: "serialized" + description: "A serialized TensorProto proto of the input tensor." + type_attr: DT_STRING + } + attr { + name: "T" + type: "type" + description: "The type of the input tensor." + } + summary: "Transforms a Tensor into a serialized TensorProto proto." +} op { name: "Placeholder" output_arg { diff --git a/tensorflow/core/ops/parsing_ops.cc b/tensorflow/core/ops/parsing_ops.cc index 2e605fdffcf..1f7ebe91cf0 100644 --- a/tensorflow/core/ops/parsing_ops.cc +++ b/tensorflow/core/ops/parsing_ops.cc @@ -292,6 +292,19 @@ out_type: The type of the serialized tensor. The provided type must match the output: A Tensor of type `out_type`. )doc"); +REGISTER_OP("SerializeTensor") + .Input("tensor: T") + .Output("serialized: string") + .Attr("T: type") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Transforms a Tensor into a serialized TensorProto proto. + +tensor: A Tensor of type `T`. +T: The type of the input tensor. +serialized: A serialized TensorProto proto of the input tensor. +)doc"); + REGISTER_OP("DecodeJSONExample") .Input("json_examples: string") .Output("binary_examples: string") diff --git a/tensorflow/python/ops/io_ops.py b/tensorflow/python/ops/io_ops.py index 5cd5d7ba2f3..bd879ac4238 100644 --- a/tensorflow/python/ops/io_ops.py +++ b/tensorflow/python/ops/io_ops.py @@ -37,6 +37,7 @@ See the @{$python/io_ops} guide. @@parse_example @@parse_single_example @@parse_tensor +@@serialize_tensor @@decode_json_example @@QueueBase @@FIFOQueue diff --git a/tensorflow/python/ops/parsing_ops.py b/tensorflow/python/ops/parsing_ops.py index e0e3d08e7ce..bf7c9fac8ed 100644 --- a/tensorflow/python/ops/parsing_ops.py +++ b/tensorflow/python/ops/parsing_ops.py @@ -40,6 +40,7 @@ from tensorflow.python.platform import tf_logging ops.NotDifferentiable("DecodeRaw") ops.NotDifferentiable("ParseTensor") +ops.NotDifferentiable("SerializeTensor") ops.NotDifferentiable("StringToNumber") diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt index 667ae5cf6e5..7ad00281a13 100644 --- a/tensorflow/tools/api/golden/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.pbtxt @@ -1684,6 +1684,10 @@ tf_module { name: "serialize_sparse" argspec: "args=[\'sp_input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "serialize_tensor" + argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "set_random_seed" argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=None"