Update trace, matrix_set_diag, matrix_diag_part and their gradients to work for rectangular matrices.
Generalize trace to work like numpy.trace(x, axis1=-2, axis2=-1), including for rank > 2. Fix bad doc string for matrix_band_part. Change: 134700928
This commit is contained in:
parent
b7d5df182b
commit
2a5a96976d
@ -226,7 +226,7 @@ class OperatorPDCholeskyTest(tf.test.TestCase):
|
||||
# should raise.
|
||||
with self.test_session():
|
||||
batch_vec = [[1.0], [2.0]] # shape 2 x 1
|
||||
with self.assertRaisesRegexp(ValueError, ".*Dimensions.*"):
|
||||
with self.assertRaisesOpError("x == y did not hold"):
|
||||
operator = operator_pd_cholesky.OperatorPDCholesky(batch_vec)
|
||||
operator.to_dense().eval()
|
||||
|
||||
|
@ -58,23 +58,19 @@ class MatrixDiagPartOp : public OpKernel {
|
||||
"input must be at least 2-dim, received shape: ",
|
||||
input.shape().DebugString()));
|
||||
|
||||
// Check to make sure the last two dimensions have the same value
|
||||
const int64 k = input_shape.dim_size(rank - 1);
|
||||
OP_REQUIRES(
|
||||
context, k == input_shape.dim_size(rank - 2),
|
||||
errors::InvalidArgument(
|
||||
"input's last two dimensions must be equal, received shape: ",
|
||||
input.shape().DebugString()));
|
||||
|
||||
auto input_reshaped = input.flat_inner_dims<T, 3>();
|
||||
|
||||
TensorShape output_shape = input_shape;
|
||||
output_shape.RemoveDim(rank - 1);
|
||||
TensorShape output_shape;
|
||||
for (int i = 0; i < rank - 2; ++i) {
|
||||
output_shape.AddDim(input_shape.dim_size(i));
|
||||
}
|
||||
const int64 min_dim = std::min(input_shape.dim_size(rank - 2),
|
||||
input_shape.dim_size(rank - 1));
|
||||
output_shape.AddDim(min_dim);
|
||||
|
||||
Tensor* output = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
|
||||
|
||||
auto output_reshaped = output->flat_inner_dims<T, 2>();
|
||||
auto input_reshaped = input.flat_inner_dims<T, 3>();
|
||||
|
||||
functor::MatrixDiagPart<Device, T>::Compute(
|
||||
context->eigen_device<Device>(), input_reshaped, output_reshaped);
|
||||
@ -101,7 +97,6 @@ class MatrixDiagOp : public OpKernel {
|
||||
"input must be at least 1-dim, received shape: ",
|
||||
input.shape().DebugString()));
|
||||
|
||||
// Check to make sure the last two dimensions have the same value
|
||||
const int64 k = input_shape.dim_size(rank - 1);
|
||||
auto input_reshaped = input.flat_inner_dims<T, 2>();
|
||||
|
||||
|
@ -59,21 +59,18 @@ class MatrixSetDiagOp : public OpKernel {
|
||||
"input must be at least 2-dim, received shape: ",
|
||||
input.shape().DebugString()));
|
||||
|
||||
// Check to make sure the last two dimensions have the same value
|
||||
const int64 k = input_shape.dim_size(rank - 1);
|
||||
OP_REQUIRES(
|
||||
context, k == input_shape.dim_size(rank - 2),
|
||||
errors::InvalidArgument(
|
||||
"input's last two dimensions must be equal, received shape: ",
|
||||
input.shape().DebugString()));
|
||||
|
||||
TensorShape input_shape_but_one = input_shape;
|
||||
input_shape_but_one.RemoveDim(rank - 1);
|
||||
|
||||
OP_REQUIRES(context, input_shape_but_one == diag_shape,
|
||||
// Check to make sure the last dimension of diag is equal to the smaller of
|
||||
// the last two dimensions of input.
|
||||
const int64 min_dim = std::min(input_shape.dim_size(rank - 1),
|
||||
input_shape.dim_size(rank - 2));
|
||||
TensorShape expected_diag_shape = input_shape;
|
||||
expected_diag_shape.RemoveDim(rank - 1);
|
||||
expected_diag_shape.RemoveDim(rank - 2);
|
||||
expected_diag_shape.AddDim(min_dim);
|
||||
OP_REQUIRES(context, expected_diag_shape == diag_shape,
|
||||
errors::InvalidArgument(
|
||||
"must have diagonal.shape == input.shape[:-1], but "
|
||||
"received input shape: ",
|
||||
"must have diagonal.shape == input.shape[:-2] + "
|
||||
"min(input.shape[-2:]), but received input shape: ",
|
||||
input_shape.DebugString(), " and diagonal shape: ",
|
||||
diag_shape.DebugString()));
|
||||
|
||||
@ -127,7 +124,7 @@ struct MatrixSetDiag<CPUDevice, T> {
|
||||
typename TTypes<T, 3>::Tensor output) {
|
||||
output.device(d) = input;
|
||||
for (int64 r = 0; r < output.dimension(0); ++r) {
|
||||
for (int64 d = 0; d < output.dimension(1); ++d) {
|
||||
for (int64 d = 0; d < diag.dimension(1); ++d) {
|
||||
output(r, d, d) = diag(r, d);
|
||||
}
|
||||
}
|
||||
|
@ -554,16 +554,24 @@ REGISTER_OP("MatrixSetDiag")
|
||||
ShapeHandle diag;
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &diag));
|
||||
|
||||
DimensionHandle square_dim;
|
||||
if (c->RankKnown(input)) {
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), c->Rank(input) - 1, &diag));
|
||||
}
|
||||
DimensionHandle smallest_dim;
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Merge(c->Dim(input, -2), c->Dim(input, -1), &square_dim));
|
||||
TF_RETURN_IF_ERROR(c->Merge(square_dim, c->Dim(diag, -1), &square_dim));
|
||||
|
||||
ShapeHandle output;
|
||||
TF_RETURN_IF_ERROR(c->Concatenate(diag, c->Vector(square_dim), &output));
|
||||
TF_RETURN_IF_ERROR(c->Merge(input, output, &output));
|
||||
c->Min(c->Dim(input, -2), c->Dim(input, -1), &smallest_dim));
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Merge(smallest_dim, c->Dim(diag, -1), &smallest_dim));
|
||||
|
||||
ShapeHandle output = input;
|
||||
if (c->RankKnown(diag) && !c->FullyDefined(input)) {
|
||||
// Try to infer parts of shape from diag.
|
||||
ShapeHandle diag_prefix;
|
||||
TF_RETURN_IF_ERROR(c->Subshape(diag, 0, -1, &diag_prefix));
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Concatenate(diag_prefix, c->UnknownShapeOfRank(2), &diag));
|
||||
TF_RETURN_IF_ERROR(c->Merge(input, diag, &output));
|
||||
}
|
||||
c->set_output(0, output);
|
||||
return Status::OK();
|
||||
})
|
||||
@ -571,15 +579,14 @@ REGISTER_OP("MatrixSetDiag")
|
||||
Returns a batched matrix tensor with new batched diagonal values.
|
||||
|
||||
Given `input` and `diagonal`, this operation returns a tensor with the
|
||||
same shape and values as `input`, except for the diagonals of the innermost
|
||||
matrices. These will be overwritten by the values in `diagonal`.
|
||||
The batched matrices must be square.
|
||||
same shape and values as `input`, except for the main diagonal of the
|
||||
innermost matrices. These will be overwritten by the values in `diagonal`.
|
||||
|
||||
The output is computed as follows:
|
||||
|
||||
Assume `input` has `k+1` dimensions `[I, J, K, ..., N, N]` and `diagonal` has
|
||||
`k` dimensions `[I, J, K, ..., N]`. Then the output is a
|
||||
tensor of rank `k+1` with dimensions [I, J, K, ..., N, N]` where:
|
||||
Assume `input` has `k+1` dimensions `[I, J, K, ..., M, N]` and `diagonal` has
|
||||
`k` dimensions `[I, J, K, ..., min(M, N)]`. Then the output is a
|
||||
tensor of rank `k+1` with dimensions `[I, J, K, ..., M, N]` where:
|
||||
|
||||
* `output[i, j, k, ..., m, n] = diagonal[i, j, k, ..., n]` for `m == n`.
|
||||
* `output[i, j, k, ..., m, n] = input[i, j, k, ..., m, n]` for `m != n`.
|
||||
@ -602,14 +609,13 @@ REGISTER_OP("MatrixDiagPart")
|
||||
return Status::OK();
|
||||
}
|
||||
const int32 rank = c->Rank(in);
|
||||
// Last two dims must match.
|
||||
DimensionHandle unused;
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Merge(c->Dim(in, rank - 1), c->Dim(in, rank - 2), &unused));
|
||||
|
||||
// Output shape has all dims but last of input.
|
||||
std::vector<DimensionHandle> dims;
|
||||
for (int i = 0; i < rank - 1; ++i) dims.push_back(c->Dim(in, i));
|
||||
for (int i = 0; i < rank - 2; ++i) dims.push_back(c->Dim(in, i));
|
||||
|
||||
DimensionHandle min_dim;
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Min(c->Dim(in, rank - 2), c->Dim(in, rank - 1), &min_dim));
|
||||
dims.push_back(min_dim);
|
||||
c->set_output(0, c->MakeShape(dims));
|
||||
return Status::OK();
|
||||
})
|
||||
@ -619,8 +625,8 @@ Returns the batched diagonal part of a batched tensor.
|
||||
This operation returns a tensor with the `diagonal` part
|
||||
of the batched `input`. The `diagonal` part is computed as follows:
|
||||
|
||||
Assume `input` has `k` dimensions `[I, J, K, ..., N, N]`, then the output is a
|
||||
tensor of rank `k - 1` with dimensions `[I, J, K, ..., N]` where:
|
||||
Assume `input` has `k` dimensions `[I, J, K, ..., M, N]`, then the output is a
|
||||
tensor of rank `k - 1` with dimensions `[I, J, K, ..., min(M, N)]` where:
|
||||
|
||||
`diagonal[i, j, k, ..., n] = input[i, j, k, ..., n, n]`.
|
||||
|
||||
@ -645,9 +651,9 @@ tf.matrix_diag_part(input) ==> [[1, 2, 3, 4], [5, 6, 7, 8]]
|
||||
which has shape (2, 4)
|
||||
```
|
||||
|
||||
input: Rank `k` tensor where `k >= 2` and the last two dimensions are equal.
|
||||
input: Rank `k` tensor where `k >= 2`.
|
||||
diagonal: The extracted diagonal(s) having shape
|
||||
`diagonal.shape = input.shape[:-1]`.
|
||||
`diagonal.shape = input.shape[:-2] + [min(input.shape[-2:])]`.
|
||||
)doc");
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
@ -668,9 +674,10 @@ tensor with the same shape where
|
||||
|
||||
`band[i, j, k, ..., m, n] = in_band(m, n) * input[i, j, k, ..., m, n]`.
|
||||
|
||||
The indicator function 'in_band(m, n)` is one if
|
||||
`(num_lower < 0 || (m-n) <= num_lower)) &&
|
||||
(num_upper < 0 || (n-m) <= num_upper)`, and zero otherwise.
|
||||
The indicator function
|
||||
|
||||
`in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower)) &&
|
||||
(num_upper < 0 || (n-m) <= num_upper)`.
|
||||
|
||||
For example:
|
||||
|
||||
@ -681,14 +688,14 @@ For example:
|
||||
[-3, -2, -1, 0]],
|
||||
|
||||
tf.matrix_band_part(input, 1, -1) ==> [[ 0, 1, 2, 3]
|
||||
[-1, 0, 1, 2]
|
||||
[ 0, -1, 0, 1]
|
||||
[ 0, 0, -1, 0]],
|
||||
[-1, 0, 1, 2]
|
||||
[ 0, -1, 0, 1]
|
||||
[ 0, 0, -1, 0]],
|
||||
|
||||
tf.matrix_band_part(input, 2, 1) ==> [[ 0, 1, 0, 0]
|
||||
[-1, 0, 1, 0]
|
||||
[-2, -1, 0, 1]
|
||||
[ 0, -2, -1, 0]]
|
||||
[-1, 0, 1, 0]
|
||||
[-2, -1, 0, 1]
|
||||
[ 0, -2, -1, 0]]
|
||||
```
|
||||
|
||||
Useful special cases:
|
||||
|
@ -185,7 +185,8 @@ TEST(ArrayOpsTest, MatrixDiagPart_ShapeFn) {
|
||||
INFER_OK(op, "?", "?");
|
||||
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[?]");
|
||||
INFER_OK(op, "[?,1,2,2]", "[d0_0,d0_1,d0_2|d0_3]");
|
||||
INFER_ERROR("Dimensions must be equal, but are 3 and 2", op, "[1,2,3]");
|
||||
INFER_OK(op, "[?,1,2,3]", "[d0_0,d0_1,d0_2]");
|
||||
INFER_OK(op, "[?,1,3,2]", "[d0_0,d0_1,d0_3]");
|
||||
}
|
||||
|
||||
TEST(ArrayOpsTest, Reverse_ShapeFn) {
|
||||
@ -364,19 +365,25 @@ TEST(ArrayOpsTest, MatrixSetDiag_ShapeFn) {
|
||||
// Rank checks.
|
||||
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1];?");
|
||||
INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "?;[]");
|
||||
INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[2,2];[]");
|
||||
INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[2,2];[2,2]");
|
||||
|
||||
// Output matches input, and also matches diagonal + diagonal.dim(-1).
|
||||
// diagonal[-1] must match smallest matrix dimension.
|
||||
INFER_ERROR("Dimensions must be equal, but are 2 and 3", op, "[2,3];[3]");
|
||||
|
||||
// Output matches input.
|
||||
INFER_OK(op, "?;?", "?");
|
||||
INFER_OK(op, "?;[1,2]", "[d1_0,d1_1,d1_1]");
|
||||
INFER_OK(op, "[1,2,2];?", "in0");
|
||||
INFER_OK(op, "[1,2,2];[1,2]", "in0");
|
||||
INFER_OK(op, "[1,2,3];?", "in0");
|
||||
INFER_OK(op, "[1,3,2];?", "in0");
|
||||
INFER_OK(op, "[1,?,2];[?,?]", "in0");
|
||||
INFER_OK(op, "[1,?,?];[?,2]", "[d0_0,d1_1,d1_1]");
|
||||
INFER_OK(op, "[1,?,?];[?,2]", "in0");
|
||||
|
||||
// Last 2 dims of input must match.
|
||||
INFER_ERROR("Dimensions must be equal, but are 2 and 3", op, "[1,2,3];?");
|
||||
|
||||
// Dims matches prefix of input.
|
||||
INFER_ERROR("Dimensions must be equal, but are 1 and 2", op, "[1,?];[2]");
|
||||
// Infer batch shape from diag when input is not fully specified.
|
||||
INFER_OK(op, "?;[1,2]", "[d1_0,?,?]");
|
||||
INFER_OK(op, "[?,?,3];[1,2]", "[d1_0,d0_1,d0_2]");
|
||||
INFER_OK(op, "[?,3,?];[1,2]", "[d1_0,d0_1,d0_2]");
|
||||
INFER_OK(op, "[?,3,2];[1,2]", "[d1_0,d0_1,d0_2]");
|
||||
}
|
||||
|
||||
TEST(ArrayOpsTest, ExpandDims_ShapeFn) {
|
||||
|
@ -75,7 +75,7 @@ class MatrixDiagGpuTest(MatrixDiagTest):
|
||||
class MatrixSetDiagTest(tf.test.TestCase):
|
||||
_use_gpu = False
|
||||
|
||||
def testVector(self):
|
||||
def testSquare(self):
|
||||
with self.test_session(use_gpu=self._use_gpu):
|
||||
v = np.array([1.0, 2.0, 3.0])
|
||||
mat = np.array([[0.0, 1.0, 0.0],
|
||||
@ -88,7 +88,23 @@ class MatrixSetDiagTest(tf.test.TestCase):
|
||||
self.assertEqual((3, 3), output.get_shape())
|
||||
self.assertAllEqual(mat_set_diag, output.eval())
|
||||
|
||||
def testBatchVector(self):
|
||||
def testRectangular(self):
|
||||
with self.test_session(use_gpu=self._use_gpu):
|
||||
v = np.array([3.0, 4.0])
|
||||
mat = np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0]])
|
||||
expected = np.array([[3.0, 1.0, 0.0], [1.0, 4.0, 1.0]])
|
||||
output = tf.matrix_set_diag(mat, v)
|
||||
self.assertEqual((2, 3), output.get_shape())
|
||||
self.assertAllEqual(expected, output.eval())
|
||||
|
||||
v = np.array([3.0, 4.0])
|
||||
mat = np.array([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0]])
|
||||
expected = np.array([[3.0, 1.0], [1.0, 4.0], [1.0, 1.0]])
|
||||
output = tf.matrix_set_diag(mat, v)
|
||||
self.assertEqual((3, 2), output.get_shape())
|
||||
self.assertAllEqual(expected, output.eval())
|
||||
|
||||
def testSquareBatch(self):
|
||||
with self.test_session(use_gpu=self._use_gpu):
|
||||
v_batch = np.array([[-1.0, -2.0, -3.0],
|
||||
[-4.0, -5.0, -6.0]])
|
||||
@ -111,6 +127,25 @@ class MatrixSetDiagTest(tf.test.TestCase):
|
||||
self.assertEqual((2, 3, 3), output.get_shape())
|
||||
self.assertAllEqual(mat_set_diag_batch, output.eval())
|
||||
|
||||
def testRectangularBatch(self):
|
||||
with self.test_session(use_gpu=self._use_gpu):
|
||||
v_batch = np.array([[-1.0, -2.0],
|
||||
[-4.0, -5.0]])
|
||||
mat_batch = np.array(
|
||||
[[[1.0, 0.0, 3.0],
|
||||
[0.0, 2.0, 0.0]],
|
||||
[[4.0, 0.0, 4.0],
|
||||
[0.0, 5.0, 0.0]]])
|
||||
|
||||
mat_set_diag_batch = np.array(
|
||||
[[[-1.0, 0.0, 3.0],
|
||||
[0.0, -2.0, 0.0]],
|
||||
[[-4.0, 0.0, 4.0],
|
||||
[0.0, -5.0, 0.0]]])
|
||||
output = tf.matrix_set_diag(mat_batch, v_batch)
|
||||
self.assertEqual((2, 2, 3), output.get_shape())
|
||||
self.assertAllEqual(mat_set_diag_batch, output.eval())
|
||||
|
||||
def testInvalidShape(self):
|
||||
with self.assertRaisesRegexp(ValueError, "must be at least rank 2"):
|
||||
tf.matrix_set_diag(0, [0])
|
||||
@ -127,11 +162,12 @@ class MatrixSetDiagTest(tf.test.TestCase):
|
||||
tf.matrix_set_diag([[v]], v).eval(feed_dict={v: 0.0})
|
||||
|
||||
def testGrad(self):
|
||||
shapes = ((3, 4, 4), (7, 4, 8, 8))
|
||||
shapes = ((3, 4, 4), (3, 3, 4), (3, 4, 3), (7, 4, 8, 8))
|
||||
with self.test_session(use_gpu=self._use_gpu):
|
||||
for shape in shapes:
|
||||
x = tf.constant(np.random.rand(*shape), dtype=tf.float32)
|
||||
x_diag = tf.constant(np.random.rand(*shape[:-1]), dtype=tf.float32)
|
||||
diag_shape = shape[:-2] + (min(shape[-2:]),)
|
||||
x_diag = tf.constant(np.random.rand(*diag_shape), dtype=tf.float32)
|
||||
y = tf.matrix_set_diag(x, x_diag)
|
||||
error_x = tf.test.compute_gradient_error(x, x.get_shape().as_list(),
|
||||
y, y.get_shape().as_list())
|
||||
@ -164,7 +200,7 @@ class MatrixSetDiagGpuTest(MatrixSetDiagTest):
|
||||
class MatrixDiagPartTest(tf.test.TestCase):
|
||||
_use_gpu = False
|
||||
|
||||
def testMatrix(self):
|
||||
def testSquare(self):
|
||||
with self.test_session(use_gpu=self._use_gpu):
|
||||
v = np.array([1.0, 2.0, 3.0])
|
||||
mat = np.diag(v)
|
||||
@ -172,7 +208,16 @@ class MatrixDiagPartTest(tf.test.TestCase):
|
||||
self.assertEqual((3,), mat_diag.get_shape())
|
||||
self.assertAllEqual(mat_diag.eval(), v)
|
||||
|
||||
def testBatchMatrix(self):
|
||||
def testRectangular(self):
|
||||
with self.test_session(use_gpu=self._use_gpu):
|
||||
mat = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
|
||||
mat_diag = tf.matrix_diag_part(mat)
|
||||
self.assertAllEqual(mat_diag.eval(), np.array([1.0, 5.0]))
|
||||
mat = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
|
||||
mat_diag = tf.matrix_diag_part(mat)
|
||||
self.assertAllEqual(mat_diag.eval(), np.array([1.0, 4.0]))
|
||||
|
||||
def testSquareBatch(self):
|
||||
with self.test_session(use_gpu=self._use_gpu):
|
||||
v_batch = np.array([[1.0, 2.0, 3.0],
|
||||
[4.0, 5.0, 6.0]])
|
||||
@ -188,22 +233,32 @@ class MatrixDiagPartTest(tf.test.TestCase):
|
||||
self.assertEqual((2, 3), mat_batch_diag.get_shape())
|
||||
self.assertAllEqual(mat_batch_diag.eval(), v_batch)
|
||||
|
||||
def testRectangularBatch(self):
|
||||
with self.test_session(use_gpu=self._use_gpu):
|
||||
v_batch = np.array([[1.0, 2.0],
|
||||
[4.0, 5.0]])
|
||||
mat_batch = np.array(
|
||||
[[[1.0, 0.0, 0.0],
|
||||
[0.0, 2.0, 0.0]],
|
||||
[[4.0, 0.0, 0.0],
|
||||
[0.0, 5.0, 0.0]]])
|
||||
self.assertEqual(mat_batch.shape, (2, 2, 3))
|
||||
mat_batch_diag = tf.matrix_diag_part(mat_batch)
|
||||
self.assertEqual((2, 2), mat_batch_diag.get_shape())
|
||||
self.assertAllEqual(mat_batch_diag.eval(), v_batch)
|
||||
|
||||
def testInvalidShape(self):
|
||||
with self.assertRaisesRegexp(ValueError, "must be at least rank 2"):
|
||||
tf.matrix_diag_part(0)
|
||||
with self.assertRaisesRegexp(ValueError, r"Dimensions must be equal"):
|
||||
tf.matrix_diag_part([[0, 1], [1, 0], [0, 0]])
|
||||
|
||||
def testInvalidShapeAtEval(self):
|
||||
with self.test_session(use_gpu=self._use_gpu):
|
||||
v = tf.placeholder(dtype=tf.float32)
|
||||
with self.assertRaisesOpError("input must be at least 2-dim"):
|
||||
tf.matrix_diag_part(v).eval(feed_dict={v: 0.0})
|
||||
with self.assertRaisesOpError("last two dimensions must be equal"):
|
||||
tf.matrix_diag_part(v).eval(feed_dict={v: [[0, 1], [1, 0], [0, 0]]})
|
||||
|
||||
def testGrad(self):
|
||||
shapes = ((3, 3), (5, 3, 3))
|
||||
shapes = ((3, 3), (2, 3), (3, 2), (5, 3, 3))
|
||||
with self.test_session(use_gpu=self._use_gpu):
|
||||
for shape in shapes:
|
||||
x = tf.constant(np.random.rand(*shape), dtype=np.float32)
|
||||
|
@ -17,54 +17,26 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class TraceTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
x = numpy.random.seed(0)
|
||||
x = np.random.seed(0)
|
||||
|
||||
def traceOp(self, x, dtype, expected_ans, use_gpu=False):
|
||||
with self.test_session(use_gpu=use_gpu):
|
||||
tf_ans = tf.trace(x.astype(dtype))
|
||||
out = tf_ans.eval()
|
||||
self.assertAllClose(out, expected_ans)
|
||||
def compare(self, x):
|
||||
np_ans = np.trace(x, axis1=-2, axis2=-1)
|
||||
with self.test_session(use_gpu=True):
|
||||
tf_ans = tf.trace(x).eval()
|
||||
self.assertAllClose(tf_ans, np_ans)
|
||||
|
||||
def testEmptyTensor(self):
|
||||
x = numpy.array([])
|
||||
self.assertRaises(ValueError, self.traceOp, x, numpy.float32, 0)
|
||||
|
||||
def testRankOneTensor(self):
|
||||
x = numpy.array([1,2,3])
|
||||
self.assertRaises(ValueError, self.traceOp, x, numpy.float32, 0)
|
||||
|
||||
def testRankTwoIntTensor(self):
|
||||
x = numpy.array(
|
||||
[[1, 0, 0],
|
||||
[0, 2, 0],
|
||||
[0, 0, 3]])
|
||||
expected_ans = 6
|
||||
self.traceOp(x, numpy.int32, expected_ans)
|
||||
self.traceOp(x, numpy.int64, expected_ans)
|
||||
|
||||
def testRankTwoFloatTensor(self):
|
||||
x = numpy.array(
|
||||
[[1.1, 0, 0],
|
||||
[0, 2.2, 0],
|
||||
[0, 0, 3.3]])
|
||||
expected_ans = 6.6
|
||||
self.traceOp(x, numpy.float32, expected_ans)
|
||||
self.traceOp(x, numpy.float64, expected_ans)
|
||||
|
||||
def testRankThreeFloatTensor(self):
|
||||
x = numpy.random.rand(2, 2, 2)
|
||||
self.assertRaises(ValueError, self.traceOp, x, numpy.float32, 0)
|
||||
|
||||
def testRankFourFloatTensor(self):
|
||||
x = numpy.random.rand(2, 2, 2, 2)
|
||||
self.assertRaises(ValueError, self.traceOp, x, numpy.float32, 0)
|
||||
def testTrace(self):
|
||||
for dtype in [np.int32, np.float32, np.float64]:
|
||||
for shape in [[2, 2], [2, 3], [3, 2], [2, 3, 2], [2, 2, 2, 3]]:
|
||||
x = np.random.rand(np.prod(shape)).astype(dtype).reshape(shape)
|
||||
self.compare(x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -233,20 +233,30 @@ def _MatrixDiagGrad(_, grad):
|
||||
|
||||
|
||||
@ops.RegisterGradient("MatrixDiagPart")
|
||||
def _MatrixDiagPartGrad(_, grad):
|
||||
return array_ops.matrix_diag(grad)
|
||||
def _MatrixDiagPartGrad(op, grad):
|
||||
matrix_shape = op.inputs[0].get_shape()[-2:]
|
||||
if matrix_shape.is_fully_defined() and matrix_shape[0] == matrix_shape[1]:
|
||||
return array_ops.matrix_diag(grad)
|
||||
else:
|
||||
return array_ops.matrix_set_diag(array_ops.zeros_like(op.inputs[0]), grad)
|
||||
|
||||
|
||||
@ops.RegisterGradient("MatrixSetDiag")
|
||||
def _MatrixSetDiagGrad(op, grad):
|
||||
input_shape = op.inputs[0].get_shape().merge_with(grad.get_shape())
|
||||
diag_shape = op.inputs[1].get_shape()
|
||||
diag_shape = diag_shape.merge_with(op.inputs[0].get_shape()[:-1])
|
||||
diag_shape = diag_shape.merge_with(grad.get_shape()[:-1])
|
||||
if diag_shape.is_fully_defined():
|
||||
diag_shape = diag_shape.as_list()
|
||||
batch_shape = input_shape[:-2].merge_with(diag_shape[:-1])
|
||||
matrix_shape = input_shape[-2:]
|
||||
if batch_shape.is_fully_defined() and matrix_shape.is_fully_defined():
|
||||
diag_shape = batch_shape.as_list() + [min(matrix_shape.as_list())]
|
||||
else:
|
||||
diag_shape = array_ops.shape(grad)
|
||||
diag_shape = array_ops.slice(diag_shape, [0], [array_ops.rank(grad) - 1])
|
||||
with ops.colocate_with(grad):
|
||||
grad_shape = array_ops.shape(grad)
|
||||
grad_rank = array_ops.rank(grad)
|
||||
batch_shape = array_ops.slice(grad_shape, [0], [grad_rank - 2])
|
||||
matrix_shape = array_ops.slice(grad_shape, [grad_rank - 2], [2])
|
||||
min_dim = math_ops.reduce_min(matrix_shape)
|
||||
diag_shape = array_ops.concat(0, [batch_shape, [min_dim]])
|
||||
grad_input = array_ops.matrix_set_diag(
|
||||
grad, array_ops.zeros(
|
||||
diag_shape, dtype=grad.dtype))
|
||||
|
@ -1299,23 +1299,35 @@ def reduce_logsumexp(input_tensor, reduction_indices=None, keep_dims=False,
|
||||
def trace(x, name=None):
|
||||
""" Compute the trace of a tensor `x`.
|
||||
|
||||
`trace(x)` returns the sum of along the diagonal.
|
||||
`trace(x)` returns the sum along the main diagonal of each inner-most matrix
|
||||
in x. If x is of rank `k` with shape `[I, J, K, ..., L, M, N]`, then output
|
||||
is a tensor of rank `k-2` with dimensions `[I, J, K, ..., L]` where
|
||||
|
||||
`output[i, j, k, ..., l] = trace(x[i, j, i, ..., l, :, :])`
|
||||
|
||||
For example:
|
||||
|
||||
```python
|
||||
# 'x' is [[1, 1],
|
||||
# [1, 1]]
|
||||
tf.trace(x) ==> 2
|
||||
# 'x' is [[1, 2],
|
||||
# [3, 4]]
|
||||
tf.trace(x) ==> 5
|
||||
|
||||
# 'x' is [[1,2,3],
|
||||
# [4,5,6],
|
||||
# [7,8,9]]
|
||||
tf.trace(x) ==> 15
|
||||
|
||||
# 'x' is [[[1,2,3],
|
||||
# [4,5,6],
|
||||
# [7,8,9]],
|
||||
# [[-1,-2,-3],
|
||||
# [-4,-5,-6],
|
||||
# [-7,-8,-9]]]
|
||||
tf.trace(x) ==> [15,-15]
|
||||
```
|
||||
|
||||
Args:
|
||||
x: 2-D tensor.
|
||||
x: tensor.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
@ -1323,10 +1335,7 @@ def trace(x, name=None):
|
||||
"""
|
||||
with ops.name_scope(name, "Trace", [x]) as name:
|
||||
x = ops.convert_to_tensor(x, name="x")
|
||||
if len(x.get_shape()) != 2:
|
||||
raise ValueError("Expected a tensor with rank 2, rank %d tensor received"
|
||||
% len(x.get_shape()))
|
||||
return reduce_sum(array_ops.diag_part(x), name=name)
|
||||
return reduce_sum(array_ops.matrix_diag_part(x), [-1], name=name)
|
||||
|
||||
|
||||
def matmul(a, b,
|
||||
|
Loading…
Reference in New Issue
Block a user