diff --git a/tensorflow/c/experimental/gradients/BUILD b/tensorflow/c/experimental/gradients/BUILD index 53423edd0a7..5ae3cf501ac 100644 --- a/tensorflow/c/experimental/gradients/BUILD +++ b/tensorflow/c/experimental/gradients/BUILD @@ -31,14 +31,11 @@ cc_library( "//tensorflow:internal", ], deps = [ - "//tensorflow/c/eager:abstract_operation", "//tensorflow/c/eager:abstract_tensor_handle", - "//tensorflow/c/eager:c_api_unified_internal", "//tensorflow/c/eager:gradients_internal", "//tensorflow/c/experimental/ops:array_ops", "//tensorflow/c/experimental/ops:math_ops", "//tensorflow/c/experimental/ops:nn_ops", - "//tensorflow/core/lib/llvm_rtti", ], ) diff --git a/tensorflow/c/experimental/gradients/math_grad.cc b/tensorflow/c/experimental/gradients/math_grad.cc index 3537b30c597..c2aa9caf814 100644 --- a/tensorflow/c/experimental/gradients/math_grad.cc +++ b/tensorflow/c/experimental/gradients/math_grad.cc @@ -22,10 +22,8 @@ limitations under the License. using std::vector; using tensorflow::ops::Conj; -using tensorflow::ops::Identity; using tensorflow::ops::MatMul; using tensorflow::ops::Mul; -using tensorflow::ops::ZerosLike; namespace tensorflow { namespace gradients { @@ -36,21 +34,14 @@ class AddGradientFunction : public GradientFunction { Status Compute(Context* ctx, const IncomingGradients& grad_inputs, vector<AbstractTensorHandle*>* grad_outputs) override { grad_outputs->resize(2); - vector<AbstractTensorHandle*> identity_outputs(1); - // TODO(b/145674566): Handle name unification in tracing code. // TODO(b/161805092): Support broadcasting. - std::string name = "Identity_A"; - TF_RETURN_IF_ERROR(ops::Identity(ctx->ctx, {grad_inputs[0]}, - absl::MakeSpan(identity_outputs), - name.c_str())); - (*grad_outputs)[0] = identity_outputs[0]; + DCHECK(grad_inputs[0]); + (*grad_outputs)[0] = grad_inputs[0]; + (*grad_outputs)[1] = grad_inputs[0]; - name = "Identity_B"; - TF_RETURN_IF_ERROR(ops::Identity(ctx->ctx, {grad_inputs[0]}, - absl::MakeSpan(identity_outputs), - name.c_str())); - (*grad_outputs)[1] = identity_outputs[0]; + (*grad_outputs)[0]->Ref(); + (*grad_outputs)[1]->Ref(); return Status::OK(); } ~AddGradientFunction() override {}