[C++ Gradients] Remove superfluous Identity nodes from AddGradientFunction and just Ref the incoming gradient.

PiperOrigin-RevId: 331808120
Change-Id: I18b92395aad88931965e3bbeb451b286fb9ea1f2
This commit is contained in:
Saurabh Saxena 2020-09-15 11:02:18 -07:00 committed by TensorFlower Gardener
parent 6f0dd6eac4
commit 1009006e3e
2 changed files with 5 additions and 17 deletions

View File

@ -31,14 +31,11 @@ cc_library(
"//tensorflow:internal",
],
deps = [
"//tensorflow/c/eager:abstract_operation",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:c_api_unified_internal",
"//tensorflow/c/eager:gradients_internal",
"//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/c/experimental/ops:math_ops",
"//tensorflow/c/experimental/ops:nn_ops",
"//tensorflow/core/lib/llvm_rtti",
],
)

View File

@ -22,10 +22,8 @@ limitations under the License.
using std::vector;
using tensorflow::ops::Conj;
using tensorflow::ops::Identity;
using tensorflow::ops::MatMul;
using tensorflow::ops::Mul;
using tensorflow::ops::ZerosLike;
namespace tensorflow {
namespace gradients {
@ -36,21 +34,14 @@ class AddGradientFunction : public GradientFunction {
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
grad_outputs->resize(2);
vector<AbstractTensorHandle*> identity_outputs(1);
// TODO(b/145674566): Handle name unification in tracing code.
// TODO(b/161805092): Support broadcasting.
std::string name = "Identity_A";
TF_RETURN_IF_ERROR(ops::Identity(ctx->ctx, {grad_inputs[0]},
absl::MakeSpan(identity_outputs),
name.c_str()));
(*grad_outputs)[0] = identity_outputs[0];
DCHECK(grad_inputs[0]);
(*grad_outputs)[0] = grad_inputs[0];
(*grad_outputs)[1] = grad_inputs[0];
name = "Identity_B";
TF_RETURN_IF_ERROR(ops::Identity(ctx->ctx, {grad_inputs[0]},
absl::MakeSpan(identity_outputs),
name.c_str()));
(*grad_outputs)[1] = identity_outputs[0];
(*grad_outputs)[0]->Ref();
(*grad_outputs)[1]->Ref();
return Status::OK();
}
~AddGradientFunction() override {}