Replace x.eval() with self.evaluate() in tests.
It makes the line V2 compatible $ sed -i "s/\([a-z][._a-z0-9]*\).eval()/self.evaluate(\1)/" ./third_party/tensorflow/python/kernel_tests/*.py And some manual cleaning. PiperOrigin-RevId: 324169377 Change-Id: I36a8d92c28a46e5a05fb32bdb634a84081951d88
This commit is contained in:
parent
dd472f98f5
commit
88c2594048
@ -1069,11 +1069,11 @@ class StridedSliceBenchmark(test_lib.Benchmark):
|
||||
def run_and_time(self, slice_op):
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
for _ in range(10):
|
||||
_ = slice_op.eval()
|
||||
_ = self.evaluate(slice_op)
|
||||
iters = 1000
|
||||
t0 = time.time()
|
||||
for _ in range(iters):
|
||||
slice_op.eval()
|
||||
self.evaluate(slice_op)
|
||||
t1 = time.time()
|
||||
self.report_benchmark(iters=iters, wall_time=(t1 - t0) / 1000.0)
|
||||
|
||||
@ -1474,7 +1474,7 @@ class GuaranteeConstOpTest(test_util.TensorFlowTestCase):
|
||||
with self.cached_session():
|
||||
a = array_ops.constant(10)
|
||||
guarantee_a = array_ops.guarantee_const(a)
|
||||
self.assertEqual(10, guarantee_a.eval())
|
||||
self.assertEqual(10, self.evaluate(guarantee_a))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testVariables(self):
|
||||
@ -1487,7 +1487,7 @@ class GuaranteeConstOpTest(test_util.TensorFlowTestCase):
|
||||
use_resource=use_resource)
|
||||
guarantee_a = array_ops.guarantee_const(a)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertEqual(10.0, guarantee_a.eval())
|
||||
self.assertEqual(10.0, self.evaluate(guarantee_a))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testResourceRejection(self):
|
||||
@ -1500,7 +1500,7 @@ class GuaranteeConstOpTest(test_util.TensorFlowTestCase):
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
|
||||
"cannot be a resource variable"):
|
||||
guarantee_a.eval()
|
||||
self.evaluate(guarantee_a)
|
||||
|
||||
|
||||
class SnapshotOpTest(test_util.TensorFlowTestCase):
|
||||
|
@ -78,11 +78,11 @@ class BarrierTest(test.TestCase):
|
||||
insert_0_op = b.insert_many(0, keys, [10.0, 20.0, 30.0])
|
||||
insert_1_op = b.insert_many(1, keys, [100.0, 200.0, 300.0])
|
||||
|
||||
self.assertEqual(size_t.eval(), [0])
|
||||
self.assertEqual(self.evaluate(size_t), [0])
|
||||
insert_0_op.run()
|
||||
self.assertEqual(size_t.eval(), [0])
|
||||
self.assertEqual(self.evaluate(size_t), [0])
|
||||
insert_1_op.run()
|
||||
self.assertEqual(size_t.eval(), [3])
|
||||
self.assertEqual(self.evaluate(size_t), [3])
|
||||
|
||||
def testInsertManyEmptyTensor(self):
|
||||
with self.cached_session():
|
||||
@ -100,7 +100,7 @@ class BarrierTest(test.TestCase):
|
||||
self.assertEqual([], size_t.get_shape())
|
||||
keys = [b"a", b"b", b"c"]
|
||||
insert_0_op = b.insert_many(0, keys, np.array([[], [], []], np.float32))
|
||||
self.assertEqual(size_t.eval(), [0])
|
||||
self.assertEqual(self.evaluate(size_t), [0])
|
||||
with self.assertRaisesOpError(
|
||||
".*Tensors with no elements are not supported.*"):
|
||||
insert_0_op.run()
|
||||
@ -120,7 +120,7 @@ class BarrierTest(test.TestCase):
|
||||
|
||||
insert_0_op.run()
|
||||
insert_1_op.run()
|
||||
self.assertEqual(size_t.eval(), [3])
|
||||
self.assertEqual(self.evaluate(size_t), [3])
|
||||
|
||||
indices_val, keys_val, values_0_val, values_1_val = sess.run(
|
||||
[take_t[0], take_t[1], take_t[2][0], take_t[2][1]])
|
||||
@ -157,8 +157,9 @@ class BarrierTest(test.TestCase):
|
||||
close_op.run()
|
||||
# Now we have a closed barrier with 2 ready elements. Running take_t
|
||||
# should return a reduced batch with 2 elements only.
|
||||
self.assertEqual(size_i.eval(), [2]) # assert that incomplete size = 2
|
||||
self.assertEqual(size_t.eval(), [2]) # assert that ready size = 2
|
||||
self.assertEqual(self.evaluate(size_i),
|
||||
[2]) # assert that incomplete size = 2
|
||||
self.assertEqual(self.evaluate(size_t), [2]) # assert that ready size = 2
|
||||
_, keys_val, values_0_val, values_1_val = sess.run(
|
||||
[index_t, key_t, value_list_t[0], value_list_t[1]])
|
||||
# Check that correct values have been returned.
|
||||
@ -170,8 +171,9 @@ class BarrierTest(test.TestCase):
|
||||
# The next insert completes the element with key "c". The next take_t
|
||||
# should return a batch with just 1 element.
|
||||
insert_1_2_op.run()
|
||||
self.assertEqual(size_i.eval(), [1]) # assert that incomplete size = 1
|
||||
self.assertEqual(size_t.eval(), [1]) # assert that ready size = 1
|
||||
self.assertEqual(self.evaluate(size_i),
|
||||
[1]) # assert that incomplete size = 1
|
||||
self.assertEqual(self.evaluate(size_t), [1]) # assert that ready size = 1
|
||||
_, keys_val, values_0_val, values_1_val = sess.run(
|
||||
[index_t, key_t, value_list_t[0], value_list_t[1]])
|
||||
# Check that correct values have been returned.
|
||||
@ -212,7 +214,7 @@ class BarrierTest(test.TestCase):
|
||||
|
||||
insert_0_op.run()
|
||||
insert_1_op.run()
|
||||
self.assertEqual(size_t.eval(), [3])
|
||||
self.assertEqual(self.evaluate(size_t), [3])
|
||||
|
||||
indices_val, keys_val, values_0_val, values_1_val = sess.run(
|
||||
[take_t[0], take_t[1], take_t[2][0], take_t[2][1]])
|
||||
@ -237,7 +239,7 @@ class BarrierTest(test.TestCase):
|
||||
take_t = b.take_many(10)
|
||||
|
||||
self.evaluate(insert_ops)
|
||||
self.assertEqual(size_t.eval(), [10])
|
||||
self.assertEqual(self.evaluate(size_t), [10])
|
||||
|
||||
indices_val, keys_val, values_val = sess.run(
|
||||
[take_t[0], take_t[1], take_t[2][0]])
|
||||
@ -258,7 +260,7 @@ class BarrierTest(test.TestCase):
|
||||
take_t = [b.take_many(1) for _ in keys]
|
||||
|
||||
insert_op.run()
|
||||
self.assertEqual(size_t.eval(), [10])
|
||||
self.assertEqual(self.evaluate(size_t), [10])
|
||||
|
||||
index_fetches = []
|
||||
key_fetches = []
|
||||
@ -402,11 +404,11 @@ class BarrierTest(test.TestCase):
|
||||
take_t = b.take_many(3)
|
||||
take_too_many_t = b.take_many(4)
|
||||
|
||||
self.assertEqual(size_t.eval(), [0])
|
||||
self.assertEqual(incomplete_t.eval(), [0])
|
||||
self.assertEqual(self.evaluate(size_t), [0])
|
||||
self.assertEqual(self.evaluate(incomplete_t), [0])
|
||||
insert_0_op.run()
|
||||
self.assertEqual(size_t.eval(), [0])
|
||||
self.assertEqual(incomplete_t.eval(), [3])
|
||||
self.assertEqual(self.evaluate(size_t), [0])
|
||||
self.assertEqual(self.evaluate(incomplete_t), [3])
|
||||
close_op.run()
|
||||
|
||||
# This op should fail because the barrier is closed.
|
||||
@ -416,8 +418,8 @@ class BarrierTest(test.TestCase):
|
||||
# This op should succeed because the barrier has not canceled
|
||||
# pending enqueues
|
||||
insert_1_op.run()
|
||||
self.assertEqual(size_t.eval(), [3])
|
||||
self.assertEqual(incomplete_t.eval(), [0])
|
||||
self.assertEqual(self.evaluate(size_t), [3])
|
||||
self.assertEqual(self.evaluate(incomplete_t), [0])
|
||||
|
||||
# This op should fail because the barrier is closed.
|
||||
with self.assertRaisesOpError("is closed"):
|
||||
@ -462,11 +464,11 @@ class BarrierTest(test.TestCase):
|
||||
take_t = b.take_many(2)
|
||||
take_too_many_t = b.take_many(3)
|
||||
|
||||
self.assertEqual(size_t.eval(), [0])
|
||||
self.assertEqual(self.evaluate(size_t), [0])
|
||||
insert_0_op.run()
|
||||
insert_1_op.run()
|
||||
self.assertEqual(size_t.eval(), [2])
|
||||
self.assertEqual(incomplete_t.eval(), [1])
|
||||
self.assertEqual(self.evaluate(size_t), [2])
|
||||
self.assertEqual(self.evaluate(incomplete_t), [1])
|
||||
cancel_op.run()
|
||||
|
||||
# This op should fail because the queue is closed.
|
||||
@ -700,17 +702,17 @@ class BarrierTest(test.TestCase):
|
||||
(dtypes.float32,), shapes=(()), shared_name="b_a")
|
||||
b_a_2 = data_flow_ops.Barrier(
|
||||
(dtypes.int32,), shapes=(()), shared_name="b_a")
|
||||
b_a_1.barrier_ref.eval()
|
||||
self.evaluate(b_a_1.barrier_ref)
|
||||
with self.assertRaisesOpError("component types"):
|
||||
b_a_2.barrier_ref.eval()
|
||||
self.evaluate(b_a_2.barrier_ref)
|
||||
|
||||
b_b_1 = data_flow_ops.Barrier(
|
||||
(dtypes.float32,), shapes=(()), shared_name="b_b")
|
||||
b_b_2 = data_flow_ops.Barrier(
|
||||
(dtypes.float32, dtypes.int32), shapes=((), ()), shared_name="b_b")
|
||||
b_b_1.barrier_ref.eval()
|
||||
self.evaluate(b_b_1.barrier_ref)
|
||||
with self.assertRaisesOpError("component types"):
|
||||
b_b_2.barrier_ref.eval()
|
||||
self.evaluate(b_b_2.barrier_ref)
|
||||
|
||||
b_c_1 = data_flow_ops.Barrier(
|
||||
(dtypes.float32, dtypes.float32),
|
||||
@ -718,9 +720,9 @@ class BarrierTest(test.TestCase):
|
||||
shared_name="b_c")
|
||||
b_c_2 = data_flow_ops.Barrier(
|
||||
(dtypes.float32, dtypes.float32), shared_name="b_c")
|
||||
b_c_1.barrier_ref.eval()
|
||||
self.evaluate(b_c_1.barrier_ref)
|
||||
with self.assertRaisesOpError("component shapes"):
|
||||
b_c_2.barrier_ref.eval()
|
||||
self.evaluate(b_c_2.barrier_ref)
|
||||
|
||||
b_d_1 = data_flow_ops.Barrier(
|
||||
(dtypes.float32, dtypes.float32), shapes=((), ()), shared_name="b_d")
|
||||
@ -728,9 +730,9 @@ class BarrierTest(test.TestCase):
|
||||
(dtypes.float32, dtypes.float32),
|
||||
shapes=((2, 2), (8,)),
|
||||
shared_name="b_d")
|
||||
b_d_1.barrier_ref.eval()
|
||||
self.evaluate(b_d_1.barrier_ref)
|
||||
with self.assertRaisesOpError("component shapes"):
|
||||
b_d_2.barrier_ref.eval()
|
||||
self.evaluate(b_d_2.barrier_ref)
|
||||
|
||||
b_e_1 = data_flow_ops.Barrier(
|
||||
(dtypes.float32, dtypes.float32),
|
||||
@ -740,9 +742,9 @@ class BarrierTest(test.TestCase):
|
||||
(dtypes.float32, dtypes.float32),
|
||||
shapes=((2, 5), (8,)),
|
||||
shared_name="b_e")
|
||||
b_e_1.barrier_ref.eval()
|
||||
self.evaluate(b_e_1.barrier_ref)
|
||||
with self.assertRaisesOpError("component shapes"):
|
||||
b_e_2.barrier_ref.eval()
|
||||
self.evaluate(b_e_2.barrier_ref)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -89,7 +89,7 @@ class BatchToSpaceErrorHandlingTest(test.TestCase, PythonOpImpl):
|
||||
block_size = 0
|
||||
with self.assertRaises(ValueError):
|
||||
out_tf = self.batch_to_space(x_np, crops, block_size)
|
||||
out_tf.eval()
|
||||
self.evaluate(out_tf)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBlockSizeOne(self):
|
||||
@ -109,7 +109,7 @@ class BatchToSpaceErrorHandlingTest(test.TestCase, PythonOpImpl):
|
||||
block_size = 10
|
||||
with self.assertRaises(ValueError):
|
||||
out_tf = self.batch_to_space(x_np, crops, block_size)
|
||||
out_tf.eval()
|
||||
self.evaluate(out_tf)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBlockSizeSquaredNotDivisibleBatch(self):
|
||||
|
@ -381,8 +381,8 @@ class ClipTest(test.TestCase):
|
||||
np_ans_1 = [0.8, -1.6]
|
||||
|
||||
ans, norm = clip_ops.clip_by_global_norm([x0, x1], clip_norm)
|
||||
tf_ans_1 = ans[0].eval()
|
||||
tf_ans_2 = ans[1].values.eval()
|
||||
tf_ans_1 = self.evaluate(ans[0])
|
||||
tf_ans_2 = self.evaluate(ans[1].values)
|
||||
tf_norm = self.evaluate(norm)
|
||||
|
||||
self.assertAllClose(tf_norm, 5.0)
|
||||
|
@ -1254,7 +1254,7 @@ class CondV2CollectionTest(test.TestCase):
|
||||
return math_ops.add(x_const, y_const)
|
||||
|
||||
cnd = cond_v2.cond_v2(constant_op.constant(True), fn, fn)
|
||||
self.assertEqual(cnd.eval(), 7)
|
||||
self.assertEqual(self.evaluate(cnd), 7)
|
||||
|
||||
def testCollectionTensorValueAccessInCond(self):
|
||||
"""Read tensors from collections inside of cond_v2 & use them."""
|
||||
@ -1271,7 +1271,7 @@ class CondV2CollectionTest(test.TestCase):
|
||||
return math_ops.add(x_read, y_read)
|
||||
|
||||
cnd = cond_v2.cond_v2(math_ops.less(x, y), fn, fn)
|
||||
self.assertEqual(cnd.eval(), 7)
|
||||
self.assertEqual(self.evaluate(cnd), 7)
|
||||
|
||||
def testCollectionIntValueWriteInCond(self):
|
||||
"""Make sure Int writes to collections work inside of cond_v2."""
|
||||
@ -1289,7 +1289,7 @@ class CondV2CollectionTest(test.TestCase):
|
||||
return math_ops.mul(x, z)
|
||||
|
||||
cnd = cond_v2.cond_v2(constant_op.constant(True), true_fn, false_fn)
|
||||
self.assertEqual(cnd.eval(), 14)
|
||||
self.assertEqual(self.evaluate(cnd), 14)
|
||||
|
||||
read_z_collection = ops.get_collection("z")
|
||||
self.assertEqual(read_z_collection, [7])
|
||||
@ -1363,11 +1363,11 @@ class CondV2ContainerTest(test.TestCase):
|
||||
with ops.container("l1"):
|
||||
cnd_true = cond_v2.cond_v2(
|
||||
constant_op.constant(True), true_fn, false_fn)
|
||||
self.assertEqual(cnd_true.eval(), 2)
|
||||
self.assertEqual(self.evaluate(cnd_true), 2)
|
||||
|
||||
cnd_false = cond_v2.cond_v2(
|
||||
constant_op.constant(False), true_fn, false_fn)
|
||||
self.assertEqual(cnd_false.eval(), 6)
|
||||
self.assertEqual(self.evaluate(cnd_false), 6)
|
||||
|
||||
v4 = variables.Variable([3])
|
||||
q4 = data_flow_ops.FIFOQueue(1, dtypes.float32)
|
||||
|
@ -53,7 +53,7 @@ class DepthToSpaceTest(test.TestCase):
|
||||
with self.assertRaisesRegex(
|
||||
errors_impl.InvalidArgumentError,
|
||||
"No OpKernel was registered to support Op 'DepthToSpace'"):
|
||||
output_nhwc.eval()
|
||||
self.evaluate(output_nhwc)
|
||||
|
||||
if test.is_gpu_available():
|
||||
with self.cached_session(use_gpu=True):
|
||||
|
@ -809,7 +809,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
|
||||
initializer=initializer))
|
||||
for w in embedding_weights:
|
||||
self.evaluate(w.initializer)
|
||||
embedding_weights = [w.eval() for w in embedding_weights]
|
||||
embedding_weights = [self.evaluate(w) for w in embedding_weights]
|
||||
return embedding_weights
|
||||
|
||||
def _ids_and_weights_2d(self):
|
||||
|
@ -456,7 +456,7 @@ class UnconvertedFIFOQueueTests(test.TestCase):
|
||||
|
||||
dequeued_elems = []
|
||||
for _ in dequeue_counts:
|
||||
dequeued_elems.extend(dequeued_t.eval())
|
||||
dequeued_elems.extend(self.evaluate(dequeued_t))
|
||||
self.assertEqual(elems, dequeued_elems)
|
||||
|
||||
def testDequeueFromClosedQueue(self):
|
||||
@ -485,7 +485,7 @@ class UnconvertedFIFOQueueTests(test.TestCase):
|
||||
|
||||
enqueue_op.run()
|
||||
for _ in range(500):
|
||||
self.assertEqual(size_t.eval(), [1])
|
||||
self.assertEqual(self.evaluate(size_t), [1])
|
||||
|
||||
def testSharedQueueSameSession(self):
|
||||
with self.cached_session():
|
||||
@ -499,23 +499,23 @@ class UnconvertedFIFOQueueTests(test.TestCase):
|
||||
q1_size_t = q1.size()
|
||||
q2_size_t = q2.size()
|
||||
|
||||
self.assertEqual(q1_size_t.eval(), [1])
|
||||
self.assertEqual(q2_size_t.eval(), [1])
|
||||
self.assertEqual(self.evaluate(q1_size_t), [1])
|
||||
self.assertEqual(self.evaluate(q2_size_t), [1])
|
||||
|
||||
self.assertEqual(q2.dequeue().eval(), [10.0])
|
||||
|
||||
self.assertEqual(q1_size_t.eval(), [0])
|
||||
self.assertEqual(q2_size_t.eval(), [0])
|
||||
self.assertEqual(self.evaluate(q1_size_t), [0])
|
||||
self.assertEqual(self.evaluate(q2_size_t), [0])
|
||||
|
||||
q2.enqueue((20.0,)).run()
|
||||
|
||||
self.assertEqual(q1_size_t.eval(), [1])
|
||||
self.assertEqual(q2_size_t.eval(), [1])
|
||||
self.assertEqual(self.evaluate(q1_size_t), [1])
|
||||
self.assertEqual(self.evaluate(q2_size_t), [1])
|
||||
|
||||
self.assertEqual(q1.dequeue().eval(), [20.0])
|
||||
|
||||
self.assertEqual(q1_size_t.eval(), [0])
|
||||
self.assertEqual(q2_size_t.eval(), [0])
|
||||
self.assertEqual(self.evaluate(q1_size_t), [0])
|
||||
self.assertEqual(self.evaluate(q2_size_t), [0])
|
||||
|
||||
def testIncompatibleSharedQueueErrors(self):
|
||||
with self.cached_session():
|
||||
@ -796,7 +796,7 @@ class FIFOQueueParallelTests(test.TestCase):
|
||||
# Dequeue every element using a single thread.
|
||||
results = []
|
||||
for _ in xrange(len(elems)):
|
||||
results.append(dequeued_t.eval())
|
||||
results.append(self.evaluate(dequeued_t))
|
||||
self.assertItemsEqual(elems, results)
|
||||
|
||||
def testParallelDequeue(self):
|
||||
@ -906,27 +906,28 @@ class FIFOQueueParallelTests(test.TestCase):
|
||||
|
||||
# The enqueue should start and then block.
|
||||
results = []
|
||||
results.append(deq.eval()) # Will only complete after the enqueue starts.
|
||||
results.append(
|
||||
self.evaluate(deq)) # Will only complete after the enqueue starts.
|
||||
self.assertEqual(len(enq_done), 1)
|
||||
self.assertEqual(self.evaluate(size_op), 5)
|
||||
|
||||
for _ in range(3):
|
||||
results.append(deq.eval())
|
||||
results.append(self.evaluate(deq))
|
||||
|
||||
time.sleep(0.1)
|
||||
self.assertEqual(len(enq_done), 1)
|
||||
self.assertEqual(self.evaluate(size_op), 5)
|
||||
|
||||
# This dequeue will unblock the thread.
|
||||
results.append(deq.eval())
|
||||
results.append(self.evaluate(deq))
|
||||
time.sleep(0.1)
|
||||
self.assertEqual(len(enq_done), 2)
|
||||
thread.join()
|
||||
|
||||
for i in range(5):
|
||||
self.assertEqual(size_op.eval(), 5 - i)
|
||||
results.append(deq.eval())
|
||||
self.assertEqual(size_op.eval(), 5 - i - 1)
|
||||
self.assertEqual(self.evaluate(size_op), 5 - i)
|
||||
results.append(self.evaluate(deq))
|
||||
self.assertEqual(self.evaluate(size_op), 5 - i - 1)
|
||||
|
||||
self.assertAllEqual(elem, results)
|
||||
|
||||
@ -1404,7 +1405,7 @@ class FIFOQueueParallelTests(test.TestCase):
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
self.assertItemsEqual(dequeued_t.eval(), elems * 10)
|
||||
self.assertCountEqual(self.evaluate(dequeued_t), elems * 10)
|
||||
|
||||
def testParallelDequeueMany(self):
|
||||
# We need each thread to keep its own device stack or the device scopes
|
||||
|
@ -1039,7 +1039,7 @@ class PartitionedCallTest(test.TestCase):
|
||||
output, = functional_ops.partitioned_call(
|
||||
args=[constant_op.constant(1.),
|
||||
constant_op.constant(2.)], f=Body)
|
||||
self.assertEqual(output.eval(), 12.)
|
||||
self.assertEqual(self.evaluate(output), 12.)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBasicMultiDeviceGPU(self):
|
||||
|
@ -1001,7 +1001,7 @@ class ConvolutionOrthogonal1dInitializerTest(test.TestCase):
|
||||
shape=shape,
|
||||
initializer=init_ops.convolutional_orthogonal_1d)
|
||||
self.evaluate(x.initializer)
|
||||
y = np.sum(x.eval(), axis=0)
|
||||
y = np.sum(self.evaluate(x), axis=0)
|
||||
determinant = np.linalg.det(y)
|
||||
value += determinant
|
||||
abs_value += np.abs(determinant)
|
||||
@ -1230,7 +1230,7 @@ class ConvolutionOrthogonal3dInitializerTest(test.TestCase):
|
||||
shape=shape,
|
||||
initializer=init_ops.convolutional_orthogonal_3d)
|
||||
self.evaluate(x.initializer)
|
||||
y = np.sum(x.eval(), axis=(0, 1, 2))
|
||||
y = np.sum(self.evaluate(x), axis=(0, 1, 2))
|
||||
determinant = np.linalg.det(y)
|
||||
value += determinant
|
||||
abs_value += np.abs(determinant)
|
||||
|
@ -42,7 +42,7 @@ class IoOpsTest(test.TestCase):
|
||||
with self.cached_session():
|
||||
read = io_ops.read_file(temp.name)
|
||||
self.assertEqual([], read.get_shape())
|
||||
self.assertEqual(read.eval(), contents)
|
||||
self.assertEqual(self.evaluate(read), contents)
|
||||
os.remove(temp.name)
|
||||
|
||||
def testWriteFile(self):
|
||||
|
@ -1690,7 +1690,7 @@ class DenseHashTableOpTest(test.TestCase):
|
||||
[[11, 12], [11, 14], [11, 15], [13, 14], [13, 15]], dtypes.int64)
|
||||
output = table.lookup(input_string)
|
||||
self.assertAllEqual([[0, 1], [2, 3], [-1, -2], [4, 5], [-1, -2]],
|
||||
output.eval())
|
||||
self.evaluate(output))
|
||||
|
||||
@test_util.run_v1_only("Saver V1 only")
|
||||
def testVectorScalarSaveRestore(self):
|
||||
|
@ -132,7 +132,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
|
||||
labels = constant_op.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
|
||||
loss = losses.softmax_cross_entropy(labels, logits)
|
||||
self.assertEqual('softmax_cross_entropy_loss/value', loss.op.name)
|
||||
self.assertAlmostEqual(loss.eval(), 0.0, 3)
|
||||
self.assertAlmostEqual(self.evaluate(loss), 0.0, 3)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testAllWrong(self):
|
||||
@ -143,7 +143,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
|
||||
with self.cached_session():
|
||||
loss = losses.softmax_cross_entropy(labels, logits)
|
||||
self.assertEqual(loss.op.name, 'softmax_cross_entropy_loss/value')
|
||||
self.assertAlmostEqual(loss.eval(), 10.0, 3)
|
||||
self.assertAlmostEqual(self.evaluate(loss), 10.0, 3)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testNonZeroLossWithPythonScalarWeight(self):
|
||||
@ -225,7 +225,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
|
||||
labels, logits, label_smoothing=label_smoothing)
|
||||
self.assertEqual(loss.op.name, 'softmax_cross_entropy_loss/value')
|
||||
expected_value = 400.0 * label_smoothing / 3.0
|
||||
self.assertAlmostEqual(loss.eval(), expected_value, 3)
|
||||
self.assertAlmostEqual(self.evaluate(loss), expected_value, 3)
|
||||
|
||||
|
||||
class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
|
||||
@ -246,7 +246,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
|
||||
labels = constant_op.constant([[0], [1], [2]], dtype=dtypes.int32)
|
||||
loss = losses.sparse_softmax_cross_entropy(labels, logits)
|
||||
self.assertEqual(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
|
||||
self.assertAlmostEqual(loss.eval(), 0.0, 3)
|
||||
self.assertAlmostEqual(self.evaluate(loss), 0.0, 3)
|
||||
|
||||
@test_util.assert_no_new_pyobjects_executing_eagerly
|
||||
def testEagerNoMemoryLeaked(self):
|
||||
@ -263,7 +263,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
|
||||
labels = constant_op.constant([[0], [1], [2]], dtype=dtypes.int64)
|
||||
loss = losses.sparse_softmax_cross_entropy(labels, logits)
|
||||
self.assertEqual(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
|
||||
self.assertAlmostEqual(loss.eval(), 0.0, 3)
|
||||
self.assertAlmostEqual(self.evaluate(loss), 0.0, 3)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testAllCorrectNonColumnLabels(self):
|
||||
@ -273,7 +273,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
|
||||
labels = constant_op.constant([0, 1, 2])
|
||||
loss = losses.sparse_softmax_cross_entropy(labels, logits)
|
||||
self.assertEqual(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
|
||||
self.assertAlmostEqual(loss.eval(), 0.0, 3)
|
||||
self.assertAlmostEqual(self.evaluate(loss), 0.0, 3)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testAllWrongInt32Labels(self):
|
||||
@ -284,7 +284,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
|
||||
with self.cached_session():
|
||||
loss = losses.sparse_softmax_cross_entropy(labels, logits)
|
||||
self.assertEqual(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
|
||||
self.assertAlmostEqual(loss.eval(), 10.0, 3)
|
||||
self.assertAlmostEqual(self.evaluate(loss), 10.0, 3)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testAllWrongInt64Labels(self):
|
||||
@ -295,7 +295,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
|
||||
with self.cached_session():
|
||||
loss = losses.sparse_softmax_cross_entropy(labels, logits)
|
||||
self.assertEqual(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
|
||||
self.assertAlmostEqual(loss.eval(), 10.0, 3)
|
||||
self.assertAlmostEqual(self.evaluate(loss), 10.0, 3)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testAllWrongNonColumnLabels(self):
|
||||
@ -306,7 +306,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
|
||||
with self.cached_session():
|
||||
loss = losses.sparse_softmax_cross_entropy(labels, logits)
|
||||
self.assertEqual(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
|
||||
self.assertAlmostEqual(loss.eval(), 10.0, 3)
|
||||
self.assertAlmostEqual(self.evaluate(loss), 10.0, 3)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testNonZeroLossWithPythonScalarWeight(self):
|
||||
@ -551,7 +551,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
|
||||
loss = losses.sigmoid_cross_entropy(labels, logits)
|
||||
self.assertEqual(logits.dtype, loss.dtype)
|
||||
self.assertEqual('sigmoid_cross_entropy_loss/value', loss.op.name)
|
||||
self.assertAlmostEqual(loss.eval(), 600.0 / 9.0, 3)
|
||||
self.assertAlmostEqual(self.evaluate(loss), 600.0 / 9.0, 3)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testAllWrongSigmoidWithMeasurementSpecificWeights(self):
|
||||
@ -630,7 +630,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
|
||||
self.assertEqual(logits.dtype, loss.dtype)
|
||||
self.assertEqual('sigmoid_cross_entropy_loss/value', loss.op.name)
|
||||
expected_value = (100.0 + 50.0 * label_smoothing) / 3.0
|
||||
self.assertAlmostEqual(loss.eval(), expected_value, 3)
|
||||
self.assertAlmostEqual(self.evaluate(loss), expected_value, 3)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSigmoidLabelSmoothingEqualsSoftmaxTwoLabel(self):
|
||||
@ -647,8 +647,8 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
|
||||
softmax_labels = constant_op.constant([[0, 1], [1, 0], [0, 1]])
|
||||
softmax_loss = losses.softmax_cross_entropy(
|
||||
softmax_labels, softmax_logits, label_smoothing=label_smoothing)
|
||||
self.assertAlmostEqual(sigmoid_loss.eval(), self.evaluate(softmax_loss),
|
||||
3)
|
||||
self.assertAlmostEqual(
|
||||
self.evaluate(sigmoid_loss), self.evaluate(softmax_loss), 3)
|
||||
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
|
@ -264,8 +264,8 @@ class MeanTest(test.TestCase):
|
||||
variables.local_variables_initializer().run()
|
||||
for mean_result in mean_results:
|
||||
mean, update_op = mean_result
|
||||
self.assertAlmostEqual(expected, update_op.eval())
|
||||
self.assertAlmostEqual(expected, mean.eval())
|
||||
self.assertAlmostEqual(expected, self.evaluate(update_op))
|
||||
self.assertAlmostEqual(expected, self.evaluate(mean))
|
||||
|
||||
def _test_3d_weighted(self, values, weights):
|
||||
expected = (
|
||||
@ -275,8 +275,8 @@ class MeanTest(test.TestCase):
|
||||
mean, update_op = metrics.mean(values, weights=weights)
|
||||
with self.cached_session():
|
||||
variables.local_variables_initializer().run()
|
||||
self.assertAlmostEqual(expected, update_op.eval(), places=5)
|
||||
self.assertAlmostEqual(expected, mean.eval(), places=5)
|
||||
self.assertAlmostEqual(expected, self.evaluate(update_op), places=5)
|
||||
self.assertAlmostEqual(expected, self.evaluate(mean), places=5)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test1x1x1Weighted(self):
|
||||
@ -615,9 +615,9 @@ class AccuracyTest(test.TestCase):
|
||||
self.evaluate(update_op)
|
||||
|
||||
# Then verify idempotency.
|
||||
initial_accuracy = accuracy.eval()
|
||||
initial_accuracy = self.evaluate(accuracy)
|
||||
for _ in range(10):
|
||||
self.assertEqual(initial_accuracy, accuracy.eval())
|
||||
self.assertEqual(initial_accuracy, self.evaluate(accuracy))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testMultipleUpdates(self):
|
||||
@ -646,7 +646,7 @@ class AccuracyTest(test.TestCase):
|
||||
for _ in xrange(3):
|
||||
self.evaluate(update_op)
|
||||
self.assertEqual(0.5, self.evaluate(update_op))
|
||||
self.assertEqual(0.5, accuracy.eval())
|
||||
self.assertEqual(0.5, self.evaluate(accuracy))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testEffectivelyEquivalentSizes(self):
|
||||
@ -656,8 +656,8 @@ class AccuracyTest(test.TestCase):
|
||||
accuracy, update_op = metrics.accuracy(labels, predictions)
|
||||
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.assertEqual(1.0, update_op.eval())
|
||||
self.assertEqual(1.0, accuracy.eval())
|
||||
self.assertEqual(1.0, self.evaluate(update_op))
|
||||
self.assertEqual(1.0, self.evaluate(accuracy))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testEffectivelyEquivalentSizesWithScalarWeight(self):
|
||||
@ -667,8 +667,8 @@ class AccuracyTest(test.TestCase):
|
||||
accuracy, update_op = metrics.accuracy(labels, predictions, weights=2.0)
|
||||
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.assertEqual(1.0, update_op.eval())
|
||||
self.assertEqual(1.0, accuracy.eval())
|
||||
self.assertEqual(1.0, self.evaluate(update_op))
|
||||
self.assertEqual(1.0, self.evaluate(accuracy))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testEffectivelyEquivalentSizesWithStaticShapedWeight(self):
|
||||
@ -685,8 +685,8 @@ class AccuracyTest(test.TestCase):
|
||||
# if streaming_accuracy does not flatten the weight, accuracy would be
|
||||
# 0.33333334 due to an intended broadcast of weight. Due to flattening,
|
||||
# it will be higher than .95
|
||||
self.assertGreater(update_op.eval(), .95)
|
||||
self.assertGreater(accuracy.eval(), .95)
|
||||
self.assertGreater(self.evaluate(update_op), .95)
|
||||
self.assertGreater(self.evaluate(accuracy), .95)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testEffectivelyEquivalentSizesWithDynamicallyShapedWeight(self):
|
||||
@ -746,7 +746,7 @@ class AccuracyTest(test.TestCase):
|
||||
for _ in xrange(3):
|
||||
self.evaluate(update_op)
|
||||
self.assertEqual(1.0, self.evaluate(update_op))
|
||||
self.assertEqual(1.0, accuracy.eval())
|
||||
self.assertEqual(1.0, self.evaluate(accuracy))
|
||||
|
||||
|
||||
class PrecisionTest(test.TestCase):
|
||||
@ -796,9 +796,9 @@ class PrecisionTest(test.TestCase):
|
||||
self.evaluate(update_op)
|
||||
|
||||
# Then verify idempotency.
|
||||
initial_precision = precision.eval()
|
||||
initial_precision = self.evaluate(precision)
|
||||
for _ in range(10):
|
||||
self.assertEqual(initial_precision, precision.eval())
|
||||
self.assertEqual(initial_precision, self.evaluate(precision))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testAllCorrect(self):
|
||||
@ -811,7 +811,7 @@ class PrecisionTest(test.TestCase):
|
||||
with self.cached_session():
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.assertAlmostEqual(1.0, self.evaluate(update_op), 6)
|
||||
self.assertAlmostEqual(1.0, precision.eval(), 6)
|
||||
self.assertAlmostEqual(1.0, self.evaluate(precision), 6)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSomeCorrect_multipleInputDtypes(self):
|
||||
@ -824,8 +824,8 @@ class PrecisionTest(test.TestCase):
|
||||
|
||||
with self.cached_session():
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.assertAlmostEqual(0.5, update_op.eval())
|
||||
self.assertAlmostEqual(0.5, precision.eval())
|
||||
self.assertAlmostEqual(0.5, self.evaluate(update_op))
|
||||
self.assertAlmostEqual(0.5, self.evaluate(precision))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testWeighted1d(self):
|
||||
@ -839,8 +839,8 @@ class PrecisionTest(test.TestCase):
|
||||
weighted_tp = 2.0 + 5.0
|
||||
weighted_positives = (2.0 + 2.0) + (5.0 + 5.0)
|
||||
expected_precision = weighted_tp / weighted_positives
|
||||
self.assertAlmostEqual(expected_precision, update_op.eval())
|
||||
self.assertAlmostEqual(expected_precision, precision.eval())
|
||||
self.assertAlmostEqual(expected_precision, self.evaluate(update_op))
|
||||
self.assertAlmostEqual(expected_precision, self.evaluate(precision))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testWeightedScalar_placeholders(self):
|
||||
@ -897,8 +897,8 @@ class PrecisionTest(test.TestCase):
|
||||
weighted_tp = 3.0 + 4.0
|
||||
weighted_positives = (1.0 + 3.0) + (4.0 + 2.0)
|
||||
expected_precision = weighted_tp / weighted_positives
|
||||
self.assertAlmostEqual(expected_precision, update_op.eval())
|
||||
self.assertAlmostEqual(expected_precision, precision.eval())
|
||||
self.assertAlmostEqual(expected_precision, self.evaluate(update_op))
|
||||
self.assertAlmostEqual(expected_precision, self.evaluate(precision))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testWeighted2d_placeholders(self):
|
||||
@ -934,7 +934,7 @@ class PrecisionTest(test.TestCase):
|
||||
with self.cached_session():
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.evaluate(update_op)
|
||||
self.assertAlmostEqual(0, precision.eval())
|
||||
self.assertAlmostEqual(0, self.evaluate(precision))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testZeroTrueAndFalsePositivesGivesZeroPrecision(self):
|
||||
@ -945,7 +945,7 @@ class PrecisionTest(test.TestCase):
|
||||
with self.cached_session():
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.evaluate(update_op)
|
||||
self.assertEqual(0.0, precision.eval())
|
||||
self.assertEqual(0.0, self.evaluate(precision))
|
||||
|
||||
|
||||
class RecallTest(test.TestCase):
|
||||
@ -996,9 +996,9 @@ class RecallTest(test.TestCase):
|
||||
self.evaluate(update_op)
|
||||
|
||||
# Then verify idempotency.
|
||||
initial_recall = recall.eval()
|
||||
initial_recall = self.evaluate(recall)
|
||||
for _ in range(10):
|
||||
self.assertEqual(initial_recall, recall.eval())
|
||||
self.assertEqual(initial_recall, self.evaluate(recall))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testAllCorrect(self):
|
||||
@ -1011,7 +1011,7 @@ class RecallTest(test.TestCase):
|
||||
with self.cached_session():
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.evaluate(update_op)
|
||||
self.assertAlmostEqual(1.0, recall.eval(), 6)
|
||||
self.assertAlmostEqual(1.0, self.evaluate(recall), 6)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSomeCorrect_multipleInputDtypes(self):
|
||||
@ -1024,8 +1024,8 @@ class RecallTest(test.TestCase):
|
||||
|
||||
with self.cached_session():
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.assertAlmostEqual(0.5, update_op.eval())
|
||||
self.assertAlmostEqual(0.5, recall.eval())
|
||||
self.assertAlmostEqual(0.5, self.evaluate(update_op))
|
||||
self.assertAlmostEqual(0.5, self.evaluate(recall))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testWeighted1d(self):
|
||||
@ -1039,8 +1039,8 @@ class RecallTest(test.TestCase):
|
||||
weighted_tp = 2.0 + 5.0
|
||||
weighted_t = (2.0 + 2.0) + (5.0 + 5.0)
|
||||
expected_precision = weighted_tp / weighted_t
|
||||
self.assertAlmostEqual(expected_precision, update_op.eval())
|
||||
self.assertAlmostEqual(expected_precision, recall.eval())
|
||||
self.assertAlmostEqual(expected_precision, self.evaluate(update_op))
|
||||
self.assertAlmostEqual(expected_precision, self.evaluate(recall))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testWeighted2d(self):
|
||||
@ -1054,8 +1054,8 @@ class RecallTest(test.TestCase):
|
||||
weighted_tp = 3.0 + 1.0
|
||||
weighted_t = (2.0 + 3.0) + (4.0 + 1.0)
|
||||
expected_precision = weighted_tp / weighted_t
|
||||
self.assertAlmostEqual(expected_precision, update_op.eval())
|
||||
self.assertAlmostEqual(expected_precision, recall.eval())
|
||||
self.assertAlmostEqual(expected_precision, self.evaluate(update_op))
|
||||
self.assertAlmostEqual(expected_precision, self.evaluate(recall))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testAllIncorrect(self):
|
||||
@ -1068,7 +1068,7 @@ class RecallTest(test.TestCase):
|
||||
with self.cached_session():
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.evaluate(update_op)
|
||||
self.assertEqual(0, recall.eval())
|
||||
self.assertEqual(0, self.evaluate(recall))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testZeroTruePositivesAndFalseNegativesGivesZeroRecall(self):
|
||||
@ -1079,7 +1079,7 @@ class RecallTest(test.TestCase):
|
||||
with self.cached_session():
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.evaluate(update_op)
|
||||
self.assertEqual(0, recall.eval())
|
||||
self.assertEqual(0, self.evaluate(recall))
|
||||
|
||||
|
||||
class AUCTest(test.TestCase):
|
||||
@ -1128,9 +1128,9 @@ class AUCTest(test.TestCase):
|
||||
self.evaluate(update_op)
|
||||
|
||||
# Then verify idempotency.
|
||||
initial_auc = auc.eval()
|
||||
initial_auc = self.evaluate(auc)
|
||||
for _ in range(10):
|
||||
self.assertAlmostEqual(initial_auc, auc.eval(), 5)
|
||||
self.assertAlmostEqual(initial_auc, self.evaluate(auc), 5)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testAllCorrect(self):
|
||||
@ -1147,7 +1147,7 @@ class AUCTest(test.TestCase):
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.assertEqual(1, self.evaluate(update_op))
|
||||
|
||||
self.assertEqual(1, auc.eval())
|
||||
self.assertEqual(1, self.evaluate(auc))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSomeCorrect_multipleLabelDtypes(self):
|
||||
@ -1163,7 +1163,7 @@ class AUCTest(test.TestCase):
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.assertAlmostEqual(0.5, self.evaluate(update_op))
|
||||
|
||||
self.assertAlmostEqual(0.5, auc.eval())
|
||||
self.assertAlmostEqual(0.5, self.evaluate(auc))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testWeighted1d(self):
|
||||
@ -1177,7 +1177,7 @@ class AUCTest(test.TestCase):
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.assertAlmostEqual(0.5, self.evaluate(update_op), 5)
|
||||
|
||||
self.assertAlmostEqual(0.5, auc.eval(), 5)
|
||||
self.assertAlmostEqual(0.5, self.evaluate(auc), 5)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testWeighted2d(self):
|
||||
@ -1191,7 +1191,7 @@ class AUCTest(test.TestCase):
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.assertAlmostEqual(0.7, self.evaluate(update_op), 5)
|
||||
|
||||
self.assertAlmostEqual(0.7, auc.eval(), 5)
|
||||
self.assertAlmostEqual(0.7, self.evaluate(auc), 5)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testManualThresholds(self):
|
||||
@ -1216,10 +1216,10 @@ class AUCTest(test.TestCase):
|
||||
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.assertAlmostEqual(0.875, self.evaluate(default_update_op), 3)
|
||||
self.assertAlmostEqual(0.875, default_auc.eval(), 3)
|
||||
self.assertAlmostEqual(0.875, self.evaluate(default_auc), 3)
|
||||
|
||||
self.assertAlmostEqual(0.75, self.evaluate(manual_update_op), 3)
|
||||
self.assertAlmostEqual(0.75, manual_auc.eval(), 3)
|
||||
self.assertAlmostEqual(0.75, self.evaluate(manual_auc), 3)
|
||||
|
||||
# Regarding the AUC-PR tests: note that the preferred method when
|
||||
# calculating AUC-PR is summation_method='careful_interpolation'.
|
||||
@ -1236,7 +1236,7 @@ class AUCTest(test.TestCase):
|
||||
# expected ~= 0.79726744594
|
||||
expected = 1 - math.log(1.5) / 2
|
||||
self.assertAlmostEqual(expected, self.evaluate(update_op), delta=1e-3)
|
||||
self.assertAlmostEqual(expected, auc.eval(), delta=1e-3)
|
||||
self.assertAlmostEqual(expected, self.evaluate(auc), delta=1e-3)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testCorrectAnotherAUCPRSpecialCase(self):
|
||||
@ -1253,7 +1253,7 @@ class AUCTest(test.TestCase):
|
||||
# expected ~= 0.61350593198
|
||||
expected = (2.5 - 2 * math.log(4./3) - 0.25 * math.log(7./5)) / 3
|
||||
self.assertAlmostEqual(expected, self.evaluate(update_op), delta=1e-3)
|
||||
self.assertAlmostEqual(expected, auc.eval(), delta=1e-3)
|
||||
self.assertAlmostEqual(expected, self.evaluate(auc), delta=1e-3)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testThirdCorrectAUCPRSpecialCase(self):
|
||||
@ -1270,7 +1270,7 @@ class AUCTest(test.TestCase):
|
||||
# expected ~= 0.90410597584
|
||||
expected = 1 - math.log(4./3) / 3
|
||||
self.assertAlmostEqual(expected, self.evaluate(update_op), delta=1e-3)
|
||||
self.assertAlmostEqual(expected, auc.eval(), delta=1e-3)
|
||||
self.assertAlmostEqual(expected, self.evaluate(auc), delta=1e-3)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testIncorrectAUCPRSpecialCase(self):
|
||||
@ -1284,7 +1284,7 @@ class AUCTest(test.TestCase):
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.assertAlmostEqual(0.79166, self.evaluate(update_op), delta=1e-3)
|
||||
|
||||
self.assertAlmostEqual(0.79166, auc.eval(), delta=1e-3)
|
||||
self.assertAlmostEqual(0.79166, self.evaluate(auc), delta=1e-3)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testAnotherIncorrectAUCPRSpecialCase(self):
|
||||
@ -1300,7 +1300,7 @@ class AUCTest(test.TestCase):
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.assertAlmostEqual(0.610317, self.evaluate(update_op), delta=1e-3)
|
||||
|
||||
self.assertAlmostEqual(0.610317, auc.eval(), delta=1e-3)
|
||||
self.assertAlmostEqual(0.610317, self.evaluate(auc), delta=1e-3)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testThirdIncorrectAUCPRSpecialCase(self):
|
||||
@ -1316,7 +1316,7 @@ class AUCTest(test.TestCase):
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.assertAlmostEqual(0.90277, self.evaluate(update_op), delta=1e-3)
|
||||
|
||||
self.assertAlmostEqual(0.90277, auc.eval(), delta=1e-3)
|
||||
self.assertAlmostEqual(0.90277, self.evaluate(auc), delta=1e-3)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testAllIncorrect(self):
|
||||
@ -1330,7 +1330,7 @@ class AUCTest(test.TestCase):
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.assertAlmostEqual(0, self.evaluate(update_op))
|
||||
|
||||
self.assertAlmostEqual(0, auc.eval())
|
||||
self.assertAlmostEqual(0, self.evaluate(auc))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testZeroTruePositivesAndFalseNegativesGivesOneAUC(self):
|
||||
@ -1342,7 +1342,7 @@ class AUCTest(test.TestCase):
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.assertAlmostEqual(1, self.evaluate(update_op), 6)
|
||||
|
||||
self.assertAlmostEqual(1, auc.eval(), 6)
|
||||
self.assertAlmostEqual(1, self.evaluate(auc), 6)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testRecallOneAndPrecisionOneGivesOnePRAUC(self):
|
||||
@ -1354,7 +1354,7 @@ class AUCTest(test.TestCase):
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.assertAlmostEqual(1, self.evaluate(update_op), 6)
|
||||
|
||||
self.assertAlmostEqual(1, auc.eval(), 6)
|
||||
self.assertAlmostEqual(1, self.evaluate(auc), 6)
|
||||
|
||||
def np_auc(self, predictions, labels, weights):
|
||||
"""Computes the AUC explicitly using Numpy.
|
||||
@ -1431,7 +1431,7 @@ class AUCTest(test.TestCase):
|
||||
# Since this is only approximate, we can't expect a 6 digits match.
|
||||
# Although with higher number of samples/thresholds we should see the
|
||||
# accuracy improving
|
||||
self.assertAlmostEqual(expected_auc, auc.eval(), 2)
|
||||
self.assertAlmostEqual(expected_auc, self.evaluate(auc), 2)
|
||||
|
||||
|
||||
class SpecificityAtSensitivityTest(test.TestCase):
|
||||
@ -1489,9 +1489,10 @@ class SpecificityAtSensitivityTest(test.TestCase):
|
||||
self.evaluate(update_op)
|
||||
|
||||
# Then verify idempotency.
|
||||
initial_specificity = specificity.eval()
|
||||
initial_specificity = self.evaluate(specificity)
|
||||
for _ in range(10):
|
||||
self.assertAlmostEqual(initial_specificity, specificity.eval(), 5)
|
||||
self.assertAlmostEqual(initial_specificity, self.evaluate(specificity),
|
||||
5)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testAllCorrect(self):
|
||||
@ -1505,7 +1506,7 @@ class SpecificityAtSensitivityTest(test.TestCase):
|
||||
with self.cached_session():
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.assertEqual(1, self.evaluate(update_op))
|
||||
self.assertEqual(1, specificity.eval())
|
||||
self.assertEqual(1, self.evaluate(specificity))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSomeCorrectHighSensitivity(self):
|
||||
@ -1521,7 +1522,7 @@ class SpecificityAtSensitivityTest(test.TestCase):
|
||||
with self.cached_session():
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.assertAlmostEqual(1.0, self.evaluate(update_op))
|
||||
self.assertAlmostEqual(1.0, specificity.eval())
|
||||
self.assertAlmostEqual(1.0, self.evaluate(specificity))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSomeCorrectLowSensitivity(self):
|
||||
@ -1538,7 +1539,7 @@ class SpecificityAtSensitivityTest(test.TestCase):
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
|
||||
self.assertAlmostEqual(0.6, self.evaluate(update_op))
|
||||
self.assertAlmostEqual(0.6, specificity.eval())
|
||||
self.assertAlmostEqual(0.6, self.evaluate(specificity))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testWeighted1d_multipleLabelDtypes(self):
|
||||
@ -1558,7 +1559,7 @@ class SpecificityAtSensitivityTest(test.TestCase):
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
|
||||
self.assertAlmostEqual(0.6, self.evaluate(update_op))
|
||||
self.assertAlmostEqual(0.6, specificity.eval())
|
||||
self.assertAlmostEqual(0.6, self.evaluate(specificity))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testWeighted2d(self):
|
||||
@ -1577,7 +1578,7 @@ class SpecificityAtSensitivityTest(test.TestCase):
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
|
||||
self.assertAlmostEqual(8.0 / 15.0, self.evaluate(update_op))
|
||||
self.assertAlmostEqual(8.0 / 15.0, specificity.eval())
|
||||
self.assertAlmostEqual(8.0 / 15.0, self.evaluate(specificity))
|
||||
|
||||
|
||||
class SensitivityAtSpecificityTest(test.TestCase):
|
||||
@ -1635,9 +1636,10 @@ class SensitivityAtSpecificityTest(test.TestCase):
|
||||
self.evaluate(update_op)
|
||||
|
||||
# Then verify idempotency.
|
||||
initial_sensitivity = sensitivity.eval()
|
||||
initial_sensitivity = self.evaluate(sensitivity)
|
||||
for _ in range(10):
|
||||
self.assertAlmostEqual(initial_sensitivity, sensitivity.eval(), 5)
|
||||
self.assertAlmostEqual(initial_sensitivity, self.evaluate(sensitivity),
|
||||
5)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testAllCorrect(self):
|
||||
@ -1651,7 +1653,7 @@ class SensitivityAtSpecificityTest(test.TestCase):
|
||||
with self.cached_session():
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.assertAlmostEqual(1.0, self.evaluate(update_op), 6)
|
||||
self.assertAlmostEqual(1.0, specificity.eval(), 6)
|
||||
self.assertAlmostEqual(1.0, self.evaluate(specificity), 6)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSomeCorrectHighSpecificity(self):
|
||||
@ -1667,7 +1669,7 @@ class SensitivityAtSpecificityTest(test.TestCase):
|
||||
with self.cached_session():
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.assertAlmostEqual(0.8, self.evaluate(update_op))
|
||||
self.assertAlmostEqual(0.8, specificity.eval())
|
||||
self.assertAlmostEqual(0.8, self.evaluate(specificity))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSomeCorrectLowSpecificity(self):
|
||||
@ -1683,7 +1685,7 @@ class SensitivityAtSpecificityTest(test.TestCase):
|
||||
with self.cached_session():
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.assertAlmostEqual(0.6, self.evaluate(update_op))
|
||||
self.assertAlmostEqual(0.6, specificity.eval())
|
||||
self.assertAlmostEqual(0.6, self.evaluate(specificity))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testWeighted_multipleLabelDtypes(self):
|
||||
@ -1703,7 +1705,7 @@ class SensitivityAtSpecificityTest(test.TestCase):
|
||||
with self.cached_session():
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.assertAlmostEqual(0.675, self.evaluate(update_op))
|
||||
self.assertAlmostEqual(0.675, specificity.eval())
|
||||
self.assertAlmostEqual(0.675, self.evaluate(specificity))
|
||||
|
||||
|
||||
# TODO(nsilberman): Break this up into two sets of tests.
|
||||
@ -1771,8 +1773,8 @@ class PrecisionRecallThresholdsTest(test.TestCase):
|
||||
|
||||
# Run several updates, then verify idempotency.
|
||||
self.evaluate([prec_op, rec_op])
|
||||
initial_prec = prec.eval()
|
||||
initial_rec = rec.eval()
|
||||
initial_prec = self.evaluate(prec)
|
||||
initial_rec = self.evaluate(rec)
|
||||
for _ in range(10):
|
||||
self.evaluate([prec_op, rec_op])
|
||||
self.assertAllClose(initial_prec, prec)
|
||||
@ -1795,8 +1797,8 @@ class PrecisionRecallThresholdsTest(test.TestCase):
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.evaluate([prec_op, rec_op])
|
||||
|
||||
self.assertEqual(1, prec.eval())
|
||||
self.assertEqual(1, rec.eval())
|
||||
self.assertEqual(1, self.evaluate(prec))
|
||||
self.assertEqual(1, self.evaluate(rec))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSomeCorrect_multipleLabelDtypes(self):
|
||||
@ -1816,8 +1818,8 @@ class PrecisionRecallThresholdsTest(test.TestCase):
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.evaluate([prec_op, rec_op])
|
||||
|
||||
self.assertAlmostEqual(0.5, prec.eval())
|
||||
self.assertAlmostEqual(0.5, rec.eval())
|
||||
self.assertAlmostEqual(0.5, self.evaluate(prec))
|
||||
self.assertAlmostEqual(0.5, self.evaluate(rec))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testAllIncorrect(self):
|
||||
@ -1835,8 +1837,8 @@ class PrecisionRecallThresholdsTest(test.TestCase):
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.evaluate([prec_op, rec_op])
|
||||
|
||||
self.assertAlmostEqual(0, prec.eval())
|
||||
self.assertAlmostEqual(0, rec.eval())
|
||||
self.assertAlmostEqual(0, self.evaluate(prec))
|
||||
self.assertAlmostEqual(0, self.evaluate(rec))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testWeights1d(self):
|
||||
@ -1864,10 +1866,10 @@ class PrecisionRecallThresholdsTest(test.TestCase):
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.evaluate([prec_op, rec_op])
|
||||
|
||||
self.assertAlmostEqual(1.0, prec_low.eval(), places=5)
|
||||
self.assertAlmostEqual(0.0, prec_high.eval(), places=5)
|
||||
self.assertAlmostEqual(1.0, rec_low.eval(), places=5)
|
||||
self.assertAlmostEqual(0.0, rec_high.eval(), places=5)
|
||||
self.assertAlmostEqual(1.0, self.evaluate(prec_low), places=5)
|
||||
self.assertAlmostEqual(0.0, self.evaluate(prec_high), places=5)
|
||||
self.assertAlmostEqual(1.0, self.evaluate(rec_low), places=5)
|
||||
self.assertAlmostEqual(0.0, self.evaluate(rec_high), places=5)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testWeights2d(self):
|
||||
@ -1895,10 +1897,10 @@ class PrecisionRecallThresholdsTest(test.TestCase):
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.evaluate([prec_op, rec_op])
|
||||
|
||||
self.assertAlmostEqual(1.0, prec_low.eval(), places=5)
|
||||
self.assertAlmostEqual(0.0, prec_high.eval(), places=5)
|
||||
self.assertAlmostEqual(1.0, rec_low.eval(), places=5)
|
||||
self.assertAlmostEqual(0.0, rec_high.eval(), places=5)
|
||||
self.assertAlmostEqual(1.0, self.evaluate(prec_low), places=5)
|
||||
self.assertAlmostEqual(0.0, self.evaluate(prec_high), places=5)
|
||||
self.assertAlmostEqual(1.0, self.evaluate(rec_low), places=5)
|
||||
self.assertAlmostEqual(0.0, self.evaluate(rec_high), places=5)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testExtremeThresholds(self):
|
||||
@ -1920,10 +1922,10 @@ class PrecisionRecallThresholdsTest(test.TestCase):
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.evaluate([prec_op, rec_op])
|
||||
|
||||
self.assertAlmostEqual(0.75, prec_low.eval())
|
||||
self.assertAlmostEqual(0.0, prec_high.eval())
|
||||
self.assertAlmostEqual(1.0, rec_low.eval())
|
||||
self.assertAlmostEqual(0.0, rec_high.eval())
|
||||
self.assertAlmostEqual(0.75, self.evaluate(prec_low))
|
||||
self.assertAlmostEqual(0.0, self.evaluate(prec_high))
|
||||
self.assertAlmostEqual(1.0, self.evaluate(rec_low))
|
||||
self.assertAlmostEqual(0.0, self.evaluate(rec_high))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testZeroLabelsPredictions(self):
|
||||
@ -1939,8 +1941,8 @@ class PrecisionRecallThresholdsTest(test.TestCase):
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.evaluate([prec_op, rec_op])
|
||||
|
||||
self.assertAlmostEqual(0, prec.eval(), 6)
|
||||
self.assertAlmostEqual(0, rec.eval(), 6)
|
||||
self.assertAlmostEqual(0, self.evaluate(prec), 6)
|
||||
self.assertAlmostEqual(0, self.evaluate(rec), 6)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testWithMultipleUpdates(self):
|
||||
@ -2011,8 +2013,8 @@ class PrecisionRecallThresholdsTest(test.TestCase):
|
||||
# Since this is only approximate, we can't expect a 6 digits match.
|
||||
# Although with higher number of samples/thresholds we should see the
|
||||
# accuracy improving
|
||||
self.assertAlmostEqual(expected_prec, prec.eval(), 2)
|
||||
self.assertAlmostEqual(expected_rec, rec.eval(), 2)
|
||||
self.assertAlmostEqual(expected_prec, self.evaluate(prec), 2)
|
||||
self.assertAlmostEqual(expected_rec, self.evaluate(rec), 2)
|
||||
|
||||
|
||||
def _test_precision_at_k(predictions,
|
||||
@ -3001,9 +3003,9 @@ class MeanAbsoluteErrorTest(test.TestCase):
|
||||
self.evaluate(update_op)
|
||||
|
||||
# Then verify idempotency.
|
||||
initial_error = error.eval()
|
||||
initial_error = self.evaluate(error)
|
||||
for _ in range(10):
|
||||
self.assertEqual(initial_error, error.eval())
|
||||
self.assertEqual(initial_error, self.evaluate(error))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSingleUpdateWithErrorAndWeights(self):
|
||||
@ -3018,7 +3020,7 @@ class MeanAbsoluteErrorTest(test.TestCase):
|
||||
with self.cached_session():
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.assertEqual(3, self.evaluate(update_op))
|
||||
self.assertEqual(3, error.eval())
|
||||
self.assertEqual(3, self.evaluate(error))
|
||||
|
||||
|
||||
class MeanRelativeErrorTest(test.TestCase):
|
||||
@ -3071,9 +3073,9 @@ class MeanRelativeErrorTest(test.TestCase):
|
||||
self.evaluate(update_op)
|
||||
|
||||
# Then verify idempotency.
|
||||
initial_error = error.eval()
|
||||
initial_error = self.evaluate(error)
|
||||
for _ in range(10):
|
||||
self.assertEqual(initial_error, error.eval())
|
||||
self.assertEqual(initial_error, self.evaluate(error))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSingleUpdateNormalizedByLabels(self):
|
||||
@ -3092,7 +3094,7 @@ class MeanRelativeErrorTest(test.TestCase):
|
||||
with self.cached_session():
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.assertEqual(expected_error, self.evaluate(update_op))
|
||||
self.assertEqual(expected_error, error.eval())
|
||||
self.assertEqual(expected_error, self.evaluate(error))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSingleUpdateNormalizedByZeros(self):
|
||||
@ -3109,7 +3111,7 @@ class MeanRelativeErrorTest(test.TestCase):
|
||||
with self.cached_session():
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.assertEqual(0.0, self.evaluate(update_op))
|
||||
self.assertEqual(0.0, error.eval())
|
||||
self.assertEqual(0.0, self.evaluate(error))
|
||||
|
||||
|
||||
class MeanSquaredErrorTest(test.TestCase):
|
||||
@ -3156,9 +3158,9 @@ class MeanSquaredErrorTest(test.TestCase):
|
||||
self.evaluate(update_op)
|
||||
|
||||
# Then verify idempotency.
|
||||
initial_error = error.eval()
|
||||
initial_error = self.evaluate(error)
|
||||
for _ in range(10):
|
||||
self.assertEqual(initial_error, error.eval())
|
||||
self.assertEqual(initial_error, self.evaluate(error))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSingleUpdateZeroError(self):
|
||||
@ -3170,7 +3172,7 @@ class MeanSquaredErrorTest(test.TestCase):
|
||||
with self.cached_session():
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.assertEqual(0, self.evaluate(update_op))
|
||||
self.assertEqual(0, error.eval())
|
||||
self.assertEqual(0, self.evaluate(error))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSingleUpdateWithError(self):
|
||||
@ -3184,7 +3186,7 @@ class MeanSquaredErrorTest(test.TestCase):
|
||||
with self.cached_session():
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.assertEqual(6, self.evaluate(update_op))
|
||||
self.assertEqual(6, error.eval())
|
||||
self.assertEqual(6, self.evaluate(error))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSingleUpdateWithErrorAndWeights(self):
|
||||
@ -3199,7 +3201,7 @@ class MeanSquaredErrorTest(test.TestCase):
|
||||
with self.cached_session():
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.assertEqual(13, self.evaluate(update_op))
|
||||
self.assertEqual(13, error.eval())
|
||||
self.assertEqual(13, self.evaluate(error))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testMultipleBatchesOfSizeOne(self):
|
||||
@ -3224,7 +3226,7 @@ class MeanSquaredErrorTest(test.TestCase):
|
||||
self.evaluate(update_op)
|
||||
self.assertAlmostEqual(208.0 / 6, self.evaluate(update_op), 5)
|
||||
|
||||
self.assertAlmostEqual(208.0 / 6, error.eval(), 5)
|
||||
self.assertAlmostEqual(208.0 / 6, self.evaluate(error), 5)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testMetricsComputedConcurrently(self):
|
||||
@ -3294,8 +3296,8 @@ class MeanSquaredErrorTest(test.TestCase):
|
||||
self.evaluate([ma_update_op, ms_update_op])
|
||||
self.evaluate([ma_update_op, ms_update_op])
|
||||
|
||||
self.assertAlmostEqual(32.0 / 6, mae.eval(), 5)
|
||||
self.assertAlmostEqual(208.0 / 6, mse.eval(), 5)
|
||||
self.assertAlmostEqual(32.0 / 6, self.evaluate(mae), 5)
|
||||
self.assertAlmostEqual(208.0 / 6, self.evaluate(mse), 5)
|
||||
|
||||
|
||||
class RootMeanSquaredErrorTest(test.TestCase):
|
||||
@ -3343,9 +3345,9 @@ class RootMeanSquaredErrorTest(test.TestCase):
|
||||
self.evaluate(update_op)
|
||||
|
||||
# Then verify idempotency.
|
||||
initial_error = error.eval()
|
||||
initial_error = self.evaluate(error)
|
||||
for _ in range(10):
|
||||
self.assertEqual(initial_error, error.eval())
|
||||
self.assertEqual(initial_error, self.evaluate(error))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSingleUpdateZeroError(self):
|
||||
@ -3359,7 +3361,7 @@ class RootMeanSquaredErrorTest(test.TestCase):
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.assertEqual(0, self.evaluate(update_op))
|
||||
|
||||
self.assertEqual(0, rmse.eval())
|
||||
self.assertEqual(0, self.evaluate(rmse))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSingleUpdateWithError(self):
|
||||
@ -3372,8 +3374,8 @@ class RootMeanSquaredErrorTest(test.TestCase):
|
||||
rmse, update_op = metrics.root_mean_squared_error(labels, predictions)
|
||||
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.assertAlmostEqual(math.sqrt(6), update_op.eval(), 5)
|
||||
self.assertAlmostEqual(math.sqrt(6), rmse.eval(), 5)
|
||||
self.assertAlmostEqual(math.sqrt(6), self.evaluate(update_op), 5)
|
||||
self.assertAlmostEqual(math.sqrt(6), self.evaluate(rmse), 5)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSingleUpdateWithErrorAndWeights(self):
|
||||
@ -3390,7 +3392,7 @@ class RootMeanSquaredErrorTest(test.TestCase):
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.assertAlmostEqual(math.sqrt(13), self.evaluate(update_op))
|
||||
|
||||
self.assertAlmostEqual(math.sqrt(13), rmse.eval(), 5)
|
||||
self.assertAlmostEqual(math.sqrt(13), self.evaluate(rmse), 5)
|
||||
|
||||
|
||||
def _reweight(predictions, labels, weights):
|
||||
@ -3448,9 +3450,9 @@ class MeanCosineDistanceTest(test.TestCase):
|
||||
self.evaluate(update_op)
|
||||
|
||||
# Then verify idempotency.
|
||||
initial_error = error.eval()
|
||||
initial_error = self.evaluate(error)
|
||||
for _ in range(10):
|
||||
self.assertEqual(initial_error, error.eval())
|
||||
self.assertEqual(initial_error, self.evaluate(error))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSingleUpdateZeroError(self):
|
||||
@ -3466,7 +3468,7 @@ class MeanCosineDistanceTest(test.TestCase):
|
||||
with self.cached_session():
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.assertEqual(0, self.evaluate(update_op))
|
||||
self.assertEqual(0, error.eval())
|
||||
self.assertEqual(0, self.evaluate(error))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSingleUpdateWithError1(self):
|
||||
@ -3483,7 +3485,7 @@ class MeanCosineDistanceTest(test.TestCase):
|
||||
with self.cached_session():
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.assertAlmostEqual(1, self.evaluate(update_op), 5)
|
||||
self.assertAlmostEqual(1, error.eval(), 5)
|
||||
self.assertAlmostEqual(1, self.evaluate(error), 5)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSingleUpdateWithError2(self):
|
||||
@ -3505,7 +3507,7 @@ class MeanCosineDistanceTest(test.TestCase):
|
||||
with self.cached_session():
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.assertAlmostEqual(1.0, self.evaluate(update_op), 5)
|
||||
self.assertAlmostEqual(1.0, error.eval(), 5)
|
||||
self.assertAlmostEqual(1.0, self.evaluate(error), 5)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSingleUpdateWithErrorAndWeights1(self):
|
||||
@ -3525,7 +3527,7 @@ class MeanCosineDistanceTest(test.TestCase):
|
||||
with self.cached_session():
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.assertEqual(0, self.evaluate(update_op))
|
||||
self.assertEqual(0, error.eval())
|
||||
self.assertEqual(0, self.evaluate(error))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSingleUpdateWithErrorAndWeights2(self):
|
||||
@ -3544,8 +3546,8 @@ class MeanCosineDistanceTest(test.TestCase):
|
||||
|
||||
with self.cached_session():
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.assertEqual(1.5, update_op.eval())
|
||||
self.assertEqual(1.5, error.eval())
|
||||
self.assertEqual(1.5, self.evaluate(update_op))
|
||||
self.assertEqual(1.5, self.evaluate(error))
|
||||
|
||||
|
||||
class PcntBelowThreshTest(test.TestCase):
|
||||
@ -3689,9 +3691,9 @@ class MeanIOUTest(test.TestCase):
|
||||
self.evaluate(update_op)
|
||||
|
||||
# Then verify idempotency.
|
||||
initial_mean_iou = mean_iou.eval()
|
||||
initial_mean_iou = self.evaluate(mean_iou)
|
||||
for _ in range(10):
|
||||
self.assertEqual(initial_mean_iou, mean_iou.eval())
|
||||
self.assertEqual(initial_mean_iou, self.evaluate(mean_iou))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testMultipleUpdates(self):
|
||||
@ -3723,7 +3725,7 @@ class MeanIOUTest(test.TestCase):
|
||||
for _ in range(5):
|
||||
self.evaluate(update_op)
|
||||
desired_output = np.mean([1.0 / 2.0, 1.0 / 4.0, 0.])
|
||||
self.assertEqual(desired_output, miou.eval())
|
||||
self.assertEqual(desired_output, self.evaluate(miou))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testMultipleUpdatesWithWeights(self):
|
||||
@ -3769,7 +3771,7 @@ class MeanIOUTest(test.TestCase):
|
||||
for _ in range(6):
|
||||
self.evaluate(update_op)
|
||||
desired_output = np.mean([2.0 / 3.0, 1.0 / 2.0])
|
||||
self.assertAlmostEqual(desired_output, mean_iou.eval())
|
||||
self.assertAlmostEqual(desired_output, self.evaluate(mean_iou))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testMultipleUpdatesWithMissingClass(self):
|
||||
@ -3806,7 +3808,7 @@ class MeanIOUTest(test.TestCase):
|
||||
for _ in range(5):
|
||||
self.evaluate(update_op)
|
||||
desired_output = np.mean([1.0 / 3.0, 2.0 / 4.0])
|
||||
self.assertAlmostEqual(desired_output, miou.eval())
|
||||
self.assertAlmostEqual(desired_output, self.evaluate(miou))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testUpdateOpEvalIsAccumulatedConfusionMatrix(self):
|
||||
@ -3828,10 +3830,10 @@ class MeanIOUTest(test.TestCase):
|
||||
with self.cached_session():
|
||||
miou, update_op = metrics.mean_iou(labels, predictions, num_classes)
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
confusion_matrix = update_op.eval()
|
||||
confusion_matrix = self.evaluate(update_op)
|
||||
self.assertAllEqual([[3, 0], [2, 5]], confusion_matrix)
|
||||
desired_miou = np.mean([3. / 5., 5. / 7.])
|
||||
self.assertAlmostEqual(desired_miou, miou.eval())
|
||||
self.assertAlmostEqual(desired_miou, self.evaluate(miou))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testAllCorrect(self):
|
||||
@ -3841,8 +3843,8 @@ class MeanIOUTest(test.TestCase):
|
||||
with self.cached_session():
|
||||
miou, update_op = metrics.mean_iou(labels, predictions, num_classes)
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.assertEqual(40, update_op.eval()[0])
|
||||
self.assertEqual(1.0, miou.eval())
|
||||
self.assertEqual(40, self.evaluate(update_op)[0])
|
||||
self.assertEqual(1.0, self.evaluate(miou))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testAllWrong(self):
|
||||
@ -3853,7 +3855,7 @@ class MeanIOUTest(test.TestCase):
|
||||
miou, update_op = metrics.mean_iou(labels, predictions, num_classes)
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.assertAllEqual([[0, 0], [40, 0]], update_op)
|
||||
self.assertEqual(0., miou.eval())
|
||||
self.assertEqual(0., self.evaluate(miou))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testResultsWithSomeMissing(self):
|
||||
@ -3886,7 +3888,7 @@ class MeanIOUTest(test.TestCase):
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.assertAllEqual([[2, 0], [2, 4]], update_op)
|
||||
desired_miou = np.mean([2. / 4., 4. / 6.])
|
||||
self.assertAlmostEqual(desired_miou, miou.eval())
|
||||
self.assertAlmostEqual(desired_miou, self.evaluate(miou))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testMissingClassInLabels(self):
|
||||
@ -3907,7 +3909,7 @@ class MeanIOUTest(test.TestCase):
|
||||
self.assertAllEqual([[7, 4, 3], [3, 5, 2], [0, 0, 0]], update_op)
|
||||
self.assertAlmostEqual(
|
||||
1 / 3 * (7 / (7 + 3 + 7) + 5 / (5 + 4 + 5) + 0 / (0 + 5 + 0)),
|
||||
miou.eval())
|
||||
self.evaluate(miou))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testMissingClassOverallSmall(self):
|
||||
@ -3918,7 +3920,7 @@ class MeanIOUTest(test.TestCase):
|
||||
miou, update_op = metrics.mean_iou(labels, predictions, num_classes)
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.assertAllEqual([[1, 0], [0, 0]], update_op)
|
||||
self.assertAlmostEqual(1, miou.eval())
|
||||
self.assertAlmostEqual(1, self.evaluate(miou))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testMissingClassOverallLarge(self):
|
||||
@ -3937,8 +3939,8 @@ class MeanIOUTest(test.TestCase):
|
||||
miou, update_op = metrics.mean_iou(labels, predictions, num_classes)
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.assertAllEqual([[9, 5, 0], [3, 7, 0], [0, 0, 0]], update_op)
|
||||
self.assertAlmostEqual(
|
||||
1 / 2 * (9 / (9 + 3 + 5) + 7 / (7 + 5 + 3)), miou.eval())
|
||||
self.assertAlmostEqual(1 / 2 * (9 / (9 + 3 + 5) + 7 / (7 + 5 + 3)),
|
||||
self.evaluate(miou))
|
||||
|
||||
|
||||
class MeanPerClassAccuracyTest(test.TestCase):
|
||||
@ -4011,9 +4013,9 @@ class MeanPerClassAccuracyTest(test.TestCase):
|
||||
self.evaluate(update_op)
|
||||
|
||||
# Then verify idempotency.
|
||||
initial_mean_accuracy = mean_accuracy.eval()
|
||||
initial_mean_accuracy = self.evaluate(mean_accuracy)
|
||||
for _ in range(10):
|
||||
self.assertEqual(initial_mean_accuracy, mean_accuracy.eval())
|
||||
self.assertEqual(initial_mean_accuracy, self.evaluate(mean_accuracy))
|
||||
|
||||
num_classes = 3
|
||||
with self.cached_session() as sess:
|
||||
@ -4044,7 +4046,7 @@ class MeanPerClassAccuracyTest(test.TestCase):
|
||||
for _ in range(5):
|
||||
self.evaluate(update_op)
|
||||
desired_output = np.mean([1.0, 1.0 / 3.0, 0.0])
|
||||
self.assertAlmostEqual(desired_output, mean_accuracy.eval())
|
||||
self.assertAlmostEqual(desired_output, self.evaluate(mean_accuracy))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testMultipleUpdatesWithWeights(self):
|
||||
@ -4090,7 +4092,7 @@ class MeanPerClassAccuracyTest(test.TestCase):
|
||||
for _ in range(6):
|
||||
self.evaluate(update_op)
|
||||
desired_output = np.mean([2.0 / 2.0, 0.5 / 1.5])
|
||||
self.assertAlmostEqual(desired_output, mean_accuracy.eval())
|
||||
self.assertAlmostEqual(desired_output, self.evaluate(mean_accuracy))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testMultipleUpdatesWithMissingClass(self):
|
||||
@ -4128,7 +4130,7 @@ class MeanPerClassAccuracyTest(test.TestCase):
|
||||
for _ in range(5):
|
||||
self.evaluate(update_op)
|
||||
desired_output = np.mean([1.0 / 2.0, 2.0 / 3.0, 0.])
|
||||
self.assertAlmostEqual(desired_output, mean_accuracy.eval())
|
||||
self.assertAlmostEqual(desired_output, self.evaluate(mean_accuracy))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testAllCorrect(self):
|
||||
@ -4139,8 +4141,8 @@ class MeanPerClassAccuracyTest(test.TestCase):
|
||||
mean_accuracy, update_op = metrics.mean_per_class_accuracy(
|
||||
labels, predictions, num_classes)
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.assertEqual(1.0, update_op.eval()[0])
|
||||
self.assertEqual(1.0, mean_accuracy.eval())
|
||||
self.assertEqual(1.0, self.evaluate(update_op)[0])
|
||||
self.assertEqual(1.0, self.evaluate(mean_accuracy))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testAllWrong(self):
|
||||
@ -4152,7 +4154,7 @@ class MeanPerClassAccuracyTest(test.TestCase):
|
||||
labels, predictions, num_classes)
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
self.assertAllEqual([0.0, 0.0], update_op)
|
||||
self.assertEqual(0., mean_accuracy.eval())
|
||||
self.assertEqual(0., self.evaluate(mean_accuracy))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testResultsWithSomeMissing(self):
|
||||
@ -4174,7 +4176,8 @@ class MeanPerClassAccuracyTest(test.TestCase):
|
||||
desired_accuracy = np.array([2. / 2., 4. / 6.], dtype=np.float32)
|
||||
self.assertAllEqual(desired_accuracy, update_op)
|
||||
desired_mean_accuracy = np.mean(desired_accuracy)
|
||||
self.assertAlmostEqual(desired_mean_accuracy, mean_accuracy.eval())
|
||||
self.assertAlmostEqual(desired_mean_accuracy,
|
||||
self.evaluate(mean_accuracy))
|
||||
|
||||
|
||||
class FalseNegativesTest(test.TestCase):
|
||||
|
@ -145,7 +145,7 @@ class PaddingFIFOQueueTest(test.TestCase):
|
||||
# Dequeue every element using a single thread.
|
||||
results = []
|
||||
for _ in xrange(len(elems)):
|
||||
results.append(dequeued_t.eval())
|
||||
results.append(self.evaluate(dequeued_t))
|
||||
self.assertItemsEqual(elems, results)
|
||||
|
||||
def testParallelDequeue(self):
|
||||
@ -321,8 +321,9 @@ class PaddingFIFOQueueTest(test.TestCase):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"When providing partial shapes, a list of shapes must be provided."):
|
||||
data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32,
|
||||
None).queue_ref.eval()
|
||||
self.evaluate(
|
||||
data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32,
|
||||
None).queue_ref)
|
||||
|
||||
def testMultiEnqueueMany(self):
|
||||
with self.cached_session() as sess:
|
||||
@ -656,7 +657,7 @@ class PaddingFIFOQueueTest(test.TestCase):
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
self.assertItemsEqual(dequeued_t.eval(), elems * 10)
|
||||
self.assertCountEqual(self.evaluate(dequeued_t), elems * 10)
|
||||
|
||||
def testParallelDequeueMany(self):
|
||||
# We need each thread to keep its own device stack or the device scopes
|
||||
@ -898,7 +899,7 @@ class PaddingFIFOQueueTest(test.TestCase):
|
||||
|
||||
dequeued_elems = []
|
||||
for _ in dequeue_counts:
|
||||
dequeued_elems.extend(dequeued_t.eval())
|
||||
dequeued_elems.extend(self.evaluate(dequeued_t))
|
||||
self.assertEqual(elems, dequeued_elems)
|
||||
|
||||
def testDequeueFromClosedQueue(self):
|
||||
@ -1335,7 +1336,7 @@ class PaddingFIFOQueueTest(test.TestCase):
|
||||
|
||||
enqueue_op.run()
|
||||
for _ in range(500):
|
||||
self.assertEqual(size_t.eval(), [1])
|
||||
self.assertEqual(self.evaluate(size_t), [1])
|
||||
|
||||
def testSharedQueueSameSession(self):
|
||||
with self.cached_session():
|
||||
@ -1349,23 +1350,23 @@ class PaddingFIFOQueueTest(test.TestCase):
|
||||
q1_size_t = q1.size()
|
||||
q2_size_t = q2.size()
|
||||
|
||||
self.assertEqual(q1_size_t.eval(), [1])
|
||||
self.assertEqual(q2_size_t.eval(), [1])
|
||||
self.assertEqual(self.evaluate(q1_size_t), [1])
|
||||
self.assertEqual(self.evaluate(q2_size_t), [1])
|
||||
|
||||
self.assertEqual(q2.dequeue().eval(), [10.0])
|
||||
|
||||
self.assertEqual(q1_size_t.eval(), [0])
|
||||
self.assertEqual(q2_size_t.eval(), [0])
|
||||
self.assertEqual(self.evaluate(q1_size_t), [0])
|
||||
self.assertEqual(self.evaluate(q2_size_t), [0])
|
||||
|
||||
q2.enqueue((20.0,)).run()
|
||||
|
||||
self.assertEqual(q1_size_t.eval(), [1])
|
||||
self.assertEqual(q2_size_t.eval(), [1])
|
||||
self.assertEqual(self.evaluate(q1_size_t), [1])
|
||||
self.assertEqual(self.evaluate(q2_size_t), [1])
|
||||
|
||||
self.assertEqual(q1.dequeue().eval(), [20.0])
|
||||
|
||||
self.assertEqual(q1_size_t.eval(), [0])
|
||||
self.assertEqual(q2_size_t.eval(), [0])
|
||||
self.assertEqual(self.evaluate(q1_size_t), [0])
|
||||
self.assertEqual(self.evaluate(q2_size_t), [0])
|
||||
|
||||
def testIncompatibleSharedQueueErrors(self):
|
||||
with self.cached_session():
|
||||
@ -1509,27 +1510,28 @@ class PaddingFIFOQueueTest(test.TestCase):
|
||||
|
||||
# The enqueue should start and then block.
|
||||
results = []
|
||||
results.append(deq.eval()) # Will only complete after the enqueue starts.
|
||||
results.append(
|
||||
self.evaluate(deq)) # Will only complete after the enqueue starts.
|
||||
self.assertEqual(len(enq_done), 1)
|
||||
self.assertEqual(self.evaluate(size_op), 5)
|
||||
|
||||
for _ in range(3):
|
||||
results.append(deq.eval())
|
||||
results.append(self.evaluate(deq))
|
||||
|
||||
time.sleep(0.1)
|
||||
self.assertEqual(len(enq_done), 1)
|
||||
self.assertEqual(self.evaluate(size_op), 5)
|
||||
|
||||
# This dequeue will unblock the thread.
|
||||
results.append(deq.eval())
|
||||
results.append(self.evaluate(deq))
|
||||
time.sleep(0.1)
|
||||
self.assertEqual(len(enq_done), 2)
|
||||
thread.join()
|
||||
|
||||
for i in range(5):
|
||||
self.assertEqual(size_op.eval(), 5 - i)
|
||||
results.append(deq.eval())
|
||||
self.assertEqual(size_op.eval(), 5 - i - 1)
|
||||
self.assertEqual(self.evaluate(size_op), 5 - i)
|
||||
results.append(self.evaluate(deq))
|
||||
self.assertEqual(self.evaluate(size_op), 5 - i - 1)
|
||||
|
||||
self.assertAllEqual(elem, results)
|
||||
|
||||
|
@ -251,7 +251,7 @@ class PyFuncTest(PyFuncTestBase):
|
||||
y, = script_ops.py_func(read_object_array, [],
|
||||
[dtypes.string])
|
||||
z, = script_ops.py_func(read_and_return_strings, [x, y], [dtypes.string])
|
||||
self.assertListEqual(list(z.eval()), [b"hello there", b"hi ya"])
|
||||
self.assertListEqual(list(self.evaluate(z)), [b"hello there", b"hi ya"])
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testStringPadding(self):
|
||||
@ -308,7 +308,7 @@ class PyFuncTest(PyFuncTestBase):
|
||||
return correct
|
||||
|
||||
z, = script_ops.py_func(unicode_string, [], [dtypes.string])
|
||||
self.assertEqual(z.eval(), correct.encode("utf8"))
|
||||
self.assertEqual(self.evaluate(z), correct.encode("utf8"))
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testBadNumpyReturnType(self):
|
||||
|
@ -187,7 +187,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
||||
with self.cached_session():
|
||||
handle = resource_variable_ops.var_handle_op(
|
||||
dtype=dtypes.int32, shape=[1], name="foo")
|
||||
self.assertNotEmpty(handle.eval())
|
||||
self.assertNotEmpty(self.evaluate(handle))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testCachedValueReadBeforeWrite(self):
|
||||
|
@ -2197,7 +2197,7 @@ class RawRNNTest(test.TestCase):
|
||||
|
||||
r = rnn.raw_rnn(cell, loop_fn)
|
||||
loop_state = r[-1]
|
||||
self.assertEqual([10], loop_state.eval())
|
||||
self.assertEqual([10], self.evaluate(loop_state))
|
||||
|
||||
@test_util.run_v1_only("b/124229375")
|
||||
def testLoopStateWithTensorArray(self):
|
||||
|
@ -556,7 +556,7 @@ class ScatterNdTest(test.TestCase):
|
||||
scatter = self.scatter_nd(indices, updates, shape)
|
||||
|
||||
with self.cached_session():
|
||||
self.assertEqual(scatter.eval().size, 0)
|
||||
self.assertEqual(self.evaluate(scatter).size, 0)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testRank3InvalidShape1(self):
|
||||
|
@ -374,7 +374,7 @@ class SpaceToBatchErrorHandlingTest(test.TestCase, PythonOpImpl):
|
||||
block_size = 10
|
||||
with self.assertRaises(ValueError):
|
||||
out_tf = self.space_to_batch(x_np, paddings, block_size)
|
||||
out_tf.eval()
|
||||
self.evaluate(out_tf)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBlockSizeNotDivisibleWidth(self):
|
||||
|
@ -154,15 +154,16 @@ class MatMulGradientTest(test.TestCase):
|
||||
transpose_b=tr_b,
|
||||
a_is_sparse=sp_a,
|
||||
b_is_sparse=sp_b)
|
||||
err = (gradient_checker.compute_gradient_error(
|
||||
a, [2, 3] if tr_a else [3, 2],
|
||||
m, [3, 4],
|
||||
x_init_value=a.eval(),
|
||||
delta=delta) + gradient_checker.compute_gradient_error(
|
||||
b, [4, 2] if tr_b else [2, 4],
|
||||
err = (
|
||||
gradient_checker.compute_gradient_error(
|
||||
a, [2, 3] if tr_a else [3, 2],
|
||||
m, [3, 4],
|
||||
x_init_value=b.eval(),
|
||||
delta=delta))
|
||||
x_init_value=self.evaluate(a),
|
||||
delta=delta) + gradient_checker.compute_gradient_error(
|
||||
b, [4, 2] if tr_b else [2, 4],
|
||||
m, [3, 4],
|
||||
x_init_value=self.evaluate(b),
|
||||
delta=delta))
|
||||
self.assertLessEqual(err, delta / 2.)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
|
@ -254,8 +254,8 @@ class SparseSliceOpTest(test.TestCase):
|
||||
with self.session(use_gpu=False):
|
||||
for start, size in start_and_size:
|
||||
sp_output = sparse_ops.sparse_slice(sp_input, start, size)
|
||||
nnz_in = len(sp_input.values.eval())
|
||||
nnz_out = len(sp_output.values.eval())
|
||||
nnz_in = len(self.evaluate(sp_input.values))
|
||||
nnz_out = len(self.evaluate(sp_output.values))
|
||||
|
||||
err = gradient_checker.compute_gradient_error(
|
||||
[sp_input.values], [(nnz_in,)], sp_output.values, (nnz_out,))
|
||||
|
@ -265,7 +265,7 @@ class StackOpRefTest(test.TestCase):
|
||||
h2 = gen_data_flow_ops._stack(dtypes.float32, stack_name="foo")
|
||||
c2 = gen_data_flow_ops.stack_push(h2, 5.0)
|
||||
_ = c1 + c2
|
||||
self.assertNotEqual(h1.eval()[1], self.evaluate(h2)[1])
|
||||
self.assertNotEqual(self.evaluate(h1)[1], self.evaluate(h2)[1])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSameNameStacks(self):
|
||||
|
@ -41,7 +41,7 @@ class StringJoinOpTest(test.TestCase):
|
||||
self.assertAllEqual(output, [b"a--a--a", b"b--a--b"])
|
||||
|
||||
output = string_ops.string_join([input1] * 4, separator="!")
|
||||
self.assertEqual(output.eval(), b"a!a!a!a")
|
||||
self.assertEqual(self.evaluate(output), b"a!a!a!a")
|
||||
|
||||
output = string_ops.string_join([input2] * 2, separator="")
|
||||
self.assertAllEqual(output, [[b"bb"], [b"cc"]])
|
||||
|
@ -617,7 +617,7 @@ class TensorArrayTest(test.TestCase):
|
||||
with self.assertRaisesOpError(
|
||||
r"TensorArray foo_.*: Could not write to TensorArray index 2 because "
|
||||
r"it has already been written to."):
|
||||
w1.flow.eval()
|
||||
self.evaluate(w1.flow)
|
||||
|
||||
# Using differing shapes causes an exception
|
||||
wb0_grad = ta_grad.write(1, c(1.0))
|
||||
@ -626,7 +626,7 @@ class TensorArrayTest(test.TestCase):
|
||||
with self.assertRaisesOpError(
|
||||
r"Could not aggregate to TensorArray index 1 because the "
|
||||
r"existing shape is \[\] but the new input shape is \[1\]"):
|
||||
wb1_grad.flow.eval()
|
||||
self.evaluate(wb1_grad.flow)
|
||||
|
||||
@test_util.disable_control_flow_v2("v2 does not support TensorArray.grad.")
|
||||
@test_util.run_v1_only("v2 does not support TensorArray.grad.")
|
||||
|
Loading…
Reference in New Issue
Block a user