Fix the gradient for functions when its output has no dependencies on
its inputs. Change: 120128592
This commit is contained in:
parent
3c280f6fa0
commit
449ecb561f
@ -156,9 +156,8 @@ class SymbolicGradientBuilder {
|
||||
// add dy as an input of the gradient function.
|
||||
std::deque<Node*> ready_;
|
||||
|
||||
// The set of nodes at which to stop backprop.
|
||||
// Maps from node.id -> index of 'x_node_outputs_'
|
||||
std::unordered_map<int, int> stop_nodes_;
|
||||
// The set of node ids at which to stop backprop.
|
||||
std::unordered_set<int> stop_nodes_;
|
||||
|
||||
// Initialize pending_ and ready_.
|
||||
void InitBackprop();
|
||||
@ -190,7 +189,7 @@ SymbolicGradientBuilder::SymbolicGradientBuilder(
|
||||
x_grad_node_outputs_->resize(x_node_outputs_.size());
|
||||
stop_nodes_.reserve(x_node_outputs_.size());
|
||||
for (int i = 0; i < x_node_outputs_.size(); ++i) {
|
||||
stop_nodes_.insert(std::make_pair(x_node_outputs_[i].node->id(), i));
|
||||
stop_nodes_.insert(x_node_outputs_[i].node->id());
|
||||
}
|
||||
}
|
||||
|
||||
@ -319,11 +318,9 @@ Status SymbolicGradientBuilder::Compute() {
|
||||
|
||||
auto iter = stop_nodes_.find(n->id());
|
||||
if (iter != stop_nodes_.end()) {
|
||||
// Stop backprop and add gradient sum to 'x_grad_node_outputs_'.
|
||||
// Stop backprop.
|
||||
// TODO(andydavis) Support stop nodes with more than one output.
|
||||
CHECK_EQ(1, num_y);
|
||||
const int index = iter->second;
|
||||
(*x_grad_node_outputs_)[index] = SumGradients(x_node_outputs_[index]);
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -362,6 +359,10 @@ Status SymbolicGradientBuilder::Compute() {
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < x_node_outputs_.size(); ++i) {
|
||||
(*x_grad_node_outputs_)[i] = SumGradients(x_node_outputs_[i]);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -232,6 +232,21 @@ class FunctionTest(tf.test.TestCase):
|
||||
self.assertEquals(x.get_shape(), dx.get_shape())
|
||||
self.assertEquals(y.get_shape(), dy.get_shape())
|
||||
|
||||
def testZNoDepOnY(self):
|
||||
with tf.Graph().as_default():
|
||||
# z = Foo(x, y). z doe
|
||||
@function.Defun(tf.float32, tf.float32)
|
||||
def Foo(x, y):
|
||||
return x * 2
|
||||
x = tf.constant(1.0)
|
||||
y = tf.constant(2.0)
|
||||
z = Foo(x, y)
|
||||
dx, dy = tf.gradients([z], [x, y])
|
||||
with tf.Session() as sess:
|
||||
dx_val, dy_val = sess.run([dx, dy])
|
||||
self.assertEquals([2.0], dx_val)
|
||||
self.assertEquals([0.0], dy_val)
|
||||
|
||||
def testDefineFunctionNoArgs(self):
|
||||
|
||||
def AConstant():
|
||||
|
Loading…
x
Reference in New Issue
Block a user