Replace a few calls of Session run
with evaluate
In order to support tests running in eager mode we need to avoid unnecessary use of Sessions in tests. This moves to remove some of the uses of the `run` function in favor of `evaluate`. PiperOrigin-RevId: 222013881
This commit is contained in:
parent
1aaa68d93c
commit
1fdd7c7408
@ -60,7 +60,7 @@ class CategoricalTest(xla_test.XLATestCase):
|
||||
random_seed.set_random_seed(1618)
|
||||
op = random_ops.multinomial(logits, num_samples,
|
||||
output_dtype=dtypes.int32)
|
||||
d = sess.run(op)
|
||||
d = self.evaluate(op)
|
||||
|
||||
batch_size, num_classes = logits.shape
|
||||
freqs_mat = []
|
||||
@ -85,9 +85,9 @@ class CategoricalTest(xla_test.XLATestCase):
|
||||
|
||||
# The random-number generator, if working correctly, should produce the
|
||||
# same output multiple times with low probability.
|
||||
y = sess.run(x)
|
||||
z = sess.run(x)
|
||||
w = sess.run(x)
|
||||
y = self.evaluate(x)
|
||||
z = self.evaluate(x)
|
||||
w = self.evaluate(x)
|
||||
|
||||
# We use exact equality here. If the random-number generator is producing
|
||||
# deterministic output, all three outputs will be bitwise identical.
|
||||
@ -112,7 +112,7 @@ class CategoricalTest(xla_test.XLATestCase):
|
||||
x = random_ops.multinomial(
|
||||
array_ops.ones(shape=[1, 20], dtype=dtype), 1000,
|
||||
output_dtype=output_dtype)
|
||||
y = sess.run(x)
|
||||
y = self.evaluate(x)
|
||||
self.assertTrue((y >= 0).sum() == 1000)
|
||||
self.assertTrue((y < 20).sum() == 1000)
|
||||
|
||||
|
@ -337,7 +337,7 @@ class ConcatOffsetTest(xla_test.XLATestCase):
|
||||
s1 = constant_op.constant([2, 7, 5], dtypes.int32)
|
||||
s2 = constant_op.constant([2, 20, 5], dtypes.int32)
|
||||
off = gen_array_ops.concat_offset(cdim, [s0, s1, s2])
|
||||
ans = sess.run(off)
|
||||
ans = self.evaluate(off)
|
||||
self.assertAllEqual(ans, [[0, 0, 0], [0, 3, 0], [0, 10, 0]])
|
||||
|
||||
|
||||
@ -350,7 +350,7 @@ class PackTest(xla_test.XLATestCase):
|
||||
s1 = constant_op.constant([2, 7, 5], dtypes.int32)
|
||||
s2 = constant_op.constant([2, 20, 5], dtypes.int32)
|
||||
packed = array_ops.stack([s0, s1, s2])
|
||||
ans = sess.run(packed)
|
||||
ans = self.evaluate(packed)
|
||||
self.assertAllEqual(ans, [[2, 3, 5], [2, 7, 5], [2, 20, 5]])
|
||||
|
||||
def testScalars(self):
|
||||
@ -360,7 +360,7 @@ class PackTest(xla_test.XLATestCase):
|
||||
s1 = constant_op.constant(3, dtypes.int32)
|
||||
s2 = constant_op.constant(5, dtypes.int32)
|
||||
packed = array_ops.stack([s0, s1, s2])
|
||||
ans = sess.run(packed)
|
||||
ans = self.evaluate(packed)
|
||||
self.assertAllEqual(ans, [2, 3, 5])
|
||||
|
||||
def testEmpty(self):
|
||||
@ -370,7 +370,7 @@ class PackTest(xla_test.XLATestCase):
|
||||
s1 = constant_op.constant([[]], dtypes.int32)
|
||||
s2 = constant_op.constant([[]], dtypes.int32)
|
||||
packed = array_ops.stack([s0, s1, s2])
|
||||
ans = sess.run(packed)
|
||||
ans = self.evaluate(packed)
|
||||
self.assertAllEqual(ans, [[[]], [[]], [[]]])
|
||||
|
||||
|
||||
|
@ -106,7 +106,7 @@ class EagerTest(xla_test.XLATestCase):
|
||||
three = constant_op.constant(3)
|
||||
five = constant_op.constant(5)
|
||||
product = three * five
|
||||
self.assertAllEqual(15, sess.run(product))
|
||||
self.assertAllEqual(15, self.evaluate(product))
|
||||
|
||||
def testDegenerateSlices(self):
|
||||
with self.test_scope():
|
||||
|
@ -50,7 +50,7 @@ class FunctionTest(xla_test.XLATestCase):
|
||||
b = constant_op.constant(bval, name="b")
|
||||
with self.test_scope():
|
||||
call_f = Foo(a, b)
|
||||
result = sess.run(call_f)
|
||||
result = self.evaluate(call_f)
|
||||
self.assertAllClose(result, expected, rtol=1e-3)
|
||||
|
||||
def testNestedFunctions(self):
|
||||
@ -76,7 +76,7 @@ class FunctionTest(xla_test.XLATestCase):
|
||||
b = constant_op.constant(bval, name="b")
|
||||
with self.test_scope():
|
||||
call_g = Foo(a, b)
|
||||
result = sess.run(call_g)
|
||||
result = self.evaluate(call_g)
|
||||
self.assertAllClose(result, expected, rtol=1e-3)
|
||||
|
||||
def testFunctionMultipleRetvals(self):
|
||||
@ -100,7 +100,7 @@ class FunctionTest(xla_test.XLATestCase):
|
||||
b = constant_op.constant(bval, name="b")
|
||||
with self.test_scope():
|
||||
call_f = Foo(a, b)
|
||||
result = sess.run(call_f)
|
||||
result = self.evaluate(call_f)
|
||||
self.assertAllClose(result, expected, rtol=1e-3)
|
||||
|
||||
def testCompileTimeConstantsInDefun(self):
|
||||
|
@ -88,7 +88,7 @@ class LSTMTest(test.TestCase):
|
||||
(basename, m_prev_scalar, c_prev_scalar, pad_scalar))
|
||||
|
||||
# Initialize variables and run the unrolled LSTM step.
|
||||
sess.run(variables.global_variables_initializer())
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
return sess.run([m, c])
|
||||
|
||||
def testLSTMCell(self):
|
||||
@ -173,7 +173,7 @@ class LSTMTest(test.TestCase):
|
||||
(basename, m_init_scalar, c_init_scalar, pad_scalar))
|
||||
|
||||
# Initialize variables and run the unrolled LSTM layer.
|
||||
sess.run(variables.global_variables_initializer())
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
return sess.run(out_seq)
|
||||
|
||||
def testLSTMLayer(self):
|
||||
|
@ -33,7 +33,7 @@ class PlaceholderTest(xla_test.XLATestCase):
|
||||
ph = array_ops.placeholder_with_default(v, shape=[])
|
||||
out = ph * 2
|
||||
sess.run(variables.variables_initializer([v]))
|
||||
self.assertEqual(8.0, sess.run(out))
|
||||
self.assertEqual(8.0, self.evaluate(out))
|
||||
|
||||
def test_placeholder_with_default_fed(self):
|
||||
with self.cached_session() as sess, self.test_scope():
|
||||
|
@ -46,9 +46,9 @@ class RandomOpsTest(xla_test.XLATestCase):
|
||||
|
||||
# The random-number generator, if working correctly, should produce the
|
||||
# same output multiple times with low probability.
|
||||
y = sess.run(x)
|
||||
z = sess.run(x)
|
||||
w = sess.run(x)
|
||||
y = self.evaluate(x)
|
||||
z = self.evaluate(x)
|
||||
w = self.evaluate(x)
|
||||
|
||||
# We use exact equality here. If the random-number generator is producing
|
||||
# deterministic output, all three outputs will be bitwise identical.
|
||||
@ -83,7 +83,7 @@ class RandomOpsTest(xla_test.XLATestCase):
|
||||
with self.test_scope():
|
||||
x = random_ops.random_uniform(
|
||||
shape=[1000], dtype=dtype, minval=-2, maxval=33)
|
||||
y = sess.run(x)
|
||||
y = self.evaluate(x)
|
||||
self.assertTrue((y >= -2).sum() == 1000)
|
||||
self.assertTrue((y < 33).sum() == 1000)
|
||||
|
||||
@ -102,7 +102,7 @@ class RandomOpsTest(xla_test.XLATestCase):
|
||||
with self.cached_session() as sess:
|
||||
with self.test_scope():
|
||||
x = random_ops.truncated_normal(shape=[count], dtype=dtype)
|
||||
y = sess.run(x)
|
||||
y = self.evaluate(x)
|
||||
|
||||
def normal_cdf(x):
|
||||
return .5 * math.erfc(-x / math.sqrt(2))
|
||||
@ -148,7 +148,7 @@ class RandomOpsTest(xla_test.XLATestCase):
|
||||
with self.test_scope():
|
||||
x = math_ops.range(1 << 16)
|
||||
shuffle = random_ops.random_shuffle(x)
|
||||
result = sess.run(shuffle)
|
||||
result = self.evaluate(shuffle)
|
||||
expected = range(1 << 16)
|
||||
# Compare sets to avoid randomness behavior changes but make sure still
|
||||
# have all the values.
|
||||
@ -159,7 +159,7 @@ class RandomOpsTest(xla_test.XLATestCase):
|
||||
with self.test_scope():
|
||||
x = array_ops.diag(math_ops.range(20))
|
||||
shuffle = random_ops.random_shuffle(x)
|
||||
result = sess.run(shuffle)
|
||||
result = self.evaluate(shuffle)
|
||||
expected = np.diag(range(20)).flatten()
|
||||
# Compare sets to avoid randomness behavior changes but make sure still
|
||||
# have all the values.
|
||||
|
@ -505,7 +505,7 @@ class TensorArrayTest(xla_test.XLATestCase):
|
||||
[-0.5, 1.5], # read(0) gradient
|
||||
[20.0, 30.0, 40.0, 50.0], # concat gradient
|
||||
])
|
||||
grad_vals = sess.run(grad_r) # 2 + 2 entries
|
||||
grad_vals = self.evaluate(grad_r) # 2 + 2 entries
|
||||
|
||||
self.assertAllClose([2.0 - 0.5 + 20.0, 3.0 + 1.5 + 30.0], grad_vals[0])
|
||||
self.assertAllEqual([4.0 + 40.0, 5.0 + 50.0], grad_vals[1])
|
||||
|
@ -229,7 +229,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
resource_variable_ops.resource_scatter_add(
|
||||
handle, [0], constant_op.constant([[2]], dtype=dtypes.int32)))
|
||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||
self.assertAllEqual(sess.run(read), [[3], [7]])
|
||||
self.assertAllEqual(self.evaluate(read), [[3], [7]])
|
||||
|
||||
def testScatterSub(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
@ -242,7 +242,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
resource_variable_ops.resource_scatter_sub(
|
||||
handle, [1], constant_op.constant([[2]], dtype=dtypes.int32)))
|
||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||
self.assertAllEqual(sess.run(read), [[4], [-1]])
|
||||
self.assertAllEqual(self.evaluate(read), [[4], [-1]])
|
||||
|
||||
def testScatterMul(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
@ -255,7 +255,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
resource_variable_ops.resource_scatter_mul(
|
||||
handle, [0], constant_op.constant([[5]], dtype=dtypes.int32)))
|
||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||
self.assertEqual(sess.run(read), [[5]])
|
||||
self.assertEqual(self.evaluate(read), [[5]])
|
||||
|
||||
def testScatterDiv(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
@ -268,7 +268,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
resource_variable_ops.resource_scatter_div(
|
||||
handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
|
||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||
self.assertAllEqual(sess.run(read), [[2]])
|
||||
self.assertAllEqual(self.evaluate(read), [[2]])
|
||||
|
||||
def testScatterMin(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
@ -281,7 +281,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
resource_variable_ops.resource_scatter_min(
|
||||
handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
|
||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||
self.assertEqual(sess.run(read), [[3]])
|
||||
self.assertEqual(self.evaluate(read), [[3]])
|
||||
|
||||
def testScatterMax(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
@ -294,7 +294,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
resource_variable_ops.resource_scatter_max(
|
||||
handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
|
||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||
self.assertEqual(sess.run(read), [[6]])
|
||||
self.assertEqual(self.evaluate(read), [[6]])
|
||||
|
||||
def testScatterUpdate(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
@ -307,7 +307,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
resource_variable_ops.resource_scatter_update(
|
||||
handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
|
||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||
self.assertEqual(sess.run(read), [[3]])
|
||||
self.assertEqual(self.evaluate(read), [[3]])
|
||||
|
||||
def testScatterAddScalar(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
@ -320,7 +320,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
resource_variable_ops.resource_scatter_add(
|
||||
handle, [0], constant_op.constant(2, dtype=dtypes.int32)))
|
||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||
self.assertEqual(sess.run(read), [[3]])
|
||||
self.assertEqual(self.evaluate(read), [[3]])
|
||||
|
||||
def testScatterSubScalar(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
@ -333,7 +333,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
resource_variable_ops.resource_scatter_sub(
|
||||
handle, [0], constant_op.constant(2, dtype=dtypes.int32)))
|
||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||
self.assertEqual(sess.run(read), [[-1]])
|
||||
self.assertEqual(self.evaluate(read), [[-1]])
|
||||
|
||||
def testScatterMulScalar(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
@ -346,7 +346,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
resource_variable_ops.resource_scatter_mul(
|
||||
handle, [0], constant_op.constant(5, dtype=dtypes.int32)))
|
||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||
self.assertEqual(sess.run(read), [[5]])
|
||||
self.assertEqual(self.evaluate(read), [[5]])
|
||||
|
||||
def testScatterDivScalar(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
@ -359,7 +359,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
resource_variable_ops.resource_scatter_div(
|
||||
handle, [0], constant_op.constant(3, dtype=dtypes.int32)))
|
||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||
self.assertEqual(sess.run(read), [[2]])
|
||||
self.assertEqual(self.evaluate(read), [[2]])
|
||||
|
||||
def testScatterMinScalar(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
@ -372,7 +372,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
resource_variable_ops.resource_scatter_min(
|
||||
handle, [0], constant_op.constant(3, dtype=dtypes.int32)))
|
||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||
self.assertEqual(sess.run(read), [[3]])
|
||||
self.assertEqual(self.evaluate(read), [[3]])
|
||||
|
||||
def testScatterMaxScalar(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
@ -385,7 +385,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
resource_variable_ops.resource_scatter_max(
|
||||
handle, [0], constant_op.constant(3, dtype=dtypes.int32)))
|
||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||
self.assertEqual(sess.run(read), [[6]])
|
||||
self.assertEqual(self.evaluate(read), [[6]])
|
||||
|
||||
def testScatterNdAddOps(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
@ -400,7 +400,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
sess.run(gen_state_ops.resource_scatter_nd_add(handle, indices, updates))
|
||||
read = resource_variable_ops.read_variable_op(
|
||||
handle, dtype=dtypes.float32)
|
||||
self.assertAllClose(expected, sess.run(read))
|
||||
self.assertAllClose(expected, self.evaluate(read))
|
||||
|
||||
def testScatterNdUpdateAddOps(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
@ -416,7 +416,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
gen_state_ops.resource_scatter_nd_update(handle, indices, updates))
|
||||
read = resource_variable_ops.read_variable_op(
|
||||
handle, dtype=dtypes.float32)
|
||||
self.assertAllClose(expected, sess.run(read))
|
||||
self.assertAllClose(expected, self.evaluate(read))
|
||||
|
||||
|
||||
class StridedSliceAssignChecker(object):
|
||||
|
@ -96,7 +96,7 @@ class KerasTest(tf.test.TestCase):
|
||||
sess.run(init)
|
||||
sample_input = tf.random_uniform((1, 10, 10, 1))
|
||||
output = model(sample_input) # pylint: disable=not-callable
|
||||
self.assertEqual(sess.run(output).shape, (1, 3))
|
||||
self.assertEqual(self.evaluate(output).shape, (1, 3))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -34,7 +34,7 @@ class ListLiteralsTest(tf.test.TestCase):
|
||||
result = converted()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(sess.run(result), [1, 2, 3])
|
||||
self.assertAllEqual(self.evaluate(result), [1, 2, 3])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -35,7 +35,7 @@ class InputDataTest(test.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
sample_data = tf.zeros([32000, 2])
|
||||
wav_encoder = contrib_audio.encode_wav(sample_data, 16000)
|
||||
wav_data = sess.run(wav_encoder)
|
||||
wav_data = self.evaluate(wav_encoder)
|
||||
return wav_data
|
||||
|
||||
def _saveTestWavFile(self, filename, wav_data):
|
||||
|
@ -33,7 +33,7 @@ class LabelWavTest(test.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
sample_data = tf.zeros([1000, 2])
|
||||
wav_encoder = contrib_audio.encode_wav(sample_data, 16000)
|
||||
wav_data = sess.run(wav_encoder)
|
||||
wav_data = self.evaluate(wav_encoder)
|
||||
return wav_data
|
||||
|
||||
def _saveTestWavFile(self, filename, wav_data):
|
||||
|
@ -33,7 +33,7 @@ class WavToFeaturesTest(test.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
sample_data = tf.zeros([32000, 2])
|
||||
wav_encoder = contrib_audio.encode_wav(sample_data, 16000)
|
||||
wav_data = sess.run(wav_encoder)
|
||||
wav_data = self.evaluate(wav_encoder)
|
||||
return wav_data
|
||||
|
||||
def _saveTestWavFile(self, filename, wav_data):
|
||||
|
@ -113,7 +113,7 @@ class CallTreesTest(converter_testing.TestCase):
|
||||
with self.compiled(node, ns) as result:
|
||||
with self.cached_session() as sess:
|
||||
result_tensor = result.test_fn(constant_op.constant(1))
|
||||
self.assertEquals(sess.run(result_tensor), 3)
|
||||
self.assertEquals(self.evaluate(result_tensor), 3)
|
||||
|
||||
def test_call_to_decorated_function(self):
|
||||
|
||||
|
@ -68,7 +68,7 @@ class ListTest(converter_testing.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
tl = result.test_fn()
|
||||
r = list_ops.tensor_list_stack(tl, dtypes.int32)
|
||||
self.assertAllEqual(sess.run(r), [1, 2, 3])
|
||||
self.assertAllEqual(self.evaluate(r), [1, 2, 3])
|
||||
|
||||
def test_list_pop(self):
|
||||
|
||||
@ -91,8 +91,8 @@ class ListTest(converter_testing.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
ts, tl = result.test_fn()
|
||||
r = list_ops.tensor_list_stack(tl, dtypes.int32)
|
||||
self.assertAllEqual(sess.run(r), [1, 2])
|
||||
self.assertAllEqual(sess.run(ts), 3)
|
||||
self.assertAllEqual(self.evaluate(r), [1, 2])
|
||||
self.assertAllEqual(self.evaluate(ts), 3)
|
||||
|
||||
def test_double_list_pop(self):
|
||||
|
||||
|
@ -48,12 +48,12 @@ class SideEffectGuardsTest(converter_testing.TestCase):
|
||||
with self.compiled(node, {}, state_ops.assign) as result:
|
||||
with self.cached_session() as sess:
|
||||
v = variable_scope.get_variable('test', initializer=2)
|
||||
sess.run(v.initializer)
|
||||
self.evaluate(v.initializer)
|
||||
sess.run(result.test_fn(v))
|
||||
# TODO(mdan): Add support for this use case.
|
||||
# Right now the variable `a` is not conditioned on the `assign` because
|
||||
# there's no way to add control dependencies to a variable object.
|
||||
self.assertEqual(2, sess.run(v))
|
||||
self.assertEqual(2, self.evaluate(v))
|
||||
|
||||
def test_side_effect_on_used_variable(self):
|
||||
|
||||
@ -69,11 +69,11 @@ class SideEffectGuardsTest(converter_testing.TestCase):
|
||||
with self.compiled(node, {}, state_ops.assign) as result:
|
||||
with self.cached_session() as sess:
|
||||
v = variable_scope.get_variable('test', initializer=2)
|
||||
sess.run(v.initializer)
|
||||
self.evaluate(v.initializer)
|
||||
sess.run(result.test_fn(v))
|
||||
# TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
|
||||
# Right now it's 3 or 4 based on whether the read is synchronized.
|
||||
self.assertEqual(3, sess.run(v))
|
||||
self.assertEqual(3, self.evaluate(v))
|
||||
|
||||
def test_side_effect_on_tensor(self):
|
||||
|
||||
@ -109,10 +109,10 @@ class SideEffectGuardsTest(converter_testing.TestCase):
|
||||
with self.compiled(node, {}, state_ops.assign_add) as result:
|
||||
with self.cached_session() as sess:
|
||||
v = variable_scope.get_variable('test', initializer=2)
|
||||
sess.run(v.initializer)
|
||||
self.evaluate(v.initializer)
|
||||
sess.run(result.test_fn(v))
|
||||
# TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
|
||||
self.assertEqual(4, sess.run(v))
|
||||
self.assertEqual(4, self.evaluate(v))
|
||||
|
||||
def test_multiline_nested_block(self):
|
||||
|
||||
@ -130,10 +130,10 @@ class SideEffectGuardsTest(converter_testing.TestCase):
|
||||
with self.compiled(node, {}, state_ops.assign, ops.name_scope) as result:
|
||||
with self.cached_session() as sess:
|
||||
v = variable_scope.get_variable('test', initializer=2)
|
||||
sess.run(v.initializer)
|
||||
self.evaluate(v.initializer)
|
||||
sess.run(result.test_fn(v))
|
||||
# TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
|
||||
self.assertEqual(3, sess.run(v))
|
||||
self.assertEqual(3, self.evaluate(v))
|
||||
|
||||
def test_multiline_block_unsafe(self):
|
||||
|
||||
@ -153,10 +153,10 @@ class SideEffectGuardsTest(converter_testing.TestCase):
|
||||
state_ops.assign_add) as result:
|
||||
with self.cached_session() as sess:
|
||||
v = variable_scope.get_variable('test', initializer=2)
|
||||
sess.run(v.initializer)
|
||||
self.evaluate(v.initializer)
|
||||
sess.run(result.test_fn(v))
|
||||
# TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
|
||||
self.assertEqual(4, sess.run(v))
|
||||
self.assertEqual(4, self.evaluate(v))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -49,7 +49,7 @@ class SliceTest(converter_testing.TestCase):
|
||||
tl = list_ops.tensor_list_from_tensor(
|
||||
[1, 2], element_shape=constant_op.constant([], dtype=dtypes.int32))
|
||||
y = result.test_fn(tl)
|
||||
self.assertEqual(2, sess.run(y))
|
||||
self.assertEqual(2, self.evaluate(y))
|
||||
|
||||
def test_index_access_multiple_definitions(self):
|
||||
|
||||
|
@ -63,7 +63,7 @@ class ApiTest(test.TestCase):
|
||||
x = tc.test_method(
|
||||
constant_op.constant([2, 4]), constant_op.constant(1),
|
||||
constant_op.constant(-2))
|
||||
self.assertListEqual([0, 1], sess.run(x).tolist())
|
||||
self.assertListEqual([0, 1], self.evaluate(x).tolist())
|
||||
|
||||
def test_decorator_does_not_recurse(self):
|
||||
|
||||
@ -83,7 +83,7 @@ class ApiTest(test.TestCase):
|
||||
x = tc.test_method(
|
||||
constant_op.constant([2, 4]), constant_op.constant(1),
|
||||
constant_op.constant(-2))
|
||||
self.assertListEqual([0, 1], sess.run(x).tolist())
|
||||
self.assertListEqual([0, 1], self.evaluate(x).tolist())
|
||||
|
||||
def test_decorator_calls_unconverted_graph(self):
|
||||
|
||||
@ -104,7 +104,7 @@ class ApiTest(test.TestCase):
|
||||
x = tc.test_method(
|
||||
constant_op.constant([2, 4]), constant_op.constant(1),
|
||||
constant_op.constant(-2))
|
||||
self.assertListEqual([0, 1], sess.run(x).tolist())
|
||||
self.assertListEqual([0, 1], self.evaluate(x).tolist())
|
||||
|
||||
def test_decorator_calls_unconverted_py_func(self):
|
||||
|
||||
@ -130,7 +130,7 @@ class ApiTest(test.TestCase):
|
||||
x = tc.test_method(
|
||||
constant_op.constant([2, 4]), constant_op.constant(1),
|
||||
constant_op.constant(-2))
|
||||
self.assertListEqual([0, 1], sess.run(x).tolist())
|
||||
self.assertListEqual([0, 1], self.evaluate(x).tolist())
|
||||
|
||||
def test_decorator_calls_decorated(self):
|
||||
|
||||
@ -153,7 +153,7 @@ class ApiTest(test.TestCase):
|
||||
x = tc.test_method(
|
||||
constant_op.constant([2, 4]), constant_op.constant(1),
|
||||
constant_op.constant(-2))
|
||||
self.assertListEqual([0, 1], sess.run(x).tolist())
|
||||
self.assertListEqual([0, 1], self.evaluate(x).tolist())
|
||||
|
||||
def test_decorator_preserves_argspec(self):
|
||||
|
||||
@ -192,7 +192,7 @@ class ApiTest(test.TestCase):
|
||||
x = tc.test_method(
|
||||
constant_op.constant([2, 4]), constant_op.constant(1),
|
||||
constant_op.constant(-2))
|
||||
self.assertListEqual([0, 1], sess.run(x).tolist())
|
||||
self.assertListEqual([0, 1], self.evaluate(x).tolist())
|
||||
|
||||
def test_converted_call_builtin(self):
|
||||
x = api.converted_call(range, None, converter.ConversionOptions(), 3)
|
||||
@ -208,7 +208,7 @@ class ApiTest(test.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
x = api.converted_call(test_fn, None, converter.ConversionOptions(),
|
||||
constant_op.constant(-1))
|
||||
self.assertEqual(1, sess.run(x))
|
||||
self.assertEqual(1, self.evaluate(x))
|
||||
|
||||
def test_converted_call_method_explicit_owner(self):
|
||||
# TODO(mdan): Implement.
|
||||
@ -234,7 +234,7 @@ class ApiTest(test.TestCase):
|
||||
tc = TestClass(constant_op.constant(-1))
|
||||
x = api.converted_call(tc.test_method, None,
|
||||
converter.ConversionOptions(), tc)
|
||||
self.assertEqual(1, sess.run(x))
|
||||
self.assertEqual(1, self.evaluate(x))
|
||||
|
||||
def test_converted_call_method_by_class(self):
|
||||
|
||||
@ -252,7 +252,7 @@ class ApiTest(test.TestCase):
|
||||
tc = TestClass(constant_op.constant(-1))
|
||||
x = api.converted_call(TestClass.test_method, None,
|
||||
converter.ConversionOptions(), tc)
|
||||
self.assertEqual(1, sess.run(x))
|
||||
self.assertEqual(1, self.evaluate(x))
|
||||
|
||||
def test_converted_call_callable_object(self):
|
||||
|
||||
@ -269,7 +269,7 @@ class ApiTest(test.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
tc = TestClass(constant_op.constant(-1))
|
||||
x = api.converted_call(tc, None, converter.ConversionOptions())
|
||||
self.assertEqual(1, sess.run(x))
|
||||
self.assertEqual(1, self.evaluate(x))
|
||||
|
||||
def test_converted_call_constructor(self):
|
||||
|
||||
@ -288,7 +288,7 @@ class ApiTest(test.TestCase):
|
||||
constant_op.constant(-1))
|
||||
# tc is now a converted object.
|
||||
x = tc.test_method()
|
||||
self.assertEqual(1, sess.run(x))
|
||||
self.assertEqual(1, self.evaluate(x))
|
||||
|
||||
def test_converted_call_already_converted(self):
|
||||
|
||||
@ -298,12 +298,12 @@ class ApiTest(test.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
x = api.converted_call(f, None, converter.ConversionOptions(),
|
||||
constant_op.constant(0))
|
||||
self.assertTrue(sess.run(x))
|
||||
self.assertTrue(self.evaluate(x))
|
||||
|
||||
converted_f = api.to_graph(f)
|
||||
x = api.converted_call(converted_f, None, converter.ConversionOptions(),
|
||||
constant_op.constant(0))
|
||||
self.assertTrue(sess.run(x))
|
||||
self.assertTrue(self.evaluate(x))
|
||||
|
||||
def test_converted_call_no_user_code(self):
|
||||
|
||||
@ -334,8 +334,8 @@ class ApiTest(test.TestCase):
|
||||
constant_op.constant([[0.0]]), training=True)
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
self.assertAllEqual([[0.0, 0.0]], sess.run(x))
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertAllEqual([[0.0, 0.0]], self.evaluate(x))
|
||||
|
||||
def test_converted_call_whitelisted_method_extra_self(self):
|
||||
|
||||
@ -349,8 +349,8 @@ class ApiTest(test.TestCase):
|
||||
model, constant_op.constant([[0.0]]), training=True)
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
self.assertAllEqual([[0.0, 0.0]], sess.run(x))
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertAllEqual([[0.0, 0.0]], self.evaluate(x))
|
||||
|
||||
def test_converted_call_whitelisted_method_via_owner(self):
|
||||
|
||||
@ -364,8 +364,8 @@ class ApiTest(test.TestCase):
|
||||
constant_op.constant([[0.0]]), training=True)
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
self.assertAllEqual([[0.0, 0.0]], sess.run(x))
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertAllEqual([[0.0, 0.0]], self.evaluate(x))
|
||||
|
||||
def test_converted_call_lambda(self):
|
||||
|
||||
@ -376,8 +376,8 @@ class ApiTest(test.TestCase):
|
||||
x = api.converted_call(l, None, opts, constant_op.constant(0))
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
self.assertAllEqual(True, sess.run(x))
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertAllEqual(True, self.evaluate(x))
|
||||
|
||||
def test_to_graph_basic(self):
|
||||
|
||||
@ -390,7 +390,7 @@ class ApiTest(test.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
x = compiled_fn(constant_op.constant([4, 8]), 4)
|
||||
self.assertListEqual([1, 2], sess.run(x).tolist())
|
||||
self.assertListEqual([1, 2], self.evaluate(x).tolist())
|
||||
|
||||
def test_to_graph_with_defaults(self):
|
||||
|
||||
@ -405,7 +405,7 @@ class ApiTest(test.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
x = compiled_fn(constant_op.constant([4, 8]))
|
||||
self.assertListEqual([1, 2], sess.run(x).tolist())
|
||||
self.assertListEqual([1, 2], self.evaluate(x).tolist())
|
||||
|
||||
def test_to_code_basic(self):
|
||||
|
||||
|
@ -36,7 +36,7 @@ class SpecialFunctionsTest(test.TestCase):
|
||||
python_one = special_functions.match_staging_level(1, 1)
|
||||
with self.cached_session() as sess:
|
||||
self.assertTrue(tensor_util.is_tensor(tensor_one))
|
||||
self.assertAllEqual(sess.run(tensor_one), 1)
|
||||
self.assertAllEqual(self.evaluate(tensor_one), 1)
|
||||
self.assertEqual(python_one, 1)
|
||||
|
||||
def test_tensor_list_empty_list(self):
|
||||
@ -45,21 +45,21 @@ class SpecialFunctionsTest(test.TestCase):
|
||||
element_shape=())
|
||||
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(sess.run(sl), [])
|
||||
self.assertAllEqual(self.evaluate(sl), [])
|
||||
|
||||
l = special_functions.tensor_list((),
|
||||
element_dtype=dtypes.int32,
|
||||
element_shape=())
|
||||
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(sess.run(sl), [])
|
||||
self.assertAllEqual(self.evaluate(sl), [])
|
||||
|
||||
def test_tensor_list_tensor(self):
|
||||
l = special_functions.tensor_list(
|
||||
constant_op.constant([], dtype=dtypes.int32))
|
||||
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(sess.run(sl), [])
|
||||
self.assertAllEqual(self.evaluate(sl), [])
|
||||
|
||||
def test_tensor_list_unsupported_initializer(self):
|
||||
with self.assertRaisesRegexp(ValueError, 'unknown type'):
|
||||
@ -76,7 +76,7 @@ class SpecialFunctionsTest(test.TestCase):
|
||||
l = special_functions.tensor_list(elements)
|
||||
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
|
||||
self.assertAllEqual(self.evaluate(sl), [[1, 2], [3, 4]])
|
||||
|
||||
def test_tensor_list_array_from_elements(self):
|
||||
elements = [constant_op.constant([1, 2]), constant_op.constant([3, 4])]
|
||||
@ -84,7 +84,7 @@ class SpecialFunctionsTest(test.TestCase):
|
||||
l = special_functions.tensor_list(elements, use_tensor_array=True)
|
||||
sl = l.stack()
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
|
||||
self.assertAllEqual(self.evaluate(sl), [[1, 2], [3, 4]])
|
||||
|
||||
def test_stack(self):
|
||||
self.assertEqual(special_functions.stack(1, strict=False), 1)
|
||||
|
@ -35,7 +35,7 @@ class ForLoopTest(test.TestCase):
|
||||
body=lambda i, s: (s + i,),
|
||||
init_state=(0,))
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual((10,), sess.run(s))
|
||||
self.assertEqual((10,), self.evaluate(s))
|
||||
|
||||
def test_python(self):
|
||||
s = control_flow.for_stmt(
|
||||
@ -53,7 +53,7 @@ class ForLoopTest(test.TestCase):
|
||||
body=lambda i, s: (s + i,),
|
||||
init_state=(0,))
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual((10,), sess.run(s))
|
||||
self.assertEqual((10,), self.evaluate(s))
|
||||
|
||||
|
||||
class WhileLoopTest(test.TestCase):
|
||||
@ -66,7 +66,7 @@ class WhileLoopTest(test.TestCase):
|
||||
init_state=(0, 0),
|
||||
extra_deps=(n,))
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual((5, 10), sess.run(results))
|
||||
self.assertEqual((5, 10), self.evaluate(results))
|
||||
|
||||
def test_python(self):
|
||||
n = 5
|
||||
@ -90,9 +90,9 @@ class IfStmtTest(test.TestCase):
|
||||
def test_tensor(self):
|
||||
with self.cached_session() as sess:
|
||||
t = self.single_return_if_stmt(constant_op.constant(True))
|
||||
self.assertEqual(1, sess.run(t))
|
||||
self.assertEqual(1, self.evaluate(t))
|
||||
t = self.single_return_if_stmt(constant_op.constant(False))
|
||||
self.assertEqual(-1, sess.run(t))
|
||||
self.assertEqual(-1, self.evaluate(t))
|
||||
|
||||
def test_python(self):
|
||||
self.assertEqual(1, self.single_return_if_stmt(True))
|
||||
@ -101,9 +101,9 @@ class IfStmtTest(test.TestCase):
|
||||
def test_tensor_multiple_returns(self):
|
||||
with self.cached_session() as sess:
|
||||
t = self.multi_return_if_stmt(constant_op.constant(True))
|
||||
self.assertAllEqual([1, 2], sess.run(t))
|
||||
self.assertAllEqual([1, 2], self.evaluate(t))
|
||||
t = self.multi_return_if_stmt(constant_op.constant(False))
|
||||
self.assertAllEqual([-1, -2], sess.run(t))
|
||||
self.assertAllEqual([-1, -2], self.evaluate(t))
|
||||
|
||||
def test_python_multiple_returns(self):
|
||||
self.assertEqual((1, 2), self.multi_return_if_stmt(True))
|
||||
|
@ -43,7 +43,7 @@ class ListTest(test.TestCase):
|
||||
l = data_structures.tf_tensor_list_new([3, 4, 5])
|
||||
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(sess.run(t), [3, 4, 5])
|
||||
self.assertAllEqual(self.evaluate(t), [3, 4, 5])
|
||||
|
||||
def test_tf_tensor_list_new_empty(self):
|
||||
l = data_structures.tf_tensor_list_new([],
|
||||
@ -51,13 +51,13 @@ class ListTest(test.TestCase):
|
||||
element_shape=())
|
||||
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(sess.run(t), [])
|
||||
self.assertAllEqual(self.evaluate(t), [])
|
||||
|
||||
def test_tf_tensor_list_new_from_tensor(self):
|
||||
l = data_structures.tf_tensor_list_new(constant_op.constant([3, 4, 5]))
|
||||
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(sess.run(t), [3, 4, 5])
|
||||
self.assertAllEqual(self.evaluate(t), [3, 4, 5])
|
||||
|
||||
def test_tf_tensor_list_new_illegal_input(self):
|
||||
with self.assertRaises(ValueError):
|
||||
@ -77,7 +77,7 @@ class ListTest(test.TestCase):
|
||||
l = data_structures.tf_tensor_array_new([3, 4, 5])
|
||||
t = l.stack()
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(sess.run(t), [3, 4, 5])
|
||||
self.assertAllEqual(self.evaluate(t), [3, 4, 5])
|
||||
|
||||
def test_tf_tensor_array_new_illegal_input(self):
|
||||
with self.assertRaises(ValueError):
|
||||
@ -102,7 +102,7 @@ class ListTest(test.TestCase):
|
||||
|
||||
t = list_ops.tensor_list_stack(l, element_dtype=x.dtype)
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(sess.run(t), [[1, 2, 3]])
|
||||
self.assertAllEqual(self.evaluate(t), [[1, 2, 3]])
|
||||
|
||||
def test_append_tensorarray(self):
|
||||
l = tensor_array_ops.TensorArray(dtypes.int32, size=0, dynamic_size=True)
|
||||
@ -131,10 +131,10 @@ class ListTest(test.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
l, x = data_structures.list_pop(l, None, opts)
|
||||
self.assertAllEqual(sess.run(x), [3, 4])
|
||||
self.assertAllEqual(self.evaluate(x), [3, 4])
|
||||
|
||||
t = list_ops.tensor_list_stack(l, element_dtype=initial_list.dtype)
|
||||
self.assertAllEqual(sess.run(t), [[1, 2]])
|
||||
self.assertAllEqual(self.evaluate(t), [[1, 2]])
|
||||
|
||||
def test_pop_python(self):
|
||||
l = [1, 2, 3]
|
||||
@ -152,7 +152,7 @@ class ListTest(test.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
t = data_structures.list_stack(l, opts)
|
||||
self.assertAllEqual(sess.run(t), sess.run(initial_list))
|
||||
self.assertAllEqual(sess.run(t), self.evaluate(initial_list))
|
||||
|
||||
def test_stack_tensor_list_empty(self):
|
||||
l = list_ops.empty_tensor_list(
|
||||
|
@ -45,11 +45,11 @@ class LogicalOperatorsTest(test.TestCase):
|
||||
def test_and_tf(self):
|
||||
with self.cached_session() as sess:
|
||||
t = logical.and_(self._tf_true, self._tf_true)
|
||||
self.assertEqual(sess.run(t), True)
|
||||
self.assertEqual(self.evaluate(t), True)
|
||||
t = logical.and_(self._tf_true, lambda: True)
|
||||
self.assertEqual(sess.run(t), True)
|
||||
self.assertEqual(self.evaluate(t), True)
|
||||
t = logical.and_(self._tf_false, lambda: True)
|
||||
self.assertEqual(sess.run(t), False)
|
||||
self.assertEqual(self.evaluate(t), False)
|
||||
# TODO(mdan): Add a test for ops with side effects.
|
||||
|
||||
def test_or_python(self):
|
||||
@ -63,11 +63,11 @@ class LogicalOperatorsTest(test.TestCase):
|
||||
def test_or_tf(self):
|
||||
with self.cached_session() as sess:
|
||||
t = logical.or_(self._tf_false, self._tf_true)
|
||||
self.assertEqual(sess.run(t), True)
|
||||
self.assertEqual(self.evaluate(t), True)
|
||||
t = logical.or_(self._tf_false, lambda: True)
|
||||
self.assertEqual(sess.run(t), True)
|
||||
self.assertEqual(self.evaluate(t), True)
|
||||
t = logical.or_(self._tf_true, lambda: True)
|
||||
self.assertEqual(sess.run(t), True)
|
||||
self.assertEqual(self.evaluate(t), True)
|
||||
# TODO(mdan): Add a test for ops with side effects.
|
||||
|
||||
def test_not_python(self):
|
||||
@ -78,7 +78,7 @@ class LogicalOperatorsTest(test.TestCase):
|
||||
def test_not_tf(self):
|
||||
with self.cached_session() as sess:
|
||||
t = logical.not_(self._tf_false())
|
||||
self.assertEqual(sess.run(t), True)
|
||||
self.assertEqual(self.evaluate(t), True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -38,29 +38,29 @@ class PyBuiltinsTest(test.TestCase):
|
||||
self.assertEqual(py_builtins.abs_(-1), 1)
|
||||
with self.cached_session() as sess:
|
||||
t = py_builtins.abs_(constant_op.constant(-1))
|
||||
self.assertEqual(sess.run(t), 1)
|
||||
self.assertEqual(self.evaluate(t), 1)
|
||||
t = py_builtins.abs_(constant_op.constant([-1, 2, -3]))
|
||||
self.assertAllEqual(sess.run(t), [1, 2, 3])
|
||||
self.assertAllEqual(self.evaluate(t), [1, 2, 3])
|
||||
|
||||
def test_float(self):
|
||||
self.assertEqual(py_builtins.float_(10), 10.0)
|
||||
self.assertEqual(py_builtins.float_('10.0'), 10.0)
|
||||
with self.cached_session() as sess:
|
||||
t = py_builtins.float_(constant_op.constant(1, dtype=dtypes.int64))
|
||||
self.assertEqual(sess.run(t), 1.0)
|
||||
self.assertEqual(self.evaluate(t), 1.0)
|
||||
st = py_builtins.float_(constant_op.constant('1.0'))
|
||||
self.assertEqual(sess.run(st), 1.0)
|
||||
self.assertEqual(self.evaluate(st), 1.0)
|
||||
|
||||
def test_int(self):
|
||||
self.assertEqual(py_builtins.int_(10.0), 10)
|
||||
self.assertEqual(py_builtins.int_('11', 2), 3)
|
||||
with self.cached_session() as sess:
|
||||
t = py_builtins.int_(constant_op.constant(1, dtype=dtypes.float64))
|
||||
self.assertEqual(sess.run(t), 1)
|
||||
self.assertEqual(self.evaluate(t), 1)
|
||||
st = py_builtins.int_(constant_op.constant('1'))
|
||||
self.assertEqual(sess.run(st), 1)
|
||||
self.assertEqual(self.evaluate(st), 1)
|
||||
st = py_builtins.int_(constant_op.constant('1'), 10)
|
||||
self.assertEqual(sess.run(st), 1)
|
||||
self.assertEqual(self.evaluate(st), 1)
|
||||
|
||||
def test_int_unsupported_base(self):
|
||||
t = constant_op.constant(1, dtype=dtypes.float64)
|
||||
@ -73,9 +73,9 @@ class PyBuiltinsTest(test.TestCase):
|
||||
t = py_builtins.len_(constant_op.constant([[1], [2], [3]]))
|
||||
self.assertEqual(t, 3)
|
||||
ta = py_builtins.len_(tensor_array_ops.TensorArray(dtypes.int32, size=5))
|
||||
self.assertEqual(sess.run(ta), 5)
|
||||
self.assertEqual(self.evaluate(ta), 5)
|
||||
tl = py_builtins.len_(data_structures.tf_tensor_list_new([3, 4, 5]))
|
||||
self.assertEqual(sess.run(tl), 3)
|
||||
self.assertEqual(self.evaluate(tl), 3)
|
||||
|
||||
def test_len_scalar(self):
|
||||
with self.assertRaises(ValueError):
|
||||
@ -120,18 +120,18 @@ class PyBuiltinsTest(test.TestCase):
|
||||
def test_range_tensor(self):
|
||||
with self.cached_session() as sess:
|
||||
r = py_builtins.range_(constant_op.constant(3))
|
||||
self.assertAllEqual(sess.run(r), [0, 1, 2])
|
||||
self.assertAllEqual(self.evaluate(r), [0, 1, 2])
|
||||
r = py_builtins.range_(1, constant_op.constant(3))
|
||||
self.assertAllEqual(sess.run(r), [1, 2])
|
||||
self.assertAllEqual(self.evaluate(r), [1, 2])
|
||||
r = py_builtins.range_(2, 0, constant_op.constant(-1))
|
||||
self.assertAllEqual(sess.run(r), [2, 1])
|
||||
self.assertAllEqual(self.evaluate(r), [2, 1])
|
||||
|
||||
def test_range_tensor_empty_range(self):
|
||||
with self.session() as sess:
|
||||
r = py_builtins.range_(constant_op.constant(-3))
|
||||
self.assertAllEqual(sess.run(r), [])
|
||||
self.assertAllEqual(self.evaluate(r), [])
|
||||
r = py_builtins.range_(5, constant_op.constant(2))
|
||||
self.assertAllEqual(sess.run(r), [])
|
||||
self.assertAllEqual(self.evaluate(r), [])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -34,7 +34,7 @@ class SlicesTest(test.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
t = list_ops.tensor_list_stack(l, element_dtype=initial_list.dtype)
|
||||
self.assertAllEqual(sess.run(t), [[5, 6], [3, 4]])
|
||||
self.assertAllEqual(self.evaluate(t), [[5, 6], [3, 4]])
|
||||
|
||||
def test_get_item_tensor_list(self):
|
||||
initial_list = constant_op.constant([[1, 2], [3, 4]])
|
||||
@ -44,7 +44,7 @@ class SlicesTest(test.TestCase):
|
||||
l, 1, slices.GetItemOpts(element_dtype=initial_list.dtype))
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(sess.run(t), [3, 4])
|
||||
self.assertAllEqual(self.evaluate(t), [3, 4])
|
||||
|
||||
def test_get_item_tensor_string(self):
|
||||
initial_str = constant_op.constant('abcd')
|
||||
@ -52,14 +52,14 @@ class SlicesTest(test.TestCase):
|
||||
slices.GetItemOpts(element_dtype=initial_str.dtype))
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(sess.run(t), b'b')
|
||||
self.assertEqual(self.evaluate(t), b'b')
|
||||
|
||||
initial_list_str = constant_op.constant(['abcd', 'bcde'])
|
||||
t = slices.get_item(initial_list_str, 1,
|
||||
slices.GetItemOpts(element_dtype=initial_str.dtype))
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(sess.run(t), b'bcde')
|
||||
self.assertEqual(self.evaluate(t), b'bcde')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -32,7 +32,7 @@ class MiscTest(test.TestCase):
|
||||
new_a = alias_tensors(a)
|
||||
self.assertFalse(new_a is a)
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(1, sess.run(new_a))
|
||||
self.assertEqual(1, self.evaluate(new_a))
|
||||
|
||||
def test_alias_tensors(self):
|
||||
a = constant(1)
|
||||
@ -47,7 +47,7 @@ class MiscTest(test.TestCase):
|
||||
self.assertTrue(new_s is s)
|
||||
self.assertTrue(new_l is l)
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(1, sess.run(new_a))
|
||||
self.assertEqual(1, self.evaluate(new_a))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -34,13 +34,13 @@ class PyFuncTest(test.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
result = py_func.wrap_py_func(test_fn, dtypes.int64,
|
||||
(1, constant_op.constant(1), 1))
|
||||
self.assertEqual(3, sess.run(result))
|
||||
self.assertEqual(3, self.evaluate(result))
|
||||
result = py_func.wrap_py_func(test_fn, dtypes.int64, (1, 1, 1))
|
||||
self.assertEqual(3, sess.run(result))
|
||||
self.assertEqual(3, self.evaluate(result))
|
||||
result = py_func.wrap_py_func(
|
||||
test_fn, dtypes.int64,
|
||||
(constant_op.constant(1), 1, constant_op.constant(1)))
|
||||
self.assertEqual(3, sess.run(result))
|
||||
self.assertEqual(3, self.evaluate(result))
|
||||
|
||||
def test_wrap_py_func_complex_args(self):
|
||||
|
||||
@ -54,10 +54,10 @@ class PyFuncTest(test.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
result = py_func.wrap_py_func(test_fn, dtypes.int64, (7, TestClass()))
|
||||
self.assertEqual(35, sess.run(result))
|
||||
self.assertEqual(35, self.evaluate(result))
|
||||
result = py_func.wrap_py_func(test_fn, dtypes.int64,
|
||||
(constant_op.constant(7), TestClass()))
|
||||
self.assertEqual(35, sess.run(result))
|
||||
self.assertEqual(35, self.evaluate(result))
|
||||
|
||||
def test_wrap_py_func_kwargs(self):
|
||||
|
||||
@ -74,13 +74,13 @@ class PyFuncTest(test.TestCase):
|
||||
'c': 11,
|
||||
'd': TestClass(13)
|
||||
})
|
||||
self.assertEqual(178, sess.run(result))
|
||||
self.assertEqual(178, self.evaluate(result))
|
||||
result = py_func.wrap_py_func(test_fn, dtypes.int64,
|
||||
(constant_op.constant(7), TestClass(5)), {
|
||||
'c': constant_op.constant(11),
|
||||
'd': TestClass(13)
|
||||
})
|
||||
self.assertEqual(178, sess.run(result))
|
||||
self.assertEqual(178, self.evaluate(result))
|
||||
|
||||
def test_wrap_py_func_dummy_return(self):
|
||||
|
||||
@ -91,11 +91,11 @@ class PyFuncTest(test.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
result = py_func.wrap_py_func(test_fn, None, (5,), use_dummy_return=True)
|
||||
self.assertEqual(1, sess.run(result))
|
||||
self.assertEqual(1, self.evaluate(result))
|
||||
self.assertEqual([1], side_counter)
|
||||
result = py_func.wrap_py_func(
|
||||
test_fn, None, (constant_op.constant(5),), use_dummy_return=True)
|
||||
self.assertEqual(1, sess.run(result))
|
||||
self.assertEqual(1, self.evaluate(result))
|
||||
self.assertEqual([2], side_counter)
|
||||
|
||||
|
||||
|
@ -43,13 +43,13 @@ class TensorListTest(test.TestCase):
|
||||
l = tl.dynamic_list_append(l, 1)
|
||||
s = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(sess.run(s), [1])
|
||||
self.assertAllEqual(self.evaluate(s), [1])
|
||||
|
||||
l = tensor_array_ops.TensorArray(dtypes.int32, size=0, dynamic_size=True)
|
||||
l = tl.dynamic_list_append(l, 1)
|
||||
s = l.stack()
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(sess.run(s), [1])
|
||||
self.assertAllEqual(self.evaluate(s), [1])
|
||||
|
||||
l = tl.TensorList(self._shape(()), dtypes.int32)
|
||||
l = tl.dynamic_list_append(l, 1)
|
||||
|
@ -62,7 +62,7 @@ class SessionClusterSpecPropagationTest(test_util.TensorFlowTestCase):
|
||||
|
||||
const = constant_op.constant(17)
|
||||
sess = session.Session(server1.target, config=config)
|
||||
output = sess.run(const)
|
||||
output = self.evaluate(const)
|
||||
self.assertEqual(17, output)
|
||||
|
||||
def testClusterSpecPropagationWorker2Placement(self):
|
||||
@ -106,7 +106,7 @@ class SessionClusterSpecPropagationTest(test_util.TensorFlowTestCase):
|
||||
with ops.Graph().as_default() as g, ops.device('/job:worker/task:0'):
|
||||
const = constant_op.constant(17)
|
||||
sess = session.Session(server1.target, config=config, graph=g)
|
||||
output = sess.run(const)
|
||||
output = self.evaluate(const)
|
||||
self.assertEqual(17, output)
|
||||
|
||||
def testCanonicalDeviceNames(self):
|
||||
@ -208,7 +208,7 @@ class SessionClusterSpecPropagationTest(test_util.TensorFlowTestCase):
|
||||
with ops.device('/job:worker/task:0/cpu:0'):
|
||||
sum3 = sum1 + sum2
|
||||
sess = session.Session(server1.target, config=config, graph=g)
|
||||
output = sess.run(sum3)
|
||||
output = self.evaluate(sum3)
|
||||
self.assertEqual(40, output)
|
||||
|
||||
def testLegacyDeviceNames(self):
|
||||
|
@ -147,7 +147,7 @@ class TimelineTest(test.TestCase):
|
||||
num2 = variables.Variable(2.0, name='num2')
|
||||
with ops.device('/cpu:2'):
|
||||
result = num1 + num2 + num1 * num2
|
||||
sess.run(variables.global_variables_initializer())
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
sess.run(result, options=run_options, run_metadata=run_metadata)
|
||||
|
||||
self.assertTrue(run_metadata.HasField('step_stats'))
|
||||
@ -176,7 +176,7 @@ class TimelineTest(test.TestCase):
|
||||
num2 = variables.Variable(2.0, name='num2')
|
||||
with ops.device('/cpu:2'):
|
||||
result = num1 + num2 + num1 * num2
|
||||
sess.run(variables.global_variables_initializer())
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
sess.run(result, options=run_options, run_metadata=run_metadata)
|
||||
self.assertTrue(run_metadata.HasField('step_stats'))
|
||||
step_stats = run_metadata.step_stats
|
||||
|
@ -216,7 +216,7 @@ class VirtualGpuTest(test_util.TensorFlowTestCase):
|
||||
for d in self._util.devices:
|
||||
with ops.device(d):
|
||||
var = variables.Variable(random_ops.random_uniform(mat_shape))
|
||||
sess.run(var.initializer)
|
||||
self.evaluate(var.initializer)
|
||||
data.append(var)
|
||||
s = data[0]
|
||||
for i in range(1, len(data)):
|
||||
|
@ -53,10 +53,10 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
|
||||
for start in range(0, len(components), 4):
|
||||
results = sess.run(get_next)
|
||||
results = self.evaluate(get_next)
|
||||
self.assertAllEqual([[i, j]
|
||||
for i, c in enumerate(components[start:start + 4])
|
||||
for j in range(c)], results.indices)
|
||||
@ -81,10 +81,10 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
|
||||
for start in range(0, len(components), 4):
|
||||
results = sess.run(get_next)
|
||||
results = self.evaluate(get_next)
|
||||
self.assertAllEqual([[i, j, z]
|
||||
for i, c in enumerate(components[start:start + 4])
|
||||
for j in range(c)
|
||||
@ -141,7 +141,7 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer, feed_dict={placeholder: [0, 1, 2, 3]})
|
||||
for i in range(4):
|
||||
self.assertEqual(i, sess.run(next_elem))
|
||||
self.assertEqual(i, self.evaluate(next_elem))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_elem)
|
||||
|
||||
@ -159,7 +159,7 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual((i,) * 3, sess.run(op))
|
||||
self.assertEqual((i,) * 3, self.evaluate(op))
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(op)
|
||||
@ -179,7 +179,7 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual((i, compat.as_bytes(str(i)), i), sess.run(op))
|
||||
self.assertEqual((i, compat.as_bytes(str(i)), i), self.evaluate(op))
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(op)
|
||||
@ -198,7 +198,7 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
st_row = sess.run(next_element)
|
||||
st_row = self.evaluate(next_element)
|
||||
self.assertEqual([i], st_row.indices)
|
||||
self.assertEqual([i], st_row.values)
|
||||
self.assertEqual([10], st_row.dense_shape)
|
||||
@ -219,7 +219,7 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
dense_elem, st_row = sess.run(next_element)
|
||||
dense_elem, st_row = self.evaluate(next_element)
|
||||
self.assertEqual(i, dense_elem)
|
||||
self.assertEqual([i], st_row.indices)
|
||||
self.assertEqual([i], st_row.values)
|
||||
@ -241,7 +241,7 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual(((i,),) * 3, sess.run(op))
|
||||
self.assertEqual(((i,),) * 3, self.evaluate(op))
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(op)
|
||||
@ -354,7 +354,7 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
sess.run(init_op, feed_dict={count: 28, batch_size: 14})
|
||||
num_batches = (28 * 7) // 14
|
||||
for i in range(num_batches):
|
||||
result = sess.run(get_next)
|
||||
result = self.evaluate(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
for j in range(14):
|
||||
self.assertAllEqual(component[(i * 14 + j) % 7]**2,
|
||||
@ -369,12 +369,12 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
# We expect (num_batches - 1) full-sized batches.
|
||||
num_batches = int(math.ceil((14 * 7) / 8))
|
||||
for i in range(num_batches - 1):
|
||||
result = sess.run(get_next)
|
||||
result = self.evaluate(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
for j in range(8):
|
||||
self.assertAllEqual(component[(i * 8 + j) % 7]**2,
|
||||
result_component[j])
|
||||
result = sess.run(get_next)
|
||||
result = self.evaluate(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
for j in range((14 * 7) % 8):
|
||||
self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2,
|
||||
@ -408,10 +408,10 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertEqual([None, 1], iterator.output_shapes.as_list())
|
||||
next_element = iterator.get_next()
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
|
||||
self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
|
||||
self.assertAllEqual([[0], [1], [4], [9]], self.evaluate(next_element))
|
||||
self.assertAllEqual([[16], [25], [36], [49]], self.evaluate(next_element))
|
||||
if not drop_remainder:
|
||||
self.assertAllEqual([[64], [81]], sess.run(next_element))
|
||||
self.assertAllEqual([[64], [81]], self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -423,9 +423,9 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertEqual([None, 1], iterator.output_shapes.as_list())
|
||||
next_element = iterator.get_next()
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
|
||||
self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
|
||||
self.assertAllEqual([[64], [81]], sess.run(next_element))
|
||||
self.assertAllEqual([[0], [1], [4], [9]], self.evaluate(next_element))
|
||||
self.assertAllEqual([[16], [25], [36], [49]], self.evaluate(next_element))
|
||||
self.assertAllEqual([[64], [81]], self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -439,7 +439,7 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
elements.append(iterator.get_next())
|
||||
with self.cached_session() as sess:
|
||||
for i in range(5):
|
||||
got = sess.run(elements)
|
||||
got = self.evaluate(elements)
|
||||
got.sort(key=lambda x: x[0])
|
||||
expected = []
|
||||
for j in range(100):
|
||||
@ -459,7 +459,7 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
elements.append(iterator.get_next())
|
||||
with self.cached_session() as sess:
|
||||
for i in range(4):
|
||||
got = sess.run(elements)
|
||||
got = self.evaluate(elements)
|
||||
got.sort(key=lambda x: x[0])
|
||||
expected = []
|
||||
for j in range(100):
|
||||
@ -480,9 +480,9 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
for i in range(2):
|
||||
actual = sess.run(get_next)
|
||||
actual = self.evaluate(get_next)
|
||||
expected = sparse_tensor.SparseTensorValue(
|
||||
indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
|
||||
values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4],
|
||||
@ -524,7 +524,7 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
"number of elements does not match"):
|
||||
sess.run(get_next)
|
||||
@ -576,7 +576,8 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(threshold // 10):
|
||||
self.assertAllEqual([i * 10 + j for j in range(10)], sess.run(get_next))
|
||||
self.assertAllEqual([i * 10 + j for j in range(10)],
|
||||
self.evaluate(get_next))
|
||||
if threshold % 10 != 0:
|
||||
self.assertAllEqual(
|
||||
[threshold // 10 * 10 + j for j in range(threshold % 10)],
|
||||
@ -609,7 +610,8 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for _ in range(10):
|
||||
self.assertAllEqual([element for _ in range(10)], sess.run(get_next))
|
||||
self.assertAllEqual([element for _ in range(10)],
|
||||
self.evaluate(get_next))
|
||||
|
||||
|
||||
class UnbatchDatasetBenchmark(test.Benchmark):
|
||||
|
@ -300,7 +300,7 @@ class BucketBySequenceLengthTest(test_base.DatasetTestBase):
|
||||
with self.cached_session() as sess:
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
while True:
|
||||
output = sess.run(batch)
|
||||
output = self.evaluate(batch)
|
||||
sprs_tensor = (tuple([tuple(idx) for idx in output.indices]),
|
||||
tuple(output.values))
|
||||
all_sparse_tensors.add(sprs_tensor)
|
||||
|
@ -57,7 +57,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -82,7 +82,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
|
||||
self.assertAllEqual([0, 1, 2, 3], self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -108,7 +108,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -134,7 +134,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -160,7 +160,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual({"a": i}, sess.run(next_element))
|
||||
self.assertEqual({"a": i}, self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -186,7 +186,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual({"a": i}, sess.run(next_element))
|
||||
self.assertEqual({"a": i}, self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -217,7 +217,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
for i in range(10):
|
||||
actual = sess.run(next_element)
|
||||
actual = self.evaluate(next_element)
|
||||
self.assertAllEqual([i], actual.values)
|
||||
self.assertAllEqual([[0, 0]], actual.indices)
|
||||
self.assertAllEqual([2, 2], actual.dense_shape)
|
||||
@ -251,7 +251,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
for i in range(10):
|
||||
actual = sess.run(next_element)
|
||||
actual = self.evaluate(next_element)
|
||||
self.assertAllEqual([i], actual.values)
|
||||
self.assertAllEqual([[0, 0]], actual.indices)
|
||||
self.assertAllEqual([2, 2], actual.dense_shape)
|
||||
@ -271,9 +271,9 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -290,9 +290,9 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -323,9 +323,9 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(10):
|
||||
x, y, z = sess.run(next_element)
|
||||
x, y, z = self.evaluate(next_element)
|
||||
self.assertEqual(i**2, x)
|
||||
self.assertEqual(float(i**2), y)
|
||||
self.assertEqual(util_compat.as_bytes(str(i)), z)
|
||||
@ -345,8 +345,8 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
|
||||
self.evaluate(iterator.initializer)
|
||||
self.assertAllEqual([0, 1, 2, 3], self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -363,8 +363,8 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
|
||||
self.evaluate(iterator.initializer)
|
||||
self.assertAllEqual([0, 1, 2, 3], self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -381,8 +381,8 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element))
|
||||
self.evaluate(iterator.initializer)
|
||||
self.assertAllEqual([b"a", b"b", b"c"], self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -399,8 +399,8 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element))
|
||||
self.evaluate(iterator.initializer)
|
||||
self.assertAllEqual([b"a", b"b", b"c"], self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -420,9 +420,9 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -447,12 +447,12 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(5):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
sess.run(iterator.initializer)
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -477,12 +477,12 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(5):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
sess.run(iterator.initializer)
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -499,12 +499,12 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(5):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
sess.run(iterator.initializer)
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -521,12 +521,12 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(5):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
sess.run(iterator.initializer)
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -553,7 +553,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
|
||||
# For each element of the dataset, assert that the optional evaluates to
|
||||
# the expected value.
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(3):
|
||||
elem_has_value, elem_value = sess.run([elem_has_value_t, elem_value_t])
|
||||
self.assertTrue(elem_has_value)
|
||||
@ -562,7 +562,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
# After exhausting the iterator, `next_elem.has_value()` will evaluate to
|
||||
# false, and attempting to get the value will fail.
|
||||
for _ in range(2):
|
||||
self.assertFalse(sess.run(elem_has_value_t))
|
||||
self.assertFalse(self.evaluate(elem_has_value_t))
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(elem_value_t)
|
||||
|
||||
|
@ -38,13 +38,13 @@ class CounterTest(test_base.DatasetTestBase):
|
||||
negative_get_next = negative_iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(3, sess.run(get_next))
|
||||
self.assertEqual(3 + 4, sess.run(get_next))
|
||||
self.assertEqual(3 + 2 * 4, sess.run(get_next))
|
||||
self.assertEqual(3, self.evaluate(get_next))
|
||||
self.assertEqual(3 + 4, self.evaluate(get_next))
|
||||
self.assertEqual(3 + 2 * 4, self.evaluate(get_next))
|
||||
|
||||
self.assertEqual(0, sess.run(negative_get_next))
|
||||
self.assertEqual(-1, sess.run(negative_get_next))
|
||||
self.assertEqual(-2, sess.run(negative_get_next))
|
||||
self.assertEqual(0, self.evaluate(negative_get_next))
|
||||
self.assertEqual(-1, self.evaluate(negative_get_next))
|
||||
self.assertEqual(-2, self.evaluate(negative_get_next))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -41,10 +41,10 @@ class DenseToSparseBatchTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
|
||||
for start in range(0, len(components), 4):
|
||||
results = sess.run(get_next)
|
||||
results = self.evaluate(get_next)
|
||||
self.assertAllEqual([[i, j]
|
||||
for i, c in enumerate(components[start:start + 4])
|
||||
for j in range(c)], results.indices)
|
||||
@ -69,10 +69,10 @@ class DenseToSparseBatchTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
|
||||
for start in range(0, len(components), 4):
|
||||
results = sess.run(get_next)
|
||||
results = self.evaluate(get_next)
|
||||
self.assertAllEqual([[i, j, z]
|
||||
for i, c in enumerate(components[start:start + 4])
|
||||
for j in range(c)
|
||||
|
@ -40,10 +40,10 @@ class DirectedInterleaveDatasetTest(test_base.DatasetTestBase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
for _ in range(100):
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -107,7 +107,7 @@ class DirectedInterleaveDatasetTest(test_base.DatasetTestBase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in choice_array:
|
||||
self.assertEqual(words[i], sess.run(next_element))
|
||||
self.assertEqual(words[i], self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
|
@ -44,9 +44,9 @@ class EnumerateDatasetTest(test_base.DatasetTestBase):
|
||||
[t.shape for t in get_next[1]])
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.assertEqual((20, (b"a", 1, 37.0)), sess.run(get_next))
|
||||
self.assertEqual((21, (b"b", 2, 38.0)), sess.run(get_next))
|
||||
self.evaluate(init_op)
|
||||
self.assertEqual((20, (b"a", 1, 37.0)), self.evaluate(get_next))
|
||||
self.assertEqual((21, (b"b", 2, 38.0)), self.evaluate(get_next))
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
@ -94,18 +94,18 @@ class FunctionBufferingResourceTest(test_base.DatasetTestBase):
|
||||
device0, device1)
|
||||
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
elem = sess.run(prefetch_op)
|
||||
elem = self.evaluate(prefetch_op)
|
||||
self.assertEqual(elem, [1.0])
|
||||
elem = sess.run(prefetch_op)
|
||||
elem = self.evaluate(prefetch_op)
|
||||
self.assertEqual(elem, [2.0])
|
||||
elem = sess.run(prefetch_op)
|
||||
elem = self.evaluate(prefetch_op)
|
||||
self.assertEqual(elem, [3.0])
|
||||
elem = sess.run(prefetch_op)
|
||||
elem = self.evaluate(prefetch_op)
|
||||
self.assertEqual(elem, [4.0])
|
||||
self._event.wait()
|
||||
elem = sess.run(prefetch_op)
|
||||
elem = self.evaluate(prefetch_op)
|
||||
self.assertEqual(elem, [5.0])
|
||||
sess.run(destroy_op)
|
||||
self.evaluate(destroy_op)
|
||||
|
||||
def testSameDeviceCPU(self):
|
||||
self._prefetch_fn_helper_one_shot("same_device_cpu",
|
||||
@ -135,35 +135,35 @@ class FunctionBufferingResourceTest(test_base.DatasetTestBase):
|
||||
ds, ds_iterator, "reinit", device0, device1)
|
||||
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
sess.run(ds_iterator.initializer)
|
||||
elem = sess.run(prefetch_op)
|
||||
self.evaluate(ds_iterator.initializer)
|
||||
elem = self.evaluate(prefetch_op)
|
||||
self.assertEqual(elem, [1.0])
|
||||
elem = sess.run(prefetch_op)
|
||||
elem = self.evaluate(prefetch_op)
|
||||
self.assertEqual(elem, [2.0])
|
||||
elem = sess.run(prefetch_op)
|
||||
elem = self.evaluate(prefetch_op)
|
||||
self.assertEqual(elem, [3.0])
|
||||
elem = sess.run(prefetch_op)
|
||||
elem = self.evaluate(prefetch_op)
|
||||
self.assertEqual(elem, [4.0])
|
||||
self._event.wait()
|
||||
elem = sess.run(prefetch_op)
|
||||
elem = self.evaluate(prefetch_op)
|
||||
self.assertEqual(elem, [5.0])
|
||||
# Lets reset the function buffering resource and reinitialize the
|
||||
# iterator. Should be able to go through this again.
|
||||
self._event.clear()
|
||||
sess.run(reset_op)
|
||||
sess.run(ds_iterator.initializer)
|
||||
elem = sess.run(prefetch_op)
|
||||
self.evaluate(reset_op)
|
||||
self.evaluate(ds_iterator.initializer)
|
||||
elem = self.evaluate(prefetch_op)
|
||||
self.assertEqual(elem, [1.0])
|
||||
elem = sess.run(prefetch_op)
|
||||
elem = self.evaluate(prefetch_op)
|
||||
self.assertEqual(elem, [2.0])
|
||||
elem = sess.run(prefetch_op)
|
||||
elem = self.evaluate(prefetch_op)
|
||||
self.assertEqual(elem, [3.0])
|
||||
elem = sess.run(prefetch_op)
|
||||
elem = self.evaluate(prefetch_op)
|
||||
self.assertEqual(elem, [4.0])
|
||||
self._event.wait()
|
||||
elem = sess.run(prefetch_op)
|
||||
elem = self.evaluate(prefetch_op)
|
||||
self.assertEqual(elem, [5.0])
|
||||
sess.run(destroy_op)
|
||||
self.evaluate(destroy_op)
|
||||
|
||||
def testReinitializationOutOfRange(self):
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
@ -175,30 +175,30 @@ class FunctionBufferingResourceTest(test_base.DatasetTestBase):
|
||||
ds, ds_iterator, "reinit", device0, device1)
|
||||
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
sess.run(ds_iterator.initializer)
|
||||
self.evaluate(ds_iterator.initializer)
|
||||
for i in range(1, 10):
|
||||
elem = sess.run(prefetch_op)
|
||||
elem = self.evaluate(prefetch_op)
|
||||
self.assertEqual(elem, [float(i)])
|
||||
# Try fetching after its over twice to test out end of sequence.
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(prefetch_op)
|
||||
self.evaluate(prefetch_op)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(prefetch_op)
|
||||
self.evaluate(prefetch_op)
|
||||
|
||||
# Now reset everything and try it out again.
|
||||
self._event.clear()
|
||||
sess.run(reset_op)
|
||||
sess.run(ds_iterator.initializer)
|
||||
self.evaluate(reset_op)
|
||||
self.evaluate(ds_iterator.initializer)
|
||||
for i in range(1, 10):
|
||||
elem = sess.run(prefetch_op)
|
||||
elem = self.evaluate(prefetch_op)
|
||||
self.assertEqual(elem, [float(i)])
|
||||
# Try fetching after its over twice to test out end of sequence.
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(prefetch_op)
|
||||
self.evaluate(prefetch_op)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(prefetch_op)
|
||||
self.evaluate(prefetch_op)
|
||||
|
||||
sess.run(destroy_op)
|
||||
self.evaluate(destroy_op)
|
||||
|
||||
def testStringsGPU(self):
|
||||
if not test_util.is_gpu_available():
|
||||
@ -235,13 +235,13 @@ class FunctionBufferingResourceTest(test_base.DatasetTestBase):
|
||||
buffer_resource_handle, ignore_lookup_error=True)
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual([b"a"], sess.run(prefetch_op))
|
||||
self.assertEqual([b"b"], sess.run(prefetch_op))
|
||||
self.assertEqual([b"c"], sess.run(prefetch_op))
|
||||
self.assertEqual([b"a"], self.evaluate(prefetch_op))
|
||||
self.assertEqual([b"b"], self.evaluate(prefetch_op))
|
||||
self.assertEqual([b"c"], self.evaluate(prefetch_op))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(prefetch_op)
|
||||
self.evaluate(prefetch_op)
|
||||
|
||||
sess.run(destroy_op)
|
||||
self.evaluate(destroy_op)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -39,7 +39,7 @@ class GroupByReducerTest(test_base.DatasetTestBase):
|
||||
get_next = dataset.make_one_shot_iterator().get_next()
|
||||
with self.cached_session() as sess:
|
||||
for expected in values:
|
||||
got = sess.run(get_next)
|
||||
got = self.evaluate(get_next)
|
||||
self.assertEqual(got, expected)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
@ -127,7 +127,7 @@ class GroupByReducerTest(test_base.DatasetTestBase):
|
||||
iterator = dataset.make_one_shot_iterator()
|
||||
get_next = iterator.get_next()
|
||||
with self.cached_session() as sess:
|
||||
x, y = sess.run(get_next)
|
||||
x, y = self.evaluate(get_next)
|
||||
self.assertAllEqual([0] * (2**i), x)
|
||||
self.assertAllEqual(np.array(1, ndmin=i), y)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -190,7 +190,7 @@ class GroupByReducerTest(test_base.DatasetTestBase):
|
||||
grouping.group_by_reducer(lambda x, y: np.int64(0), reducer))
|
||||
get_next = dataset.make_one_shot_iterator().get_next()
|
||||
with self.cached_session() as sess:
|
||||
x, y = sess.run(get_next)
|
||||
x, y = self.evaluate(get_next)
|
||||
self.assertAllEqual(x, np.asarray([x for x in range(10)]))
|
||||
self.assertEqual(y, 45)
|
||||
|
||||
|
@ -68,9 +68,9 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
|
||||
which_bucket, bucketed_values = sess.run(get_next)
|
||||
which_bucket, bucketed_values = self.evaluate(get_next)
|
||||
|
||||
self.assertEqual(0, which_bucket)
|
||||
|
||||
@ -103,11 +103,11 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
|
||||
# Get two minibatches (one containing even values, one containing odds)
|
||||
which_bucket_even, bucketed_values_even = sess.run(get_next)
|
||||
which_bucket_odd, bucketed_values_odd = sess.run(get_next)
|
||||
which_bucket_even, bucketed_values_even = self.evaluate(get_next)
|
||||
which_bucket_odd, bucketed_values_odd = self.evaluate(get_next)
|
||||
|
||||
# Count number of bucket_tensors.
|
||||
self.assertEqual(3, len(bucketed_values_even))
|
||||
@ -174,11 +174,11 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
|
||||
# Get two minibatches ([0, 2, ...] and [64, 66, ...])
|
||||
which_bucket0, bucketed_values_even0 = sess.run(get_next)
|
||||
which_bucket1, bucketed_values_even1 = sess.run(get_next)
|
||||
which_bucket0, bucketed_values_even0 = self.evaluate(get_next)
|
||||
which_bucket1, bucketed_values_even1 = self.evaluate(get_next)
|
||||
|
||||
# Ensure that bucket 1 was completely filtered out
|
||||
self.assertAllEqual(0, which_bucket0)
|
||||
@ -207,11 +207,11 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
batches = 0
|
||||
while True:
|
||||
result = sess.run(get_next)
|
||||
result = self.evaluate(get_next)
|
||||
is_even = all(x % 2 == 0 for x in result)
|
||||
is_odd = all(x % 2 == 1 for x in result)
|
||||
self.assertTrue(is_even or is_odd)
|
||||
@ -232,11 +232,11 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
counts = []
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
while True:
|
||||
result = sess.run(get_next)
|
||||
result = self.evaluate(get_next)
|
||||
self.assertTrue(
|
||||
all(x % 2 == 0
|
||||
for x in result) or all(x % 2 == 1)
|
||||
@ -259,16 +259,16 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
# The input is infinite, so this test demonstrates that:
|
||||
# 1. We produce output without having to consume the entire input,
|
||||
# 2. Different buckets can produce output at different rates, and
|
||||
# 3. For deterministic input, the output is deterministic.
|
||||
for _ in range(3):
|
||||
self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
|
||||
self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
|
||||
self.assertAllEqual([2, 2, 2, 2], sess.run(get_next))
|
||||
self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
|
||||
self.assertAllEqual([0, 0, 0, 0], self.evaluate(get_next))
|
||||
self.assertAllEqual([1, 1, 1, 1], self.evaluate(get_next))
|
||||
self.assertAllEqual([2, 2, 2, 2], self.evaluate(get_next))
|
||||
self.assertAllEqual([0, 0, 0, 0], self.evaluate(get_next))
|
||||
|
||||
def testSmallGroups(self):
|
||||
components = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0], dtype=np.int64)
|
||||
@ -280,13 +280,13 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
|
||||
self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
|
||||
self.evaluate(init_op)
|
||||
self.assertAllEqual([0, 0, 0, 0], self.evaluate(get_next))
|
||||
self.assertAllEqual([1, 1, 1, 1], self.evaluate(get_next))
|
||||
# The small outputs at the end are deterministically produced in key
|
||||
# order.
|
||||
self.assertAllEqual([0, 0, 0], sess.run(get_next))
|
||||
self.assertAllEqual([1], sess.run(get_next))
|
||||
self.assertAllEqual([0, 0, 0], self.evaluate(get_next))
|
||||
self.assertAllEqual([1], self.evaluate(get_next))
|
||||
|
||||
def testEmpty(self):
|
||||
iterator = (
|
||||
@ -297,7 +297,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
with self.assertRaisesRegexp(
|
||||
errors.InvalidArgumentError,
|
||||
"Window size must be greater than zero, but got 0."):
|
||||
@ -323,7 +323,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -351,11 +351,11 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
counts = []
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
while True:
|
||||
tight_result, multiple_of_10_result = sess.run(get_next)
|
||||
tight_result, multiple_of_10_result = self.evaluate(get_next)
|
||||
self.assertEqual(0, multiple_of_10_result.shape[1] % 10)
|
||||
self.assertAllEqual(tight_result,
|
||||
multiple_of_10_result[:, :tight_result.shape[1]])
|
||||
|
@ -47,9 +47,9 @@ class IgnoreErrorsTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
for x in [1., 2., 3., 5.]:
|
||||
self.assertEqual(x, sess.run(get_next))
|
||||
self.assertEqual(x, self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -65,9 +65,9 @@ class IgnoreErrorsTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
for x in [1., 2., 3., 5.]:
|
||||
self.assertEqual(x, sess.run(get_next))
|
||||
self.assertEqual(x, self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -93,9 +93,9 @@ class IgnoreErrorsTest(test_base.DatasetTestBase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
# All of the files are present.
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
for filename in filenames:
|
||||
self.assertEqual(compat.as_bytes(filename), sess.run(get_next))
|
||||
self.assertEqual(compat.as_bytes(filename), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -104,9 +104,9 @@ class IgnoreErrorsTest(test_base.DatasetTestBase):
|
||||
|
||||
# Attempting to read filenames[0] will fail, but ignore_errors()
|
||||
# will catch the error.
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
for filename in filenames[1:]:
|
||||
self.assertEqual(compat.as_bytes(filename), sess.run(get_next))
|
||||
self.assertEqual(compat.as_bytes(filename), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
|
@ -53,7 +53,7 @@ class IndexedDatasetOpsTest(test_base.DatasetTestBase):
|
||||
ds = indexed_dataset_ops.IdentityIndexedDataset(16)
|
||||
materialized = ds.materialize()
|
||||
with self.cached_session() as sess:
|
||||
sess.run(materialized.initializer)
|
||||
self.evaluate(materialized.initializer)
|
||||
placeholder = array_ops.placeholder(dtypes.uint64, shape=[])
|
||||
for i in range(16):
|
||||
output = sess.run(
|
||||
@ -68,9 +68,9 @@ class IndexedDatasetOpsTest(test_base.DatasetTestBase):
|
||||
itr = ds.make_initializable_iterator()
|
||||
n = itr.get_next()
|
||||
with self.cached_session() as sess:
|
||||
sess.run(itr.initializer)
|
||||
self.evaluate(itr.initializer)
|
||||
for i in range(16):
|
||||
output = sess.run(n)
|
||||
output = self.evaluate(n)
|
||||
self.assertEqual(i, output)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(n)
|
||||
|
@ -112,10 +112,10 @@ class MakeBatchedFeaturesDatasetTest(
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
for file_batch, _, _, _, record_batch, _ in self._next_expected_batch(
|
||||
range(self._num_files), 2, 10):
|
||||
actual_batch = sess.run(next_element)
|
||||
actual_batch = self.evaluate(next_element)
|
||||
self.assertAllEqual(file_batch, actual_batch["file"])
|
||||
self.assertAllEqual(record_batch, actual_batch["record"])
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
|
@ -90,7 +90,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
|
||||
batch_size,
|
||||
num_epochs,
|
||||
):
|
||||
actual_features = sess.run(nxt)
|
||||
actual_features = self.evaluate(nxt)
|
||||
|
||||
if label_name is not None:
|
||||
expected_labels = expected_features.pop(label_name)
|
||||
|
@ -105,7 +105,7 @@ class MakeTFRecordDatasetTest(
|
||||
for expected_batch in self._next_expected_batch(
|
||||
file_indices, batch_size, num_epochs, interleave_cycle_length,
|
||||
drop_final_batch, use_parser_fn):
|
||||
actual_batch = sess.run(outputs)
|
||||
actual_batch = self.evaluate(outputs)
|
||||
self.assertAllEqual(expected_batch, actual_batch)
|
||||
|
||||
def _read_test(self, batch_size, num_epochs, file_index=None,
|
||||
@ -188,7 +188,7 @@ class MakeTFRecordDatasetTest(
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
first_batches = []
|
||||
try:
|
||||
while True:
|
||||
@ -196,7 +196,7 @@ class MakeTFRecordDatasetTest(
|
||||
except errors.OutOfRangeError:
|
||||
pass
|
||||
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
second_batches = []
|
||||
try:
|
||||
while True:
|
||||
|
@ -89,7 +89,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
sess.run(init_op, feed_dict={count: 28, batch_size: 14})
|
||||
num_batches = (28 * 7) // 14
|
||||
for i in range(num_batches):
|
||||
result = sess.run(get_next)
|
||||
result = self.evaluate(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
for j in range(14):
|
||||
self.assertAllEqual(component[(i * 14 + j) % 7]**2,
|
||||
@ -104,12 +104,12 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
# We expect (num_batches - 1) full-sized batches.
|
||||
num_batches = int(math.ceil((14 * 7) / 8))
|
||||
for i in range(num_batches - 1):
|
||||
result = sess.run(get_next)
|
||||
result = self.evaluate(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
for j in range(8):
|
||||
self.assertAllEqual(component[(i * 8 + j) % 7]**2,
|
||||
result_component[j])
|
||||
result = sess.run(get_next)
|
||||
result = self.evaluate(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
for j in range((14 * 7) % 8):
|
||||
self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2,
|
||||
@ -152,10 +152,10 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertEqual([None, 1], iterator.output_shapes.as_list())
|
||||
next_element = iterator.get_next()
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
|
||||
self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
|
||||
self.assertAllEqual([[0], [1], [4], [9]], self.evaluate(next_element))
|
||||
self.assertAllEqual([[16], [25], [36], [49]], self.evaluate(next_element))
|
||||
if not drop_remainder:
|
||||
self.assertAllEqual([[64], [81]], sess.run(next_element))
|
||||
self.assertAllEqual([[64], [81]], self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -177,9 +177,9 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertEqual([None, 1], iterator.output_shapes.as_list())
|
||||
next_element = iterator.get_next()
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
|
||||
self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
|
||||
self.assertAllEqual([[64], [81]], sess.run(next_element))
|
||||
self.assertAllEqual([[0], [1], [4], [9]], self.evaluate(next_element))
|
||||
self.assertAllEqual([[16], [25], [36], [49]], self.evaluate(next_element))
|
||||
self.assertAllEqual([[64], [81]], self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -201,7 +201,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
elements.append(iterator.get_next())
|
||||
with self.cached_session() as sess:
|
||||
for i in range(5):
|
||||
got = sess.run(elements)
|
||||
got = self.evaluate(elements)
|
||||
got.sort(key=lambda x: x[0])
|
||||
expected = []
|
||||
for j in range(100):
|
||||
@ -230,7 +230,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
elements.append(iterator.get_next())
|
||||
with self.cached_session() as sess:
|
||||
for i in range(4):
|
||||
got = sess.run(elements)
|
||||
got = self.evaluate(elements)
|
||||
got.sort(key=lambda x: x[0])
|
||||
expected = []
|
||||
for j in range(100):
|
||||
@ -261,9 +261,9 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
for i in range(2):
|
||||
actual = sess.run(get_next)
|
||||
actual = self.evaluate(get_next)
|
||||
expected = sparse_tensor.SparseTensorValue(
|
||||
indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
|
||||
values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4],
|
||||
@ -321,7 +321,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
"number of elements does not match"):
|
||||
sess.run(get_next)
|
||||
@ -393,7 +393,8 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(threshold // 10):
|
||||
self.assertAllEqual([i * 10 + j for j in range(10)], sess.run(get_next))
|
||||
self.assertAllEqual([i * 10 + j for j in range(10)],
|
||||
self.evaluate(get_next))
|
||||
if threshold % 10 != 0:
|
||||
self.assertAllEqual(
|
||||
[threshold // 10 * 10 + j for j in range(threshold % 10)],
|
||||
@ -442,7 +443,8 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for _ in range(10):
|
||||
self.assertAllEqual([element for _ in range(10)], sess.run(get_next))
|
||||
self.assertAllEqual([element for _ in range(10)],
|
||||
self.evaluate(get_next))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("Identity", None, lambda x: x, None),
|
||||
@ -462,7 +464,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
else:
|
||||
expected = map_fn(
|
||||
sess.run(self.structuredElement(structure, shape=[10])))
|
||||
self.assertAllEqual(expected, sess.run(get_next))
|
||||
self.assertAllEqual(expected, self.evaluate(get_next))
|
||||
|
||||
def testShortCircuitCapturedInput(self):
|
||||
captured_t = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
@ -473,7 +475,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer, feed_dict={captured_t: 42})
|
||||
self.assertAllEqual([42] * 10, sess.run(get_next))
|
||||
self.assertAllEqual([42] * 10, self.evaluate(get_next))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("Normal", False),
|
||||
|
@ -218,7 +218,7 @@ class MapDefunTest(test_base.DatasetTestBase):
|
||||
|
||||
def _assert_op_cancelled(self, sess, map_defun_op):
|
||||
with self.assertRaisesRegexp(errors.CancelledError, "was cancelled"):
|
||||
sess.run(map_defun_op)
|
||||
self.evaluate(map_defun_op)
|
||||
|
||||
def testMapDefunWithParentCancellation(self):
|
||||
# Checks that a cancellation of the parent graph is threaded through to
|
||||
|
@ -72,7 +72,7 @@ class OverrideThreadpoolTest(test_base.DatasetTestBase,
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
thread_ids = []
|
||||
try:
|
||||
while True:
|
||||
|
@ -637,11 +637,11 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
for i in range(10):
|
||||
for j in range(2):
|
||||
expected = [i, 0] if j % 2 == 0 else [0, -i]
|
||||
self.assertAllEqual(expected, sess.run(get_next))
|
||||
self.assertAllEqual(expected, self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -796,7 +796,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
|
||||
with self.cached_session() as sess:
|
||||
for _ in range(2):
|
||||
elements = []
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
try:
|
||||
while True:
|
||||
elements.extend(sess.run(next_element))
|
||||
|
@ -57,7 +57,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -87,7 +87,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -117,7 +117,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual({"a": i}, sess.run(next_element))
|
||||
self.assertEqual({"a": i}, self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -150,7 +150,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
for i in range(10):
|
||||
actual = sess.run(next_element)
|
||||
actual = self.evaluate(next_element)
|
||||
self.assertAllEqual([i], actual.values)
|
||||
self.assertAllEqual([[0, 0]], actual.indices)
|
||||
self.assertAllEqual([2, 2], actual.dense_shape)
|
||||
@ -170,7 +170,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -199,12 +199,12 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
|
||||
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(5):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
sess.run(iterator.initializer)
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -220,12 +220,12 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(5):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
sess.run(iterator.initializer)
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
|
@ -60,7 +60,7 @@ class ScanTest(test_base.DatasetTestBase):
|
||||
feed_dict={start: start_val, step: step_val, take: take_val})
|
||||
for expected, _ in zip(
|
||||
itertools.count(start_val, step_val), range(take_val)):
|
||||
self.assertEqual(expected, sess.run(next_element))
|
||||
self.assertEqual(expected, self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -110,7 +110,7 @@ class ScanTest(test_base.DatasetTestBase):
|
||||
feed_dict={start: start_val, step: step_val, take: take_val})
|
||||
for expected, _ in zip(
|
||||
itertools.count(start_val, step_val), range(take_val)):
|
||||
self.assertEqual(expected, sess.run(next_element).values[0])
|
||||
self.assertEqual(expected, self.evaluate(next_element).values[0])
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -136,7 +136,7 @@ class ScanTest(test_base.DatasetTestBase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(5):
|
||||
(longer_vector_val, larger_rank_val), _ = sess.run(next_element)
|
||||
(longer_vector_val, larger_rank_val), _ = self.evaluate(next_element)
|
||||
self.assertAllEqual([0] * (2**i), longer_vector_val)
|
||||
self.assertAllEqual(np.array(1, ndmin=i), larger_rank_val)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
|
@ -71,19 +71,19 @@ class RangeDatasetSerializationTest(
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next, save_op, _ = _build_graph(start, stop)
|
||||
with self.session(graph=g) as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
sess.run(init_op)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.evaluate(init_op)
|
||||
for i in range(start, break_point):
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
sess.run(save_op)
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
self.evaluate(save_op)
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next, _, restore_op = _build_graph(start, stop)
|
||||
with self.session(graph=g) as sess:
|
||||
sess.run(init_op)
|
||||
sess.run(restore_op)
|
||||
self.evaluate(init_op)
|
||||
self.evaluate(restore_op)
|
||||
for i in range(break_point, stop):
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -91,14 +91,14 @@ class RangeDatasetSerializationTest(
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next, save_op, restore_op = _build_graph(start, stop)
|
||||
with self.session(graph=g) as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
sess.run(init_op)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.evaluate(init_op)
|
||||
for i in range(start, break_point):
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
sess.run(save_op)
|
||||
sess.run(restore_op)
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
self.evaluate(save_op)
|
||||
self.evaluate(restore_op)
|
||||
for i in range(break_point, stop):
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
|
@ -62,7 +62,7 @@ class SerializationIntegrationTest(test.TestCase):
|
||||
with self.session(graph=g) as sess:
|
||||
sess.run(init_ops)
|
||||
for _ in range(break_point):
|
||||
output = sess.run(get_next_ops)
|
||||
output = self.evaluate(get_next_ops)
|
||||
for i in range(num_pipelines):
|
||||
all_outputs[i].append(output[i])
|
||||
saver.save(sess, self._ckpt_path())
|
||||
@ -73,7 +73,7 @@ class SerializationIntegrationTest(test.TestCase):
|
||||
with self.session(graph=g) as sess:
|
||||
saver.restore(sess, self._ckpt_path())
|
||||
for _ in range(num_outputs - break_point):
|
||||
output = sess.run(get_next_ops)
|
||||
output = self.evaluate(get_next_ops)
|
||||
for i in range(num_pipelines):
|
||||
all_outputs[i].append(output[i])
|
||||
|
||||
|
@ -108,7 +108,7 @@ class ShuffleAndRepeatTest(test_base.DatasetTestBase):
|
||||
shuffle_ops.shuffle_and_repeat(buffer_size=21))
|
||||
get_next_op = ds.make_one_shot_iterator().get_next()
|
||||
with self.session(graph=g) as sess:
|
||||
sess.run(get_next_op)
|
||||
self.evaluate(get_next_op)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -38,10 +38,10 @@ class SleepTest(test_base.DatasetTestBase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
start_time = time.time()
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
end_time = time.time()
|
||||
self.assertGreater(end_time - start_time, (10 * sleep_microseconds) / 1e6)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
|
@ -39,8 +39,9 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
for _ in range(2): # Dataset is repeated. See setUp.
|
||||
self.assertEqual((b"John", b"Doe", b"Hi!"), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", b"Moe", b"Hi again!"), sess.run(get_next))
|
||||
self.assertEqual((b"John", b"Doe", b"Hi!"), self.evaluate(get_next))
|
||||
self.assertEqual((b"Jane", b"Moe", b"Hi again!"),
|
||||
self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -58,7 +59,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
"ON students.first_name = people.first_name "
|
||||
"AND students.last_name = people.last_name"
|
||||
})
|
||||
self.assertEqual((b"John", b"California", b"Hi!"), sess.run(get_next))
|
||||
self.assertEqual((b"John", b"California", b"Hi!"),
|
||||
self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -75,8 +77,9 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
"SELECT first_name, last_name, favorite_nonsense_word "
|
||||
"FROM students ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", b"Doe", b"n\0nsense"), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", b"Moe", b"nonsense\0"), sess.run(get_next))
|
||||
self.assertEqual((b"John", b"Doe", b"n\0nsense"), self.evaluate(get_next))
|
||||
self.assertEqual((b"Jane", b"Moe", b"nonsense\0"),
|
||||
self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -93,8 +96,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
self.query: "SELECT first_name, last_name, motto FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", b"Doe", b"Hi!"), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", b"Moe", b"Hi again!"), sess.run(get_next))
|
||||
self.assertEqual((b"John", b"Doe", b"Hi!"), self.evaluate(get_next))
|
||||
self.assertEqual((b"Jane", b"Moe", b"Hi again!"), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
sess.run(
|
||||
@ -103,7 +106,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
self.query: "SELECT first_name, last_name, state FROM people "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", b"Doe", b"California"), sess.run(get_next))
|
||||
self.assertEqual((b"John", b"Doe", b"California"),
|
||||
self.evaluate(get_next))
|
||||
self.assertEqual((b"Benjamin", b"Franklin", b"Pennsylvania"),
|
||||
sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -212,8 +216,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
self.query: "SELECT first_name, desk_number FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", 9), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", 127), sess.run(get_next))
|
||||
self.assertEqual((b"John", 9), self.evaluate(get_next))
|
||||
self.assertEqual((b"Jane", 127), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -230,7 +234,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
"FROM students "
|
||||
"WHERE first_name = 'John' ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", 0, -2), sess.run(get_next))
|
||||
self.assertEqual((b"John", 0, -2), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -246,9 +250,9 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
"SELECT desk_number, favorite_negative_number FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((9, -2), sess.run(get_next))
|
||||
self.assertEqual((9, -2), self.evaluate(get_next))
|
||||
# Max and min values of int8
|
||||
self.assertEqual((127, -128), sess.run(get_next))
|
||||
self.assertEqual((127, -128), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -263,8 +267,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
self.query: "SELECT first_name, desk_number FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", 9), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", 127), sess.run(get_next))
|
||||
self.assertEqual((b"John", 9), self.evaluate(get_next))
|
||||
self.assertEqual((b"Jane", 127), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -281,7 +285,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
"FROM students "
|
||||
"WHERE first_name = 'John' ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", 0, -2), sess.run(get_next))
|
||||
self.assertEqual((b"John", 0, -2), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -297,9 +301,9 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
"FROM students ORDER BY first_name DESC"
|
||||
})
|
||||
# Max value of int16
|
||||
self.assertEqual((b"John", 32767), sess.run(get_next))
|
||||
self.assertEqual((b"John", 32767), self.evaluate(get_next))
|
||||
# Min value of int16
|
||||
self.assertEqual((b"Jane", -32768), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", -32768), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -314,8 +318,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
self.query: "SELECT first_name, desk_number FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", 9), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", 127), sess.run(get_next))
|
||||
self.assertEqual((b"John", 9), self.evaluate(get_next))
|
||||
self.assertEqual((b"Jane", 127), self.evaluate(get_next))
|
||||
|
||||
# Test that `SqlDataset` can read a negative or 0-valued integer from a
|
||||
# SQLite database table and place it in an `int32` tensor.
|
||||
@ -328,8 +332,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
self.query: "SELECT first_name, income FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", 0), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", -20000), sess.run(get_next))
|
||||
self.assertEqual((b"John", 0), self.evaluate(get_next))
|
||||
self.assertEqual((b"Jane", -20000), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -345,9 +349,9 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
# Max value of int32
|
||||
self.assertEqual((b"John", 2147483647), sess.run(get_next))
|
||||
self.assertEqual((b"John", 2147483647), self.evaluate(get_next))
|
||||
# Min value of int32
|
||||
self.assertEqual((b"Jane", -2147483648), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", -2147483648), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -362,8 +366,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
self.query: "SELECT first_name, school_id FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", 123), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", 1000), sess.run(get_next))
|
||||
self.assertEqual((b"John", 123), self.evaluate(get_next))
|
||||
self.assertEqual((b"Jane", 1000), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -378,8 +382,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
self.query: "SELECT first_name, desk_number FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", 9), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", 127), sess.run(get_next))
|
||||
self.assertEqual((b"John", 9), self.evaluate(get_next))
|
||||
self.assertEqual((b"Jane", 127), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -394,8 +398,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
self.query: "SELECT first_name, income FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", 0), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", -20000), sess.run(get_next))
|
||||
self.assertEqual((b"John", 0), self.evaluate(get_next))
|
||||
self.assertEqual((b"Jane", -20000), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -412,9 +416,9 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
# Max value of int64
|
||||
self.assertEqual((b"John", 9223372036854775807), sess.run(get_next))
|
||||
self.assertEqual((b"John", 9223372036854775807), self.evaluate(get_next))
|
||||
# Min value of int64
|
||||
self.assertEqual((b"Jane", -9223372036854775808), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", -9223372036854775808), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -429,8 +433,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
self.query: "SELECT first_name, desk_number FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", 9), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", 127), sess.run(get_next))
|
||||
self.assertEqual((b"John", 9), self.evaluate(get_next))
|
||||
self.assertEqual((b"Jane", 127), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -446,9 +450,9 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
# Min value of uint8
|
||||
self.assertEqual((b"John", 0), sess.run(get_next))
|
||||
self.assertEqual((b"John", 0), self.evaluate(get_next))
|
||||
# Max value of uint8
|
||||
self.assertEqual((b"Jane", 255), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", 255), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -463,8 +467,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
self.query: "SELECT first_name, desk_number FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", 9), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", 127), sess.run(get_next))
|
||||
self.assertEqual((b"John", 9), self.evaluate(get_next))
|
||||
self.assertEqual((b"Jane", 127), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -480,9 +484,9 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
# Min value of uint16
|
||||
self.assertEqual((b"John", 0), sess.run(get_next))
|
||||
self.assertEqual((b"John", 0), self.evaluate(get_next))
|
||||
# Max value of uint16
|
||||
self.assertEqual((b"Jane", 65535), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", 65535), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -499,8 +503,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
"SELECT first_name, registration_complete FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", True), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", False), sess.run(get_next))
|
||||
self.assertEqual((b"John", True), self.evaluate(get_next))
|
||||
self.assertEqual((b"Jane", False), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -515,8 +519,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
self.query: "SELECT first_name, favorite_medium_sized_number "
|
||||
"FROM students ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", True), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", True), sess.run(get_next))
|
||||
self.assertEqual((b"John", True), self.evaluate(get_next))
|
||||
self.assertEqual((b"Jane", True), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -533,8 +537,9 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
"SELECT first_name, last_name, victories FROM townspeople "
|
||||
"ORDER BY first_name"
|
||||
})
|
||||
self.assertEqual((b"George", b"Washington", 20.0), sess.run(get_next))
|
||||
self.assertEqual((b"John", b"Adams", -19.95), sess.run(get_next))
|
||||
self.assertEqual((b"George", b"Washington", 20.0),
|
||||
self.evaluate(get_next))
|
||||
self.assertEqual((b"John", b"Adams", -19.95), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
|
@ -74,18 +74,18 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
summary_t = aggregator.get_summary()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
expected_sum = 0.0
|
||||
for i in range(100):
|
||||
self.assertAllEqual(
|
||||
np.array([i] * i, dtype=np.int64), sess.run(next_element))
|
||||
summary_str = sess.run(summary_t)
|
||||
summary_str = self.evaluate(summary_t)
|
||||
self._assertSummaryHasCount(summary_str, "bytes_produced", float(i + 1))
|
||||
expected_sum += i * 8.0
|
||||
self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
summary_str = sess.run(summary_t)
|
||||
summary_str = self.evaluate(summary_t)
|
||||
self._assertSummaryHasCount(summary_str, "bytes_produced", 100.0)
|
||||
self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum)
|
||||
|
||||
@ -99,14 +99,15 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
summary_t = aggregator.get_summary()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(100):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
self._assertSummaryHasCount(
|
||||
sess.run(summary_t), "record_latency", float(i + 1))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0)
|
||||
self._assertSummaryHasCount(
|
||||
self.evaluate(summary_t), "record_latency", 100.0)
|
||||
|
||||
def testPrefetchBufferUtilization(self, dataset_transformation):
|
||||
aggregator = stats_aggregator.StatsAggregator()
|
||||
@ -118,11 +119,11 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
summary_t = aggregator.get_summary()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(100):
|
||||
self.assertAllEqual(
|
||||
np.array([i] * i, dtype=np.int64), sess.run(next_element))
|
||||
summary_str = sess.run(summary_t)
|
||||
summary_str = self.evaluate(summary_t)
|
||||
self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
|
||||
float(i + 1))
|
||||
self._assertSummaryContains(summary_str, "Prefetch::buffer_capacity")
|
||||
@ -131,7 +132,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
0, 1)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
summary_str = sess.run(summary_t)
|
||||
summary_str = self.evaluate(summary_t)
|
||||
self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
|
||||
100)
|
||||
|
||||
@ -145,11 +146,11 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
summary_t = aggregator.get_summary()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(10):
|
||||
self.assertAllEqual(
|
||||
np.array([i] * i, dtype=np.int64), sess.run(next_element))
|
||||
summary_str = sess.run(summary_t)
|
||||
summary_str = self.evaluate(summary_t)
|
||||
self._assertSummaryHasScalarValue(summary_str,
|
||||
"Prefetch::buffer_capacity", 0)
|
||||
self._assertSummaryHasScalarValue(summary_str, "Prefetch::buffer_size",
|
||||
@ -167,9 +168,9 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
summary_t = aggregator.get_summary()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(34):
|
||||
self.assertEqual(i * 3, sess.run(next_element))
|
||||
self.assertEqual(i * 3, self.evaluate(next_element))
|
||||
if i is not 0:
|
||||
self._assertSummaryHasScalarValue(
|
||||
sess.run(summary_t), "Filter::dropped_elements", float(i * 2))
|
||||
@ -261,9 +262,9 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for j in range(5):
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(100):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
self._assertSummaryHasCount(
|
||||
sess.run(summary_t), "record_latency", float((j * 100) + i + 1))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -278,9 +279,9 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(100):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -295,16 +296,17 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
summary_t = aggregator.get_summary()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(100):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
self._assertSummaryHasCount(
|
||||
sess.run(summary_t), "record_latency", float(i + 1))
|
||||
self._assertSummaryHasCount(
|
||||
sess.run(summary_t), "record_latency_2", float(i + 1))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0)
|
||||
self._assertSummaryHasCount(
|
||||
self.evaluate(summary_t), "record_latency", 100.0)
|
||||
self._assertSummaryHasCount(
|
||||
sess.run(summary_t), "record_latency_2", 100.0)
|
||||
|
||||
@ -319,14 +321,15 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
summary_t = aggregator.get_summary()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(100):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
self._assertSummaryHasCount(
|
||||
sess.run(summary_t), "record_latency", float(2 * (i + 1)))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0)
|
||||
self._assertSummaryHasCount(
|
||||
self.evaluate(summary_t), "record_latency", 200.0)
|
||||
|
||||
def testMultipleIteratorsSameAggregator(self, dataset_transformation):
|
||||
aggregator = stats_aggregator.StatsAggregator()
|
||||
@ -341,12 +344,13 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
with self.cached_session() as sess:
|
||||
sess.run([iterator_0.initializer, iterator_1.initializer])
|
||||
for i in range(100):
|
||||
self.assertEqual(i * 2, sess.run(next_element))
|
||||
self.assertEqual(i * 2, self.evaluate(next_element))
|
||||
self._assertSummaryHasCount(
|
||||
sess.run(summary_t), "record_latency", float(2 * (i + 1)))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0)
|
||||
self._assertSummaryHasCount(
|
||||
self.evaluate(summary_t), "record_latency", 200.0)
|
||||
|
||||
def testMultipleDatasetWithPrefixes(self, dataset_transformation):
|
||||
aggregator = stats_aggregator.StatsAggregator()
|
||||
@ -364,7 +368,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
with self.test_session() as sess:
|
||||
sess.run([iterator_0.initializer, iterator_1.initializer])
|
||||
for i in range(100):
|
||||
self.assertEqual(i * 2, sess.run(next_element))
|
||||
self.assertEqual(i * 2, self.evaluate(next_element))
|
||||
self._assertSummaryHasCount(
|
||||
sess.run(summary_t), "dataset1_record_latency", float(i + 1))
|
||||
self._assertSummaryHasCount(
|
||||
@ -421,7 +425,7 @@ class FeatureStatsDatasetTest(
|
||||
summary_t = aggregator.get_summary()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
for _ in range(num_output):
|
||||
sess.run(next_element)
|
||||
|
||||
|
@ -50,7 +50,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer, feed_dict={placeholder: [0, 1, 2, 3]})
|
||||
for i in range(4):
|
||||
self.assertEqual(i, sess.run(next_elem))
|
||||
self.assertEqual(i, self.evaluate(next_elem))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_elem)
|
||||
|
||||
@ -68,7 +68,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual((i,) * 3, sess.run(op))
|
||||
self.assertEqual((i,) * 3, self.evaluate(op))
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(op)
|
||||
@ -88,7 +88,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual((i, compat.as_bytes(str(i)), i), sess.run(op))
|
||||
self.assertEqual((i, compat.as_bytes(str(i)), i), self.evaluate(op))
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(op)
|
||||
@ -107,7 +107,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
st_row = sess.run(next_element)
|
||||
st_row = self.evaluate(next_element)
|
||||
self.assertEqual([i], st_row.indices)
|
||||
self.assertEqual([i], st_row.values)
|
||||
self.assertEqual([10], st_row.dense_shape)
|
||||
@ -128,7 +128,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
dense_elem, st_row = sess.run(next_element)
|
||||
dense_elem, st_row = self.evaluate(next_element)
|
||||
self.assertEqual(i, dense_elem)
|
||||
self.assertEqual([i], st_row.indices)
|
||||
self.assertEqual([i], st_row.values)
|
||||
@ -150,7 +150,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual(((i,),) * 3, sess.run(op))
|
||||
self.assertEqual(((i,),) * 3, self.evaluate(op))
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(op)
|
||||
|
@ -49,11 +49,11 @@ class UniqueTest(test_base.DatasetTestBase):
|
||||
with self.cached_session() as sess:
|
||||
for test_case, expected in test_cases:
|
||||
current_test_case = test_case
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
for element in expected:
|
||||
if dtype == dtypes.string:
|
||||
element = compat.as_bytes(element)
|
||||
self.assertAllEqual(element, sess.run(next_element))
|
||||
self.assertAllEqual(element, self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
|
@ -93,13 +93,13 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
})
|
||||
num_full_batches = (count * 7) // batch_size
|
||||
for i in range(num_full_batches):
|
||||
result = sess.run(get_next)
|
||||
result = self.evaluate(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
for j in range(batch_size):
|
||||
self.assertAllEqual(component[(i * batch_size + j) % 7]**2,
|
||||
result_component[j])
|
||||
if not drop_remainder and (count * 7) % batch_size > 0:
|
||||
result = sess.run(get_next)
|
||||
result = self.evaluate(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
for j in range((count * 7) % batch_size):
|
||||
self.assertAllEqual(
|
||||
@ -128,9 +128,9 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
for i in range(2):
|
||||
actual = sess.run(get_next)
|
||||
actual = self.evaluate(get_next)
|
||||
expected = sparse_tensor.SparseTensorValue(
|
||||
indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
|
||||
values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4],
|
||||
@ -155,9 +155,9 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
for i in range(2):
|
||||
actual = sess.run(get_next)
|
||||
actual = self.evaluate(get_next)
|
||||
expected_indices = []
|
||||
expected_values = []
|
||||
for j in range(5):
|
||||
@ -185,8 +185,8 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
actual = sess.run(get_next)
|
||||
self.evaluate(init_op)
|
||||
actual = self.evaluate(get_next)
|
||||
expected = sparse_tensor.SparseTensorValue(
|
||||
indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [0, 4, 0],
|
||||
[1, 0, 0], [1, 1, 0], [1, 2, 0], [1, 3, 0], [1, 4, 0]],
|
||||
@ -211,7 +211,7 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
with self.assertRaisesRegexp(
|
||||
errors.InvalidArgumentError,
|
||||
r'Cannot batch tensors with different shapes in component 0. '
|
||||
@ -271,7 +271,7 @@ class PaddedBatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
num_full_batches = len(seq_lens) // batch_size
|
||||
|
||||
for i in range(num_full_batches):
|
||||
result = sess.run(get_next)
|
||||
result = self.evaluate(get_next)
|
||||
padded_len = padded_shapes[0]
|
||||
if padded_len is None or padded_len == -1:
|
||||
padded_len = np.max(result) if result.size > 0 else 0
|
||||
@ -283,7 +283,7 @@ class PaddedBatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
[0] * (padded_len - seq_len))
|
||||
|
||||
if not drop_remainder and len(seq_lens) % batch_size > 0:
|
||||
result = sess.run(get_next)
|
||||
result = self.evaluate(get_next)
|
||||
padded_len = np.max(result) if result.size > 0 else 0
|
||||
self.assertEqual((len(seq_lens) % batch_size, padded_len),
|
||||
result.shape)
|
||||
@ -315,7 +315,7 @@ class PaddedBatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
result = sess.run(get_next)
|
||||
result = self.evaluate(get_next)
|
||||
self.assertAllEqual([[], [], [], []], result)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
@ -347,7 +347,7 @@ class PaddedBatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
seq_lens: random_seq_lens
|
||||
})
|
||||
for i in range(8):
|
||||
result = sess.run(get_next)
|
||||
result = self.evaluate(get_next)
|
||||
padded_len = np.max(result[0])
|
||||
self.assertEqual((4, padded_len), result[0].shape)
|
||||
self.assertEqual((4, padded_len), result[1].shape)
|
||||
|
@ -71,7 +71,7 @@ class FileCacheDatasetTest(test_base.DatasetTestBase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
# First run without caching to collect the "ground truth".
|
||||
sess.run(init_fifo_op)
|
||||
self.evaluate(init_fifo_op)
|
||||
elements = []
|
||||
for _ in range(20):
|
||||
elements.append(sess.run(get_next))
|
||||
@ -220,14 +220,14 @@ class MemoryCacheDatasetTest(test_base.DatasetTestBase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
|
||||
sess.run(repeat_count.initializer)
|
||||
sess.run(cached_iterator.initializer)
|
||||
sess.run(uncached_iterator.initializer)
|
||||
self.evaluate(repeat_count.initializer)
|
||||
self.evaluate(cached_iterator.initializer)
|
||||
self.evaluate(uncached_iterator.initializer)
|
||||
|
||||
for i in range(3):
|
||||
for _ in range(10):
|
||||
self.assertEqual(sess.run(cached_next), i)
|
||||
self.assertEqual(sess.run(uncached_next), i)
|
||||
self.assertEqual(self.evaluate(cached_next), i)
|
||||
self.assertEqual(self.evaluate(uncached_next), i)
|
||||
|
||||
sess.run(repeat_count.assign(0))
|
||||
|
||||
@ -238,7 +238,7 @@ class MemoryCacheDatasetTest(test_base.DatasetTestBase):
|
||||
# The cached iterator replays from cache.
|
||||
for i in range(3):
|
||||
for _ in range(10):
|
||||
self.assertEqual(sess.run(cached_next), i)
|
||||
self.assertEqual(self.evaluate(cached_next), i)
|
||||
|
||||
# The cached iterator should now be empty.
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -280,7 +280,7 @@ class MemoryCacheDatasetTest(test_base.DatasetTestBase):
|
||||
i2 = d2.make_initializable_iterator()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(i1.initializer)
|
||||
self.evaluate(i1.initializer)
|
||||
|
||||
self.assertEqual(1, sess.run(i1.get_next()))
|
||||
self.assertEqual(2, sess.run(i1.get_next()))
|
||||
@ -307,7 +307,7 @@ class MemoryCacheDatasetTest(test_base.DatasetTestBase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i, expected in enumerate(expected_values):
|
||||
self.assertEqual(expected, sess.run(n),
|
||||
self.assertEqual(expected, self.evaluate(n),
|
||||
"Unexpected value at index %s" % i)
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
|
@ -51,9 +51,9 @@ class ConcatenateDatasetTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
for i in range(9):
|
||||
result = sess.run(get_next)
|
||||
result = self.evaluate(get_next)
|
||||
if i < 4:
|
||||
for component, result_component in zip(input_components, result):
|
||||
self.assertAllEqual(component[i], result_component)
|
||||
@ -85,9 +85,9 @@ class ConcatenateDatasetTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
for i in range(9):
|
||||
result = sess.run(get_next)
|
||||
result = self.evaluate(get_next)
|
||||
if i < 4:
|
||||
for component, result_component in zip(input_components, result):
|
||||
self.assertAllEqual(component[i], result_component)
|
||||
|
@ -52,8 +52,8 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
[t.shape for t in get_next])
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
results = sess.run(get_next)
|
||||
self.evaluate(init_op)
|
||||
results = self.evaluate(get_next)
|
||||
for component, result_component in zip(components, results):
|
||||
self.assertAllEqual(component, result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -81,8 +81,8 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
[shape for shape in iterator.output_shapes])
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
results = sess.run(get_next)
|
||||
self.evaluate(init_op)
|
||||
results = self.evaluate(get_next)
|
||||
for component, result_component in zip(components, results):
|
||||
self.assertSparseValuesEqual(component, result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -112,8 +112,8 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
], [shape for shape in iterator.output_shapes])
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
results = sess.run(get_next)
|
||||
self.evaluate(init_op)
|
||||
results = self.evaluate(get_next)
|
||||
for component, result_component in zip(components, results):
|
||||
if sparse_tensor.is_sparse(component):
|
||||
self.assertSparseValuesEqual(component, result_component)
|
||||
@ -139,9 +139,9 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
[t.shape for t in get_next])
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
for i in range(4):
|
||||
results = sess.run(get_next)
|
||||
results = self.evaluate(get_next)
|
||||
for component, result_component in zip(components, results):
|
||||
self.assertAllEqual(component[i], result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -169,7 +169,7 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
[shape for shape in iterator.output_shapes])
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
expected = [
|
||||
(sparse_tensor.SparseTensorValue(
|
||||
indices=np.array([[0]]),
|
||||
@ -197,7 +197,7 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
dense_shape=np.array([3]))),
|
||||
]
|
||||
for i in range(3):
|
||||
results = sess.run(get_next)
|
||||
results = self.evaluate(get_next)
|
||||
for component, result_component in zip(expected[i], results):
|
||||
self.assertSparseValuesEqual(component, result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -229,7 +229,7 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
], [shape for shape in iterator.output_shapes])
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
expected = [
|
||||
(sparse_tensor.SparseTensorValue(
|
||||
indices=np.array([[0]]),
|
||||
@ -257,7 +257,7 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
dense_shape=np.array([3]))),
|
||||
]
|
||||
for i in range(3):
|
||||
results = sess.run(get_next)
|
||||
results = self.evaluate(get_next)
|
||||
for component, result_component in zip(
|
||||
(list(zip(*components[:3]))[i] + expected[i]), results):
|
||||
if sparse_tensor.is_sparse(component):
|
||||
@ -280,9 +280,9 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
self.assertEqual((1,), iterator.output_shapes["bar"])
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
for i in range(3):
|
||||
results = sess.run(get_next)
|
||||
results = self.evaluate(get_next)
|
||||
self.assertEqual(components["foo"][i], results["foo"])
|
||||
self.assertEqual(components["bar"][i], results["bar"])
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -308,7 +308,7 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
dense_shape)
|
||||
sess.run(init_op, feed_dict={st: sparse_feed})
|
||||
for i, s in enumerate(slices):
|
||||
results = sess.run(get_next)
|
||||
results = self.evaluate(get_next)
|
||||
self.assertAllEqual(s, results.values)
|
||||
expected_indices = np.array(
|
||||
[[j] for j in range(len(slices[i]))]).reshape([-1, 1])
|
||||
@ -474,15 +474,15 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
with ops.device("/cpu:0"):
|
||||
var_0 = resource_variable_ops.ResourceVariable(initial_value=0)
|
||||
dataset = dataset.map(lambda x: x + var_0.read_value())
|
||||
sess.run(var_0.initializer)
|
||||
self.evaluate(var_0.initializer)
|
||||
|
||||
with ops.device("/cpu:1"):
|
||||
var_1 = resource_variable_ops.ResourceVariable(initial_value=0)
|
||||
dataset = dataset.map(lambda x: x + var_1.read_value())
|
||||
sess.run(var_1.initializer)
|
||||
self.evaluate(var_1.initializer)
|
||||
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
errors.FailedPreconditionError,
|
||||
@ -506,7 +506,7 @@ class DatasetConstructorBenchmark(test.Benchmark):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with session.Session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
# Run one whole epoch to burn in the computation.
|
||||
for _ in range(input_size // batch_size):
|
||||
sess.run(next_element)
|
||||
@ -543,7 +543,7 @@ class DatasetConstructorBenchmark(test.Benchmark):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with session.Session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
get_next_element = sess.make_callable(next_element)
|
||||
# Run one whole epoch to burn in the computation.
|
||||
for _ in range(input_size // batch_size):
|
||||
@ -582,7 +582,7 @@ class DatasetConstructorBenchmark(test.Benchmark):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with session.Session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
get_next_element = sess.make_callable(next_element)
|
||||
# Run one whole epoch to burn in the computation.
|
||||
for _ in range(input_size // batch_size):
|
||||
@ -620,7 +620,7 @@ class DatasetConstructorBenchmark(test.Benchmark):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with session.Session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
get_next_element = sess.make_callable(next_element)
|
||||
# Run one whole epoch to burn in the computation.
|
||||
for _ in range(input_size // batch_size):
|
||||
|
@ -47,10 +47,10 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for _ in range(2): # Run twice to test reinitialization.
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
for _ in range(num_repeats):
|
||||
for elem in elem_sequence:
|
||||
self.assertAllEqual(elem, sess.run(get_next))
|
||||
self.assertAllEqual(elem, self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -65,7 +65,7 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
with self.cached_session() as sess:
|
||||
for _ in range(num_repeats):
|
||||
for elem in elem_sequence:
|
||||
self.assertAllEqual(elem, sess.run(get_next))
|
||||
self.assertAllEqual(elem, self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -133,10 +133,10 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
for _ in range(num_inner_repeats * num_outer_repeats):
|
||||
for elem in input_list:
|
||||
val0, val1 = sess.run(get_next)
|
||||
val0, val1 = self.evaluate(get_next)
|
||||
self.assertAllEqual(elem[0], val0)
|
||||
self.assertAllEqual(elem[1], val1)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -192,10 +192,10 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
for elem in [0, 1]:
|
||||
for _ in range(num_parallel_iterators):
|
||||
self.assertAllEqual(elem, sess.run(get_next))
|
||||
self.assertAllEqual(elem, self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -215,9 +215,9 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
self.assertEqual(dtype, get_next.dtype)
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
for expected in [[1], [2], [3]]:
|
||||
next_val = sess.run(get_next)
|
||||
next_val = self.evaluate(get_next)
|
||||
self.assertEqual(dtype.as_numpy_dtype, next_val.dtype)
|
||||
self.assertAllEqual(expected, next_val)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -236,9 +236,9 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
for expected in [b"foo", b"bar", b"baz"]:
|
||||
next_val = sess.run(get_next)
|
||||
next_val = self.evaluate(get_next)
|
||||
self.assertAllEqual(expected, next_val)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
@ -257,12 +257,12 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.assertAllEqual([1, 2, 3], sess.run(get_next))
|
||||
self.assertAllEqual([4, 5, 6], sess.run(get_next))
|
||||
self.evaluate(init_op)
|
||||
self.assertAllEqual([1, 2, 3], self.evaluate(get_next))
|
||||
self.assertAllEqual([4, 5, 6], self.evaluate(get_next))
|
||||
with self.assertRaisesOpError("The expected type was int64"):
|
||||
sess.run(get_next)
|
||||
self.assertAllEqual([7, 8, 9], sess.run(get_next))
|
||||
self.assertAllEqual([7, 8, 9], self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -280,12 +280,12 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.assertAllEqual([1, 2, 3], sess.run(get_next))
|
||||
self.assertAllEqual([4, 5, 6], sess.run(get_next))
|
||||
self.evaluate(init_op)
|
||||
self.assertAllEqual([1, 2, 3], self.evaluate(get_next))
|
||||
self.assertAllEqual([4, 5, 6], self.evaluate(get_next))
|
||||
with self.assertRaisesOpError(r"element of shape \(3,\) was expected"):
|
||||
sess.run(get_next)
|
||||
self.assertAllEqual([11, 12, 13], sess.run(get_next))
|
||||
self.assertAllEqual([11, 12, 13], self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -304,16 +304,16 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.assertEqual((1, 2), sess.run(get_next))
|
||||
self.assertEqual((3, 4), sess.run(get_next))
|
||||
self.evaluate(init_op)
|
||||
self.assertEqual((1, 2), self.evaluate(get_next))
|
||||
self.assertEqual((3, 4), self.evaluate(get_next))
|
||||
with self.assertRaisesOpError(
|
||||
r"The expected structure was \(tf\.int64, tf\.int64\)"):
|
||||
sess.run(get_next)
|
||||
with self.assertRaisesOpError(
|
||||
r"The expected structure was \(tf\.int64, tf\.int64\)"):
|
||||
sess.run(get_next)
|
||||
self.assertEqual((9, 10), sess.run(get_next))
|
||||
self.assertEqual((9, 10), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -329,9 +329,9 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.assertAllEqual(1, sess.run(get_next))
|
||||
self.assertAllEqual([2, 3], sess.run(get_next))
|
||||
self.evaluate(init_op)
|
||||
self.assertAllEqual(1, self.evaluate(get_next))
|
||||
self.assertAllEqual([2, 3], self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -349,9 +349,9 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.assertAllEqual(0, sess.run(get_next))
|
||||
self.assertAllEqual(1, sess.run(get_next))
|
||||
self.evaluate(init_op)
|
||||
self.assertAllEqual(0, self.evaluate(get_next))
|
||||
self.assertAllEqual(1, self.evaluate(get_next))
|
||||
|
||||
def testFromGeneratorDestructorCalled(self):
|
||||
# Use an `Event` to signal that the generator has been deleted.
|
||||
@ -378,9 +378,9 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with session.Session() as sess:
|
||||
sess.run(init_op)
|
||||
self.assertAllEqual(42, sess.run(get_next))
|
||||
self.assertAllEqual(42, sess.run(get_next))
|
||||
self.evaluate(init_op)
|
||||
self.assertAllEqual(42, self.evaluate(get_next))
|
||||
self.assertAllEqual(42, self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
# Test that `GeneratorWrapper` object is destroyed when the
|
||||
@ -407,10 +407,10 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
expected = [1, 2, 2, 3, 3, 3, 4, 4, 4, 4]
|
||||
for x in expected:
|
||||
self.assertEqual(x, sess.run(get_next))
|
||||
self.assertEqual(x, self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -436,13 +436,13 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
expected = [(0, b"Hi!"),
|
||||
(0, b"Hi!"), (1, b"Hi!"),
|
||||
(0, b"Hi!"), (1, b"Hi!"), (2, b"Hi!"),
|
||||
(0, b"Hi!"), (1, b"Hi!"), (2, b"Hi!"), (3, b"Hi!")]
|
||||
for x in expected:
|
||||
self.assertEqual(x, sess.run(get_next))
|
||||
self.assertEqual(x, self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -470,9 +470,9 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.assertAllEqual(37, sess.run(get_next))
|
||||
self.assertAllEqual(37, sess.run(get_next))
|
||||
self.evaluate(init_op)
|
||||
self.assertAllEqual(37, self.evaluate(get_next))
|
||||
self.assertAllEqual(37, self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.assertTrue(event.is_set())
|
||||
|
@ -67,7 +67,7 @@ class FilterDatasetTest(test_base.DatasetTestBase):
|
||||
sess.run(init_op, feed_dict={count: count_val, modulus: modulus_val})
|
||||
for _ in range(count_val):
|
||||
for i in [x for x in range(7) if x**2 % modulus_val == 0]:
|
||||
result = sess.run(get_next)
|
||||
result = self.evaluate(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
self.assertAllEqual(component[i]**2, result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -86,9 +86,9 @@ class FilterDatasetTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(0, sess.run(get_next))
|
||||
self.assertEqual(1, sess.run(get_next))
|
||||
self.assertEqual(3, sess.run(get_next))
|
||||
self.assertEqual(0, self.evaluate(get_next))
|
||||
self.assertEqual(1, self.evaluate(get_next))
|
||||
self.assertEqual(3, self.evaluate(get_next))
|
||||
|
||||
def testFilterDict(self):
|
||||
iterator = (dataset_ops.Dataset.range(10)
|
||||
@ -100,10 +100,10 @@ class FilterDatasetTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
for i in range(10):
|
||||
if (i ** 2) % 2 == 0:
|
||||
self.assertEqual(i * 2 + i ** 2, sess.run(get_next))
|
||||
self.assertEqual(i * 2 + i**2, self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -125,8 +125,8 @@ class FilterDatasetTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.assertAllEqual(input_data[0], sess.run(get_next))
|
||||
self.evaluate(init_op)
|
||||
self.assertAllEqual(input_data[0], self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -148,9 +148,9 @@ class FilterDatasetTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
for i in range(5):
|
||||
actual = sess.run(get_next)
|
||||
actual = self.evaluate(get_next)
|
||||
self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue))
|
||||
self.assertSparseValuesEqual(actual, _map_fn(i * 2)[0])
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -166,9 +166,9 @@ class FilterDatasetTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
for i in range(10):
|
||||
self.assertEqual((i, True), sess.run(get_next))
|
||||
self.assertEqual((i, True), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -178,7 +178,7 @@ class FilterDatasetTest(test_base.DatasetTestBase):
|
||||
iterators = [dataset.make_one_shot_iterator() for _ in range(10)]
|
||||
next_elements = [iterator.get_next() for iterator in iterators]
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual([0 for _ in range(10)], sess.run(next_elements))
|
||||
self.assertEqual([0 for _ in range(10)], self.evaluate(next_elements))
|
||||
|
||||
|
||||
class FilterDatasetBenchmark(test.Benchmark):
|
||||
|
@ -45,10 +45,10 @@ class FlatMapDatasetTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
for i in repeats:
|
||||
for _ in range(i):
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -64,11 +64,11 @@ class FlatMapDatasetTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
for row in repeats:
|
||||
for i in row:
|
||||
for _ in range(i):
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
@ -94,12 +94,12 @@ class FlatMapDatasetTest(test_base.DatasetTestBase):
|
||||
with session.Session(server.target) as sess2:
|
||||
for _ in range(3):
|
||||
sess = random.choice([sess1, sess2])
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
for row in repeats:
|
||||
for i in row:
|
||||
for _ in range(i):
|
||||
sess = random.choice([sess1, sess2])
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess = random.choice([sess1, sess2])
|
||||
@ -115,10 +115,10 @@ class FlatMapDatasetTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
for i in range(10):
|
||||
for _ in range(i ** 2):
|
||||
self.assertEqual(i * 2, sess.run(get_next))
|
||||
self.assertEqual(i * 2, self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
# pylint: enable=g-long-lambda
|
||||
@ -139,11 +139,11 @@ class FlatMapDatasetTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
for i in range(10):
|
||||
for j in range(2):
|
||||
expected = [i, 0] if j % 2 == 0 else [0, -i]
|
||||
self.assertAllEqual(expected, sess.run(get_next))
|
||||
self.assertAllEqual(expected, self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
|
@ -196,7 +196,7 @@ class InterleaveDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
for expected_element in _interleave(
|
||||
_repeat(input_values, count), cycle_length, block_length):
|
||||
self.assertEqual(expected_element, sess.run(get_next))
|
||||
self.assertEqual(expected_element, self.evaluate(get_next))
|
||||
|
||||
for _ in range(2):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -231,7 +231,7 @@ class InterleaveDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(get_next)
|
||||
else:
|
||||
self.assertEqual(value, sess.run(get_next))
|
||||
self.assertEqual(value, self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -254,7 +254,7 @@ class InterleaveDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
for i in range(10):
|
||||
for j in range(2):
|
||||
expected = [i, 0] if j % 2 == 0 else [0, -i]
|
||||
self.assertAllEqual(expected, sess.run(get_next))
|
||||
self.assertAllEqual(expected, self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -308,7 +308,7 @@ class InterleaveDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
for element in elements:
|
||||
coordination_events[element].set()
|
||||
self.assertEqual(element * element, sess.run(get_next))
|
||||
self.assertEqual(element * element, self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
|
@ -57,7 +57,7 @@ class IteratorClusterTest(test.TestCase):
|
||||
|
||||
with session.Session(worker[0].target) as sess:
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(get_next_op)
|
||||
self.evaluate(get_next_op)
|
||||
|
||||
def _testRemoteIteratorHelper(self, device0, device1, target):
|
||||
with ops.device(device1):
|
||||
@ -134,12 +134,12 @@ class IteratorClusterTest(test.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with session.Session(worker[0].target) as sess:
|
||||
sess.run(table.initializer)
|
||||
sess.run(init_op)
|
||||
self.assertAllEqual([0, 0, -1, 1, 2], sess.run(get_next))
|
||||
self.evaluate(table.initializer)
|
||||
self.evaluate(init_op)
|
||||
self.assertAllEqual([0, 0, -1, 1, 2], self.evaluate(get_next))
|
||||
|
||||
with session.Session(worker[0].target) as sess:
|
||||
self.assertAllEqual([2, 0], sess.run(get_next))
|
||||
self.assertAllEqual([2, 0], self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -166,7 +166,7 @@ class IteratorClusterTest(test.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with session.Session(worker[0].target) as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
for _ in range(3):
|
||||
sess.run(get_next)
|
||||
|
||||
|
@ -97,7 +97,7 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
for _ in range(14):
|
||||
for i in range(7):
|
||||
result = sess.run(get_next)
|
||||
result = self.evaluate(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
self.assertAllEqual(component[i]**2, result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -123,7 +123,7 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
for _ in range(14):
|
||||
for i in range(7):
|
||||
result = sess.run(get_next)
|
||||
result = self.evaluate(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
self.assertAllEqual(component[i]**2, result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -159,7 +159,7 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
for _ in range(14):
|
||||
for i in range(7):
|
||||
result = sess.run(get_next)
|
||||
result = self.evaluate(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
self.assertAllEqual(component[i]**2, result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -175,7 +175,7 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
||||
config = config_pb2.ConfigProto(
|
||||
inter_op_parallelism_threads=1, use_per_session_threads=True)
|
||||
with session.Session(config=config) as sess:
|
||||
self.assertAllEqual([1, 4, 9], sess.run(next_element))
|
||||
self.assertAllEqual([1, 4, 9], self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -254,15 +254,15 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with session.Session(server.target) as sess:
|
||||
sess.run(init_op)
|
||||
results = sess.run(get_next)
|
||||
self.evaluate(init_op)
|
||||
results = self.evaluate(get_next)
|
||||
for component, result_component in zip(components, results):
|
||||
self.assertAllEqual(component, result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Re-initialize the iterator in the first session.
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
|
||||
with ops.Graph().as_default():
|
||||
# Re-define the iterator manually, without defining any of the
|
||||
@ -277,7 +277,7 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
with session.Session(server.target) as sess:
|
||||
# Use the iterator without re-initializing in the second session.
|
||||
results = sess.run(get_next)
|
||||
results = self.evaluate(get_next)
|
||||
for component, result_component in zip(components, results):
|
||||
self.assertAllEqual(component, result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -317,20 +317,20 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
||||
sess.run(get_next)
|
||||
|
||||
# Initialize with one dataset.
|
||||
sess.run(dataset_3_init_op)
|
||||
self.assertAllEqual([1, 2, 3], sess.run(get_next))
|
||||
self.evaluate(dataset_3_init_op)
|
||||
self.assertAllEqual([1, 2, 3], self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Initialize with a different dataset.
|
||||
sess.run(dataset_4_init_op)
|
||||
self.assertAllEqual([4, 5, 6, 7], sess.run(get_next))
|
||||
self.evaluate(dataset_4_init_op)
|
||||
self.assertAllEqual([4, 5, 6, 7], self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Reinitialize with the first dataset.
|
||||
sess.run(dataset_3_init_op)
|
||||
self.assertAllEqual([1, 2, 3], sess.run(get_next))
|
||||
self.evaluate(dataset_3_init_op)
|
||||
self.assertAllEqual([1, 2, 3], self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -348,7 +348,7 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
||||
g, output_types=dtypes.int64)
|
||||
sess.run(iterator.make_initializer(dataset_1))
|
||||
for expected in range(10):
|
||||
self.assertEqual(expected, sess.run(next_element))
|
||||
self.assertEqual(expected, self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -356,7 +356,7 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
||||
g, output_types=dtypes.int64)
|
||||
sess.run(iterator.make_initializer(dataset_2))
|
||||
for expected in range(10):
|
||||
self.assertEqual(expected, sess.run(next_element))
|
||||
self.assertEqual(expected, self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -679,10 +679,10 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
||||
n = itr.get_next()
|
||||
|
||||
with session.Session(s3.target, config=config) as sess:
|
||||
sess.run(itr.initializer)
|
||||
self.evaluate(itr.initializer)
|
||||
expected_values = worker_devices
|
||||
for expected in expected_values:
|
||||
self.assertEqual((compat.as_bytes(expected),), sess.run(n))
|
||||
self.assertEqual((compat.as_bytes(expected),), self.evaluate(n))
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(n)
|
||||
@ -786,8 +786,8 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, _, save_op, _ = _build_range_dataset_graph()
|
||||
with self.session(graph=g) as sess:
|
||||
sess.run(init_op)
|
||||
sess.run(save_op)
|
||||
self.evaluate(init_op)
|
||||
self.evaluate(save_op)
|
||||
|
||||
# Attempt to restore the saved iterator into an IteratorResource of
|
||||
# incompatible type. An iterator of RangeDataset has output type int64,
|
||||
@ -798,7 +798,7 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
||||
_, _, _, restore_op = _build_reader_dataset_graph()
|
||||
with self.session(graph=g) as sess:
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(restore_op)
|
||||
self.evaluate(restore_op)
|
||||
|
||||
def testRepeatedGetNextWarning(self):
|
||||
iterator = dataset_ops.Dataset.range(10).make_one_shot_iterator()
|
||||
@ -949,7 +949,7 @@ class IteratorCheckpointingTest(test.TestCase):
|
||||
checkpoint.restore(checkpoint_management.latest_checkpoint(
|
||||
checkpoint_directory)).initialize_or_restore(sess)
|
||||
for j in range(2):
|
||||
self.assertEqual(i * 2 + j, sess.run(get_next))
|
||||
self.assertEqual(i * 2 + j, self.evaluate(get_next))
|
||||
checkpoint.save(file_prefix=checkpoint_prefix)
|
||||
|
||||
|
||||
|
@ -102,7 +102,7 @@ class ListFilesDatasetOpTest(test_base.DatasetTestBase):
|
||||
all_produced_filenames = []
|
||||
for _ in range(3):
|
||||
produced_filenames = []
|
||||
sess.run(itr.initializer)
|
||||
self.evaluate(itr.initializer)
|
||||
try:
|
||||
while True:
|
||||
produced_filenames.append(sess.run(next_element))
|
||||
|
@ -114,7 +114,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
sess.run(init_op, feed_dict={count: 14})
|
||||
for _ in range(14):
|
||||
for i in range(7):
|
||||
result = sess.run(get_next)
|
||||
result = self.evaluate(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
self.assertAllEqual(component[i]**2, result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -185,7 +185,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
output_buffer_size: output_buffer_size_val})
|
||||
for _ in range(14):
|
||||
for i in range(7):
|
||||
result = sess.run(get_next)
|
||||
result = self.evaluate(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
self.assertAllEqual(component[i]**2, result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -242,7 +242,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
for _ in range(3):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -257,7 +257,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
for _ in range(3):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -272,7 +272,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
for _ in range(3):
|
||||
sess.run(get_next)
|
||||
# The 4th element is NaN, so `array_ops.check_numerics()` should fail.
|
||||
@ -293,7 +293,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
for _ in range(3):
|
||||
sess.run(get_next)
|
||||
# The 4th element is NaN, so `array_ops.check_numerics()` should fail.
|
||||
@ -325,10 +325,10 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with ops.Graph().as_default() as g:
|
||||
captured_init_op, init_op, get_next = _build_graph()
|
||||
with self.session(graph=g) as sess:
|
||||
sess.run(captured_init_op)
|
||||
sess.run(init_op)
|
||||
self.evaluate(captured_init_op)
|
||||
self.evaluate(init_op)
|
||||
for i in range(10):
|
||||
self.assertEqual(i * i, sess.run(get_next))
|
||||
self.assertEqual(i * i, self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -353,8 +353,8 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(table.initializer)
|
||||
sess.run(init_op)
|
||||
self.evaluate(table.initializer)
|
||||
self.evaluate(init_op)
|
||||
sess.run(get_next)
|
||||
sess.run(get_next)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -371,11 +371,11 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(enqueue_op)
|
||||
sess.run(close_op)
|
||||
sess.run(init_op)
|
||||
self.evaluate(enqueue_op)
|
||||
self.evaluate(close_op)
|
||||
self.evaluate(init_op)
|
||||
for element in elements:
|
||||
self.assertEqual(element, sess.run(get_next))
|
||||
self.assertEqual(element, self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -396,9 +396,9 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(enqueue_op)
|
||||
sess.run(close_op)
|
||||
sess.run(init_op)
|
||||
self.evaluate(enqueue_op)
|
||||
self.evaluate(close_op)
|
||||
self.evaluate(init_op)
|
||||
for i in range(100):
|
||||
self.assertEqual(sorted([elements[i * 2], elements[i * 2 + 1]]),
|
||||
sorted(sess.run(get_next)))
|
||||
@ -415,15 +415,15 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(counter_var.initializer)
|
||||
sess.run(init_op)
|
||||
self.evaluate(counter_var.initializer)
|
||||
self.evaluate(init_op)
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(counter_var))
|
||||
self.assertEqual(i + 1, sess.run(get_next))
|
||||
self.assertEqual(10, sess.run(counter_var))
|
||||
self.assertEqual(i, self.evaluate(counter_var))
|
||||
self.assertEqual(i + 1, self.evaluate(get_next))
|
||||
self.assertEqual(10, self.evaluate(counter_var))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.assertEqual(10, sess.run(counter_var))
|
||||
self.assertEqual(10, self.evaluate(counter_var))
|
||||
|
||||
def testCaptureUninitializedVariableError(self):
|
||||
counter_var = variable_scope.get_variable(
|
||||
@ -435,7 +435,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
with self.assertRaises(errors.NotFoundError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -447,14 +447,14 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
random_values = []
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
while True:
|
||||
random_values.extend(sess.run(get_next))
|
||||
self.assertEqual(10, len(random_values))
|
||||
self.assertGreater(np.abs(np.diff(random_values)).max(), 1e-6)
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
random_values_2 = []
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
while True:
|
||||
@ -473,8 +473,8 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
random_values = sess.run(get_next)
|
||||
self.evaluate(init_op)
|
||||
random_values = self.evaluate(get_next)
|
||||
|
||||
# Assert that one of the next 99 batches yielded by the iterator is
|
||||
# different from the first.
|
||||
@ -500,15 +500,15 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(counter_var.initializer)
|
||||
sess.run(init_op)
|
||||
self.evaluate(counter_var.initializer)
|
||||
self.evaluate(init_op)
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(counter_var))
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
self.assertEqual(10, sess.run(counter_var))
|
||||
self.assertEqual(i, self.evaluate(counter_var))
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
self.assertEqual(10, self.evaluate(counter_var))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.assertEqual(10, sess.run(counter_var))
|
||||
self.assertEqual(10, self.evaluate(counter_var))
|
||||
|
||||
def testMapDict(self):
|
||||
iterator = (dataset_ops.Dataset.range(10)
|
||||
@ -519,9 +519,9 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
for i in range(10):
|
||||
self.assertEqual(i * 2 + i ** 2, sess.run(get_next))
|
||||
self.assertEqual(i * 2 + i**2, self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -569,8 +569,8 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.assertAllEqual(row ** 2, sess.run(get_next))
|
||||
self.evaluate(init_op)
|
||||
self.assertAllEqual(row**2, self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -611,7 +611,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
row = np.arange(6)
|
||||
for num in [2, 3, 4]:
|
||||
init_op, get_next = build_dataset(row, num)
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
for i in range(6):
|
||||
self.assertEqual(
|
||||
(i // 2 if i % 2 else i * 2) if (num == 2 or num == 3) else i * 2,
|
||||
@ -652,7 +652,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
row = np.arange(6)
|
||||
for num in [2, 3, 4]:
|
||||
init_op, get_next = build_dataset(row, num)
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
self.assertAllEqual(
|
||||
[x // 2 if (num == 2 or num == 3) else x * 2 for x in row],
|
||||
sess.run(get_next))
|
||||
@ -697,7 +697,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
self.assertAllEqual([(x // 2 if x % 2 else x * 2) if
|
||||
(num == 2 or num == 3) else x * 2 for x in row],
|
||||
sess.run(get_next))
|
||||
@ -735,7 +735,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
for buffer_size in [1, 10, 100, 1000]:
|
||||
sess.run(init_op, feed_dict={buffer_size_placeholder: buffer_size})
|
||||
for i in range(100):
|
||||
self.assertEqual(i * i, sess.run(get_next))
|
||||
self.assertEqual(i * i, self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -753,10 +753,10 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
sess.run(init_op, feed_dict={buffer_size_placeholder: buffer_size})
|
||||
for i in range(event_will_be_set_after_consuming):
|
||||
self.assertFalse(ev.is_set())
|
||||
self.assertEqual(i * i, sess.run(get_next))
|
||||
self.assertEqual(i * i, self.evaluate(get_next))
|
||||
ev.wait()
|
||||
for i in range(event_will_be_set_after_consuming, 100):
|
||||
self.assertEqual(i * i, sess.run(get_next))
|
||||
self.assertEqual(i * i, self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -768,9 +768,9 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
for i in range(10):
|
||||
self.assertEqual((i, 37.0), sess.run(get_next))
|
||||
self.assertEqual((i, 37.0), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -789,9 +789,9 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
for i in range(10):
|
||||
self.assertEqual((i, 37.0), sess.run(get_next))
|
||||
self.assertEqual((i, 37.0), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -810,9 +810,9 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
for i in range(10):
|
||||
actual = sess.run(get_next)
|
||||
actual = self.evaluate(get_next)
|
||||
self.assertIsInstance(actual, sparse_tensor.SparseTensorValue)
|
||||
self.assertSparseValuesEqual(actual, _sparse(i))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -837,9 +837,9 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
for i in range(10):
|
||||
actual = sess.run(get_next)
|
||||
actual = self.evaluate(get_next)
|
||||
self.assertIsInstance(actual, sparse_tensor.SparseTensorValue)
|
||||
self.assertSparseValuesEqual(actual, _check(_sparse(i)).eval())
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -861,9 +861,9 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
for i in range(100):
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -875,9 +875,9 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
for i in range(10):
|
||||
self.assertEqual((i, b"hello", 10), sess.run(get_next))
|
||||
self.assertEqual((i, b"hello", 10), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -945,7 +945,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError, "BrokenConst"):
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
|
||||
# pylint: disable=g-long-lambda
|
||||
@parameterized.named_parameters(
|
||||
@ -972,7 +972,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
tids = sess.run(get_next)
|
||||
tids = self.evaluate(get_next)
|
||||
self.assertTrue(all(tids[0] == tid for tid in tids))
|
||||
# pylint: enable=g-long-lambda
|
||||
|
||||
@ -996,7 +996,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
expected = map_fn(*sess.run(self.structuredElement(structure)))
|
||||
else:
|
||||
expected = map_fn(sess.run(self.structuredElement(structure)))
|
||||
self.assertEqual(expected, sess.run(get_next))
|
||||
self.assertEqual(expected, self.evaluate(get_next))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("Sequential", None),
|
||||
@ -1011,7 +1011,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer, feed_dict={captured_t: 42})
|
||||
self.assertEqual(42, sess.run(get_next))
|
||||
self.assertEqual(42, self.evaluate(get_next))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("1", 1, 1),
|
||||
@ -1030,7 +1030,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.cached_session(config=config) as sess:
|
||||
for i in range(num_elements):
|
||||
coordination_events[i].set()
|
||||
self.assertEqual(i * i, sess.run(get_next))
|
||||
self.assertEqual(i * i, self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -1052,7 +1052,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
for element in elements:
|
||||
coordination_events[element].set()
|
||||
self.assertEqual(element * element, sess.run(get_next))
|
||||
self.assertEqual(element * element, self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
|
@ -40,7 +40,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
||||
|
||||
config = config_pb2.ConfigProto(device_count={"CPU": 3})
|
||||
with self.test_session(config=config) as sess:
|
||||
sess.run(multi_device_iterator.initializer)
|
||||
self.evaluate(multi_device_iterator.initializer)
|
||||
|
||||
def testBasic(self):
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
@ -50,10 +50,10 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
||||
|
||||
config = config_pb2.ConfigProto(device_count={"CPU": 3})
|
||||
with self.test_session(config=config) as sess:
|
||||
sess.run(multi_device_iterator.initializer)
|
||||
self.evaluate(multi_device_iterator.initializer)
|
||||
for i in range(0, 10, 2):
|
||||
self.assertEqual(i, sess.run(elem_on_1))
|
||||
self.assertEqual(i + 1, sess.run(elem_on_2))
|
||||
self.assertEqual(i, self.evaluate(elem_on_1))
|
||||
self.assertEqual(i + 1, self.evaluate(elem_on_2))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(elem_on_1)
|
||||
sess.run(elem_on_2)
|
||||
@ -67,10 +67,10 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
||||
|
||||
config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=config) as sess:
|
||||
sess.run(multi_device_iterator.initializer)
|
||||
self.evaluate(multi_device_iterator.initializer)
|
||||
for i in range(0, 10, 2):
|
||||
self.assertEqual(i, sess.run(elem_on_1))
|
||||
self.assertEqual(i + 1, sess.run(elem_on_2))
|
||||
self.assertEqual(i, self.evaluate(elem_on_1))
|
||||
self.assertEqual(i + 1, self.evaluate(elem_on_2))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(elem_on_1)
|
||||
sess.run(elem_on_2)
|
||||
@ -85,12 +85,12 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
||||
|
||||
config = config_pb2.ConfigProto(device_count={"CPU": 3})
|
||||
with self.test_session(config=config) as sess:
|
||||
sess.run(multi_device_iterator.initializer)
|
||||
self.evaluate(multi_device_iterator.initializer)
|
||||
for i in range(0, 20, 4):
|
||||
self.assertEqual(i, sess.run(elem_on_1))
|
||||
self.assertEqual(i + 1, sess.run(elem_on_2))
|
||||
self.assertEqual(i + 2, sess.run(elem_on_3))
|
||||
self.assertEqual(i + 3, sess.run(elem_on_4))
|
||||
self.assertEqual(i, self.evaluate(elem_on_1))
|
||||
self.assertEqual(i + 1, self.evaluate(elem_on_2))
|
||||
self.assertEqual(i + 2, self.evaluate(elem_on_3))
|
||||
self.assertEqual(i + 3, self.evaluate(elem_on_4))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(elem_on_1)
|
||||
sess.run(elem_on_2)
|
||||
@ -105,11 +105,11 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
||||
|
||||
config = config_pb2.ConfigProto(device_count={"CPU": 3})
|
||||
with self.test_session(config=config) as sess:
|
||||
sess.run(multi_device_iterator.initializer)
|
||||
self.evaluate(multi_device_iterator.initializer)
|
||||
for i in range(0, 8, 2):
|
||||
self.assertEqual(i, sess.run(elem_on_1))
|
||||
self.assertEqual(i + 1, sess.run(elem_on_2))
|
||||
self.assertEqual(8, sess.run(elem_on_1))
|
||||
self.assertEqual(i, self.evaluate(elem_on_1))
|
||||
self.assertEqual(i + 1, self.evaluate(elem_on_2))
|
||||
self.assertEqual(8, self.evaluate(elem_on_1))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(elem_on_1)
|
||||
sess.run(elem_on_2)
|
||||
@ -126,7 +126,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
||||
|
||||
config = config_pb2.ConfigProto(device_count={"CPU": 3})
|
||||
with self.test_session(config=config) as sess:
|
||||
sess.run(multi_device_iterator.initializer)
|
||||
self.evaluate(multi_device_iterator.initializer)
|
||||
for i in range(0, 8, 2):
|
||||
elem_on_1_has_value, elem_on_1_value = sess.run(
|
||||
[elem_on_1_has_value_t, elem_on_1_t])
|
||||
@ -140,8 +140,8 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
||||
[elem_on_1_has_value_t, elem_on_1_t])
|
||||
self.assertTrue(elem_on_1_has_value)
|
||||
self.assertEqual(8, elem_on_1_value)
|
||||
self.assertFalse(sess.run(elem_on_1_has_value_t))
|
||||
self.assertFalse(sess.run(elem_on_2_has_value_t))
|
||||
self.assertFalse(self.evaluate(elem_on_1_has_value_t))
|
||||
self.assertFalse(self.evaluate(elem_on_2_has_value_t))
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(elem_on_1_t)
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
@ -155,11 +155,11 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
||||
|
||||
config = config_pb2.ConfigProto(device_count={"CPU": 3})
|
||||
with self.test_session(config=config) as sess:
|
||||
sess.run(multi_device_iterator.initializer)
|
||||
self.evaluate(multi_device_iterator.initializer)
|
||||
for i in range(0, 10, 2):
|
||||
self.assertEqual(i, sess.run(elem_on_1))
|
||||
self.assertEqual(i, self.evaluate(elem_on_1))
|
||||
for i in range(0, 10, 2):
|
||||
self.assertEqual(i + 1, sess.run(elem_on_2))
|
||||
self.assertEqual(i + 1, self.evaluate(elem_on_2))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(elem_on_1)
|
||||
sess.run(elem_on_2)
|
||||
@ -192,10 +192,10 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
||||
|
||||
config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1})
|
||||
with self.test_session(config=config) as sess:
|
||||
sess.run(multi_device_iterator.initializer)
|
||||
self.evaluate(multi_device_iterator.initializer)
|
||||
for i in range(0, 10, 2):
|
||||
self.assertEqual(i, sess.run(elem_on_1))
|
||||
self.assertEqual(i + 1, sess.run(elem_on_2))
|
||||
self.assertEqual(i, self.evaluate(elem_on_1))
|
||||
self.assertEqual(i + 1, self.evaluate(elem_on_2))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(elem_on_1)
|
||||
sess.run(elem_on_2)
|
||||
@ -211,11 +211,11 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
||||
|
||||
config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1})
|
||||
with self.test_session(config=config) as sess:
|
||||
sess.run(multi_device_iterator.initializer)
|
||||
self.evaluate(multi_device_iterator.initializer)
|
||||
for i in range(0, 10, 2):
|
||||
self.assertEqual(i, sess.run(elem_on_1))
|
||||
self.assertEqual(i, self.evaluate(elem_on_1))
|
||||
for i in range(0, 10, 2):
|
||||
self.assertEqual(i + 1, sess.run(elem_on_2))
|
||||
self.assertEqual(i + 1, self.evaluate(elem_on_2))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(elem_on_1)
|
||||
sess.run(elem_on_2)
|
||||
@ -235,7 +235,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
||||
|
||||
config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1})
|
||||
with self.test_session(config=config) as sess:
|
||||
sess.run(multi_device_iterator.initializer)
|
||||
self.evaluate(multi_device_iterator.initializer)
|
||||
for i in range(0, 8, 2):
|
||||
elem_on_1_has_value, elem_on_1_value = sess.run(
|
||||
[elem_on_1_has_value_t, elem_on_1_t])
|
||||
@ -249,8 +249,8 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
||||
[elem_on_1_has_value_t, elem_on_1_t])
|
||||
self.assertTrue(elem_on_1_has_value)
|
||||
self.assertEqual(8, elem_on_1_value)
|
||||
self.assertFalse(sess.run(elem_on_1_has_value_t))
|
||||
self.assertFalse(sess.run(elem_on_2_has_value_t))
|
||||
self.assertFalse(self.evaluate(elem_on_1_has_value_t))
|
||||
self.assertFalse(self.evaluate(elem_on_2_has_value_t))
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(elem_on_1_t)
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
@ -272,10 +272,10 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
||||
|
||||
config = config_pb2.ConfigProto(device_count={"CPU": 3})
|
||||
with self.test_session(config=config) as sess:
|
||||
sess.run(multi_device_iterator.initializer)
|
||||
self.evaluate(multi_device_iterator.initializer)
|
||||
for i in range(0, 10, 2):
|
||||
self.assertEqual(i, sess.run(elem_on_1))
|
||||
self.assertEqual(i + 1, sess.run(elem_on_2))
|
||||
self.assertEqual(i, self.evaluate(elem_on_1))
|
||||
self.assertEqual(i + 1, self.evaluate(elem_on_2))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(elem_on_1)
|
||||
sess.run(elem_on_2)
|
||||
|
@ -227,7 +227,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
# For each element of the dataset, assert that the optional evaluates to
|
||||
# the expected value.
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
for _ in range(3):
|
||||
elem_has_value, elem_value = sess.run([elem_has_value_t, elem_value_t])
|
||||
self.assertTrue(elem_has_value)
|
||||
@ -236,7 +236,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
# After exhausting the iterator, `next_elem.has_value()` will evaluate to
|
||||
# false, and attempting to get the value will fail.
|
||||
for _ in range(2):
|
||||
self.assertFalse(sess.run(elem_has_value_t))
|
||||
self.assertFalse(self.evaluate(elem_has_value_t))
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(elem_value_t)
|
||||
|
||||
|
@ -40,7 +40,7 @@ class PrefetchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op, feed_dict={buffer_size_t: buffer_size})
|
||||
for m in range(10):
|
||||
self.assertEqual(m, sess.run(get_next))
|
||||
self.assertEqual(m, self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
|
@ -124,19 +124,19 @@ class ExperimentalCheckpointDatasetTest(test_base.DatasetTestBase):
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next, save_op, _ = _build_graph(start, stop)
|
||||
with self.session(graph=g) as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
sess.run(init_op)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.evaluate(init_op)
|
||||
for i in range(start, break_point):
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
sess.run(save_op)
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
self.evaluate(save_op)
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next, _, restore_op = _build_graph(start, stop)
|
||||
with self.session(graph=g) as sess:
|
||||
sess.run(init_op)
|
||||
sess.run(restore_op)
|
||||
self.evaluate(init_op)
|
||||
self.evaluate(restore_op)
|
||||
for i in range(break_point, stop):
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -144,14 +144,14 @@ class ExperimentalCheckpointDatasetTest(test_base.DatasetTestBase):
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next, save_op, restore_op = _build_graph(start, stop)
|
||||
with self.session(graph=g) as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
sess.run(init_op)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.evaluate(init_op)
|
||||
for i in range(start, break_point):
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
sess.run(save_op)
|
||||
sess.run(restore_op)
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
self.evaluate(save_op)
|
||||
self.evaluate(restore_op)
|
||||
for i in range(break_point, stop):
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -175,14 +175,14 @@ class ExperimentalCheckpointDatasetTest(test_base.DatasetTestBase):
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs)
|
||||
with self.session(graph=g) as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
sess.run(init_op)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.evaluate(init_op)
|
||||
for _ in range(break_epoch):
|
||||
for i in range(start, stop):
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
for i in range(start, break_point):
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
sess.run(save_op)
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
self.evaluate(save_op)
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
# Create an empty IteratorResource and restore the Iterator into it.
|
||||
@ -193,12 +193,12 @@ class ExperimentalCheckpointDatasetTest(test_base.DatasetTestBase):
|
||||
restore_op = self._restore_op(iterator._iterator_resource)
|
||||
get_next = iterator.get_next()
|
||||
with self.session(graph=g) as sess:
|
||||
sess.run(restore_op)
|
||||
self.evaluate(restore_op)
|
||||
for i in range(break_point, stop):
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
for _ in range(break_epoch + 1, num_epochs):
|
||||
for i in range(start, stop):
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -221,20 +221,20 @@ class ExperimentalCheckpointDatasetTest(test_base.DatasetTestBase):
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next, save_op, _ = _build_graph(start, stop)
|
||||
with self.session(graph=g) as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
sess.run(init_op)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.evaluate(init_op)
|
||||
for i in range(start, break_point):
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
sess.run(save_op)
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
self.evaluate(save_op)
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
# Intentionally build a graph with a different value for stop to make sure
|
||||
# the original dataset graph is actually getting loaded.
|
||||
init_op, get_next, _, restore_op = _build_graph(start, stop_1)
|
||||
with self.session(graph=g) as sess:
|
||||
sess.run(restore_op)
|
||||
self.evaluate(restore_op)
|
||||
for i in range(break_point, stop):
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -259,19 +259,19 @@ class ExperimentalCheckpointDatasetTest(test_base.DatasetTestBase):
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next, save_op, _ = _build_graph(start, stop)
|
||||
with self.session(graph=g) as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
sess.run(init_op)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.evaluate(init_op)
|
||||
for i in range(start, break_point):
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
sess.run(save_op)
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
self.evaluate(save_op)
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next, _, restore_op = _build_graph(start, stop)
|
||||
with self.session(graph=g) as sess:
|
||||
sess.run(init_op)
|
||||
sess.run(restore_op)
|
||||
self.evaluate(init_op)
|
||||
self.evaluate(restore_op)
|
||||
for i in range(break_point, stop):
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -294,27 +294,27 @@ class ExperimentalCheckpointDatasetTest(test_base.DatasetTestBase):
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next, save_op, _ = _build_graph(start, stop)
|
||||
with self.session(graph=g) as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
sess.run(init_op)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.evaluate(init_op)
|
||||
for i in range(start, break_point1):
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
sess.run(save_op)
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
self.evaluate(save_op)
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next, save_op, restore_op = _build_graph(start, stop)
|
||||
with self.session(graph=g) as sess:
|
||||
sess.run(restore_op)
|
||||
self.evaluate(restore_op)
|
||||
for i in range(break_point1, break_point2):
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
sess.run(save_op)
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
self.evaluate(save_op)
|
||||
|
||||
break_point2 = 7
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next, save_op, restore_op = _build_graph(start, stop)
|
||||
with self.session(graph=g) as sess:
|
||||
sess.run(restore_op)
|
||||
self.evaluate(restore_op)
|
||||
for i in range(break_point2, stop):
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -338,28 +338,28 @@ class ExperimentalCheckpointDatasetTest(test_base.DatasetTestBase):
|
||||
init_op, get_next, save_op, restore_op = _build_graph(
|
||||
start, stop, num_epochs)
|
||||
with self.session(graph=g) as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
sess.run(init_op)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.evaluate(init_op)
|
||||
# Note: There is no checkpoint saved currently so a NotFoundError is
|
||||
# raised.
|
||||
with self.assertRaises(errors.NotFoundError):
|
||||
sess.run(restore_op)
|
||||
self.evaluate(restore_op)
|
||||
for _ in range(break_epoch - 1):
|
||||
for i in range(start, stop):
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
for i in range(start, break_range):
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
sess.run(save_op)
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
self.evaluate(save_op)
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next, _, restore_op = _build_graph(start, stop, num_epochs)
|
||||
with self.session(graph=g) as sess:
|
||||
sess.run(restore_op)
|
||||
self.evaluate(restore_op)
|
||||
for i in range(break_range, stop):
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
for _ in range(break_epoch, num_epochs):
|
||||
for i in range(start, stop):
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -381,23 +381,23 @@ class ExperimentalCheckpointDatasetTest(test_base.DatasetTestBase):
|
||||
init_op, get_next, save_op, restore_op = _build_graph(
|
||||
start, stop, num_epochs)
|
||||
with self.session(graph=g) as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
sess.run(init_op)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.evaluate(init_op)
|
||||
# Note: There is no checkpoint saved currently so a NotFoundError is
|
||||
# raised.
|
||||
with self.assertRaises(errors.NotFoundError):
|
||||
sess.run(restore_op)
|
||||
self.evaluate(restore_op)
|
||||
for _ in range(num_epochs):
|
||||
for i in range(start, stop):
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
sess.run(save_op)
|
||||
self.evaluate(save_op)
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next, _, restore_op = _build_graph(start, stop, num_epochs)
|
||||
with self.session(graph=g) as sess:
|
||||
sess.run(restore_op)
|
||||
self.evaluate(restore_op)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
|
@ -107,7 +107,7 @@ class TextLineDatasetTest(test_base.DatasetTestBase):
|
||||
init_op, feed_dict={filenames: [test_filenames[0]],
|
||||
num_epochs: 1})
|
||||
for i in range(5):
|
||||
self.assertEqual(self._lineText(0, i), sess.run(get_next))
|
||||
self.assertEqual(self._lineText(0, i), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -116,7 +116,7 @@ class TextLineDatasetTest(test_base.DatasetTestBase):
|
||||
init_op, feed_dict={filenames: [test_filenames[1]],
|
||||
num_epochs: 1})
|
||||
for i in range(5):
|
||||
self.assertEqual(self._lineText(1, i), sess.run(get_next))
|
||||
self.assertEqual(self._lineText(1, i), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -124,7 +124,7 @@ class TextLineDatasetTest(test_base.DatasetTestBase):
|
||||
sess.run(init_op, feed_dict={filenames: test_filenames, num_epochs: 1})
|
||||
for j in range(2):
|
||||
for i in range(5):
|
||||
self.assertEqual(self._lineText(j, i), sess.run(get_next))
|
||||
self.assertEqual(self._lineText(j, i), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -133,7 +133,7 @@ class TextLineDatasetTest(test_base.DatasetTestBase):
|
||||
for _ in range(10):
|
||||
for j in range(2):
|
||||
for i in range(5):
|
||||
self.assertEqual(self._lineText(j, i), sess.run(get_next))
|
||||
self.assertEqual(self._lineText(j, i), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -267,7 +267,7 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
||||
init_op, feed_dict={filenames: [test_filenames[0]],
|
||||
num_epochs: 1})
|
||||
for i in range(self._num_records):
|
||||
self.assertEqual(self._record(0, i), sess.run(get_next))
|
||||
self.assertEqual(self._record(0, i), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -276,7 +276,7 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
||||
init_op, feed_dict={filenames: [test_filenames[1]],
|
||||
num_epochs: 1})
|
||||
for i in range(self._num_records):
|
||||
self.assertEqual(self._record(1, i), sess.run(get_next))
|
||||
self.assertEqual(self._record(1, i), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -284,7 +284,7 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
||||
sess.run(init_op, feed_dict={filenames: test_filenames, num_epochs: 1})
|
||||
for j in range(self._num_files):
|
||||
for i in range(self._num_records):
|
||||
self.assertEqual(self._record(j, i), sess.run(get_next))
|
||||
self.assertEqual(self._record(j, i), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -293,7 +293,7 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
||||
for _ in range(10):
|
||||
for j in range(self._num_files):
|
||||
for i in range(self._num_records):
|
||||
self.assertEqual(self._record(j, i), sess.run(get_next))
|
||||
self.assertEqual(self._record(j, i), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -405,19 +405,19 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
||||
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
|
||||
num_epochs=num_epochs)
|
||||
with self.session(graph=g) as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
# Note: There is no checkpoint saved currently so a NotFoundError is
|
||||
# raised.
|
||||
with self.assertRaises(errors.NotFoundError):
|
||||
sess.run(restore_op)
|
||||
self.evaluate(restore_op)
|
||||
for epoch in range(num_epochs):
|
||||
for f in range(self._num_files):
|
||||
for r in range(self._num_records):
|
||||
if (epoch == epoch_break and f == file_break and
|
||||
r == record_break):
|
||||
sess.run(save_op)
|
||||
self.evaluate(save_op)
|
||||
break
|
||||
self.assertEqual(self._record(f, r), sess.run(get_next_op))
|
||||
self.assertEqual(self._record(f, r), self.evaluate(get_next_op))
|
||||
else:
|
||||
continue
|
||||
break
|
||||
@ -426,13 +426,13 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
||||
break
|
||||
else:
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next_op)
|
||||
self.evaluate(get_next_op)
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
|
||||
num_epochs=num_epochs)
|
||||
with self.session(graph=g) as sess:
|
||||
sess.run(restore_op)
|
||||
self.evaluate(restore_op)
|
||||
for epoch in range(num_epochs):
|
||||
for f in range(self._num_files):
|
||||
for r in range(self._num_records):
|
||||
@ -441,9 +441,9 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
||||
(epoch == epoch_break and f == file_break and
|
||||
r < record_break)):
|
||||
continue
|
||||
self.assertEqual(self._record(f, r), sess.run(get_next_op))
|
||||
self.assertEqual(self._record(f, r), self.evaluate(get_next_op))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next_op)
|
||||
self.evaluate(get_next_op)
|
||||
|
||||
def testInitThenRestore(self):
|
||||
# Note: Calling init_op before restore_op is redundant. This test just makes
|
||||
@ -458,19 +458,19 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
||||
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
|
||||
num_epochs=num_epochs)
|
||||
with self.session(graph=g) as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
# Note: There is no checkpoint saved currently so a NotFoundError is
|
||||
# raised.
|
||||
with self.assertRaises(errors.NotFoundError):
|
||||
sess.run(restore_op)
|
||||
self.evaluate(restore_op)
|
||||
for epoch in range(num_epochs):
|
||||
for f in range(self._num_files):
|
||||
for r in range(self._num_records):
|
||||
if (epoch == epoch_break and f == file_break and
|
||||
r == record_break):
|
||||
sess.run(save_op)
|
||||
self.evaluate(save_op)
|
||||
break
|
||||
self.assertEqual(self._record(f, r), sess.run(get_next_op))
|
||||
self.assertEqual(self._record(f, r), self.evaluate(get_next_op))
|
||||
else:
|
||||
continue
|
||||
break
|
||||
@ -479,14 +479,14 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
||||
break
|
||||
else:
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next_op)
|
||||
self.evaluate(get_next_op)
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
|
||||
num_epochs=num_epochs)
|
||||
with self.session(graph=g) as sess:
|
||||
sess.run(init_op)
|
||||
sess.run(restore_op)
|
||||
self.evaluate(init_op)
|
||||
self.evaluate(restore_op)
|
||||
for epoch in range(num_epochs):
|
||||
for f in range(self._num_files):
|
||||
for r in range(self._num_records):
|
||||
@ -495,9 +495,9 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
||||
(epoch == epoch_break and f == file_break and
|
||||
r < record_break)):
|
||||
continue
|
||||
self.assertEqual(self._record(f, r), sess.run(get_next_op))
|
||||
self.assertEqual(self._record(f, r), self.evaluate(get_next_op))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next_op)
|
||||
self.evaluate(get_next_op)
|
||||
|
||||
def testRestoreInModifiedGraph(self):
|
||||
num_epochs = 10
|
||||
@ -510,19 +510,19 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
||||
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
|
||||
num_epochs=num_epochs)
|
||||
with self.session(graph=g) as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
# Note: There is no checkpoint saved currently so a NotFoundError is
|
||||
# raised.
|
||||
with self.assertRaises(errors.NotFoundError):
|
||||
sess.run(restore_op)
|
||||
self.evaluate(restore_op)
|
||||
for epoch in range(num_epochs):
|
||||
for f in range(self._num_files):
|
||||
for r in range(self._num_records):
|
||||
if (epoch == epoch_break and f == file_break and
|
||||
r == record_break):
|
||||
sess.run(save_op)
|
||||
self.evaluate(save_op)
|
||||
break
|
||||
self.assertEqual(self._record(f, r), sess.run(get_next_op))
|
||||
self.assertEqual(self._record(f, r), self.evaluate(get_next_op))
|
||||
else:
|
||||
continue
|
||||
break
|
||||
@ -531,13 +531,13 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
||||
break
|
||||
else:
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next_op)
|
||||
self.evaluate(get_next_op)
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
|
||||
num_epochs=num_epochs_1)
|
||||
with self.session(graph=g) as sess:
|
||||
sess.run(restore_op)
|
||||
self.evaluate(restore_op)
|
||||
for epoch in range(num_epochs):
|
||||
for f in range(self._num_files):
|
||||
for r in range(self._num_records):
|
||||
@ -546,9 +546,9 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
||||
(epoch == epoch_break and f == file_break and
|
||||
r < record_break)):
|
||||
continue
|
||||
self.assertEqual(self._record(f, r), sess.run(get_next_op))
|
||||
self.assertEqual(self._record(f, r), self.evaluate(get_next_op))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next_op)
|
||||
self.evaluate(get_next_op)
|
||||
|
||||
def testRestoreWithoutBuildingDatasetGraph(self):
|
||||
num_epochs = 10
|
||||
@ -560,19 +560,19 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
||||
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
|
||||
num_epochs=num_epochs)
|
||||
with self.session(graph=g) as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
# Note: There is no checkpoint saved currently so a NotFoundError is
|
||||
# raised.
|
||||
with self.assertRaises(errors.NotFoundError):
|
||||
sess.run(restore_op)
|
||||
self.evaluate(restore_op)
|
||||
for epoch in range(num_epochs):
|
||||
for f in range(self._num_files):
|
||||
for r in range(self._num_records):
|
||||
if (epoch == epoch_break and f == file_break and
|
||||
r == record_break):
|
||||
sess.run(save_op)
|
||||
self.evaluate(save_op)
|
||||
break
|
||||
self.assertEqual(self._record(f, r), sess.run(get_next_op))
|
||||
self.assertEqual(self._record(f, r), self.evaluate(get_next_op))
|
||||
else:
|
||||
continue
|
||||
break
|
||||
@ -581,12 +581,12 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
||||
break
|
||||
else:
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next_op)
|
||||
self.evaluate(get_next_op)
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
restore_op, get_next_op = self._restore_iterator()
|
||||
with self.session(graph=g) as sess:
|
||||
sess.run(restore_op)
|
||||
self.evaluate(restore_op)
|
||||
for epoch in range(num_epochs):
|
||||
for f in range(self._num_files):
|
||||
for r in range(self._num_records):
|
||||
@ -595,9 +595,9 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
||||
(epoch == epoch_break and f == file_break and
|
||||
r < record_break)):
|
||||
continue
|
||||
self.assertEqual(self._record(f, r), sess.run(get_next_op))
|
||||
self.assertEqual(self._record(f, r), self.evaluate(get_next_op))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next_op)
|
||||
self.evaluate(get_next_op)
|
||||
|
||||
def testRestoreUnusedIterator(self):
|
||||
num_epochs = 10
|
||||
@ -605,22 +605,22 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
||||
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
|
||||
num_epochs=num_epochs)
|
||||
with self.session(graph=g) as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
# Note: There is no checkpoint saved currently so a NotFoundError is
|
||||
# raised.
|
||||
with self.assertRaises(errors.NotFoundError):
|
||||
sess.run(restore_op)
|
||||
self.evaluate(restore_op)
|
||||
# Save unused iterator.
|
||||
sess.run(save_op)
|
||||
self.evaluate(save_op)
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
|
||||
num_epochs=num_epochs)
|
||||
with self.session(graph=g) as sess:
|
||||
sess.run(restore_op)
|
||||
self.evaluate(restore_op)
|
||||
for _ in range(num_epochs * self._num_files * self._num_records):
|
||||
sess.run(get_next_op)
|
||||
self.evaluate(get_next_op)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next_op)
|
||||
self.evaluate(get_next_op)
|
||||
|
||||
def testRestoreExhaustedIterator(self):
|
||||
num_epochs = 10
|
||||
@ -629,26 +629,26 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
||||
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
|
||||
num_epochs=num_epochs)
|
||||
with self.session(graph=g) as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
# Note: There is no checkpoint saved currently so a NotFoundError is
|
||||
# raised.
|
||||
with self.assertRaises(errors.NotFoundError):
|
||||
sess.run(restore_op)
|
||||
self.evaluate(restore_op)
|
||||
for _ in range(num_epochs):
|
||||
for f in range(self._num_files):
|
||||
for r in range(self._num_records):
|
||||
self.assertEqual(self._record(f, r), sess.run(get_next_op))
|
||||
self.assertEqual(self._record(f, r), self.evaluate(get_next_op))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next_op)
|
||||
sess.run(save_op)
|
||||
self.evaluate(get_next_op)
|
||||
self.evaluate(save_op)
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
|
||||
num_epochs=num_epochs)
|
||||
with self.session(graph=g) as sess:
|
||||
sess.run(restore_op)
|
||||
self.evaluate(restore_op)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next_op)
|
||||
self.evaluate(get_next_op)
|
||||
|
||||
|
||||
class TFRecordDatasetTest(test_base.DatasetTestBase):
|
||||
@ -807,7 +807,7 @@ class TFRecordDatasetTest(test_base.DatasetTestBase):
|
||||
with self.cached_session() as sess:
|
||||
for j in range(self._num_files):
|
||||
for i in range(self._num_records):
|
||||
self.assertAllEqual(self._record(j, i), sess.run(next_element))
|
||||
self.assertAllEqual(self._record(j, i), self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -819,7 +819,7 @@ class TFRecordDatasetTest(test_base.DatasetTestBase):
|
||||
with self.cached_session() as sess:
|
||||
for j in range(self._num_files):
|
||||
for i in range(self._num_records):
|
||||
self.assertAllEqual(self._record(j, i), sess.run(next_element))
|
||||
self.assertAllEqual(self._record(j, i), self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
|
@ -36,7 +36,7 @@ class ReduceDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
ds = dataset_ops.Dataset.range(1, i + 1)
|
||||
result = ds.reduce(np.int64(0), lambda x, y: x + y)
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(((i + 1) * i) // 2, sess.run(result))
|
||||
self.assertEqual(((i + 1) * i) // 2, self.evaluate(result))
|
||||
|
||||
def testSumTuple(self):
|
||||
|
||||
@ -49,7 +49,7 @@ class ReduceDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
ds = dataset_ops.Dataset.zip((ds, ds))
|
||||
result = ds.reduce(np.int64(0), reduce_fn)
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(((i + 1) * i), sess.run(result))
|
||||
self.assertEqual(((i + 1) * i), self.evaluate(result))
|
||||
|
||||
def testSumAndCount(self):
|
||||
|
||||
@ -61,7 +61,7 @@ class ReduceDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
ds = dataset_ops.Dataset.range(1, i + 1)
|
||||
result = ds.reduce((np.int64(0), np.int64(0)), reduce_fn)
|
||||
with self.cached_session() as sess:
|
||||
s, c = sess.run(result)
|
||||
s, c = self.evaluate(result)
|
||||
self.assertEqual(((i + 1) * i) // 2, s)
|
||||
self.assertEqual(i, c)
|
||||
|
||||
@ -93,7 +93,8 @@ class ReduceDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
ds = dataset_ops.Dataset.from_tensors(make_sparse_fn(i+1))
|
||||
result = ds.reduce(make_sparse_fn(0), reduce_fn)
|
||||
with self.cached_session() as sess:
|
||||
self.assertSparseValuesEqual(make_sparse_fn(i+1), sess.run(result))
|
||||
self.assertSparseValuesEqual(
|
||||
make_sparse_fn(i + 1), self.evaluate(result))
|
||||
|
||||
def testNested(self):
|
||||
|
||||
@ -116,7 +117,7 @@ class ReduceDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
ds = dataset_ops.Dataset.range(1, i + 1).map(map_fn)
|
||||
result = ds.reduce(map_fn(0), reduce_fn)
|
||||
with self.cached_session() as sess:
|
||||
result = sess.run(result)
|
||||
result = self.evaluate(result)
|
||||
self.assertEqual(((i + 1) * i) // 2, result["dense"])
|
||||
self.assertSparseValuesEqual(make_sparse_fn(i), result["sparse"])
|
||||
|
||||
|
@ -49,7 +49,7 @@ class SequenceDatasetTest(test_base.DatasetTestBase):
|
||||
# Test a finite repetition.
|
||||
sess.run(init_op, feed_dict={count_placeholder: 3})
|
||||
for _ in range(3):
|
||||
results = sess.run(get_next)
|
||||
results = self.evaluate(get_next)
|
||||
for component, result_component in zip(components, results):
|
||||
self.assertAllEqual(component, result_component)
|
||||
|
||||
@ -59,7 +59,7 @@ class SequenceDatasetTest(test_base.DatasetTestBase):
|
||||
# Test a different finite repetition.
|
||||
sess.run(init_op, feed_dict={count_placeholder: 7})
|
||||
for _ in range(7):
|
||||
results = sess.run(get_next)
|
||||
results = self.evaluate(get_next)
|
||||
for component, result_component in zip(components, results):
|
||||
self.assertAllEqual(component, result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -75,7 +75,7 @@ class SequenceDatasetTest(test_base.DatasetTestBase):
|
||||
# actually is infinite.
|
||||
sess.run(init_op, feed_dict={count_placeholder: -1})
|
||||
for _ in range(17):
|
||||
results = sess.run(get_next)
|
||||
results = self.evaluate(get_next)
|
||||
for component, result_component in zip(components, results):
|
||||
self.assertAllEqual(component, result_component)
|
||||
|
||||
@ -95,7 +95,7 @@ class SequenceDatasetTest(test_base.DatasetTestBase):
|
||||
# Take fewer than input size
|
||||
sess.run(init_op, feed_dict={count_placeholder: 4})
|
||||
for i in range(4):
|
||||
results = sess.run(get_next)
|
||||
results = self.evaluate(get_next)
|
||||
self.assertAllEqual(results, components[0][i:i+1])
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -104,7 +104,7 @@ class SequenceDatasetTest(test_base.DatasetTestBase):
|
||||
# Take more than input size
|
||||
sess.run(init_op, feed_dict={count_placeholder: 25})
|
||||
for i in range(10):
|
||||
results = sess.run(get_next)
|
||||
results = self.evaluate(get_next)
|
||||
self.assertAllEqual(results, components[0][i:i+1])
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -113,7 +113,7 @@ class SequenceDatasetTest(test_base.DatasetTestBase):
|
||||
# Take all of input
|
||||
sess.run(init_op, feed_dict={count_placeholder: -1})
|
||||
for i in range(10):
|
||||
results = sess.run(get_next)
|
||||
results = self.evaluate(get_next)
|
||||
self.assertAllEqual(results, components[0][i:i+1])
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -142,7 +142,7 @@ class SequenceDatasetTest(test_base.DatasetTestBase):
|
||||
# the first 4 elements and then read the rest.
|
||||
sess.run(init_op, feed_dict={count_placeholder: 4})
|
||||
for i in range(4, 10):
|
||||
results = sess.run(get_next)
|
||||
results = self.evaluate(get_next)
|
||||
self.assertAllEqual(results, components[0][i:i+1])
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
@ -165,7 +165,7 @@ class SequenceDatasetTest(test_base.DatasetTestBase):
|
||||
# Skip nothing
|
||||
sess.run(init_op, feed_dict={count_placeholder: 0})
|
||||
for i in range(0, 10):
|
||||
results = sess.run(get_next)
|
||||
results = self.evaluate(get_next)
|
||||
self.assertAllEqual(results, components[0][i:i+1])
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
@ -187,7 +187,7 @@ class SequenceDatasetTest(test_base.DatasetTestBase):
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op, feed_dict={inner_count: 7, outer_count: 14})
|
||||
for _ in range(7 * 14):
|
||||
results = sess.run(get_next)
|
||||
results = self.evaluate(get_next)
|
||||
for component, result_component in zip(components, results):
|
||||
self.assertAllEqual(component, result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -201,7 +201,7 @@ class SequenceDatasetTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
|
@ -66,7 +66,7 @@ class ShuffleDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
# First run without shuffling to collect the "ground truth".
|
||||
sess.run(init_fifo_op)
|
||||
self.evaluate(init_fifo_op)
|
||||
unshuffled_elements = []
|
||||
for _ in range(20):
|
||||
unshuffled_elements.append(sess.run(get_next))
|
||||
@ -159,7 +159,7 @@ class ShuffleDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer, feed_dict={seed_placeholder: 0})
|
||||
for elem in elems:
|
||||
self.assertEqual(elem, sess.run(get_next))
|
||||
self.assertEqual(elem, self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -188,9 +188,9 @@ class ShuffleDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
initial_permutation = sess.run(next_element)
|
||||
self.assertAllEqual(initial_permutation, sess.run(next_element))
|
||||
self.assertAllEqual(initial_permutation, sess.run(next_element))
|
||||
initial_permutation = self.evaluate(next_element)
|
||||
self.assertAllEqual(initial_permutation, self.evaluate(next_element))
|
||||
self.assertAllEqual(initial_permutation, self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -261,7 +261,7 @@ class ShuffleDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.session(graph=g) as sess:
|
||||
for iterator in iterators:
|
||||
if initializable:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
next_element = iterator.get_next()
|
||||
run_results = []
|
||||
for _ in range(300):
|
||||
|
@ -102,7 +102,7 @@ class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
num_full_batches = max(
|
||||
0, (count * 7 - ((size - 1) * stride + 1)) // shift + 1)
|
||||
for i in range(num_full_batches):
|
||||
result = sess.run(get_next)
|
||||
result = self.evaluate(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
for j in range(size):
|
||||
self.assertAllEqual(component[(i * shift + j * stride) % 7]**2,
|
||||
@ -111,7 +111,7 @@ class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
num_partial_batches = (count * 7) // shift + (
|
||||
(count * 7) % shift > 0) - num_full_batches
|
||||
for i in range(num_partial_batches):
|
||||
result = sess.run(get_next)
|
||||
result = self.evaluate(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
remaining = (count * 7) - ((num_full_batches + i) * shift)
|
||||
num_elements = remaining // stride + ((remaining % stride) > 0)
|
||||
@ -164,10 +164,10 @@ class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
num_batches = (10 - 5) // 3 + 1
|
||||
for i in range(num_batches):
|
||||
actual = sess.run(get_next)
|
||||
actual = self.evaluate(get_next)
|
||||
expected = sparse_tensor.SparseTensorValue(
|
||||
indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
|
||||
values=[i * 3, i * 3 + 1, i * 3 + 2, i * 3 + 3, i * 3 + 4],
|
||||
@ -193,10 +193,10 @@ class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
num_batches = (10 - 5) // 3 + 1
|
||||
for i in range(num_batches):
|
||||
actual = sess.run(get_next)
|
||||
actual = self.evaluate(get_next)
|
||||
expected_indices = []
|
||||
expected_values = []
|
||||
for j in range(5):
|
||||
@ -227,9 +227,9 @@ class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
# Slide: 1st batch.
|
||||
actual = sess.run(get_next)
|
||||
actual = self.evaluate(get_next)
|
||||
expected = sparse_tensor.SparseTensorValue(
|
||||
indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [1, 0, 0],
|
||||
[1, 1, 0], [1, 2, 0], [1, 3, 0], [2, 0, 0], [2, 1, 0],
|
||||
@ -239,7 +239,7 @@ class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertTrue(sparse_tensor.is_sparse(actual))
|
||||
self.assertSparseValuesEqual(actual, expected)
|
||||
# Slide: 2nd batch.
|
||||
actual = sess.run(get_next)
|
||||
actual = self.evaluate(get_next)
|
||||
expected = sparse_tensor.SparseTensorValue(
|
||||
indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [1, 0, 0],
|
||||
[1, 1, 0], [1, 2, 0], [1, 3, 0], [2, 0, 0], [2, 1, 0],
|
||||
@ -265,7 +265,7 @@ class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
with self.assertRaisesRegexp(
|
||||
errors.InvalidArgumentError,
|
||||
r"Cannot batch tensors with different shapes in component 0. "
|
||||
@ -281,8 +281,8 @@ class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = dataset.make_one_shot_iterator().get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(np.float32([1., 2.]), sess.run(get_next))
|
||||
self.assertAllEqual(np.float32([2., 3.]), sess.run(get_next))
|
||||
self.assertAllEqual(np.float32([1., 2.]), self.evaluate(get_next))
|
||||
self.assertAllEqual(np.float32([2., 3.]), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
|
@ -55,7 +55,7 @@ class ZipDatasetTest(test_base.DatasetTestBase):
|
||||
sess.run(init_op, feed_dict={ph: value for ph, value in zip(
|
||||
component_placeholders, equal_length_components)})
|
||||
for i in range(4):
|
||||
results = sess.run(get_next)
|
||||
results = self.evaluate(get_next)
|
||||
for component, result_component in zip(
|
||||
equal_length_components, results):
|
||||
self.assertAllEqual(component[i], result_component)
|
||||
@ -66,7 +66,7 @@ class ZipDatasetTest(test_base.DatasetTestBase):
|
||||
sess.run(init_op, feed_dict={ph: value for ph, value in zip(
|
||||
component_placeholders, variable_length_components)})
|
||||
for i in range(2):
|
||||
results = sess.run(get_next)
|
||||
results = self.evaluate(get_next)
|
||||
for component, result_component in zip(
|
||||
variable_length_components, results):
|
||||
self.assertAllEqual(component[i], result_component)
|
||||
@ -103,7 +103,7 @@ class ZipDatasetTest(test_base.DatasetTestBase):
|
||||
sess.run(init_op, feed_dict={ph: value for ph, value in zip(
|
||||
component_placeholders, equal_length_components)})
|
||||
for i in range(4):
|
||||
result1, (result2, result3) = sess.run(get_next)
|
||||
result1, (result2, result3) = self.evaluate(get_next)
|
||||
self.assertAllEqual(equal_length_components[0][i], result1)
|
||||
self.assertAllEqual(equal_length_components[1][i], result2)
|
||||
self.assertAllEqual(equal_length_components[2][i], result3)
|
||||
|
@ -31,24 +31,24 @@ class ConvertTest(test.TestCase):
|
||||
def testInteger(self):
|
||||
resp = convert.optional_param_to_tensor("foo", 3)
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(3, sess.run(resp))
|
||||
self.assertEqual(3, self.evaluate(resp))
|
||||
|
||||
def testIntegerDefault(self):
|
||||
resp = convert.optional_param_to_tensor("foo", None)
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(0, sess.run(resp))
|
||||
self.assertEqual(0, self.evaluate(resp))
|
||||
|
||||
def testStringDefault(self):
|
||||
resp = convert.optional_param_to_tensor("bar", None, "default",
|
||||
dtypes.string)
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(compat.as_bytes("default"), sess.run(resp))
|
||||
self.assertEqual(compat.as_bytes("default"), self.evaluate(resp))
|
||||
|
||||
def testString(self):
|
||||
resp = convert.optional_param_to_tensor("bar", "value", "default",
|
||||
dtypes.string)
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(compat.as_bytes("value"), sess.run(resp))
|
||||
self.assertEqual(compat.as_bytes("value"), self.evaluate(resp))
|
||||
|
||||
def testPartialShapeToTensorKnownDimension(self):
|
||||
with self.cached_session() as sess:
|
||||
|
@ -1583,7 +1583,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
|
||||
x = variables.VariableV1([1, 3, 3, 7], name="x")
|
||||
_, idx = array_ops.unique(x, name="x_unique")
|
||||
idx_times_two = math_ops.multiply(idx, 2, name="idx_times_two")
|
||||
sess.run(x.initializer)
|
||||
self.evaluate(x.initializer)
|
||||
|
||||
run_options = config_pb2.RunOptions(output_partition_graphs=True)
|
||||
debug_utils.watch_graph(
|
||||
|
@ -126,8 +126,8 @@ class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase):
|
||||
u = variables.Variable([12.0], name="u")
|
||||
v = variables.Variable([30.0], name="v")
|
||||
w = math_ops.add(u, v, name="w")
|
||||
sess.run(u.initializer)
|
||||
sess.run(v.initializer)
|
||||
self.evaluate(u.initializer)
|
||||
self.evaluate(v.initializer)
|
||||
|
||||
self._compareOriginalAndReconstructedGraphDefs(
|
||||
sess, w, expected_output=[42.0])
|
||||
@ -139,7 +139,7 @@ class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase):
|
||||
b = math_ops.add(a, a, name="b")
|
||||
with ops.control_dependencies([a, b]):
|
||||
c = math_ops.multiply(b, b, name="c")
|
||||
sess.run(a.initializer)
|
||||
self.evaluate(a.initializer)
|
||||
|
||||
self._compareOriginalAndReconstructedGraphDefs(
|
||||
sess, c, expected_output=400.0)
|
||||
@ -150,8 +150,8 @@ class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase):
|
||||
y = variables.Variable(20.0, name="y")
|
||||
cond = control_flow_ops.cond(
|
||||
x > y, lambda: math_ops.add(x, 1), lambda: math_ops.add(y, 1))
|
||||
sess.run(x.initializer)
|
||||
sess.run(y.initializer)
|
||||
self.evaluate(x.initializer)
|
||||
self.evaluate(y.initializer)
|
||||
|
||||
self._compareOriginalAndReconstructedGraphDefs(
|
||||
sess, cond, expected_output=21.0)
|
||||
@ -173,8 +173,8 @@ class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase):
|
||||
toy_loss = x * (u - v)
|
||||
train_op = gradient_descent.GradientDescentOptimizer(
|
||||
learning_rate=0.1).minimize(toy_loss, name="train_op")
|
||||
sess.run(u.initializer)
|
||||
sess.run(v.initializer)
|
||||
self.evaluate(u.initializer)
|
||||
self.evaluate(v.initializer)
|
||||
|
||||
self._compareOriginalAndReconstructedGraphDefs(sess, train_op)
|
||||
|
||||
|
@ -131,8 +131,8 @@ class DistributedSessionDebugTest(test_util.TensorFlowTestCase):
|
||||
with session.Session(
|
||||
config=self.session_config, graph=graph,
|
||||
target=self.server_target) as sess:
|
||||
sess.run(self.a.initializer)
|
||||
sess.run(self.b.initializer)
|
||||
self.evaluate(self.a.initializer)
|
||||
self.evaluate(self.b.initializer)
|
||||
|
||||
run_options = config_pb2.RunOptions()
|
||||
debug_utils.watch_graph(
|
||||
@ -198,8 +198,8 @@ class DistributedSessionDebugTest(test_util.TensorFlowTestCase):
|
||||
with session.Session(
|
||||
config=self.session_config, graph=graph,
|
||||
target=self.server_target) as sess:
|
||||
sess.run(self.a.initializer)
|
||||
sess.run(self.b.initializer)
|
||||
self.evaluate(self.a.initializer)
|
||||
self.evaluate(self.b.initializer)
|
||||
|
||||
def watch_fn(feeds, fetch_keys):
|
||||
del feeds, fetch_keys
|
||||
|
@ -67,7 +67,7 @@ class SessionDebugMultiGPUTest(test_util.TensorFlowTestCase):
|
||||
u1 = math_ops.multiply(v, v, name="u1")
|
||||
w = math_ops.subtract(u1, u0, name="w")
|
||||
|
||||
sess.run(v.initializer)
|
||||
self.evaluate(v.initializer)
|
||||
|
||||
run_options = config_pb2.RunOptions(output_partition_graphs=True)
|
||||
debug_utils.watch_graph(run_options, sess.graph,
|
||||
|
@ -109,8 +109,8 @@ class SourceHelperTest(test_util.TensorFlowTestCase):
|
||||
self.w = math_ops.matmul(self.u, self.v, name="w")
|
||||
self.w_line_number = line_number_above()
|
||||
|
||||
sess.run(self.u.initializer)
|
||||
sess.run(self.v.initializer)
|
||||
self.evaluate(self.u.initializer)
|
||||
self.evaluate(self.v.initializer)
|
||||
|
||||
run_options = config_pb2.RunOptions(output_partition_graphs=True)
|
||||
debug_utils.watch_graph(
|
||||
|
@ -235,7 +235,7 @@ class DistributeCoordinatorTestBase(test.TestCase):
|
||||
result = math_ops.add_n(xs)
|
||||
|
||||
variables.global_variables_initializer().run()
|
||||
result_value = sess.run(result)
|
||||
result_value = self.evaluate(result)
|
||||
self.assertEqual(result_value, expected)
|
||||
if result_value == expected:
|
||||
self._result_correct += 1
|
||||
@ -294,7 +294,7 @@ class DistributeCoordinatorTestBase(test.TestCase):
|
||||
if len(uninit_vars) == 0:
|
||||
break
|
||||
|
||||
sess.run(train_op)
|
||||
self.evaluate(train_op)
|
||||
|
||||
# Synchronize workers after one step to make sure they all have finished
|
||||
# training.
|
||||
@ -327,7 +327,7 @@ class DistributeCoordinatorTestBase(test.TestCase):
|
||||
|
||||
# The monitored session will run init or ready ops.
|
||||
with monitored_session.MonitoredSession() as sess:
|
||||
sess.run(train_op)
|
||||
self.evaluate(train_op)
|
||||
|
||||
# Synchronize workers after one step to make sure they all have finished
|
||||
# training.
|
||||
|
@ -92,7 +92,7 @@ class AutoShardDatasetTest(test.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
for f in range(self._shard_index, self._num_files, self._num_shards):
|
||||
for r in range(self._num_records):
|
||||
self.assertAllEqual(record_fn(r, f), sess.run(next_element))
|
||||
self.assertAllEqual(record_fn(r, f), self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -205,10 +205,11 @@ class AutoShardDatasetTest(test.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
for f in range(self._shard_index, self._num_files, self._num_shards):
|
||||
for r in range(self._num_records):
|
||||
self.assertAllEqual(self._record(r, f), sess.run(next_element))
|
||||
self.assertAllEqual(self._record(r, f), self.evaluate(next_element))
|
||||
for f in range(self._shard_index, self._num_files, self._num_shards):
|
||||
for r in range(self._num_records):
|
||||
self.assertAllEqual(self._text_line(r, f), sess.run(next_element))
|
||||
self.assertAllEqual(
|
||||
self._text_line(r, f), self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
|
@ -149,9 +149,9 @@ class DefFunctionTest(test.TestCase):
|
||||
|
||||
result = fn(3.0)
|
||||
|
||||
sess.run(variables.global_variables_initializer())
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertAllEqual(sess.run(state[0]), 2.0)
|
||||
self.assertAllEqual(sess.run(result), 6.0)
|
||||
self.assertAllEqual(self.evaluate(result), 6.0)
|
||||
|
||||
def testLegacyGraphModeVariablesNonTrivialInitializer(self):
|
||||
with ops.Graph().as_default(), self.test_session() as sess:
|
||||
@ -168,9 +168,9 @@ class DefFunctionTest(test.TestCase):
|
||||
|
||||
result = fn(3.0)
|
||||
|
||||
sess.run(variables.global_variables_initializer())
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertAllEqual(sess.run(state[0]), 6.0)
|
||||
self.assertAllEqual(sess.run(result), 18.0)
|
||||
self.assertAllEqual(self.evaluate(result), 18.0)
|
||||
|
||||
def testLegacyGraphModeInputDependentInitializerFails(self):
|
||||
with ops.Graph().as_default():
|
||||
|
@ -78,7 +78,7 @@ class FunctionGradientsTest(test.TestCase, parameterized.TestCase):
|
||||
c = constant_op.constant([[2.]])
|
||||
f_c = f(c)
|
||||
g, = gradients_impl.gradients(f_c, c)
|
||||
self.assertAllEqual(sess.run(g).values, [[1.0]])
|
||||
self.assertAllEqual(self.evaluate(g).values, [[1.0]])
|
||||
|
||||
def testNoSymGradNestedDefun(self):
|
||||
|
||||
|
@ -564,7 +564,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
variables.global_variables_initializer().run()
|
||||
call = def_function.function(o.call)
|
||||
op = call()
|
||||
self.assertAllEqual(sess.run(op), 2.0)
|
||||
self.assertAllEqual(self.evaluate(op), 2.0)
|
||||
|
||||
def testGraphModeManyFunctions(self):
|
||||
with ops.Graph().as_default(), self.cached_session():
|
||||
@ -1733,7 +1733,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
function.register(cpu_boost, x)
|
||||
y = gpu_boost(x)
|
||||
y_value = sess.run(y)
|
||||
y_value = self.evaluate(y)
|
||||
|
||||
if test.is_gpu_available():
|
||||
self.assertEqual(y_value, 5.0)
|
||||
|
@ -1026,7 +1026,7 @@ class CrossedColumnTest(test.TestCase):
|
||||
outputs = _transform_features(features, [price_cross_wire])
|
||||
output = outputs[price_cross_wire]
|
||||
with self.cached_session() as sess:
|
||||
output_val = sess.run(output)
|
||||
output_val = self.evaluate(output)
|
||||
self.assertAllEqual(
|
||||
[[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]], output_val.indices)
|
||||
for val in output_val.values:
|
||||
@ -1880,7 +1880,8 @@ class LinearModelTest(test.TestCase):
|
||||
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
|
||||
sess.run(bias.assign([5.]))
|
||||
|
||||
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], sess.run(net))
|
||||
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]],
|
||||
self.evaluate(net))
|
||||
|
||||
def test_with_1d_unknown_shape_sparse_tensor(self):
|
||||
price = fc.numeric_column('price')
|
||||
@ -2514,7 +2515,8 @@ class _LinearModelTest(test.TestCase):
|
||||
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
|
||||
sess.run(bias.assign([5.]))
|
||||
|
||||
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], sess.run(net))
|
||||
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]],
|
||||
self.evaluate(net))
|
||||
|
||||
def test_with_1d_unknown_shape_sparse_tensor(self):
|
||||
price = fc.numeric_column('price')
|
||||
|
@ -1190,7 +1190,7 @@ class CrossedColumnTest(test.TestCase):
|
||||
outputs = fc._transform_features(features, [price_cross_wire], None)
|
||||
output = outputs[price_cross_wire]
|
||||
with self.cached_session() as sess:
|
||||
output_val = sess.run(output)
|
||||
output_val = self.evaluate(output)
|
||||
self.assertAllEqual(
|
||||
[[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]], output_val.indices)
|
||||
for val in output_val.values:
|
||||
@ -2091,7 +2091,8 @@ class LinearModelTest(test.TestCase):
|
||||
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
|
||||
sess.run(bias.assign([5.]))
|
||||
|
||||
self.assertAllClose([[10 - 1000 + 5.], [100 - 10 + 5.]], sess.run(net))
|
||||
self.assertAllClose([[10 - 1000 + 5.], [100 - 10 + 5.]],
|
||||
self.evaluate(net))
|
||||
|
||||
coord.request_stop()
|
||||
coord.join(threads)
|
||||
@ -2127,7 +2128,8 @@ class LinearModelTest(test.TestCase):
|
||||
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
|
||||
sess.run(bias.assign([5.]))
|
||||
|
||||
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], sess.run(net))
|
||||
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]],
|
||||
self.evaluate(net))
|
||||
|
||||
def test_with_1d_unknown_shape_sparse_tensor(self):
|
||||
price = fc.numeric_column_v2('price')
|
||||
@ -2849,7 +2851,8 @@ class OldLinearModelTest(test.TestCase):
|
||||
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
|
||||
sess.run(bias.assign([5.]))
|
||||
|
||||
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], sess.run(net))
|
||||
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]],
|
||||
self.evaluate(net))
|
||||
|
||||
def test_with_1d_unknown_shape_sparse_tensor(self):
|
||||
price = fc.numeric_column_v2('price')
|
||||
|
@ -102,7 +102,7 @@ class FunctionTest(test.TestCase):
|
||||
call = MyIdentityFunc([18.0])
|
||||
self.assertEqual("MyIdentity", call.op.name)
|
||||
with session.Session() as sess:
|
||||
self.assertAllEqual([18.0], sess.run(call))
|
||||
self.assertAllEqual([18.0], self.evaluate(call))
|
||||
|
||||
def testIdentityImplicitDeref(self):
|
||||
|
||||
@ -116,8 +116,8 @@ class FunctionTest(test.TestCase):
|
||||
self.assertEqual("MyIdentity", call.op.name)
|
||||
for cfg in _OptimizerOptions():
|
||||
with session.Session(config=cfg) as sess:
|
||||
sess.run(var.initializer)
|
||||
self.assertAllEqual([18.0], sess.run(call))
|
||||
self.evaluate(var.initializer)
|
||||
self.assertAllEqual([18.0], self.evaluate(call))
|
||||
|
||||
def testIdentityOutputName(self):
|
||||
|
||||
@ -130,7 +130,7 @@ class FunctionTest(test.TestCase):
|
||||
call = MyIdentityFunc([18.0])
|
||||
self.assertEqual("MyIdentity", call.op.name)
|
||||
with session.Session() as sess:
|
||||
self.assertAllEqual([18.0], sess.run(call))
|
||||
self.assertAllEqual([18.0], self.evaluate(call))
|
||||
|
||||
def testTooManyOutputNames(self):
|
||||
|
||||
@ -158,7 +158,7 @@ class FunctionTest(test.TestCase):
|
||||
call = APlus2B([1.0], [2.0])
|
||||
self.assertEqual("APlus2B", call.op.name)
|
||||
with session.Session() as sess:
|
||||
self.assertAllEqual([5.0], sess.run(call))
|
||||
self.assertAllEqual([5.0], self.evaluate(call))
|
||||
|
||||
def testFunctionWithNoOutput(self):
|
||||
|
||||
@ -187,7 +187,7 @@ class FunctionTest(test.TestCase):
|
||||
call = APlus2B([1.0], [2.0])
|
||||
self.assertEqual("APlus2B", call.op.name)
|
||||
with session.Session() as sess:
|
||||
self.assertAllEqual([5.0], sess.run(call))
|
||||
self.assertAllEqual([5.0], self.evaluate(call))
|
||||
|
||||
def testDefineFunctionDuplicateOutputs(self):
|
||||
|
||||
@ -224,8 +224,8 @@ class FunctionTest(test.TestCase):
|
||||
call_g = XSquarePlusOneGrad([2.0], [0.1])
|
||||
|
||||
with session.Session() as sess:
|
||||
self.assertAllClose([5.0], sess.run(call_f))
|
||||
self.assertAllClose([0.4], sess.run(call_g))
|
||||
self.assertAllClose([5.0], self.evaluate(call_f))
|
||||
self.assertAllClose([0.4], self.evaluate(call_g))
|
||||
|
||||
def testTanhSymGrad(self):
|
||||
|
||||
@ -387,7 +387,7 @@ class FunctionTest(test.TestCase):
|
||||
call = AConstant()
|
||||
self.assertEqual("AConstant", call.op.name)
|
||||
with session.Session() as sess:
|
||||
self.assertAllEqual([42], sess.run(call))
|
||||
self.assertAllEqual([42], self.evaluate(call))
|
||||
|
||||
def testDefineFunctionNames(self):
|
||||
|
||||
@ -468,7 +468,7 @@ class FunctionTest(test.TestCase):
|
||||
|
||||
loop = control_flow_ops.while_loop(lambda x: x < 1e5, Body, [1.0])
|
||||
|
||||
ans = sess.run(loop)
|
||||
ans = self.evaluate(loop)
|
||||
self.assertAllClose(ans, 131072.)
|
||||
|
||||
def testControlFlowStrictness(self):
|
||||
@ -650,8 +650,8 @@ class FunctionTest(test.TestCase):
|
||||
# pylint: enable=unexpected-keyword-arg
|
||||
self.assertEqual("next", call2.op.name)
|
||||
with session.Session() as sess:
|
||||
self.assertAllEqual([1], sess.run(call1))
|
||||
self.assertAllEqual([0], sess.run(call2))
|
||||
self.assertAllEqual([1], self.evaluate(call1))
|
||||
self.assertAllEqual([0], self.evaluate(call2))
|
||||
|
||||
def testNestedFunction(self):
|
||||
|
||||
@ -794,7 +794,7 @@ class FunctionTest(test.TestCase):
|
||||
y = Foo()
|
||||
|
||||
with self.session(graph=g) as sess:
|
||||
self.assertEqual(sess.run(y), 10)
|
||||
self.assertEqual(self.evaluate(y), 10)
|
||||
|
||||
def testCaptureInCond(self):
|
||||
g = ops.Graph()
|
||||
@ -809,8 +809,8 @@ class FunctionTest(test.TestCase):
|
||||
z = Foo(False)
|
||||
|
||||
with self.session(graph=g) as sess:
|
||||
self.assertEqual(sess.run(y), 1)
|
||||
self.assertEqual(sess.run(z), 2)
|
||||
self.assertEqual(self.evaluate(y), 1)
|
||||
self.assertEqual(self.evaluate(z), 2)
|
||||
|
||||
def testStableName(self):
|
||||
|
||||
@ -900,7 +900,7 @@ class FunctionTest(test.TestCase):
|
||||
self.assertEqual(global_vars[0].name, "linear/w:0")
|
||||
|
||||
with session.Session() as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
output_val = sess.run(
|
||||
output_op, feed_dict={input_op: np.random.rand(32, 100)})
|
||||
self.assertEqual(output_val.shape, (32, 100))
|
||||
@ -928,7 +928,7 @@ class FunctionTest(test.TestCase):
|
||||
self.assertEqual(global_vars[0].name, "vs1/var:0")
|
||||
|
||||
with session.Session() as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
out1, out2 = sess.run(
|
||||
[out1_op, out2_op], feed_dict={input_op: np.linspace(1, 10, 10)})
|
||||
self.assertAllEqual(out1, np.linspace(2, 11, 10))
|
||||
@ -991,8 +991,8 @@ class FunctionTest(test.TestCase):
|
||||
result_2 = Bar(constant_op.constant(100, dtype=dtypes.int64))
|
||||
|
||||
with session.Session() as sess:
|
||||
self.assertEqual(4.0, sess.run(result_1))
|
||||
self.assertEqual(100, sess.run(result_2))
|
||||
self.assertEqual(4.0, self.evaluate(result_1))
|
||||
self.assertEqual(100, self.evaluate(result_2))
|
||||
self.assertEqual((4.0, 100), sess.run((result_1, result_2)))
|
||||
|
||||
def testStatefulFunction(self):
|
||||
@ -1052,8 +1052,8 @@ class FunctionTest(test.TestCase):
|
||||
for config in _OptimizerOptions():
|
||||
config.device_count["CPU"] = 2
|
||||
with session.Session(config=config) as sess:
|
||||
self.assertEqual(42.0, sess.run(f_0))
|
||||
self.assertEqual(44.0, sess.run(f_1))
|
||||
self.assertEqual(42.0, self.evaluate(f_0))
|
||||
self.assertEqual(44.0, self.evaluate(f_1))
|
||||
self.assertEqual((42.0, 44.0), sess.run((f_0, f_1)))
|
||||
|
||||
def testGuaranteedConstsAreCaptured(self):
|
||||
@ -1076,7 +1076,7 @@ class FunctionTest(test.TestCase):
|
||||
return output
|
||||
|
||||
with self.session(use_gpu=False) as sess:
|
||||
sess.run(var.initializer)
|
||||
self.evaluate(var.initializer)
|
||||
_ = sess.run(CapturesGuaranteedConst(), {also_not_const: 1.0})
|
||||
|
||||
def testSameFunctionDifferentGrads(self):
|
||||
@ -1651,8 +1651,8 @@ class ModuleFunctionTest(test.TestCase):
|
||||
y = LinearWithCApi(a, b, c)
|
||||
z = Linear2WithCApi(a, b, c, d, e)
|
||||
with session.Session() as sess:
|
||||
self.assertAllEqual([[1]], sess.run(y))
|
||||
self.assertAllEqual([[5]], sess.run(z))
|
||||
self.assertAllEqual([[1]], self.evaluate(y))
|
||||
self.assertAllEqual([[5]], self.evaluate(z))
|
||||
|
||||
|
||||
class VariableHoistingTest(test.TestCase):
|
||||
@ -1704,7 +1704,7 @@ class VariableHoistingTest(test.TestCase):
|
||||
self.assertEqual("Foo/b", b.op.name)
|
||||
|
||||
with self.session(graph=g) as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
w, b, x, y0, loss, dw, db = sess.run([w, b, x, y0, loss, dw, db])
|
||||
|
||||
self.assertAllEqual(w.shape, (64, 64))
|
||||
|
@ -211,7 +211,7 @@ class DeviceFunctionsTest(test.TestCase):
|
||||
with session.Session() as sess:
|
||||
init = variables.variables_initializer([variable_node])
|
||||
sess.run(init)
|
||||
output = sess.run(output_node)
|
||||
output = self.evaluate(output_node)
|
||||
self.assertNear(4.0, output, 0.00001)
|
||||
variable_graph_def = sess.graph.as_graph_def()
|
||||
|
||||
@ -242,8 +242,8 @@ class DeviceFunctionsTest(test.TestCase):
|
||||
output_node = math_ops_lib.multiply(
|
||||
variable_node, 2.0, name="output_node")
|
||||
with session.Session() as sess:
|
||||
sess.run(variable_node.initializer)
|
||||
output = sess.run(output_node)
|
||||
self.evaluate(variable_node.initializer)
|
||||
output = self.evaluate(output_node)
|
||||
self.assertNear(2.0, output, 0.00001)
|
||||
variable_graph_def = sess.graph.as_graph_def()
|
||||
# First get the constant_graph_def when variable_names_whitelist is
|
||||
@ -256,7 +256,7 @@ class DeviceFunctionsTest(test.TestCase):
|
||||
|
||||
# Then initialize the unused variable, and get another
|
||||
# constant_graph_def when variable_names_whitelist is not set.
|
||||
sess.run(another_variable.initializer)
|
||||
self.evaluate(another_variable.initializer)
|
||||
constant_graph_def_without_variable_whitelist = (
|
||||
graph_util.convert_variables_to_constants(
|
||||
sess, variable_graph_def, ["output_node"]))
|
||||
@ -295,7 +295,7 @@ class DeviceFunctionsTest(test.TestCase):
|
||||
["Variable", "VariableV2", "VarHandleOp", "ReadVariableOp"])
|
||||
with session.Session() as sess:
|
||||
output_node = sess.graph.get_tensor_by_name("output_node:0")
|
||||
output = sess.run(output_node)
|
||||
output = self.evaluate(output_node)
|
||||
self.assertNear(2.0, output, 0.00001)
|
||||
|
||||
def create_node_def(self, op, name, inputs):
|
||||
|
@ -398,10 +398,10 @@ class ImportGraphDefTest(test.TestCase):
|
||||
# TODO(b/76173421): make this work (currently DCHECKS)
|
||||
# with self.cached_session() as sess:
|
||||
# sess.run(imported_init)
|
||||
# self.assertEqual(sess.run(imported_var), 1.0)
|
||||
# self.assertEqual(sess.run(imported_assign), 2.0)
|
||||
# self.assertEqual(list(sess.run(imported_shape)), [])
|
||||
# self.assertEqual(list(sess.run(new_var_shape)), [])
|
||||
# self.assertEqual(self.evaluate(imported_var), 1.0)
|
||||
# self.assertEqual(self.evaluate(imported_assign), 2.0)
|
||||
# self.assertEqual(list(self.evaluate(imported_shape)), [])
|
||||
# self.assertEqual(list(self.evaluate(new_var_shape)), [])
|
||||
|
||||
def testWhileLoop(self):
|
||||
# Produce GraphDef containing while loop.
|
||||
@ -418,7 +418,7 @@ class ImportGraphDefTest(test.TestCase):
|
||||
return_elements=[r.name])
|
||||
self.assertEqual(imported_r.name, "import/" + r.name)
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(sess.run(imported_r), 10)
|
||||
self.assertEqual(self.evaluate(imported_r), 10)
|
||||
|
||||
def testImportWhileLoopInCond(self):
|
||||
# Produce GraphDef containing while loop.
|
||||
@ -458,7 +458,7 @@ class ImportGraphDefTest(test.TestCase):
|
||||
lambda i: i < 2, ImportFn, [0],
|
||||
shape_invariants=[tensor_shape.TensorShape(None)])
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(sess.run(out), 10)
|
||||
self.assertEqual(self.evaluate(out), 10)
|
||||
|
||||
def testTypeMismatchInGraphDef(self):
|
||||
# TODO(skyewm): improve error message
|
||||
|
@ -492,8 +492,8 @@ class ScopedMetaGraphTest(test.TestCase):
|
||||
init_op = variables.global_variables_initializer()
|
||||
grad = gradients_impl.gradients([output], [var])
|
||||
with session.Session() as sess:
|
||||
sess.run(init_op)
|
||||
expected_grad_value = sess.run(grad)
|
||||
self.evaluate(init_op)
|
||||
expected_grad_value = self.evaluate(grad)
|
||||
|
||||
# Restore the MetaGraphDef into a new Graph with an import scope.
|
||||
with ops.Graph().as_default():
|
||||
@ -518,8 +518,8 @@ class ScopedMetaGraphTest(test.TestCase):
|
||||
init_op = variables.global_variables_initializer()
|
||||
|
||||
with session.Session() as sess:
|
||||
sess.run(init_op)
|
||||
actual_grad_value = sess.run(grad)
|
||||
self.evaluate(init_op)
|
||||
actual_grad_value = self.evaluate(grad)
|
||||
self.assertEqual(expected_grad_value, actual_grad_value)
|
||||
|
||||
def testImportWhileLoopInWhileLoop(self):
|
||||
@ -544,7 +544,7 @@ class ScopedMetaGraphTest(test.TestCase):
|
||||
_, x = control_flow_ops.while_loop(lambda i, x: i < 2, body, [0, 0.0],
|
||||
name="")
|
||||
with session.Session() as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
sess.run(x)
|
||||
|
||||
def testScopedImportUnderNameScope(self):
|
||||
@ -869,7 +869,7 @@ class MetaGraphWithVariableScopeTest(test.TestCase):
|
||||
|
||||
initializer = variables.local_variables_initializer()
|
||||
sess.run(initializer)
|
||||
sess.run(update_op)
|
||||
self.evaluate(update_op)
|
||||
|
||||
meta_graph.export_scoped_meta_graph(
|
||||
filename=meta_graph_filename, graph=graph)
|
||||
|
@ -517,21 +517,21 @@ class OperationTest(test_util.TensorFlowTestCase):
|
||||
self.assertEquals(x.consumers(), [])
|
||||
self.assertEquals(y.consumers(), [z.op, z.op])
|
||||
with session.Session(graph=g) as sess:
|
||||
self.assertEquals(sess.run(z), 4)
|
||||
self.assertEquals(self.evaluate(z), 4)
|
||||
|
||||
z.op._update_input(0, x) # pylint: disable=protected-access
|
||||
self.assertEquals(list(z.op.inputs), [x, y])
|
||||
self.assertEquals(x.consumers(), [z.op])
|
||||
self.assertEquals(y.consumers(), [z.op])
|
||||
with session.Session(graph=g) as sess:
|
||||
self.assertEquals(sess.run(z), 3)
|
||||
self.assertEquals(self.evaluate(z), 3)
|
||||
|
||||
z.op._update_input(1, y) # pylint: disable=protected-access
|
||||
self.assertEquals(list(z.op.inputs), [x, y])
|
||||
self.assertEquals(x.consumers(), [z.op])
|
||||
self.assertEquals(y.consumers(), [z.op])
|
||||
with session.Session(graph=g) as sess:
|
||||
self.assertEquals(sess.run(z), 3)
|
||||
self.assertEquals(self.evaluate(z), 3)
|
||||
|
||||
def testUpdateInputGraphError(self):
|
||||
g_0 = ops.Graph()
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user