Move python batch op from contrib to core.
PiperOrigin-RevId: 242728245
This commit is contained in:
parent
a1cffe1a7a
commit
a1e538fc9f
tensorflow
contrib/batching
core/api_def/python_api
python
tools
api/golden
compatibility
@ -6,31 +6,18 @@ package(
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"py_test",
|
||||
"tf_custom_op_library",
|
||||
"tf_gen_op_libs",
|
||||
"tf_gen_op_wrapper_py",
|
||||
"tf_kernel_library",
|
||||
)
|
||||
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
|
||||
|
||||
py_library(
|
||||
name = "batch_py",
|
||||
srcs = glob(["python/ops/*.py"]) + ["__init__.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/util:util_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:batch_ops",
|
||||
"//tensorflow/python:batch_ops_gen",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:gradients",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:script_ops",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -18,14 +18,13 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.ops import gen_batch_ops
|
||||
# go/tf-wildcard-import
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.python.ops.gen_batch_ops import *
|
||||
# pylint: enable=wildcard-import
|
||||
# pylint: disable=unused-import
|
||||
from tensorflow.python.ops.batch_ops import batch
|
||||
from tensorflow.python.ops.batch_ops import batch_function
|
||||
from tensorflow.python.ops.batch_ops import unbatch
|
||||
# pylint: enable=unused-import
|
||||
|
||||
|
||||
@ops.RegisterGradient("Batch")
|
||||
@ -55,85 +54,6 @@ def _UnbatchGrad(op, grad): # pylint: disable=invalid-name
|
||||
]
|
||||
|
||||
|
||||
def batch_function(num_batch_threads,
|
||||
max_batch_size,
|
||||
batch_timeout_micros,
|
||||
allowed_batch_sizes=None,
|
||||
max_enqueued_batches=10):
|
||||
"""Batches the computation done by the decorated function.
|
||||
|
||||
So, for example, in the following code
|
||||
|
||||
```python
|
||||
@batch_function(1, 2, 3)
|
||||
def layer(a):
|
||||
return tf.matmul(a, a)
|
||||
|
||||
b = layer(w)
|
||||
```
|
||||
|
||||
if more than one session.run call is simultaneously trying to compute `b`
|
||||
the values of `w` will be gathered, non-deterministically concatenated
|
||||
along the first axis, and only one thread will run the computation. See the
|
||||
documentation of the `Batch` op for more details.
|
||||
|
||||
Assumes that all arguments of the decorated function are Tensors which will
|
||||
be batched along their first dimension.
|
||||
|
||||
SparseTensor is not supported. The return value of the decorated function
|
||||
must be a Tensor or a list/tuple of Tensors.
|
||||
|
||||
Args:
|
||||
num_batch_threads: Number of scheduling threads for processing batches
|
||||
of work. Determines the number of batches processed in parallel.
|
||||
max_batch_size: Batch sizes will never be bigger than this.
|
||||
batch_timeout_micros: Maximum number of microseconds to wait before
|
||||
outputting an incomplete batch.
|
||||
allowed_batch_sizes: Optional list of allowed batch sizes. If left empty,
|
||||
does nothing. Otherwise, supplies a list of batch sizes, causing the op
|
||||
to pad batches up to one of those sizes. The entries must increase
|
||||
monotonically, and the final entry must equal max_batch_size.
|
||||
max_enqueued_batches: The maximum depth of the batch queue. Defaults to 10.
|
||||
|
||||
Returns:
|
||||
The decorated function will return the unbatched computation output Tensors.
|
||||
"""
|
||||
|
||||
def decorator(fn): # pylint: disable=missing-docstring
|
||||
|
||||
def decorated(*args): # pylint: disable=missing-docstring
|
||||
|
||||
@function.defun(autograph=False)
|
||||
def computation(*computation_args):
|
||||
return fn(*computation_args)
|
||||
|
||||
computation = computation.get_concrete_function(
|
||||
*[tensor_spec.TensorSpec(dtype=x.dtype, shape=x.shape, name=str(i))
|
||||
for i, x in enumerate(args)])
|
||||
|
||||
with ops.name_scope("batch") as name:
|
||||
for a in args:
|
||||
if not isinstance(a, ops.Tensor):
|
||||
raise ValueError("All arguments to functions decorated with "
|
||||
"`batch_function` are supposed to be Tensors; "
|
||||
"found %s" % repr(a))
|
||||
return gen_batch_ops.batch_function(
|
||||
num_batch_threads=num_batch_threads,
|
||||
max_batch_size=max_batch_size,
|
||||
batch_timeout_micros=batch_timeout_micros,
|
||||
allowed_batch_sizes=allowed_batch_sizes,
|
||||
max_enqueued_batches=max_enqueued_batches,
|
||||
shared_name=name,
|
||||
f=computation,
|
||||
in_tensors=list(args),
|
||||
captured_tensors=computation.captured_inputs,
|
||||
Tout=[o.dtype for o in computation.outputs])
|
||||
|
||||
return decorated
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def batch_function_v1(num_batch_threads,
|
||||
max_batch_size,
|
||||
batch_timeout_micros,
|
||||
|
@ -23,12 +23,8 @@ import time
|
||||
|
||||
from tensorflow.contrib.batching.python.ops import batch_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework.errors import InvalidArgumentError
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_batch_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import script_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -41,153 +37,6 @@ def delayed_plus1(x):
|
||||
class BatchOpsTest(test.TestCase):
|
||||
"""Tests for batch_ops.{un,}batch."""
|
||||
|
||||
def testBasicBatch(self):
|
||||
"""Tests that a single batched tensor executes together and only once."""
|
||||
with self.cached_session() as sess:
|
||||
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
|
||||
batched, index, _ = batch_ops.batch(
|
||||
[inp], num_batch_threads=1, max_batch_size=2,
|
||||
batch_timeout_micros=36000000, grad_timeout_micros=0,
|
||||
batching_queue="")
|
||||
thread_results = []
|
||||
|
||||
def worker():
|
||||
thread_results.extend(
|
||||
sess.run([batched, index], feed_dict={inp: [1]}))
|
||||
|
||||
worker_thread = threading.Thread(target=worker)
|
||||
worker_thread.start()
|
||||
main_results = sess.run([batched, index], feed_dict={inp: [2]})
|
||||
worker_thread.join()
|
||||
|
||||
# At this point either the thread or the main did the batch and the other
|
||||
# should have empty results.
|
||||
if list(thread_results[0][0]):
|
||||
batch_t = thread_results[0][0]
|
||||
index_t = thread_results[1]
|
||||
empty_b = main_results[0][0]
|
||||
empty_m = main_results[1]
|
||||
else:
|
||||
batch_t = main_results[0][0]
|
||||
index_t = main_results[1]
|
||||
empty_b = thread_results[0][0]
|
||||
empty_m = thread_results[1]
|
||||
|
||||
# Check that both the inputs made it out exactly once.
|
||||
self.assertAllEqual(sorted(batch_t), (1, 2))
|
||||
# Check that we get 2 rows in the index tensor.
|
||||
self.assertEqual(len(index_t), 2)
|
||||
# Check that the other ones are empty.
|
||||
self.assertEqual(len(empty_b), 0)
|
||||
self.assertEqual(len(empty_m), 0)
|
||||
|
||||
def testBatchWithPadding(self):
|
||||
"""Test that batching with padding up to an allowed batch size works."""
|
||||
with self.cached_session() as sess:
|
||||
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[2])
|
||||
batched, index, _ = batch_ops.batch(
|
||||
[inp], num_batch_threads=1, max_batch_size=10,
|
||||
batch_timeout_micros=100000, # 100ms
|
||||
allowed_batch_sizes=[5, 10],
|
||||
grad_timeout_micros=0, batching_queue="")
|
||||
thread_results = []
|
||||
|
||||
def worker():
|
||||
thread_results.extend(
|
||||
sess.run([batched, index], feed_dict={inp: [1, 3]}))
|
||||
|
||||
worker_thread = threading.Thread(target=worker)
|
||||
worker_thread.start()
|
||||
main_results = sess.run([batched, index], feed_dict={inp: [2, 4]})
|
||||
worker_thread.join()
|
||||
|
||||
# At this point either the thread or the main did the batch and the other
|
||||
# should have empty results.
|
||||
if list(thread_results[0][0]):
|
||||
batch_t = thread_results[0][0]
|
||||
else:
|
||||
batch_t = main_results[0][0]
|
||||
|
||||
# Check that the batch tensor incorporates the padding.
|
||||
self.assertEqual(len(batch_t), 5)
|
||||
|
||||
def testMultipleBatch(self):
|
||||
"""Tests that multiple batched tensors execute together."""
|
||||
with self.cached_session() as sess:
|
||||
inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
|
||||
inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
|
||||
batched, _, _ = batch_ops.batch(
|
||||
[inp0, inp1],
|
||||
num_batch_threads=1,
|
||||
max_batch_size=2,
|
||||
batch_timeout_micros=36000000,
|
||||
grad_timeout_micros=0,
|
||||
batching_queue="")
|
||||
thread_results = []
|
||||
|
||||
def worker():
|
||||
thread_results.extend(
|
||||
sess.run([batched], feed_dict={inp0: [1],
|
||||
inp1: [2]}))
|
||||
|
||||
worker_thread = threading.Thread(target=worker)
|
||||
worker_thread.start()
|
||||
main_results = sess.run([batched], feed_dict={inp0: [2], inp1: [3]})
|
||||
worker_thread.join()
|
||||
|
||||
# At this point either the thread or the main did the batch and the other
|
||||
# should have empty results.
|
||||
if list(thread_results[0][0]):
|
||||
batch_t = thread_results[0]
|
||||
empty_t = main_results[0]
|
||||
else:
|
||||
batch_t = main_results[0]
|
||||
empty_t = thread_results[0]
|
||||
|
||||
# Assert that the tensors were batched together.
|
||||
self.assertAllEqual(sorted(batch_t[0]), [1, 2])
|
||||
self.assertAllEqual(sorted(batch_t[1]), [2, 3])
|
||||
self.assertAllEqual(empty_t[0], [])
|
||||
self.assertAllEqual(empty_t[1], [])
|
||||
|
||||
def testIllegalBatchDifferentDim0Sizes(self):
|
||||
"""Tests illegally feeding tensors with different dim0 sizes."""
|
||||
with self.cached_session() as sess:
|
||||
inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
|
||||
inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[2])
|
||||
batched, index, _ = batch_ops.batch(
|
||||
[inp0, inp1], num_batch_threads=1, max_batch_size=2,
|
||||
batch_timeout_micros=0, grad_timeout_micros=0, batching_queue="")
|
||||
with self.assertRaises(Exception) as raised:
|
||||
_ = sess.run([batched, index], feed_dict={inp0: [0], inp1: [1, 2]})
|
||||
self.assertGreater(
|
||||
raised.exception.message.find("must have equal 0th-dimension size"),
|
||||
0)
|
||||
|
||||
def testBasicUnbatch(self):
|
||||
"""Tests that batch and unbatch work together."""
|
||||
with self.cached_session() as sess:
|
||||
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
|
||||
batched, index, id_t = batch_ops.batch(
|
||||
[inp], num_batch_threads=1, max_batch_size=10,
|
||||
batch_timeout_micros=100000, # 100ms
|
||||
allowed_batch_sizes=[3, 10],
|
||||
grad_timeout_micros=0, batching_queue="")
|
||||
computation = batched[0] + 1
|
||||
result = batch_ops.unbatch(computation, index, id_t,
|
||||
timeout_micros=1000000, shared_name="unbatch")
|
||||
thread_results = []
|
||||
|
||||
def worker():
|
||||
thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
|
||||
|
||||
worker_thread = threading.Thread(target=worker)
|
||||
worker_thread.start()
|
||||
main_results = sess.run([result], feed_dict={inp: [2]})
|
||||
worker_thread.join()
|
||||
self.assertEqual(thread_results[0], [2])
|
||||
self.assertEqual(main_results[0], [3])
|
||||
|
||||
def testBasicUnbatchV1Decorated(self):
|
||||
"""Tests that the batch_function_v1 decorator works."""
|
||||
with self.cached_session() as sess:
|
||||
@ -210,206 +59,6 @@ class BatchOpsTest(test.TestCase):
|
||||
self.assertEqual(thread_results[0], [2])
|
||||
self.assertEqual(main_results[0], [3])
|
||||
|
||||
def testBasicUnbatchDecorated(self):
|
||||
"""Tests that the batch_function decorator works."""
|
||||
with self.cached_session() as sess:
|
||||
# TODO(apassos): Removing this line causes test flakiness! Ideally should
|
||||
# be investigated.
|
||||
default_inp = array_ops.placeholder_with_default(2, shape=[]) # pylint: disable=unused-variable
|
||||
|
||||
@batch_ops.batch_function(1, 10, 100000)
|
||||
def computation(in_t):
|
||||
self.assertTrue(in_t.shape is not None)
|
||||
return in_t + 1
|
||||
|
||||
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
|
||||
result = computation(inp)
|
||||
thread_results = []
|
||||
|
||||
def worker():
|
||||
thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
|
||||
|
||||
worker_thread = threading.Thread(target=worker)
|
||||
worker_thread.start()
|
||||
main_results = sess.run([result], feed_dict={inp: [2]})
|
||||
worker_thread.join()
|
||||
self.assertEqual(thread_results[0], [2])
|
||||
self.assertEqual(main_results[0], [3])
|
||||
|
||||
def testBatchDecoratedWithCapturedInput(self):
|
||||
"""Tests that the batch_function decorator works."""
|
||||
with self.cached_session() as sess:
|
||||
captured_inp0 = array_ops.placeholder_with_default(2, shape=[])
|
||||
captured_inp1 = array_ops.placeholder_with_default(1, shape=[])
|
||||
|
||||
@batch_ops.batch_function(1, 10, 100000)
|
||||
def computation(in_t):
|
||||
return in_t + captured_inp0 - captured_inp1
|
||||
|
||||
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
|
||||
result = computation(inp)
|
||||
thread_results = []
|
||||
|
||||
def worker():
|
||||
thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
|
||||
|
||||
worker_thread = threading.Thread(target=worker)
|
||||
worker_thread.start()
|
||||
main_results = sess.run([result], feed_dict={inp: [2]})
|
||||
worker_thread.join()
|
||||
self.assertEqual(thread_results[0], [2])
|
||||
self.assertEqual(main_results[0], [3])
|
||||
|
||||
def testBatchFunctionOp(self):
|
||||
"""Tests that the batch_function op works."""
|
||||
with self.cached_session() as sess:
|
||||
|
||||
@function.Defun(dtypes.int32)
|
||||
def computation(in_t):
|
||||
return in_t + 1
|
||||
|
||||
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
|
||||
result = gen_batch_ops.batch_function(
|
||||
[inp],
|
||||
num_batch_threads=1,
|
||||
max_batch_size=10,
|
||||
batch_timeout_micros=100000,
|
||||
Tout=[dtypes.int32],
|
||||
f=computation,
|
||||
captured_tensors=computation.captured_inputs)
|
||||
thread_results = []
|
||||
|
||||
def worker():
|
||||
thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
|
||||
|
||||
worker_thread = threading.Thread(target=worker)
|
||||
worker_thread.start()
|
||||
main_results = sess.run([result], feed_dict={inp: [2]})
|
||||
worker_thread.join()
|
||||
self.assertEqual(thread_results[0], [2])
|
||||
self.assertEqual(main_results[0], [3])
|
||||
|
||||
def testBatchFunctionOpWithCapturedInput(self):
|
||||
"""Tests that batch_function op works with captured input."""
|
||||
with self.cached_session() as sess:
|
||||
captured_inp0 = array_ops.placeholder_with_default(2, shape=[])
|
||||
captured_inp1 = array_ops.placeholder_with_default(1, shape=[])
|
||||
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
|
||||
|
||||
@function.Defun(dtypes.int32)
|
||||
def computation(inp):
|
||||
return inp + captured_inp0 - captured_inp1
|
||||
|
||||
result = gen_batch_ops.batch_function(
|
||||
num_batch_threads=1,
|
||||
max_batch_size=10,
|
||||
batch_timeout_micros=100000, # 100ms
|
||||
allowed_batch_sizes=[3, 10],
|
||||
batching_queue="",
|
||||
f=computation,
|
||||
in_tensors=[inp],
|
||||
captured_tensors=computation.captured_inputs,
|
||||
Tout=[o.type for o in computation.definition.signature.output_arg])
|
||||
|
||||
thread_results = []
|
||||
|
||||
def worker():
|
||||
thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
|
||||
|
||||
worker_thread = threading.Thread(target=worker)
|
||||
worker_thread.start()
|
||||
main_results = sess.run([result], feed_dict={inp: [2]})
|
||||
worker_thread.join()
|
||||
self.assertEqual(thread_results[0], [2])
|
||||
self.assertEqual(main_results[0], [3])
|
||||
|
||||
def testBatchFunctionOpWithInputError(self):
|
||||
"""Tests that batch_function op works with error in the inputs."""
|
||||
with self.cached_session() as sess:
|
||||
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
|
||||
|
||||
@function.Defun(dtypes.int32, dtypes.int32)
|
||||
def computation(in0, in1):
|
||||
return in0 + in1
|
||||
|
||||
result = gen_batch_ops.batch_function(
|
||||
[inp], # computation actually expects 2 inputs.
|
||||
num_batch_threads=1,
|
||||
max_batch_size=10,
|
||||
batch_timeout_micros=100000, # 100ms
|
||||
batching_queue="",
|
||||
f=computation,
|
||||
captured_tensors=computation.captured_inputs,
|
||||
Tout=[o.type for o in computation.definition.signature.output_arg])
|
||||
|
||||
with self.assertRaisesRegexp(InvalidArgumentError,
|
||||
".*2 arguments.*but 1.*"):
|
||||
sess.run([result], feed_dict={inp: [2]})
|
||||
|
||||
def testBasicUnbatchDecoratedWithReshape(self):
|
||||
"""Tests that the batch_function decorator works."""
|
||||
with self.cached_session() as sess:
|
||||
|
||||
@batch_ops.batch_function(1, 10, 100000)
|
||||
def computation(in_t):
|
||||
return array_ops.reshape(in_t, [-1]) + 1
|
||||
|
||||
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1, 1])
|
||||
result = computation(inp)
|
||||
thread_results = []
|
||||
|
||||
def worker():
|
||||
thread_results.extend(sess.run([result], feed_dict={inp: [[1]]}))
|
||||
|
||||
worker_thread = threading.Thread(target=worker)
|
||||
worker_thread.start()
|
||||
main_results = sess.run([result], feed_dict={inp: [[2]]})
|
||||
worker_thread.join()
|
||||
self.assertEqual(thread_results[0], [2])
|
||||
self.assertEqual(main_results[0], [3])
|
||||
|
||||
def testUnbatchTimeout(self):
|
||||
"""Tests that the unbatch timeout works."""
|
||||
with self.cached_session() as sess:
|
||||
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
|
||||
batched, index, id_t = batch_ops.batch(
|
||||
[inp], num_batch_threads=1, max_batch_size=2,
|
||||
batch_timeout_micros=36000000, grad_timeout_micros=0,
|
||||
batching_queue="")
|
||||
computation = batched[0] + 1
|
||||
timeout_micros = 10
|
||||
result = batch_ops.unbatch(computation, index, id_t, timeout_micros,
|
||||
shared_name="shared_unbatch")
|
||||
# Set up a parallel pipeline that delays the computation, but uses the
|
||||
# same unbatch resource object as the non-delayed pipeline.
|
||||
computation_delayed = script_ops.py_func(delayed_plus1,
|
||||
[batched[0]],
|
||||
dtypes.int32)
|
||||
result_delayed = batch_ops.unbatch(computation_delayed,
|
||||
index,
|
||||
id_t,
|
||||
timeout_micros,
|
||||
shared_name="shared_unbatch")
|
||||
|
||||
thread_results = []
|
||||
def worker():
|
||||
# A first call using the non-delayed pipeline. The batcher will send an
|
||||
# empty tensor along the non-delayed pipeline.
|
||||
thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
|
||||
worker_thread = threading.Thread(target=worker)
|
||||
worker_thread.start()
|
||||
time.sleep(0.1) # Ensure the thread's call starts first.
|
||||
# A second call using the delayed pipeline. The batcher will send the
|
||||
# batched tensor along the delayed pipeline, thus delaying the arrival of
|
||||
# the batched tensor at the unbatch op, relative to the empty tensor.
|
||||
#
|
||||
# TODO(olston, apassos): Avoid relying on the order in which the batch op
|
||||
# emits the empty tensor versus the batched one.
|
||||
_ = sess.run([result_delayed], feed_dict={inp: [2]})
|
||||
worker_thread.join()
|
||||
# The thread's call should hit the timeout, and thus get 0 results.
|
||||
self.assertEqual(len(thread_results), 0)
|
||||
|
||||
def testUnbatchGrad(self):
|
||||
"""Tests that batch and unbatch are differentiable."""
|
||||
with self.cached_session() as sess:
|
||||
@ -434,6 +83,5 @@ class BatchOpsTest(test.TestCase):
|
||||
self.assertEqual(thread_results[0], [2])
|
||||
self.assertEqual(main_results[0], [4])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
4
tensorflow/core/api_def/python_api/api_def_Batch.pbtxt
Normal file
4
tensorflow/core/api_def/python_api/api_def_Batch.pbtxt
Normal file
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "Batch"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "BatchFunction"
|
||||
visibility: HIDDEN
|
||||
}
|
4
tensorflow/core/api_def/python_api/api_def_Unbatch.pbtxt
Normal file
4
tensorflow/core/api_def/python_api/api_def_Unbatch.pbtxt
Normal file
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "Unbatch"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "UnbatchGrad"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -1961,6 +1961,38 @@ tf_gen_op_wrapper_private_py(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "batch_ops",
|
||||
srcs = [
|
||||
"ops/batch_ops.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":batch_ops_gen",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "batch_ops_test",
|
||||
size = "small",
|
||||
srcs = ["ops/batch_ops_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"manual",
|
||||
"no_pip",
|
||||
"nomac",
|
||||
],
|
||||
deps = [
|
||||
":batch_ops",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:gradients",
|
||||
"//tensorflow/python:script_ops",
|
||||
],
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_private_py(
|
||||
name = "manip_ops_gen",
|
||||
visibility = [
|
||||
@ -3304,6 +3336,7 @@ py_library(
|
||||
deps = [
|
||||
":array_grad",
|
||||
":array_ops",
|
||||
":batch_ops",
|
||||
":check_ops",
|
||||
":clip_ops",
|
||||
":confusion_matrix",
|
||||
|
111
tensorflow/python/ops/batch_ops.py
Normal file
111
tensorflow/python/ops/batch_ops.py
Normal file
@ -0,0 +1,111 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Operations for automatic batching and unbatching."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.ops import gen_batch_ops
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.python.ops.gen_batch_ops import *
|
||||
# pylint: enable=wildcard-import
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@tf_export("nondifferentiable_batch_function")
|
||||
def batch_function(num_batch_threads,
|
||||
max_batch_size,
|
||||
batch_timeout_micros,
|
||||
allowed_batch_sizes=None,
|
||||
max_enqueued_batches=10,
|
||||
autograph=True):
|
||||
"""Batches the computation done by the decorated function.
|
||||
|
||||
So, for example, in the following code
|
||||
|
||||
```python
|
||||
@batch_function(1, 2, 3)
|
||||
def layer(a):
|
||||
return tf.matmul(a, a)
|
||||
|
||||
b = layer(w)
|
||||
```
|
||||
|
||||
if more than one session.run call is simultaneously trying to compute `b`
|
||||
the values of `w` will be gathered, non-deterministically concatenated
|
||||
along the first axis, and only one thread will run the computation. See the
|
||||
documentation of the `Batch` op for more details.
|
||||
|
||||
Assumes that all arguments of the decorated function are Tensors which will
|
||||
be batched along their first dimension.
|
||||
|
||||
SparseTensor is not supported. The return value of the decorated function
|
||||
must be a Tensor or a list/tuple of Tensors.
|
||||
|
||||
Args:
|
||||
num_batch_threads: Number of scheduling threads for processing batches
|
||||
of work. Determines the number of batches processed in parallel.
|
||||
max_batch_size: Batch sizes will never be bigger than this.
|
||||
batch_timeout_micros: Maximum number of microseconds to wait before
|
||||
outputting an incomplete batch.
|
||||
allowed_batch_sizes: Optional list of allowed batch sizes. If left empty,
|
||||
does nothing. Otherwise, supplies a list of batch sizes, causing the op
|
||||
to pad batches up to one of those sizes. The entries must increase
|
||||
monotonically, and the final entry must equal max_batch_size.
|
||||
max_enqueued_batches: The maximum depth of the batch queue. Defaults to 10.
|
||||
autograph: Whether to use autograph to compile python and eager style code
|
||||
for efficient graph-mode execution.
|
||||
|
||||
Returns:
|
||||
The decorated function will return the unbatched computation output Tensors.
|
||||
"""
|
||||
|
||||
def decorator(fn): # pylint: disable=missing-docstring
|
||||
|
||||
def decorated(*args): # pylint: disable=missing-docstring
|
||||
|
||||
@function.defun(autograph=autograph)
|
||||
def computation(*computation_args):
|
||||
return fn(*computation_args)
|
||||
|
||||
computation = computation.get_concrete_function(
|
||||
*[tensor_spec.TensorSpec(dtype=x.dtype, shape=x.shape, name=str(i))
|
||||
for i, x in enumerate(args)])
|
||||
|
||||
with ops.name_scope("batch") as name:
|
||||
for a in args:
|
||||
if not isinstance(a, ops.Tensor):
|
||||
raise ValueError("All arguments to functions decorated with "
|
||||
"`batch_function` are supposed to be Tensors; "
|
||||
"found %s" % repr(a))
|
||||
return gen_batch_ops.batch_function(
|
||||
num_batch_threads=num_batch_threads,
|
||||
max_batch_size=max_batch_size,
|
||||
batch_timeout_micros=batch_timeout_micros,
|
||||
allowed_batch_sizes=allowed_batch_sizes,
|
||||
max_enqueued_batches=max_enqueued_batches,
|
||||
shared_name=name,
|
||||
f=computation,
|
||||
in_tensors=list(args),
|
||||
captured_tensors=computation.captured_inputs,
|
||||
Tout=[o.dtype for o in computation.outputs])
|
||||
|
||||
return decorated
|
||||
|
||||
return decorator
|
421
tensorflow/python/ops/batch_ops_test.py
Normal file
421
tensorflow/python/ops/batch_ops_test.py
Normal file
@ -0,0 +1,421 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Tests for the currently experimental in-graph batch ops."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import threading
|
||||
import time
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.framework.errors import InvalidArgumentError
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import batch_ops
|
||||
from tensorflow.python.ops import gen_batch_ops
|
||||
from tensorflow.python.ops import script_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def delayed_plus1(x):
|
||||
"""Sleeps for 100ms then returns x+1."""
|
||||
time.sleep(0.1)
|
||||
return x + 1
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class BatchOpsTest(test.TestCase):
|
||||
"""Tests for batch_ops.{un,}batch."""
|
||||
|
||||
# Test for only non eager mode as batching in eager context as a functionality
|
||||
# is TBD.
|
||||
def testBasicBatch(self):
|
||||
"""Tests that a single batched tensor executes together and only once."""
|
||||
if context.executing_eagerly():
|
||||
return
|
||||
with self.cached_session() as sess:
|
||||
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
|
||||
batched, index, _ = batch_ops.batch(
|
||||
[inp], num_batch_threads=1, max_batch_size=2,
|
||||
batch_timeout_micros=36000000, grad_timeout_micros=0,
|
||||
batching_queue="")
|
||||
thread_results = []
|
||||
|
||||
def worker():
|
||||
thread_results.extend(
|
||||
sess.run([batched, index], feed_dict={inp: [1]}))
|
||||
|
||||
worker_thread = threading.Thread(target=worker)
|
||||
worker_thread.start()
|
||||
main_results = sess.run([batched, index], feed_dict={inp: [2]})
|
||||
worker_thread.join()
|
||||
|
||||
# At this point either the thread or the main did the batch and the other
|
||||
# should have empty results.
|
||||
if list(thread_results[0][0]):
|
||||
batch_t = thread_results[0][0]
|
||||
index_t = thread_results[1]
|
||||
empty_b = main_results[0][0]
|
||||
empty_m = main_results[1]
|
||||
else:
|
||||
batch_t = main_results[0][0]
|
||||
index_t = main_results[1]
|
||||
empty_b = thread_results[0][0]
|
||||
empty_m = thread_results[1]
|
||||
|
||||
# Check that both the inputs made it out exactly once.
|
||||
self.assertAllEqual(sorted(batch_t), (1, 2))
|
||||
# Check that we get 2 rows in the index tensor.
|
||||
self.assertEqual(len(index_t), 2)
|
||||
# Check that the other ones are empty.
|
||||
self.assertEqual(len(empty_b), 0)
|
||||
self.assertEqual(len(empty_m), 0)
|
||||
|
||||
def testBatchWithPadding(self):
|
||||
"""Test that batching with padding up to an allowed batch size works."""
|
||||
if context.executing_eagerly():
|
||||
return
|
||||
with self.cached_session() as sess:
|
||||
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[2])
|
||||
batched, index, _ = batch_ops.batch(
|
||||
[inp], num_batch_threads=1, max_batch_size=10,
|
||||
batch_timeout_micros=100000, # 100ms
|
||||
allowed_batch_sizes=[5, 10],
|
||||
grad_timeout_micros=0, batching_queue="")
|
||||
thread_results = []
|
||||
|
||||
def worker():
|
||||
thread_results.extend(
|
||||
sess.run([batched, index], feed_dict={inp: [1, 3]}))
|
||||
|
||||
worker_thread = threading.Thread(target=worker)
|
||||
worker_thread.start()
|
||||
main_results = sess.run([batched, index], feed_dict={inp: [2, 4]})
|
||||
worker_thread.join()
|
||||
|
||||
# At this point either the thread or the main did the batch and the other
|
||||
# should have empty results.
|
||||
if list(thread_results[0][0]):
|
||||
batch_t = thread_results[0][0]
|
||||
else:
|
||||
batch_t = main_results[0][0]
|
||||
|
||||
# Check that the batch tensor incorporates the padding.
|
||||
self.assertEqual(len(batch_t), 5)
|
||||
|
||||
def testMultipleBatch(self):
|
||||
"""Tests that multiple batched tensors execute together."""
|
||||
if context.executing_eagerly():
|
||||
return
|
||||
with self.cached_session() as sess:
|
||||
inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
|
||||
inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
|
||||
batched, _, _ = batch_ops.batch(
|
||||
[inp0, inp1],
|
||||
num_batch_threads=1,
|
||||
max_batch_size=2,
|
||||
batch_timeout_micros=36000000,
|
||||
grad_timeout_micros=0,
|
||||
batching_queue="")
|
||||
thread_results = []
|
||||
|
||||
def worker():
|
||||
thread_results.extend(
|
||||
sess.run([batched], feed_dict={inp0: [1],
|
||||
inp1: [2]}))
|
||||
|
||||
worker_thread = threading.Thread(target=worker)
|
||||
worker_thread.start()
|
||||
main_results = sess.run([batched], feed_dict={inp0: [2], inp1: [3]})
|
||||
worker_thread.join()
|
||||
|
||||
# At this point either the thread or the main did the batch and the other
|
||||
# should have empty results.
|
||||
if list(thread_results[0][0]):
|
||||
batch_t = thread_results[0]
|
||||
empty_t = main_results[0]
|
||||
else:
|
||||
batch_t = main_results[0]
|
||||
empty_t = thread_results[0]
|
||||
|
||||
# Assert that the tensors were batched together.
|
||||
self.assertAllEqual(sorted(batch_t[0]), [1, 2])
|
||||
self.assertAllEqual(sorted(batch_t[1]), [2, 3])
|
||||
self.assertAllEqual(empty_t[0], [])
|
||||
self.assertAllEqual(empty_t[1], [])
|
||||
|
||||
def testIllegalBatchDifferentDim0Sizes(self):
|
||||
"""Tests illegally feeding tensors with different dim0 sizes."""
|
||||
if context.executing_eagerly():
|
||||
return
|
||||
with self.cached_session() as sess:
|
||||
inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
|
||||
inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[2])
|
||||
batched, index, _ = batch_ops.batch(
|
||||
[inp0, inp1], num_batch_threads=1, max_batch_size=2,
|
||||
batch_timeout_micros=0, grad_timeout_micros=0, batching_queue="")
|
||||
with self.assertRaises(Exception) as raised:
|
||||
_ = sess.run([batched, index], feed_dict={inp0: [0], inp1: [1, 2]})
|
||||
self.assertGreater(
|
||||
raised.exception.message.find("must have equal 0th-dimension size"),
|
||||
0)
|
||||
|
||||
def testBasicUnbatch(self):
|
||||
"""Tests that batch and unbatch work together."""
|
||||
if context.executing_eagerly():
|
||||
return
|
||||
with self.cached_session() as sess:
|
||||
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
|
||||
batched, index, id_t = batch_ops.batch(
|
||||
[inp], num_batch_threads=1, max_batch_size=10,
|
||||
batch_timeout_micros=100000, # 100ms
|
||||
allowed_batch_sizes=[3, 10],
|
||||
grad_timeout_micros=0, batching_queue="")
|
||||
computation = batched[0] + 1
|
||||
result = batch_ops.unbatch(computation, index, id_t,
|
||||
timeout_micros=1000000, shared_name="unbatch")
|
||||
thread_results = []
|
||||
|
||||
def worker():
|
||||
thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
|
||||
|
||||
worker_thread = threading.Thread(target=worker)
|
||||
worker_thread.start()
|
||||
main_results = sess.run([result], feed_dict={inp: [2]})
|
||||
worker_thread.join()
|
||||
self.assertEqual(thread_results[0], [2])
|
||||
self.assertEqual(main_results[0], [3])
|
||||
|
||||
def testBasicUnbatchDecorated(self):
|
||||
"""Tests that the batch_function decorator works."""
|
||||
if context.executing_eagerly():
|
||||
return
|
||||
with self.cached_session() as sess:
|
||||
# TODO(apassos): Removing this line causes test flakiness! Ideally should
|
||||
# be investigated.
|
||||
default_inp = array_ops.placeholder_with_default(2, shape=[]) # pylint: disable=unused-variable
|
||||
|
||||
@batch_ops.batch_function(1, 10, 100000)
|
||||
def computation(in_t):
|
||||
self.assertTrue(in_t.shape is not None)
|
||||
return in_t + 1
|
||||
|
||||
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
|
||||
result = computation(inp)
|
||||
thread_results = []
|
||||
|
||||
def worker():
|
||||
thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
|
||||
|
||||
worker_thread = threading.Thread(target=worker)
|
||||
worker_thread.start()
|
||||
main_results = sess.run([result], feed_dict={inp: [2]})
|
||||
worker_thread.join()
|
||||
self.assertEqual(thread_results[0], [2])
|
||||
self.assertEqual(main_results[0], [3])
|
||||
|
||||
def testBatchDecoratedWithCapturedInput(self):
|
||||
"""Tests that the batch_function decorator works."""
|
||||
if context.executing_eagerly():
|
||||
return
|
||||
with self.cached_session() as sess:
|
||||
captured_inp0 = array_ops.placeholder_with_default(2, shape=[])
|
||||
captured_inp1 = array_ops.placeholder_with_default(1, shape=[])
|
||||
|
||||
@batch_ops.batch_function(1, 10, 100000)
|
||||
def computation(in_t):
|
||||
return in_t + captured_inp0 - captured_inp1
|
||||
|
||||
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
|
||||
result = computation(inp)
|
||||
thread_results = []
|
||||
|
||||
def worker():
|
||||
thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
|
||||
|
||||
worker_thread = threading.Thread(target=worker)
|
||||
worker_thread.start()
|
||||
main_results = sess.run([result], feed_dict={inp: [2]})
|
||||
worker_thread.join()
|
||||
self.assertEqual(thread_results[0], [2])
|
||||
self.assertEqual(main_results[0], [3])
|
||||
|
||||
def testBatchFunctionOp(self):
|
||||
"""Tests that the batch_function op works."""
|
||||
if context.executing_eagerly():
|
||||
return
|
||||
with self.cached_session() as sess:
|
||||
|
||||
@function.Defun(dtypes.int32)
|
||||
def computation(in_t):
|
||||
return in_t + 1
|
||||
|
||||
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
|
||||
result = gen_batch_ops.batch_function(
|
||||
[inp],
|
||||
num_batch_threads=1,
|
||||
max_batch_size=10,
|
||||
batch_timeout_micros=100000,
|
||||
Tout=[dtypes.int32],
|
||||
f=computation,
|
||||
captured_tensors=computation.captured_inputs)
|
||||
thread_results = []
|
||||
|
||||
def worker():
|
||||
thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
|
||||
|
||||
worker_thread = threading.Thread(target=worker)
|
||||
worker_thread.start()
|
||||
main_results = sess.run([result], feed_dict={inp: [2]})
|
||||
worker_thread.join()
|
||||
self.assertEqual(thread_results[0], [2])
|
||||
self.assertEqual(main_results[0], [3])
|
||||
|
||||
def testBatchFunctionOpWithCapturedInput(self):
|
||||
"""Tests that batch_function op works with captured input."""
|
||||
if context.executing_eagerly():
|
||||
return
|
||||
with self.cached_session() as sess:
|
||||
captured_inp0 = array_ops.placeholder_with_default(2, shape=[])
|
||||
captured_inp1 = array_ops.placeholder_with_default(1, shape=[])
|
||||
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
|
||||
|
||||
@function.Defun(dtypes.int32)
|
||||
def computation(inp):
|
||||
return inp + captured_inp0 - captured_inp1
|
||||
|
||||
result = gen_batch_ops.batch_function(
|
||||
num_batch_threads=1,
|
||||
max_batch_size=10,
|
||||
batch_timeout_micros=100000, # 100ms
|
||||
allowed_batch_sizes=[3, 10],
|
||||
batching_queue="",
|
||||
f=computation,
|
||||
in_tensors=[inp],
|
||||
captured_tensors=computation.captured_inputs,
|
||||
Tout=[o.type for o in computation.definition.signature.output_arg])
|
||||
|
||||
thread_results = []
|
||||
|
||||
def worker():
|
||||
thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
|
||||
|
||||
worker_thread = threading.Thread(target=worker)
|
||||
worker_thread.start()
|
||||
main_results = sess.run([result], feed_dict={inp: [2]})
|
||||
worker_thread.join()
|
||||
self.assertEqual(thread_results[0], [2])
|
||||
self.assertEqual(main_results[0], [3])
|
||||
|
||||
def testBatchFunctionOpWithInputError(self):
|
||||
"""Tests that batch_function op works with error in the inputs."""
|
||||
if context.executing_eagerly():
|
||||
return
|
||||
with self.cached_session() as sess:
|
||||
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
|
||||
|
||||
@function.Defun(dtypes.int32, dtypes.int32)
|
||||
def computation(in0, in1):
|
||||
return in0 + in1
|
||||
|
||||
result = gen_batch_ops.batch_function(
|
||||
[inp], # computation actually expects 2 inputs.
|
||||
num_batch_threads=1,
|
||||
max_batch_size=10,
|
||||
batch_timeout_micros=100000, # 100ms
|
||||
batching_queue="",
|
||||
f=computation,
|
||||
captured_tensors=computation.captured_inputs,
|
||||
Tout=[o.type for o in computation.definition.signature.output_arg])
|
||||
|
||||
with self.assertRaisesRegexp(InvalidArgumentError,
|
||||
".*2 arguments.*but 1.*"):
|
||||
sess.run([result], feed_dict={inp: [2]})
|
||||
|
||||
def testBasicUnbatchDecoratedWithReshape(self):
|
||||
"""Tests that the batch_function decorator works."""
|
||||
if context.executing_eagerly():
|
||||
return
|
||||
with self.cached_session() as sess:
|
||||
|
||||
@batch_ops.batch_function(1, 10, 100000)
|
||||
def computation(in_t):
|
||||
return array_ops.reshape(in_t, [-1]) + 1
|
||||
|
||||
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1, 1])
|
||||
result = computation(inp)
|
||||
thread_results = []
|
||||
|
||||
def worker():
|
||||
thread_results.extend(sess.run([result], feed_dict={inp: [[1]]}))
|
||||
|
||||
worker_thread = threading.Thread(target=worker)
|
||||
worker_thread.start()
|
||||
main_results = sess.run([result], feed_dict={inp: [[2]]})
|
||||
worker_thread.join()
|
||||
self.assertEqual(thread_results[0], [2])
|
||||
self.assertEqual(main_results[0], [3])
|
||||
|
||||
def testUnbatchTimeout(self):
|
||||
"""Tests that the unbatch timeout works."""
|
||||
if context.executing_eagerly():
|
||||
return
|
||||
with self.cached_session() as sess:
|
||||
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
|
||||
batched, index, id_t = batch_ops.batch(
|
||||
[inp], num_batch_threads=1, max_batch_size=2,
|
||||
batch_timeout_micros=36000000, grad_timeout_micros=0,
|
||||
batching_queue="")
|
||||
computation = batched[0] + 1
|
||||
timeout_micros = 10
|
||||
result = batch_ops.unbatch(computation, index, id_t, timeout_micros,
|
||||
shared_name="shared_unbatch")
|
||||
# Set up a parallel pipeline that delays the computation, but uses the
|
||||
# same unbatch resource object as the non-delayed pipeline.
|
||||
computation_delayed = script_ops.py_func(delayed_plus1,
|
||||
[batched[0]],
|
||||
dtypes.int32)
|
||||
result_delayed = batch_ops.unbatch(computation_delayed,
|
||||
index,
|
||||
id_t,
|
||||
timeout_micros,
|
||||
shared_name="shared_unbatch")
|
||||
|
||||
thread_results = []
|
||||
def worker():
|
||||
# A first call using the non-delayed pipeline. The batcher will send an
|
||||
# empty tensor along the non-delayed pipeline.
|
||||
thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
|
||||
worker_thread = threading.Thread(target=worker)
|
||||
worker_thread.start()
|
||||
time.sleep(0.1) # Ensure the thread's call starts first.
|
||||
# A second call using the delayed pipeline. The batcher will send the
|
||||
# batched tensor along the delayed pipeline, thus delaying the arrival of
|
||||
# the batched tensor at the unbatch op, relative to the empty tensor.
|
||||
#
|
||||
# TODO(olston, apassos): Avoid relying on the order in which the batch op
|
||||
# emits the empty tensor versus the batched one.
|
||||
_ = sess.run([result_delayed], feed_dict={inp: [2]})
|
||||
worker_thread.join()
|
||||
# The thread's call should hit the timeout, and thus get 0 results.
|
||||
self.assertEqual(len(thread_results), 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -54,6 +54,7 @@ from tensorflow.python.ops.control_flow_ops import tuple # pylint: disable=rede
|
||||
# pylint: enable=redefined-builtin
|
||||
from tensorflow.python.eager import wrap_function
|
||||
from tensorflow.python.ops.control_flow_ops import while_loop
|
||||
from tensorflow.python.ops.batch_ops import *
|
||||
from tensorflow.python.ops.critical_section_ops import *
|
||||
from tensorflow.python.ops.data_flow_ops import *
|
||||
from tensorflow.python.ops.functional_ops import *
|
||||
|
@ -1672,6 +1672,10 @@ tf_module {
|
||||
name: "no_regularizer"
|
||||
argspec: "args=[\'_\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "nondifferentiable_batch_function"
|
||||
argspec: "args=[\'num_batch_threads\', \'max_batch_size\', \'batch_timeout_micros\', \'allowed_batch_sizes\', \'max_enqueued_batches\', \'autograph\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'True\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "norm"
|
||||
argspec: "args=[\'tensor\', \'ord\', \'axis\', \'keepdims\', \'name\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'euclidean\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
|
@ -276,6 +276,10 @@ tf_module {
|
||||
name: "BarrierTakeMany"
|
||||
argspec: "args=[\'handle\', \'num_elements\', \'component_types\', \'allow_small_batch\', \'wait_for_incomplete\', \'timeout_ms\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'-1\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "Batch"
|
||||
argspec: "args=[\'in_tensors\', \'num_batch_threads\', \'max_batch_size\', \'batch_timeout_micros\', \'grad_timeout_micros\', \'max_enqueued_batches\', \'allowed_batch_sizes\', \'container\', \'shared_name\', \'batching_queue\', \'name\'], varargs=None, keywords=None, defaults=[\'10\', \'[]\', \'\', \'\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "BatchCholesky"
|
||||
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
@ -304,6 +308,10 @@ tf_module {
|
||||
name: "BatchFFT3D"
|
||||
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "BatchFunction"
|
||||
argspec: "args=[\'in_tensors\', \'captured_tensors\', \'f\', \'num_batch_threads\', \'max_batch_size\', \'batch_timeout_micros\', \'Tout\', \'max_enqueued_batches\', \'allowed_batch_sizes\', \'container\', \'shared_name\', \'batching_queue\', \'name\'], varargs=None, keywords=None, defaults=[\'10\', \'[]\', \'\', \'\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "BatchIFFT"
|
||||
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
@ -4228,6 +4236,14 @@ tf_module {
|
||||
name: "TruncatedNormal"
|
||||
argspec: "args=[\'shape\', \'dtype\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "Unbatch"
|
||||
argspec: "args=[\'batched_tensor\', \'batch_index\', \'id\', \'timeout_micros\', \'container\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "UnbatchGrad"
|
||||
argspec: "args=[\'original_input\', \'batch_index\', \'grad\', \'id\', \'container\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "UnicodeDecode"
|
||||
argspec: "args=[\'input\', \'input_encoding\', \'errors\', \'replacement_char\', \'replace_control_characters\', \'name\'], varargs=None, keywords=None, defaults=[\'replace\', \'65533\', \'False\', \'None\'], "
|
||||
|
@ -780,6 +780,10 @@ tf_module {
|
||||
name: "no_regularizer"
|
||||
argspec: "args=[\'_\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "nondifferentiable_batch_function"
|
||||
argspec: "args=[\'num_batch_threads\', \'max_batch_size\', \'batch_timeout_micros\', \'allowed_batch_sizes\', \'max_enqueued_batches\', \'autograph\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'True\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "norm"
|
||||
argspec: "args=[\'tensor\', \'ord\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'euclidean\', \'None\', \'None\', \'None\'], "
|
||||
|
@ -276,6 +276,10 @@ tf_module {
|
||||
name: "BarrierTakeMany"
|
||||
argspec: "args=[\'handle\', \'num_elements\', \'component_types\', \'allow_small_batch\', \'wait_for_incomplete\', \'timeout_ms\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'-1\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "Batch"
|
||||
argspec: "args=[\'in_tensors\', \'num_batch_threads\', \'max_batch_size\', \'batch_timeout_micros\', \'grad_timeout_micros\', \'max_enqueued_batches\', \'allowed_batch_sizes\', \'container\', \'shared_name\', \'batching_queue\', \'name\'], varargs=None, keywords=None, defaults=[\'10\', \'[]\', \'\', \'\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "BatchCholesky"
|
||||
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
@ -304,6 +308,10 @@ tf_module {
|
||||
name: "BatchFFT3D"
|
||||
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "BatchFunction"
|
||||
argspec: "args=[\'in_tensors\', \'captured_tensors\', \'f\', \'num_batch_threads\', \'max_batch_size\', \'batch_timeout_micros\', \'Tout\', \'max_enqueued_batches\', \'allowed_batch_sizes\', \'container\', \'shared_name\', \'batching_queue\', \'name\'], varargs=None, keywords=None, defaults=[\'10\', \'[]\', \'\', \'\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "BatchIFFT"
|
||||
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
@ -4228,6 +4236,14 @@ tf_module {
|
||||
name: "TruncatedNormal"
|
||||
argspec: "args=[\'shape\', \'dtype\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "Unbatch"
|
||||
argspec: "args=[\'batched_tensor\', \'batch_index\', \'id\', \'timeout_micros\', \'container\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "UnbatchGrad"
|
||||
argspec: "args=[\'original_input\', \'batch_index\', \'grad\', \'id\', \'container\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "UnicodeDecode"
|
||||
argspec: "args=[\'input\', \'input_encoding\', \'errors\', \'replacement_char\', \'replace_control_characters\', \'name\'], varargs=None, keywords=None, defaults=[\'replace\', \'65533\', \'False\', \'None\'], "
|
||||
|
@ -603,6 +603,8 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
||||
"tf.nest.map_structure",
|
||||
"tf.contrib.framework.nest.pack_sequence_as":
|
||||
"tf.nest.pack_sequence_as",
|
||||
"tf.contrib.batching.batch_function":
|
||||
"tf.nondifferentiable_batch_function",
|
||||
"tf.contrib.util.constant_value":
|
||||
"tf.get_static_value",
|
||||
"tf.contrib.saved_model.load_keras_model":
|
||||
|
Loading…
Reference in New Issue
Block a user