Change all compiler tests to use self.session

The session returned by cached_session uses soft placement, something we don't
want for XLA_* devices.  With soft placement ops lacking XLA kernels silently
fall back and run on the CPU, misleading us into thinking we have more test
coverage than we actually do.  With this test some tests (rightly) start failing
because they were testing ops with dtypes the XLA kernels do not support.  I've
removed these dtypes from the tests.

This CL partially addresses b/132430685.  It stubs out "cached_session" and
"test_session" to raise errors, so we have more confidence that the compiler is
being exercised.  However, we still use XLA_* devices to exercise XLA, which has
a different code path than xla.compile and tpu.rewrite.  This needs to be
incrementally fixed.

PiperOrigin-RevId: 248437673
This commit is contained in:
Sanjoy Das 2019-05-15 17:23:34 -07:00 committed by TensorFlower Gardener
parent e7d6770051
commit 6762ca15c4
84 changed files with 464 additions and 435 deletions

View File

@ -41,7 +41,7 @@ class AdadeltaOptimizerTest(xla_test.XLATestCase):
all_lr = [1.0, 0.5, 0.1] all_lr = [1.0, 0.5, 0.1]
for dtype in self.float_types: for dtype in self.float_types:
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
for grad in all_grad: for grad in all_grad:
for lr in all_lr: for lr in all_lr:
var0_init = [1.0, 2.0] var0_init = [1.0, 2.0]

View File

@ -33,7 +33,7 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase):
def testAdagradDAWithoutRegularizationBasic1(self): def testAdagradDAWithoutRegularizationBasic1(self):
for dtype in self.float_types: for dtype in self.float_types:
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
global_step = resource_variable_ops.ResourceVariable( global_step = resource_variable_ops.ResourceVariable(
0, dtype=dtypes.int64) 0, dtype=dtypes.int64)
var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype) var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype)
@ -69,7 +69,7 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase):
def testAdagradDAwithoutRegularizationBasic2(self): def testAdagradDAwithoutRegularizationBasic2(self):
for dtype in self.float_types: for dtype in self.float_types:
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
global_step = resource_variable_ops.ResourceVariable( global_step = resource_variable_ops.ResourceVariable(
0, dtype=dtypes.int64) 0, dtype=dtypes.int64)
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
@ -100,7 +100,7 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase):
def testAdagradDAWithL1(self): def testAdagradDAWithL1(self):
for dtype in self.float_types: for dtype in self.float_types:
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
global_step = resource_variable_ops.ResourceVariable( global_step = resource_variable_ops.ResourceVariable(
0, dtype=dtypes.int64) 0, dtype=dtypes.int64)
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
@ -131,7 +131,7 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase):
def testAdagradDAWithL1_L2(self): def testAdagradDAWithL1_L2(self):
for dtype in self.float_types: for dtype in self.float_types:
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
global_step = resource_variable_ops.ResourceVariable( global_step = resource_variable_ops.ResourceVariable(
0, dtype=dtypes.int64) 0, dtype=dtypes.int64)
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)

View File

@ -32,7 +32,7 @@ class AdagradOptimizerTest(xla_test.XLATestCase):
def testBasic(self): def testBasic(self):
for dtype in self.float_types: for dtype in self.float_types:
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@ -59,7 +59,7 @@ class AdagradOptimizerTest(xla_test.XLATestCase):
def testTensorLearningRate(self): def testTensorLearningRate(self):
for dtype in self.float_types: for dtype in self.float_types:
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@ -87,7 +87,7 @@ class AdagradOptimizerTest(xla_test.XLATestCase):
def testSharing(self): def testSharing(self):
for dtype in self.float_types: for dtype in self.float_types:
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)

View File

@ -56,7 +56,7 @@ class AdamOptimizerTest(xla_test.XLATestCase):
# TODO: test fails for float16 due to excessive precision requirements. # TODO: test fails for float16 due to excessive precision requirements.
if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]: if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]:
continue continue
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
variable_scope.get_variable_scope().set_use_resource(True) variable_scope.get_variable_scope().set_use_resource(True)
# Initialize variables for numpy implementation. # Initialize variables for numpy implementation.
@ -99,7 +99,7 @@ class AdamOptimizerTest(xla_test.XLATestCase):
# TODO: test fails for float16 due to excessive precision requirements. # TODO: test fails for float16 due to excessive precision requirements.
if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]: if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]:
continue continue
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
variable_scope.get_variable_scope().set_use_resource(True) variable_scope.get_variable_scope().set_use_resource(True)
# Initialize variables for numpy implementation. # Initialize variables for numpy implementation.
@ -142,7 +142,7 @@ class AdamOptimizerTest(xla_test.XLATestCase):
# TODO: test fails for float16 due to excessive precision requirements. # TODO: test fails for float16 due to excessive precision requirements.
if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]: if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]:
continue continue
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
variable_scope.get_variable_scope().set_use_resource(True) variable_scope.get_variable_scope().set_use_resource(True)
# Initialize variables for numpy implementation. # Initialize variables for numpy implementation.

View File

@ -49,7 +49,7 @@ class AdaMaxOptimizerTest(xla_test.XLATestCase):
def testBasic(self): def testBasic(self):
for i, dtype in enumerate(self.float_types): for i, dtype in enumerate(self.float_types):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
variable_scope.get_variable_scope().set_use_resource(True) variable_scope.get_variable_scope().set_use_resource(True)
# Initialize variables for numpy implementation. # Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
@ -103,7 +103,7 @@ class AdaMaxOptimizerTest(xla_test.XLATestCase):
def testTensorLearningRate(self): def testTensorLearningRate(self):
for dtype in self.float_types: for dtype in self.float_types:
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
variable_scope.get_variable_scope().set_use_resource(True) variable_scope.get_variable_scope().set_use_resource(True)
# Initialize variables for numpy implementation. # Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0

View File

@ -30,7 +30,7 @@ from tensorflow.python.platform import test
class XlaAddNTest(xla_test.XLATestCase): class XlaAddNTest(xla_test.XLATestCase):
def testAddTensorLists(self): def testAddTensorLists(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
l1 = list_ops.tensor_list_reserve( l1 = list_ops.tensor_list_reserve(
element_shape=[], element_dtype=dtypes.float32, num_elements=3) element_shape=[], element_dtype=dtypes.float32, num_elements=3)
l2 = list_ops.tensor_list_reserve( l2 = list_ops.tensor_list_reserve(
@ -44,7 +44,7 @@ class XlaAddNTest(xla_test.XLATestCase):
[5.0, 0.0, 10.0]) [5.0, 0.0, 10.0])
def testAddTensorListsFailsIfLeadingDimsMismatch(self): def testAddTensorListsFailsIfLeadingDimsMismatch(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
l1 = list_ops.tensor_list_reserve( l1 = list_ops.tensor_list_reserve(
element_shape=[], element_dtype=dtypes.float32, num_elements=2) element_shape=[], element_dtype=dtypes.float32, num_elements=2)
l2 = list_ops.tensor_list_reserve( l2 = list_ops.tensor_list_reserve(
@ -56,7 +56,7 @@ class XlaAddNTest(xla_test.XLATestCase):
list_ops.tensor_list_stack(l, element_dtype=dtypes.float32).eval() list_ops.tensor_list_stack(l, element_dtype=dtypes.float32).eval()
def testAddTensorListsFailsIfElementShapesMismatch(self): def testAddTensorListsFailsIfElementShapesMismatch(self):
with self.cached_session() as session, self.test_scope(): with self.session() as session, self.test_scope():
# Use placeholders instead of constant values for shapes to prevent TF's # Use placeholders instead of constant values for shapes to prevent TF's
# shape inference from catching this early. # shape inference from catching this early.
l1_element_shape = array_ops.placeholder(dtype=dtypes.int32) l1_element_shape = array_ops.placeholder(dtype=dtypes.int32)

View File

@ -63,7 +63,7 @@ class AddSignTest(xla_test.XLATestCase):
alpha=1.0, alpha=1.0,
beta=0.9): beta=0.9):
for dtype in self.float_types: for dtype in self.float_types:
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
# Initialize variables for numpy implementation. # Initialize variables for numpy implementation.
m0, m1 = 0.0, 0.0 m0, m1 = 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype) var0_np = np.array([1.0, 2.0], dtype=dtype)

View File

@ -40,7 +40,7 @@ class ArgMinMaxTest(xla_test.XLATestCase):
op_input: numpy input array to use as input to 'op'. op_input: numpy input array to use as input to 'op'.
expected: numpy array representing the expected output of 'op'. expected: numpy array representing the expected output of 'op'.
""" """
with self.cached_session() as session: with self.session() as session:
with self.test_scope(): with self.test_scope():
pinp = array_ops.placeholder( pinp = array_ops.placeholder(
dtypes.as_dtype(op_input.dtype), op_input.shape, name="a") dtypes.as_dtype(op_input.dtype), op_input.shape, name="a")

View File

@ -39,7 +39,7 @@ class BinaryOpsTest(xla_test.XLATestCase):
"""Test cases for binary operators.""" """Test cases for binary operators."""
def _testBinary(self, op, a, b, expected, equality_test=None): def _testBinary(self, op, a, b, expected, equality_test=None):
with self.cached_session() as session: with self.session() as session:
with self.test_scope(): with self.test_scope():
pa = array_ops.placeholder(dtypes.as_dtype(a.dtype), a.shape, name="a") pa = array_ops.placeholder(dtypes.as_dtype(a.dtype), a.shape, name="a")
pb = array_ops.placeholder(dtypes.as_dtype(b.dtype), b.shape, name="b") pb = array_ops.placeholder(dtypes.as_dtype(b.dtype), b.shape, name="b")

View File

@ -29,7 +29,7 @@ from tensorflow.python.platform import test
class BucketizationOpTest(xla_test.XLATestCase): class BucketizationOpTest(xla_test.XLATestCase):
def testInt(self): def testInt(self):
with self.cached_session() as sess: with self.session() as sess:
p = array_ops.placeholder(dtypes.int32) p = array_ops.placeholder(dtypes.int32)
with self.test_scope(): with self.test_scope():
op = math_ops._bucketize(p, boundaries=[0, 3, 8, 11]) op = math_ops._bucketize(p, boundaries=[0, 3, 8, 11])
@ -38,7 +38,7 @@ class BucketizationOpTest(xla_test.XLATestCase):
sess.run(op, {p: [-5, 0, 2, 3, 5, 8, 10, 11, 12]})) sess.run(op, {p: [-5, 0, 2, 3, 5, 8, 10, 11, 12]}))
def testFloat(self): def testFloat(self):
with self.cached_session() as sess: with self.session() as sess:
p = array_ops.placeholder(dtypes.float32) p = array_ops.placeholder(dtypes.float32)
with self.test_scope(): with self.test_scope():
op = math_ops._bucketize(p, boundaries=[0., 3., 8., 11.]) op = math_ops._bucketize(p, boundaries=[0., 3., 8., 11.])
@ -48,7 +48,7 @@ class BucketizationOpTest(xla_test.XLATestCase):
sess.run(op, {p: [-5., 0., 2., 3., 5., 8., 10., 11., 12.]})) sess.run(op, {p: [-5., 0., 2., 3., 5., 8., 10., 11., 12.]}))
def test2DInput(self): def test2DInput(self):
with self.cached_session() as sess: with self.session() as sess:
p = array_ops.placeholder(dtypes.float32) p = array_ops.placeholder(dtypes.float32)
with self.test_scope(): with self.test_scope():
op = math_ops._bucketize(p, boundaries=[0, 3, 8, 11]) op = math_ops._bucketize(p, boundaries=[0, 3, 8, 11])
@ -58,7 +58,7 @@ class BucketizationOpTest(xla_test.XLATestCase):
{p: [[-5, 0, 2, 3, 5], [8, 10, 11, 12, 0]]})) {p: [[-5, 0, 2, 3, 5], [8, 10, 11, 12, 0]]}))
def testInvalidBoundariesOrder(self): def testInvalidBoundariesOrder(self):
with self.cached_session() as sess: with self.session() as sess:
p = array_ops.placeholder(dtypes.int32) p = array_ops.placeholder(dtypes.int32)
with self.test_scope(): with self.test_scope():
op = math_ops._bucketize(p, boundaries=[0, 8, 3, 11]) op = math_ops._bucketize(p, boundaries=[0, 8, 3, 11])
@ -67,7 +67,7 @@ class BucketizationOpTest(xla_test.XLATestCase):
sess.run(op, {p: [-5, 0]}) sess.run(op, {p: [-5, 0]})
def testBoundariesNotList(self): def testBoundariesNotList(self):
with self.cached_session(): with self.session():
with self.assertRaisesRegexp(TypeError, "Expected list.*"): with self.assertRaisesRegexp(TypeError, "Expected list.*"):
p = array_ops.placeholder(dtypes.int32) p = array_ops.placeholder(dtypes.int32)
with self.test_scope(): with self.test_scope():

View File

@ -57,7 +57,7 @@ class CategoricalTest(xla_test.XLATestCase):
Returns: Returns:
Frequencies from sampled classes; shape [batch_size, num_classes]. Frequencies from sampled classes; shape [batch_size, num_classes].
""" """
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
random_seed.set_random_seed(1618) random_seed.set_random_seed(1618)
op = random_ops.multinomial(logits, num_samples, op = random_ops.multinomial(logits, num_samples,
output_dtype=dtypes.int32) output_dtype=dtypes.int32)
@ -80,7 +80,7 @@ class CategoricalTest(xla_test.XLATestCase):
def _testRngIsNotConstant(self, rng, dtype, output_dtype): def _testRngIsNotConstant(self, rng, dtype, output_dtype):
# Tests that 'rng' does not always return the same value. # Tests that 'rng' does not always return the same value.
with self.cached_session(): with self.session():
with self.test_scope(): with self.test_scope():
x = rng(dtype, output_dtype) x = rng(dtype, output_dtype)
@ -108,7 +108,7 @@ class CategoricalTest(xla_test.XLATestCase):
def testCategoricalIsInRange(self): def testCategoricalIsInRange(self):
for dtype in self.float_types: for dtype in self.float_types:
for output_dtype in self.output_dtypes(): for output_dtype in self.output_dtypes():
with self.cached_session(): with self.session():
with self.test_scope(): with self.test_scope():
x = random_ops.multinomial( x = random_ops.multinomial(
array_ops.ones(shape=[1, 20], dtype=dtype), 1000, array_ops.ones(shape=[1, 20], dtype=dtype), 1000,
@ -140,9 +140,10 @@ class CategoricalTest(xla_test.XLATestCase):
self.assertLess(chi2, 1e-3) self.assertLess(chi2, 1e-3)
def testStatelessMultinomialIsInRange(self): def testStatelessMultinomialIsInRange(self):
for dtype in self.float_types: for dtype in self.float_types.intersection(
[dtypes.float32, dtypes.bfloat16]):
for output_dtype in self.output_dtypes(): for output_dtype in self.output_dtypes():
with self.cached_session() as sess: with self.session() as sess:
with self.test_scope(): with self.test_scope():
seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
x = stateless_random_ops.stateless_multinomial( x = stateless_random_ops.stateless_multinomial(
@ -157,7 +158,7 @@ class CategoricalTest(xla_test.XLATestCase):
def testDeterminismMultinomial(self): def testDeterminismMultinomial(self):
# Stateless values should be equal iff the seeds are equal (roughly) # Stateless values should be equal iff the seeds are equal (roughly)
num_samples = 10 num_samples = 10
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
seeds = [(x, y) for x in range(5) for y in range(5)] * 3 seeds = [(x, y) for x in range(5) for y in range(5)] * 3
for logits in ([[0.1, 0.25, 0.5, 0.15]], [[0.5, 0.5], [0.8, 0.2], for logits in ([[0.1, 0.25, 0.5, 0.15]], [[0.5, 0.5], [0.8, 0.2],
@ -170,7 +171,7 @@ class CategoricalTest(xla_test.XLATestCase):
self.assertEqual(s0 == s1, np.all(v0 == v1)) self.assertEqual(s0 == s1, np.all(v0 == v1))
def testEmpty(self): def testEmpty(self):
with self.cached_session(): with self.session():
with self.test_scope(): with self.test_scope():
x = random_ops.multinomial( x = random_ops.multinomial(
array_ops.zeros([42, 40]), 0, output_dtype=dtypes.int32) array_ops.zeros([42, 40]), 0, output_dtype=dtypes.int32)
@ -178,7 +179,7 @@ class CategoricalTest(xla_test.XLATestCase):
self.assertEqual(y.shape, (42, 0)) self.assertEqual(y.shape, (42, 0))
def testEmptyStateless(self): def testEmptyStateless(self):
with self.cached_session() as sess: with self.session() as sess:
with self.test_scope(): with self.test_scope():
seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
x = stateless_random_ops.stateless_multinomial( x = stateless_random_ops.stateless_multinomial(

View File

@ -54,7 +54,7 @@ class CholeskyOpTest(xla_test.XLATestCase):
def _verifyCholesky(self, x, atol=1e-6): def _verifyCholesky(self, x, atol=1e-6):
# Verify that LL^T == x. # Verify that LL^T == x.
with self.cached_session() as sess: with self.session() as sess:
placeholder = array_ops.placeholder( placeholder = array_ops.placeholder(
dtypes.as_dtype(x.dtype), shape=x.shape) dtypes.as_dtype(x.dtype), shape=x.shape)
with self.test_scope(): with self.test_scope():

View File

@ -38,7 +38,7 @@ class ClusteringTest(xla_test.XLATestCase):
val1 = np.array([4, 3, 2, 1], dtype=np.float32) val1 = np.array([4, 3, 2, 1], dtype=np.float32)
val2 = np.array([5, 6, 7, 8], dtype=np.float32) val2 = np.array([5, 6, 7, 8], dtype=np.float32)
expected = val1 + val2 expected = val1 + val2
with self.cached_session(): with self.session():
with self.test_scope(): with self.test_scope():
input1 = constant_op.constant(val1, name="const1") input1 = constant_op.constant(val1, name="const1")
input2 = constant_op.constant(val2, name="const2") input2 = constant_op.constant(val2, name="const2")
@ -50,7 +50,7 @@ class ClusteringTest(xla_test.XLATestCase):
val1 = np.array([4, 3, 2, 1]).astype(np.float32) val1 = np.array([4, 3, 2, 1]).astype(np.float32)
val2 = np.array([5, 6, 7, 8]).astype(np.float32) val2 = np.array([5, 6, 7, 8]).astype(np.float32)
expected = val1 + val2 expected = val1 + val2
with self.cached_session(): with self.session():
with ops.device(CPU_DEVICE): with ops.device(CPU_DEVICE):
input1 = constant_op.constant(val1, name="const1") input1 = constant_op.constant(val1, name="const1")
input2 = constant_op.constant(val2, name="const2") input2 = constant_op.constant(val2, name="const2")
@ -68,7 +68,7 @@ class ClusteringTest(xla_test.XLATestCase):
# where x and z are placed on the CPU and y and w are placed on the XLA # where x and z are placed on the CPU and y and w are placed on the XLA
# device. If y and w are clustered for compilation, then the graph will # device. If y and w are clustered for compilation, then the graph will
# deadlock since the clustered graph will contain a self-loop. # deadlock since the clustered graph will contain a self-loop.
with self.cached_session() as sess: with self.session() as sess:
with ops.device(CPU_DEVICE): with ops.device(CPU_DEVICE):
x = array_ops.placeholder(dtypes.float32, [2]) x = array_ops.placeholder(dtypes.float32, [2])
with self.test_scope(): with self.test_scope():
@ -81,7 +81,7 @@ class ClusteringTest(xla_test.XLATestCase):
self.assertAllClose(result, [12., 2.], rtol=1e-3) self.assertAllClose(result, [12., 2.], rtol=1e-3)
def testHostMemory(self): def testHostMemory(self):
with self.cached_session() as sess: with self.session() as sess:
x = array_ops.placeholder(dtypes.int32) x = array_ops.placeholder(dtypes.int32)
with self.test_scope(): with self.test_scope():
y = x + 1 y = x + 1

View File

@ -33,7 +33,7 @@ from tensorflow.python.platform import googletest
class ConcatTest(xla_test.XLATestCase): class ConcatTest(xla_test.XLATestCase):
def testHStack(self): def testHStack(self):
with self.cached_session(): with self.session():
p1 = array_ops.placeholder(dtypes.float32, shape=[4, 4]) p1 = array_ops.placeholder(dtypes.float32, shape=[4, 4])
p2 = array_ops.placeholder(dtypes.float32, shape=[4, 4]) p2 = array_ops.placeholder(dtypes.float32, shape=[4, 4])
with self.test_scope(): with self.test_scope():
@ -49,7 +49,7 @@ class ConcatTest(xla_test.XLATestCase):
self.assertAllEqual(result[4:, :], params[p2]) self.assertAllEqual(result[4:, :], params[p2])
def testVStack(self): def testVStack(self):
with self.cached_session(): with self.session():
p1 = array_ops.placeholder(dtypes.float32, shape=[4, 4]) p1 = array_ops.placeholder(dtypes.float32, shape=[4, 4])
p2 = array_ops.placeholder(dtypes.float32, shape=[4, 4]) p2 = array_ops.placeholder(dtypes.float32, shape=[4, 4])
with self.test_scope(): with self.test_scope():
@ -65,7 +65,7 @@ class ConcatTest(xla_test.XLATestCase):
self.assertAllEqual(result[:, 4:], params[p2]) self.assertAllEqual(result[:, 4:], params[p2])
def testInt32(self): def testInt32(self):
with self.cached_session(): with self.session():
p1 = np.random.rand(2, 3).astype("i") p1 = np.random.rand(2, 3).astype("i")
p2 = np.random.rand(2, 3).astype("i") p2 = np.random.rand(2, 3).astype("i")
x1 = constant_op.constant(p1) x1 = constant_op.constant(p1)
@ -88,7 +88,7 @@ class ConcatTest(xla_test.XLATestCase):
dtype_feed = dtypes.float32 dtype_feed = dtypes.float32
else: else:
dtype_feed = dtype dtype_feed = dtype
with self.cached_session(): with self.session():
p = [] p = []
for i in np.arange(num_tensors): for i in np.arange(num_tensors):
input_shape = shape input_shape = shape
@ -130,7 +130,7 @@ class ConcatTest(xla_test.XLATestCase):
self._testRandom(dtypes.int32) self._testRandom(dtypes.int32)
def _testGradientsSimple(self): def _testGradientsSimple(self):
with self.cached_session(): with self.session():
inp = [] inp = []
inp_tensors = [] inp_tensors = []
with self.test_scope(): with self.test_scope():
@ -157,7 +157,7 @@ class ConcatTest(xla_test.XLATestCase):
self._testGradientsSimple() self._testGradientsSimple()
def _testGradientsFirstDim(self): def _testGradientsFirstDim(self):
with self.cached_session(): with self.session():
inp = [] inp = []
inp_tensors = [] inp_tensors = []
with self.test_scope(): with self.test_scope():
@ -185,7 +185,7 @@ class ConcatTest(xla_test.XLATestCase):
self._testGradientsFirstDim() self._testGradientsFirstDim()
def _testGradientsLastDim(self): def _testGradientsLastDim(self):
with self.cached_session(): with self.session():
inp = [] inp = []
inp_tensors = [] inp_tensors = []
with self.test_scope(): with self.test_scope():
@ -220,7 +220,7 @@ class ConcatTest(xla_test.XLATestCase):
# Random dim to concat on # Random dim to concat on
concat_dim = np.random.randint(5) concat_dim = np.random.randint(5)
concat_dim_sizes = np.random.randint(1, 5, size=num_tensors) concat_dim_sizes = np.random.randint(1, 5, size=num_tensors)
with self.cached_session(): with self.session():
inp = [] inp = []
inp_tensors = [] inp_tensors = []
with self.test_scope(): with self.test_scope():
@ -254,7 +254,7 @@ class ConcatTest(xla_test.XLATestCase):
def DISABLED_testZeroSize(self): def DISABLED_testZeroSize(self):
# Verify that concat doesn't crash and burn for zero size inputs # Verify that concat doesn't crash and burn for zero size inputs
np.random.seed(7) np.random.seed(7)
with self.cached_session(): with self.session():
with self.test_scope(): with self.test_scope():
for shape0 in (), (2,): for shape0 in (), (2,):
axis = len(shape0) axis = len(shape0)
@ -276,14 +276,14 @@ class ConcatTest(xla_test.XLATestCase):
def testConcatTuple(self): def testConcatTuple(self):
c1 = np.random.rand(4, 4).astype(np.float32) c1 = np.random.rand(4, 4).astype(np.float32)
c2 = np.random.rand(4, 4).astype(np.float32) c2 = np.random.rand(4, 4).astype(np.float32)
with self.cached_session(): with self.session():
with self.test_scope(): with self.test_scope():
concat_list_t = array_ops.concat([c1, c2], 0) concat_list_t = array_ops.concat([c1, c2], 0)
concat_tuple_t = array_ops.concat((c1, c2), 0) concat_tuple_t = array_ops.concat((c1, c2), 0)
self.assertAllEqual(concat_list_t.eval(), self.evaluate(concat_tuple_t)) self.assertAllEqual(concat_list_t.eval(), self.evaluate(concat_tuple_t))
def testConcatNoScalars(self): def testConcatNoScalars(self):
with self.cached_session(): with self.session():
with self.test_scope(): with self.test_scope():
scalar = constant_op.constant(7) scalar = constant_op.constant(7)
dim = array_ops.placeholder(dtypes.int32) dim = array_ops.placeholder(dtypes.int32)
@ -297,7 +297,7 @@ class ConcatTest(xla_test.XLATestCase):
if "CPU" in self.device: if "CPU" in self.device:
self.skipTest("This test can time out on CPU, so we will just allow " self.skipTest("This test can time out on CPU, so we will just allow "
"other backends to catch this specific error.") "other backends to catch this specific error.")
with self.cached_session(): with self.session():
with self.test_scope(): with self.test_scope():
for concat_dim in range(2): for concat_dim in range(2):
params = {} params = {}
@ -333,7 +333,7 @@ class ConcatTest(xla_test.XLATestCase):
class ConcatOffsetTest(xla_test.XLATestCase): class ConcatOffsetTest(xla_test.XLATestCase):
def testBasic(self): def testBasic(self):
with self.cached_session(): with self.session():
with self.test_scope(): with self.test_scope():
cdim = constant_op.constant(1, dtypes.int32) cdim = constant_op.constant(1, dtypes.int32)
s0 = constant_op.constant([2, 3, 5], dtypes.int32) s0 = constant_op.constant([2, 3, 5], dtypes.int32)
@ -347,7 +347,7 @@ class ConcatOffsetTest(xla_test.XLATestCase):
class PackTest(xla_test.XLATestCase): class PackTest(xla_test.XLATestCase):
def testBasic(self): def testBasic(self):
with self.cached_session(): with self.session():
with self.test_scope(): with self.test_scope():
s0 = constant_op.constant([2, 3, 5], dtypes.int32) s0 = constant_op.constant([2, 3, 5], dtypes.int32)
s1 = constant_op.constant([2, 7, 5], dtypes.int32) s1 = constant_op.constant([2, 7, 5], dtypes.int32)
@ -357,7 +357,7 @@ class PackTest(xla_test.XLATestCase):
self.assertAllEqual(ans, [[2, 3, 5], [2, 7, 5], [2, 20, 5]]) self.assertAllEqual(ans, [[2, 3, 5], [2, 7, 5], [2, 20, 5]])
def testScalars(self): def testScalars(self):
with self.cached_session(): with self.session():
with self.test_scope(): with self.test_scope():
s0 = constant_op.constant(2, dtypes.int32) s0 = constant_op.constant(2, dtypes.int32)
s1 = constant_op.constant(3, dtypes.int32) s1 = constant_op.constant(3, dtypes.int32)
@ -367,7 +367,7 @@ class PackTest(xla_test.XLATestCase):
self.assertAllEqual(ans, [2, 3, 5]) self.assertAllEqual(ans, [2, 3, 5])
def testEmpty(self): def testEmpty(self):
with self.cached_session(): with self.session():
with self.test_scope(): with self.test_scope():
s0 = constant_op.constant([[]], dtypes.int32) s0 = constant_op.constant([[]], dtypes.int32)
s1 = constant_op.constant([[]], dtypes.int32) s1 = constant_op.constant([[]], dtypes.int32)

View File

@ -33,7 +33,7 @@ from tensorflow.python.platform import test
class CondTest(xla_test.XLATestCase): class CondTest(xla_test.XLATestCase):
def testCondAndTensorArrayInDefun(self): def testCondAndTensorArrayInDefun(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
xla_context = control_flow_ops.XLAControlFlowContext() xla_context = control_flow_ops.XLAControlFlowContext()
xla_context.Enter() xla_context.Enter()
@ -52,7 +52,7 @@ class CondTest(xla_test.XLATestCase):
xla_context.Exit() xla_context.Exit()
def testCondConstPropagation(self): def testCondConstPropagation(self):
with self.cached_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
xla_context = control_flow_ops.XLAControlFlowContext() xla_context = control_flow_ops.XLAControlFlowContext()
xla_context.Enter() xla_context.Enter()

View File

@ -87,7 +87,7 @@ class Conv2DTest(xla_test.XLATestCase, parameterized.TestCase):
dilations = test_utils.PermuteDimsBetweenDataFormats( dilations = test_utils.PermuteDimsBetweenDataFormats(
dilations, data_format_src, data_format_dst) dilations, data_format_src, data_format_dst)
with self.cached_session() as sess: with self.session() as sess:
t1 = array_ops.placeholder(dtypes.float32, shape=input_sizes) t1 = array_ops.placeholder(dtypes.float32, shape=input_sizes)
t2 = array_ops.placeholder(dtypes.float32, shape=filter_sizes) t2 = array_ops.placeholder(dtypes.float32, shape=filter_sizes)
with self.test_scope(): with self.test_scope():
@ -288,7 +288,7 @@ class Conv2DBackpropInputTest(xla_test.XLATestCase, parameterized.TestCase):
dilations = test_utils.PermuteDimsBetweenDataFormats( dilations = test_utils.PermuteDimsBetweenDataFormats(
dilations, data_format_src, data_format_dst) dilations, data_format_src, data_format_dst)
with self.cached_session() as sess: with self.session() as sess:
t1 = array_ops.placeholder(dtypes.float32, shape=filter_sizes) t1 = array_ops.placeholder(dtypes.float32, shape=filter_sizes)
t2 = array_ops.placeholder(dtypes.float32, shape=out_backprop_sizes) t2 = array_ops.placeholder(dtypes.float32, shape=out_backprop_sizes)
with self.test_scope(): with self.test_scope():
@ -586,7 +586,7 @@ class Conv2DBackpropFilterTest(xla_test.XLATestCase, parameterized.TestCase):
dilations = test_utils.PermuteDimsBetweenDataFormats( dilations = test_utils.PermuteDimsBetweenDataFormats(
dilations, data_format_src, data_format_dst) dilations, data_format_src, data_format_dst)
with self.cached_session() as sess: with self.session() as sess:
t1 = array_ops.placeholder(dtypes.float32, shape=input_sizes) t1 = array_ops.placeholder(dtypes.float32, shape=input_sizes)
t2 = array_ops.placeholder(dtypes.float32, shape=out_backprop_sizes) t2 = array_ops.placeholder(dtypes.float32, shape=out_backprop_sizes)
with self.test_scope(): with self.test_scope():

View File

@ -36,7 +36,7 @@ from tensorflow.python.platform import googletest
class Conv3DBackpropFilterV2GradTest(xla_test.XLATestCase): class Conv3DBackpropFilterV2GradTest(xla_test.XLATestCase):
def testGradient(self): def testGradient(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
for padding in ["SAME", "VALID"]: for padding in ["SAME", "VALID"]:
for stride in [1, 2]: for stride in [1, 2]:
np.random.seed(1) np.random.seed(1)
@ -69,7 +69,7 @@ class Conv3DBackpropFilterV2GradTest(xla_test.XLATestCase):
class Conv3DTransposeTest(xla_test.XLATestCase): class Conv3DTransposeTest(xla_test.XLATestCase):
def testConv3DTransposeSingleStride(self): def testConv3DTransposeSingleStride(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
strides = [1, 1, 1, 1, 1] strides = [1, 1, 1, 1, 1]
# Input, output: [batch, depth, height, width, channel] # Input, output: [batch, depth, height, width, channel]
@ -119,7 +119,7 @@ class Conv3DTransposeTest(xla_test.XLATestCase):
self.assertAllClose(target, value[n, d, h, w, k]) self.assertAllClose(target, value[n, d, h, w, k])
def testConv3DTransposeSame(self): def testConv3DTransposeSame(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
strides = [1, 2, 2, 2, 1] strides = [1, 2, 2, 2, 1]
# Input, output: [batch, depth, height, width, depth] # Input, output: [batch, depth, height, width, depth]
@ -157,7 +157,7 @@ class Conv3DTransposeTest(xla_test.XLATestCase):
self.assertAllClose(target, value[n, d, h, w, k]) self.assertAllClose(target, value[n, d, h, w, k])
def testConv3DTransposeValid(self): def testConv3DTransposeValid(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
strides = [1, 2, 2, 2, 1] strides = [1, 2, 2, 2, 1]
# Input, output: [batch, depth, height, width, depth] # Input, output: [batch, depth, height, width, depth]
@ -217,7 +217,7 @@ class Conv3DTransposeTest(xla_test.XLATestCase):
np.random.seed(1) # Make it reproducible. np.random.seed(1) # Make it reproducible.
x_val = np.random.random_sample(x_shape).astype(np.float64) x_val = np.random.random_sample(x_shape).astype(np.float64)
f_val = np.random.random_sample(f_shape).astype(np.float64) f_val = np.random.random_sample(f_shape).astype(np.float64)
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
x = constant_op.constant(x_val, name="x", dtype=dtypes.float32) x = constant_op.constant(x_val, name="x", dtype=dtypes.float32)
f = constant_op.constant(f_val, name="f", dtype=dtypes.float32) f = constant_op.constant(f_val, name="f", dtype=dtypes.float32)
output = nn_ops.conv3d_transpose( output = nn_ops.conv3d_transpose(

View File

@ -92,7 +92,7 @@ class DenseLayerTest(test.TestCase):
XlaCompile/XlaRun op pair by XLA. XlaCompile/XlaRun op pair by XLA.
""" """
with self.cached_session() as sess: with self.session() as sess:
x = array_ops.placeholder(shape=[2, 2, 3], dtype=np.float32) x = array_ops.placeholder(shape=[2, 2, 3], dtype=np.float32)
with jit_scope(): with jit_scope():
y = layers.dense(x, 3) y = layers.dense(x, 3)
@ -115,7 +115,7 @@ class DenseLayerTest(test.TestCase):
"""Tests that the dense layer node is properly compiled in jit scope. """Tests that the dense layer node is properly compiled in jit scope.
""" """
with self.cached_session() as sess: with self.session() as sess:
x = array_ops.placeholder(shape=[None, None, 3], dtype=np.float32) x = array_ops.placeholder(shape=[None, None, 3], dtype=np.float32)
with jit_scope(): with jit_scope():
y = layers.dense(x, 3) y = layers.dense(x, 3)

View File

@ -151,7 +151,7 @@ class DepthwiseConv2DTest(xla_test.XLATestCase):
dtype=data_type).reshape(tensor_in_sizes) dtype=data_type).reshape(tensor_in_sizes)
x2 = np.array([f * 1.0 for f in range(1, total_size_2 + 1)], x2 = np.array([f * 1.0 for f in range(1, total_size_2 + 1)],
dtype=data_type).reshape(filter_in_sizes) dtype=data_type).reshape(filter_in_sizes)
with self.cached_session() as sess: with self.session() as sess:
if data_type == np.float32: if data_type == np.float32:
tolerance = 1e-4 tolerance = 1e-4
else: else:
@ -247,7 +247,7 @@ class DepthwiseConv2DTest(xla_test.XLATestCase):
dtype=np.float32).reshape(tensor_in_sizes) dtype=np.float32).reshape(tensor_in_sizes)
x2 = np.array([f * 1.0 for f in range(1, total_size_2 + 1)], x2 = np.array([f * 1.0 for f in range(1, total_size_2 + 1)],
dtype=np.float32).reshape(filter_in_sizes) dtype=np.float32).reshape(filter_in_sizes)
with self.cached_session() as sess: with self.session() as sess:
t1 = array_ops.placeholder(shape=tensor_in_sizes, dtype=np.float32) t1 = array_ops.placeholder(shape=tensor_in_sizes, dtype=np.float32)
t2 = array_ops.placeholder(shape=filter_in_sizes, dtype=np.float32) t2 = array_ops.placeholder(shape=filter_in_sizes, dtype=np.float32)
with self.test_scope(): with self.test_scope():
@ -321,7 +321,7 @@ class DepthwiseConv2DTest(xla_test.XLATestCase):
x2 = np.random.rand(*output_sizes).astype(np.float32) x2 = np.random.rand(*output_sizes).astype(np.float32)
def _GetVal(use_xla): def _GetVal(use_xla):
with self.cached_session(): with self.session():
t0 = constant_op.constant(input_sizes, shape=[len(input_sizes)]) t0 = constant_op.constant(input_sizes, shape=[len(input_sizes)])
t1 = array_ops.placeholder(np.float32, shape=filter_sizes) t1 = array_ops.placeholder(np.float32, shape=filter_sizes)
t2 = array_ops.placeholder(np.float32, shape=output_sizes) t2 = array_ops.placeholder(np.float32, shape=output_sizes)
@ -361,7 +361,7 @@ class DepthwiseConv2DTest(xla_test.XLATestCase):
x2 = np.random.rand(*output_sizes).astype(np.float32) x2 = np.random.rand(*output_sizes).astype(np.float32)
def _GetVal(use_xla): def _GetVal(use_xla):
with self.cached_session(): with self.session():
t0 = array_ops.placeholder(np.float32, shape=input_sizes) t0 = array_ops.placeholder(np.float32, shape=input_sizes)
t1 = constant_op.constant(filter_sizes, shape=[len(filter_sizes)]) t1 = constant_op.constant(filter_sizes, shape=[len(filter_sizes)])
t2 = array_ops.placeholder(np.float32, shape=output_sizes) t2 = array_ops.placeholder(np.float32, shape=output_sizes)

View File

@ -30,7 +30,7 @@ from tensorflow.python.platform import test
class DynamicUpdateSliceOpsTest(xla_test.XLATestCase): class DynamicUpdateSliceOpsTest(xla_test.XLATestCase):
def _assertOpOutputMatchesExpected(self, op, args, expected): def _assertOpOutputMatchesExpected(self, op, args, expected):
with self.cached_session() as session: with self.session() as session:
with self.test_scope(): with self.test_scope():
placeholders = [ placeholders = [
array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape) array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape)

View File

@ -30,7 +30,7 @@ from tensorflow.python.platform import googletest
class DynamicStitchTest(xla_test.XLATestCase): class DynamicStitchTest(xla_test.XLATestCase):
def _AssertDynamicStitchResultIs(self, indices, data, expected): def _AssertDynamicStitchResultIs(self, indices, data, expected):
with self.cached_session() as session: with self.session() as session:
index_placeholders = [ index_placeholders = [
array_ops.placeholder(dtypes.as_dtype(arg.dtype)) for arg in indices array_ops.placeholder(dtypes.as_dtype(arg.dtype)) for arg in indices
] ]

View File

@ -104,7 +104,7 @@ class EagerTest(xla_test.XLATestCase):
self.assertAllEqual(15, product) self.assertAllEqual(15, product)
# Run some ops graphly # Run some ops graphly
with context.graph_mode(), self.cached_session(): with context.graph_mode(), self.session():
with self.test_scope(): with self.test_scope():
three = constant_op.constant(3) three = constant_op.constant(3)
five = constant_op.constant(5) five = constant_op.constant(5)

View File

@ -44,7 +44,7 @@ class ExtractImagePatches(xla_test.XLATestCase):
strides = [1] + strides + [1] strides = [1] + strides + [1]
rates = [1] + rates + [1] rates = [1] + rates + [1]
with self.cached_session(): with self.session():
image_placeholder = array_ops.placeholder(dtypes.float32) image_placeholder = array_ops.placeholder(dtypes.float32)
with self.test_scope(): with self.test_scope():
out_tensor = array_ops.extract_image_patches( out_tensor = array_ops.extract_image_patches(

View File

@ -107,7 +107,7 @@ class FakeQuantWithMinMaxArgsTest(xla_test.XLATestCase):
], ],
dtype=np.float32) dtype=np.float32)
with self.cached_session() as session: with self.session() as session:
with self.test_scope(): with self.test_scope():
input_placeholder = array_ops.placeholder( input_placeholder = array_ops.placeholder(
dtypes.float32, inputs.shape, name="inputs") dtypes.float32, inputs.shape, name="inputs")
@ -198,7 +198,7 @@ class FakeQuantWithMinMaxArgsGradientTest(xla_test.XLATestCase):
[0.0, 0.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 0.0, 0.0], [0.0, 0.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 0.0, 0.0],
dtype=np.float32) dtype=np.float32)
with self.cached_session() as session: with self.session() as session:
with self.test_scope(): with self.test_scope():
gradient_placeholder = array_ops.placeholder( gradient_placeholder = array_ops.placeholder(
dtypes.float32, gradients.shape, name="gradients") dtypes.float32, gradients.shape, name="gradients")
@ -306,7 +306,7 @@ class FakeQuantWithMinMaxVarsTest(xla_test.XLATestCase):
], ],
dtype=np.float32) dtype=np.float32)
with self.cached_session() as session: with self.session() as session:
with self.test_scope(): with self.test_scope():
input_placeholder = array_ops.placeholder( input_placeholder = array_ops.placeholder(
dtypes.float32, inputs.shape, name="inputs") dtypes.float32, inputs.shape, name="inputs")
@ -406,7 +406,7 @@ class FakeQuantWithMinMaxVarsGradientTest(xla_test.XLATestCase):
expected_backprops_wrt_min = 1.0 + 2.0 expected_backprops_wrt_min = 1.0 + 2.0
expected_backprops_wrt_max = 10.0 + 11.0 expected_backprops_wrt_max = 10.0 + 11.0
with self.cached_session() as session: with self.session() as session:
with self.test_scope(): with self.test_scope():
gradient_placeholder = array_ops.placeholder( gradient_placeholder = array_ops.placeholder(
dtypes.float32, gradients.shape, name="gradients") dtypes.float32, gradients.shape, name="gradients")

View File

@ -70,7 +70,7 @@ class FFTTest(xla_test.XLATestCase):
data = np.reshape(data.astype(np.float32).view(np.complex64), shape) data = np.reshape(data.astype(np.float32).view(np.complex64), shape)
data = to_32bit(complex_to_input(data)) data = to_32bit(complex_to_input(data))
expected = to_32bit(input_to_expected(data)) expected = to_32bit(input_to_expected(data))
with self.cached_session() as sess: with self.session() as sess:
with self.test_scope(): with self.test_scope():
ph = array_ops.placeholder( ph = array_ops.placeholder(
dtypes.as_dtype(data.dtype), shape=data.shape) dtypes.as_dtype(data.dtype), shape=data.shape)
@ -92,7 +92,7 @@ class FFTTest(xla_test.XLATestCase):
data, nperseg=ws, noverlap=ws - hs, boundary=None, window=window)[2] data, nperseg=ws, noverlap=ws - hs, boundary=None, window=window)[2]
expected = np.swapaxes(expected, -1, -2) expected = np.swapaxes(expected, -1, -2)
expected *= window.sum() # scipy divides by window sum expected *= window.sum() # scipy divides by window sum
with self.cached_session() as sess: with self.session() as sess:
with self.test_scope(): with self.test_scope():
ph = array_ops.placeholder( ph = array_ops.placeholder(
dtypes.as_dtype(data.dtype), shape=data.shape) dtypes.as_dtype(data.dtype), shape=data.shape)

View File

@ -31,13 +31,13 @@ from tensorflow.python.platform import test
class FIFOQueueTest(xla_test.XLATestCase): class FIFOQueueTest(xla_test.XLATestCase):
def testEnqueue(self): def testEnqueue(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
enqueue_op = q.enqueue((10.0,)) enqueue_op = q.enqueue((10.0,))
enqueue_op.run() enqueue_op.run()
def testEnqueueWithShape(self): def testEnqueueWithShape(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, shapes=(3, 2)) q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, shapes=(3, 2))
enqueue_correct_op = q.enqueue(([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],)) enqueue_correct_op = q.enqueue(([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],))
enqueue_correct_op.run() enqueue_correct_op.run()
@ -46,7 +46,7 @@ class FIFOQueueTest(xla_test.XLATestCase):
self.assertEqual(1, q.size().eval()) self.assertEqual(1, q.size().eval())
def testMultipleDequeues(self): def testMultipleDequeues(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
self.evaluate(q.enqueue([1])) self.evaluate(q.enqueue([1]))
self.evaluate(q.enqueue([2])) self.evaluate(q.enqueue([2]))
@ -55,7 +55,7 @@ class FIFOQueueTest(xla_test.XLATestCase):
self.assertAllEqual(set([1, 2, 3]), set([a, b, c])) self.assertAllEqual(set([1, 2, 3]), set([a, b, c]))
def testQueuesDontShare(self): def testQueuesDontShare(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
self.evaluate(q.enqueue(1)) self.evaluate(q.enqueue(1))
q2 = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) q2 = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
@ -64,13 +64,13 @@ class FIFOQueueTest(xla_test.XLATestCase):
self.assertAllEqual(self.evaluate(q.dequeue()), 1) self.assertAllEqual(self.evaluate(q.dequeue()), 1)
def testEnqueueDictWithoutNames(self): def testEnqueueDictWithoutNames(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
with self.assertRaisesRegexp(ValueError, "must have names"): with self.assertRaisesRegexp(ValueError, "must have names"):
q.enqueue({"a": 12.0}) q.enqueue({"a": 12.0})
def testParallelEnqueue(self): def testParallelEnqueue(self):
with self.cached_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0] elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
enqueue_ops = [q.enqueue((x,)) for x in elems] enqueue_ops = [q.enqueue((x,)) for x in elems]
@ -95,7 +95,7 @@ class FIFOQueueTest(xla_test.XLATestCase):
self.assertItemsEqual(elems, results) self.assertItemsEqual(elems, results)
def testParallelDequeue(self): def testParallelDequeue(self):
with self.cached_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0] elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
enqueue_ops = [q.enqueue((x,)) for x in elems] enqueue_ops = [q.enqueue((x,)) for x in elems]
@ -119,7 +119,7 @@ class FIFOQueueTest(xla_test.XLATestCase):
self.assertItemsEqual(elems, results) self.assertItemsEqual(elems, results)
def testDequeue(self): def testDequeue(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0] elems = [10.0, 20.0, 30.0]
enqueue_ops = [q.enqueue((x,)) for x in elems] enqueue_ops = [q.enqueue((x,)) for x in elems]
@ -133,7 +133,7 @@ class FIFOQueueTest(xla_test.XLATestCase):
self.assertEqual([elems[i]], vals) self.assertEqual([elems[i]], vals)
def testEnqueueAndBlockingDequeue(self): def testEnqueueAndBlockingDequeue(self):
with self.cached_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
q = data_flow_ops.FIFOQueue(3, dtypes_lib.float32) q = data_flow_ops.FIFOQueue(3, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0] elems = [10.0, 20.0, 30.0]
enqueue_ops = [q.enqueue((x,)) for x in elems] enqueue_ops = [q.enqueue((x,)) for x in elems]
@ -163,7 +163,7 @@ class FIFOQueueTest(xla_test.XLATestCase):
self.assertEqual([elem], result) self.assertEqual([elem], result)
def testMultiEnqueueAndDequeue(self): def testMultiEnqueueAndDequeue(self):
with self.cached_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
q = data_flow_ops.FIFOQueue(10, (dtypes_lib.int32, dtypes_lib.float32)) q = data_flow_ops.FIFOQueue(10, (dtypes_lib.int32, dtypes_lib.float32))
elems = [(5, 10.0), (10, 20.0), (15, 30.0)] elems = [(5, 10.0), (10, 20.0), (15, 30.0)]
enqueue_ops = [q.enqueue((x, y)) for x, y in elems] enqueue_ops = [q.enqueue((x, y)) for x, y in elems]
@ -179,12 +179,12 @@ class FIFOQueueTest(xla_test.XLATestCase):
self.assertEqual([y], y_val) self.assertEqual([y], y_val)
def testQueueSizeEmpty(self): def testQueueSizeEmpty(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
self.assertEqual([0], q.size().eval()) self.assertEqual([0], q.size().eval())
def testQueueSizeAfterEnqueueAndDequeue(self): def testQueueSizeAfterEnqueueAndDequeue(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
enqueue_op = q.enqueue((10.0,)) enqueue_op = q.enqueue((10.0,))
dequeued_t = q.dequeue() dequeued_t = q.dequeue()

View File

@ -111,7 +111,7 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
def testFtrlwithoutRegularization(self): def testFtrlwithoutRegularization(self):
for dtype in self.float_types: for dtype in self.float_types:
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype) var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
@ -145,7 +145,7 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
def testFtrlwithoutRegularization2(self): def testFtrlwithoutRegularization2(self):
for dtype in self.float_types: for dtype in self.float_types:
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
@ -178,7 +178,7 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
def testFtrlWithL1(self): def testFtrlWithL1(self):
for dtype in self.float_types: for dtype in self.float_types:
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
@ -212,7 +212,7 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
def testFtrlWithL1_L2(self): def testFtrlWithL1_L2(self):
for dtype in self.float_types: for dtype in self.float_types:
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
@ -250,7 +250,7 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
weights will tend to have smaller magnitudes with this parameter set. weights will tend to have smaller magnitudes with this parameter set.
""" """
for dtype in self.float_types: for dtype in self.float_types:
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
@ -284,7 +284,7 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
def testFtrlWithL2ShrinkageDoesNotChangeLrSchedule(self): def testFtrlWithL2ShrinkageDoesNotChangeLrSchedule(self):
"""Verifies that l2 shrinkage in FTRL does not change lr schedule.""" """Verifies that l2 shrinkage in FTRL does not change lr schedule."""
for dtype in self.float_types: for dtype in self.float_types:
with self.test_session(), self.test_scope(): with self.session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
@ -331,9 +331,9 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
def testEquivAdagradwithoutRegularization(self): def testEquivAdagradwithoutRegularization(self):
steps = 5 steps = 5
for dtype in self.float_types: for dtype in self.float_types:
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
val0, val1 = self.equivAdagradTest_FtrlPart(steps, dtype) val0, val1 = self.equivAdagradTest_FtrlPart(steps, dtype)
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
val2, val3 = self.equivAdagradTest_AdagradPart(steps, dtype) val2, val3 = self.equivAdagradTest_AdagradPart(steps, dtype)
self.assertAllCloseAccordingToType(val0, val2, rtol=1e-4, half_rtol=1e-2) self.assertAllCloseAccordingToType(val0, val2, rtol=1e-4, half_rtol=1e-2)
@ -342,9 +342,9 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
def testEquivGradientDescentwithoutRegularization(self): def testEquivGradientDescentwithoutRegularization(self):
steps = 5 steps = 5
for dtype in self.float_types: for dtype in self.float_types:
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
val0, val1 = self.equivGradientDescentTest_FtrlPart(steps, dtype) val0, val1 = self.equivGradientDescentTest_FtrlPart(steps, dtype)
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
val2, val3 = self.equivGradientDescentTest_GradientDescentPart( val2, val3 = self.equivGradientDescentTest_GradientDescentPart(
steps, dtype) steps, dtype)

View File

@ -40,7 +40,7 @@ class FunctionTest(xla_test.XLATestCase):
bval = np.array([5, 6, 7, 8]).reshape([2, 2]).astype(np.float32) bval = np.array([5, 6, 7, 8]).reshape([2, 2]).astype(np.float32)
expected = APlus2B(aval, bval) expected = APlus2B(aval, bval)
with self.cached_session(): with self.session():
@function.Defun(dtypes.float32, dtypes.float32) @function.Defun(dtypes.float32, dtypes.float32)
def Foo(a, b): def Foo(a, b):
@ -66,7 +66,7 @@ class FunctionTest(xla_test.XLATestCase):
bval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32) bval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32)
expected = APlus2B(aval, bval) expected = APlus2B(aval, bval)
with self.cached_session(): with self.session():
@function.Defun(dtypes.float32, dtypes.float32) @function.Defun(dtypes.float32, dtypes.float32)
def Foo(a, b): def Foo(a, b):
@ -90,7 +90,7 @@ class FunctionTest(xla_test.XLATestCase):
bval = np.array([5, 6, 7, 8]).reshape([2, 2]).astype(np.float32) bval = np.array([5, 6, 7, 8]).reshape([2, 2]).astype(np.float32)
expected = Func(aval, bval) expected = Func(aval, bval)
with self.cached_session(): with self.session():
@function.Defun(dtypes.float32, dtypes.float32) @function.Defun(dtypes.float32, dtypes.float32)
def Foo(a, b): def Foo(a, b):
@ -105,7 +105,7 @@ class FunctionTest(xla_test.XLATestCase):
def testCompileTimeConstantsInDefun(self): def testCompileTimeConstantsInDefun(self):
"""Tests that XLA handles compile-time constants in defuns.""" """Tests that XLA handles compile-time constants in defuns."""
with self.cached_session() as sess: with self.session() as sess:
@function.Defun(dtypes.float32, dtypes.int32, dtypes.int32) @function.Defun(dtypes.float32, dtypes.int32, dtypes.int32)
def Foo(a, c, d): def Foo(a, c, d):
@ -140,7 +140,7 @@ class FunctionTest(xla_test.XLATestCase):
bval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32) bval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32)
expected = aval + bval * 2 expected = aval + bval * 2
with self.cached_session() as sess: with self.session() as sess:
with self.test_scope(): with self.test_scope():
a = array_ops.placeholder(dtypes.float32, name="a") a = array_ops.placeholder(dtypes.float32, name="a")
b = array_ops.placeholder(dtypes.float32, name="b") b = array_ops.placeholder(dtypes.float32, name="b")

View File

@ -85,7 +85,7 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase):
y_ref, mean_ref, var_ref, _ = self._reference_training( y_ref, mean_ref, var_ref, _ = self._reference_training(
x_val, scale_val, offset_val, epsilon, data_format_src) x_val, scale_val, offset_val, epsilon, data_format_src)
with self.cached_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
# To avoid constant folding # To avoid constant folding
x_val_converted = test_utils.ConvertBetweenDataFormats( x_val_converted = test_utils.ConvertBetweenDataFormats(
x_val, data_format_src, data_format) x_val, data_format_src, data_format)
@ -130,7 +130,7 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase):
y_ref, mean_ref, _, var_ref_corr = self._reference_training( y_ref, mean_ref, _, var_ref_corr = self._reference_training(
x_val, scale_val, offset_val, epsilon, data_format_src) x_val, scale_val, offset_val, epsilon, data_format_src)
with self.cached_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
# To avoid constant folding # To avoid constant folding
x_val_converted = test_utils.ConvertBetweenDataFormats( x_val_converted = test_utils.ConvertBetweenDataFormats(
x_val, data_format_src, data_format) x_val, data_format_src, data_format)
@ -213,7 +213,7 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase):
grad_x_ref, grad_scale_ref, grad_offset_ref = self._reference_grad( grad_x_ref, grad_scale_ref, grad_offset_ref = self._reference_grad(
x_val, grad_val, scale_val, mean_val, var_val, epsilon, data_format_src) x_val, grad_val, scale_val, mean_val, var_val, epsilon, data_format_src)
with self.cached_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
grad_val_converted = test_utils.ConvertBetweenDataFormats( grad_val_converted = test_utils.ConvertBetweenDataFormats(
grad_val, data_format_src, data_format) grad_val, data_format_src, data_format)
x_val_converted = test_utils.ConvertBetweenDataFormats( x_val_converted = test_utils.ConvertBetweenDataFormats(
@ -266,7 +266,7 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase):
var_val = np.random.random_sample(scale_shape).astype(np.float32) var_val = np.random.random_sample(scale_shape).astype(np.float32)
data_format_src = "NHWC" data_format_src = "NHWC"
with self.cached_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
grad_val_converted = test_utils.ConvertBetweenDataFormats( grad_val_converted = test_utils.ConvertBetweenDataFormats(
grad_val, data_format_src, data_format) grad_val, data_format_src, data_format)
x_val_converted = test_utils.ConvertBetweenDataFormats( x_val_converted = test_utils.ConvertBetweenDataFormats(

View File

@ -29,7 +29,7 @@ from tensorflow.python.platform import test
class GatherNdTest(xla_test.XLATestCase): class GatherNdTest(xla_test.XLATestCase):
def _runGather(self, params, indices): def _runGather(self, params, indices):
with self.cached_session(): with self.session():
paramsp = array_ops.placeholder(params.dtype) paramsp = array_ops.placeholder(params.dtype)
indicesp = array_ops.placeholder(indices.dtype) indicesp = array_ops.placeholder(indices.dtype)
with self.test_scope(): with self.test_scope():
@ -46,7 +46,7 @@ class GatherNdTest(xla_test.XLATestCase):
np.array([[4], [4], [0]], np.int32))) np.array([[4], [4], [0]], np.int32)))
def testEmptyIndicesAndParamsOKButJustEmptyParamsFails(self): def testEmptyIndicesAndParamsOKButJustEmptyParamsFails(self):
with self.cached_session(): with self.session():
params = np.ones((3, 3), dtype=np.float32) params = np.ones((3, 3), dtype=np.float32)
indices_empty = np.empty((0, 2), dtype=np.int32) indices_empty = np.empty((0, 2), dtype=np.int32)

View File

@ -42,7 +42,7 @@ class GatherTest(xla_test.XLATestCase):
return data return data
def testScalar1D(self): def testScalar1D(self):
with self.cached_session() as session, self.test_scope(): with self.session() as session, self.test_scope():
data = np.array([0, 1, 2, 3, 7, 5]) data = np.array([0, 1, 2, 3, 7, 5])
for dtype in self.all_tf_types: for dtype in self.all_tf_types:
for indices in 4, [4], [1, 2, 2, 4, 5]: for indices in 4, [4], [1, 2, 2, 4, 5]:
@ -55,7 +55,7 @@ class GatherTest(xla_test.XLATestCase):
self.assertAllEqual(np_val, gather_val) self.assertAllEqual(np_val, gather_val)
def testScalar2D(self): def testScalar2D(self):
with self.cached_session() as session, self.test_scope(): with self.session() as session, self.test_scope():
data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11], data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11],
[12, 13, 14]]) [12, 13, 14]])
for dtype in self.all_tf_types: for dtype in self.all_tf_types:
@ -70,7 +70,7 @@ class GatherTest(xla_test.XLATestCase):
self.assertAllEqual(expected, gather_val) self.assertAllEqual(expected, gather_val)
def testSimpleTwoD32(self): def testSimpleTwoD32(self):
with self.cached_session() as session, self.test_scope(): with self.session() as session, self.test_scope():
data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11], data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11],
[12, 13, 14]]) [12, 13, 14]])
for dtype in self.all_tf_types: for dtype in self.all_tf_types:
@ -89,7 +89,7 @@ class GatherTest(xla_test.XLATestCase):
if np.int64 not in self.int_types: if np.int64 not in self.int_types:
return return
with self.cached_session() as session, self.test_scope(): with self.session() as session, self.test_scope():
data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11], data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11],
[12, 13, 14]]) [12, 13, 14]])
# The indices must be in bounds for any axis. # The indices must be in bounds for any axis.
@ -117,7 +117,7 @@ class GatherTest(xla_test.XLATestCase):
for axis in 0, 1, 2, 3, -1, -2: for axis in 0, 1, 2, 3, -1, -2:
params = self._buildParams(np.random.randn(*shape), dtype) params = self._buildParams(np.random.randn(*shape), dtype)
indices = np.random.randint(shape[axis], size=indices_shape) indices = np.random.randint(shape[axis], size=indices_shape)
with self.cached_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
tf_params = array_ops.placeholder(dtype=dtype) tf_params = array_ops.placeholder(dtype=dtype)
tf_indices = constant_op.constant(indices, dtype=dtypes.int32) tf_indices = constant_op.constant(indices, dtype=dtypes.int32)
gather = array_ops.gather(tf_params, tf_indices, axis=axis) gather = array_ops.gather(tf_params, tf_indices, axis=axis)
@ -127,7 +127,7 @@ class GatherTest(xla_test.XLATestCase):
self.assertAllEqual(gather_np, gather_value) self.assertAllEqual(gather_np, gather_value)
def testIndicesWithDifferentDimensions(self): def testIndicesWithDifferentDimensions(self):
with self.cached_session(): with self.session():
for dtype in self.numeric_tf_types: for dtype in self.numeric_tf_types:
params = array_ops.placeholder(dtype=dtype) params = array_ops.placeholder(dtype=dtype)
indices = array_ops.placeholder(dtype=np.int32) indices = array_ops.placeholder(dtype=np.int32)
@ -141,7 +141,7 @@ class GatherTest(xla_test.XLATestCase):
[[7]], gather.eval(feed_dict={params: [4, 7, 2], indices: [[1]]})) [[7]], gather.eval(feed_dict={params: [4, 7, 2], indices: [[1]]}))
def testGatherPrecision(self): def testGatherPrecision(self):
with self.cached_session() as session, self.test_scope(): with self.session() as session, self.test_scope():
data = np.array([[0, 0, 0, 0], [0, 2 * (1 + np.exp2(-8)), 0, 0], data = np.array([[0, 0, 0, 0], [0, 2 * (1 + np.exp2(-8)), 0, 0],
[0, 0, 0, 0], [0.015789, 0.0985, 0.55789, 0.3842]]) [0, 0, 0, 0], [0.015789, 0.0985, 0.55789, 0.3842]])
indices = np.array([1, 2, 3, 1]) indices = np.array([1, 2, 3, 1])

View File

@ -53,7 +53,7 @@ class RGBToHSVTest(xla_test.XLATestCase):
inp = GenerateNumpyRandomRGB(shape).astype(nptype) inp = GenerateNumpyRandomRGB(shape).astype(nptype)
# Convert to HSV and back, as a batch and individually # Convert to HSV and back, as a batch and individually
with self.cached_session() as sess: with self.session() as sess:
batch0 = array_ops.placeholder(nptype, shape=shape) batch0 = array_ops.placeholder(nptype, shape=shape)
with self.test_scope(): with self.test_scope():
batch1 = image_ops.rgb_to_hsv(batch0) batch1 = image_ops.rgb_to_hsv(batch0)
@ -77,7 +77,7 @@ class RGBToHSVTest(xla_test.XLATestCase):
data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
for nptype in self.float_types: for nptype in self.float_types:
rgb_np = np.array(data, dtype=nptype).reshape([2, 2, 3]) / 255. rgb_np = np.array(data, dtype=nptype).reshape([2, 2, 3]) / 255.
with self.cached_session(): with self.session():
placeholder = array_ops.placeholder(nptype) placeholder = array_ops.placeholder(nptype)
with self.test_scope(): with self.test_scope():
hsv = image_ops.rgb_to_hsv(placeholder) hsv = image_ops.rgb_to_hsv(placeholder)
@ -96,7 +96,7 @@ class RGBToHSVTest(xla_test.XLATestCase):
for r, g, b in rgb_flat for r, g, b in rgb_flat
]) ])
hsv_np = hsv_np.reshape(4, 4, 4, 3) hsv_np = hsv_np.reshape(4, 4, 4, 3)
with self.cached_session(): with self.session():
placeholder = array_ops.placeholder(nptype) placeholder = array_ops.placeholder(nptype)
with self.test_scope(): with self.test_scope():
hsv_op = image_ops.rgb_to_hsv(placeholder) hsv_op = image_ops.rgb_to_hsv(placeholder)
@ -107,7 +107,7 @@ class RGBToHSVTest(xla_test.XLATestCase):
class AdjustContrastTest(xla_test.XLATestCase): class AdjustContrastTest(xla_test.XLATestCase):
def _testContrast(self, x_np, y_np, contrast_factor): def _testContrast(self, x_np, y_np, contrast_factor):
with self.cached_session(): with self.session():
x = array_ops.placeholder(x_np.dtype, shape=x_np.shape) x = array_ops.placeholder(x_np.dtype, shape=x_np.shape)
flt_x = image_ops.convert_image_dtype(x, dtypes.float32) flt_x = image_ops.convert_image_dtype(x, dtypes.float32)
with self.test_scope(): with self.test_scope():
@ -145,7 +145,7 @@ class AdjustContrastTest(xla_test.XLATestCase):
return y_np return y_np
def _adjustContrastTf(self, x_np, contrast_factor): def _adjustContrastTf(self, x_np, contrast_factor):
with self.cached_session(): with self.session():
x = array_ops.placeholder(np.float32) x = array_ops.placeholder(np.float32)
with self.test_scope(): with self.test_scope():
y = image_ops.adjust_contrast(x, contrast_factor) y = image_ops.adjust_contrast(x, contrast_factor)
@ -179,7 +179,7 @@ class AdjustHueTest(xla_test.XLATestCase):
y_data = [0, 13, 1, 54, 226, 59, 8, 234, 150, 255, 39, 1] y_data = [0, 13, 1, 54, 226, 59, 8, 234, 150, 255, 39, 1]
y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape) y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape)
with self.cached_session(): with self.session():
x = array_ops.placeholder(x_np.dtype, shape=x_shape) x = array_ops.placeholder(x_np.dtype, shape=x_shape)
flt_x = image_ops.convert_image_dtype(x, dtypes.float32) flt_x = image_ops.convert_image_dtype(x, dtypes.float32)
with self.test_scope(): with self.test_scope():
@ -197,7 +197,7 @@ class AdjustHueTest(xla_test.XLATestCase):
y_data = [13, 0, 11, 226, 54, 221, 234, 8, 92, 1, 217, 255] y_data = [13, 0, 11, 226, 54, 221, 234, 8, 92, 1, 217, 255]
y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape) y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape)
with self.cached_session(): with self.session():
x = array_ops.placeholder(x_np.dtype, shape=x_shape) x = array_ops.placeholder(x_np.dtype, shape=x_shape)
flt_x = image_ops.convert_image_dtype(x, dtypes.float32) flt_x = image_ops.convert_image_dtype(x, dtypes.float32)
with self.test_scope(): with self.test_scope():
@ -215,7 +215,7 @@ class AdjustHueTest(xla_test.XLATestCase):
y_data = [13, 0, 11, 226, 54, 221, 234, 8, 92, 1, 217, 255] y_data = [13, 0, 11, 226, 54, 221, 234, 8, 92, 1, 217, 255]
y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape) y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape)
with self.cached_session(): with self.session():
x = array_ops.placeholder(x_np.dtype, shape=x_shape) x = array_ops.placeholder(x_np.dtype, shape=x_shape)
flt_x = image_ops.convert_image_dtype(x, dtypes.float32) flt_x = image_ops.convert_image_dtype(x, dtypes.float32)
with self.test_scope(): with self.test_scope():
@ -243,7 +243,7 @@ class AdjustHueTest(xla_test.XLATestCase):
return y_v.reshape(x_np.shape) return y_v.reshape(x_np.shape)
def _adjustHueTf(self, x_np, delta_h): def _adjustHueTf(self, x_np, delta_h):
with self.cached_session(): with self.session():
x = array_ops.placeholder(dtypes.float32) x = array_ops.placeholder(dtypes.float32)
with self.test_scope(): with self.test_scope():
y = gen_image_ops.adjust_hue(x, delta_h) y = gen_image_ops.adjust_hue(x, delta_h)
@ -323,7 +323,7 @@ class AdjustSaturationTest(xla_test.XLATestCase):
y_rgb_data = [6, 9, 13, 140, 180, 226, 135, 121, 234, 172, 255, 128] y_rgb_data = [6, 9, 13, 140, 180, 226, 135, 121, 234, 172, 255, 128]
y_np = np.array(y_rgb_data, dtype=np.uint8).reshape(x_shape) y_np = np.array(y_rgb_data, dtype=np.uint8).reshape(x_shape)
with self.cached_session(): with self.session():
x = array_ops.placeholder(x_np.dtype, shape=x_shape) x = array_ops.placeholder(x_np.dtype, shape=x_shape)
y = self._adjust_saturation(x, saturation_factor) y = self._adjust_saturation(x, saturation_factor)
y_tf = y.eval({x: x_np}) y_tf = y.eval({x: x_np})
@ -338,7 +338,7 @@ class AdjustSaturationTest(xla_test.XLATestCase):
y_data = [0, 5, 13, 0, 106, 226, 30, 0, 234, 89, 255, 0] y_data = [0, 5, 13, 0, 106, 226, 30, 0, 234, 89, 255, 0]
y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape) y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape)
with self.cached_session(): with self.session():
x = array_ops.placeholder(x_np.dtype, shape=x_shape) x = array_ops.placeholder(x_np.dtype, shape=x_shape)
y = self._adjust_saturation(x, saturation_factor) y = self._adjust_saturation(x, saturation_factor)
y_tf = y.eval({x: x_np}) y_tf = y.eval({x: x_np})
@ -377,7 +377,7 @@ class AdjustSaturationTest(xla_test.XLATestCase):
"gb_same", "gb_same",
"rgb_same", "rgb_same",
] ]
with self.cached_session(): with self.session():
for x_shape in x_shapes: for x_shape in x_shapes:
for test_style in test_styles: for test_style in test_styles:
x_np = np.random.rand(*x_shape) * 255. x_np = np.random.rand(*x_shape) * 255.
@ -416,7 +416,7 @@ class ResizeNearestNeighborTest(xla_test.XLATestCase):
align_corners=True): align_corners=True):
if expected is None: if expected is None:
self.fail("expected must be specified") self.fail("expected must be specified")
with self.cached_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
image = array_ops.placeholder(image_np.dtype) image = array_ops.placeholder(image_np.dtype)
resized = gen_image_ops.resize_nearest_neighbor( resized = gen_image_ops.resize_nearest_neighbor(
image, target_shape, align_corners=align_corners) image, target_shape, align_corners=align_corners)
@ -524,7 +524,7 @@ class ResizeBilinearTest(xla_test.XLATestCase):
align_corners=True): align_corners=True):
if expected is None: if expected is None:
self.fail("expected must be specified") self.fail("expected must be specified")
with self.cached_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
image = array_ops.placeholder(image_np.dtype) image = array_ops.placeholder(image_np.dtype)
resized = gen_image_ops.resize_bilinear( resized = gen_image_ops.resize_bilinear(
image, target_shape, align_corners=align_corners) image, target_shape, align_corners=align_corners)
@ -544,7 +544,7 @@ class ResizeBilinearTest(xla_test.XLATestCase):
self.fail("input_shape must be specified") self.fail("input_shape must be specified")
if expected is None: if expected is None:
self.fail("expected must be specified") self.fail("expected must be specified")
with self.cached_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
dtype = dtype or np.float32 dtype = dtype or np.float32
grads = array_ops.placeholder(np.float32) grads = array_ops.placeholder(np.float32)
resized = gen_image_ops.resize_bilinear_grad( resized = gen_image_ops.resize_bilinear_grad(
@ -722,7 +722,7 @@ class ResizeBilinearTest(xla_test.XLATestCase):
for dtype in self.float_types: for dtype in self.float_types:
input_image = np.array(input_data, dtype=dtype) input_image = np.array(input_data, dtype=dtype)
expected = np.array(expected_data, dtype=dtype) expected = np.array(expected_data, dtype=dtype)
with self.cached_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
image = array_ops.placeholder(input_image.dtype) image = array_ops.placeholder(input_image.dtype)
resized = gen_image_ops.resize_bilinear( resized = gen_image_ops.resize_bilinear(
image, [6, 4], align_corners=False) image, [6, 4], align_corners=False)
@ -741,7 +741,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
iou_threshold_np = np.array(0.5, dtype=np.float32) iou_threshold_np = np.array(0.5, dtype=np.float32)
score_threshold_np = np.array(0.0, dtype=np.float32) score_threshold_np = np.array(0.0, dtype=np.float32)
with self.cached_session() as sess: with self.session() as sess:
boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape)
scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape)
iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, iou_threshold = array_ops.placeholder(iou_threshold_np.dtype,
@ -779,7 +779,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
iou_threshold_np = np.array(0.5, dtype=np.float32) iou_threshold_np = np.array(0.5, dtype=np.float32)
score_threshold_np = np.array(0.0, dtype=np.float32) score_threshold_np = np.array(0.0, dtype=np.float32)
with self.cached_session() as sess: with self.session() as sess:
boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape)
scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape)
iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, iou_threshold = array_ops.placeholder(iou_threshold_np.dtype,
@ -821,7 +821,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
iou_threshold_np = np.array(0.5, dtype=np.float32) iou_threshold_np = np.array(0.5, dtype=np.float32)
score_threshold_np = np.array(0.4, dtype=np.float32) score_threshold_np = np.array(0.4, dtype=np.float32)
with self.cached_session() as sess: with self.session() as sess:
boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape)
scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape)
iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, iou_threshold = array_ops.placeholder(iou_threshold_np.dtype,
@ -864,7 +864,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
iou_threshold_np = np.array(0.5, dtype=np.float32) iou_threshold_np = np.array(0.5, dtype=np.float32)
score_threshold_np = np.array(0.4, dtype=np.float32) score_threshold_np = np.array(0.4, dtype=np.float32)
with self.cached_session() as sess: with self.session() as sess:
boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape)
scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape)
iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, iou_threshold = array_ops.placeholder(iou_threshold_np.dtype,
@ -905,7 +905,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
iou_threshold_np = np.array(0.5, dtype=np.float32) iou_threshold_np = np.array(0.5, dtype=np.float32)
score_threshold_np = np.array(0.1, dtype=np.float32) score_threshold_np = np.array(0.1, dtype=np.float32)
with self.cached_session() as sess: with self.session() as sess:
boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape)
scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape)
iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, iou_threshold = array_ops.placeholder(iou_threshold_np.dtype,

View File

@ -33,7 +33,7 @@ class ListDiffTest(xla_test.XLATestCase):
def _testListDiff(self, x, y, out, idx): def _testListDiff(self, x, y, out, idx):
for dtype in [dtypes.int32, dtypes.int64]: for dtype in [dtypes.int32, dtypes.int64]:
for index_dtype in [dtypes.int32, dtypes.int64]: for index_dtype in [dtypes.int32, dtypes.int64]:
with self.cached_session(): with self.session():
x_tensor = ops.convert_to_tensor(x, dtype=dtype) x_tensor = ops.convert_to_tensor(x, dtype=dtype)
y_tensor = ops.convert_to_tensor(y, dtype=dtype) y_tensor = ops.convert_to_tensor(y, dtype=dtype)
with self.test_scope(): with self.test_scope():

View File

@ -58,7 +58,7 @@ class LRNTest(xla_test.XLATestCase):
return output return output
def _RunAndVerify(self, dtype): def _RunAndVerify(self, dtype):
with self.cached_session(): with self.session():
# random shape # random shape
shape = np.random.randint(1, 16, size=4) shape = np.random.randint(1, 16, size=4)
# Make depth at least 2 to make it meaningful # Make depth at least 2 to make it meaningful
@ -110,7 +110,7 @@ class LRNTest(xla_test.XLATestCase):
alpha = 1.0 * np.random.rand() alpha = 1.0 * np.random.rand()
beta = 1.0 * np.random.rand() beta = 1.0 * np.random.rand()
with self.cached_session(): with self.session():
in_image = constant_op.constant(in_image_vals, shape=shape) in_image = constant_op.constant(in_image_vals, shape=shape)
out_image = constant_op.constant(out_image_vals, shape=shape) out_image = constant_op.constant(out_image_vals, shape=shape)
out_grads = constant_op.constant(out_grads_vals, shape=shape) out_grads = constant_op.constant(out_grads_vals, shape=shape)

View File

@ -73,7 +73,7 @@ class LSTMTest(test.TestCase):
def _RunLSTMCell(self, basename, init_weights, m_prev_scalar, c_prev_scalar, def _RunLSTMCell(self, basename, init_weights, m_prev_scalar, c_prev_scalar,
pad_scalar): pad_scalar):
with self.cached_session() as sess: with self.session() as sess:
num_inputs = 1 num_inputs = 1
num_nodes = 1 num_nodes = 1
@ -156,7 +156,7 @@ class LSTMTest(test.TestCase):
def _RunLSTMLayer(self, basename, init_weights, m_init_scalar, c_init_scalar, def _RunLSTMLayer(self, basename, init_weights, m_init_scalar, c_init_scalar,
pad_scalar): pad_scalar):
with self.cached_session() as sess: with self.session() as sess:
num_inputs = 1 num_inputs = 1
num_nodes = 1 num_nodes = 1
seq_length = 3 seq_length = 3

View File

@ -173,7 +173,7 @@ class MatrixBandPartTest(xla_test.XLATestCase, parameterized.TestCase):
]: ]:
pass pass
for dtype in self.float_types: for dtype in self.float_types:
with self.cached_session(): with self.session():
mat = np.ones(batch_shape + [rows, cols]).astype(dtype) mat = np.ones(batch_shape + [rows, cols]).astype(dtype)
batch_mat = np.tile(mat, batch_shape + [1, 1]) batch_mat = np.tile(mat, batch_shape + [1, 1])
for lower in -1, 0, 1, rows - 1: for lower in -1, 0, 1, rows - 1:

View File

@ -54,7 +54,7 @@ class MatrixTriangularSolveOpTest(xla_test.XLATestCase):
def _VerifyTriangularSolve(self, a, b, lower, adjoint, atol): def _VerifyTriangularSolve(self, a, b, lower, adjoint, atol):
clean_a = np.tril(a) if lower else np.triu(a) clean_a = np.tril(a) if lower else np.triu(a)
with self.cached_session() as sess: with self.session() as sess:
placeholder_a = MakePlaceholder(a) placeholder_a = MakePlaceholder(a)
placeholder_ca = MakePlaceholder(clean_a) placeholder_ca = MakePlaceholder(clean_a)
placeholder_b = MakePlaceholder(b) placeholder_b = MakePlaceholder(b)

View File

@ -41,7 +41,7 @@ class MomentumOptimizerTest(xla_test.XLATestCase):
def testBasic(self): def testBasic(self):
for dtype in self.float_types: for dtype in self.float_types:
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@ -101,7 +101,7 @@ class MomentumOptimizerTest(xla_test.XLATestCase):
def testNesterovMomentum(self): def testNesterovMomentum(self):
for dtype in self.float_types: for dtype in self.float_types:
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([0.1, 0.2], dtype=dtype) var0 = resource_variable_ops.ResourceVariable([0.1, 0.2], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([0.3, 0.4], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([0.3, 0.4], dtype=dtype)
var0_np = np.array([0.1, 0.2], dtype=dtype) var0_np = np.array([0.1, 0.2], dtype=dtype)
@ -126,7 +126,7 @@ class MomentumOptimizerTest(xla_test.XLATestCase):
def testTensorLearningRateAndMomentum(self): def testTensorLearningRateAndMomentum(self):
for dtype in self.float_types: for dtype in self.float_types:
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)

View File

@ -32,7 +32,7 @@ from tensorflow.python.platform import googletest
class NAryOpsTest(xla_test.XLATestCase): class NAryOpsTest(xla_test.XLATestCase):
def _testNAry(self, op, args, expected, equality_fn=None): def _testNAry(self, op, args, expected, equality_fn=None):
with self.cached_session() as session: with self.session() as session:
with self.test_scope(): with self.test_scope():
placeholders = [ placeholders = [
array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape) array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape)
@ -126,7 +126,7 @@ class NAryOpsTest(xla_test.XLATestCase):
[[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]], dtype=np.float32)) [[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]], dtype=np.float32))
def testOneHot(self): def testOneHot(self):
with self.cached_session() as session, self.test_scope(): with self.session() as session, self.test_scope():
indices = array_ops.constant(np.array([[2, 3], [0, 1]], dtype=np.int32)) indices = array_ops.constant(np.array([[2, 3], [0, 1]], dtype=np.int32))
op = array_ops.one_hot(indices, op = array_ops.one_hot(indices,
np.int32(4), np.int32(4),
@ -148,7 +148,7 @@ class NAryOpsTest(xla_test.XLATestCase):
self.assertAllEqual(output, expected) self.assertAllEqual(output, expected)
def testSplitV(self): def testSplitV(self):
with self.cached_session() as session: with self.session() as session:
with self.test_scope(): with self.test_scope():
output = session.run( output = session.run(
array_ops.split(np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 0, 1, 2]], array_ops.split(np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 0, 1, 2]],

View File

@ -29,14 +29,14 @@ from tensorflow.python.platform import googletest
class NullaryOpsTest(xla_test.XLATestCase): class NullaryOpsTest(xla_test.XLATestCase):
def _testNullary(self, op, expected): def _testNullary(self, op, expected):
with self.cached_session() as session: with self.session() as session:
with self.test_scope(): with self.test_scope():
output = op() output = op()
result = session.run(output) result = session.run(output)
self.assertAllClose(result, expected, rtol=1e-3) self.assertAllClose(result, expected, rtol=1e-3)
def testNoOp(self): def testNoOp(self):
with self.cached_session(): with self.session():
with self.test_scope(): with self.test_scope():
output = control_flow_ops.no_op() output = control_flow_ops.no_op()
# This should not crash. # This should not crash.

View File

@ -30,7 +30,7 @@ from tensorflow.python.platform import test
class XlaPermuteOpTest(xla_test.XLATestCase): class XlaPermuteOpTest(xla_test.XLATestCase):
def _runPermuteAndCompare(self, x, src_format, dst_format, expected): def _runPermuteAndCompare(self, x, src_format, dst_format, expected):
with self.cached_session() as session: with self.session() as session:
with self.test_scope(): with self.test_scope():
placeholder = array_ops.placeholder(dtypes.as_dtype(x.dtype), x.shape) placeholder = array_ops.placeholder(dtypes.as_dtype(x.dtype), x.shape)
param = {placeholder: x} param = {placeholder: x}

View File

@ -28,7 +28,7 @@ from tensorflow.python.platform import googletest
class PlaceholderTest(xla_test.XLATestCase): class PlaceholderTest(xla_test.XLATestCase):
def test_placeholder_with_default_default(self): def test_placeholder_with_default_default(self):
with self.cached_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
v = resource_variable_ops.ResourceVariable(4.0) v = resource_variable_ops.ResourceVariable(4.0)
ph = array_ops.placeholder_with_default(v, shape=[]) ph = array_ops.placeholder_with_default(v, shape=[])
out = ph * 2 out = ph * 2
@ -36,7 +36,7 @@ class PlaceholderTest(xla_test.XLATestCase):
self.assertEqual(8.0, self.evaluate(out)) self.assertEqual(8.0, self.evaluate(out))
def test_placeholder_with_default_fed(self): def test_placeholder_with_default_fed(self):
with self.cached_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
v = resource_variable_ops.ResourceVariable(4.0) v = resource_variable_ops.ResourceVariable(4.0)
ph = array_ops.placeholder_with_default(v, shape=[]) ph = array_ops.placeholder_with_default(v, shape=[])
out = ph * 2 out = ph * 2

View File

@ -62,7 +62,7 @@ class Pooling3DTest(xla_test.XLATestCase):
# numbers from 1. # numbers from 1.
x = np.arange(1.0, total_size + 1, dtype=np.float32) x = np.arange(1.0, total_size + 1, dtype=np.float32)
x = x.reshape(input_sizes) x = x.reshape(input_sizes)
with self.cached_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
inputs = array_ops.placeholder(dtypes.float32) inputs = array_ops.placeholder(dtypes.float32)
t = pool_func( t = pool_func(
inputs, inputs,
@ -210,7 +210,7 @@ class Pooling3DTest(xla_test.XLATestCase):
strides = [1] + strides + [1] strides = [1] + strides + [1]
total_size = np.prod(input_sizes) total_size = np.prod(input_sizes)
x = np.arange(1, total_size + 1, dtype=np.float32).reshape(input_sizes) x = np.arange(1, total_size + 1, dtype=np.float32).reshape(input_sizes)
with self.cached_session() as sess: with self.session() as sess:
# Use the forward pool function to compute some corresponding outputs # Use the forward pool function to compute some corresponding outputs
# (needed for the CPU device, and we need the shape in both cases). # (needed for the CPU device, and we need the shape in both cases).
with ops.device("CPU"): with ops.device("CPU"):

View File

@ -89,7 +89,7 @@ class PoolingTest(xla_test.XLATestCase):
# numbers from 1. # numbers from 1.
x = np.array([f * 1.0 for f in range(1, total_size + 1)], dtype=np.float32) x = np.array([f * 1.0 for f in range(1, total_size + 1)], dtype=np.float32)
x = x.reshape(input_sizes) x = x.reshape(input_sizes)
with self.cached_session() as sess: with self.session() as sess:
with self.test_scope(): with self.test_scope():
inputs = array_ops.placeholder(dtypes.float32) inputs = array_ops.placeholder(dtypes.float32)
t = inputs t = inputs
@ -324,7 +324,7 @@ class PoolGradTest(xla_test.XLATestCase):
# TODO(b/74222344): Fix nan handling for max pool grad. # TODO(b/74222344): Fix nan handling for max pool grad.
# x[np.random.choice(total_size)] = np.nan # x[np.random.choice(total_size)] = np.nan
x = x.reshape(input_sizes) x = x.reshape(input_sizes)
with self.cached_session() as sess: with self.session() as sess:
# Use the forward pool function to compute some corresponding outputs # Use the forward pool function to compute some corresponding outputs
# (needed for the CPU device, and we need the shape in both cases). # (needed for the CPU device, and we need the shape in both cases).
with ops.device(self.CPU_DEVICE): with ops.device(self.CPU_DEVICE):

View File

@ -64,7 +64,7 @@ class PowerSignTest(xla_test.XLATestCase):
base=math.e, base=math.e,
beta=0.9): beta=0.9):
for dtype in self.float_types: for dtype in self.float_types:
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
# Initialize variables for numpy implementation. # Initialize variables for numpy implementation.
m0, m1 = 0.0, 0.0 m0, m1 = 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype) var0_np = np.array([1.0, 2.0], dtype=dtype)

View File

@ -32,7 +32,7 @@ from tensorflow.python.training import proximal_adagrad
class ProximalAdagradOptimizerTest(xla_test.XLATestCase): class ProximalAdagradOptimizerTest(xla_test.XLATestCase):
def testResourceProximalAdagradwithoutRegularization(self): def testResourceProximalAdagradwithoutRegularization(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([0.0, 0.0]) var0 = resource_variable_ops.ResourceVariable([0.0, 0.0])
var1 = resource_variable_ops.ResourceVariable([0.0, 0.0]) var1 = resource_variable_ops.ResourceVariable([0.0, 0.0])
grads0 = constant_op.constant([0.1, 0.2]) grads0 = constant_op.constant([0.1, 0.2])
@ -62,7 +62,7 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase):
self.assertEqual(2, len(opt_vars)) self.assertEqual(2, len(opt_vars))
def testProximalAdagradwithoutRegularization2(self): def testProximalAdagradwithoutRegularization2(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) var0 = resource_variable_ops.ResourceVariable([1.0, 2.0])
var1 = resource_variable_ops.ResourceVariable([4.0, 3.0]) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0])
grads0 = constant_op.constant([0.1, 0.2]) grads0 = constant_op.constant([0.1, 0.2])
@ -86,7 +86,7 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase):
self.assertAllClose(np.array([3.715679, 2.433051]), self.evaluate(var1)) self.assertAllClose(np.array([3.715679, 2.433051]), self.evaluate(var1))
def testProximalAdagradWithL1(self): def testProximalAdagradWithL1(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) var0 = resource_variable_ops.ResourceVariable([1.0, 2.0])
var1 = resource_variable_ops.ResourceVariable([4.0, 3.0]) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0])
grads0 = constant_op.constant([0.1, 0.2]) grads0 = constant_op.constant([0.1, 0.2])
@ -110,7 +110,7 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase):
self.assertAllClose(np.array([2.959304, 1.029232]), self.evaluate(var1)) self.assertAllClose(np.array([2.959304, 1.029232]), self.evaluate(var1))
def testProximalAdagradWithL1_L2(self): def testProximalAdagradWithL1_L2(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) var0 = resource_variable_ops.ResourceVariable([1.0, 2.0])
var1 = resource_variable_ops.ResourceVariable([4.0, 3.0]) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0])
grads0 = constant_op.constant([0.1, 0.2]) grads0 = constant_op.constant([0.1, 0.2])
@ -153,7 +153,7 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase):
return self.evaluate(var0), self.evaluate(var1) return self.evaluate(var0), self.evaluate(var1)
def testEquivAdagradwithoutRegularization(self): def testEquivAdagradwithoutRegularization(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
val0, val1 = self.applyOptimizer( val0, val1 = self.applyOptimizer(
proximal_adagrad.ProximalAdagradOptimizer( proximal_adagrad.ProximalAdagradOptimizer(
3.0, 3.0,
@ -161,7 +161,7 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase):
l1_regularization_strength=0.0, l1_regularization_strength=0.0,
l2_regularization_strength=0.0)) l2_regularization_strength=0.0))
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
val2, val3 = self.applyOptimizer( val2, val3 = self.applyOptimizer(
adagrad.AdagradOptimizer( adagrad.AdagradOptimizer(
3.0, initial_accumulator_value=0.1)) 3.0, initial_accumulator_value=0.1))

View File

@ -32,7 +32,7 @@ from tensorflow.python.training import proximal_gradient_descent
class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase): class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase):
def testResourceProximalGradientDescentwithoutRegularization(self): def testResourceProximalGradientDescentwithoutRegularization(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([0.0, 0.0]) var0 = resource_variable_ops.ResourceVariable([0.0, 0.0])
var1 = resource_variable_ops.ResourceVariable([0.0, 0.0]) var1 = resource_variable_ops.ResourceVariable([0.0, 0.0])
grads0 = constant_op.constant([0.1, 0.2]) grads0 = constant_op.constant([0.1, 0.2])
@ -53,7 +53,7 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase):
self.assertAllClose(np.array([-0.09, -0.18]), self.evaluate(var1)) self.assertAllClose(np.array([-0.09, -0.18]), self.evaluate(var1))
def testProximalGradientDescentwithoutRegularization2(self): def testProximalGradientDescentwithoutRegularization2(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) var0 = resource_variable_ops.ResourceVariable([1.0, 2.0])
var1 = resource_variable_ops.ResourceVariable([4.0, 3.0]) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0])
grads0 = constant_op.constant([0.1, 0.2]) grads0 = constant_op.constant([0.1, 0.2])
@ -75,7 +75,7 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase):
self.assertAllClose(np.array([3.91, 2.82]), self.evaluate(var1)) self.assertAllClose(np.array([3.91, 2.82]), self.evaluate(var1))
def testProximalGradientDescentWithL1(self): def testProximalGradientDescentWithL1(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) var0 = resource_variable_ops.ResourceVariable([1.0, 2.0])
var1 = resource_variable_ops.ResourceVariable([4.0, 3.0]) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0])
grads0 = constant_op.constant([0.1, 0.2]) grads0 = constant_op.constant([0.1, 0.2])
@ -97,7 +97,7 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase):
self.assertAllClose(np.array([3.67, 2.37]), self.evaluate(var1)) self.assertAllClose(np.array([3.67, 2.37]), self.evaluate(var1))
def testProximalGradientDescentWithL1_L2(self): def testProximalGradientDescentWithL1_L2(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) var0 = resource_variable_ops.ResourceVariable([1.0, 2.0])
var1 = resource_variable_ops.ResourceVariable([4.0, 3.0]) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0])
grads0 = constant_op.constant([0.1, 0.2]) grads0 = constant_op.constant([0.1, 0.2])
@ -137,14 +137,14 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase):
return self.evaluate(var0), self.evaluate(var1) return self.evaluate(var0), self.evaluate(var1)
def testEquivGradientDescentwithoutRegularization(self): def testEquivGradientDescentwithoutRegularization(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
val0, val1 = self.applyOptimizer( val0, val1 = self.applyOptimizer(
proximal_gradient_descent.ProximalGradientDescentOptimizer( proximal_gradient_descent.ProximalGradientDescentOptimizer(
3.0, 3.0,
l1_regularization_strength=0.0, l1_regularization_strength=0.0,
l2_regularization_strength=0.0)) l2_regularization_strength=0.0))
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
val2, val3 = self.applyOptimizer( val2, val3 = self.applyOptimizer(
gradient_descent.GradientDescentOptimizer(3.0)) gradient_descent.GradientDescentOptimizer(3.0))

View File

@ -71,7 +71,7 @@ class QrOpTest(xla_test.XLATestCase, parameterized.TestCase):
x_np = np.random.uniform( x_np = np.random.uniform(
low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype) low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype)
with self.cached_session() as sess: with self.session() as sess:
x_tf = array_ops.placeholder(dtype) x_tf = array_ops.placeholder(dtype)
with self.test_scope(): with self.test_scope():
q_tf, r_tf = linalg_ops.qr(x_tf, full_matrices=full_matrices) q_tf, r_tf = linalg_ops.qr(x_tf, full_matrices=full_matrices)

View File

@ -36,7 +36,7 @@ class QuantizedOpsTest(xla_test.XLATestCase):
# Verify that quantized types can be clustered by XLA. # Verify that quantized types can be clustered by XLA.
def testQuantizedTypeRoundtrip(self): def testQuantizedTypeRoundtrip(self):
with self.cached_session() as session: with self.session() as session:
for dtype in self.quantized_tf_types: for dtype in self.quantized_tf_types:
in_values = np.array([1, 2, 3, 4, 5, 6]) in_values = np.array([1, 2, 3, 4, 5, 6])
expected = [[1, 2], [3, 4], [5, 6]] expected = [[1, 2], [3, 4], [5, 6]]
@ -82,7 +82,7 @@ class DeuantizedOpsTest(xla_test.XLATestCase):
num_rows = 100 num_rows = 100
num_columns = 3547 num_columns = 3547
random_input = np.random.normal(128.0, 10.0, [num_rows, num_columns]) random_input = np.random.normal(128.0, 10.0, [num_rows, num_columns])
with self.cached_session() as session: with self.session() as session:
with ops.device("CPU"): with ops.device("CPU"):
test_input = ops.convert_to_tensor(random_input, dtype=dtypes.float32) test_input = ops.convert_to_tensor(random_input, dtype=dtypes.float32)
transposed_input = array_ops.transpose(test_input, [1, 0]) transposed_input = array_ops.transpose(test_input, [1, 0])
@ -95,7 +95,7 @@ class DeuantizedOpsTest(xla_test.XLATestCase):
quantized_output = array_ops.slice(transposed_quantized_output, [0, 0], quantized_output = array_ops.slice(transposed_quantized_output, [0, 0],
[num_rows, num_columns]) [num_rows, num_columns])
value = session.run(quantized_output) value = session.run(quantized_output)
self.assertAllClose(value, random_input, 1.0) self.assertAllClose(value, random_input, 1.0)

View File

@ -40,7 +40,7 @@ class RandomOpsTest(xla_test.XLATestCase):
def _testRngIsNotConstant(self, rng, dtype): def _testRngIsNotConstant(self, rng, dtype):
# Tests that 'rng' does not always return the same value. # Tests that 'rng' does not always return the same value.
with self.cached_session() as sess: with self.session() as sess:
with self.test_scope(): with self.test_scope():
x = rng(dtype) x = rng(dtype)
@ -74,7 +74,7 @@ class RandomOpsTest(xla_test.XLATestCase):
def testRandomNormalMean(self): def testRandomNormalMean(self):
for dtype in self._random_types() & self.float_types: for dtype in self._random_types() & self.float_types:
with self.cached_session(): with self.session():
with self.test_scope(): with self.test_scope():
normal = random_ops.random_normal([1024], normal = random_ops.random_normal([1024],
dtype=dtype, dtype=dtype,
@ -86,7 +86,7 @@ class RandomOpsTest(xla_test.XLATestCase):
def testRandomNormalVariance(self): def testRandomNormalVariance(self):
for dtype in self._random_types() & self.float_types: for dtype in self._random_types() & self.float_types:
with self.cached_session(): with self.session():
with self.test_scope(): with self.test_scope():
normal = random_ops.random_normal([1024], normal = random_ops.random_normal([1024],
dtype=dtype, dtype=dtype,
@ -103,7 +103,7 @@ class RandomOpsTest(xla_test.XLATestCase):
if (self.device in ["XLA_GPU", "XLA_CPU" if (self.device in ["XLA_GPU", "XLA_CPU"
]) and (dtype in [dtypes.bfloat16, dtypes.half]): ]) and (dtype in [dtypes.bfloat16, dtypes.half]):
continue continue
with self.cached_session() as sess: with self.session() as sess:
with self.test_scope(): with self.test_scope():
x = random_ops.random_uniform( x = random_ops.random_uniform(
shape=[1000], dtype=dtype, minval=-2, maxval=33) shape=[1000], dtype=dtype, minval=-2, maxval=33)
@ -116,14 +116,13 @@ class RandomOpsTest(xla_test.XLATestCase):
def rng(dtype): def rng(dtype):
return random_ops.truncated_normal(shape=[2], dtype=dtype) return random_ops.truncated_normal(shape=[2], dtype=dtype)
for dtype in self._random_types() & self.float_types: self._testRngIsNotConstant(rng, dtypes.float32)
self._testRngIsNotConstant(rng, dtype)
def testTruncatedNormalIsInRange(self): def testTruncatedNormalIsInRange(self):
count = 10000000 count = 10000000
# TODO(b/34339814): make this test work with 16 bit float types. # TODO(b/34339814): make this test work with 16 bit float types.
for dtype in self._random_types() & {dtypes.float32, dtypes.float64}: for dtype in self._random_types() & {dtypes.float32, dtypes.float64}:
with self.cached_session() as sess: with self.session() as sess:
with self.test_scope(): with self.test_scope():
x = random_ops.truncated_normal(shape=[count], dtype=dtype) x = random_ops.truncated_normal(shape=[count], dtype=dtype)
y = self.evaluate(x) y = self.evaluate(x)
@ -168,7 +167,7 @@ class RandomOpsTest(xla_test.XLATestCase):
self.assertAllClose(actual_variance, expected_variance, rtol=2*1e-3) self.assertAllClose(actual_variance, expected_variance, rtol=2*1e-3)
def testShuffle1d(self): def testShuffle1d(self):
with self.cached_session() as sess: with self.session() as sess:
with self.test_scope(): with self.test_scope():
x = math_ops.range(1 << 16) x = math_ops.range(1 << 16)
shuffle = random_ops.random_shuffle(x) shuffle = random_ops.random_shuffle(x)
@ -179,7 +178,7 @@ class RandomOpsTest(xla_test.XLATestCase):
self.assertAllEqual(set(result), set(expected)) self.assertAllEqual(set(result), set(expected))
def testShuffle2d(self): def testShuffle2d(self):
with self.cached_session() as sess: with self.session() as sess:
with self.test_scope(): with self.test_scope():
x = array_ops.diag(math_ops.range(20)) x = array_ops.diag(math_ops.range(20))
shuffle = random_ops.random_shuffle(x) shuffle = random_ops.random_shuffle(x)

View File

@ -45,7 +45,7 @@ class ReduceOpsTest(xla_test.XLATestCase, parameterized.TestCase):
"""Tests that the output of 'tf_reduce_fn' matches numpy's output.""" """Tests that the output of 'tf_reduce_fn' matches numpy's output."""
for test_input in test_inputs: for test_input in test_inputs:
with self.cached_session() as sess: with self.session() as sess:
with self.test_scope(): with self.test_scope():
a = array_ops.placeholder(dtype) a = array_ops.placeholder(dtype)
index = array_ops.placeholder(index_dtype) index = array_ops.placeholder(index_dtype)
@ -190,7 +190,7 @@ class ReduceOpPrecisionTest(xla_test.XLATestCase):
""" """
for test_input in test_inputs: for test_input in test_inputs:
with self.cached_session() as sess: with self.session() as sess:
with self.test_scope(): with self.test_scope():
a = array_ops.placeholder(dtype) a = array_ops.placeholder(dtype)
index = array_ops.placeholder(dtypes.int32) index = array_ops.placeholder(dtypes.int32)

View File

@ -32,7 +32,7 @@ class ReduceWindowTest(xla_test.XLATestCase):
"""Test cases for xla.reduce_window.""" """Test cases for xla.reduce_window."""
def _reduce_window(self, operand, init, reducer, **kwargs): def _reduce_window(self, operand, init, reducer, **kwargs):
with self.cached_session(): with self.session():
placeholder = array_ops.placeholder(operand.dtype) placeholder = array_ops.placeholder(operand.dtype)
with self.test_scope(): with self.test_scope():
output = xla.reduce_window(placeholder, init, reducer, **kwargs) output = xla.reduce_window(placeholder, init, reducer, **kwargs)

View File

@ -33,7 +33,7 @@ class ReshapeTest(xla_test.XLATestCase, parameterized.TestCase):
('64_bit_index', dtypes.int64)) ('64_bit_index', dtypes.int64))
def testBasic(self, index_dtype): def testBasic(self, index_dtype):
for dtype in self.numeric_types: for dtype in self.numeric_types:
with self.cached_session(): with self.session():
i = array_ops.placeholder(dtype, shape=[2, 3]) i = array_ops.placeholder(dtype, shape=[2, 3])
with self.test_scope(): with self.test_scope():
shape = constant_op.constant([3, 2], dtype=index_dtype) shape = constant_op.constant([3, 2], dtype=index_dtype)

View File

@ -51,7 +51,7 @@ class ReverseOpsTest(xla_test.XLATestCase):
def _AssertReverseEqual(self, revdims, shape): def _AssertReverseEqual(self, revdims, shape):
np.random.seed(120) np.random.seed(120)
pval = np.random.randint(0, 100, size=shape).astype(float) pval = np.random.randint(0, 100, size=shape).astype(float)
with self.cached_session(): with self.session():
with self.test_scope(): with self.test_scope():
p = array_ops.placeholder(dtypes.int32, shape=shape) p = array_ops.placeholder(dtypes.int32, shape=shape)
axis = constant_op.constant( axis = constant_op.constant(

View File

@ -35,7 +35,7 @@ class ReverseSequenceTest(xla_test.XLATestCase):
seq_lengths, seq_lengths,
truth, truth,
expected_err_re=None): expected_err_re=None):
with self.cached_session(): with self.session():
p = array_ops.placeholder(dtypes.as_dtype(x.dtype)) p = array_ops.placeholder(dtypes.as_dtype(x.dtype))
lengths = array_ops.placeholder(dtypes.as_dtype(seq_lengths.dtype)) lengths = array_ops.placeholder(dtypes.as_dtype(seq_lengths.dtype))
with self.test_scope(): with self.test_scope():

View File

@ -55,7 +55,7 @@ class RmspropTest(xla_test.XLATestCase):
def testBasic(self): def testBasic(self):
for dtype in self.float_types: for dtype in self.float_types:
for centered in [False, True]: for centered in [False, True]:
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
# Initialize variables for numpy implementation. # Initialize variables for numpy implementation.
var0_np = np.array([1.0, 2.0], dtype=dtype) var0_np = np.array([1.0, 2.0], dtype=dtype)
grads0_np = np.array([0.1, 0.1], dtype=dtype) grads0_np = np.array([0.1, 0.1], dtype=dtype)

View File

@ -78,7 +78,7 @@ class CumsumTest(xla_test.XLATestCase):
def _compare(self, x, axis, exclusive, reverse): def _compare(self, x, axis, exclusive, reverse):
np_out = handle_options(np.cumsum, x, axis, exclusive, reverse) np_out = handle_options(np.cumsum, x, axis, exclusive, reverse)
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
p = array_ops.placeholder(x.dtype) p = array_ops.placeholder(x.dtype)
tf_out = math_ops.cumsum(p, axis, exclusive, reverse).eval( tf_out = math_ops.cumsum(p, axis, exclusive, reverse).eval(
feed_dict={p: x}) feed_dict={p: x})
@ -100,7 +100,7 @@ class CumsumTest(xla_test.XLATestCase):
for dtype in self.valid_dtypes: for dtype in self.valid_dtypes:
x = np.arange(1, 6).reshape([5]).astype(dtype) x = np.arange(1, 6).reshape([5]).astype(dtype)
for axis_dtype in self.axis_dtypes(): for axis_dtype in self.axis_dtypes():
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
p = array_ops.placeholder(x.dtype) p = array_ops.placeholder(x.dtype)
axis = constant_op.constant(0, axis_dtype) axis = constant_op.constant(0, axis_dtype)
math_ops.cumsum(p, axis).eval(feed_dict={p: x}) math_ops.cumsum(p, axis).eval(feed_dict={p: x})
@ -131,7 +131,7 @@ class CumsumTest(xla_test.XLATestCase):
def testInvalidAxis(self): def testInvalidAxis(self):
x = np.arange(0, 10).reshape([2, 5]).astype(np.float32) x = np.arange(0, 10).reshape([2, 5]).astype(np.float32)
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
input_tensor = ops.convert_to_tensor(x) input_tensor = ops.convert_to_tensor(x)
with self.assertRaisesWithPredicateMatch( with self.assertRaisesWithPredicateMatch(
errors_impl.InvalidArgumentError, errors_impl.InvalidArgumentError,
@ -156,7 +156,7 @@ class CumprodTest(xla_test.XLATestCase):
def _compare(self, x, axis, exclusive, reverse): def _compare(self, x, axis, exclusive, reverse):
np_out = handle_options(np.cumprod, x, axis, exclusive, reverse) np_out = handle_options(np.cumprod, x, axis, exclusive, reverse)
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
p = array_ops.placeholder(x.dtype) p = array_ops.placeholder(x.dtype)
prod = math_ops.cumprod(p, axis, exclusive, reverse) prod = math_ops.cumprod(p, axis, exclusive, reverse)
tf_out = prod.eval(feed_dict={p: x}) tf_out = prod.eval(feed_dict={p: x})
@ -178,7 +178,7 @@ class CumprodTest(xla_test.XLATestCase):
for dtype in self.valid_dtypes: for dtype in self.valid_dtypes:
x = np.arange(1, 6).reshape([5]).astype(dtype) x = np.arange(1, 6).reshape([5]).astype(dtype)
for axis_dtype in self.axis_dtypes(): for axis_dtype in self.axis_dtypes():
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
p = array_ops.placeholder(x.dtype) p = array_ops.placeholder(x.dtype)
axis = constant_op.constant(0, axis_dtype) axis = constant_op.constant(0, axis_dtype)
math_ops.cumprod(x, axis).eval(feed_dict={p: x}) math_ops.cumprod(x, axis).eval(feed_dict={p: x})
@ -209,7 +209,7 @@ class CumprodTest(xla_test.XLATestCase):
def testInvalidAxis(self): def testInvalidAxis(self):
x = np.arange(0, 10).reshape([2, 5]).astype(np.float32) x = np.arange(0, 10).reshape([2, 5]).astype(np.float32)
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
input_tensor = ops.convert_to_tensor(x) input_tensor = ops.convert_to_tensor(x)
with self.assertRaisesWithPredicateMatch( with self.assertRaisesWithPredicateMatch(
errors_impl.InvalidArgumentError, errors_impl.InvalidArgumentError,

View File

@ -119,7 +119,7 @@ class ScatterNdTest(xla_test.XLATestCase):
self._VariableRankTest(np_scatter, tf_scatter, vtype, itype) self._VariableRankTest(np_scatter, tf_scatter, vtype, itype)
def _runScatterNd(self, indices, updates, shape): def _runScatterNd(self, indices, updates, shape):
with self.cached_session(): with self.session():
updates_placeholder = array_ops.placeholder(updates.dtype) updates_placeholder = array_ops.placeholder(updates.dtype)
indices_placeholder = array_ops.placeholder(indices.dtype) indices_placeholder = array_ops.placeholder(indices.dtype)
with self.test_scope(): with self.test_scope():

View File

@ -32,7 +32,7 @@ class SegmentReductionOpsTest(xla_test.XLATestCase):
"""Test cases for segment reduction ops.""" """Test cases for segment reduction ops."""
def _segmentReduction(self, op, data, indices, num_segments): def _segmentReduction(self, op, data, indices, num_segments):
with self.cached_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
d = array_ops.placeholder(data.dtype, shape=data.shape) d = array_ops.placeholder(data.dtype, shape=data.shape)
if isinstance(indices, int): if isinstance(indices, int):
i = array_ops.placeholder(np.int32, shape=[]) i = array_ops.placeholder(np.int32, shape=[])

View File

@ -38,7 +38,7 @@ class SelfAdjointEigOpTest(xla_test.XLATestCase, parameterized.TestCase):
n = shape[-1] n = shape[-1]
e_np, _ = np.linalg.eigh(x_np) e_np, _ = np.linalg.eigh(x_np)
with self.cached_session() as sess: with self.session() as sess:
x_tf = array_ops.placeholder(dtype) x_tf = array_ops.placeholder(dtype)
with self.test_scope(): with self.test_scope():
e, v = linalg_ops.self_adjoint_eig(x_tf) e, v = linalg_ops.self_adjoint_eig(x_tf)

View File

@ -29,7 +29,7 @@ class SliceTest(xla_test.XLATestCase):
def test1D(self): def test1D(self):
for dtype in self.numeric_types: for dtype in self.numeric_types:
with self.cached_session(): with self.session():
i = array_ops.placeholder(dtype, shape=[10]) i = array_ops.placeholder(dtype, shape=[10])
with self.test_scope(): with self.test_scope():
o = array_ops.slice(i, [2], [4]) o = array_ops.slice(i, [2], [4])
@ -42,7 +42,7 @@ class SliceTest(xla_test.XLATestCase):
def testZeroSlice(self): def testZeroSlice(self):
for dtype in self.numeric_types: for dtype in self.numeric_types:
with self.cached_session(): with self.session():
i = array_ops.placeholder(dtype, shape=[2]) i = array_ops.placeholder(dtype, shape=[2])
with self.test_scope(): with self.test_scope():
o = array_ops.slice(i, [0], [0]) o = array_ops.slice(i, [0], [0])
@ -55,7 +55,7 @@ class SliceTest(xla_test.XLATestCase):
def test3D(self): def test3D(self):
for dtype in self.numeric_types: for dtype in self.numeric_types:
with self.cached_session(): with self.session():
i = array_ops.placeholder(dtype, shape=[3, 3, 10]) i = array_ops.placeholder(dtype, shape=[3, 3, 10])
with self.test_scope(): with self.test_scope():
o = array_ops.slice(i, [1, 2, 2], [1, 1, 4]) o = array_ops.slice(i, [1, 2, 2], [1, 1, 4])
@ -77,7 +77,7 @@ class SliceTest(xla_test.XLATestCase):
def test3DWithDynamicBegin(self): def test3DWithDynamicBegin(self):
"""Tests a slice where the start offset is not known at compile time.""" """Tests a slice where the start offset is not known at compile time."""
for dtype in self.numeric_types: for dtype in self.numeric_types:
with self.cached_session(): with self.session():
i = array_ops.placeholder(dtype, shape=[3, 3, 10]) i = array_ops.placeholder(dtype, shape=[3, 3, 10])
begin = array_ops.placeholder(dtypes.int32, shape=[3]) begin = array_ops.placeholder(dtypes.int32, shape=[3])
with self.test_scope(): with self.test_scope():
@ -101,7 +101,7 @@ class SliceTest(xla_test.XLATestCase):
def test3DWithDynamicBeginAndNegativeSize(self): def test3DWithDynamicBeginAndNegativeSize(self):
"""Tests a slice where `begin` is fed dynamically and `size` contains -1.""" """Tests a slice where `begin` is fed dynamically and `size` contains -1."""
for dtype in self.numeric_types: for dtype in self.numeric_types:
with self.cached_session(): with self.session():
i = array_ops.placeholder(dtype, shape=[3, 3, 10]) i = array_ops.placeholder(dtype, shape=[3, 3, 10])
begin = array_ops.placeholder(dtypes.int32, shape=[3]) begin = array_ops.placeholder(dtypes.int32, shape=[3])
with self.test_scope(): with self.test_scope():
@ -127,7 +127,7 @@ class StridedSliceTest(xla_test.XLATestCase):
def test1D(self): def test1D(self):
for dtype in self.numeric_types: for dtype in self.numeric_types:
with self.cached_session(): with self.session():
i = array_ops.placeholder(dtype, shape=[10]) i = array_ops.placeholder(dtype, shape=[10])
with self.test_scope(): with self.test_scope():
o = array_ops.strided_slice(i, [2], [6], [2]) o = array_ops.strided_slice(i, [2], [6], [2])
@ -140,7 +140,7 @@ class StridedSliceTest(xla_test.XLATestCase):
def test1DNegativeStride(self): def test1DNegativeStride(self):
for dtype in self.numeric_types: for dtype in self.numeric_types:
with self.cached_session(): with self.session():
i = array_ops.placeholder(dtype, shape=[10]) i = array_ops.placeholder(dtype, shape=[10])
with self.test_scope(): with self.test_scope():
o = array_ops.strided_slice(i, [6], [2], [-2]) o = array_ops.strided_slice(i, [6], [2], [-2])
@ -153,7 +153,7 @@ class StridedSliceTest(xla_test.XLATestCase):
def test2DDegenerate(self): def test2DDegenerate(self):
for dtype in self.numeric_types: for dtype in self.numeric_types:
with self.cached_session(): with self.session():
i = array_ops.placeholder(dtype, shape=[2, 3]) i = array_ops.placeholder(dtype, shape=[2, 3])
with self.test_scope(): with self.test_scope():
o = array_ops.strided_slice(i, [-1, 0], [0, 3]) o = array_ops.strided_slice(i, [-1, 0], [0, 3])
@ -167,7 +167,7 @@ class StridedSliceTest(xla_test.XLATestCase):
def test2DDegenerateNegativeStride(self): def test2DDegenerateNegativeStride(self):
for dtype in self.numeric_types: for dtype in self.numeric_types:
with self.cached_session(): with self.session():
i = array_ops.placeholder(dtype, shape=[2, 3]) i = array_ops.placeholder(dtype, shape=[2, 3])
with self.test_scope(): with self.test_scope():
o = array_ops.strided_slice(i, [0, 0], [-1, 3], [-1, 1]) o = array_ops.strided_slice(i, [0, 0], [-1, 3], [-1, 1])
@ -181,7 +181,7 @@ class StridedSliceTest(xla_test.XLATestCase):
def test3D(self): def test3D(self):
for dtype in self.numeric_types: for dtype in self.numeric_types:
with self.cached_session(): with self.session():
i = array_ops.placeholder(dtype, shape=[3, 3, 10]) i = array_ops.placeholder(dtype, shape=[3, 3, 10])
with self.test_scope(): with self.test_scope():
o = array_ops.strided_slice(i, [0, 2, 2], [2, 3, 6], [1, 1, 2]) o = array_ops.strided_slice(i, [0, 2, 2], [2, 3, 6], [1, 1, 2])
@ -202,7 +202,7 @@ class StridedSliceTest(xla_test.XLATestCase):
def test3DNegativeStride(self): def test3DNegativeStride(self):
for dtype in self.numeric_types: for dtype in self.numeric_types:
with self.cached_session(): with self.session():
i = array_ops.placeholder(dtype, shape=[3, 4, 10]) i = array_ops.placeholder(dtype, shape=[3, 4, 10])
with self.test_scope(): with self.test_scope():
o = array_ops.strided_slice(i, [2, 2, 6], [0, 0, 2], [-1, -1, -2]) o = array_ops.strided_slice(i, [2, 2, 6], [0, 0, 2], [-1, -1, -2])

View File

@ -32,7 +32,7 @@ from tensorflow.python.platform import test
class XlaSortOpTest(xla_test.XLATestCase): class XlaSortOpTest(xla_test.XLATestCase):
def _assertOpOutputMatchesExpected(self, op, args, expected): def _assertOpOutputMatchesExpected(self, op, args, expected):
with self.cached_session() as session: with self.session() as session:
with self.test_scope(): with self.test_scope():
placeholders = [ placeholders = [
array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape) array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape)
@ -134,7 +134,7 @@ class XlaSortOpTest(xla_test.XLATestCase):
if bfloat16 not in self.numeric_types: if bfloat16 not in self.numeric_types:
return return
with self.cached_session() as sess: with self.session() as sess:
p = array_ops.placeholder(dtypes.bfloat16) p = array_ops.placeholder(dtypes.bfloat16)
with self.test_scope(): with self.test_scope():
topk = nn_ops.top_k(p, k=4) topk = nn_ops.top_k(p, k=4)
@ -152,7 +152,7 @@ class XlaSortOpTest(xla_test.XLATestCase):
if bfloat16 not in self.numeric_types: if bfloat16 not in self.numeric_types:
return return
with self.cached_session() as sess: with self.session() as sess:
p = array_ops.placeholder(dtypes.bfloat16) p = array_ops.placeholder(dtypes.bfloat16)
with self.test_scope(): with self.test_scope():
topk = nn_ops.top_k(p, k=6) topk = nn_ops.top_k(p, k=6)

View File

@ -72,7 +72,7 @@ class SpaceToBatchTest(xla_test.XLATestCase):
"""Tests input-output pairs for the SpaceToBatch and BatchToSpace ops.""" """Tests input-output pairs for the SpaceToBatch and BatchToSpace ops."""
def _testPad(self, inputs, paddings, block_size, outputs): def _testPad(self, inputs, paddings, block_size, outputs):
with self.cached_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
for dtype in self.float_types: for dtype in self.float_types:
# outputs = space_to_batch(inputs) # outputs = space_to_batch(inputs)
placeholder = array_ops.placeholder(dtype) placeholder = array_ops.placeholder(dtype)
@ -155,7 +155,7 @@ class SpaceToBatchNDTest(xla_test.XLATestCase):
def _testPad(self, inputs, block_shape, paddings, outputs): def _testPad(self, inputs, block_shape, paddings, outputs):
block_shape = np.array(block_shape) block_shape = np.array(block_shape)
paddings = np.array(paddings).reshape((len(block_shape), 2)) paddings = np.array(paddings).reshape((len(block_shape), 2))
with self.cached_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
for dtype in self.float_types: for dtype in self.float_types:
# TODO(b/68813416): Skip bfloat16's as the input type for direct is # TODO(b/68813416): Skip bfloat16's as the input type for direct is
# float32 and results in a mismatch, while making testDirect provide the # float32 and results in a mismatch, while making testDirect provide the

View File

@ -45,32 +45,32 @@ def _SparseToDense(sparse_indices,
class SparseToDenseTest(xla_test.XLATestCase): class SparseToDenseTest(xla_test.XLATestCase):
def testInt(self): def testInt(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
tf_ans = _SparseToDense([1, 3], [5], 1, 0) tf_ans = _SparseToDense([1, 3], [5], 1, 0)
np_ans = np.array([0, 1, 0, 1, 0]).astype(np.int32) np_ans = np.array([0, 1, 0, 1, 0]).astype(np.int32)
self.assertAllClose(np_ans, tf_ans) self.assertAllClose(np_ans, tf_ans)
def testFloat(self): def testFloat(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
tf_ans = _SparseToDense([1, 3], [5], 1.0, 0.0) tf_ans = _SparseToDense([1, 3], [5], 1.0, 0.0)
np_ans = np.array([0, 1, 0, 1, 0]).astype(np.float32) np_ans = np.array([0, 1, 0, 1, 0]).astype(np.float32)
self.assertAllClose(np_ans, tf_ans) self.assertAllClose(np_ans, tf_ans)
def testSetValue(self): def testSetValue(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
tf_ans = _SparseToDense([1, 3], [5], [1, 2], -1) tf_ans = _SparseToDense([1, 3], [5], [1, 2], -1)
np_ans = np.array([-1, 1, -1, 2, -1]).astype(np.int32) np_ans = np.array([-1, 1, -1, 2, -1]).astype(np.int32)
self.assertAllClose(np_ans, tf_ans) self.assertAllClose(np_ans, tf_ans)
def testSetSingleValue(self): def testSetSingleValue(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
tf_ans = _SparseToDense([1, 3], [5], 1, -1) tf_ans = _SparseToDense([1, 3], [5], 1, -1)
np_ans = np.array([-1, 1, -1, 1, -1]).astype(np.int32) np_ans = np.array([-1, 1, -1, 1, -1]).astype(np.int32)
self.assertAllClose(np_ans, tf_ans) self.assertAllClose(np_ans, tf_ans)
def test2d(self): def test2d(self):
# pylint: disable=bad-whitespace # pylint: disable=bad-whitespace
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
tf_ans = _SparseToDense([[1, 3], [2, 0]], [3, 4], 1, -1) tf_ans = _SparseToDense([[1, 3], [2, 0]], [3, 4], 1, -1)
np_ans = np.array([[-1, -1, -1, -1], np_ans = np.array([[-1, -1, -1, -1],
[-1, -1, -1, 1], [-1, -1, -1, 1],
@ -78,12 +78,12 @@ class SparseToDenseTest(xla_test.XLATestCase):
self.assertAllClose(np_ans, tf_ans) self.assertAllClose(np_ans, tf_ans)
def testZeroDefault(self): def testZeroDefault(self):
with self.cached_session(): with self.session():
x = sparse_ops.sparse_to_dense(2, [4], 7).eval() x = sparse_ops.sparse_to_dense(2, [4], 7).eval()
self.assertAllEqual(x, [0, 0, 7, 0]) self.assertAllEqual(x, [0, 0, 7, 0])
def test3d(self): def test3d(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
tf_ans = _SparseToDense([[1, 3, 0], [2, 0, 1]], [3, 4, 2], 1, -1) tf_ans = _SparseToDense([[1, 3, 0], [2, 0, 1]], [3, 4, 2], 1, -1)
np_ans = np.ones((3, 4, 2), dtype=np.int32) * -1 np_ans = np.ones((3, 4, 2), dtype=np.int32) * -1
np_ans[1, 3, 0] = 1 np_ans[1, 3, 0] = 1
@ -91,31 +91,31 @@ class SparseToDenseTest(xla_test.XLATestCase):
self.assertAllClose(np_ans, tf_ans) self.assertAllClose(np_ans, tf_ans)
def testDegenerateIndexMatrix(self): def testDegenerateIndexMatrix(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
tf_ans = _SparseToDense([[2], [3], [4], [5], [6], [7], [8], [9]], [10], tf_ans = _SparseToDense([[2], [3], [4], [5], [6], [7], [8], [9]], [10],
[1, 2, 3, 4, 5, 6, 7, 8], -1) [1, 2, 3, 4, 5, 6, 7, 8], -1)
self.assertAllClose([-1, -1, 1, 2, 3, 4, 5, 6, 7, 8], tf_ans) self.assertAllClose([-1, -1, 1, 2, 3, 4, 5, 6, 7, 8], tf_ans)
def testBadShape(self): def testBadShape(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
with self.assertRaisesWithPredicateMatch(ValueError, "must be rank 1"): with self.assertRaisesWithPredicateMatch(ValueError, "must be rank 1"):
_SparseToDense([1, 3], [[5], [3]], 1, -1) _SparseToDense([1, 3], [[5], [3]], 1, -1)
def testBadValue(self): def testBadValue(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
with self.assertRaisesOpError( with self.assertRaisesOpError(
r"sparse_values has incorrect shape \[2,1\], " r"sparse_values has incorrect shape \[2,1\], "
r"should be \[\] or \[2\]"): r"should be \[\] or \[2\]"):
_SparseToDense([1, 3], [5], [[5], [3]], -1) _SparseToDense([1, 3], [5], [[5], [3]], -1)
def testBadNumValues(self): def testBadNumValues(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
with self.assertRaisesOpError( with self.assertRaisesOpError(
r"sparse_values has incorrect shape \[3\], should be \[\] or \[2\]"): r"sparse_values has incorrect shape \[3\], should be \[\] or \[2\]"):
_SparseToDense([1, 3], [5], [1, 2, 3], -1) _SparseToDense([1, 3], [5], [1, 2, 3], -1)
def testBadDefault(self): def testBadDefault(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
with self.assertRaisesOpError("default_value should be a scalar"): with self.assertRaisesOpError("default_value should be a scalar"):
_SparseToDense([1, 3], [5], [1, 2], [0]) _SparseToDense([1, 3], [5], [1, 2], [0])

View File

@ -32,7 +32,7 @@ from tensorflow.python.platform import test
class StackOpTest(xla_test.XLATestCase): class StackOpTest(xla_test.XLATestCase):
def testStackPushPop(self): def testStackPushPop(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
v = array_ops.placeholder(dtypes.float32) v = array_ops.placeholder(dtypes.float32)
@ -47,7 +47,7 @@ class StackOpTest(xla_test.XLATestCase):
xla.compile(fn)[0].eval({v: [[4.0, 5.0]]})) xla.compile(fn)[0].eval({v: [[4.0, 5.0]]}))
def testStackPushPopSwap(self): def testStackPushPopSwap(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
a = np.arange(2000) a = np.arange(2000)
x = array_ops.placeholder(dtypes.float32) x = array_ops.placeholder(dtypes.float32)
@ -60,7 +60,7 @@ class StackOpTest(xla_test.XLATestCase):
self.assertAllClose(a, xla.compile(fn)[0].eval({x: a})) self.assertAllClose(a, xla.compile(fn)[0].eval({x: a}))
def testMultiStack(self): def testMultiStack(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
v = array_ops.placeholder(dtypes.float32) v = array_ops.placeholder(dtypes.float32)
def fn(): def fn():
@ -78,7 +78,7 @@ class StackOpTest(xla_test.XLATestCase):
def testSameNameStacks(self): def testSameNameStacks(self):
"""Different stacks with the same name do not interfere.""" """Different stacks with the same name do not interfere."""
with self.cached_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
v1 = array_ops.placeholder(dtypes.float32) v1 = array_ops.placeholder(dtypes.float32)
v2 = array_ops.placeholder(dtypes.float32) v2 = array_ops.placeholder(dtypes.float32)
@ -100,7 +100,7 @@ class StackOpTest(xla_test.XLATestCase):
self.assertAllClose(out2, 5.0) self.assertAllClose(out2, 5.0)
def testCloseStack(self): def testCloseStack(self):
with self.cached_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
def fn(): def fn():
h = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo") h = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo")
@ -109,7 +109,7 @@ class StackOpTest(xla_test.XLATestCase):
sess.run(xla.compile(fn)) sess.run(xla.compile(fn))
def testPushCloseStack(self): def testPushCloseStack(self):
with self.cached_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
v = array_ops.placeholder(dtypes.float32) v = array_ops.placeholder(dtypes.float32)
def fn(): def fn():

View File

@ -33,14 +33,14 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
"""Test cases for stateless random-number generator operators.""" """Test cases for stateless random-number generator operators."""
def _random_types(self, include_int=False): def _random_types(self, include_int=False):
allowed_types = {dtypes.float32, dtypes.float64, dtypes.bfloat16} allowed_types = {dtypes.float32, dtypes.bfloat16}
if include_int: if include_int:
allowed_types.update({dtypes.int32, dtypes.int64}) allowed_types.update({dtypes.int32, dtypes.int64})
return self.all_tf_types & allowed_types return self.all_tf_types & allowed_types
def testDeterminism(self): def testDeterminism(self):
# Stateless values should be equal iff the seeds are equal (roughly) # Stateless values should be equal iff the seeds are equal (roughly)
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
seeds = [(x, y) for x in range(5) for y in range(5)] * 3 # pylint: disable=g-complex-comprehension seeds = [(x, y) for x in range(5) for y in range(5)] * 3 # pylint: disable=g-complex-comprehension
for stateless_op in [ for stateless_op in [
@ -62,7 +62,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
self.assertEqual(s0 == s1, np.all(v0 == v1)) self.assertEqual(s0 == s1, np.all(v0 == v1))
def testRandomUniformIsInRange(self): def testRandomUniformIsInRange(self):
with self.cached_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
for dtype in self._random_types(include_int=True): for dtype in self._random_types(include_int=True):
maxval = 1 maxval = 1
if dtype.is_integer: if dtype.is_integer:
@ -76,7 +76,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
def testDistributionOfStatelessRandomUniform(self): def testDistributionOfStatelessRandomUniform(self):
"""Use Pearson's Chi-squared test to test for uniformity.""" """Use Pearson's Chi-squared test to test for uniformity."""
with self.cached_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
for dtype in self._random_types(include_int=True): for dtype in self._random_types(include_int=True):
seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
n = 1000 n = 1000
@ -96,7 +96,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
self.assertLess(random_test_util.chi_squared(y, 10), 16.92) self.assertLess(random_test_util.chi_squared(y, 10), 16.92)
def testRandomNormalIsFinite(self): def testRandomNormalIsFinite(self):
with self.cached_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
for dtype in self._random_types(): for dtype in self._random_types():
seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
x = stateless.stateless_random_normal( x = stateless.stateless_random_normal(
@ -106,7 +106,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
def testDistributionOfStatelessRandomNormal(self): def testDistributionOfStatelessRandomNormal(self):
"""Use Anderson-Darling test to test distribution appears normal.""" """Use Anderson-Darling test to test distribution appears normal."""
with self.cached_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
for dtype in self._random_types(): for dtype in self._random_types():
seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
n = 1000 n = 1000
@ -121,7 +121,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
def testTruncatedNormal(self): def testTruncatedNormal(self):
for dtype in self._random_types(): for dtype in self._random_types():
with self.cached_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
n = 10000000 n = 10000000
x = stateless.stateless_truncated_normal( x = stateless.stateless_truncated_normal(

View File

@ -46,7 +46,7 @@ class SvdOpTest(xla_test.XLATestCase, parameterized.TestCase):
x_np = np.random.uniform(low=-1.0, high=1.0, size=shape).astype(dtype) x_np = np.random.uniform(low=-1.0, high=1.0, size=shape).astype(dtype)
m, n = shape[-2], shape[-1] m, n = shape[-2], shape[-1]
_, s_np, _ = np.linalg.svd(x_np) _, s_np, _ = np.linalg.svd(x_np)
with self.cached_session() as sess: with self.session() as sess:
x_tf = array_ops.placeholder(dtype) x_tf = array_ops.placeholder(dtype)
with self.test_scope(): with self.test_scope():
s, u, v = linalg_ops.svd(x_tf, full_matrices=True) s, u, v = linalg_ops.svd(x_tf, full_matrices=True)

View File

@ -54,7 +54,7 @@ class TensorArrayTest(xla_test.XLATestCase):
@test_util.disable_control_flow_v2("Tries to evaluate flow") @test_util.disable_control_flow_v2("Tries to evaluate flow")
def testTensorArrayWriteRead(self): def testTensorArrayWriteRead(self):
with self.cached_session() as session, self.test_scope(): with self.session() as session, self.test_scope():
def fn(): def fn():
ta = tensor_array_ops.TensorArray( ta = tensor_array_ops.TensorArray(
@ -77,7 +77,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self.assertAllEqual([], flow_val.shape) self.assertAllEqual([], flow_val.shape)
def _testTensorArrayWritePack(self, tf_dtype): def _testTensorArrayWritePack(self, tf_dtype):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
convert = _make_converter(tf_dtype) convert = _make_converter(tf_dtype)
def fn(): def fn():
@ -99,7 +99,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self._testTensorArrayWritePack(dtype) self._testTensorArrayWritePack(dtype)
def testEmptyTensorArrayPack(self): def testEmptyTensorArrayPack(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
def fn(): def fn():
ta = tensor_array_ops.TensorArray( ta = tensor_array_ops.TensorArray(
@ -115,7 +115,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self.assertAllEqual([3, 0, 1], self.evaluate(xla.compile(fn)[0]).shape) self.assertAllEqual([3, 0, 1], self.evaluate(xla.compile(fn)[0]).shape)
def _testTensorArrayWriteConcat(self, tf_dtype): def _testTensorArrayWriteConcat(self, tf_dtype):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
convert = _make_converter(tf_dtype) convert = _make_converter(tf_dtype)
def fn(): def fn():
@ -139,7 +139,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self._testTensorArrayWriteConcat(dtype) self._testTensorArrayWriteConcat(dtype)
def _testTensorArrayUnpackRead(self, tf_dtype): def _testTensorArrayUnpackRead(self, tf_dtype):
with self.cached_session() as session, self.test_scope(): with self.session() as session, self.test_scope():
convert = _make_converter(tf_dtype) convert = _make_converter(tf_dtype)
def fn(): def fn():
@ -202,7 +202,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self._testTensorArrayUnpackReadMaybeLegacy() self._testTensorArrayUnpackReadMaybeLegacy()
def _testTensorArraySplitRead(self, tf_dtype): def _testTensorArraySplitRead(self, tf_dtype):
with self.cached_session() as session, self.test_scope(): with self.session() as session, self.test_scope():
convert = _make_converter(tf_dtype) convert = _make_converter(tf_dtype)
def fn(): def fn():
@ -265,7 +265,7 @@ class TensorArrayTest(xla_test.XLATestCase):
@test_util.disable_control_flow_v2("TensorArray.grad is not supported in v2") @test_util.disable_control_flow_v2("TensorArray.grad is not supported in v2")
def testTensorGradArrayWriteRead(self): def testTensorGradArrayWriteRead(self):
with self.cached_session() as session, self.test_scope(): with self.session() as session, self.test_scope():
def fn(): def fn():
ta = tensor_array_ops.TensorArray( ta = tensor_array_ops.TensorArray(
@ -301,7 +301,7 @@ class TensorArrayTest(xla_test.XLATestCase):
@test_util.disable_control_flow_v2("TensorArray.grad is not supported in v2") @test_util.disable_control_flow_v2("TensorArray.grad is not supported in v2")
def testTensorGradArrayDynamicWriteRead(self): def testTensorGradArrayDynamicWriteRead(self):
with self.cached_session() as session, self.test_scope(): with self.session() as session, self.test_scope():
def fn(): def fn():
ta = tensor_array_ops.TensorArray( ta = tensor_array_ops.TensorArray(
@ -342,7 +342,7 @@ class TensorArrayTest(xla_test.XLATestCase):
@test_util.disable_control_flow_v2("TensorArray.grad is not supported in v2") @test_util.disable_control_flow_v2("TensorArray.grad is not supported in v2")
def testTensorGradAccessTwiceReceiveSameObject(self): def testTensorGradAccessTwiceReceiveSameObject(self):
with self.cached_session() as session, self.test_scope(): with self.session() as session, self.test_scope():
ta_out = {} ta_out = {}
def fn(): def fn():
@ -382,7 +382,7 @@ class TensorArrayTest(xla_test.XLATestCase):
@test_util.disable_control_flow_v2("b/124334470") @test_util.disable_control_flow_v2("b/124334470")
def testTensorArrayWriteWrongIndexOrDataTypeFails(self): def testTensorArrayWriteWrongIndexOrDataTypeFails(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
def fn(): def fn():
ta = tensor_array_ops.TensorArray( ta = tensor_array_ops.TensorArray(
@ -407,7 +407,7 @@ class TensorArrayTest(xla_test.XLATestCase):
# the first type, but try to read the other type. # the first type, but try to read the other type.
if len(self.float_types) > 1: if len(self.float_types) > 1:
dtype1, dtype2 = list(self.float_types)[:2] dtype1, dtype2 = list(self.float_types)[:2]
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
def fn(): def fn():
ta = tensor_array_ops.TensorArray( ta = tensor_array_ops.TensorArray(
@ -436,7 +436,7 @@ class TensorArrayTest(xla_test.XLATestCase):
@test_util.disable_control_flow_v2("b/122315872 (split)") @test_util.disable_control_flow_v2("b/122315872 (split)")
def testTensorArraySplitIncompatibleShapesFails(self): def testTensorArraySplitIncompatibleShapesFails(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
def fn(): def fn():
ta = tensor_array_ops.TensorArray( ta = tensor_array_ops.TensorArray(
@ -489,7 +489,7 @@ class TensorArrayTest(xla_test.XLATestCase):
xla.compile(fn)[0].eval() xla.compile(fn)[0].eval()
def _testTensorArrayWriteGradientAddMultipleAdds(self, dtype): def _testTensorArrayWriteGradientAddMultipleAdds(self, dtype):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
c = lambda x: np.asarray(x, dtype=dtype.as_numpy_dtype) c = lambda x: np.asarray(x, dtype=dtype.as_numpy_dtype)
def fn(): def fn():
@ -534,7 +534,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self._testTensorArrayWriteGradientAddMultipleAdds(dtype) self._testTensorArrayWriteGradientAddMultipleAdds(dtype)
def testMultiTensorArray(self): def testMultiTensorArray(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
def fn(): def fn():
h1 = tensor_array_ops.TensorArray( h1 = tensor_array_ops.TensorArray(
@ -552,7 +552,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self.assertAllClose(9.0, self.evaluate(xla.compile(fn)[0])) self.assertAllClose(9.0, self.evaluate(xla.compile(fn)[0]))
def _testTensorArrayGradientWriteReadType(self, dtype): def _testTensorArrayGradientWriteReadType(self, dtype):
with self.cached_session() as session, self.test_scope(): with self.session() as session, self.test_scope():
c = lambda x: np.array(x, dtype=dtype) c = lambda x: np.array(x, dtype=dtype)
def fn(): def fn():
@ -610,7 +610,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self._testTensorArrayGradientWriteReadType(dtype) self._testTensorArrayGradientWriteReadType(dtype)
def _testTensorArrayGradientWritePackConcatAndRead(self): def _testTensorArrayGradientWritePackConcatAndRead(self):
with self.cached_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
def fn(): def fn():
ta = tensor_array_ops.TensorArray( ta = tensor_array_ops.TensorArray(
@ -649,7 +649,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self._testTensorArrayGradientWritePackConcatAndRead() self._testTensorArrayGradientWritePackConcatAndRead()
def testTensorArrayReadTwice(self): def testTensorArrayReadTwice(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
def fn(): def fn():
value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]]) value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]])
@ -669,7 +669,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self.assertAllEqual([1.0, -1.0], self.evaluate(xla.compile(fn))[0]) self.assertAllEqual([1.0, -1.0], self.evaluate(xla.compile(fn))[0])
def _testTensorArrayGradientUnpackRead(self): def _testTensorArrayGradientUnpackRead(self):
with self.cached_session() as session, self.test_scope(): with self.session() as session, self.test_scope():
def fn(): def fn():
ta = tensor_array_ops.TensorArray( ta = tensor_array_ops.TensorArray(
@ -701,7 +701,7 @@ class TensorArrayTest(xla_test.XLATestCase):
@test_util.disable_control_flow_v2("b/122315751(concat), b/122315872(split)") @test_util.disable_control_flow_v2("b/122315751(concat), b/122315872(split)")
def testTensorArrayGradientSplitConcat(self): def testTensorArrayGradientSplitConcat(self):
with self.cached_session() as session, self.test_scope(): with self.session() as session, self.test_scope():
def fn(): def fn():
ta = tensor_array_ops.TensorArray( ta = tensor_array_ops.TensorArray(
@ -728,7 +728,7 @@ class TensorArrayTest(xla_test.XLATestCase):
grad_vals[0]) grad_vals[0])
def testCloseTensorArray(self): def testCloseTensorArray(self):
with self.cached_session() as session, self.test_scope(): with self.session() as session, self.test_scope():
def fn(): def fn():
ta = tensor_array_ops.TensorArray( ta = tensor_array_ops.TensorArray(
@ -739,7 +739,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self.evaluate(xla.compile(fn)[0]) self.evaluate(xla.compile(fn)[0])
def testSizeTensorArray(self): def testSizeTensorArray(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
def fn(): def fn():
ta = tensor_array_ops.TensorArray( ta = tensor_array_ops.TensorArray(
@ -749,7 +749,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self.assertAllEqual(3, self.evaluate(xla.compile(fn))[0]) self.assertAllEqual(3, self.evaluate(xla.compile(fn))[0])
def testWriteCloseTensorArray(self): def testWriteCloseTensorArray(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
def fn(): def fn():
ta = tensor_array_ops.TensorArray( ta = tensor_array_ops.TensorArray(
@ -767,7 +767,7 @@ class TensorArrayTest(xla_test.XLATestCase):
# TODO(phawkins): implement while loops. # TODO(phawkins): implement while loops.
# def _testWhileLoopWritePackGradients(self, dynamic_size, dtype): # def _testWhileLoopWritePackGradients(self, dynamic_size, dtype):
# np_dtype = dtype.as_numpy_dtype # np_dtype = dtype.as_numpy_dtype
# with self.cached_session() as session, self.test_scope(): # with self.session() as session, self.test_scope():
# v0 = array_ops.identity(np.arange(3 * 5, dtype=np_dtype).reshape(3, 5)) # v0 = array_ops.identity(np.arange(3 * 5, dtype=np_dtype).reshape(3, 5))
# var = variables.Variable(np.arange(100, 105, dtype=np_dtype)) # var = variables.Variable(np.arange(100, 105, dtype=np_dtype))
# state0 = array_ops.identity(np.array([1] * 5, dtype=np_dtype)) # state0 = array_ops.identity(np.array([1] * 5, dtype=np_dtype))
@ -851,7 +851,7 @@ class TensorArrayTest(xla_test.XLATestCase):
# dynamic_size=True, dtype=dtypes.float32) # dynamic_size=True, dtype=dtypes.float32)
# def testGradSerialTwoLoops(self): # def testGradSerialTwoLoops(self):
# with self.cached_session(), self.test_scope(): # with self.session(), self.test_scope():
# num_steps = 100 # num_steps = 100
# acc = tensor_array_ops.TensorArray( # acc = tensor_array_ops.TensorArray(
# dtype=dtypes.float32, # dtype=dtypes.float32,
@ -884,7 +884,7 @@ class TensorArrayTest(xla_test.XLATestCase):
# self.assertAllClose(31.0, self.evaluate(grad)) # self.assertAllClose(31.0, self.evaluate(grad))
def testSumOfTwoReadVariablesWithoutRepeatGrad(self): def testSumOfTwoReadVariablesWithoutRepeatGrad(self):
with self.cached_session() as session, self.test_scope(): with self.session() as session, self.test_scope():
g0 = -(np.arange(3 * 5, dtype=np.float32).reshape(3, 5) + 1) g0 = -(np.arange(3 * 5, dtype=np.float32).reshape(3, 5) + 1)
def fn(): def fn():
@ -918,7 +918,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self.assertAllEqual(joint_grad_b_t, g0) self.assertAllEqual(joint_grad_b_t, g0)
def testWriteShape(self): def testWriteShape(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
def fn(): def fn():
ta = tensor_array_ops.TensorArray( ta = tensor_array_ops.TensorArray(
@ -960,7 +960,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self.evaluate(xla.compile(fn)) self.evaluate(xla.compile(fn))
def _testGradientWhenNotAllComponentsRead(self): def _testGradientWhenNotAllComponentsRead(self):
with self.cached_session() as session, self.test_scope(): with self.session() as session, self.test_scope():
def fn(): def fn():
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2) ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2)
@ -977,7 +977,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self._testGradientWhenNotAllComponentsRead() self._testGradientWhenNotAllComponentsRead()
def _testTensorArrayEvalEmpty(self): def _testTensorArrayEvalEmpty(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
def fn(): def fn():
ta = tensor_array_ops.TensorArray( ta = tensor_array_ops.TensorArray(
@ -994,7 +994,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self._testTensorArrayEvalEmpty() self._testTensorArrayEvalEmpty()
def _testTensorArrayEvalEmptyWithDefault(self): def _testTensorArrayEvalEmptyWithDefault(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
def fn(): def fn():
ta = tensor_array_ops.TensorArray( ta = tensor_array_ops.TensorArray(
@ -1023,7 +1023,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self._testTensorArrayEvalEmptyWithDefault() self._testTensorArrayEvalEmptyWithDefault()
def _testTensorArrayScatterRead(self, tf_dtype): def _testTensorArrayScatterRead(self, tf_dtype):
with self.cached_session() as session, self.test_scope(): with self.session() as session, self.test_scope():
convert = _make_converter(tf_dtype) convert = _make_converter(tf_dtype)
id0 = array_ops.placeholder(dtypes.int32) id0 = array_ops.placeholder(dtypes.int32)
id1 = array_ops.placeholder(dtypes.int32) id1 = array_ops.placeholder(dtypes.int32)
@ -1054,7 +1054,7 @@ class TensorArrayTest(xla_test.XLATestCase):
@test_util.disable_control_flow_v2("b/122315734 (scatter)") @test_util.disable_control_flow_v2("b/122315734 (scatter)")
def testTensorArrayScatterReadAndGradients(self): def testTensorArrayScatterReadAndGradients(self):
with self.cached_session() as session, self.test_scope(): with self.session() as session, self.test_scope():
id0 = array_ops.placeholder(dtypes.int32) id0 = array_ops.placeholder(dtypes.int32)
id1 = array_ops.placeholder(dtypes.int32) id1 = array_ops.placeholder(dtypes.int32)
@ -1088,7 +1088,7 @@ class TensorArrayTest(xla_test.XLATestCase):
@test_util.disable_control_flow_v2("b/122315378 (gather)") @test_util.disable_control_flow_v2("b/122315378 (gather)")
def testTensorArrayWriteGatherAndGradients(self): def testTensorArrayWriteGatherAndGradients(self):
with self.cached_session() as session, self.test_scope(): with self.session() as session, self.test_scope():
def fn(): def fn():
ta = tensor_array_ops.TensorArray( ta = tensor_array_ops.TensorArray(
@ -1118,7 +1118,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self.assertAllEqual(expected_grad, grad_vals[0]) self.assertAllEqual(expected_grad, grad_vals[0])
def testTensorArrayIdentity(self): def testTensorArrayIdentity(self):
with self.cached_session() as session, self.test_scope(): with self.session() as session, self.test_scope():
tensor_arrays = {} tensor_arrays = {}
v0 = resource_variable_ops.ResourceVariable(0.0) v0 = resource_variable_ops.ResourceVariable(0.0)

View File

@ -32,7 +32,7 @@ from tensorflow.python.platform import test
class ListOpsTest(xla_test.XLATestCase): class ListOpsTest(xla_test.XLATestCase):
def testElementShape(self): def testElementShape(self):
with self.cached_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
dim = array_ops.placeholder(dtypes.int32) dim = array_ops.placeholder(dtypes.int32)
l = list_ops.empty_tensor_list( l = list_ops.empty_tensor_list(
element_shape=(dim, 15), element_shape=(dim, 15),
@ -44,7 +44,7 @@ class ListOpsTest(xla_test.XLATestCase):
self.assertAllEqual(sess.run(e64, {dim: 7}), (7, 15)) self.assertAllEqual(sess.run(e64, {dim: 7}), (7, 15))
def testPushPop(self): def testPushPop(self):
with self.cached_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
l = list_ops.empty_tensor_list( l = list_ops.empty_tensor_list(
element_shape=(7, 15), element_shape=(7, 15),
element_dtype=dtypes.float32, element_dtype=dtypes.float32,
@ -59,7 +59,7 @@ class ListOpsTest(xla_test.XLATestCase):
self.assertAllEqual(sess.run(e1), 1.0 * np.ones((7, 15))) self.assertAllEqual(sess.run(e1), 1.0 * np.ones((7, 15)))
def testDoNotConstantFoldVariants(self): def testDoNotConstantFoldVariants(self):
with self.cached_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
val = array_ops.placeholder(dtype=dtypes.float32) val = array_ops.placeholder(dtype=dtypes.float32)
l = list_ops.empty_tensor_list( l = list_ops.empty_tensor_list(
element_shape=(7, 15), element_shape=(7, 15),
@ -78,7 +78,7 @@ class ListOpsTest(xla_test.XLATestCase):
self.assertAllEqual(sess.run(e1, {val: 1.0}), 1.0 * np.ones((7, 15))) self.assertAllEqual(sess.run(e1, {val: 1.0}), 1.0 * np.ones((7, 15)))
def testPushPopSeparateLists(self): def testPushPopSeparateLists(self):
with self.cached_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
l = list_ops.empty_tensor_list( l = list_ops.empty_tensor_list(
element_shape=[], element_shape=[],
element_dtype=dtypes.float32, element_dtype=dtypes.float32,
@ -95,7 +95,7 @@ class ListOpsTest(xla_test.XLATestCase):
self.assertEqual(result, [1.0, [2.0, 1.0], [3.0, 1.0]]) self.assertEqual(result, [1.0, [2.0, 1.0], [3.0, 1.0]])
def testEmptyTensorListNoMax(self): def testEmptyTensorListNoMax(self):
with self.cached_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
l = list_ops.empty_tensor_list( l = list_ops.empty_tensor_list(
element_shape=(7, 15), element_dtype=dtypes.float32) element_shape=(7, 15), element_dtype=dtypes.float32)
l = list_ops.tensor_list_push_back( l = list_ops.tensor_list_push_back(
@ -106,7 +106,7 @@ class ListOpsTest(xla_test.XLATestCase):
self.assertAllEqual(sess.run(e), 1.0 * np.ones((7, 15))) self.assertAllEqual(sess.run(e), 1.0 * np.ones((7, 15)))
def testEmptyTensorListMax(self): def testEmptyTensorListMax(self):
with self.cached_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
l = list_ops.empty_tensor_list( l = list_ops.empty_tensor_list(
element_shape=(10, 15), element_dtype=dtypes.float32, element_shape=(10, 15), element_dtype=dtypes.float32,
max_num_elements=2) max_num_elements=2)
@ -116,7 +116,7 @@ class ListOpsTest(xla_test.XLATestCase):
self.assertAllEqual(sess.run(e), 3.0 * np.ones((10, 15))) self.assertAllEqual(sess.run(e), 3.0 * np.ones((10, 15)))
def testListFromTensor(self): def testListFromTensor(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
t = constant_op.constant([1.0, 2.0]) t = constant_op.constant([1.0, 2.0])
l = list_ops.tensor_list_from_tensor(t, element_shape=[]) l = list_ops.tensor_list_from_tensor(t, element_shape=[])
e = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32) e = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
@ -128,7 +128,7 @@ class ListOpsTest(xla_test.XLATestCase):
self.assertAllEqual(list_ops.tensor_list_length(l), 2) self.assertAllEqual(list_ops.tensor_list_length(l), 2)
def testGetSet(self): def testGetSet(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
t = constant_op.constant([1.0, 2.0]) t = constant_op.constant([1.0, 2.0])
l = list_ops.tensor_list_from_tensor(t, element_shape=[]) l = list_ops.tensor_list_from_tensor(t, element_shape=[])
e0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32) e0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
@ -138,7 +138,7 @@ class ListOpsTest(xla_test.XLATestCase):
self.assertAllEqual(t, [3.0, 2.0]) self.assertAllEqual(t, [3.0, 2.0])
def testSetDoesNotUpdatePushIndex(self): def testSetDoesNotUpdatePushIndex(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
l = list_ops.empty_tensor_list( l = list_ops.empty_tensor_list(
element_shape=[], element_dtype=dtypes.float32, max_num_elements=2) element_shape=[], element_dtype=dtypes.float32, max_num_elements=2)
# SetItem should not change the push index. # SetItem should not change the push index.
@ -149,7 +149,7 @@ class ListOpsTest(xla_test.XLATestCase):
self.assertAllEqual(t, [5., 7.]) self.assertAllEqual(t, [5., 7.])
def testGetSetReserved(self): def testGetSetReserved(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
l = list_ops.tensor_list_reserve( l = list_ops.tensor_list_reserve(
element_dtype=dtypes.float32, element_shape=[], num_elements=2) element_dtype=dtypes.float32, element_shape=[], num_elements=2)
e0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32) e0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
@ -159,7 +159,7 @@ class ListOpsTest(xla_test.XLATestCase):
self.assertAllEqual(t, [3.0, 0.0]) self.assertAllEqual(t, [3.0, 0.0])
def testSetStackReservedUnknownElementShape(self): def testSetStackReservedUnknownElementShape(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
l = list_ops.tensor_list_reserve( l = list_ops.tensor_list_reserve(
element_dtype=dtypes.float32, element_shape=None, num_elements=2) element_dtype=dtypes.float32, element_shape=None, num_elements=2)
l = list_ops.tensor_list_set_item(l, 0, [3.0, 4.0]) l = list_ops.tensor_list_set_item(l, 0, [3.0, 4.0])
@ -167,7 +167,7 @@ class ListOpsTest(xla_test.XLATestCase):
self.assertAllEqual(t, [[3.0, 4.0], [0., 0.]]) self.assertAllEqual(t, [[3.0, 4.0], [0., 0.]])
def testPushInEmptyListWithUnknownElementShape(self): def testPushInEmptyListWithUnknownElementShape(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
l = list_ops.empty_tensor_list( l = list_ops.empty_tensor_list(
element_dtype=dtypes.float32, element_shape=None, max_num_elements=2) element_dtype=dtypes.float32, element_shape=None, max_num_elements=2)
l = list_ops.tensor_list_push_back(l, [3.0, 4.0]) l = list_ops.tensor_list_push_back(l, [3.0, 4.0])
@ -178,7 +178,7 @@ class ListOpsTest(xla_test.XLATestCase):
list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)) list_ops.tensor_list_stack(l, element_dtype=dtypes.float32))
def testGetSetReservedNonScalar(self): def testGetSetReservedNonScalar(self):
with self.cached_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
l = list_ops.tensor_list_reserve( l = list_ops.tensor_list_reserve(
element_dtype=dtypes.float32, element_dtype=dtypes.float32,
element_shape=(7, 15), element_shape=(7, 15),
@ -191,7 +191,7 @@ class ListOpsTest(xla_test.XLATestCase):
self.assertAllEqual(sess.run(e2), np.zeros((7, 15))) self.assertAllEqual(sess.run(e2), np.zeros((7, 15)))
def testStack(self): def testStack(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
l = list_ops.empty_tensor_list( l = list_ops.empty_tensor_list(
element_dtype=dtypes.float32, element_dtype=dtypes.float32,
element_shape=[], element_shape=[],
@ -205,14 +205,14 @@ class ListOpsTest(xla_test.XLATestCase):
self.assertAllEqual(t, [1.0, 2.0]) self.assertAllEqual(t, [1.0, 2.0])
def testStackWithUninitializedTensors(self): def testStackWithUninitializedTensors(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
l = list_ops.tensor_list_reserve( l = list_ops.tensor_list_reserve(
element_dtype=dtypes.float32, element_shape=[], num_elements=3) element_dtype=dtypes.float32, element_shape=[], num_elements=3)
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
self.assertAllEqual(t, [0., 0., 0.]) self.assertAllEqual(t, [0., 0., 0.])
def testZerosLikeForTensorList(self): def testZerosLikeForTensorList(self):
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
l = list_ops.empty_tensor_list( l = list_ops.empty_tensor_list(
element_dtype=dtypes.float32, element_dtype=dtypes.float32,
element_shape=[], element_shape=[],

View File

@ -31,7 +31,7 @@ from tensorflow.python.platform import googletest
class TernaryOpsTest(xla_test.XLATestCase): class TernaryOpsTest(xla_test.XLATestCase):
def _testTernary(self, op, a, b, c, expected): def _testTernary(self, op, a, b, c, expected):
with self.cached_session() as session: with self.session() as session:
with self.test_scope(): with self.test_scope():
pa = array_ops.placeholder(dtypes.as_dtype(a.dtype), a.shape, name="a") pa = array_ops.placeholder(dtypes.as_dtype(a.dtype), a.shape, name="a")
pb = array_ops.placeholder(dtypes.as_dtype(b.dtype), b.shape, name="b") pb = array_ops.placeholder(dtypes.as_dtype(b.dtype), b.shape, name="b")

View File

@ -65,7 +65,7 @@ class UnaryOpsTest(xla_test.XLATestCase):
rtol: relative tolerance for equality test. rtol: relative tolerance for equality test.
atol: absolute tolerance for equality test. atol: absolute tolerance for equality test.
""" """
with self.cached_session() as session: with self.session() as session:
with self.test_scope(): with self.test_scope():
pinp = array_ops.placeholder( pinp = array_ops.placeholder(
dtypes.as_dtype(inp.dtype), inp.shape, name="a") dtypes.as_dtype(inp.dtype), inp.shape, name="a")
@ -200,7 +200,7 @@ class UnaryOpsTest(xla_test.XLATestCase):
# Disable float16 testing for now # Disable float16 testing for now
if dtype != np.float16: if dtype != np.float16:
x = np.arange(-10, 10, 1).astype(dtype) x = np.arange(-10, 10, 1).astype(dtype)
with self.cached_session() as session: with self.session() as session:
erf_x = session.run(math_ops.erf(x)) erf_x = session.run(math_ops.erf(x))
erfc_x = session.run(math_ops.erfc(x)) erfc_x = session.run(math_ops.erfc(x))

View File

@ -44,7 +44,7 @@ class VariableOpsTest(xla_test.XLATestCase):
# Verifies that we can pass an uninitialized variable with an empty shape, # Verifies that we can pass an uninitialized variable with an empty shape,
# assign it a value, and successfully return it. # assign it a value, and successfully return it.
for dtype in self.numeric_types: for dtype in self.numeric_types:
with self.test_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
zeros = np.zeros([3, 0], dtype=dtype) zeros = np.zeros([3, 0], dtype=dtype)
v = resource_variable_ops.ResourceVariable(zeros) v = resource_variable_ops.ResourceVariable(zeros)
p = array_ops.placeholder(dtype) p = array_ops.placeholder(dtype)
@ -58,7 +58,7 @@ class VariableOpsTest(xla_test.XLATestCase):
# output and one variable update were mishandled. # output and one variable update were mishandled.
for dtype in self.numeric_types: for dtype in self.numeric_types:
init = np.array([[1, 2j], [3, 4]]).astype(dtype) init = np.array([[1, 2j], [3, 4]]).astype(dtype)
with self.test_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
v = resource_variable_ops.ResourceVariable(init) v = resource_variable_ops.ResourceVariable(init)
sess.run(variables.variables_initializer([v])) sess.run(variables.variables_initializer([v]))
p = array_ops.placeholder(dtype) p = array_ops.placeholder(dtype)
@ -72,7 +72,7 @@ class VariableOpsTest(xla_test.XLATestCase):
for dtype in self.numeric_types: for dtype in self.numeric_types:
init = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8j, 9, 10, init = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8j, 9, 10,
11]]).astype(dtype) 11]]).astype(dtype)
with self.test_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
v = resource_variable_ops.ResourceVariable(init) v = resource_variable_ops.ResourceVariable(init)
sess.run(variables.variables_initializer([v])) sess.run(variables.variables_initializer([v]))
x = v.sparse_read(2) x = v.sparse_read(2)
@ -83,7 +83,7 @@ class VariableOpsTest(xla_test.XLATestCase):
for dtype in self.numeric_types: for dtype in self.numeric_types:
init = np.array([[0, 1, 2, 3], [4, 5, 6j, 7], [8, 9, 10, init = np.array([[0, 1, 2, 3], [4, 5, 6j, 7], [8, 9, 10,
11]]).astype(dtype) 11]]).astype(dtype)
with self.test_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
v = resource_variable_ops.ResourceVariable(init) v = resource_variable_ops.ResourceVariable(init)
sess.run(variables.variables_initializer([v])) sess.run(variables.variables_initializer([v]))
x = v.sparse_read([2, 1]) x = v.sparse_read([2, 1])
@ -95,7 +95,7 @@ class VariableOpsTest(xla_test.XLATestCase):
for dtype in self.numeric_types: for dtype in self.numeric_types:
init = np.array([[0, 1, 2j, 3], [4, 5, 6, 7], [8, 9, 10, init = np.array([[0, 1, 2j, 3], [4, 5, 6, 7], [8, 9, 10,
11]]).astype(dtype) 11]]).astype(dtype)
with self.test_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
v = resource_variable_ops.ResourceVariable(init) v = resource_variable_ops.ResourceVariable(init)
sess.run(variables.variables_initializer([v])) sess.run(variables.variables_initializer([v]))
x = v.sparse_read([[2, 1], [0, 2]]) x = v.sparse_read([[2, 1], [0, 2]])
@ -109,7 +109,7 @@ class VariableOpsTest(xla_test.XLATestCase):
init = np.array([[[0, 1, 2], [3, 4, 5]], [[10, 11, 12], [13, 14, 15]], init = np.array([[[0, 1, 2], [3, 4, 5]], [[10, 11, 12], [13, 14, 15]],
[[20, 21, 22], [23, 24j, 25]], [[20, 21, 22], [23, 24j, 25]],
[[30, 31, 32], [33, 34, 35]]]).astype(dtype) [[30, 31, 32], [33, 34, 35]]]).astype(dtype)
with self.test_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
v = resource_variable_ops.ResourceVariable(init) v = resource_variable_ops.ResourceVariable(init)
sess.run(variables.variables_initializer([v])) sess.run(variables.variables_initializer([v]))
x = v.sparse_read([[2, 1], [3, 0]]) x = v.sparse_read([[2, 1], [3, 0]])
@ -122,7 +122,7 @@ class VariableOpsTest(xla_test.XLATestCase):
def testShape(self): def testShape(self):
for dtype in self.numeric_types: for dtype in self.numeric_types:
init = np.ones([2, 3]).astype(dtype) init = np.ones([2, 3]).astype(dtype)
with self.test_session() as session, self.test_scope(): with self.session() as session, self.test_scope():
v = resource_variable_ops.ResourceVariable(init) v = resource_variable_ops.ResourceVariable(init)
session.run(variables.variables_initializer([v])) session.run(variables.variables_initializer([v]))
h = v.handle h = v.handle
@ -138,7 +138,7 @@ class VariableOpsTest(xla_test.XLATestCase):
def testReadWrite(self): def testReadWrite(self):
"""Tests initialization, reading, and writing a resource variable.""" """Tests initialization, reading, and writing a resource variable."""
for dtype in self.numeric_types: for dtype in self.numeric_types:
with self.test_session() as session: with self.session() as session:
with self.test_scope(): with self.test_scope():
with variable_scope.variable_scope("ascope", use_resource=True): with variable_scope.variable_scope("ascope", use_resource=True):
x = variable_scope.get_variable( x = variable_scope.get_variable(
@ -166,7 +166,7 @@ class VariableOpsTest(xla_test.XLATestCase):
def testTraining(self): def testTraining(self):
"""Tests a gradient descent step for a simple model.""" """Tests a gradient descent step for a simple model."""
with self.test_session() as session: with self.session() as session:
with self.test_scope(): with self.test_scope():
with variable_scope.variable_scope("ascope", use_resource=True): with variable_scope.variable_scope("ascope", use_resource=True):
w = variable_scope.get_variable( w = variable_scope.get_variable(
@ -203,7 +203,7 @@ class VariableOpsTest(xla_test.XLATestCase):
for dtype in self.numeric_types: for dtype in self.numeric_types:
init = np.array([[1, 2j], [3, 4]]).astype(dtype) init = np.array([[1, 2j], [3, 4]]).astype(dtype)
update = np.array([[7, 1j], [2, 11]]).astype(dtype) update = np.array([[7, 1j], [2, 11]]).astype(dtype)
with self.test_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
v = resource_variable_ops.ResourceVariable(init) v = resource_variable_ops.ResourceVariable(init)
sess.run(variables.variables_initializer([v])) sess.run(variables.variables_initializer([v]))
p = array_ops.placeholder(dtype) p = array_ops.placeholder(dtype)
@ -219,7 +219,7 @@ class VariableOpsTest(xla_test.XLATestCase):
self.assertAllClose(update, result[2]) self.assertAllClose(update, result[2])
def testScatterAdd(self): def testScatterAdd(self):
with self.test_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
handle = resource_variable_ops.var_handle_op( handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[2, 1]) dtype=dtypes.int32, shape=[2, 1])
sess.run( sess.run(
@ -232,7 +232,7 @@ class VariableOpsTest(xla_test.XLATestCase):
self.assertAllEqual(self.evaluate(read), [[3], [7]]) self.assertAllEqual(self.evaluate(read), [[3], [7]])
def testScatterSub(self): def testScatterSub(self):
with self.test_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
handle = resource_variable_ops.var_handle_op( handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[2, 1]) dtype=dtypes.int32, shape=[2, 1])
sess.run( sess.run(
@ -245,7 +245,7 @@ class VariableOpsTest(xla_test.XLATestCase):
self.assertAllEqual(self.evaluate(read), [[4], [-1]]) self.assertAllEqual(self.evaluate(read), [[4], [-1]])
def testScatterMul(self): def testScatterMul(self):
with self.test_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
handle = resource_variable_ops.var_handle_op( handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1]) dtype=dtypes.int32, shape=[1, 1])
sess.run( sess.run(
@ -258,7 +258,7 @@ class VariableOpsTest(xla_test.XLATestCase):
self.assertEqual(self.evaluate(read), [[5]]) self.assertEqual(self.evaluate(read), [[5]])
def testScatterDiv(self): def testScatterDiv(self):
with self.test_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
handle = resource_variable_ops.var_handle_op( handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1]) dtype=dtypes.int32, shape=[1, 1])
sess.run( sess.run(
@ -271,7 +271,7 @@ class VariableOpsTest(xla_test.XLATestCase):
self.assertAllEqual(self.evaluate(read), [[2]]) self.assertAllEqual(self.evaluate(read), [[2]])
def testScatterMin(self): def testScatterMin(self):
with self.test_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
handle = resource_variable_ops.var_handle_op( handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1]) dtype=dtypes.int32, shape=[1, 1])
sess.run( sess.run(
@ -284,7 +284,7 @@ class VariableOpsTest(xla_test.XLATestCase):
self.assertEqual(self.evaluate(read), [[3]]) self.assertEqual(self.evaluate(read), [[3]])
def testScatterMax(self): def testScatterMax(self):
with self.test_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
handle = resource_variable_ops.var_handle_op( handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1]) dtype=dtypes.int32, shape=[1, 1])
sess.run( sess.run(
@ -297,7 +297,7 @@ class VariableOpsTest(xla_test.XLATestCase):
self.assertEqual(self.evaluate(read), [[6]]) self.assertEqual(self.evaluate(read), [[6]])
def testScatterUpdate(self): def testScatterUpdate(self):
with self.test_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
handle = resource_variable_ops.var_handle_op( handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1]) dtype=dtypes.int32, shape=[1, 1])
sess.run( sess.run(
@ -310,7 +310,7 @@ class VariableOpsTest(xla_test.XLATestCase):
self.assertEqual(self.evaluate(read), [[3]]) self.assertEqual(self.evaluate(read), [[3]])
def testScatterAddScalar(self): def testScatterAddScalar(self):
with self.test_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
handle = resource_variable_ops.var_handle_op( handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1]) dtype=dtypes.int32, shape=[1, 1])
sess.run( sess.run(
@ -323,7 +323,7 @@ class VariableOpsTest(xla_test.XLATestCase):
self.assertEqual(self.evaluate(read), [[3]]) self.assertEqual(self.evaluate(read), [[3]])
def testScatterSubScalar(self): def testScatterSubScalar(self):
with self.test_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
handle = resource_variable_ops.var_handle_op( handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1]) dtype=dtypes.int32, shape=[1, 1])
sess.run( sess.run(
@ -336,7 +336,7 @@ class VariableOpsTest(xla_test.XLATestCase):
self.assertEqual(self.evaluate(read), [[-1]]) self.assertEqual(self.evaluate(read), [[-1]])
def testScatterMulScalar(self): def testScatterMulScalar(self):
with self.test_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
handle = resource_variable_ops.var_handle_op( handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1]) dtype=dtypes.int32, shape=[1, 1])
sess.run( sess.run(
@ -349,7 +349,7 @@ class VariableOpsTest(xla_test.XLATestCase):
self.assertEqual(self.evaluate(read), [[5]]) self.assertEqual(self.evaluate(read), [[5]])
def testScatterDivScalar(self): def testScatterDivScalar(self):
with self.test_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
handle = resource_variable_ops.var_handle_op( handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1]) dtype=dtypes.int32, shape=[1, 1])
sess.run( sess.run(
@ -362,7 +362,7 @@ class VariableOpsTest(xla_test.XLATestCase):
self.assertEqual(self.evaluate(read), [[2]]) self.assertEqual(self.evaluate(read), [[2]])
def testScatterMinScalar(self): def testScatterMinScalar(self):
with self.test_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
handle = resource_variable_ops.var_handle_op( handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1]) dtype=dtypes.int32, shape=[1, 1])
sess.run( sess.run(
@ -375,7 +375,7 @@ class VariableOpsTest(xla_test.XLATestCase):
self.assertEqual(self.evaluate(read), [[3]]) self.assertEqual(self.evaluate(read), [[3]])
def testScatterMaxScalar(self): def testScatterMaxScalar(self):
with self.test_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
handle = resource_variable_ops.var_handle_op( handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1]) dtype=dtypes.int32, shape=[1, 1])
sess.run( sess.run(
@ -388,7 +388,7 @@ class VariableOpsTest(xla_test.XLATestCase):
self.assertEqual(self.evaluate(read), [[6]]) self.assertEqual(self.evaluate(read), [[6]])
def testScatterNdAddOps(self): def testScatterNdAddOps(self):
with self.test_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
handle = resource_variable_ops.var_handle_op( handle = resource_variable_ops.var_handle_op(
dtype=dtypes.float32, shape=[8]) dtype=dtypes.float32, shape=[8])
sess.run( sess.run(
@ -403,7 +403,7 @@ class VariableOpsTest(xla_test.XLATestCase):
self.assertAllClose(expected, self.evaluate(read)) self.assertAllClose(expected, self.evaluate(read))
def testScatterNdUpdateAddOps(self): def testScatterNdUpdateAddOps(self):
with self.test_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
handle = resource_variable_ops.var_handle_op( handle = resource_variable_ops.var_handle_op(
dtype=dtypes.float32, shape=[8]) dtype=dtypes.float32, shape=[8])
sess.run( sess.run(
@ -433,7 +433,7 @@ class StridedSliceAssignChecker(object):
self.which_mode = 1 - self.which_mode self.which_mode = 1 - self.which_mode
value = np.array(value).astype(self.dtype) value = np.array(value).astype(self.dtype)
with self.test.test_session() as sess, self.test.test_scope(): with self.test.session() as sess, self.test.test_scope():
x = constant_op.constant(self.x_np, dtype=self.dtype) x = constant_op.constant(self.x_np, dtype=self.dtype)
var = resource_variable_ops.ResourceVariable(x) var = resource_variable_ops.ResourceVariable(x)
sess.run(variables.variables_initializer([var])) sess.run(variables.variables_initializer([var]))
@ -487,7 +487,7 @@ class SliceAssignTest(xla_test.XLATestCase):
def testUninitialized(self): def testUninitialized(self):
with self.assertRaisesRegexp(errors.FailedPreconditionError, with self.assertRaisesRegexp(errors.FailedPreconditionError,
"uninitialized variable"): "uninitialized variable"):
with self.test_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
v = resource_variable_ops.ResourceVariable([1, 2]) v = resource_variable_ops.ResourceVariable([1, 2])
sess.run(v[:].assign([1, 2])) sess.run(v[:].assign([1, 2]))

View File

@ -49,7 +49,7 @@ class WhileTest(xla_test.XLATestCase):
def loop_cond(step): def loop_cond(step):
return step < 10 return step < 10
with self.cached_session() as sess: with self.session() as sess:
init_index = array_ops.placeholder(dtypes.int32, []) init_index = array_ops.placeholder(dtypes.int32, [])
with self.test_scope(): with self.test_scope():
loop_outputs = xla.while_loop([init_index], loop_cond, loop_body) loop_outputs = xla.while_loop([init_index], loop_cond, loop_body)
@ -71,7 +71,7 @@ class WhileTest(xla_test.XLATestCase):
del rsum del rsum
return step < 10 return step < 10
with self.cached_session() as sess: with self.session() as sess:
init_index = array_ops.placeholder(dtypes.int32, []) init_index = array_ops.placeholder(dtypes.int32, [])
init_sum = array_ops.placeholder(dtypes.float32, []) init_sum = array_ops.placeholder(dtypes.float32, [])
with self.test_scope(): with self.test_scope():
@ -97,7 +97,7 @@ class WhileTest(xla_test.XLATestCase):
del rsum del rsum
return step < 10 return step < 10
with self.cached_session() as sess: with self.session() as sess:
init_index = array_ops.placeholder(dtypes.int32, []) init_index = array_ops.placeholder(dtypes.int32, [])
init_sum = array_ops.placeholder(dtypes.complex64, []) init_sum = array_ops.placeholder(dtypes.complex64, [])
with self.test_scope(): with self.test_scope():
@ -123,7 +123,7 @@ class WhileTest(xla_test.XLATestCase):
del x del x
return step < 10 return step < 10
with self.cached_session() as sess: with self.session() as sess:
init_index = array_ops.placeholder(dtypes.int32, []) init_index = array_ops.placeholder(dtypes.int32, [])
with self.test_scope(): with self.test_scope():
loop_outputs = xla.while_loop([init_index, 42], loop_cond, loop_body) loop_outputs = xla.while_loop([init_index, 42], loop_cond, loop_body)
@ -134,7 +134,7 @@ class WhileTest(xla_test.XLATestCase):
def _testMaxItersSimple(self): def _testMaxItersSimple(self):
if is_compile_on_demand(): if is_compile_on_demand():
self.skipTest("list_ops are not supported in cpu_ondemand") self.skipTest("list_ops are not supported in cpu_ondemand")
with self.cached_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
xla_context = control_flow_ops.XLAControlFlowContext() xla_context = control_flow_ops.XLAControlFlowContext()
xla_context.Enter() xla_context.Enter()
v = constant_op.constant(1.0) v = constant_op.constant(1.0)
@ -168,7 +168,7 @@ class WhileTest(xla_test.XLATestCase):
def _testNestedWhileLoopWithMaxItersFromOuterContext(self): def _testNestedWhileLoopWithMaxItersFromOuterContext(self):
if is_compile_on_demand(): if is_compile_on_demand():
self.skipTest("list_ops are not supported in cpu_ondemand") self.skipTest("list_ops are not supported in cpu_ondemand")
with self.cached_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
xla_context = control_flow_ops.XLAControlFlowContext() xla_context = control_flow_ops.XLAControlFlowContext()
xla_context.Enter() xla_context.Enter()
v = constant_op.constant(1.0) v = constant_op.constant(1.0)
@ -229,7 +229,7 @@ class WhileTest(xla_test.XLATestCase):
def testMap(self): def testMap(self):
if is_compile_on_demand(): if is_compile_on_demand():
self.skipTest("list_ops are not supported in cpu_ondemand") self.skipTest("list_ops are not supported in cpu_ondemand")
with self.cached_session(), self.test_scope(): with self.session(), self.test_scope():
xla_context = control_flow_ops.XLAControlFlowContext() xla_context = control_flow_ops.XLAControlFlowContext()
xla_context.Enter() xla_context.Enter()
nums = [1, 2, 3, 4, 5, 6] nums = [1, 2, 3, 4, 5, 6]

View File

@ -37,7 +37,7 @@ class XlaDeviceTest(xla_test.XLATestCase):
[16384, 1], [1, 16384], [1, 20000, 1, 1]] [16384, 1], [1, 16384], [1, 20000, 1, 1]]
for dtype in self.numeric_types: for dtype in self.numeric_types:
for shape in shapes: for shape in shapes:
with self.cached_session() as sess: with self.session() as sess:
with ops.device("CPU"): with ops.device("CPU"):
x = array_ops.placeholder(dtype, shape) x = array_ops.placeholder(dtype, shape)
with self.test_scope(): with self.test_scope():
@ -58,7 +58,7 @@ class XlaDeviceTest(xla_test.XLATestCase):
]) ])
shape = (10, 10) shape = (10, 10)
for unsupported_dtype in test_types - self.all_types: for unsupported_dtype in test_types - self.all_types:
with self.cached_session() as sess: with self.session() as sess:
with ops.device("CPU"): with ops.device("CPU"):
x = array_ops.placeholder(unsupported_dtype, shape) x = array_ops.placeholder(unsupported_dtype, shape)
with self.test_scope(): with self.test_scope():
@ -78,7 +78,7 @@ class XlaDeviceTest(xla_test.XLATestCase):
pass pass
def testControlTrigger(self): def testControlTrigger(self):
with self.cached_session() as sess: with self.session() as sess:
with self.test_scope(): with self.test_scope():
x = gen_control_flow_ops.control_trigger() x = gen_control_flow_ops.control_trigger()
self.evaluate(x) self.evaluate(x)

View File

@ -35,7 +35,7 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase):
def _assertOpOutputMatchesExpected(self, op, args, expected, def _assertOpOutputMatchesExpected(self, op, args, expected,
equality_fn=None): equality_fn=None):
with self.test_session() as session: with self.session() as session:
with self.test_scope(): with self.test_scope():
placeholders = [ placeholders = [
array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape) array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape)
@ -310,7 +310,7 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase):
dtype=dtype)) dtype=dtype))
def testDynamicSliceWithIncorrectStartIndicesShape(self): def testDynamicSliceWithIncorrectStartIndicesShape(self):
with self.test_session() as session: with self.session() as session:
with self.test_scope(): with self.test_scope():
output = xla.dynamic_slice( output = xla.dynamic_slice(
np.arange(1000, dtype=np.int32).reshape([10, 10, 10]), np.arange(1000, dtype=np.int32).reshape([10, 10, 10]),
@ -323,7 +323,7 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase):
r'but input rank is 3 and start_indices has shape \[2\].*')) r'but input rank is 3 and start_indices has shape \[2\].*'))
def testDynamicSliceWithIncorrectSizeIndicesShape(self): def testDynamicSliceWithIncorrectSizeIndicesShape(self):
with self.test_session() as session: with self.session() as session:
with self.test_scope(): with self.test_scope():
output = xla.dynamic_slice( output = xla.dynamic_slice(
np.arange(1000, dtype=np.int32).reshape([10, 10, 10]), np.arange(1000, dtype=np.int32).reshape([10, 10, 10]),

View File

@ -26,6 +26,7 @@ import re
import numpy as np import numpy as np
from tensorflow.contrib.compiler import jit from tensorflow.contrib.compiler import jit
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.core.framework import types_pb2 from tensorflow.core.framework import types_pb2
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session from tensorflow.python.client import session
@ -199,10 +200,10 @@ class XLATestCase(test.TestCase):
logging.info('End test case: %s', self._testMethodName) logging.info('End test case: %s', self._testMethodName)
@contextlib.contextmanager @contextlib.contextmanager
def test_session(self): def session(self):
"""Custom implementation of test_session() for XLA tests. """Custom implementation of session() for XLA tests.
We override the standard Tensorflow test_session() since it is too We override the standard Tensorflow session() since it is too
specific to CPU and GPU tests. In particular, we want to disable soft specific to CPU and GPU tests. In particular, we want to disable soft
placement and explicitly assign ops to devices under test. placement and explicitly assign ops to devices under test.
@ -210,9 +211,25 @@ class XLATestCase(test.TestCase):
A session to use when running a test case. A session to use when running a test case.
""" """
graph = ops.Graph() graph = ops.Graph()
with session.Session(graph=graph) as sess, graph.as_default(): config = config_pb2.ConfigProto()
# Grappler can constant fold TensorListFromTensor ops into DT_VARIANT
# constants which XLA does not understand. So disable constant folding in
# these tests.
config.graph_options.rewrite_options.constant_folding = (
rewriter_config_pb2.RewriterConfig.OFF)
with session.Session(
graph=graph, config=config) as sess, graph.as_default():
yield sess yield sess
def cached_session(self):
raise NotImplementedError(
'cached_session not supported on XLATestCase, please use session')
def test_session(self):
raise NotImplementedError(
'test_session not supported on XLATestCase, please use session')
@contextlib.contextmanager @contextlib.contextmanager
def test_scope(self): def test_scope(self):
"""Test scope that runs tests on a Tensorflow/XLA device. """Test scope that runs tests on a Tensorflow/XLA device.
@ -268,6 +285,7 @@ def Benchmark(tf_bench,
for fetch in fetches: for fetch in fetches:
targets.append(array_ops.identity(fetch).op) targets.append(array_ops.identity(fetch).op)
# TODO(b/132430685): Should we allow soft placement here?
config = config_pb2.ConfigProto(allow_soft_placement=True) config = config_pb2.ConfigProto(allow_soft_placement=True)
with session.Session(config=config) as sess: with session.Session(config=config) as sess:
sess.run(variables.global_variables_initializer()) sess.run(variables.global_variables_initializer())

View File

@ -87,7 +87,6 @@ void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) {
XlaArgMaxOp::XlaArgMaxOp(OpKernelConstruction* ctx) XlaArgMaxOp::XlaArgMaxOp(OpKernelConstruction* ctx)
: XlaArgMinMaxOp(ctx, /*is_min=*/false) {} : XlaArgMinMaxOp(ctx, /*is_min=*/false) {}
REGISTER_XLA_OP(Name("ArgMax") REGISTER_XLA_OP(Name("ArgMax")
.Device(DEVICE_GPU_XLA_JIT)
.CompileTimeConstantInput("dimension"), .CompileTimeConstantInput("dimension"),
XlaArgMaxOp); XlaArgMaxOp);

View File

@ -237,36 +237,35 @@ class FusedConv2DBiasActivationTest(object):
# This is to guarantee that there are always negative values after # This is to guarantee that there are always negative values after
# bias add so that we can test whether relu works correctly. # bias add so that we can test whether relu works correctly.
x3 = bias x3 = bias
with self.cached_session(use_gpu=True), self.test_scope(): t1 = constant_op.constant(x1, shape=tensor_in_sizes, dtype=dtype)
t1 = constant_op.constant(x1, shape=tensor_in_sizes, dtype=dtype) t2 = constant_op.constant(x2, shape=filter_in_sizes, dtype=dtype)
t2 = constant_op.constant(x2, shape=filter_in_sizes, dtype=dtype) fused_t2 = t2
fused_t2 = t2 if filter_format == "OIHW":
if filter_format == "OIHW": fused_t2 = _HwioToOihw(t2)
fused_t2 = _HwioToOihw(t2) t3 = constant_op.constant(x3, shape=[bias_size], dtype=dtype)
t3 = constant_op.constant(x3, shape=[bias_size], dtype=dtype) strides = [1] + strides + [1]
strides = [1] + strides + [1] if data_format == "NCHW":
if data_format == "NCHW": t1 = test_util.NHWCToNCHW(t1)
t1 = test_util.NHWCToNCHW(t1) strides = test_util.NHWCToNCHW(strides)
strides = test_util.NHWCToNCHW(strides) output = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
output = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( t1,
t1, fused_t2,
fused_t2, t3,
t3, strides=strides,
strides=strides, padding=padding,
padding=padding, data_format=data_format,
data_format=data_format, filter_format=filter_format,
filter_format=filter_format, activation_mode=activation_mode)
activation_mode=activation_mode) ref_conv_output = nn_ops.conv2d(
ref_conv_output = nn_ops.conv2d( t1, t2, strides=strides, padding=padding, data_format=data_format)
t1, t2, strides=strides, padding=padding, data_format=data_format) ref_bias_output = nn_ops.bias_add(
ref_bias_output = nn_ops.bias_add( ref_conv_output, t3, data_format=data_format)
ref_conv_output, t3, data_format=data_format) ref_output = nn_ops.relu(ref_bias_output)
ref_output = nn_ops.relu(ref_bias_output) if data_format == "NCHW":
if data_format == "NCHW": output = test_util.NCHWToNHWC(output)
output = test_util.NCHWToNHWC(output) ref_output = test_util.NCHWToNHWC(ref_output)
ref_output = test_util.NCHWToNHWC(ref_output)
return output, ref_output return output, ref_output
def CompareFwdValues(self, tensor_in_sizes, filter_in_sizes, conv_strides, def CompareFwdValues(self, tensor_in_sizes, filter_in_sizes, conv_strides,
padding): padding):
@ -285,62 +284,62 @@ class FusedConv2DBiasActivationTest(object):
x3 = np.random.rand(*[filter_in_sizes[-1]]).astype(np.float32) x3 = np.random.rand(*[filter_in_sizes[-1]]).astype(np.float32)
def _SetupVal(data_format, use_gpu): def _SetupVal(data_format, use_gpu):
with self.cached_session(use_gpu=use_gpu), self.test_scope(): t1 = constant_op.constant(x1, shape=tensor_in_sizes)
t1 = constant_op.constant(x1, shape=tensor_in_sizes) t2 = constant_op.constant(x2, shape=filter_in_sizes)
t2 = constant_op.constant(x2, shape=filter_in_sizes) t3 = constant_op.constant(x3, shape=[filter_in_sizes[-1]])
t3 = constant_op.constant(x3, shape=[filter_in_sizes[-1]]) strides = [1] + conv_strides + [1]
strides = [1] + conv_strides + [1] if data_format == "NCHW":
if data_format == "NCHW": t1 = test_util.NHWCToNCHW(t1)
t1 = test_util.NHWCToNCHW(t1) strides = test_util.NHWCToNCHW(strides)
strides = test_util.NHWCToNCHW(strides) output = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
output = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( t1,
t1, t2,
t2, t3,
t3, strides=strides,
strides=strides, padding=padding,
padding=padding, data_format=data_format,
data_format=data_format, activation_mode="Relu")
activation_mode="Relu")
if data_format == "NCHW": if data_format == "NCHW":
output = test_util.NCHWToNHWC(output) output = test_util.NCHWToNHWC(output)
return output return output
tensors = [] with self.session() as sess, self.test_scope():
for (data_format, use_gpu) in _GetTestConfigs(): tensors = []
tensors.append(_SetupVal(data_format, use_gpu)) for (data_format, use_gpu) in _GetTestConfigs():
with self.cached_session() as sess, self.test_scope(): tensors.append(_SetupVal(data_format, use_gpu))
values = sess.run(tensors) values = sess.run(tensors)
for i in range(1, len(values)): for i in range(1, len(values)):
self.assertAllClose(values[0], values[i], rtol=1e-3, atol=1e-3) self.assertAllClose(values[0], values[i], rtol=1e-3, atol=1e-3)
def _VerifyValues(self, tensor_in_sizes, filter_in_sizes, bias, strides, def _VerifyValues(self, tensor_in_sizes, filter_in_sizes, bias, strides,
padding): padding):
tensors = [] with self.session() as sess, self.test_scope():
ref_tensors = [] tensors = []
for (data_format, use_gpu) in _GetTestConfigs(): ref_tensors = []
for dtype in self._DtypesToTest(use_gpu): for (data_format, use_gpu) in _GetTestConfigs():
for filter_format in self._FilterFormatsToTest(use_gpu): for dtype in self._DtypesToTest(use_gpu):
result, expected = self._SetupValuesForDevice( for filter_format in self._FilterFormatsToTest(use_gpu):
tensor_in_sizes, filter_in_sizes, bias, strides, padding, "Relu", result, expected = self._SetupValuesForDevice(
data_format, filter_format, dtype) tensor_in_sizes, filter_in_sizes, bias, strides, padding,
tensors.append(result) "Relu", data_format, filter_format, dtype)
ref_tensors.append(expected) tensors.append(result)
with self.cached_session() as sess, self.test_scope(): ref_tensors.append(expected)
values = sess.run(tensors)
ref_values = sess.run(ref_tensors) values = sess.run(tensors)
for i in range(len(tensors)): ref_values = sess.run(ref_tensors)
conv = tensors[i] for i in range(len(tensors)):
value = values[i] conv = tensors[i]
ref_value = ref_values[i] value = values[i]
tf_logging.info("expected = %s", ref_value) ref_value = ref_values[i]
tf_logging.info("actual = %s", value) tf_logging.info("expected = %s", ref_value)
tol = 1e-5 tf_logging.info("actual = %s", value)
if value.dtype == np.float16: tol = 1e-5
tol = 1e-3 if value.dtype == np.float16:
self.assertAllClose( tol = 1e-3
np.ravel(ref_value), np.ravel(value), atol=tol, rtol=tol) self.assertAllClose(
self.assertShapeEqual(value, conv) np.ravel(ref_value), np.ravel(value), atol=tol, rtol=tol)
self.assertShapeEqual(value, conv)
def testConv2D1x1Filter(self, gpu_only=True): def testConv2D1x1Filter(self, gpu_only=True):
if gpu_only and not test.is_gpu_available(): if gpu_only and not test.is_gpu_available():
@ -537,7 +536,7 @@ class FusedConv2DBiasActivationTest(object):
if gpu_only and not test.is_gpu_available(): if gpu_only and not test.is_gpu_available():
tf_logging.info("Skipping OpEdgeCases tests.") tf_logging.info("Skipping OpEdgeCases tests.")
return return
with self.cached_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
# Illegal strides. # Illegal strides.
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
errors_impl.UnimplementedError, errors_impl.UnimplementedError,
@ -904,7 +903,7 @@ class FusedConvInt8CPUTests(object):
bias_scale = test_param["bias_scale"] bias_scale = test_param["bias_scale"]
padding_type = test_param["padding_type"] padding_type = test_param["padding_type"]
with self.cached_session(use_gpu=False) as sess, self.test_scope(): with self.session() as sess, self.test_scope():
conv_input, _, _ = gen_array_ops.quantize_v2( conv_input, _, _ = gen_array_ops.quantize_v2(
random_ops.random_uniform( random_ops.random_uniform(
[batch_size, input_height, input_width, input_channels], [batch_size, input_height, input_width, input_channels],
@ -995,7 +994,7 @@ class FusedConvInt8CorrespondenceTests(object):
bias_scale = test_param["bias_scale"] bias_scale = test_param["bias_scale"]
padding_type = test_param["padding_type"] padding_type = test_param["padding_type"]
with self.cached_session(use_gpu=True) as sess, self.test_scope(): with self.session() as sess, self.test_scope():
conv_input, _, _ = gen_array_ops.quantize_v2( conv_input, _, _ = gen_array_ops.quantize_v2(
random_ops.random_uniform( random_ops.random_uniform(
[batch_size, input_channels // 4, input_height, input_width, 4], [batch_size, input_channels // 4, input_height, input_width, 4],

View File

@ -30,7 +30,7 @@ from tensorflow.python.platform import test
class ResamplerOpsTest(xla_test.XLATestCase): class ResamplerOpsTest(xla_test.XLATestCase):
def _assertForwardOpMatchesExpected(self, image_np, warp_np, expected): def _assertForwardOpMatchesExpected(self, image_np, warp_np, expected):
with self.test_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
input_image = array_ops.placeholder(image_np.dtype) input_image = array_ops.placeholder(image_np.dtype)
warp = array_ops.placeholder(warp_np.dtype) warp = array_ops.placeholder(warp_np.dtype)
resampled = resampler.resampler(input_image, warp, name='resampler') resampled = resampler.resampler(input_image, warp, name='resampler')
@ -41,7 +41,7 @@ class ResamplerOpsTest(xla_test.XLATestCase):
def _assertBackwardOpMatchesExpected(self, input_np, warp_np, grad_output_np, def _assertBackwardOpMatchesExpected(self, input_np, warp_np, grad_output_np,
expected_grad_data, expected_grad_warp): expected_grad_data, expected_grad_warp):
with self.cached_session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
input_image = array_ops.placeholder(input_np.dtype) input_image = array_ops.placeholder(input_np.dtype)
warp = array_ops.placeholder(warp_np.dtype) warp = array_ops.placeholder(warp_np.dtype)
grad_output = array_ops.placeholder(grad_output_np.dtype) grad_output = array_ops.placeholder(grad_output_np.dtype)

View File

@ -1110,6 +1110,8 @@ Status ColocationGraph::GetDevicesForNode(
"Could not satisfy explicit device specification '", "Could not satisfy explicit device specification '",
node->requested_device(), "' because no supported kernel for ", node->requested_device(), "' because no supported kernel for ",
specified_device_name.type, " devices is available.", debug_info, specified_device_name.type, " devices is available.", debug_info,
"\nOp: ", node->type_string(),
"\nNode attrs: ", node->attrs().DebugString(),
"\nRegistered kernels:\n", "\nRegistered kernels:\n",
KernelsRegisteredForOp(node->type_string())); KernelsRegisteredForOp(node->type_string()));
} else { } else {
@ -1302,17 +1304,10 @@ Status ColocationGraph::InitializeMember(const Node& node, Member* member) {
for (Device* d : device_set_.devices()) { for (Device* d : device_set_.devices()) {
registered_device_types.insert(d->device_type()); registered_device_types.insert(d->device_type());
} }
std::vector<string> attr_key_vals;
for (const auto& it : node.attrs()) {
const string& name = it.first;
const AttrValue& attr_value = it.second;
attr_key_vals.push_back(
strings::StrCat(name, "=", SummarizeAttrValue(attr_value)));
}
return errors::InvalidArgument( return errors::InvalidArgument(
"No OpKernel was registered to support Op '", node.type_string(), "No OpKernel was registered to support Op '", node.type_string(),
"' used by ", errors::FormatNodeNameForError(node.name()), "' used by ", errors::FormatNodeNameForError(node.name()),
"with these attrs: [", str_util::Join(attr_key_vals, ", "), "with these attrs: [", node.attrs().DebugString(),
"]\n" "]\n"
"Registered devices: [", "Registered devices: [",
str_util::Join(registered_device_types, ", "), "]\n", str_util::Join(registered_device_types, ", "), "]\n",

View File

@ -19,6 +19,8 @@ limitations under the License.
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/graph.pb_text.h" #include "tensorflow/core/framework/graph.pb_text.h"
#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op.h"
@ -82,6 +84,18 @@ string AttrSlice::SummarizeNode() const {
"[", SummarizeAttrsHelper(*this, StringPiece()), "]"); "[", SummarizeAttrsHelper(*this, StringPiece()), "]");
} }
string AttrSlice::DebugString() const {
std::vector<string> attr_key_vals;
attr_key_vals.reserve(attrs_->size());
for (const auto& it : *this) {
const string& name = it.first;
const AttrValue& attr_value = it.second;
attr_key_vals.push_back(
absl::StrCat(name, "=", SummarizeAttrValue(attr_value)));
}
return absl::StrJoin(attr_key_vals, ", ");
}
string SummarizeNode(const Node& node) { return SummarizeNodeDef(node.def()); } string SummarizeNode(const Node& node) { return SummarizeNodeDef(node.def()); }
string SummarizeNodeDef(const NodeDef& node_def) { string SummarizeNodeDef(const NodeDef& node_def) {

View File

@ -173,6 +173,8 @@ class AttrSlice {
AttrValueMap::const_iterator begin() const { return attrs_->begin(); } AttrValueMap::const_iterator begin() const { return attrs_->begin(); }
AttrValueMap::const_iterator end() const { return attrs_->end(); } AttrValueMap::const_iterator end() const { return attrs_->end(); }
string DebugString() const;
private: private:
const NodeDef* ndef_; const NodeDef* ndef_;
const AttrValueMap* attrs_; const AttrValueMap* attrs_;

View File

@ -196,6 +196,7 @@ TF_CALL_ALL_TYPES(REGISTER_CPU_KERNEL);
// the conversion from uint8 to quint8. // the conversion from uint8 to quint8.
REGISTER_KERNEL(CPU, quint8); REGISTER_KERNEL(CPU, quint8);
REGISTER_KERNEL(CPU, quint16); REGISTER_KERNEL(CPU, quint16);
REGISTER_KERNEL(CPU, uint32);
#undef REGISTER_CPU_KERNEL #undef REGISTER_CPU_KERNEL
#ifdef TENSORFLOW_USE_SYCL #ifdef TENSORFLOW_USE_SYCL

View File

@ -137,6 +137,7 @@ struct FillFunctor<Eigen::ThreadPoolDevice, T> {
TF_CALL_ALL_TYPES(DEFINE_FILL_CPU); TF_CALL_ALL_TYPES(DEFINE_FILL_CPU);
DEFINE_FILL_CPU(quint8); DEFINE_FILL_CPU(quint8);
DEFINE_FILL_CPU(quint16); DEFINE_FILL_CPU(quint16);
DEFINE_FILL_CPU(uint32);
#undef DEFINE_FILL_CPU #undef DEFINE_FILL_CPU
#ifdef TENSORFLOW_USE_SYCL #ifdef TENSORFLOW_USE_SYCL