fix lambda capture error in BuildGradModel
This commit is contained in:
parent
6a79c9186a
commit
eb0092fc21
@ -214,6 +214,7 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/c/eager:c_api_test_util",
|
||||
"//tensorflow/c/experimental/gradients/tape:tape_context",
|
||||
"//tensorflow/c/experimental/ops:math_ops",
|
||||
"//tensorflow/core/platform:tensor_float_32_utils",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
] + if_libtpu(
|
||||
|
@ -106,14 +106,16 @@ void CheckTensorValue(AbstractTensorHandle* t, absl::Span<const float> manuals,
|
||||
delete[] danalytical;
|
||||
}
|
||||
|
||||
Model BuildGradModel(Model forward, size_t num_inputs, const string& op,
|
||||
Model BuildGradModel(Model ops, size_t num_inputs, string ops_name,
|
||||
GradientFunctionFactory gradient_function_factory) {
|
||||
return [&, forward, gradient_function_factory](
|
||||
return [num_inputs, forward_ops = std::move(ops),
|
||||
forward_name = std::move(ops_name),
|
||||
gradient_factory = std::move(gradient_function_factory)](
|
||||
AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs) -> Status {
|
||||
GradientRegistry registry;
|
||||
TF_RETURN_IF_ERROR(registry.Register(op, gradient_function_factory));
|
||||
TF_RETURN_IF_ERROR(registry.Register(forward_name, gradient_factory));
|
||||
|
||||
Tape tape(/*persistent=*/false);
|
||||
for (size_t i{}; i < num_inputs; ++i) {
|
||||
@ -122,7 +124,7 @@ Model BuildGradModel(Model forward, size_t num_inputs, const string& op,
|
||||
std::vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, registry));
|
||||
TF_RETURN_IF_ERROR(
|
||||
forward(tape_ctx.get(), inputs, absl::MakeSpan(temp_outputs)));
|
||||
forward_ops(tape_ctx.get(), inputs, absl::MakeSpan(temp_outputs)));
|
||||
|
||||
TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs,
|
||||
/*sources=*/inputs,
|
||||
|
@ -30,7 +30,7 @@ void CompareNumericalAndAutodiffGradients(
|
||||
void CheckTensorValue(AbstractTensorHandle* t, absl::Span<const float> manuals,
|
||||
absl::Span<const int64_t> dims, double abs_error = 1e-2);
|
||||
|
||||
Model BuildGradModel(Model forward, size_t num_inputs, const string& op,
|
||||
Model BuildGradModel(Model ops, size_t num_inputs, string ops_name,
|
||||
GradientFunctionFactory gradient_function_factory);
|
||||
|
||||
} // namespace internal
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/experimental/gradients/tape/tape_context.h"
|
||||
#include "tensorflow/c/experimental/ops/math_ops.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/platform/tensor_float_32_utils.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -78,29 +79,6 @@ Status DivNoNanModel(AbstractContext* ctx,
|
||||
return ops::DivNoNan(ctx, inputs, outputs, "DivNoNan");
|
||||
}
|
||||
|
||||
Status DivNoNanGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
GradientRegistry registry;
|
||||
TF_RETURN_IF_ERROR(registry.Register("DivNoNan", DivNoNanRegisterer));
|
||||
|
||||
Tape tape(/*persistent=*/false);
|
||||
tape.Watch(inputs[0]);
|
||||
tape.Watch(inputs[1]);
|
||||
std::vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::DivNoNan(
|
||||
tape_ctx.get(), inputs, absl::MakeSpan(temp_outputs), "DivNoNanGrad"));
|
||||
|
||||
TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs,
|
||||
/*sources=*/inputs,
|
||||
/*output_gradients=*/{}, outputs));
|
||||
for (auto temp_output : temp_outputs) {
|
||||
temp_output->Unref();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
class CppGradients
|
||||
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
|
||||
protected:
|
||||
@ -117,6 +95,11 @@ class CppGradients
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx_.reset(ctx_raw);
|
||||
}
|
||||
|
||||
// Computing numerical gradients with TensorFloat-32 is numerically
|
||||
// unstable. Some forward pass tests also fail with TensorFloat-32 due to
|
||||
// low tolerances
|
||||
enable_tensor_float_32_execution(false);
|
||||
}
|
||||
|
||||
AbstractContextPtr ctx_;
|
||||
@ -249,8 +232,9 @@ TEST_P(CppGradients, TestLog1pGrad) {
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestDivNoNanGrad) {
|
||||
// TODO(vnvo2409): Figure out why `BuildGradModel` does not work with
|
||||
// `DivNoNan`.
|
||||
auto DivNoNanGradModel =
|
||||
BuildGradModel(DivNoNanModel, 2, "DivNoNan", DivNoNanRegisterer);
|
||||
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
|
Loading…
Reference in New Issue
Block a user