Extend c++ gradient_checker to complex types.

PiperOrigin-RevId: 168392949
This commit is contained in:
A. Unique TensorFlower 2017-09-12 10:16:26 -07:00 committed by TensorFlower Gardener
parent f63aa7f49f
commit e6b011763a
8 changed files with 328 additions and 117 deletions

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/cc/framework/gradients.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor_util.h"
#include "tensorflow/core/framework/type_traits.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
@ -31,7 +32,74 @@ namespace {
// TODO(andydavis) Vectorize and/or multi-thread Jacobian computations if
// performance becomes an issue.
// BaseUnitsForType provides a list of typed unit values for each basis in the
// requested type.
// When T is real,
// BaseUnitsForType<T>::values() is just a single-entry vector [1]
// When T is complex,
// BaseUnitsForType<T>::values() is a two-entry vector [1, i] - the unit
// values in each of its two bases.
template <typename T>
struct BaseUnitsForType {}; // Specializations below
// Template specialization for BaseUnitsForType
#define SET_BASE_UNITS_FOR_TYPE(TYPE, INIT) \
template <> \
struct BaseUnitsForType<TYPE> { \
static const std::vector<TYPE>& values() { \
static std::vector<TYPE>* units = new std::vector<TYPE> INIT; \
return *units; \
} \
}
SET_BASE_UNITS_FOR_TYPE(float, {1});
SET_BASE_UNITS_FOR_TYPE(double, {1});
SET_BASE_UNITS_FOR_TYPE(complex64, ({{1, 0}, {0, 1}}));
SET_BASE_UNITS_FOR_TYPE(complex128, ({{1, 0}, {0, 1}}));
// SetJacobian sets the jacobian value at the provided row and column from a
// tensor entry with type T.
// When T is real, this is a simple assignment that casts the entry into the
// jacobian type.
// When T is complex, it assigns the real and complex values to successive rows
// or columns in the matrix depending on the expand_by_row parameter
template <typename T, typename JAC_T>
typename std::enable_if<std::is_floating_point<T>::value>::type SetJacobian(
typename TTypes<JAC_T>::Matrix* jacobian, const int row, const int col,
const T& value, const bool expand_by_row) {
(*jacobian)(row, col) = JAC_T{value};
}
template <typename T, typename JAC_T>
typename std::enable_if<is_complex<T>::value>::type SetJacobian(
typename TTypes<JAC_T>::Matrix* jacobian, const int row, const int col,
const T& value, const bool expand_by_row) {
(*jacobian)(row, col) = JAC_T{value.real()};
if (expand_by_row) {
(*jacobian)(row + 1, col) = JAC_T{value.imag()};
} else {
(*jacobian)(row, col + 1) = JAC_T{value.imag()};
}
}
// JacobianStride<T>::value holds the number of Jacobian elements needed to
// represent one element of the given type.
// When T is real the stride is 1, and when T is complex the stride is 2.
template <typename T>
struct JacobianStride {}; // Specializations below
#define SET_JACOBIAN_STRIDE(TYPE, VALUE) \
template <> \
struct JacobianStride<TYPE> { \
static constexpr int value = VALUE; \
}
SET_JACOBIAN_STRIDE(float, 1);
SET_JACOBIAN_STRIDE(double, 1);
SET_JACOBIAN_STRIDE(complex64, 2);
SET_JACOBIAN_STRIDE(complex128, 2);
template <typename X_T, typename Y_T, typename JAC_T>
Status ComputeTheoreticalJacobianTranspose(
const Scope& scope, const OutputList& xs,
const std::vector<TensorShape>& x_shapes,
@ -44,9 +112,9 @@ Status ComputeTheoreticalJacobianTranspose(
OutputList dys;
dys.reserve(y_shapes.size());
for (const auto& y_shape : y_shapes) {
// TODO(suharshs): This currently assumes that all x's are the same type.
// TODO(suharshs): This currently assumes that all y's are the same type.
dys.push_back(
ops::Cast(scope, ops::Const(scope, 1.0, y_shape), xs[0].type()));
ops::Cast(scope, ops::Const(scope, 1.0, y_shape), ys[0].type()));
}
OutputList dxs;
TF_RETURN_IF_ERROR(AddSymbolicGradients(scope, ys, xs, dys, &dxs));
@ -55,7 +123,7 @@ Status ComputeTheoreticalJacobianTranspose(
std::vector<Tensor> dy_datas(y_num);
for (int i = 0; i < y_num; i++) {
dy_datas[i] = Tensor(ys[i].type(), y_shapes[i]);
auto dy_data_flat = dy_datas[i].flat<T>();
auto dy_data_flat = dy_datas[i].flat<Y_T>();
dy_data_flat.setZero();
}
@ -68,30 +136,41 @@ Status ComputeTheoreticalJacobianTranspose(
feed_list.insert({dys[i], dy_datas[i]});
}
// x_stride and y_stride are used to calculate the correct jacobian row and
// column position for a pair of elements at positions r, c within the x and y
// tensors respectively.
const int x_stride = JacobianStride<X_T>::value;
const int y_stride = JacobianStride<Y_T>::value;
ClientSession session(scope);
for (int y_idx = 0; y_idx < y_num; y_idx++) {
auto dy_data_flat = dy_datas[y_idx].flat<T>();
auto dy_data_flat = dy_datas[y_idx].flat<Y_T>();
const int64 dy_size = y_shapes[y_idx].num_elements();
// Compute the theoretical Jacobians one row at a time by back propagating
// '1.0' for each element of 'dy', while holding all other elements of 'dy'
// at zero.
// '1.0' (or '1' and 'i' if y is complex) for each element of 'dy', while
// holding all other elements of 'dy' at zero.
for (int c = 0; c < dy_size; ++c) {
dy_data_flat(c) = 1.0;
int unit_dimension = 0;
for (Y_T unit : BaseUnitsForType<Y_T>::values()) {
dy_data_flat(c) = unit;
std::vector<Tensor> dxout;
TF_RETURN_IF_ERROR(session.Run(feed_list, dxs, &dxout));
std::vector<Tensor> dxout;
TF_RETURN_IF_ERROR(session.Run(feed_list, dxs, &dxout));
for (int x_idx = 0; x_idx < x_num; x_idx++) {
const int64 x_size = x_shapes[x_idx].num_elements();
auto jacobian = (*jacobian_ts)[x_idx * y_num + y_idx].matrix<T>();
auto dx_flat = dxout[x_idx].flat<T>();
for (int r = 0; r < x_size; ++r) {
jacobian(r, c) = dx_flat(r);
for (int x_idx = 0; x_idx < x_num; x_idx++) {
const int64 x_size = x_shapes[x_idx].num_elements();
auto jacobian = (*jacobian_ts)[x_idx * y_num + y_idx].matrix<JAC_T>();
auto dx_flat = dxout[x_idx].flat<X_T>();
for (int r = 0; r < x_size; ++r) {
SetJacobian<X_T, JAC_T>(&jacobian, r * x_stride,
c * y_stride + unit_dimension, dx_flat(r),
true /* expand_by_row=true */);
}
}
}
dy_data_flat(c) = 0.0;
dy_data_flat(c) = Y_T{0};
unit_dimension++;
}
}
}
return Status::OK();
@ -122,104 +201,154 @@ Status EvaluateGraph(ClientSession* session, const OutputList& xs,
return Status::OK();
}
template <typename T>
template <typename X_T, typename Y_T, typename JAC_T>
Status ComputeNumericJacobianTranspose(const Scope& scope, const OutputList& xs,
const std::vector<TensorShape>& x_shapes,
const OutputList& ys,
const std::vector<TensorShape>& y_shapes,
const T delta,
const JAC_T delta,
std::vector<Tensor>* x_datas,
std::vector<Tensor>* jacobian_ts) {
size_t y_num = y_shapes.size();
size_t x_num = x_shapes.size();
// x_stride and y_stride are used to calculate the correct jacobian row and
// column position for a pair of elements at positions r, c within the x and y
// tensors respectively.
const int x_stride = JacobianStride<X_T>::value;
const int y_stride = JacobianStride<Y_T>::value;
ClientSession session(scope);
for (int x_idx = 0; x_idx < x_num; x_idx++) {
auto x_data_flat = (*x_datas)[x_idx].flat<T>();
auto x_data_flat = (*x_datas)[x_idx].flat<X_T>();
const int64 x_size = x_shapes[x_idx].num_elements();
// Compute the numeric Jacobian one column at a time by perturbing each
// element of 'x_data' (positively and negatively) by 'delta', and
// updating the jacobian with the centered difference.
// updating the jacobian with the centered difference. When x_data is
// complex-valued, we perturb its real and complex parts separately.
for (int r = 0; r < x_size; ++r) {
// Store current value of 'x' at 'r'.
T v = x_data_flat(r);
// Evaluate at positive delta.
x_data_flat(r) = v + delta;
std::vector<Tensor> y_pos;
TF_RETURN_IF_ERROR(EvaluateGraph(&session, xs, ys, x_datas, &y_pos));
// Evaluate at negative delta.
x_data_flat(r) = v - delta;
std::vector<Tensor> y_neg;
TF_RETURN_IF_ERROR(EvaluateGraph(&session, xs, ys, x_datas, &y_neg));
int unit_dimension = 0;
for (X_T unit : BaseUnitsForType<X_T>::values()) {
X_T x_delta = unit * X_T{delta};
// Store current value of 'x' at 'r'.
X_T v = x_data_flat(r);
// Evaluate at positive delta.
x_data_flat(r) = v + x_delta;
std::vector<Tensor> y_pos;
TF_RETURN_IF_ERROR(EvaluateGraph(&session, xs, ys, x_datas, &y_pos));
// Evaluate at negative delta.
x_data_flat(r) = v - x_delta;
std::vector<Tensor> y_neg;
TF_RETURN_IF_ERROR(EvaluateGraph(&session, xs, ys, x_datas, &y_neg));
for (int y_idx = 0; y_idx < y_num; y_idx++) {
// Compute element-wise centered difference and store in each Jacobian.
auto y_pos_flat = y_pos[y_idx].flat<T>();
auto y_neg_flat = y_neg[y_idx].flat<T>();
const int64 y_size = y_shapes[y_idx].num_elements();
const T scale = 2 * delta;
auto jacobian = (*jacobian_ts)[x_idx * y_num + y_idx].matrix<T>();
for (int c = 0; c < y_size; ++c) {
jacobian(r, c) = (y_pos_flat(c) - y_neg_flat(c)) / scale;
for (int y_idx = 0; y_idx < y_num; y_idx++) {
// Compute element-wise centered difference and store in each
// Jacobian.
auto y_pos_flat = y_pos[y_idx].flat<Y_T>();
auto y_neg_flat = y_neg[y_idx].flat<Y_T>();
const int64 y_size = y_shapes[y_idx].num_elements();
const Y_T scale = Y_T{2 * delta};
auto jacobian = (*jacobian_ts)[x_idx * y_num + y_idx].matrix<JAC_T>();
for (int c = 0; c < y_size; ++c) {
SetJacobian<Y_T, JAC_T>(&jacobian, r * x_stride + unit_dimension,
c * y_stride,
(y_pos_flat(c) - y_neg_flat(c)) / scale,
false /* expand_by_row=false */);
}
}
// Restore pre-perturbation value.
x_data_flat(r) = v;
unit_dimension++;
}
// Restore pre-perturbation value.
x_data_flat(r) = v;
}
}
return Status::OK();
}
template <typename T>
// The Jacobian is always a real-valued matrix.
// Given y = f(x) for tensors y and x, it contains the derivatives dy_i/dx_j for
// every pair y_i in y and x_j in x. Note that the Jacobian is defined directly
// over the elements of tensors y and x, and doesn't depend on their shapes.
//
// If x = (x_1, x_2, ..., x_m) and y = (y_1, y_2, .., y_n) the matrix evaluated
// is actually the Jacobian transpose, defined as this mxn matrix:
// dy_1/d_x1 dy_2/dx_1 ... dy_n/dx_1
// dy_1/dx_2 dy_2/dx_2 ... dy_n/dx_2
// .
// .
// .
// dy_1/dx_m dy_2/dx_m ... dy_n/dx_m
//
// If x or y is complex, each complex entry is "expanded" into a real and
// imaginary entry, and the Jacobian is organized as above on the expanded list.
// e.g.
// [y1, y2] = Square([x1, x2]) where x and y are complex.
// Writing
// x = [x1_real, x1_imag, x2_real, x2_imag]
// y = [y1_real, y1_imag, y2_real, y2_imag]
// the Jacobian transpose is
// the 4x4 matrix:
// dy1_real/dx1_real dy1_imag/dx1_real dy2_real/dx1_real dy2_imag/dx1_real
// dy1_real/dx1_imag dy1_imag/dx1_imag dy2_real/dx1_imag dy2_imag/dx1_imag
// dy1_real/dx2_real dy1_imag/dx2_real dy2_real/dx2_real dy2_imag/dx2_real
// dy1_real/dx2_imag dy1_imag/dx2_imag dy2_real/dx2_imag dy2_imag/dx2_imag
template <typename X_T, typename Y_T, typename JAC_T>
void InitJacobians(const OutputList& xs,
const std::vector<TensorShape>& x_shapes,
const std::vector<TensorShape>& y_shapes,
std::vector<Tensor>* jacobians) {
size_t y_num = y_shapes.size();
size_t x_num = x_shapes.size();
const size_t y_num = y_shapes.size();
const size_t x_num = x_shapes.size();
const DataType jacobian_type = DataTypeToEnum<JAC_T>::v();
jacobians->resize(y_num * x_num);
for (int x_idx = 0; x_idx < x_num; x_idx++) {
const int64 x_size = x_shapes[x_idx].num_elements();
// The number of rows is the number of elements in the x tensor multiplied
// by the number of Jacobian entries needed to represent each x type.
const int64 x_size =
x_shapes[x_idx].num_elements() * JacobianStride<X_T>::value;
for (int y_idx = 0; y_idx < y_num; y_idx++) {
const int64 y_size = y_shapes[y_idx].num_elements();
Tensor jacobian_t(xs[x_idx].type(), {x_size, y_size});
auto jacobian_t_flat = jacobian_t.flat<T>();
// The number of columns is the number of elements in the y tensor
// multiplied by the number of Jacobian entries needed to represent each
// y type.
const int64 y_size =
y_shapes[y_idx].num_elements() * JacobianStride<Y_T>::value;
Tensor jacobian_t(jacobian_type, {x_size, y_size});
auto jacobian_t_flat = jacobian_t.flat<JAC_T>();
jacobian_t_flat.setZero();
(*jacobians)[x_idx * y_num + y_idx] = std::move(jacobian_t);
}
}
}
template <typename T>
template <typename X_T, typename Y_T, typename JAC_T>
Status ComputeGradientErrorInternal(const Scope& scope, const OutputList& xs,
const std::vector<TensorShape>& x_shapes,
const OutputList& ys,
const std::vector<TensorShape>& y_shapes,
std::vector<Tensor>* x_datas,
T* max_error) {
JAC_T* max_error) {
// Initialize theoretical Jacobians to zeros.
std::vector<Tensor> jacobian_ts;
InitJacobians<T>(xs, x_shapes, y_shapes, &jacobian_ts);
InitJacobians<X_T, Y_T, JAC_T>(xs, x_shapes, y_shapes, &jacobian_ts);
// Compute theoretical Jacobian.
TF_RETURN_IF_ERROR(ComputeTheoreticalJacobianTranspose<T>(
scope, xs, x_shapes, *x_datas, ys, y_shapes, &jacobian_ts));
TF_RETURN_IF_ERROR((ComputeTheoreticalJacobianTranspose<X_T, Y_T, JAC_T>(
scope, xs, x_shapes, *x_datas, ys, y_shapes, &jacobian_ts)));
// Initialize numeric Jacobian to zeros.
std::vector<Tensor> jacobian_ns;
InitJacobians<T>(xs, x_shapes, y_shapes, &jacobian_ns);
InitJacobians<X_T, Y_T, JAC_T>(xs, x_shapes, y_shapes, &jacobian_ns);
// Compute numeric Jacobian.
TF_RETURN_IF_ERROR(ComputeNumericJacobianTranspose<T>(
scope, xs, x_shapes, ys, y_shapes, 1e-3, x_datas, &jacobian_ns));
TF_RETURN_IF_ERROR((ComputeNumericJacobianTranspose<X_T, Y_T, JAC_T>(
scope, xs, x_shapes, ys, y_shapes, JAC_T{1e-3f}, x_datas, &jacobian_ns)));
for (int i = 0; i < jacobian_ts.size(); i++) {
// Compute the maximum error between theoretical and numeric Jacobians.
*max_error = 0.0;
auto jac_t = jacobian_ts[i].matrix<T>();
auto jac_n = jacobian_ns[i].matrix<T>();
auto jac_t = jacobian_ts[i].matrix<JAC_T>();
auto jac_n = jacobian_ns[i].matrix<JAC_T>();
for (int r = 0; r < jacobian_ts[i].dim_size(0); ++r) {
for (int c = 0; c < jacobian_ts[i].dim_size(1); ++c) {
*max_error = std::max(*max_error, std::fabs(jac_t(r, c) - jac_n(r, c)));
@ -231,12 +360,12 @@ Status ComputeGradientErrorInternal(const Scope& scope, const OutputList& xs,
} // namespace
template <typename T>
template <typename X_T, typename Y_T, typename JAC_T>
Status ComputeGradientError(const Scope& scope, const OutputList& xs,
const std::vector<TensorShape>& x_shapes,
const OutputList& ys,
const std::vector<TensorShape>& y_shapes,
T* max_error) {
JAC_T* max_error) {
if (xs.size() != x_shapes.size()) {
return errors::InvalidArgument("xs(size ", xs.size(),
") and x_shapes(size ", x_shapes.size(),
@ -251,35 +380,39 @@ Status ComputeGradientError(const Scope& scope, const OutputList& xs,
std::vector<Tensor> x_datas(x_shapes.size());
for (int i = 0; i < x_shapes.size(); i++) {
x_datas[i] = Tensor(xs[i].type(), x_shapes[i]);
auto x_data_flat = x_datas[i].flat<T>();
auto x_data_flat = x_datas[i].flat<X_T>();
x_data_flat.setRandom();
}
// Compute gradient error.
return ComputeGradientErrorInternal(scope, xs, x_shapes, ys, y_shapes,
&x_datas, max_error);
return ComputeGradientErrorInternal<X_T, Y_T, JAC_T>(
scope, xs, x_shapes, ys, y_shapes, &x_datas, max_error);
}
template <typename T>
template <typename X_T, typename Y_T, typename JAC_T>
Status ComputeGradientError(const Scope& scope, const Output& x,
const Tensor& x_init_value, const Output& y,
const TensorShape& y_shape, T* max_error) {
const TensorShape& y_shape, JAC_T* max_error) {
// Initialize 'x_data' from 'x_init_value'.
std::vector<Tensor> x_datas(1, Tensor(x_init_value));
// Compute gradient error.
return ComputeGradientErrorInternal(scope, {x}, {x_datas[0].shape()}, {y},
{y_shape}, &x_datas, max_error);
return ComputeGradientErrorInternal<X_T, Y_T, JAC_T>(
scope, {x}, {x_datas[0].shape()}, {y}, {y_shape}, &x_datas, max_error);
}
#define INSTANTIATE_GRAD_ERR_TYPE(T) \
template Status ComputeGradientError<T>( \
#define INSTANTIATE_GRAD_ERR_TYPE(X_T, Y_T, JAC_T) \
template Status ComputeGradientError<X_T, Y_T, JAC_T>( \
const Scope& scope, const OutputList& xs, \
const std::vector<TensorShape>& x_shapes, const OutputList& ys, \
const std::vector<TensorShape>& y_shapes, T* max_error); \
template Status ComputeGradientError<T>( \
const std::vector<TensorShape>& y_shapes, JAC_T* max_error); \
template Status ComputeGradientError<X_T, Y_T, JAC_T>( \
const Scope& scope, const Output& x, const Tensor& x_init_value, \
const Output& y, const TensorShape& y_shape, T* max_error);
const Output& y, const TensorShape& y_shape, JAC_T* max_error);
INSTANTIATE_GRAD_ERR_TYPE(float);
INSTANTIATE_GRAD_ERR_TYPE(double);
INSTANTIATE_GRAD_ERR_TYPE(float, float, float);
INSTANTIATE_GRAD_ERR_TYPE(double, double, double);
INSTANTIATE_GRAD_ERR_TYPE(complex64, float, float);
INSTANTIATE_GRAD_ERR_TYPE(float, complex64, float);
INSTANTIATE_GRAD_ERR_TYPE(complex64, complex64, float);
INSTANTIATE_GRAD_ERR_TYPE(complex128, complex128, double);
} // namespace tensorflow

View File

@ -24,19 +24,39 @@ namespace tensorflow {
/// Returns in 'max_error' the maximum element-wise error for dy/dx between the
/// computed and numeric Jacobian matrices where 'xs' and 'ys' are tensors.
/// X_T and Y_T are the c++ types for the x and y tensors, and JAC_T is a
/// real-valued type to store the Jacobian derivatives dy/dx.
/// This function adds operations to the graph associated with 'scope'.
template <typename T>
///
/// Examples:
/// if y = Square(x), where x (and so y) are DT_FLOAT,
/// <X_T, Y_T, JAC_T> should be <float, float, float>
///
/// if y = Square(x), where x (and so y) are DT_DOUBLE,
/// <X_T, Y_T, JAC_T> should be <double, double, double>
///
/// if y = Square(x), where x (and so y) are DT_COMPLEX64,
/// <X_T, Y_T, JAC_T> should be <complex64, complex64, float>
/// Note that JAC_T is always real-valued, and should be an appropriate
/// precision to host the partial derivatives for dy/dx
///
/// if y = ComplexAbs(x) where x is DT_COMPLEX64 (so y is DT_FLOAT)
/// <X_T, Y_T, JAC_T> should be <complex64, float, float>
///
/// if y = Complex(x, x) where x is DT_FLOAT (so y is DT_COMPLEX64)
/// <X_T, Y_T, JAC_T> should be <float, complex64, float>
template <typename X_T, typename Y_T, typename JAC_T>
Status ComputeGradientError(const Scope& scope, const OutputList& xs,
const std::vector<TensorShape>& x_shapes,
const OutputList& ys,
const std::vector<TensorShape>& y_shapes,
T* max_error);
JAC_T* max_error);
/// Overload of ComputeGradientError which takes an initial value for 'x'.
template <typename T>
template <typename X_T, typename Y_T, typename JAC_T>
Status ComputeGradientError(const Scope& scope, const Output& x,
const Tensor& x_init_value, const Output& y,
const TensorShape& y_shape, T* max_error);
const TensorShape& y_shape, JAC_T* max_error);
} // namespace tensorflow

View File

@ -34,8 +34,8 @@ TEST(GradientCheckerTest, BasicFloat) {
auto x = Placeholder(scope, DT_FLOAT, Placeholder::Shape(shape));
auto y = Square(scope, x);
float max_error;
TF_ASSERT_OK(ComputeGradientError<float>(scope, {x}, {shape}, {y}, {shape},
&max_error));
TF_ASSERT_OK((ComputeGradientError<float, float, float>(
scope, {x}, {shape}, {y}, {shape}, &max_error)));
EXPECT_LT(max_error, 1e-4);
}
@ -45,11 +45,57 @@ TEST(GradientCheckerTest, BasicDouble) {
auto x = Placeholder(scope, DT_DOUBLE, Placeholder::Shape(shape));
auto y = Square(scope, x);
double max_error;
TF_ASSERT_OK(ComputeGradientError<double>(scope, {x}, {shape}, {y}, {shape},
&max_error));
TF_ASSERT_OK((ComputeGradientError<double, double, double>(
scope, {x}, {shape}, {y}, {shape}, &max_error)));
EXPECT_LT(max_error, 1e-10);
}
TEST(GradientCheckerTest, BasicComplex64) {
Scope scope = Scope::NewRootScope();
TensorShape shape({2, 4, 3});
auto x = Placeholder(scope, DT_COMPLEX64, Placeholder::Shape(shape));
auto y = Square(scope, x);
float max_error;
TF_ASSERT_OK((ComputeGradientError<complex64, complex64, float>(
scope, {x}, {shape}, {y}, {shape}, &max_error)));
EXPECT_LT(max_error, 1e-4);
}
TEST(GradientCheckerTest, BasicComplex128) {
Scope scope = Scope::NewRootScope();
TensorShape shape({2, 4, 3});
auto x = Placeholder(scope, DT_COMPLEX128, Placeholder::Shape(shape));
auto y = Square(scope, x);
double max_error;
TF_ASSERT_OK((ComputeGradientError<complex128, complex128, double>(
scope, {x}, {shape}, {y}, {shape}, &max_error)));
EXPECT_LT(max_error, 1e-10);
}
TEST(GradientCheckerTest, FloatToComplex64) {
// Test an op whose inputs are real and outputs are complex
Scope scope = Scope::NewRootScope();
TensorShape shape({2, 4, 3});
auto x = Placeholder(scope, DT_FLOAT, Placeholder::Shape(shape));
auto y = Complex(scope, x, x);
float max_error;
TF_ASSERT_OK((ComputeGradientError<float, complex64, float>(
scope, {x}, {shape}, {y}, {shape}, &max_error)));
EXPECT_LT(max_error, 1e-4);
}
TEST(GradientCheckerTest, Complex64ToFloat) {
// Test an op whose inputs are complex and outputs are real
Scope scope = Scope::NewRootScope();
TensorShape shape({2, 4, 3});
auto x = Placeholder(scope, DT_COMPLEX64, Placeholder::Shape(shape));
auto y = Real(scope, x);
float max_error;
TF_ASSERT_OK((ComputeGradientError<complex64, float, float>(
scope, {x}, {shape}, {y}, {shape}, &max_error)));
EXPECT_LT(max_error, 1e-4);
}
TEST(GradientCheckerTest, MatMulGrad) {
Scope scope = Scope::NewRootScope();
@ -61,8 +107,8 @@ TEST(GradientCheckerTest, MatMulGrad) {
auto y = Const(scope, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, y_shape);
auto z = MatMul(scope, x, y);
double max_error;
TF_ASSERT_OK(ComputeGradientError<double>(scope, {x}, {x_shape}, {z},
{z_shape}, &max_error));
TF_ASSERT_OK((ComputeGradientError<double, double, double>(
scope, {x}, {x_shape}, {z}, {z_shape}, &max_error)));
EXPECT_LT(max_error, 1e-10);
}
@ -76,8 +122,8 @@ TEST(GradientCheckerTest, SplitGrad) {
auto y = Split(scope, split_dim, x, /* num_split */ 2);
TensorShape y_shape = TensorShape({5, 1});
double max_error;
TF_ASSERT_OK(ComputeGradientError<double>(scope, {x}, {x_shape}, y.output,
{y_shape, y_shape}, &max_error));
TF_ASSERT_OK((ComputeGradientError<double, double, double>(
scope, {x}, {x_shape}, y.output, {y_shape, y_shape}, &max_error)));
EXPECT_LT(max_error, 1e-10);
}
@ -91,8 +137,8 @@ TEST(GradientCheckerTest, StackGrad) {
auto y = Stack(scope, xs, Stack::Axis(0));
TensorShape y_shape({2, 1, 2, 3});
double max_error;
TF_ASSERT_OK(ComputeGradientError<double>(scope, xs, {x_shape, x_shape}, {y},
{y_shape}, &max_error));
TF_ASSERT_OK((ComputeGradientError<double, double, double>(
scope, xs, {x_shape, x_shape}, {y}, {y_shape}, &max_error)));
EXPECT_LT(max_error, 1e-10);
}
@ -107,8 +153,8 @@ TEST(GradientCheckerTest, StackUnstackGrad) {
auto tmp = Stack(scope, xs, Stack::Axis(0));
auto y = Unstack(scope, tmp, 2, Unstack::Axis(0));
double max_error;
TF_ASSERT_OK(ComputeGradientError<double>(scope, xs, {shape, shape}, y.output,
{shape, shape}, &max_error));
TF_ASSERT_OK((ComputeGradientError<double, double, double>(
scope, xs, {shape, shape}, y.output, {shape, shape}, &max_error)));
EXPECT_LT(max_error, 1e-10);
}

View File

@ -36,8 +36,8 @@ class ArrayGradTest : public ::testing::Test {
const TensorShape& y_shape) {
TF_ASSERT_OK(scope_.status());
float max_error;
TF_ASSERT_OK(ComputeGradientError(scope_, {x}, {x_shape}, {y}, {y_shape},
&max_error));
TF_ASSERT_OK((ComputeGradientError<float, float, float>(
scope_, {x}, {x_shape}, {y}, {y_shape}, &max_error)));
EXPECT_LT(max_error, 1e-3);
}
@ -45,8 +45,8 @@ class ArrayGradTest : public ::testing::Test {
const OutputList& ys, const std::vector<TensorShape>& y_shapes) {
TF_ASSERT_OK(scope_.status());
float max_error;
TF_ASSERT_OK(
ComputeGradientError(scope_, xs, x_shapes, ys, y_shapes, &max_error));
TF_ASSERT_OK((ComputeGradientError<float, float, float>(
scope_, xs, x_shapes, ys, y_shapes, &max_error)));
EXPECT_LT(max_error, 1e-3);
}

View File

@ -35,8 +35,8 @@ class DataFlowGradTest : public ::testing::Test {
const OutputList& ys, const std::vector<TensorShape>& y_shapes) {
TF_ASSERT_OK(scope_.status());
float max_error;
TF_ASSERT_OK(
ComputeGradientError(scope_, xs, x_shapes, ys, y_shapes, &max_error));
TF_ASSERT_OK((ComputeGradientError<float, float, float>(
scope_, xs, x_shapes, ys, y_shapes, &max_error)));
EXPECT_LT(max_error, 1e-4);
}

View File

@ -526,6 +526,15 @@ Status ImagGrad(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("Imag", ImagGrad);
Status ComplexGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
auto gx_1 = Real(scope, grad_inputs[0]);
auto gx_2 = Imag(scope, grad_inputs[0]);
return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
}
REGISTER_GRADIENT_OP("Complex", ComplexGrad);
Status AngleGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {

View File

@ -737,10 +737,14 @@ TEST_F(CWiseUnaryComplexGradTest, Angle) {
Tensor x = test::AsTensor<complex64>(
{{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3});
Tensor dy = test::AsTensor<float>({11, -12, 13, -14, 15, -16}, {2, 3});
Tensor dx_expected = test::AsTensor<complex64>(
{{5.5, 5.5}, {3, 3},
{2.1666666666666665, 2.1666666666666665}, {1.75, 1.75},
{0.9375, 0.9375}, {0.8888888888888888, 0.8888888888888888}}, {2, 3});
Tensor dx_expected =
test::AsTensor<complex64>({{5.5, 5.5},
{3, 3},
{2.1666666666666665, 2.1666666666666665},
{1.75, 1.75},
{0.9375, 0.9375},
{0.8888888888888888, 0.8888888888888888}},
{2, 3});
TestCWiseGradComplex(ANGLE, x, dy, dx_expected);
}
@ -920,8 +924,8 @@ class NaryGradTest : public ::testing::Test {
const OutputList& ys, const std::vector<TensorShape>& y_shapes) {
TF_ASSERT_OK(scope_.status());
float max_error;
TF_ASSERT_OK(
ComputeGradientError(scope_, xs, x_shapes, ys, y_shapes, &max_error));
TF_ASSERT_OK((ComputeGradientError<float, float, float>(
scope_, xs, x_shapes, ys, y_shapes, &max_error)));
EXPECT_LT(max_error, 1e-3);
}
@ -929,8 +933,8 @@ class NaryGradTest : public ::testing::Test {
const TensorShape& y_shape) {
TF_ASSERT_OK(scope_.status());
float max_error;
TF_ASSERT_OK(
ComputeGradientError(scope_, x, x_init_value, y, y_shape, &max_error));
TF_ASSERT_OK((ComputeGradientError<float, float, float>(
scope_, x, x_init_value, y, y_shape, &max_error)));
EXPECT_LT(max_error, 1e-3);
}

View File

@ -34,16 +34,16 @@ class NNGradTest : public ::testing::Test {
void RunTest(const Output& x, const TensorShape& x_shape, const Output& y,
const TensorShape& y_shape) {
float max_error;
TF_ASSERT_OK(ComputeGradientError(scope_, {x}, {x_shape}, {y}, {y_shape},
&max_error));
TF_ASSERT_OK((ComputeGradientError<float, float, float>(
scope_, {x}, {x_shape}, {y}, {y_shape}, &max_error)));
EXPECT_LT(max_error, 2e-4);
}
void RunTest(const Output& x, const Tensor& x_init_value, const Output& y,
const TensorShape& y_shape) {
float max_error;
TF_ASSERT_OK(
ComputeGradientError(scope_, x, x_init_value, y, y_shape, &max_error));
TF_ASSERT_OK((ComputeGradientError<float, float, float>(
scope_, x, x_init_value, y, y_shape, &max_error)));
EXPECT_LT(max_error, 2e-4);
}
@ -51,8 +51,8 @@ class NNGradTest : public ::testing::Test {
const OutputList& ys, const std::vector<TensorShape>& y_shapes) {
TF_ASSERT_OK(scope_.status());
float max_error;
TF_ASSERT_OK(
ComputeGradientError(scope_, xs, x_shapes, ys, y_shapes, &max_error));
TF_ASSERT_OK((ComputeGradientError<float, float, float>(
scope_, xs, x_shapes, ys, y_shapes, &max_error)));
EXPECT_LT(max_error, 2e-4);
}
@ -71,11 +71,10 @@ TEST_F(NNGradTest, LogSoftmaxGrad) {
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
auto y = LogSoftmax(scope_, x);
// Avoid numerical instability when computing finite differences.
Tensor x_init_value = test::AsTensor<float>(
{-0.9f, -0.7f, -0.5f, -0.3f, -0.1f,
0.1f, 0.3f, 0.5f, 0.7f, 0.8f,
-0.1f, 0.1f, 0.1f, 0.1f, 1.2f},
{5, 3});
Tensor x_init_value =
test::AsTensor<float>({-0.9f, -0.7f, -0.5f, -0.3f, -0.1f, 0.1f, 0.3f,
0.5f, 0.7f, 0.8f, -0.1f, 0.1f, 0.1f, 0.1f, 1.2f},
{5, 3});
RunTest(x, x_init_value, y, shape);
}
@ -136,7 +135,7 @@ TEST_F(NNGradTest, BiasAddGradHelper) {
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
auto bias = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(bias_shape));
auto y = BiasAdd(scope_, x, bias);
RunTest({x,bias}, {shape, bias_shape}, {y}, {shape});
RunTest({x, bias}, {shape, bias_shape}, {y}, {shape});
}
} // namespace