From fb1a1c872cc101b8eef3755a610380623e9cee8e Mon Sep 17 00:00:00 2001
From: Jiri Simsa <jsimsa@google.com>
Date: Fri, 14 Jun 2019 17:36:42 -0700
Subject: [PATCH] Adding support for `bytes://` schema as a descriptor source
 for proto decoding and encoding ops.

This change makes using the ops with a custom proto simpler. Previously, users would need to either:

1) package the C++ target for the custom proto as a shared library and dynamically load the library in the Python program that wishes to use the ops with instances of the custom proto.

2) use the proto compiler to generate a file with a descriptor for the custom proto and make sure that the Python program that wishes to use the ops with instances of the custom proto has access to the file.

The `bytes://` schema makes it possible to embed the (serialized) descriptor proto into the descriptor source string, which means the Python binaries are self-contained (decoding does not depend on the ability to read the descriptor proto from a file) and do not need to jump through dynamic loading hoops.

PiperOrigin-RevId: 253327493
---
 .../base_api/api_def_DecodeProtoV2.pbtxt      | 70 +++++++++----------
 .../base_api/api_def_EncodeProto.pbtxt        | 65 +++++++++--------
 tensorflow/core/util/proto/BUILD              |  1 +
 tensorflow/core/util/proto/descriptors.cc     | 60 +++++++++++-----
 tensorflow/core/util/proto/descriptors.h      | 30 +++++---
 .../proto/descriptor_source_test_base.py      | 27 ++++---
 6 files changed, 147 insertions(+), 106 deletions(-)

diff --git a/tensorflow/core/api_def/base_api/api_def_DecodeProtoV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_DecodeProtoV2.pbtxt
index c9e1fc58aad..9adb1a4056c 100644
--- a/tensorflow/core/api_def/base_api/api_def_DecodeProtoV2.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_DecodeProtoV2.pbtxt
@@ -69,50 +69,48 @@ The `decode_proto` op extracts fields from a serialized protocol buffers
 message into tensors.  The fields in `field_names` are decoded and converted
 to the corresponding `output_types` if possible.
 
-A `message_type` name must be provided to give context for the field
-names. The actual message descriptor can be looked up either in the
-linked-in descriptor pool or a filename provided by the caller using
-the `descriptor_source` attribute.
+A `message_type` name must be provided to give context for the field names.
+The actual message descriptor can be looked up either in the linked-in
+descriptor pool or a filename provided by the caller using the
+`descriptor_source` attribute.
 
-Each output tensor is a dense tensor. This means that it is padded to
-hold the largest number of repeated elements seen in the input
-minibatch. (The shape is also padded by one to prevent zero-sized
-dimensions). The actual repeat counts for each example in the
-minibatch can be found in the `sizes` output. In many cases the output
-of `decode_proto` is fed immediately into tf.squeeze if missing values
-are not a concern. When using tf.squeeze, always pass the squeeze
-dimension explicitly to avoid surprises.
+Each output tensor is a dense tensor. This means that it is padded to hold
+the largest number of repeated elements seen in the input minibatch. (The
+shape is also padded by one to prevent zero-sized dimensions). The actual
+repeat counts for each example in the minibatch can be found in the `sizes`
+output. In many cases the output of `decode_proto` is fed immediately into
+tf.squeeze if missing values are not a concern. When using tf.squeeze, always
+pass the squeeze dimension explicitly to avoid surprises.
 
-For the most part, the mapping between Proto field types and
-TensorFlow dtypes is straightforward. However, there are a few
-special cases:
+For the most part, the mapping between Proto field types and TensorFlow dtypes
+is straightforward. However, there are a few special cases:
 
 - A proto field that contains a submessage or group can only be converted
-to `DT_STRING` (the serialized submessage). This is to reduce the
-complexity of the API. The resulting string can be used as input
-to another instance of the decode_proto op.
+to `DT_STRING` (the serialized submessage). This is to reduce the complexity
+of the API. The resulting string can be used as input to another instance of
+the decode_proto op.
 
 - TensorFlow lacks support for unsigned integers. The ops represent uint64
-types as a `DT_INT64` with the same twos-complement bit pattern
-(the obvious way). Unsigned int32 values can be represented exactly by
-specifying type `DT_INT64`, or using twos-complement if the caller
-specifies `DT_INT32` in the `output_types` attribute.
-
-The `descriptor_source` attribute selects a source of protocol
-descriptors to consult when looking up `message_type`. This may be a
-filename containing a serialized `FileDescriptorSet` message,
-or the special value `local://`, in which case only descriptors linked
-into the code will be searched; the filename can be on any filesystem
-accessible to TensorFlow.
-
-You can build a `descriptor_source` file using the `--descriptor_set_out`
-and `--include_imports` options to the protocol compiler `protoc`.
-
-The `local://` database only covers descriptors linked into the
-code via C++ libraries, not Python imports. You can link in a proto descriptor
-by creating a cc_library target with alwayslink=1.
+types as a `DT_INT64` with the same twos-complement bit pattern (the obvious
+way). Unsigned int32 values can be represented exactly by specifying type
+`DT_INT64`, or using twos-complement if the caller specifies `DT_INT32` in
+the `output_types` attribute.
 
 Both binary and text proto serializations are supported, and can be
 chosen using the `format` attribute.
+
+The `descriptor_source` attribute selects the source of protocol
+descriptors to consult when looking up `message_type`. This may be:
+
+- An empty string  or "local://", in which case protocol descriptors are
+created for C++ (not Python) proto definitions linked to the binary.
+
+- A file, in which case protocol descriptors are created from the file,
+which is expected to contain a `FileDescriptorSet` serialized as a string.
+NOTE: You can build a `descriptor_source` file using the `--descriptor_set_out`
+and `--include_imports` options to the protocol compiler `protoc`.
+
+- A "bytes://<bytes>", in which protocol descriptors are created from `<bytes>`,
+which is expected to be a `FileDescriptorSet` serialized as a string.
 END
 }
diff --git a/tensorflow/core/api_def/base_api/api_def_EncodeProto.pbtxt b/tensorflow/core/api_def/base_api/api_def_EncodeProto.pbtxt
index 49b7b21b36e..b15abd2fce2 100644
--- a/tensorflow/core/api_def/base_api/api_def_EncodeProto.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_EncodeProto.pbtxt
@@ -41,42 +41,45 @@ END
 The op serializes protobuf messages provided in the input tensors.
 END
   description: <<END
-The types of the tensors in `values` must match the schema for the
-fields specified in `field_names`. All the tensors in `values` must
-have a common shape prefix, *batch_shape*.
+The types of the tensors in `values` must match the schema for the fields
+specified in `field_names`. All the tensors in `values` must have a common
+shape prefix, *batch_shape*.
 
-The `sizes` tensor specifies repeat counts for each field.  The repeat
-count (last dimension) of a each tensor in `values` must be greater
-than or equal to corresponding repeat count in `sizes`.
+The `sizes` tensor specifies repeat counts for each field.  The repeat count
+(last dimension) of a each tensor in `values` must be greater than or equal
+to corresponding repeat count in `sizes`.
 
-A `message_type` name must be provided to give context for the field
-names. The actual message descriptor can be looked up either in the
-linked-in descriptor pool or a filename provided by the caller using
-the `descriptor_source` attribute.
+A `message_type` name must be provided to give context for the field names.
+The actual message descriptor can be looked up either in the linked-in
+descriptor pool or a filename provided by the caller using the
+`descriptor_source` attribute.
 
-The `descriptor_source` attribute selects a source of protocol
-descriptors to consult when looking up `message_type`. This may be a
-filename containing a serialized `FileDescriptorSet` message,
-or the special value `local://`, in which case only descriptors linked
-into the code will be searched; the filename can be on any filesystem
-accessible to TensorFlow.
+For the most part, the mapping between Proto field types and TensorFlow dtypes
+is straightforward. However, there are a few special cases:
 
-You can build a `descriptor_source` file using the `--descriptor_set_out`
+- A proto field that contains a submessage or group can only be converted
+to `DT_STRING` (the serialized submessage). This is to reduce the complexity
+of the API. The resulting string can be used as input to another instance of
+the decode_proto op.
+
+- TensorFlow lacks support for unsigned integers. The ops represent uint64
+types as a `DT_INT64` with the same twos-complement bit pattern (the obvious
+way). Unsigned int32 values can be represented exactly by specifying type
+`DT_INT64`, or using twos-complement if the caller specifies `DT_INT32` in
+the `output_types` attribute.
+
+The `descriptor_source` attribute selects the source of protocol
+descriptors to consult when looking up `message_type`. This may be:
+
+- An empty string  or "local://", in which case protocol descriptors are
+created for C++ (not Python) proto definitions linked to the binary.
+
+- A file, in which case protocol descriptors are created from the file,
+which is expected to contain a `FileDescriptorSet` serialized as a string.
+NOTE: You can build a `descriptor_source` file using the `--descriptor_set_out`
 and `--include_imports` options to the protocol compiler `protoc`.
 
-The `local://` database only covers descriptors linked into the
-code via C++ libraries, not Python imports. You can link in a proto descriptor
-by creating a cc_library target with alwayslink=1.
-
-There are a few special cases in the value mapping:
-
-Submessage and group fields must be pre-serialized as TensorFlow strings.
-
-TensorFlow lacks support for unsigned int64s, so they must be
-represented as `tf.int64` with the same twos-complement bit pattern
-(the obvious way).
-
-Unsigned int32 values can be represented exactly with `tf.int64`, or
-with sign wrapping if the input is of type `tf.int32`.
+- A "bytes://<bytes>", in which protocol descriptors are created from `<bytes>`,
+which is expected to be a `FileDescriptorSet` serialized as a string.
 END
 }
diff --git a/tensorflow/core/util/proto/BUILD b/tensorflow/core/util/proto/BUILD
index 6ca6582e20e..0c1b905e812 100644
--- a/tensorflow/core/util/proto/BUILD
+++ b/tensorflow/core/util/proto/BUILD
@@ -23,6 +23,7 @@ cc_library(
         ":local_descriptor_pool_registration",
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
+        "@com_google_absl//absl/strings",
     ],
 )
 
diff --git a/tensorflow/core/util/proto/descriptors.cc b/tensorflow/core/util/proto/descriptors.cc
index c3797f1a8a8..3f82091ba91 100644
--- a/tensorflow/core/util/proto/descriptors.cc
+++ b/tensorflow/core/util/proto/descriptors.cc
@@ -13,15 +13,31 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
+#include "tensorflow/core/util/proto/descriptors.h"
+
+#include "absl/strings/match.h"
+#include "absl/strings/strip.h"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/reader_op_kernel.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/protobuf.h"
 #include "tensorflow/core/util/proto/descriptor_pool_registry.h"
 
-#include "tensorflow/core/util/proto/descriptors.h"
-
 namespace tensorflow {
 namespace {
 
+Status CreatePoolFromSet(const protobuf::FileDescriptorSet& set,
+                         std::unique_ptr<protobuf::DescriptorPool>* out_pool) {
+  *out_pool = absl::make_unique<protobuf::DescriptorPool>();
+  for (const auto& file : set.file()) {
+    if ((*out_pool)->BuildFile(file) == nullptr) {
+      return errors::InvalidArgument("Failed to load FileDescriptorProto: ",
+                                     file.DebugString());
+    }
+  }
+  return Status::OK();
+}
+
 // Build a `DescriptorPool` from the named file or URI. The file or URI
 // must be available to the current TensorFlow environment.
 //
@@ -29,15 +45,14 @@ namespace {
 // `GetDescriptorPool()` for more information.
 Status GetDescriptorPoolFromFile(
     tensorflow::Env* env, const string& filename,
-    std::unique_ptr<tensorflow::protobuf::DescriptorPool>* owned_desc_pool) {
+    std::unique_ptr<protobuf::DescriptorPool>* owned_desc_pool) {
   Status st = env->FileExists(filename);
   if (!st.ok()) {
     return st;
   }
-
   // Read and parse the FileDescriptorSet.
-  tensorflow::protobuf::FileDescriptorSet descs;
-  std::unique_ptr<tensorflow::ReadOnlyMemoryRegion> buf;
+  protobuf::FileDescriptorSet descs;
+  std::unique_ptr<ReadOnlyMemoryRegion> buf;
   st = env->NewReadOnlyMemoryRegionFromFile(filename, &buf);
   if (!st.ok()) {
     return st;
@@ -46,25 +61,31 @@ Status GetDescriptorPoolFromFile(
     return errors::InvalidArgument(
         "descriptor_source contains invalid FileDescriptorSet: ", filename);
   }
+  return CreatePoolFromSet(descs, owned_desc_pool);
+}
 
-  // Build a DescriptorPool from the FileDescriptorSet.
-  owned_desc_pool->reset(new tensorflow::protobuf::DescriptorPool());
-  for (const auto& filedesc : descs.file()) {
-    if ((*owned_desc_pool)->BuildFile(filedesc) == nullptr) {
-      return errors::InvalidArgument(
-          "Problem loading FileDescriptorProto (missing dependencies?): ",
-          filename);
-    }
+Status GetDescriptorPoolFromBinary(
+    const string& source,
+    std::unique_ptr<protobuf::DescriptorPool>* owned_desc_pool) {
+  if (!absl::StartsWith(source, "bytes://")) {
+    return errors::InvalidArgument(
+        "Source does not represent serialized file descriptor set proto.");
   }
-  return Status::OK();
+  // Parse the FileDescriptorSet.
+  protobuf::FileDescriptorSet proto;
+  if (!proto.ParseFromString(string(absl::StripPrefix(source, "bytes://")))) {
+    return errors::InvalidArgument(
+        "Source does not represent serialized file descriptor set proto.");
+  }
+  return CreatePoolFromSet(proto, owned_desc_pool);
 }
 
 }  // namespace
 
 Status GetDescriptorPool(
-    tensorflow::Env* env, string const& descriptor_source,
-    tensorflow::protobuf::DescriptorPool const** desc_pool,
-    std::unique_ptr<tensorflow::protobuf::DescriptorPool>* owned_desc_pool) {
+    Env* env, string const& descriptor_source,
+    protobuf::DescriptorPool const** desc_pool,
+    std::unique_ptr<protobuf::DescriptorPool>* owned_desc_pool) {
   // Attempt to lookup the pool in the registry.
   auto pool_fn = DescriptorPoolRegistry::Global()->Get(descriptor_source);
   if (pool_fn != nullptr) {
@@ -77,7 +98,10 @@ Status GetDescriptorPool(
       GetDescriptorPoolFromFile(env, descriptor_source, owned_desc_pool);
   if (status.ok()) {
     *desc_pool = owned_desc_pool->get();
+    return Status::OK();
   }
+
+  status = GetDescriptorPoolFromBinary(descriptor_source, owned_desc_pool);
   *desc_pool = owned_desc_pool->get();
   return status;
 }
diff --git a/tensorflow/core/util/proto/descriptors.h b/tensorflow/core/util/proto/descriptors.h
index 92ee8997ab2..a9942d312fc 100644
--- a/tensorflow/core/util/proto/descriptors.h
+++ b/tensorflow/core/util/proto/descriptors.h
@@ -25,17 +25,27 @@ namespace tensorflow {
 class Env;
 class Status;
 
-// Get a `DescriptorPool` object from the named `descriptor_source`.
-// `descriptor_source` may be a path to a file accessible to TensorFlow, in
-// which case it is parsed as a `FileDescriptorSet` and used to build the
-// `DescriptorPool`.
+// Gets a `DescriptorPool` object from the `descriptor_source`. This may be:
 //
-// `owned_desc_pool` will be filled in with the same pointer as `desc_pool` if
-// the caller should take ownership.
-extern tensorflow::Status GetDescriptorPool(
-    tensorflow::Env* env, string const& descriptor_source,
-    tensorflow::protobuf::DescriptorPool const** desc_pool,
-    std::unique_ptr<tensorflow::protobuf::DescriptorPool>* owned_desc_pool);
+// 1) An empty string  or "local://", in which case the local descriptor pool
+// created for proto definitions linked to the binary is returned.
+//
+// 2) A file path, in which case the descriptor pool is created from the
+// contents of the file, which is expected to contain a `FileDescriptorSet`
+// serialized as a string. The descriptor pool ownership is transferred to the
+// caller via `owned_desc_pool`.
+//
+// 3) A "bytes://<bytes>", in which case the descriptor pool is created from
+// `<bytes>`, which is expected to be a `FileDescriptorSet` serialized as a
+// string. The descriptor pool ownership is transferred to the caller via
+// `owned_desc_pool`.
+//
+// Custom schemas can be supported by registering a handler with the
+// `DescriptorPoolRegistry`.
+Status GetDescriptorPool(
+    Env* env, string const& descriptor_source,
+    protobuf::DescriptorPool const** desc_pool,
+    std::unique_ptr<protobuf::DescriptorPool>* owned_desc_pool);
 
 }  // namespace tensorflow
 
diff --git a/tensorflow/python/kernel_tests/proto/descriptor_source_test_base.py b/tensorflow/python/kernel_tests/proto/descriptor_source_test_base.py
index 24f154b90f6..831c7403da3 100644
--- a/tensorflow/python/kernel_tests/proto/descriptor_source_test_base.py
+++ b/tensorflow/python/kernel_tests/proto/descriptor_source_test_base.py
@@ -51,18 +51,16 @@ class DescriptorSourceTestBase(test.TestCase):
   #
   # The generated descriptor should capture the subset of `test_example.proto`
   # used in `test_base.simple_test_case()`.
-  def _createDescriptorFile(self):
-    set_proto = FileDescriptorSet()
+  def _createDescriptorProto(self):
+    proto = FileDescriptorSet()
 
-    file_proto = set_proto.file.add(
-        name='types.proto',
-        package='tensorflow',
-        syntax='proto3')
+    file_proto = proto.file.add(
+        name='types.proto', package='tensorflow', syntax='proto3')
     enum_proto = file_proto.enum_type.add(name='DataType')
     enum_proto.value.add(name='DT_DOUBLE', number=0)
     enum_proto.value.add(name='DT_BOOL', number=1)
 
-    file_proto = set_proto.file.add(
+    file_proto = proto.file.add(
         name='test_example.proto',
         package='tensorflow.contrib.proto',
         dependency=['types.proto'])
@@ -123,9 +121,12 @@ class DescriptorSourceTestBase(test.TestCase):
         type_name='.tensorflow.contrib.proto.TestValue',
         label=FieldDescriptorProto.LABEL_OPTIONAL)
 
+    return proto
+
+  def _writeProtoToFile(self, proto):
     fn = os.path.join(self.get_temp_dir(), 'descriptor.pb')
     with open(fn, 'wb') as f:
-      f.write(set_proto.SerializeToString())
+      f.write(proto.SerializeToString())
     return fn
 
   def _testRoundtrip(self, descriptor_source):
@@ -169,8 +170,12 @@ class DescriptorSourceTestBase(test.TestCase):
   def testWithFileDescriptorSet(self):
     # First try parsing with a local proto db, which should fail.
     with self.assertRaisesOpError('No descriptor found for message type'):
-      self._testRoundtrip('local://')
+      self._testRoundtrip(b'local://')
 
     # Now try parsing with a FileDescriptorSet which contains the test proto.
-    descriptor_file = self._createDescriptorFile()
-    self._testRoundtrip(descriptor_file)
+    proto = self._createDescriptorProto()
+    proto_file = self._writeProtoToFile(proto)
+    self._testRoundtrip(proto_file)
+
+    # Finally, try parsing the descriptor as a serialized string.
+    self._testRoundtrip(b'bytes://' + proto.SerializeToString())