Fix bug in complex einsum grad.
Discovered another bug triggered on empty shapes with complex inputs. On the way, went ahead and removed spurious conjugate operations inside einsum function. Fixes #37307. PiperOrigin-RevId: 310069589 Change-Id: I665e4b077c13eddfbad7c5ffbb9f0addd7207cb4
This commit is contained in:
parent
e58a4820ed
commit
eaaf5b6b43
@ -78,8 +78,9 @@ struct ParallelMatMulKernel {
|
||||
}
|
||||
|
||||
static void Run(const OpKernelContext* context, const Tensor& in_x,
|
||||
const Tensor in_y, bool adj_x, bool adj_y,
|
||||
const MatMulBCast& bcast, Tensor* out, int start, int limit) {
|
||||
const Tensor in_y, bool adj_x, bool adj_y, bool trans_x,
|
||||
bool trans_y, const MatMulBCast& bcast, Tensor* out,
|
||||
int start, int limit) {
|
||||
static_assert(IsComplex, "Complex type expected.");
|
||||
auto Tx = in_x.tensor<Scalar, 3>();
|
||||
auto Ty = in_y.tensor<Scalar, 3>();
|
||||
@ -90,7 +91,7 @@ struct ParallelMatMulKernel {
|
||||
// to halve the number of cases. The final conjugation of the result is
|
||||
// done at the end of LaunchBatchMatMul<CPUDevice, Scalar>::Launch().
|
||||
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
|
||||
contract_pairs[0] = ContractionDims(adj_x, adj_y);
|
||||
contract_pairs[0] = ContractionDims(adj_x || trans_x, adj_y || trans_y);
|
||||
const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
|
||||
|
||||
const bool should_bcast = bcast.IsBroadcastingRequired();
|
||||
@ -121,13 +122,14 @@ struct ParallelMatMulKernel<Scalar, false> {
|
||||
static void Conjugate(const OpKernelContext* context, Tensor* out) {}
|
||||
|
||||
static void Run(const OpKernelContext* context, const Tensor& in_x,
|
||||
const Tensor& in_y, bool adj_x, bool adj_y,
|
||||
const MatMulBCast& bcast, Tensor* out, int start, int limit) {
|
||||
const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
|
||||
bool trans_y, const MatMulBCast& bcast, Tensor* out,
|
||||
int start, int limit) {
|
||||
auto Tx = in_x.tensor<Scalar, 3>();
|
||||
auto Ty = in_y.tensor<Scalar, 3>();
|
||||
auto Tz = out->tensor<Scalar, 3>();
|
||||
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
|
||||
contract_pairs[0] = ContractionDims(adj_x, adj_y);
|
||||
contract_pairs[0] = ContractionDims(adj_x || trans_x, adj_y || trans_y);
|
||||
const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
|
||||
|
||||
const bool should_bcast = bcast.IsBroadcastingRequired();
|
||||
@ -169,8 +171,8 @@ struct SequentialMatMulKernel {
|
||||
}
|
||||
|
||||
static void Run(const Tensor& in_x, const Tensor& in_y, bool adj_x,
|
||||
bool adj_y, const MatMulBCast& bcast, Tensor* out, int start,
|
||||
int limit) {
|
||||
bool adj_y, bool trans_x, bool trans_y,
|
||||
const MatMulBCast& bcast, Tensor* out, int start, int limit) {
|
||||
const bool should_bcast = bcast.IsBroadcastingRequired();
|
||||
const auto& x_batch_indices = bcast.x_batch_indices();
|
||||
const auto& y_batch_indices = bcast.y_batch_indices();
|
||||
@ -180,17 +182,31 @@ struct SequentialMatMulKernel {
|
||||
auto x = ConstTensorSliceToEigenMatrix(in_x, x_batch_index);
|
||||
auto y = ConstTensorSliceToEigenMatrix(in_y, y_batch_index);
|
||||
auto z = TensorSliceToEigenMatrix(out, i);
|
||||
if (!adj_x) {
|
||||
if (!adj_y) {
|
||||
// Assume at most one of adj_x or trans_x is true. Similarly, for adj_y
|
||||
// and trans_y.
|
||||
if (!adj_x && !trans_x) {
|
||||
if (!adj_y && !trans_y) {
|
||||
z.noalias() = x * y;
|
||||
} else {
|
||||
} else if (adj_y) {
|
||||
z.noalias() = x * y.adjoint();
|
||||
} else { // trans_y == true
|
||||
z.noalias() = x * y.transpose();
|
||||
}
|
||||
} else {
|
||||
if (!adj_y) {
|
||||
} else if (adj_x) {
|
||||
if (!adj_y && !trans_y) {
|
||||
z.noalias() = x.adjoint() * y;
|
||||
} else {
|
||||
} else if (adj_y) {
|
||||
z.noalias() = x.adjoint() * y.adjoint();
|
||||
} else { // trans_y == true
|
||||
z.noalias() = x.adjoint() * y.transpose();
|
||||
}
|
||||
} else { // trans_x == true
|
||||
if (!adj_y && !trans_y) {
|
||||
z.noalias() = x.transpose() * y;
|
||||
} else if (adj_y) {
|
||||
z.noalias() = x.transpose() * y.adjoint();
|
||||
} else { // trans_y == true
|
||||
z.noalias() = x.transpose() * y.transpose();
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -205,8 +221,8 @@ struct LaunchBatchMatMul;
|
||||
template <typename Scalar>
|
||||
struct LaunchBatchMatMul<CPUDevice, Scalar> {
|
||||
static void Launch(OpKernelContext* context, const Tensor& in_x,
|
||||
const Tensor& in_y, bool adj_x, bool adj_y,
|
||||
const MatMulBCast& bcast, Tensor* out) {
|
||||
const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
|
||||
bool trans_y, const MatMulBCast& bcast, Tensor* out) {
|
||||
typedef ParallelMatMulKernel<Scalar, Eigen::NumTraits<Scalar>::IsComplex>
|
||||
ParallelMatMulKernel;
|
||||
bool conjugate_result = false;
|
||||
@ -226,17 +242,19 @@ struct LaunchBatchMatMul<CPUDevice, Scalar> {
|
||||
// Parallelize over inner dims.
|
||||
// For large matrix products it is counter-productive to parallelize
|
||||
// over the batch dimension.
|
||||
ParallelMatMulKernel::Run(context, in_x, in_y, adj_x, adj_y, bcast, out,
|
||||
0, batch_size);
|
||||
ParallelMatMulKernel::Run(context, in_x, in_y, adj_x, adj_y, trans_x,
|
||||
trans_y, bcast, out, 0, batch_size);
|
||||
conjugate_result = adj_x;
|
||||
} else {
|
||||
// Parallelize over outer dims. For small matrices and large batches, it
|
||||
// is counter-productive to parallelize the inner matrix multiplies.
|
||||
Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
|
||||
cost_per_unit,
|
||||
[&in_x, &in_y, adj_x, adj_y, &bcast, out](int start, int limit) {
|
||||
[&in_x, &in_y, adj_x, adj_y, trans_x, trans_y, &bcast, out](
|
||||
int start, int limit) {
|
||||
SequentialMatMulKernel<Scalar>::Run(in_x, in_y, adj_x, adj_y,
|
||||
bcast, out, start, limit);
|
||||
trans_x, trans_y, bcast, out,
|
||||
start, limit);
|
||||
});
|
||||
}
|
||||
if (conjugate_result) {
|
||||
@ -297,19 +315,17 @@ class BlasScratchAllocator : public se::ScratchAllocator {
|
||||
template <typename Scalar>
|
||||
struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
||||
static void Launch(OpKernelContext* context, const Tensor& in_x,
|
||||
const Tensor& in_y, bool adj_x, bool adj_y,
|
||||
const MatMulBCast& bcast, Tensor* out) {
|
||||
constexpr se::blas::Transpose kTranspose =
|
||||
is_complex<Scalar>::value ? se::blas::Transpose::kConjugateTranspose
|
||||
: se::blas::Transpose::kTranspose;
|
||||
const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
|
||||
bool trans_y, const MatMulBCast& bcast, Tensor* out) {
|
||||
se::blas::Transpose trans[] = {se::blas::Transpose::kNoTranspose,
|
||||
kTranspose};
|
||||
const uint64 m = in_x.dim_size(adj_x ? 2 : 1);
|
||||
const uint64 k = in_x.dim_size(adj_x ? 1 : 2);
|
||||
const uint64 n = in_y.dim_size(adj_y ? 1 : 2);
|
||||
se::blas::Transpose::kTranspose,
|
||||
se::blas::Transpose::kConjugateTranspose};
|
||||
const uint64 m = in_x.dim_size(adj_x || trans_x ? 2 : 1);
|
||||
const uint64 k = in_x.dim_size(adj_x || trans_x ? 1 : 2);
|
||||
const uint64 n = in_y.dim_size(adj_y || trans_y ? 1 : 2);
|
||||
const int64 batch_size = bcast.output_batch_size();
|
||||
auto blas_transpose_a = trans[adj_x];
|
||||
auto blas_transpose_b = trans[adj_y];
|
||||
auto blas_transpose_a = trans[adj_x ? 2 : (trans_x ? 1 : 0)];
|
||||
auto blas_transpose_b = trans[adj_y ? 2 : (trans_y ? 1 : 0)];
|
||||
|
||||
auto* stream = context->op_device_context()->stream();
|
||||
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
|
||||
@ -399,9 +415,10 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
||||
: se::blas::Transpose::kTranspose;
|
||||
bool blas_launch_status =
|
||||
stream
|
||||
->ThenBlasGemv(gemv_trans_a, adj_x ? m : k, adj_x ? k : m,
|
||||
->ThenBlasGemv(gemv_trans_a, adj_x || trans_x ? m : k,
|
||||
adj_x || trans_x ? k : m,
|
||||
static_cast<Coefficient>(1.0), *(a_ptrs[0]),
|
||||
adj_x ? m : k, *(b_ptrs[0]), 1,
|
||||
adj_x || trans_x ? m : k, *(b_ptrs[0]), 1,
|
||||
static_cast<Coefficient>(0.0), c_ptrs[0], 1)
|
||||
.ok();
|
||||
if (!blas_launch_status) {
|
||||
@ -415,7 +432,8 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
||||
stream
|
||||
->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k,
|
||||
static_cast<Coefficient>(1.0), *(b_ptrs[0]),
|
||||
adj_y ? k : n, *(a_ptrs[0]), adj_x ? m : k,
|
||||
adj_y || trans_y ? k : n, *(a_ptrs[0]),
|
||||
adj_x || trans_x ? m : k,
|
||||
static_cast<Coefficient>(0.0), c_ptrs[0], n)
|
||||
.ok();
|
||||
if (!blas_launch_status) {
|
||||
@ -430,8 +448,9 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
||||
stream
|
||||
->ThenBlasGemmStridedBatched(
|
||||
blas_transpose_b, blas_transpose_a, n, m, k,
|
||||
static_cast<Coefficient>(1.0), *b_ptrs[0], adj_y ? k : n,
|
||||
b_stride, *a_ptrs[0], adj_x ? m : k, a_stride,
|
||||
static_cast<Coefficient>(1.0), *b_ptrs[0],
|
||||
adj_y || trans_y ? k : n, b_stride, *a_ptrs[0],
|
||||
adj_x || trans_x ? m : k, a_stride,
|
||||
static_cast<Coefficient>(0.0), c_ptrs[0], n, c_stride,
|
||||
batch_size)
|
||||
.ok();
|
||||
@ -448,9 +467,10 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
||||
stream
|
||||
->ThenBlasGemmBatchedWithScratch(
|
||||
blas_transpose_b, blas_transpose_a, n, m, k,
|
||||
static_cast<Coefficient>(1.0), b_ptrs, adj_y ? k : n, a_ptrs,
|
||||
adj_x ? m : k, static_cast<Coefficient>(0.0), c_ptrs, n,
|
||||
batch_size, &scratch_allocator)
|
||||
static_cast<Coefficient>(1.0), b_ptrs,
|
||||
adj_y || trans_y ? k : n, a_ptrs, adj_x || trans_x ? m : k,
|
||||
static_cast<Coefficient>(0.0), c_ptrs, n, batch_size,
|
||||
&scratch_allocator)
|
||||
.ok();
|
||||
if (!blas_launch_status) {
|
||||
context->SetStatus(errors::Internal(
|
||||
@ -466,21 +486,18 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
||||
template <>
|
||||
struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
|
||||
static void Launch(OpKernelContext* context, const Tensor& in_x,
|
||||
const Tensor& in_y, bool adj_x, bool adj_y,
|
||||
const MatMulBCast& bcast, Tensor* out) {
|
||||
const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
|
||||
bool trans_y, const MatMulBCast& bcast, Tensor* out) {
|
||||
typedef Eigen::half Scalar;
|
||||
constexpr perftools::gputools::blas::Transpose kTranspose =
|
||||
is_complex<Scalar>::value
|
||||
? perftools::gputools::blas::Transpose::kConjugateTranspose
|
||||
: perftools::gputools::blas::Transpose::kTranspose;
|
||||
perftools::gputools::blas::Transpose trans[] = {
|
||||
perftools::gputools::blas::Transpose::kNoTranspose, kTranspose};
|
||||
const uint64 m = in_x.dim_size(adj_x ? 2 : 1);
|
||||
const uint64 k = in_x.dim_size(adj_x ? 1 : 2);
|
||||
const uint64 n = in_y.dim_size(adj_y ? 1 : 2);
|
||||
se::blas::Transpose trans[] = {se::blas::Transpose::kNoTranspose,
|
||||
se::blas::Transpose::kTranspose,
|
||||
se::blas::Transpose::kConjugateTranspose};
|
||||
const uint64 m = in_x.dim_size(adj_x || trans_x ? 2 : 1);
|
||||
const uint64 k = in_x.dim_size(adj_x || trans_x ? 1 : 2);
|
||||
const uint64 n = in_y.dim_size(adj_y || trans_y ? 1 : 2);
|
||||
const uint64 batch_size = bcast.output_batch_size();
|
||||
auto blas_transpose_a = trans[adj_x];
|
||||
auto blas_transpose_b = trans[adj_y];
|
||||
auto blas_transpose_a = trans[adj_x ? 2 : (trans_x ? 1 : 0)];
|
||||
auto blas_transpose_b = trans[adj_y ? 2 : (trans_y ? 1 : 0)];
|
||||
|
||||
auto* stream = context->op_device_context()->stream();
|
||||
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
|
||||
@ -563,7 +580,8 @@ struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
|
||||
stream
|
||||
->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k,
|
||||
static_cast<Coefficient>(1.0), *(b_ptrs[0]),
|
||||
adj_y ? k : n, *(a_ptrs[0]), adj_x ? m : k,
|
||||
adj_y || trans_y ? k : n, *(a_ptrs[0]),
|
||||
adj_x || trans_x ? m : k,
|
||||
static_cast<Coefficient>(0.0), c_ptrs[0], n)
|
||||
.ok();
|
||||
if (!blas_launch_status) {
|
||||
@ -577,8 +595,9 @@ struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
|
||||
stream
|
||||
->ThenBlasGemmStridedBatched(
|
||||
blas_transpose_b, blas_transpose_a, n, m, k,
|
||||
static_cast<Coefficient>(1.0), *b_ptrs[0], adj_y ? k : n,
|
||||
b_stride, *a_ptrs[0], adj_x ? m : k, a_stride,
|
||||
static_cast<Coefficient>(1.0), *b_ptrs[0],
|
||||
adj_y || trans_y ? k : n, b_stride, *a_ptrs[0],
|
||||
adj_x || trans_x ? m : k, a_stride,
|
||||
static_cast<Coefficient>(0.0), c_ptrs[0], n, c_stride,
|
||||
batch_size)
|
||||
.ok();
|
||||
@ -595,9 +614,10 @@ struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
|
||||
stream
|
||||
->ThenBlasGemmBatchedWithScratch(
|
||||
blas_transpose_b, blas_transpose_a, n, m, k,
|
||||
static_cast<Coefficient>(1.0), b_ptrs, adj_y ? k : n, a_ptrs,
|
||||
adj_x ? m : k, static_cast<Coefficient>(0.0), c_ptrs, n,
|
||||
batch_size, &scratch_allocator)
|
||||
static_cast<Coefficient>(1.0), b_ptrs,
|
||||
adj_y || trans_y ? k : n, a_ptrs, adj_x || trans_x ? m : k,
|
||||
static_cast<Coefficient>(0.0), c_ptrs, n, batch_size,
|
||||
&scratch_allocator)
|
||||
.ok();
|
||||
if (!blas_launch_status) {
|
||||
context->SetStatus(errors::Internal(
|
||||
@ -616,13 +636,14 @@ struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
|
||||
template <typename Scalar>
|
||||
struct ParallelMatMulKernelSYCL {
|
||||
static void Run(const OpKernelContext* context, const Tensor& in_x,
|
||||
const Tensor& in_y, bool adj_x, bool adj_y,
|
||||
const MatMulBCast& bcast, Tensor* out, int start, int limit) {
|
||||
const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
|
||||
bool trans_y, const MatMulBCast& bcast, Tensor* out,
|
||||
int start, int limit) {
|
||||
auto Tx = in_x.tensor<Scalar, 3>();
|
||||
auto Ty = in_y.tensor<Scalar, 3>();
|
||||
auto Tz = out->tensor<Scalar, 3>();
|
||||
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
|
||||
contract_pairs[0] = ContractionDims(adj_x, adj_y);
|
||||
contract_pairs[0] = ContractionDims(adj_x || trans_x, adj_y || trans_y);
|
||||
auto d = context->eigen_sycl_device();
|
||||
|
||||
const bool should_bcast = bcast.IsBroadcastingRequired();
|
||||
@ -643,12 +664,13 @@ struct ParallelMatMulKernelSYCL {
|
||||
template <typename Scalar>
|
||||
struct LaunchBatchMatMul<SYCLDevice, Scalar> {
|
||||
static void Launch(OpKernelContext* context, const Tensor& in_x,
|
||||
const Tensor& in_y, bool adj_x, bool adj_y,
|
||||
const MatMulBCast& bcast, Tensor* out) {
|
||||
const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
|
||||
bool trans_y, const MatMulBCast& bcast, Tensor* out) {
|
||||
// Number of matrix multiplies i.e. size of the batch.
|
||||
const int64 batch_size = bcast.output_batch_size();
|
||||
ParallelMatMulKernelSYCL<Scalar>::Run(context, in_x, in_y, adj_x, adj_y,
|
||||
bcast, out, 0, batch_size);
|
||||
trans_x, trans_y, bcast, out, 0,
|
||||
batch_size);
|
||||
}
|
||||
};
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
@ -720,7 +742,8 @@ class BaseBatchMatMulOp : public OpKernel {
|
||||
errors::Internal("Failed to reshape output from ",
|
||||
out->shape().DebugString()));
|
||||
LaunchBatchMatMul<Device, Scalar>::Launch(
|
||||
ctx, in0_reshaped, in1_reshaped, adj_x_, adj_y_, bcast, &out_reshaped);
|
||||
ctx, in0_reshaped, in1_reshaped, adj_x_, adj_y_, /*trans_x=*/false,
|
||||
/*trans_y=*/false, bcast, &out_reshaped);
|
||||
}
|
||||
|
||||
protected:
|
||||
|
@ -537,29 +537,13 @@ struct EinsumHelper {
|
||||
return CopyFrom(input, output_shape, output);
|
||||
}
|
||||
|
||||
// Conjugates the input.
|
||||
template <typename Device, typename T>
|
||||
static Status Conjugate(OpKernelContext* ctx, Tensor* input) {
|
||||
std::vector<int> permutation(input->dims());
|
||||
std::iota(permutation.begin(), permutation.end(), 0);
|
||||
Tensor output;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ctx->allocate_temp(DataTypeToEnum<T>::value, input->shape(), &output));
|
||||
const Device& d = ctx->eigen_device<Device>();
|
||||
TF_RETURN_IF_ERROR(DoConjugateTranspose(d, *input, permutation, &output));
|
||||
std::swap(*input, output);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Contracts the inputs along the last axis. (or the second last if the
|
||||
// corresponding value of swap_free_and_contract is true). The batch
|
||||
// dimensions are broadcast to the output shape.
|
||||
// TODO(anudhyan): Factor this function into a BatchMatMul functor and support
|
||||
// transpose_x and transpose_y attributes (in addition to adj_x and adj_y).
|
||||
// Also, the BatchMatMul might devolve into a component-wise multiplication
|
||||
// when the matrix shape is [1,1]; in this case BatchMatMul functor would be
|
||||
// very inefficient. The functor should detect if this is the case and perform
|
||||
// componentwise multiplication functor instead.
|
||||
// TODO(anudhyan): BatchMatMul might devolve into a component-wise
|
||||
// multiplication when the matrix shape is [1,1]; in this case BatchMatMul
|
||||
// functor would be very inefficient. The functor should detect if this is the
|
||||
// case and perform componentwise multiplication functor instead.
|
||||
template <typename Device, typename T>
|
||||
static Status ContractOperands(OpKernelContext* ctx,
|
||||
absl::Span<const Tensor> inputs,
|
||||
@ -584,12 +568,8 @@ struct EinsumHelper {
|
||||
inputs[i].dims() - (swap_free_and_contract[i] ? 1 : 2);
|
||||
output_shape.AddDim(inputs[i].dim_size(free_axis));
|
||||
}
|
||||
bool adj_x = swap_free_and_contract[0];
|
||||
bool adj_y = !swap_free_and_contract[1];
|
||||
if (is_complex<T>::value) {
|
||||
if (adj_x) TF_RETURN_IF_ERROR(Conjugate<Device, T>(ctx, &lhs));
|
||||
if (adj_y) TF_RETURN_IF_ERROR(Conjugate<Device, T>(ctx, &rhs));
|
||||
}
|
||||
bool trans_x = swap_free_and_contract[0];
|
||||
bool trans_y = !swap_free_and_contract[1];
|
||||
TF_RETURN_IF_ERROR(
|
||||
ctx->allocate_temp(DataTypeToEnum<T>::value, output_shape, output));
|
||||
if (lhs.NumElements() == 0 || rhs.NumElements() == 0) {
|
||||
@ -600,8 +580,9 @@ struct EinsumHelper {
|
||||
Tensor output_reshaped;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ReshapeToRank3(*output, bcast.output_batch_size(), &output_reshaped));
|
||||
LaunchBatchMatMul<Device, T>::Launch(ctx, lhs, rhs, adj_x, adj_y, bcast,
|
||||
&output_reshaped);
|
||||
LaunchBatchMatMul<Device, T>::Launch(ctx, lhs, rhs, /*adj_x=*/false,
|
||||
/*adj_y=*/false, trans_x, trans_y,
|
||||
bcast, &output_reshaped);
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
@ -286,13 +286,21 @@ class EinsumGradTest(test.TestCase):
|
||||
|
||||
def _check_gradient(self, s, *input_shapes):
|
||||
with self.cached_session():
|
||||
r = np.random.RandomState(0)
|
||||
inputs = [np.array(r.randn(*shape), np.float64) for shape in input_shapes]
|
||||
input_tensors = [constant_op.constant(x, shape=x.shape) for x in inputs]
|
||||
analytical, numerical = gradient_checker_v2.compute_gradient(
|
||||
lambda *xs: gen_linalg_ops.einsum(xs, s), input_tensors)
|
||||
self.assertLess(
|
||||
gradient_checker_v2.max_error(analytical, numerical), 1e-4)
|
||||
r = np.random.RandomState(seed=0)
|
||||
for dtype in (np.float32, np.float64, np.complex64, np.complex128):
|
||||
tol = 10 * np.sqrt(np.finfo(dtype).resolution)
|
||||
if dtype in (np.complex64, np.complex128):
|
||||
inputs = [
|
||||
np.array(r.randn(*shape), dtype) +
|
||||
1j * np.array(r.randn(*shape), dtype) for shape in input_shapes
|
||||
]
|
||||
else:
|
||||
inputs = [np.array(r.randn(*shape), dtype) for shape in input_shapes]
|
||||
input_tensors = [constant_op.constant(x, shape=x.shape) for x in inputs]
|
||||
analytical, numerical = gradient_checker_v2.compute_gradient(
|
||||
lambda *xs: gen_linalg_ops.einsum(xs, s), input_tensors)
|
||||
self.assertLess(
|
||||
gradient_checker_v2.max_error(analytical, numerical), tol)
|
||||
|
||||
@test_util.disable_xla('b/131919749')
|
||||
def testUnary(self):
|
||||
|
@ -332,6 +332,10 @@ def _EinsumGrad(op, grad):
|
||||
# Obtain the gradients wrt the inputs x and y, without taking into account
|
||||
# the unbroadcasting.
|
||||
x, y = op.inputs[0], op.inputs[1]
|
||||
if grad.dtype.is_complex:
|
||||
x = math_ops.conj(x)
|
||||
y = math_ops.conj(y)
|
||||
|
||||
x_shape = array_ops.shape(x)
|
||||
y_shape = array_ops.shape(y)
|
||||
grad_x = _GetGradWrt(grad, y, x_shape, x_subs, y_subs, output_subs)
|
||||
|
Loading…
Reference in New Issue
Block a user