Update signature of VSpace::BuildOnesLike.
Remove obsolete comment. PiperOrigin-RevId: 334637944 Change-Id: Ia376cbf5ba8f4d2ab1f8425a8dc6fd3c0c273a70
This commit is contained in:
parent
221282c169
commit
a0dca5c683
@ -191,7 +191,7 @@ Status TapeVSpace::CallBackwardFunction(
|
|||||||
&ctx, incoming_gradients, result);
|
&ctx, incoming_gradients, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status TapeVSpace::BuildOnesLike(TapeTensor t,
|
Status TapeVSpace::BuildOnesLike(const TapeTensor& t,
|
||||||
AbstractTensorHandle** result) const {
|
AbstractTensorHandle** result) const {
|
||||||
AbstractOperationPtr op(ctx_->CreateOperation());
|
AbstractOperationPtr op(ctx_->CreateOperation());
|
||||||
TF_RETURN_IF_ERROR(op->Reset("OnesLike", /*raw_device_name=*/nullptr));
|
TF_RETURN_IF_ERROR(op->Reset("OnesLike", /*raw_device_name=*/nullptr));
|
||||||
|
@ -180,10 +180,6 @@ int64 ToId(AbstractTensorHandle* t);
|
|||||||
// allow us to trace the data dependencies between operations and hence compute
|
// allow us to trace the data dependencies between operations and hence compute
|
||||||
// gradients.
|
// gradients.
|
||||||
//
|
//
|
||||||
// This also implements `OnesLike` to create the default
|
|
||||||
// incoming gradients for tensors which do not already have an incoming
|
|
||||||
// gradient.
|
|
||||||
//
|
|
||||||
// `ZerosLike` is not expected to be called and returns a nullptr. The creation
|
// `ZerosLike` is not expected to be called and returns a nullptr. The creation
|
||||||
// of default zeros grads is handled by the `DefaultGradientFunction` registered
|
// of default zeros grads is handled by the `DefaultGradientFunction` registered
|
||||||
// for each op.
|
// for each op.
|
||||||
@ -233,7 +229,7 @@ class TapeVSpace
|
|||||||
std::vector<AbstractTensorHandle*>* result) const override;
|
std::vector<AbstractTensorHandle*>* result) const override;
|
||||||
|
|
||||||
// Builds a tensor filled with ones with the same shape and dtype as `t`.
|
// Builds a tensor filled with ones with the same shape and dtype as `t`.
|
||||||
Status BuildOnesLike(TapeTensor t,
|
Status BuildOnesLike(const TapeTensor& t,
|
||||||
AbstractTensorHandle** result) const override;
|
AbstractTensorHandle** result) const override;
|
||||||
|
|
||||||
// Looks up the ID of a Gradient.
|
// Looks up the ID of a Gradient.
|
||||||
|
@ -100,7 +100,8 @@ class VSpace {
|
|||||||
std::vector<Gradient*>* result) const = 0;
|
std::vector<Gradient*>* result) const = 0;
|
||||||
|
|
||||||
// Builds a tensor filled with ones with the same shape and dtype as `t`.
|
// Builds a tensor filled with ones with the same shape and dtype as `t`.
|
||||||
virtual Status BuildOnesLike(TapeTensor t, Gradient** result) const = 0;
|
virtual Status BuildOnesLike(const TapeTensor& t,
|
||||||
|
Gradient** result) const = 0;
|
||||||
|
|
||||||
// Looks up the ID of a Gradient.
|
// Looks up the ID of a Gradient.
|
||||||
virtual int64 TensorId(Gradient* tensor) const = 0;
|
virtual int64 TensorId(Gradient* tensor) const = 0;
|
||||||
|
@ -1283,7 +1283,8 @@ class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Builds a tensor filled with ones with the same shape and dtype as `t`.
|
// Builds a tensor filled with ones with the same shape and dtype as `t`.
|
||||||
Status BuildOnesLike(PyTapeTensor t, PyObject** result) const override {
|
Status BuildOnesLike(const PyTapeTensor& t,
|
||||||
|
PyObject** result) const override {
|
||||||
*result = t.OnesLike();
|
*result = t.OnesLike();
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user