Make tf.contrib.proto.* TF2-friendly.
This included fixing a bug where shape inference caught an incorrect shape, but since eager mode doesn't run shape inference the core code caused a segfault. PiperOrigin-RevId: 237316781
This commit is contained in:
parent
17baa62ff4
commit
add7a1a911
@ -296,14 +296,13 @@ class DecodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase):
|
||||
field_names = ['sizes']
|
||||
field_types = [dtypes.int32]
|
||||
|
||||
with self.cached_session() as sess:
|
||||
ctensor, vtensor = self._decode_module.decode_proto(
|
||||
with self.assertRaisesRegexp(
|
||||
errors.DataLossError, 'Unable to parse binary protobuf'
|
||||
'|Failed to consume entire buffer'):
|
||||
self.evaluate(
|
||||
self._decode_module.decode_proto(
|
||||
batch,
|
||||
message_type=msg_type,
|
||||
field_names=field_names,
|
||||
output_types=field_types,
|
||||
sanitize=sanitize)
|
||||
with self.assertRaisesRegexp(errors.DataLossError,
|
||||
'Unable to parse binary protobuf'
|
||||
'|Failed to consume entire buffer'):
|
||||
_ = sess.run([ctensor] + vtensor)
|
||||
sanitize=sanitize))
|
||||
|
@ -30,7 +30,9 @@ from google.protobuf import text_format
|
||||
|
||||
from tensorflow.contrib.proto.python.kernel_tests import proto_op_test_base as test_base
|
||||
from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.ops import array_ops
|
||||
|
||||
|
||||
@ -50,30 +52,53 @@ class EncodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase):
|
||||
self._decode_module = decode_module
|
||||
self._encode_module = encode_module
|
||||
|
||||
def testBadSizesShape(self):
|
||||
if context.executing_eagerly():
|
||||
expected_error = (errors.InvalidArgumentError,
|
||||
r'Invalid shape for field double_value.')
|
||||
else:
|
||||
expected_error = (ValueError,
|
||||
r'Shape must be at least rank 2 but is rank 0')
|
||||
with self.assertRaisesRegexp(*expected_error):
|
||||
self.evaluate(
|
||||
self._encode_module.encode_proto(
|
||||
sizes=1,
|
||||
values=[np.double(1.0)],
|
||||
message_type='tensorflow.contrib.proto.TestValue',
|
||||
field_names=['double_value']))
|
||||
|
||||
def testBadInputs(self):
|
||||
# Invalid field name
|
||||
with self.cached_session():
|
||||
with self.assertRaisesOpError('Unknown field: non_existent_field'):
|
||||
self.evaluate(
|
||||
self._encode_module.encode_proto(
|
||||
sizes=[[1]],
|
||||
values=[np.array([[0.0]], dtype=np.int32)],
|
||||
message_type='tensorflow.contrib.proto.TestValue',
|
||||
field_names=['non_existent_field']).eval()
|
||||
field_names=['non_existent_field']))
|
||||
|
||||
# Incorrect types.
|
||||
with self.cached_session():
|
||||
with self.assertRaisesOpError(
|
||||
'Incompatible type for field double_value.'):
|
||||
with self.assertRaisesOpError('Incompatible type for field double_value.'):
|
||||
self.evaluate(
|
||||
self._encode_module.encode_proto(
|
||||
sizes=[[1]],
|
||||
values=[np.array([[0.0]], dtype=np.int32)],
|
||||
message_type='tensorflow.contrib.proto.TestValue',
|
||||
field_names=['double_value']).eval()
|
||||
field_names=['double_value']))
|
||||
|
||||
# Incorrect shapes of sizes.
|
||||
with self.cached_session():
|
||||
for sizes_value in 1, np.array([[[0, 0]]]):
|
||||
with self.assertRaisesOpError(
|
||||
r'sizes should be batch_size \+ \[len\(field_names\)\]'):
|
||||
if context.executing_eagerly():
|
||||
self.evaluate(
|
||||
self._encode_module.encode_proto(
|
||||
sizes=sizes_value,
|
||||
values=[np.array([[0.0]])],
|
||||
message_type='tensorflow.contrib.proto.TestValue',
|
||||
field_names=['double_value']))
|
||||
else:
|
||||
with self.cached_session():
|
||||
sizes = array_ops.placeholder(dtypes.int32)
|
||||
values = array_ops.placeholder(dtypes.float64)
|
||||
self._encode_module.encode_proto(
|
||||
@ -81,15 +106,22 @@ class EncodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase):
|
||||
values=[values],
|
||||
message_type='tensorflow.contrib.proto.TestValue',
|
||||
field_names=['double_value']).eval(feed_dict={
|
||||
sizes: [[[0, 0]]],
|
||||
sizes: sizes_value,
|
||||
values: [[0.0]]
|
||||
})
|
||||
|
||||
# Inconsistent shapes of values.
|
||||
with self.assertRaisesOpError('Values must match up to the last dimension'):
|
||||
if context.executing_eagerly():
|
||||
self.evaluate(
|
||||
self._encode_module.encode_proto(
|
||||
sizes=[[1, 1]],
|
||||
values=[np.array([[0.0]]),
|
||||
np.array([[0], [0]])],
|
||||
message_type='tensorflow.contrib.proto.TestValue',
|
||||
field_names=['double_value', 'int32_value']))
|
||||
else:
|
||||
with self.cached_session():
|
||||
with self.assertRaisesOpError(
|
||||
'Values must match up to the last dimension'):
|
||||
sizes = array_ops.placeholder(dtypes.int32)
|
||||
values1 = array_ops.placeholder(dtypes.float64)
|
||||
values2 = array_ops.placeholder(dtypes.int32)
|
||||
(self._encode_module.encode_proto(
|
||||
|
@ -525,11 +525,16 @@ class EncodeProtoOp : public OpKernel {
|
||||
ctx,
|
||||
proto_utils::IsCompatibleType(field_descs_[i]->type(), v.dtype()),
|
||||
errors::InvalidArgument(
|
||||
"Incompatible type for field " + field_names_[i] +
|
||||
". Saw dtype: ",
|
||||
DataTypeString(v.dtype()),
|
||||
"Incompatible type for field ", field_names_[i],
|
||||
". Saw dtype: ", DataTypeString(v.dtype()),
|
||||
" but field type is: ", field_descs_[i]->type_name()));
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, TensorShapeUtils::IsMatrixOrHigher(v.shape()),
|
||||
errors::InvalidArgument("Invalid shape for field ", field_names_[i],
|
||||
". Saw shape ", v.shape().DebugString(),
|
||||
" but it should be at least a matrix."));
|
||||
|
||||
// All value tensors must have the same shape prefix (i.e. batch size).
|
||||
TensorShape shape_prefix = v.shape();
|
||||
shape_prefix.RemoveDim(shape_prefix.dims() - 1);
|
||||
|
Loading…
Reference in New Issue
Block a user