From fb1a1c872cc101b8eef3755a610380623e9cee8e Mon Sep 17 00:00:00 2001 From: Jiri Simsa <jsimsa@google.com> Date: Fri, 14 Jun 2019 17:36:42 -0700 Subject: [PATCH] Adding support for `bytes://` schema as a descriptor source for proto decoding and encoding ops. This change makes using the ops with a custom proto simpler. Previously, users would need to either: 1) package the C++ target for the custom proto as a shared library and dynamically load the library in the Python program that wishes to use the ops with instances of the custom proto. 2) use the proto compiler to generate a file with a descriptor for the custom proto and make sure that the Python program that wishes to use the ops with instances of the custom proto has access to the file. The `bytes://` schema makes it possible to embed the (serialized) descriptor proto into the descriptor source string, which means the Python binaries are self-contained (decoding does not depend on the ability to read the descriptor proto from a file) and do not need to jump through dynamic loading hoops. PiperOrigin-RevId: 253327493 --- .../base_api/api_def_DecodeProtoV2.pbtxt | 70 +++++++++---------- .../base_api/api_def_EncodeProto.pbtxt | 65 +++++++++-------- tensorflow/core/util/proto/BUILD | 1 + tensorflow/core/util/proto/descriptors.cc | 60 +++++++++++----- tensorflow/core/util/proto/descriptors.h | 30 +++++--- .../proto/descriptor_source_test_base.py | 27 ++++--- 6 files changed, 147 insertions(+), 106 deletions(-) diff --git a/tensorflow/core/api_def/base_api/api_def_DecodeProtoV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_DecodeProtoV2.pbtxt index c9e1fc58aad..9adb1a4056c 100644 --- a/tensorflow/core/api_def/base_api/api_def_DecodeProtoV2.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_DecodeProtoV2.pbtxt @@ -69,50 +69,48 @@ The `decode_proto` op extracts fields from a serialized protocol buffers message into tensors. The fields in `field_names` are decoded and converted to the corresponding `output_types` if possible. -A `message_type` name must be provided to give context for the field -names. The actual message descriptor can be looked up either in the -linked-in descriptor pool or a filename provided by the caller using -the `descriptor_source` attribute. +A `message_type` name must be provided to give context for the field names. +The actual message descriptor can be looked up either in the linked-in +descriptor pool or a filename provided by the caller using the +`descriptor_source` attribute. -Each output tensor is a dense tensor. This means that it is padded to -hold the largest number of repeated elements seen in the input -minibatch. (The shape is also padded by one to prevent zero-sized -dimensions). The actual repeat counts for each example in the -minibatch can be found in the `sizes` output. In many cases the output -of `decode_proto` is fed immediately into tf.squeeze if missing values -are not a concern. When using tf.squeeze, always pass the squeeze -dimension explicitly to avoid surprises. +Each output tensor is a dense tensor. This means that it is padded to hold +the largest number of repeated elements seen in the input minibatch. (The +shape is also padded by one to prevent zero-sized dimensions). The actual +repeat counts for each example in the minibatch can be found in the `sizes` +output. In many cases the output of `decode_proto` is fed immediately into +tf.squeeze if missing values are not a concern. When using tf.squeeze, always +pass the squeeze dimension explicitly to avoid surprises. -For the most part, the mapping between Proto field types and -TensorFlow dtypes is straightforward. However, there are a few -special cases: +For the most part, the mapping between Proto field types and TensorFlow dtypes +is straightforward. However, there are a few special cases: - A proto field that contains a submessage or group can only be converted -to `DT_STRING` (the serialized submessage). This is to reduce the -complexity of the API. The resulting string can be used as input -to another instance of the decode_proto op. +to `DT_STRING` (the serialized submessage). This is to reduce the complexity +of the API. The resulting string can be used as input to another instance of +the decode_proto op. - TensorFlow lacks support for unsigned integers. The ops represent uint64 -types as a `DT_INT64` with the same twos-complement bit pattern -(the obvious way). Unsigned int32 values can be represented exactly by -specifying type `DT_INT64`, or using twos-complement if the caller -specifies `DT_INT32` in the `output_types` attribute. - -The `descriptor_source` attribute selects a source of protocol -descriptors to consult when looking up `message_type`. This may be a -filename containing a serialized `FileDescriptorSet` message, -or the special value `local://`, in which case only descriptors linked -into the code will be searched; the filename can be on any filesystem -accessible to TensorFlow. - -You can build a `descriptor_source` file using the `--descriptor_set_out` -and `--include_imports` options to the protocol compiler `protoc`. - -The `local://` database only covers descriptors linked into the -code via C++ libraries, not Python imports. You can link in a proto descriptor -by creating a cc_library target with alwayslink=1. +types as a `DT_INT64` with the same twos-complement bit pattern (the obvious +way). Unsigned int32 values can be represented exactly by specifying type +`DT_INT64`, or using twos-complement if the caller specifies `DT_INT32` in +the `output_types` attribute. Both binary and text proto serializations are supported, and can be chosen using the `format` attribute. + +The `descriptor_source` attribute selects the source of protocol +descriptors to consult when looking up `message_type`. This may be: + +- An empty string or "local://", in which case protocol descriptors are +created for C++ (not Python) proto definitions linked to the binary. + +- A file, in which case protocol descriptors are created from the file, +which is expected to contain a `FileDescriptorSet` serialized as a string. +NOTE: You can build a `descriptor_source` file using the `--descriptor_set_out` +and `--include_imports` options to the protocol compiler `protoc`. + +- A "bytes://<bytes>", in which protocol descriptors are created from `<bytes>`, +which is expected to be a `FileDescriptorSet` serialized as a string. END } diff --git a/tensorflow/core/api_def/base_api/api_def_EncodeProto.pbtxt b/tensorflow/core/api_def/base_api/api_def_EncodeProto.pbtxt index 49b7b21b36e..b15abd2fce2 100644 --- a/tensorflow/core/api_def/base_api/api_def_EncodeProto.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_EncodeProto.pbtxt @@ -41,42 +41,45 @@ END The op serializes protobuf messages provided in the input tensors. END description: <<END -The types of the tensors in `values` must match the schema for the -fields specified in `field_names`. All the tensors in `values` must -have a common shape prefix, *batch_shape*. +The types of the tensors in `values` must match the schema for the fields +specified in `field_names`. All the tensors in `values` must have a common +shape prefix, *batch_shape*. -The `sizes` tensor specifies repeat counts for each field. The repeat -count (last dimension) of a each tensor in `values` must be greater -than or equal to corresponding repeat count in `sizes`. +The `sizes` tensor specifies repeat counts for each field. The repeat count +(last dimension) of a each tensor in `values` must be greater than or equal +to corresponding repeat count in `sizes`. -A `message_type` name must be provided to give context for the field -names. The actual message descriptor can be looked up either in the -linked-in descriptor pool or a filename provided by the caller using -the `descriptor_source` attribute. +A `message_type` name must be provided to give context for the field names. +The actual message descriptor can be looked up either in the linked-in +descriptor pool or a filename provided by the caller using the +`descriptor_source` attribute. -The `descriptor_source` attribute selects a source of protocol -descriptors to consult when looking up `message_type`. This may be a -filename containing a serialized `FileDescriptorSet` message, -or the special value `local://`, in which case only descriptors linked -into the code will be searched; the filename can be on any filesystem -accessible to TensorFlow. +For the most part, the mapping between Proto field types and TensorFlow dtypes +is straightforward. However, there are a few special cases: -You can build a `descriptor_source` file using the `--descriptor_set_out` +- A proto field that contains a submessage or group can only be converted +to `DT_STRING` (the serialized submessage). This is to reduce the complexity +of the API. The resulting string can be used as input to another instance of +the decode_proto op. + +- TensorFlow lacks support for unsigned integers. The ops represent uint64 +types as a `DT_INT64` with the same twos-complement bit pattern (the obvious +way). Unsigned int32 values can be represented exactly by specifying type +`DT_INT64`, or using twos-complement if the caller specifies `DT_INT32` in +the `output_types` attribute. + +The `descriptor_source` attribute selects the source of protocol +descriptors to consult when looking up `message_type`. This may be: + +- An empty string or "local://", in which case protocol descriptors are +created for C++ (not Python) proto definitions linked to the binary. + +- A file, in which case protocol descriptors are created from the file, +which is expected to contain a `FileDescriptorSet` serialized as a string. +NOTE: You can build a `descriptor_source` file using the `--descriptor_set_out` and `--include_imports` options to the protocol compiler `protoc`. -The `local://` database only covers descriptors linked into the -code via C++ libraries, not Python imports. You can link in a proto descriptor -by creating a cc_library target with alwayslink=1. - -There are a few special cases in the value mapping: - -Submessage and group fields must be pre-serialized as TensorFlow strings. - -TensorFlow lacks support for unsigned int64s, so they must be -represented as `tf.int64` with the same twos-complement bit pattern -(the obvious way). - -Unsigned int32 values can be represented exactly with `tf.int64`, or -with sign wrapping if the input is of type `tf.int32`. +- A "bytes://<bytes>", in which protocol descriptors are created from `<bytes>`, +which is expected to be a `FileDescriptorSet` serialized as a string. END } diff --git a/tensorflow/core/util/proto/BUILD b/tensorflow/core/util/proto/BUILD index 6ca6582e20e..0c1b905e812 100644 --- a/tensorflow/core/util/proto/BUILD +++ b/tensorflow/core/util/proto/BUILD @@ -23,6 +23,7 @@ cc_library( ":local_descriptor_pool_registration", "//tensorflow/core:framework", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/core/util/proto/descriptors.cc b/tensorflow/core/util/proto/descriptors.cc index c3797f1a8a8..3f82091ba91 100644 --- a/tensorflow/core/util/proto/descriptors.cc +++ b/tensorflow/core/util/proto/descriptors.cc @@ -13,15 +13,31 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/util/proto/descriptors.h" + +#include "absl/strings/match.h" +#include "absl/strings/strip.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/reader_op_kernel.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/util/proto/descriptor_pool_registry.h" -#include "tensorflow/core/util/proto/descriptors.h" - namespace tensorflow { namespace { +Status CreatePoolFromSet(const protobuf::FileDescriptorSet& set, + std::unique_ptr<protobuf::DescriptorPool>* out_pool) { + *out_pool = absl::make_unique<protobuf::DescriptorPool>(); + for (const auto& file : set.file()) { + if ((*out_pool)->BuildFile(file) == nullptr) { + return errors::InvalidArgument("Failed to load FileDescriptorProto: ", + file.DebugString()); + } + } + return Status::OK(); +} + // Build a `DescriptorPool` from the named file or URI. The file or URI // must be available to the current TensorFlow environment. // @@ -29,15 +45,14 @@ namespace { // `GetDescriptorPool()` for more information. Status GetDescriptorPoolFromFile( tensorflow::Env* env, const string& filename, - std::unique_ptr<tensorflow::protobuf::DescriptorPool>* owned_desc_pool) { + std::unique_ptr<protobuf::DescriptorPool>* owned_desc_pool) { Status st = env->FileExists(filename); if (!st.ok()) { return st; } - // Read and parse the FileDescriptorSet. - tensorflow::protobuf::FileDescriptorSet descs; - std::unique_ptr<tensorflow::ReadOnlyMemoryRegion> buf; + protobuf::FileDescriptorSet descs; + std::unique_ptr<ReadOnlyMemoryRegion> buf; st = env->NewReadOnlyMemoryRegionFromFile(filename, &buf); if (!st.ok()) { return st; @@ -46,25 +61,31 @@ Status GetDescriptorPoolFromFile( return errors::InvalidArgument( "descriptor_source contains invalid FileDescriptorSet: ", filename); } + return CreatePoolFromSet(descs, owned_desc_pool); +} - // Build a DescriptorPool from the FileDescriptorSet. - owned_desc_pool->reset(new tensorflow::protobuf::DescriptorPool()); - for (const auto& filedesc : descs.file()) { - if ((*owned_desc_pool)->BuildFile(filedesc) == nullptr) { - return errors::InvalidArgument( - "Problem loading FileDescriptorProto (missing dependencies?): ", - filename); - } +Status GetDescriptorPoolFromBinary( + const string& source, + std::unique_ptr<protobuf::DescriptorPool>* owned_desc_pool) { + if (!absl::StartsWith(source, "bytes://")) { + return errors::InvalidArgument( + "Source does not represent serialized file descriptor set proto."); } - return Status::OK(); + // Parse the FileDescriptorSet. + protobuf::FileDescriptorSet proto; + if (!proto.ParseFromString(string(absl::StripPrefix(source, "bytes://")))) { + return errors::InvalidArgument( + "Source does not represent serialized file descriptor set proto."); + } + return CreatePoolFromSet(proto, owned_desc_pool); } } // namespace Status GetDescriptorPool( - tensorflow::Env* env, string const& descriptor_source, - tensorflow::protobuf::DescriptorPool const** desc_pool, - std::unique_ptr<tensorflow::protobuf::DescriptorPool>* owned_desc_pool) { + Env* env, string const& descriptor_source, + protobuf::DescriptorPool const** desc_pool, + std::unique_ptr<protobuf::DescriptorPool>* owned_desc_pool) { // Attempt to lookup the pool in the registry. auto pool_fn = DescriptorPoolRegistry::Global()->Get(descriptor_source); if (pool_fn != nullptr) { @@ -77,7 +98,10 @@ Status GetDescriptorPool( GetDescriptorPoolFromFile(env, descriptor_source, owned_desc_pool); if (status.ok()) { *desc_pool = owned_desc_pool->get(); + return Status::OK(); } + + status = GetDescriptorPoolFromBinary(descriptor_source, owned_desc_pool); *desc_pool = owned_desc_pool->get(); return status; } diff --git a/tensorflow/core/util/proto/descriptors.h b/tensorflow/core/util/proto/descriptors.h index 92ee8997ab2..a9942d312fc 100644 --- a/tensorflow/core/util/proto/descriptors.h +++ b/tensorflow/core/util/proto/descriptors.h @@ -25,17 +25,27 @@ namespace tensorflow { class Env; class Status; -// Get a `DescriptorPool` object from the named `descriptor_source`. -// `descriptor_source` may be a path to a file accessible to TensorFlow, in -// which case it is parsed as a `FileDescriptorSet` and used to build the -// `DescriptorPool`. +// Gets a `DescriptorPool` object from the `descriptor_source`. This may be: // -// `owned_desc_pool` will be filled in with the same pointer as `desc_pool` if -// the caller should take ownership. -extern tensorflow::Status GetDescriptorPool( - tensorflow::Env* env, string const& descriptor_source, - tensorflow::protobuf::DescriptorPool const** desc_pool, - std::unique_ptr<tensorflow::protobuf::DescriptorPool>* owned_desc_pool); +// 1) An empty string or "local://", in which case the local descriptor pool +// created for proto definitions linked to the binary is returned. +// +// 2) A file path, in which case the descriptor pool is created from the +// contents of the file, which is expected to contain a `FileDescriptorSet` +// serialized as a string. The descriptor pool ownership is transferred to the +// caller via `owned_desc_pool`. +// +// 3) A "bytes://<bytes>", in which case the descriptor pool is created from +// `<bytes>`, which is expected to be a `FileDescriptorSet` serialized as a +// string. The descriptor pool ownership is transferred to the caller via +// `owned_desc_pool`. +// +// Custom schemas can be supported by registering a handler with the +// `DescriptorPoolRegistry`. +Status GetDescriptorPool( + Env* env, string const& descriptor_source, + protobuf::DescriptorPool const** desc_pool, + std::unique_ptr<protobuf::DescriptorPool>* owned_desc_pool); } // namespace tensorflow diff --git a/tensorflow/python/kernel_tests/proto/descriptor_source_test_base.py b/tensorflow/python/kernel_tests/proto/descriptor_source_test_base.py index 24f154b90f6..831c7403da3 100644 --- a/tensorflow/python/kernel_tests/proto/descriptor_source_test_base.py +++ b/tensorflow/python/kernel_tests/proto/descriptor_source_test_base.py @@ -51,18 +51,16 @@ class DescriptorSourceTestBase(test.TestCase): # # The generated descriptor should capture the subset of `test_example.proto` # used in `test_base.simple_test_case()`. - def _createDescriptorFile(self): - set_proto = FileDescriptorSet() + def _createDescriptorProto(self): + proto = FileDescriptorSet() - file_proto = set_proto.file.add( - name='types.proto', - package='tensorflow', - syntax='proto3') + file_proto = proto.file.add( + name='types.proto', package='tensorflow', syntax='proto3') enum_proto = file_proto.enum_type.add(name='DataType') enum_proto.value.add(name='DT_DOUBLE', number=0) enum_proto.value.add(name='DT_BOOL', number=1) - file_proto = set_proto.file.add( + file_proto = proto.file.add( name='test_example.proto', package='tensorflow.contrib.proto', dependency=['types.proto']) @@ -123,9 +121,12 @@ class DescriptorSourceTestBase(test.TestCase): type_name='.tensorflow.contrib.proto.TestValue', label=FieldDescriptorProto.LABEL_OPTIONAL) + return proto + + def _writeProtoToFile(self, proto): fn = os.path.join(self.get_temp_dir(), 'descriptor.pb') with open(fn, 'wb') as f: - f.write(set_proto.SerializeToString()) + f.write(proto.SerializeToString()) return fn def _testRoundtrip(self, descriptor_source): @@ -169,8 +170,12 @@ class DescriptorSourceTestBase(test.TestCase): def testWithFileDescriptorSet(self): # First try parsing with a local proto db, which should fail. with self.assertRaisesOpError('No descriptor found for message type'): - self._testRoundtrip('local://') + self._testRoundtrip(b'local://') # Now try parsing with a FileDescriptorSet which contains the test proto. - descriptor_file = self._createDescriptorFile() - self._testRoundtrip(descriptor_file) + proto = self._createDescriptorProto() + proto_file = self._writeProtoToFile(proto) + self._testRoundtrip(proto_file) + + # Finally, try parsing the descriptor as a serialized string. + self._testRoundtrip(b'bytes://' + proto.SerializeToString())