Remove forward_compatibility / forward_compatible checks for dates that have already passed.
- Also, removed print statements from relu_op_test.py PiperOrigin-RevId: 287911742 Change-Id: Ib1763a5a010e5738e4d93e348391839e1e164108
This commit is contained in:
parent
4b7f4c1f09
commit
f75c37faf3
|
@ -21,19 +21,10 @@ from __future__ import print_function
|
|||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
# LINT.IfChange
|
||||
matrix_diag_v3_forward_compat_date = (2019, 12, 6)
|
||||
# LINT.ThenChange(
|
||||
# //tensorflow/python/kernel_tests/diag_op_test.py,
|
||||
# //tensorflow/python/ops/array_ops.py,
|
||||
# //tensorflow/python/ops/parallel_for/array_test.py
|
||||
# )
|
||||
|
||||
default_v2_alignment = "LEFT_LEFT"
|
||||
alignment_list = ["RIGHT_LEFT", "LEFT_RIGHT"]
|
||||
|
||||
|
@ -404,7 +395,6 @@ class MatrixDiagTest(xla_test.XLATestCase):
|
|||
|
||||
# From here onwards are v2-only tests.
|
||||
def testSquare(self):
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
for align in alignment_list:
|
||||
for _, tests in [square_cases(align)]:
|
||||
for diag_index, (vecs, solution) in tests.items():
|
||||
|
@ -412,7 +402,6 @@ class MatrixDiagTest(xla_test.XLATestCase):
|
|||
self._assertOpOutputMatchesExpected(params, solution[0])
|
||||
|
||||
def testSquareBatch(self):
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
for align in alignment_list:
|
||||
for _, tests in [square_cases(align)]:
|
||||
for diag_index, (vecs, solution) in tests.items():
|
||||
|
@ -420,9 +409,6 @@ class MatrixDiagTest(xla_test.XLATestCase):
|
|||
self._assertOpOutputMatchesExpected(params, solution)
|
||||
|
||||
def testRectangularBatch(self):
|
||||
if not compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
return
|
||||
|
||||
# Stores expected num_rows and num_cols (when the other is given).
|
||||
# expected[(d_lower, d_upper)] = (expected_num_rows, expected_num_cols)
|
||||
test_list = list()
|
||||
|
@ -513,7 +499,6 @@ class MatrixDiagTest(xla_test.XLATestCase):
|
|||
}, solution_given_num_cols)
|
||||
|
||||
def testPadding(self):
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
for padding_value, align in zip_to_first_list_length([555, -11],
|
||||
alignment_list):
|
||||
for _, tests in all_tests(align):
|
||||
|
@ -634,7 +619,6 @@ class MatrixSetDiagTest(xla_test.XLATestCase):
|
|||
|
||||
# From here onwards are v2-only tests.
|
||||
def testSingleMatrix(self):
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
for align in alignment_list:
|
||||
for _, tests in all_tests(align):
|
||||
for diag_index, (vecs, banded_mat) in tests.items():
|
||||
|
@ -650,7 +634,6 @@ class MatrixSetDiagTest(xla_test.XLATestCase):
|
|||
}, solution)
|
||||
|
||||
def testBatch(self):
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
for align in alignment_list:
|
||||
for _, tests in all_tests(align):
|
||||
for diag_index, (vecs, banded_mat) in tests.items():
|
||||
|
@ -705,7 +688,6 @@ class MatrixDiagPartTest(xla_test.XLATestCase):
|
|||
|
||||
# From here onwards are v2-only tests.
|
||||
def testSingleMatrix(self):
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
for align in alignment_list:
|
||||
test_list = [square_cases(align), tall_cases(align), fat_cases(align)]
|
||||
for mat, tests in test_list:
|
||||
|
@ -718,7 +700,6 @@ class MatrixDiagPartTest(xla_test.XLATestCase):
|
|||
}, solution[0])
|
||||
|
||||
def testBatch(self):
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
for align in alignment_list:
|
||||
for mat, tests in all_tests(align):
|
||||
for diag_index, (solution, _) in tests.items():
|
||||
|
@ -730,7 +711,6 @@ class MatrixDiagPartTest(xla_test.XLATestCase):
|
|||
}, solution)
|
||||
|
||||
def testPadding(self):
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
for padding_value, align in zip_to_first_list_length([555, -11],
|
||||
alignment_list):
|
||||
for mat, tests in all_tests(align):
|
||||
|
|
|
@ -23,7 +23,6 @@ from tensorflow.core.protobuf import cluster_pb2
|
|||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.core.protobuf import tensorflow_server_pb2
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.experimental.ops import distribute
|
||||
from tensorflow.python.data.experimental.ops import distribute_options
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
|
@ -90,7 +89,6 @@ class LocalReplicateTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1], mode=["graph", "eager"]))
|
||||
def testExternalStatePolicyIgnore(self):
|
||||
with compat.forward_compatibility_horizon(2019, 11, 30):
|
||||
with ops.device(self._device0):
|
||||
dataset0 = dataset_ops.Dataset.range(100).map(
|
||||
lambda _: random_ops.random_uniform( # pylint:disable=g-long-lambda
|
||||
|
@ -122,7 +120,6 @@ class LocalReplicateTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1], mode=["graph", "eager"]))
|
||||
def testExternalStatePolicyWarn(self):
|
||||
with compat.forward_compatibility_horizon(2019, 11, 30):
|
||||
with ops.device(self._device0):
|
||||
dataset0 = dataset_ops.Dataset.range(100).map(
|
||||
lambda _: random_ops.random_uniform( # pylint:disable=g-long-lambda
|
||||
|
@ -154,7 +151,6 @@ class LocalReplicateTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
@combinations.generate(
|
||||
combinations.combine(tf_api_version=[1], mode=["graph", "eager"]))
|
||||
def testExternalStatePolicyFail(self):
|
||||
with compat.forward_compatibility_horizon(2019, 11, 30):
|
||||
with ops.device(self._device0):
|
||||
dataset0 = dataset_ops.Dataset.range(100).map(
|
||||
lambda _: random_ops.random_uniform( # pylint:disable=g-long-lambda
|
||||
|
|
|
@ -17,7 +17,6 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.data.util import structure
|
||||
|
@ -125,8 +124,7 @@ class _ScanDataset(dataset_ops.UnaryDataset):
|
|||
self._scan_func = wrapped_func
|
||||
self._scan_func.function.add_to_graph(ops.get_default_graph())
|
||||
# pylint: disable=protected-access
|
||||
if compat.forward_compatible(2019, 10,
|
||||
15) or use_default_device is not None:
|
||||
if use_default_device is not None:
|
||||
variant_tensor = gen_experimental_dataset_ops.scan_dataset(
|
||||
self._input_dataset._variant_tensor,
|
||||
structure.to_tensor_list(self._state_structure, self._initial_state),
|
||||
|
|
|
@ -25,7 +25,6 @@ import numpy as np
|
|||
from tensorflow.core.protobuf import cluster_pb2
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.compat import compat as forward_compat
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
|
@ -464,7 +463,6 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testIteratorStringHandleFuture(self):
|
||||
with forward_compat.forward_compatibility_horizon(2018, 8, 4):
|
||||
dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
|
||||
dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40])
|
||||
|
||||
|
|
|
@ -30,7 +30,6 @@ from six.moves import queue as Queue # pylint: disable=redefined-builtin
|
|||
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.experimental.ops import distribute_options
|
||||
from tensorflow.python.data.experimental.ops import optimization_options
|
||||
from tensorflow.python.data.experimental.ops import stats_options
|
||||
|
@ -223,7 +222,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
|
|||
A scalar `tf.Tensor` of `tf.string` type, representing this dataset as a
|
||||
serialized graph.
|
||||
"""
|
||||
if compat.forward_compatible(2019, 11, 25) or external_state_policy:
|
||||
if external_state_policy:
|
||||
policy = None
|
||||
if external_state_policy:
|
||||
policy = external_state_policy.value
|
||||
|
@ -231,7 +230,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
|
|||
self._variant_tensor,
|
||||
external_state_policy=policy,
|
||||
strip_device_assignment=strip_device_assignment)
|
||||
if compat.forward_compatible(2019, 11, 16) or strip_device_assignment:
|
||||
if strip_device_assignment:
|
||||
return gen_dataset_ops.dataset_to_graph(
|
||||
self._variant_tensor,
|
||||
allow_stateful=allow_stateful,
|
||||
|
|
|
@ -28,7 +28,6 @@ from tensorflow.core.protobuf import config_pb2
|
|||
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
|
@ -395,7 +394,6 @@ class AutoMixedPrecisionTest(test.TestCase):
|
|||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_conv_bn(self):
|
||||
"""Test graph with convolution followed by batch norm."""
|
||||
with compat.forward_compatibility_horizon(2019, 6, 7):
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
random_seed.set_random_seed(0)
|
||||
x = _input([2, 8, 8, 1])
|
||||
|
@ -468,7 +466,6 @@ class AutoMixedPrecisionTest(test.TestCase):
|
|||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_conv_bn_dropout(self):
|
||||
"""Test dropout precision of convolution batch norm graph."""
|
||||
with compat.forward_compatibility_horizon(2019, 6, 7):
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
random_seed.set_random_seed(0)
|
||||
x = _input([2, 8, 8, 1])
|
||||
|
|
|
@ -22,7 +22,6 @@ import numpy as np
|
|||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
|
@ -143,7 +142,6 @@ def _GetBatchMatmulOpBroadcastingTest(dtype, adjoint_a, adjoint_b,
|
|||
use_static_shape):
|
||||
|
||||
def Test(self):
|
||||
with compat.forward_compatibility_horizon(2019, 4, 26):
|
||||
np.random.seed(42)
|
||||
self._testBroadcasting(dtype, adjoint_a, adjoint_b, use_static_shape)
|
||||
|
||||
|
@ -200,7 +198,6 @@ def _GetBatchMatmulGradientWithBroadcastingTest(dtype, adjoint_a, adjoint_b):
|
|||
def CheckGradients(self, a_shape, b_shape):
|
||||
self._compare(a_shape, b_shape, dtype, adjoint_a, adjoint_b)
|
||||
|
||||
with compat.forward_compatibility_horizon(2019, 4, 26):
|
||||
CheckGradients(self, [1, 5, 2, 3], [7, 1, 3, 2])
|
||||
CheckGradients(self, [2, 3], [1, 3, 5])
|
||||
CheckGradients(self, [2, 3], [5, 3, 5])
|
||||
|
@ -231,7 +228,6 @@ class BatchMatMulBenchmark(test.Benchmark):
|
|||
|
||||
def benchmarkBatchMatMulBroadcast(self):
|
||||
for (a_shape, b_shape) in self.shape_pairs:
|
||||
with compat.forward_compatibility_horizon(2019, 4, 26):
|
||||
with ops.Graph().as_default(), \
|
||||
session.Session(config=benchmark.benchmark_config()) as sess, \
|
||||
ops.device("/cpu:0"):
|
||||
|
|
|
@ -21,7 +21,6 @@ import itertools
|
|||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes as dtypes_lib
|
||||
from tensorflow.python.framework import ops
|
||||
|
@ -33,15 +32,6 @@ from tensorflow.python.platform import test
|
|||
from tensorflow.python.platform import tf_logging
|
||||
|
||||
|
||||
# LINT.IfChange
|
||||
matrix_diag_v3_forward_compat_date = (2019, 12, 6)
|
||||
# LINT.ThenChange(
|
||||
# //tensorflow/compiler/tests/matrix_diag_ops_test.py,
|
||||
# //tensorflow/python/ops/array_ops.py,
|
||||
# //tensorflow/python/ops/parallel_for/array_test.py
|
||||
# )
|
||||
|
||||
|
||||
default_v2_alignment = "LEFT_LEFT"
|
||||
alignment_list = ["RIGHT_LEFT", "LEFT_RIGHT", "LEFT_LEFT", "RIGHT_RIGHT"]
|
||||
|
||||
|
@ -391,7 +381,6 @@ class MatrixDiagTest(test.TestCase):
|
|||
self.assertEqual((3, 3), v_diag.get_shape())
|
||||
self.assertAllEqual(v_diag.eval(), mat)
|
||||
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
# {Sub,Super}diagonals.
|
||||
for offset in [1, -2, 5]:
|
||||
mat = np.diag(v, offset)
|
||||
|
@ -417,7 +406,6 @@ class MatrixDiagTest(test.TestCase):
|
|||
self.assertEqual((2, 3, 3), v_batch_diag.get_shape())
|
||||
self.assertAllEqual(v_batch_diag.eval(), mat_batch)
|
||||
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
# {Sub,Super}diagonals.
|
||||
for offset in [1, -2, 5]:
|
||||
v_batch_diag = array_ops.matrix_diag(v_batch, k=offset)
|
||||
|
@ -453,7 +441,6 @@ class MatrixDiagTest(test.TestCase):
|
|||
|
||||
@test_util.run_deprecated_v1
|
||||
def testRectangularBatch(self):
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
with self.cached_session(use_gpu=True):
|
||||
# Stores expected num_rows and num_cols (when the other is given).
|
||||
# expected[d_lower, d_upper] = (expected_num_rows, expected_num_cols)
|
||||
|
@ -574,7 +561,6 @@ class MatrixDiagTest(test.TestCase):
|
|||
y.get_shape().as_list())
|
||||
self.assertLess(error, 1e-4)
|
||||
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
# {Sub,super}diagonals/band.
|
||||
tests = dict() # tests[shape] = (d_lower, d_upper)
|
||||
tests[(3,)] = (-1, -1)
|
||||
|
@ -604,7 +590,6 @@ class MatrixSetDiagTest(test.TestCase):
|
|||
self.assertEqual((3, 3), output.get_shape())
|
||||
self.assertAllEqual(mat_set_diag, self.evaluate(output))
|
||||
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
# Diagonal bands.
|
||||
for align in alignment_list:
|
||||
_, tests = square_cases(align)
|
||||
|
@ -634,7 +619,6 @@ class MatrixSetDiagTest(test.TestCase):
|
|||
self.assertEqual((3, 2), output.get_shape())
|
||||
self.assertAllEqual(expected, self.evaluate(output))
|
||||
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
# Diagonal bands.
|
||||
for align in alignment_list:
|
||||
for _, tests in [tall_cases(align), fat_cases(align)]:
|
||||
|
@ -663,7 +647,6 @@ class MatrixSetDiagTest(test.TestCase):
|
|||
self.assertEqual((2, 3, 3), output.get_shape())
|
||||
self.assertAllEqual(mat_set_diag_batch, self.evaluate(output))
|
||||
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
# Diagonal bands.
|
||||
for align in alignment_list:
|
||||
_, tests = square_cases(align)
|
||||
|
@ -697,7 +680,6 @@ class MatrixSetDiagTest(test.TestCase):
|
|||
self.assertEqual((2, 2, 3), output.get_shape())
|
||||
self.assertAllEqual(mat_set_diag_batch, self.evaluate(output))
|
||||
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
# Diagonal bands.
|
||||
for align in alignment_list:
|
||||
for _, tests in [tall_cases(align), fat_cases(align)]:
|
||||
|
@ -727,7 +709,6 @@ class MatrixSetDiagTest(test.TestCase):
|
|||
with self.assertRaisesOpError("diagonal must be at least 1-dim"):
|
||||
array_ops.matrix_set_diag([[v]], v).eval(feed_dict={v: 0.0})
|
||||
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
d = array_ops.placeholder(dtype=dtypes_lib.float32)
|
||||
with self.assertRaisesOpError(
|
||||
"first dimensions of diagonal don't match"):
|
||||
|
@ -743,10 +724,7 @@ class MatrixSetDiagTest(test.TestCase):
|
|||
x_diag = constant_op.constant(
|
||||
np.random.rand(*diag_shape), dtype=dtypes_lib.float32)
|
||||
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
y = array_ops.matrix_set_diag(x, x_diag, k=diags, align=align)
|
||||
else:
|
||||
y = array_ops.matrix_set_diag(x, x_diag)
|
||||
error_x = gradient_checker.compute_gradient_error(x,
|
||||
x.get_shape().as_list(),
|
||||
y,
|
||||
|
@ -763,7 +741,6 @@ class MatrixSetDiagTest(test.TestCase):
|
|||
input_shapes = [(3, 4, 4), (3, 3, 4), (3, 4, 3), (7, 4, 8, 8)]
|
||||
diag_bands = [(0, 0)]
|
||||
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
diag_bands.append((-1, 1))
|
||||
for input_shape, diags, align in itertools.product(input_shapes, diag_bands,
|
||||
alignment_list):
|
||||
|
@ -805,7 +782,6 @@ class MatrixDiagPartTest(test.TestCase):
|
|||
self.assertEqual((3,), mat_diag.get_shape())
|
||||
self.assertAllEqual(mat_diag.eval(), v)
|
||||
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
for offset in [-2, 3]:
|
||||
mat = np.diag(v, offset)
|
||||
mat_diag = array_ops.matrix_diag_part(mat, k=offset)
|
||||
|
@ -831,7 +807,6 @@ class MatrixDiagPartTest(test.TestCase):
|
|||
mat_diag = array_ops.matrix_diag_part(mat)
|
||||
self.assertAllEqual(mat_diag.eval(), np.array([1.0, 4.0]))
|
||||
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
# Diagonal bands.
|
||||
for align in alignment_list:
|
||||
for mat, tests in [tall_cases(align), fat_cases(align)]:
|
||||
|
@ -853,7 +828,6 @@ class MatrixDiagPartTest(test.TestCase):
|
|||
self.assertEqual((2, 3), mat_batch_diag.get_shape())
|
||||
self.assertAllEqual(mat_batch_diag.eval(), v_batch)
|
||||
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
# Diagonal bands with padding_value.
|
||||
for padding_value, align in zip_to_first_list_length([0, 555, -11],
|
||||
alignment_list):
|
||||
|
@ -889,7 +863,6 @@ class MatrixDiagPartTest(test.TestCase):
|
|||
self.assertEqual((2, 2), mat_batch_diag.get_shape())
|
||||
self.assertAllEqual(mat_batch_diag.eval(), v_batch)
|
||||
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
# Diagonal bands with padding_value and align.
|
||||
for padding_value, align in zip_to_first_list_length([0, 555, -11],
|
||||
alignment_list):
|
||||
|
@ -905,7 +878,6 @@ class MatrixDiagPartTest(test.TestCase):
|
|||
|
||||
@test_util.run_deprecated_v1
|
||||
def testUnknownShape(self):
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
matrix = array_ops.placeholder(dtypes_lib.int32, shape=[None, None])
|
||||
result = array_ops.matrix_diag_part(matrix, k=-1)
|
||||
input_matrix = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
|
||||
|
@ -939,7 +911,6 @@ class MatrixDiagPartTest(test.TestCase):
|
|||
y.get_shape().as_list())
|
||||
self.assertLess(error, 1e-4)
|
||||
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
# {Sub,super}diagonals/band.
|
||||
tests = dict() # tests[shape] = (d_lower, d_upper)
|
||||
tests[(3, 3)] = (-1, -1)
|
||||
|
|
|
@ -22,7 +22,6 @@ import numpy as np
|
|||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
|
@ -122,7 +121,6 @@ class ReluTest(test.TestCase):
|
|||
order="F")
|
||||
err = gradient_checker_v2.max_error(
|
||||
*gradient_checker_v2.compute_gradient(nn_ops.relu, [x]))
|
||||
print("relu (float32) gradient err = ", err)
|
||||
self.assertLess(err, 1e-4)
|
||||
|
||||
# The gradient for fp16 is inaccurate due to the low-precision.
|
||||
|
@ -171,7 +169,6 @@ class ReluTest(test.TestCase):
|
|||
order="F")
|
||||
err = gradient_checker_v2.max_error(
|
||||
*gradient_checker_v2.compute_gradient(nn_ops.relu, [x]))
|
||||
print("relu (float64) gradient err = ", err)
|
||||
self.assertLess(err, 1e-10)
|
||||
|
||||
def testGradGradFloat32(self):
|
||||
|
@ -190,7 +187,6 @@ class ReluTest(test.TestCase):
|
|||
order="F")
|
||||
err = gradient_checker_v2.max_error(
|
||||
*gradient_checker_v2.compute_gradient(f, [x]))
|
||||
print("relu (float32) gradient of gradient err = ", err)
|
||||
self.assertLess(err, 1e-4)
|
||||
|
||||
def testGradGradFloat64(self):
|
||||
|
@ -209,7 +205,6 @@ class ReluTest(test.TestCase):
|
|||
order="F")
|
||||
err = gradient_checker_v2.max_error(
|
||||
*gradient_checker_v2.compute_gradient(f, [x]))
|
||||
print("relu (float64) gradient of gradient err = ", err)
|
||||
self.assertLess(err, 1e-10)
|
||||
|
||||
def testGradientScalar(self):
|
||||
|
@ -283,7 +278,6 @@ class Relu6Test(test.TestCase):
|
|||
order="F")
|
||||
err = gradient_checker_v2.max_error(
|
||||
*gradient_checker_v2.compute_gradient(nn_ops.relu6, [x]))
|
||||
print("relu6 (float32) gradient err = ", err)
|
||||
self.assertLess(err, 1e-4)
|
||||
|
||||
def testGradientFloat64(self):
|
||||
|
@ -294,7 +288,6 @@ class Relu6Test(test.TestCase):
|
|||
order="F")
|
||||
err = gradient_checker_v2.max_error(
|
||||
*gradient_checker_v2.compute_gradient(nn_ops.relu6, [x]))
|
||||
print("relu6 (float64) gradient err = ", err)
|
||||
self.assertLess(err, 1e-10)
|
||||
|
||||
|
||||
|
@ -345,7 +338,6 @@ class LeakyReluTest(test.TestCase):
|
|||
order="F")
|
||||
err = gradient_checker_v2.max_error(
|
||||
*gradient_checker_v2.compute_gradient(nn_ops.leaky_relu, [x]))
|
||||
print("leaky_relu (float32) gradient err = ", err)
|
||||
self.assertLess(err, 1e-4)
|
||||
|
||||
def testGradientFloat64(self):
|
||||
|
@ -356,11 +348,9 @@ class LeakyReluTest(test.TestCase):
|
|||
order="F")
|
||||
err = gradient_checker_v2.max_error(
|
||||
*gradient_checker_v2.compute_gradient(nn_ops.leaky_relu, [x]))
|
||||
print("leaky_relu (float64) gradient err = ", err)
|
||||
self.assertLess(err, 1e-10)
|
||||
|
||||
def testGradGradFloat32(self):
|
||||
with compat.forward_compatibility_horizon(2018, 11, 2):
|
||||
with self.cached_session():
|
||||
|
||||
def f(x):
|
||||
|
@ -376,11 +366,9 @@ class LeakyReluTest(test.TestCase):
|
|||
order="F")
|
||||
err = gradient_checker_v2.max_error(
|
||||
*gradient_checker_v2.compute_gradient(f, [x]))
|
||||
print("leaky_relu (float32) gradient of gradient err = ", err)
|
||||
self.assertLess(err, 1e-4)
|
||||
|
||||
def testGradGradFloat64(self):
|
||||
with compat.forward_compatibility_horizon(2018, 11, 2):
|
||||
with self.cached_session():
|
||||
|
||||
def f(x):
|
||||
|
@ -396,7 +384,6 @@ class LeakyReluTest(test.TestCase):
|
|||
order="F")
|
||||
err = gradient_checker_v2.max_error(
|
||||
*gradient_checker_v2.compute_gradient(f, [x]))
|
||||
print("leaky_relu (float64) gradient of gradient err = ", err)
|
||||
self.assertLess(err, 1e-10)
|
||||
|
||||
def testGradientScalar(self):
|
||||
|
@ -463,7 +450,6 @@ class EluTest(test.TestCase):
|
|||
x = np.asarray(x_val, dtype=np.float32, order="F")
|
||||
err = gradient_checker_v2.max_error(
|
||||
*gradient_checker_v2.compute_gradient(nn_ops.elu, [x]))
|
||||
print("elu (float32) gradient err = ", err)
|
||||
self.assertLess(err, 1e-4)
|
||||
|
||||
def testGradientFloat64(self):
|
||||
|
@ -472,7 +458,6 @@ class EluTest(test.TestCase):
|
|||
x = np.asarray(x_val, dtype=np.float64, order="F")
|
||||
err = gradient_checker_v2.max_error(
|
||||
*gradient_checker_v2.compute_gradient(nn_ops.elu, [x]))
|
||||
print("elu (float64) gradient err = ", err)
|
||||
self.assertLess(err, 1e-6)
|
||||
|
||||
def testGradGrad(self):
|
||||
|
@ -507,7 +492,6 @@ class EluTest(test.TestCase):
|
|||
order="F")
|
||||
err = gradient_checker_v2.max_error(
|
||||
*gradient_checker_v2.compute_gradient(f, [x]))
|
||||
print("elu (float32) gradient of gradient err = ", err)
|
||||
self.assertLess(err, 1e-4)
|
||||
|
||||
def testGradGradFloat64(self):
|
||||
|
@ -526,7 +510,6 @@ class EluTest(test.TestCase):
|
|||
order="F")
|
||||
err = gradient_checker_v2.max_error(
|
||||
*gradient_checker_v2.compute_gradient(f, [x]))
|
||||
print("elu (float64) gradient of gradient err = ", err)
|
||||
self.assertLess(err, 1e-6)
|
||||
|
||||
|
||||
|
@ -567,7 +550,6 @@ class SeluTest(test.TestCase):
|
|||
x = np.asarray(x_val, dtype=np.float32, order="F")
|
||||
err = gradient_checker_v2.max_error(
|
||||
*gradient_checker_v2.compute_gradient(nn_ops.selu, [x]))
|
||||
print("selu (float32) gradient err = ", err)
|
||||
self.assertLess(err, 1e-4)
|
||||
|
||||
def testGradientFloat64(self):
|
||||
|
@ -576,7 +558,6 @@ class SeluTest(test.TestCase):
|
|||
x = np.asarray(x_val, dtype=np.float64, order="F")
|
||||
err = gradient_checker_v2.max_error(
|
||||
*gradient_checker_v2.compute_gradient(nn_ops.selu, [x]))
|
||||
print("selu (float64) gradient err = ", err)
|
||||
self.assertLess(err, 1e-6)
|
||||
|
||||
def testGradGradFloat32(self):
|
||||
|
@ -595,7 +576,6 @@ class SeluTest(test.TestCase):
|
|||
order="F")
|
||||
err = gradient_checker_v2.max_error(
|
||||
*gradient_checker_v2.compute_gradient(f, [x]))
|
||||
print("selu (float32) gradient of gradient err = ", err)
|
||||
self.assertLess(err, 1e-4)
|
||||
|
||||
def testGradGradFloat64(self):
|
||||
|
@ -614,7 +594,6 @@ class SeluTest(test.TestCase):
|
|||
order="F")
|
||||
err = gradient_checker_v2.max_error(
|
||||
*gradient_checker_v2.compute_gradient(f, [x]))
|
||||
print("selu (float64) gradient of gradient err = ", err)
|
||||
self.assertLess(err, 1e-6)
|
||||
|
||||
|
||||
|
|
|
@ -22,7 +22,6 @@ from __future__ import print_function
|
|||
import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import common_shapes
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
|
@ -54,13 +53,6 @@ tf_export("newaxis").export_constant(__name__, "newaxis")
|
|||
# existing 'slice' for later use in this module.
|
||||
_BaseSlice = slice
|
||||
|
||||
# LINT.IfChange
|
||||
matrix_diag_v3_forward_compat_date = (2019, 12, 6)
|
||||
# LINT.ThenChange(
|
||||
# //tensorflow/compiler/tests/matrix_diag_ops_test.py,
|
||||
# //tensorflow/python/kernel_tests/diag_op_test.py,
|
||||
# //tensorflow/python/ops/parallel_for/array_test.py
|
||||
# )
|
||||
|
||||
@tf_export("reshape", v1=["reshape", "manip.reshape"])
|
||||
def reshape(tensor, shape, name=None): # pylint: disable=redefined-outer-name
|
||||
|
@ -2362,7 +2354,6 @@ def matrix_diag(diagonal,
|
|||
Returns:
|
||||
A Tensor. Has the same type as `diagonal`.
|
||||
"""
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
# Special case to sidestep the tf.constant conversion error:
|
||||
# TypeError: Expected bool, got 0 of type 'int' instead.
|
||||
if hasattr(diagonal, "dtype") and diagonal.dtype == "bool":
|
||||
|
@ -2377,10 +2368,6 @@ def matrix_diag(diagonal,
|
|||
align=align,
|
||||
name=name)
|
||||
|
||||
# Call v1 to maintain forward compatibility.
|
||||
# (We skip v2 because its alignment conflicts with v3's default alignment.)
|
||||
return gen_array_ops.matrix_diag(diagonal=diagonal, name=name)
|
||||
|
||||
|
||||
@tf_export("linalg.diag_part", v1=["linalg.diag_part", "matrix_diag_part"])
|
||||
@deprecation.deprecated_endpoints("matrix_diag_part")
|
||||
|
@ -2513,7 +2500,6 @@ def matrix_diag_part(
|
|||
Returns:
|
||||
A Tensor containing diagonals of `input`. Has the same type as `input`.
|
||||
"""
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
# Special case to sidestep the tf.constant conversion error:
|
||||
# TypeError: Expected bool, got 0 of type 'int' instead.
|
||||
if hasattr(input, "dtype") and input.dtype == "bool":
|
||||
|
@ -2522,10 +2508,6 @@ def matrix_diag_part(
|
|||
return gen_array_ops.matrix_diag_part_v3(
|
||||
input=input, k=k, padding_value=padding_value, align=align, name=name)
|
||||
|
||||
# Call v1 to maintain forward compatibility.
|
||||
# (We skip v2 because its alignment conflicts with v3's default alignment.)
|
||||
return gen_array_ops.matrix_diag_part(input=input, name=name)
|
||||
|
||||
|
||||
@tf_export("linalg.set_diag", v1=["linalg.set_diag", "matrix_set_diag"])
|
||||
@deprecation.deprecated_endpoints("matrix_set_diag")
|
||||
|
@ -2659,15 +2641,9 @@ def matrix_set_diag(
|
|||
the left (right-pads the row). It is the packing format LAPACK uses.
|
||||
cuSPARSE uses "LEFT_RIGHT", which is the opposite alignment.
|
||||
"""
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
return gen_array_ops.matrix_set_diag_v3(
|
||||
input=input, diagonal=diagonal, k=k, align=align, name=name)
|
||||
|
||||
# Call v1 to maintain forward compatibility.
|
||||
# (We skip v2 because its alignment conflicts with v3's default alignment.)
|
||||
return gen_array_ops.matrix_set_diag(
|
||||
input=input, diagonal=diagonal, name=name)
|
||||
|
||||
|
||||
# pylint: enable=invalid-name
|
||||
|
||||
|
@ -4921,7 +4897,7 @@ def quantize_v2(
|
|||
raise ValueError("input should have known rank to use negative axis.")
|
||||
axis %= input.shape.ndims
|
||||
|
||||
if compat.forward_compatible(2019, 11, 13) or ensure_minimum_range != 0.01:
|
||||
if ensure_minimum_range != 0.01:
|
||||
return gen_array_ops.quantize_v2(
|
||||
input,
|
||||
min_range,
|
||||
|
@ -4965,7 +4941,7 @@ def quantize(
|
|||
axis=None,
|
||||
ensure_minimum_range=0.01):
|
||||
"""Quantize the input tensor."""
|
||||
if compat.forward_compatible(2019, 11, 13) or ensure_minimum_range != 0.01:
|
||||
if ensure_minimum_range != 0.01:
|
||||
return quantize_v2(
|
||||
input,
|
||||
min_range,
|
||||
|
@ -5007,7 +4983,7 @@ def dequantize( # pylint: disable=missing-docstring
|
|||
raise ValueError("input should have known rank to use negative axis.")
|
||||
axis %= input.shape.ndims
|
||||
|
||||
if compat.forward_compatible(2019, 10, 22) or axis >= 0 or narrow_range:
|
||||
if axis >= 0 or narrow_range:
|
||||
return gen_array_ops.dequantize(
|
||||
input, min_range, max_range, mode=mode, name=name,
|
||||
narrow_range=narrow_range, axis=axis)
|
||||
|
|
|
@ -20,7 +20,6 @@ from __future__ import print_function
|
|||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
|
@ -383,7 +382,6 @@ class BatchNormalizationTest(test.TestCase):
|
|||
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
|
||||
|
||||
def testInferenceShape5(self):
|
||||
with compat.forward_compatibility_horizon(2019, 6, 7):
|
||||
x_shape = [0, 131, 127, 6]
|
||||
for dtype in [np.float16, np.float32]:
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
|
@ -450,7 +448,6 @@ class BatchNormalizationTest(test.TestCase):
|
|||
|
||||
@test_util.disable_xla('b/141236973: Empty inputs wrong on CPU.')
|
||||
def testTrainingShape5(self):
|
||||
with compat.forward_compatibility_horizon(2019, 6, 7):
|
||||
x_shape = [0, 131, 127, 6]
|
||||
for dtype in [np.float16, np.float32]:
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
|
@ -586,7 +583,6 @@ class BatchNormalizationTest(test.TestCase):
|
|||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla('This test never passed for XLA')
|
||||
def testBatchNormGradShape5(self):
|
||||
with compat.forward_compatibility_horizon(2019, 6, 7):
|
||||
for is_training in [True, False]:
|
||||
x_shape = [0, 7, 11, 4]
|
||||
for dtype in [np.float16, np.float32]:
|
||||
|
|
|
@ -18,7 +18,6 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
|
@ -33,15 +32,6 @@ from tensorflow.python.ops.parallel_for.test_util import PForTestCase
|
|||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
# LINT.IfChange
|
||||
matrix_diag_v3_forward_compat_date = (2019, 12, 6)
|
||||
# LINT.ThenChange(
|
||||
# //tensorflow/compiler/tests/matrix_diag_ops_test.py,
|
||||
# //tensorflow/python/kernel_tests/diag_op_test.py,
|
||||
# //tensorflow/python/ops/array_ops.py
|
||||
# )
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class ArrayTest(PForTestCase):
|
||||
|
||||
|
@ -345,10 +335,8 @@ class ArrayTest(PForTestCase):
|
|||
|
||||
def loop_fn(i):
|
||||
diagonal = array_ops.gather(x, i)
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
return array_ops.matrix_diag(
|
||||
diagonal, k=(0, 1), num_rows=4, num_cols=5, align="RIGHT_LEFT")
|
||||
return array_ops.matrix_diag(diagonal)
|
||||
|
||||
self._test_loop_fn(loop_fn, 3)
|
||||
|
||||
|
@ -357,10 +345,8 @@ class ArrayTest(PForTestCase):
|
|||
|
||||
def loop_fn(i):
|
||||
input = array_ops.gather(x, i) # pylint: disable=redefined-builtin
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
return array_ops.matrix_diag_part(
|
||||
input, k=(-2, 0), padding_value=3, align="RIGHT_LEFT")
|
||||
return array_ops.matrix_diag_part(input)
|
||||
|
||||
self._test_loop_fn(loop_fn, 3)
|
||||
|
||||
|
@ -378,7 +364,6 @@ class ArrayTest(PForTestCase):
|
|||
array_ops.matrix_set_diag(matrix_i, diags[0, ...]),
|
||||
]
|
||||
|
||||
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
|
||||
k = (-1, 1)
|
||||
band_i = array_ops.gather(bands, i)
|
||||
for align in ["RIGHT_LEFT", "LEFT_RIGHT"]:
|
||||
|
|
|
@ -28,7 +28,6 @@ import numpy as np
|
|||
from tensorflow.core.example import example_pb2
|
||||
from tensorflow.core.example import feature_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
|
@ -444,7 +443,6 @@ class NNTest(PForTestCase):
|
|||
self._test_loop_fn(loop_fn, 3)
|
||||
|
||||
def test_fused_batch_norm(self):
|
||||
with compat.forward_compatibility_horizon(2019, 6, 7):
|
||||
data_formats = ["NHWC"]
|
||||
if test.is_gpu_available():
|
||||
data_formats.append("NCHW")
|
||||
|
|
|
@ -33,7 +33,6 @@ import six
|
|||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.compiler.tf2xla.ops import gen_xla_ops
|
||||
from tensorflow.python.compat import compat as fwd_compat
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
|
@ -253,10 +252,7 @@ def einsum(equation, *inputs, **kwargs):
|
|||
- the format of `equation` is incorrect,
|
||||
- number of inputs or their shapes are inconsistent with `equation`.
|
||||
"""
|
||||
if fwd_compat.forward_compatible(2019, 10, 18):
|
||||
return _einsum_v2(equation, *inputs, **kwargs)
|
||||
else:
|
||||
return _einsum_v1(equation, *inputs, **kwargs)
|
||||
|
||||
|
||||
def _einsum_v1(equation, *inputs, **kwargs):
|
||||
|
|
|
@ -23,7 +23,6 @@ import opt_einsum
|
|||
import six
|
||||
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
|
@ -243,7 +242,6 @@ class EinsumTest(test.TestCase):
|
|||
self._check('abc->ca', (3, 4, 5))
|
||||
self._check('abc->cab', (3, 4, 5))
|
||||
|
||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
||||
# Empty cases.
|
||||
self._check('', ())
|
||||
self._check('->', ())
|
||||
|
@ -266,7 +264,6 @@ class EinsumTest(test.TestCase):
|
|||
self._check('...ij->...ji', (5, 2, 3)) # batch matrix transpose
|
||||
self._check('...ij->...', (5, 2, 3)) # batch sum
|
||||
|
||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
||||
self._check('...->...', ())
|
||||
self._check('->...', ())
|
||||
|
||||
|
@ -301,7 +298,6 @@ class EinsumTest(test.TestCase):
|
|||
self._check('ab,ab->', (3, 4), (3, 4))
|
||||
|
||||
def test_repeated_indices(self):
|
||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
||||
# Repeated indices.
|
||||
self._check('ijj,k->ik', (2, 3, 3), (4,))
|
||||
self._check('aba,a->b', (3, 4, 3), (3,))
|
||||
|
@ -324,7 +320,6 @@ class EinsumTest(test.TestCase):
|
|||
self._check('...i,...j->...ij', (5, 2), (5, 3)) # outer product
|
||||
|
||||
def test_broadcasting(self):
|
||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
||||
# Batch matmul with broadcasting.
|
||||
self._check('...ij,...jk->...ik', (1, 2, 3), (3, 5))
|
||||
self._check('...ij,...jk->...ik', (2, 3), (1, 3, 5))
|
||||
|
@ -388,12 +383,10 @@ class EinsumTest(test.TestCase):
|
|||
((4, 3), (None, 3)))
|
||||
|
||||
# Ellipsis with unknown rank.
|
||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
||||
check('bijl,bjkm->bik', ((9, 2, 3, 5), None), ((9, 3, 4, 7), None))
|
||||
check('...ij,...jk->...ik', ((3, 1, 2, 3), None), ((1, 7, 3, 4), None))
|
||||
|
||||
def test_numpy_input(self):
|
||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
||||
# In addition to Tensors, we also support raw numpy arrays as inputs.
|
||||
r = np.random.RandomState(0)
|
||||
s = 'ijk,ijl,ikl->i'
|
||||
|
@ -464,7 +457,6 @@ class EinsumTest(test.TestCase):
|
|||
|
||||
@test_util.disable_xla('b/131919749')
|
||||
def test_long_cases_with_repeated_labels(self):
|
||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
||||
cases = [
|
||||
# Tests from dask.
|
||||
'fdf,cdd,ccd,afe->ae',
|
||||
|
@ -481,7 +473,6 @@ class EinsumTest(test.TestCase):
|
|||
@test_util.disable_xla('b/131919749')
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_invalid_equation(self):
|
||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
||||
r = np.random.RandomState(0)
|
||||
cases = [
|
||||
# invalid equation format.
|
||||
|
@ -535,7 +526,6 @@ class EinsumTest(test.TestCase):
|
|||
# From transformer xl.
|
||||
check('ibnd,ijbn->jnd', [(1, 0, 5, 10), (1, 1, 0, 5)], (1, 5, 10))
|
||||
|
||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
||||
# Generalized traces with zero-sized dimensions.
|
||||
check('aab,bc->ac', [(0, 0, 10), (10, 10)], (0, 10))
|
||||
check('aaab,bc->c', [(0, 0, 0, 3), (3, 4)], (4,))
|
||||
|
@ -556,7 +546,6 @@ class EinsumGradTest(test.TestCase):
|
|||
|
||||
@test_util.disable_xla('b/131919749')
|
||||
def test_unary(self):
|
||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
||||
self._check_gradient('->', ())
|
||||
self._check_gradient('aaa->a', (3, 3, 3))
|
||||
self._check_gradient('aabcd->ad', (3, 3, 5, 4, 4))
|
||||
|
@ -564,7 +553,6 @@ class EinsumGradTest(test.TestCase):
|
|||
|
||||
@test_util.disable_xla('b/131919749')
|
||||
def test_unary_ellipsis(self):
|
||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
||||
self._check_gradient('...->...', ())
|
||||
self._check_gradient('...->', ())
|
||||
self._check_gradient('->...', ())
|
||||
|
@ -582,7 +570,6 @@ class EinsumGradTest(test.TestCase):
|
|||
self._check_gradient('ab...cd->da...', (3, 5, 2, 3, 4, 2))
|
||||
|
||||
def test_binary_simple(self):
|
||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
||||
# Binary cases in XLA mode must have either (a) each index appearing
|
||||
# exactly once in both the inputs (batch or contraction index), or
|
||||
# (b) appearing exactly once in an input and in the output (free index).
|
||||
|
@ -598,20 +585,17 @@ class EinsumGradTest(test.TestCase):
|
|||
self._check_gradient('sa,shb->shab', (2, 1), (2, 3, 4))
|
||||
|
||||
def test_empty(self):
|
||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
||||
# From Transformer XL.
|
||||
self._check_gradient('ibnd,ijbn->jnd', (1, 0, 5, 10), (1, 1, 0, 5))
|
||||
|
||||
@test_util.disable_xla('b/131919749')
|
||||
def test_reduced_indices(self):
|
||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
||||
self._check_gradient('ba,b->', (3, 2), (3,))
|
||||
self._check_gradient('ab,ab->', (3, 4), (3, 4))
|
||||
self._check_gradient('abce,badf->abcd', (1, 2, 3, 4), (2, 1, 4, 3))
|
||||
|
||||
@test_util.disable_xla('b/131919749')
|
||||
def test_repeated_indices(self):
|
||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
||||
# Repeated indices.
|
||||
self._check_gradient('aba,a->b', (3, 4, 3), (3,))
|
||||
self._check_gradient('ijj,k->ik', (2, 3, 3), (4,))
|
||||
|
@ -622,14 +606,12 @@ class EinsumGradTest(test.TestCase):
|
|||
|
||||
@test_util.disable_xla('b/131919749')
|
||||
def test_empty_with_repeated_indices(self):
|
||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
||||
self._check_gradient('aab,bc->ac', (0, 0, 10), (10, 10))
|
||||
self._check_gradient('aab,bc->ac', (1, 1, 0), (0, 10))
|
||||
self._check_gradient('aaab,bc->c', (0, 0, 0, 3), (3, 4))
|
||||
|
||||
@test_util.disable_xla('b/131919749')
|
||||
def test_broadcasting(self):
|
||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
||||
self._check_gradient('...ij,...jk->...ik', (3, 2), (2, 4))
|
||||
self._check_gradient('ij...,jk...->ik...', (3, 2, 1), (2, 4))
|
||||
self._check_gradient('...ij,...jk->...ik', (3, 1, 3, 2), (1, 5, 2, 4))
|
||||
|
@ -641,7 +623,6 @@ class EinsumGradTest(test.TestCase):
|
|||
self._check_gradient('...i,...j,...k->...ijk', (1,), (1,), (1,))
|
||||
|
||||
def test_long_cases(self):
|
||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
||||
cases = [
|
||||
'abhe,hidj,jgba,hiab,gab->ed',
|
||||
# Tests from dask.
|
||||
|
@ -658,7 +639,6 @@ class EinsumGradTest(test.TestCase):
|
|||
|
||||
@test_util.disable_xla('b/131919749')
|
||||
def test_long_cases_with_repeated_labels(self):
|
||||
with compat.forward_compatibility_horizon(2019, 10, 19):
|
||||
cases = [
|
||||
# Tests from dask.
|
||||
'fdf,cdd,ccd,afe->ae',
|
||||
|
|
Loading…
Reference in New Issue