Fix up tests to work with TensorShapeV2
PiperOrigin-RevId: 225049315
This commit is contained in:
parent
39b6e1924e
commit
6756eee557
@ -93,7 +93,7 @@ class TensorFunctionsTest(test_util.TensorFlowTestCase):
|
|||||||
str(error.exception))
|
str(error.exception))
|
||||||
self.assertEqual([None, 3, 5], tensor.shape.as_list())
|
self.assertEqual([None, 3, 5], tensor.shape.as_list())
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_deprecated_v1
|
||||||
def testSetTensorShapeDimensionInvalid(self):
|
def testSetTensorShapeDimensionInvalid(self):
|
||||||
# Tests set_tensor_shape where the shape passed in is incompatiable.
|
# Tests set_tensor_shape where the shape passed in is incompatiable.
|
||||||
tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32)
|
tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32)
|
||||||
@ -102,9 +102,8 @@ class TensorFunctionsTest(test_util.TensorFlowTestCase):
|
|||||||
with self.assertRaises(ValueError) as error:
|
with self.assertRaises(ValueError) as error:
|
||||||
convert_saved_model.set_tensor_shapes([tensor],
|
convert_saved_model.set_tensor_shapes([tensor],
|
||||||
{"Placeholder": [1, 5, 5]})
|
{"Placeholder": [1, 5, 5]})
|
||||||
self.assertIn(
|
self.assertIn("The shape of tensor 'Placeholder' cannot be changed",
|
||||||
"The shape of tensor 'Placeholder' cannot be changed from "
|
str(error.exception))
|
||||||
"(?, 3, 5) to [1, 5, 5].", str(error.exception))
|
|
||||||
self.assertEqual([None, 3, 5], tensor.shape.as_list())
|
self.assertEqual([None, 3, 5], tensor.shape.as_list())
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("b/120545219")
|
||||||
|
@ -1104,8 +1104,13 @@ class GradientTape(object):
|
|||||||
dimension of `target` and `source` do not match.
|
dimension of `target` and `source` do not match.
|
||||||
"""
|
"""
|
||||||
target_shape = target.shape
|
target_shape = target.shape
|
||||||
if not target_shape.with_rank_at_least(2)[0].is_compatible_with(
|
if target_shape.rank is None:
|
||||||
source.shape.with_rank_at_least(2)[0]):
|
dim = Dimension(None)
|
||||||
|
else:
|
||||||
|
dim = target_shape.dims[0]
|
||||||
|
if not (target_shape.with_rank_at_least(2) and
|
||||||
|
source.shape.with_rank_at_least(2) and
|
||||||
|
dim.is_compatible_with(source.shape[0])):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Need first dimension of target shape (%s) and "
|
"Need first dimension of target shape (%s) and "
|
||||||
"source shape (%s) to match." % (target.shape, source.shape))
|
"source shape (%s) to match." % (target.shape, source.shape))
|
||||||
|
@ -1338,17 +1338,14 @@ class BatchJacobianTest(test.TestCase):
|
|||||||
array_ops.diag(2 * x[1] * y[1])])
|
array_ops.diag(2 * x[1] * y[1])])
|
||||||
return batch_jacobian, answer
|
return batch_jacobian, answer
|
||||||
|
|
||||||
@test_util.run_v1_only('b/120545219')
|
|
||||||
def testPfor(self):
|
def testPfor(self):
|
||||||
batch_jacobian, answer = self._batch_jacobian(experimental_use_pfor=True)
|
batch_jacobian, answer = self._batch_jacobian(experimental_use_pfor=True)
|
||||||
self.assertAllEqual(answer, batch_jacobian)
|
self.assertAllEqual(answer, batch_jacobian)
|
||||||
|
|
||||||
@test_util.run_v1_only('b/120545219')
|
|
||||||
def testWhileLoop(self):
|
def testWhileLoop(self):
|
||||||
batch_jacobian, answer = self._batch_jacobian(experimental_use_pfor=False)
|
batch_jacobian, answer = self._batch_jacobian(experimental_use_pfor=False)
|
||||||
self.assertAllEqual(answer, batch_jacobian)
|
self.assertAllEqual(answer, batch_jacobian)
|
||||||
|
|
||||||
@test_util.run_v1_only('b/120545219')
|
|
||||||
def testPforDefun(self):
|
def testPforDefun(self):
|
||||||
|
|
||||||
@function.defun
|
@function.defun
|
||||||
@ -1358,7 +1355,6 @@ class BatchJacobianTest(test.TestCase):
|
|||||||
batch_jacobian, answer = _f()
|
batch_jacobian, answer = _f()
|
||||||
self.assertAllEqual(answer, batch_jacobian)
|
self.assertAllEqual(answer, batch_jacobian)
|
||||||
|
|
||||||
@test_util.run_v1_only('b/120545219')
|
|
||||||
def testWhileLoopDefun(self):
|
def testWhileLoopDefun(self):
|
||||||
|
|
||||||
@function.defun
|
@function.defun
|
||||||
@ -1368,7 +1364,6 @@ class BatchJacobianTest(test.TestCase):
|
|||||||
batch_jacobian, answer = _f()
|
batch_jacobian, answer = _f()
|
||||||
self.assertAllEqual(answer, batch_jacobian)
|
self.assertAllEqual(answer, batch_jacobian)
|
||||||
|
|
||||||
@test_util.run_v1_only('b/120545219')
|
|
||||||
def testPersistentTape(self):
|
def testPersistentTape(self):
|
||||||
if not context.executing_eagerly():
|
if not context.executing_eagerly():
|
||||||
return
|
return
|
||||||
@ -1379,7 +1374,6 @@ class BatchJacobianTest(test.TestCase):
|
|||||||
with self.assertRaisesRegexp(RuntimeError, 'persistent'):
|
with self.assertRaisesRegexp(RuntimeError, 'persistent'):
|
||||||
g.batch_jacobian(y, x, experimental_use_pfor=False)
|
g.batch_jacobian(y, x, experimental_use_pfor=False)
|
||||||
|
|
||||||
@test_util.run_v1_only('b/120545219')
|
|
||||||
def testBadShape(self):
|
def testBadShape(self):
|
||||||
x = random_ops.random_uniform([2, 3])
|
x = random_ops.random_uniform([2, 3])
|
||||||
with backprop.GradientTape() as g:
|
with backprop.GradientTape() as g:
|
||||||
@ -1387,7 +1381,6 @@ class BatchJacobianTest(test.TestCase):
|
|||||||
with self.assertRaisesRegexp(ValueError, 'Need first dimension'):
|
with self.assertRaisesRegexp(ValueError, 'Need first dimension'):
|
||||||
g.batch_jacobian(y, x)
|
g.batch_jacobian(y, x)
|
||||||
|
|
||||||
@test_util.run_v1_only('b/120545219')
|
|
||||||
def testBadInputRank(self):
|
def testBadInputRank(self):
|
||||||
x = random_ops.random_uniform([2])
|
x = random_ops.random_uniform([2])
|
||||||
with backprop.GradientTape() as g:
|
with backprop.GradientTape() as g:
|
||||||
@ -1402,7 +1395,6 @@ class BatchJacobianTest(test.TestCase):
|
|||||||
with self.assertRaisesRegexp(ValueError, 'must have rank at least 2'):
|
with self.assertRaisesRegexp(ValueError, 'must have rank at least 2'):
|
||||||
g.batch_jacobian(y, x)
|
g.batch_jacobian(y, x)
|
||||||
|
|
||||||
@test_util.run_v1_only('b/120545219')
|
|
||||||
def testPforException(self):
|
def testPforException(self):
|
||||||
var = variables.Variable([1.])
|
var = variables.Variable([1.])
|
||||||
|
|
||||||
@ -1423,7 +1415,6 @@ class BatchJacobianTest(test.TestCase):
|
|||||||
with self.assertRaisesRegexp(ValueError, 'No converter'):
|
with self.assertRaisesRegexp(ValueError, 'No converter'):
|
||||||
g.batch_jacobian(y, x, experimental_use_pfor=True)
|
g.batch_jacobian(y, x, experimental_use_pfor=True)
|
||||||
|
|
||||||
@test_util.run_v1_only('b/120545219')
|
|
||||||
def test_parallel_iterations(self):
|
def test_parallel_iterations(self):
|
||||||
with backprop.GradientTape(persistent=True) as g:
|
with backprop.GradientTape(persistent=True) as g:
|
||||||
x = constant_op.constant([[1., 2], [3, 4]])
|
x = constant_op.constant([[1., 2], [3, 4]])
|
||||||
|
@ -330,7 +330,6 @@ class OpsTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEquals(t, dtypes.string)
|
self.assertEquals(t, dtypes.string)
|
||||||
self.assertEquals(r[0].dtype, dtypes.string)
|
self.assertEquals(r[0].dtype, dtypes.string)
|
||||||
|
|
||||||
@test_util.run_v1_only('b/120545219')
|
|
||||||
def testFlattenLayer(self):
|
def testFlattenLayer(self):
|
||||||
flatten_layer = core.Flatten()
|
flatten_layer = core.Flatten()
|
||||||
x = constant_op.constant([[[-10, -20], [-30, -40]], [[10, 20], [30, 40]]])
|
x = constant_op.constant([[[-10, -20], [-30, -40]], [[10, 20], [30, 40]]])
|
||||||
|
@ -134,7 +134,6 @@ class KerasIntegrationTest(test.TestCase):
|
|||||||
verbose=2)
|
verbose=2)
|
||||||
self.assertGreater(history.history['val_acc'][-1], 0.7)
|
self.assertGreater(history.history['val_acc'][-1], 0.7)
|
||||||
|
|
||||||
@test_util.run_v1_only('b/120545219')
|
|
||||||
def test_image_classification_sequential(self):
|
def test_image_classification_sequential(self):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
np.random.seed(1337)
|
np.random.seed(1337)
|
||||||
|
@ -549,8 +549,8 @@ class Flatten(Layer):
|
|||||||
inputs = array_ops.transpose(inputs, perm=permutation)
|
inputs = array_ops.transpose(inputs, perm=permutation)
|
||||||
|
|
||||||
outputs = array_ops.reshape(
|
outputs = array_ops.reshape(
|
||||||
inputs, (tensor_shape.dimension_value(inputs.shape[0])
|
inputs, (tensor_shape.dimension_value(inputs.shape[0]) or
|
||||||
or array_ops.shape(inputs)[0], -1))
|
array_ops.shape(inputs)[0], -1))
|
||||||
if not context.executing_eagerly():
|
if not context.executing_eagerly():
|
||||||
outputs.set_shape(self.compute_output_shape(inputs.get_shape()))
|
outputs.set_shape(self.compute_output_shape(inputs.get_shape()))
|
||||||
return outputs
|
return outputs
|
||||||
|
@ -135,7 +135,6 @@ class CoreLayersTest(test.TestCase):
|
|||||||
kwargs={'dims': (1, 4, 2)}, input_shape=(3, 2, 4))
|
kwargs={'dims': (1, 4, 2)}, input_shape=(3, 2, 4))
|
||||||
|
|
||||||
@tf_test_util.run_in_graph_and_eager_modes
|
@tf_test_util.run_in_graph_and_eager_modes
|
||||||
@tf_test_util.run_v1_only('b/120545219')
|
|
||||||
def test_flatten(self):
|
def test_flatten(self):
|
||||||
testing_utils.layer_test(
|
testing_utils.layer_test(
|
||||||
keras.layers.Flatten, kwargs={}, input_shape=(3, 2, 4))
|
keras.layers.Flatten, kwargs={}, input_shape=(3, 2, 4))
|
||||||
@ -151,7 +150,6 @@ class CoreLayersTest(test.TestCase):
|
|||||||
self.assertAllClose(outputs, target_outputs)
|
self.assertAllClose(outputs, target_outputs)
|
||||||
|
|
||||||
@tf_test_util.run_in_graph_and_eager_modes
|
@tf_test_util.run_in_graph_and_eager_modes
|
||||||
@tf_test_util.run_v1_only('b/120545219')
|
|
||||||
def test_flatten_scalar_channels(self):
|
def test_flatten_scalar_channels(self):
|
||||||
testing_utils.layer_test(
|
testing_utils.layer_test(
|
||||||
keras.layers.Flatten, kwargs={}, input_shape=(3,))
|
keras.layers.Flatten, kwargs={}, input_shape=(3,))
|
||||||
|
@ -1516,12 +1516,12 @@ class ControlFlowTest(test.TestCase):
|
|||||||
]
|
]
|
||||||
|
|
||||||
_, r = control_flow_ops.while_loop(c, b, [i, x])
|
_, r = control_flow_ops.while_loop(c, b, [i, x])
|
||||||
self.assertEqual(r.dense_shape.get_shape()[0].value, 1)
|
self.assertEqual(r.dense_shape.get_shape()[0], 1)
|
||||||
|
|
||||||
_, r = control_flow_ops.while_loop(
|
_, r = control_flow_ops.while_loop(
|
||||||
c, b, [i, x],
|
c, b, [i, x],
|
||||||
[i.get_shape(), tensor_shape.TensorShape([None])])
|
[i.get_shape(), tensor_shape.TensorShape([None])])
|
||||||
self.assertTrue(r.dense_shape.get_shape()[0].value is None)
|
self.assertEqual(r.dense_shape.get_shape().as_list(), [None])
|
||||||
|
|
||||||
with self.assertRaisesRegexp(ValueError, "is not compatible with"):
|
with self.assertRaisesRegexp(ValueError, "is not compatible with"):
|
||||||
_, r = control_flow_ops.while_loop(
|
_, r = control_flow_ops.while_loop(
|
||||||
@ -1548,15 +1548,14 @@ class ControlFlowTest(test.TestCase):
|
|||||||
]
|
]
|
||||||
|
|
||||||
_, r = control_flow_ops.while_loop(c, b, [i, x])
|
_, r = control_flow_ops.while_loop(c, b, [i, x])
|
||||||
self.assertEqual(r.dense_shape.get_shape()[0].value, 2)
|
self.assertEqual(r.dense_shape.get_shape()[0], 2)
|
||||||
self.assertEqual(r.values.get_shape(), tensor_shape.TensorShape([2, 2]))
|
self.assertEqual(r.values.get_shape(), tensor_shape.TensorShape([2, 2]))
|
||||||
|
|
||||||
_, r = control_flow_ops.while_loop(
|
_, r = control_flow_ops.while_loop(
|
||||||
c, b, [i, x],
|
c, b, [i, x],
|
||||||
[i.get_shape(), tensor_shape.TensorShape([None, 2])])
|
[i.get_shape(), tensor_shape.TensorShape([None, 2])])
|
||||||
self.assertEqual(r.dense_shape.get_shape()[0].value, 2)
|
self.assertEqual(r.dense_shape.get_shape()[0], 2)
|
||||||
self.assertTrue(r.values.get_shape()[0].value is None)
|
self.assertEqual(r.values.get_shape().as_list(), [None, 2])
|
||||||
self.assertEqual(r.values.get_shape()[1].value, 2)
|
|
||||||
|
|
||||||
with self.assertRaisesRegexp(ValueError, "is not compatible with"):
|
with self.assertRaisesRegexp(ValueError, "is not compatible with"):
|
||||||
_, r = control_flow_ops.while_loop(
|
_, r = control_flow_ops.while_loop(
|
||||||
@ -1925,7 +1924,7 @@ class ControlFlowTest(test.TestCase):
|
|||||||
self.assertAllClose(np.array([10.0, 10.0, 10.0]), result2)
|
self.assertAllClose(np.array([10.0, 10.0, 10.0]), result2)
|
||||||
|
|
||||||
@test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
|
@test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_v1_only("b/120545219")
|
||||||
def testWhileUpdateVariable_3(self):
|
def testWhileUpdateVariable_3(self):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
select = variables.Variable([3.0, 4.0, 5.0])
|
select = variables.Variable([3.0, 4.0, 5.0])
|
||||||
|
@ -242,7 +242,6 @@ class CTCLossTest(test.TestCase):
|
|||||||
|
|
||||||
self._testCTCLoss(inputs, seq_lens, labels, loss_truth, grad_truth)
|
self._testCTCLoss(inputs, seq_lens, labels, loss_truth, grad_truth)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
|
||||||
def test_time_major(self):
|
def test_time_major(self):
|
||||||
"""Testing time_major param.
|
"""Testing time_major param.
|
||||||
|
|
||||||
@ -565,7 +564,6 @@ class CTCLossTestV2(test.TestCase):
|
|||||||
rtol=2e-06,
|
rtol=2e-06,
|
||||||
atol=2e-06)
|
atol=2e-06)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
|
||||||
def testCollapseRepeated(self):
|
def testCollapseRepeated(self):
|
||||||
collapsed, new_seq_lengths = ctc_ops.collapse_repeated(
|
collapsed, new_seq_lengths = ctc_ops.collapse_repeated(
|
||||||
labels=[[1, 3, 3, 3, 0],
|
labels=[[1, 3, 3, 3, 0],
|
||||||
@ -579,7 +577,6 @@ class CTCLossTestV2(test.TestCase):
|
|||||||
[1, 4, 0, 0],
|
[1, 4, 0, 0],
|
||||||
[4, 2, 9, 4]])
|
[4, 2, 9, 4]])
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
|
||||||
def testCollapseRepeatedPreservesDtypes(self):
|
def testCollapseRepeatedPreservesDtypes(self):
|
||||||
collapsed, new_seq_lengths = ctc_ops.collapse_repeated(
|
collapsed, new_seq_lengths = ctc_ops.collapse_repeated(
|
||||||
labels=constant_op.constant(
|
labels=constant_op.constant(
|
||||||
@ -597,7 +594,6 @@ class CTCLossTestV2(test.TestCase):
|
|||||||
[1, 4, 0, 0],
|
[1, 4, 0, 0],
|
||||||
[4, 2, 9, 4]])
|
[4, 2, 9, 4]])
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
|
||||||
def testCollapseRepeatedExtraPadding(self):
|
def testCollapseRepeatedExtraPadding(self):
|
||||||
collapsed, new_seq_lengths = ctc_ops.collapse_repeated(
|
collapsed, new_seq_lengths = ctc_ops.collapse_repeated(
|
||||||
labels=[[1, 3, 3, 3, 0, 0, 0],
|
labels=[[1, 3, 3, 3, 0, 0, 0],
|
||||||
@ -611,7 +607,6 @@ class CTCLossTestV2(test.TestCase):
|
|||||||
[1, 4, 0, 0],
|
[1, 4, 0, 0],
|
||||||
[4, 2, 9, 4]])
|
[4, 2, 9, 4]])
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
|
||||||
def testCollapseRepeatedFrontRepeats(self):
|
def testCollapseRepeatedFrontRepeats(self):
|
||||||
collapsed, new_seq_lengths = ctc_ops.collapse_repeated(
|
collapsed, new_seq_lengths = ctc_ops.collapse_repeated(
|
||||||
labels=[[1, 1, 1, 2, 2],
|
labels=[[1, 1, 1, 2, 2],
|
||||||
@ -625,7 +620,6 @@ class CTCLossTestV2(test.TestCase):
|
|||||||
[1, 2],
|
[1, 2],
|
||||||
[1, 0]])
|
[1, 0]])
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
|
||||||
def testCollapseRepeatedAllLabelsTheSame(self):
|
def testCollapseRepeatedAllLabelsTheSame(self):
|
||||||
collapsed, new_seq_lengths = ctc_ops.collapse_repeated(
|
collapsed, new_seq_lengths = ctc_ops.collapse_repeated(
|
||||||
labels=[[1, 1, 1, 1, 1],
|
labels=[[1, 1, 1, 1, 1],
|
||||||
@ -658,7 +652,6 @@ class CTCLossTestV2(test.TestCase):
|
|||||||
|
|
||||||
self.assertAllEqual(padded_dense, new_dense)
|
self.assertAllEqual(padded_dense, new_dense)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
|
||||||
def testUnique(self):
|
def testUnique(self):
|
||||||
labels = [
|
labels = [
|
||||||
[3, 4, 4, 3],
|
[3, 4, 4, 3],
|
||||||
@ -674,7 +667,6 @@ class CTCLossTestV2(test.TestCase):
|
|||||||
[0, 0, 0, 1],
|
[0, 0, 0, 1],
|
||||||
], idx)
|
], idx)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
|
||||||
def testSumStates(self):
|
def testSumStates(self):
|
||||||
idx = [
|
idx = [
|
||||||
[0, 1, 0, 1],
|
[0, 1, 0, 1],
|
||||||
@ -694,7 +686,6 @@ class CTCLossTestV2(test.TestCase):
|
|||||||
[1.8, 0.8, 0.0, 0.0]]
|
[1.8, 0.8, 0.0, 0.0]]
|
||||||
], sum_of_states)
|
], sum_of_states)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
|
||||||
def testStateToOlabel(self):
|
def testStateToOlabel(self):
|
||||||
labels = [
|
labels = [
|
||||||
[3, 4, 3, 4],
|
[3, 4, 3, 4],
|
||||||
@ -733,7 +724,6 @@ class CTCLossTestV2(test.TestCase):
|
|||||||
[22.0 + 23.0 + 24.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
|
[22.0 + 23.0 + 24.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
|
||||||
])
|
])
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
|
||||||
def testStateToOlabelUnique(self):
|
def testStateToOlabelUnique(self):
|
||||||
labels = [
|
labels = [
|
||||||
[3, 4, 3, 4],
|
[3, 4, 3, 4],
|
||||||
|
@ -214,7 +214,7 @@ class LinearOperatorTest(test.TestCase):
|
|||||||
operator = LinearOperatorMatmulSolve(matrix, is_square=True)
|
operator = LinearOperatorMatmulSolve(matrix, is_square=True)
|
||||||
self.assertTrue(operator.is_square)
|
self.assertTrue(operator.is_square)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_deprecated_v1
|
||||||
def test_linear_operator_matmul_hints_closed(self):
|
def test_linear_operator_matmul_hints_closed(self):
|
||||||
matrix = array_ops.placeholder(dtypes.float32)
|
matrix = array_ops.placeholder(dtypes.float32)
|
||||||
operator1 = LinearOperatorMatmulSolve(matrix)
|
operator1 = LinearOperatorMatmulSolve(matrix)
|
||||||
@ -241,7 +241,7 @@ class LinearOperatorTest(test.TestCase):
|
|||||||
self.assertTrue(operator_matmul.is_self_adjoint)
|
self.assertTrue(operator_matmul.is_self_adjoint)
|
||||||
self.assertEqual(None, operator_matmul.is_positive_definite)
|
self.assertEqual(None, operator_matmul.is_positive_definite)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_deprecated_v1
|
||||||
def test_linear_operator_matmul_hints_false(self):
|
def test_linear_operator_matmul_hints_false(self):
|
||||||
matrix = array_ops.placeholder(dtypes.float32)
|
matrix = array_ops.placeholder(dtypes.float32)
|
||||||
operator1 = LinearOperatorMatmulSolve(
|
operator1 = LinearOperatorMatmulSolve(
|
||||||
@ -274,7 +274,7 @@ class LinearOperatorTest(test.TestCase):
|
|||||||
self.assertEqual(None, operator_matmul.is_self_adjoint)
|
self.assertEqual(None, operator_matmul.is_self_adjoint)
|
||||||
self.assertEqual(None, operator_matmul.is_positive_definite)
|
self.assertEqual(None, operator_matmul.is_positive_definite)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_deprecated_v1
|
||||||
def test_linear_operator_matmul_hint_infer_square(self):
|
def test_linear_operator_matmul_hint_infer_square(self):
|
||||||
matrix1 = array_ops.placeholder(shape=[2, 3], dtype=dtypes.float32)
|
matrix1 = array_ops.placeholder(shape=[2, 3], dtype=dtypes.float32)
|
||||||
matrix2 = array_ops.placeholder(shape=[3, 2], dtype=dtypes.float32)
|
matrix2 = array_ops.placeholder(shape=[3, 2], dtype=dtypes.float32)
|
||||||
|
@ -463,9 +463,9 @@ class DropoutTest(test.TestCase):
|
|||||||
self.assertAllClose(np.ones((5, 5)), np_output)
|
self.assertAllClose(np.ones((5, 5)), np_output)
|
||||||
|
|
||||||
|
|
||||||
@test_util.run_v1_only('b/120545219')
|
|
||||||
class FlattenTest(test.TestCase):
|
class FlattenTest(test.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testCreateFlatten(self):
|
def testCreateFlatten(self):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
x = array_ops.placeholder(shape=(None, 2, 3), dtype='float32')
|
x = array_ops.placeholder(shape=(None, 2, 3), dtype='float32')
|
||||||
@ -490,6 +490,7 @@ class FlattenTest(test.TestCase):
|
|||||||
shape = core_layers.Flatten().compute_output_shape((None, 3, None))
|
shape = core_layers.Flatten().compute_output_shape((None, 3, None))
|
||||||
self.assertEqual(shape.as_list(), [None, None])
|
self.assertEqual(shape.as_list(), [None, None])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testDataFormat5d(self):
|
def testDataFormat5d(self):
|
||||||
np_input_channels_last = np.arange(
|
np_input_channels_last = np.arange(
|
||||||
120, dtype='float32').reshape([1, 5, 4, 3, 2])
|
120, dtype='float32').reshape([1, 5, 4, 3, 2])
|
||||||
@ -507,6 +508,7 @@ class FlattenTest(test.TestCase):
|
|||||||
|
|
||||||
self.assertAllEqual(np_output_cl, np_output_cf)
|
self.assertAllEqual(np_output_cl, np_output_cf)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testDataFormat4d(self):
|
def testDataFormat4d(self):
|
||||||
np_input_channels_last = np.arange(
|
np_input_channels_last = np.arange(
|
||||||
24, dtype='float32').reshape([1, 4, 3, 2])
|
24, dtype='float32').reshape([1, 4, 3, 2])
|
||||||
@ -524,11 +526,13 @@ class FlattenTest(test.TestCase):
|
|||||||
|
|
||||||
self.assertAllEqual(np_output_cl, np_output_cf)
|
self.assertAllEqual(np_output_cl, np_output_cf)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testFunctionalFlatten(self):
|
def testFunctionalFlatten(self):
|
||||||
x = array_ops.placeholder(shape=(None, 2, 3), dtype='float32')
|
x = array_ops.placeholder(shape=(None, 2, 3), dtype='float32')
|
||||||
y = core_layers.flatten(x, name='flatten')
|
y = core_layers.flatten(x, name='flatten')
|
||||||
self.assertEqual(y.get_shape().as_list(), [None, 6])
|
self.assertEqual(y.get_shape().as_list(), [None, 6])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testFlatten0D(self):
|
def testFlatten0D(self):
|
||||||
x = array_ops.placeholder(shape=(None,), dtype='float32')
|
x = array_ops.placeholder(shape=(None,), dtype='float32')
|
||||||
y = core_layers.Flatten()(x)
|
y = core_layers.Flatten()(x)
|
||||||
@ -537,6 +541,7 @@ class FlattenTest(test.TestCase):
|
|||||||
self.assertEqual(list(np_output.shape), [5, 1])
|
self.assertEqual(list(np_output.shape), [5, 1])
|
||||||
self.assertEqual(y.shape.as_list(), [None, 1])
|
self.assertEqual(y.shape.as_list(), [None, 1])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testFlattenUnknownAxes(self):
|
def testFlattenUnknownAxes(self):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
x = array_ops.placeholder(shape=(5, None, None), dtype='float32')
|
x = array_ops.placeholder(shape=(5, None, None), dtype='float32')
|
||||||
|
@ -24,6 +24,7 @@ from tensorflow.python.framework import dtypes
|
|||||||
from tensorflow.python.framework import function
|
from tensorflow.python.framework import function
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
|
from tensorflow.python.framework import tensor_shape
|
||||||
|
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import functional_ops
|
from tensorflow.python.ops import functional_ops
|
||||||
@ -1127,4 +1128,5 @@ def _scan(fn, elems, initial, reverse=False, inclusive=False, final_only=False):
|
|||||||
|
|
||||||
def _get_dim(tensor, i):
|
def _get_dim(tensor, i):
|
||||||
"""Get value of tensor shape[i] preferring static value if available."""
|
"""Get value of tensor shape[i] preferring static value if available."""
|
||||||
return tensor.shape[i].value or array_ops.shape(tensor)[i]
|
return tensor_shape.dimension_value(
|
||||||
|
tensor.shape[i]) or array_ops.shape(tensor)[i]
|
||||||
|
@ -381,7 +381,10 @@ class LinearOperator(object):
|
|||||||
`Dimension` object.
|
`Dimension` object.
|
||||||
"""
|
"""
|
||||||
# Derived classes get this "for free" once .shape is implemented.
|
# Derived classes get this "for free" once .shape is implemented.
|
||||||
return self.shape[-1]
|
if self.shape.rank is None:
|
||||||
|
return tensor_shape.Dimension(None)
|
||||||
|
else:
|
||||||
|
return self.shape.dims[-1]
|
||||||
|
|
||||||
def domain_dimension_tensor(self, name="domain_dimension_tensor"):
|
def domain_dimension_tensor(self, name="domain_dimension_tensor"):
|
||||||
"""Dimension (in the sense of vector spaces) of the domain of this operator.
|
"""Dimension (in the sense of vector spaces) of the domain of this operator.
|
||||||
|
Loading…
Reference in New Issue
Block a user