Make tests compatible with TF2
Handle TensorShapeV2 changes as well as assertion behavior changes. PiperOrigin-RevId: 235935411
This commit is contained in:
parent
0c8deb2f91
commit
6d36f1b408
@ -20,7 +20,6 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.client import session
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
@ -60,7 +59,7 @@ class ConstantFoldingTest(test.TestCase):
|
|||||||
loop_vars=[0, init_y],
|
loop_vars=[0, init_y],
|
||||||
back_prop=False,
|
back_prop=False,
|
||||||
parallel_iterations=1)
|
parallel_iterations=1)
|
||||||
with session.Session() as sess:
|
|
||||||
y_v = self.evaluate(y)
|
y_v = self.evaluate(y)
|
||||||
self.assertAllEqual(np.zeros([10, 20, 30]), y_v)
|
self.assertAllEqual(np.zeros([10, 20, 30]), y_v)
|
||||||
|
|
||||||
|
@ -193,12 +193,12 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertEqual(None, gather_t.shape)
|
self.assertEqual(None, gather_t.shape)
|
||||||
|
|
||||||
def testBadIndicesCPU(self):
|
def testBadIndicesCPU(self):
|
||||||
with self.session(use_gpu=False):
|
with test_util.force_cpu():
|
||||||
params = [[0, 1, 2], [3, 4, 5]]
|
params = [[0, 1, 2], [3, 4, 5]]
|
||||||
with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 2\)"):
|
with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 2\)"):
|
||||||
array_ops.gather(params, [[7]], axis=0).eval()
|
self.evaluate(array_ops.gather(params, [[7]], axis=0))
|
||||||
with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 3\)"):
|
with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 3\)"):
|
||||||
array_ops.gather(params, [[7]], axis=1).eval()
|
self.evaluate(array_ops.gather(params, [[7]], axis=1))
|
||||||
|
|
||||||
def _disabledTestBadIndicesGPU(self):
|
def _disabledTestBadIndicesGPU(self):
|
||||||
# TODO disabled due to different behavior on GPU and CPU
|
# TODO disabled due to different behavior on GPU and CPU
|
||||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
@ -106,7 +107,7 @@ class SquareLinearOperatorFullMatrixTest(
|
|||||||
matrix = [[1., 1.], [1., 1.]]
|
matrix = [[1., 1.], [1., 1.]]
|
||||||
operator = linalg.LinearOperatorFullMatrix(matrix, is_self_adjoint=True)
|
operator = linalg.LinearOperatorFullMatrix(matrix, is_self_adjoint=True)
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
with self.assertRaisesOpError("Cholesky decomposition was not success"):
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
operator.assert_positive_definite().run()
|
operator.assert_positive_definite().run()
|
||||||
|
|
||||||
|
|
||||||
|
@ -589,7 +589,7 @@ def random_positive_definite_matrix(shape, dtype, force_well_conditioned=False):
|
|||||||
if not tensor_util.is_tensor(shape):
|
if not tensor_util.is_tensor(shape):
|
||||||
shape = tensor_shape.TensorShape(shape)
|
shape = tensor_shape.TensorShape(shape)
|
||||||
# Matrix must be square.
|
# Matrix must be square.
|
||||||
shape[-1].assert_is_compatible_with(shape[-2])
|
shape.dims[-1].assert_is_compatible_with(shape.dims[-2])
|
||||||
|
|
||||||
with ops.name_scope("random_positive_definite_matrix"):
|
with ops.name_scope("random_positive_definite_matrix"):
|
||||||
tril = random_tril_matrix(
|
tril = random_tril_matrix(
|
||||||
|
Loading…
Reference in New Issue
Block a user