Adding support for bytes://
schema as a descriptor source for proto decoding and encoding ops.
This change makes using the ops with a custom proto simpler. Previously, users would need to either: 1) package the C++ target for the custom proto as a shared library and dynamically load the library in the Python program that wishes to use the ops with instances of the custom proto. 2) use the proto compiler to generate a file with a descriptor for the custom proto and make sure that the Python program that wishes to use the ops with instances of the custom proto has access to the file. The `bytes://` schema makes it possible to embed the (serialized) descriptor proto into the descriptor source string, which means the Python binaries are self-contained (decoding does not depend on the ability to read the descriptor proto from a file) and do not need to jump through dynamic loading hoops. PiperOrigin-RevId: 253327493
This commit is contained in:
parent
db3ecd34a4
commit
fb1a1c872c
@ -69,50 +69,48 @@ The `decode_proto` op extracts fields from a serialized protocol buffers
|
|||||||
message into tensors. The fields in `field_names` are decoded and converted
|
message into tensors. The fields in `field_names` are decoded and converted
|
||||||
to the corresponding `output_types` if possible.
|
to the corresponding `output_types` if possible.
|
||||||
|
|
||||||
A `message_type` name must be provided to give context for the field
|
A `message_type` name must be provided to give context for the field names.
|
||||||
names. The actual message descriptor can be looked up either in the
|
The actual message descriptor can be looked up either in the linked-in
|
||||||
linked-in descriptor pool or a filename provided by the caller using
|
descriptor pool or a filename provided by the caller using the
|
||||||
the `descriptor_source` attribute.
|
`descriptor_source` attribute.
|
||||||
|
|
||||||
Each output tensor is a dense tensor. This means that it is padded to
|
Each output tensor is a dense tensor. This means that it is padded to hold
|
||||||
hold the largest number of repeated elements seen in the input
|
the largest number of repeated elements seen in the input minibatch. (The
|
||||||
minibatch. (The shape is also padded by one to prevent zero-sized
|
shape is also padded by one to prevent zero-sized dimensions). The actual
|
||||||
dimensions). The actual repeat counts for each example in the
|
repeat counts for each example in the minibatch can be found in the `sizes`
|
||||||
minibatch can be found in the `sizes` output. In many cases the output
|
output. In many cases the output of `decode_proto` is fed immediately into
|
||||||
of `decode_proto` is fed immediately into tf.squeeze if missing values
|
tf.squeeze if missing values are not a concern. When using tf.squeeze, always
|
||||||
are not a concern. When using tf.squeeze, always pass the squeeze
|
pass the squeeze dimension explicitly to avoid surprises.
|
||||||
dimension explicitly to avoid surprises.
|
|
||||||
|
|
||||||
For the most part, the mapping between Proto field types and
|
For the most part, the mapping between Proto field types and TensorFlow dtypes
|
||||||
TensorFlow dtypes is straightforward. However, there are a few
|
is straightforward. However, there are a few special cases:
|
||||||
special cases:
|
|
||||||
|
|
||||||
- A proto field that contains a submessage or group can only be converted
|
- A proto field that contains a submessage or group can only be converted
|
||||||
to `DT_STRING` (the serialized submessage). This is to reduce the
|
to `DT_STRING` (the serialized submessage). This is to reduce the complexity
|
||||||
complexity of the API. The resulting string can be used as input
|
of the API. The resulting string can be used as input to another instance of
|
||||||
to another instance of the decode_proto op.
|
the decode_proto op.
|
||||||
|
|
||||||
- TensorFlow lacks support for unsigned integers. The ops represent uint64
|
- TensorFlow lacks support for unsigned integers. The ops represent uint64
|
||||||
types as a `DT_INT64` with the same twos-complement bit pattern
|
types as a `DT_INT64` with the same twos-complement bit pattern (the obvious
|
||||||
(the obvious way). Unsigned int32 values can be represented exactly by
|
way). Unsigned int32 values can be represented exactly by specifying type
|
||||||
specifying type `DT_INT64`, or using twos-complement if the caller
|
`DT_INT64`, or using twos-complement if the caller specifies `DT_INT32` in
|
||||||
specifies `DT_INT32` in the `output_types` attribute.
|
the `output_types` attribute.
|
||||||
|
|
||||||
The `descriptor_source` attribute selects a source of protocol
|
|
||||||
descriptors to consult when looking up `message_type`. This may be a
|
|
||||||
filename containing a serialized `FileDescriptorSet` message,
|
|
||||||
or the special value `local://`, in which case only descriptors linked
|
|
||||||
into the code will be searched; the filename can be on any filesystem
|
|
||||||
accessible to TensorFlow.
|
|
||||||
|
|
||||||
You can build a `descriptor_source` file using the `--descriptor_set_out`
|
|
||||||
and `--include_imports` options to the protocol compiler `protoc`.
|
|
||||||
|
|
||||||
The `local://` database only covers descriptors linked into the
|
|
||||||
code via C++ libraries, not Python imports. You can link in a proto descriptor
|
|
||||||
by creating a cc_library target with alwayslink=1.
|
|
||||||
|
|
||||||
Both binary and text proto serializations are supported, and can be
|
Both binary and text proto serializations are supported, and can be
|
||||||
chosen using the `format` attribute.
|
chosen using the `format` attribute.
|
||||||
|
|
||||||
|
The `descriptor_source` attribute selects the source of protocol
|
||||||
|
descriptors to consult when looking up `message_type`. This may be:
|
||||||
|
|
||||||
|
- An empty string or "local://", in which case protocol descriptors are
|
||||||
|
created for C++ (not Python) proto definitions linked to the binary.
|
||||||
|
|
||||||
|
- A file, in which case protocol descriptors are created from the file,
|
||||||
|
which is expected to contain a `FileDescriptorSet` serialized as a string.
|
||||||
|
NOTE: You can build a `descriptor_source` file using the `--descriptor_set_out`
|
||||||
|
and `--include_imports` options to the protocol compiler `protoc`.
|
||||||
|
|
||||||
|
- A "bytes://<bytes>", in which protocol descriptors are created from `<bytes>`,
|
||||||
|
which is expected to be a `FileDescriptorSet` serialized as a string.
|
||||||
END
|
END
|
||||||
}
|
}
|
||||||
|
@ -41,42 +41,45 @@ END
|
|||||||
The op serializes protobuf messages provided in the input tensors.
|
The op serializes protobuf messages provided in the input tensors.
|
||||||
END
|
END
|
||||||
description: <<END
|
description: <<END
|
||||||
The types of the tensors in `values` must match the schema for the
|
The types of the tensors in `values` must match the schema for the fields
|
||||||
fields specified in `field_names`. All the tensors in `values` must
|
specified in `field_names`. All the tensors in `values` must have a common
|
||||||
have a common shape prefix, *batch_shape*.
|
shape prefix, *batch_shape*.
|
||||||
|
|
||||||
The `sizes` tensor specifies repeat counts for each field. The repeat
|
The `sizes` tensor specifies repeat counts for each field. The repeat count
|
||||||
count (last dimension) of a each tensor in `values` must be greater
|
(last dimension) of a each tensor in `values` must be greater than or equal
|
||||||
than or equal to corresponding repeat count in `sizes`.
|
to corresponding repeat count in `sizes`.
|
||||||
|
|
||||||
A `message_type` name must be provided to give context for the field
|
A `message_type` name must be provided to give context for the field names.
|
||||||
names. The actual message descriptor can be looked up either in the
|
The actual message descriptor can be looked up either in the linked-in
|
||||||
linked-in descriptor pool or a filename provided by the caller using
|
descriptor pool or a filename provided by the caller using the
|
||||||
the `descriptor_source` attribute.
|
`descriptor_source` attribute.
|
||||||
|
|
||||||
The `descriptor_source` attribute selects a source of protocol
|
For the most part, the mapping between Proto field types and TensorFlow dtypes
|
||||||
descriptors to consult when looking up `message_type`. This may be a
|
is straightforward. However, there are a few special cases:
|
||||||
filename containing a serialized `FileDescriptorSet` message,
|
|
||||||
or the special value `local://`, in which case only descriptors linked
|
|
||||||
into the code will be searched; the filename can be on any filesystem
|
|
||||||
accessible to TensorFlow.
|
|
||||||
|
|
||||||
You can build a `descriptor_source` file using the `--descriptor_set_out`
|
- A proto field that contains a submessage or group can only be converted
|
||||||
|
to `DT_STRING` (the serialized submessage). This is to reduce the complexity
|
||||||
|
of the API. The resulting string can be used as input to another instance of
|
||||||
|
the decode_proto op.
|
||||||
|
|
||||||
|
- TensorFlow lacks support for unsigned integers. The ops represent uint64
|
||||||
|
types as a `DT_INT64` with the same twos-complement bit pattern (the obvious
|
||||||
|
way). Unsigned int32 values can be represented exactly by specifying type
|
||||||
|
`DT_INT64`, or using twos-complement if the caller specifies `DT_INT32` in
|
||||||
|
the `output_types` attribute.
|
||||||
|
|
||||||
|
The `descriptor_source` attribute selects the source of protocol
|
||||||
|
descriptors to consult when looking up `message_type`. This may be:
|
||||||
|
|
||||||
|
- An empty string or "local://", in which case protocol descriptors are
|
||||||
|
created for C++ (not Python) proto definitions linked to the binary.
|
||||||
|
|
||||||
|
- A file, in which case protocol descriptors are created from the file,
|
||||||
|
which is expected to contain a `FileDescriptorSet` serialized as a string.
|
||||||
|
NOTE: You can build a `descriptor_source` file using the `--descriptor_set_out`
|
||||||
and `--include_imports` options to the protocol compiler `protoc`.
|
and `--include_imports` options to the protocol compiler `protoc`.
|
||||||
|
|
||||||
The `local://` database only covers descriptors linked into the
|
- A "bytes://<bytes>", in which protocol descriptors are created from `<bytes>`,
|
||||||
code via C++ libraries, not Python imports. You can link in a proto descriptor
|
which is expected to be a `FileDescriptorSet` serialized as a string.
|
||||||
by creating a cc_library target with alwayslink=1.
|
|
||||||
|
|
||||||
There are a few special cases in the value mapping:
|
|
||||||
|
|
||||||
Submessage and group fields must be pre-serialized as TensorFlow strings.
|
|
||||||
|
|
||||||
TensorFlow lacks support for unsigned int64s, so they must be
|
|
||||||
represented as `tf.int64` with the same twos-complement bit pattern
|
|
||||||
(the obvious way).
|
|
||||||
|
|
||||||
Unsigned int32 values can be represented exactly with `tf.int64`, or
|
|
||||||
with sign wrapping if the input is of type `tf.int32`.
|
|
||||||
END
|
END
|
||||||
}
|
}
|
||||||
|
@ -23,6 +23,7 @@ cc_library(
|
|||||||
":local_descriptor_pool_registration",
|
":local_descriptor_pool_registration",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -13,15 +13,31 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/util/proto/descriptors.h"
|
||||||
|
|
||||||
|
#include "absl/strings/match.h"
|
||||||
|
#include "absl/strings/strip.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/reader_op_kernel.h"
|
#include "tensorflow/core/framework/reader_op_kernel.h"
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/platform/protobuf.h"
|
||||||
#include "tensorflow/core/util/proto/descriptor_pool_registry.h"
|
#include "tensorflow/core/util/proto/descriptor_pool_registry.h"
|
||||||
|
|
||||||
#include "tensorflow/core/util/proto/descriptors.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
Status CreatePoolFromSet(const protobuf::FileDescriptorSet& set,
|
||||||
|
std::unique_ptr<protobuf::DescriptorPool>* out_pool) {
|
||||||
|
*out_pool = absl::make_unique<protobuf::DescriptorPool>();
|
||||||
|
for (const auto& file : set.file()) {
|
||||||
|
if ((*out_pool)->BuildFile(file) == nullptr) {
|
||||||
|
return errors::InvalidArgument("Failed to load FileDescriptorProto: ",
|
||||||
|
file.DebugString());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
// Build a `DescriptorPool` from the named file or URI. The file or URI
|
// Build a `DescriptorPool` from the named file or URI. The file or URI
|
||||||
// must be available to the current TensorFlow environment.
|
// must be available to the current TensorFlow environment.
|
||||||
//
|
//
|
||||||
@ -29,15 +45,14 @@ namespace {
|
|||||||
// `GetDescriptorPool()` for more information.
|
// `GetDescriptorPool()` for more information.
|
||||||
Status GetDescriptorPoolFromFile(
|
Status GetDescriptorPoolFromFile(
|
||||||
tensorflow::Env* env, const string& filename,
|
tensorflow::Env* env, const string& filename,
|
||||||
std::unique_ptr<tensorflow::protobuf::DescriptorPool>* owned_desc_pool) {
|
std::unique_ptr<protobuf::DescriptorPool>* owned_desc_pool) {
|
||||||
Status st = env->FileExists(filename);
|
Status st = env->FileExists(filename);
|
||||||
if (!st.ok()) {
|
if (!st.ok()) {
|
||||||
return st;
|
return st;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read and parse the FileDescriptorSet.
|
// Read and parse the FileDescriptorSet.
|
||||||
tensorflow::protobuf::FileDescriptorSet descs;
|
protobuf::FileDescriptorSet descs;
|
||||||
std::unique_ptr<tensorflow::ReadOnlyMemoryRegion> buf;
|
std::unique_ptr<ReadOnlyMemoryRegion> buf;
|
||||||
st = env->NewReadOnlyMemoryRegionFromFile(filename, &buf);
|
st = env->NewReadOnlyMemoryRegionFromFile(filename, &buf);
|
||||||
if (!st.ok()) {
|
if (!st.ok()) {
|
||||||
return st;
|
return st;
|
||||||
@ -46,25 +61,31 @@ Status GetDescriptorPoolFromFile(
|
|||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"descriptor_source contains invalid FileDescriptorSet: ", filename);
|
"descriptor_source contains invalid FileDescriptorSet: ", filename);
|
||||||
}
|
}
|
||||||
|
return CreatePoolFromSet(descs, owned_desc_pool);
|
||||||
|
}
|
||||||
|
|
||||||
// Build a DescriptorPool from the FileDescriptorSet.
|
Status GetDescriptorPoolFromBinary(
|
||||||
owned_desc_pool->reset(new tensorflow::protobuf::DescriptorPool());
|
const string& source,
|
||||||
for (const auto& filedesc : descs.file()) {
|
std::unique_ptr<protobuf::DescriptorPool>* owned_desc_pool) {
|
||||||
if ((*owned_desc_pool)->BuildFile(filedesc) == nullptr) {
|
if (!absl::StartsWith(source, "bytes://")) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"Problem loading FileDescriptorProto (missing dependencies?): ",
|
"Source does not represent serialized file descriptor set proto.");
|
||||||
filename);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return Status::OK();
|
// Parse the FileDescriptorSet.
|
||||||
|
protobuf::FileDescriptorSet proto;
|
||||||
|
if (!proto.ParseFromString(string(absl::StripPrefix(source, "bytes://")))) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Source does not represent serialized file descriptor set proto.");
|
||||||
|
}
|
||||||
|
return CreatePoolFromSet(proto, owned_desc_pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Status GetDescriptorPool(
|
Status GetDescriptorPool(
|
||||||
tensorflow::Env* env, string const& descriptor_source,
|
Env* env, string const& descriptor_source,
|
||||||
tensorflow::protobuf::DescriptorPool const** desc_pool,
|
protobuf::DescriptorPool const** desc_pool,
|
||||||
std::unique_ptr<tensorflow::protobuf::DescriptorPool>* owned_desc_pool) {
|
std::unique_ptr<protobuf::DescriptorPool>* owned_desc_pool) {
|
||||||
// Attempt to lookup the pool in the registry.
|
// Attempt to lookup the pool in the registry.
|
||||||
auto pool_fn = DescriptorPoolRegistry::Global()->Get(descriptor_source);
|
auto pool_fn = DescriptorPoolRegistry::Global()->Get(descriptor_source);
|
||||||
if (pool_fn != nullptr) {
|
if (pool_fn != nullptr) {
|
||||||
@ -77,7 +98,10 @@ Status GetDescriptorPool(
|
|||||||
GetDescriptorPoolFromFile(env, descriptor_source, owned_desc_pool);
|
GetDescriptorPoolFromFile(env, descriptor_source, owned_desc_pool);
|
||||||
if (status.ok()) {
|
if (status.ok()) {
|
||||||
*desc_pool = owned_desc_pool->get();
|
*desc_pool = owned_desc_pool->get();
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
status = GetDescriptorPoolFromBinary(descriptor_source, owned_desc_pool);
|
||||||
*desc_pool = owned_desc_pool->get();
|
*desc_pool = owned_desc_pool->get();
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
@ -25,17 +25,27 @@ namespace tensorflow {
|
|||||||
class Env;
|
class Env;
|
||||||
class Status;
|
class Status;
|
||||||
|
|
||||||
// Get a `DescriptorPool` object from the named `descriptor_source`.
|
// Gets a `DescriptorPool` object from the `descriptor_source`. This may be:
|
||||||
// `descriptor_source` may be a path to a file accessible to TensorFlow, in
|
|
||||||
// which case it is parsed as a `FileDescriptorSet` and used to build the
|
|
||||||
// `DescriptorPool`.
|
|
||||||
//
|
//
|
||||||
// `owned_desc_pool` will be filled in with the same pointer as `desc_pool` if
|
// 1) An empty string or "local://", in which case the local descriptor pool
|
||||||
// the caller should take ownership.
|
// created for proto definitions linked to the binary is returned.
|
||||||
extern tensorflow::Status GetDescriptorPool(
|
//
|
||||||
tensorflow::Env* env, string const& descriptor_source,
|
// 2) A file path, in which case the descriptor pool is created from the
|
||||||
tensorflow::protobuf::DescriptorPool const** desc_pool,
|
// contents of the file, which is expected to contain a `FileDescriptorSet`
|
||||||
std::unique_ptr<tensorflow::protobuf::DescriptorPool>* owned_desc_pool);
|
// serialized as a string. The descriptor pool ownership is transferred to the
|
||||||
|
// caller via `owned_desc_pool`.
|
||||||
|
//
|
||||||
|
// 3) A "bytes://<bytes>", in which case the descriptor pool is created from
|
||||||
|
// `<bytes>`, which is expected to be a `FileDescriptorSet` serialized as a
|
||||||
|
// string. The descriptor pool ownership is transferred to the caller via
|
||||||
|
// `owned_desc_pool`.
|
||||||
|
//
|
||||||
|
// Custom schemas can be supported by registering a handler with the
|
||||||
|
// `DescriptorPoolRegistry`.
|
||||||
|
Status GetDescriptorPool(
|
||||||
|
Env* env, string const& descriptor_source,
|
||||||
|
protobuf::DescriptorPool const** desc_pool,
|
||||||
|
std::unique_ptr<protobuf::DescriptorPool>* owned_desc_pool);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -51,18 +51,16 @@ class DescriptorSourceTestBase(test.TestCase):
|
|||||||
#
|
#
|
||||||
# The generated descriptor should capture the subset of `test_example.proto`
|
# The generated descriptor should capture the subset of `test_example.proto`
|
||||||
# used in `test_base.simple_test_case()`.
|
# used in `test_base.simple_test_case()`.
|
||||||
def _createDescriptorFile(self):
|
def _createDescriptorProto(self):
|
||||||
set_proto = FileDescriptorSet()
|
proto = FileDescriptorSet()
|
||||||
|
|
||||||
file_proto = set_proto.file.add(
|
file_proto = proto.file.add(
|
||||||
name='types.proto',
|
name='types.proto', package='tensorflow', syntax='proto3')
|
||||||
package='tensorflow',
|
|
||||||
syntax='proto3')
|
|
||||||
enum_proto = file_proto.enum_type.add(name='DataType')
|
enum_proto = file_proto.enum_type.add(name='DataType')
|
||||||
enum_proto.value.add(name='DT_DOUBLE', number=0)
|
enum_proto.value.add(name='DT_DOUBLE', number=0)
|
||||||
enum_proto.value.add(name='DT_BOOL', number=1)
|
enum_proto.value.add(name='DT_BOOL', number=1)
|
||||||
|
|
||||||
file_proto = set_proto.file.add(
|
file_proto = proto.file.add(
|
||||||
name='test_example.proto',
|
name='test_example.proto',
|
||||||
package='tensorflow.contrib.proto',
|
package='tensorflow.contrib.proto',
|
||||||
dependency=['types.proto'])
|
dependency=['types.proto'])
|
||||||
@ -123,9 +121,12 @@ class DescriptorSourceTestBase(test.TestCase):
|
|||||||
type_name='.tensorflow.contrib.proto.TestValue',
|
type_name='.tensorflow.contrib.proto.TestValue',
|
||||||
label=FieldDescriptorProto.LABEL_OPTIONAL)
|
label=FieldDescriptorProto.LABEL_OPTIONAL)
|
||||||
|
|
||||||
|
return proto
|
||||||
|
|
||||||
|
def _writeProtoToFile(self, proto):
|
||||||
fn = os.path.join(self.get_temp_dir(), 'descriptor.pb')
|
fn = os.path.join(self.get_temp_dir(), 'descriptor.pb')
|
||||||
with open(fn, 'wb') as f:
|
with open(fn, 'wb') as f:
|
||||||
f.write(set_proto.SerializeToString())
|
f.write(proto.SerializeToString())
|
||||||
return fn
|
return fn
|
||||||
|
|
||||||
def _testRoundtrip(self, descriptor_source):
|
def _testRoundtrip(self, descriptor_source):
|
||||||
@ -169,8 +170,12 @@ class DescriptorSourceTestBase(test.TestCase):
|
|||||||
def testWithFileDescriptorSet(self):
|
def testWithFileDescriptorSet(self):
|
||||||
# First try parsing with a local proto db, which should fail.
|
# First try parsing with a local proto db, which should fail.
|
||||||
with self.assertRaisesOpError('No descriptor found for message type'):
|
with self.assertRaisesOpError('No descriptor found for message type'):
|
||||||
self._testRoundtrip('local://')
|
self._testRoundtrip(b'local://')
|
||||||
|
|
||||||
# Now try parsing with a FileDescriptorSet which contains the test proto.
|
# Now try parsing with a FileDescriptorSet which contains the test proto.
|
||||||
descriptor_file = self._createDescriptorFile()
|
proto = self._createDescriptorProto()
|
||||||
self._testRoundtrip(descriptor_file)
|
proto_file = self._writeProtoToFile(proto)
|
||||||
|
self._testRoundtrip(proto_file)
|
||||||
|
|
||||||
|
# Finally, try parsing the descriptor as a serialized string.
|
||||||
|
self._testRoundtrip(b'bytes://' + proto.SerializeToString())
|
||||||
|
Loading…
Reference in New Issue
Block a user