137 lines
5.0 KiB
C++
137 lines
5.0 KiB
C++
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
#include "tensorflow/c/experimental/gradients/grad_test_helper.h"
|
|
|
|
#include "tensorflow/c/eager/gradient_checker.h"
|
|
#include "tensorflow/c/experimental/gradients/tape/tape_context.h"
|
|
#include "tensorflow/core/platform/test.h"
|
|
|
|
namespace tensorflow {
|
|
namespace gradients {
|
|
namespace internal {
|
|
|
|
void CompareNumericalAndAutodiffGradients(
|
|
Model model, Model grad_model, AbstractContext* ctx,
|
|
absl::Span<AbstractTensorHandle* const> inputs, bool use_function,
|
|
double abs_error) {
|
|
auto num_inputs = inputs.size();
|
|
std::vector<AbstractTensorHandle*> outputs(num_inputs);
|
|
auto s = RunModel(grad_model, ctx, inputs, absl::MakeSpan(outputs),
|
|
/*use_function=*/use_function);
|
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
|
|
|
for (int i = 0; i < num_inputs; ++i) {
|
|
if (!outputs[i]) continue;
|
|
|
|
AbstractTensorHandlePtr numerical_grad;
|
|
{
|
|
AbstractTensorHandle* numerical_grad_raw;
|
|
s = CalcNumericalGrad(ctx, model, inputs,
|
|
/*input_index=*/i, use_function,
|
|
&numerical_grad_raw);
|
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
|
numerical_grad.reset(numerical_grad_raw);
|
|
}
|
|
|
|
TF_Tensor* numerical_tensor;
|
|
s = GetValue(numerical_grad.get(), &numerical_tensor);
|
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
|
auto num_elem_numerical = TF_TensorElementCount(numerical_tensor);
|
|
|
|
TF_Tensor* analytical_tensor;
|
|
s = GetValue(outputs[i], &analytical_tensor);
|
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
|
auto num_elem_analytical = TF_TensorElementCount(analytical_tensor);
|
|
|
|
ASSERT_EQ(num_elem_numerical, num_elem_analytical);
|
|
|
|
float* dnumerical = new float[num_elem_numerical]{0};
|
|
memcpy(&dnumerical[0], TF_TensorData(numerical_tensor),
|
|
TF_TensorByteSize(numerical_tensor));
|
|
float* danalytical = new float[num_elem_analytical]{0};
|
|
memcpy(&danalytical[0], TF_TensorData(analytical_tensor),
|
|
TF_TensorByteSize(analytical_tensor));
|
|
|
|
for (int j = 0; j < num_elem_numerical; j++) {
|
|
ASSERT_NEAR(dnumerical[j], danalytical[j], abs_error);
|
|
}
|
|
TF_DeleteTensor(analytical_tensor);
|
|
TF_DeleteTensor(numerical_tensor);
|
|
delete[] danalytical;
|
|
delete[] dnumerical;
|
|
outputs[i]->Unref();
|
|
}
|
|
}
|
|
|
|
void CheckTensorValue(AbstractTensorHandle* t, absl::Span<const float> manuals,
|
|
absl::Span<const int64_t> dims, double abs_error) {
|
|
TF_Tensor* analytical_tensor;
|
|
auto s = GetValue(t, &analytical_tensor);
|
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
|
|
|
int64_t num_elem_analytical = 1;
|
|
auto num_dims_analytical = TF_NumDims(analytical_tensor);
|
|
ASSERT_EQ(dims.size(), num_dims_analytical);
|
|
for (int j = 0; j < num_dims_analytical; j++) {
|
|
auto dim_analytical = TF_Dim(analytical_tensor, j);
|
|
ASSERT_EQ(dims[j], dim_analytical);
|
|
num_elem_analytical *= dim_analytical;
|
|
}
|
|
|
|
float* danalytical = new float[num_elem_analytical]{0};
|
|
memcpy(&danalytical[0], TF_TensorData(analytical_tensor),
|
|
TF_TensorByteSize(analytical_tensor));
|
|
|
|
for (int64_t j = 0; j < num_elem_analytical; j++) {
|
|
if (abs_error == 0) {
|
|
ASSERT_EQ(manuals[j], danalytical[j]);
|
|
} else {
|
|
ASSERT_NEAR(manuals[j], danalytical[j], abs_error);
|
|
}
|
|
}
|
|
|
|
TF_DeleteTensor(analytical_tensor);
|
|
delete[] danalytical;
|
|
}
|
|
|
|
Model BuildGradModel(Model forward, GradientRegistry registry) {
|
|
return [forward_model = std::move(forward),
|
|
grad_registry = std::move(registry)](
|
|
AbstractContext* ctx,
|
|
absl::Span<AbstractTensorHandle* const> inputs,
|
|
absl::Span<AbstractTensorHandle*> outputs) -> Status {
|
|
Tape tape(/*persistent=*/false);
|
|
for (size_t i{}; i < inputs.size(); ++i) {
|
|
tape.Watch(inputs[i]);
|
|
}
|
|
std::vector<AbstractTensorHandle*> temp_outputs(1);
|
|
AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, grad_registry));
|
|
TF_RETURN_IF_ERROR(
|
|
forward_model(tape_ctx.get(), inputs, absl::MakeSpan(temp_outputs)));
|
|
|
|
TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs,
|
|
/*sources=*/inputs,
|
|
/*output_gradients=*/{}, outputs));
|
|
for (auto temp_output : temp_outputs) {
|
|
temp_output->Unref();
|
|
}
|
|
return Status::OK();
|
|
};
|
|
}
|
|
|
|
} // namespace internal
|
|
} // namespace gradients
|
|
} // namespace tensorflow
|