Fix the gradient for functions when its output has no dependencies on

its inputs.
Change: 120128592
This commit is contained in:
A. Unique TensorFlower 2016-04-18 08:10:32 -08:00 committed by TensorFlower Gardener
parent 3c280f6fa0
commit 449ecb561f
2 changed files with 23 additions and 7 deletions

View File

@ -156,9 +156,8 @@ class SymbolicGradientBuilder {
// add dy as an input of the gradient function.
std::deque<Node*> ready_;
// The set of nodes at which to stop backprop.
// Maps from node.id -> index of 'x_node_outputs_'
std::unordered_map<int, int> stop_nodes_;
// The set of node ids at which to stop backprop.
std::unordered_set<int> stop_nodes_;
// Initialize pending_ and ready_.
void InitBackprop();
@ -190,7 +189,7 @@ SymbolicGradientBuilder::SymbolicGradientBuilder(
x_grad_node_outputs_->resize(x_node_outputs_.size());
stop_nodes_.reserve(x_node_outputs_.size());
for (int i = 0; i < x_node_outputs_.size(); ++i) {
stop_nodes_.insert(std::make_pair(x_node_outputs_[i].node->id(), i));
stop_nodes_.insert(x_node_outputs_[i].node->id());
}
}
@ -319,11 +318,9 @@ Status SymbolicGradientBuilder::Compute() {
auto iter = stop_nodes_.find(n->id());
if (iter != stop_nodes_.end()) {
// Stop backprop and add gradient sum to 'x_grad_node_outputs_'.
// Stop backprop.
// TODO(andydavis) Support stop nodes with more than one output.
CHECK_EQ(1, num_y);
const int index = iter->second;
(*x_grad_node_outputs_)[index] = SumGradients(x_node_outputs_[index]);
continue;
}
@ -362,6 +359,10 @@ Status SymbolicGradientBuilder::Compute() {
}
}
for (int i = 0; i < x_node_outputs_.size(); ++i) {
(*x_grad_node_outputs_)[i] = SumGradients(x_node_outputs_[i]);
}
return Status::OK();
}

View File

@ -232,6 +232,21 @@ class FunctionTest(tf.test.TestCase):
self.assertEquals(x.get_shape(), dx.get_shape())
self.assertEquals(y.get_shape(), dy.get_shape())
def testZNoDepOnY(self):
with tf.Graph().as_default():
# z = Foo(x, y). z doe
@function.Defun(tf.float32, tf.float32)
def Foo(x, y):
return x * 2
x = tf.constant(1.0)
y = tf.constant(2.0)
z = Foo(x, y)
dx, dy = tf.gradients([z], [x, y])
with tf.Session() as sess:
dx_val, dy_val = sess.run([dx, dy])
self.assertEquals([2.0], dx_val)
self.assertEquals([0.0], dy_val)
def testDefineFunctionNoArgs(self):
def AConstant():