From 449ecb561f6b480a6043d23160be00f35b524aa9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 18 Apr 2016 08:10:32 -0800 Subject: [PATCH] Fix the gradient for functions when its output has no dependencies on its inputs. Change: 120128592 --- tensorflow/core/graph/gradients.cc | 15 ++++++++------- tensorflow/python/framework/function_test.py | 15 +++++++++++++++ 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/tensorflow/core/graph/gradients.cc b/tensorflow/core/graph/gradients.cc index 1c902d29a02..4e5677414c5 100644 --- a/tensorflow/core/graph/gradients.cc +++ b/tensorflow/core/graph/gradients.cc @@ -156,9 +156,8 @@ class SymbolicGradientBuilder { // add dy as an input of the gradient function. std::deque ready_; - // The set of nodes at which to stop backprop. - // Maps from node.id -> index of 'x_node_outputs_' - std::unordered_map stop_nodes_; + // The set of node ids at which to stop backprop. + std::unordered_set stop_nodes_; // Initialize pending_ and ready_. void InitBackprop(); @@ -190,7 +189,7 @@ SymbolicGradientBuilder::SymbolicGradientBuilder( x_grad_node_outputs_->resize(x_node_outputs_.size()); stop_nodes_.reserve(x_node_outputs_.size()); for (int i = 0; i < x_node_outputs_.size(); ++i) { - stop_nodes_.insert(std::make_pair(x_node_outputs_[i].node->id(), i)); + stop_nodes_.insert(x_node_outputs_[i].node->id()); } } @@ -319,11 +318,9 @@ Status SymbolicGradientBuilder::Compute() { auto iter = stop_nodes_.find(n->id()); if (iter != stop_nodes_.end()) { - // Stop backprop and add gradient sum to 'x_grad_node_outputs_'. + // Stop backprop. // TODO(andydavis) Support stop nodes with more than one output. CHECK_EQ(1, num_y); - const int index = iter->second; - (*x_grad_node_outputs_)[index] = SumGradients(x_node_outputs_[index]); continue; } @@ -362,6 +359,10 @@ Status SymbolicGradientBuilder::Compute() { } } + for (int i = 0; i < x_node_outputs_.size(); ++i) { + (*x_grad_node_outputs_)[i] = SumGradients(x_node_outputs_[i]); + } + return Status::OK(); } diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py index f5f91118c26..f9e4ec258b6 100644 --- a/tensorflow/python/framework/function_test.py +++ b/tensorflow/python/framework/function_test.py @@ -232,6 +232,21 @@ class FunctionTest(tf.test.TestCase): self.assertEquals(x.get_shape(), dx.get_shape()) self.assertEquals(y.get_shape(), dy.get_shape()) + def testZNoDepOnY(self): + with tf.Graph().as_default(): + # z = Foo(x, y). z doe + @function.Defun(tf.float32, tf.float32) + def Foo(x, y): + return x * 2 + x = tf.constant(1.0) + y = tf.constant(2.0) + z = Foo(x, y) + dx, dy = tf.gradients([z], [x, y]) + with tf.Session() as sess: + dx_val, dy_val = sess.run([dx, dy]) + self.assertEquals([2.0], dx_val) + self.assertEquals([0.0], dy_val) + def testDefineFunctionNoArgs(self): def AConstant():