fix //tensorflow/python:slot_creator_test UT fail

This commit is contained in:
candy.dc 2019-08-19 09:58:29 +08:00
parent 2a72e99930
commit 8e77c00933

View File

@ -23,7 +23,6 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
@ -136,11 +135,14 @@ class SlotCreatorTest(test.TestCase):
slot = slot_creator.create_slot(v, v.initialized_value(), name="slot")
self.assertEqual("scope/scope/var/slot", slot.op.name)
@test_util.run_deprecated_v1
def testCreateSlotFromFirstMDimensionVariable(self):
with self.test_session():
s = variables.Variable([1.0, 2.5], name="var")
p_v = variable_scope.get_variable("var", shape=[2, 2],
partitioner=partitioned_variables.fixed_size_partitioner(2))
p_v = variable_scope.get_variable(
"var",
shape=[2, 2],
partitioner=partitioned_variables.fixed_size_partitioner(2))
for i, v in enumerate(p_v):
slot = slot_creator.create_slot(v, s.initialized_value(), name="slot")
si = slot._save_slice_info
@ -155,11 +157,14 @@ class SlotCreatorTest(test.TestCase):
self.assertAllEqual([i], si.var_offset)
self.assertAllEqual([1], si.var_shape)
@test_util.run_deprecated_v1
def testCreateSlotFromScalarVariable(self):
with self.test_session():
s = variables.Variable(1.0, name="var")
p_v = variable_scope.get_variable("var", shape=[2, 2],
partitioner=partitioned_variables.fixed_size_partitioner(2))
p_v = variable_scope.get_variable(
"var",
shape=[2, 2],
partitioner=partitioned_variables.fixed_size_partitioner(2))
for i, v in enumerate(p_v):
slot = slot_creator.create_slot(v, s.initialized_value(), name="slot")