Extend c++ gradient_checker to complex types.
PiperOrigin-RevId: 168392949
This commit is contained in:
parent
f63aa7f49f
commit
e6b011763a
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/framework/gradients.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/tensor_util.h"
|
||||
#include "tensorflow/core/framework/type_traits.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -31,7 +32,74 @@ namespace {
|
||||
// TODO(andydavis) Vectorize and/or multi-thread Jacobian computations if
|
||||
// performance becomes an issue.
|
||||
|
||||
// BaseUnitsForType provides a list of typed unit values for each basis in the
|
||||
// requested type.
|
||||
// When T is real,
|
||||
// BaseUnitsForType<T>::values() is just a single-entry vector [1]
|
||||
// When T is complex,
|
||||
// BaseUnitsForType<T>::values() is a two-entry vector [1, i] - the unit
|
||||
// values in each of its two bases.
|
||||
template <typename T>
|
||||
struct BaseUnitsForType {}; // Specializations below
|
||||
|
||||
// Template specialization for BaseUnitsForType
|
||||
#define SET_BASE_UNITS_FOR_TYPE(TYPE, INIT) \
|
||||
template <> \
|
||||
struct BaseUnitsForType<TYPE> { \
|
||||
static const std::vector<TYPE>& values() { \
|
||||
static std::vector<TYPE>* units = new std::vector<TYPE> INIT; \
|
||||
return *units; \
|
||||
} \
|
||||
}
|
||||
|
||||
SET_BASE_UNITS_FOR_TYPE(float, {1});
|
||||
SET_BASE_UNITS_FOR_TYPE(double, {1});
|
||||
SET_BASE_UNITS_FOR_TYPE(complex64, ({{1, 0}, {0, 1}}));
|
||||
SET_BASE_UNITS_FOR_TYPE(complex128, ({{1, 0}, {0, 1}}));
|
||||
|
||||
// SetJacobian sets the jacobian value at the provided row and column from a
|
||||
// tensor entry with type T.
|
||||
// When T is real, this is a simple assignment that casts the entry into the
|
||||
// jacobian type.
|
||||
// When T is complex, it assigns the real and complex values to successive rows
|
||||
// or columns in the matrix depending on the expand_by_row parameter
|
||||
template <typename T, typename JAC_T>
|
||||
typename std::enable_if<std::is_floating_point<T>::value>::type SetJacobian(
|
||||
typename TTypes<JAC_T>::Matrix* jacobian, const int row, const int col,
|
||||
const T& value, const bool expand_by_row) {
|
||||
(*jacobian)(row, col) = JAC_T{value};
|
||||
}
|
||||
|
||||
template <typename T, typename JAC_T>
|
||||
typename std::enable_if<is_complex<T>::value>::type SetJacobian(
|
||||
typename TTypes<JAC_T>::Matrix* jacobian, const int row, const int col,
|
||||
const T& value, const bool expand_by_row) {
|
||||
(*jacobian)(row, col) = JAC_T{value.real()};
|
||||
if (expand_by_row) {
|
||||
(*jacobian)(row + 1, col) = JAC_T{value.imag()};
|
||||
} else {
|
||||
(*jacobian)(row, col + 1) = JAC_T{value.imag()};
|
||||
}
|
||||
}
|
||||
|
||||
// JacobianStride<T>::value holds the number of Jacobian elements needed to
|
||||
// represent one element of the given type.
|
||||
// When T is real the stride is 1, and when T is complex the stride is 2.
|
||||
template <typename T>
|
||||
struct JacobianStride {}; // Specializations below
|
||||
|
||||
#define SET_JACOBIAN_STRIDE(TYPE, VALUE) \
|
||||
template <> \
|
||||
struct JacobianStride<TYPE> { \
|
||||
static constexpr int value = VALUE; \
|
||||
}
|
||||
|
||||
SET_JACOBIAN_STRIDE(float, 1);
|
||||
SET_JACOBIAN_STRIDE(double, 1);
|
||||
SET_JACOBIAN_STRIDE(complex64, 2);
|
||||
SET_JACOBIAN_STRIDE(complex128, 2);
|
||||
|
||||
template <typename X_T, typename Y_T, typename JAC_T>
|
||||
Status ComputeTheoreticalJacobianTranspose(
|
||||
const Scope& scope, const OutputList& xs,
|
||||
const std::vector<TensorShape>& x_shapes,
|
||||
@ -44,9 +112,9 @@ Status ComputeTheoreticalJacobianTranspose(
|
||||
OutputList dys;
|
||||
dys.reserve(y_shapes.size());
|
||||
for (const auto& y_shape : y_shapes) {
|
||||
// TODO(suharshs): This currently assumes that all x's are the same type.
|
||||
// TODO(suharshs): This currently assumes that all y's are the same type.
|
||||
dys.push_back(
|
||||
ops::Cast(scope, ops::Const(scope, 1.0, y_shape), xs[0].type()));
|
||||
ops::Cast(scope, ops::Const(scope, 1.0, y_shape), ys[0].type()));
|
||||
}
|
||||
OutputList dxs;
|
||||
TF_RETURN_IF_ERROR(AddSymbolicGradients(scope, ys, xs, dys, &dxs));
|
||||
@ -55,7 +123,7 @@ Status ComputeTheoreticalJacobianTranspose(
|
||||
std::vector<Tensor> dy_datas(y_num);
|
||||
for (int i = 0; i < y_num; i++) {
|
||||
dy_datas[i] = Tensor(ys[i].type(), y_shapes[i]);
|
||||
auto dy_data_flat = dy_datas[i].flat<T>();
|
||||
auto dy_data_flat = dy_datas[i].flat<Y_T>();
|
||||
dy_data_flat.setZero();
|
||||
}
|
||||
|
||||
@ -68,30 +136,41 @@ Status ComputeTheoreticalJacobianTranspose(
|
||||
feed_list.insert({dys[i], dy_datas[i]});
|
||||
}
|
||||
|
||||
// x_stride and y_stride are used to calculate the correct jacobian row and
|
||||
// column position for a pair of elements at positions r, c within the x and y
|
||||
// tensors respectively.
|
||||
const int x_stride = JacobianStride<X_T>::value;
|
||||
const int y_stride = JacobianStride<Y_T>::value;
|
||||
ClientSession session(scope);
|
||||
for (int y_idx = 0; y_idx < y_num; y_idx++) {
|
||||
auto dy_data_flat = dy_datas[y_idx].flat<T>();
|
||||
auto dy_data_flat = dy_datas[y_idx].flat<Y_T>();
|
||||
const int64 dy_size = y_shapes[y_idx].num_elements();
|
||||
|
||||
// Compute the theoretical Jacobians one row at a time by back propagating
|
||||
// '1.0' for each element of 'dy', while holding all other elements of 'dy'
|
||||
// at zero.
|
||||
// '1.0' (or '1' and 'i' if y is complex) for each element of 'dy', while
|
||||
// holding all other elements of 'dy' at zero.
|
||||
for (int c = 0; c < dy_size; ++c) {
|
||||
dy_data_flat(c) = 1.0;
|
||||
int unit_dimension = 0;
|
||||
for (Y_T unit : BaseUnitsForType<Y_T>::values()) {
|
||||
dy_data_flat(c) = unit;
|
||||
|
||||
std::vector<Tensor> dxout;
|
||||
TF_RETURN_IF_ERROR(session.Run(feed_list, dxs, &dxout));
|
||||
std::vector<Tensor> dxout;
|
||||
TF_RETURN_IF_ERROR(session.Run(feed_list, dxs, &dxout));
|
||||
|
||||
for (int x_idx = 0; x_idx < x_num; x_idx++) {
|
||||
const int64 x_size = x_shapes[x_idx].num_elements();
|
||||
auto jacobian = (*jacobian_ts)[x_idx * y_num + y_idx].matrix<T>();
|
||||
auto dx_flat = dxout[x_idx].flat<T>();
|
||||
for (int r = 0; r < x_size; ++r) {
|
||||
jacobian(r, c) = dx_flat(r);
|
||||
for (int x_idx = 0; x_idx < x_num; x_idx++) {
|
||||
const int64 x_size = x_shapes[x_idx].num_elements();
|
||||
auto jacobian = (*jacobian_ts)[x_idx * y_num + y_idx].matrix<JAC_T>();
|
||||
auto dx_flat = dxout[x_idx].flat<X_T>();
|
||||
for (int r = 0; r < x_size; ++r) {
|
||||
SetJacobian<X_T, JAC_T>(&jacobian, r * x_stride,
|
||||
c * y_stride + unit_dimension, dx_flat(r),
|
||||
true /* expand_by_row=true */);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dy_data_flat(c) = 0.0;
|
||||
dy_data_flat(c) = Y_T{0};
|
||||
unit_dimension++;
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
@ -122,104 +201,154 @@ Status EvaluateGraph(ClientSession* session, const OutputList& xs,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename X_T, typename Y_T, typename JAC_T>
|
||||
Status ComputeNumericJacobianTranspose(const Scope& scope, const OutputList& xs,
|
||||
const std::vector<TensorShape>& x_shapes,
|
||||
const OutputList& ys,
|
||||
const std::vector<TensorShape>& y_shapes,
|
||||
const T delta,
|
||||
const JAC_T delta,
|
||||
std::vector<Tensor>* x_datas,
|
||||
std::vector<Tensor>* jacobian_ts) {
|
||||
size_t y_num = y_shapes.size();
|
||||
size_t x_num = x_shapes.size();
|
||||
// x_stride and y_stride are used to calculate the correct jacobian row and
|
||||
// column position for a pair of elements at positions r, c within the x and y
|
||||
// tensors respectively.
|
||||
const int x_stride = JacobianStride<X_T>::value;
|
||||
const int y_stride = JacobianStride<Y_T>::value;
|
||||
|
||||
ClientSession session(scope);
|
||||
for (int x_idx = 0; x_idx < x_num; x_idx++) {
|
||||
auto x_data_flat = (*x_datas)[x_idx].flat<T>();
|
||||
auto x_data_flat = (*x_datas)[x_idx].flat<X_T>();
|
||||
const int64 x_size = x_shapes[x_idx].num_elements();
|
||||
|
||||
// Compute the numeric Jacobian one column at a time by perturbing each
|
||||
// element of 'x_data' (positively and negatively) by 'delta', and
|
||||
// updating the jacobian with the centered difference.
|
||||
// updating the jacobian with the centered difference. When x_data is
|
||||
// complex-valued, we perturb its real and complex parts separately.
|
||||
for (int r = 0; r < x_size; ++r) {
|
||||
// Store current value of 'x' at 'r'.
|
||||
T v = x_data_flat(r);
|
||||
// Evaluate at positive delta.
|
||||
x_data_flat(r) = v + delta;
|
||||
std::vector<Tensor> y_pos;
|
||||
TF_RETURN_IF_ERROR(EvaluateGraph(&session, xs, ys, x_datas, &y_pos));
|
||||
// Evaluate at negative delta.
|
||||
x_data_flat(r) = v - delta;
|
||||
std::vector<Tensor> y_neg;
|
||||
TF_RETURN_IF_ERROR(EvaluateGraph(&session, xs, ys, x_datas, &y_neg));
|
||||
int unit_dimension = 0;
|
||||
for (X_T unit : BaseUnitsForType<X_T>::values()) {
|
||||
X_T x_delta = unit * X_T{delta};
|
||||
// Store current value of 'x' at 'r'.
|
||||
X_T v = x_data_flat(r);
|
||||
// Evaluate at positive delta.
|
||||
x_data_flat(r) = v + x_delta;
|
||||
std::vector<Tensor> y_pos;
|
||||
TF_RETURN_IF_ERROR(EvaluateGraph(&session, xs, ys, x_datas, &y_pos));
|
||||
// Evaluate at negative delta.
|
||||
x_data_flat(r) = v - x_delta;
|
||||
std::vector<Tensor> y_neg;
|
||||
TF_RETURN_IF_ERROR(EvaluateGraph(&session, xs, ys, x_datas, &y_neg));
|
||||
|
||||
for (int y_idx = 0; y_idx < y_num; y_idx++) {
|
||||
// Compute element-wise centered difference and store in each Jacobian.
|
||||
auto y_pos_flat = y_pos[y_idx].flat<T>();
|
||||
auto y_neg_flat = y_neg[y_idx].flat<T>();
|
||||
const int64 y_size = y_shapes[y_idx].num_elements();
|
||||
const T scale = 2 * delta;
|
||||
auto jacobian = (*jacobian_ts)[x_idx * y_num + y_idx].matrix<T>();
|
||||
for (int c = 0; c < y_size; ++c) {
|
||||
jacobian(r, c) = (y_pos_flat(c) - y_neg_flat(c)) / scale;
|
||||
for (int y_idx = 0; y_idx < y_num; y_idx++) {
|
||||
// Compute element-wise centered difference and store in each
|
||||
// Jacobian.
|
||||
auto y_pos_flat = y_pos[y_idx].flat<Y_T>();
|
||||
auto y_neg_flat = y_neg[y_idx].flat<Y_T>();
|
||||
const int64 y_size = y_shapes[y_idx].num_elements();
|
||||
const Y_T scale = Y_T{2 * delta};
|
||||
auto jacobian = (*jacobian_ts)[x_idx * y_num + y_idx].matrix<JAC_T>();
|
||||
for (int c = 0; c < y_size; ++c) {
|
||||
SetJacobian<Y_T, JAC_T>(&jacobian, r * x_stride + unit_dimension,
|
||||
c * y_stride,
|
||||
(y_pos_flat(c) - y_neg_flat(c)) / scale,
|
||||
false /* expand_by_row=false */);
|
||||
}
|
||||
}
|
||||
// Restore pre-perturbation value.
|
||||
x_data_flat(r) = v;
|
||||
unit_dimension++;
|
||||
}
|
||||
// Restore pre-perturbation value.
|
||||
x_data_flat(r) = v;
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
// The Jacobian is always a real-valued matrix.
|
||||
// Given y = f(x) for tensors y and x, it contains the derivatives dy_i/dx_j for
|
||||
// every pair y_i in y and x_j in x. Note that the Jacobian is defined directly
|
||||
// over the elements of tensors y and x, and doesn't depend on their shapes.
|
||||
//
|
||||
// If x = (x_1, x_2, ..., x_m) and y = (y_1, y_2, .., y_n) the matrix evaluated
|
||||
// is actually the Jacobian transpose, defined as this mxn matrix:
|
||||
// dy_1/d_x1 dy_2/dx_1 ... dy_n/dx_1
|
||||
// dy_1/dx_2 dy_2/dx_2 ... dy_n/dx_2
|
||||
// .
|
||||
// .
|
||||
// .
|
||||
// dy_1/dx_m dy_2/dx_m ... dy_n/dx_m
|
||||
//
|
||||
// If x or y is complex, each complex entry is "expanded" into a real and
|
||||
// imaginary entry, and the Jacobian is organized as above on the expanded list.
|
||||
// e.g.
|
||||
// [y1, y2] = Square([x1, x2]) where x and y are complex.
|
||||
// Writing
|
||||
// x = [x1_real, x1_imag, x2_real, x2_imag]
|
||||
// y = [y1_real, y1_imag, y2_real, y2_imag]
|
||||
// the Jacobian transpose is
|
||||
// the 4x4 matrix:
|
||||
// dy1_real/dx1_real dy1_imag/dx1_real dy2_real/dx1_real dy2_imag/dx1_real
|
||||
// dy1_real/dx1_imag dy1_imag/dx1_imag dy2_real/dx1_imag dy2_imag/dx1_imag
|
||||
// dy1_real/dx2_real dy1_imag/dx2_real dy2_real/dx2_real dy2_imag/dx2_real
|
||||
// dy1_real/dx2_imag dy1_imag/dx2_imag dy2_real/dx2_imag dy2_imag/dx2_imag
|
||||
template <typename X_T, typename Y_T, typename JAC_T>
|
||||
void InitJacobians(const OutputList& xs,
|
||||
const std::vector<TensorShape>& x_shapes,
|
||||
const std::vector<TensorShape>& y_shapes,
|
||||
std::vector<Tensor>* jacobians) {
|
||||
size_t y_num = y_shapes.size();
|
||||
size_t x_num = x_shapes.size();
|
||||
const size_t y_num = y_shapes.size();
|
||||
const size_t x_num = x_shapes.size();
|
||||
const DataType jacobian_type = DataTypeToEnum<JAC_T>::v();
|
||||
|
||||
jacobians->resize(y_num * x_num);
|
||||
for (int x_idx = 0; x_idx < x_num; x_idx++) {
|
||||
const int64 x_size = x_shapes[x_idx].num_elements();
|
||||
// The number of rows is the number of elements in the x tensor multiplied
|
||||
// by the number of Jacobian entries needed to represent each x type.
|
||||
const int64 x_size =
|
||||
x_shapes[x_idx].num_elements() * JacobianStride<X_T>::value;
|
||||
for (int y_idx = 0; y_idx < y_num; y_idx++) {
|
||||
const int64 y_size = y_shapes[y_idx].num_elements();
|
||||
Tensor jacobian_t(xs[x_idx].type(), {x_size, y_size});
|
||||
auto jacobian_t_flat = jacobian_t.flat<T>();
|
||||
// The number of columns is the number of elements in the y tensor
|
||||
// multiplied by the number of Jacobian entries needed to represent each
|
||||
// y type.
|
||||
const int64 y_size =
|
||||
y_shapes[y_idx].num_elements() * JacobianStride<Y_T>::value;
|
||||
Tensor jacobian_t(jacobian_type, {x_size, y_size});
|
||||
auto jacobian_t_flat = jacobian_t.flat<JAC_T>();
|
||||
jacobian_t_flat.setZero();
|
||||
(*jacobians)[x_idx * y_num + y_idx] = std::move(jacobian_t);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename X_T, typename Y_T, typename JAC_T>
|
||||
Status ComputeGradientErrorInternal(const Scope& scope, const OutputList& xs,
|
||||
const std::vector<TensorShape>& x_shapes,
|
||||
const OutputList& ys,
|
||||
const std::vector<TensorShape>& y_shapes,
|
||||
std::vector<Tensor>* x_datas,
|
||||
T* max_error) {
|
||||
JAC_T* max_error) {
|
||||
// Initialize theoretical Jacobians to zeros.
|
||||
std::vector<Tensor> jacobian_ts;
|
||||
InitJacobians<T>(xs, x_shapes, y_shapes, &jacobian_ts);
|
||||
InitJacobians<X_T, Y_T, JAC_T>(xs, x_shapes, y_shapes, &jacobian_ts);
|
||||
|
||||
// Compute theoretical Jacobian.
|
||||
TF_RETURN_IF_ERROR(ComputeTheoreticalJacobianTranspose<T>(
|
||||
scope, xs, x_shapes, *x_datas, ys, y_shapes, &jacobian_ts));
|
||||
TF_RETURN_IF_ERROR((ComputeTheoreticalJacobianTranspose<X_T, Y_T, JAC_T>(
|
||||
scope, xs, x_shapes, *x_datas, ys, y_shapes, &jacobian_ts)));
|
||||
|
||||
// Initialize numeric Jacobian to zeros.
|
||||
std::vector<Tensor> jacobian_ns;
|
||||
InitJacobians<T>(xs, x_shapes, y_shapes, &jacobian_ns);
|
||||
InitJacobians<X_T, Y_T, JAC_T>(xs, x_shapes, y_shapes, &jacobian_ns);
|
||||
|
||||
// Compute numeric Jacobian.
|
||||
TF_RETURN_IF_ERROR(ComputeNumericJacobianTranspose<T>(
|
||||
scope, xs, x_shapes, ys, y_shapes, 1e-3, x_datas, &jacobian_ns));
|
||||
TF_RETURN_IF_ERROR((ComputeNumericJacobianTranspose<X_T, Y_T, JAC_T>(
|
||||
scope, xs, x_shapes, ys, y_shapes, JAC_T{1e-3f}, x_datas, &jacobian_ns)));
|
||||
|
||||
for (int i = 0; i < jacobian_ts.size(); i++) {
|
||||
// Compute the maximum error between theoretical and numeric Jacobians.
|
||||
*max_error = 0.0;
|
||||
auto jac_t = jacobian_ts[i].matrix<T>();
|
||||
auto jac_n = jacobian_ns[i].matrix<T>();
|
||||
auto jac_t = jacobian_ts[i].matrix<JAC_T>();
|
||||
auto jac_n = jacobian_ns[i].matrix<JAC_T>();
|
||||
for (int r = 0; r < jacobian_ts[i].dim_size(0); ++r) {
|
||||
for (int c = 0; c < jacobian_ts[i].dim_size(1); ++c) {
|
||||
*max_error = std::max(*max_error, std::fabs(jac_t(r, c) - jac_n(r, c)));
|
||||
@ -231,12 +360,12 @@ Status ComputeGradientErrorInternal(const Scope& scope, const OutputList& xs,
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
template <typename X_T, typename Y_T, typename JAC_T>
|
||||
Status ComputeGradientError(const Scope& scope, const OutputList& xs,
|
||||
const std::vector<TensorShape>& x_shapes,
|
||||
const OutputList& ys,
|
||||
const std::vector<TensorShape>& y_shapes,
|
||||
T* max_error) {
|
||||
JAC_T* max_error) {
|
||||
if (xs.size() != x_shapes.size()) {
|
||||
return errors::InvalidArgument("xs(size ", xs.size(),
|
||||
") and x_shapes(size ", x_shapes.size(),
|
||||
@ -251,35 +380,39 @@ Status ComputeGradientError(const Scope& scope, const OutputList& xs,
|
||||
std::vector<Tensor> x_datas(x_shapes.size());
|
||||
for (int i = 0; i < x_shapes.size(); i++) {
|
||||
x_datas[i] = Tensor(xs[i].type(), x_shapes[i]);
|
||||
auto x_data_flat = x_datas[i].flat<T>();
|
||||
auto x_data_flat = x_datas[i].flat<X_T>();
|
||||
x_data_flat.setRandom();
|
||||
}
|
||||
// Compute gradient error.
|
||||
return ComputeGradientErrorInternal(scope, xs, x_shapes, ys, y_shapes,
|
||||
&x_datas, max_error);
|
||||
return ComputeGradientErrorInternal<X_T, Y_T, JAC_T>(
|
||||
scope, xs, x_shapes, ys, y_shapes, &x_datas, max_error);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename X_T, typename Y_T, typename JAC_T>
|
||||
Status ComputeGradientError(const Scope& scope, const Output& x,
|
||||
const Tensor& x_init_value, const Output& y,
|
||||
const TensorShape& y_shape, T* max_error) {
|
||||
const TensorShape& y_shape, JAC_T* max_error) {
|
||||
// Initialize 'x_data' from 'x_init_value'.
|
||||
std::vector<Tensor> x_datas(1, Tensor(x_init_value));
|
||||
// Compute gradient error.
|
||||
return ComputeGradientErrorInternal(scope, {x}, {x_datas[0].shape()}, {y},
|
||||
{y_shape}, &x_datas, max_error);
|
||||
return ComputeGradientErrorInternal<X_T, Y_T, JAC_T>(
|
||||
scope, {x}, {x_datas[0].shape()}, {y}, {y_shape}, &x_datas, max_error);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_GRAD_ERR_TYPE(T) \
|
||||
template Status ComputeGradientError<T>( \
|
||||
#define INSTANTIATE_GRAD_ERR_TYPE(X_T, Y_T, JAC_T) \
|
||||
template Status ComputeGradientError<X_T, Y_T, JAC_T>( \
|
||||
const Scope& scope, const OutputList& xs, \
|
||||
const std::vector<TensorShape>& x_shapes, const OutputList& ys, \
|
||||
const std::vector<TensorShape>& y_shapes, T* max_error); \
|
||||
template Status ComputeGradientError<T>( \
|
||||
const std::vector<TensorShape>& y_shapes, JAC_T* max_error); \
|
||||
template Status ComputeGradientError<X_T, Y_T, JAC_T>( \
|
||||
const Scope& scope, const Output& x, const Tensor& x_init_value, \
|
||||
const Output& y, const TensorShape& y_shape, T* max_error);
|
||||
const Output& y, const TensorShape& y_shape, JAC_T* max_error);
|
||||
|
||||
INSTANTIATE_GRAD_ERR_TYPE(float);
|
||||
INSTANTIATE_GRAD_ERR_TYPE(double);
|
||||
INSTANTIATE_GRAD_ERR_TYPE(float, float, float);
|
||||
INSTANTIATE_GRAD_ERR_TYPE(double, double, double);
|
||||
INSTANTIATE_GRAD_ERR_TYPE(complex64, float, float);
|
||||
INSTANTIATE_GRAD_ERR_TYPE(float, complex64, float);
|
||||
INSTANTIATE_GRAD_ERR_TYPE(complex64, complex64, float);
|
||||
INSTANTIATE_GRAD_ERR_TYPE(complex128, complex128, double);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -24,19 +24,39 @@ namespace tensorflow {
|
||||
|
||||
/// Returns in 'max_error' the maximum element-wise error for dy/dx between the
|
||||
/// computed and numeric Jacobian matrices where 'xs' and 'ys' are tensors.
|
||||
/// X_T and Y_T are the c++ types for the x and y tensors, and JAC_T is a
|
||||
/// real-valued type to store the Jacobian derivatives dy/dx.
|
||||
/// This function adds operations to the graph associated with 'scope'.
|
||||
template <typename T>
|
||||
///
|
||||
/// Examples:
|
||||
/// if y = Square(x), where x (and so y) are DT_FLOAT,
|
||||
/// <X_T, Y_T, JAC_T> should be <float, float, float>
|
||||
///
|
||||
/// if y = Square(x), where x (and so y) are DT_DOUBLE,
|
||||
/// <X_T, Y_T, JAC_T> should be <double, double, double>
|
||||
///
|
||||
/// if y = Square(x), where x (and so y) are DT_COMPLEX64,
|
||||
/// <X_T, Y_T, JAC_T> should be <complex64, complex64, float>
|
||||
/// Note that JAC_T is always real-valued, and should be an appropriate
|
||||
/// precision to host the partial derivatives for dy/dx
|
||||
///
|
||||
/// if y = ComplexAbs(x) where x is DT_COMPLEX64 (so y is DT_FLOAT)
|
||||
/// <X_T, Y_T, JAC_T> should be <complex64, float, float>
|
||||
///
|
||||
/// if y = Complex(x, x) where x is DT_FLOAT (so y is DT_COMPLEX64)
|
||||
/// <X_T, Y_T, JAC_T> should be <float, complex64, float>
|
||||
template <typename X_T, typename Y_T, typename JAC_T>
|
||||
Status ComputeGradientError(const Scope& scope, const OutputList& xs,
|
||||
const std::vector<TensorShape>& x_shapes,
|
||||
const OutputList& ys,
|
||||
const std::vector<TensorShape>& y_shapes,
|
||||
T* max_error);
|
||||
JAC_T* max_error);
|
||||
|
||||
/// Overload of ComputeGradientError which takes an initial value for 'x'.
|
||||
template <typename T>
|
||||
template <typename X_T, typename Y_T, typename JAC_T>
|
||||
Status ComputeGradientError(const Scope& scope, const Output& x,
|
||||
const Tensor& x_init_value, const Output& y,
|
||||
const TensorShape& y_shape, T* max_error);
|
||||
const TensorShape& y_shape, JAC_T* max_error);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -34,8 +34,8 @@ TEST(GradientCheckerTest, BasicFloat) {
|
||||
auto x = Placeholder(scope, DT_FLOAT, Placeholder::Shape(shape));
|
||||
auto y = Square(scope, x);
|
||||
float max_error;
|
||||
TF_ASSERT_OK(ComputeGradientError<float>(scope, {x}, {shape}, {y}, {shape},
|
||||
&max_error));
|
||||
TF_ASSERT_OK((ComputeGradientError<float, float, float>(
|
||||
scope, {x}, {shape}, {y}, {shape}, &max_error)));
|
||||
EXPECT_LT(max_error, 1e-4);
|
||||
}
|
||||
|
||||
@ -45,11 +45,57 @@ TEST(GradientCheckerTest, BasicDouble) {
|
||||
auto x = Placeholder(scope, DT_DOUBLE, Placeholder::Shape(shape));
|
||||
auto y = Square(scope, x);
|
||||
double max_error;
|
||||
TF_ASSERT_OK(ComputeGradientError<double>(scope, {x}, {shape}, {y}, {shape},
|
||||
&max_error));
|
||||
TF_ASSERT_OK((ComputeGradientError<double, double, double>(
|
||||
scope, {x}, {shape}, {y}, {shape}, &max_error)));
|
||||
EXPECT_LT(max_error, 1e-10);
|
||||
}
|
||||
|
||||
TEST(GradientCheckerTest, BasicComplex64) {
|
||||
Scope scope = Scope::NewRootScope();
|
||||
TensorShape shape({2, 4, 3});
|
||||
auto x = Placeholder(scope, DT_COMPLEX64, Placeholder::Shape(shape));
|
||||
auto y = Square(scope, x);
|
||||
float max_error;
|
||||
TF_ASSERT_OK((ComputeGradientError<complex64, complex64, float>(
|
||||
scope, {x}, {shape}, {y}, {shape}, &max_error)));
|
||||
EXPECT_LT(max_error, 1e-4);
|
||||
}
|
||||
|
||||
TEST(GradientCheckerTest, BasicComplex128) {
|
||||
Scope scope = Scope::NewRootScope();
|
||||
TensorShape shape({2, 4, 3});
|
||||
auto x = Placeholder(scope, DT_COMPLEX128, Placeholder::Shape(shape));
|
||||
auto y = Square(scope, x);
|
||||
double max_error;
|
||||
TF_ASSERT_OK((ComputeGradientError<complex128, complex128, double>(
|
||||
scope, {x}, {shape}, {y}, {shape}, &max_error)));
|
||||
EXPECT_LT(max_error, 1e-10);
|
||||
}
|
||||
|
||||
TEST(GradientCheckerTest, FloatToComplex64) {
|
||||
// Test an op whose inputs are real and outputs are complex
|
||||
Scope scope = Scope::NewRootScope();
|
||||
TensorShape shape({2, 4, 3});
|
||||
auto x = Placeholder(scope, DT_FLOAT, Placeholder::Shape(shape));
|
||||
auto y = Complex(scope, x, x);
|
||||
float max_error;
|
||||
TF_ASSERT_OK((ComputeGradientError<float, complex64, float>(
|
||||
scope, {x}, {shape}, {y}, {shape}, &max_error)));
|
||||
EXPECT_LT(max_error, 1e-4);
|
||||
}
|
||||
|
||||
TEST(GradientCheckerTest, Complex64ToFloat) {
|
||||
// Test an op whose inputs are complex and outputs are real
|
||||
Scope scope = Scope::NewRootScope();
|
||||
TensorShape shape({2, 4, 3});
|
||||
auto x = Placeholder(scope, DT_COMPLEX64, Placeholder::Shape(shape));
|
||||
auto y = Real(scope, x);
|
||||
float max_error;
|
||||
TF_ASSERT_OK((ComputeGradientError<complex64, float, float>(
|
||||
scope, {x}, {shape}, {y}, {shape}, &max_error)));
|
||||
EXPECT_LT(max_error, 1e-4);
|
||||
}
|
||||
|
||||
TEST(GradientCheckerTest, MatMulGrad) {
|
||||
Scope scope = Scope::NewRootScope();
|
||||
|
||||
@ -61,8 +107,8 @@ TEST(GradientCheckerTest, MatMulGrad) {
|
||||
auto y = Const(scope, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, y_shape);
|
||||
auto z = MatMul(scope, x, y);
|
||||
double max_error;
|
||||
TF_ASSERT_OK(ComputeGradientError<double>(scope, {x}, {x_shape}, {z},
|
||||
{z_shape}, &max_error));
|
||||
TF_ASSERT_OK((ComputeGradientError<double, double, double>(
|
||||
scope, {x}, {x_shape}, {z}, {z_shape}, &max_error)));
|
||||
EXPECT_LT(max_error, 1e-10);
|
||||
}
|
||||
|
||||
@ -76,8 +122,8 @@ TEST(GradientCheckerTest, SplitGrad) {
|
||||
auto y = Split(scope, split_dim, x, /* num_split */ 2);
|
||||
TensorShape y_shape = TensorShape({5, 1});
|
||||
double max_error;
|
||||
TF_ASSERT_OK(ComputeGradientError<double>(scope, {x}, {x_shape}, y.output,
|
||||
{y_shape, y_shape}, &max_error));
|
||||
TF_ASSERT_OK((ComputeGradientError<double, double, double>(
|
||||
scope, {x}, {x_shape}, y.output, {y_shape, y_shape}, &max_error)));
|
||||
EXPECT_LT(max_error, 1e-10);
|
||||
}
|
||||
|
||||
@ -91,8 +137,8 @@ TEST(GradientCheckerTest, StackGrad) {
|
||||
auto y = Stack(scope, xs, Stack::Axis(0));
|
||||
TensorShape y_shape({2, 1, 2, 3});
|
||||
double max_error;
|
||||
TF_ASSERT_OK(ComputeGradientError<double>(scope, xs, {x_shape, x_shape}, {y},
|
||||
{y_shape}, &max_error));
|
||||
TF_ASSERT_OK((ComputeGradientError<double, double, double>(
|
||||
scope, xs, {x_shape, x_shape}, {y}, {y_shape}, &max_error)));
|
||||
EXPECT_LT(max_error, 1e-10);
|
||||
}
|
||||
|
||||
@ -107,8 +153,8 @@ TEST(GradientCheckerTest, StackUnstackGrad) {
|
||||
auto tmp = Stack(scope, xs, Stack::Axis(0));
|
||||
auto y = Unstack(scope, tmp, 2, Unstack::Axis(0));
|
||||
double max_error;
|
||||
TF_ASSERT_OK(ComputeGradientError<double>(scope, xs, {shape, shape}, y.output,
|
||||
{shape, shape}, &max_error));
|
||||
TF_ASSERT_OK((ComputeGradientError<double, double, double>(
|
||||
scope, xs, {shape, shape}, y.output, {shape, shape}, &max_error)));
|
||||
EXPECT_LT(max_error, 1e-10);
|
||||
}
|
||||
|
||||
|
@ -36,8 +36,8 @@ class ArrayGradTest : public ::testing::Test {
|
||||
const TensorShape& y_shape) {
|
||||
TF_ASSERT_OK(scope_.status());
|
||||
float max_error;
|
||||
TF_ASSERT_OK(ComputeGradientError(scope_, {x}, {x_shape}, {y}, {y_shape},
|
||||
&max_error));
|
||||
TF_ASSERT_OK((ComputeGradientError<float, float, float>(
|
||||
scope_, {x}, {x_shape}, {y}, {y_shape}, &max_error)));
|
||||
EXPECT_LT(max_error, 1e-3);
|
||||
}
|
||||
|
||||
@ -45,8 +45,8 @@ class ArrayGradTest : public ::testing::Test {
|
||||
const OutputList& ys, const std::vector<TensorShape>& y_shapes) {
|
||||
TF_ASSERT_OK(scope_.status());
|
||||
float max_error;
|
||||
TF_ASSERT_OK(
|
||||
ComputeGradientError(scope_, xs, x_shapes, ys, y_shapes, &max_error));
|
||||
TF_ASSERT_OK((ComputeGradientError<float, float, float>(
|
||||
scope_, xs, x_shapes, ys, y_shapes, &max_error)));
|
||||
EXPECT_LT(max_error, 1e-3);
|
||||
}
|
||||
|
||||
|
@ -35,8 +35,8 @@ class DataFlowGradTest : public ::testing::Test {
|
||||
const OutputList& ys, const std::vector<TensorShape>& y_shapes) {
|
||||
TF_ASSERT_OK(scope_.status());
|
||||
float max_error;
|
||||
TF_ASSERT_OK(
|
||||
ComputeGradientError(scope_, xs, x_shapes, ys, y_shapes, &max_error));
|
||||
TF_ASSERT_OK((ComputeGradientError<float, float, float>(
|
||||
scope_, xs, x_shapes, ys, y_shapes, &max_error)));
|
||||
EXPECT_LT(max_error, 1e-4);
|
||||
}
|
||||
|
||||
|
@ -526,6 +526,15 @@ Status ImagGrad(const Scope& scope, const Operation& op,
|
||||
}
|
||||
REGISTER_GRADIENT_OP("Imag", ImagGrad);
|
||||
|
||||
Status ComplexGrad(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
auto gx_1 = Real(scope, grad_inputs[0]);
|
||||
auto gx_2 = Imag(scope, grad_inputs[0]);
|
||||
return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
|
||||
}
|
||||
REGISTER_GRADIENT_OP("Complex", ComplexGrad);
|
||||
|
||||
Status AngleGrad(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
|
@ -737,10 +737,14 @@ TEST_F(CWiseUnaryComplexGradTest, Angle) {
|
||||
Tensor x = test::AsTensor<complex64>(
|
||||
{{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3});
|
||||
Tensor dy = test::AsTensor<float>({11, -12, 13, -14, 15, -16}, {2, 3});
|
||||
Tensor dx_expected = test::AsTensor<complex64>(
|
||||
{{5.5, 5.5}, {3, 3},
|
||||
{2.1666666666666665, 2.1666666666666665}, {1.75, 1.75},
|
||||
{0.9375, 0.9375}, {0.8888888888888888, 0.8888888888888888}}, {2, 3});
|
||||
Tensor dx_expected =
|
||||
test::AsTensor<complex64>({{5.5, 5.5},
|
||||
{3, 3},
|
||||
{2.1666666666666665, 2.1666666666666665},
|
||||
{1.75, 1.75},
|
||||
{0.9375, 0.9375},
|
||||
{0.8888888888888888, 0.8888888888888888}},
|
||||
{2, 3});
|
||||
TestCWiseGradComplex(ANGLE, x, dy, dx_expected);
|
||||
}
|
||||
|
||||
@ -920,8 +924,8 @@ class NaryGradTest : public ::testing::Test {
|
||||
const OutputList& ys, const std::vector<TensorShape>& y_shapes) {
|
||||
TF_ASSERT_OK(scope_.status());
|
||||
float max_error;
|
||||
TF_ASSERT_OK(
|
||||
ComputeGradientError(scope_, xs, x_shapes, ys, y_shapes, &max_error));
|
||||
TF_ASSERT_OK((ComputeGradientError<float, float, float>(
|
||||
scope_, xs, x_shapes, ys, y_shapes, &max_error)));
|
||||
EXPECT_LT(max_error, 1e-3);
|
||||
}
|
||||
|
||||
@ -929,8 +933,8 @@ class NaryGradTest : public ::testing::Test {
|
||||
const TensorShape& y_shape) {
|
||||
TF_ASSERT_OK(scope_.status());
|
||||
float max_error;
|
||||
TF_ASSERT_OK(
|
||||
ComputeGradientError(scope_, x, x_init_value, y, y_shape, &max_error));
|
||||
TF_ASSERT_OK((ComputeGradientError<float, float, float>(
|
||||
scope_, x, x_init_value, y, y_shape, &max_error)));
|
||||
EXPECT_LT(max_error, 1e-3);
|
||||
}
|
||||
|
||||
|
@ -34,16 +34,16 @@ class NNGradTest : public ::testing::Test {
|
||||
void RunTest(const Output& x, const TensorShape& x_shape, const Output& y,
|
||||
const TensorShape& y_shape) {
|
||||
float max_error;
|
||||
TF_ASSERT_OK(ComputeGradientError(scope_, {x}, {x_shape}, {y}, {y_shape},
|
||||
&max_error));
|
||||
TF_ASSERT_OK((ComputeGradientError<float, float, float>(
|
||||
scope_, {x}, {x_shape}, {y}, {y_shape}, &max_error)));
|
||||
EXPECT_LT(max_error, 2e-4);
|
||||
}
|
||||
|
||||
void RunTest(const Output& x, const Tensor& x_init_value, const Output& y,
|
||||
const TensorShape& y_shape) {
|
||||
float max_error;
|
||||
TF_ASSERT_OK(
|
||||
ComputeGradientError(scope_, x, x_init_value, y, y_shape, &max_error));
|
||||
TF_ASSERT_OK((ComputeGradientError<float, float, float>(
|
||||
scope_, x, x_init_value, y, y_shape, &max_error)));
|
||||
EXPECT_LT(max_error, 2e-4);
|
||||
}
|
||||
|
||||
@ -51,8 +51,8 @@ class NNGradTest : public ::testing::Test {
|
||||
const OutputList& ys, const std::vector<TensorShape>& y_shapes) {
|
||||
TF_ASSERT_OK(scope_.status());
|
||||
float max_error;
|
||||
TF_ASSERT_OK(
|
||||
ComputeGradientError(scope_, xs, x_shapes, ys, y_shapes, &max_error));
|
||||
TF_ASSERT_OK((ComputeGradientError<float, float, float>(
|
||||
scope_, xs, x_shapes, ys, y_shapes, &max_error)));
|
||||
EXPECT_LT(max_error, 2e-4);
|
||||
}
|
||||
|
||||
@ -71,11 +71,10 @@ TEST_F(NNGradTest, LogSoftmaxGrad) {
|
||||
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
|
||||
auto y = LogSoftmax(scope_, x);
|
||||
// Avoid numerical instability when computing finite differences.
|
||||
Tensor x_init_value = test::AsTensor<float>(
|
||||
{-0.9f, -0.7f, -0.5f, -0.3f, -0.1f,
|
||||
0.1f, 0.3f, 0.5f, 0.7f, 0.8f,
|
||||
-0.1f, 0.1f, 0.1f, 0.1f, 1.2f},
|
||||
{5, 3});
|
||||
Tensor x_init_value =
|
||||
test::AsTensor<float>({-0.9f, -0.7f, -0.5f, -0.3f, -0.1f, 0.1f, 0.3f,
|
||||
0.5f, 0.7f, 0.8f, -0.1f, 0.1f, 0.1f, 0.1f, 1.2f},
|
||||
{5, 3});
|
||||
RunTest(x, x_init_value, y, shape);
|
||||
}
|
||||
|
||||
@ -136,7 +135,7 @@ TEST_F(NNGradTest, BiasAddGradHelper) {
|
||||
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
|
||||
auto bias = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(bias_shape));
|
||||
auto y = BiasAdd(scope_, x, bias);
|
||||
RunTest({x,bias}, {shape, bias_shape}, {y}, {shape});
|
||||
RunTest({x, bias}, {shape, bias_shape}, {y}, {shape});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
Loading…
Reference in New Issue
Block a user