Make tf.contrib.proto.* TF2-friendly.
This included fixing a bug where shape inference caught an incorrect shape, but since eager mode doesn't run shape inference the core code caused a segfault. PiperOrigin-RevId: 237316781
This commit is contained in:
parent
17baa62ff4
commit
add7a1a911
@ -296,14 +296,13 @@ class DecodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase):
|
|||||||
field_names = ['sizes']
|
field_names = ['sizes']
|
||||||
field_types = [dtypes.int32]
|
field_types = [dtypes.int32]
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.assertRaisesRegexp(
|
||||||
ctensor, vtensor = self._decode_module.decode_proto(
|
errors.DataLossError, 'Unable to parse binary protobuf'
|
||||||
|
'|Failed to consume entire buffer'):
|
||||||
|
self.evaluate(
|
||||||
|
self._decode_module.decode_proto(
|
||||||
batch,
|
batch,
|
||||||
message_type=msg_type,
|
message_type=msg_type,
|
||||||
field_names=field_names,
|
field_names=field_names,
|
||||||
output_types=field_types,
|
output_types=field_types,
|
||||||
sanitize=sanitize)
|
sanitize=sanitize))
|
||||||
with self.assertRaisesRegexp(errors.DataLossError,
|
|
||||||
'Unable to parse binary protobuf'
|
|
||||||
'|Failed to consume entire buffer'):
|
|
||||||
_ = sess.run([ctensor] + vtensor)
|
|
||||||
|
@ -30,7 +30,9 @@ from google.protobuf import text_format
|
|||||||
|
|
||||||
from tensorflow.contrib.proto.python.kernel_tests import proto_op_test_base as test_base
|
from tensorflow.contrib.proto.python.kernel_tests import proto_op_test_base as test_base
|
||||||
from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2
|
from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2
|
||||||
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
|
||||||
|
|
||||||
@ -50,30 +52,53 @@ class EncodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase):
|
|||||||
self._decode_module = decode_module
|
self._decode_module = decode_module
|
||||||
self._encode_module = encode_module
|
self._encode_module = encode_module
|
||||||
|
|
||||||
|
def testBadSizesShape(self):
|
||||||
|
if context.executing_eagerly():
|
||||||
|
expected_error = (errors.InvalidArgumentError,
|
||||||
|
r'Invalid shape for field double_value.')
|
||||||
|
else:
|
||||||
|
expected_error = (ValueError,
|
||||||
|
r'Shape must be at least rank 2 but is rank 0')
|
||||||
|
with self.assertRaisesRegexp(*expected_error):
|
||||||
|
self.evaluate(
|
||||||
|
self._encode_module.encode_proto(
|
||||||
|
sizes=1,
|
||||||
|
values=[np.double(1.0)],
|
||||||
|
message_type='tensorflow.contrib.proto.TestValue',
|
||||||
|
field_names=['double_value']))
|
||||||
|
|
||||||
def testBadInputs(self):
|
def testBadInputs(self):
|
||||||
# Invalid field name
|
# Invalid field name
|
||||||
with self.cached_session():
|
|
||||||
with self.assertRaisesOpError('Unknown field: non_existent_field'):
|
with self.assertRaisesOpError('Unknown field: non_existent_field'):
|
||||||
|
self.evaluate(
|
||||||
self._encode_module.encode_proto(
|
self._encode_module.encode_proto(
|
||||||
sizes=[[1]],
|
sizes=[[1]],
|
||||||
values=[np.array([[0.0]], dtype=np.int32)],
|
values=[np.array([[0.0]], dtype=np.int32)],
|
||||||
message_type='tensorflow.contrib.proto.TestValue',
|
message_type='tensorflow.contrib.proto.TestValue',
|
||||||
field_names=['non_existent_field']).eval()
|
field_names=['non_existent_field']))
|
||||||
|
|
||||||
# Incorrect types.
|
# Incorrect types.
|
||||||
with self.cached_session():
|
with self.assertRaisesOpError('Incompatible type for field double_value.'):
|
||||||
with self.assertRaisesOpError(
|
self.evaluate(
|
||||||
'Incompatible type for field double_value.'):
|
|
||||||
self._encode_module.encode_proto(
|
self._encode_module.encode_proto(
|
||||||
sizes=[[1]],
|
sizes=[[1]],
|
||||||
values=[np.array([[0.0]], dtype=np.int32)],
|
values=[np.array([[0.0]], dtype=np.int32)],
|
||||||
message_type='tensorflow.contrib.proto.TestValue',
|
message_type='tensorflow.contrib.proto.TestValue',
|
||||||
field_names=['double_value']).eval()
|
field_names=['double_value']))
|
||||||
|
|
||||||
# Incorrect shapes of sizes.
|
# Incorrect shapes of sizes.
|
||||||
with self.cached_session():
|
for sizes_value in 1, np.array([[[0, 0]]]):
|
||||||
with self.assertRaisesOpError(
|
with self.assertRaisesOpError(
|
||||||
r'sizes should be batch_size \+ \[len\(field_names\)\]'):
|
r'sizes should be batch_size \+ \[len\(field_names\)\]'):
|
||||||
|
if context.executing_eagerly():
|
||||||
|
self.evaluate(
|
||||||
|
self._encode_module.encode_proto(
|
||||||
|
sizes=sizes_value,
|
||||||
|
values=[np.array([[0.0]])],
|
||||||
|
message_type='tensorflow.contrib.proto.TestValue',
|
||||||
|
field_names=['double_value']))
|
||||||
|
else:
|
||||||
|
with self.cached_session():
|
||||||
sizes = array_ops.placeholder(dtypes.int32)
|
sizes = array_ops.placeholder(dtypes.int32)
|
||||||
values = array_ops.placeholder(dtypes.float64)
|
values = array_ops.placeholder(dtypes.float64)
|
||||||
self._encode_module.encode_proto(
|
self._encode_module.encode_proto(
|
||||||
@ -81,15 +106,22 @@ class EncodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase):
|
|||||||
values=[values],
|
values=[values],
|
||||||
message_type='tensorflow.contrib.proto.TestValue',
|
message_type='tensorflow.contrib.proto.TestValue',
|
||||||
field_names=['double_value']).eval(feed_dict={
|
field_names=['double_value']).eval(feed_dict={
|
||||||
sizes: [[[0, 0]]],
|
sizes: sizes_value,
|
||||||
values: [[0.0]]
|
values: [[0.0]]
|
||||||
})
|
})
|
||||||
|
|
||||||
# Inconsistent shapes of values.
|
# Inconsistent shapes of values.
|
||||||
|
with self.assertRaisesOpError('Values must match up to the last dimension'):
|
||||||
|
if context.executing_eagerly():
|
||||||
|
self.evaluate(
|
||||||
|
self._encode_module.encode_proto(
|
||||||
|
sizes=[[1, 1]],
|
||||||
|
values=[np.array([[0.0]]),
|
||||||
|
np.array([[0], [0]])],
|
||||||
|
message_type='tensorflow.contrib.proto.TestValue',
|
||||||
|
field_names=['double_value', 'int32_value']))
|
||||||
|
else:
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
with self.assertRaisesOpError(
|
|
||||||
'Values must match up to the last dimension'):
|
|
||||||
sizes = array_ops.placeholder(dtypes.int32)
|
|
||||||
values1 = array_ops.placeholder(dtypes.float64)
|
values1 = array_ops.placeholder(dtypes.float64)
|
||||||
values2 = array_ops.placeholder(dtypes.int32)
|
values2 = array_ops.placeholder(dtypes.int32)
|
||||||
(self._encode_module.encode_proto(
|
(self._encode_module.encode_proto(
|
||||||
|
@ -525,11 +525,16 @@ class EncodeProtoOp : public OpKernel {
|
|||||||
ctx,
|
ctx,
|
||||||
proto_utils::IsCompatibleType(field_descs_[i]->type(), v.dtype()),
|
proto_utils::IsCompatibleType(field_descs_[i]->type(), v.dtype()),
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
"Incompatible type for field " + field_names_[i] +
|
"Incompatible type for field ", field_names_[i],
|
||||||
". Saw dtype: ",
|
". Saw dtype: ", DataTypeString(v.dtype()),
|
||||||
DataTypeString(v.dtype()),
|
|
||||||
" but field type is: ", field_descs_[i]->type_name()));
|
" but field type is: ", field_descs_[i]->type_name()));
|
||||||
|
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, TensorShapeUtils::IsMatrixOrHigher(v.shape()),
|
||||||
|
errors::InvalidArgument("Invalid shape for field ", field_names_[i],
|
||||||
|
". Saw shape ", v.shape().DebugString(),
|
||||||
|
" but it should be at least a matrix."));
|
||||||
|
|
||||||
// All value tensors must have the same shape prefix (i.e. batch size).
|
// All value tensors must have the same shape prefix (i.e. batch size).
|
||||||
TensorShape shape_prefix = v.shape();
|
TensorShape shape_prefix = v.shape();
|
||||||
shape_prefix.RemoveDim(shape_prefix.dims() - 1);
|
shape_prefix.RemoveDim(shape_prefix.dims() - 1);
|
||||||
|
Loading…
Reference in New Issue
Block a user