Adding support for bytes:// schema as a descriptor source for proto decoding and encoding ops.

This change makes using the ops with a custom proto simpler. Previously, users would need to either:

1) package the C++ target for the custom proto as a shared library and dynamically load the library in the Python program that wishes to use the ops with instances of the custom proto.

2) use the proto compiler to generate a file with a descriptor for the custom proto and make sure that the Python program that wishes to use the ops with instances of the custom proto has access to the file.

The `bytes://` schema makes it possible to embed the (serialized) descriptor proto into the descriptor source string, which means the Python binaries are self-contained (decoding does not depend on the ability to read the descriptor proto from a file) and do not need to jump through dynamic loading hoops.

PiperOrigin-RevId: 253327493
This commit is contained in:
Jiri Simsa 2019-06-14 17:36:42 -07:00 committed by TensorFlower Gardener
parent db3ecd34a4
commit fb1a1c872c
6 changed files with 147 additions and 106 deletions

View File

@ -69,50 +69,48 @@ The `decode_proto` op extracts fields from a serialized protocol buffers
message into tensors. The fields in `field_names` are decoded and converted
to the corresponding `output_types` if possible.
A `message_type` name must be provided to give context for the field
names. The actual message descriptor can be looked up either in the
linked-in descriptor pool or a filename provided by the caller using
the `descriptor_source` attribute.
A `message_type` name must be provided to give context for the field names.
The actual message descriptor can be looked up either in the linked-in
descriptor pool or a filename provided by the caller using the
`descriptor_source` attribute.
Each output tensor is a dense tensor. This means that it is padded to
hold the largest number of repeated elements seen in the input
minibatch. (The shape is also padded by one to prevent zero-sized
dimensions). The actual repeat counts for each example in the
minibatch can be found in the `sizes` output. In many cases the output
of `decode_proto` is fed immediately into tf.squeeze if missing values
are not a concern. When using tf.squeeze, always pass the squeeze
dimension explicitly to avoid surprises.
Each output tensor is a dense tensor. This means that it is padded to hold
the largest number of repeated elements seen in the input minibatch. (The
shape is also padded by one to prevent zero-sized dimensions). The actual
repeat counts for each example in the minibatch can be found in the `sizes`
output. In many cases the output of `decode_proto` is fed immediately into
tf.squeeze if missing values are not a concern. When using tf.squeeze, always
pass the squeeze dimension explicitly to avoid surprises.
For the most part, the mapping between Proto field types and
TensorFlow dtypes is straightforward. However, there are a few
special cases:
For the most part, the mapping between Proto field types and TensorFlow dtypes
is straightforward. However, there are a few special cases:
- A proto field that contains a submessage or group can only be converted
to `DT_STRING` (the serialized submessage). This is to reduce the
complexity of the API. The resulting string can be used as input
to another instance of the decode_proto op.
to `DT_STRING` (the serialized submessage). This is to reduce the complexity
of the API. The resulting string can be used as input to another instance of
the decode_proto op.
- TensorFlow lacks support for unsigned integers. The ops represent uint64
types as a `DT_INT64` with the same twos-complement bit pattern
(the obvious way). Unsigned int32 values can be represented exactly by
specifying type `DT_INT64`, or using twos-complement if the caller
specifies `DT_INT32` in the `output_types` attribute.
The `descriptor_source` attribute selects a source of protocol
descriptors to consult when looking up `message_type`. This may be a
filename containing a serialized `FileDescriptorSet` message,
or the special value `local://`, in which case only descriptors linked
into the code will be searched; the filename can be on any filesystem
accessible to TensorFlow.
You can build a `descriptor_source` file using the `--descriptor_set_out`
and `--include_imports` options to the protocol compiler `protoc`.
The `local://` database only covers descriptors linked into the
code via C++ libraries, not Python imports. You can link in a proto descriptor
by creating a cc_library target with alwayslink=1.
types as a `DT_INT64` with the same twos-complement bit pattern (the obvious
way). Unsigned int32 values can be represented exactly by specifying type
`DT_INT64`, or using twos-complement if the caller specifies `DT_INT32` in
the `output_types` attribute.
Both binary and text proto serializations are supported, and can be
chosen using the `format` attribute.
The `descriptor_source` attribute selects the source of protocol
descriptors to consult when looking up `message_type`. This may be:
- An empty string or "local://", in which case protocol descriptors are
created for C++ (not Python) proto definitions linked to the binary.
- A file, in which case protocol descriptors are created from the file,
which is expected to contain a `FileDescriptorSet` serialized as a string.
NOTE: You can build a `descriptor_source` file using the `--descriptor_set_out`
and `--include_imports` options to the protocol compiler `protoc`.
- A "bytes://<bytes>", in which protocol descriptors are created from `<bytes>`,
which is expected to be a `FileDescriptorSet` serialized as a string.
END
}

View File

@ -41,42 +41,45 @@ END
The op serializes protobuf messages provided in the input tensors.
END
description: <<END
The types of the tensors in `values` must match the schema for the
fields specified in `field_names`. All the tensors in `values` must
have a common shape prefix, *batch_shape*.
The types of the tensors in `values` must match the schema for the fields
specified in `field_names`. All the tensors in `values` must have a common
shape prefix, *batch_shape*.
The `sizes` tensor specifies repeat counts for each field. The repeat
count (last dimension) of a each tensor in `values` must be greater
than or equal to corresponding repeat count in `sizes`.
The `sizes` tensor specifies repeat counts for each field. The repeat count
(last dimension) of a each tensor in `values` must be greater than or equal
to corresponding repeat count in `sizes`.
A `message_type` name must be provided to give context for the field
names. The actual message descriptor can be looked up either in the
linked-in descriptor pool or a filename provided by the caller using
the `descriptor_source` attribute.
A `message_type` name must be provided to give context for the field names.
The actual message descriptor can be looked up either in the linked-in
descriptor pool or a filename provided by the caller using the
`descriptor_source` attribute.
The `descriptor_source` attribute selects a source of protocol
descriptors to consult when looking up `message_type`. This may be a
filename containing a serialized `FileDescriptorSet` message,
or the special value `local://`, in which case only descriptors linked
into the code will be searched; the filename can be on any filesystem
accessible to TensorFlow.
For the most part, the mapping between Proto field types and TensorFlow dtypes
is straightforward. However, there are a few special cases:
You can build a `descriptor_source` file using the `--descriptor_set_out`
- A proto field that contains a submessage or group can only be converted
to `DT_STRING` (the serialized submessage). This is to reduce the complexity
of the API. The resulting string can be used as input to another instance of
the decode_proto op.
- TensorFlow lacks support for unsigned integers. The ops represent uint64
types as a `DT_INT64` with the same twos-complement bit pattern (the obvious
way). Unsigned int32 values can be represented exactly by specifying type
`DT_INT64`, or using twos-complement if the caller specifies `DT_INT32` in
the `output_types` attribute.
The `descriptor_source` attribute selects the source of protocol
descriptors to consult when looking up `message_type`. This may be:
- An empty string or "local://", in which case protocol descriptors are
created for C++ (not Python) proto definitions linked to the binary.
- A file, in which case protocol descriptors are created from the file,
which is expected to contain a `FileDescriptorSet` serialized as a string.
NOTE: You can build a `descriptor_source` file using the `--descriptor_set_out`
and `--include_imports` options to the protocol compiler `protoc`.
The `local://` database only covers descriptors linked into the
code via C++ libraries, not Python imports. You can link in a proto descriptor
by creating a cc_library target with alwayslink=1.
There are a few special cases in the value mapping:
Submessage and group fields must be pre-serialized as TensorFlow strings.
TensorFlow lacks support for unsigned int64s, so they must be
represented as `tf.int64` with the same twos-complement bit pattern
(the obvious way).
Unsigned int32 values can be represented exactly with `tf.int64`, or
with sign wrapping if the input is of type `tf.int32`.
- A "bytes://<bytes>", in which protocol descriptors are created from `<bytes>`,
which is expected to be a `FileDescriptorSet` serialized as a string.
END
}

View File

@ -23,6 +23,7 @@ cc_library(
":local_descriptor_pool_registration",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
],
)

View File

@ -13,15 +13,31 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/util/proto/descriptors.h"
#include "absl/strings/match.h"
#include "absl/strings/strip.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/reader_op_kernel.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/util/proto/descriptor_pool_registry.h"
#include "tensorflow/core/util/proto/descriptors.h"
namespace tensorflow {
namespace {
Status CreatePoolFromSet(const protobuf::FileDescriptorSet& set,
std::unique_ptr<protobuf::DescriptorPool>* out_pool) {
*out_pool = absl::make_unique<protobuf::DescriptorPool>();
for (const auto& file : set.file()) {
if ((*out_pool)->BuildFile(file) == nullptr) {
return errors::InvalidArgument("Failed to load FileDescriptorProto: ",
file.DebugString());
}
}
return Status::OK();
}
// Build a `DescriptorPool` from the named file or URI. The file or URI
// must be available to the current TensorFlow environment.
//
@ -29,15 +45,14 @@ namespace {
// `GetDescriptorPool()` for more information.
Status GetDescriptorPoolFromFile(
tensorflow::Env* env, const string& filename,
std::unique_ptr<tensorflow::protobuf::DescriptorPool>* owned_desc_pool) {
std::unique_ptr<protobuf::DescriptorPool>* owned_desc_pool) {
Status st = env->FileExists(filename);
if (!st.ok()) {
return st;
}
// Read and parse the FileDescriptorSet.
tensorflow::protobuf::FileDescriptorSet descs;
std::unique_ptr<tensorflow::ReadOnlyMemoryRegion> buf;
protobuf::FileDescriptorSet descs;
std::unique_ptr<ReadOnlyMemoryRegion> buf;
st = env->NewReadOnlyMemoryRegionFromFile(filename, &buf);
if (!st.ok()) {
return st;
@ -46,25 +61,31 @@ Status GetDescriptorPoolFromFile(
return errors::InvalidArgument(
"descriptor_source contains invalid FileDescriptorSet: ", filename);
}
return CreatePoolFromSet(descs, owned_desc_pool);
}
// Build a DescriptorPool from the FileDescriptorSet.
owned_desc_pool->reset(new tensorflow::protobuf::DescriptorPool());
for (const auto& filedesc : descs.file()) {
if ((*owned_desc_pool)->BuildFile(filedesc) == nullptr) {
return errors::InvalidArgument(
"Problem loading FileDescriptorProto (missing dependencies?): ",
filename);
}
Status GetDescriptorPoolFromBinary(
const string& source,
std::unique_ptr<protobuf::DescriptorPool>* owned_desc_pool) {
if (!absl::StartsWith(source, "bytes://")) {
return errors::InvalidArgument(
"Source does not represent serialized file descriptor set proto.");
}
return Status::OK();
// Parse the FileDescriptorSet.
protobuf::FileDescriptorSet proto;
if (!proto.ParseFromString(string(absl::StripPrefix(source, "bytes://")))) {
return errors::InvalidArgument(
"Source does not represent serialized file descriptor set proto.");
}
return CreatePoolFromSet(proto, owned_desc_pool);
}
} // namespace
Status GetDescriptorPool(
tensorflow::Env* env, string const& descriptor_source,
tensorflow::protobuf::DescriptorPool const** desc_pool,
std::unique_ptr<tensorflow::protobuf::DescriptorPool>* owned_desc_pool) {
Env* env, string const& descriptor_source,
protobuf::DescriptorPool const** desc_pool,
std::unique_ptr<protobuf::DescriptorPool>* owned_desc_pool) {
// Attempt to lookup the pool in the registry.
auto pool_fn = DescriptorPoolRegistry::Global()->Get(descriptor_source);
if (pool_fn != nullptr) {
@ -77,7 +98,10 @@ Status GetDescriptorPool(
GetDescriptorPoolFromFile(env, descriptor_source, owned_desc_pool);
if (status.ok()) {
*desc_pool = owned_desc_pool->get();
return Status::OK();
}
status = GetDescriptorPoolFromBinary(descriptor_source, owned_desc_pool);
*desc_pool = owned_desc_pool->get();
return status;
}

View File

@ -25,17 +25,27 @@ namespace tensorflow {
class Env;
class Status;
// Get a `DescriptorPool` object from the named `descriptor_source`.
// `descriptor_source` may be a path to a file accessible to TensorFlow, in
// which case it is parsed as a `FileDescriptorSet` and used to build the
// `DescriptorPool`.
// Gets a `DescriptorPool` object from the `descriptor_source`. This may be:
//
// `owned_desc_pool` will be filled in with the same pointer as `desc_pool` if
// the caller should take ownership.
extern tensorflow::Status GetDescriptorPool(
tensorflow::Env* env, string const& descriptor_source,
tensorflow::protobuf::DescriptorPool const** desc_pool,
std::unique_ptr<tensorflow::protobuf::DescriptorPool>* owned_desc_pool);
// 1) An empty string or "local://", in which case the local descriptor pool
// created for proto definitions linked to the binary is returned.
//
// 2) A file path, in which case the descriptor pool is created from the
// contents of the file, which is expected to contain a `FileDescriptorSet`
// serialized as a string. The descriptor pool ownership is transferred to the
// caller via `owned_desc_pool`.
//
// 3) A "bytes://<bytes>", in which case the descriptor pool is created from
// `<bytes>`, which is expected to be a `FileDescriptorSet` serialized as a
// string. The descriptor pool ownership is transferred to the caller via
// `owned_desc_pool`.
//
// Custom schemas can be supported by registering a handler with the
// `DescriptorPoolRegistry`.
Status GetDescriptorPool(
Env* env, string const& descriptor_source,
protobuf::DescriptorPool const** desc_pool,
std::unique_ptr<protobuf::DescriptorPool>* owned_desc_pool);
} // namespace tensorflow

View File

@ -51,18 +51,16 @@ class DescriptorSourceTestBase(test.TestCase):
#
# The generated descriptor should capture the subset of `test_example.proto`
# used in `test_base.simple_test_case()`.
def _createDescriptorFile(self):
set_proto = FileDescriptorSet()
def _createDescriptorProto(self):
proto = FileDescriptorSet()
file_proto = set_proto.file.add(
name='types.proto',
package='tensorflow',
syntax='proto3')
file_proto = proto.file.add(
name='types.proto', package='tensorflow', syntax='proto3')
enum_proto = file_proto.enum_type.add(name='DataType')
enum_proto.value.add(name='DT_DOUBLE', number=0)
enum_proto.value.add(name='DT_BOOL', number=1)
file_proto = set_proto.file.add(
file_proto = proto.file.add(
name='test_example.proto',
package='tensorflow.contrib.proto',
dependency=['types.proto'])
@ -123,9 +121,12 @@ class DescriptorSourceTestBase(test.TestCase):
type_name='.tensorflow.contrib.proto.TestValue',
label=FieldDescriptorProto.LABEL_OPTIONAL)
return proto
def _writeProtoToFile(self, proto):
fn = os.path.join(self.get_temp_dir(), 'descriptor.pb')
with open(fn, 'wb') as f:
f.write(set_proto.SerializeToString())
f.write(proto.SerializeToString())
return fn
def _testRoundtrip(self, descriptor_source):
@ -169,8 +170,12 @@ class DescriptorSourceTestBase(test.TestCase):
def testWithFileDescriptorSet(self):
# First try parsing with a local proto db, which should fail.
with self.assertRaisesOpError('No descriptor found for message type'):
self._testRoundtrip('local://')
self._testRoundtrip(b'local://')
# Now try parsing with a FileDescriptorSet which contains the test proto.
descriptor_file = self._createDescriptorFile()
self._testRoundtrip(descriptor_file)
proto = self._createDescriptorProto()
proto_file = self._writeProtoToFile(proto)
self._testRoundtrip(proto_file)
# Finally, try parsing the descriptor as a serialized string.
self._testRoundtrip(b'bytes://' + proto.SerializeToString())