C++ Gradients: Adds gradient checker.
Change: 133955605
This commit is contained in:
parent
510a7bfac3
commit
a3c7fe1a73
@ -48,6 +48,42 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gradient_checker",
|
||||
srcs = ["framework/gradient_checker.cc"],
|
||||
hdrs = ["framework/gradient_checker.h"],
|
||||
deps = [
|
||||
":cc_ops",
|
||||
":client_session",
|
||||
":grad_op_registry",
|
||||
":gradients",
|
||||
":ops",
|
||||
":scope",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "framework_gradient_checker_test",
|
||||
srcs = ["framework/gradient_checker_test.cc"],
|
||||
deps = [
|
||||
":cc_ops",
|
||||
":grad_op_registry",
|
||||
":grad_ops",
|
||||
":gradient_checker",
|
||||
":testutil",
|
||||
"//tensorflow/core:all_kernels",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "grad_ops",
|
||||
deps = [
|
||||
|
165
tensorflow/cc/framework/gradient_checker.cc
Normal file
165
tensorflow/cc/framework/gradient_checker.cc
Normal file
@ -0,0 +1,165 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/framework/gradient_checker.h"
|
||||
|
||||
#include "tensorflow/cc/client/client_session.h"
|
||||
#include "tensorflow/cc/framework/gradients.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
|
||||
namespace tensorflow {
|
||||
using namespace ops; // NOLINT(build/namespaces)
|
||||
|
||||
namespace {
|
||||
|
||||
// TODO(andydavis) Support returning relative error (as opposed to max error)
|
||||
// between theoretical and numerical jacobians:
|
||||
// fabs(jac_t - jac_n) / max(fabs(jac_t), fabs(jac_n))
|
||||
|
||||
// TODO(andydavis) Vectorize and/or multi-thread Jacobian computations if
|
||||
// performance becomes an issue.
|
||||
|
||||
template <typename T>
|
||||
Status ComputeTheoreticalJacobianTranspose(
|
||||
const Scope& scope, const ops::Output& x, const TensorShape& x_shape,
|
||||
const Tensor& x_data, const ops::Output& y, const TensorShape& y_shape,
|
||||
Tensor* jacobian_t) {
|
||||
// Call AddSymbolicGradients to get 'dx' (we will feed 'dy').
|
||||
auto dy = Cast(scope, Const(scope, 1.0, y_shape), x.type());
|
||||
std::vector<ops::Output> outputs;
|
||||
TF_RETURN_IF_ERROR(AddSymbolicGradients(scope, {y}, {x}, {dy}, &outputs));
|
||||
auto dx = outputs[0];
|
||||
|
||||
// Initialize 'dy_data' to zeros.
|
||||
Tensor dy_data(y.type(), y_shape);
|
||||
auto dy_data_flat = dy_data.flat<T>();
|
||||
dy_data_flat.setZero();
|
||||
|
||||
// Compute the theoretical Jacobian one row at a time by backproping '1.0'
|
||||
// for each element of 'dy', while holding all other elements of 'dy' at zero.
|
||||
ClientSession session(scope);
|
||||
std::vector<Tensor> dxout;
|
||||
const int64 x_size = x_shape.num_elements();
|
||||
const int64 dy_size = y_shape.num_elements();
|
||||
auto jacobian = jacobian_t->matrix<T>();
|
||||
for (int c = 0; c < dy_size; ++c) {
|
||||
dy_data_flat(c) = 1.0;
|
||||
|
||||
TF_RETURN_IF_ERROR(session.Run({{x, x_data}, {dy, dy_data}}, {dx}, &dxout));
|
||||
|
||||
auto dx_flat = dxout[0].flat<T>();
|
||||
for (int r = 0; r < x_size; ++r) {
|
||||
jacobian(r, c) = dx_flat(r);
|
||||
}
|
||||
|
||||
dy_data_flat(c) = 0.0;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status ComputeNumericJacobianTranspose(const Scope& scope, const ops::Output& x,
|
||||
const TensorShape& x_shape,
|
||||
const ops::Output& y,
|
||||
const TensorShape& y_shape,
|
||||
const T delta, Tensor* x_data,
|
||||
Tensor* jacobian_t) {
|
||||
const int64 x_size = x_shape.num_elements();
|
||||
const int64 y_size = y_shape.num_elements();
|
||||
auto x_data_flat = x_data->flat<T>();
|
||||
|
||||
// Compute the numeric Jacobian one column at a time by perturbing each
|
||||
// element of 'x_data' (positively and negatively) by 'delta', and
|
||||
// updating the jacobian with the centered difference.
|
||||
ClientSession session(scope);
|
||||
std::vector<Tensor> yout;
|
||||
auto jacobian = jacobian_t->matrix<T>();
|
||||
for (int r = 0; r < x_size; ++r) {
|
||||
// Store current value of 'x' at 'r'.
|
||||
T v = x_data_flat(r);
|
||||
// Evaluate at positive delta.
|
||||
x_data_flat(r) = v + delta;
|
||||
TF_RETURN_IF_ERROR(session.Run({{x, *x_data}}, {y}, &yout));
|
||||
Tensor y_pos = yout[0];
|
||||
// Evaluate at negative delta.
|
||||
x_data_flat(r) = v - delta;
|
||||
TF_RETURN_IF_ERROR(session.Run({{x, *x_data}}, {y}, &yout));
|
||||
Tensor y_neg = yout[0];
|
||||
// Compute element-wise centered difference and store in Jacobian.
|
||||
auto y_pos_flat = y_pos.flat<T>();
|
||||
auto y_neg_flat = y_neg.flat<T>();
|
||||
const T scale = 2 * delta;
|
||||
for (int c = 0; c < y_size; ++c) {
|
||||
jacobian(r, c) = (y_pos_flat(c) - y_neg_flat(c)) / scale;
|
||||
}
|
||||
// Restore pre-perturbation value.
|
||||
x_data_flat(r) = v;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
Status ComputeGradientError(const Scope& scope, const ops::Output& x,
|
||||
const TensorShape& x_shape, const ops::Output& y,
|
||||
const TensorShape& y_shape, T* max_error) {
|
||||
const int64 x_size = x_shape.num_elements();
|
||||
const int64 y_size = y_shape.num_elements();
|
||||
|
||||
// Initialize 'x_data' to random values.
|
||||
Tensor x_data(x.type(), x_shape);
|
||||
auto x_data_flat = x_data.flat<T>();
|
||||
x_data_flat.setRandom();
|
||||
|
||||
// Initialize theoretical Jacobian to zeros.
|
||||
Tensor jacobian_t(x.type(), {x_size, y_size});
|
||||
auto jacobian_t_flat = jacobian_t.flat<T>();
|
||||
jacobian_t_flat.setZero();
|
||||
|
||||
// Compute theoretical Jacobian.
|
||||
TF_RETURN_IF_ERROR(ComputeTheoreticalJacobianTranspose<T>(
|
||||
scope, x, x_shape, x_data, y, y_shape, &jacobian_t));
|
||||
|
||||
// Inititalize numeric Jacobian to zeros.
|
||||
Tensor jacobian_n(x.type(), {x_size, y_size});
|
||||
auto jacobian_n_flat = jacobian_n.flat<T>();
|
||||
jacobian_n_flat.setZero();
|
||||
|
||||
// Compute numeric Jacobian.
|
||||
TF_RETURN_IF_ERROR(ComputeNumericJacobianTranspose<T>(
|
||||
scope, x, x_shape, y, y_shape, 1e-3, &x_data, &jacobian_n));
|
||||
|
||||
// Compute the maximum error between theoretical and numeric Jacobians.
|
||||
*max_error = 0.0;
|
||||
auto jac_t = jacobian_t.matrix<T>();
|
||||
auto jac_n = jacobian_n.matrix<T>();
|
||||
for (int r = 0; r < x_size; ++r) {
|
||||
for (int c = 0; c < y_size; ++c) {
|
||||
*max_error = std::max(*max_error, std::fabs(jac_t(r, c) - jac_n(r, c)));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define INSTANTIATE_GRAD_ERR_TYPE(T) \
|
||||
template Status ComputeGradientError<T>( \
|
||||
const Scope& scope, const ops::Output& x, const TensorShape& x_shape, \
|
||||
const ops::Output& y, const TensorShape& y_shape, T* max_error)
|
||||
|
||||
INSTANTIATE_GRAD_ERR_TYPE(float);
|
||||
INSTANTIATE_GRAD_ERR_TYPE(double);
|
||||
|
||||
} // namespace tensorflow
|
35
tensorflow/cc/framework/gradient_checker.h
Normal file
35
tensorflow/cc/framework/gradient_checker.h
Normal file
@ -0,0 +1,35 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRADIENT_CHECKER_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRADIENT_CHECKER_H_
|
||||
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/framework/scope.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Returns in 'max_error' the maximum element-wise error for dy/dx between the
|
||||
// computed and numeric Jacobian matrices where 'x' and 'y' are tensors.
|
||||
// This function adds operations to the graph associated with 'scope'.
|
||||
template <typename T>
|
||||
Status ComputeGradientError(const Scope& scope, const ops::Output& x,
|
||||
const TensorShape& x_shape, const ops::Output& y,
|
||||
const TensorShape& y_shape, T* max_error);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRADIENT_CHECKER_H_
|
70
tensorflow/cc/framework/gradient_checker_test.cc
Normal file
70
tensorflow/cc/framework/gradient_checker_test.cc
Normal file
@ -0,0 +1,70 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/framework/gradient_checker.h"
|
||||
#include "tensorflow/cc/framework/grad_op_registry.h"
|
||||
#include "tensorflow/cc/framework/testutil.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/graph/equal_graph_def.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
using namespace ops; // NOLINT(build/namespaces)
|
||||
|
||||
namespace {
|
||||
|
||||
TEST(GradientCheckerTest, BasicFloat) {
|
||||
Scope scope = Scope::NewRootScope();
|
||||
TensorShape shape({2, 4, 3});
|
||||
auto x = Placeholder(scope, DT_FLOAT, Placeholder::Shape(shape));
|
||||
auto y = Square(scope, x);
|
||||
float max_error;
|
||||
TF_ASSERT_OK(
|
||||
ComputeGradientError<float>(scope, x, shape, y, shape, &max_error));
|
||||
EXPECT_LT(max_error, 1e-4);
|
||||
}
|
||||
|
||||
TEST(GradientCheckerTest, BasicDouble) {
|
||||
Scope scope = Scope::NewRootScope();
|
||||
TensorShape shape({2, 4, 3});
|
||||
auto x = Placeholder(scope, DT_DOUBLE, Placeholder::Shape(shape));
|
||||
auto y = Square(scope, x);
|
||||
double max_error;
|
||||
TF_ASSERT_OK(
|
||||
ComputeGradientError<double>(scope, x, shape, y, shape, &max_error));
|
||||
EXPECT_LT(max_error, 1e-10);
|
||||
}
|
||||
|
||||
TEST(GradientCheckerTest, MatMulGrad) {
|
||||
Scope scope = Scope::NewRootScope();
|
||||
|
||||
TensorShape x_shape({4, 3});
|
||||
TensorShape y_shape({3, 2});
|
||||
TensorShape z_shape({4, 2});
|
||||
|
||||
auto x = Placeholder(scope, DT_DOUBLE, Placeholder::Shape(x_shape));
|
||||
auto y = Const(scope, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, y_shape);
|
||||
auto z = MatMul(scope, x, y);
|
||||
double max_error;
|
||||
TF_ASSERT_OK(
|
||||
ComputeGradientError<double>(scope, x, x_shape, z, z_shape, &max_error));
|
||||
EXPECT_LT(max_error, 1e-10);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user