STT-tensorflow/tensorflow/python/kernel_tests/barrier_ops_test.py
Gaurav Jain 24f578cd66 Add @run_deprecated_v1 annotation to tests failing in v2
PiperOrigin-RevId: 223422907
2018-11-29 15:43:25 -08:00

750 lines
27 KiB
Python

# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for barrier ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.platform import test
class BarrierTest(test.TestCase):
def testConstructorWithShapes(self):
with ops.Graph().as_default():
b = data_flow_ops.Barrier(
(dtypes.float32, dtypes.float32),
shapes=((1, 2, 3), (8,)),
shared_name="B",
name="B")
self.assertTrue(isinstance(b.barrier_ref, ops.Tensor))
self.assertProtoEquals("""
name:'B' op:'Barrier'
attr {
key: "capacity"
value {
i: -1
}
}
attr { key: 'component_types'
value { list { type: DT_FLOAT type: DT_FLOAT } } }
attr {
key: 'shapes'
value {
list {
shape {
dim { size: 1 } dim { size: 2 } dim { size: 3 }
}
shape {
dim { size: 8 }
}
}
}
}
attr { key: 'container' value { s: "" } }
attr { key: 'shared_name' value: { s: 'B' } }
""", b.barrier_ref.op.node_def)
@test_util.run_deprecated_v1
def testInsertMany(self):
with self.cached_session():
b = data_flow_ops.Barrier(
(dtypes.float32, dtypes.float32), shapes=((), ()), name="B")
size_t = b.ready_size()
self.assertEqual([], size_t.get_shape())
keys = [b"a", b"b", b"c"]
insert_0_op = b.insert_many(0, keys, [10.0, 20.0, 30.0])
insert_1_op = b.insert_many(1, keys, [100.0, 200.0, 300.0])
self.assertEquals(size_t.eval(), [0])
insert_0_op.run()
self.assertEquals(size_t.eval(), [0])
insert_1_op.run()
self.assertEquals(size_t.eval(), [3])
def testInsertManyEmptyTensor(self):
with self.cached_session():
error_message = ("Empty tensors are not supported, but received shape "
r"\'\(0,\)\' at index 1")
with self.assertRaisesRegexp(ValueError, error_message):
data_flow_ops.Barrier(
(dtypes.float32, dtypes.float32), shapes=((1,), (0,)), name="B")
@test_util.run_deprecated_v1
def testInsertManyEmptyTensorUnknown(self):
with self.cached_session():
b = data_flow_ops.Barrier((dtypes.float32, dtypes.float32), name="B")
size_t = b.ready_size()
self.assertEqual([], size_t.get_shape())
keys = [b"a", b"b", b"c"]
insert_0_op = b.insert_many(0, keys, np.array([[], [], []], np.float32))
self.assertEquals(size_t.eval(), [0])
with self.assertRaisesOpError(
".*Tensors with no elements are not supported.*"):
insert_0_op.run()
@test_util.run_deprecated_v1
def testTakeMany(self):
with self.cached_session() as sess:
b = data_flow_ops.Barrier(
(dtypes.float32, dtypes.float32), shapes=((), ()), name="B")
size_t = b.ready_size()
keys = [b"a", b"b", b"c"]
values_0 = [10.0, 20.0, 30.0]
values_1 = [100.0, 200.0, 300.0]
insert_0_op = b.insert_many(0, keys, values_0)
insert_1_op = b.insert_many(1, keys, values_1)
take_t = b.take_many(3)
insert_0_op.run()
insert_1_op.run()
self.assertEquals(size_t.eval(), [3])
indices_val, keys_val, values_0_val, values_1_val = sess.run(
[take_t[0], take_t[1], take_t[2][0], take_t[2][1]])
self.assertAllEqual(indices_val, [-2**63] * 3)
for k, v0, v1 in zip(keys, values_0, values_1):
idx = keys_val.tolist().index(k)
self.assertEqual(values_0_val[idx], v0)
self.assertEqual(values_1_val[idx], v1)
@test_util.run_deprecated_v1
def testTakeManySmallBatch(self):
with self.cached_session() as sess:
b = data_flow_ops.Barrier(
(dtypes.float32, dtypes.float32), shapes=((), ()), name="B")
size_t = b.ready_size()
size_i = b.incomplete_size()
keys = [b"a", b"b", b"c", b"d"]
values_0 = [10.0, 20.0, 30.0, 40.0]
values_1 = [100.0, 200.0, 300.0, 400.0]
insert_0_op = b.insert_many(0, keys, values_0)
# Split adding of the second component into two independent operations.
# After insert_1_1_op, we'll have two ready elements in the barrier,
# 2 will still be incomplete.
insert_1_1_op = b.insert_many(1, keys[0:2], values_1[0:2]) # add "a", "b"
insert_1_2_op = b.insert_many(1, keys[2:3], values_1[2:3]) # add "c"
insert_1_3_op = b.insert_many(1, keys[3:], values_1[3:]) # add "d"
insert_empty_op = b.insert_many(0, [], [])
close_op = b.close()
close_op_final = b.close(cancel_pending_enqueues=True)
index_t, key_t, value_list_t = b.take_many(3, allow_small_batch=True)
insert_0_op.run()
insert_1_1_op.run()
close_op.run()
# Now we have a closed barrier with 2 ready elements. Running take_t
# should return a reduced batch with 2 elements only.
self.assertEquals(size_i.eval(), [2]) # assert that incomplete size = 2
self.assertEquals(size_t.eval(), [2]) # assert that ready size = 2
_, keys_val, values_0_val, values_1_val = sess.run(
[index_t, key_t, value_list_t[0], value_list_t[1]])
# Check that correct values have been returned.
for k, v0, v1 in zip(keys[0:2], values_0[0:2], values_1[0:2]):
idx = keys_val.tolist().index(k)
self.assertEqual(values_0_val[idx], v0)
self.assertEqual(values_1_val[idx], v1)
# The next insert completes the element with key "c". The next take_t
# should return a batch with just 1 element.
insert_1_2_op.run()
self.assertEquals(size_i.eval(), [1]) # assert that incomplete size = 1
self.assertEquals(size_t.eval(), [1]) # assert that ready size = 1
_, keys_val, values_0_val, values_1_val = sess.run(
[index_t, key_t, value_list_t[0], value_list_t[1]])
# Check that correct values have been returned.
for k, v0, v1 in zip(keys[2:3], values_0[2:3], values_1[2:3]):
idx = keys_val.tolist().index(k)
self.assertEqual(values_0_val[idx], v0)
self.assertEqual(values_1_val[idx], v1)
# Adding nothing ought to work, even if the barrier is closed.
insert_empty_op.run()
# currently keys "a" and "b" are not in the barrier, adding them
# again after it has been closed, ought to cause failure.
with self.assertRaisesOpError("is closed"):
insert_1_1_op.run()
close_op_final.run()
# These ops should fail because the barrier has now been closed with
# cancel_pending_enqueues = True.
with self.assertRaisesOpError("is closed"):
insert_empty_op.run()
with self.assertRaisesOpError("is closed"):
insert_1_3_op.run()
@test_util.run_deprecated_v1
def testUseBarrierWithShape(self):
with self.cached_session() as sess:
b = data_flow_ops.Barrier(
(dtypes.float32, dtypes.float32), shapes=((2, 2), (8,)), name="B")
size_t = b.ready_size()
keys = [b"a", b"b", b"c"]
values_0 = np.array(
[[[10.0] * 2] * 2, [[20.0] * 2] * 2, [[30.0] * 2] * 2], np.float32)
values_1 = np.array([[100.0] * 8, [200.0] * 8, [300.0] * 8], np.float32)
insert_0_op = b.insert_many(0, keys, values_0)
insert_1_op = b.insert_many(1, keys, values_1)
take_t = b.take_many(3)
insert_0_op.run()
insert_1_op.run()
self.assertEquals(size_t.eval(), [3])
indices_val, keys_val, values_0_val, values_1_val = sess.run(
[take_t[0], take_t[1], take_t[2][0], take_t[2][1]])
self.assertAllEqual(indices_val, [-2**63] * 3)
self.assertShapeEqual(keys_val, take_t[1])
self.assertShapeEqual(values_0_val, take_t[2][0])
self.assertShapeEqual(values_1_val, take_t[2][1])
for k, v0, v1 in zip(keys, values_0, values_1):
idx = keys_val.tolist().index(k)
self.assertAllEqual(values_0_val[idx], v0)
self.assertAllEqual(values_1_val[idx], v1)
@test_util.run_deprecated_v1
def testParallelInsertMany(self):
with self.cached_session() as sess:
b = data_flow_ops.Barrier(dtypes.float32, shapes=())
size_t = b.ready_size()
keys = [str(x).encode("ascii") for x in range(10)]
values = [float(x) for x in range(10)]
insert_ops = [b.insert_many(0, [k], [v]) for k, v in zip(keys, values)]
take_t = b.take_many(10)
self.evaluate(insert_ops)
self.assertEquals(size_t.eval(), [10])
indices_val, keys_val, values_val = sess.run(
[take_t[0], take_t[1], take_t[2][0]])
self.assertAllEqual(indices_val, [-2**63 + x for x in range(10)])
for k, v in zip(keys, values):
idx = keys_val.tolist().index(k)
self.assertEqual(values_val[idx], v)
@test_util.run_deprecated_v1
def testParallelTakeMany(self):
with self.cached_session() as sess:
b = data_flow_ops.Barrier(dtypes.float32, shapes=())
size_t = b.ready_size()
keys = [str(x).encode("ascii") for x in range(10)]
values = [float(x) for x in range(10)]
insert_op = b.insert_many(0, keys, values)
take_t = [b.take_many(1) for _ in keys]
insert_op.run()
self.assertEquals(size_t.eval(), [10])
index_fetches = []
key_fetches = []
value_fetches = []
for ix_t, k_t, v_t in take_t:
index_fetches.append(ix_t)
key_fetches.append(k_t)
value_fetches.append(v_t[0])
vals = sess.run(index_fetches + key_fetches + value_fetches)
index_vals = vals[:len(keys)]
key_vals = vals[len(keys):2 * len(keys)]
value_vals = vals[2 * len(keys):]
taken_elems = []
for k, v in zip(key_vals, value_vals):
taken_elems.append((k[0], v[0]))
self.assertAllEqual(np.hstack(index_vals), [-2**63] * 10)
self.assertItemsEqual(
zip(keys, values), [(k[0], v[0]) for k, v in zip(key_vals, value_vals)])
@test_util.run_deprecated_v1
def testBlockingTakeMany(self):
with self.cached_session() as sess:
b = data_flow_ops.Barrier(dtypes.float32, shapes=())
keys = [str(x).encode("ascii") for x in range(10)]
values = [float(x) for x in range(10)]
insert_ops = [b.insert_many(0, [k], [v]) for k, v in zip(keys, values)]
take_t = b.take_many(10)
def take():
indices_val, keys_val, values_val = sess.run(
[take_t[0], take_t[1], take_t[2][0]])
self.assertAllEqual(indices_val,
[int(x.decode("ascii")) - 2**63 for x in keys_val])
self.assertItemsEqual(zip(keys, values), zip(keys_val, values_val))
t = self.checkedThread(target=take)
t.start()
time.sleep(0.1)
for insert_op in insert_ops:
insert_op.run()
t.join()
@test_util.run_deprecated_v1
def testParallelInsertManyTakeMany(self):
with self.cached_session() as sess:
b = data_flow_ops.Barrier(
(dtypes.float32, dtypes.int64), shapes=((), (2,)))
num_iterations = 100
keys = [str(x) for x in range(10)]
values_0 = np.asarray(range(10), dtype=np.float32)
values_1 = np.asarray([[x + 1, x + 2] for x in range(10)], dtype=np.int64)
keys_i = lambda i: [("%d:%s" % (i, k)).encode("ascii") for k in keys]
insert_0_ops = [
b.insert_many(0, keys_i(i), values_0 + i)
for i in range(num_iterations)
]
insert_1_ops = [
b.insert_many(1, keys_i(i), values_1 + i)
for i in range(num_iterations)
]
take_ops = [b.take_many(10) for _ in range(num_iterations)]
def take(sess, i, taken):
indices_val, keys_val, values_0_val, values_1_val = sess.run([
take_ops[i][0], take_ops[i][1], take_ops[i][2][0], take_ops[i][2][1]
])
taken.append({
"indices": indices_val,
"keys": keys_val,
"values_0": values_0_val,
"values_1": values_1_val
})
def insert(sess, i):
sess.run([insert_0_ops[i], insert_1_ops[i]])
taken = []
take_threads = [
self.checkedThread(
target=take, args=(sess, i, taken)) for i in range(num_iterations)
]
insert_threads = [
self.checkedThread(
target=insert, args=(sess, i)) for i in range(num_iterations)
]
for t in take_threads:
t.start()
time.sleep(0.1)
for t in insert_threads:
t.start()
for t in take_threads:
t.join()
for t in insert_threads:
t.join()
self.assertEquals(len(taken), num_iterations)
flatten = lambda l: [item for sublist in l for item in sublist]
all_indices = sorted(flatten([t_i["indices"] for t_i in taken]))
all_keys = sorted(flatten([t_i["keys"] for t_i in taken]))
expected_keys = sorted(
flatten([keys_i(i) for i in range(num_iterations)]))
expected_indices = sorted(
flatten([-2**63 + j] * 10 for j in range(num_iterations)))
self.assertAllEqual(all_indices, expected_indices)
self.assertAllEqual(all_keys, expected_keys)
for taken_i in taken:
outer_indices_from_keys = np.array(
[int(k.decode("ascii").split(":")[0]) for k in taken_i["keys"]])
inner_indices_from_keys = np.array(
[int(k.decode("ascii").split(":")[1]) for k in taken_i["keys"]])
self.assertAllEqual(taken_i["values_0"],
outer_indices_from_keys + inner_indices_from_keys)
expected_values_1 = np.vstack(
(1 + outer_indices_from_keys + inner_indices_from_keys,
2 + outer_indices_from_keys + inner_indices_from_keys)).T
self.assertAllEqual(taken_i["values_1"], expected_values_1)
@test_util.run_deprecated_v1
def testClose(self):
with self.cached_session() as sess:
b = data_flow_ops.Barrier(
(dtypes.float32, dtypes.float32), shapes=((), ()), name="B")
size_t = b.ready_size()
incomplete_t = b.incomplete_size()
keys = [b"a", b"b", b"c"]
values_0 = [10.0, 20.0, 30.0]
values_1 = [100.0, 200.0, 300.0]
insert_0_op = b.insert_many(0, keys, values_0)
insert_1_op = b.insert_many(1, keys, values_1)
close_op = b.close()
fail_insert_op = b.insert_many(0, ["f"], [60.0])
take_t = b.take_many(3)
take_too_many_t = b.take_many(4)
self.assertEquals(size_t.eval(), [0])
self.assertEquals(incomplete_t.eval(), [0])
insert_0_op.run()
self.assertEquals(size_t.eval(), [0])
self.assertEquals(incomplete_t.eval(), [3])
close_op.run()
# This op should fail because the barrier is closed.
with self.assertRaisesOpError("is closed"):
fail_insert_op.run()
# This op should succeed because the barrier has not canceled
# pending enqueues
insert_1_op.run()
self.assertEquals(size_t.eval(), [3])
self.assertEquals(incomplete_t.eval(), [0])
# This op should fail because the barrier is closed.
with self.assertRaisesOpError("is closed"):
fail_insert_op.run()
# This op should fail because we requested more elements than are
# available in incomplete + ready queue.
with self.assertRaisesOpError(r"is closed and has insufficient elements "
r"\(requested 4, total size 3\)"):
sess.run(take_too_many_t[0]) # Sufficient to request just the indices
# This op should succeed because there are still completed elements
# to process.
indices_val, keys_val, values_0_val, values_1_val = sess.run(
[take_t[0], take_t[1], take_t[2][0], take_t[2][1]])
self.assertAllEqual(indices_val, [-2**63] * 3)
for k, v0, v1 in zip(keys, values_0, values_1):
idx = keys_val.tolist().index(k)
self.assertEqual(values_0_val[idx], v0)
self.assertEqual(values_1_val[idx], v1)
# This op should fail because there are no more completed elements and
# the queue is closed.
with self.assertRaisesOpError("is closed and has insufficient elements"):
sess.run(take_t[0])
@test_util.run_deprecated_v1
def testCancel(self):
with self.cached_session() as sess:
b = data_flow_ops.Barrier(
(dtypes.float32, dtypes.float32), shapes=((), ()), name="B")
size_t = b.ready_size()
incomplete_t = b.incomplete_size()
keys = [b"a", b"b", b"c"]
values_0 = [10.0, 20.0, 30.0]
values_1 = [100.0, 200.0, 300.0]
insert_0_op = b.insert_many(0, keys, values_0)
insert_1_op = b.insert_many(1, keys[0:2], values_1[0:2])
insert_2_op = b.insert_many(1, keys[2:], values_1[2:])
cancel_op = b.close(cancel_pending_enqueues=True)
fail_insert_op = b.insert_many(0, ["f"], [60.0])
take_t = b.take_many(2)
take_too_many_t = b.take_many(3)
self.assertEquals(size_t.eval(), [0])
insert_0_op.run()
insert_1_op.run()
self.assertEquals(size_t.eval(), [2])
self.assertEquals(incomplete_t.eval(), [1])
cancel_op.run()
# This op should fail because the queue is closed.
with self.assertRaisesOpError("is closed"):
fail_insert_op.run()
# This op should fail because the queue is canceled.
with self.assertRaisesOpError("is closed"):
insert_2_op.run()
# This op should fail because we requested more elements than are
# available in incomplete + ready queue.
with self.assertRaisesOpError(r"is closed and has insufficient elements "
r"\(requested 3, total size 2\)"):
sess.run(take_too_many_t[0]) # Sufficient to request just the indices
# This op should succeed because there are still completed elements
# to process.
indices_val, keys_val, values_0_val, values_1_val = sess.run(
[take_t[0], take_t[1], take_t[2][0], take_t[2][1]])
self.assertAllEqual(indices_val, [-2**63] * 2)
for k, v0, v1 in zip(keys[0:2], values_0[0:2], values_1[0:2]):
idx = keys_val.tolist().index(k)
self.assertEqual(values_0_val[idx], v0)
self.assertEqual(values_1_val[idx], v1)
# This op should fail because there are no more completed elements and
# the queue is closed.
with self.assertRaisesOpError("is closed and has insufficient elements"):
sess.run(take_t[0])
def _testClosedEmptyBarrierTakeManyAllowSmallBatchRaises(self, cancel):
with self.cached_session() as sess:
b = data_flow_ops.Barrier(
(dtypes.float32, dtypes.float32), shapes=((), ()), name="B")
take_t = b.take_many(1, allow_small_batch=True)
self.evaluate(b.close(cancel))
with self.assertRaisesOpError("is closed and has insufficient elements"):
self.evaluate(take_t)
@test_util.run_deprecated_v1
def testClosedEmptyBarrierTakeManyAllowSmallBatchRaises(self):
self._testClosedEmptyBarrierTakeManyAllowSmallBatchRaises(cancel=False)
self._testClosedEmptyBarrierTakeManyAllowSmallBatchRaises(cancel=True)
def _testParallelInsertManyTakeManyCloseHalfwayThrough(self, cancel):
with self.cached_session() as sess:
b = data_flow_ops.Barrier(
(dtypes.float32, dtypes.int64), shapes=((), (2,)))
num_iterations = 50
keys = [str(x) for x in range(10)]
values_0 = np.asarray(range(10), dtype=np.float32)
values_1 = np.asarray([[x + 1, x + 2] for x in range(10)], dtype=np.int64)
keys_i = lambda i: [("%d:%s" % (i, k)).encode("ascii") for k in keys]
insert_0_ops = [
b.insert_many(0, keys_i(i), values_0 + i)
for i in range(num_iterations)
]
insert_1_ops = [
b.insert_many(1, keys_i(i), values_1 + i)
for i in range(num_iterations)
]
take_ops = [b.take_many(10) for _ in range(num_iterations)]
close_op = b.close(cancel_pending_enqueues=cancel)
def take(sess, i, taken):
try:
indices_val, unused_keys_val, unused_val_0, unused_val_1 = sess.run([
take_ops[i][0], take_ops[i][1], take_ops[i][2][0],
take_ops[i][2][1]
])
taken.append(len(indices_val))
except errors_impl.OutOfRangeError:
taken.append(0)
def insert(sess, i):
try:
sess.run([insert_0_ops[i], insert_1_ops[i]])
except errors_impl.CancelledError:
pass
taken = []
take_threads = [
self.checkedThread(
target=take, args=(sess, i, taken)) for i in range(num_iterations)
]
insert_threads = [
self.checkedThread(
target=insert, args=(sess, i)) for i in range(num_iterations)
]
first_half_insert_threads = insert_threads[:num_iterations // 2]
second_half_insert_threads = insert_threads[num_iterations // 2:]
for t in take_threads:
t.start()
for t in first_half_insert_threads:
t.start()
for t in first_half_insert_threads:
t.join()
close_op.run()
for t in second_half_insert_threads:
t.start()
for t in take_threads:
t.join()
for t in second_half_insert_threads:
t.join()
self.assertEqual(
sorted(taken),
[0] * (num_iterations // 2) + [10] * (num_iterations // 2))
@test_util.run_deprecated_v1
def testParallelInsertManyTakeManyCloseHalfwayThrough(self):
self._testParallelInsertManyTakeManyCloseHalfwayThrough(cancel=False)
@test_util.run_deprecated_v1
def testParallelInsertManyTakeManyCancelHalfwayThrough(self):
self._testParallelInsertManyTakeManyCloseHalfwayThrough(cancel=True)
def _testParallelPartialInsertManyTakeManyCloseHalfwayThrough(self, cancel):
with self.cached_session() as sess:
b = data_flow_ops.Barrier(
(dtypes.float32, dtypes.int64), shapes=((), (2,)))
num_iterations = 100
keys = [str(x) for x in range(10)]
values_0 = np.asarray(range(10), dtype=np.float32)
values_1 = np.asarray([[x + 1, x + 2] for x in range(10)], dtype=np.int64)
keys_i = lambda i: [("%d:%s" % (i, k)).encode("ascii") for k in keys]
insert_0_ops = [
b.insert_many(
0, keys_i(i), values_0 + i, name="insert_0_%d" % i)
for i in range(num_iterations)
]
close_op = b.close(cancel_pending_enqueues=cancel)
take_ops = [
b.take_many(
10, name="take_%d" % i) for i in range(num_iterations)
]
# insert_1_ops will only run after closure
insert_1_ops = [
b.insert_many(
1, keys_i(i), values_1 + i, name="insert_1_%d" % i)
for i in range(num_iterations)
]
def take(sess, i, taken):
if cancel:
try:
indices_val, unused_keys_val, unused_val_0, unused_val_1 = sess.run(
[
take_ops[i][0], take_ops[i][1], take_ops[i][2][0],
take_ops[i][2][1]
])
taken.append(len(indices_val))
except errors_impl.OutOfRangeError:
taken.append(0)
else:
indices_val, unused_keys_val, unused_val_0, unused_val_1 = sess.run([
take_ops[i][0], take_ops[i][1], take_ops[i][2][0],
take_ops[i][2][1]
])
taken.append(len(indices_val))
def insert_0(sess, i):
insert_0_ops[i].run(session=sess)
def insert_1(sess, i):
if cancel:
try:
insert_1_ops[i].run(session=sess)
except errors_impl.CancelledError:
pass
else:
insert_1_ops[i].run(session=sess)
taken = []
take_threads = [
self.checkedThread(
target=take, args=(sess, i, taken)) for i in range(num_iterations)
]
insert_0_threads = [
self.checkedThread(
target=insert_0, args=(sess, i)) for i in range(num_iterations)
]
insert_1_threads = [
self.checkedThread(
target=insert_1, args=(sess, i)) for i in range(num_iterations)
]
for t in insert_0_threads:
t.start()
for t in insert_0_threads:
t.join()
for t in take_threads:
t.start()
close_op.run()
for t in insert_1_threads:
t.start()
for t in take_threads:
t.join()
for t in insert_1_threads:
t.join()
if cancel:
self.assertEqual(taken, [0] * num_iterations)
else:
self.assertEqual(taken, [10] * num_iterations)
@test_util.run_deprecated_v1
def testParallelPartialInsertManyTakeManyCloseHalfwayThrough(self):
self._testParallelPartialInsertManyTakeManyCloseHalfwayThrough(cancel=False)
@test_util.run_deprecated_v1
def testParallelPartialInsertManyTakeManyCancelHalfwayThrough(self):
self._testParallelPartialInsertManyTakeManyCloseHalfwayThrough(cancel=True)
@test_util.run_deprecated_v1
def testIncompatibleSharedBarrierErrors(self):
with self.cached_session():
# Do component types and shapes.
b_a_1 = data_flow_ops.Barrier(
(dtypes.float32,), shapes=(()), shared_name="b_a")
b_a_2 = data_flow_ops.Barrier(
(dtypes.int32,), shapes=(()), shared_name="b_a")
b_a_1.barrier_ref.eval()
with self.assertRaisesOpError("component types"):
b_a_2.barrier_ref.eval()
b_b_1 = data_flow_ops.Barrier(
(dtypes.float32,), shapes=(()), shared_name="b_b")
b_b_2 = data_flow_ops.Barrier(
(dtypes.float32, dtypes.int32), shapes=((), ()), shared_name="b_b")
b_b_1.barrier_ref.eval()
with self.assertRaisesOpError("component types"):
b_b_2.barrier_ref.eval()
b_c_1 = data_flow_ops.Barrier(
(dtypes.float32, dtypes.float32),
shapes=((2, 2), (8,)),
shared_name="b_c")
b_c_2 = data_flow_ops.Barrier(
(dtypes.float32, dtypes.float32), shared_name="b_c")
b_c_1.barrier_ref.eval()
with self.assertRaisesOpError("component shapes"):
b_c_2.barrier_ref.eval()
b_d_1 = data_flow_ops.Barrier(
(dtypes.float32, dtypes.float32), shapes=((), ()), shared_name="b_d")
b_d_2 = data_flow_ops.Barrier(
(dtypes.float32, dtypes.float32),
shapes=((2, 2), (8,)),
shared_name="b_d")
b_d_1.barrier_ref.eval()
with self.assertRaisesOpError("component shapes"):
b_d_2.barrier_ref.eval()
b_e_1 = data_flow_ops.Barrier(
(dtypes.float32, dtypes.float32),
shapes=((2, 2), (8,)),
shared_name="b_e")
b_e_2 = data_flow_ops.Barrier(
(dtypes.float32, dtypes.float32),
shapes=((2, 5), (8,)),
shared_name="b_e")
b_e_1.barrier_ref.eval()
with self.assertRaisesOpError("component shapes"):
b_e_2.barrier_ref.eval()
if __name__ == "__main__":
test.main()