STT-tensorflow/tensorflow/python/kernel_tests/diag_op_test.py
Srinivas Vasudevan f75c37faf3 Remove forward_compatibility / forward_compatible checks for dates that have already passed.
- Also, removed print statements from relu_op_test.py

PiperOrigin-RevId: 287911742
Change-Id: Ib1763a5a010e5738e4d93e348391839e1e164108
2020-01-02 16:14:15 -08:00

1200 lines
48 KiB
Python

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes as dtypes_lib
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import gradients_impl
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
default_v2_alignment = "LEFT_LEFT"
alignment_list = ["RIGHT_LEFT", "LEFT_RIGHT", "LEFT_LEFT", "RIGHT_RIGHT"]
def zip_to_first_list_length(a, b):
if len(b) > len(a):
return zip(a, b[:len(a)])
return zip(a, b + [None] * (len(a) - len(b)))
def repack_diagonals(packed_diagonals,
diag_index,
num_rows,
num_cols,
align=None):
# The original test cases are LEFT_LEFT aligned.
if align == default_v2_alignment or align is None:
return packed_diagonals
align = align.split("_")
d_lower, d_upper = diag_index
batch_dims = packed_diagonals.ndim - (2 if d_lower < d_upper else 1)
max_diag_len = packed_diagonals.shape[-1]
index = (slice(None),) * batch_dims
repacked_diagonals = np.zeros_like(packed_diagonals)
# Aligns each diagonal row-by-row.
for diag_index in range(d_lower, d_upper + 1):
diag_len = min(num_rows + min(0, diag_index), num_cols - max(0, diag_index))
row_index = d_upper - diag_index
padding_len = max_diag_len - diag_len
left_align = (diag_index >= 0 and
align[0] == "LEFT") or (diag_index <= 0 and
align[1] == "LEFT")
# Prepares index tuples.
extra_dim = tuple() if d_lower == d_upper else (row_index,)
packed_last_dim = (slice(None),) if left_align else (slice(0, diag_len, 1),)
repacked_last_dim = (slice(None),) if left_align else (slice(
padding_len, max_diag_len, 1),)
packed_index = index + extra_dim + packed_last_dim
repacked_index = index + extra_dim + repacked_last_dim
# Repacks the diagonal.
repacked_diagonals[repacked_index] = packed_diagonals[packed_index]
return repacked_diagonals
def repack_diagonals_in_tests(tests, align=None):
# The original test cases are LEFT_LEFT aligned.
if align == default_v2_alignment or align is None:
return tests
new_tests = dict()
# Loops through each case.
for diag_index, (packed_diagonals, padded_diagonals) in tests.items():
num_rows, num_cols = padded_diagonals.shape[-2:]
repacked_diagonals = repack_diagonals(
packed_diagonals, diag_index, num_rows, num_cols, align=align)
new_tests[diag_index] = (repacked_diagonals, padded_diagonals)
return new_tests
# Test cases shared by MatrixDiagV2, MatrixDiagPartV2, and MatrixSetDiagV2.
def square_cases(align=None):
# pyformat: disable
mat = np.array([[[1, 2, 3, 4, 5],
[6, 7, 8, 9, 1],
[3, 4, 5, 6, 7],
[8, 9, 1, 2, 3],
[4, 5, 6, 7, 8]],
[[9, 1, 2, 3, 4],
[5, 6, 7, 8, 9],
[1, 2, 3, 4, 5],
[6, 7, 8, 9, 1],
[2, 3, 4, 5, 6]]])
tests = dict()
# tests[d_lower, d_upper] = (packed_diagonals, padded_diagonals)
tests[-1, -1] = (np.array([[6, 4, 1, 7],
[5, 2, 8, 5]]),
np.array([[[0, 0, 0, 0, 0],
[6, 0, 0, 0, 0],
[0, 4, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 7, 0]],
[[0, 0, 0, 0, 0],
[5, 0, 0, 0, 0],
[0, 2, 0, 0, 0],
[0, 0, 8, 0, 0],
[0, 0, 0, 5, 0]]]))
tests[-4, -3] = (np.array([[[8, 5],
[4, 0]],
[[6, 3],
[2, 0]]]),
np.array([[[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[8, 0, 0, 0, 0],
[4, 5, 0, 0, 0]],
[[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[6, 0, 0, 0, 0],
[2, 3, 0, 0, 0]]]))
tests[-2, 1] = (np.array([[[2, 8, 6, 3, 0],
[1, 7, 5, 2, 8],
[6, 4, 1, 7, 0],
[3, 9, 6, 0, 0]],
[[1, 7, 4, 1, 0],
[9, 6, 3, 9, 6],
[5, 2, 8, 5, 0],
[1, 7, 4, 0, 0]]]),
np.array([[[1, 2, 0, 0, 0],
[6, 7, 8, 0, 0],
[3, 4, 5, 6, 0],
[0, 9, 1, 2, 3],
[0, 0, 6, 7, 8]],
[[9, 1, 0, 0, 0],
[5, 6, 7, 0, 0],
[1, 2, 3, 4, 0],
[0, 7, 8, 9, 1],
[0, 0, 4, 5, 6]]]))
tests[2, 4] = (np.array([[[5, 0, 0],
[4, 1, 0],
[3, 9, 7]],
[[4, 0, 0],
[3, 9, 0],
[2, 8, 5]]]),
np.array([[[0, 0, 3, 4, 5],
[0, 0, 0, 9, 1],
[0, 0, 0, 0, 7],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]],
[[0, 0, 2, 3, 4],
[0, 0, 0, 8, 9],
[0, 0, 0, 0, 5],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]]]))
# pyformat: enable
return (mat, repack_diagonals_in_tests(tests, align))
def tall_cases(align=None):
# pyformat: disable
mat = np.array([[[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[9, 8, 7],
[6, 5, 4]],
[[3, 2, 1],
[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[9, 8, 7]]])
tests = dict()
tests[0, 0] = (np.array([[1, 5, 9],
[3, 2, 6]]),
np.array([[[1, 0, 0],
[0, 5, 0],
[0, 0, 9],
[0, 0, 0]],
[[3, 0, 0],
[0, 2, 0],
[0, 0, 6],
[0, 0, 0]]]))
tests[-4, -3] = (np.array([[[9, 5],
[6, 0]],
[[7, 8],
[9, 0]]]),
np.array([[[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
[9, 0, 0],
[6, 5, 0]],
[[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
[7, 0, 0],
[9, 8, 0]]]))
tests[-2, -1] = (np.array([[[4, 8, 7],
[7, 8, 4]],
[[1, 5, 9],
[4, 8, 7]]]),
np.array([[[0, 0, 0],
[4, 0, 0],
[7, 8, 0],
[0, 8, 7],
[0, 0, 4]],
[[0, 0, 0],
[1, 0, 0],
[4, 5, 0],
[0, 8, 9],
[0, 0, 7]]]))
tests[-2, 1] = (np.array([[[2, 6, 0],
[1, 5, 9],
[4, 8, 7],
[7, 8, 4]],
[[2, 3, 0],
[3, 2, 6],
[1, 5, 9],
[4, 8, 7]]]),
np.array([[[1, 2, 0],
[4, 5, 6],
[7, 8, 9],
[0, 8, 7],
[0, 0, 4]],
[[3, 2, 0],
[1, 2, 3],
[4, 5, 6],
[0, 8, 9],
[0, 0, 7]]]))
tests[1, 2] = (np.array([[[3, 0],
[2, 6]],
[[1, 0],
[2, 3]]]),
np.array([[[0, 2, 3],
[0, 0, 6],
[0, 0, 0],
[0, 0, 0],
[0, 0, 0]],
[[0, 2, 1],
[0, 0, 3],
[0, 0, 0],
[0, 0, 0],
[0, 0, 0]]]))
# pyformat: enable
return (mat, repack_diagonals_in_tests(tests, align))
def fat_cases(align=None):
# pyformat: disable
mat = np.array([[[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 1, 2, 3]],
[[4, 5, 6, 7],
[8, 9, 1, 2],
[3, 4, 5, 6]]])
tests = dict()
tests[2, 2] = (np.array([[3, 8],
[6, 2]]),
np.array([[[0, 0, 3, 0],
[0, 0, 0, 8],
[0, 0, 0, 0]],
[[0, 0, 6, 0],
[0, 0, 0, 2],
[0, 0, 0, 0]]]))
tests[-2, 0] = (np.array([[[1, 6, 2],
[5, 1, 0],
[9, 0, 0]],
[[4, 9, 5],
[8, 4, 0],
[3, 0, 0]]]),
np.array([[[1, 0, 0, 0],
[5, 6, 0, 0],
[9, 1, 2, 0]],
[[4, 0, 0, 0],
[8, 9, 0, 0],
[3, 4, 5, 0]]]))
tests[-1, 1] = (np.array([[[2, 7, 3],
[1, 6, 2],
[5, 1, 0]],
[[5, 1, 6],
[4, 9, 5],
[8, 4, 0]]]),
np.array([[[1, 2, 0, 0],
[5, 6, 7, 0],
[0, 1, 2, 3]],
[[4, 5, 0, 0],
[8, 9, 1, 0],
[0, 4, 5, 6]]]))
tests[0, 3] = (np.array([[[4, 0, 0],
[3, 8, 0],
[2, 7, 3],
[1, 6, 2]],
[[7, 0, 0],
[6, 2, 0],
[5, 1, 6],
[4, 9, 5]]]),
np.array([[[1, 2, 3, 4],
[0, 6, 7, 8],
[0, 0, 2, 3]],
[[4, 5, 6, 7],
[0, 9, 1, 2],
[0, 0, 5, 6]]]))
# pyformat: enable
return (mat, repack_diagonals_in_tests(tests, align))
def all_tests(align=None):
return [square_cases(align), tall_cases(align), fat_cases(align)]
class MatrixDiagTest(test.TestCase):
def _moreCases(self, align=None):
# Diagonal bands.
# pyformat: disable
vecs = np.array([[[1, 2, 3, 4], # Input shape: (2, 3, 4)
[5, 6, 7, 8],
[9, 8, 7, 6]],
[[5, 4, 3, 2],
[1, 2, 3, 4],
[5, 6, 7, 8]]])
tests = dict()
tests[-3, -1] = (vecs,
np.array([[[0, 0, 0, 0, 0],
[1, 0, 0, 0, 0],
[5, 2, 0, 0, 0],
[9, 6, 3, 0, 0],
[0, 8, 7, 4, 0]],
[[0, 0, 0, 0, 0],
[5, 0, 0, 0, 0],
[1, 4, 0, 0, 0],
[5, 2, 3, 0, 0],
[0, 6, 3, 2, 0]]]))
tests[-1, 1] = (vecs,
np.array([[[5, 1, 0, 0],
[9, 6, 2, 0],
[0, 8, 7, 3],
[0, 0, 7, 8]],
[[1, 5, 0, 0],
[5, 2, 4, 0],
[0, 6, 3, 3],
[0, 0, 7, 4]]]))
tests[2, 4] = (vecs,
np.array([[[0, 0, 9, 5, 1, 0],
[0, 0, 0, 8, 6, 2],
[0, 0, 0, 0, 7, 7],
[0, 0, 0, 0, 0, 6],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]],
[[0, 0, 5, 1, 5, 0],
[0, 0, 0, 6, 2, 4],
[0, 0, 0, 0, 7, 3],
[0, 0, 0, 0, 0, 8],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]]]))
# pyformat: enable
return (None, repack_diagonals_in_tests(tests, align))
@test_util.run_deprecated_v1
def testVector(self):
with self.session(use_gpu=True):
v = np.array([1.0, 2.0, 3.0])
mat = np.diag(v)
v_diag = array_ops.matrix_diag(v)
self.assertEqual((3, 3), v_diag.get_shape())
self.assertAllEqual(v_diag.eval(), mat)
# {Sub,Super}diagonals.
for offset in [1, -2, 5]:
mat = np.diag(v, offset)
v_diag = array_ops.matrix_diag(v, k=offset)
self.assertEqual(mat.shape, v_diag.get_shape())
self.assertAllEqual(v_diag.eval(), mat)
# Diagonal bands.
for align in alignment_list:
for _, tests in [self._moreCases(align), square_cases(align)]:
for diags, (vecs, solution) in tests.items():
v_diags = array_ops.matrix_diag(vecs[0], k=diags, align=align)
self.assertEqual(v_diags.get_shape(), solution[0].shape)
self.assertAllEqual(v_diags.eval(), solution[0])
def _testVectorBatch(self, dtype):
with self.cached_session(use_gpu=True):
v_batch = np.array([[1.0, 0.0, 3.0], [4.0, 5.0, 6.0]]).astype(dtype)
mat_batch = np.array([[[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 3.0]],
[[4.0, 0.0, 0.0], [0.0, 5.0, 0.0],
[0.0, 0.0, 6.0]]]).astype(dtype)
v_batch_diag = array_ops.matrix_diag(v_batch)
self.assertEqual((2, 3, 3), v_batch_diag.get_shape())
self.assertAllEqual(v_batch_diag.eval(), mat_batch)
# {Sub,Super}diagonals.
for offset in [1, -2, 5]:
v_batch_diag = array_ops.matrix_diag(v_batch, k=offset)
mats = [
np.diag(v_batch[i], offset) for i in range(0, v_batch.shape[0])
]
mat_batch = np.stack(mats, axis=0)
self.assertEqual(mat_batch.shape, v_batch_diag.get_shape())
self.assertAllEqual(v_batch_diag.eval(), mat_batch)
# Diagonal bands with padding_value.
for padding_value, align in zip_to_first_list_length([0, 555, -11],
alignment_list):
for _, tests in [self._moreCases(align), square_cases(align)]:
for diags, (vecs, solution) in tests.items():
v_diags = array_ops.matrix_diag(
vecs.astype(dtype),
k=diags,
padding_value=padding_value,
align=align)
mask = solution == 0
solution = (solution + padding_value * mask).astype(dtype)
self.assertEqual(v_diags.get_shape(), solution.shape)
self.assertAllEqual(v_diags.eval(), solution)
@test_util.run_deprecated_v1
def testVectorBatch(self):
self._testVectorBatch(np.float32)
self._testVectorBatch(np.float64)
self._testVectorBatch(np.int32)
self._testVectorBatch(np.int64)
self._testVectorBatch(np.bool)
@test_util.run_deprecated_v1
def testRectangularBatch(self):
with self.cached_session(use_gpu=True):
# Stores expected num_rows and num_cols (when the other is given).
# expected[d_lower, d_upper] = (expected_num_rows, expected_num_cols)
test_list = list()
# Square cases:
expected = {
(-1, -1): (5, 4),
(-4, -3): (5, 2),
(-2, 1): (5, 5),
(2, 4): (3, 5),
}
# Do not change alignment yet. Re-alignment needs to happen after the
# solution shape is updated.
test_list.append((expected, square_cases()))
# More cases:
expected = {(-3, -1): (5, 4), (-1, 1): (4, 4), (2, 4): (4, 6)}
test_list.append((expected, self._moreCases()))
# Tall cases
expected = {
(0, 0): (3, 3),
(-4, -3): (5, 2),
(-2, -1): (4, 3),
(-2, 1): (3, 3),
(1, 2): (2, 3)
}
test_list.append((expected, tall_cases()))
# Fat cases
expected = {
(2, 2): (2, 4),
(-2, 0): (3, 3),
(-1, 1): (3, 3),
(0, 3): (3, 3)
}
test_list.append((expected, fat_cases()))
for padding_value, align in zip_to_first_list_length([0, 555, -11],
alignment_list):
# Giving both num_rows and num_cols
for _, tests in [tall_cases(align), fat_cases(align)]:
for diags, (vecs, solution) in tests.items():
v_diags = array_ops.matrix_diag(
vecs,
k=diags,
num_rows=solution.shape[-2],
num_cols=solution.shape[-1],
padding_value=padding_value,
align=align)
mask = solution == 0
solution = solution + padding_value * mask
self.assertEqual(v_diags.get_shape(), solution.shape)
self.assertAllEqual(v_diags.eval(), solution)
# Giving just num_rows.
for expected, (_, tests) in test_list:
for diags, (_, new_num_cols) in expected.items():
vecs, solution = tests[diags]
solution = solution.take(indices=range(new_num_cols), axis=-1)
# Repacks the diagonal input according to the new solution shape.
vecs = repack_diagonals(
vecs, diags, solution.shape[-2], new_num_cols, align=align)
v_diags = array_ops.matrix_diag(
vecs,
k=diags,
num_rows=solution.shape[-2],
padding_value=padding_value,
align=align)
mask = solution == 0
solution = solution + padding_value * mask
self.assertEqual(v_diags.get_shape(), solution.shape)
self.assertAllEqual(v_diags.eval(), solution)
# Giving just num_cols.
for expected, (_, tests) in test_list:
for diags, (new_num_rows, _) in expected.items():
vecs, solution = tests[diags]
solution = solution.take(indices=range(new_num_rows), axis=-2)
# Repacks the diagonal input according to the new solution shape.
vecs = repack_diagonals(
vecs, diags, new_num_rows, solution.shape[-1], align=align)
v_diags = array_ops.matrix_diag(
vecs,
k=diags,
num_cols=solution.shape[-1],
padding_value=padding_value,
align=align)
mask = solution == 0
solution = solution + padding_value * mask
self.assertEqual(v_diags.get_shape(), solution.shape)
self.assertAllEqual(v_diags.eval(), solution)
@test_util.run_deprecated_v1
def testInvalidShape(self):
with self.assertRaisesRegexp(ValueError, "must be at least rank 1"):
array_ops.matrix_diag(0)
@test_util.run_deprecated_v1
@test_util.disable_xla("b/123337890") # Error messages differ
def testInvalidShapeAtEval(self):
with self.session(use_gpu=True):
v = array_ops.placeholder(dtype=dtypes_lib.float32)
with self.assertRaisesOpError("diagonal must be at least 1-dim"):
array_ops.matrix_diag(v).eval(feed_dict={v: 0.0})
@test_util.run_deprecated_v1
def testGrad(self):
shapes = ((3,), (7, 4))
with self.session(use_gpu=True):
for shape in shapes:
x = constant_op.constant(np.random.rand(*shape), np.float32)
y = array_ops.matrix_diag(x)
error = gradient_checker.compute_gradient_error(x,
x.get_shape().as_list(),
y,
y.get_shape().as_list())
self.assertLess(error, 1e-4)
# {Sub,super}diagonals/band.
tests = dict() # tests[shape] = (d_lower, d_upper)
tests[(3,)] = (-1, -1)
tests[(7, 3, 4)] = (-1, 1)
with self.session(use_gpu=True):
for shape, diags in tests.items():
x = constant_op.constant(np.random.rand(*shape), np.float32)
for align in alignment_list:
y = array_ops.matrix_diag(x, k=diags, align=align)
error = gradient_checker.compute_gradient_error(
x,
x.get_shape().as_list(), y,
y.get_shape().as_list())
self.assertLess(error, 1e-4)
class MatrixSetDiagTest(test.TestCase):
@test_util.run_deprecated_v1
def testSquare(self):
with self.session(use_gpu=True):
v = np.array([1.0, 2.0, 3.0])
mat = np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0], [1.0, 1.0, 1.0]])
mat_set_diag = np.array([[1.0, 1.0, 0.0], [1.0, 2.0, 1.0],
[1.0, 1.0, 3.0]])
output = array_ops.matrix_set_diag(mat, v)
self.assertEqual((3, 3), output.get_shape())
self.assertAllEqual(mat_set_diag, self.evaluate(output))
# Diagonal bands.
for align in alignment_list:
_, tests = square_cases(align)
for diags, (vecs, banded_mat) in tests.items():
mask = banded_mat[0] == 0
input_mat = np.random.randint(10, size=mask.shape)
solution = input_mat * mask + banded_mat[0]
output = array_ops.matrix_set_diag(
input_mat, vecs[0], k=diags, align=align)
self.assertEqual(output.get_shape(), solution.shape)
self.assertAllEqual(output.eval(), solution)
@test_util.run_deprecated_v1
def testRectangular(self):
with self.session(use_gpu=True):
v = np.array([3.0, 4.0])
mat = np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0]])
expected = np.array([[3.0, 1.0, 0.0], [1.0, 4.0, 1.0]])
output = array_ops.matrix_set_diag(mat, v)
self.assertEqual((2, 3), output.get_shape())
self.assertAllEqual(expected, self.evaluate(output))
v = np.array([3.0, 4.0])
mat = np.array([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0]])
expected = np.array([[3.0, 1.0], [1.0, 4.0], [1.0, 1.0]])
output = array_ops.matrix_set_diag(mat, v)
self.assertEqual((3, 2), output.get_shape())
self.assertAllEqual(expected, self.evaluate(output))
# Diagonal bands.
for align in alignment_list:
for _, tests in [tall_cases(align), fat_cases(align)]:
for diags, (vecs, banded_mat) in tests.items():
mask = banded_mat[0] == 0
input_mat = np.random.randint(10, size=mask.shape)
solution = input_mat * mask + banded_mat[0]
output = array_ops.matrix_set_diag(
input_mat, vecs[0], k=diags, align=align)
self.assertEqual(output.get_shape(), solution.shape)
self.assertAllEqual(output.eval(), solution)
def _testSquareBatch(self, dtype):
with self.cached_session(use_gpu=True):
v_batch = np.array([[-1.0, 0.0, -3.0], [-4.0, -5.0, -6.0]]).astype(dtype)
mat_batch = np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0], [1.0, 0.0, 3.0]],
[[4.0, 0.0, 4.0], [0.0, 5.0, 0.0],
[2.0, 0.0, 6.0]]]).astype(dtype)
mat_set_diag_batch = np.array([[[-1.0, 0.0, 3.0], [0.0, 0.0, 0.0],
[1.0, 0.0, -3.0]],
[[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0],
[2.0, 0.0, -6.0]]]).astype(dtype)
output = array_ops.matrix_set_diag(mat_batch, v_batch)
self.assertEqual((2, 3, 3), output.get_shape())
self.assertAllEqual(mat_set_diag_batch, self.evaluate(output))
# Diagonal bands.
for align in alignment_list:
_, tests = square_cases(align)
for diags, (vecs, banded_mat) in tests.items():
mask = banded_mat == 0
input_mat = np.random.randint(10, size=mask.shape).astype(dtype)
solution = (input_mat * mask + banded_mat).astype(dtype)
output = array_ops.matrix_set_diag(
input_mat, vecs.astype(dtype), k=diags, align=align)
self.assertEqual(output.get_shape(), solution.shape)
self.assertAllEqual(output.eval(), solution)
@test_util.run_deprecated_v1
def testSquareBatch(self):
self._testSquareBatch(np.float32)
self._testSquareBatch(np.float64)
self._testSquareBatch(np.int32)
self._testSquareBatch(np.int64)
self._testSquareBatch(np.bool)
@test_util.run_deprecated_v1
def testRectangularBatch(self):
with self.session(use_gpu=True):
v_batch = np.array([[-1.0, -2.0], [-4.0, -5.0]])
mat_batch = np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0]],
[[4.0, 0.0, 4.0], [0.0, 5.0, 0.0]]])
mat_set_diag_batch = np.array([[[-1.0, 0.0, 3.0], [0.0, -2.0, 0.0]],
[[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0]]])
output = array_ops.matrix_set_diag(mat_batch, v_batch)
self.assertEqual((2, 2, 3), output.get_shape())
self.assertAllEqual(mat_set_diag_batch, self.evaluate(output))
# Diagonal bands.
for align in alignment_list:
for _, tests in [tall_cases(align), fat_cases(align)]:
for diags, pair in tests.items():
vecs, banded_mat = pair
mask = banded_mat == 0
input_mat = np.random.randint(10, size=mask.shape)
solution = input_mat * mask + banded_mat
output = array_ops.matrix_set_diag(
input_mat, vecs, k=diags, align=align)
self.assertEqual(output.get_shape(), solution.shape)
self.assertAllEqual(output.eval(), solution)
@test_util.run_deprecated_v1
def testInvalidShape(self):
with self.assertRaisesRegexp(ValueError, "must be at least rank 2"):
array_ops.matrix_set_diag(0, [0])
with self.assertRaisesRegexp(ValueError, "must be at least rank 1"):
array_ops.matrix_set_diag([[0]], 0)
@test_util.run_deprecated_v1
def testInvalidShapeAtEval(self):
with self.session(use_gpu=True):
v = array_ops.placeholder(dtype=dtypes_lib.float32)
with self.assertRaisesOpError("input must be at least 2-dim"):
array_ops.matrix_set_diag(v, [v]).eval(feed_dict={v: 0.0})
with self.assertRaisesOpError("diagonal must be at least 1-dim"):
array_ops.matrix_set_diag([[v]], v).eval(feed_dict={v: 0.0})
d = array_ops.placeholder(dtype=dtypes_lib.float32)
with self.assertRaisesOpError(
"first dimensions of diagonal don't match"):
array_ops.matrix_set_diag(v, d).eval(feed_dict={
v: np.zeros((2, 3, 3)),
d: np.ones((2, 4))
})
def _testGrad(self, input_shape, diag_shape, diags, align):
with self.session(use_gpu=True):
x = constant_op.constant(
np.random.rand(*input_shape), dtype=dtypes_lib.float32)
x_diag = constant_op.constant(
np.random.rand(*diag_shape), dtype=dtypes_lib.float32)
y = array_ops.matrix_set_diag(x, x_diag, k=diags, align=align)
error_x = gradient_checker.compute_gradient_error(x,
x.get_shape().as_list(),
y,
y.get_shape().as_list())
self.assertLess(error_x, 1e-4)
error_x_diag = gradient_checker.compute_gradient_error(
x_diag,
x_diag.get_shape().as_list(), y,
y.get_shape().as_list())
self.assertLess(error_x_diag, 1e-4)
@test_util.run_deprecated_v1
def testGrad(self):
input_shapes = [(3, 4, 4), (3, 3, 4), (3, 4, 3), (7, 4, 8, 8)]
diag_bands = [(0, 0)]
diag_bands.append((-1, 1))
for input_shape, diags, align in itertools.product(input_shapes, diag_bands,
alignment_list):
lower_diag_index, upper_diag_index = diags
num_diags = upper_diag_index - lower_diag_index + 1
num_diags_dim = () if num_diags == 1 else (num_diags,)
diag_shape = input_shape[:-2] + num_diags_dim + (min(input_shape[-2:]),)
self._testGrad(input_shape, diag_shape, diags, align)
@test_util.run_deprecated_v1
def testGradWithNoShapeInformation(self):
with self.session(use_gpu=True) as sess:
v = array_ops.placeholder(dtype=dtypes_lib.float32)
mat = array_ops.placeholder(dtype=dtypes_lib.float32)
grad_input = array_ops.placeholder(dtype=dtypes_lib.float32)
output = array_ops.matrix_set_diag(mat, v)
grads = gradients_impl.gradients(output, [mat, v], grad_ys=grad_input)
grad_input_val = np.random.rand(3, 3).astype(np.float32)
grad_vals = sess.run(
grads,
feed_dict={
v: 2 * np.ones(3),
mat: np.ones((3, 3)),
grad_input: grad_input_val
})
self.assertAllEqual(np.diag(grad_input_val), grad_vals[1])
self.assertAllEqual(grad_input_val - np.diag(np.diag(grad_input_val)),
grad_vals[0])
class MatrixDiagPartTest(test.TestCase):
@test_util.run_deprecated_v1
def testSquare(self):
with self.session(use_gpu=True):
v = np.array([1.0, 2.0, 3.0])
mat = np.diag(v)
mat_diag = array_ops.matrix_diag_part(mat)
self.assertEqual((3,), mat_diag.get_shape())
self.assertAllEqual(mat_diag.eval(), v)
for offset in [-2, 3]:
mat = np.diag(v, offset)
mat_diag = array_ops.matrix_diag_part(mat, k=offset)
self.assertEqual((3,), mat_diag.get_shape())
self.assertAllEqual(mat_diag.eval(), v)
# Diagonal bands.
for align in alignment_list:
mat, tests = square_cases(align)
for diags, pair in tests.items():
solution, _ = pair
mat_diag = array_ops.matrix_diag_part(mat[0], k=diags, align=align)
self.assertEqual(mat_diag.get_shape(), solution[0].shape)
self.assertAllEqual(mat_diag.eval(), solution[0])
@test_util.run_deprecated_v1
def testRectangular(self):
with self.session(use_gpu=True):
mat = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
mat_diag = array_ops.matrix_diag_part(mat)
self.assertAllEqual(mat_diag.eval(), np.array([1.0, 5.0]))
mat = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
mat_diag = array_ops.matrix_diag_part(mat)
self.assertAllEqual(mat_diag.eval(), np.array([1.0, 4.0]))
# Diagonal bands.
for align in alignment_list:
for mat, tests in [tall_cases(align), fat_cases(align)]:
for diags, pair in tests.items():
solution, _ = pair
mat_diag = array_ops.matrix_diag_part(
mat[0], k=diags, align=align)
self.assertEqual(mat_diag.get_shape(), solution[0].shape)
self.assertAllEqual(mat_diag.eval(), solution[0])
def _testSquareBatch(self, dtype):
with self.cached_session(use_gpu=True):
v_batch = np.array([[1.0, 0.0, 3.0], [4.0, 5.0, 6.0]]).astype(dtype)
mat_batch = np.array([[[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 3.0]],
[[4.0, 0.0, 0.0], [0.0, 5.0, 0.0],
[0.0, 0.0, 6.0]]]).astype(dtype)
self.assertEqual(mat_batch.shape, (2, 3, 3))
mat_batch_diag = array_ops.matrix_diag_part(mat_batch)
self.assertEqual((2, 3), mat_batch_diag.get_shape())
self.assertAllEqual(mat_batch_diag.eval(), v_batch)
# Diagonal bands with padding_value.
for padding_value, align in zip_to_first_list_length([0, 555, -11],
alignment_list):
mat, tests = square_cases(align)
for diags, pair in tests.items():
solution, _ = pair
mat_batch_diag = array_ops.matrix_diag_part(
mat.astype(dtype),
k=diags,
padding_value=padding_value,
align=align)
mask = solution == 0
solution = (solution + padding_value * mask).astype(dtype)
self.assertEqual(mat_batch_diag.get_shape(), solution.shape)
self.assertAllEqual(mat_batch_diag.eval(), solution)
@test_util.run_deprecated_v1
def testSquareBatch(self):
self._testSquareBatch(np.float32)
self._testSquareBatch(np.float64)
self._testSquareBatch(np.int32)
self._testSquareBatch(np.int64)
self._testSquareBatch(np.bool)
@test_util.run_deprecated_v1
def testRectangularBatch(self):
with self.session(use_gpu=True):
v_batch = np.array([[1.0, 2.0], [4.0, 5.0]])
mat_batch = np.array([[[1.0, 0.0, 0.0], [0.0, 2.0, 0.0]],
[[4.0, 0.0, 0.0], [0.0, 5.0, 0.0]]])
self.assertEqual(mat_batch.shape, (2, 2, 3))
mat_batch_diag = array_ops.matrix_diag_part(mat_batch)
self.assertEqual((2, 2), mat_batch_diag.get_shape())
self.assertAllEqual(mat_batch_diag.eval(), v_batch)
# Diagonal bands with padding_value and align.
for padding_value, align in zip_to_first_list_length([0, 555, -11],
alignment_list):
for mat, tests in [tall_cases(align), fat_cases(align)]:
for diags, pair in tests.items():
solution, _ = pair
mat_batch_diag = array_ops.matrix_diag_part(
mat, k=diags, padding_value=padding_value, align=align)
mask = solution == 0
solution = solution + padding_value * mask
self.assertEqual(mat_batch_diag.get_shape(), solution.shape)
self.assertAllEqual(mat_batch_diag.eval(), solution)
@test_util.run_deprecated_v1
def testUnknownShape(self):
matrix = array_ops.placeholder(dtypes_lib.int32, shape=[None, None])
result = array_ops.matrix_diag_part(matrix, k=-1)
input_matrix = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
with self.session(use_gpu=True):
result_eval = result.eval(feed_dict={matrix: input_matrix})
self.assertAllEqual([4, 8], result_eval)
@test_util.run_deprecated_v1
def testInvalidShape(self):
with self.assertRaisesRegexp(ValueError, "must be at least rank 2"):
array_ops.matrix_diag_part(0)
@test_util.run_deprecated_v1
@test_util.disable_xla("b/123337890") # Error messages differ
def testInvalidShapeAtEval(self):
with self.session(use_gpu=True):
v = array_ops.placeholder(dtype=dtypes_lib.float32)
with self.assertRaisesOpError("input must be at least 2-dim"):
array_ops.matrix_diag_part(v).eval(feed_dict={v: 0.0})
@test_util.run_deprecated_v1
def testGrad(self):
shapes = ((3, 3), (2, 3), (3, 2), (5, 3, 3))
with self.session(use_gpu=True):
for shape in shapes:
x = constant_op.constant(np.random.rand(*shape), dtype=np.float32)
y = array_ops.matrix_diag_part(x)
error = gradient_checker.compute_gradient_error(x,
x.get_shape().as_list(),
y,
y.get_shape().as_list())
self.assertLess(error, 1e-4)
# {Sub,super}diagonals/band.
tests = dict() # tests[shape] = (d_lower, d_upper)
tests[(3, 3)] = (-1, -1)
tests[(7, 3, 4)] = (-1, 1)
with self.session(use_gpu=True):
for align in alignment_list:
for shape, diags in tests.items():
x = constant_op.constant(np.random.rand(*shape), np.float32)
y = array_ops.matrix_diag_part(input=x, k=diags, align=align)
error = gradient_checker.compute_gradient_error(
x,
x.get_shape().as_list(), y,
y.get_shape().as_list())
self.assertLess(error, 1e-4)
class DiagTest(test.TestCase):
def _diagOp(self, diag, dtype, expected_ans, use_gpu):
with self.cached_session(use_gpu=use_gpu):
tf_ans = array_ops.diag(ops.convert_to_tensor(diag.astype(dtype)))
out = self.evaluate(tf_ans)
tf_ans_inv = array_ops.diag_part(expected_ans)
inv_out = self.evaluate(tf_ans_inv)
self.assertAllClose(out, expected_ans)
self.assertAllClose(inv_out, diag)
self.assertShapeEqual(expected_ans, tf_ans)
self.assertShapeEqual(diag, tf_ans_inv)
def diagOp(self, diag, dtype, expected_ans):
self._diagOp(diag, dtype, expected_ans, False)
self._diagOp(diag, dtype, expected_ans, True)
def testEmptyTensor(self):
x = np.array([])
expected_ans = np.empty([0, 0])
self.diagOp(x, np.int32, expected_ans)
def testRankOneIntTensor(self):
x = np.array([1, 2, 3])
expected_ans = np.array([[1, 0, 0], [0, 2, 0], [0, 0, 3]])
self.diagOp(x, np.int32, expected_ans)
self.diagOp(x, np.int64, expected_ans)
def testRankOneFloatTensor(self):
x = np.array([1.1, 2.2, 3.3])
expected_ans = np.array([[1.1, 0, 0], [0, 2.2, 0], [0, 0, 3.3]])
self.diagOp(x, np.float32, expected_ans)
self.diagOp(x, np.float64, expected_ans)
def testRankOneComplexTensor(self):
for dtype in [np.complex64, np.complex128]:
x = np.array([1.1 + 1.1j, 2.2 + 2.2j, 3.3 + 3.3j], dtype=dtype)
expected_ans = np.array(
[[1.1 + 1.1j, 0 + 0j, 0 + 0j], [0 + 0j, 2.2 + 2.2j, 0 + 0j],
[0 + 0j, 0 + 0j, 3.3 + 3.3j]],
dtype=dtype)
self.diagOp(x, dtype, expected_ans)
def testRankTwoIntTensor(self):
x = np.array([[1, 2, 3], [4, 5, 6]])
expected_ans = np.array([[[[1, 0, 0], [0, 0, 0]], [[0, 2, 0], [0, 0, 0]],
[[0, 0, 3], [0, 0, 0]]],
[[[0, 0, 0], [4, 0, 0]], [[0, 0, 0], [0, 5, 0]],
[[0, 0, 0], [0, 0, 6]]]])
self.diagOp(x, np.int32, expected_ans)
self.diagOp(x, np.int64, expected_ans)
def testRankTwoFloatTensor(self):
x = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]])
expected_ans = np.array(
[[[[1.1, 0, 0], [0, 0, 0]], [[0, 2.2, 0], [0, 0, 0]],
[[0, 0, 3.3], [0, 0, 0]]], [[[0, 0, 0], [4.4, 0, 0]],
[[0, 0, 0], [0, 5.5, 0]], [[0, 0, 0],
[0, 0, 6.6]]]])
self.diagOp(x, np.float32, expected_ans)
self.diagOp(x, np.float64, expected_ans)
def testRankTwoComplexTensor(self):
for dtype in [np.complex64, np.complex128]:
x = np.array(
[[1.1 + 1.1j, 2.2 + 2.2j, 3.3 + 3.3j],
[4.4 + 4.4j, 5.5 + 5.5j, 6.6 + 6.6j]],
dtype=dtype)
expected_ans = np.array(
[[[[1.1 + 1.1j, 0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j, 0 + 0j]], [
[0 + 0j, 2.2 + 2.2j, 0 + 0j], [0 + 0j, 0 + 0j, 0 + 0j]
], [[0 + 0j, 0 + 0j, 3.3 + 3.3j], [0 + 0j, 0 + 0j, 0 + 0j]]], [[
[0 + 0j, 0 + 0j, 0 + 0j], [4.4 + 4.4j, 0 + 0j, 0 + 0j]
], [[0 + 0j, 0 + 0j, 0 + 0j], [0 + 0j, 5.5 + 5.5j, 0 + 0j]
], [[0 + 0j, 0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j, 6.6 + 6.6j]]]],
dtype=dtype)
self.diagOp(x, dtype, expected_ans)
def testRankThreeFloatTensor(self):
x = np.array([[[1.1, 2.2], [3.3, 4.4]], [[5.5, 6.6], [7.7, 8.8]]])
expected_ans = np.array([[[[[[1.1, 0], [0, 0]], [[0, 0], [0, 0]]],
[[[0, 2.2], [0, 0]], [[0, 0], [0, 0]]]],
[[[[0, 0], [3.3, 0]], [[0, 0], [0, 0]]],
[[[0, 0], [0, 4.4]], [[0, 0], [0, 0]]]]],
[[[[[0, 0], [0, 0]], [[5.5, 0], [0, 0]]],
[[[0, 0], [0, 0]], [[0, 6.6], [0, 0]]]],
[[[[0, 0], [0, 0]], [[0, 0], [7.7, 0]]],
[[[0, 0], [0, 0]], [[0, 0], [0, 8.8]]]]]])
self.diagOp(x, np.float32, expected_ans)
self.diagOp(x, np.float64, expected_ans)
def testRankThreeComplexTensor(self):
for dtype in [np.complex64, np.complex128]:
x = np.array(
[[[1.1 + 1.1j, 2.2 + 2.2j], [3.3 + 3.3j, 4.4 + 4.4j]],
[[5.5 + 5.5j, 6.6 + 6.6j], [7.7 + 7.7j, 8.8 + 8.8j]]],
dtype=dtype)
expected_ans = np.array(
[[[[[[1.1 + 1.1j, 0 + 0j], [0 + 0j, 0 + 0j]], [[0 + 0j, 0 + 0j], [
0 + 0j, 0 + 0j
]]], [[[0 + 0j, 2.2 + 2.2j], [0 + 0j, 0 + 0j]], [[0 + 0j, 0 + 0j], [
0 + 0j, 0 + 0j
]]]], [[[[0 + 0j, 0 + 0j], [3.3 + 3.3j, 0 + 0j]], [[0 + 0j, 0 + 0j], [
0 + 0j, 0 + 0j
]]], [[[0 + 0j, 0 + 0j], [0 + 0j, 4.4 + 4.4j]], [[0 + 0j, 0 + 0j], [
0 + 0j, 0 + 0j
]]]]], [[[[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]], [
[5.5 + 5.5j, 0 + 0j], [0 + 0j, 0 + 0j]
]], [[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]], [[0 + 0j, 6.6 + 6.6j], [
0 + 0j, 0 + 0j
]]]], [[[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]], [[0 + 0j, 0 + 0j], [
7.7 + 7.7j, 0 + 0j
]]], [[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]],
[[0 + 0j, 0 + 0j], [0 + 0j, 8.8 + 8.8j]]]]]],
dtype=dtype)
self.diagOp(x, dtype, expected_ans)
def testRankFourNumberTensor(self):
for dtype in [np.float32, np.float64, np.int64, np.int32]:
# Input with shape [2, 1, 2, 3]
x = np.array(
[[[[1, 2, 3], [4, 5, 6]]], [[[7, 8, 9], [10, 11, 12]]]], dtype=dtype)
# Output with shape [2, 1, 2, 3, 2, 1, 2, 3]
expected_ans = np.array(
[[[[[[[[1, 0, 0], [0, 0, 0]]], [[[0, 0, 0], [0, 0, 0]]]], [
[[[0, 2, 0], [0, 0, 0]]], [[[0, 0, 0], [0, 0, 0]]]
], [[[[0, 0, 3], [0, 0, 0]]], [[[0, 0, 0], [0, 0, 0]]]]], [[
[[[0, 0, 0], [4, 0, 0]]], [[[0, 0, 0], [0, 0, 0]]]
], [[[[0, 0, 0], [0, 5, 0]]], [[[0, 0, 0], [0, 0, 0]]]], [
[[[0, 0, 0], [0, 0, 6]]], [[[0, 0, 0], [0, 0, 0]]]
]]]], [[[[[[[0, 0, 0], [0, 0, 0]]], [[[7, 0, 0], [0, 0, 0]]]], [
[[[0, 0, 0], [0, 0, 0]]], [[[0, 8, 0], [0, 0, 0]]]
], [[[[0, 0, 0], [0, 0, 0]]], [[[0, 0, 9], [0, 0, 0]]]]], [[
[[[0, 0, 0], [0, 0, 0]]], [[[0, 0, 0], [10, 0, 0]]]
], [[[[0, 0, 0], [0, 0, 0]]], [[[0, 0, 0], [0, 11, 0]]]
], [[[[0, 0, 0], [0, 0, 0]]], [[[0, 0, 0], [0, 0, 12]]]]]]]],
dtype=dtype)
self.diagOp(x, dtype, expected_ans)
@test_util.run_deprecated_v1
def testInvalidRank(self):
with self.assertRaisesRegexp(ValueError, "must be at least rank 1"):
array_ops.diag(0.0)
class DiagPartOpTest(test.TestCase):
def setUp(self):
np.random.seed(0)
def _diagPartOp(self, tensor, dtype, expected_ans, use_gpu):
with self.cached_session(use_gpu=use_gpu):
tensor = ops.convert_to_tensor(tensor.astype(dtype))
tf_ans_inv = array_ops.diag_part(tensor)
inv_out = self.evaluate(tf_ans_inv)
self.assertAllClose(inv_out, expected_ans)
self.assertShapeEqual(expected_ans, tf_ans_inv)
def diagPartOp(self, tensor, dtype, expected_ans):
self._diagPartOp(tensor, dtype, expected_ans, False)
self._diagPartOp(tensor, dtype, expected_ans, True)
def testRankTwoFloatTensor(self):
x = np.random.rand(3, 3)
i = np.arange(3)
expected_ans = x[i, i]
self.diagPartOp(x, np.float32, expected_ans)
self.diagPartOp(x, np.float64, expected_ans)
def testRankFourFloatTensorUnknownShape(self):
x = np.random.rand(3, 3)
i = np.arange(3)
expected_ans = x[i, i]
for shape in None, (None, 3), (3, None):
with self.cached_session(use_gpu=False):
t = ops.convert_to_tensor(x.astype(np.float32))
t.set_shape(shape)
tf_ans = array_ops.diag_part(t)
out = self.evaluate(tf_ans)
self.assertAllClose(out, expected_ans)
self.assertShapeEqual(expected_ans, tf_ans)
def testRankFourFloatTensor(self):
x = np.random.rand(2, 3, 2, 3)
i = np.arange(2)[:, None]
j = np.arange(3)
expected_ans = x[i, j, i, j]
self.diagPartOp(x, np.float32, expected_ans)
self.diagPartOp(x, np.float64, expected_ans)
def testRankSixFloatTensor(self):
x = np.random.rand(2, 2, 2, 2, 2, 2)
i = np.arange(2)[:, None, None]
j = np.arange(2)[:, None]
k = np.arange(2)
expected_ans = x[i, j, k, i, j, k]
self.diagPartOp(x, np.float32, expected_ans)
self.diagPartOp(x, np.float64, expected_ans)
def testRankEightComplexTensor(self):
x = np.random.rand(2, 2, 2, 3, 2, 2, 2, 3)
i = np.arange(2)[:, None, None, None]
j = np.arange(2)[:, None, None]
k = np.arange(2)[:, None]
l = np.arange(3)
expected_ans = x[i, j, k, l, i, j, k, l]
self.diagPartOp(x, np.complex64, expected_ans)
self.diagPartOp(x, np.complex128, expected_ans)
@test_util.run_deprecated_v1
def testOddRank(self):
w = np.random.rand(2)
x = np.random.rand(2, 2, 2)
self.assertRaises(ValueError, self.diagPartOp, w, np.float32, 0)
self.assertRaises(ValueError, self.diagPartOp, x, np.float32, 0)
with self.assertRaises(ValueError):
array_ops.diag_part(0.0)
@test_util.run_deprecated_v1
def testUnevenDimensions(self):
w = np.random.rand(2, 5)
x = np.random.rand(2, 1, 2, 3)
self.assertRaises(ValueError, self.diagPartOp, w, np.float32, 0)
self.assertRaises(ValueError, self.diagPartOp, x, np.float32, 0)
class DiagGradOpTest(test.TestCase):
@test_util.run_deprecated_v1
def testDiagGrad(self):
np.random.seed(0)
shapes = ((3,), (3, 3), (3, 3, 3))
dtypes = (dtypes_lib.float32, dtypes_lib.float64)
with self.session(use_gpu=False):
errors = []
for shape in shapes:
for dtype in dtypes:
x1 = constant_op.constant(np.random.rand(*shape), dtype=dtype)
y = array_ops.diag(x1)
error = gradient_checker.compute_gradient_error(
x1,
x1.get_shape().as_list(), y,
y.get_shape().as_list())
tf_logging.info("error = %f", error)
self.assertLess(error, 1e-4)
class DiagGradPartOpTest(test.TestCase):
@test_util.run_deprecated_v1
def testDiagPartGrad(self):
np.random.seed(0)
shapes = ((3, 3), (3, 3, 3, 3))
dtypes = (dtypes_lib.float32, dtypes_lib.float64)
with self.session(use_gpu=False):
errors = []
for shape in shapes:
for dtype in dtypes:
x1 = constant_op.constant(np.random.rand(*shape), dtype=dtype)
y = array_ops.diag_part(x1)
error = gradient_checker.compute_gradient_error(
x1,
x1.get_shape().as_list(), y,
y.get_shape().as_list())
tf_logging.info("error = %f", error)
self.assertLess(error, 1e-4)
if __name__ == "__main__":
test.main()