decode_proto: Fixed decoding of default values for enum fields.
PiperOrigin-RevId: 271043492
This commit is contained in:
parent
ead8115fe2
commit
f776c50ca3
@ -163,7 +163,6 @@ Status InitDefaultValueFromFieldDescriptor(DataType dtype,
|
||||
case WireFormatLite::TYPE_UINT64:
|
||||
return InitDefaultValue(dtype, field_desc->default_value_uint64(),
|
||||
result);
|
||||
case WireFormatLite::TYPE_ENUM:
|
||||
case WireFormatLite::TYPE_INT32:
|
||||
case WireFormatLite::TYPE_SINT32:
|
||||
case WireFormatLite::TYPE_SFIXED32:
|
||||
@ -174,6 +173,9 @@ Status InitDefaultValueFromFieldDescriptor(DataType dtype,
|
||||
result);
|
||||
case WireFormatLite::TYPE_BOOL:
|
||||
return InitDefaultValue(dtype, field_desc->default_value_bool(), result);
|
||||
case WireFormatLite::TYPE_ENUM:
|
||||
return InitDefaultValue(dtype, field_desc->default_value_enum()->number(),
|
||||
result);
|
||||
case WireFormatLite::TYPE_BYTES:
|
||||
case WireFormatLite::TYPE_STRING:
|
||||
// Manipulating default string values as C-style pointers should be OK
|
||||
|
@ -134,7 +134,10 @@ class DecodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase):
|
||||
dtypes.uint64:
|
||||
'uint64_value',
|
||||
}
|
||||
tf_field_name = tf_type_to_primitive_value_field.get(field.dtype)
|
||||
if field.name in ['enum_value', 'enum_value_with_default']:
|
||||
tf_field_name = 'enum_value'
|
||||
else:
|
||||
tf_field_name = tf_type_to_primitive_value_field.get(field.dtype)
|
||||
if tf_field_name is None:
|
||||
self.fail('Unhandled tensorflow type %d' % field.dtype)
|
||||
|
||||
|
@ -131,6 +131,11 @@ class ProtoOpTestBase(test.TestCase):
|
||||
field.name = "bytes_value_with_default"
|
||||
field.dtype = types_pb2.DT_STRING
|
||||
field.value.string_value.append("a longer default string")
|
||||
test_case.sizes.append(0)
|
||||
field = test_case.fields.add()
|
||||
field.name = "enum_value_with_default"
|
||||
field.dtype = types_pb2.DT_INT32
|
||||
field.value.enum_value.append(test_example_pb2.Color.GREEN)
|
||||
return test_case
|
||||
|
||||
@staticmethod
|
||||
@ -421,6 +426,7 @@ class ProtoOpTestBase(test.TestCase):
|
||||
value = test_case.values.add()
|
||||
value.double_value.append(23.5)
|
||||
value.bool_value.append(True)
|
||||
value.enum_value.append(test_example_pb2.Color.INDIGO)
|
||||
test_case.shapes.append(1)
|
||||
test_case.sizes.append(1)
|
||||
field = test_case.fields.add()
|
||||
@ -432,4 +438,9 @@ class ProtoOpTestBase(test.TestCase):
|
||||
field.name = "bool_value"
|
||||
field.dtype = types_pb2.DT_BOOL
|
||||
field.value.bool_value.append(True)
|
||||
test_case.sizes.append(1)
|
||||
field = test_case.fields.add()
|
||||
field.name = "enum_value"
|
||||
field.dtype = types_pb2.DT_INT32
|
||||
field.value.enum_value.append(test_example_pb2.Color.INDIGO)
|
||||
return test_case
|
||||
|
@ -2,10 +2,10 @@
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
import "tensorflow/core/framework/types.proto";
|
||||
|
||||
package tensorflow.contrib.proto;
|
||||
|
||||
import "tensorflow/core/framework/types.proto";
|
||||
|
||||
// A TestCase holds a proto and assertions about how it should decode.
|
||||
message TestCase {
|
||||
// Batches of primitive values.
|
||||
@ -16,14 +16,24 @@ message TestCase {
|
||||
repeated int32 sizes = 3;
|
||||
// Expected values for each field.
|
||||
repeated FieldSpec fields = 4;
|
||||
};
|
||||
}
|
||||
|
||||
// FieldSpec describes the expected output for a single field.
|
||||
message FieldSpec {
|
||||
optional string name = 1;
|
||||
optional tensorflow.DataType dtype = 2;
|
||||
optional TestValue value = 3;
|
||||
};
|
||||
}
|
||||
|
||||
enum Color {
|
||||
RED = 0;
|
||||
ORANGE = 1;
|
||||
YELLOW = 2;
|
||||
GREEN = 3;
|
||||
BLUE = 4;
|
||||
INDIGO = 5;
|
||||
VIOLET = 6;
|
||||
}
|
||||
|
||||
// NOTE: This definition must be kept in sync with PackedTestValue.
|
||||
message TestValue {
|
||||
@ -43,6 +53,7 @@ message TestValue {
|
||||
repeated sint32 sint32_value = 17;
|
||||
repeated sint64 sint64_value = 18;
|
||||
repeated PrimitiveValue message_value = 19;
|
||||
repeated Color enum_value = 35;
|
||||
|
||||
// Optional fields with explicitly-specified defaults.
|
||||
optional double double_value_with_default = 20 [default = 1.0];
|
||||
@ -61,6 +72,7 @@ message TestValue {
|
||||
optional sfixed64 sfixed64_value_with_default = 32 [default = 11];
|
||||
optional sint32 sint32_value_with_default = 33 [default = 12];
|
||||
optional sint64 sint64_value_with_default = 34 [default = 13];
|
||||
optional Color enum_value_with_default = 36 [default = GREEN];
|
||||
|
||||
extensions 100 to 199;
|
||||
}
|
||||
@ -88,6 +100,7 @@ message PackedTestValue {
|
||||
repeated sint32 sint32_value = 17 [packed = true];
|
||||
repeated sint64 sint64_value = 18 [packed = true];
|
||||
repeated PrimitiveValue message_value = 19;
|
||||
repeated Color enum_value = 35;
|
||||
|
||||
optional double double_value_with_default = 20 [default = 1.0];
|
||||
optional float float_value_with_default = 21 [default = 2.0];
|
||||
@ -105,6 +118,7 @@ message PackedTestValue {
|
||||
optional sfixed64 sfixed64_value_with_default = 32 [default = 11];
|
||||
optional sint32 sint32_value_with_default = 33 [default = 12];
|
||||
optional sint64 sint64_value_with_default = 34 [default = 13];
|
||||
optional Color enum_value_with_default = 36 [default = GREEN];
|
||||
}
|
||||
|
||||
message PrimitiveValue {
|
||||
@ -140,21 +154,6 @@ extend TestValue {
|
||||
|
||||
// The messages below are for yet-to-be created tests.
|
||||
|
||||
message EnumValue {
|
||||
enum Color {
|
||||
RED = 0;
|
||||
ORANGE = 1;
|
||||
YELLOW = 2;
|
||||
GREEN = 3;
|
||||
BLUE = 4;
|
||||
INDIGO = 5;
|
||||
VIOLET = 6;
|
||||
};
|
||||
optional Color enum_value = 14;
|
||||
repeated Color repeated_enum_value = 15;
|
||||
}
|
||||
|
||||
|
||||
message InnerMessageValue {
|
||||
optional float float_value = 2;
|
||||
repeated bytes bytes_values = 8;
|
||||
|
Loading…
Reference in New Issue
Block a user