Move from deprecated self.test_session() to self.cached_session().
self.test_session() has been deprecated in 9962eb5e84
as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about:
* the fact that the session may be reused.
* the session is not closed even when doing a "with self.test_session()" statement.
PiperOrigin-RevId: 214300210
This commit is contained in:
parent
5fbb064ba1
commit
28eeda839f
@ -92,7 +92,7 @@ class ErrorsTest(tf.test.TestCase):
|
|||||||
compiled_fn = ag.to_graph(test_fn)
|
compiled_fn = ag.to_graph(test_fn)
|
||||||
|
|
||||||
with self.assertRaises(ag.TfRuntimeError) as error:
|
with self.assertRaises(ag.TfRuntimeError) as error:
|
||||||
with self.test_session() as sess:
|
with self.cached_session() as sess:
|
||||||
x = compiled_fn(tf.constant([4, 8]))
|
x = compiled_fn(tf.constant([4, 8]))
|
||||||
with ag.improved_errors(compiled_fn):
|
with ag.improved_errors(compiled_fn):
|
||||||
sess.run(x)
|
sess.run(x)
|
||||||
@ -134,7 +134,7 @@ class ErrorsTest(tf.test.TestCase):
|
|||||||
# frame with "g" as the function name but because we don't yet add
|
# frame with "g" as the function name but because we don't yet add
|
||||||
# try/except blocks to inner functions the name is "tf__g".
|
# try/except blocks to inner functions the name is "tf__g".
|
||||||
with self.assertRaises(ag.TfRuntimeError) as error:
|
with self.assertRaises(ag.TfRuntimeError) as error:
|
||||||
with self.test_session() as sess:
|
with self.cached_session() as sess:
|
||||||
x = compiled_fn(tf.constant([4, 8]))
|
x = compiled_fn(tf.constant([4, 8]))
|
||||||
with ag.improved_errors(compiled_fn):
|
with ag.improved_errors(compiled_fn):
|
||||||
sess.run(x)
|
sess.run(x)
|
||||||
|
@ -54,7 +54,7 @@ class RuntimeErrorsTest(test.TestCase):
|
|||||||
ops = zero_div_caller()
|
ops = zero_div_caller()
|
||||||
with self.assertRaises(errors.TfRuntimeError) as cm:
|
with self.assertRaises(errors.TfRuntimeError) as cm:
|
||||||
with errors.improved_errors(zero_div_caller):
|
with errors.improved_errors(zero_div_caller):
|
||||||
with self.test_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(ops)
|
sess.run(ops)
|
||||||
|
|
||||||
for frame in cm.exception.custom_traceback:
|
for frame in cm.exception.custom_traceback:
|
||||||
@ -69,7 +69,7 @@ class RuntimeErrorsTest(test.TestCase):
|
|||||||
ops = zero_div_caller()
|
ops = zero_div_caller()
|
||||||
with self.assertRaises(errors.TfRuntimeError) as cm:
|
with self.assertRaises(errors.TfRuntimeError) as cm:
|
||||||
with errors.improved_errors(zero_div_caller):
|
with errors.improved_errors(zero_div_caller):
|
||||||
with self.test_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(ops)
|
sess.run(ops)
|
||||||
|
|
||||||
all_function_names = set()
|
all_function_names = set()
|
||||||
@ -86,7 +86,7 @@ class RuntimeErrorsTest(test.TestCase):
|
|||||||
ops = zero_div_caller()
|
ops = zero_div_caller()
|
||||||
with self.assertRaises(tf_errors.InvalidArgumentError):
|
with self.assertRaises(tf_errors.InvalidArgumentError):
|
||||||
with errors.improved_errors(zero_div_caller):
|
with errors.improved_errors(zero_div_caller):
|
||||||
with self.test_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(ops)
|
sess.run(ops)
|
||||||
|
|
||||||
def test_improved_errors_validation(self):
|
def test_improved_errors_validation(self):
|
||||||
|
@ -55,7 +55,7 @@ class ApiTest(test.TestCase):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
tc = TestClass()
|
tc = TestClass()
|
||||||
with self.test_session() as sess:
|
with self.cached_session() as sess:
|
||||||
x = tc.test_method(
|
x = tc.test_method(
|
||||||
constant_op.constant([2, 4]), constant_op.constant(1),
|
constant_op.constant([2, 4]), constant_op.constant(1),
|
||||||
constant_op.constant(-2))
|
constant_op.constant(-2))
|
||||||
@ -75,7 +75,7 @@ class ApiTest(test.TestCase):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
tc = TestClass()
|
tc = TestClass()
|
||||||
with self.test_session() as sess:
|
with self.cached_session() as sess:
|
||||||
x = tc.test_method(
|
x = tc.test_method(
|
||||||
constant_op.constant([2, 4]), constant_op.constant(1),
|
constant_op.constant([2, 4]), constant_op.constant(1),
|
||||||
constant_op.constant(-2))
|
constant_op.constant(-2))
|
||||||
@ -96,7 +96,7 @@ class ApiTest(test.TestCase):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
tc = TestClass()
|
tc = TestClass()
|
||||||
with self.test_session() as sess:
|
with self.cached_session() as sess:
|
||||||
x = tc.test_method(
|
x = tc.test_method(
|
||||||
constant_op.constant([2, 4]), constant_op.constant(1),
|
constant_op.constant([2, 4]), constant_op.constant(1),
|
||||||
constant_op.constant(-2))
|
constant_op.constant(-2))
|
||||||
@ -122,7 +122,7 @@ class ApiTest(test.TestCase):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
tc = TestClass()
|
tc = TestClass()
|
||||||
with self.test_session() as sess:
|
with self.cached_session() as sess:
|
||||||
x = tc.test_method(
|
x = tc.test_method(
|
||||||
constant_op.constant([2, 4]), constant_op.constant(1),
|
constant_op.constant([2, 4]), constant_op.constant(1),
|
||||||
constant_op.constant(-2))
|
constant_op.constant(-2))
|
||||||
@ -145,7 +145,7 @@ class ApiTest(test.TestCase):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
tc = TestClass()
|
tc = TestClass()
|
||||||
with self.test_session() as sess:
|
with self.cached_session() as sess:
|
||||||
x = tc.test_method(
|
x = tc.test_method(
|
||||||
constant_op.constant([2, 4]), constant_op.constant(1),
|
constant_op.constant([2, 4]), constant_op.constant(1),
|
||||||
constant_op.constant(-2))
|
constant_op.constant(-2))
|
||||||
@ -185,7 +185,7 @@ class ApiTest(test.TestCase):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
tc = TestClass()
|
tc = TestClass()
|
||||||
with self.test_session() as sess:
|
with self.cached_session() as sess:
|
||||||
x = tc.test_method(
|
x = tc.test_method(
|
||||||
constant_op.constant([2, 4]), constant_op.constant(1),
|
constant_op.constant([2, 4]), constant_op.constant(1),
|
||||||
constant_op.constant(-2))
|
constant_op.constant(-2))
|
||||||
@ -202,7 +202,7 @@ class ApiTest(test.TestCase):
|
|||||||
return -x
|
return -x
|
||||||
return x
|
return x
|
||||||
|
|
||||||
with self.test_session() as sess:
|
with self.cached_session() as sess:
|
||||||
x = api.converted_call(test_fn, api.ConversionOptions.new(),
|
x = api.converted_call(test_fn, api.ConversionOptions.new(),
|
||||||
constant_op.constant(-1))
|
constant_op.constant(-1))
|
||||||
self.assertEqual(1, sess.run(x))
|
self.assertEqual(1, sess.run(x))
|
||||||
@ -219,7 +219,7 @@ class ApiTest(test.TestCase):
|
|||||||
return -self.x
|
return -self.x
|
||||||
return self.x
|
return self.x
|
||||||
|
|
||||||
with self.test_session() as sess:
|
with self.cached_session() as sess:
|
||||||
tc = TestClass(constant_op.constant(-1))
|
tc = TestClass(constant_op.constant(-1))
|
||||||
x = api.converted_call(tc.test_method, api.ConversionOptions.new(), tc)
|
x = api.converted_call(tc.test_method, api.ConversionOptions.new(), tc)
|
||||||
self.assertEqual(1, sess.run(x))
|
self.assertEqual(1, sess.run(x))
|
||||||
@ -236,7 +236,7 @@ class ApiTest(test.TestCase):
|
|||||||
return -self.x
|
return -self.x
|
||||||
return self.x
|
return self.x
|
||||||
|
|
||||||
with self.test_session() as sess:
|
with self.cached_session() as sess:
|
||||||
tc = TestClass(constant_op.constant(-1))
|
tc = TestClass(constant_op.constant(-1))
|
||||||
x = api.converted_call(
|
x = api.converted_call(
|
||||||
TestClass.test_method,
|
TestClass.test_method,
|
||||||
@ -255,7 +255,7 @@ class ApiTest(test.TestCase):
|
|||||||
return -self.x
|
return -self.x
|
||||||
return self.x
|
return self.x
|
||||||
|
|
||||||
with self.test_session() as sess:
|
with self.cached_session() as sess:
|
||||||
tc = TestClass(constant_op.constant(-1))
|
tc = TestClass(constant_op.constant(-1))
|
||||||
x = api.converted_call(tc, api.ConversionOptions.new())
|
x = api.converted_call(tc, api.ConversionOptions.new())
|
||||||
self.assertEqual(1, sess.run(x))
|
self.assertEqual(1, sess.run(x))
|
||||||
@ -272,7 +272,7 @@ class ApiTest(test.TestCase):
|
|||||||
return -self.x
|
return -self.x
|
||||||
return self.x
|
return self.x
|
||||||
|
|
||||||
with self.test_session() as sess:
|
with self.cached_session() as sess:
|
||||||
tc = api.converted_call(TestClass, api.ConversionOptions.new(),
|
tc = api.converted_call(TestClass, api.ConversionOptions.new(),
|
||||||
constant_op.constant(-1))
|
constant_op.constant(-1))
|
||||||
# tc is now a converted object.
|
# tc is now a converted object.
|
||||||
@ -284,7 +284,7 @@ class ApiTest(test.TestCase):
|
|||||||
def f(x):
|
def f(x):
|
||||||
return x == 0
|
return x == 0
|
||||||
|
|
||||||
with self.test_session() as sess:
|
with self.cached_session() as sess:
|
||||||
x = api.converted_call(f, api.ConversionOptions.new(),
|
x = api.converted_call(f, api.ConversionOptions.new(),
|
||||||
constant_op.constant(0))
|
constant_op.constant(0))
|
||||||
self.assertTrue(sess.run(x))
|
self.assertTrue(sess.run(x))
|
||||||
@ -303,7 +303,7 @@ class ApiTest(test.TestCase):
|
|||||||
|
|
||||||
compiled_fn = api.to_graph(test_fn)
|
compiled_fn = api.to_graph(test_fn)
|
||||||
|
|
||||||
with self.test_session() as sess:
|
with self.cached_session() as sess:
|
||||||
x = compiled_fn(constant_op.constant([4, 8]), 4)
|
x = compiled_fn(constant_op.constant([4, 8]), 4)
|
||||||
self.assertListEqual([1, 2], sess.run(x).tolist())
|
self.assertListEqual([1, 2], sess.run(x).tolist())
|
||||||
|
|
||||||
|
@ -33,7 +33,7 @@ class SpecialFunctionsTest(test.TestCase):
|
|||||||
|
|
||||||
l = special_functions.tensor_list(elements)
|
l = special_functions.tensor_list(elements)
|
||||||
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
||||||
with self.test_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
|
self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
|
||||||
|
|
||||||
def test_tensor_list_array_from_elements(self):
|
def test_tensor_list_array_from_elements(self):
|
||||||
@ -41,7 +41,7 @@ class SpecialFunctionsTest(test.TestCase):
|
|||||||
|
|
||||||
l = special_functions.tensor_list(elements, use_tensor_array=True)
|
l = special_functions.tensor_list(elements, use_tensor_array=True)
|
||||||
sl = l.stack()
|
sl = l.stack()
|
||||||
with self.test_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
|
self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
|
||||||
|
|
||||||
def test_stack(self):
|
def test_stack(self):
|
||||||
|
@ -36,7 +36,7 @@ class PyBuiltinsTest(test.TestCase):
|
|||||||
|
|
||||||
def test_abs(self):
|
def test_abs(self):
|
||||||
self.assertEqual(py_builtins.abs_(-1), 1)
|
self.assertEqual(py_builtins.abs_(-1), 1)
|
||||||
with self.test_session() as sess:
|
with self.cached_session() as sess:
|
||||||
t = py_builtins.abs_(constant_op.constant(-1))
|
t = py_builtins.abs_(constant_op.constant(-1))
|
||||||
self.assertEqual(sess.run(t), 1)
|
self.assertEqual(sess.run(t), 1)
|
||||||
t = py_builtins.abs_(constant_op.constant([-1, 2, -3]))
|
t = py_builtins.abs_(constant_op.constant([-1, 2, -3]))
|
||||||
@ -45,7 +45,7 @@ class PyBuiltinsTest(test.TestCase):
|
|||||||
def test_float(self):
|
def test_float(self):
|
||||||
self.assertEqual(py_builtins.float_(10), 10.0)
|
self.assertEqual(py_builtins.float_(10), 10.0)
|
||||||
self.assertEqual(py_builtins.float_('10.0'), 10.0)
|
self.assertEqual(py_builtins.float_('10.0'), 10.0)
|
||||||
with self.test_session() as sess:
|
with self.cached_session() as sess:
|
||||||
t = py_builtins.float_(constant_op.constant(1, dtype=dtypes.int64))
|
t = py_builtins.float_(constant_op.constant(1, dtype=dtypes.int64))
|
||||||
self.assertEqual(sess.run(t), 1.0)
|
self.assertEqual(sess.run(t), 1.0)
|
||||||
st = py_builtins.float_(constant_op.constant('1.0'))
|
st = py_builtins.float_(constant_op.constant('1.0'))
|
||||||
@ -54,7 +54,7 @@ class PyBuiltinsTest(test.TestCase):
|
|||||||
def test_int(self):
|
def test_int(self):
|
||||||
self.assertEqual(py_builtins.int_(10.0), 10)
|
self.assertEqual(py_builtins.int_(10.0), 10)
|
||||||
self.assertEqual(py_builtins.int_('11', 2), 3)
|
self.assertEqual(py_builtins.int_('11', 2), 3)
|
||||||
with self.test_session() as sess:
|
with self.cached_session() as sess:
|
||||||
t = py_builtins.int_(constant_op.constant(1, dtype=dtypes.float64))
|
t = py_builtins.int_(constant_op.constant(1, dtype=dtypes.float64))
|
||||||
self.assertEqual(sess.run(t), 1)
|
self.assertEqual(sess.run(t), 1)
|
||||||
st = py_builtins.int_(constant_op.constant('1'))
|
st = py_builtins.int_(constant_op.constant('1'))
|
||||||
@ -69,7 +69,7 @@ class PyBuiltinsTest(test.TestCase):
|
|||||||
|
|
||||||
def test_len(self):
|
def test_len(self):
|
||||||
self.assertEqual(py_builtins.len_([1, 2, 3]), 3)
|
self.assertEqual(py_builtins.len_([1, 2, 3]), 3)
|
||||||
with self.test_session() as sess:
|
with self.cached_session() as sess:
|
||||||
t = py_builtins.len_(constant_op.constant([[1], [2], [3]]))
|
t = py_builtins.len_(constant_op.constant([[1], [2], [3]]))
|
||||||
self.assertEqual(t, 3)
|
self.assertEqual(t, 3)
|
||||||
ta = py_builtins.len_(tensor_array_ops.TensorArray(dtypes.int32, size=5))
|
ta = py_builtins.len_(tensor_array_ops.TensorArray(dtypes.int32, size=5))
|
||||||
@ -82,7 +82,7 @@ class PyBuiltinsTest(test.TestCase):
|
|||||||
py_builtins.len_(constant_op.constant(1))
|
py_builtins.len_(constant_op.constant(1))
|
||||||
|
|
||||||
def test_len_dynamic_shape(self):
|
def test_len_dynamic_shape(self):
|
||||||
with self.test_session() as sess:
|
with self.cached_session() as sess:
|
||||||
p = array_ops.placeholder(dtype=dtypes.int32, shape=None)
|
p = array_ops.placeholder(dtype=dtypes.int32, shape=None)
|
||||||
t = py_builtins.len_(p)
|
t = py_builtins.len_(p)
|
||||||
self.assertEqual(sess.run(t, {p: [1, 2, 3]}), 3)
|
self.assertEqual(sess.run(t, {p: [1, 2, 3]}), 3)
|
||||||
@ -95,7 +95,7 @@ class PyBuiltinsTest(test.TestCase):
|
|||||||
try:
|
try:
|
||||||
out_capturer = six.StringIO()
|
out_capturer = six.StringIO()
|
||||||
sys.stdout = out_capturer
|
sys.stdout = out_capturer
|
||||||
with self.test_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(py_builtins.print_(constant_op.constant('test message'), 1))
|
sess.run(py_builtins.print_(constant_op.constant('test message'), 1))
|
||||||
self.assertEqual(out_capturer.getvalue(), 'test message 1\n')
|
self.assertEqual(out_capturer.getvalue(), 'test message 1\n')
|
||||||
finally:
|
finally:
|
||||||
@ -105,7 +105,7 @@ class PyBuiltinsTest(test.TestCase):
|
|||||||
try:
|
try:
|
||||||
out_capturer = six.StringIO()
|
out_capturer = six.StringIO()
|
||||||
sys.stdout = out_capturer
|
sys.stdout = out_capturer
|
||||||
with self.test_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(
|
sess.run(
|
||||||
py_builtins.print_(constant_op.constant('test message'), [1, 2]))
|
py_builtins.print_(constant_op.constant('test message'), [1, 2]))
|
||||||
self.assertEqual(out_capturer.getvalue(), 'test message [1, 2]\n')
|
self.assertEqual(out_capturer.getvalue(), 'test message [1, 2]\n')
|
||||||
@ -118,7 +118,7 @@ class PyBuiltinsTest(test.TestCase):
|
|||||||
self.assertListEqual(list(py_builtins.range_(2, 0, -1)), [2, 1])
|
self.assertListEqual(list(py_builtins.range_(2, 0, -1)), [2, 1])
|
||||||
|
|
||||||
def test_range_tensor(self):
|
def test_range_tensor(self):
|
||||||
with self.test_session() as sess:
|
with self.cached_session() as sess:
|
||||||
r = py_builtins.range_(constant_op.constant(3))
|
r = py_builtins.range_(constant_op.constant(3))
|
||||||
self.assertAllEqual(sess.run(r), [0, 1, 2])
|
self.assertAllEqual(sess.run(r), [0, 1, 2])
|
||||||
r = py_builtins.range_(1, constant_op.constant(3))
|
r = py_builtins.range_(1, constant_op.constant(3))
|
||||||
|
@ -51,14 +51,14 @@ class SlicesTest(test.TestCase):
|
|||||||
t = slices.get_item(initial_str, 1,
|
t = slices.get_item(initial_str, 1,
|
||||||
slices.GetItemOpts(element_dtype=initial_str.dtype))
|
slices.GetItemOpts(element_dtype=initial_str.dtype))
|
||||||
|
|
||||||
with self.test_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertEqual(sess.run(t), b'b')
|
self.assertEqual(sess.run(t), b'b')
|
||||||
|
|
||||||
initial_list_str = constant_op.constant(['abcd', 'bcde'])
|
initial_list_str = constant_op.constant(['abcd', 'bcde'])
|
||||||
t = slices.get_item(initial_list_str, 1,
|
t = slices.get_item(initial_list_str, 1,
|
||||||
slices.GetItemOpts(element_dtype=initial_str.dtype))
|
slices.GetItemOpts(element_dtype=initial_str.dtype))
|
||||||
|
|
||||||
with self.test_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertEqual(sess.run(t), b'bcde')
|
self.assertEqual(sess.run(t), b'bcde')
|
||||||
|
|
||||||
|
|
||||||
|
@ -1602,7 +1602,7 @@ class FunctionTest(test.TestCase):
|
|||||||
defun_add = function.defun_with_attributes(
|
defun_add = function.defun_with_attributes(
|
||||||
add, attributes={'experimental_3': True, 'experimental_4': 1.0})
|
add, attributes={'experimental_3': True, 'experimental_4': 1.0})
|
||||||
|
|
||||||
with context.graph_mode(), self.test_session():
|
with context.graph_mode(), self.cached_session():
|
||||||
with ops.get_default_graph().as_default():
|
with ops.get_default_graph().as_default():
|
||||||
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
|
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
|
||||||
sq = matmul(t, t)
|
sq = matmul(t, t)
|
||||||
@ -1636,7 +1636,7 @@ class FunctionTest(test.TestCase):
|
|||||||
|
|
||||||
with self.assertRaisesRegexp(ValueError,
|
with self.assertRaisesRegexp(ValueError,
|
||||||
'.*Attribute name is not whitelisted.*'):
|
'.*Attribute name is not whitelisted.*'):
|
||||||
with context.graph_mode(), self.test_session():
|
with context.graph_mode(), self.cached_session():
|
||||||
with ops.get_default_graph().as_default():
|
with ops.get_default_graph().as_default():
|
||||||
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
|
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
|
||||||
matmul(t, t)
|
matmul(t, t)
|
||||||
@ -1647,7 +1647,7 @@ class FunctionTest(test.TestCase):
|
|||||||
|
|
||||||
with self.assertRaisesRegexp(ValueError,
|
with self.assertRaisesRegexp(ValueError,
|
||||||
'.*Unsupported attribute type.*'):
|
'.*Unsupported attribute type.*'):
|
||||||
with context.graph_mode(), self.test_session():
|
with context.graph_mode(), self.cached_session():
|
||||||
with ops.get_default_graph().as_default():
|
with ops.get_default_graph().as_default():
|
||||||
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
|
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
|
||||||
add(t, t)
|
add(t, t)
|
||||||
|
@ -915,7 +915,7 @@ class TopologyConstructionTest(test.TestCase):
|
|||||||
|
|
||||||
def test_constant_initializer_with_numpy(self):
|
def test_constant_initializer_with_numpy(self):
|
||||||
|
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
initializer = keras.initializers.Constant(np.ones((3, 2)))
|
initializer = keras.initializers.Constant(np.ones((3, 2)))
|
||||||
model = keras.models.Sequential()
|
model = keras.models.Sequential()
|
||||||
model.add(keras.layers.Dense(2, input_shape=(3,),
|
model.add(keras.layers.Dense(2, input_shape=(3,),
|
||||||
|
@ -186,7 +186,7 @@ class TestMultiGPUModel(test.TestCase):
|
|||||||
if not check_if_compatible_devices(gpus=gpus):
|
if not check_if_compatible_devices(gpus=gpus):
|
||||||
return
|
return
|
||||||
|
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
inputs = keras.Input((4, 3))
|
inputs = keras.Input((4, 3))
|
||||||
init_state = keras.Input((3,))
|
init_state = keras.Input((3,))
|
||||||
outputs = keras.layers.SimpleRNN(
|
outputs = keras.layers.SimpleRNN(
|
||||||
|
@ -934,7 +934,7 @@ class FeatureContribsOpsTest(test_util.TensorFlowTestCase):
|
|||||||
For example, this could happen if the final ensemble contains one tree that
|
For example, this could happen if the final ensemble contains one tree that
|
||||||
got pruned up to the root.
|
got pruned up to the root.
|
||||||
"""
|
"""
|
||||||
with self.test_session() as session:
|
with self.cached_session() as session:
|
||||||
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
|
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
|
||||||
text_format.Merge(
|
text_format.Merge(
|
||||||
"""
|
"""
|
||||||
@ -990,7 +990,7 @@ class FeatureContribsOpsTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
def testContribsMultipleTreeWhenFirstTreeIsABiasNode(self):
|
def testContribsMultipleTreeWhenFirstTreeIsABiasNode(self):
|
||||||
"""Tests case when, after training, first tree contains only a bias node."""
|
"""Tests case when, after training, first tree contains only a bias node."""
|
||||||
with self.test_session() as session:
|
with self.cached_session() as session:
|
||||||
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
|
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
|
||||||
text_format.Merge(
|
text_format.Merge(
|
||||||
"""
|
"""
|
||||||
|
@ -78,7 +78,7 @@ class QuantileOpsTest(test_util.TensorFlowTestCase):
|
|||||||
self.num_quantiles = constant_op.constant(3, dtype=dtypes.int64)
|
self.num_quantiles = constant_op.constant(3, dtype=dtypes.int64)
|
||||||
|
|
||||||
def testBasicQuantileBucketsSingleResource(self):
|
def testBasicQuantileBucketsSingleResource(self):
|
||||||
with self.test_session() as sess:
|
with self.cached_session() as sess:
|
||||||
quantile_accumulator_handle = self.create_resource("floats", self.eps,
|
quantile_accumulator_handle = self.create_resource("floats", self.eps,
|
||||||
self.max_elements, 2)
|
self.max_elements, 2)
|
||||||
resources.initialize_resources(resources.shared_resources()).run()
|
resources.initialize_resources(resources.shared_resources()).run()
|
||||||
@ -102,7 +102,7 @@ class QuantileOpsTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertAllClose(self._feature_1_quantiles, quantiles[1].eval())
|
self.assertAllClose(self._feature_1_quantiles, quantiles[1].eval())
|
||||||
|
|
||||||
def testBasicQuantileBucketsMultipleResources(self):
|
def testBasicQuantileBucketsMultipleResources(self):
|
||||||
with self.test_session() as sess:
|
with self.cached_session() as sess:
|
||||||
quantile_accumulator_handle_0 = self.create_resource("float_0", self.eps,
|
quantile_accumulator_handle_0 = self.create_resource("float_0", self.eps,
|
||||||
self.max_elements)
|
self.max_elements)
|
||||||
quantile_accumulator_handle_1 = self.create_resource("float_1", self.eps,
|
quantile_accumulator_handle_1 = self.create_resource("float_1", self.eps,
|
||||||
|
@ -76,7 +76,7 @@ class LinearOperatorAdditionCorrectnessTest(test.TestCase):
|
|||||||
[1., 1.], is_positive_definite=True, name="A")
|
[1., 1.], is_positive_definite=True, name="A")
|
||||||
op_b = linalg.LinearOperatorDiag(
|
op_b = linalg.LinearOperatorDiag(
|
||||||
[2., 2.], is_positive_definite=True, name="B")
|
[2., 2.], is_positive_definite=True, name="B")
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
op_sum = add_operators([op_a, op_b])
|
op_sum = add_operators([op_a, op_b])
|
||||||
self.assertEqual(1, len(op_sum))
|
self.assertEqual(1, len(op_sum))
|
||||||
op = op_sum[0]
|
op = op_sum[0]
|
||||||
@ -98,7 +98,7 @@ class LinearOperatorAdditionCorrectnessTest(test.TestCase):
|
|||||||
[2., 2.], is_positive_definite=True, name="op2")
|
[2., 2.], is_positive_definite=True, name="op2")
|
||||||
op3 = linalg.LinearOperatorDiag(
|
op3 = linalg.LinearOperatorDiag(
|
||||||
[3., 3.], is_positive_definite=True, name="op3")
|
[3., 3.], is_positive_definite=True, name="op3")
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
op_sum = add_operators([op1, op2, op3])
|
op_sum = add_operators([op1, op2, op3])
|
||||||
self.assertEqual(1, len(op_sum))
|
self.assertEqual(1, len(op_sum))
|
||||||
op = op_sum[0]
|
op = op_sum[0]
|
||||||
@ -121,7 +121,7 @@ class LinearOperatorAdditionCorrectnessTest(test.TestCase):
|
|||||||
name="tril")
|
name="tril")
|
||||||
op3 = linalg.LinearOperatorDiag(
|
op3 = linalg.LinearOperatorDiag(
|
||||||
[3., 3.], is_non_singular=True, name="diag_b")
|
[3., 3.], is_non_singular=True, name="diag_b")
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
op_sum = add_operators([op1, op2, op3])
|
op_sum = add_operators([op1, op2, op3])
|
||||||
self.assertEqual(1, len(op_sum))
|
self.assertEqual(1, len(op_sum))
|
||||||
op = op_sum[0]
|
op = op_sum[0]
|
||||||
@ -143,7 +143,7 @@ class LinearOperatorAdditionCorrectnessTest(test.TestCase):
|
|||||||
op2 = linalg.LinearOperatorLowerTriangular(
|
op2 = linalg.LinearOperatorLowerTriangular(
|
||||||
[[2., 0.], [1.5, 2.]], name="tril")
|
[[2., 0.], [1.5, 2.]], name="tril")
|
||||||
op3 = linalg.LinearOperatorDiag([3., 3.], name="diag_b")
|
op3 = linalg.LinearOperatorDiag([3., 3.], name="diag_b")
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
op_sum = add_operators([op0, op1, op2, op3], operator_name="my_operator")
|
op_sum = add_operators([op0, op1, op2, op3], operator_name="my_operator")
|
||||||
self.assertEqual(1, len(op_sum))
|
self.assertEqual(1, len(op_sum))
|
||||||
op = op_sum[0]
|
op = op_sum[0]
|
||||||
@ -233,7 +233,7 @@ class LinearOperatorOrderOfAdditionTest(test.TestCase):
|
|||||||
self.assertEqual(2, len(op_sum))
|
self.assertEqual(2, len(op_sum))
|
||||||
found_diag = False
|
found_diag = False
|
||||||
found_tril = False
|
found_tril = False
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
for op in op_sum:
|
for op in op_sum:
|
||||||
if isinstance(op, linalg.LinearOperatorDiag):
|
if isinstance(op, linalg.LinearOperatorDiag):
|
||||||
found_diag = True
|
found_diag = True
|
||||||
@ -273,7 +273,7 @@ class AddAndReturnScaledIdentityTest(test.TestCase):
|
|||||||
operator = self._adder.add(id1, id2, "my_operator", hints)
|
operator = self._adder.add(id1, id2, "my_operator", hints)
|
||||||
self.assertIsInstance(operator, linalg.LinearOperatorScaledIdentity)
|
self.assertIsInstance(operator, linalg.LinearOperatorScaledIdentity)
|
||||||
|
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
self.assertAllClose(2 *
|
self.assertAllClose(2 *
|
||||||
linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
|
linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
|
||||||
operator.to_dense().eval())
|
operator.to_dense().eval())
|
||||||
@ -291,7 +291,7 @@ class AddAndReturnScaledIdentityTest(test.TestCase):
|
|||||||
operator = self._adder.add(id1, id2, "my_operator", hints)
|
operator = self._adder.add(id1, id2, "my_operator", hints)
|
||||||
self.assertIsInstance(operator, linalg.LinearOperatorScaledIdentity)
|
self.assertIsInstance(operator, linalg.LinearOperatorScaledIdentity)
|
||||||
|
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
self.assertAllClose(3.2 *
|
self.assertAllClose(3.2 *
|
||||||
linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
|
linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
|
||||||
operator.to_dense().eval())
|
operator.to_dense().eval())
|
||||||
@ -310,7 +310,7 @@ class AddAndReturnScaledIdentityTest(test.TestCase):
|
|||||||
operator = self._adder.add(id1, id2, "my_operator", hints)
|
operator = self._adder.add(id1, id2, "my_operator", hints)
|
||||||
self.assertIsInstance(operator, linalg.LinearOperatorScaledIdentity)
|
self.assertIsInstance(operator, linalg.LinearOperatorScaledIdentity)
|
||||||
|
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
self.assertAllClose(1.2 *
|
self.assertAllClose(1.2 *
|
||||||
linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
|
linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
|
||||||
operator.to_dense().eval())
|
operator.to_dense().eval())
|
||||||
@ -334,7 +334,7 @@ class AddAndReturnDiagTest(test.TestCase):
|
|||||||
operator = self._adder.add(id1, id2, "my_operator", hints)
|
operator = self._adder.add(id1, id2, "my_operator", hints)
|
||||||
self.assertIsInstance(operator, linalg.LinearOperatorDiag)
|
self.assertIsInstance(operator, linalg.LinearOperatorDiag)
|
||||||
|
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
self.assertAllClose(2 *
|
self.assertAllClose(2 *
|
||||||
linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
|
linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
|
||||||
operator.to_dense().eval())
|
operator.to_dense().eval())
|
||||||
@ -354,7 +354,7 @@ class AddAndReturnDiagTest(test.TestCase):
|
|||||||
operator = self._adder.add(op1, op2, "my_operator", hints)
|
operator = self._adder.add(op1, op2, "my_operator", hints)
|
||||||
self.assertIsInstance(operator, linalg.LinearOperatorDiag)
|
self.assertIsInstance(operator, linalg.LinearOperatorDiag)
|
||||||
|
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
self.assertAllClose(
|
self.assertAllClose(
|
||||||
linalg.LinearOperatorDiag(diag1 + diag2).to_dense().eval(),
|
linalg.LinearOperatorDiag(diag1 + diag2).to_dense().eval(),
|
||||||
operator.to_dense().eval())
|
operator.to_dense().eval())
|
||||||
@ -379,7 +379,7 @@ class AddAndReturnTriLTest(test.TestCase):
|
|||||||
operator = self._adder.add(diag, tril, "my_operator", hints)
|
operator = self._adder.add(diag, tril, "my_operator", hints)
|
||||||
self.assertIsInstance(operator, linalg.LinearOperatorLowerTriangular)
|
self.assertIsInstance(operator, linalg.LinearOperatorLowerTriangular)
|
||||||
|
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
self.assertAllClose([[11., 0.], [30., 2.]], operator.to_dense().eval())
|
self.assertAllClose([[11., 0.], [30., 2.]], operator.to_dense().eval())
|
||||||
self.assertTrue(operator.is_positive_definite)
|
self.assertTrue(operator.is_positive_definite)
|
||||||
self.assertTrue(operator.is_non_singular)
|
self.assertTrue(operator.is_non_singular)
|
||||||
@ -401,7 +401,7 @@ class AddAndReturnMatrixTest(test.TestCase):
|
|||||||
operator = self._adder.add(diag1, diag2, "my_operator", hints)
|
operator = self._adder.add(diag1, diag2, "my_operator", hints)
|
||||||
self.assertIsInstance(operator, linalg.LinearOperatorFullMatrix)
|
self.assertIsInstance(operator, linalg.LinearOperatorFullMatrix)
|
||||||
|
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
self.assertAllClose([[0., 0.], [0., 5.]], operator.to_dense().eval())
|
self.assertAllClose([[0., 0.], [0., 5.]], operator.to_dense().eval())
|
||||||
self.assertFalse(operator.is_positive_definite)
|
self.assertFalse(operator.is_positive_definite)
|
||||||
self.assertFalse(operator.is_non_singular)
|
self.assertFalse(operator.is_non_singular)
|
||||||
|
@ -31,7 +31,7 @@ class PrintV2LoggingLevelTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testPrintOneTensorLogInfo(self):
|
def testPrintOneTensorLogInfo(self):
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
tensor = math_ops.range(10)
|
tensor = math_ops.range(10)
|
||||||
with self.captureWritesToStream(sys.stderr) as printed:
|
with self.captureWritesToStream(sys.stderr) as printed:
|
||||||
print_op = logging_ops.print_v2(
|
print_op = logging_ops.print_v2(
|
||||||
@ -43,7 +43,7 @@ class PrintV2LoggingLevelTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testPrintOneTensorLogWarning(self):
|
def testPrintOneTensorLogWarning(self):
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
tensor = math_ops.range(10)
|
tensor = math_ops.range(10)
|
||||||
with self.captureWritesToStream(sys.stderr) as printed:
|
with self.captureWritesToStream(sys.stderr) as printed:
|
||||||
print_op = logging_ops.print_v2(
|
print_op = logging_ops.print_v2(
|
||||||
@ -55,7 +55,7 @@ class PrintV2LoggingLevelTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testPrintOneTensorLogError(self):
|
def testPrintOneTensorLogError(self):
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
tensor = math_ops.range(10)
|
tensor = math_ops.range(10)
|
||||||
with self.captureWritesToStream(sys.stderr) as printed:
|
with self.captureWritesToStream(sys.stderr) as printed:
|
||||||
print_op = logging_ops.print_v2(
|
print_op = logging_ops.print_v2(
|
||||||
|
@ -69,7 +69,7 @@ class PrintV2Test(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testPrintOneTensor(self):
|
def testPrintOneTensor(self):
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
tensor = math_ops.range(10)
|
tensor = math_ops.range(10)
|
||||||
with self.captureWritesToStream(sys.stderr) as printed:
|
with self.captureWritesToStream(sys.stderr) as printed:
|
||||||
print_op = logging_ops.print_v2(tensor)
|
print_op = logging_ops.print_v2(tensor)
|
||||||
@ -80,7 +80,7 @@ class PrintV2Test(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testPrintOneTensorVarySummarize(self):
|
def testPrintOneTensorVarySummarize(self):
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
tensor = math_ops.range(10)
|
tensor = math_ops.range(10)
|
||||||
with self.captureWritesToStream(sys.stderr) as printed:
|
with self.captureWritesToStream(sys.stderr) as printed:
|
||||||
print_op = logging_ops.print_v2(tensor, summarize=1)
|
print_op = logging_ops.print_v2(tensor, summarize=1)
|
||||||
@ -89,7 +89,7 @@ class PrintV2Test(test.TestCase):
|
|||||||
expected = "[0 ... 9]"
|
expected = "[0 ... 9]"
|
||||||
self.assertTrue((expected + "\n") in printed.contents())
|
self.assertTrue((expected + "\n") in printed.contents())
|
||||||
|
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
tensor = math_ops.range(10)
|
tensor = math_ops.range(10)
|
||||||
with self.captureWritesToStream(sys.stderr) as printed:
|
with self.captureWritesToStream(sys.stderr) as printed:
|
||||||
print_op = logging_ops.print_v2(tensor, summarize=2)
|
print_op = logging_ops.print_v2(tensor, summarize=2)
|
||||||
@ -98,7 +98,7 @@ class PrintV2Test(test.TestCase):
|
|||||||
expected = "[0 1 ... 8 9]"
|
expected = "[0 1 ... 8 9]"
|
||||||
self.assertTrue((expected + "\n") in printed.contents())
|
self.assertTrue((expected + "\n") in printed.contents())
|
||||||
|
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
tensor = math_ops.range(10)
|
tensor = math_ops.range(10)
|
||||||
with self.captureWritesToStream(sys.stderr) as printed:
|
with self.captureWritesToStream(sys.stderr) as printed:
|
||||||
print_op = logging_ops.print_v2(tensor, summarize=3)
|
print_op = logging_ops.print_v2(tensor, summarize=3)
|
||||||
@ -107,7 +107,7 @@ class PrintV2Test(test.TestCase):
|
|||||||
expected = "[0 1 2 ... 7 8 9]"
|
expected = "[0 1 2 ... 7 8 9]"
|
||||||
self.assertTrue((expected + "\n") in printed.contents())
|
self.assertTrue((expected + "\n") in printed.contents())
|
||||||
|
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
tensor = math_ops.range(10)
|
tensor = math_ops.range(10)
|
||||||
with self.captureWritesToStream(sys.stderr) as printed:
|
with self.captureWritesToStream(sys.stderr) as printed:
|
||||||
print_op = logging_ops.print_v2(tensor, summarize=-1)
|
print_op = logging_ops.print_v2(tensor, summarize=-1)
|
||||||
@ -118,7 +118,7 @@ class PrintV2Test(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testPrintOneVariable(self):
|
def testPrintOneVariable(self):
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
var = variables.Variable(math_ops.range(10))
|
var = variables.Variable(math_ops.range(10))
|
||||||
if not context.executing_eagerly():
|
if not context.executing_eagerly():
|
||||||
variables.global_variables_initializer().run()
|
variables.global_variables_initializer().run()
|
||||||
@ -130,7 +130,7 @@ class PrintV2Test(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testPrintTwoVariablesInStructWithAssignAdd(self):
|
def testPrintTwoVariablesInStructWithAssignAdd(self):
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
var_one = variables.Variable(2.14)
|
var_one = variables.Variable(2.14)
|
||||||
plus_one = var_one.assign_add(1.0)
|
plus_one = var_one.assign_add(1.0)
|
||||||
var_two = variables.Variable(math_ops.range(10))
|
var_two = variables.Variable(math_ops.range(10))
|
||||||
@ -145,7 +145,7 @@ class PrintV2Test(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testPrintTwoTensors(self):
|
def testPrintTwoTensors(self):
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
tensor = math_ops.range(10)
|
tensor = math_ops.range(10)
|
||||||
with self.captureWritesToStream(sys.stderr) as printed:
|
with self.captureWritesToStream(sys.stderr) as printed:
|
||||||
print_op = logging_ops.print_v2(tensor, tensor * 10)
|
print_op = logging_ops.print_v2(tensor, tensor * 10)
|
||||||
@ -155,7 +155,7 @@ class PrintV2Test(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testPrintPlaceholderGeneration(self):
|
def testPrintPlaceholderGeneration(self):
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
tensor = math_ops.range(10)
|
tensor = math_ops.range(10)
|
||||||
with self.captureWritesToStream(sys.stderr) as printed:
|
with self.captureWritesToStream(sys.stderr) as printed:
|
||||||
print_op = logging_ops.print_v2("{}6", {"{}": tensor * 10})
|
print_op = logging_ops.print_v2("{}6", {"{}": tensor * 10})
|
||||||
@ -165,7 +165,7 @@ class PrintV2Test(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testPrintNoTensors(self):
|
def testPrintNoTensors(self):
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
with self.captureWritesToStream(sys.stderr) as printed:
|
with self.captureWritesToStream(sys.stderr) as printed:
|
||||||
print_op = logging_ops.print_v2(23, [23, 5], {"6": 12})
|
print_op = logging_ops.print_v2(23, [23, 5], {"6": 12})
|
||||||
self.evaluate(print_op)
|
self.evaluate(print_op)
|
||||||
@ -174,7 +174,7 @@ class PrintV2Test(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testPrintFloatScalar(self):
|
def testPrintFloatScalar(self):
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
tensor = ops.convert_to_tensor(434.43)
|
tensor = ops.convert_to_tensor(434.43)
|
||||||
with self.captureWritesToStream(sys.stderr) as printed:
|
with self.captureWritesToStream(sys.stderr) as printed:
|
||||||
print_op = logging_ops.print_v2(tensor)
|
print_op = logging_ops.print_v2(tensor)
|
||||||
@ -184,7 +184,7 @@ class PrintV2Test(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testPrintStringScalar(self):
|
def testPrintStringScalar(self):
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
tensor = ops.convert_to_tensor("scalar")
|
tensor = ops.convert_to_tensor("scalar")
|
||||||
with self.captureWritesToStream(sys.stderr) as printed:
|
with self.captureWritesToStream(sys.stderr) as printed:
|
||||||
print_op = logging_ops.print_v2(tensor)
|
print_op = logging_ops.print_v2(tensor)
|
||||||
@ -194,7 +194,7 @@ class PrintV2Test(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testPrintComplexTensorStruct(self):
|
def testPrintComplexTensorStruct(self):
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
tensor = math_ops.range(10)
|
tensor = math_ops.range(10)
|
||||||
small_tensor = constant_op.constant([0.3, 12.4, -16.1])
|
small_tensor = constant_op.constant([0.3, 12.4, -16.1])
|
||||||
big_tensor = math_ops.mul(tensor, 10)
|
big_tensor = math_ops.mul(tensor, 10)
|
||||||
@ -214,7 +214,7 @@ class PrintV2Test(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testPrintSparseTensor(self):
|
def testPrintSparseTensor(self):
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
ind = [[0, 0], [1, 0], [1, 3], [4, 1], [1, 4], [3, 2], [3, 3]]
|
ind = [[0, 0], [1, 0], [1, 3], [4, 1], [1, 4], [3, 2], [3, 3]]
|
||||||
val = [0, 10, 13, 4, 14, 32, 33]
|
val = [0, 10, 13, 4, 14, 32, 33]
|
||||||
shape = [5, 6]
|
shape = [5, 6]
|
||||||
@ -238,7 +238,7 @@ class PrintV2Test(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testPrintSparseTensorInDataStruct(self):
|
def testPrintSparseTensorInDataStruct(self):
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
ind = [[0, 0], [1, 0], [1, 3], [4, 1], [1, 4], [3, 2], [3, 3]]
|
ind = [[0, 0], [1, 0], [1, 3], [4, 1], [1, 4], [3, 2], [3, 3]]
|
||||||
val = [0, 10, 13, 4, 14, 32, 33]
|
val = [0, 10, 13, 4, 14, 32, 33]
|
||||||
shape = [5, 6]
|
shape = [5, 6]
|
||||||
@ -262,7 +262,7 @@ class PrintV2Test(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testPrintOneTensorStdout(self):
|
def testPrintOneTensorStdout(self):
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
tensor = math_ops.range(10)
|
tensor = math_ops.range(10)
|
||||||
with self.captureWritesToStream(sys.stdout) as printed:
|
with self.captureWritesToStream(sys.stdout) as printed:
|
||||||
print_op = logging_ops.print_v2(
|
print_op = logging_ops.print_v2(
|
||||||
@ -273,7 +273,7 @@ class PrintV2Test(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testInvalidOutputStreamRaisesError(self):
|
def testInvalidOutputStreamRaisesError(self):
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
tensor = math_ops.range(10)
|
tensor = math_ops.range(10)
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
print_op = logging_ops.print_v2(
|
print_op = logging_ops.print_v2(
|
||||||
@ -281,13 +281,13 @@ class PrintV2Test(test.TestCase):
|
|||||||
self.evaluate(print_op)
|
self.evaluate(print_op)
|
||||||
|
|
||||||
def testPrintOpName(self):
|
def testPrintOpName(self):
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
tensor = math_ops.range(10)
|
tensor = math_ops.range(10)
|
||||||
print_op = logging_ops.print_v2(tensor, name="print_name")
|
print_op = logging_ops.print_v2(tensor, name="print_name")
|
||||||
self.assertEqual(print_op.name, "print_name")
|
self.assertEqual(print_op.name, "print_name")
|
||||||
|
|
||||||
def testNoDuplicateFormatOpGraphModeAfterExplicitFormat(self):
|
def testNoDuplicateFormatOpGraphModeAfterExplicitFormat(self):
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
tensor = math_ops.range(10)
|
tensor = math_ops.range(10)
|
||||||
formatted_string = string_ops.string_format("{}", tensor)
|
formatted_string = string_ops.string_format("{}", tensor)
|
||||||
print_op = logging_ops.print_v2(formatted_string)
|
print_op = logging_ops.print_v2(formatted_string)
|
||||||
@ -298,7 +298,7 @@ class PrintV2Test(test.TestCase):
|
|||||||
self.assertEqual(len(format_ops), 1)
|
self.assertEqual(len(format_ops), 1)
|
||||||
|
|
||||||
def testPrintOneTensorEagerOnOpCreate(self):
|
def testPrintOneTensorEagerOnOpCreate(self):
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
with context.eager_mode():
|
with context.eager_mode():
|
||||||
tensor = math_ops.range(10)
|
tensor = math_ops.range(10)
|
||||||
expected = "[0 1 2 ... 7 8 9]"
|
expected = "[0 1 2 ... 7 8 9]"
|
||||||
|
@ -34,14 +34,14 @@ class StringFormatOpTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testFormatOneTensorOneDim(self):
|
def testFormatOneTensorOneDim(self):
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
tensor = math_ops.range(10)
|
tensor = math_ops.range(10)
|
||||||
format_output = string_ops.string_format("{}", tensor)
|
format_output = string_ops.string_format("{}", tensor)
|
||||||
out = self.evaluate(format_output)
|
out = self.evaluate(format_output)
|
||||||
expected = "[0 1 2 ... 7 8 9]"
|
expected = "[0 1 2 ... 7 8 9]"
|
||||||
self.assertEqual(compat.as_text(out), expected)
|
self.assertEqual(compat.as_text(out), expected)
|
||||||
|
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
tensor = math_ops.range(10)
|
tensor = math_ops.range(10)
|
||||||
format_output = string_ops.string_format("{}", [tensor])
|
format_output = string_ops.string_format("{}", [tensor])
|
||||||
out = self.evaluate(format_output)
|
out = self.evaluate(format_output)
|
||||||
@ -50,7 +50,7 @@ class StringFormatOpTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testFormatOneVariableScalar(self):
|
def testFormatOneVariableScalar(self):
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
var = variables.Variable(3.34)
|
var = variables.Variable(3.34)
|
||||||
format_output = string_ops.string_format("{}", [var])
|
format_output = string_ops.string_format("{}", [var])
|
||||||
if not context.executing_eagerly():
|
if not context.executing_eagerly():
|
||||||
@ -61,7 +61,7 @@ class StringFormatOpTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testFormatOneVariableOneDim(self):
|
def testFormatOneVariableOneDim(self):
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
var = variables.Variable(math_ops.range(10))
|
var = variables.Variable(math_ops.range(10))
|
||||||
format_output = string_ops.string_format("{}", [var])
|
format_output = string_ops.string_format("{}", [var])
|
||||||
if not context.executing_eagerly():
|
if not context.executing_eagerly():
|
||||||
@ -72,7 +72,7 @@ class StringFormatOpTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testFormatTwoVariablesWithAssignAdd(self):
|
def testFormatTwoVariablesWithAssignAdd(self):
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
var_one = variables.Variable(2.14)
|
var_one = variables.Variable(2.14)
|
||||||
plus_one = var_one.assign_add(1.0)
|
plus_one = var_one.assign_add(1.0)
|
||||||
var_two = variables.Variable(math_ops.range(10))
|
var_two = variables.Variable(math_ops.range(10))
|
||||||
@ -86,7 +86,7 @@ class StringFormatOpTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testFormatOneTensorOneDimFloat(self):
|
def testFormatOneTensorOneDimFloat(self):
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
tensor = constant_op.constant([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7])
|
tensor = constant_op.constant([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7])
|
||||||
format_output = string_ops.string_format("{}", tensor)
|
format_output = string_ops.string_format("{}", tensor)
|
||||||
out = self.evaluate(format_output)
|
out = self.evaluate(format_output)
|
||||||
@ -95,7 +95,7 @@ class StringFormatOpTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testFormatOneTensorOneDimMatchesSummarize(self):
|
def testFormatOneTensorOneDimMatchesSummarize(self):
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
tensor = math_ops.range(6)
|
tensor = math_ops.range(6)
|
||||||
format_output = string_ops.string_format("{}", tensor, summarize=3)
|
format_output = string_ops.string_format("{}", tensor, summarize=3)
|
||||||
out = self.evaluate(format_output)
|
out = self.evaluate(format_output)
|
||||||
@ -104,28 +104,28 @@ class StringFormatOpTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testFormatOneTensorOneDimVarySummarize(self):
|
def testFormatOneTensorOneDimVarySummarize(self):
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
tensor = math_ops.range(6)
|
tensor = math_ops.range(6)
|
||||||
format_output = string_ops.string_format("{}", tensor, summarize=-1)
|
format_output = string_ops.string_format("{}", tensor, summarize=-1)
|
||||||
out = self.evaluate(format_output)
|
out = self.evaluate(format_output)
|
||||||
expected = "[0 1 2 3 4 5]"
|
expected = "[0 1 2 3 4 5]"
|
||||||
self.assertEqual(compat.as_text(out), expected)
|
self.assertEqual(compat.as_text(out), expected)
|
||||||
|
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
tensor = math_ops.range(6)
|
tensor = math_ops.range(6)
|
||||||
format_output = string_ops.string_format("{}", tensor, summarize=1)
|
format_output = string_ops.string_format("{}", tensor, summarize=1)
|
||||||
out = self.evaluate(format_output)
|
out = self.evaluate(format_output)
|
||||||
expected = "[0 ... 5]"
|
expected = "[0 ... 5]"
|
||||||
self.assertEqual(compat.as_text(out), expected)
|
self.assertEqual(compat.as_text(out), expected)
|
||||||
|
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
tensor = math_ops.range(6)
|
tensor = math_ops.range(6)
|
||||||
format_output = string_ops.string_format("{}", tensor, summarize=2)
|
format_output = string_ops.string_format("{}", tensor, summarize=2)
|
||||||
out = self.evaluate(format_output)
|
out = self.evaluate(format_output)
|
||||||
expected = "[0 1 ... 4 5]"
|
expected = "[0 1 ... 4 5]"
|
||||||
self.assertEqual(compat.as_text(out), expected)
|
self.assertEqual(compat.as_text(out), expected)
|
||||||
|
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
tensor = math_ops.range(6)
|
tensor = math_ops.range(6)
|
||||||
format_output = string_ops.string_format("{}", tensor, summarize=10)
|
format_output = string_ops.string_format("{}", tensor, summarize=10)
|
||||||
out = self.evaluate(format_output)
|
out = self.evaluate(format_output)
|
||||||
@ -134,7 +134,7 @@ class StringFormatOpTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testFormatOneTensorOneDimAlmostSummarize(self):
|
def testFormatOneTensorOneDimAlmostSummarize(self):
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
tensor = math_ops.range(5)
|
tensor = math_ops.range(5)
|
||||||
format_output = string_ops.string_format("{}", tensor, summarize=3)
|
format_output = string_ops.string_format("{}", tensor, summarize=3)
|
||||||
out = self.evaluate(format_output)
|
out = self.evaluate(format_output)
|
||||||
@ -143,7 +143,7 @@ class StringFormatOpTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testFormatOneTensorTwoDimLessThanSummarize(self):
|
def testFormatOneTensorTwoDimLessThanSummarize(self):
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
tensor = array_ops.reshape(math_ops.range(4), [2, 2])
|
tensor = array_ops.reshape(math_ops.range(4), [2, 2])
|
||||||
format_output = string_ops.string_format("{}", tensor, summarize=3)
|
format_output = string_ops.string_format("{}", tensor, summarize=3)
|
||||||
out = self.evaluate(format_output)
|
out = self.evaluate(format_output)
|
||||||
@ -153,7 +153,7 @@ class StringFormatOpTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testFormatOneTensorTwoDim(self):
|
def testFormatOneTensorTwoDim(self):
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
tensor = array_ops.reshape(math_ops.range(100), [10, 10])
|
tensor = array_ops.reshape(math_ops.range(100), [10, 10])
|
||||||
format_output = string_ops.string_format("{}", tensor)
|
format_output = string_ops.string_format("{}", tensor)
|
||||||
out = self.evaluate(format_output)
|
out = self.evaluate(format_output)
|
||||||
@ -168,7 +168,7 @@ class StringFormatOpTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testFormatOneTensorTwoDimSummarizeTwo(self):
|
def testFormatOneTensorTwoDimSummarizeTwo(self):
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
tensor = array_ops.reshape(math_ops.range(100), [10, 10])
|
tensor = array_ops.reshape(math_ops.range(100), [10, 10])
|
||||||
format_output = string_ops.string_format("{}", tensor, summarize=2)
|
format_output = string_ops.string_format("{}", tensor, summarize=2)
|
||||||
out = self.evaluate(format_output)
|
out = self.evaluate(format_output)
|
||||||
@ -181,7 +181,7 @@ class StringFormatOpTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testFormatOneTensorThreeDim(self):
|
def testFormatOneTensorThreeDim(self):
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
tensor = array_ops.reshape(math_ops.range(1000), [10, 10, 10])
|
tensor = array_ops.reshape(math_ops.range(1000), [10, 10, 10])
|
||||||
format_output = string_ops.string_format("{}", tensor)
|
format_output = string_ops.string_format("{}", tensor)
|
||||||
out = self.evaluate(format_output)
|
out = self.evaluate(format_output)
|
||||||
@ -237,7 +237,7 @@ class StringFormatOpTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testFormatOneTensorTemplatePrefix(self):
|
def testFormatOneTensorTemplatePrefix(self):
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
tensor = array_ops.reshape(math_ops.range(100), [10, 10])
|
tensor = array_ops.reshape(math_ops.range(100), [10, 10])
|
||||||
format_output = string_ops.string_format("tensor summary: {}", tensor)
|
format_output = string_ops.string_format("tensor summary: {}", tensor)
|
||||||
out = self.evaluate(format_output)
|
out = self.evaluate(format_output)
|
||||||
@ -252,7 +252,7 @@ class StringFormatOpTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testFormatOneTensorTemplatePrefixAndSuffix(self):
|
def testFormatOneTensorTemplatePrefixAndSuffix(self):
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
tensor = array_ops.reshape(math_ops.range(100), [10, 10])
|
tensor = array_ops.reshape(math_ops.range(100), [10, 10])
|
||||||
format_output = string_ops.string_format("tensor summary: {}, suffix",
|
format_output = string_ops.string_format("tensor summary: {}, suffix",
|
||||||
tensor)
|
tensor)
|
||||||
@ -268,7 +268,7 @@ class StringFormatOpTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testFormatOneTensorTemplateSuffix(self):
|
def testFormatOneTensorTemplateSuffix(self):
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
tensor = array_ops.reshape(math_ops.range(100), [10, 10])
|
tensor = array_ops.reshape(math_ops.range(100), [10, 10])
|
||||||
format_output = string_ops.string_format("{}, suffix", tensor)
|
format_output = string_ops.string_format("{}, suffix", tensor)
|
||||||
out = self.evaluate(format_output)
|
out = self.evaluate(format_output)
|
||||||
@ -283,7 +283,7 @@ class StringFormatOpTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testFormatNoTensor(self):
|
def testFormatNoTensor(self):
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
format_output = string_ops.string_format("No tensor.", ())
|
format_output = string_ops.string_format("No tensor.", ())
|
||||||
out = self.evaluate(format_output)
|
out = self.evaluate(format_output)
|
||||||
expected = "No tensor."
|
expected = "No tensor."
|
||||||
@ -291,7 +291,7 @@ class StringFormatOpTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testFormatMultiTensor(self):
|
def testFormatMultiTensor(self):
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
tensor_one = array_ops.reshape(math_ops.range(100), [10, 10])
|
tensor_one = array_ops.reshape(math_ops.range(100), [10, 10])
|
||||||
tensor_two = tensor_one * 10
|
tensor_two = tensor_one * 10
|
||||||
format_output = string_ops.string_format("One: {},\nTwo: {}",
|
format_output = string_ops.string_format("One: {},\nTwo: {}",
|
||||||
@ -315,7 +315,7 @@ class StringFormatOpTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testFormatSummarizeOne(self):
|
def testFormatSummarizeOne(self):
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
tensor = array_ops.reshape(math_ops.range(100), [10, 10])
|
tensor = array_ops.reshape(math_ops.range(100), [10, 10])
|
||||||
format_output = string_ops.string_format("tensor summary: {}", tensor,
|
format_output = string_ops.string_format("tensor summary: {}", tensor,
|
||||||
summarize=1)
|
summarize=1)
|
||||||
@ -327,7 +327,7 @@ class StringFormatOpTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testFormatSummarizeTwo(self):
|
def testFormatSummarizeTwo(self):
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
tensor = array_ops.reshape(math_ops.range(100), [10, 10])
|
tensor = array_ops.reshape(math_ops.range(100), [10, 10])
|
||||||
format_output = string_ops.string_format("tensor summary: {}", tensor,
|
format_output = string_ops.string_format("tensor summary: {}", tensor,
|
||||||
summarize=2)
|
summarize=2)
|
||||||
@ -341,7 +341,7 @@ class StringFormatOpTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testFormatPlaceholder(self):
|
def testFormatPlaceholder(self):
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
tensor = array_ops.reshape(math_ops.range(100), [10, 10])
|
tensor = array_ops.reshape(math_ops.range(100), [10, 10])
|
||||||
format_output = string_ops.string_format("tensor summary: %t%", tensor,
|
format_output = string_ops.string_format("tensor summary: %t%", tensor,
|
||||||
placeholder="%t%")
|
placeholder="%t%")
|
||||||
@ -357,21 +357,21 @@ class StringFormatOpTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testTensorCountMustMatchPlaceholderCount(self):
|
def testTensorCountMustMatchPlaceholderCount(self):
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
ValueError, r"2 placeholder\(s\) in template does not match 1 "
|
ValueError, r"2 placeholder\(s\) in template does not match 1 "
|
||||||
r"tensor\(s\) provided as input"):
|
r"tensor\(s\) provided as input"):
|
||||||
tensor = math_ops.range(10)
|
tensor = math_ops.range(10)
|
||||||
format_output = string_ops.string_format("{} {}", tensor)
|
format_output = string_ops.string_format("{} {}", tensor)
|
||||||
self.evaluate(format_output)
|
self.evaluate(format_output)
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
ValueError, r"2 placeholder\(s\) in template does not match 1 "
|
ValueError, r"2 placeholder\(s\) in template does not match 1 "
|
||||||
r"tensor\(s\) provided as input"):
|
r"tensor\(s\) provided as input"):
|
||||||
tensor = math_ops.range(10)
|
tensor = math_ops.range(10)
|
||||||
format_output = string_ops.string_format("{} {}", [tensor])
|
format_output = string_ops.string_format("{} {}", [tensor])
|
||||||
self.evaluate(format_output)
|
self.evaluate(format_output)
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
ValueError, r"1 placeholder\(s\) in template does not match 2 "
|
ValueError, r"1 placeholder\(s\) in template does not match 2 "
|
||||||
r"tensor\(s\) provided as input"):
|
r"tensor\(s\) provided as input"):
|
||||||
|
@ -41,7 +41,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
|||||||
x = constant_op.constant(2.)
|
x = constant_op.constant(2.)
|
||||||
ret = while_loop_v2(lambda v: v < 8., lambda v: v * v, [x])
|
ret = while_loop_v2(lambda v: v < 8., lambda v: v * v, [x])
|
||||||
grad = gradients_impl.gradients(ret, [x])
|
grad = gradients_impl.gradients(ret, [x])
|
||||||
with self.test_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertEqual(sess.run(ret), 16.)
|
self.assertEqual(sess.run(ret), 16.)
|
||||||
self.assertSequenceEqual(sess.run(grad), [32.])
|
self.assertSequenceEqual(sess.run(grad), [32.])
|
||||||
|
|
||||||
@ -58,7 +58,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
# Note: This is simply d_ret[0]/d_x since d_ret[1]/d_x is 0.
|
# Note: This is simply d_ret[0]/d_x since d_ret[1]/d_x is 0.
|
||||||
grad = gradients_impl.gradients(ret, [x]) # [2*x*y]
|
grad = gradients_impl.gradients(ret, [x]) # [2*x*y]
|
||||||
with self.test_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertSequenceEqual(sess.run(ret), [45., 3.])
|
self.assertSequenceEqual(sess.run(ret), [45., 3.])
|
||||||
self.assertSequenceEqual(sess.run(grad), [9.])
|
self.assertSequenceEqual(sess.run(grad), [9.])
|
||||||
|
|
||||||
@ -81,7 +81,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
|||||||
grady_0 = gradients_impl.gradients(ret[0], [y]) # [2*x*y + x**2]
|
grady_0 = gradients_impl.gradients(ret[0], [y]) # [2*x*y + x**2]
|
||||||
grady_1 = gradients_impl.gradients(ret[1], [y]) # [x + 1]
|
grady_1 = gradients_impl.gradients(ret[1], [y]) # [x + 1]
|
||||||
grady_2 = gradients_impl.gradients(ret, [y]) # [2*x*y + x**2 + x + 1]
|
grady_2 = gradients_impl.gradients(ret, [y]) # [2*x*y + x**2 + x + 1]
|
||||||
with self.test_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertSequenceEqual(sess.run(ret), [120., 23.])
|
self.assertSequenceEqual(sess.run(ret), [120., 23.])
|
||||||
self.assertSequenceEqual(sess.run(gradx_0), [39.])
|
self.assertSequenceEqual(sess.run(gradx_0), [39.])
|
||||||
self.assertSequenceEqual(sess.run(gradx_1), [4.])
|
self.assertSequenceEqual(sess.run(gradx_1), [4.])
|
||||||
@ -96,7 +96,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
|||||||
ret2 = while_loop_v2(lambda v: v < 16., lambda v: v * v, ret1) # x**4
|
ret2 = while_loop_v2(lambda v: v < 16., lambda v: v * v, ret1) # x**4
|
||||||
grad = gradients_impl.gradients(ret2, [x]) # 4x**3
|
grad = gradients_impl.gradients(ret2, [x]) # 4x**3
|
||||||
grad_grad = gradients_impl.gradients(grad, [x]) # 12x**2
|
grad_grad = gradients_impl.gradients(grad, [x]) # 12x**2
|
||||||
with self.test_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertSequenceEqual(sess.run(grad), [32.])
|
self.assertSequenceEqual(sess.run(grad), [32.])
|
||||||
self.assertSequenceEqual(sess.run(grad_grad), [48.])
|
self.assertSequenceEqual(sess.run(grad_grad), [48.])
|
||||||
|
|
||||||
@ -105,7 +105,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
|||||||
ret = while_loop_v2(lambda v: v < 8., lambda v: v**2, [x]) # x**4
|
ret = while_loop_v2(lambda v: v < 8., lambda v: v**2, [x]) # x**4
|
||||||
grad = gradients_impl.gradients(ret, [x]) # 4x**3
|
grad = gradients_impl.gradients(ret, [x]) # 4x**3
|
||||||
grad_grad = gradients_impl.gradients(grad, [x]) # 12x**2
|
grad_grad = gradients_impl.gradients(grad, [x]) # 12x**2
|
||||||
with self.test_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertEqual(sess.run(ret), 16.)
|
self.assertEqual(sess.run(ret), 16.)
|
||||||
self.assertSequenceEqual(sess.run(grad), [32.])
|
self.assertSequenceEqual(sess.run(grad), [32.])
|
||||||
self.assertSequenceEqual(sess.run(grad_grad), [48.])
|
self.assertSequenceEqual(sess.run(grad_grad), [48.])
|
||||||
@ -148,7 +148,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
|||||||
y = constant_op.constant(1.)
|
y = constant_op.constant(1.)
|
||||||
ret = while_loop_v2(lambda v: v + y < 9., lambda v: v * 3., [x])
|
ret = while_loop_v2(lambda v: v + y < 9., lambda v: v * 3., [x])
|
||||||
grad = gradients_impl.gradients(ret, [x])
|
grad = gradients_impl.gradients(ret, [x])
|
||||||
with self.test_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertEqual(sess.run(ret), 18.)
|
self.assertEqual(sess.run(ret), 18.)
|
||||||
self.assertSequenceEqual(sess.run(grad), [9.])
|
self.assertSequenceEqual(sess.run(grad), [9.])
|
||||||
|
|
||||||
@ -157,7 +157,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
|||||||
y = constant_op.constant(3.)
|
y = constant_op.constant(3.)
|
||||||
ret = while_loop_v2(lambda v: v < 8., lambda v: v * y, [x])
|
ret = while_loop_v2(lambda v: v < 8., lambda v: v * y, [x])
|
||||||
grad = gradients_impl.gradients(ret, [x])
|
grad = gradients_impl.gradients(ret, [x])
|
||||||
with self.test_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertEqual(sess.run(ret), 18.)
|
self.assertEqual(sess.run(ret), 18.)
|
||||||
self.assertSequenceEqual(sess.run(grad), [9.])
|
self.assertSequenceEqual(sess.run(grad), [9.])
|
||||||
|
|
||||||
@ -178,7 +178,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
ret = while_loop_v2(Cond, Body, [x, tensor_list])
|
ret = while_loop_v2(Cond, Body, [x, tensor_list])
|
||||||
grad = gradients_impl.gradients(ret[0], x)
|
grad = gradients_impl.gradients(ret[0], x)
|
||||||
with self.test_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertEqual(sess.run(ret[0]), 16.)
|
self.assertEqual(sess.run(ret[0]), 16.)
|
||||||
self.assertSequenceEqual(sess.run(grad), [32.])
|
self.assertSequenceEqual(sess.run(grad), [32.])
|
||||||
|
|
||||||
@ -212,7 +212,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertEqual(accumulator_count, 1)
|
self.assertEqual(accumulator_count, 1)
|
||||||
|
|
||||||
grad = gradients_impl.gradients(ret[0], x)
|
grad = gradients_impl.gradients(ret[0], x)
|
||||||
with self.test_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertEqual(sess.run(ret[0]), 16.)
|
self.assertEqual(sess.run(ret[0]), 16.)
|
||||||
self.assertSequenceEqual(sess.run(grad), [32.])
|
self.assertSequenceEqual(sess.run(grad), [32.])
|
||||||
|
|
||||||
|
@ -3673,7 +3673,7 @@ class NonMaxSuppressionTest(test_util.TensorFlowTestCase):
|
|||||||
# Note: There are multiple versions of non_max_suppression v2, v3, v4.
|
# Note: There are multiple versions of non_max_suppression v2, v3, v4.
|
||||||
# gen_image_ops.non_max_suppression_v2:
|
# gen_image_ops.non_max_suppression_v2:
|
||||||
for dtype in [np.float16, np.float32]:
|
for dtype in [np.float16, np.float32]:
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
boxes = constant_op.constant(boxes_np, dtype=dtype)
|
boxes = constant_op.constant(boxes_np, dtype=dtype)
|
||||||
scores = constant_op.constant(scores_np, dtype=dtype)
|
scores = constant_op.constant(scores_np, dtype=dtype)
|
||||||
max_output_size = constant_op.constant(max_output_size_np)
|
max_output_size = constant_op.constant(max_output_size_np)
|
||||||
@ -3683,7 +3683,7 @@ class NonMaxSuppressionTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertAllClose(selected_indices, [3, 0, 5])
|
self.assertAllClose(selected_indices, [3, 0, 5])
|
||||||
# image_ops.non_max_suppression = gen_image_ops.non_max_suppression_v3.
|
# image_ops.non_max_suppression = gen_image_ops.non_max_suppression_v3.
|
||||||
for dtype in [np.float16, np.float32]:
|
for dtype in [np.float16, np.float32]:
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
boxes = constant_op.constant(boxes_np, dtype=dtype)
|
boxes = constant_op.constant(boxes_np, dtype=dtype)
|
||||||
scores = constant_op.constant(scores_np, dtype=dtype)
|
scores = constant_op.constant(scores_np, dtype=dtype)
|
||||||
max_output_size = constant_op.constant(max_output_size_np)
|
max_output_size = constant_op.constant(max_output_size_np)
|
||||||
@ -3694,7 +3694,7 @@ class NonMaxSuppressionTest(test_util.TensorFlowTestCase):
|
|||||||
# gen_image_ops.non_max_suppression_v4.
|
# gen_image_ops.non_max_suppression_v4.
|
||||||
score_threshold = float('-inf')
|
score_threshold = float('-inf')
|
||||||
for dtype in [np.float16, np.float32]:
|
for dtype in [np.float16, np.float32]:
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
boxes = constant_op.constant(boxes_np, dtype=dtype)
|
boxes = constant_op.constant(boxes_np, dtype=dtype)
|
||||||
scores = constant_op.constant(scores_np, dtype=dtype)
|
scores = constant_op.constant(scores_np, dtype=dtype)
|
||||||
max_output_size = constant_op.constant(max_output_size_np)
|
max_output_size = constant_op.constant(max_output_size_np)
|
||||||
|
@ -218,7 +218,7 @@ class FtrlOptimizerTest(test.TestCase):
|
|||||||
def testFtrlWithL1_L2_L2ShrinkageSparse(self):
|
def testFtrlWithL1_L2_L2ShrinkageSparse(self):
|
||||||
"""Tests the new FTRL op with support for l2 shrinkage on sparse grads."""
|
"""Tests the new FTRL op with support for l2 shrinkage on sparse grads."""
|
||||||
for dtype in [dtypes.half, dtypes.float32]:
|
for dtype in [dtypes.half, dtypes.float32]:
|
||||||
with self.test_session() as sess:
|
with self.cached_session() as sess:
|
||||||
var0 = variables.Variable([[1.0], [2.0]], dtype=dtype)
|
var0 = variables.Variable([[1.0], [2.0]], dtype=dtype)
|
||||||
var1 = variables.Variable([[4.0], [3.0]], dtype=dtype)
|
var1 = variables.Variable([[4.0], [3.0]], dtype=dtype)
|
||||||
grads0 = ops.IndexedSlices(
|
grads0 = ops.IndexedSlices(
|
||||||
@ -252,7 +252,7 @@ class FtrlOptimizerTest(test.TestCase):
|
|||||||
def testFtrlWithL2ShrinkageDoesNotChangeLrSchedule(self):
|
def testFtrlWithL2ShrinkageDoesNotChangeLrSchedule(self):
|
||||||
"""Verifies that l2 shrinkage in FTRL does not change lr schedule."""
|
"""Verifies that l2 shrinkage in FTRL does not change lr schedule."""
|
||||||
for dtype in [dtypes.half, dtypes.float32]:
|
for dtype in [dtypes.half, dtypes.float32]:
|
||||||
with self.test_session() as sess:
|
with self.cached_session() as sess:
|
||||||
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
|
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
|
||||||
var1 = variables.Variable([1.0, 2.0], dtype=dtype)
|
var1 = variables.Variable([1.0, 2.0], dtype=dtype)
|
||||||
grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
|
grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
|
||||||
|
@ -62,7 +62,7 @@ class LRDecayTestV2(test_util.TensorFlowTestCase):
|
|||||||
self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
|
self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
|
||||||
|
|
||||||
def testVariables(self):
|
def testVariables(self):
|
||||||
with self.test_session():
|
with self.cached_session():
|
||||||
step = variables.Variable(1)
|
step = variables.Variable(1)
|
||||||
assign_1 = step.assign(1)
|
assign_1 = step.assign(1)
|
||||||
assign_2 = step.assign(2)
|
assign_2 = step.assign(2)
|
||||||
|
Loading…
Reference in New Issue
Block a user