decode_proto: Fixed decoding of default values for enum fields.

PiperOrigin-RevId: 271043492
This commit is contained in:
A. Unique TensorFlower 2019-09-24 20:45:56 -07:00 committed by TensorFlower Gardener
parent ead8115fe2
commit f776c50ca3
4 changed files with 36 additions and 21 deletions

View File

@ -163,7 +163,6 @@ Status InitDefaultValueFromFieldDescriptor(DataType dtype,
case WireFormatLite::TYPE_UINT64:
return InitDefaultValue(dtype, field_desc->default_value_uint64(),
result);
case WireFormatLite::TYPE_ENUM:
case WireFormatLite::TYPE_INT32:
case WireFormatLite::TYPE_SINT32:
case WireFormatLite::TYPE_SFIXED32:
@ -174,6 +173,9 @@ Status InitDefaultValueFromFieldDescriptor(DataType dtype,
result);
case WireFormatLite::TYPE_BOOL:
return InitDefaultValue(dtype, field_desc->default_value_bool(), result);
case WireFormatLite::TYPE_ENUM:
return InitDefaultValue(dtype, field_desc->default_value_enum()->number(),
result);
case WireFormatLite::TYPE_BYTES:
case WireFormatLite::TYPE_STRING:
// Manipulating default string values as C-style pointers should be OK

View File

@ -134,7 +134,10 @@ class DecodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase):
dtypes.uint64:
'uint64_value',
}
tf_field_name = tf_type_to_primitive_value_field.get(field.dtype)
if field.name in ['enum_value', 'enum_value_with_default']:
tf_field_name = 'enum_value'
else:
tf_field_name = tf_type_to_primitive_value_field.get(field.dtype)
if tf_field_name is None:
self.fail('Unhandled tensorflow type %d' % field.dtype)

View File

@ -131,6 +131,11 @@ class ProtoOpTestBase(test.TestCase):
field.name = "bytes_value_with_default"
field.dtype = types_pb2.DT_STRING
field.value.string_value.append("a longer default string")
test_case.sizes.append(0)
field = test_case.fields.add()
field.name = "enum_value_with_default"
field.dtype = types_pb2.DT_INT32
field.value.enum_value.append(test_example_pb2.Color.GREEN)
return test_case
@staticmethod
@ -421,6 +426,7 @@ class ProtoOpTestBase(test.TestCase):
value = test_case.values.add()
value.double_value.append(23.5)
value.bool_value.append(True)
value.enum_value.append(test_example_pb2.Color.INDIGO)
test_case.shapes.append(1)
test_case.sizes.append(1)
field = test_case.fields.add()
@ -432,4 +438,9 @@ class ProtoOpTestBase(test.TestCase):
field.name = "bool_value"
field.dtype = types_pb2.DT_BOOL
field.value.bool_value.append(True)
test_case.sizes.append(1)
field = test_case.fields.add()
field.name = "enum_value"
field.dtype = types_pb2.DT_INT32
field.value.enum_value.append(test_example_pb2.Color.INDIGO)
return test_case

View File

@ -2,10 +2,10 @@
syntax = "proto2";
import "tensorflow/core/framework/types.proto";
package tensorflow.contrib.proto;
import "tensorflow/core/framework/types.proto";
// A TestCase holds a proto and assertions about how it should decode.
message TestCase {
// Batches of primitive values.
@ -16,14 +16,24 @@ message TestCase {
repeated int32 sizes = 3;
// Expected values for each field.
repeated FieldSpec fields = 4;
};
}
// FieldSpec describes the expected output for a single field.
message FieldSpec {
optional string name = 1;
optional tensorflow.DataType dtype = 2;
optional TestValue value = 3;
};
}
enum Color {
RED = 0;
ORANGE = 1;
YELLOW = 2;
GREEN = 3;
BLUE = 4;
INDIGO = 5;
VIOLET = 6;
}
// NOTE: This definition must be kept in sync with PackedTestValue.
message TestValue {
@ -43,6 +53,7 @@ message TestValue {
repeated sint32 sint32_value = 17;
repeated sint64 sint64_value = 18;
repeated PrimitiveValue message_value = 19;
repeated Color enum_value = 35;
// Optional fields with explicitly-specified defaults.
optional double double_value_with_default = 20 [default = 1.0];
@ -61,6 +72,7 @@ message TestValue {
optional sfixed64 sfixed64_value_with_default = 32 [default = 11];
optional sint32 sint32_value_with_default = 33 [default = 12];
optional sint64 sint64_value_with_default = 34 [default = 13];
optional Color enum_value_with_default = 36 [default = GREEN];
extensions 100 to 199;
}
@ -88,6 +100,7 @@ message PackedTestValue {
repeated sint32 sint32_value = 17 [packed = true];
repeated sint64 sint64_value = 18 [packed = true];
repeated PrimitiveValue message_value = 19;
repeated Color enum_value = 35;
optional double double_value_with_default = 20 [default = 1.0];
optional float float_value_with_default = 21 [default = 2.0];
@ -105,6 +118,7 @@ message PackedTestValue {
optional sfixed64 sfixed64_value_with_default = 32 [default = 11];
optional sint32 sint32_value_with_default = 33 [default = 12];
optional sint64 sint64_value_with_default = 34 [default = 13];
optional Color enum_value_with_default = 36 [default = GREEN];
}
message PrimitiveValue {
@ -140,21 +154,6 @@ extend TestValue {
// The messages below are for yet-to-be created tests.
message EnumValue {
enum Color {
RED = 0;
ORANGE = 1;
YELLOW = 2;
GREEN = 3;
BLUE = 4;
INDIGO = 5;
VIOLET = 6;
};
optional Color enum_value = 14;
repeated Color repeated_enum_value = 15;
}
message InnerMessageValue {
optional float float_value = 2;
repeated bytes bytes_values = 8;