Adding support for bytes://
schema as a descriptor source for proto decoding and encoding ops.
This change makes using the ops with a custom proto simpler. Previously, users would need to either: 1) package the C++ target for the custom proto as a shared library and dynamically load the library in the Python program that wishes to use the ops with instances of the custom proto. 2) use the proto compiler to generate a file with a descriptor for the custom proto and make sure that the Python program that wishes to use the ops with instances of the custom proto has access to the file. The `bytes://` schema makes it possible to embed the (serialized) descriptor proto into the descriptor source string, which means the Python binaries are self-contained (decoding does not depend on the ability to read the descriptor proto from a file) and do not need to jump through dynamic loading hoops. PiperOrigin-RevId: 253327493
This commit is contained in:
parent
db3ecd34a4
commit
fb1a1c872c
@ -69,50 +69,48 @@ The `decode_proto` op extracts fields from a serialized protocol buffers
|
||||
message into tensors. The fields in `field_names` are decoded and converted
|
||||
to the corresponding `output_types` if possible.
|
||||
|
||||
A `message_type` name must be provided to give context for the field
|
||||
names. The actual message descriptor can be looked up either in the
|
||||
linked-in descriptor pool or a filename provided by the caller using
|
||||
the `descriptor_source` attribute.
|
||||
A `message_type` name must be provided to give context for the field names.
|
||||
The actual message descriptor can be looked up either in the linked-in
|
||||
descriptor pool or a filename provided by the caller using the
|
||||
`descriptor_source` attribute.
|
||||
|
||||
Each output tensor is a dense tensor. This means that it is padded to
|
||||
hold the largest number of repeated elements seen in the input
|
||||
minibatch. (The shape is also padded by one to prevent zero-sized
|
||||
dimensions). The actual repeat counts for each example in the
|
||||
minibatch can be found in the `sizes` output. In many cases the output
|
||||
of `decode_proto` is fed immediately into tf.squeeze if missing values
|
||||
are not a concern. When using tf.squeeze, always pass the squeeze
|
||||
dimension explicitly to avoid surprises.
|
||||
Each output tensor is a dense tensor. This means that it is padded to hold
|
||||
the largest number of repeated elements seen in the input minibatch. (The
|
||||
shape is also padded by one to prevent zero-sized dimensions). The actual
|
||||
repeat counts for each example in the minibatch can be found in the `sizes`
|
||||
output. In many cases the output of `decode_proto` is fed immediately into
|
||||
tf.squeeze if missing values are not a concern. When using tf.squeeze, always
|
||||
pass the squeeze dimension explicitly to avoid surprises.
|
||||
|
||||
For the most part, the mapping between Proto field types and
|
||||
TensorFlow dtypes is straightforward. However, there are a few
|
||||
special cases:
|
||||
For the most part, the mapping between Proto field types and TensorFlow dtypes
|
||||
is straightforward. However, there are a few special cases:
|
||||
|
||||
- A proto field that contains a submessage or group can only be converted
|
||||
to `DT_STRING` (the serialized submessage). This is to reduce the
|
||||
complexity of the API. The resulting string can be used as input
|
||||
to another instance of the decode_proto op.
|
||||
to `DT_STRING` (the serialized submessage). This is to reduce the complexity
|
||||
of the API. The resulting string can be used as input to another instance of
|
||||
the decode_proto op.
|
||||
|
||||
- TensorFlow lacks support for unsigned integers. The ops represent uint64
|
||||
types as a `DT_INT64` with the same twos-complement bit pattern
|
||||
(the obvious way). Unsigned int32 values can be represented exactly by
|
||||
specifying type `DT_INT64`, or using twos-complement if the caller
|
||||
specifies `DT_INT32` in the `output_types` attribute.
|
||||
|
||||
The `descriptor_source` attribute selects a source of protocol
|
||||
descriptors to consult when looking up `message_type`. This may be a
|
||||
filename containing a serialized `FileDescriptorSet` message,
|
||||
or the special value `local://`, in which case only descriptors linked
|
||||
into the code will be searched; the filename can be on any filesystem
|
||||
accessible to TensorFlow.
|
||||
|
||||
You can build a `descriptor_source` file using the `--descriptor_set_out`
|
||||
and `--include_imports` options to the protocol compiler `protoc`.
|
||||
|
||||
The `local://` database only covers descriptors linked into the
|
||||
code via C++ libraries, not Python imports. You can link in a proto descriptor
|
||||
by creating a cc_library target with alwayslink=1.
|
||||
types as a `DT_INT64` with the same twos-complement bit pattern (the obvious
|
||||
way). Unsigned int32 values can be represented exactly by specifying type
|
||||
`DT_INT64`, or using twos-complement if the caller specifies `DT_INT32` in
|
||||
the `output_types` attribute.
|
||||
|
||||
Both binary and text proto serializations are supported, and can be
|
||||
chosen using the `format` attribute.
|
||||
|
||||
The `descriptor_source` attribute selects the source of protocol
|
||||
descriptors to consult when looking up `message_type`. This may be:
|
||||
|
||||
- An empty string or "local://", in which case protocol descriptors are
|
||||
created for C++ (not Python) proto definitions linked to the binary.
|
||||
|
||||
- A file, in which case protocol descriptors are created from the file,
|
||||
which is expected to contain a `FileDescriptorSet` serialized as a string.
|
||||
NOTE: You can build a `descriptor_source` file using the `--descriptor_set_out`
|
||||
and `--include_imports` options to the protocol compiler `protoc`.
|
||||
|
||||
- A "bytes://<bytes>", in which protocol descriptors are created from `<bytes>`,
|
||||
which is expected to be a `FileDescriptorSet` serialized as a string.
|
||||
END
|
||||
}
|
||||
|
@ -41,42 +41,45 @@ END
|
||||
The op serializes protobuf messages provided in the input tensors.
|
||||
END
|
||||
description: <<END
|
||||
The types of the tensors in `values` must match the schema for the
|
||||
fields specified in `field_names`. All the tensors in `values` must
|
||||
have a common shape prefix, *batch_shape*.
|
||||
The types of the tensors in `values` must match the schema for the fields
|
||||
specified in `field_names`. All the tensors in `values` must have a common
|
||||
shape prefix, *batch_shape*.
|
||||
|
||||
The `sizes` tensor specifies repeat counts for each field. The repeat
|
||||
count (last dimension) of a each tensor in `values` must be greater
|
||||
than or equal to corresponding repeat count in `sizes`.
|
||||
The `sizes` tensor specifies repeat counts for each field. The repeat count
|
||||
(last dimension) of a each tensor in `values` must be greater than or equal
|
||||
to corresponding repeat count in `sizes`.
|
||||
|
||||
A `message_type` name must be provided to give context for the field
|
||||
names. The actual message descriptor can be looked up either in the
|
||||
linked-in descriptor pool or a filename provided by the caller using
|
||||
the `descriptor_source` attribute.
|
||||
A `message_type` name must be provided to give context for the field names.
|
||||
The actual message descriptor can be looked up either in the linked-in
|
||||
descriptor pool or a filename provided by the caller using the
|
||||
`descriptor_source` attribute.
|
||||
|
||||
The `descriptor_source` attribute selects a source of protocol
|
||||
descriptors to consult when looking up `message_type`. This may be a
|
||||
filename containing a serialized `FileDescriptorSet` message,
|
||||
or the special value `local://`, in which case only descriptors linked
|
||||
into the code will be searched; the filename can be on any filesystem
|
||||
accessible to TensorFlow.
|
||||
For the most part, the mapping between Proto field types and TensorFlow dtypes
|
||||
is straightforward. However, there are a few special cases:
|
||||
|
||||
You can build a `descriptor_source` file using the `--descriptor_set_out`
|
||||
- A proto field that contains a submessage or group can only be converted
|
||||
to `DT_STRING` (the serialized submessage). This is to reduce the complexity
|
||||
of the API. The resulting string can be used as input to another instance of
|
||||
the decode_proto op.
|
||||
|
||||
- TensorFlow lacks support for unsigned integers. The ops represent uint64
|
||||
types as a `DT_INT64` with the same twos-complement bit pattern (the obvious
|
||||
way). Unsigned int32 values can be represented exactly by specifying type
|
||||
`DT_INT64`, or using twos-complement if the caller specifies `DT_INT32` in
|
||||
the `output_types` attribute.
|
||||
|
||||
The `descriptor_source` attribute selects the source of protocol
|
||||
descriptors to consult when looking up `message_type`. This may be:
|
||||
|
||||
- An empty string or "local://", in which case protocol descriptors are
|
||||
created for C++ (not Python) proto definitions linked to the binary.
|
||||
|
||||
- A file, in which case protocol descriptors are created from the file,
|
||||
which is expected to contain a `FileDescriptorSet` serialized as a string.
|
||||
NOTE: You can build a `descriptor_source` file using the `--descriptor_set_out`
|
||||
and `--include_imports` options to the protocol compiler `protoc`.
|
||||
|
||||
The `local://` database only covers descriptors linked into the
|
||||
code via C++ libraries, not Python imports. You can link in a proto descriptor
|
||||
by creating a cc_library target with alwayslink=1.
|
||||
|
||||
There are a few special cases in the value mapping:
|
||||
|
||||
Submessage and group fields must be pre-serialized as TensorFlow strings.
|
||||
|
||||
TensorFlow lacks support for unsigned int64s, so they must be
|
||||
represented as `tf.int64` with the same twos-complement bit pattern
|
||||
(the obvious way).
|
||||
|
||||
Unsigned int32 values can be represented exactly with `tf.int64`, or
|
||||
with sign wrapping if the input is of type `tf.int32`.
|
||||
- A "bytes://<bytes>", in which protocol descriptors are created from `<bytes>`,
|
||||
which is expected to be a `FileDescriptorSet` serialized as a string.
|
||||
END
|
||||
}
|
||||
|
@ -23,6 +23,7 @@ cc_library(
|
||||
":local_descriptor_pool_registration",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -13,15 +13,31 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/util/proto/descriptors.h"
|
||||
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/strip.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/reader_op_kernel.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/util/proto/descriptor_pool_registry.h"
|
||||
|
||||
#include "tensorflow/core/util/proto/descriptors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
Status CreatePoolFromSet(const protobuf::FileDescriptorSet& set,
|
||||
std::unique_ptr<protobuf::DescriptorPool>* out_pool) {
|
||||
*out_pool = absl::make_unique<protobuf::DescriptorPool>();
|
||||
for (const auto& file : set.file()) {
|
||||
if ((*out_pool)->BuildFile(file) == nullptr) {
|
||||
return errors::InvalidArgument("Failed to load FileDescriptorProto: ",
|
||||
file.DebugString());
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Build a `DescriptorPool` from the named file or URI. The file or URI
|
||||
// must be available to the current TensorFlow environment.
|
||||
//
|
||||
@ -29,15 +45,14 @@ namespace {
|
||||
// `GetDescriptorPool()` for more information.
|
||||
Status GetDescriptorPoolFromFile(
|
||||
tensorflow::Env* env, const string& filename,
|
||||
std::unique_ptr<tensorflow::protobuf::DescriptorPool>* owned_desc_pool) {
|
||||
std::unique_ptr<protobuf::DescriptorPool>* owned_desc_pool) {
|
||||
Status st = env->FileExists(filename);
|
||||
if (!st.ok()) {
|
||||
return st;
|
||||
}
|
||||
|
||||
// Read and parse the FileDescriptorSet.
|
||||
tensorflow::protobuf::FileDescriptorSet descs;
|
||||
std::unique_ptr<tensorflow::ReadOnlyMemoryRegion> buf;
|
||||
protobuf::FileDescriptorSet descs;
|
||||
std::unique_ptr<ReadOnlyMemoryRegion> buf;
|
||||
st = env->NewReadOnlyMemoryRegionFromFile(filename, &buf);
|
||||
if (!st.ok()) {
|
||||
return st;
|
||||
@ -46,25 +61,31 @@ Status GetDescriptorPoolFromFile(
|
||||
return errors::InvalidArgument(
|
||||
"descriptor_source contains invalid FileDescriptorSet: ", filename);
|
||||
}
|
||||
return CreatePoolFromSet(descs, owned_desc_pool);
|
||||
}
|
||||
|
||||
// Build a DescriptorPool from the FileDescriptorSet.
|
||||
owned_desc_pool->reset(new tensorflow::protobuf::DescriptorPool());
|
||||
for (const auto& filedesc : descs.file()) {
|
||||
if ((*owned_desc_pool)->BuildFile(filedesc) == nullptr) {
|
||||
return errors::InvalidArgument(
|
||||
"Problem loading FileDescriptorProto (missing dependencies?): ",
|
||||
filename);
|
||||
}
|
||||
Status GetDescriptorPoolFromBinary(
|
||||
const string& source,
|
||||
std::unique_ptr<protobuf::DescriptorPool>* owned_desc_pool) {
|
||||
if (!absl::StartsWith(source, "bytes://")) {
|
||||
return errors::InvalidArgument(
|
||||
"Source does not represent serialized file descriptor set proto.");
|
||||
}
|
||||
return Status::OK();
|
||||
// Parse the FileDescriptorSet.
|
||||
protobuf::FileDescriptorSet proto;
|
||||
if (!proto.ParseFromString(string(absl::StripPrefix(source, "bytes://")))) {
|
||||
return errors::InvalidArgument(
|
||||
"Source does not represent serialized file descriptor set proto.");
|
||||
}
|
||||
return CreatePoolFromSet(proto, owned_desc_pool);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status GetDescriptorPool(
|
||||
tensorflow::Env* env, string const& descriptor_source,
|
||||
tensorflow::protobuf::DescriptorPool const** desc_pool,
|
||||
std::unique_ptr<tensorflow::protobuf::DescriptorPool>* owned_desc_pool) {
|
||||
Env* env, string const& descriptor_source,
|
||||
protobuf::DescriptorPool const** desc_pool,
|
||||
std::unique_ptr<protobuf::DescriptorPool>* owned_desc_pool) {
|
||||
// Attempt to lookup the pool in the registry.
|
||||
auto pool_fn = DescriptorPoolRegistry::Global()->Get(descriptor_source);
|
||||
if (pool_fn != nullptr) {
|
||||
@ -77,7 +98,10 @@ Status GetDescriptorPool(
|
||||
GetDescriptorPoolFromFile(env, descriptor_source, owned_desc_pool);
|
||||
if (status.ok()) {
|
||||
*desc_pool = owned_desc_pool->get();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
status = GetDescriptorPoolFromBinary(descriptor_source, owned_desc_pool);
|
||||
*desc_pool = owned_desc_pool->get();
|
||||
return status;
|
||||
}
|
||||
|
@ -25,17 +25,27 @@ namespace tensorflow {
|
||||
class Env;
|
||||
class Status;
|
||||
|
||||
// Get a `DescriptorPool` object from the named `descriptor_source`.
|
||||
// `descriptor_source` may be a path to a file accessible to TensorFlow, in
|
||||
// which case it is parsed as a `FileDescriptorSet` and used to build the
|
||||
// `DescriptorPool`.
|
||||
// Gets a `DescriptorPool` object from the `descriptor_source`. This may be:
|
||||
//
|
||||
// `owned_desc_pool` will be filled in with the same pointer as `desc_pool` if
|
||||
// the caller should take ownership.
|
||||
extern tensorflow::Status GetDescriptorPool(
|
||||
tensorflow::Env* env, string const& descriptor_source,
|
||||
tensorflow::protobuf::DescriptorPool const** desc_pool,
|
||||
std::unique_ptr<tensorflow::protobuf::DescriptorPool>* owned_desc_pool);
|
||||
// 1) An empty string or "local://", in which case the local descriptor pool
|
||||
// created for proto definitions linked to the binary is returned.
|
||||
//
|
||||
// 2) A file path, in which case the descriptor pool is created from the
|
||||
// contents of the file, which is expected to contain a `FileDescriptorSet`
|
||||
// serialized as a string. The descriptor pool ownership is transferred to the
|
||||
// caller via `owned_desc_pool`.
|
||||
//
|
||||
// 3) A "bytes://<bytes>", in which case the descriptor pool is created from
|
||||
// `<bytes>`, which is expected to be a `FileDescriptorSet` serialized as a
|
||||
// string. The descriptor pool ownership is transferred to the caller via
|
||||
// `owned_desc_pool`.
|
||||
//
|
||||
// Custom schemas can be supported by registering a handler with the
|
||||
// `DescriptorPoolRegistry`.
|
||||
Status GetDescriptorPool(
|
||||
Env* env, string const& descriptor_source,
|
||||
protobuf::DescriptorPool const** desc_pool,
|
||||
std::unique_ptr<protobuf::DescriptorPool>* owned_desc_pool);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -51,18 +51,16 @@ class DescriptorSourceTestBase(test.TestCase):
|
||||
#
|
||||
# The generated descriptor should capture the subset of `test_example.proto`
|
||||
# used in `test_base.simple_test_case()`.
|
||||
def _createDescriptorFile(self):
|
||||
set_proto = FileDescriptorSet()
|
||||
def _createDescriptorProto(self):
|
||||
proto = FileDescriptorSet()
|
||||
|
||||
file_proto = set_proto.file.add(
|
||||
name='types.proto',
|
||||
package='tensorflow',
|
||||
syntax='proto3')
|
||||
file_proto = proto.file.add(
|
||||
name='types.proto', package='tensorflow', syntax='proto3')
|
||||
enum_proto = file_proto.enum_type.add(name='DataType')
|
||||
enum_proto.value.add(name='DT_DOUBLE', number=0)
|
||||
enum_proto.value.add(name='DT_BOOL', number=1)
|
||||
|
||||
file_proto = set_proto.file.add(
|
||||
file_proto = proto.file.add(
|
||||
name='test_example.proto',
|
||||
package='tensorflow.contrib.proto',
|
||||
dependency=['types.proto'])
|
||||
@ -123,9 +121,12 @@ class DescriptorSourceTestBase(test.TestCase):
|
||||
type_name='.tensorflow.contrib.proto.TestValue',
|
||||
label=FieldDescriptorProto.LABEL_OPTIONAL)
|
||||
|
||||
return proto
|
||||
|
||||
def _writeProtoToFile(self, proto):
|
||||
fn = os.path.join(self.get_temp_dir(), 'descriptor.pb')
|
||||
with open(fn, 'wb') as f:
|
||||
f.write(set_proto.SerializeToString())
|
||||
f.write(proto.SerializeToString())
|
||||
return fn
|
||||
|
||||
def _testRoundtrip(self, descriptor_source):
|
||||
@ -169,8 +170,12 @@ class DescriptorSourceTestBase(test.TestCase):
|
||||
def testWithFileDescriptorSet(self):
|
||||
# First try parsing with a local proto db, which should fail.
|
||||
with self.assertRaisesOpError('No descriptor found for message type'):
|
||||
self._testRoundtrip('local://')
|
||||
self._testRoundtrip(b'local://')
|
||||
|
||||
# Now try parsing with a FileDescriptorSet which contains the test proto.
|
||||
descriptor_file = self._createDescriptorFile()
|
||||
self._testRoundtrip(descriptor_file)
|
||||
proto = self._createDescriptorProto()
|
||||
proto_file = self._writeProtoToFile(proto)
|
||||
self._testRoundtrip(proto_file)
|
||||
|
||||
# Finally, try parsing the descriptor as a serialized string.
|
||||
self._testRoundtrip(b'bytes://' + proto.SerializeToString())
|
||||
|
Loading…
Reference in New Issue
Block a user