fix lambda capture error in BuildGradModel
This commit is contained in:
parent
6a79c9186a
commit
eb0092fc21
@ -214,6 +214,7 @@ tf_cuda_cc_test(
|
|||||||
"//tensorflow/c/eager:c_api_test_util",
|
"//tensorflow/c/eager:c_api_test_util",
|
||||||
"//tensorflow/c/experimental/gradients/tape:tape_context",
|
"//tensorflow/c/experimental/gradients/tape:tape_context",
|
||||||
"//tensorflow/c/experimental/ops:math_ops",
|
"//tensorflow/c/experimental/ops:math_ops",
|
||||||
|
"//tensorflow/core/platform:tensor_float_32_utils",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
] + if_libtpu(
|
] + if_libtpu(
|
||||||
|
@ -106,14 +106,16 @@ void CheckTensorValue(AbstractTensorHandle* t, absl::Span<const float> manuals,
|
|||||||
delete[] danalytical;
|
delete[] danalytical;
|
||||||
}
|
}
|
||||||
|
|
||||||
Model BuildGradModel(Model forward, size_t num_inputs, const string& op,
|
Model BuildGradModel(Model ops, size_t num_inputs, string ops_name,
|
||||||
GradientFunctionFactory gradient_function_factory) {
|
GradientFunctionFactory gradient_function_factory) {
|
||||||
return [&, forward, gradient_function_factory](
|
return [num_inputs, forward_ops = std::move(ops),
|
||||||
|
forward_name = std::move(ops_name),
|
||||||
|
gradient_factory = std::move(gradient_function_factory)](
|
||||||
AbstractContext* ctx,
|
AbstractContext* ctx,
|
||||||
absl::Span<AbstractTensorHandle* const> inputs,
|
absl::Span<AbstractTensorHandle* const> inputs,
|
||||||
absl::Span<AbstractTensorHandle*> outputs) -> Status {
|
absl::Span<AbstractTensorHandle*> outputs) -> Status {
|
||||||
GradientRegistry registry;
|
GradientRegistry registry;
|
||||||
TF_RETURN_IF_ERROR(registry.Register(op, gradient_function_factory));
|
TF_RETURN_IF_ERROR(registry.Register(forward_name, gradient_factory));
|
||||||
|
|
||||||
Tape tape(/*persistent=*/false);
|
Tape tape(/*persistent=*/false);
|
||||||
for (size_t i{}; i < num_inputs; ++i) {
|
for (size_t i{}; i < num_inputs; ++i) {
|
||||||
@ -122,7 +124,7 @@ Model BuildGradModel(Model forward, size_t num_inputs, const string& op,
|
|||||||
std::vector<AbstractTensorHandle*> temp_outputs(1);
|
std::vector<AbstractTensorHandle*> temp_outputs(1);
|
||||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, registry));
|
AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, registry));
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
forward(tape_ctx.get(), inputs, absl::MakeSpan(temp_outputs)));
|
forward_ops(tape_ctx.get(), inputs, absl::MakeSpan(temp_outputs)));
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs,
|
TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs,
|
||||||
/*sources=*/inputs,
|
/*sources=*/inputs,
|
||||||
|
@ -30,7 +30,7 @@ void CompareNumericalAndAutodiffGradients(
|
|||||||
void CheckTensorValue(AbstractTensorHandle* t, absl::Span<const float> manuals,
|
void CheckTensorValue(AbstractTensorHandle* t, absl::Span<const float> manuals,
|
||||||
absl::Span<const int64_t> dims, double abs_error = 1e-2);
|
absl::Span<const int64_t> dims, double abs_error = 1e-2);
|
||||||
|
|
||||||
Model BuildGradModel(Model forward, size_t num_inputs, const string& op,
|
Model BuildGradModel(Model ops, size_t num_inputs, string ops_name,
|
||||||
GradientFunctionFactory gradient_function_factory);
|
GradientFunctionFactory gradient_function_factory);
|
||||||
|
|
||||||
} // namespace internal
|
} // namespace internal
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/experimental/gradients/tape/tape_context.h"
|
#include "tensorflow/c/experimental/gradients/tape/tape_context.h"
|
||||||
#include "tensorflow/c/experimental/ops/math_ops.h"
|
#include "tensorflow/c/experimental/ops/math_ops.h"
|
||||||
#include "tensorflow/c/tf_status_helper.h"
|
#include "tensorflow/c/tf_status_helper.h"
|
||||||
|
#include "tensorflow/core/platform/tensor_float_32_utils.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -78,29 +79,6 @@ Status DivNoNanModel(AbstractContext* ctx,
|
|||||||
return ops::DivNoNan(ctx, inputs, outputs, "DivNoNan");
|
return ops::DivNoNan(ctx, inputs, outputs, "DivNoNan");
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DivNoNanGradModel(AbstractContext* ctx,
|
|
||||||
absl::Span<AbstractTensorHandle* const> inputs,
|
|
||||||
absl::Span<AbstractTensorHandle*> outputs) {
|
|
||||||
GradientRegistry registry;
|
|
||||||
TF_RETURN_IF_ERROR(registry.Register("DivNoNan", DivNoNanRegisterer));
|
|
||||||
|
|
||||||
Tape tape(/*persistent=*/false);
|
|
||||||
tape.Watch(inputs[0]);
|
|
||||||
tape.Watch(inputs[1]);
|
|
||||||
std::vector<AbstractTensorHandle*> temp_outputs(1);
|
|
||||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, registry));
|
|
||||||
TF_RETURN_IF_ERROR(ops::DivNoNan(
|
|
||||||
tape_ctx.get(), inputs, absl::MakeSpan(temp_outputs), "DivNoNanGrad"));
|
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs,
|
|
||||||
/*sources=*/inputs,
|
|
||||||
/*output_gradients=*/{}, outputs));
|
|
||||||
for (auto temp_output : temp_outputs) {
|
|
||||||
temp_output->Unref();
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
class CppGradients
|
class CppGradients
|
||||||
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
|
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
|
||||||
protected:
|
protected:
|
||||||
@ -117,6 +95,11 @@ class CppGradients
|
|||||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
ctx_.reset(ctx_raw);
|
ctx_.reset(ctx_raw);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Computing numerical gradients with TensorFloat-32 is numerically
|
||||||
|
// unstable. Some forward pass tests also fail with TensorFloat-32 due to
|
||||||
|
// low tolerances
|
||||||
|
enable_tensor_float_32_execution(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractContextPtr ctx_;
|
AbstractContextPtr ctx_;
|
||||||
@ -249,8 +232,9 @@ TEST_P(CppGradients, TestLog1pGrad) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(CppGradients, TestDivNoNanGrad) {
|
TEST_P(CppGradients, TestDivNoNanGrad) {
|
||||||
// TODO(vnvo2409): Figure out why `BuildGradModel` does not work with
|
auto DivNoNanGradModel =
|
||||||
// `DivNoNan`.
|
BuildGradModel(DivNoNanModel, 2, "DivNoNan", DivNoNanRegisterer);
|
||||||
|
|
||||||
AbstractTensorHandlePtr x;
|
AbstractTensorHandlePtr x;
|
||||||
{
|
{
|
||||||
AbstractTensorHandle* x_raw = nullptr;
|
AbstractTensorHandle* x_raw = nullptr;
|
||||||
|
Loading…
Reference in New Issue
Block a user