decode_proto: Fixed decoding of default values for enum fields.
PiperOrigin-RevId: 271043492
This commit is contained in:
parent
ead8115fe2
commit
f776c50ca3
@ -163,7 +163,6 @@ Status InitDefaultValueFromFieldDescriptor(DataType dtype,
|
|||||||
case WireFormatLite::TYPE_UINT64:
|
case WireFormatLite::TYPE_UINT64:
|
||||||
return InitDefaultValue(dtype, field_desc->default_value_uint64(),
|
return InitDefaultValue(dtype, field_desc->default_value_uint64(),
|
||||||
result);
|
result);
|
||||||
case WireFormatLite::TYPE_ENUM:
|
|
||||||
case WireFormatLite::TYPE_INT32:
|
case WireFormatLite::TYPE_INT32:
|
||||||
case WireFormatLite::TYPE_SINT32:
|
case WireFormatLite::TYPE_SINT32:
|
||||||
case WireFormatLite::TYPE_SFIXED32:
|
case WireFormatLite::TYPE_SFIXED32:
|
||||||
@ -174,6 +173,9 @@ Status InitDefaultValueFromFieldDescriptor(DataType dtype,
|
|||||||
result);
|
result);
|
||||||
case WireFormatLite::TYPE_BOOL:
|
case WireFormatLite::TYPE_BOOL:
|
||||||
return InitDefaultValue(dtype, field_desc->default_value_bool(), result);
|
return InitDefaultValue(dtype, field_desc->default_value_bool(), result);
|
||||||
|
case WireFormatLite::TYPE_ENUM:
|
||||||
|
return InitDefaultValue(dtype, field_desc->default_value_enum()->number(),
|
||||||
|
result);
|
||||||
case WireFormatLite::TYPE_BYTES:
|
case WireFormatLite::TYPE_BYTES:
|
||||||
case WireFormatLite::TYPE_STRING:
|
case WireFormatLite::TYPE_STRING:
|
||||||
// Manipulating default string values as C-style pointers should be OK
|
// Manipulating default string values as C-style pointers should be OK
|
||||||
|
@ -134,7 +134,10 @@ class DecodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase):
|
|||||||
dtypes.uint64:
|
dtypes.uint64:
|
||||||
'uint64_value',
|
'uint64_value',
|
||||||
}
|
}
|
||||||
tf_field_name = tf_type_to_primitive_value_field.get(field.dtype)
|
if field.name in ['enum_value', 'enum_value_with_default']:
|
||||||
|
tf_field_name = 'enum_value'
|
||||||
|
else:
|
||||||
|
tf_field_name = tf_type_to_primitive_value_field.get(field.dtype)
|
||||||
if tf_field_name is None:
|
if tf_field_name is None:
|
||||||
self.fail('Unhandled tensorflow type %d' % field.dtype)
|
self.fail('Unhandled tensorflow type %d' % field.dtype)
|
||||||
|
|
||||||
|
@ -131,6 +131,11 @@ class ProtoOpTestBase(test.TestCase):
|
|||||||
field.name = "bytes_value_with_default"
|
field.name = "bytes_value_with_default"
|
||||||
field.dtype = types_pb2.DT_STRING
|
field.dtype = types_pb2.DT_STRING
|
||||||
field.value.string_value.append("a longer default string")
|
field.value.string_value.append("a longer default string")
|
||||||
|
test_case.sizes.append(0)
|
||||||
|
field = test_case.fields.add()
|
||||||
|
field.name = "enum_value_with_default"
|
||||||
|
field.dtype = types_pb2.DT_INT32
|
||||||
|
field.value.enum_value.append(test_example_pb2.Color.GREEN)
|
||||||
return test_case
|
return test_case
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -421,6 +426,7 @@ class ProtoOpTestBase(test.TestCase):
|
|||||||
value = test_case.values.add()
|
value = test_case.values.add()
|
||||||
value.double_value.append(23.5)
|
value.double_value.append(23.5)
|
||||||
value.bool_value.append(True)
|
value.bool_value.append(True)
|
||||||
|
value.enum_value.append(test_example_pb2.Color.INDIGO)
|
||||||
test_case.shapes.append(1)
|
test_case.shapes.append(1)
|
||||||
test_case.sizes.append(1)
|
test_case.sizes.append(1)
|
||||||
field = test_case.fields.add()
|
field = test_case.fields.add()
|
||||||
@ -432,4 +438,9 @@ class ProtoOpTestBase(test.TestCase):
|
|||||||
field.name = "bool_value"
|
field.name = "bool_value"
|
||||||
field.dtype = types_pb2.DT_BOOL
|
field.dtype = types_pb2.DT_BOOL
|
||||||
field.value.bool_value.append(True)
|
field.value.bool_value.append(True)
|
||||||
|
test_case.sizes.append(1)
|
||||||
|
field = test_case.fields.add()
|
||||||
|
field.name = "enum_value"
|
||||||
|
field.dtype = types_pb2.DT_INT32
|
||||||
|
field.value.enum_value.append(test_example_pb2.Color.INDIGO)
|
||||||
return test_case
|
return test_case
|
||||||
|
@ -2,10 +2,10 @@
|
|||||||
|
|
||||||
syntax = "proto2";
|
syntax = "proto2";
|
||||||
|
|
||||||
import "tensorflow/core/framework/types.proto";
|
|
||||||
|
|
||||||
package tensorflow.contrib.proto;
|
package tensorflow.contrib.proto;
|
||||||
|
|
||||||
|
import "tensorflow/core/framework/types.proto";
|
||||||
|
|
||||||
// A TestCase holds a proto and assertions about how it should decode.
|
// A TestCase holds a proto and assertions about how it should decode.
|
||||||
message TestCase {
|
message TestCase {
|
||||||
// Batches of primitive values.
|
// Batches of primitive values.
|
||||||
@ -16,14 +16,24 @@ message TestCase {
|
|||||||
repeated int32 sizes = 3;
|
repeated int32 sizes = 3;
|
||||||
// Expected values for each field.
|
// Expected values for each field.
|
||||||
repeated FieldSpec fields = 4;
|
repeated FieldSpec fields = 4;
|
||||||
};
|
}
|
||||||
|
|
||||||
// FieldSpec describes the expected output for a single field.
|
// FieldSpec describes the expected output for a single field.
|
||||||
message FieldSpec {
|
message FieldSpec {
|
||||||
optional string name = 1;
|
optional string name = 1;
|
||||||
optional tensorflow.DataType dtype = 2;
|
optional tensorflow.DataType dtype = 2;
|
||||||
optional TestValue value = 3;
|
optional TestValue value = 3;
|
||||||
};
|
}
|
||||||
|
|
||||||
|
enum Color {
|
||||||
|
RED = 0;
|
||||||
|
ORANGE = 1;
|
||||||
|
YELLOW = 2;
|
||||||
|
GREEN = 3;
|
||||||
|
BLUE = 4;
|
||||||
|
INDIGO = 5;
|
||||||
|
VIOLET = 6;
|
||||||
|
}
|
||||||
|
|
||||||
// NOTE: This definition must be kept in sync with PackedTestValue.
|
// NOTE: This definition must be kept in sync with PackedTestValue.
|
||||||
message TestValue {
|
message TestValue {
|
||||||
@ -43,6 +53,7 @@ message TestValue {
|
|||||||
repeated sint32 sint32_value = 17;
|
repeated sint32 sint32_value = 17;
|
||||||
repeated sint64 sint64_value = 18;
|
repeated sint64 sint64_value = 18;
|
||||||
repeated PrimitiveValue message_value = 19;
|
repeated PrimitiveValue message_value = 19;
|
||||||
|
repeated Color enum_value = 35;
|
||||||
|
|
||||||
// Optional fields with explicitly-specified defaults.
|
// Optional fields with explicitly-specified defaults.
|
||||||
optional double double_value_with_default = 20 [default = 1.0];
|
optional double double_value_with_default = 20 [default = 1.0];
|
||||||
@ -61,6 +72,7 @@ message TestValue {
|
|||||||
optional sfixed64 sfixed64_value_with_default = 32 [default = 11];
|
optional sfixed64 sfixed64_value_with_default = 32 [default = 11];
|
||||||
optional sint32 sint32_value_with_default = 33 [default = 12];
|
optional sint32 sint32_value_with_default = 33 [default = 12];
|
||||||
optional sint64 sint64_value_with_default = 34 [default = 13];
|
optional sint64 sint64_value_with_default = 34 [default = 13];
|
||||||
|
optional Color enum_value_with_default = 36 [default = GREEN];
|
||||||
|
|
||||||
extensions 100 to 199;
|
extensions 100 to 199;
|
||||||
}
|
}
|
||||||
@ -88,6 +100,7 @@ message PackedTestValue {
|
|||||||
repeated sint32 sint32_value = 17 [packed = true];
|
repeated sint32 sint32_value = 17 [packed = true];
|
||||||
repeated sint64 sint64_value = 18 [packed = true];
|
repeated sint64 sint64_value = 18 [packed = true];
|
||||||
repeated PrimitiveValue message_value = 19;
|
repeated PrimitiveValue message_value = 19;
|
||||||
|
repeated Color enum_value = 35;
|
||||||
|
|
||||||
optional double double_value_with_default = 20 [default = 1.0];
|
optional double double_value_with_default = 20 [default = 1.0];
|
||||||
optional float float_value_with_default = 21 [default = 2.0];
|
optional float float_value_with_default = 21 [default = 2.0];
|
||||||
@ -105,6 +118,7 @@ message PackedTestValue {
|
|||||||
optional sfixed64 sfixed64_value_with_default = 32 [default = 11];
|
optional sfixed64 sfixed64_value_with_default = 32 [default = 11];
|
||||||
optional sint32 sint32_value_with_default = 33 [default = 12];
|
optional sint32 sint32_value_with_default = 33 [default = 12];
|
||||||
optional sint64 sint64_value_with_default = 34 [default = 13];
|
optional sint64 sint64_value_with_default = 34 [default = 13];
|
||||||
|
optional Color enum_value_with_default = 36 [default = GREEN];
|
||||||
}
|
}
|
||||||
|
|
||||||
message PrimitiveValue {
|
message PrimitiveValue {
|
||||||
@ -140,21 +154,6 @@ extend TestValue {
|
|||||||
|
|
||||||
// The messages below are for yet-to-be created tests.
|
// The messages below are for yet-to-be created tests.
|
||||||
|
|
||||||
message EnumValue {
|
|
||||||
enum Color {
|
|
||||||
RED = 0;
|
|
||||||
ORANGE = 1;
|
|
||||||
YELLOW = 2;
|
|
||||||
GREEN = 3;
|
|
||||||
BLUE = 4;
|
|
||||||
INDIGO = 5;
|
|
||||||
VIOLET = 6;
|
|
||||||
};
|
|
||||||
optional Color enum_value = 14;
|
|
||||||
repeated Color repeated_enum_value = 15;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
message InnerMessageValue {
|
message InnerMessageValue {
|
||||||
optional float float_value = 2;
|
optional float float_value = 2;
|
||||||
repeated bytes bytes_values = 8;
|
repeated bytes bytes_values = 8;
|
||||||
|
Loading…
Reference in New Issue
Block a user