remove v1 only decorator
PiperOrigin-RevId: 323821989 Change-Id: Id0cb6221990644f70170891a6a9dbaa6b45e64a8
This commit is contained in:
parent
50caa9d728
commit
6951872858
@ -22,6 +22,7 @@ import pickle
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
@ -51,41 +52,41 @@ class TensorSpecTest(test_util.TensorFlowTestCase):
|
||||
desc = tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32)
|
||||
self.assertEqual(desc.shape, tensor_shape.TensorShape(None))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testShapeCompatibility(self):
|
||||
unknown = array_ops.placeholder(dtypes.int64)
|
||||
partial = array_ops.placeholder(dtypes.int64, shape=[None, 1])
|
||||
full = array_ops.placeholder(dtypes.int64, shape=[2, 3])
|
||||
rank3 = array_ops.placeholder(dtypes.int64, shape=[4, 5, 6])
|
||||
# This test needs a placeholder which means we need to construct a graph.
|
||||
with ops.Graph().as_default():
|
||||
unknown = array_ops.placeholder(dtypes.int64)
|
||||
partial = array_ops.placeholder(dtypes.int64, shape=[None, 1])
|
||||
full = array_ops.placeholder(dtypes.int64, shape=[2, 3])
|
||||
rank3 = array_ops.placeholder(dtypes.int64, shape=[4, 5, 6])
|
||||
|
||||
desc_unknown = tensor_spec.TensorSpec(None, dtypes.int64)
|
||||
self.assertTrue(desc_unknown.is_compatible_with(unknown))
|
||||
self.assertTrue(desc_unknown.is_compatible_with(partial))
|
||||
self.assertTrue(desc_unknown.is_compatible_with(full))
|
||||
self.assertTrue(desc_unknown.is_compatible_with(rank3))
|
||||
desc_unknown = tensor_spec.TensorSpec(None, dtypes.int64)
|
||||
self.assertTrue(desc_unknown.is_compatible_with(unknown))
|
||||
self.assertTrue(desc_unknown.is_compatible_with(partial))
|
||||
self.assertTrue(desc_unknown.is_compatible_with(full))
|
||||
self.assertTrue(desc_unknown.is_compatible_with(rank3))
|
||||
|
||||
desc_partial = tensor_spec.TensorSpec([2, None], dtypes.int64)
|
||||
self.assertTrue(desc_partial.is_compatible_with(unknown))
|
||||
self.assertTrue(desc_partial.is_compatible_with(partial))
|
||||
self.assertTrue(desc_partial.is_compatible_with(full))
|
||||
self.assertFalse(desc_partial.is_compatible_with(rank3))
|
||||
desc_partial = tensor_spec.TensorSpec([2, None], dtypes.int64)
|
||||
self.assertTrue(desc_partial.is_compatible_with(unknown))
|
||||
self.assertTrue(desc_partial.is_compatible_with(partial))
|
||||
self.assertTrue(desc_partial.is_compatible_with(full))
|
||||
self.assertFalse(desc_partial.is_compatible_with(rank3))
|
||||
|
||||
desc_full = tensor_spec.TensorSpec([2, 3], dtypes.int64)
|
||||
self.assertTrue(desc_full.is_compatible_with(unknown))
|
||||
self.assertFalse(desc_full.is_compatible_with(partial))
|
||||
self.assertTrue(desc_full.is_compatible_with(full))
|
||||
self.assertFalse(desc_full.is_compatible_with(rank3))
|
||||
desc_full = tensor_spec.TensorSpec([2, 3], dtypes.int64)
|
||||
self.assertTrue(desc_full.is_compatible_with(unknown))
|
||||
self.assertFalse(desc_full.is_compatible_with(partial))
|
||||
self.assertTrue(desc_full.is_compatible_with(full))
|
||||
self.assertFalse(desc_full.is_compatible_with(rank3))
|
||||
|
||||
desc_rank3 = tensor_spec.TensorSpec([4, 5, 6], dtypes.int64)
|
||||
self.assertTrue(desc_rank3.is_compatible_with(unknown))
|
||||
self.assertFalse(desc_rank3.is_compatible_with(partial))
|
||||
self.assertFalse(desc_rank3.is_compatible_with(full))
|
||||
self.assertTrue(desc_rank3.is_compatible_with(rank3))
|
||||
desc_rank3 = tensor_spec.TensorSpec([4, 5, 6], dtypes.int64)
|
||||
self.assertTrue(desc_rank3.is_compatible_with(unknown))
|
||||
self.assertFalse(desc_rank3.is_compatible_with(partial))
|
||||
self.assertFalse(desc_rank3.is_compatible_with(full))
|
||||
self.assertTrue(desc_rank3.is_compatible_with(rank3))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testTypeCompatibility(self):
|
||||
floats = array_ops.placeholder(dtypes.float32, shape=[10, 10])
|
||||
ints = array_ops.placeholder(dtypes.int32, shape=[10, 10])
|
||||
floats = constant_op.constant(1, dtype=dtypes.float32, shape=[10, 10])
|
||||
ints = constant_op.constant(1, dtype=dtypes.int32, shape=[10, 10])
|
||||
desc = tensor_spec.TensorSpec(shape=(10, 10), dtype=dtypes.float32)
|
||||
self.assertTrue(desc.is_compatible_with(floats))
|
||||
self.assertFalse(desc.is_compatible_with(ints))
|
||||
@ -118,28 +119,31 @@ class TensorSpecTest(test_util.TensorFlowTestCase):
|
||||
spec_2 = tensor_spec.TensorSpec.from_spec(spec_1)
|
||||
self.assertEqual(spec_1, spec_2)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testFromTensor(self):
|
||||
zero = constant_op.constant(0)
|
||||
spec = tensor_spec.TensorSpec.from_tensor(zero)
|
||||
self.assertEqual(spec.dtype, dtypes.int32)
|
||||
self.assertEqual(spec.shape, [])
|
||||
self.assertEqual(spec.name, "Const")
|
||||
# Tensor.name is meaningless when eager execution is enabled.
|
||||
if not context.executing_eagerly():
|
||||
self.assertEqual(spec.name, "Const")
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testFromPlaceholder(self):
|
||||
unknown = array_ops.placeholder(dtypes.int64, name="unknown")
|
||||
partial = array_ops.placeholder(dtypes.float32,
|
||||
shape=[None, 1],
|
||||
name="partial")
|
||||
spec_1 = tensor_spec.TensorSpec.from_tensor(unknown)
|
||||
self.assertEqual(spec_1.dtype, dtypes.int64)
|
||||
self.assertEqual(spec_1.shape, None)
|
||||
self.assertEqual(spec_1.name, "unknown")
|
||||
spec_2 = tensor_spec.TensorSpec.from_tensor(partial)
|
||||
self.assertEqual(spec_2.dtype, dtypes.float32)
|
||||
self.assertEqual(spec_2.shape.as_list(), [None, 1])
|
||||
self.assertEqual(spec_2.name, "partial")
|
||||
# This test needs a placeholder which means we need to construct a graph.
|
||||
with ops.Graph().as_default():
|
||||
unknown = array_ops.placeholder(dtypes.int64, name="unknown")
|
||||
partial = array_ops.placeholder(dtypes.float32,
|
||||
shape=[None, 1],
|
||||
name="partial")
|
||||
|
||||
spec_1 = tensor_spec.TensorSpec.from_tensor(unknown)
|
||||
self.assertEqual(spec_1.dtype, dtypes.int64)
|
||||
self.assertEqual(spec_1.shape, None)
|
||||
self.assertEqual(spec_1.name, "unknown")
|
||||
spec_2 = tensor_spec.TensorSpec.from_tensor(partial)
|
||||
self.assertEqual(spec_2.dtype, dtypes.float32)
|
||||
self.assertEqual(spec_2.shape.as_list(), [None, 1])
|
||||
self.assertEqual(spec_2.name, "partial")
|
||||
|
||||
def testFromBoundedTensorSpec(self):
|
||||
bounded_spec = tensor_spec.BoundedTensorSpec((1, 2), dtypes.int32, 0, 1)
|
||||
|
Loading…
Reference in New Issue
Block a user