diff --git a/tensorflow/cc/framework/gradients_test.cc b/tensorflow/cc/framework/gradients_test.cc index 26e3170ad8e..75291678177 100644 --- a/tensorflow/cc/framework/gradients_test.cc +++ b/tensorflow/cc/framework/gradients_test.cc @@ -503,6 +503,42 @@ TEST_F(GradientsTest, MultiOutputNodeDependentOutputs) { EXPECT_EQ(grad_result[0].flat()(0), 17610.0f); } +TEST_F(GradientsTest, AddSymbolicGradientsTest) { + Scope scope = Scope::NewRootScope(); + for (int cnt = 0; cnt < 100; ++cnt) { + int N = 5 + rand() % 10; + // Construct forward graph. + OutputList inputs; + for (int i = 0; i < N; ++i) { + auto a = Const(scope, i, {1}); + inputs.push_back(a); + } + + auto pack = Stack(scope, inputs); + TF_ASSERT_OK(scope.status()); + + // Construct grad inputs. + OutputList output_grads; + Tensor ts(DT_INT32, {N, 1}); + auto v = ts.matrix(); + for (int i = 0; i < N; ++i) { + v(i, 0) = i; + } + auto dy = Const(scope, ts); + output_grads.push_back(dy); + // Call AddSymbolicGradients. + std::vector grad_outputs; + TF_ASSERT_OK(AddSymbolicGradients(scope, {pack.output}, inputs, + output_grads, &grad_outputs)); + ClientSession session((scope)); + std::vector in_grad; + TF_ASSERT_OK(session.Run(grad_outputs, &in_grad)); + for (int i = 0; i < N; ++i) { + test::ExpectTensorEqual(in_grad[i], test::AsTensor({i}, {1})); + } + } +} + // StopGradientSingleOutputMultiEdgeTest tests combinations of valid and // 'NoGradient' (induced by StopGradient op) returned along multiple edges from // a single nodes output.