Adding a gradient for Sqrt
PiperOrigin-RevId: 336161771 Change-Id: I7e66b81772132d4e45b951c975df503ac5469075
This commit is contained in:
parent
c4f499b603
commit
4d04d4cf32
@ -61,6 +61,7 @@ Status RegisterGradients(GradientRegistry* registry) {
|
|||||||
TF_RETURN_IF_ERROR(registry->Register("AddV2", AddRegisterer));
|
TF_RETURN_IF_ERROR(registry->Register("AddV2", AddRegisterer));
|
||||||
TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer));
|
TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer));
|
||||||
TF_RETURN_IF_ERROR(registry->Register("IdentityN", IdentityNRegisterer));
|
TF_RETURN_IF_ERROR(registry->Register("IdentityN", IdentityNRegisterer));
|
||||||
|
TF_RETURN_IF_ERROR(registry->Register("Sqrt", SqrtRegisterer));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -131,6 +132,37 @@ Status ExpGradModel(AbstractContext* ctx,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Computes
|
||||||
|
// y = sqrt(inputs[0])
|
||||||
|
// return grad(y, {inputs[0]})
|
||||||
|
Status SqrtGradModel(AbstractContext* ctx,
|
||||||
|
absl::Span<AbstractTensorHandle* const> inputs,
|
||||||
|
absl::Span<AbstractTensorHandle*> outputs,
|
||||||
|
const GradientRegistry& registry) {
|
||||||
|
TapeVSpace vspace(ctx);
|
||||||
|
auto tape = new Tape(/*persistent=*/false);
|
||||||
|
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||||
|
std::vector<AbstractTensorHandle*> sqrt_outputs(1);
|
||||||
|
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
ops::Sqrt(tape_ctx.get(), inputs, absl::MakeSpan(sqrt_outputs), "Sqrt"));
|
||||||
|
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||||
|
source_tensors_that_are_targets;
|
||||||
|
|
||||||
|
std::vector<AbstractTensorHandle*> out_grads;
|
||||||
|
TF_RETURN_IF_ERROR(tape->ComputeGradient(
|
||||||
|
vspace, /*target_tensor_ids=*/{ToId(sqrt_outputs[0])},
|
||||||
|
/*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets,
|
||||||
|
/*output_gradients=*/{}, &out_grads,
|
||||||
|
/*build_default_zeros_grads=*/false));
|
||||||
|
for (auto sqrt_output : sqrt_outputs) {
|
||||||
|
sqrt_output->Unref();
|
||||||
|
}
|
||||||
|
outputs[0] = out_grads[0];
|
||||||
|
delete tape;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
// Computes
|
// Computes
|
||||||
// ignored, y = IdentityN(inputs[0], inputs[1])
|
// ignored, y = IdentityN(inputs[0], inputs[1])
|
||||||
// return grad(y, {inputs[0], inputs[1]})
|
// return grad(y, {inputs[0], inputs[1]})
|
||||||
@ -401,6 +433,50 @@ TEST_P(CppGradients, TestExpGrad) {
|
|||||||
result_tensor = nullptr;
|
result_tensor = nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_P(CppGradients, TestSqrtGrad) {
|
||||||
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
|
AbstractContextPtr ctx;
|
||||||
|
{
|
||||||
|
AbstractContext* ctx_raw = nullptr;
|
||||||
|
Status s =
|
||||||
|
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||||
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
ctx.reset(ctx_raw);
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractTensorHandlePtr x;
|
||||||
|
{
|
||||||
|
AbstractTensorHandle* x_raw = nullptr;
|
||||||
|
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
|
||||||
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
x.reset(x_raw);
|
||||||
|
}
|
||||||
|
|
||||||
|
GradientRegistry registry;
|
||||||
|
Status s = RegisterGradients(®istry);
|
||||||
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
|
||||||
|
// Pseudo-code:
|
||||||
|
//
|
||||||
|
// tape.watch(x)
|
||||||
|
// y = sqrt(x)
|
||||||
|
// outputs = tape.gradient(y, x)
|
||||||
|
std::vector<AbstractTensorHandle*> outputs(1);
|
||||||
|
s = RunModel(SqrtGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs),
|
||||||
|
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||||
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
|
||||||
|
TF_Tensor* result_tensor;
|
||||||
|
s = getValue(outputs[0], &result_tensor);
|
||||||
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
||||||
|
EXPECT_NEAR(*result_value, 0.5, 0.001);
|
||||||
|
outputs[0]->Unref();
|
||||||
|
TF_DeleteTensor(result_tensor);
|
||||||
|
result_tensor = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
TEST_P(CppGradients, TestIdentityNGrad) {
|
TEST_P(CppGradients, TestIdentityNGrad) {
|
||||||
// Pseudo-code:
|
// Pseudo-code:
|
||||||
//
|
//
|
||||||
|
@ -24,6 +24,7 @@ using std::vector;
|
|||||||
using tensorflow::ops::Conj;
|
using tensorflow::ops::Conj;
|
||||||
using tensorflow::ops::MatMul;
|
using tensorflow::ops::MatMul;
|
||||||
using tensorflow::ops::Mul;
|
using tensorflow::ops::Mul;
|
||||||
|
using tensorflow::ops::SqrtGrad;
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace gradients {
|
namespace gradients {
|
||||||
@ -72,6 +73,25 @@ class ExpGradientFunction : public GradientFunction {
|
|||||||
AbstractTensorHandlePtr exp_;
|
AbstractTensorHandlePtr exp_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class SqrtGradientFunction : public GradientFunction {
|
||||||
|
public:
|
||||||
|
explicit SqrtGradientFunction(AbstractTensorHandle* sqrt) : sqrt_(sqrt) {
|
||||||
|
sqrt->Ref();
|
||||||
|
}
|
||||||
|
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||||
|
vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||||
|
std::string name = "Sqrt_Grad";
|
||||||
|
grad_outputs->resize(1);
|
||||||
|
TF_RETURN_IF_ERROR(SqrtGrad(ctx->ctx, {sqrt_.get(), grad_inputs[0]},
|
||||||
|
absl::MakeSpan(*grad_outputs), name.c_str()));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
~SqrtGradientFunction() override {}
|
||||||
|
|
||||||
|
private:
|
||||||
|
AbstractTensorHandlePtr sqrt_;
|
||||||
|
};
|
||||||
|
|
||||||
class MatMulGradientFunction : public GradientFunction {
|
class MatMulGradientFunction : public GradientFunction {
|
||||||
public:
|
public:
|
||||||
explicit MatMulGradientFunction(vector<AbstractTensorHandle*> f_inputs,
|
explicit MatMulGradientFunction(vector<AbstractTensorHandle*> f_inputs,
|
||||||
@ -210,5 +230,14 @@ BackwardFunction* MatMulRegisterer(const ForwardOperation& op) {
|
|||||||
return new BackwardFunction(gradient_function, default_gradients);
|
return new BackwardFunction(gradient_function, default_gradients);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
BackwardFunction* SqrtRegisterer(const ForwardOperation& op) {
|
||||||
|
auto gradient_function = new SqrtGradientFunction(op.outputs[0]);
|
||||||
|
// For ops with a single output, the gradient function is not called if there
|
||||||
|
// is no incoming gradient. So we do not need to worry about creating zeros
|
||||||
|
// grads in this case.
|
||||||
|
auto default_gradients = new PassThroughDefaultGradients(op);
|
||||||
|
return new BackwardFunction(gradient_function, default_gradients);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gradients
|
} // namespace gradients
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -19,9 +19,12 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace gradients {
|
namespace gradients {
|
||||||
|
|
||||||
BackwardFunction* AddRegisterer(const ForwardOperation& op);
|
BackwardFunction* AddRegisterer(const ForwardOperation& op);
|
||||||
BackwardFunction* ExpRegisterer(const ForwardOperation& op);
|
BackwardFunction* ExpRegisterer(const ForwardOperation& op);
|
||||||
BackwardFunction* MatMulRegisterer(const ForwardOperation& op);
|
BackwardFunction* MatMulRegisterer(const ForwardOperation& op);
|
||||||
|
BackwardFunction* SqrtRegisterer(const ForwardOperation& op);
|
||||||
|
|
||||||
} // namespace gradients
|
} // namespace gradients
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -144,5 +144,33 @@ Status Exp(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
|||||||
return exp_op->Execute(outputs, &num_retvals);
|
return exp_op->Execute(outputs, &num_retvals);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status Sqrt(AbstractContext* ctx,
|
||||||
|
absl::Span<AbstractTensorHandle* const> inputs,
|
||||||
|
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||||
|
AbstractOperationPtr sqrt_op(ctx->CreateOperation());
|
||||||
|
TF_RETURN_IF_ERROR(sqrt_op->Reset("Sqrt", /*raw_device_name=*/nullptr));
|
||||||
|
TF_RETURN_IF_ERROR(MaybeSetOpName(sqrt_op.get(), name));
|
||||||
|
TF_RETURN_IF_ERROR(sqrt_op->AddInput(inputs[0]));
|
||||||
|
|
||||||
|
int num_retvals = 1;
|
||||||
|
Status s = sqrt_op->Execute(outputs, &num_retvals);
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status SqrtGrad(AbstractContext* ctx,
|
||||||
|
absl::Span<AbstractTensorHandle* const> inputs,
|
||||||
|
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||||
|
AbstractOperationPtr sqrt_grad_op(ctx->CreateOperation());
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
sqrt_grad_op->Reset("SqrtGrad", /*raw_device_name=*/nullptr));
|
||||||
|
TF_RETURN_IF_ERROR(MaybeSetOpName(sqrt_grad_op.get(), name));
|
||||||
|
TF_RETURN_IF_ERROR(sqrt_grad_op->AddInput(inputs[0]));
|
||||||
|
TF_RETURN_IF_ERROR(sqrt_grad_op->AddInput(inputs[1]));
|
||||||
|
|
||||||
|
int num_retvals = 1;
|
||||||
|
Status s = sqrt_grad_op->Execute(outputs, &num_retvals);
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -50,6 +50,15 @@ Status DivNoNan(AbstractContext* ctx,
|
|||||||
|
|
||||||
Status Exp(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
Status Exp(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||||
|
|
||||||
|
Status Sqrt(AbstractContext* ctx,
|
||||||
|
absl::Span<AbstractTensorHandle* const> inputs,
|
||||||
|
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||||
|
|
||||||
|
Status SqrtGrad(AbstractContext* ctx,
|
||||||
|
absl::Span<AbstractTensorHandle* const> inputs,
|
||||||
|
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||||
|
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user