Update signature of VSpace::BuildOnesLike.

Remove obsolete comment.

PiperOrigin-RevId: 334637944
Change-Id: Ia376cbf5ba8f4d2ab1f8425a8dc6fd3c0c273a70
This commit is contained in:
Saurabh Saxena 2020-09-30 11:08:36 -07:00 committed by TensorFlower Gardener
parent 221282c169
commit a0dca5c683
4 changed files with 6 additions and 8 deletions

View File

@ -191,7 +191,7 @@ Status TapeVSpace::CallBackwardFunction(
&ctx, incoming_gradients, result); &ctx, incoming_gradients, result);
} }
Status TapeVSpace::BuildOnesLike(TapeTensor t, Status TapeVSpace::BuildOnesLike(const TapeTensor& t,
AbstractTensorHandle** result) const { AbstractTensorHandle** result) const {
AbstractOperationPtr op(ctx_->CreateOperation()); AbstractOperationPtr op(ctx_->CreateOperation());
TF_RETURN_IF_ERROR(op->Reset("OnesLike", /*raw_device_name=*/nullptr)); TF_RETURN_IF_ERROR(op->Reset("OnesLike", /*raw_device_name=*/nullptr));

View File

@ -180,10 +180,6 @@ int64 ToId(AbstractTensorHandle* t);
// allow us to trace the data dependencies between operations and hence compute // allow us to trace the data dependencies between operations and hence compute
// gradients. // gradients.
// //
// This also implements `OnesLike` to create the default
// incoming gradients for tensors which do not already have an incoming
// gradient.
//
// `ZerosLike` is not expected to be called and returns a nullptr. The creation // `ZerosLike` is not expected to be called and returns a nullptr. The creation
// of default zeros grads is handled by the `DefaultGradientFunction` registered // of default zeros grads is handled by the `DefaultGradientFunction` registered
// for each op. // for each op.
@ -233,7 +229,7 @@ class TapeVSpace
std::vector<AbstractTensorHandle*>* result) const override; std::vector<AbstractTensorHandle*>* result) const override;
// Builds a tensor filled with ones with the same shape and dtype as `t`. // Builds a tensor filled with ones with the same shape and dtype as `t`.
Status BuildOnesLike(TapeTensor t, Status BuildOnesLike(const TapeTensor& t,
AbstractTensorHandle** result) const override; AbstractTensorHandle** result) const override;
// Looks up the ID of a Gradient. // Looks up the ID of a Gradient.

View File

@ -100,7 +100,8 @@ class VSpace {
std::vector<Gradient*>* result) const = 0; std::vector<Gradient*>* result) const = 0;
// Builds a tensor filled with ones with the same shape and dtype as `t`. // Builds a tensor filled with ones with the same shape and dtype as `t`.
virtual Status BuildOnesLike(TapeTensor t, Gradient** result) const = 0; virtual Status BuildOnesLike(const TapeTensor& t,
Gradient** result) const = 0;
// Looks up the ID of a Gradient. // Looks up the ID of a Gradient.
virtual int64 TensorId(Gradient* tensor) const = 0; virtual int64 TensorId(Gradient* tensor) const = 0;

View File

@ -1283,7 +1283,8 @@ class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction,
} }
// Builds a tensor filled with ones with the same shape and dtype as `t`. // Builds a tensor filled with ones with the same shape and dtype as `t`.
Status BuildOnesLike(PyTapeTensor t, PyObject** result) const override { Status BuildOnesLike(const PyTapeTensor& t,
PyObject** result) const override {
*result = t.OnesLike(); *result = t.OnesLike();
return Status::OK(); return Status::OK();
} }