Adding support for bytes:// schema as a descriptor source for proto decoding and encoding ops.

This change makes using the ops with a custom proto simpler. Previously, users would need to either:

1) package the C++ target for the custom proto as a shared library and dynamically load the library in the Python program that wishes to use the ops with instances of the custom proto.

2) use the proto compiler to generate a file with a descriptor for the custom proto and make sure that the Python program that wishes to use the ops with instances of the custom proto has access to the file.

The `bytes://` schema makes it possible to embed the (serialized) descriptor proto into the descriptor source string, which means the Python binaries are self-contained (decoding does not depend on the ability to read the descriptor proto from a file) and do not need to jump through dynamic loading hoops.

PiperOrigin-RevId: 253327493
This commit is contained in:
Jiri Simsa 2019-06-14 17:36:42 -07:00 committed by TensorFlower Gardener
parent db3ecd34a4
commit fb1a1c872c
6 changed files with 147 additions and 106 deletions

View File

@ -69,50 +69,48 @@ The `decode_proto` op extracts fields from a serialized protocol buffers
message into tensors. The fields in `field_names` are decoded and converted message into tensors. The fields in `field_names` are decoded and converted
to the corresponding `output_types` if possible. to the corresponding `output_types` if possible.
A `message_type` name must be provided to give context for the field A `message_type` name must be provided to give context for the field names.
names. The actual message descriptor can be looked up either in the The actual message descriptor can be looked up either in the linked-in
linked-in descriptor pool or a filename provided by the caller using descriptor pool or a filename provided by the caller using the
the `descriptor_source` attribute. `descriptor_source` attribute.
Each output tensor is a dense tensor. This means that it is padded to Each output tensor is a dense tensor. This means that it is padded to hold
hold the largest number of repeated elements seen in the input the largest number of repeated elements seen in the input minibatch. (The
minibatch. (The shape is also padded by one to prevent zero-sized shape is also padded by one to prevent zero-sized dimensions). The actual
dimensions). The actual repeat counts for each example in the repeat counts for each example in the minibatch can be found in the `sizes`
minibatch can be found in the `sizes` output. In many cases the output output. In many cases the output of `decode_proto` is fed immediately into
of `decode_proto` is fed immediately into tf.squeeze if missing values tf.squeeze if missing values are not a concern. When using tf.squeeze, always
are not a concern. When using tf.squeeze, always pass the squeeze pass the squeeze dimension explicitly to avoid surprises.
dimension explicitly to avoid surprises.
For the most part, the mapping between Proto field types and For the most part, the mapping between Proto field types and TensorFlow dtypes
TensorFlow dtypes is straightforward. However, there are a few is straightforward. However, there are a few special cases:
special cases:
- A proto field that contains a submessage or group can only be converted - A proto field that contains a submessage or group can only be converted
to `DT_STRING` (the serialized submessage). This is to reduce the to `DT_STRING` (the serialized submessage). This is to reduce the complexity
complexity of the API. The resulting string can be used as input of the API. The resulting string can be used as input to another instance of
to another instance of the decode_proto op. the decode_proto op.
- TensorFlow lacks support for unsigned integers. The ops represent uint64 - TensorFlow lacks support for unsigned integers. The ops represent uint64
types as a `DT_INT64` with the same twos-complement bit pattern types as a `DT_INT64` with the same twos-complement bit pattern (the obvious
(the obvious way). Unsigned int32 values can be represented exactly by way). Unsigned int32 values can be represented exactly by specifying type
specifying type `DT_INT64`, or using twos-complement if the caller `DT_INT64`, or using twos-complement if the caller specifies `DT_INT32` in
specifies `DT_INT32` in the `output_types` attribute. the `output_types` attribute.
The `descriptor_source` attribute selects a source of protocol
descriptors to consult when looking up `message_type`. This may be a
filename containing a serialized `FileDescriptorSet` message,
or the special value `local://`, in which case only descriptors linked
into the code will be searched; the filename can be on any filesystem
accessible to TensorFlow.
You can build a `descriptor_source` file using the `--descriptor_set_out`
and `--include_imports` options to the protocol compiler `protoc`.
The `local://` database only covers descriptors linked into the
code via C++ libraries, not Python imports. You can link in a proto descriptor
by creating a cc_library target with alwayslink=1.
Both binary and text proto serializations are supported, and can be Both binary and text proto serializations are supported, and can be
chosen using the `format` attribute. chosen using the `format` attribute.
The `descriptor_source` attribute selects the source of protocol
descriptors to consult when looking up `message_type`. This may be:
- An empty string or "local://", in which case protocol descriptors are
created for C++ (not Python) proto definitions linked to the binary.
- A file, in which case protocol descriptors are created from the file,
which is expected to contain a `FileDescriptorSet` serialized as a string.
NOTE: You can build a `descriptor_source` file using the `--descriptor_set_out`
and `--include_imports` options to the protocol compiler `protoc`.
- A "bytes://<bytes>", in which protocol descriptors are created from `<bytes>`,
which is expected to be a `FileDescriptorSet` serialized as a string.
END END
} }

View File

@ -41,42 +41,45 @@ END
The op serializes protobuf messages provided in the input tensors. The op serializes protobuf messages provided in the input tensors.
END END
description: <<END description: <<END
The types of the tensors in `values` must match the schema for the The types of the tensors in `values` must match the schema for the fields
fields specified in `field_names`. All the tensors in `values` must specified in `field_names`. All the tensors in `values` must have a common
have a common shape prefix, *batch_shape*. shape prefix, *batch_shape*.
The `sizes` tensor specifies repeat counts for each field. The repeat The `sizes` tensor specifies repeat counts for each field. The repeat count
count (last dimension) of a each tensor in `values` must be greater (last dimension) of a each tensor in `values` must be greater than or equal
than or equal to corresponding repeat count in `sizes`. to corresponding repeat count in `sizes`.
A `message_type` name must be provided to give context for the field A `message_type` name must be provided to give context for the field names.
names. The actual message descriptor can be looked up either in the The actual message descriptor can be looked up either in the linked-in
linked-in descriptor pool or a filename provided by the caller using descriptor pool or a filename provided by the caller using the
the `descriptor_source` attribute. `descriptor_source` attribute.
The `descriptor_source` attribute selects a source of protocol For the most part, the mapping between Proto field types and TensorFlow dtypes
descriptors to consult when looking up `message_type`. This may be a is straightforward. However, there are a few special cases:
filename containing a serialized `FileDescriptorSet` message,
or the special value `local://`, in which case only descriptors linked
into the code will be searched; the filename can be on any filesystem
accessible to TensorFlow.
You can build a `descriptor_source` file using the `--descriptor_set_out` - A proto field that contains a submessage or group can only be converted
to `DT_STRING` (the serialized submessage). This is to reduce the complexity
of the API. The resulting string can be used as input to another instance of
the decode_proto op.
- TensorFlow lacks support for unsigned integers. The ops represent uint64
types as a `DT_INT64` with the same twos-complement bit pattern (the obvious
way). Unsigned int32 values can be represented exactly by specifying type
`DT_INT64`, or using twos-complement if the caller specifies `DT_INT32` in
the `output_types` attribute.
The `descriptor_source` attribute selects the source of protocol
descriptors to consult when looking up `message_type`. This may be:
- An empty string or "local://", in which case protocol descriptors are
created for C++ (not Python) proto definitions linked to the binary.
- A file, in which case protocol descriptors are created from the file,
which is expected to contain a `FileDescriptorSet` serialized as a string.
NOTE: You can build a `descriptor_source` file using the `--descriptor_set_out`
and `--include_imports` options to the protocol compiler `protoc`. and `--include_imports` options to the protocol compiler `protoc`.
The `local://` database only covers descriptors linked into the - A "bytes://<bytes>", in which protocol descriptors are created from `<bytes>`,
code via C++ libraries, not Python imports. You can link in a proto descriptor which is expected to be a `FileDescriptorSet` serialized as a string.
by creating a cc_library target with alwayslink=1.
There are a few special cases in the value mapping:
Submessage and group fields must be pre-serialized as TensorFlow strings.
TensorFlow lacks support for unsigned int64s, so they must be
represented as `tf.int64` with the same twos-complement bit pattern
(the obvious way).
Unsigned int32 values can be represented exactly with `tf.int64`, or
with sign wrapping if the input is of type `tf.int32`.
END END
} }

View File

@ -23,6 +23,7 @@ cc_library(
":local_descriptor_pool_registration", ":local_descriptor_pool_registration",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"@com_google_absl//absl/strings",
], ],
) )

View File

@ -13,15 +13,31 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/util/proto/descriptors.h"
#include "absl/strings/match.h"
#include "absl/strings/strip.h"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/reader_op_kernel.h" #include "tensorflow/core/framework/reader_op_kernel.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/util/proto/descriptor_pool_registry.h" #include "tensorflow/core/util/proto/descriptor_pool_registry.h"
#include "tensorflow/core/util/proto/descriptors.h"
namespace tensorflow { namespace tensorflow {
namespace { namespace {
Status CreatePoolFromSet(const protobuf::FileDescriptorSet& set,
std::unique_ptr<protobuf::DescriptorPool>* out_pool) {
*out_pool = absl::make_unique<protobuf::DescriptorPool>();
for (const auto& file : set.file()) {
if ((*out_pool)->BuildFile(file) == nullptr) {
return errors::InvalidArgument("Failed to load FileDescriptorProto: ",
file.DebugString());
}
}
return Status::OK();
}
// Build a `DescriptorPool` from the named file or URI. The file or URI // Build a `DescriptorPool` from the named file or URI. The file or URI
// must be available to the current TensorFlow environment. // must be available to the current TensorFlow environment.
// //
@ -29,15 +45,14 @@ namespace {
// `GetDescriptorPool()` for more information. // `GetDescriptorPool()` for more information.
Status GetDescriptorPoolFromFile( Status GetDescriptorPoolFromFile(
tensorflow::Env* env, const string& filename, tensorflow::Env* env, const string& filename,
std::unique_ptr<tensorflow::protobuf::DescriptorPool>* owned_desc_pool) { std::unique_ptr<protobuf::DescriptorPool>* owned_desc_pool) {
Status st = env->FileExists(filename); Status st = env->FileExists(filename);
if (!st.ok()) { if (!st.ok()) {
return st; return st;
} }
// Read and parse the FileDescriptorSet. // Read and parse the FileDescriptorSet.
tensorflow::protobuf::FileDescriptorSet descs; protobuf::FileDescriptorSet descs;
std::unique_ptr<tensorflow::ReadOnlyMemoryRegion> buf; std::unique_ptr<ReadOnlyMemoryRegion> buf;
st = env->NewReadOnlyMemoryRegionFromFile(filename, &buf); st = env->NewReadOnlyMemoryRegionFromFile(filename, &buf);
if (!st.ok()) { if (!st.ok()) {
return st; return st;
@ -46,25 +61,31 @@ Status GetDescriptorPoolFromFile(
return errors::InvalidArgument( return errors::InvalidArgument(
"descriptor_source contains invalid FileDescriptorSet: ", filename); "descriptor_source contains invalid FileDescriptorSet: ", filename);
} }
return CreatePoolFromSet(descs, owned_desc_pool);
}
// Build a DescriptorPool from the FileDescriptorSet. Status GetDescriptorPoolFromBinary(
owned_desc_pool->reset(new tensorflow::protobuf::DescriptorPool()); const string& source,
for (const auto& filedesc : descs.file()) { std::unique_ptr<protobuf::DescriptorPool>* owned_desc_pool) {
if ((*owned_desc_pool)->BuildFile(filedesc) == nullptr) { if (!absl::StartsWith(source, "bytes://")) {
return errors::InvalidArgument( return errors::InvalidArgument(
"Problem loading FileDescriptorProto (missing dependencies?): ", "Source does not represent serialized file descriptor set proto.");
filename);
}
} }
return Status::OK(); // Parse the FileDescriptorSet.
protobuf::FileDescriptorSet proto;
if (!proto.ParseFromString(string(absl::StripPrefix(source, "bytes://")))) {
return errors::InvalidArgument(
"Source does not represent serialized file descriptor set proto.");
}
return CreatePoolFromSet(proto, owned_desc_pool);
} }
} // namespace } // namespace
Status GetDescriptorPool( Status GetDescriptorPool(
tensorflow::Env* env, string const& descriptor_source, Env* env, string const& descriptor_source,
tensorflow::protobuf::DescriptorPool const** desc_pool, protobuf::DescriptorPool const** desc_pool,
std::unique_ptr<tensorflow::protobuf::DescriptorPool>* owned_desc_pool) { std::unique_ptr<protobuf::DescriptorPool>* owned_desc_pool) {
// Attempt to lookup the pool in the registry. // Attempt to lookup the pool in the registry.
auto pool_fn = DescriptorPoolRegistry::Global()->Get(descriptor_source); auto pool_fn = DescriptorPoolRegistry::Global()->Get(descriptor_source);
if (pool_fn != nullptr) { if (pool_fn != nullptr) {
@ -77,7 +98,10 @@ Status GetDescriptorPool(
GetDescriptorPoolFromFile(env, descriptor_source, owned_desc_pool); GetDescriptorPoolFromFile(env, descriptor_source, owned_desc_pool);
if (status.ok()) { if (status.ok()) {
*desc_pool = owned_desc_pool->get(); *desc_pool = owned_desc_pool->get();
return Status::OK();
} }
status = GetDescriptorPoolFromBinary(descriptor_source, owned_desc_pool);
*desc_pool = owned_desc_pool->get(); *desc_pool = owned_desc_pool->get();
return status; return status;
} }

View File

@ -25,17 +25,27 @@ namespace tensorflow {
class Env; class Env;
class Status; class Status;
// Get a `DescriptorPool` object from the named `descriptor_source`. // Gets a `DescriptorPool` object from the `descriptor_source`. This may be:
// `descriptor_source` may be a path to a file accessible to TensorFlow, in
// which case it is parsed as a `FileDescriptorSet` and used to build the
// `DescriptorPool`.
// //
// `owned_desc_pool` will be filled in with the same pointer as `desc_pool` if // 1) An empty string or "local://", in which case the local descriptor pool
// the caller should take ownership. // created for proto definitions linked to the binary is returned.
extern tensorflow::Status GetDescriptorPool( //
tensorflow::Env* env, string const& descriptor_source, // 2) A file path, in which case the descriptor pool is created from the
tensorflow::protobuf::DescriptorPool const** desc_pool, // contents of the file, which is expected to contain a `FileDescriptorSet`
std::unique_ptr<tensorflow::protobuf::DescriptorPool>* owned_desc_pool); // serialized as a string. The descriptor pool ownership is transferred to the
// caller via `owned_desc_pool`.
//
// 3) A "bytes://<bytes>", in which case the descriptor pool is created from
// `<bytes>`, which is expected to be a `FileDescriptorSet` serialized as a
// string. The descriptor pool ownership is transferred to the caller via
// `owned_desc_pool`.
//
// Custom schemas can be supported by registering a handler with the
// `DescriptorPoolRegistry`.
Status GetDescriptorPool(
Env* env, string const& descriptor_source,
protobuf::DescriptorPool const** desc_pool,
std::unique_ptr<protobuf::DescriptorPool>* owned_desc_pool);
} // namespace tensorflow } // namespace tensorflow

View File

@ -51,18 +51,16 @@ class DescriptorSourceTestBase(test.TestCase):
# #
# The generated descriptor should capture the subset of `test_example.proto` # The generated descriptor should capture the subset of `test_example.proto`
# used in `test_base.simple_test_case()`. # used in `test_base.simple_test_case()`.
def _createDescriptorFile(self): def _createDescriptorProto(self):
set_proto = FileDescriptorSet() proto = FileDescriptorSet()
file_proto = set_proto.file.add( file_proto = proto.file.add(
name='types.proto', name='types.proto', package='tensorflow', syntax='proto3')
package='tensorflow',
syntax='proto3')
enum_proto = file_proto.enum_type.add(name='DataType') enum_proto = file_proto.enum_type.add(name='DataType')
enum_proto.value.add(name='DT_DOUBLE', number=0) enum_proto.value.add(name='DT_DOUBLE', number=0)
enum_proto.value.add(name='DT_BOOL', number=1) enum_proto.value.add(name='DT_BOOL', number=1)
file_proto = set_proto.file.add( file_proto = proto.file.add(
name='test_example.proto', name='test_example.proto',
package='tensorflow.contrib.proto', package='tensorflow.contrib.proto',
dependency=['types.proto']) dependency=['types.proto'])
@ -123,9 +121,12 @@ class DescriptorSourceTestBase(test.TestCase):
type_name='.tensorflow.contrib.proto.TestValue', type_name='.tensorflow.contrib.proto.TestValue',
label=FieldDescriptorProto.LABEL_OPTIONAL) label=FieldDescriptorProto.LABEL_OPTIONAL)
return proto
def _writeProtoToFile(self, proto):
fn = os.path.join(self.get_temp_dir(), 'descriptor.pb') fn = os.path.join(self.get_temp_dir(), 'descriptor.pb')
with open(fn, 'wb') as f: with open(fn, 'wb') as f:
f.write(set_proto.SerializeToString()) f.write(proto.SerializeToString())
return fn return fn
def _testRoundtrip(self, descriptor_source): def _testRoundtrip(self, descriptor_source):
@ -169,8 +170,12 @@ class DescriptorSourceTestBase(test.TestCase):
def testWithFileDescriptorSet(self): def testWithFileDescriptorSet(self):
# First try parsing with a local proto db, which should fail. # First try parsing with a local proto db, which should fail.
with self.assertRaisesOpError('No descriptor found for message type'): with self.assertRaisesOpError('No descriptor found for message type'):
self._testRoundtrip('local://') self._testRoundtrip(b'local://')
# Now try parsing with a FileDescriptorSet which contains the test proto. # Now try parsing with a FileDescriptorSet which contains the test proto.
descriptor_file = self._createDescriptorFile() proto = self._createDescriptorProto()
self._testRoundtrip(descriptor_file) proto_file = self._writeProtoToFile(proto)
self._testRoundtrip(proto_file)
# Finally, try parsing the descriptor as a serialized string.
self._testRoundtrip(b'bytes://' + proto.SerializeToString())