Cleaning up test proto for tensorflow/contrib/rpc.

PiperOrigin-RevId: 204307008
This commit is contained in:
Jiri Simsa 2018-07-12 08:54:19 -07:00 committed by TensorFlower Gardener
parent 34a1b6780b
commit 0ca8c47bfe
4 changed files with 34 additions and 176 deletions

View File

@ -1,5 +1,3 @@
# TODO(b/76425722): Port everything in here to OS (currently excluded).
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
@ -17,7 +15,6 @@ tf_proto_library(
srcs = ["test_example.proto"],
has_services = 1,
cc_api_version = 2,
protodeps = ["//tensorflow/core:protos_all"],
)
py_library(

View File

@ -51,23 +51,23 @@ class RpcOpTestBase(object):
def testScalarHostPortRpc(self):
with self.test_session() as sess:
request_tensors = (
test_example_pb2.TestCase(shape=[1, 2, 3]).SerializeToString())
test_example_pb2.TestCase(values=[1, 2, 3]).SerializeToString())
response_tensors = self.rpc(
method=self.get_method_name('IncrementTestShapes'),
method=self.get_method_name('Increment'),
address=self._address,
request=request_tensors)
self.assertEqual(response_tensors.shape, ())
response_values = sess.run(response_tensors)
response_message = test_example_pb2.TestCase()
self.assertTrue(response_message.ParseFromString(response_values))
self.assertAllEqual([2, 3, 4], response_message.shape)
self.assertAllEqual([2, 3, 4], response_message.values)
def testScalarHostPortTryRpc(self):
with self.test_session() as sess:
request_tensors = (
test_example_pb2.TestCase(shape=[1, 2, 3]).SerializeToString())
test_example_pb2.TestCase(values=[1, 2, 3]).SerializeToString())
response_tensors, status_code, status_message = self.try_rpc(
method=self.get_method_name('IncrementTestShapes'),
method=self.get_method_name('Increment'),
address=self._address,
request=request_tensors)
self.assertEqual(status_code.shape, ())
@ -77,7 +77,7 @@ class RpcOpTestBase(object):
sess.run((response_tensors, status_code, status_message)))
response_message = test_example_pb2.TestCase()
self.assertTrue(response_message.ParseFromString(response_values))
self.assertAllEqual([2, 3, 4], response_message.shape)
self.assertAllEqual([2, 3, 4], response_message.values)
# For the base Rpc op, don't expect to get error status back.
self.assertEqual(errors.OK, status_code_values)
self.assertEqual(b'', status_message_values)
@ -86,7 +86,7 @@ class RpcOpTestBase(object):
with self.test_session() as sess:
request_tensors = []
response_tensors = self.rpc(
method=self.get_method_name('IncrementTestShapes'),
method=self.get_method_name('Increment'),
address=self._address,
request=request_tensors)
self.assertAllEqual(response_tensors.shape, [0])
@ -95,7 +95,7 @@ class RpcOpTestBase(object):
def testInvalidMethod(self):
for method in [
'/InvalidService.IncrementTestShapes',
'/InvalidService.Increment',
self.get_method_name('InvalidMethodName')
]:
with self.test_session() as sess:
@ -115,12 +115,12 @@ class RpcOpTestBase(object):
with self.assertRaises(errors.UnavailableError):
sess.run(
self.rpc(
method=self.get_method_name('IncrementTestShapes'),
method=self.get_method_name('Increment'),
address=address,
request=''))
_, status_code_value, status_message_value = sess.run(
self.try_rpc(
method=self.get_method_name('IncrementTestShapes'),
method=self.get_method_name('Increment'),
address=address,
request=''))
self.assertEqual(errors.UNAVAILABLE, status_code_value)
@ -182,10 +182,10 @@ class RpcOpTestBase(object):
with self.test_session() as sess:
request_tensors = [
test_example_pb2.TestCase(
shape=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
values=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
]
response_tensors = self.rpc(
method=self.get_method_name('IncrementTestShapes'),
method=self.get_method_name('Increment'),
address=self._address,
request=request_tensors)
self.assertEqual(response_tensors.shape, (20,))
@ -194,17 +194,17 @@ class RpcOpTestBase(object):
for i in range(20):
response_message = test_example_pb2.TestCase()
self.assertTrue(response_message.ParseFromString(response_values[i]))
self.assertAllEqual([i + 1, i + 2, i + 3], response_message.shape)
self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values)
def testVecHostPortManyParallelRpcs(self):
with self.test_session() as sess:
request_tensors = [
test_example_pb2.TestCase(
shape=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
values=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
]
many_response_tensors = [
self.rpc(
method=self.get_method_name('IncrementTestShapes'),
method=self.get_method_name('Increment'),
address=self._address,
request=request_tensors) for _ in range(10)
]
@ -216,25 +216,25 @@ class RpcOpTestBase(object):
for i in range(20):
response_message = test_example_pb2.TestCase()
self.assertTrue(response_message.ParseFromString(response_values[i]))
self.assertAllEqual([i + 1, i + 2, i + 3], response_message.shape)
self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values)
def testVecHostPortRpcUsingEncodeAndDecodeProto(self):
with self.test_session() as sess:
request_tensors = encode_proto_op.encode_proto(
message_type='tensorflow.contrib.rpc.TestCase',
field_names=['shape'],
field_names=['values'],
sizes=[[3]] * 20,
values=[
[[i, i + 1, i + 2] for i in range(20)],
])
response_tensor_strings = self.rpc(
method=self.get_method_name('IncrementTestShapes'),
method=self.get_method_name('Increment'),
address=self._address,
request=request_tensors)
_, (response_shape,) = decode_proto_op.decode_proto(
bytes=response_tensor_strings,
message_type='tensorflow.contrib.rpc.TestCase',
field_names=['shape'],
field_names=['values'],
output_types=[dtypes.int32])
response_shape_values = sess.run(response_shape)
self.assertAllEqual([[i + 1, i + 2, i + 3]
@ -285,9 +285,9 @@ class RpcOpTestBase(object):
addresses = flatten([[
self._address, 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@'
] for _ in range(10)])
request = test_example_pb2.TestCase(shape=[0, 1, 2]).SerializeToString()
request = test_example_pb2.TestCase(values=[0, 1, 2]).SerializeToString()
response_tensors, status_code, _ = self.try_rpc(
method=self.get_method_name('IncrementTestShapes'),
method=self.get_method_name('Increment'),
address=addresses,
request=request)
response_tensors_values, status_code_values = sess.run((response_tensors,
@ -303,9 +303,9 @@ class RpcOpTestBase(object):
flatten = lambda x: list(itertools.chain.from_iterable(x))
with self.test_session() as sess:
methods = flatten(
[[self.get_method_name('IncrementTestShapes'), 'InvalidMethodName']
[[self.get_method_name('Increment'), 'InvalidMethodName']
for _ in range(10)])
request = test_example_pb2.TestCase(shape=[0, 1, 2]).SerializeToString()
request = test_example_pb2.TestCase(values=[0, 1, 2]).SerializeToString()
response_tensors, status_code, _ = self.try_rpc(
method=methods, address=self._address, request=request)
response_tensors_values, status_code_values = sess.run((response_tensors,
@ -325,10 +325,10 @@ class RpcOpTestBase(object):
] for _ in range(10)])
requests = [
test_example_pb2.TestCase(
shape=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
values=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
]
response_tensors, status_code, _ = self.try_rpc(
method=self.get_method_name('IncrementTestShapes'),
method=self.get_method_name('Increment'),
address=addresses,
request=requests)
response_tensors_values, status_code_values = sess.run((response_tensors,
@ -343,4 +343,4 @@ class RpcOpTestBase(object):
response_message = test_example_pb2.TestCase()
self.assertTrue(
response_message.ParseFromString(response_tensors_values[i]))
self.assertAllEqual([i + 1, i + 2, i + 3], response_message.shape)
self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values)

View File

@ -30,8 +30,8 @@ from tensorflow.contrib.rpc.python.kernel_tests import test_example_pb2_grpc
class RpcOpTestServicer(test_example_pb2_grpc.TestCaseServiceServicer):
"""Test servicer for RpcOp tests."""
def IncrementTestShapes(self, request, context):
"""Increment the entries in the shape attribute of request.
def Increment(self, request, context):
"""Increment the entries in the `values` attribute of request.
Args:
request: input TestCase.
@ -40,8 +40,8 @@ class RpcOpTestServicer(test_example_pb2_grpc.TestCaseServiceServicer):
Returns:
output TestCase.
"""
for i in range(len(request.shape)):
request.shape[i] += 1
for i in range(len(request.values)):
request.values[i] += 1
return request
def AlwaysFailWithInvalidArgument(self, request, context):

View File

@ -1,29 +1,17 @@
// Test description and protos to work with it.
//
// Many of the protos in this file are for unit tests that haven't been written yet.
syntax = "proto2";
import "tensorflow/core/framework/types.proto";
package tensorflow.contrib.rpc;
// A TestCase holds a proto and a bunch of assertions
// about how it should decode.
// A TestCase holds a sequence of values.
message TestCase {
// A batch of primitives to be serialized and decoded.
repeated RepeatedPrimitiveValue primitive = 1;
// The shape of the batch.
repeated int32 shape = 2;
// Expected sizes for each field.
repeated int32 sizes = 3;
// Expected values for each field.
repeated FieldSpec field = 4;
repeated int32 values = 1;
};
service TestCaseService {
// Copy input, and increment each entry in 'shape' by 1.
rpc IncrementTestShapes(TestCase) returns (TestCase) {
// Copy input, and increment each entry in 'values' by 1.
rpc Increment(TestCase) returns (TestCase) {
}
// Sleep forever.
@ -42,130 +30,3 @@ service TestCaseService {
rpc SometimesFailWithInvalidArgument(TestCase) returns (TestCase) {
}
};
// FieldSpec describes the expected output for a single field.
message FieldSpec {
optional string name = 1;
optional tensorflow.DataType dtype = 2;
optional RepeatedPrimitiveValue expected = 3;
};
message TestValue {
optional PrimitiveValue primitive_value = 1;
optional EnumValue enum_value = 2;
optional MessageValue message_value = 3;
optional RepeatedMessageValue repeated_message_value = 4;
optional RepeatedPrimitiveValue repeated_primitive_value = 6;
}
message PrimitiveValue {
optional double double_value = 1;
optional float float_value = 2;
optional int64 int64_value = 3;
optional uint64 uint64_value = 4;
optional int32 int32_value = 5;
optional fixed64 fixed64_value = 6;
optional fixed32 fixed32_value = 7;
optional bool bool_value = 8;
optional string string_value = 9;
optional bytes bytes_value = 12;
optional uint32 uint32_value = 13;
optional sfixed32 sfixed32_value = 15;
optional sfixed64 sfixed64_value = 16;
optional sint32 sint32_value = 17;
optional sint64 sint64_value = 18;
}
// NOTE: This definition must be kept in sync with PackedPrimitiveValue.
message RepeatedPrimitiveValue {
repeated double double_value = 1;
repeated float float_value = 2;
repeated int64 int64_value = 3;
repeated uint64 uint64_value = 4;
repeated int32 int32_value = 5;
repeated fixed64 fixed64_value = 6;
repeated fixed32 fixed32_value = 7;
repeated bool bool_value = 8;
repeated string string_value = 9;
repeated bytes bytes_value = 12;
repeated uint32 uint32_value = 13;
repeated sfixed32 sfixed32_value = 15;
repeated sfixed64 sfixed64_value = 16;
repeated sint32 sint32_value = 17;
repeated sint64 sint64_value = 18;
repeated PrimitiveValue message_value = 19;
}
// A PackedPrimitiveValue looks exactly the same as a RepeatedPrimitiveValue
// in the text format, but the binary serializion is different.
// We test the packed representations by loading the same test cases
// using this definition instead of RepeatedPrimitiveValue.
// NOTE: This definition must be kept in sync with RepeatedPrimitiveValue
// in every way except the packed=true declaration.
message PackedPrimitiveValue {
repeated double double_value = 1 [packed = true];
repeated float float_value = 2 [packed = true];
repeated int64 int64_value = 3 [packed = true];
repeated uint64 uint64_value = 4 [packed = true];
repeated int32 int32_value = 5 [packed = true];
repeated fixed64 fixed64_value = 6 [packed = true];
repeated fixed32 fixed32_value = 7 [packed = true];
repeated bool bool_value = 8 [packed = true];
repeated string string_value = 9;
repeated bytes bytes_value = 12;
repeated uint32 uint32_value = 13 [packed = true];
repeated sfixed32 sfixed32_value = 15 [packed = true];
repeated sfixed64 sfixed64_value = 16 [packed = true];
repeated sint32 sint32_value = 17 [packed = true];
repeated sint64 sint64_value = 18 [packed = true];
repeated PrimitiveValue message_value = 19;
}
message EnumValue {
enum Color {
RED = 0;
ORANGE = 1;
YELLOW = 2;
GREEN = 3;
BLUE = 4;
INDIGO = 5;
VIOLET = 6;
};
optional Color enum_value = 14;
repeated Color repeated_enum_value = 15;
}
message InnerMessageValue {
optional float float_value = 2;
repeated bytes bytes_values = 8;
}
message MiddleMessageValue {
repeated int32 int32_values = 5;
optional InnerMessageValue message_value = 11;
optional uint32 uint32_value = 13;
}
message MessageValue {
optional double double_value = 1;
optional MiddleMessageValue message_value = 11;
}
message RepeatedMessageValue {
message NestedMessageValue {
optional float float_value = 2;
repeated bytes bytes_values = 8;
}
repeated NestedMessageValue message_values = 11;
}
// Message containing fields with field numbers higher than any field above. An
// instance of this message is prepended to each binary message in the test to
// exercise the code path that handles fields encoded out of order of field
// number.
message ExtraFields {
optional string string_value = 1776;
optional bool bool_value = 1777;
}