Automated rollback of commit 1fdd7c7408

PiperOrigin-RevId: 222434204
This commit is contained in:
Alexandre Passos 2018-11-21 11:09:47 -08:00 committed by TensorFlower Gardener
parent c16394423c
commit f6ce9fd485
216 changed files with 1980 additions and 2004 deletions

View File

@ -61,7 +61,7 @@ class CategoricalTest(xla_test.XLATestCase):
random_seed.set_random_seed(1618)
op = random_ops.multinomial(logits, num_samples,
output_dtype=dtypes.int32)
d = self.evaluate(op)
d = sess.run(op)
batch_size, num_classes = logits.shape
freqs_mat = []
@ -86,9 +86,9 @@ class CategoricalTest(xla_test.XLATestCase):
# The random-number generator, if working correctly, should produce the
# same output multiple times with low probability.
y = self.evaluate(x)
z = self.evaluate(x)
w = self.evaluate(x)
y = sess.run(x)
z = sess.run(x)
w = sess.run(x)
# We use exact equality here. If the random-number generator is producing
# deterministic output, all three outputs will be bitwise identical.
@ -113,7 +113,7 @@ class CategoricalTest(xla_test.XLATestCase):
x = random_ops.multinomial(
array_ops.ones(shape=[1, 20], dtype=dtype), 1000,
output_dtype=output_dtype)
y = self.evaluate(x)
y = sess.run(x)
self.assertTrue((y >= 0).sum() == 1000)
self.assertTrue((y < 20).sum() == 1000)

View File

@ -337,7 +337,7 @@ class ConcatOffsetTest(xla_test.XLATestCase):
s1 = constant_op.constant([2, 7, 5], dtypes.int32)
s2 = constant_op.constant([2, 20, 5], dtypes.int32)
off = gen_array_ops.concat_offset(cdim, [s0, s1, s2])
ans = self.evaluate(off)
ans = sess.run(off)
self.assertAllEqual(ans, [[0, 0, 0], [0, 3, 0], [0, 10, 0]])
@ -350,7 +350,7 @@ class PackTest(xla_test.XLATestCase):
s1 = constant_op.constant([2, 7, 5], dtypes.int32)
s2 = constant_op.constant([2, 20, 5], dtypes.int32)
packed = array_ops.stack([s0, s1, s2])
ans = self.evaluate(packed)
ans = sess.run(packed)
self.assertAllEqual(ans, [[2, 3, 5], [2, 7, 5], [2, 20, 5]])
def testScalars(self):
@ -360,7 +360,7 @@ class PackTest(xla_test.XLATestCase):
s1 = constant_op.constant(3, dtypes.int32)
s2 = constant_op.constant(5, dtypes.int32)
packed = array_ops.stack([s0, s1, s2])
ans = self.evaluate(packed)
ans = sess.run(packed)
self.assertAllEqual(ans, [2, 3, 5])
def testEmpty(self):
@ -370,7 +370,7 @@ class PackTest(xla_test.XLATestCase):
s1 = constant_op.constant([[]], dtypes.int32)
s2 = constant_op.constant([[]], dtypes.int32)
packed = array_ops.stack([s0, s1, s2])
ans = self.evaluate(packed)
ans = sess.run(packed)
self.assertAllEqual(ans, [[[]], [[]], [[]]])

View File

@ -106,7 +106,7 @@ class EagerTest(xla_test.XLATestCase):
three = constant_op.constant(3)
five = constant_op.constant(5)
product = three * five
self.assertAllEqual(15, self.evaluate(product))
self.assertAllEqual(15, sess.run(product))
def testDegenerateSlices(self):
with self.test_scope():

View File

@ -50,7 +50,7 @@ class FunctionTest(xla_test.XLATestCase):
b = constant_op.constant(bval, name="b")
with self.test_scope():
call_f = Foo(a, b)
result = self.evaluate(call_f)
result = sess.run(call_f)
self.assertAllClose(result, expected, rtol=1e-3)
def testNestedFunctions(self):
@ -76,7 +76,7 @@ class FunctionTest(xla_test.XLATestCase):
b = constant_op.constant(bval, name="b")
with self.test_scope():
call_g = Foo(a, b)
result = self.evaluate(call_g)
result = sess.run(call_g)
self.assertAllClose(result, expected, rtol=1e-3)
def testFunctionMultipleRetvals(self):
@ -100,7 +100,7 @@ class FunctionTest(xla_test.XLATestCase):
b = constant_op.constant(bval, name="b")
with self.test_scope():
call_f = Foo(a, b)
result = self.evaluate(call_f)
result = sess.run(call_f)
self.assertAllClose(result, expected, rtol=1e-3)
def testCompileTimeConstantsInDefun(self):

View File

@ -88,7 +88,7 @@ class LSTMTest(test.TestCase):
(basename, m_prev_scalar, c_prev_scalar, pad_scalar))
# Initialize variables and run the unrolled LSTM step.
self.evaluate(variables.global_variables_initializer())
sess.run(variables.global_variables_initializer())
return sess.run([m, c])
def testLSTMCell(self):
@ -173,7 +173,7 @@ class LSTMTest(test.TestCase):
(basename, m_init_scalar, c_init_scalar, pad_scalar))
# Initialize variables and run the unrolled LSTM layer.
self.evaluate(variables.global_variables_initializer())
sess.run(variables.global_variables_initializer())
return sess.run(out_seq)
def testLSTMLayer(self):

View File

@ -33,7 +33,7 @@ class PlaceholderTest(xla_test.XLATestCase):
ph = array_ops.placeholder_with_default(v, shape=[])
out = ph * 2
sess.run(variables.variables_initializer([v]))
self.assertEqual(8.0, self.evaluate(out))
self.assertEqual(8.0, sess.run(out))
def test_placeholder_with_default_fed(self):
with self.cached_session() as sess, self.test_scope():

View File

@ -46,9 +46,9 @@ class RandomOpsTest(xla_test.XLATestCase):
# The random-number generator, if working correctly, should produce the
# same output multiple times with low probability.
y = self.evaluate(x)
z = self.evaluate(x)
w = self.evaluate(x)
y = sess.run(x)
z = sess.run(x)
w = sess.run(x)
# We use exact equality here. If the random-number generator is producing
# deterministic output, all three outputs will be bitwise identical.
@ -83,7 +83,7 @@ class RandomOpsTest(xla_test.XLATestCase):
with self.test_scope():
x = random_ops.random_uniform(
shape=[1000], dtype=dtype, minval=-2, maxval=33)
y = self.evaluate(x)
y = sess.run(x)
self.assertTrue((y >= -2).sum() == 1000)
self.assertTrue((y < 33).sum() == 1000)
@ -102,7 +102,7 @@ class RandomOpsTest(xla_test.XLATestCase):
with self.cached_session() as sess:
with self.test_scope():
x = random_ops.truncated_normal(shape=[count], dtype=dtype)
y = self.evaluate(x)
y = sess.run(x)
def normal_cdf(x):
return .5 * math.erfc(-x / math.sqrt(2))
@ -148,7 +148,7 @@ class RandomOpsTest(xla_test.XLATestCase):
with self.test_scope():
x = math_ops.range(1 << 16)
shuffle = random_ops.random_shuffle(x)
result = self.evaluate(shuffle)
result = sess.run(shuffle)
expected = range(1 << 16)
# Compare sets to avoid randomness behavior changes but make sure still
# have all the values.
@ -159,7 +159,7 @@ class RandomOpsTest(xla_test.XLATestCase):
with self.test_scope():
x = array_ops.diag(math_ops.range(20))
shuffle = random_ops.random_shuffle(x)
result = self.evaluate(shuffle)
result = sess.run(shuffle)
expected = np.diag(range(20)).flatten()
# Compare sets to avoid randomness behavior changes but make sure still
# have all the values.

View File

@ -505,7 +505,7 @@ class TensorArrayTest(xla_test.XLATestCase):
[-0.5, 1.5], # read(0) gradient
[20.0, 30.0, 40.0, 50.0], # concat gradient
])
grad_vals = self.evaluate(grad_r) # 2 + 2 entries
grad_vals = sess.run(grad_r) # 2 + 2 entries
self.assertAllClose([2.0 - 0.5 + 20.0, 3.0 + 1.5 + 30.0], grad_vals[0])
self.assertAllEqual([4.0 + 40.0, 5.0 + 50.0], grad_vals[1])

View File

@ -229,7 +229,7 @@ class VariableOpsTest(xla_test.XLATestCase):
resource_variable_ops.resource_scatter_add(
handle, [0], constant_op.constant([[2]], dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertAllEqual(self.evaluate(read), [[3], [7]])
self.assertAllEqual(sess.run(read), [[3], [7]])
def testScatterSub(self):
with self.test_session() as sess, self.test_scope():
@ -242,7 +242,7 @@ class VariableOpsTest(xla_test.XLATestCase):
resource_variable_ops.resource_scatter_sub(
handle, [1], constant_op.constant([[2]], dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertAllEqual(self.evaluate(read), [[4], [-1]])
self.assertAllEqual(sess.run(read), [[4], [-1]])
def testScatterMul(self):
with self.test_session() as sess, self.test_scope():
@ -255,7 +255,7 @@ class VariableOpsTest(xla_test.XLATestCase):
resource_variable_ops.resource_scatter_mul(
handle, [0], constant_op.constant([[5]], dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(self.evaluate(read), [[5]])
self.assertEqual(sess.run(read), [[5]])
def testScatterDiv(self):
with self.test_session() as sess, self.test_scope():
@ -268,7 +268,7 @@ class VariableOpsTest(xla_test.XLATestCase):
resource_variable_ops.resource_scatter_div(
handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertAllEqual(self.evaluate(read), [[2]])
self.assertAllEqual(sess.run(read), [[2]])
def testScatterMin(self):
with self.test_session() as sess, self.test_scope():
@ -281,7 +281,7 @@ class VariableOpsTest(xla_test.XLATestCase):
resource_variable_ops.resource_scatter_min(
handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(self.evaluate(read), [[3]])
self.assertEqual(sess.run(read), [[3]])
def testScatterMax(self):
with self.test_session() as sess, self.test_scope():
@ -294,7 +294,7 @@ class VariableOpsTest(xla_test.XLATestCase):
resource_variable_ops.resource_scatter_max(
handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(self.evaluate(read), [[6]])
self.assertEqual(sess.run(read), [[6]])
def testScatterUpdate(self):
with self.test_session() as sess, self.test_scope():
@ -307,7 +307,7 @@ class VariableOpsTest(xla_test.XLATestCase):
resource_variable_ops.resource_scatter_update(
handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(self.evaluate(read), [[3]])
self.assertEqual(sess.run(read), [[3]])
def testScatterAddScalar(self):
with self.test_session() as sess, self.test_scope():
@ -320,7 +320,7 @@ class VariableOpsTest(xla_test.XLATestCase):
resource_variable_ops.resource_scatter_add(
handle, [0], constant_op.constant(2, dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(self.evaluate(read), [[3]])
self.assertEqual(sess.run(read), [[3]])
def testScatterSubScalar(self):
with self.test_session() as sess, self.test_scope():
@ -333,7 +333,7 @@ class VariableOpsTest(xla_test.XLATestCase):
resource_variable_ops.resource_scatter_sub(
handle, [0], constant_op.constant(2, dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(self.evaluate(read), [[-1]])
self.assertEqual(sess.run(read), [[-1]])
def testScatterMulScalar(self):
with self.test_session() as sess, self.test_scope():
@ -346,7 +346,7 @@ class VariableOpsTest(xla_test.XLATestCase):
resource_variable_ops.resource_scatter_mul(
handle, [0], constant_op.constant(5, dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(self.evaluate(read), [[5]])
self.assertEqual(sess.run(read), [[5]])
def testScatterDivScalar(self):
with self.test_session() as sess, self.test_scope():
@ -359,7 +359,7 @@ class VariableOpsTest(xla_test.XLATestCase):
resource_variable_ops.resource_scatter_div(
handle, [0], constant_op.constant(3, dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(self.evaluate(read), [[2]])
self.assertEqual(sess.run(read), [[2]])
def testScatterMinScalar(self):
with self.test_session() as sess, self.test_scope():
@ -372,7 +372,7 @@ class VariableOpsTest(xla_test.XLATestCase):
resource_variable_ops.resource_scatter_min(
handle, [0], constant_op.constant(3, dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(self.evaluate(read), [[3]])
self.assertEqual(sess.run(read), [[3]])
def testScatterMaxScalar(self):
with self.test_session() as sess, self.test_scope():
@ -385,7 +385,7 @@ class VariableOpsTest(xla_test.XLATestCase):
resource_variable_ops.resource_scatter_max(
handle, [0], constant_op.constant(3, dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(self.evaluate(read), [[6]])
self.assertEqual(sess.run(read), [[6]])
def testScatterNdAddOps(self):
with self.test_session() as sess, self.test_scope():
@ -400,7 +400,7 @@ class VariableOpsTest(xla_test.XLATestCase):
sess.run(gen_state_ops.resource_scatter_nd_add(handle, indices, updates))
read = resource_variable_ops.read_variable_op(
handle, dtype=dtypes.float32)
self.assertAllClose(expected, self.evaluate(read))
self.assertAllClose(expected, sess.run(read))
def testScatterNdUpdateAddOps(self):
with self.test_session() as sess, self.test_scope():
@ -416,7 +416,7 @@ class VariableOpsTest(xla_test.XLATestCase):
gen_state_ops.resource_scatter_nd_update(handle, indices, updates))
read = resource_variable_ops.read_variable_op(
handle, dtype=dtypes.float32)
self.assertAllClose(expected, self.evaluate(read))
self.assertAllClose(expected, sess.run(read))
class StridedSliceAssignChecker(object):

View File

@ -96,7 +96,7 @@ class KerasTest(tf.test.TestCase):
sess.run(init)
sample_input = tf.random_uniform((1, 10, 10, 1))
output = model(sample_input) # pylint: disable=not-callable
self.assertEqual(self.evaluate(output).shape, (1, 3))
self.assertEqual(sess.run(output).shape, (1, 3))
if __name__ == '__main__':

View File

@ -34,7 +34,7 @@ class ListLiteralsTest(tf.test.TestCase):
result = converted()
with self.cached_session() as sess:
self.assertAllEqual(self.evaluate(result), [1, 2, 3])
self.assertAllEqual(sess.run(result), [1, 2, 3])
if __name__ == '__main__':

View File

@ -35,7 +35,7 @@ class InputDataTest(test.TestCase):
with self.cached_session() as sess:
sample_data = tf.zeros([32000, 2])
wav_encoder = contrib_audio.encode_wav(sample_data, 16000)
wav_data = self.evaluate(wav_encoder)
wav_data = sess.run(wav_encoder)
return wav_data
def _saveTestWavFile(self, filename, wav_data):

View File

@ -33,7 +33,7 @@ class LabelWavTest(test.TestCase):
with self.cached_session() as sess:
sample_data = tf.zeros([1000, 2])
wav_encoder = contrib_audio.encode_wav(sample_data, 16000)
wav_data = self.evaluate(wav_encoder)
wav_data = sess.run(wav_encoder)
return wav_data
def _saveTestWavFile(self, filename, wav_data):

View File

@ -33,7 +33,7 @@ class WavToFeaturesTest(test.TestCase):
with self.cached_session() as sess:
sample_data = tf.zeros([32000, 2])
wav_encoder = contrib_audio.encode_wav(sample_data, 16000)
wav_data = self.evaluate(wav_encoder)
wav_data = sess.run(wav_encoder)
return wav_data
def _saveTestWavFile(self, filename, wav_data):

View File

@ -113,7 +113,7 @@ class CallTreesTest(converter_testing.TestCase):
with self.compiled(node, ns) as result:
with self.cached_session() as sess:
result_tensor = result.test_fn(constant_op.constant(1))
self.assertEquals(self.evaluate(result_tensor), 3)
self.assertEquals(sess.run(result_tensor), 3)
def test_call_to_decorated_function(self):

View File

@ -68,7 +68,7 @@ class ListTest(converter_testing.TestCase):
with self.cached_session() as sess:
tl = result.test_fn()
r = list_ops.tensor_list_stack(tl, dtypes.int32)
self.assertAllEqual(self.evaluate(r), [1, 2, 3])
self.assertAllEqual(sess.run(r), [1, 2, 3])
def test_list_pop(self):
@ -91,8 +91,8 @@ class ListTest(converter_testing.TestCase):
with self.cached_session() as sess:
ts, tl = result.test_fn()
r = list_ops.tensor_list_stack(tl, dtypes.int32)
self.assertAllEqual(self.evaluate(r), [1, 2])
self.assertAllEqual(self.evaluate(ts), 3)
self.assertAllEqual(sess.run(r), [1, 2])
self.assertAllEqual(sess.run(ts), 3)
def test_double_list_pop(self):

View File

@ -48,12 +48,12 @@ class SideEffectGuardsTest(converter_testing.TestCase):
with self.compiled(node, {}, state_ops.assign) as result:
with self.cached_session() as sess:
v = variable_scope.get_variable('test', initializer=2)
self.evaluate(v.initializer)
sess.run(v.initializer)
sess.run(result.test_fn(v))
# TODO(mdan): Add support for this use case.
# Right now the variable `a` is not conditioned on the `assign` because
# there's no way to add control dependencies to a variable object.
self.assertEqual(2, self.evaluate(v))
self.assertEqual(2, sess.run(v))
def test_side_effect_on_used_variable(self):
@ -69,11 +69,11 @@ class SideEffectGuardsTest(converter_testing.TestCase):
with self.compiled(node, {}, state_ops.assign) as result:
with self.cached_session() as sess:
v = variable_scope.get_variable('test', initializer=2)
self.evaluate(v.initializer)
sess.run(v.initializer)
sess.run(result.test_fn(v))
# TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
# Right now it's 3 or 4 based on whether the read is synchronized.
self.assertEqual(3, self.evaluate(v))
self.assertEqual(3, sess.run(v))
def test_side_effect_on_tensor(self):
@ -109,10 +109,10 @@ class SideEffectGuardsTest(converter_testing.TestCase):
with self.compiled(node, {}, state_ops.assign_add) as result:
with self.cached_session() as sess:
v = variable_scope.get_variable('test', initializer=2)
self.evaluate(v.initializer)
sess.run(v.initializer)
sess.run(result.test_fn(v))
# TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
self.assertEqual(4, self.evaluate(v))
self.assertEqual(4, sess.run(v))
def test_multiline_nested_block(self):
@ -130,10 +130,10 @@ class SideEffectGuardsTest(converter_testing.TestCase):
with self.compiled(node, {}, state_ops.assign, ops.name_scope) as result:
with self.cached_session() as sess:
v = variable_scope.get_variable('test', initializer=2)
self.evaluate(v.initializer)
sess.run(v.initializer)
sess.run(result.test_fn(v))
# TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
self.assertEqual(3, self.evaluate(v))
self.assertEqual(3, sess.run(v))
def test_multiline_block_unsafe(self):
@ -153,10 +153,10 @@ class SideEffectGuardsTest(converter_testing.TestCase):
state_ops.assign_add) as result:
with self.cached_session() as sess:
v = variable_scope.get_variable('test', initializer=2)
self.evaluate(v.initializer)
sess.run(v.initializer)
sess.run(result.test_fn(v))
# TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
self.assertEqual(4, self.evaluate(v))
self.assertEqual(4, sess.run(v))
if __name__ == '__main__':

View File

@ -49,7 +49,7 @@ class SliceTest(converter_testing.TestCase):
tl = list_ops.tensor_list_from_tensor(
[1, 2], element_shape=constant_op.constant([], dtype=dtypes.int32))
y = result.test_fn(tl)
self.assertEqual(2, self.evaluate(y))
self.assertEqual(2, sess.run(y))
def test_index_access_multiple_definitions(self):

View File

@ -63,7 +63,7 @@ class ApiTest(test.TestCase):
x = tc.test_method(
constant_op.constant([2, 4]), constant_op.constant(1),
constant_op.constant(-2))
self.assertListEqual([0, 1], self.evaluate(x).tolist())
self.assertListEqual([0, 1], sess.run(x).tolist())
def test_decorator_does_not_recurse(self):
@ -83,7 +83,7 @@ class ApiTest(test.TestCase):
x = tc.test_method(
constant_op.constant([2, 4]), constant_op.constant(1),
constant_op.constant(-2))
self.assertListEqual([0, 1], self.evaluate(x).tolist())
self.assertListEqual([0, 1], sess.run(x).tolist())
def test_decorator_calls_unconverted_graph(self):
@ -104,7 +104,7 @@ class ApiTest(test.TestCase):
x = tc.test_method(
constant_op.constant([2, 4]), constant_op.constant(1),
constant_op.constant(-2))
self.assertListEqual([0, 1], self.evaluate(x).tolist())
self.assertListEqual([0, 1], sess.run(x).tolist())
def test_decorator_calls_unconverted_py_func(self):
@ -130,7 +130,7 @@ class ApiTest(test.TestCase):
x = tc.test_method(
constant_op.constant([2, 4]), constant_op.constant(1),
constant_op.constant(-2))
self.assertListEqual([0, 1], self.evaluate(x).tolist())
self.assertListEqual([0, 1], sess.run(x).tolist())
def test_decorator_calls_decorated(self):
@ -153,7 +153,7 @@ class ApiTest(test.TestCase):
x = tc.test_method(
constant_op.constant([2, 4]), constant_op.constant(1),
constant_op.constant(-2))
self.assertListEqual([0, 1], self.evaluate(x).tolist())
self.assertListEqual([0, 1], sess.run(x).tolist())
def test_decorator_preserves_argspec(self):
@ -192,7 +192,7 @@ class ApiTest(test.TestCase):
x = tc.test_method(
constant_op.constant([2, 4]), constant_op.constant(1),
constant_op.constant(-2))
self.assertListEqual([0, 1], self.evaluate(x).tolist())
self.assertListEqual([0, 1], sess.run(x).tolist())
def test_converted_call_builtin(self):
x = api.converted_call(range, None, converter.ConversionOptions(), 3)
@ -208,7 +208,7 @@ class ApiTest(test.TestCase):
with self.cached_session() as sess:
x = api.converted_call(test_fn, None, converter.ConversionOptions(),
constant_op.constant(-1))
self.assertEqual(1, self.evaluate(x))
self.assertEqual(1, sess.run(x))
def test_converted_call_method_explicit_owner(self):
# TODO(mdan): Implement.
@ -234,7 +234,7 @@ class ApiTest(test.TestCase):
tc = TestClass(constant_op.constant(-1))
x = api.converted_call(tc.test_method, None,
converter.ConversionOptions(), tc)
self.assertEqual(1, self.evaluate(x))
self.assertEqual(1, sess.run(x))
def test_converted_call_method_by_class(self):
@ -252,7 +252,7 @@ class ApiTest(test.TestCase):
tc = TestClass(constant_op.constant(-1))
x = api.converted_call(TestClass.test_method, None,
converter.ConversionOptions(), tc)
self.assertEqual(1, self.evaluate(x))
self.assertEqual(1, sess.run(x))
def test_converted_call_callable_object(self):
@ -269,7 +269,7 @@ class ApiTest(test.TestCase):
with self.cached_session() as sess:
tc = TestClass(constant_op.constant(-1))
x = api.converted_call(tc, None, converter.ConversionOptions())
self.assertEqual(1, self.evaluate(x))
self.assertEqual(1, sess.run(x))
def test_converted_call_constructor(self):
@ -288,7 +288,7 @@ class ApiTest(test.TestCase):
constant_op.constant(-1))
# tc is now a converted object.
x = tc.test_method()
self.assertEqual(1, self.evaluate(x))
self.assertEqual(1, sess.run(x))
def test_converted_call_already_converted(self):
@ -298,12 +298,12 @@ class ApiTest(test.TestCase):
with self.cached_session() as sess:
x = api.converted_call(f, None, converter.ConversionOptions(),
constant_op.constant(0))
self.assertTrue(self.evaluate(x))
self.assertTrue(sess.run(x))
converted_f = api.to_graph(f)
x = api.converted_call(converted_f, None, converter.ConversionOptions(),
constant_op.constant(0))
self.assertTrue(self.evaluate(x))
self.assertTrue(sess.run(x))
def test_converted_call_no_user_code(self):
@ -334,8 +334,8 @@ class ApiTest(test.TestCase):
constant_op.constant([[0.0]]), training=True)
with self.cached_session() as sess:
self.evaluate(variables.global_variables_initializer())
self.assertAllEqual([[0.0, 0.0]], self.evaluate(x))
sess.run(variables.global_variables_initializer())
self.assertAllEqual([[0.0, 0.0]], sess.run(x))
def test_converted_call_whitelisted_method_extra_self(self):
@ -349,8 +349,8 @@ class ApiTest(test.TestCase):
model, constant_op.constant([[0.0]]), training=True)
with self.cached_session() as sess:
self.evaluate(variables.global_variables_initializer())
self.assertAllEqual([[0.0, 0.0]], self.evaluate(x))
sess.run(variables.global_variables_initializer())
self.assertAllEqual([[0.0, 0.0]], sess.run(x))
def test_converted_call_whitelisted_method_via_owner(self):
@ -364,8 +364,8 @@ class ApiTest(test.TestCase):
constant_op.constant([[0.0]]), training=True)
with self.cached_session() as sess:
self.evaluate(variables.global_variables_initializer())
self.assertAllEqual([[0.0, 0.0]], self.evaluate(x))
sess.run(variables.global_variables_initializer())
self.assertAllEqual([[0.0, 0.0]], sess.run(x))
def test_converted_call_lambda(self):
@ -376,8 +376,8 @@ class ApiTest(test.TestCase):
x = api.converted_call(l, None, opts, constant_op.constant(0))
with self.cached_session() as sess:
self.evaluate(variables.global_variables_initializer())
self.assertAllEqual(True, self.evaluate(x))
sess.run(variables.global_variables_initializer())
self.assertAllEqual(True, sess.run(x))
def test_to_graph_basic(self):
@ -390,7 +390,7 @@ class ApiTest(test.TestCase):
with self.cached_session() as sess:
x = compiled_fn(constant_op.constant([4, 8]), 4)
self.assertListEqual([1, 2], self.evaluate(x).tolist())
self.assertListEqual([1, 2], sess.run(x).tolist())
def test_to_graph_with_defaults(self):
@ -405,7 +405,7 @@ class ApiTest(test.TestCase):
with self.cached_session() as sess:
x = compiled_fn(constant_op.constant([4, 8]))
self.assertListEqual([1, 2], self.evaluate(x).tolist())
self.assertListEqual([1, 2], sess.run(x).tolist())
def test_to_code_basic(self):

View File

@ -36,7 +36,7 @@ class SpecialFunctionsTest(test.TestCase):
python_one = special_functions.match_staging_level(1, 1)
with self.cached_session() as sess:
self.assertTrue(tensor_util.is_tensor(tensor_one))
self.assertAllEqual(self.evaluate(tensor_one), 1)
self.assertAllEqual(sess.run(tensor_one), 1)
self.assertEqual(python_one, 1)
def test_tensor_list_empty_list(self):
@ -45,21 +45,21 @@ class SpecialFunctionsTest(test.TestCase):
element_shape=())
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
with self.cached_session() as sess:
self.assertAllEqual(self.evaluate(sl), [])
self.assertAllEqual(sess.run(sl), [])
l = special_functions.tensor_list((),
element_dtype=dtypes.int32,
element_shape=())
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
with self.cached_session() as sess:
self.assertAllEqual(self.evaluate(sl), [])
self.assertAllEqual(sess.run(sl), [])
def test_tensor_list_tensor(self):
l = special_functions.tensor_list(
constant_op.constant([], dtype=dtypes.int32))
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
with self.cached_session() as sess:
self.assertAllEqual(self.evaluate(sl), [])
self.assertAllEqual(sess.run(sl), [])
def test_tensor_list_unsupported_initializer(self):
with self.assertRaisesRegexp(ValueError, 'unknown type'):
@ -76,7 +76,7 @@ class SpecialFunctionsTest(test.TestCase):
l = special_functions.tensor_list(elements)
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
with self.cached_session() as sess:
self.assertAllEqual(self.evaluate(sl), [[1, 2], [3, 4]])
self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
def test_tensor_list_array_from_elements(self):
elements = [constant_op.constant([1, 2]), constant_op.constant([3, 4])]
@ -84,7 +84,7 @@ class SpecialFunctionsTest(test.TestCase):
l = special_functions.tensor_list(elements, use_tensor_array=True)
sl = l.stack()
with self.cached_session() as sess:
self.assertAllEqual(self.evaluate(sl), [[1, 2], [3, 4]])
self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
def test_stack(self):
self.assertEqual(special_functions.stack(1, strict=False), 1)

View File

@ -35,7 +35,7 @@ class ForLoopTest(test.TestCase):
body=lambda i, s: (s + i,),
init_state=(0,))
with self.cached_session() as sess:
self.assertEqual((10,), self.evaluate(s))
self.assertEqual((10,), sess.run(s))
def test_python(self):
s = control_flow.for_stmt(
@ -53,7 +53,7 @@ class ForLoopTest(test.TestCase):
body=lambda i, s: (s + i,),
init_state=(0,))
with self.cached_session() as sess:
self.assertEqual((10,), self.evaluate(s))
self.assertEqual((10,), sess.run(s))
class WhileLoopTest(test.TestCase):
@ -66,7 +66,7 @@ class WhileLoopTest(test.TestCase):
init_state=(0, 0),
extra_deps=(n,))
with self.cached_session() as sess:
self.assertEqual((5, 10), self.evaluate(results))
self.assertEqual((5, 10), sess.run(results))
def test_python(self):
n = 5
@ -90,9 +90,9 @@ class IfStmtTest(test.TestCase):
def test_tensor(self):
with self.cached_session() as sess:
t = self.single_return_if_stmt(constant_op.constant(True))
self.assertEqual(1, self.evaluate(t))
self.assertEqual(1, sess.run(t))
t = self.single_return_if_stmt(constant_op.constant(False))
self.assertEqual(-1, self.evaluate(t))
self.assertEqual(-1, sess.run(t))
def test_python(self):
self.assertEqual(1, self.single_return_if_stmt(True))
@ -101,9 +101,9 @@ class IfStmtTest(test.TestCase):
def test_tensor_multiple_returns(self):
with self.cached_session() as sess:
t = self.multi_return_if_stmt(constant_op.constant(True))
self.assertAllEqual([1, 2], self.evaluate(t))
self.assertAllEqual([1, 2], sess.run(t))
t = self.multi_return_if_stmt(constant_op.constant(False))
self.assertAllEqual([-1, -2], self.evaluate(t))
self.assertAllEqual([-1, -2], sess.run(t))
def test_python_multiple_returns(self):
self.assertEqual((1, 2), self.multi_return_if_stmt(True))

View File

@ -43,7 +43,7 @@ class ListTest(test.TestCase):
l = data_structures.tf_tensor_list_new([3, 4, 5])
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
with self.cached_session() as sess:
self.assertAllEqual(self.evaluate(t), [3, 4, 5])
self.assertAllEqual(sess.run(t), [3, 4, 5])
def test_tf_tensor_list_new_empty(self):
l = data_structures.tf_tensor_list_new([],
@ -51,13 +51,13 @@ class ListTest(test.TestCase):
element_shape=())
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
with self.cached_session() as sess:
self.assertAllEqual(self.evaluate(t), [])
self.assertAllEqual(sess.run(t), [])
def test_tf_tensor_list_new_from_tensor(self):
l = data_structures.tf_tensor_list_new(constant_op.constant([3, 4, 5]))
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
with self.cached_session() as sess:
self.assertAllEqual(self.evaluate(t), [3, 4, 5])
self.assertAllEqual(sess.run(t), [3, 4, 5])
def test_tf_tensor_list_new_illegal_input(self):
with self.assertRaises(ValueError):
@ -77,7 +77,7 @@ class ListTest(test.TestCase):
l = data_structures.tf_tensor_array_new([3, 4, 5])
t = l.stack()
with self.cached_session() as sess:
self.assertAllEqual(self.evaluate(t), [3, 4, 5])
self.assertAllEqual(sess.run(t), [3, 4, 5])
def test_tf_tensor_array_new_illegal_input(self):
with self.assertRaises(ValueError):
@ -102,7 +102,7 @@ class ListTest(test.TestCase):
t = list_ops.tensor_list_stack(l, element_dtype=x.dtype)
with self.cached_session() as sess:
self.assertAllEqual(self.evaluate(t), [[1, 2, 3]])
self.assertAllEqual(sess.run(t), [[1, 2, 3]])
def test_append_tensorarray(self):
l = tensor_array_ops.TensorArray(dtypes.int32, size=0, dynamic_size=True)
@ -131,10 +131,10 @@ class ListTest(test.TestCase):
with self.cached_session() as sess:
l, x = data_structures.list_pop(l, None, opts)
self.assertAllEqual(self.evaluate(x), [3, 4])
self.assertAllEqual(sess.run(x), [3, 4])
t = list_ops.tensor_list_stack(l, element_dtype=initial_list.dtype)
self.assertAllEqual(self.evaluate(t), [[1, 2]])
self.assertAllEqual(sess.run(t), [[1, 2]])
def test_pop_python(self):
l = [1, 2, 3]
@ -152,7 +152,7 @@ class ListTest(test.TestCase):
with self.cached_session() as sess:
t = data_structures.list_stack(l, opts)
self.assertAllEqual(sess.run(t), self.evaluate(initial_list))
self.assertAllEqual(sess.run(t), sess.run(initial_list))
def test_stack_tensor_list_empty(self):
l = list_ops.empty_tensor_list(

View File

@ -45,11 +45,11 @@ class LogicalOperatorsTest(test.TestCase):
def test_and_tf(self):
with self.cached_session() as sess:
t = logical.and_(self._tf_true, self._tf_true)
self.assertEqual(self.evaluate(t), True)
self.assertEqual(sess.run(t), True)
t = logical.and_(self._tf_true, lambda: True)
self.assertEqual(self.evaluate(t), True)
self.assertEqual(sess.run(t), True)
t = logical.and_(self._tf_false, lambda: True)
self.assertEqual(self.evaluate(t), False)
self.assertEqual(sess.run(t), False)
# TODO(mdan): Add a test for ops with side effects.
def test_or_python(self):
@ -63,11 +63,11 @@ class LogicalOperatorsTest(test.TestCase):
def test_or_tf(self):
with self.cached_session() as sess:
t = logical.or_(self._tf_false, self._tf_true)
self.assertEqual(self.evaluate(t), True)
self.assertEqual(sess.run(t), True)
t = logical.or_(self._tf_false, lambda: True)
self.assertEqual(self.evaluate(t), True)
self.assertEqual(sess.run(t), True)
t = logical.or_(self._tf_true, lambda: True)
self.assertEqual(self.evaluate(t), True)
self.assertEqual(sess.run(t), True)
# TODO(mdan): Add a test for ops with side effects.
def test_not_python(self):
@ -78,7 +78,7 @@ class LogicalOperatorsTest(test.TestCase):
def test_not_tf(self):
with self.cached_session() as sess:
t = logical.not_(self._tf_false())
self.assertEqual(self.evaluate(t), True)
self.assertEqual(sess.run(t), True)
if __name__ == '__main__':

View File

@ -38,29 +38,29 @@ class PyBuiltinsTest(test.TestCase):
self.assertEqual(py_builtins.abs_(-1), 1)
with self.cached_session() as sess:
t = py_builtins.abs_(constant_op.constant(-1))
self.assertEqual(self.evaluate(t), 1)
self.assertEqual(sess.run(t), 1)
t = py_builtins.abs_(constant_op.constant([-1, 2, -3]))
self.assertAllEqual(self.evaluate(t), [1, 2, 3])
self.assertAllEqual(sess.run(t), [1, 2, 3])
def test_float(self):
self.assertEqual(py_builtins.float_(10), 10.0)
self.assertEqual(py_builtins.float_('10.0'), 10.0)
with self.cached_session() as sess:
t = py_builtins.float_(constant_op.constant(1, dtype=dtypes.int64))
self.assertEqual(self.evaluate(t), 1.0)
self.assertEqual(sess.run(t), 1.0)
st = py_builtins.float_(constant_op.constant('1.0'))
self.assertEqual(self.evaluate(st), 1.0)
self.assertEqual(sess.run(st), 1.0)
def test_int(self):
self.assertEqual(py_builtins.int_(10.0), 10)
self.assertEqual(py_builtins.int_('11', 2), 3)
with self.cached_session() as sess:
t = py_builtins.int_(constant_op.constant(1, dtype=dtypes.float64))
self.assertEqual(self.evaluate(t), 1)
self.assertEqual(sess.run(t), 1)
st = py_builtins.int_(constant_op.constant('1'))
self.assertEqual(self.evaluate(st), 1)
self.assertEqual(sess.run(st), 1)
st = py_builtins.int_(constant_op.constant('1'), 10)
self.assertEqual(self.evaluate(st), 1)
self.assertEqual(sess.run(st), 1)
def test_int_unsupported_base(self):
t = constant_op.constant(1, dtype=dtypes.float64)
@ -73,9 +73,9 @@ class PyBuiltinsTest(test.TestCase):
t = py_builtins.len_(constant_op.constant([[1], [2], [3]]))
self.assertEqual(t, 3)
ta = py_builtins.len_(tensor_array_ops.TensorArray(dtypes.int32, size=5))
self.assertEqual(self.evaluate(ta), 5)
self.assertEqual(sess.run(ta), 5)
tl = py_builtins.len_(data_structures.tf_tensor_list_new([3, 4, 5]))
self.assertEqual(self.evaluate(tl), 3)
self.assertEqual(sess.run(tl), 3)
def test_len_scalar(self):
with self.assertRaises(ValueError):
@ -120,18 +120,18 @@ class PyBuiltinsTest(test.TestCase):
def test_range_tensor(self):
with self.cached_session() as sess:
r = py_builtins.range_(constant_op.constant(3))
self.assertAllEqual(self.evaluate(r), [0, 1, 2])
self.assertAllEqual(sess.run(r), [0, 1, 2])
r = py_builtins.range_(1, constant_op.constant(3))
self.assertAllEqual(self.evaluate(r), [1, 2])
self.assertAllEqual(sess.run(r), [1, 2])
r = py_builtins.range_(2, 0, constant_op.constant(-1))
self.assertAllEqual(self.evaluate(r), [2, 1])
self.assertAllEqual(sess.run(r), [2, 1])
def test_range_tensor_empty_range(self):
with self.session() as sess:
r = py_builtins.range_(constant_op.constant(-3))
self.assertAllEqual(self.evaluate(r), [])
self.assertAllEqual(sess.run(r), [])
r = py_builtins.range_(5, constant_op.constant(2))
self.assertAllEqual(self.evaluate(r), [])
self.assertAllEqual(sess.run(r), [])
if __name__ == '__main__':

View File

@ -34,7 +34,7 @@ class SlicesTest(test.TestCase):
with self.cached_session() as sess:
t = list_ops.tensor_list_stack(l, element_dtype=initial_list.dtype)
self.assertAllEqual(self.evaluate(t), [[5, 6], [3, 4]])
self.assertAllEqual(sess.run(t), [[5, 6], [3, 4]])
def test_get_item_tensor_list(self):
initial_list = constant_op.constant([[1, 2], [3, 4]])
@ -44,7 +44,7 @@ class SlicesTest(test.TestCase):
l, 1, slices.GetItemOpts(element_dtype=initial_list.dtype))
with self.cached_session() as sess:
self.assertAllEqual(self.evaluate(t), [3, 4])
self.assertAllEqual(sess.run(t), [3, 4])
def test_get_item_tensor_string(self):
initial_str = constant_op.constant('abcd')
@ -52,14 +52,14 @@ class SlicesTest(test.TestCase):
slices.GetItemOpts(element_dtype=initial_str.dtype))
with self.cached_session() as sess:
self.assertEqual(self.evaluate(t), b'b')
self.assertEqual(sess.run(t), b'b')
initial_list_str = constant_op.constant(['abcd', 'bcde'])
t = slices.get_item(initial_list_str, 1,
slices.GetItemOpts(element_dtype=initial_str.dtype))
with self.cached_session() as sess:
self.assertEqual(self.evaluate(t), b'bcde')
self.assertEqual(sess.run(t), b'bcde')
if __name__ == '__main__':

View File

@ -32,7 +32,7 @@ class MiscTest(test.TestCase):
new_a = alias_tensors(a)
self.assertFalse(new_a is a)
with self.cached_session() as sess:
self.assertEqual(1, self.evaluate(new_a))
self.assertEqual(1, sess.run(new_a))
def test_alias_tensors(self):
a = constant(1)
@ -47,7 +47,7 @@ class MiscTest(test.TestCase):
self.assertTrue(new_s is s)
self.assertTrue(new_l is l)
with self.cached_session() as sess:
self.assertEqual(1, self.evaluate(new_a))
self.assertEqual(1, sess.run(new_a))
if __name__ == '__main__':

View File

@ -34,13 +34,13 @@ class PyFuncTest(test.TestCase):
with self.cached_session() as sess:
result = py_func.wrap_py_func(test_fn, dtypes.int64,
(1, constant_op.constant(1), 1))
self.assertEqual(3, self.evaluate(result))
self.assertEqual(3, sess.run(result))
result = py_func.wrap_py_func(test_fn, dtypes.int64, (1, 1, 1))
self.assertEqual(3, self.evaluate(result))
self.assertEqual(3, sess.run(result))
result = py_func.wrap_py_func(
test_fn, dtypes.int64,
(constant_op.constant(1), 1, constant_op.constant(1)))
self.assertEqual(3, self.evaluate(result))
self.assertEqual(3, sess.run(result))
def test_wrap_py_func_complex_args(self):
@ -54,10 +54,10 @@ class PyFuncTest(test.TestCase):
with self.cached_session() as sess:
result = py_func.wrap_py_func(test_fn, dtypes.int64, (7, TestClass()))
self.assertEqual(35, self.evaluate(result))
self.assertEqual(35, sess.run(result))
result = py_func.wrap_py_func(test_fn, dtypes.int64,
(constant_op.constant(7), TestClass()))
self.assertEqual(35, self.evaluate(result))
self.assertEqual(35, sess.run(result))
def test_wrap_py_func_kwargs(self):
@ -74,13 +74,13 @@ class PyFuncTest(test.TestCase):
'c': 11,
'd': TestClass(13)
})
self.assertEqual(178, self.evaluate(result))
self.assertEqual(178, sess.run(result))
result = py_func.wrap_py_func(test_fn, dtypes.int64,
(constant_op.constant(7), TestClass(5)), {
'c': constant_op.constant(11),
'd': TestClass(13)
})
self.assertEqual(178, self.evaluate(result))
self.assertEqual(178, sess.run(result))
def test_wrap_py_func_dummy_return(self):
@ -91,11 +91,11 @@ class PyFuncTest(test.TestCase):
with self.cached_session() as sess:
result = py_func.wrap_py_func(test_fn, None, (5,), use_dummy_return=True)
self.assertEqual(1, self.evaluate(result))
self.assertEqual(1, sess.run(result))
self.assertEqual([1], side_counter)
result = py_func.wrap_py_func(
test_fn, None, (constant_op.constant(5),), use_dummy_return=True)
self.assertEqual(1, self.evaluate(result))
self.assertEqual(1, sess.run(result))
self.assertEqual([2], side_counter)

View File

@ -43,13 +43,13 @@ class TensorListTest(test.TestCase):
l = tl.dynamic_list_append(l, 1)
s = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
with self.cached_session() as sess:
self.assertAllEqual(self.evaluate(s), [1])
self.assertAllEqual(sess.run(s), [1])
l = tensor_array_ops.TensorArray(dtypes.int32, size=0, dynamic_size=True)
l = tl.dynamic_list_append(l, 1)
s = l.stack()
with self.cached_session() as sess:
self.assertAllEqual(self.evaluate(s), [1])
self.assertAllEqual(sess.run(s), [1])
l = tl.TensorList(self._shape(()), dtypes.int32)
l = tl.dynamic_list_append(l, 1)

View File

@ -62,7 +62,7 @@ class SessionClusterSpecPropagationTest(test_util.TensorFlowTestCase):
const = constant_op.constant(17)
sess = session.Session(server1.target, config=config)
output = self.evaluate(const)
output = sess.run(const)
self.assertEqual(17, output)
def testClusterSpecPropagationWorker2Placement(self):
@ -106,7 +106,7 @@ class SessionClusterSpecPropagationTest(test_util.TensorFlowTestCase):
with ops.Graph().as_default() as g, ops.device('/job:worker/task:0'):
const = constant_op.constant(17)
sess = session.Session(server1.target, config=config, graph=g)
output = self.evaluate(const)
output = sess.run(const)
self.assertEqual(17, output)
def testCanonicalDeviceNames(self):
@ -208,7 +208,7 @@ class SessionClusterSpecPropagationTest(test_util.TensorFlowTestCase):
with ops.device('/job:worker/task:0/cpu:0'):
sum3 = sum1 + sum2
sess = session.Session(server1.target, config=config, graph=g)
output = self.evaluate(sum3)
output = sess.run(sum3)
self.assertEqual(40, output)
def testLegacyDeviceNames(self):

View File

@ -147,7 +147,7 @@ class TimelineTest(test.TestCase):
num2 = variables.Variable(2.0, name='num2')
with ops.device('/cpu:2'):
result = num1 + num2 + num1 * num2
self.evaluate(variables.global_variables_initializer())
sess.run(variables.global_variables_initializer())
sess.run(result, options=run_options, run_metadata=run_metadata)
self.assertTrue(run_metadata.HasField('step_stats'))
@ -176,7 +176,7 @@ class TimelineTest(test.TestCase):
num2 = variables.Variable(2.0, name='num2')
with ops.device('/cpu:2'):
result = num1 + num2 + num1 * num2
self.evaluate(variables.global_variables_initializer())
sess.run(variables.global_variables_initializer())
sess.run(result, options=run_options, run_metadata=run_metadata)
self.assertTrue(run_metadata.HasField('step_stats'))
step_stats = run_metadata.step_stats

View File

@ -216,7 +216,7 @@ class VirtualGpuTest(test_util.TensorFlowTestCase):
for d in self._util.devices:
with ops.device(d):
var = variables.Variable(random_ops.random_uniform(mat_shape))
self.evaluate(var.initializer)
sess.run(var.initializer)
data.append(var)
s = data[0]
for i in range(1, len(data)):

View File

@ -53,10 +53,10 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
for start in range(0, len(components), 4):
results = self.evaluate(get_next)
results = sess.run(get_next)
self.assertAllEqual([[i, j]
for i, c in enumerate(components[start:start + 4])
for j in range(c)], results.indices)
@ -81,10 +81,10 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
for start in range(0, len(components), 4):
results = self.evaluate(get_next)
results = sess.run(get_next)
self.assertAllEqual([[i, j, z]
for i, c in enumerate(components[start:start + 4])
for j in range(c)
@ -141,7 +141,7 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.cached_session() as sess:
sess.run(iterator.initializer, feed_dict={placeholder: [0, 1, 2, 3]})
for i in range(4):
self.assertEqual(i, self.evaluate(next_elem))
self.assertEqual(i, sess.run(next_elem))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_elem)
@ -159,7 +159,7 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.cached_session() as sess:
for i in range(10):
self.assertEqual((i,) * 3, self.evaluate(op))
self.assertEqual((i,) * 3, sess.run(op))
with self.assertRaises(errors.OutOfRangeError):
sess.run(op)
@ -179,7 +179,7 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.cached_session() as sess:
for i in range(10):
self.assertEqual((i, compat.as_bytes(str(i)), i), self.evaluate(op))
self.assertEqual((i, compat.as_bytes(str(i)), i), sess.run(op))
with self.assertRaises(errors.OutOfRangeError):
sess.run(op)
@ -198,7 +198,7 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.cached_session() as sess:
for i in range(10):
st_row = self.evaluate(next_element)
st_row = sess.run(next_element)
self.assertEqual([i], st_row.indices)
self.assertEqual([i], st_row.values)
self.assertEqual([10], st_row.dense_shape)
@ -219,7 +219,7 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.cached_session() as sess:
for i in range(10):
dense_elem, st_row = self.evaluate(next_element)
dense_elem, st_row = sess.run(next_element)
self.assertEqual(i, dense_elem)
self.assertEqual([i], st_row.indices)
self.assertEqual([i], st_row.values)
@ -241,7 +241,7 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.cached_session() as sess:
for i in range(10):
self.assertEqual(((i,),) * 3, self.evaluate(op))
self.assertEqual(((i,),) * 3, sess.run(op))
with self.assertRaises(errors.OutOfRangeError):
sess.run(op)
@ -354,7 +354,7 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
sess.run(init_op, feed_dict={count: 28, batch_size: 14})
num_batches = (28 * 7) // 14
for i in range(num_batches):
result = self.evaluate(get_next)
result = sess.run(get_next)
for component, result_component in zip(components, result):
for j in range(14):
self.assertAllEqual(component[(i * 14 + j) % 7]**2,
@ -369,12 +369,12 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
# We expect (num_batches - 1) full-sized batches.
num_batches = int(math.ceil((14 * 7) / 8))
for i in range(num_batches - 1):
result = self.evaluate(get_next)
result = sess.run(get_next)
for component, result_component in zip(components, result):
for j in range(8):
self.assertAllEqual(component[(i * 8 + j) % 7]**2,
result_component[j])
result = self.evaluate(get_next)
result = sess.run(get_next)
for component, result_component in zip(components, result):
for j in range((14 * 7) % 8):
self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2,
@ -408,10 +408,10 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertEqual([None, 1], iterator.output_shapes.as_list())
next_element = iterator.get_next()
with self.cached_session() as sess:
self.assertAllEqual([[0], [1], [4], [9]], self.evaluate(next_element))
self.assertAllEqual([[16], [25], [36], [49]], self.evaluate(next_element))
self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
if not drop_remainder:
self.assertAllEqual([[64], [81]], self.evaluate(next_element))
self.assertAllEqual([[64], [81]], sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
@ -423,9 +423,9 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertEqual([None, 1], iterator.output_shapes.as_list())
next_element = iterator.get_next()
with self.cached_session() as sess:
self.assertAllEqual([[0], [1], [4], [9]], self.evaluate(next_element))
self.assertAllEqual([[16], [25], [36], [49]], self.evaluate(next_element))
self.assertAllEqual([[64], [81]], self.evaluate(next_element))
self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
self.assertAllEqual([[64], [81]], sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
@ -439,7 +439,7 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
elements.append(iterator.get_next())
with self.cached_session() as sess:
for i in range(5):
got = self.evaluate(elements)
got = sess.run(elements)
got.sort(key=lambda x: x[0])
expected = []
for j in range(100):
@ -459,7 +459,7 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
elements.append(iterator.get_next())
with self.cached_session() as sess:
for i in range(4):
got = self.evaluate(elements)
got = sess.run(elements)
got.sort(key=lambda x: x[0])
expected = []
for j in range(100):
@ -480,9 +480,9 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
for i in range(2):
actual = self.evaluate(get_next)
actual = sess.run(get_next)
expected = sparse_tensor.SparseTensorValue(
indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4],
@ -524,7 +524,7 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"number of elements does not match"):
sess.run(get_next)
@ -576,8 +576,7 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.cached_session() as sess:
for i in range(threshold // 10):
self.assertAllEqual([i * 10 + j for j in range(10)],
self.evaluate(get_next))
self.assertAllEqual([i * 10 + j for j in range(10)], sess.run(get_next))
if threshold % 10 != 0:
self.assertAllEqual(
[threshold // 10 * 10 + j for j in range(threshold % 10)],
@ -610,8 +609,7 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.cached_session() as sess:
for _ in range(10):
self.assertAllEqual([element for _ in range(10)],
self.evaluate(get_next))
self.assertAllEqual([element for _ in range(10)], sess.run(get_next))
class UnbatchDatasetBenchmark(test.Benchmark):

View File

@ -300,7 +300,7 @@ class BucketBySequenceLengthTest(test_base.DatasetTestBase):
with self.cached_session() as sess:
with self.assertRaises(errors.OutOfRangeError):
while True:
output = self.evaluate(batch)
output = sess.run(batch)
sprs_tensor = (tuple([tuple(idx) for idx in output.indices]),
tuple(output.values))
all_sparse_tensors.add(sprs_tensor)

View File

@ -57,7 +57,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config) as sess:
for i in range(10):
self.assertEqual(i, self.evaluate(next_element))
self.assertEqual(i, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
@ -82,7 +82,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config) as sess:
self.assertAllEqual([0, 1, 2, 3], self.evaluate(next_element))
self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
@ -108,7 +108,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config) as sess:
for i in range(10):
self.assertEqual(i, self.evaluate(next_element))
self.assertEqual(i, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
@ -134,7 +134,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config) as sess:
for i in range(10):
self.assertEqual(i, self.evaluate(next_element))
self.assertEqual(i, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
@ -160,7 +160,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config) as sess:
for i in range(10):
self.assertEqual({"a": i}, self.evaluate(next_element))
self.assertEqual({"a": i}, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
@ -186,7 +186,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config) as sess:
for i in range(10):
self.assertEqual({"a": i}, self.evaluate(next_element))
self.assertEqual({"a": i}, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
@ -217,7 +217,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config) as sess:
for i in range(10):
actual = self.evaluate(next_element)
actual = sess.run(next_element)
self.assertAllEqual([i], actual.values)
self.assertAllEqual([[0, 0]], actual.indices)
self.assertAllEqual([2, 2], actual.dense_shape)
@ -251,7 +251,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config) as sess:
for i in range(10):
actual = self.evaluate(next_element)
actual = sess.run(next_element)
self.assertAllEqual([i], actual.values)
self.assertAllEqual([[0, 0]], actual.indices)
self.assertAllEqual([2, 2], actual.dense_shape)
@ -271,9 +271,9 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
next_element = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(iterator.initializer)
sess.run(iterator.initializer)
for i in range(10):
self.assertEqual(i, self.evaluate(next_element))
self.assertEqual(i, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
@ -290,9 +290,9 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
next_element = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(iterator.initializer)
sess.run(iterator.initializer)
for i in range(10):
self.assertEqual(i, self.evaluate(next_element))
self.assertEqual(i, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
@ -323,9 +323,9 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
next_element = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(iterator.initializer)
sess.run(iterator.initializer)
for i in range(10):
x, y, z = self.evaluate(next_element)
x, y, z = sess.run(next_element)
self.assertEqual(i**2, x)
self.assertEqual(float(i**2), y)
self.assertEqual(util_compat.as_bytes(str(i)), z)
@ -345,8 +345,8 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
next_element = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(iterator.initializer)
self.assertAllEqual([0, 1, 2, 3], self.evaluate(next_element))
sess.run(iterator.initializer)
self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
@ -363,8 +363,8 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
next_element = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(iterator.initializer)
self.assertAllEqual([0, 1, 2, 3], self.evaluate(next_element))
sess.run(iterator.initializer)
self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
@ -381,8 +381,8 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
next_element = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(iterator.initializer)
self.assertAllEqual([b"a", b"b", b"c"], self.evaluate(next_element))
sess.run(iterator.initializer)
self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
@ -399,8 +399,8 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
next_element = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(iterator.initializer)
self.assertAllEqual([b"a", b"b", b"c"], self.evaluate(next_element))
sess.run(iterator.initializer)
self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
@ -420,9 +420,9 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
next_element = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(iterator.initializer)
sess.run(iterator.initializer)
for i in range(10):
self.assertEqual(i, self.evaluate(next_element))
self.assertEqual(i, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
@ -447,12 +447,12 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config) as sess:
self.evaluate(iterator.initializer)
sess.run(iterator.initializer)
for i in range(5):
self.assertEqual(i, self.evaluate(next_element))
self.evaluate(iterator.initializer)
self.assertEqual(i, sess.run(next_element))
sess.run(iterator.initializer)
for i in range(10):
self.assertEqual(i, self.evaluate(next_element))
self.assertEqual(i, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
@ -477,12 +477,12 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config) as sess:
self.evaluate(iterator.initializer)
sess.run(iterator.initializer)
for i in range(5):
self.assertEqual(i, self.evaluate(next_element))
self.evaluate(iterator.initializer)
self.assertEqual(i, sess.run(next_element))
sess.run(iterator.initializer)
for i in range(10):
self.assertEqual(i, self.evaluate(next_element))
self.assertEqual(i, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
@ -499,12 +499,12 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
next_element = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(iterator.initializer)
sess.run(iterator.initializer)
for i in range(5):
self.assertEqual(i, self.evaluate(next_element))
self.evaluate(iterator.initializer)
self.assertEqual(i, sess.run(next_element))
sess.run(iterator.initializer)
for i in range(10):
self.assertEqual(i, self.evaluate(next_element))
self.assertEqual(i, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
@ -521,12 +521,12 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
next_element = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(iterator.initializer)
sess.run(iterator.initializer)
for i in range(5):
self.assertEqual(i, self.evaluate(next_element))
self.evaluate(iterator.initializer)
self.assertEqual(i, sess.run(next_element))
sess.run(iterator.initializer)
for i in range(10):
self.assertEqual(i, self.evaluate(next_element))
self.assertEqual(i, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
@ -553,7 +553,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
# For each element of the dataset, assert that the optional evaluates to
# the expected value.
self.evaluate(iterator.initializer)
sess.run(iterator.initializer)
for i in range(3):
elem_has_value, elem_value = sess.run([elem_has_value_t, elem_value_t])
self.assertTrue(elem_has_value)
@ -562,7 +562,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
# After exhausting the iterator, `next_elem.has_value()` will evaluate to
# false, and attempting to get the value will fail.
for _ in range(2):
self.assertFalse(self.evaluate(elem_has_value_t))
self.assertFalse(sess.run(elem_has_value_t))
with self.assertRaises(errors.InvalidArgumentError):
sess.run(elem_value_t)

View File

@ -38,13 +38,13 @@ class CounterTest(test_base.DatasetTestBase):
negative_get_next = negative_iterator.get_next()
with self.cached_session() as sess:
self.assertEqual(3, self.evaluate(get_next))
self.assertEqual(3 + 4, self.evaluate(get_next))
self.assertEqual(3 + 2 * 4, self.evaluate(get_next))
self.assertEqual(3, sess.run(get_next))
self.assertEqual(3 + 4, sess.run(get_next))
self.assertEqual(3 + 2 * 4, sess.run(get_next))
self.assertEqual(0, self.evaluate(negative_get_next))
self.assertEqual(-1, self.evaluate(negative_get_next))
self.assertEqual(-2, self.evaluate(negative_get_next))
self.assertEqual(0, sess.run(negative_get_next))
self.assertEqual(-1, sess.run(negative_get_next))
self.assertEqual(-2, sess.run(negative_get_next))
if __name__ == "__main__":

View File

@ -41,10 +41,10 @@ class DenseToSparseBatchTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
for start in range(0, len(components), 4):
results = self.evaluate(get_next)
results = sess.run(get_next)
self.assertAllEqual([[i, j]
for i, c in enumerate(components[start:start + 4])
for j in range(c)], results.indices)
@ -69,10 +69,10 @@ class DenseToSparseBatchTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
for start in range(0, len(components), 4):
results = self.evaluate(get_next)
results = sess.run(get_next)
self.assertAllEqual([[i, j, z]
for i, c in enumerate(components[start:start + 4])
for j in range(c)

View File

@ -40,10 +40,10 @@ class DirectedInterleaveDatasetTest(test_base.DatasetTestBase):
next_element = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(iterator.initializer)
sess.run(iterator.initializer)
for _ in range(100):
for i in range(10):
self.assertEqual(i, self.evaluate(next_element))
self.assertEqual(i, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
@ -107,7 +107,7 @@ class DirectedInterleaveDatasetTest(test_base.DatasetTestBase):
with self.cached_session() as sess:
for i in choice_array:
self.assertEqual(words[i], self.evaluate(next_element))
self.assertEqual(words[i], sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)

View File

@ -44,9 +44,9 @@ class EnumerateDatasetTest(test_base.DatasetTestBase):
[t.shape for t in get_next[1]])
with self.cached_session() as sess:
self.evaluate(init_op)
self.assertEqual((20, (b"a", 1, 37.0)), self.evaluate(get_next))
self.assertEqual((21, (b"b", 2, 38.0)), self.evaluate(get_next))
sess.run(init_op)
self.assertEqual((20, (b"a", 1, 37.0)), sess.run(get_next))
self.assertEqual((21, (b"b", 2, 38.0)), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)

View File

@ -94,18 +94,18 @@ class FunctionBufferingResourceTest(test_base.DatasetTestBase):
device0, device1)
with self.test_session(config=worker_config) as sess:
elem = self.evaluate(prefetch_op)
elem = sess.run(prefetch_op)
self.assertEqual(elem, [1.0])
elem = self.evaluate(prefetch_op)
elem = sess.run(prefetch_op)
self.assertEqual(elem, [2.0])
elem = self.evaluate(prefetch_op)
elem = sess.run(prefetch_op)
self.assertEqual(elem, [3.0])
elem = self.evaluate(prefetch_op)
elem = sess.run(prefetch_op)
self.assertEqual(elem, [4.0])
self._event.wait()
elem = self.evaluate(prefetch_op)
elem = sess.run(prefetch_op)
self.assertEqual(elem, [5.0])
self.evaluate(destroy_op)
sess.run(destroy_op)
def testSameDeviceCPU(self):
self._prefetch_fn_helper_one_shot("same_device_cpu",
@ -135,35 +135,35 @@ class FunctionBufferingResourceTest(test_base.DatasetTestBase):
ds, ds_iterator, "reinit", device0, device1)
with self.test_session(config=worker_config) as sess:
self.evaluate(ds_iterator.initializer)
elem = self.evaluate(prefetch_op)
sess.run(ds_iterator.initializer)
elem = sess.run(prefetch_op)
self.assertEqual(elem, [1.0])
elem = self.evaluate(prefetch_op)
elem = sess.run(prefetch_op)
self.assertEqual(elem, [2.0])
elem = self.evaluate(prefetch_op)
elem = sess.run(prefetch_op)
self.assertEqual(elem, [3.0])
elem = self.evaluate(prefetch_op)
elem = sess.run(prefetch_op)
self.assertEqual(elem, [4.0])
self._event.wait()
elem = self.evaluate(prefetch_op)
elem = sess.run(prefetch_op)
self.assertEqual(elem, [5.0])
# Lets reset the function buffering resource and reinitialize the
# iterator. Should be able to go through this again.
self._event.clear()
self.evaluate(reset_op)
self.evaluate(ds_iterator.initializer)
elem = self.evaluate(prefetch_op)
sess.run(reset_op)
sess.run(ds_iterator.initializer)
elem = sess.run(prefetch_op)
self.assertEqual(elem, [1.0])
elem = self.evaluate(prefetch_op)
elem = sess.run(prefetch_op)
self.assertEqual(elem, [2.0])
elem = self.evaluate(prefetch_op)
elem = sess.run(prefetch_op)
self.assertEqual(elem, [3.0])
elem = self.evaluate(prefetch_op)
elem = sess.run(prefetch_op)
self.assertEqual(elem, [4.0])
self._event.wait()
elem = self.evaluate(prefetch_op)
elem = sess.run(prefetch_op)
self.assertEqual(elem, [5.0])
self.evaluate(destroy_op)
sess.run(destroy_op)
def testReinitializationOutOfRange(self):
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
@ -175,30 +175,30 @@ class FunctionBufferingResourceTest(test_base.DatasetTestBase):
ds, ds_iterator, "reinit", device0, device1)
with self.test_session(config=worker_config) as sess:
self.evaluate(ds_iterator.initializer)
sess.run(ds_iterator.initializer)
for i in range(1, 10):
elem = self.evaluate(prefetch_op)
elem = sess.run(prefetch_op)
self.assertEqual(elem, [float(i)])
# Try fetching after its over twice to test out end of sequence.
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(prefetch_op)
sess.run(prefetch_op)
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(prefetch_op)
sess.run(prefetch_op)
# Now reset everything and try it out again.
self._event.clear()
self.evaluate(reset_op)
self.evaluate(ds_iterator.initializer)
sess.run(reset_op)
sess.run(ds_iterator.initializer)
for i in range(1, 10):
elem = self.evaluate(prefetch_op)
elem = sess.run(prefetch_op)
self.assertEqual(elem, [float(i)])
# Try fetching after its over twice to test out end of sequence.
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(prefetch_op)
sess.run(prefetch_op)
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(prefetch_op)
sess.run(prefetch_op)
self.evaluate(destroy_op)
sess.run(destroy_op)
def testStringsGPU(self):
if not test_util.is_gpu_available():
@ -235,13 +235,13 @@ class FunctionBufferingResourceTest(test_base.DatasetTestBase):
buffer_resource_handle, ignore_lookup_error=True)
with self.cached_session() as sess:
self.assertEqual([b"a"], self.evaluate(prefetch_op))
self.assertEqual([b"b"], self.evaluate(prefetch_op))
self.assertEqual([b"c"], self.evaluate(prefetch_op))
self.assertEqual([b"a"], sess.run(prefetch_op))
self.assertEqual([b"b"], sess.run(prefetch_op))
self.assertEqual([b"c"], sess.run(prefetch_op))
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(prefetch_op)
sess.run(prefetch_op)
self.evaluate(destroy_op)
sess.run(destroy_op)
if __name__ == "__main__":

View File

@ -39,7 +39,7 @@ class GroupByReducerTest(test_base.DatasetTestBase):
get_next = dataset.make_one_shot_iterator().get_next()
with self.cached_session() as sess:
for expected in values:
got = self.evaluate(get_next)
got = sess.run(get_next)
self.assertEqual(got, expected)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -127,7 +127,7 @@ class GroupByReducerTest(test_base.DatasetTestBase):
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
with self.cached_session() as sess:
x, y = self.evaluate(get_next)
x, y = sess.run(get_next)
self.assertAllEqual([0] * (2**i), x)
self.assertAllEqual(np.array(1, ndmin=i), y)
with self.assertRaises(errors.OutOfRangeError):
@ -190,7 +190,7 @@ class GroupByReducerTest(test_base.DatasetTestBase):
grouping.group_by_reducer(lambda x, y: np.int64(0), reducer))
get_next = dataset.make_one_shot_iterator().get_next()
with self.cached_session() as sess:
x, y = self.evaluate(get_next)
x, y = sess.run(get_next)
self.assertAllEqual(x, np.asarray([x for x in range(10)]))
self.assertEqual(y, 45)

View File

@ -68,9 +68,9 @@ class GroupByWindowTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
which_bucket, bucketed_values = self.evaluate(get_next)
which_bucket, bucketed_values = sess.run(get_next)
self.assertEqual(0, which_bucket)
@ -103,11 +103,11 @@ class GroupByWindowTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
# Get two minibatches (one containing even values, one containing odds)
which_bucket_even, bucketed_values_even = self.evaluate(get_next)
which_bucket_odd, bucketed_values_odd = self.evaluate(get_next)
which_bucket_even, bucketed_values_even = sess.run(get_next)
which_bucket_odd, bucketed_values_odd = sess.run(get_next)
# Count number of bucket_tensors.
self.assertEqual(3, len(bucketed_values_even))
@ -174,11 +174,11 @@ class GroupByWindowTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
# Get two minibatches ([0, 2, ...] and [64, 66, ...])
which_bucket0, bucketed_values_even0 = self.evaluate(get_next)
which_bucket1, bucketed_values_even1 = self.evaluate(get_next)
which_bucket0, bucketed_values_even0 = sess.run(get_next)
which_bucket1, bucketed_values_even1 = sess.run(get_next)
# Ensure that bucket 1 was completely filtered out
self.assertAllEqual(0, which_bucket0)
@ -207,11 +207,11 @@ class GroupByWindowTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
with self.assertRaises(errors.OutOfRangeError):
batches = 0
while True:
result = self.evaluate(get_next)
result = sess.run(get_next)
is_even = all(x % 2 == 0 for x in result)
is_odd = all(x % 2 == 1 for x in result)
self.assertTrue(is_even or is_odd)
@ -232,11 +232,11 @@ class GroupByWindowTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
counts = []
with self.assertRaises(errors.OutOfRangeError):
while True:
result = self.evaluate(get_next)
result = sess.run(get_next)
self.assertTrue(
all(x % 2 == 0
for x in result) or all(x % 2 == 1)
@ -259,16 +259,16 @@ class GroupByWindowTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
# The input is infinite, so this test demonstrates that:
# 1. We produce output without having to consume the entire input,
# 2. Different buckets can produce output at different rates, and
# 3. For deterministic input, the output is deterministic.
for _ in range(3):
self.assertAllEqual([0, 0, 0, 0], self.evaluate(get_next))
self.assertAllEqual([1, 1, 1, 1], self.evaluate(get_next))
self.assertAllEqual([2, 2, 2, 2], self.evaluate(get_next))
self.assertAllEqual([0, 0, 0, 0], self.evaluate(get_next))
self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
self.assertAllEqual([2, 2, 2, 2], sess.run(get_next))
self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
def testSmallGroups(self):
components = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0], dtype=np.int64)
@ -280,13 +280,13 @@ class GroupByWindowTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
self.assertAllEqual([0, 0, 0, 0], self.evaluate(get_next))
self.assertAllEqual([1, 1, 1, 1], self.evaluate(get_next))
sess.run(init_op)
self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
# The small outputs at the end are deterministically produced in key
# order.
self.assertAllEqual([0, 0, 0], self.evaluate(get_next))
self.assertAllEqual([1], self.evaluate(get_next))
self.assertAllEqual([0, 0, 0], sess.run(get_next))
self.assertAllEqual([1], sess.run(get_next))
def testEmpty(self):
iterator = (
@ -297,7 +297,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
"Window size must be greater than zero, but got 0."):
@ -323,7 +323,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
with self.assertRaises(errors.InvalidArgumentError):
sess.run(get_next)
@ -351,11 +351,11 @@ class GroupByWindowTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
counts = []
with self.assertRaises(errors.OutOfRangeError):
while True:
tight_result, multiple_of_10_result = self.evaluate(get_next)
tight_result, multiple_of_10_result = sess.run(get_next)
self.assertEqual(0, multiple_of_10_result.shape[1] % 10)
self.assertAllEqual(tight_result,
multiple_of_10_result[:, :tight_result.shape[1]])

View File

@ -47,9 +47,9 @@ class IgnoreErrorsTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
for x in [1., 2., 3., 5.]:
self.assertEqual(x, self.evaluate(get_next))
self.assertEqual(x, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -65,9 +65,9 @@ class IgnoreErrorsTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
for x in [1., 2., 3., 5.]:
self.assertEqual(x, self.evaluate(get_next))
self.assertEqual(x, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -93,9 +93,9 @@ class IgnoreErrorsTest(test_base.DatasetTestBase):
with self.cached_session() as sess:
# All of the files are present.
self.evaluate(init_op)
sess.run(init_op)
for filename in filenames:
self.assertEqual(compat.as_bytes(filename), self.evaluate(get_next))
self.assertEqual(compat.as_bytes(filename), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -104,9 +104,9 @@ class IgnoreErrorsTest(test_base.DatasetTestBase):
# Attempting to read filenames[0] will fail, but ignore_errors()
# will catch the error.
self.evaluate(init_op)
sess.run(init_op)
for filename in filenames[1:]:
self.assertEqual(compat.as_bytes(filename), self.evaluate(get_next))
self.assertEqual(compat.as_bytes(filename), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)

View File

@ -53,7 +53,7 @@ class IndexedDatasetOpsTest(test_base.DatasetTestBase):
ds = indexed_dataset_ops.IdentityIndexedDataset(16)
materialized = ds.materialize()
with self.cached_session() as sess:
self.evaluate(materialized.initializer)
sess.run(materialized.initializer)
placeholder = array_ops.placeholder(dtypes.uint64, shape=[])
for i in range(16):
output = sess.run(
@ -68,9 +68,9 @@ class IndexedDatasetOpsTest(test_base.DatasetTestBase):
itr = ds.make_initializable_iterator()
n = itr.get_next()
with self.cached_session() as sess:
self.evaluate(itr.initializer)
sess.run(itr.initializer)
for i in range(16):
output = self.evaluate(n)
output = sess.run(n)
self.assertEqual(i, output)
with self.assertRaises(errors.OutOfRangeError):
sess.run(n)

View File

@ -112,10 +112,10 @@ class MakeBatchedFeaturesDatasetTest(
next_element = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
for file_batch, _, _, _, record_batch, _ in self._next_expected_batch(
range(self._num_files), 2, 10):
actual_batch = self.evaluate(next_element)
actual_batch = sess.run(next_element)
self.assertAllEqual(file_batch, actual_batch["file"])
self.assertAllEqual(record_batch, actual_batch["record"])
with self.assertRaises(errors.OutOfRangeError):

View File

@ -90,7 +90,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
batch_size,
num_epochs,
):
actual_features = self.evaluate(nxt)
actual_features = sess.run(nxt)
if label_name is not None:
expected_labels = expected_features.pop(label_name)

View File

@ -105,7 +105,7 @@ class MakeTFRecordDatasetTest(
for expected_batch in self._next_expected_batch(
file_indices, batch_size, num_epochs, interleave_cycle_length,
drop_final_batch, use_parser_fn):
actual_batch = self.evaluate(outputs)
actual_batch = sess.run(outputs)
self.assertAllEqual(expected_batch, actual_batch)
def _read_test(self, batch_size, num_epochs, file_index=None,
@ -188,7 +188,7 @@ class MakeTFRecordDatasetTest(
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
self.evaluate(iterator.initializer)
sess.run(iterator.initializer)
first_batches = []
try:
while True:
@ -196,7 +196,7 @@ class MakeTFRecordDatasetTest(
except errors.OutOfRangeError:
pass
self.evaluate(iterator.initializer)
sess.run(iterator.initializer)
second_batches = []
try:
while True:

View File

@ -89,7 +89,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
sess.run(init_op, feed_dict={count: 28, batch_size: 14})
num_batches = (28 * 7) // 14
for i in range(num_batches):
result = self.evaluate(get_next)
result = sess.run(get_next)
for component, result_component in zip(components, result):
for j in range(14):
self.assertAllEqual(component[(i * 14 + j) % 7]**2,
@ -104,12 +104,12 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
# We expect (num_batches - 1) full-sized batches.
num_batches = int(math.ceil((14 * 7) / 8))
for i in range(num_batches - 1):
result = self.evaluate(get_next)
result = sess.run(get_next)
for component, result_component in zip(components, result):
for j in range(8):
self.assertAllEqual(component[(i * 8 + j) % 7]**2,
result_component[j])
result = self.evaluate(get_next)
result = sess.run(get_next)
for component, result_component in zip(components, result):
for j in range((14 * 7) % 8):
self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2,
@ -152,10 +152,10 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertEqual([None, 1], iterator.output_shapes.as_list())
next_element = iterator.get_next()
with self.cached_session() as sess:
self.assertAllEqual([[0], [1], [4], [9]], self.evaluate(next_element))
self.assertAllEqual([[16], [25], [36], [49]], self.evaluate(next_element))
self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
if not drop_remainder:
self.assertAllEqual([[64], [81]], self.evaluate(next_element))
self.assertAllEqual([[64], [81]], sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
@ -177,9 +177,9 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertEqual([None, 1], iterator.output_shapes.as_list())
next_element = iterator.get_next()
with self.cached_session() as sess:
self.assertAllEqual([[0], [1], [4], [9]], self.evaluate(next_element))
self.assertAllEqual([[16], [25], [36], [49]], self.evaluate(next_element))
self.assertAllEqual([[64], [81]], self.evaluate(next_element))
self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
self.assertAllEqual([[64], [81]], sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
@ -201,7 +201,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
elements.append(iterator.get_next())
with self.cached_session() as sess:
for i in range(5):
got = self.evaluate(elements)
got = sess.run(elements)
got.sort(key=lambda x: x[0])
expected = []
for j in range(100):
@ -230,7 +230,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
elements.append(iterator.get_next())
with self.cached_session() as sess:
for i in range(4):
got = self.evaluate(elements)
got = sess.run(elements)
got.sort(key=lambda x: x[0])
expected = []
for j in range(100):
@ -261,9 +261,9 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
for i in range(2):
actual = self.evaluate(get_next)
actual = sess.run(get_next)
expected = sparse_tensor.SparseTensorValue(
indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4],
@ -321,7 +321,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"number of elements does not match"):
sess.run(get_next)
@ -393,8 +393,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.cached_session() as sess:
for i in range(threshold // 10):
self.assertAllEqual([i * 10 + j for j in range(10)],
self.evaluate(get_next))
self.assertAllEqual([i * 10 + j for j in range(10)], sess.run(get_next))
if threshold % 10 != 0:
self.assertAllEqual(
[threshold // 10 * 10 + j for j in range(threshold % 10)],
@ -443,8 +442,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.cached_session() as sess:
for _ in range(10):
self.assertAllEqual([element for _ in range(10)],
self.evaluate(get_next))
self.assertAllEqual([element for _ in range(10)], sess.run(get_next))
@parameterized.named_parameters(
("Identity", None, lambda x: x, None),
@ -464,7 +462,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
else:
expected = map_fn(
sess.run(self.structuredElement(structure, shape=[10])))
self.assertAllEqual(expected, self.evaluate(get_next))
self.assertAllEqual(expected, sess.run(get_next))
def testShortCircuitCapturedInput(self):
captured_t = array_ops.placeholder(dtypes.int64, shape=[])
@ -475,7 +473,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.cached_session() as sess:
sess.run(iterator.initializer, feed_dict={captured_t: 42})
self.assertAllEqual([42] * 10, self.evaluate(get_next))
self.assertAllEqual([42] * 10, sess.run(get_next))
@parameterized.named_parameters(
("Normal", False),

View File

@ -218,7 +218,7 @@ class MapDefunTest(test_base.DatasetTestBase):
def _assert_op_cancelled(self, sess, map_defun_op):
with self.assertRaisesRegexp(errors.CancelledError, "was cancelled"):
self.evaluate(map_defun_op)
sess.run(map_defun_op)
def testMapDefunWithParentCancellation(self):
# Checks that a cancellation of the parent graph is threaded through to

View File

@ -72,7 +72,7 @@ class OverrideThreadpoolTest(test_base.DatasetTestBase,
next_element = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(iterator.initializer)
sess.run(iterator.initializer)
thread_ids = []
try:
while True:

View File

@ -637,11 +637,11 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
for i in range(10):
for j in range(2):
expected = [i, 0] if j % 2 == 0 else [0, -i]
self.assertAllEqual(expected, self.evaluate(get_next))
self.assertAllEqual(expected, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -796,7 +796,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
with self.cached_session() as sess:
for _ in range(2):
elements = []
self.evaluate(iterator.initializer)
sess.run(iterator.initializer)
try:
while True:
elements.extend(sess.run(next_element))

View File

@ -57,7 +57,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config) as sess:
for i in range(10):
self.assertEqual(i, self.evaluate(next_element))
self.assertEqual(i, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
@ -87,7 +87,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
with self.cached_session() as sess:
for i in range(10):
self.assertEqual(i, self.evaluate(next_element))
self.assertEqual(i, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
@ -117,7 +117,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config) as sess:
for i in range(10):
self.assertEqual({"a": i}, self.evaluate(next_element))
self.assertEqual({"a": i}, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
@ -150,7 +150,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config) as sess:
for i in range(10):
actual = self.evaluate(next_element)
actual = sess.run(next_element)
self.assertAllEqual([i], actual.values)
self.assertAllEqual([[0, 0]], actual.indices)
self.assertAllEqual([2, 2], actual.dense_shape)
@ -170,7 +170,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
with self.cached_session() as sess:
for i in range(10):
self.assertEqual(i, self.evaluate(next_element))
self.assertEqual(i, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
@ -199,12 +199,12 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config) as sess:
self.evaluate(iterator.initializer)
sess.run(iterator.initializer)
for i in range(5):
self.assertEqual(i, self.evaluate(next_element))
self.evaluate(iterator.initializer)
self.assertEqual(i, sess.run(next_element))
sess.run(iterator.initializer)
for i in range(10):
self.assertEqual(i, self.evaluate(next_element))
self.assertEqual(i, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
@ -220,12 +220,12 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
next_element = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(iterator.initializer)
sess.run(iterator.initializer)
for i in range(5):
self.assertEqual(i, self.evaluate(next_element))
self.evaluate(iterator.initializer)
self.assertEqual(i, sess.run(next_element))
sess.run(iterator.initializer)
for i in range(10):
self.assertEqual(i, self.evaluate(next_element))
self.assertEqual(i, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)

View File

@ -60,7 +60,7 @@ class ScanTest(test_base.DatasetTestBase):
feed_dict={start: start_val, step: step_val, take: take_val})
for expected, _ in zip(
itertools.count(start_val, step_val), range(take_val)):
self.assertEqual(expected, self.evaluate(next_element))
self.assertEqual(expected, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
@ -110,7 +110,7 @@ class ScanTest(test_base.DatasetTestBase):
feed_dict={start: start_val, step: step_val, take: take_val})
for expected, _ in zip(
itertools.count(start_val, step_val), range(take_val)):
self.assertEqual(expected, self.evaluate(next_element).values[0])
self.assertEqual(expected, sess.run(next_element).values[0])
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
@ -136,7 +136,7 @@ class ScanTest(test_base.DatasetTestBase):
with self.cached_session() as sess:
for i in range(5):
(longer_vector_val, larger_rank_val), _ = self.evaluate(next_element)
(longer_vector_val, larger_rank_val), _ = sess.run(next_element)
self.assertAllEqual([0] * (2**i), longer_vector_val)
self.assertAllEqual(np.array(1, ndmin=i), larger_rank_val)
with self.assertRaises(errors.OutOfRangeError):

View File

@ -71,19 +71,19 @@ class RangeDatasetSerializationTest(
with ops.Graph().as_default() as g:
init_op, get_next, save_op, _ = _build_graph(start, stop)
with self.session(graph=g) as sess:
self.evaluate(variables.global_variables_initializer())
self.evaluate(init_op)
sess.run(variables.global_variables_initializer())
sess.run(init_op)
for i in range(start, break_point):
self.assertEqual(i, self.evaluate(get_next))
self.evaluate(save_op)
self.assertEqual(i, sess.run(get_next))
sess.run(save_op)
with ops.Graph().as_default() as g:
init_op, get_next, _, restore_op = _build_graph(start, stop)
with self.session(graph=g) as sess:
self.evaluate(init_op)
self.evaluate(restore_op)
sess.run(init_op)
sess.run(restore_op)
for i in range(break_point, stop):
self.assertEqual(i, self.evaluate(get_next))
self.assertEqual(i, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -91,14 +91,14 @@ class RangeDatasetSerializationTest(
with ops.Graph().as_default() as g:
init_op, get_next, save_op, restore_op = _build_graph(start, stop)
with self.session(graph=g) as sess:
self.evaluate(variables.global_variables_initializer())
self.evaluate(init_op)
sess.run(variables.global_variables_initializer())
sess.run(init_op)
for i in range(start, break_point):
self.assertEqual(i, self.evaluate(get_next))
self.evaluate(save_op)
self.evaluate(restore_op)
self.assertEqual(i, sess.run(get_next))
sess.run(save_op)
sess.run(restore_op)
for i in range(break_point, stop):
self.assertEqual(i, self.evaluate(get_next))
self.assertEqual(i, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)

View File

@ -62,7 +62,7 @@ class SerializationIntegrationTest(test.TestCase):
with self.session(graph=g) as sess:
sess.run(init_ops)
for _ in range(break_point):
output = self.evaluate(get_next_ops)
output = sess.run(get_next_ops)
for i in range(num_pipelines):
all_outputs[i].append(output[i])
saver.save(sess, self._ckpt_path())
@ -73,7 +73,7 @@ class SerializationIntegrationTest(test.TestCase):
with self.session(graph=g) as sess:
saver.restore(sess, self._ckpt_path())
for _ in range(num_outputs - break_point):
output = self.evaluate(get_next_ops)
output = sess.run(get_next_ops)
for i in range(num_pipelines):
all_outputs[i].append(output[i])

View File

@ -108,7 +108,7 @@ class ShuffleAndRepeatTest(test_base.DatasetTestBase):
shuffle_ops.shuffle_and_repeat(buffer_size=21))
get_next_op = ds.make_one_shot_iterator().get_next()
with self.session(graph=g) as sess:
self.evaluate(get_next_op)
sess.run(get_next_op)
if __name__ == "__main__":

View File

@ -38,10 +38,10 @@ class SleepTest(test_base.DatasetTestBase):
next_element = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(iterator.initializer)
sess.run(iterator.initializer)
start_time = time.time()
for i in range(10):
self.assertEqual(i, self.evaluate(next_element))
self.assertEqual(i, sess.run(next_element))
end_time = time.time()
self.assertGreater(end_time - start_time, (10 * sleep_microseconds) / 1e6)
with self.assertRaises(errors.OutOfRangeError):

View File

@ -39,9 +39,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
"ORDER BY first_name DESC"
})
for _ in range(2): # Dataset is repeated. See setUp.
self.assertEqual((b"John", b"Doe", b"Hi!"), self.evaluate(get_next))
self.assertEqual((b"Jane", b"Moe", b"Hi again!"),
self.evaluate(get_next))
self.assertEqual((b"John", b"Doe", b"Hi!"), sess.run(get_next))
self.assertEqual((b"Jane", b"Moe", b"Hi again!"), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -59,8 +58,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
"ON students.first_name = people.first_name "
"AND students.last_name = people.last_name"
})
self.assertEqual((b"John", b"California", b"Hi!"),
self.evaluate(get_next))
self.assertEqual((b"John", b"California", b"Hi!"), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -77,9 +75,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
"SELECT first_name, last_name, favorite_nonsense_word "
"FROM students ORDER BY first_name DESC"
})
self.assertEqual((b"John", b"Doe", b"n\0nsense"), self.evaluate(get_next))
self.assertEqual((b"Jane", b"Moe", b"nonsense\0"),
self.evaluate(get_next))
self.assertEqual((b"John", b"Doe", b"n\0nsense"), sess.run(get_next))
self.assertEqual((b"Jane", b"Moe", b"nonsense\0"), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -96,8 +93,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
self.query: "SELECT first_name, last_name, motto FROM students "
"ORDER BY first_name DESC"
})
self.assertEqual((b"John", b"Doe", b"Hi!"), self.evaluate(get_next))
self.assertEqual((b"Jane", b"Moe", b"Hi again!"), self.evaluate(get_next))
self.assertEqual((b"John", b"Doe", b"Hi!"), sess.run(get_next))
self.assertEqual((b"Jane", b"Moe", b"Hi again!"), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
sess.run(
@ -106,8 +103,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
self.query: "SELECT first_name, last_name, state FROM people "
"ORDER BY first_name DESC"
})
self.assertEqual((b"John", b"Doe", b"California"),
self.evaluate(get_next))
self.assertEqual((b"John", b"Doe", b"California"), sess.run(get_next))
self.assertEqual((b"Benjamin", b"Franklin", b"Pennsylvania"),
sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
@ -216,8 +212,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
self.query: "SELECT first_name, desk_number FROM students "
"ORDER BY first_name DESC"
})
self.assertEqual((b"John", 9), self.evaluate(get_next))
self.assertEqual((b"Jane", 127), self.evaluate(get_next))
self.assertEqual((b"John", 9), sess.run(get_next))
self.assertEqual((b"Jane", 127), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -234,7 +230,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
"FROM students "
"WHERE first_name = 'John' ORDER BY first_name DESC"
})
self.assertEqual((b"John", 0, -2), self.evaluate(get_next))
self.assertEqual((b"John", 0, -2), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -250,9 +246,9 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
"SELECT desk_number, favorite_negative_number FROM students "
"ORDER BY first_name DESC"
})
self.assertEqual((9, -2), self.evaluate(get_next))
self.assertEqual((9, -2), sess.run(get_next))
# Max and min values of int8
self.assertEqual((127, -128), self.evaluate(get_next))
self.assertEqual((127, -128), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -267,8 +263,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
self.query: "SELECT first_name, desk_number FROM students "
"ORDER BY first_name DESC"
})
self.assertEqual((b"John", 9), self.evaluate(get_next))
self.assertEqual((b"Jane", 127), self.evaluate(get_next))
self.assertEqual((b"John", 9), sess.run(get_next))
self.assertEqual((b"Jane", 127), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -285,7 +281,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
"FROM students "
"WHERE first_name = 'John' ORDER BY first_name DESC"
})
self.assertEqual((b"John", 0, -2), self.evaluate(get_next))
self.assertEqual((b"John", 0, -2), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -301,9 +297,9 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
"FROM students ORDER BY first_name DESC"
})
# Max value of int16
self.assertEqual((b"John", 32767), self.evaluate(get_next))
self.assertEqual((b"John", 32767), sess.run(get_next))
# Min value of int16
self.assertEqual((b"Jane", -32768), self.evaluate(get_next))
self.assertEqual((b"Jane", -32768), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -318,8 +314,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
self.query: "SELECT first_name, desk_number FROM students "
"ORDER BY first_name DESC"
})
self.assertEqual((b"John", 9), self.evaluate(get_next))
self.assertEqual((b"Jane", 127), self.evaluate(get_next))
self.assertEqual((b"John", 9), sess.run(get_next))
self.assertEqual((b"Jane", 127), sess.run(get_next))
# Test that `SqlDataset` can read a negative or 0-valued integer from a
# SQLite database table and place it in an `int32` tensor.
@ -332,8 +328,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
self.query: "SELECT first_name, income FROM students "
"ORDER BY first_name DESC"
})
self.assertEqual((b"John", 0), self.evaluate(get_next))
self.assertEqual((b"Jane", -20000), self.evaluate(get_next))
self.assertEqual((b"John", 0), sess.run(get_next))
self.assertEqual((b"Jane", -20000), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -349,9 +345,9 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
"ORDER BY first_name DESC"
})
# Max value of int32
self.assertEqual((b"John", 2147483647), self.evaluate(get_next))
self.assertEqual((b"John", 2147483647), sess.run(get_next))
# Min value of int32
self.assertEqual((b"Jane", -2147483648), self.evaluate(get_next))
self.assertEqual((b"Jane", -2147483648), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -366,8 +362,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
self.query: "SELECT first_name, school_id FROM students "
"ORDER BY first_name DESC"
})
self.assertEqual((b"John", 123), self.evaluate(get_next))
self.assertEqual((b"Jane", 1000), self.evaluate(get_next))
self.assertEqual((b"John", 123), sess.run(get_next))
self.assertEqual((b"Jane", 1000), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -382,8 +378,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
self.query: "SELECT first_name, desk_number FROM students "
"ORDER BY first_name DESC"
})
self.assertEqual((b"John", 9), self.evaluate(get_next))
self.assertEqual((b"Jane", 127), self.evaluate(get_next))
self.assertEqual((b"John", 9), sess.run(get_next))
self.assertEqual((b"Jane", 127), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -398,8 +394,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
self.query: "SELECT first_name, income FROM students "
"ORDER BY first_name DESC"
})
self.assertEqual((b"John", 0), self.evaluate(get_next))
self.assertEqual((b"Jane", -20000), self.evaluate(get_next))
self.assertEqual((b"John", 0), sess.run(get_next))
self.assertEqual((b"Jane", -20000), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -416,9 +412,9 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
"ORDER BY first_name DESC"
})
# Max value of int64
self.assertEqual((b"John", 9223372036854775807), self.evaluate(get_next))
self.assertEqual((b"John", 9223372036854775807), sess.run(get_next))
# Min value of int64
self.assertEqual((b"Jane", -9223372036854775808), self.evaluate(get_next))
self.assertEqual((b"Jane", -9223372036854775808), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -433,8 +429,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
self.query: "SELECT first_name, desk_number FROM students "
"ORDER BY first_name DESC"
})
self.assertEqual((b"John", 9), self.evaluate(get_next))
self.assertEqual((b"Jane", 127), self.evaluate(get_next))
self.assertEqual((b"John", 9), sess.run(get_next))
self.assertEqual((b"Jane", 127), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -450,9 +446,9 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
"ORDER BY first_name DESC"
})
# Min value of uint8
self.assertEqual((b"John", 0), self.evaluate(get_next))
self.assertEqual((b"John", 0), sess.run(get_next))
# Max value of uint8
self.assertEqual((b"Jane", 255), self.evaluate(get_next))
self.assertEqual((b"Jane", 255), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -467,8 +463,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
self.query: "SELECT first_name, desk_number FROM students "
"ORDER BY first_name DESC"
})
self.assertEqual((b"John", 9), self.evaluate(get_next))
self.assertEqual((b"Jane", 127), self.evaluate(get_next))
self.assertEqual((b"John", 9), sess.run(get_next))
self.assertEqual((b"Jane", 127), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -484,9 +480,9 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
"ORDER BY first_name DESC"
})
# Min value of uint16
self.assertEqual((b"John", 0), self.evaluate(get_next))
self.assertEqual((b"John", 0), sess.run(get_next))
# Max value of uint16
self.assertEqual((b"Jane", 65535), self.evaluate(get_next))
self.assertEqual((b"Jane", 65535), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -503,8 +499,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
"SELECT first_name, registration_complete FROM students "
"ORDER BY first_name DESC"
})
self.assertEqual((b"John", True), self.evaluate(get_next))
self.assertEqual((b"Jane", False), self.evaluate(get_next))
self.assertEqual((b"John", True), sess.run(get_next))
self.assertEqual((b"Jane", False), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -519,8 +515,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
self.query: "SELECT first_name, favorite_medium_sized_number "
"FROM students ORDER BY first_name DESC"
})
self.assertEqual((b"John", True), self.evaluate(get_next))
self.assertEqual((b"Jane", True), self.evaluate(get_next))
self.assertEqual((b"John", True), sess.run(get_next))
self.assertEqual((b"Jane", True), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -537,9 +533,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
"SELECT first_name, last_name, victories FROM townspeople "
"ORDER BY first_name"
})
self.assertEqual((b"George", b"Washington", 20.0),
self.evaluate(get_next))
self.assertEqual((b"John", b"Adams", -19.95), self.evaluate(get_next))
self.assertEqual((b"George", b"Washington", 20.0), sess.run(get_next))
self.assertEqual((b"John", b"Adams", -19.95), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)

View File

@ -74,18 +74,18 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
summary_t = aggregator.get_summary()
with self.cached_session() as sess:
self.evaluate(iterator.initializer)
sess.run(iterator.initializer)
expected_sum = 0.0
for i in range(100):
self.assertAllEqual(
np.array([i] * i, dtype=np.int64), sess.run(next_element))
summary_str = self.evaluate(summary_t)
summary_str = sess.run(summary_t)
self._assertSummaryHasCount(summary_str, "bytes_produced", float(i + 1))
expected_sum += i * 8.0
self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum)
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
summary_str = self.evaluate(summary_t)
summary_str = sess.run(summary_t)
self._assertSummaryHasCount(summary_str, "bytes_produced", 100.0)
self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum)
@ -99,15 +99,14 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
summary_t = aggregator.get_summary()
with self.cached_session() as sess:
self.evaluate(iterator.initializer)
sess.run(iterator.initializer)
for i in range(100):
self.assertEqual(i, self.evaluate(next_element))
self.assertEqual(i, sess.run(next_element))
self._assertSummaryHasCount(
sess.run(summary_t), "record_latency", float(i + 1))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self._assertSummaryHasCount(
self.evaluate(summary_t), "record_latency", 100.0)
self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0)
def testPrefetchBufferUtilization(self, dataset_transformation):
aggregator = stats_aggregator.StatsAggregator()
@ -119,11 +118,11 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
summary_t = aggregator.get_summary()
with self.cached_session() as sess:
self.evaluate(iterator.initializer)
sess.run(iterator.initializer)
for i in range(100):
self.assertAllEqual(
np.array([i] * i, dtype=np.int64), sess.run(next_element))
summary_str = self.evaluate(summary_t)
summary_str = sess.run(summary_t)
self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
float(i + 1))
self._assertSummaryContains(summary_str, "Prefetch::buffer_capacity")
@ -132,7 +131,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
0, 1)
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
summary_str = self.evaluate(summary_t)
summary_str = sess.run(summary_t)
self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
100)
@ -146,11 +145,11 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
summary_t = aggregator.get_summary()
with self.cached_session() as sess:
self.evaluate(iterator.initializer)
sess.run(iterator.initializer)
for i in range(10):
self.assertAllEqual(
np.array([i] * i, dtype=np.int64), sess.run(next_element))
summary_str = self.evaluate(summary_t)
summary_str = sess.run(summary_t)
self._assertSummaryHasScalarValue(summary_str,
"Prefetch::buffer_capacity", 0)
self._assertSummaryHasScalarValue(summary_str, "Prefetch::buffer_size",
@ -168,9 +167,9 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
summary_t = aggregator.get_summary()
with self.test_session() as sess:
self.evaluate(iterator.initializer)
sess.run(iterator.initializer)
for i in range(34):
self.assertEqual(i * 3, self.evaluate(next_element))
self.assertEqual(i * 3, sess.run(next_element))
if i is not 0:
self._assertSummaryHasScalarValue(
sess.run(summary_t), "Filter::dropped_elements", float(i * 2))
@ -262,9 +261,9 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
with self.cached_session() as sess:
for j in range(5):
self.evaluate(iterator.initializer)
sess.run(iterator.initializer)
for i in range(100):
self.assertEqual(i, self.evaluate(next_element))
self.assertEqual(i, sess.run(next_element))
self._assertSummaryHasCount(
sess.run(summary_t), "record_latency", float((j * 100) + i + 1))
with self.assertRaises(errors.OutOfRangeError):
@ -279,9 +278,9 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
next_element = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(iterator.initializer)
sess.run(iterator.initializer)
for i in range(100):
self.assertEqual(i, self.evaluate(next_element))
self.assertEqual(i, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
@ -296,17 +295,16 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
summary_t = aggregator.get_summary()
with self.cached_session() as sess:
self.evaluate(iterator.initializer)
sess.run(iterator.initializer)
for i in range(100):
self.assertEqual(i, self.evaluate(next_element))
self.assertEqual(i, sess.run(next_element))
self._assertSummaryHasCount(
sess.run(summary_t), "record_latency", float(i + 1))
self._assertSummaryHasCount(
sess.run(summary_t), "record_latency_2", float(i + 1))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self._assertSummaryHasCount(
self.evaluate(summary_t), "record_latency", 100.0)
self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0)
self._assertSummaryHasCount(
sess.run(summary_t), "record_latency_2", 100.0)
@ -321,15 +319,14 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
summary_t = aggregator.get_summary()
with self.cached_session() as sess:
self.evaluate(iterator.initializer)
sess.run(iterator.initializer)
for i in range(100):
self.assertEqual(i, self.evaluate(next_element))
self.assertEqual(i, sess.run(next_element))
self._assertSummaryHasCount(
sess.run(summary_t), "record_latency", float(2 * (i + 1)))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self._assertSummaryHasCount(
self.evaluate(summary_t), "record_latency", 200.0)
self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0)
def testMultipleIteratorsSameAggregator(self, dataset_transformation):
aggregator = stats_aggregator.StatsAggregator()
@ -344,13 +341,12 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
with self.cached_session() as sess:
sess.run([iterator_0.initializer, iterator_1.initializer])
for i in range(100):
self.assertEqual(i * 2, self.evaluate(next_element))
self.assertEqual(i * 2, sess.run(next_element))
self._assertSummaryHasCount(
sess.run(summary_t), "record_latency", float(2 * (i + 1)))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self._assertSummaryHasCount(
self.evaluate(summary_t), "record_latency", 200.0)
self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0)
def testMultipleDatasetWithPrefixes(self, dataset_transformation):
aggregator = stats_aggregator.StatsAggregator()
@ -368,7 +364,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
with self.test_session() as sess:
sess.run([iterator_0.initializer, iterator_1.initializer])
for i in range(100):
self.assertEqual(i * 2, self.evaluate(next_element))
self.assertEqual(i * 2, sess.run(next_element))
self._assertSummaryHasCount(
sess.run(summary_t), "dataset1_record_latency", float(i + 1))
self._assertSummaryHasCount(
@ -425,7 +421,7 @@ class FeatureStatsDatasetTest(
summary_t = aggregator.get_summary()
with self.test_session() as sess:
self.evaluate(iterator.initializer)
sess.run(iterator.initializer)
for _ in range(num_output):
sess.run(next_element)

View File

@ -50,7 +50,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.cached_session() as sess:
sess.run(iterator.initializer, feed_dict={placeholder: [0, 1, 2, 3]})
for i in range(4):
self.assertEqual(i, self.evaluate(next_elem))
self.assertEqual(i, sess.run(next_elem))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_elem)
@ -68,7 +68,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.cached_session() as sess:
for i in range(10):
self.assertEqual((i,) * 3, self.evaluate(op))
self.assertEqual((i,) * 3, sess.run(op))
with self.assertRaises(errors.OutOfRangeError):
sess.run(op)
@ -88,7 +88,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.cached_session() as sess:
for i in range(10):
self.assertEqual((i, compat.as_bytes(str(i)), i), self.evaluate(op))
self.assertEqual((i, compat.as_bytes(str(i)), i), sess.run(op))
with self.assertRaises(errors.OutOfRangeError):
sess.run(op)
@ -107,7 +107,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.cached_session() as sess:
for i in range(10):
st_row = self.evaluate(next_element)
st_row = sess.run(next_element)
self.assertEqual([i], st_row.indices)
self.assertEqual([i], st_row.values)
self.assertEqual([10], st_row.dense_shape)
@ -128,7 +128,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.cached_session() as sess:
for i in range(10):
dense_elem, st_row = self.evaluate(next_element)
dense_elem, st_row = sess.run(next_element)
self.assertEqual(i, dense_elem)
self.assertEqual([i], st_row.indices)
self.assertEqual([i], st_row.values)
@ -150,7 +150,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.cached_session() as sess:
for i in range(10):
self.assertEqual(((i,),) * 3, self.evaluate(op))
self.assertEqual(((i,),) * 3, sess.run(op))
with self.assertRaises(errors.OutOfRangeError):
sess.run(op)

View File

@ -49,11 +49,11 @@ class UniqueTest(test_base.DatasetTestBase):
with self.cached_session() as sess:
for test_case, expected in test_cases:
current_test_case = test_case
self.evaluate(iterator.initializer)
sess.run(iterator.initializer)
for element in expected:
if dtype == dtypes.string:
element = compat.as_bytes(element)
self.assertAllEqual(element, self.evaluate(next_element))
self.assertAllEqual(element, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)

View File

@ -93,13 +93,13 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
})
num_full_batches = (count * 7) // batch_size
for i in range(num_full_batches):
result = self.evaluate(get_next)
result = sess.run(get_next)
for component, result_component in zip(components, result):
for j in range(batch_size):
self.assertAllEqual(component[(i * batch_size + j) % 7]**2,
result_component[j])
if not drop_remainder and (count * 7) % batch_size > 0:
result = self.evaluate(get_next)
result = sess.run(get_next)
for component, result_component in zip(components, result):
for j in range((count * 7) % batch_size):
self.assertAllEqual(
@ -128,9 +128,9 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
for i in range(2):
actual = self.evaluate(get_next)
actual = sess.run(get_next)
expected = sparse_tensor.SparseTensorValue(
indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4],
@ -155,9 +155,9 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
for i in range(2):
actual = self.evaluate(get_next)
actual = sess.run(get_next)
expected_indices = []
expected_values = []
for j in range(5):
@ -185,8 +185,8 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
actual = self.evaluate(get_next)
sess.run(init_op)
actual = sess.run(get_next)
expected = sparse_tensor.SparseTensorValue(
indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [0, 4, 0],
[1, 0, 0], [1, 1, 0], [1, 2, 0], [1, 3, 0], [1, 4, 0]],
@ -211,7 +211,7 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
next_element = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(iterator.initializer)
sess.run(iterator.initializer)
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
r'Cannot batch tensors with different shapes in component 0. '
@ -271,7 +271,7 @@ class PaddedBatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
num_full_batches = len(seq_lens) // batch_size
for i in range(num_full_batches):
result = self.evaluate(get_next)
result = sess.run(get_next)
padded_len = padded_shapes[0]
if padded_len is None or padded_len == -1:
padded_len = np.max(result) if result.size > 0 else 0
@ -283,7 +283,7 @@ class PaddedBatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
[0] * (padded_len - seq_len))
if not drop_remainder and len(seq_lens) % batch_size > 0:
result = self.evaluate(get_next)
result = sess.run(get_next)
padded_len = np.max(result) if result.size > 0 else 0
self.assertEqual((len(seq_lens) % batch_size, padded_len),
result.shape)
@ -315,7 +315,7 @@ class PaddedBatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
get_next = iterator.get_next()
with self.cached_session() as sess:
result = self.evaluate(get_next)
result = sess.run(get_next)
self.assertAllEqual([[], [], [], []], result)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -347,7 +347,7 @@ class PaddedBatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
seq_lens: random_seq_lens
})
for i in range(8):
result = self.evaluate(get_next)
result = sess.run(get_next)
padded_len = np.max(result[0])
self.assertEqual((4, padded_len), result[0].shape)
self.assertEqual((4, padded_len), result[1].shape)

View File

@ -71,7 +71,7 @@ class FileCacheDatasetTest(test_base.DatasetTestBase):
with self.cached_session() as sess:
# First run without caching to collect the "ground truth".
self.evaluate(init_fifo_op)
sess.run(init_fifo_op)
elements = []
for _ in range(20):
elements.append(sess.run(get_next))
@ -220,14 +220,14 @@ class MemoryCacheDatasetTest(test_base.DatasetTestBase):
with self.cached_session() as sess:
self.evaluate(repeat_count.initializer)
self.evaluate(cached_iterator.initializer)
self.evaluate(uncached_iterator.initializer)
sess.run(repeat_count.initializer)
sess.run(cached_iterator.initializer)
sess.run(uncached_iterator.initializer)
for i in range(3):
for _ in range(10):
self.assertEqual(self.evaluate(cached_next), i)
self.assertEqual(self.evaluate(uncached_next), i)
self.assertEqual(sess.run(cached_next), i)
self.assertEqual(sess.run(uncached_next), i)
sess.run(repeat_count.assign(0))
@ -238,7 +238,7 @@ class MemoryCacheDatasetTest(test_base.DatasetTestBase):
# The cached iterator replays from cache.
for i in range(3):
for _ in range(10):
self.assertEqual(self.evaluate(cached_next), i)
self.assertEqual(sess.run(cached_next), i)
# The cached iterator should now be empty.
with self.assertRaises(errors.OutOfRangeError):
@ -280,7 +280,7 @@ class MemoryCacheDatasetTest(test_base.DatasetTestBase):
i2 = d2.make_initializable_iterator()
with self.cached_session() as sess:
self.evaluate(i1.initializer)
sess.run(i1.initializer)
self.assertEqual(1, sess.run(i1.get_next()))
self.assertEqual(2, sess.run(i1.get_next()))
@ -307,7 +307,7 @@ class MemoryCacheDatasetTest(test_base.DatasetTestBase):
with self.cached_session() as sess:
for i, expected in enumerate(expected_values):
self.assertEqual(expected, self.evaluate(n),
self.assertEqual(expected, sess.run(n),
"Unexpected value at index %s" % i)
with self.assertRaises(errors.OutOfRangeError):

View File

@ -51,9 +51,9 @@ class ConcatenateDatasetTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
for i in range(9):
result = self.evaluate(get_next)
result = sess.run(get_next)
if i < 4:
for component, result_component in zip(input_components, result):
self.assertAllEqual(component[i], result_component)
@ -85,9 +85,9 @@ class ConcatenateDatasetTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
for i in range(9):
result = self.evaluate(get_next)
result = sess.run(get_next)
if i < 4:
for component, result_component in zip(input_components, result):
self.assertAllEqual(component[i], result_component)

View File

@ -52,8 +52,8 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
[t.shape for t in get_next])
with self.cached_session() as sess:
self.evaluate(init_op)
results = self.evaluate(get_next)
sess.run(init_op)
results = sess.run(get_next)
for component, result_component in zip(components, results):
self.assertAllEqual(component, result_component)
with self.assertRaises(errors.OutOfRangeError):
@ -81,8 +81,8 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
[shape for shape in iterator.output_shapes])
with self.cached_session() as sess:
self.evaluate(init_op)
results = self.evaluate(get_next)
sess.run(init_op)
results = sess.run(get_next)
for component, result_component in zip(components, results):
self.assertSparseValuesEqual(component, result_component)
with self.assertRaises(errors.OutOfRangeError):
@ -112,8 +112,8 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
], [shape for shape in iterator.output_shapes])
with self.cached_session() as sess:
self.evaluate(init_op)
results = self.evaluate(get_next)
sess.run(init_op)
results = sess.run(get_next)
for component, result_component in zip(components, results):
if sparse_tensor.is_sparse(component):
self.assertSparseValuesEqual(component, result_component)
@ -139,9 +139,9 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
[t.shape for t in get_next])
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
for i in range(4):
results = self.evaluate(get_next)
results = sess.run(get_next)
for component, result_component in zip(components, results):
self.assertAllEqual(component[i], result_component)
with self.assertRaises(errors.OutOfRangeError):
@ -169,7 +169,7 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
[shape for shape in iterator.output_shapes])
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
expected = [
(sparse_tensor.SparseTensorValue(
indices=np.array([[0]]),
@ -197,7 +197,7 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
dense_shape=np.array([3]))),
]
for i in range(3):
results = self.evaluate(get_next)
results = sess.run(get_next)
for component, result_component in zip(expected[i], results):
self.assertSparseValuesEqual(component, result_component)
with self.assertRaises(errors.OutOfRangeError):
@ -229,7 +229,7 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
], [shape for shape in iterator.output_shapes])
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
expected = [
(sparse_tensor.SparseTensorValue(
indices=np.array([[0]]),
@ -257,7 +257,7 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
dense_shape=np.array([3]))),
]
for i in range(3):
results = self.evaluate(get_next)
results = sess.run(get_next)
for component, result_component in zip(
(list(zip(*components[:3]))[i] + expected[i]), results):
if sparse_tensor.is_sparse(component):
@ -280,9 +280,9 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
self.assertEqual((1,), iterator.output_shapes["bar"])
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
for i in range(3):
results = self.evaluate(get_next)
results = sess.run(get_next)
self.assertEqual(components["foo"][i], results["foo"])
self.assertEqual(components["bar"][i], results["bar"])
with self.assertRaises(errors.OutOfRangeError):
@ -308,7 +308,7 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
dense_shape)
sess.run(init_op, feed_dict={st: sparse_feed})
for i, s in enumerate(slices):
results = self.evaluate(get_next)
results = sess.run(get_next)
self.assertAllEqual(s, results.values)
expected_indices = np.array(
[[j] for j in range(len(slices[i]))]).reshape([-1, 1])
@ -474,15 +474,15 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
with ops.device("/cpu:0"):
var_0 = resource_variable_ops.ResourceVariable(initial_value=0)
dataset = dataset.map(lambda x: x + var_0.read_value())
self.evaluate(var_0.initializer)
sess.run(var_0.initializer)
with ops.device("/cpu:1"):
var_1 = resource_variable_ops.ResourceVariable(initial_value=0)
dataset = dataset.map(lambda x: x + var_1.read_value())
self.evaluate(var_1.initializer)
sess.run(var_1.initializer)
iterator = dataset.make_initializable_iterator()
self.evaluate(iterator.initializer)
sess.run(iterator.initializer)
with self.assertRaisesRegexp(
errors.FailedPreconditionError,
@ -506,7 +506,7 @@ class DatasetConstructorBenchmark(test.Benchmark):
next_element = iterator.get_next()
with session.Session() as sess:
self.evaluate(iterator.initializer)
sess.run(iterator.initializer)
# Run one whole epoch to burn in the computation.
for _ in range(input_size // batch_size):
sess.run(next_element)
@ -543,7 +543,7 @@ class DatasetConstructorBenchmark(test.Benchmark):
next_element = iterator.get_next()
with session.Session() as sess:
self.evaluate(iterator.initializer)
sess.run(iterator.initializer)
get_next_element = sess.make_callable(next_element)
# Run one whole epoch to burn in the computation.
for _ in range(input_size // batch_size):
@ -582,7 +582,7 @@ class DatasetConstructorBenchmark(test.Benchmark):
next_element = iterator.get_next()
with session.Session() as sess:
self.evaluate(iterator.initializer)
sess.run(iterator.initializer)
get_next_element = sess.make_callable(next_element)
# Run one whole epoch to burn in the computation.
for _ in range(input_size // batch_size):
@ -620,7 +620,7 @@ class DatasetConstructorBenchmark(test.Benchmark):
next_element = iterator.get_next()
with session.Session() as sess:
self.evaluate(iterator.initializer)
sess.run(iterator.initializer)
get_next_element = sess.make_callable(next_element)
# Run one whole epoch to burn in the computation.
for _ in range(input_size // batch_size):

View File

@ -47,10 +47,10 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
with self.cached_session() as sess:
for _ in range(2): # Run twice to test reinitialization.
self.evaluate(init_op)
sess.run(init_op)
for _ in range(num_repeats):
for elem in elem_sequence:
self.assertAllEqual(elem, self.evaluate(get_next))
self.assertAllEqual(elem, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -65,7 +65,7 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
with self.cached_session() as sess:
for _ in range(num_repeats):
for elem in elem_sequence:
self.assertAllEqual(elem, self.evaluate(get_next))
self.assertAllEqual(elem, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -133,10 +133,10 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
for _ in range(num_inner_repeats * num_outer_repeats):
for elem in input_list:
val0, val1 = self.evaluate(get_next)
val0, val1 = sess.run(get_next)
self.assertAllEqual(elem[0], val0)
self.assertAllEqual(elem[1], val1)
with self.assertRaises(errors.OutOfRangeError):
@ -192,10 +192,10 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
for elem in [0, 1]:
for _ in range(num_parallel_iterators):
self.assertAllEqual(elem, self.evaluate(get_next))
self.assertAllEqual(elem, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -215,9 +215,9 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
self.assertEqual(dtype, get_next.dtype)
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
for expected in [[1], [2], [3]]:
next_val = self.evaluate(get_next)
next_val = sess.run(get_next)
self.assertEqual(dtype.as_numpy_dtype, next_val.dtype)
self.assertAllEqual(expected, next_val)
with self.assertRaises(errors.OutOfRangeError):
@ -236,9 +236,9 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
for expected in [b"foo", b"bar", b"baz"]:
next_val = self.evaluate(get_next)
next_val = sess.run(get_next)
self.assertAllEqual(expected, next_val)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -257,12 +257,12 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
self.assertAllEqual([1, 2, 3], self.evaluate(get_next))
self.assertAllEqual([4, 5, 6], self.evaluate(get_next))
sess.run(init_op)
self.assertAllEqual([1, 2, 3], sess.run(get_next))
self.assertAllEqual([4, 5, 6], sess.run(get_next))
with self.assertRaisesOpError("The expected type was int64"):
sess.run(get_next)
self.assertAllEqual([7, 8, 9], self.evaluate(get_next))
self.assertAllEqual([7, 8, 9], sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -280,12 +280,12 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
self.assertAllEqual([1, 2, 3], self.evaluate(get_next))
self.assertAllEqual([4, 5, 6], self.evaluate(get_next))
sess.run(init_op)
self.assertAllEqual([1, 2, 3], sess.run(get_next))
self.assertAllEqual([4, 5, 6], sess.run(get_next))
with self.assertRaisesOpError(r"element of shape \(3,\) was expected"):
sess.run(get_next)
self.assertAllEqual([11, 12, 13], self.evaluate(get_next))
self.assertAllEqual([11, 12, 13], sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -304,16 +304,16 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
self.assertEqual((1, 2), self.evaluate(get_next))
self.assertEqual((3, 4), self.evaluate(get_next))
sess.run(init_op)
self.assertEqual((1, 2), sess.run(get_next))
self.assertEqual((3, 4), sess.run(get_next))
with self.assertRaisesOpError(
r"The expected structure was \(tf\.int64, tf\.int64\)"):
sess.run(get_next)
with self.assertRaisesOpError(
r"The expected structure was \(tf\.int64, tf\.int64\)"):
sess.run(get_next)
self.assertEqual((9, 10), self.evaluate(get_next))
self.assertEqual((9, 10), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -329,9 +329,9 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
self.assertAllEqual(1, self.evaluate(get_next))
self.assertAllEqual([2, 3], self.evaluate(get_next))
sess.run(init_op)
self.assertAllEqual(1, sess.run(get_next))
self.assertAllEqual([2, 3], sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -349,9 +349,9 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
self.assertAllEqual(0, self.evaluate(get_next))
self.assertAllEqual(1, self.evaluate(get_next))
sess.run(init_op)
self.assertAllEqual(0, sess.run(get_next))
self.assertAllEqual(1, sess.run(get_next))
def testFromGeneratorDestructorCalled(self):
# Use an `Event` to signal that the generator has been deleted.
@ -378,9 +378,9 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with session.Session() as sess:
self.evaluate(init_op)
self.assertAllEqual(42, self.evaluate(get_next))
self.assertAllEqual(42, self.evaluate(get_next))
sess.run(init_op)
self.assertAllEqual(42, sess.run(get_next))
self.assertAllEqual(42, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test that `GeneratorWrapper` object is destroyed when the
@ -407,10 +407,10 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
expected = [1, 2, 2, 3, 3, 3, 4, 4, 4, 4]
for x in expected:
self.assertEqual(x, self.evaluate(get_next))
self.assertEqual(x, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -436,13 +436,13 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
expected = [(0, b"Hi!"),
(0, b"Hi!"), (1, b"Hi!"),
(0, b"Hi!"), (1, b"Hi!"), (2, b"Hi!"),
(0, b"Hi!"), (1, b"Hi!"), (2, b"Hi!"), (3, b"Hi!")]
for x in expected:
self.assertEqual(x, self.evaluate(get_next))
self.assertEqual(x, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -470,9 +470,9 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
self.assertAllEqual(37, self.evaluate(get_next))
self.assertAllEqual(37, self.evaluate(get_next))
sess.run(init_op)
self.assertAllEqual(37, sess.run(get_next))
self.assertAllEqual(37, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.assertTrue(event.is_set())

View File

@ -67,7 +67,7 @@ class FilterDatasetTest(test_base.DatasetTestBase):
sess.run(init_op, feed_dict={count: count_val, modulus: modulus_val})
for _ in range(count_val):
for i in [x for x in range(7) if x**2 % modulus_val == 0]:
result = self.evaluate(get_next)
result = sess.run(get_next)
for component, result_component in zip(components, result):
self.assertAllEqual(component[i]**2, result_component)
with self.assertRaises(errors.OutOfRangeError):
@ -86,9 +86,9 @@ class FilterDatasetTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.assertEqual(0, self.evaluate(get_next))
self.assertEqual(1, self.evaluate(get_next))
self.assertEqual(3, self.evaluate(get_next))
self.assertEqual(0, sess.run(get_next))
self.assertEqual(1, sess.run(get_next))
self.assertEqual(3, sess.run(get_next))
def testFilterDict(self):
iterator = (dataset_ops.Dataset.range(10)
@ -100,10 +100,10 @@ class FilterDatasetTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
for i in range(10):
if (i ** 2) % 2 == 0:
self.assertEqual(i * 2 + i**2, self.evaluate(get_next))
self.assertEqual(i * 2 + i ** 2, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -125,8 +125,8 @@ class FilterDatasetTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
self.assertAllEqual(input_data[0], self.evaluate(get_next))
sess.run(init_op)
self.assertAllEqual(input_data[0], sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -148,9 +148,9 @@ class FilterDatasetTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
for i in range(5):
actual = self.evaluate(get_next)
actual = sess.run(get_next)
self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue))
self.assertSparseValuesEqual(actual, _map_fn(i * 2)[0])
with self.assertRaises(errors.OutOfRangeError):
@ -166,9 +166,9 @@ class FilterDatasetTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
for i in range(10):
self.assertEqual((i, True), self.evaluate(get_next))
self.assertEqual((i, True), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -178,7 +178,7 @@ class FilterDatasetTest(test_base.DatasetTestBase):
iterators = [dataset.make_one_shot_iterator() for _ in range(10)]
next_elements = [iterator.get_next() for iterator in iterators]
with self.cached_session() as sess:
self.assertEqual([0 for _ in range(10)], self.evaluate(next_elements))
self.assertEqual([0 for _ in range(10)], sess.run(next_elements))
class FilterDatasetBenchmark(test.Benchmark):

View File

@ -45,10 +45,10 @@ class FlatMapDatasetTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
for i in repeats:
for _ in range(i):
self.assertEqual(i, self.evaluate(get_next))
self.assertEqual(i, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -64,11 +64,11 @@ class FlatMapDatasetTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
for row in repeats:
for i in row:
for _ in range(i):
self.assertEqual(i, self.evaluate(get_next))
self.assertEqual(i, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -94,12 +94,12 @@ class FlatMapDatasetTest(test_base.DatasetTestBase):
with session.Session(server.target) as sess2:
for _ in range(3):
sess = random.choice([sess1, sess2])
self.evaluate(init_op)
sess.run(init_op)
for row in repeats:
for i in row:
for _ in range(i):
sess = random.choice([sess1, sess2])
self.assertEqual(i, self.evaluate(get_next))
self.assertEqual(i, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess = random.choice([sess1, sess2])
@ -115,10 +115,10 @@ class FlatMapDatasetTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
for i in range(10):
for _ in range(i ** 2):
self.assertEqual(i * 2, self.evaluate(get_next))
self.assertEqual(i * 2, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# pylint: enable=g-long-lambda
@ -139,11 +139,11 @@ class FlatMapDatasetTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
for i in range(10):
for j in range(2):
expected = [i, 0] if j % 2 == 0 else [0, -i]
self.assertAllEqual(expected, self.evaluate(get_next))
self.assertAllEqual(expected, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)

View File

@ -196,7 +196,7 @@ class InterleaveDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.cached_session() as sess:
for expected_element in _interleave(
_repeat(input_values, count), cycle_length, block_length):
self.assertEqual(expected_element, self.evaluate(get_next))
self.assertEqual(expected_element, sess.run(get_next))
for _ in range(2):
with self.assertRaises(errors.OutOfRangeError):
@ -231,7 +231,7 @@ class InterleaveDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaises(errors.InvalidArgumentError):
sess.run(get_next)
else:
self.assertEqual(value, self.evaluate(get_next))
self.assertEqual(value, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -254,7 +254,7 @@ class InterleaveDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
for i in range(10):
for j in range(2):
expected = [i, 0] if j % 2 == 0 else [0, -i]
self.assertAllEqual(expected, self.evaluate(get_next))
self.assertAllEqual(expected, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -308,7 +308,7 @@ class InterleaveDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
for element in elements:
coordination_events[element].set()
self.assertEqual(element * element, self.evaluate(get_next))
self.assertEqual(element * element, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)

View File

@ -57,7 +57,7 @@ class IteratorClusterTest(test.TestCase):
with session.Session(worker[0].target) as sess:
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(get_next_op)
sess.run(get_next_op)
def _testRemoteIteratorHelper(self, device0, device1, target):
with ops.device(device1):
@ -134,12 +134,12 @@ class IteratorClusterTest(test.TestCase):
get_next = iterator.get_next()
with session.Session(worker[0].target) as sess:
self.evaluate(table.initializer)
self.evaluate(init_op)
self.assertAllEqual([0, 0, -1, 1, 2], self.evaluate(get_next))
sess.run(table.initializer)
sess.run(init_op)
self.assertAllEqual([0, 0, -1, 1, 2], sess.run(get_next))
with session.Session(worker[0].target) as sess:
self.assertAllEqual([2, 0], self.evaluate(get_next))
self.assertAllEqual([2, 0], sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -166,7 +166,7 @@ class IteratorClusterTest(test.TestCase):
get_next = iterator.get_next()
with session.Session(worker[0].target) as sess:
self.evaluate(init_op)
sess.run(init_op)
for _ in range(3):
sess.run(get_next)

View File

@ -97,7 +97,7 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
with self.cached_session() as sess:
for _ in range(14):
for i in range(7):
result = self.evaluate(get_next)
result = sess.run(get_next)
for component, result_component in zip(components, result):
self.assertAllEqual(component[i]**2, result_component)
with self.assertRaises(errors.OutOfRangeError):
@ -123,7 +123,7 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
with self.cached_session() as sess:
for _ in range(14):
for i in range(7):
result = self.evaluate(get_next)
result = sess.run(get_next)
for component, result_component in zip(components, result):
self.assertAllEqual(component[i]**2, result_component)
with self.assertRaises(errors.OutOfRangeError):
@ -159,7 +159,7 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
for _ in range(14):
for i in range(7):
result = self.evaluate(get_next)
result = sess.run(get_next)
for component, result_component in zip(components, result):
self.assertAllEqual(component[i]**2, result_component)
with self.assertRaises(errors.OutOfRangeError):
@ -175,7 +175,7 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
config = config_pb2.ConfigProto(
inter_op_parallelism_threads=1, use_per_session_threads=True)
with session.Session(config=config) as sess:
self.assertAllEqual([1, 4, 9], self.evaluate(next_element))
self.assertAllEqual([1, 4, 9], sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
@ -254,15 +254,15 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
get_next = iterator.get_next()
with session.Session(server.target) as sess:
self.evaluate(init_op)
results = self.evaluate(get_next)
sess.run(init_op)
results = sess.run(get_next)
for component, result_component in zip(components, results):
self.assertAllEqual(component, result_component)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Re-initialize the iterator in the first session.
self.evaluate(init_op)
sess.run(init_op)
with ops.Graph().as_default():
# Re-define the iterator manually, without defining any of the
@ -277,7 +277,7 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
with session.Session(server.target) as sess:
# Use the iterator without re-initializing in the second session.
results = self.evaluate(get_next)
results = sess.run(get_next)
for component, result_component in zip(components, results):
self.assertAllEqual(component, result_component)
with self.assertRaises(errors.OutOfRangeError):
@ -317,20 +317,20 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
sess.run(get_next)
# Initialize with one dataset.
self.evaluate(dataset_3_init_op)
self.assertAllEqual([1, 2, 3], self.evaluate(get_next))
sess.run(dataset_3_init_op)
self.assertAllEqual([1, 2, 3], sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Initialize with a different dataset.
self.evaluate(dataset_4_init_op)
self.assertAllEqual([4, 5, 6, 7], self.evaluate(get_next))
sess.run(dataset_4_init_op)
self.assertAllEqual([4, 5, 6, 7], sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Reinitialize with the first dataset.
self.evaluate(dataset_3_init_op)
self.assertAllEqual([1, 2, 3], self.evaluate(get_next))
sess.run(dataset_3_init_op)
self.assertAllEqual([1, 2, 3], sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -348,7 +348,7 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
g, output_types=dtypes.int64)
sess.run(iterator.make_initializer(dataset_1))
for expected in range(10):
self.assertEqual(expected, self.evaluate(next_element))
self.assertEqual(expected, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
@ -356,7 +356,7 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
g, output_types=dtypes.int64)
sess.run(iterator.make_initializer(dataset_2))
for expected in range(10):
self.assertEqual(expected, self.evaluate(next_element))
self.assertEqual(expected, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
@ -679,10 +679,10 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
n = itr.get_next()
with session.Session(s3.target, config=config) as sess:
self.evaluate(itr.initializer)
sess.run(itr.initializer)
expected_values = worker_devices
for expected in expected_values:
self.assertEqual((compat.as_bytes(expected),), self.evaluate(n))
self.assertEqual((compat.as_bytes(expected),), sess.run(n))
with self.assertRaises(errors.OutOfRangeError):
sess.run(n)
@ -786,8 +786,8 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
with ops.Graph().as_default() as g:
init_op, _, save_op, _ = _build_range_dataset_graph()
with self.session(graph=g) as sess:
self.evaluate(init_op)
self.evaluate(save_op)
sess.run(init_op)
sess.run(save_op)
# Attempt to restore the saved iterator into an IteratorResource of
# incompatible type. An iterator of RangeDataset has output type int64,
@ -798,7 +798,7 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
_, _, _, restore_op = _build_reader_dataset_graph()
with self.session(graph=g) as sess:
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(restore_op)
sess.run(restore_op)
def testRepeatedGetNextWarning(self):
iterator = dataset_ops.Dataset.range(10).make_one_shot_iterator()
@ -949,7 +949,7 @@ class IteratorCheckpointingTest(test.TestCase):
checkpoint.restore(checkpoint_management.latest_checkpoint(
checkpoint_directory)).initialize_or_restore(sess)
for j in range(2):
self.assertEqual(i * 2 + j, self.evaluate(get_next))
self.assertEqual(i * 2 + j, sess.run(get_next))
checkpoint.save(file_prefix=checkpoint_prefix)

View File

@ -102,7 +102,7 @@ class ListFilesDatasetOpTest(test_base.DatasetTestBase):
all_produced_filenames = []
for _ in range(3):
produced_filenames = []
self.evaluate(itr.initializer)
sess.run(itr.initializer)
try:
while True:
produced_filenames.append(sess.run(next_element))

View File

@ -114,7 +114,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
sess.run(init_op, feed_dict={count: 14})
for _ in range(14):
for i in range(7):
result = self.evaluate(get_next)
result = sess.run(get_next)
for component, result_component in zip(components, result):
self.assertAllEqual(component[i]**2, result_component)
with self.assertRaises(errors.OutOfRangeError):
@ -185,7 +185,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
output_buffer_size: output_buffer_size_val})
for _ in range(14):
for i in range(7):
result = self.evaluate(get_next)
result = sess.run(get_next)
for component, result_component in zip(components, result):
self.assertAllEqual(component[i]**2, result_component)
with self.assertRaises(errors.OutOfRangeError):
@ -242,7 +242,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
for _ in range(3):
sess.run(get_next)
@ -257,7 +257,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
for _ in range(3):
sess.run(get_next)
@ -272,7 +272,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
for _ in range(3):
sess.run(get_next)
# The 4th element is NaN, so `array_ops.check_numerics()` should fail.
@ -293,7 +293,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
for _ in range(3):
sess.run(get_next)
# The 4th element is NaN, so `array_ops.check_numerics()` should fail.
@ -325,10 +325,10 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
with ops.Graph().as_default() as g:
captured_init_op, init_op, get_next = _build_graph()
with self.session(graph=g) as sess:
self.evaluate(captured_init_op)
self.evaluate(init_op)
sess.run(captured_init_op)
sess.run(init_op)
for i in range(10):
self.assertEqual(i * i, self.evaluate(get_next))
self.assertEqual(i * i, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -353,8 +353,8 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(table.initializer)
self.evaluate(init_op)
sess.run(table.initializer)
sess.run(init_op)
sess.run(get_next)
sess.run(get_next)
with self.assertRaises(errors.OutOfRangeError):
@ -371,11 +371,11 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(enqueue_op)
self.evaluate(close_op)
self.evaluate(init_op)
sess.run(enqueue_op)
sess.run(close_op)
sess.run(init_op)
for element in elements:
self.assertEqual(element, self.evaluate(get_next))
self.assertEqual(element, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -396,9 +396,9 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(enqueue_op)
self.evaluate(close_op)
self.evaluate(init_op)
sess.run(enqueue_op)
sess.run(close_op)
sess.run(init_op)
for i in range(100):
self.assertEqual(sorted([elements[i * 2], elements[i * 2 + 1]]),
sorted(sess.run(get_next)))
@ -415,15 +415,15 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(counter_var.initializer)
self.evaluate(init_op)
sess.run(counter_var.initializer)
sess.run(init_op)
for i in range(10):
self.assertEqual(i, self.evaluate(counter_var))
self.assertEqual(i + 1, self.evaluate(get_next))
self.assertEqual(10, self.evaluate(counter_var))
self.assertEqual(i, sess.run(counter_var))
self.assertEqual(i + 1, sess.run(get_next))
self.assertEqual(10, sess.run(counter_var))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.assertEqual(10, self.evaluate(counter_var))
self.assertEqual(10, sess.run(counter_var))
def testCaptureUninitializedVariableError(self):
counter_var = variable_scope.get_variable(
@ -435,7 +435,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
with self.assertRaises(errors.NotFoundError):
sess.run(get_next)
@ -447,14 +447,14 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
random_values = []
with self.assertRaises(errors.OutOfRangeError):
while True:
random_values.extend(sess.run(get_next))
self.assertEqual(10, len(random_values))
self.assertGreater(np.abs(np.diff(random_values)).max(), 1e-6)
self.evaluate(init_op)
sess.run(init_op)
random_values_2 = []
with self.assertRaises(errors.OutOfRangeError):
while True:
@ -473,8 +473,8 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
random_values = self.evaluate(get_next)
sess.run(init_op)
random_values = sess.run(get_next)
# Assert that one of the next 99 batches yielded by the iterator is
# different from the first.
@ -500,15 +500,15 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(counter_var.initializer)
self.evaluate(init_op)
sess.run(counter_var.initializer)
sess.run(init_op)
for i in range(10):
self.assertEqual(i, self.evaluate(counter_var))
self.assertEqual(i, self.evaluate(get_next))
self.assertEqual(10, self.evaluate(counter_var))
self.assertEqual(i, sess.run(counter_var))
self.assertEqual(i, sess.run(get_next))
self.assertEqual(10, sess.run(counter_var))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.assertEqual(10, self.evaluate(counter_var))
self.assertEqual(10, sess.run(counter_var))
def testMapDict(self):
iterator = (dataset_ops.Dataset.range(10)
@ -519,9 +519,9 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
for i in range(10):
self.assertEqual(i * 2 + i**2, self.evaluate(get_next))
self.assertEqual(i * 2 + i ** 2, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -569,8 +569,8 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
self.assertAllEqual(row**2, self.evaluate(get_next))
sess.run(init_op)
self.assertAllEqual(row ** 2, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -611,7 +611,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
row = np.arange(6)
for num in [2, 3, 4]:
init_op, get_next = build_dataset(row, num)
self.evaluate(init_op)
sess.run(init_op)
for i in range(6):
self.assertEqual(
(i // 2 if i % 2 else i * 2) if (num == 2 or num == 3) else i * 2,
@ -652,7 +652,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
row = np.arange(6)
for num in [2, 3, 4]:
init_op, get_next = build_dataset(row, num)
self.evaluate(init_op)
sess.run(init_op)
self.assertAllEqual(
[x // 2 if (num == 2 or num == 3) else x * 2 for x in row],
sess.run(get_next))
@ -697,7 +697,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
self.assertAllEqual([(x // 2 if x % 2 else x * 2) if
(num == 2 or num == 3) else x * 2 for x in row],
sess.run(get_next))
@ -735,7 +735,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
for buffer_size in [1, 10, 100, 1000]:
sess.run(init_op, feed_dict={buffer_size_placeholder: buffer_size})
for i in range(100):
self.assertEqual(i * i, self.evaluate(get_next))
self.assertEqual(i * i, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -753,10 +753,10 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
sess.run(init_op, feed_dict={buffer_size_placeholder: buffer_size})
for i in range(event_will_be_set_after_consuming):
self.assertFalse(ev.is_set())
self.assertEqual(i * i, self.evaluate(get_next))
self.assertEqual(i * i, sess.run(get_next))
ev.wait()
for i in range(event_will_be_set_after_consuming, 100):
self.assertEqual(i * i, self.evaluate(get_next))
self.assertEqual(i * i, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -768,9 +768,9 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
for i in range(10):
self.assertEqual((i, 37.0), self.evaluate(get_next))
self.assertEqual((i, 37.0), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -789,9 +789,9 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
for i in range(10):
self.assertEqual((i, 37.0), self.evaluate(get_next))
self.assertEqual((i, 37.0), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -810,9 +810,9 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
for i in range(10):
actual = self.evaluate(get_next)
actual = sess.run(get_next)
self.assertIsInstance(actual, sparse_tensor.SparseTensorValue)
self.assertSparseValuesEqual(actual, _sparse(i))
with self.assertRaises(errors.OutOfRangeError):
@ -837,9 +837,9 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
for i in range(10):
actual = self.evaluate(get_next)
actual = sess.run(get_next)
self.assertIsInstance(actual, sparse_tensor.SparseTensorValue)
self.assertSparseValuesEqual(actual, _check(_sparse(i)).eval())
with self.assertRaises(errors.OutOfRangeError):
@ -861,9 +861,9 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
for i in range(100):
self.assertEqual(i, self.evaluate(get_next))
self.assertEqual(i, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -875,9 +875,9 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
for i in range(10):
self.assertEqual((i, b"hello", 10), self.evaluate(get_next))
self.assertEqual((i, b"hello", 10), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -945,7 +945,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.cached_session() as sess:
with self.assertRaisesRegexp(errors.InvalidArgumentError, "BrokenConst"):
self.evaluate(iterator.initializer)
sess.run(iterator.initializer)
# pylint: disable=g-long-lambda
@parameterized.named_parameters(
@ -972,7 +972,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
get_next = iterator.get_next()
with self.cached_session() as sess:
tids = self.evaluate(get_next)
tids = sess.run(get_next)
self.assertTrue(all(tids[0] == tid for tid in tids))
# pylint: enable=g-long-lambda
@ -996,7 +996,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
expected = map_fn(*sess.run(self.structuredElement(structure)))
else:
expected = map_fn(sess.run(self.structuredElement(structure)))
self.assertEqual(expected, self.evaluate(get_next))
self.assertEqual(expected, sess.run(get_next))
@parameterized.named_parameters(
("Sequential", None),
@ -1011,7 +1011,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.cached_session() as sess:
sess.run(iterator.initializer, feed_dict={captured_t: 42})
self.assertEqual(42, self.evaluate(get_next))
self.assertEqual(42, sess.run(get_next))
@parameterized.named_parameters(
("1", 1, 1),
@ -1030,7 +1030,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.cached_session(config=config) as sess:
for i in range(num_elements):
coordination_events[i].set()
self.assertEqual(i * i, self.evaluate(get_next))
self.assertEqual(i * i, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -1052,7 +1052,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
for element in elements:
coordination_events[element].set()
self.assertEqual(element * element, self.evaluate(get_next))
self.assertEqual(element * element, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)

View File

@ -40,7 +40,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
config = config_pb2.ConfigProto(device_count={"CPU": 3})
with self.test_session(config=config) as sess:
self.evaluate(multi_device_iterator.initializer)
sess.run(multi_device_iterator.initializer)
def testBasic(self):
dataset = dataset_ops.Dataset.range(10)
@ -50,10 +50,10 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
config = config_pb2.ConfigProto(device_count={"CPU": 3})
with self.test_session(config=config) as sess:
self.evaluate(multi_device_iterator.initializer)
sess.run(multi_device_iterator.initializer)
for i in range(0, 10, 2):
self.assertEqual(i, self.evaluate(elem_on_1))
self.assertEqual(i + 1, self.evaluate(elem_on_2))
self.assertEqual(i, sess.run(elem_on_1))
self.assertEqual(i + 1, sess.run(elem_on_2))
with self.assertRaises(errors.OutOfRangeError):
sess.run(elem_on_1)
sess.run(elem_on_2)
@ -67,10 +67,10 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=config) as sess:
self.evaluate(multi_device_iterator.initializer)
sess.run(multi_device_iterator.initializer)
for i in range(0, 10, 2):
self.assertEqual(i, self.evaluate(elem_on_1))
self.assertEqual(i + 1, self.evaluate(elem_on_2))
self.assertEqual(i, sess.run(elem_on_1))
self.assertEqual(i + 1, sess.run(elem_on_2))
with self.assertRaises(errors.OutOfRangeError):
sess.run(elem_on_1)
sess.run(elem_on_2)
@ -85,12 +85,12 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
config = config_pb2.ConfigProto(device_count={"CPU": 3})
with self.test_session(config=config) as sess:
self.evaluate(multi_device_iterator.initializer)
sess.run(multi_device_iterator.initializer)
for i in range(0, 20, 4):
self.assertEqual(i, self.evaluate(elem_on_1))
self.assertEqual(i + 1, self.evaluate(elem_on_2))
self.assertEqual(i + 2, self.evaluate(elem_on_3))
self.assertEqual(i + 3, self.evaluate(elem_on_4))
self.assertEqual(i, sess.run(elem_on_1))
self.assertEqual(i + 1, sess.run(elem_on_2))
self.assertEqual(i + 2, sess.run(elem_on_3))
self.assertEqual(i + 3, sess.run(elem_on_4))
with self.assertRaises(errors.OutOfRangeError):
sess.run(elem_on_1)
sess.run(elem_on_2)
@ -105,11 +105,11 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
config = config_pb2.ConfigProto(device_count={"CPU": 3})
with self.test_session(config=config) as sess:
self.evaluate(multi_device_iterator.initializer)
sess.run(multi_device_iterator.initializer)
for i in range(0, 8, 2):
self.assertEqual(i, self.evaluate(elem_on_1))
self.assertEqual(i + 1, self.evaluate(elem_on_2))
self.assertEqual(8, self.evaluate(elem_on_1))
self.assertEqual(i, sess.run(elem_on_1))
self.assertEqual(i + 1, sess.run(elem_on_2))
self.assertEqual(8, sess.run(elem_on_1))
with self.assertRaises(errors.OutOfRangeError):
sess.run(elem_on_1)
sess.run(elem_on_2)
@ -126,7 +126,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
config = config_pb2.ConfigProto(device_count={"CPU": 3})
with self.test_session(config=config) as sess:
self.evaluate(multi_device_iterator.initializer)
sess.run(multi_device_iterator.initializer)
for i in range(0, 8, 2):
elem_on_1_has_value, elem_on_1_value = sess.run(
[elem_on_1_has_value_t, elem_on_1_t])
@ -140,8 +140,8 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
[elem_on_1_has_value_t, elem_on_1_t])
self.assertTrue(elem_on_1_has_value)
self.assertEqual(8, elem_on_1_value)
self.assertFalse(self.evaluate(elem_on_1_has_value_t))
self.assertFalse(self.evaluate(elem_on_2_has_value_t))
self.assertFalse(sess.run(elem_on_1_has_value_t))
self.assertFalse(sess.run(elem_on_2_has_value_t))
with self.assertRaises(errors.InvalidArgumentError):
sess.run(elem_on_1_t)
with self.assertRaises(errors.InvalidArgumentError):
@ -155,11 +155,11 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
config = config_pb2.ConfigProto(device_count={"CPU": 3})
with self.test_session(config=config) as sess:
self.evaluate(multi_device_iterator.initializer)
sess.run(multi_device_iterator.initializer)
for i in range(0, 10, 2):
self.assertEqual(i, self.evaluate(elem_on_1))
self.assertEqual(i, sess.run(elem_on_1))
for i in range(0, 10, 2):
self.assertEqual(i + 1, self.evaluate(elem_on_2))
self.assertEqual(i + 1, sess.run(elem_on_2))
with self.assertRaises(errors.OutOfRangeError):
sess.run(elem_on_1)
sess.run(elem_on_2)
@ -192,10 +192,10 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1})
with self.test_session(config=config) as sess:
self.evaluate(multi_device_iterator.initializer)
sess.run(multi_device_iterator.initializer)
for i in range(0, 10, 2):
self.assertEqual(i, self.evaluate(elem_on_1))
self.assertEqual(i + 1, self.evaluate(elem_on_2))
self.assertEqual(i, sess.run(elem_on_1))
self.assertEqual(i + 1, sess.run(elem_on_2))
with self.assertRaises(errors.OutOfRangeError):
sess.run(elem_on_1)
sess.run(elem_on_2)
@ -211,11 +211,11 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1})
with self.test_session(config=config) as sess:
self.evaluate(multi_device_iterator.initializer)
sess.run(multi_device_iterator.initializer)
for i in range(0, 10, 2):
self.assertEqual(i, self.evaluate(elem_on_1))
self.assertEqual(i, sess.run(elem_on_1))
for i in range(0, 10, 2):
self.assertEqual(i + 1, self.evaluate(elem_on_2))
self.assertEqual(i + 1, sess.run(elem_on_2))
with self.assertRaises(errors.OutOfRangeError):
sess.run(elem_on_1)
sess.run(elem_on_2)
@ -235,7 +235,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1})
with self.test_session(config=config) as sess:
self.evaluate(multi_device_iterator.initializer)
sess.run(multi_device_iterator.initializer)
for i in range(0, 8, 2):
elem_on_1_has_value, elem_on_1_value = sess.run(
[elem_on_1_has_value_t, elem_on_1_t])
@ -249,8 +249,8 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
[elem_on_1_has_value_t, elem_on_1_t])
self.assertTrue(elem_on_1_has_value)
self.assertEqual(8, elem_on_1_value)
self.assertFalse(self.evaluate(elem_on_1_has_value_t))
self.assertFalse(self.evaluate(elem_on_2_has_value_t))
self.assertFalse(sess.run(elem_on_1_has_value_t))
self.assertFalse(sess.run(elem_on_2_has_value_t))
with self.assertRaises(errors.InvalidArgumentError):
sess.run(elem_on_1_t)
with self.assertRaises(errors.InvalidArgumentError):
@ -272,10 +272,10 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
config = config_pb2.ConfigProto(device_count={"CPU": 3})
with self.test_session(config=config) as sess:
self.evaluate(multi_device_iterator.initializer)
sess.run(multi_device_iterator.initializer)
for i in range(0, 10, 2):
self.assertEqual(i, self.evaluate(elem_on_1))
self.assertEqual(i + 1, self.evaluate(elem_on_2))
self.assertEqual(i, sess.run(elem_on_1))
self.assertEqual(i + 1, sess.run(elem_on_2))
with self.assertRaises(errors.OutOfRangeError):
sess.run(elem_on_1)
sess.run(elem_on_2)

View File

@ -227,7 +227,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
# For each element of the dataset, assert that the optional evaluates to
# the expected value.
self.evaluate(iterator.initializer)
sess.run(iterator.initializer)
for _ in range(3):
elem_has_value, elem_value = sess.run([elem_has_value_t, elem_value_t])
self.assertTrue(elem_has_value)
@ -236,7 +236,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
# After exhausting the iterator, `next_elem.has_value()` will evaluate to
# false, and attempting to get the value will fail.
for _ in range(2):
self.assertFalse(self.evaluate(elem_has_value_t))
self.assertFalse(sess.run(elem_has_value_t))
with self.assertRaises(errors.InvalidArgumentError):
sess.run(elem_value_t)

View File

@ -40,7 +40,7 @@ class PrefetchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.cached_session() as sess:
sess.run(init_op, feed_dict={buffer_size_t: buffer_size})
for m in range(10):
self.assertEqual(m, self.evaluate(get_next))
self.assertEqual(m, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)

View File

@ -124,19 +124,19 @@ class ExperimentalCheckpointDatasetTest(test_base.DatasetTestBase):
with ops.Graph().as_default() as g:
init_op, get_next, save_op, _ = _build_graph(start, stop)
with self.session(graph=g) as sess:
self.evaluate(variables.global_variables_initializer())
self.evaluate(init_op)
sess.run(variables.global_variables_initializer())
sess.run(init_op)
for i in range(start, break_point):
self.assertEqual(i, self.evaluate(get_next))
self.evaluate(save_op)
self.assertEqual(i, sess.run(get_next))
sess.run(save_op)
with ops.Graph().as_default() as g:
init_op, get_next, _, restore_op = _build_graph(start, stop)
with self.session(graph=g) as sess:
self.evaluate(init_op)
self.evaluate(restore_op)
sess.run(init_op)
sess.run(restore_op)
for i in range(break_point, stop):
self.assertEqual(i, self.evaluate(get_next))
self.assertEqual(i, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -144,14 +144,14 @@ class ExperimentalCheckpointDatasetTest(test_base.DatasetTestBase):
with ops.Graph().as_default() as g:
init_op, get_next, save_op, restore_op = _build_graph(start, stop)
with self.session(graph=g) as sess:
self.evaluate(variables.global_variables_initializer())
self.evaluate(init_op)
sess.run(variables.global_variables_initializer())
sess.run(init_op)
for i in range(start, break_point):
self.assertEqual(i, self.evaluate(get_next))
self.evaluate(save_op)
self.evaluate(restore_op)
self.assertEqual(i, sess.run(get_next))
sess.run(save_op)
sess.run(restore_op)
for i in range(break_point, stop):
self.assertEqual(i, self.evaluate(get_next))
self.assertEqual(i, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -175,14 +175,14 @@ class ExperimentalCheckpointDatasetTest(test_base.DatasetTestBase):
with ops.Graph().as_default() as g:
init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs)
with self.session(graph=g) as sess:
self.evaluate(variables.global_variables_initializer())
self.evaluate(init_op)
sess.run(variables.global_variables_initializer())
sess.run(init_op)
for _ in range(break_epoch):
for i in range(start, stop):
self.assertEqual(i, self.evaluate(get_next))
self.assertEqual(i, sess.run(get_next))
for i in range(start, break_point):
self.assertEqual(i, self.evaluate(get_next))
self.evaluate(save_op)
self.assertEqual(i, sess.run(get_next))
sess.run(save_op)
with ops.Graph().as_default() as g:
# Create an empty IteratorResource and restore the Iterator into it.
@ -193,12 +193,12 @@ class ExperimentalCheckpointDatasetTest(test_base.DatasetTestBase):
restore_op = self._restore_op(iterator._iterator_resource)
get_next = iterator.get_next()
with self.session(graph=g) as sess:
self.evaluate(restore_op)
sess.run(restore_op)
for i in range(break_point, stop):
self.assertEqual(i, self.evaluate(get_next))
self.assertEqual(i, sess.run(get_next))
for _ in range(break_epoch + 1, num_epochs):
for i in range(start, stop):
self.assertEqual(i, self.evaluate(get_next))
self.assertEqual(i, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -221,20 +221,20 @@ class ExperimentalCheckpointDatasetTest(test_base.DatasetTestBase):
with ops.Graph().as_default() as g:
init_op, get_next, save_op, _ = _build_graph(start, stop)
with self.session(graph=g) as sess:
self.evaluate(variables.global_variables_initializer())
self.evaluate(init_op)
sess.run(variables.global_variables_initializer())
sess.run(init_op)
for i in range(start, break_point):
self.assertEqual(i, self.evaluate(get_next))
self.evaluate(save_op)
self.assertEqual(i, sess.run(get_next))
sess.run(save_op)
with ops.Graph().as_default() as g:
# Intentionally build a graph with a different value for stop to make sure
# the original dataset graph is actually getting loaded.
init_op, get_next, _, restore_op = _build_graph(start, stop_1)
with self.session(graph=g) as sess:
self.evaluate(restore_op)
sess.run(restore_op)
for i in range(break_point, stop):
self.assertEqual(i, self.evaluate(get_next))
self.assertEqual(i, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -259,19 +259,19 @@ class ExperimentalCheckpointDatasetTest(test_base.DatasetTestBase):
with ops.Graph().as_default() as g:
init_op, get_next, save_op, _ = _build_graph(start, stop)
with self.session(graph=g) as sess:
self.evaluate(variables.global_variables_initializer())
self.evaluate(init_op)
sess.run(variables.global_variables_initializer())
sess.run(init_op)
for i in range(start, break_point):
self.assertEqual(i, self.evaluate(get_next))
self.evaluate(save_op)
self.assertEqual(i, sess.run(get_next))
sess.run(save_op)
with ops.Graph().as_default() as g:
init_op, get_next, _, restore_op = _build_graph(start, stop)
with self.session(graph=g) as sess:
self.evaluate(init_op)
self.evaluate(restore_op)
sess.run(init_op)
sess.run(restore_op)
for i in range(break_point, stop):
self.assertEqual(i, self.evaluate(get_next))
self.assertEqual(i, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -294,27 +294,27 @@ class ExperimentalCheckpointDatasetTest(test_base.DatasetTestBase):
with ops.Graph().as_default() as g:
init_op, get_next, save_op, _ = _build_graph(start, stop)
with self.session(graph=g) as sess:
self.evaluate(variables.global_variables_initializer())
self.evaluate(init_op)
sess.run(variables.global_variables_initializer())
sess.run(init_op)
for i in range(start, break_point1):
self.assertEqual(i, self.evaluate(get_next))
self.evaluate(save_op)
self.assertEqual(i, sess.run(get_next))
sess.run(save_op)
with ops.Graph().as_default() as g:
init_op, get_next, save_op, restore_op = _build_graph(start, stop)
with self.session(graph=g) as sess:
self.evaluate(restore_op)
sess.run(restore_op)
for i in range(break_point1, break_point2):
self.assertEqual(i, self.evaluate(get_next))
self.evaluate(save_op)
self.assertEqual(i, sess.run(get_next))
sess.run(save_op)
break_point2 = 7
with ops.Graph().as_default() as g:
init_op, get_next, save_op, restore_op = _build_graph(start, stop)
with self.session(graph=g) as sess:
self.evaluate(restore_op)
sess.run(restore_op)
for i in range(break_point2, stop):
self.assertEqual(i, self.evaluate(get_next))
self.assertEqual(i, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -338,28 +338,28 @@ class ExperimentalCheckpointDatasetTest(test_base.DatasetTestBase):
init_op, get_next, save_op, restore_op = _build_graph(
start, stop, num_epochs)
with self.session(graph=g) as sess:
self.evaluate(variables.global_variables_initializer())
self.evaluate(init_op)
sess.run(variables.global_variables_initializer())
sess.run(init_op)
# Note: There is no checkpoint saved currently so a NotFoundError is
# raised.
with self.assertRaises(errors.NotFoundError):
self.evaluate(restore_op)
sess.run(restore_op)
for _ in range(break_epoch - 1):
for i in range(start, stop):
self.assertEqual(i, self.evaluate(get_next))
self.assertEqual(i, sess.run(get_next))
for i in range(start, break_range):
self.assertEqual(i, self.evaluate(get_next))
self.evaluate(save_op)
self.assertEqual(i, sess.run(get_next))
sess.run(save_op)
with ops.Graph().as_default() as g:
init_op, get_next, _, restore_op = _build_graph(start, stop, num_epochs)
with self.session(graph=g) as sess:
self.evaluate(restore_op)
sess.run(restore_op)
for i in range(break_range, stop):
self.assertEqual(i, self.evaluate(get_next))
self.assertEqual(i, sess.run(get_next))
for _ in range(break_epoch, num_epochs):
for i in range(start, stop):
self.assertEqual(i, self.evaluate(get_next))
self.assertEqual(i, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -381,23 +381,23 @@ class ExperimentalCheckpointDatasetTest(test_base.DatasetTestBase):
init_op, get_next, save_op, restore_op = _build_graph(
start, stop, num_epochs)
with self.session(graph=g) as sess:
self.evaluate(variables.global_variables_initializer())
self.evaluate(init_op)
sess.run(variables.global_variables_initializer())
sess.run(init_op)
# Note: There is no checkpoint saved currently so a NotFoundError is
# raised.
with self.assertRaises(errors.NotFoundError):
self.evaluate(restore_op)
sess.run(restore_op)
for _ in range(num_epochs):
for i in range(start, stop):
self.assertEqual(i, self.evaluate(get_next))
self.assertEqual(i, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.evaluate(save_op)
sess.run(save_op)
with ops.Graph().as_default() as g:
init_op, get_next, _, restore_op = _build_graph(start, stop, num_epochs)
with self.session(graph=g) as sess:
self.evaluate(restore_op)
sess.run(restore_op)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)

View File

@ -107,7 +107,7 @@ class TextLineDatasetTest(test_base.DatasetTestBase):
init_op, feed_dict={filenames: [test_filenames[0]],
num_epochs: 1})
for i in range(5):
self.assertEqual(self._lineText(0, i), self.evaluate(get_next))
self.assertEqual(self._lineText(0, i), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -116,7 +116,7 @@ class TextLineDatasetTest(test_base.DatasetTestBase):
init_op, feed_dict={filenames: [test_filenames[1]],
num_epochs: 1})
for i in range(5):
self.assertEqual(self._lineText(1, i), self.evaluate(get_next))
self.assertEqual(self._lineText(1, i), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -124,7 +124,7 @@ class TextLineDatasetTest(test_base.DatasetTestBase):
sess.run(init_op, feed_dict={filenames: test_filenames, num_epochs: 1})
for j in range(2):
for i in range(5):
self.assertEqual(self._lineText(j, i), self.evaluate(get_next))
self.assertEqual(self._lineText(j, i), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -133,7 +133,7 @@ class TextLineDatasetTest(test_base.DatasetTestBase):
for _ in range(10):
for j in range(2):
for i in range(5):
self.assertEqual(self._lineText(j, i), self.evaluate(get_next))
self.assertEqual(self._lineText(j, i), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -267,7 +267,7 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
init_op, feed_dict={filenames: [test_filenames[0]],
num_epochs: 1})
for i in range(self._num_records):
self.assertEqual(self._record(0, i), self.evaluate(get_next))
self.assertEqual(self._record(0, i), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -276,7 +276,7 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
init_op, feed_dict={filenames: [test_filenames[1]],
num_epochs: 1})
for i in range(self._num_records):
self.assertEqual(self._record(1, i), self.evaluate(get_next))
self.assertEqual(self._record(1, i), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -284,7 +284,7 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
sess.run(init_op, feed_dict={filenames: test_filenames, num_epochs: 1})
for j in range(self._num_files):
for i in range(self._num_records):
self.assertEqual(self._record(j, i), self.evaluate(get_next))
self.assertEqual(self._record(j, i), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -293,7 +293,7 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
for _ in range(10):
for j in range(self._num_files):
for i in range(self._num_records):
self.assertEqual(self._record(j, i), self.evaluate(get_next))
self.assertEqual(self._record(j, i), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -405,19 +405,19 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
num_epochs=num_epochs)
with self.session(graph=g) as sess:
self.evaluate(init_op)
sess.run(init_op)
# Note: There is no checkpoint saved currently so a NotFoundError is
# raised.
with self.assertRaises(errors.NotFoundError):
self.evaluate(restore_op)
sess.run(restore_op)
for epoch in range(num_epochs):
for f in range(self._num_files):
for r in range(self._num_records):
if (epoch == epoch_break and f == file_break and
r == record_break):
self.evaluate(save_op)
sess.run(save_op)
break
self.assertEqual(self._record(f, r), self.evaluate(get_next_op))
self.assertEqual(self._record(f, r), sess.run(get_next_op))
else:
continue
break
@ -426,13 +426,13 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
break
else:
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next_op)
sess.run(get_next_op)
with ops.Graph().as_default() as g:
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
num_epochs=num_epochs)
with self.session(graph=g) as sess:
self.evaluate(restore_op)
sess.run(restore_op)
for epoch in range(num_epochs):
for f in range(self._num_files):
for r in range(self._num_records):
@ -441,9 +441,9 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
(epoch == epoch_break and f == file_break and
r < record_break)):
continue
self.assertEqual(self._record(f, r), self.evaluate(get_next_op))
self.assertEqual(self._record(f, r), sess.run(get_next_op))
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next_op)
sess.run(get_next_op)
def testInitThenRestore(self):
# Note: Calling init_op before restore_op is redundant. This test just makes
@ -458,19 +458,19 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
num_epochs=num_epochs)
with self.session(graph=g) as sess:
self.evaluate(init_op)
sess.run(init_op)
# Note: There is no checkpoint saved currently so a NotFoundError is
# raised.
with self.assertRaises(errors.NotFoundError):
self.evaluate(restore_op)
sess.run(restore_op)
for epoch in range(num_epochs):
for f in range(self._num_files):
for r in range(self._num_records):
if (epoch == epoch_break and f == file_break and
r == record_break):
self.evaluate(save_op)
sess.run(save_op)
break
self.assertEqual(self._record(f, r), self.evaluate(get_next_op))
self.assertEqual(self._record(f, r), sess.run(get_next_op))
else:
continue
break
@ -479,14 +479,14 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
break
else:
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next_op)
sess.run(get_next_op)
with ops.Graph().as_default() as g:
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
num_epochs=num_epochs)
with self.session(graph=g) as sess:
self.evaluate(init_op)
self.evaluate(restore_op)
sess.run(init_op)
sess.run(restore_op)
for epoch in range(num_epochs):
for f in range(self._num_files):
for r in range(self._num_records):
@ -495,9 +495,9 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
(epoch == epoch_break and f == file_break and
r < record_break)):
continue
self.assertEqual(self._record(f, r), self.evaluate(get_next_op))
self.assertEqual(self._record(f, r), sess.run(get_next_op))
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next_op)
sess.run(get_next_op)
def testRestoreInModifiedGraph(self):
num_epochs = 10
@ -510,19 +510,19 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
num_epochs=num_epochs)
with self.session(graph=g) as sess:
self.evaluate(init_op)
sess.run(init_op)
# Note: There is no checkpoint saved currently so a NotFoundError is
# raised.
with self.assertRaises(errors.NotFoundError):
self.evaluate(restore_op)
sess.run(restore_op)
for epoch in range(num_epochs):
for f in range(self._num_files):
for r in range(self._num_records):
if (epoch == epoch_break and f == file_break and
r == record_break):
self.evaluate(save_op)
sess.run(save_op)
break
self.assertEqual(self._record(f, r), self.evaluate(get_next_op))
self.assertEqual(self._record(f, r), sess.run(get_next_op))
else:
continue
break
@ -531,13 +531,13 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
break
else:
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next_op)
sess.run(get_next_op)
with ops.Graph().as_default() as g:
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
num_epochs=num_epochs_1)
with self.session(graph=g) as sess:
self.evaluate(restore_op)
sess.run(restore_op)
for epoch in range(num_epochs):
for f in range(self._num_files):
for r in range(self._num_records):
@ -546,9 +546,9 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
(epoch == epoch_break and f == file_break and
r < record_break)):
continue
self.assertEqual(self._record(f, r), self.evaluate(get_next_op))
self.assertEqual(self._record(f, r), sess.run(get_next_op))
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next_op)
sess.run(get_next_op)
def testRestoreWithoutBuildingDatasetGraph(self):
num_epochs = 10
@ -560,19 +560,19 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
num_epochs=num_epochs)
with self.session(graph=g) as sess:
self.evaluate(init_op)
sess.run(init_op)
# Note: There is no checkpoint saved currently so a NotFoundError is
# raised.
with self.assertRaises(errors.NotFoundError):
self.evaluate(restore_op)
sess.run(restore_op)
for epoch in range(num_epochs):
for f in range(self._num_files):
for r in range(self._num_records):
if (epoch == epoch_break and f == file_break and
r == record_break):
self.evaluate(save_op)
sess.run(save_op)
break
self.assertEqual(self._record(f, r), self.evaluate(get_next_op))
self.assertEqual(self._record(f, r), sess.run(get_next_op))
else:
continue
break
@ -581,12 +581,12 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
break
else:
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next_op)
sess.run(get_next_op)
with ops.Graph().as_default() as g:
restore_op, get_next_op = self._restore_iterator()
with self.session(graph=g) as sess:
self.evaluate(restore_op)
sess.run(restore_op)
for epoch in range(num_epochs):
for f in range(self._num_files):
for r in range(self._num_records):
@ -595,9 +595,9 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
(epoch == epoch_break and f == file_break and
r < record_break)):
continue
self.assertEqual(self._record(f, r), self.evaluate(get_next_op))
self.assertEqual(self._record(f, r), sess.run(get_next_op))
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next_op)
sess.run(get_next_op)
def testRestoreUnusedIterator(self):
num_epochs = 10
@ -605,22 +605,22 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
num_epochs=num_epochs)
with self.session(graph=g) as sess:
self.evaluate(init_op)
sess.run(init_op)
# Note: There is no checkpoint saved currently so a NotFoundError is
# raised.
with self.assertRaises(errors.NotFoundError):
self.evaluate(restore_op)
sess.run(restore_op)
# Save unused iterator.
self.evaluate(save_op)
sess.run(save_op)
with ops.Graph().as_default() as g:
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
num_epochs=num_epochs)
with self.session(graph=g) as sess:
self.evaluate(restore_op)
sess.run(restore_op)
for _ in range(num_epochs * self._num_files * self._num_records):
self.evaluate(get_next_op)
sess.run(get_next_op)
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next_op)
sess.run(get_next_op)
def testRestoreExhaustedIterator(self):
num_epochs = 10
@ -629,26 +629,26 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
num_epochs=num_epochs)
with self.session(graph=g) as sess:
self.evaluate(init_op)
sess.run(init_op)
# Note: There is no checkpoint saved currently so a NotFoundError is
# raised.
with self.assertRaises(errors.NotFoundError):
self.evaluate(restore_op)
sess.run(restore_op)
for _ in range(num_epochs):
for f in range(self._num_files):
for r in range(self._num_records):
self.assertEqual(self._record(f, r), self.evaluate(get_next_op))
self.assertEqual(self._record(f, r), sess.run(get_next_op))
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next_op)
self.evaluate(save_op)
sess.run(get_next_op)
sess.run(save_op)
with ops.Graph().as_default() as g:
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
num_epochs=num_epochs)
with self.session(graph=g) as sess:
self.evaluate(restore_op)
sess.run(restore_op)
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next_op)
sess.run(get_next_op)
class TFRecordDatasetTest(test_base.DatasetTestBase):
@ -807,7 +807,7 @@ class TFRecordDatasetTest(test_base.DatasetTestBase):
with self.cached_session() as sess:
for j in range(self._num_files):
for i in range(self._num_records):
self.assertAllEqual(self._record(j, i), self.evaluate(next_element))
self.assertAllEqual(self._record(j, i), sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
@ -819,7 +819,7 @@ class TFRecordDatasetTest(test_base.DatasetTestBase):
with self.cached_session() as sess:
for j in range(self._num_files):
for i in range(self._num_records):
self.assertAllEqual(self._record(j, i), self.evaluate(next_element))
self.assertAllEqual(self._record(j, i), sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)

View File

@ -36,7 +36,7 @@ class ReduceDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
ds = dataset_ops.Dataset.range(1, i + 1)
result = ds.reduce(np.int64(0), lambda x, y: x + y)
with self.cached_session() as sess:
self.assertEqual(((i + 1) * i) // 2, self.evaluate(result))
self.assertEqual(((i + 1) * i) // 2, sess.run(result))
def testSumTuple(self):
@ -49,7 +49,7 @@ class ReduceDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
ds = dataset_ops.Dataset.zip((ds, ds))
result = ds.reduce(np.int64(0), reduce_fn)
with self.cached_session() as sess:
self.assertEqual(((i + 1) * i), self.evaluate(result))
self.assertEqual(((i + 1) * i), sess.run(result))
def testSumAndCount(self):
@ -61,7 +61,7 @@ class ReduceDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
ds = dataset_ops.Dataset.range(1, i + 1)
result = ds.reduce((np.int64(0), np.int64(0)), reduce_fn)
with self.cached_session() as sess:
s, c = self.evaluate(result)
s, c = sess.run(result)
self.assertEqual(((i + 1) * i) // 2, s)
self.assertEqual(i, c)
@ -93,8 +93,7 @@ class ReduceDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
ds = dataset_ops.Dataset.from_tensors(make_sparse_fn(i+1))
result = ds.reduce(make_sparse_fn(0), reduce_fn)
with self.cached_session() as sess:
self.assertSparseValuesEqual(
make_sparse_fn(i + 1), self.evaluate(result))
self.assertSparseValuesEqual(make_sparse_fn(i+1), sess.run(result))
def testNested(self):
@ -117,7 +116,7 @@ class ReduceDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
ds = dataset_ops.Dataset.range(1, i + 1).map(map_fn)
result = ds.reduce(map_fn(0), reduce_fn)
with self.cached_session() as sess:
result = self.evaluate(result)
result = sess.run(result)
self.assertEqual(((i + 1) * i) // 2, result["dense"])
self.assertSparseValuesEqual(make_sparse_fn(i), result["sparse"])

View File

@ -49,7 +49,7 @@ class SequenceDatasetTest(test_base.DatasetTestBase):
# Test a finite repetition.
sess.run(init_op, feed_dict={count_placeholder: 3})
for _ in range(3):
results = self.evaluate(get_next)
results = sess.run(get_next)
for component, result_component in zip(components, results):
self.assertAllEqual(component, result_component)
@ -59,7 +59,7 @@ class SequenceDatasetTest(test_base.DatasetTestBase):
# Test a different finite repetition.
sess.run(init_op, feed_dict={count_placeholder: 7})
for _ in range(7):
results = self.evaluate(get_next)
results = sess.run(get_next)
for component, result_component in zip(components, results):
self.assertAllEqual(component, result_component)
with self.assertRaises(errors.OutOfRangeError):
@ -75,7 +75,7 @@ class SequenceDatasetTest(test_base.DatasetTestBase):
# actually is infinite.
sess.run(init_op, feed_dict={count_placeholder: -1})
for _ in range(17):
results = self.evaluate(get_next)
results = sess.run(get_next)
for component, result_component in zip(components, results):
self.assertAllEqual(component, result_component)
@ -95,7 +95,7 @@ class SequenceDatasetTest(test_base.DatasetTestBase):
# Take fewer than input size
sess.run(init_op, feed_dict={count_placeholder: 4})
for i in range(4):
results = self.evaluate(get_next)
results = sess.run(get_next)
self.assertAllEqual(results, components[0][i:i+1])
with self.assertRaises(errors.OutOfRangeError):
@ -104,7 +104,7 @@ class SequenceDatasetTest(test_base.DatasetTestBase):
# Take more than input size
sess.run(init_op, feed_dict={count_placeholder: 25})
for i in range(10):
results = self.evaluate(get_next)
results = sess.run(get_next)
self.assertAllEqual(results, components[0][i:i+1])
with self.assertRaises(errors.OutOfRangeError):
@ -113,7 +113,7 @@ class SequenceDatasetTest(test_base.DatasetTestBase):
# Take all of input
sess.run(init_op, feed_dict={count_placeholder: -1})
for i in range(10):
results = self.evaluate(get_next)
results = sess.run(get_next)
self.assertAllEqual(results, components[0][i:i+1])
with self.assertRaises(errors.OutOfRangeError):
@ -142,7 +142,7 @@ class SequenceDatasetTest(test_base.DatasetTestBase):
# the first 4 elements and then read the rest.
sess.run(init_op, feed_dict={count_placeholder: 4})
for i in range(4, 10):
results = self.evaluate(get_next)
results = sess.run(get_next)
self.assertAllEqual(results, components[0][i:i+1])
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -165,7 +165,7 @@ class SequenceDatasetTest(test_base.DatasetTestBase):
# Skip nothing
sess.run(init_op, feed_dict={count_placeholder: 0})
for i in range(0, 10):
results = self.evaluate(get_next)
results = sess.run(get_next)
self.assertAllEqual(results, components[0][i:i+1])
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -187,7 +187,7 @@ class SequenceDatasetTest(test_base.DatasetTestBase):
with self.cached_session() as sess:
sess.run(init_op, feed_dict={inner_count: 7, outer_count: 14})
for _ in range(7 * 14):
results = self.evaluate(get_next)
results = sess.run(get_next)
for component, result_component in zip(components, results):
self.assertAllEqual(component, result_component)
with self.assertRaises(errors.OutOfRangeError):
@ -201,7 +201,7 @@ class SequenceDatasetTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)

View File

@ -66,7 +66,7 @@ class ShuffleDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.cached_session() as sess:
# First run without shuffling to collect the "ground truth".
self.evaluate(init_fifo_op)
sess.run(init_fifo_op)
unshuffled_elements = []
for _ in range(20):
unshuffled_elements.append(sess.run(get_next))
@ -159,7 +159,7 @@ class ShuffleDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.cached_session() as sess:
sess.run(iterator.initializer, feed_dict={seed_placeholder: 0})
for elem in elems:
self.assertEqual(elem, self.evaluate(get_next))
self.assertEqual(elem, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -188,9 +188,9 @@ class ShuffleDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
next_element = iterator.get_next()
with self.cached_session() as sess:
initial_permutation = self.evaluate(next_element)
self.assertAllEqual(initial_permutation, self.evaluate(next_element))
self.assertAllEqual(initial_permutation, self.evaluate(next_element))
initial_permutation = sess.run(next_element)
self.assertAllEqual(initial_permutation, sess.run(next_element))
self.assertAllEqual(initial_permutation, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
@ -261,7 +261,7 @@ class ShuffleDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.session(graph=g) as sess:
for iterator in iterators:
if initializable:
self.evaluate(iterator.initializer)
sess.run(iterator.initializer)
next_element = iterator.get_next()
run_results = []
for _ in range(300):

View File

@ -102,7 +102,7 @@ class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
num_full_batches = max(
0, (count * 7 - ((size - 1) * stride + 1)) // shift + 1)
for i in range(num_full_batches):
result = self.evaluate(get_next)
result = sess.run(get_next)
for component, result_component in zip(components, result):
for j in range(size):
self.assertAllEqual(component[(i * shift + j * stride) % 7]**2,
@ -111,7 +111,7 @@ class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
num_partial_batches = (count * 7) // shift + (
(count * 7) % shift > 0) - num_full_batches
for i in range(num_partial_batches):
result = self.evaluate(get_next)
result = sess.run(get_next)
for component, result_component in zip(components, result):
remaining = (count * 7) - ((num_full_batches + i) * shift)
num_elements = remaining // stride + ((remaining % stride) > 0)
@ -164,10 +164,10 @@ class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
num_batches = (10 - 5) // 3 + 1
for i in range(num_batches):
actual = self.evaluate(get_next)
actual = sess.run(get_next)
expected = sparse_tensor.SparseTensorValue(
indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
values=[i * 3, i * 3 + 1, i * 3 + 2, i * 3 + 3, i * 3 + 4],
@ -193,10 +193,10 @@ class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
num_batches = (10 - 5) // 3 + 1
for i in range(num_batches):
actual = self.evaluate(get_next)
actual = sess.run(get_next)
expected_indices = []
expected_values = []
for j in range(5):
@ -227,9 +227,9 @@ class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
sess.run(init_op)
# Slide: 1st batch.
actual = self.evaluate(get_next)
actual = sess.run(get_next)
expected = sparse_tensor.SparseTensorValue(
indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [1, 0, 0],
[1, 1, 0], [1, 2, 0], [1, 3, 0], [2, 0, 0], [2, 1, 0],
@ -239,7 +239,7 @@ class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertTrue(sparse_tensor.is_sparse(actual))
self.assertSparseValuesEqual(actual, expected)
# Slide: 2nd batch.
actual = self.evaluate(get_next)
actual = sess.run(get_next)
expected = sparse_tensor.SparseTensorValue(
indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [1, 0, 0],
[1, 1, 0], [1, 2, 0], [1, 3, 0], [2, 0, 0], [2, 1, 0],
@ -265,7 +265,7 @@ class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
next_element = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(iterator.initializer)
sess.run(iterator.initializer)
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
r"Cannot batch tensors with different shapes in component 0. "
@ -281,8 +281,8 @@ class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
get_next = dataset.make_one_shot_iterator().get_next()
with self.cached_session() as sess:
self.assertAllEqual(np.float32([1., 2.]), self.evaluate(get_next))
self.assertAllEqual(np.float32([2., 3.]), self.evaluate(get_next))
self.assertAllEqual(np.float32([1., 2.]), sess.run(get_next))
self.assertAllEqual(np.float32([2., 3.]), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)

View File

@ -55,7 +55,7 @@ class ZipDatasetTest(test_base.DatasetTestBase):
sess.run(init_op, feed_dict={ph: value for ph, value in zip(
component_placeholders, equal_length_components)})
for i in range(4):
results = self.evaluate(get_next)
results = sess.run(get_next)
for component, result_component in zip(
equal_length_components, results):
self.assertAllEqual(component[i], result_component)
@ -66,7 +66,7 @@ class ZipDatasetTest(test_base.DatasetTestBase):
sess.run(init_op, feed_dict={ph: value for ph, value in zip(
component_placeholders, variable_length_components)})
for i in range(2):
results = self.evaluate(get_next)
results = sess.run(get_next)
for component, result_component in zip(
variable_length_components, results):
self.assertAllEqual(component[i], result_component)
@ -103,7 +103,7 @@ class ZipDatasetTest(test_base.DatasetTestBase):
sess.run(init_op, feed_dict={ph: value for ph, value in zip(
component_placeholders, equal_length_components)})
for i in range(4):
result1, (result2, result3) = self.evaluate(get_next)
result1, (result2, result3) = sess.run(get_next)
self.assertAllEqual(equal_length_components[0][i], result1)
self.assertAllEqual(equal_length_components[1][i], result2)
self.assertAllEqual(equal_length_components[2][i], result3)

View File

@ -31,24 +31,24 @@ class ConvertTest(test.TestCase):
def testInteger(self):
resp = convert.optional_param_to_tensor("foo", 3)
with self.cached_session() as sess:
self.assertEqual(3, self.evaluate(resp))
self.assertEqual(3, sess.run(resp))
def testIntegerDefault(self):
resp = convert.optional_param_to_tensor("foo", None)
with self.cached_session() as sess:
self.assertEqual(0, self.evaluate(resp))
self.assertEqual(0, sess.run(resp))
def testStringDefault(self):
resp = convert.optional_param_to_tensor("bar", None, "default",
dtypes.string)
with self.cached_session() as sess:
self.assertEqual(compat.as_bytes("default"), self.evaluate(resp))
self.assertEqual(compat.as_bytes("default"), sess.run(resp))
def testString(self):
resp = convert.optional_param_to_tensor("bar", "value", "default",
dtypes.string)
with self.cached_session() as sess:
self.assertEqual(compat.as_bytes("value"), self.evaluate(resp))
self.assertEqual(compat.as_bytes("value"), sess.run(resp))
def testPartialShapeToTensorKnownDimension(self):
with self.cached_session() as sess:

View File

@ -1583,7 +1583,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
x = variables.VariableV1([1, 3, 3, 7], name="x")
_, idx = array_ops.unique(x, name="x_unique")
idx_times_two = math_ops.multiply(idx, 2, name="idx_times_two")
self.evaluate(x.initializer)
sess.run(x.initializer)
run_options = config_pb2.RunOptions(output_partition_graphs=True)
debug_utils.watch_graph(

View File

@ -126,8 +126,8 @@ class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase):
u = variables.Variable([12.0], name="u")
v = variables.Variable([30.0], name="v")
w = math_ops.add(u, v, name="w")
self.evaluate(u.initializer)
self.evaluate(v.initializer)
sess.run(u.initializer)
sess.run(v.initializer)
self._compareOriginalAndReconstructedGraphDefs(
sess, w, expected_output=[42.0])
@ -139,7 +139,7 @@ class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase):
b = math_ops.add(a, a, name="b")
with ops.control_dependencies([a, b]):
c = math_ops.multiply(b, b, name="c")
self.evaluate(a.initializer)
sess.run(a.initializer)
self._compareOriginalAndReconstructedGraphDefs(
sess, c, expected_output=400.0)
@ -150,8 +150,8 @@ class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase):
y = variables.Variable(20.0, name="y")
cond = control_flow_ops.cond(
x > y, lambda: math_ops.add(x, 1), lambda: math_ops.add(y, 1))
self.evaluate(x.initializer)
self.evaluate(y.initializer)
sess.run(x.initializer)
sess.run(y.initializer)
self._compareOriginalAndReconstructedGraphDefs(
sess, cond, expected_output=21.0)
@ -173,8 +173,8 @@ class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase):
toy_loss = x * (u - v)
train_op = gradient_descent.GradientDescentOptimizer(
learning_rate=0.1).minimize(toy_loss, name="train_op")
self.evaluate(u.initializer)
self.evaluate(v.initializer)
sess.run(u.initializer)
sess.run(v.initializer)
self._compareOriginalAndReconstructedGraphDefs(sess, train_op)

View File

@ -131,8 +131,8 @@ class DistributedSessionDebugTest(test_util.TensorFlowTestCase):
with session.Session(
config=self.session_config, graph=graph,
target=self.server_target) as sess:
self.evaluate(self.a.initializer)
self.evaluate(self.b.initializer)
sess.run(self.a.initializer)
sess.run(self.b.initializer)
run_options = config_pb2.RunOptions()
debug_utils.watch_graph(
@ -198,8 +198,8 @@ class DistributedSessionDebugTest(test_util.TensorFlowTestCase):
with session.Session(
config=self.session_config, graph=graph,
target=self.server_target) as sess:
self.evaluate(self.a.initializer)
self.evaluate(self.b.initializer)
sess.run(self.a.initializer)
sess.run(self.b.initializer)
def watch_fn(feeds, fetch_keys):
del feeds, fetch_keys

View File

@ -67,7 +67,7 @@ class SessionDebugMultiGPUTest(test_util.TensorFlowTestCase):
u1 = math_ops.multiply(v, v, name="u1")
w = math_ops.subtract(u1, u0, name="w")
self.evaluate(v.initializer)
sess.run(v.initializer)
run_options = config_pb2.RunOptions(output_partition_graphs=True)
debug_utils.watch_graph(run_options, sess.graph,

View File

@ -109,8 +109,8 @@ class SourceHelperTest(test_util.TensorFlowTestCase):
self.w = math_ops.matmul(self.u, self.v, name="w")
self.w_line_number = line_number_above()
self.evaluate(self.u.initializer)
self.evaluate(self.v.initializer)
sess.run(self.u.initializer)
sess.run(self.v.initializer)
run_options = config_pb2.RunOptions(output_partition_graphs=True)
debug_utils.watch_graph(

View File

@ -92,7 +92,7 @@ class AutoShardDatasetTest(test.TestCase):
with self.cached_session() as sess:
for f in range(self._shard_index, self._num_files, self._num_shards):
for r in range(self._num_records):
self.assertAllEqual(record_fn(r, f), self.evaluate(next_element))
self.assertAllEqual(record_fn(r, f), sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
@ -205,11 +205,10 @@ class AutoShardDatasetTest(test.TestCase):
with self.cached_session() as sess:
for f in range(self._shard_index, self._num_files, self._num_shards):
for r in range(self._num_records):
self.assertAllEqual(self._record(r, f), self.evaluate(next_element))
self.assertAllEqual(self._record(r, f), sess.run(next_element))
for f in range(self._shard_index, self._num_files, self._num_shards):
for r in range(self._num_records):
self.assertAllEqual(
self._text_line(r, f), self.evaluate(next_element))
self.assertAllEqual(self._text_line(r, f), sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)

View File

@ -149,9 +149,9 @@ class DefFunctionTest(test.TestCase):
result = fn(3.0)
self.evaluate(variables.global_variables_initializer())
sess.run(variables.global_variables_initializer())
self.assertAllEqual(sess.run(state[0]), 2.0)
self.assertAllEqual(self.evaluate(result), 6.0)
self.assertAllEqual(sess.run(result), 6.0)
def testLegacyGraphModeVariablesNonTrivialInitializer(self):
with ops.Graph().as_default(), self.test_session() as sess:
@ -168,9 +168,9 @@ class DefFunctionTest(test.TestCase):
result = fn(3.0)
self.evaluate(variables.global_variables_initializer())
sess.run(variables.global_variables_initializer())
self.assertAllEqual(sess.run(state[0]), 6.0)
self.assertAllEqual(self.evaluate(result), 18.0)
self.assertAllEqual(sess.run(result), 18.0)
def testLegacyGraphModeInputDependentInitializerFails(self):
with ops.Graph().as_default():

View File

@ -78,7 +78,7 @@ class FunctionGradientsTest(test.TestCase, parameterized.TestCase):
c = constant_op.constant([[2.]])
f_c = f(c)
g, = gradients_impl.gradients(f_c, c)
self.assertAllEqual(self.evaluate(g).values, [[1.0]])
self.assertAllEqual(sess.run(g).values, [[1.0]])
def testNoSymGradNestedDefun(self):

View File

@ -564,7 +564,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
variables.global_variables_initializer().run()
call = def_function.function(o.call)
op = call()
self.assertAllEqual(self.evaluate(op), 2.0)
self.assertAllEqual(sess.run(op), 2.0)
def testGraphModeManyFunctions(self):
with ops.Graph().as_default(), self.cached_session():
@ -1732,7 +1732,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
function.register(cpu_boost, x)
y = gpu_boost(x)
y_value = self.evaluate(y)
y_value = sess.run(y)
if test.is_gpu_available():
self.assertEqual(y_value, 5.0)

View File

@ -1027,7 +1027,7 @@ class CrossedColumnTest(test.TestCase):
outputs = _transform_features(features, [price_cross_wire])
output = outputs[price_cross_wire]
with self.cached_session() as sess:
output_val = self.evaluate(output)
output_val = sess.run(output)
self.assertAllEqual(
[[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]], output_val.indices)
for val in output_val.values:
@ -1886,8 +1886,7 @@ class LinearModelTest(test.TestCase):
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
sess.run(bias.assign([5.]))
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]],
self.evaluate(net))
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], sess.run(net))
def test_with_1d_unknown_shape_sparse_tensor(self):
price = fc._numeric_column('price')
@ -2526,8 +2525,7 @@ class _LinearModelTest(test.TestCase):
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
sess.run(bias.assign([5.]))
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]],
self.evaluate(net))
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], sess.run(net))
def test_with_1d_unknown_shape_sparse_tensor(self):
price = fc._numeric_column('price')

View File

@ -1188,7 +1188,7 @@ class CrossedColumnTest(test.TestCase):
outputs = fc._transform_features_v2(features, [price_cross_wire], None)
output = outputs[price_cross_wire]
with self.cached_session() as sess:
output_val = self.evaluate(output)
output_val = sess.run(output)
self.assertAllEqual(
[[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]], output_val.indices)
for val in output_val.values:
@ -2088,8 +2088,7 @@ class LinearModelTest(test.TestCase):
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
sess.run(bias.assign([5.]))
self.assertAllClose([[10 - 1000 + 5.], [100 - 10 + 5.]],
self.evaluate(net))
self.assertAllClose([[10 - 1000 + 5.], [100 - 10 + 5.]], sess.run(net))
coord.request_stop()
coord.join(threads)
@ -2125,8 +2124,7 @@ class LinearModelTest(test.TestCase):
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
sess.run(bias.assign([5.]))
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]],
self.evaluate(net))
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], sess.run(net))
def test_with_1d_unknown_shape_sparse_tensor(self):
price = fc.numeric_column('price')
@ -2845,8 +2843,7 @@ class OldLinearModelTest(test.TestCase):
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
sess.run(bias.assign([5.]))
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]],
self.evaluate(net))
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], sess.run(net))
def test_with_1d_unknown_shape_sparse_tensor(self):
price = fc.numeric_column('price')

View File

@ -102,7 +102,7 @@ class FunctionTest(test.TestCase):
call = MyIdentityFunc([18.0])
self.assertEqual("MyIdentity", call.op.name)
with session.Session() as sess:
self.assertAllEqual([18.0], self.evaluate(call))
self.assertAllEqual([18.0], sess.run(call))
def testIdentityImplicitDeref(self):
@ -116,8 +116,8 @@ class FunctionTest(test.TestCase):
self.assertEqual("MyIdentity", call.op.name)
for cfg in _OptimizerOptions():
with session.Session(config=cfg) as sess:
self.evaluate(var.initializer)
self.assertAllEqual([18.0], self.evaluate(call))
sess.run(var.initializer)
self.assertAllEqual([18.0], sess.run(call))
def testIdentityOutputName(self):
@ -130,7 +130,7 @@ class FunctionTest(test.TestCase):
call = MyIdentityFunc([18.0])
self.assertEqual("MyIdentity", call.op.name)
with session.Session() as sess:
self.assertAllEqual([18.0], self.evaluate(call))
self.assertAllEqual([18.0], sess.run(call))
def testTooManyOutputNames(self):
@ -158,7 +158,7 @@ class FunctionTest(test.TestCase):
call = APlus2B([1.0], [2.0])
self.assertEqual("APlus2B", call.op.name)
with session.Session() as sess:
self.assertAllEqual([5.0], self.evaluate(call))
self.assertAllEqual([5.0], sess.run(call))
def testFunctionWithNoOutput(self):
@ -187,7 +187,7 @@ class FunctionTest(test.TestCase):
call = APlus2B([1.0], [2.0])
self.assertEqual("APlus2B", call.op.name)
with session.Session() as sess:
self.assertAllEqual([5.0], self.evaluate(call))
self.assertAllEqual([5.0], sess.run(call))
def testDefineFunctionDuplicateOutputs(self):
@ -224,8 +224,8 @@ class FunctionTest(test.TestCase):
call_g = XSquarePlusOneGrad([2.0], [0.1])
with session.Session() as sess:
self.assertAllClose([5.0], self.evaluate(call_f))
self.assertAllClose([0.4], self.evaluate(call_g))
self.assertAllClose([5.0], sess.run(call_f))
self.assertAllClose([0.4], sess.run(call_g))
def testTanhSymGrad(self):
@ -387,7 +387,7 @@ class FunctionTest(test.TestCase):
call = AConstant()
self.assertEqual("AConstant", call.op.name)
with session.Session() as sess:
self.assertAllEqual([42], self.evaluate(call))
self.assertAllEqual([42], sess.run(call))
def testDefineFunctionNames(self):
@ -468,7 +468,7 @@ class FunctionTest(test.TestCase):
loop = control_flow_ops.while_loop(lambda x: x < 1e5, Body, [1.0])
ans = self.evaluate(loop)
ans = sess.run(loop)
self.assertAllClose(ans, 131072.)
def testControlFlowStrictness(self):
@ -650,8 +650,8 @@ class FunctionTest(test.TestCase):
# pylint: enable=unexpected-keyword-arg
self.assertEqual("next", call2.op.name)
with session.Session() as sess:
self.assertAllEqual([1], self.evaluate(call1))
self.assertAllEqual([0], self.evaluate(call2))
self.assertAllEqual([1], sess.run(call1))
self.assertAllEqual([0], sess.run(call2))
def testNestedFunction(self):
@ -794,7 +794,7 @@ class FunctionTest(test.TestCase):
y = Foo()
with self.session(graph=g) as sess:
self.assertEqual(self.evaluate(y), 10)
self.assertEqual(sess.run(y), 10)
def testCaptureInCond(self):
g = ops.Graph()
@ -809,8 +809,8 @@ class FunctionTest(test.TestCase):
z = Foo(False)
with self.session(graph=g) as sess:
self.assertEqual(self.evaluate(y), 1)
self.assertEqual(self.evaluate(z), 2)
self.assertEqual(sess.run(y), 1)
self.assertEqual(sess.run(z), 2)
def testStableName(self):
@ -900,7 +900,7 @@ class FunctionTest(test.TestCase):
self.assertEqual(global_vars[0].name, "linear/w:0")
with session.Session() as sess:
self.evaluate(variables.global_variables_initializer())
sess.run(variables.global_variables_initializer())
output_val = sess.run(
output_op, feed_dict={input_op: np.random.rand(32, 100)})
self.assertEqual(output_val.shape, (32, 100))
@ -928,7 +928,7 @@ class FunctionTest(test.TestCase):
self.assertEqual(global_vars[0].name, "vs1/var:0")
with session.Session() as sess:
self.evaluate(variables.global_variables_initializer())
sess.run(variables.global_variables_initializer())
out1, out2 = sess.run(
[out1_op, out2_op], feed_dict={input_op: np.linspace(1, 10, 10)})
self.assertAllEqual(out1, np.linspace(2, 11, 10))
@ -991,8 +991,8 @@ class FunctionTest(test.TestCase):
result_2 = Bar(constant_op.constant(100, dtype=dtypes.int64))
with session.Session() as sess:
self.assertEqual(4.0, self.evaluate(result_1))
self.assertEqual(100, self.evaluate(result_2))
self.assertEqual(4.0, sess.run(result_1))
self.assertEqual(100, sess.run(result_2))
self.assertEqual((4.0, 100), sess.run((result_1, result_2)))
def testStatefulFunction(self):
@ -1052,8 +1052,8 @@ class FunctionTest(test.TestCase):
for config in _OptimizerOptions():
config.device_count["CPU"] = 2
with session.Session(config=config) as sess:
self.assertEqual(42.0, self.evaluate(f_0))
self.assertEqual(44.0, self.evaluate(f_1))
self.assertEqual(42.0, sess.run(f_0))
self.assertEqual(44.0, sess.run(f_1))
self.assertEqual((42.0, 44.0), sess.run((f_0, f_1)))
def testGuaranteedConstsAreCaptured(self):
@ -1076,7 +1076,7 @@ class FunctionTest(test.TestCase):
return output
with self.session(use_gpu=False) as sess:
self.evaluate(var.initializer)
sess.run(var.initializer)
_ = sess.run(CapturesGuaranteedConst(), {also_not_const: 1.0})
def testSameFunctionDifferentGrads(self):
@ -1651,8 +1651,8 @@ class ModuleFunctionTest(test.TestCase):
y = LinearWithCApi(a, b, c)
z = Linear2WithCApi(a, b, c, d, e)
with session.Session() as sess:
self.assertAllEqual([[1]], self.evaluate(y))
self.assertAllEqual([[5]], self.evaluate(z))
self.assertAllEqual([[1]], sess.run(y))
self.assertAllEqual([[5]], sess.run(z))
class VariableHoistingTest(test.TestCase):
@ -1704,7 +1704,7 @@ class VariableHoistingTest(test.TestCase):
self.assertEqual("Foo/b", b.op.name)
with self.session(graph=g) as sess:
self.evaluate(variables.global_variables_initializer())
sess.run(variables.global_variables_initializer())
w, b, x, y0, loss, dw, db = sess.run([w, b, x, y0, loss, dw, db])
self.assertAllEqual(w.shape, (64, 64))

View File

@ -211,7 +211,7 @@ class DeviceFunctionsTest(test.TestCase):
with session.Session() as sess:
init = variables.variables_initializer([variable_node])
sess.run(init)
output = self.evaluate(output_node)
output = sess.run(output_node)
self.assertNear(4.0, output, 0.00001)
variable_graph_def = sess.graph.as_graph_def()
@ -242,8 +242,8 @@ class DeviceFunctionsTest(test.TestCase):
output_node = math_ops_lib.multiply(
variable_node, 2.0, name="output_node")
with session.Session() as sess:
self.evaluate(variable_node.initializer)
output = self.evaluate(output_node)
sess.run(variable_node.initializer)
output = sess.run(output_node)
self.assertNear(2.0, output, 0.00001)
variable_graph_def = sess.graph.as_graph_def()
# First get the constant_graph_def when variable_names_whitelist is
@ -256,7 +256,7 @@ class DeviceFunctionsTest(test.TestCase):
# Then initialize the unused variable, and get another
# constant_graph_def when variable_names_whitelist is not set.
self.evaluate(another_variable.initializer)
sess.run(another_variable.initializer)
constant_graph_def_without_variable_whitelist = (
graph_util.convert_variables_to_constants(
sess, variable_graph_def, ["output_node"]))
@ -295,7 +295,7 @@ class DeviceFunctionsTest(test.TestCase):
["Variable", "VariableV2", "VarHandleOp", "ReadVariableOp"])
with session.Session() as sess:
output_node = sess.graph.get_tensor_by_name("output_node:0")
output = self.evaluate(output_node)
output = sess.run(output_node)
self.assertNear(2.0, output, 0.00001)
def create_node_def(self, op, name, inputs):

View File

@ -398,10 +398,10 @@ class ImportGraphDefTest(test.TestCase):
# TODO(b/76173421): make this work (currently DCHECKS)
# with self.cached_session() as sess:
# sess.run(imported_init)
# self.assertEqual(self.evaluate(imported_var), 1.0)
# self.assertEqual(self.evaluate(imported_assign), 2.0)
# self.assertEqual(list(self.evaluate(imported_shape)), [])
# self.assertEqual(list(self.evaluate(new_var_shape)), [])
# self.assertEqual(sess.run(imported_var), 1.0)
# self.assertEqual(sess.run(imported_assign), 2.0)
# self.assertEqual(list(sess.run(imported_shape)), [])
# self.assertEqual(list(sess.run(new_var_shape)), [])
def testWhileLoop(self):
# Produce GraphDef containing while loop.
@ -418,7 +418,7 @@ class ImportGraphDefTest(test.TestCase):
return_elements=[r.name])
self.assertEqual(imported_r.name, "import/" + r.name)
with self.cached_session() as sess:
self.assertEqual(self.evaluate(imported_r), 10)
self.assertEqual(sess.run(imported_r), 10)
def testImportWhileLoopInCond(self):
# Produce GraphDef containing while loop.
@ -458,7 +458,7 @@ class ImportGraphDefTest(test.TestCase):
lambda i: i < 2, ImportFn, [0],
shape_invariants=[tensor_shape.TensorShape(None)])
with self.cached_session() as sess:
self.assertEqual(self.evaluate(out), 10)
self.assertEqual(sess.run(out), 10)
def testTypeMismatchInGraphDef(self):
# TODO(skyewm): improve error message

View File

@ -492,8 +492,8 @@ class ScopedMetaGraphTest(test.TestCase):
init_op = variables.global_variables_initializer()
grad = gradients_impl.gradients([output], [var])
with session.Session() as sess:
self.evaluate(init_op)
expected_grad_value = self.evaluate(grad)
sess.run(init_op)
expected_grad_value = sess.run(grad)
# Restore the MetaGraphDef into a new Graph with an import scope.
with ops.Graph().as_default():
@ -518,8 +518,8 @@ class ScopedMetaGraphTest(test.TestCase):
init_op = variables.global_variables_initializer()
with session.Session() as sess:
self.evaluate(init_op)
actual_grad_value = self.evaluate(grad)
sess.run(init_op)
actual_grad_value = sess.run(grad)
self.assertEqual(expected_grad_value, actual_grad_value)
def testImportWhileLoopInWhileLoop(self):
@ -544,7 +544,7 @@ class ScopedMetaGraphTest(test.TestCase):
_, x = control_flow_ops.while_loop(lambda i, x: i < 2, body, [0, 0.0],
name="")
with session.Session() as sess:
self.evaluate(variables.global_variables_initializer())
sess.run(variables.global_variables_initializer())
sess.run(x)
def testScopedImportUnderNameScope(self):
@ -869,7 +869,7 @@ class MetaGraphWithVariableScopeTest(test.TestCase):
initializer = variables.local_variables_initializer()
sess.run(initializer)
self.evaluate(update_op)
sess.run(update_op)
meta_graph.export_scoped_meta_graph(
filename=meta_graph_filename, graph=graph)

View File

@ -517,21 +517,21 @@ class OperationTest(test_util.TensorFlowTestCase):
self.assertEquals(x.consumers(), [])
self.assertEquals(y.consumers(), [z.op, z.op])
with session.Session(graph=g) as sess:
self.assertEquals(self.evaluate(z), 4)
self.assertEquals(sess.run(z), 4)
z.op._update_input(0, x) # pylint: disable=protected-access
self.assertEquals(list(z.op.inputs), [x, y])
self.assertEquals(x.consumers(), [z.op])
self.assertEquals(y.consumers(), [z.op])
with session.Session(graph=g) as sess:
self.assertEquals(self.evaluate(z), 3)
self.assertEquals(sess.run(z), 3)
z.op._update_input(1, y) # pylint: disable=protected-access
self.assertEquals(list(z.op.inputs), [x, y])
self.assertEquals(x.consumers(), [z.op])
self.assertEquals(y.consumers(), [z.op])
with session.Session(graph=g) as sess:
self.assertEquals(self.evaluate(z), 3)
self.assertEquals(sess.run(z), 3)
def testUpdateInputGraphError(self):
g_0 = ops.Graph()

View File

@ -109,8 +109,8 @@ class SmartCaseTest(test_util.TensorFlowTestCase):
exclusive=True)
with session.Session() as sess:
# No feed_dict necessary
self.assertEqual(self.evaluate(y), 1)
self.assertEqual(self.evaluate(z), 1)
self.assertEqual(sess.run(y), 1)
self.assertEqual(sess.run(z), 1)
def testFalse(self):
conditions = [(False, raise_exception)]
@ -121,8 +121,8 @@ class SmartCaseTest(test_util.TensorFlowTestCase):
default=lambda: constant_op.constant(1),
exclusive=True)
with session.Session() as sess:
self.assertEqual(self.evaluate(y), 1)
self.assertEqual(self.evaluate(z), 1)
self.assertEqual(sess.run(y), 1)
self.assertEqual(sess.run(z), 1)
def testMix(self):
x = array_ops.placeholder(dtype=dtypes.int32, shape=[])

Some files were not shown because too many files have changed in this diff Show More