add BuildGradModel
This commit is contained in:
parent
1c1a91e70f
commit
6a79c9186a
tensorflow/c/experimental/gradients
@ -166,7 +166,9 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/eager:gradient_checker",
|
||||
"//tensorflow/c/eager:gradients_internal",
|
||||
"//tensorflow/c/eager:unified_api_testutil",
|
||||
"//tensorflow/c/experimental/gradients/tape:tape_context",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/experimental/gradients/grad_test_helper.h"
|
||||
|
||||
#include "tensorflow/c/eager/gradient_checker.h"
|
||||
#include "tensorflow/c/experimental/gradients/tape/tape_context.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -105,6 +106,34 @@ void CheckTensorValue(AbstractTensorHandle* t, absl::Span<const float> manuals,
|
||||
delete[] danalytical;
|
||||
}
|
||||
|
||||
Model BuildGradModel(Model forward, size_t num_inputs, const string& op,
|
||||
GradientFunctionFactory gradient_function_factory) {
|
||||
return [&, forward, gradient_function_factory](
|
||||
AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs) -> Status {
|
||||
GradientRegistry registry;
|
||||
TF_RETURN_IF_ERROR(registry.Register(op, gradient_function_factory));
|
||||
|
||||
Tape tape(/*persistent=*/false);
|
||||
for (size_t i{}; i < num_inputs; ++i) {
|
||||
tape.Watch(inputs[i]);
|
||||
}
|
||||
std::vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, registry));
|
||||
TF_RETURN_IF_ERROR(
|
||||
forward(tape_ctx.get(), inputs, absl::MakeSpan(temp_outputs)));
|
||||
|
||||
TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs,
|
||||
/*sources=*/inputs,
|
||||
/*output_gradients=*/{}, outputs));
|
||||
for (auto temp_output : temp_outputs) {
|
||||
temp_output->Unref();
|
||||
}
|
||||
return Status::OK();
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_GRAD_TEST_HELPER_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_GRAD_TEST_HELPER_H_
|
||||
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
#include "tensorflow/c/eager/unified_api_testutil.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -29,6 +30,9 @@ void CompareNumericalAndAutodiffGradients(
|
||||
void CheckTensorValue(AbstractTensorHandle* t, absl::Span<const float> manuals,
|
||||
absl::Span<const int64_t> dims, double abs_error = 1e-2);
|
||||
|
||||
Model BuildGradModel(Model forward, size_t num_inputs, const string& op,
|
||||
GradientFunctionFactory gradient_function_factory);
|
||||
|
||||
} // namespace internal
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
@ -36,199 +36,42 @@ Status AddModel(AbstractContext* ctx,
|
||||
return ops::Add(ctx, inputs, outputs, "Add");
|
||||
}
|
||||
|
||||
Status AddGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
GradientRegistry registry;
|
||||
TF_RETURN_IF_ERROR(registry.Register("AddV2", AddRegisterer));
|
||||
|
||||
Tape tape(/*persistent=*/false);
|
||||
tape.Watch(inputs[0]);
|
||||
tape.Watch(inputs[1]);
|
||||
std::vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::Add(tape_ctx.get(), inputs,
|
||||
absl::MakeSpan(temp_outputs), "AddGrad"));
|
||||
|
||||
TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs,
|
||||
/*sources=*/inputs,
|
||||
/*output_gradients=*/{}, outputs));
|
||||
for (auto temp_output : temp_outputs) {
|
||||
temp_output->Unref();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ExpModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
return ops::Exp(ctx, inputs, outputs, "Exp");
|
||||
}
|
||||
|
||||
Status ExpGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
GradientRegistry registry;
|
||||
TF_RETURN_IF_ERROR(registry.Register("Exp", ExpRegisterer));
|
||||
|
||||
Tape tape(/*persistent=*/false);
|
||||
tape.Watch(inputs[0]);
|
||||
std::vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::Exp(tape_ctx.get(), inputs,
|
||||
absl::MakeSpan(temp_outputs), "ExpGrad"));
|
||||
|
||||
TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs,
|
||||
/*sources=*/inputs,
|
||||
/*output_gradients=*/{}, outputs));
|
||||
for (auto temp_output : temp_outputs) {
|
||||
temp_output->Unref();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SqrtModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
return ops::Sqrt(ctx, inputs, outputs, "Sqrt");
|
||||
}
|
||||
|
||||
Status SqrtGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
GradientRegistry registry;
|
||||
TF_RETURN_IF_ERROR(registry.Register("Sqrt", SqrtRegisterer));
|
||||
|
||||
Tape tape(/*persistent=*/false);
|
||||
tape.Watch(inputs[0]);
|
||||
std::vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::Sqrt(tape_ctx.get(), inputs,
|
||||
absl::MakeSpan(temp_outputs), "SqrtGrad"));
|
||||
|
||||
TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs,
|
||||
/*sources=*/inputs,
|
||||
/*output_gradients=*/{}, outputs));
|
||||
for (auto temp_output : temp_outputs) {
|
||||
temp_output->Unref();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status NegModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
return ops::Neg(ctx, inputs, outputs, "Neg");
|
||||
}
|
||||
|
||||
Status NegGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
GradientRegistry registry;
|
||||
TF_RETURN_IF_ERROR(registry.Register("Neg", NegRegisterer));
|
||||
|
||||
Tape tape(/*persistent=*/false);
|
||||
tape.Watch(inputs[0]);
|
||||
std::vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::Neg(tape_ctx.get(), inputs,
|
||||
absl::MakeSpan(temp_outputs), "NegGrad"));
|
||||
|
||||
TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs,
|
||||
/*sources=*/inputs,
|
||||
/*output_gradients=*/{}, outputs));
|
||||
for (auto temp_output : temp_outputs) {
|
||||
temp_output->Unref();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SubModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
return ops::Sub(ctx, inputs, outputs, "Sub");
|
||||
}
|
||||
|
||||
Status SubGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
GradientRegistry registry;
|
||||
TF_RETURN_IF_ERROR(registry.Register("Sub", SubRegisterer));
|
||||
|
||||
Tape tape(/*persistent=*/false);
|
||||
tape.Watch(inputs[0]);
|
||||
tape.Watch(inputs[1]);
|
||||
std::vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::Sub(tape_ctx.get(), inputs,
|
||||
absl::MakeSpan(temp_outputs), "SubGrad"));
|
||||
|
||||
TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs,
|
||||
/*sources=*/inputs,
|
||||
/*output_gradients=*/{}, outputs));
|
||||
for (auto temp_output : temp_outputs) {
|
||||
temp_output->Unref();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MulModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
return ops::Mul(ctx, inputs, outputs, "Mul");
|
||||
}
|
||||
|
||||
Status MulGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
GradientRegistry registry;
|
||||
TF_RETURN_IF_ERROR(registry.Register("Mul", MulRegisterer));
|
||||
|
||||
Tape tape(/*persistent=*/false);
|
||||
tape.Watch(inputs[0]);
|
||||
tape.Watch(inputs[1]);
|
||||
std::vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::Mul(tape_ctx.get(), inputs,
|
||||
absl::MakeSpan(temp_outputs), "MulGrad"));
|
||||
|
||||
TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs,
|
||||
/*sources=*/inputs,
|
||||
/*output_gradients=*/{}, outputs));
|
||||
for (auto temp_output : temp_outputs) {
|
||||
temp_output->Unref();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Log1pModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
return ops::Log1p(ctx, inputs, outputs, "Log1p");
|
||||
}
|
||||
|
||||
Status Log1pGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
GradientRegistry registry;
|
||||
TF_RETURN_IF_ERROR(registry.Register("Log1p", Log1pRegisterer));
|
||||
|
||||
Tape tape(/*persistent=*/false);
|
||||
tape.Watch(inputs[0]);
|
||||
std::vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::Log1p(tape_ctx.get(), inputs,
|
||||
absl::MakeSpan(temp_outputs), "Log1pGrad"));
|
||||
|
||||
TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs,
|
||||
/*sources=*/inputs,
|
||||
/*output_gradients=*/{}, outputs));
|
||||
for (auto temp_output : temp_outputs) {
|
||||
temp_output->Unref();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DivNoNanModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
@ -301,7 +144,8 @@ TEST_P(CppGradients, TestAddGrad) {
|
||||
}
|
||||
|
||||
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
|
||||
AddModel, AddGradModel, ctx_.get(), {x.get(), y.get()}, UseFunction()));
|
||||
AddModel, BuildGradModel(AddModel, 2, "AddV2", AddRegisterer), ctx_.get(),
|
||||
{x.get(), y.get()}, UseFunction()));
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestExpGrad) {
|
||||
@ -314,7 +158,8 @@ TEST_P(CppGradients, TestExpGrad) {
|
||||
}
|
||||
|
||||
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
|
||||
ExpModel, ExpGradModel, ctx_.get(), {x.get()}, UseFunction()));
|
||||
ExpModel, BuildGradModel(ExpModel, 1, "Exp", ExpRegisterer), ctx_.get(),
|
||||
{x.get()}, UseFunction()));
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestSqrtGrad) {
|
||||
@ -327,7 +172,8 @@ TEST_P(CppGradients, TestSqrtGrad) {
|
||||
}
|
||||
|
||||
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
|
||||
SqrtModel, SqrtGradModel, ctx_.get(), {x.get()}, UseFunction()));
|
||||
SqrtModel, BuildGradModel(SqrtModel, 1, "Sqrt", SqrtRegisterer),
|
||||
ctx_.get(), {x.get()}, UseFunction()));
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestNegGrad) {
|
||||
@ -340,7 +186,8 @@ TEST_P(CppGradients, TestNegGrad) {
|
||||
}
|
||||
|
||||
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
|
||||
NegModel, NegGradModel, ctx_.get(), {x.get()}, UseFunction()));
|
||||
NegModel, BuildGradModel(NegModel, 1, "Neg", NegRegisterer), ctx_.get(),
|
||||
{x.get()}, UseFunction()));
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestSubGrad) {
|
||||
@ -361,7 +208,8 @@ TEST_P(CppGradients, TestSubGrad) {
|
||||
}
|
||||
|
||||
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
|
||||
SubModel, SubGradModel, ctx_.get(), {x.get(), y.get()}, UseFunction()));
|
||||
SubModel, BuildGradModel(SubModel, 2, "Sub", SubRegisterer), ctx_.get(),
|
||||
{x.get(), y.get()}, UseFunction()));
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestMulGrad) {
|
||||
@ -382,7 +230,8 @@ TEST_P(CppGradients, TestMulGrad) {
|
||||
}
|
||||
|
||||
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
|
||||
MulModel, MulGradModel, ctx_.get(), {x.get(), y.get()}, UseFunction()));
|
||||
MulModel, BuildGradModel(MulModel, 2, "Mul", MulRegisterer), ctx_.get(),
|
||||
{x.get(), y.get()}, UseFunction()));
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestLog1pGrad) {
|
||||
@ -395,10 +244,13 @@ TEST_P(CppGradients, TestLog1pGrad) {
|
||||
}
|
||||
|
||||
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
|
||||
Log1pModel, Log1pGradModel, ctx_.get(), {x.get()}, UseFunction()));
|
||||
Log1pModel, BuildGradModel(Log1pModel, 1, "Log1p", Log1pRegisterer),
|
||||
ctx_.get(), {x.get()}, UseFunction()));
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestDivNoNanGrad) {
|
||||
// TODO(vnvo2409): Figure out why `BuildGradModel` does not work with
|
||||
// `DivNoNan`.
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
|
Loading…
Reference in New Issue
Block a user