Update trace, matrix_set_diag, matrix_diag_part and their gradients to work for rectangular matrices.

Generalize trace to work like numpy.trace(x, axis1=-2, axis2=-1), including for rank > 2.
Fix bad doc string for matrix_band_part.
Change: 134700928
This commit is contained in:
A. Unique TensorFlower 2016-09-29 12:28:50 -08:00 committed by TensorFlower Gardener
parent b7d5df182b
commit 2a5a96976d
9 changed files with 193 additions and 141 deletions

View File

@ -226,7 +226,7 @@ class OperatorPDCholeskyTest(tf.test.TestCase):
# should raise.
with self.test_session():
batch_vec = [[1.0], [2.0]] # shape 2 x 1
with self.assertRaisesRegexp(ValueError, ".*Dimensions.*"):
with self.assertRaisesOpError("x == y did not hold"):
operator = operator_pd_cholesky.OperatorPDCholesky(batch_vec)
operator.to_dense().eval()

View File

@ -58,23 +58,19 @@ class MatrixDiagPartOp : public OpKernel {
"input must be at least 2-dim, received shape: ",
input.shape().DebugString()));
// Check to make sure the last two dimensions have the same value
const int64 k = input_shape.dim_size(rank - 1);
OP_REQUIRES(
context, k == input_shape.dim_size(rank - 2),
errors::InvalidArgument(
"input's last two dimensions must be equal, received shape: ",
input.shape().DebugString()));
auto input_reshaped = input.flat_inner_dims<T, 3>();
TensorShape output_shape = input_shape;
output_shape.RemoveDim(rank - 1);
TensorShape output_shape;
for (int i = 0; i < rank - 2; ++i) {
output_shape.AddDim(input_shape.dim_size(i));
}
const int64 min_dim = std::min(input_shape.dim_size(rank - 2),
input_shape.dim_size(rank - 1));
output_shape.AddDim(min_dim);
Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
auto output_reshaped = output->flat_inner_dims<T, 2>();
auto input_reshaped = input.flat_inner_dims<T, 3>();
functor::MatrixDiagPart<Device, T>::Compute(
context->eigen_device<Device>(), input_reshaped, output_reshaped);
@ -101,7 +97,6 @@ class MatrixDiagOp : public OpKernel {
"input must be at least 1-dim, received shape: ",
input.shape().DebugString()));
// Check to make sure the last two dimensions have the same value
const int64 k = input_shape.dim_size(rank - 1);
auto input_reshaped = input.flat_inner_dims<T, 2>();

View File

@ -59,21 +59,18 @@ class MatrixSetDiagOp : public OpKernel {
"input must be at least 2-dim, received shape: ",
input.shape().DebugString()));
// Check to make sure the last two dimensions have the same value
const int64 k = input_shape.dim_size(rank - 1);
OP_REQUIRES(
context, k == input_shape.dim_size(rank - 2),
errors::InvalidArgument(
"input's last two dimensions must be equal, received shape: ",
input.shape().DebugString()));
TensorShape input_shape_but_one = input_shape;
input_shape_but_one.RemoveDim(rank - 1);
OP_REQUIRES(context, input_shape_but_one == diag_shape,
// Check to make sure the last dimension of diag is equal to the smaller of
// the last two dimensions of input.
const int64 min_dim = std::min(input_shape.dim_size(rank - 1),
input_shape.dim_size(rank - 2));
TensorShape expected_diag_shape = input_shape;
expected_diag_shape.RemoveDim(rank - 1);
expected_diag_shape.RemoveDim(rank - 2);
expected_diag_shape.AddDim(min_dim);
OP_REQUIRES(context, expected_diag_shape == diag_shape,
errors::InvalidArgument(
"must have diagonal.shape == input.shape[:-1], but "
"received input shape: ",
"must have diagonal.shape == input.shape[:-2] + "
"min(input.shape[-2:]), but received input shape: ",
input_shape.DebugString(), " and diagonal shape: ",
diag_shape.DebugString()));
@ -127,7 +124,7 @@ struct MatrixSetDiag<CPUDevice, T> {
typename TTypes<T, 3>::Tensor output) {
output.device(d) = input;
for (int64 r = 0; r < output.dimension(0); ++r) {
for (int64 d = 0; d < output.dimension(1); ++d) {
for (int64 d = 0; d < diag.dimension(1); ++d) {
output(r, d, d) = diag(r, d);
}
}

View File

@ -554,16 +554,24 @@ REGISTER_OP("MatrixSetDiag")
ShapeHandle diag;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &diag));
DimensionHandle square_dim;
if (c->RankKnown(input)) {
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), c->Rank(input) - 1, &diag));
}
DimensionHandle smallest_dim;
TF_RETURN_IF_ERROR(
c->Merge(c->Dim(input, -2), c->Dim(input, -1), &square_dim));
TF_RETURN_IF_ERROR(c->Merge(square_dim, c->Dim(diag, -1), &square_dim));
ShapeHandle output;
TF_RETURN_IF_ERROR(c->Concatenate(diag, c->Vector(square_dim), &output));
TF_RETURN_IF_ERROR(c->Merge(input, output, &output));
c->Min(c->Dim(input, -2), c->Dim(input, -1), &smallest_dim));
TF_RETURN_IF_ERROR(
c->Merge(smallest_dim, c->Dim(diag, -1), &smallest_dim));
ShapeHandle output = input;
if (c->RankKnown(diag) && !c->FullyDefined(input)) {
// Try to infer parts of shape from diag.
ShapeHandle diag_prefix;
TF_RETURN_IF_ERROR(c->Subshape(diag, 0, -1, &diag_prefix));
TF_RETURN_IF_ERROR(
c->Concatenate(diag_prefix, c->UnknownShapeOfRank(2), &diag));
TF_RETURN_IF_ERROR(c->Merge(input, diag, &output));
}
c->set_output(0, output);
return Status::OK();
})
@ -571,15 +579,14 @@ REGISTER_OP("MatrixSetDiag")
Returns a batched matrix tensor with new batched diagonal values.
Given `input` and `diagonal`, this operation returns a tensor with the
same shape and values as `input`, except for the diagonals of the innermost
matrices. These will be overwritten by the values in `diagonal`.
The batched matrices must be square.
same shape and values as `input`, except for the main diagonal of the
innermost matrices. These will be overwritten by the values in `diagonal`.
The output is computed as follows:
Assume `input` has `k+1` dimensions `[I, J, K, ..., N, N]` and `diagonal` has
`k` dimensions `[I, J, K, ..., N]`. Then the output is a
tensor of rank `k+1` with dimensions [I, J, K, ..., N, N]` where:
Assume `input` has `k+1` dimensions `[I, J, K, ..., M, N]` and `diagonal` has
`k` dimensions `[I, J, K, ..., min(M, N)]`. Then the output is a
tensor of rank `k+1` with dimensions `[I, J, K, ..., M, N]` where:
* `output[i, j, k, ..., m, n] = diagonal[i, j, k, ..., n]` for `m == n`.
* `output[i, j, k, ..., m, n] = input[i, j, k, ..., m, n]` for `m != n`.
@ -602,14 +609,13 @@ REGISTER_OP("MatrixDiagPart")
return Status::OK();
}
const int32 rank = c->Rank(in);
// Last two dims must match.
DimensionHandle unused;
TF_RETURN_IF_ERROR(
c->Merge(c->Dim(in, rank - 1), c->Dim(in, rank - 2), &unused));
// Output shape has all dims but last of input.
std::vector<DimensionHandle> dims;
for (int i = 0; i < rank - 1; ++i) dims.push_back(c->Dim(in, i));
for (int i = 0; i < rank - 2; ++i) dims.push_back(c->Dim(in, i));
DimensionHandle min_dim;
TF_RETURN_IF_ERROR(
c->Min(c->Dim(in, rank - 2), c->Dim(in, rank - 1), &min_dim));
dims.push_back(min_dim);
c->set_output(0, c->MakeShape(dims));
return Status::OK();
})
@ -619,8 +625,8 @@ Returns the batched diagonal part of a batched tensor.
This operation returns a tensor with the `diagonal` part
of the batched `input`. The `diagonal` part is computed as follows:
Assume `input` has `k` dimensions `[I, J, K, ..., N, N]`, then the output is a
tensor of rank `k - 1` with dimensions `[I, J, K, ..., N]` where:
Assume `input` has `k` dimensions `[I, J, K, ..., M, N]`, then the output is a
tensor of rank `k - 1` with dimensions `[I, J, K, ..., min(M, N)]` where:
`diagonal[i, j, k, ..., n] = input[i, j, k, ..., n, n]`.
@ -645,9 +651,9 @@ tf.matrix_diag_part(input) ==> [[1, 2, 3, 4], [5, 6, 7, 8]]
which has shape (2, 4)
```
input: Rank `k` tensor where `k >= 2` and the last two dimensions are equal.
input: Rank `k` tensor where `k >= 2`.
diagonal: The extracted diagonal(s) having shape
`diagonal.shape = input.shape[:-1]`.
`diagonal.shape = input.shape[:-2] + [min(input.shape[-2:])]`.
)doc");
// --------------------------------------------------------------------------
@ -668,9 +674,10 @@ tensor with the same shape where
`band[i, j, k, ..., m, n] = in_band(m, n) * input[i, j, k, ..., m, n]`.
The indicator function 'in_band(m, n)` is one if
`(num_lower < 0 || (m-n) <= num_lower)) &&
(num_upper < 0 || (n-m) <= num_upper)`, and zero otherwise.
The indicator function
`in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower)) &&
(num_upper < 0 || (n-m) <= num_upper)`.
For example:
@ -681,14 +688,14 @@ For example:
[-3, -2, -1, 0]],
tf.matrix_band_part(input, 1, -1) ==> [[ 0, 1, 2, 3]
[-1, 0, 1, 2]
[ 0, -1, 0, 1]
[ 0, 0, -1, 0]],
[-1, 0, 1, 2]
[ 0, -1, 0, 1]
[ 0, 0, -1, 0]],
tf.matrix_band_part(input, 2, 1) ==> [[ 0, 1, 0, 0]
[-1, 0, 1, 0]
[-2, -1, 0, 1]
[ 0, -2, -1, 0]]
[-1, 0, 1, 0]
[-2, -1, 0, 1]
[ 0, -2, -1, 0]]
```
Useful special cases:

View File

@ -185,7 +185,8 @@ TEST(ArrayOpsTest, MatrixDiagPart_ShapeFn) {
INFER_OK(op, "?", "?");
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[?]");
INFER_OK(op, "[?,1,2,2]", "[d0_0,d0_1,d0_2|d0_3]");
INFER_ERROR("Dimensions must be equal, but are 3 and 2", op, "[1,2,3]");
INFER_OK(op, "[?,1,2,3]", "[d0_0,d0_1,d0_2]");
INFER_OK(op, "[?,1,3,2]", "[d0_0,d0_1,d0_3]");
}
TEST(ArrayOpsTest, Reverse_ShapeFn) {
@ -364,19 +365,25 @@ TEST(ArrayOpsTest, MatrixSetDiag_ShapeFn) {
// Rank checks.
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1];?");
INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "?;[]");
INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[2,2];[]");
INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[2,2];[2,2]");
// Output matches input, and also matches diagonal + diagonal.dim(-1).
// diagonal[-1] must match smallest matrix dimension.
INFER_ERROR("Dimensions must be equal, but are 2 and 3", op, "[2,3];[3]");
// Output matches input.
INFER_OK(op, "?;?", "?");
INFER_OK(op, "?;[1,2]", "[d1_0,d1_1,d1_1]");
INFER_OK(op, "[1,2,2];?", "in0");
INFER_OK(op, "[1,2,2];[1,2]", "in0");
INFER_OK(op, "[1,2,3];?", "in0");
INFER_OK(op, "[1,3,2];?", "in0");
INFER_OK(op, "[1,?,2];[?,?]", "in0");
INFER_OK(op, "[1,?,?];[?,2]", "[d0_0,d1_1,d1_1]");
INFER_OK(op, "[1,?,?];[?,2]", "in0");
// Last 2 dims of input must match.
INFER_ERROR("Dimensions must be equal, but are 2 and 3", op, "[1,2,3];?");
// Dims matches prefix of input.
INFER_ERROR("Dimensions must be equal, but are 1 and 2", op, "[1,?];[2]");
// Infer batch shape from diag when input is not fully specified.
INFER_OK(op, "?;[1,2]", "[d1_0,?,?]");
INFER_OK(op, "[?,?,3];[1,2]", "[d1_0,d0_1,d0_2]");
INFER_OK(op, "[?,3,?];[1,2]", "[d1_0,d0_1,d0_2]");
INFER_OK(op, "[?,3,2];[1,2]", "[d1_0,d0_1,d0_2]");
}
TEST(ArrayOpsTest, ExpandDims_ShapeFn) {

View File

@ -75,7 +75,7 @@ class MatrixDiagGpuTest(MatrixDiagTest):
class MatrixSetDiagTest(tf.test.TestCase):
_use_gpu = False
def testVector(self):
def testSquare(self):
with self.test_session(use_gpu=self._use_gpu):
v = np.array([1.0, 2.0, 3.0])
mat = np.array([[0.0, 1.0, 0.0],
@ -88,7 +88,23 @@ class MatrixSetDiagTest(tf.test.TestCase):
self.assertEqual((3, 3), output.get_shape())
self.assertAllEqual(mat_set_diag, output.eval())
def testBatchVector(self):
def testRectangular(self):
with self.test_session(use_gpu=self._use_gpu):
v = np.array([3.0, 4.0])
mat = np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0]])
expected = np.array([[3.0, 1.0, 0.0], [1.0, 4.0, 1.0]])
output = tf.matrix_set_diag(mat, v)
self.assertEqual((2, 3), output.get_shape())
self.assertAllEqual(expected, output.eval())
v = np.array([3.0, 4.0])
mat = np.array([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0]])
expected = np.array([[3.0, 1.0], [1.0, 4.0], [1.0, 1.0]])
output = tf.matrix_set_diag(mat, v)
self.assertEqual((3, 2), output.get_shape())
self.assertAllEqual(expected, output.eval())
def testSquareBatch(self):
with self.test_session(use_gpu=self._use_gpu):
v_batch = np.array([[-1.0, -2.0, -3.0],
[-4.0, -5.0, -6.0]])
@ -111,6 +127,25 @@ class MatrixSetDiagTest(tf.test.TestCase):
self.assertEqual((2, 3, 3), output.get_shape())
self.assertAllEqual(mat_set_diag_batch, output.eval())
def testRectangularBatch(self):
with self.test_session(use_gpu=self._use_gpu):
v_batch = np.array([[-1.0, -2.0],
[-4.0, -5.0]])
mat_batch = np.array(
[[[1.0, 0.0, 3.0],
[0.0, 2.0, 0.0]],
[[4.0, 0.0, 4.0],
[0.0, 5.0, 0.0]]])
mat_set_diag_batch = np.array(
[[[-1.0, 0.0, 3.0],
[0.0, -2.0, 0.0]],
[[-4.0, 0.0, 4.0],
[0.0, -5.0, 0.0]]])
output = tf.matrix_set_diag(mat_batch, v_batch)
self.assertEqual((2, 2, 3), output.get_shape())
self.assertAllEqual(mat_set_diag_batch, output.eval())
def testInvalidShape(self):
with self.assertRaisesRegexp(ValueError, "must be at least rank 2"):
tf.matrix_set_diag(0, [0])
@ -127,11 +162,12 @@ class MatrixSetDiagTest(tf.test.TestCase):
tf.matrix_set_diag([[v]], v).eval(feed_dict={v: 0.0})
def testGrad(self):
shapes = ((3, 4, 4), (7, 4, 8, 8))
shapes = ((3, 4, 4), (3, 3, 4), (3, 4, 3), (7, 4, 8, 8))
with self.test_session(use_gpu=self._use_gpu):
for shape in shapes:
x = tf.constant(np.random.rand(*shape), dtype=tf.float32)
x_diag = tf.constant(np.random.rand(*shape[:-1]), dtype=tf.float32)
diag_shape = shape[:-2] + (min(shape[-2:]),)
x_diag = tf.constant(np.random.rand(*diag_shape), dtype=tf.float32)
y = tf.matrix_set_diag(x, x_diag)
error_x = tf.test.compute_gradient_error(x, x.get_shape().as_list(),
y, y.get_shape().as_list())
@ -164,7 +200,7 @@ class MatrixSetDiagGpuTest(MatrixSetDiagTest):
class MatrixDiagPartTest(tf.test.TestCase):
_use_gpu = False
def testMatrix(self):
def testSquare(self):
with self.test_session(use_gpu=self._use_gpu):
v = np.array([1.0, 2.0, 3.0])
mat = np.diag(v)
@ -172,7 +208,16 @@ class MatrixDiagPartTest(tf.test.TestCase):
self.assertEqual((3,), mat_diag.get_shape())
self.assertAllEqual(mat_diag.eval(), v)
def testBatchMatrix(self):
def testRectangular(self):
with self.test_session(use_gpu=self._use_gpu):
mat = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
mat_diag = tf.matrix_diag_part(mat)
self.assertAllEqual(mat_diag.eval(), np.array([1.0, 5.0]))
mat = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
mat_diag = tf.matrix_diag_part(mat)
self.assertAllEqual(mat_diag.eval(), np.array([1.0, 4.0]))
def testSquareBatch(self):
with self.test_session(use_gpu=self._use_gpu):
v_batch = np.array([[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0]])
@ -188,22 +233,32 @@ class MatrixDiagPartTest(tf.test.TestCase):
self.assertEqual((2, 3), mat_batch_diag.get_shape())
self.assertAllEqual(mat_batch_diag.eval(), v_batch)
def testRectangularBatch(self):
with self.test_session(use_gpu=self._use_gpu):
v_batch = np.array([[1.0, 2.0],
[4.0, 5.0]])
mat_batch = np.array(
[[[1.0, 0.0, 0.0],
[0.0, 2.0, 0.0]],
[[4.0, 0.0, 0.0],
[0.0, 5.0, 0.0]]])
self.assertEqual(mat_batch.shape, (2, 2, 3))
mat_batch_diag = tf.matrix_diag_part(mat_batch)
self.assertEqual((2, 2), mat_batch_diag.get_shape())
self.assertAllEqual(mat_batch_diag.eval(), v_batch)
def testInvalidShape(self):
with self.assertRaisesRegexp(ValueError, "must be at least rank 2"):
tf.matrix_diag_part(0)
with self.assertRaisesRegexp(ValueError, r"Dimensions must be equal"):
tf.matrix_diag_part([[0, 1], [1, 0], [0, 0]])
def testInvalidShapeAtEval(self):
with self.test_session(use_gpu=self._use_gpu):
v = tf.placeholder(dtype=tf.float32)
with self.assertRaisesOpError("input must be at least 2-dim"):
tf.matrix_diag_part(v).eval(feed_dict={v: 0.0})
with self.assertRaisesOpError("last two dimensions must be equal"):
tf.matrix_diag_part(v).eval(feed_dict={v: [[0, 1], [1, 0], [0, 0]]})
def testGrad(self):
shapes = ((3, 3), (5, 3, 3))
shapes = ((3, 3), (2, 3), (3, 2), (5, 3, 3))
with self.test_session(use_gpu=self._use_gpu):
for shape in shapes:
x = tf.constant(np.random.rand(*shape), dtype=np.float32)

View File

@ -17,54 +17,26 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy
import numpy as np
import tensorflow as tf
class TraceTest(tf.test.TestCase):
def setUp(self):
x = numpy.random.seed(0)
x = np.random.seed(0)
def traceOp(self, x, dtype, expected_ans, use_gpu=False):
with self.test_session(use_gpu=use_gpu):
tf_ans = tf.trace(x.astype(dtype))
out = tf_ans.eval()
self.assertAllClose(out, expected_ans)
def compare(self, x):
np_ans = np.trace(x, axis1=-2, axis2=-1)
with self.test_session(use_gpu=True):
tf_ans = tf.trace(x).eval()
self.assertAllClose(tf_ans, np_ans)
def testEmptyTensor(self):
x = numpy.array([])
self.assertRaises(ValueError, self.traceOp, x, numpy.float32, 0)
def testRankOneTensor(self):
x = numpy.array([1,2,3])
self.assertRaises(ValueError, self.traceOp, x, numpy.float32, 0)
def testRankTwoIntTensor(self):
x = numpy.array(
[[1, 0, 0],
[0, 2, 0],
[0, 0, 3]])
expected_ans = 6
self.traceOp(x, numpy.int32, expected_ans)
self.traceOp(x, numpy.int64, expected_ans)
def testRankTwoFloatTensor(self):
x = numpy.array(
[[1.1, 0, 0],
[0, 2.2, 0],
[0, 0, 3.3]])
expected_ans = 6.6
self.traceOp(x, numpy.float32, expected_ans)
self.traceOp(x, numpy.float64, expected_ans)
def testRankThreeFloatTensor(self):
x = numpy.random.rand(2, 2, 2)
self.assertRaises(ValueError, self.traceOp, x, numpy.float32, 0)
def testRankFourFloatTensor(self):
x = numpy.random.rand(2, 2, 2, 2)
self.assertRaises(ValueError, self.traceOp, x, numpy.float32, 0)
def testTrace(self):
for dtype in [np.int32, np.float32, np.float64]:
for shape in [[2, 2], [2, 3], [3, 2], [2, 3, 2], [2, 2, 2, 3]]:
x = np.random.rand(np.prod(shape)).astype(dtype).reshape(shape)
self.compare(x)
if __name__ == "__main__":

View File

@ -233,20 +233,30 @@ def _MatrixDiagGrad(_, grad):
@ops.RegisterGradient("MatrixDiagPart")
def _MatrixDiagPartGrad(_, grad):
return array_ops.matrix_diag(grad)
def _MatrixDiagPartGrad(op, grad):
matrix_shape = op.inputs[0].get_shape()[-2:]
if matrix_shape.is_fully_defined() and matrix_shape[0] == matrix_shape[1]:
return array_ops.matrix_diag(grad)
else:
return array_ops.matrix_set_diag(array_ops.zeros_like(op.inputs[0]), grad)
@ops.RegisterGradient("MatrixSetDiag")
def _MatrixSetDiagGrad(op, grad):
input_shape = op.inputs[0].get_shape().merge_with(grad.get_shape())
diag_shape = op.inputs[1].get_shape()
diag_shape = diag_shape.merge_with(op.inputs[0].get_shape()[:-1])
diag_shape = diag_shape.merge_with(grad.get_shape()[:-1])
if diag_shape.is_fully_defined():
diag_shape = diag_shape.as_list()
batch_shape = input_shape[:-2].merge_with(diag_shape[:-1])
matrix_shape = input_shape[-2:]
if batch_shape.is_fully_defined() and matrix_shape.is_fully_defined():
diag_shape = batch_shape.as_list() + [min(matrix_shape.as_list())]
else:
diag_shape = array_ops.shape(grad)
diag_shape = array_ops.slice(diag_shape, [0], [array_ops.rank(grad) - 1])
with ops.colocate_with(grad):
grad_shape = array_ops.shape(grad)
grad_rank = array_ops.rank(grad)
batch_shape = array_ops.slice(grad_shape, [0], [grad_rank - 2])
matrix_shape = array_ops.slice(grad_shape, [grad_rank - 2], [2])
min_dim = math_ops.reduce_min(matrix_shape)
diag_shape = array_ops.concat(0, [batch_shape, [min_dim]])
grad_input = array_ops.matrix_set_diag(
grad, array_ops.zeros(
diag_shape, dtype=grad.dtype))

View File

@ -1299,23 +1299,35 @@ def reduce_logsumexp(input_tensor, reduction_indices=None, keep_dims=False,
def trace(x, name=None):
""" Compute the trace of a tensor `x`.
`trace(x)` returns the sum of along the diagonal.
`trace(x)` returns the sum along the main diagonal of each inner-most matrix
in x. If x is of rank `k` with shape `[I, J, K, ..., L, M, N]`, then output
is a tensor of rank `k-2` with dimensions `[I, J, K, ..., L]` where
`output[i, j, k, ..., l] = trace(x[i, j, i, ..., l, :, :])`
For example:
```python
# 'x' is [[1, 1],
# [1, 1]]
tf.trace(x) ==> 2
# 'x' is [[1, 2],
# [3, 4]]
tf.trace(x) ==> 5
# 'x' is [[1,2,3],
# [4,5,6],
# [7,8,9]]
tf.trace(x) ==> 15
# 'x' is [[[1,2,3],
# [4,5,6],
# [7,8,9]],
# [[-1,-2,-3],
# [-4,-5,-6],
# [-7,-8,-9]]]
tf.trace(x) ==> [15,-15]
```
Args:
x: 2-D tensor.
x: tensor.
name: A name for the operation (optional).
Returns:
@ -1323,10 +1335,7 @@ def trace(x, name=None):
"""
with ops.name_scope(name, "Trace", [x]) as name:
x = ops.convert_to_tensor(x, name="x")
if len(x.get_shape()) != 2:
raise ValueError("Expected a tensor with rank 2, rank %d tensor received"
% len(x.get_shape()))
return reduce_sum(array_ops.diag_part(x), name=name)
return reduce_sum(array_ops.matrix_diag_part(x), [-1], name=name)
def matmul(a, b,