Adding a gradient for Sqrt
PiperOrigin-RevId: 336161771 Change-Id: I7e66b81772132d4e45b951c975df503ac5469075
This commit is contained in:
parent
c4f499b603
commit
4d04d4cf32
@ -61,6 +61,7 @@ Status RegisterGradients(GradientRegistry* registry) {
|
||||
TF_RETURN_IF_ERROR(registry->Register("AddV2", AddRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("IdentityN", IdentityNRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("Sqrt", SqrtRegisterer));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -131,6 +132,37 @@ Status ExpGradModel(AbstractContext* ctx,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Computes
|
||||
// y = sqrt(inputs[0])
|
||||
// return grad(y, {inputs[0]})
|
||||
Status SqrtGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
std::vector<AbstractTensorHandle*> sqrt_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(
|
||||
ops::Sqrt(tape_ctx.get(), inputs, absl::MakeSpan(sqrt_outputs), "Sqrt"));
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
|
||||
std::vector<AbstractTensorHandle*> out_grads;
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(
|
||||
vspace, /*target_tensor_ids=*/{ToId(sqrt_outputs[0])},
|
||||
/*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads,
|
||||
/*build_default_zeros_grads=*/false));
|
||||
for (auto sqrt_output : sqrt_outputs) {
|
||||
sqrt_output->Unref();
|
||||
}
|
||||
outputs[0] = out_grads[0];
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Computes
|
||||
// ignored, y = IdentityN(inputs[0], inputs[1])
|
||||
// return grad(y, {inputs[0], inputs[1]})
|
||||
@ -401,6 +433,50 @@ TEST_P(CppGradients, TestExpGrad) {
|
||||
result_tensor = nullptr;
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestSqrtGrad) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
||||
GradientRegistry registry;
|
||||
Status s = RegisterGradients(®istry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
// Pseudo-code:
|
||||
//
|
||||
// tape.watch(x)
|
||||
// y = sqrt(x)
|
||||
// outputs = tape.gradient(y, x)
|
||||
std::vector<AbstractTensorHandle*> outputs(1);
|
||||
s = RunModel(SqrtGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
TF_Tensor* result_tensor;
|
||||
s = getValue(outputs[0], &result_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
||||
EXPECT_NEAR(*result_value, 0.5, 0.001);
|
||||
outputs[0]->Unref();
|
||||
TF_DeleteTensor(result_tensor);
|
||||
result_tensor = nullptr;
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestIdentityNGrad) {
|
||||
// Pseudo-code:
|
||||
//
|
||||
|
@ -24,6 +24,7 @@ using std::vector;
|
||||
using tensorflow::ops::Conj;
|
||||
using tensorflow::ops::MatMul;
|
||||
using tensorflow::ops::Mul;
|
||||
using tensorflow::ops::SqrtGrad;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
@ -72,6 +73,25 @@ class ExpGradientFunction : public GradientFunction {
|
||||
AbstractTensorHandlePtr exp_;
|
||||
};
|
||||
|
||||
class SqrtGradientFunction : public GradientFunction {
|
||||
public:
|
||||
explicit SqrtGradientFunction(AbstractTensorHandle* sqrt) : sqrt_(sqrt) {
|
||||
sqrt->Ref();
|
||||
}
|
||||
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||
vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
std::string name = "Sqrt_Grad";
|
||||
grad_outputs->resize(1);
|
||||
TF_RETURN_IF_ERROR(SqrtGrad(ctx->ctx, {sqrt_.get(), grad_inputs[0]},
|
||||
absl::MakeSpan(*grad_outputs), name.c_str()));
|
||||
return Status::OK();
|
||||
}
|
||||
~SqrtGradientFunction() override {}
|
||||
|
||||
private:
|
||||
AbstractTensorHandlePtr sqrt_;
|
||||
};
|
||||
|
||||
class MatMulGradientFunction : public GradientFunction {
|
||||
public:
|
||||
explicit MatMulGradientFunction(vector<AbstractTensorHandle*> f_inputs,
|
||||
@ -210,5 +230,14 @@ BackwardFunction* MatMulRegisterer(const ForwardOperation& op) {
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
}
|
||||
|
||||
BackwardFunction* SqrtRegisterer(const ForwardOperation& op) {
|
||||
auto gradient_function = new SqrtGradientFunction(op.outputs[0]);
|
||||
// For ops with a single output, the gradient function is not called if there
|
||||
// is no incoming gradient. So we do not need to worry about creating zeros
|
||||
// grads in this case.
|
||||
auto default_gradients = new PassThroughDefaultGradients(op);
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
}
|
||||
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
@ -19,10 +19,13 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
|
||||
BackwardFunction* AddRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* ExpRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* MatMulRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* SqrtRegisterer(const ForwardOperation& op);
|
||||
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_
|
||||
|
@ -144,5 +144,33 @@ Status Exp(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
return exp_op->Execute(outputs, &num_retvals);
|
||||
}
|
||||
|
||||
Status Sqrt(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr sqrt_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(sqrt_op->Reset("Sqrt", /*raw_device_name=*/nullptr));
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(sqrt_op.get(), name));
|
||||
TF_RETURN_IF_ERROR(sqrt_op->AddInput(inputs[0]));
|
||||
|
||||
int num_retvals = 1;
|
||||
Status s = sqrt_op->Execute(outputs, &num_retvals);
|
||||
return s;
|
||||
}
|
||||
|
||||
Status SqrtGrad(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr sqrt_grad_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(
|
||||
sqrt_grad_op->Reset("SqrtGrad", /*raw_device_name=*/nullptr));
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(sqrt_grad_op.get(), name));
|
||||
TF_RETURN_IF_ERROR(sqrt_grad_op->AddInput(inputs[0]));
|
||||
TF_RETURN_IF_ERROR(sqrt_grad_op->AddInput(inputs[1]));
|
||||
|
||||
int num_retvals = 1;
|
||||
Status s = sqrt_grad_op->Execute(outputs, &num_retvals);
|
||||
return s;
|
||||
}
|
||||
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
||||
|
@ -50,6 +50,15 @@ Status DivNoNan(AbstractContext* ctx,
|
||||
|
||||
Status Exp(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
Status Sqrt(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
Status SqrtGrad(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user