Move python batch op from contrib to core.

PiperOrigin-RevId: 242728245
This commit is contained in:
A. Unique TensorFlower 2019-04-09 13:28:03 -07:00 committed by TensorFlower Gardener
parent a1cffe1a7a
commit a1e538fc9f
16 changed files with 630 additions and 451 deletions

View File

@ -6,31 +6,18 @@ package(
licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load(
"//tensorflow:tensorflow.bzl",
"py_test",
"tf_custom_op_library",
"tf_gen_op_libs",
"tf_gen_op_wrapper_py",
"tf_kernel_library",
)
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
py_library(
name = "batch_py",
srcs = glob(["python/ops/*.py"]) + ["__init__.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/util:util_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:batch_ops",
"//tensorflow/python:batch_ops_gen",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:gradients",
"//tensorflow/python:platform",
"//tensorflow/python:script_ops",
"//tensorflow/python:util",
],
)

View File

@ -18,14 +18,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.eager import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import gen_batch_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_batch_ops import *
# pylint: enable=wildcard-import
# pylint: disable=unused-import
from tensorflow.python.ops.batch_ops import batch
from tensorflow.python.ops.batch_ops import batch_function
from tensorflow.python.ops.batch_ops import unbatch
# pylint: enable=unused-import
@ops.RegisterGradient("Batch")
@ -55,85 +54,6 @@ def _UnbatchGrad(op, grad): # pylint: disable=invalid-name
]
def batch_function(num_batch_threads,
max_batch_size,
batch_timeout_micros,
allowed_batch_sizes=None,
max_enqueued_batches=10):
"""Batches the computation done by the decorated function.
So, for example, in the following code
```python
@batch_function(1, 2, 3)
def layer(a):
return tf.matmul(a, a)
b = layer(w)
```
if more than one session.run call is simultaneously trying to compute `b`
the values of `w` will be gathered, non-deterministically concatenated
along the first axis, and only one thread will run the computation. See the
documentation of the `Batch` op for more details.
Assumes that all arguments of the decorated function are Tensors which will
be batched along their first dimension.
SparseTensor is not supported. The return value of the decorated function
must be a Tensor or a list/tuple of Tensors.
Args:
num_batch_threads: Number of scheduling threads for processing batches
of work. Determines the number of batches processed in parallel.
max_batch_size: Batch sizes will never be bigger than this.
batch_timeout_micros: Maximum number of microseconds to wait before
outputting an incomplete batch.
allowed_batch_sizes: Optional list of allowed batch sizes. If left empty,
does nothing. Otherwise, supplies a list of batch sizes, causing the op
to pad batches up to one of those sizes. The entries must increase
monotonically, and the final entry must equal max_batch_size.
max_enqueued_batches: The maximum depth of the batch queue. Defaults to 10.
Returns:
The decorated function will return the unbatched computation output Tensors.
"""
def decorator(fn): # pylint: disable=missing-docstring
def decorated(*args): # pylint: disable=missing-docstring
@function.defun(autograph=False)
def computation(*computation_args):
return fn(*computation_args)
computation = computation.get_concrete_function(
*[tensor_spec.TensorSpec(dtype=x.dtype, shape=x.shape, name=str(i))
for i, x in enumerate(args)])
with ops.name_scope("batch") as name:
for a in args:
if not isinstance(a, ops.Tensor):
raise ValueError("All arguments to functions decorated with "
"`batch_function` are supposed to be Tensors; "
"found %s" % repr(a))
return gen_batch_ops.batch_function(
num_batch_threads=num_batch_threads,
max_batch_size=max_batch_size,
batch_timeout_micros=batch_timeout_micros,
allowed_batch_sizes=allowed_batch_sizes,
max_enqueued_batches=max_enqueued_batches,
shared_name=name,
f=computation,
in_tensors=list(args),
captured_tensors=computation.captured_inputs,
Tout=[o.dtype for o in computation.outputs])
return decorated
return decorator
def batch_function_v1(num_batch_threads,
max_batch_size,
batch_timeout_micros,

View File

@ -23,12 +23,8 @@ import time
from tensorflow.contrib.batching.python.ops import batch_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework.errors import InvalidArgumentError
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_batch_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import script_ops
from tensorflow.python.platform import test
@ -41,153 +37,6 @@ def delayed_plus1(x):
class BatchOpsTest(test.TestCase):
"""Tests for batch_ops.{un,}batch."""
def testBasicBatch(self):
"""Tests that a single batched tensor executes together and only once."""
with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
batched, index, _ = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=2,
batch_timeout_micros=36000000, grad_timeout_micros=0,
batching_queue="")
thread_results = []
def worker():
thread_results.extend(
sess.run([batched, index], feed_dict={inp: [1]}))
worker_thread = threading.Thread(target=worker)
worker_thread.start()
main_results = sess.run([batched, index], feed_dict={inp: [2]})
worker_thread.join()
# At this point either the thread or the main did the batch and the other
# should have empty results.
if list(thread_results[0][0]):
batch_t = thread_results[0][0]
index_t = thread_results[1]
empty_b = main_results[0][0]
empty_m = main_results[1]
else:
batch_t = main_results[0][0]
index_t = main_results[1]
empty_b = thread_results[0][0]
empty_m = thread_results[1]
# Check that both the inputs made it out exactly once.
self.assertAllEqual(sorted(batch_t), (1, 2))
# Check that we get 2 rows in the index tensor.
self.assertEqual(len(index_t), 2)
# Check that the other ones are empty.
self.assertEqual(len(empty_b), 0)
self.assertEqual(len(empty_m), 0)
def testBatchWithPadding(self):
"""Test that batching with padding up to an allowed batch size works."""
with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[2])
batched, index, _ = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=10,
batch_timeout_micros=100000, # 100ms
allowed_batch_sizes=[5, 10],
grad_timeout_micros=0, batching_queue="")
thread_results = []
def worker():
thread_results.extend(
sess.run([batched, index], feed_dict={inp: [1, 3]}))
worker_thread = threading.Thread(target=worker)
worker_thread.start()
main_results = sess.run([batched, index], feed_dict={inp: [2, 4]})
worker_thread.join()
# At this point either the thread or the main did the batch and the other
# should have empty results.
if list(thread_results[0][0]):
batch_t = thread_results[0][0]
else:
batch_t = main_results[0][0]
# Check that the batch tensor incorporates the padding.
self.assertEqual(len(batch_t), 5)
def testMultipleBatch(self):
"""Tests that multiple batched tensors execute together."""
with self.cached_session() as sess:
inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
batched, _, _ = batch_ops.batch(
[inp0, inp1],
num_batch_threads=1,
max_batch_size=2,
batch_timeout_micros=36000000,
grad_timeout_micros=0,
batching_queue="")
thread_results = []
def worker():
thread_results.extend(
sess.run([batched], feed_dict={inp0: [1],
inp1: [2]}))
worker_thread = threading.Thread(target=worker)
worker_thread.start()
main_results = sess.run([batched], feed_dict={inp0: [2], inp1: [3]})
worker_thread.join()
# At this point either the thread or the main did the batch and the other
# should have empty results.
if list(thread_results[0][0]):
batch_t = thread_results[0]
empty_t = main_results[0]
else:
batch_t = main_results[0]
empty_t = thread_results[0]
# Assert that the tensors were batched together.
self.assertAllEqual(sorted(batch_t[0]), [1, 2])
self.assertAllEqual(sorted(batch_t[1]), [2, 3])
self.assertAllEqual(empty_t[0], [])
self.assertAllEqual(empty_t[1], [])
def testIllegalBatchDifferentDim0Sizes(self):
"""Tests illegally feeding tensors with different dim0 sizes."""
with self.cached_session() as sess:
inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[2])
batched, index, _ = batch_ops.batch(
[inp0, inp1], num_batch_threads=1, max_batch_size=2,
batch_timeout_micros=0, grad_timeout_micros=0, batching_queue="")
with self.assertRaises(Exception) as raised:
_ = sess.run([batched, index], feed_dict={inp0: [0], inp1: [1, 2]})
self.assertGreater(
raised.exception.message.find("must have equal 0th-dimension size"),
0)
def testBasicUnbatch(self):
"""Tests that batch and unbatch work together."""
with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
batched, index, id_t = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=10,
batch_timeout_micros=100000, # 100ms
allowed_batch_sizes=[3, 10],
grad_timeout_micros=0, batching_queue="")
computation = batched[0] + 1
result = batch_ops.unbatch(computation, index, id_t,
timeout_micros=1000000, shared_name="unbatch")
thread_results = []
def worker():
thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
worker_thread = threading.Thread(target=worker)
worker_thread.start()
main_results = sess.run([result], feed_dict={inp: [2]})
worker_thread.join()
self.assertEqual(thread_results[0], [2])
self.assertEqual(main_results[0], [3])
def testBasicUnbatchV1Decorated(self):
"""Tests that the batch_function_v1 decorator works."""
with self.cached_session() as sess:
@ -210,206 +59,6 @@ class BatchOpsTest(test.TestCase):
self.assertEqual(thread_results[0], [2])
self.assertEqual(main_results[0], [3])
def testBasicUnbatchDecorated(self):
"""Tests that the batch_function decorator works."""
with self.cached_session() as sess:
# TODO(apassos): Removing this line causes test flakiness! Ideally should
# be investigated.
default_inp = array_ops.placeholder_with_default(2, shape=[]) # pylint: disable=unused-variable
@batch_ops.batch_function(1, 10, 100000)
def computation(in_t):
self.assertTrue(in_t.shape is not None)
return in_t + 1
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
result = computation(inp)
thread_results = []
def worker():
thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
worker_thread = threading.Thread(target=worker)
worker_thread.start()
main_results = sess.run([result], feed_dict={inp: [2]})
worker_thread.join()
self.assertEqual(thread_results[0], [2])
self.assertEqual(main_results[0], [3])
def testBatchDecoratedWithCapturedInput(self):
"""Tests that the batch_function decorator works."""
with self.cached_session() as sess:
captured_inp0 = array_ops.placeholder_with_default(2, shape=[])
captured_inp1 = array_ops.placeholder_with_default(1, shape=[])
@batch_ops.batch_function(1, 10, 100000)
def computation(in_t):
return in_t + captured_inp0 - captured_inp1
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
result = computation(inp)
thread_results = []
def worker():
thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
worker_thread = threading.Thread(target=worker)
worker_thread.start()
main_results = sess.run([result], feed_dict={inp: [2]})
worker_thread.join()
self.assertEqual(thread_results[0], [2])
self.assertEqual(main_results[0], [3])
def testBatchFunctionOp(self):
"""Tests that the batch_function op works."""
with self.cached_session() as sess:
@function.Defun(dtypes.int32)
def computation(in_t):
return in_t + 1
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
result = gen_batch_ops.batch_function(
[inp],
num_batch_threads=1,
max_batch_size=10,
batch_timeout_micros=100000,
Tout=[dtypes.int32],
f=computation,
captured_tensors=computation.captured_inputs)
thread_results = []
def worker():
thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
worker_thread = threading.Thread(target=worker)
worker_thread.start()
main_results = sess.run([result], feed_dict={inp: [2]})
worker_thread.join()
self.assertEqual(thread_results[0], [2])
self.assertEqual(main_results[0], [3])
def testBatchFunctionOpWithCapturedInput(self):
"""Tests that batch_function op works with captured input."""
with self.cached_session() as sess:
captured_inp0 = array_ops.placeholder_with_default(2, shape=[])
captured_inp1 = array_ops.placeholder_with_default(1, shape=[])
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
@function.Defun(dtypes.int32)
def computation(inp):
return inp + captured_inp0 - captured_inp1
result = gen_batch_ops.batch_function(
num_batch_threads=1,
max_batch_size=10,
batch_timeout_micros=100000, # 100ms
allowed_batch_sizes=[3, 10],
batching_queue="",
f=computation,
in_tensors=[inp],
captured_tensors=computation.captured_inputs,
Tout=[o.type for o in computation.definition.signature.output_arg])
thread_results = []
def worker():
thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
worker_thread = threading.Thread(target=worker)
worker_thread.start()
main_results = sess.run([result], feed_dict={inp: [2]})
worker_thread.join()
self.assertEqual(thread_results[0], [2])
self.assertEqual(main_results[0], [3])
def testBatchFunctionOpWithInputError(self):
"""Tests that batch_function op works with error in the inputs."""
with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
@function.Defun(dtypes.int32, dtypes.int32)
def computation(in0, in1):
return in0 + in1
result = gen_batch_ops.batch_function(
[inp], # computation actually expects 2 inputs.
num_batch_threads=1,
max_batch_size=10,
batch_timeout_micros=100000, # 100ms
batching_queue="",
f=computation,
captured_tensors=computation.captured_inputs,
Tout=[o.type for o in computation.definition.signature.output_arg])
with self.assertRaisesRegexp(InvalidArgumentError,
".*2 arguments.*but 1.*"):
sess.run([result], feed_dict={inp: [2]})
def testBasicUnbatchDecoratedWithReshape(self):
"""Tests that the batch_function decorator works."""
with self.cached_session() as sess:
@batch_ops.batch_function(1, 10, 100000)
def computation(in_t):
return array_ops.reshape(in_t, [-1]) + 1
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1, 1])
result = computation(inp)
thread_results = []
def worker():
thread_results.extend(sess.run([result], feed_dict={inp: [[1]]}))
worker_thread = threading.Thread(target=worker)
worker_thread.start()
main_results = sess.run([result], feed_dict={inp: [[2]]})
worker_thread.join()
self.assertEqual(thread_results[0], [2])
self.assertEqual(main_results[0], [3])
def testUnbatchTimeout(self):
"""Tests that the unbatch timeout works."""
with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
batched, index, id_t = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=2,
batch_timeout_micros=36000000, grad_timeout_micros=0,
batching_queue="")
computation = batched[0] + 1
timeout_micros = 10
result = batch_ops.unbatch(computation, index, id_t, timeout_micros,
shared_name="shared_unbatch")
# Set up a parallel pipeline that delays the computation, but uses the
# same unbatch resource object as the non-delayed pipeline.
computation_delayed = script_ops.py_func(delayed_plus1,
[batched[0]],
dtypes.int32)
result_delayed = batch_ops.unbatch(computation_delayed,
index,
id_t,
timeout_micros,
shared_name="shared_unbatch")
thread_results = []
def worker():
# A first call using the non-delayed pipeline. The batcher will send an
# empty tensor along the non-delayed pipeline.
thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
worker_thread = threading.Thread(target=worker)
worker_thread.start()
time.sleep(0.1) # Ensure the thread's call starts first.
# A second call using the delayed pipeline. The batcher will send the
# batched tensor along the delayed pipeline, thus delaying the arrival of
# the batched tensor at the unbatch op, relative to the empty tensor.
#
# TODO(olston, apassos): Avoid relying on the order in which the batch op
# emits the empty tensor versus the batched one.
_ = sess.run([result_delayed], feed_dict={inp: [2]})
worker_thread.join()
# The thread's call should hit the timeout, and thus get 0 results.
self.assertEqual(len(thread_results), 0)
def testUnbatchGrad(self):
"""Tests that batch and unbatch are differentiable."""
with self.cached_session() as sess:
@ -434,6 +83,5 @@ class BatchOpsTest(test.TestCase):
self.assertEqual(thread_results[0], [2])
self.assertEqual(main_results[0], [4])
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "Batch"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "BatchFunction"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "Unbatch"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "UnbatchGrad"
visibility: HIDDEN
}

View File

@ -1961,6 +1961,38 @@ tf_gen_op_wrapper_private_py(
],
)
py_library(
name = "batch_ops",
srcs = [
"ops/batch_ops.py",
],
srcs_version = "PY2AND3",
deps = [
":batch_ops_gen",
],
)
py_test(
name = "batch_ops_test",
size = "small",
srcs = ["ops/batch_ops_test.py"],
srcs_version = "PY2AND3",
tags = [
"manual",
"no_pip",
"nomac",
],
deps = [
":batch_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework",
"//tensorflow/python:gradients",
"//tensorflow/python:script_ops",
],
)
tf_gen_op_wrapper_private_py(
name = "manip_ops_gen",
visibility = [
@ -3304,6 +3336,7 @@ py_library(
deps = [
":array_grad",
":array_ops",
":batch_ops",
":check_ops",
":clip_ops",
":confusion_matrix",

View File

@ -0,0 +1,111 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Operations for automatic batching and unbatching."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.eager import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import gen_batch_ops
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_batch_ops import *
# pylint: enable=wildcard-import
from tensorflow.python.util.tf_export import tf_export
@tf_export("nondifferentiable_batch_function")
def batch_function(num_batch_threads,
max_batch_size,
batch_timeout_micros,
allowed_batch_sizes=None,
max_enqueued_batches=10,
autograph=True):
"""Batches the computation done by the decorated function.
So, for example, in the following code
```python
@batch_function(1, 2, 3)
def layer(a):
return tf.matmul(a, a)
b = layer(w)
```
if more than one session.run call is simultaneously trying to compute `b`
the values of `w` will be gathered, non-deterministically concatenated
along the first axis, and only one thread will run the computation. See the
documentation of the `Batch` op for more details.
Assumes that all arguments of the decorated function are Tensors which will
be batched along their first dimension.
SparseTensor is not supported. The return value of the decorated function
must be a Tensor or a list/tuple of Tensors.
Args:
num_batch_threads: Number of scheduling threads for processing batches
of work. Determines the number of batches processed in parallel.
max_batch_size: Batch sizes will never be bigger than this.
batch_timeout_micros: Maximum number of microseconds to wait before
outputting an incomplete batch.
allowed_batch_sizes: Optional list of allowed batch sizes. If left empty,
does nothing. Otherwise, supplies a list of batch sizes, causing the op
to pad batches up to one of those sizes. The entries must increase
monotonically, and the final entry must equal max_batch_size.
max_enqueued_batches: The maximum depth of the batch queue. Defaults to 10.
autograph: Whether to use autograph to compile python and eager style code
for efficient graph-mode execution.
Returns:
The decorated function will return the unbatched computation output Tensors.
"""
def decorator(fn): # pylint: disable=missing-docstring
def decorated(*args): # pylint: disable=missing-docstring
@function.defun(autograph=autograph)
def computation(*computation_args):
return fn(*computation_args)
computation = computation.get_concrete_function(
*[tensor_spec.TensorSpec(dtype=x.dtype, shape=x.shape, name=str(i))
for i, x in enumerate(args)])
with ops.name_scope("batch") as name:
for a in args:
if not isinstance(a, ops.Tensor):
raise ValueError("All arguments to functions decorated with "
"`batch_function` are supposed to be Tensors; "
"found %s" % repr(a))
return gen_batch_ops.batch_function(
num_batch_threads=num_batch_threads,
max_batch_size=max_batch_size,
batch_timeout_micros=batch_timeout_micros,
allowed_batch_sizes=allowed_batch_sizes,
max_enqueued_batches=max_enqueued_batches,
shared_name=name,
f=computation,
in_tensors=list(args),
captured_tensors=computation.captured_inputs,
Tout=[o.dtype for o in computation.outputs])
return decorated
return decorator

View File

@ -0,0 +1,421 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the currently experimental in-graph batch ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import threading
import time
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import test_util
from tensorflow.python.framework.errors import InvalidArgumentError
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import batch_ops
from tensorflow.python.ops import gen_batch_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.platform import test
def delayed_plus1(x):
"""Sleeps for 100ms then returns x+1."""
time.sleep(0.1)
return x + 1
@test_util.run_all_in_graph_and_eager_modes
class BatchOpsTest(test.TestCase):
"""Tests for batch_ops.{un,}batch."""
# Test for only non eager mode as batching in eager context as a functionality
# is TBD.
def testBasicBatch(self):
"""Tests that a single batched tensor executes together and only once."""
if context.executing_eagerly():
return
with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
batched, index, _ = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=2,
batch_timeout_micros=36000000, grad_timeout_micros=0,
batching_queue="")
thread_results = []
def worker():
thread_results.extend(
sess.run([batched, index], feed_dict={inp: [1]}))
worker_thread = threading.Thread(target=worker)
worker_thread.start()
main_results = sess.run([batched, index], feed_dict={inp: [2]})
worker_thread.join()
# At this point either the thread or the main did the batch and the other
# should have empty results.
if list(thread_results[0][0]):
batch_t = thread_results[0][0]
index_t = thread_results[1]
empty_b = main_results[0][0]
empty_m = main_results[1]
else:
batch_t = main_results[0][0]
index_t = main_results[1]
empty_b = thread_results[0][0]
empty_m = thread_results[1]
# Check that both the inputs made it out exactly once.
self.assertAllEqual(sorted(batch_t), (1, 2))
# Check that we get 2 rows in the index tensor.
self.assertEqual(len(index_t), 2)
# Check that the other ones are empty.
self.assertEqual(len(empty_b), 0)
self.assertEqual(len(empty_m), 0)
def testBatchWithPadding(self):
"""Test that batching with padding up to an allowed batch size works."""
if context.executing_eagerly():
return
with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[2])
batched, index, _ = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=10,
batch_timeout_micros=100000, # 100ms
allowed_batch_sizes=[5, 10],
grad_timeout_micros=0, batching_queue="")
thread_results = []
def worker():
thread_results.extend(
sess.run([batched, index], feed_dict={inp: [1, 3]}))
worker_thread = threading.Thread(target=worker)
worker_thread.start()
main_results = sess.run([batched, index], feed_dict={inp: [2, 4]})
worker_thread.join()
# At this point either the thread or the main did the batch and the other
# should have empty results.
if list(thread_results[0][0]):
batch_t = thread_results[0][0]
else:
batch_t = main_results[0][0]
# Check that the batch tensor incorporates the padding.
self.assertEqual(len(batch_t), 5)
def testMultipleBatch(self):
"""Tests that multiple batched tensors execute together."""
if context.executing_eagerly():
return
with self.cached_session() as sess:
inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
batched, _, _ = batch_ops.batch(
[inp0, inp1],
num_batch_threads=1,
max_batch_size=2,
batch_timeout_micros=36000000,
grad_timeout_micros=0,
batching_queue="")
thread_results = []
def worker():
thread_results.extend(
sess.run([batched], feed_dict={inp0: [1],
inp1: [2]}))
worker_thread = threading.Thread(target=worker)
worker_thread.start()
main_results = sess.run([batched], feed_dict={inp0: [2], inp1: [3]})
worker_thread.join()
# At this point either the thread or the main did the batch and the other
# should have empty results.
if list(thread_results[0][0]):
batch_t = thread_results[0]
empty_t = main_results[0]
else:
batch_t = main_results[0]
empty_t = thread_results[0]
# Assert that the tensors were batched together.
self.assertAllEqual(sorted(batch_t[0]), [1, 2])
self.assertAllEqual(sorted(batch_t[1]), [2, 3])
self.assertAllEqual(empty_t[0], [])
self.assertAllEqual(empty_t[1], [])
def testIllegalBatchDifferentDim0Sizes(self):
"""Tests illegally feeding tensors with different dim0 sizes."""
if context.executing_eagerly():
return
with self.cached_session() as sess:
inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[2])
batched, index, _ = batch_ops.batch(
[inp0, inp1], num_batch_threads=1, max_batch_size=2,
batch_timeout_micros=0, grad_timeout_micros=0, batching_queue="")
with self.assertRaises(Exception) as raised:
_ = sess.run([batched, index], feed_dict={inp0: [0], inp1: [1, 2]})
self.assertGreater(
raised.exception.message.find("must have equal 0th-dimension size"),
0)
def testBasicUnbatch(self):
"""Tests that batch and unbatch work together."""
if context.executing_eagerly():
return
with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
batched, index, id_t = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=10,
batch_timeout_micros=100000, # 100ms
allowed_batch_sizes=[3, 10],
grad_timeout_micros=0, batching_queue="")
computation = batched[0] + 1
result = batch_ops.unbatch(computation, index, id_t,
timeout_micros=1000000, shared_name="unbatch")
thread_results = []
def worker():
thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
worker_thread = threading.Thread(target=worker)
worker_thread.start()
main_results = sess.run([result], feed_dict={inp: [2]})
worker_thread.join()
self.assertEqual(thread_results[0], [2])
self.assertEqual(main_results[0], [3])
def testBasicUnbatchDecorated(self):
"""Tests that the batch_function decorator works."""
if context.executing_eagerly():
return
with self.cached_session() as sess:
# TODO(apassos): Removing this line causes test flakiness! Ideally should
# be investigated.
default_inp = array_ops.placeholder_with_default(2, shape=[]) # pylint: disable=unused-variable
@batch_ops.batch_function(1, 10, 100000)
def computation(in_t):
self.assertTrue(in_t.shape is not None)
return in_t + 1
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
result = computation(inp)
thread_results = []
def worker():
thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
worker_thread = threading.Thread(target=worker)
worker_thread.start()
main_results = sess.run([result], feed_dict={inp: [2]})
worker_thread.join()
self.assertEqual(thread_results[0], [2])
self.assertEqual(main_results[0], [3])
def testBatchDecoratedWithCapturedInput(self):
"""Tests that the batch_function decorator works."""
if context.executing_eagerly():
return
with self.cached_session() as sess:
captured_inp0 = array_ops.placeholder_with_default(2, shape=[])
captured_inp1 = array_ops.placeholder_with_default(1, shape=[])
@batch_ops.batch_function(1, 10, 100000)
def computation(in_t):
return in_t + captured_inp0 - captured_inp1
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
result = computation(inp)
thread_results = []
def worker():
thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
worker_thread = threading.Thread(target=worker)
worker_thread.start()
main_results = sess.run([result], feed_dict={inp: [2]})
worker_thread.join()
self.assertEqual(thread_results[0], [2])
self.assertEqual(main_results[0], [3])
def testBatchFunctionOp(self):
"""Tests that the batch_function op works."""
if context.executing_eagerly():
return
with self.cached_session() as sess:
@function.Defun(dtypes.int32)
def computation(in_t):
return in_t + 1
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
result = gen_batch_ops.batch_function(
[inp],
num_batch_threads=1,
max_batch_size=10,
batch_timeout_micros=100000,
Tout=[dtypes.int32],
f=computation,
captured_tensors=computation.captured_inputs)
thread_results = []
def worker():
thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
worker_thread = threading.Thread(target=worker)
worker_thread.start()
main_results = sess.run([result], feed_dict={inp: [2]})
worker_thread.join()
self.assertEqual(thread_results[0], [2])
self.assertEqual(main_results[0], [3])
def testBatchFunctionOpWithCapturedInput(self):
"""Tests that batch_function op works with captured input."""
if context.executing_eagerly():
return
with self.cached_session() as sess:
captured_inp0 = array_ops.placeholder_with_default(2, shape=[])
captured_inp1 = array_ops.placeholder_with_default(1, shape=[])
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
@function.Defun(dtypes.int32)
def computation(inp):
return inp + captured_inp0 - captured_inp1
result = gen_batch_ops.batch_function(
num_batch_threads=1,
max_batch_size=10,
batch_timeout_micros=100000, # 100ms
allowed_batch_sizes=[3, 10],
batching_queue="",
f=computation,
in_tensors=[inp],
captured_tensors=computation.captured_inputs,
Tout=[o.type for o in computation.definition.signature.output_arg])
thread_results = []
def worker():
thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
worker_thread = threading.Thread(target=worker)
worker_thread.start()
main_results = sess.run([result], feed_dict={inp: [2]})
worker_thread.join()
self.assertEqual(thread_results[0], [2])
self.assertEqual(main_results[0], [3])
def testBatchFunctionOpWithInputError(self):
"""Tests that batch_function op works with error in the inputs."""
if context.executing_eagerly():
return
with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
@function.Defun(dtypes.int32, dtypes.int32)
def computation(in0, in1):
return in0 + in1
result = gen_batch_ops.batch_function(
[inp], # computation actually expects 2 inputs.
num_batch_threads=1,
max_batch_size=10,
batch_timeout_micros=100000, # 100ms
batching_queue="",
f=computation,
captured_tensors=computation.captured_inputs,
Tout=[o.type for o in computation.definition.signature.output_arg])
with self.assertRaisesRegexp(InvalidArgumentError,
".*2 arguments.*but 1.*"):
sess.run([result], feed_dict={inp: [2]})
def testBasicUnbatchDecoratedWithReshape(self):
"""Tests that the batch_function decorator works."""
if context.executing_eagerly():
return
with self.cached_session() as sess:
@batch_ops.batch_function(1, 10, 100000)
def computation(in_t):
return array_ops.reshape(in_t, [-1]) + 1
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1, 1])
result = computation(inp)
thread_results = []
def worker():
thread_results.extend(sess.run([result], feed_dict={inp: [[1]]}))
worker_thread = threading.Thread(target=worker)
worker_thread.start()
main_results = sess.run([result], feed_dict={inp: [[2]]})
worker_thread.join()
self.assertEqual(thread_results[0], [2])
self.assertEqual(main_results[0], [3])
def testUnbatchTimeout(self):
"""Tests that the unbatch timeout works."""
if context.executing_eagerly():
return
with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
batched, index, id_t = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=2,
batch_timeout_micros=36000000, grad_timeout_micros=0,
batching_queue="")
computation = batched[0] + 1
timeout_micros = 10
result = batch_ops.unbatch(computation, index, id_t, timeout_micros,
shared_name="shared_unbatch")
# Set up a parallel pipeline that delays the computation, but uses the
# same unbatch resource object as the non-delayed pipeline.
computation_delayed = script_ops.py_func(delayed_plus1,
[batched[0]],
dtypes.int32)
result_delayed = batch_ops.unbatch(computation_delayed,
index,
id_t,
timeout_micros,
shared_name="shared_unbatch")
thread_results = []
def worker():
# A first call using the non-delayed pipeline. The batcher will send an
# empty tensor along the non-delayed pipeline.
thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
worker_thread = threading.Thread(target=worker)
worker_thread.start()
time.sleep(0.1) # Ensure the thread's call starts first.
# A second call using the delayed pipeline. The batcher will send the
# batched tensor along the delayed pipeline, thus delaying the arrival of
# the batched tensor at the unbatch op, relative to the empty tensor.
#
# TODO(olston, apassos): Avoid relying on the order in which the batch op
# emits the empty tensor versus the batched one.
_ = sess.run([result_delayed], feed_dict={inp: [2]})
worker_thread.join()
# The thread's call should hit the timeout, and thus get 0 results.
self.assertEqual(len(thread_results), 0)
if __name__ == "__main__":
test.main()

View File

@ -54,6 +54,7 @@ from tensorflow.python.ops.control_flow_ops import tuple # pylint: disable=rede
# pylint: enable=redefined-builtin
from tensorflow.python.eager import wrap_function
from tensorflow.python.ops.control_flow_ops import while_loop
from tensorflow.python.ops.batch_ops import *
from tensorflow.python.ops.critical_section_ops import *
from tensorflow.python.ops.data_flow_ops import *
from tensorflow.python.ops.functional_ops import *

View File

@ -1672,6 +1672,10 @@ tf_module {
name: "no_regularizer"
argspec: "args=[\'_\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "nondifferentiable_batch_function"
argspec: "args=[\'num_batch_threads\', \'max_batch_size\', \'batch_timeout_micros\', \'allowed_batch_sizes\', \'max_enqueued_batches\', \'autograph\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'True\'], "
}
member_method {
name: "norm"
argspec: "args=[\'tensor\', \'ord\', \'axis\', \'keepdims\', \'name\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'euclidean\', \'None\', \'None\', \'None\', \'None\'], "

View File

@ -276,6 +276,10 @@ tf_module {
name: "BarrierTakeMany"
argspec: "args=[\'handle\', \'num_elements\', \'component_types\', \'allow_small_batch\', \'wait_for_incomplete\', \'timeout_ms\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'-1\', \'None\'], "
}
member_method {
name: "Batch"
argspec: "args=[\'in_tensors\', \'num_batch_threads\', \'max_batch_size\', \'batch_timeout_micros\', \'grad_timeout_micros\', \'max_enqueued_batches\', \'allowed_batch_sizes\', \'container\', \'shared_name\', \'batching_queue\', \'name\'], varargs=None, keywords=None, defaults=[\'10\', \'[]\', \'\', \'\', \'\', \'None\'], "
}
member_method {
name: "BatchCholesky"
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
@ -304,6 +308,10 @@ tf_module {
name: "BatchFFT3D"
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "BatchFunction"
argspec: "args=[\'in_tensors\', \'captured_tensors\', \'f\', \'num_batch_threads\', \'max_batch_size\', \'batch_timeout_micros\', \'Tout\', \'max_enqueued_batches\', \'allowed_batch_sizes\', \'container\', \'shared_name\', \'batching_queue\', \'name\'], varargs=None, keywords=None, defaults=[\'10\', \'[]\', \'\', \'\', \'\', \'None\'], "
}
member_method {
name: "BatchIFFT"
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
@ -4228,6 +4236,14 @@ tf_module {
name: "TruncatedNormal"
argspec: "args=[\'shape\', \'dtype\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'None\'], "
}
member_method {
name: "Unbatch"
argspec: "args=[\'batched_tensor\', \'batch_index\', \'id\', \'timeout_micros\', \'container\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'None\'], "
}
member_method {
name: "UnbatchGrad"
argspec: "args=[\'original_input\', \'batch_index\', \'grad\', \'id\', \'container\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'None\'], "
}
member_method {
name: "UnicodeDecode"
argspec: "args=[\'input\', \'input_encoding\', \'errors\', \'replacement_char\', \'replace_control_characters\', \'name\'], varargs=None, keywords=None, defaults=[\'replace\', \'65533\', \'False\', \'None\'], "

View File

@ -780,6 +780,10 @@ tf_module {
name: "no_regularizer"
argspec: "args=[\'_\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "nondifferentiable_batch_function"
argspec: "args=[\'num_batch_threads\', \'max_batch_size\', \'batch_timeout_micros\', \'allowed_batch_sizes\', \'max_enqueued_batches\', \'autograph\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'True\'], "
}
member_method {
name: "norm"
argspec: "args=[\'tensor\', \'ord\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'euclidean\', \'None\', \'None\', \'None\'], "

View File

@ -276,6 +276,10 @@ tf_module {
name: "BarrierTakeMany"
argspec: "args=[\'handle\', \'num_elements\', \'component_types\', \'allow_small_batch\', \'wait_for_incomplete\', \'timeout_ms\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'-1\', \'None\'], "
}
member_method {
name: "Batch"
argspec: "args=[\'in_tensors\', \'num_batch_threads\', \'max_batch_size\', \'batch_timeout_micros\', \'grad_timeout_micros\', \'max_enqueued_batches\', \'allowed_batch_sizes\', \'container\', \'shared_name\', \'batching_queue\', \'name\'], varargs=None, keywords=None, defaults=[\'10\', \'[]\', \'\', \'\', \'\', \'None\'], "
}
member_method {
name: "BatchCholesky"
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
@ -304,6 +308,10 @@ tf_module {
name: "BatchFFT3D"
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "BatchFunction"
argspec: "args=[\'in_tensors\', \'captured_tensors\', \'f\', \'num_batch_threads\', \'max_batch_size\', \'batch_timeout_micros\', \'Tout\', \'max_enqueued_batches\', \'allowed_batch_sizes\', \'container\', \'shared_name\', \'batching_queue\', \'name\'], varargs=None, keywords=None, defaults=[\'10\', \'[]\', \'\', \'\', \'\', \'None\'], "
}
member_method {
name: "BatchIFFT"
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
@ -4228,6 +4236,14 @@ tf_module {
name: "TruncatedNormal"
argspec: "args=[\'shape\', \'dtype\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'None\'], "
}
member_method {
name: "Unbatch"
argspec: "args=[\'batched_tensor\', \'batch_index\', \'id\', \'timeout_micros\', \'container\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'None\'], "
}
member_method {
name: "UnbatchGrad"
argspec: "args=[\'original_input\', \'batch_index\', \'grad\', \'id\', \'container\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'None\'], "
}
member_method {
name: "UnicodeDecode"
argspec: "args=[\'input\', \'input_encoding\', \'errors\', \'replacement_char\', \'replace_control_characters\', \'name\'], varargs=None, keywords=None, defaults=[\'replace\', \'65533\', \'False\', \'None\'], "

View File

@ -603,6 +603,8 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
"tf.nest.map_structure",
"tf.contrib.framework.nest.pack_sequence_as":
"tf.nest.pack_sequence_as",
"tf.contrib.batching.batch_function":
"tf.nondifferentiable_batch_function",
"tf.contrib.util.constant_value":
"tf.get_static_value",
"tf.contrib.saved_model.load_keras_model":