Remove forward_compatibility / forward_compatible checks for dates that have already passed.
- Also, removed print statements from relu_op_test.py PiperOrigin-RevId: 287911742 Change-Id: Ib1763a5a010e5738e4d93e348391839e1e164108
This commit is contained in:
parent
4b7f4c1f09
commit
f75c37faf3
|
@ -21,19 +21,10 @@ from __future__ import print_function
|
|||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
# LINT.IfChange
|
||||
matrix_diag_v3_forward_compat_date = (2019, 12, 6)
|
||||
# LINT.ThenChange(
|
||||
# //tensorflow/python/kernel_tests/diag_op_test.py,
|
||||
# //tensorflow/python/ops/array_ops.py,
|
||||
# //tensorflow/python/ops/parallel_for/array_test.py
|
||||
# )
|
||||
|
||||
default_v2_alignment = "LEFT_LEFT"
|
||||
alignment_list = ["RIGHT_LEFT", "LEFT_RIGHT"]
|
||||
|
||||
|
@ -404,25 +395,20 @@ class MatrixDiagTest(xla_test.XLATestCase):
|
|||
|
||||
# From here onwards are v2-only tests.
|
||||
def testSquare(self):
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
for align in alignment_list:
|
||||
for _, tests in [square_cases(align)]:
|
||||
for diag_index, (vecs, solution) in tests.items():
|
||||
params = {"diagonal": vecs[0], "k": diag_index, "align": align}
|
||||
self._assertOpOutputMatchesExpected(params, solution[0])
|
||||
for align in alignment_list:
|
||||
for _, tests in [square_cases(align)]:
|
||||
for diag_index, (vecs, solution) in tests.items():
|
||||
params = {"diagonal": vecs[0], "k": diag_index, "align": align}
|
||||
self._assertOpOutputMatchesExpected(params, solution[0])
|
||||
|
||||
def testSquareBatch(self):
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
for align in alignment_list:
|
||||
for _, tests in [square_cases(align)]:
|
||||
for diag_index, (vecs, solution) in tests.items():
|
||||
params = {"diagonal": vecs, "k": diag_index, "align": align}
|
||||
self._assertOpOutputMatchesExpected(params, solution)
|
||||
for align in alignment_list:
|
||||
for _, tests in [square_cases(align)]:
|
||||
for diag_index, (vecs, solution) in tests.items():
|
||||
params = {"diagonal": vecs, "k": diag_index, "align": align}
|
||||
self._assertOpOutputMatchesExpected(params, solution)
|
||||
|
||||
def testRectangularBatch(self):
|
||||
if not compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
return
|
||||
|
||||
# Stores expected num_rows and num_cols (when the other is given).
|
||||
# expected[(d_lower, d_upper)] = (expected_num_rows, expected_num_cols)
|
||||
test_list = list()
|
||||
|
@ -513,22 +499,21 @@ class MatrixDiagTest(xla_test.XLATestCase):
|
|||
}, solution_given_num_cols)
|
||||
|
||||
def testPadding(self):
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
for padding_value, align in zip_to_first_list_length([555, -11],
|
||||
alignment_list):
|
||||
for _, tests in all_tests(align):
|
||||
for diag_index, (vecs, solution) in tests.items():
|
||||
mask = (solution == 0)
|
||||
solution = solution + (mask * padding_value)
|
||||
self._assertOpOutputMatchesExpected(
|
||||
{
|
||||
"diagonal": vecs,
|
||||
"k": diag_index,
|
||||
"num_rows": solution.shape[-2],
|
||||
"num_cols": solution.shape[-1],
|
||||
"padding_value": padding_value,
|
||||
"align": align
|
||||
}, solution)
|
||||
for padding_value, align in zip_to_first_list_length([555, -11],
|
||||
alignment_list):
|
||||
for _, tests in all_tests(align):
|
||||
for diag_index, (vecs, solution) in tests.items():
|
||||
mask = (solution == 0)
|
||||
solution = solution + (mask * padding_value)
|
||||
self._assertOpOutputMatchesExpected(
|
||||
{
|
||||
"diagonal": vecs,
|
||||
"k": diag_index,
|
||||
"num_rows": solution.shape[-2],
|
||||
"num_cols": solution.shape[-1],
|
||||
"padding_value": padding_value,
|
||||
"align": align
|
||||
}, solution)
|
||||
|
||||
|
||||
class MatrixSetDiagTest(xla_test.XLATestCase):
|
||||
|
@ -634,36 +619,34 @@ class MatrixSetDiagTest(xla_test.XLATestCase):
|
|||
|
||||
# From here onwards are v2-only tests.
|
||||
def testSingleMatrix(self):
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
for align in alignment_list:
|
||||
for _, tests in all_tests(align):
|
||||
for diag_index, (vecs, banded_mat) in tests.items():
|
||||
mask = (banded_mat[0] == 0)
|
||||
input_mat = np.random.randint(10, size=mask.shape)
|
||||
solution = input_mat * mask + banded_mat[0]
|
||||
self._assertOpOutputMatchesExpected(
|
||||
{
|
||||
"input": input_mat,
|
||||
"diagonal": vecs[0],
|
||||
"k": diag_index,
|
||||
"align": align
|
||||
}, solution)
|
||||
for align in alignment_list:
|
||||
for _, tests in all_tests(align):
|
||||
for diag_index, (vecs, banded_mat) in tests.items():
|
||||
mask = (banded_mat[0] == 0)
|
||||
input_mat = np.random.randint(10, size=mask.shape)
|
||||
solution = input_mat * mask + banded_mat[0]
|
||||
self._assertOpOutputMatchesExpected(
|
||||
{
|
||||
"input": input_mat,
|
||||
"diagonal": vecs[0],
|
||||
"k": diag_index,
|
||||
"align": align
|
||||
}, solution)
|
||||
|
||||
def testBatch(self):
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
for align in alignment_list:
|
||||
for _, tests in all_tests(align):
|
||||
for diag_index, (vecs, banded_mat) in tests.items():
|
||||
mask = (banded_mat == 0)
|
||||
input_mat = np.random.randint(10, size=mask.shape)
|
||||
solution = input_mat * mask + banded_mat
|
||||
self._assertOpOutputMatchesExpected(
|
||||
{
|
||||
"input": input_mat,
|
||||
"diagonal": vecs,
|
||||
"k": diag_index,
|
||||
"align": align
|
||||
}, solution)
|
||||
for align in alignment_list:
|
||||
for _, tests in all_tests(align):
|
||||
for diag_index, (vecs, banded_mat) in tests.items():
|
||||
mask = (banded_mat == 0)
|
||||
input_mat = np.random.randint(10, size=mask.shape)
|
||||
solution = input_mat * mask + banded_mat
|
||||
self._assertOpOutputMatchesExpected(
|
||||
{
|
||||
"input": input_mat,
|
||||
"diagonal": vecs,
|
||||
"k": diag_index,
|
||||
"align": align
|
||||
}, solution)
|
||||
|
||||
|
||||
class MatrixDiagPartTest(xla_test.XLATestCase):
|
||||
|
@ -705,45 +688,42 @@ class MatrixDiagPartTest(xla_test.XLATestCase):
|
|||
|
||||
# From here onwards are v2-only tests.
|
||||
def testSingleMatrix(self):
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
for align in alignment_list:
|
||||
test_list = [square_cases(align), tall_cases(align), fat_cases(align)]
|
||||
for mat, tests in test_list:
|
||||
for diag_index, (solution, _) in tests.items():
|
||||
self._assertOpOutputMatchesExpected(
|
||||
{
|
||||
"input": mat[0],
|
||||
"k": diag_index,
|
||||
"align": align
|
||||
}, solution[0])
|
||||
for align in alignment_list:
|
||||
test_list = [square_cases(align), tall_cases(align), fat_cases(align)]
|
||||
for mat, tests in test_list:
|
||||
for diag_index, (solution, _) in tests.items():
|
||||
self._assertOpOutputMatchesExpected(
|
||||
{
|
||||
"input": mat[0],
|
||||
"k": diag_index,
|
||||
"align": align
|
||||
}, solution[0])
|
||||
|
||||
def testBatch(self):
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
for align in alignment_list:
|
||||
for mat, tests in all_tests(align):
|
||||
for diag_index, (solution, _) in tests.items():
|
||||
self._assertOpOutputMatchesExpected(
|
||||
{
|
||||
"input": mat,
|
||||
"k": diag_index,
|
||||
"align": align
|
||||
}, solution)
|
||||
for align in alignment_list:
|
||||
for mat, tests in all_tests(align):
|
||||
for diag_index, (solution, _) in tests.items():
|
||||
self._assertOpOutputMatchesExpected(
|
||||
{
|
||||
"input": mat,
|
||||
"k": diag_index,
|
||||
"align": align
|
||||
}, solution)
|
||||
|
||||
def testPadding(self):
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
for padding_value, align in zip_to_first_list_length([555, -11],
|
||||
alignment_list):
|
||||
for mat, tests in all_tests(align):
|
||||
for diag_index, (solution, _) in tests.items():
|
||||
mask = (solution == 0)
|
||||
solution = solution + (mask * padding_value)
|
||||
self._assertOpOutputMatchesExpected(
|
||||
{
|
||||
"input": mat,
|
||||
"k": diag_index,
|
||||
"padding_value": padding_value,
|
||||
"align": align
|
||||
}, solution)
|
||||
for padding_value, align in zip_to_first_list_length([555, -11],
|
||||
alignment_list):
|
||||
for mat, tests in all_tests(align):
|
||||
for diag_index, (solution, _) in tests.items():
|
||||
mask = (solution == 0)
|
||||
solution = solution + (mask * padding_value)
|
||||
self._assertOpOutputMatchesExpected(
|
||||
{
|
||||
"input": mat,
|
||||
"k": diag_index,
|
||||
"padding_value": padding_value,
|
||||
"align": align
|
||||
}, solution)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -23,7 +23,6 @@ from tensorflow.core.protobuf import cluster_pb2
|
|||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.core.protobuf import tensorflow_server_pb2
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.experimental.ops import distribute
|
||||
from tensorflow.python.data.experimental.ops import distribute_options
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
|
@ -90,50 +89,80 @@ class LocalReplicateTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1], mode=["graph", "eager"]))
|
||||
def testExternalStatePolicyIgnore(self):
|
||||
with compat.forward_compatibility_horizon(2019, 11, 30):
|
||||
with ops.device(self._device0):
|
||||
dataset0 = dataset_ops.Dataset.range(100).map(
|
||||
lambda _: random_ops.random_uniform( # pylint:disable=g-long-lambda
|
||||
[],
|
||||
minval=1,
|
||||
maxval=10,
|
||||
dtype=dtypes.float32))
|
||||
opt = dataset_ops.Options()
|
||||
opt.experimental_external_state_policy = (
|
||||
distribute_options.ExternalStatePolicy.IGNORE)
|
||||
dataset0 = dataset0.with_options(opt)
|
||||
replicated_ds = distribute.replicate(dataset0,
|
||||
[self._device1, self._device2])
|
||||
dataset1 = replicated_ds[self._device1]
|
||||
dataset2 = replicated_ds[self._device2]
|
||||
with ops.device(self._device0):
|
||||
dataset0 = dataset_ops.Dataset.range(100).map(
|
||||
lambda _: random_ops.random_uniform( # pylint:disable=g-long-lambda
|
||||
[],
|
||||
minval=1,
|
||||
maxval=10,
|
||||
dtype=dtypes.float32))
|
||||
opt = dataset_ops.Options()
|
||||
opt.experimental_external_state_policy = (
|
||||
distribute_options.ExternalStatePolicy.IGNORE)
|
||||
dataset0 = dataset0.with_options(opt)
|
||||
replicated_ds = distribute.replicate(dataset0,
|
||||
[self._device1, self._device2])
|
||||
dataset1 = replicated_ds[self._device1]
|
||||
dataset2 = replicated_ds[self._device2]
|
||||
|
||||
with ops.device(self._device0):
|
||||
get_next0 = self.getNext(dataset0)
|
||||
with ops.device(self._device1):
|
||||
get_next1 = self.getNext(dataset1)
|
||||
with ops.device(self._device2):
|
||||
get_next2 = self.getNext(dataset2)
|
||||
with ops.device(self._device0):
|
||||
get_next0 = self.getNext(dataset0)
|
||||
with ops.device(self._device1):
|
||||
get_next1 = self.getNext(dataset1)
|
||||
with ops.device(self._device2):
|
||||
get_next2 = self.getNext(dataset2)
|
||||
|
||||
for _ in range(100):
|
||||
self.evaluate(get_next0())
|
||||
self.evaluate(get_next1())
|
||||
self.evaluate(get_next2())
|
||||
for _ in range(100):
|
||||
self.evaluate(get_next0())
|
||||
self.evaluate(get_next1())
|
||||
self.evaluate(get_next2())
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1], mode=["graph", "eager"]))
|
||||
def testExternalStatePolicyWarn(self):
|
||||
with compat.forward_compatibility_horizon(2019, 11, 30):
|
||||
with ops.device(self._device0):
|
||||
dataset0 = dataset_ops.Dataset.range(100).map(
|
||||
lambda _: random_ops.random_uniform( # pylint:disable=g-long-lambda
|
||||
[],
|
||||
minval=1,
|
||||
maxval=10,
|
||||
dtype=dtypes.float32))
|
||||
opt = dataset_ops.Options()
|
||||
opt.experimental_external_state_policy = (
|
||||
distribute_options.ExternalStatePolicy.WARN)
|
||||
dataset0 = dataset0.with_options(opt)
|
||||
with ops.device(self._device0):
|
||||
dataset0 = dataset_ops.Dataset.range(100).map(
|
||||
lambda _: random_ops.random_uniform( # pylint:disable=g-long-lambda
|
||||
[],
|
||||
minval=1,
|
||||
maxval=10,
|
||||
dtype=dtypes.float32))
|
||||
opt = dataset_ops.Options()
|
||||
opt.experimental_external_state_policy = (
|
||||
distribute_options.ExternalStatePolicy.WARN)
|
||||
dataset0 = dataset0.with_options(opt)
|
||||
replicated_ds = distribute.replicate(dataset0,
|
||||
[self._device1, self._device2])
|
||||
dataset1 = replicated_ds[self._device1]
|
||||
dataset2 = replicated_ds[self._device2]
|
||||
|
||||
with ops.device(self._device0):
|
||||
get_next0 = self.getNext(dataset0)
|
||||
with ops.device(self._device1):
|
||||
get_next1 = self.getNext(dataset1)
|
||||
with ops.device(self._device2):
|
||||
get_next2 = self.getNext(dataset2)
|
||||
|
||||
for _ in range(100):
|
||||
self.evaluate(get_next0())
|
||||
self.evaluate(get_next1())
|
||||
self.evaluate(get_next2())
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1], mode=["graph", "eager"]))
|
||||
def testExternalStatePolicyFail(self):
|
||||
with ops.device(self._device0):
|
||||
dataset0 = dataset_ops.Dataset.range(100).map(
|
||||
lambda _: random_ops.random_uniform( # pylint:disable=g-long-lambda
|
||||
[],
|
||||
minval=1,
|
||||
maxval=10,
|
||||
dtype=dtypes.float32))
|
||||
opt = dataset_ops.Options()
|
||||
opt.experimental_external_state_policy = (
|
||||
distribute_options.ExternalStatePolicy.FAIL)
|
||||
dataset0 = dataset0.with_options(opt)
|
||||
with self.assertRaises(errors.FailedPreconditionError):
|
||||
replicated_ds = distribute.replicate(dataset0,
|
||||
[self._device1, self._device2])
|
||||
dataset1 = replicated_ds[self._device1]
|
||||
|
@ -151,39 +180,6 @@ class LocalReplicateTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
self.evaluate(get_next1())
|
||||
self.evaluate(get_next2())
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1], mode=["graph", "eager"]))
|
||||
def testExternalStatePolicyFail(self):
|
||||
with compat.forward_compatibility_horizon(2019, 11, 30):
|
||||
with ops.device(self._device0):
|
||||
dataset0 = dataset_ops.Dataset.range(100).map(
|
||||
lambda _: random_ops.random_uniform( # pylint:disable=g-long-lambda
|
||||
[],
|
||||
minval=1,
|
||||
maxval=10,
|
||||
dtype=dtypes.float32))
|
||||
opt = dataset_ops.Options()
|
||||
opt.experimental_external_state_policy = (
|
||||
distribute_options.ExternalStatePolicy.FAIL)
|
||||
dataset0 = dataset0.with_options(opt)
|
||||
with self.assertRaises(errors.FailedPreconditionError):
|
||||
replicated_ds = distribute.replicate(dataset0,
|
||||
[self._device1, self._device2])
|
||||
dataset1 = replicated_ds[self._device1]
|
||||
dataset2 = replicated_ds[self._device2]
|
||||
|
||||
with ops.device(self._device0):
|
||||
get_next0 = self.getNext(dataset0)
|
||||
with ops.device(self._device1):
|
||||
get_next1 = self.getNext(dataset1)
|
||||
with ops.device(self._device2):
|
||||
get_next2 = self.getNext(dataset2)
|
||||
|
||||
for _ in range(100):
|
||||
self.evaluate(get_next0())
|
||||
self.evaluate(get_next1())
|
||||
self.evaluate(get_next2())
|
||||
|
||||
|
||||
JOB_NAME = "remote_device"
|
||||
|
||||
|
|
|
@ -17,7 +17,6 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.data.util import structure
|
||||
|
@ -125,8 +124,7 @@ class _ScanDataset(dataset_ops.UnaryDataset):
|
|||
self._scan_func = wrapped_func
|
||||
self._scan_func.function.add_to_graph(ops.get_default_graph())
|
||||
# pylint: disable=protected-access
|
||||
if compat.forward_compatible(2019, 10,
|
||||
15) or use_default_device is not None:
|
||||
if use_default_device is not None:
|
||||
variant_tensor = gen_experimental_dataset_ops.scan_dataset(
|
||||
self._input_dataset._variant_tensor,
|
||||
structure.to_tensor_list(self._state_structure, self._initial_state),
|
||||
|
|
|
@ -25,7 +25,6 @@ import numpy as np
|
|||
from tensorflow.core.protobuf import cluster_pb2
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.compat import compat as forward_compat
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
|
@ -464,69 +463,68 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testIteratorStringHandleFuture(self):
|
||||
with forward_compat.forward_compatibility_horizon(2018, 8, 4):
|
||||
dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
|
||||
dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40])
|
||||
dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
|
||||
dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40])
|
||||
|
||||
iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3)
|
||||
iterator_4 = dataset_ops.make_one_shot_iterator(dataset_4)
|
||||
iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3)
|
||||
iterator_4 = dataset_ops.make_one_shot_iterator(dataset_4)
|
||||
|
||||
handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])
|
||||
feedable_iterator = iterator_ops.Iterator.from_string_handle(
|
||||
handle_placeholder, dataset_ops.get_legacy_output_types(dataset_3),
|
||||
dataset_ops.get_legacy_output_shapes(dataset_3))
|
||||
next_element = feedable_iterator.get_next()
|
||||
handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])
|
||||
feedable_iterator = iterator_ops.Iterator.from_string_handle(
|
||||
handle_placeholder, dataset_ops.get_legacy_output_types(dataset_3),
|
||||
dataset_ops.get_legacy_output_shapes(dataset_3))
|
||||
next_element = feedable_iterator.get_next()
|
||||
|
||||
self.assertTrue(
|
||||
structure.are_compatible(
|
||||
dataset_ops.get_structure(dataset_3),
|
||||
dataset_ops.get_structure(feedable_iterator)))
|
||||
self.assertTrue(
|
||||
structure.are_compatible(
|
||||
dataset_ops.get_structure(dataset_3),
|
||||
dataset_ops.get_structure(feedable_iterator)))
|
||||
|
||||
with self.cached_session() as sess:
|
||||
iterator_3_handle = sess.run(iterator_3.string_handle())
|
||||
iterator_4_handle = sess.run(iterator_4.string_handle())
|
||||
with self.cached_session() as sess:
|
||||
iterator_3_handle = sess.run(iterator_3.string_handle())
|
||||
iterator_4_handle = sess.run(iterator_4.string_handle())
|
||||
|
||||
self.assertEqual(
|
||||
10,
|
||||
sess.run(
|
||||
next_element,
|
||||
feed_dict={handle_placeholder: iterator_4_handle}))
|
||||
self.assertEqual(
|
||||
1,
|
||||
sess.run(
|
||||
next_element,
|
||||
feed_dict={handle_placeholder: iterator_3_handle}))
|
||||
self.assertEqual(
|
||||
20,
|
||||
sess.run(
|
||||
next_element,
|
||||
feed_dict={handle_placeholder: iterator_4_handle}))
|
||||
self.assertEqual(
|
||||
2,
|
||||
sess.run(
|
||||
next_element,
|
||||
feed_dict={handle_placeholder: iterator_3_handle}))
|
||||
self.assertEqual(
|
||||
30,
|
||||
sess.run(
|
||||
next_element,
|
||||
feed_dict={handle_placeholder: iterator_4_handle}))
|
||||
self.assertEqual(
|
||||
3,
|
||||
sess.run(
|
||||
next_element,
|
||||
feed_dict={handle_placeholder: iterator_3_handle}))
|
||||
self.assertEqual(
|
||||
40,
|
||||
sess.run(
|
||||
next_element,
|
||||
feed_dict={handle_placeholder: iterator_4_handle}))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.assertEqual(
|
||||
10,
|
||||
sess.run(
|
||||
next_element, feed_dict={handle_placeholder: iterator_3_handle})
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
next_element,
|
||||
feed_dict={handle_placeholder: iterator_4_handle}))
|
||||
self.assertEqual(
|
||||
1,
|
||||
sess.run(
|
||||
next_element, feed_dict={handle_placeholder: iterator_4_handle})
|
||||
next_element,
|
||||
feed_dict={handle_placeholder: iterator_3_handle}))
|
||||
self.assertEqual(
|
||||
20,
|
||||
sess.run(
|
||||
next_element,
|
||||
feed_dict={handle_placeholder: iterator_4_handle}))
|
||||
self.assertEqual(
|
||||
2,
|
||||
sess.run(
|
||||
next_element,
|
||||
feed_dict={handle_placeholder: iterator_3_handle}))
|
||||
self.assertEqual(
|
||||
30,
|
||||
sess.run(
|
||||
next_element,
|
||||
feed_dict={handle_placeholder: iterator_4_handle}))
|
||||
self.assertEqual(
|
||||
3,
|
||||
sess.run(
|
||||
next_element,
|
||||
feed_dict={handle_placeholder: iterator_3_handle}))
|
||||
self.assertEqual(
|
||||
40,
|
||||
sess.run(
|
||||
next_element,
|
||||
feed_dict={handle_placeholder: iterator_4_handle}))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(
|
||||
next_element, feed_dict={handle_placeholder: iterator_3_handle})
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(
|
||||
next_element, feed_dict={handle_placeholder: iterator_4_handle})
|
||||
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testIteratorStringHandleReuseTensorObject(self):
|
||||
|
|
|
@ -30,7 +30,6 @@ from six.moves import queue as Queue # pylint: disable=redefined-builtin
|
|||
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.experimental.ops import distribute_options
|
||||
from tensorflow.python.data.experimental.ops import optimization_options
|
||||
from tensorflow.python.data.experimental.ops import stats_options
|
||||
|
@ -223,7 +222,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
|
|||
A scalar `tf.Tensor` of `tf.string` type, representing this dataset as a
|
||||
serialized graph.
|
||||
"""
|
||||
if compat.forward_compatible(2019, 11, 25) or external_state_policy:
|
||||
if external_state_policy:
|
||||
policy = None
|
||||
if external_state_policy:
|
||||
policy = external_state_policy.value
|
||||
|
@ -231,7 +230,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
|
|||
self._variant_tensor,
|
||||
external_state_policy=policy,
|
||||
strip_device_assignment=strip_device_assignment)
|
||||
if compat.forward_compatible(2019, 11, 16) or strip_device_assignment:
|
||||
if strip_device_assignment:
|
||||
return gen_dataset_ops.dataset_to_graph(
|
||||
self._variant_tensor,
|
||||
allow_stateful=allow_stateful,
|
||||
|
|
|
@ -28,7 +28,6 @@ from tensorflow.core.protobuf import config_pb2
|
|||
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
|
@ -395,24 +394,23 @@ class AutoMixedPrecisionTest(test.TestCase):
|
|||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_conv_bn(self):
|
||||
"""Test graph with convolution followed by batch norm."""
|
||||
with compat.forward_compatibility_horizon(2019, 6, 7):
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
random_seed.set_random_seed(0)
|
||||
x = _input([2, 8, 8, 1])
|
||||
x = _conv_bn(x)
|
||||
output = _conv_bn(x)
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
random_seed.set_random_seed(0)
|
||||
x = _input([2, 8, 8, 1])
|
||||
x = _conv_bn(x)
|
||||
output = _conv_bn(x)
|
||||
|
||||
output_val_ref, output_val, cost_graph = self._run(output)
|
||||
node_map = _build_node_map(cost_graph.node)
|
||||
num_to_fp16, num_to_fp32 = _count_casts(cost_graph.node)
|
||||
output_val_ref, output_val, cost_graph = self._run(output)
|
||||
node_map = _build_node_map(cost_graph.node)
|
||||
num_to_fp16, num_to_fp32 = _count_casts(cost_graph.node)
|
||||
|
||||
self._assert_output_fp16(node_map, 'Conv2D')
|
||||
self._assert_output_fp16(node_map, 'FusedBatchNormV3')
|
||||
self._assert_output_fp16(node_map, 'Conv2D_1')
|
||||
self.assertEqual(num_to_fp16,
|
||||
3) # Before Conv2D:0, Conv2D:1, Conv2D_1:1
|
||||
self.assertEqual(num_to_fp32, 1) # After FusedBatchNormV3:0
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
||||
self._assert_output_fp16(node_map, 'Conv2D')
|
||||
self._assert_output_fp16(node_map, 'FusedBatchNormV3')
|
||||
self._assert_output_fp16(node_map, 'Conv2D_1')
|
||||
self.assertEqual(num_to_fp16,
|
||||
3) # Before Conv2D:0, Conv2D:1, Conv2D_1:1
|
||||
self.assertEqual(num_to_fp32, 1) # After FusedBatchNormV3:0
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
||||
|
||||
# TODO: enable these tests when cuDNN is upgraded to >= 7.6.2. Same with the
|
||||
# test_conv3d() below.
|
||||
|
@ -468,31 +466,30 @@ class AutoMixedPrecisionTest(test.TestCase):
|
|||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_conv_bn_dropout(self):
|
||||
"""Test dropout precision of convolution batch norm graph."""
|
||||
with compat.forward_compatibility_horizon(2019, 6, 7):
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
random_seed.set_random_seed(0)
|
||||
x = _input([2, 8, 8, 1])
|
||||
y = _conv_bn(x)
|
||||
y = nn.dropout(y, rate=0.5)
|
||||
y = math_ops.add(y, 1, name='addition')
|
||||
y = _conv_bn(y)
|
||||
y = array_ops.identity(y)
|
||||
optimizer = gradient_descent.GradientDescentOptimizer(
|
||||
learning_rate=0.01)
|
||||
g = optimizer.compute_gradients(y, [x])
|
||||
output = (y, g)
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
random_seed.set_random_seed(0)
|
||||
x = _input([2, 8, 8, 1])
|
||||
y = _conv_bn(x)
|
||||
y = nn.dropout(y, rate=0.5)
|
||||
y = math_ops.add(y, 1, name='addition')
|
||||
y = _conv_bn(y)
|
||||
y = array_ops.identity(y)
|
||||
optimizer = gradient_descent.GradientDescentOptimizer(
|
||||
learning_rate=0.01)
|
||||
g = optimizer.compute_gradients(y, [x])
|
||||
output = (y, g)
|
||||
|
||||
output_val_ref, output_val, cost_graph = self._run(output)
|
||||
node_map = _build_node_map(cost_graph.node)
|
||||
self._assert_output_fp16(node_map, 'Conv2D')
|
||||
self._assert_output_fp16(node_map, 'FusedBatchNormV3')
|
||||
# We do not assert dropout's dtype because we do not want to rely on the
|
||||
# node names of dropout's internal implementation.
|
||||
self._assert_output_fp16(node_map, 'addition')
|
||||
self._assert_output_fp16(node_map, 'Conv2D_1')
|
||||
output_val_ref, output_val, cost_graph = self._run(output)
|
||||
node_map = _build_node_map(cost_graph.node)
|
||||
self._assert_output_fp16(node_map, 'Conv2D')
|
||||
self._assert_output_fp16(node_map, 'FusedBatchNormV3')
|
||||
# We do not assert dropout's dtype because we do not want to rely on the
|
||||
# node names of dropout's internal implementation.
|
||||
self._assert_output_fp16(node_map, 'addition')
|
||||
self._assert_output_fp16(node_map, 'Conv2D_1')
|
||||
|
||||
output_val_ref, output_val, cost_graph = self._run(output)
|
||||
self.assertAllClose(output_val_ref, output_val, atol=2e-3, rtol=2e-3)
|
||||
output_val_ref, output_val, cost_graph = self._run(output)
|
||||
self.assertAllClose(output_val_ref, output_val, atol=2e-3, rtol=2e-3)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
|
|
|
@ -22,7 +22,6 @@ import numpy as np
|
|||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
|
@ -143,9 +142,8 @@ def _GetBatchMatmulOpBroadcastingTest(dtype, adjoint_a, adjoint_b,
|
|||
use_static_shape):
|
||||
|
||||
def Test(self):
|
||||
with compat.forward_compatibility_horizon(2019, 4, 26):
|
||||
np.random.seed(42)
|
||||
self._testBroadcasting(dtype, adjoint_a, adjoint_b, use_static_shape)
|
||||
np.random.seed(42)
|
||||
self._testBroadcasting(dtype, adjoint_a, adjoint_b, use_static_shape)
|
||||
|
||||
return Test
|
||||
|
||||
|
@ -200,14 +198,13 @@ def _GetBatchMatmulGradientWithBroadcastingTest(dtype, adjoint_a, adjoint_b):
|
|||
def CheckGradients(self, a_shape, b_shape):
|
||||
self._compare(a_shape, b_shape, dtype, adjoint_a, adjoint_b)
|
||||
|
||||
with compat.forward_compatibility_horizon(2019, 4, 26):
|
||||
CheckGradients(self, [1, 5, 2, 3], [7, 1, 3, 2])
|
||||
CheckGradients(self, [2, 3], [1, 3, 5])
|
||||
CheckGradients(self, [2, 3], [5, 3, 5])
|
||||
CheckGradients(self, [5, 2, 5], [5, 3])
|
||||
CheckGradients(self, [5, 2, 2, 3], [3, 5])
|
||||
CheckGradients(self, [4, 5, 1, 2, 3], [1, 1, 3, 5])
|
||||
CheckGradients(self, [1, 2, 1, 4, 2, 1, 3, 4], [3, 2, 1, 1, 1, 2, 4, 2])
|
||||
CheckGradients(self, [1, 5, 2, 3], [7, 1, 3, 2])
|
||||
CheckGradients(self, [2, 3], [1, 3, 5])
|
||||
CheckGradients(self, [2, 3], [5, 3, 5])
|
||||
CheckGradients(self, [5, 2, 5], [5, 3])
|
||||
CheckGradients(self, [5, 2, 2, 3], [3, 5])
|
||||
CheckGradients(self, [4, 5, 1, 2, 3], [1, 1, 3, 5])
|
||||
CheckGradients(self, [1, 2, 1, 4, 2, 1, 3, 4], [3, 2, 1, 1, 1, 2, 4, 2])
|
||||
|
||||
return Test
|
||||
|
||||
|
@ -231,38 +228,37 @@ class BatchMatMulBenchmark(test.Benchmark):
|
|||
|
||||
def benchmarkBatchMatMulBroadcast(self):
|
||||
for (a_shape, b_shape) in self.shape_pairs:
|
||||
with compat.forward_compatibility_horizon(2019, 4, 26):
|
||||
with ops.Graph().as_default(), \
|
||||
session.Session(config=benchmark.benchmark_config()) as sess, \
|
||||
ops.device("/cpu:0"):
|
||||
matrix_a = variables.Variable(
|
||||
GetRandomNormalInput(a_shape, np.float32))
|
||||
matrix_b = variables.Variable(
|
||||
GetRandomNormalInput(b_shape, np.float32))
|
||||
variables.global_variables_initializer().run()
|
||||
with ops.Graph().as_default(), \
|
||||
session.Session(config=benchmark.benchmark_config()) as sess, \
|
||||
ops.device("/cpu:0"):
|
||||
matrix_a = variables.Variable(
|
||||
GetRandomNormalInput(a_shape, np.float32))
|
||||
matrix_b = variables.Variable(
|
||||
GetRandomNormalInput(b_shape, np.float32))
|
||||
variables.global_variables_initializer().run()
|
||||
|
||||
# Use batch matmul op's internal broadcasting.
|
||||
self.run_op_benchmark(
|
||||
sess,
|
||||
math_ops.matmul(matrix_a, matrix_b),
|
||||
min_iters=50,
|
||||
name="batch_matmul_cpu_{}_{}".format(a_shape, b_shape))
|
||||
# Use batch matmul op's internal broadcasting.
|
||||
self.run_op_benchmark(
|
||||
sess,
|
||||
math_ops.matmul(matrix_a, matrix_b),
|
||||
min_iters=50,
|
||||
name="batch_matmul_cpu_{}_{}".format(a_shape, b_shape))
|
||||
|
||||
# Manually broadcast the input matrices using the broadcast_to op.
|
||||
broadcasted_batch_shape = array_ops.broadcast_static_shape(
|
||||
matrix_a.shape[:-2], matrix_b.shape[:-2])
|
||||
broadcasted_a_shape = broadcasted_batch_shape.concatenate(
|
||||
matrix_a.shape[-2:])
|
||||
broadcasted_b_shape = broadcasted_batch_shape.concatenate(
|
||||
matrix_b.shape[-2:])
|
||||
self.run_op_benchmark(
|
||||
sess,
|
||||
math_ops.matmul(
|
||||
array_ops.broadcast_to(matrix_a, broadcasted_a_shape),
|
||||
array_ops.broadcast_to(matrix_b, broadcasted_b_shape)),
|
||||
min_iters=50,
|
||||
name="batch_matmul_manual_broadcast_cpu_{}_{}".format(
|
||||
a_shape, b_shape))
|
||||
# Manually broadcast the input matrices using the broadcast_to op.
|
||||
broadcasted_batch_shape = array_ops.broadcast_static_shape(
|
||||
matrix_a.shape[:-2], matrix_b.shape[:-2])
|
||||
broadcasted_a_shape = broadcasted_batch_shape.concatenate(
|
||||
matrix_a.shape[-2:])
|
||||
broadcasted_b_shape = broadcasted_batch_shape.concatenate(
|
||||
matrix_b.shape[-2:])
|
||||
self.run_op_benchmark(
|
||||
sess,
|
||||
math_ops.matmul(
|
||||
array_ops.broadcast_to(matrix_a, broadcasted_a_shape),
|
||||
array_ops.broadcast_to(matrix_b, broadcasted_b_shape)),
|
||||
min_iters=50,
|
||||
name="batch_matmul_manual_broadcast_cpu_{}_{}".format(
|
||||
a_shape, b_shape))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -21,7 +21,6 @@ import itertools
|
|||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes as dtypes_lib
|
||||
from tensorflow.python.framework import ops
|
||||
|
@ -33,15 +32,6 @@ from tensorflow.python.platform import test
|
|||
from tensorflow.python.platform import tf_logging
|
||||
|
||||
|
||||
# LINT.IfChange
|
||||
matrix_diag_v3_forward_compat_date = (2019, 12, 6)
|
||||
# LINT.ThenChange(
|
||||
# //tensorflow/compiler/tests/matrix_diag_ops_test.py,
|
||||
# //tensorflow/python/ops/array_ops.py,
|
||||
# //tensorflow/python/ops/parallel_for/array_test.py
|
||||
# )
|
||||
|
||||
|
||||
default_v2_alignment = "LEFT_LEFT"
|
||||
alignment_list = ["RIGHT_LEFT", "LEFT_RIGHT", "LEFT_LEFT", "RIGHT_RIGHT"]
|
||||
|
||||
|
@ -391,21 +381,20 @@ class MatrixDiagTest(test.TestCase):
|
|||
self.assertEqual((3, 3), v_diag.get_shape())
|
||||
self.assertAllEqual(v_diag.eval(), mat)
|
||||
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
# {Sub,Super}diagonals.
|
||||
for offset in [1, -2, 5]:
|
||||
mat = np.diag(v, offset)
|
||||
v_diag = array_ops.matrix_diag(v, k=offset)
|
||||
self.assertEqual(mat.shape, v_diag.get_shape())
|
||||
self.assertAllEqual(v_diag.eval(), mat)
|
||||
# {Sub,Super}diagonals.
|
||||
for offset in [1, -2, 5]:
|
||||
mat = np.diag(v, offset)
|
||||
v_diag = array_ops.matrix_diag(v, k=offset)
|
||||
self.assertEqual(mat.shape, v_diag.get_shape())
|
||||
self.assertAllEqual(v_diag.eval(), mat)
|
||||
|
||||
# Diagonal bands.
|
||||
for align in alignment_list:
|
||||
for _, tests in [self._moreCases(align), square_cases(align)]:
|
||||
for diags, (vecs, solution) in tests.items():
|
||||
v_diags = array_ops.matrix_diag(vecs[0], k=diags, align=align)
|
||||
self.assertEqual(v_diags.get_shape(), solution[0].shape)
|
||||
self.assertAllEqual(v_diags.eval(), solution[0])
|
||||
# Diagonal bands.
|
||||
for align in alignment_list:
|
||||
for _, tests in [self._moreCases(align), square_cases(align)]:
|
||||
for diags, (vecs, solution) in tests.items():
|
||||
v_diags = array_ops.matrix_diag(vecs[0], k=diags, align=align)
|
||||
self.assertEqual(v_diags.get_shape(), solution[0].shape)
|
||||
self.assertAllEqual(v_diags.eval(), solution[0])
|
||||
|
||||
def _testVectorBatch(self, dtype):
|
||||
with self.cached_session(use_gpu=True):
|
||||
|
@ -417,31 +406,30 @@ class MatrixDiagTest(test.TestCase):
|
|||
self.assertEqual((2, 3, 3), v_batch_diag.get_shape())
|
||||
self.assertAllEqual(v_batch_diag.eval(), mat_batch)
|
||||
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
# {Sub,Super}diagonals.
|
||||
for offset in [1, -2, 5]:
|
||||
v_batch_diag = array_ops.matrix_diag(v_batch, k=offset)
|
||||
mats = [
|
||||
np.diag(v_batch[i], offset) for i in range(0, v_batch.shape[0])
|
||||
]
|
||||
mat_batch = np.stack(mats, axis=0)
|
||||
self.assertEqual(mat_batch.shape, v_batch_diag.get_shape())
|
||||
self.assertAllEqual(v_batch_diag.eval(), mat_batch)
|
||||
# {Sub,Super}diagonals.
|
||||
for offset in [1, -2, 5]:
|
||||
v_batch_diag = array_ops.matrix_diag(v_batch, k=offset)
|
||||
mats = [
|
||||
np.diag(v_batch[i], offset) for i in range(0, v_batch.shape[0])
|
||||
]
|
||||
mat_batch = np.stack(mats, axis=0)
|
||||
self.assertEqual(mat_batch.shape, v_batch_diag.get_shape())
|
||||
self.assertAllEqual(v_batch_diag.eval(), mat_batch)
|
||||
|
||||
# Diagonal bands with padding_value.
|
||||
for padding_value, align in zip_to_first_list_length([0, 555, -11],
|
||||
alignment_list):
|
||||
for _, tests in [self._moreCases(align), square_cases(align)]:
|
||||
for diags, (vecs, solution) in tests.items():
|
||||
v_diags = array_ops.matrix_diag(
|
||||
vecs.astype(dtype),
|
||||
k=diags,
|
||||
padding_value=padding_value,
|
||||
align=align)
|
||||
mask = solution == 0
|
||||
solution = (solution + padding_value * mask).astype(dtype)
|
||||
self.assertEqual(v_diags.get_shape(), solution.shape)
|
||||
self.assertAllEqual(v_diags.eval(), solution)
|
||||
# Diagonal bands with padding_value.
|
||||
for padding_value, align in zip_to_first_list_length([0, 555, -11],
|
||||
alignment_list):
|
||||
for _, tests in [self._moreCases(align), square_cases(align)]:
|
||||
for diags, (vecs, solution) in tests.items():
|
||||
v_diags = array_ops.matrix_diag(
|
||||
vecs.astype(dtype),
|
||||
k=diags,
|
||||
padding_value=padding_value,
|
||||
align=align)
|
||||
mask = solution == 0
|
||||
solution = (solution + padding_value * mask).astype(dtype)
|
||||
self.assertEqual(v_diags.get_shape(), solution.shape)
|
||||
self.assertAllEqual(v_diags.eval(), solution)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testVectorBatch(self):
|
||||
|
@ -453,100 +441,99 @@ class MatrixDiagTest(test.TestCase):
|
|||
|
||||
@test_util.run_deprecated_v1
|
||||
def testRectangularBatch(self):
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
with self.cached_session(use_gpu=True):
|
||||
# Stores expected num_rows and num_cols (when the other is given).
|
||||
# expected[d_lower, d_upper] = (expected_num_rows, expected_num_cols)
|
||||
test_list = list()
|
||||
with self.cached_session(use_gpu=True):
|
||||
# Stores expected num_rows and num_cols (when the other is given).
|
||||
# expected[d_lower, d_upper] = (expected_num_rows, expected_num_cols)
|
||||
test_list = list()
|
||||
|
||||
# Square cases:
|
||||
expected = {
|
||||
(-1, -1): (5, 4),
|
||||
(-4, -3): (5, 2),
|
||||
(-2, 1): (5, 5),
|
||||
(2, 4): (3, 5),
|
||||
}
|
||||
# Do not change alignment yet. Re-alignment needs to happen after the
|
||||
# solution shape is updated.
|
||||
test_list.append((expected, square_cases()))
|
||||
# Square cases:
|
||||
expected = {
|
||||
(-1, -1): (5, 4),
|
||||
(-4, -3): (5, 2),
|
||||
(-2, 1): (5, 5),
|
||||
(2, 4): (3, 5),
|
||||
}
|
||||
# Do not change alignment yet. Re-alignment needs to happen after the
|
||||
# solution shape is updated.
|
||||
test_list.append((expected, square_cases()))
|
||||
|
||||
# More cases:
|
||||
expected = {(-3, -1): (5, 4), (-1, 1): (4, 4), (2, 4): (4, 6)}
|
||||
test_list.append((expected, self._moreCases()))
|
||||
# More cases:
|
||||
expected = {(-3, -1): (5, 4), (-1, 1): (4, 4), (2, 4): (4, 6)}
|
||||
test_list.append((expected, self._moreCases()))
|
||||
|
||||
# Tall cases
|
||||
expected = {
|
||||
(0, 0): (3, 3),
|
||||
(-4, -3): (5, 2),
|
||||
(-2, -1): (4, 3),
|
||||
(-2, 1): (3, 3),
|
||||
(1, 2): (2, 3)
|
||||
}
|
||||
test_list.append((expected, tall_cases()))
|
||||
# Tall cases
|
||||
expected = {
|
||||
(0, 0): (3, 3),
|
||||
(-4, -3): (5, 2),
|
||||
(-2, -1): (4, 3),
|
||||
(-2, 1): (3, 3),
|
||||
(1, 2): (2, 3)
|
||||
}
|
||||
test_list.append((expected, tall_cases()))
|
||||
|
||||
# Fat cases
|
||||
expected = {
|
||||
(2, 2): (2, 4),
|
||||
(-2, 0): (3, 3),
|
||||
(-1, 1): (3, 3),
|
||||
(0, 3): (3, 3)
|
||||
}
|
||||
test_list.append((expected, fat_cases()))
|
||||
# Fat cases
|
||||
expected = {
|
||||
(2, 2): (2, 4),
|
||||
(-2, 0): (3, 3),
|
||||
(-1, 1): (3, 3),
|
||||
(0, 3): (3, 3)
|
||||
}
|
||||
test_list.append((expected, fat_cases()))
|
||||
|
||||
for padding_value, align in zip_to_first_list_length([0, 555, -11],
|
||||
alignment_list):
|
||||
# Giving both num_rows and num_cols
|
||||
for _, tests in [tall_cases(align), fat_cases(align)]:
|
||||
for diags, (vecs, solution) in tests.items():
|
||||
v_diags = array_ops.matrix_diag(
|
||||
vecs,
|
||||
k=diags,
|
||||
num_rows=solution.shape[-2],
|
||||
num_cols=solution.shape[-1],
|
||||
padding_value=padding_value,
|
||||
align=align)
|
||||
mask = solution == 0
|
||||
solution = solution + padding_value * mask
|
||||
self.assertEqual(v_diags.get_shape(), solution.shape)
|
||||
self.assertAllEqual(v_diags.eval(), solution)
|
||||
for padding_value, align in zip_to_first_list_length([0, 555, -11],
|
||||
alignment_list):
|
||||
# Giving both num_rows and num_cols
|
||||
for _, tests in [tall_cases(align), fat_cases(align)]:
|
||||
for diags, (vecs, solution) in tests.items():
|
||||
v_diags = array_ops.matrix_diag(
|
||||
vecs,
|
||||
k=diags,
|
||||
num_rows=solution.shape[-2],
|
||||
num_cols=solution.shape[-1],
|
||||
padding_value=padding_value,
|
||||
align=align)
|
||||
mask = solution == 0
|
||||
solution = solution + padding_value * mask
|
||||
self.assertEqual(v_diags.get_shape(), solution.shape)
|
||||
self.assertAllEqual(v_diags.eval(), solution)
|
||||
|
||||
# Giving just num_rows.
|
||||
for expected, (_, tests) in test_list:
|
||||
for diags, (_, new_num_cols) in expected.items():
|
||||
vecs, solution = tests[diags]
|
||||
solution = solution.take(indices=range(new_num_cols), axis=-1)
|
||||
# Repacks the diagonal input according to the new solution shape.
|
||||
vecs = repack_diagonals(
|
||||
vecs, diags, solution.shape[-2], new_num_cols, align=align)
|
||||
v_diags = array_ops.matrix_diag(
|
||||
vecs,
|
||||
k=diags,
|
||||
num_rows=solution.shape[-2],
|
||||
padding_value=padding_value,
|
||||
align=align)
|
||||
mask = solution == 0
|
||||
solution = solution + padding_value * mask
|
||||
self.assertEqual(v_diags.get_shape(), solution.shape)
|
||||
self.assertAllEqual(v_diags.eval(), solution)
|
||||
# Giving just num_rows.
|
||||
for expected, (_, tests) in test_list:
|
||||
for diags, (_, new_num_cols) in expected.items():
|
||||
vecs, solution = tests[diags]
|
||||
solution = solution.take(indices=range(new_num_cols), axis=-1)
|
||||
# Repacks the diagonal input according to the new solution shape.
|
||||
vecs = repack_diagonals(
|
||||
vecs, diags, solution.shape[-2], new_num_cols, align=align)
|
||||
v_diags = array_ops.matrix_diag(
|
||||
vecs,
|
||||
k=diags,
|
||||
num_rows=solution.shape[-2],
|
||||
padding_value=padding_value,
|
||||
align=align)
|
||||
mask = solution == 0
|
||||
solution = solution + padding_value * mask
|
||||
self.assertEqual(v_diags.get_shape(), solution.shape)
|
||||
self.assertAllEqual(v_diags.eval(), solution)
|
||||
|
||||
# Giving just num_cols.
|
||||
for expected, (_, tests) in test_list:
|
||||
for diags, (new_num_rows, _) in expected.items():
|
||||
vecs, solution = tests[diags]
|
||||
solution = solution.take(indices=range(new_num_rows), axis=-2)
|
||||
# Repacks the diagonal input according to the new solution shape.
|
||||
vecs = repack_diagonals(
|
||||
vecs, diags, new_num_rows, solution.shape[-1], align=align)
|
||||
v_diags = array_ops.matrix_diag(
|
||||
vecs,
|
||||
k=diags,
|
||||
num_cols=solution.shape[-1],
|
||||
padding_value=padding_value,
|
||||
align=align)
|
||||
mask = solution == 0
|
||||
solution = solution + padding_value * mask
|
||||
self.assertEqual(v_diags.get_shape(), solution.shape)
|
||||
self.assertAllEqual(v_diags.eval(), solution)
|
||||
# Giving just num_cols.
|
||||
for expected, (_, tests) in test_list:
|
||||
for diags, (new_num_rows, _) in expected.items():
|
||||
vecs, solution = tests[diags]
|
||||
solution = solution.take(indices=range(new_num_rows), axis=-2)
|
||||
# Repacks the diagonal input according to the new solution shape.
|
||||
vecs = repack_diagonals(
|
||||
vecs, diags, new_num_rows, solution.shape[-1], align=align)
|
||||
v_diags = array_ops.matrix_diag(
|
||||
vecs,
|
||||
k=diags,
|
||||
num_cols=solution.shape[-1],
|
||||
padding_value=padding_value,
|
||||
align=align)
|
||||
mask = solution == 0
|
||||
solution = solution + padding_value * mask
|
||||
self.assertEqual(v_diags.get_shape(), solution.shape)
|
||||
self.assertAllEqual(v_diags.eval(), solution)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testInvalidShape(self):
|
||||
|
@ -574,21 +561,20 @@ class MatrixDiagTest(test.TestCase):
|
|||
y.get_shape().as_list())
|
||||
self.assertLess(error, 1e-4)
|
||||
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
# {Sub,super}diagonals/band.
|
||||
tests = dict() # tests[shape] = (d_lower, d_upper)
|
||||
tests[(3,)] = (-1, -1)
|
||||
tests[(7, 3, 4)] = (-1, 1)
|
||||
with self.session(use_gpu=True):
|
||||
for shape, diags in tests.items():
|
||||
x = constant_op.constant(np.random.rand(*shape), np.float32)
|
||||
for align in alignment_list:
|
||||
y = array_ops.matrix_diag(x, k=diags, align=align)
|
||||
error = gradient_checker.compute_gradient_error(
|
||||
x,
|
||||
x.get_shape().as_list(), y,
|
||||
y.get_shape().as_list())
|
||||
self.assertLess(error, 1e-4)
|
||||
# {Sub,super}diagonals/band.
|
||||
tests = dict() # tests[shape] = (d_lower, d_upper)
|
||||
tests[(3,)] = (-1, -1)
|
||||
tests[(7, 3, 4)] = (-1, 1)
|
||||
with self.session(use_gpu=True):
|
||||
for shape, diags in tests.items():
|
||||
x = constant_op.constant(np.random.rand(*shape), np.float32)
|
||||
for align in alignment_list:
|
||||
y = array_ops.matrix_diag(x, k=diags, align=align)
|
||||
error = gradient_checker.compute_gradient_error(
|
||||
x,
|
||||
x.get_shape().as_list(), y,
|
||||
y.get_shape().as_list())
|
||||
self.assertLess(error, 1e-4)
|
||||
|
||||
|
||||
class MatrixSetDiagTest(test.TestCase):
|
||||
|
@ -604,18 +590,17 @@ class MatrixSetDiagTest(test.TestCase):
|
|||
self.assertEqual((3, 3), output.get_shape())
|
||||
self.assertAllEqual(mat_set_diag, self.evaluate(output))
|
||||
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
# Diagonal bands.
|
||||
for align in alignment_list:
|
||||
_, tests = square_cases(align)
|
||||
for diags, (vecs, banded_mat) in tests.items():
|
||||
mask = banded_mat[0] == 0
|
||||
input_mat = np.random.randint(10, size=mask.shape)
|
||||
solution = input_mat * mask + banded_mat[0]
|
||||
output = array_ops.matrix_set_diag(
|
||||
input_mat, vecs[0], k=diags, align=align)
|
||||
self.assertEqual(output.get_shape(), solution.shape)
|
||||
self.assertAllEqual(output.eval(), solution)
|
||||
# Diagonal bands.
|
||||
for align in alignment_list:
|
||||
_, tests = square_cases(align)
|
||||
for diags, (vecs, banded_mat) in tests.items():
|
||||
mask = banded_mat[0] == 0
|
||||
input_mat = np.random.randint(10, size=mask.shape)
|
||||
solution = input_mat * mask + banded_mat[0]
|
||||
output = array_ops.matrix_set_diag(
|
||||
input_mat, vecs[0], k=diags, align=align)
|
||||
self.assertEqual(output.get_shape(), solution.shape)
|
||||
self.assertAllEqual(output.eval(), solution)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testRectangular(self):
|
||||
|
@ -634,18 +619,17 @@ class MatrixSetDiagTest(test.TestCase):
|
|||
self.assertEqual((3, 2), output.get_shape())
|
||||
self.assertAllEqual(expected, self.evaluate(output))
|
||||
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
# Diagonal bands.
|
||||
for align in alignment_list:
|
||||
for _, tests in [tall_cases(align), fat_cases(align)]:
|
||||
for diags, (vecs, banded_mat) in tests.items():
|
||||
mask = banded_mat[0] == 0
|
||||
input_mat = np.random.randint(10, size=mask.shape)
|
||||
solution = input_mat * mask + banded_mat[0]
|
||||
output = array_ops.matrix_set_diag(
|
||||
input_mat, vecs[0], k=diags, align=align)
|
||||
self.assertEqual(output.get_shape(), solution.shape)
|
||||
self.assertAllEqual(output.eval(), solution)
|
||||
# Diagonal bands.
|
||||
for align in alignment_list:
|
||||
for _, tests in [tall_cases(align), fat_cases(align)]:
|
||||
for diags, (vecs, banded_mat) in tests.items():
|
||||
mask = banded_mat[0] == 0
|
||||
input_mat = np.random.randint(10, size=mask.shape)
|
||||
solution = input_mat * mask + banded_mat[0]
|
||||
output = array_ops.matrix_set_diag(
|
||||
input_mat, vecs[0], k=diags, align=align)
|
||||
self.assertEqual(output.get_shape(), solution.shape)
|
||||
self.assertAllEqual(output.eval(), solution)
|
||||
|
||||
def _testSquareBatch(self, dtype):
|
||||
with self.cached_session(use_gpu=True):
|
||||
|
@ -663,18 +647,17 @@ class MatrixSetDiagTest(test.TestCase):
|
|||
self.assertEqual((2, 3, 3), output.get_shape())
|
||||
self.assertAllEqual(mat_set_diag_batch, self.evaluate(output))
|
||||
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
# Diagonal bands.
|
||||
for align in alignment_list:
|
||||
_, tests = square_cases(align)
|
||||
for diags, (vecs, banded_mat) in tests.items():
|
||||
mask = banded_mat == 0
|
||||
input_mat = np.random.randint(10, size=mask.shape).astype(dtype)
|
||||
solution = (input_mat * mask + banded_mat).astype(dtype)
|
||||
output = array_ops.matrix_set_diag(
|
||||
input_mat, vecs.astype(dtype), k=diags, align=align)
|
||||
self.assertEqual(output.get_shape(), solution.shape)
|
||||
self.assertAllEqual(output.eval(), solution)
|
||||
# Diagonal bands.
|
||||
for align in alignment_list:
|
||||
_, tests = square_cases(align)
|
||||
for diags, (vecs, banded_mat) in tests.items():
|
||||
mask = banded_mat == 0
|
||||
input_mat = np.random.randint(10, size=mask.shape).astype(dtype)
|
||||
solution = (input_mat * mask + banded_mat).astype(dtype)
|
||||
output = array_ops.matrix_set_diag(
|
||||
input_mat, vecs.astype(dtype), k=diags, align=align)
|
||||
self.assertEqual(output.get_shape(), solution.shape)
|
||||
self.assertAllEqual(output.eval(), solution)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSquareBatch(self):
|
||||
|
@ -697,19 +680,18 @@ class MatrixSetDiagTest(test.TestCase):
|
|||
self.assertEqual((2, 2, 3), output.get_shape())
|
||||
self.assertAllEqual(mat_set_diag_batch, self.evaluate(output))
|
||||
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
# Diagonal bands.
|
||||
for align in alignment_list:
|
||||
for _, tests in [tall_cases(align), fat_cases(align)]:
|
||||
for diags, pair in tests.items():
|
||||
vecs, banded_mat = pair
|
||||
mask = banded_mat == 0
|
||||
input_mat = np.random.randint(10, size=mask.shape)
|
||||
solution = input_mat * mask + banded_mat
|
||||
output = array_ops.matrix_set_diag(
|
||||
input_mat, vecs, k=diags, align=align)
|
||||
self.assertEqual(output.get_shape(), solution.shape)
|
||||
self.assertAllEqual(output.eval(), solution)
|
||||
# Diagonal bands.
|
||||
for align in alignment_list:
|
||||
for _, tests in [tall_cases(align), fat_cases(align)]:
|
||||
for diags, pair in tests.items():
|
||||
vecs, banded_mat = pair
|
||||
mask = banded_mat == 0
|
||||
input_mat = np.random.randint(10, size=mask.shape)
|
||||
solution = input_mat * mask + banded_mat
|
||||
output = array_ops.matrix_set_diag(
|
||||
input_mat, vecs, k=diags, align=align)
|
||||
self.assertEqual(output.get_shape(), solution.shape)
|
||||
self.assertAllEqual(output.eval(), solution)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testInvalidShape(self):
|
||||
|
@ -727,14 +709,13 @@ class MatrixSetDiagTest(test.TestCase):
|
|||
with self.assertRaisesOpError("diagonal must be at least 1-dim"):
|
||||
array_ops.matrix_set_diag([[v]], v).eval(feed_dict={v: 0.0})
|
||||
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
d = array_ops.placeholder(dtype=dtypes_lib.float32)
|
||||
with self.assertRaisesOpError(
|
||||
"first dimensions of diagonal don't match"):
|
||||
array_ops.matrix_set_diag(v, d).eval(feed_dict={
|
||||
v: np.zeros((2, 3, 3)),
|
||||
d: np.ones((2, 4))
|
||||
})
|
||||
d = array_ops.placeholder(dtype=dtypes_lib.float32)
|
||||
with self.assertRaisesOpError(
|
||||
"first dimensions of diagonal don't match"):
|
||||
array_ops.matrix_set_diag(v, d).eval(feed_dict={
|
||||
v: np.zeros((2, 3, 3)),
|
||||
d: np.ones((2, 4))
|
||||
})
|
||||
|
||||
def _testGrad(self, input_shape, diag_shape, diags, align):
|
||||
with self.session(use_gpu=True):
|
||||
|
@ -743,10 +724,7 @@ class MatrixSetDiagTest(test.TestCase):
|
|||
x_diag = constant_op.constant(
|
||||
np.random.rand(*diag_shape), dtype=dtypes_lib.float32)
|
||||
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
y = array_ops.matrix_set_diag(x, x_diag, k=diags, align=align)
|
||||
else:
|
||||
y = array_ops.matrix_set_diag(x, x_diag)
|
||||
y = array_ops.matrix_set_diag(x, x_diag, k=diags, align=align)
|
||||
error_x = gradient_checker.compute_gradient_error(x,
|
||||
x.get_shape().as_list(),
|
||||
y,
|
||||
|
@ -763,8 +741,7 @@ class MatrixSetDiagTest(test.TestCase):
|
|||
input_shapes = [(3, 4, 4), (3, 3, 4), (3, 4, 3), (7, 4, 8, 8)]
|
||||
diag_bands = [(0, 0)]
|
||||
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
diag_bands.append((-1, 1))
|
||||
diag_bands.append((-1, 1))
|
||||
for input_shape, diags, align in itertools.product(input_shapes, diag_bands,
|
||||
alignment_list):
|
||||
lower_diag_index, upper_diag_index = diags
|
||||
|
@ -805,21 +782,20 @@ class MatrixDiagPartTest(test.TestCase):
|
|||
self.assertEqual((3,), mat_diag.get_shape())
|
||||
self.assertAllEqual(mat_diag.eval(), v)
|
||||
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
for offset in [-2, 3]:
|
||||
mat = np.diag(v, offset)
|
||||
mat_diag = array_ops.matrix_diag_part(mat, k=offset)
|
||||
self.assertEqual((3,), mat_diag.get_shape())
|
||||
self.assertAllEqual(mat_diag.eval(), v)
|
||||
for offset in [-2, 3]:
|
||||
mat = np.diag(v, offset)
|
||||
mat_diag = array_ops.matrix_diag_part(mat, k=offset)
|
||||
self.assertEqual((3,), mat_diag.get_shape())
|
||||
self.assertAllEqual(mat_diag.eval(), v)
|
||||
|
||||
# Diagonal bands.
|
||||
for align in alignment_list:
|
||||
mat, tests = square_cases(align)
|
||||
for diags, pair in tests.items():
|
||||
solution, _ = pair
|
||||
mat_diag = array_ops.matrix_diag_part(mat[0], k=diags, align=align)
|
||||
self.assertEqual(mat_diag.get_shape(), solution[0].shape)
|
||||
self.assertAllEqual(mat_diag.eval(), solution[0])
|
||||
# Diagonal bands.
|
||||
for align in alignment_list:
|
||||
mat, tests = square_cases(align)
|
||||
for diags, pair in tests.items():
|
||||
solution, _ = pair
|
||||
mat_diag = array_ops.matrix_diag_part(mat[0], k=diags, align=align)
|
||||
self.assertEqual(mat_diag.get_shape(), solution[0].shape)
|
||||
self.assertAllEqual(mat_diag.eval(), solution[0])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testRectangular(self):
|
||||
|
@ -831,16 +807,15 @@ class MatrixDiagPartTest(test.TestCase):
|
|||
mat_diag = array_ops.matrix_diag_part(mat)
|
||||
self.assertAllEqual(mat_diag.eval(), np.array([1.0, 4.0]))
|
||||
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
# Diagonal bands.
|
||||
for align in alignment_list:
|
||||
for mat, tests in [tall_cases(align), fat_cases(align)]:
|
||||
for diags, pair in tests.items():
|
||||
solution, _ = pair
|
||||
mat_diag = array_ops.matrix_diag_part(
|
||||
mat[0], k=diags, align=align)
|
||||
self.assertEqual(mat_diag.get_shape(), solution[0].shape)
|
||||
self.assertAllEqual(mat_diag.eval(), solution[0])
|
||||
# Diagonal bands.
|
||||
for align in alignment_list:
|
||||
for mat, tests in [tall_cases(align), fat_cases(align)]:
|
||||
for diags, pair in tests.items():
|
||||
solution, _ = pair
|
||||
mat_diag = array_ops.matrix_diag_part(
|
||||
mat[0], k=diags, align=align)
|
||||
self.assertEqual(mat_diag.get_shape(), solution[0].shape)
|
||||
self.assertAllEqual(mat_diag.eval(), solution[0])
|
||||
|
||||
def _testSquareBatch(self, dtype):
|
||||
with self.cached_session(use_gpu=True):
|
||||
|
@ -853,22 +828,21 @@ class MatrixDiagPartTest(test.TestCase):
|
|||
self.assertEqual((2, 3), mat_batch_diag.get_shape())
|
||||
self.assertAllEqual(mat_batch_diag.eval(), v_batch)
|
||||
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
# Diagonal bands with padding_value.
|
||||
for padding_value, align in zip_to_first_list_length([0, 555, -11],
|
||||
alignment_list):
|
||||
mat, tests = square_cases(align)
|
||||
for diags, pair in tests.items():
|
||||
solution, _ = pair
|
||||
mat_batch_diag = array_ops.matrix_diag_part(
|
||||
mat.astype(dtype),
|
||||
k=diags,
|
||||
padding_value=padding_value,
|
||||
align=align)
|
||||
mask = solution == 0
|
||||
solution = (solution + padding_value * mask).astype(dtype)
|
||||
self.assertEqual(mat_batch_diag.get_shape(), solution.shape)
|
||||
self.assertAllEqual(mat_batch_diag.eval(), solution)
|
||||
# Diagonal bands with padding_value.
|
||||
for padding_value, align in zip_to_first_list_length([0, 555, -11],
|
||||
alignment_list):
|
||||
mat, tests = square_cases(align)
|
||||
for diags, pair in tests.items():
|
||||
solution, _ = pair
|
||||
mat_batch_diag = array_ops.matrix_diag_part(
|
||||
mat.astype(dtype),
|
||||
k=diags,
|
||||
padding_value=padding_value,
|
||||
align=align)
|
||||
mask = solution == 0
|
||||
solution = (solution + padding_value * mask).astype(dtype)
|
||||
self.assertEqual(mat_batch_diag.get_shape(), solution.shape)
|
||||
self.assertAllEqual(mat_batch_diag.eval(), solution)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSquareBatch(self):
|
||||
|
@ -889,29 +863,27 @@ class MatrixDiagPartTest(test.TestCase):
|
|||
self.assertEqual((2, 2), mat_batch_diag.get_shape())
|
||||
self.assertAllEqual(mat_batch_diag.eval(), v_batch)
|
||||
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
# Diagonal bands with padding_value and align.
|
||||
for padding_value, align in zip_to_first_list_length([0, 555, -11],
|
||||
alignment_list):
|
||||
for mat, tests in [tall_cases(align), fat_cases(align)]:
|
||||
for diags, pair in tests.items():
|
||||
solution, _ = pair
|
||||
mat_batch_diag = array_ops.matrix_diag_part(
|
||||
mat, k=diags, padding_value=padding_value, align=align)
|
||||
mask = solution == 0
|
||||
solution = solution + padding_value * mask
|
||||
self.assertEqual(mat_batch_diag.get_shape(), solution.shape)
|
||||
self.assertAllEqual(mat_batch_diag.eval(), solution)
|
||||
# Diagonal bands with padding_value and align.
|
||||
for padding_value, align in zip_to_first_list_length([0, 555, -11],
|
||||
alignment_list):
|
||||
for mat, tests in [tall_cases(align), fat_cases(align)]:
|
||||
for diags, pair in tests.items():
|
||||
solution, _ = pair
|
||||
mat_batch_diag = array_ops.matrix_diag_part(
|
||||
mat, k=diags, padding_value=padding_value, align=align)
|
||||
mask = solution == 0
|
||||
solution = solution + padding_value * mask
|
||||
self.assertEqual(mat_batch_diag.get_shape(), solution.shape)
|
||||
self.assertAllEqual(mat_batch_diag.eval(), solution)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testUnknownShape(self):
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
matrix = array_ops.placeholder(dtypes_lib.int32, shape=[None, None])
|
||||
result = array_ops.matrix_diag_part(matrix, k=-1)
|
||||
input_matrix = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
|
||||
with self.session(use_gpu=True):
|
||||
result_eval = result.eval(feed_dict={matrix: input_matrix})
|
||||
self.assertAllEqual([4, 8], result_eval)
|
||||
matrix = array_ops.placeholder(dtypes_lib.int32, shape=[None, None])
|
||||
result = array_ops.matrix_diag_part(matrix, k=-1)
|
||||
input_matrix = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
|
||||
with self.session(use_gpu=True):
|
||||
result_eval = result.eval(feed_dict={matrix: input_matrix})
|
||||
self.assertAllEqual([4, 8], result_eval)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testInvalidShape(self):
|
||||
|
@ -939,21 +911,20 @@ class MatrixDiagPartTest(test.TestCase):
|
|||
y.get_shape().as_list())
|
||||
self.assertLess(error, 1e-4)
|
||||
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
# {Sub,super}diagonals/band.
|
||||
tests = dict() # tests[shape] = (d_lower, d_upper)
|
||||
tests[(3, 3)] = (-1, -1)
|
||||
tests[(7, 3, 4)] = (-1, 1)
|
||||
with self.session(use_gpu=True):
|
||||
for align in alignment_list:
|
||||
for shape, diags in tests.items():
|
||||
x = constant_op.constant(np.random.rand(*shape), np.float32)
|
||||
y = array_ops.matrix_diag_part(input=x, k=diags, align=align)
|
||||
error = gradient_checker.compute_gradient_error(
|
||||
x,
|
||||
x.get_shape().as_list(), y,
|
||||
y.get_shape().as_list())
|
||||
self.assertLess(error, 1e-4)
|
||||
# {Sub,super}diagonals/band.
|
||||
tests = dict() # tests[shape] = (d_lower, d_upper)
|
||||
tests[(3, 3)] = (-1, -1)
|
||||
tests[(7, 3, 4)] = (-1, 1)
|
||||
with self.session(use_gpu=True):
|
||||
for align in alignment_list:
|
||||
for shape, diags in tests.items():
|
||||
x = constant_op.constant(np.random.rand(*shape), np.float32)
|
||||
y = array_ops.matrix_diag_part(input=x, k=diags, align=align)
|
||||
error = gradient_checker.compute_gradient_error(
|
||||
x,
|
||||
x.get_shape().as_list(), y,
|
||||
y.get_shape().as_list())
|
||||
self.assertLess(error, 1e-4)
|
||||
|
||||
|
||||
class DiagTest(test.TestCase):
|
||||
|
|
|
@ -22,7 +22,6 @@ import numpy as np
|
|||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
|
@ -122,7 +121,6 @@ class ReluTest(test.TestCase):
|
|||
order="F")
|
||||
err = gradient_checker_v2.max_error(
|
||||
*gradient_checker_v2.compute_gradient(nn_ops.relu, [x]))
|
||||
print("relu (float32) gradient err = ", err)
|
||||
self.assertLess(err, 1e-4)
|
||||
|
||||
# The gradient for fp16 is inaccurate due to the low-precision.
|
||||
|
@ -171,7 +169,6 @@ class ReluTest(test.TestCase):
|
|||
order="F")
|
||||
err = gradient_checker_v2.max_error(
|
||||
*gradient_checker_v2.compute_gradient(nn_ops.relu, [x]))
|
||||
print("relu (float64) gradient err = ", err)
|
||||
self.assertLess(err, 1e-10)
|
||||
|
||||
def testGradGradFloat32(self):
|
||||
|
@ -190,7 +187,6 @@ class ReluTest(test.TestCase):
|
|||
order="F")
|
||||
err = gradient_checker_v2.max_error(
|
||||
*gradient_checker_v2.compute_gradient(f, [x]))
|
||||
print("relu (float32) gradient of gradient err = ", err)
|
||||
self.assertLess(err, 1e-4)
|
||||
|
||||
def testGradGradFloat64(self):
|
||||
|
@ -209,7 +205,6 @@ class ReluTest(test.TestCase):
|
|||
order="F")
|
||||
err = gradient_checker_v2.max_error(
|
||||
*gradient_checker_v2.compute_gradient(f, [x]))
|
||||
print("relu (float64) gradient of gradient err = ", err)
|
||||
self.assertLess(err, 1e-10)
|
||||
|
||||
def testGradientScalar(self):
|
||||
|
@ -283,7 +278,6 @@ class Relu6Test(test.TestCase):
|
|||
order="F")
|
||||
err = gradient_checker_v2.max_error(
|
||||
*gradient_checker_v2.compute_gradient(nn_ops.relu6, [x]))
|
||||
print("relu6 (float32) gradient err = ", err)
|
||||
self.assertLess(err, 1e-4)
|
||||
|
||||
def testGradientFloat64(self):
|
||||
|
@ -294,7 +288,6 @@ class Relu6Test(test.TestCase):
|
|||
order="F")
|
||||
err = gradient_checker_v2.max_error(
|
||||
*gradient_checker_v2.compute_gradient(nn_ops.relu6, [x]))
|
||||
print("relu6 (float64) gradient err = ", err)
|
||||
self.assertLess(err, 1e-10)
|
||||
|
||||
|
||||
|
@ -345,7 +338,6 @@ class LeakyReluTest(test.TestCase):
|
|||
order="F")
|
||||
err = gradient_checker_v2.max_error(
|
||||
*gradient_checker_v2.compute_gradient(nn_ops.leaky_relu, [x]))
|
||||
print("leaky_relu (float32) gradient err = ", err)
|
||||
self.assertLess(err, 1e-4)
|
||||
|
||||
def testGradientFloat64(self):
|
||||
|
@ -356,48 +348,43 @@ class LeakyReluTest(test.TestCase):
|
|||
order="F")
|
||||
err = gradient_checker_v2.max_error(
|
||||
*gradient_checker_v2.compute_gradient(nn_ops.leaky_relu, [x]))
|
||||
print("leaky_relu (float64) gradient err = ", err)
|
||||
self.assertLess(err, 1e-10)
|
||||
|
||||
def testGradGradFloat32(self):
|
||||
with compat.forward_compatibility_horizon(2018, 11, 2):
|
||||
with self.cached_session():
|
||||
with self.cached_session():
|
||||
|
||||
def f(x):
|
||||
assert x.dtype == dtypes.float32
|
||||
with backprop.GradientTape() as tape:
|
||||
tape.watch(x)
|
||||
y = nn_ops.leaky_relu(x)
|
||||
return tape.gradient(y, x)
|
||||
def f(x):
|
||||
assert x.dtype == dtypes.float32
|
||||
with backprop.GradientTape() as tape:
|
||||
tape.watch(x)
|
||||
y = nn_ops.leaky_relu(x)
|
||||
return tape.gradient(y, x)
|
||||
|
||||
x = np.asarray(
|
||||
[[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
|
||||
dtype=np.float32,
|
||||
order="F")
|
||||
err = gradient_checker_v2.max_error(
|
||||
*gradient_checker_v2.compute_gradient(f, [x]))
|
||||
print("leaky_relu (float32) gradient of gradient err = ", err)
|
||||
self.assertLess(err, 1e-4)
|
||||
x = np.asarray(
|
||||
[[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
|
||||
dtype=np.float32,
|
||||
order="F")
|
||||
err = gradient_checker_v2.max_error(
|
||||
*gradient_checker_v2.compute_gradient(f, [x]))
|
||||
self.assertLess(err, 1e-4)
|
||||
|
||||
def testGradGradFloat64(self):
|
||||
with compat.forward_compatibility_horizon(2018, 11, 2):
|
||||
with self.cached_session():
|
||||
with self.cached_session():
|
||||
|
||||
def f(x):
|
||||
assert x.dtype == dtypes.float64
|
||||
with backprop.GradientTape() as tape:
|
||||
tape.watch(x)
|
||||
y = nn_ops.leaky_relu(x)
|
||||
return tape.gradient(y, x)
|
||||
def f(x):
|
||||
assert x.dtype == dtypes.float64
|
||||
with backprop.GradientTape() as tape:
|
||||
tape.watch(x)
|
||||
y = nn_ops.leaky_relu(x)
|
||||
return tape.gradient(y, x)
|
||||
|
||||
x = np.asarray(
|
||||
[[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
|
||||
dtype=np.float64,
|
||||
order="F")
|
||||
err = gradient_checker_v2.max_error(
|
||||
*gradient_checker_v2.compute_gradient(f, [x]))
|
||||
print("leaky_relu (float64) gradient of gradient err = ", err)
|
||||
self.assertLess(err, 1e-10)
|
||||
x = np.asarray(
|
||||
[[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
|
||||
dtype=np.float64,
|
||||
order="F")
|
||||
err = gradient_checker_v2.max_error(
|
||||
*gradient_checker_v2.compute_gradient(f, [x]))
|
||||
self.assertLess(err, 1e-10)
|
||||
|
||||
def testGradientScalar(self):
|
||||
x = variables.Variable(-100.)
|
||||
|
@ -463,7 +450,6 @@ class EluTest(test.TestCase):
|
|||
x = np.asarray(x_val, dtype=np.float32, order="F")
|
||||
err = gradient_checker_v2.max_error(
|
||||
*gradient_checker_v2.compute_gradient(nn_ops.elu, [x]))
|
||||
print("elu (float32) gradient err = ", err)
|
||||
self.assertLess(err, 1e-4)
|
||||
|
||||
def testGradientFloat64(self):
|
||||
|
@ -472,7 +458,6 @@ class EluTest(test.TestCase):
|
|||
x = np.asarray(x_val, dtype=np.float64, order="F")
|
||||
err = gradient_checker_v2.max_error(
|
||||
*gradient_checker_v2.compute_gradient(nn_ops.elu, [x]))
|
||||
print("elu (float64) gradient err = ", err)
|
||||
self.assertLess(err, 1e-6)
|
||||
|
||||
def testGradGrad(self):
|
||||
|
@ -507,7 +492,6 @@ class EluTest(test.TestCase):
|
|||
order="F")
|
||||
err = gradient_checker_v2.max_error(
|
||||
*gradient_checker_v2.compute_gradient(f, [x]))
|
||||
print("elu (float32) gradient of gradient err = ", err)
|
||||
self.assertLess(err, 1e-4)
|
||||
|
||||
def testGradGradFloat64(self):
|
||||
|
@ -526,7 +510,6 @@ class EluTest(test.TestCase):
|
|||
order="F")
|
||||
err = gradient_checker_v2.max_error(
|
||||
*gradient_checker_v2.compute_gradient(f, [x]))
|
||||
print("elu (float64) gradient of gradient err = ", err)
|
||||
self.assertLess(err, 1e-6)
|
||||
|
||||
|
||||
|
@ -567,7 +550,6 @@ class SeluTest(test.TestCase):
|
|||
x = np.asarray(x_val, dtype=np.float32, order="F")
|
||||
err = gradient_checker_v2.max_error(
|
||||
*gradient_checker_v2.compute_gradient(nn_ops.selu, [x]))
|
||||
print("selu (float32) gradient err = ", err)
|
||||
self.assertLess(err, 1e-4)
|
||||
|
||||
def testGradientFloat64(self):
|
||||
|
@ -576,7 +558,6 @@ class SeluTest(test.TestCase):
|
|||
x = np.asarray(x_val, dtype=np.float64, order="F")
|
||||
err = gradient_checker_v2.max_error(
|
||||
*gradient_checker_v2.compute_gradient(nn_ops.selu, [x]))
|
||||
print("selu (float64) gradient err = ", err)
|
||||
self.assertLess(err, 1e-6)
|
||||
|
||||
def testGradGradFloat32(self):
|
||||
|
@ -595,7 +576,6 @@ class SeluTest(test.TestCase):
|
|||
order="F")
|
||||
err = gradient_checker_v2.max_error(
|
||||
*gradient_checker_v2.compute_gradient(f, [x]))
|
||||
print("selu (float32) gradient of gradient err = ", err)
|
||||
self.assertLess(err, 1e-4)
|
||||
|
||||
def testGradGradFloat64(self):
|
||||
|
@ -614,7 +594,6 @@ class SeluTest(test.TestCase):
|
|||
order="F")
|
||||
err = gradient_checker_v2.max_error(
|
||||
*gradient_checker_v2.compute_gradient(f, [x]))
|
||||
print("selu (float64) gradient of gradient err = ", err)
|
||||
self.assertLess(err, 1e-6)
|
||||
|
||||
|
||||
|
|
|
@ -22,7 +22,6 @@ from __future__ import print_function
|
|||
import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import common_shapes
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
|
@ -54,13 +53,6 @@ tf_export("newaxis").export_constant(__name__, "newaxis")
|
|||
# existing 'slice' for later use in this module.
|
||||
_BaseSlice = slice
|
||||
|
||||
# LINT.IfChange
|
||||
matrix_diag_v3_forward_compat_date = (2019, 12, 6)
|
||||
# LINT.ThenChange(
|
||||
# //tensorflow/compiler/tests/matrix_diag_ops_test.py,
|
||||
# //tensorflow/python/kernel_tests/diag_op_test.py,
|
||||
# //tensorflow/python/ops/parallel_for/array_test.py
|
||||
# )
|
||||
|
||||
@tf_export("reshape", v1=["reshape", "manip.reshape"])
|
||||
def reshape(tensor, shape, name=None): # pylint: disable=redefined-outer-name
|
||||
|
@ -2362,24 +2354,19 @@ def matrix_diag(diagonal,
|
|||
Returns:
|
||||
A Tensor. Has the same type as `diagonal`.
|
||||
"""
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
# Special case to sidestep the tf.constant conversion error:
|
||||
# TypeError: Expected bool, got 0 of type 'int' instead.
|
||||
if hasattr(diagonal, "dtype") and diagonal.dtype == "bool":
|
||||
padding_value = bool(padding_value)
|
||||
# Special case to sidestep the tf.constant conversion error:
|
||||
# TypeError: Expected bool, got 0 of type 'int' instead.
|
||||
if hasattr(diagonal, "dtype") and diagonal.dtype == "bool":
|
||||
padding_value = bool(padding_value)
|
||||
|
||||
return gen_array_ops.matrix_diag_v3(
|
||||
diagonal=diagonal,
|
||||
k=k,
|
||||
num_rows=num_rows,
|
||||
num_cols=num_cols,
|
||||
padding_value=padding_value,
|
||||
align=align,
|
||||
name=name)
|
||||
|
||||
# Call v1 to maintain forward compatibility.
|
||||
# (We skip v2 because its alignment conflicts with v3's default alignment.)
|
||||
return gen_array_ops.matrix_diag(diagonal=diagonal, name=name)
|
||||
return gen_array_ops.matrix_diag_v3(
|
||||
diagonal=diagonal,
|
||||
k=k,
|
||||
num_rows=num_rows,
|
||||
num_cols=num_cols,
|
||||
padding_value=padding_value,
|
||||
align=align,
|
||||
name=name)
|
||||
|
||||
|
||||
@tf_export("linalg.diag_part", v1=["linalg.diag_part", "matrix_diag_part"])
|
||||
|
@ -2513,18 +2500,13 @@ def matrix_diag_part(
|
|||
Returns:
|
||||
A Tensor containing diagonals of `input`. Has the same type as `input`.
|
||||
"""
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
# Special case to sidestep the tf.constant conversion error:
|
||||
# TypeError: Expected bool, got 0 of type 'int' instead.
|
||||
if hasattr(input, "dtype") and input.dtype == "bool":
|
||||
padding_value = bool(padding_value)
|
||||
# Special case to sidestep the tf.constant conversion error:
|
||||
# TypeError: Expected bool, got 0 of type 'int' instead.
|
||||
if hasattr(input, "dtype") and input.dtype == "bool":
|
||||
padding_value = bool(padding_value)
|
||||
|
||||
return gen_array_ops.matrix_diag_part_v3(
|
||||
input=input, k=k, padding_value=padding_value, align=align, name=name)
|
||||
|
||||
# Call v1 to maintain forward compatibility.
|
||||
# (We skip v2 because its alignment conflicts with v3's default alignment.)
|
||||
return gen_array_ops.matrix_diag_part(input=input, name=name)
|
||||
return gen_array_ops.matrix_diag_part_v3(
|
||||
input=input, k=k, padding_value=padding_value, align=align, name=name)
|
||||
|
||||
|
||||
@tf_export("linalg.set_diag", v1=["linalg.set_diag", "matrix_set_diag"])
|
||||
|
@ -2659,14 +2641,8 @@ def matrix_set_diag(
|
|||
the left (right-pads the row). It is the packing format LAPACK uses.
|
||||
cuSPARSE uses "LEFT_RIGHT", which is the opposite alignment.
|
||||
"""
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
return gen_array_ops.matrix_set_diag_v3(
|
||||
input=input, diagonal=diagonal, k=k, align=align, name=name)
|
||||
|
||||
# Call v1 to maintain forward compatibility.
|
||||
# (We skip v2 because its alignment conflicts with v3's default alignment.)
|
||||
return gen_array_ops.matrix_set_diag(
|
||||
input=input, diagonal=diagonal, name=name)
|
||||
return gen_array_ops.matrix_set_diag_v3(
|
||||
input=input, diagonal=diagonal, k=k, align=align, name=name)
|
||||
|
||||
|
||||
# pylint: enable=invalid-name
|
||||
|
@ -4921,7 +4897,7 @@ def quantize_v2(
|
|||
raise ValueError("input should have known rank to use negative axis.")
|
||||
axis %= input.shape.ndims
|
||||
|
||||
if compat.forward_compatible(2019, 11, 13) or ensure_minimum_range != 0.01:
|
||||
if ensure_minimum_range != 0.01:
|
||||
return gen_array_ops.quantize_v2(
|
||||
input,
|
||||
min_range,
|
||||
|
@ -4965,7 +4941,7 @@ def quantize(
|
|||
axis=None,
|
||||
ensure_minimum_range=0.01):
|
||||
"""Quantize the input tensor."""
|
||||
if compat.forward_compatible(2019, 11, 13) or ensure_minimum_range != 0.01:
|
||||
if ensure_minimum_range != 0.01:
|
||||
return quantize_v2(
|
||||
input,
|
||||
min_range,
|
||||
|
@ -5007,7 +4983,7 @@ def dequantize( # pylint: disable=missing-docstring
|
|||
raise ValueError("input should have known rank to use negative axis.")
|
||||
axis %= input.shape.ndims
|
||||
|
||||
if compat.forward_compatible(2019, 10, 22) or axis >= 0 or narrow_range:
|
||||
if axis >= 0 or narrow_range:
|
||||
return gen_array_ops.dequantize(
|
||||
input, min_range, max_range, mode=mode, name=name,
|
||||
narrow_range=narrow_range, axis=axis)
|
||||
|
|
|
@ -20,7 +20,6 @@ from __future__ import print_function
|
|||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
|
@ -383,26 +382,25 @@ class BatchNormalizationTest(test.TestCase):
|
|||
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
|
||||
|
||||
def testInferenceShape5(self):
|
||||
with compat.forward_compatibility_horizon(2019, 6, 7):
|
||||
x_shape = [0, 131, 127, 6]
|
||||
for dtype in [np.float16, np.float32]:
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
self._test_inference(
|
||||
x_shape,
|
||||
dtype, [131],
|
||||
np.float32,
|
||||
use_gpu=True,
|
||||
data_format='NCHW')
|
||||
self._test_inference(
|
||||
x_shape, dtype, [6], np.float32, use_gpu=True, data_format='NHWC')
|
||||
x_shape = [0, 131, 127, 6]
|
||||
for dtype in [np.float16, np.float32]:
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
self._test_inference(
|
||||
x_shape,
|
||||
dtype, [131],
|
||||
np.float32,
|
||||
use_gpu=False,
|
||||
use_gpu=True,
|
||||
data_format='NCHW')
|
||||
self._test_inference(
|
||||
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
|
||||
x_shape, dtype, [6], np.float32, use_gpu=True, data_format='NHWC')
|
||||
self._test_inference(
|
||||
x_shape,
|
||||
dtype, [131],
|
||||
np.float32,
|
||||
use_gpu=False,
|
||||
data_format='NCHW')
|
||||
self._test_inference(
|
||||
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
|
||||
|
||||
def testTrainingShape1(self):
|
||||
x_shape = [1, 1, 6, 1]
|
||||
|
@ -450,26 +448,25 @@ class BatchNormalizationTest(test.TestCase):
|
|||
|
||||
@test_util.disable_xla('b/141236973: Empty inputs wrong on CPU.')
|
||||
def testTrainingShape5(self):
|
||||
with compat.forward_compatibility_horizon(2019, 6, 7):
|
||||
x_shape = [0, 131, 127, 6]
|
||||
for dtype in [np.float16, np.float32]:
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
self._test_training(
|
||||
x_shape,
|
||||
dtype, [131],
|
||||
np.float32,
|
||||
use_gpu=True,
|
||||
data_format='NCHW')
|
||||
self._test_training(
|
||||
x_shape, dtype, [6], np.float32, use_gpu=True, data_format='NHWC')
|
||||
x_shape = [0, 131, 127, 6]
|
||||
for dtype in [np.float16, np.float32]:
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
self._test_training(
|
||||
x_shape,
|
||||
dtype, [131],
|
||||
np.float32,
|
||||
use_gpu=False,
|
||||
use_gpu=True,
|
||||
data_format='NCHW')
|
||||
self._test_training(
|
||||
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
|
||||
x_shape, dtype, [6], np.float32, use_gpu=True, data_format='NHWC')
|
||||
self._test_training(
|
||||
x_shape,
|
||||
dtype, [131],
|
||||
np.float32,
|
||||
use_gpu=False,
|
||||
data_format='NCHW')
|
||||
self._test_training(
|
||||
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBatchNormGradShape1(self):
|
||||
|
@ -586,39 +583,38 @@ class BatchNormalizationTest(test.TestCase):
|
|||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla('This test never passed for XLA')
|
||||
def testBatchNormGradShape5(self):
|
||||
with compat.forward_compatibility_horizon(2019, 6, 7):
|
||||
for is_training in [True, False]:
|
||||
x_shape = [0, 7, 11, 4]
|
||||
for dtype in [np.float16, np.float32]:
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
self._test_gradient(
|
||||
x_shape,
|
||||
dtype, [7],
|
||||
np.float32,
|
||||
use_gpu=True,
|
||||
data_format='NCHW',
|
||||
is_training=is_training)
|
||||
self._test_gradient(
|
||||
x_shape,
|
||||
dtype, [4],
|
||||
np.float32,
|
||||
use_gpu=True,
|
||||
data_format='NHWC',
|
||||
is_training=is_training)
|
||||
self._test_gradient(
|
||||
x_shape,
|
||||
dtype, [4],
|
||||
np.float32,
|
||||
use_gpu=False,
|
||||
data_format='NHWC',
|
||||
is_training=is_training)
|
||||
for is_training in [True, False]:
|
||||
x_shape = [0, 7, 11, 4]
|
||||
for dtype in [np.float16, np.float32]:
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
self._test_gradient(
|
||||
x_shape,
|
||||
dtype, [7],
|
||||
np.float32,
|
||||
use_gpu=False,
|
||||
use_gpu=True,
|
||||
data_format='NCHW',
|
||||
is_training=is_training)
|
||||
self._test_gradient(
|
||||
x_shape,
|
||||
dtype, [4],
|
||||
np.float32,
|
||||
use_gpu=True,
|
||||
data_format='NHWC',
|
||||
is_training=is_training)
|
||||
self._test_gradient(
|
||||
x_shape,
|
||||
dtype, [4],
|
||||
np.float32,
|
||||
use_gpu=False,
|
||||
data_format='NHWC',
|
||||
is_training=is_training)
|
||||
self._test_gradient(
|
||||
x_shape,
|
||||
dtype, [7],
|
||||
np.float32,
|
||||
use_gpu=False,
|
||||
data_format='NCHW',
|
||||
is_training=is_training)
|
||||
|
||||
def _testBatchNormGradGrad(self, config):
|
||||
shape = config['shape']
|
||||
|
|
|
@ -18,7 +18,6 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
|
@ -33,15 +32,6 @@ from tensorflow.python.ops.parallel_for.test_util import PForTestCase
|
|||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
# LINT.IfChange
|
||||
matrix_diag_v3_forward_compat_date = (2019, 12, 6)
|
||||
# LINT.ThenChange(
|
||||
# //tensorflow/compiler/tests/matrix_diag_ops_test.py,
|
||||
# //tensorflow/python/kernel_tests/diag_op_test.py,
|
||||
# //tensorflow/python/ops/array_ops.py
|
||||
# )
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class ArrayTest(PForTestCase):
|
||||
|
||||
|
@ -345,10 +335,8 @@ class ArrayTest(PForTestCase):
|
|||
|
||||
def loop_fn(i):
|
||||
diagonal = array_ops.gather(x, i)
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
return array_ops.matrix_diag(
|
||||
diagonal, k=(0, 1), num_rows=4, num_cols=5, align="RIGHT_LEFT")
|
||||
return array_ops.matrix_diag(diagonal)
|
||||
return array_ops.matrix_diag(
|
||||
diagonal, k=(0, 1), num_rows=4, num_cols=5, align="RIGHT_LEFT")
|
||||
|
||||
self._test_loop_fn(loop_fn, 3)
|
||||
|
||||
|
@ -357,10 +345,8 @@ class ArrayTest(PForTestCase):
|
|||
|
||||
def loop_fn(i):
|
||||
input = array_ops.gather(x, i) # pylint: disable=redefined-builtin
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
return array_ops.matrix_diag_part(
|
||||
input, k=(-2, 0), padding_value=3, align="RIGHT_LEFT")
|
||||
return array_ops.matrix_diag_part(input)
|
||||
return array_ops.matrix_diag_part(
|
||||
input, k=(-2, 0), padding_value=3, align="RIGHT_LEFT")
|
||||
|
||||
self._test_loop_fn(loop_fn, 3)
|
||||
|
||||
|
@ -378,17 +364,16 @@ class ArrayTest(PForTestCase):
|
|||
array_ops.matrix_set_diag(matrix_i, diags[0, ...]),
|
||||
]
|
||||
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
k = (-1, 1)
|
||||
band_i = array_ops.gather(bands, i)
|
||||
for align in ["RIGHT_LEFT", "LEFT_RIGHT"]:
|
||||
results.extend([
|
||||
array_ops.matrix_set_diag(matrix_i, band_i, k=k, align=align),
|
||||
array_ops.matrix_set_diag(
|
||||
matrices[0, ...], band_i, k=k, align=align),
|
||||
array_ops.matrix_set_diag(
|
||||
matrix_i, bands[0, ...], k=k, align=align)
|
||||
])
|
||||
k = (-1, 1)
|
||||
band_i = array_ops.gather(bands, i)
|
||||
for align in ["RIGHT_LEFT", "LEFT_RIGHT"]:
|
||||
results.extend([
|
||||
array_ops.matrix_set_diag(matrix_i, band_i, k=k, align=align),
|
||||
array_ops.matrix_set_diag(
|
||||
matrices[0, ...], band_i, k=k, align=align),
|
||||
array_ops.matrix_set_diag(
|
||||
matrix_i, bands[0, ...], k=k, align=align)
|
||||
])
|
||||
return results
|
||||
|
||||
self._test_loop_fn(loop_fn, 3)
|
||||
|
|
|
@ -28,7 +28,6 @@ import numpy as np
|
|||
from tensorflow.core.example import example_pb2
|
||||
from tensorflow.core.example import feature_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
|
@ -444,55 +443,54 @@ class NNTest(PForTestCase):
|
|||
self._test_loop_fn(loop_fn, 3)
|
||||
|
||||
def test_fused_batch_norm(self):
|
||||
with compat.forward_compatibility_horizon(2019, 6, 7):
|
||||
data_formats = ["NHWC"]
|
||||
if test.is_gpu_available():
|
||||
data_formats.append("NCHW")
|
||||
for is_training in (True, False):
|
||||
for data_format in data_formats:
|
||||
with backprop.GradientTape(persistent=True) as g:
|
||||
if data_format == "NCHW":
|
||||
x = random_ops.random_uniform([3, 1, 2, 5, 5])
|
||||
else:
|
||||
x = random_ops.random_uniform([3, 1, 5, 5, 2])
|
||||
g.watch(x)
|
||||
scale = random_ops.random_uniform([2])
|
||||
g.watch(scale)
|
||||
offset = random_ops.random_uniform([2])
|
||||
g.watch(offset)
|
||||
mean = None if is_training else random_ops.random_uniform([2])
|
||||
variance = None if is_training else random_ops.random_uniform([2])
|
||||
data_formats = ["NHWC"]
|
||||
if test.is_gpu_available():
|
||||
data_formats.append("NCHW")
|
||||
for is_training in (True, False):
|
||||
for data_format in data_formats:
|
||||
with backprop.GradientTape(persistent=True) as g:
|
||||
if data_format == "NCHW":
|
||||
x = random_ops.random_uniform([3, 1, 2, 5, 5])
|
||||
else:
|
||||
x = random_ops.random_uniform([3, 1, 5, 5, 2])
|
||||
g.watch(x)
|
||||
scale = random_ops.random_uniform([2])
|
||||
g.watch(scale)
|
||||
offset = random_ops.random_uniform([2])
|
||||
g.watch(offset)
|
||||
mean = None if is_training else random_ops.random_uniform([2])
|
||||
variance = None if is_training else random_ops.random_uniform([2])
|
||||
|
||||
# pylint: disable=cell-var-from-loop
|
||||
def loop_fn(i):
|
||||
with g:
|
||||
x1 = array_ops.gather(x, i)
|
||||
outputs = nn.fused_batch_norm(
|
||||
x1,
|
||||
scale,
|
||||
offset,
|
||||
mean=mean,
|
||||
variance=variance,
|
||||
epsilon=0.01,
|
||||
data_format=data_format,
|
||||
is_training=is_training)
|
||||
outputs = list(outputs)
|
||||
# We only test the first value of outputs when is_training is
|
||||
# False. It looks like CPU and GPU have different outputs for
|
||||
# batch_mean and batch_variance for this case.
|
||||
if not is_training:
|
||||
outputs[1] = constant_op.constant(0.)
|
||||
outputs[2] = constant_op.constant(0.)
|
||||
loss = nn.l2_loss(outputs[0])
|
||||
if is_training:
|
||||
gradients = g.gradient(loss, [x1, scale, offset])
|
||||
else:
|
||||
gradients = [constant_op.constant(0.)] * 3
|
||||
return outputs + gradients
|
||||
# pylint: disable=cell-var-from-loop
|
||||
def loop_fn(i):
|
||||
with g:
|
||||
x1 = array_ops.gather(x, i)
|
||||
outputs = nn.fused_batch_norm(
|
||||
x1,
|
||||
scale,
|
||||
offset,
|
||||
mean=mean,
|
||||
variance=variance,
|
||||
epsilon=0.01,
|
||||
data_format=data_format,
|
||||
is_training=is_training)
|
||||
outputs = list(outputs)
|
||||
# We only test the first value of outputs when is_training is
|
||||
# False. It looks like CPU and GPU have different outputs for
|
||||
# batch_mean and batch_variance for this case.
|
||||
if not is_training:
|
||||
outputs[1] = constant_op.constant(0.)
|
||||
outputs[2] = constant_op.constant(0.)
|
||||
loss = nn.l2_loss(outputs[0])
|
||||
if is_training:
|
||||
gradients = g.gradient(loss, [x1, scale, offset])
|
||||
else:
|
||||
gradients = [constant_op.constant(0.)] * 3
|
||||
return outputs + gradients
|
||||
|
||||
# pylint: enable=cell-var-from-loop
|
||||
# pylint: enable=cell-var-from-loop
|
||||
|
||||
self._test_loop_fn(loop_fn, 3)
|
||||
self._test_loop_fn(loop_fn, 3)
|
||||
|
||||
def test_log_softmax(self):
|
||||
logits = random_ops.random_uniform([3, 2, 4])
|
||||
|
|
|
@ -33,7 +33,6 @@ import six
|
|||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.compiler.tf2xla.ops import gen_xla_ops
|
||||
from tensorflow.python.compat import compat as fwd_compat
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
|
@ -253,10 +252,7 @@ def einsum(equation, *inputs, **kwargs):
|
|||
- the format of `equation` is incorrect,
|
||||
- number of inputs or their shapes are inconsistent with `equation`.
|
||||
"""
|
||||
if fwd_compat.forward_compatible(2019, 10, 18):
|
||||
return _einsum_v2(equation, *inputs, **kwargs)
|
||||
else:
|
||||
return _einsum_v1(equation, *inputs, **kwargs)
|
||||
return _einsum_v2(equation, *inputs, **kwargs)
|
||||
|
||||
|
||||
def _einsum_v1(equation, *inputs, **kwargs):
|
||||
|
|
|
@ -23,7 +23,6 @@ import opt_einsum
|
|||
import six
|
||||
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
|
@ -243,20 +242,19 @@ class EinsumTest(test.TestCase):
|
|||
self._check('abc->ca', (3, 4, 5))
|
||||
self._check('abc->cab', (3, 4, 5))
|
||||
|
||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
||||
# Empty cases.
|
||||
self._check('', ())
|
||||
self._check('->', ())
|
||||
# Empty cases.
|
||||
self._check('', ())
|
||||
self._check('->', ())
|
||||
|
||||
# Repeated indices cases.
|
||||
self._check('aa->', (3, 3))
|
||||
self._check('aa->a', (3, 3))
|
||||
self._check('aaa->', (3, 3, 3))
|
||||
self._check('aaa->a', (3, 3, 3))
|
||||
self._check('aab->a', (3, 3, 4))
|
||||
self._check('aabcc->a', (3, 3, 5, 4, 4))
|
||||
self._check('aabcc->ac', (3, 3, 5, 4, 4))
|
||||
self._check('aabcd->ad', (3, 3, 5, 4, 4))
|
||||
# Repeated indices cases.
|
||||
self._check('aa->', (3, 3))
|
||||
self._check('aa->a', (3, 3))
|
||||
self._check('aaa->', (3, 3, 3))
|
||||
self._check('aaa->a', (3, 3, 3))
|
||||
self._check('aab->a', (3, 3, 4))
|
||||
self._check('aabcc->a', (3, 3, 5, 4, 4))
|
||||
self._check('aabcc->ac', (3, 3, 5, 4, 4))
|
||||
self._check('aabcd->ad', (3, 3, 5, 4, 4))
|
||||
|
||||
def test_unary_ellipsis(self):
|
||||
self._check('...->', ())
|
||||
|
@ -266,17 +264,16 @@ class EinsumTest(test.TestCase):
|
|||
self._check('...ij->...ji', (5, 2, 3)) # batch matrix transpose
|
||||
self._check('...ij->...', (5, 2, 3)) # batch sum
|
||||
|
||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
||||
self._check('...->...', ())
|
||||
self._check('->...', ())
|
||||
self._check('...->...', ())
|
||||
self._check('->...', ())
|
||||
|
||||
# Repeated indices.
|
||||
self._check('i...ii->...i', (3, 2, 3, 3))
|
||||
self._check('i...i->i...', (2, 2))
|
||||
self._check('i...i->', (2, 2))
|
||||
self._check('i...i->...', (2, 5, 1, 2))
|
||||
self._check('i...i->i...', (2, 1, 2))
|
||||
self._check('i...i->i...', (2, 3, 4, 5, 2))
|
||||
# Repeated indices.
|
||||
self._check('i...ii->...i', (3, 2, 3, 3))
|
||||
self._check('i...i->i...', (2, 2))
|
||||
self._check('i...i->', (2, 2))
|
||||
self._check('i...i->...', (2, 5, 1, 2))
|
||||
self._check('i...i->i...', (2, 1, 2))
|
||||
self._check('i...i->i...', (2, 3, 4, 5, 2))
|
||||
|
||||
def test_binary_simple(self):
|
||||
# Binary cases in XLA mode must have either (a) each index appearing exactly
|
||||
|
@ -301,13 +298,12 @@ class EinsumTest(test.TestCase):
|
|||
self._check('ab,ab->', (3, 4), (3, 4))
|
||||
|
||||
def test_repeated_indices(self):
|
||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
||||
# Repeated indices.
|
||||
self._check('ijj,k->ik', (2, 3, 3), (4,))
|
||||
self._check('aba,a->b', (3, 4, 3), (3,))
|
||||
# From https://github.com/dask/dask/pull/3412#discussion_r182413444
|
||||
self._check('aab,bc->ac', (2, 2, 3), (3, 4))
|
||||
self._check('aab,bcc->ac', (2, 2, 3), (3, 4, 4))
|
||||
# Repeated indices.
|
||||
self._check('ijj,k->ik', (2, 3, 3), (4,))
|
||||
self._check('aba,a->b', (3, 4, 3), (3,))
|
||||
# From https://github.com/dask/dask/pull/3412#discussion_r182413444
|
||||
self._check('aab,bc->ac', (2, 2, 3), (3, 4))
|
||||
self._check('aab,bcc->ac', (2, 2, 3), (3, 4, 4))
|
||||
|
||||
def test_binary_ellipsis(self):
|
||||
# Batch matmul with ellipsis but without broadcasting.
|
||||
|
@ -324,23 +320,22 @@ class EinsumTest(test.TestCase):
|
|||
self._check('...i,...j->...ij', (5, 2), (5, 3)) # outer product
|
||||
|
||||
def test_broadcasting(self):
|
||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
||||
# Batch matmul with broadcasting.
|
||||
self._check('...ij,...jk->...ik', (1, 2, 3), (3, 5))
|
||||
self._check('...ij,...jk->...ik', (2, 3), (1, 3, 5))
|
||||
self._check('...ij,...jk->...ik', (5, 2, 3), (3, 5))
|
||||
self._check('...ij,...jk->...ik', (2, 3), (5, 3, 5))
|
||||
self._check('...ij,...jk->...ik', (3, 1, 2, 3), (1, 1, 7, 3, 5))
|
||||
self._check('i...j,j...k->...ik', (2, 1, 3, 1, 3), (3, 1, 7, 5))
|
||||
# Batch matmul with broadcasting.
|
||||
self._check('...ij,...jk->...ik', (1, 2, 3), (3, 5))
|
||||
self._check('...ij,...jk->...ik', (2, 3), (1, 3, 5))
|
||||
self._check('...ij,...jk->...ik', (5, 2, 3), (3, 5))
|
||||
self._check('...ij,...jk->...ik', (2, 3), (5, 3, 5))
|
||||
self._check('...ij,...jk->...ik', (3, 1, 2, 3), (1, 1, 7, 3, 5))
|
||||
self._check('i...j,j...k->...ik', (2, 1, 3, 1, 3), (3, 1, 7, 5))
|
||||
|
||||
# Broadcasting with repeated indices.
|
||||
self._check('ij,jk...k->i...', (3, 2), (2, 4, 1, 4))
|
||||
self._check('ij,jk...k->...i', (3, 2), (2, 4, 5, 4))
|
||||
self._check('ijj,jk...k->i...', (3, 2, 2), (2, 4, 1, 4))
|
||||
self._check('i...jj,jk...k->i...', (3, 3, 1, 2, 2), (2, 4, 1, 5, 4))
|
||||
# Following 2 from https://stackoverflow.com/a/19203475/1611416
|
||||
self._check('...abc,...abcd->...d', (1, 1, 2, 3, 4), (5, 2, 3, 4, 6))
|
||||
self._check('ab...,b->ab...', (2, 3, 1, 1, 5), (3,))
|
||||
# Broadcasting with repeated indices.
|
||||
self._check('ij,jk...k->i...', (3, 2), (2, 4, 1, 4))
|
||||
self._check('ij,jk...k->...i', (3, 2), (2, 4, 5, 4))
|
||||
self._check('ijj,jk...k->i...', (3, 2, 2), (2, 4, 1, 4))
|
||||
self._check('i...jj,jk...k->i...', (3, 3, 1, 2, 2), (2, 4, 1, 5, 4))
|
||||
# Following 2 from https://stackoverflow.com/a/19203475/1611416
|
||||
self._check('...abc,...abcd->...d', (1, 1, 2, 3, 4), (5, 2, 3, 4, 6))
|
||||
self._check('ab...,b->ab...', (2, 3, 1, 1, 5), (3,))
|
||||
|
||||
def test_dtypes(self):
|
||||
dtypes = []
|
||||
|
@ -388,22 +383,20 @@ class EinsumTest(test.TestCase):
|
|||
((4, 3), (None, 3)))
|
||||
|
||||
# Ellipsis with unknown rank.
|
||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
||||
check('bijl,bjkm->bik', ((9, 2, 3, 5), None), ((9, 3, 4, 7), None))
|
||||
check('...ij,...jk->...ik', ((3, 1, 2, 3), None), ((1, 7, 3, 4), None))
|
||||
check('bijl,bjkm->bik', ((9, 2, 3, 5), None), ((9, 3, 4, 7), None))
|
||||
check('...ij,...jk->...ik', ((3, 1, 2, 3), None), ((1, 7, 3, 4), None))
|
||||
|
||||
def test_numpy_input(self):
|
||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
||||
# In addition to Tensors, we also support raw numpy arrays as inputs.
|
||||
r = np.random.RandomState(0)
|
||||
s = 'ijk,ijl,ikl->i'
|
||||
x = r.randn(1, 2, 3)
|
||||
y = r.randn(1, 2, 4)
|
||||
z = r.randn(1, 3, 4)
|
||||
# In addition to Tensors, we also support raw numpy arrays as inputs.
|
||||
r = np.random.RandomState(0)
|
||||
s = 'ijk,ijl,ikl->i'
|
||||
x = r.randn(1, 2, 3)
|
||||
y = r.randn(1, 2, 4)
|
||||
z = r.randn(1, 3, 4)
|
||||
|
||||
a = np.einsum(s, x, y, z)
|
||||
b = self.evaluate(special_math_ops.einsum(s, x, y, z))
|
||||
self.assertAllClose(a, b, atol=1e-4, rtol=1e-4)
|
||||
a = np.einsum(s, x, y, z)
|
||||
b = self.evaluate(special_math_ops.einsum(s, x, y, z))
|
||||
self.assertAllClose(a, b, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_long_cases(self):
|
||||
cases = [
|
||||
|
@ -464,58 +457,56 @@ class EinsumTest(test.TestCase):
|
|||
|
||||
@test_util.disable_xla('b/131919749')
|
||||
def test_long_cases_with_repeated_labels(self):
|
||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
||||
cases = [
|
||||
# Tests from dask.
|
||||
'fdf,cdd,ccd,afe->ae',
|
||||
'fff,fae,bef,def->abd',
|
||||
]
|
||||
dimension_map = dict((c, ord(c) - ord('a') + 1) for c in 'abcdefghij')
|
||||
for equation in cases:
|
||||
inputs = equation.split('->')[0].replace(' ', '')
|
||||
input_shapes = []
|
||||
for input_str in inputs.split(','):
|
||||
input_shapes.append(tuple([dimension_map[c] for c in input_str]))
|
||||
self._check(equation, *input_shapes)
|
||||
cases = [
|
||||
# Tests from dask.
|
||||
'fdf,cdd,ccd,afe->ae',
|
||||
'fff,fae,bef,def->abd',
|
||||
]
|
||||
dimension_map = dict((c, ord(c) - ord('a') + 1) for c in 'abcdefghij')
|
||||
for equation in cases:
|
||||
inputs = equation.split('->')[0].replace(' ', '')
|
||||
input_shapes = []
|
||||
for input_str in inputs.split(','):
|
||||
input_shapes.append(tuple([dimension_map[c] for c in input_str]))
|
||||
self._check(equation, *input_shapes)
|
||||
|
||||
@test_util.disable_xla('b/131919749')
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_invalid_equation(self):
|
||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
||||
r = np.random.RandomState(0)
|
||||
cases = [
|
||||
# invalid equation format.
|
||||
('a0->a', r.randn(5, 3)),
|
||||
('a->a,a', r.randn(5)),
|
||||
('a->a->a', r.randn(5)),
|
||||
('ijk ijk', r.randn(1, 2, 3), r.randn(1, 2, 3)),
|
||||
('ij.jk->ik', r.randn(2, 3), r.randn(3, 4)),
|
||||
# output label not present in input.
|
||||
('a->b', r.randn(5)),
|
||||
('ij,jk->im', r.randn(2, 3), r.randn(3, 4)),
|
||||
# wrong shape.
|
||||
('ij,jk->ik', r.randn(1, 2, 3), r.randn(3, 4)),
|
||||
# inconsistent dimensions.
|
||||
('ij,jk->ik', r.randn(2, 3), r.randn(4, 4)),
|
||||
# output has repeated subscripts.
|
||||
('ij,jk->iik', r.randn(2, 3), r.randn(3, 4)),
|
||||
# too many ellipses
|
||||
('...ij...,jk...->ik...', r.randn(2, 3), r.randn(3, 4)),
|
||||
('...ij,jk...->...ik...', r.randn(2, 3), r.randn(3, 4)),
|
||||
# invalid broadcast dimensions.
|
||||
('...ij,...jk->...ik', r.randn(5, 2, 3), r.randn(7, 3, 4)),
|
||||
# output should have ellipsis when broadcasting shape is non-empty.
|
||||
('...ij,...jk->ik', r.randn(2, 2, 3), r.randn(3, 4)),
|
||||
]
|
||||
for args in cases:
|
||||
with self.assertRaises((ValueError, errors.InvalidArgumentError)):
|
||||
_ = special_math_ops.einsum(*args)
|
||||
r = np.random.RandomState(0)
|
||||
cases = [
|
||||
# invalid equation format.
|
||||
('a0->a', r.randn(5, 3)),
|
||||
('a->a,a', r.randn(5)),
|
||||
('a->a->a', r.randn(5)),
|
||||
('ijk ijk', r.randn(1, 2, 3), r.randn(1, 2, 3)),
|
||||
('ij.jk->ik', r.randn(2, 3), r.randn(3, 4)),
|
||||
# output label not present in input.
|
||||
('a->b', r.randn(5)),
|
||||
('ij,jk->im', r.randn(2, 3), r.randn(3, 4)),
|
||||
# wrong shape.
|
||||
('ij,jk->ik', r.randn(1, 2, 3), r.randn(3, 4)),
|
||||
# inconsistent dimensions.
|
||||
('ij,jk->ik', r.randn(2, 3), r.randn(4, 4)),
|
||||
# output has repeated subscripts.
|
||||
('ij,jk->iik', r.randn(2, 3), r.randn(3, 4)),
|
||||
# too many ellipses
|
||||
('...ij...,jk...->ik...', r.randn(2, 3), r.randn(3, 4)),
|
||||
('...ij,jk...->...ik...', r.randn(2, 3), r.randn(3, 4)),
|
||||
# invalid broadcast dimensions.
|
||||
('...ij,...jk->...ik', r.randn(5, 2, 3), r.randn(7, 3, 4)),
|
||||
# output should have ellipsis when broadcasting shape is non-empty.
|
||||
('...ij,...jk->ik', r.randn(2, 2, 3), r.randn(3, 4)),
|
||||
]
|
||||
for args in cases:
|
||||
with self.assertRaises((ValueError, errors.InvalidArgumentError)):
|
||||
_ = special_math_ops.einsum(*args)
|
||||
|
||||
placeholders = [
|
||||
array_ops.placeholder_with_default(x, shape=None) for x in args[1:]
|
||||
]
|
||||
with self.assertRaises((ValueError, errors.InvalidArgumentError)):
|
||||
_ = self.evaluate(special_math_ops.einsum(args[0], *placeholders))
|
||||
placeholders = [
|
||||
array_ops.placeholder_with_default(x, shape=None) for x in args[1:]
|
||||
]
|
||||
with self.assertRaises((ValueError, errors.InvalidArgumentError)):
|
||||
_ = self.evaluate(special_math_ops.einsum(args[0], *placeholders))
|
||||
|
||||
@test_util.disable_xla('b/131919749')
|
||||
def test_empty(self):
|
||||
|
@ -535,10 +526,9 @@ class EinsumTest(test.TestCase):
|
|||
# From transformer xl.
|
||||
check('ibnd,ijbn->jnd', [(1, 0, 5, 10), (1, 1, 0, 5)], (1, 5, 10))
|
||||
|
||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
||||
# Generalized traces with zero-sized dimensions.
|
||||
check('aab,bc->ac', [(0, 0, 10), (10, 10)], (0, 10))
|
||||
check('aaab,bc->c', [(0, 0, 0, 3), (3, 4)], (4,))
|
||||
# Generalized traces with zero-sized dimensions.
|
||||
check('aab,bc->ac', [(0, 0, 10), (10, 10)], (0, 10))
|
||||
check('aaab,bc->c', [(0, 0, 0, 3), (3, 4)], (4,))
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
|
@ -556,122 +546,112 @@ class EinsumGradTest(test.TestCase):
|
|||
|
||||
@test_util.disable_xla('b/131919749')
|
||||
def test_unary(self):
|
||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
||||
self._check_gradient('->', ())
|
||||
self._check_gradient('aaa->a', (3, 3, 3))
|
||||
self._check_gradient('aabcd->ad', (3, 3, 5, 4, 4))
|
||||
self._check_gradient('abcd->da', (3, 5, 4, 2))
|
||||
self._check_gradient('->', ())
|
||||
self._check_gradient('aaa->a', (3, 3, 3))
|
||||
self._check_gradient('aabcd->ad', (3, 3, 5, 4, 4))
|
||||
self._check_gradient('abcd->da', (3, 5, 4, 2))
|
||||
|
||||
@test_util.disable_xla('b/131919749')
|
||||
def test_unary_ellipsis(self):
|
||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
||||
self._check_gradient('...->...', ())
|
||||
self._check_gradient('...->', ())
|
||||
self._check_gradient('->...', ())
|
||||
self._check_gradient('...->...', ())
|
||||
self._check_gradient('...->', ())
|
||||
self._check_gradient('->...', ())
|
||||
|
||||
# Tests from dask
|
||||
self._check_gradient('a...a->a...', (2, 2))
|
||||
self._check_gradient('a...a->', (2, 2))
|
||||
self._check_gradient('a...a->...', (2, 5, 1, 2))
|
||||
self._check_gradient('a...a->a...', (2, 1, 2))
|
||||
self._check_gradient('a...a->a...', (2, 3, 4, 5, 2))
|
||||
# Tests from dask
|
||||
self._check_gradient('a...a->a...', (2, 2))
|
||||
self._check_gradient('a...a->', (2, 2))
|
||||
self._check_gradient('a...a->...', (2, 5, 1, 2))
|
||||
self._check_gradient('a...a->a...', (2, 1, 2))
|
||||
self._check_gradient('a...a->a...', (2, 3, 4, 5, 2))
|
||||
|
||||
self._check_gradient('...ijk->...ki', (3, 4, 5))
|
||||
self._check_gradient('...ijk->...ki', (1, 3, 4, 5))
|
||||
self._check_gradient('...ijk->...ki', (2, 2, 3, 4, 5))
|
||||
self._check_gradient('ab...cd->da...', (3, 5, 2, 3, 4, 2))
|
||||
self._check_gradient('...ijk->...ki', (3, 4, 5))
|
||||
self._check_gradient('...ijk->...ki', (1, 3, 4, 5))
|
||||
self._check_gradient('...ijk->...ki', (2, 2, 3, 4, 5))
|
||||
self._check_gradient('ab...cd->da...', (3, 5, 2, 3, 4, 2))
|
||||
|
||||
def test_binary_simple(self):
|
||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
||||
# Binary cases in XLA mode must have either (a) each index appearing
|
||||
# exactly once in both the inputs (batch or contraction index), or
|
||||
# (b) appearing exactly once in an input and in the output (free index).
|
||||
self._check_gradient(',->', (), ())
|
||||
self._check_gradient('a,a->', (3,), (3,))
|
||||
self._check_gradient('a,a->a', (3,), (3,))
|
||||
self._check_gradient('ab,b->a', (3, 4), (4,))
|
||||
self._check_gradient('ab,ab->', (3, 4), (3, 4))
|
||||
self._check_gradient('ab,bc->ac', (3, 4), (4, 5))
|
||||
self._check_gradient('nij,jk->nik', (5, 2, 3), (3, 4))
|
||||
self._check_gradient('abc,bad->abcd', (1, 2, 3), (2, 1, 4))
|
||||
# Based on https://github.com/google/jax/issues/37#issuecomment-448572187
|
||||
self._check_gradient('sa,shb->shab', (2, 1), (2, 3, 4))
|
||||
# Binary cases in XLA mode must have either (a) each index appearing
|
||||
# exactly once in both the inputs (batch or contraction index), or
|
||||
# (b) appearing exactly once in an input and in the output (free index).
|
||||
self._check_gradient(',->', (), ())
|
||||
self._check_gradient('a,a->', (3,), (3,))
|
||||
self._check_gradient('a,a->a', (3,), (3,))
|
||||
self._check_gradient('ab,b->a', (3, 4), (4,))
|
||||
self._check_gradient('ab,ab->', (3, 4), (3, 4))
|
||||
self._check_gradient('ab,bc->ac', (3, 4), (4, 5))
|
||||
self._check_gradient('nij,jk->nik', (5, 2, 3), (3, 4))
|
||||
self._check_gradient('abc,bad->abcd', (1, 2, 3), (2, 1, 4))
|
||||
# Based on https://github.com/google/jax/issues/37#issuecomment-448572187
|
||||
self._check_gradient('sa,shb->shab', (2, 1), (2, 3, 4))
|
||||
|
||||
def test_empty(self):
|
||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
||||
# From Transformer XL.
|
||||
self._check_gradient('ibnd,ijbn->jnd', (1, 0, 5, 10), (1, 1, 0, 5))
|
||||
# From Transformer XL.
|
||||
self._check_gradient('ibnd,ijbn->jnd', (1, 0, 5, 10), (1, 1, 0, 5))
|
||||
|
||||
@test_util.disable_xla('b/131919749')
|
||||
def test_reduced_indices(self):
|
||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
||||
self._check_gradient('ba,b->', (3, 2), (3,))
|
||||
self._check_gradient('ab,ab->', (3, 4), (3, 4))
|
||||
self._check_gradient('abce,badf->abcd', (1, 2, 3, 4), (2, 1, 4, 3))
|
||||
self._check_gradient('ba,b->', (3, 2), (3,))
|
||||
self._check_gradient('ab,ab->', (3, 4), (3, 4))
|
||||
self._check_gradient('abce,badf->abcd', (1, 2, 3, 4), (2, 1, 4, 3))
|
||||
|
||||
@test_util.disable_xla('b/131919749')
|
||||
def test_repeated_indices(self):
|
||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
||||
# Repeated indices.
|
||||
self._check_gradient('aba,a->b', (3, 4, 3), (3,))
|
||||
self._check_gradient('ijj,k->ik', (2, 3, 3), (4,))
|
||||
self._check_gradient('ill,k->ik', (2, 3, 3), (4,))
|
||||
# From https://github.com/dask/dask/pull/3412#discussion_r182413444
|
||||
self._check_gradient('aab,bc->ac', (1, 1, 3), (3, 4))
|
||||
self._check_gradient('aab,bcc->ac', (2, 2, 3), (3, 4, 4))
|
||||
# Repeated indices.
|
||||
self._check_gradient('aba,a->b', (3, 4, 3), (3,))
|
||||
self._check_gradient('ijj,k->ik', (2, 3, 3), (4,))
|
||||
self._check_gradient('ill,k->ik', (2, 3, 3), (4,))
|
||||
# From https://github.com/dask/dask/pull/3412#discussion_r182413444
|
||||
self._check_gradient('aab,bc->ac', (1, 1, 3), (3, 4))
|
||||
self._check_gradient('aab,bcc->ac', (2, 2, 3), (3, 4, 4))
|
||||
|
||||
@test_util.disable_xla('b/131919749')
|
||||
def test_empty_with_repeated_indices(self):
|
||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
||||
self._check_gradient('aab,bc->ac', (0, 0, 10), (10, 10))
|
||||
self._check_gradient('aab,bc->ac', (1, 1, 0), (0, 10))
|
||||
self._check_gradient('aaab,bc->c', (0, 0, 0, 3), (3, 4))
|
||||
self._check_gradient('aab,bc->ac', (0, 0, 10), (10, 10))
|
||||
self._check_gradient('aab,bc->ac', (1, 1, 0), (0, 10))
|
||||
self._check_gradient('aaab,bc->c', (0, 0, 0, 3), (3, 4))
|
||||
|
||||
@test_util.disable_xla('b/131919749')
|
||||
def test_broadcasting(self):
|
||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
||||
self._check_gradient('...ij,...jk->...ik', (3, 2), (2, 4))
|
||||
self._check_gradient('ij...,jk...->ik...', (3, 2, 1), (2, 4))
|
||||
self._check_gradient('...ij,...jk->...ik', (3, 1, 3, 2), (1, 5, 2, 4))
|
||||
self._check_gradient('ij,jk...k->i...', (3, 2), (2, 4, 1, 4))
|
||||
self._check_gradient('aab,b...c->a...c', (1, 1, 3), (3, 1, 1, 4))
|
||||
# Tests from dask.
|
||||
self._check_gradient('...i,...j,...k->...ijk', (1, 4, 1, 2), (5, 1, 1, 3),
|
||||
(1, 1, 1, 1, 9))
|
||||
self._check_gradient('...i,...j,...k->...ijk', (1,), (1,), (1,))
|
||||
self._check_gradient('...ij,...jk->...ik', (3, 2), (2, 4))
|
||||
self._check_gradient('ij...,jk...->ik...', (3, 2, 1), (2, 4))
|
||||
self._check_gradient('...ij,...jk->...ik', (3, 1, 3, 2), (1, 5, 2, 4))
|
||||
self._check_gradient('ij,jk...k->i...', (3, 2), (2, 4, 1, 4))
|
||||
self._check_gradient('aab,b...c->a...c', (1, 1, 3), (3, 1, 1, 4))
|
||||
# Tests from dask.
|
||||
self._check_gradient('...i,...j,...k->...ijk', (1, 4, 1, 2), (5, 1, 1, 3),
|
||||
(1, 1, 1, 1, 9))
|
||||
self._check_gradient('...i,...j,...k->...ijk', (1,), (1,), (1,))
|
||||
|
||||
def test_long_cases(self):
|
||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
||||
cases = [
|
||||
'abhe,hidj,jgba,hiab,gab->ed',
|
||||
# Tests from dask.
|
||||
'ea,fb,abcd,gc,hd->efgh',
|
||||
]
|
||||
dimension_map = dict(
|
||||
(c, ((ord(c) - ord('a')) % 3) + 1) for c in 'abcdefghij')
|
||||
for equation in cases:
|
||||
inputs = equation.split('->')[0].replace(' ', '')
|
||||
input_shapes = []
|
||||
for input_str in inputs.split(','):
|
||||
input_shapes.append(tuple([dimension_map[c] for c in input_str]))
|
||||
self._check_gradient(equation, *input_shapes)
|
||||
cases = [
|
||||
'abhe,hidj,jgba,hiab,gab->ed',
|
||||
# Tests from dask.
|
||||
'ea,fb,abcd,gc,hd->efgh',
|
||||
]
|
||||
dimension_map = dict(
|
||||
(c, ((ord(c) - ord('a')) % 3) + 1) for c in 'abcdefghij')
|
||||
for equation in cases:
|
||||
inputs = equation.split('->')[0].replace(' ', '')
|
||||
input_shapes = []
|
||||
for input_str in inputs.split(','):
|
||||
input_shapes.append(tuple([dimension_map[c] for c in input_str]))
|
||||
self._check_gradient(equation, *input_shapes)
|
||||
|
||||
@test_util.disable_xla('b/131919749')
|
||||
def test_long_cases_with_repeated_labels(self):
|
||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
||||
cases = [
|
||||
# Tests from dask.
|
||||
'fdf,cdd,ccd,afe->ae',
|
||||
'fff,fae,bef,def->abd',
|
||||
]
|
||||
dimension_map = dict(
|
||||
(c, ((ord(c) - ord('a')) % 3) + 1) for c in 'abcdefghij')
|
||||
for equation in cases:
|
||||
inputs = equation.split('->')[0].replace(' ', '')
|
||||
input_shapes = []
|
||||
for input_str in inputs.split(','):
|
||||
input_shapes.append(tuple([dimension_map[c] for c in input_str]))
|
||||
self._check_gradient(equation, *input_shapes)
|
||||
cases = [
|
||||
# Tests from dask.
|
||||
'fdf,cdd,ccd,afe->ae',
|
||||
'fff,fae,bef,def->abd',
|
||||
]
|
||||
dimension_map = dict(
|
||||
(c, ((ord(c) - ord('a')) % 3) + 1) for c in 'abcdefghij')
|
||||
for equation in cases:
|
||||
inputs = equation.split('->')[0].replace(' ', '')
|
||||
input_shapes = []
|
||||
for input_str in inputs.split(','):
|
||||
input_shapes.append(tuple([dimension_map[c] for c in input_str]))
|
||||
self._check_gradient(equation, *input_shapes)
|
||||
|
||||
|
||||
class EinsumBenchmark(test.Benchmark):
|
||||
|
|
Loading…
Reference in New Issue