Remove forward_compatibility / forward_compatible checks for dates that have already passed.

- Also, removed print statements from relu_op_test.py

PiperOrigin-RevId: 287911742
Change-Id: Ib1763a5a010e5738e4d93e348391839e1e164108
This commit is contained in:
Srinivas Vasudevan 2020-01-02 16:10:50 -08:00 committed by TensorFlower Gardener
parent 4b7f4c1f09
commit f75c37faf3
15 changed files with 887 additions and 1042 deletions

View File

@ -21,19 +21,10 @@ from __future__ import print_function
import numpy as np import numpy as np
from tensorflow.compiler.tests import xla_test from tensorflow.compiler.tests import xla_test
from tensorflow.python.compat import compat
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.platform import googletest from tensorflow.python.platform import googletest
# LINT.IfChange
matrix_diag_v3_forward_compat_date = (2019, 12, 6)
# LINT.ThenChange(
# //tensorflow/python/kernel_tests/diag_op_test.py,
# //tensorflow/python/ops/array_ops.py,
# //tensorflow/python/ops/parallel_for/array_test.py
# )
default_v2_alignment = "LEFT_LEFT" default_v2_alignment = "LEFT_LEFT"
alignment_list = ["RIGHT_LEFT", "LEFT_RIGHT"] alignment_list = ["RIGHT_LEFT", "LEFT_RIGHT"]
@ -404,25 +395,20 @@ class MatrixDiagTest(xla_test.XLATestCase):
# From here onwards are v2-only tests. # From here onwards are v2-only tests.
def testSquare(self): def testSquare(self):
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): for align in alignment_list:
for align in alignment_list: for _, tests in [square_cases(align)]:
for _, tests in [square_cases(align)]: for diag_index, (vecs, solution) in tests.items():
for diag_index, (vecs, solution) in tests.items(): params = {"diagonal": vecs[0], "k": diag_index, "align": align}
params = {"diagonal": vecs[0], "k": diag_index, "align": align} self._assertOpOutputMatchesExpected(params, solution[0])
self._assertOpOutputMatchesExpected(params, solution[0])
def testSquareBatch(self): def testSquareBatch(self):
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): for align in alignment_list:
for align in alignment_list: for _, tests in [square_cases(align)]:
for _, tests in [square_cases(align)]: for diag_index, (vecs, solution) in tests.items():
for diag_index, (vecs, solution) in tests.items(): params = {"diagonal": vecs, "k": diag_index, "align": align}
params = {"diagonal": vecs, "k": diag_index, "align": align} self._assertOpOutputMatchesExpected(params, solution)
self._assertOpOutputMatchesExpected(params, solution)
def testRectangularBatch(self): def testRectangularBatch(self):
if not compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
return
# Stores expected num_rows and num_cols (when the other is given). # Stores expected num_rows and num_cols (when the other is given).
# expected[(d_lower, d_upper)] = (expected_num_rows, expected_num_cols) # expected[(d_lower, d_upper)] = (expected_num_rows, expected_num_cols)
test_list = list() test_list = list()
@ -513,22 +499,21 @@ class MatrixDiagTest(xla_test.XLATestCase):
}, solution_given_num_cols) }, solution_given_num_cols)
def testPadding(self): def testPadding(self):
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): for padding_value, align in zip_to_first_list_length([555, -11],
for padding_value, align in zip_to_first_list_length([555, -11], alignment_list):
alignment_list): for _, tests in all_tests(align):
for _, tests in all_tests(align): for diag_index, (vecs, solution) in tests.items():
for diag_index, (vecs, solution) in tests.items(): mask = (solution == 0)
mask = (solution == 0) solution = solution + (mask * padding_value)
solution = solution + (mask * padding_value) self._assertOpOutputMatchesExpected(
self._assertOpOutputMatchesExpected( {
{ "diagonal": vecs,
"diagonal": vecs, "k": diag_index,
"k": diag_index, "num_rows": solution.shape[-2],
"num_rows": solution.shape[-2], "num_cols": solution.shape[-1],
"num_cols": solution.shape[-1], "padding_value": padding_value,
"padding_value": padding_value, "align": align
"align": align }, solution)
}, solution)
class MatrixSetDiagTest(xla_test.XLATestCase): class MatrixSetDiagTest(xla_test.XLATestCase):
@ -634,36 +619,34 @@ class MatrixSetDiagTest(xla_test.XLATestCase):
# From here onwards are v2-only tests. # From here onwards are v2-only tests.
def testSingleMatrix(self): def testSingleMatrix(self):
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): for align in alignment_list:
for align in alignment_list: for _, tests in all_tests(align):
for _, tests in all_tests(align): for diag_index, (vecs, banded_mat) in tests.items():
for diag_index, (vecs, banded_mat) in tests.items(): mask = (banded_mat[0] == 0)
mask = (banded_mat[0] == 0) input_mat = np.random.randint(10, size=mask.shape)
input_mat = np.random.randint(10, size=mask.shape) solution = input_mat * mask + banded_mat[0]
solution = input_mat * mask + banded_mat[0] self._assertOpOutputMatchesExpected(
self._assertOpOutputMatchesExpected( {
{ "input": input_mat,
"input": input_mat, "diagonal": vecs[0],
"diagonal": vecs[0], "k": diag_index,
"k": diag_index, "align": align
"align": align }, solution)
}, solution)
def testBatch(self): def testBatch(self):
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): for align in alignment_list:
for align in alignment_list: for _, tests in all_tests(align):
for _, tests in all_tests(align): for diag_index, (vecs, banded_mat) in tests.items():
for diag_index, (vecs, banded_mat) in tests.items(): mask = (banded_mat == 0)
mask = (banded_mat == 0) input_mat = np.random.randint(10, size=mask.shape)
input_mat = np.random.randint(10, size=mask.shape) solution = input_mat * mask + banded_mat
solution = input_mat * mask + banded_mat self._assertOpOutputMatchesExpected(
self._assertOpOutputMatchesExpected( {
{ "input": input_mat,
"input": input_mat, "diagonal": vecs,
"diagonal": vecs, "k": diag_index,
"k": diag_index, "align": align
"align": align }, solution)
}, solution)
class MatrixDiagPartTest(xla_test.XLATestCase): class MatrixDiagPartTest(xla_test.XLATestCase):
@ -705,45 +688,42 @@ class MatrixDiagPartTest(xla_test.XLATestCase):
# From here onwards are v2-only tests. # From here onwards are v2-only tests.
def testSingleMatrix(self): def testSingleMatrix(self):
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): for align in alignment_list:
for align in alignment_list: test_list = [square_cases(align), tall_cases(align), fat_cases(align)]
test_list = [square_cases(align), tall_cases(align), fat_cases(align)] for mat, tests in test_list:
for mat, tests in test_list: for diag_index, (solution, _) in tests.items():
for diag_index, (solution, _) in tests.items(): self._assertOpOutputMatchesExpected(
self._assertOpOutputMatchesExpected( {
{ "input": mat[0],
"input": mat[0], "k": diag_index,
"k": diag_index, "align": align
"align": align }, solution[0])
}, solution[0])
def testBatch(self): def testBatch(self):
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): for align in alignment_list:
for align in alignment_list: for mat, tests in all_tests(align):
for mat, tests in all_tests(align): for diag_index, (solution, _) in tests.items():
for diag_index, (solution, _) in tests.items(): self._assertOpOutputMatchesExpected(
self._assertOpOutputMatchesExpected( {
{ "input": mat,
"input": mat, "k": diag_index,
"k": diag_index, "align": align
"align": align }, solution)
}, solution)
def testPadding(self): def testPadding(self):
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): for padding_value, align in zip_to_first_list_length([555, -11],
for padding_value, align in zip_to_first_list_length([555, -11], alignment_list):
alignment_list): for mat, tests in all_tests(align):
for mat, tests in all_tests(align): for diag_index, (solution, _) in tests.items():
for diag_index, (solution, _) in tests.items(): mask = (solution == 0)
mask = (solution == 0) solution = solution + (mask * padding_value)
solution = solution + (mask * padding_value) self._assertOpOutputMatchesExpected(
self._assertOpOutputMatchesExpected( {
{ "input": mat,
"input": mat, "k": diag_index,
"k": diag_index, "padding_value": padding_value,
"padding_value": padding_value, "align": align
"align": align }, solution)
}, solution)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -23,7 +23,6 @@ from tensorflow.core.protobuf import cluster_pb2
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import tensorflow_server_pb2 from tensorflow.core.protobuf import tensorflow_server_pb2
from tensorflow.python import pywrap_tensorflow from tensorflow.python import pywrap_tensorflow
from tensorflow.python.compat import compat
from tensorflow.python.data.experimental.ops import distribute from tensorflow.python.data.experimental.ops import distribute
from tensorflow.python.data.experimental.ops import distribute_options from tensorflow.python.data.experimental.ops import distribute_options
from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.kernel_tests import test_base
@ -90,50 +89,80 @@ class LocalReplicateTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate( @combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph", "eager"])) combinations.combine(tf_api_version=[1], mode=["graph", "eager"]))
def testExternalStatePolicyIgnore(self): def testExternalStatePolicyIgnore(self):
with compat.forward_compatibility_horizon(2019, 11, 30): with ops.device(self._device0):
with ops.device(self._device0): dataset0 = dataset_ops.Dataset.range(100).map(
dataset0 = dataset_ops.Dataset.range(100).map( lambda _: random_ops.random_uniform( # pylint:disable=g-long-lambda
lambda _: random_ops.random_uniform( # pylint:disable=g-long-lambda [],
[], minval=1,
minval=1, maxval=10,
maxval=10, dtype=dtypes.float32))
dtype=dtypes.float32)) opt = dataset_ops.Options()
opt = dataset_ops.Options() opt.experimental_external_state_policy = (
opt.experimental_external_state_policy = ( distribute_options.ExternalStatePolicy.IGNORE)
distribute_options.ExternalStatePolicy.IGNORE) dataset0 = dataset0.with_options(opt)
dataset0 = dataset0.with_options(opt) replicated_ds = distribute.replicate(dataset0,
replicated_ds = distribute.replicate(dataset0, [self._device1, self._device2])
[self._device1, self._device2]) dataset1 = replicated_ds[self._device1]
dataset1 = replicated_ds[self._device1] dataset2 = replicated_ds[self._device2]
dataset2 = replicated_ds[self._device2]
with ops.device(self._device0): with ops.device(self._device0):
get_next0 = self.getNext(dataset0) get_next0 = self.getNext(dataset0)
with ops.device(self._device1): with ops.device(self._device1):
get_next1 = self.getNext(dataset1) get_next1 = self.getNext(dataset1)
with ops.device(self._device2): with ops.device(self._device2):
get_next2 = self.getNext(dataset2) get_next2 = self.getNext(dataset2)
for _ in range(100): for _ in range(100):
self.evaluate(get_next0()) self.evaluate(get_next0())
self.evaluate(get_next1()) self.evaluate(get_next1())
self.evaluate(get_next2()) self.evaluate(get_next2())
@combinations.generate( @combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph", "eager"])) combinations.combine(tf_api_version=[1], mode=["graph", "eager"]))
def testExternalStatePolicyWarn(self): def testExternalStatePolicyWarn(self):
with compat.forward_compatibility_horizon(2019, 11, 30): with ops.device(self._device0):
with ops.device(self._device0): dataset0 = dataset_ops.Dataset.range(100).map(
dataset0 = dataset_ops.Dataset.range(100).map( lambda _: random_ops.random_uniform( # pylint:disable=g-long-lambda
lambda _: random_ops.random_uniform( # pylint:disable=g-long-lambda [],
[], minval=1,
minval=1, maxval=10,
maxval=10, dtype=dtypes.float32))
dtype=dtypes.float32)) opt = dataset_ops.Options()
opt = dataset_ops.Options() opt.experimental_external_state_policy = (
opt.experimental_external_state_policy = ( distribute_options.ExternalStatePolicy.WARN)
distribute_options.ExternalStatePolicy.WARN) dataset0 = dataset0.with_options(opt)
dataset0 = dataset0.with_options(opt) replicated_ds = distribute.replicate(dataset0,
[self._device1, self._device2])
dataset1 = replicated_ds[self._device1]
dataset2 = replicated_ds[self._device2]
with ops.device(self._device0):
get_next0 = self.getNext(dataset0)
with ops.device(self._device1):
get_next1 = self.getNext(dataset1)
with ops.device(self._device2):
get_next2 = self.getNext(dataset2)
for _ in range(100):
self.evaluate(get_next0())
self.evaluate(get_next1())
self.evaluate(get_next2())
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph", "eager"]))
def testExternalStatePolicyFail(self):
with ops.device(self._device0):
dataset0 = dataset_ops.Dataset.range(100).map(
lambda _: random_ops.random_uniform( # pylint:disable=g-long-lambda
[],
minval=1,
maxval=10,
dtype=dtypes.float32))
opt = dataset_ops.Options()
opt.experimental_external_state_policy = (
distribute_options.ExternalStatePolicy.FAIL)
dataset0 = dataset0.with_options(opt)
with self.assertRaises(errors.FailedPreconditionError):
replicated_ds = distribute.replicate(dataset0, replicated_ds = distribute.replicate(dataset0,
[self._device1, self._device2]) [self._device1, self._device2])
dataset1 = replicated_ds[self._device1] dataset1 = replicated_ds[self._device1]
@ -151,39 +180,6 @@ class LocalReplicateTest(test_base.DatasetTestBase, parameterized.TestCase):
self.evaluate(get_next1()) self.evaluate(get_next1())
self.evaluate(get_next2()) self.evaluate(get_next2())
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph", "eager"]))
def testExternalStatePolicyFail(self):
with compat.forward_compatibility_horizon(2019, 11, 30):
with ops.device(self._device0):
dataset0 = dataset_ops.Dataset.range(100).map(
lambda _: random_ops.random_uniform( # pylint:disable=g-long-lambda
[],
minval=1,
maxval=10,
dtype=dtypes.float32))
opt = dataset_ops.Options()
opt.experimental_external_state_policy = (
distribute_options.ExternalStatePolicy.FAIL)
dataset0 = dataset0.with_options(opt)
with self.assertRaises(errors.FailedPreconditionError):
replicated_ds = distribute.replicate(dataset0,
[self._device1, self._device2])
dataset1 = replicated_ds[self._device1]
dataset2 = replicated_ds[self._device2]
with ops.device(self._device0):
get_next0 = self.getNext(dataset0)
with ops.device(self._device1):
get_next1 = self.getNext(dataset1)
with ops.device(self._device2):
get_next2 = self.getNext(dataset2)
for _ in range(100):
self.evaluate(get_next0())
self.evaluate(get_next1())
self.evaluate(get_next2())
JOB_NAME = "remote_device" JOB_NAME = "remote_device"

View File

@ -17,7 +17,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.compat import compat
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest from tensorflow.python.data.util import nest
from tensorflow.python.data.util import structure from tensorflow.python.data.util import structure
@ -125,8 +124,7 @@ class _ScanDataset(dataset_ops.UnaryDataset):
self._scan_func = wrapped_func self._scan_func = wrapped_func
self._scan_func.function.add_to_graph(ops.get_default_graph()) self._scan_func.function.add_to_graph(ops.get_default_graph())
# pylint: disable=protected-access # pylint: disable=protected-access
if compat.forward_compatible(2019, 10, if use_default_device is not None:
15) or use_default_device is not None:
variant_tensor = gen_experimental_dataset_ops.scan_dataset( variant_tensor = gen_experimental_dataset_ops.scan_dataset(
self._input_dataset._variant_tensor, self._input_dataset._variant_tensor,
structure.to_tensor_list(self._state_structure, self._initial_state), structure.to_tensor_list(self._state_structure, self._initial_state),

View File

@ -25,7 +25,6 @@ import numpy as np
from tensorflow.core.protobuf import cluster_pb2 from tensorflow.core.protobuf import cluster_pb2
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session from tensorflow.python.client import session
from tensorflow.python.compat import compat as forward_compat
from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.ops import iterator_ops
@ -464,69 +463,68 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.graph_only_combinations()) @combinations.generate(test_base.graph_only_combinations())
def testIteratorStringHandleFuture(self): def testIteratorStringHandleFuture(self):
with forward_compat.forward_compatibility_horizon(2018, 8, 4): dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40])
dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40])
iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3) iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3)
iterator_4 = dataset_ops.make_one_shot_iterator(dataset_4) iterator_4 = dataset_ops.make_one_shot_iterator(dataset_4)
handle_placeholder = array_ops.placeholder(dtypes.string, shape=[]) handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])
feedable_iterator = iterator_ops.Iterator.from_string_handle( feedable_iterator = iterator_ops.Iterator.from_string_handle(
handle_placeholder, dataset_ops.get_legacy_output_types(dataset_3), handle_placeholder, dataset_ops.get_legacy_output_types(dataset_3),
dataset_ops.get_legacy_output_shapes(dataset_3)) dataset_ops.get_legacy_output_shapes(dataset_3))
next_element = feedable_iterator.get_next() next_element = feedable_iterator.get_next()
self.assertTrue( self.assertTrue(
structure.are_compatible( structure.are_compatible(
dataset_ops.get_structure(dataset_3), dataset_ops.get_structure(dataset_3),
dataset_ops.get_structure(feedable_iterator))) dataset_ops.get_structure(feedable_iterator)))
with self.cached_session() as sess: with self.cached_session() as sess:
iterator_3_handle = sess.run(iterator_3.string_handle()) iterator_3_handle = sess.run(iterator_3.string_handle())
iterator_4_handle = sess.run(iterator_4.string_handle()) iterator_4_handle = sess.run(iterator_4.string_handle())
self.assertEqual( self.assertEqual(
10, 10,
sess.run(
next_element,
feed_dict={handle_placeholder: iterator_4_handle}))
self.assertEqual(
1,
sess.run(
next_element,
feed_dict={handle_placeholder: iterator_3_handle}))
self.assertEqual(
20,
sess.run(
next_element,
feed_dict={handle_placeholder: iterator_4_handle}))
self.assertEqual(
2,
sess.run(
next_element,
feed_dict={handle_placeholder: iterator_3_handle}))
self.assertEqual(
30,
sess.run(
next_element,
feed_dict={handle_placeholder: iterator_4_handle}))
self.assertEqual(
3,
sess.run(
next_element,
feed_dict={handle_placeholder: iterator_3_handle}))
self.assertEqual(
40,
sess.run(
next_element,
feed_dict={handle_placeholder: iterator_4_handle}))
with self.assertRaises(errors.OutOfRangeError):
sess.run( sess.run(
next_element, feed_dict={handle_placeholder: iterator_3_handle}) next_element,
with self.assertRaises(errors.OutOfRangeError): feed_dict={handle_placeholder: iterator_4_handle}))
self.assertEqual(
1,
sess.run( sess.run(
next_element, feed_dict={handle_placeholder: iterator_4_handle}) next_element,
feed_dict={handle_placeholder: iterator_3_handle}))
self.assertEqual(
20,
sess.run(
next_element,
feed_dict={handle_placeholder: iterator_4_handle}))
self.assertEqual(
2,
sess.run(
next_element,
feed_dict={handle_placeholder: iterator_3_handle}))
self.assertEqual(
30,
sess.run(
next_element,
feed_dict={handle_placeholder: iterator_4_handle}))
self.assertEqual(
3,
sess.run(
next_element,
feed_dict={handle_placeholder: iterator_3_handle}))
self.assertEqual(
40,
sess.run(
next_element,
feed_dict={handle_placeholder: iterator_4_handle}))
with self.assertRaises(errors.OutOfRangeError):
sess.run(
next_element, feed_dict={handle_placeholder: iterator_3_handle})
with self.assertRaises(errors.OutOfRangeError):
sess.run(
next_element, feed_dict={handle_placeholder: iterator_4_handle})
@combinations.generate(test_base.graph_only_combinations()) @combinations.generate(test_base.graph_only_combinations())
def testIteratorStringHandleReuseTensorObject(self): def testIteratorStringHandleReuseTensorObject(self):

View File

@ -30,7 +30,6 @@ from six.moves import queue as Queue # pylint: disable=redefined-builtin
from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import graph_pb2
from tensorflow.python import tf2 from tensorflow.python import tf2
from tensorflow.python.compat import compat
from tensorflow.python.data.experimental.ops import distribute_options from tensorflow.python.data.experimental.ops import distribute_options
from tensorflow.python.data.experimental.ops import optimization_options from tensorflow.python.data.experimental.ops import optimization_options
from tensorflow.python.data.experimental.ops import stats_options from tensorflow.python.data.experimental.ops import stats_options
@ -223,7 +222,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
A scalar `tf.Tensor` of `tf.string` type, representing this dataset as a A scalar `tf.Tensor` of `tf.string` type, representing this dataset as a
serialized graph. serialized graph.
""" """
if compat.forward_compatible(2019, 11, 25) or external_state_policy: if external_state_policy:
policy = None policy = None
if external_state_policy: if external_state_policy:
policy = external_state_policy.value policy = external_state_policy.value
@ -231,7 +230,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
self._variant_tensor, self._variant_tensor,
external_state_policy=policy, external_state_policy=policy,
strip_device_assignment=strip_device_assignment) strip_device_assignment=strip_device_assignment)
if compat.forward_compatible(2019, 11, 16) or strip_device_assignment: if strip_device_assignment:
return gen_dataset_ops.dataset_to_graph( return gen_dataset_ops.dataset_to_graph(
self._variant_tensor, self._variant_tensor,
allow_stateful=allow_stateful, allow_stateful=allow_stateful,

View File

@ -28,7 +28,6 @@ from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python import tf2 from tensorflow.python import tf2
from tensorflow.python.client import session from tensorflow.python.client import session
from tensorflow.python.compat import compat
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
@ -395,24 +394,23 @@ class AutoMixedPrecisionTest(test.TestCase):
@test_util.disable_xla('This test does not pass with XLA') @test_util.disable_xla('This test does not pass with XLA')
def test_conv_bn(self): def test_conv_bn(self):
"""Test graph with convolution followed by batch norm.""" """Test graph with convolution followed by batch norm."""
with compat.forward_compatibility_horizon(2019, 6, 7): if test.is_gpu_available(cuda_only=True):
if test.is_gpu_available(cuda_only=True): random_seed.set_random_seed(0)
random_seed.set_random_seed(0) x = _input([2, 8, 8, 1])
x = _input([2, 8, 8, 1]) x = _conv_bn(x)
x = _conv_bn(x) output = _conv_bn(x)
output = _conv_bn(x)
output_val_ref, output_val, cost_graph = self._run(output) output_val_ref, output_val, cost_graph = self._run(output)
node_map = _build_node_map(cost_graph.node) node_map = _build_node_map(cost_graph.node)
num_to_fp16, num_to_fp32 = _count_casts(cost_graph.node) num_to_fp16, num_to_fp32 = _count_casts(cost_graph.node)
self._assert_output_fp16(node_map, 'Conv2D') self._assert_output_fp16(node_map, 'Conv2D')
self._assert_output_fp16(node_map, 'FusedBatchNormV3') self._assert_output_fp16(node_map, 'FusedBatchNormV3')
self._assert_output_fp16(node_map, 'Conv2D_1') self._assert_output_fp16(node_map, 'Conv2D_1')
self.assertEqual(num_to_fp16, self.assertEqual(num_to_fp16,
3) # Before Conv2D:0, Conv2D:1, Conv2D_1:1 3) # Before Conv2D:0, Conv2D:1, Conv2D_1:1
self.assertEqual(num_to_fp32, 1) # After FusedBatchNormV3:0 self.assertEqual(num_to_fp32, 1) # After FusedBatchNormV3:0
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3) self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
# TODO: enable these tests when cuDNN is upgraded to >= 7.6.2. Same with the # TODO: enable these tests when cuDNN is upgraded to >= 7.6.2. Same with the
# test_conv3d() below. # test_conv3d() below.
@ -468,31 +466,30 @@ class AutoMixedPrecisionTest(test.TestCase):
@test_util.disable_xla('This test does not pass with XLA') @test_util.disable_xla('This test does not pass with XLA')
def test_conv_bn_dropout(self): def test_conv_bn_dropout(self):
"""Test dropout precision of convolution batch norm graph.""" """Test dropout precision of convolution batch norm graph."""
with compat.forward_compatibility_horizon(2019, 6, 7): if test.is_gpu_available(cuda_only=True):
if test.is_gpu_available(cuda_only=True): random_seed.set_random_seed(0)
random_seed.set_random_seed(0) x = _input([2, 8, 8, 1])
x = _input([2, 8, 8, 1]) y = _conv_bn(x)
y = _conv_bn(x) y = nn.dropout(y, rate=0.5)
y = nn.dropout(y, rate=0.5) y = math_ops.add(y, 1, name='addition')
y = math_ops.add(y, 1, name='addition') y = _conv_bn(y)
y = _conv_bn(y) y = array_ops.identity(y)
y = array_ops.identity(y) optimizer = gradient_descent.GradientDescentOptimizer(
optimizer = gradient_descent.GradientDescentOptimizer( learning_rate=0.01)
learning_rate=0.01) g = optimizer.compute_gradients(y, [x])
g = optimizer.compute_gradients(y, [x]) output = (y, g)
output = (y, g)
output_val_ref, output_val, cost_graph = self._run(output) output_val_ref, output_val, cost_graph = self._run(output)
node_map = _build_node_map(cost_graph.node) node_map = _build_node_map(cost_graph.node)
self._assert_output_fp16(node_map, 'Conv2D') self._assert_output_fp16(node_map, 'Conv2D')
self._assert_output_fp16(node_map, 'FusedBatchNormV3') self._assert_output_fp16(node_map, 'FusedBatchNormV3')
# We do not assert dropout's dtype because we do not want to rely on the # We do not assert dropout's dtype because we do not want to rely on the
# node names of dropout's internal implementation. # node names of dropout's internal implementation.
self._assert_output_fp16(node_map, 'addition') self._assert_output_fp16(node_map, 'addition')
self._assert_output_fp16(node_map, 'Conv2D_1') self._assert_output_fp16(node_map, 'Conv2D_1')
output_val_ref, output_val, cost_graph = self._run(output) output_val_ref, output_val, cost_graph = self._run(output)
self.assertAllClose(output_val_ref, output_val, atol=2e-3, rtol=2e-3) self.assertAllClose(output_val_ref, output_val, atol=2e-3, rtol=2e-3)
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
@test_util.disable_xla('This test does not pass with XLA') @test_util.disable_xla('This test does not pass with XLA')

View File

@ -22,7 +22,6 @@ import numpy as np
from tensorflow.python import tf2 from tensorflow.python import tf2
from tensorflow.python.client import session from tensorflow.python.client import session
from tensorflow.python.compat import compat
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
@ -143,9 +142,8 @@ def _GetBatchMatmulOpBroadcastingTest(dtype, adjoint_a, adjoint_b,
use_static_shape): use_static_shape):
def Test(self): def Test(self):
with compat.forward_compatibility_horizon(2019, 4, 26): np.random.seed(42)
np.random.seed(42) self._testBroadcasting(dtype, adjoint_a, adjoint_b, use_static_shape)
self._testBroadcasting(dtype, adjoint_a, adjoint_b, use_static_shape)
return Test return Test
@ -200,14 +198,13 @@ def _GetBatchMatmulGradientWithBroadcastingTest(dtype, adjoint_a, adjoint_b):
def CheckGradients(self, a_shape, b_shape): def CheckGradients(self, a_shape, b_shape):
self._compare(a_shape, b_shape, dtype, adjoint_a, adjoint_b) self._compare(a_shape, b_shape, dtype, adjoint_a, adjoint_b)
with compat.forward_compatibility_horizon(2019, 4, 26): CheckGradients(self, [1, 5, 2, 3], [7, 1, 3, 2])
CheckGradients(self, [1, 5, 2, 3], [7, 1, 3, 2]) CheckGradients(self, [2, 3], [1, 3, 5])
CheckGradients(self, [2, 3], [1, 3, 5]) CheckGradients(self, [2, 3], [5, 3, 5])
CheckGradients(self, [2, 3], [5, 3, 5]) CheckGradients(self, [5, 2, 5], [5, 3])
CheckGradients(self, [5, 2, 5], [5, 3]) CheckGradients(self, [5, 2, 2, 3], [3, 5])
CheckGradients(self, [5, 2, 2, 3], [3, 5]) CheckGradients(self, [4, 5, 1, 2, 3], [1, 1, 3, 5])
CheckGradients(self, [4, 5, 1, 2, 3], [1, 1, 3, 5]) CheckGradients(self, [1, 2, 1, 4, 2, 1, 3, 4], [3, 2, 1, 1, 1, 2, 4, 2])
CheckGradients(self, [1, 2, 1, 4, 2, 1, 3, 4], [3, 2, 1, 1, 1, 2, 4, 2])
return Test return Test
@ -231,38 +228,37 @@ class BatchMatMulBenchmark(test.Benchmark):
def benchmarkBatchMatMulBroadcast(self): def benchmarkBatchMatMulBroadcast(self):
for (a_shape, b_shape) in self.shape_pairs: for (a_shape, b_shape) in self.shape_pairs:
with compat.forward_compatibility_horizon(2019, 4, 26): with ops.Graph().as_default(), \
with ops.Graph().as_default(), \ session.Session(config=benchmark.benchmark_config()) as sess, \
session.Session(config=benchmark.benchmark_config()) as sess, \ ops.device("/cpu:0"):
ops.device("/cpu:0"): matrix_a = variables.Variable(
matrix_a = variables.Variable( GetRandomNormalInput(a_shape, np.float32))
GetRandomNormalInput(a_shape, np.float32)) matrix_b = variables.Variable(
matrix_b = variables.Variable( GetRandomNormalInput(b_shape, np.float32))
GetRandomNormalInput(b_shape, np.float32)) variables.global_variables_initializer().run()
variables.global_variables_initializer().run()
# Use batch matmul op's internal broadcasting. # Use batch matmul op's internal broadcasting.
self.run_op_benchmark( self.run_op_benchmark(
sess, sess,
math_ops.matmul(matrix_a, matrix_b), math_ops.matmul(matrix_a, matrix_b),
min_iters=50, min_iters=50,
name="batch_matmul_cpu_{}_{}".format(a_shape, b_shape)) name="batch_matmul_cpu_{}_{}".format(a_shape, b_shape))
# Manually broadcast the input matrices using the broadcast_to op. # Manually broadcast the input matrices using the broadcast_to op.
broadcasted_batch_shape = array_ops.broadcast_static_shape( broadcasted_batch_shape = array_ops.broadcast_static_shape(
matrix_a.shape[:-2], matrix_b.shape[:-2]) matrix_a.shape[:-2], matrix_b.shape[:-2])
broadcasted_a_shape = broadcasted_batch_shape.concatenate( broadcasted_a_shape = broadcasted_batch_shape.concatenate(
matrix_a.shape[-2:]) matrix_a.shape[-2:])
broadcasted_b_shape = broadcasted_batch_shape.concatenate( broadcasted_b_shape = broadcasted_batch_shape.concatenate(
matrix_b.shape[-2:]) matrix_b.shape[-2:])
self.run_op_benchmark( self.run_op_benchmark(
sess, sess,
math_ops.matmul( math_ops.matmul(
array_ops.broadcast_to(matrix_a, broadcasted_a_shape), array_ops.broadcast_to(matrix_a, broadcasted_a_shape),
array_ops.broadcast_to(matrix_b, broadcasted_b_shape)), array_ops.broadcast_to(matrix_b, broadcasted_b_shape)),
min_iters=50, min_iters=50,
name="batch_matmul_manual_broadcast_cpu_{}_{}".format( name="batch_matmul_manual_broadcast_cpu_{}_{}".format(
a_shape, b_shape)) a_shape, b_shape))
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -21,7 +21,6 @@ import itertools
import numpy as np import numpy as np
from tensorflow.python.compat import compat
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes as dtypes_lib from tensorflow.python.framework import dtypes as dtypes_lib
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
@ -33,15 +32,6 @@ from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging from tensorflow.python.platform import tf_logging
# LINT.IfChange
matrix_diag_v3_forward_compat_date = (2019, 12, 6)
# LINT.ThenChange(
# //tensorflow/compiler/tests/matrix_diag_ops_test.py,
# //tensorflow/python/ops/array_ops.py,
# //tensorflow/python/ops/parallel_for/array_test.py
# )
default_v2_alignment = "LEFT_LEFT" default_v2_alignment = "LEFT_LEFT"
alignment_list = ["RIGHT_LEFT", "LEFT_RIGHT", "LEFT_LEFT", "RIGHT_RIGHT"] alignment_list = ["RIGHT_LEFT", "LEFT_RIGHT", "LEFT_LEFT", "RIGHT_RIGHT"]
@ -391,21 +381,20 @@ class MatrixDiagTest(test.TestCase):
self.assertEqual((3, 3), v_diag.get_shape()) self.assertEqual((3, 3), v_diag.get_shape())
self.assertAllEqual(v_diag.eval(), mat) self.assertAllEqual(v_diag.eval(), mat)
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): # {Sub,Super}diagonals.
# {Sub,Super}diagonals. for offset in [1, -2, 5]:
for offset in [1, -2, 5]: mat = np.diag(v, offset)
mat = np.diag(v, offset) v_diag = array_ops.matrix_diag(v, k=offset)
v_diag = array_ops.matrix_diag(v, k=offset) self.assertEqual(mat.shape, v_diag.get_shape())
self.assertEqual(mat.shape, v_diag.get_shape()) self.assertAllEqual(v_diag.eval(), mat)
self.assertAllEqual(v_diag.eval(), mat)
# Diagonal bands. # Diagonal bands.
for align in alignment_list: for align in alignment_list:
for _, tests in [self._moreCases(align), square_cases(align)]: for _, tests in [self._moreCases(align), square_cases(align)]:
for diags, (vecs, solution) in tests.items(): for diags, (vecs, solution) in tests.items():
v_diags = array_ops.matrix_diag(vecs[0], k=diags, align=align) v_diags = array_ops.matrix_diag(vecs[0], k=diags, align=align)
self.assertEqual(v_diags.get_shape(), solution[0].shape) self.assertEqual(v_diags.get_shape(), solution[0].shape)
self.assertAllEqual(v_diags.eval(), solution[0]) self.assertAllEqual(v_diags.eval(), solution[0])
def _testVectorBatch(self, dtype): def _testVectorBatch(self, dtype):
with self.cached_session(use_gpu=True): with self.cached_session(use_gpu=True):
@ -417,31 +406,30 @@ class MatrixDiagTest(test.TestCase):
self.assertEqual((2, 3, 3), v_batch_diag.get_shape()) self.assertEqual((2, 3, 3), v_batch_diag.get_shape())
self.assertAllEqual(v_batch_diag.eval(), mat_batch) self.assertAllEqual(v_batch_diag.eval(), mat_batch)
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): # {Sub,Super}diagonals.
# {Sub,Super}diagonals. for offset in [1, -2, 5]:
for offset in [1, -2, 5]: v_batch_diag = array_ops.matrix_diag(v_batch, k=offset)
v_batch_diag = array_ops.matrix_diag(v_batch, k=offset) mats = [
mats = [ np.diag(v_batch[i], offset) for i in range(0, v_batch.shape[0])
np.diag(v_batch[i], offset) for i in range(0, v_batch.shape[0]) ]
] mat_batch = np.stack(mats, axis=0)
mat_batch = np.stack(mats, axis=0) self.assertEqual(mat_batch.shape, v_batch_diag.get_shape())
self.assertEqual(mat_batch.shape, v_batch_diag.get_shape()) self.assertAllEqual(v_batch_diag.eval(), mat_batch)
self.assertAllEqual(v_batch_diag.eval(), mat_batch)
# Diagonal bands with padding_value. # Diagonal bands with padding_value.
for padding_value, align in zip_to_first_list_length([0, 555, -11], for padding_value, align in zip_to_first_list_length([0, 555, -11],
alignment_list): alignment_list):
for _, tests in [self._moreCases(align), square_cases(align)]: for _, tests in [self._moreCases(align), square_cases(align)]:
for diags, (vecs, solution) in tests.items(): for diags, (vecs, solution) in tests.items():
v_diags = array_ops.matrix_diag( v_diags = array_ops.matrix_diag(
vecs.astype(dtype), vecs.astype(dtype),
k=diags, k=diags,
padding_value=padding_value, padding_value=padding_value,
align=align) align=align)
mask = solution == 0 mask = solution == 0
solution = (solution + padding_value * mask).astype(dtype) solution = (solution + padding_value * mask).astype(dtype)
self.assertEqual(v_diags.get_shape(), solution.shape) self.assertEqual(v_diags.get_shape(), solution.shape)
self.assertAllEqual(v_diags.eval(), solution) self.assertAllEqual(v_diags.eval(), solution)
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testVectorBatch(self): def testVectorBatch(self):
@ -453,100 +441,99 @@ class MatrixDiagTest(test.TestCase):
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testRectangularBatch(self): def testRectangularBatch(self):
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): with self.cached_session(use_gpu=True):
with self.cached_session(use_gpu=True): # Stores expected num_rows and num_cols (when the other is given).
# Stores expected num_rows and num_cols (when the other is given). # expected[d_lower, d_upper] = (expected_num_rows, expected_num_cols)
# expected[d_lower, d_upper] = (expected_num_rows, expected_num_cols) test_list = list()
test_list = list()
# Square cases: # Square cases:
expected = { expected = {
(-1, -1): (5, 4), (-1, -1): (5, 4),
(-4, -3): (5, 2), (-4, -3): (5, 2),
(-2, 1): (5, 5), (-2, 1): (5, 5),
(2, 4): (3, 5), (2, 4): (3, 5),
} }
# Do not change alignment yet. Re-alignment needs to happen after the # Do not change alignment yet. Re-alignment needs to happen after the
# solution shape is updated. # solution shape is updated.
test_list.append((expected, square_cases())) test_list.append((expected, square_cases()))
# More cases: # More cases:
expected = {(-3, -1): (5, 4), (-1, 1): (4, 4), (2, 4): (4, 6)} expected = {(-3, -1): (5, 4), (-1, 1): (4, 4), (2, 4): (4, 6)}
test_list.append((expected, self._moreCases())) test_list.append((expected, self._moreCases()))
# Tall cases # Tall cases
expected = { expected = {
(0, 0): (3, 3), (0, 0): (3, 3),
(-4, -3): (5, 2), (-4, -3): (5, 2),
(-2, -1): (4, 3), (-2, -1): (4, 3),
(-2, 1): (3, 3), (-2, 1): (3, 3),
(1, 2): (2, 3) (1, 2): (2, 3)
} }
test_list.append((expected, tall_cases())) test_list.append((expected, tall_cases()))
# Fat cases # Fat cases
expected = { expected = {
(2, 2): (2, 4), (2, 2): (2, 4),
(-2, 0): (3, 3), (-2, 0): (3, 3),
(-1, 1): (3, 3), (-1, 1): (3, 3),
(0, 3): (3, 3) (0, 3): (3, 3)
} }
test_list.append((expected, fat_cases())) test_list.append((expected, fat_cases()))
for padding_value, align in zip_to_first_list_length([0, 555, -11], for padding_value, align in zip_to_first_list_length([0, 555, -11],
alignment_list): alignment_list):
# Giving both num_rows and num_cols # Giving both num_rows and num_cols
for _, tests in [tall_cases(align), fat_cases(align)]: for _, tests in [tall_cases(align), fat_cases(align)]:
for diags, (vecs, solution) in tests.items(): for diags, (vecs, solution) in tests.items():
v_diags = array_ops.matrix_diag( v_diags = array_ops.matrix_diag(
vecs, vecs,
k=diags, k=diags,
num_rows=solution.shape[-2], num_rows=solution.shape[-2],
num_cols=solution.shape[-1], num_cols=solution.shape[-1],
padding_value=padding_value, padding_value=padding_value,
align=align) align=align)
mask = solution == 0 mask = solution == 0
solution = solution + padding_value * mask solution = solution + padding_value * mask
self.assertEqual(v_diags.get_shape(), solution.shape) self.assertEqual(v_diags.get_shape(), solution.shape)
self.assertAllEqual(v_diags.eval(), solution) self.assertAllEqual(v_diags.eval(), solution)
# Giving just num_rows. # Giving just num_rows.
for expected, (_, tests) in test_list: for expected, (_, tests) in test_list:
for diags, (_, new_num_cols) in expected.items(): for diags, (_, new_num_cols) in expected.items():
vecs, solution = tests[diags] vecs, solution = tests[diags]
solution = solution.take(indices=range(new_num_cols), axis=-1) solution = solution.take(indices=range(new_num_cols), axis=-1)
# Repacks the diagonal input according to the new solution shape. # Repacks the diagonal input according to the new solution shape.
vecs = repack_diagonals( vecs = repack_diagonals(
vecs, diags, solution.shape[-2], new_num_cols, align=align) vecs, diags, solution.shape[-2], new_num_cols, align=align)
v_diags = array_ops.matrix_diag( v_diags = array_ops.matrix_diag(
vecs, vecs,
k=diags, k=diags,
num_rows=solution.shape[-2], num_rows=solution.shape[-2],
padding_value=padding_value, padding_value=padding_value,
align=align) align=align)
mask = solution == 0 mask = solution == 0
solution = solution + padding_value * mask solution = solution + padding_value * mask
self.assertEqual(v_diags.get_shape(), solution.shape) self.assertEqual(v_diags.get_shape(), solution.shape)
self.assertAllEqual(v_diags.eval(), solution) self.assertAllEqual(v_diags.eval(), solution)
# Giving just num_cols. # Giving just num_cols.
for expected, (_, tests) in test_list: for expected, (_, tests) in test_list:
for diags, (new_num_rows, _) in expected.items(): for diags, (new_num_rows, _) in expected.items():
vecs, solution = tests[diags] vecs, solution = tests[diags]
solution = solution.take(indices=range(new_num_rows), axis=-2) solution = solution.take(indices=range(new_num_rows), axis=-2)
# Repacks the diagonal input according to the new solution shape. # Repacks the diagonal input according to the new solution shape.
vecs = repack_diagonals( vecs = repack_diagonals(
vecs, diags, new_num_rows, solution.shape[-1], align=align) vecs, diags, new_num_rows, solution.shape[-1], align=align)
v_diags = array_ops.matrix_diag( v_diags = array_ops.matrix_diag(
vecs, vecs,
k=diags, k=diags,
num_cols=solution.shape[-1], num_cols=solution.shape[-1],
padding_value=padding_value, padding_value=padding_value,
align=align) align=align)
mask = solution == 0 mask = solution == 0
solution = solution + padding_value * mask solution = solution + padding_value * mask
self.assertEqual(v_diags.get_shape(), solution.shape) self.assertEqual(v_diags.get_shape(), solution.shape)
self.assertAllEqual(v_diags.eval(), solution) self.assertAllEqual(v_diags.eval(), solution)
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testInvalidShape(self): def testInvalidShape(self):
@ -574,21 +561,20 @@ class MatrixDiagTest(test.TestCase):
y.get_shape().as_list()) y.get_shape().as_list())
self.assertLess(error, 1e-4) self.assertLess(error, 1e-4)
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): # {Sub,super}diagonals/band.
# {Sub,super}diagonals/band. tests = dict() # tests[shape] = (d_lower, d_upper)
tests = dict() # tests[shape] = (d_lower, d_upper) tests[(3,)] = (-1, -1)
tests[(3,)] = (-1, -1) tests[(7, 3, 4)] = (-1, 1)
tests[(7, 3, 4)] = (-1, 1) with self.session(use_gpu=True):
with self.session(use_gpu=True): for shape, diags in tests.items():
for shape, diags in tests.items(): x = constant_op.constant(np.random.rand(*shape), np.float32)
x = constant_op.constant(np.random.rand(*shape), np.float32) for align in alignment_list:
for align in alignment_list: y = array_ops.matrix_diag(x, k=diags, align=align)
y = array_ops.matrix_diag(x, k=diags, align=align) error = gradient_checker.compute_gradient_error(
error = gradient_checker.compute_gradient_error( x,
x, x.get_shape().as_list(), y,
x.get_shape().as_list(), y, y.get_shape().as_list())
y.get_shape().as_list()) self.assertLess(error, 1e-4)
self.assertLess(error, 1e-4)
class MatrixSetDiagTest(test.TestCase): class MatrixSetDiagTest(test.TestCase):
@ -604,18 +590,17 @@ class MatrixSetDiagTest(test.TestCase):
self.assertEqual((3, 3), output.get_shape()) self.assertEqual((3, 3), output.get_shape())
self.assertAllEqual(mat_set_diag, self.evaluate(output)) self.assertAllEqual(mat_set_diag, self.evaluate(output))
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): # Diagonal bands.
# Diagonal bands. for align in alignment_list:
for align in alignment_list: _, tests = square_cases(align)
_, tests = square_cases(align) for diags, (vecs, banded_mat) in tests.items():
for diags, (vecs, banded_mat) in tests.items(): mask = banded_mat[0] == 0
mask = banded_mat[0] == 0 input_mat = np.random.randint(10, size=mask.shape)
input_mat = np.random.randint(10, size=mask.shape) solution = input_mat * mask + banded_mat[0]
solution = input_mat * mask + banded_mat[0] output = array_ops.matrix_set_diag(
output = array_ops.matrix_set_diag( input_mat, vecs[0], k=diags, align=align)
input_mat, vecs[0], k=diags, align=align) self.assertEqual(output.get_shape(), solution.shape)
self.assertEqual(output.get_shape(), solution.shape) self.assertAllEqual(output.eval(), solution)
self.assertAllEqual(output.eval(), solution)
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testRectangular(self): def testRectangular(self):
@ -634,18 +619,17 @@ class MatrixSetDiagTest(test.TestCase):
self.assertEqual((3, 2), output.get_shape()) self.assertEqual((3, 2), output.get_shape())
self.assertAllEqual(expected, self.evaluate(output)) self.assertAllEqual(expected, self.evaluate(output))
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): # Diagonal bands.
# Diagonal bands. for align in alignment_list:
for align in alignment_list: for _, tests in [tall_cases(align), fat_cases(align)]:
for _, tests in [tall_cases(align), fat_cases(align)]: for diags, (vecs, banded_mat) in tests.items():
for diags, (vecs, banded_mat) in tests.items(): mask = banded_mat[0] == 0
mask = banded_mat[0] == 0 input_mat = np.random.randint(10, size=mask.shape)
input_mat = np.random.randint(10, size=mask.shape) solution = input_mat * mask + banded_mat[0]
solution = input_mat * mask + banded_mat[0] output = array_ops.matrix_set_diag(
output = array_ops.matrix_set_diag( input_mat, vecs[0], k=diags, align=align)
input_mat, vecs[0], k=diags, align=align) self.assertEqual(output.get_shape(), solution.shape)
self.assertEqual(output.get_shape(), solution.shape) self.assertAllEqual(output.eval(), solution)
self.assertAllEqual(output.eval(), solution)
def _testSquareBatch(self, dtype): def _testSquareBatch(self, dtype):
with self.cached_session(use_gpu=True): with self.cached_session(use_gpu=True):
@ -663,18 +647,17 @@ class MatrixSetDiagTest(test.TestCase):
self.assertEqual((2, 3, 3), output.get_shape()) self.assertEqual((2, 3, 3), output.get_shape())
self.assertAllEqual(mat_set_diag_batch, self.evaluate(output)) self.assertAllEqual(mat_set_diag_batch, self.evaluate(output))
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): # Diagonal bands.
# Diagonal bands. for align in alignment_list:
for align in alignment_list: _, tests = square_cases(align)
_, tests = square_cases(align) for diags, (vecs, banded_mat) in tests.items():
for diags, (vecs, banded_mat) in tests.items(): mask = banded_mat == 0
mask = banded_mat == 0 input_mat = np.random.randint(10, size=mask.shape).astype(dtype)
input_mat = np.random.randint(10, size=mask.shape).astype(dtype) solution = (input_mat * mask + banded_mat).astype(dtype)
solution = (input_mat * mask + banded_mat).astype(dtype) output = array_ops.matrix_set_diag(
output = array_ops.matrix_set_diag( input_mat, vecs.astype(dtype), k=diags, align=align)
input_mat, vecs.astype(dtype), k=diags, align=align) self.assertEqual(output.get_shape(), solution.shape)
self.assertEqual(output.get_shape(), solution.shape) self.assertAllEqual(output.eval(), solution)
self.assertAllEqual(output.eval(), solution)
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testSquareBatch(self): def testSquareBatch(self):
@ -697,19 +680,18 @@ class MatrixSetDiagTest(test.TestCase):
self.assertEqual((2, 2, 3), output.get_shape()) self.assertEqual((2, 2, 3), output.get_shape())
self.assertAllEqual(mat_set_diag_batch, self.evaluate(output)) self.assertAllEqual(mat_set_diag_batch, self.evaluate(output))
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): # Diagonal bands.
# Diagonal bands. for align in alignment_list:
for align in alignment_list: for _, tests in [tall_cases(align), fat_cases(align)]:
for _, tests in [tall_cases(align), fat_cases(align)]: for diags, pair in tests.items():
for diags, pair in tests.items(): vecs, banded_mat = pair
vecs, banded_mat = pair mask = banded_mat == 0
mask = banded_mat == 0 input_mat = np.random.randint(10, size=mask.shape)
input_mat = np.random.randint(10, size=mask.shape) solution = input_mat * mask + banded_mat
solution = input_mat * mask + banded_mat output = array_ops.matrix_set_diag(
output = array_ops.matrix_set_diag( input_mat, vecs, k=diags, align=align)
input_mat, vecs, k=diags, align=align) self.assertEqual(output.get_shape(), solution.shape)
self.assertEqual(output.get_shape(), solution.shape) self.assertAllEqual(output.eval(), solution)
self.assertAllEqual(output.eval(), solution)
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testInvalidShape(self): def testInvalidShape(self):
@ -727,14 +709,13 @@ class MatrixSetDiagTest(test.TestCase):
with self.assertRaisesOpError("diagonal must be at least 1-dim"): with self.assertRaisesOpError("diagonal must be at least 1-dim"):
array_ops.matrix_set_diag([[v]], v).eval(feed_dict={v: 0.0}) array_ops.matrix_set_diag([[v]], v).eval(feed_dict={v: 0.0})
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): d = array_ops.placeholder(dtype=dtypes_lib.float32)
d = array_ops.placeholder(dtype=dtypes_lib.float32) with self.assertRaisesOpError(
with self.assertRaisesOpError( "first dimensions of diagonal don't match"):
"first dimensions of diagonal don't match"): array_ops.matrix_set_diag(v, d).eval(feed_dict={
array_ops.matrix_set_diag(v, d).eval(feed_dict={ v: np.zeros((2, 3, 3)),
v: np.zeros((2, 3, 3)), d: np.ones((2, 4))
d: np.ones((2, 4)) })
})
def _testGrad(self, input_shape, diag_shape, diags, align): def _testGrad(self, input_shape, diag_shape, diags, align):
with self.session(use_gpu=True): with self.session(use_gpu=True):
@ -743,10 +724,7 @@ class MatrixSetDiagTest(test.TestCase):
x_diag = constant_op.constant( x_diag = constant_op.constant(
np.random.rand(*diag_shape), dtype=dtypes_lib.float32) np.random.rand(*diag_shape), dtype=dtypes_lib.float32)
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): y = array_ops.matrix_set_diag(x, x_diag, k=diags, align=align)
y = array_ops.matrix_set_diag(x, x_diag, k=diags, align=align)
else:
y = array_ops.matrix_set_diag(x, x_diag)
error_x = gradient_checker.compute_gradient_error(x, error_x = gradient_checker.compute_gradient_error(x,
x.get_shape().as_list(), x.get_shape().as_list(),
y, y,
@ -763,8 +741,7 @@ class MatrixSetDiagTest(test.TestCase):
input_shapes = [(3, 4, 4), (3, 3, 4), (3, 4, 3), (7, 4, 8, 8)] input_shapes = [(3, 4, 4), (3, 3, 4), (3, 4, 3), (7, 4, 8, 8)]
diag_bands = [(0, 0)] diag_bands = [(0, 0)]
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): diag_bands.append((-1, 1))
diag_bands.append((-1, 1))
for input_shape, diags, align in itertools.product(input_shapes, diag_bands, for input_shape, diags, align in itertools.product(input_shapes, diag_bands,
alignment_list): alignment_list):
lower_diag_index, upper_diag_index = diags lower_diag_index, upper_diag_index = diags
@ -805,21 +782,20 @@ class MatrixDiagPartTest(test.TestCase):
self.assertEqual((3,), mat_diag.get_shape()) self.assertEqual((3,), mat_diag.get_shape())
self.assertAllEqual(mat_diag.eval(), v) self.assertAllEqual(mat_diag.eval(), v)
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): for offset in [-2, 3]:
for offset in [-2, 3]: mat = np.diag(v, offset)
mat = np.diag(v, offset) mat_diag = array_ops.matrix_diag_part(mat, k=offset)
mat_diag = array_ops.matrix_diag_part(mat, k=offset) self.assertEqual((3,), mat_diag.get_shape())
self.assertEqual((3,), mat_diag.get_shape()) self.assertAllEqual(mat_diag.eval(), v)
self.assertAllEqual(mat_diag.eval(), v)
# Diagonal bands. # Diagonal bands.
for align in alignment_list: for align in alignment_list:
mat, tests = square_cases(align) mat, tests = square_cases(align)
for diags, pair in tests.items(): for diags, pair in tests.items():
solution, _ = pair solution, _ = pair
mat_diag = array_ops.matrix_diag_part(mat[0], k=diags, align=align) mat_diag = array_ops.matrix_diag_part(mat[0], k=diags, align=align)
self.assertEqual(mat_diag.get_shape(), solution[0].shape) self.assertEqual(mat_diag.get_shape(), solution[0].shape)
self.assertAllEqual(mat_diag.eval(), solution[0]) self.assertAllEqual(mat_diag.eval(), solution[0])
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testRectangular(self): def testRectangular(self):
@ -831,16 +807,15 @@ class MatrixDiagPartTest(test.TestCase):
mat_diag = array_ops.matrix_diag_part(mat) mat_diag = array_ops.matrix_diag_part(mat)
self.assertAllEqual(mat_diag.eval(), np.array([1.0, 4.0])) self.assertAllEqual(mat_diag.eval(), np.array([1.0, 4.0]))
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): # Diagonal bands.
# Diagonal bands. for align in alignment_list:
for align in alignment_list: for mat, tests in [tall_cases(align), fat_cases(align)]:
for mat, tests in [tall_cases(align), fat_cases(align)]: for diags, pair in tests.items():
for diags, pair in tests.items(): solution, _ = pair
solution, _ = pair mat_diag = array_ops.matrix_diag_part(
mat_diag = array_ops.matrix_diag_part( mat[0], k=diags, align=align)
mat[0], k=diags, align=align) self.assertEqual(mat_diag.get_shape(), solution[0].shape)
self.assertEqual(mat_diag.get_shape(), solution[0].shape) self.assertAllEqual(mat_diag.eval(), solution[0])
self.assertAllEqual(mat_diag.eval(), solution[0])
def _testSquareBatch(self, dtype): def _testSquareBatch(self, dtype):
with self.cached_session(use_gpu=True): with self.cached_session(use_gpu=True):
@ -853,22 +828,21 @@ class MatrixDiagPartTest(test.TestCase):
self.assertEqual((2, 3), mat_batch_diag.get_shape()) self.assertEqual((2, 3), mat_batch_diag.get_shape())
self.assertAllEqual(mat_batch_diag.eval(), v_batch) self.assertAllEqual(mat_batch_diag.eval(), v_batch)
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): # Diagonal bands with padding_value.
# Diagonal bands with padding_value. for padding_value, align in zip_to_first_list_length([0, 555, -11],
for padding_value, align in zip_to_first_list_length([0, 555, -11], alignment_list):
alignment_list): mat, tests = square_cases(align)
mat, tests = square_cases(align) for diags, pair in tests.items():
for diags, pair in tests.items(): solution, _ = pair
solution, _ = pair mat_batch_diag = array_ops.matrix_diag_part(
mat_batch_diag = array_ops.matrix_diag_part( mat.astype(dtype),
mat.astype(dtype), k=diags,
k=diags, padding_value=padding_value,
padding_value=padding_value, align=align)
align=align) mask = solution == 0
mask = solution == 0 solution = (solution + padding_value * mask).astype(dtype)
solution = (solution + padding_value * mask).astype(dtype) self.assertEqual(mat_batch_diag.get_shape(), solution.shape)
self.assertEqual(mat_batch_diag.get_shape(), solution.shape) self.assertAllEqual(mat_batch_diag.eval(), solution)
self.assertAllEqual(mat_batch_diag.eval(), solution)
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testSquareBatch(self): def testSquareBatch(self):
@ -889,29 +863,27 @@ class MatrixDiagPartTest(test.TestCase):
self.assertEqual((2, 2), mat_batch_diag.get_shape()) self.assertEqual((2, 2), mat_batch_diag.get_shape())
self.assertAllEqual(mat_batch_diag.eval(), v_batch) self.assertAllEqual(mat_batch_diag.eval(), v_batch)
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): # Diagonal bands with padding_value and align.
# Diagonal bands with padding_value and align. for padding_value, align in zip_to_first_list_length([0, 555, -11],
for padding_value, align in zip_to_first_list_length([0, 555, -11], alignment_list):
alignment_list): for mat, tests in [tall_cases(align), fat_cases(align)]:
for mat, tests in [tall_cases(align), fat_cases(align)]: for diags, pair in tests.items():
for diags, pair in tests.items(): solution, _ = pair
solution, _ = pair mat_batch_diag = array_ops.matrix_diag_part(
mat_batch_diag = array_ops.matrix_diag_part( mat, k=diags, padding_value=padding_value, align=align)
mat, k=diags, padding_value=padding_value, align=align) mask = solution == 0
mask = solution == 0 solution = solution + padding_value * mask
solution = solution + padding_value * mask self.assertEqual(mat_batch_diag.get_shape(), solution.shape)
self.assertEqual(mat_batch_diag.get_shape(), solution.shape) self.assertAllEqual(mat_batch_diag.eval(), solution)
self.assertAllEqual(mat_batch_diag.eval(), solution)
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testUnknownShape(self): def testUnknownShape(self):
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): matrix = array_ops.placeholder(dtypes_lib.int32, shape=[None, None])
matrix = array_ops.placeholder(dtypes_lib.int32, shape=[None, None]) result = array_ops.matrix_diag_part(matrix, k=-1)
result = array_ops.matrix_diag_part(matrix, k=-1) input_matrix = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
input_matrix = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] with self.session(use_gpu=True):
with self.session(use_gpu=True): result_eval = result.eval(feed_dict={matrix: input_matrix})
result_eval = result.eval(feed_dict={matrix: input_matrix}) self.assertAllEqual([4, 8], result_eval)
self.assertAllEqual([4, 8], result_eval)
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testInvalidShape(self): def testInvalidShape(self):
@ -939,21 +911,20 @@ class MatrixDiagPartTest(test.TestCase):
y.get_shape().as_list()) y.get_shape().as_list())
self.assertLess(error, 1e-4) self.assertLess(error, 1e-4)
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): # {Sub,super}diagonals/band.
# {Sub,super}diagonals/band. tests = dict() # tests[shape] = (d_lower, d_upper)
tests = dict() # tests[shape] = (d_lower, d_upper) tests[(3, 3)] = (-1, -1)
tests[(3, 3)] = (-1, -1) tests[(7, 3, 4)] = (-1, 1)
tests[(7, 3, 4)] = (-1, 1) with self.session(use_gpu=True):
with self.session(use_gpu=True): for align in alignment_list:
for align in alignment_list: for shape, diags in tests.items():
for shape, diags in tests.items(): x = constant_op.constant(np.random.rand(*shape), np.float32)
x = constant_op.constant(np.random.rand(*shape), np.float32) y = array_ops.matrix_diag_part(input=x, k=diags, align=align)
y = array_ops.matrix_diag_part(input=x, k=diags, align=align) error = gradient_checker.compute_gradient_error(
error = gradient_checker.compute_gradient_error( x,
x, x.get_shape().as_list(), y,
x.get_shape().as_list(), y, y.get_shape().as_list())
y.get_shape().as_list()) self.assertLess(error, 1e-4)
self.assertLess(error, 1e-4)
class DiagTest(test.TestCase): class DiagTest(test.TestCase):

View File

@ -22,7 +22,6 @@ import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python import tf2 from tensorflow.python import tf2
from tensorflow.python.compat import compat
from tensorflow.python.eager import backprop from tensorflow.python.eager import backprop
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
@ -122,7 +121,6 @@ class ReluTest(test.TestCase):
order="F") order="F")
err = gradient_checker_v2.max_error( err = gradient_checker_v2.max_error(
*gradient_checker_v2.compute_gradient(nn_ops.relu, [x])) *gradient_checker_v2.compute_gradient(nn_ops.relu, [x]))
print("relu (float32) gradient err = ", err)
self.assertLess(err, 1e-4) self.assertLess(err, 1e-4)
# The gradient for fp16 is inaccurate due to the low-precision. # The gradient for fp16 is inaccurate due to the low-precision.
@ -171,7 +169,6 @@ class ReluTest(test.TestCase):
order="F") order="F")
err = gradient_checker_v2.max_error( err = gradient_checker_v2.max_error(
*gradient_checker_v2.compute_gradient(nn_ops.relu, [x])) *gradient_checker_v2.compute_gradient(nn_ops.relu, [x]))
print("relu (float64) gradient err = ", err)
self.assertLess(err, 1e-10) self.assertLess(err, 1e-10)
def testGradGradFloat32(self): def testGradGradFloat32(self):
@ -190,7 +187,6 @@ class ReluTest(test.TestCase):
order="F") order="F")
err = gradient_checker_v2.max_error( err = gradient_checker_v2.max_error(
*gradient_checker_v2.compute_gradient(f, [x])) *gradient_checker_v2.compute_gradient(f, [x]))
print("relu (float32) gradient of gradient err = ", err)
self.assertLess(err, 1e-4) self.assertLess(err, 1e-4)
def testGradGradFloat64(self): def testGradGradFloat64(self):
@ -209,7 +205,6 @@ class ReluTest(test.TestCase):
order="F") order="F")
err = gradient_checker_v2.max_error( err = gradient_checker_v2.max_error(
*gradient_checker_v2.compute_gradient(f, [x])) *gradient_checker_v2.compute_gradient(f, [x]))
print("relu (float64) gradient of gradient err = ", err)
self.assertLess(err, 1e-10) self.assertLess(err, 1e-10)
def testGradientScalar(self): def testGradientScalar(self):
@ -283,7 +278,6 @@ class Relu6Test(test.TestCase):
order="F") order="F")
err = gradient_checker_v2.max_error( err = gradient_checker_v2.max_error(
*gradient_checker_v2.compute_gradient(nn_ops.relu6, [x])) *gradient_checker_v2.compute_gradient(nn_ops.relu6, [x]))
print("relu6 (float32) gradient err = ", err)
self.assertLess(err, 1e-4) self.assertLess(err, 1e-4)
def testGradientFloat64(self): def testGradientFloat64(self):
@ -294,7 +288,6 @@ class Relu6Test(test.TestCase):
order="F") order="F")
err = gradient_checker_v2.max_error( err = gradient_checker_v2.max_error(
*gradient_checker_v2.compute_gradient(nn_ops.relu6, [x])) *gradient_checker_v2.compute_gradient(nn_ops.relu6, [x]))
print("relu6 (float64) gradient err = ", err)
self.assertLess(err, 1e-10) self.assertLess(err, 1e-10)
@ -345,7 +338,6 @@ class LeakyReluTest(test.TestCase):
order="F") order="F")
err = gradient_checker_v2.max_error( err = gradient_checker_v2.max_error(
*gradient_checker_v2.compute_gradient(nn_ops.leaky_relu, [x])) *gradient_checker_v2.compute_gradient(nn_ops.leaky_relu, [x]))
print("leaky_relu (float32) gradient err = ", err)
self.assertLess(err, 1e-4) self.assertLess(err, 1e-4)
def testGradientFloat64(self): def testGradientFloat64(self):
@ -356,48 +348,43 @@ class LeakyReluTest(test.TestCase):
order="F") order="F")
err = gradient_checker_v2.max_error( err = gradient_checker_v2.max_error(
*gradient_checker_v2.compute_gradient(nn_ops.leaky_relu, [x])) *gradient_checker_v2.compute_gradient(nn_ops.leaky_relu, [x]))
print("leaky_relu (float64) gradient err = ", err)
self.assertLess(err, 1e-10) self.assertLess(err, 1e-10)
def testGradGradFloat32(self): def testGradGradFloat32(self):
with compat.forward_compatibility_horizon(2018, 11, 2): with self.cached_session():
with self.cached_session():
def f(x): def f(x):
assert x.dtype == dtypes.float32 assert x.dtype == dtypes.float32
with backprop.GradientTape() as tape: with backprop.GradientTape() as tape:
tape.watch(x) tape.watch(x)
y = nn_ops.leaky_relu(x) y = nn_ops.leaky_relu(x)
return tape.gradient(y, x) return tape.gradient(y, x)
x = np.asarray( x = np.asarray(
[[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
dtype=np.float32, dtype=np.float32,
order="F") order="F")
err = gradient_checker_v2.max_error( err = gradient_checker_v2.max_error(
*gradient_checker_v2.compute_gradient(f, [x])) *gradient_checker_v2.compute_gradient(f, [x]))
print("leaky_relu (float32) gradient of gradient err = ", err) self.assertLess(err, 1e-4)
self.assertLess(err, 1e-4)
def testGradGradFloat64(self): def testGradGradFloat64(self):
with compat.forward_compatibility_horizon(2018, 11, 2): with self.cached_session():
with self.cached_session():
def f(x): def f(x):
assert x.dtype == dtypes.float64 assert x.dtype == dtypes.float64
with backprop.GradientTape() as tape: with backprop.GradientTape() as tape:
tape.watch(x) tape.watch(x)
y = nn_ops.leaky_relu(x) y = nn_ops.leaky_relu(x)
return tape.gradient(y, x) return tape.gradient(y, x)
x = np.asarray( x = np.asarray(
[[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
dtype=np.float64, dtype=np.float64,
order="F") order="F")
err = gradient_checker_v2.max_error( err = gradient_checker_v2.max_error(
*gradient_checker_v2.compute_gradient(f, [x])) *gradient_checker_v2.compute_gradient(f, [x]))
print("leaky_relu (float64) gradient of gradient err = ", err) self.assertLess(err, 1e-10)
self.assertLess(err, 1e-10)
def testGradientScalar(self): def testGradientScalar(self):
x = variables.Variable(-100.) x = variables.Variable(-100.)
@ -463,7 +450,6 @@ class EluTest(test.TestCase):
x = np.asarray(x_val, dtype=np.float32, order="F") x = np.asarray(x_val, dtype=np.float32, order="F")
err = gradient_checker_v2.max_error( err = gradient_checker_v2.max_error(
*gradient_checker_v2.compute_gradient(nn_ops.elu, [x])) *gradient_checker_v2.compute_gradient(nn_ops.elu, [x]))
print("elu (float32) gradient err = ", err)
self.assertLess(err, 1e-4) self.assertLess(err, 1e-4)
def testGradientFloat64(self): def testGradientFloat64(self):
@ -472,7 +458,6 @@ class EluTest(test.TestCase):
x = np.asarray(x_val, dtype=np.float64, order="F") x = np.asarray(x_val, dtype=np.float64, order="F")
err = gradient_checker_v2.max_error( err = gradient_checker_v2.max_error(
*gradient_checker_v2.compute_gradient(nn_ops.elu, [x])) *gradient_checker_v2.compute_gradient(nn_ops.elu, [x]))
print("elu (float64) gradient err = ", err)
self.assertLess(err, 1e-6) self.assertLess(err, 1e-6)
def testGradGrad(self): def testGradGrad(self):
@ -507,7 +492,6 @@ class EluTest(test.TestCase):
order="F") order="F")
err = gradient_checker_v2.max_error( err = gradient_checker_v2.max_error(
*gradient_checker_v2.compute_gradient(f, [x])) *gradient_checker_v2.compute_gradient(f, [x]))
print("elu (float32) gradient of gradient err = ", err)
self.assertLess(err, 1e-4) self.assertLess(err, 1e-4)
def testGradGradFloat64(self): def testGradGradFloat64(self):
@ -526,7 +510,6 @@ class EluTest(test.TestCase):
order="F") order="F")
err = gradient_checker_v2.max_error( err = gradient_checker_v2.max_error(
*gradient_checker_v2.compute_gradient(f, [x])) *gradient_checker_v2.compute_gradient(f, [x]))
print("elu (float64) gradient of gradient err = ", err)
self.assertLess(err, 1e-6) self.assertLess(err, 1e-6)
@ -567,7 +550,6 @@ class SeluTest(test.TestCase):
x = np.asarray(x_val, dtype=np.float32, order="F") x = np.asarray(x_val, dtype=np.float32, order="F")
err = gradient_checker_v2.max_error( err = gradient_checker_v2.max_error(
*gradient_checker_v2.compute_gradient(nn_ops.selu, [x])) *gradient_checker_v2.compute_gradient(nn_ops.selu, [x]))
print("selu (float32) gradient err = ", err)
self.assertLess(err, 1e-4) self.assertLess(err, 1e-4)
def testGradientFloat64(self): def testGradientFloat64(self):
@ -576,7 +558,6 @@ class SeluTest(test.TestCase):
x = np.asarray(x_val, dtype=np.float64, order="F") x = np.asarray(x_val, dtype=np.float64, order="F")
err = gradient_checker_v2.max_error( err = gradient_checker_v2.max_error(
*gradient_checker_v2.compute_gradient(nn_ops.selu, [x])) *gradient_checker_v2.compute_gradient(nn_ops.selu, [x]))
print("selu (float64) gradient err = ", err)
self.assertLess(err, 1e-6) self.assertLess(err, 1e-6)
def testGradGradFloat32(self): def testGradGradFloat32(self):
@ -595,7 +576,6 @@ class SeluTest(test.TestCase):
order="F") order="F")
err = gradient_checker_v2.max_error( err = gradient_checker_v2.max_error(
*gradient_checker_v2.compute_gradient(f, [x])) *gradient_checker_v2.compute_gradient(f, [x]))
print("selu (float32) gradient of gradient err = ", err)
self.assertLess(err, 1e-4) self.assertLess(err, 1e-4)
def testGradGradFloat64(self): def testGradGradFloat64(self):
@ -614,7 +594,6 @@ class SeluTest(test.TestCase):
order="F") order="F")
err = gradient_checker_v2.max_error( err = gradient_checker_v2.max_error(
*gradient_checker_v2.compute_gradient(f, [x])) *gradient_checker_v2.compute_gradient(f, [x]))
print("selu (float64) gradient of gradient err = ", err)
self.assertLess(err, 1e-6) self.assertLess(err, 1e-6)

View File

@ -22,7 +22,6 @@ from __future__ import print_function
import numpy as np import numpy as np
import six import six
from tensorflow.python.compat import compat
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import common_shapes from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import composite_tensor
@ -54,13 +53,6 @@ tf_export("newaxis").export_constant(__name__, "newaxis")
# existing 'slice' for later use in this module. # existing 'slice' for later use in this module.
_BaseSlice = slice _BaseSlice = slice
# LINT.IfChange
matrix_diag_v3_forward_compat_date = (2019, 12, 6)
# LINT.ThenChange(
# //tensorflow/compiler/tests/matrix_diag_ops_test.py,
# //tensorflow/python/kernel_tests/diag_op_test.py,
# //tensorflow/python/ops/parallel_for/array_test.py
# )
@tf_export("reshape", v1=["reshape", "manip.reshape"]) @tf_export("reshape", v1=["reshape", "manip.reshape"])
def reshape(tensor, shape, name=None): # pylint: disable=redefined-outer-name def reshape(tensor, shape, name=None): # pylint: disable=redefined-outer-name
@ -2362,24 +2354,19 @@ def matrix_diag(diagonal,
Returns: Returns:
A Tensor. Has the same type as `diagonal`. A Tensor. Has the same type as `diagonal`.
""" """
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): # Special case to sidestep the tf.constant conversion error:
# Special case to sidestep the tf.constant conversion error: # TypeError: Expected bool, got 0 of type 'int' instead.
# TypeError: Expected bool, got 0 of type 'int' instead. if hasattr(diagonal, "dtype") and diagonal.dtype == "bool":
if hasattr(diagonal, "dtype") and diagonal.dtype == "bool": padding_value = bool(padding_value)
padding_value = bool(padding_value)
return gen_array_ops.matrix_diag_v3( return gen_array_ops.matrix_diag_v3(
diagonal=diagonal, diagonal=diagonal,
k=k, k=k,
num_rows=num_rows, num_rows=num_rows,
num_cols=num_cols, num_cols=num_cols,
padding_value=padding_value, padding_value=padding_value,
align=align, align=align,
name=name) name=name)
# Call v1 to maintain forward compatibility.
# (We skip v2 because its alignment conflicts with v3's default alignment.)
return gen_array_ops.matrix_diag(diagonal=diagonal, name=name)
@tf_export("linalg.diag_part", v1=["linalg.diag_part", "matrix_diag_part"]) @tf_export("linalg.diag_part", v1=["linalg.diag_part", "matrix_diag_part"])
@ -2513,18 +2500,13 @@ def matrix_diag_part(
Returns: Returns:
A Tensor containing diagonals of `input`. Has the same type as `input`. A Tensor containing diagonals of `input`. Has the same type as `input`.
""" """
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): # Special case to sidestep the tf.constant conversion error:
# Special case to sidestep the tf.constant conversion error: # TypeError: Expected bool, got 0 of type 'int' instead.
# TypeError: Expected bool, got 0 of type 'int' instead. if hasattr(input, "dtype") and input.dtype == "bool":
if hasattr(input, "dtype") and input.dtype == "bool": padding_value = bool(padding_value)
padding_value = bool(padding_value)
return gen_array_ops.matrix_diag_part_v3( return gen_array_ops.matrix_diag_part_v3(
input=input, k=k, padding_value=padding_value, align=align, name=name) input=input, k=k, padding_value=padding_value, align=align, name=name)
# Call v1 to maintain forward compatibility.
# (We skip v2 because its alignment conflicts with v3's default alignment.)
return gen_array_ops.matrix_diag_part(input=input, name=name)
@tf_export("linalg.set_diag", v1=["linalg.set_diag", "matrix_set_diag"]) @tf_export("linalg.set_diag", v1=["linalg.set_diag", "matrix_set_diag"])
@ -2659,14 +2641,8 @@ def matrix_set_diag(
the left (right-pads the row). It is the packing format LAPACK uses. the left (right-pads the row). It is the packing format LAPACK uses.
cuSPARSE uses "LEFT_RIGHT", which is the opposite alignment. cuSPARSE uses "LEFT_RIGHT", which is the opposite alignment.
""" """
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): return gen_array_ops.matrix_set_diag_v3(
return gen_array_ops.matrix_set_diag_v3( input=input, diagonal=diagonal, k=k, align=align, name=name)
input=input, diagonal=diagonal, k=k, align=align, name=name)
# Call v1 to maintain forward compatibility.
# (We skip v2 because its alignment conflicts with v3's default alignment.)
return gen_array_ops.matrix_set_diag(
input=input, diagonal=diagonal, name=name)
# pylint: enable=invalid-name # pylint: enable=invalid-name
@ -4921,7 +4897,7 @@ def quantize_v2(
raise ValueError("input should have known rank to use negative axis.") raise ValueError("input should have known rank to use negative axis.")
axis %= input.shape.ndims axis %= input.shape.ndims
if compat.forward_compatible(2019, 11, 13) or ensure_minimum_range != 0.01: if ensure_minimum_range != 0.01:
return gen_array_ops.quantize_v2( return gen_array_ops.quantize_v2(
input, input,
min_range, min_range,
@ -4965,7 +4941,7 @@ def quantize(
axis=None, axis=None,
ensure_minimum_range=0.01): ensure_minimum_range=0.01):
"""Quantize the input tensor.""" """Quantize the input tensor."""
if compat.forward_compatible(2019, 11, 13) or ensure_minimum_range != 0.01: if ensure_minimum_range != 0.01:
return quantize_v2( return quantize_v2(
input, input,
min_range, min_range,
@ -5007,7 +4983,7 @@ def dequantize( # pylint: disable=missing-docstring
raise ValueError("input should have known rank to use negative axis.") raise ValueError("input should have known rank to use negative axis.")
axis %= input.shape.ndims axis %= input.shape.ndims
if compat.forward_compatible(2019, 10, 22) or axis >= 0 or narrow_range: if axis >= 0 or narrow_range:
return gen_array_ops.dequantize( return gen_array_ops.dequantize(
input, min_range, max_range, mode=mode, name=name, input, min_range, max_range, mode=mode, name=name,
narrow_range=narrow_range, axis=axis) narrow_range=narrow_range, axis=axis)

View File

@ -20,7 +20,6 @@ from __future__ import print_function
import numpy as np import numpy as np
from tensorflow.python.compat import compat
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
@ -383,26 +382,25 @@ class BatchNormalizationTest(test.TestCase):
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC') x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
def testInferenceShape5(self): def testInferenceShape5(self):
with compat.forward_compatibility_horizon(2019, 6, 7): x_shape = [0, 131, 127, 6]
x_shape = [0, 131, 127, 6] for dtype in [np.float16, np.float32]:
for dtype in [np.float16, np.float32]: if test.is_gpu_available(cuda_only=True):
if test.is_gpu_available(cuda_only=True):
self._test_inference(
x_shape,
dtype, [131],
np.float32,
use_gpu=True,
data_format='NCHW')
self._test_inference(
x_shape, dtype, [6], np.float32, use_gpu=True, data_format='NHWC')
self._test_inference( self._test_inference(
x_shape, x_shape,
dtype, [131], dtype, [131],
np.float32, np.float32,
use_gpu=False, use_gpu=True,
data_format='NCHW') data_format='NCHW')
self._test_inference( self._test_inference(
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC') x_shape, dtype, [6], np.float32, use_gpu=True, data_format='NHWC')
self._test_inference(
x_shape,
dtype, [131],
np.float32,
use_gpu=False,
data_format='NCHW')
self._test_inference(
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
def testTrainingShape1(self): def testTrainingShape1(self):
x_shape = [1, 1, 6, 1] x_shape = [1, 1, 6, 1]
@ -450,26 +448,25 @@ class BatchNormalizationTest(test.TestCase):
@test_util.disable_xla('b/141236973: Empty inputs wrong on CPU.') @test_util.disable_xla('b/141236973: Empty inputs wrong on CPU.')
def testTrainingShape5(self): def testTrainingShape5(self):
with compat.forward_compatibility_horizon(2019, 6, 7): x_shape = [0, 131, 127, 6]
x_shape = [0, 131, 127, 6] for dtype in [np.float16, np.float32]:
for dtype in [np.float16, np.float32]: if test.is_gpu_available(cuda_only=True):
if test.is_gpu_available(cuda_only=True):
self._test_training(
x_shape,
dtype, [131],
np.float32,
use_gpu=True,
data_format='NCHW')
self._test_training(
x_shape, dtype, [6], np.float32, use_gpu=True, data_format='NHWC')
self._test_training( self._test_training(
x_shape, x_shape,
dtype, [131], dtype, [131],
np.float32, np.float32,
use_gpu=False, use_gpu=True,
data_format='NCHW') data_format='NCHW')
self._test_training( self._test_training(
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC') x_shape, dtype, [6], np.float32, use_gpu=True, data_format='NHWC')
self._test_training(
x_shape,
dtype, [131],
np.float32,
use_gpu=False,
data_format='NCHW')
self._test_training(
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testBatchNormGradShape1(self): def testBatchNormGradShape1(self):
@ -586,39 +583,38 @@ class BatchNormalizationTest(test.TestCase):
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
@test_util.disable_xla('This test never passed for XLA') @test_util.disable_xla('This test never passed for XLA')
def testBatchNormGradShape5(self): def testBatchNormGradShape5(self):
with compat.forward_compatibility_horizon(2019, 6, 7): for is_training in [True, False]:
for is_training in [True, False]: x_shape = [0, 7, 11, 4]
x_shape = [0, 7, 11, 4] for dtype in [np.float16, np.float32]:
for dtype in [np.float16, np.float32]: if test.is_gpu_available(cuda_only=True):
if test.is_gpu_available(cuda_only=True):
self._test_gradient(
x_shape,
dtype, [7],
np.float32,
use_gpu=True,
data_format='NCHW',
is_training=is_training)
self._test_gradient(
x_shape,
dtype, [4],
np.float32,
use_gpu=True,
data_format='NHWC',
is_training=is_training)
self._test_gradient(
x_shape,
dtype, [4],
np.float32,
use_gpu=False,
data_format='NHWC',
is_training=is_training)
self._test_gradient( self._test_gradient(
x_shape, x_shape,
dtype, [7], dtype, [7],
np.float32, np.float32,
use_gpu=False, use_gpu=True,
data_format='NCHW', data_format='NCHW',
is_training=is_training) is_training=is_training)
self._test_gradient(
x_shape,
dtype, [4],
np.float32,
use_gpu=True,
data_format='NHWC',
is_training=is_training)
self._test_gradient(
x_shape,
dtype, [4],
np.float32,
use_gpu=False,
data_format='NHWC',
is_training=is_training)
self._test_gradient(
x_shape,
dtype, [7],
np.float32,
use_gpu=False,
data_format='NCHW',
is_training=is_training)
def _testBatchNormGradGrad(self, config): def _testBatchNormGradGrad(self, config):
shape = config['shape'] shape = config['shape']

View File

@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.compat import compat
from tensorflow.python.eager import backprop from tensorflow.python.eager import backprop
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
@ -33,15 +32,6 @@ from tensorflow.python.ops.parallel_for.test_util import PForTestCase
from tensorflow.python.platform import test from tensorflow.python.platform import test
# LINT.IfChange
matrix_diag_v3_forward_compat_date = (2019, 12, 6)
# LINT.ThenChange(
# //tensorflow/compiler/tests/matrix_diag_ops_test.py,
# //tensorflow/python/kernel_tests/diag_op_test.py,
# //tensorflow/python/ops/array_ops.py
# )
@test_util.run_all_in_graph_and_eager_modes @test_util.run_all_in_graph_and_eager_modes
class ArrayTest(PForTestCase): class ArrayTest(PForTestCase):
@ -345,10 +335,8 @@ class ArrayTest(PForTestCase):
def loop_fn(i): def loop_fn(i):
diagonal = array_ops.gather(x, i) diagonal = array_ops.gather(x, i)
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): return array_ops.matrix_diag(
return array_ops.matrix_diag( diagonal, k=(0, 1), num_rows=4, num_cols=5, align="RIGHT_LEFT")
diagonal, k=(0, 1), num_rows=4, num_cols=5, align="RIGHT_LEFT")
return array_ops.matrix_diag(diagonal)
self._test_loop_fn(loop_fn, 3) self._test_loop_fn(loop_fn, 3)
@ -357,10 +345,8 @@ class ArrayTest(PForTestCase):
def loop_fn(i): def loop_fn(i):
input = array_ops.gather(x, i) # pylint: disable=redefined-builtin input = array_ops.gather(x, i) # pylint: disable=redefined-builtin
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): return array_ops.matrix_diag_part(
return array_ops.matrix_diag_part( input, k=(-2, 0), padding_value=3, align="RIGHT_LEFT")
input, k=(-2, 0), padding_value=3, align="RIGHT_LEFT")
return array_ops.matrix_diag_part(input)
self._test_loop_fn(loop_fn, 3) self._test_loop_fn(loop_fn, 3)
@ -378,17 +364,16 @@ class ArrayTest(PForTestCase):
array_ops.matrix_set_diag(matrix_i, diags[0, ...]), array_ops.matrix_set_diag(matrix_i, diags[0, ...]),
] ]
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): k = (-1, 1)
k = (-1, 1) band_i = array_ops.gather(bands, i)
band_i = array_ops.gather(bands, i) for align in ["RIGHT_LEFT", "LEFT_RIGHT"]:
for align in ["RIGHT_LEFT", "LEFT_RIGHT"]: results.extend([
results.extend([ array_ops.matrix_set_diag(matrix_i, band_i, k=k, align=align),
array_ops.matrix_set_diag(matrix_i, band_i, k=k, align=align), array_ops.matrix_set_diag(
array_ops.matrix_set_diag( matrices[0, ...], band_i, k=k, align=align),
matrices[0, ...], band_i, k=k, align=align), array_ops.matrix_set_diag(
array_ops.matrix_set_diag( matrix_i, bands[0, ...], k=k, align=align)
matrix_i, bands[0, ...], k=k, align=align) ])
])
return results return results
self._test_loop_fn(loop_fn, 3) self._test_loop_fn(loop_fn, 3)

View File

@ -28,7 +28,6 @@ import numpy as np
from tensorflow.core.example import example_pb2 from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2 from tensorflow.core.example import feature_pb2
from tensorflow.python.client import session from tensorflow.python.client import session
from tensorflow.python.compat import compat
from tensorflow.python.eager import backprop from tensorflow.python.eager import backprop
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
@ -444,55 +443,54 @@ class NNTest(PForTestCase):
self._test_loop_fn(loop_fn, 3) self._test_loop_fn(loop_fn, 3)
def test_fused_batch_norm(self): def test_fused_batch_norm(self):
with compat.forward_compatibility_horizon(2019, 6, 7): data_formats = ["NHWC"]
data_formats = ["NHWC"] if test.is_gpu_available():
if test.is_gpu_available(): data_formats.append("NCHW")
data_formats.append("NCHW") for is_training in (True, False):
for is_training in (True, False): for data_format in data_formats:
for data_format in data_formats: with backprop.GradientTape(persistent=True) as g:
with backprop.GradientTape(persistent=True) as g: if data_format == "NCHW":
if data_format == "NCHW": x = random_ops.random_uniform([3, 1, 2, 5, 5])
x = random_ops.random_uniform([3, 1, 2, 5, 5]) else:
else: x = random_ops.random_uniform([3, 1, 5, 5, 2])
x = random_ops.random_uniform([3, 1, 5, 5, 2]) g.watch(x)
g.watch(x) scale = random_ops.random_uniform([2])
scale = random_ops.random_uniform([2]) g.watch(scale)
g.watch(scale) offset = random_ops.random_uniform([2])
offset = random_ops.random_uniform([2]) g.watch(offset)
g.watch(offset) mean = None if is_training else random_ops.random_uniform([2])
mean = None if is_training else random_ops.random_uniform([2]) variance = None if is_training else random_ops.random_uniform([2])
variance = None if is_training else random_ops.random_uniform([2])
# pylint: disable=cell-var-from-loop # pylint: disable=cell-var-from-loop
def loop_fn(i): def loop_fn(i):
with g: with g:
x1 = array_ops.gather(x, i) x1 = array_ops.gather(x, i)
outputs = nn.fused_batch_norm( outputs = nn.fused_batch_norm(
x1, x1,
scale, scale,
offset, offset,
mean=mean, mean=mean,
variance=variance, variance=variance,
epsilon=0.01, epsilon=0.01,
data_format=data_format, data_format=data_format,
is_training=is_training) is_training=is_training)
outputs = list(outputs) outputs = list(outputs)
# We only test the first value of outputs when is_training is # We only test the first value of outputs when is_training is
# False. It looks like CPU and GPU have different outputs for # False. It looks like CPU and GPU have different outputs for
# batch_mean and batch_variance for this case. # batch_mean and batch_variance for this case.
if not is_training: if not is_training:
outputs[1] = constant_op.constant(0.) outputs[1] = constant_op.constant(0.)
outputs[2] = constant_op.constant(0.) outputs[2] = constant_op.constant(0.)
loss = nn.l2_loss(outputs[0]) loss = nn.l2_loss(outputs[0])
if is_training: if is_training:
gradients = g.gradient(loss, [x1, scale, offset]) gradients = g.gradient(loss, [x1, scale, offset])
else: else:
gradients = [constant_op.constant(0.)] * 3 gradients = [constant_op.constant(0.)] * 3
return outputs + gradients return outputs + gradients
# pylint: enable=cell-var-from-loop # pylint: enable=cell-var-from-loop
self._test_loop_fn(loop_fn, 3) self._test_loop_fn(loop_fn, 3)
def test_log_softmax(self): def test_log_softmax(self):
logits = random_ops.random_uniform([3, 2, 4]) logits = random_ops.random_uniform([3, 2, 4])

View File

@ -33,7 +33,6 @@ import six
from six.moves import xrange # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.compiler.tf2xla.ops import gen_xla_ops from tensorflow.compiler.tf2xla.ops import gen_xla_ops
from tensorflow.python.compat import compat as fwd_compat
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
@ -253,10 +252,7 @@ def einsum(equation, *inputs, **kwargs):
- the format of `equation` is incorrect, - the format of `equation` is incorrect,
- number of inputs or their shapes are inconsistent with `equation`. - number of inputs or their shapes are inconsistent with `equation`.
""" """
if fwd_compat.forward_compatible(2019, 10, 18): return _einsum_v2(equation, *inputs, **kwargs)
return _einsum_v2(equation, *inputs, **kwargs)
else:
return _einsum_v1(equation, *inputs, **kwargs)
def _einsum_v1(equation, *inputs, **kwargs): def _einsum_v1(equation, *inputs, **kwargs):

View File

@ -23,7 +23,6 @@ import opt_einsum
import six import six
from tensorflow.python.client import session from tensorflow.python.client import session
from tensorflow.python.compat import compat
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
@ -243,20 +242,19 @@ class EinsumTest(test.TestCase):
self._check('abc->ca', (3, 4, 5)) self._check('abc->ca', (3, 4, 5))
self._check('abc->cab', (3, 4, 5)) self._check('abc->cab', (3, 4, 5))
with compat.forward_compatibility_horizon(2019, 10, 19): # Empty cases.
# Empty cases. self._check('', ())
self._check('', ()) self._check('->', ())
self._check('->', ())
# Repeated indices cases. # Repeated indices cases.
self._check('aa->', (3, 3)) self._check('aa->', (3, 3))
self._check('aa->a', (3, 3)) self._check('aa->a', (3, 3))
self._check('aaa->', (3, 3, 3)) self._check('aaa->', (3, 3, 3))
self._check('aaa->a', (3, 3, 3)) self._check('aaa->a', (3, 3, 3))
self._check('aab->a', (3, 3, 4)) self._check('aab->a', (3, 3, 4))
self._check('aabcc->a', (3, 3, 5, 4, 4)) self._check('aabcc->a', (3, 3, 5, 4, 4))
self._check('aabcc->ac', (3, 3, 5, 4, 4)) self._check('aabcc->ac', (3, 3, 5, 4, 4))
self._check('aabcd->ad', (3, 3, 5, 4, 4)) self._check('aabcd->ad', (3, 3, 5, 4, 4))
def test_unary_ellipsis(self): def test_unary_ellipsis(self):
self._check('...->', ()) self._check('...->', ())
@ -266,17 +264,16 @@ class EinsumTest(test.TestCase):
self._check('...ij->...ji', (5, 2, 3)) # batch matrix transpose self._check('...ij->...ji', (5, 2, 3)) # batch matrix transpose
self._check('...ij->...', (5, 2, 3)) # batch sum self._check('...ij->...', (5, 2, 3)) # batch sum
with compat.forward_compatibility_horizon(2019, 10, 19): self._check('...->...', ())
self._check('...->...', ()) self._check('->...', ())
self._check('->...', ())
# Repeated indices. # Repeated indices.
self._check('i...ii->...i', (3, 2, 3, 3)) self._check('i...ii->...i', (3, 2, 3, 3))
self._check('i...i->i...', (2, 2)) self._check('i...i->i...', (2, 2))
self._check('i...i->', (2, 2)) self._check('i...i->', (2, 2))
self._check('i...i->...', (2, 5, 1, 2)) self._check('i...i->...', (2, 5, 1, 2))
self._check('i...i->i...', (2, 1, 2)) self._check('i...i->i...', (2, 1, 2))
self._check('i...i->i...', (2, 3, 4, 5, 2)) self._check('i...i->i...', (2, 3, 4, 5, 2))
def test_binary_simple(self): def test_binary_simple(self):
# Binary cases in XLA mode must have either (a) each index appearing exactly # Binary cases in XLA mode must have either (a) each index appearing exactly
@ -301,13 +298,12 @@ class EinsumTest(test.TestCase):
self._check('ab,ab->', (3, 4), (3, 4)) self._check('ab,ab->', (3, 4), (3, 4))
def test_repeated_indices(self): def test_repeated_indices(self):
with compat.forward_compatibility_horizon(2019, 10, 19): # Repeated indices.
# Repeated indices. self._check('ijj,k->ik', (2, 3, 3), (4,))
self._check('ijj,k->ik', (2, 3, 3), (4,)) self._check('aba,a->b', (3, 4, 3), (3,))
self._check('aba,a->b', (3, 4, 3), (3,)) # From https://github.com/dask/dask/pull/3412#discussion_r182413444
# From https://github.com/dask/dask/pull/3412#discussion_r182413444 self._check('aab,bc->ac', (2, 2, 3), (3, 4))
self._check('aab,bc->ac', (2, 2, 3), (3, 4)) self._check('aab,bcc->ac', (2, 2, 3), (3, 4, 4))
self._check('aab,bcc->ac', (2, 2, 3), (3, 4, 4))
def test_binary_ellipsis(self): def test_binary_ellipsis(self):
# Batch matmul with ellipsis but without broadcasting. # Batch matmul with ellipsis but without broadcasting.
@ -324,23 +320,22 @@ class EinsumTest(test.TestCase):
self._check('...i,...j->...ij', (5, 2), (5, 3)) # outer product self._check('...i,...j->...ij', (5, 2), (5, 3)) # outer product
def test_broadcasting(self): def test_broadcasting(self):
with compat.forward_compatibility_horizon(2019, 10, 19): # Batch matmul with broadcasting.
# Batch matmul with broadcasting. self._check('...ij,...jk->...ik', (1, 2, 3), (3, 5))
self._check('...ij,...jk->...ik', (1, 2, 3), (3, 5)) self._check('...ij,...jk->...ik', (2, 3), (1, 3, 5))
self._check('...ij,...jk->...ik', (2, 3), (1, 3, 5)) self._check('...ij,...jk->...ik', (5, 2, 3), (3, 5))
self._check('...ij,...jk->...ik', (5, 2, 3), (3, 5)) self._check('...ij,...jk->...ik', (2, 3), (5, 3, 5))
self._check('...ij,...jk->...ik', (2, 3), (5, 3, 5)) self._check('...ij,...jk->...ik', (3, 1, 2, 3), (1, 1, 7, 3, 5))
self._check('...ij,...jk->...ik', (3, 1, 2, 3), (1, 1, 7, 3, 5)) self._check('i...j,j...k->...ik', (2, 1, 3, 1, 3), (3, 1, 7, 5))
self._check('i...j,j...k->...ik', (2, 1, 3, 1, 3), (3, 1, 7, 5))
# Broadcasting with repeated indices. # Broadcasting with repeated indices.
self._check('ij,jk...k->i...', (3, 2), (2, 4, 1, 4)) self._check('ij,jk...k->i...', (3, 2), (2, 4, 1, 4))
self._check('ij,jk...k->...i', (3, 2), (2, 4, 5, 4)) self._check('ij,jk...k->...i', (3, 2), (2, 4, 5, 4))
self._check('ijj,jk...k->i...', (3, 2, 2), (2, 4, 1, 4)) self._check('ijj,jk...k->i...', (3, 2, 2), (2, 4, 1, 4))
self._check('i...jj,jk...k->i...', (3, 3, 1, 2, 2), (2, 4, 1, 5, 4)) self._check('i...jj,jk...k->i...', (3, 3, 1, 2, 2), (2, 4, 1, 5, 4))
# Following 2 from https://stackoverflow.com/a/19203475/1611416 # Following 2 from https://stackoverflow.com/a/19203475/1611416
self._check('...abc,...abcd->...d', (1, 1, 2, 3, 4), (5, 2, 3, 4, 6)) self._check('...abc,...abcd->...d', (1, 1, 2, 3, 4), (5, 2, 3, 4, 6))
self._check('ab...,b->ab...', (2, 3, 1, 1, 5), (3,)) self._check('ab...,b->ab...', (2, 3, 1, 1, 5), (3,))
def test_dtypes(self): def test_dtypes(self):
dtypes = [] dtypes = []
@ -388,22 +383,20 @@ class EinsumTest(test.TestCase):
((4, 3), (None, 3))) ((4, 3), (None, 3)))
# Ellipsis with unknown rank. # Ellipsis with unknown rank.
with compat.forward_compatibility_horizon(2019, 10, 19): check('bijl,bjkm->bik', ((9, 2, 3, 5), None), ((9, 3, 4, 7), None))
check('bijl,bjkm->bik', ((9, 2, 3, 5), None), ((9, 3, 4, 7), None)) check('...ij,...jk->...ik', ((3, 1, 2, 3), None), ((1, 7, 3, 4), None))
check('...ij,...jk->...ik', ((3, 1, 2, 3), None), ((1, 7, 3, 4), None))
def test_numpy_input(self): def test_numpy_input(self):
with compat.forward_compatibility_horizon(2019, 10, 19): # In addition to Tensors, we also support raw numpy arrays as inputs.
# In addition to Tensors, we also support raw numpy arrays as inputs. r = np.random.RandomState(0)
r = np.random.RandomState(0) s = 'ijk,ijl,ikl->i'
s = 'ijk,ijl,ikl->i' x = r.randn(1, 2, 3)
x = r.randn(1, 2, 3) y = r.randn(1, 2, 4)
y = r.randn(1, 2, 4) z = r.randn(1, 3, 4)
z = r.randn(1, 3, 4)
a = np.einsum(s, x, y, z) a = np.einsum(s, x, y, z)
b = self.evaluate(special_math_ops.einsum(s, x, y, z)) b = self.evaluate(special_math_ops.einsum(s, x, y, z))
self.assertAllClose(a, b, atol=1e-4, rtol=1e-4) self.assertAllClose(a, b, atol=1e-4, rtol=1e-4)
def test_long_cases(self): def test_long_cases(self):
cases = [ cases = [
@ -464,58 +457,56 @@ class EinsumTest(test.TestCase):
@test_util.disable_xla('b/131919749') @test_util.disable_xla('b/131919749')
def test_long_cases_with_repeated_labels(self): def test_long_cases_with_repeated_labels(self):
with compat.forward_compatibility_horizon(2019, 10, 19): cases = [
cases = [ # Tests from dask.
# Tests from dask. 'fdf,cdd,ccd,afe->ae',
'fdf,cdd,ccd,afe->ae', 'fff,fae,bef,def->abd',
'fff,fae,bef,def->abd', ]
] dimension_map = dict((c, ord(c) - ord('a') + 1) for c in 'abcdefghij')
dimension_map = dict((c, ord(c) - ord('a') + 1) for c in 'abcdefghij') for equation in cases:
for equation in cases: inputs = equation.split('->')[0].replace(' ', '')
inputs = equation.split('->')[0].replace(' ', '') input_shapes = []
input_shapes = [] for input_str in inputs.split(','):
for input_str in inputs.split(','): input_shapes.append(tuple([dimension_map[c] for c in input_str]))
input_shapes.append(tuple([dimension_map[c] for c in input_str])) self._check(equation, *input_shapes)
self._check(equation, *input_shapes)
@test_util.disable_xla('b/131919749') @test_util.disable_xla('b/131919749')
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def test_invalid_equation(self): def test_invalid_equation(self):
with compat.forward_compatibility_horizon(2019, 10, 19): r = np.random.RandomState(0)
r = np.random.RandomState(0) cases = [
cases = [ # invalid equation format.
# invalid equation format. ('a0->a', r.randn(5, 3)),
('a0->a', r.randn(5, 3)), ('a->a,a', r.randn(5)),
('a->a,a', r.randn(5)), ('a->a->a', r.randn(5)),
('a->a->a', r.randn(5)), ('ijk ijk', r.randn(1, 2, 3), r.randn(1, 2, 3)),
('ijk ijk', r.randn(1, 2, 3), r.randn(1, 2, 3)), ('ij.jk->ik', r.randn(2, 3), r.randn(3, 4)),
('ij.jk->ik', r.randn(2, 3), r.randn(3, 4)), # output label not present in input.
# output label not present in input. ('a->b', r.randn(5)),
('a->b', r.randn(5)), ('ij,jk->im', r.randn(2, 3), r.randn(3, 4)),
('ij,jk->im', r.randn(2, 3), r.randn(3, 4)), # wrong shape.
# wrong shape. ('ij,jk->ik', r.randn(1, 2, 3), r.randn(3, 4)),
('ij,jk->ik', r.randn(1, 2, 3), r.randn(3, 4)), # inconsistent dimensions.
# inconsistent dimensions. ('ij,jk->ik', r.randn(2, 3), r.randn(4, 4)),
('ij,jk->ik', r.randn(2, 3), r.randn(4, 4)), # output has repeated subscripts.
# output has repeated subscripts. ('ij,jk->iik', r.randn(2, 3), r.randn(3, 4)),
('ij,jk->iik', r.randn(2, 3), r.randn(3, 4)), # too many ellipses
# too many ellipses ('...ij...,jk...->ik...', r.randn(2, 3), r.randn(3, 4)),
('...ij...,jk...->ik...', r.randn(2, 3), r.randn(3, 4)), ('...ij,jk...->...ik...', r.randn(2, 3), r.randn(3, 4)),
('...ij,jk...->...ik...', r.randn(2, 3), r.randn(3, 4)), # invalid broadcast dimensions.
# invalid broadcast dimensions. ('...ij,...jk->...ik', r.randn(5, 2, 3), r.randn(7, 3, 4)),
('...ij,...jk->...ik', r.randn(5, 2, 3), r.randn(7, 3, 4)), # output should have ellipsis when broadcasting shape is non-empty.
# output should have ellipsis when broadcasting shape is non-empty. ('...ij,...jk->ik', r.randn(2, 2, 3), r.randn(3, 4)),
('...ij,...jk->ik', r.randn(2, 2, 3), r.randn(3, 4)), ]
] for args in cases:
for args in cases: with self.assertRaises((ValueError, errors.InvalidArgumentError)):
with self.assertRaises((ValueError, errors.InvalidArgumentError)): _ = special_math_ops.einsum(*args)
_ = special_math_ops.einsum(*args)
placeholders = [ placeholders = [
array_ops.placeholder_with_default(x, shape=None) for x in args[1:] array_ops.placeholder_with_default(x, shape=None) for x in args[1:]
] ]
with self.assertRaises((ValueError, errors.InvalidArgumentError)): with self.assertRaises((ValueError, errors.InvalidArgumentError)):
_ = self.evaluate(special_math_ops.einsum(args[0], *placeholders)) _ = self.evaluate(special_math_ops.einsum(args[0], *placeholders))
@test_util.disable_xla('b/131919749') @test_util.disable_xla('b/131919749')
def test_empty(self): def test_empty(self):
@ -535,10 +526,9 @@ class EinsumTest(test.TestCase):
# From transformer xl. # From transformer xl.
check('ibnd,ijbn->jnd', [(1, 0, 5, 10), (1, 1, 0, 5)], (1, 5, 10)) check('ibnd,ijbn->jnd', [(1, 0, 5, 10), (1, 1, 0, 5)], (1, 5, 10))
with compat.forward_compatibility_horizon(2019, 10, 19): # Generalized traces with zero-sized dimensions.
# Generalized traces with zero-sized dimensions. check('aab,bc->ac', [(0, 0, 10), (10, 10)], (0, 10))
check('aab,bc->ac', [(0, 0, 10), (10, 10)], (0, 10)) check('aaab,bc->c', [(0, 0, 0, 3), (3, 4)], (4,))
check('aaab,bc->c', [(0, 0, 0, 3), (3, 4)], (4,))
@test_util.run_all_in_graph_and_eager_modes @test_util.run_all_in_graph_and_eager_modes
@ -556,122 +546,112 @@ class EinsumGradTest(test.TestCase):
@test_util.disable_xla('b/131919749') @test_util.disable_xla('b/131919749')
def test_unary(self): def test_unary(self):
with compat.forward_compatibility_horizon(2019, 10, 19): self._check_gradient('->', ())
self._check_gradient('->', ()) self._check_gradient('aaa->a', (3, 3, 3))
self._check_gradient('aaa->a', (3, 3, 3)) self._check_gradient('aabcd->ad', (3, 3, 5, 4, 4))
self._check_gradient('aabcd->ad', (3, 3, 5, 4, 4)) self._check_gradient('abcd->da', (3, 5, 4, 2))
self._check_gradient('abcd->da', (3, 5, 4, 2))
@test_util.disable_xla('b/131919749') @test_util.disable_xla('b/131919749')
def test_unary_ellipsis(self): def test_unary_ellipsis(self):
with compat.forward_compatibility_horizon(2019, 10, 19): self._check_gradient('...->...', ())
self._check_gradient('...->...', ()) self._check_gradient('...->', ())
self._check_gradient('...->', ()) self._check_gradient('->...', ())
self._check_gradient('->...', ())
# Tests from dask # Tests from dask
self._check_gradient('a...a->a...', (2, 2)) self._check_gradient('a...a->a...', (2, 2))
self._check_gradient('a...a->', (2, 2)) self._check_gradient('a...a->', (2, 2))
self._check_gradient('a...a->...', (2, 5, 1, 2)) self._check_gradient('a...a->...', (2, 5, 1, 2))
self._check_gradient('a...a->a...', (2, 1, 2)) self._check_gradient('a...a->a...', (2, 1, 2))
self._check_gradient('a...a->a...', (2, 3, 4, 5, 2)) self._check_gradient('a...a->a...', (2, 3, 4, 5, 2))
self._check_gradient('...ijk->...ki', (3, 4, 5)) self._check_gradient('...ijk->...ki', (3, 4, 5))
self._check_gradient('...ijk->...ki', (1, 3, 4, 5)) self._check_gradient('...ijk->...ki', (1, 3, 4, 5))
self._check_gradient('...ijk->...ki', (2, 2, 3, 4, 5)) self._check_gradient('...ijk->...ki', (2, 2, 3, 4, 5))
self._check_gradient('ab...cd->da...', (3, 5, 2, 3, 4, 2)) self._check_gradient('ab...cd->da...', (3, 5, 2, 3, 4, 2))
def test_binary_simple(self): def test_binary_simple(self):
with compat.forward_compatibility_horizon(2019, 10, 19): # Binary cases in XLA mode must have either (a) each index appearing
# Binary cases in XLA mode must have either (a) each index appearing # exactly once in both the inputs (batch or contraction index), or
# exactly once in both the inputs (batch or contraction index), or # (b) appearing exactly once in an input and in the output (free index).
# (b) appearing exactly once in an input and in the output (free index). self._check_gradient(',->', (), ())
self._check_gradient(',->', (), ()) self._check_gradient('a,a->', (3,), (3,))
self._check_gradient('a,a->', (3,), (3,)) self._check_gradient('a,a->a', (3,), (3,))
self._check_gradient('a,a->a', (3,), (3,)) self._check_gradient('ab,b->a', (3, 4), (4,))
self._check_gradient('ab,b->a', (3, 4), (4,)) self._check_gradient('ab,ab->', (3, 4), (3, 4))
self._check_gradient('ab,ab->', (3, 4), (3, 4)) self._check_gradient('ab,bc->ac', (3, 4), (4, 5))
self._check_gradient('ab,bc->ac', (3, 4), (4, 5)) self._check_gradient('nij,jk->nik', (5, 2, 3), (3, 4))
self._check_gradient('nij,jk->nik', (5, 2, 3), (3, 4)) self._check_gradient('abc,bad->abcd', (1, 2, 3), (2, 1, 4))
self._check_gradient('abc,bad->abcd', (1, 2, 3), (2, 1, 4)) # Based on https://github.com/google/jax/issues/37#issuecomment-448572187
# Based on https://github.com/google/jax/issues/37#issuecomment-448572187 self._check_gradient('sa,shb->shab', (2, 1), (2, 3, 4))
self._check_gradient('sa,shb->shab', (2, 1), (2, 3, 4))
def test_empty(self): def test_empty(self):
with compat.forward_compatibility_horizon(2019, 10, 19): # From Transformer XL.
# From Transformer XL. self._check_gradient('ibnd,ijbn->jnd', (1, 0, 5, 10), (1, 1, 0, 5))
self._check_gradient('ibnd,ijbn->jnd', (1, 0, 5, 10), (1, 1, 0, 5))
@test_util.disable_xla('b/131919749') @test_util.disable_xla('b/131919749')
def test_reduced_indices(self): def test_reduced_indices(self):
with compat.forward_compatibility_horizon(2019, 10, 19): self._check_gradient('ba,b->', (3, 2), (3,))
self._check_gradient('ba,b->', (3, 2), (3,)) self._check_gradient('ab,ab->', (3, 4), (3, 4))
self._check_gradient('ab,ab->', (3, 4), (3, 4)) self._check_gradient('abce,badf->abcd', (1, 2, 3, 4), (2, 1, 4, 3))
self._check_gradient('abce,badf->abcd', (1, 2, 3, 4), (2, 1, 4, 3))
@test_util.disable_xla('b/131919749') @test_util.disable_xla('b/131919749')
def test_repeated_indices(self): def test_repeated_indices(self):
with compat.forward_compatibility_horizon(2019, 10, 19): # Repeated indices.
# Repeated indices. self._check_gradient('aba,a->b', (3, 4, 3), (3,))
self._check_gradient('aba,a->b', (3, 4, 3), (3,)) self._check_gradient('ijj,k->ik', (2, 3, 3), (4,))
self._check_gradient('ijj,k->ik', (2, 3, 3), (4,)) self._check_gradient('ill,k->ik', (2, 3, 3), (4,))
self._check_gradient('ill,k->ik', (2, 3, 3), (4,)) # From https://github.com/dask/dask/pull/3412#discussion_r182413444
# From https://github.com/dask/dask/pull/3412#discussion_r182413444 self._check_gradient('aab,bc->ac', (1, 1, 3), (3, 4))
self._check_gradient('aab,bc->ac', (1, 1, 3), (3, 4)) self._check_gradient('aab,bcc->ac', (2, 2, 3), (3, 4, 4))
self._check_gradient('aab,bcc->ac', (2, 2, 3), (3, 4, 4))
@test_util.disable_xla('b/131919749') @test_util.disable_xla('b/131919749')
def test_empty_with_repeated_indices(self): def test_empty_with_repeated_indices(self):
with compat.forward_compatibility_horizon(2019, 10, 19): self._check_gradient('aab,bc->ac', (0, 0, 10), (10, 10))
self._check_gradient('aab,bc->ac', (0, 0, 10), (10, 10)) self._check_gradient('aab,bc->ac', (1, 1, 0), (0, 10))
self._check_gradient('aab,bc->ac', (1, 1, 0), (0, 10)) self._check_gradient('aaab,bc->c', (0, 0, 0, 3), (3, 4))
self._check_gradient('aaab,bc->c', (0, 0, 0, 3), (3, 4))
@test_util.disable_xla('b/131919749') @test_util.disable_xla('b/131919749')
def test_broadcasting(self): def test_broadcasting(self):
with compat.forward_compatibility_horizon(2019, 10, 19): self._check_gradient('...ij,...jk->...ik', (3, 2), (2, 4))
self._check_gradient('...ij,...jk->...ik', (3, 2), (2, 4)) self._check_gradient('ij...,jk...->ik...', (3, 2, 1), (2, 4))
self._check_gradient('ij...,jk...->ik...', (3, 2, 1), (2, 4)) self._check_gradient('...ij,...jk->...ik', (3, 1, 3, 2), (1, 5, 2, 4))
self._check_gradient('...ij,...jk->...ik', (3, 1, 3, 2), (1, 5, 2, 4)) self._check_gradient('ij,jk...k->i...', (3, 2), (2, 4, 1, 4))
self._check_gradient('ij,jk...k->i...', (3, 2), (2, 4, 1, 4)) self._check_gradient('aab,b...c->a...c', (1, 1, 3), (3, 1, 1, 4))
self._check_gradient('aab,b...c->a...c', (1, 1, 3), (3, 1, 1, 4)) # Tests from dask.
# Tests from dask. self._check_gradient('...i,...j,...k->...ijk', (1, 4, 1, 2), (5, 1, 1, 3),
self._check_gradient('...i,...j,...k->...ijk', (1, 4, 1, 2), (5, 1, 1, 3), (1, 1, 1, 1, 9))
(1, 1, 1, 1, 9)) self._check_gradient('...i,...j,...k->...ijk', (1,), (1,), (1,))
self._check_gradient('...i,...j,...k->...ijk', (1,), (1,), (1,))
def test_long_cases(self): def test_long_cases(self):
with compat.forward_compatibility_horizon(2019, 10, 19): cases = [
cases = [ 'abhe,hidj,jgba,hiab,gab->ed',
'abhe,hidj,jgba,hiab,gab->ed', # Tests from dask.
# Tests from dask. 'ea,fb,abcd,gc,hd->efgh',
'ea,fb,abcd,gc,hd->efgh', ]
] dimension_map = dict(
dimension_map = dict( (c, ((ord(c) - ord('a')) % 3) + 1) for c in 'abcdefghij')
(c, ((ord(c) - ord('a')) % 3) + 1) for c in 'abcdefghij') for equation in cases:
for equation in cases: inputs = equation.split('->')[0].replace(' ', '')
inputs = equation.split('->')[0].replace(' ', '') input_shapes = []
input_shapes = [] for input_str in inputs.split(','):
for input_str in inputs.split(','): input_shapes.append(tuple([dimension_map[c] for c in input_str]))
input_shapes.append(tuple([dimension_map[c] for c in input_str])) self._check_gradient(equation, *input_shapes)
self._check_gradient(equation, *input_shapes)
@test_util.disable_xla('b/131919749') @test_util.disable_xla('b/131919749')
def test_long_cases_with_repeated_labels(self): def test_long_cases_with_repeated_labels(self):
with compat.forward_compatibility_horizon(2019, 10, 19): cases = [
cases = [ # Tests from dask.
# Tests from dask. 'fdf,cdd,ccd,afe->ae',
'fdf,cdd,ccd,afe->ae', 'fff,fae,bef,def->abd',
'fff,fae,bef,def->abd', ]
] dimension_map = dict(
dimension_map = dict( (c, ((ord(c) - ord('a')) % 3) + 1) for c in 'abcdefghij')
(c, ((ord(c) - ord('a')) % 3) + 1) for c in 'abcdefghij') for equation in cases:
for equation in cases: inputs = equation.split('->')[0].replace(' ', '')
inputs = equation.split('->')[0].replace(' ', '') input_shapes = []
input_shapes = [] for input_str in inputs.split(','):
for input_str in inputs.split(','): input_shapes.append(tuple([dimension_map[c] for c in input_str]))
input_shapes.append(tuple([dimension_map[c] for c in input_str])) self._check_gradient(equation, *input_shapes)
self._check_gradient(equation, *input_shapes)
class EinsumBenchmark(test.Benchmark): class EinsumBenchmark(test.Benchmark):