Remove forward_compatibility / forward_compatible checks for dates that have already passed.
- Also, removed print statements from relu_op_test.py PiperOrigin-RevId: 287911742 Change-Id: Ib1763a5a010e5738e4d93e348391839e1e164108
This commit is contained in:
parent
4b7f4c1f09
commit
f75c37faf3
|
@ -21,19 +21,10 @@ from __future__ import print_function
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.compiler.tests import xla_test
|
from tensorflow.compiler.tests import xla_test
|
||||||
from tensorflow.python.compat import compat
|
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.platform import googletest
|
from tensorflow.python.platform import googletest
|
||||||
|
|
||||||
|
|
||||||
# LINT.IfChange
|
|
||||||
matrix_diag_v3_forward_compat_date = (2019, 12, 6)
|
|
||||||
# LINT.ThenChange(
|
|
||||||
# //tensorflow/python/kernel_tests/diag_op_test.py,
|
|
||||||
# //tensorflow/python/ops/array_ops.py,
|
|
||||||
# //tensorflow/python/ops/parallel_for/array_test.py
|
|
||||||
# )
|
|
||||||
|
|
||||||
default_v2_alignment = "LEFT_LEFT"
|
default_v2_alignment = "LEFT_LEFT"
|
||||||
alignment_list = ["RIGHT_LEFT", "LEFT_RIGHT"]
|
alignment_list = ["RIGHT_LEFT", "LEFT_RIGHT"]
|
||||||
|
|
||||||
|
@ -404,7 +395,6 @@ class MatrixDiagTest(xla_test.XLATestCase):
|
||||||
|
|
||||||
# From here onwards are v2-only tests.
|
# From here onwards are v2-only tests.
|
||||||
def testSquare(self):
|
def testSquare(self):
|
||||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
|
||||||
for align in alignment_list:
|
for align in alignment_list:
|
||||||
for _, tests in [square_cases(align)]:
|
for _, tests in [square_cases(align)]:
|
||||||
for diag_index, (vecs, solution) in tests.items():
|
for diag_index, (vecs, solution) in tests.items():
|
||||||
|
@ -412,7 +402,6 @@ class MatrixDiagTest(xla_test.XLATestCase):
|
||||||
self._assertOpOutputMatchesExpected(params, solution[0])
|
self._assertOpOutputMatchesExpected(params, solution[0])
|
||||||
|
|
||||||
def testSquareBatch(self):
|
def testSquareBatch(self):
|
||||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
|
||||||
for align in alignment_list:
|
for align in alignment_list:
|
||||||
for _, tests in [square_cases(align)]:
|
for _, tests in [square_cases(align)]:
|
||||||
for diag_index, (vecs, solution) in tests.items():
|
for diag_index, (vecs, solution) in tests.items():
|
||||||
|
@ -420,9 +409,6 @@ class MatrixDiagTest(xla_test.XLATestCase):
|
||||||
self._assertOpOutputMatchesExpected(params, solution)
|
self._assertOpOutputMatchesExpected(params, solution)
|
||||||
|
|
||||||
def testRectangularBatch(self):
|
def testRectangularBatch(self):
|
||||||
if not compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
|
||||||
return
|
|
||||||
|
|
||||||
# Stores expected num_rows and num_cols (when the other is given).
|
# Stores expected num_rows and num_cols (when the other is given).
|
||||||
# expected[(d_lower, d_upper)] = (expected_num_rows, expected_num_cols)
|
# expected[(d_lower, d_upper)] = (expected_num_rows, expected_num_cols)
|
||||||
test_list = list()
|
test_list = list()
|
||||||
|
@ -513,7 +499,6 @@ class MatrixDiagTest(xla_test.XLATestCase):
|
||||||
}, solution_given_num_cols)
|
}, solution_given_num_cols)
|
||||||
|
|
||||||
def testPadding(self):
|
def testPadding(self):
|
||||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
|
||||||
for padding_value, align in zip_to_first_list_length([555, -11],
|
for padding_value, align in zip_to_first_list_length([555, -11],
|
||||||
alignment_list):
|
alignment_list):
|
||||||
for _, tests in all_tests(align):
|
for _, tests in all_tests(align):
|
||||||
|
@ -634,7 +619,6 @@ class MatrixSetDiagTest(xla_test.XLATestCase):
|
||||||
|
|
||||||
# From here onwards are v2-only tests.
|
# From here onwards are v2-only tests.
|
||||||
def testSingleMatrix(self):
|
def testSingleMatrix(self):
|
||||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
|
||||||
for align in alignment_list:
|
for align in alignment_list:
|
||||||
for _, tests in all_tests(align):
|
for _, tests in all_tests(align):
|
||||||
for diag_index, (vecs, banded_mat) in tests.items():
|
for diag_index, (vecs, banded_mat) in tests.items():
|
||||||
|
@ -650,7 +634,6 @@ class MatrixSetDiagTest(xla_test.XLATestCase):
|
||||||
}, solution)
|
}, solution)
|
||||||
|
|
||||||
def testBatch(self):
|
def testBatch(self):
|
||||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
|
||||||
for align in alignment_list:
|
for align in alignment_list:
|
||||||
for _, tests in all_tests(align):
|
for _, tests in all_tests(align):
|
||||||
for diag_index, (vecs, banded_mat) in tests.items():
|
for diag_index, (vecs, banded_mat) in tests.items():
|
||||||
|
@ -705,7 +688,6 @@ class MatrixDiagPartTest(xla_test.XLATestCase):
|
||||||
|
|
||||||
# From here onwards are v2-only tests.
|
# From here onwards are v2-only tests.
|
||||||
def testSingleMatrix(self):
|
def testSingleMatrix(self):
|
||||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
|
||||||
for align in alignment_list:
|
for align in alignment_list:
|
||||||
test_list = [square_cases(align), tall_cases(align), fat_cases(align)]
|
test_list = [square_cases(align), tall_cases(align), fat_cases(align)]
|
||||||
for mat, tests in test_list:
|
for mat, tests in test_list:
|
||||||
|
@ -718,7 +700,6 @@ class MatrixDiagPartTest(xla_test.XLATestCase):
|
||||||
}, solution[0])
|
}, solution[0])
|
||||||
|
|
||||||
def testBatch(self):
|
def testBatch(self):
|
||||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
|
||||||
for align in alignment_list:
|
for align in alignment_list:
|
||||||
for mat, tests in all_tests(align):
|
for mat, tests in all_tests(align):
|
||||||
for diag_index, (solution, _) in tests.items():
|
for diag_index, (solution, _) in tests.items():
|
||||||
|
@ -730,7 +711,6 @@ class MatrixDiagPartTest(xla_test.XLATestCase):
|
||||||
}, solution)
|
}, solution)
|
||||||
|
|
||||||
def testPadding(self):
|
def testPadding(self):
|
||||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
|
||||||
for padding_value, align in zip_to_first_list_length([555, -11],
|
for padding_value, align in zip_to_first_list_length([555, -11],
|
||||||
alignment_list):
|
alignment_list):
|
||||||
for mat, tests in all_tests(align):
|
for mat, tests in all_tests(align):
|
||||||
|
|
|
@ -23,7 +23,6 @@ from tensorflow.core.protobuf import cluster_pb2
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
from tensorflow.core.protobuf import tensorflow_server_pb2
|
from tensorflow.core.protobuf import tensorflow_server_pb2
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import pywrap_tensorflow
|
||||||
from tensorflow.python.compat import compat
|
|
||||||
from tensorflow.python.data.experimental.ops import distribute
|
from tensorflow.python.data.experimental.ops import distribute
|
||||||
from tensorflow.python.data.experimental.ops import distribute_options
|
from tensorflow.python.data.experimental.ops import distribute_options
|
||||||
from tensorflow.python.data.kernel_tests import test_base
|
from tensorflow.python.data.kernel_tests import test_base
|
||||||
|
@ -90,7 +89,6 @@ class LocalReplicateTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
@combinations.generate(
|
@combinations.generate(
|
||||||
combinations.combine(tf_api_version=[1], mode=["graph", "eager"]))
|
combinations.combine(tf_api_version=[1], mode=["graph", "eager"]))
|
||||||
def testExternalStatePolicyIgnore(self):
|
def testExternalStatePolicyIgnore(self):
|
||||||
with compat.forward_compatibility_horizon(2019, 11, 30):
|
|
||||||
with ops.device(self._device0):
|
with ops.device(self._device0):
|
||||||
dataset0 = dataset_ops.Dataset.range(100).map(
|
dataset0 = dataset_ops.Dataset.range(100).map(
|
||||||
lambda _: random_ops.random_uniform( # pylint:disable=g-long-lambda
|
lambda _: random_ops.random_uniform( # pylint:disable=g-long-lambda
|
||||||
|
@ -122,7 +120,6 @@ class LocalReplicateTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
@combinations.generate(
|
@combinations.generate(
|
||||||
combinations.combine(tf_api_version=[1], mode=["graph", "eager"]))
|
combinations.combine(tf_api_version=[1], mode=["graph", "eager"]))
|
||||||
def testExternalStatePolicyWarn(self):
|
def testExternalStatePolicyWarn(self):
|
||||||
with compat.forward_compatibility_horizon(2019, 11, 30):
|
|
||||||
with ops.device(self._device0):
|
with ops.device(self._device0):
|
||||||
dataset0 = dataset_ops.Dataset.range(100).map(
|
dataset0 = dataset_ops.Dataset.range(100).map(
|
||||||
lambda _: random_ops.random_uniform( # pylint:disable=g-long-lambda
|
lambda _: random_ops.random_uniform( # pylint:disable=g-long-lambda
|
||||||
|
@ -154,7 +151,6 @@ class LocalReplicateTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
@combinations.generate(
|
@combinations.generate(
|
||||||
combinations.combine(tf_api_version=[1], mode=["graph", "eager"]))
|
combinations.combine(tf_api_version=[1], mode=["graph", "eager"]))
|
||||||
def testExternalStatePolicyFail(self):
|
def testExternalStatePolicyFail(self):
|
||||||
with compat.forward_compatibility_horizon(2019, 11, 30):
|
|
||||||
with ops.device(self._device0):
|
with ops.device(self._device0):
|
||||||
dataset0 = dataset_ops.Dataset.range(100).map(
|
dataset0 = dataset_ops.Dataset.range(100).map(
|
||||||
lambda _: random_ops.random_uniform( # pylint:disable=g-long-lambda
|
lambda _: random_ops.random_uniform( # pylint:disable=g-long-lambda
|
||||||
|
|
|
@ -17,7 +17,6 @@ from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.python.compat import compat
|
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.data.util import nest
|
from tensorflow.python.data.util import nest
|
||||||
from tensorflow.python.data.util import structure
|
from tensorflow.python.data.util import structure
|
||||||
|
@ -125,8 +124,7 @@ class _ScanDataset(dataset_ops.UnaryDataset):
|
||||||
self._scan_func = wrapped_func
|
self._scan_func = wrapped_func
|
||||||
self._scan_func.function.add_to_graph(ops.get_default_graph())
|
self._scan_func.function.add_to_graph(ops.get_default_graph())
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
if compat.forward_compatible(2019, 10,
|
if use_default_device is not None:
|
||||||
15) or use_default_device is not None:
|
|
||||||
variant_tensor = gen_experimental_dataset_ops.scan_dataset(
|
variant_tensor = gen_experimental_dataset_ops.scan_dataset(
|
||||||
self._input_dataset._variant_tensor,
|
self._input_dataset._variant_tensor,
|
||||||
structure.to_tensor_list(self._state_structure, self._initial_state),
|
structure.to_tensor_list(self._state_structure, self._initial_state),
|
||||||
|
|
|
@ -25,7 +25,6 @@ import numpy as np
|
||||||
from tensorflow.core.protobuf import cluster_pb2
|
from tensorflow.core.protobuf import cluster_pb2
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
from tensorflow.python.compat import compat as forward_compat
|
|
||||||
from tensorflow.python.data.kernel_tests import test_base
|
from tensorflow.python.data.kernel_tests import test_base
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.data.ops import iterator_ops
|
from tensorflow.python.data.ops import iterator_ops
|
||||||
|
@ -464,7 +463,6 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
|
|
||||||
@combinations.generate(test_base.graph_only_combinations())
|
@combinations.generate(test_base.graph_only_combinations())
|
||||||
def testIteratorStringHandleFuture(self):
|
def testIteratorStringHandleFuture(self):
|
||||||
with forward_compat.forward_compatibility_horizon(2018, 8, 4):
|
|
||||||
dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
|
dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
|
||||||
dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40])
|
dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40])
|
||||||
|
|
||||||
|
|
|
@ -30,7 +30,6 @@ from six.moves import queue as Queue # pylint: disable=redefined-builtin
|
||||||
|
|
||||||
from tensorflow.core.framework import graph_pb2
|
from tensorflow.core.framework import graph_pb2
|
||||||
from tensorflow.python import tf2
|
from tensorflow.python import tf2
|
||||||
from tensorflow.python.compat import compat
|
|
||||||
from tensorflow.python.data.experimental.ops import distribute_options
|
from tensorflow.python.data.experimental.ops import distribute_options
|
||||||
from tensorflow.python.data.experimental.ops import optimization_options
|
from tensorflow.python.data.experimental.ops import optimization_options
|
||||||
from tensorflow.python.data.experimental.ops import stats_options
|
from tensorflow.python.data.experimental.ops import stats_options
|
||||||
|
@ -223,7 +222,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
|
||||||
A scalar `tf.Tensor` of `tf.string` type, representing this dataset as a
|
A scalar `tf.Tensor` of `tf.string` type, representing this dataset as a
|
||||||
serialized graph.
|
serialized graph.
|
||||||
"""
|
"""
|
||||||
if compat.forward_compatible(2019, 11, 25) or external_state_policy:
|
if external_state_policy:
|
||||||
policy = None
|
policy = None
|
||||||
if external_state_policy:
|
if external_state_policy:
|
||||||
policy = external_state_policy.value
|
policy = external_state_policy.value
|
||||||
|
@ -231,7 +230,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
|
||||||
self._variant_tensor,
|
self._variant_tensor,
|
||||||
external_state_policy=policy,
|
external_state_policy=policy,
|
||||||
strip_device_assignment=strip_device_assignment)
|
strip_device_assignment=strip_device_assignment)
|
||||||
if compat.forward_compatible(2019, 11, 16) or strip_device_assignment:
|
if strip_device_assignment:
|
||||||
return gen_dataset_ops.dataset_to_graph(
|
return gen_dataset_ops.dataset_to_graph(
|
||||||
self._variant_tensor,
|
self._variant_tensor,
|
||||||
allow_stateful=allow_stateful,
|
allow_stateful=allow_stateful,
|
||||||
|
|
|
@ -28,7 +28,6 @@ from tensorflow.core.protobuf import config_pb2
|
||||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||||
from tensorflow.python import tf2
|
from tensorflow.python import tf2
|
||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
from tensorflow.python.compat import compat
|
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
@ -395,7 +394,6 @@ class AutoMixedPrecisionTest(test.TestCase):
|
||||||
@test_util.disable_xla('This test does not pass with XLA')
|
@test_util.disable_xla('This test does not pass with XLA')
|
||||||
def test_conv_bn(self):
|
def test_conv_bn(self):
|
||||||
"""Test graph with convolution followed by batch norm."""
|
"""Test graph with convolution followed by batch norm."""
|
||||||
with compat.forward_compatibility_horizon(2019, 6, 7):
|
|
||||||
if test.is_gpu_available(cuda_only=True):
|
if test.is_gpu_available(cuda_only=True):
|
||||||
random_seed.set_random_seed(0)
|
random_seed.set_random_seed(0)
|
||||||
x = _input([2, 8, 8, 1])
|
x = _input([2, 8, 8, 1])
|
||||||
|
@ -468,7 +466,6 @@ class AutoMixedPrecisionTest(test.TestCase):
|
||||||
@test_util.disable_xla('This test does not pass with XLA')
|
@test_util.disable_xla('This test does not pass with XLA')
|
||||||
def test_conv_bn_dropout(self):
|
def test_conv_bn_dropout(self):
|
||||||
"""Test dropout precision of convolution batch norm graph."""
|
"""Test dropout precision of convolution batch norm graph."""
|
||||||
with compat.forward_compatibility_horizon(2019, 6, 7):
|
|
||||||
if test.is_gpu_available(cuda_only=True):
|
if test.is_gpu_available(cuda_only=True):
|
||||||
random_seed.set_random_seed(0)
|
random_seed.set_random_seed(0)
|
||||||
x = _input([2, 8, 8, 1])
|
x = _input([2, 8, 8, 1])
|
||||||
|
|
|
@ -22,7 +22,6 @@ import numpy as np
|
||||||
|
|
||||||
from tensorflow.python import tf2
|
from tensorflow.python import tf2
|
||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
from tensorflow.python.compat import compat
|
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
@ -143,7 +142,6 @@ def _GetBatchMatmulOpBroadcastingTest(dtype, adjoint_a, adjoint_b,
|
||||||
use_static_shape):
|
use_static_shape):
|
||||||
|
|
||||||
def Test(self):
|
def Test(self):
|
||||||
with compat.forward_compatibility_horizon(2019, 4, 26):
|
|
||||||
np.random.seed(42)
|
np.random.seed(42)
|
||||||
self._testBroadcasting(dtype, adjoint_a, adjoint_b, use_static_shape)
|
self._testBroadcasting(dtype, adjoint_a, adjoint_b, use_static_shape)
|
||||||
|
|
||||||
|
@ -200,7 +198,6 @@ def _GetBatchMatmulGradientWithBroadcastingTest(dtype, adjoint_a, adjoint_b):
|
||||||
def CheckGradients(self, a_shape, b_shape):
|
def CheckGradients(self, a_shape, b_shape):
|
||||||
self._compare(a_shape, b_shape, dtype, adjoint_a, adjoint_b)
|
self._compare(a_shape, b_shape, dtype, adjoint_a, adjoint_b)
|
||||||
|
|
||||||
with compat.forward_compatibility_horizon(2019, 4, 26):
|
|
||||||
CheckGradients(self, [1, 5, 2, 3], [7, 1, 3, 2])
|
CheckGradients(self, [1, 5, 2, 3], [7, 1, 3, 2])
|
||||||
CheckGradients(self, [2, 3], [1, 3, 5])
|
CheckGradients(self, [2, 3], [1, 3, 5])
|
||||||
CheckGradients(self, [2, 3], [5, 3, 5])
|
CheckGradients(self, [2, 3], [5, 3, 5])
|
||||||
|
@ -231,7 +228,6 @@ class BatchMatMulBenchmark(test.Benchmark):
|
||||||
|
|
||||||
def benchmarkBatchMatMulBroadcast(self):
|
def benchmarkBatchMatMulBroadcast(self):
|
||||||
for (a_shape, b_shape) in self.shape_pairs:
|
for (a_shape, b_shape) in self.shape_pairs:
|
||||||
with compat.forward_compatibility_horizon(2019, 4, 26):
|
|
||||||
with ops.Graph().as_default(), \
|
with ops.Graph().as_default(), \
|
||||||
session.Session(config=benchmark.benchmark_config()) as sess, \
|
session.Session(config=benchmark.benchmark_config()) as sess, \
|
||||||
ops.device("/cpu:0"):
|
ops.device("/cpu:0"):
|
||||||
|
|
|
@ -21,7 +21,6 @@ import itertools
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.compat import compat
|
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes as dtypes_lib
|
from tensorflow.python.framework import dtypes as dtypes_lib
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
@ -33,15 +32,6 @@ from tensorflow.python.platform import test
|
||||||
from tensorflow.python.platform import tf_logging
|
from tensorflow.python.platform import tf_logging
|
||||||
|
|
||||||
|
|
||||||
# LINT.IfChange
|
|
||||||
matrix_diag_v3_forward_compat_date = (2019, 12, 6)
|
|
||||||
# LINT.ThenChange(
|
|
||||||
# //tensorflow/compiler/tests/matrix_diag_ops_test.py,
|
|
||||||
# //tensorflow/python/ops/array_ops.py,
|
|
||||||
# //tensorflow/python/ops/parallel_for/array_test.py
|
|
||||||
# )
|
|
||||||
|
|
||||||
|
|
||||||
default_v2_alignment = "LEFT_LEFT"
|
default_v2_alignment = "LEFT_LEFT"
|
||||||
alignment_list = ["RIGHT_LEFT", "LEFT_RIGHT", "LEFT_LEFT", "RIGHT_RIGHT"]
|
alignment_list = ["RIGHT_LEFT", "LEFT_RIGHT", "LEFT_LEFT", "RIGHT_RIGHT"]
|
||||||
|
|
||||||
|
@ -391,7 +381,6 @@ class MatrixDiagTest(test.TestCase):
|
||||||
self.assertEqual((3, 3), v_diag.get_shape())
|
self.assertEqual((3, 3), v_diag.get_shape())
|
||||||
self.assertAllEqual(v_diag.eval(), mat)
|
self.assertAllEqual(v_diag.eval(), mat)
|
||||||
|
|
||||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
|
||||||
# {Sub,Super}diagonals.
|
# {Sub,Super}diagonals.
|
||||||
for offset in [1, -2, 5]:
|
for offset in [1, -2, 5]:
|
||||||
mat = np.diag(v, offset)
|
mat = np.diag(v, offset)
|
||||||
|
@ -417,7 +406,6 @@ class MatrixDiagTest(test.TestCase):
|
||||||
self.assertEqual((2, 3, 3), v_batch_diag.get_shape())
|
self.assertEqual((2, 3, 3), v_batch_diag.get_shape())
|
||||||
self.assertAllEqual(v_batch_diag.eval(), mat_batch)
|
self.assertAllEqual(v_batch_diag.eval(), mat_batch)
|
||||||
|
|
||||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
|
||||||
# {Sub,Super}diagonals.
|
# {Sub,Super}diagonals.
|
||||||
for offset in [1, -2, 5]:
|
for offset in [1, -2, 5]:
|
||||||
v_batch_diag = array_ops.matrix_diag(v_batch, k=offset)
|
v_batch_diag = array_ops.matrix_diag(v_batch, k=offset)
|
||||||
|
@ -453,7 +441,6 @@ class MatrixDiagTest(test.TestCase):
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testRectangularBatch(self):
|
def testRectangularBatch(self):
|
||||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
|
||||||
with self.cached_session(use_gpu=True):
|
with self.cached_session(use_gpu=True):
|
||||||
# Stores expected num_rows and num_cols (when the other is given).
|
# Stores expected num_rows and num_cols (when the other is given).
|
||||||
# expected[d_lower, d_upper] = (expected_num_rows, expected_num_cols)
|
# expected[d_lower, d_upper] = (expected_num_rows, expected_num_cols)
|
||||||
|
@ -574,7 +561,6 @@ class MatrixDiagTest(test.TestCase):
|
||||||
y.get_shape().as_list())
|
y.get_shape().as_list())
|
||||||
self.assertLess(error, 1e-4)
|
self.assertLess(error, 1e-4)
|
||||||
|
|
||||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
|
||||||
# {Sub,super}diagonals/band.
|
# {Sub,super}diagonals/band.
|
||||||
tests = dict() # tests[shape] = (d_lower, d_upper)
|
tests = dict() # tests[shape] = (d_lower, d_upper)
|
||||||
tests[(3,)] = (-1, -1)
|
tests[(3,)] = (-1, -1)
|
||||||
|
@ -604,7 +590,6 @@ class MatrixSetDiagTest(test.TestCase):
|
||||||
self.assertEqual((3, 3), output.get_shape())
|
self.assertEqual((3, 3), output.get_shape())
|
||||||
self.assertAllEqual(mat_set_diag, self.evaluate(output))
|
self.assertAllEqual(mat_set_diag, self.evaluate(output))
|
||||||
|
|
||||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
|
||||||
# Diagonal bands.
|
# Diagonal bands.
|
||||||
for align in alignment_list:
|
for align in alignment_list:
|
||||||
_, tests = square_cases(align)
|
_, tests = square_cases(align)
|
||||||
|
@ -634,7 +619,6 @@ class MatrixSetDiagTest(test.TestCase):
|
||||||
self.assertEqual((3, 2), output.get_shape())
|
self.assertEqual((3, 2), output.get_shape())
|
||||||
self.assertAllEqual(expected, self.evaluate(output))
|
self.assertAllEqual(expected, self.evaluate(output))
|
||||||
|
|
||||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
|
||||||
# Diagonal bands.
|
# Diagonal bands.
|
||||||
for align in alignment_list:
|
for align in alignment_list:
|
||||||
for _, tests in [tall_cases(align), fat_cases(align)]:
|
for _, tests in [tall_cases(align), fat_cases(align)]:
|
||||||
|
@ -663,7 +647,6 @@ class MatrixSetDiagTest(test.TestCase):
|
||||||
self.assertEqual((2, 3, 3), output.get_shape())
|
self.assertEqual((2, 3, 3), output.get_shape())
|
||||||
self.assertAllEqual(mat_set_diag_batch, self.evaluate(output))
|
self.assertAllEqual(mat_set_diag_batch, self.evaluate(output))
|
||||||
|
|
||||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
|
||||||
# Diagonal bands.
|
# Diagonal bands.
|
||||||
for align in alignment_list:
|
for align in alignment_list:
|
||||||
_, tests = square_cases(align)
|
_, tests = square_cases(align)
|
||||||
|
@ -697,7 +680,6 @@ class MatrixSetDiagTest(test.TestCase):
|
||||||
self.assertEqual((2, 2, 3), output.get_shape())
|
self.assertEqual((2, 2, 3), output.get_shape())
|
||||||
self.assertAllEqual(mat_set_diag_batch, self.evaluate(output))
|
self.assertAllEqual(mat_set_diag_batch, self.evaluate(output))
|
||||||
|
|
||||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
|
||||||
# Diagonal bands.
|
# Diagonal bands.
|
||||||
for align in alignment_list:
|
for align in alignment_list:
|
||||||
for _, tests in [tall_cases(align), fat_cases(align)]:
|
for _, tests in [tall_cases(align), fat_cases(align)]:
|
||||||
|
@ -727,7 +709,6 @@ class MatrixSetDiagTest(test.TestCase):
|
||||||
with self.assertRaisesOpError("diagonal must be at least 1-dim"):
|
with self.assertRaisesOpError("diagonal must be at least 1-dim"):
|
||||||
array_ops.matrix_set_diag([[v]], v).eval(feed_dict={v: 0.0})
|
array_ops.matrix_set_diag([[v]], v).eval(feed_dict={v: 0.0})
|
||||||
|
|
||||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
|
||||||
d = array_ops.placeholder(dtype=dtypes_lib.float32)
|
d = array_ops.placeholder(dtype=dtypes_lib.float32)
|
||||||
with self.assertRaisesOpError(
|
with self.assertRaisesOpError(
|
||||||
"first dimensions of diagonal don't match"):
|
"first dimensions of diagonal don't match"):
|
||||||
|
@ -743,10 +724,7 @@ class MatrixSetDiagTest(test.TestCase):
|
||||||
x_diag = constant_op.constant(
|
x_diag = constant_op.constant(
|
||||||
np.random.rand(*diag_shape), dtype=dtypes_lib.float32)
|
np.random.rand(*diag_shape), dtype=dtypes_lib.float32)
|
||||||
|
|
||||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
|
||||||
y = array_ops.matrix_set_diag(x, x_diag, k=diags, align=align)
|
y = array_ops.matrix_set_diag(x, x_diag, k=diags, align=align)
|
||||||
else:
|
|
||||||
y = array_ops.matrix_set_diag(x, x_diag)
|
|
||||||
error_x = gradient_checker.compute_gradient_error(x,
|
error_x = gradient_checker.compute_gradient_error(x,
|
||||||
x.get_shape().as_list(),
|
x.get_shape().as_list(),
|
||||||
y,
|
y,
|
||||||
|
@ -763,7 +741,6 @@ class MatrixSetDiagTest(test.TestCase):
|
||||||
input_shapes = [(3, 4, 4), (3, 3, 4), (3, 4, 3), (7, 4, 8, 8)]
|
input_shapes = [(3, 4, 4), (3, 3, 4), (3, 4, 3), (7, 4, 8, 8)]
|
||||||
diag_bands = [(0, 0)]
|
diag_bands = [(0, 0)]
|
||||||
|
|
||||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
|
||||||
diag_bands.append((-1, 1))
|
diag_bands.append((-1, 1))
|
||||||
for input_shape, diags, align in itertools.product(input_shapes, diag_bands,
|
for input_shape, diags, align in itertools.product(input_shapes, diag_bands,
|
||||||
alignment_list):
|
alignment_list):
|
||||||
|
@ -805,7 +782,6 @@ class MatrixDiagPartTest(test.TestCase):
|
||||||
self.assertEqual((3,), mat_diag.get_shape())
|
self.assertEqual((3,), mat_diag.get_shape())
|
||||||
self.assertAllEqual(mat_diag.eval(), v)
|
self.assertAllEqual(mat_diag.eval(), v)
|
||||||
|
|
||||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
|
||||||
for offset in [-2, 3]:
|
for offset in [-2, 3]:
|
||||||
mat = np.diag(v, offset)
|
mat = np.diag(v, offset)
|
||||||
mat_diag = array_ops.matrix_diag_part(mat, k=offset)
|
mat_diag = array_ops.matrix_diag_part(mat, k=offset)
|
||||||
|
@ -831,7 +807,6 @@ class MatrixDiagPartTest(test.TestCase):
|
||||||
mat_diag = array_ops.matrix_diag_part(mat)
|
mat_diag = array_ops.matrix_diag_part(mat)
|
||||||
self.assertAllEqual(mat_diag.eval(), np.array([1.0, 4.0]))
|
self.assertAllEqual(mat_diag.eval(), np.array([1.0, 4.0]))
|
||||||
|
|
||||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
|
||||||
# Diagonal bands.
|
# Diagonal bands.
|
||||||
for align in alignment_list:
|
for align in alignment_list:
|
||||||
for mat, tests in [tall_cases(align), fat_cases(align)]:
|
for mat, tests in [tall_cases(align), fat_cases(align)]:
|
||||||
|
@ -853,7 +828,6 @@ class MatrixDiagPartTest(test.TestCase):
|
||||||
self.assertEqual((2, 3), mat_batch_diag.get_shape())
|
self.assertEqual((2, 3), mat_batch_diag.get_shape())
|
||||||
self.assertAllEqual(mat_batch_diag.eval(), v_batch)
|
self.assertAllEqual(mat_batch_diag.eval(), v_batch)
|
||||||
|
|
||||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
|
||||||
# Diagonal bands with padding_value.
|
# Diagonal bands with padding_value.
|
||||||
for padding_value, align in zip_to_first_list_length([0, 555, -11],
|
for padding_value, align in zip_to_first_list_length([0, 555, -11],
|
||||||
alignment_list):
|
alignment_list):
|
||||||
|
@ -889,7 +863,6 @@ class MatrixDiagPartTest(test.TestCase):
|
||||||
self.assertEqual((2, 2), mat_batch_diag.get_shape())
|
self.assertEqual((2, 2), mat_batch_diag.get_shape())
|
||||||
self.assertAllEqual(mat_batch_diag.eval(), v_batch)
|
self.assertAllEqual(mat_batch_diag.eval(), v_batch)
|
||||||
|
|
||||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
|
||||||
# Diagonal bands with padding_value and align.
|
# Diagonal bands with padding_value and align.
|
||||||
for padding_value, align in zip_to_first_list_length([0, 555, -11],
|
for padding_value, align in zip_to_first_list_length([0, 555, -11],
|
||||||
alignment_list):
|
alignment_list):
|
||||||
|
@ -905,7 +878,6 @@ class MatrixDiagPartTest(test.TestCase):
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testUnknownShape(self):
|
def testUnknownShape(self):
|
||||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
|
||||||
matrix = array_ops.placeholder(dtypes_lib.int32, shape=[None, None])
|
matrix = array_ops.placeholder(dtypes_lib.int32, shape=[None, None])
|
||||||
result = array_ops.matrix_diag_part(matrix, k=-1)
|
result = array_ops.matrix_diag_part(matrix, k=-1)
|
||||||
input_matrix = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
|
input_matrix = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
|
||||||
|
@ -939,7 +911,6 @@ class MatrixDiagPartTest(test.TestCase):
|
||||||
y.get_shape().as_list())
|
y.get_shape().as_list())
|
||||||
self.assertLess(error, 1e-4)
|
self.assertLess(error, 1e-4)
|
||||||
|
|
||||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
|
||||||
# {Sub,super}diagonals/band.
|
# {Sub,super}diagonals/band.
|
||||||
tests = dict() # tests[shape] = (d_lower, d_upper)
|
tests = dict() # tests[shape] = (d_lower, d_upper)
|
||||||
tests[(3, 3)] = (-1, -1)
|
tests[(3, 3)] = (-1, -1)
|
||||||
|
|
|
@ -22,7 +22,6 @@ import numpy as np
|
||||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||||
|
|
||||||
from tensorflow.python import tf2
|
from tensorflow.python import tf2
|
||||||
from tensorflow.python.compat import compat
|
|
||||||
from tensorflow.python.eager import backprop
|
from tensorflow.python.eager import backprop
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
@ -122,7 +121,6 @@ class ReluTest(test.TestCase):
|
||||||
order="F")
|
order="F")
|
||||||
err = gradient_checker_v2.max_error(
|
err = gradient_checker_v2.max_error(
|
||||||
*gradient_checker_v2.compute_gradient(nn_ops.relu, [x]))
|
*gradient_checker_v2.compute_gradient(nn_ops.relu, [x]))
|
||||||
print("relu (float32) gradient err = ", err)
|
|
||||||
self.assertLess(err, 1e-4)
|
self.assertLess(err, 1e-4)
|
||||||
|
|
||||||
# The gradient for fp16 is inaccurate due to the low-precision.
|
# The gradient for fp16 is inaccurate due to the low-precision.
|
||||||
|
@ -171,7 +169,6 @@ class ReluTest(test.TestCase):
|
||||||
order="F")
|
order="F")
|
||||||
err = gradient_checker_v2.max_error(
|
err = gradient_checker_v2.max_error(
|
||||||
*gradient_checker_v2.compute_gradient(nn_ops.relu, [x]))
|
*gradient_checker_v2.compute_gradient(nn_ops.relu, [x]))
|
||||||
print("relu (float64) gradient err = ", err)
|
|
||||||
self.assertLess(err, 1e-10)
|
self.assertLess(err, 1e-10)
|
||||||
|
|
||||||
def testGradGradFloat32(self):
|
def testGradGradFloat32(self):
|
||||||
|
@ -190,7 +187,6 @@ class ReluTest(test.TestCase):
|
||||||
order="F")
|
order="F")
|
||||||
err = gradient_checker_v2.max_error(
|
err = gradient_checker_v2.max_error(
|
||||||
*gradient_checker_v2.compute_gradient(f, [x]))
|
*gradient_checker_v2.compute_gradient(f, [x]))
|
||||||
print("relu (float32) gradient of gradient err = ", err)
|
|
||||||
self.assertLess(err, 1e-4)
|
self.assertLess(err, 1e-4)
|
||||||
|
|
||||||
def testGradGradFloat64(self):
|
def testGradGradFloat64(self):
|
||||||
|
@ -209,7 +205,6 @@ class ReluTest(test.TestCase):
|
||||||
order="F")
|
order="F")
|
||||||
err = gradient_checker_v2.max_error(
|
err = gradient_checker_v2.max_error(
|
||||||
*gradient_checker_v2.compute_gradient(f, [x]))
|
*gradient_checker_v2.compute_gradient(f, [x]))
|
||||||
print("relu (float64) gradient of gradient err = ", err)
|
|
||||||
self.assertLess(err, 1e-10)
|
self.assertLess(err, 1e-10)
|
||||||
|
|
||||||
def testGradientScalar(self):
|
def testGradientScalar(self):
|
||||||
|
@ -283,7 +278,6 @@ class Relu6Test(test.TestCase):
|
||||||
order="F")
|
order="F")
|
||||||
err = gradient_checker_v2.max_error(
|
err = gradient_checker_v2.max_error(
|
||||||
*gradient_checker_v2.compute_gradient(nn_ops.relu6, [x]))
|
*gradient_checker_v2.compute_gradient(nn_ops.relu6, [x]))
|
||||||
print("relu6 (float32) gradient err = ", err)
|
|
||||||
self.assertLess(err, 1e-4)
|
self.assertLess(err, 1e-4)
|
||||||
|
|
||||||
def testGradientFloat64(self):
|
def testGradientFloat64(self):
|
||||||
|
@ -294,7 +288,6 @@ class Relu6Test(test.TestCase):
|
||||||
order="F")
|
order="F")
|
||||||
err = gradient_checker_v2.max_error(
|
err = gradient_checker_v2.max_error(
|
||||||
*gradient_checker_v2.compute_gradient(nn_ops.relu6, [x]))
|
*gradient_checker_v2.compute_gradient(nn_ops.relu6, [x]))
|
||||||
print("relu6 (float64) gradient err = ", err)
|
|
||||||
self.assertLess(err, 1e-10)
|
self.assertLess(err, 1e-10)
|
||||||
|
|
||||||
|
|
||||||
|
@ -345,7 +338,6 @@ class LeakyReluTest(test.TestCase):
|
||||||
order="F")
|
order="F")
|
||||||
err = gradient_checker_v2.max_error(
|
err = gradient_checker_v2.max_error(
|
||||||
*gradient_checker_v2.compute_gradient(nn_ops.leaky_relu, [x]))
|
*gradient_checker_v2.compute_gradient(nn_ops.leaky_relu, [x]))
|
||||||
print("leaky_relu (float32) gradient err = ", err)
|
|
||||||
self.assertLess(err, 1e-4)
|
self.assertLess(err, 1e-4)
|
||||||
|
|
||||||
def testGradientFloat64(self):
|
def testGradientFloat64(self):
|
||||||
|
@ -356,11 +348,9 @@ class LeakyReluTest(test.TestCase):
|
||||||
order="F")
|
order="F")
|
||||||
err = gradient_checker_v2.max_error(
|
err = gradient_checker_v2.max_error(
|
||||||
*gradient_checker_v2.compute_gradient(nn_ops.leaky_relu, [x]))
|
*gradient_checker_v2.compute_gradient(nn_ops.leaky_relu, [x]))
|
||||||
print("leaky_relu (float64) gradient err = ", err)
|
|
||||||
self.assertLess(err, 1e-10)
|
self.assertLess(err, 1e-10)
|
||||||
|
|
||||||
def testGradGradFloat32(self):
|
def testGradGradFloat32(self):
|
||||||
with compat.forward_compatibility_horizon(2018, 11, 2):
|
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
|
|
||||||
def f(x):
|
def f(x):
|
||||||
|
@ -376,11 +366,9 @@ class LeakyReluTest(test.TestCase):
|
||||||
order="F")
|
order="F")
|
||||||
err = gradient_checker_v2.max_error(
|
err = gradient_checker_v2.max_error(
|
||||||
*gradient_checker_v2.compute_gradient(f, [x]))
|
*gradient_checker_v2.compute_gradient(f, [x]))
|
||||||
print("leaky_relu (float32) gradient of gradient err = ", err)
|
|
||||||
self.assertLess(err, 1e-4)
|
self.assertLess(err, 1e-4)
|
||||||
|
|
||||||
def testGradGradFloat64(self):
|
def testGradGradFloat64(self):
|
||||||
with compat.forward_compatibility_horizon(2018, 11, 2):
|
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
|
|
||||||
def f(x):
|
def f(x):
|
||||||
|
@ -396,7 +384,6 @@ class LeakyReluTest(test.TestCase):
|
||||||
order="F")
|
order="F")
|
||||||
err = gradient_checker_v2.max_error(
|
err = gradient_checker_v2.max_error(
|
||||||
*gradient_checker_v2.compute_gradient(f, [x]))
|
*gradient_checker_v2.compute_gradient(f, [x]))
|
||||||
print("leaky_relu (float64) gradient of gradient err = ", err)
|
|
||||||
self.assertLess(err, 1e-10)
|
self.assertLess(err, 1e-10)
|
||||||
|
|
||||||
def testGradientScalar(self):
|
def testGradientScalar(self):
|
||||||
|
@ -463,7 +450,6 @@ class EluTest(test.TestCase):
|
||||||
x = np.asarray(x_val, dtype=np.float32, order="F")
|
x = np.asarray(x_val, dtype=np.float32, order="F")
|
||||||
err = gradient_checker_v2.max_error(
|
err = gradient_checker_v2.max_error(
|
||||||
*gradient_checker_v2.compute_gradient(nn_ops.elu, [x]))
|
*gradient_checker_v2.compute_gradient(nn_ops.elu, [x]))
|
||||||
print("elu (float32) gradient err = ", err)
|
|
||||||
self.assertLess(err, 1e-4)
|
self.assertLess(err, 1e-4)
|
||||||
|
|
||||||
def testGradientFloat64(self):
|
def testGradientFloat64(self):
|
||||||
|
@ -472,7 +458,6 @@ class EluTest(test.TestCase):
|
||||||
x = np.asarray(x_val, dtype=np.float64, order="F")
|
x = np.asarray(x_val, dtype=np.float64, order="F")
|
||||||
err = gradient_checker_v2.max_error(
|
err = gradient_checker_v2.max_error(
|
||||||
*gradient_checker_v2.compute_gradient(nn_ops.elu, [x]))
|
*gradient_checker_v2.compute_gradient(nn_ops.elu, [x]))
|
||||||
print("elu (float64) gradient err = ", err)
|
|
||||||
self.assertLess(err, 1e-6)
|
self.assertLess(err, 1e-6)
|
||||||
|
|
||||||
def testGradGrad(self):
|
def testGradGrad(self):
|
||||||
|
@ -507,7 +492,6 @@ class EluTest(test.TestCase):
|
||||||
order="F")
|
order="F")
|
||||||
err = gradient_checker_v2.max_error(
|
err = gradient_checker_v2.max_error(
|
||||||
*gradient_checker_v2.compute_gradient(f, [x]))
|
*gradient_checker_v2.compute_gradient(f, [x]))
|
||||||
print("elu (float32) gradient of gradient err = ", err)
|
|
||||||
self.assertLess(err, 1e-4)
|
self.assertLess(err, 1e-4)
|
||||||
|
|
||||||
def testGradGradFloat64(self):
|
def testGradGradFloat64(self):
|
||||||
|
@ -526,7 +510,6 @@ class EluTest(test.TestCase):
|
||||||
order="F")
|
order="F")
|
||||||
err = gradient_checker_v2.max_error(
|
err = gradient_checker_v2.max_error(
|
||||||
*gradient_checker_v2.compute_gradient(f, [x]))
|
*gradient_checker_v2.compute_gradient(f, [x]))
|
||||||
print("elu (float64) gradient of gradient err = ", err)
|
|
||||||
self.assertLess(err, 1e-6)
|
self.assertLess(err, 1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
@ -567,7 +550,6 @@ class SeluTest(test.TestCase):
|
||||||
x = np.asarray(x_val, dtype=np.float32, order="F")
|
x = np.asarray(x_val, dtype=np.float32, order="F")
|
||||||
err = gradient_checker_v2.max_error(
|
err = gradient_checker_v2.max_error(
|
||||||
*gradient_checker_v2.compute_gradient(nn_ops.selu, [x]))
|
*gradient_checker_v2.compute_gradient(nn_ops.selu, [x]))
|
||||||
print("selu (float32) gradient err = ", err)
|
|
||||||
self.assertLess(err, 1e-4)
|
self.assertLess(err, 1e-4)
|
||||||
|
|
||||||
def testGradientFloat64(self):
|
def testGradientFloat64(self):
|
||||||
|
@ -576,7 +558,6 @@ class SeluTest(test.TestCase):
|
||||||
x = np.asarray(x_val, dtype=np.float64, order="F")
|
x = np.asarray(x_val, dtype=np.float64, order="F")
|
||||||
err = gradient_checker_v2.max_error(
|
err = gradient_checker_v2.max_error(
|
||||||
*gradient_checker_v2.compute_gradient(nn_ops.selu, [x]))
|
*gradient_checker_v2.compute_gradient(nn_ops.selu, [x]))
|
||||||
print("selu (float64) gradient err = ", err)
|
|
||||||
self.assertLess(err, 1e-6)
|
self.assertLess(err, 1e-6)
|
||||||
|
|
||||||
def testGradGradFloat32(self):
|
def testGradGradFloat32(self):
|
||||||
|
@ -595,7 +576,6 @@ class SeluTest(test.TestCase):
|
||||||
order="F")
|
order="F")
|
||||||
err = gradient_checker_v2.max_error(
|
err = gradient_checker_v2.max_error(
|
||||||
*gradient_checker_v2.compute_gradient(f, [x]))
|
*gradient_checker_v2.compute_gradient(f, [x]))
|
||||||
print("selu (float32) gradient of gradient err = ", err)
|
|
||||||
self.assertLess(err, 1e-4)
|
self.assertLess(err, 1e-4)
|
||||||
|
|
||||||
def testGradGradFloat64(self):
|
def testGradGradFloat64(self):
|
||||||
|
@ -614,7 +594,6 @@ class SeluTest(test.TestCase):
|
||||||
order="F")
|
order="F")
|
||||||
err = gradient_checker_v2.max_error(
|
err = gradient_checker_v2.max_error(
|
||||||
*gradient_checker_v2.compute_gradient(f, [x]))
|
*gradient_checker_v2.compute_gradient(f, [x]))
|
||||||
print("selu (float64) gradient of gradient err = ", err)
|
|
||||||
self.assertLess(err, 1e-6)
|
self.assertLess(err, 1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,6 @@ from __future__ import print_function
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from tensorflow.python.compat import compat
|
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import common_shapes
|
from tensorflow.python.framework import common_shapes
|
||||||
from tensorflow.python.framework import composite_tensor
|
from tensorflow.python.framework import composite_tensor
|
||||||
|
@ -54,13 +53,6 @@ tf_export("newaxis").export_constant(__name__, "newaxis")
|
||||||
# existing 'slice' for later use in this module.
|
# existing 'slice' for later use in this module.
|
||||||
_BaseSlice = slice
|
_BaseSlice = slice
|
||||||
|
|
||||||
# LINT.IfChange
|
|
||||||
matrix_diag_v3_forward_compat_date = (2019, 12, 6)
|
|
||||||
# LINT.ThenChange(
|
|
||||||
# //tensorflow/compiler/tests/matrix_diag_ops_test.py,
|
|
||||||
# //tensorflow/python/kernel_tests/diag_op_test.py,
|
|
||||||
# //tensorflow/python/ops/parallel_for/array_test.py
|
|
||||||
# )
|
|
||||||
|
|
||||||
@tf_export("reshape", v1=["reshape", "manip.reshape"])
|
@tf_export("reshape", v1=["reshape", "manip.reshape"])
|
||||||
def reshape(tensor, shape, name=None): # pylint: disable=redefined-outer-name
|
def reshape(tensor, shape, name=None): # pylint: disable=redefined-outer-name
|
||||||
|
@ -2362,7 +2354,6 @@ def matrix_diag(diagonal,
|
||||||
Returns:
|
Returns:
|
||||||
A Tensor. Has the same type as `diagonal`.
|
A Tensor. Has the same type as `diagonal`.
|
||||||
"""
|
"""
|
||||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
|
||||||
# Special case to sidestep the tf.constant conversion error:
|
# Special case to sidestep the tf.constant conversion error:
|
||||||
# TypeError: Expected bool, got 0 of type 'int' instead.
|
# TypeError: Expected bool, got 0 of type 'int' instead.
|
||||||
if hasattr(diagonal, "dtype") and diagonal.dtype == "bool":
|
if hasattr(diagonal, "dtype") and diagonal.dtype == "bool":
|
||||||
|
@ -2377,10 +2368,6 @@ def matrix_diag(diagonal,
|
||||||
align=align,
|
align=align,
|
||||||
name=name)
|
name=name)
|
||||||
|
|
||||||
# Call v1 to maintain forward compatibility.
|
|
||||||
# (We skip v2 because its alignment conflicts with v3's default alignment.)
|
|
||||||
return gen_array_ops.matrix_diag(diagonal=diagonal, name=name)
|
|
||||||
|
|
||||||
|
|
||||||
@tf_export("linalg.diag_part", v1=["linalg.diag_part", "matrix_diag_part"])
|
@tf_export("linalg.diag_part", v1=["linalg.diag_part", "matrix_diag_part"])
|
||||||
@deprecation.deprecated_endpoints("matrix_diag_part")
|
@deprecation.deprecated_endpoints("matrix_diag_part")
|
||||||
|
@ -2513,7 +2500,6 @@ def matrix_diag_part(
|
||||||
Returns:
|
Returns:
|
||||||
A Tensor containing diagonals of `input`. Has the same type as `input`.
|
A Tensor containing diagonals of `input`. Has the same type as `input`.
|
||||||
"""
|
"""
|
||||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
|
||||||
# Special case to sidestep the tf.constant conversion error:
|
# Special case to sidestep the tf.constant conversion error:
|
||||||
# TypeError: Expected bool, got 0 of type 'int' instead.
|
# TypeError: Expected bool, got 0 of type 'int' instead.
|
||||||
if hasattr(input, "dtype") and input.dtype == "bool":
|
if hasattr(input, "dtype") and input.dtype == "bool":
|
||||||
|
@ -2522,10 +2508,6 @@ def matrix_diag_part(
|
||||||
return gen_array_ops.matrix_diag_part_v3(
|
return gen_array_ops.matrix_diag_part_v3(
|
||||||
input=input, k=k, padding_value=padding_value, align=align, name=name)
|
input=input, k=k, padding_value=padding_value, align=align, name=name)
|
||||||
|
|
||||||
# Call v1 to maintain forward compatibility.
|
|
||||||
# (We skip v2 because its alignment conflicts with v3's default alignment.)
|
|
||||||
return gen_array_ops.matrix_diag_part(input=input, name=name)
|
|
||||||
|
|
||||||
|
|
||||||
@tf_export("linalg.set_diag", v1=["linalg.set_diag", "matrix_set_diag"])
|
@tf_export("linalg.set_diag", v1=["linalg.set_diag", "matrix_set_diag"])
|
||||||
@deprecation.deprecated_endpoints("matrix_set_diag")
|
@deprecation.deprecated_endpoints("matrix_set_diag")
|
||||||
|
@ -2659,15 +2641,9 @@ def matrix_set_diag(
|
||||||
the left (right-pads the row). It is the packing format LAPACK uses.
|
the left (right-pads the row). It is the packing format LAPACK uses.
|
||||||
cuSPARSE uses "LEFT_RIGHT", which is the opposite alignment.
|
cuSPARSE uses "LEFT_RIGHT", which is the opposite alignment.
|
||||||
"""
|
"""
|
||||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
|
||||||
return gen_array_ops.matrix_set_diag_v3(
|
return gen_array_ops.matrix_set_diag_v3(
|
||||||
input=input, diagonal=diagonal, k=k, align=align, name=name)
|
input=input, diagonal=diagonal, k=k, align=align, name=name)
|
||||||
|
|
||||||
# Call v1 to maintain forward compatibility.
|
|
||||||
# (We skip v2 because its alignment conflicts with v3's default alignment.)
|
|
||||||
return gen_array_ops.matrix_set_diag(
|
|
||||||
input=input, diagonal=diagonal, name=name)
|
|
||||||
|
|
||||||
|
|
||||||
# pylint: enable=invalid-name
|
# pylint: enable=invalid-name
|
||||||
|
|
||||||
|
@ -4921,7 +4897,7 @@ def quantize_v2(
|
||||||
raise ValueError("input should have known rank to use negative axis.")
|
raise ValueError("input should have known rank to use negative axis.")
|
||||||
axis %= input.shape.ndims
|
axis %= input.shape.ndims
|
||||||
|
|
||||||
if compat.forward_compatible(2019, 11, 13) or ensure_minimum_range != 0.01:
|
if ensure_minimum_range != 0.01:
|
||||||
return gen_array_ops.quantize_v2(
|
return gen_array_ops.quantize_v2(
|
||||||
input,
|
input,
|
||||||
min_range,
|
min_range,
|
||||||
|
@ -4965,7 +4941,7 @@ def quantize(
|
||||||
axis=None,
|
axis=None,
|
||||||
ensure_minimum_range=0.01):
|
ensure_minimum_range=0.01):
|
||||||
"""Quantize the input tensor."""
|
"""Quantize the input tensor."""
|
||||||
if compat.forward_compatible(2019, 11, 13) or ensure_minimum_range != 0.01:
|
if ensure_minimum_range != 0.01:
|
||||||
return quantize_v2(
|
return quantize_v2(
|
||||||
input,
|
input,
|
||||||
min_range,
|
min_range,
|
||||||
|
@ -5007,7 +4983,7 @@ def dequantize( # pylint: disable=missing-docstring
|
||||||
raise ValueError("input should have known rank to use negative axis.")
|
raise ValueError("input should have known rank to use negative axis.")
|
||||||
axis %= input.shape.ndims
|
axis %= input.shape.ndims
|
||||||
|
|
||||||
if compat.forward_compatible(2019, 10, 22) or axis >= 0 or narrow_range:
|
if axis >= 0 or narrow_range:
|
||||||
return gen_array_ops.dequantize(
|
return gen_array_ops.dequantize(
|
||||||
input, min_range, max_range, mode=mode, name=name,
|
input, min_range, max_range, mode=mode, name=name,
|
||||||
narrow_range=narrow_range, axis=axis)
|
narrow_range=narrow_range, axis=axis)
|
||||||
|
|
|
@ -20,7 +20,6 @@ from __future__ import print_function
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.compat import compat
|
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
|
@ -383,7 +382,6 @@ class BatchNormalizationTest(test.TestCase):
|
||||||
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
|
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
|
||||||
|
|
||||||
def testInferenceShape5(self):
|
def testInferenceShape5(self):
|
||||||
with compat.forward_compatibility_horizon(2019, 6, 7):
|
|
||||||
x_shape = [0, 131, 127, 6]
|
x_shape = [0, 131, 127, 6]
|
||||||
for dtype in [np.float16, np.float32]:
|
for dtype in [np.float16, np.float32]:
|
||||||
if test.is_gpu_available(cuda_only=True):
|
if test.is_gpu_available(cuda_only=True):
|
||||||
|
@ -450,7 +448,6 @@ class BatchNormalizationTest(test.TestCase):
|
||||||
|
|
||||||
@test_util.disable_xla('b/141236973: Empty inputs wrong on CPU.')
|
@test_util.disable_xla('b/141236973: Empty inputs wrong on CPU.')
|
||||||
def testTrainingShape5(self):
|
def testTrainingShape5(self):
|
||||||
with compat.forward_compatibility_horizon(2019, 6, 7):
|
|
||||||
x_shape = [0, 131, 127, 6]
|
x_shape = [0, 131, 127, 6]
|
||||||
for dtype in [np.float16, np.float32]:
|
for dtype in [np.float16, np.float32]:
|
||||||
if test.is_gpu_available(cuda_only=True):
|
if test.is_gpu_available(cuda_only=True):
|
||||||
|
@ -586,7 +583,6 @@ class BatchNormalizationTest(test.TestCase):
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
@test_util.disable_xla('This test never passed for XLA')
|
@test_util.disable_xla('This test never passed for XLA')
|
||||||
def testBatchNormGradShape5(self):
|
def testBatchNormGradShape5(self):
|
||||||
with compat.forward_compatibility_horizon(2019, 6, 7):
|
|
||||||
for is_training in [True, False]:
|
for is_training in [True, False]:
|
||||||
x_shape = [0, 7, 11, 4]
|
x_shape = [0, 7, 11, 4]
|
||||||
for dtype in [np.float16, np.float32]:
|
for dtype in [np.float16, np.float32]:
|
||||||
|
|
|
@ -18,7 +18,6 @@ from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.python.compat import compat
|
|
||||||
from tensorflow.python.eager import backprop
|
from tensorflow.python.eager import backprop
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
@ -33,15 +32,6 @@ from tensorflow.python.ops.parallel_for.test_util import PForTestCase
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
# LINT.IfChange
|
|
||||||
matrix_diag_v3_forward_compat_date = (2019, 12, 6)
|
|
||||||
# LINT.ThenChange(
|
|
||||||
# //tensorflow/compiler/tests/matrix_diag_ops_test.py,
|
|
||||||
# //tensorflow/python/kernel_tests/diag_op_test.py,
|
|
||||||
# //tensorflow/python/ops/array_ops.py
|
|
||||||
# )
|
|
||||||
|
|
||||||
|
|
||||||
@test_util.run_all_in_graph_and_eager_modes
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
class ArrayTest(PForTestCase):
|
class ArrayTest(PForTestCase):
|
||||||
|
|
||||||
|
@ -345,10 +335,8 @@ class ArrayTest(PForTestCase):
|
||||||
|
|
||||||
def loop_fn(i):
|
def loop_fn(i):
|
||||||
diagonal = array_ops.gather(x, i)
|
diagonal = array_ops.gather(x, i)
|
||||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
|
||||||
return array_ops.matrix_diag(
|
return array_ops.matrix_diag(
|
||||||
diagonal, k=(0, 1), num_rows=4, num_cols=5, align="RIGHT_LEFT")
|
diagonal, k=(0, 1), num_rows=4, num_cols=5, align="RIGHT_LEFT")
|
||||||
return array_ops.matrix_diag(diagonal)
|
|
||||||
|
|
||||||
self._test_loop_fn(loop_fn, 3)
|
self._test_loop_fn(loop_fn, 3)
|
||||||
|
|
||||||
|
@ -357,10 +345,8 @@ class ArrayTest(PForTestCase):
|
||||||
|
|
||||||
def loop_fn(i):
|
def loop_fn(i):
|
||||||
input = array_ops.gather(x, i) # pylint: disable=redefined-builtin
|
input = array_ops.gather(x, i) # pylint: disable=redefined-builtin
|
||||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
|
||||||
return array_ops.matrix_diag_part(
|
return array_ops.matrix_diag_part(
|
||||||
input, k=(-2, 0), padding_value=3, align="RIGHT_LEFT")
|
input, k=(-2, 0), padding_value=3, align="RIGHT_LEFT")
|
||||||
return array_ops.matrix_diag_part(input)
|
|
||||||
|
|
||||||
self._test_loop_fn(loop_fn, 3)
|
self._test_loop_fn(loop_fn, 3)
|
||||||
|
|
||||||
|
@ -378,7 +364,6 @@ class ArrayTest(PForTestCase):
|
||||||
array_ops.matrix_set_diag(matrix_i, diags[0, ...]),
|
array_ops.matrix_set_diag(matrix_i, diags[0, ...]),
|
||||||
]
|
]
|
||||||
|
|
||||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
|
||||||
k = (-1, 1)
|
k = (-1, 1)
|
||||||
band_i = array_ops.gather(bands, i)
|
band_i = array_ops.gather(bands, i)
|
||||||
for align in ["RIGHT_LEFT", "LEFT_RIGHT"]:
|
for align in ["RIGHT_LEFT", "LEFT_RIGHT"]:
|
||||||
|
|
|
@ -28,7 +28,6 @@ import numpy as np
|
||||||
from tensorflow.core.example import example_pb2
|
from tensorflow.core.example import example_pb2
|
||||||
from tensorflow.core.example import feature_pb2
|
from tensorflow.core.example import feature_pb2
|
||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
from tensorflow.python.compat import compat
|
|
||||||
from tensorflow.python.eager import backprop
|
from tensorflow.python.eager import backprop
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
|
@ -444,7 +443,6 @@ class NNTest(PForTestCase):
|
||||||
self._test_loop_fn(loop_fn, 3)
|
self._test_loop_fn(loop_fn, 3)
|
||||||
|
|
||||||
def test_fused_batch_norm(self):
|
def test_fused_batch_norm(self):
|
||||||
with compat.forward_compatibility_horizon(2019, 6, 7):
|
|
||||||
data_formats = ["NHWC"]
|
data_formats = ["NHWC"]
|
||||||
if test.is_gpu_available():
|
if test.is_gpu_available():
|
||||||
data_formats.append("NCHW")
|
data_formats.append("NCHW")
|
||||||
|
|
|
@ -33,7 +33,6 @@ import six
|
||||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||||
|
|
||||||
from tensorflow.compiler.tf2xla.ops import gen_xla_ops
|
from tensorflow.compiler.tf2xla.ops import gen_xla_ops
|
||||||
from tensorflow.python.compat import compat as fwd_compat
|
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
@ -253,10 +252,7 @@ def einsum(equation, *inputs, **kwargs):
|
||||||
- the format of `equation` is incorrect,
|
- the format of `equation` is incorrect,
|
||||||
- number of inputs or their shapes are inconsistent with `equation`.
|
- number of inputs or their shapes are inconsistent with `equation`.
|
||||||
"""
|
"""
|
||||||
if fwd_compat.forward_compatible(2019, 10, 18):
|
|
||||||
return _einsum_v2(equation, *inputs, **kwargs)
|
return _einsum_v2(equation, *inputs, **kwargs)
|
||||||
else:
|
|
||||||
return _einsum_v1(equation, *inputs, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def _einsum_v1(equation, *inputs, **kwargs):
|
def _einsum_v1(equation, *inputs, **kwargs):
|
||||||
|
|
|
@ -23,7 +23,6 @@ import opt_einsum
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
from tensorflow.python.compat import compat
|
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
@ -243,7 +242,6 @@ class EinsumTest(test.TestCase):
|
||||||
self._check('abc->ca', (3, 4, 5))
|
self._check('abc->ca', (3, 4, 5))
|
||||||
self._check('abc->cab', (3, 4, 5))
|
self._check('abc->cab', (3, 4, 5))
|
||||||
|
|
||||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
|
||||||
# Empty cases.
|
# Empty cases.
|
||||||
self._check('', ())
|
self._check('', ())
|
||||||
self._check('->', ())
|
self._check('->', ())
|
||||||
|
@ -266,7 +264,6 @@ class EinsumTest(test.TestCase):
|
||||||
self._check('...ij->...ji', (5, 2, 3)) # batch matrix transpose
|
self._check('...ij->...ji', (5, 2, 3)) # batch matrix transpose
|
||||||
self._check('...ij->...', (5, 2, 3)) # batch sum
|
self._check('...ij->...', (5, 2, 3)) # batch sum
|
||||||
|
|
||||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
|
||||||
self._check('...->...', ())
|
self._check('...->...', ())
|
||||||
self._check('->...', ())
|
self._check('->...', ())
|
||||||
|
|
||||||
|
@ -301,7 +298,6 @@ class EinsumTest(test.TestCase):
|
||||||
self._check('ab,ab->', (3, 4), (3, 4))
|
self._check('ab,ab->', (3, 4), (3, 4))
|
||||||
|
|
||||||
def test_repeated_indices(self):
|
def test_repeated_indices(self):
|
||||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
|
||||||
# Repeated indices.
|
# Repeated indices.
|
||||||
self._check('ijj,k->ik', (2, 3, 3), (4,))
|
self._check('ijj,k->ik', (2, 3, 3), (4,))
|
||||||
self._check('aba,a->b', (3, 4, 3), (3,))
|
self._check('aba,a->b', (3, 4, 3), (3,))
|
||||||
|
@ -324,7 +320,6 @@ class EinsumTest(test.TestCase):
|
||||||
self._check('...i,...j->...ij', (5, 2), (5, 3)) # outer product
|
self._check('...i,...j->...ij', (5, 2), (5, 3)) # outer product
|
||||||
|
|
||||||
def test_broadcasting(self):
|
def test_broadcasting(self):
|
||||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
|
||||||
# Batch matmul with broadcasting.
|
# Batch matmul with broadcasting.
|
||||||
self._check('...ij,...jk->...ik', (1, 2, 3), (3, 5))
|
self._check('...ij,...jk->...ik', (1, 2, 3), (3, 5))
|
||||||
self._check('...ij,...jk->...ik', (2, 3), (1, 3, 5))
|
self._check('...ij,...jk->...ik', (2, 3), (1, 3, 5))
|
||||||
|
@ -388,12 +383,10 @@ class EinsumTest(test.TestCase):
|
||||||
((4, 3), (None, 3)))
|
((4, 3), (None, 3)))
|
||||||
|
|
||||||
# Ellipsis with unknown rank.
|
# Ellipsis with unknown rank.
|
||||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
|
||||||
check('bijl,bjkm->bik', ((9, 2, 3, 5), None), ((9, 3, 4, 7), None))
|
check('bijl,bjkm->bik', ((9, 2, 3, 5), None), ((9, 3, 4, 7), None))
|
||||||
check('...ij,...jk->...ik', ((3, 1, 2, 3), None), ((1, 7, 3, 4), None))
|
check('...ij,...jk->...ik', ((3, 1, 2, 3), None), ((1, 7, 3, 4), None))
|
||||||
|
|
||||||
def test_numpy_input(self):
|
def test_numpy_input(self):
|
||||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
|
||||||
# In addition to Tensors, we also support raw numpy arrays as inputs.
|
# In addition to Tensors, we also support raw numpy arrays as inputs.
|
||||||
r = np.random.RandomState(0)
|
r = np.random.RandomState(0)
|
||||||
s = 'ijk,ijl,ikl->i'
|
s = 'ijk,ijl,ikl->i'
|
||||||
|
@ -464,7 +457,6 @@ class EinsumTest(test.TestCase):
|
||||||
|
|
||||||
@test_util.disable_xla('b/131919749')
|
@test_util.disable_xla('b/131919749')
|
||||||
def test_long_cases_with_repeated_labels(self):
|
def test_long_cases_with_repeated_labels(self):
|
||||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
|
||||||
cases = [
|
cases = [
|
||||||
# Tests from dask.
|
# Tests from dask.
|
||||||
'fdf,cdd,ccd,afe->ae',
|
'fdf,cdd,ccd,afe->ae',
|
||||||
|
@ -481,7 +473,6 @@ class EinsumTest(test.TestCase):
|
||||||
@test_util.disable_xla('b/131919749')
|
@test_util.disable_xla('b/131919749')
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def test_invalid_equation(self):
|
def test_invalid_equation(self):
|
||||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
|
||||||
r = np.random.RandomState(0)
|
r = np.random.RandomState(0)
|
||||||
cases = [
|
cases = [
|
||||||
# invalid equation format.
|
# invalid equation format.
|
||||||
|
@ -535,7 +526,6 @@ class EinsumTest(test.TestCase):
|
||||||
# From transformer xl.
|
# From transformer xl.
|
||||||
check('ibnd,ijbn->jnd', [(1, 0, 5, 10), (1, 1, 0, 5)], (1, 5, 10))
|
check('ibnd,ijbn->jnd', [(1, 0, 5, 10), (1, 1, 0, 5)], (1, 5, 10))
|
||||||
|
|
||||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
|
||||||
# Generalized traces with zero-sized dimensions.
|
# Generalized traces with zero-sized dimensions.
|
||||||
check('aab,bc->ac', [(0, 0, 10), (10, 10)], (0, 10))
|
check('aab,bc->ac', [(0, 0, 10), (10, 10)], (0, 10))
|
||||||
check('aaab,bc->c', [(0, 0, 0, 3), (3, 4)], (4,))
|
check('aaab,bc->c', [(0, 0, 0, 3), (3, 4)], (4,))
|
||||||
|
@ -556,7 +546,6 @@ class EinsumGradTest(test.TestCase):
|
||||||
|
|
||||||
@test_util.disable_xla('b/131919749')
|
@test_util.disable_xla('b/131919749')
|
||||||
def test_unary(self):
|
def test_unary(self):
|
||||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
|
||||||
self._check_gradient('->', ())
|
self._check_gradient('->', ())
|
||||||
self._check_gradient('aaa->a', (3, 3, 3))
|
self._check_gradient('aaa->a', (3, 3, 3))
|
||||||
self._check_gradient('aabcd->ad', (3, 3, 5, 4, 4))
|
self._check_gradient('aabcd->ad', (3, 3, 5, 4, 4))
|
||||||
|
@ -564,7 +553,6 @@ class EinsumGradTest(test.TestCase):
|
||||||
|
|
||||||
@test_util.disable_xla('b/131919749')
|
@test_util.disable_xla('b/131919749')
|
||||||
def test_unary_ellipsis(self):
|
def test_unary_ellipsis(self):
|
||||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
|
||||||
self._check_gradient('...->...', ())
|
self._check_gradient('...->...', ())
|
||||||
self._check_gradient('...->', ())
|
self._check_gradient('...->', ())
|
||||||
self._check_gradient('->...', ())
|
self._check_gradient('->...', ())
|
||||||
|
@ -582,7 +570,6 @@ class EinsumGradTest(test.TestCase):
|
||||||
self._check_gradient('ab...cd->da...', (3, 5, 2, 3, 4, 2))
|
self._check_gradient('ab...cd->da...', (3, 5, 2, 3, 4, 2))
|
||||||
|
|
||||||
def test_binary_simple(self):
|
def test_binary_simple(self):
|
||||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
|
||||||
# Binary cases in XLA mode must have either (a) each index appearing
|
# Binary cases in XLA mode must have either (a) each index appearing
|
||||||
# exactly once in both the inputs (batch or contraction index), or
|
# exactly once in both the inputs (batch or contraction index), or
|
||||||
# (b) appearing exactly once in an input and in the output (free index).
|
# (b) appearing exactly once in an input and in the output (free index).
|
||||||
|
@ -598,20 +585,17 @@ class EinsumGradTest(test.TestCase):
|
||||||
self._check_gradient('sa,shb->shab', (2, 1), (2, 3, 4))
|
self._check_gradient('sa,shb->shab', (2, 1), (2, 3, 4))
|
||||||
|
|
||||||
def test_empty(self):
|
def test_empty(self):
|
||||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
|
||||||
# From Transformer XL.
|
# From Transformer XL.
|
||||||
self._check_gradient('ibnd,ijbn->jnd', (1, 0, 5, 10), (1, 1, 0, 5))
|
self._check_gradient('ibnd,ijbn->jnd', (1, 0, 5, 10), (1, 1, 0, 5))
|
||||||
|
|
||||||
@test_util.disable_xla('b/131919749')
|
@test_util.disable_xla('b/131919749')
|
||||||
def test_reduced_indices(self):
|
def test_reduced_indices(self):
|
||||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
|
||||||
self._check_gradient('ba,b->', (3, 2), (3,))
|
self._check_gradient('ba,b->', (3, 2), (3,))
|
||||||
self._check_gradient('ab,ab->', (3, 4), (3, 4))
|
self._check_gradient('ab,ab->', (3, 4), (3, 4))
|
||||||
self._check_gradient('abce,badf->abcd', (1, 2, 3, 4), (2, 1, 4, 3))
|
self._check_gradient('abce,badf->abcd', (1, 2, 3, 4), (2, 1, 4, 3))
|
||||||
|
|
||||||
@test_util.disable_xla('b/131919749')
|
@test_util.disable_xla('b/131919749')
|
||||||
def test_repeated_indices(self):
|
def test_repeated_indices(self):
|
||||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
|
||||||
# Repeated indices.
|
# Repeated indices.
|
||||||
self._check_gradient('aba,a->b', (3, 4, 3), (3,))
|
self._check_gradient('aba,a->b', (3, 4, 3), (3,))
|
||||||
self._check_gradient('ijj,k->ik', (2, 3, 3), (4,))
|
self._check_gradient('ijj,k->ik', (2, 3, 3), (4,))
|
||||||
|
@ -622,14 +606,12 @@ class EinsumGradTest(test.TestCase):
|
||||||
|
|
||||||
@test_util.disable_xla('b/131919749')
|
@test_util.disable_xla('b/131919749')
|
||||||
def test_empty_with_repeated_indices(self):
|
def test_empty_with_repeated_indices(self):
|
||||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
|
||||||
self._check_gradient('aab,bc->ac', (0, 0, 10), (10, 10))
|
self._check_gradient('aab,bc->ac', (0, 0, 10), (10, 10))
|
||||||
self._check_gradient('aab,bc->ac', (1, 1, 0), (0, 10))
|
self._check_gradient('aab,bc->ac', (1, 1, 0), (0, 10))
|
||||||
self._check_gradient('aaab,bc->c', (0, 0, 0, 3), (3, 4))
|
self._check_gradient('aaab,bc->c', (0, 0, 0, 3), (3, 4))
|
||||||
|
|
||||||
@test_util.disable_xla('b/131919749')
|
@test_util.disable_xla('b/131919749')
|
||||||
def test_broadcasting(self):
|
def test_broadcasting(self):
|
||||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
|
||||||
self._check_gradient('...ij,...jk->...ik', (3, 2), (2, 4))
|
self._check_gradient('...ij,...jk->...ik', (3, 2), (2, 4))
|
||||||
self._check_gradient('ij...,jk...->ik...', (3, 2, 1), (2, 4))
|
self._check_gradient('ij...,jk...->ik...', (3, 2, 1), (2, 4))
|
||||||
self._check_gradient('...ij,...jk->...ik', (3, 1, 3, 2), (1, 5, 2, 4))
|
self._check_gradient('...ij,...jk->...ik', (3, 1, 3, 2), (1, 5, 2, 4))
|
||||||
|
@ -641,7 +623,6 @@ class EinsumGradTest(test.TestCase):
|
||||||
self._check_gradient('...i,...j,...k->...ijk', (1,), (1,), (1,))
|
self._check_gradient('...i,...j,...k->...ijk', (1,), (1,), (1,))
|
||||||
|
|
||||||
def test_long_cases(self):
|
def test_long_cases(self):
|
||||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
|
||||||
cases = [
|
cases = [
|
||||||
'abhe,hidj,jgba,hiab,gab->ed',
|
'abhe,hidj,jgba,hiab,gab->ed',
|
||||||
# Tests from dask.
|
# Tests from dask.
|
||||||
|
@ -658,7 +639,6 @@ class EinsumGradTest(test.TestCase):
|
||||||
|
|
||||||
@test_util.disable_xla('b/131919749')
|
@test_util.disable_xla('b/131919749')
|
||||||
def test_long_cases_with_repeated_labels(self):
|
def test_long_cases_with_repeated_labels(self):
|
||||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
|
||||||
cases = [
|
cases = [
|
||||||
# Tests from dask.
|
# Tests from dask.
|
||||||
'fdf,cdd,ccd,afe->ae',
|
'fdf,cdd,ccd,afe->ae',
|
||||||
|
|
Loading…
Reference in New Issue