[C++ Gradients] Remove superfluous Identity nodes from AddGradientFunction and just Ref
the incoming gradient.
PiperOrigin-RevId: 331808120 Change-Id: I18b92395aad88931965e3bbeb451b286fb9ea1f2
This commit is contained in:
parent
6f0dd6eac4
commit
1009006e3e
@ -31,14 +31,11 @@ cc_library(
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/eager:abstract_operation",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:c_api_unified_internal",
|
||||
"//tensorflow/c/eager:gradients_internal",
|
||||
"//tensorflow/c/experimental/ops:array_ops",
|
||||
"//tensorflow/c/experimental/ops:math_ops",
|
||||
"//tensorflow/c/experimental/ops:nn_ops",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -22,10 +22,8 @@ limitations under the License.
|
||||
|
||||
using std::vector;
|
||||
using tensorflow::ops::Conj;
|
||||
using tensorflow::ops::Identity;
|
||||
using tensorflow::ops::MatMul;
|
||||
using tensorflow::ops::Mul;
|
||||
using tensorflow::ops::ZerosLike;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
@ -36,21 +34,14 @@ class AddGradientFunction : public GradientFunction {
|
||||
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||
vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
grad_outputs->resize(2);
|
||||
vector<AbstractTensorHandle*> identity_outputs(1);
|
||||
// TODO(b/145674566): Handle name unification in tracing code.
|
||||
// TODO(b/161805092): Support broadcasting.
|
||||
|
||||
std::string name = "Identity_A";
|
||||
TF_RETURN_IF_ERROR(ops::Identity(ctx->ctx, {grad_inputs[0]},
|
||||
absl::MakeSpan(identity_outputs),
|
||||
name.c_str()));
|
||||
(*grad_outputs)[0] = identity_outputs[0];
|
||||
DCHECK(grad_inputs[0]);
|
||||
(*grad_outputs)[0] = grad_inputs[0];
|
||||
(*grad_outputs)[1] = grad_inputs[0];
|
||||
|
||||
name = "Identity_B";
|
||||
TF_RETURN_IF_ERROR(ops::Identity(ctx->ctx, {grad_inputs[0]},
|
||||
absl::MakeSpan(identity_outputs),
|
||||
name.c_str()));
|
||||
(*grad_outputs)[1] = identity_outputs[0];
|
||||
(*grad_outputs)[0]->Ref();
|
||||
(*grad_outputs)[1]->Ref();
|
||||
return Status::OK();
|
||||
}
|
||||
~AddGradientFunction() override {}
|
||||
|
Loading…
Reference in New Issue
Block a user