fix asan in gradient_checker
This commit is contained in:
parent
28c7e4d9f2
commit
e1b31fce08
@ -416,9 +416,7 @@ tf_cuda_cc_test(
|
|||||||
],
|
],
|
||||||
args = ["--heap_check=local"],
|
args = ["--heap_check=local"],
|
||||||
linkstatic = tf_kernel_tests_linkstatic(),
|
linkstatic = tf_kernel_tests_linkstatic(),
|
||||||
tags = tf_cuda_tests_tags() + [
|
tags = tf_cuda_tests_tags(),
|
||||||
"no_cuda_asan", # b/175330074
|
|
||||||
],
|
|
||||||
deps = [
|
deps = [
|
||||||
":abstract_tensor_handle",
|
":abstract_tensor_handle",
|
||||||
":c_api_experimental",
|
":c_api_experimental",
|
||||||
|
|||||||
@ -54,15 +54,16 @@ Status RunAndMaybeSum(AbstractContext* ctx, Model forward,
|
|||||||
// Run the model.
|
// Run the model.
|
||||||
TF_RETURN_IF_ERROR(RunModel(forward, ctx, inputs,
|
TF_RETURN_IF_ERROR(RunModel(forward, ctx, inputs,
|
||||||
absl::MakeSpan(model_outputs), use_function));
|
absl::MakeSpan(model_outputs), use_function));
|
||||||
AbstractTensorHandle* model_out = model_outputs[0];
|
AbstractTensorHandlePtr model_out(model_outputs[0]);
|
||||||
|
|
||||||
TF_Tensor* model_out_tensor;
|
TF_Tensor* model_out_tensor;
|
||||||
TF_RETURN_IF_ERROR(GetValue(model_out, &model_out_tensor));
|
TF_RETURN_IF_ERROR(GetValue(model_out.get(), &model_out_tensor));
|
||||||
int num_dims_out = TF_NumDims(model_out_tensor);
|
int num_dims_out = TF_NumDims(model_out_tensor);
|
||||||
|
TF_DeleteTensor(model_out_tensor);
|
||||||
|
|
||||||
// If the output is a scalar, then return the scalar output
|
// If the output is a scalar, then return the scalar output
|
||||||
if (num_dims_out == 0) {
|
if (num_dims_out == 0) {
|
||||||
outputs[0] = model_out;
|
outputs[0] = model_out.release();
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -81,12 +82,8 @@ Status RunAndMaybeSum(AbstractContext* ctx, Model forward,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Reduce sum the output on all dimensions.
|
// Reduce sum the output on all dimensions.
|
||||||
std::vector<AbstractTensorHandle*> sum_inputs(2);
|
TF_RETURN_IF_ERROR(ops::Sum(ctx, {model_out.get(), sum_dims.get()},
|
||||||
sum_inputs[0] = model_out;
|
absl::MakeSpan(model_outputs), "sum_output"));
|
||||||
sum_inputs[1] = sum_dims.get();
|
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
ops::Sum(ctx, sum_inputs, absl::MakeSpan(model_outputs), "sum_output"));
|
|
||||||
outputs[0] = model_outputs[0];
|
outputs[0] = model_outputs[0];
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -169,37 +166,38 @@ Status CalcNumericalGrad(AbstractContext* ctx, Model forward,
|
|||||||
theta_inputs[input_index] = thetaPlus.get();
|
theta_inputs[input_index] = thetaPlus.get();
|
||||||
TF_RETURN_IF_ERROR(RunAndMaybeSum(ctx, forward, theta_inputs,
|
TF_RETURN_IF_ERROR(RunAndMaybeSum(ctx, forward, theta_inputs,
|
||||||
absl::MakeSpan(f_outputs), use_function));
|
absl::MakeSpan(f_outputs), use_function));
|
||||||
AbstractTensorHandle* fPlus = f_outputs[0];
|
AbstractTensorHandlePtr fPlus(f_outputs[0]);
|
||||||
|
|
||||||
// Get f(theta - eps):
|
// Get f(theta - eps):
|
||||||
theta_inputs[input_index] = thetaMinus.get();
|
theta_inputs[input_index] = thetaMinus.get();
|
||||||
TF_RETURN_IF_ERROR(RunAndMaybeSum(ctx, forward, theta_inputs,
|
TF_RETURN_IF_ERROR(RunAndMaybeSum(ctx, forward, theta_inputs,
|
||||||
absl::MakeSpan(f_outputs), use_function));
|
absl::MakeSpan(f_outputs), use_function));
|
||||||
AbstractTensorHandle* fMinus = f_outputs[0];
|
AbstractTensorHandlePtr fMinus(f_outputs[0]);
|
||||||
|
|
||||||
// Take Difference of both estimates: (f(theta + eps) - f(theta - eps)).
|
// Take Difference of both estimates: (f(theta + eps) - f(theta - eps)).
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(ops::Sub(ctx, {fPlus.get(), fMinus.get()},
|
||||||
ops::Sub(ctx, {fPlus, fMinus}, absl::MakeSpan(f_outputs), "sub_top"));
|
absl::MakeSpan(f_outputs), "sub_top"));
|
||||||
AbstractTensorHandle* fDiff = f_outputs[0];
|
AbstractTensorHandlePtr fDiff(f_outputs[0]);
|
||||||
|
|
||||||
// Calculate using the difference quotient definition:
|
// Calculate using the difference quotient definition:
|
||||||
// (f(theta + eps) - f(theta - eps)) / (2 * eps).
|
// (f(theta + eps) - f(theta - eps)) / (2 * eps).
|
||||||
TF_RETURN_IF_ERROR(ops::Div(ctx, {fDiff, two_eps.get()},
|
TF_RETURN_IF_ERROR(ops::Div(ctx, {fDiff.get(), two_eps.get()},
|
||||||
absl::MakeSpan(f_outputs), "diff_quotient"));
|
absl::MakeSpan(f_outputs), "diff_quotient"));
|
||||||
AbstractTensorHandle* diff_quotient = f_outputs[0];
|
AbstractTensorHandlePtr diff_quotient(f_outputs[0]);
|
||||||
|
|
||||||
TF_Tensor* grad_tensor;
|
TF_Tensor* grad_tensor;
|
||||||
TF_RETURN_IF_ERROR(GetValue(diff_quotient, &grad_tensor));
|
TF_RETURN_IF_ERROR(GetValue(diff_quotient.get(), &grad_tensor));
|
||||||
float grad_data[1];
|
float grad_data[1];
|
||||||
memcpy(&grad_data[0], TF_TensorData(grad_tensor),
|
memcpy(&grad_data[0], TF_TensorData(grad_tensor),
|
||||||
TF_TensorByteSize(grad_tensor));
|
TF_TensorByteSize(grad_tensor));
|
||||||
|
TF_DeleteTensor(grad_tensor);
|
||||||
dtheta_approx[i] = grad_data[0];
|
dtheta_approx[i] = grad_data[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
// Populate *numerical_grad with the data from dtheta_approx.
|
// Populate *numerical_grad with the data from dtheta_approx.
|
||||||
TF_RETURN_IF_ERROR(TestTensorHandleWithDimsFloat(
|
TF_RETURN_IF_ERROR(TestTensorHandleWithDimsFloat(
|
||||||
ctx, dtheta_approx.data(), theta_dims.data(), num_dims, numerical_grad));
|
ctx, dtheta_approx.data(), theta_dims.data(), num_dims, numerical_grad));
|
||||||
|
TF_DeleteTensor(theta_tensor);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -34,13 +34,18 @@ void CompareNumericalAndManualGradients(
|
|||||||
absl::Span<AbstractTensorHandle* const> inputs, int input_index,
|
absl::Span<AbstractTensorHandle* const> inputs, int input_index,
|
||||||
float* expected_grad, int num_grad, bool use_function,
|
float* expected_grad, int num_grad, bool use_function,
|
||||||
double abs_error = 1e-2) {
|
double abs_error = 1e-2) {
|
||||||
AbstractTensorHandle* numerical_grad;
|
Status s;
|
||||||
Status s = CalcNumericalGrad(ctx, model, inputs, input_index, use_function,
|
AbstractTensorHandlePtr numerical_grad;
|
||||||
&numerical_grad);
|
{
|
||||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
AbstractTensorHandle* numerical_grad_raw;
|
||||||
|
s = CalcNumericalGrad(ctx, model, inputs, input_index, use_function,
|
||||||
|
&numerical_grad_raw);
|
||||||
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
numerical_grad.reset(numerical_grad_raw);
|
||||||
|
}
|
||||||
|
|
||||||
TF_Tensor* numerical_tensor;
|
TF_Tensor* numerical_tensor;
|
||||||
s = GetValue(numerical_grad, &numerical_tensor);
|
s = GetValue(numerical_grad.get(), &numerical_tensor);
|
||||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
auto num_elem_numerical = TF_TensorElementCount(numerical_tensor);
|
auto num_elem_numerical = TF_TensorElementCount(numerical_tensor);
|
||||||
ASSERT_EQ(num_elem_numerical, num_grad);
|
ASSERT_EQ(num_elem_numerical, num_grad);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user