Fix bug in complex einsum grad.

Discovered another bug triggered on empty shapes with complex inputs. On the way, went ahead and removed spurious conjugate operations inside einsum function.

Fixes #37307.

PiperOrigin-RevId: 310069589
Change-Id: I665e4b077c13eddfbad7c5ffbb9f0addd7207cb4
This commit is contained in:
Anudhyan Boral 2020-05-05 19:43:01 -07:00 committed by TensorFlower Gardener
parent e58a4820ed
commit eaaf5b6b43
4 changed files with 116 additions and 100 deletions

View File

@ -78,8 +78,9 @@ struct ParallelMatMulKernel {
}
static void Run(const OpKernelContext* context, const Tensor& in_x,
const Tensor in_y, bool adj_x, bool adj_y,
const MatMulBCast& bcast, Tensor* out, int start, int limit) {
const Tensor in_y, bool adj_x, bool adj_y, bool trans_x,
bool trans_y, const MatMulBCast& bcast, Tensor* out,
int start, int limit) {
static_assert(IsComplex, "Complex type expected.");
auto Tx = in_x.tensor<Scalar, 3>();
auto Ty = in_y.tensor<Scalar, 3>();
@ -90,7 +91,7 @@ struct ParallelMatMulKernel {
// to halve the number of cases. The final conjugation of the result is
// done at the end of LaunchBatchMatMul<CPUDevice, Scalar>::Launch().
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
contract_pairs[0] = ContractionDims(adj_x, adj_y);
contract_pairs[0] = ContractionDims(adj_x || trans_x, adj_y || trans_y);
const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
const bool should_bcast = bcast.IsBroadcastingRequired();
@ -121,13 +122,14 @@ struct ParallelMatMulKernel<Scalar, false> {
static void Conjugate(const OpKernelContext* context, Tensor* out) {}
static void Run(const OpKernelContext* context, const Tensor& in_x,
const Tensor& in_y, bool adj_x, bool adj_y,
const MatMulBCast& bcast, Tensor* out, int start, int limit) {
const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
bool trans_y, const MatMulBCast& bcast, Tensor* out,
int start, int limit) {
auto Tx = in_x.tensor<Scalar, 3>();
auto Ty = in_y.tensor<Scalar, 3>();
auto Tz = out->tensor<Scalar, 3>();
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
contract_pairs[0] = ContractionDims(adj_x, adj_y);
contract_pairs[0] = ContractionDims(adj_x || trans_x, adj_y || trans_y);
const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
const bool should_bcast = bcast.IsBroadcastingRequired();
@ -169,8 +171,8 @@ struct SequentialMatMulKernel {
}
static void Run(const Tensor& in_x, const Tensor& in_y, bool adj_x,
bool adj_y, const MatMulBCast& bcast, Tensor* out, int start,
int limit) {
bool adj_y, bool trans_x, bool trans_y,
const MatMulBCast& bcast, Tensor* out, int start, int limit) {
const bool should_bcast = bcast.IsBroadcastingRequired();
const auto& x_batch_indices = bcast.x_batch_indices();
const auto& y_batch_indices = bcast.y_batch_indices();
@ -180,17 +182,31 @@ struct SequentialMatMulKernel {
auto x = ConstTensorSliceToEigenMatrix(in_x, x_batch_index);
auto y = ConstTensorSliceToEigenMatrix(in_y, y_batch_index);
auto z = TensorSliceToEigenMatrix(out, i);
if (!adj_x) {
if (!adj_y) {
// Assume at most one of adj_x or trans_x is true. Similarly, for adj_y
// and trans_y.
if (!adj_x && !trans_x) {
if (!adj_y && !trans_y) {
z.noalias() = x * y;
} else {
} else if (adj_y) {
z.noalias() = x * y.adjoint();
} else { // trans_y == true
z.noalias() = x * y.transpose();
}
} else {
if (!adj_y) {
} else if (adj_x) {
if (!adj_y && !trans_y) {
z.noalias() = x.adjoint() * y;
} else {
} else if (adj_y) {
z.noalias() = x.adjoint() * y.adjoint();
} else { // trans_y == true
z.noalias() = x.adjoint() * y.transpose();
}
} else { // trans_x == true
if (!adj_y && !trans_y) {
z.noalias() = x.transpose() * y;
} else if (adj_y) {
z.noalias() = x.transpose() * y.adjoint();
} else { // trans_y == true
z.noalias() = x.transpose() * y.transpose();
}
}
}
@ -205,8 +221,8 @@ struct LaunchBatchMatMul;
template <typename Scalar>
struct LaunchBatchMatMul<CPUDevice, Scalar> {
static void Launch(OpKernelContext* context, const Tensor& in_x,
const Tensor& in_y, bool adj_x, bool adj_y,
const MatMulBCast& bcast, Tensor* out) {
const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
bool trans_y, const MatMulBCast& bcast, Tensor* out) {
typedef ParallelMatMulKernel<Scalar, Eigen::NumTraits<Scalar>::IsComplex>
ParallelMatMulKernel;
bool conjugate_result = false;
@ -226,17 +242,19 @@ struct LaunchBatchMatMul<CPUDevice, Scalar> {
// Parallelize over inner dims.
// For large matrix products it is counter-productive to parallelize
// over the batch dimension.
ParallelMatMulKernel::Run(context, in_x, in_y, adj_x, adj_y, bcast, out,
0, batch_size);
ParallelMatMulKernel::Run(context, in_x, in_y, adj_x, adj_y, trans_x,
trans_y, bcast, out, 0, batch_size);
conjugate_result = adj_x;
} else {
// Parallelize over outer dims. For small matrices and large batches, it
// is counter-productive to parallelize the inner matrix multiplies.
Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
cost_per_unit,
[&in_x, &in_y, adj_x, adj_y, &bcast, out](int start, int limit) {
[&in_x, &in_y, adj_x, adj_y, trans_x, trans_y, &bcast, out](
int start, int limit) {
SequentialMatMulKernel<Scalar>::Run(in_x, in_y, adj_x, adj_y,
bcast, out, start, limit);
trans_x, trans_y, bcast, out,
start, limit);
});
}
if (conjugate_result) {
@ -297,19 +315,17 @@ class BlasScratchAllocator : public se::ScratchAllocator {
template <typename Scalar>
struct LaunchBatchMatMul<GPUDevice, Scalar> {
static void Launch(OpKernelContext* context, const Tensor& in_x,
const Tensor& in_y, bool adj_x, bool adj_y,
const MatMulBCast& bcast, Tensor* out) {
constexpr se::blas::Transpose kTranspose =
is_complex<Scalar>::value ? se::blas::Transpose::kConjugateTranspose
: se::blas::Transpose::kTranspose;
const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
bool trans_y, const MatMulBCast& bcast, Tensor* out) {
se::blas::Transpose trans[] = {se::blas::Transpose::kNoTranspose,
kTranspose};
const uint64 m = in_x.dim_size(adj_x ? 2 : 1);
const uint64 k = in_x.dim_size(adj_x ? 1 : 2);
const uint64 n = in_y.dim_size(adj_y ? 1 : 2);
se::blas::Transpose::kTranspose,
se::blas::Transpose::kConjugateTranspose};
const uint64 m = in_x.dim_size(adj_x || trans_x ? 2 : 1);
const uint64 k = in_x.dim_size(adj_x || trans_x ? 1 : 2);
const uint64 n = in_y.dim_size(adj_y || trans_y ? 1 : 2);
const int64 batch_size = bcast.output_batch_size();
auto blas_transpose_a = trans[adj_x];
auto blas_transpose_b = trans[adj_y];
auto blas_transpose_a = trans[adj_x ? 2 : (trans_x ? 1 : 0)];
auto blas_transpose_b = trans[adj_y ? 2 : (trans_y ? 1 : 0)];
auto* stream = context->op_device_context()->stream();
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
@ -399,9 +415,10 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
: se::blas::Transpose::kTranspose;
bool blas_launch_status =
stream
->ThenBlasGemv(gemv_trans_a, adj_x ? m : k, adj_x ? k : m,
->ThenBlasGemv(gemv_trans_a, adj_x || trans_x ? m : k,
adj_x || trans_x ? k : m,
static_cast<Coefficient>(1.0), *(a_ptrs[0]),
adj_x ? m : k, *(b_ptrs[0]), 1,
adj_x || trans_x ? m : k, *(b_ptrs[0]), 1,
static_cast<Coefficient>(0.0), c_ptrs[0], 1)
.ok();
if (!blas_launch_status) {
@ -415,7 +432,8 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
stream
->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k,
static_cast<Coefficient>(1.0), *(b_ptrs[0]),
adj_y ? k : n, *(a_ptrs[0]), adj_x ? m : k,
adj_y || trans_y ? k : n, *(a_ptrs[0]),
adj_x || trans_x ? m : k,
static_cast<Coefficient>(0.0), c_ptrs[0], n)
.ok();
if (!blas_launch_status) {
@ -430,8 +448,9 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
stream
->ThenBlasGemmStridedBatched(
blas_transpose_b, blas_transpose_a, n, m, k,
static_cast<Coefficient>(1.0), *b_ptrs[0], adj_y ? k : n,
b_stride, *a_ptrs[0], adj_x ? m : k, a_stride,
static_cast<Coefficient>(1.0), *b_ptrs[0],
adj_y || trans_y ? k : n, b_stride, *a_ptrs[0],
adj_x || trans_x ? m : k, a_stride,
static_cast<Coefficient>(0.0), c_ptrs[0], n, c_stride,
batch_size)
.ok();
@ -448,9 +467,10 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
stream
->ThenBlasGemmBatchedWithScratch(
blas_transpose_b, blas_transpose_a, n, m, k,
static_cast<Coefficient>(1.0), b_ptrs, adj_y ? k : n, a_ptrs,
adj_x ? m : k, static_cast<Coefficient>(0.0), c_ptrs, n,
batch_size, &scratch_allocator)
static_cast<Coefficient>(1.0), b_ptrs,
adj_y || trans_y ? k : n, a_ptrs, adj_x || trans_x ? m : k,
static_cast<Coefficient>(0.0), c_ptrs, n, batch_size,
&scratch_allocator)
.ok();
if (!blas_launch_status) {
context->SetStatus(errors::Internal(
@ -466,21 +486,18 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
template <>
struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
static void Launch(OpKernelContext* context, const Tensor& in_x,
const Tensor& in_y, bool adj_x, bool adj_y,
const MatMulBCast& bcast, Tensor* out) {
const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
bool trans_y, const MatMulBCast& bcast, Tensor* out) {
typedef Eigen::half Scalar;
constexpr perftools::gputools::blas::Transpose kTranspose =
is_complex<Scalar>::value
? perftools::gputools::blas::Transpose::kConjugateTranspose
: perftools::gputools::blas::Transpose::kTranspose;
perftools::gputools::blas::Transpose trans[] = {
perftools::gputools::blas::Transpose::kNoTranspose, kTranspose};
const uint64 m = in_x.dim_size(adj_x ? 2 : 1);
const uint64 k = in_x.dim_size(adj_x ? 1 : 2);
const uint64 n = in_y.dim_size(adj_y ? 1 : 2);
se::blas::Transpose trans[] = {se::blas::Transpose::kNoTranspose,
se::blas::Transpose::kTranspose,
se::blas::Transpose::kConjugateTranspose};
const uint64 m = in_x.dim_size(adj_x || trans_x ? 2 : 1);
const uint64 k = in_x.dim_size(adj_x || trans_x ? 1 : 2);
const uint64 n = in_y.dim_size(adj_y || trans_y ? 1 : 2);
const uint64 batch_size = bcast.output_batch_size();
auto blas_transpose_a = trans[adj_x];
auto blas_transpose_b = trans[adj_y];
auto blas_transpose_a = trans[adj_x ? 2 : (trans_x ? 1 : 0)];
auto blas_transpose_b = trans[adj_y ? 2 : (trans_y ? 1 : 0)];
auto* stream = context->op_device_context()->stream();
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
@ -563,7 +580,8 @@ struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
stream
->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k,
static_cast<Coefficient>(1.0), *(b_ptrs[0]),
adj_y ? k : n, *(a_ptrs[0]), adj_x ? m : k,
adj_y || trans_y ? k : n, *(a_ptrs[0]),
adj_x || trans_x ? m : k,
static_cast<Coefficient>(0.0), c_ptrs[0], n)
.ok();
if (!blas_launch_status) {
@ -577,8 +595,9 @@ struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
stream
->ThenBlasGemmStridedBatched(
blas_transpose_b, blas_transpose_a, n, m, k,
static_cast<Coefficient>(1.0), *b_ptrs[0], adj_y ? k : n,
b_stride, *a_ptrs[0], adj_x ? m : k, a_stride,
static_cast<Coefficient>(1.0), *b_ptrs[0],
adj_y || trans_y ? k : n, b_stride, *a_ptrs[0],
adj_x || trans_x ? m : k, a_stride,
static_cast<Coefficient>(0.0), c_ptrs[0], n, c_stride,
batch_size)
.ok();
@ -595,9 +614,10 @@ struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
stream
->ThenBlasGemmBatchedWithScratch(
blas_transpose_b, blas_transpose_a, n, m, k,
static_cast<Coefficient>(1.0), b_ptrs, adj_y ? k : n, a_ptrs,
adj_x ? m : k, static_cast<Coefficient>(0.0), c_ptrs, n,
batch_size, &scratch_allocator)
static_cast<Coefficient>(1.0), b_ptrs,
adj_y || trans_y ? k : n, a_ptrs, adj_x || trans_x ? m : k,
static_cast<Coefficient>(0.0), c_ptrs, n, batch_size,
&scratch_allocator)
.ok();
if (!blas_launch_status) {
context->SetStatus(errors::Internal(
@ -616,13 +636,14 @@ struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
template <typename Scalar>
struct ParallelMatMulKernelSYCL {
static void Run(const OpKernelContext* context, const Tensor& in_x,
const Tensor& in_y, bool adj_x, bool adj_y,
const MatMulBCast& bcast, Tensor* out, int start, int limit) {
const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
bool trans_y, const MatMulBCast& bcast, Tensor* out,
int start, int limit) {
auto Tx = in_x.tensor<Scalar, 3>();
auto Ty = in_y.tensor<Scalar, 3>();
auto Tz = out->tensor<Scalar, 3>();
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
contract_pairs[0] = ContractionDims(adj_x, adj_y);
contract_pairs[0] = ContractionDims(adj_x || trans_x, adj_y || trans_y);
auto d = context->eigen_sycl_device();
const bool should_bcast = bcast.IsBroadcastingRequired();
@ -643,12 +664,13 @@ struct ParallelMatMulKernelSYCL {
template <typename Scalar>
struct LaunchBatchMatMul<SYCLDevice, Scalar> {
static void Launch(OpKernelContext* context, const Tensor& in_x,
const Tensor& in_y, bool adj_x, bool adj_y,
const MatMulBCast& bcast, Tensor* out) {
const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
bool trans_y, const MatMulBCast& bcast, Tensor* out) {
// Number of matrix multiplies i.e. size of the batch.
const int64 batch_size = bcast.output_batch_size();
ParallelMatMulKernelSYCL<Scalar>::Run(context, in_x, in_y, adj_x, adj_y,
bcast, out, 0, batch_size);
trans_x, trans_y, bcast, out, 0,
batch_size);
}
};
#endif // TENSORFLOW_USE_SYCL
@ -720,7 +742,8 @@ class BaseBatchMatMulOp : public OpKernel {
errors::Internal("Failed to reshape output from ",
out->shape().DebugString()));
LaunchBatchMatMul<Device, Scalar>::Launch(
ctx, in0_reshaped, in1_reshaped, adj_x_, adj_y_, bcast, &out_reshaped);
ctx, in0_reshaped, in1_reshaped, adj_x_, adj_y_, /*trans_x=*/false,
/*trans_y=*/false, bcast, &out_reshaped);
}
protected:

View File

@ -537,29 +537,13 @@ struct EinsumHelper {
return CopyFrom(input, output_shape, output);
}
// Conjugates the input.
template <typename Device, typename T>
static Status Conjugate(OpKernelContext* ctx, Tensor* input) {
std::vector<int> permutation(input->dims());
std::iota(permutation.begin(), permutation.end(), 0);
Tensor output;
TF_RETURN_IF_ERROR(
ctx->allocate_temp(DataTypeToEnum<T>::value, input->shape(), &output));
const Device& d = ctx->eigen_device<Device>();
TF_RETURN_IF_ERROR(DoConjugateTranspose(d, *input, permutation, &output));
std::swap(*input, output);
return Status::OK();
}
// Contracts the inputs along the last axis. (or the second last if the
// corresponding value of swap_free_and_contract is true). The batch
// dimensions are broadcast to the output shape.
// TODO(anudhyan): Factor this function into a BatchMatMul functor and support
// transpose_x and transpose_y attributes (in addition to adj_x and adj_y).
// Also, the BatchMatMul might devolve into a component-wise multiplication
// when the matrix shape is [1,1]; in this case BatchMatMul functor would be
// very inefficient. The functor should detect if this is the case and perform
// componentwise multiplication functor instead.
// TODO(anudhyan): BatchMatMul might devolve into a component-wise
// multiplication when the matrix shape is [1,1]; in this case BatchMatMul
// functor would be very inefficient. The functor should detect if this is the
// case and perform componentwise multiplication functor instead.
template <typename Device, typename T>
static Status ContractOperands(OpKernelContext* ctx,
absl::Span<const Tensor> inputs,
@ -584,12 +568,8 @@ struct EinsumHelper {
inputs[i].dims() - (swap_free_and_contract[i] ? 1 : 2);
output_shape.AddDim(inputs[i].dim_size(free_axis));
}
bool adj_x = swap_free_and_contract[0];
bool adj_y = !swap_free_and_contract[1];
if (is_complex<T>::value) {
if (adj_x) TF_RETURN_IF_ERROR(Conjugate<Device, T>(ctx, &lhs));
if (adj_y) TF_RETURN_IF_ERROR(Conjugate<Device, T>(ctx, &rhs));
}
bool trans_x = swap_free_and_contract[0];
bool trans_y = !swap_free_and_contract[1];
TF_RETURN_IF_ERROR(
ctx->allocate_temp(DataTypeToEnum<T>::value, output_shape, output));
if (lhs.NumElements() == 0 || rhs.NumElements() == 0) {
@ -600,8 +580,9 @@ struct EinsumHelper {
Tensor output_reshaped;
TF_RETURN_IF_ERROR(
ReshapeToRank3(*output, bcast.output_batch_size(), &output_reshaped));
LaunchBatchMatMul<Device, T>::Launch(ctx, lhs, rhs, adj_x, adj_y, bcast,
&output_reshaped);
LaunchBatchMatMul<Device, T>::Launch(ctx, lhs, rhs, /*adj_x=*/false,
/*adj_y=*/false, trans_x, trans_y,
bcast, &output_reshaped);
return Status::OK();
}
};

View File

@ -286,13 +286,21 @@ class EinsumGradTest(test.TestCase):
def _check_gradient(self, s, *input_shapes):
with self.cached_session():
r = np.random.RandomState(0)
inputs = [np.array(r.randn(*shape), np.float64) for shape in input_shapes]
input_tensors = [constant_op.constant(x, shape=x.shape) for x in inputs]
analytical, numerical = gradient_checker_v2.compute_gradient(
lambda *xs: gen_linalg_ops.einsum(xs, s), input_tensors)
self.assertLess(
gradient_checker_v2.max_error(analytical, numerical), 1e-4)
r = np.random.RandomState(seed=0)
for dtype in (np.float32, np.float64, np.complex64, np.complex128):
tol = 10 * np.sqrt(np.finfo(dtype).resolution)
if dtype in (np.complex64, np.complex128):
inputs = [
np.array(r.randn(*shape), dtype) +
1j * np.array(r.randn(*shape), dtype) for shape in input_shapes
]
else:
inputs = [np.array(r.randn(*shape), dtype) for shape in input_shapes]
input_tensors = [constant_op.constant(x, shape=x.shape) for x in inputs]
analytical, numerical = gradient_checker_v2.compute_gradient(
lambda *xs: gen_linalg_ops.einsum(xs, s), input_tensors)
self.assertLess(
gradient_checker_v2.max_error(analytical, numerical), tol)
@test_util.disable_xla('b/131919749')
def testUnary(self):

View File

@ -332,6 +332,10 @@ def _EinsumGrad(op, grad):
# Obtain the gradients wrt the inputs x and y, without taking into account
# the unbroadcasting.
x, y = op.inputs[0], op.inputs[1]
if grad.dtype.is_complex:
x = math_ops.conj(x)
y = math_ops.conj(y)
x_shape = array_ops.shape(x)
y_shape = array_ops.shape(y)
grad_x = _GetGradWrt(grad, y, x_shape, x_subs, y_subs, output_subs)