generalize BuildGradModel
This commit is contained in:
parent
1b03b99346
commit
1697abe35d
@ -106,25 +106,20 @@ void CheckTensorValue(AbstractTensorHandle* t, absl::Span<const float> manuals,
|
||||
delete[] danalytical;
|
||||
}
|
||||
|
||||
Model BuildGradModel(Model ops, size_t num_inputs, string ops_name,
|
||||
GradientFunctionFactory gradient_function_factory) {
|
||||
return [num_inputs, forward_ops = std::move(ops),
|
||||
forward_name = std::move(ops_name),
|
||||
gradient_factory = std::move(gradient_function_factory)](
|
||||
Model BuildGradModel(Model forward, GradientRegistry registry) {
|
||||
return [forward_model = std::move(forward),
|
||||
grad_registry = std::move(registry)](
|
||||
AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs) -> Status {
|
||||
GradientRegistry registry;
|
||||
TF_RETURN_IF_ERROR(registry.Register(forward_name, gradient_factory));
|
||||
|
||||
Tape tape(/*persistent=*/false);
|
||||
for (size_t i{}; i < num_inputs; ++i) {
|
||||
for (size_t i{}; i < inputs.size(); ++i) {
|
||||
tape.Watch(inputs[i]);
|
||||
}
|
||||
std::vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, registry));
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, grad_registry));
|
||||
TF_RETURN_IF_ERROR(
|
||||
forward_ops(tape_ctx.get(), inputs, absl::MakeSpan(temp_outputs)));
|
||||
forward_model(tape_ctx.get(), inputs, absl::MakeSpan(temp_outputs)));
|
||||
|
||||
TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs,
|
||||
/*sources=*/inputs,
|
||||
|
@ -30,8 +30,7 @@ void CompareNumericalAndAutodiffGradients(
|
||||
void CheckTensorValue(AbstractTensorHandle* t, absl::Span<const float> manuals,
|
||||
absl::Span<const int64_t> dims, double abs_error = 1e-2);
|
||||
|
||||
Model BuildGradModel(Model ops, size_t num_inputs, string ops_name,
|
||||
GradientFunctionFactory gradient_function_factory);
|
||||
Model BuildGradModel(Model forward, GradientRegistry registry);
|
||||
|
||||
} // namespace internal
|
||||
} // namespace gradients
|
||||
|
@ -85,8 +85,8 @@ class CppGradients
|
||||
void SetUp() override {
|
||||
TF_StatusPtr status(TF_NewStatus());
|
||||
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
|
||||
Status s = StatusFromTF_Status(status.get());
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
status_ = StatusFromTF_Status(status.get());
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
@ -103,6 +103,8 @@ class CppGradients
|
||||
}
|
||||
|
||||
AbstractContextPtr ctx_;
|
||||
GradientRegistry registry_;
|
||||
Status status_;
|
||||
|
||||
public:
|
||||
bool UseMlir() const { return strcmp(std::get<0>(GetParam()), "mlir") == 0; }
|
||||
@ -113,21 +115,24 @@ TEST_P(CppGradients, TestAddGrad) {
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx_.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
status_ = TestScalarTensorHandle(ctx_.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr y;
|
||||
{
|
||||
AbstractTensorHandle* y_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx_.get(), 2.0f, &y_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
status_ = TestScalarTensorHandle(ctx_.get(), 2.0f, &y_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
y.reset(y_raw);
|
||||
}
|
||||
|
||||
status_ = registry_.Register("AddV2", AddRegisterer);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
|
||||
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
|
||||
AddModel, BuildGradModel(AddModel, 2, "AddV2", AddRegisterer), ctx_.get(),
|
||||
AddModel, BuildGradModel(AddModel, registry_), ctx_.get(),
|
||||
{x.get(), y.get()}, UseFunction()));
|
||||
}
|
||||
|
||||
@ -135,14 +140,17 @@ TEST_P(CppGradients, TestExpGrad) {
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx_.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
status_ = TestScalarTensorHandle(ctx_.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
||||
status_ = registry_.Register("Exp", ExpRegisterer);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
|
||||
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
|
||||
ExpModel, BuildGradModel(ExpModel, 1, "Exp", ExpRegisterer), ctx_.get(),
|
||||
{x.get()}, UseFunction()));
|
||||
ExpModel, BuildGradModel(ExpModel, registry_), ctx_.get(), {x.get()},
|
||||
UseFunction()));
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestMatMulGrad) {
|
||||
@ -151,9 +159,9 @@ TEST_P(CppGradients, TestMatMulGrad) {
|
||||
AbstractTensorHandlePtr A;
|
||||
{
|
||||
AbstractTensorHandle* A_raw;
|
||||
Status s =
|
||||
status_ =
|
||||
TestTensorHandleWithDimsFloat(ctx_.get(), A_vals, A_dims, 2, &A_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
A.reset(A_raw);
|
||||
}
|
||||
|
||||
@ -162,12 +170,15 @@ TEST_P(CppGradients, TestMatMulGrad) {
|
||||
AbstractTensorHandlePtr B;
|
||||
{
|
||||
AbstractTensorHandle* B_raw;
|
||||
Status s =
|
||||
status_ =
|
||||
TestTensorHandleWithDimsFloat(ctx_.get(), B_vals, B_dims, 2, &B_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
B.reset(B_raw);
|
||||
}
|
||||
|
||||
status_ = registry_.Register("MatMul", MatMulRegisterer);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
|
||||
for (bool transpose_a : {false, true}) {
|
||||
for (bool transpose_b : {false, true}) {
|
||||
Model MatMulModel =
|
||||
@ -182,9 +193,8 @@ TEST_P(CppGradients, TestMatMulGrad) {
|
||||
// well with `MatMul` and remove `TestMatMul*` in
|
||||
// `mnist_gradients_test` when done.
|
||||
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
|
||||
MatMulModel,
|
||||
BuildGradModel(MatMulModel, 2, "MatMul", MatMulRegisterer),
|
||||
ctx_.get(), {A.get(), B.get()}, UseFunction(), /*abs_error*/ 0.4));
|
||||
MatMulModel, BuildGradModel(MatMulModel, registry_), ctx_.get(),
|
||||
{A.get(), B.get()}, UseFunction(), /*abs_error*/ 0.4));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -193,49 +203,58 @@ TEST_P(CppGradients, TestSqrtGrad) {
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx_.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
status_ = TestScalarTensorHandle(ctx_.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
||||
status_ = registry_.Register("Sqrt", SqrtRegisterer);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
|
||||
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
|
||||
SqrtModel, BuildGradModel(SqrtModel, 1, "Sqrt", SqrtRegisterer),
|
||||
ctx_.get(), {x.get()}, UseFunction()));
|
||||
SqrtModel, BuildGradModel(SqrtModel, registry_), ctx_.get(), {x.get()},
|
||||
UseFunction()));
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestNegGrad) {
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx_.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
status_ = TestScalarTensorHandle(ctx_.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
||||
status_ = registry_.Register("Neg", NegRegisterer);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
|
||||
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
|
||||
NegModel, BuildGradModel(NegModel, 1, "Neg", NegRegisterer), ctx_.get(),
|
||||
{x.get()}, UseFunction()));
|
||||
NegModel, BuildGradModel(NegModel, registry_), ctx_.get(), {x.get()},
|
||||
UseFunction()));
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestSubGrad) {
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx_.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
status_ = TestScalarTensorHandle(ctx_.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr y;
|
||||
{
|
||||
AbstractTensorHandle* y_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx_.get(), 2.0f, &y_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
status_ = TestScalarTensorHandle(ctx_.get(), 2.0f, &y_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
y.reset(y_raw);
|
||||
}
|
||||
|
||||
status_ = registry_.Register("Sub", SubRegisterer);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
|
||||
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
|
||||
SubModel, BuildGradModel(SubModel, 2, "Sub", SubRegisterer), ctx_.get(),
|
||||
SubModel, BuildGradModel(SubModel, registry_), ctx_.get(),
|
||||
{x.get(), y.get()}, UseFunction()));
|
||||
}
|
||||
|
||||
@ -243,21 +262,24 @@ TEST_P(CppGradients, TestMulGrad) {
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx_.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
status_ = TestScalarTensorHandle(ctx_.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr y;
|
||||
{
|
||||
AbstractTensorHandle* y_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx_.get(), 2.0f, &y_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
status_ = TestScalarTensorHandle(ctx_.get(), 2.0f, &y_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
y.reset(y_raw);
|
||||
}
|
||||
|
||||
status_ = registry_.Register("Mul", MulRegisterer);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
|
||||
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
|
||||
MulModel, BuildGradModel(MulModel, 2, "Mul", MulRegisterer), ctx_.get(),
|
||||
MulModel, BuildGradModel(MulModel, registry_), ctx_.get(),
|
||||
{x.get(), y.get()}, UseFunction()));
|
||||
}
|
||||
|
||||
@ -265,33 +287,38 @@ TEST_P(CppGradients, TestLog1pGrad) {
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx_.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
status_ = TestScalarTensorHandle(ctx_.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
||||
status_ = registry_.Register("Log1p", Log1pRegisterer);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
|
||||
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
|
||||
Log1pModel, BuildGradModel(Log1pModel, 1, "Log1p", Log1pRegisterer),
|
||||
ctx_.get(), {x.get()}, UseFunction()));
|
||||
Log1pModel, BuildGradModel(Log1pModel, registry_), ctx_.get(), {x.get()},
|
||||
UseFunction()));
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestDivNoNanGrad) {
|
||||
auto DivNoNanGradModel =
|
||||
BuildGradModel(DivNoNanModel, 2, "DivNoNan", DivNoNanRegisterer);
|
||||
status_ = registry_.Register("DivNoNan", DivNoNanRegisterer);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
|
||||
auto DivNoNanGradModel = BuildGradModel(DivNoNanModel, registry_);
|
||||
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx_.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
status_ = TestScalarTensorHandle(ctx_.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr y;
|
||||
{
|
||||
AbstractTensorHandle* y_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx_.get(), 2.0f, &y_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
status_ = TestScalarTensorHandle(ctx_.get(), 2.0f, &y_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
y.reset(y_raw);
|
||||
}
|
||||
|
||||
@ -303,14 +330,14 @@ TEST_P(CppGradients, TestDivNoNanGrad) {
|
||||
AbstractTensorHandlePtr z;
|
||||
{
|
||||
AbstractTensorHandle* z_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx_.get(), 0.0f, &z_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
status_ = TestScalarTensorHandle(ctx_.get(), 0.0f, &z_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
z.reset(z_raw);
|
||||
}
|
||||
std::vector<AbstractTensorHandle*> outputs(2);
|
||||
auto s = RunModel(DivNoNanGradModel, ctx_.get(), {x.get(), z.get()},
|
||||
absl::MakeSpan(outputs), UseFunction());
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
status_ = RunModel(DivNoNanGradModel, ctx_.get(), {x.get(), z.get()},
|
||||
absl::MakeSpan(outputs), UseFunction());
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
ASSERT_NO_FATAL_FAILURE(CheckTensorValue(outputs[0], {0.0f}, /*dims*/ {},
|
||||
/*abs_error*/ 0));
|
||||
ASSERT_NO_FATAL_FAILURE(CheckTensorValue(outputs[1], {0.0f}, /*dims*/ {},
|
||||
|
Loading…
Reference in New Issue
Block a user