parent
c16394423c
commit
f6ce9fd485
@ -61,7 +61,7 @@ class CategoricalTest(xla_test.XLATestCase):
|
||||
random_seed.set_random_seed(1618)
|
||||
op = random_ops.multinomial(logits, num_samples,
|
||||
output_dtype=dtypes.int32)
|
||||
d = self.evaluate(op)
|
||||
d = sess.run(op)
|
||||
|
||||
batch_size, num_classes = logits.shape
|
||||
freqs_mat = []
|
||||
@ -86,9 +86,9 @@ class CategoricalTest(xla_test.XLATestCase):
|
||||
|
||||
# The random-number generator, if working correctly, should produce the
|
||||
# same output multiple times with low probability.
|
||||
y = self.evaluate(x)
|
||||
z = self.evaluate(x)
|
||||
w = self.evaluate(x)
|
||||
y = sess.run(x)
|
||||
z = sess.run(x)
|
||||
w = sess.run(x)
|
||||
|
||||
# We use exact equality here. If the random-number generator is producing
|
||||
# deterministic output, all three outputs will be bitwise identical.
|
||||
@ -113,7 +113,7 @@ class CategoricalTest(xla_test.XLATestCase):
|
||||
x = random_ops.multinomial(
|
||||
array_ops.ones(shape=[1, 20], dtype=dtype), 1000,
|
||||
output_dtype=output_dtype)
|
||||
y = self.evaluate(x)
|
||||
y = sess.run(x)
|
||||
self.assertTrue((y >= 0).sum() == 1000)
|
||||
self.assertTrue((y < 20).sum() == 1000)
|
||||
|
||||
|
@ -337,7 +337,7 @@ class ConcatOffsetTest(xla_test.XLATestCase):
|
||||
s1 = constant_op.constant([2, 7, 5], dtypes.int32)
|
||||
s2 = constant_op.constant([2, 20, 5], dtypes.int32)
|
||||
off = gen_array_ops.concat_offset(cdim, [s0, s1, s2])
|
||||
ans = self.evaluate(off)
|
||||
ans = sess.run(off)
|
||||
self.assertAllEqual(ans, [[0, 0, 0], [0, 3, 0], [0, 10, 0]])
|
||||
|
||||
|
||||
@ -350,7 +350,7 @@ class PackTest(xla_test.XLATestCase):
|
||||
s1 = constant_op.constant([2, 7, 5], dtypes.int32)
|
||||
s2 = constant_op.constant([2, 20, 5], dtypes.int32)
|
||||
packed = array_ops.stack([s0, s1, s2])
|
||||
ans = self.evaluate(packed)
|
||||
ans = sess.run(packed)
|
||||
self.assertAllEqual(ans, [[2, 3, 5], [2, 7, 5], [2, 20, 5]])
|
||||
|
||||
def testScalars(self):
|
||||
@ -360,7 +360,7 @@ class PackTest(xla_test.XLATestCase):
|
||||
s1 = constant_op.constant(3, dtypes.int32)
|
||||
s2 = constant_op.constant(5, dtypes.int32)
|
||||
packed = array_ops.stack([s0, s1, s2])
|
||||
ans = self.evaluate(packed)
|
||||
ans = sess.run(packed)
|
||||
self.assertAllEqual(ans, [2, 3, 5])
|
||||
|
||||
def testEmpty(self):
|
||||
@ -370,7 +370,7 @@ class PackTest(xla_test.XLATestCase):
|
||||
s1 = constant_op.constant([[]], dtypes.int32)
|
||||
s2 = constant_op.constant([[]], dtypes.int32)
|
||||
packed = array_ops.stack([s0, s1, s2])
|
||||
ans = self.evaluate(packed)
|
||||
ans = sess.run(packed)
|
||||
self.assertAllEqual(ans, [[[]], [[]], [[]]])
|
||||
|
||||
|
||||
|
@ -106,7 +106,7 @@ class EagerTest(xla_test.XLATestCase):
|
||||
three = constant_op.constant(3)
|
||||
five = constant_op.constant(5)
|
||||
product = three * five
|
||||
self.assertAllEqual(15, self.evaluate(product))
|
||||
self.assertAllEqual(15, sess.run(product))
|
||||
|
||||
def testDegenerateSlices(self):
|
||||
with self.test_scope():
|
||||
|
@ -50,7 +50,7 @@ class FunctionTest(xla_test.XLATestCase):
|
||||
b = constant_op.constant(bval, name="b")
|
||||
with self.test_scope():
|
||||
call_f = Foo(a, b)
|
||||
result = self.evaluate(call_f)
|
||||
result = sess.run(call_f)
|
||||
self.assertAllClose(result, expected, rtol=1e-3)
|
||||
|
||||
def testNestedFunctions(self):
|
||||
@ -76,7 +76,7 @@ class FunctionTest(xla_test.XLATestCase):
|
||||
b = constant_op.constant(bval, name="b")
|
||||
with self.test_scope():
|
||||
call_g = Foo(a, b)
|
||||
result = self.evaluate(call_g)
|
||||
result = sess.run(call_g)
|
||||
self.assertAllClose(result, expected, rtol=1e-3)
|
||||
|
||||
def testFunctionMultipleRetvals(self):
|
||||
@ -100,7 +100,7 @@ class FunctionTest(xla_test.XLATestCase):
|
||||
b = constant_op.constant(bval, name="b")
|
||||
with self.test_scope():
|
||||
call_f = Foo(a, b)
|
||||
result = self.evaluate(call_f)
|
||||
result = sess.run(call_f)
|
||||
self.assertAllClose(result, expected, rtol=1e-3)
|
||||
|
||||
def testCompileTimeConstantsInDefun(self):
|
||||
|
@ -88,7 +88,7 @@ class LSTMTest(test.TestCase):
|
||||
(basename, m_prev_scalar, c_prev_scalar, pad_scalar))
|
||||
|
||||
# Initialize variables and run the unrolled LSTM step.
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
sess.run(variables.global_variables_initializer())
|
||||
return sess.run([m, c])
|
||||
|
||||
def testLSTMCell(self):
|
||||
@ -173,7 +173,7 @@ class LSTMTest(test.TestCase):
|
||||
(basename, m_init_scalar, c_init_scalar, pad_scalar))
|
||||
|
||||
# Initialize variables and run the unrolled LSTM layer.
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
sess.run(variables.global_variables_initializer())
|
||||
return sess.run(out_seq)
|
||||
|
||||
def testLSTMLayer(self):
|
||||
|
@ -33,7 +33,7 @@ class PlaceholderTest(xla_test.XLATestCase):
|
||||
ph = array_ops.placeholder_with_default(v, shape=[])
|
||||
out = ph * 2
|
||||
sess.run(variables.variables_initializer([v]))
|
||||
self.assertEqual(8.0, self.evaluate(out))
|
||||
self.assertEqual(8.0, sess.run(out))
|
||||
|
||||
def test_placeholder_with_default_fed(self):
|
||||
with self.cached_session() as sess, self.test_scope():
|
||||
|
@ -46,9 +46,9 @@ class RandomOpsTest(xla_test.XLATestCase):
|
||||
|
||||
# The random-number generator, if working correctly, should produce the
|
||||
# same output multiple times with low probability.
|
||||
y = self.evaluate(x)
|
||||
z = self.evaluate(x)
|
||||
w = self.evaluate(x)
|
||||
y = sess.run(x)
|
||||
z = sess.run(x)
|
||||
w = sess.run(x)
|
||||
|
||||
# We use exact equality here. If the random-number generator is producing
|
||||
# deterministic output, all three outputs will be bitwise identical.
|
||||
@ -83,7 +83,7 @@ class RandomOpsTest(xla_test.XLATestCase):
|
||||
with self.test_scope():
|
||||
x = random_ops.random_uniform(
|
||||
shape=[1000], dtype=dtype, minval=-2, maxval=33)
|
||||
y = self.evaluate(x)
|
||||
y = sess.run(x)
|
||||
self.assertTrue((y >= -2).sum() == 1000)
|
||||
self.assertTrue((y < 33).sum() == 1000)
|
||||
|
||||
@ -102,7 +102,7 @@ class RandomOpsTest(xla_test.XLATestCase):
|
||||
with self.cached_session() as sess:
|
||||
with self.test_scope():
|
||||
x = random_ops.truncated_normal(shape=[count], dtype=dtype)
|
||||
y = self.evaluate(x)
|
||||
y = sess.run(x)
|
||||
|
||||
def normal_cdf(x):
|
||||
return .5 * math.erfc(-x / math.sqrt(2))
|
||||
@ -148,7 +148,7 @@ class RandomOpsTest(xla_test.XLATestCase):
|
||||
with self.test_scope():
|
||||
x = math_ops.range(1 << 16)
|
||||
shuffle = random_ops.random_shuffle(x)
|
||||
result = self.evaluate(shuffle)
|
||||
result = sess.run(shuffle)
|
||||
expected = range(1 << 16)
|
||||
# Compare sets to avoid randomness behavior changes but make sure still
|
||||
# have all the values.
|
||||
@ -159,7 +159,7 @@ class RandomOpsTest(xla_test.XLATestCase):
|
||||
with self.test_scope():
|
||||
x = array_ops.diag(math_ops.range(20))
|
||||
shuffle = random_ops.random_shuffle(x)
|
||||
result = self.evaluate(shuffle)
|
||||
result = sess.run(shuffle)
|
||||
expected = np.diag(range(20)).flatten()
|
||||
# Compare sets to avoid randomness behavior changes but make sure still
|
||||
# have all the values.
|
||||
|
@ -505,7 +505,7 @@ class TensorArrayTest(xla_test.XLATestCase):
|
||||
[-0.5, 1.5], # read(0) gradient
|
||||
[20.0, 30.0, 40.0, 50.0], # concat gradient
|
||||
])
|
||||
grad_vals = self.evaluate(grad_r) # 2 + 2 entries
|
||||
grad_vals = sess.run(grad_r) # 2 + 2 entries
|
||||
|
||||
self.assertAllClose([2.0 - 0.5 + 20.0, 3.0 + 1.5 + 30.0], grad_vals[0])
|
||||
self.assertAllEqual([4.0 + 40.0, 5.0 + 50.0], grad_vals[1])
|
||||
|
@ -229,7 +229,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
resource_variable_ops.resource_scatter_add(
|
||||
handle, [0], constant_op.constant([[2]], dtype=dtypes.int32)))
|
||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||
self.assertAllEqual(self.evaluate(read), [[3], [7]])
|
||||
self.assertAllEqual(sess.run(read), [[3], [7]])
|
||||
|
||||
def testScatterSub(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
@ -242,7 +242,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
resource_variable_ops.resource_scatter_sub(
|
||||
handle, [1], constant_op.constant([[2]], dtype=dtypes.int32)))
|
||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||
self.assertAllEqual(self.evaluate(read), [[4], [-1]])
|
||||
self.assertAllEqual(sess.run(read), [[4], [-1]])
|
||||
|
||||
def testScatterMul(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
@ -255,7 +255,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
resource_variable_ops.resource_scatter_mul(
|
||||
handle, [0], constant_op.constant([[5]], dtype=dtypes.int32)))
|
||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||
self.assertEqual(self.evaluate(read), [[5]])
|
||||
self.assertEqual(sess.run(read), [[5]])
|
||||
|
||||
def testScatterDiv(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
@ -268,7 +268,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
resource_variable_ops.resource_scatter_div(
|
||||
handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
|
||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||
self.assertAllEqual(self.evaluate(read), [[2]])
|
||||
self.assertAllEqual(sess.run(read), [[2]])
|
||||
|
||||
def testScatterMin(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
@ -281,7 +281,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
resource_variable_ops.resource_scatter_min(
|
||||
handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
|
||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||
self.assertEqual(self.evaluate(read), [[3]])
|
||||
self.assertEqual(sess.run(read), [[3]])
|
||||
|
||||
def testScatterMax(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
@ -294,7 +294,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
resource_variable_ops.resource_scatter_max(
|
||||
handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
|
||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||
self.assertEqual(self.evaluate(read), [[6]])
|
||||
self.assertEqual(sess.run(read), [[6]])
|
||||
|
||||
def testScatterUpdate(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
@ -307,7 +307,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
resource_variable_ops.resource_scatter_update(
|
||||
handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
|
||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||
self.assertEqual(self.evaluate(read), [[3]])
|
||||
self.assertEqual(sess.run(read), [[3]])
|
||||
|
||||
def testScatterAddScalar(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
@ -320,7 +320,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
resource_variable_ops.resource_scatter_add(
|
||||
handle, [0], constant_op.constant(2, dtype=dtypes.int32)))
|
||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||
self.assertEqual(self.evaluate(read), [[3]])
|
||||
self.assertEqual(sess.run(read), [[3]])
|
||||
|
||||
def testScatterSubScalar(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
@ -333,7 +333,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
resource_variable_ops.resource_scatter_sub(
|
||||
handle, [0], constant_op.constant(2, dtype=dtypes.int32)))
|
||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||
self.assertEqual(self.evaluate(read), [[-1]])
|
||||
self.assertEqual(sess.run(read), [[-1]])
|
||||
|
||||
def testScatterMulScalar(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
@ -346,7 +346,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
resource_variable_ops.resource_scatter_mul(
|
||||
handle, [0], constant_op.constant(5, dtype=dtypes.int32)))
|
||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||
self.assertEqual(self.evaluate(read), [[5]])
|
||||
self.assertEqual(sess.run(read), [[5]])
|
||||
|
||||
def testScatterDivScalar(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
@ -359,7 +359,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
resource_variable_ops.resource_scatter_div(
|
||||
handle, [0], constant_op.constant(3, dtype=dtypes.int32)))
|
||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||
self.assertEqual(self.evaluate(read), [[2]])
|
||||
self.assertEqual(sess.run(read), [[2]])
|
||||
|
||||
def testScatterMinScalar(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
@ -372,7 +372,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
resource_variable_ops.resource_scatter_min(
|
||||
handle, [0], constant_op.constant(3, dtype=dtypes.int32)))
|
||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||
self.assertEqual(self.evaluate(read), [[3]])
|
||||
self.assertEqual(sess.run(read), [[3]])
|
||||
|
||||
def testScatterMaxScalar(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
@ -385,7 +385,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
resource_variable_ops.resource_scatter_max(
|
||||
handle, [0], constant_op.constant(3, dtype=dtypes.int32)))
|
||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||
self.assertEqual(self.evaluate(read), [[6]])
|
||||
self.assertEqual(sess.run(read), [[6]])
|
||||
|
||||
def testScatterNdAddOps(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
@ -400,7 +400,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
sess.run(gen_state_ops.resource_scatter_nd_add(handle, indices, updates))
|
||||
read = resource_variable_ops.read_variable_op(
|
||||
handle, dtype=dtypes.float32)
|
||||
self.assertAllClose(expected, self.evaluate(read))
|
||||
self.assertAllClose(expected, sess.run(read))
|
||||
|
||||
def testScatterNdUpdateAddOps(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
@ -416,7 +416,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
gen_state_ops.resource_scatter_nd_update(handle, indices, updates))
|
||||
read = resource_variable_ops.read_variable_op(
|
||||
handle, dtype=dtypes.float32)
|
||||
self.assertAllClose(expected, self.evaluate(read))
|
||||
self.assertAllClose(expected, sess.run(read))
|
||||
|
||||
|
||||
class StridedSliceAssignChecker(object):
|
||||
|
@ -96,7 +96,7 @@ class KerasTest(tf.test.TestCase):
|
||||
sess.run(init)
|
||||
sample_input = tf.random_uniform((1, 10, 10, 1))
|
||||
output = model(sample_input) # pylint: disable=not-callable
|
||||
self.assertEqual(self.evaluate(output).shape, (1, 3))
|
||||
self.assertEqual(sess.run(output).shape, (1, 3))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -34,7 +34,7 @@ class ListLiteralsTest(tf.test.TestCase):
|
||||
result = converted()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(self.evaluate(result), [1, 2, 3])
|
||||
self.assertAllEqual(sess.run(result), [1, 2, 3])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -35,7 +35,7 @@ class InputDataTest(test.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
sample_data = tf.zeros([32000, 2])
|
||||
wav_encoder = contrib_audio.encode_wav(sample_data, 16000)
|
||||
wav_data = self.evaluate(wav_encoder)
|
||||
wav_data = sess.run(wav_encoder)
|
||||
return wav_data
|
||||
|
||||
def _saveTestWavFile(self, filename, wav_data):
|
||||
|
@ -33,7 +33,7 @@ class LabelWavTest(test.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
sample_data = tf.zeros([1000, 2])
|
||||
wav_encoder = contrib_audio.encode_wav(sample_data, 16000)
|
||||
wav_data = self.evaluate(wav_encoder)
|
||||
wav_data = sess.run(wav_encoder)
|
||||
return wav_data
|
||||
|
||||
def _saveTestWavFile(self, filename, wav_data):
|
||||
|
@ -33,7 +33,7 @@ class WavToFeaturesTest(test.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
sample_data = tf.zeros([32000, 2])
|
||||
wav_encoder = contrib_audio.encode_wav(sample_data, 16000)
|
||||
wav_data = self.evaluate(wav_encoder)
|
||||
wav_data = sess.run(wav_encoder)
|
||||
return wav_data
|
||||
|
||||
def _saveTestWavFile(self, filename, wav_data):
|
||||
|
@ -113,7 +113,7 @@ class CallTreesTest(converter_testing.TestCase):
|
||||
with self.compiled(node, ns) as result:
|
||||
with self.cached_session() as sess:
|
||||
result_tensor = result.test_fn(constant_op.constant(1))
|
||||
self.assertEquals(self.evaluate(result_tensor), 3)
|
||||
self.assertEquals(sess.run(result_tensor), 3)
|
||||
|
||||
def test_call_to_decorated_function(self):
|
||||
|
||||
|
@ -68,7 +68,7 @@ class ListTest(converter_testing.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
tl = result.test_fn()
|
||||
r = list_ops.tensor_list_stack(tl, dtypes.int32)
|
||||
self.assertAllEqual(self.evaluate(r), [1, 2, 3])
|
||||
self.assertAllEqual(sess.run(r), [1, 2, 3])
|
||||
|
||||
def test_list_pop(self):
|
||||
|
||||
@ -91,8 +91,8 @@ class ListTest(converter_testing.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
ts, tl = result.test_fn()
|
||||
r = list_ops.tensor_list_stack(tl, dtypes.int32)
|
||||
self.assertAllEqual(self.evaluate(r), [1, 2])
|
||||
self.assertAllEqual(self.evaluate(ts), 3)
|
||||
self.assertAllEqual(sess.run(r), [1, 2])
|
||||
self.assertAllEqual(sess.run(ts), 3)
|
||||
|
||||
def test_double_list_pop(self):
|
||||
|
||||
|
@ -48,12 +48,12 @@ class SideEffectGuardsTest(converter_testing.TestCase):
|
||||
with self.compiled(node, {}, state_ops.assign) as result:
|
||||
with self.cached_session() as sess:
|
||||
v = variable_scope.get_variable('test', initializer=2)
|
||||
self.evaluate(v.initializer)
|
||||
sess.run(v.initializer)
|
||||
sess.run(result.test_fn(v))
|
||||
# TODO(mdan): Add support for this use case.
|
||||
# Right now the variable `a` is not conditioned on the `assign` because
|
||||
# there's no way to add control dependencies to a variable object.
|
||||
self.assertEqual(2, self.evaluate(v))
|
||||
self.assertEqual(2, sess.run(v))
|
||||
|
||||
def test_side_effect_on_used_variable(self):
|
||||
|
||||
@ -69,11 +69,11 @@ class SideEffectGuardsTest(converter_testing.TestCase):
|
||||
with self.compiled(node, {}, state_ops.assign) as result:
|
||||
with self.cached_session() as sess:
|
||||
v = variable_scope.get_variable('test', initializer=2)
|
||||
self.evaluate(v.initializer)
|
||||
sess.run(v.initializer)
|
||||
sess.run(result.test_fn(v))
|
||||
# TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
|
||||
# Right now it's 3 or 4 based on whether the read is synchronized.
|
||||
self.assertEqual(3, self.evaluate(v))
|
||||
self.assertEqual(3, sess.run(v))
|
||||
|
||||
def test_side_effect_on_tensor(self):
|
||||
|
||||
@ -109,10 +109,10 @@ class SideEffectGuardsTest(converter_testing.TestCase):
|
||||
with self.compiled(node, {}, state_ops.assign_add) as result:
|
||||
with self.cached_session() as sess:
|
||||
v = variable_scope.get_variable('test', initializer=2)
|
||||
self.evaluate(v.initializer)
|
||||
sess.run(v.initializer)
|
||||
sess.run(result.test_fn(v))
|
||||
# TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
|
||||
self.assertEqual(4, self.evaluate(v))
|
||||
self.assertEqual(4, sess.run(v))
|
||||
|
||||
def test_multiline_nested_block(self):
|
||||
|
||||
@ -130,10 +130,10 @@ class SideEffectGuardsTest(converter_testing.TestCase):
|
||||
with self.compiled(node, {}, state_ops.assign, ops.name_scope) as result:
|
||||
with self.cached_session() as sess:
|
||||
v = variable_scope.get_variable('test', initializer=2)
|
||||
self.evaluate(v.initializer)
|
||||
sess.run(v.initializer)
|
||||
sess.run(result.test_fn(v))
|
||||
# TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
|
||||
self.assertEqual(3, self.evaluate(v))
|
||||
self.assertEqual(3, sess.run(v))
|
||||
|
||||
def test_multiline_block_unsafe(self):
|
||||
|
||||
@ -153,10 +153,10 @@ class SideEffectGuardsTest(converter_testing.TestCase):
|
||||
state_ops.assign_add) as result:
|
||||
with self.cached_session() as sess:
|
||||
v = variable_scope.get_variable('test', initializer=2)
|
||||
self.evaluate(v.initializer)
|
||||
sess.run(v.initializer)
|
||||
sess.run(result.test_fn(v))
|
||||
# TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
|
||||
self.assertEqual(4, self.evaluate(v))
|
||||
self.assertEqual(4, sess.run(v))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -49,7 +49,7 @@ class SliceTest(converter_testing.TestCase):
|
||||
tl = list_ops.tensor_list_from_tensor(
|
||||
[1, 2], element_shape=constant_op.constant([], dtype=dtypes.int32))
|
||||
y = result.test_fn(tl)
|
||||
self.assertEqual(2, self.evaluate(y))
|
||||
self.assertEqual(2, sess.run(y))
|
||||
|
||||
def test_index_access_multiple_definitions(self):
|
||||
|
||||
|
@ -63,7 +63,7 @@ class ApiTest(test.TestCase):
|
||||
x = tc.test_method(
|
||||
constant_op.constant([2, 4]), constant_op.constant(1),
|
||||
constant_op.constant(-2))
|
||||
self.assertListEqual([0, 1], self.evaluate(x).tolist())
|
||||
self.assertListEqual([0, 1], sess.run(x).tolist())
|
||||
|
||||
def test_decorator_does_not_recurse(self):
|
||||
|
||||
@ -83,7 +83,7 @@ class ApiTest(test.TestCase):
|
||||
x = tc.test_method(
|
||||
constant_op.constant([2, 4]), constant_op.constant(1),
|
||||
constant_op.constant(-2))
|
||||
self.assertListEqual([0, 1], self.evaluate(x).tolist())
|
||||
self.assertListEqual([0, 1], sess.run(x).tolist())
|
||||
|
||||
def test_decorator_calls_unconverted_graph(self):
|
||||
|
||||
@ -104,7 +104,7 @@ class ApiTest(test.TestCase):
|
||||
x = tc.test_method(
|
||||
constant_op.constant([2, 4]), constant_op.constant(1),
|
||||
constant_op.constant(-2))
|
||||
self.assertListEqual([0, 1], self.evaluate(x).tolist())
|
||||
self.assertListEqual([0, 1], sess.run(x).tolist())
|
||||
|
||||
def test_decorator_calls_unconverted_py_func(self):
|
||||
|
||||
@ -130,7 +130,7 @@ class ApiTest(test.TestCase):
|
||||
x = tc.test_method(
|
||||
constant_op.constant([2, 4]), constant_op.constant(1),
|
||||
constant_op.constant(-2))
|
||||
self.assertListEqual([0, 1], self.evaluate(x).tolist())
|
||||
self.assertListEqual([0, 1], sess.run(x).tolist())
|
||||
|
||||
def test_decorator_calls_decorated(self):
|
||||
|
||||
@ -153,7 +153,7 @@ class ApiTest(test.TestCase):
|
||||
x = tc.test_method(
|
||||
constant_op.constant([2, 4]), constant_op.constant(1),
|
||||
constant_op.constant(-2))
|
||||
self.assertListEqual([0, 1], self.evaluate(x).tolist())
|
||||
self.assertListEqual([0, 1], sess.run(x).tolist())
|
||||
|
||||
def test_decorator_preserves_argspec(self):
|
||||
|
||||
@ -192,7 +192,7 @@ class ApiTest(test.TestCase):
|
||||
x = tc.test_method(
|
||||
constant_op.constant([2, 4]), constant_op.constant(1),
|
||||
constant_op.constant(-2))
|
||||
self.assertListEqual([0, 1], self.evaluate(x).tolist())
|
||||
self.assertListEqual([0, 1], sess.run(x).tolist())
|
||||
|
||||
def test_converted_call_builtin(self):
|
||||
x = api.converted_call(range, None, converter.ConversionOptions(), 3)
|
||||
@ -208,7 +208,7 @@ class ApiTest(test.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
x = api.converted_call(test_fn, None, converter.ConversionOptions(),
|
||||
constant_op.constant(-1))
|
||||
self.assertEqual(1, self.evaluate(x))
|
||||
self.assertEqual(1, sess.run(x))
|
||||
|
||||
def test_converted_call_method_explicit_owner(self):
|
||||
# TODO(mdan): Implement.
|
||||
@ -234,7 +234,7 @@ class ApiTest(test.TestCase):
|
||||
tc = TestClass(constant_op.constant(-1))
|
||||
x = api.converted_call(tc.test_method, None,
|
||||
converter.ConversionOptions(), tc)
|
||||
self.assertEqual(1, self.evaluate(x))
|
||||
self.assertEqual(1, sess.run(x))
|
||||
|
||||
def test_converted_call_method_by_class(self):
|
||||
|
||||
@ -252,7 +252,7 @@ class ApiTest(test.TestCase):
|
||||
tc = TestClass(constant_op.constant(-1))
|
||||
x = api.converted_call(TestClass.test_method, None,
|
||||
converter.ConversionOptions(), tc)
|
||||
self.assertEqual(1, self.evaluate(x))
|
||||
self.assertEqual(1, sess.run(x))
|
||||
|
||||
def test_converted_call_callable_object(self):
|
||||
|
||||
@ -269,7 +269,7 @@ class ApiTest(test.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
tc = TestClass(constant_op.constant(-1))
|
||||
x = api.converted_call(tc, None, converter.ConversionOptions())
|
||||
self.assertEqual(1, self.evaluate(x))
|
||||
self.assertEqual(1, sess.run(x))
|
||||
|
||||
def test_converted_call_constructor(self):
|
||||
|
||||
@ -288,7 +288,7 @@ class ApiTest(test.TestCase):
|
||||
constant_op.constant(-1))
|
||||
# tc is now a converted object.
|
||||
x = tc.test_method()
|
||||
self.assertEqual(1, self.evaluate(x))
|
||||
self.assertEqual(1, sess.run(x))
|
||||
|
||||
def test_converted_call_already_converted(self):
|
||||
|
||||
@ -298,12 +298,12 @@ class ApiTest(test.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
x = api.converted_call(f, None, converter.ConversionOptions(),
|
||||
constant_op.constant(0))
|
||||
self.assertTrue(self.evaluate(x))
|
||||
self.assertTrue(sess.run(x))
|
||||
|
||||
converted_f = api.to_graph(f)
|
||||
x = api.converted_call(converted_f, None, converter.ConversionOptions(),
|
||||
constant_op.constant(0))
|
||||
self.assertTrue(self.evaluate(x))
|
||||
self.assertTrue(sess.run(x))
|
||||
|
||||
def test_converted_call_no_user_code(self):
|
||||
|
||||
@ -334,8 +334,8 @@ class ApiTest(test.TestCase):
|
||||
constant_op.constant([[0.0]]), training=True)
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertAllEqual([[0.0, 0.0]], self.evaluate(x))
|
||||
sess.run(variables.global_variables_initializer())
|
||||
self.assertAllEqual([[0.0, 0.0]], sess.run(x))
|
||||
|
||||
def test_converted_call_whitelisted_method_extra_self(self):
|
||||
|
||||
@ -349,8 +349,8 @@ class ApiTest(test.TestCase):
|
||||
model, constant_op.constant([[0.0]]), training=True)
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertAllEqual([[0.0, 0.0]], self.evaluate(x))
|
||||
sess.run(variables.global_variables_initializer())
|
||||
self.assertAllEqual([[0.0, 0.0]], sess.run(x))
|
||||
|
||||
def test_converted_call_whitelisted_method_via_owner(self):
|
||||
|
||||
@ -364,8 +364,8 @@ class ApiTest(test.TestCase):
|
||||
constant_op.constant([[0.0]]), training=True)
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertAllEqual([[0.0, 0.0]], self.evaluate(x))
|
||||
sess.run(variables.global_variables_initializer())
|
||||
self.assertAllEqual([[0.0, 0.0]], sess.run(x))
|
||||
|
||||
def test_converted_call_lambda(self):
|
||||
|
||||
@ -376,8 +376,8 @@ class ApiTest(test.TestCase):
|
||||
x = api.converted_call(l, None, opts, constant_op.constant(0))
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertAllEqual(True, self.evaluate(x))
|
||||
sess.run(variables.global_variables_initializer())
|
||||
self.assertAllEqual(True, sess.run(x))
|
||||
|
||||
def test_to_graph_basic(self):
|
||||
|
||||
@ -390,7 +390,7 @@ class ApiTest(test.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
x = compiled_fn(constant_op.constant([4, 8]), 4)
|
||||
self.assertListEqual([1, 2], self.evaluate(x).tolist())
|
||||
self.assertListEqual([1, 2], sess.run(x).tolist())
|
||||
|
||||
def test_to_graph_with_defaults(self):
|
||||
|
||||
@ -405,7 +405,7 @@ class ApiTest(test.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
x = compiled_fn(constant_op.constant([4, 8]))
|
||||
self.assertListEqual([1, 2], self.evaluate(x).tolist())
|
||||
self.assertListEqual([1, 2], sess.run(x).tolist())
|
||||
|
||||
def test_to_code_basic(self):
|
||||
|
||||
|
@ -36,7 +36,7 @@ class SpecialFunctionsTest(test.TestCase):
|
||||
python_one = special_functions.match_staging_level(1, 1)
|
||||
with self.cached_session() as sess:
|
||||
self.assertTrue(tensor_util.is_tensor(tensor_one))
|
||||
self.assertAllEqual(self.evaluate(tensor_one), 1)
|
||||
self.assertAllEqual(sess.run(tensor_one), 1)
|
||||
self.assertEqual(python_one, 1)
|
||||
|
||||
def test_tensor_list_empty_list(self):
|
||||
@ -45,21 +45,21 @@ class SpecialFunctionsTest(test.TestCase):
|
||||
element_shape=())
|
||||
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(self.evaluate(sl), [])
|
||||
self.assertAllEqual(sess.run(sl), [])
|
||||
|
||||
l = special_functions.tensor_list((),
|
||||
element_dtype=dtypes.int32,
|
||||
element_shape=())
|
||||
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(self.evaluate(sl), [])
|
||||
self.assertAllEqual(sess.run(sl), [])
|
||||
|
||||
def test_tensor_list_tensor(self):
|
||||
l = special_functions.tensor_list(
|
||||
constant_op.constant([], dtype=dtypes.int32))
|
||||
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(self.evaluate(sl), [])
|
||||
self.assertAllEqual(sess.run(sl), [])
|
||||
|
||||
def test_tensor_list_unsupported_initializer(self):
|
||||
with self.assertRaisesRegexp(ValueError, 'unknown type'):
|
||||
@ -76,7 +76,7 @@ class SpecialFunctionsTest(test.TestCase):
|
||||
l = special_functions.tensor_list(elements)
|
||||
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(self.evaluate(sl), [[1, 2], [3, 4]])
|
||||
self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
|
||||
|
||||
def test_tensor_list_array_from_elements(self):
|
||||
elements = [constant_op.constant([1, 2]), constant_op.constant([3, 4])]
|
||||
@ -84,7 +84,7 @@ class SpecialFunctionsTest(test.TestCase):
|
||||
l = special_functions.tensor_list(elements, use_tensor_array=True)
|
||||
sl = l.stack()
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(self.evaluate(sl), [[1, 2], [3, 4]])
|
||||
self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
|
||||
|
||||
def test_stack(self):
|
||||
self.assertEqual(special_functions.stack(1, strict=False), 1)
|
||||
|
@ -35,7 +35,7 @@ class ForLoopTest(test.TestCase):
|
||||
body=lambda i, s: (s + i,),
|
||||
init_state=(0,))
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual((10,), self.evaluate(s))
|
||||
self.assertEqual((10,), sess.run(s))
|
||||
|
||||
def test_python(self):
|
||||
s = control_flow.for_stmt(
|
||||
@ -53,7 +53,7 @@ class ForLoopTest(test.TestCase):
|
||||
body=lambda i, s: (s + i,),
|
||||
init_state=(0,))
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual((10,), self.evaluate(s))
|
||||
self.assertEqual((10,), sess.run(s))
|
||||
|
||||
|
||||
class WhileLoopTest(test.TestCase):
|
||||
@ -66,7 +66,7 @@ class WhileLoopTest(test.TestCase):
|
||||
init_state=(0, 0),
|
||||
extra_deps=(n,))
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual((5, 10), self.evaluate(results))
|
||||
self.assertEqual((5, 10), sess.run(results))
|
||||
|
||||
def test_python(self):
|
||||
n = 5
|
||||
@ -90,9 +90,9 @@ class IfStmtTest(test.TestCase):
|
||||
def test_tensor(self):
|
||||
with self.cached_session() as sess:
|
||||
t = self.single_return_if_stmt(constant_op.constant(True))
|
||||
self.assertEqual(1, self.evaluate(t))
|
||||
self.assertEqual(1, sess.run(t))
|
||||
t = self.single_return_if_stmt(constant_op.constant(False))
|
||||
self.assertEqual(-1, self.evaluate(t))
|
||||
self.assertEqual(-1, sess.run(t))
|
||||
|
||||
def test_python(self):
|
||||
self.assertEqual(1, self.single_return_if_stmt(True))
|
||||
@ -101,9 +101,9 @@ class IfStmtTest(test.TestCase):
|
||||
def test_tensor_multiple_returns(self):
|
||||
with self.cached_session() as sess:
|
||||
t = self.multi_return_if_stmt(constant_op.constant(True))
|
||||
self.assertAllEqual([1, 2], self.evaluate(t))
|
||||
self.assertAllEqual([1, 2], sess.run(t))
|
||||
t = self.multi_return_if_stmt(constant_op.constant(False))
|
||||
self.assertAllEqual([-1, -2], self.evaluate(t))
|
||||
self.assertAllEqual([-1, -2], sess.run(t))
|
||||
|
||||
def test_python_multiple_returns(self):
|
||||
self.assertEqual((1, 2), self.multi_return_if_stmt(True))
|
||||
|
@ -43,7 +43,7 @@ class ListTest(test.TestCase):
|
||||
l = data_structures.tf_tensor_list_new([3, 4, 5])
|
||||
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(self.evaluate(t), [3, 4, 5])
|
||||
self.assertAllEqual(sess.run(t), [3, 4, 5])
|
||||
|
||||
def test_tf_tensor_list_new_empty(self):
|
||||
l = data_structures.tf_tensor_list_new([],
|
||||
@ -51,13 +51,13 @@ class ListTest(test.TestCase):
|
||||
element_shape=())
|
||||
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(self.evaluate(t), [])
|
||||
self.assertAllEqual(sess.run(t), [])
|
||||
|
||||
def test_tf_tensor_list_new_from_tensor(self):
|
||||
l = data_structures.tf_tensor_list_new(constant_op.constant([3, 4, 5]))
|
||||
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(self.evaluate(t), [3, 4, 5])
|
||||
self.assertAllEqual(sess.run(t), [3, 4, 5])
|
||||
|
||||
def test_tf_tensor_list_new_illegal_input(self):
|
||||
with self.assertRaises(ValueError):
|
||||
@ -77,7 +77,7 @@ class ListTest(test.TestCase):
|
||||
l = data_structures.tf_tensor_array_new([3, 4, 5])
|
||||
t = l.stack()
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(self.evaluate(t), [3, 4, 5])
|
||||
self.assertAllEqual(sess.run(t), [3, 4, 5])
|
||||
|
||||
def test_tf_tensor_array_new_illegal_input(self):
|
||||
with self.assertRaises(ValueError):
|
||||
@ -102,7 +102,7 @@ class ListTest(test.TestCase):
|
||||
|
||||
t = list_ops.tensor_list_stack(l, element_dtype=x.dtype)
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(self.evaluate(t), [[1, 2, 3]])
|
||||
self.assertAllEqual(sess.run(t), [[1, 2, 3]])
|
||||
|
||||
def test_append_tensorarray(self):
|
||||
l = tensor_array_ops.TensorArray(dtypes.int32, size=0, dynamic_size=True)
|
||||
@ -131,10 +131,10 @@ class ListTest(test.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
l, x = data_structures.list_pop(l, None, opts)
|
||||
self.assertAllEqual(self.evaluate(x), [3, 4])
|
||||
self.assertAllEqual(sess.run(x), [3, 4])
|
||||
|
||||
t = list_ops.tensor_list_stack(l, element_dtype=initial_list.dtype)
|
||||
self.assertAllEqual(self.evaluate(t), [[1, 2]])
|
||||
self.assertAllEqual(sess.run(t), [[1, 2]])
|
||||
|
||||
def test_pop_python(self):
|
||||
l = [1, 2, 3]
|
||||
@ -152,7 +152,7 @@ class ListTest(test.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
t = data_structures.list_stack(l, opts)
|
||||
self.assertAllEqual(sess.run(t), self.evaluate(initial_list))
|
||||
self.assertAllEqual(sess.run(t), sess.run(initial_list))
|
||||
|
||||
def test_stack_tensor_list_empty(self):
|
||||
l = list_ops.empty_tensor_list(
|
||||
|
@ -45,11 +45,11 @@ class LogicalOperatorsTest(test.TestCase):
|
||||
def test_and_tf(self):
|
||||
with self.cached_session() as sess:
|
||||
t = logical.and_(self._tf_true, self._tf_true)
|
||||
self.assertEqual(self.evaluate(t), True)
|
||||
self.assertEqual(sess.run(t), True)
|
||||
t = logical.and_(self._tf_true, lambda: True)
|
||||
self.assertEqual(self.evaluate(t), True)
|
||||
self.assertEqual(sess.run(t), True)
|
||||
t = logical.and_(self._tf_false, lambda: True)
|
||||
self.assertEqual(self.evaluate(t), False)
|
||||
self.assertEqual(sess.run(t), False)
|
||||
# TODO(mdan): Add a test for ops with side effects.
|
||||
|
||||
def test_or_python(self):
|
||||
@ -63,11 +63,11 @@ class LogicalOperatorsTest(test.TestCase):
|
||||
def test_or_tf(self):
|
||||
with self.cached_session() as sess:
|
||||
t = logical.or_(self._tf_false, self._tf_true)
|
||||
self.assertEqual(self.evaluate(t), True)
|
||||
self.assertEqual(sess.run(t), True)
|
||||
t = logical.or_(self._tf_false, lambda: True)
|
||||
self.assertEqual(self.evaluate(t), True)
|
||||
self.assertEqual(sess.run(t), True)
|
||||
t = logical.or_(self._tf_true, lambda: True)
|
||||
self.assertEqual(self.evaluate(t), True)
|
||||
self.assertEqual(sess.run(t), True)
|
||||
# TODO(mdan): Add a test for ops with side effects.
|
||||
|
||||
def test_not_python(self):
|
||||
@ -78,7 +78,7 @@ class LogicalOperatorsTest(test.TestCase):
|
||||
def test_not_tf(self):
|
||||
with self.cached_session() as sess:
|
||||
t = logical.not_(self._tf_false())
|
||||
self.assertEqual(self.evaluate(t), True)
|
||||
self.assertEqual(sess.run(t), True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -38,29 +38,29 @@ class PyBuiltinsTest(test.TestCase):
|
||||
self.assertEqual(py_builtins.abs_(-1), 1)
|
||||
with self.cached_session() as sess:
|
||||
t = py_builtins.abs_(constant_op.constant(-1))
|
||||
self.assertEqual(self.evaluate(t), 1)
|
||||
self.assertEqual(sess.run(t), 1)
|
||||
t = py_builtins.abs_(constant_op.constant([-1, 2, -3]))
|
||||
self.assertAllEqual(self.evaluate(t), [1, 2, 3])
|
||||
self.assertAllEqual(sess.run(t), [1, 2, 3])
|
||||
|
||||
def test_float(self):
|
||||
self.assertEqual(py_builtins.float_(10), 10.0)
|
||||
self.assertEqual(py_builtins.float_('10.0'), 10.0)
|
||||
with self.cached_session() as sess:
|
||||
t = py_builtins.float_(constant_op.constant(1, dtype=dtypes.int64))
|
||||
self.assertEqual(self.evaluate(t), 1.0)
|
||||
self.assertEqual(sess.run(t), 1.0)
|
||||
st = py_builtins.float_(constant_op.constant('1.0'))
|
||||
self.assertEqual(self.evaluate(st), 1.0)
|
||||
self.assertEqual(sess.run(st), 1.0)
|
||||
|
||||
def test_int(self):
|
||||
self.assertEqual(py_builtins.int_(10.0), 10)
|
||||
self.assertEqual(py_builtins.int_('11', 2), 3)
|
||||
with self.cached_session() as sess:
|
||||
t = py_builtins.int_(constant_op.constant(1, dtype=dtypes.float64))
|
||||
self.assertEqual(self.evaluate(t), 1)
|
||||
self.assertEqual(sess.run(t), 1)
|
||||
st = py_builtins.int_(constant_op.constant('1'))
|
||||
self.assertEqual(self.evaluate(st), 1)
|
||||
self.assertEqual(sess.run(st), 1)
|
||||
st = py_builtins.int_(constant_op.constant('1'), 10)
|
||||
self.assertEqual(self.evaluate(st), 1)
|
||||
self.assertEqual(sess.run(st), 1)
|
||||
|
||||
def test_int_unsupported_base(self):
|
||||
t = constant_op.constant(1, dtype=dtypes.float64)
|
||||
@ -73,9 +73,9 @@ class PyBuiltinsTest(test.TestCase):
|
||||
t = py_builtins.len_(constant_op.constant([[1], [2], [3]]))
|
||||
self.assertEqual(t, 3)
|
||||
ta = py_builtins.len_(tensor_array_ops.TensorArray(dtypes.int32, size=5))
|
||||
self.assertEqual(self.evaluate(ta), 5)
|
||||
self.assertEqual(sess.run(ta), 5)
|
||||
tl = py_builtins.len_(data_structures.tf_tensor_list_new([3, 4, 5]))
|
||||
self.assertEqual(self.evaluate(tl), 3)
|
||||
self.assertEqual(sess.run(tl), 3)
|
||||
|
||||
def test_len_scalar(self):
|
||||
with self.assertRaises(ValueError):
|
||||
@ -120,18 +120,18 @@ class PyBuiltinsTest(test.TestCase):
|
||||
def test_range_tensor(self):
|
||||
with self.cached_session() as sess:
|
||||
r = py_builtins.range_(constant_op.constant(3))
|
||||
self.assertAllEqual(self.evaluate(r), [0, 1, 2])
|
||||
self.assertAllEqual(sess.run(r), [0, 1, 2])
|
||||
r = py_builtins.range_(1, constant_op.constant(3))
|
||||
self.assertAllEqual(self.evaluate(r), [1, 2])
|
||||
self.assertAllEqual(sess.run(r), [1, 2])
|
||||
r = py_builtins.range_(2, 0, constant_op.constant(-1))
|
||||
self.assertAllEqual(self.evaluate(r), [2, 1])
|
||||
self.assertAllEqual(sess.run(r), [2, 1])
|
||||
|
||||
def test_range_tensor_empty_range(self):
|
||||
with self.session() as sess:
|
||||
r = py_builtins.range_(constant_op.constant(-3))
|
||||
self.assertAllEqual(self.evaluate(r), [])
|
||||
self.assertAllEqual(sess.run(r), [])
|
||||
r = py_builtins.range_(5, constant_op.constant(2))
|
||||
self.assertAllEqual(self.evaluate(r), [])
|
||||
self.assertAllEqual(sess.run(r), [])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -34,7 +34,7 @@ class SlicesTest(test.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
t = list_ops.tensor_list_stack(l, element_dtype=initial_list.dtype)
|
||||
self.assertAllEqual(self.evaluate(t), [[5, 6], [3, 4]])
|
||||
self.assertAllEqual(sess.run(t), [[5, 6], [3, 4]])
|
||||
|
||||
def test_get_item_tensor_list(self):
|
||||
initial_list = constant_op.constant([[1, 2], [3, 4]])
|
||||
@ -44,7 +44,7 @@ class SlicesTest(test.TestCase):
|
||||
l, 1, slices.GetItemOpts(element_dtype=initial_list.dtype))
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(self.evaluate(t), [3, 4])
|
||||
self.assertAllEqual(sess.run(t), [3, 4])
|
||||
|
||||
def test_get_item_tensor_string(self):
|
||||
initial_str = constant_op.constant('abcd')
|
||||
@ -52,14 +52,14 @@ class SlicesTest(test.TestCase):
|
||||
slices.GetItemOpts(element_dtype=initial_str.dtype))
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(self.evaluate(t), b'b')
|
||||
self.assertEqual(sess.run(t), b'b')
|
||||
|
||||
initial_list_str = constant_op.constant(['abcd', 'bcde'])
|
||||
t = slices.get_item(initial_list_str, 1,
|
||||
slices.GetItemOpts(element_dtype=initial_str.dtype))
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(self.evaluate(t), b'bcde')
|
||||
self.assertEqual(sess.run(t), b'bcde')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -32,7 +32,7 @@ class MiscTest(test.TestCase):
|
||||
new_a = alias_tensors(a)
|
||||
self.assertFalse(new_a is a)
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(1, self.evaluate(new_a))
|
||||
self.assertEqual(1, sess.run(new_a))
|
||||
|
||||
def test_alias_tensors(self):
|
||||
a = constant(1)
|
||||
@ -47,7 +47,7 @@ class MiscTest(test.TestCase):
|
||||
self.assertTrue(new_s is s)
|
||||
self.assertTrue(new_l is l)
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(1, self.evaluate(new_a))
|
||||
self.assertEqual(1, sess.run(new_a))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -34,13 +34,13 @@ class PyFuncTest(test.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
result = py_func.wrap_py_func(test_fn, dtypes.int64,
|
||||
(1, constant_op.constant(1), 1))
|
||||
self.assertEqual(3, self.evaluate(result))
|
||||
self.assertEqual(3, sess.run(result))
|
||||
result = py_func.wrap_py_func(test_fn, dtypes.int64, (1, 1, 1))
|
||||
self.assertEqual(3, self.evaluate(result))
|
||||
self.assertEqual(3, sess.run(result))
|
||||
result = py_func.wrap_py_func(
|
||||
test_fn, dtypes.int64,
|
||||
(constant_op.constant(1), 1, constant_op.constant(1)))
|
||||
self.assertEqual(3, self.evaluate(result))
|
||||
self.assertEqual(3, sess.run(result))
|
||||
|
||||
def test_wrap_py_func_complex_args(self):
|
||||
|
||||
@ -54,10 +54,10 @@ class PyFuncTest(test.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
result = py_func.wrap_py_func(test_fn, dtypes.int64, (7, TestClass()))
|
||||
self.assertEqual(35, self.evaluate(result))
|
||||
self.assertEqual(35, sess.run(result))
|
||||
result = py_func.wrap_py_func(test_fn, dtypes.int64,
|
||||
(constant_op.constant(7), TestClass()))
|
||||
self.assertEqual(35, self.evaluate(result))
|
||||
self.assertEqual(35, sess.run(result))
|
||||
|
||||
def test_wrap_py_func_kwargs(self):
|
||||
|
||||
@ -74,13 +74,13 @@ class PyFuncTest(test.TestCase):
|
||||
'c': 11,
|
||||
'd': TestClass(13)
|
||||
})
|
||||
self.assertEqual(178, self.evaluate(result))
|
||||
self.assertEqual(178, sess.run(result))
|
||||
result = py_func.wrap_py_func(test_fn, dtypes.int64,
|
||||
(constant_op.constant(7), TestClass(5)), {
|
||||
'c': constant_op.constant(11),
|
||||
'd': TestClass(13)
|
||||
})
|
||||
self.assertEqual(178, self.evaluate(result))
|
||||
self.assertEqual(178, sess.run(result))
|
||||
|
||||
def test_wrap_py_func_dummy_return(self):
|
||||
|
||||
@ -91,11 +91,11 @@ class PyFuncTest(test.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
result = py_func.wrap_py_func(test_fn, None, (5,), use_dummy_return=True)
|
||||
self.assertEqual(1, self.evaluate(result))
|
||||
self.assertEqual(1, sess.run(result))
|
||||
self.assertEqual([1], side_counter)
|
||||
result = py_func.wrap_py_func(
|
||||
test_fn, None, (constant_op.constant(5),), use_dummy_return=True)
|
||||
self.assertEqual(1, self.evaluate(result))
|
||||
self.assertEqual(1, sess.run(result))
|
||||
self.assertEqual([2], side_counter)
|
||||
|
||||
|
||||
|
@ -43,13 +43,13 @@ class TensorListTest(test.TestCase):
|
||||
l = tl.dynamic_list_append(l, 1)
|
||||
s = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(self.evaluate(s), [1])
|
||||
self.assertAllEqual(sess.run(s), [1])
|
||||
|
||||
l = tensor_array_ops.TensorArray(dtypes.int32, size=0, dynamic_size=True)
|
||||
l = tl.dynamic_list_append(l, 1)
|
||||
s = l.stack()
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(self.evaluate(s), [1])
|
||||
self.assertAllEqual(sess.run(s), [1])
|
||||
|
||||
l = tl.TensorList(self._shape(()), dtypes.int32)
|
||||
l = tl.dynamic_list_append(l, 1)
|
||||
|
@ -62,7 +62,7 @@ class SessionClusterSpecPropagationTest(test_util.TensorFlowTestCase):
|
||||
|
||||
const = constant_op.constant(17)
|
||||
sess = session.Session(server1.target, config=config)
|
||||
output = self.evaluate(const)
|
||||
output = sess.run(const)
|
||||
self.assertEqual(17, output)
|
||||
|
||||
def testClusterSpecPropagationWorker2Placement(self):
|
||||
@ -106,7 +106,7 @@ class SessionClusterSpecPropagationTest(test_util.TensorFlowTestCase):
|
||||
with ops.Graph().as_default() as g, ops.device('/job:worker/task:0'):
|
||||
const = constant_op.constant(17)
|
||||
sess = session.Session(server1.target, config=config, graph=g)
|
||||
output = self.evaluate(const)
|
||||
output = sess.run(const)
|
||||
self.assertEqual(17, output)
|
||||
|
||||
def testCanonicalDeviceNames(self):
|
||||
@ -208,7 +208,7 @@ class SessionClusterSpecPropagationTest(test_util.TensorFlowTestCase):
|
||||
with ops.device('/job:worker/task:0/cpu:0'):
|
||||
sum3 = sum1 + sum2
|
||||
sess = session.Session(server1.target, config=config, graph=g)
|
||||
output = self.evaluate(sum3)
|
||||
output = sess.run(sum3)
|
||||
self.assertEqual(40, output)
|
||||
|
||||
def testLegacyDeviceNames(self):
|
||||
|
@ -147,7 +147,7 @@ class TimelineTest(test.TestCase):
|
||||
num2 = variables.Variable(2.0, name='num2')
|
||||
with ops.device('/cpu:2'):
|
||||
result = num1 + num2 + num1 * num2
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
sess.run(variables.global_variables_initializer())
|
||||
sess.run(result, options=run_options, run_metadata=run_metadata)
|
||||
|
||||
self.assertTrue(run_metadata.HasField('step_stats'))
|
||||
@ -176,7 +176,7 @@ class TimelineTest(test.TestCase):
|
||||
num2 = variables.Variable(2.0, name='num2')
|
||||
with ops.device('/cpu:2'):
|
||||
result = num1 + num2 + num1 * num2
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
sess.run(variables.global_variables_initializer())
|
||||
sess.run(result, options=run_options, run_metadata=run_metadata)
|
||||
self.assertTrue(run_metadata.HasField('step_stats'))
|
||||
step_stats = run_metadata.step_stats
|
||||
|
@ -216,7 +216,7 @@ class VirtualGpuTest(test_util.TensorFlowTestCase):
|
||||
for d in self._util.devices:
|
||||
with ops.device(d):
|
||||
var = variables.Variable(random_ops.random_uniform(mat_shape))
|
||||
self.evaluate(var.initializer)
|
||||
sess.run(var.initializer)
|
||||
data.append(var)
|
||||
s = data[0]
|
||||
for i in range(1, len(data)):
|
||||
|
@ -53,10 +53,10 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
|
||||
for start in range(0, len(components), 4):
|
||||
results = self.evaluate(get_next)
|
||||
results = sess.run(get_next)
|
||||
self.assertAllEqual([[i, j]
|
||||
for i, c in enumerate(components[start:start + 4])
|
||||
for j in range(c)], results.indices)
|
||||
@ -81,10 +81,10 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
|
||||
for start in range(0, len(components), 4):
|
||||
results = self.evaluate(get_next)
|
||||
results = sess.run(get_next)
|
||||
self.assertAllEqual([[i, j, z]
|
||||
for i, c in enumerate(components[start:start + 4])
|
||||
for j in range(c)
|
||||
@ -141,7 +141,7 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer, feed_dict={placeholder: [0, 1, 2, 3]})
|
||||
for i in range(4):
|
||||
self.assertEqual(i, self.evaluate(next_elem))
|
||||
self.assertEqual(i, sess.run(next_elem))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_elem)
|
||||
|
||||
@ -159,7 +159,7 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual((i,) * 3, self.evaluate(op))
|
||||
self.assertEqual((i,) * 3, sess.run(op))
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(op)
|
||||
@ -179,7 +179,7 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual((i, compat.as_bytes(str(i)), i), self.evaluate(op))
|
||||
self.assertEqual((i, compat.as_bytes(str(i)), i), sess.run(op))
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(op)
|
||||
@ -198,7 +198,7 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
st_row = self.evaluate(next_element)
|
||||
st_row = sess.run(next_element)
|
||||
self.assertEqual([i], st_row.indices)
|
||||
self.assertEqual([i], st_row.values)
|
||||
self.assertEqual([10], st_row.dense_shape)
|
||||
@ -219,7 +219,7 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
dense_elem, st_row = self.evaluate(next_element)
|
||||
dense_elem, st_row = sess.run(next_element)
|
||||
self.assertEqual(i, dense_elem)
|
||||
self.assertEqual([i], st_row.indices)
|
||||
self.assertEqual([i], st_row.values)
|
||||
@ -241,7 +241,7 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual(((i,),) * 3, self.evaluate(op))
|
||||
self.assertEqual(((i,),) * 3, sess.run(op))
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(op)
|
||||
@ -354,7 +354,7 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
sess.run(init_op, feed_dict={count: 28, batch_size: 14})
|
||||
num_batches = (28 * 7) // 14
|
||||
for i in range(num_batches):
|
||||
result = self.evaluate(get_next)
|
||||
result = sess.run(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
for j in range(14):
|
||||
self.assertAllEqual(component[(i * 14 + j) % 7]**2,
|
||||
@ -369,12 +369,12 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
# We expect (num_batches - 1) full-sized batches.
|
||||
num_batches = int(math.ceil((14 * 7) / 8))
|
||||
for i in range(num_batches - 1):
|
||||
result = self.evaluate(get_next)
|
||||
result = sess.run(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
for j in range(8):
|
||||
self.assertAllEqual(component[(i * 8 + j) % 7]**2,
|
||||
result_component[j])
|
||||
result = self.evaluate(get_next)
|
||||
result = sess.run(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
for j in range((14 * 7) % 8):
|
||||
self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2,
|
||||
@ -408,10 +408,10 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertEqual([None, 1], iterator.output_shapes.as_list())
|
||||
next_element = iterator.get_next()
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual([[0], [1], [4], [9]], self.evaluate(next_element))
|
||||
self.assertAllEqual([[16], [25], [36], [49]], self.evaluate(next_element))
|
||||
self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
|
||||
self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
|
||||
if not drop_remainder:
|
||||
self.assertAllEqual([[64], [81]], self.evaluate(next_element))
|
||||
self.assertAllEqual([[64], [81]], sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -423,9 +423,9 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertEqual([None, 1], iterator.output_shapes.as_list())
|
||||
next_element = iterator.get_next()
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual([[0], [1], [4], [9]], self.evaluate(next_element))
|
||||
self.assertAllEqual([[16], [25], [36], [49]], self.evaluate(next_element))
|
||||
self.assertAllEqual([[64], [81]], self.evaluate(next_element))
|
||||
self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
|
||||
self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
|
||||
self.assertAllEqual([[64], [81]], sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -439,7 +439,7 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
elements.append(iterator.get_next())
|
||||
with self.cached_session() as sess:
|
||||
for i in range(5):
|
||||
got = self.evaluate(elements)
|
||||
got = sess.run(elements)
|
||||
got.sort(key=lambda x: x[0])
|
||||
expected = []
|
||||
for j in range(100):
|
||||
@ -459,7 +459,7 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
elements.append(iterator.get_next())
|
||||
with self.cached_session() as sess:
|
||||
for i in range(4):
|
||||
got = self.evaluate(elements)
|
||||
got = sess.run(elements)
|
||||
got.sort(key=lambda x: x[0])
|
||||
expected = []
|
||||
for j in range(100):
|
||||
@ -480,9 +480,9 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
for i in range(2):
|
||||
actual = self.evaluate(get_next)
|
||||
actual = sess.run(get_next)
|
||||
expected = sparse_tensor.SparseTensorValue(
|
||||
indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
|
||||
values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4],
|
||||
@ -524,7 +524,7 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
"number of elements does not match"):
|
||||
sess.run(get_next)
|
||||
@ -576,8 +576,7 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(threshold // 10):
|
||||
self.assertAllEqual([i * 10 + j for j in range(10)],
|
||||
self.evaluate(get_next))
|
||||
self.assertAllEqual([i * 10 + j for j in range(10)], sess.run(get_next))
|
||||
if threshold % 10 != 0:
|
||||
self.assertAllEqual(
|
||||
[threshold // 10 * 10 + j for j in range(threshold % 10)],
|
||||
@ -610,8 +609,7 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for _ in range(10):
|
||||
self.assertAllEqual([element for _ in range(10)],
|
||||
self.evaluate(get_next))
|
||||
self.assertAllEqual([element for _ in range(10)], sess.run(get_next))
|
||||
|
||||
|
||||
class UnbatchDatasetBenchmark(test.Benchmark):
|
||||
|
@ -300,7 +300,7 @@ class BucketBySequenceLengthTest(test_base.DatasetTestBase):
|
||||
with self.cached_session() as sess:
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
while True:
|
||||
output = self.evaluate(batch)
|
||||
output = sess.run(batch)
|
||||
sprs_tensor = (tuple([tuple(idx) for idx in output.indices]),
|
||||
tuple(output.values))
|
||||
all_sparse_tensors.add(sprs_tensor)
|
||||
|
@ -57,7 +57,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -82,7 +82,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
self.assertAllEqual([0, 1, 2, 3], self.evaluate(next_element))
|
||||
self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -108,7 +108,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -134,7 +134,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -160,7 +160,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual({"a": i}, self.evaluate(next_element))
|
||||
self.assertEqual({"a": i}, sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -186,7 +186,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual({"a": i}, self.evaluate(next_element))
|
||||
self.assertEqual({"a": i}, sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -217,7 +217,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
for i in range(10):
|
||||
actual = self.evaluate(next_element)
|
||||
actual = sess.run(next_element)
|
||||
self.assertAllEqual([i], actual.values)
|
||||
self.assertAllEqual([[0, 0]], actual.indices)
|
||||
self.assertAllEqual([2, 2], actual.dense_shape)
|
||||
@ -251,7 +251,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
for i in range(10):
|
||||
actual = self.evaluate(next_element)
|
||||
actual = sess.run(next_element)
|
||||
self.assertAllEqual([i], actual.values)
|
||||
self.assertAllEqual([[0, 0]], actual.indices)
|
||||
self.assertAllEqual([2, 2], actual.dense_shape)
|
||||
@ -271,9 +271,9 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(iterator.initializer)
|
||||
sess.run(iterator.initializer)
|
||||
for i in range(10):
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -290,9 +290,9 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(iterator.initializer)
|
||||
sess.run(iterator.initializer)
|
||||
for i in range(10):
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -323,9 +323,9 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(iterator.initializer)
|
||||
sess.run(iterator.initializer)
|
||||
for i in range(10):
|
||||
x, y, z = self.evaluate(next_element)
|
||||
x, y, z = sess.run(next_element)
|
||||
self.assertEqual(i**2, x)
|
||||
self.assertEqual(float(i**2), y)
|
||||
self.assertEqual(util_compat.as_bytes(str(i)), z)
|
||||
@ -345,8 +345,8 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(iterator.initializer)
|
||||
self.assertAllEqual([0, 1, 2, 3], self.evaluate(next_element))
|
||||
sess.run(iterator.initializer)
|
||||
self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -363,8 +363,8 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(iterator.initializer)
|
||||
self.assertAllEqual([0, 1, 2, 3], self.evaluate(next_element))
|
||||
sess.run(iterator.initializer)
|
||||
self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -381,8 +381,8 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(iterator.initializer)
|
||||
self.assertAllEqual([b"a", b"b", b"c"], self.evaluate(next_element))
|
||||
sess.run(iterator.initializer)
|
||||
self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -399,8 +399,8 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(iterator.initializer)
|
||||
self.assertAllEqual([b"a", b"b", b"c"], self.evaluate(next_element))
|
||||
sess.run(iterator.initializer)
|
||||
self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -420,9 +420,9 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(iterator.initializer)
|
||||
sess.run(iterator.initializer)
|
||||
for i in range(10):
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -447,12 +447,12 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
self.evaluate(iterator.initializer)
|
||||
sess.run(iterator.initializer)
|
||||
for i in range(5):
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
self.evaluate(iterator.initializer)
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
sess.run(iterator.initializer)
|
||||
for i in range(10):
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -477,12 +477,12 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
self.evaluate(iterator.initializer)
|
||||
sess.run(iterator.initializer)
|
||||
for i in range(5):
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
self.evaluate(iterator.initializer)
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
sess.run(iterator.initializer)
|
||||
for i in range(10):
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -499,12 +499,12 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(iterator.initializer)
|
||||
sess.run(iterator.initializer)
|
||||
for i in range(5):
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
self.evaluate(iterator.initializer)
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
sess.run(iterator.initializer)
|
||||
for i in range(10):
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -521,12 +521,12 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(iterator.initializer)
|
||||
sess.run(iterator.initializer)
|
||||
for i in range(5):
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
self.evaluate(iterator.initializer)
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
sess.run(iterator.initializer)
|
||||
for i in range(10):
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -553,7 +553,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
|
||||
# For each element of the dataset, assert that the optional evaluates to
|
||||
# the expected value.
|
||||
self.evaluate(iterator.initializer)
|
||||
sess.run(iterator.initializer)
|
||||
for i in range(3):
|
||||
elem_has_value, elem_value = sess.run([elem_has_value_t, elem_value_t])
|
||||
self.assertTrue(elem_has_value)
|
||||
@ -562,7 +562,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
# After exhausting the iterator, `next_elem.has_value()` will evaluate to
|
||||
# false, and attempting to get the value will fail.
|
||||
for _ in range(2):
|
||||
self.assertFalse(self.evaluate(elem_has_value_t))
|
||||
self.assertFalse(sess.run(elem_has_value_t))
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(elem_value_t)
|
||||
|
||||
|
@ -38,13 +38,13 @@ class CounterTest(test_base.DatasetTestBase):
|
||||
negative_get_next = negative_iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(3, self.evaluate(get_next))
|
||||
self.assertEqual(3 + 4, self.evaluate(get_next))
|
||||
self.assertEqual(3 + 2 * 4, self.evaluate(get_next))
|
||||
self.assertEqual(3, sess.run(get_next))
|
||||
self.assertEqual(3 + 4, sess.run(get_next))
|
||||
self.assertEqual(3 + 2 * 4, sess.run(get_next))
|
||||
|
||||
self.assertEqual(0, self.evaluate(negative_get_next))
|
||||
self.assertEqual(-1, self.evaluate(negative_get_next))
|
||||
self.assertEqual(-2, self.evaluate(negative_get_next))
|
||||
self.assertEqual(0, sess.run(negative_get_next))
|
||||
self.assertEqual(-1, sess.run(negative_get_next))
|
||||
self.assertEqual(-2, sess.run(negative_get_next))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -41,10 +41,10 @@ class DenseToSparseBatchTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
|
||||
for start in range(0, len(components), 4):
|
||||
results = self.evaluate(get_next)
|
||||
results = sess.run(get_next)
|
||||
self.assertAllEqual([[i, j]
|
||||
for i, c in enumerate(components[start:start + 4])
|
||||
for j in range(c)], results.indices)
|
||||
@ -69,10 +69,10 @@ class DenseToSparseBatchTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
|
||||
for start in range(0, len(components), 4):
|
||||
results = self.evaluate(get_next)
|
||||
results = sess.run(get_next)
|
||||
self.assertAllEqual([[i, j, z]
|
||||
for i, c in enumerate(components[start:start + 4])
|
||||
for j in range(c)
|
||||
|
@ -40,10 +40,10 @@ class DirectedInterleaveDatasetTest(test_base.DatasetTestBase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(iterator.initializer)
|
||||
sess.run(iterator.initializer)
|
||||
for _ in range(100):
|
||||
for i in range(10):
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -107,7 +107,7 @@ class DirectedInterleaveDatasetTest(test_base.DatasetTestBase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in choice_array:
|
||||
self.assertEqual(words[i], self.evaluate(next_element))
|
||||
self.assertEqual(words[i], sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
|
@ -44,9 +44,9 @@ class EnumerateDatasetTest(test_base.DatasetTestBase):
|
||||
[t.shape for t in get_next[1]])
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
self.assertEqual((20, (b"a", 1, 37.0)), self.evaluate(get_next))
|
||||
self.assertEqual((21, (b"b", 2, 38.0)), self.evaluate(get_next))
|
||||
sess.run(init_op)
|
||||
self.assertEqual((20, (b"a", 1, 37.0)), sess.run(get_next))
|
||||
self.assertEqual((21, (b"b", 2, 38.0)), sess.run(get_next))
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
@ -94,18 +94,18 @@ class FunctionBufferingResourceTest(test_base.DatasetTestBase):
|
||||
device0, device1)
|
||||
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
elem = self.evaluate(prefetch_op)
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [1.0])
|
||||
elem = self.evaluate(prefetch_op)
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [2.0])
|
||||
elem = self.evaluate(prefetch_op)
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [3.0])
|
||||
elem = self.evaluate(prefetch_op)
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [4.0])
|
||||
self._event.wait()
|
||||
elem = self.evaluate(prefetch_op)
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [5.0])
|
||||
self.evaluate(destroy_op)
|
||||
sess.run(destroy_op)
|
||||
|
||||
def testSameDeviceCPU(self):
|
||||
self._prefetch_fn_helper_one_shot("same_device_cpu",
|
||||
@ -135,35 +135,35 @@ class FunctionBufferingResourceTest(test_base.DatasetTestBase):
|
||||
ds, ds_iterator, "reinit", device0, device1)
|
||||
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
self.evaluate(ds_iterator.initializer)
|
||||
elem = self.evaluate(prefetch_op)
|
||||
sess.run(ds_iterator.initializer)
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [1.0])
|
||||
elem = self.evaluate(prefetch_op)
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [2.0])
|
||||
elem = self.evaluate(prefetch_op)
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [3.0])
|
||||
elem = self.evaluate(prefetch_op)
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [4.0])
|
||||
self._event.wait()
|
||||
elem = self.evaluate(prefetch_op)
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [5.0])
|
||||
# Lets reset the function buffering resource and reinitialize the
|
||||
# iterator. Should be able to go through this again.
|
||||
self._event.clear()
|
||||
self.evaluate(reset_op)
|
||||
self.evaluate(ds_iterator.initializer)
|
||||
elem = self.evaluate(prefetch_op)
|
||||
sess.run(reset_op)
|
||||
sess.run(ds_iterator.initializer)
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [1.0])
|
||||
elem = self.evaluate(prefetch_op)
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [2.0])
|
||||
elem = self.evaluate(prefetch_op)
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [3.0])
|
||||
elem = self.evaluate(prefetch_op)
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [4.0])
|
||||
self._event.wait()
|
||||
elem = self.evaluate(prefetch_op)
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [5.0])
|
||||
self.evaluate(destroy_op)
|
||||
sess.run(destroy_op)
|
||||
|
||||
def testReinitializationOutOfRange(self):
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
@ -175,30 +175,30 @@ class FunctionBufferingResourceTest(test_base.DatasetTestBase):
|
||||
ds, ds_iterator, "reinit", device0, device1)
|
||||
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
self.evaluate(ds_iterator.initializer)
|
||||
sess.run(ds_iterator.initializer)
|
||||
for i in range(1, 10):
|
||||
elem = self.evaluate(prefetch_op)
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [float(i)])
|
||||
# Try fetching after its over twice to test out end of sequence.
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(prefetch_op)
|
||||
sess.run(prefetch_op)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(prefetch_op)
|
||||
sess.run(prefetch_op)
|
||||
|
||||
# Now reset everything and try it out again.
|
||||
self._event.clear()
|
||||
self.evaluate(reset_op)
|
||||
self.evaluate(ds_iterator.initializer)
|
||||
sess.run(reset_op)
|
||||
sess.run(ds_iterator.initializer)
|
||||
for i in range(1, 10):
|
||||
elem = self.evaluate(prefetch_op)
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [float(i)])
|
||||
# Try fetching after its over twice to test out end of sequence.
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(prefetch_op)
|
||||
sess.run(prefetch_op)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(prefetch_op)
|
||||
sess.run(prefetch_op)
|
||||
|
||||
self.evaluate(destroy_op)
|
||||
sess.run(destroy_op)
|
||||
|
||||
def testStringsGPU(self):
|
||||
if not test_util.is_gpu_available():
|
||||
@ -235,13 +235,13 @@ class FunctionBufferingResourceTest(test_base.DatasetTestBase):
|
||||
buffer_resource_handle, ignore_lookup_error=True)
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual([b"a"], self.evaluate(prefetch_op))
|
||||
self.assertEqual([b"b"], self.evaluate(prefetch_op))
|
||||
self.assertEqual([b"c"], self.evaluate(prefetch_op))
|
||||
self.assertEqual([b"a"], sess.run(prefetch_op))
|
||||
self.assertEqual([b"b"], sess.run(prefetch_op))
|
||||
self.assertEqual([b"c"], sess.run(prefetch_op))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(prefetch_op)
|
||||
sess.run(prefetch_op)
|
||||
|
||||
self.evaluate(destroy_op)
|
||||
sess.run(destroy_op)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -39,7 +39,7 @@ class GroupByReducerTest(test_base.DatasetTestBase):
|
||||
get_next = dataset.make_one_shot_iterator().get_next()
|
||||
with self.cached_session() as sess:
|
||||
for expected in values:
|
||||
got = self.evaluate(get_next)
|
||||
got = sess.run(get_next)
|
||||
self.assertEqual(got, expected)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
@ -127,7 +127,7 @@ class GroupByReducerTest(test_base.DatasetTestBase):
|
||||
iterator = dataset.make_one_shot_iterator()
|
||||
get_next = iterator.get_next()
|
||||
with self.cached_session() as sess:
|
||||
x, y = self.evaluate(get_next)
|
||||
x, y = sess.run(get_next)
|
||||
self.assertAllEqual([0] * (2**i), x)
|
||||
self.assertAllEqual(np.array(1, ndmin=i), y)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -190,7 +190,7 @@ class GroupByReducerTest(test_base.DatasetTestBase):
|
||||
grouping.group_by_reducer(lambda x, y: np.int64(0), reducer))
|
||||
get_next = dataset.make_one_shot_iterator().get_next()
|
||||
with self.cached_session() as sess:
|
||||
x, y = self.evaluate(get_next)
|
||||
x, y = sess.run(get_next)
|
||||
self.assertAllEqual(x, np.asarray([x for x in range(10)]))
|
||||
self.assertEqual(y, 45)
|
||||
|
||||
|
@ -68,9 +68,9 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
|
||||
which_bucket, bucketed_values = self.evaluate(get_next)
|
||||
which_bucket, bucketed_values = sess.run(get_next)
|
||||
|
||||
self.assertEqual(0, which_bucket)
|
||||
|
||||
@ -103,11 +103,11 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
|
||||
# Get two minibatches (one containing even values, one containing odds)
|
||||
which_bucket_even, bucketed_values_even = self.evaluate(get_next)
|
||||
which_bucket_odd, bucketed_values_odd = self.evaluate(get_next)
|
||||
which_bucket_even, bucketed_values_even = sess.run(get_next)
|
||||
which_bucket_odd, bucketed_values_odd = sess.run(get_next)
|
||||
|
||||
# Count number of bucket_tensors.
|
||||
self.assertEqual(3, len(bucketed_values_even))
|
||||
@ -174,11 +174,11 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
|
||||
# Get two minibatches ([0, 2, ...] and [64, 66, ...])
|
||||
which_bucket0, bucketed_values_even0 = self.evaluate(get_next)
|
||||
which_bucket1, bucketed_values_even1 = self.evaluate(get_next)
|
||||
which_bucket0, bucketed_values_even0 = sess.run(get_next)
|
||||
which_bucket1, bucketed_values_even1 = sess.run(get_next)
|
||||
|
||||
# Ensure that bucket 1 was completely filtered out
|
||||
self.assertAllEqual(0, which_bucket0)
|
||||
@ -207,11 +207,11 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
batches = 0
|
||||
while True:
|
||||
result = self.evaluate(get_next)
|
||||
result = sess.run(get_next)
|
||||
is_even = all(x % 2 == 0 for x in result)
|
||||
is_odd = all(x % 2 == 1 for x in result)
|
||||
self.assertTrue(is_even or is_odd)
|
||||
@ -232,11 +232,11 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
counts = []
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
while True:
|
||||
result = self.evaluate(get_next)
|
||||
result = sess.run(get_next)
|
||||
self.assertTrue(
|
||||
all(x % 2 == 0
|
||||
for x in result) or all(x % 2 == 1)
|
||||
@ -259,16 +259,16 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
# The input is infinite, so this test demonstrates that:
|
||||
# 1. We produce output without having to consume the entire input,
|
||||
# 2. Different buckets can produce output at different rates, and
|
||||
# 3. For deterministic input, the output is deterministic.
|
||||
for _ in range(3):
|
||||
self.assertAllEqual([0, 0, 0, 0], self.evaluate(get_next))
|
||||
self.assertAllEqual([1, 1, 1, 1], self.evaluate(get_next))
|
||||
self.assertAllEqual([2, 2, 2, 2], self.evaluate(get_next))
|
||||
self.assertAllEqual([0, 0, 0, 0], self.evaluate(get_next))
|
||||
self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
|
||||
self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
|
||||
self.assertAllEqual([2, 2, 2, 2], sess.run(get_next))
|
||||
self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
|
||||
|
||||
def testSmallGroups(self):
|
||||
components = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0], dtype=np.int64)
|
||||
@ -280,13 +280,13 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
self.assertAllEqual([0, 0, 0, 0], self.evaluate(get_next))
|
||||
self.assertAllEqual([1, 1, 1, 1], self.evaluate(get_next))
|
||||
sess.run(init_op)
|
||||
self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
|
||||
self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
|
||||
# The small outputs at the end are deterministically produced in key
|
||||
# order.
|
||||
self.assertAllEqual([0, 0, 0], self.evaluate(get_next))
|
||||
self.assertAllEqual([1], self.evaluate(get_next))
|
||||
self.assertAllEqual([0, 0, 0], sess.run(get_next))
|
||||
self.assertAllEqual([1], sess.run(get_next))
|
||||
|
||||
def testEmpty(self):
|
||||
iterator = (
|
||||
@ -297,7 +297,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
with self.assertRaisesRegexp(
|
||||
errors.InvalidArgumentError,
|
||||
"Window size must be greater than zero, but got 0."):
|
||||
@ -323,7 +323,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -351,11 +351,11 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
counts = []
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
while True:
|
||||
tight_result, multiple_of_10_result = self.evaluate(get_next)
|
||||
tight_result, multiple_of_10_result = sess.run(get_next)
|
||||
self.assertEqual(0, multiple_of_10_result.shape[1] % 10)
|
||||
self.assertAllEqual(tight_result,
|
||||
multiple_of_10_result[:, :tight_result.shape[1]])
|
||||
|
@ -47,9 +47,9 @@ class IgnoreErrorsTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
for x in [1., 2., 3., 5.]:
|
||||
self.assertEqual(x, self.evaluate(get_next))
|
||||
self.assertEqual(x, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -65,9 +65,9 @@ class IgnoreErrorsTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
for x in [1., 2., 3., 5.]:
|
||||
self.assertEqual(x, self.evaluate(get_next))
|
||||
self.assertEqual(x, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -93,9 +93,9 @@ class IgnoreErrorsTest(test_base.DatasetTestBase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
# All of the files are present.
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
for filename in filenames:
|
||||
self.assertEqual(compat.as_bytes(filename), self.evaluate(get_next))
|
||||
self.assertEqual(compat.as_bytes(filename), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -104,9 +104,9 @@ class IgnoreErrorsTest(test_base.DatasetTestBase):
|
||||
|
||||
# Attempting to read filenames[0] will fail, but ignore_errors()
|
||||
# will catch the error.
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
for filename in filenames[1:]:
|
||||
self.assertEqual(compat.as_bytes(filename), self.evaluate(get_next))
|
||||
self.assertEqual(compat.as_bytes(filename), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
|
@ -53,7 +53,7 @@ class IndexedDatasetOpsTest(test_base.DatasetTestBase):
|
||||
ds = indexed_dataset_ops.IdentityIndexedDataset(16)
|
||||
materialized = ds.materialize()
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(materialized.initializer)
|
||||
sess.run(materialized.initializer)
|
||||
placeholder = array_ops.placeholder(dtypes.uint64, shape=[])
|
||||
for i in range(16):
|
||||
output = sess.run(
|
||||
@ -68,9 +68,9 @@ class IndexedDatasetOpsTest(test_base.DatasetTestBase):
|
||||
itr = ds.make_initializable_iterator()
|
||||
n = itr.get_next()
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(itr.initializer)
|
||||
sess.run(itr.initializer)
|
||||
for i in range(16):
|
||||
output = self.evaluate(n)
|
||||
output = sess.run(n)
|
||||
self.assertEqual(i, output)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(n)
|
||||
|
@ -112,10 +112,10 @@ class MakeBatchedFeaturesDatasetTest(
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
for file_batch, _, _, _, record_batch, _ in self._next_expected_batch(
|
||||
range(self._num_files), 2, 10):
|
||||
actual_batch = self.evaluate(next_element)
|
||||
actual_batch = sess.run(next_element)
|
||||
self.assertAllEqual(file_batch, actual_batch["file"])
|
||||
self.assertAllEqual(record_batch, actual_batch["record"])
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
|
@ -90,7 +90,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
|
||||
batch_size,
|
||||
num_epochs,
|
||||
):
|
||||
actual_features = self.evaluate(nxt)
|
||||
actual_features = sess.run(nxt)
|
||||
|
||||
if label_name is not None:
|
||||
expected_labels = expected_features.pop(label_name)
|
||||
|
@ -105,7 +105,7 @@ class MakeTFRecordDatasetTest(
|
||||
for expected_batch in self._next_expected_batch(
|
||||
file_indices, batch_size, num_epochs, interleave_cycle_length,
|
||||
drop_final_batch, use_parser_fn):
|
||||
actual_batch = self.evaluate(outputs)
|
||||
actual_batch = sess.run(outputs)
|
||||
self.assertAllEqual(expected_batch, actual_batch)
|
||||
|
||||
def _read_test(self, batch_size, num_epochs, file_index=None,
|
||||
@ -188,7 +188,7 @@ class MakeTFRecordDatasetTest(
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
self.evaluate(iterator.initializer)
|
||||
sess.run(iterator.initializer)
|
||||
first_batches = []
|
||||
try:
|
||||
while True:
|
||||
@ -196,7 +196,7 @@ class MakeTFRecordDatasetTest(
|
||||
except errors.OutOfRangeError:
|
||||
pass
|
||||
|
||||
self.evaluate(iterator.initializer)
|
||||
sess.run(iterator.initializer)
|
||||
second_batches = []
|
||||
try:
|
||||
while True:
|
||||
|
@ -89,7 +89,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
sess.run(init_op, feed_dict={count: 28, batch_size: 14})
|
||||
num_batches = (28 * 7) // 14
|
||||
for i in range(num_batches):
|
||||
result = self.evaluate(get_next)
|
||||
result = sess.run(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
for j in range(14):
|
||||
self.assertAllEqual(component[(i * 14 + j) % 7]**2,
|
||||
@ -104,12 +104,12 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
# We expect (num_batches - 1) full-sized batches.
|
||||
num_batches = int(math.ceil((14 * 7) / 8))
|
||||
for i in range(num_batches - 1):
|
||||
result = self.evaluate(get_next)
|
||||
result = sess.run(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
for j in range(8):
|
||||
self.assertAllEqual(component[(i * 8 + j) % 7]**2,
|
||||
result_component[j])
|
||||
result = self.evaluate(get_next)
|
||||
result = sess.run(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
for j in range((14 * 7) % 8):
|
||||
self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2,
|
||||
@ -152,10 +152,10 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertEqual([None, 1], iterator.output_shapes.as_list())
|
||||
next_element = iterator.get_next()
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual([[0], [1], [4], [9]], self.evaluate(next_element))
|
||||
self.assertAllEqual([[16], [25], [36], [49]], self.evaluate(next_element))
|
||||
self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
|
||||
self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
|
||||
if not drop_remainder:
|
||||
self.assertAllEqual([[64], [81]], self.evaluate(next_element))
|
||||
self.assertAllEqual([[64], [81]], sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -177,9 +177,9 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertEqual([None, 1], iterator.output_shapes.as_list())
|
||||
next_element = iterator.get_next()
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual([[0], [1], [4], [9]], self.evaluate(next_element))
|
||||
self.assertAllEqual([[16], [25], [36], [49]], self.evaluate(next_element))
|
||||
self.assertAllEqual([[64], [81]], self.evaluate(next_element))
|
||||
self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
|
||||
self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
|
||||
self.assertAllEqual([[64], [81]], sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -201,7 +201,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
elements.append(iterator.get_next())
|
||||
with self.cached_session() as sess:
|
||||
for i in range(5):
|
||||
got = self.evaluate(elements)
|
||||
got = sess.run(elements)
|
||||
got.sort(key=lambda x: x[0])
|
||||
expected = []
|
||||
for j in range(100):
|
||||
@ -230,7 +230,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
elements.append(iterator.get_next())
|
||||
with self.cached_session() as sess:
|
||||
for i in range(4):
|
||||
got = self.evaluate(elements)
|
||||
got = sess.run(elements)
|
||||
got.sort(key=lambda x: x[0])
|
||||
expected = []
|
||||
for j in range(100):
|
||||
@ -261,9 +261,9 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
for i in range(2):
|
||||
actual = self.evaluate(get_next)
|
||||
actual = sess.run(get_next)
|
||||
expected = sparse_tensor.SparseTensorValue(
|
||||
indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
|
||||
values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4],
|
||||
@ -321,7 +321,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
"number of elements does not match"):
|
||||
sess.run(get_next)
|
||||
@ -393,8 +393,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(threshold // 10):
|
||||
self.assertAllEqual([i * 10 + j for j in range(10)],
|
||||
self.evaluate(get_next))
|
||||
self.assertAllEqual([i * 10 + j for j in range(10)], sess.run(get_next))
|
||||
if threshold % 10 != 0:
|
||||
self.assertAllEqual(
|
||||
[threshold // 10 * 10 + j for j in range(threshold % 10)],
|
||||
@ -443,8 +442,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for _ in range(10):
|
||||
self.assertAllEqual([element for _ in range(10)],
|
||||
self.evaluate(get_next))
|
||||
self.assertAllEqual([element for _ in range(10)], sess.run(get_next))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("Identity", None, lambda x: x, None),
|
||||
@ -464,7 +462,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
else:
|
||||
expected = map_fn(
|
||||
sess.run(self.structuredElement(structure, shape=[10])))
|
||||
self.assertAllEqual(expected, self.evaluate(get_next))
|
||||
self.assertAllEqual(expected, sess.run(get_next))
|
||||
|
||||
def testShortCircuitCapturedInput(self):
|
||||
captured_t = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
@ -475,7 +473,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer, feed_dict={captured_t: 42})
|
||||
self.assertAllEqual([42] * 10, self.evaluate(get_next))
|
||||
self.assertAllEqual([42] * 10, sess.run(get_next))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("Normal", False),
|
||||
|
@ -218,7 +218,7 @@ class MapDefunTest(test_base.DatasetTestBase):
|
||||
|
||||
def _assert_op_cancelled(self, sess, map_defun_op):
|
||||
with self.assertRaisesRegexp(errors.CancelledError, "was cancelled"):
|
||||
self.evaluate(map_defun_op)
|
||||
sess.run(map_defun_op)
|
||||
|
||||
def testMapDefunWithParentCancellation(self):
|
||||
# Checks that a cancellation of the parent graph is threaded through to
|
||||
|
@ -72,7 +72,7 @@ class OverrideThreadpoolTest(test_base.DatasetTestBase,
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(iterator.initializer)
|
||||
sess.run(iterator.initializer)
|
||||
thread_ids = []
|
||||
try:
|
||||
while True:
|
||||
|
@ -637,11 +637,11 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
for i in range(10):
|
||||
for j in range(2):
|
||||
expected = [i, 0] if j % 2 == 0 else [0, -i]
|
||||
self.assertAllEqual(expected, self.evaluate(get_next))
|
||||
self.assertAllEqual(expected, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -796,7 +796,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
|
||||
with self.cached_session() as sess:
|
||||
for _ in range(2):
|
||||
elements = []
|
||||
self.evaluate(iterator.initializer)
|
||||
sess.run(iterator.initializer)
|
||||
try:
|
||||
while True:
|
||||
elements.extend(sess.run(next_element))
|
||||
|
@ -57,7 +57,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -87,7 +87,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -117,7 +117,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual({"a": i}, self.evaluate(next_element))
|
||||
self.assertEqual({"a": i}, sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -150,7 +150,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
for i in range(10):
|
||||
actual = self.evaluate(next_element)
|
||||
actual = sess.run(next_element)
|
||||
self.assertAllEqual([i], actual.values)
|
||||
self.assertAllEqual([[0, 0]], actual.indices)
|
||||
self.assertAllEqual([2, 2], actual.dense_shape)
|
||||
@ -170,7 +170,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -199,12 +199,12 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
|
||||
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
self.evaluate(iterator.initializer)
|
||||
sess.run(iterator.initializer)
|
||||
for i in range(5):
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
self.evaluate(iterator.initializer)
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
sess.run(iterator.initializer)
|
||||
for i in range(10):
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -220,12 +220,12 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(iterator.initializer)
|
||||
sess.run(iterator.initializer)
|
||||
for i in range(5):
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
self.evaluate(iterator.initializer)
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
sess.run(iterator.initializer)
|
||||
for i in range(10):
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
|
@ -60,7 +60,7 @@ class ScanTest(test_base.DatasetTestBase):
|
||||
feed_dict={start: start_val, step: step_val, take: take_val})
|
||||
for expected, _ in zip(
|
||||
itertools.count(start_val, step_val), range(take_val)):
|
||||
self.assertEqual(expected, self.evaluate(next_element))
|
||||
self.assertEqual(expected, sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -110,7 +110,7 @@ class ScanTest(test_base.DatasetTestBase):
|
||||
feed_dict={start: start_val, step: step_val, take: take_val})
|
||||
for expected, _ in zip(
|
||||
itertools.count(start_val, step_val), range(take_val)):
|
||||
self.assertEqual(expected, self.evaluate(next_element).values[0])
|
||||
self.assertEqual(expected, sess.run(next_element).values[0])
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -136,7 +136,7 @@ class ScanTest(test_base.DatasetTestBase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(5):
|
||||
(longer_vector_val, larger_rank_val), _ = self.evaluate(next_element)
|
||||
(longer_vector_val, larger_rank_val), _ = sess.run(next_element)
|
||||
self.assertAllEqual([0] * (2**i), longer_vector_val)
|
||||
self.assertAllEqual(np.array(1, ndmin=i), larger_rank_val)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
|
@ -71,19 +71,19 @@ class RangeDatasetSerializationTest(
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next, save_op, _ = _build_graph(start, stop)
|
||||
with self.session(graph=g) as sess:
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.evaluate(init_op)
|
||||
sess.run(variables.global_variables_initializer())
|
||||
sess.run(init_op)
|
||||
for i in range(start, break_point):
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
self.evaluate(save_op)
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
sess.run(save_op)
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next, _, restore_op = _build_graph(start, stop)
|
||||
with self.session(graph=g) as sess:
|
||||
self.evaluate(init_op)
|
||||
self.evaluate(restore_op)
|
||||
sess.run(init_op)
|
||||
sess.run(restore_op)
|
||||
for i in range(break_point, stop):
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -91,14 +91,14 @@ class RangeDatasetSerializationTest(
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next, save_op, restore_op = _build_graph(start, stop)
|
||||
with self.session(graph=g) as sess:
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.evaluate(init_op)
|
||||
sess.run(variables.global_variables_initializer())
|
||||
sess.run(init_op)
|
||||
for i in range(start, break_point):
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
self.evaluate(save_op)
|
||||
self.evaluate(restore_op)
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
sess.run(save_op)
|
||||
sess.run(restore_op)
|
||||
for i in range(break_point, stop):
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
|
@ -62,7 +62,7 @@ class SerializationIntegrationTest(test.TestCase):
|
||||
with self.session(graph=g) as sess:
|
||||
sess.run(init_ops)
|
||||
for _ in range(break_point):
|
||||
output = self.evaluate(get_next_ops)
|
||||
output = sess.run(get_next_ops)
|
||||
for i in range(num_pipelines):
|
||||
all_outputs[i].append(output[i])
|
||||
saver.save(sess, self._ckpt_path())
|
||||
@ -73,7 +73,7 @@ class SerializationIntegrationTest(test.TestCase):
|
||||
with self.session(graph=g) as sess:
|
||||
saver.restore(sess, self._ckpt_path())
|
||||
for _ in range(num_outputs - break_point):
|
||||
output = self.evaluate(get_next_ops)
|
||||
output = sess.run(get_next_ops)
|
||||
for i in range(num_pipelines):
|
||||
all_outputs[i].append(output[i])
|
||||
|
||||
|
@ -108,7 +108,7 @@ class ShuffleAndRepeatTest(test_base.DatasetTestBase):
|
||||
shuffle_ops.shuffle_and_repeat(buffer_size=21))
|
||||
get_next_op = ds.make_one_shot_iterator().get_next()
|
||||
with self.session(graph=g) as sess:
|
||||
self.evaluate(get_next_op)
|
||||
sess.run(get_next_op)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -38,10 +38,10 @@ class SleepTest(test_base.DatasetTestBase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(iterator.initializer)
|
||||
sess.run(iterator.initializer)
|
||||
start_time = time.time()
|
||||
for i in range(10):
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
end_time = time.time()
|
||||
self.assertGreater(end_time - start_time, (10 * sleep_microseconds) / 1e6)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
|
@ -39,9 +39,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
for _ in range(2): # Dataset is repeated. See setUp.
|
||||
self.assertEqual((b"John", b"Doe", b"Hi!"), self.evaluate(get_next))
|
||||
self.assertEqual((b"Jane", b"Moe", b"Hi again!"),
|
||||
self.evaluate(get_next))
|
||||
self.assertEqual((b"John", b"Doe", b"Hi!"), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", b"Moe", b"Hi again!"), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -59,8 +58,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
"ON students.first_name = people.first_name "
|
||||
"AND students.last_name = people.last_name"
|
||||
})
|
||||
self.assertEqual((b"John", b"California", b"Hi!"),
|
||||
self.evaluate(get_next))
|
||||
self.assertEqual((b"John", b"California", b"Hi!"), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -77,9 +75,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
"SELECT first_name, last_name, favorite_nonsense_word "
|
||||
"FROM students ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", b"Doe", b"n\0nsense"), self.evaluate(get_next))
|
||||
self.assertEqual((b"Jane", b"Moe", b"nonsense\0"),
|
||||
self.evaluate(get_next))
|
||||
self.assertEqual((b"John", b"Doe", b"n\0nsense"), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", b"Moe", b"nonsense\0"), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -96,8 +93,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
self.query: "SELECT first_name, last_name, motto FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", b"Doe", b"Hi!"), self.evaluate(get_next))
|
||||
self.assertEqual((b"Jane", b"Moe", b"Hi again!"), self.evaluate(get_next))
|
||||
self.assertEqual((b"John", b"Doe", b"Hi!"), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", b"Moe", b"Hi again!"), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
sess.run(
|
||||
@ -106,8 +103,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
self.query: "SELECT first_name, last_name, state FROM people "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", b"Doe", b"California"),
|
||||
self.evaluate(get_next))
|
||||
self.assertEqual((b"John", b"Doe", b"California"), sess.run(get_next))
|
||||
self.assertEqual((b"Benjamin", b"Franklin", b"Pennsylvania"),
|
||||
sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -216,8 +212,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
self.query: "SELECT first_name, desk_number FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", 9), self.evaluate(get_next))
|
||||
self.assertEqual((b"Jane", 127), self.evaluate(get_next))
|
||||
self.assertEqual((b"John", 9), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", 127), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -234,7 +230,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
"FROM students "
|
||||
"WHERE first_name = 'John' ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", 0, -2), self.evaluate(get_next))
|
||||
self.assertEqual((b"John", 0, -2), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -250,9 +246,9 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
"SELECT desk_number, favorite_negative_number FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((9, -2), self.evaluate(get_next))
|
||||
self.assertEqual((9, -2), sess.run(get_next))
|
||||
# Max and min values of int8
|
||||
self.assertEqual((127, -128), self.evaluate(get_next))
|
||||
self.assertEqual((127, -128), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -267,8 +263,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
self.query: "SELECT first_name, desk_number FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", 9), self.evaluate(get_next))
|
||||
self.assertEqual((b"Jane", 127), self.evaluate(get_next))
|
||||
self.assertEqual((b"John", 9), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", 127), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -285,7 +281,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
"FROM students "
|
||||
"WHERE first_name = 'John' ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", 0, -2), self.evaluate(get_next))
|
||||
self.assertEqual((b"John", 0, -2), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -301,9 +297,9 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
"FROM students ORDER BY first_name DESC"
|
||||
})
|
||||
# Max value of int16
|
||||
self.assertEqual((b"John", 32767), self.evaluate(get_next))
|
||||
self.assertEqual((b"John", 32767), sess.run(get_next))
|
||||
# Min value of int16
|
||||
self.assertEqual((b"Jane", -32768), self.evaluate(get_next))
|
||||
self.assertEqual((b"Jane", -32768), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -318,8 +314,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
self.query: "SELECT first_name, desk_number FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", 9), self.evaluate(get_next))
|
||||
self.assertEqual((b"Jane", 127), self.evaluate(get_next))
|
||||
self.assertEqual((b"John", 9), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", 127), sess.run(get_next))
|
||||
|
||||
# Test that `SqlDataset` can read a negative or 0-valued integer from a
|
||||
# SQLite database table and place it in an `int32` tensor.
|
||||
@ -332,8 +328,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
self.query: "SELECT first_name, income FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", 0), self.evaluate(get_next))
|
||||
self.assertEqual((b"Jane", -20000), self.evaluate(get_next))
|
||||
self.assertEqual((b"John", 0), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", -20000), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -349,9 +345,9 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
# Max value of int32
|
||||
self.assertEqual((b"John", 2147483647), self.evaluate(get_next))
|
||||
self.assertEqual((b"John", 2147483647), sess.run(get_next))
|
||||
# Min value of int32
|
||||
self.assertEqual((b"Jane", -2147483648), self.evaluate(get_next))
|
||||
self.assertEqual((b"Jane", -2147483648), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -366,8 +362,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
self.query: "SELECT first_name, school_id FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", 123), self.evaluate(get_next))
|
||||
self.assertEqual((b"Jane", 1000), self.evaluate(get_next))
|
||||
self.assertEqual((b"John", 123), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", 1000), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -382,8 +378,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
self.query: "SELECT first_name, desk_number FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", 9), self.evaluate(get_next))
|
||||
self.assertEqual((b"Jane", 127), self.evaluate(get_next))
|
||||
self.assertEqual((b"John", 9), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", 127), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -398,8 +394,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
self.query: "SELECT first_name, income FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", 0), self.evaluate(get_next))
|
||||
self.assertEqual((b"Jane", -20000), self.evaluate(get_next))
|
||||
self.assertEqual((b"John", 0), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", -20000), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -416,9 +412,9 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
# Max value of int64
|
||||
self.assertEqual((b"John", 9223372036854775807), self.evaluate(get_next))
|
||||
self.assertEqual((b"John", 9223372036854775807), sess.run(get_next))
|
||||
# Min value of int64
|
||||
self.assertEqual((b"Jane", -9223372036854775808), self.evaluate(get_next))
|
||||
self.assertEqual((b"Jane", -9223372036854775808), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -433,8 +429,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
self.query: "SELECT first_name, desk_number FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", 9), self.evaluate(get_next))
|
||||
self.assertEqual((b"Jane", 127), self.evaluate(get_next))
|
||||
self.assertEqual((b"John", 9), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", 127), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -450,9 +446,9 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
# Min value of uint8
|
||||
self.assertEqual((b"John", 0), self.evaluate(get_next))
|
||||
self.assertEqual((b"John", 0), sess.run(get_next))
|
||||
# Max value of uint8
|
||||
self.assertEqual((b"Jane", 255), self.evaluate(get_next))
|
||||
self.assertEqual((b"Jane", 255), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -467,8 +463,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
self.query: "SELECT first_name, desk_number FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", 9), self.evaluate(get_next))
|
||||
self.assertEqual((b"Jane", 127), self.evaluate(get_next))
|
||||
self.assertEqual((b"John", 9), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", 127), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -484,9 +480,9 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
# Min value of uint16
|
||||
self.assertEqual((b"John", 0), self.evaluate(get_next))
|
||||
self.assertEqual((b"John", 0), sess.run(get_next))
|
||||
# Max value of uint16
|
||||
self.assertEqual((b"Jane", 65535), self.evaluate(get_next))
|
||||
self.assertEqual((b"Jane", 65535), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -503,8 +499,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
"SELECT first_name, registration_complete FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", True), self.evaluate(get_next))
|
||||
self.assertEqual((b"Jane", False), self.evaluate(get_next))
|
||||
self.assertEqual((b"John", True), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", False), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -519,8 +515,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
self.query: "SELECT first_name, favorite_medium_sized_number "
|
||||
"FROM students ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", True), self.evaluate(get_next))
|
||||
self.assertEqual((b"Jane", True), self.evaluate(get_next))
|
||||
self.assertEqual((b"John", True), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", True), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -537,9 +533,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
"SELECT first_name, last_name, victories FROM townspeople "
|
||||
"ORDER BY first_name"
|
||||
})
|
||||
self.assertEqual((b"George", b"Washington", 20.0),
|
||||
self.evaluate(get_next))
|
||||
self.assertEqual((b"John", b"Adams", -19.95), self.evaluate(get_next))
|
||||
self.assertEqual((b"George", b"Washington", 20.0), sess.run(get_next))
|
||||
self.assertEqual((b"John", b"Adams", -19.95), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
|
@ -74,18 +74,18 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
summary_t = aggregator.get_summary()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(iterator.initializer)
|
||||
sess.run(iterator.initializer)
|
||||
expected_sum = 0.0
|
||||
for i in range(100):
|
||||
self.assertAllEqual(
|
||||
np.array([i] * i, dtype=np.int64), sess.run(next_element))
|
||||
summary_str = self.evaluate(summary_t)
|
||||
summary_str = sess.run(summary_t)
|
||||
self._assertSummaryHasCount(summary_str, "bytes_produced", float(i + 1))
|
||||
expected_sum += i * 8.0
|
||||
self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
summary_str = self.evaluate(summary_t)
|
||||
summary_str = sess.run(summary_t)
|
||||
self._assertSummaryHasCount(summary_str, "bytes_produced", 100.0)
|
||||
self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum)
|
||||
|
||||
@ -99,15 +99,14 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
summary_t = aggregator.get_summary()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(iterator.initializer)
|
||||
sess.run(iterator.initializer)
|
||||
for i in range(100):
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self._assertSummaryHasCount(
|
||||
sess.run(summary_t), "record_latency", float(i + 1))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self._assertSummaryHasCount(
|
||||
self.evaluate(summary_t), "record_latency", 100.0)
|
||||
self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0)
|
||||
|
||||
def testPrefetchBufferUtilization(self, dataset_transformation):
|
||||
aggregator = stats_aggregator.StatsAggregator()
|
||||
@ -119,11 +118,11 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
summary_t = aggregator.get_summary()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(iterator.initializer)
|
||||
sess.run(iterator.initializer)
|
||||
for i in range(100):
|
||||
self.assertAllEqual(
|
||||
np.array([i] * i, dtype=np.int64), sess.run(next_element))
|
||||
summary_str = self.evaluate(summary_t)
|
||||
summary_str = sess.run(summary_t)
|
||||
self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
|
||||
float(i + 1))
|
||||
self._assertSummaryContains(summary_str, "Prefetch::buffer_capacity")
|
||||
@ -132,7 +131,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
0, 1)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
summary_str = self.evaluate(summary_t)
|
||||
summary_str = sess.run(summary_t)
|
||||
self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
|
||||
100)
|
||||
|
||||
@ -146,11 +145,11 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
summary_t = aggregator.get_summary()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(iterator.initializer)
|
||||
sess.run(iterator.initializer)
|
||||
for i in range(10):
|
||||
self.assertAllEqual(
|
||||
np.array([i] * i, dtype=np.int64), sess.run(next_element))
|
||||
summary_str = self.evaluate(summary_t)
|
||||
summary_str = sess.run(summary_t)
|
||||
self._assertSummaryHasScalarValue(summary_str,
|
||||
"Prefetch::buffer_capacity", 0)
|
||||
self._assertSummaryHasScalarValue(summary_str, "Prefetch::buffer_size",
|
||||
@ -168,9 +167,9 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
summary_t = aggregator.get_summary()
|
||||
|
||||
with self.test_session() as sess:
|
||||
self.evaluate(iterator.initializer)
|
||||
sess.run(iterator.initializer)
|
||||
for i in range(34):
|
||||
self.assertEqual(i * 3, self.evaluate(next_element))
|
||||
self.assertEqual(i * 3, sess.run(next_element))
|
||||
if i is not 0:
|
||||
self._assertSummaryHasScalarValue(
|
||||
sess.run(summary_t), "Filter::dropped_elements", float(i * 2))
|
||||
@ -262,9 +261,9 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for j in range(5):
|
||||
self.evaluate(iterator.initializer)
|
||||
sess.run(iterator.initializer)
|
||||
for i in range(100):
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self._assertSummaryHasCount(
|
||||
sess.run(summary_t), "record_latency", float((j * 100) + i + 1))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -279,9 +278,9 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(iterator.initializer)
|
||||
sess.run(iterator.initializer)
|
||||
for i in range(100):
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -296,17 +295,16 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
summary_t = aggregator.get_summary()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(iterator.initializer)
|
||||
sess.run(iterator.initializer)
|
||||
for i in range(100):
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self._assertSummaryHasCount(
|
||||
sess.run(summary_t), "record_latency", float(i + 1))
|
||||
self._assertSummaryHasCount(
|
||||
sess.run(summary_t), "record_latency_2", float(i + 1))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self._assertSummaryHasCount(
|
||||
self.evaluate(summary_t), "record_latency", 100.0)
|
||||
self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0)
|
||||
self._assertSummaryHasCount(
|
||||
sess.run(summary_t), "record_latency_2", 100.0)
|
||||
|
||||
@ -321,15 +319,14 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
summary_t = aggregator.get_summary()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(iterator.initializer)
|
||||
sess.run(iterator.initializer)
|
||||
for i in range(100):
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self._assertSummaryHasCount(
|
||||
sess.run(summary_t), "record_latency", float(2 * (i + 1)))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self._assertSummaryHasCount(
|
||||
self.evaluate(summary_t), "record_latency", 200.0)
|
||||
self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0)
|
||||
|
||||
def testMultipleIteratorsSameAggregator(self, dataset_transformation):
|
||||
aggregator = stats_aggregator.StatsAggregator()
|
||||
@ -344,13 +341,12 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
with self.cached_session() as sess:
|
||||
sess.run([iterator_0.initializer, iterator_1.initializer])
|
||||
for i in range(100):
|
||||
self.assertEqual(i * 2, self.evaluate(next_element))
|
||||
self.assertEqual(i * 2, sess.run(next_element))
|
||||
self._assertSummaryHasCount(
|
||||
sess.run(summary_t), "record_latency", float(2 * (i + 1)))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self._assertSummaryHasCount(
|
||||
self.evaluate(summary_t), "record_latency", 200.0)
|
||||
self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0)
|
||||
|
||||
def testMultipleDatasetWithPrefixes(self, dataset_transformation):
|
||||
aggregator = stats_aggregator.StatsAggregator()
|
||||
@ -368,7 +364,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
with self.test_session() as sess:
|
||||
sess.run([iterator_0.initializer, iterator_1.initializer])
|
||||
for i in range(100):
|
||||
self.assertEqual(i * 2, self.evaluate(next_element))
|
||||
self.assertEqual(i * 2, sess.run(next_element))
|
||||
self._assertSummaryHasCount(
|
||||
sess.run(summary_t), "dataset1_record_latency", float(i + 1))
|
||||
self._assertSummaryHasCount(
|
||||
@ -425,7 +421,7 @@ class FeatureStatsDatasetTest(
|
||||
summary_t = aggregator.get_summary()
|
||||
|
||||
with self.test_session() as sess:
|
||||
self.evaluate(iterator.initializer)
|
||||
sess.run(iterator.initializer)
|
||||
for _ in range(num_output):
|
||||
sess.run(next_element)
|
||||
|
||||
|
@ -50,7 +50,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer, feed_dict={placeholder: [0, 1, 2, 3]})
|
||||
for i in range(4):
|
||||
self.assertEqual(i, self.evaluate(next_elem))
|
||||
self.assertEqual(i, sess.run(next_elem))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_elem)
|
||||
|
||||
@ -68,7 +68,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual((i,) * 3, self.evaluate(op))
|
||||
self.assertEqual((i,) * 3, sess.run(op))
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(op)
|
||||
@ -88,7 +88,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual((i, compat.as_bytes(str(i)), i), self.evaluate(op))
|
||||
self.assertEqual((i, compat.as_bytes(str(i)), i), sess.run(op))
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(op)
|
||||
@ -107,7 +107,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
st_row = self.evaluate(next_element)
|
||||
st_row = sess.run(next_element)
|
||||
self.assertEqual([i], st_row.indices)
|
||||
self.assertEqual([i], st_row.values)
|
||||
self.assertEqual([10], st_row.dense_shape)
|
||||
@ -128,7 +128,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
dense_elem, st_row = self.evaluate(next_element)
|
||||
dense_elem, st_row = sess.run(next_element)
|
||||
self.assertEqual(i, dense_elem)
|
||||
self.assertEqual([i], st_row.indices)
|
||||
self.assertEqual([i], st_row.values)
|
||||
@ -150,7 +150,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual(((i,),) * 3, self.evaluate(op))
|
||||
self.assertEqual(((i,),) * 3, sess.run(op))
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(op)
|
||||
|
@ -49,11 +49,11 @@ class UniqueTest(test_base.DatasetTestBase):
|
||||
with self.cached_session() as sess:
|
||||
for test_case, expected in test_cases:
|
||||
current_test_case = test_case
|
||||
self.evaluate(iterator.initializer)
|
||||
sess.run(iterator.initializer)
|
||||
for element in expected:
|
||||
if dtype == dtypes.string:
|
||||
element = compat.as_bytes(element)
|
||||
self.assertAllEqual(element, self.evaluate(next_element))
|
||||
self.assertAllEqual(element, sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
|
@ -93,13 +93,13 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
})
|
||||
num_full_batches = (count * 7) // batch_size
|
||||
for i in range(num_full_batches):
|
||||
result = self.evaluate(get_next)
|
||||
result = sess.run(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
for j in range(batch_size):
|
||||
self.assertAllEqual(component[(i * batch_size + j) % 7]**2,
|
||||
result_component[j])
|
||||
if not drop_remainder and (count * 7) % batch_size > 0:
|
||||
result = self.evaluate(get_next)
|
||||
result = sess.run(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
for j in range((count * 7) % batch_size):
|
||||
self.assertAllEqual(
|
||||
@ -128,9 +128,9 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
for i in range(2):
|
||||
actual = self.evaluate(get_next)
|
||||
actual = sess.run(get_next)
|
||||
expected = sparse_tensor.SparseTensorValue(
|
||||
indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
|
||||
values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4],
|
||||
@ -155,9 +155,9 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
for i in range(2):
|
||||
actual = self.evaluate(get_next)
|
||||
actual = sess.run(get_next)
|
||||
expected_indices = []
|
||||
expected_values = []
|
||||
for j in range(5):
|
||||
@ -185,8 +185,8 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
actual = self.evaluate(get_next)
|
||||
sess.run(init_op)
|
||||
actual = sess.run(get_next)
|
||||
expected = sparse_tensor.SparseTensorValue(
|
||||
indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [0, 4, 0],
|
||||
[1, 0, 0], [1, 1, 0], [1, 2, 0], [1, 3, 0], [1, 4, 0]],
|
||||
@ -211,7 +211,7 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(iterator.initializer)
|
||||
sess.run(iterator.initializer)
|
||||
with self.assertRaisesRegexp(
|
||||
errors.InvalidArgumentError,
|
||||
r'Cannot batch tensors with different shapes in component 0. '
|
||||
@ -271,7 +271,7 @@ class PaddedBatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
num_full_batches = len(seq_lens) // batch_size
|
||||
|
||||
for i in range(num_full_batches):
|
||||
result = self.evaluate(get_next)
|
||||
result = sess.run(get_next)
|
||||
padded_len = padded_shapes[0]
|
||||
if padded_len is None or padded_len == -1:
|
||||
padded_len = np.max(result) if result.size > 0 else 0
|
||||
@ -283,7 +283,7 @@ class PaddedBatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
[0] * (padded_len - seq_len))
|
||||
|
||||
if not drop_remainder and len(seq_lens) % batch_size > 0:
|
||||
result = self.evaluate(get_next)
|
||||
result = sess.run(get_next)
|
||||
padded_len = np.max(result) if result.size > 0 else 0
|
||||
self.assertEqual((len(seq_lens) % batch_size, padded_len),
|
||||
result.shape)
|
||||
@ -315,7 +315,7 @@ class PaddedBatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
result = self.evaluate(get_next)
|
||||
result = sess.run(get_next)
|
||||
self.assertAllEqual([[], [], [], []], result)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
@ -347,7 +347,7 @@ class PaddedBatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
seq_lens: random_seq_lens
|
||||
})
|
||||
for i in range(8):
|
||||
result = self.evaluate(get_next)
|
||||
result = sess.run(get_next)
|
||||
padded_len = np.max(result[0])
|
||||
self.assertEqual((4, padded_len), result[0].shape)
|
||||
self.assertEqual((4, padded_len), result[1].shape)
|
||||
|
@ -71,7 +71,7 @@ class FileCacheDatasetTest(test_base.DatasetTestBase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
# First run without caching to collect the "ground truth".
|
||||
self.evaluate(init_fifo_op)
|
||||
sess.run(init_fifo_op)
|
||||
elements = []
|
||||
for _ in range(20):
|
||||
elements.append(sess.run(get_next))
|
||||
@ -220,14 +220,14 @@ class MemoryCacheDatasetTest(test_base.DatasetTestBase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
|
||||
self.evaluate(repeat_count.initializer)
|
||||
self.evaluate(cached_iterator.initializer)
|
||||
self.evaluate(uncached_iterator.initializer)
|
||||
sess.run(repeat_count.initializer)
|
||||
sess.run(cached_iterator.initializer)
|
||||
sess.run(uncached_iterator.initializer)
|
||||
|
||||
for i in range(3):
|
||||
for _ in range(10):
|
||||
self.assertEqual(self.evaluate(cached_next), i)
|
||||
self.assertEqual(self.evaluate(uncached_next), i)
|
||||
self.assertEqual(sess.run(cached_next), i)
|
||||
self.assertEqual(sess.run(uncached_next), i)
|
||||
|
||||
sess.run(repeat_count.assign(0))
|
||||
|
||||
@ -238,7 +238,7 @@ class MemoryCacheDatasetTest(test_base.DatasetTestBase):
|
||||
# The cached iterator replays from cache.
|
||||
for i in range(3):
|
||||
for _ in range(10):
|
||||
self.assertEqual(self.evaluate(cached_next), i)
|
||||
self.assertEqual(sess.run(cached_next), i)
|
||||
|
||||
# The cached iterator should now be empty.
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -280,7 +280,7 @@ class MemoryCacheDatasetTest(test_base.DatasetTestBase):
|
||||
i2 = d2.make_initializable_iterator()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(i1.initializer)
|
||||
sess.run(i1.initializer)
|
||||
|
||||
self.assertEqual(1, sess.run(i1.get_next()))
|
||||
self.assertEqual(2, sess.run(i1.get_next()))
|
||||
@ -307,7 +307,7 @@ class MemoryCacheDatasetTest(test_base.DatasetTestBase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i, expected in enumerate(expected_values):
|
||||
self.assertEqual(expected, self.evaluate(n),
|
||||
self.assertEqual(expected, sess.run(n),
|
||||
"Unexpected value at index %s" % i)
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
|
@ -51,9 +51,9 @@ class ConcatenateDatasetTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
for i in range(9):
|
||||
result = self.evaluate(get_next)
|
||||
result = sess.run(get_next)
|
||||
if i < 4:
|
||||
for component, result_component in zip(input_components, result):
|
||||
self.assertAllEqual(component[i], result_component)
|
||||
@ -85,9 +85,9 @@ class ConcatenateDatasetTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
for i in range(9):
|
||||
result = self.evaluate(get_next)
|
||||
result = sess.run(get_next)
|
||||
if i < 4:
|
||||
for component, result_component in zip(input_components, result):
|
||||
self.assertAllEqual(component[i], result_component)
|
||||
|
@ -52,8 +52,8 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
[t.shape for t in get_next])
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
results = self.evaluate(get_next)
|
||||
sess.run(init_op)
|
||||
results = sess.run(get_next)
|
||||
for component, result_component in zip(components, results):
|
||||
self.assertAllEqual(component, result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -81,8 +81,8 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
[shape for shape in iterator.output_shapes])
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
results = self.evaluate(get_next)
|
||||
sess.run(init_op)
|
||||
results = sess.run(get_next)
|
||||
for component, result_component in zip(components, results):
|
||||
self.assertSparseValuesEqual(component, result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -112,8 +112,8 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
], [shape for shape in iterator.output_shapes])
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
results = self.evaluate(get_next)
|
||||
sess.run(init_op)
|
||||
results = sess.run(get_next)
|
||||
for component, result_component in zip(components, results):
|
||||
if sparse_tensor.is_sparse(component):
|
||||
self.assertSparseValuesEqual(component, result_component)
|
||||
@ -139,9 +139,9 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
[t.shape for t in get_next])
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
for i in range(4):
|
||||
results = self.evaluate(get_next)
|
||||
results = sess.run(get_next)
|
||||
for component, result_component in zip(components, results):
|
||||
self.assertAllEqual(component[i], result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -169,7 +169,7 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
[shape for shape in iterator.output_shapes])
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
expected = [
|
||||
(sparse_tensor.SparseTensorValue(
|
||||
indices=np.array([[0]]),
|
||||
@ -197,7 +197,7 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
dense_shape=np.array([3]))),
|
||||
]
|
||||
for i in range(3):
|
||||
results = self.evaluate(get_next)
|
||||
results = sess.run(get_next)
|
||||
for component, result_component in zip(expected[i], results):
|
||||
self.assertSparseValuesEqual(component, result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -229,7 +229,7 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
], [shape for shape in iterator.output_shapes])
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
expected = [
|
||||
(sparse_tensor.SparseTensorValue(
|
||||
indices=np.array([[0]]),
|
||||
@ -257,7 +257,7 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
dense_shape=np.array([3]))),
|
||||
]
|
||||
for i in range(3):
|
||||
results = self.evaluate(get_next)
|
||||
results = sess.run(get_next)
|
||||
for component, result_component in zip(
|
||||
(list(zip(*components[:3]))[i] + expected[i]), results):
|
||||
if sparse_tensor.is_sparse(component):
|
||||
@ -280,9 +280,9 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
self.assertEqual((1,), iterator.output_shapes["bar"])
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
for i in range(3):
|
||||
results = self.evaluate(get_next)
|
||||
results = sess.run(get_next)
|
||||
self.assertEqual(components["foo"][i], results["foo"])
|
||||
self.assertEqual(components["bar"][i], results["bar"])
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -308,7 +308,7 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
dense_shape)
|
||||
sess.run(init_op, feed_dict={st: sparse_feed})
|
||||
for i, s in enumerate(slices):
|
||||
results = self.evaluate(get_next)
|
||||
results = sess.run(get_next)
|
||||
self.assertAllEqual(s, results.values)
|
||||
expected_indices = np.array(
|
||||
[[j] for j in range(len(slices[i]))]).reshape([-1, 1])
|
||||
@ -474,15 +474,15 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
with ops.device("/cpu:0"):
|
||||
var_0 = resource_variable_ops.ResourceVariable(initial_value=0)
|
||||
dataset = dataset.map(lambda x: x + var_0.read_value())
|
||||
self.evaluate(var_0.initializer)
|
||||
sess.run(var_0.initializer)
|
||||
|
||||
with ops.device("/cpu:1"):
|
||||
var_1 = resource_variable_ops.ResourceVariable(initial_value=0)
|
||||
dataset = dataset.map(lambda x: x + var_1.read_value())
|
||||
self.evaluate(var_1.initializer)
|
||||
sess.run(var_1.initializer)
|
||||
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
self.evaluate(iterator.initializer)
|
||||
sess.run(iterator.initializer)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
errors.FailedPreconditionError,
|
||||
@ -506,7 +506,7 @@ class DatasetConstructorBenchmark(test.Benchmark):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with session.Session() as sess:
|
||||
self.evaluate(iterator.initializer)
|
||||
sess.run(iterator.initializer)
|
||||
# Run one whole epoch to burn in the computation.
|
||||
for _ in range(input_size // batch_size):
|
||||
sess.run(next_element)
|
||||
@ -543,7 +543,7 @@ class DatasetConstructorBenchmark(test.Benchmark):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with session.Session() as sess:
|
||||
self.evaluate(iterator.initializer)
|
||||
sess.run(iterator.initializer)
|
||||
get_next_element = sess.make_callable(next_element)
|
||||
# Run one whole epoch to burn in the computation.
|
||||
for _ in range(input_size // batch_size):
|
||||
@ -582,7 +582,7 @@ class DatasetConstructorBenchmark(test.Benchmark):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with session.Session() as sess:
|
||||
self.evaluate(iterator.initializer)
|
||||
sess.run(iterator.initializer)
|
||||
get_next_element = sess.make_callable(next_element)
|
||||
# Run one whole epoch to burn in the computation.
|
||||
for _ in range(input_size // batch_size):
|
||||
@ -620,7 +620,7 @@ class DatasetConstructorBenchmark(test.Benchmark):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with session.Session() as sess:
|
||||
self.evaluate(iterator.initializer)
|
||||
sess.run(iterator.initializer)
|
||||
get_next_element = sess.make_callable(next_element)
|
||||
# Run one whole epoch to burn in the computation.
|
||||
for _ in range(input_size // batch_size):
|
||||
|
@ -47,10 +47,10 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for _ in range(2): # Run twice to test reinitialization.
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
for _ in range(num_repeats):
|
||||
for elem in elem_sequence:
|
||||
self.assertAllEqual(elem, self.evaluate(get_next))
|
||||
self.assertAllEqual(elem, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -65,7 +65,7 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
with self.cached_session() as sess:
|
||||
for _ in range(num_repeats):
|
||||
for elem in elem_sequence:
|
||||
self.assertAllEqual(elem, self.evaluate(get_next))
|
||||
self.assertAllEqual(elem, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -133,10 +133,10 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
for _ in range(num_inner_repeats * num_outer_repeats):
|
||||
for elem in input_list:
|
||||
val0, val1 = self.evaluate(get_next)
|
||||
val0, val1 = sess.run(get_next)
|
||||
self.assertAllEqual(elem[0], val0)
|
||||
self.assertAllEqual(elem[1], val1)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -192,10 +192,10 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
for elem in [0, 1]:
|
||||
for _ in range(num_parallel_iterators):
|
||||
self.assertAllEqual(elem, self.evaluate(get_next))
|
||||
self.assertAllEqual(elem, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -215,9 +215,9 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
self.assertEqual(dtype, get_next.dtype)
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
for expected in [[1], [2], [3]]:
|
||||
next_val = self.evaluate(get_next)
|
||||
next_val = sess.run(get_next)
|
||||
self.assertEqual(dtype.as_numpy_dtype, next_val.dtype)
|
||||
self.assertAllEqual(expected, next_val)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -236,9 +236,9 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
for expected in [b"foo", b"bar", b"baz"]:
|
||||
next_val = self.evaluate(get_next)
|
||||
next_val = sess.run(get_next)
|
||||
self.assertAllEqual(expected, next_val)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
@ -257,12 +257,12 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
self.assertAllEqual([1, 2, 3], self.evaluate(get_next))
|
||||
self.assertAllEqual([4, 5, 6], self.evaluate(get_next))
|
||||
sess.run(init_op)
|
||||
self.assertAllEqual([1, 2, 3], sess.run(get_next))
|
||||
self.assertAllEqual([4, 5, 6], sess.run(get_next))
|
||||
with self.assertRaisesOpError("The expected type was int64"):
|
||||
sess.run(get_next)
|
||||
self.assertAllEqual([7, 8, 9], self.evaluate(get_next))
|
||||
self.assertAllEqual([7, 8, 9], sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -280,12 +280,12 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
self.assertAllEqual([1, 2, 3], self.evaluate(get_next))
|
||||
self.assertAllEqual([4, 5, 6], self.evaluate(get_next))
|
||||
sess.run(init_op)
|
||||
self.assertAllEqual([1, 2, 3], sess.run(get_next))
|
||||
self.assertAllEqual([4, 5, 6], sess.run(get_next))
|
||||
with self.assertRaisesOpError(r"element of shape \(3,\) was expected"):
|
||||
sess.run(get_next)
|
||||
self.assertAllEqual([11, 12, 13], self.evaluate(get_next))
|
||||
self.assertAllEqual([11, 12, 13], sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -304,16 +304,16 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
self.assertEqual((1, 2), self.evaluate(get_next))
|
||||
self.assertEqual((3, 4), self.evaluate(get_next))
|
||||
sess.run(init_op)
|
||||
self.assertEqual((1, 2), sess.run(get_next))
|
||||
self.assertEqual((3, 4), sess.run(get_next))
|
||||
with self.assertRaisesOpError(
|
||||
r"The expected structure was \(tf\.int64, tf\.int64\)"):
|
||||
sess.run(get_next)
|
||||
with self.assertRaisesOpError(
|
||||
r"The expected structure was \(tf\.int64, tf\.int64\)"):
|
||||
sess.run(get_next)
|
||||
self.assertEqual((9, 10), self.evaluate(get_next))
|
||||
self.assertEqual((9, 10), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -329,9 +329,9 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
self.assertAllEqual(1, self.evaluate(get_next))
|
||||
self.assertAllEqual([2, 3], self.evaluate(get_next))
|
||||
sess.run(init_op)
|
||||
self.assertAllEqual(1, sess.run(get_next))
|
||||
self.assertAllEqual([2, 3], sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -349,9 +349,9 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
self.assertAllEqual(0, self.evaluate(get_next))
|
||||
self.assertAllEqual(1, self.evaluate(get_next))
|
||||
sess.run(init_op)
|
||||
self.assertAllEqual(0, sess.run(get_next))
|
||||
self.assertAllEqual(1, sess.run(get_next))
|
||||
|
||||
def testFromGeneratorDestructorCalled(self):
|
||||
# Use an `Event` to signal that the generator has been deleted.
|
||||
@ -378,9 +378,9 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with session.Session() as sess:
|
||||
self.evaluate(init_op)
|
||||
self.assertAllEqual(42, self.evaluate(get_next))
|
||||
self.assertAllEqual(42, self.evaluate(get_next))
|
||||
sess.run(init_op)
|
||||
self.assertAllEqual(42, sess.run(get_next))
|
||||
self.assertAllEqual(42, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
# Test that `GeneratorWrapper` object is destroyed when the
|
||||
@ -407,10 +407,10 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
expected = [1, 2, 2, 3, 3, 3, 4, 4, 4, 4]
|
||||
for x in expected:
|
||||
self.assertEqual(x, self.evaluate(get_next))
|
||||
self.assertEqual(x, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -436,13 +436,13 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
expected = [(0, b"Hi!"),
|
||||
(0, b"Hi!"), (1, b"Hi!"),
|
||||
(0, b"Hi!"), (1, b"Hi!"), (2, b"Hi!"),
|
||||
(0, b"Hi!"), (1, b"Hi!"), (2, b"Hi!"), (3, b"Hi!")]
|
||||
for x in expected:
|
||||
self.assertEqual(x, self.evaluate(get_next))
|
||||
self.assertEqual(x, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -470,9 +470,9 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
self.assertAllEqual(37, self.evaluate(get_next))
|
||||
self.assertAllEqual(37, self.evaluate(get_next))
|
||||
sess.run(init_op)
|
||||
self.assertAllEqual(37, sess.run(get_next))
|
||||
self.assertAllEqual(37, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.assertTrue(event.is_set())
|
||||
|
@ -67,7 +67,7 @@ class FilterDatasetTest(test_base.DatasetTestBase):
|
||||
sess.run(init_op, feed_dict={count: count_val, modulus: modulus_val})
|
||||
for _ in range(count_val):
|
||||
for i in [x for x in range(7) if x**2 % modulus_val == 0]:
|
||||
result = self.evaluate(get_next)
|
||||
result = sess.run(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
self.assertAllEqual(component[i]**2, result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -86,9 +86,9 @@ class FilterDatasetTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(0, self.evaluate(get_next))
|
||||
self.assertEqual(1, self.evaluate(get_next))
|
||||
self.assertEqual(3, self.evaluate(get_next))
|
||||
self.assertEqual(0, sess.run(get_next))
|
||||
self.assertEqual(1, sess.run(get_next))
|
||||
self.assertEqual(3, sess.run(get_next))
|
||||
|
||||
def testFilterDict(self):
|
||||
iterator = (dataset_ops.Dataset.range(10)
|
||||
@ -100,10 +100,10 @@ class FilterDatasetTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
for i in range(10):
|
||||
if (i ** 2) % 2 == 0:
|
||||
self.assertEqual(i * 2 + i**2, self.evaluate(get_next))
|
||||
self.assertEqual(i * 2 + i ** 2, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -125,8 +125,8 @@ class FilterDatasetTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
self.assertAllEqual(input_data[0], self.evaluate(get_next))
|
||||
sess.run(init_op)
|
||||
self.assertAllEqual(input_data[0], sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -148,9 +148,9 @@ class FilterDatasetTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
for i in range(5):
|
||||
actual = self.evaluate(get_next)
|
||||
actual = sess.run(get_next)
|
||||
self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue))
|
||||
self.assertSparseValuesEqual(actual, _map_fn(i * 2)[0])
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -166,9 +166,9 @@ class FilterDatasetTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
for i in range(10):
|
||||
self.assertEqual((i, True), self.evaluate(get_next))
|
||||
self.assertEqual((i, True), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -178,7 +178,7 @@ class FilterDatasetTest(test_base.DatasetTestBase):
|
||||
iterators = [dataset.make_one_shot_iterator() for _ in range(10)]
|
||||
next_elements = [iterator.get_next() for iterator in iterators]
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual([0 for _ in range(10)], self.evaluate(next_elements))
|
||||
self.assertEqual([0 for _ in range(10)], sess.run(next_elements))
|
||||
|
||||
|
||||
class FilterDatasetBenchmark(test.Benchmark):
|
||||
|
@ -45,10 +45,10 @@ class FlatMapDatasetTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
for i in repeats:
|
||||
for _ in range(i):
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -64,11 +64,11 @@ class FlatMapDatasetTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
for row in repeats:
|
||||
for i in row:
|
||||
for _ in range(i):
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
@ -94,12 +94,12 @@ class FlatMapDatasetTest(test_base.DatasetTestBase):
|
||||
with session.Session(server.target) as sess2:
|
||||
for _ in range(3):
|
||||
sess = random.choice([sess1, sess2])
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
for row in repeats:
|
||||
for i in row:
|
||||
for _ in range(i):
|
||||
sess = random.choice([sess1, sess2])
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess = random.choice([sess1, sess2])
|
||||
@ -115,10 +115,10 @@ class FlatMapDatasetTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
for i in range(10):
|
||||
for _ in range(i ** 2):
|
||||
self.assertEqual(i * 2, self.evaluate(get_next))
|
||||
self.assertEqual(i * 2, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
# pylint: enable=g-long-lambda
|
||||
@ -139,11 +139,11 @@ class FlatMapDatasetTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
for i in range(10):
|
||||
for j in range(2):
|
||||
expected = [i, 0] if j % 2 == 0 else [0, -i]
|
||||
self.assertAllEqual(expected, self.evaluate(get_next))
|
||||
self.assertAllEqual(expected, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
|
@ -196,7 +196,7 @@ class InterleaveDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
for expected_element in _interleave(
|
||||
_repeat(input_values, count), cycle_length, block_length):
|
||||
self.assertEqual(expected_element, self.evaluate(get_next))
|
||||
self.assertEqual(expected_element, sess.run(get_next))
|
||||
|
||||
for _ in range(2):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -231,7 +231,7 @@ class InterleaveDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(get_next)
|
||||
else:
|
||||
self.assertEqual(value, self.evaluate(get_next))
|
||||
self.assertEqual(value, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -254,7 +254,7 @@ class InterleaveDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
for i in range(10):
|
||||
for j in range(2):
|
||||
expected = [i, 0] if j % 2 == 0 else [0, -i]
|
||||
self.assertAllEqual(expected, self.evaluate(get_next))
|
||||
self.assertAllEqual(expected, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -308,7 +308,7 @@ class InterleaveDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
for element in elements:
|
||||
coordination_events[element].set()
|
||||
self.assertEqual(element * element, self.evaluate(get_next))
|
||||
self.assertEqual(element * element, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
|
@ -57,7 +57,7 @@ class IteratorClusterTest(test.TestCase):
|
||||
|
||||
with session.Session(worker[0].target) as sess:
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
self.evaluate(get_next_op)
|
||||
sess.run(get_next_op)
|
||||
|
||||
def _testRemoteIteratorHelper(self, device0, device1, target):
|
||||
with ops.device(device1):
|
||||
@ -134,12 +134,12 @@ class IteratorClusterTest(test.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with session.Session(worker[0].target) as sess:
|
||||
self.evaluate(table.initializer)
|
||||
self.evaluate(init_op)
|
||||
self.assertAllEqual([0, 0, -1, 1, 2], self.evaluate(get_next))
|
||||
sess.run(table.initializer)
|
||||
sess.run(init_op)
|
||||
self.assertAllEqual([0, 0, -1, 1, 2], sess.run(get_next))
|
||||
|
||||
with session.Session(worker[0].target) as sess:
|
||||
self.assertAllEqual([2, 0], self.evaluate(get_next))
|
||||
self.assertAllEqual([2, 0], sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -166,7 +166,7 @@ class IteratorClusterTest(test.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with session.Session(worker[0].target) as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
for _ in range(3):
|
||||
sess.run(get_next)
|
||||
|
||||
|
@ -97,7 +97,7 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
for _ in range(14):
|
||||
for i in range(7):
|
||||
result = self.evaluate(get_next)
|
||||
result = sess.run(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
self.assertAllEqual(component[i]**2, result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -123,7 +123,7 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
for _ in range(14):
|
||||
for i in range(7):
|
||||
result = self.evaluate(get_next)
|
||||
result = sess.run(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
self.assertAllEqual(component[i]**2, result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -159,7 +159,7 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
for _ in range(14):
|
||||
for i in range(7):
|
||||
result = self.evaluate(get_next)
|
||||
result = sess.run(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
self.assertAllEqual(component[i]**2, result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -175,7 +175,7 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
||||
config = config_pb2.ConfigProto(
|
||||
inter_op_parallelism_threads=1, use_per_session_threads=True)
|
||||
with session.Session(config=config) as sess:
|
||||
self.assertAllEqual([1, 4, 9], self.evaluate(next_element))
|
||||
self.assertAllEqual([1, 4, 9], sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -254,15 +254,15 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with session.Session(server.target) as sess:
|
||||
self.evaluate(init_op)
|
||||
results = self.evaluate(get_next)
|
||||
sess.run(init_op)
|
||||
results = sess.run(get_next)
|
||||
for component, result_component in zip(components, results):
|
||||
self.assertAllEqual(component, result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Re-initialize the iterator in the first session.
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
|
||||
with ops.Graph().as_default():
|
||||
# Re-define the iterator manually, without defining any of the
|
||||
@ -277,7 +277,7 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
with session.Session(server.target) as sess:
|
||||
# Use the iterator without re-initializing in the second session.
|
||||
results = self.evaluate(get_next)
|
||||
results = sess.run(get_next)
|
||||
for component, result_component in zip(components, results):
|
||||
self.assertAllEqual(component, result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -317,20 +317,20 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
||||
sess.run(get_next)
|
||||
|
||||
# Initialize with one dataset.
|
||||
self.evaluate(dataset_3_init_op)
|
||||
self.assertAllEqual([1, 2, 3], self.evaluate(get_next))
|
||||
sess.run(dataset_3_init_op)
|
||||
self.assertAllEqual([1, 2, 3], sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Initialize with a different dataset.
|
||||
self.evaluate(dataset_4_init_op)
|
||||
self.assertAllEqual([4, 5, 6, 7], self.evaluate(get_next))
|
||||
sess.run(dataset_4_init_op)
|
||||
self.assertAllEqual([4, 5, 6, 7], sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Reinitialize with the first dataset.
|
||||
self.evaluate(dataset_3_init_op)
|
||||
self.assertAllEqual([1, 2, 3], self.evaluate(get_next))
|
||||
sess.run(dataset_3_init_op)
|
||||
self.assertAllEqual([1, 2, 3], sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -348,7 +348,7 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
||||
g, output_types=dtypes.int64)
|
||||
sess.run(iterator.make_initializer(dataset_1))
|
||||
for expected in range(10):
|
||||
self.assertEqual(expected, self.evaluate(next_element))
|
||||
self.assertEqual(expected, sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -356,7 +356,7 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
||||
g, output_types=dtypes.int64)
|
||||
sess.run(iterator.make_initializer(dataset_2))
|
||||
for expected in range(10):
|
||||
self.assertEqual(expected, self.evaluate(next_element))
|
||||
self.assertEqual(expected, sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -679,10 +679,10 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
||||
n = itr.get_next()
|
||||
|
||||
with session.Session(s3.target, config=config) as sess:
|
||||
self.evaluate(itr.initializer)
|
||||
sess.run(itr.initializer)
|
||||
expected_values = worker_devices
|
||||
for expected in expected_values:
|
||||
self.assertEqual((compat.as_bytes(expected),), self.evaluate(n))
|
||||
self.assertEqual((compat.as_bytes(expected),), sess.run(n))
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(n)
|
||||
@ -786,8 +786,8 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, _, save_op, _ = _build_range_dataset_graph()
|
||||
with self.session(graph=g) as sess:
|
||||
self.evaluate(init_op)
|
||||
self.evaluate(save_op)
|
||||
sess.run(init_op)
|
||||
sess.run(save_op)
|
||||
|
||||
# Attempt to restore the saved iterator into an IteratorResource of
|
||||
# incompatible type. An iterator of RangeDataset has output type int64,
|
||||
@ -798,7 +798,7 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
||||
_, _, _, restore_op = _build_reader_dataset_graph()
|
||||
with self.session(graph=g) as sess:
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
self.evaluate(restore_op)
|
||||
sess.run(restore_op)
|
||||
|
||||
def testRepeatedGetNextWarning(self):
|
||||
iterator = dataset_ops.Dataset.range(10).make_one_shot_iterator()
|
||||
@ -949,7 +949,7 @@ class IteratorCheckpointingTest(test.TestCase):
|
||||
checkpoint.restore(checkpoint_management.latest_checkpoint(
|
||||
checkpoint_directory)).initialize_or_restore(sess)
|
||||
for j in range(2):
|
||||
self.assertEqual(i * 2 + j, self.evaluate(get_next))
|
||||
self.assertEqual(i * 2 + j, sess.run(get_next))
|
||||
checkpoint.save(file_prefix=checkpoint_prefix)
|
||||
|
||||
|
||||
|
@ -102,7 +102,7 @@ class ListFilesDatasetOpTest(test_base.DatasetTestBase):
|
||||
all_produced_filenames = []
|
||||
for _ in range(3):
|
||||
produced_filenames = []
|
||||
self.evaluate(itr.initializer)
|
||||
sess.run(itr.initializer)
|
||||
try:
|
||||
while True:
|
||||
produced_filenames.append(sess.run(next_element))
|
||||
|
@ -114,7 +114,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
sess.run(init_op, feed_dict={count: 14})
|
||||
for _ in range(14):
|
||||
for i in range(7):
|
||||
result = self.evaluate(get_next)
|
||||
result = sess.run(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
self.assertAllEqual(component[i]**2, result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -185,7 +185,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
output_buffer_size: output_buffer_size_val})
|
||||
for _ in range(14):
|
||||
for i in range(7):
|
||||
result = self.evaluate(get_next)
|
||||
result = sess.run(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
self.assertAllEqual(component[i]**2, result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -242,7 +242,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
for _ in range(3):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -257,7 +257,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
for _ in range(3):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -272,7 +272,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
for _ in range(3):
|
||||
sess.run(get_next)
|
||||
# The 4th element is NaN, so `array_ops.check_numerics()` should fail.
|
||||
@ -293,7 +293,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
for _ in range(3):
|
||||
sess.run(get_next)
|
||||
# The 4th element is NaN, so `array_ops.check_numerics()` should fail.
|
||||
@ -325,10 +325,10 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with ops.Graph().as_default() as g:
|
||||
captured_init_op, init_op, get_next = _build_graph()
|
||||
with self.session(graph=g) as sess:
|
||||
self.evaluate(captured_init_op)
|
||||
self.evaluate(init_op)
|
||||
sess.run(captured_init_op)
|
||||
sess.run(init_op)
|
||||
for i in range(10):
|
||||
self.assertEqual(i * i, self.evaluate(get_next))
|
||||
self.assertEqual(i * i, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -353,8 +353,8 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(table.initializer)
|
||||
self.evaluate(init_op)
|
||||
sess.run(table.initializer)
|
||||
sess.run(init_op)
|
||||
sess.run(get_next)
|
||||
sess.run(get_next)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -371,11 +371,11 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(enqueue_op)
|
||||
self.evaluate(close_op)
|
||||
self.evaluate(init_op)
|
||||
sess.run(enqueue_op)
|
||||
sess.run(close_op)
|
||||
sess.run(init_op)
|
||||
for element in elements:
|
||||
self.assertEqual(element, self.evaluate(get_next))
|
||||
self.assertEqual(element, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -396,9 +396,9 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(enqueue_op)
|
||||
self.evaluate(close_op)
|
||||
self.evaluate(init_op)
|
||||
sess.run(enqueue_op)
|
||||
sess.run(close_op)
|
||||
sess.run(init_op)
|
||||
for i in range(100):
|
||||
self.assertEqual(sorted([elements[i * 2], elements[i * 2 + 1]]),
|
||||
sorted(sess.run(get_next)))
|
||||
@ -415,15 +415,15 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(counter_var.initializer)
|
||||
self.evaluate(init_op)
|
||||
sess.run(counter_var.initializer)
|
||||
sess.run(init_op)
|
||||
for i in range(10):
|
||||
self.assertEqual(i, self.evaluate(counter_var))
|
||||
self.assertEqual(i + 1, self.evaluate(get_next))
|
||||
self.assertEqual(10, self.evaluate(counter_var))
|
||||
self.assertEqual(i, sess.run(counter_var))
|
||||
self.assertEqual(i + 1, sess.run(get_next))
|
||||
self.assertEqual(10, sess.run(counter_var))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.assertEqual(10, self.evaluate(counter_var))
|
||||
self.assertEqual(10, sess.run(counter_var))
|
||||
|
||||
def testCaptureUninitializedVariableError(self):
|
||||
counter_var = variable_scope.get_variable(
|
||||
@ -435,7 +435,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
with self.assertRaises(errors.NotFoundError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -447,14 +447,14 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
random_values = []
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
while True:
|
||||
random_values.extend(sess.run(get_next))
|
||||
self.assertEqual(10, len(random_values))
|
||||
self.assertGreater(np.abs(np.diff(random_values)).max(), 1e-6)
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
random_values_2 = []
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
while True:
|
||||
@ -473,8 +473,8 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
random_values = self.evaluate(get_next)
|
||||
sess.run(init_op)
|
||||
random_values = sess.run(get_next)
|
||||
|
||||
# Assert that one of the next 99 batches yielded by the iterator is
|
||||
# different from the first.
|
||||
@ -500,15 +500,15 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(counter_var.initializer)
|
||||
self.evaluate(init_op)
|
||||
sess.run(counter_var.initializer)
|
||||
sess.run(init_op)
|
||||
for i in range(10):
|
||||
self.assertEqual(i, self.evaluate(counter_var))
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
self.assertEqual(10, self.evaluate(counter_var))
|
||||
self.assertEqual(i, sess.run(counter_var))
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
self.assertEqual(10, sess.run(counter_var))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.assertEqual(10, self.evaluate(counter_var))
|
||||
self.assertEqual(10, sess.run(counter_var))
|
||||
|
||||
def testMapDict(self):
|
||||
iterator = (dataset_ops.Dataset.range(10)
|
||||
@ -519,9 +519,9 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
for i in range(10):
|
||||
self.assertEqual(i * 2 + i**2, self.evaluate(get_next))
|
||||
self.assertEqual(i * 2 + i ** 2, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -569,8 +569,8 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
self.assertAllEqual(row**2, self.evaluate(get_next))
|
||||
sess.run(init_op)
|
||||
self.assertAllEqual(row ** 2, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -611,7 +611,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
row = np.arange(6)
|
||||
for num in [2, 3, 4]:
|
||||
init_op, get_next = build_dataset(row, num)
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
for i in range(6):
|
||||
self.assertEqual(
|
||||
(i // 2 if i % 2 else i * 2) if (num == 2 or num == 3) else i * 2,
|
||||
@ -652,7 +652,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
row = np.arange(6)
|
||||
for num in [2, 3, 4]:
|
||||
init_op, get_next = build_dataset(row, num)
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
self.assertAllEqual(
|
||||
[x // 2 if (num == 2 or num == 3) else x * 2 for x in row],
|
||||
sess.run(get_next))
|
||||
@ -697,7 +697,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
self.assertAllEqual([(x // 2 if x % 2 else x * 2) if
|
||||
(num == 2 or num == 3) else x * 2 for x in row],
|
||||
sess.run(get_next))
|
||||
@ -735,7 +735,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
for buffer_size in [1, 10, 100, 1000]:
|
||||
sess.run(init_op, feed_dict={buffer_size_placeholder: buffer_size})
|
||||
for i in range(100):
|
||||
self.assertEqual(i * i, self.evaluate(get_next))
|
||||
self.assertEqual(i * i, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -753,10 +753,10 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
sess.run(init_op, feed_dict={buffer_size_placeholder: buffer_size})
|
||||
for i in range(event_will_be_set_after_consuming):
|
||||
self.assertFalse(ev.is_set())
|
||||
self.assertEqual(i * i, self.evaluate(get_next))
|
||||
self.assertEqual(i * i, sess.run(get_next))
|
||||
ev.wait()
|
||||
for i in range(event_will_be_set_after_consuming, 100):
|
||||
self.assertEqual(i * i, self.evaluate(get_next))
|
||||
self.assertEqual(i * i, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -768,9 +768,9 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
for i in range(10):
|
||||
self.assertEqual((i, 37.0), self.evaluate(get_next))
|
||||
self.assertEqual((i, 37.0), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -789,9 +789,9 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
for i in range(10):
|
||||
self.assertEqual((i, 37.0), self.evaluate(get_next))
|
||||
self.assertEqual((i, 37.0), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -810,9 +810,9 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
for i in range(10):
|
||||
actual = self.evaluate(get_next)
|
||||
actual = sess.run(get_next)
|
||||
self.assertIsInstance(actual, sparse_tensor.SparseTensorValue)
|
||||
self.assertSparseValuesEqual(actual, _sparse(i))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -837,9 +837,9 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
for i in range(10):
|
||||
actual = self.evaluate(get_next)
|
||||
actual = sess.run(get_next)
|
||||
self.assertIsInstance(actual, sparse_tensor.SparseTensorValue)
|
||||
self.assertSparseValuesEqual(actual, _check(_sparse(i)).eval())
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -861,9 +861,9 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
for i in range(100):
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -875,9 +875,9 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
for i in range(10):
|
||||
self.assertEqual((i, b"hello", 10), self.evaluate(get_next))
|
||||
self.assertEqual((i, b"hello", 10), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -945,7 +945,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError, "BrokenConst"):
|
||||
self.evaluate(iterator.initializer)
|
||||
sess.run(iterator.initializer)
|
||||
|
||||
# pylint: disable=g-long-lambda
|
||||
@parameterized.named_parameters(
|
||||
@ -972,7 +972,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
tids = self.evaluate(get_next)
|
||||
tids = sess.run(get_next)
|
||||
self.assertTrue(all(tids[0] == tid for tid in tids))
|
||||
# pylint: enable=g-long-lambda
|
||||
|
||||
@ -996,7 +996,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
expected = map_fn(*sess.run(self.structuredElement(structure)))
|
||||
else:
|
||||
expected = map_fn(sess.run(self.structuredElement(structure)))
|
||||
self.assertEqual(expected, self.evaluate(get_next))
|
||||
self.assertEqual(expected, sess.run(get_next))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("Sequential", None),
|
||||
@ -1011,7 +1011,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer, feed_dict={captured_t: 42})
|
||||
self.assertEqual(42, self.evaluate(get_next))
|
||||
self.assertEqual(42, sess.run(get_next))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("1", 1, 1),
|
||||
@ -1030,7 +1030,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.cached_session(config=config) as sess:
|
||||
for i in range(num_elements):
|
||||
coordination_events[i].set()
|
||||
self.assertEqual(i * i, self.evaluate(get_next))
|
||||
self.assertEqual(i * i, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -1052,7 +1052,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
for element in elements:
|
||||
coordination_events[element].set()
|
||||
self.assertEqual(element * element, self.evaluate(get_next))
|
||||
self.assertEqual(element * element, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
|
@ -40,7 +40,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
||||
|
||||
config = config_pb2.ConfigProto(device_count={"CPU": 3})
|
||||
with self.test_session(config=config) as sess:
|
||||
self.evaluate(multi_device_iterator.initializer)
|
||||
sess.run(multi_device_iterator.initializer)
|
||||
|
||||
def testBasic(self):
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
@ -50,10 +50,10 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
||||
|
||||
config = config_pb2.ConfigProto(device_count={"CPU": 3})
|
||||
with self.test_session(config=config) as sess:
|
||||
self.evaluate(multi_device_iterator.initializer)
|
||||
sess.run(multi_device_iterator.initializer)
|
||||
for i in range(0, 10, 2):
|
||||
self.assertEqual(i, self.evaluate(elem_on_1))
|
||||
self.assertEqual(i + 1, self.evaluate(elem_on_2))
|
||||
self.assertEqual(i, sess.run(elem_on_1))
|
||||
self.assertEqual(i + 1, sess.run(elem_on_2))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(elem_on_1)
|
||||
sess.run(elem_on_2)
|
||||
@ -67,10 +67,10 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
||||
|
||||
config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=config) as sess:
|
||||
self.evaluate(multi_device_iterator.initializer)
|
||||
sess.run(multi_device_iterator.initializer)
|
||||
for i in range(0, 10, 2):
|
||||
self.assertEqual(i, self.evaluate(elem_on_1))
|
||||
self.assertEqual(i + 1, self.evaluate(elem_on_2))
|
||||
self.assertEqual(i, sess.run(elem_on_1))
|
||||
self.assertEqual(i + 1, sess.run(elem_on_2))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(elem_on_1)
|
||||
sess.run(elem_on_2)
|
||||
@ -85,12 +85,12 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
||||
|
||||
config = config_pb2.ConfigProto(device_count={"CPU": 3})
|
||||
with self.test_session(config=config) as sess:
|
||||
self.evaluate(multi_device_iterator.initializer)
|
||||
sess.run(multi_device_iterator.initializer)
|
||||
for i in range(0, 20, 4):
|
||||
self.assertEqual(i, self.evaluate(elem_on_1))
|
||||
self.assertEqual(i + 1, self.evaluate(elem_on_2))
|
||||
self.assertEqual(i + 2, self.evaluate(elem_on_3))
|
||||
self.assertEqual(i + 3, self.evaluate(elem_on_4))
|
||||
self.assertEqual(i, sess.run(elem_on_1))
|
||||
self.assertEqual(i + 1, sess.run(elem_on_2))
|
||||
self.assertEqual(i + 2, sess.run(elem_on_3))
|
||||
self.assertEqual(i + 3, sess.run(elem_on_4))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(elem_on_1)
|
||||
sess.run(elem_on_2)
|
||||
@ -105,11 +105,11 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
||||
|
||||
config = config_pb2.ConfigProto(device_count={"CPU": 3})
|
||||
with self.test_session(config=config) as sess:
|
||||
self.evaluate(multi_device_iterator.initializer)
|
||||
sess.run(multi_device_iterator.initializer)
|
||||
for i in range(0, 8, 2):
|
||||
self.assertEqual(i, self.evaluate(elem_on_1))
|
||||
self.assertEqual(i + 1, self.evaluate(elem_on_2))
|
||||
self.assertEqual(8, self.evaluate(elem_on_1))
|
||||
self.assertEqual(i, sess.run(elem_on_1))
|
||||
self.assertEqual(i + 1, sess.run(elem_on_2))
|
||||
self.assertEqual(8, sess.run(elem_on_1))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(elem_on_1)
|
||||
sess.run(elem_on_2)
|
||||
@ -126,7 +126,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
||||
|
||||
config = config_pb2.ConfigProto(device_count={"CPU": 3})
|
||||
with self.test_session(config=config) as sess:
|
||||
self.evaluate(multi_device_iterator.initializer)
|
||||
sess.run(multi_device_iterator.initializer)
|
||||
for i in range(0, 8, 2):
|
||||
elem_on_1_has_value, elem_on_1_value = sess.run(
|
||||
[elem_on_1_has_value_t, elem_on_1_t])
|
||||
@ -140,8 +140,8 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
||||
[elem_on_1_has_value_t, elem_on_1_t])
|
||||
self.assertTrue(elem_on_1_has_value)
|
||||
self.assertEqual(8, elem_on_1_value)
|
||||
self.assertFalse(self.evaluate(elem_on_1_has_value_t))
|
||||
self.assertFalse(self.evaluate(elem_on_2_has_value_t))
|
||||
self.assertFalse(sess.run(elem_on_1_has_value_t))
|
||||
self.assertFalse(sess.run(elem_on_2_has_value_t))
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(elem_on_1_t)
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
@ -155,11 +155,11 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
||||
|
||||
config = config_pb2.ConfigProto(device_count={"CPU": 3})
|
||||
with self.test_session(config=config) as sess:
|
||||
self.evaluate(multi_device_iterator.initializer)
|
||||
sess.run(multi_device_iterator.initializer)
|
||||
for i in range(0, 10, 2):
|
||||
self.assertEqual(i, self.evaluate(elem_on_1))
|
||||
self.assertEqual(i, sess.run(elem_on_1))
|
||||
for i in range(0, 10, 2):
|
||||
self.assertEqual(i + 1, self.evaluate(elem_on_2))
|
||||
self.assertEqual(i + 1, sess.run(elem_on_2))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(elem_on_1)
|
||||
sess.run(elem_on_2)
|
||||
@ -192,10 +192,10 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
||||
|
||||
config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1})
|
||||
with self.test_session(config=config) as sess:
|
||||
self.evaluate(multi_device_iterator.initializer)
|
||||
sess.run(multi_device_iterator.initializer)
|
||||
for i in range(0, 10, 2):
|
||||
self.assertEqual(i, self.evaluate(elem_on_1))
|
||||
self.assertEqual(i + 1, self.evaluate(elem_on_2))
|
||||
self.assertEqual(i, sess.run(elem_on_1))
|
||||
self.assertEqual(i + 1, sess.run(elem_on_2))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(elem_on_1)
|
||||
sess.run(elem_on_2)
|
||||
@ -211,11 +211,11 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
||||
|
||||
config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1})
|
||||
with self.test_session(config=config) as sess:
|
||||
self.evaluate(multi_device_iterator.initializer)
|
||||
sess.run(multi_device_iterator.initializer)
|
||||
for i in range(0, 10, 2):
|
||||
self.assertEqual(i, self.evaluate(elem_on_1))
|
||||
self.assertEqual(i, sess.run(elem_on_1))
|
||||
for i in range(0, 10, 2):
|
||||
self.assertEqual(i + 1, self.evaluate(elem_on_2))
|
||||
self.assertEqual(i + 1, sess.run(elem_on_2))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(elem_on_1)
|
||||
sess.run(elem_on_2)
|
||||
@ -235,7 +235,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
||||
|
||||
config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1})
|
||||
with self.test_session(config=config) as sess:
|
||||
self.evaluate(multi_device_iterator.initializer)
|
||||
sess.run(multi_device_iterator.initializer)
|
||||
for i in range(0, 8, 2):
|
||||
elem_on_1_has_value, elem_on_1_value = sess.run(
|
||||
[elem_on_1_has_value_t, elem_on_1_t])
|
||||
@ -249,8 +249,8 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
||||
[elem_on_1_has_value_t, elem_on_1_t])
|
||||
self.assertTrue(elem_on_1_has_value)
|
||||
self.assertEqual(8, elem_on_1_value)
|
||||
self.assertFalse(self.evaluate(elem_on_1_has_value_t))
|
||||
self.assertFalse(self.evaluate(elem_on_2_has_value_t))
|
||||
self.assertFalse(sess.run(elem_on_1_has_value_t))
|
||||
self.assertFalse(sess.run(elem_on_2_has_value_t))
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(elem_on_1_t)
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
@ -272,10 +272,10 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
||||
|
||||
config = config_pb2.ConfigProto(device_count={"CPU": 3})
|
||||
with self.test_session(config=config) as sess:
|
||||
self.evaluate(multi_device_iterator.initializer)
|
||||
sess.run(multi_device_iterator.initializer)
|
||||
for i in range(0, 10, 2):
|
||||
self.assertEqual(i, self.evaluate(elem_on_1))
|
||||
self.assertEqual(i + 1, self.evaluate(elem_on_2))
|
||||
self.assertEqual(i, sess.run(elem_on_1))
|
||||
self.assertEqual(i + 1, sess.run(elem_on_2))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(elem_on_1)
|
||||
sess.run(elem_on_2)
|
||||
|
@ -227,7 +227,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
# For each element of the dataset, assert that the optional evaluates to
|
||||
# the expected value.
|
||||
self.evaluate(iterator.initializer)
|
||||
sess.run(iterator.initializer)
|
||||
for _ in range(3):
|
||||
elem_has_value, elem_value = sess.run([elem_has_value_t, elem_value_t])
|
||||
self.assertTrue(elem_has_value)
|
||||
@ -236,7 +236,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
# After exhausting the iterator, `next_elem.has_value()` will evaluate to
|
||||
# false, and attempting to get the value will fail.
|
||||
for _ in range(2):
|
||||
self.assertFalse(self.evaluate(elem_has_value_t))
|
||||
self.assertFalse(sess.run(elem_has_value_t))
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(elem_value_t)
|
||||
|
||||
|
@ -40,7 +40,7 @@ class PrefetchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op, feed_dict={buffer_size_t: buffer_size})
|
||||
for m in range(10):
|
||||
self.assertEqual(m, self.evaluate(get_next))
|
||||
self.assertEqual(m, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
|
@ -124,19 +124,19 @@ class ExperimentalCheckpointDatasetTest(test_base.DatasetTestBase):
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next, save_op, _ = _build_graph(start, stop)
|
||||
with self.session(graph=g) as sess:
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.evaluate(init_op)
|
||||
sess.run(variables.global_variables_initializer())
|
||||
sess.run(init_op)
|
||||
for i in range(start, break_point):
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
self.evaluate(save_op)
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
sess.run(save_op)
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next, _, restore_op = _build_graph(start, stop)
|
||||
with self.session(graph=g) as sess:
|
||||
self.evaluate(init_op)
|
||||
self.evaluate(restore_op)
|
||||
sess.run(init_op)
|
||||
sess.run(restore_op)
|
||||
for i in range(break_point, stop):
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -144,14 +144,14 @@ class ExperimentalCheckpointDatasetTest(test_base.DatasetTestBase):
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next, save_op, restore_op = _build_graph(start, stop)
|
||||
with self.session(graph=g) as sess:
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.evaluate(init_op)
|
||||
sess.run(variables.global_variables_initializer())
|
||||
sess.run(init_op)
|
||||
for i in range(start, break_point):
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
self.evaluate(save_op)
|
||||
self.evaluate(restore_op)
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
sess.run(save_op)
|
||||
sess.run(restore_op)
|
||||
for i in range(break_point, stop):
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -175,14 +175,14 @@ class ExperimentalCheckpointDatasetTest(test_base.DatasetTestBase):
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs)
|
||||
with self.session(graph=g) as sess:
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.evaluate(init_op)
|
||||
sess.run(variables.global_variables_initializer())
|
||||
sess.run(init_op)
|
||||
for _ in range(break_epoch):
|
||||
for i in range(start, stop):
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
for i in range(start, break_point):
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
self.evaluate(save_op)
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
sess.run(save_op)
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
# Create an empty IteratorResource and restore the Iterator into it.
|
||||
@ -193,12 +193,12 @@ class ExperimentalCheckpointDatasetTest(test_base.DatasetTestBase):
|
||||
restore_op = self._restore_op(iterator._iterator_resource)
|
||||
get_next = iterator.get_next()
|
||||
with self.session(graph=g) as sess:
|
||||
self.evaluate(restore_op)
|
||||
sess.run(restore_op)
|
||||
for i in range(break_point, stop):
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
for _ in range(break_epoch + 1, num_epochs):
|
||||
for i in range(start, stop):
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -221,20 +221,20 @@ class ExperimentalCheckpointDatasetTest(test_base.DatasetTestBase):
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next, save_op, _ = _build_graph(start, stop)
|
||||
with self.session(graph=g) as sess:
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.evaluate(init_op)
|
||||
sess.run(variables.global_variables_initializer())
|
||||
sess.run(init_op)
|
||||
for i in range(start, break_point):
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
self.evaluate(save_op)
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
sess.run(save_op)
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
# Intentionally build a graph with a different value for stop to make sure
|
||||
# the original dataset graph is actually getting loaded.
|
||||
init_op, get_next, _, restore_op = _build_graph(start, stop_1)
|
||||
with self.session(graph=g) as sess:
|
||||
self.evaluate(restore_op)
|
||||
sess.run(restore_op)
|
||||
for i in range(break_point, stop):
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -259,19 +259,19 @@ class ExperimentalCheckpointDatasetTest(test_base.DatasetTestBase):
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next, save_op, _ = _build_graph(start, stop)
|
||||
with self.session(graph=g) as sess:
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.evaluate(init_op)
|
||||
sess.run(variables.global_variables_initializer())
|
||||
sess.run(init_op)
|
||||
for i in range(start, break_point):
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
self.evaluate(save_op)
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
sess.run(save_op)
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next, _, restore_op = _build_graph(start, stop)
|
||||
with self.session(graph=g) as sess:
|
||||
self.evaluate(init_op)
|
||||
self.evaluate(restore_op)
|
||||
sess.run(init_op)
|
||||
sess.run(restore_op)
|
||||
for i in range(break_point, stop):
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -294,27 +294,27 @@ class ExperimentalCheckpointDatasetTest(test_base.DatasetTestBase):
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next, save_op, _ = _build_graph(start, stop)
|
||||
with self.session(graph=g) as sess:
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.evaluate(init_op)
|
||||
sess.run(variables.global_variables_initializer())
|
||||
sess.run(init_op)
|
||||
for i in range(start, break_point1):
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
self.evaluate(save_op)
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
sess.run(save_op)
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next, save_op, restore_op = _build_graph(start, stop)
|
||||
with self.session(graph=g) as sess:
|
||||
self.evaluate(restore_op)
|
||||
sess.run(restore_op)
|
||||
for i in range(break_point1, break_point2):
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
self.evaluate(save_op)
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
sess.run(save_op)
|
||||
|
||||
break_point2 = 7
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next, save_op, restore_op = _build_graph(start, stop)
|
||||
with self.session(graph=g) as sess:
|
||||
self.evaluate(restore_op)
|
||||
sess.run(restore_op)
|
||||
for i in range(break_point2, stop):
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -338,28 +338,28 @@ class ExperimentalCheckpointDatasetTest(test_base.DatasetTestBase):
|
||||
init_op, get_next, save_op, restore_op = _build_graph(
|
||||
start, stop, num_epochs)
|
||||
with self.session(graph=g) as sess:
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.evaluate(init_op)
|
||||
sess.run(variables.global_variables_initializer())
|
||||
sess.run(init_op)
|
||||
# Note: There is no checkpoint saved currently so a NotFoundError is
|
||||
# raised.
|
||||
with self.assertRaises(errors.NotFoundError):
|
||||
self.evaluate(restore_op)
|
||||
sess.run(restore_op)
|
||||
for _ in range(break_epoch - 1):
|
||||
for i in range(start, stop):
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
for i in range(start, break_range):
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
self.evaluate(save_op)
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
sess.run(save_op)
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next, _, restore_op = _build_graph(start, stop, num_epochs)
|
||||
with self.session(graph=g) as sess:
|
||||
self.evaluate(restore_op)
|
||||
sess.run(restore_op)
|
||||
for i in range(break_range, stop):
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
for _ in range(break_epoch, num_epochs):
|
||||
for i in range(start, stop):
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -381,23 +381,23 @@ class ExperimentalCheckpointDatasetTest(test_base.DatasetTestBase):
|
||||
init_op, get_next, save_op, restore_op = _build_graph(
|
||||
start, stop, num_epochs)
|
||||
with self.session(graph=g) as sess:
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.evaluate(init_op)
|
||||
sess.run(variables.global_variables_initializer())
|
||||
sess.run(init_op)
|
||||
# Note: There is no checkpoint saved currently so a NotFoundError is
|
||||
# raised.
|
||||
with self.assertRaises(errors.NotFoundError):
|
||||
self.evaluate(restore_op)
|
||||
sess.run(restore_op)
|
||||
for _ in range(num_epochs):
|
||||
for i in range(start, stop):
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(save_op)
|
||||
sess.run(save_op)
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next, _, restore_op = _build_graph(start, stop, num_epochs)
|
||||
with self.session(graph=g) as sess:
|
||||
self.evaluate(restore_op)
|
||||
sess.run(restore_op)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
|
@ -107,7 +107,7 @@ class TextLineDatasetTest(test_base.DatasetTestBase):
|
||||
init_op, feed_dict={filenames: [test_filenames[0]],
|
||||
num_epochs: 1})
|
||||
for i in range(5):
|
||||
self.assertEqual(self._lineText(0, i), self.evaluate(get_next))
|
||||
self.assertEqual(self._lineText(0, i), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -116,7 +116,7 @@ class TextLineDatasetTest(test_base.DatasetTestBase):
|
||||
init_op, feed_dict={filenames: [test_filenames[1]],
|
||||
num_epochs: 1})
|
||||
for i in range(5):
|
||||
self.assertEqual(self._lineText(1, i), self.evaluate(get_next))
|
||||
self.assertEqual(self._lineText(1, i), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -124,7 +124,7 @@ class TextLineDatasetTest(test_base.DatasetTestBase):
|
||||
sess.run(init_op, feed_dict={filenames: test_filenames, num_epochs: 1})
|
||||
for j in range(2):
|
||||
for i in range(5):
|
||||
self.assertEqual(self._lineText(j, i), self.evaluate(get_next))
|
||||
self.assertEqual(self._lineText(j, i), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -133,7 +133,7 @@ class TextLineDatasetTest(test_base.DatasetTestBase):
|
||||
for _ in range(10):
|
||||
for j in range(2):
|
||||
for i in range(5):
|
||||
self.assertEqual(self._lineText(j, i), self.evaluate(get_next))
|
||||
self.assertEqual(self._lineText(j, i), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -267,7 +267,7 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
||||
init_op, feed_dict={filenames: [test_filenames[0]],
|
||||
num_epochs: 1})
|
||||
for i in range(self._num_records):
|
||||
self.assertEqual(self._record(0, i), self.evaluate(get_next))
|
||||
self.assertEqual(self._record(0, i), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -276,7 +276,7 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
||||
init_op, feed_dict={filenames: [test_filenames[1]],
|
||||
num_epochs: 1})
|
||||
for i in range(self._num_records):
|
||||
self.assertEqual(self._record(1, i), self.evaluate(get_next))
|
||||
self.assertEqual(self._record(1, i), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -284,7 +284,7 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
||||
sess.run(init_op, feed_dict={filenames: test_filenames, num_epochs: 1})
|
||||
for j in range(self._num_files):
|
||||
for i in range(self._num_records):
|
||||
self.assertEqual(self._record(j, i), self.evaluate(get_next))
|
||||
self.assertEqual(self._record(j, i), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -293,7 +293,7 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
||||
for _ in range(10):
|
||||
for j in range(self._num_files):
|
||||
for i in range(self._num_records):
|
||||
self.assertEqual(self._record(j, i), self.evaluate(get_next))
|
||||
self.assertEqual(self._record(j, i), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -405,19 +405,19 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
||||
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
|
||||
num_epochs=num_epochs)
|
||||
with self.session(graph=g) as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
# Note: There is no checkpoint saved currently so a NotFoundError is
|
||||
# raised.
|
||||
with self.assertRaises(errors.NotFoundError):
|
||||
self.evaluate(restore_op)
|
||||
sess.run(restore_op)
|
||||
for epoch in range(num_epochs):
|
||||
for f in range(self._num_files):
|
||||
for r in range(self._num_records):
|
||||
if (epoch == epoch_break and f == file_break and
|
||||
r == record_break):
|
||||
self.evaluate(save_op)
|
||||
sess.run(save_op)
|
||||
break
|
||||
self.assertEqual(self._record(f, r), self.evaluate(get_next_op))
|
||||
self.assertEqual(self._record(f, r), sess.run(get_next_op))
|
||||
else:
|
||||
continue
|
||||
break
|
||||
@ -426,13 +426,13 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
||||
break
|
||||
else:
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(get_next_op)
|
||||
sess.run(get_next_op)
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
|
||||
num_epochs=num_epochs)
|
||||
with self.session(graph=g) as sess:
|
||||
self.evaluate(restore_op)
|
||||
sess.run(restore_op)
|
||||
for epoch in range(num_epochs):
|
||||
for f in range(self._num_files):
|
||||
for r in range(self._num_records):
|
||||
@ -441,9 +441,9 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
||||
(epoch == epoch_break and f == file_break and
|
||||
r < record_break)):
|
||||
continue
|
||||
self.assertEqual(self._record(f, r), self.evaluate(get_next_op))
|
||||
self.assertEqual(self._record(f, r), sess.run(get_next_op))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(get_next_op)
|
||||
sess.run(get_next_op)
|
||||
|
||||
def testInitThenRestore(self):
|
||||
# Note: Calling init_op before restore_op is redundant. This test just makes
|
||||
@ -458,19 +458,19 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
||||
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
|
||||
num_epochs=num_epochs)
|
||||
with self.session(graph=g) as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
# Note: There is no checkpoint saved currently so a NotFoundError is
|
||||
# raised.
|
||||
with self.assertRaises(errors.NotFoundError):
|
||||
self.evaluate(restore_op)
|
||||
sess.run(restore_op)
|
||||
for epoch in range(num_epochs):
|
||||
for f in range(self._num_files):
|
||||
for r in range(self._num_records):
|
||||
if (epoch == epoch_break and f == file_break and
|
||||
r == record_break):
|
||||
self.evaluate(save_op)
|
||||
sess.run(save_op)
|
||||
break
|
||||
self.assertEqual(self._record(f, r), self.evaluate(get_next_op))
|
||||
self.assertEqual(self._record(f, r), sess.run(get_next_op))
|
||||
else:
|
||||
continue
|
||||
break
|
||||
@ -479,14 +479,14 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
||||
break
|
||||
else:
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(get_next_op)
|
||||
sess.run(get_next_op)
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
|
||||
num_epochs=num_epochs)
|
||||
with self.session(graph=g) as sess:
|
||||
self.evaluate(init_op)
|
||||
self.evaluate(restore_op)
|
||||
sess.run(init_op)
|
||||
sess.run(restore_op)
|
||||
for epoch in range(num_epochs):
|
||||
for f in range(self._num_files):
|
||||
for r in range(self._num_records):
|
||||
@ -495,9 +495,9 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
||||
(epoch == epoch_break and f == file_break and
|
||||
r < record_break)):
|
||||
continue
|
||||
self.assertEqual(self._record(f, r), self.evaluate(get_next_op))
|
||||
self.assertEqual(self._record(f, r), sess.run(get_next_op))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(get_next_op)
|
||||
sess.run(get_next_op)
|
||||
|
||||
def testRestoreInModifiedGraph(self):
|
||||
num_epochs = 10
|
||||
@ -510,19 +510,19 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
||||
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
|
||||
num_epochs=num_epochs)
|
||||
with self.session(graph=g) as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
# Note: There is no checkpoint saved currently so a NotFoundError is
|
||||
# raised.
|
||||
with self.assertRaises(errors.NotFoundError):
|
||||
self.evaluate(restore_op)
|
||||
sess.run(restore_op)
|
||||
for epoch in range(num_epochs):
|
||||
for f in range(self._num_files):
|
||||
for r in range(self._num_records):
|
||||
if (epoch == epoch_break and f == file_break and
|
||||
r == record_break):
|
||||
self.evaluate(save_op)
|
||||
sess.run(save_op)
|
||||
break
|
||||
self.assertEqual(self._record(f, r), self.evaluate(get_next_op))
|
||||
self.assertEqual(self._record(f, r), sess.run(get_next_op))
|
||||
else:
|
||||
continue
|
||||
break
|
||||
@ -531,13 +531,13 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
||||
break
|
||||
else:
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(get_next_op)
|
||||
sess.run(get_next_op)
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
|
||||
num_epochs=num_epochs_1)
|
||||
with self.session(graph=g) as sess:
|
||||
self.evaluate(restore_op)
|
||||
sess.run(restore_op)
|
||||
for epoch in range(num_epochs):
|
||||
for f in range(self._num_files):
|
||||
for r in range(self._num_records):
|
||||
@ -546,9 +546,9 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
||||
(epoch == epoch_break and f == file_break and
|
||||
r < record_break)):
|
||||
continue
|
||||
self.assertEqual(self._record(f, r), self.evaluate(get_next_op))
|
||||
self.assertEqual(self._record(f, r), sess.run(get_next_op))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(get_next_op)
|
||||
sess.run(get_next_op)
|
||||
|
||||
def testRestoreWithoutBuildingDatasetGraph(self):
|
||||
num_epochs = 10
|
||||
@ -560,19 +560,19 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
||||
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
|
||||
num_epochs=num_epochs)
|
||||
with self.session(graph=g) as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
# Note: There is no checkpoint saved currently so a NotFoundError is
|
||||
# raised.
|
||||
with self.assertRaises(errors.NotFoundError):
|
||||
self.evaluate(restore_op)
|
||||
sess.run(restore_op)
|
||||
for epoch in range(num_epochs):
|
||||
for f in range(self._num_files):
|
||||
for r in range(self._num_records):
|
||||
if (epoch == epoch_break and f == file_break and
|
||||
r == record_break):
|
||||
self.evaluate(save_op)
|
||||
sess.run(save_op)
|
||||
break
|
||||
self.assertEqual(self._record(f, r), self.evaluate(get_next_op))
|
||||
self.assertEqual(self._record(f, r), sess.run(get_next_op))
|
||||
else:
|
||||
continue
|
||||
break
|
||||
@ -581,12 +581,12 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
||||
break
|
||||
else:
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(get_next_op)
|
||||
sess.run(get_next_op)
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
restore_op, get_next_op = self._restore_iterator()
|
||||
with self.session(graph=g) as sess:
|
||||
self.evaluate(restore_op)
|
||||
sess.run(restore_op)
|
||||
for epoch in range(num_epochs):
|
||||
for f in range(self._num_files):
|
||||
for r in range(self._num_records):
|
||||
@ -595,9 +595,9 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
||||
(epoch == epoch_break and f == file_break and
|
||||
r < record_break)):
|
||||
continue
|
||||
self.assertEqual(self._record(f, r), self.evaluate(get_next_op))
|
||||
self.assertEqual(self._record(f, r), sess.run(get_next_op))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(get_next_op)
|
||||
sess.run(get_next_op)
|
||||
|
||||
def testRestoreUnusedIterator(self):
|
||||
num_epochs = 10
|
||||
@ -605,22 +605,22 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
||||
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
|
||||
num_epochs=num_epochs)
|
||||
with self.session(graph=g) as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
# Note: There is no checkpoint saved currently so a NotFoundError is
|
||||
# raised.
|
||||
with self.assertRaises(errors.NotFoundError):
|
||||
self.evaluate(restore_op)
|
||||
sess.run(restore_op)
|
||||
# Save unused iterator.
|
||||
self.evaluate(save_op)
|
||||
sess.run(save_op)
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
|
||||
num_epochs=num_epochs)
|
||||
with self.session(graph=g) as sess:
|
||||
self.evaluate(restore_op)
|
||||
sess.run(restore_op)
|
||||
for _ in range(num_epochs * self._num_files * self._num_records):
|
||||
self.evaluate(get_next_op)
|
||||
sess.run(get_next_op)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(get_next_op)
|
||||
sess.run(get_next_op)
|
||||
|
||||
def testRestoreExhaustedIterator(self):
|
||||
num_epochs = 10
|
||||
@ -629,26 +629,26 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
||||
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
|
||||
num_epochs=num_epochs)
|
||||
with self.session(graph=g) as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
# Note: There is no checkpoint saved currently so a NotFoundError is
|
||||
# raised.
|
||||
with self.assertRaises(errors.NotFoundError):
|
||||
self.evaluate(restore_op)
|
||||
sess.run(restore_op)
|
||||
for _ in range(num_epochs):
|
||||
for f in range(self._num_files):
|
||||
for r in range(self._num_records):
|
||||
self.assertEqual(self._record(f, r), self.evaluate(get_next_op))
|
||||
self.assertEqual(self._record(f, r), sess.run(get_next_op))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(get_next_op)
|
||||
self.evaluate(save_op)
|
||||
sess.run(get_next_op)
|
||||
sess.run(save_op)
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
|
||||
num_epochs=num_epochs)
|
||||
with self.session(graph=g) as sess:
|
||||
self.evaluate(restore_op)
|
||||
sess.run(restore_op)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(get_next_op)
|
||||
sess.run(get_next_op)
|
||||
|
||||
|
||||
class TFRecordDatasetTest(test_base.DatasetTestBase):
|
||||
@ -807,7 +807,7 @@ class TFRecordDatasetTest(test_base.DatasetTestBase):
|
||||
with self.cached_session() as sess:
|
||||
for j in range(self._num_files):
|
||||
for i in range(self._num_records):
|
||||
self.assertAllEqual(self._record(j, i), self.evaluate(next_element))
|
||||
self.assertAllEqual(self._record(j, i), sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -819,7 +819,7 @@ class TFRecordDatasetTest(test_base.DatasetTestBase):
|
||||
with self.cached_session() as sess:
|
||||
for j in range(self._num_files):
|
||||
for i in range(self._num_records):
|
||||
self.assertAllEqual(self._record(j, i), self.evaluate(next_element))
|
||||
self.assertAllEqual(self._record(j, i), sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
|
@ -36,7 +36,7 @@ class ReduceDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
ds = dataset_ops.Dataset.range(1, i + 1)
|
||||
result = ds.reduce(np.int64(0), lambda x, y: x + y)
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(((i + 1) * i) // 2, self.evaluate(result))
|
||||
self.assertEqual(((i + 1) * i) // 2, sess.run(result))
|
||||
|
||||
def testSumTuple(self):
|
||||
|
||||
@ -49,7 +49,7 @@ class ReduceDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
ds = dataset_ops.Dataset.zip((ds, ds))
|
||||
result = ds.reduce(np.int64(0), reduce_fn)
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(((i + 1) * i), self.evaluate(result))
|
||||
self.assertEqual(((i + 1) * i), sess.run(result))
|
||||
|
||||
def testSumAndCount(self):
|
||||
|
||||
@ -61,7 +61,7 @@ class ReduceDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
ds = dataset_ops.Dataset.range(1, i + 1)
|
||||
result = ds.reduce((np.int64(0), np.int64(0)), reduce_fn)
|
||||
with self.cached_session() as sess:
|
||||
s, c = self.evaluate(result)
|
||||
s, c = sess.run(result)
|
||||
self.assertEqual(((i + 1) * i) // 2, s)
|
||||
self.assertEqual(i, c)
|
||||
|
||||
@ -93,8 +93,7 @@ class ReduceDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
ds = dataset_ops.Dataset.from_tensors(make_sparse_fn(i+1))
|
||||
result = ds.reduce(make_sparse_fn(0), reduce_fn)
|
||||
with self.cached_session() as sess:
|
||||
self.assertSparseValuesEqual(
|
||||
make_sparse_fn(i + 1), self.evaluate(result))
|
||||
self.assertSparseValuesEqual(make_sparse_fn(i+1), sess.run(result))
|
||||
|
||||
def testNested(self):
|
||||
|
||||
@ -117,7 +116,7 @@ class ReduceDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
ds = dataset_ops.Dataset.range(1, i + 1).map(map_fn)
|
||||
result = ds.reduce(map_fn(0), reduce_fn)
|
||||
with self.cached_session() as sess:
|
||||
result = self.evaluate(result)
|
||||
result = sess.run(result)
|
||||
self.assertEqual(((i + 1) * i) // 2, result["dense"])
|
||||
self.assertSparseValuesEqual(make_sparse_fn(i), result["sparse"])
|
||||
|
||||
|
@ -49,7 +49,7 @@ class SequenceDatasetTest(test_base.DatasetTestBase):
|
||||
# Test a finite repetition.
|
||||
sess.run(init_op, feed_dict={count_placeholder: 3})
|
||||
for _ in range(3):
|
||||
results = self.evaluate(get_next)
|
||||
results = sess.run(get_next)
|
||||
for component, result_component in zip(components, results):
|
||||
self.assertAllEqual(component, result_component)
|
||||
|
||||
@ -59,7 +59,7 @@ class SequenceDatasetTest(test_base.DatasetTestBase):
|
||||
# Test a different finite repetition.
|
||||
sess.run(init_op, feed_dict={count_placeholder: 7})
|
||||
for _ in range(7):
|
||||
results = self.evaluate(get_next)
|
||||
results = sess.run(get_next)
|
||||
for component, result_component in zip(components, results):
|
||||
self.assertAllEqual(component, result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -75,7 +75,7 @@ class SequenceDatasetTest(test_base.DatasetTestBase):
|
||||
# actually is infinite.
|
||||
sess.run(init_op, feed_dict={count_placeholder: -1})
|
||||
for _ in range(17):
|
||||
results = self.evaluate(get_next)
|
||||
results = sess.run(get_next)
|
||||
for component, result_component in zip(components, results):
|
||||
self.assertAllEqual(component, result_component)
|
||||
|
||||
@ -95,7 +95,7 @@ class SequenceDatasetTest(test_base.DatasetTestBase):
|
||||
# Take fewer than input size
|
||||
sess.run(init_op, feed_dict={count_placeholder: 4})
|
||||
for i in range(4):
|
||||
results = self.evaluate(get_next)
|
||||
results = sess.run(get_next)
|
||||
self.assertAllEqual(results, components[0][i:i+1])
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -104,7 +104,7 @@ class SequenceDatasetTest(test_base.DatasetTestBase):
|
||||
# Take more than input size
|
||||
sess.run(init_op, feed_dict={count_placeholder: 25})
|
||||
for i in range(10):
|
||||
results = self.evaluate(get_next)
|
||||
results = sess.run(get_next)
|
||||
self.assertAllEqual(results, components[0][i:i+1])
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -113,7 +113,7 @@ class SequenceDatasetTest(test_base.DatasetTestBase):
|
||||
# Take all of input
|
||||
sess.run(init_op, feed_dict={count_placeholder: -1})
|
||||
for i in range(10):
|
||||
results = self.evaluate(get_next)
|
||||
results = sess.run(get_next)
|
||||
self.assertAllEqual(results, components[0][i:i+1])
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -142,7 +142,7 @@ class SequenceDatasetTest(test_base.DatasetTestBase):
|
||||
# the first 4 elements and then read the rest.
|
||||
sess.run(init_op, feed_dict={count_placeholder: 4})
|
||||
for i in range(4, 10):
|
||||
results = self.evaluate(get_next)
|
||||
results = sess.run(get_next)
|
||||
self.assertAllEqual(results, components[0][i:i+1])
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
@ -165,7 +165,7 @@ class SequenceDatasetTest(test_base.DatasetTestBase):
|
||||
# Skip nothing
|
||||
sess.run(init_op, feed_dict={count_placeholder: 0})
|
||||
for i in range(0, 10):
|
||||
results = self.evaluate(get_next)
|
||||
results = sess.run(get_next)
|
||||
self.assertAllEqual(results, components[0][i:i+1])
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
@ -187,7 +187,7 @@ class SequenceDatasetTest(test_base.DatasetTestBase):
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op, feed_dict={inner_count: 7, outer_count: 14})
|
||||
for _ in range(7 * 14):
|
||||
results = self.evaluate(get_next)
|
||||
results = sess.run(get_next)
|
||||
for component, result_component in zip(components, results):
|
||||
self.assertAllEqual(component, result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -201,7 +201,7 @@ class SequenceDatasetTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
|
@ -66,7 +66,7 @@ class ShuffleDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
# First run without shuffling to collect the "ground truth".
|
||||
self.evaluate(init_fifo_op)
|
||||
sess.run(init_fifo_op)
|
||||
unshuffled_elements = []
|
||||
for _ in range(20):
|
||||
unshuffled_elements.append(sess.run(get_next))
|
||||
@ -159,7 +159,7 @@ class ShuffleDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer, feed_dict={seed_placeholder: 0})
|
||||
for elem in elems:
|
||||
self.assertEqual(elem, self.evaluate(get_next))
|
||||
self.assertEqual(elem, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -188,9 +188,9 @@ class ShuffleDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
initial_permutation = self.evaluate(next_element)
|
||||
self.assertAllEqual(initial_permutation, self.evaluate(next_element))
|
||||
self.assertAllEqual(initial_permutation, self.evaluate(next_element))
|
||||
initial_permutation = sess.run(next_element)
|
||||
self.assertAllEqual(initial_permutation, sess.run(next_element))
|
||||
self.assertAllEqual(initial_permutation, sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -261,7 +261,7 @@ class ShuffleDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.session(graph=g) as sess:
|
||||
for iterator in iterators:
|
||||
if initializable:
|
||||
self.evaluate(iterator.initializer)
|
||||
sess.run(iterator.initializer)
|
||||
next_element = iterator.get_next()
|
||||
run_results = []
|
||||
for _ in range(300):
|
||||
|
@ -102,7 +102,7 @@ class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
num_full_batches = max(
|
||||
0, (count * 7 - ((size - 1) * stride + 1)) // shift + 1)
|
||||
for i in range(num_full_batches):
|
||||
result = self.evaluate(get_next)
|
||||
result = sess.run(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
for j in range(size):
|
||||
self.assertAllEqual(component[(i * shift + j * stride) % 7]**2,
|
||||
@ -111,7 +111,7 @@ class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
num_partial_batches = (count * 7) // shift + (
|
||||
(count * 7) % shift > 0) - num_full_batches
|
||||
for i in range(num_partial_batches):
|
||||
result = self.evaluate(get_next)
|
||||
result = sess.run(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
remaining = (count * 7) - ((num_full_batches + i) * shift)
|
||||
num_elements = remaining // stride + ((remaining % stride) > 0)
|
||||
@ -164,10 +164,10 @@ class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
num_batches = (10 - 5) // 3 + 1
|
||||
for i in range(num_batches):
|
||||
actual = self.evaluate(get_next)
|
||||
actual = sess.run(get_next)
|
||||
expected = sparse_tensor.SparseTensorValue(
|
||||
indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
|
||||
values=[i * 3, i * 3 + 1, i * 3 + 2, i * 3 + 3, i * 3 + 4],
|
||||
@ -193,10 +193,10 @@ class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
num_batches = (10 - 5) // 3 + 1
|
||||
for i in range(num_batches):
|
||||
actual = self.evaluate(get_next)
|
||||
actual = sess.run(get_next)
|
||||
expected_indices = []
|
||||
expected_values = []
|
||||
for j in range(5):
|
||||
@ -227,9 +227,9 @@ class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(init_op)
|
||||
sess.run(init_op)
|
||||
# Slide: 1st batch.
|
||||
actual = self.evaluate(get_next)
|
||||
actual = sess.run(get_next)
|
||||
expected = sparse_tensor.SparseTensorValue(
|
||||
indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [1, 0, 0],
|
||||
[1, 1, 0], [1, 2, 0], [1, 3, 0], [2, 0, 0], [2, 1, 0],
|
||||
@ -239,7 +239,7 @@ class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertTrue(sparse_tensor.is_sparse(actual))
|
||||
self.assertSparseValuesEqual(actual, expected)
|
||||
# Slide: 2nd batch.
|
||||
actual = self.evaluate(get_next)
|
||||
actual = sess.run(get_next)
|
||||
expected = sparse_tensor.SparseTensorValue(
|
||||
indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [1, 0, 0],
|
||||
[1, 1, 0], [1, 2, 0], [1, 3, 0], [2, 0, 0], [2, 1, 0],
|
||||
@ -265,7 +265,7 @@ class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate(iterator.initializer)
|
||||
sess.run(iterator.initializer)
|
||||
with self.assertRaisesRegexp(
|
||||
errors.InvalidArgumentError,
|
||||
r"Cannot batch tensors with different shapes in component 0. "
|
||||
@ -281,8 +281,8 @@ class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = dataset.make_one_shot_iterator().get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(np.float32([1., 2.]), self.evaluate(get_next))
|
||||
self.assertAllEqual(np.float32([2., 3.]), self.evaluate(get_next))
|
||||
self.assertAllEqual(np.float32([1., 2.]), sess.run(get_next))
|
||||
self.assertAllEqual(np.float32([2., 3.]), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
|
@ -55,7 +55,7 @@ class ZipDatasetTest(test_base.DatasetTestBase):
|
||||
sess.run(init_op, feed_dict={ph: value for ph, value in zip(
|
||||
component_placeholders, equal_length_components)})
|
||||
for i in range(4):
|
||||
results = self.evaluate(get_next)
|
||||
results = sess.run(get_next)
|
||||
for component, result_component in zip(
|
||||
equal_length_components, results):
|
||||
self.assertAllEqual(component[i], result_component)
|
||||
@ -66,7 +66,7 @@ class ZipDatasetTest(test_base.DatasetTestBase):
|
||||
sess.run(init_op, feed_dict={ph: value for ph, value in zip(
|
||||
component_placeholders, variable_length_components)})
|
||||
for i in range(2):
|
||||
results = self.evaluate(get_next)
|
||||
results = sess.run(get_next)
|
||||
for component, result_component in zip(
|
||||
variable_length_components, results):
|
||||
self.assertAllEqual(component[i], result_component)
|
||||
@ -103,7 +103,7 @@ class ZipDatasetTest(test_base.DatasetTestBase):
|
||||
sess.run(init_op, feed_dict={ph: value for ph, value in zip(
|
||||
component_placeholders, equal_length_components)})
|
||||
for i in range(4):
|
||||
result1, (result2, result3) = self.evaluate(get_next)
|
||||
result1, (result2, result3) = sess.run(get_next)
|
||||
self.assertAllEqual(equal_length_components[0][i], result1)
|
||||
self.assertAllEqual(equal_length_components[1][i], result2)
|
||||
self.assertAllEqual(equal_length_components[2][i], result3)
|
||||
|
@ -31,24 +31,24 @@ class ConvertTest(test.TestCase):
|
||||
def testInteger(self):
|
||||
resp = convert.optional_param_to_tensor("foo", 3)
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(3, self.evaluate(resp))
|
||||
self.assertEqual(3, sess.run(resp))
|
||||
|
||||
def testIntegerDefault(self):
|
||||
resp = convert.optional_param_to_tensor("foo", None)
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(0, self.evaluate(resp))
|
||||
self.assertEqual(0, sess.run(resp))
|
||||
|
||||
def testStringDefault(self):
|
||||
resp = convert.optional_param_to_tensor("bar", None, "default",
|
||||
dtypes.string)
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(compat.as_bytes("default"), self.evaluate(resp))
|
||||
self.assertEqual(compat.as_bytes("default"), sess.run(resp))
|
||||
|
||||
def testString(self):
|
||||
resp = convert.optional_param_to_tensor("bar", "value", "default",
|
||||
dtypes.string)
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(compat.as_bytes("value"), self.evaluate(resp))
|
||||
self.assertEqual(compat.as_bytes("value"), sess.run(resp))
|
||||
|
||||
def testPartialShapeToTensorKnownDimension(self):
|
||||
with self.cached_session() as sess:
|
||||
|
@ -1583,7 +1583,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
|
||||
x = variables.VariableV1([1, 3, 3, 7], name="x")
|
||||
_, idx = array_ops.unique(x, name="x_unique")
|
||||
idx_times_two = math_ops.multiply(idx, 2, name="idx_times_two")
|
||||
self.evaluate(x.initializer)
|
||||
sess.run(x.initializer)
|
||||
|
||||
run_options = config_pb2.RunOptions(output_partition_graphs=True)
|
||||
debug_utils.watch_graph(
|
||||
|
@ -126,8 +126,8 @@ class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase):
|
||||
u = variables.Variable([12.0], name="u")
|
||||
v = variables.Variable([30.0], name="v")
|
||||
w = math_ops.add(u, v, name="w")
|
||||
self.evaluate(u.initializer)
|
||||
self.evaluate(v.initializer)
|
||||
sess.run(u.initializer)
|
||||
sess.run(v.initializer)
|
||||
|
||||
self._compareOriginalAndReconstructedGraphDefs(
|
||||
sess, w, expected_output=[42.0])
|
||||
@ -139,7 +139,7 @@ class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase):
|
||||
b = math_ops.add(a, a, name="b")
|
||||
with ops.control_dependencies([a, b]):
|
||||
c = math_ops.multiply(b, b, name="c")
|
||||
self.evaluate(a.initializer)
|
||||
sess.run(a.initializer)
|
||||
|
||||
self._compareOriginalAndReconstructedGraphDefs(
|
||||
sess, c, expected_output=400.0)
|
||||
@ -150,8 +150,8 @@ class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase):
|
||||
y = variables.Variable(20.0, name="y")
|
||||
cond = control_flow_ops.cond(
|
||||
x > y, lambda: math_ops.add(x, 1), lambda: math_ops.add(y, 1))
|
||||
self.evaluate(x.initializer)
|
||||
self.evaluate(y.initializer)
|
||||
sess.run(x.initializer)
|
||||
sess.run(y.initializer)
|
||||
|
||||
self._compareOriginalAndReconstructedGraphDefs(
|
||||
sess, cond, expected_output=21.0)
|
||||
@ -173,8 +173,8 @@ class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase):
|
||||
toy_loss = x * (u - v)
|
||||
train_op = gradient_descent.GradientDescentOptimizer(
|
||||
learning_rate=0.1).minimize(toy_loss, name="train_op")
|
||||
self.evaluate(u.initializer)
|
||||
self.evaluate(v.initializer)
|
||||
sess.run(u.initializer)
|
||||
sess.run(v.initializer)
|
||||
|
||||
self._compareOriginalAndReconstructedGraphDefs(sess, train_op)
|
||||
|
||||
|
@ -131,8 +131,8 @@ class DistributedSessionDebugTest(test_util.TensorFlowTestCase):
|
||||
with session.Session(
|
||||
config=self.session_config, graph=graph,
|
||||
target=self.server_target) as sess:
|
||||
self.evaluate(self.a.initializer)
|
||||
self.evaluate(self.b.initializer)
|
||||
sess.run(self.a.initializer)
|
||||
sess.run(self.b.initializer)
|
||||
|
||||
run_options = config_pb2.RunOptions()
|
||||
debug_utils.watch_graph(
|
||||
@ -198,8 +198,8 @@ class DistributedSessionDebugTest(test_util.TensorFlowTestCase):
|
||||
with session.Session(
|
||||
config=self.session_config, graph=graph,
|
||||
target=self.server_target) as sess:
|
||||
self.evaluate(self.a.initializer)
|
||||
self.evaluate(self.b.initializer)
|
||||
sess.run(self.a.initializer)
|
||||
sess.run(self.b.initializer)
|
||||
|
||||
def watch_fn(feeds, fetch_keys):
|
||||
del feeds, fetch_keys
|
||||
|
@ -67,7 +67,7 @@ class SessionDebugMultiGPUTest(test_util.TensorFlowTestCase):
|
||||
u1 = math_ops.multiply(v, v, name="u1")
|
||||
w = math_ops.subtract(u1, u0, name="w")
|
||||
|
||||
self.evaluate(v.initializer)
|
||||
sess.run(v.initializer)
|
||||
|
||||
run_options = config_pb2.RunOptions(output_partition_graphs=True)
|
||||
debug_utils.watch_graph(run_options, sess.graph,
|
||||
|
@ -109,8 +109,8 @@ class SourceHelperTest(test_util.TensorFlowTestCase):
|
||||
self.w = math_ops.matmul(self.u, self.v, name="w")
|
||||
self.w_line_number = line_number_above()
|
||||
|
||||
self.evaluate(self.u.initializer)
|
||||
self.evaluate(self.v.initializer)
|
||||
sess.run(self.u.initializer)
|
||||
sess.run(self.v.initializer)
|
||||
|
||||
run_options = config_pb2.RunOptions(output_partition_graphs=True)
|
||||
debug_utils.watch_graph(
|
||||
|
@ -92,7 +92,7 @@ class AutoShardDatasetTest(test.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
for f in range(self._shard_index, self._num_files, self._num_shards):
|
||||
for r in range(self._num_records):
|
||||
self.assertAllEqual(record_fn(r, f), self.evaluate(next_element))
|
||||
self.assertAllEqual(record_fn(r, f), sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -205,11 +205,10 @@ class AutoShardDatasetTest(test.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
for f in range(self._shard_index, self._num_files, self._num_shards):
|
||||
for r in range(self._num_records):
|
||||
self.assertAllEqual(self._record(r, f), self.evaluate(next_element))
|
||||
self.assertAllEqual(self._record(r, f), sess.run(next_element))
|
||||
for f in range(self._shard_index, self._num_files, self._num_shards):
|
||||
for r in range(self._num_records):
|
||||
self.assertAllEqual(
|
||||
self._text_line(r, f), self.evaluate(next_element))
|
||||
self.assertAllEqual(self._text_line(r, f), sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
|
@ -149,9 +149,9 @@ class DefFunctionTest(test.TestCase):
|
||||
|
||||
result = fn(3.0)
|
||||
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
sess.run(variables.global_variables_initializer())
|
||||
self.assertAllEqual(sess.run(state[0]), 2.0)
|
||||
self.assertAllEqual(self.evaluate(result), 6.0)
|
||||
self.assertAllEqual(sess.run(result), 6.0)
|
||||
|
||||
def testLegacyGraphModeVariablesNonTrivialInitializer(self):
|
||||
with ops.Graph().as_default(), self.test_session() as sess:
|
||||
@ -168,9 +168,9 @@ class DefFunctionTest(test.TestCase):
|
||||
|
||||
result = fn(3.0)
|
||||
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
sess.run(variables.global_variables_initializer())
|
||||
self.assertAllEqual(sess.run(state[0]), 6.0)
|
||||
self.assertAllEqual(self.evaluate(result), 18.0)
|
||||
self.assertAllEqual(sess.run(result), 18.0)
|
||||
|
||||
def testLegacyGraphModeInputDependentInitializerFails(self):
|
||||
with ops.Graph().as_default():
|
||||
|
@ -78,7 +78,7 @@ class FunctionGradientsTest(test.TestCase, parameterized.TestCase):
|
||||
c = constant_op.constant([[2.]])
|
||||
f_c = f(c)
|
||||
g, = gradients_impl.gradients(f_c, c)
|
||||
self.assertAllEqual(self.evaluate(g).values, [[1.0]])
|
||||
self.assertAllEqual(sess.run(g).values, [[1.0]])
|
||||
|
||||
def testNoSymGradNestedDefun(self):
|
||||
|
||||
|
@ -564,7 +564,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
variables.global_variables_initializer().run()
|
||||
call = def_function.function(o.call)
|
||||
op = call()
|
||||
self.assertAllEqual(self.evaluate(op), 2.0)
|
||||
self.assertAllEqual(sess.run(op), 2.0)
|
||||
|
||||
def testGraphModeManyFunctions(self):
|
||||
with ops.Graph().as_default(), self.cached_session():
|
||||
@ -1732,7 +1732,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
function.register(cpu_boost, x)
|
||||
y = gpu_boost(x)
|
||||
y_value = self.evaluate(y)
|
||||
y_value = sess.run(y)
|
||||
|
||||
if test.is_gpu_available():
|
||||
self.assertEqual(y_value, 5.0)
|
||||
|
@ -1027,7 +1027,7 @@ class CrossedColumnTest(test.TestCase):
|
||||
outputs = _transform_features(features, [price_cross_wire])
|
||||
output = outputs[price_cross_wire]
|
||||
with self.cached_session() as sess:
|
||||
output_val = self.evaluate(output)
|
||||
output_val = sess.run(output)
|
||||
self.assertAllEqual(
|
||||
[[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]], output_val.indices)
|
||||
for val in output_val.values:
|
||||
@ -1886,8 +1886,7 @@ class LinearModelTest(test.TestCase):
|
||||
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
|
||||
sess.run(bias.assign([5.]))
|
||||
|
||||
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]],
|
||||
self.evaluate(net))
|
||||
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], sess.run(net))
|
||||
|
||||
def test_with_1d_unknown_shape_sparse_tensor(self):
|
||||
price = fc._numeric_column('price')
|
||||
@ -2526,8 +2525,7 @@ class _LinearModelTest(test.TestCase):
|
||||
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
|
||||
sess.run(bias.assign([5.]))
|
||||
|
||||
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]],
|
||||
self.evaluate(net))
|
||||
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], sess.run(net))
|
||||
|
||||
def test_with_1d_unknown_shape_sparse_tensor(self):
|
||||
price = fc._numeric_column('price')
|
||||
|
@ -1188,7 +1188,7 @@ class CrossedColumnTest(test.TestCase):
|
||||
outputs = fc._transform_features_v2(features, [price_cross_wire], None)
|
||||
output = outputs[price_cross_wire]
|
||||
with self.cached_session() as sess:
|
||||
output_val = self.evaluate(output)
|
||||
output_val = sess.run(output)
|
||||
self.assertAllEqual(
|
||||
[[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]], output_val.indices)
|
||||
for val in output_val.values:
|
||||
@ -2088,8 +2088,7 @@ class LinearModelTest(test.TestCase):
|
||||
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
|
||||
sess.run(bias.assign([5.]))
|
||||
|
||||
self.assertAllClose([[10 - 1000 + 5.], [100 - 10 + 5.]],
|
||||
self.evaluate(net))
|
||||
self.assertAllClose([[10 - 1000 + 5.], [100 - 10 + 5.]], sess.run(net))
|
||||
|
||||
coord.request_stop()
|
||||
coord.join(threads)
|
||||
@ -2125,8 +2124,7 @@ class LinearModelTest(test.TestCase):
|
||||
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
|
||||
sess.run(bias.assign([5.]))
|
||||
|
||||
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]],
|
||||
self.evaluate(net))
|
||||
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], sess.run(net))
|
||||
|
||||
def test_with_1d_unknown_shape_sparse_tensor(self):
|
||||
price = fc.numeric_column('price')
|
||||
@ -2845,8 +2843,7 @@ class OldLinearModelTest(test.TestCase):
|
||||
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
|
||||
sess.run(bias.assign([5.]))
|
||||
|
||||
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]],
|
||||
self.evaluate(net))
|
||||
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], sess.run(net))
|
||||
|
||||
def test_with_1d_unknown_shape_sparse_tensor(self):
|
||||
price = fc.numeric_column('price')
|
||||
|
@ -102,7 +102,7 @@ class FunctionTest(test.TestCase):
|
||||
call = MyIdentityFunc([18.0])
|
||||
self.assertEqual("MyIdentity", call.op.name)
|
||||
with session.Session() as sess:
|
||||
self.assertAllEqual([18.0], self.evaluate(call))
|
||||
self.assertAllEqual([18.0], sess.run(call))
|
||||
|
||||
def testIdentityImplicitDeref(self):
|
||||
|
||||
@ -116,8 +116,8 @@ class FunctionTest(test.TestCase):
|
||||
self.assertEqual("MyIdentity", call.op.name)
|
||||
for cfg in _OptimizerOptions():
|
||||
with session.Session(config=cfg) as sess:
|
||||
self.evaluate(var.initializer)
|
||||
self.assertAllEqual([18.0], self.evaluate(call))
|
||||
sess.run(var.initializer)
|
||||
self.assertAllEqual([18.0], sess.run(call))
|
||||
|
||||
def testIdentityOutputName(self):
|
||||
|
||||
@ -130,7 +130,7 @@ class FunctionTest(test.TestCase):
|
||||
call = MyIdentityFunc([18.0])
|
||||
self.assertEqual("MyIdentity", call.op.name)
|
||||
with session.Session() as sess:
|
||||
self.assertAllEqual([18.0], self.evaluate(call))
|
||||
self.assertAllEqual([18.0], sess.run(call))
|
||||
|
||||
def testTooManyOutputNames(self):
|
||||
|
||||
@ -158,7 +158,7 @@ class FunctionTest(test.TestCase):
|
||||
call = APlus2B([1.0], [2.0])
|
||||
self.assertEqual("APlus2B", call.op.name)
|
||||
with session.Session() as sess:
|
||||
self.assertAllEqual([5.0], self.evaluate(call))
|
||||
self.assertAllEqual([5.0], sess.run(call))
|
||||
|
||||
def testFunctionWithNoOutput(self):
|
||||
|
||||
@ -187,7 +187,7 @@ class FunctionTest(test.TestCase):
|
||||
call = APlus2B([1.0], [2.0])
|
||||
self.assertEqual("APlus2B", call.op.name)
|
||||
with session.Session() as sess:
|
||||
self.assertAllEqual([5.0], self.evaluate(call))
|
||||
self.assertAllEqual([5.0], sess.run(call))
|
||||
|
||||
def testDefineFunctionDuplicateOutputs(self):
|
||||
|
||||
@ -224,8 +224,8 @@ class FunctionTest(test.TestCase):
|
||||
call_g = XSquarePlusOneGrad([2.0], [0.1])
|
||||
|
||||
with session.Session() as sess:
|
||||
self.assertAllClose([5.0], self.evaluate(call_f))
|
||||
self.assertAllClose([0.4], self.evaluate(call_g))
|
||||
self.assertAllClose([5.0], sess.run(call_f))
|
||||
self.assertAllClose([0.4], sess.run(call_g))
|
||||
|
||||
def testTanhSymGrad(self):
|
||||
|
||||
@ -387,7 +387,7 @@ class FunctionTest(test.TestCase):
|
||||
call = AConstant()
|
||||
self.assertEqual("AConstant", call.op.name)
|
||||
with session.Session() as sess:
|
||||
self.assertAllEqual([42], self.evaluate(call))
|
||||
self.assertAllEqual([42], sess.run(call))
|
||||
|
||||
def testDefineFunctionNames(self):
|
||||
|
||||
@ -468,7 +468,7 @@ class FunctionTest(test.TestCase):
|
||||
|
||||
loop = control_flow_ops.while_loop(lambda x: x < 1e5, Body, [1.0])
|
||||
|
||||
ans = self.evaluate(loop)
|
||||
ans = sess.run(loop)
|
||||
self.assertAllClose(ans, 131072.)
|
||||
|
||||
def testControlFlowStrictness(self):
|
||||
@ -650,8 +650,8 @@ class FunctionTest(test.TestCase):
|
||||
# pylint: enable=unexpected-keyword-arg
|
||||
self.assertEqual("next", call2.op.name)
|
||||
with session.Session() as sess:
|
||||
self.assertAllEqual([1], self.evaluate(call1))
|
||||
self.assertAllEqual([0], self.evaluate(call2))
|
||||
self.assertAllEqual([1], sess.run(call1))
|
||||
self.assertAllEqual([0], sess.run(call2))
|
||||
|
||||
def testNestedFunction(self):
|
||||
|
||||
@ -794,7 +794,7 @@ class FunctionTest(test.TestCase):
|
||||
y = Foo()
|
||||
|
||||
with self.session(graph=g) as sess:
|
||||
self.assertEqual(self.evaluate(y), 10)
|
||||
self.assertEqual(sess.run(y), 10)
|
||||
|
||||
def testCaptureInCond(self):
|
||||
g = ops.Graph()
|
||||
@ -809,8 +809,8 @@ class FunctionTest(test.TestCase):
|
||||
z = Foo(False)
|
||||
|
||||
with self.session(graph=g) as sess:
|
||||
self.assertEqual(self.evaluate(y), 1)
|
||||
self.assertEqual(self.evaluate(z), 2)
|
||||
self.assertEqual(sess.run(y), 1)
|
||||
self.assertEqual(sess.run(z), 2)
|
||||
|
||||
def testStableName(self):
|
||||
|
||||
@ -900,7 +900,7 @@ class FunctionTest(test.TestCase):
|
||||
self.assertEqual(global_vars[0].name, "linear/w:0")
|
||||
|
||||
with session.Session() as sess:
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
sess.run(variables.global_variables_initializer())
|
||||
output_val = sess.run(
|
||||
output_op, feed_dict={input_op: np.random.rand(32, 100)})
|
||||
self.assertEqual(output_val.shape, (32, 100))
|
||||
@ -928,7 +928,7 @@ class FunctionTest(test.TestCase):
|
||||
self.assertEqual(global_vars[0].name, "vs1/var:0")
|
||||
|
||||
with session.Session() as sess:
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
sess.run(variables.global_variables_initializer())
|
||||
out1, out2 = sess.run(
|
||||
[out1_op, out2_op], feed_dict={input_op: np.linspace(1, 10, 10)})
|
||||
self.assertAllEqual(out1, np.linspace(2, 11, 10))
|
||||
@ -991,8 +991,8 @@ class FunctionTest(test.TestCase):
|
||||
result_2 = Bar(constant_op.constant(100, dtype=dtypes.int64))
|
||||
|
||||
with session.Session() as sess:
|
||||
self.assertEqual(4.0, self.evaluate(result_1))
|
||||
self.assertEqual(100, self.evaluate(result_2))
|
||||
self.assertEqual(4.0, sess.run(result_1))
|
||||
self.assertEqual(100, sess.run(result_2))
|
||||
self.assertEqual((4.0, 100), sess.run((result_1, result_2)))
|
||||
|
||||
def testStatefulFunction(self):
|
||||
@ -1052,8 +1052,8 @@ class FunctionTest(test.TestCase):
|
||||
for config in _OptimizerOptions():
|
||||
config.device_count["CPU"] = 2
|
||||
with session.Session(config=config) as sess:
|
||||
self.assertEqual(42.0, self.evaluate(f_0))
|
||||
self.assertEqual(44.0, self.evaluate(f_1))
|
||||
self.assertEqual(42.0, sess.run(f_0))
|
||||
self.assertEqual(44.0, sess.run(f_1))
|
||||
self.assertEqual((42.0, 44.0), sess.run((f_0, f_1)))
|
||||
|
||||
def testGuaranteedConstsAreCaptured(self):
|
||||
@ -1076,7 +1076,7 @@ class FunctionTest(test.TestCase):
|
||||
return output
|
||||
|
||||
with self.session(use_gpu=False) as sess:
|
||||
self.evaluate(var.initializer)
|
||||
sess.run(var.initializer)
|
||||
_ = sess.run(CapturesGuaranteedConst(), {also_not_const: 1.0})
|
||||
|
||||
def testSameFunctionDifferentGrads(self):
|
||||
@ -1651,8 +1651,8 @@ class ModuleFunctionTest(test.TestCase):
|
||||
y = LinearWithCApi(a, b, c)
|
||||
z = Linear2WithCApi(a, b, c, d, e)
|
||||
with session.Session() as sess:
|
||||
self.assertAllEqual([[1]], self.evaluate(y))
|
||||
self.assertAllEqual([[5]], self.evaluate(z))
|
||||
self.assertAllEqual([[1]], sess.run(y))
|
||||
self.assertAllEqual([[5]], sess.run(z))
|
||||
|
||||
|
||||
class VariableHoistingTest(test.TestCase):
|
||||
@ -1704,7 +1704,7 @@ class VariableHoistingTest(test.TestCase):
|
||||
self.assertEqual("Foo/b", b.op.name)
|
||||
|
||||
with self.session(graph=g) as sess:
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
sess.run(variables.global_variables_initializer())
|
||||
w, b, x, y0, loss, dw, db = sess.run([w, b, x, y0, loss, dw, db])
|
||||
|
||||
self.assertAllEqual(w.shape, (64, 64))
|
||||
|
@ -211,7 +211,7 @@ class DeviceFunctionsTest(test.TestCase):
|
||||
with session.Session() as sess:
|
||||
init = variables.variables_initializer([variable_node])
|
||||
sess.run(init)
|
||||
output = self.evaluate(output_node)
|
||||
output = sess.run(output_node)
|
||||
self.assertNear(4.0, output, 0.00001)
|
||||
variable_graph_def = sess.graph.as_graph_def()
|
||||
|
||||
@ -242,8 +242,8 @@ class DeviceFunctionsTest(test.TestCase):
|
||||
output_node = math_ops_lib.multiply(
|
||||
variable_node, 2.0, name="output_node")
|
||||
with session.Session() as sess:
|
||||
self.evaluate(variable_node.initializer)
|
||||
output = self.evaluate(output_node)
|
||||
sess.run(variable_node.initializer)
|
||||
output = sess.run(output_node)
|
||||
self.assertNear(2.0, output, 0.00001)
|
||||
variable_graph_def = sess.graph.as_graph_def()
|
||||
# First get the constant_graph_def when variable_names_whitelist is
|
||||
@ -256,7 +256,7 @@ class DeviceFunctionsTest(test.TestCase):
|
||||
|
||||
# Then initialize the unused variable, and get another
|
||||
# constant_graph_def when variable_names_whitelist is not set.
|
||||
self.evaluate(another_variable.initializer)
|
||||
sess.run(another_variable.initializer)
|
||||
constant_graph_def_without_variable_whitelist = (
|
||||
graph_util.convert_variables_to_constants(
|
||||
sess, variable_graph_def, ["output_node"]))
|
||||
@ -295,7 +295,7 @@ class DeviceFunctionsTest(test.TestCase):
|
||||
["Variable", "VariableV2", "VarHandleOp", "ReadVariableOp"])
|
||||
with session.Session() as sess:
|
||||
output_node = sess.graph.get_tensor_by_name("output_node:0")
|
||||
output = self.evaluate(output_node)
|
||||
output = sess.run(output_node)
|
||||
self.assertNear(2.0, output, 0.00001)
|
||||
|
||||
def create_node_def(self, op, name, inputs):
|
||||
|
@ -398,10 +398,10 @@ class ImportGraphDefTest(test.TestCase):
|
||||
# TODO(b/76173421): make this work (currently DCHECKS)
|
||||
# with self.cached_session() as sess:
|
||||
# sess.run(imported_init)
|
||||
# self.assertEqual(self.evaluate(imported_var), 1.0)
|
||||
# self.assertEqual(self.evaluate(imported_assign), 2.0)
|
||||
# self.assertEqual(list(self.evaluate(imported_shape)), [])
|
||||
# self.assertEqual(list(self.evaluate(new_var_shape)), [])
|
||||
# self.assertEqual(sess.run(imported_var), 1.0)
|
||||
# self.assertEqual(sess.run(imported_assign), 2.0)
|
||||
# self.assertEqual(list(sess.run(imported_shape)), [])
|
||||
# self.assertEqual(list(sess.run(new_var_shape)), [])
|
||||
|
||||
def testWhileLoop(self):
|
||||
# Produce GraphDef containing while loop.
|
||||
@ -418,7 +418,7 @@ class ImportGraphDefTest(test.TestCase):
|
||||
return_elements=[r.name])
|
||||
self.assertEqual(imported_r.name, "import/" + r.name)
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(self.evaluate(imported_r), 10)
|
||||
self.assertEqual(sess.run(imported_r), 10)
|
||||
|
||||
def testImportWhileLoopInCond(self):
|
||||
# Produce GraphDef containing while loop.
|
||||
@ -458,7 +458,7 @@ class ImportGraphDefTest(test.TestCase):
|
||||
lambda i: i < 2, ImportFn, [0],
|
||||
shape_invariants=[tensor_shape.TensorShape(None)])
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(self.evaluate(out), 10)
|
||||
self.assertEqual(sess.run(out), 10)
|
||||
|
||||
def testTypeMismatchInGraphDef(self):
|
||||
# TODO(skyewm): improve error message
|
||||
|
@ -492,8 +492,8 @@ class ScopedMetaGraphTest(test.TestCase):
|
||||
init_op = variables.global_variables_initializer()
|
||||
grad = gradients_impl.gradients([output], [var])
|
||||
with session.Session() as sess:
|
||||
self.evaluate(init_op)
|
||||
expected_grad_value = self.evaluate(grad)
|
||||
sess.run(init_op)
|
||||
expected_grad_value = sess.run(grad)
|
||||
|
||||
# Restore the MetaGraphDef into a new Graph with an import scope.
|
||||
with ops.Graph().as_default():
|
||||
@ -518,8 +518,8 @@ class ScopedMetaGraphTest(test.TestCase):
|
||||
init_op = variables.global_variables_initializer()
|
||||
|
||||
with session.Session() as sess:
|
||||
self.evaluate(init_op)
|
||||
actual_grad_value = self.evaluate(grad)
|
||||
sess.run(init_op)
|
||||
actual_grad_value = sess.run(grad)
|
||||
self.assertEqual(expected_grad_value, actual_grad_value)
|
||||
|
||||
def testImportWhileLoopInWhileLoop(self):
|
||||
@ -544,7 +544,7 @@ class ScopedMetaGraphTest(test.TestCase):
|
||||
_, x = control_flow_ops.while_loop(lambda i, x: i < 2, body, [0, 0.0],
|
||||
name="")
|
||||
with session.Session() as sess:
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
sess.run(variables.global_variables_initializer())
|
||||
sess.run(x)
|
||||
|
||||
def testScopedImportUnderNameScope(self):
|
||||
@ -869,7 +869,7 @@ class MetaGraphWithVariableScopeTest(test.TestCase):
|
||||
|
||||
initializer = variables.local_variables_initializer()
|
||||
sess.run(initializer)
|
||||
self.evaluate(update_op)
|
||||
sess.run(update_op)
|
||||
|
||||
meta_graph.export_scoped_meta_graph(
|
||||
filename=meta_graph_filename, graph=graph)
|
||||
|
@ -517,21 +517,21 @@ class OperationTest(test_util.TensorFlowTestCase):
|
||||
self.assertEquals(x.consumers(), [])
|
||||
self.assertEquals(y.consumers(), [z.op, z.op])
|
||||
with session.Session(graph=g) as sess:
|
||||
self.assertEquals(self.evaluate(z), 4)
|
||||
self.assertEquals(sess.run(z), 4)
|
||||
|
||||
z.op._update_input(0, x) # pylint: disable=protected-access
|
||||
self.assertEquals(list(z.op.inputs), [x, y])
|
||||
self.assertEquals(x.consumers(), [z.op])
|
||||
self.assertEquals(y.consumers(), [z.op])
|
||||
with session.Session(graph=g) as sess:
|
||||
self.assertEquals(self.evaluate(z), 3)
|
||||
self.assertEquals(sess.run(z), 3)
|
||||
|
||||
z.op._update_input(1, y) # pylint: disable=protected-access
|
||||
self.assertEquals(list(z.op.inputs), [x, y])
|
||||
self.assertEquals(x.consumers(), [z.op])
|
||||
self.assertEquals(y.consumers(), [z.op])
|
||||
with session.Session(graph=g) as sess:
|
||||
self.assertEquals(self.evaluate(z), 3)
|
||||
self.assertEquals(sess.run(z), 3)
|
||||
|
||||
def testUpdateInputGraphError(self):
|
||||
g_0 = ops.Graph()
|
||||
|
@ -109,8 +109,8 @@ class SmartCaseTest(test_util.TensorFlowTestCase):
|
||||
exclusive=True)
|
||||
with session.Session() as sess:
|
||||
# No feed_dict necessary
|
||||
self.assertEqual(self.evaluate(y), 1)
|
||||
self.assertEqual(self.evaluate(z), 1)
|
||||
self.assertEqual(sess.run(y), 1)
|
||||
self.assertEqual(sess.run(z), 1)
|
||||
|
||||
def testFalse(self):
|
||||
conditions = [(False, raise_exception)]
|
||||
@ -121,8 +121,8 @@ class SmartCaseTest(test_util.TensorFlowTestCase):
|
||||
default=lambda: constant_op.constant(1),
|
||||
exclusive=True)
|
||||
with session.Session() as sess:
|
||||
self.assertEqual(self.evaluate(y), 1)
|
||||
self.assertEqual(self.evaluate(z), 1)
|
||||
self.assertEqual(sess.run(y), 1)
|
||||
self.assertEqual(sess.run(z), 1)
|
||||
|
||||
def testMix(self):
|
||||
x = array_ops.placeholder(dtype=dtypes.int32, shape=[])
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user