Cleaning up test proto for tensorflow/contrib/rpc
.
PiperOrigin-RevId: 204307008
This commit is contained in:
parent
34a1b6780b
commit
0ca8c47bfe
@ -1,5 +1,3 @@
|
||||
# TODO(b/76425722): Port everything in here to OS (currently excluded).
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
@ -17,7 +15,6 @@ tf_proto_library(
|
||||
srcs = ["test_example.proto"],
|
||||
has_services = 1,
|
||||
cc_api_version = 2,
|
||||
protodeps = ["//tensorflow/core:protos_all"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
|
@ -51,23 +51,23 @@ class RpcOpTestBase(object):
|
||||
def testScalarHostPortRpc(self):
|
||||
with self.test_session() as sess:
|
||||
request_tensors = (
|
||||
test_example_pb2.TestCase(shape=[1, 2, 3]).SerializeToString())
|
||||
test_example_pb2.TestCase(values=[1, 2, 3]).SerializeToString())
|
||||
response_tensors = self.rpc(
|
||||
method=self.get_method_name('IncrementTestShapes'),
|
||||
method=self.get_method_name('Increment'),
|
||||
address=self._address,
|
||||
request=request_tensors)
|
||||
self.assertEqual(response_tensors.shape, ())
|
||||
response_values = sess.run(response_tensors)
|
||||
response_message = test_example_pb2.TestCase()
|
||||
self.assertTrue(response_message.ParseFromString(response_values))
|
||||
self.assertAllEqual([2, 3, 4], response_message.shape)
|
||||
self.assertAllEqual([2, 3, 4], response_message.values)
|
||||
|
||||
def testScalarHostPortTryRpc(self):
|
||||
with self.test_session() as sess:
|
||||
request_tensors = (
|
||||
test_example_pb2.TestCase(shape=[1, 2, 3]).SerializeToString())
|
||||
test_example_pb2.TestCase(values=[1, 2, 3]).SerializeToString())
|
||||
response_tensors, status_code, status_message = self.try_rpc(
|
||||
method=self.get_method_name('IncrementTestShapes'),
|
||||
method=self.get_method_name('Increment'),
|
||||
address=self._address,
|
||||
request=request_tensors)
|
||||
self.assertEqual(status_code.shape, ())
|
||||
@ -77,7 +77,7 @@ class RpcOpTestBase(object):
|
||||
sess.run((response_tensors, status_code, status_message)))
|
||||
response_message = test_example_pb2.TestCase()
|
||||
self.assertTrue(response_message.ParseFromString(response_values))
|
||||
self.assertAllEqual([2, 3, 4], response_message.shape)
|
||||
self.assertAllEqual([2, 3, 4], response_message.values)
|
||||
# For the base Rpc op, don't expect to get error status back.
|
||||
self.assertEqual(errors.OK, status_code_values)
|
||||
self.assertEqual(b'', status_message_values)
|
||||
@ -86,7 +86,7 @@ class RpcOpTestBase(object):
|
||||
with self.test_session() as sess:
|
||||
request_tensors = []
|
||||
response_tensors = self.rpc(
|
||||
method=self.get_method_name('IncrementTestShapes'),
|
||||
method=self.get_method_name('Increment'),
|
||||
address=self._address,
|
||||
request=request_tensors)
|
||||
self.assertAllEqual(response_tensors.shape, [0])
|
||||
@ -95,7 +95,7 @@ class RpcOpTestBase(object):
|
||||
|
||||
def testInvalidMethod(self):
|
||||
for method in [
|
||||
'/InvalidService.IncrementTestShapes',
|
||||
'/InvalidService.Increment',
|
||||
self.get_method_name('InvalidMethodName')
|
||||
]:
|
||||
with self.test_session() as sess:
|
||||
@ -115,12 +115,12 @@ class RpcOpTestBase(object):
|
||||
with self.assertRaises(errors.UnavailableError):
|
||||
sess.run(
|
||||
self.rpc(
|
||||
method=self.get_method_name('IncrementTestShapes'),
|
||||
method=self.get_method_name('Increment'),
|
||||
address=address,
|
||||
request=''))
|
||||
_, status_code_value, status_message_value = sess.run(
|
||||
self.try_rpc(
|
||||
method=self.get_method_name('IncrementTestShapes'),
|
||||
method=self.get_method_name('Increment'),
|
||||
address=address,
|
||||
request=''))
|
||||
self.assertEqual(errors.UNAVAILABLE, status_code_value)
|
||||
@ -182,10 +182,10 @@ class RpcOpTestBase(object):
|
||||
with self.test_session() as sess:
|
||||
request_tensors = [
|
||||
test_example_pb2.TestCase(
|
||||
shape=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
|
||||
values=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
|
||||
]
|
||||
response_tensors = self.rpc(
|
||||
method=self.get_method_name('IncrementTestShapes'),
|
||||
method=self.get_method_name('Increment'),
|
||||
address=self._address,
|
||||
request=request_tensors)
|
||||
self.assertEqual(response_tensors.shape, (20,))
|
||||
@ -194,17 +194,17 @@ class RpcOpTestBase(object):
|
||||
for i in range(20):
|
||||
response_message = test_example_pb2.TestCase()
|
||||
self.assertTrue(response_message.ParseFromString(response_values[i]))
|
||||
self.assertAllEqual([i + 1, i + 2, i + 3], response_message.shape)
|
||||
self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values)
|
||||
|
||||
def testVecHostPortManyParallelRpcs(self):
|
||||
with self.test_session() as sess:
|
||||
request_tensors = [
|
||||
test_example_pb2.TestCase(
|
||||
shape=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
|
||||
values=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
|
||||
]
|
||||
many_response_tensors = [
|
||||
self.rpc(
|
||||
method=self.get_method_name('IncrementTestShapes'),
|
||||
method=self.get_method_name('Increment'),
|
||||
address=self._address,
|
||||
request=request_tensors) for _ in range(10)
|
||||
]
|
||||
@ -216,25 +216,25 @@ class RpcOpTestBase(object):
|
||||
for i in range(20):
|
||||
response_message = test_example_pb2.TestCase()
|
||||
self.assertTrue(response_message.ParseFromString(response_values[i]))
|
||||
self.assertAllEqual([i + 1, i + 2, i + 3], response_message.shape)
|
||||
self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values)
|
||||
|
||||
def testVecHostPortRpcUsingEncodeAndDecodeProto(self):
|
||||
with self.test_session() as sess:
|
||||
request_tensors = encode_proto_op.encode_proto(
|
||||
message_type='tensorflow.contrib.rpc.TestCase',
|
||||
field_names=['shape'],
|
||||
field_names=['values'],
|
||||
sizes=[[3]] * 20,
|
||||
values=[
|
||||
[[i, i + 1, i + 2] for i in range(20)],
|
||||
])
|
||||
response_tensor_strings = self.rpc(
|
||||
method=self.get_method_name('IncrementTestShapes'),
|
||||
method=self.get_method_name('Increment'),
|
||||
address=self._address,
|
||||
request=request_tensors)
|
||||
_, (response_shape,) = decode_proto_op.decode_proto(
|
||||
bytes=response_tensor_strings,
|
||||
message_type='tensorflow.contrib.rpc.TestCase',
|
||||
field_names=['shape'],
|
||||
field_names=['values'],
|
||||
output_types=[dtypes.int32])
|
||||
response_shape_values = sess.run(response_shape)
|
||||
self.assertAllEqual([[i + 1, i + 2, i + 3]
|
||||
@ -285,9 +285,9 @@ class RpcOpTestBase(object):
|
||||
addresses = flatten([[
|
||||
self._address, 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@'
|
||||
] for _ in range(10)])
|
||||
request = test_example_pb2.TestCase(shape=[0, 1, 2]).SerializeToString()
|
||||
request = test_example_pb2.TestCase(values=[0, 1, 2]).SerializeToString()
|
||||
response_tensors, status_code, _ = self.try_rpc(
|
||||
method=self.get_method_name('IncrementTestShapes'),
|
||||
method=self.get_method_name('Increment'),
|
||||
address=addresses,
|
||||
request=request)
|
||||
response_tensors_values, status_code_values = sess.run((response_tensors,
|
||||
@ -303,9 +303,9 @@ class RpcOpTestBase(object):
|
||||
flatten = lambda x: list(itertools.chain.from_iterable(x))
|
||||
with self.test_session() as sess:
|
||||
methods = flatten(
|
||||
[[self.get_method_name('IncrementTestShapes'), 'InvalidMethodName']
|
||||
[[self.get_method_name('Increment'), 'InvalidMethodName']
|
||||
for _ in range(10)])
|
||||
request = test_example_pb2.TestCase(shape=[0, 1, 2]).SerializeToString()
|
||||
request = test_example_pb2.TestCase(values=[0, 1, 2]).SerializeToString()
|
||||
response_tensors, status_code, _ = self.try_rpc(
|
||||
method=methods, address=self._address, request=request)
|
||||
response_tensors_values, status_code_values = sess.run((response_tensors,
|
||||
@ -325,10 +325,10 @@ class RpcOpTestBase(object):
|
||||
] for _ in range(10)])
|
||||
requests = [
|
||||
test_example_pb2.TestCase(
|
||||
shape=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
|
||||
values=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
|
||||
]
|
||||
response_tensors, status_code, _ = self.try_rpc(
|
||||
method=self.get_method_name('IncrementTestShapes'),
|
||||
method=self.get_method_name('Increment'),
|
||||
address=addresses,
|
||||
request=requests)
|
||||
response_tensors_values, status_code_values = sess.run((response_tensors,
|
||||
@ -343,4 +343,4 @@ class RpcOpTestBase(object):
|
||||
response_message = test_example_pb2.TestCase()
|
||||
self.assertTrue(
|
||||
response_message.ParseFromString(response_tensors_values[i]))
|
||||
self.assertAllEqual([i + 1, i + 2, i + 3], response_message.shape)
|
||||
self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values)
|
||||
|
@ -30,8 +30,8 @@ from tensorflow.contrib.rpc.python.kernel_tests import test_example_pb2_grpc
|
||||
class RpcOpTestServicer(test_example_pb2_grpc.TestCaseServiceServicer):
|
||||
"""Test servicer for RpcOp tests."""
|
||||
|
||||
def IncrementTestShapes(self, request, context):
|
||||
"""Increment the entries in the shape attribute of request.
|
||||
def Increment(self, request, context):
|
||||
"""Increment the entries in the `values` attribute of request.
|
||||
|
||||
Args:
|
||||
request: input TestCase.
|
||||
@ -40,8 +40,8 @@ class RpcOpTestServicer(test_example_pb2_grpc.TestCaseServiceServicer):
|
||||
Returns:
|
||||
output TestCase.
|
||||
"""
|
||||
for i in range(len(request.shape)):
|
||||
request.shape[i] += 1
|
||||
for i in range(len(request.values)):
|
||||
request.values[i] += 1
|
||||
return request
|
||||
|
||||
def AlwaysFailWithInvalidArgument(self, request, context):
|
||||
|
@ -1,29 +1,17 @@
|
||||
// Test description and protos to work with it.
|
||||
//
|
||||
// Many of the protos in this file are for unit tests that haven't been written yet.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
import "tensorflow/core/framework/types.proto";
|
||||
|
||||
package tensorflow.contrib.rpc;
|
||||
|
||||
// A TestCase holds a proto and a bunch of assertions
|
||||
// about how it should decode.
|
||||
// A TestCase holds a sequence of values.
|
||||
message TestCase {
|
||||
// A batch of primitives to be serialized and decoded.
|
||||
repeated RepeatedPrimitiveValue primitive = 1;
|
||||
// The shape of the batch.
|
||||
repeated int32 shape = 2;
|
||||
// Expected sizes for each field.
|
||||
repeated int32 sizes = 3;
|
||||
// Expected values for each field.
|
||||
repeated FieldSpec field = 4;
|
||||
repeated int32 values = 1;
|
||||
};
|
||||
|
||||
service TestCaseService {
|
||||
// Copy input, and increment each entry in 'shape' by 1.
|
||||
rpc IncrementTestShapes(TestCase) returns (TestCase) {
|
||||
// Copy input, and increment each entry in 'values' by 1.
|
||||
rpc Increment(TestCase) returns (TestCase) {
|
||||
}
|
||||
|
||||
// Sleep forever.
|
||||
@ -42,130 +30,3 @@ service TestCaseService {
|
||||
rpc SometimesFailWithInvalidArgument(TestCase) returns (TestCase) {
|
||||
}
|
||||
};
|
||||
|
||||
// FieldSpec describes the expected output for a single field.
|
||||
message FieldSpec {
|
||||
optional string name = 1;
|
||||
optional tensorflow.DataType dtype = 2;
|
||||
optional RepeatedPrimitiveValue expected = 3;
|
||||
};
|
||||
|
||||
message TestValue {
|
||||
optional PrimitiveValue primitive_value = 1;
|
||||
optional EnumValue enum_value = 2;
|
||||
optional MessageValue message_value = 3;
|
||||
optional RepeatedMessageValue repeated_message_value = 4;
|
||||
optional RepeatedPrimitiveValue repeated_primitive_value = 6;
|
||||
}
|
||||
|
||||
message PrimitiveValue {
|
||||
optional double double_value = 1;
|
||||
optional float float_value = 2;
|
||||
optional int64 int64_value = 3;
|
||||
optional uint64 uint64_value = 4;
|
||||
optional int32 int32_value = 5;
|
||||
optional fixed64 fixed64_value = 6;
|
||||
optional fixed32 fixed32_value = 7;
|
||||
optional bool bool_value = 8;
|
||||
optional string string_value = 9;
|
||||
optional bytes bytes_value = 12;
|
||||
optional uint32 uint32_value = 13;
|
||||
optional sfixed32 sfixed32_value = 15;
|
||||
optional sfixed64 sfixed64_value = 16;
|
||||
optional sint32 sint32_value = 17;
|
||||
optional sint64 sint64_value = 18;
|
||||
}
|
||||
|
||||
// NOTE: This definition must be kept in sync with PackedPrimitiveValue.
|
||||
message RepeatedPrimitiveValue {
|
||||
repeated double double_value = 1;
|
||||
repeated float float_value = 2;
|
||||
repeated int64 int64_value = 3;
|
||||
repeated uint64 uint64_value = 4;
|
||||
repeated int32 int32_value = 5;
|
||||
repeated fixed64 fixed64_value = 6;
|
||||
repeated fixed32 fixed32_value = 7;
|
||||
repeated bool bool_value = 8;
|
||||
repeated string string_value = 9;
|
||||
repeated bytes bytes_value = 12;
|
||||
repeated uint32 uint32_value = 13;
|
||||
repeated sfixed32 sfixed32_value = 15;
|
||||
repeated sfixed64 sfixed64_value = 16;
|
||||
repeated sint32 sint32_value = 17;
|
||||
repeated sint64 sint64_value = 18;
|
||||
repeated PrimitiveValue message_value = 19;
|
||||
}
|
||||
|
||||
// A PackedPrimitiveValue looks exactly the same as a RepeatedPrimitiveValue
|
||||
// in the text format, but the binary serializion is different.
|
||||
// We test the packed representations by loading the same test cases
|
||||
// using this definition instead of RepeatedPrimitiveValue.
|
||||
// NOTE: This definition must be kept in sync with RepeatedPrimitiveValue
|
||||
// in every way except the packed=true declaration.
|
||||
message PackedPrimitiveValue {
|
||||
repeated double double_value = 1 [packed = true];
|
||||
repeated float float_value = 2 [packed = true];
|
||||
repeated int64 int64_value = 3 [packed = true];
|
||||
repeated uint64 uint64_value = 4 [packed = true];
|
||||
repeated int32 int32_value = 5 [packed = true];
|
||||
repeated fixed64 fixed64_value = 6 [packed = true];
|
||||
repeated fixed32 fixed32_value = 7 [packed = true];
|
||||
repeated bool bool_value = 8 [packed = true];
|
||||
repeated string string_value = 9;
|
||||
repeated bytes bytes_value = 12;
|
||||
repeated uint32 uint32_value = 13 [packed = true];
|
||||
repeated sfixed32 sfixed32_value = 15 [packed = true];
|
||||
repeated sfixed64 sfixed64_value = 16 [packed = true];
|
||||
repeated sint32 sint32_value = 17 [packed = true];
|
||||
repeated sint64 sint64_value = 18 [packed = true];
|
||||
repeated PrimitiveValue message_value = 19;
|
||||
}
|
||||
|
||||
message EnumValue {
|
||||
enum Color {
|
||||
RED = 0;
|
||||
ORANGE = 1;
|
||||
YELLOW = 2;
|
||||
GREEN = 3;
|
||||
BLUE = 4;
|
||||
INDIGO = 5;
|
||||
VIOLET = 6;
|
||||
};
|
||||
optional Color enum_value = 14;
|
||||
repeated Color repeated_enum_value = 15;
|
||||
}
|
||||
|
||||
|
||||
message InnerMessageValue {
|
||||
optional float float_value = 2;
|
||||
repeated bytes bytes_values = 8;
|
||||
}
|
||||
|
||||
message MiddleMessageValue {
|
||||
repeated int32 int32_values = 5;
|
||||
optional InnerMessageValue message_value = 11;
|
||||
optional uint32 uint32_value = 13;
|
||||
}
|
||||
|
||||
message MessageValue {
|
||||
optional double double_value = 1;
|
||||
optional MiddleMessageValue message_value = 11;
|
||||
}
|
||||
|
||||
message RepeatedMessageValue {
|
||||
message NestedMessageValue {
|
||||
optional float float_value = 2;
|
||||
repeated bytes bytes_values = 8;
|
||||
}
|
||||
|
||||
repeated NestedMessageValue message_values = 11;
|
||||
}
|
||||
|
||||
// Message containing fields with field numbers higher than any field above. An
|
||||
// instance of this message is prepended to each binary message in the test to
|
||||
// exercise the code path that handles fields encoded out of order of field
|
||||
// number.
|
||||
message ExtraFields {
|
||||
optional string string_value = 1776;
|
||||
optional bool bool_value = 1777;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user