add unit test for bugfix

This commit is contained in:
sunchenggen 2020-03-20 20:54:12 +08:00
parent c2f7bee604
commit 86ece822c8

View File

@ -503,6 +503,42 @@ TEST_F(GradientsTest, MultiOutputNodeDependentOutputs) {
EXPECT_EQ(grad_result[0].flat<float>()(0), 17610.0f);
}
TEST_F(GradientsTest, AddSymbolicGradientsTest) {
Scope scope = Scope::NewRootScope();
for (int cnt = 0; cnt < 100; ++cnt) {
int N = 5 + rand() % 10;
// Construct forward graph.
OutputList inputs;
for (int i = 0; i < N; ++i) {
auto a = Const(scope, i, {1});
inputs.push_back(a);
}
auto pack = Stack(scope, inputs);
TF_ASSERT_OK(scope.status());
// Construct grad inputs.
OutputList output_grads;
Tensor ts(DT_INT32, {N, 1});
auto v = ts.matrix<int32>();
for (int i = 0; i < N; ++i) {
v(i, 0) = i;
}
auto dy = Const(scope, ts);
output_grads.push_back(dy);
// Call AddSymbolicGradients.
std::vector<Output> grad_outputs;
TF_ASSERT_OK(AddSymbolicGradients(scope, {pack.output}, inputs,
output_grads, &grad_outputs));
ClientSession session((scope));
std::vector<Tensor> in_grad;
TF_ASSERT_OK(session.Run(grad_outputs, &in_grad));
for (int i = 0; i < N; ++i) {
test::ExpectTensorEqual<int>(in_grad[i], test::AsTensor<int>({i}, {1}));
}
}
}
// StopGradientSingleOutputMultiEdgeTest tests combinations of valid and
// 'NoGradient' (induced by StopGradient op) returned along multiple edges from
// a single nodes output.