eager/function_test to PY3
Fix test to not be sensitive to tuple order since it is generated from a set. PiperOrigin-RevId: 284923887 Change-Id: I6ca521003032e9bead025f99887d089affd4dc9b
This commit is contained in:
parent
47f9db3b0e
commit
115ea3db34
@ -383,7 +383,7 @@ cuda_py_test(
|
||||
"//tensorflow/python:tensor_spec",
|
||||
"//tensorflow/python:test_ops",
|
||||
],
|
||||
python_version = "PY2",
|
||||
python_version = "PY3",
|
||||
shard_count = 15,
|
||||
)
|
||||
|
||||
|
||||
@ -338,10 +338,14 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
cf = f.get_concrete_function()
|
||||
c = cc[0]
|
||||
|
||||
self.assertEqual(cf.variables, (a, b, c))
|
||||
self.assertEqual(cf.trainable_variables, (b, c))
|
||||
self.assertEqual(cf.graph.variables, (a, b, c))
|
||||
self.assertEqual(cf.graph.trainable_variables, (b, c))
|
||||
captured_variables = {v.experimental_ref() for v in (a, b, c)}
|
||||
trainable_variables = {v.experimental_ref() for v in (b, c)}
|
||||
self.assertEqual({v.experimental_ref() for v in cf.variables},
|
||||
captured_variables)
|
||||
self.assertEqual({v.experimental_ref() for v in cf.trainable_variables},
|
||||
trainable_variables)
|
||||
self.assertEqual(cf.variables, cf.graph.variables)
|
||||
self.assertEqual(cf.trainable_variables, cf.graph.trainable_variables)
|
||||
|
||||
def testNestedInputShapeFunctionRelaxation(self):
|
||||
unknown_dim = [False]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user