Open sourcing proto/rpc ops.

PiperOrigin-RevId: 191962572
This commit is contained in:
Jiri Simsa 2018-04-06 17:17:22 -07:00 committed by TensorFlower Gardener
parent ddf54d1c24
commit 5e11bbacaf
45 changed files with 4394 additions and 0 deletions

View File

@ -354,6 +354,9 @@ tensorflow/contrib/periodic_resample
tensorflow/contrib/periodic_resample/python
tensorflow/contrib/periodic_resample/python/ops
tensorflow/contrib/predictor
tensorflow/contrib/proto
tensorflow/contrib/proto/python
tensorflow/contrib/proto/python/ops
tensorflow/contrib/quantization
tensorflow/contrib/quantization/python
tensorflow/contrib/quantize
@ -382,6 +385,9 @@ tensorflow/contrib/rnn/ops
tensorflow/contrib/rnn/python
tensorflow/contrib/rnn/python/kernel_tests
tensorflow/contrib/rnn/python/ops
tensorflow/contrib/rpc
tensorflow/contrib/rpc/python
tensorflow/contrib/rpc/python/ops
tensorflow/contrib/saved_model
tensorflow/contrib/saved_model/python
tensorflow/contrib/saved_model/python/saved_model

View File

@ -25,6 +25,8 @@ set(tf_op_lib_names
"cudnn_rnn_ops"
"data_flow_ops"
"dataset_ops"
"decode_proto_ops"
"encode_proto_ops"
"functional_ops"
"image_ops"
"io_ops"
@ -40,6 +42,7 @@ set(tf_op_lib_names
"random_ops"
"remote_fused_graph_ops"
"resource_variable_ops"
"rpc_ops"
"script_ops"
"sdca_ops"
"set_ops"

View File

@ -330,6 +330,8 @@ GENERATE_PYTHON_OP_LIB("ctc_ops")
GENERATE_PYTHON_OP_LIB("cudnn_rnn_ops")
GENERATE_PYTHON_OP_LIB("data_flow_ops")
GENERATE_PYTHON_OP_LIB("dataset_ops")
GENERATE_PYTHON_OP_LIB("decode_proto_ops")
GENERATE_PYTHON_OP_LIB("encode_proto_ops")
GENERATE_PYTHON_OP_LIB("image_ops")
GENERATE_PYTHON_OP_LIB("io_ops")
GENERATE_PYTHON_OP_LIB("linalg_ops")
@ -343,6 +345,7 @@ GENERATE_PYTHON_OP_LIB("random_ops")
GENERATE_PYTHON_OP_LIB("remote_fused_graph_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/remote_fused_graph/pylib/python/ops/gen_remote_fused_graph_ops.py)
GENERATE_PYTHON_OP_LIB("resource_variable_ops")
GENERATE_PYTHON_OP_LIB("rpc_ops")
GENERATE_PYTHON_OP_LIB("script_ops")
GENERATE_PYTHON_OP_LIB("sdca_ops")
GENERATE_PYTHON_OP_LIB("set_ops")

View File

@ -301,3 +301,5 @@ tensorflow/core/kernels/warn_about_ints.cc
tensorflow/core/kernels/segment_reduction_ops.cc
tensorflow/core/kernels/batch_util.cc
tensorflow/core/ops/audio_ops.cc
tensorflow/core/kernels/decode_proto_op.cc
tensorflow/core/kernels/encode_proto_op.cc

View File

@ -0,0 +1,16 @@
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
py_library(
name = "proto",
srcs = [
"__init__.py",
],
deps = [
"//tensorflow/contrib/proto/python/ops:decode_proto_op_py",
"//tensorflow/contrib/proto/python/ops:encode_proto_op_py",
],
)

View File

@ -0,0 +1,28 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Ops and modules related to proto.
@@decode_proto
@@encode_proto
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.proto.python.ops.decode_proto_op import decode_proto
from tensorflow.contrib.proto.python.ops.encode_proto_op import encode_proto
from tensorflow.python.util.all_util import remove_undocumented
remove_undocumented(__name__)

View File

@ -0,0 +1,44 @@
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
load(
"//tensorflow:tensorflow.bzl",
"tf_gen_op_wrapper_py",
)
py_library(
name = "decode_proto_op_py",
srcs = ["decode_proto_op.py"],
deps = [
":gen_decode_proto_op_py",
"//tensorflow/python:framework_ops",
],
)
tf_gen_op_wrapper_py(
name = "gen_decode_proto_op_py",
out = "gen_decode_proto_op.py",
deps = [
"//tensorflow/core:decode_proto_ops_op_lib",
],
)
py_library(
name = "encode_proto_op_py",
srcs = ["encode_proto_op.py"],
deps = [
":gen_encode_proto_op_py",
"//tensorflow/python:framework_ops",
],
)
tf_gen_op_wrapper_py(
name = "gen_encode_proto_op_py",
out = "gen_encode_proto_op.py",
deps = [
"//tensorflow/core:encode_proto_ops_op_lib",
],
)

View File

@ -0,0 +1,25 @@
# =============================================================================
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
# pylint: disable=wildcard-import,unused-import
"""Protocol Buffer decoding from tensors."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.proto.python.ops.gen_decode_proto_op import decode_proto_v2 as decode_proto
from tensorflow.python.framework import ops
ops.NotDifferentiable("DecodeProtoV2")

View File

@ -0,0 +1,25 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
# pylint: disable=wildcard-import,unused-import
"""Protocol Buffer encoding from tensors."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.proto.python.ops.gen_encode_proto_op import encode_proto
from tensorflow.python.framework import ops
ops.NotDifferentiable("EncodeProto")

View File

@ -0,0 +1,13 @@
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
py_library(
name = "rpc",
srcs = [
"__init__.py",
],
deps = ["//tensorflow/contrib/rpc/python/ops:rpc_op_py"],
)

View File

@ -0,0 +1,28 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Ops and modules related to RPC.
@@rpc
@@try_rpc
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.rpc.python.ops.rpc_op import rpc
from tensorflow.contrib.rpc.python.ops.rpc_op import try_rpc
from tensorflow.python.util.all_util import remove_undocumented
remove_undocumented(__name__)

View File

@ -0,0 +1,24 @@
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
py_library(
name = "rpc_op_py",
srcs = ["rpc_op.py"],
deps = [
":gen_rpc_op_py",
"//tensorflow/python:framework_ops",
],
)
tf_gen_op_wrapper_py(
name = "gen_rpc_op_py",
out = "gen_rpc_op.py",
deps = [
"//tensorflow/core:rpc_ops_op_lib",
],
)

View File

@ -0,0 +1,26 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
# pylint: disable=wildcard-import,unused-import
"""RPC communication."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.rpc.python.ops.gen_rpc_op import rpc
from tensorflow.contrib.rpc.python.ops.gen_rpc_op import try_rpc
from tensorflow.python.framework import ops
ops.NotDifferentiable("Rpc")
ops.NotDifferentiable("TryRpc")

View File

@ -637,6 +637,8 @@ tf_gen_op_libs(
"ctc_ops",
"data_flow_ops",
"dataset_ops",
"decode_proto_ops",
"encode_proto_ops",
"function_ops",
"functional_ops",
"image_ops",
@ -653,6 +655,7 @@ tf_gen_op_libs(
"random_ops",
"remote_fused_graph_ops",
"resource_variable_ops",
"rpc_ops",
"scoped_allocator_ops",
"sdca_ops",
"set_ops",
@ -751,6 +754,8 @@ cc_library(
":cudnn_rnn_ops_op_lib",
":data_flow_ops_op_lib",
":dataset_ops_op_lib",
":decode_proto_ops_op_lib",
":encode_proto_ops_op_lib",
":function_ops_op_lib",
":functional_ops_op_lib",
":image_ops_op_lib",
@ -767,6 +772,7 @@ cc_library(
":random_ops_op_lib",
":remote_fused_graph_ops_op_lib",
":resource_variable_ops_op_lib",
":rpc_ops_op_lib",
":scoped_allocator_ops_op_lib",
":script_ops_op_lib",
":sdca_ops_op_lib",
@ -893,6 +899,8 @@ cc_library(
"//tensorflow/core/kernels:cudnn_rnn_kernels",
"//tensorflow/core/kernels:data_flow",
"//tensorflow/core/kernels:dataset_ops",
"//tensorflow/core/kernels:decode_proto_op",
"//tensorflow/core/kernels:encode_proto_op",
"//tensorflow/core/kernels:fake_quant_ops",
"//tensorflow/core/kernels:function_ops",
"//tensorflow/core/kernels:functional_ops",
@ -914,6 +922,7 @@ cc_library(
"//tensorflow/core/kernels:remote_fused_graph_ops",
"//tensorflow/core/kernels:required",
"//tensorflow/core/kernels:resource_variable_ops",
"//tensorflow/core/kernels:rpc_op",
"//tensorflow/core/kernels:scoped_allocator_ops",
"//tensorflow/core/kernels:sdca_ops",
"//tensorflow/core/kernels:set_kernels",

View File

@ -0,0 +1,116 @@
op {
graph_op_name: "DecodeProtoV2"
in_arg {
name: "bytes"
description: <<END
Tensor of serialized protos with shape `batch_shape`.
END
}
out_arg {
name: "sizes"
description: <<END
Tensor of int32 with shape `[batch_shape, len(field_names)]`.
Each entry is the number of values found for the corresponding field.
Optional fields may have 0 or 1 values.
END
}
out_arg {
name: "values"
description: <<END
List of tensors containing values for the corresponding field.
`values[i]` has datatype `output_types[i]`
and shape `[batch_shape, max(sizes[...,i])]`.
END
}
attr {
name: "message_type"
description: <<END
Name of the proto message type to decode.
END
}
attr {
name: "field_names"
description: <<END
List of strings containing proto field names.
END
}
attr {
name: "output_types"
description: <<END
List of TF types to use for the respective field in field_names.
END
}
attr {
name: "descriptor_source"
description: <<END
Either the special value `local://` or a path to a file containing
a serialized `FileDescriptorSet`.
END
}
attr {
name: "message_format"
description: <<END
Either `binary` or `text`.
END
}
attr {
name: "sanitize"
description: <<END
Whether to sanitize the result or not.
END
}
summary: <<END
The op extracts fields from a serialized protocol buffers message into tensors.
END
description: <<END
The `decode_proto` op extracts fields from a serialized protocol buffers
message into tensors. The fields in `field_names` are decoded and converted
to the corresponding `output_types` if possible.
A `message_type` name must be provided to give context for the field
names. The actual message descriptor can be looked up either in the
linked-in descriptor pool or a filename provided by the caller using
the `descriptor_source` attribute.
Each output tensor is a dense tensor. This means that it is padded to
hold the largest number of repeated elements seen in the input
minibatch. (The shape is also padded by one to prevent zero-sized
dimensions). The actual repeat counts for each example in the
minibatch can be found in the `sizes` output. In many cases the output
of `decode_proto` is fed immediately into tf.squeeze if missing values
are not a concern. When using tf.squeeze, always pass the squeeze
dimension explicitly to avoid surprises.
For the most part, the mapping between Proto field types and
TensorFlow dtypes is straightforward. However, there are a few
special cases:
- A proto field that contains a submessage or group can only be converted
to `DT_STRING` (the serialized submessage). This is to reduce the
complexity of the API. The resulting string can be used as input
to another instance of the decode_proto op.
- TensorFlow lacks support for unsigned integers. The ops represent uint64
types as a `DT_INT64` with the same twos-complement bit pattern
(the obvious way). Unsigned int32 values can be represented exactly by
specifying type `DT_INT64`, or using twos-complement if the caller
specifies `DT_INT32` in the `output_types` attribute.
The `descriptor_source` attribute selects a source of protocol
descriptors to consult when looking up `message_type`. This may be a
filename containing a serialized `FileDescriptorSet` message,
or the special value `local://`, in which case only descriptors linked
into the code will be searched; the filename can be on any filesystem
accessible to TensorFlow.
You can build a `descriptor_source` file using the `--descriptor_set_out`
and `--include_imports` options to the protocol compiler `protoc`.
The `local://` database only covers descriptors linked into the
code via C++ libraries, not Python imports. You can link in a proto descriptor
by creating a cc_library target with alwayslink=1.
Both binary and text proto serializations are supported, and can be
chosen using the `format` attribute.
END
}

View File

@ -0,0 +1,81 @@
op {
graph_op_name: "EncodeProto"
in_arg {
name: "sizes"
description: <<END
Tensor of int32 with shape `[batch_shape, len(field_names)]`.
END
}
in_arg {
name: "values"
description: <<END
List of tensors containing values for the corresponding field.
END
}
out_arg {
name: "bytes"
description: <<END
Tensor of serialized protos with shape `batch_shape`.
END
}
attr {
name: "message_type"
description: <<END
Name of the proto message type to decode.
END
}
attr {
name: "field_names"
description: <<END
List of strings containing proto field names.
END
}
attr {
name: "Tinput_types"
description: <<END
The input types.
END
}
summary: <<END
The op serializes protobuf messages provided in the input tensors.
END
description: <<END
The types of the tensors in `values` must match the schema for the
fields specified in `field_names`. All the tensors in `values` must
have a common shape prefix, *batch_shape*.
The `sizes` tensor specifies repeat counts for each field. The repeat
count (last dimension) of a each tensor in `values` must be greater
than or equal to corresponding repeat count in `sizes`.
A `message_type` name must be provided to give context for the field
names. The actual message descriptor can be looked up either in the
linked-in descriptor pool or a filename provided by the caller using
the `descriptor_source` attribute.
The `descriptor_source` attribute selects a source of protocol
descriptors to consult when looking up `message_type`. This may be a
filename containing a serialized `FileDescriptorSet` message,
or the special value `local://`, in which case only descriptors linked
into the code will be searched; the filename can be on any filesystem
accessible to TensorFlow.
You can build a `descriptor_source` file using the `--descriptor_set_out`
and `--include_imports` options to the protocol compiler `protoc`.
The `local://` database only covers descriptors linked into the
code via C++ libraries, not Python imports. You can link in a proto descriptor
by creating a cc_library target with alwayslink=1.
There are a few special cases in the value mapping:
Submessage and group fields must be pre-serialized as TensorFlow strings.
TensorFlow lacks support for unsigned int64s, so they must be
represented as `tf.int64` with the same twos-complement bit pattern
(the obvious way).
Unsigned int32 values can be represented exactly with `tf.int64`, or
with sign wrapping if the input is of type `tf.int32`.
END
}

View File

@ -0,0 +1,108 @@
op {
graph_op_name: "Rpc"
in_arg {
name: "address"
description: <<END
`0-D` or `1-D`. The address (i.e. host_name:port) of the RPC server.
If this tensor has more than 1 element, then multiple parallel rpc requests
are sent. This argument broadcasts with `method` and `request`.
END
}
in_arg {
name: "method"
description: <<END
`0-D` or `1-D`. The method address on the RPC server.
If this tensor has more than 1 element, then multiple parallel rpc requests
are sent. This argument broadcasts with `address` and `request`.
END
}
in_arg {
name: "request"
description: <<END
`0-D` or `1-D`. Serialized proto strings: the rpc request argument.
If this tensor has more than 1 element, then multiple parallel rpc requests
are sent. This argument broadcasts with `address` and `method`.
END
}
out_arg {
name: "response"
description: <<END
Same shape as `request`. Serialized proto strings: the rpc responses.
END
}
attr {
name: "protocol"
description: <<END
RPC protocol to use. Empty string means use the default protocol.
Options include 'grpc'.
END
}
attr {
name: "fail_fast"
description: <<END
`boolean`. If `true` (default), then failures to connect
(i.e., the server does not immediately respond) cause an RPC failure.
END
}
attr {
name: "timeout_in_ms"
description: <<END
`int`. If `0` (default), then the kernel will run the RPC
request and only time out if the RPC deadline passes or the session times out.
If this value is greater than `0`, then the op will raise an exception if
the RPC takes longer than `timeout_in_ms`.
END
}
summary: <<END
Perform batches of RPC requests.
END
description: <<END
This op asynchronously performs either a single RPC request, or a batch
of requests. RPC requests are defined by three main parameters:
- `address` (the host+port or BNS address of the request)
- `method` (the RPC method name for the request)
- `request` (the serialized proto string, or vector of strings,
of the RPC request argument).
For example, if you have an RPC service running on port localhost:2345,
and its interface is configured with the following proto declaration:
```
service MyService {
rpc MyMethod(MyRequestProto) returns (MyResponseProto) {
}
};
```
then call this op with arguments:
```
address = "localhost:2345"
method = "MyService/MyMethod"
```
The `request` tensor is a string tensor representing serialized `MyRequestProto`
strings; and the output string tensor `response` will have the same shape
and contain (upon successful completion) corresponding serialized
`MyResponseProto` strings.
For example, to send a single, empty, `MyRequestProto`, call
this op with `request = ""`. To send 5 **parallel** empty requests,
call this op with `request = ["", "", "", "", ""]`.
More generally, one can create a batch of `MyRequestProto` serialized protos
from regular batched tensors using the `encode_proto` op, and convert
the response `MyResponseProto` serialized protos to batched tensors
using the `decode_proto` op.
**NOTE** Working with serialized proto strings is faster than instantiating
actual proto objects in memory, so no performance degradation is expected
compared to writing custom kernels for this workflow.
If the connection fails or the remote worker returns an error
status, the op reraises this exception locally.
See the `TryRpc` op if you prefer to handle RPC failures manually in the graph.
END
}

View File

@ -0,0 +1,123 @@
op {
graph_op_name: "TryRpc"
in_arg {
name: "address"
description: <<END
`0-D` or `1-D`. The address (i.e. host_name:port) of the RPC server.
If this tensor has more than 1 element, then multiple parallel rpc requests
are sent. This argument broadcasts with `method` and `request`.
END
}
in_arg {
name: "method"
description: <<END
`0-D` or `1-D`. The method address on the RPC server.
If this tensor has more than 1 element, then multiple parallel rpc requests
are sent. This argument broadcasts with `address` and `request`.
END
}
in_arg {
name: "request"
description: <<END
`0-D` or `1-D`. Serialized proto strings: the rpc request argument.
If this tensor has more than 1 element, then multiple parallel rpc requests
are sent. This argument broadcasts with `address` and `method`.
END
}
out_arg {
name: "response"
description: <<END
Same shape as `request`. Serialized proto strings: the rpc responses.
END
}
out_arg {
name: "status_code"
description: <<END
Same shape as `request`. Values correspond to tensorflow Status enum codes.
END
}
out_arg {
name: "status_message"
description: <<END
Same shape as `request`. Values correspond to Status messages
returned from the RPC calls.
END
}
attr {
name: "protocol"
description: <<END
RPC protocol to use. Empty string means use the default protocol.
Options include 'grpc'.
END
}
attr {
name: "fail_fast"
description: <<END
`boolean`. If `true` (default), then failures to connect
(i.e., the server does not immediately respond) cause an RPC failure.
END
}
attr {
name: "timeout_in_ms"
description: <<END
`int`. If `0` (default), then the kernel will run the RPC
request and only time out if the RPC deadline passes or the session times out.
If this value is greater than `0`, then the op will raise an exception if
the RPC takes longer than `timeout_in_ms`.
END
}
summary: <<END
Perform batches of RPC requests.
END
description: <<END
This op asynchronously performs either a single RPC request, or a batch
of requests. RPC requests are defined by three main parameters:
- `address` (the host+port or BNS address of the request)
- `method` (the method name for the request)
- `request` (the serialized proto string, or vector of strings,
of the RPC request argument).
For example, if you have an RPC service running on port localhost:2345,
and its interface is configured with the following proto declaration:
```
service MyService {
rpc MyMethod(MyRequestProto) returns (MyResponseProto) {
}
};
```
then call this op with arguments:
```
address = "localhost:2345"
method = "MyService/MyMethod"
```
The `request` tensor is a string tensor representing serialized `MyRequestProto`
strings; and the output string tensor `response` will have the same shape
and contain (upon successful completion) corresponding serialized
`MyResponseProto` strings.
For example, to send a single, empty, `MyRequestProto`, call
this op with `request = ""`. To send 5 **parallel** empty requests,
call this op with `request = ["", "", "", "", ""]`.
More generally, one can create a batch of `MyRequestProto` serialized protos
from regular batched tensors using the `encode_proto` op, and convert
the response `MyResponseProto` serialized protos to batched tensors
using the `decode_proto` op.
**NOTE** Working with serialized proto strings is faster than instantiating
actual proto objects in memory, so no performance degradation is expected
compared to writing custom kernels for this workflow.
Unlike the standard `Rpc` op, if the connection fails or the remote worker
returns an error status, this op does **not** reraise the exception.
Instead, the `status_code` and `status_message` entry for the corresponding RPC
call is set with the error returned from the RPC call. The `response` tensor
will contain valid response values for those minibatch entries whose RPCs did
not fail; the rest of the entries will have empty strings.
END
}

View File

@ -499,3 +499,33 @@ tf_cuda_cc_test(
"//tensorflow/core/kernels:variable_ops",
],
)
cc_library(
name = "grpc_rpc_factory",
srcs = [
"grpc_rpc_factory.cc",
],
hdrs = ["grpc_rpc_factory.h"],
deps = [
":grpc_state",
":grpc_util",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/util/rpc:call_container",
"//tensorflow/core/util/rpc:rpc_factory",
],
)
cc_library(
name = "grpc_rpc_factory_registration",
srcs = [
"grpc_rpc_factory_registration.cc",
],
deps = [
":grpc_rpc_factory",
"//tensorflow/core/util/rpc:rpc_factory",
"//tensorflow/core/util/rpc:rpc_factory_registry",
],
alwayslink = 1,
)

View File

@ -0,0 +1,213 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <string>
#include <vector>
#include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/util/rpc/call_container.h"
#include "tensorflow/core/util/rpc/rpc_factory.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h"
namespace tensorflow {
namespace {
class GrpcCall {
public:
explicit GrpcCall(CallContainer<GrpcCall>* container, int index, bool try_rpc,
const string* request_msg, string* response_msg,
int32* status_code, string* status_message)
: container_(container),
index_(index),
try_rpc_(try_rpc),
request_msg_(request_msg),
response_msg_(response_msg),
status_code_(status_code),
status_message_(status_message) {}
void StartCancel() { call_opts_.StartCancel(); }
void Done(const Status& s) {
DCHECK(container_ != nullptr);
if (!s.ok() && try_rpc_) {
DCHECK(status_code_ != nullptr);
DCHECK(status_message_ != nullptr);
*status_code_ = s.code();
*status_message_ = s.error_message();
}
container_->Done(s, index_);
}
const string& request() const { return *request_msg_; }
string* response() const { return response_msg_; }
CallOptions* call_opts() { return &call_opts_; }
private:
CallContainer<GrpcCall>* const container_;
const int index_;
bool try_rpc_;
CallOptions call_opts_;
const string* request_msg_;
string* response_msg_;
int* status_code_;
string* status_message_;
};
} // namespace
GrpcRPCFactory::GrpcRPCFactory(OpKernelConstruction* ctx, bool fail_fast,
int64 timeout_in_ms)
: RPCFactory(), fail_fast_(fail_fast), timeout_in_ms_(timeout_in_ms) {
// TODO(ebrevdo): Investigate possible performance improvements by
// replacing this thread with a threadpool.
polling_thread_ =
ctx->env()->StartThread(ThreadOptions(), "rpc_op_grpc_factory", [this]() {
void* tag;
bool ok;
while (completion_queue_.Next(&tag, &ok)) {
GrpcClientCQTag* callback_tag = static_cast<GrpcClientCQTag*>(tag);
callback_tag->OnCompleted(ok);
}
});
}
GrpcRPCFactory::~GrpcRPCFactory() {
// The amount of time we wait depends on several parameters, including:
// - the value of the fail_fast attribute.
// - the timeout option of the rpc call in the proto declaration.
// - the network roundtrip time and service's execution time.
//
// If a connection is made but the service doesn't ever respond, and
// there is no timeout option set for this rpc call, then it is
// possible the RPC request will wait forever.
//
completion_queue_.Shutdown();
delete polling_thread_;
}
void GrpcRPCFactory::Call(OpKernelContext* ctx, int64 num_elements,
const Tensor& address_t, const Tensor& method_t,
const Tensor& request_t, const bool try_rpc,
Tensor* response_t, Tensor* status_code_t,
Tensor* status_message_t,
AsyncOpKernel::DoneCallback done) {
auto address = address_t.flat<string>();
auto method = method_t.flat<string>();
auto request = request_t.flat<string>();
// Stubs are maintained by the GrpcRPCFactory class and will be
// deleted when the class is destroyed.
::grpc::GenericStub* singleton_stub = nullptr;
if (address.size() == 1) {
singleton_stub = GetOrCreateStubForAddress(address(0));
}
auto get_stub = [&address, this,
singleton_stub](int64 ix) -> ::grpc::GenericStub* {
return (address.size() > 1) ? GetOrCreateStubForAddress(address(ix))
: singleton_stub;
};
auto get_method_ptr = [&method](int64 ix) -> const string* {
return (method.size() > 1) ? &(method(ix)) : &(method(0));
};
auto get_request_ptr = [&request](int64 ix) -> const string* {
return (request.size() > 1) ? &(request(ix)) : &(request(0));
};
if (try_rpc) {
// In this case status_code will never be set in the response,
// so we just set it to OK.
DCHECK(status_code_t != nullptr);
status_code_t->flat<int32>().setConstant(
static_cast<int>(errors::Code::OK));
}
CancellationManager* cm = ctx->cancellation_manager();
CancellationToken cancellation_token = cm->get_cancellation_token();
// This object will delete itself when done.
auto* container =
new CallContainer<GrpcCall>(ctx, num_elements, fail_fast_, try_rpc,
std::move(done), cancellation_token);
auto response = response_t->flat<string>();
int32* status_code_ptr = nullptr;
string* status_message_ptr = nullptr;
if (try_rpc) {
status_code_ptr = status_code_t->flat<int32>().data();
status_message_ptr = status_message_t->flat<string>().data();
}
for (int i = 0; i < num_elements; ++i) {
container->calls()->emplace_back(
container, i, try_rpc, get_request_ptr(i), &response(i),
(try_rpc) ? &status_code_ptr[i] : nullptr,
(try_rpc) ? &status_message_ptr[i] : nullptr);
}
int i = 0;
for (GrpcCall& call : *(container->calls())) {
// This object will delete itself when done.
new RPCState<string>(get_stub(i), &completion_queue_, *get_method_ptr(i),
call.request(), call.response(),
/*done=*/[&call](const Status& s) { call.Done(s); },
call.call_opts(), fail_fast_, timeout_in_ms_);
++i;
}
// Need to register this callback after all the RPCs are in
// flight; otherwise we may try to cancel an RPC *before* it
// launches, which is a no-op, and then fall into a deadlock.
bool is_cancelled = !cm->RegisterCallback(
cancellation_token, [container]() { container->StartCancel(); });
if (is_cancelled) {
ctx->SetStatus(errors::Cancelled("Operation has been cancelled."));
// container's reference counter will take care of calling done().
container->StartCancel();
}
}
::grpc::GenericStub* GrpcRPCFactory::GetOrCreateStubForAddress(
const string& address) {
mutex_lock lock(mu_);
auto stub = stubs_.find(address);
if (stub != stubs_.end()) return stub->second.get();
ChannelPtr channel = CreateChannelForAddress(address);
auto* created = new ::grpc::GenericStub(channel);
stubs_[address].reset(created);
return created;
}
GrpcRPCFactory::ChannelPtr GrpcRPCFactory::CreateChannelForAddress(
const string& address) {
::grpc::ChannelArguments args;
args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, std::numeric_limits<int32>::max());
// Set a standard backoff timeout of 1s instead of the
// (sometimes default) 20s.
args.SetInt("grpc.testing.fixed_reconnect_backoff_ms", 1000);
return ::grpc::CreateCustomChannel(
/*target=*/address, ::grpc::InsecureChannelCredentials(), args);
}
} // namespace tensorflow

View File

@ -0,0 +1,59 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_RPC_FACTORY_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_RPC_FACTORY_H_
#include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/util/rpc/rpc_factory.h"
namespace tensorflow {
class GrpcRPCFactory : public RPCFactory {
public:
explicit GrpcRPCFactory(OpKernelConstruction* ctx, bool fail_fast,
int64 timeout_in_ms);
// Explicit destructor to control destruction order.
~GrpcRPCFactory() override;
void Call(OpKernelContext* ctx, int64 num_elements, const Tensor& address_t,
const Tensor& method_t, const Tensor& request_t, const bool try_rpc,
Tensor* response_t, Tensor* status_code_t, Tensor* status_message_t,
AsyncOpKernel::DoneCallback done) override;
protected:
typedef std::shared_ptr<::grpc::Channel> ChannelPtr;
virtual ChannelPtr CreateChannelForAddress(const string& address);
private:
::grpc::GenericStub* GetOrCreateStubForAddress(const string& address);
bool fail_fast_;
int64 timeout_in_ms_;
::grpc::CompletionQueue completion_queue_;
Thread* polling_thread_; // Owned.
mutex mu_;
typedef std::unique_ptr<::grpc::GenericStub> StubPtr;
std::unordered_map<string, StubPtr> stubs_ GUARDED_BY(mu_);
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_RPC_FACTORY_H_

View File

@ -0,0 +1,34 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h"
#include "tensorflow/core/util/rpc/rpc_factory.h"
#include "tensorflow/core/util/rpc/rpc_factory_registry.h"
namespace tensorflow {
namespace {
// Used for adding the grpc factory to the RPC factory registry.
struct Value {
static RPCFactory* Function(OpKernelConstruction* ctx, bool fail_fast,
int64 timeout_in_ms) {
return new GrpcRPCFactory(ctx, fail_fast, timeout_in_ms);
}
};
REGISTER_RPC_FACTORY("grpc", Value::Function);
} // namespace
} // namespace tensorflow

View File

@ -5121,6 +5121,9 @@ filegroup(
"summary_interface.*",
"summary_kernels.*",
"spectrogram_convert_test_data.cc",
"decode_proto_op.cc",
"encode_proto_op.cc",
"rpc_op.cc",
# Excluded due to experimental status:
"debug_ops.*",
"scatter_nd_op*",
@ -6153,6 +6156,50 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "decode_proto_op",
srcs = [
"decode_proto_op.cc",
],
deps = [
"//tensorflow/core:decode_proto_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/util/proto:decode",
"//tensorflow/core/util/proto:descriptors",
"//third_party/eigen3",
],
)
tf_kernel_library(
name = "encode_proto_op",
srcs = ["encode_proto_op.cc"],
deps = [
"//tensorflow/core:encode_proto_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/util/proto:descriptors",
"//third_party/eigen3",
],
)
tf_kernel_library(
name = "rpc_op",
srcs = [
"rpc_op.cc",
],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:rpc_ops_op_lib",
"//tensorflow/core/util/rpc:call_container",
"//tensorflow/core/util/rpc:rpc_factory",
"//tensorflow/core/util/rpc:rpc_factory_registry",
"//third_party/eigen3",
],
)
# -----------------------------------------------------------------------------
# Google-internal targets. These must be at the end for syncrepo.

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,591 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// EncodeProto is a TensorFlow Op which serializes tensors into
// arbitrary protobufs.
//
// See the docstring in ../ops/encode_proto_op.cc for usage of the op.
//
// This implementation writes the serialized format using a handful of
// calls from the WireFormatLite API.
#include <memory>
#include <vector>
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/util/proto/descriptors.h"
namespace tensorflow {
namespace {
using ::tensorflow::protobuf::Descriptor;
using ::tensorflow::protobuf::DescriptorPool;
using ::tensorflow::protobuf::FieldDescriptor;
using ::tensorflow::protobuf::internal::WireFormatLite;
using ::tensorflow::protobuf::io::CodedOutputStream;
using ::tensorflow::protobuf::io::StringOutputStream;
// Computes the total serialized size for a packed repeated field.
// For fixed-size types this can just multiply, but for variable-sized
// types it has to iterate through the values in the tensor.
template <WireFormatLite::FieldType FieldType, typename TensorT>
size_t TotalPackedSize(const Tensor& input, int message_index, int size);
template <>
size_t TotalPackedSize<WireFormatLite::TYPE_DOUBLE, double>(const Tensor& input,
int message_index,
int size) {
return size * WireFormatLite::kDoubleSize;
}
template <>
size_t TotalPackedSize<WireFormatLite::TYPE_FLOAT, double>(const Tensor& input,
int message_index,
int size) {
return size * WireFormatLite::kFloatSize;
}
template <>
size_t TotalPackedSize<WireFormatLite::TYPE_FLOAT, float>(const Tensor& input,
int message_index,
int size) {
return size * WireFormatLite::kFloatSize;
}
template <>
size_t TotalPackedSize<WireFormatLite::TYPE_INT64, int64>(const Tensor& input,
int message_index,
int size) {
size_t data_size = 0;
auto input_t = input.flat_inner_dims<int64>();
for (int64 i = 0; i < size; i++) {
data_size += WireFormatLite::Int64Size(
input_t(static_cast<int64>(message_index), i));
}
return data_size;
}
template <>
size_t TotalPackedSize<WireFormatLite::TYPE_UINT64, int64>(const Tensor& input,
int message_index,
int size) {
size_t data_size = 0;
auto input_t = input.flat_inner_dims<int64>();
for (int64 i = 0; i < size; i++) {
data_size += WireFormatLite::UInt64Size(
input_t(static_cast<int64>(message_index), i));
}
return data_size;
}
template <>
size_t TotalPackedSize<WireFormatLite::TYPE_INT32, int32>(const Tensor& input,
int message_index,
int size) {
size_t data_size = 0;
auto input_t = input.flat_inner_dims<int32>();
for (int64 i = 0; i < size; i++) {
data_size += WireFormatLite::Int32Size(
input_t(static_cast<int64>(message_index), i));
}
return data_size;
}
template <>
size_t TotalPackedSize<WireFormatLite::TYPE_FIXED64, int64>(const Tensor& input,
int message_index,
int size) {
return size * WireFormatLite::kFixed64Size;
}
template <>
size_t TotalPackedSize<WireFormatLite::TYPE_FIXED32, int64>(const Tensor& input,
int message_index,
int size) {
return size * WireFormatLite::kFixed32Size;
}
template <>
size_t TotalPackedSize<WireFormatLite::TYPE_FIXED32, int32>(const Tensor& input,
int message_index,
int size) {
return size * WireFormatLite::kFixed32Size;
}
template <>
size_t TotalPackedSize<WireFormatLite::TYPE_BOOL, bool>(const Tensor& input,
int message_index,
int size) {
return size * WireFormatLite::kBoolSize;
}
template <>
size_t TotalPackedSize<WireFormatLite::TYPE_UINT32, int64>(const Tensor& input,
int message_index,
int size) {
size_t data_size = 0;
auto input_t = input.flat_inner_dims<int64>();
for (int64 i = 0; i < size; i++) {
data_size += WireFormatLite::UInt32Size(
input_t(static_cast<int64>(message_index), i));
}
return data_size;
}
template <>
size_t TotalPackedSize<WireFormatLite::TYPE_UINT32, int32>(const Tensor& input,
int message_index,
int size) {
size_t data_size = 0;
auto input_t = input.flat_inner_dims<int32>();
for (int64 i = 0; i < size; i++) {
data_size += WireFormatLite::UInt32Size(
input_t(static_cast<int64>(message_index), i));
}
return data_size;
}
template <>
size_t TotalPackedSize<WireFormatLite::TYPE_ENUM, int32>(const Tensor& input,
int message_index,
int size) {
size_t data_size = 0;
auto input_t = input.flat_inner_dims<int32>();
for (int64 i = 0; i < size; i++) {
data_size +=
WireFormatLite::EnumSize(input_t(static_cast<int64>(message_index), i));
}
return data_size;
}
template <>
size_t TotalPackedSize<WireFormatLite::TYPE_SFIXED32, int32>(
const Tensor& input, int message_index, int size) {
return size * WireFormatLite::kSFixed32Size;
}
template <>
size_t TotalPackedSize<WireFormatLite::TYPE_SFIXED64, int64>(
const Tensor& input, int message_index, int size) {
return size * WireFormatLite::kSFixed64Size;
}
template <>
size_t TotalPackedSize<WireFormatLite::TYPE_SINT32, int32>(const Tensor& input,
int message_index,
int size) {
size_t data_size = 0;
auto input_t = input.flat_inner_dims<int32>();
for (int64 i = 0; i < size; i++) {
data_size += WireFormatLite::SInt32Size(
input_t(static_cast<int64>(message_index), i));
}
return data_size;
}
template <>
size_t TotalPackedSize<WireFormatLite::TYPE_SINT64, int64>(const Tensor& input,
int message_index,
int size) {
size_t data_size = 0;
auto input_t = input.flat_inner_dims<int64>();
for (int64 i = 0; i < size; i++) {
data_size += WireFormatLite::SInt64Size(
input_t(static_cast<int64>(message_index), i));
}
return data_size;
}
// Writes a possibly repeated primitive field.
// TensorFlow does not have unsigned types, so we decode them to signed and
// encode them back to unsigned.
template <typename TensorT, typename ProtoT,
WireFormatLite::FieldType FieldType,
void Writer(ProtoT, CodedOutputStream*)>
void WriteField(const FieldDescriptor& field_desc, const Tensor& input,
int message_index, int size, CodedOutputStream* output) {
auto wire_type = WireFormatLite::WireTypeForFieldType(
WireFormatLite::FieldType(field_desc.type()));
auto input_t = input.flat_inner_dims<TensorT>();
if (field_desc.options().packed()) {
// Write the tag for the packed field.
WireFormatLite::WriteTag(field_desc.number(),
WireFormatLite::WIRETYPE_LENGTH_DELIMITED, output);
// Write the total packed length.
size_t data_size =
TotalPackedSize<FieldType, TensorT>(input, message_index, size);
output->WriteVarint32(data_size);
// Write individual values.
for (int64 i = 0; i < size; i++) {
// Note implicit cast from signed to unsigned.
const ProtoT& value = input_t(static_cast<int64>(message_index), i);
Writer(value, output);
}
} else {
for (int64 i = 0; i < size; i++) {
WireFormatLite::WriteTag(field_desc.number(), wire_type, output);
// Note implicit cast from signed to unsigned.
const ProtoT& value = input_t(static_cast<int64>(message_index), i);
Writer(value, output);
}
}
}
// Writes a possibly repeated string, bytes, or message field.
template <typename T, void Writer(int, const T&, CodedOutputStream*)>
void WriteVarLenField(const FieldDescriptor& field_desc, const Tensor& input,
int message_index, int size, CodedOutputStream* output) {
auto input_t = input.flat_inner_dims<T>();
for (int64 i = 0; i < size; i++) {
const T& value = input_t(static_cast<int64>(message_index), i);
// TODO(nix): there doesn't seem to be an inlined version of
// WireFormatLite::WriteString or its relatives, which might allow a
// small speedup.
Writer(field_desc.number(), value, output);
}
}
// Writes a group field.
// Groups are treated like submessages, but tag-delimited
// instead of length-delimited. WireFormatLite handles this
// differently so we code it ourselves.
void WriteGroup(const FieldDescriptor& field_desc, const Tensor& input,
int message_index, int size, CodedOutputStream* output) {
auto input_t = input.flat_inner_dims<string>();
for (int64 i = 0; i < size; i++) {
const string& value = input_t(static_cast<int64>(message_index), i);
WireFormatLite::WriteTag(field_desc.number(),
WireFormatLite::WIRETYPE_START_GROUP, output);
// Note the use of WriteRaw instead of WriteString to skip the length.
output->WriteRaw(value.data(), value.size());
WireFormatLite::WriteTag(field_desc.number(),
WireFormatLite::WIRETYPE_END_GROUP, output);
}
}
// Writes a (possibly repeated) field into an output stream.
// It is the caller's responsibility to ensure that the type of
// the input tensor is compatible with the type of the proto
// field descriptor, and that (message_index, size-1) is within
// bounds.
void WriteField(const FieldDescriptor& field_desc, const Tensor& input,
int message_index, int size, CodedOutputStream* output) {
DataType tf_type = input.dtype();
switch (field_desc.type()) {
case WireFormatLite::TYPE_DOUBLE:
return WriteField<double, double, WireFormatLite::TYPE_DOUBLE,
WireFormatLite::WriteDoubleNoTag>(
field_desc, input, message_index, size, output);
case WireFormatLite::TYPE_FLOAT:
switch (tf_type) {
case DataType::DT_FLOAT:
return WriteField<float, float, WireFormatLite::TYPE_FLOAT,
WireFormatLite::WriteFloatNoTag>(
field_desc, input, message_index, size, output);
case DataType::DT_DOUBLE:
return WriteField<double, float, WireFormatLite::TYPE_FLOAT,
WireFormatLite::WriteFloatNoTag>(
field_desc, input, message_index, size, output);
default:
return;
}
case WireFormatLite::TYPE_INT64:
return WriteField<int64, protobuf_int64, WireFormatLite::TYPE_INT64,
WireFormatLite::WriteInt64NoTag>(
field_desc, input, message_index, size, output);
case WireFormatLite::TYPE_UINT64:
return WriteField<int64, protobuf_uint64, WireFormatLite::TYPE_UINT64,
WireFormatLite::WriteUInt64NoTag>(
field_desc, input, message_index, size, output);
case WireFormatLite::TYPE_INT32:
return WriteField<int32, int32, WireFormatLite::TYPE_INT32,
WireFormatLite::WriteInt32NoTag>(
field_desc, input, message_index, size, output);
case WireFormatLite::TYPE_FIXED64:
return WriteField<int64, protobuf_uint64, WireFormatLite::TYPE_FIXED64,
WireFormatLite::WriteFixed64NoTag>(
field_desc, input, message_index, size, output);
case WireFormatLite::TYPE_FIXED32:
switch (tf_type) {
case DataType::DT_INT64:
return WriteField<int64, uint32, WireFormatLite::TYPE_FIXED32,
WireFormatLite::WriteFixed32NoTag>(
field_desc, input, message_index, size, output);
case DataType::DT_INT32:
return WriteField<int32, uint32, WireFormatLite::TYPE_FIXED32,
WireFormatLite::WriteFixed32NoTag>(
field_desc, input, message_index, size, output);
default:
return;
}
case WireFormatLite::TYPE_BOOL:
return WriteField<bool, bool, WireFormatLite::TYPE_BOOL,
WireFormatLite::WriteBoolNoTag>(
field_desc, input, message_index, size, output);
case WireFormatLite::TYPE_STRING:
return WriteVarLenField<string, WireFormatLite::WriteString>(
field_desc, input, message_index, size, output);
case WireFormatLite::TYPE_GROUP:
return WriteGroup(field_desc, input, message_index, size, output);
case WireFormatLite::TYPE_MESSAGE:
return WriteVarLenField<string, WireFormatLite::WriteBytes>(
field_desc, input, message_index, size, output);
case WireFormatLite::TYPE_BYTES:
return WriteVarLenField<string, WireFormatLite::WriteBytes>(
field_desc, input, message_index, size, output);
case WireFormatLite::TYPE_UINT32:
switch (tf_type) {
case DataType::DT_INT64:
return WriteField<int64, uint32, WireFormatLite::TYPE_UINT32,
WireFormatLite::WriteUInt32NoTag>(
field_desc, input, message_index, size, output);
case DataType::DT_INT32:
return WriteField<int32, uint32, WireFormatLite::TYPE_UINT32,
WireFormatLite::WriteUInt32NoTag>(
field_desc, input, message_index, size, output);
default:
return;
}
case WireFormatLite::TYPE_ENUM:
return WriteField<int32, int32, WireFormatLite::TYPE_ENUM,
WireFormatLite::WriteEnumNoTag>(
field_desc, input, message_index, size, output);
case WireFormatLite::TYPE_SFIXED32:
return WriteField<int32, int32, WireFormatLite::TYPE_SFIXED32,
WireFormatLite::WriteSFixed32NoTag>(
field_desc, input, message_index, size, output);
case WireFormatLite::TYPE_SFIXED64:
return WriteField<int64, protobuf_int64, WireFormatLite::TYPE_SFIXED64,
WireFormatLite::WriteSFixed64NoTag>(
field_desc, input, message_index, size, output);
case WireFormatLite::TYPE_SINT32:
return WriteField<int32, int32, WireFormatLite::TYPE_SINT32,
WireFormatLite::WriteSInt32NoTag>(
field_desc, input, message_index, size, output);
case WireFormatLite::TYPE_SINT64:
return WriteField<int64, protobuf_int64, WireFormatLite::TYPE_SINT64,
WireFormatLite::WriteSInt64NoTag>(
field_desc, input, message_index, size, output);
// default: intentionally omitted in order to enable static checking.
}
}
// Checks that a Protobuf field is compatible with a TensorFlow datatype.
// This is separated from WriteField to lift it out of the inner loop.
bool IsCompatibleType(const FieldDescriptor& field_desc, DataType tf_type) {
switch (field_desc.type()) {
case WireFormatLite::TYPE_DOUBLE:
return tf_type == DataType::DT_DOUBLE;
case WireFormatLite::TYPE_FLOAT:
return tf_type == DataType::DT_FLOAT || tf_type == DataType::DT_DOUBLE;
case WireFormatLite::TYPE_INT64:
case WireFormatLite::TYPE_SFIXED64:
case WireFormatLite::TYPE_SINT64:
return tf_type == DataType::DT_INT64;
case WireFormatLite::TYPE_UINT64:
return tf_type == DataType::DT_INT64;
case WireFormatLite::TYPE_INT32:
case WireFormatLite::TYPE_ENUM:
case WireFormatLite::TYPE_SFIXED32:
case WireFormatLite::TYPE_SINT32:
return tf_type == DataType::DT_INT32;
case WireFormatLite::TYPE_FIXED64:
return tf_type == DataType::DT_INT64;
case WireFormatLite::TYPE_FIXED32:
case WireFormatLite::TYPE_UINT32:
return tf_type == DataType::DT_INT64 || tf_type == DataType::DT_INT32;
case WireFormatLite::TYPE_BOOL:
return tf_type == DataType::DT_BOOL;
case WireFormatLite::TYPE_STRING:
case WireFormatLite::TYPE_GROUP:
case WireFormatLite::TYPE_MESSAGE:
case WireFormatLite::TYPE_BYTES:
return tf_type == DataType::DT_STRING;
// default: intentionally omitted in order to enable static checking.
}
return false;
}
class EncodeProtoOp : public OpKernel {
public:
explicit EncodeProtoOp(OpKernelConstruction* context) : OpKernel(context) {
string descriptor_source;
OP_REQUIRES_OK(context,
context->GetAttr("descriptor_source", &descriptor_source));
// We always get back a desc_pool, but we may not own it. If we own it,
// owned_desc_pool_ will be filled in.
DescriptorPool const* desc_pool;
OP_REQUIRES_OK(context, GetDescriptorPool(context->env(), descriptor_source,
&desc_pool, &owned_desc_pool_));
string message_type;
OP_REQUIRES_OK(context, context->GetAttr("message_type", &message_type));
const Descriptor* message_desc =
desc_pool->FindMessageTypeByName(message_type);
OP_REQUIRES(context, message_desc != nullptr,
errors::InvalidArgument("No descriptor found for message type ",
message_type));
OP_REQUIRES_OK(context, context->GetAttr("field_names", &field_names_));
// Gather the field descriptors for the given field_names.
field_descs_.resize(field_names_.size());
for (int i = 0; i < field_names_.size(); i++) {
const string& name = field_names_[i];
auto field_desc = message_desc->FindFieldByName(name);
OP_REQUIRES(context, field_desc != nullptr,
errors::InvalidArgument("Unknown field: ", name,
" in message type ", message_type));
field_descs_[i] = field_desc;
}
// Build a list of indices into field_descs sorted by increasing
// field_number. This will be used to output fields in sorted order,
// which is strongly encouraged when serializing protobufs.
sorted_field_index_.resize(field_names_.size());
// Start with the fields sorted by current index.
for (int i = 0; i < field_names_.size(); i++) sorted_field_index_[i] = i;
// Then sort the field indices by their proto field number.
std::sort(sorted_field_index_.begin(), sorted_field_index_.end(),
[this](int a, int b) -> bool {
return field_descs_[a]->number() < field_descs_[b]->number();
});
}
void Compute(OpKernelContext* cx) override {
const Tensor* sizes_tensor;
OP_REQUIRES_OK(cx, cx->input("sizes", &sizes_tensor));
OpInputList values;
OP_REQUIRES_OK(cx, cx->input_list("values", &values));
OP_REQUIRES(cx, field_descs_.size() == values.size(),
errors::InvalidArgument(
"Length of inputs list must match field_names"));
// Check the arguments for consistency.
TensorShape common_prefix;
int message_count;
for (int i = 0; i < field_descs_.size(); i++) {
const Tensor& v = values[i];
// The type of each value tensor must match the corresponding field.
OP_REQUIRES(cx, IsCompatibleType(*field_descs_[i], v.dtype()),
errors::InvalidArgument(
"Incompatible type for field " + field_names_[i] +
". Saw dtype: ",
DataTypeString(v.dtype()),
" but field type is: ", field_descs_[i]->type_name()));
// All value tensors must have the same shape prefix (i.e. batch size).
TensorShape shape_prefix = v.shape();
shape_prefix.RemoveDim(shape_prefix.dims() - 1);
// Do some initialization on the first input value. The rest will
// have to match this one.
if (i == 0) {
OP_REQUIRES(cx, v.dims() >= 1,
errors::InvalidArgument(
"Expected value to be at least a vector, saw shape: ",
v.shape().DebugString()));
common_prefix = shape_prefix;
message_count = common_prefix.num_elements();
} else {
OP_REQUIRES(cx, shape_prefix == common_prefix,
errors::InvalidArgument(
"Values must match up to the last dimension"));
}
}
TensorShape expected_sizes_shape = common_prefix;
expected_sizes_shape.AddDim(field_descs_.size());
OP_REQUIRES(cx, sizes_tensor->shape() == expected_sizes_shape,
errors::InvalidArgument(
"sizes should be batch_size + [len(field_names)]. Saw: ",
sizes_tensor->shape().DebugString(),
" but expected: ", expected_sizes_shape.DebugString()));
auto sizes = sizes_tensor->flat_inner_dims<int32>();
for (int i = 0; i < field_descs_.size(); ++i) {
const Tensor& v = values[i];
int max_size = v.dim_size(v.dims() - 1);
// The last dimension of a value tensor must be greater than the
// corresponding
// size in the sizes tensor.
for (int message_index = 0; message_index < message_count;
message_index++) {
OP_REQUIRES(
cx, sizes(message_index, i) <= max_size,
errors::InvalidArgument(
"Size to write must not be larger than value tensor; but saw: ",
sizes(message_index, i), " > ", max_size, " at message ",
message_index, " field ", i));
}
}
// This pointer is owned by the context.
Tensor* output_tensor;
OP_REQUIRES_OK(cx, cx->allocate_output(0, common_prefix, &output_tensor));
auto bufs = output_tensor->flat<string>();
for (int message_index = 0; message_index < message_count;
message_index++) {
// TODO(nix): possibly optimize allocation here by calling
// bufs(message_index).reserve(DEFAULT_BUF_SIZE);
StringOutputStream output_string(&bufs(message_index));
CodedOutputStream out(&output_string);
// Write fields in ascending field_number order.
for (int i : sorted_field_index_) {
auto& field_desc = *field_descs_[i];
const Tensor& v = values[i];
int size = sizes(message_index, i);
if (!size) continue;
WriteField(field_desc, v, message_index, size, &out);
}
}
}
private:
std::vector<string> field_names_;
std::vector<const FieldDescriptor*> field_descs_;
// Owned_desc_pool_ is null when using descriptor_source=local.
std::unique_ptr<DescriptorPool> owned_desc_pool_;
// Contains indices into field_names_, sorted by field number since
// that's the order of writing.
std::vector<int> sorted_field_index_;
TF_DISALLOW_COPY_AND_ASSIGN(EncodeProtoOp);
};
REGISTER_KERNEL_BUILDER(Name("EncodeProto").Device(DEVICE_CPU), EncodeProtoOp);
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,129 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// RpcOp is a TensorFlow op that sends and receives arbitrary messages.
//
// See docs in ../ops/rpc_op.cc.
#include <memory>
#include <string>
#include <vector>
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/gtl/stl_util.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/rpc/call_container.h"
#include "tensorflow/core/util/rpc/rpc_factory.h"
#include "tensorflow/core/util/rpc/rpc_factory_registry.h"
namespace tensorflow {
class RpcOp : public AsyncOpKernel {
public:
explicit RpcOp(OpKernelConstruction* context) : AsyncOpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("protocol", &protocol_));
OP_REQUIRES(context, !protocol_.empty(),
errors::InvalidArgument("protocol must be non-empty."));
bool fail_fast;
OP_REQUIRES_OK(context, context->GetAttr("fail_fast", &fail_fast));
int64 timeout_in_ms;
OP_REQUIRES_OK(context, context->GetAttr("timeout_in_ms", &timeout_in_ms));
RPCFactoryRegistry::RPCFactoryFn* rpc_factory_fn =
RPCFactoryRegistry::Global()->Get(protocol_);
OP_REQUIRES(context, rpc_factory_fn != nullptr,
errors::InvalidArgument("The protocol ", protocol_,
" was not recognized."));
rpc_factory_.reset((*rpc_factory_fn)(context, fail_fast, timeout_in_ms));
}
~RpcOp() override {}
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
const Tensor& address_t = ctx->input(0);
const Tensor& method_t = ctx->input(1);
const Tensor& request_t = ctx->input(2);
OP_REQUIRES_ASYNC(
ctx, address_t.dims() == 0 || address_t.dims() == 1,
errors::InvalidArgument("address must be a scalar or vector."), done);
OP_REQUIRES_ASYNC(
ctx, method_t.dims() == 0 || method_t.dims() == 1,
errors::InvalidArgument("method must be a scalar or vector."), done);
OP_REQUIRES_ASYNC(
ctx, request_t.dims() == 0 || request_t.dims() == 1,
errors::InvalidArgument("request must be a scalar or vector."), done);
TensorShape output_shape({});
for (const Tensor& t : {address_t, method_t, request_t}) {
if (t.dims() == 1) {
OP_REQUIRES_ASYNC(
ctx,
output_shape.dims() == 0 ||
output_shape.dim_size(0) == t.dim_size(0),
errors::InvalidArgument(
"Input vector shapes don't match: ", output_shape.DebugString(),
" vs. ", t.shape().DebugString()),
done);
output_shape = t.shape();
}
}
Tensor* response_t;
OP_REQUIRES_OK_ASYNC(
ctx, ctx->allocate_output(0, output_shape, &response_t), done);
const bool try_rpc = (ctx->num_outputs() > 1);
Tensor* status_code_t = nullptr;
Tensor* status_message_t = nullptr;
if (try_rpc) {
OP_REQUIRES_OK_ASYNC(
ctx, ctx->allocate_output(1, output_shape, &status_code_t), done);
OP_REQUIRES_OK_ASYNC(
ctx, ctx->allocate_output(2, output_shape, &status_message_t), done);
}
if (request_t.NumElements() == 0) {
// Special case, we finished early!
done();
return;
}
int64 num_elements = output_shape.num_elements();
rpc_factory_->Call(ctx, num_elements, address_t, method_t, request_t,
try_rpc, response_t, status_code_t, status_message_t,
std::move(done));
}
private:
string protocol_;
std::unique_ptr<RPCFactory> rpc_factory_;
TF_DISALLOW_COPY_AND_ASSIGN(RpcOp);
};
REGISTER_KERNEL_BUILDER(Name("Rpc").Device(DEVICE_CPU), RpcOp);
REGISTER_KERNEL_BUILDER(Name("TryRpc").Device(DEVICE_CPU), RpcOp);
} // namespace tensorflow

View File

@ -0,0 +1,67 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
using tensorflow::shape_inference::InferenceContext;
using tensorflow::shape_inference::ShapeHandle;
REGISTER_OP("DecodeProtoV2")
.Input("bytes: string")
.Attr("message_type: string")
.Attr("field_names: list(string)")
.Attr("output_types: list(type) >= 0")
.Attr("descriptor_source: string = 'local://'")
.Attr("message_format: string = 'binary'")
.Attr("sanitize: bool = false")
.Output("sizes: int32")
.Output("values: output_types")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle input = c->input(0);
std::vector<tensorflow::DataType> output_types;
TF_RETURN_IF_ERROR(c->GetAttr("output_types", &output_types));
ShapeHandle sizes;
TF_RETURN_IF_ERROR(
c->Concatenate(input, c->Vector(output_types.size()), &sizes));
c->set_output(0, sizes);
// TODO(nix): to do the best possible job of shape inference, we
// should examine the proto descriptors here in order to set shape
// indices to 1 instead of unknown for optional or required fields.
// Any general-purpose code will have to handle the unknown case,
// but there might be XLA code that could be sped up with the additional
// knowledge.
for (int i = 0; i < output_types.size(); ++i) {
ShapeHandle values;
TF_RETURN_IF_ERROR(
c->Concatenate(input, c->Vector(c->UnknownDim()), &values));
c->set_output(i + 1, values);
}
return Status::OK();
});
// TODO(nix): Consider adding an additional input argument that truncates
// repeated fields to a maximum count. For now this could be done by passing
// the output through tf.slice.
// TODO(nix): define missing value behavior.
} // namespace tensorflow

View File

@ -0,0 +1,49 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
using tensorflow::shape_inference::InferenceContext;
using tensorflow::shape_inference::ShapeHandle;
REGISTER_OP("EncodeProto")
.Input("sizes: int32")
.Input("values: Tinput_types")
.Attr("field_names: list(string)")
.Attr("message_type: string")
.Attr("descriptor_source: string = 'local://'")
.Attr("Tinput_types: list(type)")
.Output("bytes: string")
.SetShapeFn([](InferenceContext* c) {
int first_field_index = 1;
int num_fields = c->num_inputs() - 1;
ShapeHandle output;
for (int i = num_fields - 1; i >= 0; --i) {
ShapeHandle input = c->input(first_field_index + i);
TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, 2, &input));
ShapeHandle inner;
TF_RETURN_IF_ERROR(c->Subshape(input, 0, -1, &inner));
TF_RETURN_IF_ERROR(c->Merge(inner, output, &output));
}
c->set_output(0, output);
return Status::OK();
});
} // namespace tensorflow

View File

@ -0,0 +1,81 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
using tensorflow::shape_inference::DimensionHandle;
using tensorflow::shape_inference::InferenceContext;
using tensorflow::shape_inference::ShapeHandle;
Status RpcShapeOp(InferenceContext* c, bool try_rpc) {
ShapeHandle address;
ShapeHandle method;
ShapeHandle request;
ShapeHandle output;
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &address));
if (c->Rank(address) == 1) {
TF_RETURN_IF_ERROR(c->Merge(output, address, &output));
}
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &method));
if (c->Rank(method) == 1) {
TF_RETURN_IF_ERROR(c->Merge(output, method, &output));
}
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &request));
if (c->Rank(request) == 1) {
TF_RETURN_IF_ERROR(c->Merge(output, request, &output));
}
if (!c->RankKnown(output)) {
output = request;
}
c->set_output(0, output); // response
if (try_rpc) {
c->set_output(1, output); // status_code
c->set_output(2, output); // status_message
}
return Status::OK();
}
REGISTER_OP("Rpc")
.Input("address: string")
.Input("method: string")
.Input("request: string")
.Attr("protocol: string = ''")
.Attr("fail_fast: bool = true")
.Attr("timeout_in_ms: int = 0")
.Output("response: string")
.SetIsStateful()
.SetShapeFn([](InferenceContext* c) {
return RpcShapeOp(c, /*try_rpc=*/false);
});
REGISTER_OP("TryRpc")
.Input("address: string")
.Input("method: string")
.Input("request: string")
.Attr("protocol: string = ''")
.Attr("fail_fast: bool = true")
.Attr("timeout_in_ms: int = 0")
.Output("response: string")
.Output("status_code: int32")
.Output("status_message: string")
.SetIsStateful()
.SetShapeFn([](InferenceContext* c) {
return RpcShapeOp(c, /*try_rpc=*/true);
});
} // namespace tensorflow

View File

@ -0,0 +1,62 @@
package(
default_visibility = ["//visibility:public"],
)
licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
cc_library(
name = "decode",
hdrs = ["decode.h"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
],
)
cc_library(
name = "descriptors",
srcs = ["descriptors.cc"],
hdrs = ["descriptors.h"],
deps = [
":descriptor_pool_registry",
":local_descriptor_pool_registration",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
],
)
cc_library(
name = "descriptor_pool_registry",
srcs = ["descriptor_pool_registry.cc"],
hdrs = ["descriptor_pool_registry.h"],
deps = [
"//tensorflow/core:lib",
],
)
tf_cc_test(
name = "descriptor_pool_registry_test",
srcs = ["descriptor_pool_registry_test.cc"],
deps = [
":descriptor_pool_registry",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
# Depending on this target adds support for using the special
# value "local://" (or "") for descriptor source, in which case
# descriptors linked into the code will be searched.
cc_library(
name = "local_descriptor_pool_registration",
srcs = ["local_descriptor_pool_registration.cc"],
deps = [
":descriptor_pool_registry",
"//tensorflow/core:lib",
],
alwayslink = 1,
)

View File

@ -0,0 +1,592 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Inline functions for parsing the protocol buffers wire format.
//
// These functions have been optimized at the expense of safety.
// They are broken out into a separate file for readability but are
// not intended for use by clients other than the decode_proto op.
//
// The calling code in the decode_proto op does some fairly
// complicated things to ensure that this code is called
// safely. Changes to this code should be thoroughly fuzz tested.
#ifndef TENSORFLOW_CORE_UTIL_PROTO_DECODE_H_
#define TENSORFLOW_CORE_UTIL_PROTO_DECODE_H_
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace internal {
using tensorflow::protobuf::internal::WireFormatLite;
using tensorflow::protobuf::io::CodedInputStream;
using tensorflow::protobuf::io::CodedOutputStream;
using tensorflow::protobuf::io::StringOutputStream;
// Converts an uint64 to an int64 without loss of information.
// Unsigned values greater than INT64_MAX are represented as
// negative numbers by wrapping (same as twos-complement bit equivalence).
inline int64 WrapUnsignedAsSigned64(uint64 unsigned_value) {
// For a detailed explanation of why this works to wrap unsigned ints, see
// http://stackoverflow.com/questions/13150449/efficient-unsigned-to-signed-cast-avoiding-implementation-defined-behavior
// Both if tests should be optimized out.
if (unsigned_value <= INT64_MAX) {
return static_cast<int64>(unsigned_value);
}
// The C++ spec allows an architecture where this test is required.
if (unsigned_value >= INT64_MIN) {
return static_cast<int64>(unsigned_value - INT64_MIN) + INT64_MIN;
}
return 0; // This should never occur.
}
// Converts an uint32 to an int32 without loss of information.
// Unsigned values greater than INT_MAX are represented as
// negative numbers by wrapping (same as twos-complement bit equivalence).
inline int32 WrapUnsignedAsSigned32(uint32 unsigned_value) {
// For a detailed explanation of why this works to wrap unsigned ints, see
// http://stackoverflow.com/questions/13150449/efficient-unsigned-to-signed-cast-avoiding-implementation-defined-behavior
// Both if tests should be optimized out.
if (unsigned_value <= INT_MAX) {
return static_cast<int32>(unsigned_value);
}
// The C++ spec allows an architecture where this test is required.
if (unsigned_value >= INT_MIN) {
return static_cast<int32>(unsigned_value - INT_MIN) + INT_MIN;
}
return 0; // This should never occur.
}
// Reads a single varint32 from a byte array.
// It is the caller's responsibility to ensure that there is enough
// space in the buffer.
// The ok value will be set to false if the buffer does not contain
// a valid varint.
inline const uint8* ReadVarint64FromArray(const uint8* buffer, bool* ok,
uint64* value);
// Reads a single varint32 from a byte array.
// It is the caller's responsibility to ensure that there is enough
// space in the buffer.
// The ok value will be set to false if the buffer does not contain
// a valid varint.
// This is slightly less efficient than the private version in
// coded_stream.cc but we duplicate less code by calling
// the 64 bit version instead of copying the code.
inline const uint8* ReadVarint32FromArray(const uint8* buffer, bool* ok,
uint32* value) {
uint64 tmp;
const uint8* buf = ReadVarint64FromArray(buffer, ok, &tmp);
*value = tmp & 0xffffffff;
return buf;
}
// Reads a single proto field value from a byte array into an array.
// The array is part of a Tensor that was allocated by the caller
// with type TensorType, while DeclaredType is the proto field type.
template <class TensorType, enum WireFormatLite::FieldType DeclaredType>
const uint8* ReadFromArray(const uint8* buf, TensorType* value);
template <>
inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_INT32>(
const uint8* buf, int32* value) {
uint32 temp;
bool unused_ok; // The Counting pass would have failed if this were corrupt.
buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
*value = static_cast<int32>(temp);
return buf;
}
template <>
inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_INT64>(
const uint8* buf, int64* value) {
uint64 temp;
bool unused_ok; // The Counting pass would have failed if this were corrupt.
buf = ReadVarint64FromArray(buf, &unused_ok, &temp);
*value = WrapUnsignedAsSigned64(temp);
return buf;
}
template <>
inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_UINT32>(
const uint8* buf, int64* value) {
uint32 temp;
bool unused_ok; // The Counting pass would have failed if this were corrupt.
buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
*value = temp;
return buf;
}
template <>
inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_UINT32>(
const uint8* buf, int32* value) {
uint32 temp;
bool unused_ok; // The Counting pass would have failed if this were corrupt.
buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
*value = WrapUnsignedAsSigned32(temp);
return buf;
}
template <>
inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_UINT64>(
const uint8* buf, int64* value) {
uint64 temp;
bool unused_ok; // The Counting pass would have failed if this were corrupt.
buf = ReadVarint64FromArray(buf, &unused_ok, &temp);
*value = static_cast<int64>(temp);
return buf;
}
template <>
inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_SINT32>(
const uint8* buf, int32* value) {
uint32 temp;
bool unused_ok; // The Counting pass would have failed if this were corrupt.
buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
*value = WireFormatLite::ZigZagDecode32(temp);
return buf;
}
template <>
inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_SINT64>(
const uint8* buf, int64* value) {
uint64 temp;
bool unused_ok; // The Counting pass would have failed if this were corrupt.
buf = ReadVarint64FromArray(buf, &unused_ok, &temp);
*value = WireFormatLite::ZigZagDecode64(temp);
return buf;
}
template <>
inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_FIXED32>(
const uint8* buf, int64* value) {
uint32 temp;
buf = WireFormatLite::ReadPrimitiveFromArray<uint32,
WireFormatLite::TYPE_FIXED32>(
buf, &temp);
*value = temp;
return buf;
}
template <>
inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_FIXED32>(
const uint8* buf, int32* value) {
uint32 temp;
buf = WireFormatLite::ReadPrimitiveFromArray<uint32,
WireFormatLite::TYPE_FIXED32>(
buf, &temp);
*value = WrapUnsignedAsSigned32(temp);
return buf;
}
template <>
inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_FIXED64>(
const uint8* buf, int64* value) {
protobuf_uint64 temp;
buf = WireFormatLite::ReadPrimitiveFromArray<protobuf_uint64,
WireFormatLite::TYPE_FIXED64>(
buf, &temp);
*value = WrapUnsignedAsSigned64(temp);
return buf;
}
template <>
inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_SFIXED32>(
const uint8* buf, int32* value) {
return WireFormatLite::ReadPrimitiveFromArray<int32,
WireFormatLite::TYPE_SFIXED32>(
buf, value);
}
template <>
inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_SFIXED64>(
const uint8* buf, int64* value) {
protobuf_int64 temp;
buf = WireFormatLite::ReadPrimitiveFromArray<protobuf_int64,
WireFormatLite::TYPE_SFIXED64>(
buf, &temp);
*value = temp;
return buf;
}
template <>
inline const uint8* ReadFromArray<float, WireFormatLite::TYPE_FLOAT>(
const uint8* buf, float* value) {
return WireFormatLite::ReadPrimitiveFromArray<float,
WireFormatLite::TYPE_FLOAT>(
buf, value);
}
template <>
inline const uint8* ReadFromArray<double, WireFormatLite::TYPE_DOUBLE>(
const uint8* buf, double* value) {
return WireFormatLite::ReadPrimitiveFromArray<double,
WireFormatLite::TYPE_DOUBLE>(
buf, value);
}
template <>
inline const uint8* ReadFromArray<bool, WireFormatLite::TYPE_BOOL>(
const uint8* buf, bool* value) {
uint64 temp;
bool unused_ok; // The Counting pass would have failed if this were corrupt.
buf = ReadVarint64FromArray(buf, &unused_ok, &temp);
*value = temp != 0;
return buf;
}
template <>
inline const uint8* ReadFromArray<int, WireFormatLite::TYPE_ENUM>(
const uint8* buf, int* value) {
uint32 temp;
bool unused_ok; // The Counting pass would have failed if this were corrupt.
buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
*value = static_cast<int>(temp);
return buf;
}
// Reads packed values from an array.
// Stride is set to 1 for repeated fields, and 0 for non-repeated fields
// (where any value overwrites previous values).
template <class TensorType, enum WireFormatLite::FieldType DeclaredType>
inline int ReadPackedPrimitives(const void* bufp, const size_t len,
const int index, const int stride,
void* datap) {
const uint8* buf = reinterpret_cast<const uint8*>(bufp);
const uint8* bound = buf + len;
TensorType* data = reinterpret_cast<TensorType*>(datap) + index;
int count;
// This could overrun the bound by stride-1. This is defended
// against in the caller, where it ensures that the input buffer
// contains complete values.
for (count = 0; buf < bound; count += stride) {
buf = ReadFromArray<TensorType, DeclaredType>(buf, data + count);
}
return count;
}
// Reads a primitive value field from a serialized proto.
// The value is parsed from the serialized format, then static_cast
// to the desired type for TensorFlow and stored.
template <class ValueType, class TensorType,
enum WireFormatLite::FieldType DeclaredType>
inline Status ReadPrimitive(CodedInputStream* input, int index, void* data) {
ValueType v;
if (!WireFormatLite::ReadPrimitive<ValueType, DeclaredType>(input, &v)) {
return errors::DataLoss("Failed reading primitive");
}
reinterpret_cast<TensorType*>(data)[index] = v;
return Status::OK();
}
// Reads a string, submessage, or other variable-length field from a
// serialized proto.
// May read all or part of a repeated field.
inline Status ReadBytes(CodedInputStream* input, int index, void* datap) {
string* data = reinterpret_cast<string*>(datap) + index;
if (!WireFormatLite::ReadBytes(input, data)) {
return errors::DataLoss("Failed reading bytes");
}
return Status::OK();
}
// Reads a tag-delimited field (TYPE_GROUP) from a serialized proto,
// as a bytestring.
inline Status ReadGroupBytes(CodedInputStream* input, int field_number,
int index, void* datap) {
// WireFormatLite::SkipField has an option to emit the
// skipped bytes to an output stream. We could do better by implementing our
// own scanner but this is simpler for now.
// TODO(nix): there is a faster way to grab TYPE_GROUP bytes by relying
// on input->IsFlat() == true and using input->GetDirectBufferPointer()
// with input->CurrentPosition().
string* data = reinterpret_cast<string*>(datap) + index;
StringOutputStream string_stream(data);
CodedOutputStream out(&string_stream);
if (!WireFormatLite::SkipField(
input,
WireFormatLite::MakeTag(field_number,
WireFormatLite::WIRETYPE_START_GROUP),
&out)) {
return errors::DataLoss("Failed reading group");
}
return Status::OK();
}
// Reads a single field value from a CodedInputStream into a tensor.
inline Status ReadValue(CodedInputStream* input,
WireFormatLite::FieldType field_type, int field_number,
DataType dtype, int index, void* datap) {
// Dispatch to the appropriately typed field reader based on the
// schema type.
switch (field_type) {
case WireFormatLite::TYPE_DOUBLE:
return ReadPrimitive<double, double, WireFormatLite::TYPE_DOUBLE>(
input, index, datap);
case WireFormatLite::TYPE_FLOAT:
if (dtype == DataType::DT_FLOAT) {
return ReadPrimitive<float, float, WireFormatLite::TYPE_FLOAT>(
input, index, datap);
}
if (dtype == DataType::DT_DOUBLE) {
return ReadPrimitive<float, double, WireFormatLite::TYPE_FLOAT>(
input, index, datap);
}
// Any case that reaches this point should have triggered an error
// already.
return errors::DataLoss("Failed reading TYPE_FLOAT");
case WireFormatLite::TYPE_INT64:
return ReadPrimitive<protobuf_int64, int64, WireFormatLite::TYPE_INT64>(
input, index, datap);
case WireFormatLite::TYPE_UINT64:
return ReadPrimitive<protobuf_uint64, int64, WireFormatLite::TYPE_UINT64>(
input, index, datap);
case WireFormatLite::TYPE_INT32:
return ReadPrimitive<int32, int32, WireFormatLite::TYPE_INT32>(
input, index, datap);
case WireFormatLite::TYPE_FIXED64:
return ReadPrimitive<protobuf_uint64, int64,
WireFormatLite::TYPE_FIXED64>(input, index, datap);
case WireFormatLite::TYPE_FIXED32:
if (dtype == DataType::DT_INT64) {
return ReadPrimitive<uint32, int64, WireFormatLite::TYPE_FIXED32>(
input, index, datap);
}
if (dtype == DataType::DT_INT32) {
return ReadPrimitive<uint32, int32, WireFormatLite::TYPE_FIXED32>(
input, index, datap);
}
// Any case that reaches this point should have triggered an error
// already.
return errors::DataLoss("Failed reading TYPE_FIXED32");
case WireFormatLite::TYPE_BOOL:
return ReadPrimitive<bool, bool, WireFormatLite::TYPE_BOOL>(input, index,
datap);
case WireFormatLite::TYPE_STRING:
return ReadBytes(input, index, datap);
case WireFormatLite::TYPE_GROUP:
return ReadGroupBytes(input, field_number, index, datap);
case WireFormatLite::TYPE_MESSAGE:
return ReadBytes(input, index, datap);
case WireFormatLite::TYPE_BYTES:
return ReadBytes(input, index, datap);
case WireFormatLite::TYPE_UINT32:
if (dtype == DataType::DT_INT64) {
return ReadPrimitive<uint32, int64, WireFormatLite::TYPE_UINT32>(
input, index, datap);
}
if (dtype == DataType::DT_INT32) {
return ReadPrimitive<uint32, int32, WireFormatLite::TYPE_UINT32>(
input, index, datap);
}
// Any case that reaches this point should have triggered an error
// already.
return errors::DataLoss("Failed reading TYPE_UINT32");
case WireFormatLite::TYPE_ENUM:
return ReadPrimitive<int32, int32, WireFormatLite::TYPE_ENUM>(
input, index, datap);
case WireFormatLite::TYPE_SFIXED32:
return ReadPrimitive<int32, int32, WireFormatLite::TYPE_SFIXED32>(
input, index, datap);
case WireFormatLite::TYPE_SFIXED64:
return ReadPrimitive<protobuf_int64, int64,
WireFormatLite::TYPE_SFIXED64>(input, index, datap);
case WireFormatLite::TYPE_SINT32:
return ReadPrimitive<int32, int32, WireFormatLite::TYPE_SINT32>(
input, index, datap);
case WireFormatLite::TYPE_SINT64:
return ReadPrimitive<protobuf_int64, int64, WireFormatLite::TYPE_SINT64>(
input, index, datap);
// default: intentionally omitted in order to enable static checking.
}
// Unreachable.
return errors::DataLoss("Failed reading unknown wire type");
}
// Reads and stores a length-delimited list of values.
inline Status ReadPackedFromArray(const void* buf, size_t buf_size,
const WireFormatLite::FieldType field_type,
const int field_number, const DataType dtype,
const int stride, int* index, void* data) {
// Dispatch to the appropriately typed field reader based on the
// schema type.
switch (field_type) {
case WireFormatLite::TYPE_DOUBLE:
*index += ReadPackedPrimitives<double, WireFormatLite::TYPE_DOUBLE>(
buf, buf_size, *index, stride, data);
return Status::OK();
case WireFormatLite::TYPE_FLOAT:
*index += ReadPackedPrimitives<float, WireFormatLite::TYPE_FLOAT>(
buf, buf_size, *index, stride, data);
return Status::OK();
case WireFormatLite::TYPE_INT64:
*index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_INT64>(
buf, buf_size, *index, stride, data);
return Status::OK();
case WireFormatLite::TYPE_UINT64:
*index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_UINT64>(
buf, buf_size, *index, stride, data);
return Status::OK();
case WireFormatLite::TYPE_INT32:
*index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_INT32>(
buf, buf_size, *index, stride, data);
return Status::OK();
case WireFormatLite::TYPE_FIXED64:
*index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_FIXED64>(
buf, buf_size, *index, stride, data);
return Status::OK();
case WireFormatLite::TYPE_FIXED32:
if (dtype == DataType::DT_INT64) {
*index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_FIXED32>(
buf, buf_size, *index, stride, data);
return Status::OK();
}
if (dtype == DataType::DT_INT32) {
*index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_FIXED32>(
buf, buf_size, *index, stride, data);
return Status::OK();
}
// Any case that reaches this point should have triggered an error
// already.
return errors::DataLoss("Failed reading TYPE_FIXED32");
case WireFormatLite::TYPE_BOOL:
*index += ReadPackedPrimitives<bool, WireFormatLite::TYPE_BOOL>(
buf, buf_size, *index, stride, data);
return Status::OK();
case WireFormatLite::TYPE_STRING:
case WireFormatLite::TYPE_GROUP:
case WireFormatLite::TYPE_MESSAGE:
case WireFormatLite::TYPE_BYTES:
return errors::DataLoss("Non-primitive type encountered as packed");
case WireFormatLite::TYPE_UINT32:
if (dtype == DataType::DT_INT64) {
*index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_UINT32>(
buf, buf_size, *index, stride, data);
return Status::OK();
}
if (dtype == DataType::DT_INT32) {
*index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_UINT32>(
buf, buf_size, *index, stride, data);
return Status::OK();
}
// Any case that reaches this point should have triggered an error
// already.
return errors::DataLoss("Failed reading TYPE_UINT32");
case WireFormatLite::TYPE_ENUM:
*index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_ENUM>(
buf, buf_size, *index, stride, data);
return Status::OK();
case WireFormatLite::TYPE_SFIXED32:
*index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_SFIXED32>(
buf, buf_size, *index, stride, data);
return Status::OK();
case WireFormatLite::TYPE_SFIXED64:
*index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_SFIXED64>(
buf, buf_size, *index, stride, data);
return Status::OK();
case WireFormatLite::TYPE_SINT32:
*index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_SINT32>(
buf, buf_size, *index, stride, data);
return Status::OK();
case WireFormatLite::TYPE_SINT64:
*index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_SINT64>(
buf, buf_size, *index, stride, data);
return Status::OK();
// default: intentionally omitted in order to enable static checking.
}
// Unreachable.
return errors::DataLoss("Failed reading unknown wire type");
}
// Reads a varint from the given buffer, write it to *value, and return the
// new buffer pointer.
// This was copied from coded_stream.cc where it is private.
// Important: This routine may read as much as kMaxVarintBytes from
// the buffer. It is the caller's responsibility to make sure that there is
// enough space in the buffer.
inline const uint8* ReadVarint64FromArray(const uint8* buffer, bool* ok,
uint64* value) {
const uint8* ptr = buffer;
uint32 b;
// Splitting into 32-bit pieces gives better performance on 32-bit
// processors.
uint32 part0 = 0, part1 = 0, part2 = 0;
b = *(ptr++);
part0 = b;
if (!(b & 0x80)) goto done;
part0 -= 0x80;
b = *(ptr++);
part0 += b << 7;
if (!(b & 0x80)) goto done;
part0 -= 0x80 << 7;
b = *(ptr++);
part0 += b << 14;
if (!(b & 0x80)) goto done;
part0 -= 0x80 << 14;
b = *(ptr++);
part0 += b << 21;
if (!(b & 0x80)) goto done;
part0 -= 0x80 << 21;
b = *(ptr++);
part1 = b;
if (!(b & 0x80)) goto done;
part1 -= 0x80;
b = *(ptr++);
part1 += b << 7;
if (!(b & 0x80)) goto done;
part1 -= 0x80 << 7;
b = *(ptr++);
part1 += b << 14;
if (!(b & 0x80)) goto done;
part1 -= 0x80 << 14;
b = *(ptr++);
part1 += b << 21;
if (!(b & 0x80)) goto done;
part1 -= 0x80 << 21;
b = *(ptr++);
part2 = b;
if (!(b & 0x80)) goto done;
part2 -= 0x80;
b = *(ptr++);
part2 += b << 7;
if (!(b & 0x80)) goto done;
// "part2 -= 0x80 << 7" is irrelevant because (0x80 << 7) << 56 is 0.
// We have overrun the maximum size of a varint (10 bytes). Assume
// the data is corrupt.
*ok = false;
return ptr;
done:
*ok = true;
*value = (static_cast<uint64>(part0)) | (static_cast<uint64>(part1) << 28) |
(static_cast<uint64>(part2) << 56);
return ptr;
}
} // namespace internal
} // namespace tensorflow
#endif // TENSORFLOW_CORE_UTIL_PROTO_DECODE_H_

View File

@ -0,0 +1,45 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string>
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/proto/descriptor_pool_registry.h"
namespace tensorflow {
DescriptorPoolRegistry* DescriptorPoolRegistry::Global() {
static DescriptorPoolRegistry* registry = new DescriptorPoolRegistry;
return registry;
}
DescriptorPoolRegistry::DescriptorPoolFn* DescriptorPoolRegistry::Get(
const string& source) {
auto found = fns_.find(source);
if (found == fns_.end()) return nullptr;
return &found->second;
}
void DescriptorPoolRegistry::Register(
const string& source,
const DescriptorPoolRegistry::DescriptorPoolFn& pool_fn) {
auto existing = Get(source);
CHECK_EQ(existing, nullptr)
<< "descriptor pool for source: " << source << " already registered";
fns_.insert(std::pair<const string&, DescriptorPoolFn>(source, pool_fn));
}
} // namespace tensorflow

View File

@ -0,0 +1,76 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_UTIL_PROTO_DESCRIPTOR_POOL_REGISTRY_H_
#define TENSORFLOW_CORE_UTIL_PROTO_DESCRIPTOR_POOL_REGISTRY_H_
#include <functional>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/protobuf.h"
namespace tensorflow {
class DescriptorPoolRegistry {
public:
typedef std::function<Status(
tensorflow::protobuf::DescriptorPool const** desc_pool,
std::unique_ptr<tensorflow::protobuf::DescriptorPool>* owned_desc_pool)>
DescriptorPoolFn;
// Returns a pointer to a global DescriptorPoolRegistry object.
static DescriptorPoolRegistry* Global();
// Returns a pointer to a descriptor pool function for the given source.
DescriptorPoolFn* Get(const string& source);
// Registers a descriptor pool factory.
void Register(const string& source, const DescriptorPoolFn& pool_fn);
private:
std::map<string, DescriptorPoolFn> fns_;
};
namespace descriptor_pool_registration {
class DescriptorPoolRegistration {
public:
DescriptorPoolRegistration(
const string& source,
const DescriptorPoolRegistry::DescriptorPoolFn& pool_fn) {
DescriptorPoolRegistry::Global()->Register(source, pool_fn);
}
};
} // namespace descriptor_pool_registration
#define REGISTER_DESCRIPTOR_POOL(source, pool_fn) \
REGISTER_DESCRIPTOR_POOL_UNIQ_HELPER(__COUNTER__, source, pool_fn)
#define REGISTER_DESCRIPTOR_POOL_UNIQ_HELPER(ctr, source, pool_fn) \
REGISTER_DESCRIPTOR_POOL_UNIQ(ctr, source, pool_fn)
#define REGISTER_DESCRIPTOR_POOL_UNIQ(ctr, source, pool_fn) \
static descriptor_pool_registration::DescriptorPoolRegistration \
descriptor_pool_registration_fn_##ctr(source, pool_fn)
} // namespace tensorflow
#endif // TENSORFLOW_CORE_UTIL_PROTO_DESCRIPTOR_POOL_REGISTRY_H_

View File

@ -0,0 +1,43 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/util/proto/descriptor_pool_registry.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
struct Value {
static Status Function(
tensorflow::protobuf::DescriptorPool const** desc_pool,
std::unique_ptr<tensorflow::protobuf::DescriptorPool>* owned_desc_pool) {
return Status::OK();
}
};
REGISTER_DESCRIPTOR_POOL("TEST POOL 1", Value::Function);
REGISTER_DESCRIPTOR_POOL("TEST POOL 2", Value::Function);
} // namespace
TEST(DescriptorPoolRegistryTest, TestBasic) {
EXPECT_EQ(DescriptorPoolRegistry::Global()->Get("NON-EXISTENT"), nullptr);
auto pool1 = DescriptorPoolRegistry::Global()->Get("TEST POOL 1");
EXPECT_NE(pool1, nullptr);
auto pool2 = DescriptorPoolRegistry::Global()->Get("TEST POOL 2");
EXPECT_NE(pool2, nullptr);
}
} // namespace tensorflow

View File

@ -0,0 +1,85 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/reader_op_kernel.h"
#include "tensorflow/core/util/proto/descriptor_pool_registry.h"
#include "tensorflow/core/util/proto/descriptors.h"
namespace tensorflow {
namespace {
// Build a `DescriptorPool` from the named file or URI. The file or URI
// must be available to the current TensorFlow environment.
//
// The file must contiain a serialized `FileDescriptorSet`. See
// `GetDescriptorPool()` for more information.
Status GetDescriptorPoolFromFile(
tensorflow::Env* env, const string& filename,
std::unique_ptr<tensorflow::protobuf::DescriptorPool>* owned_desc_pool) {
Status st = env->FileExists(filename);
if (!st.ok()) {
return st;
}
// Read and parse the FileDescriptorSet.
tensorflow::protobuf::FileDescriptorSet descs;
std::unique_ptr<tensorflow::ReadOnlyMemoryRegion> buf;
st = env->NewReadOnlyMemoryRegionFromFile(filename, &buf);
if (!st.ok()) {
return st;
}
if (!descs.ParseFromArray(buf->data(), buf->length())) {
return errors::InvalidArgument(
"descriptor_source contains invalid FileDescriptorSet: ", filename);
}
// Build a DescriptorPool from the FileDescriptorSet.
owned_desc_pool->reset(new tensorflow::protobuf::DescriptorPool());
for (const auto& filedesc : descs.file()) {
if ((*owned_desc_pool)->BuildFile(filedesc) == nullptr) {
return errors::InvalidArgument(
"Problem loading FileDescriptorProto (missing dependencies?): ",
filename);
}
}
return Status::OK();
}
} // namespace
Status GetDescriptorPool(
tensorflow::Env* env, string const& descriptor_source,
tensorflow::protobuf::DescriptorPool const** desc_pool,
std::unique_ptr<tensorflow::protobuf::DescriptorPool>* owned_desc_pool) {
// Attempt to lookup the pool in the registry.
auto pool_fn = DescriptorPoolRegistry::Global()->Get(descriptor_source);
if (pool_fn != nullptr) {
return (*pool_fn)(desc_pool, owned_desc_pool);
}
// If there is no pool function registered for the given source, let the
// runtime find the file or URL.
Status status =
GetDescriptorPoolFromFile(env, descriptor_source, owned_desc_pool);
if (status.ok()) {
*desc_pool = owned_desc_pool->get();
}
*desc_pool = owned_desc_pool->get();
return status;
}
} // namespace tensorflow

View File

@ -0,0 +1,42 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_UTIL_PROTO_DESCRIPTORS_H_
#define TENSORFLOW_CORE_UTIL_PROTO_DESCRIPTORS_H_
#include <memory>
#include <string>
#include "tensorflow/core/platform/protobuf.h"
namespace tensorflow {
class Env;
class Status;
// Get a `DescriptorPool` object from the named `descriptor_source`.
// `descriptor_source` may be a path to a file accessible to TensorFlow, in
// which case it is parsed as a `FileDescriptorSet` and used to build the
// `DescriptorPool`.
//
// `owned_desc_pool` will be filled in with the same pointer as `desc_pool` if
// the caller should take ownership.
extern tensorflow::Status GetDescriptorPool(
tensorflow::Env* env, string const& descriptor_source,
tensorflow::protobuf::DescriptorPool const** desc_pool,
std::unique_ptr<tensorflow::protobuf::DescriptorPool>* owned_desc_pool);
} // namespace tensorflow
#endif // TENSORFLOW_CORE_UTIL_PROTO_DESCRIPTORS_H_

View File

@ -0,0 +1,39 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/util/proto/descriptor_pool_registry.h"
namespace tensorflow {
namespace {
struct LocalDescriptorPool {
static Status Function(
tensorflow::protobuf::DescriptorPool const** desc_pool,
std::unique_ptr<tensorflow::protobuf::DescriptorPool>* owned_desc_pool) {
*desc_pool = ::tensorflow::protobuf::DescriptorPool::generated_pool();
if (*desc_pool == nullptr) {
return errors::InvalidArgument("Problem loading protobuf generated_pool");
}
return Status::OK();
}
};
REGISTER_DESCRIPTOR_POOL("", LocalDescriptorPool::Function);
REGISTER_DESCRIPTOR_POOL("local://", LocalDescriptorPool::Function);
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,48 @@
package(
default_visibility = ["//visibility:public"],
)
licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
cc_library(
name = "call_container",
hdrs = ["call_container.h"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
cc_library(
name = "rpc_factory",
srcs = ["rpc_factory.cc"],
hdrs = ["rpc_factory.h"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
],
)
cc_library(
name = "rpc_factory_registry",
srcs = ["rpc_factory_registry.cc"],
hdrs = ["rpc_factory_registry.h"],
deps = [
":rpc_factory",
"//tensorflow/core:framework",
],
)
tf_cc_test(
name = "rpc_factory_registry_test",
srcs = ["rpc_factory_registry_test.cc"],
deps = [
":rpc_factory_registry",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)

View File

@ -0,0 +1,90 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_UTIL_RPC_CALL_CONTAINER_H_
#define TENSORFLOW_CORE_UTIL_RPC_CALL_CONTAINER_H_
#include <list>
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/util/reffed_status_callback.h"
namespace tensorflow {
template <typename Call>
class CallContainer {
public:
explicit CallContainer(OpKernelContext* ctx, int num_calls, bool fail_fast,
bool try_rpc, AsyncOpKernel::DoneCallback done,
CancellationToken token)
: ctx_(ctx),
done_(std::move(done)),
token_(token),
fail_fast_(fail_fast),
try_rpc_(try_rpc) {
CHECK_GT(num_calls, 0);
// This will run when all RPCs are finished.
reffed_status_callback_ = new ReffedStatusCallback([this](const Status& s) {
ctx_->cancellation_manager()->DeregisterCallback(token_);
ctx_->SetStatus(s);
done_();
delete this;
});
// Subtract reference count from the initial creation.
core::ScopedUnref unref(reffed_status_callback_);
for (int i = 0; i < num_calls; ++i) {
// Increase the reference on the callback for each new RPC.
reffed_status_callback_->Ref();
}
}
std::list<Call>* calls() { return &calls_; }
void StartCancel() {
// Once this loop is done, can no longer assume anything is valid
// because "delete this" may have been immediately called.
// Nothing should run after this loop.
for (auto& call : calls_) {
call.StartCancel();
}
}
void Done(const Status& s, int index) {
if (!try_rpc_) {
reffed_status_callback_->UpdateStatus(s);
}
reffed_status_callback_->Unref();
}
private:
OpKernelContext* ctx_;
std::list<Call> calls_;
const AsyncOpKernel::DoneCallback done_;
const CancellationToken token_;
const bool fail_fast_;
const bool try_rpc_;
// Performs its own reference counting.
ReffedStatusCallback* reffed_status_callback_;
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_UTIL_RPC_CALL_CONTAINER_H_

View File

@ -0,0 +1,53 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/util/rpc/rpc_factory.h"
namespace tensorflow {
template <>
bool GetEnvVar(const char* key, const string& default_value, string* value) {
const char* env_value = std::getenv(key);
if (!env_value || env_value[0] == '\0') {
*value = default_value;
} else {
*value = env_value;
}
return true;
}
template <>
bool GetEnvVar(const char* key, const int64& default_value, int64* value) {
const char* env_value = std::getenv(key);
if (!env_value || env_value[0] == '\0') {
*value = default_value;
return true;
}
return strings::safe_strto64(env_value, value);
}
template <>
bool GetEnvVar(const char* key, const uint64& default_value, uint64* value) {
const char* env_value = std::getenv(key);
if (!env_value || env_value[0] == '\0') {
*value = default_value;
return true;
}
return strings::safe_strtou64(env_value, value);
}
} // namespace tensorflow

View File

@ -0,0 +1,70 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_H_
#define TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_H_
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
namespace tensorflow {
// Return the environment variable `key`. If the variable is not set,
// use the default value. If it is set but could not be parsed,
// return `false`. Otherwise set `value` and return `true`.
template <typename T>
bool GetEnvVar(const char* key, const T& default_value, T* value);
class RPCFactory {
public:
RPCFactory() {}
virtual ~RPCFactory() {}
// Start a Call() to methods `method_t` at addresses `address_t` with
// request strings from `request_t`. Any of these may be scalar
// Tensors, in which case the operands are broadcasted.
// Upon completion of all requests, `response_t` will be populated.
//
// If `try_rpc` is `true`, then `status_message_t` and
// `status_code_t` will be populated as well.
//
// If `try_rpc` is `false`, then `status_message_t` and
// `status_code_t` are ignored (and may be nullptr). Instead, the
// status of any failed call will be propagated to the op.
//
// REQUIRES:
// - `response_t` is not null, and is a string Tensor with the same shape as
// `request_t`.
//
// If `try_rpc` is `true`:
// - `status_code_t` and `status_message_t` are not null.
// - `status_code_t` is an int32 Tensor with the same shape as
// `request_t`.
// - `status_message_t` is a string Tensor with the same shape as
// `request_t`.
virtual void Call(OpKernelContext* ctx, int64 num_elements,
const Tensor& address_t, const Tensor& method_t,
const Tensor& request_t, const bool try_rpc,
Tensor* response_t, Tensor* status_code_t,
Tensor* status_message_t,
AsyncOpKernel::DoneCallback done) = 0;
private:
TF_DISALLOW_COPY_AND_ASSIGN(RPCFactory);
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_H_

View File

@ -0,0 +1,44 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string>
#include "tensorflow/core/util/rpc/rpc_factory.h"
#include "tensorflow/core/util/rpc/rpc_factory_registry.h"
namespace tensorflow {
RPCFactoryRegistry* RPCFactoryRegistry::Global() {
static RPCFactoryRegistry* registry = new RPCFactoryRegistry;
return registry;
}
RPCFactoryRegistry::RPCFactoryFn* RPCFactoryRegistry::Get(
const string& protocol) {
auto found = fns_.find(protocol);
if (found == fns_.end()) return nullptr;
return &found->second;
}
void RPCFactoryRegistry::Register(const string& protocol,
const RPCFactoryFn& factory_fn) {
auto existing = Get(protocol);
CHECK_EQ(existing, nullptr)
<< "RPC factory for protocol: " << protocol << " already registered";
fns_.insert(std::pair<const string&, RPCFactoryFn>(protocol, factory_fn));
}
} // namespace tensorflow

View File

@ -0,0 +1,72 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_REGISTRY_H_
#define TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_REGISTRY_H_
#include <map>
#include <string>
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/util/rpc/rpc_factory.h"
namespace tensorflow {
class RPCFactoryRegistry {
public:
typedef std::function<RPCFactory*(OpKernelConstruction* ctx, bool fail_fast,
int64 timeout_in_ms)>
RPCFactoryFn;
// Returns a pointer to a global RPCFactoryRegistry object.
static RPCFactoryRegistry* Global();
// Returns a pointer to an function that creates an RPC factory for the given
// protocol.
RPCFactoryFn* Get(const string& protocol);
// Registers a function that creates and RPC factory for the given protocol.
// The function should transfer the ownership of the factory to its caller.
void Register(const string& protocol, const RPCFactoryFn& factory_fn);
private:
std::map<string, RPCFactoryFn> fns_;
};
namespace rpc_factory_registration {
class RPCFactoryRegistration {
public:
RPCFactoryRegistration(const string& protocol,
const RPCFactoryRegistry::RPCFactoryFn& factory_fn) {
RPCFactoryRegistry::Global()->Register(protocol, factory_fn);
}
};
} // namespace rpc_factory_registration
#define REGISTER_RPC_FACTORY(protocol, factory_fn) \
REGISTER_RPC_FACTORY_UNIQ_HELPER(__COUNTER__, protocol, factory_fn)
#define REGISTER_RPC_FACTORY_UNIQ_HELPER(ctr, protocol, factory_fn) \
REGISTER_RPC_FACTORY_UNIQ(ctr, protocol, factory_fn)
#define REGISTER_RPC_FACTORY_UNIQ(ctr, protocol, factory_fn) \
static rpc_factory_registration::RPCFactoryRegistration \
rpc_factory_registration_fn_##ctr(protocol, factory_fn)
} // namespace tensorflow
#endif // TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_REGISTRY_H_

View File

@ -0,0 +1,41 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/util/rpc/rpc_factory_registry.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
struct Value {
static RPCFactory* Function(OpKernelConstruction* ctx, bool fail_fast,
int64 timeout_in_ms) {
return nullptr;
}
};
REGISTER_RPC_FACTORY("TEST FACTORY 1", Value::Function);
REGISTER_RPC_FACTORY("TEST FACTORY 2", Value::Function);
} // namespace
TEST(RPCFactoryRegistryTest, TestBasic) {
EXPECT_EQ(RPCFactoryRegistry::Global()->Get("NON-EXISTENT"), nullptr);
auto factory1 = RPCFactoryRegistry::Global()->Get("TEST FACTORY 1");
EXPECT_NE(factory1, nullptr);
auto factory2 = RPCFactoryRegistry::Global()->Get("TEST FACTORY 2");
EXPECT_NE(factory2, nullptr);
}
} // namespace tensorflow

View File

@ -3370,6 +3370,7 @@ tf_py_wrap_cc(
"//tensorflow/c:python_api",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c/eager:c_api",
"//tensorflow/core/distributed_runtime/rpc:grpc_rpc_factory_registration",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_session",
"//tensorflow/core/grappler:grappler_item",