fix lambda capture error in BuildGradModel

This commit is contained in:
Võ Văn Nghĩa 2021-01-14 14:53:46 +07:00
parent 6a79c9186a
commit eb0092fc21
4 changed files with 17 additions and 30 deletions

View File

@ -214,6 +214,7 @@ tf_cuda_cc_test(
"//tensorflow/c/eager:c_api_test_util",
"//tensorflow/c/experimental/gradients/tape:tape_context",
"//tensorflow/c/experimental/ops:math_ops",
"//tensorflow/core/platform:tensor_float_32_utils",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
] + if_libtpu(

View File

@ -106,14 +106,16 @@ void CheckTensorValue(AbstractTensorHandle* t, absl::Span<const float> manuals,
delete[] danalytical;
}
Model BuildGradModel(Model forward, size_t num_inputs, const string& op,
Model BuildGradModel(Model ops, size_t num_inputs, string ops_name,
GradientFunctionFactory gradient_function_factory) {
return [&, forward, gradient_function_factory](
return [num_inputs, forward_ops = std::move(ops),
forward_name = std::move(ops_name),
gradient_factory = std::move(gradient_function_factory)](
AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) -> Status {
GradientRegistry registry;
TF_RETURN_IF_ERROR(registry.Register(op, gradient_function_factory));
TF_RETURN_IF_ERROR(registry.Register(forward_name, gradient_factory));
Tape tape(/*persistent=*/false);
for (size_t i{}; i < num_inputs; ++i) {
@ -122,7 +124,7 @@ Model BuildGradModel(Model forward, size_t num_inputs, const string& op,
std::vector<AbstractTensorHandle*> temp_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, registry));
TF_RETURN_IF_ERROR(
forward(tape_ctx.get(), inputs, absl::MakeSpan(temp_outputs)));
forward_ops(tape_ctx.get(), inputs, absl::MakeSpan(temp_outputs)));
TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs,
/*sources=*/inputs,

View File

@ -30,7 +30,7 @@ void CompareNumericalAndAutodiffGradients(
void CheckTensorValue(AbstractTensorHandle* t, absl::Span<const float> manuals,
absl::Span<const int64_t> dims, double abs_error = 1e-2);
Model BuildGradModel(Model forward, size_t num_inputs, const string& op,
Model BuildGradModel(Model ops, size_t num_inputs, string ops_name,
GradientFunctionFactory gradient_function_factory);
} // namespace internal

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/c/experimental/gradients/tape/tape_context.h"
#include "tensorflow/c/experimental/ops/math_ops.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/platform/tensor_float_32_utils.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@ -78,29 +79,6 @@ Status DivNoNanModel(AbstractContext* ctx,
return ops::DivNoNan(ctx, inputs, outputs, "DivNoNan");
}
Status DivNoNanGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
GradientRegistry registry;
TF_RETURN_IF_ERROR(registry.Register("DivNoNan", DivNoNanRegisterer));
Tape tape(/*persistent=*/false);
tape.Watch(inputs[0]);
tape.Watch(inputs[1]);
std::vector<AbstractTensorHandle*> temp_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, registry));
TF_RETURN_IF_ERROR(ops::DivNoNan(
tape_ctx.get(), inputs, absl::MakeSpan(temp_outputs), "DivNoNanGrad"));
TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs,
/*sources=*/inputs,
/*output_gradients=*/{}, outputs));
for (auto temp_output : temp_outputs) {
temp_output->Unref();
}
return Status::OK();
}
class CppGradients
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
protected:
@ -117,6 +95,11 @@ class CppGradients
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx_.reset(ctx_raw);
}
// Computing numerical gradients with TensorFloat-32 is numerically
// unstable. Some forward pass tests also fail with TensorFloat-32 due to
// low tolerances
enable_tensor_float_32_execution(false);
}
AbstractContextPtr ctx_;
@ -249,8 +232,9 @@ TEST_P(CppGradients, TestLog1pGrad) {
}
TEST_P(CppGradients, TestDivNoNanGrad) {
// TODO(vnvo2409): Figure out why `BuildGradModel` does not work with
// `DivNoNan`.
auto DivNoNanGradModel =
BuildGradModel(DivNoNanModel, 2, "DivNoNan", DivNoNanRegisterer);
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;