Unreachable input gradients (#13071)
* Check if inputs are reachable from outputs in AddSymbolicGradients. * Removing LOG. * Edit following the PR comments. * Lines > 80 chars. * Formatting comments in gradients_test. * Eliminate m2 and renamed m1->m, dm1->diff_m * Edit InvalidArgument string concatenation.
This commit is contained in:
parent
66eed36466
commit
a0bbeb10e2
@ -175,8 +175,14 @@ Status SymbolicGradientBuilder::Initialize() {
|
||||
"Must specify a gradient input for each output.");
|
||||
}
|
||||
std::vector<bool> reachable_nodes = GetReachableNodes();
|
||||
// TODO(theflofly) Check that inputs_ are reachable from
|
||||
// outputs_ using reachable_nodes
|
||||
for (const Output& input : inputs_) {
|
||||
if (!reachable_nodes[input.node()->id()]) {
|
||||
return errors::InvalidArgument(
|
||||
"Cannot compute the partial derivative for node '",
|
||||
input.node()->name(),
|
||||
"' as it's unreachable from the output node(s).");
|
||||
}
|
||||
}
|
||||
grad_outputs_->clear();
|
||||
grad_outputs_->resize(inputs_.size());
|
||||
// Populate `output_nodes_` from node ids in `outputs_`.
|
||||
|
@ -48,9 +48,9 @@ class GradientsTest : public ::testing::Test {
|
||||
Scope scope_test_;
|
||||
};
|
||||
|
||||
// EX.
|
||||
// Example:
|
||||
// ^ ^
|
||||
// dy| dx| // MatMul Gradient Graph
|
||||
// dy| dx| (MatMul Gradient Graph)
|
||||
// | |
|
||||
// MatMul_1 MatMul_2
|
||||
// ^ ^ ^ ^
|
||||
@ -61,7 +61,7 @@ class GradientsTest : public ::testing::Test {
|
||||
// | Const_3 |
|
||||
// | |
|
||||
// | ^ |
|
||||
// | z| | // MatMul Forward Graph
|
||||
// | z| | (MatMul Forward Graph)
|
||||
// | | |
|
||||
// | MatMul_0 |
|
||||
// | / \ |
|
||||
@ -373,24 +373,22 @@ TEST_F(GradientsTest, UnreachableEdgeGradOneOutput) {
|
||||
auto y_const = Const(scope_test_, {{1.0}, {2.0}, {3.0}});
|
||||
auto y_assign = Assign(scope_test_, y, y_const);
|
||||
|
||||
auto m1 = MatMul(scope_test_, x, y);
|
||||
auto m = MatMul(scope_test_, x, y);
|
||||
|
||||
auto z = Variable(scope_test_, {1, 3}, DT_DOUBLE);
|
||||
auto z_const = Const(scope_test_, {{9.0, 10.0, 11.0}});
|
||||
auto z_assign = Assign(scope_test_, z, z_const);
|
||||
|
||||
auto m2 = MatMul(scope_test_, y, z);
|
||||
|
||||
auto dm1 = Const(scope_test_, {{0.5}, {0.5}});
|
||||
auto diff_m = Const(scope_test_, {{0.5}, {0.5}});
|
||||
|
||||
std::vector<Output> grad_outputs;
|
||||
TF_ASSERT_OK(
|
||||
AddSymbolicGradients(scope_test_, {m1}, {y}, {dm1}, &grad_outputs));
|
||||
AddSymbolicGradients(scope_test_, {m}, {y}, {diff_m}, &grad_outputs));
|
||||
|
||||
std::vector<Tensor> outputs;
|
||||
test::GetTensors(scope_test_, {x_assign, y_assign, z_assign},
|
||||
{grad_outputs[0]}, &outputs);
|
||||
// dz/dy = xT * dm1
|
||||
// dz/dy = xT * diff_m
|
||||
test::ExpectTensorNear<double>(
|
||||
outputs[0], test::AsTensor<double>({2.5, 3.5, 4.5}, {3, 1}), 1e-5);
|
||||
}
|
||||
@ -424,13 +422,36 @@ TEST_F(GradientsTest, UnreachableEdgeGradTwoOutputs) {
|
||||
test::GetTensors(scope_test_, {x_assign, y_assign, z_assign},
|
||||
{grad_outputs[0]}, &outputs);
|
||||
|
||||
// the gradients from m1 and m2 will be summed to compute the gradient
|
||||
// w.r.t y
|
||||
// The gradients from m1 and m2 will be summed to compute the gradient
|
||||
// w.r.t y:
|
||||
// dz/dy = xT * dm1 + dm2 * zT
|
||||
test::ExpectTensorNear<double>(
|
||||
outputs[0], test::AsTensor<double>({17.5, 24.7, 26.8}, {3, 1}), 1e-5);
|
||||
}
|
||||
|
||||
TEST_F(GradientsTest, UnreachableInput) {
|
||||
auto x = Const(scope_test_, {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}});
|
||||
auto y = Const(scope_test_, {{1.0}, {2.0}, {3.0}});
|
||||
auto z = Const(scope_test_.WithOpName("z"), {{9.0, 10.0, 11.0}});
|
||||
|
||||
auto m1 = MatMul(scope_test_, x, y);
|
||||
auto m2 = MatMul(scope_test_, y, z);
|
||||
auto dm1 = Const(scope_test_, {{0.5}, {0.5}});
|
||||
|
||||
// From m1, z is unreachable, so an error status should be returned.
|
||||
// m2 m1
|
||||
// | |
|
||||
// * *
|
||||
// / \ / \
|
||||
// z y x
|
||||
std::vector<Output> grad_outputs;
|
||||
Status status = AddSymbolicGradients(scope_test_, {m1}, {z}, {dm1},
|
||||
&grad_outputs);
|
||||
EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
|
||||
EXPECT_EQ(status.error_message(), "Cannot compute the partial derivative"
|
||||
" for node 'z' as it's unreachable from the output node(s).");
|
||||
}
|
||||
|
||||
// StopGradientSingleOutputMultiEdgeTest tests combinations of valid and
|
||||
// 'NoGradient' (induced by StopGradient op) returned along multiple edges from
|
||||
// a single nodes output.
|
||||
|
Loading…
Reference in New Issue
Block a user