[C++ Gradients] Remove superfluous Identity nodes from AddGradientFunction and just Ref the incoming gradient.

PiperOrigin-RevId: 331808120
Change-Id: I18b92395aad88931965e3bbeb451b286fb9ea1f2
This commit is contained in:
Saurabh Saxena 2020-09-15 11:02:18 -07:00 committed by TensorFlower Gardener
parent 6f0dd6eac4
commit 1009006e3e
2 changed files with 5 additions and 17 deletions

View File

@ -31,14 +31,11 @@ cc_library(
"//tensorflow:internal", "//tensorflow:internal",
], ],
deps = [ deps = [
"//tensorflow/c/eager:abstract_operation",
"//tensorflow/c/eager:abstract_tensor_handle", "//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:c_api_unified_internal",
"//tensorflow/c/eager:gradients_internal", "//tensorflow/c/eager:gradients_internal",
"//tensorflow/c/experimental/ops:array_ops", "//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/c/experimental/ops:math_ops", "//tensorflow/c/experimental/ops:math_ops",
"//tensorflow/c/experimental/ops:nn_ops", "//tensorflow/c/experimental/ops:nn_ops",
"//tensorflow/core/lib/llvm_rtti",
], ],
) )

View File

@ -22,10 +22,8 @@ limitations under the License.
using std::vector; using std::vector;
using tensorflow::ops::Conj; using tensorflow::ops::Conj;
using tensorflow::ops::Identity;
using tensorflow::ops::MatMul; using tensorflow::ops::MatMul;
using tensorflow::ops::Mul; using tensorflow::ops::Mul;
using tensorflow::ops::ZerosLike;
namespace tensorflow { namespace tensorflow {
namespace gradients { namespace gradients {
@ -36,21 +34,14 @@ class AddGradientFunction : public GradientFunction {
Status Compute(Context* ctx, const IncomingGradients& grad_inputs, Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override { vector<AbstractTensorHandle*>* grad_outputs) override {
grad_outputs->resize(2); grad_outputs->resize(2);
vector<AbstractTensorHandle*> identity_outputs(1);
// TODO(b/145674566): Handle name unification in tracing code.
// TODO(b/161805092): Support broadcasting. // TODO(b/161805092): Support broadcasting.
std::string name = "Identity_A"; DCHECK(grad_inputs[0]);
TF_RETURN_IF_ERROR(ops::Identity(ctx->ctx, {grad_inputs[0]}, (*grad_outputs)[0] = grad_inputs[0];
absl::MakeSpan(identity_outputs), (*grad_outputs)[1] = grad_inputs[0];
name.c_str()));
(*grad_outputs)[0] = identity_outputs[0];
name = "Identity_B"; (*grad_outputs)[0]->Ref();
TF_RETURN_IF_ERROR(ops::Identity(ctx->ctx, {grad_inputs[0]}, (*grad_outputs)[1]->Ref();
absl::MakeSpan(identity_outputs),
name.c_str()));
(*grad_outputs)[1] = identity_outputs[0];
return Status::OK(); return Status::OK();
} }
~AddGradientFunction() override {} ~AddGradientFunction() override {}