Replace a few calls of Session run
with evaluate
In order to support tests running in eager mode we need to avoid unnecessary use of Sessions in tests. This moves to remove some of the uses of the `run` function in favor of `evaluate`. PiperOrigin-RevId: 222013881
This commit is contained in:
parent
1aaa68d93c
commit
1fdd7c7408
@ -60,7 +60,7 @@ class CategoricalTest(xla_test.XLATestCase):
|
|||||||
random_seed.set_random_seed(1618)
|
random_seed.set_random_seed(1618)
|
||||||
op = random_ops.multinomial(logits, num_samples,
|
op = random_ops.multinomial(logits, num_samples,
|
||||||
output_dtype=dtypes.int32)
|
output_dtype=dtypes.int32)
|
||||||
d = sess.run(op)
|
d = self.evaluate(op)
|
||||||
|
|
||||||
batch_size, num_classes = logits.shape
|
batch_size, num_classes = logits.shape
|
||||||
freqs_mat = []
|
freqs_mat = []
|
||||||
@ -85,9 +85,9 @@ class CategoricalTest(xla_test.XLATestCase):
|
|||||||
|
|
||||||
# The random-number generator, if working correctly, should produce the
|
# The random-number generator, if working correctly, should produce the
|
||||||
# same output multiple times with low probability.
|
# same output multiple times with low probability.
|
||||||
y = sess.run(x)
|
y = self.evaluate(x)
|
||||||
z = sess.run(x)
|
z = self.evaluate(x)
|
||||||
w = sess.run(x)
|
w = self.evaluate(x)
|
||||||
|
|
||||||
# We use exact equality here. If the random-number generator is producing
|
# We use exact equality here. If the random-number generator is producing
|
||||||
# deterministic output, all three outputs will be bitwise identical.
|
# deterministic output, all three outputs will be bitwise identical.
|
||||||
@ -112,7 +112,7 @@ class CategoricalTest(xla_test.XLATestCase):
|
|||||||
x = random_ops.multinomial(
|
x = random_ops.multinomial(
|
||||||
array_ops.ones(shape=[1, 20], dtype=dtype), 1000,
|
array_ops.ones(shape=[1, 20], dtype=dtype), 1000,
|
||||||
output_dtype=output_dtype)
|
output_dtype=output_dtype)
|
||||||
y = sess.run(x)
|
y = self.evaluate(x)
|
||||||
self.assertTrue((y >= 0).sum() == 1000)
|
self.assertTrue((y >= 0).sum() == 1000)
|
||||||
self.assertTrue((y < 20).sum() == 1000)
|
self.assertTrue((y < 20).sum() == 1000)
|
||||||
|
|
||||||
|
@ -337,7 +337,7 @@ class ConcatOffsetTest(xla_test.XLATestCase):
|
|||||||
s1 = constant_op.constant([2, 7, 5], dtypes.int32)
|
s1 = constant_op.constant([2, 7, 5], dtypes.int32)
|
||||||
s2 = constant_op.constant([2, 20, 5], dtypes.int32)
|
s2 = constant_op.constant([2, 20, 5], dtypes.int32)
|
||||||
off = gen_array_ops.concat_offset(cdim, [s0, s1, s2])
|
off = gen_array_ops.concat_offset(cdim, [s0, s1, s2])
|
||||||
ans = sess.run(off)
|
ans = self.evaluate(off)
|
||||||
self.assertAllEqual(ans, [[0, 0, 0], [0, 3, 0], [0, 10, 0]])
|
self.assertAllEqual(ans, [[0, 0, 0], [0, 3, 0], [0, 10, 0]])
|
||||||
|
|
||||||
|
|
||||||
@ -350,7 +350,7 @@ class PackTest(xla_test.XLATestCase):
|
|||||||
s1 = constant_op.constant([2, 7, 5], dtypes.int32)
|
s1 = constant_op.constant([2, 7, 5], dtypes.int32)
|
||||||
s2 = constant_op.constant([2, 20, 5], dtypes.int32)
|
s2 = constant_op.constant([2, 20, 5], dtypes.int32)
|
||||||
packed = array_ops.stack([s0, s1, s2])
|
packed = array_ops.stack([s0, s1, s2])
|
||||||
ans = sess.run(packed)
|
ans = self.evaluate(packed)
|
||||||
self.assertAllEqual(ans, [[2, 3, 5], [2, 7, 5], [2, 20, 5]])
|
self.assertAllEqual(ans, [[2, 3, 5], [2, 7, 5], [2, 20, 5]])
|
||||||
|
|
||||||
def testScalars(self):
|
def testScalars(self):
|
||||||
@ -360,7 +360,7 @@ class PackTest(xla_test.XLATestCase):
|
|||||||
s1 = constant_op.constant(3, dtypes.int32)
|
s1 = constant_op.constant(3, dtypes.int32)
|
||||||
s2 = constant_op.constant(5, dtypes.int32)
|
s2 = constant_op.constant(5, dtypes.int32)
|
||||||
packed = array_ops.stack([s0, s1, s2])
|
packed = array_ops.stack([s0, s1, s2])
|
||||||
ans = sess.run(packed)
|
ans = self.evaluate(packed)
|
||||||
self.assertAllEqual(ans, [2, 3, 5])
|
self.assertAllEqual(ans, [2, 3, 5])
|
||||||
|
|
||||||
def testEmpty(self):
|
def testEmpty(self):
|
||||||
@ -370,7 +370,7 @@ class PackTest(xla_test.XLATestCase):
|
|||||||
s1 = constant_op.constant([[]], dtypes.int32)
|
s1 = constant_op.constant([[]], dtypes.int32)
|
||||||
s2 = constant_op.constant([[]], dtypes.int32)
|
s2 = constant_op.constant([[]], dtypes.int32)
|
||||||
packed = array_ops.stack([s0, s1, s2])
|
packed = array_ops.stack([s0, s1, s2])
|
||||||
ans = sess.run(packed)
|
ans = self.evaluate(packed)
|
||||||
self.assertAllEqual(ans, [[[]], [[]], [[]]])
|
self.assertAllEqual(ans, [[[]], [[]], [[]]])
|
||||||
|
|
||||||
|
|
||||||
|
@ -106,7 +106,7 @@ class EagerTest(xla_test.XLATestCase):
|
|||||||
three = constant_op.constant(3)
|
three = constant_op.constant(3)
|
||||||
five = constant_op.constant(5)
|
five = constant_op.constant(5)
|
||||||
product = three * five
|
product = three * five
|
||||||
self.assertAllEqual(15, sess.run(product))
|
self.assertAllEqual(15, self.evaluate(product))
|
||||||
|
|
||||||
def testDegenerateSlices(self):
|
def testDegenerateSlices(self):
|
||||||
with self.test_scope():
|
with self.test_scope():
|
||||||
|
@ -50,7 +50,7 @@ class FunctionTest(xla_test.XLATestCase):
|
|||||||
b = constant_op.constant(bval, name="b")
|
b = constant_op.constant(bval, name="b")
|
||||||
with self.test_scope():
|
with self.test_scope():
|
||||||
call_f = Foo(a, b)
|
call_f = Foo(a, b)
|
||||||
result = sess.run(call_f)
|
result = self.evaluate(call_f)
|
||||||
self.assertAllClose(result, expected, rtol=1e-3)
|
self.assertAllClose(result, expected, rtol=1e-3)
|
||||||
|
|
||||||
def testNestedFunctions(self):
|
def testNestedFunctions(self):
|
||||||
@ -76,7 +76,7 @@ class FunctionTest(xla_test.XLATestCase):
|
|||||||
b = constant_op.constant(bval, name="b")
|
b = constant_op.constant(bval, name="b")
|
||||||
with self.test_scope():
|
with self.test_scope():
|
||||||
call_g = Foo(a, b)
|
call_g = Foo(a, b)
|
||||||
result = sess.run(call_g)
|
result = self.evaluate(call_g)
|
||||||
self.assertAllClose(result, expected, rtol=1e-3)
|
self.assertAllClose(result, expected, rtol=1e-3)
|
||||||
|
|
||||||
def testFunctionMultipleRetvals(self):
|
def testFunctionMultipleRetvals(self):
|
||||||
@ -100,7 +100,7 @@ class FunctionTest(xla_test.XLATestCase):
|
|||||||
b = constant_op.constant(bval, name="b")
|
b = constant_op.constant(bval, name="b")
|
||||||
with self.test_scope():
|
with self.test_scope():
|
||||||
call_f = Foo(a, b)
|
call_f = Foo(a, b)
|
||||||
result = sess.run(call_f)
|
result = self.evaluate(call_f)
|
||||||
self.assertAllClose(result, expected, rtol=1e-3)
|
self.assertAllClose(result, expected, rtol=1e-3)
|
||||||
|
|
||||||
def testCompileTimeConstantsInDefun(self):
|
def testCompileTimeConstantsInDefun(self):
|
||||||
|
@ -88,7 +88,7 @@ class LSTMTest(test.TestCase):
|
|||||||
(basename, m_prev_scalar, c_prev_scalar, pad_scalar))
|
(basename, m_prev_scalar, c_prev_scalar, pad_scalar))
|
||||||
|
|
||||||
# Initialize variables and run the unrolled LSTM step.
|
# Initialize variables and run the unrolled LSTM step.
|
||||||
sess.run(variables.global_variables_initializer())
|
self.evaluate(variables.global_variables_initializer())
|
||||||
return sess.run([m, c])
|
return sess.run([m, c])
|
||||||
|
|
||||||
def testLSTMCell(self):
|
def testLSTMCell(self):
|
||||||
@ -173,7 +173,7 @@ class LSTMTest(test.TestCase):
|
|||||||
(basename, m_init_scalar, c_init_scalar, pad_scalar))
|
(basename, m_init_scalar, c_init_scalar, pad_scalar))
|
||||||
|
|
||||||
# Initialize variables and run the unrolled LSTM layer.
|
# Initialize variables and run the unrolled LSTM layer.
|
||||||
sess.run(variables.global_variables_initializer())
|
self.evaluate(variables.global_variables_initializer())
|
||||||
return sess.run(out_seq)
|
return sess.run(out_seq)
|
||||||
|
|
||||||
def testLSTMLayer(self):
|
def testLSTMLayer(self):
|
||||||
|
@ -33,7 +33,7 @@ class PlaceholderTest(xla_test.XLATestCase):
|
|||||||
ph = array_ops.placeholder_with_default(v, shape=[])
|
ph = array_ops.placeholder_with_default(v, shape=[])
|
||||||
out = ph * 2
|
out = ph * 2
|
||||||
sess.run(variables.variables_initializer([v]))
|
sess.run(variables.variables_initializer([v]))
|
||||||
self.assertEqual(8.0, sess.run(out))
|
self.assertEqual(8.0, self.evaluate(out))
|
||||||
|
|
||||||
def test_placeholder_with_default_fed(self):
|
def test_placeholder_with_default_fed(self):
|
||||||
with self.cached_session() as sess, self.test_scope():
|
with self.cached_session() as sess, self.test_scope():
|
||||||
|
@ -46,9 +46,9 @@ class RandomOpsTest(xla_test.XLATestCase):
|
|||||||
|
|
||||||
# The random-number generator, if working correctly, should produce the
|
# The random-number generator, if working correctly, should produce the
|
||||||
# same output multiple times with low probability.
|
# same output multiple times with low probability.
|
||||||
y = sess.run(x)
|
y = self.evaluate(x)
|
||||||
z = sess.run(x)
|
z = self.evaluate(x)
|
||||||
w = sess.run(x)
|
w = self.evaluate(x)
|
||||||
|
|
||||||
# We use exact equality here. If the random-number generator is producing
|
# We use exact equality here. If the random-number generator is producing
|
||||||
# deterministic output, all three outputs will be bitwise identical.
|
# deterministic output, all three outputs will be bitwise identical.
|
||||||
@ -83,7 +83,7 @@ class RandomOpsTest(xla_test.XLATestCase):
|
|||||||
with self.test_scope():
|
with self.test_scope():
|
||||||
x = random_ops.random_uniform(
|
x = random_ops.random_uniform(
|
||||||
shape=[1000], dtype=dtype, minval=-2, maxval=33)
|
shape=[1000], dtype=dtype, minval=-2, maxval=33)
|
||||||
y = sess.run(x)
|
y = self.evaluate(x)
|
||||||
self.assertTrue((y >= -2).sum() == 1000)
|
self.assertTrue((y >= -2).sum() == 1000)
|
||||||
self.assertTrue((y < 33).sum() == 1000)
|
self.assertTrue((y < 33).sum() == 1000)
|
||||||
|
|
||||||
@ -102,7 +102,7 @@ class RandomOpsTest(xla_test.XLATestCase):
|
|||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
with self.test_scope():
|
with self.test_scope():
|
||||||
x = random_ops.truncated_normal(shape=[count], dtype=dtype)
|
x = random_ops.truncated_normal(shape=[count], dtype=dtype)
|
||||||
y = sess.run(x)
|
y = self.evaluate(x)
|
||||||
|
|
||||||
def normal_cdf(x):
|
def normal_cdf(x):
|
||||||
return .5 * math.erfc(-x / math.sqrt(2))
|
return .5 * math.erfc(-x / math.sqrt(2))
|
||||||
@ -148,7 +148,7 @@ class RandomOpsTest(xla_test.XLATestCase):
|
|||||||
with self.test_scope():
|
with self.test_scope():
|
||||||
x = math_ops.range(1 << 16)
|
x = math_ops.range(1 << 16)
|
||||||
shuffle = random_ops.random_shuffle(x)
|
shuffle = random_ops.random_shuffle(x)
|
||||||
result = sess.run(shuffle)
|
result = self.evaluate(shuffle)
|
||||||
expected = range(1 << 16)
|
expected = range(1 << 16)
|
||||||
# Compare sets to avoid randomness behavior changes but make sure still
|
# Compare sets to avoid randomness behavior changes but make sure still
|
||||||
# have all the values.
|
# have all the values.
|
||||||
@ -159,7 +159,7 @@ class RandomOpsTest(xla_test.XLATestCase):
|
|||||||
with self.test_scope():
|
with self.test_scope():
|
||||||
x = array_ops.diag(math_ops.range(20))
|
x = array_ops.diag(math_ops.range(20))
|
||||||
shuffle = random_ops.random_shuffle(x)
|
shuffle = random_ops.random_shuffle(x)
|
||||||
result = sess.run(shuffle)
|
result = self.evaluate(shuffle)
|
||||||
expected = np.diag(range(20)).flatten()
|
expected = np.diag(range(20)).flatten()
|
||||||
# Compare sets to avoid randomness behavior changes but make sure still
|
# Compare sets to avoid randomness behavior changes but make sure still
|
||||||
# have all the values.
|
# have all the values.
|
||||||
|
@ -505,7 +505,7 @@ class TensorArrayTest(xla_test.XLATestCase):
|
|||||||
[-0.5, 1.5], # read(0) gradient
|
[-0.5, 1.5], # read(0) gradient
|
||||||
[20.0, 30.0, 40.0, 50.0], # concat gradient
|
[20.0, 30.0, 40.0, 50.0], # concat gradient
|
||||||
])
|
])
|
||||||
grad_vals = sess.run(grad_r) # 2 + 2 entries
|
grad_vals = self.evaluate(grad_r) # 2 + 2 entries
|
||||||
|
|
||||||
self.assertAllClose([2.0 - 0.5 + 20.0, 3.0 + 1.5 + 30.0], grad_vals[0])
|
self.assertAllClose([2.0 - 0.5 + 20.0, 3.0 + 1.5 + 30.0], grad_vals[0])
|
||||||
self.assertAllEqual([4.0 + 40.0, 5.0 + 50.0], grad_vals[1])
|
self.assertAllEqual([4.0 + 40.0, 5.0 + 50.0], grad_vals[1])
|
||||||
|
@ -229,7 +229,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
|||||||
resource_variable_ops.resource_scatter_add(
|
resource_variable_ops.resource_scatter_add(
|
||||||
handle, [0], constant_op.constant([[2]], dtype=dtypes.int32)))
|
handle, [0], constant_op.constant([[2]], dtype=dtypes.int32)))
|
||||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||||
self.assertAllEqual(sess.run(read), [[3], [7]])
|
self.assertAllEqual(self.evaluate(read), [[3], [7]])
|
||||||
|
|
||||||
def testScatterSub(self):
|
def testScatterSub(self):
|
||||||
with self.test_session() as sess, self.test_scope():
|
with self.test_session() as sess, self.test_scope():
|
||||||
@ -242,7 +242,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
|||||||
resource_variable_ops.resource_scatter_sub(
|
resource_variable_ops.resource_scatter_sub(
|
||||||
handle, [1], constant_op.constant([[2]], dtype=dtypes.int32)))
|
handle, [1], constant_op.constant([[2]], dtype=dtypes.int32)))
|
||||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||||
self.assertAllEqual(sess.run(read), [[4], [-1]])
|
self.assertAllEqual(self.evaluate(read), [[4], [-1]])
|
||||||
|
|
||||||
def testScatterMul(self):
|
def testScatterMul(self):
|
||||||
with self.test_session() as sess, self.test_scope():
|
with self.test_session() as sess, self.test_scope():
|
||||||
@ -255,7 +255,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
|||||||
resource_variable_ops.resource_scatter_mul(
|
resource_variable_ops.resource_scatter_mul(
|
||||||
handle, [0], constant_op.constant([[5]], dtype=dtypes.int32)))
|
handle, [0], constant_op.constant([[5]], dtype=dtypes.int32)))
|
||||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||||
self.assertEqual(sess.run(read), [[5]])
|
self.assertEqual(self.evaluate(read), [[5]])
|
||||||
|
|
||||||
def testScatterDiv(self):
|
def testScatterDiv(self):
|
||||||
with self.test_session() as sess, self.test_scope():
|
with self.test_session() as sess, self.test_scope():
|
||||||
@ -268,7 +268,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
|||||||
resource_variable_ops.resource_scatter_div(
|
resource_variable_ops.resource_scatter_div(
|
||||||
handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
|
handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
|
||||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||||
self.assertAllEqual(sess.run(read), [[2]])
|
self.assertAllEqual(self.evaluate(read), [[2]])
|
||||||
|
|
||||||
def testScatterMin(self):
|
def testScatterMin(self):
|
||||||
with self.test_session() as sess, self.test_scope():
|
with self.test_session() as sess, self.test_scope():
|
||||||
@ -281,7 +281,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
|||||||
resource_variable_ops.resource_scatter_min(
|
resource_variable_ops.resource_scatter_min(
|
||||||
handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
|
handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
|
||||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||||
self.assertEqual(sess.run(read), [[3]])
|
self.assertEqual(self.evaluate(read), [[3]])
|
||||||
|
|
||||||
def testScatterMax(self):
|
def testScatterMax(self):
|
||||||
with self.test_session() as sess, self.test_scope():
|
with self.test_session() as sess, self.test_scope():
|
||||||
@ -294,7 +294,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
|||||||
resource_variable_ops.resource_scatter_max(
|
resource_variable_ops.resource_scatter_max(
|
||||||
handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
|
handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
|
||||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||||
self.assertEqual(sess.run(read), [[6]])
|
self.assertEqual(self.evaluate(read), [[6]])
|
||||||
|
|
||||||
def testScatterUpdate(self):
|
def testScatterUpdate(self):
|
||||||
with self.test_session() as sess, self.test_scope():
|
with self.test_session() as sess, self.test_scope():
|
||||||
@ -307,7 +307,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
|||||||
resource_variable_ops.resource_scatter_update(
|
resource_variable_ops.resource_scatter_update(
|
||||||
handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
|
handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
|
||||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||||
self.assertEqual(sess.run(read), [[3]])
|
self.assertEqual(self.evaluate(read), [[3]])
|
||||||
|
|
||||||
def testScatterAddScalar(self):
|
def testScatterAddScalar(self):
|
||||||
with self.test_session() as sess, self.test_scope():
|
with self.test_session() as sess, self.test_scope():
|
||||||
@ -320,7 +320,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
|||||||
resource_variable_ops.resource_scatter_add(
|
resource_variable_ops.resource_scatter_add(
|
||||||
handle, [0], constant_op.constant(2, dtype=dtypes.int32)))
|
handle, [0], constant_op.constant(2, dtype=dtypes.int32)))
|
||||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||||
self.assertEqual(sess.run(read), [[3]])
|
self.assertEqual(self.evaluate(read), [[3]])
|
||||||
|
|
||||||
def testScatterSubScalar(self):
|
def testScatterSubScalar(self):
|
||||||
with self.test_session() as sess, self.test_scope():
|
with self.test_session() as sess, self.test_scope():
|
||||||
@ -333,7 +333,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
|||||||
resource_variable_ops.resource_scatter_sub(
|
resource_variable_ops.resource_scatter_sub(
|
||||||
handle, [0], constant_op.constant(2, dtype=dtypes.int32)))
|
handle, [0], constant_op.constant(2, dtype=dtypes.int32)))
|
||||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||||
self.assertEqual(sess.run(read), [[-1]])
|
self.assertEqual(self.evaluate(read), [[-1]])
|
||||||
|
|
||||||
def testScatterMulScalar(self):
|
def testScatterMulScalar(self):
|
||||||
with self.test_session() as sess, self.test_scope():
|
with self.test_session() as sess, self.test_scope():
|
||||||
@ -346,7 +346,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
|||||||
resource_variable_ops.resource_scatter_mul(
|
resource_variable_ops.resource_scatter_mul(
|
||||||
handle, [0], constant_op.constant(5, dtype=dtypes.int32)))
|
handle, [0], constant_op.constant(5, dtype=dtypes.int32)))
|
||||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||||
self.assertEqual(sess.run(read), [[5]])
|
self.assertEqual(self.evaluate(read), [[5]])
|
||||||
|
|
||||||
def testScatterDivScalar(self):
|
def testScatterDivScalar(self):
|
||||||
with self.test_session() as sess, self.test_scope():
|
with self.test_session() as sess, self.test_scope():
|
||||||
@ -359,7 +359,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
|||||||
resource_variable_ops.resource_scatter_div(
|
resource_variable_ops.resource_scatter_div(
|
||||||
handle, [0], constant_op.constant(3, dtype=dtypes.int32)))
|
handle, [0], constant_op.constant(3, dtype=dtypes.int32)))
|
||||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||||
self.assertEqual(sess.run(read), [[2]])
|
self.assertEqual(self.evaluate(read), [[2]])
|
||||||
|
|
||||||
def testScatterMinScalar(self):
|
def testScatterMinScalar(self):
|
||||||
with self.test_session() as sess, self.test_scope():
|
with self.test_session() as sess, self.test_scope():
|
||||||
@ -372,7 +372,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
|||||||
resource_variable_ops.resource_scatter_min(
|
resource_variable_ops.resource_scatter_min(
|
||||||
handle, [0], constant_op.constant(3, dtype=dtypes.int32)))
|
handle, [0], constant_op.constant(3, dtype=dtypes.int32)))
|
||||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||||
self.assertEqual(sess.run(read), [[3]])
|
self.assertEqual(self.evaluate(read), [[3]])
|
||||||
|
|
||||||
def testScatterMaxScalar(self):
|
def testScatterMaxScalar(self):
|
||||||
with self.test_session() as sess, self.test_scope():
|
with self.test_session() as sess, self.test_scope():
|
||||||
@ -385,7 +385,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
|||||||
resource_variable_ops.resource_scatter_max(
|
resource_variable_ops.resource_scatter_max(
|
||||||
handle, [0], constant_op.constant(3, dtype=dtypes.int32)))
|
handle, [0], constant_op.constant(3, dtype=dtypes.int32)))
|
||||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||||
self.assertEqual(sess.run(read), [[6]])
|
self.assertEqual(self.evaluate(read), [[6]])
|
||||||
|
|
||||||
def testScatterNdAddOps(self):
|
def testScatterNdAddOps(self):
|
||||||
with self.test_session() as sess, self.test_scope():
|
with self.test_session() as sess, self.test_scope():
|
||||||
@ -400,7 +400,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
|||||||
sess.run(gen_state_ops.resource_scatter_nd_add(handle, indices, updates))
|
sess.run(gen_state_ops.resource_scatter_nd_add(handle, indices, updates))
|
||||||
read = resource_variable_ops.read_variable_op(
|
read = resource_variable_ops.read_variable_op(
|
||||||
handle, dtype=dtypes.float32)
|
handle, dtype=dtypes.float32)
|
||||||
self.assertAllClose(expected, sess.run(read))
|
self.assertAllClose(expected, self.evaluate(read))
|
||||||
|
|
||||||
def testScatterNdUpdateAddOps(self):
|
def testScatterNdUpdateAddOps(self):
|
||||||
with self.test_session() as sess, self.test_scope():
|
with self.test_session() as sess, self.test_scope():
|
||||||
@ -416,7 +416,7 @@ class VariableOpsTest(xla_test.XLATestCase):
|
|||||||
gen_state_ops.resource_scatter_nd_update(handle, indices, updates))
|
gen_state_ops.resource_scatter_nd_update(handle, indices, updates))
|
||||||
read = resource_variable_ops.read_variable_op(
|
read = resource_variable_ops.read_variable_op(
|
||||||
handle, dtype=dtypes.float32)
|
handle, dtype=dtypes.float32)
|
||||||
self.assertAllClose(expected, sess.run(read))
|
self.assertAllClose(expected, self.evaluate(read))
|
||||||
|
|
||||||
|
|
||||||
class StridedSliceAssignChecker(object):
|
class StridedSliceAssignChecker(object):
|
||||||
|
@ -96,7 +96,7 @@ class KerasTest(tf.test.TestCase):
|
|||||||
sess.run(init)
|
sess.run(init)
|
||||||
sample_input = tf.random_uniform((1, 10, 10, 1))
|
sample_input = tf.random_uniform((1, 10, 10, 1))
|
||||||
output = model(sample_input) # pylint: disable=not-callable
|
output = model(sample_input) # pylint: disable=not-callable
|
||||||
self.assertEqual(sess.run(output).shape, (1, 3))
|
self.assertEqual(self.evaluate(output).shape, (1, 3))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -34,7 +34,7 @@ class ListLiteralsTest(tf.test.TestCase):
|
|||||||
result = converted()
|
result = converted()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertAllEqual(sess.run(result), [1, 2, 3])
|
self.assertAllEqual(self.evaluate(result), [1, 2, 3])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -35,7 +35,7 @@ class InputDataTest(test.TestCase):
|
|||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sample_data = tf.zeros([32000, 2])
|
sample_data = tf.zeros([32000, 2])
|
||||||
wav_encoder = contrib_audio.encode_wav(sample_data, 16000)
|
wav_encoder = contrib_audio.encode_wav(sample_data, 16000)
|
||||||
wav_data = sess.run(wav_encoder)
|
wav_data = self.evaluate(wav_encoder)
|
||||||
return wav_data
|
return wav_data
|
||||||
|
|
||||||
def _saveTestWavFile(self, filename, wav_data):
|
def _saveTestWavFile(self, filename, wav_data):
|
||||||
|
@ -33,7 +33,7 @@ class LabelWavTest(test.TestCase):
|
|||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sample_data = tf.zeros([1000, 2])
|
sample_data = tf.zeros([1000, 2])
|
||||||
wav_encoder = contrib_audio.encode_wav(sample_data, 16000)
|
wav_encoder = contrib_audio.encode_wav(sample_data, 16000)
|
||||||
wav_data = sess.run(wav_encoder)
|
wav_data = self.evaluate(wav_encoder)
|
||||||
return wav_data
|
return wav_data
|
||||||
|
|
||||||
def _saveTestWavFile(self, filename, wav_data):
|
def _saveTestWavFile(self, filename, wav_data):
|
||||||
|
@ -33,7 +33,7 @@ class WavToFeaturesTest(test.TestCase):
|
|||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sample_data = tf.zeros([32000, 2])
|
sample_data = tf.zeros([32000, 2])
|
||||||
wav_encoder = contrib_audio.encode_wav(sample_data, 16000)
|
wav_encoder = contrib_audio.encode_wav(sample_data, 16000)
|
||||||
wav_data = sess.run(wav_encoder)
|
wav_data = self.evaluate(wav_encoder)
|
||||||
return wav_data
|
return wav_data
|
||||||
|
|
||||||
def _saveTestWavFile(self, filename, wav_data):
|
def _saveTestWavFile(self, filename, wav_data):
|
||||||
|
@ -113,7 +113,7 @@ class CallTreesTest(converter_testing.TestCase):
|
|||||||
with self.compiled(node, ns) as result:
|
with self.compiled(node, ns) as result:
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
result_tensor = result.test_fn(constant_op.constant(1))
|
result_tensor = result.test_fn(constant_op.constant(1))
|
||||||
self.assertEquals(sess.run(result_tensor), 3)
|
self.assertEquals(self.evaluate(result_tensor), 3)
|
||||||
|
|
||||||
def test_call_to_decorated_function(self):
|
def test_call_to_decorated_function(self):
|
||||||
|
|
||||||
|
@ -68,7 +68,7 @@ class ListTest(converter_testing.TestCase):
|
|||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
tl = result.test_fn()
|
tl = result.test_fn()
|
||||||
r = list_ops.tensor_list_stack(tl, dtypes.int32)
|
r = list_ops.tensor_list_stack(tl, dtypes.int32)
|
||||||
self.assertAllEqual(sess.run(r), [1, 2, 3])
|
self.assertAllEqual(self.evaluate(r), [1, 2, 3])
|
||||||
|
|
||||||
def test_list_pop(self):
|
def test_list_pop(self):
|
||||||
|
|
||||||
@ -91,8 +91,8 @@ class ListTest(converter_testing.TestCase):
|
|||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
ts, tl = result.test_fn()
|
ts, tl = result.test_fn()
|
||||||
r = list_ops.tensor_list_stack(tl, dtypes.int32)
|
r = list_ops.tensor_list_stack(tl, dtypes.int32)
|
||||||
self.assertAllEqual(sess.run(r), [1, 2])
|
self.assertAllEqual(self.evaluate(r), [1, 2])
|
||||||
self.assertAllEqual(sess.run(ts), 3)
|
self.assertAllEqual(self.evaluate(ts), 3)
|
||||||
|
|
||||||
def test_double_list_pop(self):
|
def test_double_list_pop(self):
|
||||||
|
|
||||||
|
@ -48,12 +48,12 @@ class SideEffectGuardsTest(converter_testing.TestCase):
|
|||||||
with self.compiled(node, {}, state_ops.assign) as result:
|
with self.compiled(node, {}, state_ops.assign) as result:
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
v = variable_scope.get_variable('test', initializer=2)
|
v = variable_scope.get_variable('test', initializer=2)
|
||||||
sess.run(v.initializer)
|
self.evaluate(v.initializer)
|
||||||
sess.run(result.test_fn(v))
|
sess.run(result.test_fn(v))
|
||||||
# TODO(mdan): Add support for this use case.
|
# TODO(mdan): Add support for this use case.
|
||||||
# Right now the variable `a` is not conditioned on the `assign` because
|
# Right now the variable `a` is not conditioned on the `assign` because
|
||||||
# there's no way to add control dependencies to a variable object.
|
# there's no way to add control dependencies to a variable object.
|
||||||
self.assertEqual(2, sess.run(v))
|
self.assertEqual(2, self.evaluate(v))
|
||||||
|
|
||||||
def test_side_effect_on_used_variable(self):
|
def test_side_effect_on_used_variable(self):
|
||||||
|
|
||||||
@ -69,11 +69,11 @@ class SideEffectGuardsTest(converter_testing.TestCase):
|
|||||||
with self.compiled(node, {}, state_ops.assign) as result:
|
with self.compiled(node, {}, state_ops.assign) as result:
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
v = variable_scope.get_variable('test', initializer=2)
|
v = variable_scope.get_variable('test', initializer=2)
|
||||||
sess.run(v.initializer)
|
self.evaluate(v.initializer)
|
||||||
sess.run(result.test_fn(v))
|
sess.run(result.test_fn(v))
|
||||||
# TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
|
# TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
|
||||||
# Right now it's 3 or 4 based on whether the read is synchronized.
|
# Right now it's 3 or 4 based on whether the read is synchronized.
|
||||||
self.assertEqual(3, sess.run(v))
|
self.assertEqual(3, self.evaluate(v))
|
||||||
|
|
||||||
def test_side_effect_on_tensor(self):
|
def test_side_effect_on_tensor(self):
|
||||||
|
|
||||||
@ -109,10 +109,10 @@ class SideEffectGuardsTest(converter_testing.TestCase):
|
|||||||
with self.compiled(node, {}, state_ops.assign_add) as result:
|
with self.compiled(node, {}, state_ops.assign_add) as result:
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
v = variable_scope.get_variable('test', initializer=2)
|
v = variable_scope.get_variable('test', initializer=2)
|
||||||
sess.run(v.initializer)
|
self.evaluate(v.initializer)
|
||||||
sess.run(result.test_fn(v))
|
sess.run(result.test_fn(v))
|
||||||
# TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
|
# TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
|
||||||
self.assertEqual(4, sess.run(v))
|
self.assertEqual(4, self.evaluate(v))
|
||||||
|
|
||||||
def test_multiline_nested_block(self):
|
def test_multiline_nested_block(self):
|
||||||
|
|
||||||
@ -130,10 +130,10 @@ class SideEffectGuardsTest(converter_testing.TestCase):
|
|||||||
with self.compiled(node, {}, state_ops.assign, ops.name_scope) as result:
|
with self.compiled(node, {}, state_ops.assign, ops.name_scope) as result:
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
v = variable_scope.get_variable('test', initializer=2)
|
v = variable_scope.get_variable('test', initializer=2)
|
||||||
sess.run(v.initializer)
|
self.evaluate(v.initializer)
|
||||||
sess.run(result.test_fn(v))
|
sess.run(result.test_fn(v))
|
||||||
# TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
|
# TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
|
||||||
self.assertEqual(3, sess.run(v))
|
self.assertEqual(3, self.evaluate(v))
|
||||||
|
|
||||||
def test_multiline_block_unsafe(self):
|
def test_multiline_block_unsafe(self):
|
||||||
|
|
||||||
@ -153,10 +153,10 @@ class SideEffectGuardsTest(converter_testing.TestCase):
|
|||||||
state_ops.assign_add) as result:
|
state_ops.assign_add) as result:
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
v = variable_scope.get_variable('test', initializer=2)
|
v = variable_scope.get_variable('test', initializer=2)
|
||||||
sess.run(v.initializer)
|
self.evaluate(v.initializer)
|
||||||
sess.run(result.test_fn(v))
|
sess.run(result.test_fn(v))
|
||||||
# TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
|
# TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
|
||||||
self.assertEqual(4, sess.run(v))
|
self.assertEqual(4, self.evaluate(v))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -49,7 +49,7 @@ class SliceTest(converter_testing.TestCase):
|
|||||||
tl = list_ops.tensor_list_from_tensor(
|
tl = list_ops.tensor_list_from_tensor(
|
||||||
[1, 2], element_shape=constant_op.constant([], dtype=dtypes.int32))
|
[1, 2], element_shape=constant_op.constant([], dtype=dtypes.int32))
|
||||||
y = result.test_fn(tl)
|
y = result.test_fn(tl)
|
||||||
self.assertEqual(2, sess.run(y))
|
self.assertEqual(2, self.evaluate(y))
|
||||||
|
|
||||||
def test_index_access_multiple_definitions(self):
|
def test_index_access_multiple_definitions(self):
|
||||||
|
|
||||||
|
@ -63,7 +63,7 @@ class ApiTest(test.TestCase):
|
|||||||
x = tc.test_method(
|
x = tc.test_method(
|
||||||
constant_op.constant([2, 4]), constant_op.constant(1),
|
constant_op.constant([2, 4]), constant_op.constant(1),
|
||||||
constant_op.constant(-2))
|
constant_op.constant(-2))
|
||||||
self.assertListEqual([0, 1], sess.run(x).tolist())
|
self.assertListEqual([0, 1], self.evaluate(x).tolist())
|
||||||
|
|
||||||
def test_decorator_does_not_recurse(self):
|
def test_decorator_does_not_recurse(self):
|
||||||
|
|
||||||
@ -83,7 +83,7 @@ class ApiTest(test.TestCase):
|
|||||||
x = tc.test_method(
|
x = tc.test_method(
|
||||||
constant_op.constant([2, 4]), constant_op.constant(1),
|
constant_op.constant([2, 4]), constant_op.constant(1),
|
||||||
constant_op.constant(-2))
|
constant_op.constant(-2))
|
||||||
self.assertListEqual([0, 1], sess.run(x).tolist())
|
self.assertListEqual([0, 1], self.evaluate(x).tolist())
|
||||||
|
|
||||||
def test_decorator_calls_unconverted_graph(self):
|
def test_decorator_calls_unconverted_graph(self):
|
||||||
|
|
||||||
@ -104,7 +104,7 @@ class ApiTest(test.TestCase):
|
|||||||
x = tc.test_method(
|
x = tc.test_method(
|
||||||
constant_op.constant([2, 4]), constant_op.constant(1),
|
constant_op.constant([2, 4]), constant_op.constant(1),
|
||||||
constant_op.constant(-2))
|
constant_op.constant(-2))
|
||||||
self.assertListEqual([0, 1], sess.run(x).tolist())
|
self.assertListEqual([0, 1], self.evaluate(x).tolist())
|
||||||
|
|
||||||
def test_decorator_calls_unconverted_py_func(self):
|
def test_decorator_calls_unconverted_py_func(self):
|
||||||
|
|
||||||
@ -130,7 +130,7 @@ class ApiTest(test.TestCase):
|
|||||||
x = tc.test_method(
|
x = tc.test_method(
|
||||||
constant_op.constant([2, 4]), constant_op.constant(1),
|
constant_op.constant([2, 4]), constant_op.constant(1),
|
||||||
constant_op.constant(-2))
|
constant_op.constant(-2))
|
||||||
self.assertListEqual([0, 1], sess.run(x).tolist())
|
self.assertListEqual([0, 1], self.evaluate(x).tolist())
|
||||||
|
|
||||||
def test_decorator_calls_decorated(self):
|
def test_decorator_calls_decorated(self):
|
||||||
|
|
||||||
@ -153,7 +153,7 @@ class ApiTest(test.TestCase):
|
|||||||
x = tc.test_method(
|
x = tc.test_method(
|
||||||
constant_op.constant([2, 4]), constant_op.constant(1),
|
constant_op.constant([2, 4]), constant_op.constant(1),
|
||||||
constant_op.constant(-2))
|
constant_op.constant(-2))
|
||||||
self.assertListEqual([0, 1], sess.run(x).tolist())
|
self.assertListEqual([0, 1], self.evaluate(x).tolist())
|
||||||
|
|
||||||
def test_decorator_preserves_argspec(self):
|
def test_decorator_preserves_argspec(self):
|
||||||
|
|
||||||
@ -192,7 +192,7 @@ class ApiTest(test.TestCase):
|
|||||||
x = tc.test_method(
|
x = tc.test_method(
|
||||||
constant_op.constant([2, 4]), constant_op.constant(1),
|
constant_op.constant([2, 4]), constant_op.constant(1),
|
||||||
constant_op.constant(-2))
|
constant_op.constant(-2))
|
||||||
self.assertListEqual([0, 1], sess.run(x).tolist())
|
self.assertListEqual([0, 1], self.evaluate(x).tolist())
|
||||||
|
|
||||||
def test_converted_call_builtin(self):
|
def test_converted_call_builtin(self):
|
||||||
x = api.converted_call(range, None, converter.ConversionOptions(), 3)
|
x = api.converted_call(range, None, converter.ConversionOptions(), 3)
|
||||||
@ -208,7 +208,7 @@ class ApiTest(test.TestCase):
|
|||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
x = api.converted_call(test_fn, None, converter.ConversionOptions(),
|
x = api.converted_call(test_fn, None, converter.ConversionOptions(),
|
||||||
constant_op.constant(-1))
|
constant_op.constant(-1))
|
||||||
self.assertEqual(1, sess.run(x))
|
self.assertEqual(1, self.evaluate(x))
|
||||||
|
|
||||||
def test_converted_call_method_explicit_owner(self):
|
def test_converted_call_method_explicit_owner(self):
|
||||||
# TODO(mdan): Implement.
|
# TODO(mdan): Implement.
|
||||||
@ -234,7 +234,7 @@ class ApiTest(test.TestCase):
|
|||||||
tc = TestClass(constant_op.constant(-1))
|
tc = TestClass(constant_op.constant(-1))
|
||||||
x = api.converted_call(tc.test_method, None,
|
x = api.converted_call(tc.test_method, None,
|
||||||
converter.ConversionOptions(), tc)
|
converter.ConversionOptions(), tc)
|
||||||
self.assertEqual(1, sess.run(x))
|
self.assertEqual(1, self.evaluate(x))
|
||||||
|
|
||||||
def test_converted_call_method_by_class(self):
|
def test_converted_call_method_by_class(self):
|
||||||
|
|
||||||
@ -252,7 +252,7 @@ class ApiTest(test.TestCase):
|
|||||||
tc = TestClass(constant_op.constant(-1))
|
tc = TestClass(constant_op.constant(-1))
|
||||||
x = api.converted_call(TestClass.test_method, None,
|
x = api.converted_call(TestClass.test_method, None,
|
||||||
converter.ConversionOptions(), tc)
|
converter.ConversionOptions(), tc)
|
||||||
self.assertEqual(1, sess.run(x))
|
self.assertEqual(1, self.evaluate(x))
|
||||||
|
|
||||||
def test_converted_call_callable_object(self):
|
def test_converted_call_callable_object(self):
|
||||||
|
|
||||||
@ -269,7 +269,7 @@ class ApiTest(test.TestCase):
|
|||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
tc = TestClass(constant_op.constant(-1))
|
tc = TestClass(constant_op.constant(-1))
|
||||||
x = api.converted_call(tc, None, converter.ConversionOptions())
|
x = api.converted_call(tc, None, converter.ConversionOptions())
|
||||||
self.assertEqual(1, sess.run(x))
|
self.assertEqual(1, self.evaluate(x))
|
||||||
|
|
||||||
def test_converted_call_constructor(self):
|
def test_converted_call_constructor(self):
|
||||||
|
|
||||||
@ -288,7 +288,7 @@ class ApiTest(test.TestCase):
|
|||||||
constant_op.constant(-1))
|
constant_op.constant(-1))
|
||||||
# tc is now a converted object.
|
# tc is now a converted object.
|
||||||
x = tc.test_method()
|
x = tc.test_method()
|
||||||
self.assertEqual(1, sess.run(x))
|
self.assertEqual(1, self.evaluate(x))
|
||||||
|
|
||||||
def test_converted_call_already_converted(self):
|
def test_converted_call_already_converted(self):
|
||||||
|
|
||||||
@ -298,12 +298,12 @@ class ApiTest(test.TestCase):
|
|||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
x = api.converted_call(f, None, converter.ConversionOptions(),
|
x = api.converted_call(f, None, converter.ConversionOptions(),
|
||||||
constant_op.constant(0))
|
constant_op.constant(0))
|
||||||
self.assertTrue(sess.run(x))
|
self.assertTrue(self.evaluate(x))
|
||||||
|
|
||||||
converted_f = api.to_graph(f)
|
converted_f = api.to_graph(f)
|
||||||
x = api.converted_call(converted_f, None, converter.ConversionOptions(),
|
x = api.converted_call(converted_f, None, converter.ConversionOptions(),
|
||||||
constant_op.constant(0))
|
constant_op.constant(0))
|
||||||
self.assertTrue(sess.run(x))
|
self.assertTrue(self.evaluate(x))
|
||||||
|
|
||||||
def test_converted_call_no_user_code(self):
|
def test_converted_call_no_user_code(self):
|
||||||
|
|
||||||
@ -334,8 +334,8 @@ class ApiTest(test.TestCase):
|
|||||||
constant_op.constant([[0.0]]), training=True)
|
constant_op.constant([[0.0]]), training=True)
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(variables.global_variables_initializer())
|
self.evaluate(variables.global_variables_initializer())
|
||||||
self.assertAllEqual([[0.0, 0.0]], sess.run(x))
|
self.assertAllEqual([[0.0, 0.0]], self.evaluate(x))
|
||||||
|
|
||||||
def test_converted_call_whitelisted_method_extra_self(self):
|
def test_converted_call_whitelisted_method_extra_self(self):
|
||||||
|
|
||||||
@ -349,8 +349,8 @@ class ApiTest(test.TestCase):
|
|||||||
model, constant_op.constant([[0.0]]), training=True)
|
model, constant_op.constant([[0.0]]), training=True)
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(variables.global_variables_initializer())
|
self.evaluate(variables.global_variables_initializer())
|
||||||
self.assertAllEqual([[0.0, 0.0]], sess.run(x))
|
self.assertAllEqual([[0.0, 0.0]], self.evaluate(x))
|
||||||
|
|
||||||
def test_converted_call_whitelisted_method_via_owner(self):
|
def test_converted_call_whitelisted_method_via_owner(self):
|
||||||
|
|
||||||
@ -364,8 +364,8 @@ class ApiTest(test.TestCase):
|
|||||||
constant_op.constant([[0.0]]), training=True)
|
constant_op.constant([[0.0]]), training=True)
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(variables.global_variables_initializer())
|
self.evaluate(variables.global_variables_initializer())
|
||||||
self.assertAllEqual([[0.0, 0.0]], sess.run(x))
|
self.assertAllEqual([[0.0, 0.0]], self.evaluate(x))
|
||||||
|
|
||||||
def test_converted_call_lambda(self):
|
def test_converted_call_lambda(self):
|
||||||
|
|
||||||
@ -376,8 +376,8 @@ class ApiTest(test.TestCase):
|
|||||||
x = api.converted_call(l, None, opts, constant_op.constant(0))
|
x = api.converted_call(l, None, opts, constant_op.constant(0))
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(variables.global_variables_initializer())
|
self.evaluate(variables.global_variables_initializer())
|
||||||
self.assertAllEqual(True, sess.run(x))
|
self.assertAllEqual(True, self.evaluate(x))
|
||||||
|
|
||||||
def test_to_graph_basic(self):
|
def test_to_graph_basic(self):
|
||||||
|
|
||||||
@ -390,7 +390,7 @@ class ApiTest(test.TestCase):
|
|||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
x = compiled_fn(constant_op.constant([4, 8]), 4)
|
x = compiled_fn(constant_op.constant([4, 8]), 4)
|
||||||
self.assertListEqual([1, 2], sess.run(x).tolist())
|
self.assertListEqual([1, 2], self.evaluate(x).tolist())
|
||||||
|
|
||||||
def test_to_graph_with_defaults(self):
|
def test_to_graph_with_defaults(self):
|
||||||
|
|
||||||
@ -405,7 +405,7 @@ class ApiTest(test.TestCase):
|
|||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
x = compiled_fn(constant_op.constant([4, 8]))
|
x = compiled_fn(constant_op.constant([4, 8]))
|
||||||
self.assertListEqual([1, 2], sess.run(x).tolist())
|
self.assertListEqual([1, 2], self.evaluate(x).tolist())
|
||||||
|
|
||||||
def test_to_code_basic(self):
|
def test_to_code_basic(self):
|
||||||
|
|
||||||
|
@ -36,7 +36,7 @@ class SpecialFunctionsTest(test.TestCase):
|
|||||||
python_one = special_functions.match_staging_level(1, 1)
|
python_one = special_functions.match_staging_level(1, 1)
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertTrue(tensor_util.is_tensor(tensor_one))
|
self.assertTrue(tensor_util.is_tensor(tensor_one))
|
||||||
self.assertAllEqual(sess.run(tensor_one), 1)
|
self.assertAllEqual(self.evaluate(tensor_one), 1)
|
||||||
self.assertEqual(python_one, 1)
|
self.assertEqual(python_one, 1)
|
||||||
|
|
||||||
def test_tensor_list_empty_list(self):
|
def test_tensor_list_empty_list(self):
|
||||||
@ -45,21 +45,21 @@ class SpecialFunctionsTest(test.TestCase):
|
|||||||
element_shape=())
|
element_shape=())
|
||||||
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertAllEqual(sess.run(sl), [])
|
self.assertAllEqual(self.evaluate(sl), [])
|
||||||
|
|
||||||
l = special_functions.tensor_list((),
|
l = special_functions.tensor_list((),
|
||||||
element_dtype=dtypes.int32,
|
element_dtype=dtypes.int32,
|
||||||
element_shape=())
|
element_shape=())
|
||||||
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertAllEqual(sess.run(sl), [])
|
self.assertAllEqual(self.evaluate(sl), [])
|
||||||
|
|
||||||
def test_tensor_list_tensor(self):
|
def test_tensor_list_tensor(self):
|
||||||
l = special_functions.tensor_list(
|
l = special_functions.tensor_list(
|
||||||
constant_op.constant([], dtype=dtypes.int32))
|
constant_op.constant([], dtype=dtypes.int32))
|
||||||
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertAllEqual(sess.run(sl), [])
|
self.assertAllEqual(self.evaluate(sl), [])
|
||||||
|
|
||||||
def test_tensor_list_unsupported_initializer(self):
|
def test_tensor_list_unsupported_initializer(self):
|
||||||
with self.assertRaisesRegexp(ValueError, 'unknown type'):
|
with self.assertRaisesRegexp(ValueError, 'unknown type'):
|
||||||
@ -76,7 +76,7 @@ class SpecialFunctionsTest(test.TestCase):
|
|||||||
l = special_functions.tensor_list(elements)
|
l = special_functions.tensor_list(elements)
|
||||||
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
|
self.assertAllEqual(self.evaluate(sl), [[1, 2], [3, 4]])
|
||||||
|
|
||||||
def test_tensor_list_array_from_elements(self):
|
def test_tensor_list_array_from_elements(self):
|
||||||
elements = [constant_op.constant([1, 2]), constant_op.constant([3, 4])]
|
elements = [constant_op.constant([1, 2]), constant_op.constant([3, 4])]
|
||||||
@ -84,7 +84,7 @@ class SpecialFunctionsTest(test.TestCase):
|
|||||||
l = special_functions.tensor_list(elements, use_tensor_array=True)
|
l = special_functions.tensor_list(elements, use_tensor_array=True)
|
||||||
sl = l.stack()
|
sl = l.stack()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
|
self.assertAllEqual(self.evaluate(sl), [[1, 2], [3, 4]])
|
||||||
|
|
||||||
def test_stack(self):
|
def test_stack(self):
|
||||||
self.assertEqual(special_functions.stack(1, strict=False), 1)
|
self.assertEqual(special_functions.stack(1, strict=False), 1)
|
||||||
|
@ -35,7 +35,7 @@ class ForLoopTest(test.TestCase):
|
|||||||
body=lambda i, s: (s + i,),
|
body=lambda i, s: (s + i,),
|
||||||
init_state=(0,))
|
init_state=(0,))
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertEqual((10,), sess.run(s))
|
self.assertEqual((10,), self.evaluate(s))
|
||||||
|
|
||||||
def test_python(self):
|
def test_python(self):
|
||||||
s = control_flow.for_stmt(
|
s = control_flow.for_stmt(
|
||||||
@ -53,7 +53,7 @@ class ForLoopTest(test.TestCase):
|
|||||||
body=lambda i, s: (s + i,),
|
body=lambda i, s: (s + i,),
|
||||||
init_state=(0,))
|
init_state=(0,))
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertEqual((10,), sess.run(s))
|
self.assertEqual((10,), self.evaluate(s))
|
||||||
|
|
||||||
|
|
||||||
class WhileLoopTest(test.TestCase):
|
class WhileLoopTest(test.TestCase):
|
||||||
@ -66,7 +66,7 @@ class WhileLoopTest(test.TestCase):
|
|||||||
init_state=(0, 0),
|
init_state=(0, 0),
|
||||||
extra_deps=(n,))
|
extra_deps=(n,))
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertEqual((5, 10), sess.run(results))
|
self.assertEqual((5, 10), self.evaluate(results))
|
||||||
|
|
||||||
def test_python(self):
|
def test_python(self):
|
||||||
n = 5
|
n = 5
|
||||||
@ -90,9 +90,9 @@ class IfStmtTest(test.TestCase):
|
|||||||
def test_tensor(self):
|
def test_tensor(self):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
t = self.single_return_if_stmt(constant_op.constant(True))
|
t = self.single_return_if_stmt(constant_op.constant(True))
|
||||||
self.assertEqual(1, sess.run(t))
|
self.assertEqual(1, self.evaluate(t))
|
||||||
t = self.single_return_if_stmt(constant_op.constant(False))
|
t = self.single_return_if_stmt(constant_op.constant(False))
|
||||||
self.assertEqual(-1, sess.run(t))
|
self.assertEqual(-1, self.evaluate(t))
|
||||||
|
|
||||||
def test_python(self):
|
def test_python(self):
|
||||||
self.assertEqual(1, self.single_return_if_stmt(True))
|
self.assertEqual(1, self.single_return_if_stmt(True))
|
||||||
@ -101,9 +101,9 @@ class IfStmtTest(test.TestCase):
|
|||||||
def test_tensor_multiple_returns(self):
|
def test_tensor_multiple_returns(self):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
t = self.multi_return_if_stmt(constant_op.constant(True))
|
t = self.multi_return_if_stmt(constant_op.constant(True))
|
||||||
self.assertAllEqual([1, 2], sess.run(t))
|
self.assertAllEqual([1, 2], self.evaluate(t))
|
||||||
t = self.multi_return_if_stmt(constant_op.constant(False))
|
t = self.multi_return_if_stmt(constant_op.constant(False))
|
||||||
self.assertAllEqual([-1, -2], sess.run(t))
|
self.assertAllEqual([-1, -2], self.evaluate(t))
|
||||||
|
|
||||||
def test_python_multiple_returns(self):
|
def test_python_multiple_returns(self):
|
||||||
self.assertEqual((1, 2), self.multi_return_if_stmt(True))
|
self.assertEqual((1, 2), self.multi_return_if_stmt(True))
|
||||||
|
@ -43,7 +43,7 @@ class ListTest(test.TestCase):
|
|||||||
l = data_structures.tf_tensor_list_new([3, 4, 5])
|
l = data_structures.tf_tensor_list_new([3, 4, 5])
|
||||||
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertAllEqual(sess.run(t), [3, 4, 5])
|
self.assertAllEqual(self.evaluate(t), [3, 4, 5])
|
||||||
|
|
||||||
def test_tf_tensor_list_new_empty(self):
|
def test_tf_tensor_list_new_empty(self):
|
||||||
l = data_structures.tf_tensor_list_new([],
|
l = data_structures.tf_tensor_list_new([],
|
||||||
@ -51,13 +51,13 @@ class ListTest(test.TestCase):
|
|||||||
element_shape=())
|
element_shape=())
|
||||||
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertAllEqual(sess.run(t), [])
|
self.assertAllEqual(self.evaluate(t), [])
|
||||||
|
|
||||||
def test_tf_tensor_list_new_from_tensor(self):
|
def test_tf_tensor_list_new_from_tensor(self):
|
||||||
l = data_structures.tf_tensor_list_new(constant_op.constant([3, 4, 5]))
|
l = data_structures.tf_tensor_list_new(constant_op.constant([3, 4, 5]))
|
||||||
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertAllEqual(sess.run(t), [3, 4, 5])
|
self.assertAllEqual(self.evaluate(t), [3, 4, 5])
|
||||||
|
|
||||||
def test_tf_tensor_list_new_illegal_input(self):
|
def test_tf_tensor_list_new_illegal_input(self):
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
@ -77,7 +77,7 @@ class ListTest(test.TestCase):
|
|||||||
l = data_structures.tf_tensor_array_new([3, 4, 5])
|
l = data_structures.tf_tensor_array_new([3, 4, 5])
|
||||||
t = l.stack()
|
t = l.stack()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertAllEqual(sess.run(t), [3, 4, 5])
|
self.assertAllEqual(self.evaluate(t), [3, 4, 5])
|
||||||
|
|
||||||
def test_tf_tensor_array_new_illegal_input(self):
|
def test_tf_tensor_array_new_illegal_input(self):
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
@ -102,7 +102,7 @@ class ListTest(test.TestCase):
|
|||||||
|
|
||||||
t = list_ops.tensor_list_stack(l, element_dtype=x.dtype)
|
t = list_ops.tensor_list_stack(l, element_dtype=x.dtype)
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertAllEqual(sess.run(t), [[1, 2, 3]])
|
self.assertAllEqual(self.evaluate(t), [[1, 2, 3]])
|
||||||
|
|
||||||
def test_append_tensorarray(self):
|
def test_append_tensorarray(self):
|
||||||
l = tensor_array_ops.TensorArray(dtypes.int32, size=0, dynamic_size=True)
|
l = tensor_array_ops.TensorArray(dtypes.int32, size=0, dynamic_size=True)
|
||||||
@ -131,10 +131,10 @@ class ListTest(test.TestCase):
|
|||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
l, x = data_structures.list_pop(l, None, opts)
|
l, x = data_structures.list_pop(l, None, opts)
|
||||||
self.assertAllEqual(sess.run(x), [3, 4])
|
self.assertAllEqual(self.evaluate(x), [3, 4])
|
||||||
|
|
||||||
t = list_ops.tensor_list_stack(l, element_dtype=initial_list.dtype)
|
t = list_ops.tensor_list_stack(l, element_dtype=initial_list.dtype)
|
||||||
self.assertAllEqual(sess.run(t), [[1, 2]])
|
self.assertAllEqual(self.evaluate(t), [[1, 2]])
|
||||||
|
|
||||||
def test_pop_python(self):
|
def test_pop_python(self):
|
||||||
l = [1, 2, 3]
|
l = [1, 2, 3]
|
||||||
@ -152,7 +152,7 @@ class ListTest(test.TestCase):
|
|||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
t = data_structures.list_stack(l, opts)
|
t = data_structures.list_stack(l, opts)
|
||||||
self.assertAllEqual(sess.run(t), sess.run(initial_list))
|
self.assertAllEqual(sess.run(t), self.evaluate(initial_list))
|
||||||
|
|
||||||
def test_stack_tensor_list_empty(self):
|
def test_stack_tensor_list_empty(self):
|
||||||
l = list_ops.empty_tensor_list(
|
l = list_ops.empty_tensor_list(
|
||||||
|
@ -45,11 +45,11 @@ class LogicalOperatorsTest(test.TestCase):
|
|||||||
def test_and_tf(self):
|
def test_and_tf(self):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
t = logical.and_(self._tf_true, self._tf_true)
|
t = logical.and_(self._tf_true, self._tf_true)
|
||||||
self.assertEqual(sess.run(t), True)
|
self.assertEqual(self.evaluate(t), True)
|
||||||
t = logical.and_(self._tf_true, lambda: True)
|
t = logical.and_(self._tf_true, lambda: True)
|
||||||
self.assertEqual(sess.run(t), True)
|
self.assertEqual(self.evaluate(t), True)
|
||||||
t = logical.and_(self._tf_false, lambda: True)
|
t = logical.and_(self._tf_false, lambda: True)
|
||||||
self.assertEqual(sess.run(t), False)
|
self.assertEqual(self.evaluate(t), False)
|
||||||
# TODO(mdan): Add a test for ops with side effects.
|
# TODO(mdan): Add a test for ops with side effects.
|
||||||
|
|
||||||
def test_or_python(self):
|
def test_or_python(self):
|
||||||
@ -63,11 +63,11 @@ class LogicalOperatorsTest(test.TestCase):
|
|||||||
def test_or_tf(self):
|
def test_or_tf(self):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
t = logical.or_(self._tf_false, self._tf_true)
|
t = logical.or_(self._tf_false, self._tf_true)
|
||||||
self.assertEqual(sess.run(t), True)
|
self.assertEqual(self.evaluate(t), True)
|
||||||
t = logical.or_(self._tf_false, lambda: True)
|
t = logical.or_(self._tf_false, lambda: True)
|
||||||
self.assertEqual(sess.run(t), True)
|
self.assertEqual(self.evaluate(t), True)
|
||||||
t = logical.or_(self._tf_true, lambda: True)
|
t = logical.or_(self._tf_true, lambda: True)
|
||||||
self.assertEqual(sess.run(t), True)
|
self.assertEqual(self.evaluate(t), True)
|
||||||
# TODO(mdan): Add a test for ops with side effects.
|
# TODO(mdan): Add a test for ops with side effects.
|
||||||
|
|
||||||
def test_not_python(self):
|
def test_not_python(self):
|
||||||
@ -78,7 +78,7 @@ class LogicalOperatorsTest(test.TestCase):
|
|||||||
def test_not_tf(self):
|
def test_not_tf(self):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
t = logical.not_(self._tf_false())
|
t = logical.not_(self._tf_false())
|
||||||
self.assertEqual(sess.run(t), True)
|
self.assertEqual(self.evaluate(t), True)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -38,29 +38,29 @@ class PyBuiltinsTest(test.TestCase):
|
|||||||
self.assertEqual(py_builtins.abs_(-1), 1)
|
self.assertEqual(py_builtins.abs_(-1), 1)
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
t = py_builtins.abs_(constant_op.constant(-1))
|
t = py_builtins.abs_(constant_op.constant(-1))
|
||||||
self.assertEqual(sess.run(t), 1)
|
self.assertEqual(self.evaluate(t), 1)
|
||||||
t = py_builtins.abs_(constant_op.constant([-1, 2, -3]))
|
t = py_builtins.abs_(constant_op.constant([-1, 2, -3]))
|
||||||
self.assertAllEqual(sess.run(t), [1, 2, 3])
|
self.assertAllEqual(self.evaluate(t), [1, 2, 3])
|
||||||
|
|
||||||
def test_float(self):
|
def test_float(self):
|
||||||
self.assertEqual(py_builtins.float_(10), 10.0)
|
self.assertEqual(py_builtins.float_(10), 10.0)
|
||||||
self.assertEqual(py_builtins.float_('10.0'), 10.0)
|
self.assertEqual(py_builtins.float_('10.0'), 10.0)
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
t = py_builtins.float_(constant_op.constant(1, dtype=dtypes.int64))
|
t = py_builtins.float_(constant_op.constant(1, dtype=dtypes.int64))
|
||||||
self.assertEqual(sess.run(t), 1.0)
|
self.assertEqual(self.evaluate(t), 1.0)
|
||||||
st = py_builtins.float_(constant_op.constant('1.0'))
|
st = py_builtins.float_(constant_op.constant('1.0'))
|
||||||
self.assertEqual(sess.run(st), 1.0)
|
self.assertEqual(self.evaluate(st), 1.0)
|
||||||
|
|
||||||
def test_int(self):
|
def test_int(self):
|
||||||
self.assertEqual(py_builtins.int_(10.0), 10)
|
self.assertEqual(py_builtins.int_(10.0), 10)
|
||||||
self.assertEqual(py_builtins.int_('11', 2), 3)
|
self.assertEqual(py_builtins.int_('11', 2), 3)
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
t = py_builtins.int_(constant_op.constant(1, dtype=dtypes.float64))
|
t = py_builtins.int_(constant_op.constant(1, dtype=dtypes.float64))
|
||||||
self.assertEqual(sess.run(t), 1)
|
self.assertEqual(self.evaluate(t), 1)
|
||||||
st = py_builtins.int_(constant_op.constant('1'))
|
st = py_builtins.int_(constant_op.constant('1'))
|
||||||
self.assertEqual(sess.run(st), 1)
|
self.assertEqual(self.evaluate(st), 1)
|
||||||
st = py_builtins.int_(constant_op.constant('1'), 10)
|
st = py_builtins.int_(constant_op.constant('1'), 10)
|
||||||
self.assertEqual(sess.run(st), 1)
|
self.assertEqual(self.evaluate(st), 1)
|
||||||
|
|
||||||
def test_int_unsupported_base(self):
|
def test_int_unsupported_base(self):
|
||||||
t = constant_op.constant(1, dtype=dtypes.float64)
|
t = constant_op.constant(1, dtype=dtypes.float64)
|
||||||
@ -73,9 +73,9 @@ class PyBuiltinsTest(test.TestCase):
|
|||||||
t = py_builtins.len_(constant_op.constant([[1], [2], [3]]))
|
t = py_builtins.len_(constant_op.constant([[1], [2], [3]]))
|
||||||
self.assertEqual(t, 3)
|
self.assertEqual(t, 3)
|
||||||
ta = py_builtins.len_(tensor_array_ops.TensorArray(dtypes.int32, size=5))
|
ta = py_builtins.len_(tensor_array_ops.TensorArray(dtypes.int32, size=5))
|
||||||
self.assertEqual(sess.run(ta), 5)
|
self.assertEqual(self.evaluate(ta), 5)
|
||||||
tl = py_builtins.len_(data_structures.tf_tensor_list_new([3, 4, 5]))
|
tl = py_builtins.len_(data_structures.tf_tensor_list_new([3, 4, 5]))
|
||||||
self.assertEqual(sess.run(tl), 3)
|
self.assertEqual(self.evaluate(tl), 3)
|
||||||
|
|
||||||
def test_len_scalar(self):
|
def test_len_scalar(self):
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
@ -120,18 +120,18 @@ class PyBuiltinsTest(test.TestCase):
|
|||||||
def test_range_tensor(self):
|
def test_range_tensor(self):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
r = py_builtins.range_(constant_op.constant(3))
|
r = py_builtins.range_(constant_op.constant(3))
|
||||||
self.assertAllEqual(sess.run(r), [0, 1, 2])
|
self.assertAllEqual(self.evaluate(r), [0, 1, 2])
|
||||||
r = py_builtins.range_(1, constant_op.constant(3))
|
r = py_builtins.range_(1, constant_op.constant(3))
|
||||||
self.assertAllEqual(sess.run(r), [1, 2])
|
self.assertAllEqual(self.evaluate(r), [1, 2])
|
||||||
r = py_builtins.range_(2, 0, constant_op.constant(-1))
|
r = py_builtins.range_(2, 0, constant_op.constant(-1))
|
||||||
self.assertAllEqual(sess.run(r), [2, 1])
|
self.assertAllEqual(self.evaluate(r), [2, 1])
|
||||||
|
|
||||||
def test_range_tensor_empty_range(self):
|
def test_range_tensor_empty_range(self):
|
||||||
with self.session() as sess:
|
with self.session() as sess:
|
||||||
r = py_builtins.range_(constant_op.constant(-3))
|
r = py_builtins.range_(constant_op.constant(-3))
|
||||||
self.assertAllEqual(sess.run(r), [])
|
self.assertAllEqual(self.evaluate(r), [])
|
||||||
r = py_builtins.range_(5, constant_op.constant(2))
|
r = py_builtins.range_(5, constant_op.constant(2))
|
||||||
self.assertAllEqual(sess.run(r), [])
|
self.assertAllEqual(self.evaluate(r), [])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -34,7 +34,7 @@ class SlicesTest(test.TestCase):
|
|||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
t = list_ops.tensor_list_stack(l, element_dtype=initial_list.dtype)
|
t = list_ops.tensor_list_stack(l, element_dtype=initial_list.dtype)
|
||||||
self.assertAllEqual(sess.run(t), [[5, 6], [3, 4]])
|
self.assertAllEqual(self.evaluate(t), [[5, 6], [3, 4]])
|
||||||
|
|
||||||
def test_get_item_tensor_list(self):
|
def test_get_item_tensor_list(self):
|
||||||
initial_list = constant_op.constant([[1, 2], [3, 4]])
|
initial_list = constant_op.constant([[1, 2], [3, 4]])
|
||||||
@ -44,7 +44,7 @@ class SlicesTest(test.TestCase):
|
|||||||
l, 1, slices.GetItemOpts(element_dtype=initial_list.dtype))
|
l, 1, slices.GetItemOpts(element_dtype=initial_list.dtype))
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertAllEqual(sess.run(t), [3, 4])
|
self.assertAllEqual(self.evaluate(t), [3, 4])
|
||||||
|
|
||||||
def test_get_item_tensor_string(self):
|
def test_get_item_tensor_string(self):
|
||||||
initial_str = constant_op.constant('abcd')
|
initial_str = constant_op.constant('abcd')
|
||||||
@ -52,14 +52,14 @@ class SlicesTest(test.TestCase):
|
|||||||
slices.GetItemOpts(element_dtype=initial_str.dtype))
|
slices.GetItemOpts(element_dtype=initial_str.dtype))
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertEqual(sess.run(t), b'b')
|
self.assertEqual(self.evaluate(t), b'b')
|
||||||
|
|
||||||
initial_list_str = constant_op.constant(['abcd', 'bcde'])
|
initial_list_str = constant_op.constant(['abcd', 'bcde'])
|
||||||
t = slices.get_item(initial_list_str, 1,
|
t = slices.get_item(initial_list_str, 1,
|
||||||
slices.GetItemOpts(element_dtype=initial_str.dtype))
|
slices.GetItemOpts(element_dtype=initial_str.dtype))
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertEqual(sess.run(t), b'bcde')
|
self.assertEqual(self.evaluate(t), b'bcde')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -32,7 +32,7 @@ class MiscTest(test.TestCase):
|
|||||||
new_a = alias_tensors(a)
|
new_a = alias_tensors(a)
|
||||||
self.assertFalse(new_a is a)
|
self.assertFalse(new_a is a)
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertEqual(1, sess.run(new_a))
|
self.assertEqual(1, self.evaluate(new_a))
|
||||||
|
|
||||||
def test_alias_tensors(self):
|
def test_alias_tensors(self):
|
||||||
a = constant(1)
|
a = constant(1)
|
||||||
@ -47,7 +47,7 @@ class MiscTest(test.TestCase):
|
|||||||
self.assertTrue(new_s is s)
|
self.assertTrue(new_s is s)
|
||||||
self.assertTrue(new_l is l)
|
self.assertTrue(new_l is l)
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertEqual(1, sess.run(new_a))
|
self.assertEqual(1, self.evaluate(new_a))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -34,13 +34,13 @@ class PyFuncTest(test.TestCase):
|
|||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
result = py_func.wrap_py_func(test_fn, dtypes.int64,
|
result = py_func.wrap_py_func(test_fn, dtypes.int64,
|
||||||
(1, constant_op.constant(1), 1))
|
(1, constant_op.constant(1), 1))
|
||||||
self.assertEqual(3, sess.run(result))
|
self.assertEqual(3, self.evaluate(result))
|
||||||
result = py_func.wrap_py_func(test_fn, dtypes.int64, (1, 1, 1))
|
result = py_func.wrap_py_func(test_fn, dtypes.int64, (1, 1, 1))
|
||||||
self.assertEqual(3, sess.run(result))
|
self.assertEqual(3, self.evaluate(result))
|
||||||
result = py_func.wrap_py_func(
|
result = py_func.wrap_py_func(
|
||||||
test_fn, dtypes.int64,
|
test_fn, dtypes.int64,
|
||||||
(constant_op.constant(1), 1, constant_op.constant(1)))
|
(constant_op.constant(1), 1, constant_op.constant(1)))
|
||||||
self.assertEqual(3, sess.run(result))
|
self.assertEqual(3, self.evaluate(result))
|
||||||
|
|
||||||
def test_wrap_py_func_complex_args(self):
|
def test_wrap_py_func_complex_args(self):
|
||||||
|
|
||||||
@ -54,10 +54,10 @@ class PyFuncTest(test.TestCase):
|
|||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
result = py_func.wrap_py_func(test_fn, dtypes.int64, (7, TestClass()))
|
result = py_func.wrap_py_func(test_fn, dtypes.int64, (7, TestClass()))
|
||||||
self.assertEqual(35, sess.run(result))
|
self.assertEqual(35, self.evaluate(result))
|
||||||
result = py_func.wrap_py_func(test_fn, dtypes.int64,
|
result = py_func.wrap_py_func(test_fn, dtypes.int64,
|
||||||
(constant_op.constant(7), TestClass()))
|
(constant_op.constant(7), TestClass()))
|
||||||
self.assertEqual(35, sess.run(result))
|
self.assertEqual(35, self.evaluate(result))
|
||||||
|
|
||||||
def test_wrap_py_func_kwargs(self):
|
def test_wrap_py_func_kwargs(self):
|
||||||
|
|
||||||
@ -74,13 +74,13 @@ class PyFuncTest(test.TestCase):
|
|||||||
'c': 11,
|
'c': 11,
|
||||||
'd': TestClass(13)
|
'd': TestClass(13)
|
||||||
})
|
})
|
||||||
self.assertEqual(178, sess.run(result))
|
self.assertEqual(178, self.evaluate(result))
|
||||||
result = py_func.wrap_py_func(test_fn, dtypes.int64,
|
result = py_func.wrap_py_func(test_fn, dtypes.int64,
|
||||||
(constant_op.constant(7), TestClass(5)), {
|
(constant_op.constant(7), TestClass(5)), {
|
||||||
'c': constant_op.constant(11),
|
'c': constant_op.constant(11),
|
||||||
'd': TestClass(13)
|
'd': TestClass(13)
|
||||||
})
|
})
|
||||||
self.assertEqual(178, sess.run(result))
|
self.assertEqual(178, self.evaluate(result))
|
||||||
|
|
||||||
def test_wrap_py_func_dummy_return(self):
|
def test_wrap_py_func_dummy_return(self):
|
||||||
|
|
||||||
@ -91,11 +91,11 @@ class PyFuncTest(test.TestCase):
|
|||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
result = py_func.wrap_py_func(test_fn, None, (5,), use_dummy_return=True)
|
result = py_func.wrap_py_func(test_fn, None, (5,), use_dummy_return=True)
|
||||||
self.assertEqual(1, sess.run(result))
|
self.assertEqual(1, self.evaluate(result))
|
||||||
self.assertEqual([1], side_counter)
|
self.assertEqual([1], side_counter)
|
||||||
result = py_func.wrap_py_func(
|
result = py_func.wrap_py_func(
|
||||||
test_fn, None, (constant_op.constant(5),), use_dummy_return=True)
|
test_fn, None, (constant_op.constant(5),), use_dummy_return=True)
|
||||||
self.assertEqual(1, sess.run(result))
|
self.assertEqual(1, self.evaluate(result))
|
||||||
self.assertEqual([2], side_counter)
|
self.assertEqual([2], side_counter)
|
||||||
|
|
||||||
|
|
||||||
|
@ -43,13 +43,13 @@ class TensorListTest(test.TestCase):
|
|||||||
l = tl.dynamic_list_append(l, 1)
|
l = tl.dynamic_list_append(l, 1)
|
||||||
s = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
s = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertAllEqual(sess.run(s), [1])
|
self.assertAllEqual(self.evaluate(s), [1])
|
||||||
|
|
||||||
l = tensor_array_ops.TensorArray(dtypes.int32, size=0, dynamic_size=True)
|
l = tensor_array_ops.TensorArray(dtypes.int32, size=0, dynamic_size=True)
|
||||||
l = tl.dynamic_list_append(l, 1)
|
l = tl.dynamic_list_append(l, 1)
|
||||||
s = l.stack()
|
s = l.stack()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertAllEqual(sess.run(s), [1])
|
self.assertAllEqual(self.evaluate(s), [1])
|
||||||
|
|
||||||
l = tl.TensorList(self._shape(()), dtypes.int32)
|
l = tl.TensorList(self._shape(()), dtypes.int32)
|
||||||
l = tl.dynamic_list_append(l, 1)
|
l = tl.dynamic_list_append(l, 1)
|
||||||
|
@ -62,7 +62,7 @@ class SessionClusterSpecPropagationTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
const = constant_op.constant(17)
|
const = constant_op.constant(17)
|
||||||
sess = session.Session(server1.target, config=config)
|
sess = session.Session(server1.target, config=config)
|
||||||
output = sess.run(const)
|
output = self.evaluate(const)
|
||||||
self.assertEqual(17, output)
|
self.assertEqual(17, output)
|
||||||
|
|
||||||
def testClusterSpecPropagationWorker2Placement(self):
|
def testClusterSpecPropagationWorker2Placement(self):
|
||||||
@ -106,7 +106,7 @@ class SessionClusterSpecPropagationTest(test_util.TensorFlowTestCase):
|
|||||||
with ops.Graph().as_default() as g, ops.device('/job:worker/task:0'):
|
with ops.Graph().as_default() as g, ops.device('/job:worker/task:0'):
|
||||||
const = constant_op.constant(17)
|
const = constant_op.constant(17)
|
||||||
sess = session.Session(server1.target, config=config, graph=g)
|
sess = session.Session(server1.target, config=config, graph=g)
|
||||||
output = sess.run(const)
|
output = self.evaluate(const)
|
||||||
self.assertEqual(17, output)
|
self.assertEqual(17, output)
|
||||||
|
|
||||||
def testCanonicalDeviceNames(self):
|
def testCanonicalDeviceNames(self):
|
||||||
@ -208,7 +208,7 @@ class SessionClusterSpecPropagationTest(test_util.TensorFlowTestCase):
|
|||||||
with ops.device('/job:worker/task:0/cpu:0'):
|
with ops.device('/job:worker/task:0/cpu:0'):
|
||||||
sum3 = sum1 + sum2
|
sum3 = sum1 + sum2
|
||||||
sess = session.Session(server1.target, config=config, graph=g)
|
sess = session.Session(server1.target, config=config, graph=g)
|
||||||
output = sess.run(sum3)
|
output = self.evaluate(sum3)
|
||||||
self.assertEqual(40, output)
|
self.assertEqual(40, output)
|
||||||
|
|
||||||
def testLegacyDeviceNames(self):
|
def testLegacyDeviceNames(self):
|
||||||
|
@ -147,7 +147,7 @@ class TimelineTest(test.TestCase):
|
|||||||
num2 = variables.Variable(2.0, name='num2')
|
num2 = variables.Variable(2.0, name='num2')
|
||||||
with ops.device('/cpu:2'):
|
with ops.device('/cpu:2'):
|
||||||
result = num1 + num2 + num1 * num2
|
result = num1 + num2 + num1 * num2
|
||||||
sess.run(variables.global_variables_initializer())
|
self.evaluate(variables.global_variables_initializer())
|
||||||
sess.run(result, options=run_options, run_metadata=run_metadata)
|
sess.run(result, options=run_options, run_metadata=run_metadata)
|
||||||
|
|
||||||
self.assertTrue(run_metadata.HasField('step_stats'))
|
self.assertTrue(run_metadata.HasField('step_stats'))
|
||||||
@ -176,7 +176,7 @@ class TimelineTest(test.TestCase):
|
|||||||
num2 = variables.Variable(2.0, name='num2')
|
num2 = variables.Variable(2.0, name='num2')
|
||||||
with ops.device('/cpu:2'):
|
with ops.device('/cpu:2'):
|
||||||
result = num1 + num2 + num1 * num2
|
result = num1 + num2 + num1 * num2
|
||||||
sess.run(variables.global_variables_initializer())
|
self.evaluate(variables.global_variables_initializer())
|
||||||
sess.run(result, options=run_options, run_metadata=run_metadata)
|
sess.run(result, options=run_options, run_metadata=run_metadata)
|
||||||
self.assertTrue(run_metadata.HasField('step_stats'))
|
self.assertTrue(run_metadata.HasField('step_stats'))
|
||||||
step_stats = run_metadata.step_stats
|
step_stats = run_metadata.step_stats
|
||||||
|
@ -216,7 +216,7 @@ class VirtualGpuTest(test_util.TensorFlowTestCase):
|
|||||||
for d in self._util.devices:
|
for d in self._util.devices:
|
||||||
with ops.device(d):
|
with ops.device(d):
|
||||||
var = variables.Variable(random_ops.random_uniform(mat_shape))
|
var = variables.Variable(random_ops.random_uniform(mat_shape))
|
||||||
sess.run(var.initializer)
|
self.evaluate(var.initializer)
|
||||||
data.append(var)
|
data.append(var)
|
||||||
s = data[0]
|
s = data[0]
|
||||||
for i in range(1, len(data)):
|
for i in range(1, len(data)):
|
||||||
|
@ -53,10 +53,10 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
|
|
||||||
for start in range(0, len(components), 4):
|
for start in range(0, len(components), 4):
|
||||||
results = sess.run(get_next)
|
results = self.evaluate(get_next)
|
||||||
self.assertAllEqual([[i, j]
|
self.assertAllEqual([[i, j]
|
||||||
for i, c in enumerate(components[start:start + 4])
|
for i, c in enumerate(components[start:start + 4])
|
||||||
for j in range(c)], results.indices)
|
for j in range(c)], results.indices)
|
||||||
@ -81,10 +81,10 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
|
|
||||||
for start in range(0, len(components), 4):
|
for start in range(0, len(components), 4):
|
||||||
results = sess.run(get_next)
|
results = self.evaluate(get_next)
|
||||||
self.assertAllEqual([[i, j, z]
|
self.assertAllEqual([[i, j, z]
|
||||||
for i, c in enumerate(components[start:start + 4])
|
for i, c in enumerate(components[start:start + 4])
|
||||||
for j in range(c)
|
for j in range(c)
|
||||||
@ -141,7 +141,7 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(iterator.initializer, feed_dict={placeholder: [0, 1, 2, 3]})
|
sess.run(iterator.initializer, feed_dict={placeholder: [0, 1, 2, 3]})
|
||||||
for i in range(4):
|
for i in range(4):
|
||||||
self.assertEqual(i, sess.run(next_elem))
|
self.assertEqual(i, self.evaluate(next_elem))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_elem)
|
sess.run(next_elem)
|
||||||
|
|
||||||
@ -159,7 +159,7 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
self.assertEqual((i,) * 3, sess.run(op))
|
self.assertEqual((i,) * 3, self.evaluate(op))
|
||||||
|
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(op)
|
sess.run(op)
|
||||||
@ -179,7 +179,7 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
self.assertEqual((i, compat.as_bytes(str(i)), i), sess.run(op))
|
self.assertEqual((i, compat.as_bytes(str(i)), i), self.evaluate(op))
|
||||||
|
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(op)
|
sess.run(op)
|
||||||
@ -198,7 +198,7 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
st_row = sess.run(next_element)
|
st_row = self.evaluate(next_element)
|
||||||
self.assertEqual([i], st_row.indices)
|
self.assertEqual([i], st_row.indices)
|
||||||
self.assertEqual([i], st_row.values)
|
self.assertEqual([i], st_row.values)
|
||||||
self.assertEqual([10], st_row.dense_shape)
|
self.assertEqual([10], st_row.dense_shape)
|
||||||
@ -219,7 +219,7 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
dense_elem, st_row = sess.run(next_element)
|
dense_elem, st_row = self.evaluate(next_element)
|
||||||
self.assertEqual(i, dense_elem)
|
self.assertEqual(i, dense_elem)
|
||||||
self.assertEqual([i], st_row.indices)
|
self.assertEqual([i], st_row.indices)
|
||||||
self.assertEqual([i], st_row.values)
|
self.assertEqual([i], st_row.values)
|
||||||
@ -241,7 +241,7 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
self.assertEqual(((i,),) * 3, sess.run(op))
|
self.assertEqual(((i,),) * 3, self.evaluate(op))
|
||||||
|
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(op)
|
sess.run(op)
|
||||||
@ -354,7 +354,7 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
sess.run(init_op, feed_dict={count: 28, batch_size: 14})
|
sess.run(init_op, feed_dict={count: 28, batch_size: 14})
|
||||||
num_batches = (28 * 7) // 14
|
num_batches = (28 * 7) // 14
|
||||||
for i in range(num_batches):
|
for i in range(num_batches):
|
||||||
result = sess.run(get_next)
|
result = self.evaluate(get_next)
|
||||||
for component, result_component in zip(components, result):
|
for component, result_component in zip(components, result):
|
||||||
for j in range(14):
|
for j in range(14):
|
||||||
self.assertAllEqual(component[(i * 14 + j) % 7]**2,
|
self.assertAllEqual(component[(i * 14 + j) % 7]**2,
|
||||||
@ -369,12 +369,12 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
# We expect (num_batches - 1) full-sized batches.
|
# We expect (num_batches - 1) full-sized batches.
|
||||||
num_batches = int(math.ceil((14 * 7) / 8))
|
num_batches = int(math.ceil((14 * 7) / 8))
|
||||||
for i in range(num_batches - 1):
|
for i in range(num_batches - 1):
|
||||||
result = sess.run(get_next)
|
result = self.evaluate(get_next)
|
||||||
for component, result_component in zip(components, result):
|
for component, result_component in zip(components, result):
|
||||||
for j in range(8):
|
for j in range(8):
|
||||||
self.assertAllEqual(component[(i * 8 + j) % 7]**2,
|
self.assertAllEqual(component[(i * 8 + j) % 7]**2,
|
||||||
result_component[j])
|
result_component[j])
|
||||||
result = sess.run(get_next)
|
result = self.evaluate(get_next)
|
||||||
for component, result_component in zip(components, result):
|
for component, result_component in zip(components, result):
|
||||||
for j in range((14 * 7) % 8):
|
for j in range((14 * 7) % 8):
|
||||||
self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2,
|
self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2,
|
||||||
@ -408,10 +408,10 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
self.assertEqual([None, 1], iterator.output_shapes.as_list())
|
self.assertEqual([None, 1], iterator.output_shapes.as_list())
|
||||||
next_element = iterator.get_next()
|
next_element = iterator.get_next()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
|
self.assertAllEqual([[0], [1], [4], [9]], self.evaluate(next_element))
|
||||||
self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
|
self.assertAllEqual([[16], [25], [36], [49]], self.evaluate(next_element))
|
||||||
if not drop_remainder:
|
if not drop_remainder:
|
||||||
self.assertAllEqual([[64], [81]], sess.run(next_element))
|
self.assertAllEqual([[64], [81]], self.evaluate(next_element))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
|
|
||||||
@ -423,9 +423,9 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
self.assertEqual([None, 1], iterator.output_shapes.as_list())
|
self.assertEqual([None, 1], iterator.output_shapes.as_list())
|
||||||
next_element = iterator.get_next()
|
next_element = iterator.get_next()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
|
self.assertAllEqual([[0], [1], [4], [9]], self.evaluate(next_element))
|
||||||
self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
|
self.assertAllEqual([[16], [25], [36], [49]], self.evaluate(next_element))
|
||||||
self.assertAllEqual([[64], [81]], sess.run(next_element))
|
self.assertAllEqual([[64], [81]], self.evaluate(next_element))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
|
|
||||||
@ -439,7 +439,7 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
elements.append(iterator.get_next())
|
elements.append(iterator.get_next())
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
got = sess.run(elements)
|
got = self.evaluate(elements)
|
||||||
got.sort(key=lambda x: x[0])
|
got.sort(key=lambda x: x[0])
|
||||||
expected = []
|
expected = []
|
||||||
for j in range(100):
|
for j in range(100):
|
||||||
@ -459,7 +459,7 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
elements.append(iterator.get_next())
|
elements.append(iterator.get_next())
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
for i in range(4):
|
for i in range(4):
|
||||||
got = sess.run(elements)
|
got = self.evaluate(elements)
|
||||||
got.sort(key=lambda x: x[0])
|
got.sort(key=lambda x: x[0])
|
||||||
expected = []
|
expected = []
|
||||||
for j in range(100):
|
for j in range(100):
|
||||||
@ -480,9 +480,9 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
actual = sess.run(get_next)
|
actual = self.evaluate(get_next)
|
||||||
expected = sparse_tensor.SparseTensorValue(
|
expected = sparse_tensor.SparseTensorValue(
|
||||||
indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
|
indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
|
||||||
values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4],
|
values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4],
|
||||||
@ -524,7 +524,7 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
init_op = iterator.initializer
|
init_op = iterator.initializer
|
||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||||
"number of elements does not match"):
|
"number of elements does not match"):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
@ -576,7 +576,8 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
for i in range(threshold // 10):
|
for i in range(threshold // 10):
|
||||||
self.assertAllEqual([i * 10 + j for j in range(10)], sess.run(get_next))
|
self.assertAllEqual([i * 10 + j for j in range(10)],
|
||||||
|
self.evaluate(get_next))
|
||||||
if threshold % 10 != 0:
|
if threshold % 10 != 0:
|
||||||
self.assertAllEqual(
|
self.assertAllEqual(
|
||||||
[threshold // 10 * 10 + j for j in range(threshold % 10)],
|
[threshold // 10 * 10 + j for j in range(threshold % 10)],
|
||||||
@ -609,7 +610,8 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
self.assertAllEqual([element for _ in range(10)], sess.run(get_next))
|
self.assertAllEqual([element for _ in range(10)],
|
||||||
|
self.evaluate(get_next))
|
||||||
|
|
||||||
|
|
||||||
class UnbatchDatasetBenchmark(test.Benchmark):
|
class UnbatchDatasetBenchmark(test.Benchmark):
|
||||||
|
@ -300,7 +300,7 @@ class BucketBySequenceLengthTest(test_base.DatasetTestBase):
|
|||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
while True:
|
while True:
|
||||||
output = sess.run(batch)
|
output = self.evaluate(batch)
|
||||||
sprs_tensor = (tuple([tuple(idx) for idx in output.indices]),
|
sprs_tensor = (tuple([tuple(idx) for idx in output.indices]),
|
||||||
tuple(output.values))
|
tuple(output.values))
|
||||||
all_sparse_tensors.add(sprs_tensor)
|
all_sparse_tensors.add(sprs_tensor)
|
||||||
|
@ -57,7 +57,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
|||||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||||
with self.test_session(config=worker_config) as sess:
|
with self.test_session(config=worker_config) as sess:
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
self.assertEqual(i, sess.run(next_element))
|
self.assertEqual(i, self.evaluate(next_element))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
|
|
||||||
@ -82,7 +82,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
|||||||
|
|
||||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||||
with self.test_session(config=worker_config) as sess:
|
with self.test_session(config=worker_config) as sess:
|
||||||
self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
|
self.assertAllEqual([0, 1, 2, 3], self.evaluate(next_element))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
|
|
||||||
@ -108,7 +108,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
|||||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||||
with self.test_session(config=worker_config) as sess:
|
with self.test_session(config=worker_config) as sess:
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
self.assertEqual(i, sess.run(next_element))
|
self.assertEqual(i, self.evaluate(next_element))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
|
|
||||||
@ -134,7 +134,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
|||||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||||
with self.test_session(config=worker_config) as sess:
|
with self.test_session(config=worker_config) as sess:
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
self.assertEqual(i, sess.run(next_element))
|
self.assertEqual(i, self.evaluate(next_element))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
|
|
||||||
@ -160,7 +160,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
|||||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||||
with self.test_session(config=worker_config) as sess:
|
with self.test_session(config=worker_config) as sess:
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
self.assertEqual({"a": i}, sess.run(next_element))
|
self.assertEqual({"a": i}, self.evaluate(next_element))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
|
|
||||||
@ -186,7 +186,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
|||||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||||
with self.test_session(config=worker_config) as sess:
|
with self.test_session(config=worker_config) as sess:
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
self.assertEqual({"a": i}, sess.run(next_element))
|
self.assertEqual({"a": i}, self.evaluate(next_element))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
|
|
||||||
@ -217,7 +217,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
|||||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||||
with self.test_session(config=worker_config) as sess:
|
with self.test_session(config=worker_config) as sess:
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
actual = sess.run(next_element)
|
actual = self.evaluate(next_element)
|
||||||
self.assertAllEqual([i], actual.values)
|
self.assertAllEqual([i], actual.values)
|
||||||
self.assertAllEqual([[0, 0]], actual.indices)
|
self.assertAllEqual([[0, 0]], actual.indices)
|
||||||
self.assertAllEqual([2, 2], actual.dense_shape)
|
self.assertAllEqual([2, 2], actual.dense_shape)
|
||||||
@ -251,7 +251,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
|||||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||||
with self.test_session(config=worker_config) as sess:
|
with self.test_session(config=worker_config) as sess:
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
actual = sess.run(next_element)
|
actual = self.evaluate(next_element)
|
||||||
self.assertAllEqual([i], actual.values)
|
self.assertAllEqual([i], actual.values)
|
||||||
self.assertAllEqual([[0, 0]], actual.indices)
|
self.assertAllEqual([[0, 0]], actual.indices)
|
||||||
self.assertAllEqual([2, 2], actual.dense_shape)
|
self.assertAllEqual([2, 2], actual.dense_shape)
|
||||||
@ -271,9 +271,9 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
|||||||
next_element = iterator.get_next()
|
next_element = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(iterator.initializer)
|
self.evaluate(iterator.initializer)
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
self.assertEqual(i, sess.run(next_element))
|
self.assertEqual(i, self.evaluate(next_element))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
|
|
||||||
@ -290,9 +290,9 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
|||||||
next_element = iterator.get_next()
|
next_element = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(iterator.initializer)
|
self.evaluate(iterator.initializer)
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
self.assertEqual(i, sess.run(next_element))
|
self.assertEqual(i, self.evaluate(next_element))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
|
|
||||||
@ -323,9 +323,9 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
|||||||
next_element = iterator.get_next()
|
next_element = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(iterator.initializer)
|
self.evaluate(iterator.initializer)
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
x, y, z = sess.run(next_element)
|
x, y, z = self.evaluate(next_element)
|
||||||
self.assertEqual(i**2, x)
|
self.assertEqual(i**2, x)
|
||||||
self.assertEqual(float(i**2), y)
|
self.assertEqual(float(i**2), y)
|
||||||
self.assertEqual(util_compat.as_bytes(str(i)), z)
|
self.assertEqual(util_compat.as_bytes(str(i)), z)
|
||||||
@ -345,8 +345,8 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
|||||||
next_element = iterator.get_next()
|
next_element = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(iterator.initializer)
|
self.evaluate(iterator.initializer)
|
||||||
self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
|
self.assertAllEqual([0, 1, 2, 3], self.evaluate(next_element))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
|
|
||||||
@ -363,8 +363,8 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
|||||||
next_element = iterator.get_next()
|
next_element = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(iterator.initializer)
|
self.evaluate(iterator.initializer)
|
||||||
self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
|
self.assertAllEqual([0, 1, 2, 3], self.evaluate(next_element))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
|
|
||||||
@ -381,8 +381,8 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
|||||||
next_element = iterator.get_next()
|
next_element = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(iterator.initializer)
|
self.evaluate(iterator.initializer)
|
||||||
self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element))
|
self.assertAllEqual([b"a", b"b", b"c"], self.evaluate(next_element))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
|
|
||||||
@ -399,8 +399,8 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
|||||||
next_element = iterator.get_next()
|
next_element = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(iterator.initializer)
|
self.evaluate(iterator.initializer)
|
||||||
self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element))
|
self.assertAllEqual([b"a", b"b", b"c"], self.evaluate(next_element))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
|
|
||||||
@ -420,9 +420,9 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
|||||||
next_element = iterator.get_next()
|
next_element = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(iterator.initializer)
|
self.evaluate(iterator.initializer)
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
self.assertEqual(i, sess.run(next_element))
|
self.assertEqual(i, self.evaluate(next_element))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
|
|
||||||
@ -447,12 +447,12 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
|||||||
|
|
||||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||||
with self.test_session(config=worker_config) as sess:
|
with self.test_session(config=worker_config) as sess:
|
||||||
sess.run(iterator.initializer)
|
self.evaluate(iterator.initializer)
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
self.assertEqual(i, sess.run(next_element))
|
self.assertEqual(i, self.evaluate(next_element))
|
||||||
sess.run(iterator.initializer)
|
self.evaluate(iterator.initializer)
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
self.assertEqual(i, sess.run(next_element))
|
self.assertEqual(i, self.evaluate(next_element))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
|
|
||||||
@ -477,12 +477,12 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
|||||||
|
|
||||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||||
with self.test_session(config=worker_config) as sess:
|
with self.test_session(config=worker_config) as sess:
|
||||||
sess.run(iterator.initializer)
|
self.evaluate(iterator.initializer)
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
self.assertEqual(i, sess.run(next_element))
|
self.assertEqual(i, self.evaluate(next_element))
|
||||||
sess.run(iterator.initializer)
|
self.evaluate(iterator.initializer)
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
self.assertEqual(i, sess.run(next_element))
|
self.assertEqual(i, self.evaluate(next_element))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
|
|
||||||
@ -499,12 +499,12 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
|||||||
next_element = iterator.get_next()
|
next_element = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(iterator.initializer)
|
self.evaluate(iterator.initializer)
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
self.assertEqual(i, sess.run(next_element))
|
self.assertEqual(i, self.evaluate(next_element))
|
||||||
sess.run(iterator.initializer)
|
self.evaluate(iterator.initializer)
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
self.assertEqual(i, sess.run(next_element))
|
self.assertEqual(i, self.evaluate(next_element))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
|
|
||||||
@ -521,12 +521,12 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
|||||||
next_element = iterator.get_next()
|
next_element = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(iterator.initializer)
|
self.evaluate(iterator.initializer)
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
self.assertEqual(i, sess.run(next_element))
|
self.assertEqual(i, self.evaluate(next_element))
|
||||||
sess.run(iterator.initializer)
|
self.evaluate(iterator.initializer)
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
self.assertEqual(i, sess.run(next_element))
|
self.assertEqual(i, self.evaluate(next_element))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
|
|
||||||
@ -553,7 +553,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
|||||||
|
|
||||||
# For each element of the dataset, assert that the optional evaluates to
|
# For each element of the dataset, assert that the optional evaluates to
|
||||||
# the expected value.
|
# the expected value.
|
||||||
sess.run(iterator.initializer)
|
self.evaluate(iterator.initializer)
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
elem_has_value, elem_value = sess.run([elem_has_value_t, elem_value_t])
|
elem_has_value, elem_value = sess.run([elem_has_value_t, elem_value_t])
|
||||||
self.assertTrue(elem_has_value)
|
self.assertTrue(elem_has_value)
|
||||||
@ -562,7 +562,7 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
|||||||
# After exhausting the iterator, `next_elem.has_value()` will evaluate to
|
# After exhausting the iterator, `next_elem.has_value()` will evaluate to
|
||||||
# false, and attempting to get the value will fail.
|
# false, and attempting to get the value will fail.
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
self.assertFalse(sess.run(elem_has_value_t))
|
self.assertFalse(self.evaluate(elem_has_value_t))
|
||||||
with self.assertRaises(errors.InvalidArgumentError):
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
sess.run(elem_value_t)
|
sess.run(elem_value_t)
|
||||||
|
|
||||||
|
@ -38,13 +38,13 @@ class CounterTest(test_base.DatasetTestBase):
|
|||||||
negative_get_next = negative_iterator.get_next()
|
negative_get_next = negative_iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertEqual(3, sess.run(get_next))
|
self.assertEqual(3, self.evaluate(get_next))
|
||||||
self.assertEqual(3 + 4, sess.run(get_next))
|
self.assertEqual(3 + 4, self.evaluate(get_next))
|
||||||
self.assertEqual(3 + 2 * 4, sess.run(get_next))
|
self.assertEqual(3 + 2 * 4, self.evaluate(get_next))
|
||||||
|
|
||||||
self.assertEqual(0, sess.run(negative_get_next))
|
self.assertEqual(0, self.evaluate(negative_get_next))
|
||||||
self.assertEqual(-1, sess.run(negative_get_next))
|
self.assertEqual(-1, self.evaluate(negative_get_next))
|
||||||
self.assertEqual(-2, sess.run(negative_get_next))
|
self.assertEqual(-2, self.evaluate(negative_get_next))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -41,10 +41,10 @@ class DenseToSparseBatchTest(test_base.DatasetTestBase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
|
|
||||||
for start in range(0, len(components), 4):
|
for start in range(0, len(components), 4):
|
||||||
results = sess.run(get_next)
|
results = self.evaluate(get_next)
|
||||||
self.assertAllEqual([[i, j]
|
self.assertAllEqual([[i, j]
|
||||||
for i, c in enumerate(components[start:start + 4])
|
for i, c in enumerate(components[start:start + 4])
|
||||||
for j in range(c)], results.indices)
|
for j in range(c)], results.indices)
|
||||||
@ -69,10 +69,10 @@ class DenseToSparseBatchTest(test_base.DatasetTestBase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
|
|
||||||
for start in range(0, len(components), 4):
|
for start in range(0, len(components), 4):
|
||||||
results = sess.run(get_next)
|
results = self.evaluate(get_next)
|
||||||
self.assertAllEqual([[i, j, z]
|
self.assertAllEqual([[i, j, z]
|
||||||
for i, c in enumerate(components[start:start + 4])
|
for i, c in enumerate(components[start:start + 4])
|
||||||
for j in range(c)
|
for j in range(c)
|
||||||
|
@ -40,10 +40,10 @@ class DirectedInterleaveDatasetTest(test_base.DatasetTestBase):
|
|||||||
next_element = iterator.get_next()
|
next_element = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(iterator.initializer)
|
self.evaluate(iterator.initializer)
|
||||||
for _ in range(100):
|
for _ in range(100):
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
self.assertEqual(i, sess.run(next_element))
|
self.assertEqual(i, self.evaluate(next_element))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
|
|
||||||
@ -107,7 +107,7 @@ class DirectedInterleaveDatasetTest(test_base.DatasetTestBase):
|
|||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
for i in choice_array:
|
for i in choice_array:
|
||||||
self.assertEqual(words[i], sess.run(next_element))
|
self.assertEqual(words[i], self.evaluate(next_element))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
|
|
||||||
|
@ -44,9 +44,9 @@ class EnumerateDatasetTest(test_base.DatasetTestBase):
|
|||||||
[t.shape for t in get_next[1]])
|
[t.shape for t in get_next[1]])
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
self.assertEqual((20, (b"a", 1, 37.0)), sess.run(get_next))
|
self.assertEqual((20, (b"a", 1, 37.0)), self.evaluate(get_next))
|
||||||
self.assertEqual((21, (b"b", 2, 38.0)), sess.run(get_next))
|
self.assertEqual((21, (b"b", 2, 38.0)), self.evaluate(get_next))
|
||||||
|
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
@ -94,18 +94,18 @@ class FunctionBufferingResourceTest(test_base.DatasetTestBase):
|
|||||||
device0, device1)
|
device0, device1)
|
||||||
|
|
||||||
with self.test_session(config=worker_config) as sess:
|
with self.test_session(config=worker_config) as sess:
|
||||||
elem = sess.run(prefetch_op)
|
elem = self.evaluate(prefetch_op)
|
||||||
self.assertEqual(elem, [1.0])
|
self.assertEqual(elem, [1.0])
|
||||||
elem = sess.run(prefetch_op)
|
elem = self.evaluate(prefetch_op)
|
||||||
self.assertEqual(elem, [2.0])
|
self.assertEqual(elem, [2.0])
|
||||||
elem = sess.run(prefetch_op)
|
elem = self.evaluate(prefetch_op)
|
||||||
self.assertEqual(elem, [3.0])
|
self.assertEqual(elem, [3.0])
|
||||||
elem = sess.run(prefetch_op)
|
elem = self.evaluate(prefetch_op)
|
||||||
self.assertEqual(elem, [4.0])
|
self.assertEqual(elem, [4.0])
|
||||||
self._event.wait()
|
self._event.wait()
|
||||||
elem = sess.run(prefetch_op)
|
elem = self.evaluate(prefetch_op)
|
||||||
self.assertEqual(elem, [5.0])
|
self.assertEqual(elem, [5.0])
|
||||||
sess.run(destroy_op)
|
self.evaluate(destroy_op)
|
||||||
|
|
||||||
def testSameDeviceCPU(self):
|
def testSameDeviceCPU(self):
|
||||||
self._prefetch_fn_helper_one_shot("same_device_cpu",
|
self._prefetch_fn_helper_one_shot("same_device_cpu",
|
||||||
@ -135,35 +135,35 @@ class FunctionBufferingResourceTest(test_base.DatasetTestBase):
|
|||||||
ds, ds_iterator, "reinit", device0, device1)
|
ds, ds_iterator, "reinit", device0, device1)
|
||||||
|
|
||||||
with self.test_session(config=worker_config) as sess:
|
with self.test_session(config=worker_config) as sess:
|
||||||
sess.run(ds_iterator.initializer)
|
self.evaluate(ds_iterator.initializer)
|
||||||
elem = sess.run(prefetch_op)
|
elem = self.evaluate(prefetch_op)
|
||||||
self.assertEqual(elem, [1.0])
|
self.assertEqual(elem, [1.0])
|
||||||
elem = sess.run(prefetch_op)
|
elem = self.evaluate(prefetch_op)
|
||||||
self.assertEqual(elem, [2.0])
|
self.assertEqual(elem, [2.0])
|
||||||
elem = sess.run(prefetch_op)
|
elem = self.evaluate(prefetch_op)
|
||||||
self.assertEqual(elem, [3.0])
|
self.assertEqual(elem, [3.0])
|
||||||
elem = sess.run(prefetch_op)
|
elem = self.evaluate(prefetch_op)
|
||||||
self.assertEqual(elem, [4.0])
|
self.assertEqual(elem, [4.0])
|
||||||
self._event.wait()
|
self._event.wait()
|
||||||
elem = sess.run(prefetch_op)
|
elem = self.evaluate(prefetch_op)
|
||||||
self.assertEqual(elem, [5.0])
|
self.assertEqual(elem, [5.0])
|
||||||
# Lets reset the function buffering resource and reinitialize the
|
# Lets reset the function buffering resource and reinitialize the
|
||||||
# iterator. Should be able to go through this again.
|
# iterator. Should be able to go through this again.
|
||||||
self._event.clear()
|
self._event.clear()
|
||||||
sess.run(reset_op)
|
self.evaluate(reset_op)
|
||||||
sess.run(ds_iterator.initializer)
|
self.evaluate(ds_iterator.initializer)
|
||||||
elem = sess.run(prefetch_op)
|
elem = self.evaluate(prefetch_op)
|
||||||
self.assertEqual(elem, [1.0])
|
self.assertEqual(elem, [1.0])
|
||||||
elem = sess.run(prefetch_op)
|
elem = self.evaluate(prefetch_op)
|
||||||
self.assertEqual(elem, [2.0])
|
self.assertEqual(elem, [2.0])
|
||||||
elem = sess.run(prefetch_op)
|
elem = self.evaluate(prefetch_op)
|
||||||
self.assertEqual(elem, [3.0])
|
self.assertEqual(elem, [3.0])
|
||||||
elem = sess.run(prefetch_op)
|
elem = self.evaluate(prefetch_op)
|
||||||
self.assertEqual(elem, [4.0])
|
self.assertEqual(elem, [4.0])
|
||||||
self._event.wait()
|
self._event.wait()
|
||||||
elem = sess.run(prefetch_op)
|
elem = self.evaluate(prefetch_op)
|
||||||
self.assertEqual(elem, [5.0])
|
self.assertEqual(elem, [5.0])
|
||||||
sess.run(destroy_op)
|
self.evaluate(destroy_op)
|
||||||
|
|
||||||
def testReinitializationOutOfRange(self):
|
def testReinitializationOutOfRange(self):
|
||||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||||
@ -175,30 +175,30 @@ class FunctionBufferingResourceTest(test_base.DatasetTestBase):
|
|||||||
ds, ds_iterator, "reinit", device0, device1)
|
ds, ds_iterator, "reinit", device0, device1)
|
||||||
|
|
||||||
with self.test_session(config=worker_config) as sess:
|
with self.test_session(config=worker_config) as sess:
|
||||||
sess.run(ds_iterator.initializer)
|
self.evaluate(ds_iterator.initializer)
|
||||||
for i in range(1, 10):
|
for i in range(1, 10):
|
||||||
elem = sess.run(prefetch_op)
|
elem = self.evaluate(prefetch_op)
|
||||||
self.assertEqual(elem, [float(i)])
|
self.assertEqual(elem, [float(i)])
|
||||||
# Try fetching after its over twice to test out end of sequence.
|
# Try fetching after its over twice to test out end of sequence.
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(prefetch_op)
|
self.evaluate(prefetch_op)
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(prefetch_op)
|
self.evaluate(prefetch_op)
|
||||||
|
|
||||||
# Now reset everything and try it out again.
|
# Now reset everything and try it out again.
|
||||||
self._event.clear()
|
self._event.clear()
|
||||||
sess.run(reset_op)
|
self.evaluate(reset_op)
|
||||||
sess.run(ds_iterator.initializer)
|
self.evaluate(ds_iterator.initializer)
|
||||||
for i in range(1, 10):
|
for i in range(1, 10):
|
||||||
elem = sess.run(prefetch_op)
|
elem = self.evaluate(prefetch_op)
|
||||||
self.assertEqual(elem, [float(i)])
|
self.assertEqual(elem, [float(i)])
|
||||||
# Try fetching after its over twice to test out end of sequence.
|
# Try fetching after its over twice to test out end of sequence.
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(prefetch_op)
|
self.evaluate(prefetch_op)
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(prefetch_op)
|
self.evaluate(prefetch_op)
|
||||||
|
|
||||||
sess.run(destroy_op)
|
self.evaluate(destroy_op)
|
||||||
|
|
||||||
def testStringsGPU(self):
|
def testStringsGPU(self):
|
||||||
if not test_util.is_gpu_available():
|
if not test_util.is_gpu_available():
|
||||||
@ -235,13 +235,13 @@ class FunctionBufferingResourceTest(test_base.DatasetTestBase):
|
|||||||
buffer_resource_handle, ignore_lookup_error=True)
|
buffer_resource_handle, ignore_lookup_error=True)
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertEqual([b"a"], sess.run(prefetch_op))
|
self.assertEqual([b"a"], self.evaluate(prefetch_op))
|
||||||
self.assertEqual([b"b"], sess.run(prefetch_op))
|
self.assertEqual([b"b"], self.evaluate(prefetch_op))
|
||||||
self.assertEqual([b"c"], sess.run(prefetch_op))
|
self.assertEqual([b"c"], self.evaluate(prefetch_op))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(prefetch_op)
|
self.evaluate(prefetch_op)
|
||||||
|
|
||||||
sess.run(destroy_op)
|
self.evaluate(destroy_op)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -39,7 +39,7 @@ class GroupByReducerTest(test_base.DatasetTestBase):
|
|||||||
get_next = dataset.make_one_shot_iterator().get_next()
|
get_next = dataset.make_one_shot_iterator().get_next()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
for expected in values:
|
for expected in values:
|
||||||
got = sess.run(get_next)
|
got = self.evaluate(get_next)
|
||||||
self.assertEqual(got, expected)
|
self.assertEqual(got, expected)
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
@ -127,7 +127,7 @@ class GroupByReducerTest(test_base.DatasetTestBase):
|
|||||||
iterator = dataset.make_one_shot_iterator()
|
iterator = dataset.make_one_shot_iterator()
|
||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
x, y = sess.run(get_next)
|
x, y = self.evaluate(get_next)
|
||||||
self.assertAllEqual([0] * (2**i), x)
|
self.assertAllEqual([0] * (2**i), x)
|
||||||
self.assertAllEqual(np.array(1, ndmin=i), y)
|
self.assertAllEqual(np.array(1, ndmin=i), y)
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
@ -190,7 +190,7 @@ class GroupByReducerTest(test_base.DatasetTestBase):
|
|||||||
grouping.group_by_reducer(lambda x, y: np.int64(0), reducer))
|
grouping.group_by_reducer(lambda x, y: np.int64(0), reducer))
|
||||||
get_next = dataset.make_one_shot_iterator().get_next()
|
get_next = dataset.make_one_shot_iterator().get_next()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
x, y = sess.run(get_next)
|
x, y = self.evaluate(get_next)
|
||||||
self.assertAllEqual(x, np.asarray([x for x in range(10)]))
|
self.assertAllEqual(x, np.asarray([x for x in range(10)]))
|
||||||
self.assertEqual(y, 45)
|
self.assertEqual(y, 45)
|
||||||
|
|
||||||
|
@ -68,9 +68,9 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
|
|
||||||
which_bucket, bucketed_values = sess.run(get_next)
|
which_bucket, bucketed_values = self.evaluate(get_next)
|
||||||
|
|
||||||
self.assertEqual(0, which_bucket)
|
self.assertEqual(0, which_bucket)
|
||||||
|
|
||||||
@ -103,11 +103,11 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
|
|
||||||
# Get two minibatches (one containing even values, one containing odds)
|
# Get two minibatches (one containing even values, one containing odds)
|
||||||
which_bucket_even, bucketed_values_even = sess.run(get_next)
|
which_bucket_even, bucketed_values_even = self.evaluate(get_next)
|
||||||
which_bucket_odd, bucketed_values_odd = sess.run(get_next)
|
which_bucket_odd, bucketed_values_odd = self.evaluate(get_next)
|
||||||
|
|
||||||
# Count number of bucket_tensors.
|
# Count number of bucket_tensors.
|
||||||
self.assertEqual(3, len(bucketed_values_even))
|
self.assertEqual(3, len(bucketed_values_even))
|
||||||
@ -174,11 +174,11 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
|
|
||||||
# Get two minibatches ([0, 2, ...] and [64, 66, ...])
|
# Get two minibatches ([0, 2, ...] and [64, 66, ...])
|
||||||
which_bucket0, bucketed_values_even0 = sess.run(get_next)
|
which_bucket0, bucketed_values_even0 = self.evaluate(get_next)
|
||||||
which_bucket1, bucketed_values_even1 = sess.run(get_next)
|
which_bucket1, bucketed_values_even1 = self.evaluate(get_next)
|
||||||
|
|
||||||
# Ensure that bucket 1 was completely filtered out
|
# Ensure that bucket 1 was completely filtered out
|
||||||
self.assertAllEqual(0, which_bucket0)
|
self.assertAllEqual(0, which_bucket0)
|
||||||
@ -207,11 +207,11 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
batches = 0
|
batches = 0
|
||||||
while True:
|
while True:
|
||||||
result = sess.run(get_next)
|
result = self.evaluate(get_next)
|
||||||
is_even = all(x % 2 == 0 for x in result)
|
is_even = all(x % 2 == 0 for x in result)
|
||||||
is_odd = all(x % 2 == 1 for x in result)
|
is_odd = all(x % 2 == 1 for x in result)
|
||||||
self.assertTrue(is_even or is_odd)
|
self.assertTrue(is_even or is_odd)
|
||||||
@ -232,11 +232,11 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
counts = []
|
counts = []
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
while True:
|
while True:
|
||||||
result = sess.run(get_next)
|
result = self.evaluate(get_next)
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
all(x % 2 == 0
|
all(x % 2 == 0
|
||||||
for x in result) or all(x % 2 == 1)
|
for x in result) or all(x % 2 == 1)
|
||||||
@ -259,16 +259,16 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
# The input is infinite, so this test demonstrates that:
|
# The input is infinite, so this test demonstrates that:
|
||||||
# 1. We produce output without having to consume the entire input,
|
# 1. We produce output without having to consume the entire input,
|
||||||
# 2. Different buckets can produce output at different rates, and
|
# 2. Different buckets can produce output at different rates, and
|
||||||
# 3. For deterministic input, the output is deterministic.
|
# 3. For deterministic input, the output is deterministic.
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
|
self.assertAllEqual([0, 0, 0, 0], self.evaluate(get_next))
|
||||||
self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
|
self.assertAllEqual([1, 1, 1, 1], self.evaluate(get_next))
|
||||||
self.assertAllEqual([2, 2, 2, 2], sess.run(get_next))
|
self.assertAllEqual([2, 2, 2, 2], self.evaluate(get_next))
|
||||||
self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
|
self.assertAllEqual([0, 0, 0, 0], self.evaluate(get_next))
|
||||||
|
|
||||||
def testSmallGroups(self):
|
def testSmallGroups(self):
|
||||||
components = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0], dtype=np.int64)
|
components = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0], dtype=np.int64)
|
||||||
@ -280,13 +280,13 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
|
self.assertAllEqual([0, 0, 0, 0], self.evaluate(get_next))
|
||||||
self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
|
self.assertAllEqual([1, 1, 1, 1], self.evaluate(get_next))
|
||||||
# The small outputs at the end are deterministically produced in key
|
# The small outputs at the end are deterministically produced in key
|
||||||
# order.
|
# order.
|
||||||
self.assertAllEqual([0, 0, 0], sess.run(get_next))
|
self.assertAllEqual([0, 0, 0], self.evaluate(get_next))
|
||||||
self.assertAllEqual([1], sess.run(get_next))
|
self.assertAllEqual([1], self.evaluate(get_next))
|
||||||
|
|
||||||
def testEmpty(self):
|
def testEmpty(self):
|
||||||
iterator = (
|
iterator = (
|
||||||
@ -297,7 +297,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
errors.InvalidArgumentError,
|
errors.InvalidArgumentError,
|
||||||
"Window size must be greater than zero, but got 0."):
|
"Window size must be greater than zero, but got 0."):
|
||||||
@ -323,7 +323,7 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
with self.assertRaises(errors.InvalidArgumentError):
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -351,11 +351,11 @@ class GroupByWindowTest(test_base.DatasetTestBase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
counts = []
|
counts = []
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
while True:
|
while True:
|
||||||
tight_result, multiple_of_10_result = sess.run(get_next)
|
tight_result, multiple_of_10_result = self.evaluate(get_next)
|
||||||
self.assertEqual(0, multiple_of_10_result.shape[1] % 10)
|
self.assertEqual(0, multiple_of_10_result.shape[1] % 10)
|
||||||
self.assertAllEqual(tight_result,
|
self.assertAllEqual(tight_result,
|
||||||
multiple_of_10_result[:, :tight_result.shape[1]])
|
multiple_of_10_result[:, :tight_result.shape[1]])
|
||||||
|
@ -47,9 +47,9 @@ class IgnoreErrorsTest(test_base.DatasetTestBase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for x in [1., 2., 3., 5.]:
|
for x in [1., 2., 3., 5.]:
|
||||||
self.assertEqual(x, sess.run(get_next))
|
self.assertEqual(x, self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -65,9 +65,9 @@ class IgnoreErrorsTest(test_base.DatasetTestBase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for x in [1., 2., 3., 5.]:
|
for x in [1., 2., 3., 5.]:
|
||||||
self.assertEqual(x, sess.run(get_next))
|
self.assertEqual(x, self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -93,9 +93,9 @@ class IgnoreErrorsTest(test_base.DatasetTestBase):
|
|||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
# All of the files are present.
|
# All of the files are present.
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
self.assertEqual(compat.as_bytes(filename), sess.run(get_next))
|
self.assertEqual(compat.as_bytes(filename), self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -104,9 +104,9 @@ class IgnoreErrorsTest(test_base.DatasetTestBase):
|
|||||||
|
|
||||||
# Attempting to read filenames[0] will fail, but ignore_errors()
|
# Attempting to read filenames[0] will fail, but ignore_errors()
|
||||||
# will catch the error.
|
# will catch the error.
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for filename in filenames[1:]:
|
for filename in filenames[1:]:
|
||||||
self.assertEqual(compat.as_bytes(filename), sess.run(get_next))
|
self.assertEqual(compat.as_bytes(filename), self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
|
@ -53,7 +53,7 @@ class IndexedDatasetOpsTest(test_base.DatasetTestBase):
|
|||||||
ds = indexed_dataset_ops.IdentityIndexedDataset(16)
|
ds = indexed_dataset_ops.IdentityIndexedDataset(16)
|
||||||
materialized = ds.materialize()
|
materialized = ds.materialize()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(materialized.initializer)
|
self.evaluate(materialized.initializer)
|
||||||
placeholder = array_ops.placeholder(dtypes.uint64, shape=[])
|
placeholder = array_ops.placeholder(dtypes.uint64, shape=[])
|
||||||
for i in range(16):
|
for i in range(16):
|
||||||
output = sess.run(
|
output = sess.run(
|
||||||
@ -68,9 +68,9 @@ class IndexedDatasetOpsTest(test_base.DatasetTestBase):
|
|||||||
itr = ds.make_initializable_iterator()
|
itr = ds.make_initializable_iterator()
|
||||||
n = itr.get_next()
|
n = itr.get_next()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(itr.initializer)
|
self.evaluate(itr.initializer)
|
||||||
for i in range(16):
|
for i in range(16):
|
||||||
output = sess.run(n)
|
output = self.evaluate(n)
|
||||||
self.assertEqual(i, output)
|
self.assertEqual(i, output)
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(n)
|
sess.run(n)
|
||||||
|
@ -112,10 +112,10 @@ class MakeBatchedFeaturesDatasetTest(
|
|||||||
next_element = iterator.get_next()
|
next_element = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for file_batch, _, _, _, record_batch, _ in self._next_expected_batch(
|
for file_batch, _, _, _, record_batch, _ in self._next_expected_batch(
|
||||||
range(self._num_files), 2, 10):
|
range(self._num_files), 2, 10):
|
||||||
actual_batch = sess.run(next_element)
|
actual_batch = self.evaluate(next_element)
|
||||||
self.assertAllEqual(file_batch, actual_batch["file"])
|
self.assertAllEqual(file_batch, actual_batch["file"])
|
||||||
self.assertAllEqual(record_batch, actual_batch["record"])
|
self.assertAllEqual(record_batch, actual_batch["record"])
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
@ -90,7 +90,7 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
|
|||||||
batch_size,
|
batch_size,
|
||||||
num_epochs,
|
num_epochs,
|
||||||
):
|
):
|
||||||
actual_features = sess.run(nxt)
|
actual_features = self.evaluate(nxt)
|
||||||
|
|
||||||
if label_name is not None:
|
if label_name is not None:
|
||||||
expected_labels = expected_features.pop(label_name)
|
expected_labels = expected_features.pop(label_name)
|
||||||
|
@ -105,7 +105,7 @@ class MakeTFRecordDatasetTest(
|
|||||||
for expected_batch in self._next_expected_batch(
|
for expected_batch in self._next_expected_batch(
|
||||||
file_indices, batch_size, num_epochs, interleave_cycle_length,
|
file_indices, batch_size, num_epochs, interleave_cycle_length,
|
||||||
drop_final_batch, use_parser_fn):
|
drop_final_batch, use_parser_fn):
|
||||||
actual_batch = sess.run(outputs)
|
actual_batch = self.evaluate(outputs)
|
||||||
self.assertAllEqual(expected_batch, actual_batch)
|
self.assertAllEqual(expected_batch, actual_batch)
|
||||||
|
|
||||||
def _read_test(self, batch_size, num_epochs, file_index=None,
|
def _read_test(self, batch_size, num_epochs, file_index=None,
|
||||||
@ -188,7 +188,7 @@ class MakeTFRecordDatasetTest(
|
|||||||
iterator = dataset.make_initializable_iterator()
|
iterator = dataset.make_initializable_iterator()
|
||||||
next_element = iterator.get_next()
|
next_element = iterator.get_next()
|
||||||
|
|
||||||
sess.run(iterator.initializer)
|
self.evaluate(iterator.initializer)
|
||||||
first_batches = []
|
first_batches = []
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
@ -196,7 +196,7 @@ class MakeTFRecordDatasetTest(
|
|||||||
except errors.OutOfRangeError:
|
except errors.OutOfRangeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
sess.run(iterator.initializer)
|
self.evaluate(iterator.initializer)
|
||||||
second_batches = []
|
second_batches = []
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
|
@ -89,7 +89,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
sess.run(init_op, feed_dict={count: 28, batch_size: 14})
|
sess.run(init_op, feed_dict={count: 28, batch_size: 14})
|
||||||
num_batches = (28 * 7) // 14
|
num_batches = (28 * 7) // 14
|
||||||
for i in range(num_batches):
|
for i in range(num_batches):
|
||||||
result = sess.run(get_next)
|
result = self.evaluate(get_next)
|
||||||
for component, result_component in zip(components, result):
|
for component, result_component in zip(components, result):
|
||||||
for j in range(14):
|
for j in range(14):
|
||||||
self.assertAllEqual(component[(i * 14 + j) % 7]**2,
|
self.assertAllEqual(component[(i * 14 + j) % 7]**2,
|
||||||
@ -104,12 +104,12 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
# We expect (num_batches - 1) full-sized batches.
|
# We expect (num_batches - 1) full-sized batches.
|
||||||
num_batches = int(math.ceil((14 * 7) / 8))
|
num_batches = int(math.ceil((14 * 7) / 8))
|
||||||
for i in range(num_batches - 1):
|
for i in range(num_batches - 1):
|
||||||
result = sess.run(get_next)
|
result = self.evaluate(get_next)
|
||||||
for component, result_component in zip(components, result):
|
for component, result_component in zip(components, result):
|
||||||
for j in range(8):
|
for j in range(8):
|
||||||
self.assertAllEqual(component[(i * 8 + j) % 7]**2,
|
self.assertAllEqual(component[(i * 8 + j) % 7]**2,
|
||||||
result_component[j])
|
result_component[j])
|
||||||
result = sess.run(get_next)
|
result = self.evaluate(get_next)
|
||||||
for component, result_component in zip(components, result):
|
for component, result_component in zip(components, result):
|
||||||
for j in range((14 * 7) % 8):
|
for j in range((14 * 7) % 8):
|
||||||
self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2,
|
self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2,
|
||||||
@ -152,10 +152,10 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
self.assertEqual([None, 1], iterator.output_shapes.as_list())
|
self.assertEqual([None, 1], iterator.output_shapes.as_list())
|
||||||
next_element = iterator.get_next()
|
next_element = iterator.get_next()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
|
self.assertAllEqual([[0], [1], [4], [9]], self.evaluate(next_element))
|
||||||
self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
|
self.assertAllEqual([[16], [25], [36], [49]], self.evaluate(next_element))
|
||||||
if not drop_remainder:
|
if not drop_remainder:
|
||||||
self.assertAllEqual([[64], [81]], sess.run(next_element))
|
self.assertAllEqual([[64], [81]], self.evaluate(next_element))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
|
|
||||||
@ -177,9 +177,9 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
self.assertEqual([None, 1], iterator.output_shapes.as_list())
|
self.assertEqual([None, 1], iterator.output_shapes.as_list())
|
||||||
next_element = iterator.get_next()
|
next_element = iterator.get_next()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
|
self.assertAllEqual([[0], [1], [4], [9]], self.evaluate(next_element))
|
||||||
self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
|
self.assertAllEqual([[16], [25], [36], [49]], self.evaluate(next_element))
|
||||||
self.assertAllEqual([[64], [81]], sess.run(next_element))
|
self.assertAllEqual([[64], [81]], self.evaluate(next_element))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
|
|
||||||
@ -201,7 +201,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
elements.append(iterator.get_next())
|
elements.append(iterator.get_next())
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
got = sess.run(elements)
|
got = self.evaluate(elements)
|
||||||
got.sort(key=lambda x: x[0])
|
got.sort(key=lambda x: x[0])
|
||||||
expected = []
|
expected = []
|
||||||
for j in range(100):
|
for j in range(100):
|
||||||
@ -230,7 +230,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
elements.append(iterator.get_next())
|
elements.append(iterator.get_next())
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
for i in range(4):
|
for i in range(4):
|
||||||
got = sess.run(elements)
|
got = self.evaluate(elements)
|
||||||
got.sort(key=lambda x: x[0])
|
got.sort(key=lambda x: x[0])
|
||||||
expected = []
|
expected = []
|
||||||
for j in range(100):
|
for j in range(100):
|
||||||
@ -261,9 +261,9 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
actual = sess.run(get_next)
|
actual = self.evaluate(get_next)
|
||||||
expected = sparse_tensor.SparseTensorValue(
|
expected = sparse_tensor.SparseTensorValue(
|
||||||
indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
|
indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
|
||||||
values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4],
|
values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4],
|
||||||
@ -321,7 +321,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
init_op = iterator.initializer
|
init_op = iterator.initializer
|
||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||||
"number of elements does not match"):
|
"number of elements does not match"):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
@ -393,7 +393,8 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
for i in range(threshold // 10):
|
for i in range(threshold // 10):
|
||||||
self.assertAllEqual([i * 10 + j for j in range(10)], sess.run(get_next))
|
self.assertAllEqual([i * 10 + j for j in range(10)],
|
||||||
|
self.evaluate(get_next))
|
||||||
if threshold % 10 != 0:
|
if threshold % 10 != 0:
|
||||||
self.assertAllEqual(
|
self.assertAllEqual(
|
||||||
[threshold // 10 * 10 + j for j in range(threshold % 10)],
|
[threshold // 10 * 10 + j for j in range(threshold % 10)],
|
||||||
@ -442,7 +443,8 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
self.assertAllEqual([element for _ in range(10)], sess.run(get_next))
|
self.assertAllEqual([element for _ in range(10)],
|
||||||
|
self.evaluate(get_next))
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
("Identity", None, lambda x: x, None),
|
("Identity", None, lambda x: x, None),
|
||||||
@ -462,7 +464,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
else:
|
else:
|
||||||
expected = map_fn(
|
expected = map_fn(
|
||||||
sess.run(self.structuredElement(structure, shape=[10])))
|
sess.run(self.structuredElement(structure, shape=[10])))
|
||||||
self.assertAllEqual(expected, sess.run(get_next))
|
self.assertAllEqual(expected, self.evaluate(get_next))
|
||||||
|
|
||||||
def testShortCircuitCapturedInput(self):
|
def testShortCircuitCapturedInput(self):
|
||||||
captured_t = array_ops.placeholder(dtypes.int64, shape=[])
|
captured_t = array_ops.placeholder(dtypes.int64, shape=[])
|
||||||
@ -473,7 +475,7 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(iterator.initializer, feed_dict={captured_t: 42})
|
sess.run(iterator.initializer, feed_dict={captured_t: 42})
|
||||||
self.assertAllEqual([42] * 10, sess.run(get_next))
|
self.assertAllEqual([42] * 10, self.evaluate(get_next))
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
("Normal", False),
|
("Normal", False),
|
||||||
|
@ -218,7 +218,7 @@ class MapDefunTest(test_base.DatasetTestBase):
|
|||||||
|
|
||||||
def _assert_op_cancelled(self, sess, map_defun_op):
|
def _assert_op_cancelled(self, sess, map_defun_op):
|
||||||
with self.assertRaisesRegexp(errors.CancelledError, "was cancelled"):
|
with self.assertRaisesRegexp(errors.CancelledError, "was cancelled"):
|
||||||
sess.run(map_defun_op)
|
self.evaluate(map_defun_op)
|
||||||
|
|
||||||
def testMapDefunWithParentCancellation(self):
|
def testMapDefunWithParentCancellation(self):
|
||||||
# Checks that a cancellation of the parent graph is threaded through to
|
# Checks that a cancellation of the parent graph is threaded through to
|
||||||
|
@ -72,7 +72,7 @@ class OverrideThreadpoolTest(test_base.DatasetTestBase,
|
|||||||
next_element = iterator.get_next()
|
next_element = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(iterator.initializer)
|
self.evaluate(iterator.initializer)
|
||||||
thread_ids = []
|
thread_ids = []
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
|
@ -637,11 +637,11 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
for j in range(2):
|
for j in range(2):
|
||||||
expected = [i, 0] if j % 2 == 0 else [0, -i]
|
expected = [i, 0] if j % 2 == 0 else [0, -i]
|
||||||
self.assertAllEqual(expected, sess.run(get_next))
|
self.assertAllEqual(expected, self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -796,7 +796,7 @@ class ParallelInterleaveTest(test_base.DatasetTestBase):
|
|||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
elements = []
|
elements = []
|
||||||
sess.run(iterator.initializer)
|
self.evaluate(iterator.initializer)
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
elements.extend(sess.run(next_element))
|
elements.extend(sess.run(next_element))
|
||||||
|
@ -57,7 +57,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
|
|||||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||||
with self.test_session(config=worker_config) as sess:
|
with self.test_session(config=worker_config) as sess:
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
self.assertEqual(i, sess.run(next_element))
|
self.assertEqual(i, self.evaluate(next_element))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
|
|
||||||
@ -87,7 +87,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
|
|||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
self.assertEqual(i, sess.run(next_element))
|
self.assertEqual(i, self.evaluate(next_element))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
|
|
||||||
@ -117,7 +117,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
|
|||||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||||
with self.test_session(config=worker_config) as sess:
|
with self.test_session(config=worker_config) as sess:
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
self.assertEqual({"a": i}, sess.run(next_element))
|
self.assertEqual({"a": i}, self.evaluate(next_element))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
|
|
||||||
@ -150,7 +150,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
|
|||||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||||
with self.test_session(config=worker_config) as sess:
|
with self.test_session(config=worker_config) as sess:
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
actual = sess.run(next_element)
|
actual = self.evaluate(next_element)
|
||||||
self.assertAllEqual([i], actual.values)
|
self.assertAllEqual([i], actual.values)
|
||||||
self.assertAllEqual([[0, 0]], actual.indices)
|
self.assertAllEqual([[0, 0]], actual.indices)
|
||||||
self.assertAllEqual([2, 2], actual.dense_shape)
|
self.assertAllEqual([2, 2], actual.dense_shape)
|
||||||
@ -170,7 +170,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
|
|||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
self.assertEqual(i, sess.run(next_element))
|
self.assertEqual(i, self.evaluate(next_element))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
|
|
||||||
@ -199,12 +199,12 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
|
|||||||
|
|
||||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||||
with self.test_session(config=worker_config) as sess:
|
with self.test_session(config=worker_config) as sess:
|
||||||
sess.run(iterator.initializer)
|
self.evaluate(iterator.initializer)
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
self.assertEqual(i, sess.run(next_element))
|
self.assertEqual(i, self.evaluate(next_element))
|
||||||
sess.run(iterator.initializer)
|
self.evaluate(iterator.initializer)
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
self.assertEqual(i, sess.run(next_element))
|
self.assertEqual(i, self.evaluate(next_element))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
|
|
||||||
@ -220,12 +220,12 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase):
|
|||||||
next_element = iterator.get_next()
|
next_element = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(iterator.initializer)
|
self.evaluate(iterator.initializer)
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
self.assertEqual(i, sess.run(next_element))
|
self.assertEqual(i, self.evaluate(next_element))
|
||||||
sess.run(iterator.initializer)
|
self.evaluate(iterator.initializer)
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
self.assertEqual(i, sess.run(next_element))
|
self.assertEqual(i, self.evaluate(next_element))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
|
|
||||||
|
@ -60,7 +60,7 @@ class ScanTest(test_base.DatasetTestBase):
|
|||||||
feed_dict={start: start_val, step: step_val, take: take_val})
|
feed_dict={start: start_val, step: step_val, take: take_val})
|
||||||
for expected, _ in zip(
|
for expected, _ in zip(
|
||||||
itertools.count(start_val, step_val), range(take_val)):
|
itertools.count(start_val, step_val), range(take_val)):
|
||||||
self.assertEqual(expected, sess.run(next_element))
|
self.assertEqual(expected, self.evaluate(next_element))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
|
|
||||||
@ -110,7 +110,7 @@ class ScanTest(test_base.DatasetTestBase):
|
|||||||
feed_dict={start: start_val, step: step_val, take: take_val})
|
feed_dict={start: start_val, step: step_val, take: take_val})
|
||||||
for expected, _ in zip(
|
for expected, _ in zip(
|
||||||
itertools.count(start_val, step_val), range(take_val)):
|
itertools.count(start_val, step_val), range(take_val)):
|
||||||
self.assertEqual(expected, sess.run(next_element).values[0])
|
self.assertEqual(expected, self.evaluate(next_element).values[0])
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
|
|
||||||
@ -136,7 +136,7 @@ class ScanTest(test_base.DatasetTestBase):
|
|||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
(longer_vector_val, larger_rank_val), _ = sess.run(next_element)
|
(longer_vector_val, larger_rank_val), _ = self.evaluate(next_element)
|
||||||
self.assertAllEqual([0] * (2**i), longer_vector_val)
|
self.assertAllEqual([0] * (2**i), longer_vector_val)
|
||||||
self.assertAllEqual(np.array(1, ndmin=i), larger_rank_val)
|
self.assertAllEqual(np.array(1, ndmin=i), larger_rank_val)
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
@ -71,19 +71,19 @@ class RangeDatasetSerializationTest(
|
|||||||
with ops.Graph().as_default() as g:
|
with ops.Graph().as_default() as g:
|
||||||
init_op, get_next, save_op, _ = _build_graph(start, stop)
|
init_op, get_next, save_op, _ = _build_graph(start, stop)
|
||||||
with self.session(graph=g) as sess:
|
with self.session(graph=g) as sess:
|
||||||
sess.run(variables.global_variables_initializer())
|
self.evaluate(variables.global_variables_initializer())
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for i in range(start, break_point):
|
for i in range(start, break_point):
|
||||||
self.assertEqual(i, sess.run(get_next))
|
self.assertEqual(i, self.evaluate(get_next))
|
||||||
sess.run(save_op)
|
self.evaluate(save_op)
|
||||||
|
|
||||||
with ops.Graph().as_default() as g:
|
with ops.Graph().as_default() as g:
|
||||||
init_op, get_next, _, restore_op = _build_graph(start, stop)
|
init_op, get_next, _, restore_op = _build_graph(start, stop)
|
||||||
with self.session(graph=g) as sess:
|
with self.session(graph=g) as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
sess.run(restore_op)
|
self.evaluate(restore_op)
|
||||||
for i in range(break_point, stop):
|
for i in range(break_point, stop):
|
||||||
self.assertEqual(i, sess.run(get_next))
|
self.assertEqual(i, self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -91,14 +91,14 @@ class RangeDatasetSerializationTest(
|
|||||||
with ops.Graph().as_default() as g:
|
with ops.Graph().as_default() as g:
|
||||||
init_op, get_next, save_op, restore_op = _build_graph(start, stop)
|
init_op, get_next, save_op, restore_op = _build_graph(start, stop)
|
||||||
with self.session(graph=g) as sess:
|
with self.session(graph=g) as sess:
|
||||||
sess.run(variables.global_variables_initializer())
|
self.evaluate(variables.global_variables_initializer())
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for i in range(start, break_point):
|
for i in range(start, break_point):
|
||||||
self.assertEqual(i, sess.run(get_next))
|
self.assertEqual(i, self.evaluate(get_next))
|
||||||
sess.run(save_op)
|
self.evaluate(save_op)
|
||||||
sess.run(restore_op)
|
self.evaluate(restore_op)
|
||||||
for i in range(break_point, stop):
|
for i in range(break_point, stop):
|
||||||
self.assertEqual(i, sess.run(get_next))
|
self.assertEqual(i, self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
|
@ -62,7 +62,7 @@ class SerializationIntegrationTest(test.TestCase):
|
|||||||
with self.session(graph=g) as sess:
|
with self.session(graph=g) as sess:
|
||||||
sess.run(init_ops)
|
sess.run(init_ops)
|
||||||
for _ in range(break_point):
|
for _ in range(break_point):
|
||||||
output = sess.run(get_next_ops)
|
output = self.evaluate(get_next_ops)
|
||||||
for i in range(num_pipelines):
|
for i in range(num_pipelines):
|
||||||
all_outputs[i].append(output[i])
|
all_outputs[i].append(output[i])
|
||||||
saver.save(sess, self._ckpt_path())
|
saver.save(sess, self._ckpt_path())
|
||||||
@ -73,7 +73,7 @@ class SerializationIntegrationTest(test.TestCase):
|
|||||||
with self.session(graph=g) as sess:
|
with self.session(graph=g) as sess:
|
||||||
saver.restore(sess, self._ckpt_path())
|
saver.restore(sess, self._ckpt_path())
|
||||||
for _ in range(num_outputs - break_point):
|
for _ in range(num_outputs - break_point):
|
||||||
output = sess.run(get_next_ops)
|
output = self.evaluate(get_next_ops)
|
||||||
for i in range(num_pipelines):
|
for i in range(num_pipelines):
|
||||||
all_outputs[i].append(output[i])
|
all_outputs[i].append(output[i])
|
||||||
|
|
||||||
|
@ -108,7 +108,7 @@ class ShuffleAndRepeatTest(test_base.DatasetTestBase):
|
|||||||
shuffle_ops.shuffle_and_repeat(buffer_size=21))
|
shuffle_ops.shuffle_and_repeat(buffer_size=21))
|
||||||
get_next_op = ds.make_one_shot_iterator().get_next()
|
get_next_op = ds.make_one_shot_iterator().get_next()
|
||||||
with self.session(graph=g) as sess:
|
with self.session(graph=g) as sess:
|
||||||
sess.run(get_next_op)
|
self.evaluate(get_next_op)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -38,10 +38,10 @@ class SleepTest(test_base.DatasetTestBase):
|
|||||||
next_element = iterator.get_next()
|
next_element = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(iterator.initializer)
|
self.evaluate(iterator.initializer)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
self.assertEqual(i, sess.run(next_element))
|
self.assertEqual(i, self.evaluate(next_element))
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
self.assertGreater(end_time - start_time, (10 * sleep_microseconds) / 1e6)
|
self.assertGreater(end_time - start_time, (10 * sleep_microseconds) / 1e6)
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
@ -39,8 +39,9 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
|||||||
"ORDER BY first_name DESC"
|
"ORDER BY first_name DESC"
|
||||||
})
|
})
|
||||||
for _ in range(2): # Dataset is repeated. See setUp.
|
for _ in range(2): # Dataset is repeated. See setUp.
|
||||||
self.assertEqual((b"John", b"Doe", b"Hi!"), sess.run(get_next))
|
self.assertEqual((b"John", b"Doe", b"Hi!"), self.evaluate(get_next))
|
||||||
self.assertEqual((b"Jane", b"Moe", b"Hi again!"), sess.run(get_next))
|
self.assertEqual((b"Jane", b"Moe", b"Hi again!"),
|
||||||
|
self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -58,7 +59,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
|||||||
"ON students.first_name = people.first_name "
|
"ON students.first_name = people.first_name "
|
||||||
"AND students.last_name = people.last_name"
|
"AND students.last_name = people.last_name"
|
||||||
})
|
})
|
||||||
self.assertEqual((b"John", b"California", b"Hi!"), sess.run(get_next))
|
self.assertEqual((b"John", b"California", b"Hi!"),
|
||||||
|
self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -75,8 +77,9 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
|||||||
"SELECT first_name, last_name, favorite_nonsense_word "
|
"SELECT first_name, last_name, favorite_nonsense_word "
|
||||||
"FROM students ORDER BY first_name DESC"
|
"FROM students ORDER BY first_name DESC"
|
||||||
})
|
})
|
||||||
self.assertEqual((b"John", b"Doe", b"n\0nsense"), sess.run(get_next))
|
self.assertEqual((b"John", b"Doe", b"n\0nsense"), self.evaluate(get_next))
|
||||||
self.assertEqual((b"Jane", b"Moe", b"nonsense\0"), sess.run(get_next))
|
self.assertEqual((b"Jane", b"Moe", b"nonsense\0"),
|
||||||
|
self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -93,8 +96,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
|||||||
self.query: "SELECT first_name, last_name, motto FROM students "
|
self.query: "SELECT first_name, last_name, motto FROM students "
|
||||||
"ORDER BY first_name DESC"
|
"ORDER BY first_name DESC"
|
||||||
})
|
})
|
||||||
self.assertEqual((b"John", b"Doe", b"Hi!"), sess.run(get_next))
|
self.assertEqual((b"John", b"Doe", b"Hi!"), self.evaluate(get_next))
|
||||||
self.assertEqual((b"Jane", b"Moe", b"Hi again!"), sess.run(get_next))
|
self.assertEqual((b"Jane", b"Moe", b"Hi again!"), self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
sess.run(
|
sess.run(
|
||||||
@ -103,7 +106,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
|||||||
self.query: "SELECT first_name, last_name, state FROM people "
|
self.query: "SELECT first_name, last_name, state FROM people "
|
||||||
"ORDER BY first_name DESC"
|
"ORDER BY first_name DESC"
|
||||||
})
|
})
|
||||||
self.assertEqual((b"John", b"Doe", b"California"), sess.run(get_next))
|
self.assertEqual((b"John", b"Doe", b"California"),
|
||||||
|
self.evaluate(get_next))
|
||||||
self.assertEqual((b"Benjamin", b"Franklin", b"Pennsylvania"),
|
self.assertEqual((b"Benjamin", b"Franklin", b"Pennsylvania"),
|
||||||
sess.run(get_next))
|
sess.run(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
@ -212,8 +216,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
|||||||
self.query: "SELECT first_name, desk_number FROM students "
|
self.query: "SELECT first_name, desk_number FROM students "
|
||||||
"ORDER BY first_name DESC"
|
"ORDER BY first_name DESC"
|
||||||
})
|
})
|
||||||
self.assertEqual((b"John", 9), sess.run(get_next))
|
self.assertEqual((b"John", 9), self.evaluate(get_next))
|
||||||
self.assertEqual((b"Jane", 127), sess.run(get_next))
|
self.assertEqual((b"Jane", 127), self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -230,7 +234,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
|||||||
"FROM students "
|
"FROM students "
|
||||||
"WHERE first_name = 'John' ORDER BY first_name DESC"
|
"WHERE first_name = 'John' ORDER BY first_name DESC"
|
||||||
})
|
})
|
||||||
self.assertEqual((b"John", 0, -2), sess.run(get_next))
|
self.assertEqual((b"John", 0, -2), self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -246,9 +250,9 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
|||||||
"SELECT desk_number, favorite_negative_number FROM students "
|
"SELECT desk_number, favorite_negative_number FROM students "
|
||||||
"ORDER BY first_name DESC"
|
"ORDER BY first_name DESC"
|
||||||
})
|
})
|
||||||
self.assertEqual((9, -2), sess.run(get_next))
|
self.assertEqual((9, -2), self.evaluate(get_next))
|
||||||
# Max and min values of int8
|
# Max and min values of int8
|
||||||
self.assertEqual((127, -128), sess.run(get_next))
|
self.assertEqual((127, -128), self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -263,8 +267,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
|||||||
self.query: "SELECT first_name, desk_number FROM students "
|
self.query: "SELECT first_name, desk_number FROM students "
|
||||||
"ORDER BY first_name DESC"
|
"ORDER BY first_name DESC"
|
||||||
})
|
})
|
||||||
self.assertEqual((b"John", 9), sess.run(get_next))
|
self.assertEqual((b"John", 9), self.evaluate(get_next))
|
||||||
self.assertEqual((b"Jane", 127), sess.run(get_next))
|
self.assertEqual((b"Jane", 127), self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -281,7 +285,7 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
|||||||
"FROM students "
|
"FROM students "
|
||||||
"WHERE first_name = 'John' ORDER BY first_name DESC"
|
"WHERE first_name = 'John' ORDER BY first_name DESC"
|
||||||
})
|
})
|
||||||
self.assertEqual((b"John", 0, -2), sess.run(get_next))
|
self.assertEqual((b"John", 0, -2), self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -297,9 +301,9 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
|||||||
"FROM students ORDER BY first_name DESC"
|
"FROM students ORDER BY first_name DESC"
|
||||||
})
|
})
|
||||||
# Max value of int16
|
# Max value of int16
|
||||||
self.assertEqual((b"John", 32767), sess.run(get_next))
|
self.assertEqual((b"John", 32767), self.evaluate(get_next))
|
||||||
# Min value of int16
|
# Min value of int16
|
||||||
self.assertEqual((b"Jane", -32768), sess.run(get_next))
|
self.assertEqual((b"Jane", -32768), self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -314,8 +318,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
|||||||
self.query: "SELECT first_name, desk_number FROM students "
|
self.query: "SELECT first_name, desk_number FROM students "
|
||||||
"ORDER BY first_name DESC"
|
"ORDER BY first_name DESC"
|
||||||
})
|
})
|
||||||
self.assertEqual((b"John", 9), sess.run(get_next))
|
self.assertEqual((b"John", 9), self.evaluate(get_next))
|
||||||
self.assertEqual((b"Jane", 127), sess.run(get_next))
|
self.assertEqual((b"Jane", 127), self.evaluate(get_next))
|
||||||
|
|
||||||
# Test that `SqlDataset` can read a negative or 0-valued integer from a
|
# Test that `SqlDataset` can read a negative or 0-valued integer from a
|
||||||
# SQLite database table and place it in an `int32` tensor.
|
# SQLite database table and place it in an `int32` tensor.
|
||||||
@ -328,8 +332,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
|||||||
self.query: "SELECT first_name, income FROM students "
|
self.query: "SELECT first_name, income FROM students "
|
||||||
"ORDER BY first_name DESC"
|
"ORDER BY first_name DESC"
|
||||||
})
|
})
|
||||||
self.assertEqual((b"John", 0), sess.run(get_next))
|
self.assertEqual((b"John", 0), self.evaluate(get_next))
|
||||||
self.assertEqual((b"Jane", -20000), sess.run(get_next))
|
self.assertEqual((b"Jane", -20000), self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -345,9 +349,9 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
|||||||
"ORDER BY first_name DESC"
|
"ORDER BY first_name DESC"
|
||||||
})
|
})
|
||||||
# Max value of int32
|
# Max value of int32
|
||||||
self.assertEqual((b"John", 2147483647), sess.run(get_next))
|
self.assertEqual((b"John", 2147483647), self.evaluate(get_next))
|
||||||
# Min value of int32
|
# Min value of int32
|
||||||
self.assertEqual((b"Jane", -2147483648), sess.run(get_next))
|
self.assertEqual((b"Jane", -2147483648), self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -362,8 +366,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
|||||||
self.query: "SELECT first_name, school_id FROM students "
|
self.query: "SELECT first_name, school_id FROM students "
|
||||||
"ORDER BY first_name DESC"
|
"ORDER BY first_name DESC"
|
||||||
})
|
})
|
||||||
self.assertEqual((b"John", 123), sess.run(get_next))
|
self.assertEqual((b"John", 123), self.evaluate(get_next))
|
||||||
self.assertEqual((b"Jane", 1000), sess.run(get_next))
|
self.assertEqual((b"Jane", 1000), self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -378,8 +382,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
|||||||
self.query: "SELECT first_name, desk_number FROM students "
|
self.query: "SELECT first_name, desk_number FROM students "
|
||||||
"ORDER BY first_name DESC"
|
"ORDER BY first_name DESC"
|
||||||
})
|
})
|
||||||
self.assertEqual((b"John", 9), sess.run(get_next))
|
self.assertEqual((b"John", 9), self.evaluate(get_next))
|
||||||
self.assertEqual((b"Jane", 127), sess.run(get_next))
|
self.assertEqual((b"Jane", 127), self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -394,8 +398,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
|||||||
self.query: "SELECT first_name, income FROM students "
|
self.query: "SELECT first_name, income FROM students "
|
||||||
"ORDER BY first_name DESC"
|
"ORDER BY first_name DESC"
|
||||||
})
|
})
|
||||||
self.assertEqual((b"John", 0), sess.run(get_next))
|
self.assertEqual((b"John", 0), self.evaluate(get_next))
|
||||||
self.assertEqual((b"Jane", -20000), sess.run(get_next))
|
self.assertEqual((b"Jane", -20000), self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -412,9 +416,9 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
|||||||
"ORDER BY first_name DESC"
|
"ORDER BY first_name DESC"
|
||||||
})
|
})
|
||||||
# Max value of int64
|
# Max value of int64
|
||||||
self.assertEqual((b"John", 9223372036854775807), sess.run(get_next))
|
self.assertEqual((b"John", 9223372036854775807), self.evaluate(get_next))
|
||||||
# Min value of int64
|
# Min value of int64
|
||||||
self.assertEqual((b"Jane", -9223372036854775808), sess.run(get_next))
|
self.assertEqual((b"Jane", -9223372036854775808), self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -429,8 +433,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
|||||||
self.query: "SELECT first_name, desk_number FROM students "
|
self.query: "SELECT first_name, desk_number FROM students "
|
||||||
"ORDER BY first_name DESC"
|
"ORDER BY first_name DESC"
|
||||||
})
|
})
|
||||||
self.assertEqual((b"John", 9), sess.run(get_next))
|
self.assertEqual((b"John", 9), self.evaluate(get_next))
|
||||||
self.assertEqual((b"Jane", 127), sess.run(get_next))
|
self.assertEqual((b"Jane", 127), self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -446,9 +450,9 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
|||||||
"ORDER BY first_name DESC"
|
"ORDER BY first_name DESC"
|
||||||
})
|
})
|
||||||
# Min value of uint8
|
# Min value of uint8
|
||||||
self.assertEqual((b"John", 0), sess.run(get_next))
|
self.assertEqual((b"John", 0), self.evaluate(get_next))
|
||||||
# Max value of uint8
|
# Max value of uint8
|
||||||
self.assertEqual((b"Jane", 255), sess.run(get_next))
|
self.assertEqual((b"Jane", 255), self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -463,8 +467,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
|||||||
self.query: "SELECT first_name, desk_number FROM students "
|
self.query: "SELECT first_name, desk_number FROM students "
|
||||||
"ORDER BY first_name DESC"
|
"ORDER BY first_name DESC"
|
||||||
})
|
})
|
||||||
self.assertEqual((b"John", 9), sess.run(get_next))
|
self.assertEqual((b"John", 9), self.evaluate(get_next))
|
||||||
self.assertEqual((b"Jane", 127), sess.run(get_next))
|
self.assertEqual((b"Jane", 127), self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -480,9 +484,9 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
|||||||
"ORDER BY first_name DESC"
|
"ORDER BY first_name DESC"
|
||||||
})
|
})
|
||||||
# Min value of uint16
|
# Min value of uint16
|
||||||
self.assertEqual((b"John", 0), sess.run(get_next))
|
self.assertEqual((b"John", 0), self.evaluate(get_next))
|
||||||
# Max value of uint16
|
# Max value of uint16
|
||||||
self.assertEqual((b"Jane", 65535), sess.run(get_next))
|
self.assertEqual((b"Jane", 65535), self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -499,8 +503,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
|||||||
"SELECT first_name, registration_complete FROM students "
|
"SELECT first_name, registration_complete FROM students "
|
||||||
"ORDER BY first_name DESC"
|
"ORDER BY first_name DESC"
|
||||||
})
|
})
|
||||||
self.assertEqual((b"John", True), sess.run(get_next))
|
self.assertEqual((b"John", True), self.evaluate(get_next))
|
||||||
self.assertEqual((b"Jane", False), sess.run(get_next))
|
self.assertEqual((b"Jane", False), self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -515,8 +519,8 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
|||||||
self.query: "SELECT first_name, favorite_medium_sized_number "
|
self.query: "SELECT first_name, favorite_medium_sized_number "
|
||||||
"FROM students ORDER BY first_name DESC"
|
"FROM students ORDER BY first_name DESC"
|
||||||
})
|
})
|
||||||
self.assertEqual((b"John", True), sess.run(get_next))
|
self.assertEqual((b"John", True), self.evaluate(get_next))
|
||||||
self.assertEqual((b"Jane", True), sess.run(get_next))
|
self.assertEqual((b"Jane", True), self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -533,8 +537,9 @@ class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
|||||||
"SELECT first_name, last_name, victories FROM townspeople "
|
"SELECT first_name, last_name, victories FROM townspeople "
|
||||||
"ORDER BY first_name"
|
"ORDER BY first_name"
|
||||||
})
|
})
|
||||||
self.assertEqual((b"George", b"Washington", 20.0), sess.run(get_next))
|
self.assertEqual((b"George", b"Washington", 20.0),
|
||||||
self.assertEqual((b"John", b"Adams", -19.95), sess.run(get_next))
|
self.evaluate(get_next))
|
||||||
|
self.assertEqual((b"John", b"Adams", -19.95), self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
|
@ -74,18 +74,18 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
|||||||
summary_t = aggregator.get_summary()
|
summary_t = aggregator.get_summary()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(iterator.initializer)
|
self.evaluate(iterator.initializer)
|
||||||
expected_sum = 0.0
|
expected_sum = 0.0
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
self.assertAllEqual(
|
self.assertAllEqual(
|
||||||
np.array([i] * i, dtype=np.int64), sess.run(next_element))
|
np.array([i] * i, dtype=np.int64), sess.run(next_element))
|
||||||
summary_str = sess.run(summary_t)
|
summary_str = self.evaluate(summary_t)
|
||||||
self._assertSummaryHasCount(summary_str, "bytes_produced", float(i + 1))
|
self._assertSummaryHasCount(summary_str, "bytes_produced", float(i + 1))
|
||||||
expected_sum += i * 8.0
|
expected_sum += i * 8.0
|
||||||
self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum)
|
self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum)
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
summary_str = sess.run(summary_t)
|
summary_str = self.evaluate(summary_t)
|
||||||
self._assertSummaryHasCount(summary_str, "bytes_produced", 100.0)
|
self._assertSummaryHasCount(summary_str, "bytes_produced", 100.0)
|
||||||
self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum)
|
self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum)
|
||||||
|
|
||||||
@ -99,14 +99,15 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
|||||||
summary_t = aggregator.get_summary()
|
summary_t = aggregator.get_summary()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(iterator.initializer)
|
self.evaluate(iterator.initializer)
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
self.assertEqual(i, sess.run(next_element))
|
self.assertEqual(i, self.evaluate(next_element))
|
||||||
self._assertSummaryHasCount(
|
self._assertSummaryHasCount(
|
||||||
sess.run(summary_t), "record_latency", float(i + 1))
|
sess.run(summary_t), "record_latency", float(i + 1))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0)
|
self._assertSummaryHasCount(
|
||||||
|
self.evaluate(summary_t), "record_latency", 100.0)
|
||||||
|
|
||||||
def testPrefetchBufferUtilization(self, dataset_transformation):
|
def testPrefetchBufferUtilization(self, dataset_transformation):
|
||||||
aggregator = stats_aggregator.StatsAggregator()
|
aggregator = stats_aggregator.StatsAggregator()
|
||||||
@ -118,11 +119,11 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
|||||||
summary_t = aggregator.get_summary()
|
summary_t = aggregator.get_summary()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(iterator.initializer)
|
self.evaluate(iterator.initializer)
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
self.assertAllEqual(
|
self.assertAllEqual(
|
||||||
np.array([i] * i, dtype=np.int64), sess.run(next_element))
|
np.array([i] * i, dtype=np.int64), sess.run(next_element))
|
||||||
summary_str = sess.run(summary_t)
|
summary_str = self.evaluate(summary_t)
|
||||||
self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
|
self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
|
||||||
float(i + 1))
|
float(i + 1))
|
||||||
self._assertSummaryContains(summary_str, "Prefetch::buffer_capacity")
|
self._assertSummaryContains(summary_str, "Prefetch::buffer_capacity")
|
||||||
@ -131,7 +132,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
|||||||
0, 1)
|
0, 1)
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
summary_str = sess.run(summary_t)
|
summary_str = self.evaluate(summary_t)
|
||||||
self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
|
self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
|
||||||
100)
|
100)
|
||||||
|
|
||||||
@ -145,11 +146,11 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
|||||||
summary_t = aggregator.get_summary()
|
summary_t = aggregator.get_summary()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(iterator.initializer)
|
self.evaluate(iterator.initializer)
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
self.assertAllEqual(
|
self.assertAllEqual(
|
||||||
np.array([i] * i, dtype=np.int64), sess.run(next_element))
|
np.array([i] * i, dtype=np.int64), sess.run(next_element))
|
||||||
summary_str = sess.run(summary_t)
|
summary_str = self.evaluate(summary_t)
|
||||||
self._assertSummaryHasScalarValue(summary_str,
|
self._assertSummaryHasScalarValue(summary_str,
|
||||||
"Prefetch::buffer_capacity", 0)
|
"Prefetch::buffer_capacity", 0)
|
||||||
self._assertSummaryHasScalarValue(summary_str, "Prefetch::buffer_size",
|
self._assertSummaryHasScalarValue(summary_str, "Prefetch::buffer_size",
|
||||||
@ -167,9 +168,9 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
|||||||
summary_t = aggregator.get_summary()
|
summary_t = aggregator.get_summary()
|
||||||
|
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
sess.run(iterator.initializer)
|
self.evaluate(iterator.initializer)
|
||||||
for i in range(34):
|
for i in range(34):
|
||||||
self.assertEqual(i * 3, sess.run(next_element))
|
self.assertEqual(i * 3, self.evaluate(next_element))
|
||||||
if i is not 0:
|
if i is not 0:
|
||||||
self._assertSummaryHasScalarValue(
|
self._assertSummaryHasScalarValue(
|
||||||
sess.run(summary_t), "Filter::dropped_elements", float(i * 2))
|
sess.run(summary_t), "Filter::dropped_elements", float(i * 2))
|
||||||
@ -261,9 +262,9 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
|||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
for j in range(5):
|
for j in range(5):
|
||||||
sess.run(iterator.initializer)
|
self.evaluate(iterator.initializer)
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
self.assertEqual(i, sess.run(next_element))
|
self.assertEqual(i, self.evaluate(next_element))
|
||||||
self._assertSummaryHasCount(
|
self._assertSummaryHasCount(
|
||||||
sess.run(summary_t), "record_latency", float((j * 100) + i + 1))
|
sess.run(summary_t), "record_latency", float((j * 100) + i + 1))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
@ -278,9 +279,9 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
|||||||
next_element = iterator.get_next()
|
next_element = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(iterator.initializer)
|
self.evaluate(iterator.initializer)
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
self.assertEqual(i, sess.run(next_element))
|
self.assertEqual(i, self.evaluate(next_element))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
|
|
||||||
@ -295,16 +296,17 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
|||||||
summary_t = aggregator.get_summary()
|
summary_t = aggregator.get_summary()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(iterator.initializer)
|
self.evaluate(iterator.initializer)
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
self.assertEqual(i, sess.run(next_element))
|
self.assertEqual(i, self.evaluate(next_element))
|
||||||
self._assertSummaryHasCount(
|
self._assertSummaryHasCount(
|
||||||
sess.run(summary_t), "record_latency", float(i + 1))
|
sess.run(summary_t), "record_latency", float(i + 1))
|
||||||
self._assertSummaryHasCount(
|
self._assertSummaryHasCount(
|
||||||
sess.run(summary_t), "record_latency_2", float(i + 1))
|
sess.run(summary_t), "record_latency_2", float(i + 1))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0)
|
self._assertSummaryHasCount(
|
||||||
|
self.evaluate(summary_t), "record_latency", 100.0)
|
||||||
self._assertSummaryHasCount(
|
self._assertSummaryHasCount(
|
||||||
sess.run(summary_t), "record_latency_2", 100.0)
|
sess.run(summary_t), "record_latency_2", 100.0)
|
||||||
|
|
||||||
@ -319,14 +321,15 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
|||||||
summary_t = aggregator.get_summary()
|
summary_t = aggregator.get_summary()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(iterator.initializer)
|
self.evaluate(iterator.initializer)
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
self.assertEqual(i, sess.run(next_element))
|
self.assertEqual(i, self.evaluate(next_element))
|
||||||
self._assertSummaryHasCount(
|
self._assertSummaryHasCount(
|
||||||
sess.run(summary_t), "record_latency", float(2 * (i + 1)))
|
sess.run(summary_t), "record_latency", float(2 * (i + 1)))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0)
|
self._assertSummaryHasCount(
|
||||||
|
self.evaluate(summary_t), "record_latency", 200.0)
|
||||||
|
|
||||||
def testMultipleIteratorsSameAggregator(self, dataset_transformation):
|
def testMultipleIteratorsSameAggregator(self, dataset_transformation):
|
||||||
aggregator = stats_aggregator.StatsAggregator()
|
aggregator = stats_aggregator.StatsAggregator()
|
||||||
@ -341,12 +344,13 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
|||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run([iterator_0.initializer, iterator_1.initializer])
|
sess.run([iterator_0.initializer, iterator_1.initializer])
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
self.assertEqual(i * 2, sess.run(next_element))
|
self.assertEqual(i * 2, self.evaluate(next_element))
|
||||||
self._assertSummaryHasCount(
|
self._assertSummaryHasCount(
|
||||||
sess.run(summary_t), "record_latency", float(2 * (i + 1)))
|
sess.run(summary_t), "record_latency", float(2 * (i + 1)))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0)
|
self._assertSummaryHasCount(
|
||||||
|
self.evaluate(summary_t), "record_latency", 200.0)
|
||||||
|
|
||||||
def testMultipleDatasetWithPrefixes(self, dataset_transformation):
|
def testMultipleDatasetWithPrefixes(self, dataset_transformation):
|
||||||
aggregator = stats_aggregator.StatsAggregator()
|
aggregator = stats_aggregator.StatsAggregator()
|
||||||
@ -364,7 +368,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
|||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
sess.run([iterator_0.initializer, iterator_1.initializer])
|
sess.run([iterator_0.initializer, iterator_1.initializer])
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
self.assertEqual(i * 2, sess.run(next_element))
|
self.assertEqual(i * 2, self.evaluate(next_element))
|
||||||
self._assertSummaryHasCount(
|
self._assertSummaryHasCount(
|
||||||
sess.run(summary_t), "dataset1_record_latency", float(i + 1))
|
sess.run(summary_t), "dataset1_record_latency", float(i + 1))
|
||||||
self._assertSummaryHasCount(
|
self._assertSummaryHasCount(
|
||||||
@ -421,7 +425,7 @@ class FeatureStatsDatasetTest(
|
|||||||
summary_t = aggregator.get_summary()
|
summary_t = aggregator.get_summary()
|
||||||
|
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
sess.run(iterator.initializer)
|
self.evaluate(iterator.initializer)
|
||||||
for _ in range(num_output):
|
for _ in range(num_output):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
|
|
||||||
|
@ -50,7 +50,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(iterator.initializer, feed_dict={placeholder: [0, 1, 2, 3]})
|
sess.run(iterator.initializer, feed_dict={placeholder: [0, 1, 2, 3]})
|
||||||
for i in range(4):
|
for i in range(4):
|
||||||
self.assertEqual(i, sess.run(next_elem))
|
self.assertEqual(i, self.evaluate(next_elem))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_elem)
|
sess.run(next_elem)
|
||||||
|
|
||||||
@ -68,7 +68,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
self.assertEqual((i,) * 3, sess.run(op))
|
self.assertEqual((i,) * 3, self.evaluate(op))
|
||||||
|
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(op)
|
sess.run(op)
|
||||||
@ -88,7 +88,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
self.assertEqual((i, compat.as_bytes(str(i)), i), sess.run(op))
|
self.assertEqual((i, compat.as_bytes(str(i)), i), self.evaluate(op))
|
||||||
|
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(op)
|
sess.run(op)
|
||||||
@ -107,7 +107,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
st_row = sess.run(next_element)
|
st_row = self.evaluate(next_element)
|
||||||
self.assertEqual([i], st_row.indices)
|
self.assertEqual([i], st_row.indices)
|
||||||
self.assertEqual([i], st_row.values)
|
self.assertEqual([i], st_row.values)
|
||||||
self.assertEqual([10], st_row.dense_shape)
|
self.assertEqual([10], st_row.dense_shape)
|
||||||
@ -128,7 +128,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
dense_elem, st_row = sess.run(next_element)
|
dense_elem, st_row = self.evaluate(next_element)
|
||||||
self.assertEqual(i, dense_elem)
|
self.assertEqual(i, dense_elem)
|
||||||
self.assertEqual([i], st_row.indices)
|
self.assertEqual([i], st_row.indices)
|
||||||
self.assertEqual([i], st_row.values)
|
self.assertEqual([i], st_row.values)
|
||||||
@ -150,7 +150,7 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
self.assertEqual(((i,),) * 3, sess.run(op))
|
self.assertEqual(((i,),) * 3, self.evaluate(op))
|
||||||
|
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(op)
|
sess.run(op)
|
||||||
|
@ -49,11 +49,11 @@ class UniqueTest(test_base.DatasetTestBase):
|
|||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
for test_case, expected in test_cases:
|
for test_case, expected in test_cases:
|
||||||
current_test_case = test_case
|
current_test_case = test_case
|
||||||
sess.run(iterator.initializer)
|
self.evaluate(iterator.initializer)
|
||||||
for element in expected:
|
for element in expected:
|
||||||
if dtype == dtypes.string:
|
if dtype == dtypes.string:
|
||||||
element = compat.as_bytes(element)
|
element = compat.as_bytes(element)
|
||||||
self.assertAllEqual(element, sess.run(next_element))
|
self.assertAllEqual(element, self.evaluate(next_element))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
|
|
||||||
|
@ -93,13 +93,13 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
})
|
})
|
||||||
num_full_batches = (count * 7) // batch_size
|
num_full_batches = (count * 7) // batch_size
|
||||||
for i in range(num_full_batches):
|
for i in range(num_full_batches):
|
||||||
result = sess.run(get_next)
|
result = self.evaluate(get_next)
|
||||||
for component, result_component in zip(components, result):
|
for component, result_component in zip(components, result):
|
||||||
for j in range(batch_size):
|
for j in range(batch_size):
|
||||||
self.assertAllEqual(component[(i * batch_size + j) % 7]**2,
|
self.assertAllEqual(component[(i * batch_size + j) % 7]**2,
|
||||||
result_component[j])
|
result_component[j])
|
||||||
if not drop_remainder and (count * 7) % batch_size > 0:
|
if not drop_remainder and (count * 7) % batch_size > 0:
|
||||||
result = sess.run(get_next)
|
result = self.evaluate(get_next)
|
||||||
for component, result_component in zip(components, result):
|
for component, result_component in zip(components, result):
|
||||||
for j in range((count * 7) % batch_size):
|
for j in range((count * 7) % batch_size):
|
||||||
self.assertAllEqual(
|
self.assertAllEqual(
|
||||||
@ -128,9 +128,9 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
actual = sess.run(get_next)
|
actual = self.evaluate(get_next)
|
||||||
expected = sparse_tensor.SparseTensorValue(
|
expected = sparse_tensor.SparseTensorValue(
|
||||||
indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
|
indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
|
||||||
values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4],
|
values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4],
|
||||||
@ -155,9 +155,9 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
actual = sess.run(get_next)
|
actual = self.evaluate(get_next)
|
||||||
expected_indices = []
|
expected_indices = []
|
||||||
expected_values = []
|
expected_values = []
|
||||||
for j in range(5):
|
for j in range(5):
|
||||||
@ -185,8 +185,8 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
actual = sess.run(get_next)
|
actual = self.evaluate(get_next)
|
||||||
expected = sparse_tensor.SparseTensorValue(
|
expected = sparse_tensor.SparseTensorValue(
|
||||||
indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [0, 4, 0],
|
indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [0, 4, 0],
|
||||||
[1, 0, 0], [1, 1, 0], [1, 2, 0], [1, 3, 0], [1, 4, 0]],
|
[1, 0, 0], [1, 1, 0], [1, 2, 0], [1, 3, 0], [1, 4, 0]],
|
||||||
@ -211,7 +211,7 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
next_element = iterator.get_next()
|
next_element = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(iterator.initializer)
|
self.evaluate(iterator.initializer)
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
errors.InvalidArgumentError,
|
errors.InvalidArgumentError,
|
||||||
r'Cannot batch tensors with different shapes in component 0. '
|
r'Cannot batch tensors with different shapes in component 0. '
|
||||||
@ -271,7 +271,7 @@ class PaddedBatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
num_full_batches = len(seq_lens) // batch_size
|
num_full_batches = len(seq_lens) // batch_size
|
||||||
|
|
||||||
for i in range(num_full_batches):
|
for i in range(num_full_batches):
|
||||||
result = sess.run(get_next)
|
result = self.evaluate(get_next)
|
||||||
padded_len = padded_shapes[0]
|
padded_len = padded_shapes[0]
|
||||||
if padded_len is None or padded_len == -1:
|
if padded_len is None or padded_len == -1:
|
||||||
padded_len = np.max(result) if result.size > 0 else 0
|
padded_len = np.max(result) if result.size > 0 else 0
|
||||||
@ -283,7 +283,7 @@ class PaddedBatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
[0] * (padded_len - seq_len))
|
[0] * (padded_len - seq_len))
|
||||||
|
|
||||||
if not drop_remainder and len(seq_lens) % batch_size > 0:
|
if not drop_remainder and len(seq_lens) % batch_size > 0:
|
||||||
result = sess.run(get_next)
|
result = self.evaluate(get_next)
|
||||||
padded_len = np.max(result) if result.size > 0 else 0
|
padded_len = np.max(result) if result.size > 0 else 0
|
||||||
self.assertEqual((len(seq_lens) % batch_size, padded_len),
|
self.assertEqual((len(seq_lens) % batch_size, padded_len),
|
||||||
result.shape)
|
result.shape)
|
||||||
@ -315,7 +315,7 @@ class PaddedBatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
result = sess.run(get_next)
|
result = self.evaluate(get_next)
|
||||||
self.assertAllEqual([[], [], [], []], result)
|
self.assertAllEqual([[], [], [], []], result)
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
@ -347,7 +347,7 @@ class PaddedBatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
seq_lens: random_seq_lens
|
seq_lens: random_seq_lens
|
||||||
})
|
})
|
||||||
for i in range(8):
|
for i in range(8):
|
||||||
result = sess.run(get_next)
|
result = self.evaluate(get_next)
|
||||||
padded_len = np.max(result[0])
|
padded_len = np.max(result[0])
|
||||||
self.assertEqual((4, padded_len), result[0].shape)
|
self.assertEqual((4, padded_len), result[0].shape)
|
||||||
self.assertEqual((4, padded_len), result[1].shape)
|
self.assertEqual((4, padded_len), result[1].shape)
|
||||||
|
@ -71,7 +71,7 @@ class FileCacheDatasetTest(test_base.DatasetTestBase):
|
|||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
# First run without caching to collect the "ground truth".
|
# First run without caching to collect the "ground truth".
|
||||||
sess.run(init_fifo_op)
|
self.evaluate(init_fifo_op)
|
||||||
elements = []
|
elements = []
|
||||||
for _ in range(20):
|
for _ in range(20):
|
||||||
elements.append(sess.run(get_next))
|
elements.append(sess.run(get_next))
|
||||||
@ -220,14 +220,14 @@ class MemoryCacheDatasetTest(test_base.DatasetTestBase):
|
|||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
|
|
||||||
sess.run(repeat_count.initializer)
|
self.evaluate(repeat_count.initializer)
|
||||||
sess.run(cached_iterator.initializer)
|
self.evaluate(cached_iterator.initializer)
|
||||||
sess.run(uncached_iterator.initializer)
|
self.evaluate(uncached_iterator.initializer)
|
||||||
|
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
self.assertEqual(sess.run(cached_next), i)
|
self.assertEqual(self.evaluate(cached_next), i)
|
||||||
self.assertEqual(sess.run(uncached_next), i)
|
self.assertEqual(self.evaluate(uncached_next), i)
|
||||||
|
|
||||||
sess.run(repeat_count.assign(0))
|
sess.run(repeat_count.assign(0))
|
||||||
|
|
||||||
@ -238,7 +238,7 @@ class MemoryCacheDatasetTest(test_base.DatasetTestBase):
|
|||||||
# The cached iterator replays from cache.
|
# The cached iterator replays from cache.
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
self.assertEqual(sess.run(cached_next), i)
|
self.assertEqual(self.evaluate(cached_next), i)
|
||||||
|
|
||||||
# The cached iterator should now be empty.
|
# The cached iterator should now be empty.
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
@ -280,7 +280,7 @@ class MemoryCacheDatasetTest(test_base.DatasetTestBase):
|
|||||||
i2 = d2.make_initializable_iterator()
|
i2 = d2.make_initializable_iterator()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(i1.initializer)
|
self.evaluate(i1.initializer)
|
||||||
|
|
||||||
self.assertEqual(1, sess.run(i1.get_next()))
|
self.assertEqual(1, sess.run(i1.get_next()))
|
||||||
self.assertEqual(2, sess.run(i1.get_next()))
|
self.assertEqual(2, sess.run(i1.get_next()))
|
||||||
@ -307,7 +307,7 @@ class MemoryCacheDatasetTest(test_base.DatasetTestBase):
|
|||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
for i, expected in enumerate(expected_values):
|
for i, expected in enumerate(expected_values):
|
||||||
self.assertEqual(expected, sess.run(n),
|
self.assertEqual(expected, self.evaluate(n),
|
||||||
"Unexpected value at index %s" % i)
|
"Unexpected value at index %s" % i)
|
||||||
|
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
@ -51,9 +51,9 @@ class ConcatenateDatasetTest(test_base.DatasetTestBase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for i in range(9):
|
for i in range(9):
|
||||||
result = sess.run(get_next)
|
result = self.evaluate(get_next)
|
||||||
if i < 4:
|
if i < 4:
|
||||||
for component, result_component in zip(input_components, result):
|
for component, result_component in zip(input_components, result):
|
||||||
self.assertAllEqual(component[i], result_component)
|
self.assertAllEqual(component[i], result_component)
|
||||||
@ -85,9 +85,9 @@ class ConcatenateDatasetTest(test_base.DatasetTestBase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for i in range(9):
|
for i in range(9):
|
||||||
result = sess.run(get_next)
|
result = self.evaluate(get_next)
|
||||||
if i < 4:
|
if i < 4:
|
||||||
for component, result_component in zip(input_components, result):
|
for component, result_component in zip(input_components, result):
|
||||||
self.assertAllEqual(component[i], result_component)
|
self.assertAllEqual(component[i], result_component)
|
||||||
|
@ -52,8 +52,8 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
|||||||
[t.shape for t in get_next])
|
[t.shape for t in get_next])
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
results = sess.run(get_next)
|
results = self.evaluate(get_next)
|
||||||
for component, result_component in zip(components, results):
|
for component, result_component in zip(components, results):
|
||||||
self.assertAllEqual(component, result_component)
|
self.assertAllEqual(component, result_component)
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
@ -81,8 +81,8 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
|||||||
[shape for shape in iterator.output_shapes])
|
[shape for shape in iterator.output_shapes])
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
results = sess.run(get_next)
|
results = self.evaluate(get_next)
|
||||||
for component, result_component in zip(components, results):
|
for component, result_component in zip(components, results):
|
||||||
self.assertSparseValuesEqual(component, result_component)
|
self.assertSparseValuesEqual(component, result_component)
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
@ -112,8 +112,8 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
|||||||
], [shape for shape in iterator.output_shapes])
|
], [shape for shape in iterator.output_shapes])
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
results = sess.run(get_next)
|
results = self.evaluate(get_next)
|
||||||
for component, result_component in zip(components, results):
|
for component, result_component in zip(components, results):
|
||||||
if sparse_tensor.is_sparse(component):
|
if sparse_tensor.is_sparse(component):
|
||||||
self.assertSparseValuesEqual(component, result_component)
|
self.assertSparseValuesEqual(component, result_component)
|
||||||
@ -139,9 +139,9 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
|||||||
[t.shape for t in get_next])
|
[t.shape for t in get_next])
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for i in range(4):
|
for i in range(4):
|
||||||
results = sess.run(get_next)
|
results = self.evaluate(get_next)
|
||||||
for component, result_component in zip(components, results):
|
for component, result_component in zip(components, results):
|
||||||
self.assertAllEqual(component[i], result_component)
|
self.assertAllEqual(component[i], result_component)
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
@ -169,7 +169,7 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
|||||||
[shape for shape in iterator.output_shapes])
|
[shape for shape in iterator.output_shapes])
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
expected = [
|
expected = [
|
||||||
(sparse_tensor.SparseTensorValue(
|
(sparse_tensor.SparseTensorValue(
|
||||||
indices=np.array([[0]]),
|
indices=np.array([[0]]),
|
||||||
@ -197,7 +197,7 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
|||||||
dense_shape=np.array([3]))),
|
dense_shape=np.array([3]))),
|
||||||
]
|
]
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
results = sess.run(get_next)
|
results = self.evaluate(get_next)
|
||||||
for component, result_component in zip(expected[i], results):
|
for component, result_component in zip(expected[i], results):
|
||||||
self.assertSparseValuesEqual(component, result_component)
|
self.assertSparseValuesEqual(component, result_component)
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
@ -229,7 +229,7 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
|||||||
], [shape for shape in iterator.output_shapes])
|
], [shape for shape in iterator.output_shapes])
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
expected = [
|
expected = [
|
||||||
(sparse_tensor.SparseTensorValue(
|
(sparse_tensor.SparseTensorValue(
|
||||||
indices=np.array([[0]]),
|
indices=np.array([[0]]),
|
||||||
@ -257,7 +257,7 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
|||||||
dense_shape=np.array([3]))),
|
dense_shape=np.array([3]))),
|
||||||
]
|
]
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
results = sess.run(get_next)
|
results = self.evaluate(get_next)
|
||||||
for component, result_component in zip(
|
for component, result_component in zip(
|
||||||
(list(zip(*components[:3]))[i] + expected[i]), results):
|
(list(zip(*components[:3]))[i] + expected[i]), results):
|
||||||
if sparse_tensor.is_sparse(component):
|
if sparse_tensor.is_sparse(component):
|
||||||
@ -280,9 +280,9 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
|||||||
self.assertEqual((1,), iterator.output_shapes["bar"])
|
self.assertEqual((1,), iterator.output_shapes["bar"])
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
results = sess.run(get_next)
|
results = self.evaluate(get_next)
|
||||||
self.assertEqual(components["foo"][i], results["foo"])
|
self.assertEqual(components["foo"][i], results["foo"])
|
||||||
self.assertEqual(components["bar"][i], results["bar"])
|
self.assertEqual(components["bar"][i], results["bar"])
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
@ -308,7 +308,7 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
|||||||
dense_shape)
|
dense_shape)
|
||||||
sess.run(init_op, feed_dict={st: sparse_feed})
|
sess.run(init_op, feed_dict={st: sparse_feed})
|
||||||
for i, s in enumerate(slices):
|
for i, s in enumerate(slices):
|
||||||
results = sess.run(get_next)
|
results = self.evaluate(get_next)
|
||||||
self.assertAllEqual(s, results.values)
|
self.assertAllEqual(s, results.values)
|
||||||
expected_indices = np.array(
|
expected_indices = np.array(
|
||||||
[[j] for j in range(len(slices[i]))]).reshape([-1, 1])
|
[[j] for j in range(len(slices[i]))]).reshape([-1, 1])
|
||||||
@ -474,15 +474,15 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
|||||||
with ops.device("/cpu:0"):
|
with ops.device("/cpu:0"):
|
||||||
var_0 = resource_variable_ops.ResourceVariable(initial_value=0)
|
var_0 = resource_variable_ops.ResourceVariable(initial_value=0)
|
||||||
dataset = dataset.map(lambda x: x + var_0.read_value())
|
dataset = dataset.map(lambda x: x + var_0.read_value())
|
||||||
sess.run(var_0.initializer)
|
self.evaluate(var_0.initializer)
|
||||||
|
|
||||||
with ops.device("/cpu:1"):
|
with ops.device("/cpu:1"):
|
||||||
var_1 = resource_variable_ops.ResourceVariable(initial_value=0)
|
var_1 = resource_variable_ops.ResourceVariable(initial_value=0)
|
||||||
dataset = dataset.map(lambda x: x + var_1.read_value())
|
dataset = dataset.map(lambda x: x + var_1.read_value())
|
||||||
sess.run(var_1.initializer)
|
self.evaluate(var_1.initializer)
|
||||||
|
|
||||||
iterator = dataset.make_initializable_iterator()
|
iterator = dataset.make_initializable_iterator()
|
||||||
sess.run(iterator.initializer)
|
self.evaluate(iterator.initializer)
|
||||||
|
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
errors.FailedPreconditionError,
|
errors.FailedPreconditionError,
|
||||||
@ -506,7 +506,7 @@ class DatasetConstructorBenchmark(test.Benchmark):
|
|||||||
next_element = iterator.get_next()
|
next_element = iterator.get_next()
|
||||||
|
|
||||||
with session.Session() as sess:
|
with session.Session() as sess:
|
||||||
sess.run(iterator.initializer)
|
self.evaluate(iterator.initializer)
|
||||||
# Run one whole epoch to burn in the computation.
|
# Run one whole epoch to burn in the computation.
|
||||||
for _ in range(input_size // batch_size):
|
for _ in range(input_size // batch_size):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
@ -543,7 +543,7 @@ class DatasetConstructorBenchmark(test.Benchmark):
|
|||||||
next_element = iterator.get_next()
|
next_element = iterator.get_next()
|
||||||
|
|
||||||
with session.Session() as sess:
|
with session.Session() as sess:
|
||||||
sess.run(iterator.initializer)
|
self.evaluate(iterator.initializer)
|
||||||
get_next_element = sess.make_callable(next_element)
|
get_next_element = sess.make_callable(next_element)
|
||||||
# Run one whole epoch to burn in the computation.
|
# Run one whole epoch to burn in the computation.
|
||||||
for _ in range(input_size // batch_size):
|
for _ in range(input_size // batch_size):
|
||||||
@ -582,7 +582,7 @@ class DatasetConstructorBenchmark(test.Benchmark):
|
|||||||
next_element = iterator.get_next()
|
next_element = iterator.get_next()
|
||||||
|
|
||||||
with session.Session() as sess:
|
with session.Session() as sess:
|
||||||
sess.run(iterator.initializer)
|
self.evaluate(iterator.initializer)
|
||||||
get_next_element = sess.make_callable(next_element)
|
get_next_element = sess.make_callable(next_element)
|
||||||
# Run one whole epoch to burn in the computation.
|
# Run one whole epoch to burn in the computation.
|
||||||
for _ in range(input_size // batch_size):
|
for _ in range(input_size // batch_size):
|
||||||
@ -620,7 +620,7 @@ class DatasetConstructorBenchmark(test.Benchmark):
|
|||||||
next_element = iterator.get_next()
|
next_element = iterator.get_next()
|
||||||
|
|
||||||
with session.Session() as sess:
|
with session.Session() as sess:
|
||||||
sess.run(iterator.initializer)
|
self.evaluate(iterator.initializer)
|
||||||
get_next_element = sess.make_callable(next_element)
|
get_next_element = sess.make_callable(next_element)
|
||||||
# Run one whole epoch to burn in the computation.
|
# Run one whole epoch to burn in the computation.
|
||||||
for _ in range(input_size // batch_size):
|
for _ in range(input_size // batch_size):
|
||||||
|
@ -47,10 +47,10 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
|||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
for _ in range(2): # Run twice to test reinitialization.
|
for _ in range(2): # Run twice to test reinitialization.
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for _ in range(num_repeats):
|
for _ in range(num_repeats):
|
||||||
for elem in elem_sequence:
|
for elem in elem_sequence:
|
||||||
self.assertAllEqual(elem, sess.run(get_next))
|
self.assertAllEqual(elem, self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -65,7 +65,7 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
|||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
for _ in range(num_repeats):
|
for _ in range(num_repeats):
|
||||||
for elem in elem_sequence:
|
for elem in elem_sequence:
|
||||||
self.assertAllEqual(elem, sess.run(get_next))
|
self.assertAllEqual(elem, self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -133,10 +133,10 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for _ in range(num_inner_repeats * num_outer_repeats):
|
for _ in range(num_inner_repeats * num_outer_repeats):
|
||||||
for elem in input_list:
|
for elem in input_list:
|
||||||
val0, val1 = sess.run(get_next)
|
val0, val1 = self.evaluate(get_next)
|
||||||
self.assertAllEqual(elem[0], val0)
|
self.assertAllEqual(elem[0], val0)
|
||||||
self.assertAllEqual(elem[1], val1)
|
self.assertAllEqual(elem[1], val1)
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
@ -192,10 +192,10 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for elem in [0, 1]:
|
for elem in [0, 1]:
|
||||||
for _ in range(num_parallel_iterators):
|
for _ in range(num_parallel_iterators):
|
||||||
self.assertAllEqual(elem, sess.run(get_next))
|
self.assertAllEqual(elem, self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -215,9 +215,9 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
|||||||
self.assertEqual(dtype, get_next.dtype)
|
self.assertEqual(dtype, get_next.dtype)
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for expected in [[1], [2], [3]]:
|
for expected in [[1], [2], [3]]:
|
||||||
next_val = sess.run(get_next)
|
next_val = self.evaluate(get_next)
|
||||||
self.assertEqual(dtype.as_numpy_dtype, next_val.dtype)
|
self.assertEqual(dtype.as_numpy_dtype, next_val.dtype)
|
||||||
self.assertAllEqual(expected, next_val)
|
self.assertAllEqual(expected, next_val)
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
@ -236,9 +236,9 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for expected in [b"foo", b"bar", b"baz"]:
|
for expected in [b"foo", b"bar", b"baz"]:
|
||||||
next_val = sess.run(get_next)
|
next_val = self.evaluate(get_next)
|
||||||
self.assertAllEqual(expected, next_val)
|
self.assertAllEqual(expected, next_val)
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
@ -257,12 +257,12 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
self.assertAllEqual([1, 2, 3], sess.run(get_next))
|
self.assertAllEqual([1, 2, 3], self.evaluate(get_next))
|
||||||
self.assertAllEqual([4, 5, 6], sess.run(get_next))
|
self.assertAllEqual([4, 5, 6], self.evaluate(get_next))
|
||||||
with self.assertRaisesOpError("The expected type was int64"):
|
with self.assertRaisesOpError("The expected type was int64"):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
self.assertAllEqual([7, 8, 9], sess.run(get_next))
|
self.assertAllEqual([7, 8, 9], self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -280,12 +280,12 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
self.assertAllEqual([1, 2, 3], sess.run(get_next))
|
self.assertAllEqual([1, 2, 3], self.evaluate(get_next))
|
||||||
self.assertAllEqual([4, 5, 6], sess.run(get_next))
|
self.assertAllEqual([4, 5, 6], self.evaluate(get_next))
|
||||||
with self.assertRaisesOpError(r"element of shape \(3,\) was expected"):
|
with self.assertRaisesOpError(r"element of shape \(3,\) was expected"):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
self.assertAllEqual([11, 12, 13], sess.run(get_next))
|
self.assertAllEqual([11, 12, 13], self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -304,16 +304,16 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
self.assertEqual((1, 2), sess.run(get_next))
|
self.assertEqual((1, 2), self.evaluate(get_next))
|
||||||
self.assertEqual((3, 4), sess.run(get_next))
|
self.assertEqual((3, 4), self.evaluate(get_next))
|
||||||
with self.assertRaisesOpError(
|
with self.assertRaisesOpError(
|
||||||
r"The expected structure was \(tf\.int64, tf\.int64\)"):
|
r"The expected structure was \(tf\.int64, tf\.int64\)"):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
with self.assertRaisesOpError(
|
with self.assertRaisesOpError(
|
||||||
r"The expected structure was \(tf\.int64, tf\.int64\)"):
|
r"The expected structure was \(tf\.int64, tf\.int64\)"):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
self.assertEqual((9, 10), sess.run(get_next))
|
self.assertEqual((9, 10), self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -329,9 +329,9 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
self.assertAllEqual(1, sess.run(get_next))
|
self.assertAllEqual(1, self.evaluate(get_next))
|
||||||
self.assertAllEqual([2, 3], sess.run(get_next))
|
self.assertAllEqual([2, 3], self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -349,9 +349,9 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
self.assertAllEqual(0, sess.run(get_next))
|
self.assertAllEqual(0, self.evaluate(get_next))
|
||||||
self.assertAllEqual(1, sess.run(get_next))
|
self.assertAllEqual(1, self.evaluate(get_next))
|
||||||
|
|
||||||
def testFromGeneratorDestructorCalled(self):
|
def testFromGeneratorDestructorCalled(self):
|
||||||
# Use an `Event` to signal that the generator has been deleted.
|
# Use an `Event` to signal that the generator has been deleted.
|
||||||
@ -378,9 +378,9 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with session.Session() as sess:
|
with session.Session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
self.assertAllEqual(42, sess.run(get_next))
|
self.assertAllEqual(42, self.evaluate(get_next))
|
||||||
self.assertAllEqual(42, sess.run(get_next))
|
self.assertAllEqual(42, self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
# Test that `GeneratorWrapper` object is destroyed when the
|
# Test that `GeneratorWrapper` object is destroyed when the
|
||||||
@ -407,10 +407,10 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
expected = [1, 2, 2, 3, 3, 3, 4, 4, 4, 4]
|
expected = [1, 2, 2, 3, 3, 3, 4, 4, 4, 4]
|
||||||
for x in expected:
|
for x in expected:
|
||||||
self.assertEqual(x, sess.run(get_next))
|
self.assertEqual(x, self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -436,13 +436,13 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
expected = [(0, b"Hi!"),
|
expected = [(0, b"Hi!"),
|
||||||
(0, b"Hi!"), (1, b"Hi!"),
|
(0, b"Hi!"), (1, b"Hi!"),
|
||||||
(0, b"Hi!"), (1, b"Hi!"), (2, b"Hi!"),
|
(0, b"Hi!"), (1, b"Hi!"), (2, b"Hi!"),
|
||||||
(0, b"Hi!"), (1, b"Hi!"), (2, b"Hi!"), (3, b"Hi!")]
|
(0, b"Hi!"), (1, b"Hi!"), (2, b"Hi!"), (3, b"Hi!")]
|
||||||
for x in expected:
|
for x in expected:
|
||||||
self.assertEqual(x, sess.run(get_next))
|
self.assertEqual(x, self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -470,9 +470,9 @@ class DatasetConstructorTest(test_base.DatasetTestBase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
self.assertAllEqual(37, sess.run(get_next))
|
self.assertAllEqual(37, self.evaluate(get_next))
|
||||||
self.assertAllEqual(37, sess.run(get_next))
|
self.assertAllEqual(37, self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
self.assertTrue(event.is_set())
|
self.assertTrue(event.is_set())
|
||||||
|
@ -67,7 +67,7 @@ class FilterDatasetTest(test_base.DatasetTestBase):
|
|||||||
sess.run(init_op, feed_dict={count: count_val, modulus: modulus_val})
|
sess.run(init_op, feed_dict={count: count_val, modulus: modulus_val})
|
||||||
for _ in range(count_val):
|
for _ in range(count_val):
|
||||||
for i in [x for x in range(7) if x**2 % modulus_val == 0]:
|
for i in [x for x in range(7) if x**2 % modulus_val == 0]:
|
||||||
result = sess.run(get_next)
|
result = self.evaluate(get_next)
|
||||||
for component, result_component in zip(components, result):
|
for component, result_component in zip(components, result):
|
||||||
self.assertAllEqual(component[i]**2, result_component)
|
self.assertAllEqual(component[i]**2, result_component)
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
@ -86,9 +86,9 @@ class FilterDatasetTest(test_base.DatasetTestBase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertEqual(0, sess.run(get_next))
|
self.assertEqual(0, self.evaluate(get_next))
|
||||||
self.assertEqual(1, sess.run(get_next))
|
self.assertEqual(1, self.evaluate(get_next))
|
||||||
self.assertEqual(3, sess.run(get_next))
|
self.assertEqual(3, self.evaluate(get_next))
|
||||||
|
|
||||||
def testFilterDict(self):
|
def testFilterDict(self):
|
||||||
iterator = (dataset_ops.Dataset.range(10)
|
iterator = (dataset_ops.Dataset.range(10)
|
||||||
@ -100,10 +100,10 @@ class FilterDatasetTest(test_base.DatasetTestBase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
if (i ** 2) % 2 == 0:
|
if (i ** 2) % 2 == 0:
|
||||||
self.assertEqual(i * 2 + i ** 2, sess.run(get_next))
|
self.assertEqual(i * 2 + i**2, self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -125,8 +125,8 @@ class FilterDatasetTest(test_base.DatasetTestBase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
self.assertAllEqual(input_data[0], sess.run(get_next))
|
self.assertAllEqual(input_data[0], self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -148,9 +148,9 @@ class FilterDatasetTest(test_base.DatasetTestBase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
actual = sess.run(get_next)
|
actual = self.evaluate(get_next)
|
||||||
self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue))
|
self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue))
|
||||||
self.assertSparseValuesEqual(actual, _map_fn(i * 2)[0])
|
self.assertSparseValuesEqual(actual, _map_fn(i * 2)[0])
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
@ -166,9 +166,9 @@ class FilterDatasetTest(test_base.DatasetTestBase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
self.assertEqual((i, True), sess.run(get_next))
|
self.assertEqual((i, True), self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -178,7 +178,7 @@ class FilterDatasetTest(test_base.DatasetTestBase):
|
|||||||
iterators = [dataset.make_one_shot_iterator() for _ in range(10)]
|
iterators = [dataset.make_one_shot_iterator() for _ in range(10)]
|
||||||
next_elements = [iterator.get_next() for iterator in iterators]
|
next_elements = [iterator.get_next() for iterator in iterators]
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertEqual([0 for _ in range(10)], sess.run(next_elements))
|
self.assertEqual([0 for _ in range(10)], self.evaluate(next_elements))
|
||||||
|
|
||||||
|
|
||||||
class FilterDatasetBenchmark(test.Benchmark):
|
class FilterDatasetBenchmark(test.Benchmark):
|
||||||
|
@ -45,10 +45,10 @@ class FlatMapDatasetTest(test_base.DatasetTestBase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for i in repeats:
|
for i in repeats:
|
||||||
for _ in range(i):
|
for _ in range(i):
|
||||||
self.assertEqual(i, sess.run(get_next))
|
self.assertEqual(i, self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -64,11 +64,11 @@ class FlatMapDatasetTest(test_base.DatasetTestBase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for row in repeats:
|
for row in repeats:
|
||||||
for i in row:
|
for i in row:
|
||||||
for _ in range(i):
|
for _ in range(i):
|
||||||
self.assertEqual(i, sess.run(get_next))
|
self.assertEqual(i, self.evaluate(get_next))
|
||||||
|
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
@ -94,12 +94,12 @@ class FlatMapDatasetTest(test_base.DatasetTestBase):
|
|||||||
with session.Session(server.target) as sess2:
|
with session.Session(server.target) as sess2:
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
sess = random.choice([sess1, sess2])
|
sess = random.choice([sess1, sess2])
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for row in repeats:
|
for row in repeats:
|
||||||
for i in row:
|
for i in row:
|
||||||
for _ in range(i):
|
for _ in range(i):
|
||||||
sess = random.choice([sess1, sess2])
|
sess = random.choice([sess1, sess2])
|
||||||
self.assertEqual(i, sess.run(get_next))
|
self.assertEqual(i, self.evaluate(get_next))
|
||||||
|
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess = random.choice([sess1, sess2])
|
sess = random.choice([sess1, sess2])
|
||||||
@ -115,10 +115,10 @@ class FlatMapDatasetTest(test_base.DatasetTestBase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
for _ in range(i ** 2):
|
for _ in range(i ** 2):
|
||||||
self.assertEqual(i * 2, sess.run(get_next))
|
self.assertEqual(i * 2, self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
# pylint: enable=g-long-lambda
|
# pylint: enable=g-long-lambda
|
||||||
@ -139,11 +139,11 @@ class FlatMapDatasetTest(test_base.DatasetTestBase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
for j in range(2):
|
for j in range(2):
|
||||||
expected = [i, 0] if j % 2 == 0 else [0, -i]
|
expected = [i, 0] if j % 2 == 0 else [0, -i]
|
||||||
self.assertAllEqual(expected, sess.run(get_next))
|
self.assertAllEqual(expected, self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
|
@ -196,7 +196,7 @@ class InterleaveDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
for expected_element in _interleave(
|
for expected_element in _interleave(
|
||||||
_repeat(input_values, count), cycle_length, block_length):
|
_repeat(input_values, count), cycle_length, block_length):
|
||||||
self.assertEqual(expected_element, sess.run(get_next))
|
self.assertEqual(expected_element, self.evaluate(get_next))
|
||||||
|
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
@ -231,7 +231,7 @@ class InterleaveDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
with self.assertRaises(errors.InvalidArgumentError):
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
else:
|
else:
|
||||||
self.assertEqual(value, sess.run(get_next))
|
self.assertEqual(value, self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -254,7 +254,7 @@ class InterleaveDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
for i in range(10):
|
for i in range(10):
|
||||||
for j in range(2):
|
for j in range(2):
|
||||||
expected = [i, 0] if j % 2 == 0 else [0, -i]
|
expected = [i, 0] if j % 2 == 0 else [0, -i]
|
||||||
self.assertAllEqual(expected, sess.run(get_next))
|
self.assertAllEqual(expected, self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -308,7 +308,7 @@ class InterleaveDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
|
|
||||||
for element in elements:
|
for element in elements:
|
||||||
coordination_events[element].set()
|
coordination_events[element].set()
|
||||||
self.assertEqual(element * element, sess.run(get_next))
|
self.assertEqual(element * element, self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
|
@ -57,7 +57,7 @@ class IteratorClusterTest(test.TestCase):
|
|||||||
|
|
||||||
with session.Session(worker[0].target) as sess:
|
with session.Session(worker[0].target) as sess:
|
||||||
with self.assertRaises(errors.InvalidArgumentError):
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
sess.run(get_next_op)
|
self.evaluate(get_next_op)
|
||||||
|
|
||||||
def _testRemoteIteratorHelper(self, device0, device1, target):
|
def _testRemoteIteratorHelper(self, device0, device1, target):
|
||||||
with ops.device(device1):
|
with ops.device(device1):
|
||||||
@ -134,12 +134,12 @@ class IteratorClusterTest(test.TestCase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with session.Session(worker[0].target) as sess:
|
with session.Session(worker[0].target) as sess:
|
||||||
sess.run(table.initializer)
|
self.evaluate(table.initializer)
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
self.assertAllEqual([0, 0, -1, 1, 2], sess.run(get_next))
|
self.assertAllEqual([0, 0, -1, 1, 2], self.evaluate(get_next))
|
||||||
|
|
||||||
with session.Session(worker[0].target) as sess:
|
with session.Session(worker[0].target) as sess:
|
||||||
self.assertAllEqual([2, 0], sess.run(get_next))
|
self.assertAllEqual([2, 0], self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -166,7 +166,7 @@ class IteratorClusterTest(test.TestCase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with session.Session(worker[0].target) as sess:
|
with session.Session(worker[0].target) as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
|
@ -97,7 +97,7 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
|||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
for _ in range(14):
|
for _ in range(14):
|
||||||
for i in range(7):
|
for i in range(7):
|
||||||
result = sess.run(get_next)
|
result = self.evaluate(get_next)
|
||||||
for component, result_component in zip(components, result):
|
for component, result_component in zip(components, result):
|
||||||
self.assertAllEqual(component[i]**2, result_component)
|
self.assertAllEqual(component[i]**2, result_component)
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
@ -123,7 +123,7 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
|||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
for _ in range(14):
|
for _ in range(14):
|
||||||
for i in range(7):
|
for i in range(7):
|
||||||
result = sess.run(get_next)
|
result = self.evaluate(get_next)
|
||||||
for component, result_component in zip(components, result):
|
for component, result_component in zip(components, result):
|
||||||
self.assertAllEqual(component[i]**2, result_component)
|
self.assertAllEqual(component[i]**2, result_component)
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
@ -159,7 +159,7 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
for _ in range(14):
|
for _ in range(14):
|
||||||
for i in range(7):
|
for i in range(7):
|
||||||
result = sess.run(get_next)
|
result = self.evaluate(get_next)
|
||||||
for component, result_component in zip(components, result):
|
for component, result_component in zip(components, result):
|
||||||
self.assertAllEqual(component[i]**2, result_component)
|
self.assertAllEqual(component[i]**2, result_component)
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
@ -175,7 +175,7 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
|||||||
config = config_pb2.ConfigProto(
|
config = config_pb2.ConfigProto(
|
||||||
inter_op_parallelism_threads=1, use_per_session_threads=True)
|
inter_op_parallelism_threads=1, use_per_session_threads=True)
|
||||||
with session.Session(config=config) as sess:
|
with session.Session(config=config) as sess:
|
||||||
self.assertAllEqual([1, 4, 9], sess.run(next_element))
|
self.assertAllEqual([1, 4, 9], self.evaluate(next_element))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
|
|
||||||
@ -254,15 +254,15 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with session.Session(server.target) as sess:
|
with session.Session(server.target) as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
results = sess.run(get_next)
|
results = self.evaluate(get_next)
|
||||||
for component, result_component in zip(components, results):
|
for component, result_component in zip(components, results):
|
||||||
self.assertAllEqual(component, result_component)
|
self.assertAllEqual(component, result_component)
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
# Re-initialize the iterator in the first session.
|
# Re-initialize the iterator in the first session.
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
|
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
# Re-define the iterator manually, without defining any of the
|
# Re-define the iterator manually, without defining any of the
|
||||||
@ -277,7 +277,7 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
with session.Session(server.target) as sess:
|
with session.Session(server.target) as sess:
|
||||||
# Use the iterator without re-initializing in the second session.
|
# Use the iterator without re-initializing in the second session.
|
||||||
results = sess.run(get_next)
|
results = self.evaluate(get_next)
|
||||||
for component, result_component in zip(components, results):
|
for component, result_component in zip(components, results):
|
||||||
self.assertAllEqual(component, result_component)
|
self.assertAllEqual(component, result_component)
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
@ -317,20 +317,20 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
|||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
# Initialize with one dataset.
|
# Initialize with one dataset.
|
||||||
sess.run(dataset_3_init_op)
|
self.evaluate(dataset_3_init_op)
|
||||||
self.assertAllEqual([1, 2, 3], sess.run(get_next))
|
self.assertAllEqual([1, 2, 3], self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
# Initialize with a different dataset.
|
# Initialize with a different dataset.
|
||||||
sess.run(dataset_4_init_op)
|
self.evaluate(dataset_4_init_op)
|
||||||
self.assertAllEqual([4, 5, 6, 7], sess.run(get_next))
|
self.assertAllEqual([4, 5, 6, 7], self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
# Reinitialize with the first dataset.
|
# Reinitialize with the first dataset.
|
||||||
sess.run(dataset_3_init_op)
|
self.evaluate(dataset_3_init_op)
|
||||||
self.assertAllEqual([1, 2, 3], sess.run(get_next))
|
self.assertAllEqual([1, 2, 3], self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -348,7 +348,7 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
|||||||
g, output_types=dtypes.int64)
|
g, output_types=dtypes.int64)
|
||||||
sess.run(iterator.make_initializer(dataset_1))
|
sess.run(iterator.make_initializer(dataset_1))
|
||||||
for expected in range(10):
|
for expected in range(10):
|
||||||
self.assertEqual(expected, sess.run(next_element))
|
self.assertEqual(expected, self.evaluate(next_element))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
|
|
||||||
@ -356,7 +356,7 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
|||||||
g, output_types=dtypes.int64)
|
g, output_types=dtypes.int64)
|
||||||
sess.run(iterator.make_initializer(dataset_2))
|
sess.run(iterator.make_initializer(dataset_2))
|
||||||
for expected in range(10):
|
for expected in range(10):
|
||||||
self.assertEqual(expected, sess.run(next_element))
|
self.assertEqual(expected, self.evaluate(next_element))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
|
|
||||||
@ -679,10 +679,10 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
|||||||
n = itr.get_next()
|
n = itr.get_next()
|
||||||
|
|
||||||
with session.Session(s3.target, config=config) as sess:
|
with session.Session(s3.target, config=config) as sess:
|
||||||
sess.run(itr.initializer)
|
self.evaluate(itr.initializer)
|
||||||
expected_values = worker_devices
|
expected_values = worker_devices
|
||||||
for expected in expected_values:
|
for expected in expected_values:
|
||||||
self.assertEqual((compat.as_bytes(expected),), sess.run(n))
|
self.assertEqual((compat.as_bytes(expected),), self.evaluate(n))
|
||||||
|
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(n)
|
sess.run(n)
|
||||||
@ -786,8 +786,8 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
|||||||
with ops.Graph().as_default() as g:
|
with ops.Graph().as_default() as g:
|
||||||
init_op, _, save_op, _ = _build_range_dataset_graph()
|
init_op, _, save_op, _ = _build_range_dataset_graph()
|
||||||
with self.session(graph=g) as sess:
|
with self.session(graph=g) as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
sess.run(save_op)
|
self.evaluate(save_op)
|
||||||
|
|
||||||
# Attempt to restore the saved iterator into an IteratorResource of
|
# Attempt to restore the saved iterator into an IteratorResource of
|
||||||
# incompatible type. An iterator of RangeDataset has output type int64,
|
# incompatible type. An iterator of RangeDataset has output type int64,
|
||||||
@ -798,7 +798,7 @@ class IteratorTest(test.TestCase, parameterized.TestCase):
|
|||||||
_, _, _, restore_op = _build_reader_dataset_graph()
|
_, _, _, restore_op = _build_reader_dataset_graph()
|
||||||
with self.session(graph=g) as sess:
|
with self.session(graph=g) as sess:
|
||||||
with self.assertRaises(errors.InvalidArgumentError):
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
sess.run(restore_op)
|
self.evaluate(restore_op)
|
||||||
|
|
||||||
def testRepeatedGetNextWarning(self):
|
def testRepeatedGetNextWarning(self):
|
||||||
iterator = dataset_ops.Dataset.range(10).make_one_shot_iterator()
|
iterator = dataset_ops.Dataset.range(10).make_one_shot_iterator()
|
||||||
@ -949,7 +949,7 @@ class IteratorCheckpointingTest(test.TestCase):
|
|||||||
checkpoint.restore(checkpoint_management.latest_checkpoint(
|
checkpoint.restore(checkpoint_management.latest_checkpoint(
|
||||||
checkpoint_directory)).initialize_or_restore(sess)
|
checkpoint_directory)).initialize_or_restore(sess)
|
||||||
for j in range(2):
|
for j in range(2):
|
||||||
self.assertEqual(i * 2 + j, sess.run(get_next))
|
self.assertEqual(i * 2 + j, self.evaluate(get_next))
|
||||||
checkpoint.save(file_prefix=checkpoint_prefix)
|
checkpoint.save(file_prefix=checkpoint_prefix)
|
||||||
|
|
||||||
|
|
||||||
|
@ -102,7 +102,7 @@ class ListFilesDatasetOpTest(test_base.DatasetTestBase):
|
|||||||
all_produced_filenames = []
|
all_produced_filenames = []
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
produced_filenames = []
|
produced_filenames = []
|
||||||
sess.run(itr.initializer)
|
self.evaluate(itr.initializer)
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
produced_filenames.append(sess.run(next_element))
|
produced_filenames.append(sess.run(next_element))
|
||||||
|
@ -114,7 +114,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
sess.run(init_op, feed_dict={count: 14})
|
sess.run(init_op, feed_dict={count: 14})
|
||||||
for _ in range(14):
|
for _ in range(14):
|
||||||
for i in range(7):
|
for i in range(7):
|
||||||
result = sess.run(get_next)
|
result = self.evaluate(get_next)
|
||||||
for component, result_component in zip(components, result):
|
for component, result_component in zip(components, result):
|
||||||
self.assertAllEqual(component[i]**2, result_component)
|
self.assertAllEqual(component[i]**2, result_component)
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
@ -185,7 +185,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
output_buffer_size: output_buffer_size_val})
|
output_buffer_size: output_buffer_size_val})
|
||||||
for _ in range(14):
|
for _ in range(14):
|
||||||
for i in range(7):
|
for i in range(7):
|
||||||
result = sess.run(get_next)
|
result = self.evaluate(get_next)
|
||||||
for component, result_component in zip(components, result):
|
for component, result_component in zip(components, result):
|
||||||
self.assertAllEqual(component[i]**2, result_component)
|
self.assertAllEqual(component[i]**2, result_component)
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
@ -242,7 +242,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -257,7 +257,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -272,7 +272,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
# The 4th element is NaN, so `array_ops.check_numerics()` should fail.
|
# The 4th element is NaN, so `array_ops.check_numerics()` should fail.
|
||||||
@ -293,7 +293,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
# The 4th element is NaN, so `array_ops.check_numerics()` should fail.
|
# The 4th element is NaN, so `array_ops.check_numerics()` should fail.
|
||||||
@ -325,10 +325,10 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
with ops.Graph().as_default() as g:
|
with ops.Graph().as_default() as g:
|
||||||
captured_init_op, init_op, get_next = _build_graph()
|
captured_init_op, init_op, get_next = _build_graph()
|
||||||
with self.session(graph=g) as sess:
|
with self.session(graph=g) as sess:
|
||||||
sess.run(captured_init_op)
|
self.evaluate(captured_init_op)
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
self.assertEqual(i * i, sess.run(get_next))
|
self.assertEqual(i * i, self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -353,8 +353,8 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(table.initializer)
|
self.evaluate(table.initializer)
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
@ -371,11 +371,11 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(enqueue_op)
|
self.evaluate(enqueue_op)
|
||||||
sess.run(close_op)
|
self.evaluate(close_op)
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for element in elements:
|
for element in elements:
|
||||||
self.assertEqual(element, sess.run(get_next))
|
self.assertEqual(element, self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -396,9 +396,9 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(enqueue_op)
|
self.evaluate(enqueue_op)
|
||||||
sess.run(close_op)
|
self.evaluate(close_op)
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
self.assertEqual(sorted([elements[i * 2], elements[i * 2 + 1]]),
|
self.assertEqual(sorted([elements[i * 2], elements[i * 2 + 1]]),
|
||||||
sorted(sess.run(get_next)))
|
sorted(sess.run(get_next)))
|
||||||
@ -415,15 +415,15 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(counter_var.initializer)
|
self.evaluate(counter_var.initializer)
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
self.assertEqual(i, sess.run(counter_var))
|
self.assertEqual(i, self.evaluate(counter_var))
|
||||||
self.assertEqual(i + 1, sess.run(get_next))
|
self.assertEqual(i + 1, self.evaluate(get_next))
|
||||||
self.assertEqual(10, sess.run(counter_var))
|
self.assertEqual(10, self.evaluate(counter_var))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
self.assertEqual(10, sess.run(counter_var))
|
self.assertEqual(10, self.evaluate(counter_var))
|
||||||
|
|
||||||
def testCaptureUninitializedVariableError(self):
|
def testCaptureUninitializedVariableError(self):
|
||||||
counter_var = variable_scope.get_variable(
|
counter_var = variable_scope.get_variable(
|
||||||
@ -435,7 +435,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
with self.assertRaises(errors.NotFoundError):
|
with self.assertRaises(errors.NotFoundError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -447,14 +447,14 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
random_values = []
|
random_values = []
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
while True:
|
while True:
|
||||||
random_values.extend(sess.run(get_next))
|
random_values.extend(sess.run(get_next))
|
||||||
self.assertEqual(10, len(random_values))
|
self.assertEqual(10, len(random_values))
|
||||||
self.assertGreater(np.abs(np.diff(random_values)).max(), 1e-6)
|
self.assertGreater(np.abs(np.diff(random_values)).max(), 1e-6)
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
random_values_2 = []
|
random_values_2 = []
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
while True:
|
while True:
|
||||||
@ -473,8 +473,8 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
random_values = sess.run(get_next)
|
random_values = self.evaluate(get_next)
|
||||||
|
|
||||||
# Assert that one of the next 99 batches yielded by the iterator is
|
# Assert that one of the next 99 batches yielded by the iterator is
|
||||||
# different from the first.
|
# different from the first.
|
||||||
@ -500,15 +500,15 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(counter_var.initializer)
|
self.evaluate(counter_var.initializer)
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
self.assertEqual(i, sess.run(counter_var))
|
self.assertEqual(i, self.evaluate(counter_var))
|
||||||
self.assertEqual(i, sess.run(get_next))
|
self.assertEqual(i, self.evaluate(get_next))
|
||||||
self.assertEqual(10, sess.run(counter_var))
|
self.assertEqual(10, self.evaluate(counter_var))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
self.assertEqual(10, sess.run(counter_var))
|
self.assertEqual(10, self.evaluate(counter_var))
|
||||||
|
|
||||||
def testMapDict(self):
|
def testMapDict(self):
|
||||||
iterator = (dataset_ops.Dataset.range(10)
|
iterator = (dataset_ops.Dataset.range(10)
|
||||||
@ -519,9 +519,9 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
self.assertEqual(i * 2 + i ** 2, sess.run(get_next))
|
self.assertEqual(i * 2 + i**2, self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -569,8 +569,8 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
self.assertAllEqual(row ** 2, sess.run(get_next))
|
self.assertAllEqual(row**2, self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -611,7 +611,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
row = np.arange(6)
|
row = np.arange(6)
|
||||||
for num in [2, 3, 4]:
|
for num in [2, 3, 4]:
|
||||||
init_op, get_next = build_dataset(row, num)
|
init_op, get_next = build_dataset(row, num)
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for i in range(6):
|
for i in range(6):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
(i // 2 if i % 2 else i * 2) if (num == 2 or num == 3) else i * 2,
|
(i // 2 if i % 2 else i * 2) if (num == 2 or num == 3) else i * 2,
|
||||||
@ -652,7 +652,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
row = np.arange(6)
|
row = np.arange(6)
|
||||||
for num in [2, 3, 4]:
|
for num in [2, 3, 4]:
|
||||||
init_op, get_next = build_dataset(row, num)
|
init_op, get_next = build_dataset(row, num)
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
self.assertAllEqual(
|
self.assertAllEqual(
|
||||||
[x // 2 if (num == 2 or num == 3) else x * 2 for x in row],
|
[x // 2 if (num == 2 or num == 3) else x * 2 for x in row],
|
||||||
sess.run(get_next))
|
sess.run(get_next))
|
||||||
@ -697,7 +697,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
self.assertAllEqual([(x // 2 if x % 2 else x * 2) if
|
self.assertAllEqual([(x // 2 if x % 2 else x * 2) if
|
||||||
(num == 2 or num == 3) else x * 2 for x in row],
|
(num == 2 or num == 3) else x * 2 for x in row],
|
||||||
sess.run(get_next))
|
sess.run(get_next))
|
||||||
@ -735,7 +735,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
for buffer_size in [1, 10, 100, 1000]:
|
for buffer_size in [1, 10, 100, 1000]:
|
||||||
sess.run(init_op, feed_dict={buffer_size_placeholder: buffer_size})
|
sess.run(init_op, feed_dict={buffer_size_placeholder: buffer_size})
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
self.assertEqual(i * i, sess.run(get_next))
|
self.assertEqual(i * i, self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -753,10 +753,10 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
sess.run(init_op, feed_dict={buffer_size_placeholder: buffer_size})
|
sess.run(init_op, feed_dict={buffer_size_placeholder: buffer_size})
|
||||||
for i in range(event_will_be_set_after_consuming):
|
for i in range(event_will_be_set_after_consuming):
|
||||||
self.assertFalse(ev.is_set())
|
self.assertFalse(ev.is_set())
|
||||||
self.assertEqual(i * i, sess.run(get_next))
|
self.assertEqual(i * i, self.evaluate(get_next))
|
||||||
ev.wait()
|
ev.wait()
|
||||||
for i in range(event_will_be_set_after_consuming, 100):
|
for i in range(event_will_be_set_after_consuming, 100):
|
||||||
self.assertEqual(i * i, sess.run(get_next))
|
self.assertEqual(i * i, self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -768,9 +768,9 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
self.assertEqual((i, 37.0), sess.run(get_next))
|
self.assertEqual((i, 37.0), self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -789,9 +789,9 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
self.assertEqual((i, 37.0), sess.run(get_next))
|
self.assertEqual((i, 37.0), self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -810,9 +810,9 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
actual = sess.run(get_next)
|
actual = self.evaluate(get_next)
|
||||||
self.assertIsInstance(actual, sparse_tensor.SparseTensorValue)
|
self.assertIsInstance(actual, sparse_tensor.SparseTensorValue)
|
||||||
self.assertSparseValuesEqual(actual, _sparse(i))
|
self.assertSparseValuesEqual(actual, _sparse(i))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
@ -837,9 +837,9 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
actual = sess.run(get_next)
|
actual = self.evaluate(get_next)
|
||||||
self.assertIsInstance(actual, sparse_tensor.SparseTensorValue)
|
self.assertIsInstance(actual, sparse_tensor.SparseTensorValue)
|
||||||
self.assertSparseValuesEqual(actual, _check(_sparse(i)).eval())
|
self.assertSparseValuesEqual(actual, _check(_sparse(i)).eval())
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
@ -861,9 +861,9 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
self.assertEqual(i, sess.run(get_next))
|
self.assertEqual(i, self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -875,9 +875,9 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
self.assertEqual((i, b"hello", 10), sess.run(get_next))
|
self.assertEqual((i, b"hello", 10), self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -945,7 +945,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
with self.assertRaisesRegexp(errors.InvalidArgumentError, "BrokenConst"):
|
with self.assertRaisesRegexp(errors.InvalidArgumentError, "BrokenConst"):
|
||||||
sess.run(iterator.initializer)
|
self.evaluate(iterator.initializer)
|
||||||
|
|
||||||
# pylint: disable=g-long-lambda
|
# pylint: disable=g-long-lambda
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
@ -972,7 +972,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
tids = sess.run(get_next)
|
tids = self.evaluate(get_next)
|
||||||
self.assertTrue(all(tids[0] == tid for tid in tids))
|
self.assertTrue(all(tids[0] == tid for tid in tids))
|
||||||
# pylint: enable=g-long-lambda
|
# pylint: enable=g-long-lambda
|
||||||
|
|
||||||
@ -996,7 +996,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
expected = map_fn(*sess.run(self.structuredElement(structure)))
|
expected = map_fn(*sess.run(self.structuredElement(structure)))
|
||||||
else:
|
else:
|
||||||
expected = map_fn(sess.run(self.structuredElement(structure)))
|
expected = map_fn(sess.run(self.structuredElement(structure)))
|
||||||
self.assertEqual(expected, sess.run(get_next))
|
self.assertEqual(expected, self.evaluate(get_next))
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
("Sequential", None),
|
("Sequential", None),
|
||||||
@ -1011,7 +1011,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(iterator.initializer, feed_dict={captured_t: 42})
|
sess.run(iterator.initializer, feed_dict={captured_t: 42})
|
||||||
self.assertEqual(42, sess.run(get_next))
|
self.assertEqual(42, self.evaluate(get_next))
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
("1", 1, 1),
|
("1", 1, 1),
|
||||||
@ -1030,7 +1030,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
with self.cached_session(config=config) as sess:
|
with self.cached_session(config=config) as sess:
|
||||||
for i in range(num_elements):
|
for i in range(num_elements):
|
||||||
coordination_events[i].set()
|
coordination_events[i].set()
|
||||||
self.assertEqual(i * i, sess.run(get_next))
|
self.assertEqual(i * i, self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -1052,7 +1052,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
|
|
||||||
for element in elements:
|
for element in elements:
|
||||||
coordination_events[element].set()
|
coordination_events[element].set()
|
||||||
self.assertEqual(element * element, sess.run(get_next))
|
self.assertEqual(element * element, self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
|
@ -40,7 +40,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
|||||||
|
|
||||||
config = config_pb2.ConfigProto(device_count={"CPU": 3})
|
config = config_pb2.ConfigProto(device_count={"CPU": 3})
|
||||||
with self.test_session(config=config) as sess:
|
with self.test_session(config=config) as sess:
|
||||||
sess.run(multi_device_iterator.initializer)
|
self.evaluate(multi_device_iterator.initializer)
|
||||||
|
|
||||||
def testBasic(self):
|
def testBasic(self):
|
||||||
dataset = dataset_ops.Dataset.range(10)
|
dataset = dataset_ops.Dataset.range(10)
|
||||||
@ -50,10 +50,10 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
|||||||
|
|
||||||
config = config_pb2.ConfigProto(device_count={"CPU": 3})
|
config = config_pb2.ConfigProto(device_count={"CPU": 3})
|
||||||
with self.test_session(config=config) as sess:
|
with self.test_session(config=config) as sess:
|
||||||
sess.run(multi_device_iterator.initializer)
|
self.evaluate(multi_device_iterator.initializer)
|
||||||
for i in range(0, 10, 2):
|
for i in range(0, 10, 2):
|
||||||
self.assertEqual(i, sess.run(elem_on_1))
|
self.assertEqual(i, self.evaluate(elem_on_1))
|
||||||
self.assertEqual(i + 1, sess.run(elem_on_2))
|
self.assertEqual(i + 1, self.evaluate(elem_on_2))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(elem_on_1)
|
sess.run(elem_on_1)
|
||||||
sess.run(elem_on_2)
|
sess.run(elem_on_2)
|
||||||
@ -67,10 +67,10 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
|||||||
|
|
||||||
config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||||
with self.test_session(config=config) as sess:
|
with self.test_session(config=config) as sess:
|
||||||
sess.run(multi_device_iterator.initializer)
|
self.evaluate(multi_device_iterator.initializer)
|
||||||
for i in range(0, 10, 2):
|
for i in range(0, 10, 2):
|
||||||
self.assertEqual(i, sess.run(elem_on_1))
|
self.assertEqual(i, self.evaluate(elem_on_1))
|
||||||
self.assertEqual(i + 1, sess.run(elem_on_2))
|
self.assertEqual(i + 1, self.evaluate(elem_on_2))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(elem_on_1)
|
sess.run(elem_on_1)
|
||||||
sess.run(elem_on_2)
|
sess.run(elem_on_2)
|
||||||
@ -85,12 +85,12 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
|||||||
|
|
||||||
config = config_pb2.ConfigProto(device_count={"CPU": 3})
|
config = config_pb2.ConfigProto(device_count={"CPU": 3})
|
||||||
with self.test_session(config=config) as sess:
|
with self.test_session(config=config) as sess:
|
||||||
sess.run(multi_device_iterator.initializer)
|
self.evaluate(multi_device_iterator.initializer)
|
||||||
for i in range(0, 20, 4):
|
for i in range(0, 20, 4):
|
||||||
self.assertEqual(i, sess.run(elem_on_1))
|
self.assertEqual(i, self.evaluate(elem_on_1))
|
||||||
self.assertEqual(i + 1, sess.run(elem_on_2))
|
self.assertEqual(i + 1, self.evaluate(elem_on_2))
|
||||||
self.assertEqual(i + 2, sess.run(elem_on_3))
|
self.assertEqual(i + 2, self.evaluate(elem_on_3))
|
||||||
self.assertEqual(i + 3, sess.run(elem_on_4))
|
self.assertEqual(i + 3, self.evaluate(elem_on_4))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(elem_on_1)
|
sess.run(elem_on_1)
|
||||||
sess.run(elem_on_2)
|
sess.run(elem_on_2)
|
||||||
@ -105,11 +105,11 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
|||||||
|
|
||||||
config = config_pb2.ConfigProto(device_count={"CPU": 3})
|
config = config_pb2.ConfigProto(device_count={"CPU": 3})
|
||||||
with self.test_session(config=config) as sess:
|
with self.test_session(config=config) as sess:
|
||||||
sess.run(multi_device_iterator.initializer)
|
self.evaluate(multi_device_iterator.initializer)
|
||||||
for i in range(0, 8, 2):
|
for i in range(0, 8, 2):
|
||||||
self.assertEqual(i, sess.run(elem_on_1))
|
self.assertEqual(i, self.evaluate(elem_on_1))
|
||||||
self.assertEqual(i + 1, sess.run(elem_on_2))
|
self.assertEqual(i + 1, self.evaluate(elem_on_2))
|
||||||
self.assertEqual(8, sess.run(elem_on_1))
|
self.assertEqual(8, self.evaluate(elem_on_1))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(elem_on_1)
|
sess.run(elem_on_1)
|
||||||
sess.run(elem_on_2)
|
sess.run(elem_on_2)
|
||||||
@ -126,7 +126,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
|||||||
|
|
||||||
config = config_pb2.ConfigProto(device_count={"CPU": 3})
|
config = config_pb2.ConfigProto(device_count={"CPU": 3})
|
||||||
with self.test_session(config=config) as sess:
|
with self.test_session(config=config) as sess:
|
||||||
sess.run(multi_device_iterator.initializer)
|
self.evaluate(multi_device_iterator.initializer)
|
||||||
for i in range(0, 8, 2):
|
for i in range(0, 8, 2):
|
||||||
elem_on_1_has_value, elem_on_1_value = sess.run(
|
elem_on_1_has_value, elem_on_1_value = sess.run(
|
||||||
[elem_on_1_has_value_t, elem_on_1_t])
|
[elem_on_1_has_value_t, elem_on_1_t])
|
||||||
@ -140,8 +140,8 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
|||||||
[elem_on_1_has_value_t, elem_on_1_t])
|
[elem_on_1_has_value_t, elem_on_1_t])
|
||||||
self.assertTrue(elem_on_1_has_value)
|
self.assertTrue(elem_on_1_has_value)
|
||||||
self.assertEqual(8, elem_on_1_value)
|
self.assertEqual(8, elem_on_1_value)
|
||||||
self.assertFalse(sess.run(elem_on_1_has_value_t))
|
self.assertFalse(self.evaluate(elem_on_1_has_value_t))
|
||||||
self.assertFalse(sess.run(elem_on_2_has_value_t))
|
self.assertFalse(self.evaluate(elem_on_2_has_value_t))
|
||||||
with self.assertRaises(errors.InvalidArgumentError):
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
sess.run(elem_on_1_t)
|
sess.run(elem_on_1_t)
|
||||||
with self.assertRaises(errors.InvalidArgumentError):
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
@ -155,11 +155,11 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
|||||||
|
|
||||||
config = config_pb2.ConfigProto(device_count={"CPU": 3})
|
config = config_pb2.ConfigProto(device_count={"CPU": 3})
|
||||||
with self.test_session(config=config) as sess:
|
with self.test_session(config=config) as sess:
|
||||||
sess.run(multi_device_iterator.initializer)
|
self.evaluate(multi_device_iterator.initializer)
|
||||||
for i in range(0, 10, 2):
|
for i in range(0, 10, 2):
|
||||||
self.assertEqual(i, sess.run(elem_on_1))
|
self.assertEqual(i, self.evaluate(elem_on_1))
|
||||||
for i in range(0, 10, 2):
|
for i in range(0, 10, 2):
|
||||||
self.assertEqual(i + 1, sess.run(elem_on_2))
|
self.assertEqual(i + 1, self.evaluate(elem_on_2))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(elem_on_1)
|
sess.run(elem_on_1)
|
||||||
sess.run(elem_on_2)
|
sess.run(elem_on_2)
|
||||||
@ -192,10 +192,10 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
|||||||
|
|
||||||
config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1})
|
config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1})
|
||||||
with self.test_session(config=config) as sess:
|
with self.test_session(config=config) as sess:
|
||||||
sess.run(multi_device_iterator.initializer)
|
self.evaluate(multi_device_iterator.initializer)
|
||||||
for i in range(0, 10, 2):
|
for i in range(0, 10, 2):
|
||||||
self.assertEqual(i, sess.run(elem_on_1))
|
self.assertEqual(i, self.evaluate(elem_on_1))
|
||||||
self.assertEqual(i + 1, sess.run(elem_on_2))
|
self.assertEqual(i + 1, self.evaluate(elem_on_2))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(elem_on_1)
|
sess.run(elem_on_1)
|
||||||
sess.run(elem_on_2)
|
sess.run(elem_on_2)
|
||||||
@ -211,11 +211,11 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
|||||||
|
|
||||||
config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1})
|
config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1})
|
||||||
with self.test_session(config=config) as sess:
|
with self.test_session(config=config) as sess:
|
||||||
sess.run(multi_device_iterator.initializer)
|
self.evaluate(multi_device_iterator.initializer)
|
||||||
for i in range(0, 10, 2):
|
for i in range(0, 10, 2):
|
||||||
self.assertEqual(i, sess.run(elem_on_1))
|
self.assertEqual(i, self.evaluate(elem_on_1))
|
||||||
for i in range(0, 10, 2):
|
for i in range(0, 10, 2):
|
||||||
self.assertEqual(i + 1, sess.run(elem_on_2))
|
self.assertEqual(i + 1, self.evaluate(elem_on_2))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(elem_on_1)
|
sess.run(elem_on_1)
|
||||||
sess.run(elem_on_2)
|
sess.run(elem_on_2)
|
||||||
@ -235,7 +235,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
|||||||
|
|
||||||
config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1})
|
config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1})
|
||||||
with self.test_session(config=config) as sess:
|
with self.test_session(config=config) as sess:
|
||||||
sess.run(multi_device_iterator.initializer)
|
self.evaluate(multi_device_iterator.initializer)
|
||||||
for i in range(0, 8, 2):
|
for i in range(0, 8, 2):
|
||||||
elem_on_1_has_value, elem_on_1_value = sess.run(
|
elem_on_1_has_value, elem_on_1_value = sess.run(
|
||||||
[elem_on_1_has_value_t, elem_on_1_t])
|
[elem_on_1_has_value_t, elem_on_1_t])
|
||||||
@ -249,8 +249,8 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
|||||||
[elem_on_1_has_value_t, elem_on_1_t])
|
[elem_on_1_has_value_t, elem_on_1_t])
|
||||||
self.assertTrue(elem_on_1_has_value)
|
self.assertTrue(elem_on_1_has_value)
|
||||||
self.assertEqual(8, elem_on_1_value)
|
self.assertEqual(8, elem_on_1_value)
|
||||||
self.assertFalse(sess.run(elem_on_1_has_value_t))
|
self.assertFalse(self.evaluate(elem_on_1_has_value_t))
|
||||||
self.assertFalse(sess.run(elem_on_2_has_value_t))
|
self.assertFalse(self.evaluate(elem_on_2_has_value_t))
|
||||||
with self.assertRaises(errors.InvalidArgumentError):
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
sess.run(elem_on_1_t)
|
sess.run(elem_on_1_t)
|
||||||
with self.assertRaises(errors.InvalidArgumentError):
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
@ -272,10 +272,10 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
|||||||
|
|
||||||
config = config_pb2.ConfigProto(device_count={"CPU": 3})
|
config = config_pb2.ConfigProto(device_count={"CPU": 3})
|
||||||
with self.test_session(config=config) as sess:
|
with self.test_session(config=config) as sess:
|
||||||
sess.run(multi_device_iterator.initializer)
|
self.evaluate(multi_device_iterator.initializer)
|
||||||
for i in range(0, 10, 2):
|
for i in range(0, 10, 2):
|
||||||
self.assertEqual(i, sess.run(elem_on_1))
|
self.assertEqual(i, self.evaluate(elem_on_1))
|
||||||
self.assertEqual(i + 1, sess.run(elem_on_2))
|
self.assertEqual(i + 1, self.evaluate(elem_on_2))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(elem_on_1)
|
sess.run(elem_on_1)
|
||||||
sess.run(elem_on_2)
|
sess.run(elem_on_2)
|
||||||
|
@ -227,7 +227,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
|
|
||||||
# For each element of the dataset, assert that the optional evaluates to
|
# For each element of the dataset, assert that the optional evaluates to
|
||||||
# the expected value.
|
# the expected value.
|
||||||
sess.run(iterator.initializer)
|
self.evaluate(iterator.initializer)
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
elem_has_value, elem_value = sess.run([elem_has_value_t, elem_value_t])
|
elem_has_value, elem_value = sess.run([elem_has_value_t, elem_value_t])
|
||||||
self.assertTrue(elem_has_value)
|
self.assertTrue(elem_has_value)
|
||||||
@ -236,7 +236,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
# After exhausting the iterator, `next_elem.has_value()` will evaluate to
|
# After exhausting the iterator, `next_elem.has_value()` will evaluate to
|
||||||
# false, and attempting to get the value will fail.
|
# false, and attempting to get the value will fail.
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
self.assertFalse(sess.run(elem_has_value_t))
|
self.assertFalse(self.evaluate(elem_has_value_t))
|
||||||
with self.assertRaises(errors.InvalidArgumentError):
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
sess.run(elem_value_t)
|
sess.run(elem_value_t)
|
||||||
|
|
||||||
|
@ -40,7 +40,7 @@ class PrefetchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op, feed_dict={buffer_size_t: buffer_size})
|
sess.run(init_op, feed_dict={buffer_size_t: buffer_size})
|
||||||
for m in range(10):
|
for m in range(10):
|
||||||
self.assertEqual(m, sess.run(get_next))
|
self.assertEqual(m, self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
|
@ -124,19 +124,19 @@ class ExperimentalCheckpointDatasetTest(test_base.DatasetTestBase):
|
|||||||
with ops.Graph().as_default() as g:
|
with ops.Graph().as_default() as g:
|
||||||
init_op, get_next, save_op, _ = _build_graph(start, stop)
|
init_op, get_next, save_op, _ = _build_graph(start, stop)
|
||||||
with self.session(graph=g) as sess:
|
with self.session(graph=g) as sess:
|
||||||
sess.run(variables.global_variables_initializer())
|
self.evaluate(variables.global_variables_initializer())
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for i in range(start, break_point):
|
for i in range(start, break_point):
|
||||||
self.assertEqual(i, sess.run(get_next))
|
self.assertEqual(i, self.evaluate(get_next))
|
||||||
sess.run(save_op)
|
self.evaluate(save_op)
|
||||||
|
|
||||||
with ops.Graph().as_default() as g:
|
with ops.Graph().as_default() as g:
|
||||||
init_op, get_next, _, restore_op = _build_graph(start, stop)
|
init_op, get_next, _, restore_op = _build_graph(start, stop)
|
||||||
with self.session(graph=g) as sess:
|
with self.session(graph=g) as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
sess.run(restore_op)
|
self.evaluate(restore_op)
|
||||||
for i in range(break_point, stop):
|
for i in range(break_point, stop):
|
||||||
self.assertEqual(i, sess.run(get_next))
|
self.assertEqual(i, self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -144,14 +144,14 @@ class ExperimentalCheckpointDatasetTest(test_base.DatasetTestBase):
|
|||||||
with ops.Graph().as_default() as g:
|
with ops.Graph().as_default() as g:
|
||||||
init_op, get_next, save_op, restore_op = _build_graph(start, stop)
|
init_op, get_next, save_op, restore_op = _build_graph(start, stop)
|
||||||
with self.session(graph=g) as sess:
|
with self.session(graph=g) as sess:
|
||||||
sess.run(variables.global_variables_initializer())
|
self.evaluate(variables.global_variables_initializer())
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for i in range(start, break_point):
|
for i in range(start, break_point):
|
||||||
self.assertEqual(i, sess.run(get_next))
|
self.assertEqual(i, self.evaluate(get_next))
|
||||||
sess.run(save_op)
|
self.evaluate(save_op)
|
||||||
sess.run(restore_op)
|
self.evaluate(restore_op)
|
||||||
for i in range(break_point, stop):
|
for i in range(break_point, stop):
|
||||||
self.assertEqual(i, sess.run(get_next))
|
self.assertEqual(i, self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -175,14 +175,14 @@ class ExperimentalCheckpointDatasetTest(test_base.DatasetTestBase):
|
|||||||
with ops.Graph().as_default() as g:
|
with ops.Graph().as_default() as g:
|
||||||
init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs)
|
init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs)
|
||||||
with self.session(graph=g) as sess:
|
with self.session(graph=g) as sess:
|
||||||
sess.run(variables.global_variables_initializer())
|
self.evaluate(variables.global_variables_initializer())
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for _ in range(break_epoch):
|
for _ in range(break_epoch):
|
||||||
for i in range(start, stop):
|
for i in range(start, stop):
|
||||||
self.assertEqual(i, sess.run(get_next))
|
self.assertEqual(i, self.evaluate(get_next))
|
||||||
for i in range(start, break_point):
|
for i in range(start, break_point):
|
||||||
self.assertEqual(i, sess.run(get_next))
|
self.assertEqual(i, self.evaluate(get_next))
|
||||||
sess.run(save_op)
|
self.evaluate(save_op)
|
||||||
|
|
||||||
with ops.Graph().as_default() as g:
|
with ops.Graph().as_default() as g:
|
||||||
# Create an empty IteratorResource and restore the Iterator into it.
|
# Create an empty IteratorResource and restore the Iterator into it.
|
||||||
@ -193,12 +193,12 @@ class ExperimentalCheckpointDatasetTest(test_base.DatasetTestBase):
|
|||||||
restore_op = self._restore_op(iterator._iterator_resource)
|
restore_op = self._restore_op(iterator._iterator_resource)
|
||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
with self.session(graph=g) as sess:
|
with self.session(graph=g) as sess:
|
||||||
sess.run(restore_op)
|
self.evaluate(restore_op)
|
||||||
for i in range(break_point, stop):
|
for i in range(break_point, stop):
|
||||||
self.assertEqual(i, sess.run(get_next))
|
self.assertEqual(i, self.evaluate(get_next))
|
||||||
for _ in range(break_epoch + 1, num_epochs):
|
for _ in range(break_epoch + 1, num_epochs):
|
||||||
for i in range(start, stop):
|
for i in range(start, stop):
|
||||||
self.assertEqual(i, sess.run(get_next))
|
self.assertEqual(i, self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -221,20 +221,20 @@ class ExperimentalCheckpointDatasetTest(test_base.DatasetTestBase):
|
|||||||
with ops.Graph().as_default() as g:
|
with ops.Graph().as_default() as g:
|
||||||
init_op, get_next, save_op, _ = _build_graph(start, stop)
|
init_op, get_next, save_op, _ = _build_graph(start, stop)
|
||||||
with self.session(graph=g) as sess:
|
with self.session(graph=g) as sess:
|
||||||
sess.run(variables.global_variables_initializer())
|
self.evaluate(variables.global_variables_initializer())
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for i in range(start, break_point):
|
for i in range(start, break_point):
|
||||||
self.assertEqual(i, sess.run(get_next))
|
self.assertEqual(i, self.evaluate(get_next))
|
||||||
sess.run(save_op)
|
self.evaluate(save_op)
|
||||||
|
|
||||||
with ops.Graph().as_default() as g:
|
with ops.Graph().as_default() as g:
|
||||||
# Intentionally build a graph with a different value for stop to make sure
|
# Intentionally build a graph with a different value for stop to make sure
|
||||||
# the original dataset graph is actually getting loaded.
|
# the original dataset graph is actually getting loaded.
|
||||||
init_op, get_next, _, restore_op = _build_graph(start, stop_1)
|
init_op, get_next, _, restore_op = _build_graph(start, stop_1)
|
||||||
with self.session(graph=g) as sess:
|
with self.session(graph=g) as sess:
|
||||||
sess.run(restore_op)
|
self.evaluate(restore_op)
|
||||||
for i in range(break_point, stop):
|
for i in range(break_point, stop):
|
||||||
self.assertEqual(i, sess.run(get_next))
|
self.assertEqual(i, self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -259,19 +259,19 @@ class ExperimentalCheckpointDatasetTest(test_base.DatasetTestBase):
|
|||||||
with ops.Graph().as_default() as g:
|
with ops.Graph().as_default() as g:
|
||||||
init_op, get_next, save_op, _ = _build_graph(start, stop)
|
init_op, get_next, save_op, _ = _build_graph(start, stop)
|
||||||
with self.session(graph=g) as sess:
|
with self.session(graph=g) as sess:
|
||||||
sess.run(variables.global_variables_initializer())
|
self.evaluate(variables.global_variables_initializer())
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for i in range(start, break_point):
|
for i in range(start, break_point):
|
||||||
self.assertEqual(i, sess.run(get_next))
|
self.assertEqual(i, self.evaluate(get_next))
|
||||||
sess.run(save_op)
|
self.evaluate(save_op)
|
||||||
|
|
||||||
with ops.Graph().as_default() as g:
|
with ops.Graph().as_default() as g:
|
||||||
init_op, get_next, _, restore_op = _build_graph(start, stop)
|
init_op, get_next, _, restore_op = _build_graph(start, stop)
|
||||||
with self.session(graph=g) as sess:
|
with self.session(graph=g) as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
sess.run(restore_op)
|
self.evaluate(restore_op)
|
||||||
for i in range(break_point, stop):
|
for i in range(break_point, stop):
|
||||||
self.assertEqual(i, sess.run(get_next))
|
self.assertEqual(i, self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -294,27 +294,27 @@ class ExperimentalCheckpointDatasetTest(test_base.DatasetTestBase):
|
|||||||
with ops.Graph().as_default() as g:
|
with ops.Graph().as_default() as g:
|
||||||
init_op, get_next, save_op, _ = _build_graph(start, stop)
|
init_op, get_next, save_op, _ = _build_graph(start, stop)
|
||||||
with self.session(graph=g) as sess:
|
with self.session(graph=g) as sess:
|
||||||
sess.run(variables.global_variables_initializer())
|
self.evaluate(variables.global_variables_initializer())
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
for i in range(start, break_point1):
|
for i in range(start, break_point1):
|
||||||
self.assertEqual(i, sess.run(get_next))
|
self.assertEqual(i, self.evaluate(get_next))
|
||||||
sess.run(save_op)
|
self.evaluate(save_op)
|
||||||
|
|
||||||
with ops.Graph().as_default() as g:
|
with ops.Graph().as_default() as g:
|
||||||
init_op, get_next, save_op, restore_op = _build_graph(start, stop)
|
init_op, get_next, save_op, restore_op = _build_graph(start, stop)
|
||||||
with self.session(graph=g) as sess:
|
with self.session(graph=g) as sess:
|
||||||
sess.run(restore_op)
|
self.evaluate(restore_op)
|
||||||
for i in range(break_point1, break_point2):
|
for i in range(break_point1, break_point2):
|
||||||
self.assertEqual(i, sess.run(get_next))
|
self.assertEqual(i, self.evaluate(get_next))
|
||||||
sess.run(save_op)
|
self.evaluate(save_op)
|
||||||
|
|
||||||
break_point2 = 7
|
break_point2 = 7
|
||||||
with ops.Graph().as_default() as g:
|
with ops.Graph().as_default() as g:
|
||||||
init_op, get_next, save_op, restore_op = _build_graph(start, stop)
|
init_op, get_next, save_op, restore_op = _build_graph(start, stop)
|
||||||
with self.session(graph=g) as sess:
|
with self.session(graph=g) as sess:
|
||||||
sess.run(restore_op)
|
self.evaluate(restore_op)
|
||||||
for i in range(break_point2, stop):
|
for i in range(break_point2, stop):
|
||||||
self.assertEqual(i, sess.run(get_next))
|
self.assertEqual(i, self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -338,28 +338,28 @@ class ExperimentalCheckpointDatasetTest(test_base.DatasetTestBase):
|
|||||||
init_op, get_next, save_op, restore_op = _build_graph(
|
init_op, get_next, save_op, restore_op = _build_graph(
|
||||||
start, stop, num_epochs)
|
start, stop, num_epochs)
|
||||||
with self.session(graph=g) as sess:
|
with self.session(graph=g) as sess:
|
||||||
sess.run(variables.global_variables_initializer())
|
self.evaluate(variables.global_variables_initializer())
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
# Note: There is no checkpoint saved currently so a NotFoundError is
|
# Note: There is no checkpoint saved currently so a NotFoundError is
|
||||||
# raised.
|
# raised.
|
||||||
with self.assertRaises(errors.NotFoundError):
|
with self.assertRaises(errors.NotFoundError):
|
||||||
sess.run(restore_op)
|
self.evaluate(restore_op)
|
||||||
for _ in range(break_epoch - 1):
|
for _ in range(break_epoch - 1):
|
||||||
for i in range(start, stop):
|
for i in range(start, stop):
|
||||||
self.assertEqual(i, sess.run(get_next))
|
self.assertEqual(i, self.evaluate(get_next))
|
||||||
for i in range(start, break_range):
|
for i in range(start, break_range):
|
||||||
self.assertEqual(i, sess.run(get_next))
|
self.assertEqual(i, self.evaluate(get_next))
|
||||||
sess.run(save_op)
|
self.evaluate(save_op)
|
||||||
|
|
||||||
with ops.Graph().as_default() as g:
|
with ops.Graph().as_default() as g:
|
||||||
init_op, get_next, _, restore_op = _build_graph(start, stop, num_epochs)
|
init_op, get_next, _, restore_op = _build_graph(start, stop, num_epochs)
|
||||||
with self.session(graph=g) as sess:
|
with self.session(graph=g) as sess:
|
||||||
sess.run(restore_op)
|
self.evaluate(restore_op)
|
||||||
for i in range(break_range, stop):
|
for i in range(break_range, stop):
|
||||||
self.assertEqual(i, sess.run(get_next))
|
self.assertEqual(i, self.evaluate(get_next))
|
||||||
for _ in range(break_epoch, num_epochs):
|
for _ in range(break_epoch, num_epochs):
|
||||||
for i in range(start, stop):
|
for i in range(start, stop):
|
||||||
self.assertEqual(i, sess.run(get_next))
|
self.assertEqual(i, self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -381,23 +381,23 @@ class ExperimentalCheckpointDatasetTest(test_base.DatasetTestBase):
|
|||||||
init_op, get_next, save_op, restore_op = _build_graph(
|
init_op, get_next, save_op, restore_op = _build_graph(
|
||||||
start, stop, num_epochs)
|
start, stop, num_epochs)
|
||||||
with self.session(graph=g) as sess:
|
with self.session(graph=g) as sess:
|
||||||
sess.run(variables.global_variables_initializer())
|
self.evaluate(variables.global_variables_initializer())
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
# Note: There is no checkpoint saved currently so a NotFoundError is
|
# Note: There is no checkpoint saved currently so a NotFoundError is
|
||||||
# raised.
|
# raised.
|
||||||
with self.assertRaises(errors.NotFoundError):
|
with self.assertRaises(errors.NotFoundError):
|
||||||
sess.run(restore_op)
|
self.evaluate(restore_op)
|
||||||
for _ in range(num_epochs):
|
for _ in range(num_epochs):
|
||||||
for i in range(start, stop):
|
for i in range(start, stop):
|
||||||
self.assertEqual(i, sess.run(get_next))
|
self.assertEqual(i, self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
sess.run(save_op)
|
self.evaluate(save_op)
|
||||||
|
|
||||||
with ops.Graph().as_default() as g:
|
with ops.Graph().as_default() as g:
|
||||||
init_op, get_next, _, restore_op = _build_graph(start, stop, num_epochs)
|
init_op, get_next, _, restore_op = _build_graph(start, stop, num_epochs)
|
||||||
with self.session(graph=g) as sess:
|
with self.session(graph=g) as sess:
|
||||||
sess.run(restore_op)
|
self.evaluate(restore_op)
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
|
@ -107,7 +107,7 @@ class TextLineDatasetTest(test_base.DatasetTestBase):
|
|||||||
init_op, feed_dict={filenames: [test_filenames[0]],
|
init_op, feed_dict={filenames: [test_filenames[0]],
|
||||||
num_epochs: 1})
|
num_epochs: 1})
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
self.assertEqual(self._lineText(0, i), sess.run(get_next))
|
self.assertEqual(self._lineText(0, i), self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -116,7 +116,7 @@ class TextLineDatasetTest(test_base.DatasetTestBase):
|
|||||||
init_op, feed_dict={filenames: [test_filenames[1]],
|
init_op, feed_dict={filenames: [test_filenames[1]],
|
||||||
num_epochs: 1})
|
num_epochs: 1})
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
self.assertEqual(self._lineText(1, i), sess.run(get_next))
|
self.assertEqual(self._lineText(1, i), self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -124,7 +124,7 @@ class TextLineDatasetTest(test_base.DatasetTestBase):
|
|||||||
sess.run(init_op, feed_dict={filenames: test_filenames, num_epochs: 1})
|
sess.run(init_op, feed_dict={filenames: test_filenames, num_epochs: 1})
|
||||||
for j in range(2):
|
for j in range(2):
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
self.assertEqual(self._lineText(j, i), sess.run(get_next))
|
self.assertEqual(self._lineText(j, i), self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -133,7 +133,7 @@ class TextLineDatasetTest(test_base.DatasetTestBase):
|
|||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
for j in range(2):
|
for j in range(2):
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
self.assertEqual(self._lineText(j, i), sess.run(get_next))
|
self.assertEqual(self._lineText(j, i), self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -267,7 +267,7 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
|||||||
init_op, feed_dict={filenames: [test_filenames[0]],
|
init_op, feed_dict={filenames: [test_filenames[0]],
|
||||||
num_epochs: 1})
|
num_epochs: 1})
|
||||||
for i in range(self._num_records):
|
for i in range(self._num_records):
|
||||||
self.assertEqual(self._record(0, i), sess.run(get_next))
|
self.assertEqual(self._record(0, i), self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -276,7 +276,7 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
|||||||
init_op, feed_dict={filenames: [test_filenames[1]],
|
init_op, feed_dict={filenames: [test_filenames[1]],
|
||||||
num_epochs: 1})
|
num_epochs: 1})
|
||||||
for i in range(self._num_records):
|
for i in range(self._num_records):
|
||||||
self.assertEqual(self._record(1, i), sess.run(get_next))
|
self.assertEqual(self._record(1, i), self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -284,7 +284,7 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
|||||||
sess.run(init_op, feed_dict={filenames: test_filenames, num_epochs: 1})
|
sess.run(init_op, feed_dict={filenames: test_filenames, num_epochs: 1})
|
||||||
for j in range(self._num_files):
|
for j in range(self._num_files):
|
||||||
for i in range(self._num_records):
|
for i in range(self._num_records):
|
||||||
self.assertEqual(self._record(j, i), sess.run(get_next))
|
self.assertEqual(self._record(j, i), self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -293,7 +293,7 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
|||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
for j in range(self._num_files):
|
for j in range(self._num_files):
|
||||||
for i in range(self._num_records):
|
for i in range(self._num_records):
|
||||||
self.assertEqual(self._record(j, i), sess.run(get_next))
|
self.assertEqual(self._record(j, i), self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -405,19 +405,19 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
|||||||
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
|
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
|
||||||
num_epochs=num_epochs)
|
num_epochs=num_epochs)
|
||||||
with self.session(graph=g) as sess:
|
with self.session(graph=g) as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
# Note: There is no checkpoint saved currently so a NotFoundError is
|
# Note: There is no checkpoint saved currently so a NotFoundError is
|
||||||
# raised.
|
# raised.
|
||||||
with self.assertRaises(errors.NotFoundError):
|
with self.assertRaises(errors.NotFoundError):
|
||||||
sess.run(restore_op)
|
self.evaluate(restore_op)
|
||||||
for epoch in range(num_epochs):
|
for epoch in range(num_epochs):
|
||||||
for f in range(self._num_files):
|
for f in range(self._num_files):
|
||||||
for r in range(self._num_records):
|
for r in range(self._num_records):
|
||||||
if (epoch == epoch_break and f == file_break and
|
if (epoch == epoch_break and f == file_break and
|
||||||
r == record_break):
|
r == record_break):
|
||||||
sess.run(save_op)
|
self.evaluate(save_op)
|
||||||
break
|
break
|
||||||
self.assertEqual(self._record(f, r), sess.run(get_next_op))
|
self.assertEqual(self._record(f, r), self.evaluate(get_next_op))
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
break
|
break
|
||||||
@ -426,13 +426,13 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
|||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next_op)
|
self.evaluate(get_next_op)
|
||||||
|
|
||||||
with ops.Graph().as_default() as g:
|
with ops.Graph().as_default() as g:
|
||||||
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
|
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
|
||||||
num_epochs=num_epochs)
|
num_epochs=num_epochs)
|
||||||
with self.session(graph=g) as sess:
|
with self.session(graph=g) as sess:
|
||||||
sess.run(restore_op)
|
self.evaluate(restore_op)
|
||||||
for epoch in range(num_epochs):
|
for epoch in range(num_epochs):
|
||||||
for f in range(self._num_files):
|
for f in range(self._num_files):
|
||||||
for r in range(self._num_records):
|
for r in range(self._num_records):
|
||||||
@ -441,9 +441,9 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
|||||||
(epoch == epoch_break and f == file_break and
|
(epoch == epoch_break and f == file_break and
|
||||||
r < record_break)):
|
r < record_break)):
|
||||||
continue
|
continue
|
||||||
self.assertEqual(self._record(f, r), sess.run(get_next_op))
|
self.assertEqual(self._record(f, r), self.evaluate(get_next_op))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next_op)
|
self.evaluate(get_next_op)
|
||||||
|
|
||||||
def testInitThenRestore(self):
|
def testInitThenRestore(self):
|
||||||
# Note: Calling init_op before restore_op is redundant. This test just makes
|
# Note: Calling init_op before restore_op is redundant. This test just makes
|
||||||
@ -458,19 +458,19 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
|||||||
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
|
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
|
||||||
num_epochs=num_epochs)
|
num_epochs=num_epochs)
|
||||||
with self.session(graph=g) as sess:
|
with self.session(graph=g) as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
# Note: There is no checkpoint saved currently so a NotFoundError is
|
# Note: There is no checkpoint saved currently so a NotFoundError is
|
||||||
# raised.
|
# raised.
|
||||||
with self.assertRaises(errors.NotFoundError):
|
with self.assertRaises(errors.NotFoundError):
|
||||||
sess.run(restore_op)
|
self.evaluate(restore_op)
|
||||||
for epoch in range(num_epochs):
|
for epoch in range(num_epochs):
|
||||||
for f in range(self._num_files):
|
for f in range(self._num_files):
|
||||||
for r in range(self._num_records):
|
for r in range(self._num_records):
|
||||||
if (epoch == epoch_break and f == file_break and
|
if (epoch == epoch_break and f == file_break and
|
||||||
r == record_break):
|
r == record_break):
|
||||||
sess.run(save_op)
|
self.evaluate(save_op)
|
||||||
break
|
break
|
||||||
self.assertEqual(self._record(f, r), sess.run(get_next_op))
|
self.assertEqual(self._record(f, r), self.evaluate(get_next_op))
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
break
|
break
|
||||||
@ -479,14 +479,14 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
|||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next_op)
|
self.evaluate(get_next_op)
|
||||||
|
|
||||||
with ops.Graph().as_default() as g:
|
with ops.Graph().as_default() as g:
|
||||||
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
|
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
|
||||||
num_epochs=num_epochs)
|
num_epochs=num_epochs)
|
||||||
with self.session(graph=g) as sess:
|
with self.session(graph=g) as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
sess.run(restore_op)
|
self.evaluate(restore_op)
|
||||||
for epoch in range(num_epochs):
|
for epoch in range(num_epochs):
|
||||||
for f in range(self._num_files):
|
for f in range(self._num_files):
|
||||||
for r in range(self._num_records):
|
for r in range(self._num_records):
|
||||||
@ -495,9 +495,9 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
|||||||
(epoch == epoch_break and f == file_break and
|
(epoch == epoch_break and f == file_break and
|
||||||
r < record_break)):
|
r < record_break)):
|
||||||
continue
|
continue
|
||||||
self.assertEqual(self._record(f, r), sess.run(get_next_op))
|
self.assertEqual(self._record(f, r), self.evaluate(get_next_op))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next_op)
|
self.evaluate(get_next_op)
|
||||||
|
|
||||||
def testRestoreInModifiedGraph(self):
|
def testRestoreInModifiedGraph(self):
|
||||||
num_epochs = 10
|
num_epochs = 10
|
||||||
@ -510,19 +510,19 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
|||||||
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
|
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
|
||||||
num_epochs=num_epochs)
|
num_epochs=num_epochs)
|
||||||
with self.session(graph=g) as sess:
|
with self.session(graph=g) as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
# Note: There is no checkpoint saved currently so a NotFoundError is
|
# Note: There is no checkpoint saved currently so a NotFoundError is
|
||||||
# raised.
|
# raised.
|
||||||
with self.assertRaises(errors.NotFoundError):
|
with self.assertRaises(errors.NotFoundError):
|
||||||
sess.run(restore_op)
|
self.evaluate(restore_op)
|
||||||
for epoch in range(num_epochs):
|
for epoch in range(num_epochs):
|
||||||
for f in range(self._num_files):
|
for f in range(self._num_files):
|
||||||
for r in range(self._num_records):
|
for r in range(self._num_records):
|
||||||
if (epoch == epoch_break and f == file_break and
|
if (epoch == epoch_break and f == file_break and
|
||||||
r == record_break):
|
r == record_break):
|
||||||
sess.run(save_op)
|
self.evaluate(save_op)
|
||||||
break
|
break
|
||||||
self.assertEqual(self._record(f, r), sess.run(get_next_op))
|
self.assertEqual(self._record(f, r), self.evaluate(get_next_op))
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
break
|
break
|
||||||
@ -531,13 +531,13 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
|||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next_op)
|
self.evaluate(get_next_op)
|
||||||
|
|
||||||
with ops.Graph().as_default() as g:
|
with ops.Graph().as_default() as g:
|
||||||
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
|
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
|
||||||
num_epochs=num_epochs_1)
|
num_epochs=num_epochs_1)
|
||||||
with self.session(graph=g) as sess:
|
with self.session(graph=g) as sess:
|
||||||
sess.run(restore_op)
|
self.evaluate(restore_op)
|
||||||
for epoch in range(num_epochs):
|
for epoch in range(num_epochs):
|
||||||
for f in range(self._num_files):
|
for f in range(self._num_files):
|
||||||
for r in range(self._num_records):
|
for r in range(self._num_records):
|
||||||
@ -546,9 +546,9 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
|||||||
(epoch == epoch_break and f == file_break and
|
(epoch == epoch_break and f == file_break and
|
||||||
r < record_break)):
|
r < record_break)):
|
||||||
continue
|
continue
|
||||||
self.assertEqual(self._record(f, r), sess.run(get_next_op))
|
self.assertEqual(self._record(f, r), self.evaluate(get_next_op))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next_op)
|
self.evaluate(get_next_op)
|
||||||
|
|
||||||
def testRestoreWithoutBuildingDatasetGraph(self):
|
def testRestoreWithoutBuildingDatasetGraph(self):
|
||||||
num_epochs = 10
|
num_epochs = 10
|
||||||
@ -560,19 +560,19 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
|||||||
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
|
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
|
||||||
num_epochs=num_epochs)
|
num_epochs=num_epochs)
|
||||||
with self.session(graph=g) as sess:
|
with self.session(graph=g) as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
# Note: There is no checkpoint saved currently so a NotFoundError is
|
# Note: There is no checkpoint saved currently so a NotFoundError is
|
||||||
# raised.
|
# raised.
|
||||||
with self.assertRaises(errors.NotFoundError):
|
with self.assertRaises(errors.NotFoundError):
|
||||||
sess.run(restore_op)
|
self.evaluate(restore_op)
|
||||||
for epoch in range(num_epochs):
|
for epoch in range(num_epochs):
|
||||||
for f in range(self._num_files):
|
for f in range(self._num_files):
|
||||||
for r in range(self._num_records):
|
for r in range(self._num_records):
|
||||||
if (epoch == epoch_break and f == file_break and
|
if (epoch == epoch_break and f == file_break and
|
||||||
r == record_break):
|
r == record_break):
|
||||||
sess.run(save_op)
|
self.evaluate(save_op)
|
||||||
break
|
break
|
||||||
self.assertEqual(self._record(f, r), sess.run(get_next_op))
|
self.assertEqual(self._record(f, r), self.evaluate(get_next_op))
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
break
|
break
|
||||||
@ -581,12 +581,12 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
|||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next_op)
|
self.evaluate(get_next_op)
|
||||||
|
|
||||||
with ops.Graph().as_default() as g:
|
with ops.Graph().as_default() as g:
|
||||||
restore_op, get_next_op = self._restore_iterator()
|
restore_op, get_next_op = self._restore_iterator()
|
||||||
with self.session(graph=g) as sess:
|
with self.session(graph=g) as sess:
|
||||||
sess.run(restore_op)
|
self.evaluate(restore_op)
|
||||||
for epoch in range(num_epochs):
|
for epoch in range(num_epochs):
|
||||||
for f in range(self._num_files):
|
for f in range(self._num_files):
|
||||||
for r in range(self._num_records):
|
for r in range(self._num_records):
|
||||||
@ -595,9 +595,9 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
|||||||
(epoch == epoch_break and f == file_break and
|
(epoch == epoch_break and f == file_break and
|
||||||
r < record_break)):
|
r < record_break)):
|
||||||
continue
|
continue
|
||||||
self.assertEqual(self._record(f, r), sess.run(get_next_op))
|
self.assertEqual(self._record(f, r), self.evaluate(get_next_op))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next_op)
|
self.evaluate(get_next_op)
|
||||||
|
|
||||||
def testRestoreUnusedIterator(self):
|
def testRestoreUnusedIterator(self):
|
||||||
num_epochs = 10
|
num_epochs = 10
|
||||||
@ -605,22 +605,22 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
|||||||
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
|
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
|
||||||
num_epochs=num_epochs)
|
num_epochs=num_epochs)
|
||||||
with self.session(graph=g) as sess:
|
with self.session(graph=g) as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
# Note: There is no checkpoint saved currently so a NotFoundError is
|
# Note: There is no checkpoint saved currently so a NotFoundError is
|
||||||
# raised.
|
# raised.
|
||||||
with self.assertRaises(errors.NotFoundError):
|
with self.assertRaises(errors.NotFoundError):
|
||||||
sess.run(restore_op)
|
self.evaluate(restore_op)
|
||||||
# Save unused iterator.
|
# Save unused iterator.
|
||||||
sess.run(save_op)
|
self.evaluate(save_op)
|
||||||
with ops.Graph().as_default() as g:
|
with ops.Graph().as_default() as g:
|
||||||
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
|
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
|
||||||
num_epochs=num_epochs)
|
num_epochs=num_epochs)
|
||||||
with self.session(graph=g) as sess:
|
with self.session(graph=g) as sess:
|
||||||
sess.run(restore_op)
|
self.evaluate(restore_op)
|
||||||
for _ in range(num_epochs * self._num_files * self._num_records):
|
for _ in range(num_epochs * self._num_files * self._num_records):
|
||||||
sess.run(get_next_op)
|
self.evaluate(get_next_op)
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next_op)
|
self.evaluate(get_next_op)
|
||||||
|
|
||||||
def testRestoreExhaustedIterator(self):
|
def testRestoreExhaustedIterator(self):
|
||||||
num_epochs = 10
|
num_epochs = 10
|
||||||
@ -629,26 +629,26 @@ class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
|
|||||||
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
|
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
|
||||||
num_epochs=num_epochs)
|
num_epochs=num_epochs)
|
||||||
with self.session(graph=g) as sess:
|
with self.session(graph=g) as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
# Note: There is no checkpoint saved currently so a NotFoundError is
|
# Note: There is no checkpoint saved currently so a NotFoundError is
|
||||||
# raised.
|
# raised.
|
||||||
with self.assertRaises(errors.NotFoundError):
|
with self.assertRaises(errors.NotFoundError):
|
||||||
sess.run(restore_op)
|
self.evaluate(restore_op)
|
||||||
for _ in range(num_epochs):
|
for _ in range(num_epochs):
|
||||||
for f in range(self._num_files):
|
for f in range(self._num_files):
|
||||||
for r in range(self._num_records):
|
for r in range(self._num_records):
|
||||||
self.assertEqual(self._record(f, r), sess.run(get_next_op))
|
self.assertEqual(self._record(f, r), self.evaluate(get_next_op))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next_op)
|
self.evaluate(get_next_op)
|
||||||
sess.run(save_op)
|
self.evaluate(save_op)
|
||||||
|
|
||||||
with ops.Graph().as_default() as g:
|
with ops.Graph().as_default() as g:
|
||||||
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
|
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
|
||||||
num_epochs=num_epochs)
|
num_epochs=num_epochs)
|
||||||
with self.session(graph=g) as sess:
|
with self.session(graph=g) as sess:
|
||||||
sess.run(restore_op)
|
self.evaluate(restore_op)
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next_op)
|
self.evaluate(get_next_op)
|
||||||
|
|
||||||
|
|
||||||
class TFRecordDatasetTest(test_base.DatasetTestBase):
|
class TFRecordDatasetTest(test_base.DatasetTestBase):
|
||||||
@ -807,7 +807,7 @@ class TFRecordDatasetTest(test_base.DatasetTestBase):
|
|||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
for j in range(self._num_files):
|
for j in range(self._num_files):
|
||||||
for i in range(self._num_records):
|
for i in range(self._num_records):
|
||||||
self.assertAllEqual(self._record(j, i), sess.run(next_element))
|
self.assertAllEqual(self._record(j, i), self.evaluate(next_element))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
|
|
||||||
@ -819,7 +819,7 @@ class TFRecordDatasetTest(test_base.DatasetTestBase):
|
|||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
for j in range(self._num_files):
|
for j in range(self._num_files):
|
||||||
for i in range(self._num_records):
|
for i in range(self._num_records):
|
||||||
self.assertAllEqual(self._record(j, i), sess.run(next_element))
|
self.assertAllEqual(self._record(j, i), self.evaluate(next_element))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
|
|
||||||
|
@ -36,7 +36,7 @@ class ReduceDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
ds = dataset_ops.Dataset.range(1, i + 1)
|
ds = dataset_ops.Dataset.range(1, i + 1)
|
||||||
result = ds.reduce(np.int64(0), lambda x, y: x + y)
|
result = ds.reduce(np.int64(0), lambda x, y: x + y)
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertEqual(((i + 1) * i) // 2, sess.run(result))
|
self.assertEqual(((i + 1) * i) // 2, self.evaluate(result))
|
||||||
|
|
||||||
def testSumTuple(self):
|
def testSumTuple(self):
|
||||||
|
|
||||||
@ -49,7 +49,7 @@ class ReduceDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
ds = dataset_ops.Dataset.zip((ds, ds))
|
ds = dataset_ops.Dataset.zip((ds, ds))
|
||||||
result = ds.reduce(np.int64(0), reduce_fn)
|
result = ds.reduce(np.int64(0), reduce_fn)
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertEqual(((i + 1) * i), sess.run(result))
|
self.assertEqual(((i + 1) * i), self.evaluate(result))
|
||||||
|
|
||||||
def testSumAndCount(self):
|
def testSumAndCount(self):
|
||||||
|
|
||||||
@ -61,7 +61,7 @@ class ReduceDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
ds = dataset_ops.Dataset.range(1, i + 1)
|
ds = dataset_ops.Dataset.range(1, i + 1)
|
||||||
result = ds.reduce((np.int64(0), np.int64(0)), reduce_fn)
|
result = ds.reduce((np.int64(0), np.int64(0)), reduce_fn)
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
s, c = sess.run(result)
|
s, c = self.evaluate(result)
|
||||||
self.assertEqual(((i + 1) * i) // 2, s)
|
self.assertEqual(((i + 1) * i) // 2, s)
|
||||||
self.assertEqual(i, c)
|
self.assertEqual(i, c)
|
||||||
|
|
||||||
@ -93,7 +93,8 @@ class ReduceDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
ds = dataset_ops.Dataset.from_tensors(make_sparse_fn(i+1))
|
ds = dataset_ops.Dataset.from_tensors(make_sparse_fn(i+1))
|
||||||
result = ds.reduce(make_sparse_fn(0), reduce_fn)
|
result = ds.reduce(make_sparse_fn(0), reduce_fn)
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertSparseValuesEqual(make_sparse_fn(i+1), sess.run(result))
|
self.assertSparseValuesEqual(
|
||||||
|
make_sparse_fn(i + 1), self.evaluate(result))
|
||||||
|
|
||||||
def testNested(self):
|
def testNested(self):
|
||||||
|
|
||||||
@ -116,7 +117,7 @@ class ReduceDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
ds = dataset_ops.Dataset.range(1, i + 1).map(map_fn)
|
ds = dataset_ops.Dataset.range(1, i + 1).map(map_fn)
|
||||||
result = ds.reduce(map_fn(0), reduce_fn)
|
result = ds.reduce(map_fn(0), reduce_fn)
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
result = sess.run(result)
|
result = self.evaluate(result)
|
||||||
self.assertEqual(((i + 1) * i) // 2, result["dense"])
|
self.assertEqual(((i + 1) * i) // 2, result["dense"])
|
||||||
self.assertSparseValuesEqual(make_sparse_fn(i), result["sparse"])
|
self.assertSparseValuesEqual(make_sparse_fn(i), result["sparse"])
|
||||||
|
|
||||||
|
@ -49,7 +49,7 @@ class SequenceDatasetTest(test_base.DatasetTestBase):
|
|||||||
# Test a finite repetition.
|
# Test a finite repetition.
|
||||||
sess.run(init_op, feed_dict={count_placeholder: 3})
|
sess.run(init_op, feed_dict={count_placeholder: 3})
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
results = sess.run(get_next)
|
results = self.evaluate(get_next)
|
||||||
for component, result_component in zip(components, results):
|
for component, result_component in zip(components, results):
|
||||||
self.assertAllEqual(component, result_component)
|
self.assertAllEqual(component, result_component)
|
||||||
|
|
||||||
@ -59,7 +59,7 @@ class SequenceDatasetTest(test_base.DatasetTestBase):
|
|||||||
# Test a different finite repetition.
|
# Test a different finite repetition.
|
||||||
sess.run(init_op, feed_dict={count_placeholder: 7})
|
sess.run(init_op, feed_dict={count_placeholder: 7})
|
||||||
for _ in range(7):
|
for _ in range(7):
|
||||||
results = sess.run(get_next)
|
results = self.evaluate(get_next)
|
||||||
for component, result_component in zip(components, results):
|
for component, result_component in zip(components, results):
|
||||||
self.assertAllEqual(component, result_component)
|
self.assertAllEqual(component, result_component)
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
@ -75,7 +75,7 @@ class SequenceDatasetTest(test_base.DatasetTestBase):
|
|||||||
# actually is infinite.
|
# actually is infinite.
|
||||||
sess.run(init_op, feed_dict={count_placeholder: -1})
|
sess.run(init_op, feed_dict={count_placeholder: -1})
|
||||||
for _ in range(17):
|
for _ in range(17):
|
||||||
results = sess.run(get_next)
|
results = self.evaluate(get_next)
|
||||||
for component, result_component in zip(components, results):
|
for component, result_component in zip(components, results):
|
||||||
self.assertAllEqual(component, result_component)
|
self.assertAllEqual(component, result_component)
|
||||||
|
|
||||||
@ -95,7 +95,7 @@ class SequenceDatasetTest(test_base.DatasetTestBase):
|
|||||||
# Take fewer than input size
|
# Take fewer than input size
|
||||||
sess.run(init_op, feed_dict={count_placeholder: 4})
|
sess.run(init_op, feed_dict={count_placeholder: 4})
|
||||||
for i in range(4):
|
for i in range(4):
|
||||||
results = sess.run(get_next)
|
results = self.evaluate(get_next)
|
||||||
self.assertAllEqual(results, components[0][i:i+1])
|
self.assertAllEqual(results, components[0][i:i+1])
|
||||||
|
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
@ -104,7 +104,7 @@ class SequenceDatasetTest(test_base.DatasetTestBase):
|
|||||||
# Take more than input size
|
# Take more than input size
|
||||||
sess.run(init_op, feed_dict={count_placeholder: 25})
|
sess.run(init_op, feed_dict={count_placeholder: 25})
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
results = sess.run(get_next)
|
results = self.evaluate(get_next)
|
||||||
self.assertAllEqual(results, components[0][i:i+1])
|
self.assertAllEqual(results, components[0][i:i+1])
|
||||||
|
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
@ -113,7 +113,7 @@ class SequenceDatasetTest(test_base.DatasetTestBase):
|
|||||||
# Take all of input
|
# Take all of input
|
||||||
sess.run(init_op, feed_dict={count_placeholder: -1})
|
sess.run(init_op, feed_dict={count_placeholder: -1})
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
results = sess.run(get_next)
|
results = self.evaluate(get_next)
|
||||||
self.assertAllEqual(results, components[0][i:i+1])
|
self.assertAllEqual(results, components[0][i:i+1])
|
||||||
|
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
@ -142,7 +142,7 @@ class SequenceDatasetTest(test_base.DatasetTestBase):
|
|||||||
# the first 4 elements and then read the rest.
|
# the first 4 elements and then read the rest.
|
||||||
sess.run(init_op, feed_dict={count_placeholder: 4})
|
sess.run(init_op, feed_dict={count_placeholder: 4})
|
||||||
for i in range(4, 10):
|
for i in range(4, 10):
|
||||||
results = sess.run(get_next)
|
results = self.evaluate(get_next)
|
||||||
self.assertAllEqual(results, components[0][i:i+1])
|
self.assertAllEqual(results, components[0][i:i+1])
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
@ -165,7 +165,7 @@ class SequenceDatasetTest(test_base.DatasetTestBase):
|
|||||||
# Skip nothing
|
# Skip nothing
|
||||||
sess.run(init_op, feed_dict={count_placeholder: 0})
|
sess.run(init_op, feed_dict={count_placeholder: 0})
|
||||||
for i in range(0, 10):
|
for i in range(0, 10):
|
||||||
results = sess.run(get_next)
|
results = self.evaluate(get_next)
|
||||||
self.assertAllEqual(results, components[0][i:i+1])
|
self.assertAllEqual(results, components[0][i:i+1])
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
@ -187,7 +187,7 @@ class SequenceDatasetTest(test_base.DatasetTestBase):
|
|||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op, feed_dict={inner_count: 7, outer_count: 14})
|
sess.run(init_op, feed_dict={inner_count: 7, outer_count: 14})
|
||||||
for _ in range(7 * 14):
|
for _ in range(7 * 14):
|
||||||
results = sess.run(get_next)
|
results = self.evaluate(get_next)
|
||||||
for component, result_component in zip(components, results):
|
for component, result_component in zip(components, results):
|
||||||
self.assertAllEqual(component, result_component)
|
self.assertAllEqual(component, result_component)
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
@ -201,7 +201,7 @@ class SequenceDatasetTest(test_base.DatasetTestBase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
|
@ -66,7 +66,7 @@ class ShuffleDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
# First run without shuffling to collect the "ground truth".
|
# First run without shuffling to collect the "ground truth".
|
||||||
sess.run(init_fifo_op)
|
self.evaluate(init_fifo_op)
|
||||||
unshuffled_elements = []
|
unshuffled_elements = []
|
||||||
for _ in range(20):
|
for _ in range(20):
|
||||||
unshuffled_elements.append(sess.run(get_next))
|
unshuffled_elements.append(sess.run(get_next))
|
||||||
@ -159,7 +159,7 @@ class ShuffleDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(iterator.initializer, feed_dict={seed_placeholder: 0})
|
sess.run(iterator.initializer, feed_dict={seed_placeholder: 0})
|
||||||
for elem in elems:
|
for elem in elems:
|
||||||
self.assertEqual(elem, sess.run(get_next))
|
self.assertEqual(elem, self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
@ -188,9 +188,9 @@ class ShuffleDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
next_element = iterator.get_next()
|
next_element = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
initial_permutation = sess.run(next_element)
|
initial_permutation = self.evaluate(next_element)
|
||||||
self.assertAllEqual(initial_permutation, sess.run(next_element))
|
self.assertAllEqual(initial_permutation, self.evaluate(next_element))
|
||||||
self.assertAllEqual(initial_permutation, sess.run(next_element))
|
self.assertAllEqual(initial_permutation, self.evaluate(next_element))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
|
|
||||||
@ -261,7 +261,7 @@ class ShuffleDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
with self.session(graph=g) as sess:
|
with self.session(graph=g) as sess:
|
||||||
for iterator in iterators:
|
for iterator in iterators:
|
||||||
if initializable:
|
if initializable:
|
||||||
sess.run(iterator.initializer)
|
self.evaluate(iterator.initializer)
|
||||||
next_element = iterator.get_next()
|
next_element = iterator.get_next()
|
||||||
run_results = []
|
run_results = []
|
||||||
for _ in range(300):
|
for _ in range(300):
|
||||||
|
@ -102,7 +102,7 @@ class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
num_full_batches = max(
|
num_full_batches = max(
|
||||||
0, (count * 7 - ((size - 1) * stride + 1)) // shift + 1)
|
0, (count * 7 - ((size - 1) * stride + 1)) // shift + 1)
|
||||||
for i in range(num_full_batches):
|
for i in range(num_full_batches):
|
||||||
result = sess.run(get_next)
|
result = self.evaluate(get_next)
|
||||||
for component, result_component in zip(components, result):
|
for component, result_component in zip(components, result):
|
||||||
for j in range(size):
|
for j in range(size):
|
||||||
self.assertAllEqual(component[(i * shift + j * stride) % 7]**2,
|
self.assertAllEqual(component[(i * shift + j * stride) % 7]**2,
|
||||||
@ -111,7 +111,7 @@ class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
num_partial_batches = (count * 7) // shift + (
|
num_partial_batches = (count * 7) // shift + (
|
||||||
(count * 7) % shift > 0) - num_full_batches
|
(count * 7) % shift > 0) - num_full_batches
|
||||||
for i in range(num_partial_batches):
|
for i in range(num_partial_batches):
|
||||||
result = sess.run(get_next)
|
result = self.evaluate(get_next)
|
||||||
for component, result_component in zip(components, result):
|
for component, result_component in zip(components, result):
|
||||||
remaining = (count * 7) - ((num_full_batches + i) * shift)
|
remaining = (count * 7) - ((num_full_batches + i) * shift)
|
||||||
num_elements = remaining // stride + ((remaining % stride) > 0)
|
num_elements = remaining // stride + ((remaining % stride) > 0)
|
||||||
@ -164,10 +164,10 @@ class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
num_batches = (10 - 5) // 3 + 1
|
num_batches = (10 - 5) // 3 + 1
|
||||||
for i in range(num_batches):
|
for i in range(num_batches):
|
||||||
actual = sess.run(get_next)
|
actual = self.evaluate(get_next)
|
||||||
expected = sparse_tensor.SparseTensorValue(
|
expected = sparse_tensor.SparseTensorValue(
|
||||||
indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
|
indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
|
||||||
values=[i * 3, i * 3 + 1, i * 3 + 2, i * 3 + 3, i * 3 + 4],
|
values=[i * 3, i * 3 + 1, i * 3 + 2, i * 3 + 3, i * 3 + 4],
|
||||||
@ -193,10 +193,10 @@ class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
num_batches = (10 - 5) // 3 + 1
|
num_batches = (10 - 5) // 3 + 1
|
||||||
for i in range(num_batches):
|
for i in range(num_batches):
|
||||||
actual = sess.run(get_next)
|
actual = self.evaluate(get_next)
|
||||||
expected_indices = []
|
expected_indices = []
|
||||||
expected_values = []
|
expected_values = []
|
||||||
for j in range(5):
|
for j in range(5):
|
||||||
@ -227,9 +227,9 @@ class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
# Slide: 1st batch.
|
# Slide: 1st batch.
|
||||||
actual = sess.run(get_next)
|
actual = self.evaluate(get_next)
|
||||||
expected = sparse_tensor.SparseTensorValue(
|
expected = sparse_tensor.SparseTensorValue(
|
||||||
indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [1, 0, 0],
|
indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [1, 0, 0],
|
||||||
[1, 1, 0], [1, 2, 0], [1, 3, 0], [2, 0, 0], [2, 1, 0],
|
[1, 1, 0], [1, 2, 0], [1, 3, 0], [2, 0, 0], [2, 1, 0],
|
||||||
@ -239,7 +239,7 @@ class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
self.assertTrue(sparse_tensor.is_sparse(actual))
|
self.assertTrue(sparse_tensor.is_sparse(actual))
|
||||||
self.assertSparseValuesEqual(actual, expected)
|
self.assertSparseValuesEqual(actual, expected)
|
||||||
# Slide: 2nd batch.
|
# Slide: 2nd batch.
|
||||||
actual = sess.run(get_next)
|
actual = self.evaluate(get_next)
|
||||||
expected = sparse_tensor.SparseTensorValue(
|
expected = sparse_tensor.SparseTensorValue(
|
||||||
indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [1, 0, 0],
|
indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [1, 0, 0],
|
||||||
[1, 1, 0], [1, 2, 0], [1, 3, 0], [2, 0, 0], [2, 1, 0],
|
[1, 1, 0], [1, 2, 0], [1, 3, 0], [2, 0, 0], [2, 1, 0],
|
||||||
@ -265,7 +265,7 @@ class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
next_element = iterator.get_next()
|
next_element = iterator.get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(iterator.initializer)
|
self.evaluate(iterator.initializer)
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
errors.InvalidArgumentError,
|
errors.InvalidArgumentError,
|
||||||
r"Cannot batch tensors with different shapes in component 0. "
|
r"Cannot batch tensors with different shapes in component 0. "
|
||||||
@ -281,8 +281,8 @@ class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
get_next = dataset.make_one_shot_iterator().get_next()
|
get_next = dataset.make_one_shot_iterator().get_next()
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertAllEqual(np.float32([1., 2.]), sess.run(get_next))
|
self.assertAllEqual(np.float32([1., 2.]), self.evaluate(get_next))
|
||||||
self.assertAllEqual(np.float32([2., 3.]), sess.run(get_next))
|
self.assertAllEqual(np.float32([2., 3.]), self.evaluate(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
|
@ -55,7 +55,7 @@ class ZipDatasetTest(test_base.DatasetTestBase):
|
|||||||
sess.run(init_op, feed_dict={ph: value for ph, value in zip(
|
sess.run(init_op, feed_dict={ph: value for ph, value in zip(
|
||||||
component_placeholders, equal_length_components)})
|
component_placeholders, equal_length_components)})
|
||||||
for i in range(4):
|
for i in range(4):
|
||||||
results = sess.run(get_next)
|
results = self.evaluate(get_next)
|
||||||
for component, result_component in zip(
|
for component, result_component in zip(
|
||||||
equal_length_components, results):
|
equal_length_components, results):
|
||||||
self.assertAllEqual(component[i], result_component)
|
self.assertAllEqual(component[i], result_component)
|
||||||
@ -66,7 +66,7 @@ class ZipDatasetTest(test_base.DatasetTestBase):
|
|||||||
sess.run(init_op, feed_dict={ph: value for ph, value in zip(
|
sess.run(init_op, feed_dict={ph: value for ph, value in zip(
|
||||||
component_placeholders, variable_length_components)})
|
component_placeholders, variable_length_components)})
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
results = sess.run(get_next)
|
results = self.evaluate(get_next)
|
||||||
for component, result_component in zip(
|
for component, result_component in zip(
|
||||||
variable_length_components, results):
|
variable_length_components, results):
|
||||||
self.assertAllEqual(component[i], result_component)
|
self.assertAllEqual(component[i], result_component)
|
||||||
@ -103,7 +103,7 @@ class ZipDatasetTest(test_base.DatasetTestBase):
|
|||||||
sess.run(init_op, feed_dict={ph: value for ph, value in zip(
|
sess.run(init_op, feed_dict={ph: value for ph, value in zip(
|
||||||
component_placeholders, equal_length_components)})
|
component_placeholders, equal_length_components)})
|
||||||
for i in range(4):
|
for i in range(4):
|
||||||
result1, (result2, result3) = sess.run(get_next)
|
result1, (result2, result3) = self.evaluate(get_next)
|
||||||
self.assertAllEqual(equal_length_components[0][i], result1)
|
self.assertAllEqual(equal_length_components[0][i], result1)
|
||||||
self.assertAllEqual(equal_length_components[1][i], result2)
|
self.assertAllEqual(equal_length_components[1][i], result2)
|
||||||
self.assertAllEqual(equal_length_components[2][i], result3)
|
self.assertAllEqual(equal_length_components[2][i], result3)
|
||||||
|
@ -31,24 +31,24 @@ class ConvertTest(test.TestCase):
|
|||||||
def testInteger(self):
|
def testInteger(self):
|
||||||
resp = convert.optional_param_to_tensor("foo", 3)
|
resp = convert.optional_param_to_tensor("foo", 3)
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertEqual(3, sess.run(resp))
|
self.assertEqual(3, self.evaluate(resp))
|
||||||
|
|
||||||
def testIntegerDefault(self):
|
def testIntegerDefault(self):
|
||||||
resp = convert.optional_param_to_tensor("foo", None)
|
resp = convert.optional_param_to_tensor("foo", None)
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertEqual(0, sess.run(resp))
|
self.assertEqual(0, self.evaluate(resp))
|
||||||
|
|
||||||
def testStringDefault(self):
|
def testStringDefault(self):
|
||||||
resp = convert.optional_param_to_tensor("bar", None, "default",
|
resp = convert.optional_param_to_tensor("bar", None, "default",
|
||||||
dtypes.string)
|
dtypes.string)
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertEqual(compat.as_bytes("default"), sess.run(resp))
|
self.assertEqual(compat.as_bytes("default"), self.evaluate(resp))
|
||||||
|
|
||||||
def testString(self):
|
def testString(self):
|
||||||
resp = convert.optional_param_to_tensor("bar", "value", "default",
|
resp = convert.optional_param_to_tensor("bar", "value", "default",
|
||||||
dtypes.string)
|
dtypes.string)
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertEqual(compat.as_bytes("value"), sess.run(resp))
|
self.assertEqual(compat.as_bytes("value"), self.evaluate(resp))
|
||||||
|
|
||||||
def testPartialShapeToTensorKnownDimension(self):
|
def testPartialShapeToTensorKnownDimension(self):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
|
@ -1583,7 +1583,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
|
|||||||
x = variables.VariableV1([1, 3, 3, 7], name="x")
|
x = variables.VariableV1([1, 3, 3, 7], name="x")
|
||||||
_, idx = array_ops.unique(x, name="x_unique")
|
_, idx = array_ops.unique(x, name="x_unique")
|
||||||
idx_times_two = math_ops.multiply(idx, 2, name="idx_times_two")
|
idx_times_two = math_ops.multiply(idx, 2, name="idx_times_two")
|
||||||
sess.run(x.initializer)
|
self.evaluate(x.initializer)
|
||||||
|
|
||||||
run_options = config_pb2.RunOptions(output_partition_graphs=True)
|
run_options = config_pb2.RunOptions(output_partition_graphs=True)
|
||||||
debug_utils.watch_graph(
|
debug_utils.watch_graph(
|
||||||
|
@ -126,8 +126,8 @@ class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase):
|
|||||||
u = variables.Variable([12.0], name="u")
|
u = variables.Variable([12.0], name="u")
|
||||||
v = variables.Variable([30.0], name="v")
|
v = variables.Variable([30.0], name="v")
|
||||||
w = math_ops.add(u, v, name="w")
|
w = math_ops.add(u, v, name="w")
|
||||||
sess.run(u.initializer)
|
self.evaluate(u.initializer)
|
||||||
sess.run(v.initializer)
|
self.evaluate(v.initializer)
|
||||||
|
|
||||||
self._compareOriginalAndReconstructedGraphDefs(
|
self._compareOriginalAndReconstructedGraphDefs(
|
||||||
sess, w, expected_output=[42.0])
|
sess, w, expected_output=[42.0])
|
||||||
@ -139,7 +139,7 @@ class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase):
|
|||||||
b = math_ops.add(a, a, name="b")
|
b = math_ops.add(a, a, name="b")
|
||||||
with ops.control_dependencies([a, b]):
|
with ops.control_dependencies([a, b]):
|
||||||
c = math_ops.multiply(b, b, name="c")
|
c = math_ops.multiply(b, b, name="c")
|
||||||
sess.run(a.initializer)
|
self.evaluate(a.initializer)
|
||||||
|
|
||||||
self._compareOriginalAndReconstructedGraphDefs(
|
self._compareOriginalAndReconstructedGraphDefs(
|
||||||
sess, c, expected_output=400.0)
|
sess, c, expected_output=400.0)
|
||||||
@ -150,8 +150,8 @@ class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase):
|
|||||||
y = variables.Variable(20.0, name="y")
|
y = variables.Variable(20.0, name="y")
|
||||||
cond = control_flow_ops.cond(
|
cond = control_flow_ops.cond(
|
||||||
x > y, lambda: math_ops.add(x, 1), lambda: math_ops.add(y, 1))
|
x > y, lambda: math_ops.add(x, 1), lambda: math_ops.add(y, 1))
|
||||||
sess.run(x.initializer)
|
self.evaluate(x.initializer)
|
||||||
sess.run(y.initializer)
|
self.evaluate(y.initializer)
|
||||||
|
|
||||||
self._compareOriginalAndReconstructedGraphDefs(
|
self._compareOriginalAndReconstructedGraphDefs(
|
||||||
sess, cond, expected_output=21.0)
|
sess, cond, expected_output=21.0)
|
||||||
@ -173,8 +173,8 @@ class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase):
|
|||||||
toy_loss = x * (u - v)
|
toy_loss = x * (u - v)
|
||||||
train_op = gradient_descent.GradientDescentOptimizer(
|
train_op = gradient_descent.GradientDescentOptimizer(
|
||||||
learning_rate=0.1).minimize(toy_loss, name="train_op")
|
learning_rate=0.1).minimize(toy_loss, name="train_op")
|
||||||
sess.run(u.initializer)
|
self.evaluate(u.initializer)
|
||||||
sess.run(v.initializer)
|
self.evaluate(v.initializer)
|
||||||
|
|
||||||
self._compareOriginalAndReconstructedGraphDefs(sess, train_op)
|
self._compareOriginalAndReconstructedGraphDefs(sess, train_op)
|
||||||
|
|
||||||
|
@ -131,8 +131,8 @@ class DistributedSessionDebugTest(test_util.TensorFlowTestCase):
|
|||||||
with session.Session(
|
with session.Session(
|
||||||
config=self.session_config, graph=graph,
|
config=self.session_config, graph=graph,
|
||||||
target=self.server_target) as sess:
|
target=self.server_target) as sess:
|
||||||
sess.run(self.a.initializer)
|
self.evaluate(self.a.initializer)
|
||||||
sess.run(self.b.initializer)
|
self.evaluate(self.b.initializer)
|
||||||
|
|
||||||
run_options = config_pb2.RunOptions()
|
run_options = config_pb2.RunOptions()
|
||||||
debug_utils.watch_graph(
|
debug_utils.watch_graph(
|
||||||
@ -198,8 +198,8 @@ class DistributedSessionDebugTest(test_util.TensorFlowTestCase):
|
|||||||
with session.Session(
|
with session.Session(
|
||||||
config=self.session_config, graph=graph,
|
config=self.session_config, graph=graph,
|
||||||
target=self.server_target) as sess:
|
target=self.server_target) as sess:
|
||||||
sess.run(self.a.initializer)
|
self.evaluate(self.a.initializer)
|
||||||
sess.run(self.b.initializer)
|
self.evaluate(self.b.initializer)
|
||||||
|
|
||||||
def watch_fn(feeds, fetch_keys):
|
def watch_fn(feeds, fetch_keys):
|
||||||
del feeds, fetch_keys
|
del feeds, fetch_keys
|
||||||
|
@ -67,7 +67,7 @@ class SessionDebugMultiGPUTest(test_util.TensorFlowTestCase):
|
|||||||
u1 = math_ops.multiply(v, v, name="u1")
|
u1 = math_ops.multiply(v, v, name="u1")
|
||||||
w = math_ops.subtract(u1, u0, name="w")
|
w = math_ops.subtract(u1, u0, name="w")
|
||||||
|
|
||||||
sess.run(v.initializer)
|
self.evaluate(v.initializer)
|
||||||
|
|
||||||
run_options = config_pb2.RunOptions(output_partition_graphs=True)
|
run_options = config_pb2.RunOptions(output_partition_graphs=True)
|
||||||
debug_utils.watch_graph(run_options, sess.graph,
|
debug_utils.watch_graph(run_options, sess.graph,
|
||||||
|
@ -109,8 +109,8 @@ class SourceHelperTest(test_util.TensorFlowTestCase):
|
|||||||
self.w = math_ops.matmul(self.u, self.v, name="w")
|
self.w = math_ops.matmul(self.u, self.v, name="w")
|
||||||
self.w_line_number = line_number_above()
|
self.w_line_number = line_number_above()
|
||||||
|
|
||||||
sess.run(self.u.initializer)
|
self.evaluate(self.u.initializer)
|
||||||
sess.run(self.v.initializer)
|
self.evaluate(self.v.initializer)
|
||||||
|
|
||||||
run_options = config_pb2.RunOptions(output_partition_graphs=True)
|
run_options = config_pb2.RunOptions(output_partition_graphs=True)
|
||||||
debug_utils.watch_graph(
|
debug_utils.watch_graph(
|
||||||
|
@ -235,7 +235,7 @@ class DistributeCoordinatorTestBase(test.TestCase):
|
|||||||
result = math_ops.add_n(xs)
|
result = math_ops.add_n(xs)
|
||||||
|
|
||||||
variables.global_variables_initializer().run()
|
variables.global_variables_initializer().run()
|
||||||
result_value = sess.run(result)
|
result_value = self.evaluate(result)
|
||||||
self.assertEqual(result_value, expected)
|
self.assertEqual(result_value, expected)
|
||||||
if result_value == expected:
|
if result_value == expected:
|
||||||
self._result_correct += 1
|
self._result_correct += 1
|
||||||
@ -294,7 +294,7 @@ class DistributeCoordinatorTestBase(test.TestCase):
|
|||||||
if len(uninit_vars) == 0:
|
if len(uninit_vars) == 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
sess.run(train_op)
|
self.evaluate(train_op)
|
||||||
|
|
||||||
# Synchronize workers after one step to make sure they all have finished
|
# Synchronize workers after one step to make sure they all have finished
|
||||||
# training.
|
# training.
|
||||||
@ -327,7 +327,7 @@ class DistributeCoordinatorTestBase(test.TestCase):
|
|||||||
|
|
||||||
# The monitored session will run init or ready ops.
|
# The monitored session will run init or ready ops.
|
||||||
with monitored_session.MonitoredSession() as sess:
|
with monitored_session.MonitoredSession() as sess:
|
||||||
sess.run(train_op)
|
self.evaluate(train_op)
|
||||||
|
|
||||||
# Synchronize workers after one step to make sure they all have finished
|
# Synchronize workers after one step to make sure they all have finished
|
||||||
# training.
|
# training.
|
||||||
|
@ -92,7 +92,7 @@ class AutoShardDatasetTest(test.TestCase):
|
|||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
for f in range(self._shard_index, self._num_files, self._num_shards):
|
for f in range(self._shard_index, self._num_files, self._num_shards):
|
||||||
for r in range(self._num_records):
|
for r in range(self._num_records):
|
||||||
self.assertAllEqual(record_fn(r, f), sess.run(next_element))
|
self.assertAllEqual(record_fn(r, f), self.evaluate(next_element))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
|
|
||||||
@ -205,10 +205,11 @@ class AutoShardDatasetTest(test.TestCase):
|
|||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
for f in range(self._shard_index, self._num_files, self._num_shards):
|
for f in range(self._shard_index, self._num_files, self._num_shards):
|
||||||
for r in range(self._num_records):
|
for r in range(self._num_records):
|
||||||
self.assertAllEqual(self._record(r, f), sess.run(next_element))
|
self.assertAllEqual(self._record(r, f), self.evaluate(next_element))
|
||||||
for f in range(self._shard_index, self._num_files, self._num_shards):
|
for f in range(self._shard_index, self._num_files, self._num_shards):
|
||||||
for r in range(self._num_records):
|
for r in range(self._num_records):
|
||||||
self.assertAllEqual(self._text_line(r, f), sess.run(next_element))
|
self.assertAllEqual(
|
||||||
|
self._text_line(r, f), self.evaluate(next_element))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
|
|
||||||
|
@ -149,9 +149,9 @@ class DefFunctionTest(test.TestCase):
|
|||||||
|
|
||||||
result = fn(3.0)
|
result = fn(3.0)
|
||||||
|
|
||||||
sess.run(variables.global_variables_initializer())
|
self.evaluate(variables.global_variables_initializer())
|
||||||
self.assertAllEqual(sess.run(state[0]), 2.0)
|
self.assertAllEqual(sess.run(state[0]), 2.0)
|
||||||
self.assertAllEqual(sess.run(result), 6.0)
|
self.assertAllEqual(self.evaluate(result), 6.0)
|
||||||
|
|
||||||
def testLegacyGraphModeVariablesNonTrivialInitializer(self):
|
def testLegacyGraphModeVariablesNonTrivialInitializer(self):
|
||||||
with ops.Graph().as_default(), self.test_session() as sess:
|
with ops.Graph().as_default(), self.test_session() as sess:
|
||||||
@ -168,9 +168,9 @@ class DefFunctionTest(test.TestCase):
|
|||||||
|
|
||||||
result = fn(3.0)
|
result = fn(3.0)
|
||||||
|
|
||||||
sess.run(variables.global_variables_initializer())
|
self.evaluate(variables.global_variables_initializer())
|
||||||
self.assertAllEqual(sess.run(state[0]), 6.0)
|
self.assertAllEqual(sess.run(state[0]), 6.0)
|
||||||
self.assertAllEqual(sess.run(result), 18.0)
|
self.assertAllEqual(self.evaluate(result), 18.0)
|
||||||
|
|
||||||
def testLegacyGraphModeInputDependentInitializerFails(self):
|
def testLegacyGraphModeInputDependentInitializerFails(self):
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
|
@ -78,7 +78,7 @@ class FunctionGradientsTest(test.TestCase, parameterized.TestCase):
|
|||||||
c = constant_op.constant([[2.]])
|
c = constant_op.constant([[2.]])
|
||||||
f_c = f(c)
|
f_c = f(c)
|
||||||
g, = gradients_impl.gradients(f_c, c)
|
g, = gradients_impl.gradients(f_c, c)
|
||||||
self.assertAllEqual(sess.run(g).values, [[1.0]])
|
self.assertAllEqual(self.evaluate(g).values, [[1.0]])
|
||||||
|
|
||||||
def testNoSymGradNestedDefun(self):
|
def testNoSymGradNestedDefun(self):
|
||||||
|
|
||||||
|
@ -564,7 +564,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
|||||||
variables.global_variables_initializer().run()
|
variables.global_variables_initializer().run()
|
||||||
call = def_function.function(o.call)
|
call = def_function.function(o.call)
|
||||||
op = call()
|
op = call()
|
||||||
self.assertAllEqual(sess.run(op), 2.0)
|
self.assertAllEqual(self.evaluate(op), 2.0)
|
||||||
|
|
||||||
def testGraphModeManyFunctions(self):
|
def testGraphModeManyFunctions(self):
|
||||||
with ops.Graph().as_default(), self.cached_session():
|
with ops.Graph().as_default(), self.cached_session():
|
||||||
@ -1733,7 +1733,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
function.register(cpu_boost, x)
|
function.register(cpu_boost, x)
|
||||||
y = gpu_boost(x)
|
y = gpu_boost(x)
|
||||||
y_value = sess.run(y)
|
y_value = self.evaluate(y)
|
||||||
|
|
||||||
if test.is_gpu_available():
|
if test.is_gpu_available():
|
||||||
self.assertEqual(y_value, 5.0)
|
self.assertEqual(y_value, 5.0)
|
||||||
|
@ -1026,7 +1026,7 @@ class CrossedColumnTest(test.TestCase):
|
|||||||
outputs = _transform_features(features, [price_cross_wire])
|
outputs = _transform_features(features, [price_cross_wire])
|
||||||
output = outputs[price_cross_wire]
|
output = outputs[price_cross_wire]
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
output_val = sess.run(output)
|
output_val = self.evaluate(output)
|
||||||
self.assertAllEqual(
|
self.assertAllEqual(
|
||||||
[[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]], output_val.indices)
|
[[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]], output_val.indices)
|
||||||
for val in output_val.values:
|
for val in output_val.values:
|
||||||
@ -1880,7 +1880,8 @@ class LinearModelTest(test.TestCase):
|
|||||||
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
|
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
|
||||||
sess.run(bias.assign([5.]))
|
sess.run(bias.assign([5.]))
|
||||||
|
|
||||||
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], sess.run(net))
|
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]],
|
||||||
|
self.evaluate(net))
|
||||||
|
|
||||||
def test_with_1d_unknown_shape_sparse_tensor(self):
|
def test_with_1d_unknown_shape_sparse_tensor(self):
|
||||||
price = fc.numeric_column('price')
|
price = fc.numeric_column('price')
|
||||||
@ -2514,7 +2515,8 @@ class _LinearModelTest(test.TestCase):
|
|||||||
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
|
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
|
||||||
sess.run(bias.assign([5.]))
|
sess.run(bias.assign([5.]))
|
||||||
|
|
||||||
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], sess.run(net))
|
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]],
|
||||||
|
self.evaluate(net))
|
||||||
|
|
||||||
def test_with_1d_unknown_shape_sparse_tensor(self):
|
def test_with_1d_unknown_shape_sparse_tensor(self):
|
||||||
price = fc.numeric_column('price')
|
price = fc.numeric_column('price')
|
||||||
|
@ -1190,7 +1190,7 @@ class CrossedColumnTest(test.TestCase):
|
|||||||
outputs = fc._transform_features(features, [price_cross_wire], None)
|
outputs = fc._transform_features(features, [price_cross_wire], None)
|
||||||
output = outputs[price_cross_wire]
|
output = outputs[price_cross_wire]
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
output_val = sess.run(output)
|
output_val = self.evaluate(output)
|
||||||
self.assertAllEqual(
|
self.assertAllEqual(
|
||||||
[[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]], output_val.indices)
|
[[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]], output_val.indices)
|
||||||
for val in output_val.values:
|
for val in output_val.values:
|
||||||
@ -2091,7 +2091,8 @@ class LinearModelTest(test.TestCase):
|
|||||||
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
|
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
|
||||||
sess.run(bias.assign([5.]))
|
sess.run(bias.assign([5.]))
|
||||||
|
|
||||||
self.assertAllClose([[10 - 1000 + 5.], [100 - 10 + 5.]], sess.run(net))
|
self.assertAllClose([[10 - 1000 + 5.], [100 - 10 + 5.]],
|
||||||
|
self.evaluate(net))
|
||||||
|
|
||||||
coord.request_stop()
|
coord.request_stop()
|
||||||
coord.join(threads)
|
coord.join(threads)
|
||||||
@ -2127,7 +2128,8 @@ class LinearModelTest(test.TestCase):
|
|||||||
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
|
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
|
||||||
sess.run(bias.assign([5.]))
|
sess.run(bias.assign([5.]))
|
||||||
|
|
||||||
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], sess.run(net))
|
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]],
|
||||||
|
self.evaluate(net))
|
||||||
|
|
||||||
def test_with_1d_unknown_shape_sparse_tensor(self):
|
def test_with_1d_unknown_shape_sparse_tensor(self):
|
||||||
price = fc.numeric_column_v2('price')
|
price = fc.numeric_column_v2('price')
|
||||||
@ -2849,7 +2851,8 @@ class OldLinearModelTest(test.TestCase):
|
|||||||
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
|
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
|
||||||
sess.run(bias.assign([5.]))
|
sess.run(bias.assign([5.]))
|
||||||
|
|
||||||
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], sess.run(net))
|
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]],
|
||||||
|
self.evaluate(net))
|
||||||
|
|
||||||
def test_with_1d_unknown_shape_sparse_tensor(self):
|
def test_with_1d_unknown_shape_sparse_tensor(self):
|
||||||
price = fc.numeric_column_v2('price')
|
price = fc.numeric_column_v2('price')
|
||||||
|
@ -102,7 +102,7 @@ class FunctionTest(test.TestCase):
|
|||||||
call = MyIdentityFunc([18.0])
|
call = MyIdentityFunc([18.0])
|
||||||
self.assertEqual("MyIdentity", call.op.name)
|
self.assertEqual("MyIdentity", call.op.name)
|
||||||
with session.Session() as sess:
|
with session.Session() as sess:
|
||||||
self.assertAllEqual([18.0], sess.run(call))
|
self.assertAllEqual([18.0], self.evaluate(call))
|
||||||
|
|
||||||
def testIdentityImplicitDeref(self):
|
def testIdentityImplicitDeref(self):
|
||||||
|
|
||||||
@ -116,8 +116,8 @@ class FunctionTest(test.TestCase):
|
|||||||
self.assertEqual("MyIdentity", call.op.name)
|
self.assertEqual("MyIdentity", call.op.name)
|
||||||
for cfg in _OptimizerOptions():
|
for cfg in _OptimizerOptions():
|
||||||
with session.Session(config=cfg) as sess:
|
with session.Session(config=cfg) as sess:
|
||||||
sess.run(var.initializer)
|
self.evaluate(var.initializer)
|
||||||
self.assertAllEqual([18.0], sess.run(call))
|
self.assertAllEqual([18.0], self.evaluate(call))
|
||||||
|
|
||||||
def testIdentityOutputName(self):
|
def testIdentityOutputName(self):
|
||||||
|
|
||||||
@ -130,7 +130,7 @@ class FunctionTest(test.TestCase):
|
|||||||
call = MyIdentityFunc([18.0])
|
call = MyIdentityFunc([18.0])
|
||||||
self.assertEqual("MyIdentity", call.op.name)
|
self.assertEqual("MyIdentity", call.op.name)
|
||||||
with session.Session() as sess:
|
with session.Session() as sess:
|
||||||
self.assertAllEqual([18.0], sess.run(call))
|
self.assertAllEqual([18.0], self.evaluate(call))
|
||||||
|
|
||||||
def testTooManyOutputNames(self):
|
def testTooManyOutputNames(self):
|
||||||
|
|
||||||
@ -158,7 +158,7 @@ class FunctionTest(test.TestCase):
|
|||||||
call = APlus2B([1.0], [2.0])
|
call = APlus2B([1.0], [2.0])
|
||||||
self.assertEqual("APlus2B", call.op.name)
|
self.assertEqual("APlus2B", call.op.name)
|
||||||
with session.Session() as sess:
|
with session.Session() as sess:
|
||||||
self.assertAllEqual([5.0], sess.run(call))
|
self.assertAllEqual([5.0], self.evaluate(call))
|
||||||
|
|
||||||
def testFunctionWithNoOutput(self):
|
def testFunctionWithNoOutput(self):
|
||||||
|
|
||||||
@ -187,7 +187,7 @@ class FunctionTest(test.TestCase):
|
|||||||
call = APlus2B([1.0], [2.0])
|
call = APlus2B([1.0], [2.0])
|
||||||
self.assertEqual("APlus2B", call.op.name)
|
self.assertEqual("APlus2B", call.op.name)
|
||||||
with session.Session() as sess:
|
with session.Session() as sess:
|
||||||
self.assertAllEqual([5.0], sess.run(call))
|
self.assertAllEqual([5.0], self.evaluate(call))
|
||||||
|
|
||||||
def testDefineFunctionDuplicateOutputs(self):
|
def testDefineFunctionDuplicateOutputs(self):
|
||||||
|
|
||||||
@ -224,8 +224,8 @@ class FunctionTest(test.TestCase):
|
|||||||
call_g = XSquarePlusOneGrad([2.0], [0.1])
|
call_g = XSquarePlusOneGrad([2.0], [0.1])
|
||||||
|
|
||||||
with session.Session() as sess:
|
with session.Session() as sess:
|
||||||
self.assertAllClose([5.0], sess.run(call_f))
|
self.assertAllClose([5.0], self.evaluate(call_f))
|
||||||
self.assertAllClose([0.4], sess.run(call_g))
|
self.assertAllClose([0.4], self.evaluate(call_g))
|
||||||
|
|
||||||
def testTanhSymGrad(self):
|
def testTanhSymGrad(self):
|
||||||
|
|
||||||
@ -387,7 +387,7 @@ class FunctionTest(test.TestCase):
|
|||||||
call = AConstant()
|
call = AConstant()
|
||||||
self.assertEqual("AConstant", call.op.name)
|
self.assertEqual("AConstant", call.op.name)
|
||||||
with session.Session() as sess:
|
with session.Session() as sess:
|
||||||
self.assertAllEqual([42], sess.run(call))
|
self.assertAllEqual([42], self.evaluate(call))
|
||||||
|
|
||||||
def testDefineFunctionNames(self):
|
def testDefineFunctionNames(self):
|
||||||
|
|
||||||
@ -468,7 +468,7 @@ class FunctionTest(test.TestCase):
|
|||||||
|
|
||||||
loop = control_flow_ops.while_loop(lambda x: x < 1e5, Body, [1.0])
|
loop = control_flow_ops.while_loop(lambda x: x < 1e5, Body, [1.0])
|
||||||
|
|
||||||
ans = sess.run(loop)
|
ans = self.evaluate(loop)
|
||||||
self.assertAllClose(ans, 131072.)
|
self.assertAllClose(ans, 131072.)
|
||||||
|
|
||||||
def testControlFlowStrictness(self):
|
def testControlFlowStrictness(self):
|
||||||
@ -650,8 +650,8 @@ class FunctionTest(test.TestCase):
|
|||||||
# pylint: enable=unexpected-keyword-arg
|
# pylint: enable=unexpected-keyword-arg
|
||||||
self.assertEqual("next", call2.op.name)
|
self.assertEqual("next", call2.op.name)
|
||||||
with session.Session() as sess:
|
with session.Session() as sess:
|
||||||
self.assertAllEqual([1], sess.run(call1))
|
self.assertAllEqual([1], self.evaluate(call1))
|
||||||
self.assertAllEqual([0], sess.run(call2))
|
self.assertAllEqual([0], self.evaluate(call2))
|
||||||
|
|
||||||
def testNestedFunction(self):
|
def testNestedFunction(self):
|
||||||
|
|
||||||
@ -794,7 +794,7 @@ class FunctionTest(test.TestCase):
|
|||||||
y = Foo()
|
y = Foo()
|
||||||
|
|
||||||
with self.session(graph=g) as sess:
|
with self.session(graph=g) as sess:
|
||||||
self.assertEqual(sess.run(y), 10)
|
self.assertEqual(self.evaluate(y), 10)
|
||||||
|
|
||||||
def testCaptureInCond(self):
|
def testCaptureInCond(self):
|
||||||
g = ops.Graph()
|
g = ops.Graph()
|
||||||
@ -809,8 +809,8 @@ class FunctionTest(test.TestCase):
|
|||||||
z = Foo(False)
|
z = Foo(False)
|
||||||
|
|
||||||
with self.session(graph=g) as sess:
|
with self.session(graph=g) as sess:
|
||||||
self.assertEqual(sess.run(y), 1)
|
self.assertEqual(self.evaluate(y), 1)
|
||||||
self.assertEqual(sess.run(z), 2)
|
self.assertEqual(self.evaluate(z), 2)
|
||||||
|
|
||||||
def testStableName(self):
|
def testStableName(self):
|
||||||
|
|
||||||
@ -900,7 +900,7 @@ class FunctionTest(test.TestCase):
|
|||||||
self.assertEqual(global_vars[0].name, "linear/w:0")
|
self.assertEqual(global_vars[0].name, "linear/w:0")
|
||||||
|
|
||||||
with session.Session() as sess:
|
with session.Session() as sess:
|
||||||
sess.run(variables.global_variables_initializer())
|
self.evaluate(variables.global_variables_initializer())
|
||||||
output_val = sess.run(
|
output_val = sess.run(
|
||||||
output_op, feed_dict={input_op: np.random.rand(32, 100)})
|
output_op, feed_dict={input_op: np.random.rand(32, 100)})
|
||||||
self.assertEqual(output_val.shape, (32, 100))
|
self.assertEqual(output_val.shape, (32, 100))
|
||||||
@ -928,7 +928,7 @@ class FunctionTest(test.TestCase):
|
|||||||
self.assertEqual(global_vars[0].name, "vs1/var:0")
|
self.assertEqual(global_vars[0].name, "vs1/var:0")
|
||||||
|
|
||||||
with session.Session() as sess:
|
with session.Session() as sess:
|
||||||
sess.run(variables.global_variables_initializer())
|
self.evaluate(variables.global_variables_initializer())
|
||||||
out1, out2 = sess.run(
|
out1, out2 = sess.run(
|
||||||
[out1_op, out2_op], feed_dict={input_op: np.linspace(1, 10, 10)})
|
[out1_op, out2_op], feed_dict={input_op: np.linspace(1, 10, 10)})
|
||||||
self.assertAllEqual(out1, np.linspace(2, 11, 10))
|
self.assertAllEqual(out1, np.linspace(2, 11, 10))
|
||||||
@ -991,8 +991,8 @@ class FunctionTest(test.TestCase):
|
|||||||
result_2 = Bar(constant_op.constant(100, dtype=dtypes.int64))
|
result_2 = Bar(constant_op.constant(100, dtype=dtypes.int64))
|
||||||
|
|
||||||
with session.Session() as sess:
|
with session.Session() as sess:
|
||||||
self.assertEqual(4.0, sess.run(result_1))
|
self.assertEqual(4.0, self.evaluate(result_1))
|
||||||
self.assertEqual(100, sess.run(result_2))
|
self.assertEqual(100, self.evaluate(result_2))
|
||||||
self.assertEqual((4.0, 100), sess.run((result_1, result_2)))
|
self.assertEqual((4.0, 100), sess.run((result_1, result_2)))
|
||||||
|
|
||||||
def testStatefulFunction(self):
|
def testStatefulFunction(self):
|
||||||
@ -1052,8 +1052,8 @@ class FunctionTest(test.TestCase):
|
|||||||
for config in _OptimizerOptions():
|
for config in _OptimizerOptions():
|
||||||
config.device_count["CPU"] = 2
|
config.device_count["CPU"] = 2
|
||||||
with session.Session(config=config) as sess:
|
with session.Session(config=config) as sess:
|
||||||
self.assertEqual(42.0, sess.run(f_0))
|
self.assertEqual(42.0, self.evaluate(f_0))
|
||||||
self.assertEqual(44.0, sess.run(f_1))
|
self.assertEqual(44.0, self.evaluate(f_1))
|
||||||
self.assertEqual((42.0, 44.0), sess.run((f_0, f_1)))
|
self.assertEqual((42.0, 44.0), sess.run((f_0, f_1)))
|
||||||
|
|
||||||
def testGuaranteedConstsAreCaptured(self):
|
def testGuaranteedConstsAreCaptured(self):
|
||||||
@ -1076,7 +1076,7 @@ class FunctionTest(test.TestCase):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
with self.session(use_gpu=False) as sess:
|
with self.session(use_gpu=False) as sess:
|
||||||
sess.run(var.initializer)
|
self.evaluate(var.initializer)
|
||||||
_ = sess.run(CapturesGuaranteedConst(), {also_not_const: 1.0})
|
_ = sess.run(CapturesGuaranteedConst(), {also_not_const: 1.0})
|
||||||
|
|
||||||
def testSameFunctionDifferentGrads(self):
|
def testSameFunctionDifferentGrads(self):
|
||||||
@ -1651,8 +1651,8 @@ class ModuleFunctionTest(test.TestCase):
|
|||||||
y = LinearWithCApi(a, b, c)
|
y = LinearWithCApi(a, b, c)
|
||||||
z = Linear2WithCApi(a, b, c, d, e)
|
z = Linear2WithCApi(a, b, c, d, e)
|
||||||
with session.Session() as sess:
|
with session.Session() as sess:
|
||||||
self.assertAllEqual([[1]], sess.run(y))
|
self.assertAllEqual([[1]], self.evaluate(y))
|
||||||
self.assertAllEqual([[5]], sess.run(z))
|
self.assertAllEqual([[5]], self.evaluate(z))
|
||||||
|
|
||||||
|
|
||||||
class VariableHoistingTest(test.TestCase):
|
class VariableHoistingTest(test.TestCase):
|
||||||
@ -1704,7 +1704,7 @@ class VariableHoistingTest(test.TestCase):
|
|||||||
self.assertEqual("Foo/b", b.op.name)
|
self.assertEqual("Foo/b", b.op.name)
|
||||||
|
|
||||||
with self.session(graph=g) as sess:
|
with self.session(graph=g) as sess:
|
||||||
sess.run(variables.global_variables_initializer())
|
self.evaluate(variables.global_variables_initializer())
|
||||||
w, b, x, y0, loss, dw, db = sess.run([w, b, x, y0, loss, dw, db])
|
w, b, x, y0, loss, dw, db = sess.run([w, b, x, y0, loss, dw, db])
|
||||||
|
|
||||||
self.assertAllEqual(w.shape, (64, 64))
|
self.assertAllEqual(w.shape, (64, 64))
|
||||||
|
@ -211,7 +211,7 @@ class DeviceFunctionsTest(test.TestCase):
|
|||||||
with session.Session() as sess:
|
with session.Session() as sess:
|
||||||
init = variables.variables_initializer([variable_node])
|
init = variables.variables_initializer([variable_node])
|
||||||
sess.run(init)
|
sess.run(init)
|
||||||
output = sess.run(output_node)
|
output = self.evaluate(output_node)
|
||||||
self.assertNear(4.0, output, 0.00001)
|
self.assertNear(4.0, output, 0.00001)
|
||||||
variable_graph_def = sess.graph.as_graph_def()
|
variable_graph_def = sess.graph.as_graph_def()
|
||||||
|
|
||||||
@ -242,8 +242,8 @@ class DeviceFunctionsTest(test.TestCase):
|
|||||||
output_node = math_ops_lib.multiply(
|
output_node = math_ops_lib.multiply(
|
||||||
variable_node, 2.0, name="output_node")
|
variable_node, 2.0, name="output_node")
|
||||||
with session.Session() as sess:
|
with session.Session() as sess:
|
||||||
sess.run(variable_node.initializer)
|
self.evaluate(variable_node.initializer)
|
||||||
output = sess.run(output_node)
|
output = self.evaluate(output_node)
|
||||||
self.assertNear(2.0, output, 0.00001)
|
self.assertNear(2.0, output, 0.00001)
|
||||||
variable_graph_def = sess.graph.as_graph_def()
|
variable_graph_def = sess.graph.as_graph_def()
|
||||||
# First get the constant_graph_def when variable_names_whitelist is
|
# First get the constant_graph_def when variable_names_whitelist is
|
||||||
@ -256,7 +256,7 @@ class DeviceFunctionsTest(test.TestCase):
|
|||||||
|
|
||||||
# Then initialize the unused variable, and get another
|
# Then initialize the unused variable, and get another
|
||||||
# constant_graph_def when variable_names_whitelist is not set.
|
# constant_graph_def when variable_names_whitelist is not set.
|
||||||
sess.run(another_variable.initializer)
|
self.evaluate(another_variable.initializer)
|
||||||
constant_graph_def_without_variable_whitelist = (
|
constant_graph_def_without_variable_whitelist = (
|
||||||
graph_util.convert_variables_to_constants(
|
graph_util.convert_variables_to_constants(
|
||||||
sess, variable_graph_def, ["output_node"]))
|
sess, variable_graph_def, ["output_node"]))
|
||||||
@ -295,7 +295,7 @@ class DeviceFunctionsTest(test.TestCase):
|
|||||||
["Variable", "VariableV2", "VarHandleOp", "ReadVariableOp"])
|
["Variable", "VariableV2", "VarHandleOp", "ReadVariableOp"])
|
||||||
with session.Session() as sess:
|
with session.Session() as sess:
|
||||||
output_node = sess.graph.get_tensor_by_name("output_node:0")
|
output_node = sess.graph.get_tensor_by_name("output_node:0")
|
||||||
output = sess.run(output_node)
|
output = self.evaluate(output_node)
|
||||||
self.assertNear(2.0, output, 0.00001)
|
self.assertNear(2.0, output, 0.00001)
|
||||||
|
|
||||||
def create_node_def(self, op, name, inputs):
|
def create_node_def(self, op, name, inputs):
|
||||||
|
@ -398,10 +398,10 @@ class ImportGraphDefTest(test.TestCase):
|
|||||||
# TODO(b/76173421): make this work (currently DCHECKS)
|
# TODO(b/76173421): make this work (currently DCHECKS)
|
||||||
# with self.cached_session() as sess:
|
# with self.cached_session() as sess:
|
||||||
# sess.run(imported_init)
|
# sess.run(imported_init)
|
||||||
# self.assertEqual(sess.run(imported_var), 1.0)
|
# self.assertEqual(self.evaluate(imported_var), 1.0)
|
||||||
# self.assertEqual(sess.run(imported_assign), 2.0)
|
# self.assertEqual(self.evaluate(imported_assign), 2.0)
|
||||||
# self.assertEqual(list(sess.run(imported_shape)), [])
|
# self.assertEqual(list(self.evaluate(imported_shape)), [])
|
||||||
# self.assertEqual(list(sess.run(new_var_shape)), [])
|
# self.assertEqual(list(self.evaluate(new_var_shape)), [])
|
||||||
|
|
||||||
def testWhileLoop(self):
|
def testWhileLoop(self):
|
||||||
# Produce GraphDef containing while loop.
|
# Produce GraphDef containing while loop.
|
||||||
@ -418,7 +418,7 @@ class ImportGraphDefTest(test.TestCase):
|
|||||||
return_elements=[r.name])
|
return_elements=[r.name])
|
||||||
self.assertEqual(imported_r.name, "import/" + r.name)
|
self.assertEqual(imported_r.name, "import/" + r.name)
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertEqual(sess.run(imported_r), 10)
|
self.assertEqual(self.evaluate(imported_r), 10)
|
||||||
|
|
||||||
def testImportWhileLoopInCond(self):
|
def testImportWhileLoopInCond(self):
|
||||||
# Produce GraphDef containing while loop.
|
# Produce GraphDef containing while loop.
|
||||||
@ -458,7 +458,7 @@ class ImportGraphDefTest(test.TestCase):
|
|||||||
lambda i: i < 2, ImportFn, [0],
|
lambda i: i < 2, ImportFn, [0],
|
||||||
shape_invariants=[tensor_shape.TensorShape(None)])
|
shape_invariants=[tensor_shape.TensorShape(None)])
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
self.assertEqual(sess.run(out), 10)
|
self.assertEqual(self.evaluate(out), 10)
|
||||||
|
|
||||||
def testTypeMismatchInGraphDef(self):
|
def testTypeMismatchInGraphDef(self):
|
||||||
# TODO(skyewm): improve error message
|
# TODO(skyewm): improve error message
|
||||||
|
@ -492,8 +492,8 @@ class ScopedMetaGraphTest(test.TestCase):
|
|||||||
init_op = variables.global_variables_initializer()
|
init_op = variables.global_variables_initializer()
|
||||||
grad = gradients_impl.gradients([output], [var])
|
grad = gradients_impl.gradients([output], [var])
|
||||||
with session.Session() as sess:
|
with session.Session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
expected_grad_value = sess.run(grad)
|
expected_grad_value = self.evaluate(grad)
|
||||||
|
|
||||||
# Restore the MetaGraphDef into a new Graph with an import scope.
|
# Restore the MetaGraphDef into a new Graph with an import scope.
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
@ -518,8 +518,8 @@ class ScopedMetaGraphTest(test.TestCase):
|
|||||||
init_op = variables.global_variables_initializer()
|
init_op = variables.global_variables_initializer()
|
||||||
|
|
||||||
with session.Session() as sess:
|
with session.Session() as sess:
|
||||||
sess.run(init_op)
|
self.evaluate(init_op)
|
||||||
actual_grad_value = sess.run(grad)
|
actual_grad_value = self.evaluate(grad)
|
||||||
self.assertEqual(expected_grad_value, actual_grad_value)
|
self.assertEqual(expected_grad_value, actual_grad_value)
|
||||||
|
|
||||||
def testImportWhileLoopInWhileLoop(self):
|
def testImportWhileLoopInWhileLoop(self):
|
||||||
@ -544,7 +544,7 @@ class ScopedMetaGraphTest(test.TestCase):
|
|||||||
_, x = control_flow_ops.while_loop(lambda i, x: i < 2, body, [0, 0.0],
|
_, x = control_flow_ops.while_loop(lambda i, x: i < 2, body, [0, 0.0],
|
||||||
name="")
|
name="")
|
||||||
with session.Session() as sess:
|
with session.Session() as sess:
|
||||||
sess.run(variables.global_variables_initializer())
|
self.evaluate(variables.global_variables_initializer())
|
||||||
sess.run(x)
|
sess.run(x)
|
||||||
|
|
||||||
def testScopedImportUnderNameScope(self):
|
def testScopedImportUnderNameScope(self):
|
||||||
@ -869,7 +869,7 @@ class MetaGraphWithVariableScopeTest(test.TestCase):
|
|||||||
|
|
||||||
initializer = variables.local_variables_initializer()
|
initializer = variables.local_variables_initializer()
|
||||||
sess.run(initializer)
|
sess.run(initializer)
|
||||||
sess.run(update_op)
|
self.evaluate(update_op)
|
||||||
|
|
||||||
meta_graph.export_scoped_meta_graph(
|
meta_graph.export_scoped_meta_graph(
|
||||||
filename=meta_graph_filename, graph=graph)
|
filename=meta_graph_filename, graph=graph)
|
||||||
|
@ -517,21 +517,21 @@ class OperationTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEquals(x.consumers(), [])
|
self.assertEquals(x.consumers(), [])
|
||||||
self.assertEquals(y.consumers(), [z.op, z.op])
|
self.assertEquals(y.consumers(), [z.op, z.op])
|
||||||
with session.Session(graph=g) as sess:
|
with session.Session(graph=g) as sess:
|
||||||
self.assertEquals(sess.run(z), 4)
|
self.assertEquals(self.evaluate(z), 4)
|
||||||
|
|
||||||
z.op._update_input(0, x) # pylint: disable=protected-access
|
z.op._update_input(0, x) # pylint: disable=protected-access
|
||||||
self.assertEquals(list(z.op.inputs), [x, y])
|
self.assertEquals(list(z.op.inputs), [x, y])
|
||||||
self.assertEquals(x.consumers(), [z.op])
|
self.assertEquals(x.consumers(), [z.op])
|
||||||
self.assertEquals(y.consumers(), [z.op])
|
self.assertEquals(y.consumers(), [z.op])
|
||||||
with session.Session(graph=g) as sess:
|
with session.Session(graph=g) as sess:
|
||||||
self.assertEquals(sess.run(z), 3)
|
self.assertEquals(self.evaluate(z), 3)
|
||||||
|
|
||||||
z.op._update_input(1, y) # pylint: disable=protected-access
|
z.op._update_input(1, y) # pylint: disable=protected-access
|
||||||
self.assertEquals(list(z.op.inputs), [x, y])
|
self.assertEquals(list(z.op.inputs), [x, y])
|
||||||
self.assertEquals(x.consumers(), [z.op])
|
self.assertEquals(x.consumers(), [z.op])
|
||||||
self.assertEquals(y.consumers(), [z.op])
|
self.assertEquals(y.consumers(), [z.op])
|
||||||
with session.Session(graph=g) as sess:
|
with session.Session(graph=g) as sess:
|
||||||
self.assertEquals(sess.run(z), 3)
|
self.assertEquals(self.evaluate(z), 3)
|
||||||
|
|
||||||
def testUpdateInputGraphError(self):
|
def testUpdateInputGraphError(self):
|
||||||
g_0 = ops.Graph()
|
g_0 = ops.Graph()
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user