Replace a few calls of Session run
with evaluate
In order to support tests running in eager mode we need to avoid unnecessary use of Sessions in tests. This moves to remove some of the uses of the `run` function in favor of `evaluate`. PiperOrigin-RevId: 223009795
This commit is contained in:
parent
edf88fcda8
commit
b17d53c0cd
@ -57,11 +57,11 @@ class CategoricalTest(xla_test.XLATestCase):
|
||||
Returns:
|
||||
Frequencies from sampled classes; shape [batch_size, num_classes].
|
||||
"""
|
||||
with self.cached_session() as sess, self.test_scope():
|
||||
with self.cached_session(), self.test_scope():
|
||||
random_seed.set_random_seed(1618)
|
||||
op = random_ops.multinomial(logits, num_samples,
|
||||
output_dtype=dtypes.int32)
|
||||
d = sess.run(op)
|
||||
d = self.evaluate(op)
|
||||
|
||||
batch_size, num_classes = logits.shape
|
||||
freqs_mat = []
|
||||
@ -80,15 +80,15 @@ class CategoricalTest(xla_test.XLATestCase):
|
||||
|
||||
def _testRngIsNotConstant(self, rng, dtype, output_dtype):
|
||||
# Tests that 'rng' does not always return the same value.
|
||||
with self.cached_session() as sess:
|
||||
with self.cached_session():
|
||||
with self.test_scope():
|
||||
x = rng(dtype, output_dtype)
|
||||
|
||||
# The random-number generator, if working correctly, should produce the
|
||||
# same output multiple times with low probability.
|
||||
y = sess.run(x)
|
||||
z = sess.run(x)
|
||||
w = sess.run(x)
|
||||
y = self.evaluate(x)
|
||||
z = self.evaluate(x)
|
||||
w = self.evaluate(x)
|
||||
|
||||
# We use exact equality here. If the random-number generator is producing
|
||||
# deterministic output, all three outputs will be bitwise identical.
|
||||
@ -108,12 +108,12 @@ class CategoricalTest(xla_test.XLATestCase):
|
||||
def testCategoricalIsInRange(self):
|
||||
for dtype in self.float_types:
|
||||
for output_dtype in self.output_dtypes():
|
||||
with self.cached_session() as sess:
|
||||
with self.cached_session():
|
||||
with self.test_scope():
|
||||
x = random_ops.multinomial(
|
||||
array_ops.ones(shape=[1, 20], dtype=dtype), 1000,
|
||||
output_dtype=output_dtype)
|
||||
y = sess.run(x)
|
||||
y = self.evaluate(x)
|
||||
self.assertTrue((y >= 0).sum() == 1000)
|
||||
self.assertTrue((y < 20).sum() == 1000)
|
||||
|
||||
@ -170,11 +170,11 @@ class CategoricalTest(xla_test.XLATestCase):
|
||||
self.assertEqual(s0 == s1, np.all(v0 == v1))
|
||||
|
||||
def testEmpty(self):
|
||||
with self.cached_session() as sess:
|
||||
with self.cached_session():
|
||||
with self.test_scope():
|
||||
x = random_ops.multinomial(
|
||||
array_ops.zeros([42, 40]), 0, output_dtype=dtypes.int32)
|
||||
y = sess.run(x)
|
||||
y = self.evaluate(x)
|
||||
self.assertEqual(y.shape, (42, 0))
|
||||
|
||||
def testEmptyStateless(self):
|
||||
|
@ -254,7 +254,7 @@ class ConcatTest(xla_test.XLATestCase):
|
||||
def DISABLED_testZeroSize(self):
|
||||
# Verify that concat doesn't crash and burn for zero size inputs
|
||||
np.random.seed(7)
|
||||
with self.cached_session() as sess:
|
||||
with self.cached_session():
|
||||
with self.test_scope():
|
||||
for shape0 in (), (2,):
|
||||
axis = len(shape0)
|
||||
@ -270,7 +270,7 @@ class ConcatTest(xla_test.XLATestCase):
|
||||
self.assertAllEqual(c.eval(), correct)
|
||||
# Check gradients
|
||||
dc = np.random.randn(*c.get_shape().as_list())
|
||||
dxs = sess.run(gradients_impl.gradients(c, xs, dc))
|
||||
dxs = self.evaluate(gradients_impl.gradients(c, xs, dc))
|
||||
self.assertAllEqual(dc, np.concatenate(dxs, axis=axis))
|
||||
|
||||
def testConcatTuple(self):
|
||||
@ -330,47 +330,47 @@ class ConcatTest(xla_test.XLATestCase):
|
||||
class ConcatOffsetTest(xla_test.XLATestCase):
|
||||
|
||||
def testBasic(self):
|
||||
with self.cached_session() as sess:
|
||||
with self.cached_session():
|
||||
with self.test_scope():
|
||||
cdim = constant_op.constant(1, dtypes.int32)
|
||||
s0 = constant_op.constant([2, 3, 5], dtypes.int32)
|
||||
s1 = constant_op.constant([2, 7, 5], dtypes.int32)
|
||||
s2 = constant_op.constant([2, 20, 5], dtypes.int32)
|
||||
off = gen_array_ops.concat_offset(cdim, [s0, s1, s2])
|
||||
ans = sess.run(off)
|
||||
ans = self.evaluate(off)
|
||||
self.assertAllEqual(ans, [[0, 0, 0], [0, 3, 0], [0, 10, 0]])
|
||||
|
||||
|
||||
class PackTest(xla_test.XLATestCase):
|
||||
|
||||
def testBasic(self):
|
||||
with self.cached_session() as sess:
|
||||
with self.cached_session():
|
||||
with self.test_scope():
|
||||
s0 = constant_op.constant([2, 3, 5], dtypes.int32)
|
||||
s1 = constant_op.constant([2, 7, 5], dtypes.int32)
|
||||
s2 = constant_op.constant([2, 20, 5], dtypes.int32)
|
||||
packed = array_ops.stack([s0, s1, s2])
|
||||
ans = sess.run(packed)
|
||||
ans = self.evaluate(packed)
|
||||
self.assertAllEqual(ans, [[2, 3, 5], [2, 7, 5], [2, 20, 5]])
|
||||
|
||||
def testScalars(self):
|
||||
with self.cached_session() as sess:
|
||||
with self.cached_session():
|
||||
with self.test_scope():
|
||||
s0 = constant_op.constant(2, dtypes.int32)
|
||||
s1 = constant_op.constant(3, dtypes.int32)
|
||||
s2 = constant_op.constant(5, dtypes.int32)
|
||||
packed = array_ops.stack([s0, s1, s2])
|
||||
ans = sess.run(packed)
|
||||
ans = self.evaluate(packed)
|
||||
self.assertAllEqual(ans, [2, 3, 5])
|
||||
|
||||
def testEmpty(self):
|
||||
with self.cached_session() as sess:
|
||||
with self.cached_session():
|
||||
with self.test_scope():
|
||||
s0 = constant_op.constant([[]], dtypes.int32)
|
||||
s1 = constant_op.constant([[]], dtypes.int32)
|
||||
s2 = constant_op.constant([[]], dtypes.int32)
|
||||
packed = array_ops.stack([s0, s1, s2])
|
||||
ans = sess.run(packed)
|
||||
ans = self.evaluate(packed)
|
||||
self.assertAllEqual(ans, [[[]], [[]], [[]]])
|
||||
|
||||
|
||||
|
@ -72,7 +72,7 @@ class DenseLayerTest(test.TestCase):
|
||||
x = array_ops.placeholder(shape=[None, None, 3], dtype=np.float32)
|
||||
y = layers.dense(x, 3)
|
||||
|
||||
sess.run(variables.initialize_all_variables())
|
||||
self.evaluate(variables.initialize_all_variables())
|
||||
run_metadata = config_pb2.RunMetadata()
|
||||
test_utils.RunWithWarmup(
|
||||
sess,
|
||||
@ -97,7 +97,7 @@ class DenseLayerTest(test.TestCase):
|
||||
with jit_scope():
|
||||
y = layers.dense(x, 3)
|
||||
|
||||
sess.run(variables.initialize_all_variables())
|
||||
self.evaluate(variables.initialize_all_variables())
|
||||
run_metadata = config_pb2.RunMetadata()
|
||||
test_utils.RunWithWarmup(
|
||||
sess,
|
||||
@ -126,7 +126,7 @@ class DenseLayerTest(test.TestCase):
|
||||
with jit_scope():
|
||||
y = layers.dense(x, 3)
|
||||
|
||||
sess.run(variables.initialize_all_variables())
|
||||
self.evaluate(variables.initialize_all_variables())
|
||||
run_metadata = config_pb2.RunMetadata()
|
||||
test_utils.RunWithWarmup(
|
||||
sess,
|
||||
|
@ -101,12 +101,12 @@ class EagerTest(xla_test.XLATestCase):
|
||||
self.assertAllEqual(15, product)
|
||||
|
||||
# Run some ops graphly
|
||||
with context.graph_mode(), self.cached_session() as sess:
|
||||
with context.graph_mode(), self.cached_session():
|
||||
with self.test_scope():
|
||||
three = constant_op.constant(3)
|
||||
five = constant_op.constant(5)
|
||||
product = three * five
|
||||
self.assertAllEqual(15, sess.run(product))
|
||||
self.assertAllEqual(15, self.evaluate(product))
|
||||
|
||||
def testDegenerateSlices(self):
|
||||
with self.test_scope():
|
||||
|
@ -40,7 +40,7 @@ class FunctionTest(xla_test.XLATestCase):
|
||||
bval = np.array([5, 6, 7, 8]).reshape([2, 2]).astype(np.float32)
|
||||
expected = APlus2B(aval, bval)
|
||||
|
||||
with self.cached_session() as sess:
|
||||
with self.cached_session():
|
||||
|
||||
@function.Defun(dtypes.float32, dtypes.float32)
|
||||
def Foo(a, b):
|
||||
@ -50,7 +50,7 @@ class FunctionTest(xla_test.XLATestCase):
|
||||
b = constant_op.constant(bval, name="b")
|
||||
with self.test_scope():
|
||||
call_f = Foo(a, b)
|
||||
result = sess.run(call_f)
|
||||
result = self.evaluate(call_f)
|
||||
self.assertAllClose(result, expected, rtol=1e-3)
|
||||
|
||||
def testNestedFunctions(self):
|
||||
@ -66,7 +66,7 @@ class FunctionTest(xla_test.XLATestCase):
|
||||
bval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32)
|
||||
expected = APlus2B(aval, bval)
|
||||
|
||||
with self.cached_session() as sess:
|
||||
with self.cached_session():
|
||||
|
||||
@function.Defun(dtypes.float32, dtypes.float32)
|
||||
def Foo(a, b):
|
||||
@ -76,7 +76,7 @@ class FunctionTest(xla_test.XLATestCase):
|
||||
b = constant_op.constant(bval, name="b")
|
||||
with self.test_scope():
|
||||
call_g = Foo(a, b)
|
||||
result = sess.run(call_g)
|
||||
result = self.evaluate(call_g)
|
||||
self.assertAllClose(result, expected, rtol=1e-3)
|
||||
|
||||
def testFunctionMultipleRetvals(self):
|
||||
@ -90,7 +90,7 @@ class FunctionTest(xla_test.XLATestCase):
|
||||
bval = np.array([5, 6, 7, 8]).reshape([2, 2]).astype(np.float32)
|
||||
expected = Func(aval, bval)
|
||||
|
||||
with self.cached_session() as sess:
|
||||
with self.cached_session():
|
||||
|
||||
@function.Defun(dtypes.float32, dtypes.float32)
|
||||
def Foo(a, b):
|
||||
@ -100,7 +100,7 @@ class FunctionTest(xla_test.XLATestCase):
|
||||
b = constant_op.constant(bval, name="b")
|
||||
with self.test_scope():
|
||||
call_f = Foo(a, b)
|
||||
result = sess.run(call_f)
|
||||
result = self.evaluate(call_f)
|
||||
self.assertAllClose(result, expected, rtol=1e-3)
|
||||
|
||||
def testCompileTimeConstantsInDefun(self):
|
||||
|
@ -33,13 +33,13 @@ class ListDiffTest(xla_test.XLATestCase):
|
||||
def _testListDiff(self, x, y, out, idx):
|
||||
for dtype in [dtypes.int32, dtypes.int64]:
|
||||
for index_dtype in [dtypes.int32, dtypes.int64]:
|
||||
with self.cached_session() as sess:
|
||||
with self.cached_session():
|
||||
x_tensor = ops.convert_to_tensor(x, dtype=dtype)
|
||||
y_tensor = ops.convert_to_tensor(y, dtype=dtype)
|
||||
with self.test_scope():
|
||||
out_tensor, idx_tensor = array_ops.listdiff(
|
||||
x_tensor, y_tensor, out_idx=index_dtype)
|
||||
tf_out, tf_idx = sess.run([out_tensor, idx_tensor])
|
||||
tf_out, tf_idx = self.evaluate([out_tensor, idx_tensor])
|
||||
self.assertAllEqual(out, tf_out)
|
||||
self.assertAllEqual(idx, tf_idx)
|
||||
self.assertEqual(1, out_tensor.get_shape().ndims)
|
||||
|
@ -88,8 +88,8 @@ class LSTMTest(test.TestCase):
|
||||
(basename, m_prev_scalar, c_prev_scalar, pad_scalar))
|
||||
|
||||
# Initialize variables and run the unrolled LSTM step.
|
||||
sess.run(variables.global_variables_initializer())
|
||||
return sess.run([m, c])
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
return self.evaluate([m, c])
|
||||
|
||||
def testLSTMCell(self):
|
||||
# Run with all-0 weights, no padding.
|
||||
@ -173,8 +173,8 @@ class LSTMTest(test.TestCase):
|
||||
(basename, m_init_scalar, c_init_scalar, pad_scalar))
|
||||
|
||||
# Initialize variables and run the unrolled LSTM layer.
|
||||
sess.run(variables.global_variables_initializer())
|
||||
return sess.run(out_seq)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
return self.evaluate(out_seq)
|
||||
|
||||
def testLSTMLayer(self):
|
||||
# Run with all-0 weights, no padding.
|
||||
|
@ -33,7 +33,7 @@ class PlaceholderTest(xla_test.XLATestCase):
|
||||
ph = array_ops.placeholder_with_default(v, shape=[])
|
||||
out = ph * 2
|
||||
sess.run(variables.variables_initializer([v]))
|
||||
self.assertEqual(8.0, sess.run(out))
|
||||
self.assertEqual(8.0, self.evaluate(out))
|
||||
|
||||
def test_placeholder_with_default_fed(self):
|
||||
with self.cached_session() as sess, self.test_scope():
|
||||
|
@ -46,9 +46,9 @@ class RandomOpsTest(xla_test.XLATestCase):
|
||||
|
||||
# The random-number generator, if working correctly, should produce the
|
||||
# same output multiple times with low probability.
|
||||
y = sess.run(x)
|
||||
z = sess.run(x)
|
||||
w = sess.run(x)
|
||||
y = self.evaluate(x)
|
||||
z = self.evaluate(x)
|
||||
w = self.evaluate(x)
|
||||
|
||||
# We use exact equality here. If the random-number generator is producing
|
||||
# deterministic output, all three outputs will be bitwise identical.
|
||||
@ -83,7 +83,7 @@ class RandomOpsTest(xla_test.XLATestCase):
|
||||
with self.test_scope():
|
||||
x = random_ops.random_uniform(
|
||||
shape=[1000], dtype=dtype, minval=-2, maxval=33)
|
||||
y = sess.run(x)
|
||||
y = self.evaluate(x)
|
||||
self.assertTrue((y >= -2).sum() == 1000)
|
||||
self.assertTrue((y < 33).sum() == 1000)
|
||||
|
||||
@ -102,7 +102,7 @@ class RandomOpsTest(xla_test.XLATestCase):
|
||||
with self.cached_session() as sess:
|
||||
with self.test_scope():
|
||||
x = random_ops.truncated_normal(shape=[count], dtype=dtype)
|
||||
y = sess.run(x)
|
||||
y = self.evaluate(x)
|
||||
|
||||
def normal_cdf(x):
|
||||
return .5 * math.erfc(-x / math.sqrt(2))
|
||||
@ -111,7 +111,7 @@ class RandomOpsTest(xla_test.XLATestCase):
|
||||
return math.exp(-(x**2) / 2.) / math.sqrt(2 * math.pi)
|
||||
|
||||
def probit(x, sess=sess):
|
||||
return sess.run(special_math.ndtri(x))
|
||||
return self.evaluate(special_math.ndtri(x))
|
||||
|
||||
a = -2.
|
||||
b = 2.
|
||||
@ -148,7 +148,7 @@ class RandomOpsTest(xla_test.XLATestCase):
|
||||
with self.test_scope():
|
||||
x = math_ops.range(1 << 16)
|
||||
shuffle = random_ops.random_shuffle(x)
|
||||
result = sess.run(shuffle)
|
||||
result = self.evaluate(shuffle)
|
||||
expected = range(1 << 16)
|
||||
# Compare sets to avoid randomness behavior changes but make sure still
|
||||
# have all the values.
|
||||
@ -159,7 +159,7 @@ class RandomOpsTest(xla_test.XLATestCase):
|
||||
with self.test_scope():
|
||||
x = array_ops.diag(math_ops.range(20))
|
||||
shuffle = random_ops.random_shuffle(x)
|
||||
result = sess.run(shuffle)
|
||||
result = self.evaluate(shuffle)
|
||||
expected = np.diag(range(20)).flatten()
|
||||
# Compare sets to avoid randomness behavior changes but make sure still
|
||||
# have all the values.
|
||||
|
@ -156,7 +156,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
|
||||
return math.exp(-(x**2) / 2.) / math.sqrt(2 * math.pi)
|
||||
|
||||
def probit(x, sess=sess):
|
||||
return sess.run(special_math.ndtri(x))
|
||||
return self.evaluate(special_math.ndtri(x))
|
||||
|
||||
a = -2.
|
||||
b = 2.
|
||||
|
@ -505,7 +505,7 @@ class TensorArrayTest(xla_test.XLATestCase):
|
||||
[-0.5, 1.5], # read(0) gradient
|
||||
[20.0, 30.0, 40.0, 50.0], # concat gradient
|
||||
])
|
||||
grad_vals = sess.run(grad_r) # 2 + 2 entries
|
||||
grad_vals = self.evaluate(grad_r) # 2 + 2 entries
|
||||
|
||||
self.assertAllClose([2.0 - 0.5 + 20.0, 3.0 + 1.5 + 30.0], grad_vals[0])
|
||||
self.assertAllEqual([4.0 + 40.0, 5.0 + 50.0], grad_vals[1])
|
||||
|
@ -77,7 +77,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
sess.run(variables.variables_initializer([v]))
|
||||
x = v.sparse_read(2)
|
||||
self.assertAllClose(
|
||||
np.array([8j, 9, 10, 11]).astype(dtype), sess.run(x))
|
||||
np.array([8j, 9, 10, 11]).astype(dtype), self.evaluate(x))
|
||||
|
||||
def testSparseRead1DIndices(self):
|
||||
for dtype in self.numeric_types:
|
||||
@ -89,7 +89,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
x = v.sparse_read([2, 1])
|
||||
self.assertAllClose(
|
||||
np.array([[8, 9, 10, 11], [4, 5, 6j, 7]]).astype(dtype),
|
||||
sess.run(x))
|
||||
self.evaluate(x))
|
||||
|
||||
def testSparseRead2DIndices(self):
|
||||
for dtype in self.numeric_types:
|
||||
@ -102,7 +102,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
self.assertAllClose(
|
||||
np.array([[[8, 9, 10, 11], [4, 5, 6, 7]],
|
||||
[[0, 1, 2j, 3], [8, 9, 10, 11]]]).astype(dtype),
|
||||
sess.run(x))
|
||||
self.evaluate(x))
|
||||
|
||||
def testSparseRead2DIndices3DTensor(self):
|
||||
for dtype in self.numeric_types:
|
||||
@ -115,9 +115,9 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
x = v.sparse_read([[2, 1], [3, 0]])
|
||||
self.assertAllClose(
|
||||
np.array(
|
||||
[[[[20, 21, 22], [23, 24j, 25]], [[10, 11, 12], [13, 14, 15]]
|
||||
], [[[30, 31, 32], [33, 34, 35]], [[0, 1, 2], [3, 4, 5]]]
|
||||
],).astype(dtype), sess.run(x))
|
||||
[[[[20, 21, 22], [23, 24j, 25]], [[10, 11, 12], [13, 14, 15]]],
|
||||
[[[30, 31, 32], [33, 34, 35]], [[0, 1, 2], [3, 4, 5]]]
|
||||
],).astype(dtype), self.evaluate(x))
|
||||
|
||||
def testShape(self):
|
||||
for dtype in self.numeric_types:
|
||||
@ -229,7 +229,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
resource_variable_ops.resource_scatter_add(
|
||||
handle, [0], constant_op.constant([[2]], dtype=dtypes.int32)))
|
||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||
self.assertAllEqual(sess.run(read), [[3], [7]])
|
||||
self.assertAllEqual(self.evaluate(read), [[3], [7]])
|
||||
|
||||
def testScatterSub(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
@ -242,7 +242,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
resource_variable_ops.resource_scatter_sub(
|
||||
handle, [1], constant_op.constant([[2]], dtype=dtypes.int32)))
|
||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||
self.assertAllEqual(sess.run(read), [[4], [-1]])
|
||||
self.assertAllEqual(self.evaluate(read), [[4], [-1]])
|
||||
|
||||
def testScatterMul(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
@ -255,7 +255,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
resource_variable_ops.resource_scatter_mul(
|
||||
handle, [0], constant_op.constant([[5]], dtype=dtypes.int32)))
|
||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||
self.assertEqual(sess.run(read), [[5]])
|
||||
self.assertEqual(self.evaluate(read), [[5]])
|
||||
|
||||
def testScatterDiv(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
@ -268,7 +268,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
resource_variable_ops.resource_scatter_div(
|
||||
handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
|
||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||
self.assertAllEqual(sess.run(read), [[2]])
|
||||
self.assertAllEqual(self.evaluate(read), [[2]])
|
||||
|
||||
def testScatterMin(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
@ -281,7 +281,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
resource_variable_ops.resource_scatter_min(
|
||||
handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
|
||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||
self.assertEqual(sess.run(read), [[3]])
|
||||
self.assertEqual(self.evaluate(read), [[3]])
|
||||
|
||||
def testScatterMax(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
@ -294,7 +294,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
resource_variable_ops.resource_scatter_max(
|
||||
handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
|
||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||
self.assertEqual(sess.run(read), [[6]])
|
||||
self.assertEqual(self.evaluate(read), [[6]])
|
||||
|
||||
def testScatterUpdate(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
@ -307,7 +307,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
resource_variable_ops.resource_scatter_update(
|
||||
handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
|
||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||
self.assertEqual(sess.run(read), [[3]])
|
||||
self.assertEqual(self.evaluate(read), [[3]])
|
||||
|
||||
def testScatterAddScalar(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
@ -320,7 +320,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
resource_variable_ops.resource_scatter_add(
|
||||
handle, [0], constant_op.constant(2, dtype=dtypes.int32)))
|
||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||
self.assertEqual(sess.run(read), [[3]])
|
||||
self.assertEqual(self.evaluate(read), [[3]])
|
||||
|
||||
def testScatterSubScalar(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
@ -333,7 +333,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
resource_variable_ops.resource_scatter_sub(
|
||||
handle, [0], constant_op.constant(2, dtype=dtypes.int32)))
|
||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||
self.assertEqual(sess.run(read), [[-1]])
|
||||
self.assertEqual(self.evaluate(read), [[-1]])
|
||||
|
||||
def testScatterMulScalar(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
@ -346,7 +346,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
resource_variable_ops.resource_scatter_mul(
|
||||
handle, [0], constant_op.constant(5, dtype=dtypes.int32)))
|
||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||
self.assertEqual(sess.run(read), [[5]])
|
||||
self.assertEqual(self.evaluate(read), [[5]])
|
||||
|
||||
def testScatterDivScalar(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
@ -359,7 +359,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
resource_variable_ops.resource_scatter_div(
|
||||
handle, [0], constant_op.constant(3, dtype=dtypes.int32)))
|
||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||
self.assertEqual(sess.run(read), [[2]])
|
||||
self.assertEqual(self.evaluate(read), [[2]])
|
||||
|
||||
def testScatterMinScalar(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
@ -372,7 +372,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
resource_variable_ops.resource_scatter_min(
|
||||
handle, [0], constant_op.constant(3, dtype=dtypes.int32)))
|
||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||
self.assertEqual(sess.run(read), [[3]])
|
||||
self.assertEqual(self.evaluate(read), [[3]])
|
||||
|
||||
def testScatterMaxScalar(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
@ -385,7 +385,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
resource_variable_ops.resource_scatter_max(
|
||||
handle, [0], constant_op.constant(3, dtype=dtypes.int32)))
|
||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||
self.assertEqual(sess.run(read), [[6]])
|
||||
self.assertEqual(self.evaluate(read), [[6]])
|
||||
|
||||
def testScatterNdAddOps(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
@ -400,7 +400,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
sess.run(gen_state_ops.resource_scatter_nd_add(handle, indices, updates))
|
||||
read = resource_variable_ops.read_variable_op(
|
||||
handle, dtype=dtypes.float32)
|
||||
self.assertAllClose(expected, sess.run(read))
|
||||
self.assertAllClose(expected, self.evaluate(read))
|
||||
|
||||
def testScatterNdUpdateAddOps(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
@ -416,7 +416,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
||||
gen_state_ops.resource_scatter_nd_update(handle, indices, updates))
|
||||
read = resource_variable_ops.read_variable_op(
|
||||
handle, dtype=dtypes.float32)
|
||||
self.assertAllClose(expected, sess.run(read))
|
||||
self.assertAllClose(expected, self.evaluate(read))
|
||||
|
||||
|
||||
class StridedSliceAssignChecker(object):
|
||||
|
@ -81,7 +81,7 @@ class XlaDeviceTest(xla_test.XLATestCase):
|
||||
with self.cached_session() as sess:
|
||||
with self.test_scope():
|
||||
x = gen_control_flow_ops.control_trigger()
|
||||
sess.run(x)
|
||||
self.evaluate(x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -93,10 +93,10 @@ class KerasTest(tf.test.TestCase):
|
||||
init = tf.global_variables_initializer()
|
||||
|
||||
with tf.Session() as sess:
|
||||
sess.run(init)
|
||||
self.evaluate(init)
|
||||
sample_input = tf.random_uniform((1, 10, 10, 1))
|
||||
output = model(sample_input) # pylint: disable=not-callable
|
||||
self.assertEqual(sess.run(output).shape, (1, 3))
|
||||
self.assertEqual(self.evaluate(output).shape, (1, 3))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -34,7 +34,7 @@ class ListLiteralsTest(tf.test.TestCase):
|
||||
result = converted()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(sess.run(result), [1, 2, 3])
|
||||
self.assertAllEqual(self.evaluate(result), [1, 2, 3])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -35,7 +35,7 @@ class InputDataTest(test.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
sample_data = tf.zeros([32000, 2])
|
||||
wav_encoder = contrib_audio.encode_wav(sample_data, 16000)
|
||||
wav_data = sess.run(wav_encoder)
|
||||
wav_data = self.evaluate(wav_encoder)
|
||||
return wav_data
|
||||
|
||||
def _saveTestWavFile(self, filename, wav_data):
|
||||
|
@ -33,7 +33,7 @@ class LabelWavTest(test.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
sample_data = tf.zeros([1000, 2])
|
||||
wav_encoder = contrib_audio.encode_wav(sample_data, 16000)
|
||||
wav_data = sess.run(wav_encoder)
|
||||
wav_data = self.evaluate(wav_encoder)
|
||||
return wav_data
|
||||
|
||||
def _saveTestWavFile(self, filename, wav_data):
|
||||
|
@ -33,7 +33,7 @@ class WavToFeaturesTest(test.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
sample_data = tf.zeros([32000, 2])
|
||||
wav_encoder = contrib_audio.encode_wav(sample_data, 16000)
|
||||
wav_data = sess.run(wav_encoder)
|
||||
wav_data = self.evaluate(wav_encoder)
|
||||
return wav_data
|
||||
|
||||
def _saveTestWavFile(self, filename, wav_data):
|
||||
|
@ -111,7 +111,7 @@ class UnidirectionalSequenceLstmTest(test_util.TensorFlowTestCase):
|
||||
|
||||
# Initialize variables
|
||||
init = tf.global_variables_initializer()
|
||||
sess.run(init)
|
||||
self.evaluate(init)
|
||||
for _ in range(TRAIN_STEPS):
|
||||
batch_x, batch_y = self.mnist.train.next_batch(
|
||||
batch_size=self.batch_size, shuffle=False)
|
||||
|
@ -41,7 +41,7 @@ class AssertsTest(converter_testing.TestCase):
|
||||
op = result.test_fn(constant_op.constant(False))
|
||||
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
||||
'test message'):
|
||||
sess.run(op)
|
||||
self.evaluate(op)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -94,7 +94,7 @@ class CallTreesTest(converter_testing.TestCase):
|
||||
dtypes.int64) as result:
|
||||
with self.cached_session() as sess:
|
||||
self.assertTrue(isinstance(result.test_fn(), ops.Tensor))
|
||||
self.assertIn(sess.run(result.test_fn()), (0, 1, 2))
|
||||
self.assertIn(self.evaluate(result.test_fn()), (0, 1, 2))
|
||||
|
||||
def test_uncompiled_modules(self):
|
||||
|
||||
@ -113,7 +113,7 @@ class CallTreesTest(converter_testing.TestCase):
|
||||
with self.compiled(node, ns) as result:
|
||||
with self.cached_session() as sess:
|
||||
result_tensor = result.test_fn(constant_op.constant(1))
|
||||
self.assertEquals(sess.run(result_tensor), 3)
|
||||
self.assertEquals(self.evaluate(result_tensor), 3)
|
||||
|
||||
def test_call_to_decorated_function(self):
|
||||
|
||||
|
@ -68,7 +68,7 @@ class ListTest(converter_testing.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
tl = result.test_fn()
|
||||
r = list_ops.tensor_list_stack(tl, dtypes.int32)
|
||||
self.assertAllEqual(sess.run(r), [1, 2, 3])
|
||||
self.assertAllEqual(self.evaluate(r), [1, 2, 3])
|
||||
|
||||
def test_list_pop(self):
|
||||
|
||||
@ -91,8 +91,8 @@ class ListTest(converter_testing.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
ts, tl = result.test_fn()
|
||||
r = list_ops.tensor_list_stack(tl, dtypes.int32)
|
||||
self.assertAllEqual(sess.run(r), [1, 2])
|
||||
self.assertAllEqual(sess.run(ts), 3)
|
||||
self.assertAllEqual(self.evaluate(r), [1, 2])
|
||||
self.assertAllEqual(self.evaluate(ts), 3)
|
||||
|
||||
def test_double_list_pop(self):
|
||||
|
||||
@ -123,7 +123,7 @@ class ListTest(converter_testing.TestCase):
|
||||
|
||||
with self.compiled(node, {}, array_ops.stack, dtypes.int32) as result:
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(sess.run(result.test_fn()), [1, 2, 3])
|
||||
self.assertAllEqual(self.evaluate(result.test_fn()), [1, 2, 3])
|
||||
|
||||
# TODO(mdan): Add a test with tf.stack with axis kwarg.
|
||||
|
||||
|
@ -48,12 +48,12 @@ class SideEffectGuardsTest(converter_testing.TestCase):
|
||||
with self.compiled(node, {}, state_ops.assign) as result:
|
||||
with self.cached_session() as sess:
|
||||
v = variable_scope.get_variable('test', initializer=2)
|
||||
sess.run(v.initializer)
|
||||
sess.run(result.test_fn(v))
|
||||
self.evaluate(v.initializer)
|
||||
self.evaluate(result.test_fn(v))
|
||||
# TODO(mdan): Add support for this use case.
|
||||
# Right now the variable `a` is not conditioned on the `assign` because
|
||||
# there's no way to add control dependencies to a variable object.
|
||||
self.assertEqual(2, sess.run(v))
|
||||
self.assertEqual(2, self.evaluate(v))
|
||||
|
||||
def test_side_effect_on_used_variable(self):
|
||||
|
||||
@ -69,11 +69,11 @@ class SideEffectGuardsTest(converter_testing.TestCase):
|
||||
with self.compiled(node, {}, state_ops.assign) as result:
|
||||
with self.cached_session() as sess:
|
||||
v = variable_scope.get_variable('test', initializer=2)
|
||||
sess.run(v.initializer)
|
||||
sess.run(result.test_fn(v))
|
||||
self.evaluate(v.initializer)
|
||||
self.evaluate(result.test_fn(v))
|
||||
# TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
|
||||
# Right now it's 3 or 4 based on whether the read is synchronized.
|
||||
self.assertEqual(3, sess.run(v))
|
||||
self.assertEqual(3, self.evaluate(v))
|
||||
|
||||
def test_side_effect_on_tensor(self):
|
||||
|
||||
@ -109,10 +109,10 @@ class SideEffectGuardsTest(converter_testing.TestCase):
|
||||
with self.compiled(node, {}, state_ops.assign_add) as result:
|
||||
with self.cached_session() as sess:
|
||||
v = variable_scope.get_variable('test', initializer=2)
|
||||
sess.run(v.initializer)
|
||||
sess.run(result.test_fn(v))
|
||||
self.evaluate(v.initializer)
|
||||
self.evaluate(result.test_fn(v))
|
||||
# TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
|
||||
self.assertEqual(4, sess.run(v))
|
||||
self.assertEqual(4, self.evaluate(v))
|
||||
|
||||
def test_multiline_nested_block(self):
|
||||
|
||||
@ -130,10 +130,10 @@ class SideEffectGuardsTest(converter_testing.TestCase):
|
||||
with self.compiled(node, {}, state_ops.assign, ops.name_scope) as result:
|
||||
with self.cached_session() as sess:
|
||||
v = variable_scope.get_variable('test', initializer=2)
|
||||
sess.run(v.initializer)
|
||||
sess.run(result.test_fn(v))
|
||||
self.evaluate(v.initializer)
|
||||
self.evaluate(result.test_fn(v))
|
||||
# TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
|
||||
self.assertEqual(3, sess.run(v))
|
||||
self.assertEqual(3, self.evaluate(v))
|
||||
|
||||
def test_multiline_block_unsafe(self):
|
||||
|
||||
@ -153,10 +153,10 @@ class SideEffectGuardsTest(converter_testing.TestCase):
|
||||
state_ops.assign_add) as result:
|
||||
with self.cached_session() as sess:
|
||||
v = variable_scope.get_variable('test', initializer=2)
|
||||
sess.run(v.initializer)
|
||||
sess.run(result.test_fn(v))
|
||||
self.evaluate(v.initializer)
|
||||
self.evaluate(result.test_fn(v))
|
||||
# TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
|
||||
self.assertEqual(4, sess.run(v))
|
||||
self.assertEqual(4, self.evaluate(v))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -49,7 +49,7 @@ class SliceTest(converter_testing.TestCase):
|
||||
tl = list_ops.tensor_list_from_tensor(
|
||||
[1, 2], element_shape=constant_op.constant([], dtype=dtypes.int32))
|
||||
y = result.test_fn(tl)
|
||||
self.assertEqual(2, sess.run(y))
|
||||
self.assertEqual(2, self.evaluate(y))
|
||||
|
||||
def test_index_access_multiple_definitions(self):
|
||||
|
||||
|
@ -55,7 +55,7 @@ class RuntimeErrorsTest(test.TestCase):
|
||||
with self.assertRaises(errors.TfRuntimeError) as cm:
|
||||
with errors.improved_errors(zero_div_caller):
|
||||
with self.cached_session() as sess:
|
||||
sess.run(ops)
|
||||
self.evaluate(ops)
|
||||
|
||||
for frame in cm.exception.custom_traceback:
|
||||
_, _, function_name, _ = frame
|
||||
@ -70,7 +70,7 @@ class RuntimeErrorsTest(test.TestCase):
|
||||
with self.assertRaises(errors.TfRuntimeError) as cm:
|
||||
with errors.improved_errors(zero_div_caller):
|
||||
with self.cached_session() as sess:
|
||||
sess.run(ops)
|
||||
self.evaluate(ops)
|
||||
|
||||
all_function_names = set()
|
||||
for frame in cm.exception.custom_traceback:
|
||||
@ -87,7 +87,7 @@ class RuntimeErrorsTest(test.TestCase):
|
||||
with self.assertRaises(tf_errors.InvalidArgumentError):
|
||||
with errors.improved_errors(zero_div_caller):
|
||||
with self.cached_session() as sess:
|
||||
sess.run(ops)
|
||||
self.evaluate(ops)
|
||||
|
||||
def test_improved_errors_validation(self):
|
||||
with self.assertRaisesRegexp(
|
||||
|
@ -63,7 +63,7 @@ class ApiTest(test.TestCase):
|
||||
x = tc.test_method(
|
||||
constant_op.constant([2, 4]), constant_op.constant(1),
|
||||
constant_op.constant(-2))
|
||||
self.assertListEqual([0, 1], sess.run(x).tolist())
|
||||
self.assertListEqual([0, 1], self.evaluate(x).tolist())
|
||||
|
||||
def test_decorator_does_not_recurse(self):
|
||||
|
||||
@ -83,7 +83,7 @@ class ApiTest(test.TestCase):
|
||||
x = tc.test_method(
|
||||
constant_op.constant([2, 4]), constant_op.constant(1),
|
||||
constant_op.constant(-2))
|
||||
self.assertListEqual([0, 1], sess.run(x).tolist())
|
||||
self.assertListEqual([0, 1], self.evaluate(x).tolist())
|
||||
|
||||
def test_decorator_calls_unconverted_graph(self):
|
||||
|
||||
@ -104,7 +104,7 @@ class ApiTest(test.TestCase):
|
||||
x = tc.test_method(
|
||||
constant_op.constant([2, 4]), constant_op.constant(1),
|
||||
constant_op.constant(-2))
|
||||
self.assertListEqual([0, 1], sess.run(x).tolist())
|
||||
self.assertListEqual([0, 1], self.evaluate(x).tolist())
|
||||
|
||||
def test_decorator_calls_unconverted_py_func(self):
|
||||
|
||||
@ -130,7 +130,7 @@ class ApiTest(test.TestCase):
|
||||
x = tc.test_method(
|
||||
constant_op.constant([2, 4]), constant_op.constant(1),
|
||||
constant_op.constant(-2))
|
||||
self.assertListEqual([0, 1], sess.run(x).tolist())
|
||||
self.assertListEqual([0, 1], self.evaluate(x).tolist())
|
||||
|
||||
def test_decorator_calls_decorated(self):
|
||||
|
||||
@ -153,7 +153,7 @@ class ApiTest(test.TestCase):
|
||||
x = tc.test_method(
|
||||
constant_op.constant([2, 4]), constant_op.constant(1),
|
||||
constant_op.constant(-2))
|
||||
self.assertListEqual([0, 1], sess.run(x).tolist())
|
||||
self.assertListEqual([0, 1], self.evaluate(x).tolist())
|
||||
|
||||
def test_decorator_preserves_argspec(self):
|
||||
|
||||
@ -192,7 +192,7 @@ class ApiTest(test.TestCase):
|
||||
x = tc.test_method(
|
||||
constant_op.constant([2, 4]), constant_op.constant(1),
|
||||
constant_op.constant(-2))
|
||||
self.assertListEqual([0, 1], sess.run(x).tolist())
|
||||
self.assertListEqual([0, 1], self.evaluate(x).tolist())
|
||||
|
||||
def test_converted_call_builtin(self):
|
||||
x = api.converted_call(range, None, converter.ConversionOptions(), 3)
|
||||
@ -208,7 +208,7 @@ class ApiTest(test.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
x = api.converted_call(test_fn, None, converter.ConversionOptions(),
|
||||
constant_op.constant(-1))
|
||||
self.assertEqual(1, sess.run(x))
|
||||
self.assertEqual(1, self.evaluate(x))
|
||||
|
||||
def test_converted_call_method_explicit_owner(self):
|
||||
# TODO(mdan): Implement.
|
||||
@ -234,7 +234,7 @@ class ApiTest(test.TestCase):
|
||||
tc = TestClass(constant_op.constant(-1))
|
||||
x = api.converted_call(tc.test_method, None,
|
||||
converter.ConversionOptions(), tc)
|
||||
self.assertEqual(1, sess.run(x))
|
||||
self.assertEqual(1, self.evaluate(x))
|
||||
|
||||
def test_converted_call_method_by_class(self):
|
||||
|
||||
@ -252,7 +252,7 @@ class ApiTest(test.TestCase):
|
||||
tc = TestClass(constant_op.constant(-1))
|
||||
x = api.converted_call(TestClass.test_method, None,
|
||||
converter.ConversionOptions(), tc)
|
||||
self.assertEqual(1, sess.run(x))
|
||||
self.assertEqual(1, self.evaluate(x))
|
||||
|
||||
def test_converted_call_callable_object(self):
|
||||
|
||||
@ -269,7 +269,7 @@ class ApiTest(test.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
tc = TestClass(constant_op.constant(-1))
|
||||
x = api.converted_call(tc, None, converter.ConversionOptions())
|
||||
self.assertEqual(1, sess.run(x))
|
||||
self.assertEqual(1, self.evaluate(x))
|
||||
|
||||
def test_converted_call_constructor(self):
|
||||
|
||||
@ -288,7 +288,7 @@ class ApiTest(test.TestCase):
|
||||
constant_op.constant(-1))
|
||||
# tc is now a converted object.
|
||||
x = tc.test_method()
|
||||
self.assertEqual(1, sess.run(x))
|
||||
self.assertEqual(1, self.evaluate(x))
|
||||
|
||||
def test_converted_call_already_converted(self):
|
||||
|
||||
@ -298,12 +298,12 @@ class ApiTest(test.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
x = api.converted_call(f, None, converter.ConversionOptions(),
|
||||
constant_op.constant(0))
|
||||
self.assertTrue(sess.run(x))
|
||||
self.assertTrue(self.evaluate(x))
|
||||
|
||||
converted_f = api.to_graph(f)
|
||||
x = api.converted_call(converted_f, None, converter.ConversionOptions(),
|
||||
constant_op.constant(0))
|
||||
self.assertTrue(sess.run(x))
|
||||
self.assertTrue(self.evaluate(x))
|
||||
|
||||
def test_converted_call_no_user_code(self):
|
||||
|
||||
@ -334,8 +334,8 @@ class ApiTest(test.TestCase):
|
||||
constant_op.constant([[0.0]]), training=True)
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
self.assertAllEqual([[0.0, 0.0]], sess.run(x))
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertAllEqual([[0.0, 0.0]], self.evaluate(x))
|
||||
|
||||
def test_converted_call_whitelisted_method_extra_self(self):
|
||||
|
||||
@ -349,8 +349,8 @@ class ApiTest(test.TestCase):
|
||||
model, constant_op.constant([[0.0]]), training=True)
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
self.assertAllEqual([[0.0, 0.0]], sess.run(x))
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertAllEqual([[0.0, 0.0]], self.evaluate(x))
|
||||
|
||||
def test_converted_call_whitelisted_method_via_owner(self):
|
||||
|
||||
@ -364,8 +364,8 @@ class ApiTest(test.TestCase):
|
||||
constant_op.constant([[0.0]]), training=True)
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
self.assertAllEqual([[0.0, 0.0]], sess.run(x))
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertAllEqual([[0.0, 0.0]], self.evaluate(x))
|
||||
|
||||
def test_converted_call_lambda(self):
|
||||
|
||||
@ -376,8 +376,8 @@ class ApiTest(test.TestCase):
|
||||
x = api.converted_call(l, None, opts, constant_op.constant(0))
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
self.assertAllEqual(True, sess.run(x))
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertAllEqual(True, self.evaluate(x))
|
||||
|
||||
def test_to_graph_basic(self):
|
||||
|
||||
@ -390,7 +390,7 @@ class ApiTest(test.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
x = compiled_fn(constant_op.constant([4, 8]), 4)
|
||||
self.assertListEqual([1, 2], sess.run(x).tolist())
|
||||
self.assertListEqual([1, 2], self.evaluate(x).tolist())
|
||||
|
||||
def test_to_graph_with_defaults(self):
|
||||
|
||||
@ -405,7 +405,7 @@ class ApiTest(test.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
x = compiled_fn(constant_op.constant([4, 8]))
|
||||
self.assertListEqual([1, 2], sess.run(x).tolist())
|
||||
self.assertListEqual([1, 2], self.evaluate(x).tolist())
|
||||
|
||||
def test_to_code_basic(self):
|
||||
|
||||
|
@ -36,7 +36,7 @@ class SpecialFunctionsTest(test.TestCase):
|
||||
python_one = special_functions.match_staging_level(1, 1)
|
||||
with self.cached_session() as sess:
|
||||
self.assertTrue(tensor_util.is_tensor(tensor_one))
|
||||
self.assertAllEqual(sess.run(tensor_one), 1)
|
||||
self.assertAllEqual(self.evaluate(tensor_one), 1)
|
||||
self.assertEqual(python_one, 1)
|
||||
|
||||
def test_tensor_list_empty_list(self):
|
||||
@ -45,21 +45,21 @@ class SpecialFunctionsTest(test.TestCase):
|
||||
element_shape=())
|
||||
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(sess.run(sl), [])
|
||||
self.assertAllEqual(self.evaluate(sl), [])
|
||||
|
||||
l = special_functions.tensor_list((),
|
||||
element_dtype=dtypes.int32,
|
||||
element_shape=())
|
||||
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(sess.run(sl), [])
|
||||
self.assertAllEqual(self.evaluate(sl), [])
|
||||
|
||||
def test_tensor_list_tensor(self):
|
||||
l = special_functions.tensor_list(
|
||||
constant_op.constant([], dtype=dtypes.int32))
|
||||
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(sess.run(sl), [])
|
||||
self.assertAllEqual(self.evaluate(sl), [])
|
||||
|
||||
def test_tensor_list_unsupported_initializer(self):
|
||||
with self.assertRaisesRegexp(ValueError, 'unknown type'):
|
||||
@ -76,7 +76,7 @@ class SpecialFunctionsTest(test.TestCase):
|
||||
l = special_functions.tensor_list(elements)
|
||||
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
|
||||
self.assertAllEqual(self.evaluate(sl), [[1, 2], [3, 4]])
|
||||
|
||||
def test_tensor_list_array_from_elements(self):
|
||||
elements = [constant_op.constant([1, 2]), constant_op.constant([3, 4])]
|
||||
@ -84,7 +84,7 @@ class SpecialFunctionsTest(test.TestCase):
|
||||
l = special_functions.tensor_list(elements, use_tensor_array=True)
|
||||
sl = l.stack()
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
|
||||
self.assertAllEqual(self.evaluate(sl), [[1, 2], [3, 4]])
|
||||
|
||||
def test_stack(self):
|
||||
self.assertEqual(special_functions.stack(1, strict=False), 1)
|
||||
|
@ -35,7 +35,7 @@ class ForLoopTest(test.TestCase):
|
||||
body=lambda i, s: (s + i,),
|
||||
init_state=(0,))
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual((10,), sess.run(s))
|
||||
self.assertEqual((10,), self.evaluate(s))
|
||||
|
||||
def test_python(self):
|
||||
s = control_flow.for_stmt(
|
||||
@ -53,7 +53,7 @@ class ForLoopTest(test.TestCase):
|
||||
body=lambda i, s: (s + i,),
|
||||
init_state=(0,))
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual((10,), sess.run(s))
|
||||
self.assertEqual((10,), self.evaluate(s))
|
||||
|
||||
|
||||
class WhileLoopTest(test.TestCase):
|
||||
@ -66,7 +66,7 @@ class WhileLoopTest(test.TestCase):
|
||||
init_state=(0, 0),
|
||||
extra_deps=(n,))
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual((5, 10), sess.run(results))
|
||||
self.assertEqual((5, 10), self.evaluate(results))
|
||||
|
||||
def test_python(self):
|
||||
n = 5
|
||||
@ -90,9 +90,9 @@ class IfStmtTest(test.TestCase):
|
||||
def test_tensor(self):
|
||||
with self.cached_session() as sess:
|
||||
t = self.single_return_if_stmt(constant_op.constant(True))
|
||||
self.assertEqual(1, sess.run(t))
|
||||
self.assertEqual(1, self.evaluate(t))
|
||||
t = self.single_return_if_stmt(constant_op.constant(False))
|
||||
self.assertEqual(-1, sess.run(t))
|
||||
self.assertEqual(-1, self.evaluate(t))
|
||||
|
||||
def test_python(self):
|
||||
self.assertEqual(1, self.single_return_if_stmt(True))
|
||||
@ -101,9 +101,9 @@ class IfStmtTest(test.TestCase):
|
||||
def test_tensor_multiple_returns(self):
|
||||
with self.cached_session() as sess:
|
||||
t = self.multi_return_if_stmt(constant_op.constant(True))
|
||||
self.assertAllEqual([1, 2], sess.run(t))
|
||||
self.assertAllEqual([1, 2], self.evaluate(t))
|
||||
t = self.multi_return_if_stmt(constant_op.constant(False))
|
||||
self.assertAllEqual([-1, -2], sess.run(t))
|
||||
self.assertAllEqual([-1, -2], self.evaluate(t))
|
||||
|
||||
def test_python_multiple_returns(self):
|
||||
self.assertEqual((1, 2), self.multi_return_if_stmt(True))
|
||||
|
@ -43,7 +43,7 @@ class ListTest(test.TestCase):
|
||||
l = data_structures.tf_tensor_list_new([3, 4, 5])
|
||||
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(sess.run(t), [3, 4, 5])
|
||||
self.assertAllEqual(self.evaluate(t), [3, 4, 5])
|
||||
|
||||
def test_tf_tensor_list_new_empty(self):
|
||||
l = data_structures.tf_tensor_list_new([],
|
||||
@ -51,13 +51,13 @@ class ListTest(test.TestCase):
|
||||
element_shape=())
|
||||
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(sess.run(t), [])
|
||||
self.assertAllEqual(self.evaluate(t), [])
|
||||
|
||||
def test_tf_tensor_list_new_from_tensor(self):
|
||||
l = data_structures.tf_tensor_list_new(constant_op.constant([3, 4, 5]))
|
||||
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(sess.run(t), [3, 4, 5])
|
||||
self.assertAllEqual(self.evaluate(t), [3, 4, 5])
|
||||
|
||||
def test_tf_tensor_list_new_illegal_input(self):
|
||||
with self.assertRaises(ValueError):
|
||||
@ -77,7 +77,7 @@ class ListTest(test.TestCase):
|
||||
l = data_structures.tf_tensor_array_new([3, 4, 5])
|
||||
t = l.stack()
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(sess.run(t), [3, 4, 5])
|
||||
self.assertAllEqual(self.evaluate(t), [3, 4, 5])
|
||||
|
||||
def test_tf_tensor_array_new_illegal_input(self):
|
||||
with self.assertRaises(ValueError):
|
||||
@ -102,15 +102,15 @@ class ListTest(test.TestCase):
|
||||
|
||||
t = list_ops.tensor_list_stack(l, element_dtype=x.dtype)
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(sess.run(t), [[1, 2, 3]])
|
||||
self.assertAllEqual(self.evaluate(t), [[1, 2, 3]])
|
||||
|
||||
def test_append_tensorarray(self):
|
||||
l = tensor_array_ops.TensorArray(dtypes.int32, size=0, dynamic_size=True)
|
||||
l1 = data_structures.list_append(l, 1)
|
||||
l2 = data_structures.list_append(l1, 2)
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(sess.run(l1.stack()), [1])
|
||||
self.assertAllEqual(sess.run(l2.stack()), [1, 2])
|
||||
self.assertAllEqual(self.evaluate(l1.stack()), [1])
|
||||
self.assertAllEqual(self.evaluate(l2.stack()), [1, 2])
|
||||
|
||||
def test_append_python(self):
|
||||
l = []
|
||||
@ -131,10 +131,10 @@ class ListTest(test.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
l, x = data_structures.list_pop(l, None, opts)
|
||||
self.assertAllEqual(sess.run(x), [3, 4])
|
||||
self.assertAllEqual(self.evaluate(x), [3, 4])
|
||||
|
||||
t = list_ops.tensor_list_stack(l, element_dtype=initial_list.dtype)
|
||||
self.assertAllEqual(sess.run(t), [[1, 2]])
|
||||
self.assertAllEqual(self.evaluate(t), [[1, 2]])
|
||||
|
||||
def test_pop_python(self):
|
||||
l = [1, 2, 3]
|
||||
@ -152,7 +152,7 @@ class ListTest(test.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
t = data_structures.list_stack(l, opts)
|
||||
self.assertAllEqual(sess.run(t), sess.run(initial_list))
|
||||
self.assertAllEqual(self.evaluate(t), self.evaluate(initial_list))
|
||||
|
||||
def test_stack_tensor_list_empty(self):
|
||||
l = list_ops.empty_tensor_list(
|
||||
|
@ -30,7 +30,7 @@ class ExceptionsTest(test.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
t = exceptions.assert_stmt(
|
||||
constant_op.constant(True), lambda: constant_op.constant('ignored'))
|
||||
sess.run(t)
|
||||
self.evaluate(t)
|
||||
|
||||
def test_assert_tf_triggered(self):
|
||||
with self.cached_session() as sess:
|
||||
@ -40,7 +40,7 @@ class ExceptionsTest(test.TestCase):
|
||||
|
||||
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
||||
'test message'):
|
||||
sess.run(t)
|
||||
self.evaluate(t)
|
||||
|
||||
def test_assert_tf_multiple_printed_values(self):
|
||||
two_tensors = [
|
||||
@ -53,7 +53,7 @@ class ExceptionsTest(test.TestCase):
|
||||
|
||||
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
||||
'test message.*another message'):
|
||||
sess.run(t)
|
||||
self.evaluate(t)
|
||||
|
||||
def test_assert_python_untriggered(self):
|
||||
side_effect_trace = []
|
||||
|
@ -45,11 +45,11 @@ class LogicalOperatorsTest(test.TestCase):
|
||||
def test_and_tf(self):
|
||||
with self.cached_session() as sess:
|
||||
t = logical.and_(self._tf_true, self._tf_true)
|
||||
self.assertEqual(sess.run(t), True)
|
||||
self.assertEqual(self.evaluate(t), True)
|
||||
t = logical.and_(self._tf_true, lambda: True)
|
||||
self.assertEqual(sess.run(t), True)
|
||||
self.assertEqual(self.evaluate(t), True)
|
||||
t = logical.and_(self._tf_false, lambda: True)
|
||||
self.assertEqual(sess.run(t), False)
|
||||
self.assertEqual(self.evaluate(t), False)
|
||||
# TODO(mdan): Add a test for ops with side effects.
|
||||
|
||||
def test_or_python(self):
|
||||
@ -63,11 +63,11 @@ class LogicalOperatorsTest(test.TestCase):
|
||||
def test_or_tf(self):
|
||||
with self.cached_session() as sess:
|
||||
t = logical.or_(self._tf_false, self._tf_true)
|
||||
self.assertEqual(sess.run(t), True)
|
||||
self.assertEqual(self.evaluate(t), True)
|
||||
t = logical.or_(self._tf_false, lambda: True)
|
||||
self.assertEqual(sess.run(t), True)
|
||||
self.assertEqual(self.evaluate(t), True)
|
||||
t = logical.or_(self._tf_true, lambda: True)
|
||||
self.assertEqual(sess.run(t), True)
|
||||
self.assertEqual(self.evaluate(t), True)
|
||||
# TODO(mdan): Add a test for ops with side effects.
|
||||
|
||||
def test_not_python(self):
|
||||
@ -78,7 +78,7 @@ class LogicalOperatorsTest(test.TestCase):
|
||||
def test_not_tf(self):
|
||||
with self.cached_session() as sess:
|
||||
t = logical.not_(self._tf_false())
|
||||
self.assertEqual(sess.run(t), True)
|
||||
self.assertEqual(self.evaluate(t), True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -38,29 +38,29 @@ class PyBuiltinsTest(test.TestCase):
|
||||
self.assertEqual(py_builtins.abs_(-1), 1)
|
||||
with self.cached_session() as sess:
|
||||
t = py_builtins.abs_(constant_op.constant(-1))
|
||||
self.assertEqual(sess.run(t), 1)
|
||||
self.assertEqual(self.evaluate(t), 1)
|
||||
t = py_builtins.abs_(constant_op.constant([-1, 2, -3]))
|
||||
self.assertAllEqual(sess.run(t), [1, 2, 3])
|
||||
self.assertAllEqual(self.evaluate(t), [1, 2, 3])
|
||||
|
||||
def test_float(self):
|
||||
self.assertEqual(py_builtins.float_(10), 10.0)
|
||||
self.assertEqual(py_builtins.float_('10.0'), 10.0)
|
||||
with self.cached_session() as sess:
|
||||
t = py_builtins.float_(constant_op.constant(1, dtype=dtypes.int64))
|
||||
self.assertEqual(sess.run(t), 1.0)
|
||||
self.assertEqual(self.evaluate(t), 1.0)
|
||||
st = py_builtins.float_(constant_op.constant('1.0'))
|
||||
self.assertEqual(sess.run(st), 1.0)
|
||||
self.assertEqual(self.evaluate(st), 1.0)
|
||||
|
||||
def test_int(self):
|
||||
self.assertEqual(py_builtins.int_(10.0), 10)
|
||||
self.assertEqual(py_builtins.int_('11', 2), 3)
|
||||
with self.cached_session() as sess:
|
||||
t = py_builtins.int_(constant_op.constant(1, dtype=dtypes.float64))
|
||||
self.assertEqual(sess.run(t), 1)
|
||||
self.assertEqual(self.evaluate(t), 1)
|
||||
st = py_builtins.int_(constant_op.constant('1'))
|
||||
self.assertEqual(sess.run(st), 1)
|
||||
self.assertEqual(self.evaluate(st), 1)
|
||||
st = py_builtins.int_(constant_op.constant('1'), 10)
|
||||
self.assertEqual(sess.run(st), 1)
|
||||
self.assertEqual(self.evaluate(st), 1)
|
||||
|
||||
def test_int_unsupported_base(self):
|
||||
t = constant_op.constant(1, dtype=dtypes.float64)
|
||||
@ -73,9 +73,9 @@ class PyBuiltinsTest(test.TestCase):
|
||||
t = py_builtins.len_(constant_op.constant([[1], [2], [3]]))
|
||||
self.assertEqual(t, 3)
|
||||
ta = py_builtins.len_(tensor_array_ops.TensorArray(dtypes.int32, size=5))
|
||||
self.assertEqual(sess.run(ta), 5)
|
||||
self.assertEqual(self.evaluate(ta), 5)
|
||||
tl = py_builtins.len_(data_structures.tf_tensor_list_new([3, 4, 5]))
|
||||
self.assertEqual(sess.run(tl), 3)
|
||||
self.assertEqual(self.evaluate(tl), 3)
|
||||
|
||||
def test_len_scalar(self):
|
||||
with self.assertRaises(ValueError):
|
||||
@ -120,18 +120,18 @@ class PyBuiltinsTest(test.TestCase):
|
||||
def test_range_tensor(self):
|
||||
with self.cached_session() as sess:
|
||||
r = py_builtins.range_(constant_op.constant(3))
|
||||
self.assertAllEqual(sess.run(r), [0, 1, 2])
|
||||
self.assertAllEqual(self.evaluate(r), [0, 1, 2])
|
||||
r = py_builtins.range_(1, constant_op.constant(3))
|
||||
self.assertAllEqual(sess.run(r), [1, 2])
|
||||
self.assertAllEqual(self.evaluate(r), [1, 2])
|
||||
r = py_builtins.range_(2, 0, constant_op.constant(-1))
|
||||
self.assertAllEqual(sess.run(r), [2, 1])
|
||||
self.assertAllEqual(self.evaluate(r), [2, 1])
|
||||
|
||||
def test_range_tensor_empty_range(self):
|
||||
with self.session() as sess:
|
||||
r = py_builtins.range_(constant_op.constant(-3))
|
||||
self.assertAllEqual(sess.run(r), [])
|
||||
self.assertAllEqual(self.evaluate(r), [])
|
||||
r = py_builtins.range_(5, constant_op.constant(2))
|
||||
self.assertAllEqual(sess.run(r), [])
|
||||
self.assertAllEqual(self.evaluate(r), [])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -34,7 +34,7 @@ class SlicesTest(test.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
t = list_ops.tensor_list_stack(l, element_dtype=initial_list.dtype)
|
||||
self.assertAllEqual(sess.run(t), [[5, 6], [3, 4]])
|
||||
self.assertAllEqual(self.evaluate(t), [[5, 6], [3, 4]])
|
||||
|
||||
def test_get_item_tensor_list(self):
|
||||
initial_list = constant_op.constant([[1, 2], [3, 4]])
|
||||
@ -44,7 +44,7 @@ class SlicesTest(test.TestCase):
|
||||
l, 1, slices.GetItemOpts(element_dtype=initial_list.dtype))
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(sess.run(t), [3, 4])
|
||||
self.assertAllEqual(self.evaluate(t), [3, 4])
|
||||
|
||||
def test_get_item_tensor_string(self):
|
||||
initial_str = constant_op.constant('abcd')
|
||||
@ -52,14 +52,14 @@ class SlicesTest(test.TestCase):
|
||||
slices.GetItemOpts(element_dtype=initial_str.dtype))
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(sess.run(t), b'b')
|
||||
self.assertEqual(self.evaluate(t), b'b')
|
||||
|
||||
initial_list_str = constant_op.constant(['abcd', 'bcde'])
|
||||
t = slices.get_item(initial_list_str, 1,
|
||||
slices.GetItemOpts(element_dtype=initial_str.dtype))
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(sess.run(t), b'bcde')
|
||||
self.assertEqual(self.evaluate(t), b'bcde')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -32,7 +32,7 @@ class MiscTest(test.TestCase):
|
||||
new_a = alias_tensors(a)
|
||||
self.assertFalse(new_a is a)
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(1, sess.run(new_a))
|
||||
self.assertEqual(1, self.evaluate(new_a))
|
||||
|
||||
def test_alias_tensors(self):
|
||||
a = constant(1)
|
||||
@ -47,7 +47,7 @@ class MiscTest(test.TestCase):
|
||||
self.assertTrue(new_s is s)
|
||||
self.assertTrue(new_l is l)
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(1, sess.run(new_a))
|
||||
self.assertEqual(1, self.evaluate(new_a))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -34,13 +34,13 @@ class PyFuncTest(test.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
result = py_func.wrap_py_func(test_fn, dtypes.int64,
|
||||
(1, constant_op.constant(1), 1))
|
||||
self.assertEqual(3, sess.run(result))
|
||||
self.assertEqual(3, self.evaluate(result))
|
||||
result = py_func.wrap_py_func(test_fn, dtypes.int64, (1, 1, 1))
|
||||
self.assertEqual(3, sess.run(result))
|
||||
self.assertEqual(3, self.evaluate(result))
|
||||
result = py_func.wrap_py_func(
|
||||
test_fn, dtypes.int64,
|
||||
(constant_op.constant(1), 1, constant_op.constant(1)))
|
||||
self.assertEqual(3, sess.run(result))
|
||||
self.assertEqual(3, self.evaluate(result))
|
||||
|
||||
def test_wrap_py_func_complex_args(self):
|
||||
|
||||
@ -54,10 +54,10 @@ class PyFuncTest(test.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
result = py_func.wrap_py_func(test_fn, dtypes.int64, (7, TestClass()))
|
||||
self.assertEqual(35, sess.run(result))
|
||||
self.assertEqual(35, self.evaluate(result))
|
||||
result = py_func.wrap_py_func(test_fn, dtypes.int64,
|
||||
(constant_op.constant(7), TestClass()))
|
||||
self.assertEqual(35, sess.run(result))
|
||||
self.assertEqual(35, self.evaluate(result))
|
||||
|
||||
def test_wrap_py_func_kwargs(self):
|
||||
|
||||
@ -74,13 +74,13 @@ class PyFuncTest(test.TestCase):
|
||||
'c': 11,
|
||||
'd': TestClass(13)
|
||||
})
|
||||
self.assertEqual(178, sess.run(result))
|
||||
self.assertEqual(178, self.evaluate(result))
|
||||
result = py_func.wrap_py_func(test_fn, dtypes.int64,
|
||||
(constant_op.constant(7), TestClass(5)), {
|
||||
'c': constant_op.constant(11),
|
||||
'd': TestClass(13)
|
||||
})
|
||||
self.assertEqual(178, sess.run(result))
|
||||
self.assertEqual(178, self.evaluate(result))
|
||||
|
||||
def test_wrap_py_func_dummy_return(self):
|
||||
|
||||
@ -91,11 +91,11 @@ class PyFuncTest(test.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
result = py_func.wrap_py_func(test_fn, None, (5,), use_dummy_return=True)
|
||||
self.assertEqual(1, sess.run(result))
|
||||
self.assertEqual(1, self.evaluate(result))
|
||||
self.assertEqual([1], side_counter)
|
||||
result = py_func.wrap_py_func(
|
||||
test_fn, None, (constant_op.constant(5),), use_dummy_return=True)
|
||||
self.assertEqual(1, sess.run(result))
|
||||
self.assertEqual(1, self.evaluate(result))
|
||||
self.assertEqual([2], side_counter)
|
||||
|
||||
|
||||
|
@ -43,13 +43,13 @@ class TensorListTest(test.TestCase):
|
||||
l = tl.dynamic_list_append(l, 1)
|
||||
s = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(sess.run(s), [1])
|
||||
self.assertAllEqual(self.evaluate(s), [1])
|
||||
|
||||
l = tensor_array_ops.TensorArray(dtypes.int32, size=0, dynamic_size=True)
|
||||
l = tl.dynamic_list_append(l, 1)
|
||||
s = l.stack()
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(sess.run(s), [1])
|
||||
self.assertAllEqual(self.evaluate(s), [1])
|
||||
|
||||
l = tl.TensorList(self._shape(()), dtypes.int32)
|
||||
l = tl.dynamic_list_append(l, 1)
|
||||
@ -92,7 +92,7 @@ class TensorListTest(test.TestCase):
|
||||
a2 = l.pop()
|
||||
c4 = l.count()
|
||||
with Session() as sess:
|
||||
c1, c2, c3, c4, a, a2 = sess.run([c1, c2, c3, c4, a, a2])
|
||||
c1, c2, c3, c4, a, a2 = self.evaluate([c1, c2, c3, c4, a, a2])
|
||||
self.assertEqual(c1, 1)
|
||||
self.assertEqual(c2, 2)
|
||||
self.assertEqual(c3, 1)
|
||||
@ -108,7 +108,7 @@ class TensorListTest(test.TestCase):
|
||||
l[0] = b
|
||||
l1 = l[0]
|
||||
with self.cached_session() as sess:
|
||||
l0, l1, a, b = sess.run([l0, l1, a, b])
|
||||
l0, l1, a, b = self.evaluate([l0, l1, a, b])
|
||||
self.assertEqual(l0, a)
|
||||
self.assertEqual(l1, b)
|
||||
|
||||
|
@ -62,7 +62,7 @@ class SessionClusterSpecPropagationTest(test_util.TensorFlowTestCase):
|
||||
|
||||
const = constant_op.constant(17)
|
||||
sess = session.Session(server1.target, config=config)
|
||||
output = sess.run(const)
|
||||
output = self.evaluate(const)
|
||||
self.assertEqual(17, output)
|
||||
|
||||
def testClusterSpecPropagationWorker2Placement(self):
|
||||
@ -106,7 +106,7 @@ class SessionClusterSpecPropagationTest(test_util.TensorFlowTestCase):
|
||||
with ops.Graph().as_default() as g, ops.device('/job:worker/task:0'):
|
||||
const = constant_op.constant(17)
|
||||
sess = session.Session(server1.target, config=config, graph=g)
|
||||
output = sess.run(const)
|
||||
output = self.evaluate(const)
|
||||
self.assertEqual(17, output)
|
||||
|
||||
def testCanonicalDeviceNames(self):
|
||||
@ -208,7 +208,7 @@ class SessionClusterSpecPropagationTest(test_util.TensorFlowTestCase):
|
||||
with ops.device('/job:worker/task:0/cpu:0'):
|
||||
sum3 = sum1 + sum2
|
||||
sess = session.Session(server1.target, config=config, graph=g)
|
||||
output = sess.run(sum3)
|
||||
output = self.evaluate(sum3)
|
||||
self.assertEqual(40, output)
|
||||
|
||||
def testLegacyDeviceNames(self):
|
||||
|
@ -117,7 +117,7 @@ class PartialRunTest(test_util.TensorFlowTestCase):
|
||||
a = constant_op.constant(2.0, dtypes.float32)
|
||||
b = a * 2
|
||||
c = b * 3
|
||||
r1 = sess.run([b, c])
|
||||
r1 = self.evaluate([b, c])
|
||||
h = sess.partial_run_setup([b, c], [])
|
||||
r2 = sess.partial_run(h, [b, c])
|
||||
self.assertEqual(r1, r2)
|
||||
|
@ -147,7 +147,7 @@ class TimelineTest(test.TestCase):
|
||||
num2 = variables.Variable(2.0, name='num2')
|
||||
with ops.device('/cpu:2'):
|
||||
result = num1 + num2 + num1 * num2
|
||||
sess.run(variables.global_variables_initializer())
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
sess.run(result, options=run_options, run_metadata=run_metadata)
|
||||
|
||||
self.assertTrue(run_metadata.HasField('step_stats'))
|
||||
@ -176,7 +176,7 @@ class TimelineTest(test.TestCase):
|
||||
num2 = variables.Variable(2.0, name='num2')
|
||||
with ops.device('/cpu:2'):
|
||||
result = num1 + num2 + num1 * num2
|
||||
sess.run(variables.global_variables_initializer())
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
sess.run(result, options=run_options, run_metadata=run_metadata)
|
||||
self.assertTrue(run_metadata.HasField('step_stats'))
|
||||
step_stats = run_metadata.step_stats
|
||||
|
@ -216,7 +216,7 @@ class VirtualGpuTest(test_util.TensorFlowTestCase):
|
||||
for d in self._util.devices:
|
||||
with ops.device(d):
|
||||
var = variables.Variable(random_ops.random_uniform(mat_shape))
|
||||
sess.run(var.initializer)
|
||||
self.evaluate(var.initializer)
|
||||
data.append(var)
|
||||
s = data[0]
|
||||
for i in range(1, len(data)):
|
||||
|
@ -110,9 +110,9 @@ class BucketBySequenceLengthTest(test_base.DatasetTestBase):
|
||||
with self.cached_session() as sess:
|
||||
batches = []
|
||||
for _ in range(4):
|
||||
batches.append(sess.run(batch))
|
||||
batches.append(self.evaluate(batch))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(batch)
|
||||
self.evaluate(batch)
|
||||
batch_sizes_val = []
|
||||
lengths_val = []
|
||||
for batch in batches:
|
||||
@ -160,9 +160,9 @@ class BucketBySequenceLengthTest(test_base.DatasetTestBase):
|
||||
with self.cached_session() as sess:
|
||||
batches = []
|
||||
for _ in range(3):
|
||||
batches.append(sess.run(batch))
|
||||
batches.append(self.evaluate(batch))
|
||||
with self.assertRaisesOpError("bucket_boundaries"):
|
||||
sess.run(batch)
|
||||
self.evaluate(batch)
|
||||
batch_sizes_val = []
|
||||
lengths_val = []
|
||||
for batch in batches:
|
||||
@ -197,9 +197,9 @@ class BucketBySequenceLengthTest(test_base.DatasetTestBase):
|
||||
with self.cached_session() as sess:
|
||||
batches = []
|
||||
for _ in range(5):
|
||||
batches.append(sess.run(batch))
|
||||
batches.append(self.evaluate(batch))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(batch)
|
||||
self.evaluate(batch)
|
||||
|
||||
self.assertAllEqual(batches[0], [[1, 0],
|
||||
[1, 1]])
|
||||
@ -300,7 +300,7 @@ class BucketBySequenceLengthTest(test_base.DatasetTestBase):
|
||||
with self.cached_session() as sess:
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
while True:
|
||||
output = sess.run(batch)
|
||||
output = self.evaluate(batch)
|
||||
sprs_tensor = (tuple([tuple(idx) for idx in output.indices]),
|
||||
tuple(output.values))
|
||||
all_sparse_tensors.add(sprs_tensor)
|
||||
|
@ -57,9 +57,9 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
def testCopyToDeviceInt32(self):
|
||||
host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3])
|
||||
@ -82,9 +82,9 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
|
||||
self.assertAllEqual([0, 1, 2, 3], self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
def testCopyToSameDevice(self):
|
||||
host_dataset = dataset_ops.Dataset.range(10)
|
||||
@ -108,9 +108,9 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
def testCopyToDeviceWithPrefetch(self):
|
||||
host_dataset = dataset_ops.Dataset.range(10)
|
||||
@ -134,9 +134,9 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
def testCopyDictToDevice(self):
|
||||
host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
|
||||
@ -160,9 +160,9 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual({"a": i}, sess.run(next_element))
|
||||
self.assertEqual({"a": i}, self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
def testCopyDictToDeviceWithPrefetch(self):
|
||||
host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
|
||||
@ -186,9 +186,9 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual({"a": i}, sess.run(next_element))
|
||||
self.assertEqual({"a": i}, self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
def testCopySparseTensorsToDevice(self):
|
||||
|
||||
@ -217,12 +217,12 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
for i in range(10):
|
||||
actual = sess.run(next_element)
|
||||
actual = self.evaluate(next_element)
|
||||
self.assertAllEqual([i], actual.values)
|
||||
self.assertAllEqual([[0, 0]], actual.indices)
|
||||
self.assertAllEqual([2, 2], actual.dense_shape)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
def testCopySparseTensorsToDeviceWithPrefetch(self):
|
||||
|
||||
@ -251,12 +251,12 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
for i in range(10):
|
||||
actual = sess.run(next_element)
|
||||
actual = self.evaluate(next_element)
|
||||
self.assertAllEqual([i], actual.values)
|
||||
self.assertAllEqual([[0, 0]], actual.indices)
|
||||
self.assertAllEqual([2, 2], actual.dense_shape)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
def testCopyToDeviceGpu(self):
|
||||
if not test_util.is_gpu_available():
|
||||
@ -271,11 +271,11 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
def testCopyToDeviceGpuWithPrefetch(self):
|
||||
if not test_util.is_gpu_available():
|
||||
@ -290,11 +290,11 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
def testCopyToDeviceGpuWithMap(self):
|
||||
if not test_util.is_gpu_available():
|
||||
@ -323,14 +323,14 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(10):
|
||||
x, y, z = sess.run(next_element)
|
||||
x, y, z = self.evaluate(next_element)
|
||||
self.assertEqual(i**2, x)
|
||||
self.assertEqual(float(i**2), y)
|
||||
self.assertEqual(util_compat.as_bytes(str(i)), z)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
def testCopyToDeviceGpuInt32(self):
|
||||
if not test_util.is_gpu_available():
|
||||
@ -345,10 +345,10 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
|
||||
self.evaluate(iterator.initializer)
|
||||
self.assertAllEqual([0, 1, 2, 3], self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
def testCopyToDeviceGpuInt32AndPrefetch(self):
|
||||
if not test_util.is_gpu_available():
|
||||
@ -363,10 +363,10 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
|
||||
self.evaluate(iterator.initializer)
|
||||
self.assertAllEqual([0, 1, 2, 3], self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
def testCopyToDeviceGpuStrings(self):
|
||||
if not test_util.is_gpu_available():
|
||||
@ -381,10 +381,10 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element))
|
||||
self.evaluate(iterator.initializer)
|
||||
self.assertAllEqual([b"a", b"b", b"c"], self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
def testCopyToDeviceGpuStringsAndPrefetch(self):
|
||||
if not test_util.is_gpu_available():
|
||||
@ -399,10 +399,10 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element))
|
||||
self.evaluate(iterator.initializer)
|
||||
self.assertAllEqual([b"a", b"b", b"c"], self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
def testCopyToDevicePingPongCPUGPU(self):
|
||||
if not test_util.is_gpu_available():
|
||||
@ -420,11 +420,11 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
def testCopyToDeviceWithReInit(self):
|
||||
host_dataset = dataset_ops.Dataset.range(10)
|
||||
@ -447,14 +447,14 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(5):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
sess.run(iterator.initializer)
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
def testCopyToDeviceWithReInitAndPrefetch(self):
|
||||
host_dataset = dataset_ops.Dataset.range(10)
|
||||
@ -477,14 +477,14 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(5):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
sess.run(iterator.initializer)
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
def testCopyToDeviceGpuWithReInit(self):
|
||||
if not test_util.is_gpu_available():
|
||||
@ -499,14 +499,14 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(5):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
sess.run(iterator.initializer)
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
def testCopyToDeviceGpuWithReInitAndPrefetch(self):
|
||||
if not test_util.is_gpu_available():
|
||||
@ -521,14 +521,14 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(5):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
sess.run(iterator.initializer)
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
def testIteratorGetNextAsOptionalOnGPU(self):
|
||||
if not test_util.is_gpu_available():
|
||||
@ -547,24 +547,25 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
# Before initializing the iterator, evaluating the optional fails with
|
||||
# a FailedPreconditionError.
|
||||
with self.assertRaises(errors.FailedPreconditionError):
|
||||
sess.run(elem_has_value_t)
|
||||
self.evaluate(elem_has_value_t)
|
||||
with self.assertRaises(errors.FailedPreconditionError):
|
||||
sess.run(elem_value_t)
|
||||
self.evaluate(elem_value_t)
|
||||
|
||||
# For each element of the dataset, assert that the optional evaluates to
|
||||
# the expected value.
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(3):
|
||||
elem_has_value, elem_value = sess.run([elem_has_value_t, elem_value_t])
|
||||
elem_has_value, elem_value = self.evaluate(
|
||||
[elem_has_value_t, elem_value_t])
|
||||
self.assertTrue(elem_has_value)
|
||||
self.assertEqual(i, elem_value)
|
||||
|
||||
# After exhausting the iterator, `next_elem.has_value()` will evaluate to
|
||||
# false, and attempting to get the value will fail.
|
||||
for _ in range(2):
|
||||
self.assertFalse(sess.run(elem_has_value_t))
|
||||
self.assertFalse(self.evaluate(elem_has_value_t))
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(elem_value_t)
|
||||
self.evaluate(elem_value_t)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -38,13 +38,13 @@ class CounterTest(test_base.DatasetTestBase):
|
||||
negative_get_next = negative_iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(3, sess.run(get_next))
|
||||
self.assertEqual(3 + 4, sess.run(get_next))
|
||||
self.assertEqual(3 + 2 * 4, sess.run(get_next))
|
||||
self.assertEqual(3, self.evaluate(get_next))
|
||||
self.assertEqual(3 + 4, self.evaluate(get_next))
|
||||
self.assertEqual(3 + 2 * 4, self.evaluate(get_next))
|
||||
|
||||
self.assertEqual(0, sess.run(negative_get_next))
|
||||
self.assertEqual(-1, sess.run(negative_get_next))
|
||||
self.assertEqual(-2, sess.run(negative_get_next))
|
||||
self.assertEqual(0, self.evaluate(negative_get_next))
|
||||
self.assertEqual(-1, self.evaluate(negative_get_next))
|
||||
self.assertEqual(-2, self.evaluate(negative_get_next))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -41,10 +41,10 @@ class DenseToSparseBatchTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
|
||||
for start in range(0, len(components), 4):
|
||||
results = sess.run(get_next)
|
||||
results = self.evaluate(get_next)
|
||||
self.assertAllEqual([[i, j]
|
||||
for i, c in enumerate(components[start:start + 4])
|
||||
for j in range(c)], results.indices)
|
||||
@ -56,7 +56,7 @@ class DenseToSparseBatchTest(test_base.DatasetTestBase):
|
||||
results.dense_shape)
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
def testDenseToSparseBatchDatasetWithUnknownShape(self):
|
||||
components = np.random.randint(5, size=(40,)).astype(np.int32)
|
||||
@ -69,10 +69,10 @@ class DenseToSparseBatchTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
|
||||
for start in range(0, len(components), 4):
|
||||
results = sess.run(get_next)
|
||||
results = self.evaluate(get_next)
|
||||
self.assertAllEqual([[i, j, z]
|
||||
for i, c in enumerate(components[start:start + 4])
|
||||
for j in range(c)
|
||||
@ -89,7 +89,7 @@ class DenseToSparseBatchTest(test_base.DatasetTestBase):
|
||||
], results.dense_shape)
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
def testDenseToSparseBatchDatasetWithInvalidShape(self):
|
||||
input_tensor = array_ops.constant([[1]])
|
||||
@ -111,13 +111,13 @@ class DenseToSparseBatchTest(test_base.DatasetTestBase):
|
||||
sess.run(init_op, feed_dict={input_tensor: [[1]]})
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
"incompatible with the row shape"):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
# Initialize with an input tensor that is larger than `row_shape`.
|
||||
sess.run(init_op, feed_dict={input_tensor: range(13)})
|
||||
with self.assertRaisesRegexp(errors.DataLossError,
|
||||
"larger than the row shape"):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -40,12 +40,12 @@ class DirectedInterleaveDatasetTest(test_base.DatasetTestBase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
for _ in range(100):
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
def _normalize(self, vec):
|
||||
return vec / vec.sum()
|
||||
@ -71,9 +71,9 @@ class DirectedInterleaveDatasetTest(test_base.DatasetTestBase):
|
||||
with self.cached_session() as sess:
|
||||
freqs = np.zeros([num_datasets])
|
||||
for _ in range(num_samples):
|
||||
freqs[sess.run(next_element)] += 1
|
||||
freqs[self.evaluate(next_element)] += 1
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
return freqs
|
||||
|
||||
@ -107,9 +107,9 @@ class DirectedInterleaveDatasetTest(test_base.DatasetTestBase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in choice_array:
|
||||
self.assertEqual(words[i], sess.run(next_element))
|
||||
self.assertEqual(words[i], self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
def testErrors(self):
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
|
@ -44,12 +44,12 @@ class EnumerateDatasetTest(test_base.DatasetTestBase):
|
||||
[t.shape for t in get_next[1]])
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.assertEqual((20, (b"a", 1, 37.0)), sess.run(get_next))
|
||||
self.assertEqual((21, (b"b", 2, 38.0)), sess.run(get_next))
|
||||
self.evaluate(init_op)
|
||||
self.assertEqual((20, (b"a", 1, 37.0)), self.evaluate(get_next))
|
||||
self.assertEqual((21, (b"b", 2, 38.0)), self.evaluate(get_next))
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -52,12 +52,12 @@ class FilterBenchmark(test.Benchmark):
|
||||
|
||||
with session.Session() as sess:
|
||||
for _ in range(10):
|
||||
sess.run(next_element.op)
|
||||
self.evaluate(next_element.op)
|
||||
deltas = []
|
||||
for _ in range(100):
|
||||
start = time.time()
|
||||
for _ in range(100):
|
||||
sess.run(next_element.op)
|
||||
self.evaluate(next_element.op)
|
||||
end = time.time()
|
||||
deltas.append(end - start)
|
||||
|
||||
|
@ -94,18 +94,18 @@ class FunctionBufferingResourceTest(test_base.DatasetTestBase):
|
||||
device0, device1)
|
||||
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
elem = sess.run(prefetch_op)
|
||||
elem = self.evaluate(prefetch_op)
|
||||
self.assertEqual(elem, [1.0])
|
||||
elem = sess.run(prefetch_op)
|
||||
elem = self.evaluate(prefetch_op)
|
||||
self.assertEqual(elem, [2.0])
|
||||
elem = sess.run(prefetch_op)
|
||||
elem = self.evaluate(prefetch_op)
|
||||
self.assertEqual(elem, [3.0])
|
||||
elem = sess.run(prefetch_op)
|
||||
elem = self.evaluate(prefetch_op)
|
||||
self.assertEqual(elem, [4.0])
|
||||
self._event.wait()
|
||||
elem = sess.run(prefetch_op)
|
||||
elem = self.evaluate(prefetch_op)
|
||||
self.assertEqual(elem, [5.0])
|
||||
sess.run(destroy_op)
|
||||
self.evaluate(destroy_op)
|
||||
|
||||
def testSameDeviceCPU(self):
|
||||
self._prefetch_fn_helper_one_shot("same_device_cpu",
|
||||
@ -135,35 +135,35 @@ class FunctionBufferingResourceTest(test_base.DatasetTestBase):
|
||||
ds, ds_iterator, "reinit", device0, device1)
|
||||
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
sess.run(ds_iterator.initializer)
|
||||
elem = sess.run(prefetch_op)
|
||||
self.evaluate(ds_iterator.initializer)
|
||||
elem = self.evaluate(prefetch_op)
|
||||
self.assertEqual(elem, [1.0])
|
||||
elem = sess.run(prefetch_op)
|
||||
elem = self.evaluate(prefetch_op)
|
||||
self.assertEqual(elem, [2.0])
|
||||
elem = sess.run(prefetch_op)
|
||||
elem = self.evaluate(prefetch_op)
|
||||
self.assertEqual(elem, [3.0])
|
||||
elem = sess.run(prefetch_op)
|
||||
elem = self.evaluate(prefetch_op)
|
||||
self.assertEqual(elem, [4.0])
|
||||
self._event.wait()
|
||||
elem = sess.run(prefetch_op)
|
||||
elem = self.evaluate(prefetch_op)
|
||||
self.assertEqual(elem, [5.0])
|
||||
# Lets reset the function buffering resource and reinitialize the
|
||||
# iterator. Should be able to go through this again.
|
||||
self._event.clear()
|
||||
sess.run(reset_op)
|
||||
sess.run(ds_iterator.initializer)
|
||||
elem = sess.run(prefetch_op)
|
||||
self.evaluate(reset_op)
|
||||
self.evaluate(ds_iterator.initializer)
|
||||
elem = self.evaluate(prefetch_op)
|
||||
self.assertEqual(elem, [1.0])
|
||||
elem = sess.run(prefetch_op)
|
||||
elem = self.evaluate(prefetch_op)
|
||||
self.assertEqual(elem, [2.0])
|
||||
elem = sess.run(prefetch_op)
|
||||
elem = self.evaluate(prefetch_op)
|
||||
self.assertEqual(elem, [3.0])
|
||||
elem = sess.run(prefetch_op)
|
||||
elem = self.evaluate(prefetch_op)
|
||||
self.assertEqual(elem, [4.0])
|
||||
self._event.wait()
|
||||
elem = sess.run(prefetch_op)
|
||||
elem = self.evaluate(prefetch_op)
|
||||
self.assertEqual(elem, [5.0])
|
||||
sess.run(destroy_op)
|
||||
self.evaluate(destroy_op)
|
||||
|
||||
def testReinitializationOutOfRange(self):
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
@ -175,30 +175,30 @@ class FunctionBufferingResourceTest(test_base.DatasetTestBase):
|
||||
ds, ds_iterator, "reinit", device0, device1)
|
||||
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
sess.run(ds_iterator.initializer)
|
||||
self.evaluate(ds_iterator.initializer)
|
||||
for i in range(1, 10):
|
||||
elem = sess.run(prefetch_op)
|
||||
elem = self.evaluate(prefetch_op)
|
||||
self.assertEqual(elem, [float(i)])
|
||||
# Try fetching after its over twice to test out end of sequence.
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(prefetch_op)
|
||||
self.evaluate(prefetch_op)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(prefetch_op)
|
||||
self.evaluate(prefetch_op)
|
||||
|
||||
# Now reset everything and try it out again.
|
||||
self._event.clear()
|
||||
sess.run(reset_op)
|
||||
sess.run(ds_iterator.initializer)
|
||||
self.evaluate(reset_op)
|
||||
self.evaluate(ds_iterator.initializer)
|
||||
for i in range(1, 10):
|
||||
elem = sess.run(prefetch_op)
|
||||
elem = self.evaluate(prefetch_op)
|
||||
self.assertEqual(elem, [float(i)])
|
||||
# Try fetching after its over twice to test out end of sequence.
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(prefetch_op)
|
||||
self.evaluate(prefetch_op)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(prefetch_op)
|
||||
self.evaluate(prefetch_op)
|
||||
|
||||
sess.run(destroy_op)
|
||||
self.evaluate(destroy_op)
|
||||
|
||||
def testStringsGPU(self):
|
||||
if not test_util.is_gpu_available():
|
||||
@ -235,13 +235,13 @@ class FunctionBufferingResourceTest(test_base.DatasetTestBase):
|
||||
buffer_resource_handle, ignore_lookup_error=True)
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual([b"a"], sess.run(prefetch_op))
|
||||
self.assertEqual([b"b"], sess.run(prefetch_op))
|
||||
self.assertEqual([b"c"], sess.run(prefetch_op))
|
||||
self.assertEqual([b"a"], self.evaluate(prefetch_op))
|
||||
self.assertEqual([b"b"], self.evaluate(prefetch_op))
|
||||
self.assertEqual([b"c"], self.evaluate(prefetch_op))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(prefetch_op)
|
||||
self.evaluate(prefetch_op)
|
||||
|
||||
sess.run(destroy_op)
|
||||
self.evaluate(destroy_op)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -39,10 +39,10 @@ class GroupByReducerTest(test_base.DatasetTestBase):
|
||||
get_next = dataset.make_one_shot_iterator().get_next()
|
||||
with self.cached_session() as sess:
|
||||
for expected in values:
|
||||
got = sess.run(get_next)
|
||||
got = self.evaluate(get_next)
|
||||
self.assertEqual(got, expected)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
def testSum(self):
|
||||
reducer = grouping.Reducer(
|
||||
@ -127,11 +127,11 @@ class GroupByReducerTest(test_base.DatasetTestBase):
|
||||
iterator = dataset.make_one_shot_iterator()
|
||||
get_next = iterator.get_next()
|
||||
with self.cached_session() as sess:
|
||||
x, y = sess.run(get_next)
|
||||
x, y = self.evaluate(get_next)
|
||||
self.assertAllEqual([0] * (2**i), x)
|
||||
self.assertAllEqual(np.array(1, ndmin=i), y)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
def testTypeMismatch(self):
|
||||
reducer = grouping.Reducer(
|
||||
@ -190,7 +190,7 @@ class GroupByReducerTest(test_base.DatasetTestBase):
|
||||
grouping.group_by_reducer(lambda x, y: np.int64(0), reducer))
|
||||
get_next = dataset.make_one_shot_iterator().get_next()
|
||||
with self.cached_session() as sess:
|
||||
x, y = sess.run(get_next)
|
||||
x, y = self.evaluate(get_next)
|
||||
self.assertAllEqual(x, np.asarray([x for x in range(10)]))
|
||||
self.assertEqual(y, 45)
|
||||
|
||||
|
@ -68,9 +68,9 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
|
||||
which_bucket, bucketed_values = sess.run(get_next)
|
||||
which_bucket, bucketed_values = self.evaluate(get_next)
|
||||
|
||||
self.assertEqual(0, which_bucket)
|
||||
|
||||
@ -103,11 +103,11 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
|
||||
# Get two minibatches (one containing even values, one containing odds)
|
||||
which_bucket_even, bucketed_values_even = sess.run(get_next)
|
||||
which_bucket_odd, bucketed_values_odd = sess.run(get_next)
|
||||
which_bucket_even, bucketed_values_even = self.evaluate(get_next)
|
||||
which_bucket_odd, bucketed_values_odd = self.evaluate(get_next)
|
||||
|
||||
# Count number of bucket_tensors.
|
||||
self.assertEqual(3, len(bucketed_values_even))
|
||||
@ -174,11 +174,11 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
|
||||
# Get two minibatches ([0, 2, ...] and [64, 66, ...])
|
||||
which_bucket0, bucketed_values_even0 = sess.run(get_next)
|
||||
which_bucket1, bucketed_values_even1 = sess.run(get_next)
|
||||
which_bucket0, bucketed_values_even0 = self.evaluate(get_next)
|
||||
which_bucket1, bucketed_values_even1 = self.evaluate(get_next)
|
||||
|
||||
# Ensure that bucket 1 was completely filtered out
|
||||
self.assertAllEqual(0, which_bucket0)
|
||||
@ -207,11 +207,11 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
batches = 0
|
||||
while True:
|
||||
result = sess.run(get_next)
|
||||
result = self.evaluate(get_next)
|
||||
is_even = all(x % 2 == 0 for x in result)
|
||||
is_odd = all(x % 2 == 1 for x in result)
|
||||
self.assertTrue(is_even or is_odd)
|
||||
@ -232,11 +232,11 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
counts = []
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
while True:
|
||||
result = sess.run(get_next)
|
||||
result = self.evaluate(get_next)
|
||||
self.assertTrue(
|
||||
all(x % 2 == 0
|
||||
for x in result) or all(x % 2 == 1)
|
||||
@ -259,16 +259,16 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
# The input is infinite, so this test demonstrates that:
|
||||
# 1. We produce output without having to consume the entire input,
|
||||
# 2. Different buckets can produce output at different rates, and
|
||||
# 3. For deterministic input, the output is deterministic.
|
||||
for _ in range(3):
|
||||
self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
|
||||
self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
|
||||
self.assertAllEqual([2, 2, 2, 2], sess.run(get_next))
|
||||
self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
|
||||
self.assertAllEqual([0, 0, 0, 0], self.evaluate(get_next))
|
||||
self.assertAllEqual([1, 1, 1, 1], self.evaluate(get_next))
|
||||
self.assertAllEqual([2, 2, 2, 2], self.evaluate(get_next))
|
||||
self.assertAllEqual([0, 0, 0, 0], self.evaluate(get_next))
|
||||
|
||||
def testSmallGroups(self):
|
||||
components = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0], dtype=np.int64)
|
||||
@ -280,13 +280,13 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
|
||||
self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
|
||||
self.evaluate(init_op)
|
||||
self.assertAllEqual([0, 0, 0, 0], self.evaluate(get_next))
|
||||
self.assertAllEqual([1, 1, 1, 1], self.evaluate(get_next))
|
||||
# The small outputs at the end are deterministically produced in key
|
||||
# order.
|
||||
self.assertAllEqual([0, 0, 0], sess.run(get_next))
|
||||
self.assertAllEqual([1], sess.run(get_next))
|
||||
self.assertAllEqual([0, 0, 0], self.evaluate(get_next))
|
||||
self.assertAllEqual([1], self.evaluate(get_next))
|
||||
|
||||
def testEmpty(self):
|
||||
iterator = (
|
||||
@ -297,11 +297,11 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
with self.assertRaisesRegexp(
|
||||
errors.InvalidArgumentError,
|
||||
"Window size must be greater than zero, but got 0."):
|
||||
print(sess.run(get_next))
|
||||
print(self.evaluate(get_next))
|
||||
|
||||
def testReduceFuncError(self):
|
||||
components = np.random.randint(100, size=(200,)).astype(np.int64)
|
||||
@ -323,9 +323,9 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
def testConsumeWindowDatasetMoreThanOnce(self):
|
||||
components = np.random.randint(50, size=(200,)).astype(np.int64)
|
||||
@ -351,11 +351,11 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
counts = []
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
while True:
|
||||
tight_result, multiple_of_10_result = sess.run(get_next)
|
||||
tight_result, multiple_of_10_result = self.evaluate(get_next)
|
||||
self.assertEqual(0, multiple_of_10_result.shape[1] % 10)
|
||||
self.assertAllEqual(tight_result,
|
||||
multiple_of_10_result[:, :tight_result.shape[1]])
|
||||
|
@ -47,11 +47,11 @@ class IgnoreErrorsTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
for x in [1., 2., 3., 5.]:
|
||||
self.assertEqual(x, sess.run(get_next))
|
||||
self.assertEqual(x, self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
def testParallelMapIgnoreError(self):
|
||||
components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
|
||||
@ -65,11 +65,11 @@ class IgnoreErrorsTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
for x in [1., 2., 3., 5.]:
|
||||
self.assertEqual(x, sess.run(get_next))
|
||||
self.assertEqual(x, self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
def testReadFileIgnoreError(self):
|
||||
|
||||
@ -93,22 +93,22 @@ class IgnoreErrorsTest(test_base.DatasetTestBase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
# All of the files are present.
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
for filename in filenames:
|
||||
self.assertEqual(compat.as_bytes(filename), sess.run(get_next))
|
||||
self.assertEqual(compat.as_bytes(filename), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
# Delete one of the files.
|
||||
os.remove(filenames[0])
|
||||
|
||||
# Attempting to read filenames[0] will fail, but ignore_errors()
|
||||
# will catch the error.
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
for filename in filenames[1:]:
|
||||
self.assertEqual(compat.as_bytes(filename), sess.run(get_next))
|
||||
self.assertEqual(compat.as_bytes(filename), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -46,14 +46,14 @@ class IndexedDatasetOpsTest(test_base.DatasetTestBase):
|
||||
handle, index, output_types=[dtypes.uint64], output_shapes=[[]])
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(materialize)
|
||||
self.evaluate(materialize)
|
||||
self.assertEqual([3], sess.run(get_op, feed_dict={index: 3}))
|
||||
|
||||
def testIdentityIndexedDataset(self):
|
||||
ds = indexed_dataset_ops.IdentityIndexedDataset(16)
|
||||
materialized = ds.materialize()
|
||||
with self.cached_session() as sess:
|
||||
sess.run(materialized.initializer)
|
||||
self.evaluate(materialized.initializer)
|
||||
placeholder = array_ops.placeholder(dtypes.uint64, shape=[])
|
||||
for i in range(16):
|
||||
output = sess.run(
|
||||
@ -68,12 +68,13 @@ class IndexedDatasetOpsTest(test_base.DatasetTestBase):
|
||||
itr = ds.make_initializable_iterator()
|
||||
n = itr.get_next()
|
||||
with self.cached_session() as sess:
|
||||
sess.run(itr.initializer)
|
||||
self.evaluate(itr.initializer)
|
||||
for i in range(16):
|
||||
output = sess.run(n)
|
||||
output = self.evaluate(n)
|
||||
self.assertEqual(i, output)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(n)
|
||||
self.evaluate(n)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -112,14 +112,14 @@ class MakeBatchedFeaturesDatasetTest(
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
for file_batch, _, _, _, record_batch, _ in self._next_expected_batch(
|
||||
range(self._num_files), 2, 10):
|
||||
actual_batch = sess.run(next_element)
|
||||
actual_batch = self.evaluate(next_element)
|
||||
self.assertAllEqual(file_batch, actual_batch["file"])
|
||||
self.assertAllEqual(record_batch, actual_batch["record"])
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
def testReadWithFusedShuffleRepeatDataset(self):
|
||||
num_epochs = 5
|
||||
|
@ -90,7 +90,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
|
||||
batch_size,
|
||||
num_epochs,
|
||||
):
|
||||
actual_features = sess.run(nxt)
|
||||
actual_features = self.evaluate(nxt)
|
||||
|
||||
if label_name is not None:
|
||||
expected_labels = expected_features.pop(label_name)
|
||||
@ -102,7 +102,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
|
||||
self.assertAllEqual(expected_features[k], actual_features[k])
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(nxt)
|
||||
self.evaluate(nxt)
|
||||
|
||||
def _test_dataset(self,
|
||||
inputs,
|
||||
@ -607,8 +607,8 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
|
||||
outputs1 = dataset1.make_one_shot_iterator().get_next()
|
||||
outputs2 = dataset2.make_one_shot_iterator().get_next()
|
||||
for _ in range(total_records // batch_size):
|
||||
batch1 = nest.flatten(sess.run(outputs1))
|
||||
batch2 = nest.flatten(sess.run(outputs2))
|
||||
batch1 = nest.flatten(self.evaluate(outputs1))
|
||||
batch2 = nest.flatten(self.evaluate(outputs2))
|
||||
for i in range(len(batch1)):
|
||||
self.assertAllEqual(batch1[i], batch2[i])
|
||||
|
||||
@ -639,8 +639,8 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
|
||||
outputs2 = dataset2.make_one_shot_iterator().get_next()
|
||||
all_equal = False
|
||||
for _ in range(total_records // batch_size):
|
||||
batch1 = nest.flatten(sess.run(outputs1))
|
||||
batch2 = nest.flatten(sess.run(outputs2))
|
||||
batch1 = nest.flatten(self.evaluate(outputs1))
|
||||
batch2 = nest.flatten(self.evaluate(outputs2))
|
||||
for i in range(len(batch1)):
|
||||
all_equal = all_equal and np.array_equal(batch1[i], batch2[i])
|
||||
self.assertFalse(all_equal)
|
||||
|
@ -105,7 +105,7 @@ class MakeTFRecordDatasetTest(
|
||||
for expected_batch in self._next_expected_batch(
|
||||
file_indices, batch_size, num_epochs, interleave_cycle_length,
|
||||
drop_final_batch, use_parser_fn):
|
||||
actual_batch = sess.run(outputs)
|
||||
actual_batch = self.evaluate(outputs)
|
||||
self.assertAllEqual(expected_batch, actual_batch)
|
||||
|
||||
def _read_test(self, batch_size, num_epochs, file_index=None,
|
||||
@ -135,7 +135,7 @@ class MakeTFRecordDatasetTest(
|
||||
interleave_cycle_length=num_parallel_reads,
|
||||
drop_final_batch=drop_final_batch, use_parser_fn=parser_fn)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(outputs)
|
||||
self.evaluate(outputs)
|
||||
|
||||
def testRead(self):
|
||||
for batch_size in [1, 2]:
|
||||
@ -188,19 +188,19 @@ class MakeTFRecordDatasetTest(
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
first_batches = []
|
||||
try:
|
||||
while True:
|
||||
first_batches.append(sess.run(next_element))
|
||||
first_batches.append(self.evaluate(next_element))
|
||||
except errors.OutOfRangeError:
|
||||
pass
|
||||
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
second_batches = []
|
||||
try:
|
||||
while True:
|
||||
second_batches.append(sess.run(next_element))
|
||||
second_batches.append(self.evaluate(next_element))
|
||||
except errors.OutOfRangeError:
|
||||
pass
|
||||
|
||||
|
@ -89,13 +89,13 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
sess.run(init_op, feed_dict={count: 28, batch_size: 14})
|
||||
num_batches = (28 * 7) // 14
|
||||
for i in range(num_batches):
|
||||
result = sess.run(get_next)
|
||||
result = self.evaluate(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
for j in range(14):
|
||||
self.assertAllEqual(component[(i * 14 + j) % 7]**2,
|
||||
result_component[j])
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
# Batch of a finite input, where the batch_size does not
|
||||
# divide the total number of elements.
|
||||
@ -104,23 +104,23 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
# We expect (num_batches - 1) full-sized batches.
|
||||
num_batches = int(math.ceil((14 * 7) / 8))
|
||||
for i in range(num_batches - 1):
|
||||
result = sess.run(get_next)
|
||||
result = self.evaluate(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
for j in range(8):
|
||||
self.assertAllEqual(component[(i * 8 + j) % 7]**2,
|
||||
result_component[j])
|
||||
result = sess.run(get_next)
|
||||
result = self.evaluate(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
for j in range((14 * 7) % 8):
|
||||
self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2,
|
||||
result_component[j])
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
# Batch of an empty input should fail straight away.
|
||||
sess.run(init_op, feed_dict={count: 0, batch_size: 8})
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
# Empty batch should be an initialization time error.
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
@ -152,12 +152,12 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertEqual([None, 1], iterator.output_shapes.as_list())
|
||||
next_element = iterator.get_next()
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
|
||||
self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
|
||||
self.assertAllEqual([[0], [1], [4], [9]], self.evaluate(next_element))
|
||||
self.assertAllEqual([[16], [25], [36], [49]], self.evaluate(next_element))
|
||||
if not drop_remainder:
|
||||
self.assertAllEqual([[64], [81]], sess.run(next_element))
|
||||
self.assertAllEqual([[64], [81]], self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("Normal", False),
|
||||
@ -177,11 +177,11 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertEqual([None, 1], iterator.output_shapes.as_list())
|
||||
next_element = iterator.get_next()
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
|
||||
self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
|
||||
self.assertAllEqual([[64], [81]], sess.run(next_element))
|
||||
self.assertAllEqual([[0], [1], [4], [9]], self.evaluate(next_element))
|
||||
self.assertAllEqual([[16], [25], [36], [49]], self.evaluate(next_element))
|
||||
self.assertAllEqual([[64], [81]], self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("Normal", False),
|
||||
@ -201,14 +201,14 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
elements.append(iterator.get_next())
|
||||
with self.cached_session() as sess:
|
||||
for i in range(5):
|
||||
got = sess.run(elements)
|
||||
got = self.evaluate(elements)
|
||||
got.sort(key=lambda x: x[0])
|
||||
expected = []
|
||||
for j in range(100):
|
||||
expected.append(range(i * 10000 + j * 100, i * 10000 + (j + 1) * 100))
|
||||
self.assertAllEqual(got, expected)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(elements)
|
||||
self.evaluate(elements)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("Normal", False),
|
||||
@ -230,14 +230,14 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
elements.append(iterator.get_next())
|
||||
with self.cached_session() as sess:
|
||||
for i in range(4):
|
||||
got = sess.run(elements)
|
||||
got = self.evaluate(elements)
|
||||
got.sort(key=lambda x: x[0])
|
||||
expected = []
|
||||
for j in range(100):
|
||||
expected.append(range(i * 10000 + j * 100, i * 10000 + (j + 1) * 100))
|
||||
self.assertAllEqual(got, expected)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(elements)
|
||||
self.evaluate(elements)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("Normal", False),
|
||||
@ -261,9 +261,9 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
for i in range(2):
|
||||
actual = sess.run(get_next)
|
||||
actual = self.evaluate(get_next)
|
||||
expected = sparse_tensor.SparseTensorValue(
|
||||
indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
|
||||
values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4],
|
||||
@ -271,7 +271,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertTrue(sparse_tensor.is_sparse(actual))
|
||||
self.assertSparseValuesEqual(actual, expected)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("Normal", False),
|
||||
@ -321,10 +321,10 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
"number of elements does not match"):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("Normal", False),
|
||||
@ -354,7 +354,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for _ in range(3):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("1", 0, False),
|
||||
@ -393,13 +393,14 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(threshold // 10):
|
||||
self.assertAllEqual([i * 10 + j for j in range(10)], sess.run(get_next))
|
||||
self.assertAllEqual([i * 10 + j for j in range(10)],
|
||||
self.evaluate(get_next))
|
||||
if threshold % 10 != 0:
|
||||
self.assertAllEqual(
|
||||
[threshold // 10 * 10 + j for j in range(threshold % 10)],
|
||||
sess.run(get_next))
|
||||
self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("1", False, dtypes.bool, False),
|
||||
@ -442,7 +443,8 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for _ in range(10):
|
||||
self.assertAllEqual([element for _ in range(10)], sess.run(get_next))
|
||||
self.assertAllEqual([element for _ in range(10)],
|
||||
self.evaluate(get_next))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("Identity", None, lambda x: x, None),
|
||||
@ -462,7 +464,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
else:
|
||||
expected = map_fn(
|
||||
sess.run(self.structuredElement(structure, shape=[10])))
|
||||
self.assertAllEqual(expected, sess.run(get_next))
|
||||
self.assertAllEqual(expected, self.evaluate(get_next))
|
||||
|
||||
def testShortCircuitCapturedInput(self):
|
||||
captured_t = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
@ -473,7 +475,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer, feed_dict={captured_t: 42})
|
||||
self.assertAllEqual([42] * 10, sess.run(get_next))
|
||||
self.assertAllEqual([42] * 10, self.evaluate(get_next))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("Normal", False),
|
||||
@ -501,13 +503,13 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
print("Case %d" % i)
|
||||
if i < 5:
|
||||
self.assertAllEqual([i * 10 + j + 1 for j in range(10)],
|
||||
sess.run(get_next))
|
||||
self.evaluate(get_next))
|
||||
else:
|
||||
self.assertAllEqual(
|
||||
[((i * 10) + j) * ((i * 10) + j) for j in range(10)],
|
||||
sess.run(get_next))
|
||||
self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -218,7 +218,7 @@ class MapDefunTest(test_base.DatasetTestBase):
|
||||
|
||||
def _assert_op_cancelled(self, sess, map_defun_op):
|
||||
with self.assertRaisesRegexp(errors.CancelledError, "was cancelled"):
|
||||
sess.run(map_defun_op)
|
||||
self.evaluate(map_defun_op)
|
||||
|
||||
def testMapDefunWithParentCancellation(self):
|
||||
# Checks that a cancellation of the parent graph is threaded through to
|
||||
@ -260,10 +260,10 @@ class MapDefunBenchmark(test.Benchmark):
|
||||
with session.Session() as sess:
|
||||
# Warm up the session
|
||||
for _ in range(5):
|
||||
sess.run(op)
|
||||
self.evaluate(op)
|
||||
start = time.time()
|
||||
for _ in range(num_iters):
|
||||
sess.run(op)
|
||||
self.evaluate(op)
|
||||
end = time.time()
|
||||
mean_us = (end - start) * 1e6 / num_iters
|
||||
self.report_benchmark(
|
||||
|
@ -41,9 +41,9 @@ class ModelDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(0, sess.run(get_next))
|
||||
self.assertEqual(0, self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -51,7 +51,7 @@ class OptimizeDatasetTest(test_base.DatasetTestBase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op, {input_t: np.ones([512, 1024, 1025], np.int32)})
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
# TODO(b/117581999): Add eager coverage for the following tests.
|
||||
def testSkipEagerOptimizationLargeInputFromTensorSlices(self):
|
||||
@ -64,7 +64,7 @@ class OptimizeDatasetTest(test_base.DatasetTestBase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op, {input_t: np.ones([1, 512, 1024, 1025], np.int32)})
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
def testOptimizationNestedDataset(self):
|
||||
|
||||
|
@ -55,11 +55,11 @@ class OverrideThreadpoolTest(test_base.DatasetTestBase,
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
thread_ids = []
|
||||
try:
|
||||
while True:
|
||||
thread_ids.append(sess.run(next_element))
|
||||
thread_ids.append(self.evaluate(next_element))
|
||||
except errors.OutOfRangeError:
|
||||
pass
|
||||
self.assertLen(thread_ids, len(set(thread_ids)))
|
||||
|
@ -195,9 +195,9 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
|
||||
[[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 1, 1):
|
||||
self.write_coordination_events[expected_element].set()
|
||||
self.assertEqual(expected_element * expected_element,
|
||||
sess.run(self.next_element))
|
||||
self.evaluate(self.next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(self.next_element)
|
||||
self.evaluate(self.next_element)
|
||||
|
||||
def testSingleThreaded(self):
|
||||
self._testSingleThreaded()
|
||||
@ -235,10 +235,10 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
|
||||
for expected_element in self._interleave(
|
||||
[[3] * 3, [7] * 7, [4] * 4] * self.repeat_count, 2, 1):
|
||||
self.write_coordination_events[expected_element].set()
|
||||
output = sess.run(self.next_element)
|
||||
output = self.evaluate(self.next_element)
|
||||
self.assertEqual(expected_element * expected_element, output)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(self.next_element)
|
||||
self.evaluate(self.next_element)
|
||||
|
||||
def _testTwoThreadsNoContention(self, sloppy=False):
|
||||
# num_threads > 1.
|
||||
@ -262,7 +262,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
|
||||
self.write_coordination_events[expected_element].set()
|
||||
if done_first_event: # First event starts the worker threads.
|
||||
self.read_coordination_events[expected_element].acquire()
|
||||
actual_element = sess.run(self.next_element)
|
||||
actual_element = self.evaluate(self.next_element)
|
||||
if not done_first_event:
|
||||
self.read_coordination_events[expected_element].acquire()
|
||||
done_first_event = True
|
||||
@ -270,7 +270,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
|
||||
"At index %s: %s expected, got: %s" %
|
||||
(i, expected_element, actual_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(self.next_element)
|
||||
self.evaluate(self.next_element)
|
||||
|
||||
def testTwoThreadsNoContention(self):
|
||||
self._testTwoThreadsNoContention()
|
||||
@ -309,7 +309,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
|
||||
else:
|
||||
self.write_coordination_events[expected_element].set()
|
||||
time.sleep(0.5) # Sleep to consistently "avoid" the race condition.
|
||||
actual_element = sess.run(self.next_element)
|
||||
actual_element = self.evaluate(self.next_element)
|
||||
if not done_first_event:
|
||||
done_first_event = True
|
||||
self.assertTrue(
|
||||
@ -318,7 +318,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
|
||||
"At index %s: %s expected, got: %s" %
|
||||
(i, expected_element, actual_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(self.next_element)
|
||||
self.evaluate(self.next_element)
|
||||
|
||||
def testTwoThreadsNoContentionWithRaces(self):
|
||||
self._testTwoThreadsNoContentionWithRaces()
|
||||
@ -348,7 +348,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
|
||||
self.write_coordination_events[expected_element].set()
|
||||
if done_first_event: # First event starts the worker threads.
|
||||
self.read_coordination_events[expected_element].acquire()
|
||||
actual_element = sess.run(self.next_element)
|
||||
actual_element = self.evaluate(self.next_element)
|
||||
if not done_first_event:
|
||||
done_first_event = True
|
||||
self.read_coordination_events[expected_element].acquire()
|
||||
@ -356,7 +356,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
|
||||
"At index %s: %s expected, got: %s" %
|
||||
(i, expected_element, actual_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(self.next_element)
|
||||
self.evaluate(self.next_element)
|
||||
|
||||
def testTwoThreadsNoContentionBlockLength(self):
|
||||
self._testTwoThreadsNoContentionBlockLength()
|
||||
@ -396,7 +396,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
|
||||
else:
|
||||
self.write_coordination_events[expected_element].set()
|
||||
time.sleep(0.5) # Sleep to consistently "avoid" the race condition.
|
||||
actual_element = sess.run(self.next_element)
|
||||
actual_element = self.evaluate(self.next_element)
|
||||
if not done_first_event:
|
||||
done_first_event = True
|
||||
self.assertTrue(
|
||||
@ -405,7 +405,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
|
||||
"At index %s: %s expected, got: %s" %
|
||||
(i, expected_element, actual_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(self.next_element)
|
||||
self.evaluate(self.next_element)
|
||||
|
||||
def testTwoThreadsNoContentionWithRacesAndBlocking(self):
|
||||
self._testTwoThreadsNoContentionWithRacesAndBlocking()
|
||||
@ -428,7 +428,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
|
||||
self.prefetch_input_elements: 0,
|
||||
})
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(self.next_element)
|
||||
self.evaluate(self.next_element)
|
||||
|
||||
def testEmptyInput(self):
|
||||
self._testEmptyInput()
|
||||
@ -451,7 +451,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
|
||||
self.prefetch_input_elements: 0,
|
||||
})
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(self.next_element)
|
||||
self.evaluate(self.next_element)
|
||||
|
||||
def testNonEmptyInputIntoEmptyOutputs(self):
|
||||
self._testNonEmptyInputIntoEmptyOutputs()
|
||||
@ -484,7 +484,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
|
||||
# presence of finishing iterators.
|
||||
if done_first_event and not (sloppy and (i in race_indices)):
|
||||
self.read_coordination_events[expected_element].acquire()
|
||||
actual_element = sess.run(self.next_element)
|
||||
actual_element = self.evaluate(self.next_element)
|
||||
if not done_first_event or (sloppy and (i in race_indices)):
|
||||
done_first_event = True
|
||||
self.read_coordination_events[expected_element].acquire()
|
||||
@ -520,10 +520,10 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
|
||||
]
|
||||
for element in mis_ordering:
|
||||
self.write_coordination_events[element].set()
|
||||
self.assertEqual(element * element, sess.run(self.next_element))
|
||||
self.assertEqual(element * element, self.evaluate(self.next_element))
|
||||
self.assertTrue(self.read_coordination_events[element].acquire(False))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(self.next_element)
|
||||
self.evaluate(self.next_element)
|
||||
|
||||
def testBlockLengthWithContentionSloppy(self):
|
||||
with self.cached_session() as sess:
|
||||
@ -549,7 +549,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
|
||||
self.write_coordination_events[expected_element].set()
|
||||
if done_first_event: # First event starts the worker threads.
|
||||
self.read_coordination_events[expected_element].acquire()
|
||||
actual_element = sess.run(self.next_element)
|
||||
actual_element = self.evaluate(self.next_element)
|
||||
if not done_first_event:
|
||||
self.read_coordination_events[expected_element].acquire()
|
||||
done_first_event = True
|
||||
@ -557,7 +557,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
|
||||
"At index %s: %s expected, got: %s" %
|
||||
(i, expected_element, actual_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(self.next_element)
|
||||
self.evaluate(self.next_element)
|
||||
|
||||
def _testEarlyExit(self, sloppy=False):
|
||||
# Exiting without consuming all input should not block
|
||||
@ -575,7 +575,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
|
||||
})
|
||||
for i in range(4, 7):
|
||||
self.write_coordination_events[i].set()
|
||||
elem = sess.run(self.next_element) # Start all workers
|
||||
elem = self.evaluate(self.next_element) # Start all workers
|
||||
# Allow the one successful worker to progress beyond the py_func again.
|
||||
elem = int(math.sqrt(elem))
|
||||
self.write_coordination_events[elem].set()
|
||||
@ -608,7 +608,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
|
||||
with self.cached_session() as sess:
|
||||
output_values = []
|
||||
for _ in range(30):
|
||||
output_values.append(sess.run(iterator.get_next()))
|
||||
output_values.append(self.evaluate(iterator.get_next()))
|
||||
|
||||
expected_values = self._interleave(
|
||||
[[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 1, 2)
|
||||
@ -637,13 +637,13 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.evaluate(init_op)
|
||||
for i in range(10):
|
||||
for j in range(2):
|
||||
expected = [i, 0] if j % 2 == 0 else [0, -i]
|
||||
self.assertAllEqual(expected, sess.run(get_next))
|
||||
self.assertAllEqual(expected, self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
def testErrorsInOutputFn(self):
|
||||
with self.cached_session() as sess:
|
||||
@ -668,15 +668,15 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
|
||||
self.error = ValueError()
|
||||
self.write_coordination_events[expected_element].set()
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(self.next_element)
|
||||
self.evaluate(self.next_element)
|
||||
else:
|
||||
self.write_coordination_events[expected_element].set()
|
||||
actual_element = sess.run(self.next_element)
|
||||
actual_element = self.evaluate(self.next_element)
|
||||
self.assertEqual(expected_element * expected_element, actual_element,
|
||||
"At index %s: %s expected, got: %s" %
|
||||
(i, expected_element, actual_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(self.next_element)
|
||||
self.evaluate(self.next_element)
|
||||
|
||||
def testErrorsInInputFn(self):
|
||||
|
||||
@ -720,14 +720,14 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
|
||||
self._interleave([[4] * 4, [5], [6] * 6] * self.repeat_count, 2, 1)):
|
||||
if expected_element == 5:
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(self.next_element)
|
||||
self.evaluate(self.next_element)
|
||||
else:
|
||||
actual_element = sess.run(self.next_element)
|
||||
actual_element = self.evaluate(self.next_element)
|
||||
self.assertEqual(expected_element, actual_element,
|
||||
"At index %s: %s expected, got: %s" %
|
||||
(i, expected_element, actual_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(self.next_element)
|
||||
self.evaluate(self.next_element)
|
||||
|
||||
def testErrorsInInterleaveFn(self):
|
||||
|
||||
@ -769,14 +769,14 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
|
||||
self._interleave([[4] * 4, [5], [6] * 6] * self.repeat_count, 2, 1)):
|
||||
if expected_element == 5:
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(self.next_element)
|
||||
self.evaluate(self.next_element)
|
||||
else:
|
||||
actual_element = sess.run(self.next_element)
|
||||
actual_element = self.evaluate(self.next_element)
|
||||
self.assertEqual(expected_element, actual_element,
|
||||
"At index %s: %s expected, got: %s" %
|
||||
(i, expected_element, actual_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(self.next_element)
|
||||
self.evaluate(self.next_element)
|
||||
|
||||
def testShutdownRace(self):
|
||||
dataset = dataset_ops.Dataset.range(20)
|
||||
@ -796,10 +796,10 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
|
||||
with self.cached_session() as sess:
|
||||
for _ in range(2):
|
||||
elements = []
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
try:
|
||||
while True:
|
||||
elements.extend(sess.run(next_element))
|
||||
elements.extend(self.evaluate(next_element))
|
||||
except errors.OutOfRangeError:
|
||||
pass
|
||||
results.append(elements)
|
||||
|
@ -57,9 +57,9 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
def testPrefetchToSameDevice(self):
|
||||
host_dataset = dataset_ops.Dataset.range(10)
|
||||
@ -87,9 +87,9 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
def testPrefetchDictToDevice(self):
|
||||
host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
|
||||
@ -117,9 +117,9 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual({"a": i}, sess.run(next_element))
|
||||
self.assertEqual({"a": i}, self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
def testPrefetchSparseTensorsToDevice(self):
|
||||
def make_tensor(i):
|
||||
@ -150,12 +150,12 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
for i in range(10):
|
||||
actual = sess.run(next_element)
|
||||
actual = self.evaluate(next_element)
|
||||
self.assertAllEqual([i], actual.values)
|
||||
self.assertAllEqual([[0, 0]], actual.indices)
|
||||
self.assertAllEqual([2, 2], actual.dense_shape)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
def testPrefetchToDeviceGpu(self):
|
||||
if not test_util.is_gpu_available():
|
||||
@ -170,9 +170,9 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
def testPrefetchToDeviceWithReInit(self):
|
||||
host_dataset = dataset_ops.Dataset.range(10)
|
||||
@ -199,14 +199,14 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
|
||||
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(5):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
sess.run(iterator.initializer)
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
def testPrefetchToDeviceGpuWithReInit(self):
|
||||
if not test_util.is_gpu_available():
|
||||
@ -220,14 +220,14 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(5):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
sess.run(iterator.initializer)
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -60,9 +60,9 @@ class ScanTest(test_base.DatasetTestBase):
|
||||
feed_dict={start: start_val, step: step_val, take: take_val})
|
||||
for expected, _ in zip(
|
||||
itertools.count(start_val, step_val), range(take_val)):
|
||||
self.assertEqual(expected, sess.run(next_element))
|
||||
self.assertEqual(expected, self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testFibonacci(self):
|
||||
@ -110,9 +110,9 @@ class ScanTest(test_base.DatasetTestBase):
|
||||
feed_dict={start: start_val, step: step_val, take: take_val})
|
||||
for expected, _ in zip(
|
||||
itertools.count(start_val, step_val), range(take_val)):
|
||||
self.assertEqual(expected, sess.run(next_element).values[0])
|
||||
self.assertEqual(expected, self.evaluate(next_element).values[0])
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
def testChangingStateShape(self):
|
||||
# Test the fixed-point shape invariant calculations: start with
|
||||
@ -136,11 +136,11 @@ class ScanTest(test_base.DatasetTestBase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(5):
|
||||
(longer_vector_val, larger_rank_val), _ = sess.run(next_element)
|
||||
(longer_vector_val, larger_rank_val), _ = self.evaluate(next_element)
|
||||
self.assertAllEqual([0] * (2**i), longer_vector_val)
|
||||
self.assertAllEqual(np.array(1, ndmin=i), larger_rank_val)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
def testIncorrectStateType(self):
|
||||
|
||||
|
@ -71,36 +71,36 @@ class RangeDatasetSerializationTest(
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next, save_op, _ = _build_graph(start, stop)
|
||||
with self.session(graph=g) as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
sess.run(init_op)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.evaluate(init_op)
|
||||
for i in range(start, break_point):
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
sess.run(save_op)
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
self.evaluate(save_op)
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next, _, restore_op = _build_graph(start, stop)
|
||||
with self.session(graph=g) as sess:
|
||||
sess.run(init_op)
|
||||
sess.run(restore_op)
|
||||
self.evaluate(init_op)
|
||||
self.evaluate(restore_op)
|
||||
for i in range(break_point, stop):
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
# Saving and restoring in same session.
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next, save_op, restore_op = _build_graph(start, stop)
|
||||
with self.session(graph=g) as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
sess.run(init_op)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.evaluate(init_op)
|
||||
for i in range(start, break_point):
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
sess.run(save_op)
|
||||
sess.run(restore_op)
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
self.evaluate(save_op)
|
||||
self.evaluate(restore_op)
|
||||
for i in range(break_point, stop):
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
self.assertEqual(i, self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
def _build_range_dataset(self, start, stop):
|
||||
return dataset_ops.Dataset.range(start, stop)
|
||||
|
@ -60,9 +60,9 @@ class SerializationIntegrationTest(test.TestCase):
|
||||
init_ops, get_next_ops, saver = self._build_graph(num_pipelines,
|
||||
num_outputs)
|
||||
with self.session(graph=g) as sess:
|
||||
sess.run(init_ops)
|
||||
self.evaluate(init_ops)
|
||||
for _ in range(break_point):
|
||||
output = sess.run(get_next_ops)
|
||||
output = self.evaluate(get_next_ops)
|
||||
for i in range(num_pipelines):
|
||||
all_outputs[i].append(output[i])
|
||||
saver.save(sess, self._ckpt_path())
|
||||
@ -73,7 +73,7 @@ class SerializationIntegrationTest(test.TestCase):
|
||||
with self.session(graph=g) as sess:
|
||||
saver.restore(sess, self._ckpt_path())
|
||||
for _ in range(num_outputs - break_point):
|
||||
output = sess.run(get_next_ops)
|
||||
output = self.evaluate(get_next_ops)
|
||||
for i in range(num_pipelines):
|
||||
all_outputs[i].append(output[i])
|
||||
|
||||
|
@ -138,9 +138,9 @@ class ShuffleDatasetSerializationTest(
|
||||
saver = saver_lib.Saver(allow_empty=True)
|
||||
with self.session(graph=g) as sess:
|
||||
self._save(sess, saver)
|
||||
expected = [sess.run(get_next_ops) for _ in range(num_outputs)]
|
||||
expected = [self.evaluate(get_next_ops) for _ in range(num_outputs)]
|
||||
self._restore(saver, sess)
|
||||
actual = [sess.run(get_next_ops) for _ in range(num_outputs)]
|
||||
actual = [self.evaluate(get_next_ops) for _ in range(num_outputs)]
|
||||
self.match(expected, actual)
|
||||
|
||||
|
||||
|
@ -38,10 +38,10 @@ class ShuffleAndRepeatTest(test_base.DatasetTestBase):
|
||||
outputs = []
|
||||
with self.cached_session() as sess:
|
||||
for _ in range(num_outputs):
|
||||
outputs.append(sess.run(get_next))
|
||||
outputs.append(self.evaluate(get_next))
|
||||
if verify_exhausted:
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
return outputs
|
||||
|
||||
def testCorrectOutput(self):
|
||||
@ -108,7 +108,7 @@ class ShuffleAndRepeatTest(test_base.DatasetTestBase):
|
||||
shuffle_ops.shuffle_and_repeat(buffer_size=21))
|
||||
get_next_op = ds.make_one_shot_iterator().get_next()
|
||||
with self.session(graph=g) as sess:
|
||||
sess.run(get_next_op)
|
||||
self.evaluate(get_next_op)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -38,14 +38,14 @@ class SleepTest(test_base.DatasetTestBase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
start_time = time.time()
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
end_time = time.time()
|
||||
self.assertGreater(end_time - start_time, (10 * sleep_microseconds) / 1e6)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -39,10 +39,11 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
for _ in range(2): # Dataset is repeated. See setUp.
|
||||
self.assertEqual((b"John", b"Doe", b"Hi!"), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", b"Moe", b"Hi again!"), sess.run(get_next))
|
||||
self.assertEqual((b"John", b"Doe", b"Hi!"), self.evaluate(get_next))
|
||||
self.assertEqual((b"Jane", b"Moe", b"Hi again!"),
|
||||
self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
# Test that SqlDataset works on a join query.
|
||||
def testReadResultSetJoinQuery(self):
|
||||
@ -58,9 +59,10 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
"ON students.first_name = people.first_name "
|
||||
"AND students.last_name = people.last_name"
|
||||
})
|
||||
self.assertEqual((b"John", b"California", b"Hi!"), sess.run(get_next))
|
||||
self.assertEqual((b"John", b"California", b"Hi!"),
|
||||
self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
# Test that SqlDataset can read a database entry with a null-terminator
|
||||
# in the middle of the text and place the entry in a `string` tensor.
|
||||
@ -75,10 +77,11 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
"SELECT first_name, last_name, favorite_nonsense_word "
|
||||
"FROM students ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", b"Doe", b"n\0nsense"), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", b"Moe", b"nonsense\0"), sess.run(get_next))
|
||||
self.assertEqual((b"John", b"Doe", b"n\0nsense"), self.evaluate(get_next))
|
||||
self.assertEqual((b"Jane", b"Moe", b"nonsense\0"),
|
||||
self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
# Test that SqlDataset works when used on two different queries.
|
||||
# Because the output types of the dataset must be determined at graph-creation
|
||||
@ -93,21 +96,22 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
self.query: "SELECT first_name, last_name, motto FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", b"Doe", b"Hi!"), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", b"Moe", b"Hi again!"), sess.run(get_next))
|
||||
self.assertEqual((b"John", b"Doe", b"Hi!"), self.evaluate(get_next))
|
||||
self.assertEqual((b"Jane", b"Moe", b"Hi again!"), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
self.query: "SELECT first_name, last_name, state FROM people "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", b"Doe", b"California"), sess.run(get_next))
|
||||
self.assertEqual((b"John", b"Doe", b"California"),
|
||||
self.evaluate(get_next))
|
||||
self.assertEqual((b"Benjamin", b"Franklin", b"Pennsylvania"),
|
||||
sess.run(get_next))
|
||||
self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
# Test that an `OutOfRangeError` is raised on the first call to
|
||||
# `get_next_str_only` if result set is empty.
|
||||
@ -122,7 +126,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
"WHERE first_name = 'Nonexistent'"
|
||||
})
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
# Test that an error is raised when `driver_name` is invalid.
|
||||
def testReadResultSetWithInvalidDriverName(self):
|
||||
@ -151,7 +155,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
with self.assertRaises(errors.UnknownError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
# Test that an error is raised when there is a syntax error in `query`.
|
||||
def testReadResultSetOfQueryWithSyntaxError(self):
|
||||
@ -166,7 +170,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
with self.assertRaises(errors.UnknownError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
# Test that an error is raised when the number of columns in `query`
|
||||
# does not match the length of `output_types`.
|
||||
@ -181,7 +185,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
# Test that no results are returned when `query` is an insert query rather
|
||||
# than a select query. In particular, the error refers to the number of
|
||||
@ -199,7 +203,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
"VALUES ('Foo', 'Bar', 'Baz'), ('Fizz', 'Buzz', 'Fizzbuzz')"
|
||||
})
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
# Test that `SqlDataset` can read an integer from a SQLite database table and
|
||||
# place it in an `int8` tensor.
|
||||
@ -212,10 +216,10 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
self.query: "SELECT first_name, desk_number FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", 9), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", 127), sess.run(get_next))
|
||||
self.assertEqual((b"John", 9), self.evaluate(get_next))
|
||||
self.assertEqual((b"Jane", 127), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
# Test that `SqlDataset` can read a negative or 0-valued integer from a
|
||||
# SQLite database table and place it in an `int8` tensor.
|
||||
@ -230,9 +234,9 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
"FROM students "
|
||||
"WHERE first_name = 'John' ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", 0, -2), sess.run(get_next))
|
||||
self.assertEqual((b"John", 0, -2), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
# Test that `SqlDataset` can read a large (positive or negative) integer from
|
||||
# a SQLite database table and place it in an `int8` tensor.
|
||||
@ -246,11 +250,11 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
"SELECT desk_number, favorite_negative_number FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((9, -2), sess.run(get_next))
|
||||
self.assertEqual((9, -2), self.evaluate(get_next))
|
||||
# Max and min values of int8
|
||||
self.assertEqual((127, -128), sess.run(get_next))
|
||||
self.assertEqual((127, -128), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
# Test that `SqlDataset` can read an integer from a SQLite database table and
|
||||
# place it in an `int16` tensor.
|
||||
@ -263,10 +267,10 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
self.query: "SELECT first_name, desk_number FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", 9), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", 127), sess.run(get_next))
|
||||
self.assertEqual((b"John", 9), self.evaluate(get_next))
|
||||
self.assertEqual((b"Jane", 127), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
# Test that `SqlDataset` can read a negative or 0-valued integer from a
|
||||
# SQLite database table and place it in an `int16` tensor.
|
||||
@ -281,9 +285,9 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
"FROM students "
|
||||
"WHERE first_name = 'John' ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", 0, -2), sess.run(get_next))
|
||||
self.assertEqual((b"John", 0, -2), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
# Test that `SqlDataset` can read a large (positive or negative) integer from
|
||||
# a SQLite database table and place it in an `int16` tensor.
|
||||
@ -297,11 +301,11 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
"FROM students ORDER BY first_name DESC"
|
||||
})
|
||||
# Max value of int16
|
||||
self.assertEqual((b"John", 32767), sess.run(get_next))
|
||||
self.assertEqual((b"John", 32767), self.evaluate(get_next))
|
||||
# Min value of int16
|
||||
self.assertEqual((b"Jane", -32768), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", -32768), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
# Test that `SqlDataset` can read an integer from a SQLite database table and
|
||||
# place it in an `int32` tensor.
|
||||
@ -314,8 +318,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
self.query: "SELECT first_name, desk_number FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", 9), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", 127), sess.run(get_next))
|
||||
self.assertEqual((b"John", 9), self.evaluate(get_next))
|
||||
self.assertEqual((b"Jane", 127), self.evaluate(get_next))
|
||||
|
||||
# Test that `SqlDataset` can read a negative or 0-valued integer from a
|
||||
# SQLite database table and place it in an `int32` tensor.
|
||||
@ -328,10 +332,10 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
self.query: "SELECT first_name, income FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", 0), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", -20000), sess.run(get_next))
|
||||
self.assertEqual((b"John", 0), self.evaluate(get_next))
|
||||
self.assertEqual((b"Jane", -20000), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
# Test that `SqlDataset` can read a large (positive or negative) integer from
|
||||
# a SQLite database table and place it in an `int32` tensor.
|
||||
@ -345,11 +349,11 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
# Max value of int32
|
||||
self.assertEqual((b"John", 2147483647), sess.run(get_next))
|
||||
self.assertEqual((b"John", 2147483647), self.evaluate(get_next))
|
||||
# Min value of int32
|
||||
self.assertEqual((b"Jane", -2147483648), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", -2147483648), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
# Test that `SqlDataset` can read a numeric `varchar` from a SQLite database
|
||||
# table and place it in an `int32` tensor.
|
||||
@ -362,10 +366,10 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
self.query: "SELECT first_name, school_id FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", 123), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", 1000), sess.run(get_next))
|
||||
self.assertEqual((b"John", 123), self.evaluate(get_next))
|
||||
self.assertEqual((b"Jane", 1000), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
# Test that `SqlDataset` can read an integer from a SQLite database table
|
||||
# and place it in an `int64` tensor.
|
||||
@ -378,10 +382,10 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
self.query: "SELECT first_name, desk_number FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", 9), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", 127), sess.run(get_next))
|
||||
self.assertEqual((b"John", 9), self.evaluate(get_next))
|
||||
self.assertEqual((b"Jane", 127), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
# Test that `SqlDataset` can read a negative or 0-valued integer from a
|
||||
# SQLite database table and place it in an `int64` tensor.
|
||||
@ -394,10 +398,10 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
self.query: "SELECT first_name, income FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", 0), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", -20000), sess.run(get_next))
|
||||
self.assertEqual((b"John", 0), self.evaluate(get_next))
|
||||
self.assertEqual((b"Jane", -20000), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
# Test that `SqlDataset` can read a large (positive or negative) integer from
|
||||
# a SQLite database table and place it in an `int64` tensor.
|
||||
@ -412,11 +416,11 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
# Max value of int64
|
||||
self.assertEqual((b"John", 9223372036854775807), sess.run(get_next))
|
||||
self.assertEqual((b"John", 9223372036854775807), self.evaluate(get_next))
|
||||
# Min value of int64
|
||||
self.assertEqual((b"Jane", -9223372036854775808), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", -9223372036854775808), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
# Test that `SqlDataset` can read an integer from a SQLite database table and
|
||||
# place it in a `uint8` tensor.
|
||||
@ -429,10 +433,10 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
self.query: "SELECT first_name, desk_number FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", 9), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", 127), sess.run(get_next))
|
||||
self.assertEqual((b"John", 9), self.evaluate(get_next))
|
||||
self.assertEqual((b"Jane", 127), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
# Test that `SqlDataset` can read the minimum and maximum uint8 values from a
|
||||
# SQLite database table and place them in `uint8` tensors.
|
||||
@ -446,11 +450,11 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
# Min value of uint8
|
||||
self.assertEqual((b"John", 0), sess.run(get_next))
|
||||
self.assertEqual((b"John", 0), self.evaluate(get_next))
|
||||
# Max value of uint8
|
||||
self.assertEqual((b"Jane", 255), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", 255), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
# Test that `SqlDataset` can read an integer from a SQLite database table
|
||||
# and place it in a `uint16` tensor.
|
||||
@ -463,10 +467,10 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
self.query: "SELECT first_name, desk_number FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", 9), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", 127), sess.run(get_next))
|
||||
self.assertEqual((b"John", 9), self.evaluate(get_next))
|
||||
self.assertEqual((b"Jane", 127), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
# Test that `SqlDataset` can read the minimum and maximum uint16 values from a
|
||||
# SQLite database table and place them in `uint16` tensors.
|
||||
@ -480,11 +484,11 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
# Min value of uint16
|
||||
self.assertEqual((b"John", 0), sess.run(get_next))
|
||||
self.assertEqual((b"John", 0), self.evaluate(get_next))
|
||||
# Max value of uint16
|
||||
self.assertEqual((b"Jane", 65535), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", 65535), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
# Test that `SqlDataset` can read a 0-valued and 1-valued integer from a
|
||||
# SQLite database table and place them as `True` and `False` respectively
|
||||
@ -499,10 +503,10 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
"SELECT first_name, registration_complete FROM students "
|
||||
"ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", True), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", False), sess.run(get_next))
|
||||
self.assertEqual((b"John", True), self.evaluate(get_next))
|
||||
self.assertEqual((b"Jane", False), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
# Test that `SqlDataset` can read an integer that is not 0-valued or 1-valued
|
||||
# from a SQLite database table and place it as `True` in a `bool` tensor.
|
||||
@ -515,10 +519,10 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
self.query: "SELECT first_name, favorite_medium_sized_number "
|
||||
"FROM students ORDER BY first_name DESC"
|
||||
})
|
||||
self.assertEqual((b"John", True), sess.run(get_next))
|
||||
self.assertEqual((b"Jane", True), sess.run(get_next))
|
||||
self.assertEqual((b"John", True), self.evaluate(get_next))
|
||||
self.assertEqual((b"Jane", True), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
# Test that `SqlDataset` can read a float from a SQLite database table
|
||||
# and place it in a `float64` tensor.
|
||||
@ -533,10 +537,11 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
"SELECT first_name, last_name, victories FROM townspeople "
|
||||
"ORDER BY first_name"
|
||||
})
|
||||
self.assertEqual((b"George", b"Washington", 20.0), sess.run(get_next))
|
||||
self.assertEqual((b"John", b"Adams", -19.95), sess.run(get_next))
|
||||
self.assertEqual((b"George", b"Washington", 20.0),
|
||||
self.evaluate(get_next))
|
||||
self.assertEqual((b"John", b"Adams", -19.95), self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
# Test that `SqlDataset` can read a float from a SQLite database table beyond
|
||||
# the precision of 64-bit IEEE, without throwing an error. Test that
|
||||
@ -555,13 +560,13 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
self.assertEqual(
|
||||
(b"George", b"Washington",
|
||||
1331241.321342132321324589798264627463827647382647382643874),
|
||||
sess.run(get_next))
|
||||
self.evaluate(get_next))
|
||||
self.assertEqual(
|
||||
(b"John", b"Adams",
|
||||
1331241321342132321324589798264627463827647382647382643874.0),
|
||||
sess.run(get_next))
|
||||
self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
# Test that `SqlDataset` can read a float from a SQLite database table,
|
||||
# representing the largest integer representable as a 64-bit IEEE float
|
||||
@ -579,11 +584,11 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
"ORDER BY first_name"
|
||||
})
|
||||
self.assertNotEqual((b"George", b"Washington", 9007199254740992.0),
|
||||
sess.run(get_next))
|
||||
self.evaluate(get_next))
|
||||
self.assertNotEqual((b"John", b"Adams", 9007199254740991.0),
|
||||
sess.run(get_next))
|
||||
self.evaluate(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.evaluate(get_next)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -70,18 +70,18 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
summary_t = aggregator.get_summary()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
expected_sum = 0.0
|
||||
for i in range(100):
|
||||
self.assertAllEqual(
|
||||
np.array([i] * i, dtype=np.int64), sess.run(next_element))
|
||||
summary_str = sess.run(summary_t)
|
||||
np.array([i] * i, dtype=np.int64), self.evaluate(next_element))
|
||||
summary_str = self.evaluate(summary_t)
|
||||
self._assertSummaryHasCount(summary_str, "bytes_produced", float(i + 1))
|
||||
expected_sum += i * 8.0
|
||||
self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
summary_str = sess.run(summary_t)
|
||||
self.evaluate(next_element)
|
||||
summary_str = self.evaluate(summary_t)
|
||||
self._assertSummaryHasCount(summary_str, "bytes_produced", 100.0)
|
||||
self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum)
|
||||
|
||||
@ -95,14 +95,15 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
summary_t = aggregator.get_summary()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(100):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
self._assertSummaryHasCount(
|
||||
sess.run(summary_t), "record_latency", float(i + 1))
|
||||
self.evaluate(summary_t), "record_latency", float(i + 1))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0)
|
||||
self.evaluate(next_element)
|
||||
self._assertSummaryHasCount(
|
||||
self.evaluate(summary_t), "record_latency", 100.0)
|
||||
|
||||
def testPrefetchBufferUtilization(self, dataset_transformation):
|
||||
aggregator = stats_aggregator.StatsAggregator()
|
||||
@ -114,11 +115,11 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
summary_t = aggregator.get_summary()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(100):
|
||||
self.assertAllEqual(
|
||||
np.array([i] * i, dtype=np.int64), sess.run(next_element))
|
||||
summary_str = sess.run(summary_t)
|
||||
np.array([i] * i, dtype=np.int64), self.evaluate(next_element))
|
||||
summary_str = self.evaluate(summary_t)
|
||||
self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
|
||||
float(i + 1))
|
||||
self._assertSummaryContains(summary_str, "Prefetch::buffer_capacity")
|
||||
@ -126,8 +127,8 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
self._assertSummaryHasRange(summary_str, "Prefetch::buffer_utilization",
|
||||
0, 1)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
summary_str = sess.run(summary_t)
|
||||
self.evaluate(next_element)
|
||||
summary_str = self.evaluate(summary_t)
|
||||
self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
|
||||
100)
|
||||
|
||||
@ -141,17 +142,17 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
summary_t = aggregator.get_summary()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(10):
|
||||
self.assertAllEqual(
|
||||
np.array([i] * i, dtype=np.int64), sess.run(next_element))
|
||||
summary_str = sess.run(summary_t)
|
||||
np.array([i] * i, dtype=np.int64), self.evaluate(next_element))
|
||||
summary_str = self.evaluate(summary_t)
|
||||
self._assertSummaryHasScalarValue(summary_str,
|
||||
"Prefetch::buffer_capacity", 0)
|
||||
self._assertSummaryHasScalarValue(summary_str, "Prefetch::buffer_size",
|
||||
0)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
def testFilteredElementsStats(self, dataset_transformation):
|
||||
aggregator = stats_aggregator.StatsAggregator()
|
||||
@ -163,20 +164,21 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
summary_t = aggregator.get_summary()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(34):
|
||||
self.assertEqual(i * 3, sess.run(next_element))
|
||||
self.assertEqual(i * 3, self.evaluate(next_element))
|
||||
if i is not 0:
|
||||
self._assertSummaryHasScalarValue(
|
||||
sess.run(summary_t), "Filter::dropped_elements", float(i * 2))
|
||||
self.evaluate(summary_t), "Filter::dropped_elements",
|
||||
float(i * 2))
|
||||
self._assertSummaryHasScalarValue(
|
||||
sess.run(summary_t), "Filter::filtered_elements", float(i + 1))
|
||||
self.evaluate(summary_t), "Filter::filtered_elements", float(i + 1))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
self._assertSummaryHasScalarValue(
|
||||
sess.run(summary_t), "Filter::dropped_elements", 67.0)
|
||||
self.evaluate(summary_t), "Filter::dropped_elements", 67.0)
|
||||
self._assertSummaryHasScalarValue(
|
||||
sess.run(summary_t), "Filter::filtered_elements", 34.0)
|
||||
self.evaluate(summary_t), "Filter::filtered_elements", 34.0)
|
||||
|
||||
def testMapBufferUtilization(self, dataset_transformation):
|
||||
|
||||
@ -257,15 +259,16 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for j in range(5):
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(100):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
self._assertSummaryHasCount(
|
||||
sess.run(summary_t), "record_latency", float((j * 100) + i + 1))
|
||||
self.evaluate(summary_t), "record_latency",
|
||||
float((j * 100) + i + 1))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
self._assertSummaryHasCount(
|
||||
sess.run(summary_t), "record_latency", (j + 1) * 100.0)
|
||||
self.evaluate(summary_t), "record_latency", (j + 1) * 100.0)
|
||||
|
||||
def testNoAggregatorRegistered(self, dataset_transformation):
|
||||
dataset = dataset_ops.Dataset.range(100).apply(
|
||||
@ -274,11 +277,11 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(100):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
def testMultipleTags(self, dataset_transformation):
|
||||
aggregator = stats_aggregator.StatsAggregator()
|
||||
@ -291,18 +294,19 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
summary_t = aggregator.get_summary()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(100):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
self._assertSummaryHasCount(
|
||||
sess.run(summary_t), "record_latency", float(i + 1))
|
||||
self.evaluate(summary_t), "record_latency", float(i + 1))
|
||||
self._assertSummaryHasCount(
|
||||
sess.run(summary_t), "record_latency_2", float(i + 1))
|
||||
self.evaluate(summary_t), "record_latency_2", float(i + 1))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0)
|
||||
self.evaluate(next_element)
|
||||
self._assertSummaryHasCount(
|
||||
sess.run(summary_t), "record_latency_2", 100.0)
|
||||
self.evaluate(summary_t), "record_latency", 100.0)
|
||||
self._assertSummaryHasCount(
|
||||
self.evaluate(summary_t), "record_latency_2", 100.0)
|
||||
|
||||
def testRepeatedTags(self, dataset_transformation):
|
||||
aggregator = stats_aggregator.StatsAggregator()
|
||||
@ -315,14 +319,15 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
summary_t = aggregator.get_summary()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
for i in range(100):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
self.assertEqual(i, self.evaluate(next_element))
|
||||
self._assertSummaryHasCount(
|
||||
sess.run(summary_t), "record_latency", float(2 * (i + 1)))
|
||||
self.evaluate(summary_t), "record_latency", float(2 * (i + 1)))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0)
|
||||
self.evaluate(next_element)
|
||||
self._assertSummaryHasCount(
|
||||
self.evaluate(summary_t), "record_latency", 200.0)
|
||||
|
||||
def testMultipleIteratorsSameAggregator(self, dataset_transformation):
|
||||
aggregator = stats_aggregator.StatsAggregator()
|
||||
@ -335,14 +340,15 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
summary_t = aggregator.get_summary()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run([iterator_0.initializer, iterator_1.initializer])
|
||||
self.evaluate([iterator_0.initializer, iterator_1.initializer])
|
||||
for i in range(100):
|
||||
self.assertEqual(i * 2, sess.run(next_element))
|
||||
self.assertEqual(i * 2, self.evaluate(next_element))
|
||||
self._assertSummaryHasCount(
|
||||
sess.run(summary_t), "record_latency", float(2 * (i + 1)))
|
||||
self.evaluate(summary_t), "record_latency", float(2 * (i + 1)))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0)
|
||||
self.evaluate(next_element)
|
||||
self._assertSummaryHasCount(
|
||||
self.evaluate(summary_t), "record_latency", 200.0)
|
||||
|
||||
def testMultipleDatasetWithPrefixes(self, dataset_transformation):
|
||||
aggregator = stats_aggregator.StatsAggregator()
|
||||
@ -358,19 +364,19 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
summary_t = aggregator.get_summary()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run([iterator_0.initializer, iterator_1.initializer])
|
||||
self.evaluate([iterator_0.initializer, iterator_1.initializer])
|
||||
for i in range(100):
|
||||
self.assertEqual(i * 2, sess.run(next_element))
|
||||
self.assertEqual(i * 2, self.evaluate(next_element))
|
||||
self._assertSummaryHasCount(
|
||||
sess.run(summary_t), "dataset1_record_latency", float(i + 1))
|
||||
self.evaluate(summary_t), "dataset1_record_latency", float(i + 1))
|
||||
self._assertSummaryHasCount(
|
||||
sess.run(summary_t), "dataset2_record_latency", float(i + 1))
|
||||
self.evaluate(summary_t), "dataset2_record_latency", float(i + 1))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
self._assertSummaryHasCount(
|
||||
sess.run(summary_t), "dataset1_record_latency", 100.0)
|
||||
self.evaluate(summary_t), "dataset1_record_latency", 100.0)
|
||||
self._assertSummaryHasCount(
|
||||
sess.run(summary_t), "dataset2_record_latency", 100.0)
|
||||
self.evaluate(summary_t), "dataset2_record_latency", 100.0)
|
||||
|
||||
|
||||
@parameterized.named_parameters(
|
||||
@ -417,20 +423,21 @@ class FeatureStatsDatasetTest(
|
||||
summary_t = aggregator.get_summary()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
for _ in range(num_output):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
self._assertSummaryHasCount(
|
||||
sess.run(summary_t), "record_stats_features", total_records)
|
||||
self.evaluate(summary_t), "record_stats_features", total_records)
|
||||
self._assertSummaryHasCount(
|
||||
sess.run(summary_t), "record_stats_feature-values", total_records)
|
||||
self.evaluate(summary_t), "record_stats_feature-values",
|
||||
total_records)
|
||||
self._assertSummaryHasSum(
|
||||
sess.run(summary_t), "record_stats_features", total_records * 4)
|
||||
self.evaluate(summary_t), "record_stats_features", total_records * 4)
|
||||
self._assertSummaryHasSum(
|
||||
sess.run(summary_t), "record_stats_feature-values",
|
||||
self.evaluate(summary_t), "record_stats_feature-values",
|
||||
self._sum_keywords(1) * num_epochs + 3 * total_records)
|
||||
|
||||
|
||||
|
@ -47,9 +47,9 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer, feed_dict={placeholder: [0, 1, 2, 3]})
|
||||
for i in range(4):
|
||||
self.assertEqual(i, sess.run(next_elem))
|
||||
self.assertEqual(i, self.evaluate(next_elem))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_elem)
|
||||
self.evaluate(next_elem)
|
||||
|
||||
def testUnbatchScalarDataset(self):
|
||||
data = tuple([math_ops.range(10) for _ in range(3)])
|
||||
@ -65,10 +65,10 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual((i,) * 3, sess.run(op))
|
||||
self.assertEqual((i,) * 3, self.evaluate(op))
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(op)
|
||||
self.evaluate(op)
|
||||
|
||||
def testUnbatchDatasetWithStrings(self):
|
||||
data = tuple([math_ops.range(10) for _ in range(3)])
|
||||
@ -85,10 +85,10 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual((i, compat.as_bytes(str(i)), i), sess.run(op))
|
||||
self.assertEqual((i, compat.as_bytes(str(i)), i), self.evaluate(op))
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(op)
|
||||
self.evaluate(op)
|
||||
|
||||
def testUnbatchDatasetWithSparseTensor(self):
|
||||
st = sparse_tensor.SparseTensorValue(
|
||||
@ -104,12 +104,12 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
st_row = sess.run(next_element)
|
||||
st_row = self.evaluate(next_element)
|
||||
self.assertEqual([i], st_row.indices)
|
||||
self.assertEqual([i], st_row.values)
|
||||
self.assertEqual([10], st_row.dense_shape)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
def testUnbatchDatasetWithDenseAndSparseTensor(self):
|
||||
st = sparse_tensor.SparseTensorValue(
|
||||
@ -125,13 +125,13 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
dense_elem, st_row = sess.run(next_element)
|
||||
dense_elem, st_row = self.evaluate(next_element)
|
||||
self.assertEqual(i, dense_elem)
|
||||
self.assertEqual([i], st_row.indices)
|
||||
self.assertEqual([i], st_row.values)
|
||||
self.assertEqual([10], st_row.dense_shape)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
def testUnbatchSingleElementTupleDataset(self):
|
||||
data = tuple([(math_ops.range(10),) for _ in range(3)])
|
||||
@ -147,10 +147,10 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual(((i,),) * 3, sess.run(op))
|
||||
self.assertEqual(((i,),) * 3, self.evaluate(op))
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(op)
|
||||
self.evaluate(op)
|
||||
|
||||
def testUnbatchMultiElementTupleDataset(self):
|
||||
data = tuple([(math_ops.range(10 * i, 10 * i + 10),
|
||||
@ -168,10 +168,10 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual(((i, b"hi"), (10 + i, b"hi"), (20 + i, b"hi")),
|
||||
sess.run(op))
|
||||
self.evaluate(op))
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(op)
|
||||
self.evaluate(op)
|
||||
|
||||
def testUnbatchEmpty(self):
|
||||
data = dataset_ops.Dataset.from_tensors(
|
||||
@ -183,7 +183,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
def testUnbatchStaticShapeMismatch(self):
|
||||
data = dataset_ops.Dataset.from_tensors((np.arange(7), np.arange(8),
|
||||
@ -208,7 +208,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
ph2: np.arange(8).astype(np.int32)
|
||||
})
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
# No 0th dimension (i.e. scalar value) for one component.
|
||||
sess.run(
|
||||
@ -218,7 +218,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
ph2: 7
|
||||
})
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -49,13 +49,13 @@ class UniqueTest(test_base.DatasetTestBase):
|
||||
with self.cached_session() as sess:
|
||||
for test_case, expected in test_cases:
|
||||
current_test_case = test_case
|
||||
sess.run(iterator.initializer)
|
||||
self.evaluate(iterator.initializer)
|
||||
for element in expected:
|
||||
if dtype == dtypes.string:
|
||||
element = compat.as_bytes(element)
|
||||
self.assertAllEqual(element, sess.run(next_element))
|
||||
self.assertAllEqual(element, self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
def testSimpleInt(self):
|
||||
for dtype in [dtypes.int32, dtypes.int64]:
|
||||
|
@ -41,7 +41,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
||||
|
||||
config = config_pb2.ConfigProto(device_count={"CPU": 3})
|
||||
with self.test_session(config=config) as sess:
|
||||
sess.run(multi_device_iterator.initializer)
|
||||
self.evaluate(multi_device_iterator.initializer)
|
||||
|
||||
def testBasic(self):
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
@ -51,13 +51,13 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
||||
|
||||
config = config_pb2.ConfigProto(device_count={"CPU": 3})
|
||||
with self.test_session(config=config) as sess:
|
||||
sess.run(multi_device_iterator.initializer)
|
||||
self.evaluate(multi_device_iterator.initializer)
|
||||
for i in range(0, 10, 2):
|
||||
self.assertEqual(i, sess.run(elem_on_1))
|
||||
self.assertEqual(i + 1, sess.run(elem_on_2))
|
||||
self.assertEqual(i, self.evaluate(elem_on_1))
|
||||
self.assertEqual(i + 1, self.evaluate(elem_on_2))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(elem_on_1)
|
||||
sess.run(elem_on_2)
|
||||
self.evaluate(elem_on_1)
|
||||
self.evaluate(elem_on_2)
|
||||
|
||||
def testOneOnSameDevice(self):
|
||||
with ops.device("/cpu:0"):
|
||||
@ -68,13 +68,13 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
||||
|
||||
config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=config) as sess:
|
||||
sess.run(multi_device_iterator.initializer)
|
||||
self.evaluate(multi_device_iterator.initializer)
|
||||
for i in range(0, 10, 2):
|
||||
self.assertEqual(i, sess.run(elem_on_1))
|
||||
self.assertEqual(i + 1, sess.run(elem_on_2))
|
||||
self.assertEqual(i, self.evaluate(elem_on_1))
|
||||
self.assertEqual(i + 1, self.evaluate(elem_on_2))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(elem_on_1)
|
||||
sess.run(elem_on_2)
|
||||
self.evaluate(elem_on_1)
|
||||
self.evaluate(elem_on_2)
|
||||
|
||||
def testRepeatDevices(self):
|
||||
with ops.device("/cpu:0"):
|
||||
@ -86,17 +86,17 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
||||
|
||||
config = config_pb2.ConfigProto(device_count={"CPU": 3})
|
||||
with self.test_session(config=config) as sess:
|
||||
sess.run(multi_device_iterator.initializer)
|
||||
self.evaluate(multi_device_iterator.initializer)
|
||||
for i in range(0, 20, 4):
|
||||
self.assertEqual(i, sess.run(elem_on_1))
|
||||
self.assertEqual(i + 1, sess.run(elem_on_2))
|
||||
self.assertEqual(i + 2, sess.run(elem_on_3))
|
||||
self.assertEqual(i + 3, sess.run(elem_on_4))
|
||||
self.assertEqual(i, self.evaluate(elem_on_1))
|
||||
self.assertEqual(i + 1, self.evaluate(elem_on_2))
|
||||
self.assertEqual(i + 2, self.evaluate(elem_on_3))
|
||||
self.assertEqual(i + 3, self.evaluate(elem_on_4))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(elem_on_1)
|
||||
sess.run(elem_on_2)
|
||||
sess.run(elem_on_3)
|
||||
sess.run(elem_on_4)
|
||||
self.evaluate(elem_on_1)
|
||||
self.evaluate(elem_on_2)
|
||||
self.evaluate(elem_on_3)
|
||||
self.evaluate(elem_on_4)
|
||||
|
||||
def testNotFullyDivisible(self):
|
||||
dataset = dataset_ops.Dataset.range(9)
|
||||
@ -106,14 +106,14 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
||||
|
||||
config = config_pb2.ConfigProto(device_count={"CPU": 3})
|
||||
with self.test_session(config=config) as sess:
|
||||
sess.run(multi_device_iterator.initializer)
|
||||
self.evaluate(multi_device_iterator.initializer)
|
||||
for i in range(0, 8, 2):
|
||||
self.assertEqual(i, sess.run(elem_on_1))
|
||||
self.assertEqual(i + 1, sess.run(elem_on_2))
|
||||
self.assertEqual(8, sess.run(elem_on_1))
|
||||
self.assertEqual(i, self.evaluate(elem_on_1))
|
||||
self.assertEqual(i + 1, self.evaluate(elem_on_2))
|
||||
self.assertEqual(8, self.evaluate(elem_on_1))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(elem_on_1)
|
||||
sess.run(elem_on_2)
|
||||
self.evaluate(elem_on_1)
|
||||
self.evaluate(elem_on_2)
|
||||
|
||||
def testGetNextAsOptional(self):
|
||||
dataset = dataset_ops.Dataset.range(9)
|
||||
@ -127,7 +127,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
||||
|
||||
config = config_pb2.ConfigProto(device_count={"CPU": 3})
|
||||
with self.test_session(config=config) as sess:
|
||||
sess.run(multi_device_iterator.initializer)
|
||||
self.evaluate(multi_device_iterator.initializer)
|
||||
for i in range(0, 8, 2):
|
||||
elem_on_1_has_value, elem_on_1_value = sess.run(
|
||||
[elem_on_1_has_value_t, elem_on_1_t])
|
||||
@ -141,12 +141,12 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
||||
[elem_on_1_has_value_t, elem_on_1_t])
|
||||
self.assertTrue(elem_on_1_has_value)
|
||||
self.assertEqual(8, elem_on_1_value)
|
||||
self.assertFalse(sess.run(elem_on_1_has_value_t))
|
||||
self.assertFalse(sess.run(elem_on_2_has_value_t))
|
||||
self.assertFalse(self.evaluate(elem_on_1_has_value_t))
|
||||
self.assertFalse(self.evaluate(elem_on_2_has_value_t))
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(elem_on_1_t)
|
||||
self.evaluate(elem_on_1_t)
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(elem_on_2_t)
|
||||
self.evaluate(elem_on_2_t)
|
||||
|
||||
def testUneven(self):
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
@ -156,14 +156,14 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
||||
|
||||
config = config_pb2.ConfigProto(device_count={"CPU": 3})
|
||||
with self.test_session(config=config) as sess:
|
||||
sess.run(multi_device_iterator.initializer)
|
||||
self.evaluate(multi_device_iterator.initializer)
|
||||
for i in range(0, 10, 2):
|
||||
self.assertEqual(i, sess.run(elem_on_1))
|
||||
self.assertEqual(i, self.evaluate(elem_on_1))
|
||||
for i in range(0, 10, 2):
|
||||
self.assertEqual(i + 1, sess.run(elem_on_2))
|
||||
self.assertEqual(i + 1, self.evaluate(elem_on_2))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(elem_on_1)
|
||||
sess.run(elem_on_2)
|
||||
self.evaluate(elem_on_1)
|
||||
self.evaluate(elem_on_2)
|
||||
|
||||
def testMultipleInitializations(self):
|
||||
with ops.device("/cpu:0"):
|
||||
@ -180,7 +180,8 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
||||
with self.test_session(config=config) as sess:
|
||||
for i in range(1000):
|
||||
sess.run(init_op, feed_dict={epoch: i})
|
||||
self.assertEqual([(i, 0), (i, 1)], sess.run([elem_on_1, elem_on_2]))
|
||||
self.assertEqual([(i, 0), (i, 1)], self.evaluate([elem_on_1,
|
||||
elem_on_2]))
|
||||
|
||||
def testBasicGpu(self):
|
||||
if not test_util.is_gpu_available():
|
||||
@ -193,13 +194,13 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
||||
|
||||
config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1})
|
||||
with self.test_session(config=config) as sess:
|
||||
sess.run(multi_device_iterator.initializer)
|
||||
self.evaluate(multi_device_iterator.initializer)
|
||||
for i in range(0, 10, 2):
|
||||
self.assertEqual(i, sess.run(elem_on_1))
|
||||
self.assertEqual(i + 1, sess.run(elem_on_2))
|
||||
self.assertEqual(i, self.evaluate(elem_on_1))
|
||||
self.assertEqual(i + 1, self.evaluate(elem_on_2))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(elem_on_1)
|
||||
sess.run(elem_on_2)
|
||||
self.evaluate(elem_on_1)
|
||||
self.evaluate(elem_on_2)
|
||||
|
||||
def testUnevenGpu(self):
|
||||
if not test_util.is_gpu_available():
|
||||
@ -212,14 +213,14 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
||||
|
||||
config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1})
|
||||
with self.test_session(config=config) as sess:
|
||||
sess.run(multi_device_iterator.initializer)
|
||||
self.evaluate(multi_device_iterator.initializer)
|
||||
for i in range(0, 10, 2):
|
||||
self.assertEqual(i, sess.run(elem_on_1))
|
||||
self.assertEqual(i, self.evaluate(elem_on_1))
|
||||
for i in range(0, 10, 2):
|
||||
self.assertEqual(i + 1, sess.run(elem_on_2))
|
||||
self.assertEqual(i + 1, self.evaluate(elem_on_2))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(elem_on_1)
|
||||
sess.run(elem_on_2)
|
||||
self.evaluate(elem_on_1)
|
||||
self.evaluate(elem_on_2)
|
||||
|
||||
def testGetNextAsOptionalGpu(self):
|
||||
if not test_util.is_gpu_available():
|
||||
@ -236,7 +237,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
||||
|
||||
config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1})
|
||||
with self.test_session(config=config) as sess:
|
||||
sess.run(multi_device_iterator.initializer)
|
||||
self.evaluate(multi_device_iterator.initializer)
|
||||
for i in range(0, 8, 2):
|
||||
elem_on_1_has_value, elem_on_1_value = sess.run(
|
||||
[elem_on_1_has_value_t, elem_on_1_t])
|
||||
@ -250,12 +251,12 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
||||
[elem_on_1_has_value_t, elem_on_1_t])
|
||||
self.assertTrue(elem_on_1_has_value)
|
||||
self.assertEqual(8, elem_on_1_value)
|
||||
self.assertFalse(sess.run(elem_on_1_has_value_t))
|
||||
self.assertFalse(sess.run(elem_on_2_has_value_t))
|
||||
self.assertFalse(self.evaluate(elem_on_1_has_value_t))
|
||||
self.assertFalse(self.evaluate(elem_on_2_has_value_t))
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(elem_on_1_t)
|
||||
self.evaluate(elem_on_1_t)
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(elem_on_2_t)
|
||||
self.evaluate(elem_on_2_t)
|
||||
|
||||
def testOptimization(self):
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
@ -273,13 +274,13 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
||||
|
||||
config = config_pb2.ConfigProto(device_count={"CPU": 3})
|
||||
with self.test_session(config=config) as sess:
|
||||
sess.run(multi_device_iterator.initializer)
|
||||
self.evaluate(multi_device_iterator.initializer)
|
||||
for i in range(0, 10, 2):
|
||||
self.assertEqual(i, sess.run(elem_on_1))
|
||||
self.assertEqual(i + 1, sess.run(elem_on_2))
|
||||
self.assertEqual(i, self.evaluate(elem_on_1))
|
||||
self.assertEqual(i + 1, self.evaluate(elem_on_2))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(elem_on_1)
|
||||
sess.run(elem_on_2)
|
||||
self.evaluate(elem_on_1)
|
||||
self.evaluate(elem_on_2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -30,47 +30,52 @@ class ConvertTest(test.TestCase):
|
||||
|
||||
def testInteger(self):
|
||||
resp = convert.optional_param_to_tensor("foo", 3)
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(3, sess.run(resp))
|
||||
self.assertEqual(3, self.evaluate(resp))
|
||||
|
||||
def testIntegerDefault(self):
|
||||
resp = convert.optional_param_to_tensor("foo", None)
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(0, sess.run(resp))
|
||||
self.assertEqual(0, self.evaluate(resp))
|
||||
|
||||
def testStringDefault(self):
|
||||
resp = convert.optional_param_to_tensor("bar", None, "default",
|
||||
dtypes.string)
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(compat.as_bytes("default"), sess.run(resp))
|
||||
self.assertEqual(compat.as_bytes("default"), self.evaluate(resp))
|
||||
|
||||
def testString(self):
|
||||
resp = convert.optional_param_to_tensor("bar", "value", "default",
|
||||
dtypes.string)
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(compat.as_bytes("value"), sess.run(resp))
|
||||
self.assertEqual(compat.as_bytes("value"), self.evaluate(resp))
|
||||
|
||||
def testPartialShapeToTensorKnownDimension(self):
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual([1], sess.run(convert.partial_shape_to_tensor(
|
||||
tensor_shape.TensorShape([1]))))
|
||||
self.assertAllEqual([1], sess.run(convert.partial_shape_to_tensor((1,))))
|
||||
self.assertAllEqual([1], sess.run(convert.partial_shape_to_tensor([1])))
|
||||
self.assertAllEqual([1], sess.run(convert.partial_shape_to_tensor(
|
||||
constant_op.constant([1], dtype=dtypes.int64))))
|
||||
self.assertAllEqual([1],
|
||||
self.evaluate(
|
||||
convert.partial_shape_to_tensor(
|
||||
tensor_shape.TensorShape([1]))))
|
||||
self.assertAllEqual([1], self.evaluate(
|
||||
convert.partial_shape_to_tensor((1,))))
|
||||
self.assertAllEqual([1], self.evaluate(
|
||||
convert.partial_shape_to_tensor([1])))
|
||||
self.assertAllEqual([1],
|
||||
self.evaluate(
|
||||
convert.partial_shape_to_tensor(
|
||||
constant_op.constant([1], dtype=dtypes.int64))))
|
||||
|
||||
def testPartialShapeToTensorUnknownDimension(self):
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual([-1], sess.run(convert.partial_shape_to_tensor(
|
||||
tensor_shape.TensorShape([None]))))
|
||||
self.assertAllEqual([-1], sess.run(convert.partial_shape_to_tensor(
|
||||
(None,))))
|
||||
self.assertAllEqual([-1], sess.run(convert.partial_shape_to_tensor(
|
||||
[None])))
|
||||
self.assertAllEqual([-1], sess.run(convert.partial_shape_to_tensor(
|
||||
[-1])))
|
||||
self.assertAllEqual([-1], sess.run(convert.partial_shape_to_tensor(
|
||||
constant_op.constant([-1], dtype=dtypes.int64))))
|
||||
self.assertAllEqual([-1],
|
||||
self.evaluate(
|
||||
convert.partial_shape_to_tensor(
|
||||
tensor_shape.TensorShape([None]))))
|
||||
self.assertAllEqual([-1],
|
||||
self.evaluate(convert.partial_shape_to_tensor((None,))))
|
||||
self.assertAllEqual([-1],
|
||||
self.evaluate(convert.partial_shape_to_tensor([None])))
|
||||
self.assertAllEqual([-1],
|
||||
self.evaluate(convert.partial_shape_to_tensor([-1])))
|
||||
self.assertAllEqual([-1],
|
||||
self.evaluate(
|
||||
convert.partial_shape_to_tensor(
|
||||
constant_op.constant([-1],
|
||||
dtype=dtypes.int64))))
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, r"The given shape .* must be a 1-D tensor of tf.int64 "
|
||||
@ -84,42 +89,63 @@ class ConvertTest(test.TestCase):
|
||||
convert.partial_shape_to_tensor(constant_op.constant([1., 1.]))
|
||||
|
||||
def testPartialShapeToTensorMultipleDimensions(self):
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual([3, 6], sess.run(convert.partial_shape_to_tensor(
|
||||
tensor_shape.TensorShape([3, 6]))))
|
||||
self.assertAllEqual([3, 6], sess.run(convert.partial_shape_to_tensor(
|
||||
(3, 6))))
|
||||
self.assertAllEqual([3, 6], sess.run(convert.partial_shape_to_tensor(
|
||||
[3, 6])))
|
||||
self.assertAllEqual([3, 6], sess.run(convert.partial_shape_to_tensor(
|
||||
constant_op.constant([3, 6], dtype=dtypes.int64))))
|
||||
self.assertAllEqual([3, 6],
|
||||
self.evaluate(
|
||||
convert.partial_shape_to_tensor(
|
||||
tensor_shape.TensorShape([3, 6]))))
|
||||
self.assertAllEqual([3, 6],
|
||||
self.evaluate(convert.partial_shape_to_tensor((3, 6))))
|
||||
self.assertAllEqual([3, 6],
|
||||
self.evaluate(convert.partial_shape_to_tensor([3, 6])))
|
||||
self.assertAllEqual([3, 6],
|
||||
self.evaluate(
|
||||
convert.partial_shape_to_tensor(
|
||||
constant_op.constant([3, 6],
|
||||
dtype=dtypes.int64))))
|
||||
|
||||
self.assertAllEqual([3, -1], sess.run(convert.partial_shape_to_tensor(
|
||||
tensor_shape.TensorShape([3, None]))))
|
||||
self.assertAllEqual([3, -1], sess.run(convert.partial_shape_to_tensor(
|
||||
(3, None))))
|
||||
self.assertAllEqual([3, -1], sess.run(convert.partial_shape_to_tensor(
|
||||
[3, None])))
|
||||
self.assertAllEqual([3, -1], sess.run(convert.partial_shape_to_tensor(
|
||||
constant_op.constant([3, -1], dtype=dtypes.int64))))
|
||||
self.assertAllEqual([3, -1],
|
||||
self.evaluate(
|
||||
convert.partial_shape_to_tensor(
|
||||
tensor_shape.TensorShape([3, None]))))
|
||||
self.assertAllEqual([3, -1],
|
||||
self.evaluate(
|
||||
convert.partial_shape_to_tensor((3, None))))
|
||||
self.assertAllEqual([3, -1],
|
||||
self.evaluate(
|
||||
convert.partial_shape_to_tensor([3, None])))
|
||||
self.assertAllEqual([3, -1],
|
||||
self.evaluate(
|
||||
convert.partial_shape_to_tensor(
|
||||
constant_op.constant([3, -1],
|
||||
dtype=dtypes.int64))))
|
||||
|
||||
self.assertAllEqual([-1, -1], sess.run(convert.partial_shape_to_tensor(
|
||||
tensor_shape.TensorShape([None, None]))))
|
||||
self.assertAllEqual([-1, -1], sess.run(convert.partial_shape_to_tensor(
|
||||
(None, None))))
|
||||
self.assertAllEqual([-1, -1], sess.run(convert.partial_shape_to_tensor(
|
||||
[None, None])))
|
||||
self.assertAllEqual([-1, -1], sess.run(convert.partial_shape_to_tensor(
|
||||
constant_op.constant([-1, -1], dtype=dtypes.int64))))
|
||||
self.assertAllEqual([-1, -1],
|
||||
self.evaluate(
|
||||
convert.partial_shape_to_tensor(
|
||||
tensor_shape.TensorShape([None, None]))))
|
||||
self.assertAllEqual([-1, -1],
|
||||
self.evaluate(
|
||||
convert.partial_shape_to_tensor((None, None))))
|
||||
self.assertAllEqual([-1, -1],
|
||||
self.evaluate(
|
||||
convert.partial_shape_to_tensor([None, None])))
|
||||
self.assertAllEqual([-1, -1],
|
||||
self.evaluate(
|
||||
convert.partial_shape_to_tensor(
|
||||
constant_op.constant([-1, -1],
|
||||
dtype=dtypes.int64))))
|
||||
|
||||
def testPartialShapeToTensorScalar(self):
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual([], sess.run(convert.partial_shape_to_tensor(
|
||||
tensor_shape.TensorShape([]))))
|
||||
self.assertAllEqual([], sess.run(convert.partial_shape_to_tensor(())))
|
||||
self.assertAllEqual([], sess.run(convert.partial_shape_to_tensor([])))
|
||||
self.assertAllEqual([], sess.run(convert.partial_shape_to_tensor(
|
||||
constant_op.constant([], dtype=dtypes.int64))))
|
||||
self.assertAllEqual([],
|
||||
self.evaluate(
|
||||
convert.partial_shape_to_tensor(
|
||||
tensor_shape.TensorShape([]))))
|
||||
self.assertAllEqual([], self.evaluate(convert.partial_shape_to_tensor(())))
|
||||
self.assertAllEqual([], self.evaluate(convert.partial_shape_to_tensor([])))
|
||||
self.assertAllEqual([],
|
||||
self.evaluate(
|
||||
convert.partial_shape_to_tensor(
|
||||
constant_op.constant([], dtype=dtypes.int64))))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1583,7 +1583,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
|
||||
x = variables.VariableV1([1, 3, 3, 7], name="x")
|
||||
_, idx = array_ops.unique(x, name="x_unique")
|
||||
idx_times_two = math_ops.multiply(idx, 2, name="idx_times_two")
|
||||
sess.run(x.initializer)
|
||||
self.evaluate(x.initializer)
|
||||
|
||||
run_options = config_pb2.RunOptions(output_partition_graphs=True)
|
||||
debug_utils.watch_graph(
|
||||
|
@ -126,8 +126,8 @@ class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase):
|
||||
u = variables.Variable([12.0], name="u")
|
||||
v = variables.Variable([30.0], name="v")
|
||||
w = math_ops.add(u, v, name="w")
|
||||
sess.run(u.initializer)
|
||||
sess.run(v.initializer)
|
||||
self.evaluate(u.initializer)
|
||||
self.evaluate(v.initializer)
|
||||
|
||||
self._compareOriginalAndReconstructedGraphDefs(
|
||||
sess, w, expected_output=[42.0])
|
||||
@ -139,7 +139,7 @@ class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase):
|
||||
b = math_ops.add(a, a, name="b")
|
||||
with ops.control_dependencies([a, b]):
|
||||
c = math_ops.multiply(b, b, name="c")
|
||||
sess.run(a.initializer)
|
||||
self.evaluate(a.initializer)
|
||||
|
||||
self._compareOriginalAndReconstructedGraphDefs(
|
||||
sess, c, expected_output=400.0)
|
||||
@ -150,8 +150,8 @@ class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase):
|
||||
y = variables.Variable(20.0, name="y")
|
||||
cond = control_flow_ops.cond(
|
||||
x > y, lambda: math_ops.add(x, 1), lambda: math_ops.add(y, 1))
|
||||
sess.run(x.initializer)
|
||||
sess.run(y.initializer)
|
||||
self.evaluate(x.initializer)
|
||||
self.evaluate(y.initializer)
|
||||
|
||||
self._compareOriginalAndReconstructedGraphDefs(
|
||||
sess, cond, expected_output=21.0)
|
||||
@ -173,8 +173,8 @@ class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase):
|
||||
toy_loss = x * (u - v)
|
||||
train_op = gradient_descent.GradientDescentOptimizer(
|
||||
learning_rate=0.1).minimize(toy_loss, name="train_op")
|
||||
sess.run(u.initializer)
|
||||
sess.run(v.initializer)
|
||||
self.evaluate(u.initializer)
|
||||
self.evaluate(v.initializer)
|
||||
|
||||
self._compareOriginalAndReconstructedGraphDefs(sess, train_op)
|
||||
|
||||
|
@ -67,7 +67,7 @@ class SessionDebugMultiGPUTest(test_util.TensorFlowTestCase):
|
||||
u1 = math_ops.multiply(v, v, name="u1")
|
||||
w = math_ops.subtract(u1, u0, name="w")
|
||||
|
||||
sess.run(v.initializer)
|
||||
self.evaluate(v.initializer)
|
||||
|
||||
run_options = config_pb2.RunOptions(output_partition_graphs=True)
|
||||
debug_utils.watch_graph(run_options, sess.graph,
|
||||
|
@ -109,8 +109,8 @@ class SourceHelperTest(test_util.TensorFlowTestCase):
|
||||
self.w = math_ops.matmul(self.u, self.v, name="w")
|
||||
self.w_line_number = line_number_above()
|
||||
|
||||
sess.run(self.u.initializer)
|
||||
sess.run(self.v.initializer)
|
||||
self.evaluate(self.u.initializer)
|
||||
self.evaluate(self.v.initializer)
|
||||
|
||||
run_options = config_pb2.RunOptions(output_partition_graphs=True)
|
||||
debug_utils.watch_graph(
|
||||
|
@ -92,9 +92,9 @@ class AutoShardDatasetTest(test.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
for f in range(self._shard_index, self._num_files, self._num_shards):
|
||||
for r in range(self._num_records):
|
||||
self.assertAllEqual(record_fn(r, f), sess.run(next_element))
|
||||
self.assertAllEqual(record_fn(r, f), self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
def testTFRecordDataset(self):
|
||||
dataset = readers.TFRecordDataset(self._createTFRecordFiles())
|
||||
@ -138,10 +138,10 @@ class AutoShardDatasetTest(test.TestCase):
|
||||
actual, expected = [], []
|
||||
for f in range(self._shard_index, self._num_files, self._num_shards):
|
||||
for r in range(self._num_records):
|
||||
actual.append(sess.run(next_element))
|
||||
actual.append(self.evaluate(next_element))
|
||||
expected.append(self._record(r, f))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
self.assertAllEqual(expected, actual)
|
||||
|
||||
def testComplexPipeline(self):
|
||||
@ -171,9 +171,9 @@ class AutoShardDatasetTest(test.TestCase):
|
||||
num_iterations = (self._num_files * self._num_records * num_epochs) // (
|
||||
self._num_shards * batch_size)
|
||||
for _ in range(num_iterations):
|
||||
actual.extend(sess.run(next_element))
|
||||
actual.extend(self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
expected = []
|
||||
for f in range(0, self._num_files, self._num_shards):
|
||||
@ -205,12 +205,13 @@ class AutoShardDatasetTest(test.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
for f in range(self._shard_index, self._num_files, self._num_shards):
|
||||
for r in range(self._num_records):
|
||||
self.assertAllEqual(self._record(r, f), sess.run(next_element))
|
||||
self.assertAllEqual(self._record(r, f), self.evaluate(next_element))
|
||||
for f in range(self._shard_index, self._num_files, self._num_shards):
|
||||
for r in range(self._num_records):
|
||||
self.assertAllEqual(self._text_line(r, f), sess.run(next_element))
|
||||
self.assertAllEqual(
|
||||
self._text_line(r, f), self.evaluate(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
self.evaluate(next_element)
|
||||
|
||||
def testTextLineReader(self):
|
||||
dataset = readers.TextLineDataset(self._createTextFiles())
|
||||
|
@ -149,9 +149,9 @@ class DefFunctionTest(test.TestCase):
|
||||
|
||||
result = fn(3.0)
|
||||
|
||||
sess.run(variables.global_variables_initializer())
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertAllEqual(sess.run(state[0]), 2.0)
|
||||
self.assertAllEqual(sess.run(result), 6.0)
|
||||
self.assertAllEqual(self.evaluate(result), 6.0)
|
||||
|
||||
def testLegacyGraphModeVariablesNonTrivialInitializer(self):
|
||||
with ops.Graph().as_default(), self.test_session() as sess:
|
||||
@ -168,9 +168,9 @@ class DefFunctionTest(test.TestCase):
|
||||
|
||||
result = fn(3.0)
|
||||
|
||||
sess.run(variables.global_variables_initializer())
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertAllEqual(sess.run(state[0]), 6.0)
|
||||
self.assertAllEqual(sess.run(result), 18.0)
|
||||
self.assertAllEqual(self.evaluate(result), 18.0)
|
||||
|
||||
def testLegacyGraphModeInputDependentInitializerFails(self):
|
||||
with ops.Graph().as_default():
|
||||
|
@ -78,7 +78,7 @@ class FunctionGradientsTest(test.TestCase, parameterized.TestCase):
|
||||
c = constant_op.constant([[2.]])
|
||||
f_c = f(c)
|
||||
g, = gradients_impl.gradients(f_c, c)
|
||||
self.assertAllEqual(sess.run(g).values, [[1.0]])
|
||||
self.assertAllEqual(self.evaluate(g).values, [[1.0]])
|
||||
|
||||
def testNoSymGradNestedDefun(self):
|
||||
|
||||
|
@ -564,7 +564,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
variables.global_variables_initializer().run()
|
||||
call = def_function.function(o.call)
|
||||
op = call()
|
||||
self.assertAllEqual(sess.run(op), 2.0)
|
||||
self.assertAllEqual(self.evaluate(op), 2.0)
|
||||
|
||||
def testGraphModeManyFunctions(self):
|
||||
with ops.Graph().as_default(), self.cached_session():
|
||||
@ -1732,7 +1732,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
function.register(cpu_boost, x)
|
||||
y = gpu_boost(x)
|
||||
y_value = sess.run(y)
|
||||
y_value = self.evaluate(y)
|
||||
|
||||
if test.is_gpu_available():
|
||||
self.assertEqual(y_value, 5.0)
|
||||
|
@ -1027,7 +1027,7 @@ class CrossedColumnTest(test.TestCase):
|
||||
outputs = _transform_features(features, [price_cross_wire])
|
||||
output = outputs[price_cross_wire]
|
||||
with self.cached_session() as sess:
|
||||
output_val = sess.run(output)
|
||||
output_val = self.evaluate(output)
|
||||
self.assertAllEqual(
|
||||
[[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]], output_val.indices)
|
||||
for val in output_val.values:
|
||||
@ -1886,7 +1886,8 @@ class LinearModelTest(test.TestCase):
|
||||
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
|
||||
sess.run(bias.assign([5.]))
|
||||
|
||||
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], sess.run(net))
|
||||
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]],
|
||||
self.evaluate(net))
|
||||
|
||||
def test_with_1d_unknown_shape_sparse_tensor(self):
|
||||
price = fc._numeric_column('price')
|
||||
@ -2525,7 +2526,8 @@ class _LinearModelTest(test.TestCase):
|
||||
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
|
||||
sess.run(bias.assign([5.]))
|
||||
|
||||
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], sess.run(net))
|
||||
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]],
|
||||
self.evaluate(net))
|
||||
|
||||
def test_with_1d_unknown_shape_sparse_tensor(self):
|
||||
price = fc._numeric_column('price')
|
||||
|
@ -1188,7 +1188,7 @@ class CrossedColumnTest(test.TestCase):
|
||||
outputs = fc._transform_features_v2(features, [price_cross_wire], None)
|
||||
output = outputs[price_cross_wire]
|
||||
with self.cached_session() as sess:
|
||||
output_val = sess.run(output)
|
||||
output_val = self.evaluate(output)
|
||||
self.assertAllEqual(
|
||||
[[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]], output_val.indices)
|
||||
for val in output_val.values:
|
||||
@ -2088,7 +2088,8 @@ class LinearModelTest(test.TestCase):
|
||||
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
|
||||
sess.run(bias.assign([5.]))
|
||||
|
||||
self.assertAllClose([[10 - 1000 + 5.], [100 - 10 + 5.]], sess.run(net))
|
||||
self.assertAllClose([[10 - 1000 + 5.], [100 - 10 + 5.]],
|
||||
self.evaluate(net))
|
||||
|
||||
coord.request_stop()
|
||||
coord.join(threads)
|
||||
@ -2124,7 +2125,8 @@ class LinearModelTest(test.TestCase):
|
||||
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
|
||||
sess.run(bias.assign([5.]))
|
||||
|
||||
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], sess.run(net))
|
||||
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]],
|
||||
self.evaluate(net))
|
||||
|
||||
def test_with_1d_unknown_shape_sparse_tensor(self):
|
||||
price = fc.numeric_column('price')
|
||||
@ -2843,7 +2845,8 @@ class OldLinearModelTest(test.TestCase):
|
||||
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
|
||||
sess.run(bias.assign([5.]))
|
||||
|
||||
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], sess.run(net))
|
||||
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]],
|
||||
self.evaluate(net))
|
||||
|
||||
def test_with_1d_unknown_shape_sparse_tensor(self):
|
||||
price = fc.numeric_column('price')
|
||||
|
@ -42,7 +42,7 @@ class FileSystemTest(test.TestCase):
|
||||
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
|
||||
queue.enqueue_many([["test://foo"]]).run()
|
||||
queue.close().run()
|
||||
key, value = sess.run(reader.read(queue))
|
||||
key, value = self.evaluate(reader.read(queue))
|
||||
self.assertEqual(key, compat.as_bytes("test://foo"))
|
||||
self.assertEqual(value, compat.as_bytes("AAAAAAAAAA"))
|
||||
|
||||
|
@ -102,7 +102,7 @@ class FunctionTest(test.TestCase):
|
||||
call = MyIdentityFunc([18.0])
|
||||
self.assertEqual("MyIdentity", call.op.name)
|
||||
with session.Session() as sess:
|
||||
self.assertAllEqual([18.0], sess.run(call))
|
||||
self.assertAllEqual([18.0], self.evaluate(call))
|
||||
|
||||
def testIdentityImplicitDeref(self):
|
||||
|
||||
@ -116,8 +116,8 @@ class FunctionTest(test.TestCase):
|
||||
self.assertEqual("MyIdentity", call.op.name)
|
||||
for cfg in _OptimizerOptions():
|
||||
with session.Session(config=cfg) as sess:
|
||||
sess.run(var.initializer)
|
||||
self.assertAllEqual([18.0], sess.run(call))
|
||||
self.evaluate(var.initializer)
|
||||
self.assertAllEqual([18.0], self.evaluate(call))
|
||||
|
||||
def testIdentityOutputName(self):
|
||||
|
||||
@ -130,7 +130,7 @@ class FunctionTest(test.TestCase):
|
||||
call = MyIdentityFunc([18.0])
|
||||
self.assertEqual("MyIdentity", call.op.name)
|
||||
with session.Session() as sess:
|
||||
self.assertAllEqual([18.0], sess.run(call))
|
||||
self.assertAllEqual([18.0], self.evaluate(call))
|
||||
|
||||
def testTooManyOutputNames(self):
|
||||
|
||||
@ -158,7 +158,7 @@ class FunctionTest(test.TestCase):
|
||||
call = APlus2B([1.0], [2.0])
|
||||
self.assertEqual("APlus2B", call.op.name)
|
||||
with session.Session() as sess:
|
||||
self.assertAllEqual([5.0], sess.run(call))
|
||||
self.assertAllEqual([5.0], self.evaluate(call))
|
||||
|
||||
def testFunctionWithNoOutput(self):
|
||||
|
||||
@ -187,7 +187,7 @@ class FunctionTest(test.TestCase):
|
||||
call = APlus2B([1.0], [2.0])
|
||||
self.assertEqual("APlus2B", call.op.name)
|
||||
with session.Session() as sess:
|
||||
self.assertAllEqual([5.0], sess.run(call))
|
||||
self.assertAllEqual([5.0], self.evaluate(call))
|
||||
|
||||
def testDefineFunctionDuplicateOutputs(self):
|
||||
|
||||
@ -224,8 +224,8 @@ class FunctionTest(test.TestCase):
|
||||
call_g = XSquarePlusOneGrad([2.0], [0.1])
|
||||
|
||||
with session.Session() as sess:
|
||||
self.assertAllClose([5.0], sess.run(call_f))
|
||||
self.assertAllClose([0.4], sess.run(call_g))
|
||||
self.assertAllClose([5.0], self.evaluate(call_f))
|
||||
self.assertAllClose([0.4], self.evaluate(call_g))
|
||||
|
||||
def testTanhSymGrad(self):
|
||||
|
||||
@ -365,7 +365,7 @@ class FunctionTest(test.TestCase):
|
||||
else:
|
||||
dx, dy = gradients_impl.gradients([z], [x, y])
|
||||
with session.Session() as sess:
|
||||
dx_val, dy_val = sess.run([dx, dy])
|
||||
dx_val, dy_val = self.evaluate([dx, dy])
|
||||
self.assertEqual([2.0], dx_val)
|
||||
self.assertEqual([0.0], dy_val)
|
||||
|
||||
@ -387,7 +387,7 @@ class FunctionTest(test.TestCase):
|
||||
call = AConstant()
|
||||
self.assertEqual("AConstant", call.op.name)
|
||||
with session.Session() as sess:
|
||||
self.assertAllEqual([42], sess.run(call))
|
||||
self.assertAllEqual([42], self.evaluate(call))
|
||||
|
||||
def testDefineFunctionNames(self):
|
||||
|
||||
@ -468,7 +468,7 @@ class FunctionTest(test.TestCase):
|
||||
|
||||
loop = control_flow_ops.while_loop(lambda x: x < 1e5, Body, [1.0])
|
||||
|
||||
ans = sess.run(loop)
|
||||
ans = self.evaluate(loop)
|
||||
self.assertAllClose(ans, 131072.)
|
||||
|
||||
def testControlFlowStrictness(self):
|
||||
@ -650,8 +650,8 @@ class FunctionTest(test.TestCase):
|
||||
# pylint: enable=unexpected-keyword-arg
|
||||
self.assertEqual("next", call2.op.name)
|
||||
with session.Session() as sess:
|
||||
self.assertAllEqual([1], sess.run(call1))
|
||||
self.assertAllEqual([0], sess.run(call2))
|
||||
self.assertAllEqual([1], self.evaluate(call1))
|
||||
self.assertAllEqual([0], self.evaluate(call2))
|
||||
|
||||
def testNestedFunction(self):
|
||||
|
||||
@ -794,7 +794,7 @@ class FunctionTest(test.TestCase):
|
||||
y = Foo()
|
||||
|
||||
with self.session(graph=g) as sess:
|
||||
self.assertEqual(sess.run(y), 10)
|
||||
self.assertEqual(self.evaluate(y), 10)
|
||||
|
||||
def testCaptureInCond(self):
|
||||
g = ops.Graph()
|
||||
@ -809,8 +809,8 @@ class FunctionTest(test.TestCase):
|
||||
z = Foo(False)
|
||||
|
||||
with self.session(graph=g) as sess:
|
||||
self.assertEqual(sess.run(y), 1)
|
||||
self.assertEqual(sess.run(z), 2)
|
||||
self.assertEqual(self.evaluate(y), 1)
|
||||
self.assertEqual(self.evaluate(z), 2)
|
||||
|
||||
def testStableName(self):
|
||||
|
||||
@ -854,7 +854,7 @@ class FunctionTest(test.TestCase):
|
||||
z = Bar(x)
|
||||
|
||||
with self.session(graph=g) as sess:
|
||||
v0, v1 = sess.run([y, z])
|
||||
v0, v1 = self.evaluate([y, z])
|
||||
self.assertAllEqual(v0, 20.)
|
||||
self.assertAllEqual(v1, 20.)
|
||||
|
||||
@ -900,7 +900,7 @@ class FunctionTest(test.TestCase):
|
||||
self.assertEqual(global_vars[0].name, "linear/w:0")
|
||||
|
||||
with session.Session() as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
output_val = sess.run(
|
||||
output_op, feed_dict={input_op: np.random.rand(32, 100)})
|
||||
self.assertEqual(output_val.shape, (32, 100))
|
||||
@ -928,7 +928,7 @@ class FunctionTest(test.TestCase):
|
||||
self.assertEqual(global_vars[0].name, "vs1/var:0")
|
||||
|
||||
with session.Session() as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
out1, out2 = sess.run(
|
||||
[out1_op, out2_op], feed_dict={input_op: np.linspace(1, 10, 10)})
|
||||
self.assertAllEqual(out1, np.linspace(2, 11, 10))
|
||||
@ -991,8 +991,8 @@ class FunctionTest(test.TestCase):
|
||||
result_2 = Bar(constant_op.constant(100, dtype=dtypes.int64))
|
||||
|
||||
with session.Session() as sess:
|
||||
self.assertEqual(4.0, sess.run(result_1))
|
||||
self.assertEqual(100, sess.run(result_2))
|
||||
self.assertEqual(4.0, self.evaluate(result_1))
|
||||
self.assertEqual(100, self.evaluate(result_2))
|
||||
self.assertEqual((4.0, 100), sess.run((result_1, result_2)))
|
||||
|
||||
def testStatefulFunction(self):
|
||||
@ -1052,8 +1052,8 @@ class FunctionTest(test.TestCase):
|
||||
for config in _OptimizerOptions():
|
||||
config.device_count["CPU"] = 2
|
||||
with session.Session(config=config) as sess:
|
||||
self.assertEqual(42.0, sess.run(f_0))
|
||||
self.assertEqual(44.0, sess.run(f_1))
|
||||
self.assertEqual(42.0, self.evaluate(f_0))
|
||||
self.assertEqual(44.0, self.evaluate(f_1))
|
||||
self.assertEqual((42.0, 44.0), sess.run((f_0, f_1)))
|
||||
|
||||
def testGuaranteedConstsAreCaptured(self):
|
||||
@ -1076,7 +1076,7 @@ class FunctionTest(test.TestCase):
|
||||
return output
|
||||
|
||||
with self.session(use_gpu=False) as sess:
|
||||
sess.run(var.initializer)
|
||||
self.evaluate(var.initializer)
|
||||
_ = sess.run(CapturesGuaranteedConst(), {also_not_const: 1.0})
|
||||
|
||||
def testSameFunctionDifferentGrads(self):
|
||||
@ -1127,7 +1127,7 @@ class FunctionTest(test.TestCase):
|
||||
dx2, = gradients_impl.gradients(ys=[y2], xs=[x2])
|
||||
|
||||
with self.session(graph=g) as sess:
|
||||
v0, v1, v2 = sess.run([dx0, dx1, dx2])
|
||||
v0, v1, v2 = self.evaluate([dx0, dx1, dx2])
|
||||
|
||||
self.assertAllEqual(v0, 2.)
|
||||
self.assertAllEqual(v1, 101.)
|
||||
@ -1532,7 +1532,7 @@ class UnrollLSTMTest(test.TestCase):
|
||||
tf_logging.info("time: %f txt size: %d gdef bin size: %d", finish - start,
|
||||
len(str(gdef)), len(gdef.SerializeToString()))
|
||||
with g.as_default(), session.Session(config=cfg) as sess:
|
||||
return sess.run(m)
|
||||
return self.evaluate(m)
|
||||
|
||||
mv0 = RunForward("complete")
|
||||
for cfg in _OptimizerOptions():
|
||||
@ -1561,7 +1561,7 @@ class UnrollLSTMTest(test.TestCase):
|
||||
tf_logging.info("time: %f txt size: %d gdef bin size: %d", finish - start,
|
||||
len(str(gdef)), len(gdef.SerializeToString()))
|
||||
with g.as_default(), session.Session(config=cfg) as sess:
|
||||
return sess.run(dw)
|
||||
return self.evaluate(dw)
|
||||
|
||||
d0 = RunForwardBackward("complete")
|
||||
for cfg in _OptimizerOptions():
|
||||
@ -1651,8 +1651,8 @@ class ModuleFunctionTest(test.TestCase):
|
||||
y = LinearWithCApi(a, b, c)
|
||||
z = Linear2WithCApi(a, b, c, d, e)
|
||||
with session.Session() as sess:
|
||||
self.assertAllEqual([[1]], sess.run(y))
|
||||
self.assertAllEqual([[5]], sess.run(z))
|
||||
self.assertAllEqual([[1]], self.evaluate(y))
|
||||
self.assertAllEqual([[5]], self.evaluate(z))
|
||||
|
||||
|
||||
class VariableHoistingTest(test.TestCase):
|
||||
@ -1704,8 +1704,8 @@ class VariableHoistingTest(test.TestCase):
|
||||
self.assertEqual("Foo/b", b.op.name)
|
||||
|
||||
with self.session(graph=g) as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
w, b, x, y0, loss, dw, db = sess.run([w, b, x, y0, loss, dw, db])
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
w, b, x, y0, loss, dw, db = self.evaluate([w, b, x, y0, loss, dw, db])
|
||||
|
||||
self.assertAllEqual(w.shape, (64, 64))
|
||||
self.assertAllClose(np.sum(w), 2050.44)
|
||||
|
@ -210,8 +210,8 @@ class DeviceFunctionsTest(test.TestCase):
|
||||
|
||||
with session.Session() as sess:
|
||||
init = variables.variables_initializer([variable_node])
|
||||
sess.run(init)
|
||||
output = sess.run(output_node)
|
||||
self.evaluate(init)
|
||||
output = self.evaluate(output_node)
|
||||
self.assertNear(4.0, output, 0.00001)
|
||||
variable_graph_def = sess.graph.as_graph_def()
|
||||
|
||||
@ -242,8 +242,8 @@ class DeviceFunctionsTest(test.TestCase):
|
||||
output_node = math_ops_lib.multiply(
|
||||
variable_node, 2.0, name="output_node")
|
||||
with session.Session() as sess:
|
||||
sess.run(variable_node.initializer)
|
||||
output = sess.run(output_node)
|
||||
self.evaluate(variable_node.initializer)
|
||||
output = self.evaluate(output_node)
|
||||
self.assertNear(2.0, output, 0.00001)
|
||||
variable_graph_def = sess.graph.as_graph_def()
|
||||
# First get the constant_graph_def when variable_names_whitelist is
|
||||
@ -256,7 +256,7 @@ class DeviceFunctionsTest(test.TestCase):
|
||||
|
||||
# Then initialize the unused variable, and get another
|
||||
# constant_graph_def when variable_names_whitelist is not set.
|
||||
sess.run(another_variable.initializer)
|
||||
self.evaluate(another_variable.initializer)
|
||||
constant_graph_def_without_variable_whitelist = (
|
||||
graph_util.convert_variables_to_constants(
|
||||
sess, variable_graph_def, ["output_node"]))
|
||||
@ -295,7 +295,7 @@ class DeviceFunctionsTest(test.TestCase):
|
||||
["Variable", "VariableV2", "VarHandleOp", "ReadVariableOp"])
|
||||
with session.Session() as sess:
|
||||
output_node = sess.graph.get_tensor_by_name("output_node:0")
|
||||
output = sess.run(output_node)
|
||||
output = self.evaluate(output_node)
|
||||
self.assertNear(2.0, output, 0.00001)
|
||||
|
||||
def create_node_def(self, op, name, inputs):
|
||||
|
@ -397,11 +397,11 @@ class ImportGraphDefTest(test.TestCase):
|
||||
# Run the imported graph.
|
||||
# TODO(b/76173421): make this work (currently DCHECKS)
|
||||
# with self.cached_session() as sess:
|
||||
# sess.run(imported_init)
|
||||
# self.assertEqual(sess.run(imported_var), 1.0)
|
||||
# self.assertEqual(sess.run(imported_assign), 2.0)
|
||||
# self.assertEqual(list(sess.run(imported_shape)), [])
|
||||
# self.assertEqual(list(sess.run(new_var_shape)), [])
|
||||
# self.evaluate(imported_init)
|
||||
# self.assertEqual(self.evaluate(imported_var), 1.0)
|
||||
# self.assertEqual(self.evaluate(imported_assign), 2.0)
|
||||
# self.assertEqual(list(self.evaluate(imported_shape)), [])
|
||||
# self.assertEqual(list(self.evaluate(new_var_shape)), [])
|
||||
|
||||
def testWhileLoop(self):
|
||||
# Produce GraphDef containing while loop.
|
||||
@ -418,7 +418,7 @@ class ImportGraphDefTest(test.TestCase):
|
||||
return_elements=[r.name])
|
||||
self.assertEqual(imported_r.name, "import/" + r.name)
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(sess.run(imported_r), 10)
|
||||
self.assertEqual(self.evaluate(imported_r), 10)
|
||||
|
||||
def testImportWhileLoopInCond(self):
|
||||
# Produce GraphDef containing while loop.
|
||||
@ -458,7 +458,7 @@ class ImportGraphDefTest(test.TestCase):
|
||||
lambda i: i < 2, ImportFn, [0],
|
||||
shape_invariants=[tensor_shape.TensorShape(None)])
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(sess.run(out), 10)
|
||||
self.assertEqual(self.evaluate(out), 10)
|
||||
|
||||
def testTypeMismatchInGraphDef(self):
|
||||
# TODO(skyewm): improve error message
|
||||
|
@ -492,8 +492,8 @@ class ScopedMetaGraphTest(test.TestCase):
|
||||
init_op = variables.global_variables_initializer()
|
||||
grad = gradients_impl.gradients([output], [var])
|
||||
with session.Session() as sess:
|
||||
sess.run(init_op)
|
||||
expected_grad_value = sess.run(grad)
|
||||
self.evaluate(init_op)
|
||||
expected_grad_value = self.evaluate(grad)
|
||||
|
||||
# Restore the MetaGraphDef into a new Graph with an import scope.
|
||||
with ops.Graph().as_default():
|
||||
@ -518,8 +518,8 @@ class ScopedMetaGraphTest(test.TestCase):
|
||||
init_op = variables.global_variables_initializer()
|
||||
|
||||
with session.Session() as sess:
|
||||
sess.run(init_op)
|
||||
actual_grad_value = sess.run(grad)
|
||||
self.evaluate(init_op)
|
||||
actual_grad_value = self.evaluate(grad)
|
||||
self.assertEqual(expected_grad_value, actual_grad_value)
|
||||
|
||||
def testImportWhileLoopInWhileLoop(self):
|
||||
@ -544,8 +544,8 @@ class ScopedMetaGraphTest(test.TestCase):
|
||||
_, x = control_flow_ops.while_loop(lambda i, x: i < 2, body, [0, 0.0],
|
||||
name="")
|
||||
with session.Session() as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
sess.run(x)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.evaluate(x)
|
||||
|
||||
def testScopedImportUnderNameScope(self):
|
||||
graph = ops.Graph()
|
||||
@ -868,8 +868,8 @@ class MetaGraphWithVariableScopeTest(test.TestCase):
|
||||
_, update_op = metrics.mean(values)
|
||||
|
||||
initializer = variables.local_variables_initializer()
|
||||
sess.run(initializer)
|
||||
sess.run(update_op)
|
||||
self.evaluate(initializer)
|
||||
self.evaluate(update_op)
|
||||
|
||||
meta_graph.export_scoped_meta_graph(
|
||||
filename=meta_graph_filename, graph=graph)
|
||||
@ -880,7 +880,7 @@ class MetaGraphWithVariableScopeTest(test.TestCase):
|
||||
with self.session(graph=graph) as sess:
|
||||
meta_graph.import_scoped_meta_graph(meta_graph_filename)
|
||||
initializer = variables.local_variables_initializer()
|
||||
sess.run(initializer)
|
||||
self.evaluate(initializer)
|
||||
|
||||
# Verifies that importing an old meta_graph where "local_variables"
|
||||
# collection is of node_list type works, but cannot build initializer
|
||||
|
@ -503,7 +503,7 @@ class OperationTest(test_util.TensorFlowTestCase):
|
||||
with self.assertRaisesRegexp(
|
||||
errors.InvalidArgumentError,
|
||||
"Graph is invalid, contains a cycle with 2 nodes"):
|
||||
sess.run(x)
|
||||
self.evaluate(x)
|
||||
|
||||
def testUpdateInput(self):
|
||||
g = ops.Graph()
|
||||
@ -517,21 +517,21 @@ class OperationTest(test_util.TensorFlowTestCase):
|
||||
self.assertEquals(x.consumers(), [])
|
||||
self.assertEquals(y.consumers(), [z.op, z.op])
|
||||
with session.Session(graph=g) as sess:
|
||||
self.assertEquals(sess.run(z), 4)
|
||||
self.assertEquals(self.evaluate(z), 4)
|
||||
|
||||
z.op._update_input(0, x) # pylint: disable=protected-access
|
||||
self.assertEquals(list(z.op.inputs), [x, y])
|
||||
self.assertEquals(x.consumers(), [z.op])
|
||||
self.assertEquals(y.consumers(), [z.op])
|
||||
with session.Session(graph=g) as sess:
|
||||
self.assertEquals(sess.run(z), 3)
|
||||
self.assertEquals(self.evaluate(z), 3)
|
||||
|
||||
z.op._update_input(1, y) # pylint: disable=protected-access
|
||||
self.assertEquals(list(z.op.inputs), [x, y])
|
||||
self.assertEquals(x.consumers(), [z.op])
|
||||
self.assertEquals(y.consumers(), [z.op])
|
||||
with session.Session(graph=g) as sess:
|
||||
self.assertEquals(sess.run(z), 3)
|
||||
self.assertEquals(self.evaluate(z), 3)
|
||||
|
||||
def testUpdateInputGraphError(self):
|
||||
g_0 = ops.Graph()
|
||||
@ -557,7 +557,7 @@ class OperationTest(test_util.TensorFlowTestCase):
|
||||
errors.InvalidArgumentError,
|
||||
"Input 0 of node add was passed string from Const_1:0 incompatible "
|
||||
"with expected int32"):
|
||||
sess.run(z)
|
||||
self.evaluate(z)
|
||||
|
||||
def testUpdateInputShapeError(self):
|
||||
g = ops.Graph()
|
||||
@ -2390,7 +2390,7 @@ class GraphTest(test_util.TensorFlowTestCase):
|
||||
c = math_ops.add(a, b)
|
||||
# Create a session we can delete
|
||||
with session.Session(graph=g) as sess:
|
||||
sess.run(c)
|
||||
self.evaluate(c)
|
||||
# Delete all references and trigger gc
|
||||
del g
|
||||
del a
|
||||
@ -2406,7 +2406,7 @@ class GraphTest(test_util.TensorFlowTestCase):
|
||||
math_ops.add([1, 2], [1, 2, 3])
|
||||
a = constant_op.constant(1)
|
||||
with session.Session() as sess:
|
||||
sess.run(a)
|
||||
self.evaluate(a)
|
||||
|
||||
def testRunnableAfterInvalidShapeWithKernelLabelMap(self):
|
||||
g = ops.Graph()
|
||||
@ -2416,7 +2416,7 @@ class GraphTest(test_util.TensorFlowTestCase):
|
||||
test_ops.kernel_label_required(1)
|
||||
a = constant_op.constant(1)
|
||||
with session.Session() as sess:
|
||||
sess.run(a)
|
||||
self.evaluate(a)
|
||||
|
||||
|
||||
class AttrScopeTest(test_util.TensorFlowTestCase):
|
||||
|
@ -109,8 +109,8 @@ class SmartCaseTest(test_util.TensorFlowTestCase):
|
||||
exclusive=True)
|
||||
with session.Session() as sess:
|
||||
# No feed_dict necessary
|
||||
self.assertEqual(sess.run(y), 1)
|
||||
self.assertEqual(sess.run(z), 1)
|
||||
self.assertEqual(self.evaluate(y), 1)
|
||||
self.assertEqual(self.evaluate(z), 1)
|
||||
|
||||
def testFalse(self):
|
||||
conditions = [(False, raise_exception)]
|
||||
@ -121,8 +121,8 @@ class SmartCaseTest(test_util.TensorFlowTestCase):
|
||||
default=lambda: constant_op.constant(1),
|
||||
exclusive=True)
|
||||
with session.Session() as sess:
|
||||
self.assertEqual(sess.run(y), 1)
|
||||
self.assertEqual(sess.run(z), 1)
|
||||
self.assertEqual(self.evaluate(y), 1)
|
||||
self.assertEqual(self.evaluate(z), 1)
|
||||
|
||||
def testMix(self):
|
||||
x = array_ops.placeholder(dtype=dtypes.int32, shape=[])
|
||||
|
@ -50,7 +50,7 @@ class SparseTensorTest(test_util.TensorFlowTestCase):
|
||||
self.assertAllEqual(indices, value.indices)
|
||||
self.assertAllEqual(values, value.values)
|
||||
self.assertAllEqual(shape, value.dense_shape)
|
||||
sess_run_value = sess.run(sp)
|
||||
sess_run_value = self.evaluate(sp)
|
||||
self.assertAllEqual(sess_run_value.indices, value.indices)
|
||||
self.assertAllEqual(sess_run_value.values, value.values)
|
||||
self.assertAllEqual(sess_run_value.dense_shape, value.dense_shape)
|
||||
|
@ -66,9 +66,9 @@ class SubscribeTest(test_util.TensorFlowTestCase):
|
||||
self.assertTrue(c.op in d.op.control_inputs)
|
||||
|
||||
with self.cached_session() as sess:
|
||||
c_out = sess.run([c])
|
||||
n_out = sess.run([n])
|
||||
d_out = sess.run([d])
|
||||
c_out = self.evaluate([c])
|
||||
n_out = self.evaluate([n])
|
||||
d_out = self.evaluate([d])
|
||||
|
||||
self.assertEqual(n_out, [-2])
|
||||
self.assertEqual(c_out, [2])
|
||||
@ -145,8 +145,8 @@ class SubscribeTest(test_util.TensorFlowTestCase):
|
||||
lambda t: script_ops.py_func(sub, [t], [t.dtype]))
|
||||
|
||||
with self.cached_session() as sess:
|
||||
c_out = sess.run([c])
|
||||
d_out = sess.run([d])
|
||||
c_out = self.evaluate([c])
|
||||
d_out = self.evaluate([d])
|
||||
|
||||
self.assertEqual(c_out, [42])
|
||||
self.assertEqual(d_out, [11])
|
||||
@ -205,7 +205,7 @@ class SubscribeTest(test_util.TensorFlowTestCase):
|
||||
|
||||
# Expect the three side effect graphs to have been evaluated.
|
||||
with self.cached_session() as sess:
|
||||
sess.run([c_sub])
|
||||
self.evaluate([c_sub])
|
||||
self.assertIn('graph1', shared)
|
||||
self.assertIn('graph2', shared)
|
||||
self.assertIn('graph3', shared)
|
||||
@ -229,20 +229,20 @@ class SubscribeTest(test_util.TensorFlowTestCase):
|
||||
|
||||
with self.cached_session() as sess:
|
||||
# Initialize the variables first.
|
||||
sess.run([v1.initializer])
|
||||
sess.run([v2.initializer])
|
||||
self.evaluate([v1.initializer])
|
||||
self.evaluate([v2.initializer])
|
||||
|
||||
# Expect the side effects to be triggered when evaluating the add op as
|
||||
# it will read the value of the variable.
|
||||
sess.run([add])
|
||||
self.evaluate([add])
|
||||
self.assertEqual(1, len(shared))
|
||||
|
||||
# Expect the side effect not to be triggered when evaluating the assign
|
||||
# op as it will not access the 'read' output of the variable.
|
||||
sess.run([assign_v1])
|
||||
self.evaluate([assign_v1])
|
||||
self.assertEqual(1, len(shared))
|
||||
|
||||
sess.run([add])
|
||||
self.evaluate([add])
|
||||
self.assertEqual(2, len(shared))
|
||||
|
||||
# Make sure the values read from the variable match the expected ones.
|
||||
@ -273,7 +273,7 @@ class SubscribeTest(test_util.TensorFlowTestCase):
|
||||
self.assertFalse(subscribe._is_subscribed_identity(tensor_array.handle))
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run([reader])
|
||||
self.evaluate([reader])
|
||||
self.assertEqual(0, len(shared))
|
||||
|
||||
def testMultipleOutputs(self):
|
||||
@ -304,7 +304,7 @@ class SubscribeTest(test_util.TensorFlowTestCase):
|
||||
lambda t: script_ops.py_func(sub, [t], [t.dtype]))
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run([neg])
|
||||
self.evaluate([neg])
|
||||
|
||||
# All three ops have been processed.
|
||||
self.assertEqual(3, len(shared))
|
||||
@ -375,7 +375,7 @@ class SubscribeTest(test_util.TensorFlowTestCase):
|
||||
self.assertIsNot(context(subscriptions[0]), context(subscriptions[1]))
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(cond)
|
||||
self.evaluate(cond)
|
||||
|
||||
self.assertEqual(3, len(results))
|
||||
|
||||
|
@ -771,7 +771,7 @@ class TensorUtilTest(test.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
ma = MockArray(np.array([10, 20, 30]))
|
||||
t = ops.convert_to_tensor(ma)
|
||||
a = sess.run(t)
|
||||
a = self.evaluate(t)
|
||||
self.assertEquals(np.int64, a.dtype)
|
||||
self.assertAllClose(np.array([10, 20, 30], dtype=np.int64), a)
|
||||
|
||||
|
@ -61,7 +61,7 @@ class ConstantFoldingTest(test.TestCase):
|
||||
back_prop=False,
|
||||
parallel_iterations=1)
|
||||
with session.Session() as sess:
|
||||
y_v = sess.run(y)
|
||||
y_v = self.evaluate(y)
|
||||
self.assertAllEqual(np.zeros([10, 20, 30]), y_v)
|
||||
|
||||
|
||||
|
@ -241,7 +241,7 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
if restore:
|
||||
saver.restore(sess, checkpoint_path)
|
||||
else:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
|
||||
np.random.seed(0)
|
||||
for _ in range(2):
|
||||
@ -262,7 +262,7 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
output = _two_layer_model(x)
|
||||
|
||||
with session.Session(config=_get_config(False)) as sess:
|
||||
output_val_ref = sess.run(output)
|
||||
output_val_ref = self.evaluate(output)
|
||||
|
||||
with session.Session(config=_get_config()) as sess:
|
||||
metadata = config_pb2.RunMetadata()
|
||||
@ -365,7 +365,7 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
output = array_ops.identity(pad)
|
||||
|
||||
with session.Session(config=_get_config(False)) as sess:
|
||||
output_val_ref = sess.run(output)
|
||||
output_val_ref = self.evaluate(output)
|
||||
|
||||
with session.Session(config=_get_config()) as sess:
|
||||
metadata = config_pb2.RunMetadata()
|
||||
@ -396,7 +396,7 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
output = array_ops.identity(reduce_sum)
|
||||
|
||||
with session.Session(config=_get_config(False)) as sess:
|
||||
output_val_ref = sess.run(output)
|
||||
output_val_ref = self.evaluate(output)
|
||||
|
||||
with session.Session(config=_get_config()) as sess:
|
||||
metadata = config_pb2.RunMetadata()
|
||||
@ -425,7 +425,7 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
output = array_ops.identity(cast)
|
||||
|
||||
with session.Session(config=_get_config(False)) as sess:
|
||||
output_val_ref = sess.run(output)
|
||||
output_val_ref = self.evaluate(output)
|
||||
|
||||
with session.Session(config=_get_config()) as sess:
|
||||
metadata = config_pb2.RunMetadata()
|
||||
@ -456,7 +456,7 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
output = array_ops.identity(squeeze)
|
||||
|
||||
with session.Session(config=_get_config(False)) as sess:
|
||||
output_val_ref = sess.run(output)
|
||||
output_val_ref = self.evaluate(output)
|
||||
|
||||
with session.Session(config=_get_config()) as sess:
|
||||
metadata = config_pb2.RunMetadata()
|
||||
@ -486,7 +486,7 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
output = array_ops.identity(squeeze)
|
||||
|
||||
with session.Session(config=_get_config(False)) as sess:
|
||||
output_val_ref = sess.run(output)
|
||||
output_val_ref = self.evaluate(output)
|
||||
|
||||
with session.Session(config=_get_config()) as sess:
|
||||
metadata = config_pb2.RunMetadata()
|
||||
@ -516,7 +516,7 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
output = array_ops.identity(squeeze)
|
||||
|
||||
with session.Session(config=_get_config(False)) as sess:
|
||||
output_val_ref = sess.run(output)
|
||||
output_val_ref = self.evaluate(output)
|
||||
|
||||
with session.Session(config=_get_config()) as sess:
|
||||
metadata = config_pb2.RunMetadata()
|
||||
@ -545,7 +545,7 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
output = array_ops.identity(reduce_sum)
|
||||
|
||||
with session.Session(config=_get_config(False)) as sess:
|
||||
output_val_ref = sess.run(output)
|
||||
output_val_ref = self.evaluate(output)
|
||||
|
||||
with session.Session(config=_get_config()) as sess:
|
||||
metadata = config_pb2.RunMetadata()
|
||||
@ -574,7 +574,7 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
output = array_ops.identity(reduce_sum)
|
||||
|
||||
with session.Session(config=_get_config(False)) as sess:
|
||||
output_val_ref = sess.run(output)
|
||||
output_val_ref = self.evaluate(output)
|
||||
|
||||
with session.Session(config=_get_config()) as sess:
|
||||
metadata = config_pb2.RunMetadata()
|
||||
@ -603,7 +603,7 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
output = array_ops.identity(reduce_sum)
|
||||
|
||||
with session.Session(config=_get_config(False)) as sess:
|
||||
output_val_ref = sess.run(output)
|
||||
output_val_ref = self.evaluate(output)
|
||||
|
||||
with session.Session(config=_get_config()) as sess:
|
||||
metadata = config_pb2.RunMetadata()
|
||||
@ -632,7 +632,7 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
output = array_ops.identity(reduce_sum)
|
||||
|
||||
with session.Session(config=_get_config(False)) as sess:
|
||||
output_val_ref = sess.run(output)
|
||||
output_val_ref = self.evaluate(output)
|
||||
|
||||
with session.Session(config=_get_config()) as sess:
|
||||
metadata = config_pb2.RunMetadata()
|
||||
@ -662,7 +662,7 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
output = array_ops.identity(reduce_sum)
|
||||
|
||||
with session.Session(config=_get_config(False)) as sess:
|
||||
output_val_ref = sess.run(output)
|
||||
output_val_ref = self.evaluate(output)
|
||||
|
||||
with session.Session(config=_get_config()) as sess:
|
||||
metadata = config_pb2.RunMetadata()
|
||||
@ -691,7 +691,7 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
output = array_ops.identity(reduce_sum)
|
||||
|
||||
with session.Session(config=_get_config(False)) as sess:
|
||||
output_val_ref = sess.run(output)
|
||||
output_val_ref = self.evaluate(output)
|
||||
|
||||
with session.Session(config=_get_config()) as sess:
|
||||
metadata = config_pb2.RunMetadata()
|
||||
@ -724,7 +724,7 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
output = array_ops.identity(concat)
|
||||
|
||||
with session.Session(config=_get_config(False)) as sess:
|
||||
output_val_ref = sess.run(output)
|
||||
output_val_ref = self.evaluate(output)
|
||||
|
||||
with session.Session(config=_get_config()) as sess:
|
||||
metadata = config_pb2.RunMetadata()
|
||||
@ -835,7 +835,7 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
output = array_ops.identity(reverse)
|
||||
|
||||
with session.Session(config=_get_config(False)) as sess:
|
||||
output_val_ref = sess.run(output)
|
||||
output_val_ref = self.evaluate(output)
|
||||
|
||||
with session.Session(config=_get_config()) as sess:
|
||||
metadata = config_pb2.RunMetadata()
|
||||
@ -905,7 +905,7 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
output = array_ops.identity(select)
|
||||
|
||||
with session.Session(config=_get_config(False)) as sess:
|
||||
output_val_ref = sess.run(output)
|
||||
output_val_ref = self.evaluate(output)
|
||||
|
||||
with session.Session(config=_get_config()) as sess:
|
||||
metadata = config_pb2.RunMetadata()
|
||||
@ -966,7 +966,7 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
output = array_ops.identity(select)
|
||||
|
||||
with session.Session(config=_get_config(False)) as sess:
|
||||
output_val_ref = sess.run(output)
|
||||
output_val_ref = self.evaluate(output)
|
||||
|
||||
with session.Session(config=_get_config()) as sess:
|
||||
metadata = config_pb2.RunMetadata()
|
||||
@ -1179,7 +1179,7 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
output = array_ops.identity(s)
|
||||
|
||||
with session.Session(config=_get_config(False)) as sess:
|
||||
output_val_ref = sess.run(output)
|
||||
output_val_ref = self.evaluate(output)
|
||||
|
||||
with session.Session(config=_get_config()) as sess:
|
||||
metadata = config_pb2.RunMetadata()
|
||||
@ -1214,7 +1214,7 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
output = array_ops.identity(s)
|
||||
|
||||
with session.Session(config=_get_config(False)) as sess:
|
||||
output_val_ref = sess.run(output)
|
||||
output_val_ref = self.evaluate(output)
|
||||
|
||||
with session.Session(config=_get_config()) as sess:
|
||||
metadata = config_pb2.RunMetadata()
|
||||
@ -1347,7 +1347,7 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
output = _loop()
|
||||
|
||||
with session.Session(config=_get_config(False)) as sess:
|
||||
output_val_ref = sess.run(output)
|
||||
output_val_ref = self.evaluate(output)
|
||||
|
||||
with session.Session(config=_get_config()) as sess:
|
||||
metadata = config_pb2.RunMetadata()
|
||||
@ -1374,7 +1374,7 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
output = _loop_with_branch()
|
||||
|
||||
with session.Session(config=_get_config(False)) as sess:
|
||||
output_val_ref = sess.run(output)
|
||||
output_val_ref = self.evaluate(output)
|
||||
|
||||
with session.Session(config=_get_config()) as sess:
|
||||
metadata = config_pb2.RunMetadata()
|
||||
@ -1398,7 +1398,7 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
output = _loop_with_vec_and_4d()
|
||||
|
||||
with session.Session(config=_get_config(False)) as sess:
|
||||
output_val_ref = sess.run(output)
|
||||
output_val_ref = self.evaluate(output)
|
||||
|
||||
with session.Session(config=_get_config()) as sess:
|
||||
metadata = config_pb2.RunMetadata()
|
||||
@ -1422,7 +1422,7 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
output = _model_with_second_port()
|
||||
|
||||
with session.Session(config=_get_config(False)) as sess:
|
||||
output_val_ref = sess.run(output)
|
||||
output_val_ref = self.evaluate(output)
|
||||
|
||||
with session.Session(config=_get_config()) as sess:
|
||||
metadata = config_pb2.RunMetadata()
|
||||
|
@ -231,10 +231,10 @@ class MemoryOptimizerRecomputeTest(test.TestCase):
|
||||
train_op = graph.get_operation_by_name(train_op_name)
|
||||
loss_op = graph.get_tensor_by_name(loss_op_name)
|
||||
with session.Session(config=config, graph=graph) as sess:
|
||||
sess.run(init_op)
|
||||
sess.run(train_op)
|
||||
sess.run(train_op)
|
||||
return sess.run(loss_op)
|
||||
self.evaluate(init_op)
|
||||
self.evaluate(train_op)
|
||||
self.evaluate(train_op)
|
||||
return self.evaluate(loss_op)
|
||||
|
||||
def testRecomputationRewritingNoErrors(self):
|
||||
"""Tests that graph output is not significantly different with rewriting."""
|
||||
@ -295,8 +295,8 @@ class MemoryOptimizerRecomputeTest(test.TestCase):
|
||||
rewrite_options=manual_memory_config)
|
||||
session_config = config_pb2.ConfigProto(graph_options=graph_options)
|
||||
with session.Session(config=session_config) as sess:
|
||||
sess.run(init_op)
|
||||
sess.run(train_op)
|
||||
self.evaluate(init_op)
|
||||
self.evaluate(train_op)
|
||||
|
||||
def testHintDoesRewrite(self):
|
||||
graph = self._annotated_graph()[0]
|
||||
|
@ -136,7 +136,7 @@ class BackendUtilsTest(test.TestCase):
|
||||
x = keras.Input((3,))
|
||||
y = keras.layers.BatchNormalization()(x)
|
||||
if not context.executing_eagerly():
|
||||
sess.run(variables.global_variables_initializer())
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
sess.run(y, feed_dict={x: np.random.random((2, 3))})
|
||||
|
||||
def test_learning_phase_scope(self):
|
||||
|
@ -1013,8 +1013,8 @@ class RNNTest(test.TestCase):
|
||||
inputs, _ = cell(inputs, initial_state)
|
||||
output = inputs
|
||||
if not context.executing_eagerly():
|
||||
sess.run(variables_lib.global_variables_initializer())
|
||||
output = sess.run(output)
|
||||
self.evaluate(variables_lib.global_variables_initializer())
|
||||
output = self.evaluate(output)
|
||||
return output
|
||||
|
||||
random_seed.set_random_seed(12345)
|
||||
|
@ -322,19 +322,19 @@ class KerasMetricsTest(test.TestCase):
|
||||
m = metrics.Mean()
|
||||
v = array_ops.placeholder(dtypes.float32)
|
||||
w = array_ops.placeholder(dtypes.float32)
|
||||
sess.run(variables.variables_initializer(m.variables))
|
||||
self.evaluate(variables.variables_initializer(m.variables))
|
||||
|
||||
# check __call__()
|
||||
result_t = m(v, sample_weight=w)
|
||||
result = sess.run(result_t, feed_dict=({v: 100, w: 0.5}))
|
||||
self.assertEqual(sess.run(m.total), 50)
|
||||
self.assertEqual(sess.run(m.count), 0.5)
|
||||
self.assertEqual(self.evaluate(m.total), 50)
|
||||
self.assertEqual(self.evaluate(m.count), 0.5)
|
||||
self.assertEqual(result, 50 / 0.5)
|
||||
|
||||
# check update_state() and result()
|
||||
result = sess.run(result_t, feed_dict=({v: [1, 5], w: [1, 0.2]}))
|
||||
self.assertAlmostEqual(sess.run(m.total), 52, 2) # 50 + 1 + 5 * 0.2
|
||||
self.assertAlmostEqual(sess.run(m.count), 1.7, 2) # 0.5 + 1.2
|
||||
self.assertAlmostEqual(self.evaluate(m.total), 52, 2) # 50 + 1 + 5 * 0.2
|
||||
self.assertAlmostEqual(self.evaluate(m.count), 1.7, 2) # 0.5 + 1.2
|
||||
self.assertAlmostEqual(result, 52 / 1.7, 2)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user