From e6b011763a60d239972c8c6c0f36536ab6f885a3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 12 Sep 2017 10:16:26 -0700 Subject: [PATCH] Extend c++ gradient_checker to complex types. PiperOrigin-RevId: 168392949 --- tensorflow/cc/framework/gradient_checker.cc | 283 +++++++++++++----- tensorflow/cc/framework/gradient_checker.h | 28 +- .../cc/framework/gradient_checker_test.cc | 70 ++++- tensorflow/cc/gradients/array_grad_test.cc | 8 +- .../cc/gradients/data_flow_grad_test.cc | 4 +- tensorflow/cc/gradients/math_grad.cc | 9 + tensorflow/cc/gradients/math_grad_test.cc | 20 +- tensorflow/cc/gradients/nn_grad_test.cc | 23 +- 8 files changed, 328 insertions(+), 117 deletions(-) diff --git a/tensorflow/cc/framework/gradient_checker.cc b/tensorflow/cc/framework/gradient_checker.cc index f3a7c138c4e..de2645cb440 100644 --- a/tensorflow/cc/framework/gradient_checker.cc +++ b/tensorflow/cc/framework/gradient_checker.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/cc/framework/gradients.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/tensor_util.h" +#include "tensorflow/core/framework/type_traits.h" #include "tensorflow/core/lib/core/errors.h" namespace tensorflow { @@ -31,7 +32,74 @@ namespace { // TODO(andydavis) Vectorize and/or multi-thread Jacobian computations if // performance becomes an issue. +// BaseUnitsForType provides a list of typed unit values for each basis in the +// requested type. +// When T is real, +// BaseUnitsForType::values() is just a single-entry vector [1] +// When T is complex, +// BaseUnitsForType::values() is a two-entry vector [1, i] - the unit +// values in each of its two bases. template +struct BaseUnitsForType {}; // Specializations below + +// Template specialization for BaseUnitsForType +#define SET_BASE_UNITS_FOR_TYPE(TYPE, INIT) \ + template <> \ + struct BaseUnitsForType { \ + static const std::vector& values() { \ + static std::vector* units = new std::vector INIT; \ + return *units; \ + } \ + } + +SET_BASE_UNITS_FOR_TYPE(float, {1}); +SET_BASE_UNITS_FOR_TYPE(double, {1}); +SET_BASE_UNITS_FOR_TYPE(complex64, ({{1, 0}, {0, 1}})); +SET_BASE_UNITS_FOR_TYPE(complex128, ({{1, 0}, {0, 1}})); + +// SetJacobian sets the jacobian value at the provided row and column from a +// tensor entry with type T. +// When T is real, this is a simple assignment that casts the entry into the +// jacobian type. +// When T is complex, it assigns the real and complex values to successive rows +// or columns in the matrix depending on the expand_by_row parameter +template +typename std::enable_if::value>::type SetJacobian( + typename TTypes::Matrix* jacobian, const int row, const int col, + const T& value, const bool expand_by_row) { + (*jacobian)(row, col) = JAC_T{value}; +} + +template +typename std::enable_if::value>::type SetJacobian( + typename TTypes::Matrix* jacobian, const int row, const int col, + const T& value, const bool expand_by_row) { + (*jacobian)(row, col) = JAC_T{value.real()}; + if (expand_by_row) { + (*jacobian)(row + 1, col) = JAC_T{value.imag()}; + } else { + (*jacobian)(row, col + 1) = JAC_T{value.imag()}; + } +} + +// JacobianStride::value holds the number of Jacobian elements needed to +// represent one element of the given type. +// When T is real the stride is 1, and when T is complex the stride is 2. +template +struct JacobianStride {}; // Specializations below + +#define SET_JACOBIAN_STRIDE(TYPE, VALUE) \ + template <> \ + struct JacobianStride { \ + static constexpr int value = VALUE; \ + } + +SET_JACOBIAN_STRIDE(float, 1); +SET_JACOBIAN_STRIDE(double, 1); +SET_JACOBIAN_STRIDE(complex64, 2); +SET_JACOBIAN_STRIDE(complex128, 2); + +template Status ComputeTheoreticalJacobianTranspose( const Scope& scope, const OutputList& xs, const std::vector& x_shapes, @@ -44,9 +112,9 @@ Status ComputeTheoreticalJacobianTranspose( OutputList dys; dys.reserve(y_shapes.size()); for (const auto& y_shape : y_shapes) { - // TODO(suharshs): This currently assumes that all x's are the same type. + // TODO(suharshs): This currently assumes that all y's are the same type. dys.push_back( - ops::Cast(scope, ops::Const(scope, 1.0, y_shape), xs[0].type())); + ops::Cast(scope, ops::Const(scope, 1.0, y_shape), ys[0].type())); } OutputList dxs; TF_RETURN_IF_ERROR(AddSymbolicGradients(scope, ys, xs, dys, &dxs)); @@ -55,7 +123,7 @@ Status ComputeTheoreticalJacobianTranspose( std::vector dy_datas(y_num); for (int i = 0; i < y_num; i++) { dy_datas[i] = Tensor(ys[i].type(), y_shapes[i]); - auto dy_data_flat = dy_datas[i].flat(); + auto dy_data_flat = dy_datas[i].flat(); dy_data_flat.setZero(); } @@ -68,30 +136,41 @@ Status ComputeTheoreticalJacobianTranspose( feed_list.insert({dys[i], dy_datas[i]}); } + // x_stride and y_stride are used to calculate the correct jacobian row and + // column position for a pair of elements at positions r, c within the x and y + // tensors respectively. + const int x_stride = JacobianStride::value; + const int y_stride = JacobianStride::value; ClientSession session(scope); for (int y_idx = 0; y_idx < y_num; y_idx++) { - auto dy_data_flat = dy_datas[y_idx].flat(); + auto dy_data_flat = dy_datas[y_idx].flat(); const int64 dy_size = y_shapes[y_idx].num_elements(); // Compute the theoretical Jacobians one row at a time by back propagating - // '1.0' for each element of 'dy', while holding all other elements of 'dy' - // at zero. + // '1.0' (or '1' and 'i' if y is complex) for each element of 'dy', while + // holding all other elements of 'dy' at zero. for (int c = 0; c < dy_size; ++c) { - dy_data_flat(c) = 1.0; + int unit_dimension = 0; + for (Y_T unit : BaseUnitsForType::values()) { + dy_data_flat(c) = unit; - std::vector dxout; - TF_RETURN_IF_ERROR(session.Run(feed_list, dxs, &dxout)); + std::vector dxout; + TF_RETURN_IF_ERROR(session.Run(feed_list, dxs, &dxout)); - for (int x_idx = 0; x_idx < x_num; x_idx++) { - const int64 x_size = x_shapes[x_idx].num_elements(); - auto jacobian = (*jacobian_ts)[x_idx * y_num + y_idx].matrix(); - auto dx_flat = dxout[x_idx].flat(); - for (int r = 0; r < x_size; ++r) { - jacobian(r, c) = dx_flat(r); + for (int x_idx = 0; x_idx < x_num; x_idx++) { + const int64 x_size = x_shapes[x_idx].num_elements(); + auto jacobian = (*jacobian_ts)[x_idx * y_num + y_idx].matrix(); + auto dx_flat = dxout[x_idx].flat(); + for (int r = 0; r < x_size; ++r) { + SetJacobian(&jacobian, r * x_stride, + c * y_stride + unit_dimension, dx_flat(r), + true /* expand_by_row=true */); + } } - } - dy_data_flat(c) = 0.0; + dy_data_flat(c) = Y_T{0}; + unit_dimension++; + } } } return Status::OK(); @@ -122,104 +201,154 @@ Status EvaluateGraph(ClientSession* session, const OutputList& xs, return Status::OK(); } -template +template Status ComputeNumericJacobianTranspose(const Scope& scope, const OutputList& xs, const std::vector& x_shapes, const OutputList& ys, const std::vector& y_shapes, - const T delta, + const JAC_T delta, std::vector* x_datas, std::vector* jacobian_ts) { size_t y_num = y_shapes.size(); size_t x_num = x_shapes.size(); + // x_stride and y_stride are used to calculate the correct jacobian row and + // column position for a pair of elements at positions r, c within the x and y + // tensors respectively. + const int x_stride = JacobianStride::value; + const int y_stride = JacobianStride::value; ClientSession session(scope); for (int x_idx = 0; x_idx < x_num; x_idx++) { - auto x_data_flat = (*x_datas)[x_idx].flat(); + auto x_data_flat = (*x_datas)[x_idx].flat(); const int64 x_size = x_shapes[x_idx].num_elements(); // Compute the numeric Jacobian one column at a time by perturbing each // element of 'x_data' (positively and negatively) by 'delta', and - // updating the jacobian with the centered difference. + // updating the jacobian with the centered difference. When x_data is + // complex-valued, we perturb its real and complex parts separately. for (int r = 0; r < x_size; ++r) { - // Store current value of 'x' at 'r'. - T v = x_data_flat(r); - // Evaluate at positive delta. - x_data_flat(r) = v + delta; - std::vector y_pos; - TF_RETURN_IF_ERROR(EvaluateGraph(&session, xs, ys, x_datas, &y_pos)); - // Evaluate at negative delta. - x_data_flat(r) = v - delta; - std::vector y_neg; - TF_RETURN_IF_ERROR(EvaluateGraph(&session, xs, ys, x_datas, &y_neg)); + int unit_dimension = 0; + for (X_T unit : BaseUnitsForType::values()) { + X_T x_delta = unit * X_T{delta}; + // Store current value of 'x' at 'r'. + X_T v = x_data_flat(r); + // Evaluate at positive delta. + x_data_flat(r) = v + x_delta; + std::vector y_pos; + TF_RETURN_IF_ERROR(EvaluateGraph(&session, xs, ys, x_datas, &y_pos)); + // Evaluate at negative delta. + x_data_flat(r) = v - x_delta; + std::vector y_neg; + TF_RETURN_IF_ERROR(EvaluateGraph(&session, xs, ys, x_datas, &y_neg)); - for (int y_idx = 0; y_idx < y_num; y_idx++) { - // Compute element-wise centered difference and store in each Jacobian. - auto y_pos_flat = y_pos[y_idx].flat(); - auto y_neg_flat = y_neg[y_idx].flat(); - const int64 y_size = y_shapes[y_idx].num_elements(); - const T scale = 2 * delta; - auto jacobian = (*jacobian_ts)[x_idx * y_num + y_idx].matrix(); - for (int c = 0; c < y_size; ++c) { - jacobian(r, c) = (y_pos_flat(c) - y_neg_flat(c)) / scale; + for (int y_idx = 0; y_idx < y_num; y_idx++) { + // Compute element-wise centered difference and store in each + // Jacobian. + auto y_pos_flat = y_pos[y_idx].flat(); + auto y_neg_flat = y_neg[y_idx].flat(); + const int64 y_size = y_shapes[y_idx].num_elements(); + const Y_T scale = Y_T{2 * delta}; + auto jacobian = (*jacobian_ts)[x_idx * y_num + y_idx].matrix(); + for (int c = 0; c < y_size; ++c) { + SetJacobian(&jacobian, r * x_stride + unit_dimension, + c * y_stride, + (y_pos_flat(c) - y_neg_flat(c)) / scale, + false /* expand_by_row=false */); + } } + // Restore pre-perturbation value. + x_data_flat(r) = v; + unit_dimension++; } - // Restore pre-perturbation value. - x_data_flat(r) = v; } } return Status::OK(); } -template +// The Jacobian is always a real-valued matrix. +// Given y = f(x) for tensors y and x, it contains the derivatives dy_i/dx_j for +// every pair y_i in y and x_j in x. Note that the Jacobian is defined directly +// over the elements of tensors y and x, and doesn't depend on their shapes. +// +// If x = (x_1, x_2, ..., x_m) and y = (y_1, y_2, .., y_n) the matrix evaluated +// is actually the Jacobian transpose, defined as this mxn matrix: +// dy_1/d_x1 dy_2/dx_1 ... dy_n/dx_1 +// dy_1/dx_2 dy_2/dx_2 ... dy_n/dx_2 +// . +// . +// . +// dy_1/dx_m dy_2/dx_m ... dy_n/dx_m +// +// If x or y is complex, each complex entry is "expanded" into a real and +// imaginary entry, and the Jacobian is organized as above on the expanded list. +// e.g. +// [y1, y2] = Square([x1, x2]) where x and y are complex. +// Writing +// x = [x1_real, x1_imag, x2_real, x2_imag] +// y = [y1_real, y1_imag, y2_real, y2_imag] +// the Jacobian transpose is +// the 4x4 matrix: +// dy1_real/dx1_real dy1_imag/dx1_real dy2_real/dx1_real dy2_imag/dx1_real +// dy1_real/dx1_imag dy1_imag/dx1_imag dy2_real/dx1_imag dy2_imag/dx1_imag +// dy1_real/dx2_real dy1_imag/dx2_real dy2_real/dx2_real dy2_imag/dx2_real +// dy1_real/dx2_imag dy1_imag/dx2_imag dy2_real/dx2_imag dy2_imag/dx2_imag +template void InitJacobians(const OutputList& xs, const std::vector& x_shapes, const std::vector& y_shapes, std::vector* jacobians) { - size_t y_num = y_shapes.size(); - size_t x_num = x_shapes.size(); + const size_t y_num = y_shapes.size(); + const size_t x_num = x_shapes.size(); + const DataType jacobian_type = DataTypeToEnum::v(); jacobians->resize(y_num * x_num); for (int x_idx = 0; x_idx < x_num; x_idx++) { - const int64 x_size = x_shapes[x_idx].num_elements(); + // The number of rows is the number of elements in the x tensor multiplied + // by the number of Jacobian entries needed to represent each x type. + const int64 x_size = + x_shapes[x_idx].num_elements() * JacobianStride::value; for (int y_idx = 0; y_idx < y_num; y_idx++) { - const int64 y_size = y_shapes[y_idx].num_elements(); - Tensor jacobian_t(xs[x_idx].type(), {x_size, y_size}); - auto jacobian_t_flat = jacobian_t.flat(); + // The number of columns is the number of elements in the y tensor + // multiplied by the number of Jacobian entries needed to represent each + // y type. + const int64 y_size = + y_shapes[y_idx].num_elements() * JacobianStride::value; + Tensor jacobian_t(jacobian_type, {x_size, y_size}); + auto jacobian_t_flat = jacobian_t.flat(); jacobian_t_flat.setZero(); (*jacobians)[x_idx * y_num + y_idx] = std::move(jacobian_t); } } } -template +template Status ComputeGradientErrorInternal(const Scope& scope, const OutputList& xs, const std::vector& x_shapes, const OutputList& ys, const std::vector& y_shapes, std::vector* x_datas, - T* max_error) { + JAC_T* max_error) { // Initialize theoretical Jacobians to zeros. std::vector jacobian_ts; - InitJacobians(xs, x_shapes, y_shapes, &jacobian_ts); + InitJacobians(xs, x_shapes, y_shapes, &jacobian_ts); // Compute theoretical Jacobian. - TF_RETURN_IF_ERROR(ComputeTheoreticalJacobianTranspose( - scope, xs, x_shapes, *x_datas, ys, y_shapes, &jacobian_ts)); + TF_RETURN_IF_ERROR((ComputeTheoreticalJacobianTranspose( + scope, xs, x_shapes, *x_datas, ys, y_shapes, &jacobian_ts))); // Initialize numeric Jacobian to zeros. std::vector jacobian_ns; - InitJacobians(xs, x_shapes, y_shapes, &jacobian_ns); + InitJacobians(xs, x_shapes, y_shapes, &jacobian_ns); // Compute numeric Jacobian. - TF_RETURN_IF_ERROR(ComputeNumericJacobianTranspose( - scope, xs, x_shapes, ys, y_shapes, 1e-3, x_datas, &jacobian_ns)); + TF_RETURN_IF_ERROR((ComputeNumericJacobianTranspose( + scope, xs, x_shapes, ys, y_shapes, JAC_T{1e-3f}, x_datas, &jacobian_ns))); for (int i = 0; i < jacobian_ts.size(); i++) { // Compute the maximum error between theoretical and numeric Jacobians. *max_error = 0.0; - auto jac_t = jacobian_ts[i].matrix(); - auto jac_n = jacobian_ns[i].matrix(); + auto jac_t = jacobian_ts[i].matrix(); + auto jac_n = jacobian_ns[i].matrix(); for (int r = 0; r < jacobian_ts[i].dim_size(0); ++r) { for (int c = 0; c < jacobian_ts[i].dim_size(1); ++c) { *max_error = std::max(*max_error, std::fabs(jac_t(r, c) - jac_n(r, c))); @@ -231,12 +360,12 @@ Status ComputeGradientErrorInternal(const Scope& scope, const OutputList& xs, } // namespace -template +template Status ComputeGradientError(const Scope& scope, const OutputList& xs, const std::vector& x_shapes, const OutputList& ys, const std::vector& y_shapes, - T* max_error) { + JAC_T* max_error) { if (xs.size() != x_shapes.size()) { return errors::InvalidArgument("xs(size ", xs.size(), ") and x_shapes(size ", x_shapes.size(), @@ -251,35 +380,39 @@ Status ComputeGradientError(const Scope& scope, const OutputList& xs, std::vector x_datas(x_shapes.size()); for (int i = 0; i < x_shapes.size(); i++) { x_datas[i] = Tensor(xs[i].type(), x_shapes[i]); - auto x_data_flat = x_datas[i].flat(); + auto x_data_flat = x_datas[i].flat(); x_data_flat.setRandom(); } // Compute gradient error. - return ComputeGradientErrorInternal(scope, xs, x_shapes, ys, y_shapes, - &x_datas, max_error); + return ComputeGradientErrorInternal( + scope, xs, x_shapes, ys, y_shapes, &x_datas, max_error); } -template +template Status ComputeGradientError(const Scope& scope, const Output& x, const Tensor& x_init_value, const Output& y, - const TensorShape& y_shape, T* max_error) { + const TensorShape& y_shape, JAC_T* max_error) { // Initialize 'x_data' from 'x_init_value'. std::vector x_datas(1, Tensor(x_init_value)); // Compute gradient error. - return ComputeGradientErrorInternal(scope, {x}, {x_datas[0].shape()}, {y}, - {y_shape}, &x_datas, max_error); + return ComputeGradientErrorInternal( + scope, {x}, {x_datas[0].shape()}, {y}, {y_shape}, &x_datas, max_error); } -#define INSTANTIATE_GRAD_ERR_TYPE(T) \ - template Status ComputeGradientError( \ +#define INSTANTIATE_GRAD_ERR_TYPE(X_T, Y_T, JAC_T) \ + template Status ComputeGradientError( \ const Scope& scope, const OutputList& xs, \ const std::vector& x_shapes, const OutputList& ys, \ - const std::vector& y_shapes, T* max_error); \ - template Status ComputeGradientError( \ + const std::vector& y_shapes, JAC_T* max_error); \ + template Status ComputeGradientError( \ const Scope& scope, const Output& x, const Tensor& x_init_value, \ - const Output& y, const TensorShape& y_shape, T* max_error); + const Output& y, const TensorShape& y_shape, JAC_T* max_error); -INSTANTIATE_GRAD_ERR_TYPE(float); -INSTANTIATE_GRAD_ERR_TYPE(double); +INSTANTIATE_GRAD_ERR_TYPE(float, float, float); +INSTANTIATE_GRAD_ERR_TYPE(double, double, double); +INSTANTIATE_GRAD_ERR_TYPE(complex64, float, float); +INSTANTIATE_GRAD_ERR_TYPE(float, complex64, float); +INSTANTIATE_GRAD_ERR_TYPE(complex64, complex64, float); +INSTANTIATE_GRAD_ERR_TYPE(complex128, complex128, double); } // namespace tensorflow diff --git a/tensorflow/cc/framework/gradient_checker.h b/tensorflow/cc/framework/gradient_checker.h index 2e61213615a..d055c60d09c 100644 --- a/tensorflow/cc/framework/gradient_checker.h +++ b/tensorflow/cc/framework/gradient_checker.h @@ -24,19 +24,39 @@ namespace tensorflow { /// Returns in 'max_error' the maximum element-wise error for dy/dx between the /// computed and numeric Jacobian matrices where 'xs' and 'ys' are tensors. +/// X_T and Y_T are the c++ types for the x and y tensors, and JAC_T is a +/// real-valued type to store the Jacobian derivatives dy/dx. /// This function adds operations to the graph associated with 'scope'. -template +/// +/// Examples: +/// if y = Square(x), where x (and so y) are DT_FLOAT, +/// should be +/// +/// if y = Square(x), where x (and so y) are DT_DOUBLE, +/// should be +/// +/// if y = Square(x), where x (and so y) are DT_COMPLEX64, +/// should be +/// Note that JAC_T is always real-valued, and should be an appropriate +/// precision to host the partial derivatives for dy/dx +/// +/// if y = ComplexAbs(x) where x is DT_COMPLEX64 (so y is DT_FLOAT) +/// should be +/// +/// if y = Complex(x, x) where x is DT_FLOAT (so y is DT_COMPLEX64) +/// should be +template Status ComputeGradientError(const Scope& scope, const OutputList& xs, const std::vector& x_shapes, const OutputList& ys, const std::vector& y_shapes, - T* max_error); + JAC_T* max_error); /// Overload of ComputeGradientError which takes an initial value for 'x'. -template +template Status ComputeGradientError(const Scope& scope, const Output& x, const Tensor& x_init_value, const Output& y, - const TensorShape& y_shape, T* max_error); + const TensorShape& y_shape, JAC_T* max_error); } // namespace tensorflow diff --git a/tensorflow/cc/framework/gradient_checker_test.cc b/tensorflow/cc/framework/gradient_checker_test.cc index c5bddc50fcc..fdc457f40af 100644 --- a/tensorflow/cc/framework/gradient_checker_test.cc +++ b/tensorflow/cc/framework/gradient_checker_test.cc @@ -34,8 +34,8 @@ TEST(GradientCheckerTest, BasicFloat) { auto x = Placeholder(scope, DT_FLOAT, Placeholder::Shape(shape)); auto y = Square(scope, x); float max_error; - TF_ASSERT_OK(ComputeGradientError(scope, {x}, {shape}, {y}, {shape}, - &max_error)); + TF_ASSERT_OK((ComputeGradientError( + scope, {x}, {shape}, {y}, {shape}, &max_error))); EXPECT_LT(max_error, 1e-4); } @@ -45,11 +45,57 @@ TEST(GradientCheckerTest, BasicDouble) { auto x = Placeholder(scope, DT_DOUBLE, Placeholder::Shape(shape)); auto y = Square(scope, x); double max_error; - TF_ASSERT_OK(ComputeGradientError(scope, {x}, {shape}, {y}, {shape}, - &max_error)); + TF_ASSERT_OK((ComputeGradientError( + scope, {x}, {shape}, {y}, {shape}, &max_error))); EXPECT_LT(max_error, 1e-10); } +TEST(GradientCheckerTest, BasicComplex64) { + Scope scope = Scope::NewRootScope(); + TensorShape shape({2, 4, 3}); + auto x = Placeholder(scope, DT_COMPLEX64, Placeholder::Shape(shape)); + auto y = Square(scope, x); + float max_error; + TF_ASSERT_OK((ComputeGradientError( + scope, {x}, {shape}, {y}, {shape}, &max_error))); + EXPECT_LT(max_error, 1e-4); +} + +TEST(GradientCheckerTest, BasicComplex128) { + Scope scope = Scope::NewRootScope(); + TensorShape shape({2, 4, 3}); + auto x = Placeholder(scope, DT_COMPLEX128, Placeholder::Shape(shape)); + auto y = Square(scope, x); + double max_error; + TF_ASSERT_OK((ComputeGradientError( + scope, {x}, {shape}, {y}, {shape}, &max_error))); + EXPECT_LT(max_error, 1e-10); +} + +TEST(GradientCheckerTest, FloatToComplex64) { + // Test an op whose inputs are real and outputs are complex + Scope scope = Scope::NewRootScope(); + TensorShape shape({2, 4, 3}); + auto x = Placeholder(scope, DT_FLOAT, Placeholder::Shape(shape)); + auto y = Complex(scope, x, x); + float max_error; + TF_ASSERT_OK((ComputeGradientError( + scope, {x}, {shape}, {y}, {shape}, &max_error))); + EXPECT_LT(max_error, 1e-4); +} + +TEST(GradientCheckerTest, Complex64ToFloat) { + // Test an op whose inputs are complex and outputs are real + Scope scope = Scope::NewRootScope(); + TensorShape shape({2, 4, 3}); + auto x = Placeholder(scope, DT_COMPLEX64, Placeholder::Shape(shape)); + auto y = Real(scope, x); + float max_error; + TF_ASSERT_OK((ComputeGradientError( + scope, {x}, {shape}, {y}, {shape}, &max_error))); + EXPECT_LT(max_error, 1e-4); +} + TEST(GradientCheckerTest, MatMulGrad) { Scope scope = Scope::NewRootScope(); @@ -61,8 +107,8 @@ TEST(GradientCheckerTest, MatMulGrad) { auto y = Const(scope, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, y_shape); auto z = MatMul(scope, x, y); double max_error; - TF_ASSERT_OK(ComputeGradientError(scope, {x}, {x_shape}, {z}, - {z_shape}, &max_error)); + TF_ASSERT_OK((ComputeGradientError( + scope, {x}, {x_shape}, {z}, {z_shape}, &max_error))); EXPECT_LT(max_error, 1e-10); } @@ -76,8 +122,8 @@ TEST(GradientCheckerTest, SplitGrad) { auto y = Split(scope, split_dim, x, /* num_split */ 2); TensorShape y_shape = TensorShape({5, 1}); double max_error; - TF_ASSERT_OK(ComputeGradientError(scope, {x}, {x_shape}, y.output, - {y_shape, y_shape}, &max_error)); + TF_ASSERT_OK((ComputeGradientError( + scope, {x}, {x_shape}, y.output, {y_shape, y_shape}, &max_error))); EXPECT_LT(max_error, 1e-10); } @@ -91,8 +137,8 @@ TEST(GradientCheckerTest, StackGrad) { auto y = Stack(scope, xs, Stack::Axis(0)); TensorShape y_shape({2, 1, 2, 3}); double max_error; - TF_ASSERT_OK(ComputeGradientError(scope, xs, {x_shape, x_shape}, {y}, - {y_shape}, &max_error)); + TF_ASSERT_OK((ComputeGradientError( + scope, xs, {x_shape, x_shape}, {y}, {y_shape}, &max_error))); EXPECT_LT(max_error, 1e-10); } @@ -107,8 +153,8 @@ TEST(GradientCheckerTest, StackUnstackGrad) { auto tmp = Stack(scope, xs, Stack::Axis(0)); auto y = Unstack(scope, tmp, 2, Unstack::Axis(0)); double max_error; - TF_ASSERT_OK(ComputeGradientError(scope, xs, {shape, shape}, y.output, - {shape, shape}, &max_error)); + TF_ASSERT_OK((ComputeGradientError( + scope, xs, {shape, shape}, y.output, {shape, shape}, &max_error))); EXPECT_LT(max_error, 1e-10); } diff --git a/tensorflow/cc/gradients/array_grad_test.cc b/tensorflow/cc/gradients/array_grad_test.cc index 1777e181451..455d7330c10 100644 --- a/tensorflow/cc/gradients/array_grad_test.cc +++ b/tensorflow/cc/gradients/array_grad_test.cc @@ -36,8 +36,8 @@ class ArrayGradTest : public ::testing::Test { const TensorShape& y_shape) { TF_ASSERT_OK(scope_.status()); float max_error; - TF_ASSERT_OK(ComputeGradientError(scope_, {x}, {x_shape}, {y}, {y_shape}, - &max_error)); + TF_ASSERT_OK((ComputeGradientError( + scope_, {x}, {x_shape}, {y}, {y_shape}, &max_error))); EXPECT_LT(max_error, 1e-3); } @@ -45,8 +45,8 @@ class ArrayGradTest : public ::testing::Test { const OutputList& ys, const std::vector& y_shapes) { TF_ASSERT_OK(scope_.status()); float max_error; - TF_ASSERT_OK( - ComputeGradientError(scope_, xs, x_shapes, ys, y_shapes, &max_error)); + TF_ASSERT_OK((ComputeGradientError( + scope_, xs, x_shapes, ys, y_shapes, &max_error))); EXPECT_LT(max_error, 1e-3); } diff --git a/tensorflow/cc/gradients/data_flow_grad_test.cc b/tensorflow/cc/gradients/data_flow_grad_test.cc index 3d027909f08..734dfd3af97 100644 --- a/tensorflow/cc/gradients/data_flow_grad_test.cc +++ b/tensorflow/cc/gradients/data_flow_grad_test.cc @@ -35,8 +35,8 @@ class DataFlowGradTest : public ::testing::Test { const OutputList& ys, const std::vector& y_shapes) { TF_ASSERT_OK(scope_.status()); float max_error; - TF_ASSERT_OK( - ComputeGradientError(scope_, xs, x_shapes, ys, y_shapes, &max_error)); + TF_ASSERT_OK((ComputeGradientError( + scope_, xs, x_shapes, ys, y_shapes, &max_error))); EXPECT_LT(max_error, 1e-4); } diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc index d90654f2e9a..b88332ebc74 100644 --- a/tensorflow/cc/gradients/math_grad.cc +++ b/tensorflow/cc/gradients/math_grad.cc @@ -526,6 +526,15 @@ Status ImagGrad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("Imag", ImagGrad); +Status ComplexGrad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + auto gx_1 = Real(scope, grad_inputs[0]); + auto gx_2 = Imag(scope, grad_inputs[0]); + return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); +} +REGISTER_GRADIENT_OP("Complex", ComplexGrad); + Status AngleGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc index 5b1558dd820..97cd86eacba 100644 --- a/tensorflow/cc/gradients/math_grad_test.cc +++ b/tensorflow/cc/gradients/math_grad_test.cc @@ -737,10 +737,14 @@ TEST_F(CWiseUnaryComplexGradTest, Angle) { Tensor x = test::AsTensor( {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3}); Tensor dy = test::AsTensor({11, -12, 13, -14, 15, -16}, {2, 3}); - Tensor dx_expected = test::AsTensor( - {{5.5, 5.5}, {3, 3}, - {2.1666666666666665, 2.1666666666666665}, {1.75, 1.75}, - {0.9375, 0.9375}, {0.8888888888888888, 0.8888888888888888}}, {2, 3}); + Tensor dx_expected = + test::AsTensor({{5.5, 5.5}, + {3, 3}, + {2.1666666666666665, 2.1666666666666665}, + {1.75, 1.75}, + {0.9375, 0.9375}, + {0.8888888888888888, 0.8888888888888888}}, + {2, 3}); TestCWiseGradComplex(ANGLE, x, dy, dx_expected); } @@ -920,8 +924,8 @@ class NaryGradTest : public ::testing::Test { const OutputList& ys, const std::vector& y_shapes) { TF_ASSERT_OK(scope_.status()); float max_error; - TF_ASSERT_OK( - ComputeGradientError(scope_, xs, x_shapes, ys, y_shapes, &max_error)); + TF_ASSERT_OK((ComputeGradientError( + scope_, xs, x_shapes, ys, y_shapes, &max_error))); EXPECT_LT(max_error, 1e-3); } @@ -929,8 +933,8 @@ class NaryGradTest : public ::testing::Test { const TensorShape& y_shape) { TF_ASSERT_OK(scope_.status()); float max_error; - TF_ASSERT_OK( - ComputeGradientError(scope_, x, x_init_value, y, y_shape, &max_error)); + TF_ASSERT_OK((ComputeGradientError( + scope_, x, x_init_value, y, y_shape, &max_error))); EXPECT_LT(max_error, 1e-3); } diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc index affc1e1dbe6..64f1f760662 100644 --- a/tensorflow/cc/gradients/nn_grad_test.cc +++ b/tensorflow/cc/gradients/nn_grad_test.cc @@ -34,16 +34,16 @@ class NNGradTest : public ::testing::Test { void RunTest(const Output& x, const TensorShape& x_shape, const Output& y, const TensorShape& y_shape) { float max_error; - TF_ASSERT_OK(ComputeGradientError(scope_, {x}, {x_shape}, {y}, {y_shape}, - &max_error)); + TF_ASSERT_OK((ComputeGradientError( + scope_, {x}, {x_shape}, {y}, {y_shape}, &max_error))); EXPECT_LT(max_error, 2e-4); } void RunTest(const Output& x, const Tensor& x_init_value, const Output& y, const TensorShape& y_shape) { float max_error; - TF_ASSERT_OK( - ComputeGradientError(scope_, x, x_init_value, y, y_shape, &max_error)); + TF_ASSERT_OK((ComputeGradientError( + scope_, x, x_init_value, y, y_shape, &max_error))); EXPECT_LT(max_error, 2e-4); } @@ -51,8 +51,8 @@ class NNGradTest : public ::testing::Test { const OutputList& ys, const std::vector& y_shapes) { TF_ASSERT_OK(scope_.status()); float max_error; - TF_ASSERT_OK( - ComputeGradientError(scope_, xs, x_shapes, ys, y_shapes, &max_error)); + TF_ASSERT_OK((ComputeGradientError( + scope_, xs, x_shapes, ys, y_shapes, &max_error))); EXPECT_LT(max_error, 2e-4); } @@ -71,11 +71,10 @@ TEST_F(NNGradTest, LogSoftmaxGrad) { auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); auto y = LogSoftmax(scope_, x); // Avoid numerical instability when computing finite differences. - Tensor x_init_value = test::AsTensor( - {-0.9f, -0.7f, -0.5f, -0.3f, -0.1f, - 0.1f, 0.3f, 0.5f, 0.7f, 0.8f, - -0.1f, 0.1f, 0.1f, 0.1f, 1.2f}, - {5, 3}); + Tensor x_init_value = + test::AsTensor({-0.9f, -0.7f, -0.5f, -0.3f, -0.1f, 0.1f, 0.3f, + 0.5f, 0.7f, 0.8f, -0.1f, 0.1f, 0.1f, 0.1f, 1.2f}, + {5, 3}); RunTest(x, x_init_value, y, shape); } @@ -136,7 +135,7 @@ TEST_F(NNGradTest, BiasAddGradHelper) { auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); auto bias = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(bias_shape)); auto y = BiasAdd(scope_, x, bias); - RunTest({x,bias}, {shape, bias_shape}, {y}, {shape}); + RunTest({x, bias}, {shape, bias_shape}, {y}, {shape}); } } // namespace