diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 97b7b4f0a6c..65ec0054933 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -1439,6 +1439,15 @@ filegroup( visibility = ["//tensorflow:__subpackages__"], ) +cuda_py_test( + name = "accumulate_n_benchmark", + size = "large", + srcs = [ + "ops/accumulate_n_benchmark.py", + ], + main = "ops/accumulate_n_benchmark.py", +) + cuda_py_test( name = "batch_norm_benchmark", srcs = [ diff --git a/tensorflow/python/ops/accumulate_n_benchmark.py b/tensorflow/python/ops/accumulate_n_benchmark.py new file mode 100644 index 00000000000..b7947ef1030 --- /dev/null +++ b/tensorflow/python/ops/accumulate_n_benchmark.py @@ -0,0 +1,137 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Benchmark for accumulate_n() in math_ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import random +import time + +import tensorflow as tf +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import gen_control_flow_ops +from tensorflow.python.ops import gen_state_ops + + +class AccumulateNBenchmark(tf.test.Benchmark): + + def _AccumulateNTemplate(self, inputs, init, shape, validate_shape): + var = gen_state_ops._temporary_variable( + shape=shape, dtype=inputs[0].dtype.base_dtype) + ref = tf.assign(var, init, validate_shape=validate_shape) + update_ops = [tf.assign_add(ref, tensor, use_locking=True).op + for tensor in inputs] + with tf.control_dependencies(update_ops): + return gen_state_ops._destroy_temporary_variable( + ref, var_name=var.op.name) + + def _AccumulateNInitializedWithFirst(self, inputs): + return self._AccumulateNTemplate( + inputs, init=tf.zeros_like(inputs[0]), + shape=inputs[0].get_shape(), validate_shape=True) + + def _AccumulateNInitializedWithMerge(self, inputs): + return self._AccumulateNTemplate( + inputs, + init=tf.zeros_like(gen_control_flow_ops._merge(inputs)[0]), + shape=tensor_shape.vector(0), + validate_shape=False) + + def _AccumulateNInitializedWithShape(self, inputs): + return self._AccumulateNTemplate( + inputs, + init=tf.zeros(shape=inputs[0].get_shape(), + dtype=inputs[0].dtype.base_dtype), + shape=inputs[0].get_shape(), + validate_shape=True) + + def _GenerateUnorderedInputs(self, size, n): + inputs = [tf.random_uniform(shape=[size]) for _ in xrange(n)] + random.shuffle(inputs) + return inputs + + def _GenerateReplicatedInputs(self, size, n): + return n * self._GenerateUnorderedInputs(size, 1) + + def _GenerateOrderedInputs(self, size, n): + inputs = self._GenerateUnorderedInputs(size, 1) + queue = tf.FIFOQueue(capacity=1, dtypes=[inputs[0].dtype], + shapes=[inputs[0].get_shape()]) + for _ in xrange(n - 1): + op = queue.enqueue(inputs[-1]) + with tf.control_dependencies([op]): + inputs.append(tf.tanh(1.0 + queue.dequeue())) + return inputs + + def _GenerateReversedInputs(self, size, n): + inputs = self._GenerateOrderedInputs(size, n) + inputs.reverse() + return inputs + + def _SetupAndRunBenchmark(self, graph, inputs, repeats, format_args): + with graph.as_default(): + add_n = tf.add_n(inputs) + acc_n_first = self._AccumulateNInitializedWithFirst(inputs) + acc_n_merge = self._AccumulateNInitializedWithMerge(inputs) + acc_n_shape = self._AccumulateNInitializedWithShape(inputs) + + test_ops = (("AddN", add_n.op), + ("AccNFirst", acc_n_first.op), + ("AccNMerge", acc_n_merge.op), + ("AccNShape", acc_n_shape.op)) + + with tf.Session(graph=graph): + for tag, op in test_ops: + for _ in xrange(100): + op.run() # Run for warm up. + start = time.time() + for _ in xrange(repeats): + op.run() + duration = time.time() - start + args = format_args + (tag, duration) + print(self._template.format(*args)) + + def _RunBenchmark(self, tag, input_fn, sizes, ninputs, repeats): + for size in sizes: + for ninput in ninputs: + graph = tf.Graph() + with graph.as_default(): + inputs = input_fn(size, ninput) + + format_args = (tag, size, ninput, repeats) + self._SetupAndRunBenchmark(graph, inputs, repeats, format_args) + + def benchmarkAccumulateN(self): + self._template = "{:<15}" * 6 + args = {"sizes": (128, 128 ** 2), + "ninputs": (1, 10, 100, 300), + "repeats": 100} + benchmarks = (("Replicated", self._GenerateReplicatedInputs), + ("Unordered", self._GenerateUnorderedInputs), + ("Ordered", self._GenerateOrderedInputs), + ("Reversed", self._GenerateReversedInputs)) + + print(self._template.format( + "", "Size", "#Inputs", "#Repeat", "Method", "Duration")) + print("-" * 90) + for benchmark in benchmarks: + self._RunBenchmark(*benchmark, **args) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 2c3cd2859ed..3417b517e1a 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -228,6 +228,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_control_flow_ops from tensorflow.python.ops import gen_data_flow_ops from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import gen_sparse_ops @@ -1505,46 +1506,39 @@ def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None): ValueError: If `inputs` don't all have same shape and dtype or the shape cannot be inferred. """ - if tensor_dtype is None: - if not inputs or not isinstance(inputs, (list, tuple)): - raise ValueError("inputs must be a list of at least one Tensor with the " - "same dtype and shape") - inputs = ops.convert_n_to_tensor_or_indexed_slices(inputs) - if not all(isinstance(x, ops.Tensor) for x in inputs): - raise ValueError("inputs must be a list of at least one Tensor with the " - "same dtype and shape") - if not all(x.dtype == inputs[0].dtype for x in inputs): - raise ValueError("inputs must be a list of at least one Tensor with the " - "same dtype and shape") - tensor_dtype = inputs[0].dtype + if not inputs or not isinstance(inputs, (list, tuple)): + raise ValueError("inputs must be a list of at least one Tensor with the " + "same dtype and shape") + inputs = ops.convert_n_to_tensor_or_indexed_slices(inputs) + if not all(isinstance(x, ops.Tensor) for x in inputs): + raise ValueError("inputs must be a list of at least one Tensor with the " + "same dtype and shape") + if not all(x.dtype == inputs[0].dtype for x in inputs): + raise ValueError("inputs must be a list of at least one Tensor with the " + "same dtype and shape") if shape is not None: shape = tensor_shape.as_shape(shape) else: shape = tensor_shape.unknown_shape() - for input_tensor in inputs: - if isinstance(input_tensor, ops.Tensor): - shape = shape.merge_with(input_tensor.get_shape()) - if not shape.is_fully_defined(): - # TODO(pbar): Make a version of assign_add that accepts an uninitialized - # lvalue, and takes its shape from that? This would allow accumulate_n to - # work in all situations that add_n currently works. - raise ValueError("Cannot infer the shape of the accumulator for " - "accumulate_n. Pass the shape argument, or set the shape " - "of at least one of the inputs.") + for input_tensor in inputs: + if isinstance(input_tensor, ops.Tensor): + shape = shape.merge_with(input_tensor.get_shape()) + if len(inputs) == 1: + return inputs[0] + if tensor_dtype is None: + tensor_dtype = inputs[0].dtype with ops.op_scope(inputs, name, "AccumulateN") as name: - if len(inputs) == 1: - return inputs[0] - var = gen_state_ops._temporary_variable(shape=shape, dtype=tensor_dtype) - var_name = var.op.name - var = state_ops.assign(var, array_ops.zeros_like(inputs[0])) - update_ops = [] - for input_tensor in inputs: - op = state_ops.assign_add(var, input_tensor, use_locking=True) - update_ops.append(op) - with ops.control_dependencies(update_ops): - return gen_state_ops._destroy_temporary_variable(var, - var_name=var_name, - name=name) + var = gen_state_ops._temporary_variable(shape=tensor_shape.vector(0), + dtype=tensor_dtype) + with ops.colocate_with(var): + zeros = array_ops.zeros_like(gen_control_flow_ops._merge(inputs)[0]) + zeros.set_shape(shape) + ref = state_ops.assign(var, zeros, validate_shape=False) + update_ops = [state_ops.assign_add(ref, input_tensor, use_locking=True) + for input_tensor in inputs] + with ops.control_dependencies(update_ops): + return gen_state_ops._destroy_temporary_variable( + ref, var_name=var.op.name, name=name) @ops.RegisterShape("BatchMatMul") diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py index a793f4968cc..27e4a037486 100644 --- a/tensorflow/python/ops/math_ops_test.py +++ b/tensorflow/python/ops/math_ops_test.py @@ -40,6 +40,7 @@ class ReduceTest(test_util.TensorFlowTestCase): y_tf = math_ops.reduce_sum(x).eval() self.assertEqual(y_tf, 21) + class RoundTest(test_util.TensorFlowTestCase): def testRounding(self): @@ -95,7 +96,9 @@ class SquaredDifferenceTest(test_util.TensorFlowTestCase): z_tf = math_ops.squared_difference(x, y).eval() self.assertAllClose(z, z_tf) + class ScalarMulTest(test_util.TensorFlowTestCase): + def testAcceptsRefs(self): var = variables.Variable(10) result = math_ops.scalar_mul(3, var) @@ -126,5 +129,27 @@ class ScalarMulTest(test_util.TensorFlowTestCase): self.assertAllEqual(x.values.eval(), [[-6, -9], [-15, -21], [0, 3]]) self.assertAllEqual(x.indices.eval(), [0, 2, 5]) + +class AccumulateNTest(test_util.TensorFlowTestCase): + + def testFloat(self): + np.random.seed(12345) + x = [np.random.random((1, 2, 3, 4, 5)) - 0.5 for _ in range(5)] + tf_x = ops.convert_n_to_tensor(x) + for u in tf_x: + print("shape=%s" % u.get_shape()) + with self.test_session(): + self.assertAllClose(sum(x), math_ops.accumulate_n(tf_x).eval()) + self.assertAllClose(x[0] * 5, math_ops.accumulate_n([tf_x[0]] * 5).eval()) + + def testInt(self): + np.random.seed(54321) + x = [np.random.randint(-128, 128, (5, 4, 3, 2, 1)) for _ in range(6)] + tf_x = ops.convert_n_to_tensor(x) + with self.test_session(): + self.assertAllEqual(sum(x), math_ops.accumulate_n(tf_x).eval()) + self.assertAllEqual(x[0] * 6, math_ops.accumulate_n([tf_x[0]] * 6).eval()) + + if __name__ == "__main__": googletest.main()