Change accumulate_n() to no longer require the output shape to be known in advance, and to no longer blocked until the first input is available.

Change: 129657548
This commit is contained in:
A. Unique TensorFlower 2016-08-08 11:13:19 -08:00 committed by TensorFlower Gardener
parent 8f6a17f052
commit 8e58388189
4 changed files with 200 additions and 35 deletions

View File

@ -1439,6 +1439,15 @@ filegroup(
visibility = ["//tensorflow:__subpackages__"],
)
cuda_py_test(
name = "accumulate_n_benchmark",
size = "large",
srcs = [
"ops/accumulate_n_benchmark.py",
],
main = "ops/accumulate_n_benchmark.py",
)
cuda_py_test(
name = "batch_norm_benchmark",
srcs = [

View File

@ -0,0 +1,137 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Benchmark for accumulate_n() in math_ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import random
import time
import tensorflow as tf
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import gen_control_flow_ops
from tensorflow.python.ops import gen_state_ops
class AccumulateNBenchmark(tf.test.Benchmark):
def _AccumulateNTemplate(self, inputs, init, shape, validate_shape):
var = gen_state_ops._temporary_variable(
shape=shape, dtype=inputs[0].dtype.base_dtype)
ref = tf.assign(var, init, validate_shape=validate_shape)
update_ops = [tf.assign_add(ref, tensor, use_locking=True).op
for tensor in inputs]
with tf.control_dependencies(update_ops):
return gen_state_ops._destroy_temporary_variable(
ref, var_name=var.op.name)
def _AccumulateNInitializedWithFirst(self, inputs):
return self._AccumulateNTemplate(
inputs, init=tf.zeros_like(inputs[0]),
shape=inputs[0].get_shape(), validate_shape=True)
def _AccumulateNInitializedWithMerge(self, inputs):
return self._AccumulateNTemplate(
inputs,
init=tf.zeros_like(gen_control_flow_ops._merge(inputs)[0]),
shape=tensor_shape.vector(0),
validate_shape=False)
def _AccumulateNInitializedWithShape(self, inputs):
return self._AccumulateNTemplate(
inputs,
init=tf.zeros(shape=inputs[0].get_shape(),
dtype=inputs[0].dtype.base_dtype),
shape=inputs[0].get_shape(),
validate_shape=True)
def _GenerateUnorderedInputs(self, size, n):
inputs = [tf.random_uniform(shape=[size]) for _ in xrange(n)]
random.shuffle(inputs)
return inputs
def _GenerateReplicatedInputs(self, size, n):
return n * self._GenerateUnorderedInputs(size, 1)
def _GenerateOrderedInputs(self, size, n):
inputs = self._GenerateUnorderedInputs(size, 1)
queue = tf.FIFOQueue(capacity=1, dtypes=[inputs[0].dtype],
shapes=[inputs[0].get_shape()])
for _ in xrange(n - 1):
op = queue.enqueue(inputs[-1])
with tf.control_dependencies([op]):
inputs.append(tf.tanh(1.0 + queue.dequeue()))
return inputs
def _GenerateReversedInputs(self, size, n):
inputs = self._GenerateOrderedInputs(size, n)
inputs.reverse()
return inputs
def _SetupAndRunBenchmark(self, graph, inputs, repeats, format_args):
with graph.as_default():
add_n = tf.add_n(inputs)
acc_n_first = self._AccumulateNInitializedWithFirst(inputs)
acc_n_merge = self._AccumulateNInitializedWithMerge(inputs)
acc_n_shape = self._AccumulateNInitializedWithShape(inputs)
test_ops = (("AddN", add_n.op),
("AccNFirst", acc_n_first.op),
("AccNMerge", acc_n_merge.op),
("AccNShape", acc_n_shape.op))
with tf.Session(graph=graph):
for tag, op in test_ops:
for _ in xrange(100):
op.run() # Run for warm up.
start = time.time()
for _ in xrange(repeats):
op.run()
duration = time.time() - start
args = format_args + (tag, duration)
print(self._template.format(*args))
def _RunBenchmark(self, tag, input_fn, sizes, ninputs, repeats):
for size in sizes:
for ninput in ninputs:
graph = tf.Graph()
with graph.as_default():
inputs = input_fn(size, ninput)
format_args = (tag, size, ninput, repeats)
self._SetupAndRunBenchmark(graph, inputs, repeats, format_args)
def benchmarkAccumulateN(self):
self._template = "{:<15}" * 6
args = {"sizes": (128, 128 ** 2),
"ninputs": (1, 10, 100, 300),
"repeats": 100}
benchmarks = (("Replicated", self._GenerateReplicatedInputs),
("Unordered", self._GenerateUnorderedInputs),
("Ordered", self._GenerateOrderedInputs),
("Reversed", self._GenerateReversedInputs))
print(self._template.format(
"", "Size", "#Inputs", "#Repeat", "Method", "Duration"))
print("-" * 90)
for benchmark in benchmarks:
self._RunBenchmark(*benchmark, **args)
if __name__ == '__main__':
tf.test.main()

View File

@ -228,6 +228,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_control_flow_ops
from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import gen_sparse_ops
@ -1505,46 +1506,39 @@ def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None):
ValueError: If `inputs` don't all have same shape and dtype or the shape
cannot be inferred.
"""
if tensor_dtype is None:
if not inputs or not isinstance(inputs, (list, tuple)):
raise ValueError("inputs must be a list of at least one Tensor with the "
"same dtype and shape")
inputs = ops.convert_n_to_tensor_or_indexed_slices(inputs)
if not all(isinstance(x, ops.Tensor) for x in inputs):
raise ValueError("inputs must be a list of at least one Tensor with the "
"same dtype and shape")
if not all(x.dtype == inputs[0].dtype for x in inputs):
raise ValueError("inputs must be a list of at least one Tensor with the "
"same dtype and shape")
tensor_dtype = inputs[0].dtype
if not inputs or not isinstance(inputs, (list, tuple)):
raise ValueError("inputs must be a list of at least one Tensor with the "
"same dtype and shape")
inputs = ops.convert_n_to_tensor_or_indexed_slices(inputs)
if not all(isinstance(x, ops.Tensor) for x in inputs):
raise ValueError("inputs must be a list of at least one Tensor with the "
"same dtype and shape")
if not all(x.dtype == inputs[0].dtype for x in inputs):
raise ValueError("inputs must be a list of at least one Tensor with the "
"same dtype and shape")
if shape is not None:
shape = tensor_shape.as_shape(shape)
else:
shape = tensor_shape.unknown_shape()
for input_tensor in inputs:
if isinstance(input_tensor, ops.Tensor):
shape = shape.merge_with(input_tensor.get_shape())
if not shape.is_fully_defined():
# TODO(pbar): Make a version of assign_add that accepts an uninitialized
# lvalue, and takes its shape from that? This would allow accumulate_n to
# work in all situations that add_n currently works.
raise ValueError("Cannot infer the shape of the accumulator for "
"accumulate_n. Pass the shape argument, or set the shape "
"of at least one of the inputs.")
for input_tensor in inputs:
if isinstance(input_tensor, ops.Tensor):
shape = shape.merge_with(input_tensor.get_shape())
if len(inputs) == 1:
return inputs[0]
if tensor_dtype is None:
tensor_dtype = inputs[0].dtype
with ops.op_scope(inputs, name, "AccumulateN") as name:
if len(inputs) == 1:
return inputs[0]
var = gen_state_ops._temporary_variable(shape=shape, dtype=tensor_dtype)
var_name = var.op.name
var = state_ops.assign(var, array_ops.zeros_like(inputs[0]))
update_ops = []
for input_tensor in inputs:
op = state_ops.assign_add(var, input_tensor, use_locking=True)
update_ops.append(op)
with ops.control_dependencies(update_ops):
return gen_state_ops._destroy_temporary_variable(var,
var_name=var_name,
name=name)
var = gen_state_ops._temporary_variable(shape=tensor_shape.vector(0),
dtype=tensor_dtype)
with ops.colocate_with(var):
zeros = array_ops.zeros_like(gen_control_flow_ops._merge(inputs)[0])
zeros.set_shape(shape)
ref = state_ops.assign(var, zeros, validate_shape=False)
update_ops = [state_ops.assign_add(ref, input_tensor, use_locking=True)
for input_tensor in inputs]
with ops.control_dependencies(update_ops):
return gen_state_ops._destroy_temporary_variable(
ref, var_name=var.op.name, name=name)
@ops.RegisterShape("BatchMatMul")

View File

@ -40,6 +40,7 @@ class ReduceTest(test_util.TensorFlowTestCase):
y_tf = math_ops.reduce_sum(x).eval()
self.assertEqual(y_tf, 21)
class RoundTest(test_util.TensorFlowTestCase):
def testRounding(self):
@ -95,7 +96,9 @@ class SquaredDifferenceTest(test_util.TensorFlowTestCase):
z_tf = math_ops.squared_difference(x, y).eval()
self.assertAllClose(z, z_tf)
class ScalarMulTest(test_util.TensorFlowTestCase):
def testAcceptsRefs(self):
var = variables.Variable(10)
result = math_ops.scalar_mul(3, var)
@ -126,5 +129,27 @@ class ScalarMulTest(test_util.TensorFlowTestCase):
self.assertAllEqual(x.values.eval(), [[-6, -9], [-15, -21], [0, 3]])
self.assertAllEqual(x.indices.eval(), [0, 2, 5])
class AccumulateNTest(test_util.TensorFlowTestCase):
def testFloat(self):
np.random.seed(12345)
x = [np.random.random((1, 2, 3, 4, 5)) - 0.5 for _ in range(5)]
tf_x = ops.convert_n_to_tensor(x)
for u in tf_x:
print("shape=%s" % u.get_shape())
with self.test_session():
self.assertAllClose(sum(x), math_ops.accumulate_n(tf_x).eval())
self.assertAllClose(x[0] * 5, math_ops.accumulate_n([tf_x[0]] * 5).eval())
def testInt(self):
np.random.seed(54321)
x = [np.random.randint(-128, 128, (5, 4, 3, 2, 1)) for _ in range(6)]
tf_x = ops.convert_n_to_tensor(x)
with self.test_session():
self.assertAllEqual(sum(x), math_ops.accumulate_n(tf_x).eval())
self.assertAllEqual(x[0] * 6, math_ops.accumulate_n([tf_x[0]] * 6).eval())
if __name__ == "__main__":
googletest.main()