Fix the gradient for functions when its output has no dependencies on
its inputs. Change: 120128592
This commit is contained in:
parent
3c280f6fa0
commit
449ecb561f
@ -156,9 +156,8 @@ class SymbolicGradientBuilder {
|
|||||||
// add dy as an input of the gradient function.
|
// add dy as an input of the gradient function.
|
||||||
std::deque<Node*> ready_;
|
std::deque<Node*> ready_;
|
||||||
|
|
||||||
// The set of nodes at which to stop backprop.
|
// The set of node ids at which to stop backprop.
|
||||||
// Maps from node.id -> index of 'x_node_outputs_'
|
std::unordered_set<int> stop_nodes_;
|
||||||
std::unordered_map<int, int> stop_nodes_;
|
|
||||||
|
|
||||||
// Initialize pending_ and ready_.
|
// Initialize pending_ and ready_.
|
||||||
void InitBackprop();
|
void InitBackprop();
|
||||||
@ -190,7 +189,7 @@ SymbolicGradientBuilder::SymbolicGradientBuilder(
|
|||||||
x_grad_node_outputs_->resize(x_node_outputs_.size());
|
x_grad_node_outputs_->resize(x_node_outputs_.size());
|
||||||
stop_nodes_.reserve(x_node_outputs_.size());
|
stop_nodes_.reserve(x_node_outputs_.size());
|
||||||
for (int i = 0; i < x_node_outputs_.size(); ++i) {
|
for (int i = 0; i < x_node_outputs_.size(); ++i) {
|
||||||
stop_nodes_.insert(std::make_pair(x_node_outputs_[i].node->id(), i));
|
stop_nodes_.insert(x_node_outputs_[i].node->id());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -319,11 +318,9 @@ Status SymbolicGradientBuilder::Compute() {
|
|||||||
|
|
||||||
auto iter = stop_nodes_.find(n->id());
|
auto iter = stop_nodes_.find(n->id());
|
||||||
if (iter != stop_nodes_.end()) {
|
if (iter != stop_nodes_.end()) {
|
||||||
// Stop backprop and add gradient sum to 'x_grad_node_outputs_'.
|
// Stop backprop.
|
||||||
// TODO(andydavis) Support stop nodes with more than one output.
|
// TODO(andydavis) Support stop nodes with more than one output.
|
||||||
CHECK_EQ(1, num_y);
|
CHECK_EQ(1, num_y);
|
||||||
const int index = iter->second;
|
|
||||||
(*x_grad_node_outputs_)[index] = SumGradients(x_node_outputs_[index]);
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -362,6 +359,10 @@ Status SymbolicGradientBuilder::Compute() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < x_node_outputs_.size(); ++i) {
|
||||||
|
(*x_grad_node_outputs_)[i] = SumGradients(x_node_outputs_[i]);
|
||||||
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -232,6 +232,21 @@ class FunctionTest(tf.test.TestCase):
|
|||||||
self.assertEquals(x.get_shape(), dx.get_shape())
|
self.assertEquals(x.get_shape(), dx.get_shape())
|
||||||
self.assertEquals(y.get_shape(), dy.get_shape())
|
self.assertEquals(y.get_shape(), dy.get_shape())
|
||||||
|
|
||||||
|
def testZNoDepOnY(self):
|
||||||
|
with tf.Graph().as_default():
|
||||||
|
# z = Foo(x, y). z doe
|
||||||
|
@function.Defun(tf.float32, tf.float32)
|
||||||
|
def Foo(x, y):
|
||||||
|
return x * 2
|
||||||
|
x = tf.constant(1.0)
|
||||||
|
y = tf.constant(2.0)
|
||||||
|
z = Foo(x, y)
|
||||||
|
dx, dy = tf.gradients([z], [x, y])
|
||||||
|
with tf.Session() as sess:
|
||||||
|
dx_val, dy_val = sess.run([dx, dy])
|
||||||
|
self.assertEquals([2.0], dx_val)
|
||||||
|
self.assertEquals([0.0], dy_val)
|
||||||
|
|
||||||
def testDefineFunctionNoArgs(self):
|
def testDefineFunctionNoArgs(self):
|
||||||
|
|
||||||
def AConstant():
|
def AConstant():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user