STT-tensorflow/tensorflow/python/eager/context_test.py
Akshay Modi fa5cdeae7e Add a functiondef getter to the context
PiperOrigin-RevId: 296002833
Change-Id: I238a2984a9320c084b7157e6eeb30b30aa132036
2020-02-19 10:48:38 -08:00

115 lines
3.5 KiB
Python

# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import weakref
import numpy as np
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.platform import test
class ContextTest(test.TestCase):
def testSetGlobalSeed(self):
c = context.Context()
c._set_global_seed(123)
for t in [np.int32, np.int64, np.uint32, np.uint64]:
c._set_global_seed(t(123))
c._set_global_seed(np.array(123, dtype=t))
c._set_global_seed(ops.convert_to_tensor(123, dtype=t))
def testContextIsDestroyedAfterTensors(self):
# Create a new context
new_context = context.Context()
weak_c = weakref.ref(new_context)
new_context.ensure_initialized()
# Create a tensor with the new context as default.
# Make sure to restore the original context.
original_context = context.context()
try:
context._set_context(new_context)
# Use a 2D tensor so that it is not cached.
tensor1 = constant_op.constant([[3.]])
# Produce a tensor as an operation output. This uses a different code path
# from tensors created from Python.
tensor2 = tensor1 * tensor1
context._set_context(original_context)
except:
context._set_context(original_context)
raise
# Deleting our context reference should not delete the underlying object.
del new_context
self.assertIsNot(weak_c(), None)
# Deleting the first tensor should not delete the context since there is
# another tensor.
del tensor1
self.assertIsNot(weak_c(), None)
# Deleting the last tensor should result in deleting its context.
del tensor2
self.assertIs(weak_c(), None)
def testSimpleGraphCollection(self):
@def_function.function
def f(x):
return x + constant_op.constant(1.)
with context.collect_graphs() as graphs:
with ops.device('CPU:0'):
f(constant_op.constant(1.))
self.assertLen(graphs, 1)
graph, = graphs
self.assertIn('CPU:0', graph.node[0].device)
def testGetFunctionDef(self):
@def_function.function
def f():
return constant_op.constant(1.)
concrete = f.get_concrete_function()
function_def = context.get_function_def(concrete.name)
self.assertIsNot(function_def, None)
found_const_node = False
for node_def in function_def.node_def:
if node_def.op == 'Const':
found_const_node = True
break
self.assertTrue(found_const_node)
with self.assertRaises(errors.NotFoundError):
_ = context.get_function_def('this_should_not_be_found')
if __name__ == '__main__':
ops.enable_eager_execution()
test.main()