Adding a gradient for Sqrt

PiperOrigin-RevId: 336161771
Change-Id: I7e66b81772132d4e45b951c975df503ac5469075
This commit is contained in:
Rohan Jain 2020-10-08 14:10:11 -07:00 committed by TensorFlower Gardener
parent c4f499b603
commit 4d04d4cf32
5 changed files with 146 additions and 1 deletions

View File

@ -61,6 +61,7 @@ Status RegisterGradients(GradientRegistry* registry) {
TF_RETURN_IF_ERROR(registry->Register("AddV2", AddRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer));
TF_RETURN_IF_ERROR(registry->Register("IdentityN", IdentityNRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Sqrt", SqrtRegisterer));
return Status::OK();
}
@ -131,6 +132,37 @@ Status ExpGradModel(AbstractContext* ctx,
return Status::OK();
}
// Computes
// y = sqrt(inputs[0])
// return grad(y, {inputs[0]})
Status SqrtGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch x.
std::vector<AbstractTensorHandle*> sqrt_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(
ops::Sqrt(tape_ctx.get(), inputs, absl::MakeSpan(sqrt_outputs), "Sqrt"));
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
std::vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(sqrt_outputs[0])},
/*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
for (auto sqrt_output : sqrt_outputs) {
sqrt_output->Unref();
}
outputs[0] = out_grads[0];
delete tape;
return Status::OK();
}
// Computes
// ignored, y = IdentityN(inputs[0], inputs[1])
// return grad(y, {inputs[0], inputs[1]})
@ -401,6 +433,50 @@ TEST_P(CppGradients, TestExpGrad) {
result_tensor = nullptr;
}
TEST_P(CppGradients, TestSqrtGrad) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw);
}
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
// Pseudo-code:
//
// tape.watch(x)
// y = sqrt(x)
// outputs = tape.gradient(y, x)
std::vector<AbstractTensorHandle*> outputs(1);
s = RunModel(SqrtGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* result_tensor;
s = getValue(outputs[0], &result_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_NEAR(*result_value, 0.5, 0.001);
outputs[0]->Unref();
TF_DeleteTensor(result_tensor);
result_tensor = nullptr;
}
TEST_P(CppGradients, TestIdentityNGrad) {
// Pseudo-code:
//

View File

@ -24,6 +24,7 @@ using std::vector;
using tensorflow::ops::Conj;
using tensorflow::ops::MatMul;
using tensorflow::ops::Mul;
using tensorflow::ops::SqrtGrad;
namespace tensorflow {
namespace gradients {
@ -72,6 +73,25 @@ class ExpGradientFunction : public GradientFunction {
AbstractTensorHandlePtr exp_;
};
class SqrtGradientFunction : public GradientFunction {
public:
explicit SqrtGradientFunction(AbstractTensorHandle* sqrt) : sqrt_(sqrt) {
sqrt->Ref();
}
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
std::string name = "Sqrt_Grad";
grad_outputs->resize(1);
TF_RETURN_IF_ERROR(SqrtGrad(ctx->ctx, {sqrt_.get(), grad_inputs[0]},
absl::MakeSpan(*grad_outputs), name.c_str()));
return Status::OK();
}
~SqrtGradientFunction() override {}
private:
AbstractTensorHandlePtr sqrt_;
};
class MatMulGradientFunction : public GradientFunction {
public:
explicit MatMulGradientFunction(vector<AbstractTensorHandle*> f_inputs,
@ -210,5 +230,14 @@ BackwardFunction* MatMulRegisterer(const ForwardOperation& op) {
return new BackwardFunction(gradient_function, default_gradients);
}
BackwardFunction* SqrtRegisterer(const ForwardOperation& op) {
auto gradient_function = new SqrtGradientFunction(op.outputs[0]);
// For ops with a single output, the gradient function is not called if there
// is no incoming gradient. So we do not need to worry about creating zeros
// grads in this case.
auto default_gradients = new PassThroughDefaultGradients(op);
return new BackwardFunction(gradient_function, default_gradients);
}
} // namespace gradients
} // namespace tensorflow

View File

@ -19,10 +19,13 @@ limitations under the License.
namespace tensorflow {
namespace gradients {
BackwardFunction* AddRegisterer(const ForwardOperation& op);
BackwardFunction* ExpRegisterer(const ForwardOperation& op);
BackwardFunction* MatMulRegisterer(const ForwardOperation& op);
BackwardFunction* SqrtRegisterer(const ForwardOperation& op);
} // namespace gradients
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_

View File

@ -144,5 +144,33 @@ Status Exp(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
return exp_op->Execute(outputs, &num_retvals);
}
Status Sqrt(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr sqrt_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(sqrt_op->Reset("Sqrt", /*raw_device_name=*/nullptr));
TF_RETURN_IF_ERROR(MaybeSetOpName(sqrt_op.get(), name));
TF_RETURN_IF_ERROR(sqrt_op->AddInput(inputs[0]));
int num_retvals = 1;
Status s = sqrt_op->Execute(outputs, &num_retvals);
return s;
}
Status SqrtGrad(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr sqrt_grad_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(
sqrt_grad_op->Reset("SqrtGrad", /*raw_device_name=*/nullptr));
TF_RETURN_IF_ERROR(MaybeSetOpName(sqrt_grad_op.get(), name));
TF_RETURN_IF_ERROR(sqrt_grad_op->AddInput(inputs[0]));
TF_RETURN_IF_ERROR(sqrt_grad_op->AddInput(inputs[1]));
int num_retvals = 1;
Status s = sqrt_grad_op->Execute(outputs, &num_retvals);
return s;
}
} // namespace ops
} // namespace tensorflow

View File

@ -50,6 +50,15 @@ Status DivNoNan(AbstractContext* ctx,
Status Exp(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status Sqrt(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status SqrtGrad(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
} // namespace ops
} // namespace tensorflow