From 1009006e3e792485690c48ccd31fbe29f9b95a59 Mon Sep 17 00:00:00 2001 From: Saurabh Saxena Date: Tue, 15 Sep 2020 11:02:18 -0700 Subject: [PATCH] [C++ Gradients] Remove superfluous Identity nodes from AddGradientFunction and just `Ref` the incoming gradient. PiperOrigin-RevId: 331808120 Change-Id: I18b92395aad88931965e3bbeb451b286fb9ea1f2 --- tensorflow/c/experimental/gradients/BUILD | 3 --- .../c/experimental/gradients/math_grad.cc | 19 +++++-------------- 2 files changed, 5 insertions(+), 17 deletions(-) diff --git a/tensorflow/c/experimental/gradients/BUILD b/tensorflow/c/experimental/gradients/BUILD index 53423edd0a7..5ae3cf501ac 100644 --- a/tensorflow/c/experimental/gradients/BUILD +++ b/tensorflow/c/experimental/gradients/BUILD @@ -31,14 +31,11 @@ cc_library( "//tensorflow:internal", ], deps = [ - "//tensorflow/c/eager:abstract_operation", "//tensorflow/c/eager:abstract_tensor_handle", - "//tensorflow/c/eager:c_api_unified_internal", "//tensorflow/c/eager:gradients_internal", "//tensorflow/c/experimental/ops:array_ops", "//tensorflow/c/experimental/ops:math_ops", "//tensorflow/c/experimental/ops:nn_ops", - "//tensorflow/core/lib/llvm_rtti", ], ) diff --git a/tensorflow/c/experimental/gradients/math_grad.cc b/tensorflow/c/experimental/gradients/math_grad.cc index 3537b30c597..c2aa9caf814 100644 --- a/tensorflow/c/experimental/gradients/math_grad.cc +++ b/tensorflow/c/experimental/gradients/math_grad.cc @@ -22,10 +22,8 @@ limitations under the License. using std::vector; using tensorflow::ops::Conj; -using tensorflow::ops::Identity; using tensorflow::ops::MatMul; using tensorflow::ops::Mul; -using tensorflow::ops::ZerosLike; namespace tensorflow { namespace gradients { @@ -36,21 +34,14 @@ class AddGradientFunction : public GradientFunction { Status Compute(Context* ctx, const IncomingGradients& grad_inputs, vector* grad_outputs) override { grad_outputs->resize(2); - vector identity_outputs(1); - // TODO(b/145674566): Handle name unification in tracing code. // TODO(b/161805092): Support broadcasting. - std::string name = "Identity_A"; - TF_RETURN_IF_ERROR(ops::Identity(ctx->ctx, {grad_inputs[0]}, - absl::MakeSpan(identity_outputs), - name.c_str())); - (*grad_outputs)[0] = identity_outputs[0]; + DCHECK(grad_inputs[0]); + (*grad_outputs)[0] = grad_inputs[0]; + (*grad_outputs)[1] = grad_inputs[0]; - name = "Identity_B"; - TF_RETURN_IF_ERROR(ops::Identity(ctx->ctx, {grad_inputs[0]}, - absl::MakeSpan(identity_outputs), - name.c_str())); - (*grad_outputs)[1] = identity_outputs[0]; + (*grad_outputs)[0]->Ref(); + (*grad_outputs)[1]->Ref(); return Status::OK(); } ~AddGradientFunction() override {}