Update signature of VSpace::BuildOnesLike.
Remove obsolete comment. PiperOrigin-RevId: 334637944 Change-Id: Ia376cbf5ba8f4d2ab1f8425a8dc6fd3c0c273a70
This commit is contained in:
parent
221282c169
commit
a0dca5c683
@ -191,7 +191,7 @@ Status TapeVSpace::CallBackwardFunction(
|
||||
&ctx, incoming_gradients, result);
|
||||
}
|
||||
|
||||
Status TapeVSpace::BuildOnesLike(TapeTensor t,
|
||||
Status TapeVSpace::BuildOnesLike(const TapeTensor& t,
|
||||
AbstractTensorHandle** result) const {
|
||||
AbstractOperationPtr op(ctx_->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(op->Reset("OnesLike", /*raw_device_name=*/nullptr));
|
||||
|
@ -180,10 +180,6 @@ int64 ToId(AbstractTensorHandle* t);
|
||||
// allow us to trace the data dependencies between operations and hence compute
|
||||
// gradients.
|
||||
//
|
||||
// This also implements `OnesLike` to create the default
|
||||
// incoming gradients for tensors which do not already have an incoming
|
||||
// gradient.
|
||||
//
|
||||
// `ZerosLike` is not expected to be called and returns a nullptr. The creation
|
||||
// of default zeros grads is handled by the `DefaultGradientFunction` registered
|
||||
// for each op.
|
||||
@ -233,7 +229,7 @@ class TapeVSpace
|
||||
std::vector<AbstractTensorHandle*>* result) const override;
|
||||
|
||||
// Builds a tensor filled with ones with the same shape and dtype as `t`.
|
||||
Status BuildOnesLike(TapeTensor t,
|
||||
Status BuildOnesLike(const TapeTensor& t,
|
||||
AbstractTensorHandle** result) const override;
|
||||
|
||||
// Looks up the ID of a Gradient.
|
||||
|
@ -100,7 +100,8 @@ class VSpace {
|
||||
std::vector<Gradient*>* result) const = 0;
|
||||
|
||||
// Builds a tensor filled with ones with the same shape and dtype as `t`.
|
||||
virtual Status BuildOnesLike(TapeTensor t, Gradient** result) const = 0;
|
||||
virtual Status BuildOnesLike(const TapeTensor& t,
|
||||
Gradient** result) const = 0;
|
||||
|
||||
// Looks up the ID of a Gradient.
|
||||
virtual int64 TensorId(Gradient* tensor) const = 0;
|
||||
|
@ -1283,7 +1283,8 @@ class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction,
|
||||
}
|
||||
|
||||
// Builds a tensor filled with ones with the same shape and dtype as `t`.
|
||||
Status BuildOnesLike(PyTapeTensor t, PyObject** result) const override {
|
||||
Status BuildOnesLike(const PyTapeTensor& t,
|
||||
PyObject** result) const override {
|
||||
*result = t.OnesLike();
|
||||
return Status::OK();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user