Cleaning up test proto for tensorflow/contrib/rpc.

PiperOrigin-RevId: 204307008
This commit is contained in:
Jiri Simsa 2018-07-12 08:54:19 -07:00 committed by TensorFlower Gardener
parent 34a1b6780b
commit 0ca8c47bfe
4 changed files with 34 additions and 176 deletions

View File

@ -1,5 +1,3 @@
# TODO(b/76425722): Port everything in here to OS (currently excluded).
package(default_visibility = ["//visibility:public"]) package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0 licenses(["notice"]) # Apache 2.0
@ -17,7 +15,6 @@ tf_proto_library(
srcs = ["test_example.proto"], srcs = ["test_example.proto"],
has_services = 1, has_services = 1,
cc_api_version = 2, cc_api_version = 2,
protodeps = ["//tensorflow/core:protos_all"],
) )
py_library( py_library(

View File

@ -51,23 +51,23 @@ class RpcOpTestBase(object):
def testScalarHostPortRpc(self): def testScalarHostPortRpc(self):
with self.test_session() as sess: with self.test_session() as sess:
request_tensors = ( request_tensors = (
test_example_pb2.TestCase(shape=[1, 2, 3]).SerializeToString()) test_example_pb2.TestCase(values=[1, 2, 3]).SerializeToString())
response_tensors = self.rpc( response_tensors = self.rpc(
method=self.get_method_name('IncrementTestShapes'), method=self.get_method_name('Increment'),
address=self._address, address=self._address,
request=request_tensors) request=request_tensors)
self.assertEqual(response_tensors.shape, ()) self.assertEqual(response_tensors.shape, ())
response_values = sess.run(response_tensors) response_values = sess.run(response_tensors)
response_message = test_example_pb2.TestCase() response_message = test_example_pb2.TestCase()
self.assertTrue(response_message.ParseFromString(response_values)) self.assertTrue(response_message.ParseFromString(response_values))
self.assertAllEqual([2, 3, 4], response_message.shape) self.assertAllEqual([2, 3, 4], response_message.values)
def testScalarHostPortTryRpc(self): def testScalarHostPortTryRpc(self):
with self.test_session() as sess: with self.test_session() as sess:
request_tensors = ( request_tensors = (
test_example_pb2.TestCase(shape=[1, 2, 3]).SerializeToString()) test_example_pb2.TestCase(values=[1, 2, 3]).SerializeToString())
response_tensors, status_code, status_message = self.try_rpc( response_tensors, status_code, status_message = self.try_rpc(
method=self.get_method_name('IncrementTestShapes'), method=self.get_method_name('Increment'),
address=self._address, address=self._address,
request=request_tensors) request=request_tensors)
self.assertEqual(status_code.shape, ()) self.assertEqual(status_code.shape, ())
@ -77,7 +77,7 @@ class RpcOpTestBase(object):
sess.run((response_tensors, status_code, status_message))) sess.run((response_tensors, status_code, status_message)))
response_message = test_example_pb2.TestCase() response_message = test_example_pb2.TestCase()
self.assertTrue(response_message.ParseFromString(response_values)) self.assertTrue(response_message.ParseFromString(response_values))
self.assertAllEqual([2, 3, 4], response_message.shape) self.assertAllEqual([2, 3, 4], response_message.values)
# For the base Rpc op, don't expect to get error status back. # For the base Rpc op, don't expect to get error status back.
self.assertEqual(errors.OK, status_code_values) self.assertEqual(errors.OK, status_code_values)
self.assertEqual(b'', status_message_values) self.assertEqual(b'', status_message_values)
@ -86,7 +86,7 @@ class RpcOpTestBase(object):
with self.test_session() as sess: with self.test_session() as sess:
request_tensors = [] request_tensors = []
response_tensors = self.rpc( response_tensors = self.rpc(
method=self.get_method_name('IncrementTestShapes'), method=self.get_method_name('Increment'),
address=self._address, address=self._address,
request=request_tensors) request=request_tensors)
self.assertAllEqual(response_tensors.shape, [0]) self.assertAllEqual(response_tensors.shape, [0])
@ -95,7 +95,7 @@ class RpcOpTestBase(object):
def testInvalidMethod(self): def testInvalidMethod(self):
for method in [ for method in [
'/InvalidService.IncrementTestShapes', '/InvalidService.Increment',
self.get_method_name('InvalidMethodName') self.get_method_name('InvalidMethodName')
]: ]:
with self.test_session() as sess: with self.test_session() as sess:
@ -115,12 +115,12 @@ class RpcOpTestBase(object):
with self.assertRaises(errors.UnavailableError): with self.assertRaises(errors.UnavailableError):
sess.run( sess.run(
self.rpc( self.rpc(
method=self.get_method_name('IncrementTestShapes'), method=self.get_method_name('Increment'),
address=address, address=address,
request='')) request=''))
_, status_code_value, status_message_value = sess.run( _, status_code_value, status_message_value = sess.run(
self.try_rpc( self.try_rpc(
method=self.get_method_name('IncrementTestShapes'), method=self.get_method_name('Increment'),
address=address, address=address,
request='')) request=''))
self.assertEqual(errors.UNAVAILABLE, status_code_value) self.assertEqual(errors.UNAVAILABLE, status_code_value)
@ -182,10 +182,10 @@ class RpcOpTestBase(object):
with self.test_session() as sess: with self.test_session() as sess:
request_tensors = [ request_tensors = [
test_example_pb2.TestCase( test_example_pb2.TestCase(
shape=[i, i + 1, i + 2]).SerializeToString() for i in range(20) values=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
] ]
response_tensors = self.rpc( response_tensors = self.rpc(
method=self.get_method_name('IncrementTestShapes'), method=self.get_method_name('Increment'),
address=self._address, address=self._address,
request=request_tensors) request=request_tensors)
self.assertEqual(response_tensors.shape, (20,)) self.assertEqual(response_tensors.shape, (20,))
@ -194,17 +194,17 @@ class RpcOpTestBase(object):
for i in range(20): for i in range(20):
response_message = test_example_pb2.TestCase() response_message = test_example_pb2.TestCase()
self.assertTrue(response_message.ParseFromString(response_values[i])) self.assertTrue(response_message.ParseFromString(response_values[i]))
self.assertAllEqual([i + 1, i + 2, i + 3], response_message.shape) self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values)
def testVecHostPortManyParallelRpcs(self): def testVecHostPortManyParallelRpcs(self):
with self.test_session() as sess: with self.test_session() as sess:
request_tensors = [ request_tensors = [
test_example_pb2.TestCase( test_example_pb2.TestCase(
shape=[i, i + 1, i + 2]).SerializeToString() for i in range(20) values=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
] ]
many_response_tensors = [ many_response_tensors = [
self.rpc( self.rpc(
method=self.get_method_name('IncrementTestShapes'), method=self.get_method_name('Increment'),
address=self._address, address=self._address,
request=request_tensors) for _ in range(10) request=request_tensors) for _ in range(10)
] ]
@ -216,25 +216,25 @@ class RpcOpTestBase(object):
for i in range(20): for i in range(20):
response_message = test_example_pb2.TestCase() response_message = test_example_pb2.TestCase()
self.assertTrue(response_message.ParseFromString(response_values[i])) self.assertTrue(response_message.ParseFromString(response_values[i]))
self.assertAllEqual([i + 1, i + 2, i + 3], response_message.shape) self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values)
def testVecHostPortRpcUsingEncodeAndDecodeProto(self): def testVecHostPortRpcUsingEncodeAndDecodeProto(self):
with self.test_session() as sess: with self.test_session() as sess:
request_tensors = encode_proto_op.encode_proto( request_tensors = encode_proto_op.encode_proto(
message_type='tensorflow.contrib.rpc.TestCase', message_type='tensorflow.contrib.rpc.TestCase',
field_names=['shape'], field_names=['values'],
sizes=[[3]] * 20, sizes=[[3]] * 20,
values=[ values=[
[[i, i + 1, i + 2] for i in range(20)], [[i, i + 1, i + 2] for i in range(20)],
]) ])
response_tensor_strings = self.rpc( response_tensor_strings = self.rpc(
method=self.get_method_name('IncrementTestShapes'), method=self.get_method_name('Increment'),
address=self._address, address=self._address,
request=request_tensors) request=request_tensors)
_, (response_shape,) = decode_proto_op.decode_proto( _, (response_shape,) = decode_proto_op.decode_proto(
bytes=response_tensor_strings, bytes=response_tensor_strings,
message_type='tensorflow.contrib.rpc.TestCase', message_type='tensorflow.contrib.rpc.TestCase',
field_names=['shape'], field_names=['values'],
output_types=[dtypes.int32]) output_types=[dtypes.int32])
response_shape_values = sess.run(response_shape) response_shape_values = sess.run(response_shape)
self.assertAllEqual([[i + 1, i + 2, i + 3] self.assertAllEqual([[i + 1, i + 2, i + 3]
@ -285,9 +285,9 @@ class RpcOpTestBase(object):
addresses = flatten([[ addresses = flatten([[
self._address, 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@' self._address, 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@'
] for _ in range(10)]) ] for _ in range(10)])
request = test_example_pb2.TestCase(shape=[0, 1, 2]).SerializeToString() request = test_example_pb2.TestCase(values=[0, 1, 2]).SerializeToString()
response_tensors, status_code, _ = self.try_rpc( response_tensors, status_code, _ = self.try_rpc(
method=self.get_method_name('IncrementTestShapes'), method=self.get_method_name('Increment'),
address=addresses, address=addresses,
request=request) request=request)
response_tensors_values, status_code_values = sess.run((response_tensors, response_tensors_values, status_code_values = sess.run((response_tensors,
@ -303,9 +303,9 @@ class RpcOpTestBase(object):
flatten = lambda x: list(itertools.chain.from_iterable(x)) flatten = lambda x: list(itertools.chain.from_iterable(x))
with self.test_session() as sess: with self.test_session() as sess:
methods = flatten( methods = flatten(
[[self.get_method_name('IncrementTestShapes'), 'InvalidMethodName'] [[self.get_method_name('Increment'), 'InvalidMethodName']
for _ in range(10)]) for _ in range(10)])
request = test_example_pb2.TestCase(shape=[0, 1, 2]).SerializeToString() request = test_example_pb2.TestCase(values=[0, 1, 2]).SerializeToString()
response_tensors, status_code, _ = self.try_rpc( response_tensors, status_code, _ = self.try_rpc(
method=methods, address=self._address, request=request) method=methods, address=self._address, request=request)
response_tensors_values, status_code_values = sess.run((response_tensors, response_tensors_values, status_code_values = sess.run((response_tensors,
@ -325,10 +325,10 @@ class RpcOpTestBase(object):
] for _ in range(10)]) ] for _ in range(10)])
requests = [ requests = [
test_example_pb2.TestCase( test_example_pb2.TestCase(
shape=[i, i + 1, i + 2]).SerializeToString() for i in range(20) values=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
] ]
response_tensors, status_code, _ = self.try_rpc( response_tensors, status_code, _ = self.try_rpc(
method=self.get_method_name('IncrementTestShapes'), method=self.get_method_name('Increment'),
address=addresses, address=addresses,
request=requests) request=requests)
response_tensors_values, status_code_values = sess.run((response_tensors, response_tensors_values, status_code_values = sess.run((response_tensors,
@ -343,4 +343,4 @@ class RpcOpTestBase(object):
response_message = test_example_pb2.TestCase() response_message = test_example_pb2.TestCase()
self.assertTrue( self.assertTrue(
response_message.ParseFromString(response_tensors_values[i])) response_message.ParseFromString(response_tensors_values[i]))
self.assertAllEqual([i + 1, i + 2, i + 3], response_message.shape) self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values)

View File

@ -30,8 +30,8 @@ from tensorflow.contrib.rpc.python.kernel_tests import test_example_pb2_grpc
class RpcOpTestServicer(test_example_pb2_grpc.TestCaseServiceServicer): class RpcOpTestServicer(test_example_pb2_grpc.TestCaseServiceServicer):
"""Test servicer for RpcOp tests.""" """Test servicer for RpcOp tests."""
def IncrementTestShapes(self, request, context): def Increment(self, request, context):
"""Increment the entries in the shape attribute of request. """Increment the entries in the `values` attribute of request.
Args: Args:
request: input TestCase. request: input TestCase.
@ -40,8 +40,8 @@ class RpcOpTestServicer(test_example_pb2_grpc.TestCaseServiceServicer):
Returns: Returns:
output TestCase. output TestCase.
""" """
for i in range(len(request.shape)): for i in range(len(request.values)):
request.shape[i] += 1 request.values[i] += 1
return request return request
def AlwaysFailWithInvalidArgument(self, request, context): def AlwaysFailWithInvalidArgument(self, request, context):

View File

@ -1,29 +1,17 @@
// Test description and protos to work with it. // Test description and protos to work with it.
//
// Many of the protos in this file are for unit tests that haven't been written yet.
syntax = "proto2"; syntax = "proto2";
import "tensorflow/core/framework/types.proto";
package tensorflow.contrib.rpc; package tensorflow.contrib.rpc;
// A TestCase holds a proto and a bunch of assertions // A TestCase holds a sequence of values.
// about how it should decode.
message TestCase { message TestCase {
// A batch of primitives to be serialized and decoded. repeated int32 values = 1;
repeated RepeatedPrimitiveValue primitive = 1;
// The shape of the batch.
repeated int32 shape = 2;
// Expected sizes for each field.
repeated int32 sizes = 3;
// Expected values for each field.
repeated FieldSpec field = 4;
}; };
service TestCaseService { service TestCaseService {
// Copy input, and increment each entry in 'shape' by 1. // Copy input, and increment each entry in 'values' by 1.
rpc IncrementTestShapes(TestCase) returns (TestCase) { rpc Increment(TestCase) returns (TestCase) {
} }
// Sleep forever. // Sleep forever.
@ -42,130 +30,3 @@ service TestCaseService {
rpc SometimesFailWithInvalidArgument(TestCase) returns (TestCase) { rpc SometimesFailWithInvalidArgument(TestCase) returns (TestCase) {
} }
}; };
// FieldSpec describes the expected output for a single field.
message FieldSpec {
optional string name = 1;
optional tensorflow.DataType dtype = 2;
optional RepeatedPrimitiveValue expected = 3;
};
message TestValue {
optional PrimitiveValue primitive_value = 1;
optional EnumValue enum_value = 2;
optional MessageValue message_value = 3;
optional RepeatedMessageValue repeated_message_value = 4;
optional RepeatedPrimitiveValue repeated_primitive_value = 6;
}
message PrimitiveValue {
optional double double_value = 1;
optional float float_value = 2;
optional int64 int64_value = 3;
optional uint64 uint64_value = 4;
optional int32 int32_value = 5;
optional fixed64 fixed64_value = 6;
optional fixed32 fixed32_value = 7;
optional bool bool_value = 8;
optional string string_value = 9;
optional bytes bytes_value = 12;
optional uint32 uint32_value = 13;
optional sfixed32 sfixed32_value = 15;
optional sfixed64 sfixed64_value = 16;
optional sint32 sint32_value = 17;
optional sint64 sint64_value = 18;
}
// NOTE: This definition must be kept in sync with PackedPrimitiveValue.
message RepeatedPrimitiveValue {
repeated double double_value = 1;
repeated float float_value = 2;
repeated int64 int64_value = 3;
repeated uint64 uint64_value = 4;
repeated int32 int32_value = 5;
repeated fixed64 fixed64_value = 6;
repeated fixed32 fixed32_value = 7;
repeated bool bool_value = 8;
repeated string string_value = 9;
repeated bytes bytes_value = 12;
repeated uint32 uint32_value = 13;
repeated sfixed32 sfixed32_value = 15;
repeated sfixed64 sfixed64_value = 16;
repeated sint32 sint32_value = 17;
repeated sint64 sint64_value = 18;
repeated PrimitiveValue message_value = 19;
}
// A PackedPrimitiveValue looks exactly the same as a RepeatedPrimitiveValue
// in the text format, but the binary serializion is different.
// We test the packed representations by loading the same test cases
// using this definition instead of RepeatedPrimitiveValue.
// NOTE: This definition must be kept in sync with RepeatedPrimitiveValue
// in every way except the packed=true declaration.
message PackedPrimitiveValue {
repeated double double_value = 1 [packed = true];
repeated float float_value = 2 [packed = true];
repeated int64 int64_value = 3 [packed = true];
repeated uint64 uint64_value = 4 [packed = true];
repeated int32 int32_value = 5 [packed = true];
repeated fixed64 fixed64_value = 6 [packed = true];
repeated fixed32 fixed32_value = 7 [packed = true];
repeated bool bool_value = 8 [packed = true];
repeated string string_value = 9;
repeated bytes bytes_value = 12;
repeated uint32 uint32_value = 13 [packed = true];
repeated sfixed32 sfixed32_value = 15 [packed = true];
repeated sfixed64 sfixed64_value = 16 [packed = true];
repeated sint32 sint32_value = 17 [packed = true];
repeated sint64 sint64_value = 18 [packed = true];
repeated PrimitiveValue message_value = 19;
}
message EnumValue {
enum Color {
RED = 0;
ORANGE = 1;
YELLOW = 2;
GREEN = 3;
BLUE = 4;
INDIGO = 5;
VIOLET = 6;
};
optional Color enum_value = 14;
repeated Color repeated_enum_value = 15;
}
message InnerMessageValue {
optional float float_value = 2;
repeated bytes bytes_values = 8;
}
message MiddleMessageValue {
repeated int32 int32_values = 5;
optional InnerMessageValue message_value = 11;
optional uint32 uint32_value = 13;
}
message MessageValue {
optional double double_value = 1;
optional MiddleMessageValue message_value = 11;
}
message RepeatedMessageValue {
message NestedMessageValue {
optional float float_value = 2;
repeated bytes bytes_values = 8;
}
repeated NestedMessageValue message_values = 11;
}
// Message containing fields with field numbers higher than any field above. An
// instance of this message is prepended to each binary message in the test to
// exercise the code path that handles fields encoded out of order of field
// number.
message ExtraFields {
optional string string_value = 1776;
optional bool bool_value = 1777;
}