add BuildGradModel

This commit is contained in:
Võ Văn Nghĩa 2021-01-14 00:56:12 +07:00
parent 1c1a91e70f
commit 6a79c9186a
4 changed files with 51 additions and 164 deletions

View File

@ -166,7 +166,9 @@ cc_library(
],
deps = [
"//tensorflow/c/eager:gradient_checker",
"//tensorflow/c/eager:gradients_internal",
"//tensorflow/c/eager:unified_api_testutil",
"//tensorflow/c/experimental/gradients/tape:tape_context",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/c/experimental/gradients/grad_test_helper.h"
#include "tensorflow/c/eager/gradient_checker.h"
#include "tensorflow/c/experimental/gradients/tape/tape_context.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@ -105,6 +106,34 @@ void CheckTensorValue(AbstractTensorHandle* t, absl::Span<const float> manuals,
delete[] danalytical;
}
Model BuildGradModel(Model forward, size_t num_inputs, const string& op,
GradientFunctionFactory gradient_function_factory) {
return [&, forward, gradient_function_factory](
AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) -> Status {
GradientRegistry registry;
TF_RETURN_IF_ERROR(registry.Register(op, gradient_function_factory));
Tape tape(/*persistent=*/false);
for (size_t i{}; i < num_inputs; ++i) {
tape.Watch(inputs[i]);
}
std::vector<AbstractTensorHandle*> temp_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, registry));
TF_RETURN_IF_ERROR(
forward(tape_ctx.get(), inputs, absl::MakeSpan(temp_outputs)));
TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs,
/*sources=*/inputs,
/*output_gradients=*/{}, outputs));
for (auto temp_output : temp_outputs) {
temp_output->Unref();
}
return Status::OK();
};
}
} // namespace internal
} // namespace gradients
} // namespace tensorflow

View File

@ -15,6 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_GRAD_TEST_HELPER_H_
#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_GRAD_TEST_HELPER_H_
#include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/eager/unified_api_testutil.h"
namespace tensorflow {
@ -29,6 +30,9 @@ void CompareNumericalAndAutodiffGradients(
void CheckTensorValue(AbstractTensorHandle* t, absl::Span<const float> manuals,
absl::Span<const int64_t> dims, double abs_error = 1e-2);
Model BuildGradModel(Model forward, size_t num_inputs, const string& op,
GradientFunctionFactory gradient_function_factory);
} // namespace internal
} // namespace gradients
} // namespace tensorflow

View File

@ -36,199 +36,42 @@ Status AddModel(AbstractContext* ctx,
return ops::Add(ctx, inputs, outputs, "Add");
}
Status AddGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
GradientRegistry registry;
TF_RETURN_IF_ERROR(registry.Register("AddV2", AddRegisterer));
Tape tape(/*persistent=*/false);
tape.Watch(inputs[0]);
tape.Watch(inputs[1]);
std::vector<AbstractTensorHandle*> temp_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, registry));
TF_RETURN_IF_ERROR(ops::Add(tape_ctx.get(), inputs,
absl::MakeSpan(temp_outputs), "AddGrad"));
TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs,
/*sources=*/inputs,
/*output_gradients=*/{}, outputs));
for (auto temp_output : temp_outputs) {
temp_output->Unref();
}
return Status::OK();
}
Status ExpModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
return ops::Exp(ctx, inputs, outputs, "Exp");
}
Status ExpGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
GradientRegistry registry;
TF_RETURN_IF_ERROR(registry.Register("Exp", ExpRegisterer));
Tape tape(/*persistent=*/false);
tape.Watch(inputs[0]);
std::vector<AbstractTensorHandle*> temp_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, registry));
TF_RETURN_IF_ERROR(ops::Exp(tape_ctx.get(), inputs,
absl::MakeSpan(temp_outputs), "ExpGrad"));
TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs,
/*sources=*/inputs,
/*output_gradients=*/{}, outputs));
for (auto temp_output : temp_outputs) {
temp_output->Unref();
}
return Status::OK();
}
Status SqrtModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
return ops::Sqrt(ctx, inputs, outputs, "Sqrt");
}
Status SqrtGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
GradientRegistry registry;
TF_RETURN_IF_ERROR(registry.Register("Sqrt", SqrtRegisterer));
Tape tape(/*persistent=*/false);
tape.Watch(inputs[0]);
std::vector<AbstractTensorHandle*> temp_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, registry));
TF_RETURN_IF_ERROR(ops::Sqrt(tape_ctx.get(), inputs,
absl::MakeSpan(temp_outputs), "SqrtGrad"));
TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs,
/*sources=*/inputs,
/*output_gradients=*/{}, outputs));
for (auto temp_output : temp_outputs) {
temp_output->Unref();
}
return Status::OK();
}
Status NegModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
return ops::Neg(ctx, inputs, outputs, "Neg");
}
Status NegGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
GradientRegistry registry;
TF_RETURN_IF_ERROR(registry.Register("Neg", NegRegisterer));
Tape tape(/*persistent=*/false);
tape.Watch(inputs[0]);
std::vector<AbstractTensorHandle*> temp_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, registry));
TF_RETURN_IF_ERROR(ops::Neg(tape_ctx.get(), inputs,
absl::MakeSpan(temp_outputs), "NegGrad"));
TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs,
/*sources=*/inputs,
/*output_gradients=*/{}, outputs));
for (auto temp_output : temp_outputs) {
temp_output->Unref();
}
return Status::OK();
}
Status SubModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
return ops::Sub(ctx, inputs, outputs, "Sub");
}
Status SubGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
GradientRegistry registry;
TF_RETURN_IF_ERROR(registry.Register("Sub", SubRegisterer));
Tape tape(/*persistent=*/false);
tape.Watch(inputs[0]);
tape.Watch(inputs[1]);
std::vector<AbstractTensorHandle*> temp_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, registry));
TF_RETURN_IF_ERROR(ops::Sub(tape_ctx.get(), inputs,
absl::MakeSpan(temp_outputs), "SubGrad"));
TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs,
/*sources=*/inputs,
/*output_gradients=*/{}, outputs));
for (auto temp_output : temp_outputs) {
temp_output->Unref();
}
return Status::OK();
}
Status MulModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
return ops::Mul(ctx, inputs, outputs, "Mul");
}
Status MulGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
GradientRegistry registry;
TF_RETURN_IF_ERROR(registry.Register("Mul", MulRegisterer));
Tape tape(/*persistent=*/false);
tape.Watch(inputs[0]);
tape.Watch(inputs[1]);
std::vector<AbstractTensorHandle*> temp_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, registry));
TF_RETURN_IF_ERROR(ops::Mul(tape_ctx.get(), inputs,
absl::MakeSpan(temp_outputs), "MulGrad"));
TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs,
/*sources=*/inputs,
/*output_gradients=*/{}, outputs));
for (auto temp_output : temp_outputs) {
temp_output->Unref();
}
return Status::OK();
}
Status Log1pModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
return ops::Log1p(ctx, inputs, outputs, "Log1p");
}
Status Log1pGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
GradientRegistry registry;
TF_RETURN_IF_ERROR(registry.Register("Log1p", Log1pRegisterer));
Tape tape(/*persistent=*/false);
tape.Watch(inputs[0]);
std::vector<AbstractTensorHandle*> temp_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, registry));
TF_RETURN_IF_ERROR(ops::Log1p(tape_ctx.get(), inputs,
absl::MakeSpan(temp_outputs), "Log1pGrad"));
TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs,
/*sources=*/inputs,
/*output_gradients=*/{}, outputs));
for (auto temp_output : temp_outputs) {
temp_output->Unref();
}
return Status::OK();
}
Status DivNoNanModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
@ -301,7 +144,8 @@ TEST_P(CppGradients, TestAddGrad) {
}
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
AddModel, AddGradModel, ctx_.get(), {x.get(), y.get()}, UseFunction()));
AddModel, BuildGradModel(AddModel, 2, "AddV2", AddRegisterer), ctx_.get(),
{x.get(), y.get()}, UseFunction()));
}
TEST_P(CppGradients, TestExpGrad) {
@ -314,7 +158,8 @@ TEST_P(CppGradients, TestExpGrad) {
}
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
ExpModel, ExpGradModel, ctx_.get(), {x.get()}, UseFunction()));
ExpModel, BuildGradModel(ExpModel, 1, "Exp", ExpRegisterer), ctx_.get(),
{x.get()}, UseFunction()));
}
TEST_P(CppGradients, TestSqrtGrad) {
@ -327,7 +172,8 @@ TEST_P(CppGradients, TestSqrtGrad) {
}
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
SqrtModel, SqrtGradModel, ctx_.get(), {x.get()}, UseFunction()));
SqrtModel, BuildGradModel(SqrtModel, 1, "Sqrt", SqrtRegisterer),
ctx_.get(), {x.get()}, UseFunction()));
}
TEST_P(CppGradients, TestNegGrad) {
@ -340,7 +186,8 @@ TEST_P(CppGradients, TestNegGrad) {
}
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
NegModel, NegGradModel, ctx_.get(), {x.get()}, UseFunction()));
NegModel, BuildGradModel(NegModel, 1, "Neg", NegRegisterer), ctx_.get(),
{x.get()}, UseFunction()));
}
TEST_P(CppGradients, TestSubGrad) {
@ -361,7 +208,8 @@ TEST_P(CppGradients, TestSubGrad) {
}
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
SubModel, SubGradModel, ctx_.get(), {x.get(), y.get()}, UseFunction()));
SubModel, BuildGradModel(SubModel, 2, "Sub", SubRegisterer), ctx_.get(),
{x.get(), y.get()}, UseFunction()));
}
TEST_P(CppGradients, TestMulGrad) {
@ -382,7 +230,8 @@ TEST_P(CppGradients, TestMulGrad) {
}
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
MulModel, MulGradModel, ctx_.get(), {x.get(), y.get()}, UseFunction()));
MulModel, BuildGradModel(MulModel, 2, "Mul", MulRegisterer), ctx_.get(),
{x.get(), y.get()}, UseFunction()));
}
TEST_P(CppGradients, TestLog1pGrad) {
@ -395,10 +244,13 @@ TEST_P(CppGradients, TestLog1pGrad) {
}
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
Log1pModel, Log1pGradModel, ctx_.get(), {x.get()}, UseFunction()));
Log1pModel, BuildGradModel(Log1pModel, 1, "Log1p", Log1pRegisterer),
ctx_.get(), {x.get()}, UseFunction()));
}
TEST_P(CppGradients, TestDivNoNanGrad) {
// TODO(vnvo2409): Figure out why `BuildGradModel` does not work with
// `DivNoNan`.
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;