Porting tests for rpc_op to OS.
PiperOrigin-RevId: 193102564
This commit is contained in:
parent
ba1ea3ff90
commit
d995be2deb
@ -87,6 +87,7 @@ py_library(
|
|||||||
"//tensorflow/contrib/remote_fused_graph/pylib:remote_fused_graph_ops_py",
|
"//tensorflow/contrib/remote_fused_graph/pylib:remote_fused_graph_ops_py",
|
||||||
"//tensorflow/contrib/resampler:resampler_py",
|
"//tensorflow/contrib/resampler:resampler_py",
|
||||||
"//tensorflow/contrib/rnn:rnn_py",
|
"//tensorflow/contrib/rnn:rnn_py",
|
||||||
|
"//tensorflow/contrib/rpc",
|
||||||
"//tensorflow/contrib/saved_model:saved_model_py",
|
"//tensorflow/contrib/saved_model:saved_model_py",
|
||||||
"//tensorflow/contrib/seq2seq:seq2seq_py",
|
"//tensorflow/contrib/seq2seq:seq2seq_py",
|
||||||
"//tensorflow/contrib/signal:signal_py",
|
"//tensorflow/contrib/signal:signal_py",
|
||||||
|
|||||||
@ -71,6 +71,7 @@ from tensorflow.contrib import recurrent
|
|||||||
from tensorflow.contrib import reduce_slice_ops
|
from tensorflow.contrib import reduce_slice_ops
|
||||||
from tensorflow.contrib import resampler
|
from tensorflow.contrib import resampler
|
||||||
from tensorflow.contrib import rnn
|
from tensorflow.contrib import rnn
|
||||||
|
from tensorflow.contrib import rpc
|
||||||
from tensorflow.contrib import saved_model
|
from tensorflow.contrib import saved_model
|
||||||
from tensorflow.contrib import seq2seq
|
from tensorflow.contrib import seq2seq
|
||||||
from tensorflow.contrib import signal
|
from tensorflow.contrib import signal
|
||||||
|
|||||||
@ -347,7 +347,8 @@ GENERATE_PYTHON_OP_LIB("random_ops")
|
|||||||
GENERATE_PYTHON_OP_LIB("remote_fused_graph_ops"
|
GENERATE_PYTHON_OP_LIB("remote_fused_graph_ops"
|
||||||
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/remote_fused_graph/pylib/python/ops/gen_remote_fused_graph_ops.py)
|
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/remote_fused_graph/pylib/python/ops/gen_remote_fused_graph_ops.py)
|
||||||
GENERATE_PYTHON_OP_LIB("resource_variable_ops")
|
GENERATE_PYTHON_OP_LIB("resource_variable_ops")
|
||||||
GENERATE_PYTHON_OP_LIB("rpc_ops")
|
GENERATE_PYTHON_OP_LIB("rpc_ops"
|
||||||
|
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/rpc/python/ops/gen_rpc_op.py)
|
||||||
GENERATE_PYTHON_OP_LIB("script_ops")
|
GENERATE_PYTHON_OP_LIB("script_ops")
|
||||||
GENERATE_PYTHON_OP_LIB("sdca_ops")
|
GENERATE_PYTHON_OP_LIB("sdca_ops")
|
||||||
GENERATE_PYTHON_OP_LIB("set_ops")
|
GENERATE_PYTHON_OP_LIB("set_ops")
|
||||||
|
|||||||
@ -4,6 +4,8 @@ licenses(["notice"]) # Apache 2.0
|
|||||||
|
|
||||||
exports_files(["LICENSE"])
|
exports_files(["LICENSE"])
|
||||||
|
|
||||||
|
load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static")
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "rpc",
|
name = "rpc",
|
||||||
srcs = [
|
srcs = [
|
||||||
@ -11,3 +13,17 @@ py_library(
|
|||||||
],
|
],
|
||||||
deps = ["//tensorflow/contrib/rpc/python/ops:rpc_op_py"],
|
deps = ["//tensorflow/contrib/rpc/python/ops:rpc_op_py"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "rpc_pip",
|
||||||
|
data = if_static(
|
||||||
|
[],
|
||||||
|
otherwise = ["//tensorflow/contrib/rpc/python/kernel_tests:libtestexample.so"],
|
||||||
|
),
|
||||||
|
deps = [
|
||||||
|
":rpc",
|
||||||
|
"//tensorflow/contrib/rpc/python/kernel_tests:py_test_deps",
|
||||||
|
"//tensorflow/contrib/rpc/python/kernel_tests:rpc_op_test_base",
|
||||||
|
"//tensorflow/contrib/rpc/python/kernel_tests:rpc_op_test_servicer",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|||||||
80
tensorflow/contrib/rpc/python/kernel_tests/BUILD
Normal file
80
tensorflow/contrib/rpc/python/kernel_tests/BUILD
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
# TODO(b/76425722): Port everything in here to OS (currently excluded).
|
||||||
|
|
||||||
|
package(default_visibility = ["//visibility:public"])
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
exports_files(["LICENSE"])
|
||||||
|
|
||||||
|
load("//tensorflow:tensorflow.bzl", "tf_py_test")
|
||||||
|
load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object")
|
||||||
|
load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static")
|
||||||
|
load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
|
||||||
|
# Placeholder for loading internal BUILD rule.
|
||||||
|
|
||||||
|
tf_proto_library(
|
||||||
|
name = "test_example_proto",
|
||||||
|
srcs = ["test_example.proto"],
|
||||||
|
has_services = 1,
|
||||||
|
cc_api_version = 2,
|
||||||
|
protodeps = ["//tensorflow/core:protos_all"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "py_test_deps",
|
||||||
|
deps = [":test_example_proto_py"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "rpc_op_test_base",
|
||||||
|
srcs = ["rpc_op_test_base.py"],
|
||||||
|
deps = [
|
||||||
|
":test_example_proto_py",
|
||||||
|
"//tensorflow/contrib/proto",
|
||||||
|
"//tensorflow/contrib/rpc",
|
||||||
|
"//tensorflow/core:protos_all_py",
|
||||||
|
"//tensorflow/python:dtypes",
|
||||||
|
"//tensorflow/python:errors",
|
||||||
|
"//third_party/py/numpy",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "rpc_op_test_servicer",
|
||||||
|
srcs = ["rpc_op_test_servicer.py"],
|
||||||
|
deps = [
|
||||||
|
":py_test_deps",
|
||||||
|
":rpc_op_test_base",
|
||||||
|
"//tensorflow/core:protos_all_py",
|
||||||
|
"//third_party/py/numpy",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cc_shared_object(
|
||||||
|
name = "libtestexample.so",
|
||||||
|
linkstatic = 1,
|
||||||
|
deps = [
|
||||||
|
":test_example_proto_cc",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_py_test(
|
||||||
|
name = "rpc_op_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["rpc_op_test.py"],
|
||||||
|
additional_deps = [
|
||||||
|
":py_test_deps",
|
||||||
|
":rpc_op_test_base",
|
||||||
|
":rpc_op_test_servicer",
|
||||||
|
"//tensorflow/core:protos_all_py",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
],
|
||||||
|
data = if_static(
|
||||||
|
[],
|
||||||
|
otherwise = [":libtestexample.so"],
|
||||||
|
),
|
||||||
|
tags = [
|
||||||
|
"no_pip", # TODO(b/78026780)
|
||||||
|
"no_windows", # TODO(b/78028010)
|
||||||
|
],
|
||||||
|
)
|
||||||
71
tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py
Normal file
71
tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
"""Tests for RpcOp."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import ctypes as ct
|
||||||
|
import os
|
||||||
|
|
||||||
|
import grpc
|
||||||
|
from grpc.framework.foundation import logging_pool
|
||||||
|
import portpicker
|
||||||
|
|
||||||
|
from tensorflow.contrib.rpc.python.kernel_tests import rpc_op_test_base
|
||||||
|
from tensorflow.contrib.rpc.python.kernel_tests import rpc_op_test_servicer
|
||||||
|
from tensorflow.contrib.rpc.python.kernel_tests import test_example_pb2_grpc
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class RpcOpTest(test.TestCase, rpc_op_test_base.RpcOpTestBase):
|
||||||
|
_protocol = 'grpc'
|
||||||
|
|
||||||
|
invalid_method_string = 'Method not found'
|
||||||
|
|
||||||
|
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
|
||||||
|
super(RpcOpTest, self).__init__(methodName)
|
||||||
|
lib = os.path.join(os.path.dirname(__file__), 'libtestexample.so')
|
||||||
|
if os.path.isfile(lib):
|
||||||
|
ct.cdll.LoadLibrary(lib)
|
||||||
|
|
||||||
|
def get_method_name(self, suffix):
|
||||||
|
return '/tensorflow.contrib.rpc.TestCaseService/%s' % suffix
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super(RpcOpTest, self).setUp()
|
||||||
|
|
||||||
|
service_port = portpicker.pick_unused_port()
|
||||||
|
|
||||||
|
server = grpc.server(logging_pool.pool(max_workers=25))
|
||||||
|
servicer = rpc_op_test_servicer.RpcOpTestServicer()
|
||||||
|
test_example_pb2_grpc.add_TestCaseServiceServicer_to_server(
|
||||||
|
servicer, server)
|
||||||
|
self._address = 'localhost:%d' % service_port
|
||||||
|
server.add_insecure_port(self._address)
|
||||||
|
server.start()
|
||||||
|
self._server = server
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
# TODO(ebrevdo): Figure out why this sometimes times out.
|
||||||
|
# self._service.ExitLoop()
|
||||||
|
# self._service_thread.join()
|
||||||
|
# self._server.stop()
|
||||||
|
super(RpcOpTest, self).tearDown()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test.main()
|
||||||
336
tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py
Normal file
336
tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py
Normal file
@ -0,0 +1,336 @@
|
|||||||
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
"""Base class for RpcOp tests."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.contrib.proto.python.ops import decode_proto_op
|
||||||
|
from tensorflow.contrib.proto.python.ops import encode_proto_op
|
||||||
|
from tensorflow.contrib.rpc.python.kernel_tests import test_example_pb2
|
||||||
|
from tensorflow.contrib.rpc.python.ops import rpc_op
|
||||||
|
from tensorflow.core.protobuf import config_pb2
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import errors
|
||||||
|
|
||||||
|
__all__ = ['I_WARNED_YOU', 'RpcOpTestBase']
|
||||||
|
|
||||||
|
I_WARNED_YOU = 'I warned you!'
|
||||||
|
|
||||||
|
|
||||||
|
class RpcOpTestBase(object):
|
||||||
|
# pylint: disable=missing-docstring,invalid-name
|
||||||
|
"""Base class for RpcOp tests."""
|
||||||
|
|
||||||
|
def get_method_name(self, suffix):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def rpc(self, *args, **kwargs):
|
||||||
|
return rpc_op.rpc(*args, protocol=self._protocol, **kwargs)
|
||||||
|
|
||||||
|
def try_rpc(self, *args, **kwargs):
|
||||||
|
return rpc_op.try_rpc(*args, protocol=self._protocol, **kwargs)
|
||||||
|
|
||||||
|
def testScalarHostPortRpc(self):
|
||||||
|
with self.test_session() as sess:
|
||||||
|
request_tensors = (
|
||||||
|
test_example_pb2.TestCase(shape=[1, 2, 3]).SerializeToString())
|
||||||
|
response_tensors = self.rpc(
|
||||||
|
method=self.get_method_name('IncrementTestShapes'),
|
||||||
|
address=self._address,
|
||||||
|
request=request_tensors)
|
||||||
|
self.assertEqual(response_tensors.shape, ())
|
||||||
|
response_values = sess.run(response_tensors)
|
||||||
|
response_message = test_example_pb2.TestCase()
|
||||||
|
self.assertTrue(response_message.ParseFromString(response_values))
|
||||||
|
self.assertAllEqual([2, 3, 4], response_message.shape)
|
||||||
|
|
||||||
|
def testScalarHostPortTryRpc(self):
|
||||||
|
with self.test_session() as sess:
|
||||||
|
request_tensors = (
|
||||||
|
test_example_pb2.TestCase(shape=[1, 2, 3]).SerializeToString())
|
||||||
|
response_tensors, status_code, status_message = self.try_rpc(
|
||||||
|
method=self.get_method_name('IncrementTestShapes'),
|
||||||
|
address=self._address,
|
||||||
|
request=request_tensors)
|
||||||
|
self.assertEqual(status_code.shape, ())
|
||||||
|
self.assertEqual(status_message.shape, ())
|
||||||
|
self.assertEqual(response_tensors.shape, ())
|
||||||
|
response_values, status_code_values, status_message_values = (
|
||||||
|
sess.run((response_tensors, status_code, status_message)))
|
||||||
|
response_message = test_example_pb2.TestCase()
|
||||||
|
self.assertTrue(response_message.ParseFromString(response_values))
|
||||||
|
self.assertAllEqual([2, 3, 4], response_message.shape)
|
||||||
|
# For the base Rpc op, don't expect to get error status back.
|
||||||
|
self.assertEqual(errors.OK, status_code_values)
|
||||||
|
self.assertEqual(b'', status_message_values)
|
||||||
|
|
||||||
|
def testEmptyHostPortRpc(self):
|
||||||
|
with self.test_session() as sess:
|
||||||
|
request_tensors = []
|
||||||
|
response_tensors = self.rpc(
|
||||||
|
method=self.get_method_name('IncrementTestShapes'),
|
||||||
|
address=self._address,
|
||||||
|
request=request_tensors)
|
||||||
|
self.assertAllEqual(response_tensors.shape, [0])
|
||||||
|
response_values = sess.run(response_tensors)
|
||||||
|
self.assertAllEqual(response_values.shape, [0])
|
||||||
|
|
||||||
|
def testInvalidAddresses(self):
|
||||||
|
with self.test_session() as sess:
|
||||||
|
with self.assertRaisesOpError(self.invalid_method_string):
|
||||||
|
sess.run(
|
||||||
|
self.rpc(
|
||||||
|
method='/InvalidService.IncrementTestShapes',
|
||||||
|
address=self._address,
|
||||||
|
request=''))
|
||||||
|
|
||||||
|
with self.assertRaisesOpError(self.invalid_method_string):
|
||||||
|
sess.run(
|
||||||
|
self.rpc(
|
||||||
|
method=self.get_method_name('InvalidMethodName'),
|
||||||
|
address=self._address,
|
||||||
|
request=''))
|
||||||
|
|
||||||
|
# This also covers the case of address=''
|
||||||
|
# and address='localhost:293874293874'
|
||||||
|
with self.assertRaises(errors.UnavailableError):
|
||||||
|
sess.run(
|
||||||
|
self.rpc(
|
||||||
|
method=self.get_method_name('IncrementTestShapes'),
|
||||||
|
address='unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@',
|
||||||
|
request=''))
|
||||||
|
|
||||||
|
# Test invalid method with the TryRpc op
|
||||||
|
_, status_code_value, status_message_value = sess.run(
|
||||||
|
self.try_rpc(
|
||||||
|
method=self.get_method_name('InvalidMethodName'),
|
||||||
|
address=self._address,
|
||||||
|
request=''))
|
||||||
|
self.assertEqual(errors.UNIMPLEMENTED, status_code_value)
|
||||||
|
self.assertTrue(
|
||||||
|
self.invalid_method_string in status_message_value.decode('ascii'))
|
||||||
|
|
||||||
|
def testAlwaysFailingMethod(self):
|
||||||
|
with self.test_session() as sess:
|
||||||
|
response_tensors = self.rpc(
|
||||||
|
method=self.get_method_name('AlwaysFailWithInvalidArgument'),
|
||||||
|
address=self._address,
|
||||||
|
request='')
|
||||||
|
self.assertEqual(response_tensors.shape, ())
|
||||||
|
with self.assertRaisesOpError(I_WARNED_YOU):
|
||||||
|
sess.run(response_tensors)
|
||||||
|
|
||||||
|
def testSometimesFailingMethodWithManyRequests(self):
|
||||||
|
with self.test_session() as sess:
|
||||||
|
# Fail hard by default.
|
||||||
|
response_tensors = self.rpc(
|
||||||
|
method=self.get_method_name('SometimesFailWithInvalidArgument'),
|
||||||
|
address=self._address,
|
||||||
|
request=[''] * 20)
|
||||||
|
self.assertEqual(response_tensors.shape, (20,))
|
||||||
|
with self.assertRaisesOpError(I_WARNED_YOU):
|
||||||
|
sess.run(response_tensors)
|
||||||
|
|
||||||
|
# Don't fail hard, use TryRpc - return the failing status instead.
|
||||||
|
response_tensors, status_code, status_message = self.try_rpc(
|
||||||
|
method=self.get_method_name('SometimesFailWithInvalidArgument'),
|
||||||
|
address=self._address,
|
||||||
|
request=[''] * 20)
|
||||||
|
self.assertEqual(response_tensors.shape, (20,))
|
||||||
|
self.assertEqual(status_code.shape, (20,))
|
||||||
|
self.assertEqual(status_message.shape, (20,))
|
||||||
|
status_code_values, status_message_values = sess.run((status_code,
|
||||||
|
status_message))
|
||||||
|
self.assertTrue([
|
||||||
|
x in (errors.OK, errors.INVALID_ARGUMENT) for x in status_code_values
|
||||||
|
])
|
||||||
|
expected_message_values = np.where(
|
||||||
|
status_code_values == errors.INVALID_ARGUMENT,
|
||||||
|
I_WARNED_YOU.encode('ascii'), b'')
|
||||||
|
self.assertAllEqual(expected_message_values, status_message_values)
|
||||||
|
|
||||||
|
def testVecHostPortRpc(self):
|
||||||
|
with self.test_session() as sess:
|
||||||
|
request_tensors = [
|
||||||
|
test_example_pb2.TestCase(
|
||||||
|
shape=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
|
||||||
|
]
|
||||||
|
response_tensors = self.rpc(
|
||||||
|
method=self.get_method_name('IncrementTestShapes'),
|
||||||
|
address=self._address,
|
||||||
|
request=request_tensors)
|
||||||
|
self.assertEqual(response_tensors.shape, (20,))
|
||||||
|
response_values = sess.run(response_tensors)
|
||||||
|
self.assertEqual(response_values.shape, (20,))
|
||||||
|
for i in range(20):
|
||||||
|
response_message = test_example_pb2.TestCase()
|
||||||
|
self.assertTrue(response_message.ParseFromString(response_values[i]))
|
||||||
|
self.assertAllEqual([i + 1, i + 2, i + 3], response_message.shape)
|
||||||
|
|
||||||
|
def testVecHostPortManyParallelRpcs(self):
|
||||||
|
with self.test_session() as sess:
|
||||||
|
request_tensors = [
|
||||||
|
test_example_pb2.TestCase(
|
||||||
|
shape=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
|
||||||
|
]
|
||||||
|
many_response_tensors = [
|
||||||
|
self.rpc(
|
||||||
|
method=self.get_method_name('IncrementTestShapes'),
|
||||||
|
address=self._address,
|
||||||
|
request=request_tensors) for _ in range(10)
|
||||||
|
]
|
||||||
|
# Launch parallel 10 calls to the RpcOp, each containing
|
||||||
|
# 20 rpc requests.
|
||||||
|
many_response_values = sess.run(many_response_tensors)
|
||||||
|
self.assertEqual(10, len(many_response_values))
|
||||||
|
for response_values in many_response_values:
|
||||||
|
self.assertEqual(response_values.shape, (20,))
|
||||||
|
for i in range(20):
|
||||||
|
response_message = test_example_pb2.TestCase()
|
||||||
|
self.assertTrue(response_message.ParseFromString(response_values[i]))
|
||||||
|
self.assertAllEqual([i + 1, i + 2, i + 3], response_message.shape)
|
||||||
|
|
||||||
|
def testVecHostPortRpcUsingEncodeAndDecodeProto(self):
|
||||||
|
with self.test_session() as sess:
|
||||||
|
request_tensors = encode_proto_op.encode_proto(
|
||||||
|
message_type='tensorflow.contrib.rpc.TestCase',
|
||||||
|
field_names=['shape'],
|
||||||
|
sizes=[[3]] * 20,
|
||||||
|
values=[
|
||||||
|
[[i, i + 1, i + 2] for i in range(20)],
|
||||||
|
])
|
||||||
|
response_tensor_strings = self.rpc(
|
||||||
|
method=self.get_method_name('IncrementTestShapes'),
|
||||||
|
address=self._address,
|
||||||
|
request=request_tensors)
|
||||||
|
_, (response_shape,) = decode_proto_op.decode_proto(
|
||||||
|
bytes=response_tensor_strings,
|
||||||
|
message_type='tensorflow.contrib.rpc.TestCase',
|
||||||
|
field_names=['shape'],
|
||||||
|
output_types=[dtypes.int32])
|
||||||
|
response_shape_values = sess.run(response_shape)
|
||||||
|
self.assertAllEqual([[i + 1, i + 2, i + 3]
|
||||||
|
for i in range(20)], response_shape_values)
|
||||||
|
|
||||||
|
def testVecHostPortRpcCancelsUponSessionTimeOutWhenSleepingForever(self):
|
||||||
|
with self.test_session() as sess:
|
||||||
|
request_tensors = [''] * 25 # This will launch 25 RPC requests.
|
||||||
|
response_tensors = self.rpc(
|
||||||
|
method=self.get_method_name('SleepForever'),
|
||||||
|
address=self._address,
|
||||||
|
request=request_tensors)
|
||||||
|
for timeout_ms in [1, 500, 1000]:
|
||||||
|
options = config_pb2.RunOptions(timeout_in_ms=timeout_ms)
|
||||||
|
with self.assertRaises((errors.UnavailableError,
|
||||||
|
errors.DeadlineExceededError)):
|
||||||
|
sess.run(response_tensors, options=options)
|
||||||
|
|
||||||
|
def testVecHostPortRpcCancelsUponConfiguredTimeOutWhenSleepingForever(self):
|
||||||
|
with self.test_session() as sess:
|
||||||
|
request_tensors = [''] * 25 # This will launch 25 RPC requests.
|
||||||
|
response_tensors = self.rpc(
|
||||||
|
method=self.get_method_name('SleepForever'),
|
||||||
|
address=self._address,
|
||||||
|
timeout_in_ms=1000,
|
||||||
|
request=request_tensors)
|
||||||
|
with self.assertRaises(errors.DeadlineExceededError):
|
||||||
|
sess.run(response_tensors)
|
||||||
|
|
||||||
|
def testTryRpcPropagatesDeadlineErrorWithSometimesTimingOutRequests(self):
|
||||||
|
with self.test_session() as sess:
|
||||||
|
response_tensors, status_code, status_message = self.try_rpc(
|
||||||
|
method=self.get_method_name('SometimesSleepForever'),
|
||||||
|
timeout_in_ms=1000,
|
||||||
|
address=self._address,
|
||||||
|
request=[''] * 20)
|
||||||
|
self.assertEqual(response_tensors.shape, (20,))
|
||||||
|
self.assertEqual(status_code.shape, (20,))
|
||||||
|
self.assertEqual(status_message.shape, (20,))
|
||||||
|
status_code_values = sess.run(status_code)
|
||||||
|
self.assertTrue([
|
||||||
|
x in (errors.OK, errors.DEADLINE_EXCEEDED) for x in status_code_values
|
||||||
|
])
|
||||||
|
|
||||||
|
def testTryRpcWithMultipleAddressesSingleRequest(self):
|
||||||
|
flatten = lambda x: list(itertools.chain.from_iterable(x))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
addresses = flatten([[
|
||||||
|
self._address, 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@'
|
||||||
|
] for _ in range(10)])
|
||||||
|
request = test_example_pb2.TestCase(shape=[0, 1, 2]).SerializeToString()
|
||||||
|
response_tensors, status_code, _ = self.try_rpc(
|
||||||
|
method=self.get_method_name('IncrementTestShapes'),
|
||||||
|
address=addresses,
|
||||||
|
request=request)
|
||||||
|
response_tensors_values, status_code_values = sess.run((response_tensors,
|
||||||
|
status_code))
|
||||||
|
self.assertAllEqual(
|
||||||
|
flatten([errors.OK, errors.UNAVAILABLE] for _ in range(10)),
|
||||||
|
status_code_values)
|
||||||
|
for i in range(10):
|
||||||
|
self.assertTrue(response_tensors_values[2 * i])
|
||||||
|
self.assertFalse(response_tensors_values[2 * i + 1])
|
||||||
|
|
||||||
|
def testTryRpcWithMultipleMethodsSingleRequest(self):
|
||||||
|
flatten = lambda x: list(itertools.chain.from_iterable(x))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
methods = flatten(
|
||||||
|
[[self.get_method_name('IncrementTestShapes'), 'InvalidMethodName']
|
||||||
|
for _ in range(10)])
|
||||||
|
request = test_example_pb2.TestCase(shape=[0, 1, 2]).SerializeToString()
|
||||||
|
response_tensors, status_code, _ = self.try_rpc(
|
||||||
|
method=methods, address=self._address, request=request)
|
||||||
|
response_tensors_values, status_code_values = sess.run((response_tensors,
|
||||||
|
status_code))
|
||||||
|
self.assertAllEqual(
|
||||||
|
flatten([errors.OK, errors.UNIMPLEMENTED] for _ in range(10)),
|
||||||
|
status_code_values)
|
||||||
|
for i in range(10):
|
||||||
|
self.assertTrue(response_tensors_values[2 * i])
|
||||||
|
self.assertFalse(response_tensors_values[2 * i + 1])
|
||||||
|
|
||||||
|
def testTryRpcWithMultipleAddressesAndRequests(self):
|
||||||
|
flatten = lambda x: list(itertools.chain.from_iterable(x))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
addresses = flatten([[
|
||||||
|
self._address, 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@'
|
||||||
|
] for _ in range(10)])
|
||||||
|
requests = [
|
||||||
|
test_example_pb2.TestCase(
|
||||||
|
shape=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
|
||||||
|
]
|
||||||
|
response_tensors, status_code, _ = self.try_rpc(
|
||||||
|
method=self.get_method_name('IncrementTestShapes'),
|
||||||
|
address=addresses,
|
||||||
|
request=requests)
|
||||||
|
response_tensors_values, status_code_values = sess.run((response_tensors,
|
||||||
|
status_code))
|
||||||
|
self.assertAllEqual(
|
||||||
|
flatten([errors.OK, errors.UNAVAILABLE] for _ in range(10)),
|
||||||
|
status_code_values)
|
||||||
|
for i in range(20):
|
||||||
|
if i % 2 == 1:
|
||||||
|
self.assertFalse(response_tensors_values[i])
|
||||||
|
else:
|
||||||
|
response_message = test_example_pb2.TestCase()
|
||||||
|
self.assertTrue(
|
||||||
|
response_message.ParseFromString(response_tensors_values[i]))
|
||||||
|
self.assertAllEqual([i + 1, i + 2, i + 3], response_message.shape)
|
||||||
@ -0,0 +1,101 @@
|
|||||||
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
"""Test servicer for RpcOp tests."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
|
||||||
|
import grpc
|
||||||
|
|
||||||
|
from tensorflow.contrib.rpc.python.kernel_tests import rpc_op_test_base
|
||||||
|
from tensorflow.contrib.rpc.python.kernel_tests import test_example_pb2_grpc
|
||||||
|
|
||||||
|
|
||||||
|
class RpcOpTestServicer(test_example_pb2_grpc.TestCaseServiceServicer):
|
||||||
|
"""Test servicer for RpcOp tests."""
|
||||||
|
|
||||||
|
def IncrementTestShapes(self, request, context):
|
||||||
|
"""Increment the entries in the shape attribute of request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: input TestCase.
|
||||||
|
context: the rpc context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
output TestCase.
|
||||||
|
"""
|
||||||
|
for i in range(len(request.shape)):
|
||||||
|
request.shape[i] += 1
|
||||||
|
return request
|
||||||
|
|
||||||
|
def AlwaysFailWithInvalidArgument(self, request, context):
|
||||||
|
"""Always fails with an InvalidArgument status.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: input TestCase.
|
||||||
|
context: the rpc context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
output TestCase.
|
||||||
|
"""
|
||||||
|
del request
|
||||||
|
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
|
||||||
|
context.set_details(rpc_op_test_base.I_WARNED_YOU)
|
||||||
|
|
||||||
|
def SometimesFailWithInvalidArgument(self, request, context):
|
||||||
|
"""Sometimes fails with an InvalidArgument status.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: input TestCase.
|
||||||
|
context: the rpc context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
output TestCase.
|
||||||
|
"""
|
||||||
|
if random.randint(0, 1) == 1:
|
||||||
|
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
|
||||||
|
context.set_details(rpc_op_test_base.I_WARNED_YOU)
|
||||||
|
return request
|
||||||
|
|
||||||
|
def SleepForever(self, request, context):
|
||||||
|
"""Sleeps forever.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: input TestCase.
|
||||||
|
context: the rpc context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
output TestCase.
|
||||||
|
"""
|
||||||
|
# TODO(ebrevdo): Make this async wait like the stubby version.
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
def SometimesSleepForever(self, request, context):
|
||||||
|
"""Sometimes sleeps forever.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: input TestCase.
|
||||||
|
context: the rpc context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
output TestCase.
|
||||||
|
"""
|
||||||
|
if random.randint(0, 1) == 1:
|
||||||
|
time.sleep(5)
|
||||||
|
return request
|
||||||
171
tensorflow/contrib/rpc/python/kernel_tests/test_example.proto
Normal file
171
tensorflow/contrib/rpc/python/kernel_tests/test_example.proto
Normal file
@ -0,0 +1,171 @@
|
|||||||
|
// Test description and protos to work with it.
|
||||||
|
//
|
||||||
|
// Many of the protos in this file are for unit tests that haven't been written yet.
|
||||||
|
|
||||||
|
syntax = "proto2";
|
||||||
|
|
||||||
|
import "tensorflow/core/framework/types.proto";
|
||||||
|
|
||||||
|
package tensorflow.contrib.rpc;
|
||||||
|
|
||||||
|
// A TestCase holds a proto and a bunch of assertions
|
||||||
|
// about how it should decode.
|
||||||
|
message TestCase {
|
||||||
|
// A batch of primitives to be serialized and decoded.
|
||||||
|
repeated RepeatedPrimitiveValue primitive = 1;
|
||||||
|
// The shape of the batch.
|
||||||
|
repeated int32 shape = 2;
|
||||||
|
// Expected sizes for each field.
|
||||||
|
repeated int32 sizes = 3;
|
||||||
|
// Expected values for each field.
|
||||||
|
repeated FieldSpec field = 4;
|
||||||
|
};
|
||||||
|
|
||||||
|
service TestCaseService {
|
||||||
|
// Copy input, and increment each entry in 'shape' by 1.
|
||||||
|
rpc IncrementTestShapes(TestCase) returns (TestCase) {
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sleep forever.
|
||||||
|
rpc SleepForever(TestCase) returns (TestCase) {
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sleep forever 50% of the time, return immediately the other 50%.
|
||||||
|
rpc SometimesSleepForever(TestCase) returns (TestCase) {
|
||||||
|
}
|
||||||
|
|
||||||
|
// Always fails with InvalidArgument.
|
||||||
|
rpc AlwaysFailWithInvalidArgument(TestCase) returns (TestCase) {
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fails with InvalidArgument 50% of the time.
|
||||||
|
rpc SometimesFailWithInvalidArgument(TestCase) returns (TestCase) {
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// FieldSpec describes the expected output for a single field.
|
||||||
|
message FieldSpec {
|
||||||
|
optional string name = 1;
|
||||||
|
optional tensorflow.DataType dtype = 2;
|
||||||
|
optional RepeatedPrimitiveValue expected = 3;
|
||||||
|
};
|
||||||
|
|
||||||
|
message TestValue {
|
||||||
|
optional PrimitiveValue primitive_value = 1;
|
||||||
|
optional EnumValue enum_value = 2;
|
||||||
|
optional MessageValue message_value = 3;
|
||||||
|
optional RepeatedMessageValue repeated_message_value = 4;
|
||||||
|
optional RepeatedPrimitiveValue repeated_primitive_value = 6;
|
||||||
|
}
|
||||||
|
|
||||||
|
message PrimitiveValue {
|
||||||
|
optional double double_value = 1;
|
||||||
|
optional float float_value = 2;
|
||||||
|
optional int64 int64_value = 3;
|
||||||
|
optional uint64 uint64_value = 4;
|
||||||
|
optional int32 int32_value = 5;
|
||||||
|
optional fixed64 fixed64_value = 6;
|
||||||
|
optional fixed32 fixed32_value = 7;
|
||||||
|
optional bool bool_value = 8;
|
||||||
|
optional string string_value = 9;
|
||||||
|
optional bytes bytes_value = 12;
|
||||||
|
optional uint32 uint32_value = 13;
|
||||||
|
optional sfixed32 sfixed32_value = 15;
|
||||||
|
optional sfixed64 sfixed64_value = 16;
|
||||||
|
optional sint32 sint32_value = 17;
|
||||||
|
optional sint64 sint64_value = 18;
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOTE: This definition must be kept in sync with PackedPrimitiveValue.
|
||||||
|
message RepeatedPrimitiveValue {
|
||||||
|
repeated double double_value = 1;
|
||||||
|
repeated float float_value = 2;
|
||||||
|
repeated int64 int64_value = 3;
|
||||||
|
repeated uint64 uint64_value = 4;
|
||||||
|
repeated int32 int32_value = 5;
|
||||||
|
repeated fixed64 fixed64_value = 6;
|
||||||
|
repeated fixed32 fixed32_value = 7;
|
||||||
|
repeated bool bool_value = 8;
|
||||||
|
repeated string string_value = 9;
|
||||||
|
repeated bytes bytes_value = 12;
|
||||||
|
repeated uint32 uint32_value = 13;
|
||||||
|
repeated sfixed32 sfixed32_value = 15;
|
||||||
|
repeated sfixed64 sfixed64_value = 16;
|
||||||
|
repeated sint32 sint32_value = 17;
|
||||||
|
repeated sint64 sint64_value = 18;
|
||||||
|
repeated PrimitiveValue message_value = 19;
|
||||||
|
}
|
||||||
|
|
||||||
|
// A PackedPrimitiveValue looks exactly the same as a RepeatedPrimitiveValue
|
||||||
|
// in the text format, but the binary serializion is different.
|
||||||
|
// We test the packed representations by loading the same test cases
|
||||||
|
// using this definition instead of RepeatedPrimitiveValue.
|
||||||
|
// NOTE: This definition must be kept in sync with RepeatedPrimitiveValue
|
||||||
|
// in every way except the packed=true declaration.
|
||||||
|
message PackedPrimitiveValue {
|
||||||
|
repeated double double_value = 1 [packed = true];
|
||||||
|
repeated float float_value = 2 [packed = true];
|
||||||
|
repeated int64 int64_value = 3 [packed = true];
|
||||||
|
repeated uint64 uint64_value = 4 [packed = true];
|
||||||
|
repeated int32 int32_value = 5 [packed = true];
|
||||||
|
repeated fixed64 fixed64_value = 6 [packed = true];
|
||||||
|
repeated fixed32 fixed32_value = 7 [packed = true];
|
||||||
|
repeated bool bool_value = 8 [packed = true];
|
||||||
|
repeated string string_value = 9;
|
||||||
|
repeated bytes bytes_value = 12;
|
||||||
|
repeated uint32 uint32_value = 13 [packed = true];
|
||||||
|
repeated sfixed32 sfixed32_value = 15 [packed = true];
|
||||||
|
repeated sfixed64 sfixed64_value = 16 [packed = true];
|
||||||
|
repeated sint32 sint32_value = 17 [packed = true];
|
||||||
|
repeated sint64 sint64_value = 18 [packed = true];
|
||||||
|
repeated PrimitiveValue message_value = 19;
|
||||||
|
}
|
||||||
|
|
||||||
|
message EnumValue {
|
||||||
|
enum Color {
|
||||||
|
RED = 0;
|
||||||
|
ORANGE = 1;
|
||||||
|
YELLOW = 2;
|
||||||
|
GREEN = 3;
|
||||||
|
BLUE = 4;
|
||||||
|
INDIGO = 5;
|
||||||
|
VIOLET = 6;
|
||||||
|
};
|
||||||
|
optional Color enum_value = 14;
|
||||||
|
repeated Color repeated_enum_value = 15;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
message InnerMessageValue {
|
||||||
|
optional float float_value = 2;
|
||||||
|
repeated bytes bytes_values = 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
message MiddleMessageValue {
|
||||||
|
repeated int32 int32_values = 5;
|
||||||
|
optional InnerMessageValue message_value = 11;
|
||||||
|
optional uint32 uint32_value = 13;
|
||||||
|
}
|
||||||
|
|
||||||
|
message MessageValue {
|
||||||
|
optional double double_value = 1;
|
||||||
|
optional MiddleMessageValue message_value = 11;
|
||||||
|
}
|
||||||
|
|
||||||
|
message RepeatedMessageValue {
|
||||||
|
message NestedMessageValue {
|
||||||
|
optional float float_value = 2;
|
||||||
|
repeated bytes bytes_values = 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
repeated NestedMessageValue message_values = 11;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Message containing fields with field numbers higher than any field above. An
|
||||||
|
// instance of this message is prepended to each binary message in the test to
|
||||||
|
// exercise the code path that handles fields encoded out of order of field
|
||||||
|
// number.
|
||||||
|
message ExtraFields {
|
||||||
|
optional string string_value = 1776;
|
||||||
|
optional bool bool_value = 1777;
|
||||||
|
}
|
||||||
@ -1,7 +1,6 @@
|
|||||||
# Platform-specific build configurations.
|
# Platform-specific build configurations.
|
||||||
|
|
||||||
load("@protobuf_archive//:protobuf.bzl", "proto_gen")
|
load("@protobuf_archive//:protobuf.bzl", "proto_gen")
|
||||||
load("@protobuf_archive//:protobuf.bzl", "py_proto_library")
|
|
||||||
load("//tensorflow:tensorflow.bzl", "if_not_mobile")
|
load("//tensorflow:tensorflow.bzl", "if_not_mobile")
|
||||||
load("//tensorflow:tensorflow.bzl", "if_windows")
|
load("//tensorflow:tensorflow.bzl", "if_windows")
|
||||||
load("//tensorflow:tensorflow.bzl", "if_not_windows")
|
load("//tensorflow:tensorflow.bzl", "if_not_windows")
|
||||||
@ -110,6 +109,12 @@ def _proto_cc_srcs(srcs, use_grpc_plugin=False):
|
|||||||
ret += [s[:-len(".proto")] + ".grpc.pb.cc" for s in srcs]
|
ret += [s[:-len(".proto")] + ".grpc.pb.cc" for s in srcs]
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
def _proto_py_outs(srcs, use_grpc_plugin=False):
|
||||||
|
ret = [s[:-len(".proto")] + "_pb2.py" for s in srcs]
|
||||||
|
if use_grpc_plugin:
|
||||||
|
ret += [s[:-len(".proto")] + "_pb2_grpc.py" for s in srcs]
|
||||||
|
return ret
|
||||||
|
|
||||||
# Re-defined protocol buffer rule to allow building "header only" protocol
|
# Re-defined protocol buffer rule to allow building "header only" protocol
|
||||||
# buffers, to avoid duplicate registrations. Also allows non-iterable cc_libs
|
# buffers, to avoid duplicate registrations. Also allows non-iterable cc_libs
|
||||||
# containing select() statements.
|
# containing select() statements.
|
||||||
@ -212,6 +217,80 @@ def cc_proto_library(
|
|||||||
hdrs=gen_hdrs,
|
hdrs=gen_hdrs,
|
||||||
**kargs)
|
**kargs)
|
||||||
|
|
||||||
|
# Re-defined protocol buffer rule to bring in the change introduced in commit
|
||||||
|
# https://github.com/google/protobuf/commit/294b5758c373cbab4b72f35f4cb62dc1d8332b68
|
||||||
|
# which was not part of a stable protobuf release in 04/2018.
|
||||||
|
# TODO(jsimsa): Remove this once the protobuf dependency version is updated
|
||||||
|
# to include the above commit.
|
||||||
|
def py_proto_library(
|
||||||
|
name,
|
||||||
|
srcs=[],
|
||||||
|
deps=[],
|
||||||
|
py_libs=[],
|
||||||
|
py_extra_srcs=[],
|
||||||
|
include=None,
|
||||||
|
default_runtime="@protobuf_archive//:protobuf_python",
|
||||||
|
protoc="@protobuf_archive//:protoc",
|
||||||
|
use_grpc_plugin=False,
|
||||||
|
**kargs):
|
||||||
|
"""Bazel rule to create a Python protobuf library from proto source files
|
||||||
|
|
||||||
|
NOTE: the rule is only an internal workaround to generate protos. The
|
||||||
|
interface may change and the rule may be removed when bazel has introduced
|
||||||
|
the native rule.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: the name of the py_proto_library.
|
||||||
|
srcs: the .proto files of the py_proto_library.
|
||||||
|
deps: a list of dependency labels; must be py_proto_library.
|
||||||
|
py_libs: a list of other py_library targets depended by the generated
|
||||||
|
py_library.
|
||||||
|
py_extra_srcs: extra source files that will be added to the output
|
||||||
|
py_library. This attribute is used for internal bootstrapping.
|
||||||
|
include: a string indicating the include path of the .proto files.
|
||||||
|
default_runtime: the implicitly default runtime which will be depended on by
|
||||||
|
the generated py_library target.
|
||||||
|
protoc: the label of the protocol compiler to generate the sources.
|
||||||
|
use_grpc_plugin: a flag to indicate whether to call the Python C++ plugin
|
||||||
|
when processing the proto files.
|
||||||
|
**kargs: other keyword arguments that are passed to cc_library.
|
||||||
|
"""
|
||||||
|
outs = _proto_py_outs(srcs, use_grpc_plugin)
|
||||||
|
|
||||||
|
includes = []
|
||||||
|
if include != None:
|
||||||
|
includes = [include]
|
||||||
|
|
||||||
|
grpc_python_plugin = None
|
||||||
|
if use_grpc_plugin:
|
||||||
|
grpc_python_plugin = "//external:grpc_python_plugin"
|
||||||
|
# Note: Generated grpc code depends on Python grpc module. This dependency
|
||||||
|
# is not explicitly listed in py_libs. Instead, host system is assumed to
|
||||||
|
# have grpc installed.
|
||||||
|
|
||||||
|
proto_gen(
|
||||||
|
name=name + "_genproto",
|
||||||
|
srcs=srcs,
|
||||||
|
deps=[s + "_genproto" for s in deps],
|
||||||
|
includes=includes,
|
||||||
|
protoc=protoc,
|
||||||
|
gen_py=1,
|
||||||
|
outs=outs,
|
||||||
|
visibility=["//visibility:public"],
|
||||||
|
plugin=grpc_python_plugin,
|
||||||
|
plugin_language="grpc"
|
||||||
|
)
|
||||||
|
|
||||||
|
if default_runtime and not default_runtime in py_libs + deps:
|
||||||
|
py_libs = py_libs + [default_runtime]
|
||||||
|
|
||||||
|
native.py_library(
|
||||||
|
name=name,
|
||||||
|
srcs=outs+py_extra_srcs,
|
||||||
|
deps=py_libs+deps,
|
||||||
|
imports=includes,
|
||||||
|
**kargs)
|
||||||
|
|
||||||
def tf_proto_library_cc(name, srcs = [], has_services = None,
|
def tf_proto_library_cc(name, srcs = [], has_services = None,
|
||||||
protodeps = [],
|
protodeps = [],
|
||||||
visibility = [], testonly = 0,
|
visibility = [], testonly = 0,
|
||||||
@ -256,8 +335,7 @@ def tf_proto_library_cc(name, srcs = [], has_services = None,
|
|||||||
)
|
)
|
||||||
|
|
||||||
def tf_proto_library_py(name, srcs=[], protodeps=[], deps=[], visibility=[],
|
def tf_proto_library_py(name, srcs=[], protodeps=[], deps=[], visibility=[],
|
||||||
testonly=0,
|
testonly=0, srcs_version="PY2AND3", use_grpc_plugin=False):
|
||||||
srcs_version="PY2AND3"):
|
|
||||||
py_proto_library(
|
py_proto_library(
|
||||||
name = name + "_py",
|
name = name + "_py",
|
||||||
srcs = srcs,
|
srcs = srcs,
|
||||||
@ -267,6 +345,7 @@ def tf_proto_library_py(name, srcs=[], protodeps=[], deps=[], visibility=[],
|
|||||||
default_runtime = "@protobuf_archive//:protobuf_python",
|
default_runtime = "@protobuf_archive//:protobuf_python",
|
||||||
visibility = visibility,
|
visibility = visibility,
|
||||||
testonly = testonly,
|
testonly = testonly,
|
||||||
|
use_grpc_plugin = use_grpc_plugin,
|
||||||
)
|
)
|
||||||
|
|
||||||
def tf_jspb_proto_library(**kwargs):
|
def tf_jspb_proto_library(**kwargs):
|
||||||
@ -305,6 +384,7 @@ def tf_proto_library(name, srcs = [], has_services = None,
|
|||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
testonly = testonly,
|
testonly = testonly,
|
||||||
visibility = visibility,
|
visibility = visibility,
|
||||||
|
use_grpc_plugin = has_services,
|
||||||
)
|
)
|
||||||
|
|
||||||
def tf_additional_lib_hdrs(exclude = []):
|
def tf_additional_lib_hdrs(exclude = []):
|
||||||
|
|||||||
@ -76,6 +76,7 @@ COMMON_PIP_DEPS = [
|
|||||||
"//tensorflow/contrib/predictor:predictor_pip",
|
"//tensorflow/contrib/predictor:predictor_pip",
|
||||||
"//tensorflow/contrib/proto:proto_pip",
|
"//tensorflow/contrib/proto:proto_pip",
|
||||||
"//tensorflow/contrib/receptive_field:receptive_field_pip",
|
"//tensorflow/contrib/receptive_field:receptive_field_pip",
|
||||||
|
"//tensorflow/contrib/rpc:rpc_pip",
|
||||||
"//tensorflow/contrib/session_bundle:session_bundle_pip",
|
"//tensorflow/contrib/session_bundle:session_bundle_pip",
|
||||||
"//tensorflow/contrib/signal:signal_py",
|
"//tensorflow/contrib/signal:signal_py",
|
||||||
"//tensorflow/contrib/signal:test_util",
|
"//tensorflow/contrib/signal:test_util",
|
||||||
|
|||||||
@ -752,6 +752,10 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
|
|||||||
name = "grpc_cpp_plugin",
|
name = "grpc_cpp_plugin",
|
||||||
actual = "@grpc//:grpc_cpp_plugin",
|
actual = "@grpc//:grpc_cpp_plugin",
|
||||||
)
|
)
|
||||||
|
native.bind(
|
||||||
|
name = "grpc_python_plugin",
|
||||||
|
actual = "@grpc//:grpc_python_plugin",
|
||||||
|
)
|
||||||
|
|
||||||
# gRPC has three empty C++ functions which it wants the user to define
|
# gRPC has three empty C++ functions which it wants the user to define
|
||||||
# at build time. https://github.com/grpc/grpc/issues/13590
|
# at build time. https://github.com/grpc/grpc/issues/13590
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user