remove v1 decorator
PiperOrigin-RevId: 324043175 Change-Id: I0bc404d55da56b77c5a54a8235bb5601ad5e70d9
This commit is contained in:
parent
cdbd96f307
commit
43154abec0
@ -27,6 +27,7 @@ from tensorflow.python.framework import constant_op
|
|||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import func_graph
|
from tensorflow.python.framework import func_graph
|
||||||
from tensorflow.python.framework import indexed_slices
|
from tensorflow.python.framework import indexed_slices
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.framework import tensor_util
|
from tensorflow.python.framework import tensor_util
|
||||||
@ -816,15 +817,15 @@ class ConstantValueTest(test.TestCase):
|
|||||||
tf_val = constant_op.constant(np_val)
|
tf_val = constant_op.constant(np_val)
|
||||||
self.assertAllClose(np_val, tensor_util.constant_value(tf_val))
|
self.assertAllClose(np_val, tensor_util.constant_value(tf_val))
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testUnknown(self):
|
def testUnknown(self):
|
||||||
tf_val = gen_state_ops.variable(
|
with ops.Graph().as_default():
|
||||||
shape=[3, 4, 7],
|
tf_val = gen_state_ops.variable(
|
||||||
dtype=dtypes.float32,
|
shape=[3, 4, 7],
|
||||||
name="tf_val",
|
dtype=dtypes.float32,
|
||||||
container="",
|
name="tf_val",
|
||||||
shared_name="")
|
container="",
|
||||||
self.assertIs(None, tensor_util.constant_value(tf_val))
|
shared_name="")
|
||||||
|
self.assertIs(None, tensor_util.constant_value(tf_val))
|
||||||
|
|
||||||
def testShape(self):
|
def testShape(self):
|
||||||
np_val = np.array([1, 2, 3], dtype=np.int32)
|
np_val = np.array([1, 2, 3], dtype=np.int32)
|
||||||
@ -845,19 +846,17 @@ class ConstantValueTest(test.TestCase):
|
|||||||
c_val = tensor_util.constant_value(tf_val)
|
c_val = tensor_util.constant_value(tf_val)
|
||||||
self.assertEqual(6, c_val)
|
self.assertEqual(6, c_val)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testSizeOfScalar(self):
|
def testSizeOfScalar(self):
|
||||||
tf_val = array_ops.size(constant_op.constant(0.0))
|
tf_val = array_ops.size(constant_op.constant(0.0))
|
||||||
c_val = tensor_util.constant_value(tf_val)
|
c_val = tensor_util.constant_value(tf_val)
|
||||||
self.assertEqual(1, c_val)
|
self.assertEqual(1, c_val)
|
||||||
self.assertEqual(np.ndarray, type(c_val))
|
self.assertIn(type(c_val), [np.ndarray, np.int32])
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testRank(self):
|
def testRank(self):
|
||||||
tf_val = array_ops.rank(constant_op.constant(0.0, shape=[1, 2, 3]))
|
tf_val = array_ops.rank(constant_op.constant(0.0, shape=[1, 2, 3]))
|
||||||
c_val = tensor_util.constant_value(tf_val)
|
c_val = tensor_util.constant_value(tf_val)
|
||||||
|
|
||||||
self.assertEqual(np.ndarray, type(c_val))
|
self.assertIn(type(c_val), [np.ndarray, np.int32])
|
||||||
self.assertEqual((), c_val.shape)
|
self.assertEqual((), c_val.shape)
|
||||||
self.assertEqual(3, c_val)
|
self.assertEqual(3, c_val)
|
||||||
|
|
||||||
@ -868,7 +867,7 @@ class ConstantValueTest(test.TestCase):
|
|||||||
0.0, shape=[1, 2, 3]), optimize=False)
|
0.0, shape=[1, 2, 3]), optimize=False)
|
||||||
c_val = tensor_util.constant_value(tf_val)
|
c_val = tensor_util.constant_value(tf_val)
|
||||||
|
|
||||||
self.assertEqual(np.ndarray, type(c_val))
|
self.assertIn(type(c_val), [np.ndarray, np.int32])
|
||||||
self.assertEqual((), c_val.shape)
|
self.assertEqual((), c_val.shape)
|
||||||
self.assertEqual(3, c_val)
|
self.assertEqual(3, c_val)
|
||||||
self.assertEqual([3], c_val)
|
self.assertEqual([3], c_val)
|
||||||
@ -884,7 +883,6 @@ class ConstantValueTest(test.TestCase):
|
|||||||
c_val = tensor_util.constant_value(tf_val)
|
c_val = tensor_util.constant_value(tf_val)
|
||||||
self.assertAllClose(np_val.astype(np.float64), c_val)
|
self.assertAllClose(np_val.astype(np.float64), c_val)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testConcat(self):
|
def testConcat(self):
|
||||||
np_val = np.random.rand(3, 4, 7).astype(np.float32)
|
np_val = np.random.rand(3, 4, 7).astype(np.float32)
|
||||||
tf_val = array_ops.concat(
|
tf_val = array_ops.concat(
|
||||||
@ -892,19 +890,21 @@ class ConstantValueTest(test.TestCase):
|
|||||||
c_val = tensor_util.constant_value(tf_val)
|
c_val = tensor_util.constant_value(tf_val)
|
||||||
self.assertAllClose(np_val, c_val)
|
self.assertAllClose(np_val, c_val)
|
||||||
|
|
||||||
tf_val = array_ops.concat(
|
# This test needs a placeholder which means we need to construct a graph.
|
||||||
[np_val[0, :, :], np_val[1, :, :], np_val[2, :, :]],
|
with ops.Graph().as_default():
|
||||||
array_ops.placeholder(dtypes.int32))
|
tf_val = array_ops.concat(
|
||||||
c_val = tensor_util.constant_value(tf_val)
|
[np_val[0, :, :], np_val[1, :, :], np_val[2, :, :]],
|
||||||
self.assertIs(None, c_val)
|
array_ops.placeholder(dtypes.int32))
|
||||||
|
c_val = tensor_util.constant_value(tf_val)
|
||||||
|
self.assertIs(None, c_val)
|
||||||
|
|
||||||
tf_val = array_ops.concat([
|
tf_val = array_ops.concat([
|
||||||
np_val[0, :, :], array_ops.placeholder(dtypes.float32), np_val[2, :, :]
|
np_val[0, :, :],
|
||||||
], 1)
|
array_ops.placeholder(dtypes.float32), np_val[2, :, :]
|
||||||
c_val = tensor_util.constant_value(tf_val)
|
], 1)
|
||||||
self.assertIs(None, c_val)
|
c_val = tensor_util.constant_value(tf_val)
|
||||||
|
self.assertIs(None, c_val)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testPack_Axis0(self):
|
def testPack_Axis0(self):
|
||||||
inputs = [np.random.rand(4, 7) for _ in range(3)]
|
inputs = [np.random.rand(4, 7) for _ in range(3)]
|
||||||
np_val = np.array(inputs)
|
np_val = np.array(inputs)
|
||||||
@ -912,72 +912,79 @@ class ConstantValueTest(test.TestCase):
|
|||||||
c_val = tensor_util.constant_value(tf_val)
|
c_val = tensor_util.constant_value(tf_val)
|
||||||
self.assertAllClose(np_val, c_val)
|
self.assertAllClose(np_val, c_val)
|
||||||
|
|
||||||
tf_val = array_ops.stack(
|
# This test needs a placeholder which means we need to construct a graph.
|
||||||
[inputs[0], array_ops.placeholder(dtypes.float32), inputs[2]])
|
with ops.Graph().as_default():
|
||||||
c_val = tensor_util.constant_value(tf_val)
|
tf_val = array_ops.stack(
|
||||||
self.assertIs(None, c_val)
|
[inputs[0],
|
||||||
|
array_ops.placeholder(dtypes.float32), inputs[2]])
|
||||||
|
c_val = tensor_util.constant_value(tf_val)
|
||||||
|
self.assertIs(None, c_val)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testPack_Axis1(self):
|
def testPack_Axis1(self):
|
||||||
inputs = [np.random.rand(4, 7) for _ in range(3)]
|
# This test needs a placeholder which means we need to construct a graph.
|
||||||
tf_val = array_ops.stack(inputs, axis=1)
|
with ops.Graph().as_default():
|
||||||
c_val = tensor_util.constant_value(tf_val)
|
inputs = [np.random.rand(4, 7) for _ in range(3)]
|
||||||
self.assertIsNone(c_val)
|
tf_val = array_ops.stack(inputs, axis=1)
|
||||||
|
c_val = tensor_util.constant_value(tf_val)
|
||||||
|
self.assertIsNone(c_val)
|
||||||
|
|
||||||
tf_val = array_ops.stack(
|
tf_val = array_ops.stack(
|
||||||
[inputs[0], array_ops.placeholder(dtypes.float32), inputs[2]], axis=1)
|
[inputs[0],
|
||||||
c_val = tensor_util.constant_value(tf_val)
|
array_ops.placeholder(dtypes.float32), inputs[2]], axis=1)
|
||||||
self.assertIs(None, c_val)
|
c_val = tensor_util.constant_value(tf_val)
|
||||||
|
self.assertIs(None, c_val)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testPack_Partial_Axis0(self):
|
def testPack_Partial_Axis0(self):
|
||||||
input_ = np.random.rand(4, 7)
|
input_ = np.random.rand(4, 7)
|
||||||
tf_val = array_ops.stack([input_, array_ops.placeholder(dtypes.float32)])
|
# This test needs a placeholder which means we need to construct a graph.
|
||||||
c_val = tensor_util.constant_value(tf_val, partial=True)
|
with ops.Graph().as_default():
|
||||||
self.assertAllClose(input_, c_val[0])
|
tf_val = array_ops.stack([input_, array_ops.placeholder(dtypes.float32)])
|
||||||
self.assertIsNone(c_val[1])
|
c_val = tensor_util.constant_value(tf_val, partial=True)
|
||||||
|
self.assertAllClose(input_, c_val[0])
|
||||||
|
self.assertIsNone(c_val[1])
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testPack_Partial_Axis1(self):
|
def testPack_Partial_Axis1(self):
|
||||||
input_ = np.random.rand(4, 7)
|
input_ = np.random.rand(4, 7)
|
||||||
tf_val = array_ops.stack([input_, array_ops.placeholder(dtypes.float32)],
|
# This test needs a placeholder which means we need to construct a graph.
|
||||||
axis=1)
|
with ops.Graph().as_default():
|
||||||
c_val = tensor_util.constant_value(tf_val, partial=True)
|
tf_val = array_ops.stack(
|
||||||
self.assertIsNone(c_val)
|
[input_, array_ops.placeholder(dtypes.float32)], axis=1)
|
||||||
|
c_val = tensor_util.constant_value(tf_val, partial=True)
|
||||||
|
self.assertIsNone(c_val)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testUnpack_Axis0(self):
|
def testUnpack_Axis0(self):
|
||||||
inputs = np.random.rand(3, 4, 7)
|
inputs = np.random.rand(3, 4, 7)
|
||||||
tf_vals = array_ops.unstack(inputs)
|
tf_vals = array_ops.unstack(inputs)
|
||||||
c_vals = [tensor_util.constant_value(x) for x in tf_vals]
|
c_vals = [tensor_util.constant_value(x) for x in tf_vals]
|
||||||
self.assertAllClose(inputs, c_vals)
|
self.assertAllClose(inputs, c_vals)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testUnpack_Partial_Axis0(self):
|
def testUnpack_Partial_Axis0(self):
|
||||||
input_ = np.random.rand(4, 7)
|
input_ = np.random.rand(4, 7)
|
||||||
packed = array_ops.stack([input_, array_ops.placeholder(dtypes.float32)])
|
# This test needs a placeholder which means we need to construct a graph.
|
||||||
tf_vals = array_ops.unstack(packed)
|
with ops.Graph().as_default():
|
||||||
c_vals = [tensor_util.constant_value(x, partial=True) for x in tf_vals]
|
packed = array_ops.stack([input_, array_ops.placeholder(dtypes.float32)])
|
||||||
self.assertAllClose(input_, c_vals[0])
|
tf_vals = array_ops.unstack(packed)
|
||||||
self.assertIsNone(c_vals[1])
|
c_vals = [tensor_util.constant_value(x, partial=True) for x in tf_vals]
|
||||||
|
self.assertAllClose(input_, c_vals[0])
|
||||||
|
self.assertIsNone(c_vals[1])
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testSplit_Axis0(self):
|
def testSplit_Axis0(self):
|
||||||
inputs = np.random.rand(6, 5, 7)
|
inputs = np.random.rand(6, 5, 7)
|
||||||
tf_vals = array_ops.split(inputs, 3)
|
tf_vals = array_ops.split(inputs, 3)
|
||||||
c_vals = [tensor_util.constant_value(x) for x in tf_vals]
|
c_vals = [tensor_util.constant_value(x) for x in tf_vals]
|
||||||
self.assertAllClose(np.split(inputs, 3), c_vals)
|
self.assertAllClose(np.split(inputs, 3), c_vals)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testSplit_Partial_Axis0(self):
|
def testSplit_Partial_Axis0(self):
|
||||||
input_ = np.random.rand(4, 7)
|
input_ = np.random.rand(4, 7)
|
||||||
placeholder = array_ops.placeholder(dtypes.float32, shape=(4, 7))
|
# This test needs a placeholder which means we need to construct a graph.
|
||||||
# it'd be better to use concat here, but concat doesn't support partial
|
with ops.Graph().as_default():
|
||||||
packed = array_ops.stack([input_, placeholder])
|
placeholder = array_ops.placeholder(dtypes.float32, shape=(4, 7))
|
||||||
tf_vals = array_ops.split(packed, 2)
|
# it'd be better to use concat here, but concat doesn't support partial
|
||||||
c_vals = [tensor_util.constant_value(x, partial=True) for x in tf_vals]
|
packed = array_ops.stack([input_, placeholder])
|
||||||
self.assertAllClose(input_, c_vals[0][0])
|
tf_vals = array_ops.split(packed, 2)
|
||||||
self.assertIsNone(c_vals[1][0])
|
c_vals = [tensor_util.constant_value(x, partial=True) for x in tf_vals]
|
||||||
|
self.assertAllClose(input_, c_vals[0][0])
|
||||||
|
self.assertIsNone(c_vals[1][0])
|
||||||
|
|
||||||
def testEqual(self):
|
def testEqual(self):
|
||||||
# Scalar inputs.
|
# Scalar inputs.
|
||||||
@ -1079,32 +1086,35 @@ class ConstantValueAsShapeTest(test.TestCase):
|
|||||||
c_val = tensor_util.constant_value_as_shape(tf_val)
|
c_val = tensor_util.constant_value_as_shape(tf_val)
|
||||||
self.assertEqual([None, 1, None], c_val.as_list())
|
self.assertEqual([None, 1, None], c_val.as_list())
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testPack(self):
|
def testPack(self):
|
||||||
tf_val = array_ops.stack(
|
# This test needs a placeholder which means we need to construct a graph.
|
||||||
[constant_op.constant(16), 37, array_ops.placeholder(dtypes.int32)])
|
with ops.Graph().as_default():
|
||||||
c_val = tensor_util.constant_value_as_shape(tf_val)
|
tf_val = array_ops.stack(
|
||||||
self.assertEqual([16, 37, None], c_val.as_list())
|
[constant_op.constant(16), 37,
|
||||||
|
array_ops.placeholder(dtypes.int32)])
|
||||||
|
c_val = tensor_util.constant_value_as_shape(tf_val)
|
||||||
|
self.assertEqual([16, 37, None], c_val.as_list())
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testConcat(self):
|
def testConcat(self):
|
||||||
tf_val = array_ops.concat(
|
# This test needs a placeholder which means we need to construct a graph.
|
||||||
[[16, 37], array_ops.placeholder(
|
with ops.Graph().as_default():
|
||||||
dtypes.int32, shape=(2,))], 0)
|
tf_val = array_ops.concat(
|
||||||
c_val = tensor_util.constant_value_as_shape(tf_val)
|
[[16, 37], array_ops.placeholder(dtypes.int32, shape=(2,))], 0)
|
||||||
self.assertEqual([16, 37, None, None], c_val.as_list())
|
c_val = tensor_util.constant_value_as_shape(tf_val)
|
||||||
|
self.assertEqual([16, 37, None, None], c_val.as_list())
|
||||||
|
|
||||||
tf_val = array_ops.concat(
|
tf_val = array_ops.concat(
|
||||||
[[16, 37], array_ops.placeholder(
|
[[16, 37],
|
||||||
dtypes.int32, shape=(1,)), [48]], 0)
|
array_ops.placeholder(dtypes.int32, shape=(1,)), [48]], 0)
|
||||||
c_val = tensor_util.constant_value_as_shape(tf_val)
|
c_val = tensor_util.constant_value_as_shape(tf_val)
|
||||||
self.assertEqual([16, 37, None, 48], c_val.as_list())
|
self.assertEqual([16, 37, None, 48], c_val.as_list())
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testSlice(self):
|
def testSlice(self):
|
||||||
tf_val = array_ops.placeholder(dtypes.int32, shape=(4,))[0:2]
|
# This test needs a placeholder which means we need to construct a graph.
|
||||||
c_val = tensor_util.constant_value_as_shape(tf_val)
|
with ops.Graph().as_default():
|
||||||
self.assertEqual([None, None], c_val.as_list())
|
tf_val = array_ops.placeholder(dtypes.int32, shape=(4,))[0:2]
|
||||||
|
c_val = tensor_util.constant_value_as_shape(tf_val)
|
||||||
|
self.assertEqual([None, None], c_val.as_list())
|
||||||
|
|
||||||
# begin:end
|
# begin:end
|
||||||
tf_val = constant_op.constant([10, 20, 30])[1:3]
|
tf_val = constant_op.constant([10, 20, 30])[1:3]
|
||||||
@ -1118,65 +1128,67 @@ class ConstantValueAsShapeTest(test.TestCase):
|
|||||||
self.assertEqual([20], c_val.as_list())
|
self.assertEqual([20], c_val.as_list())
|
||||||
|
|
||||||
# [1, 2, 16, 37, None, 48]
|
# [1, 2, 16, 37, None, 48]
|
||||||
tf_val_orig = array_ops.concat(
|
# This test needs a placeholder which means we need to construct a graph.
|
||||||
[[1, 2, 16, 37], array_ops.placeholder(
|
with ops.Graph().as_default():
|
||||||
dtypes.int32, shape=(1,)), [48]], 0)
|
tf_val_orig = array_ops.concat(
|
||||||
|
[[1, 2, 16, 37],
|
||||||
|
array_ops.placeholder(dtypes.int32, shape=(1,)), [48]], 0)
|
||||||
|
|
||||||
# begin: no end
|
# begin: no end
|
||||||
tf_val = tf_val_orig[2:]
|
tf_val = tf_val_orig[2:]
|
||||||
c_val = tensor_util.constant_value_as_shape(tf_val)
|
|
||||||
self.assertEqual([16, 37, None, 48], c_val.as_list())
|
|
||||||
|
|
||||||
# begin::negative slice
|
|
||||||
tf_val = tf_val_orig[2::-1]
|
|
||||||
c_val = tensor_util.constant_value_as_shape(tf_val)
|
|
||||||
self.assertEqual([16, 2, 1], c_val.as_list())
|
|
||||||
|
|
||||||
# :end:negative slice
|
|
||||||
tf_val = tf_val_orig[:1:-2]
|
|
||||||
c_val = tensor_util.constant_value_as_shape(tf_val)
|
|
||||||
self.assertEqual([48, 37], c_val.as_list())
|
|
||||||
|
|
||||||
# begin:end:negative slice
|
|
||||||
tf_val = tf_val_orig[3:1:-1]
|
|
||||||
c_val = tensor_util.constant_value_as_shape(tf_val)
|
|
||||||
self.assertEqual([37, 16], c_val.as_list())
|
|
||||||
|
|
||||||
# begin:negative end:slice
|
|
||||||
tf_val = tf_val_orig[1:-3:1]
|
|
||||||
c_val = tensor_util.constant_value_as_shape(tf_val)
|
|
||||||
self.assertEqual([2, 16], c_val.as_list())
|
|
||||||
|
|
||||||
# negative begin::slice
|
|
||||||
tf_val = tf_val_orig[-3::1]
|
|
||||||
c_val = tensor_util.constant_value_as_shape(tf_val)
|
|
||||||
self.assertEqual([37, None, 48], c_val.as_list())
|
|
||||||
|
|
||||||
# negative begin::negative slice
|
|
||||||
tf_val = tf_val_orig[-3::-1]
|
|
||||||
c_val = tensor_util.constant_value_as_shape(tf_val)
|
|
||||||
self.assertEqual([37, 16, 2, 1], c_val.as_list())
|
|
||||||
|
|
||||||
# negative begin:negative end:negative slice
|
|
||||||
tf_val = tf_val_orig[-3:-5:-1]
|
|
||||||
c_val = tensor_util.constant_value_as_shape(tf_val)
|
|
||||||
self.assertEqual([37, 16], c_val.as_list())
|
|
||||||
|
|
||||||
# Do not support shape inference for additional arguments
|
|
||||||
tf_val = constant_op.constant([10, 20, 30])[...]
|
|
||||||
c_val = tensor_util.constant_value_as_shape(tf_val)
|
|
||||||
self.assertEqual([None, None, None], c_val.as_list())
|
|
||||||
|
|
||||||
# Do not support shape inference for tensor slices.
|
|
||||||
tf_val = constant_op.constant([10, 20, 30])[
|
|
||||||
array_ops.placeholder(dtypes.int32, shape=()):]
|
|
||||||
c_val = tensor_util.constant_value_as_shape(tf_val)
|
|
||||||
self.assertEqual(tensor_shape.unknown_shape(), c_val)
|
|
||||||
|
|
||||||
# Do not support shape inference for higher rank
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
tf_val = constant_op.constant([[10], [20], [30]])[:, 0:]
|
|
||||||
c_val = tensor_util.constant_value_as_shape(tf_val)
|
c_val = tensor_util.constant_value_as_shape(tf_val)
|
||||||
|
self.assertEqual([16, 37, None, 48], c_val.as_list())
|
||||||
|
|
||||||
|
# begin::negative slice
|
||||||
|
tf_val = tf_val_orig[2::-1]
|
||||||
|
c_val = tensor_util.constant_value_as_shape(tf_val)
|
||||||
|
self.assertEqual([16, 2, 1], c_val.as_list())
|
||||||
|
|
||||||
|
# :end:negative slice
|
||||||
|
tf_val = tf_val_orig[:1:-2]
|
||||||
|
c_val = tensor_util.constant_value_as_shape(tf_val)
|
||||||
|
self.assertEqual([48, 37], c_val.as_list())
|
||||||
|
|
||||||
|
# begin:end:negative slice
|
||||||
|
tf_val = tf_val_orig[3:1:-1]
|
||||||
|
c_val = tensor_util.constant_value_as_shape(tf_val)
|
||||||
|
self.assertEqual([37, 16], c_val.as_list())
|
||||||
|
|
||||||
|
# begin:negative end:slice
|
||||||
|
tf_val = tf_val_orig[1:-3:1]
|
||||||
|
c_val = tensor_util.constant_value_as_shape(tf_val)
|
||||||
|
self.assertEqual([2, 16], c_val.as_list())
|
||||||
|
|
||||||
|
# negative begin::slice
|
||||||
|
tf_val = tf_val_orig[-3::1]
|
||||||
|
c_val = tensor_util.constant_value_as_shape(tf_val)
|
||||||
|
self.assertEqual([37, None, 48], c_val.as_list())
|
||||||
|
|
||||||
|
# negative begin::negative slice
|
||||||
|
tf_val = tf_val_orig[-3::-1]
|
||||||
|
c_val = tensor_util.constant_value_as_shape(tf_val)
|
||||||
|
self.assertEqual([37, 16, 2, 1], c_val.as_list())
|
||||||
|
|
||||||
|
# negative begin:negative end:negative slice
|
||||||
|
tf_val = tf_val_orig[-3:-5:-1]
|
||||||
|
c_val = tensor_util.constant_value_as_shape(tf_val)
|
||||||
|
self.assertEqual([37, 16], c_val.as_list())
|
||||||
|
|
||||||
|
# Do not support shape inference for additional arguments
|
||||||
|
tf_val = constant_op.constant([10, 20, 30])[...]
|
||||||
|
c_val = tensor_util.constant_value_as_shape(tf_val)
|
||||||
|
self.assertEqual([None, None, None], c_val.as_list())
|
||||||
|
|
||||||
|
# Do not support shape inference for tensor slices.
|
||||||
|
tf_val = constant_op.constant(
|
||||||
|
[10, 20, 30])[array_ops.placeholder(dtypes.int32, shape=()):]
|
||||||
|
c_val = tensor_util.constant_value_as_shape(tf_val)
|
||||||
|
self.assertEqual(tensor_shape.unknown_shape(), c_val)
|
||||||
|
|
||||||
|
# Do not support shape inference for higher rank
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
tf_val = constant_op.constant([[10], [20], [30]])[:, 0:]
|
||||||
|
c_val = tensor_util.constant_value_as_shape(tf_val)
|
||||||
|
|
||||||
|
|
||||||
class MaybeSetStaticShapeTest(test.TestCase):
|
class MaybeSetStaticShapeTest(test.TestCase):
|
||||||
@ -1190,24 +1202,23 @@ class MaybeSetStaticShapeTest(test.TestCase):
|
|||||||
finally:
|
finally:
|
||||||
tensor_util._ENABLE_MAYBE_SET_STATIC_SHAPE = flag_old
|
tensor_util._ENABLE_MAYBE_SET_STATIC_SHAPE = flag_old
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testMaybeSetStaticShape(self):
|
def testMaybeSetStaticShape(self):
|
||||||
shape = constant_op.constant([2, 5], dtype=dtypes.int32)
|
shape = constant_op.constant([2, 5], dtype=dtypes.int32)
|
||||||
|
|
||||||
def reshape():
|
def reshape():
|
||||||
v = array_ops.zeros([10])
|
v = array_ops.zeros([10])
|
||||||
return array_ops.reshape(v, shape)
|
return array_ops.reshape(v, shape)
|
||||||
|
# This test needs a placeholder which means we need to construct a graph.
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
with self.disableSetStaticShape():
|
||||||
|
graph_without_shape_propagation = func_graph.func_graph_from_py_func(
|
||||||
|
"without_shape_propagation", reshape, [], {})
|
||||||
|
graph_with_shape_propagation = func_graph.func_graph_from_py_func(
|
||||||
|
"with_shape_propagation", reshape, [], {})
|
||||||
|
self.assertCountEqual(
|
||||||
|
[op.type for op in graph_without_shape_propagation.get_operations()],
|
||||||
|
[op.type for op in graph_with_shape_propagation.get_operations()])
|
||||||
|
|
||||||
with self.disableSetStaticShape():
|
|
||||||
graph_without_shape_propagation = func_graph.func_graph_from_py_func(
|
|
||||||
"without_shape_propagation", reshape, [], {})
|
|
||||||
graph_with_shape_propagation = func_graph.func_graph_from_py_func(
|
|
||||||
"with_shape_propagation", reshape, [], {})
|
|
||||||
self.assertCountEqual(
|
|
||||||
[op.type for op in graph_without_shape_propagation.get_operations()],
|
|
||||||
[op.type for op in graph_with_shape_propagation.get_operations()])
|
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testMaybeSetStaticShapeScalarShape(self):
|
def testMaybeSetStaticShapeScalarShape(self):
|
||||||
|
|
||||||
def reshape():
|
def reshape():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user