Open sourcing proto/rpc ops.
PiperOrigin-RevId: 191962572
This commit is contained in:
parent
ddf54d1c24
commit
5e11bbacaf
@ -354,6 +354,9 @@ tensorflow/contrib/periodic_resample
|
||||
tensorflow/contrib/periodic_resample/python
|
||||
tensorflow/contrib/periodic_resample/python/ops
|
||||
tensorflow/contrib/predictor
|
||||
tensorflow/contrib/proto
|
||||
tensorflow/contrib/proto/python
|
||||
tensorflow/contrib/proto/python/ops
|
||||
tensorflow/contrib/quantization
|
||||
tensorflow/contrib/quantization/python
|
||||
tensorflow/contrib/quantize
|
||||
@ -382,6 +385,9 @@ tensorflow/contrib/rnn/ops
|
||||
tensorflow/contrib/rnn/python
|
||||
tensorflow/contrib/rnn/python/kernel_tests
|
||||
tensorflow/contrib/rnn/python/ops
|
||||
tensorflow/contrib/rpc
|
||||
tensorflow/contrib/rpc/python
|
||||
tensorflow/contrib/rpc/python/ops
|
||||
tensorflow/contrib/saved_model
|
||||
tensorflow/contrib/saved_model/python
|
||||
tensorflow/contrib/saved_model/python/saved_model
|
||||
|
@ -25,6 +25,8 @@ set(tf_op_lib_names
|
||||
"cudnn_rnn_ops"
|
||||
"data_flow_ops"
|
||||
"dataset_ops"
|
||||
"decode_proto_ops"
|
||||
"encode_proto_ops"
|
||||
"functional_ops"
|
||||
"image_ops"
|
||||
"io_ops"
|
||||
@ -40,6 +42,7 @@ set(tf_op_lib_names
|
||||
"random_ops"
|
||||
"remote_fused_graph_ops"
|
||||
"resource_variable_ops"
|
||||
"rpc_ops"
|
||||
"script_ops"
|
||||
"sdca_ops"
|
||||
"set_ops"
|
||||
|
@ -330,6 +330,8 @@ GENERATE_PYTHON_OP_LIB("ctc_ops")
|
||||
GENERATE_PYTHON_OP_LIB("cudnn_rnn_ops")
|
||||
GENERATE_PYTHON_OP_LIB("data_flow_ops")
|
||||
GENERATE_PYTHON_OP_LIB("dataset_ops")
|
||||
GENERATE_PYTHON_OP_LIB("decode_proto_ops")
|
||||
GENERATE_PYTHON_OP_LIB("encode_proto_ops")
|
||||
GENERATE_PYTHON_OP_LIB("image_ops")
|
||||
GENERATE_PYTHON_OP_LIB("io_ops")
|
||||
GENERATE_PYTHON_OP_LIB("linalg_ops")
|
||||
@ -343,6 +345,7 @@ GENERATE_PYTHON_OP_LIB("random_ops")
|
||||
GENERATE_PYTHON_OP_LIB("remote_fused_graph_ops"
|
||||
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/remote_fused_graph/pylib/python/ops/gen_remote_fused_graph_ops.py)
|
||||
GENERATE_PYTHON_OP_LIB("resource_variable_ops")
|
||||
GENERATE_PYTHON_OP_LIB("rpc_ops")
|
||||
GENERATE_PYTHON_OP_LIB("script_ops")
|
||||
GENERATE_PYTHON_OP_LIB("sdca_ops")
|
||||
GENERATE_PYTHON_OP_LIB("set_ops")
|
||||
|
@ -301,3 +301,5 @@ tensorflow/core/kernels/warn_about_ints.cc
|
||||
tensorflow/core/kernels/segment_reduction_ops.cc
|
||||
tensorflow/core/kernels/batch_util.cc
|
||||
tensorflow/core/ops/audio_ops.cc
|
||||
tensorflow/core/kernels/decode_proto_op.cc
|
||||
tensorflow/core/kernels/encode_proto_op.cc
|
||||
|
16
tensorflow/contrib/proto/BUILD
Normal file
16
tensorflow/contrib/proto/BUILD
Normal file
@ -0,0 +1,16 @@
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
py_library(
|
||||
name = "proto",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/contrib/proto/python/ops:decode_proto_op_py",
|
||||
"//tensorflow/contrib/proto/python/ops:encode_proto_op_py",
|
||||
],
|
||||
)
|
28
tensorflow/contrib/proto/__init__.py
Normal file
28
tensorflow/contrib/proto/__init__.py
Normal file
@ -0,0 +1,28 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Ops and modules related to proto.
|
||||
|
||||
@@decode_proto
|
||||
@@encode_proto
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.proto.python.ops.decode_proto_op import decode_proto
|
||||
from tensorflow.contrib.proto.python.ops.encode_proto_op import encode_proto
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
remove_undocumented(__name__)
|
44
tensorflow/contrib/proto/python/ops/BUILD
Normal file
44
tensorflow/contrib/proto/python/ops/BUILD
Normal file
@ -0,0 +1,44 @@
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_gen_op_wrapper_py",
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "decode_proto_op_py",
|
||||
srcs = ["decode_proto_op.py"],
|
||||
deps = [
|
||||
":gen_decode_proto_op_py",
|
||||
"//tensorflow/python:framework_ops",
|
||||
],
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_py(
|
||||
name = "gen_decode_proto_op_py",
|
||||
out = "gen_decode_proto_op.py",
|
||||
deps = [
|
||||
"//tensorflow/core:decode_proto_ops_op_lib",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "encode_proto_op_py",
|
||||
srcs = ["encode_proto_op.py"],
|
||||
deps = [
|
||||
":gen_encode_proto_op_py",
|
||||
"//tensorflow/python:framework_ops",
|
||||
],
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_py(
|
||||
name = "gen_encode_proto_op_py",
|
||||
out = "gen_encode_proto_op.py",
|
||||
deps = [
|
||||
"//tensorflow/core:encode_proto_ops_op_lib",
|
||||
],
|
||||
)
|
25
tensorflow/contrib/proto/python/ops/decode_proto_op.py
Normal file
25
tensorflow/contrib/proto/python/ops/decode_proto_op.py
Normal file
@ -0,0 +1,25 @@
|
||||
# =============================================================================
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
# pylint: disable=wildcard-import,unused-import
|
||||
"""Protocol Buffer decoding from tensors."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.proto.python.ops.gen_decode_proto_op import decode_proto_v2 as decode_proto
|
||||
from tensorflow.python.framework import ops
|
||||
ops.NotDifferentiable("DecodeProtoV2")
|
25
tensorflow/contrib/proto/python/ops/encode_proto_op.py
Normal file
25
tensorflow/contrib/proto/python/ops/encode_proto_op.py
Normal file
@ -0,0 +1,25 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
# pylint: disable=wildcard-import,unused-import
|
||||
"""Protocol Buffer encoding from tensors."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.proto.python.ops.gen_encode_proto_op import encode_proto
|
||||
from tensorflow.python.framework import ops
|
||||
|
||||
ops.NotDifferentiable("EncodeProto")
|
13
tensorflow/contrib/rpc/BUILD
Normal file
13
tensorflow/contrib/rpc/BUILD
Normal file
@ -0,0 +1,13 @@
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
py_library(
|
||||
name = "rpc",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
],
|
||||
deps = ["//tensorflow/contrib/rpc/python/ops:rpc_op_py"],
|
||||
)
|
28
tensorflow/contrib/rpc/__init__.py
Normal file
28
tensorflow/contrib/rpc/__init__.py
Normal file
@ -0,0 +1,28 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Ops and modules related to RPC.
|
||||
|
||||
@@rpc
|
||||
@@try_rpc
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.rpc.python.ops.rpc_op import rpc
|
||||
from tensorflow.contrib.rpc.python.ops.rpc_op import try_rpc
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
remove_undocumented(__name__)
|
24
tensorflow/contrib/rpc/python/ops/BUILD
Normal file
24
tensorflow/contrib/rpc/python/ops/BUILD
Normal file
@ -0,0 +1,24 @@
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
|
||||
|
||||
py_library(
|
||||
name = "rpc_op_py",
|
||||
srcs = ["rpc_op.py"],
|
||||
deps = [
|
||||
":gen_rpc_op_py",
|
||||
"//tensorflow/python:framework_ops",
|
||||
],
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_py(
|
||||
name = "gen_rpc_op_py",
|
||||
out = "gen_rpc_op.py",
|
||||
deps = [
|
||||
"//tensorflow/core:rpc_ops_op_lib",
|
||||
],
|
||||
)
|
26
tensorflow/contrib/rpc/python/ops/rpc_op.py
Normal file
26
tensorflow/contrib/rpc/python/ops/rpc_op.py
Normal file
@ -0,0 +1,26 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
# pylint: disable=wildcard-import,unused-import
|
||||
"""RPC communication."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.rpc.python.ops.gen_rpc_op import rpc
|
||||
from tensorflow.contrib.rpc.python.ops.gen_rpc_op import try_rpc
|
||||
from tensorflow.python.framework import ops
|
||||
ops.NotDifferentiable("Rpc")
|
||||
ops.NotDifferentiable("TryRpc")
|
@ -637,6 +637,8 @@ tf_gen_op_libs(
|
||||
"ctc_ops",
|
||||
"data_flow_ops",
|
||||
"dataset_ops",
|
||||
"decode_proto_ops",
|
||||
"encode_proto_ops",
|
||||
"function_ops",
|
||||
"functional_ops",
|
||||
"image_ops",
|
||||
@ -653,6 +655,7 @@ tf_gen_op_libs(
|
||||
"random_ops",
|
||||
"remote_fused_graph_ops",
|
||||
"resource_variable_ops",
|
||||
"rpc_ops",
|
||||
"scoped_allocator_ops",
|
||||
"sdca_ops",
|
||||
"set_ops",
|
||||
@ -751,6 +754,8 @@ cc_library(
|
||||
":cudnn_rnn_ops_op_lib",
|
||||
":data_flow_ops_op_lib",
|
||||
":dataset_ops_op_lib",
|
||||
":decode_proto_ops_op_lib",
|
||||
":encode_proto_ops_op_lib",
|
||||
":function_ops_op_lib",
|
||||
":functional_ops_op_lib",
|
||||
":image_ops_op_lib",
|
||||
@ -767,6 +772,7 @@ cc_library(
|
||||
":random_ops_op_lib",
|
||||
":remote_fused_graph_ops_op_lib",
|
||||
":resource_variable_ops_op_lib",
|
||||
":rpc_ops_op_lib",
|
||||
":scoped_allocator_ops_op_lib",
|
||||
":script_ops_op_lib",
|
||||
":sdca_ops_op_lib",
|
||||
@ -893,6 +899,8 @@ cc_library(
|
||||
"//tensorflow/core/kernels:cudnn_rnn_kernels",
|
||||
"//tensorflow/core/kernels:data_flow",
|
||||
"//tensorflow/core/kernels:dataset_ops",
|
||||
"//tensorflow/core/kernels:decode_proto_op",
|
||||
"//tensorflow/core/kernels:encode_proto_op",
|
||||
"//tensorflow/core/kernels:fake_quant_ops",
|
||||
"//tensorflow/core/kernels:function_ops",
|
||||
"//tensorflow/core/kernels:functional_ops",
|
||||
@ -914,6 +922,7 @@ cc_library(
|
||||
"//tensorflow/core/kernels:remote_fused_graph_ops",
|
||||
"//tensorflow/core/kernels:required",
|
||||
"//tensorflow/core/kernels:resource_variable_ops",
|
||||
"//tensorflow/core/kernels:rpc_op",
|
||||
"//tensorflow/core/kernels:scoped_allocator_ops",
|
||||
"//tensorflow/core/kernels:sdca_ops",
|
||||
"//tensorflow/core/kernels:set_kernels",
|
||||
|
116
tensorflow/core/api_def/base_api/api_def_DecodeProtoV2.pbtxt
Normal file
116
tensorflow/core/api_def/base_api/api_def_DecodeProtoV2.pbtxt
Normal file
@ -0,0 +1,116 @@
|
||||
op {
|
||||
graph_op_name: "DecodeProtoV2"
|
||||
in_arg {
|
||||
name: "bytes"
|
||||
description: <<END
|
||||
Tensor of serialized protos with shape `batch_shape`.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "sizes"
|
||||
description: <<END
|
||||
Tensor of int32 with shape `[batch_shape, len(field_names)]`.
|
||||
Each entry is the number of values found for the corresponding field.
|
||||
Optional fields may have 0 or 1 values.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "values"
|
||||
description: <<END
|
||||
List of tensors containing values for the corresponding field.
|
||||
`values[i]` has datatype `output_types[i]`
|
||||
and shape `[batch_shape, max(sizes[...,i])]`.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "message_type"
|
||||
description: <<END
|
||||
Name of the proto message type to decode.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "field_names"
|
||||
description: <<END
|
||||
List of strings containing proto field names.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "output_types"
|
||||
description: <<END
|
||||
List of TF types to use for the respective field in field_names.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "descriptor_source"
|
||||
description: <<END
|
||||
Either the special value `local://` or a path to a file containing
|
||||
a serialized `FileDescriptorSet`.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "message_format"
|
||||
description: <<END
|
||||
Either `binary` or `text`.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "sanitize"
|
||||
description: <<END
|
||||
Whether to sanitize the result or not.
|
||||
END
|
||||
}
|
||||
summary: <<END
|
||||
The op extracts fields from a serialized protocol buffers message into tensors.
|
||||
END
|
||||
description: <<END
|
||||
The `decode_proto` op extracts fields from a serialized protocol buffers
|
||||
message into tensors. The fields in `field_names` are decoded and converted
|
||||
to the corresponding `output_types` if possible.
|
||||
|
||||
A `message_type` name must be provided to give context for the field
|
||||
names. The actual message descriptor can be looked up either in the
|
||||
linked-in descriptor pool or a filename provided by the caller using
|
||||
the `descriptor_source` attribute.
|
||||
|
||||
Each output tensor is a dense tensor. This means that it is padded to
|
||||
hold the largest number of repeated elements seen in the input
|
||||
minibatch. (The shape is also padded by one to prevent zero-sized
|
||||
dimensions). The actual repeat counts for each example in the
|
||||
minibatch can be found in the `sizes` output. In many cases the output
|
||||
of `decode_proto` is fed immediately into tf.squeeze if missing values
|
||||
are not a concern. When using tf.squeeze, always pass the squeeze
|
||||
dimension explicitly to avoid surprises.
|
||||
|
||||
For the most part, the mapping between Proto field types and
|
||||
TensorFlow dtypes is straightforward. However, there are a few
|
||||
special cases:
|
||||
|
||||
- A proto field that contains a submessage or group can only be converted
|
||||
to `DT_STRING` (the serialized submessage). This is to reduce the
|
||||
complexity of the API. The resulting string can be used as input
|
||||
to another instance of the decode_proto op.
|
||||
|
||||
- TensorFlow lacks support for unsigned integers. The ops represent uint64
|
||||
types as a `DT_INT64` with the same twos-complement bit pattern
|
||||
(the obvious way). Unsigned int32 values can be represented exactly by
|
||||
specifying type `DT_INT64`, or using twos-complement if the caller
|
||||
specifies `DT_INT32` in the `output_types` attribute.
|
||||
|
||||
The `descriptor_source` attribute selects a source of protocol
|
||||
descriptors to consult when looking up `message_type`. This may be a
|
||||
filename containing a serialized `FileDescriptorSet` message,
|
||||
or the special value `local://`, in which case only descriptors linked
|
||||
into the code will be searched; the filename can be on any filesystem
|
||||
accessible to TensorFlow.
|
||||
|
||||
You can build a `descriptor_source` file using the `--descriptor_set_out`
|
||||
and `--include_imports` options to the protocol compiler `protoc`.
|
||||
|
||||
The `local://` database only covers descriptors linked into the
|
||||
code via C++ libraries, not Python imports. You can link in a proto descriptor
|
||||
by creating a cc_library target with alwayslink=1.
|
||||
|
||||
Both binary and text proto serializations are supported, and can be
|
||||
chosen using the `format` attribute.
|
||||
END
|
||||
}
|
81
tensorflow/core/api_def/base_api/api_def_EncodeProto.pbtxt
Normal file
81
tensorflow/core/api_def/base_api/api_def_EncodeProto.pbtxt
Normal file
@ -0,0 +1,81 @@
|
||||
op {
|
||||
graph_op_name: "EncodeProto"
|
||||
in_arg {
|
||||
name: "sizes"
|
||||
description: <<END
|
||||
Tensor of int32 with shape `[batch_shape, len(field_names)]`.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "values"
|
||||
description: <<END
|
||||
List of tensors containing values for the corresponding field.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "bytes"
|
||||
description: <<END
|
||||
Tensor of serialized protos with shape `batch_shape`.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "message_type"
|
||||
description: <<END
|
||||
Name of the proto message type to decode.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "field_names"
|
||||
description: <<END
|
||||
List of strings containing proto field names.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "Tinput_types"
|
||||
description: <<END
|
||||
The input types.
|
||||
END
|
||||
}
|
||||
summary: <<END
|
||||
The op serializes protobuf messages provided in the input tensors.
|
||||
END
|
||||
description: <<END
|
||||
The types of the tensors in `values` must match the schema for the
|
||||
fields specified in `field_names`. All the tensors in `values` must
|
||||
have a common shape prefix, *batch_shape*.
|
||||
|
||||
The `sizes` tensor specifies repeat counts for each field. The repeat
|
||||
count (last dimension) of a each tensor in `values` must be greater
|
||||
than or equal to corresponding repeat count in `sizes`.
|
||||
|
||||
A `message_type` name must be provided to give context for the field
|
||||
names. The actual message descriptor can be looked up either in the
|
||||
linked-in descriptor pool or a filename provided by the caller using
|
||||
the `descriptor_source` attribute.
|
||||
|
||||
The `descriptor_source` attribute selects a source of protocol
|
||||
descriptors to consult when looking up `message_type`. This may be a
|
||||
filename containing a serialized `FileDescriptorSet` message,
|
||||
or the special value `local://`, in which case only descriptors linked
|
||||
into the code will be searched; the filename can be on any filesystem
|
||||
accessible to TensorFlow.
|
||||
|
||||
You can build a `descriptor_source` file using the `--descriptor_set_out`
|
||||
and `--include_imports` options to the protocol compiler `protoc`.
|
||||
|
||||
The `local://` database only covers descriptors linked into the
|
||||
code via C++ libraries, not Python imports. You can link in a proto descriptor
|
||||
by creating a cc_library target with alwayslink=1.
|
||||
|
||||
There are a few special cases in the value mapping:
|
||||
|
||||
Submessage and group fields must be pre-serialized as TensorFlow strings.
|
||||
|
||||
TensorFlow lacks support for unsigned int64s, so they must be
|
||||
represented as `tf.int64` with the same twos-complement bit pattern
|
||||
(the obvious way).
|
||||
|
||||
Unsigned int32 values can be represented exactly with `tf.int64`, or
|
||||
with sign wrapping if the input is of type `tf.int32`.
|
||||
END
|
||||
}
|
108
tensorflow/core/api_def/base_api/api_def_Rpc.pbtxt
Normal file
108
tensorflow/core/api_def/base_api/api_def_Rpc.pbtxt
Normal file
@ -0,0 +1,108 @@
|
||||
op {
|
||||
graph_op_name: "Rpc"
|
||||
in_arg {
|
||||
name: "address"
|
||||
description: <<END
|
||||
`0-D` or `1-D`. The address (i.e. host_name:port) of the RPC server.
|
||||
If this tensor has more than 1 element, then multiple parallel rpc requests
|
||||
are sent. This argument broadcasts with `method` and `request`.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "method"
|
||||
description: <<END
|
||||
`0-D` or `1-D`. The method address on the RPC server.
|
||||
If this tensor has more than 1 element, then multiple parallel rpc requests
|
||||
are sent. This argument broadcasts with `address` and `request`.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "request"
|
||||
description: <<END
|
||||
`0-D` or `1-D`. Serialized proto strings: the rpc request argument.
|
||||
If this tensor has more than 1 element, then multiple parallel rpc requests
|
||||
are sent. This argument broadcasts with `address` and `method`.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "response"
|
||||
description: <<END
|
||||
Same shape as `request`. Serialized proto strings: the rpc responses.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "protocol"
|
||||
description: <<END
|
||||
RPC protocol to use. Empty string means use the default protocol.
|
||||
Options include 'grpc'.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "fail_fast"
|
||||
description: <<END
|
||||
`boolean`. If `true` (default), then failures to connect
|
||||
(i.e., the server does not immediately respond) cause an RPC failure.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "timeout_in_ms"
|
||||
description: <<END
|
||||
`int`. If `0` (default), then the kernel will run the RPC
|
||||
request and only time out if the RPC deadline passes or the session times out.
|
||||
If this value is greater than `0`, then the op will raise an exception if
|
||||
the RPC takes longer than `timeout_in_ms`.
|
||||
END
|
||||
}
|
||||
summary: <<END
|
||||
Perform batches of RPC requests.
|
||||
END
|
||||
description: <<END
|
||||
This op asynchronously performs either a single RPC request, or a batch
|
||||
of requests. RPC requests are defined by three main parameters:
|
||||
|
||||
- `address` (the host+port or BNS address of the request)
|
||||
- `method` (the RPC method name for the request)
|
||||
- `request` (the serialized proto string, or vector of strings,
|
||||
of the RPC request argument).
|
||||
|
||||
For example, if you have an RPC service running on port localhost:2345,
|
||||
and its interface is configured with the following proto declaration:
|
||||
|
||||
```
|
||||
service MyService {
|
||||
rpc MyMethod(MyRequestProto) returns (MyResponseProto) {
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
then call this op with arguments:
|
||||
|
||||
```
|
||||
address = "localhost:2345"
|
||||
method = "MyService/MyMethod"
|
||||
```
|
||||
|
||||
The `request` tensor is a string tensor representing serialized `MyRequestProto`
|
||||
strings; and the output string tensor `response` will have the same shape
|
||||
and contain (upon successful completion) corresponding serialized
|
||||
`MyResponseProto` strings.
|
||||
|
||||
For example, to send a single, empty, `MyRequestProto`, call
|
||||
this op with `request = ""`. To send 5 **parallel** empty requests,
|
||||
call this op with `request = ["", "", "", "", ""]`.
|
||||
|
||||
More generally, one can create a batch of `MyRequestProto` serialized protos
|
||||
from regular batched tensors using the `encode_proto` op, and convert
|
||||
the response `MyResponseProto` serialized protos to batched tensors
|
||||
using the `decode_proto` op.
|
||||
|
||||
**NOTE** Working with serialized proto strings is faster than instantiating
|
||||
actual proto objects in memory, so no performance degradation is expected
|
||||
compared to writing custom kernels for this workflow.
|
||||
|
||||
If the connection fails or the remote worker returns an error
|
||||
status, the op reraises this exception locally.
|
||||
|
||||
See the `TryRpc` op if you prefer to handle RPC failures manually in the graph.
|
||||
END
|
||||
}
|
123
tensorflow/core/api_def/base_api/api_def_TryRpc.pbtxt
Normal file
123
tensorflow/core/api_def/base_api/api_def_TryRpc.pbtxt
Normal file
@ -0,0 +1,123 @@
|
||||
op {
|
||||
graph_op_name: "TryRpc"
|
||||
in_arg {
|
||||
name: "address"
|
||||
description: <<END
|
||||
`0-D` or `1-D`. The address (i.e. host_name:port) of the RPC server.
|
||||
If this tensor has more than 1 element, then multiple parallel rpc requests
|
||||
are sent. This argument broadcasts with `method` and `request`.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "method"
|
||||
description: <<END
|
||||
`0-D` or `1-D`. The method address on the RPC server.
|
||||
If this tensor has more than 1 element, then multiple parallel rpc requests
|
||||
are sent. This argument broadcasts with `address` and `request`.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "request"
|
||||
description: <<END
|
||||
`0-D` or `1-D`. Serialized proto strings: the rpc request argument.
|
||||
If this tensor has more than 1 element, then multiple parallel rpc requests
|
||||
are sent. This argument broadcasts with `address` and `method`.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "response"
|
||||
description: <<END
|
||||
Same shape as `request`. Serialized proto strings: the rpc responses.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "status_code"
|
||||
description: <<END
|
||||
Same shape as `request`. Values correspond to tensorflow Status enum codes.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "status_message"
|
||||
description: <<END
|
||||
Same shape as `request`. Values correspond to Status messages
|
||||
returned from the RPC calls.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "protocol"
|
||||
description: <<END
|
||||
RPC protocol to use. Empty string means use the default protocol.
|
||||
Options include 'grpc'.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "fail_fast"
|
||||
description: <<END
|
||||
`boolean`. If `true` (default), then failures to connect
|
||||
(i.e., the server does not immediately respond) cause an RPC failure.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "timeout_in_ms"
|
||||
description: <<END
|
||||
`int`. If `0` (default), then the kernel will run the RPC
|
||||
request and only time out if the RPC deadline passes or the session times out.
|
||||
If this value is greater than `0`, then the op will raise an exception if
|
||||
the RPC takes longer than `timeout_in_ms`.
|
||||
END
|
||||
}
|
||||
summary: <<END
|
||||
Perform batches of RPC requests.
|
||||
END
|
||||
description: <<END
|
||||
This op asynchronously performs either a single RPC request, or a batch
|
||||
of requests. RPC requests are defined by three main parameters:
|
||||
|
||||
- `address` (the host+port or BNS address of the request)
|
||||
- `method` (the method name for the request)
|
||||
- `request` (the serialized proto string, or vector of strings,
|
||||
of the RPC request argument).
|
||||
|
||||
For example, if you have an RPC service running on port localhost:2345,
|
||||
and its interface is configured with the following proto declaration:
|
||||
|
||||
```
|
||||
service MyService {
|
||||
rpc MyMethod(MyRequestProto) returns (MyResponseProto) {
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
then call this op with arguments:
|
||||
|
||||
```
|
||||
address = "localhost:2345"
|
||||
method = "MyService/MyMethod"
|
||||
```
|
||||
|
||||
The `request` tensor is a string tensor representing serialized `MyRequestProto`
|
||||
strings; and the output string tensor `response` will have the same shape
|
||||
and contain (upon successful completion) corresponding serialized
|
||||
`MyResponseProto` strings.
|
||||
|
||||
For example, to send a single, empty, `MyRequestProto`, call
|
||||
this op with `request = ""`. To send 5 **parallel** empty requests,
|
||||
call this op with `request = ["", "", "", "", ""]`.
|
||||
|
||||
More generally, one can create a batch of `MyRequestProto` serialized protos
|
||||
from regular batched tensors using the `encode_proto` op, and convert
|
||||
the response `MyResponseProto` serialized protos to batched tensors
|
||||
using the `decode_proto` op.
|
||||
|
||||
**NOTE** Working with serialized proto strings is faster than instantiating
|
||||
actual proto objects in memory, so no performance degradation is expected
|
||||
compared to writing custom kernels for this workflow.
|
||||
|
||||
Unlike the standard `Rpc` op, if the connection fails or the remote worker
|
||||
returns an error status, this op does **not** reraise the exception.
|
||||
Instead, the `status_code` and `status_message` entry for the corresponding RPC
|
||||
call is set with the error returned from the RPC call. The `response` tensor
|
||||
will contain valid response values for those minibatch entries whose RPCs did
|
||||
not fail; the rest of the entries will have empty strings.
|
||||
END
|
||||
}
|
@ -499,3 +499,33 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/core/kernels:variable_ops",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "grpc_rpc_factory",
|
||||
srcs = [
|
||||
"grpc_rpc_factory.cc",
|
||||
],
|
||||
hdrs = ["grpc_rpc_factory.h"],
|
||||
deps = [
|
||||
":grpc_state",
|
||||
":grpc_util",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/util/rpc:call_container",
|
||||
"//tensorflow/core/util/rpc:rpc_factory",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "grpc_rpc_factory_registration",
|
||||
srcs = [
|
||||
"grpc_rpc_factory_registration.cc",
|
||||
],
|
||||
deps = [
|
||||
":grpc_rpc_factory",
|
||||
"//tensorflow/core/util/rpc:rpc_factory",
|
||||
"//tensorflow/core/util/rpc:rpc_factory_registry",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
213
tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc
Normal file
213
tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc
Normal file
@ -0,0 +1,213 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/util/rpc/call_container.h"
|
||||
#include "tensorflow/core/util/rpc/rpc_factory.h"
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
class GrpcCall {
|
||||
public:
|
||||
explicit GrpcCall(CallContainer<GrpcCall>* container, int index, bool try_rpc,
|
||||
const string* request_msg, string* response_msg,
|
||||
int32* status_code, string* status_message)
|
||||
: container_(container),
|
||||
index_(index),
|
||||
try_rpc_(try_rpc),
|
||||
request_msg_(request_msg),
|
||||
response_msg_(response_msg),
|
||||
status_code_(status_code),
|
||||
status_message_(status_message) {}
|
||||
|
||||
void StartCancel() { call_opts_.StartCancel(); }
|
||||
|
||||
void Done(const Status& s) {
|
||||
DCHECK(container_ != nullptr);
|
||||
if (!s.ok() && try_rpc_) {
|
||||
DCHECK(status_code_ != nullptr);
|
||||
DCHECK(status_message_ != nullptr);
|
||||
*status_code_ = s.code();
|
||||
*status_message_ = s.error_message();
|
||||
}
|
||||
container_->Done(s, index_);
|
||||
}
|
||||
|
||||
const string& request() const { return *request_msg_; }
|
||||
string* response() const { return response_msg_; }
|
||||
CallOptions* call_opts() { return &call_opts_; }
|
||||
|
||||
private:
|
||||
CallContainer<GrpcCall>* const container_;
|
||||
const int index_;
|
||||
bool try_rpc_;
|
||||
CallOptions call_opts_;
|
||||
const string* request_msg_;
|
||||
string* response_msg_;
|
||||
int* status_code_;
|
||||
string* status_message_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
GrpcRPCFactory::GrpcRPCFactory(OpKernelConstruction* ctx, bool fail_fast,
|
||||
int64 timeout_in_ms)
|
||||
: RPCFactory(), fail_fast_(fail_fast), timeout_in_ms_(timeout_in_ms) {
|
||||
// TODO(ebrevdo): Investigate possible performance improvements by
|
||||
// replacing this thread with a threadpool.
|
||||
polling_thread_ =
|
||||
ctx->env()->StartThread(ThreadOptions(), "rpc_op_grpc_factory", [this]() {
|
||||
void* tag;
|
||||
bool ok;
|
||||
while (completion_queue_.Next(&tag, &ok)) {
|
||||
GrpcClientCQTag* callback_tag = static_cast<GrpcClientCQTag*>(tag);
|
||||
callback_tag->OnCompleted(ok);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
GrpcRPCFactory::~GrpcRPCFactory() {
|
||||
// The amount of time we wait depends on several parameters, including:
|
||||
// - the value of the fail_fast attribute.
|
||||
// - the timeout option of the rpc call in the proto declaration.
|
||||
// - the network roundtrip time and service's execution time.
|
||||
//
|
||||
// If a connection is made but the service doesn't ever respond, and
|
||||
// there is no timeout option set for this rpc call, then it is
|
||||
// possible the RPC request will wait forever.
|
||||
//
|
||||
completion_queue_.Shutdown();
|
||||
delete polling_thread_;
|
||||
}
|
||||
|
||||
void GrpcRPCFactory::Call(OpKernelContext* ctx, int64 num_elements,
|
||||
const Tensor& address_t, const Tensor& method_t,
|
||||
const Tensor& request_t, const bool try_rpc,
|
||||
Tensor* response_t, Tensor* status_code_t,
|
||||
Tensor* status_message_t,
|
||||
AsyncOpKernel::DoneCallback done) {
|
||||
auto address = address_t.flat<string>();
|
||||
auto method = method_t.flat<string>();
|
||||
auto request = request_t.flat<string>();
|
||||
|
||||
// Stubs are maintained by the GrpcRPCFactory class and will be
|
||||
// deleted when the class is destroyed.
|
||||
::grpc::GenericStub* singleton_stub = nullptr;
|
||||
if (address.size() == 1) {
|
||||
singleton_stub = GetOrCreateStubForAddress(address(0));
|
||||
}
|
||||
auto get_stub = [&address, this,
|
||||
singleton_stub](int64 ix) -> ::grpc::GenericStub* {
|
||||
return (address.size() > 1) ? GetOrCreateStubForAddress(address(ix))
|
||||
: singleton_stub;
|
||||
};
|
||||
auto get_method_ptr = [&method](int64 ix) -> const string* {
|
||||
return (method.size() > 1) ? &(method(ix)) : &(method(0));
|
||||
};
|
||||
auto get_request_ptr = [&request](int64 ix) -> const string* {
|
||||
return (request.size() > 1) ? &(request(ix)) : &(request(0));
|
||||
};
|
||||
|
||||
if (try_rpc) {
|
||||
// In this case status_code will never be set in the response,
|
||||
// so we just set it to OK.
|
||||
DCHECK(status_code_t != nullptr);
|
||||
status_code_t->flat<int32>().setConstant(
|
||||
static_cast<int>(errors::Code::OK));
|
||||
}
|
||||
|
||||
CancellationManager* cm = ctx->cancellation_manager();
|
||||
CancellationToken cancellation_token = cm->get_cancellation_token();
|
||||
|
||||
// This object will delete itself when done.
|
||||
auto* container =
|
||||
new CallContainer<GrpcCall>(ctx, num_elements, fail_fast_, try_rpc,
|
||||
std::move(done), cancellation_token);
|
||||
|
||||
auto response = response_t->flat<string>();
|
||||
int32* status_code_ptr = nullptr;
|
||||
string* status_message_ptr = nullptr;
|
||||
if (try_rpc) {
|
||||
status_code_ptr = status_code_t->flat<int32>().data();
|
||||
status_message_ptr = status_message_t->flat<string>().data();
|
||||
}
|
||||
for (int i = 0; i < num_elements; ++i) {
|
||||
container->calls()->emplace_back(
|
||||
container, i, try_rpc, get_request_ptr(i), &response(i),
|
||||
(try_rpc) ? &status_code_ptr[i] : nullptr,
|
||||
(try_rpc) ? &status_message_ptr[i] : nullptr);
|
||||
}
|
||||
|
||||
int i = 0;
|
||||
for (GrpcCall& call : *(container->calls())) {
|
||||
// This object will delete itself when done.
|
||||
new RPCState<string>(get_stub(i), &completion_queue_, *get_method_ptr(i),
|
||||
call.request(), call.response(),
|
||||
/*done=*/[&call](const Status& s) { call.Done(s); },
|
||||
call.call_opts(), fail_fast_, timeout_in_ms_);
|
||||
++i;
|
||||
}
|
||||
|
||||
// Need to register this callback after all the RPCs are in
|
||||
// flight; otherwise we may try to cancel an RPC *before* it
|
||||
// launches, which is a no-op, and then fall into a deadlock.
|
||||
bool is_cancelled = !cm->RegisterCallback(
|
||||
cancellation_token, [container]() { container->StartCancel(); });
|
||||
|
||||
if (is_cancelled) {
|
||||
ctx->SetStatus(errors::Cancelled("Operation has been cancelled."));
|
||||
// container's reference counter will take care of calling done().
|
||||
container->StartCancel();
|
||||
}
|
||||
}
|
||||
|
||||
::grpc::GenericStub* GrpcRPCFactory::GetOrCreateStubForAddress(
|
||||
const string& address) {
|
||||
mutex_lock lock(mu_);
|
||||
|
||||
auto stub = stubs_.find(address);
|
||||
if (stub != stubs_.end()) return stub->second.get();
|
||||
|
||||
ChannelPtr channel = CreateChannelForAddress(address);
|
||||
auto* created = new ::grpc::GenericStub(channel);
|
||||
stubs_[address].reset(created);
|
||||
return created;
|
||||
}
|
||||
|
||||
GrpcRPCFactory::ChannelPtr GrpcRPCFactory::CreateChannelForAddress(
|
||||
const string& address) {
|
||||
::grpc::ChannelArguments args;
|
||||
args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, std::numeric_limits<int32>::max());
|
||||
|
||||
// Set a standard backoff timeout of 1s instead of the
|
||||
// (sometimes default) 20s.
|
||||
args.SetInt("grpc.testing.fixed_reconnect_backoff_ms", 1000);
|
||||
return ::grpc::CreateCustomChannel(
|
||||
/*target=*/address, ::grpc::InsecureChannelCredentials(), args);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
59
tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h
Normal file
59
tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h
Normal file
@ -0,0 +1,59 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_RPC_FACTORY_H_
|
||||
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_RPC_FACTORY_H_
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/util/rpc/rpc_factory.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class GrpcRPCFactory : public RPCFactory {
|
||||
public:
|
||||
explicit GrpcRPCFactory(OpKernelConstruction* ctx, bool fail_fast,
|
||||
int64 timeout_in_ms);
|
||||
|
||||
// Explicit destructor to control destruction order.
|
||||
~GrpcRPCFactory() override;
|
||||
|
||||
void Call(OpKernelContext* ctx, int64 num_elements, const Tensor& address_t,
|
||||
const Tensor& method_t, const Tensor& request_t, const bool try_rpc,
|
||||
Tensor* response_t, Tensor* status_code_t, Tensor* status_message_t,
|
||||
AsyncOpKernel::DoneCallback done) override;
|
||||
|
||||
protected:
|
||||
typedef std::shared_ptr<::grpc::Channel> ChannelPtr;
|
||||
virtual ChannelPtr CreateChannelForAddress(const string& address);
|
||||
|
||||
private:
|
||||
::grpc::GenericStub* GetOrCreateStubForAddress(const string& address);
|
||||
|
||||
bool fail_fast_;
|
||||
int64 timeout_in_ms_;
|
||||
::grpc::CompletionQueue completion_queue_;
|
||||
Thread* polling_thread_; // Owned.
|
||||
|
||||
mutex mu_;
|
||||
typedef std::unique_ptr<::grpc::GenericStub> StubPtr;
|
||||
std::unordered_map<string, StubPtr> stubs_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_RPC_FACTORY_H_
|
@ -0,0 +1,34 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h"
|
||||
#include "tensorflow/core/util/rpc/rpc_factory.h"
|
||||
#include "tensorflow/core/util/rpc/rpc_factory_registry.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// Used for adding the grpc factory to the RPC factory registry.
|
||||
struct Value {
|
||||
static RPCFactory* Function(OpKernelConstruction* ctx, bool fail_fast,
|
||||
int64 timeout_in_ms) {
|
||||
return new GrpcRPCFactory(ctx, fail_fast, timeout_in_ms);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_RPC_FACTORY("grpc", Value::Function);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -5121,6 +5121,9 @@ filegroup(
|
||||
"summary_interface.*",
|
||||
"summary_kernels.*",
|
||||
"spectrogram_convert_test_data.cc",
|
||||
"decode_proto_op.cc",
|
||||
"encode_proto_op.cc",
|
||||
"rpc_op.cc",
|
||||
# Excluded due to experimental status:
|
||||
"debug_ops.*",
|
||||
"scatter_nd_op*",
|
||||
@ -6153,6 +6156,50 @@ tf_kernel_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "decode_proto_op",
|
||||
srcs = [
|
||||
"decode_proto_op.cc",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:decode_proto_ops_op_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/util/proto:decode",
|
||||
"//tensorflow/core/util/proto:descriptors",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "encode_proto_op",
|
||||
srcs = ["encode_proto_op.cc"],
|
||||
deps = [
|
||||
"//tensorflow/core:encode_proto_ops_op_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/util/proto:descriptors",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "rpc_op",
|
||||
srcs = [
|
||||
"rpc_op.cc",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:rpc_ops_op_lib",
|
||||
"//tensorflow/core/util/rpc:call_container",
|
||||
"//tensorflow/core/util/rpc:rpc_factory",
|
||||
"//tensorflow/core/util/rpc:rpc_factory_registry",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Google-internal targets. These must be at the end for syncrepo.
|
||||
|
||||
|
1011
tensorflow/core/kernels/decode_proto_op.cc
Normal file
1011
tensorflow/core/kernels/decode_proto_op.cc
Normal file
File diff suppressed because it is too large
Load Diff
591
tensorflow/core/kernels/encode_proto_op.cc
Normal file
591
tensorflow/core/kernels/encode_proto_op.cc
Normal file
@ -0,0 +1,591 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// EncodeProto is a TensorFlow Op which serializes tensors into
|
||||
// arbitrary protobufs.
|
||||
//
|
||||
// See the docstring in ../ops/encode_proto_op.cc for usage of the op.
|
||||
//
|
||||
// This implementation writes the serialized format using a handful of
|
||||
// calls from the WireFormatLite API.
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/util/proto/descriptors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
using ::tensorflow::protobuf::Descriptor;
|
||||
using ::tensorflow::protobuf::DescriptorPool;
|
||||
using ::tensorflow::protobuf::FieldDescriptor;
|
||||
using ::tensorflow::protobuf::internal::WireFormatLite;
|
||||
using ::tensorflow::protobuf::io::CodedOutputStream;
|
||||
using ::tensorflow::protobuf::io::StringOutputStream;
|
||||
|
||||
// Computes the total serialized size for a packed repeated field.
|
||||
// For fixed-size types this can just multiply, but for variable-sized
|
||||
// types it has to iterate through the values in the tensor.
|
||||
template <WireFormatLite::FieldType FieldType, typename TensorT>
|
||||
size_t TotalPackedSize(const Tensor& input, int message_index, int size);
|
||||
|
||||
template <>
|
||||
size_t TotalPackedSize<WireFormatLite::TYPE_DOUBLE, double>(const Tensor& input,
|
||||
int message_index,
|
||||
int size) {
|
||||
return size * WireFormatLite::kDoubleSize;
|
||||
}
|
||||
|
||||
template <>
|
||||
size_t TotalPackedSize<WireFormatLite::TYPE_FLOAT, double>(const Tensor& input,
|
||||
int message_index,
|
||||
int size) {
|
||||
return size * WireFormatLite::kFloatSize;
|
||||
}
|
||||
|
||||
template <>
|
||||
size_t TotalPackedSize<WireFormatLite::TYPE_FLOAT, float>(const Tensor& input,
|
||||
int message_index,
|
||||
int size) {
|
||||
return size * WireFormatLite::kFloatSize;
|
||||
}
|
||||
|
||||
template <>
|
||||
size_t TotalPackedSize<WireFormatLite::TYPE_INT64, int64>(const Tensor& input,
|
||||
int message_index,
|
||||
int size) {
|
||||
size_t data_size = 0;
|
||||
auto input_t = input.flat_inner_dims<int64>();
|
||||
for (int64 i = 0; i < size; i++) {
|
||||
data_size += WireFormatLite::Int64Size(
|
||||
input_t(static_cast<int64>(message_index), i));
|
||||
}
|
||||
return data_size;
|
||||
}
|
||||
|
||||
template <>
|
||||
size_t TotalPackedSize<WireFormatLite::TYPE_UINT64, int64>(const Tensor& input,
|
||||
int message_index,
|
||||
int size) {
|
||||
size_t data_size = 0;
|
||||
auto input_t = input.flat_inner_dims<int64>();
|
||||
for (int64 i = 0; i < size; i++) {
|
||||
data_size += WireFormatLite::UInt64Size(
|
||||
input_t(static_cast<int64>(message_index), i));
|
||||
}
|
||||
return data_size;
|
||||
}
|
||||
|
||||
template <>
|
||||
size_t TotalPackedSize<WireFormatLite::TYPE_INT32, int32>(const Tensor& input,
|
||||
int message_index,
|
||||
int size) {
|
||||
size_t data_size = 0;
|
||||
auto input_t = input.flat_inner_dims<int32>();
|
||||
for (int64 i = 0; i < size; i++) {
|
||||
data_size += WireFormatLite::Int32Size(
|
||||
input_t(static_cast<int64>(message_index), i));
|
||||
}
|
||||
return data_size;
|
||||
}
|
||||
|
||||
template <>
|
||||
size_t TotalPackedSize<WireFormatLite::TYPE_FIXED64, int64>(const Tensor& input,
|
||||
int message_index,
|
||||
int size) {
|
||||
return size * WireFormatLite::kFixed64Size;
|
||||
}
|
||||
|
||||
template <>
|
||||
size_t TotalPackedSize<WireFormatLite::TYPE_FIXED32, int64>(const Tensor& input,
|
||||
int message_index,
|
||||
int size) {
|
||||
return size * WireFormatLite::kFixed32Size;
|
||||
}
|
||||
|
||||
template <>
|
||||
size_t TotalPackedSize<WireFormatLite::TYPE_FIXED32, int32>(const Tensor& input,
|
||||
int message_index,
|
||||
int size) {
|
||||
return size * WireFormatLite::kFixed32Size;
|
||||
}
|
||||
|
||||
template <>
|
||||
size_t TotalPackedSize<WireFormatLite::TYPE_BOOL, bool>(const Tensor& input,
|
||||
int message_index,
|
||||
int size) {
|
||||
return size * WireFormatLite::kBoolSize;
|
||||
}
|
||||
|
||||
template <>
|
||||
size_t TotalPackedSize<WireFormatLite::TYPE_UINT32, int64>(const Tensor& input,
|
||||
int message_index,
|
||||
int size) {
|
||||
size_t data_size = 0;
|
||||
auto input_t = input.flat_inner_dims<int64>();
|
||||
for (int64 i = 0; i < size; i++) {
|
||||
data_size += WireFormatLite::UInt32Size(
|
||||
input_t(static_cast<int64>(message_index), i));
|
||||
}
|
||||
return data_size;
|
||||
}
|
||||
|
||||
template <>
|
||||
size_t TotalPackedSize<WireFormatLite::TYPE_UINT32, int32>(const Tensor& input,
|
||||
int message_index,
|
||||
int size) {
|
||||
size_t data_size = 0;
|
||||
auto input_t = input.flat_inner_dims<int32>();
|
||||
for (int64 i = 0; i < size; i++) {
|
||||
data_size += WireFormatLite::UInt32Size(
|
||||
input_t(static_cast<int64>(message_index), i));
|
||||
}
|
||||
return data_size;
|
||||
}
|
||||
|
||||
template <>
|
||||
size_t TotalPackedSize<WireFormatLite::TYPE_ENUM, int32>(const Tensor& input,
|
||||
int message_index,
|
||||
int size) {
|
||||
size_t data_size = 0;
|
||||
auto input_t = input.flat_inner_dims<int32>();
|
||||
for (int64 i = 0; i < size; i++) {
|
||||
data_size +=
|
||||
WireFormatLite::EnumSize(input_t(static_cast<int64>(message_index), i));
|
||||
}
|
||||
return data_size;
|
||||
}
|
||||
|
||||
template <>
|
||||
size_t TotalPackedSize<WireFormatLite::TYPE_SFIXED32, int32>(
|
||||
const Tensor& input, int message_index, int size) {
|
||||
return size * WireFormatLite::kSFixed32Size;
|
||||
}
|
||||
|
||||
template <>
|
||||
size_t TotalPackedSize<WireFormatLite::TYPE_SFIXED64, int64>(
|
||||
const Tensor& input, int message_index, int size) {
|
||||
return size * WireFormatLite::kSFixed64Size;
|
||||
}
|
||||
|
||||
template <>
|
||||
size_t TotalPackedSize<WireFormatLite::TYPE_SINT32, int32>(const Tensor& input,
|
||||
int message_index,
|
||||
int size) {
|
||||
size_t data_size = 0;
|
||||
auto input_t = input.flat_inner_dims<int32>();
|
||||
for (int64 i = 0; i < size; i++) {
|
||||
data_size += WireFormatLite::SInt32Size(
|
||||
input_t(static_cast<int64>(message_index), i));
|
||||
}
|
||||
return data_size;
|
||||
}
|
||||
|
||||
template <>
|
||||
size_t TotalPackedSize<WireFormatLite::TYPE_SINT64, int64>(const Tensor& input,
|
||||
int message_index,
|
||||
int size) {
|
||||
size_t data_size = 0;
|
||||
auto input_t = input.flat_inner_dims<int64>();
|
||||
for (int64 i = 0; i < size; i++) {
|
||||
data_size += WireFormatLite::SInt64Size(
|
||||
input_t(static_cast<int64>(message_index), i));
|
||||
}
|
||||
return data_size;
|
||||
}
|
||||
|
||||
// Writes a possibly repeated primitive field.
|
||||
// TensorFlow does not have unsigned types, so we decode them to signed and
|
||||
// encode them back to unsigned.
|
||||
template <typename TensorT, typename ProtoT,
|
||||
WireFormatLite::FieldType FieldType,
|
||||
void Writer(ProtoT, CodedOutputStream*)>
|
||||
void WriteField(const FieldDescriptor& field_desc, const Tensor& input,
|
||||
int message_index, int size, CodedOutputStream* output) {
|
||||
auto wire_type = WireFormatLite::WireTypeForFieldType(
|
||||
WireFormatLite::FieldType(field_desc.type()));
|
||||
|
||||
auto input_t = input.flat_inner_dims<TensorT>();
|
||||
if (field_desc.options().packed()) {
|
||||
// Write the tag for the packed field.
|
||||
WireFormatLite::WriteTag(field_desc.number(),
|
||||
WireFormatLite::WIRETYPE_LENGTH_DELIMITED, output);
|
||||
|
||||
// Write the total packed length.
|
||||
size_t data_size =
|
||||
TotalPackedSize<FieldType, TensorT>(input, message_index, size);
|
||||
output->WriteVarint32(data_size);
|
||||
|
||||
// Write individual values.
|
||||
for (int64 i = 0; i < size; i++) {
|
||||
// Note implicit cast from signed to unsigned.
|
||||
const ProtoT& value = input_t(static_cast<int64>(message_index), i);
|
||||
Writer(value, output);
|
||||
}
|
||||
} else {
|
||||
for (int64 i = 0; i < size; i++) {
|
||||
WireFormatLite::WriteTag(field_desc.number(), wire_type, output);
|
||||
|
||||
// Note implicit cast from signed to unsigned.
|
||||
const ProtoT& value = input_t(static_cast<int64>(message_index), i);
|
||||
Writer(value, output);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Writes a possibly repeated string, bytes, or message field.
|
||||
template <typename T, void Writer(int, const T&, CodedOutputStream*)>
|
||||
void WriteVarLenField(const FieldDescriptor& field_desc, const Tensor& input,
|
||||
int message_index, int size, CodedOutputStream* output) {
|
||||
auto input_t = input.flat_inner_dims<T>();
|
||||
for (int64 i = 0; i < size; i++) {
|
||||
const T& value = input_t(static_cast<int64>(message_index), i);
|
||||
// TODO(nix): there doesn't seem to be an inlined version of
|
||||
// WireFormatLite::WriteString or its relatives, which might allow a
|
||||
// small speedup.
|
||||
Writer(field_desc.number(), value, output);
|
||||
}
|
||||
}
|
||||
|
||||
// Writes a group field.
|
||||
// Groups are treated like submessages, but tag-delimited
|
||||
// instead of length-delimited. WireFormatLite handles this
|
||||
// differently so we code it ourselves.
|
||||
void WriteGroup(const FieldDescriptor& field_desc, const Tensor& input,
|
||||
int message_index, int size, CodedOutputStream* output) {
|
||||
auto input_t = input.flat_inner_dims<string>();
|
||||
for (int64 i = 0; i < size; i++) {
|
||||
const string& value = input_t(static_cast<int64>(message_index), i);
|
||||
WireFormatLite::WriteTag(field_desc.number(),
|
||||
WireFormatLite::WIRETYPE_START_GROUP, output);
|
||||
// Note the use of WriteRaw instead of WriteString to skip the length.
|
||||
output->WriteRaw(value.data(), value.size());
|
||||
WireFormatLite::WriteTag(field_desc.number(),
|
||||
WireFormatLite::WIRETYPE_END_GROUP, output);
|
||||
}
|
||||
}
|
||||
|
||||
// Writes a (possibly repeated) field into an output stream.
|
||||
// It is the caller's responsibility to ensure that the type of
|
||||
// the input tensor is compatible with the type of the proto
|
||||
// field descriptor, and that (message_index, size-1) is within
|
||||
// bounds.
|
||||
void WriteField(const FieldDescriptor& field_desc, const Tensor& input,
|
||||
int message_index, int size, CodedOutputStream* output) {
|
||||
DataType tf_type = input.dtype();
|
||||
|
||||
switch (field_desc.type()) {
|
||||
case WireFormatLite::TYPE_DOUBLE:
|
||||
return WriteField<double, double, WireFormatLite::TYPE_DOUBLE,
|
||||
WireFormatLite::WriteDoubleNoTag>(
|
||||
field_desc, input, message_index, size, output);
|
||||
case WireFormatLite::TYPE_FLOAT:
|
||||
switch (tf_type) {
|
||||
case DataType::DT_FLOAT:
|
||||
return WriteField<float, float, WireFormatLite::TYPE_FLOAT,
|
||||
WireFormatLite::WriteFloatNoTag>(
|
||||
field_desc, input, message_index, size, output);
|
||||
case DataType::DT_DOUBLE:
|
||||
return WriteField<double, float, WireFormatLite::TYPE_FLOAT,
|
||||
WireFormatLite::WriteFloatNoTag>(
|
||||
field_desc, input, message_index, size, output);
|
||||
default:
|
||||
return;
|
||||
}
|
||||
case WireFormatLite::TYPE_INT64:
|
||||
return WriteField<int64, protobuf_int64, WireFormatLite::TYPE_INT64,
|
||||
WireFormatLite::WriteInt64NoTag>(
|
||||
field_desc, input, message_index, size, output);
|
||||
case WireFormatLite::TYPE_UINT64:
|
||||
return WriteField<int64, protobuf_uint64, WireFormatLite::TYPE_UINT64,
|
||||
WireFormatLite::WriteUInt64NoTag>(
|
||||
field_desc, input, message_index, size, output);
|
||||
case WireFormatLite::TYPE_INT32:
|
||||
return WriteField<int32, int32, WireFormatLite::TYPE_INT32,
|
||||
WireFormatLite::WriteInt32NoTag>(
|
||||
field_desc, input, message_index, size, output);
|
||||
case WireFormatLite::TYPE_FIXED64:
|
||||
return WriteField<int64, protobuf_uint64, WireFormatLite::TYPE_FIXED64,
|
||||
WireFormatLite::WriteFixed64NoTag>(
|
||||
field_desc, input, message_index, size, output);
|
||||
case WireFormatLite::TYPE_FIXED32:
|
||||
switch (tf_type) {
|
||||
case DataType::DT_INT64:
|
||||
return WriteField<int64, uint32, WireFormatLite::TYPE_FIXED32,
|
||||
WireFormatLite::WriteFixed32NoTag>(
|
||||
field_desc, input, message_index, size, output);
|
||||
case DataType::DT_INT32:
|
||||
return WriteField<int32, uint32, WireFormatLite::TYPE_FIXED32,
|
||||
WireFormatLite::WriteFixed32NoTag>(
|
||||
field_desc, input, message_index, size, output);
|
||||
default:
|
||||
return;
|
||||
}
|
||||
case WireFormatLite::TYPE_BOOL:
|
||||
return WriteField<bool, bool, WireFormatLite::TYPE_BOOL,
|
||||
WireFormatLite::WriteBoolNoTag>(
|
||||
field_desc, input, message_index, size, output);
|
||||
case WireFormatLite::TYPE_STRING:
|
||||
return WriteVarLenField<string, WireFormatLite::WriteString>(
|
||||
field_desc, input, message_index, size, output);
|
||||
case WireFormatLite::TYPE_GROUP:
|
||||
return WriteGroup(field_desc, input, message_index, size, output);
|
||||
case WireFormatLite::TYPE_MESSAGE:
|
||||
return WriteVarLenField<string, WireFormatLite::WriteBytes>(
|
||||
field_desc, input, message_index, size, output);
|
||||
case WireFormatLite::TYPE_BYTES:
|
||||
return WriteVarLenField<string, WireFormatLite::WriteBytes>(
|
||||
field_desc, input, message_index, size, output);
|
||||
case WireFormatLite::TYPE_UINT32:
|
||||
switch (tf_type) {
|
||||
case DataType::DT_INT64:
|
||||
return WriteField<int64, uint32, WireFormatLite::TYPE_UINT32,
|
||||
WireFormatLite::WriteUInt32NoTag>(
|
||||
field_desc, input, message_index, size, output);
|
||||
case DataType::DT_INT32:
|
||||
return WriteField<int32, uint32, WireFormatLite::TYPE_UINT32,
|
||||
WireFormatLite::WriteUInt32NoTag>(
|
||||
field_desc, input, message_index, size, output);
|
||||
default:
|
||||
return;
|
||||
}
|
||||
case WireFormatLite::TYPE_ENUM:
|
||||
return WriteField<int32, int32, WireFormatLite::TYPE_ENUM,
|
||||
WireFormatLite::WriteEnumNoTag>(
|
||||
field_desc, input, message_index, size, output);
|
||||
case WireFormatLite::TYPE_SFIXED32:
|
||||
return WriteField<int32, int32, WireFormatLite::TYPE_SFIXED32,
|
||||
WireFormatLite::WriteSFixed32NoTag>(
|
||||
field_desc, input, message_index, size, output);
|
||||
case WireFormatLite::TYPE_SFIXED64:
|
||||
return WriteField<int64, protobuf_int64, WireFormatLite::TYPE_SFIXED64,
|
||||
WireFormatLite::WriteSFixed64NoTag>(
|
||||
field_desc, input, message_index, size, output);
|
||||
case WireFormatLite::TYPE_SINT32:
|
||||
return WriteField<int32, int32, WireFormatLite::TYPE_SINT32,
|
||||
WireFormatLite::WriteSInt32NoTag>(
|
||||
field_desc, input, message_index, size, output);
|
||||
case WireFormatLite::TYPE_SINT64:
|
||||
return WriteField<int64, protobuf_int64, WireFormatLite::TYPE_SINT64,
|
||||
WireFormatLite::WriteSInt64NoTag>(
|
||||
field_desc, input, message_index, size, output);
|
||||
// default: intentionally omitted in order to enable static checking.
|
||||
}
|
||||
}
|
||||
|
||||
// Checks that a Protobuf field is compatible with a TensorFlow datatype.
|
||||
// This is separated from WriteField to lift it out of the inner loop.
|
||||
bool IsCompatibleType(const FieldDescriptor& field_desc, DataType tf_type) {
|
||||
switch (field_desc.type()) {
|
||||
case WireFormatLite::TYPE_DOUBLE:
|
||||
return tf_type == DataType::DT_DOUBLE;
|
||||
case WireFormatLite::TYPE_FLOAT:
|
||||
return tf_type == DataType::DT_FLOAT || tf_type == DataType::DT_DOUBLE;
|
||||
case WireFormatLite::TYPE_INT64:
|
||||
case WireFormatLite::TYPE_SFIXED64:
|
||||
case WireFormatLite::TYPE_SINT64:
|
||||
return tf_type == DataType::DT_INT64;
|
||||
case WireFormatLite::TYPE_UINT64:
|
||||
return tf_type == DataType::DT_INT64;
|
||||
case WireFormatLite::TYPE_INT32:
|
||||
case WireFormatLite::TYPE_ENUM:
|
||||
case WireFormatLite::TYPE_SFIXED32:
|
||||
case WireFormatLite::TYPE_SINT32:
|
||||
return tf_type == DataType::DT_INT32;
|
||||
case WireFormatLite::TYPE_FIXED64:
|
||||
return tf_type == DataType::DT_INT64;
|
||||
case WireFormatLite::TYPE_FIXED32:
|
||||
case WireFormatLite::TYPE_UINT32:
|
||||
return tf_type == DataType::DT_INT64 || tf_type == DataType::DT_INT32;
|
||||
case WireFormatLite::TYPE_BOOL:
|
||||
return tf_type == DataType::DT_BOOL;
|
||||
case WireFormatLite::TYPE_STRING:
|
||||
case WireFormatLite::TYPE_GROUP:
|
||||
case WireFormatLite::TYPE_MESSAGE:
|
||||
case WireFormatLite::TYPE_BYTES:
|
||||
return tf_type == DataType::DT_STRING;
|
||||
// default: intentionally omitted in order to enable static checking.
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
class EncodeProtoOp : public OpKernel {
|
||||
public:
|
||||
explicit EncodeProtoOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||
string descriptor_source;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("descriptor_source", &descriptor_source));
|
||||
// We always get back a desc_pool, but we may not own it. If we own it,
|
||||
// owned_desc_pool_ will be filled in.
|
||||
DescriptorPool const* desc_pool;
|
||||
OP_REQUIRES_OK(context, GetDescriptorPool(context->env(), descriptor_source,
|
||||
&desc_pool, &owned_desc_pool_));
|
||||
|
||||
string message_type;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("message_type", &message_type));
|
||||
const Descriptor* message_desc =
|
||||
desc_pool->FindMessageTypeByName(message_type);
|
||||
OP_REQUIRES(context, message_desc != nullptr,
|
||||
errors::InvalidArgument("No descriptor found for message type ",
|
||||
message_type));
|
||||
|
||||
OP_REQUIRES_OK(context, context->GetAttr("field_names", &field_names_));
|
||||
|
||||
// Gather the field descriptors for the given field_names.
|
||||
field_descs_.resize(field_names_.size());
|
||||
for (int i = 0; i < field_names_.size(); i++) {
|
||||
const string& name = field_names_[i];
|
||||
auto field_desc = message_desc->FindFieldByName(name);
|
||||
OP_REQUIRES(context, field_desc != nullptr,
|
||||
errors::InvalidArgument("Unknown field: ", name,
|
||||
" in message type ", message_type));
|
||||
|
||||
field_descs_[i] = field_desc;
|
||||
}
|
||||
|
||||
// Build a list of indices into field_descs sorted by increasing
|
||||
// field_number. This will be used to output fields in sorted order,
|
||||
// which is strongly encouraged when serializing protobufs.
|
||||
sorted_field_index_.resize(field_names_.size());
|
||||
// Start with the fields sorted by current index.
|
||||
for (int i = 0; i < field_names_.size(); i++) sorted_field_index_[i] = i;
|
||||
// Then sort the field indices by their proto field number.
|
||||
std::sort(sorted_field_index_.begin(), sorted_field_index_.end(),
|
||||
[this](int a, int b) -> bool {
|
||||
return field_descs_[a]->number() < field_descs_[b]->number();
|
||||
});
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* cx) override {
|
||||
const Tensor* sizes_tensor;
|
||||
OP_REQUIRES_OK(cx, cx->input("sizes", &sizes_tensor));
|
||||
|
||||
OpInputList values;
|
||||
OP_REQUIRES_OK(cx, cx->input_list("values", &values));
|
||||
|
||||
OP_REQUIRES(cx, field_descs_.size() == values.size(),
|
||||
errors::InvalidArgument(
|
||||
"Length of inputs list must match field_names"));
|
||||
|
||||
// Check the arguments for consistency.
|
||||
TensorShape common_prefix;
|
||||
int message_count;
|
||||
for (int i = 0; i < field_descs_.size(); i++) {
|
||||
const Tensor& v = values[i];
|
||||
|
||||
// The type of each value tensor must match the corresponding field.
|
||||
OP_REQUIRES(cx, IsCompatibleType(*field_descs_[i], v.dtype()),
|
||||
errors::InvalidArgument(
|
||||
"Incompatible type for field " + field_names_[i] +
|
||||
". Saw dtype: ",
|
||||
DataTypeString(v.dtype()),
|
||||
" but field type is: ", field_descs_[i]->type_name()));
|
||||
|
||||
// All value tensors must have the same shape prefix (i.e. batch size).
|
||||
TensorShape shape_prefix = v.shape();
|
||||
shape_prefix.RemoveDim(shape_prefix.dims() - 1);
|
||||
|
||||
// Do some initialization on the first input value. The rest will
|
||||
// have to match this one.
|
||||
if (i == 0) {
|
||||
OP_REQUIRES(cx, v.dims() >= 1,
|
||||
errors::InvalidArgument(
|
||||
"Expected value to be at least a vector, saw shape: ",
|
||||
v.shape().DebugString()));
|
||||
common_prefix = shape_prefix;
|
||||
message_count = common_prefix.num_elements();
|
||||
} else {
|
||||
OP_REQUIRES(cx, shape_prefix == common_prefix,
|
||||
errors::InvalidArgument(
|
||||
"Values must match up to the last dimension"));
|
||||
}
|
||||
}
|
||||
|
||||
TensorShape expected_sizes_shape = common_prefix;
|
||||
expected_sizes_shape.AddDim(field_descs_.size());
|
||||
|
||||
OP_REQUIRES(cx, sizes_tensor->shape() == expected_sizes_shape,
|
||||
errors::InvalidArgument(
|
||||
"sizes should be batch_size + [len(field_names)]. Saw: ",
|
||||
sizes_tensor->shape().DebugString(),
|
||||
" but expected: ", expected_sizes_shape.DebugString()));
|
||||
|
||||
auto sizes = sizes_tensor->flat_inner_dims<int32>();
|
||||
|
||||
for (int i = 0; i < field_descs_.size(); ++i) {
|
||||
const Tensor& v = values[i];
|
||||
int max_size = v.dim_size(v.dims() - 1);
|
||||
|
||||
// The last dimension of a value tensor must be greater than the
|
||||
// corresponding
|
||||
// size in the sizes tensor.
|
||||
for (int message_index = 0; message_index < message_count;
|
||||
message_index++) {
|
||||
OP_REQUIRES(
|
||||
cx, sizes(message_index, i) <= max_size,
|
||||
errors::InvalidArgument(
|
||||
"Size to write must not be larger than value tensor; but saw: ",
|
||||
sizes(message_index, i), " > ", max_size, " at message ",
|
||||
message_index, " field ", i));
|
||||
}
|
||||
}
|
||||
|
||||
// This pointer is owned by the context.
|
||||
Tensor* output_tensor;
|
||||
OP_REQUIRES_OK(cx, cx->allocate_output(0, common_prefix, &output_tensor));
|
||||
|
||||
auto bufs = output_tensor->flat<string>();
|
||||
for (int message_index = 0; message_index < message_count;
|
||||
message_index++) {
|
||||
// TODO(nix): possibly optimize allocation here by calling
|
||||
// bufs(message_index).reserve(DEFAULT_BUF_SIZE);
|
||||
StringOutputStream output_string(&bufs(message_index));
|
||||
CodedOutputStream out(&output_string);
|
||||
// Write fields in ascending field_number order.
|
||||
for (int i : sorted_field_index_) {
|
||||
auto& field_desc = *field_descs_[i];
|
||||
const Tensor& v = values[i];
|
||||
int size = sizes(message_index, i);
|
||||
if (!size) continue;
|
||||
WriteField(field_desc, v, message_index, size, &out);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<string> field_names_;
|
||||
std::vector<const FieldDescriptor*> field_descs_;
|
||||
|
||||
// Owned_desc_pool_ is null when using descriptor_source=local.
|
||||
std::unique_ptr<DescriptorPool> owned_desc_pool_;
|
||||
|
||||
// Contains indices into field_names_, sorted by field number since
|
||||
// that's the order of writing.
|
||||
std::vector<int> sorted_field_index_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(EncodeProtoOp);
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("EncodeProto").Device(DEVICE_CPU), EncodeProtoOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
129
tensorflow/core/kernels/rpc_op.cc
Normal file
129
tensorflow/core/kernels/rpc_op.cc
Normal file
@ -0,0 +1,129 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// RpcOp is a TensorFlow op that sends and receives arbitrary messages.
|
||||
//
|
||||
// See docs in ../ops/rpc_op.cc.
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/lib/gtl/stl_util.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/util/rpc/call_container.h"
|
||||
#include "tensorflow/core/util/rpc/rpc_factory.h"
|
||||
#include "tensorflow/core/util/rpc/rpc_factory_registry.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class RpcOp : public AsyncOpKernel {
|
||||
public:
|
||||
explicit RpcOp(OpKernelConstruction* context) : AsyncOpKernel(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("protocol", &protocol_));
|
||||
OP_REQUIRES(context, !protocol_.empty(),
|
||||
errors::InvalidArgument("protocol must be non-empty."));
|
||||
bool fail_fast;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("fail_fast", &fail_fast));
|
||||
int64 timeout_in_ms;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("timeout_in_ms", &timeout_in_ms));
|
||||
|
||||
RPCFactoryRegistry::RPCFactoryFn* rpc_factory_fn =
|
||||
RPCFactoryRegistry::Global()->Get(protocol_);
|
||||
OP_REQUIRES(context, rpc_factory_fn != nullptr,
|
||||
errors::InvalidArgument("The protocol ", protocol_,
|
||||
" was not recognized."));
|
||||
|
||||
rpc_factory_.reset((*rpc_factory_fn)(context, fail_fast, timeout_in_ms));
|
||||
}
|
||||
|
||||
~RpcOp() override {}
|
||||
|
||||
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
|
||||
const Tensor& address_t = ctx->input(0);
|
||||
const Tensor& method_t = ctx->input(1);
|
||||
const Tensor& request_t = ctx->input(2);
|
||||
|
||||
OP_REQUIRES_ASYNC(
|
||||
ctx, address_t.dims() == 0 || address_t.dims() == 1,
|
||||
errors::InvalidArgument("address must be a scalar or vector."), done);
|
||||
OP_REQUIRES_ASYNC(
|
||||
ctx, method_t.dims() == 0 || method_t.dims() == 1,
|
||||
errors::InvalidArgument("method must be a scalar or vector."), done);
|
||||
OP_REQUIRES_ASYNC(
|
||||
ctx, request_t.dims() == 0 || request_t.dims() == 1,
|
||||
errors::InvalidArgument("request must be a scalar or vector."), done);
|
||||
|
||||
TensorShape output_shape({});
|
||||
for (const Tensor& t : {address_t, method_t, request_t}) {
|
||||
if (t.dims() == 1) {
|
||||
OP_REQUIRES_ASYNC(
|
||||
ctx,
|
||||
output_shape.dims() == 0 ||
|
||||
output_shape.dim_size(0) == t.dim_size(0),
|
||||
errors::InvalidArgument(
|
||||
"Input vector shapes don't match: ", output_shape.DebugString(),
|
||||
" vs. ", t.shape().DebugString()),
|
||||
done);
|
||||
output_shape = t.shape();
|
||||
}
|
||||
}
|
||||
|
||||
Tensor* response_t;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx, ctx->allocate_output(0, output_shape, &response_t), done);
|
||||
|
||||
const bool try_rpc = (ctx->num_outputs() > 1);
|
||||
|
||||
Tensor* status_code_t = nullptr;
|
||||
Tensor* status_message_t = nullptr;
|
||||
if (try_rpc) {
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx, ctx->allocate_output(1, output_shape, &status_code_t), done);
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx, ctx->allocate_output(2, output_shape, &status_message_t), done);
|
||||
}
|
||||
|
||||
if (request_t.NumElements() == 0) {
|
||||
// Special case, we finished early!
|
||||
done();
|
||||
return;
|
||||
}
|
||||
|
||||
int64 num_elements = output_shape.num_elements();
|
||||
|
||||
rpc_factory_->Call(ctx, num_elements, address_t, method_t, request_t,
|
||||
try_rpc, response_t, status_code_t, status_message_t,
|
||||
std::move(done));
|
||||
}
|
||||
|
||||
private:
|
||||
string protocol_;
|
||||
std::unique_ptr<RPCFactory> rpc_factory_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(RpcOp);
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("Rpc").Device(DEVICE_CPU), RpcOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("TryRpc").Device(DEVICE_CPU), RpcOp);
|
||||
|
||||
} // namespace tensorflow
|
67
tensorflow/core/ops/decode_proto_ops.cc
Normal file
67
tensorflow/core/ops/decode_proto_ops.cc
Normal file
@ -0,0 +1,67 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using tensorflow::shape_inference::InferenceContext;
|
||||
using tensorflow::shape_inference::ShapeHandle;
|
||||
|
||||
REGISTER_OP("DecodeProtoV2")
|
||||
.Input("bytes: string")
|
||||
.Attr("message_type: string")
|
||||
.Attr("field_names: list(string)")
|
||||
.Attr("output_types: list(type) >= 0")
|
||||
.Attr("descriptor_source: string = 'local://'")
|
||||
.Attr("message_format: string = 'binary'")
|
||||
.Attr("sanitize: bool = false")
|
||||
.Output("sizes: int32")
|
||||
.Output("values: output_types")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle input = c->input(0);
|
||||
|
||||
std::vector<tensorflow::DataType> output_types;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("output_types", &output_types));
|
||||
|
||||
ShapeHandle sizes;
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Concatenate(input, c->Vector(output_types.size()), &sizes));
|
||||
c->set_output(0, sizes);
|
||||
|
||||
// TODO(nix): to do the best possible job of shape inference, we
|
||||
// should examine the proto descriptors here in order to set shape
|
||||
// indices to 1 instead of unknown for optional or required fields.
|
||||
// Any general-purpose code will have to handle the unknown case,
|
||||
// but there might be XLA code that could be sped up with the additional
|
||||
// knowledge.
|
||||
for (int i = 0; i < output_types.size(); ++i) {
|
||||
ShapeHandle values;
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Concatenate(input, c->Vector(c->UnknownDim()), &values));
|
||||
c->set_output(i + 1, values);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
// TODO(nix): Consider adding an additional input argument that truncates
|
||||
// repeated fields to a maximum count. For now this could be done by passing
|
||||
// the output through tf.slice.
|
||||
|
||||
// TODO(nix): define missing value behavior.
|
||||
|
||||
} // namespace tensorflow
|
49
tensorflow/core/ops/encode_proto_ops.cc
Normal file
49
tensorflow/core/ops/encode_proto_ops.cc
Normal file
@ -0,0 +1,49 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using tensorflow::shape_inference::InferenceContext;
|
||||
using tensorflow::shape_inference::ShapeHandle;
|
||||
|
||||
REGISTER_OP("EncodeProto")
|
||||
.Input("sizes: int32")
|
||||
.Input("values: Tinput_types")
|
||||
.Attr("field_names: list(string)")
|
||||
.Attr("message_type: string")
|
||||
.Attr("descriptor_source: string = 'local://'")
|
||||
.Attr("Tinput_types: list(type)")
|
||||
.Output("bytes: string")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
int first_field_index = 1;
|
||||
int num_fields = c->num_inputs() - 1;
|
||||
|
||||
ShapeHandle output;
|
||||
for (int i = num_fields - 1; i >= 0; --i) {
|
||||
ShapeHandle input = c->input(first_field_index + i);
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, 2, &input));
|
||||
ShapeHandle inner;
|
||||
TF_RETURN_IF_ERROR(c->Subshape(input, 0, -1, &inner));
|
||||
TF_RETURN_IF_ERROR(c->Merge(inner, output, &output));
|
||||
}
|
||||
|
||||
c->set_output(0, output);
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
} // namespace tensorflow
|
81
tensorflow/core/ops/rpc_ops.cc
Normal file
81
tensorflow/core/ops/rpc_ops.cc
Normal file
@ -0,0 +1,81 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using tensorflow::shape_inference::DimensionHandle;
|
||||
using tensorflow::shape_inference::InferenceContext;
|
||||
using tensorflow::shape_inference::ShapeHandle;
|
||||
|
||||
Status RpcShapeOp(InferenceContext* c, bool try_rpc) {
|
||||
ShapeHandle address;
|
||||
ShapeHandle method;
|
||||
ShapeHandle request;
|
||||
ShapeHandle output;
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &address));
|
||||
if (c->Rank(address) == 1) {
|
||||
TF_RETURN_IF_ERROR(c->Merge(output, address, &output));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &method));
|
||||
if (c->Rank(method) == 1) {
|
||||
TF_RETURN_IF_ERROR(c->Merge(output, method, &output));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &request));
|
||||
if (c->Rank(request) == 1) {
|
||||
TF_RETURN_IF_ERROR(c->Merge(output, request, &output));
|
||||
}
|
||||
if (!c->RankKnown(output)) {
|
||||
output = request;
|
||||
}
|
||||
c->set_output(0, output); // response
|
||||
if (try_rpc) {
|
||||
c->set_output(1, output); // status_code
|
||||
c->set_output(2, output); // status_message
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
REGISTER_OP("Rpc")
|
||||
.Input("address: string")
|
||||
.Input("method: string")
|
||||
.Input("request: string")
|
||||
.Attr("protocol: string = ''")
|
||||
.Attr("fail_fast: bool = true")
|
||||
.Attr("timeout_in_ms: int = 0")
|
||||
.Output("response: string")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
return RpcShapeOp(c, /*try_rpc=*/false);
|
||||
});
|
||||
|
||||
REGISTER_OP("TryRpc")
|
||||
.Input("address: string")
|
||||
.Input("method: string")
|
||||
.Input("request: string")
|
||||
.Attr("protocol: string = ''")
|
||||
.Attr("fail_fast: bool = true")
|
||||
.Attr("timeout_in_ms: int = 0")
|
||||
.Output("response: string")
|
||||
.Output("status_code: int32")
|
||||
.Output("status_message: string")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
return RpcShapeOp(c, /*try_rpc=*/true);
|
||||
});
|
||||
|
||||
} // namespace tensorflow
|
62
tensorflow/core/util/proto/BUILD
Normal file
62
tensorflow/core/util/proto/BUILD
Normal file
@ -0,0 +1,62 @@
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
|
||||
cc_library(
|
||||
name = "decode",
|
||||
hdrs = ["decode.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "descriptors",
|
||||
srcs = ["descriptors.cc"],
|
||||
hdrs = ["descriptors.h"],
|
||||
deps = [
|
||||
":descriptor_pool_registry",
|
||||
":local_descriptor_pool_registration",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "descriptor_pool_registry",
|
||||
srcs = ["descriptor_pool_registry.cc"],
|
||||
hdrs = ["descriptor_pool_registry.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "descriptor_pool_registry_test",
|
||||
srcs = ["descriptor_pool_registry_test.cc"],
|
||||
deps = [
|
||||
":descriptor_pool_registry",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
# Depending on this target adds support for using the special
|
||||
# value "local://" (or "") for descriptor source, in which case
|
||||
# descriptors linked into the code will be searched.
|
||||
cc_library(
|
||||
name = "local_descriptor_pool_registration",
|
||||
srcs = ["local_descriptor_pool_registration.cc"],
|
||||
deps = [
|
||||
":descriptor_pool_registry",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
592
tensorflow/core/util/proto/decode.h
Normal file
592
tensorflow/core/util/proto/decode.h
Normal file
@ -0,0 +1,592 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Inline functions for parsing the protocol buffers wire format.
|
||||
//
|
||||
// These functions have been optimized at the expense of safety.
|
||||
// They are broken out into a separate file for readability but are
|
||||
// not intended for use by clients other than the decode_proto op.
|
||||
//
|
||||
// The calling code in the decode_proto op does some fairly
|
||||
// complicated things to ensure that this code is called
|
||||
// safely. Changes to this code should be thoroughly fuzz tested.
|
||||
|
||||
#ifndef TENSORFLOW_CORE_UTIL_PROTO_DECODE_H_
|
||||
#define TENSORFLOW_CORE_UTIL_PROTO_DECODE_H_
|
||||
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
using tensorflow::protobuf::internal::WireFormatLite;
|
||||
using tensorflow::protobuf::io::CodedInputStream;
|
||||
using tensorflow::protobuf::io::CodedOutputStream;
|
||||
using tensorflow::protobuf::io::StringOutputStream;
|
||||
|
||||
// Converts an uint64 to an int64 without loss of information.
|
||||
// Unsigned values greater than INT64_MAX are represented as
|
||||
// negative numbers by wrapping (same as twos-complement bit equivalence).
|
||||
inline int64 WrapUnsignedAsSigned64(uint64 unsigned_value) {
|
||||
// For a detailed explanation of why this works to wrap unsigned ints, see
|
||||
// http://stackoverflow.com/questions/13150449/efficient-unsigned-to-signed-cast-avoiding-implementation-defined-behavior
|
||||
// Both if tests should be optimized out.
|
||||
if (unsigned_value <= INT64_MAX) {
|
||||
return static_cast<int64>(unsigned_value);
|
||||
}
|
||||
// The C++ spec allows an architecture where this test is required.
|
||||
if (unsigned_value >= INT64_MIN) {
|
||||
return static_cast<int64>(unsigned_value - INT64_MIN) + INT64_MIN;
|
||||
}
|
||||
return 0; // This should never occur.
|
||||
}
|
||||
|
||||
// Converts an uint32 to an int32 without loss of information.
|
||||
// Unsigned values greater than INT_MAX are represented as
|
||||
// negative numbers by wrapping (same as twos-complement bit equivalence).
|
||||
inline int32 WrapUnsignedAsSigned32(uint32 unsigned_value) {
|
||||
// For a detailed explanation of why this works to wrap unsigned ints, see
|
||||
// http://stackoverflow.com/questions/13150449/efficient-unsigned-to-signed-cast-avoiding-implementation-defined-behavior
|
||||
// Both if tests should be optimized out.
|
||||
if (unsigned_value <= INT_MAX) {
|
||||
return static_cast<int32>(unsigned_value);
|
||||
}
|
||||
// The C++ spec allows an architecture where this test is required.
|
||||
if (unsigned_value >= INT_MIN) {
|
||||
return static_cast<int32>(unsigned_value - INT_MIN) + INT_MIN;
|
||||
}
|
||||
return 0; // This should never occur.
|
||||
}
|
||||
|
||||
// Reads a single varint32 from a byte array.
|
||||
// It is the caller's responsibility to ensure that there is enough
|
||||
// space in the buffer.
|
||||
// The ok value will be set to false if the buffer does not contain
|
||||
// a valid varint.
|
||||
inline const uint8* ReadVarint64FromArray(const uint8* buffer, bool* ok,
|
||||
uint64* value);
|
||||
|
||||
// Reads a single varint32 from a byte array.
|
||||
// It is the caller's responsibility to ensure that there is enough
|
||||
// space in the buffer.
|
||||
// The ok value will be set to false if the buffer does not contain
|
||||
// a valid varint.
|
||||
// This is slightly less efficient than the private version in
|
||||
// coded_stream.cc but we duplicate less code by calling
|
||||
// the 64 bit version instead of copying the code.
|
||||
inline const uint8* ReadVarint32FromArray(const uint8* buffer, bool* ok,
|
||||
uint32* value) {
|
||||
uint64 tmp;
|
||||
const uint8* buf = ReadVarint64FromArray(buffer, ok, &tmp);
|
||||
*value = tmp & 0xffffffff;
|
||||
return buf;
|
||||
}
|
||||
|
||||
// Reads a single proto field value from a byte array into an array.
|
||||
// The array is part of a Tensor that was allocated by the caller
|
||||
// with type TensorType, while DeclaredType is the proto field type.
|
||||
template <class TensorType, enum WireFormatLite::FieldType DeclaredType>
|
||||
const uint8* ReadFromArray(const uint8* buf, TensorType* value);
|
||||
|
||||
template <>
|
||||
inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_INT32>(
|
||||
const uint8* buf, int32* value) {
|
||||
uint32 temp;
|
||||
bool unused_ok; // The Counting pass would have failed if this were corrupt.
|
||||
buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
|
||||
*value = static_cast<int32>(temp);
|
||||
return buf;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_INT64>(
|
||||
const uint8* buf, int64* value) {
|
||||
uint64 temp;
|
||||
bool unused_ok; // The Counting pass would have failed if this were corrupt.
|
||||
buf = ReadVarint64FromArray(buf, &unused_ok, &temp);
|
||||
*value = WrapUnsignedAsSigned64(temp);
|
||||
return buf;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_UINT32>(
|
||||
const uint8* buf, int64* value) {
|
||||
uint32 temp;
|
||||
bool unused_ok; // The Counting pass would have failed if this were corrupt.
|
||||
buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
|
||||
*value = temp;
|
||||
return buf;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_UINT32>(
|
||||
const uint8* buf, int32* value) {
|
||||
uint32 temp;
|
||||
bool unused_ok; // The Counting pass would have failed if this were corrupt.
|
||||
buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
|
||||
*value = WrapUnsignedAsSigned32(temp);
|
||||
return buf;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_UINT64>(
|
||||
const uint8* buf, int64* value) {
|
||||
uint64 temp;
|
||||
bool unused_ok; // The Counting pass would have failed if this were corrupt.
|
||||
buf = ReadVarint64FromArray(buf, &unused_ok, &temp);
|
||||
*value = static_cast<int64>(temp);
|
||||
return buf;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_SINT32>(
|
||||
const uint8* buf, int32* value) {
|
||||
uint32 temp;
|
||||
bool unused_ok; // The Counting pass would have failed if this were corrupt.
|
||||
buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
|
||||
*value = WireFormatLite::ZigZagDecode32(temp);
|
||||
return buf;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_SINT64>(
|
||||
const uint8* buf, int64* value) {
|
||||
uint64 temp;
|
||||
bool unused_ok; // The Counting pass would have failed if this were corrupt.
|
||||
buf = ReadVarint64FromArray(buf, &unused_ok, &temp);
|
||||
*value = WireFormatLite::ZigZagDecode64(temp);
|
||||
return buf;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_FIXED32>(
|
||||
const uint8* buf, int64* value) {
|
||||
uint32 temp;
|
||||
buf = WireFormatLite::ReadPrimitiveFromArray<uint32,
|
||||
WireFormatLite::TYPE_FIXED32>(
|
||||
buf, &temp);
|
||||
*value = temp;
|
||||
return buf;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_FIXED32>(
|
||||
const uint8* buf, int32* value) {
|
||||
uint32 temp;
|
||||
buf = WireFormatLite::ReadPrimitiveFromArray<uint32,
|
||||
WireFormatLite::TYPE_FIXED32>(
|
||||
buf, &temp);
|
||||
*value = WrapUnsignedAsSigned32(temp);
|
||||
return buf;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_FIXED64>(
|
||||
const uint8* buf, int64* value) {
|
||||
protobuf_uint64 temp;
|
||||
buf = WireFormatLite::ReadPrimitiveFromArray<protobuf_uint64,
|
||||
WireFormatLite::TYPE_FIXED64>(
|
||||
buf, &temp);
|
||||
*value = WrapUnsignedAsSigned64(temp);
|
||||
return buf;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_SFIXED32>(
|
||||
const uint8* buf, int32* value) {
|
||||
return WireFormatLite::ReadPrimitiveFromArray<int32,
|
||||
WireFormatLite::TYPE_SFIXED32>(
|
||||
buf, value);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_SFIXED64>(
|
||||
const uint8* buf, int64* value) {
|
||||
protobuf_int64 temp;
|
||||
buf = WireFormatLite::ReadPrimitiveFromArray<protobuf_int64,
|
||||
WireFormatLite::TYPE_SFIXED64>(
|
||||
buf, &temp);
|
||||
*value = temp;
|
||||
return buf;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline const uint8* ReadFromArray<float, WireFormatLite::TYPE_FLOAT>(
|
||||
const uint8* buf, float* value) {
|
||||
return WireFormatLite::ReadPrimitiveFromArray<float,
|
||||
WireFormatLite::TYPE_FLOAT>(
|
||||
buf, value);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline const uint8* ReadFromArray<double, WireFormatLite::TYPE_DOUBLE>(
|
||||
const uint8* buf, double* value) {
|
||||
return WireFormatLite::ReadPrimitiveFromArray<double,
|
||||
WireFormatLite::TYPE_DOUBLE>(
|
||||
buf, value);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline const uint8* ReadFromArray<bool, WireFormatLite::TYPE_BOOL>(
|
||||
const uint8* buf, bool* value) {
|
||||
uint64 temp;
|
||||
bool unused_ok; // The Counting pass would have failed if this were corrupt.
|
||||
buf = ReadVarint64FromArray(buf, &unused_ok, &temp);
|
||||
*value = temp != 0;
|
||||
return buf;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline const uint8* ReadFromArray<int, WireFormatLite::TYPE_ENUM>(
|
||||
const uint8* buf, int* value) {
|
||||
uint32 temp;
|
||||
bool unused_ok; // The Counting pass would have failed if this were corrupt.
|
||||
buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
|
||||
*value = static_cast<int>(temp);
|
||||
return buf;
|
||||
}
|
||||
|
||||
// Reads packed values from an array.
|
||||
// Stride is set to 1 for repeated fields, and 0 for non-repeated fields
|
||||
// (where any value overwrites previous values).
|
||||
template <class TensorType, enum WireFormatLite::FieldType DeclaredType>
|
||||
inline int ReadPackedPrimitives(const void* bufp, const size_t len,
|
||||
const int index, const int stride,
|
||||
void* datap) {
|
||||
const uint8* buf = reinterpret_cast<const uint8*>(bufp);
|
||||
const uint8* bound = buf + len;
|
||||
TensorType* data = reinterpret_cast<TensorType*>(datap) + index;
|
||||
int count;
|
||||
|
||||
// This could overrun the bound by stride-1. This is defended
|
||||
// against in the caller, where it ensures that the input buffer
|
||||
// contains complete values.
|
||||
for (count = 0; buf < bound; count += stride) {
|
||||
buf = ReadFromArray<TensorType, DeclaredType>(buf, data + count);
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
// Reads a primitive value field from a serialized proto.
|
||||
// The value is parsed from the serialized format, then static_cast
|
||||
// to the desired type for TensorFlow and stored.
|
||||
template <class ValueType, class TensorType,
|
||||
enum WireFormatLite::FieldType DeclaredType>
|
||||
inline Status ReadPrimitive(CodedInputStream* input, int index, void* data) {
|
||||
ValueType v;
|
||||
if (!WireFormatLite::ReadPrimitive<ValueType, DeclaredType>(input, &v)) {
|
||||
return errors::DataLoss("Failed reading primitive");
|
||||
}
|
||||
|
||||
reinterpret_cast<TensorType*>(data)[index] = v;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Reads a string, submessage, or other variable-length field from a
|
||||
// serialized proto.
|
||||
// May read all or part of a repeated field.
|
||||
inline Status ReadBytes(CodedInputStream* input, int index, void* datap) {
|
||||
string* data = reinterpret_cast<string*>(datap) + index;
|
||||
if (!WireFormatLite::ReadBytes(input, data)) {
|
||||
return errors::DataLoss("Failed reading bytes");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Reads a tag-delimited field (TYPE_GROUP) from a serialized proto,
|
||||
// as a bytestring.
|
||||
inline Status ReadGroupBytes(CodedInputStream* input, int field_number,
|
||||
int index, void* datap) {
|
||||
// WireFormatLite::SkipField has an option to emit the
|
||||
// skipped bytes to an output stream. We could do better by implementing our
|
||||
// own scanner but this is simpler for now.
|
||||
// TODO(nix): there is a faster way to grab TYPE_GROUP bytes by relying
|
||||
// on input->IsFlat() == true and using input->GetDirectBufferPointer()
|
||||
// with input->CurrentPosition().
|
||||
string* data = reinterpret_cast<string*>(datap) + index;
|
||||
StringOutputStream string_stream(data);
|
||||
CodedOutputStream out(&string_stream);
|
||||
if (!WireFormatLite::SkipField(
|
||||
input,
|
||||
WireFormatLite::MakeTag(field_number,
|
||||
WireFormatLite::WIRETYPE_START_GROUP),
|
||||
&out)) {
|
||||
return errors::DataLoss("Failed reading group");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Reads a single field value from a CodedInputStream into a tensor.
|
||||
inline Status ReadValue(CodedInputStream* input,
|
||||
WireFormatLite::FieldType field_type, int field_number,
|
||||
DataType dtype, int index, void* datap) {
|
||||
// Dispatch to the appropriately typed field reader based on the
|
||||
// schema type.
|
||||
switch (field_type) {
|
||||
case WireFormatLite::TYPE_DOUBLE:
|
||||
return ReadPrimitive<double, double, WireFormatLite::TYPE_DOUBLE>(
|
||||
input, index, datap);
|
||||
case WireFormatLite::TYPE_FLOAT:
|
||||
if (dtype == DataType::DT_FLOAT) {
|
||||
return ReadPrimitive<float, float, WireFormatLite::TYPE_FLOAT>(
|
||||
input, index, datap);
|
||||
}
|
||||
if (dtype == DataType::DT_DOUBLE) {
|
||||
return ReadPrimitive<float, double, WireFormatLite::TYPE_FLOAT>(
|
||||
input, index, datap);
|
||||
}
|
||||
// Any case that reaches this point should have triggered an error
|
||||
// already.
|
||||
return errors::DataLoss("Failed reading TYPE_FLOAT");
|
||||
case WireFormatLite::TYPE_INT64:
|
||||
return ReadPrimitive<protobuf_int64, int64, WireFormatLite::TYPE_INT64>(
|
||||
input, index, datap);
|
||||
case WireFormatLite::TYPE_UINT64:
|
||||
return ReadPrimitive<protobuf_uint64, int64, WireFormatLite::TYPE_UINT64>(
|
||||
input, index, datap);
|
||||
case WireFormatLite::TYPE_INT32:
|
||||
return ReadPrimitive<int32, int32, WireFormatLite::TYPE_INT32>(
|
||||
input, index, datap);
|
||||
case WireFormatLite::TYPE_FIXED64:
|
||||
return ReadPrimitive<protobuf_uint64, int64,
|
||||
WireFormatLite::TYPE_FIXED64>(input, index, datap);
|
||||
case WireFormatLite::TYPE_FIXED32:
|
||||
if (dtype == DataType::DT_INT64) {
|
||||
return ReadPrimitive<uint32, int64, WireFormatLite::TYPE_FIXED32>(
|
||||
input, index, datap);
|
||||
}
|
||||
if (dtype == DataType::DT_INT32) {
|
||||
return ReadPrimitive<uint32, int32, WireFormatLite::TYPE_FIXED32>(
|
||||
input, index, datap);
|
||||
}
|
||||
// Any case that reaches this point should have triggered an error
|
||||
// already.
|
||||
return errors::DataLoss("Failed reading TYPE_FIXED32");
|
||||
case WireFormatLite::TYPE_BOOL:
|
||||
return ReadPrimitive<bool, bool, WireFormatLite::TYPE_BOOL>(input, index,
|
||||
datap);
|
||||
case WireFormatLite::TYPE_STRING:
|
||||
return ReadBytes(input, index, datap);
|
||||
case WireFormatLite::TYPE_GROUP:
|
||||
return ReadGroupBytes(input, field_number, index, datap);
|
||||
case WireFormatLite::TYPE_MESSAGE:
|
||||
return ReadBytes(input, index, datap);
|
||||
case WireFormatLite::TYPE_BYTES:
|
||||
return ReadBytes(input, index, datap);
|
||||
case WireFormatLite::TYPE_UINT32:
|
||||
if (dtype == DataType::DT_INT64) {
|
||||
return ReadPrimitive<uint32, int64, WireFormatLite::TYPE_UINT32>(
|
||||
input, index, datap);
|
||||
}
|
||||
if (dtype == DataType::DT_INT32) {
|
||||
return ReadPrimitive<uint32, int32, WireFormatLite::TYPE_UINT32>(
|
||||
input, index, datap);
|
||||
}
|
||||
// Any case that reaches this point should have triggered an error
|
||||
// already.
|
||||
return errors::DataLoss("Failed reading TYPE_UINT32");
|
||||
case WireFormatLite::TYPE_ENUM:
|
||||
return ReadPrimitive<int32, int32, WireFormatLite::TYPE_ENUM>(
|
||||
input, index, datap);
|
||||
case WireFormatLite::TYPE_SFIXED32:
|
||||
return ReadPrimitive<int32, int32, WireFormatLite::TYPE_SFIXED32>(
|
||||
input, index, datap);
|
||||
case WireFormatLite::TYPE_SFIXED64:
|
||||
return ReadPrimitive<protobuf_int64, int64,
|
||||
WireFormatLite::TYPE_SFIXED64>(input, index, datap);
|
||||
case WireFormatLite::TYPE_SINT32:
|
||||
return ReadPrimitive<int32, int32, WireFormatLite::TYPE_SINT32>(
|
||||
input, index, datap);
|
||||
case WireFormatLite::TYPE_SINT64:
|
||||
return ReadPrimitive<protobuf_int64, int64, WireFormatLite::TYPE_SINT64>(
|
||||
input, index, datap);
|
||||
// default: intentionally omitted in order to enable static checking.
|
||||
}
|
||||
// Unreachable.
|
||||
return errors::DataLoss("Failed reading unknown wire type");
|
||||
}
|
||||
|
||||
// Reads and stores a length-delimited list of values.
|
||||
inline Status ReadPackedFromArray(const void* buf, size_t buf_size,
|
||||
const WireFormatLite::FieldType field_type,
|
||||
const int field_number, const DataType dtype,
|
||||
const int stride, int* index, void* data) {
|
||||
// Dispatch to the appropriately typed field reader based on the
|
||||
// schema type.
|
||||
switch (field_type) {
|
||||
case WireFormatLite::TYPE_DOUBLE:
|
||||
*index += ReadPackedPrimitives<double, WireFormatLite::TYPE_DOUBLE>(
|
||||
buf, buf_size, *index, stride, data);
|
||||
return Status::OK();
|
||||
case WireFormatLite::TYPE_FLOAT:
|
||||
*index += ReadPackedPrimitives<float, WireFormatLite::TYPE_FLOAT>(
|
||||
buf, buf_size, *index, stride, data);
|
||||
return Status::OK();
|
||||
case WireFormatLite::TYPE_INT64:
|
||||
*index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_INT64>(
|
||||
buf, buf_size, *index, stride, data);
|
||||
return Status::OK();
|
||||
case WireFormatLite::TYPE_UINT64:
|
||||
*index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_UINT64>(
|
||||
buf, buf_size, *index, stride, data);
|
||||
return Status::OK();
|
||||
case WireFormatLite::TYPE_INT32:
|
||||
*index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_INT32>(
|
||||
buf, buf_size, *index, stride, data);
|
||||
return Status::OK();
|
||||
case WireFormatLite::TYPE_FIXED64:
|
||||
*index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_FIXED64>(
|
||||
buf, buf_size, *index, stride, data);
|
||||
return Status::OK();
|
||||
case WireFormatLite::TYPE_FIXED32:
|
||||
if (dtype == DataType::DT_INT64) {
|
||||
*index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_FIXED32>(
|
||||
buf, buf_size, *index, stride, data);
|
||||
return Status::OK();
|
||||
}
|
||||
if (dtype == DataType::DT_INT32) {
|
||||
*index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_FIXED32>(
|
||||
buf, buf_size, *index, stride, data);
|
||||
return Status::OK();
|
||||
}
|
||||
// Any case that reaches this point should have triggered an error
|
||||
// already.
|
||||
return errors::DataLoss("Failed reading TYPE_FIXED32");
|
||||
case WireFormatLite::TYPE_BOOL:
|
||||
*index += ReadPackedPrimitives<bool, WireFormatLite::TYPE_BOOL>(
|
||||
buf, buf_size, *index, stride, data);
|
||||
return Status::OK();
|
||||
case WireFormatLite::TYPE_STRING:
|
||||
case WireFormatLite::TYPE_GROUP:
|
||||
case WireFormatLite::TYPE_MESSAGE:
|
||||
case WireFormatLite::TYPE_BYTES:
|
||||
return errors::DataLoss("Non-primitive type encountered as packed");
|
||||
case WireFormatLite::TYPE_UINT32:
|
||||
if (dtype == DataType::DT_INT64) {
|
||||
*index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_UINT32>(
|
||||
buf, buf_size, *index, stride, data);
|
||||
return Status::OK();
|
||||
}
|
||||
if (dtype == DataType::DT_INT32) {
|
||||
*index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_UINT32>(
|
||||
buf, buf_size, *index, stride, data);
|
||||
return Status::OK();
|
||||
}
|
||||
// Any case that reaches this point should have triggered an error
|
||||
// already.
|
||||
return errors::DataLoss("Failed reading TYPE_UINT32");
|
||||
case WireFormatLite::TYPE_ENUM:
|
||||
*index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_ENUM>(
|
||||
buf, buf_size, *index, stride, data);
|
||||
return Status::OK();
|
||||
case WireFormatLite::TYPE_SFIXED32:
|
||||
*index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_SFIXED32>(
|
||||
buf, buf_size, *index, stride, data);
|
||||
return Status::OK();
|
||||
|
||||
case WireFormatLite::TYPE_SFIXED64:
|
||||
*index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_SFIXED64>(
|
||||
buf, buf_size, *index, stride, data);
|
||||
return Status::OK();
|
||||
|
||||
case WireFormatLite::TYPE_SINT32:
|
||||
*index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_SINT32>(
|
||||
buf, buf_size, *index, stride, data);
|
||||
return Status::OK();
|
||||
|
||||
case WireFormatLite::TYPE_SINT64:
|
||||
*index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_SINT64>(
|
||||
buf, buf_size, *index, stride, data);
|
||||
return Status::OK();
|
||||
// default: intentionally omitted in order to enable static checking.
|
||||
}
|
||||
// Unreachable.
|
||||
return errors::DataLoss("Failed reading unknown wire type");
|
||||
}
|
||||
|
||||
// Reads a varint from the given buffer, write it to *value, and return the
|
||||
// new buffer pointer.
|
||||
// This was copied from coded_stream.cc where it is private.
|
||||
// Important: This routine may read as much as kMaxVarintBytes from
|
||||
// the buffer. It is the caller's responsibility to make sure that there is
|
||||
// enough space in the buffer.
|
||||
inline const uint8* ReadVarint64FromArray(const uint8* buffer, bool* ok,
|
||||
uint64* value) {
|
||||
const uint8* ptr = buffer;
|
||||
uint32 b;
|
||||
|
||||
// Splitting into 32-bit pieces gives better performance on 32-bit
|
||||
// processors.
|
||||
uint32 part0 = 0, part1 = 0, part2 = 0;
|
||||
|
||||
b = *(ptr++);
|
||||
part0 = b;
|
||||
if (!(b & 0x80)) goto done;
|
||||
part0 -= 0x80;
|
||||
b = *(ptr++);
|
||||
part0 += b << 7;
|
||||
if (!(b & 0x80)) goto done;
|
||||
part0 -= 0x80 << 7;
|
||||
b = *(ptr++);
|
||||
part0 += b << 14;
|
||||
if (!(b & 0x80)) goto done;
|
||||
part0 -= 0x80 << 14;
|
||||
b = *(ptr++);
|
||||
part0 += b << 21;
|
||||
if (!(b & 0x80)) goto done;
|
||||
part0 -= 0x80 << 21;
|
||||
b = *(ptr++);
|
||||
part1 = b;
|
||||
if (!(b & 0x80)) goto done;
|
||||
part1 -= 0x80;
|
||||
b = *(ptr++);
|
||||
part1 += b << 7;
|
||||
if (!(b & 0x80)) goto done;
|
||||
part1 -= 0x80 << 7;
|
||||
b = *(ptr++);
|
||||
part1 += b << 14;
|
||||
if (!(b & 0x80)) goto done;
|
||||
part1 -= 0x80 << 14;
|
||||
b = *(ptr++);
|
||||
part1 += b << 21;
|
||||
if (!(b & 0x80)) goto done;
|
||||
part1 -= 0x80 << 21;
|
||||
b = *(ptr++);
|
||||
part2 = b;
|
||||
if (!(b & 0x80)) goto done;
|
||||
part2 -= 0x80;
|
||||
b = *(ptr++);
|
||||
part2 += b << 7;
|
||||
if (!(b & 0x80)) goto done;
|
||||
// "part2 -= 0x80 << 7" is irrelevant because (0x80 << 7) << 56 is 0.
|
||||
|
||||
// We have overrun the maximum size of a varint (10 bytes). Assume
|
||||
// the data is corrupt.
|
||||
*ok = false;
|
||||
return ptr;
|
||||
|
||||
done:
|
||||
*ok = true;
|
||||
*value = (static_cast<uint64>(part0)) | (static_cast<uint64>(part1) << 28) |
|
||||
(static_cast<uint64>(part2) << 56);
|
||||
return ptr;
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_UTIL_PROTO_DECODE_H_
|
45
tensorflow/core/util/proto/descriptor_pool_registry.cc
Normal file
45
tensorflow/core/util/proto/descriptor_pool_registry.cc
Normal file
@ -0,0 +1,45 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
#include "tensorflow/core/util/proto/descriptor_pool_registry.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
DescriptorPoolRegistry* DescriptorPoolRegistry::Global() {
|
||||
static DescriptorPoolRegistry* registry = new DescriptorPoolRegistry;
|
||||
return registry;
|
||||
}
|
||||
|
||||
DescriptorPoolRegistry::DescriptorPoolFn* DescriptorPoolRegistry::Get(
|
||||
const string& source) {
|
||||
auto found = fns_.find(source);
|
||||
if (found == fns_.end()) return nullptr;
|
||||
return &found->second;
|
||||
}
|
||||
|
||||
void DescriptorPoolRegistry::Register(
|
||||
const string& source,
|
||||
const DescriptorPoolRegistry::DescriptorPoolFn& pool_fn) {
|
||||
auto existing = Get(source);
|
||||
CHECK_EQ(existing, nullptr)
|
||||
<< "descriptor pool for source: " << source << " already registered";
|
||||
fns_.insert(std::pair<const string&, DescriptorPoolFn>(source, pool_fn));
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
76
tensorflow/core/util/proto/descriptor_pool_registry.h
Normal file
76
tensorflow/core/util/proto/descriptor_pool_registry.h
Normal file
@ -0,0 +1,76 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_UTIL_PROTO_DESCRIPTOR_POOL_REGISTRY_H_
|
||||
#define TENSORFLOW_CORE_UTIL_PROTO_DESCRIPTOR_POOL_REGISTRY_H_
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class DescriptorPoolRegistry {
|
||||
public:
|
||||
typedef std::function<Status(
|
||||
tensorflow::protobuf::DescriptorPool const** desc_pool,
|
||||
std::unique_ptr<tensorflow::protobuf::DescriptorPool>* owned_desc_pool)>
|
||||
DescriptorPoolFn;
|
||||
|
||||
// Returns a pointer to a global DescriptorPoolRegistry object.
|
||||
static DescriptorPoolRegistry* Global();
|
||||
|
||||
// Returns a pointer to a descriptor pool function for the given source.
|
||||
DescriptorPoolFn* Get(const string& source);
|
||||
|
||||
// Registers a descriptor pool factory.
|
||||
void Register(const string& source, const DescriptorPoolFn& pool_fn);
|
||||
|
||||
private:
|
||||
std::map<string, DescriptorPoolFn> fns_;
|
||||
};
|
||||
|
||||
namespace descriptor_pool_registration {
|
||||
|
||||
class DescriptorPoolRegistration {
|
||||
public:
|
||||
DescriptorPoolRegistration(
|
||||
const string& source,
|
||||
const DescriptorPoolRegistry::DescriptorPoolFn& pool_fn) {
|
||||
DescriptorPoolRegistry::Global()->Register(source, pool_fn);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace descriptor_pool_registration
|
||||
|
||||
#define REGISTER_DESCRIPTOR_POOL(source, pool_fn) \
|
||||
REGISTER_DESCRIPTOR_POOL_UNIQ_HELPER(__COUNTER__, source, pool_fn)
|
||||
|
||||
#define REGISTER_DESCRIPTOR_POOL_UNIQ_HELPER(ctr, source, pool_fn) \
|
||||
REGISTER_DESCRIPTOR_POOL_UNIQ(ctr, source, pool_fn)
|
||||
|
||||
#define REGISTER_DESCRIPTOR_POOL_UNIQ(ctr, source, pool_fn) \
|
||||
static descriptor_pool_registration::DescriptorPoolRegistration \
|
||||
descriptor_pool_registration_fn_##ctr(source, pool_fn)
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_UTIL_PROTO_DESCRIPTOR_POOL_REGISTRY_H_
|
43
tensorflow/core/util/proto/descriptor_pool_registry_test.cc
Normal file
43
tensorflow/core/util/proto/descriptor_pool_registry_test.cc
Normal file
@ -0,0 +1,43 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/util/proto/descriptor_pool_registry.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
struct Value {
|
||||
static Status Function(
|
||||
tensorflow::protobuf::DescriptorPool const** desc_pool,
|
||||
std::unique_ptr<tensorflow::protobuf::DescriptorPool>* owned_desc_pool) {
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_DESCRIPTOR_POOL("TEST POOL 1", Value::Function);
|
||||
REGISTER_DESCRIPTOR_POOL("TEST POOL 2", Value::Function);
|
||||
} // namespace
|
||||
|
||||
TEST(DescriptorPoolRegistryTest, TestBasic) {
|
||||
EXPECT_EQ(DescriptorPoolRegistry::Global()->Get("NON-EXISTENT"), nullptr);
|
||||
auto pool1 = DescriptorPoolRegistry::Global()->Get("TEST POOL 1");
|
||||
EXPECT_NE(pool1, nullptr);
|
||||
auto pool2 = DescriptorPoolRegistry::Global()->Get("TEST POOL 2");
|
||||
EXPECT_NE(pool2, nullptr);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
85
tensorflow/core/util/proto/descriptors.cc
Normal file
85
tensorflow/core/util/proto/descriptors.cc
Normal file
@ -0,0 +1,85 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/reader_op_kernel.h"
|
||||
#include "tensorflow/core/util/proto/descriptor_pool_registry.h"
|
||||
|
||||
#include "tensorflow/core/util/proto/descriptors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// Build a `DescriptorPool` from the named file or URI. The file or URI
|
||||
// must be available to the current TensorFlow environment.
|
||||
//
|
||||
// The file must contiain a serialized `FileDescriptorSet`. See
|
||||
// `GetDescriptorPool()` for more information.
|
||||
Status GetDescriptorPoolFromFile(
|
||||
tensorflow::Env* env, const string& filename,
|
||||
std::unique_ptr<tensorflow::protobuf::DescriptorPool>* owned_desc_pool) {
|
||||
Status st = env->FileExists(filename);
|
||||
if (!st.ok()) {
|
||||
return st;
|
||||
}
|
||||
|
||||
// Read and parse the FileDescriptorSet.
|
||||
tensorflow::protobuf::FileDescriptorSet descs;
|
||||
std::unique_ptr<tensorflow::ReadOnlyMemoryRegion> buf;
|
||||
st = env->NewReadOnlyMemoryRegionFromFile(filename, &buf);
|
||||
if (!st.ok()) {
|
||||
return st;
|
||||
}
|
||||
if (!descs.ParseFromArray(buf->data(), buf->length())) {
|
||||
return errors::InvalidArgument(
|
||||
"descriptor_source contains invalid FileDescriptorSet: ", filename);
|
||||
}
|
||||
|
||||
// Build a DescriptorPool from the FileDescriptorSet.
|
||||
owned_desc_pool->reset(new tensorflow::protobuf::DescriptorPool());
|
||||
for (const auto& filedesc : descs.file()) {
|
||||
if ((*owned_desc_pool)->BuildFile(filedesc) == nullptr) {
|
||||
return errors::InvalidArgument(
|
||||
"Problem loading FileDescriptorProto (missing dependencies?): ",
|
||||
filename);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status GetDescriptorPool(
|
||||
tensorflow::Env* env, string const& descriptor_source,
|
||||
tensorflow::protobuf::DescriptorPool const** desc_pool,
|
||||
std::unique_ptr<tensorflow::protobuf::DescriptorPool>* owned_desc_pool) {
|
||||
// Attempt to lookup the pool in the registry.
|
||||
auto pool_fn = DescriptorPoolRegistry::Global()->Get(descriptor_source);
|
||||
if (pool_fn != nullptr) {
|
||||
return (*pool_fn)(desc_pool, owned_desc_pool);
|
||||
}
|
||||
|
||||
// If there is no pool function registered for the given source, let the
|
||||
// runtime find the file or URL.
|
||||
Status status =
|
||||
GetDescriptorPoolFromFile(env, descriptor_source, owned_desc_pool);
|
||||
if (status.ok()) {
|
||||
*desc_pool = owned_desc_pool->get();
|
||||
}
|
||||
*desc_pool = owned_desc_pool->get();
|
||||
return status;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
42
tensorflow/core/util/proto/descriptors.h
Normal file
42
tensorflow/core/util/proto/descriptors.h
Normal file
@ -0,0 +1,42 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_UTIL_PROTO_DESCRIPTORS_H_
|
||||
#define TENSORFLOW_CORE_UTIL_PROTO_DESCRIPTORS_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
|
||||
namespace tensorflow {
|
||||
class Env;
|
||||
class Status;
|
||||
|
||||
// Get a `DescriptorPool` object from the named `descriptor_source`.
|
||||
// `descriptor_source` may be a path to a file accessible to TensorFlow, in
|
||||
// which case it is parsed as a `FileDescriptorSet` and used to build the
|
||||
// `DescriptorPool`.
|
||||
//
|
||||
// `owned_desc_pool` will be filled in with the same pointer as `desc_pool` if
|
||||
// the caller should take ownership.
|
||||
extern tensorflow::Status GetDescriptorPool(
|
||||
tensorflow::Env* env, string const& descriptor_source,
|
||||
tensorflow::protobuf::DescriptorPool const** desc_pool,
|
||||
std::unique_ptr<tensorflow::protobuf::DescriptorPool>* owned_desc_pool);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_UTIL_PROTO_DESCRIPTORS_H_
|
@ -0,0 +1,39 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/util/proto/descriptor_pool_registry.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
struct LocalDescriptorPool {
|
||||
static Status Function(
|
||||
tensorflow::protobuf::DescriptorPool const** desc_pool,
|
||||
std::unique_ptr<tensorflow::protobuf::DescriptorPool>* owned_desc_pool) {
|
||||
*desc_pool = ::tensorflow::protobuf::DescriptorPool::generated_pool();
|
||||
if (*desc_pool == nullptr) {
|
||||
return errors::InvalidArgument("Problem loading protobuf generated_pool");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_DESCRIPTOR_POOL("", LocalDescriptorPool::Function);
|
||||
REGISTER_DESCRIPTOR_POOL("local://", LocalDescriptorPool::Function);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
48
tensorflow/core/util/rpc/BUILD
Normal file
48
tensorflow/core/util/rpc/BUILD
Normal file
@ -0,0 +1,48 @@
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
|
||||
cc_library(
|
||||
name = "call_container",
|
||||
hdrs = ["call_container.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "rpc_factory",
|
||||
srcs = ["rpc_factory.cc"],
|
||||
hdrs = ["rpc_factory.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "rpc_factory_registry",
|
||||
srcs = ["rpc_factory_registry.cc"],
|
||||
hdrs = ["rpc_factory_registry.h"],
|
||||
deps = [
|
||||
":rpc_factory",
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "rpc_factory_registry_test",
|
||||
srcs = ["rpc_factory_registry_test.cc"],
|
||||
deps = [
|
||||
":rpc_factory_registry",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
90
tensorflow/core/util/rpc/call_container.h
Normal file
90
tensorflow/core/util/rpc/call_container.h
Normal file
@ -0,0 +1,90 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_UTIL_RPC_CALL_CONTAINER_H_
|
||||
#define TENSORFLOW_CORE_UTIL_RPC_CALL_CONTAINER_H_
|
||||
|
||||
#include <list>
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/util/reffed_status_callback.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
template <typename Call>
|
||||
class CallContainer {
|
||||
public:
|
||||
explicit CallContainer(OpKernelContext* ctx, int num_calls, bool fail_fast,
|
||||
bool try_rpc, AsyncOpKernel::DoneCallback done,
|
||||
CancellationToken token)
|
||||
: ctx_(ctx),
|
||||
done_(std::move(done)),
|
||||
token_(token),
|
||||
fail_fast_(fail_fast),
|
||||
try_rpc_(try_rpc) {
|
||||
CHECK_GT(num_calls, 0);
|
||||
|
||||
// This will run when all RPCs are finished.
|
||||
reffed_status_callback_ = new ReffedStatusCallback([this](const Status& s) {
|
||||
ctx_->cancellation_manager()->DeregisterCallback(token_);
|
||||
ctx_->SetStatus(s);
|
||||
done_();
|
||||
delete this;
|
||||
});
|
||||
|
||||
// Subtract reference count from the initial creation.
|
||||
core::ScopedUnref unref(reffed_status_callback_);
|
||||
|
||||
for (int i = 0; i < num_calls; ++i) {
|
||||
// Increase the reference on the callback for each new RPC.
|
||||
reffed_status_callback_->Ref();
|
||||
}
|
||||
}
|
||||
|
||||
std::list<Call>* calls() { return &calls_; }
|
||||
|
||||
void StartCancel() {
|
||||
// Once this loop is done, can no longer assume anything is valid
|
||||
// because "delete this" may have been immediately called.
|
||||
// Nothing should run after this loop.
|
||||
for (auto& call : calls_) {
|
||||
call.StartCancel();
|
||||
}
|
||||
}
|
||||
|
||||
void Done(const Status& s, int index) {
|
||||
if (!try_rpc_) {
|
||||
reffed_status_callback_->UpdateStatus(s);
|
||||
}
|
||||
reffed_status_callback_->Unref();
|
||||
}
|
||||
|
||||
private:
|
||||
OpKernelContext* ctx_;
|
||||
std::list<Call> calls_;
|
||||
const AsyncOpKernel::DoneCallback done_;
|
||||
const CancellationToken token_;
|
||||
const bool fail_fast_;
|
||||
const bool try_rpc_;
|
||||
|
||||
// Performs its own reference counting.
|
||||
ReffedStatusCallback* reffed_status_callback_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
#endif // TENSORFLOW_CORE_UTIL_RPC_CALL_CONTAINER_H_
|
53
tensorflow/core/util/rpc/rpc_factory.cc
Normal file
53
tensorflow/core/util/rpc/rpc_factory.cc
Normal file
@ -0,0 +1,53 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/lib/strings/numbers.h"
|
||||
|
||||
#include "tensorflow/core/util/rpc/rpc_factory.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
template <>
|
||||
bool GetEnvVar(const char* key, const string& default_value, string* value) {
|
||||
const char* env_value = std::getenv(key);
|
||||
if (!env_value || env_value[0] == '\0') {
|
||||
*value = default_value;
|
||||
} else {
|
||||
*value = env_value;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool GetEnvVar(const char* key, const int64& default_value, int64* value) {
|
||||
const char* env_value = std::getenv(key);
|
||||
if (!env_value || env_value[0] == '\0') {
|
||||
*value = default_value;
|
||||
return true;
|
||||
}
|
||||
return strings::safe_strto64(env_value, value);
|
||||
}
|
||||
|
||||
template <>
|
||||
bool GetEnvVar(const char* key, const uint64& default_value, uint64* value) {
|
||||
const char* env_value = std::getenv(key);
|
||||
if (!env_value || env_value[0] == '\0') {
|
||||
*value = default_value;
|
||||
return true;
|
||||
}
|
||||
return strings::safe_strtou64(env_value, value);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
70
tensorflow/core/util/rpc/rpc_factory.h
Normal file
70
tensorflow/core/util/rpc/rpc_factory.h
Normal file
@ -0,0 +1,70 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_H_
|
||||
#define TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_H_
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Return the environment variable `key`. If the variable is not set,
|
||||
// use the default value. If it is set but could not be parsed,
|
||||
// return `false`. Otherwise set `value` and return `true`.
|
||||
template <typename T>
|
||||
bool GetEnvVar(const char* key, const T& default_value, T* value);
|
||||
|
||||
class RPCFactory {
|
||||
public:
|
||||
RPCFactory() {}
|
||||
virtual ~RPCFactory() {}
|
||||
|
||||
// Start a Call() to methods `method_t` at addresses `address_t` with
|
||||
// request strings from `request_t`. Any of these may be scalar
|
||||
// Tensors, in which case the operands are broadcasted.
|
||||
// Upon completion of all requests, `response_t` will be populated.
|
||||
//
|
||||
// If `try_rpc` is `true`, then `status_message_t` and
|
||||
// `status_code_t` will be populated as well.
|
||||
//
|
||||
// If `try_rpc` is `false`, then `status_message_t` and
|
||||
// `status_code_t` are ignored (and may be nullptr). Instead, the
|
||||
// status of any failed call will be propagated to the op.
|
||||
//
|
||||
// REQUIRES:
|
||||
// - `response_t` is not null, and is a string Tensor with the same shape as
|
||||
// `request_t`.
|
||||
//
|
||||
// If `try_rpc` is `true`:
|
||||
// - `status_code_t` and `status_message_t` are not null.
|
||||
// - `status_code_t` is an int32 Tensor with the same shape as
|
||||
// `request_t`.
|
||||
// - `status_message_t` is a string Tensor with the same shape as
|
||||
// `request_t`.
|
||||
virtual void Call(OpKernelContext* ctx, int64 num_elements,
|
||||
const Tensor& address_t, const Tensor& method_t,
|
||||
const Tensor& request_t, const bool try_rpc,
|
||||
Tensor* response_t, Tensor* status_code_t,
|
||||
Tensor* status_message_t,
|
||||
AsyncOpKernel::DoneCallback done) = 0;
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(RPCFactory);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_H_
|
44
tensorflow/core/util/rpc/rpc_factory_registry.cc
Normal file
44
tensorflow/core/util/rpc/rpc_factory_registry.cc
Normal file
@ -0,0 +1,44 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/core/util/rpc/rpc_factory.h"
|
||||
|
||||
#include "tensorflow/core/util/rpc/rpc_factory_registry.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
RPCFactoryRegistry* RPCFactoryRegistry::Global() {
|
||||
static RPCFactoryRegistry* registry = new RPCFactoryRegistry;
|
||||
return registry;
|
||||
}
|
||||
|
||||
RPCFactoryRegistry::RPCFactoryFn* RPCFactoryRegistry::Get(
|
||||
const string& protocol) {
|
||||
auto found = fns_.find(protocol);
|
||||
if (found == fns_.end()) return nullptr;
|
||||
return &found->second;
|
||||
}
|
||||
|
||||
void RPCFactoryRegistry::Register(const string& protocol,
|
||||
const RPCFactoryFn& factory_fn) {
|
||||
auto existing = Get(protocol);
|
||||
CHECK_EQ(existing, nullptr)
|
||||
<< "RPC factory for protocol: " << protocol << " already registered";
|
||||
fns_.insert(std::pair<const string&, RPCFactoryFn>(protocol, factory_fn));
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
72
tensorflow/core/util/rpc/rpc_factory_registry.h
Normal file
72
tensorflow/core/util/rpc/rpc_factory_registry.h
Normal file
@ -0,0 +1,72 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_REGISTRY_H_
|
||||
#define TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_REGISTRY_H_
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/util/rpc/rpc_factory.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class RPCFactoryRegistry {
|
||||
public:
|
||||
typedef std::function<RPCFactory*(OpKernelConstruction* ctx, bool fail_fast,
|
||||
int64 timeout_in_ms)>
|
||||
RPCFactoryFn;
|
||||
|
||||
// Returns a pointer to a global RPCFactoryRegistry object.
|
||||
static RPCFactoryRegistry* Global();
|
||||
|
||||
// Returns a pointer to an function that creates an RPC factory for the given
|
||||
// protocol.
|
||||
RPCFactoryFn* Get(const string& protocol);
|
||||
|
||||
// Registers a function that creates and RPC factory for the given protocol.
|
||||
// The function should transfer the ownership of the factory to its caller.
|
||||
void Register(const string& protocol, const RPCFactoryFn& factory_fn);
|
||||
|
||||
private:
|
||||
std::map<string, RPCFactoryFn> fns_;
|
||||
};
|
||||
|
||||
namespace rpc_factory_registration {
|
||||
|
||||
class RPCFactoryRegistration {
|
||||
public:
|
||||
RPCFactoryRegistration(const string& protocol,
|
||||
const RPCFactoryRegistry::RPCFactoryFn& factory_fn) {
|
||||
RPCFactoryRegistry::Global()->Register(protocol, factory_fn);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace rpc_factory_registration
|
||||
|
||||
#define REGISTER_RPC_FACTORY(protocol, factory_fn) \
|
||||
REGISTER_RPC_FACTORY_UNIQ_HELPER(__COUNTER__, protocol, factory_fn)
|
||||
|
||||
#define REGISTER_RPC_FACTORY_UNIQ_HELPER(ctr, protocol, factory_fn) \
|
||||
REGISTER_RPC_FACTORY_UNIQ(ctr, protocol, factory_fn)
|
||||
|
||||
#define REGISTER_RPC_FACTORY_UNIQ(ctr, protocol, factory_fn) \
|
||||
static rpc_factory_registration::RPCFactoryRegistration \
|
||||
rpc_factory_registration_fn_##ctr(protocol, factory_fn)
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_REGISTRY_H_
|
41
tensorflow/core/util/rpc/rpc_factory_registry_test.cc
Normal file
41
tensorflow/core/util/rpc/rpc_factory_registry_test.cc
Normal file
@ -0,0 +1,41 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/util/rpc/rpc_factory_registry.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
struct Value {
|
||||
static RPCFactory* Function(OpKernelConstruction* ctx, bool fail_fast,
|
||||
int64 timeout_in_ms) {
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_RPC_FACTORY("TEST FACTORY 1", Value::Function);
|
||||
REGISTER_RPC_FACTORY("TEST FACTORY 2", Value::Function);
|
||||
} // namespace
|
||||
|
||||
TEST(RPCFactoryRegistryTest, TestBasic) {
|
||||
EXPECT_EQ(RPCFactoryRegistry::Global()->Get("NON-EXISTENT"), nullptr);
|
||||
auto factory1 = RPCFactoryRegistry::Global()->Get("TEST FACTORY 1");
|
||||
EXPECT_NE(factory1, nullptr);
|
||||
auto factory2 = RPCFactoryRegistry::Global()->Get("TEST FACTORY 2");
|
||||
EXPECT_NE(factory2, nullptr);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -3370,6 +3370,7 @@ tf_py_wrap_cc(
|
||||
"//tensorflow/c:python_api",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_rpc_factory_registration",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_session",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
|
Loading…
x
Reference in New Issue
Block a user