Replace a few calls of Session run with evaluate

In order to support tests running in eager mode we need to avoid
unnecessary use of Sessions in tests. This moves to remove some
of the uses of the `run` function in favor of `evaluate`.

PiperOrigin-RevId: 223009795
This commit is contained in:
Gaurav Jain 2018-11-27 10:05:22 -08:00 committed by TensorFlower Gardener
parent edf88fcda8
commit b17d53c0cd
264 changed files with 3026 additions and 2939 deletions

View File

@ -57,11 +57,11 @@ class CategoricalTest(xla_test.XLATestCase):
Returns:
Frequencies from sampled classes; shape [batch_size, num_classes].
"""
with self.cached_session() as sess, self.test_scope():
with self.cached_session(), self.test_scope():
random_seed.set_random_seed(1618)
op = random_ops.multinomial(logits, num_samples,
output_dtype=dtypes.int32)
d = sess.run(op)
d = self.evaluate(op)
batch_size, num_classes = logits.shape
freqs_mat = []
@ -80,15 +80,15 @@ class CategoricalTest(xla_test.XLATestCase):
def _testRngIsNotConstant(self, rng, dtype, output_dtype):
# Tests that 'rng' does not always return the same value.
with self.cached_session() as sess:
with self.cached_session():
with self.test_scope():
x = rng(dtype, output_dtype)
# The random-number generator, if working correctly, should produce the
# same output multiple times with low probability.
y = sess.run(x)
z = sess.run(x)
w = sess.run(x)
y = self.evaluate(x)
z = self.evaluate(x)
w = self.evaluate(x)
# We use exact equality here. If the random-number generator is producing
# deterministic output, all three outputs will be bitwise identical.
@ -108,12 +108,12 @@ class CategoricalTest(xla_test.XLATestCase):
def testCategoricalIsInRange(self):
for dtype in self.float_types:
for output_dtype in self.output_dtypes():
with self.cached_session() as sess:
with self.cached_session():
with self.test_scope():
x = random_ops.multinomial(
array_ops.ones(shape=[1, 20], dtype=dtype), 1000,
output_dtype=output_dtype)
y = sess.run(x)
y = self.evaluate(x)
self.assertTrue((y >= 0).sum() == 1000)
self.assertTrue((y < 20).sum() == 1000)
@ -170,11 +170,11 @@ class CategoricalTest(xla_test.XLATestCase):
self.assertEqual(s0 == s1, np.all(v0 == v1))
def testEmpty(self):
with self.cached_session() as sess:
with self.cached_session():
with self.test_scope():
x = random_ops.multinomial(
array_ops.zeros([42, 40]), 0, output_dtype=dtypes.int32)
y = sess.run(x)
y = self.evaluate(x)
self.assertEqual(y.shape, (42, 0))
def testEmptyStateless(self):

View File

@ -254,7 +254,7 @@ class ConcatTest(xla_test.XLATestCase):
def DISABLED_testZeroSize(self):
# Verify that concat doesn't crash and burn for zero size inputs
np.random.seed(7)
with self.cached_session() as sess:
with self.cached_session():
with self.test_scope():
for shape0 in (), (2,):
axis = len(shape0)
@ -270,7 +270,7 @@ class ConcatTest(xla_test.XLATestCase):
self.assertAllEqual(c.eval(), correct)
# Check gradients
dc = np.random.randn(*c.get_shape().as_list())
dxs = sess.run(gradients_impl.gradients(c, xs, dc))
dxs = self.evaluate(gradients_impl.gradients(c, xs, dc))
self.assertAllEqual(dc, np.concatenate(dxs, axis=axis))
def testConcatTuple(self):
@ -330,47 +330,47 @@ class ConcatTest(xla_test.XLATestCase):
class ConcatOffsetTest(xla_test.XLATestCase):
def testBasic(self):
with self.cached_session() as sess:
with self.cached_session():
with self.test_scope():
cdim = constant_op.constant(1, dtypes.int32)
s0 = constant_op.constant([2, 3, 5], dtypes.int32)
s1 = constant_op.constant([2, 7, 5], dtypes.int32)
s2 = constant_op.constant([2, 20, 5], dtypes.int32)
off = gen_array_ops.concat_offset(cdim, [s0, s1, s2])
ans = sess.run(off)
ans = self.evaluate(off)
self.assertAllEqual(ans, [[0, 0, 0], [0, 3, 0], [0, 10, 0]])
class PackTest(xla_test.XLATestCase):
def testBasic(self):
with self.cached_session() as sess:
with self.cached_session():
with self.test_scope():
s0 = constant_op.constant([2, 3, 5], dtypes.int32)
s1 = constant_op.constant([2, 7, 5], dtypes.int32)
s2 = constant_op.constant([2, 20, 5], dtypes.int32)
packed = array_ops.stack([s0, s1, s2])
ans = sess.run(packed)
ans = self.evaluate(packed)
self.assertAllEqual(ans, [[2, 3, 5], [2, 7, 5], [2, 20, 5]])
def testScalars(self):
with self.cached_session() as sess:
with self.cached_session():
with self.test_scope():
s0 = constant_op.constant(2, dtypes.int32)
s1 = constant_op.constant(3, dtypes.int32)
s2 = constant_op.constant(5, dtypes.int32)
packed = array_ops.stack([s0, s1, s2])
ans = sess.run(packed)
ans = self.evaluate(packed)
self.assertAllEqual(ans, [2, 3, 5])
def testEmpty(self):
with self.cached_session() as sess:
with self.cached_session():
with self.test_scope():
s0 = constant_op.constant([[]], dtypes.int32)
s1 = constant_op.constant([[]], dtypes.int32)
s2 = constant_op.constant([[]], dtypes.int32)
packed = array_ops.stack([s0, s1, s2])
ans = sess.run(packed)
ans = self.evaluate(packed)
self.assertAllEqual(ans, [[[]], [[]], [[]]])

View File

@ -72,7 +72,7 @@ class DenseLayerTest(test.TestCase):
x = array_ops.placeholder(shape=[None, None, 3], dtype=np.float32)
y = layers.dense(x, 3)
sess.run(variables.initialize_all_variables())
self.evaluate(variables.initialize_all_variables())
run_metadata = config_pb2.RunMetadata()
test_utils.RunWithWarmup(
sess,
@ -97,7 +97,7 @@ class DenseLayerTest(test.TestCase):
with jit_scope():
y = layers.dense(x, 3)
sess.run(variables.initialize_all_variables())
self.evaluate(variables.initialize_all_variables())
run_metadata = config_pb2.RunMetadata()
test_utils.RunWithWarmup(
sess,
@ -126,7 +126,7 @@ class DenseLayerTest(test.TestCase):
with jit_scope():
y = layers.dense(x, 3)
sess.run(variables.initialize_all_variables())
self.evaluate(variables.initialize_all_variables())
run_metadata = config_pb2.RunMetadata()
test_utils.RunWithWarmup(
sess,

View File

@ -101,12 +101,12 @@ class EagerTest(xla_test.XLATestCase):
self.assertAllEqual(15, product)
# Run some ops graphly
with context.graph_mode(), self.cached_session() as sess:
with context.graph_mode(), self.cached_session():
with self.test_scope():
three = constant_op.constant(3)
five = constant_op.constant(5)
product = three * five
self.assertAllEqual(15, sess.run(product))
self.assertAllEqual(15, self.evaluate(product))
def testDegenerateSlices(self):
with self.test_scope():

View File

@ -40,7 +40,7 @@ class FunctionTest(xla_test.XLATestCase):
bval = np.array([5, 6, 7, 8]).reshape([2, 2]).astype(np.float32)
expected = APlus2B(aval, bval)
with self.cached_session() as sess:
with self.cached_session():
@function.Defun(dtypes.float32, dtypes.float32)
def Foo(a, b):
@ -50,7 +50,7 @@ class FunctionTest(xla_test.XLATestCase):
b = constant_op.constant(bval, name="b")
with self.test_scope():
call_f = Foo(a, b)
result = sess.run(call_f)
result = self.evaluate(call_f)
self.assertAllClose(result, expected, rtol=1e-3)
def testNestedFunctions(self):
@ -66,7 +66,7 @@ class FunctionTest(xla_test.XLATestCase):
bval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32)
expected = APlus2B(aval, bval)
with self.cached_session() as sess:
with self.cached_session():
@function.Defun(dtypes.float32, dtypes.float32)
def Foo(a, b):
@ -76,7 +76,7 @@ class FunctionTest(xla_test.XLATestCase):
b = constant_op.constant(bval, name="b")
with self.test_scope():
call_g = Foo(a, b)
result = sess.run(call_g)
result = self.evaluate(call_g)
self.assertAllClose(result, expected, rtol=1e-3)
def testFunctionMultipleRetvals(self):
@ -90,7 +90,7 @@ class FunctionTest(xla_test.XLATestCase):
bval = np.array([5, 6, 7, 8]).reshape([2, 2]).astype(np.float32)
expected = Func(aval, bval)
with self.cached_session() as sess:
with self.cached_session():
@function.Defun(dtypes.float32, dtypes.float32)
def Foo(a, b):
@ -100,7 +100,7 @@ class FunctionTest(xla_test.XLATestCase):
b = constant_op.constant(bval, name="b")
with self.test_scope():
call_f = Foo(a, b)
result = sess.run(call_f)
result = self.evaluate(call_f)
self.assertAllClose(result, expected, rtol=1e-3)
def testCompileTimeConstantsInDefun(self):

View File

@ -33,13 +33,13 @@ class ListDiffTest(xla_test.XLATestCase):
def _testListDiff(self, x, y, out, idx):
for dtype in [dtypes.int32, dtypes.int64]:
for index_dtype in [dtypes.int32, dtypes.int64]:
with self.cached_session() as sess:
with self.cached_session():
x_tensor = ops.convert_to_tensor(x, dtype=dtype)
y_tensor = ops.convert_to_tensor(y, dtype=dtype)
with self.test_scope():
out_tensor, idx_tensor = array_ops.listdiff(
x_tensor, y_tensor, out_idx=index_dtype)
tf_out, tf_idx = sess.run([out_tensor, idx_tensor])
tf_out, tf_idx = self.evaluate([out_tensor, idx_tensor])
self.assertAllEqual(out, tf_out)
self.assertAllEqual(idx, tf_idx)
self.assertEqual(1, out_tensor.get_shape().ndims)

View File

@ -88,8 +88,8 @@ class LSTMTest(test.TestCase):
(basename, m_prev_scalar, c_prev_scalar, pad_scalar))
# Initialize variables and run the unrolled LSTM step.
sess.run(variables.global_variables_initializer())
return sess.run([m, c])
self.evaluate(variables.global_variables_initializer())
return self.evaluate([m, c])
def testLSTMCell(self):
# Run with all-0 weights, no padding.
@ -173,8 +173,8 @@ class LSTMTest(test.TestCase):
(basename, m_init_scalar, c_init_scalar, pad_scalar))
# Initialize variables and run the unrolled LSTM layer.
sess.run(variables.global_variables_initializer())
return sess.run(out_seq)
self.evaluate(variables.global_variables_initializer())
return self.evaluate(out_seq)
def testLSTMLayer(self):
# Run with all-0 weights, no padding.

View File

@ -33,7 +33,7 @@ class PlaceholderTest(xla_test.XLATestCase):
ph = array_ops.placeholder_with_default(v, shape=[])
out = ph * 2
sess.run(variables.variables_initializer([v]))
self.assertEqual(8.0, sess.run(out))
self.assertEqual(8.0, self.evaluate(out))
def test_placeholder_with_default_fed(self):
with self.cached_session() as sess, self.test_scope():

View File

@ -46,9 +46,9 @@ class RandomOpsTest(xla_test.XLATestCase):
# The random-number generator, if working correctly, should produce the
# same output multiple times with low probability.
y = sess.run(x)
z = sess.run(x)
w = sess.run(x)
y = self.evaluate(x)
z = self.evaluate(x)
w = self.evaluate(x)
# We use exact equality here. If the random-number generator is producing
# deterministic output, all three outputs will be bitwise identical.
@ -83,7 +83,7 @@ class RandomOpsTest(xla_test.XLATestCase):
with self.test_scope():
x = random_ops.random_uniform(
shape=[1000], dtype=dtype, minval=-2, maxval=33)
y = sess.run(x)
y = self.evaluate(x)
self.assertTrue((y >= -2).sum() == 1000)
self.assertTrue((y < 33).sum() == 1000)
@ -102,7 +102,7 @@ class RandomOpsTest(xla_test.XLATestCase):
with self.cached_session() as sess:
with self.test_scope():
x = random_ops.truncated_normal(shape=[count], dtype=dtype)
y = sess.run(x)
y = self.evaluate(x)
def normal_cdf(x):
return .5 * math.erfc(-x / math.sqrt(2))
@ -111,7 +111,7 @@ class RandomOpsTest(xla_test.XLATestCase):
return math.exp(-(x**2) / 2.) / math.sqrt(2 * math.pi)
def probit(x, sess=sess):
return sess.run(special_math.ndtri(x))
return self.evaluate(special_math.ndtri(x))
a = -2.
b = 2.
@ -148,7 +148,7 @@ class RandomOpsTest(xla_test.XLATestCase):
with self.test_scope():
x = math_ops.range(1 << 16)
shuffle = random_ops.random_shuffle(x)
result = sess.run(shuffle)
result = self.evaluate(shuffle)
expected = range(1 << 16)
# Compare sets to avoid randomness behavior changes but make sure still
# have all the values.
@ -159,7 +159,7 @@ class RandomOpsTest(xla_test.XLATestCase):
with self.test_scope():
x = array_ops.diag(math_ops.range(20))
shuffle = random_ops.random_shuffle(x)
result = sess.run(shuffle)
result = self.evaluate(shuffle)
expected = np.diag(range(20)).flatten()
# Compare sets to avoid randomness behavior changes but make sure still
# have all the values.

View File

@ -156,7 +156,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
return math.exp(-(x**2) / 2.) / math.sqrt(2 * math.pi)
def probit(x, sess=sess):
return sess.run(special_math.ndtri(x))
return self.evaluate(special_math.ndtri(x))
a = -2.
b = 2.

View File

@ -505,7 +505,7 @@ class TensorArrayTest(xla_test.XLATestCase):
[-0.5, 1.5], # read(0) gradient
[20.0, 30.0, 40.0, 50.0], # concat gradient
])
grad_vals = sess.run(grad_r) # 2 + 2 entries
grad_vals = self.evaluate(grad_r) # 2 + 2 entries
self.assertAllClose([2.0 - 0.5 + 20.0, 3.0 + 1.5 + 30.0], grad_vals[0])
self.assertAllEqual([4.0 + 40.0, 5.0 + 50.0], grad_vals[1])

View File

@ -77,7 +77,7 @@ class VariableOpsTest(xla_test.XLATestCase):
sess.run(variables.variables_initializer([v]))
x = v.sparse_read(2)
self.assertAllClose(
np.array([8j, 9, 10, 11]).astype(dtype), sess.run(x))
np.array([8j, 9, 10, 11]).astype(dtype), self.evaluate(x))
def testSparseRead1DIndices(self):
for dtype in self.numeric_types:
@ -89,7 +89,7 @@ class VariableOpsTest(xla_test.XLATestCase):
x = v.sparse_read([2, 1])
self.assertAllClose(
np.array([[8, 9, 10, 11], [4, 5, 6j, 7]]).astype(dtype),
sess.run(x))
self.evaluate(x))
def testSparseRead2DIndices(self):
for dtype in self.numeric_types:
@ -102,7 +102,7 @@ class VariableOpsTest(xla_test.XLATestCase):
self.assertAllClose(
np.array([[[8, 9, 10, 11], [4, 5, 6, 7]],
[[0, 1, 2j, 3], [8, 9, 10, 11]]]).astype(dtype),
sess.run(x))
self.evaluate(x))
def testSparseRead2DIndices3DTensor(self):
for dtype in self.numeric_types:
@ -115,9 +115,9 @@ class VariableOpsTest(xla_test.XLATestCase):
x = v.sparse_read([[2, 1], [3, 0]])
self.assertAllClose(
np.array(
[[[[20, 21, 22], [23, 24j, 25]], [[10, 11, 12], [13, 14, 15]]
], [[[30, 31, 32], [33, 34, 35]], [[0, 1, 2], [3, 4, 5]]]
],).astype(dtype), sess.run(x))
[[[[20, 21, 22], [23, 24j, 25]], [[10, 11, 12], [13, 14, 15]]],
[[[30, 31, 32], [33, 34, 35]], [[0, 1, 2], [3, 4, 5]]]
],).astype(dtype), self.evaluate(x))
def testShape(self):
for dtype in self.numeric_types:
@ -229,7 +229,7 @@ class VariableOpsTest(xla_test.XLATestCase):
resource_variable_ops.resource_scatter_add(
handle, [0], constant_op.constant([[2]], dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertAllEqual(sess.run(read), [[3], [7]])
self.assertAllEqual(self.evaluate(read), [[3], [7]])
def testScatterSub(self):
with self.test_session() as sess, self.test_scope():
@ -242,7 +242,7 @@ class VariableOpsTest(xla_test.XLATestCase):
resource_variable_ops.resource_scatter_sub(
handle, [1], constant_op.constant([[2]], dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertAllEqual(sess.run(read), [[4], [-1]])
self.assertAllEqual(self.evaluate(read), [[4], [-1]])
def testScatterMul(self):
with self.test_session() as sess, self.test_scope():
@ -255,7 +255,7 @@ class VariableOpsTest(xla_test.XLATestCase):
resource_variable_ops.resource_scatter_mul(
handle, [0], constant_op.constant([[5]], dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(sess.run(read), [[5]])
self.assertEqual(self.evaluate(read), [[5]])
def testScatterDiv(self):
with self.test_session() as sess, self.test_scope():
@ -268,7 +268,7 @@ class VariableOpsTest(xla_test.XLATestCase):
resource_variable_ops.resource_scatter_div(
handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertAllEqual(sess.run(read), [[2]])
self.assertAllEqual(self.evaluate(read), [[2]])
def testScatterMin(self):
with self.test_session() as sess, self.test_scope():
@ -281,7 +281,7 @@ class VariableOpsTest(xla_test.XLATestCase):
resource_variable_ops.resource_scatter_min(
handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(sess.run(read), [[3]])
self.assertEqual(self.evaluate(read), [[3]])
def testScatterMax(self):
with self.test_session() as sess, self.test_scope():
@ -294,7 +294,7 @@ class VariableOpsTest(xla_test.XLATestCase):
resource_variable_ops.resource_scatter_max(
handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(sess.run(read), [[6]])
self.assertEqual(self.evaluate(read), [[6]])
def testScatterUpdate(self):
with self.test_session() as sess, self.test_scope():
@ -307,7 +307,7 @@ class VariableOpsTest(xla_test.XLATestCase):
resource_variable_ops.resource_scatter_update(
handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(sess.run(read), [[3]])
self.assertEqual(self.evaluate(read), [[3]])
def testScatterAddScalar(self):
with self.test_session() as sess, self.test_scope():
@ -320,7 +320,7 @@ class VariableOpsTest(xla_test.XLATestCase):
resource_variable_ops.resource_scatter_add(
handle, [0], constant_op.constant(2, dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(sess.run(read), [[3]])
self.assertEqual(self.evaluate(read), [[3]])
def testScatterSubScalar(self):
with self.test_session() as sess, self.test_scope():
@ -333,7 +333,7 @@ class VariableOpsTest(xla_test.XLATestCase):
resource_variable_ops.resource_scatter_sub(
handle, [0], constant_op.constant(2, dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(sess.run(read), [[-1]])
self.assertEqual(self.evaluate(read), [[-1]])
def testScatterMulScalar(self):
with self.test_session() as sess, self.test_scope():
@ -346,7 +346,7 @@ class VariableOpsTest(xla_test.XLATestCase):
resource_variable_ops.resource_scatter_mul(
handle, [0], constant_op.constant(5, dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(sess.run(read), [[5]])
self.assertEqual(self.evaluate(read), [[5]])
def testScatterDivScalar(self):
with self.test_session() as sess, self.test_scope():
@ -359,7 +359,7 @@ class VariableOpsTest(xla_test.XLATestCase):
resource_variable_ops.resource_scatter_div(
handle, [0], constant_op.constant(3, dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(sess.run(read), [[2]])
self.assertEqual(self.evaluate(read), [[2]])
def testScatterMinScalar(self):
with self.test_session() as sess, self.test_scope():
@ -372,7 +372,7 @@ class VariableOpsTest(xla_test.XLATestCase):
resource_variable_ops.resource_scatter_min(
handle, [0], constant_op.constant(3, dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(sess.run(read), [[3]])
self.assertEqual(self.evaluate(read), [[3]])
def testScatterMaxScalar(self):
with self.test_session() as sess, self.test_scope():
@ -385,7 +385,7 @@ class VariableOpsTest(xla_test.XLATestCase):
resource_variable_ops.resource_scatter_max(
handle, [0], constant_op.constant(3, dtype=dtypes.int32)))
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(sess.run(read), [[6]])
self.assertEqual(self.evaluate(read), [[6]])
def testScatterNdAddOps(self):
with self.test_session() as sess, self.test_scope():
@ -400,7 +400,7 @@ class VariableOpsTest(xla_test.XLATestCase):
sess.run(gen_state_ops.resource_scatter_nd_add(handle, indices, updates))
read = resource_variable_ops.read_variable_op(
handle, dtype=dtypes.float32)
self.assertAllClose(expected, sess.run(read))
self.assertAllClose(expected, self.evaluate(read))
def testScatterNdUpdateAddOps(self):
with self.test_session() as sess, self.test_scope():
@ -416,7 +416,7 @@ class VariableOpsTest(xla_test.XLATestCase):
gen_state_ops.resource_scatter_nd_update(handle, indices, updates))
read = resource_variable_ops.read_variable_op(
handle, dtype=dtypes.float32)
self.assertAllClose(expected, sess.run(read))
self.assertAllClose(expected, self.evaluate(read))
class StridedSliceAssignChecker(object):

View File

@ -81,7 +81,7 @@ class XlaDeviceTest(xla_test.XLATestCase):
with self.cached_session() as sess:
with self.test_scope():
x = gen_control_flow_ops.control_trigger()
sess.run(x)
self.evaluate(x)
if __name__ == "__main__":

View File

@ -93,10 +93,10 @@ class KerasTest(tf.test.TestCase):
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
self.evaluate(init)
sample_input = tf.random_uniform((1, 10, 10, 1))
output = model(sample_input) # pylint: disable=not-callable
self.assertEqual(sess.run(output).shape, (1, 3))
self.assertEqual(self.evaluate(output).shape, (1, 3))
if __name__ == '__main__':

View File

@ -34,7 +34,7 @@ class ListLiteralsTest(tf.test.TestCase):
result = converted()
with self.cached_session() as sess:
self.assertAllEqual(sess.run(result), [1, 2, 3])
self.assertAllEqual(self.evaluate(result), [1, 2, 3])
if __name__ == '__main__':

View File

@ -35,7 +35,7 @@ class InputDataTest(test.TestCase):
with self.cached_session() as sess:
sample_data = tf.zeros([32000, 2])
wav_encoder = contrib_audio.encode_wav(sample_data, 16000)
wav_data = sess.run(wav_encoder)
wav_data = self.evaluate(wav_encoder)
return wav_data
def _saveTestWavFile(self, filename, wav_data):

View File

@ -33,7 +33,7 @@ class LabelWavTest(test.TestCase):
with self.cached_session() as sess:
sample_data = tf.zeros([1000, 2])
wav_encoder = contrib_audio.encode_wav(sample_data, 16000)
wav_data = sess.run(wav_encoder)
wav_data = self.evaluate(wav_encoder)
return wav_data
def _saveTestWavFile(self, filename, wav_data):

View File

@ -33,7 +33,7 @@ class WavToFeaturesTest(test.TestCase):
with self.cached_session() as sess:
sample_data = tf.zeros([32000, 2])
wav_encoder = contrib_audio.encode_wav(sample_data, 16000)
wav_data = sess.run(wav_encoder)
wav_data = self.evaluate(wav_encoder)
return wav_data
def _saveTestWavFile(self, filename, wav_data):

View File

@ -111,7 +111,7 @@ class UnidirectionalSequenceLstmTest(test_util.TensorFlowTestCase):
# Initialize variables
init = tf.global_variables_initializer()
sess.run(init)
self.evaluate(init)
for _ in range(TRAIN_STEPS):
batch_x, batch_y = self.mnist.train.next_batch(
batch_size=self.batch_size, shuffle=False)

View File

@ -41,7 +41,7 @@ class AssertsTest(converter_testing.TestCase):
op = result.test_fn(constant_op.constant(False))
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
'test message'):
sess.run(op)
self.evaluate(op)
if __name__ == '__main__':

View File

@ -94,7 +94,7 @@ class CallTreesTest(converter_testing.TestCase):
dtypes.int64) as result:
with self.cached_session() as sess:
self.assertTrue(isinstance(result.test_fn(), ops.Tensor))
self.assertIn(sess.run(result.test_fn()), (0, 1, 2))
self.assertIn(self.evaluate(result.test_fn()), (0, 1, 2))
def test_uncompiled_modules(self):
@ -113,7 +113,7 @@ class CallTreesTest(converter_testing.TestCase):
with self.compiled(node, ns) as result:
with self.cached_session() as sess:
result_tensor = result.test_fn(constant_op.constant(1))
self.assertEquals(sess.run(result_tensor), 3)
self.assertEquals(self.evaluate(result_tensor), 3)
def test_call_to_decorated_function(self):

View File

@ -68,7 +68,7 @@ class ListTest(converter_testing.TestCase):
with self.cached_session() as sess:
tl = result.test_fn()
r = list_ops.tensor_list_stack(tl, dtypes.int32)
self.assertAllEqual(sess.run(r), [1, 2, 3])
self.assertAllEqual(self.evaluate(r), [1, 2, 3])
def test_list_pop(self):
@ -91,8 +91,8 @@ class ListTest(converter_testing.TestCase):
with self.cached_session() as sess:
ts, tl = result.test_fn()
r = list_ops.tensor_list_stack(tl, dtypes.int32)
self.assertAllEqual(sess.run(r), [1, 2])
self.assertAllEqual(sess.run(ts), 3)
self.assertAllEqual(self.evaluate(r), [1, 2])
self.assertAllEqual(self.evaluate(ts), 3)
def test_double_list_pop(self):
@ -123,7 +123,7 @@ class ListTest(converter_testing.TestCase):
with self.compiled(node, {}, array_ops.stack, dtypes.int32) as result:
with self.cached_session() as sess:
self.assertAllEqual(sess.run(result.test_fn()), [1, 2, 3])
self.assertAllEqual(self.evaluate(result.test_fn()), [1, 2, 3])
# TODO(mdan): Add a test with tf.stack with axis kwarg.

View File

@ -48,12 +48,12 @@ class SideEffectGuardsTest(converter_testing.TestCase):
with self.compiled(node, {}, state_ops.assign) as result:
with self.cached_session() as sess:
v = variable_scope.get_variable('test', initializer=2)
sess.run(v.initializer)
sess.run(result.test_fn(v))
self.evaluate(v.initializer)
self.evaluate(result.test_fn(v))
# TODO(mdan): Add support for this use case.
# Right now the variable `a` is not conditioned on the `assign` because
# there's no way to add control dependencies to a variable object.
self.assertEqual(2, sess.run(v))
self.assertEqual(2, self.evaluate(v))
def test_side_effect_on_used_variable(self):
@ -69,11 +69,11 @@ class SideEffectGuardsTest(converter_testing.TestCase):
with self.compiled(node, {}, state_ops.assign) as result:
with self.cached_session() as sess:
v = variable_scope.get_variable('test', initializer=2)
sess.run(v.initializer)
sess.run(result.test_fn(v))
self.evaluate(v.initializer)
self.evaluate(result.test_fn(v))
# TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
# Right now it's 3 or 4 based on whether the read is synchronized.
self.assertEqual(3, sess.run(v))
self.assertEqual(3, self.evaluate(v))
def test_side_effect_on_tensor(self):
@ -109,10 +109,10 @@ class SideEffectGuardsTest(converter_testing.TestCase):
with self.compiled(node, {}, state_ops.assign_add) as result:
with self.cached_session() as sess:
v = variable_scope.get_variable('test', initializer=2)
sess.run(v.initializer)
sess.run(result.test_fn(v))
self.evaluate(v.initializer)
self.evaluate(result.test_fn(v))
# TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
self.assertEqual(4, sess.run(v))
self.assertEqual(4, self.evaluate(v))
def test_multiline_nested_block(self):
@ -130,10 +130,10 @@ class SideEffectGuardsTest(converter_testing.TestCase):
with self.compiled(node, {}, state_ops.assign, ops.name_scope) as result:
with self.cached_session() as sess:
v = variable_scope.get_variable('test', initializer=2)
sess.run(v.initializer)
sess.run(result.test_fn(v))
self.evaluate(v.initializer)
self.evaluate(result.test_fn(v))
# TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
self.assertEqual(3, sess.run(v))
self.assertEqual(3, self.evaluate(v))
def test_multiline_block_unsafe(self):
@ -153,10 +153,10 @@ class SideEffectGuardsTest(converter_testing.TestCase):
state_ops.assign_add) as result:
with self.cached_session() as sess:
v = variable_scope.get_variable('test', initializer=2)
sess.run(v.initializer)
sess.run(result.test_fn(v))
self.evaluate(v.initializer)
self.evaluate(result.test_fn(v))
# TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
self.assertEqual(4, sess.run(v))
self.assertEqual(4, self.evaluate(v))
if __name__ == '__main__':

View File

@ -49,7 +49,7 @@ class SliceTest(converter_testing.TestCase):
tl = list_ops.tensor_list_from_tensor(
[1, 2], element_shape=constant_op.constant([], dtype=dtypes.int32))
y = result.test_fn(tl)
self.assertEqual(2, sess.run(y))
self.assertEqual(2, self.evaluate(y))
def test_index_access_multiple_definitions(self):

View File

@ -55,7 +55,7 @@ class RuntimeErrorsTest(test.TestCase):
with self.assertRaises(errors.TfRuntimeError) as cm:
with errors.improved_errors(zero_div_caller):
with self.cached_session() as sess:
sess.run(ops)
self.evaluate(ops)
for frame in cm.exception.custom_traceback:
_, _, function_name, _ = frame
@ -70,7 +70,7 @@ class RuntimeErrorsTest(test.TestCase):
with self.assertRaises(errors.TfRuntimeError) as cm:
with errors.improved_errors(zero_div_caller):
with self.cached_session() as sess:
sess.run(ops)
self.evaluate(ops)
all_function_names = set()
for frame in cm.exception.custom_traceback:
@ -87,7 +87,7 @@ class RuntimeErrorsTest(test.TestCase):
with self.assertRaises(tf_errors.InvalidArgumentError):
with errors.improved_errors(zero_div_caller):
with self.cached_session() as sess:
sess.run(ops)
self.evaluate(ops)
def test_improved_errors_validation(self):
with self.assertRaisesRegexp(

View File

@ -63,7 +63,7 @@ class ApiTest(test.TestCase):
x = tc.test_method(
constant_op.constant([2, 4]), constant_op.constant(1),
constant_op.constant(-2))
self.assertListEqual([0, 1], sess.run(x).tolist())
self.assertListEqual([0, 1], self.evaluate(x).tolist())
def test_decorator_does_not_recurse(self):
@ -83,7 +83,7 @@ class ApiTest(test.TestCase):
x = tc.test_method(
constant_op.constant([2, 4]), constant_op.constant(1),
constant_op.constant(-2))
self.assertListEqual([0, 1], sess.run(x).tolist())
self.assertListEqual([0, 1], self.evaluate(x).tolist())
def test_decorator_calls_unconverted_graph(self):
@ -104,7 +104,7 @@ class ApiTest(test.TestCase):
x = tc.test_method(
constant_op.constant([2, 4]), constant_op.constant(1),
constant_op.constant(-2))
self.assertListEqual([0, 1], sess.run(x).tolist())
self.assertListEqual([0, 1], self.evaluate(x).tolist())
def test_decorator_calls_unconverted_py_func(self):
@ -130,7 +130,7 @@ class ApiTest(test.TestCase):
x = tc.test_method(
constant_op.constant([2, 4]), constant_op.constant(1),
constant_op.constant(-2))
self.assertListEqual([0, 1], sess.run(x).tolist())
self.assertListEqual([0, 1], self.evaluate(x).tolist())
def test_decorator_calls_decorated(self):
@ -153,7 +153,7 @@ class ApiTest(test.TestCase):
x = tc.test_method(
constant_op.constant([2, 4]), constant_op.constant(1),
constant_op.constant(-2))
self.assertListEqual([0, 1], sess.run(x).tolist())
self.assertListEqual([0, 1], self.evaluate(x).tolist())
def test_decorator_preserves_argspec(self):
@ -192,7 +192,7 @@ class ApiTest(test.TestCase):
x = tc.test_method(
constant_op.constant([2, 4]), constant_op.constant(1),
constant_op.constant(-2))
self.assertListEqual([0, 1], sess.run(x).tolist())
self.assertListEqual([0, 1], self.evaluate(x).tolist())
def test_converted_call_builtin(self):
x = api.converted_call(range, None, converter.ConversionOptions(), 3)
@ -208,7 +208,7 @@ class ApiTest(test.TestCase):
with self.cached_session() as sess:
x = api.converted_call(test_fn, None, converter.ConversionOptions(),
constant_op.constant(-1))
self.assertEqual(1, sess.run(x))
self.assertEqual(1, self.evaluate(x))
def test_converted_call_method_explicit_owner(self):
# TODO(mdan): Implement.
@ -234,7 +234,7 @@ class ApiTest(test.TestCase):
tc = TestClass(constant_op.constant(-1))
x = api.converted_call(tc.test_method, None,
converter.ConversionOptions(), tc)
self.assertEqual(1, sess.run(x))
self.assertEqual(1, self.evaluate(x))
def test_converted_call_method_by_class(self):
@ -252,7 +252,7 @@ class ApiTest(test.TestCase):
tc = TestClass(constant_op.constant(-1))
x = api.converted_call(TestClass.test_method, None,
converter.ConversionOptions(), tc)
self.assertEqual(1, sess.run(x))
self.assertEqual(1, self.evaluate(x))
def test_converted_call_callable_object(self):
@ -269,7 +269,7 @@ class ApiTest(test.TestCase):
with self.cached_session() as sess:
tc = TestClass(constant_op.constant(-1))
x = api.converted_call(tc, None, converter.ConversionOptions())
self.assertEqual(1, sess.run(x))
self.assertEqual(1, self.evaluate(x))
def test_converted_call_constructor(self):
@ -288,7 +288,7 @@ class ApiTest(test.TestCase):
constant_op.constant(-1))
# tc is now a converted object.
x = tc.test_method()
self.assertEqual(1, sess.run(x))
self.assertEqual(1, self.evaluate(x))
def test_converted_call_already_converted(self):
@ -298,12 +298,12 @@ class ApiTest(test.TestCase):
with self.cached_session() as sess:
x = api.converted_call(f, None, converter.ConversionOptions(),
constant_op.constant(0))
self.assertTrue(sess.run(x))
self.assertTrue(self.evaluate(x))
converted_f = api.to_graph(f)
x = api.converted_call(converted_f, None, converter.ConversionOptions(),
constant_op.constant(0))
self.assertTrue(sess.run(x))
self.assertTrue(self.evaluate(x))
def test_converted_call_no_user_code(self):
@ -334,8 +334,8 @@ class ApiTest(test.TestCase):
constant_op.constant([[0.0]]), training=True)
with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
self.assertAllEqual([[0.0, 0.0]], sess.run(x))
self.evaluate(variables.global_variables_initializer())
self.assertAllEqual([[0.0, 0.0]], self.evaluate(x))
def test_converted_call_whitelisted_method_extra_self(self):
@ -349,8 +349,8 @@ class ApiTest(test.TestCase):
model, constant_op.constant([[0.0]]), training=True)
with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
self.assertAllEqual([[0.0, 0.0]], sess.run(x))
self.evaluate(variables.global_variables_initializer())
self.assertAllEqual([[0.0, 0.0]], self.evaluate(x))
def test_converted_call_whitelisted_method_via_owner(self):
@ -364,8 +364,8 @@ class ApiTest(test.TestCase):
constant_op.constant([[0.0]]), training=True)
with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
self.assertAllEqual([[0.0, 0.0]], sess.run(x))
self.evaluate(variables.global_variables_initializer())
self.assertAllEqual([[0.0, 0.0]], self.evaluate(x))
def test_converted_call_lambda(self):
@ -376,8 +376,8 @@ class ApiTest(test.TestCase):
x = api.converted_call(l, None, opts, constant_op.constant(0))
with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
self.assertAllEqual(True, sess.run(x))
self.evaluate(variables.global_variables_initializer())
self.assertAllEqual(True, self.evaluate(x))
def test_to_graph_basic(self):
@ -390,7 +390,7 @@ class ApiTest(test.TestCase):
with self.cached_session() as sess:
x = compiled_fn(constant_op.constant([4, 8]), 4)
self.assertListEqual([1, 2], sess.run(x).tolist())
self.assertListEqual([1, 2], self.evaluate(x).tolist())
def test_to_graph_with_defaults(self):
@ -405,7 +405,7 @@ class ApiTest(test.TestCase):
with self.cached_session() as sess:
x = compiled_fn(constant_op.constant([4, 8]))
self.assertListEqual([1, 2], sess.run(x).tolist())
self.assertListEqual([1, 2], self.evaluate(x).tolist())
def test_to_code_basic(self):

View File

@ -36,7 +36,7 @@ class SpecialFunctionsTest(test.TestCase):
python_one = special_functions.match_staging_level(1, 1)
with self.cached_session() as sess:
self.assertTrue(tensor_util.is_tensor(tensor_one))
self.assertAllEqual(sess.run(tensor_one), 1)
self.assertAllEqual(self.evaluate(tensor_one), 1)
self.assertEqual(python_one, 1)
def test_tensor_list_empty_list(self):
@ -45,21 +45,21 @@ class SpecialFunctionsTest(test.TestCase):
element_shape=())
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
with self.cached_session() as sess:
self.assertAllEqual(sess.run(sl), [])
self.assertAllEqual(self.evaluate(sl), [])
l = special_functions.tensor_list((),
element_dtype=dtypes.int32,
element_shape=())
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
with self.cached_session() as sess:
self.assertAllEqual(sess.run(sl), [])
self.assertAllEqual(self.evaluate(sl), [])
def test_tensor_list_tensor(self):
l = special_functions.tensor_list(
constant_op.constant([], dtype=dtypes.int32))
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
with self.cached_session() as sess:
self.assertAllEqual(sess.run(sl), [])
self.assertAllEqual(self.evaluate(sl), [])
def test_tensor_list_unsupported_initializer(self):
with self.assertRaisesRegexp(ValueError, 'unknown type'):
@ -76,7 +76,7 @@ class SpecialFunctionsTest(test.TestCase):
l = special_functions.tensor_list(elements)
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
with self.cached_session() as sess:
self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
self.assertAllEqual(self.evaluate(sl), [[1, 2], [3, 4]])
def test_tensor_list_array_from_elements(self):
elements = [constant_op.constant([1, 2]), constant_op.constant([3, 4])]
@ -84,7 +84,7 @@ class SpecialFunctionsTest(test.TestCase):
l = special_functions.tensor_list(elements, use_tensor_array=True)
sl = l.stack()
with self.cached_session() as sess:
self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
self.assertAllEqual(self.evaluate(sl), [[1, 2], [3, 4]])
def test_stack(self):
self.assertEqual(special_functions.stack(1, strict=False), 1)

View File

@ -35,7 +35,7 @@ class ForLoopTest(test.TestCase):
body=lambda i, s: (s + i,),
init_state=(0,))
with self.cached_session() as sess:
self.assertEqual((10,), sess.run(s))
self.assertEqual((10,), self.evaluate(s))
def test_python(self):
s = control_flow.for_stmt(
@ -53,7 +53,7 @@ class ForLoopTest(test.TestCase):
body=lambda i, s: (s + i,),
init_state=(0,))
with self.cached_session() as sess:
self.assertEqual((10,), sess.run(s))
self.assertEqual((10,), self.evaluate(s))
class WhileLoopTest(test.TestCase):
@ -66,7 +66,7 @@ class WhileLoopTest(test.TestCase):
init_state=(0, 0),
extra_deps=(n,))
with self.cached_session() as sess:
self.assertEqual((5, 10), sess.run(results))
self.assertEqual((5, 10), self.evaluate(results))
def test_python(self):
n = 5
@ -90,9 +90,9 @@ class IfStmtTest(test.TestCase):
def test_tensor(self):
with self.cached_session() as sess:
t = self.single_return_if_stmt(constant_op.constant(True))
self.assertEqual(1, sess.run(t))
self.assertEqual(1, self.evaluate(t))
t = self.single_return_if_stmt(constant_op.constant(False))
self.assertEqual(-1, sess.run(t))
self.assertEqual(-1, self.evaluate(t))
def test_python(self):
self.assertEqual(1, self.single_return_if_stmt(True))
@ -101,9 +101,9 @@ class IfStmtTest(test.TestCase):
def test_tensor_multiple_returns(self):
with self.cached_session() as sess:
t = self.multi_return_if_stmt(constant_op.constant(True))
self.assertAllEqual([1, 2], sess.run(t))
self.assertAllEqual([1, 2], self.evaluate(t))
t = self.multi_return_if_stmt(constant_op.constant(False))
self.assertAllEqual([-1, -2], sess.run(t))
self.assertAllEqual([-1, -2], self.evaluate(t))
def test_python_multiple_returns(self):
self.assertEqual((1, 2), self.multi_return_if_stmt(True))

View File

@ -43,7 +43,7 @@ class ListTest(test.TestCase):
l = data_structures.tf_tensor_list_new([3, 4, 5])
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
with self.cached_session() as sess:
self.assertAllEqual(sess.run(t), [3, 4, 5])
self.assertAllEqual(self.evaluate(t), [3, 4, 5])
def test_tf_tensor_list_new_empty(self):
l = data_structures.tf_tensor_list_new([],
@ -51,13 +51,13 @@ class ListTest(test.TestCase):
element_shape=())
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
with self.cached_session() as sess:
self.assertAllEqual(sess.run(t), [])
self.assertAllEqual(self.evaluate(t), [])
def test_tf_tensor_list_new_from_tensor(self):
l = data_structures.tf_tensor_list_new(constant_op.constant([3, 4, 5]))
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
with self.cached_session() as sess:
self.assertAllEqual(sess.run(t), [3, 4, 5])
self.assertAllEqual(self.evaluate(t), [3, 4, 5])
def test_tf_tensor_list_new_illegal_input(self):
with self.assertRaises(ValueError):
@ -77,7 +77,7 @@ class ListTest(test.TestCase):
l = data_structures.tf_tensor_array_new([3, 4, 5])
t = l.stack()
with self.cached_session() as sess:
self.assertAllEqual(sess.run(t), [3, 4, 5])
self.assertAllEqual(self.evaluate(t), [3, 4, 5])
def test_tf_tensor_array_new_illegal_input(self):
with self.assertRaises(ValueError):
@ -102,15 +102,15 @@ class ListTest(test.TestCase):
t = list_ops.tensor_list_stack(l, element_dtype=x.dtype)
with self.cached_session() as sess:
self.assertAllEqual(sess.run(t), [[1, 2, 3]])
self.assertAllEqual(self.evaluate(t), [[1, 2, 3]])
def test_append_tensorarray(self):
l = tensor_array_ops.TensorArray(dtypes.int32, size=0, dynamic_size=True)
l1 = data_structures.list_append(l, 1)
l2 = data_structures.list_append(l1, 2)
with self.cached_session() as sess:
self.assertAllEqual(sess.run(l1.stack()), [1])
self.assertAllEqual(sess.run(l2.stack()), [1, 2])
self.assertAllEqual(self.evaluate(l1.stack()), [1])
self.assertAllEqual(self.evaluate(l2.stack()), [1, 2])
def test_append_python(self):
l = []
@ -131,10 +131,10 @@ class ListTest(test.TestCase):
with self.cached_session() as sess:
l, x = data_structures.list_pop(l, None, opts)
self.assertAllEqual(sess.run(x), [3, 4])
self.assertAllEqual(self.evaluate(x), [3, 4])
t = list_ops.tensor_list_stack(l, element_dtype=initial_list.dtype)
self.assertAllEqual(sess.run(t), [[1, 2]])
self.assertAllEqual(self.evaluate(t), [[1, 2]])
def test_pop_python(self):
l = [1, 2, 3]
@ -152,7 +152,7 @@ class ListTest(test.TestCase):
with self.cached_session() as sess:
t = data_structures.list_stack(l, opts)
self.assertAllEqual(sess.run(t), sess.run(initial_list))
self.assertAllEqual(self.evaluate(t), self.evaluate(initial_list))
def test_stack_tensor_list_empty(self):
l = list_ops.empty_tensor_list(

View File

@ -30,7 +30,7 @@ class ExceptionsTest(test.TestCase):
with self.cached_session() as sess:
t = exceptions.assert_stmt(
constant_op.constant(True), lambda: constant_op.constant('ignored'))
sess.run(t)
self.evaluate(t)
def test_assert_tf_triggered(self):
with self.cached_session() as sess:
@ -40,7 +40,7 @@ class ExceptionsTest(test.TestCase):
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
'test message'):
sess.run(t)
self.evaluate(t)
def test_assert_tf_multiple_printed_values(self):
two_tensors = [
@ -53,7 +53,7 @@ class ExceptionsTest(test.TestCase):
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
'test message.*another message'):
sess.run(t)
self.evaluate(t)
def test_assert_python_untriggered(self):
side_effect_trace = []

View File

@ -45,11 +45,11 @@ class LogicalOperatorsTest(test.TestCase):
def test_and_tf(self):
with self.cached_session() as sess:
t = logical.and_(self._tf_true, self._tf_true)
self.assertEqual(sess.run(t), True)
self.assertEqual(self.evaluate(t), True)
t = logical.and_(self._tf_true, lambda: True)
self.assertEqual(sess.run(t), True)
self.assertEqual(self.evaluate(t), True)
t = logical.and_(self._tf_false, lambda: True)
self.assertEqual(sess.run(t), False)
self.assertEqual(self.evaluate(t), False)
# TODO(mdan): Add a test for ops with side effects.
def test_or_python(self):
@ -63,11 +63,11 @@ class LogicalOperatorsTest(test.TestCase):
def test_or_tf(self):
with self.cached_session() as sess:
t = logical.or_(self._tf_false, self._tf_true)
self.assertEqual(sess.run(t), True)
self.assertEqual(self.evaluate(t), True)
t = logical.or_(self._tf_false, lambda: True)
self.assertEqual(sess.run(t), True)
self.assertEqual(self.evaluate(t), True)
t = logical.or_(self._tf_true, lambda: True)
self.assertEqual(sess.run(t), True)
self.assertEqual(self.evaluate(t), True)
# TODO(mdan): Add a test for ops with side effects.
def test_not_python(self):
@ -78,7 +78,7 @@ class LogicalOperatorsTest(test.TestCase):
def test_not_tf(self):
with self.cached_session() as sess:
t = logical.not_(self._tf_false())
self.assertEqual(sess.run(t), True)
self.assertEqual(self.evaluate(t), True)
if __name__ == '__main__':

View File

@ -38,29 +38,29 @@ class PyBuiltinsTest(test.TestCase):
self.assertEqual(py_builtins.abs_(-1), 1)
with self.cached_session() as sess:
t = py_builtins.abs_(constant_op.constant(-1))
self.assertEqual(sess.run(t), 1)
self.assertEqual(self.evaluate(t), 1)
t = py_builtins.abs_(constant_op.constant([-1, 2, -3]))
self.assertAllEqual(sess.run(t), [1, 2, 3])
self.assertAllEqual(self.evaluate(t), [1, 2, 3])
def test_float(self):
self.assertEqual(py_builtins.float_(10), 10.0)
self.assertEqual(py_builtins.float_('10.0'), 10.0)
with self.cached_session() as sess:
t = py_builtins.float_(constant_op.constant(1, dtype=dtypes.int64))
self.assertEqual(sess.run(t), 1.0)
self.assertEqual(self.evaluate(t), 1.0)
st = py_builtins.float_(constant_op.constant('1.0'))
self.assertEqual(sess.run(st), 1.0)
self.assertEqual(self.evaluate(st), 1.0)
def test_int(self):
self.assertEqual(py_builtins.int_(10.0), 10)
self.assertEqual(py_builtins.int_('11', 2), 3)
with self.cached_session() as sess:
t = py_builtins.int_(constant_op.constant(1, dtype=dtypes.float64))
self.assertEqual(sess.run(t), 1)
self.assertEqual(self.evaluate(t), 1)
st = py_builtins.int_(constant_op.constant('1'))
self.assertEqual(sess.run(st), 1)
self.assertEqual(self.evaluate(st), 1)
st = py_builtins.int_(constant_op.constant('1'), 10)
self.assertEqual(sess.run(st), 1)
self.assertEqual(self.evaluate(st), 1)
def test_int_unsupported_base(self):
t = constant_op.constant(1, dtype=dtypes.float64)
@ -73,9 +73,9 @@ class PyBuiltinsTest(test.TestCase):
t = py_builtins.len_(constant_op.constant([[1], [2], [3]]))
self.assertEqual(t, 3)
ta = py_builtins.len_(tensor_array_ops.TensorArray(dtypes.int32, size=5))
self.assertEqual(sess.run(ta), 5)
self.assertEqual(self.evaluate(ta), 5)
tl = py_builtins.len_(data_structures.tf_tensor_list_new([3, 4, 5]))
self.assertEqual(sess.run(tl), 3)
self.assertEqual(self.evaluate(tl), 3)
def test_len_scalar(self):
with self.assertRaises(ValueError):
@ -120,18 +120,18 @@ class PyBuiltinsTest(test.TestCase):
def test_range_tensor(self):
with self.cached_session() as sess:
r = py_builtins.range_(constant_op.constant(3))
self.assertAllEqual(sess.run(r), [0, 1, 2])
self.assertAllEqual(self.evaluate(r), [0, 1, 2])
r = py_builtins.range_(1, constant_op.constant(3))
self.assertAllEqual(sess.run(r), [1, 2])
self.assertAllEqual(self.evaluate(r), [1, 2])
r = py_builtins.range_(2, 0, constant_op.constant(-1))
self.assertAllEqual(sess.run(r), [2, 1])
self.assertAllEqual(self.evaluate(r), [2, 1])
def test_range_tensor_empty_range(self):
with self.session() as sess:
r = py_builtins.range_(constant_op.constant(-3))
self.assertAllEqual(sess.run(r), [])
self.assertAllEqual(self.evaluate(r), [])
r = py_builtins.range_(5, constant_op.constant(2))
self.assertAllEqual(sess.run(r), [])
self.assertAllEqual(self.evaluate(r), [])
if __name__ == '__main__':

View File

@ -34,7 +34,7 @@ class SlicesTest(test.TestCase):
with self.cached_session() as sess:
t = list_ops.tensor_list_stack(l, element_dtype=initial_list.dtype)
self.assertAllEqual(sess.run(t), [[5, 6], [3, 4]])
self.assertAllEqual(self.evaluate(t), [[5, 6], [3, 4]])
def test_get_item_tensor_list(self):
initial_list = constant_op.constant([[1, 2], [3, 4]])
@ -44,7 +44,7 @@ class SlicesTest(test.TestCase):
l, 1, slices.GetItemOpts(element_dtype=initial_list.dtype))
with self.cached_session() as sess:
self.assertAllEqual(sess.run(t), [3, 4])
self.assertAllEqual(self.evaluate(t), [3, 4])
def test_get_item_tensor_string(self):
initial_str = constant_op.constant('abcd')
@ -52,14 +52,14 @@ class SlicesTest(test.TestCase):
slices.GetItemOpts(element_dtype=initial_str.dtype))
with self.cached_session() as sess:
self.assertEqual(sess.run(t), b'b')
self.assertEqual(self.evaluate(t), b'b')
initial_list_str = constant_op.constant(['abcd', 'bcde'])
t = slices.get_item(initial_list_str, 1,
slices.GetItemOpts(element_dtype=initial_str.dtype))
with self.cached_session() as sess:
self.assertEqual(sess.run(t), b'bcde')
self.assertEqual(self.evaluate(t), b'bcde')
if __name__ == '__main__':

View File

@ -32,7 +32,7 @@ class MiscTest(test.TestCase):
new_a = alias_tensors(a)
self.assertFalse(new_a is a)
with self.cached_session() as sess:
self.assertEqual(1, sess.run(new_a))
self.assertEqual(1, self.evaluate(new_a))
def test_alias_tensors(self):
a = constant(1)
@ -47,7 +47,7 @@ class MiscTest(test.TestCase):
self.assertTrue(new_s is s)
self.assertTrue(new_l is l)
with self.cached_session() as sess:
self.assertEqual(1, sess.run(new_a))
self.assertEqual(1, self.evaluate(new_a))
if __name__ == '__main__':

View File

@ -34,13 +34,13 @@ class PyFuncTest(test.TestCase):
with self.cached_session() as sess:
result = py_func.wrap_py_func(test_fn, dtypes.int64,
(1, constant_op.constant(1), 1))
self.assertEqual(3, sess.run(result))
self.assertEqual(3, self.evaluate(result))
result = py_func.wrap_py_func(test_fn, dtypes.int64, (1, 1, 1))
self.assertEqual(3, sess.run(result))
self.assertEqual(3, self.evaluate(result))
result = py_func.wrap_py_func(
test_fn, dtypes.int64,
(constant_op.constant(1), 1, constant_op.constant(1)))
self.assertEqual(3, sess.run(result))
self.assertEqual(3, self.evaluate(result))
def test_wrap_py_func_complex_args(self):
@ -54,10 +54,10 @@ class PyFuncTest(test.TestCase):
with self.cached_session() as sess:
result = py_func.wrap_py_func(test_fn, dtypes.int64, (7, TestClass()))
self.assertEqual(35, sess.run(result))
self.assertEqual(35, self.evaluate(result))
result = py_func.wrap_py_func(test_fn, dtypes.int64,
(constant_op.constant(7), TestClass()))
self.assertEqual(35, sess.run(result))
self.assertEqual(35, self.evaluate(result))
def test_wrap_py_func_kwargs(self):
@ -74,13 +74,13 @@ class PyFuncTest(test.TestCase):
'c': 11,
'd': TestClass(13)
})
self.assertEqual(178, sess.run(result))
self.assertEqual(178, self.evaluate(result))
result = py_func.wrap_py_func(test_fn, dtypes.int64,
(constant_op.constant(7), TestClass(5)), {
'c': constant_op.constant(11),
'd': TestClass(13)
})
self.assertEqual(178, sess.run(result))
self.assertEqual(178, self.evaluate(result))
def test_wrap_py_func_dummy_return(self):
@ -91,11 +91,11 @@ class PyFuncTest(test.TestCase):
with self.cached_session() as sess:
result = py_func.wrap_py_func(test_fn, None, (5,), use_dummy_return=True)
self.assertEqual(1, sess.run(result))
self.assertEqual(1, self.evaluate(result))
self.assertEqual([1], side_counter)
result = py_func.wrap_py_func(
test_fn, None, (constant_op.constant(5),), use_dummy_return=True)
self.assertEqual(1, sess.run(result))
self.assertEqual(1, self.evaluate(result))
self.assertEqual([2], side_counter)

View File

@ -43,13 +43,13 @@ class TensorListTest(test.TestCase):
l = tl.dynamic_list_append(l, 1)
s = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
with self.cached_session() as sess:
self.assertAllEqual(sess.run(s), [1])
self.assertAllEqual(self.evaluate(s), [1])
l = tensor_array_ops.TensorArray(dtypes.int32, size=0, dynamic_size=True)
l = tl.dynamic_list_append(l, 1)
s = l.stack()
with self.cached_session() as sess:
self.assertAllEqual(sess.run(s), [1])
self.assertAllEqual(self.evaluate(s), [1])
l = tl.TensorList(self._shape(()), dtypes.int32)
l = tl.dynamic_list_append(l, 1)
@ -92,7 +92,7 @@ class TensorListTest(test.TestCase):
a2 = l.pop()
c4 = l.count()
with Session() as sess:
c1, c2, c3, c4, a, a2 = sess.run([c1, c2, c3, c4, a, a2])
c1, c2, c3, c4, a, a2 = self.evaluate([c1, c2, c3, c4, a, a2])
self.assertEqual(c1, 1)
self.assertEqual(c2, 2)
self.assertEqual(c3, 1)
@ -108,7 +108,7 @@ class TensorListTest(test.TestCase):
l[0] = b
l1 = l[0]
with self.cached_session() as sess:
l0, l1, a, b = sess.run([l0, l1, a, b])
l0, l1, a, b = self.evaluate([l0, l1, a, b])
self.assertEqual(l0, a)
self.assertEqual(l1, b)

View File

@ -62,7 +62,7 @@ class SessionClusterSpecPropagationTest(test_util.TensorFlowTestCase):
const = constant_op.constant(17)
sess = session.Session(server1.target, config=config)
output = sess.run(const)
output = self.evaluate(const)
self.assertEqual(17, output)
def testClusterSpecPropagationWorker2Placement(self):
@ -106,7 +106,7 @@ class SessionClusterSpecPropagationTest(test_util.TensorFlowTestCase):
with ops.Graph().as_default() as g, ops.device('/job:worker/task:0'):
const = constant_op.constant(17)
sess = session.Session(server1.target, config=config, graph=g)
output = sess.run(const)
output = self.evaluate(const)
self.assertEqual(17, output)
def testCanonicalDeviceNames(self):
@ -208,7 +208,7 @@ class SessionClusterSpecPropagationTest(test_util.TensorFlowTestCase):
with ops.device('/job:worker/task:0/cpu:0'):
sum3 = sum1 + sum2
sess = session.Session(server1.target, config=config, graph=g)
output = sess.run(sum3)
output = self.evaluate(sum3)
self.assertEqual(40, output)
def testLegacyDeviceNames(self):

View File

@ -117,7 +117,7 @@ class PartialRunTest(test_util.TensorFlowTestCase):
a = constant_op.constant(2.0, dtypes.float32)
b = a * 2
c = b * 3
r1 = sess.run([b, c])
r1 = self.evaluate([b, c])
h = sess.partial_run_setup([b, c], [])
r2 = sess.partial_run(h, [b, c])
self.assertEqual(r1, r2)

View File

@ -147,7 +147,7 @@ class TimelineTest(test.TestCase):
num2 = variables.Variable(2.0, name='num2')
with ops.device('/cpu:2'):
result = num1 + num2 + num1 * num2
sess.run(variables.global_variables_initializer())
self.evaluate(variables.global_variables_initializer())
sess.run(result, options=run_options, run_metadata=run_metadata)
self.assertTrue(run_metadata.HasField('step_stats'))
@ -176,7 +176,7 @@ class TimelineTest(test.TestCase):
num2 = variables.Variable(2.0, name='num2')
with ops.device('/cpu:2'):
result = num1 + num2 + num1 * num2
sess.run(variables.global_variables_initializer())
self.evaluate(variables.global_variables_initializer())
sess.run(result, options=run_options, run_metadata=run_metadata)
self.assertTrue(run_metadata.HasField('step_stats'))
step_stats = run_metadata.step_stats

View File

@ -216,7 +216,7 @@ class VirtualGpuTest(test_util.TensorFlowTestCase):
for d in self._util.devices:
with ops.device(d):
var = variables.Variable(random_ops.random_uniform(mat_shape))
sess.run(var.initializer)
self.evaluate(var.initializer)
data.append(var)
s = data[0]
for i in range(1, len(data)):

View File

@ -110,9 +110,9 @@ class BucketBySequenceLengthTest(test_base.DatasetTestBase):
with self.cached_session() as sess:
batches = []
for _ in range(4):
batches.append(sess.run(batch))
batches.append(self.evaluate(batch))
with self.assertRaises(errors.OutOfRangeError):
sess.run(batch)
self.evaluate(batch)
batch_sizes_val = []
lengths_val = []
for batch in batches:
@ -160,9 +160,9 @@ class BucketBySequenceLengthTest(test_base.DatasetTestBase):
with self.cached_session() as sess:
batches = []
for _ in range(3):
batches.append(sess.run(batch))
batches.append(self.evaluate(batch))
with self.assertRaisesOpError("bucket_boundaries"):
sess.run(batch)
self.evaluate(batch)
batch_sizes_val = []
lengths_val = []
for batch in batches:
@ -197,9 +197,9 @@ class BucketBySequenceLengthTest(test_base.DatasetTestBase):
with self.cached_session() as sess:
batches = []
for _ in range(5):
batches.append(sess.run(batch))
batches.append(self.evaluate(batch))
with self.assertRaises(errors.OutOfRangeError):
sess.run(batch)
self.evaluate(batch)
self.assertAllEqual(batches[0], [[1, 0],
[1, 1]])
@ -300,7 +300,7 @@ class BucketBySequenceLengthTest(test_base.DatasetTestBase):
with self.cached_session() as sess:
with self.assertRaises(errors.OutOfRangeError):
while True:
output = sess.run(batch)
output = self.evaluate(batch)
sprs_tensor = (tuple([tuple(idx) for idx in output.indices]),
tuple(output.values))
all_sparse_tensors.add(sprs_tensor)

View File

@ -57,9 +57,9 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config) as sess:
for i in range(10):
self.assertEqual(i, sess.run(next_element))
self.assertEqual(i, self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
def testCopyToDeviceInt32(self):
host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3])
@ -82,9 +82,9 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config) as sess:
self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
self.assertAllEqual([0, 1, 2, 3], self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
def testCopyToSameDevice(self):
host_dataset = dataset_ops.Dataset.range(10)
@ -108,9 +108,9 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config) as sess:
for i in range(10):
self.assertEqual(i, sess.run(next_element))
self.assertEqual(i, self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
def testCopyToDeviceWithPrefetch(self):
host_dataset = dataset_ops.Dataset.range(10)
@ -134,9 +134,9 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config) as sess:
for i in range(10):
self.assertEqual(i, sess.run(next_element))
self.assertEqual(i, self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
def testCopyDictToDevice(self):
host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
@ -160,9 +160,9 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config) as sess:
for i in range(10):
self.assertEqual({"a": i}, sess.run(next_element))
self.assertEqual({"a": i}, self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
def testCopyDictToDeviceWithPrefetch(self):
host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
@ -186,9 +186,9 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config) as sess:
for i in range(10):
self.assertEqual({"a": i}, sess.run(next_element))
self.assertEqual({"a": i}, self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
def testCopySparseTensorsToDevice(self):
@ -217,12 +217,12 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config) as sess:
for i in range(10):
actual = sess.run(next_element)
actual = self.evaluate(next_element)
self.assertAllEqual([i], actual.values)
self.assertAllEqual([[0, 0]], actual.indices)
self.assertAllEqual([2, 2], actual.dense_shape)
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
def testCopySparseTensorsToDeviceWithPrefetch(self):
@ -251,12 +251,12 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config) as sess:
for i in range(10):
actual = sess.run(next_element)
actual = self.evaluate(next_element)
self.assertAllEqual([i], actual.values)
self.assertAllEqual([[0, 0]], actual.indices)
self.assertAllEqual([2, 2], actual.dense_shape)
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
def testCopyToDeviceGpu(self):
if not test_util.is_gpu_available():
@ -271,11 +271,11 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
next_element = iterator.get_next()
with self.cached_session() as sess:
sess.run(iterator.initializer)
self.evaluate(iterator.initializer)
for i in range(10):
self.assertEqual(i, sess.run(next_element))
self.assertEqual(i, self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
def testCopyToDeviceGpuWithPrefetch(self):
if not test_util.is_gpu_available():
@ -290,11 +290,11 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
next_element = iterator.get_next()
with self.cached_session() as sess:
sess.run(iterator.initializer)
self.evaluate(iterator.initializer)
for i in range(10):
self.assertEqual(i, sess.run(next_element))
self.assertEqual(i, self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
def testCopyToDeviceGpuWithMap(self):
if not test_util.is_gpu_available():
@ -323,14 +323,14 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
next_element = iterator.get_next()
with self.cached_session() as sess:
sess.run(iterator.initializer)
self.evaluate(iterator.initializer)
for i in range(10):
x, y, z = sess.run(next_element)
x, y, z = self.evaluate(next_element)
self.assertEqual(i**2, x)
self.assertEqual(float(i**2), y)
self.assertEqual(util_compat.as_bytes(str(i)), z)
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
def testCopyToDeviceGpuInt32(self):
if not test_util.is_gpu_available():
@ -345,10 +345,10 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
next_element = iterator.get_next()
with self.cached_session() as sess:
sess.run(iterator.initializer)
self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
self.evaluate(iterator.initializer)
self.assertAllEqual([0, 1, 2, 3], self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
def testCopyToDeviceGpuInt32AndPrefetch(self):
if not test_util.is_gpu_available():
@ -363,10 +363,10 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
next_element = iterator.get_next()
with self.cached_session() as sess:
sess.run(iterator.initializer)
self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
self.evaluate(iterator.initializer)
self.assertAllEqual([0, 1, 2, 3], self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
def testCopyToDeviceGpuStrings(self):
if not test_util.is_gpu_available():
@ -381,10 +381,10 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
next_element = iterator.get_next()
with self.cached_session() as sess:
sess.run(iterator.initializer)
self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element))
self.evaluate(iterator.initializer)
self.assertAllEqual([b"a", b"b", b"c"], self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
def testCopyToDeviceGpuStringsAndPrefetch(self):
if not test_util.is_gpu_available():
@ -399,10 +399,10 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
next_element = iterator.get_next()
with self.cached_session() as sess:
sess.run(iterator.initializer)
self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element))
self.evaluate(iterator.initializer)
self.assertAllEqual([b"a", b"b", b"c"], self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
def testCopyToDevicePingPongCPUGPU(self):
if not test_util.is_gpu_available():
@ -420,11 +420,11 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
next_element = iterator.get_next()
with self.cached_session() as sess:
sess.run(iterator.initializer)
self.evaluate(iterator.initializer)
for i in range(10):
self.assertEqual(i, sess.run(next_element))
self.assertEqual(i, self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
def testCopyToDeviceWithReInit(self):
host_dataset = dataset_ops.Dataset.range(10)
@ -447,14 +447,14 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config) as sess:
sess.run(iterator.initializer)
self.evaluate(iterator.initializer)
for i in range(5):
self.assertEqual(i, sess.run(next_element))
sess.run(iterator.initializer)
self.assertEqual(i, self.evaluate(next_element))
self.evaluate(iterator.initializer)
for i in range(10):
self.assertEqual(i, sess.run(next_element))
self.assertEqual(i, self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
def testCopyToDeviceWithReInitAndPrefetch(self):
host_dataset = dataset_ops.Dataset.range(10)
@ -477,14 +477,14 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config) as sess:
sess.run(iterator.initializer)
self.evaluate(iterator.initializer)
for i in range(5):
self.assertEqual(i, sess.run(next_element))
sess.run(iterator.initializer)
self.assertEqual(i, self.evaluate(next_element))
self.evaluate(iterator.initializer)
for i in range(10):
self.assertEqual(i, sess.run(next_element))
self.assertEqual(i, self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
def testCopyToDeviceGpuWithReInit(self):
if not test_util.is_gpu_available():
@ -499,14 +499,14 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
next_element = iterator.get_next()
with self.cached_session() as sess:
sess.run(iterator.initializer)
self.evaluate(iterator.initializer)
for i in range(5):
self.assertEqual(i, sess.run(next_element))
sess.run(iterator.initializer)
self.assertEqual(i, self.evaluate(next_element))
self.evaluate(iterator.initializer)
for i in range(10):
self.assertEqual(i, sess.run(next_element))
self.assertEqual(i, self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
def testCopyToDeviceGpuWithReInitAndPrefetch(self):
if not test_util.is_gpu_available():
@ -521,14 +521,14 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
next_element = iterator.get_next()
with self.cached_session() as sess:
sess.run(iterator.initializer)
self.evaluate(iterator.initializer)
for i in range(5):
self.assertEqual(i, sess.run(next_element))
sess.run(iterator.initializer)
self.assertEqual(i, self.evaluate(next_element))
self.evaluate(iterator.initializer)
for i in range(10):
self.assertEqual(i, sess.run(next_element))
self.assertEqual(i, self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
def testIteratorGetNextAsOptionalOnGPU(self):
if not test_util.is_gpu_available():
@ -547,24 +547,25 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
# Before initializing the iterator, evaluating the optional fails with
# a FailedPreconditionError.
with self.assertRaises(errors.FailedPreconditionError):
sess.run(elem_has_value_t)
self.evaluate(elem_has_value_t)
with self.assertRaises(errors.FailedPreconditionError):
sess.run(elem_value_t)
self.evaluate(elem_value_t)
# For each element of the dataset, assert that the optional evaluates to
# the expected value.
sess.run(iterator.initializer)
self.evaluate(iterator.initializer)
for i in range(3):
elem_has_value, elem_value = sess.run([elem_has_value_t, elem_value_t])
elem_has_value, elem_value = self.evaluate(
[elem_has_value_t, elem_value_t])
self.assertTrue(elem_has_value)
self.assertEqual(i, elem_value)
# After exhausting the iterator, `next_elem.has_value()` will evaluate to
# false, and attempting to get the value will fail.
for _ in range(2):
self.assertFalse(sess.run(elem_has_value_t))
self.assertFalse(self.evaluate(elem_has_value_t))
with self.assertRaises(errors.InvalidArgumentError):
sess.run(elem_value_t)
self.evaluate(elem_value_t)
if __name__ == "__main__":

View File

@ -38,13 +38,13 @@ class CounterTest(test_base.DatasetTestBase):
negative_get_next = negative_iterator.get_next()
with self.cached_session() as sess:
self.assertEqual(3, sess.run(get_next))
self.assertEqual(3 + 4, sess.run(get_next))
self.assertEqual(3 + 2 * 4, sess.run(get_next))
self.assertEqual(3, self.evaluate(get_next))
self.assertEqual(3 + 4, self.evaluate(get_next))
self.assertEqual(3 + 2 * 4, self.evaluate(get_next))
self.assertEqual(0, sess.run(negative_get_next))
self.assertEqual(-1, sess.run(negative_get_next))
self.assertEqual(-2, sess.run(negative_get_next))
self.assertEqual(0, self.evaluate(negative_get_next))
self.assertEqual(-1, self.evaluate(negative_get_next))
self.assertEqual(-2, self.evaluate(negative_get_next))
if __name__ == "__main__":

View File

@ -41,10 +41,10 @@ class DenseToSparseBatchTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
self.evaluate(init_op)
for start in range(0, len(components), 4):
results = sess.run(get_next)
results = self.evaluate(get_next)
self.assertAllEqual([[i, j]
for i, c in enumerate(components[start:start + 4])
for j in range(c)], results.indices)
@ -56,7 +56,7 @@ class DenseToSparseBatchTest(test_base.DatasetTestBase):
results.dense_shape)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.evaluate(get_next)
def testDenseToSparseBatchDatasetWithUnknownShape(self):
components = np.random.randint(5, size=(40,)).astype(np.int32)
@ -69,10 +69,10 @@ class DenseToSparseBatchTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
self.evaluate(init_op)
for start in range(0, len(components), 4):
results = sess.run(get_next)
results = self.evaluate(get_next)
self.assertAllEqual([[i, j, z]
for i, c in enumerate(components[start:start + 4])
for j in range(c)
@ -89,7 +89,7 @@ class DenseToSparseBatchTest(test_base.DatasetTestBase):
], results.dense_shape)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.evaluate(get_next)
def testDenseToSparseBatchDatasetWithInvalidShape(self):
input_tensor = array_ops.constant([[1]])
@ -111,13 +111,13 @@ class DenseToSparseBatchTest(test_base.DatasetTestBase):
sess.run(init_op, feed_dict={input_tensor: [[1]]})
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"incompatible with the row shape"):
sess.run(get_next)
self.evaluate(get_next)
# Initialize with an input tensor that is larger than `row_shape`.
sess.run(init_op, feed_dict={input_tensor: range(13)})
with self.assertRaisesRegexp(errors.DataLossError,
"larger than the row shape"):
sess.run(get_next)
self.evaluate(get_next)
if __name__ == "__main__":

View File

@ -40,12 +40,12 @@ class DirectedInterleaveDatasetTest(test_base.DatasetTestBase):
next_element = iterator.get_next()
with self.cached_session() as sess:
sess.run(iterator.initializer)
self.evaluate(iterator.initializer)
for _ in range(100):
for i in range(10):
self.assertEqual(i, sess.run(next_element))
self.assertEqual(i, self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
def _normalize(self, vec):
return vec / vec.sum()
@ -71,9 +71,9 @@ class DirectedInterleaveDatasetTest(test_base.DatasetTestBase):
with self.cached_session() as sess:
freqs = np.zeros([num_datasets])
for _ in range(num_samples):
freqs[sess.run(next_element)] += 1
freqs[self.evaluate(next_element)] += 1
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
return freqs
@ -107,9 +107,9 @@ class DirectedInterleaveDatasetTest(test_base.DatasetTestBase):
with self.cached_session() as sess:
for i in choice_array:
self.assertEqual(words[i], sess.run(next_element))
self.assertEqual(words[i], self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
def testErrors(self):
with self.assertRaisesRegexp(ValueError,

View File

@ -44,12 +44,12 @@ class EnumerateDatasetTest(test_base.DatasetTestBase):
[t.shape for t in get_next[1]])
with self.cached_session() as sess:
sess.run(init_op)
self.assertEqual((20, (b"a", 1, 37.0)), sess.run(get_next))
self.assertEqual((21, (b"b", 2, 38.0)), sess.run(get_next))
self.evaluate(init_op)
self.assertEqual((20, (b"a", 1, 37.0)), self.evaluate(get_next))
self.assertEqual((21, (b"b", 2, 38.0)), self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.evaluate(get_next)
if __name__ == "__main__":

View File

@ -52,12 +52,12 @@ class FilterBenchmark(test.Benchmark):
with session.Session() as sess:
for _ in range(10):
sess.run(next_element.op)
self.evaluate(next_element.op)
deltas = []
for _ in range(100):
start = time.time()
for _ in range(100):
sess.run(next_element.op)
self.evaluate(next_element.op)
end = time.time()
deltas.append(end - start)

View File

@ -94,18 +94,18 @@ class FunctionBufferingResourceTest(test_base.DatasetTestBase):
device0, device1)
with self.test_session(config=worker_config) as sess:
elem = sess.run(prefetch_op)
elem = self.evaluate(prefetch_op)
self.assertEqual(elem, [1.0])
elem = sess.run(prefetch_op)
elem = self.evaluate(prefetch_op)
self.assertEqual(elem, [2.0])
elem = sess.run(prefetch_op)
elem = self.evaluate(prefetch_op)
self.assertEqual(elem, [3.0])
elem = sess.run(prefetch_op)
elem = self.evaluate(prefetch_op)
self.assertEqual(elem, [4.0])
self._event.wait()
elem = sess.run(prefetch_op)
elem = self.evaluate(prefetch_op)
self.assertEqual(elem, [5.0])
sess.run(destroy_op)
self.evaluate(destroy_op)
def testSameDeviceCPU(self):
self._prefetch_fn_helper_one_shot("same_device_cpu",
@ -135,35 +135,35 @@ class FunctionBufferingResourceTest(test_base.DatasetTestBase):
ds, ds_iterator, "reinit", device0, device1)
with self.test_session(config=worker_config) as sess:
sess.run(ds_iterator.initializer)
elem = sess.run(prefetch_op)
self.evaluate(ds_iterator.initializer)
elem = self.evaluate(prefetch_op)
self.assertEqual(elem, [1.0])
elem = sess.run(prefetch_op)
elem = self.evaluate(prefetch_op)
self.assertEqual(elem, [2.0])
elem = sess.run(prefetch_op)
elem = self.evaluate(prefetch_op)
self.assertEqual(elem, [3.0])
elem = sess.run(prefetch_op)
elem = self.evaluate(prefetch_op)
self.assertEqual(elem, [4.0])
self._event.wait()
elem = sess.run(prefetch_op)
elem = self.evaluate(prefetch_op)
self.assertEqual(elem, [5.0])
# Lets reset the function buffering resource and reinitialize the
# iterator. Should be able to go through this again.
self._event.clear()
sess.run(reset_op)
sess.run(ds_iterator.initializer)
elem = sess.run(prefetch_op)
self.evaluate(reset_op)
self.evaluate(ds_iterator.initializer)
elem = self.evaluate(prefetch_op)
self.assertEqual(elem, [1.0])
elem = sess.run(prefetch_op)
elem = self.evaluate(prefetch_op)
self.assertEqual(elem, [2.0])
elem = sess.run(prefetch_op)
elem = self.evaluate(prefetch_op)
self.assertEqual(elem, [3.0])
elem = sess.run(prefetch_op)
elem = self.evaluate(prefetch_op)
self.assertEqual(elem, [4.0])
self._event.wait()
elem = sess.run(prefetch_op)
elem = self.evaluate(prefetch_op)
self.assertEqual(elem, [5.0])
sess.run(destroy_op)
self.evaluate(destroy_op)
def testReinitializationOutOfRange(self):
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
@ -175,30 +175,30 @@ class FunctionBufferingResourceTest(test_base.DatasetTestBase):
ds, ds_iterator, "reinit", device0, device1)
with self.test_session(config=worker_config) as sess:
sess.run(ds_iterator.initializer)
self.evaluate(ds_iterator.initializer)
for i in range(1, 10):
elem = sess.run(prefetch_op)
elem = self.evaluate(prefetch_op)
self.assertEqual(elem, [float(i)])
# Try fetching after its over twice to test out end of sequence.
with self.assertRaises(errors.OutOfRangeError):
sess.run(prefetch_op)
self.evaluate(prefetch_op)
with self.assertRaises(errors.OutOfRangeError):
sess.run(prefetch_op)
self.evaluate(prefetch_op)
# Now reset everything and try it out again.
self._event.clear()
sess.run(reset_op)
sess.run(ds_iterator.initializer)
self.evaluate(reset_op)
self.evaluate(ds_iterator.initializer)
for i in range(1, 10):
elem = sess.run(prefetch_op)
elem = self.evaluate(prefetch_op)
self.assertEqual(elem, [float(i)])
# Try fetching after its over twice to test out end of sequence.
with self.assertRaises(errors.OutOfRangeError):
sess.run(prefetch_op)
self.evaluate(prefetch_op)
with self.assertRaises(errors.OutOfRangeError):
sess.run(prefetch_op)
self.evaluate(prefetch_op)
sess.run(destroy_op)
self.evaluate(destroy_op)
def testStringsGPU(self):
if not test_util.is_gpu_available():
@ -235,13 +235,13 @@ class FunctionBufferingResourceTest(test_base.DatasetTestBase):
buffer_resource_handle, ignore_lookup_error=True)
with self.cached_session() as sess:
self.assertEqual([b"a"], sess.run(prefetch_op))
self.assertEqual([b"b"], sess.run(prefetch_op))
self.assertEqual([b"c"], sess.run(prefetch_op))
self.assertEqual([b"a"], self.evaluate(prefetch_op))
self.assertEqual([b"b"], self.evaluate(prefetch_op))
self.assertEqual([b"c"], self.evaluate(prefetch_op))
with self.assertRaises(errors.OutOfRangeError):
sess.run(prefetch_op)
self.evaluate(prefetch_op)
sess.run(destroy_op)
self.evaluate(destroy_op)
if __name__ == "__main__":

View File

@ -39,10 +39,10 @@ class GroupByReducerTest(test_base.DatasetTestBase):
get_next = dataset.make_one_shot_iterator().get_next()
with self.cached_session() as sess:
for expected in values:
got = sess.run(get_next)
got = self.evaluate(get_next)
self.assertEqual(got, expected)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.evaluate(get_next)
def testSum(self):
reducer = grouping.Reducer(
@ -127,11 +127,11 @@ class GroupByReducerTest(test_base.DatasetTestBase):
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
with self.cached_session() as sess:
x, y = sess.run(get_next)
x, y = self.evaluate(get_next)
self.assertAllEqual([0] * (2**i), x)
self.assertAllEqual(np.array(1, ndmin=i), y)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.evaluate(get_next)
def testTypeMismatch(self):
reducer = grouping.Reducer(
@ -190,7 +190,7 @@ class GroupByReducerTest(test_base.DatasetTestBase):
grouping.group_by_reducer(lambda x, y: np.int64(0), reducer))
get_next = dataset.make_one_shot_iterator().get_next()
with self.cached_session() as sess:
x, y = sess.run(get_next)
x, y = self.evaluate(get_next)
self.assertAllEqual(x, np.asarray([x for x in range(10)]))
self.assertEqual(y, 45)

View File

@ -68,9 +68,9 @@ class GroupByWindowTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
self.evaluate(init_op)
which_bucket, bucketed_values = sess.run(get_next)
which_bucket, bucketed_values = self.evaluate(get_next)
self.assertEqual(0, which_bucket)
@ -103,11 +103,11 @@ class GroupByWindowTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
self.evaluate(init_op)
# Get two minibatches (one containing even values, one containing odds)
which_bucket_even, bucketed_values_even = sess.run(get_next)
which_bucket_odd, bucketed_values_odd = sess.run(get_next)
which_bucket_even, bucketed_values_even = self.evaluate(get_next)
which_bucket_odd, bucketed_values_odd = self.evaluate(get_next)
# Count number of bucket_tensors.
self.assertEqual(3, len(bucketed_values_even))
@ -174,11 +174,11 @@ class GroupByWindowTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
self.evaluate(init_op)
# Get two minibatches ([0, 2, ...] and [64, 66, ...])
which_bucket0, bucketed_values_even0 = sess.run(get_next)
which_bucket1, bucketed_values_even1 = sess.run(get_next)
which_bucket0, bucketed_values_even0 = self.evaluate(get_next)
which_bucket1, bucketed_values_even1 = self.evaluate(get_next)
# Ensure that bucket 1 was completely filtered out
self.assertAllEqual(0, which_bucket0)
@ -207,11 +207,11 @@ class GroupByWindowTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
self.evaluate(init_op)
with self.assertRaises(errors.OutOfRangeError):
batches = 0
while True:
result = sess.run(get_next)
result = self.evaluate(get_next)
is_even = all(x % 2 == 0 for x in result)
is_odd = all(x % 2 == 1 for x in result)
self.assertTrue(is_even or is_odd)
@ -232,11 +232,11 @@ class GroupByWindowTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
self.evaluate(init_op)
counts = []
with self.assertRaises(errors.OutOfRangeError):
while True:
result = sess.run(get_next)
result = self.evaluate(get_next)
self.assertTrue(
all(x % 2 == 0
for x in result) or all(x % 2 == 1)
@ -259,16 +259,16 @@ class GroupByWindowTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
self.evaluate(init_op)
# The input is infinite, so this test demonstrates that:
# 1. We produce output without having to consume the entire input,
# 2. Different buckets can produce output at different rates, and
# 3. For deterministic input, the output is deterministic.
for _ in range(3):
self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
self.assertAllEqual([2, 2, 2, 2], sess.run(get_next))
self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
self.assertAllEqual([0, 0, 0, 0], self.evaluate(get_next))
self.assertAllEqual([1, 1, 1, 1], self.evaluate(get_next))
self.assertAllEqual([2, 2, 2, 2], self.evaluate(get_next))
self.assertAllEqual([0, 0, 0, 0], self.evaluate(get_next))
def testSmallGroups(self):
components = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0], dtype=np.int64)
@ -280,13 +280,13 @@ class GroupByWindowTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
self.evaluate(init_op)
self.assertAllEqual([0, 0, 0, 0], self.evaluate(get_next))
self.assertAllEqual([1, 1, 1, 1], self.evaluate(get_next))
# The small outputs at the end are deterministically produced in key
# order.
self.assertAllEqual([0, 0, 0], sess.run(get_next))
self.assertAllEqual([1], sess.run(get_next))
self.assertAllEqual([0, 0, 0], self.evaluate(get_next))
self.assertAllEqual([1], self.evaluate(get_next))
def testEmpty(self):
iterator = (
@ -297,11 +297,11 @@ class GroupByWindowTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
self.evaluate(init_op)
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
"Window size must be greater than zero, but got 0."):
print(sess.run(get_next))
print(self.evaluate(get_next))
def testReduceFuncError(self):
components = np.random.randint(100, size=(200,)).astype(np.int64)
@ -323,9 +323,9 @@ class GroupByWindowTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
self.evaluate(init_op)
with self.assertRaises(errors.InvalidArgumentError):
sess.run(get_next)
self.evaluate(get_next)
def testConsumeWindowDatasetMoreThanOnce(self):
components = np.random.randint(50, size=(200,)).astype(np.int64)
@ -351,11 +351,11 @@ class GroupByWindowTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
self.evaluate(init_op)
counts = []
with self.assertRaises(errors.OutOfRangeError):
while True:
tight_result, multiple_of_10_result = sess.run(get_next)
tight_result, multiple_of_10_result = self.evaluate(get_next)
self.assertEqual(0, multiple_of_10_result.shape[1] % 10)
self.assertAllEqual(tight_result,
multiple_of_10_result[:, :tight_result.shape[1]])

View File

@ -47,11 +47,11 @@ class IgnoreErrorsTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
self.evaluate(init_op)
for x in [1., 2., 3., 5.]:
self.assertEqual(x, sess.run(get_next))
self.assertEqual(x, self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.evaluate(get_next)
def testParallelMapIgnoreError(self):
components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
@ -65,11 +65,11 @@ class IgnoreErrorsTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
self.evaluate(init_op)
for x in [1., 2., 3., 5.]:
self.assertEqual(x, sess.run(get_next))
self.assertEqual(x, self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.evaluate(get_next)
def testReadFileIgnoreError(self):
@ -93,22 +93,22 @@ class IgnoreErrorsTest(test_base.DatasetTestBase):
with self.cached_session() as sess:
# All of the files are present.
sess.run(init_op)
self.evaluate(init_op)
for filename in filenames:
self.assertEqual(compat.as_bytes(filename), sess.run(get_next))
self.assertEqual(compat.as_bytes(filename), self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.evaluate(get_next)
# Delete one of the files.
os.remove(filenames[0])
# Attempting to read filenames[0] will fail, but ignore_errors()
# will catch the error.
sess.run(init_op)
self.evaluate(init_op)
for filename in filenames[1:]:
self.assertEqual(compat.as_bytes(filename), sess.run(get_next))
self.assertEqual(compat.as_bytes(filename), self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.evaluate(get_next)
if __name__ == "__main__":

View File

@ -46,14 +46,14 @@ class IndexedDatasetOpsTest(test_base.DatasetTestBase):
handle, index, output_types=[dtypes.uint64], output_shapes=[[]])
with self.cached_session() as sess:
sess.run(materialize)
self.evaluate(materialize)
self.assertEqual([3], sess.run(get_op, feed_dict={index: 3}))
def testIdentityIndexedDataset(self):
ds = indexed_dataset_ops.IdentityIndexedDataset(16)
materialized = ds.materialize()
with self.cached_session() as sess:
sess.run(materialized.initializer)
self.evaluate(materialized.initializer)
placeholder = array_ops.placeholder(dtypes.uint64, shape=[])
for i in range(16):
output = sess.run(
@ -68,12 +68,13 @@ class IndexedDatasetOpsTest(test_base.DatasetTestBase):
itr = ds.make_initializable_iterator()
n = itr.get_next()
with self.cached_session() as sess:
sess.run(itr.initializer)
self.evaluate(itr.initializer)
for i in range(16):
output = sess.run(n)
output = self.evaluate(n)
self.assertEqual(i, output)
with self.assertRaises(errors.OutOfRangeError):
sess.run(n)
self.evaluate(n)
if __name__ == "__main__":
test.main()

View File

@ -112,14 +112,14 @@ class MakeBatchedFeaturesDatasetTest(
next_element = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
self.evaluate(init_op)
for file_batch, _, _, _, record_batch, _ in self._next_expected_batch(
range(self._num_files), 2, 10):
actual_batch = sess.run(next_element)
actual_batch = self.evaluate(next_element)
self.assertAllEqual(file_batch, actual_batch["file"])
self.assertAllEqual(record_batch, actual_batch["record"])
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
def testReadWithFusedShuffleRepeatDataset(self):
num_epochs = 5

View File

@ -90,7 +90,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
batch_size,
num_epochs,
):
actual_features = sess.run(nxt)
actual_features = self.evaluate(nxt)
if label_name is not None:
expected_labels = expected_features.pop(label_name)
@ -102,7 +102,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
self.assertAllEqual(expected_features[k], actual_features[k])
with self.assertRaises(errors.OutOfRangeError):
sess.run(nxt)
self.evaluate(nxt)
def _test_dataset(self,
inputs,
@ -607,8 +607,8 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
outputs1 = dataset1.make_one_shot_iterator().get_next()
outputs2 = dataset2.make_one_shot_iterator().get_next()
for _ in range(total_records // batch_size):
batch1 = nest.flatten(sess.run(outputs1))
batch2 = nest.flatten(sess.run(outputs2))
batch1 = nest.flatten(self.evaluate(outputs1))
batch2 = nest.flatten(self.evaluate(outputs2))
for i in range(len(batch1)):
self.assertAllEqual(batch1[i], batch2[i])
@ -639,8 +639,8 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
outputs2 = dataset2.make_one_shot_iterator().get_next()
all_equal = False
for _ in range(total_records // batch_size):
batch1 = nest.flatten(sess.run(outputs1))
batch2 = nest.flatten(sess.run(outputs2))
batch1 = nest.flatten(self.evaluate(outputs1))
batch2 = nest.flatten(self.evaluate(outputs2))
for i in range(len(batch1)):
all_equal = all_equal and np.array_equal(batch1[i], batch2[i])
self.assertFalse(all_equal)

View File

@ -105,7 +105,7 @@ class MakeTFRecordDatasetTest(
for expected_batch in self._next_expected_batch(
file_indices, batch_size, num_epochs, interleave_cycle_length,
drop_final_batch, use_parser_fn):
actual_batch = sess.run(outputs)
actual_batch = self.evaluate(outputs)
self.assertAllEqual(expected_batch, actual_batch)
def _read_test(self, batch_size, num_epochs, file_index=None,
@ -135,7 +135,7 @@ class MakeTFRecordDatasetTest(
interleave_cycle_length=num_parallel_reads,
drop_final_batch=drop_final_batch, use_parser_fn=parser_fn)
with self.assertRaises(errors.OutOfRangeError):
sess.run(outputs)
self.evaluate(outputs)
def testRead(self):
for batch_size in [1, 2]:
@ -188,19 +188,19 @@ class MakeTFRecordDatasetTest(
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
sess.run(iterator.initializer)
self.evaluate(iterator.initializer)
first_batches = []
try:
while True:
first_batches.append(sess.run(next_element))
first_batches.append(self.evaluate(next_element))
except errors.OutOfRangeError:
pass
sess.run(iterator.initializer)
self.evaluate(iterator.initializer)
second_batches = []
try:
while True:
second_batches.append(sess.run(next_element))
second_batches.append(self.evaluate(next_element))
except errors.OutOfRangeError:
pass

View File

@ -89,13 +89,13 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
sess.run(init_op, feed_dict={count: 28, batch_size: 14})
num_batches = (28 * 7) // 14
for i in range(num_batches):
result = sess.run(get_next)
result = self.evaluate(get_next)
for component, result_component in zip(components, result):
for j in range(14):
self.assertAllEqual(component[(i * 14 + j) % 7]**2,
result_component[j])
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.evaluate(get_next)
# Batch of a finite input, where the batch_size does not
# divide the total number of elements.
@ -104,23 +104,23 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
# We expect (num_batches - 1) full-sized batches.
num_batches = int(math.ceil((14 * 7) / 8))
for i in range(num_batches - 1):
result = sess.run(get_next)
result = self.evaluate(get_next)
for component, result_component in zip(components, result):
for j in range(8):
self.assertAllEqual(component[(i * 8 + j) % 7]**2,
result_component[j])
result = sess.run(get_next)
result = self.evaluate(get_next)
for component, result_component in zip(components, result):
for j in range((14 * 7) % 8):
self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2,
result_component[j])
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.evaluate(get_next)
# Batch of an empty input should fail straight away.
sess.run(init_op, feed_dict={count: 0, batch_size: 8})
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.evaluate(get_next)
# Empty batch should be an initialization time error.
with self.assertRaises(errors.InvalidArgumentError):
@ -152,12 +152,12 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertEqual([None, 1], iterator.output_shapes.as_list())
next_element = iterator.get_next()
with self.cached_session() as sess:
self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
self.assertAllEqual([[0], [1], [4], [9]], self.evaluate(next_element))
self.assertAllEqual([[16], [25], [36], [49]], self.evaluate(next_element))
if not drop_remainder:
self.assertAllEqual([[64], [81]], sess.run(next_element))
self.assertAllEqual([[64], [81]], self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
@parameterized.named_parameters(
("Normal", False),
@ -177,11 +177,11 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertEqual([None, 1], iterator.output_shapes.as_list())
next_element = iterator.get_next()
with self.cached_session() as sess:
self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
self.assertAllEqual([[64], [81]], sess.run(next_element))
self.assertAllEqual([[0], [1], [4], [9]], self.evaluate(next_element))
self.assertAllEqual([[16], [25], [36], [49]], self.evaluate(next_element))
self.assertAllEqual([[64], [81]], self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
@parameterized.named_parameters(
("Normal", False),
@ -201,14 +201,14 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
elements.append(iterator.get_next())
with self.cached_session() as sess:
for i in range(5):
got = sess.run(elements)
got = self.evaluate(elements)
got.sort(key=lambda x: x[0])
expected = []
for j in range(100):
expected.append(range(i * 10000 + j * 100, i * 10000 + (j + 1) * 100))
self.assertAllEqual(got, expected)
with self.assertRaises(errors.OutOfRangeError):
sess.run(elements)
self.evaluate(elements)
@parameterized.named_parameters(
("Normal", False),
@ -230,14 +230,14 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
elements.append(iterator.get_next())
with self.cached_session() as sess:
for i in range(4):
got = sess.run(elements)
got = self.evaluate(elements)
got.sort(key=lambda x: x[0])
expected = []
for j in range(100):
expected.append(range(i * 10000 + j * 100, i * 10000 + (j + 1) * 100))
self.assertAllEqual(got, expected)
with self.assertRaises(errors.OutOfRangeError):
sess.run(elements)
self.evaluate(elements)
@parameterized.named_parameters(
("Normal", False),
@ -261,9 +261,9 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
self.evaluate(init_op)
for i in range(2):
actual = sess.run(get_next)
actual = self.evaluate(get_next)
expected = sparse_tensor.SparseTensorValue(
indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4],
@ -271,7 +271,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertTrue(sparse_tensor.is_sparse(actual))
self.assertSparseValuesEqual(actual, expected)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.evaluate(get_next)
@parameterized.named_parameters(
("Normal", False),
@ -321,10 +321,10 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
self.evaluate(init_op)
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"number of elements does not match"):
sess.run(get_next)
self.evaluate(get_next)
@parameterized.named_parameters(
("Normal", False),
@ -354,7 +354,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.cached_session() as sess:
for _ in range(3):
sess.run(get_next)
self.evaluate(get_next)
@parameterized.named_parameters(
("1", 0, False),
@ -393,13 +393,14 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.cached_session() as sess:
for i in range(threshold // 10):
self.assertAllEqual([i * 10 + j for j in range(10)], sess.run(get_next))
self.assertAllEqual([i * 10 + j for j in range(10)],
self.evaluate(get_next))
if threshold % 10 != 0:
self.assertAllEqual(
[threshold // 10 * 10 + j for j in range(threshold % 10)],
sess.run(get_next))
self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.evaluate(get_next)
@parameterized.named_parameters(
("1", False, dtypes.bool, False),
@ -442,7 +443,8 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.cached_session() as sess:
for _ in range(10):
self.assertAllEqual([element for _ in range(10)], sess.run(get_next))
self.assertAllEqual([element for _ in range(10)],
self.evaluate(get_next))
@parameterized.named_parameters(
("Identity", None, lambda x: x, None),
@ -462,7 +464,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
else:
expected = map_fn(
sess.run(self.structuredElement(structure, shape=[10])))
self.assertAllEqual(expected, sess.run(get_next))
self.assertAllEqual(expected, self.evaluate(get_next))
def testShortCircuitCapturedInput(self):
captured_t = array_ops.placeholder(dtypes.int64, shape=[])
@ -473,7 +475,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.cached_session() as sess:
sess.run(iterator.initializer, feed_dict={captured_t: 42})
self.assertAllEqual([42] * 10, sess.run(get_next))
self.assertAllEqual([42] * 10, self.evaluate(get_next))
@parameterized.named_parameters(
("Normal", False),
@ -501,13 +503,13 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
print("Case %d" % i)
if i < 5:
self.assertAllEqual([i * 10 + j + 1 for j in range(10)],
sess.run(get_next))
self.evaluate(get_next))
else:
self.assertAllEqual(
[((i * 10) + j) * ((i * 10) + j) for j in range(10)],
sess.run(get_next))
self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.evaluate(get_next)
if __name__ == "__main__":

View File

@ -218,7 +218,7 @@ class MapDefunTest(test_base.DatasetTestBase):
def _assert_op_cancelled(self, sess, map_defun_op):
with self.assertRaisesRegexp(errors.CancelledError, "was cancelled"):
sess.run(map_defun_op)
self.evaluate(map_defun_op)
def testMapDefunWithParentCancellation(self):
# Checks that a cancellation of the parent graph is threaded through to
@ -260,10 +260,10 @@ class MapDefunBenchmark(test.Benchmark):
with session.Session() as sess:
# Warm up the session
for _ in range(5):
sess.run(op)
self.evaluate(op)
start = time.time()
for _ in range(num_iters):
sess.run(op)
self.evaluate(op)
end = time.time()
mean_us = (end - start) * 1e6 / num_iters
self.report_benchmark(

View File

@ -41,9 +41,9 @@ class ModelDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
get_next = iterator.get_next()
with self.cached_session() as sess:
self.assertEqual(0, sess.run(get_next))
self.assertEqual(0, self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.evaluate(get_next)
if __name__ == "__main__":

View File

@ -51,7 +51,7 @@ class OptimizeDatasetTest(test_base.DatasetTestBase):
with self.cached_session() as sess:
sess.run(init_op, {input_t: np.ones([512, 1024, 1025], np.int32)})
sess.run(get_next)
self.evaluate(get_next)
# TODO(b/117581999): Add eager coverage for the following tests.
def testSkipEagerOptimizationLargeInputFromTensorSlices(self):
@ -64,7 +64,7 @@ class OptimizeDatasetTest(test_base.DatasetTestBase):
with self.cached_session() as sess:
sess.run(init_op, {input_t: np.ones([1, 512, 1024, 1025], np.int32)})
sess.run(get_next)
self.evaluate(get_next)
def testOptimizationNestedDataset(self):

View File

@ -55,11 +55,11 @@ class OverrideThreadpoolTest(test_base.DatasetTestBase,
next_element = iterator.get_next()
with self.cached_session() as sess:
sess.run(iterator.initializer)
self.evaluate(iterator.initializer)
thread_ids = []
try:
while True:
thread_ids.append(sess.run(next_element))
thread_ids.append(self.evaluate(next_element))
except errors.OutOfRangeError:
pass
self.assertLen(thread_ids, len(set(thread_ids)))

View File

@ -195,9 +195,9 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
[[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 1, 1):
self.write_coordination_events[expected_element].set()
self.assertEqual(expected_element * expected_element,
sess.run(self.next_element))
self.evaluate(self.next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(self.next_element)
self.evaluate(self.next_element)
def testSingleThreaded(self):
self._testSingleThreaded()
@ -235,10 +235,10 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
for expected_element in self._interleave(
[[3] * 3, [7] * 7, [4] * 4] * self.repeat_count, 2, 1):
self.write_coordination_events[expected_element].set()
output = sess.run(self.next_element)
output = self.evaluate(self.next_element)
self.assertEqual(expected_element * expected_element, output)
with self.assertRaises(errors.OutOfRangeError):
sess.run(self.next_element)
self.evaluate(self.next_element)
def _testTwoThreadsNoContention(self, sloppy=False):
# num_threads > 1.
@ -262,7 +262,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
self.write_coordination_events[expected_element].set()
if done_first_event: # First event starts the worker threads.
self.read_coordination_events[expected_element].acquire()
actual_element = sess.run(self.next_element)
actual_element = self.evaluate(self.next_element)
if not done_first_event:
self.read_coordination_events[expected_element].acquire()
done_first_event = True
@ -270,7 +270,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
"At index %s: %s expected, got: %s" %
(i, expected_element, actual_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(self.next_element)
self.evaluate(self.next_element)
def testTwoThreadsNoContention(self):
self._testTwoThreadsNoContention()
@ -309,7 +309,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
else:
self.write_coordination_events[expected_element].set()
time.sleep(0.5) # Sleep to consistently "avoid" the race condition.
actual_element = sess.run(self.next_element)
actual_element = self.evaluate(self.next_element)
if not done_first_event:
done_first_event = True
self.assertTrue(
@ -318,7 +318,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
"At index %s: %s expected, got: %s" %
(i, expected_element, actual_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(self.next_element)
self.evaluate(self.next_element)
def testTwoThreadsNoContentionWithRaces(self):
self._testTwoThreadsNoContentionWithRaces()
@ -348,7 +348,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
self.write_coordination_events[expected_element].set()
if done_first_event: # First event starts the worker threads.
self.read_coordination_events[expected_element].acquire()
actual_element = sess.run(self.next_element)
actual_element = self.evaluate(self.next_element)
if not done_first_event:
done_first_event = True
self.read_coordination_events[expected_element].acquire()
@ -356,7 +356,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
"At index %s: %s expected, got: %s" %
(i, expected_element, actual_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(self.next_element)
self.evaluate(self.next_element)
def testTwoThreadsNoContentionBlockLength(self):
self._testTwoThreadsNoContentionBlockLength()
@ -396,7 +396,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
else:
self.write_coordination_events[expected_element].set()
time.sleep(0.5) # Sleep to consistently "avoid" the race condition.
actual_element = sess.run(self.next_element)
actual_element = self.evaluate(self.next_element)
if not done_first_event:
done_first_event = True
self.assertTrue(
@ -405,7 +405,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
"At index %s: %s expected, got: %s" %
(i, expected_element, actual_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(self.next_element)
self.evaluate(self.next_element)
def testTwoThreadsNoContentionWithRacesAndBlocking(self):
self._testTwoThreadsNoContentionWithRacesAndBlocking()
@ -428,7 +428,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
self.prefetch_input_elements: 0,
})
with self.assertRaises(errors.OutOfRangeError):
sess.run(self.next_element)
self.evaluate(self.next_element)
def testEmptyInput(self):
self._testEmptyInput()
@ -451,7 +451,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
self.prefetch_input_elements: 0,
})
with self.assertRaises(errors.OutOfRangeError):
sess.run(self.next_element)
self.evaluate(self.next_element)
def testNonEmptyInputIntoEmptyOutputs(self):
self._testNonEmptyInputIntoEmptyOutputs()
@ -484,7 +484,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
# presence of finishing iterators.
if done_first_event and not (sloppy and (i in race_indices)):
self.read_coordination_events[expected_element].acquire()
actual_element = sess.run(self.next_element)
actual_element = self.evaluate(self.next_element)
if not done_first_event or (sloppy and (i in race_indices)):
done_first_event = True
self.read_coordination_events[expected_element].acquire()
@ -520,10 +520,10 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
]
for element in mis_ordering:
self.write_coordination_events[element].set()
self.assertEqual(element * element, sess.run(self.next_element))
self.assertEqual(element * element, self.evaluate(self.next_element))
self.assertTrue(self.read_coordination_events[element].acquire(False))
with self.assertRaises(errors.OutOfRangeError):
sess.run(self.next_element)
self.evaluate(self.next_element)
def testBlockLengthWithContentionSloppy(self):
with self.cached_session() as sess:
@ -549,7 +549,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
self.write_coordination_events[expected_element].set()
if done_first_event: # First event starts the worker threads.
self.read_coordination_events[expected_element].acquire()
actual_element = sess.run(self.next_element)
actual_element = self.evaluate(self.next_element)
if not done_first_event:
self.read_coordination_events[expected_element].acquire()
done_first_event = True
@ -557,7 +557,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
"At index %s: %s expected, got: %s" %
(i, expected_element, actual_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(self.next_element)
self.evaluate(self.next_element)
def _testEarlyExit(self, sloppy=False):
# Exiting without consuming all input should not block
@ -575,7 +575,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
})
for i in range(4, 7):
self.write_coordination_events[i].set()
elem = sess.run(self.next_element) # Start all workers
elem = self.evaluate(self.next_element) # Start all workers
# Allow the one successful worker to progress beyond the py_func again.
elem = int(math.sqrt(elem))
self.write_coordination_events[elem].set()
@ -608,7 +608,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
with self.cached_session() as sess:
output_values = []
for _ in range(30):
output_values.append(sess.run(iterator.get_next()))
output_values.append(self.evaluate(iterator.get_next()))
expected_values = self._interleave(
[[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 1, 2)
@ -637,13 +637,13 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
self.evaluate(init_op)
for i in range(10):
for j in range(2):
expected = [i, 0] if j % 2 == 0 else [0, -i]
self.assertAllEqual(expected, sess.run(get_next))
self.assertAllEqual(expected, self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.evaluate(get_next)
def testErrorsInOutputFn(self):
with self.cached_session() as sess:
@ -668,15 +668,15 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
self.error = ValueError()
self.write_coordination_events[expected_element].set()
with self.assertRaises(errors.InvalidArgumentError):
sess.run(self.next_element)
self.evaluate(self.next_element)
else:
self.write_coordination_events[expected_element].set()
actual_element = sess.run(self.next_element)
actual_element = self.evaluate(self.next_element)
self.assertEqual(expected_element * expected_element, actual_element,
"At index %s: %s expected, got: %s" %
(i, expected_element, actual_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(self.next_element)
self.evaluate(self.next_element)
def testErrorsInInputFn(self):
@ -720,14 +720,14 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
self._interleave([[4] * 4, [5], [6] * 6] * self.repeat_count, 2, 1)):
if expected_element == 5:
with self.assertRaises(errors.InvalidArgumentError):
sess.run(self.next_element)
self.evaluate(self.next_element)
else:
actual_element = sess.run(self.next_element)
actual_element = self.evaluate(self.next_element)
self.assertEqual(expected_element, actual_element,
"At index %s: %s expected, got: %s" %
(i, expected_element, actual_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(self.next_element)
self.evaluate(self.next_element)
def testErrorsInInterleaveFn(self):
@ -769,14 +769,14 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
self._interleave([[4] * 4, [5], [6] * 6] * self.repeat_count, 2, 1)):
if expected_element == 5:
with self.assertRaises(errors.InvalidArgumentError):
sess.run(self.next_element)
self.evaluate(self.next_element)
else:
actual_element = sess.run(self.next_element)
actual_element = self.evaluate(self.next_element)
self.assertEqual(expected_element, actual_element,
"At index %s: %s expected, got: %s" %
(i, expected_element, actual_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(self.next_element)
self.evaluate(self.next_element)
def testShutdownRace(self):
dataset = dataset_ops.Dataset.range(20)
@ -796,10 +796,10 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
with self.cached_session() as sess:
for _ in range(2):
elements = []
sess.run(iterator.initializer)
self.evaluate(iterator.initializer)
try:
while True:
elements.extend(sess.run(next_element))
elements.extend(self.evaluate(next_element))
except errors.OutOfRangeError:
pass
results.append(elements)

View File

@ -57,9 +57,9 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config) as sess:
for i in range(10):
self.assertEqual(i, sess.run(next_element))
self.assertEqual(i, self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
def testPrefetchToSameDevice(self):
host_dataset = dataset_ops.Dataset.range(10)
@ -87,9 +87,9 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
with self.cached_session() as sess:
for i in range(10):
self.assertEqual(i, sess.run(next_element))
self.assertEqual(i, self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
def testPrefetchDictToDevice(self):
host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
@ -117,9 +117,9 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config) as sess:
for i in range(10):
self.assertEqual({"a": i}, sess.run(next_element))
self.assertEqual({"a": i}, self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
def testPrefetchSparseTensorsToDevice(self):
def make_tensor(i):
@ -150,12 +150,12 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config) as sess:
for i in range(10):
actual = sess.run(next_element)
actual = self.evaluate(next_element)
self.assertAllEqual([i], actual.values)
self.assertAllEqual([[0, 0]], actual.indices)
self.assertAllEqual([2, 2], actual.dense_shape)
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
def testPrefetchToDeviceGpu(self):
if not test_util.is_gpu_available():
@ -170,9 +170,9 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
with self.cached_session() as sess:
for i in range(10):
self.assertEqual(i, sess.run(next_element))
self.assertEqual(i, self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
def testPrefetchToDeviceWithReInit(self):
host_dataset = dataset_ops.Dataset.range(10)
@ -199,14 +199,14 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config) as sess:
sess.run(iterator.initializer)
self.evaluate(iterator.initializer)
for i in range(5):
self.assertEqual(i, sess.run(next_element))
sess.run(iterator.initializer)
self.assertEqual(i, self.evaluate(next_element))
self.evaluate(iterator.initializer)
for i in range(10):
self.assertEqual(i, sess.run(next_element))
self.assertEqual(i, self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
def testPrefetchToDeviceGpuWithReInit(self):
if not test_util.is_gpu_available():
@ -220,14 +220,14 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
next_element = iterator.get_next()
with self.cached_session() as sess:
sess.run(iterator.initializer)
self.evaluate(iterator.initializer)
for i in range(5):
self.assertEqual(i, sess.run(next_element))
sess.run(iterator.initializer)
self.assertEqual(i, self.evaluate(next_element))
self.evaluate(iterator.initializer)
for i in range(10):
self.assertEqual(i, sess.run(next_element))
self.assertEqual(i, self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
if __name__ == "__main__":

View File

@ -60,9 +60,9 @@ class ScanTest(test_base.DatasetTestBase):
feed_dict={start: start_val, step: step_val, take: take_val})
for expected, _ in zip(
itertools.count(start_val, step_val), range(take_val)):
self.assertEqual(expected, sess.run(next_element))
self.assertEqual(expected, self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
@test_util.run_in_graph_and_eager_modes
def testFibonacci(self):
@ -110,9 +110,9 @@ class ScanTest(test_base.DatasetTestBase):
feed_dict={start: start_val, step: step_val, take: take_val})
for expected, _ in zip(
itertools.count(start_val, step_val), range(take_val)):
self.assertEqual(expected, sess.run(next_element).values[0])
self.assertEqual(expected, self.evaluate(next_element).values[0])
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
def testChangingStateShape(self):
# Test the fixed-point shape invariant calculations: start with
@ -136,11 +136,11 @@ class ScanTest(test_base.DatasetTestBase):
with self.cached_session() as sess:
for i in range(5):
(longer_vector_val, larger_rank_val), _ = sess.run(next_element)
(longer_vector_val, larger_rank_val), _ = self.evaluate(next_element)
self.assertAllEqual([0] * (2**i), longer_vector_val)
self.assertAllEqual(np.array(1, ndmin=i), larger_rank_val)
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
def testIncorrectStateType(self):

View File

@ -71,36 +71,36 @@ class RangeDatasetSerializationTest(
with ops.Graph().as_default() as g:
init_op, get_next, save_op, _ = _build_graph(start, stop)
with self.session(graph=g) as sess:
sess.run(variables.global_variables_initializer())
sess.run(init_op)
self.evaluate(variables.global_variables_initializer())
self.evaluate(init_op)
for i in range(start, break_point):
self.assertEqual(i, sess.run(get_next))
sess.run(save_op)
self.assertEqual(i, self.evaluate(get_next))
self.evaluate(save_op)
with ops.Graph().as_default() as g:
init_op, get_next, _, restore_op = _build_graph(start, stop)
with self.session(graph=g) as sess:
sess.run(init_op)
sess.run(restore_op)
self.evaluate(init_op)
self.evaluate(restore_op)
for i in range(break_point, stop):
self.assertEqual(i, sess.run(get_next))
self.assertEqual(i, self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.evaluate(get_next)
# Saving and restoring in same session.
with ops.Graph().as_default() as g:
init_op, get_next, save_op, restore_op = _build_graph(start, stop)
with self.session(graph=g) as sess:
sess.run(variables.global_variables_initializer())
sess.run(init_op)
self.evaluate(variables.global_variables_initializer())
self.evaluate(init_op)
for i in range(start, break_point):
self.assertEqual(i, sess.run(get_next))
sess.run(save_op)
sess.run(restore_op)
self.assertEqual(i, self.evaluate(get_next))
self.evaluate(save_op)
self.evaluate(restore_op)
for i in range(break_point, stop):
self.assertEqual(i, sess.run(get_next))
self.assertEqual(i, self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.evaluate(get_next)
def _build_range_dataset(self, start, stop):
return dataset_ops.Dataset.range(start, stop)

View File

@ -60,9 +60,9 @@ class SerializationIntegrationTest(test.TestCase):
init_ops, get_next_ops, saver = self._build_graph(num_pipelines,
num_outputs)
with self.session(graph=g) as sess:
sess.run(init_ops)
self.evaluate(init_ops)
for _ in range(break_point):
output = sess.run(get_next_ops)
output = self.evaluate(get_next_ops)
for i in range(num_pipelines):
all_outputs[i].append(output[i])
saver.save(sess, self._ckpt_path())
@ -73,7 +73,7 @@ class SerializationIntegrationTest(test.TestCase):
with self.session(graph=g) as sess:
saver.restore(sess, self._ckpt_path())
for _ in range(num_outputs - break_point):
output = sess.run(get_next_ops)
output = self.evaluate(get_next_ops)
for i in range(num_pipelines):
all_outputs[i].append(output[i])

View File

@ -138,9 +138,9 @@ class ShuffleDatasetSerializationTest(
saver = saver_lib.Saver(allow_empty=True)
with self.session(graph=g) as sess:
self._save(sess, saver)
expected = [sess.run(get_next_ops) for _ in range(num_outputs)]
expected = [self.evaluate(get_next_ops) for _ in range(num_outputs)]
self._restore(saver, sess)
actual = [sess.run(get_next_ops) for _ in range(num_outputs)]
actual = [self.evaluate(get_next_ops) for _ in range(num_outputs)]
self.match(expected, actual)

View File

@ -38,10 +38,10 @@ class ShuffleAndRepeatTest(test_base.DatasetTestBase):
outputs = []
with self.cached_session() as sess:
for _ in range(num_outputs):
outputs.append(sess.run(get_next))
outputs.append(self.evaluate(get_next))
if verify_exhausted:
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.evaluate(get_next)
return outputs
def testCorrectOutput(self):
@ -108,7 +108,7 @@ class ShuffleAndRepeatTest(test_base.DatasetTestBase):
shuffle_ops.shuffle_and_repeat(buffer_size=21))
get_next_op = ds.make_one_shot_iterator().get_next()
with self.session(graph=g) as sess:
sess.run(get_next_op)
self.evaluate(get_next_op)
if __name__ == "__main__":

View File

@ -38,14 +38,14 @@ class SleepTest(test_base.DatasetTestBase):
next_element = iterator.get_next()
with self.cached_session() as sess:
sess.run(iterator.initializer)
self.evaluate(iterator.initializer)
start_time = time.time()
for i in range(10):
self.assertEqual(i, sess.run(next_element))
self.assertEqual(i, self.evaluate(next_element))
end_time = time.time()
self.assertGreater(end_time - start_time, (10 * sleep_microseconds) / 1e6)
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
if __name__ == "__main__":

View File

@ -39,10 +39,11 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
"ORDER BY first_name DESC"
})
for _ in range(2): # Dataset is repeated. See setUp.
self.assertEqual((b"John", b"Doe", b"Hi!"), sess.run(get_next))
self.assertEqual((b"Jane", b"Moe", b"Hi again!"), sess.run(get_next))
self.assertEqual((b"John", b"Doe", b"Hi!"), self.evaluate(get_next))
self.assertEqual((b"Jane", b"Moe", b"Hi again!"),
self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.evaluate(get_next)
# Test that SqlDataset works on a join query.
def testReadResultSetJoinQuery(self):
@ -58,9 +59,10 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
"ON students.first_name = people.first_name "
"AND students.last_name = people.last_name"
})
self.assertEqual((b"John", b"California", b"Hi!"), sess.run(get_next))
self.assertEqual((b"John", b"California", b"Hi!"),
self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.evaluate(get_next)
# Test that SqlDataset can read a database entry with a null-terminator
# in the middle of the text and place the entry in a `string` tensor.
@ -75,10 +77,11 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
"SELECT first_name, last_name, favorite_nonsense_word "
"FROM students ORDER BY first_name DESC"
})
self.assertEqual((b"John", b"Doe", b"n\0nsense"), sess.run(get_next))
self.assertEqual((b"Jane", b"Moe", b"nonsense\0"), sess.run(get_next))
self.assertEqual((b"John", b"Doe", b"n\0nsense"), self.evaluate(get_next))
self.assertEqual((b"Jane", b"Moe", b"nonsense\0"),
self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.evaluate(get_next)
# Test that SqlDataset works when used on two different queries.
# Because the output types of the dataset must be determined at graph-creation
@ -93,21 +96,22 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
self.query: "SELECT first_name, last_name, motto FROM students "
"ORDER BY first_name DESC"
})
self.assertEqual((b"John", b"Doe", b"Hi!"), sess.run(get_next))
self.assertEqual((b"Jane", b"Moe", b"Hi again!"), sess.run(get_next))
self.assertEqual((b"John", b"Doe", b"Hi!"), self.evaluate(get_next))
self.assertEqual((b"Jane", b"Moe", b"Hi again!"), self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.evaluate(get_next)
sess.run(
init_op,
feed_dict={
self.query: "SELECT first_name, last_name, state FROM people "
"ORDER BY first_name DESC"
})
self.assertEqual((b"John", b"Doe", b"California"), sess.run(get_next))
self.assertEqual((b"John", b"Doe", b"California"),
self.evaluate(get_next))
self.assertEqual((b"Benjamin", b"Franklin", b"Pennsylvania"),
sess.run(get_next))
self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.evaluate(get_next)
# Test that an `OutOfRangeError` is raised on the first call to
# `get_next_str_only` if result set is empty.
@ -122,7 +126,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
"WHERE first_name = 'Nonexistent'"
})
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.evaluate(get_next)
# Test that an error is raised when `driver_name` is invalid.
def testReadResultSetWithInvalidDriverName(self):
@ -151,7 +155,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
"ORDER BY first_name DESC"
})
with self.assertRaises(errors.UnknownError):
sess.run(get_next)
self.evaluate(get_next)
# Test that an error is raised when there is a syntax error in `query`.
def testReadResultSetOfQueryWithSyntaxError(self):
@ -166,7 +170,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
"ORDER BY first_name DESC"
})
with self.assertRaises(errors.UnknownError):
sess.run(get_next)
self.evaluate(get_next)
# Test that an error is raised when the number of columns in `query`
# does not match the length of `output_types`.
@ -181,7 +185,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
"ORDER BY first_name DESC"
})
with self.assertRaises(errors.InvalidArgumentError):
sess.run(get_next)
self.evaluate(get_next)
# Test that no results are returned when `query` is an insert query rather
# than a select query. In particular, the error refers to the number of
@ -199,7 +203,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
"VALUES ('Foo', 'Bar', 'Baz'), ('Fizz', 'Buzz', 'Fizzbuzz')"
})
with self.assertRaises(errors.InvalidArgumentError):
sess.run(get_next)
self.evaluate(get_next)
# Test that `SqlDataset` can read an integer from a SQLite database table and
# place it in an `int8` tensor.
@ -212,10 +216,10 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
self.query: "SELECT first_name, desk_number FROM students "
"ORDER BY first_name DESC"
})
self.assertEqual((b"John", 9), sess.run(get_next))
self.assertEqual((b"Jane", 127), sess.run(get_next))
self.assertEqual((b"John", 9), self.evaluate(get_next))
self.assertEqual((b"Jane", 127), self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.evaluate(get_next)
# Test that `SqlDataset` can read a negative or 0-valued integer from a
# SQLite database table and place it in an `int8` tensor.
@ -230,9 +234,9 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
"FROM students "
"WHERE first_name = 'John' ORDER BY first_name DESC"
})
self.assertEqual((b"John", 0, -2), sess.run(get_next))
self.assertEqual((b"John", 0, -2), self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.evaluate(get_next)
# Test that `SqlDataset` can read a large (positive or negative) integer from
# a SQLite database table and place it in an `int8` tensor.
@ -246,11 +250,11 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
"SELECT desk_number, favorite_negative_number FROM students "
"ORDER BY first_name DESC"
})
self.assertEqual((9, -2), sess.run(get_next))
self.assertEqual((9, -2), self.evaluate(get_next))
# Max and min values of int8
self.assertEqual((127, -128), sess.run(get_next))
self.assertEqual((127, -128), self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.evaluate(get_next)
# Test that `SqlDataset` can read an integer from a SQLite database table and
# place it in an `int16` tensor.
@ -263,10 +267,10 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
self.query: "SELECT first_name, desk_number FROM students "
"ORDER BY first_name DESC"
})
self.assertEqual((b"John", 9), sess.run(get_next))
self.assertEqual((b"Jane", 127), sess.run(get_next))
self.assertEqual((b"John", 9), self.evaluate(get_next))
self.assertEqual((b"Jane", 127), self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.evaluate(get_next)
# Test that `SqlDataset` can read a negative or 0-valued integer from a
# SQLite database table and place it in an `int16` tensor.
@ -281,9 +285,9 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
"FROM students "
"WHERE first_name = 'John' ORDER BY first_name DESC"
})
self.assertEqual((b"John", 0, -2), sess.run(get_next))
self.assertEqual((b"John", 0, -2), self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.evaluate(get_next)
# Test that `SqlDataset` can read a large (positive or negative) integer from
# a SQLite database table and place it in an `int16` tensor.
@ -297,11 +301,11 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
"FROM students ORDER BY first_name DESC"
})
# Max value of int16
self.assertEqual((b"John", 32767), sess.run(get_next))
self.assertEqual((b"John", 32767), self.evaluate(get_next))
# Min value of int16
self.assertEqual((b"Jane", -32768), sess.run(get_next))
self.assertEqual((b"Jane", -32768), self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.evaluate(get_next)
# Test that `SqlDataset` can read an integer from a SQLite database table and
# place it in an `int32` tensor.
@ -314,8 +318,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
self.query: "SELECT first_name, desk_number FROM students "
"ORDER BY first_name DESC"
})
self.assertEqual((b"John", 9), sess.run(get_next))
self.assertEqual((b"Jane", 127), sess.run(get_next))
self.assertEqual((b"John", 9), self.evaluate(get_next))
self.assertEqual((b"Jane", 127), self.evaluate(get_next))
# Test that `SqlDataset` can read a negative or 0-valued integer from a
# SQLite database table and place it in an `int32` tensor.
@ -328,10 +332,10 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
self.query: "SELECT first_name, income FROM students "
"ORDER BY first_name DESC"
})
self.assertEqual((b"John", 0), sess.run(get_next))
self.assertEqual((b"Jane", -20000), sess.run(get_next))
self.assertEqual((b"John", 0), self.evaluate(get_next))
self.assertEqual((b"Jane", -20000), self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.evaluate(get_next)
# Test that `SqlDataset` can read a large (positive or negative) integer from
# a SQLite database table and place it in an `int32` tensor.
@ -345,11 +349,11 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
"ORDER BY first_name DESC"
})
# Max value of int32
self.assertEqual((b"John", 2147483647), sess.run(get_next))
self.assertEqual((b"John", 2147483647), self.evaluate(get_next))
# Min value of int32
self.assertEqual((b"Jane", -2147483648), sess.run(get_next))
self.assertEqual((b"Jane", -2147483648), self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.evaluate(get_next)
# Test that `SqlDataset` can read a numeric `varchar` from a SQLite database
# table and place it in an `int32` tensor.
@ -362,10 +366,10 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
self.query: "SELECT first_name, school_id FROM students "
"ORDER BY first_name DESC"
})
self.assertEqual((b"John", 123), sess.run(get_next))
self.assertEqual((b"Jane", 1000), sess.run(get_next))
self.assertEqual((b"John", 123), self.evaluate(get_next))
self.assertEqual((b"Jane", 1000), self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.evaluate(get_next)
# Test that `SqlDataset` can read an integer from a SQLite database table
# and place it in an `int64` tensor.
@ -378,10 +382,10 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
self.query: "SELECT first_name, desk_number FROM students "
"ORDER BY first_name DESC"
})
self.assertEqual((b"John", 9), sess.run(get_next))
self.assertEqual((b"Jane", 127), sess.run(get_next))
self.assertEqual((b"John", 9), self.evaluate(get_next))
self.assertEqual((b"Jane", 127), self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.evaluate(get_next)
# Test that `SqlDataset` can read a negative or 0-valued integer from a
# SQLite database table and place it in an `int64` tensor.
@ -394,10 +398,10 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
self.query: "SELECT first_name, income FROM students "
"ORDER BY first_name DESC"
})
self.assertEqual((b"John", 0), sess.run(get_next))
self.assertEqual((b"Jane", -20000), sess.run(get_next))
self.assertEqual((b"John", 0), self.evaluate(get_next))
self.assertEqual((b"Jane", -20000), self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.evaluate(get_next)
# Test that `SqlDataset` can read a large (positive or negative) integer from
# a SQLite database table and place it in an `int64` tensor.
@ -412,11 +416,11 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
"ORDER BY first_name DESC"
})
# Max value of int64
self.assertEqual((b"John", 9223372036854775807), sess.run(get_next))
self.assertEqual((b"John", 9223372036854775807), self.evaluate(get_next))
# Min value of int64
self.assertEqual((b"Jane", -9223372036854775808), sess.run(get_next))
self.assertEqual((b"Jane", -9223372036854775808), self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.evaluate(get_next)
# Test that `SqlDataset` can read an integer from a SQLite database table and
# place it in a `uint8` tensor.
@ -429,10 +433,10 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
self.query: "SELECT first_name, desk_number FROM students "
"ORDER BY first_name DESC"
})
self.assertEqual((b"John", 9), sess.run(get_next))
self.assertEqual((b"Jane", 127), sess.run(get_next))
self.assertEqual((b"John", 9), self.evaluate(get_next))
self.assertEqual((b"Jane", 127), self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.evaluate(get_next)
# Test that `SqlDataset` can read the minimum and maximum uint8 values from a
# SQLite database table and place them in `uint8` tensors.
@ -446,11 +450,11 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
"ORDER BY first_name DESC"
})
# Min value of uint8
self.assertEqual((b"John", 0), sess.run(get_next))
self.assertEqual((b"John", 0), self.evaluate(get_next))
# Max value of uint8
self.assertEqual((b"Jane", 255), sess.run(get_next))
self.assertEqual((b"Jane", 255), self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.evaluate(get_next)
# Test that `SqlDataset` can read an integer from a SQLite database table
# and place it in a `uint16` tensor.
@ -463,10 +467,10 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
self.query: "SELECT first_name, desk_number FROM students "
"ORDER BY first_name DESC"
})
self.assertEqual((b"John", 9), sess.run(get_next))
self.assertEqual((b"Jane", 127), sess.run(get_next))
self.assertEqual((b"John", 9), self.evaluate(get_next))
self.assertEqual((b"Jane", 127), self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.evaluate(get_next)
# Test that `SqlDataset` can read the minimum and maximum uint16 values from a
# SQLite database table and place them in `uint16` tensors.
@ -480,11 +484,11 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
"ORDER BY first_name DESC"
})
# Min value of uint16
self.assertEqual((b"John", 0), sess.run(get_next))
self.assertEqual((b"John", 0), self.evaluate(get_next))
# Max value of uint16
self.assertEqual((b"Jane", 65535), sess.run(get_next))
self.assertEqual((b"Jane", 65535), self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.evaluate(get_next)
# Test that `SqlDataset` can read a 0-valued and 1-valued integer from a
# SQLite database table and place them as `True` and `False` respectively
@ -499,10 +503,10 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
"SELECT first_name, registration_complete FROM students "
"ORDER BY first_name DESC"
})
self.assertEqual((b"John", True), sess.run(get_next))
self.assertEqual((b"Jane", False), sess.run(get_next))
self.assertEqual((b"John", True), self.evaluate(get_next))
self.assertEqual((b"Jane", False), self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.evaluate(get_next)
# Test that `SqlDataset` can read an integer that is not 0-valued or 1-valued
# from a SQLite database table and place it as `True` in a `bool` tensor.
@ -515,10 +519,10 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
self.query: "SELECT first_name, favorite_medium_sized_number "
"FROM students ORDER BY first_name DESC"
})
self.assertEqual((b"John", True), sess.run(get_next))
self.assertEqual((b"Jane", True), sess.run(get_next))
self.assertEqual((b"John", True), self.evaluate(get_next))
self.assertEqual((b"Jane", True), self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.evaluate(get_next)
# Test that `SqlDataset` can read a float from a SQLite database table
# and place it in a `float64` tensor.
@ -533,10 +537,11 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
"SELECT first_name, last_name, victories FROM townspeople "
"ORDER BY first_name"
})
self.assertEqual((b"George", b"Washington", 20.0), sess.run(get_next))
self.assertEqual((b"John", b"Adams", -19.95), sess.run(get_next))
self.assertEqual((b"George", b"Washington", 20.0),
self.evaluate(get_next))
self.assertEqual((b"John", b"Adams", -19.95), self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.evaluate(get_next)
# Test that `SqlDataset` can read a float from a SQLite database table beyond
# the precision of 64-bit IEEE, without throwing an error. Test that
@ -555,13 +560,13 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
self.assertEqual(
(b"George", b"Washington",
1331241.321342132321324589798264627463827647382647382643874),
sess.run(get_next))
self.evaluate(get_next))
self.assertEqual(
(b"John", b"Adams",
1331241321342132321324589798264627463827647382647382643874.0),
sess.run(get_next))
self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.evaluate(get_next)
# Test that `SqlDataset` can read a float from a SQLite database table,
# representing the largest integer representable as a 64-bit IEEE float
@ -579,11 +584,11 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
"ORDER BY first_name"
})
self.assertNotEqual((b"George", b"Washington", 9007199254740992.0),
sess.run(get_next))
self.evaluate(get_next))
self.assertNotEqual((b"John", b"Adams", 9007199254740991.0),
sess.run(get_next))
self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.evaluate(get_next)
if __name__ == "__main__":

View File

@ -70,18 +70,18 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
summary_t = aggregator.get_summary()
with self.cached_session() as sess:
sess.run(iterator.initializer)
self.evaluate(iterator.initializer)
expected_sum = 0.0
for i in range(100):
self.assertAllEqual(
np.array([i] * i, dtype=np.int64), sess.run(next_element))
summary_str = sess.run(summary_t)
np.array([i] * i, dtype=np.int64), self.evaluate(next_element))
summary_str = self.evaluate(summary_t)
self._assertSummaryHasCount(summary_str, "bytes_produced", float(i + 1))
expected_sum += i * 8.0
self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum)
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
summary_str = sess.run(summary_t)
self.evaluate(next_element)
summary_str = self.evaluate(summary_t)
self._assertSummaryHasCount(summary_str, "bytes_produced", 100.0)
self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum)
@ -95,14 +95,15 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
summary_t = aggregator.get_summary()
with self.cached_session() as sess:
sess.run(iterator.initializer)
self.evaluate(iterator.initializer)
for i in range(100):
self.assertEqual(i, sess.run(next_element))
self.assertEqual(i, self.evaluate(next_element))
self._assertSummaryHasCount(
sess.run(summary_t), "record_latency", float(i + 1))
self.evaluate(summary_t), "record_latency", float(i + 1))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0)
self.evaluate(next_element)
self._assertSummaryHasCount(
self.evaluate(summary_t), "record_latency", 100.0)
def testPrefetchBufferUtilization(self, dataset_transformation):
aggregator = stats_aggregator.StatsAggregator()
@ -114,11 +115,11 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
summary_t = aggregator.get_summary()
with self.cached_session() as sess:
sess.run(iterator.initializer)
self.evaluate(iterator.initializer)
for i in range(100):
self.assertAllEqual(
np.array([i] * i, dtype=np.int64), sess.run(next_element))
summary_str = sess.run(summary_t)
np.array([i] * i, dtype=np.int64), self.evaluate(next_element))
summary_str = self.evaluate(summary_t)
self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
float(i + 1))
self._assertSummaryContains(summary_str, "Prefetch::buffer_capacity")
@ -126,8 +127,8 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
self._assertSummaryHasRange(summary_str, "Prefetch::buffer_utilization",
0, 1)
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
summary_str = sess.run(summary_t)
self.evaluate(next_element)
summary_str = self.evaluate(summary_t)
self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
100)
@ -141,17 +142,17 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
summary_t = aggregator.get_summary()
with self.cached_session() as sess:
sess.run(iterator.initializer)
self.evaluate(iterator.initializer)
for i in range(10):
self.assertAllEqual(
np.array([i] * i, dtype=np.int64), sess.run(next_element))
summary_str = sess.run(summary_t)
np.array([i] * i, dtype=np.int64), self.evaluate(next_element))
summary_str = self.evaluate(summary_t)
self._assertSummaryHasScalarValue(summary_str,
"Prefetch::buffer_capacity", 0)
self._assertSummaryHasScalarValue(summary_str, "Prefetch::buffer_size",
0)
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
def testFilteredElementsStats(self, dataset_transformation):
aggregator = stats_aggregator.StatsAggregator()
@ -163,20 +164,21 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
summary_t = aggregator.get_summary()
with self.test_session() as sess:
sess.run(iterator.initializer)
self.evaluate(iterator.initializer)
for i in range(34):
self.assertEqual(i * 3, sess.run(next_element))
self.assertEqual(i * 3, self.evaluate(next_element))
if i is not 0:
self._assertSummaryHasScalarValue(
sess.run(summary_t), "Filter::dropped_elements", float(i * 2))
self.evaluate(summary_t), "Filter::dropped_elements",
float(i * 2))
self._assertSummaryHasScalarValue(
sess.run(summary_t), "Filter::filtered_elements", float(i + 1))
self.evaluate(summary_t), "Filter::filtered_elements", float(i + 1))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
self._assertSummaryHasScalarValue(
sess.run(summary_t), "Filter::dropped_elements", 67.0)
self.evaluate(summary_t), "Filter::dropped_elements", 67.0)
self._assertSummaryHasScalarValue(
sess.run(summary_t), "Filter::filtered_elements", 34.0)
self.evaluate(summary_t), "Filter::filtered_elements", 34.0)
def testMapBufferUtilization(self, dataset_transformation):
@ -257,15 +259,16 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
with self.cached_session() as sess:
for j in range(5):
sess.run(iterator.initializer)
self.evaluate(iterator.initializer)
for i in range(100):
self.assertEqual(i, sess.run(next_element))
self.assertEqual(i, self.evaluate(next_element))
self._assertSummaryHasCount(
sess.run(summary_t), "record_latency", float((j * 100) + i + 1))
self.evaluate(summary_t), "record_latency",
float((j * 100) + i + 1))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
self._assertSummaryHasCount(
sess.run(summary_t), "record_latency", (j + 1) * 100.0)
self.evaluate(summary_t), "record_latency", (j + 1) * 100.0)
def testNoAggregatorRegistered(self, dataset_transformation):
dataset = dataset_ops.Dataset.range(100).apply(
@ -274,11 +277,11 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
next_element = iterator.get_next()
with self.cached_session() as sess:
sess.run(iterator.initializer)
self.evaluate(iterator.initializer)
for i in range(100):
self.assertEqual(i, sess.run(next_element))
self.assertEqual(i, self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
def testMultipleTags(self, dataset_transformation):
aggregator = stats_aggregator.StatsAggregator()
@ -291,18 +294,19 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
summary_t = aggregator.get_summary()
with self.cached_session() as sess:
sess.run(iterator.initializer)
self.evaluate(iterator.initializer)
for i in range(100):
self.assertEqual(i, sess.run(next_element))
self.assertEqual(i, self.evaluate(next_element))
self._assertSummaryHasCount(
sess.run(summary_t), "record_latency", float(i + 1))
self.evaluate(summary_t), "record_latency", float(i + 1))
self._assertSummaryHasCount(
sess.run(summary_t), "record_latency_2", float(i + 1))
self.evaluate(summary_t), "record_latency_2", float(i + 1))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0)
self.evaluate(next_element)
self._assertSummaryHasCount(
sess.run(summary_t), "record_latency_2", 100.0)
self.evaluate(summary_t), "record_latency", 100.0)
self._assertSummaryHasCount(
self.evaluate(summary_t), "record_latency_2", 100.0)
def testRepeatedTags(self, dataset_transformation):
aggregator = stats_aggregator.StatsAggregator()
@ -315,14 +319,15 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
summary_t = aggregator.get_summary()
with self.cached_session() as sess:
sess.run(iterator.initializer)
self.evaluate(iterator.initializer)
for i in range(100):
self.assertEqual(i, sess.run(next_element))
self.assertEqual(i, self.evaluate(next_element))
self._assertSummaryHasCount(
sess.run(summary_t), "record_latency", float(2 * (i + 1)))
self.evaluate(summary_t), "record_latency", float(2 * (i + 1)))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0)
self.evaluate(next_element)
self._assertSummaryHasCount(
self.evaluate(summary_t), "record_latency", 200.0)
def testMultipleIteratorsSameAggregator(self, dataset_transformation):
aggregator = stats_aggregator.StatsAggregator()
@ -335,14 +340,15 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
summary_t = aggregator.get_summary()
with self.cached_session() as sess:
sess.run([iterator_0.initializer, iterator_1.initializer])
self.evaluate([iterator_0.initializer, iterator_1.initializer])
for i in range(100):
self.assertEqual(i * 2, sess.run(next_element))
self.assertEqual(i * 2, self.evaluate(next_element))
self._assertSummaryHasCount(
sess.run(summary_t), "record_latency", float(2 * (i + 1)))
self.evaluate(summary_t), "record_latency", float(2 * (i + 1)))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0)
self.evaluate(next_element)
self._assertSummaryHasCount(
self.evaluate(summary_t), "record_latency", 200.0)
def testMultipleDatasetWithPrefixes(self, dataset_transformation):
aggregator = stats_aggregator.StatsAggregator()
@ -358,19 +364,19 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
summary_t = aggregator.get_summary()
with self.test_session() as sess:
sess.run([iterator_0.initializer, iterator_1.initializer])
self.evaluate([iterator_0.initializer, iterator_1.initializer])
for i in range(100):
self.assertEqual(i * 2, sess.run(next_element))
self.assertEqual(i * 2, self.evaluate(next_element))
self._assertSummaryHasCount(
sess.run(summary_t), "dataset1_record_latency", float(i + 1))
self.evaluate(summary_t), "dataset1_record_latency", float(i + 1))
self._assertSummaryHasCount(
sess.run(summary_t), "dataset2_record_latency", float(i + 1))
self.evaluate(summary_t), "dataset2_record_latency", float(i + 1))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
self._assertSummaryHasCount(
sess.run(summary_t), "dataset1_record_latency", 100.0)
self.evaluate(summary_t), "dataset1_record_latency", 100.0)
self._assertSummaryHasCount(
sess.run(summary_t), "dataset2_record_latency", 100.0)
self.evaluate(summary_t), "dataset2_record_latency", 100.0)
@parameterized.named_parameters(
@ -417,20 +423,21 @@ class FeatureStatsDatasetTest(
summary_t = aggregator.get_summary()
with self.test_session() as sess:
sess.run(iterator.initializer)
self.evaluate(iterator.initializer)
for _ in range(num_output):
sess.run(next_element)
self.evaluate(next_element)
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
self._assertSummaryHasCount(
sess.run(summary_t), "record_stats_features", total_records)
self.evaluate(summary_t), "record_stats_features", total_records)
self._assertSummaryHasCount(
sess.run(summary_t), "record_stats_feature-values", total_records)
self.evaluate(summary_t), "record_stats_feature-values",
total_records)
self._assertSummaryHasSum(
sess.run(summary_t), "record_stats_features", total_records * 4)
self.evaluate(summary_t), "record_stats_features", total_records * 4)
self._assertSummaryHasSum(
sess.run(summary_t), "record_stats_feature-values",
self.evaluate(summary_t), "record_stats_feature-values",
self._sum_keywords(1) * num_epochs + 3 * total_records)

View File

@ -47,9 +47,9 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.cached_session() as sess:
sess.run(iterator.initializer, feed_dict={placeholder: [0, 1, 2, 3]})
for i in range(4):
self.assertEqual(i, sess.run(next_elem))
self.assertEqual(i, self.evaluate(next_elem))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_elem)
self.evaluate(next_elem)
def testUnbatchScalarDataset(self):
data = tuple([math_ops.range(10) for _ in range(3)])
@ -65,10 +65,10 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.cached_session() as sess:
for i in range(10):
self.assertEqual((i,) * 3, sess.run(op))
self.assertEqual((i,) * 3, self.evaluate(op))
with self.assertRaises(errors.OutOfRangeError):
sess.run(op)
self.evaluate(op)
def testUnbatchDatasetWithStrings(self):
data = tuple([math_ops.range(10) for _ in range(3)])
@ -85,10 +85,10 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.cached_session() as sess:
for i in range(10):
self.assertEqual((i, compat.as_bytes(str(i)), i), sess.run(op))
self.assertEqual((i, compat.as_bytes(str(i)), i), self.evaluate(op))
with self.assertRaises(errors.OutOfRangeError):
sess.run(op)
self.evaluate(op)
def testUnbatchDatasetWithSparseTensor(self):
st = sparse_tensor.SparseTensorValue(
@ -104,12 +104,12 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.cached_session() as sess:
for i in range(10):
st_row = sess.run(next_element)
st_row = self.evaluate(next_element)
self.assertEqual([i], st_row.indices)
self.assertEqual([i], st_row.values)
self.assertEqual([10], st_row.dense_shape)
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
def testUnbatchDatasetWithDenseAndSparseTensor(self):
st = sparse_tensor.SparseTensorValue(
@ -125,13 +125,13 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.cached_session() as sess:
for i in range(10):
dense_elem, st_row = sess.run(next_element)
dense_elem, st_row = self.evaluate(next_element)
self.assertEqual(i, dense_elem)
self.assertEqual([i], st_row.indices)
self.assertEqual([i], st_row.values)
self.assertEqual([10], st_row.dense_shape)
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
def testUnbatchSingleElementTupleDataset(self):
data = tuple([(math_ops.range(10),) for _ in range(3)])
@ -147,10 +147,10 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.cached_session() as sess:
for i in range(10):
self.assertEqual(((i,),) * 3, sess.run(op))
self.assertEqual(((i,),) * 3, self.evaluate(op))
with self.assertRaises(errors.OutOfRangeError):
sess.run(op)
self.evaluate(op)
def testUnbatchMultiElementTupleDataset(self):
data = tuple([(math_ops.range(10 * i, 10 * i + 10),
@ -168,10 +168,10 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.cached_session() as sess:
for i in range(10):
self.assertEqual(((i, b"hi"), (10 + i, b"hi"), (20 + i, b"hi")),
sess.run(op))
self.evaluate(op))
with self.assertRaises(errors.OutOfRangeError):
sess.run(op)
self.evaluate(op)
def testUnbatchEmpty(self):
data = dataset_ops.Dataset.from_tensors(
@ -183,7 +183,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.cached_session() as sess:
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
def testUnbatchStaticShapeMismatch(self):
data = dataset_ops.Dataset.from_tensors((np.arange(7), np.arange(8),
@ -208,7 +208,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
ph2: np.arange(8).astype(np.int32)
})
with self.assertRaises(errors.InvalidArgumentError):
sess.run(next_element)
self.evaluate(next_element)
# No 0th dimension (i.e. scalar value) for one component.
sess.run(
@ -218,7 +218,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
ph2: 7
})
with self.assertRaises(errors.InvalidArgumentError):
sess.run(next_element)
self.evaluate(next_element)
if __name__ == "__main__":

View File

@ -49,13 +49,13 @@ class UniqueTest(test_base.DatasetTestBase):
with self.cached_session() as sess:
for test_case, expected in test_cases:
current_test_case = test_case
sess.run(iterator.initializer)
self.evaluate(iterator.initializer)
for element in expected:
if dtype == dtypes.string:
element = compat.as_bytes(element)
self.assertAllEqual(element, sess.run(next_element))
self.assertAllEqual(element, self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
def testSimpleInt(self):
for dtype in [dtypes.int32, dtypes.int64]:

View File

@ -41,7 +41,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
config = config_pb2.ConfigProto(device_count={"CPU": 3})
with self.test_session(config=config) as sess:
sess.run(multi_device_iterator.initializer)
self.evaluate(multi_device_iterator.initializer)
def testBasic(self):
dataset = dataset_ops.Dataset.range(10)
@ -51,13 +51,13 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
config = config_pb2.ConfigProto(device_count={"CPU": 3})
with self.test_session(config=config) as sess:
sess.run(multi_device_iterator.initializer)
self.evaluate(multi_device_iterator.initializer)
for i in range(0, 10, 2):
self.assertEqual(i, sess.run(elem_on_1))
self.assertEqual(i + 1, sess.run(elem_on_2))
self.assertEqual(i, self.evaluate(elem_on_1))
self.assertEqual(i + 1, self.evaluate(elem_on_2))
with self.assertRaises(errors.OutOfRangeError):
sess.run(elem_on_1)
sess.run(elem_on_2)
self.evaluate(elem_on_1)
self.evaluate(elem_on_2)
def testOneOnSameDevice(self):
with ops.device("/cpu:0"):
@ -68,13 +68,13 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=config) as sess:
sess.run(multi_device_iterator.initializer)
self.evaluate(multi_device_iterator.initializer)
for i in range(0, 10, 2):
self.assertEqual(i, sess.run(elem_on_1))
self.assertEqual(i + 1, sess.run(elem_on_2))
self.assertEqual(i, self.evaluate(elem_on_1))
self.assertEqual(i + 1, self.evaluate(elem_on_2))
with self.assertRaises(errors.OutOfRangeError):
sess.run(elem_on_1)
sess.run(elem_on_2)
self.evaluate(elem_on_1)
self.evaluate(elem_on_2)
def testRepeatDevices(self):
with ops.device("/cpu:0"):
@ -86,17 +86,17 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
config = config_pb2.ConfigProto(device_count={"CPU": 3})
with self.test_session(config=config) as sess:
sess.run(multi_device_iterator.initializer)
self.evaluate(multi_device_iterator.initializer)
for i in range(0, 20, 4):
self.assertEqual(i, sess.run(elem_on_1))
self.assertEqual(i + 1, sess.run(elem_on_2))
self.assertEqual(i + 2, sess.run(elem_on_3))
self.assertEqual(i + 3, sess.run(elem_on_4))
self.assertEqual(i, self.evaluate(elem_on_1))
self.assertEqual(i + 1, self.evaluate(elem_on_2))
self.assertEqual(i + 2, self.evaluate(elem_on_3))
self.assertEqual(i + 3, self.evaluate(elem_on_4))
with self.assertRaises(errors.OutOfRangeError):
sess.run(elem_on_1)
sess.run(elem_on_2)
sess.run(elem_on_3)
sess.run(elem_on_4)
self.evaluate(elem_on_1)
self.evaluate(elem_on_2)
self.evaluate(elem_on_3)
self.evaluate(elem_on_4)
def testNotFullyDivisible(self):
dataset = dataset_ops.Dataset.range(9)
@ -106,14 +106,14 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
config = config_pb2.ConfigProto(device_count={"CPU": 3})
with self.test_session(config=config) as sess:
sess.run(multi_device_iterator.initializer)
self.evaluate(multi_device_iterator.initializer)
for i in range(0, 8, 2):
self.assertEqual(i, sess.run(elem_on_1))
self.assertEqual(i + 1, sess.run(elem_on_2))
self.assertEqual(8, sess.run(elem_on_1))
self.assertEqual(i, self.evaluate(elem_on_1))
self.assertEqual(i + 1, self.evaluate(elem_on_2))
self.assertEqual(8, self.evaluate(elem_on_1))
with self.assertRaises(errors.OutOfRangeError):
sess.run(elem_on_1)
sess.run(elem_on_2)
self.evaluate(elem_on_1)
self.evaluate(elem_on_2)
def testGetNextAsOptional(self):
dataset = dataset_ops.Dataset.range(9)
@ -127,7 +127,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
config = config_pb2.ConfigProto(device_count={"CPU": 3})
with self.test_session(config=config) as sess:
sess.run(multi_device_iterator.initializer)
self.evaluate(multi_device_iterator.initializer)
for i in range(0, 8, 2):
elem_on_1_has_value, elem_on_1_value = sess.run(
[elem_on_1_has_value_t, elem_on_1_t])
@ -141,12 +141,12 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
[elem_on_1_has_value_t, elem_on_1_t])
self.assertTrue(elem_on_1_has_value)
self.assertEqual(8, elem_on_1_value)
self.assertFalse(sess.run(elem_on_1_has_value_t))
self.assertFalse(sess.run(elem_on_2_has_value_t))
self.assertFalse(self.evaluate(elem_on_1_has_value_t))
self.assertFalse(self.evaluate(elem_on_2_has_value_t))
with self.assertRaises(errors.InvalidArgumentError):
sess.run(elem_on_1_t)
self.evaluate(elem_on_1_t)
with self.assertRaises(errors.InvalidArgumentError):
sess.run(elem_on_2_t)
self.evaluate(elem_on_2_t)
def testUneven(self):
dataset = dataset_ops.Dataset.range(10)
@ -156,14 +156,14 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
config = config_pb2.ConfigProto(device_count={"CPU": 3})
with self.test_session(config=config) as sess:
sess.run(multi_device_iterator.initializer)
self.evaluate(multi_device_iterator.initializer)
for i in range(0, 10, 2):
self.assertEqual(i, sess.run(elem_on_1))
self.assertEqual(i, self.evaluate(elem_on_1))
for i in range(0, 10, 2):
self.assertEqual(i + 1, sess.run(elem_on_2))
self.assertEqual(i + 1, self.evaluate(elem_on_2))
with self.assertRaises(errors.OutOfRangeError):
sess.run(elem_on_1)
sess.run(elem_on_2)
self.evaluate(elem_on_1)
self.evaluate(elem_on_2)
def testMultipleInitializations(self):
with ops.device("/cpu:0"):
@ -180,7 +180,8 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
with self.test_session(config=config) as sess:
for i in range(1000):
sess.run(init_op, feed_dict={epoch: i})
self.assertEqual([(i, 0), (i, 1)], sess.run([elem_on_1, elem_on_2]))
self.assertEqual([(i, 0), (i, 1)], self.evaluate([elem_on_1,
elem_on_2]))
def testBasicGpu(self):
if not test_util.is_gpu_available():
@ -193,13 +194,13 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1})
with self.test_session(config=config) as sess:
sess.run(multi_device_iterator.initializer)
self.evaluate(multi_device_iterator.initializer)
for i in range(0, 10, 2):
self.assertEqual(i, sess.run(elem_on_1))
self.assertEqual(i + 1, sess.run(elem_on_2))
self.assertEqual(i, self.evaluate(elem_on_1))
self.assertEqual(i + 1, self.evaluate(elem_on_2))
with self.assertRaises(errors.OutOfRangeError):
sess.run(elem_on_1)
sess.run(elem_on_2)
self.evaluate(elem_on_1)
self.evaluate(elem_on_2)
def testUnevenGpu(self):
if not test_util.is_gpu_available():
@ -212,14 +213,14 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1})
with self.test_session(config=config) as sess:
sess.run(multi_device_iterator.initializer)
self.evaluate(multi_device_iterator.initializer)
for i in range(0, 10, 2):
self.assertEqual(i, sess.run(elem_on_1))
self.assertEqual(i, self.evaluate(elem_on_1))
for i in range(0, 10, 2):
self.assertEqual(i + 1, sess.run(elem_on_2))
self.assertEqual(i + 1, self.evaluate(elem_on_2))
with self.assertRaises(errors.OutOfRangeError):
sess.run(elem_on_1)
sess.run(elem_on_2)
self.evaluate(elem_on_1)
self.evaluate(elem_on_2)
def testGetNextAsOptionalGpu(self):
if not test_util.is_gpu_available():
@ -236,7 +237,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1})
with self.test_session(config=config) as sess:
sess.run(multi_device_iterator.initializer)
self.evaluate(multi_device_iterator.initializer)
for i in range(0, 8, 2):
elem_on_1_has_value, elem_on_1_value = sess.run(
[elem_on_1_has_value_t, elem_on_1_t])
@ -250,12 +251,12 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
[elem_on_1_has_value_t, elem_on_1_t])
self.assertTrue(elem_on_1_has_value)
self.assertEqual(8, elem_on_1_value)
self.assertFalse(sess.run(elem_on_1_has_value_t))
self.assertFalse(sess.run(elem_on_2_has_value_t))
self.assertFalse(self.evaluate(elem_on_1_has_value_t))
self.assertFalse(self.evaluate(elem_on_2_has_value_t))
with self.assertRaises(errors.InvalidArgumentError):
sess.run(elem_on_1_t)
self.evaluate(elem_on_1_t)
with self.assertRaises(errors.InvalidArgumentError):
sess.run(elem_on_2_t)
self.evaluate(elem_on_2_t)
def testOptimization(self):
dataset = dataset_ops.Dataset.range(10)
@ -273,13 +274,13 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
config = config_pb2.ConfigProto(device_count={"CPU": 3})
with self.test_session(config=config) as sess:
sess.run(multi_device_iterator.initializer)
self.evaluate(multi_device_iterator.initializer)
for i in range(0, 10, 2):
self.assertEqual(i, sess.run(elem_on_1))
self.assertEqual(i + 1, sess.run(elem_on_2))
self.assertEqual(i, self.evaluate(elem_on_1))
self.assertEqual(i + 1, self.evaluate(elem_on_2))
with self.assertRaises(errors.OutOfRangeError):
sess.run(elem_on_1)
sess.run(elem_on_2)
self.evaluate(elem_on_1)
self.evaluate(elem_on_2)
if __name__ == "__main__":

View File

@ -30,47 +30,52 @@ class ConvertTest(test.TestCase):
def testInteger(self):
resp = convert.optional_param_to_tensor("foo", 3)
with self.cached_session() as sess:
self.assertEqual(3, sess.run(resp))
self.assertEqual(3, self.evaluate(resp))
def testIntegerDefault(self):
resp = convert.optional_param_to_tensor("foo", None)
with self.cached_session() as sess:
self.assertEqual(0, sess.run(resp))
self.assertEqual(0, self.evaluate(resp))
def testStringDefault(self):
resp = convert.optional_param_to_tensor("bar", None, "default",
dtypes.string)
with self.cached_session() as sess:
self.assertEqual(compat.as_bytes("default"), sess.run(resp))
self.assertEqual(compat.as_bytes("default"), self.evaluate(resp))
def testString(self):
resp = convert.optional_param_to_tensor("bar", "value", "default",
dtypes.string)
with self.cached_session() as sess:
self.assertEqual(compat.as_bytes("value"), sess.run(resp))
self.assertEqual(compat.as_bytes("value"), self.evaluate(resp))
def testPartialShapeToTensorKnownDimension(self):
with self.cached_session() as sess:
self.assertAllEqual([1], sess.run(convert.partial_shape_to_tensor(
tensor_shape.TensorShape([1]))))
self.assertAllEqual([1], sess.run(convert.partial_shape_to_tensor((1,))))
self.assertAllEqual([1], sess.run(convert.partial_shape_to_tensor([1])))
self.assertAllEqual([1], sess.run(convert.partial_shape_to_tensor(
constant_op.constant([1], dtype=dtypes.int64))))
self.assertAllEqual([1],
self.evaluate(
convert.partial_shape_to_tensor(
tensor_shape.TensorShape([1]))))
self.assertAllEqual([1], self.evaluate(
convert.partial_shape_to_tensor((1,))))
self.assertAllEqual([1], self.evaluate(
convert.partial_shape_to_tensor([1])))
self.assertAllEqual([1],
self.evaluate(
convert.partial_shape_to_tensor(
constant_op.constant([1], dtype=dtypes.int64))))
def testPartialShapeToTensorUnknownDimension(self):
with self.cached_session() as sess:
self.assertAllEqual([-1], sess.run(convert.partial_shape_to_tensor(
tensor_shape.TensorShape([None]))))
self.assertAllEqual([-1], sess.run(convert.partial_shape_to_tensor(
(None,))))
self.assertAllEqual([-1], sess.run(convert.partial_shape_to_tensor(
[None])))
self.assertAllEqual([-1], sess.run(convert.partial_shape_to_tensor(
[-1])))
self.assertAllEqual([-1], sess.run(convert.partial_shape_to_tensor(
constant_op.constant([-1], dtype=dtypes.int64))))
self.assertAllEqual([-1],
self.evaluate(
convert.partial_shape_to_tensor(
tensor_shape.TensorShape([None]))))
self.assertAllEqual([-1],
self.evaluate(convert.partial_shape_to_tensor((None,))))
self.assertAllEqual([-1],
self.evaluate(convert.partial_shape_to_tensor([None])))
self.assertAllEqual([-1],
self.evaluate(convert.partial_shape_to_tensor([-1])))
self.assertAllEqual([-1],
self.evaluate(
convert.partial_shape_to_tensor(
constant_op.constant([-1],
dtype=dtypes.int64))))
with self.assertRaisesRegexp(
ValueError, r"The given shape .* must be a 1-D tensor of tf.int64 "
@ -84,42 +89,63 @@ class ConvertTest(test.TestCase):
convert.partial_shape_to_tensor(constant_op.constant([1., 1.]))
def testPartialShapeToTensorMultipleDimensions(self):
with self.cached_session() as sess:
self.assertAllEqual([3, 6], sess.run(convert.partial_shape_to_tensor(
tensor_shape.TensorShape([3, 6]))))
self.assertAllEqual([3, 6], sess.run(convert.partial_shape_to_tensor(
(3, 6))))
self.assertAllEqual([3, 6], sess.run(convert.partial_shape_to_tensor(
[3, 6])))
self.assertAllEqual([3, 6], sess.run(convert.partial_shape_to_tensor(
constant_op.constant([3, 6], dtype=dtypes.int64))))
self.assertAllEqual([3, 6],
self.evaluate(
convert.partial_shape_to_tensor(
tensor_shape.TensorShape([3, 6]))))
self.assertAllEqual([3, 6],
self.evaluate(convert.partial_shape_to_tensor((3, 6))))
self.assertAllEqual([3, 6],
self.evaluate(convert.partial_shape_to_tensor([3, 6])))
self.assertAllEqual([3, 6],
self.evaluate(
convert.partial_shape_to_tensor(
constant_op.constant([3, 6],
dtype=dtypes.int64))))
self.assertAllEqual([3, -1], sess.run(convert.partial_shape_to_tensor(
tensor_shape.TensorShape([3, None]))))
self.assertAllEqual([3, -1], sess.run(convert.partial_shape_to_tensor(
(3, None))))
self.assertAllEqual([3, -1], sess.run(convert.partial_shape_to_tensor(
[3, None])))
self.assertAllEqual([3, -1], sess.run(convert.partial_shape_to_tensor(
constant_op.constant([3, -1], dtype=dtypes.int64))))
self.assertAllEqual([3, -1],
self.evaluate(
convert.partial_shape_to_tensor(
tensor_shape.TensorShape([3, None]))))
self.assertAllEqual([3, -1],
self.evaluate(
convert.partial_shape_to_tensor((3, None))))
self.assertAllEqual([3, -1],
self.evaluate(
convert.partial_shape_to_tensor([3, None])))
self.assertAllEqual([3, -1],
self.evaluate(
convert.partial_shape_to_tensor(
constant_op.constant([3, -1],
dtype=dtypes.int64))))
self.assertAllEqual([-1, -1], sess.run(convert.partial_shape_to_tensor(
tensor_shape.TensorShape([None, None]))))
self.assertAllEqual([-1, -1], sess.run(convert.partial_shape_to_tensor(
(None, None))))
self.assertAllEqual([-1, -1], sess.run(convert.partial_shape_to_tensor(
[None, None])))
self.assertAllEqual([-1, -1], sess.run(convert.partial_shape_to_tensor(
constant_op.constant([-1, -1], dtype=dtypes.int64))))
self.assertAllEqual([-1, -1],
self.evaluate(
convert.partial_shape_to_tensor(
tensor_shape.TensorShape([None, None]))))
self.assertAllEqual([-1, -1],
self.evaluate(
convert.partial_shape_to_tensor((None, None))))
self.assertAllEqual([-1, -1],
self.evaluate(
convert.partial_shape_to_tensor([None, None])))
self.assertAllEqual([-1, -1],
self.evaluate(
convert.partial_shape_to_tensor(
constant_op.constant([-1, -1],
dtype=dtypes.int64))))
def testPartialShapeToTensorScalar(self):
with self.cached_session() as sess:
self.assertAllEqual([], sess.run(convert.partial_shape_to_tensor(
tensor_shape.TensorShape([]))))
self.assertAllEqual([], sess.run(convert.partial_shape_to_tensor(())))
self.assertAllEqual([], sess.run(convert.partial_shape_to_tensor([])))
self.assertAllEqual([], sess.run(convert.partial_shape_to_tensor(
constant_op.constant([], dtype=dtypes.int64))))
self.assertAllEqual([],
self.evaluate(
convert.partial_shape_to_tensor(
tensor_shape.TensorShape([]))))
self.assertAllEqual([], self.evaluate(convert.partial_shape_to_tensor(())))
self.assertAllEqual([], self.evaluate(convert.partial_shape_to_tensor([])))
self.assertAllEqual([],
self.evaluate(
convert.partial_shape_to_tensor(
constant_op.constant([], dtype=dtypes.int64))))
if __name__ == "__main__":

View File

@ -1583,7 +1583,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
x = variables.VariableV1([1, 3, 3, 7], name="x")
_, idx = array_ops.unique(x, name="x_unique")
idx_times_two = math_ops.multiply(idx, 2, name="idx_times_two")
sess.run(x.initializer)
self.evaluate(x.initializer)
run_options = config_pb2.RunOptions(output_partition_graphs=True)
debug_utils.watch_graph(

View File

@ -126,8 +126,8 @@ class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase):
u = variables.Variable([12.0], name="u")
v = variables.Variable([30.0], name="v")
w = math_ops.add(u, v, name="w")
sess.run(u.initializer)
sess.run(v.initializer)
self.evaluate(u.initializer)
self.evaluate(v.initializer)
self._compareOriginalAndReconstructedGraphDefs(
sess, w, expected_output=[42.0])
@ -139,7 +139,7 @@ class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase):
b = math_ops.add(a, a, name="b")
with ops.control_dependencies([a, b]):
c = math_ops.multiply(b, b, name="c")
sess.run(a.initializer)
self.evaluate(a.initializer)
self._compareOriginalAndReconstructedGraphDefs(
sess, c, expected_output=400.0)
@ -150,8 +150,8 @@ class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase):
y = variables.Variable(20.0, name="y")
cond = control_flow_ops.cond(
x > y, lambda: math_ops.add(x, 1), lambda: math_ops.add(y, 1))
sess.run(x.initializer)
sess.run(y.initializer)
self.evaluate(x.initializer)
self.evaluate(y.initializer)
self._compareOriginalAndReconstructedGraphDefs(
sess, cond, expected_output=21.0)
@ -173,8 +173,8 @@ class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase):
toy_loss = x * (u - v)
train_op = gradient_descent.GradientDescentOptimizer(
learning_rate=0.1).minimize(toy_loss, name="train_op")
sess.run(u.initializer)
sess.run(v.initializer)
self.evaluate(u.initializer)
self.evaluate(v.initializer)
self._compareOriginalAndReconstructedGraphDefs(sess, train_op)

View File

@ -67,7 +67,7 @@ class SessionDebugMultiGPUTest(test_util.TensorFlowTestCase):
u1 = math_ops.multiply(v, v, name="u1")
w = math_ops.subtract(u1, u0, name="w")
sess.run(v.initializer)
self.evaluate(v.initializer)
run_options = config_pb2.RunOptions(output_partition_graphs=True)
debug_utils.watch_graph(run_options, sess.graph,

View File

@ -109,8 +109,8 @@ class SourceHelperTest(test_util.TensorFlowTestCase):
self.w = math_ops.matmul(self.u, self.v, name="w")
self.w_line_number = line_number_above()
sess.run(self.u.initializer)
sess.run(self.v.initializer)
self.evaluate(self.u.initializer)
self.evaluate(self.v.initializer)
run_options = config_pb2.RunOptions(output_partition_graphs=True)
debug_utils.watch_graph(

View File

@ -92,9 +92,9 @@ class AutoShardDatasetTest(test.TestCase):
with self.cached_session() as sess:
for f in range(self._shard_index, self._num_files, self._num_shards):
for r in range(self._num_records):
self.assertAllEqual(record_fn(r, f), sess.run(next_element))
self.assertAllEqual(record_fn(r, f), self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
def testTFRecordDataset(self):
dataset = readers.TFRecordDataset(self._createTFRecordFiles())
@ -138,10 +138,10 @@ class AutoShardDatasetTest(test.TestCase):
actual, expected = [], []
for f in range(self._shard_index, self._num_files, self._num_shards):
for r in range(self._num_records):
actual.append(sess.run(next_element))
actual.append(self.evaluate(next_element))
expected.append(self._record(r, f))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
self.assertAllEqual(expected, actual)
def testComplexPipeline(self):
@ -171,9 +171,9 @@ class AutoShardDatasetTest(test.TestCase):
num_iterations = (self._num_files * self._num_records * num_epochs) // (
self._num_shards * batch_size)
for _ in range(num_iterations):
actual.extend(sess.run(next_element))
actual.extend(self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
expected = []
for f in range(0, self._num_files, self._num_shards):
@ -205,12 +205,13 @@ class AutoShardDatasetTest(test.TestCase):
with self.cached_session() as sess:
for f in range(self._shard_index, self._num_files, self._num_shards):
for r in range(self._num_records):
self.assertAllEqual(self._record(r, f), sess.run(next_element))
self.assertAllEqual(self._record(r, f), self.evaluate(next_element))
for f in range(self._shard_index, self._num_files, self._num_shards):
for r in range(self._num_records):
self.assertAllEqual(self._text_line(r, f), sess.run(next_element))
self.assertAllEqual(
self._text_line(r, f), self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
self.evaluate(next_element)
def testTextLineReader(self):
dataset = readers.TextLineDataset(self._createTextFiles())

View File

@ -149,9 +149,9 @@ class DefFunctionTest(test.TestCase):
result = fn(3.0)
sess.run(variables.global_variables_initializer())
self.evaluate(variables.global_variables_initializer())
self.assertAllEqual(sess.run(state[0]), 2.0)
self.assertAllEqual(sess.run(result), 6.0)
self.assertAllEqual(self.evaluate(result), 6.0)
def testLegacyGraphModeVariablesNonTrivialInitializer(self):
with ops.Graph().as_default(), self.test_session() as sess:
@ -168,9 +168,9 @@ class DefFunctionTest(test.TestCase):
result = fn(3.0)
sess.run(variables.global_variables_initializer())
self.evaluate(variables.global_variables_initializer())
self.assertAllEqual(sess.run(state[0]), 6.0)
self.assertAllEqual(sess.run(result), 18.0)
self.assertAllEqual(self.evaluate(result), 18.0)
def testLegacyGraphModeInputDependentInitializerFails(self):
with ops.Graph().as_default():

View File

@ -78,7 +78,7 @@ class FunctionGradientsTest(test.TestCase, parameterized.TestCase):
c = constant_op.constant([[2.]])
f_c = f(c)
g, = gradients_impl.gradients(f_c, c)
self.assertAllEqual(sess.run(g).values, [[1.0]])
self.assertAllEqual(self.evaluate(g).values, [[1.0]])
def testNoSymGradNestedDefun(self):

View File

@ -564,7 +564,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
variables.global_variables_initializer().run()
call = def_function.function(o.call)
op = call()
self.assertAllEqual(sess.run(op), 2.0)
self.assertAllEqual(self.evaluate(op), 2.0)
def testGraphModeManyFunctions(self):
with ops.Graph().as_default(), self.cached_session():
@ -1732,7 +1732,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
function.register(cpu_boost, x)
y = gpu_boost(x)
y_value = sess.run(y)
y_value = self.evaluate(y)
if test.is_gpu_available():
self.assertEqual(y_value, 5.0)

View File

@ -1027,7 +1027,7 @@ class CrossedColumnTest(test.TestCase):
outputs = _transform_features(features, [price_cross_wire])
output = outputs[price_cross_wire]
with self.cached_session() as sess:
output_val = sess.run(output)
output_val = self.evaluate(output)
self.assertAllEqual(
[[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]], output_val.indices)
for val in output_val.values:
@ -1886,7 +1886,8 @@ class LinearModelTest(test.TestCase):
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
sess.run(bias.assign([5.]))
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], sess.run(net))
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]],
self.evaluate(net))
def test_with_1d_unknown_shape_sparse_tensor(self):
price = fc._numeric_column('price')
@ -2525,7 +2526,8 @@ class _LinearModelTest(test.TestCase):
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
sess.run(bias.assign([5.]))
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], sess.run(net))
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]],
self.evaluate(net))
def test_with_1d_unknown_shape_sparse_tensor(self):
price = fc._numeric_column('price')

View File

@ -1188,7 +1188,7 @@ class CrossedColumnTest(test.TestCase):
outputs = fc._transform_features_v2(features, [price_cross_wire], None)
output = outputs[price_cross_wire]
with self.cached_session() as sess:
output_val = sess.run(output)
output_val = self.evaluate(output)
self.assertAllEqual(
[[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]], output_val.indices)
for val in output_val.values:
@ -2088,7 +2088,8 @@ class LinearModelTest(test.TestCase):
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
sess.run(bias.assign([5.]))
self.assertAllClose([[10 - 1000 + 5.], [100 - 10 + 5.]], sess.run(net))
self.assertAllClose([[10 - 1000 + 5.], [100 - 10 + 5.]],
self.evaluate(net))
coord.request_stop()
coord.join(threads)
@ -2124,7 +2125,8 @@ class LinearModelTest(test.TestCase):
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
sess.run(bias.assign([5.]))
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], sess.run(net))
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]],
self.evaluate(net))
def test_with_1d_unknown_shape_sparse_tensor(self):
price = fc.numeric_column('price')
@ -2843,7 +2845,8 @@ class OldLinearModelTest(test.TestCase):
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
sess.run(bias.assign([5.]))
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], sess.run(net))
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]],
self.evaluate(net))
def test_with_1d_unknown_shape_sparse_tensor(self):
price = fc.numeric_column('price')

View File

@ -42,7 +42,7 @@ class FileSystemTest(test.TestCase):
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
queue.enqueue_many([["test://foo"]]).run()
queue.close().run()
key, value = sess.run(reader.read(queue))
key, value = self.evaluate(reader.read(queue))
self.assertEqual(key, compat.as_bytes("test://foo"))
self.assertEqual(value, compat.as_bytes("AAAAAAAAAA"))

View File

@ -102,7 +102,7 @@ class FunctionTest(test.TestCase):
call = MyIdentityFunc([18.0])
self.assertEqual("MyIdentity", call.op.name)
with session.Session() as sess:
self.assertAllEqual([18.0], sess.run(call))
self.assertAllEqual([18.0], self.evaluate(call))
def testIdentityImplicitDeref(self):
@ -116,8 +116,8 @@ class FunctionTest(test.TestCase):
self.assertEqual("MyIdentity", call.op.name)
for cfg in _OptimizerOptions():
with session.Session(config=cfg) as sess:
sess.run(var.initializer)
self.assertAllEqual([18.0], sess.run(call))
self.evaluate(var.initializer)
self.assertAllEqual([18.0], self.evaluate(call))
def testIdentityOutputName(self):
@ -130,7 +130,7 @@ class FunctionTest(test.TestCase):
call = MyIdentityFunc([18.0])
self.assertEqual("MyIdentity", call.op.name)
with session.Session() as sess:
self.assertAllEqual([18.0], sess.run(call))
self.assertAllEqual([18.0], self.evaluate(call))
def testTooManyOutputNames(self):
@ -158,7 +158,7 @@ class FunctionTest(test.TestCase):
call = APlus2B([1.0], [2.0])
self.assertEqual("APlus2B", call.op.name)
with session.Session() as sess:
self.assertAllEqual([5.0], sess.run(call))
self.assertAllEqual([5.0], self.evaluate(call))
def testFunctionWithNoOutput(self):
@ -187,7 +187,7 @@ class FunctionTest(test.TestCase):
call = APlus2B([1.0], [2.0])
self.assertEqual("APlus2B", call.op.name)
with session.Session() as sess:
self.assertAllEqual([5.0], sess.run(call))
self.assertAllEqual([5.0], self.evaluate(call))
def testDefineFunctionDuplicateOutputs(self):
@ -224,8 +224,8 @@ class FunctionTest(test.TestCase):
call_g = XSquarePlusOneGrad([2.0], [0.1])
with session.Session() as sess:
self.assertAllClose([5.0], sess.run(call_f))
self.assertAllClose([0.4], sess.run(call_g))
self.assertAllClose([5.0], self.evaluate(call_f))
self.assertAllClose([0.4], self.evaluate(call_g))
def testTanhSymGrad(self):
@ -365,7 +365,7 @@ class FunctionTest(test.TestCase):
else:
dx, dy = gradients_impl.gradients([z], [x, y])
with session.Session() as sess:
dx_val, dy_val = sess.run([dx, dy])
dx_val, dy_val = self.evaluate([dx, dy])
self.assertEqual([2.0], dx_val)
self.assertEqual([0.0], dy_val)
@ -387,7 +387,7 @@ class FunctionTest(test.TestCase):
call = AConstant()
self.assertEqual("AConstant", call.op.name)
with session.Session() as sess:
self.assertAllEqual([42], sess.run(call))
self.assertAllEqual([42], self.evaluate(call))
def testDefineFunctionNames(self):
@ -468,7 +468,7 @@ class FunctionTest(test.TestCase):
loop = control_flow_ops.while_loop(lambda x: x < 1e5, Body, [1.0])
ans = sess.run(loop)
ans = self.evaluate(loop)
self.assertAllClose(ans, 131072.)
def testControlFlowStrictness(self):
@ -650,8 +650,8 @@ class FunctionTest(test.TestCase):
# pylint: enable=unexpected-keyword-arg
self.assertEqual("next", call2.op.name)
with session.Session() as sess:
self.assertAllEqual([1], sess.run(call1))
self.assertAllEqual([0], sess.run(call2))
self.assertAllEqual([1], self.evaluate(call1))
self.assertAllEqual([0], self.evaluate(call2))
def testNestedFunction(self):
@ -794,7 +794,7 @@ class FunctionTest(test.TestCase):
y = Foo()
with self.session(graph=g) as sess:
self.assertEqual(sess.run(y), 10)
self.assertEqual(self.evaluate(y), 10)
def testCaptureInCond(self):
g = ops.Graph()
@ -809,8 +809,8 @@ class FunctionTest(test.TestCase):
z = Foo(False)
with self.session(graph=g) as sess:
self.assertEqual(sess.run(y), 1)
self.assertEqual(sess.run(z), 2)
self.assertEqual(self.evaluate(y), 1)
self.assertEqual(self.evaluate(z), 2)
def testStableName(self):
@ -854,7 +854,7 @@ class FunctionTest(test.TestCase):
z = Bar(x)
with self.session(graph=g) as sess:
v0, v1 = sess.run([y, z])
v0, v1 = self.evaluate([y, z])
self.assertAllEqual(v0, 20.)
self.assertAllEqual(v1, 20.)
@ -900,7 +900,7 @@ class FunctionTest(test.TestCase):
self.assertEqual(global_vars[0].name, "linear/w:0")
with session.Session() as sess:
sess.run(variables.global_variables_initializer())
self.evaluate(variables.global_variables_initializer())
output_val = sess.run(
output_op, feed_dict={input_op: np.random.rand(32, 100)})
self.assertEqual(output_val.shape, (32, 100))
@ -928,7 +928,7 @@ class FunctionTest(test.TestCase):
self.assertEqual(global_vars[0].name, "vs1/var:0")
with session.Session() as sess:
sess.run(variables.global_variables_initializer())
self.evaluate(variables.global_variables_initializer())
out1, out2 = sess.run(
[out1_op, out2_op], feed_dict={input_op: np.linspace(1, 10, 10)})
self.assertAllEqual(out1, np.linspace(2, 11, 10))
@ -991,8 +991,8 @@ class FunctionTest(test.TestCase):
result_2 = Bar(constant_op.constant(100, dtype=dtypes.int64))
with session.Session() as sess:
self.assertEqual(4.0, sess.run(result_1))
self.assertEqual(100, sess.run(result_2))
self.assertEqual(4.0, self.evaluate(result_1))
self.assertEqual(100, self.evaluate(result_2))
self.assertEqual((4.0, 100), sess.run((result_1, result_2)))
def testStatefulFunction(self):
@ -1052,8 +1052,8 @@ class FunctionTest(test.TestCase):
for config in _OptimizerOptions():
config.device_count["CPU"] = 2
with session.Session(config=config) as sess:
self.assertEqual(42.0, sess.run(f_0))
self.assertEqual(44.0, sess.run(f_1))
self.assertEqual(42.0, self.evaluate(f_0))
self.assertEqual(44.0, self.evaluate(f_1))
self.assertEqual((42.0, 44.0), sess.run((f_0, f_1)))
def testGuaranteedConstsAreCaptured(self):
@ -1076,7 +1076,7 @@ class FunctionTest(test.TestCase):
return output
with self.session(use_gpu=False) as sess:
sess.run(var.initializer)
self.evaluate(var.initializer)
_ = sess.run(CapturesGuaranteedConst(), {also_not_const: 1.0})
def testSameFunctionDifferentGrads(self):
@ -1127,7 +1127,7 @@ class FunctionTest(test.TestCase):
dx2, = gradients_impl.gradients(ys=[y2], xs=[x2])
with self.session(graph=g) as sess:
v0, v1, v2 = sess.run([dx0, dx1, dx2])
v0, v1, v2 = self.evaluate([dx0, dx1, dx2])
self.assertAllEqual(v0, 2.)
self.assertAllEqual(v1, 101.)
@ -1532,7 +1532,7 @@ class UnrollLSTMTest(test.TestCase):
tf_logging.info("time: %f txt size: %d gdef bin size: %d", finish - start,
len(str(gdef)), len(gdef.SerializeToString()))
with g.as_default(), session.Session(config=cfg) as sess:
return sess.run(m)
return self.evaluate(m)
mv0 = RunForward("complete")
for cfg in _OptimizerOptions():
@ -1561,7 +1561,7 @@ class UnrollLSTMTest(test.TestCase):
tf_logging.info("time: %f txt size: %d gdef bin size: %d", finish - start,
len(str(gdef)), len(gdef.SerializeToString()))
with g.as_default(), session.Session(config=cfg) as sess:
return sess.run(dw)
return self.evaluate(dw)
d0 = RunForwardBackward("complete")
for cfg in _OptimizerOptions():
@ -1651,8 +1651,8 @@ class ModuleFunctionTest(test.TestCase):
y = LinearWithCApi(a, b, c)
z = Linear2WithCApi(a, b, c, d, e)
with session.Session() as sess:
self.assertAllEqual([[1]], sess.run(y))
self.assertAllEqual([[5]], sess.run(z))
self.assertAllEqual([[1]], self.evaluate(y))
self.assertAllEqual([[5]], self.evaluate(z))
class VariableHoistingTest(test.TestCase):
@ -1704,8 +1704,8 @@ class VariableHoistingTest(test.TestCase):
self.assertEqual("Foo/b", b.op.name)
with self.session(graph=g) as sess:
sess.run(variables.global_variables_initializer())
w, b, x, y0, loss, dw, db = sess.run([w, b, x, y0, loss, dw, db])
self.evaluate(variables.global_variables_initializer())
w, b, x, y0, loss, dw, db = self.evaluate([w, b, x, y0, loss, dw, db])
self.assertAllEqual(w.shape, (64, 64))
self.assertAllClose(np.sum(w), 2050.44)

View File

@ -210,8 +210,8 @@ class DeviceFunctionsTest(test.TestCase):
with session.Session() as sess:
init = variables.variables_initializer([variable_node])
sess.run(init)
output = sess.run(output_node)
self.evaluate(init)
output = self.evaluate(output_node)
self.assertNear(4.0, output, 0.00001)
variable_graph_def = sess.graph.as_graph_def()
@ -242,8 +242,8 @@ class DeviceFunctionsTest(test.TestCase):
output_node = math_ops_lib.multiply(
variable_node, 2.0, name="output_node")
with session.Session() as sess:
sess.run(variable_node.initializer)
output = sess.run(output_node)
self.evaluate(variable_node.initializer)
output = self.evaluate(output_node)
self.assertNear(2.0, output, 0.00001)
variable_graph_def = sess.graph.as_graph_def()
# First get the constant_graph_def when variable_names_whitelist is
@ -256,7 +256,7 @@ class DeviceFunctionsTest(test.TestCase):
# Then initialize the unused variable, and get another
# constant_graph_def when variable_names_whitelist is not set.
sess.run(another_variable.initializer)
self.evaluate(another_variable.initializer)
constant_graph_def_without_variable_whitelist = (
graph_util.convert_variables_to_constants(
sess, variable_graph_def, ["output_node"]))
@ -295,7 +295,7 @@ class DeviceFunctionsTest(test.TestCase):
["Variable", "VariableV2", "VarHandleOp", "ReadVariableOp"])
with session.Session() as sess:
output_node = sess.graph.get_tensor_by_name("output_node:0")
output = sess.run(output_node)
output = self.evaluate(output_node)
self.assertNear(2.0, output, 0.00001)
def create_node_def(self, op, name, inputs):

View File

@ -397,11 +397,11 @@ class ImportGraphDefTest(test.TestCase):
# Run the imported graph.
# TODO(b/76173421): make this work (currently DCHECKS)
# with self.cached_session() as sess:
# sess.run(imported_init)
# self.assertEqual(sess.run(imported_var), 1.0)
# self.assertEqual(sess.run(imported_assign), 2.0)
# self.assertEqual(list(sess.run(imported_shape)), [])
# self.assertEqual(list(sess.run(new_var_shape)), [])
# self.evaluate(imported_init)
# self.assertEqual(self.evaluate(imported_var), 1.0)
# self.assertEqual(self.evaluate(imported_assign), 2.0)
# self.assertEqual(list(self.evaluate(imported_shape)), [])
# self.assertEqual(list(self.evaluate(new_var_shape)), [])
def testWhileLoop(self):
# Produce GraphDef containing while loop.
@ -418,7 +418,7 @@ class ImportGraphDefTest(test.TestCase):
return_elements=[r.name])
self.assertEqual(imported_r.name, "import/" + r.name)
with self.cached_session() as sess:
self.assertEqual(sess.run(imported_r), 10)
self.assertEqual(self.evaluate(imported_r), 10)
def testImportWhileLoopInCond(self):
# Produce GraphDef containing while loop.
@ -458,7 +458,7 @@ class ImportGraphDefTest(test.TestCase):
lambda i: i < 2, ImportFn, [0],
shape_invariants=[tensor_shape.TensorShape(None)])
with self.cached_session() as sess:
self.assertEqual(sess.run(out), 10)
self.assertEqual(self.evaluate(out), 10)
def testTypeMismatchInGraphDef(self):
# TODO(skyewm): improve error message

View File

@ -492,8 +492,8 @@ class ScopedMetaGraphTest(test.TestCase):
init_op = variables.global_variables_initializer()
grad = gradients_impl.gradients([output], [var])
with session.Session() as sess:
sess.run(init_op)
expected_grad_value = sess.run(grad)
self.evaluate(init_op)
expected_grad_value = self.evaluate(grad)
# Restore the MetaGraphDef into a new Graph with an import scope.
with ops.Graph().as_default():
@ -518,8 +518,8 @@ class ScopedMetaGraphTest(test.TestCase):
init_op = variables.global_variables_initializer()
with session.Session() as sess:
sess.run(init_op)
actual_grad_value = sess.run(grad)
self.evaluate(init_op)
actual_grad_value = self.evaluate(grad)
self.assertEqual(expected_grad_value, actual_grad_value)
def testImportWhileLoopInWhileLoop(self):
@ -544,8 +544,8 @@ class ScopedMetaGraphTest(test.TestCase):
_, x = control_flow_ops.while_loop(lambda i, x: i < 2, body, [0, 0.0],
name="")
with session.Session() as sess:
sess.run(variables.global_variables_initializer())
sess.run(x)
self.evaluate(variables.global_variables_initializer())
self.evaluate(x)
def testScopedImportUnderNameScope(self):
graph = ops.Graph()
@ -868,8 +868,8 @@ class MetaGraphWithVariableScopeTest(test.TestCase):
_, update_op = metrics.mean(values)
initializer = variables.local_variables_initializer()
sess.run(initializer)
sess.run(update_op)
self.evaluate(initializer)
self.evaluate(update_op)
meta_graph.export_scoped_meta_graph(
filename=meta_graph_filename, graph=graph)
@ -880,7 +880,7 @@ class MetaGraphWithVariableScopeTest(test.TestCase):
with self.session(graph=graph) as sess:
meta_graph.import_scoped_meta_graph(meta_graph_filename)
initializer = variables.local_variables_initializer()
sess.run(initializer)
self.evaluate(initializer)
# Verifies that importing an old meta_graph where "local_variables"
# collection is of node_list type works, but cannot build initializer

View File

@ -503,7 +503,7 @@ class OperationTest(test_util.TensorFlowTestCase):
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
"Graph is invalid, contains a cycle with 2 nodes"):
sess.run(x)
self.evaluate(x)
def testUpdateInput(self):
g = ops.Graph()
@ -517,21 +517,21 @@ class OperationTest(test_util.TensorFlowTestCase):
self.assertEquals(x.consumers(), [])
self.assertEquals(y.consumers(), [z.op, z.op])
with session.Session(graph=g) as sess:
self.assertEquals(sess.run(z), 4)
self.assertEquals(self.evaluate(z), 4)
z.op._update_input(0, x) # pylint: disable=protected-access
self.assertEquals(list(z.op.inputs), [x, y])
self.assertEquals(x.consumers(), [z.op])
self.assertEquals(y.consumers(), [z.op])
with session.Session(graph=g) as sess:
self.assertEquals(sess.run(z), 3)
self.assertEquals(self.evaluate(z), 3)
z.op._update_input(1, y) # pylint: disable=protected-access
self.assertEquals(list(z.op.inputs), [x, y])
self.assertEquals(x.consumers(), [z.op])
self.assertEquals(y.consumers(), [z.op])
with session.Session(graph=g) as sess:
self.assertEquals(sess.run(z), 3)
self.assertEquals(self.evaluate(z), 3)
def testUpdateInputGraphError(self):
g_0 = ops.Graph()
@ -557,7 +557,7 @@ class OperationTest(test_util.TensorFlowTestCase):
errors.InvalidArgumentError,
"Input 0 of node add was passed string from Const_1:0 incompatible "
"with expected int32"):
sess.run(z)
self.evaluate(z)
def testUpdateInputShapeError(self):
g = ops.Graph()
@ -2390,7 +2390,7 @@ class GraphTest(test_util.TensorFlowTestCase):
c = math_ops.add(a, b)
# Create a session we can delete
with session.Session(graph=g) as sess:
sess.run(c)
self.evaluate(c)
# Delete all references and trigger gc
del g
del a
@ -2406,7 +2406,7 @@ class GraphTest(test_util.TensorFlowTestCase):
math_ops.add([1, 2], [1, 2, 3])
a = constant_op.constant(1)
with session.Session() as sess:
sess.run(a)
self.evaluate(a)
def testRunnableAfterInvalidShapeWithKernelLabelMap(self):
g = ops.Graph()
@ -2416,7 +2416,7 @@ class GraphTest(test_util.TensorFlowTestCase):
test_ops.kernel_label_required(1)
a = constant_op.constant(1)
with session.Session() as sess:
sess.run(a)
self.evaluate(a)
class AttrScopeTest(test_util.TensorFlowTestCase):

View File

@ -109,8 +109,8 @@ class SmartCaseTest(test_util.TensorFlowTestCase):
exclusive=True)
with session.Session() as sess:
# No feed_dict necessary
self.assertEqual(sess.run(y), 1)
self.assertEqual(sess.run(z), 1)
self.assertEqual(self.evaluate(y), 1)
self.assertEqual(self.evaluate(z), 1)
def testFalse(self):
conditions = [(False, raise_exception)]
@ -121,8 +121,8 @@ class SmartCaseTest(test_util.TensorFlowTestCase):
default=lambda: constant_op.constant(1),
exclusive=True)
with session.Session() as sess:
self.assertEqual(sess.run(y), 1)
self.assertEqual(sess.run(z), 1)
self.assertEqual(self.evaluate(y), 1)
self.assertEqual(self.evaluate(z), 1)
def testMix(self):
x = array_ops.placeholder(dtype=dtypes.int32, shape=[])

View File

@ -50,7 +50,7 @@ class SparseTensorTest(test_util.TensorFlowTestCase):
self.assertAllEqual(indices, value.indices)
self.assertAllEqual(values, value.values)
self.assertAllEqual(shape, value.dense_shape)
sess_run_value = sess.run(sp)
sess_run_value = self.evaluate(sp)
self.assertAllEqual(sess_run_value.indices, value.indices)
self.assertAllEqual(sess_run_value.values, value.values)
self.assertAllEqual(sess_run_value.dense_shape, value.dense_shape)

View File

@ -66,9 +66,9 @@ class SubscribeTest(test_util.TensorFlowTestCase):
self.assertTrue(c.op in d.op.control_inputs)
with self.cached_session() as sess:
c_out = sess.run([c])
n_out = sess.run([n])
d_out = sess.run([d])
c_out = self.evaluate([c])
n_out = self.evaluate([n])
d_out = self.evaluate([d])
self.assertEqual(n_out, [-2])
self.assertEqual(c_out, [2])
@ -145,8 +145,8 @@ class SubscribeTest(test_util.TensorFlowTestCase):
lambda t: script_ops.py_func(sub, [t], [t.dtype]))
with self.cached_session() as sess:
c_out = sess.run([c])
d_out = sess.run([d])
c_out = self.evaluate([c])
d_out = self.evaluate([d])
self.assertEqual(c_out, [42])
self.assertEqual(d_out, [11])
@ -205,7 +205,7 @@ class SubscribeTest(test_util.TensorFlowTestCase):
# Expect the three side effect graphs to have been evaluated.
with self.cached_session() as sess:
sess.run([c_sub])
self.evaluate([c_sub])
self.assertIn('graph1', shared)
self.assertIn('graph2', shared)
self.assertIn('graph3', shared)
@ -229,20 +229,20 @@ class SubscribeTest(test_util.TensorFlowTestCase):
with self.cached_session() as sess:
# Initialize the variables first.
sess.run([v1.initializer])
sess.run([v2.initializer])
self.evaluate([v1.initializer])
self.evaluate([v2.initializer])
# Expect the side effects to be triggered when evaluating the add op as
# it will read the value of the variable.
sess.run([add])
self.evaluate([add])
self.assertEqual(1, len(shared))
# Expect the side effect not to be triggered when evaluating the assign
# op as it will not access the 'read' output of the variable.
sess.run([assign_v1])
self.evaluate([assign_v1])
self.assertEqual(1, len(shared))
sess.run([add])
self.evaluate([add])
self.assertEqual(2, len(shared))
# Make sure the values read from the variable match the expected ones.
@ -273,7 +273,7 @@ class SubscribeTest(test_util.TensorFlowTestCase):
self.assertFalse(subscribe._is_subscribed_identity(tensor_array.handle))
with self.cached_session() as sess:
sess.run([reader])
self.evaluate([reader])
self.assertEqual(0, len(shared))
def testMultipleOutputs(self):
@ -304,7 +304,7 @@ class SubscribeTest(test_util.TensorFlowTestCase):
lambda t: script_ops.py_func(sub, [t], [t.dtype]))
with self.cached_session() as sess:
sess.run([neg])
self.evaluate([neg])
# All three ops have been processed.
self.assertEqual(3, len(shared))
@ -375,7 +375,7 @@ class SubscribeTest(test_util.TensorFlowTestCase):
self.assertIsNot(context(subscriptions[0]), context(subscriptions[1]))
with self.cached_session() as sess:
sess.run(cond)
self.evaluate(cond)
self.assertEqual(3, len(results))

View File

@ -771,7 +771,7 @@ class TensorUtilTest(test.TestCase):
with self.cached_session() as sess:
ma = MockArray(np.array([10, 20, 30]))
t = ops.convert_to_tensor(ma)
a = sess.run(t)
a = self.evaluate(t)
self.assertEquals(np.int64, a.dtype)
self.assertAllClose(np.array([10, 20, 30], dtype=np.int64), a)

View File

@ -61,7 +61,7 @@ class ConstantFoldingTest(test.TestCase):
back_prop=False,
parallel_iterations=1)
with session.Session() as sess:
y_v = sess.run(y)
y_v = self.evaluate(y)
self.assertAllEqual(np.zeros([10, 20, 30]), y_v)

View File

@ -241,7 +241,7 @@ class LayoutOptimizerTest(test.TestCase):
if restore:
saver.restore(sess, checkpoint_path)
else:
sess.run(variables.global_variables_initializer())
self.evaluate(variables.global_variables_initializer())
np.random.seed(0)
for _ in range(2):
@ -262,7 +262,7 @@ class LayoutOptimizerTest(test.TestCase):
output = _two_layer_model(x)
with session.Session(config=_get_config(False)) as sess:
output_val_ref = sess.run(output)
output_val_ref = self.evaluate(output)
with session.Session(config=_get_config()) as sess:
metadata = config_pb2.RunMetadata()
@ -365,7 +365,7 @@ class LayoutOptimizerTest(test.TestCase):
output = array_ops.identity(pad)
with session.Session(config=_get_config(False)) as sess:
output_val_ref = sess.run(output)
output_val_ref = self.evaluate(output)
with session.Session(config=_get_config()) as sess:
metadata = config_pb2.RunMetadata()
@ -396,7 +396,7 @@ class LayoutOptimizerTest(test.TestCase):
output = array_ops.identity(reduce_sum)
with session.Session(config=_get_config(False)) as sess:
output_val_ref = sess.run(output)
output_val_ref = self.evaluate(output)
with session.Session(config=_get_config()) as sess:
metadata = config_pb2.RunMetadata()
@ -425,7 +425,7 @@ class LayoutOptimizerTest(test.TestCase):
output = array_ops.identity(cast)
with session.Session(config=_get_config(False)) as sess:
output_val_ref = sess.run(output)
output_val_ref = self.evaluate(output)
with session.Session(config=_get_config()) as sess:
metadata = config_pb2.RunMetadata()
@ -456,7 +456,7 @@ class LayoutOptimizerTest(test.TestCase):
output = array_ops.identity(squeeze)
with session.Session(config=_get_config(False)) as sess:
output_val_ref = sess.run(output)
output_val_ref = self.evaluate(output)
with session.Session(config=_get_config()) as sess:
metadata = config_pb2.RunMetadata()
@ -486,7 +486,7 @@ class LayoutOptimizerTest(test.TestCase):
output = array_ops.identity(squeeze)
with session.Session(config=_get_config(False)) as sess:
output_val_ref = sess.run(output)
output_val_ref = self.evaluate(output)
with session.Session(config=_get_config()) as sess:
metadata = config_pb2.RunMetadata()
@ -516,7 +516,7 @@ class LayoutOptimizerTest(test.TestCase):
output = array_ops.identity(squeeze)
with session.Session(config=_get_config(False)) as sess:
output_val_ref = sess.run(output)
output_val_ref = self.evaluate(output)
with session.Session(config=_get_config()) as sess:
metadata = config_pb2.RunMetadata()
@ -545,7 +545,7 @@ class LayoutOptimizerTest(test.TestCase):
output = array_ops.identity(reduce_sum)
with session.Session(config=_get_config(False)) as sess:
output_val_ref = sess.run(output)
output_val_ref = self.evaluate(output)
with session.Session(config=_get_config()) as sess:
metadata = config_pb2.RunMetadata()
@ -574,7 +574,7 @@ class LayoutOptimizerTest(test.TestCase):
output = array_ops.identity(reduce_sum)
with session.Session(config=_get_config(False)) as sess:
output_val_ref = sess.run(output)
output_val_ref = self.evaluate(output)
with session.Session(config=_get_config()) as sess:
metadata = config_pb2.RunMetadata()
@ -603,7 +603,7 @@ class LayoutOptimizerTest(test.TestCase):
output = array_ops.identity(reduce_sum)
with session.Session(config=_get_config(False)) as sess:
output_val_ref = sess.run(output)
output_val_ref = self.evaluate(output)
with session.Session(config=_get_config()) as sess:
metadata = config_pb2.RunMetadata()
@ -632,7 +632,7 @@ class LayoutOptimizerTest(test.TestCase):
output = array_ops.identity(reduce_sum)
with session.Session(config=_get_config(False)) as sess:
output_val_ref = sess.run(output)
output_val_ref = self.evaluate(output)
with session.Session(config=_get_config()) as sess:
metadata = config_pb2.RunMetadata()
@ -662,7 +662,7 @@ class LayoutOptimizerTest(test.TestCase):
output = array_ops.identity(reduce_sum)
with session.Session(config=_get_config(False)) as sess:
output_val_ref = sess.run(output)
output_val_ref = self.evaluate(output)
with session.Session(config=_get_config()) as sess:
metadata = config_pb2.RunMetadata()
@ -691,7 +691,7 @@ class LayoutOptimizerTest(test.TestCase):
output = array_ops.identity(reduce_sum)
with session.Session(config=_get_config(False)) as sess:
output_val_ref = sess.run(output)
output_val_ref = self.evaluate(output)
with session.Session(config=_get_config()) as sess:
metadata = config_pb2.RunMetadata()
@ -724,7 +724,7 @@ class LayoutOptimizerTest(test.TestCase):
output = array_ops.identity(concat)
with session.Session(config=_get_config(False)) as sess:
output_val_ref = sess.run(output)
output_val_ref = self.evaluate(output)
with session.Session(config=_get_config()) as sess:
metadata = config_pb2.RunMetadata()
@ -835,7 +835,7 @@ class LayoutOptimizerTest(test.TestCase):
output = array_ops.identity(reverse)
with session.Session(config=_get_config(False)) as sess:
output_val_ref = sess.run(output)
output_val_ref = self.evaluate(output)
with session.Session(config=_get_config()) as sess:
metadata = config_pb2.RunMetadata()
@ -905,7 +905,7 @@ class LayoutOptimizerTest(test.TestCase):
output = array_ops.identity(select)
with session.Session(config=_get_config(False)) as sess:
output_val_ref = sess.run(output)
output_val_ref = self.evaluate(output)
with session.Session(config=_get_config()) as sess:
metadata = config_pb2.RunMetadata()
@ -966,7 +966,7 @@ class LayoutOptimizerTest(test.TestCase):
output = array_ops.identity(select)
with session.Session(config=_get_config(False)) as sess:
output_val_ref = sess.run(output)
output_val_ref = self.evaluate(output)
with session.Session(config=_get_config()) as sess:
metadata = config_pb2.RunMetadata()
@ -1179,7 +1179,7 @@ class LayoutOptimizerTest(test.TestCase):
output = array_ops.identity(s)
with session.Session(config=_get_config(False)) as sess:
output_val_ref = sess.run(output)
output_val_ref = self.evaluate(output)
with session.Session(config=_get_config()) as sess:
metadata = config_pb2.RunMetadata()
@ -1214,7 +1214,7 @@ class LayoutOptimizerTest(test.TestCase):
output = array_ops.identity(s)
with session.Session(config=_get_config(False)) as sess:
output_val_ref = sess.run(output)
output_val_ref = self.evaluate(output)
with session.Session(config=_get_config()) as sess:
metadata = config_pb2.RunMetadata()
@ -1347,7 +1347,7 @@ class LayoutOptimizerTest(test.TestCase):
output = _loop()
with session.Session(config=_get_config(False)) as sess:
output_val_ref = sess.run(output)
output_val_ref = self.evaluate(output)
with session.Session(config=_get_config()) as sess:
metadata = config_pb2.RunMetadata()
@ -1374,7 +1374,7 @@ class LayoutOptimizerTest(test.TestCase):
output = _loop_with_branch()
with session.Session(config=_get_config(False)) as sess:
output_val_ref = sess.run(output)
output_val_ref = self.evaluate(output)
with session.Session(config=_get_config()) as sess:
metadata = config_pb2.RunMetadata()
@ -1398,7 +1398,7 @@ class LayoutOptimizerTest(test.TestCase):
output = _loop_with_vec_and_4d()
with session.Session(config=_get_config(False)) as sess:
output_val_ref = sess.run(output)
output_val_ref = self.evaluate(output)
with session.Session(config=_get_config()) as sess:
metadata = config_pb2.RunMetadata()
@ -1422,7 +1422,7 @@ class LayoutOptimizerTest(test.TestCase):
output = _model_with_second_port()
with session.Session(config=_get_config(False)) as sess:
output_val_ref = sess.run(output)
output_val_ref = self.evaluate(output)
with session.Session(config=_get_config()) as sess:
metadata = config_pb2.RunMetadata()

View File

@ -231,10 +231,10 @@ class MemoryOptimizerRecomputeTest(test.TestCase):
train_op = graph.get_operation_by_name(train_op_name)
loss_op = graph.get_tensor_by_name(loss_op_name)
with session.Session(config=config, graph=graph) as sess:
sess.run(init_op)
sess.run(train_op)
sess.run(train_op)
return sess.run(loss_op)
self.evaluate(init_op)
self.evaluate(train_op)
self.evaluate(train_op)
return self.evaluate(loss_op)
def testRecomputationRewritingNoErrors(self):
"""Tests that graph output is not significantly different with rewriting."""
@ -295,8 +295,8 @@ class MemoryOptimizerRecomputeTest(test.TestCase):
rewrite_options=manual_memory_config)
session_config = config_pb2.ConfigProto(graph_options=graph_options)
with session.Session(config=session_config) as sess:
sess.run(init_op)
sess.run(train_op)
self.evaluate(init_op)
self.evaluate(train_op)
def testHintDoesRewrite(self):
graph = self._annotated_graph()[0]

View File

@ -136,7 +136,7 @@ class BackendUtilsTest(test.TestCase):
x = keras.Input((3,))
y = keras.layers.BatchNormalization()(x)
if not context.executing_eagerly():
sess.run(variables.global_variables_initializer())
self.evaluate(variables.global_variables_initializer())
sess.run(y, feed_dict={x: np.random.random((2, 3))})
def test_learning_phase_scope(self):

View File

@ -1013,8 +1013,8 @@ class RNNTest(test.TestCase):
inputs, _ = cell(inputs, initial_state)
output = inputs
if not context.executing_eagerly():
sess.run(variables_lib.global_variables_initializer())
output = sess.run(output)
self.evaluate(variables_lib.global_variables_initializer())
output = self.evaluate(output)
return output
random_seed.set_random_seed(12345)

View File

@ -322,19 +322,19 @@ class KerasMetricsTest(test.TestCase):
m = metrics.Mean()
v = array_ops.placeholder(dtypes.float32)
w = array_ops.placeholder(dtypes.float32)
sess.run(variables.variables_initializer(m.variables))
self.evaluate(variables.variables_initializer(m.variables))
# check __call__()
result_t = m(v, sample_weight=w)
result = sess.run(result_t, feed_dict=({v: 100, w: 0.5}))
self.assertEqual(sess.run(m.total), 50)
self.assertEqual(sess.run(m.count), 0.5)
self.assertEqual(self.evaluate(m.total), 50)
self.assertEqual(self.evaluate(m.count), 0.5)
self.assertEqual(result, 50 / 0.5)
# check update_state() and result()
result = sess.run(result_t, feed_dict=({v: [1, 5], w: [1, 0.2]}))
self.assertAlmostEqual(sess.run(m.total), 52, 2) # 50 + 1 + 5 * 0.2
self.assertAlmostEqual(sess.run(m.count), 1.7, 2) # 0.5 + 1.2
self.assertAlmostEqual(self.evaluate(m.total), 52, 2) # 50 + 1 + 5 * 0.2
self.assertAlmostEqual(self.evaluate(m.count), 1.7, 2) # 0.5 + 1.2
self.assertAlmostEqual(result, 52 / 1.7, 2)
@test_util.run_in_graph_and_eager_modes

Some files were not shown because too many files have changed in this diff Show More