diff --git a/tensorflow/c/eager/gradients.cc b/tensorflow/c/eager/gradients.cc index fc9c19d127c..58ffcf247cf 100644 --- a/tensorflow/c/eager/gradients.cc +++ b/tensorflow/c/eager/gradients.cc @@ -191,7 +191,7 @@ Status TapeVSpace::CallBackwardFunction( &ctx, incoming_gradients, result); } -Status TapeVSpace::BuildOnesLike(TapeTensor t, +Status TapeVSpace::BuildOnesLike(const TapeTensor& t, AbstractTensorHandle** result) const { AbstractOperationPtr op(ctx_->CreateOperation()); TF_RETURN_IF_ERROR(op->Reset("OnesLike", /*raw_device_name=*/nullptr)); diff --git a/tensorflow/c/eager/gradients.h b/tensorflow/c/eager/gradients.h index d240806c044..f7d80cbeb34 100644 --- a/tensorflow/c/eager/gradients.h +++ b/tensorflow/c/eager/gradients.h @@ -180,10 +180,6 @@ int64 ToId(AbstractTensorHandle* t); // allow us to trace the data dependencies between operations and hence compute // gradients. // -// This also implements `OnesLike` to create the default -// incoming gradients for tensors which do not already have an incoming -// gradient. -// // `ZerosLike` is not expected to be called and returns a nullptr. The creation // of default zeros grads is handled by the `DefaultGradientFunction` registered // for each op. @@ -233,7 +229,7 @@ class TapeVSpace std::vector* result) const override; // Builds a tensor filled with ones with the same shape and dtype as `t`. - Status BuildOnesLike(TapeTensor t, + Status BuildOnesLike(const TapeTensor& t, AbstractTensorHandle** result) const override; // Looks up the ID of a Gradient. diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index 80db01c7924..efab4dfbeb2 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -100,7 +100,8 @@ class VSpace { std::vector* result) const = 0; // Builds a tensor filled with ones with the same shape and dtype as `t`. - virtual Status BuildOnesLike(TapeTensor t, Gradient** result) const = 0; + virtual Status BuildOnesLike(const TapeTensor& t, + Gradient** result) const = 0; // Looks up the ID of a Gradient. virtual int64 TensorId(Gradient* tensor) const = 0; diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 1ab72a5125b..128fb09d114 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -1283,7 +1283,8 @@ class PyVSpace : public tensorflow::eager::VSpace