Add SerializeTensor operator (#11992)
* Add SerializeTensor operator * Remove empty constructor for SerializeTensorOp * Add test for SerializeTensorOp * Explicitly test each dtype for SerializeTensorOp * Fix BUILD file format * update goldens
This commit is contained in:
parent
c8981445c5
commit
2306697c8a
@ -3343,6 +3343,20 @@ tf_kernel_library(
|
|||||||
deps = PARSING_DEPS,
|
deps = PARSING_DEPS,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "parse_tensor_test",
|
||||||
|
srcs = ["parse_tensor_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":ops_testutil",
|
||||||
|
":ops_util",
|
||||||
|
":parse_tensor_op",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core:testlib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_kernel_library(
|
tf_kernel_library(
|
||||||
name = "string_to_number_op",
|
name = "string_to_number_op",
|
||||||
prefix = "string_to_number_op",
|
prefix = "string_to_number_op",
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -65,4 +66,32 @@ class ParseTensorOp : public OpKernel {
|
|||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("ParseTensor").Device(DEVICE_CPU), ParseTensorOp);
|
REGISTER_KERNEL_BUILDER(Name("ParseTensor").Device(DEVICE_CPU), ParseTensorOp);
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class SerializeTensorOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
using OpKernel::OpKernel;
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* context) override {
|
||||||
|
const Tensor& tensor = context->input(0);
|
||||||
|
TensorProto proto;
|
||||||
|
if (tensor.dtype() == DT_STRING) {
|
||||||
|
tensor.AsProtoField(&proto);
|
||||||
|
} else {
|
||||||
|
tensor.AsProtoTensorContent(&proto);
|
||||||
|
}
|
||||||
|
Tensor* proto_string = nullptr;
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
context, context->allocate_output(0, TensorShape({}), &proto_string));
|
||||||
|
CHECK(proto.SerializeToString(&proto_string->scalar<string>()()));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#define REGISTER(T) \
|
||||||
|
REGISTER_KERNEL_BUILDER( \
|
||||||
|
Name("SerializeTensor").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||||
|
SerializeTensorOp<T>);
|
||||||
|
TF_CALL_ALL_TYPES(REGISTER)
|
||||||
|
#undef REGISTER
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
213
tensorflow/core/kernels/parse_tensor_test.cc
Normal file
213
tensorflow/core/kernels/parse_tensor_test.cc
Normal file
@ -0,0 +1,213 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
|
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||||
|
#include "tensorflow/core/framework/allocator.h"
|
||||||
|
#include "tensorflow/core/framework/fake_input.h"
|
||||||
|
#include "tensorflow/core/framework/node_def_builder.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/types.h"
|
||||||
|
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class SerializeTensorOpTest : public OpsTestBase {
|
||||||
|
protected:
|
||||||
|
template <typename T>
|
||||||
|
void MakeOp(const TensorShape& input_shape,
|
||||||
|
std::function<T(int)> functor) {
|
||||||
|
TF_ASSERT_OK(
|
||||||
|
NodeDefBuilder("myop", "SerializeTensor")
|
||||||
|
.Input(FakeInput(DataTypeToEnum<T>::value))
|
||||||
|
.Finalize(node_def()));
|
||||||
|
TF_ASSERT_OK(InitOp());
|
||||||
|
AddInput<T>(input_shape, functor);
|
||||||
|
}
|
||||||
|
void ParseSerializedWithNodeDef(const NodeDef& parse_node_def,
|
||||||
|
Tensor* serialized,
|
||||||
|
Tensor* parse_output) {
|
||||||
|
std::unique_ptr<Device> device(
|
||||||
|
DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
|
||||||
|
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||||
|
inputs.push_back({nullptr, serialized});
|
||||||
|
Status status;
|
||||||
|
std::unique_ptr<OpKernel> op(
|
||||||
|
CreateOpKernel(DEVICE_CPU, device.get(),
|
||||||
|
cpu_allocator(), parse_node_def,
|
||||||
|
TF_GRAPH_DEF_VERSION, &status));
|
||||||
|
TF_EXPECT_OK(status);
|
||||||
|
OpKernelContext::Params params;
|
||||||
|
params.device = device.get();
|
||||||
|
params.inputs = &inputs;
|
||||||
|
params.frame_iter = FrameAndIter(0, 0);
|
||||||
|
params.op_kernel = op.get();
|
||||||
|
std::vector<AllocatorAttributes> attrs;
|
||||||
|
test::SetOutputAttrs(¶ms, &attrs);
|
||||||
|
OpKernelContext ctx(¶ms);
|
||||||
|
op->Compute(&ctx);
|
||||||
|
TF_EXPECT_OK(status);
|
||||||
|
*parse_output = *ctx.mutable_output(0);
|
||||||
|
}
|
||||||
|
template <typename T>
|
||||||
|
void ParseSerializedOutput(Tensor* serialized, Tensor* parse_output) {
|
||||||
|
NodeDef parse;
|
||||||
|
TF_ASSERT_OK(NodeDefBuilder("parse", "ParseTensor")
|
||||||
|
.Input(FakeInput(DT_STRING))
|
||||||
|
.Attr("out_type", DataTypeToEnum<T>::value)
|
||||||
|
.Finalize(&parse));
|
||||||
|
ParseSerializedWithNodeDef(parse, serialized, parse_output);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_half) {
|
||||||
|
MakeOp<Eigen::half>(TensorShape({10}), [](int x) -> Eigen::half {
|
||||||
|
return static_cast<Eigen::half>(x / 10.);
|
||||||
|
});
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
Tensor parse_output;
|
||||||
|
ParseSerializedOutput<Eigen::half>(GetOutput(0), &parse_output);
|
||||||
|
test::ExpectTensorEqual<Eigen::half>(parse_output, GetInput(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_float) {
|
||||||
|
MakeOp<float>(TensorShape({1, 10}), [](int x) -> float {
|
||||||
|
return static_cast<float>(x / 10.);
|
||||||
|
});
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
Tensor parse_output;
|
||||||
|
ParseSerializedOutput<float>(GetOutput(0), &parse_output);
|
||||||
|
test::ExpectTensorEqual<float>(parse_output, GetInput(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_double) {
|
||||||
|
MakeOp<double>(TensorShape({5, 5}), [](int x) -> double {
|
||||||
|
return static_cast<double>(x / 10.);
|
||||||
|
});
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
Tensor parse_output;
|
||||||
|
ParseSerializedOutput<double>(GetOutput(0), &parse_output);
|
||||||
|
test::ExpectTensorEqual<double>(parse_output, GetInput(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_int64) {
|
||||||
|
MakeOp<int64>(TensorShape({2, 3, 4}), [](int x) -> int64 {
|
||||||
|
return static_cast<int64>(x - 10);
|
||||||
|
});
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
Tensor parse_output;
|
||||||
|
ParseSerializedOutput<int64>(GetOutput(0), &parse_output);
|
||||||
|
test::ExpectTensorEqual<int64>(parse_output, GetInput(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_int32) {
|
||||||
|
MakeOp<int32>(TensorShape({4, 2}), [](int x) -> int32 {
|
||||||
|
return static_cast<int32>(x + 7);
|
||||||
|
});
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
Tensor parse_output;
|
||||||
|
ParseSerializedOutput<int32>(GetOutput(0), &parse_output);
|
||||||
|
test::ExpectTensorEqual<int32>(parse_output, GetInput(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_int16) {
|
||||||
|
MakeOp<int16>(TensorShape({8}), [](int x) -> int16 {
|
||||||
|
return static_cast<int16>(x + 18);
|
||||||
|
});
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
Tensor parse_output;
|
||||||
|
ParseSerializedOutput<int16>(GetOutput(0), &parse_output);
|
||||||
|
test::ExpectTensorEqual<int16>(parse_output, GetInput(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_int8) {
|
||||||
|
MakeOp<int8>(TensorShape({2}), [](int x) -> int8 {
|
||||||
|
return static_cast<int8>(x + 8);
|
||||||
|
});
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
Tensor parse_output;
|
||||||
|
ParseSerializedOutput<int8>(GetOutput(0), &parse_output);
|
||||||
|
test::ExpectTensorEqual<int8>(parse_output, GetInput(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_uint16) {
|
||||||
|
MakeOp<uint16>(TensorShape({1, 3}), [](int x) -> uint16 {
|
||||||
|
return static_cast<uint16>(x + 2);
|
||||||
|
});
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
Tensor parse_output;
|
||||||
|
ParseSerializedOutput<uint16>(GetOutput(0), &parse_output);
|
||||||
|
test::ExpectTensorEqual<uint16>(parse_output, GetInput(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_uint8) {
|
||||||
|
MakeOp<uint8>(TensorShape({2, 1, 1}), [](int x) -> uint8 {
|
||||||
|
return static_cast<uint8>(x + 1);
|
||||||
|
});
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
Tensor parse_output;
|
||||||
|
ParseSerializedOutput<uint8>(GetOutput(0), &parse_output);
|
||||||
|
test::ExpectTensorEqual<uint8>(parse_output, GetInput(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_complex64) {
|
||||||
|
MakeOp<complex64>(TensorShape({}), [](int x) -> complex64 {
|
||||||
|
return complex64{ static_cast<float>(x / 8.),
|
||||||
|
static_cast<float>(x / 2.) };
|
||||||
|
});
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
Tensor parse_output;
|
||||||
|
ParseSerializedOutput<complex64>(GetOutput(0), &parse_output);
|
||||||
|
test::ExpectTensorEqual<complex64>(parse_output, GetInput(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_complex128) {
|
||||||
|
MakeOp<complex128>(TensorShape({3}), [](int x) -> complex128 {
|
||||||
|
return complex128{ x / 3., x / 2. };
|
||||||
|
});
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
Tensor parse_output;
|
||||||
|
ParseSerializedOutput<complex128>(GetOutput(0), &parse_output);
|
||||||
|
test::ExpectTensorEqual<complex128>(parse_output, GetInput(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_bool) {
|
||||||
|
MakeOp<bool>(TensorShape({1}), [](int x) -> bool {
|
||||||
|
return static_cast<bool>(x % 2);
|
||||||
|
});
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
Tensor parse_output;
|
||||||
|
ParseSerializedOutput<bool>(GetOutput(0), &parse_output);
|
||||||
|
test::ExpectTensorEqual<bool>(parse_output, GetInput(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_string) {
|
||||||
|
MakeOp<std::string>(TensorShape({10}), [](int x) -> std::string {
|
||||||
|
return std::to_string(x / 10.);
|
||||||
|
});
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
Tensor parse_output;
|
||||||
|
ParseSerializedOutput<std::string>(GetOutput(0), &parse_output);
|
||||||
|
test::ExpectTensorEqual<std::string>(parse_output, GetInput(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
@ -15804,6 +15804,25 @@ op {
|
|||||||
}
|
}
|
||||||
summary: "Transforms a serialized tensorflow.TensorProto proto into a Tensor."
|
summary: "Transforms a serialized tensorflow.TensorProto proto into a Tensor."
|
||||||
}
|
}
|
||||||
|
op {
|
||||||
|
name: "SerializeTensor"
|
||||||
|
input_arg {
|
||||||
|
name: "tensor"
|
||||||
|
description: "A Tensor of type `T`."
|
||||||
|
type: "T"
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "serialized"
|
||||||
|
description: "A serialized TensorProto proto of the input tensor."
|
||||||
|
type_attr: DT_STRING
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "T"
|
||||||
|
type: "type"
|
||||||
|
description: "The type of the input tensor."
|
||||||
|
}
|
||||||
|
summary: "Transforms a Tensor into a serialized TensorProto proto."
|
||||||
|
}
|
||||||
op {
|
op {
|
||||||
name: "Placeholder"
|
name: "Placeholder"
|
||||||
output_arg {
|
output_arg {
|
||||||
|
@ -292,6 +292,19 @@ out_type: The type of the serialized tensor. The provided type must match the
|
|||||||
output: A Tensor of type `out_type`.
|
output: A Tensor of type `out_type`.
|
||||||
)doc");
|
)doc");
|
||||||
|
|
||||||
|
REGISTER_OP("SerializeTensor")
|
||||||
|
.Input("tensor: T")
|
||||||
|
.Output("serialized: string")
|
||||||
|
.Attr("T: type")
|
||||||
|
.SetShapeFn(shape_inference::ScalarShape)
|
||||||
|
.Doc(R"doc(
|
||||||
|
Transforms a Tensor into a serialized TensorProto proto.
|
||||||
|
|
||||||
|
tensor: A Tensor of type `T`.
|
||||||
|
T: The type of the input tensor.
|
||||||
|
serialized: A serialized TensorProto proto of the input tensor.
|
||||||
|
)doc");
|
||||||
|
|
||||||
REGISTER_OP("DecodeJSONExample")
|
REGISTER_OP("DecodeJSONExample")
|
||||||
.Input("json_examples: string")
|
.Input("json_examples: string")
|
||||||
.Output("binary_examples: string")
|
.Output("binary_examples: string")
|
||||||
|
@ -37,6 +37,7 @@ See the @{$python/io_ops} guide.
|
|||||||
@@parse_example
|
@@parse_example
|
||||||
@@parse_single_example
|
@@parse_single_example
|
||||||
@@parse_tensor
|
@@parse_tensor
|
||||||
|
@@serialize_tensor
|
||||||
@@decode_json_example
|
@@decode_json_example
|
||||||
@@QueueBase
|
@@QueueBase
|
||||||
@@FIFOQueue
|
@@FIFOQueue
|
||||||
|
@ -40,6 +40,7 @@ from tensorflow.python.platform import tf_logging
|
|||||||
|
|
||||||
ops.NotDifferentiable("DecodeRaw")
|
ops.NotDifferentiable("DecodeRaw")
|
||||||
ops.NotDifferentiable("ParseTensor")
|
ops.NotDifferentiable("ParseTensor")
|
||||||
|
ops.NotDifferentiable("SerializeTensor")
|
||||||
ops.NotDifferentiable("StringToNumber")
|
ops.NotDifferentiable("StringToNumber")
|
||||||
|
|
||||||
|
|
||||||
|
@ -1684,6 +1684,10 @@ tf_module {
|
|||||||
name: "serialize_sparse"
|
name: "serialize_sparse"
|
||||||
argspec: "args=[\'sp_input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'sp_input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "serialize_tensor"
|
||||||
|
argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "set_random_seed"
|
name: "set_random_seed"
|
||||||
argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
Loading…
Reference in New Issue
Block a user