From e05fdf7a7a1d5dcf434f23504ac868ab7740014e Mon Sep 17 00:00:00 2001 From: Ran Chen Date: Fri, 7 Feb 2020 13:43:09 -0800 Subject: [PATCH] Replace usage of Graph._building_function with a public one Not sure why people use it in the first place. PiperOrigin-RevId: 293885535 Change-Id: Ibdd3bc050b29b3fe88aebee4e3f6c804fe21813e --- tensorflow/python/data/kernel_tests/test_base.py | 3 +-- tensorflow/python/data/ops/dataset_ops.py | 9 +++------ tensorflow/python/data/ops/multi_device_iterator_ops.py | 3 +-- tensorflow/python/framework/ops.py | 2 +- tensorflow/python/framework/test_util.py | 2 +- tensorflow/python/framework/test_util_test.py | 3 +-- tensorflow/python/ops/control_flow_state.py | 2 +- tensorflow/python/ops/math_ops.py | 2 +- tensorflow/python/ops/resource_variable_ops.py | 3 +-- tensorflow/python/util/tf_should_use.py | 4 ++-- 10 files changed, 13 insertions(+), 20 deletions(-) diff --git a/tensorflow/python/data/kernel_tests/test_base.py b/tensorflow/python/data/kernel_tests/test_base.py index 4b92fea9feb..10a533f8166 100644 --- a/tensorflow/python/data/kernel_tests/test_base.py +++ b/tensorflow/python/data/kernel_tests/test_base.py @@ -105,8 +105,7 @@ class DatasetTestBase(test.TestCase): # Create an anonymous iterator if we are in eager-mode or are graph inside # of a tf.function. - building_function = ops.get_default_graph()._building_function # pylint: disable=protected-access - if context.executing_eagerly() or building_function: + if context.executing_eagerly() or ops.inside_function(): iterator = iter(dataset) return ta_wrapper(iterator._next_internal) # pylint: disable=protected-access else: diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index caedbb2996f..f796556202e 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -400,8 +400,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): Raises: RuntimeError: If not inside of tf.function and not executing eagerly. """ - if (context.executing_eagerly() - or ops.get_default_graph()._building_function): # pylint: disable=protected-access + if context.executing_eagerly() or ops.inside_function(): return iterator_ops.OwnedIterator(self) else: raise RuntimeError("__iter__() is only supported inside of tf.function " @@ -3446,8 +3445,7 @@ class CacheDataset(UnaryUnchangedStructureDataset): self._input_dataset = input_dataset self._filename = ops.convert_to_tensor( filename, dtype=dtypes.string, name="filename") - if tf2.enabled() and (context.executing_eagerly() or - ops.get_default_graph()._building_function): # pylint: disable=protected-access + if tf2.enabled() and (context.executing_eagerly() or ops.inside_function()): self._cache = _MemoryCache() variant_tensor = gen_dataset_ops.cache_dataset_v2( input_dataset._variant_tensor, # pylint: disable=protected-access @@ -3542,8 +3540,7 @@ class ShuffleDataset(UnaryUnchangedStructureDataset): self._reshuffle_each_iteration = reshuffle_each_iteration if tf2.enabled() and self._reshuffle_each_iteration and ( - context.executing_eagerly() or - ops.get_default_graph()._building_function): # pylint: disable=protected-access + context.executing_eagerly() or ops.inside_function()): self._seed_generator = _RandomSeedGenerator(self._seed, self._seed2) variant_tensor = gen_dataset_ops.shuffle_dataset_v2( input_dataset._variant_tensor, # pylint: disable=protected-access diff --git a/tensorflow/python/data/ops/multi_device_iterator_ops.py b/tensorflow/python/data/ops/multi_device_iterator_ops.py index ec6ff52a7aa..bd79be2b352 100644 --- a/tensorflow/python/data/ops/multi_device_iterator_ops.py +++ b/tensorflow/python/data/ops/multi_device_iterator_ops.py @@ -490,8 +490,7 @@ class OwnedMultiDeviceIterator(composite_tensor.CompositeTensor): RuntimeError: If executed in graph mode or outside of function building mode. """ - if (not context.executing_eagerly() and - not ops.get_default_graph()._building_function): # pylint: disable=protected-access + if not context.executing_eagerly() and not ops.inside_function(): raise RuntimeError("OwnedMultiDeviceIterator is only supported inside of " "tf.function or when eager execution is enabled.") if devices is None: diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 053d34c8da6..a49966c9f85 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -723,7 +723,7 @@ class Tensor(_TensorLike): def __hash__(self): g = getattr(self, "graph", None) if (Tensor._USE_EQUALITY and executing_eagerly_outside_functions() and - (g is None or g._building_function)): # pylint: disable=protected-access + (g is None or g.building_function)): raise TypeError("Tensor is unhashable if Tensor equality is enabled. " "Instead, use tensor.experimental_ref() as the key.") else: diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 5d5a445261e..b57a68eb059 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -1147,7 +1147,7 @@ def run_in_graph_and_eager_modes(func=None, def py_func_if_in_function(f): def decorated(*args, **kwds): - if not ops.get_default_graph()._building_function: + if not ops.inside_function(): return f(*args, **kwds) tensor_args = [] diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py index 8213f5b321c..f18e6e9cb21 100644 --- a/tensorflow/python/framework/test_util_test.py +++ b/tensorflow/python/framework/test_util_test.py @@ -756,8 +756,7 @@ class TestUtilTest(test_util.TensorFlowTestCase, parameterized.TestCase): @test_util.build_as_function_and_v1_graph def test_modes(inner_self): # pylint: disable=no-self-argument - is_building_function = ops.get_default_graph().building_function - if is_building_function: + if ops.inside_function(): self.assertFalse(inner_self.inside_function_tested) inner_self.inside_function_tested = True else: diff --git a/tensorflow/python/ops/control_flow_state.py b/tensorflow/python/ops/control_flow_state.py index 29c55c4d60c..5f0838aa63b 100644 --- a/tensorflow/python/ops/control_flow_state.py +++ b/tensorflow/python/ops/control_flow_state.py @@ -659,7 +659,7 @@ class _ControlFlowState(object): """ if util.IsLoopSwitch(op): return None - if op.graph._building_function: # pylint: disable=protected-access + if op.graph.building_function: # The optimization here is tricky to apply to functions return array_ops.zeros_like(op.outputs[index]) dead_branch = util.IsSwitch(op) diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 3a289286212..6c2fe01f73f 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -1487,7 +1487,7 @@ def tensor_equals(self, other): return False g = getattr(self, "graph", None) if (ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions() and - (g is None or g._building_function)): # pylint: disable=protected-access + (g is None or g.building_function)): return gen_math_ops.equal(self, other, incompatible_shape_error=False) else: # In legacy graph mode, tensor equality is object equality diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index 8fa90507eaf..6e02531fccb 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -1839,8 +1839,7 @@ class _UnreadVariable(BaseResourceVariable): # Only create a graph_element if we're in session.run-land as only # session.run requires a preexisting tensor to evaluate. Otherwise we can # avoid accidentally reading the variable. - if (context.executing_eagerly() - or ops.get_default_graph()._building_function): # pylint: disable=protected-access + if context.executing_eagerly() or ops.inside_function(): graph_element = None else: with ops.control_dependencies([parent_op]): diff --git a/tensorflow/python/util/tf_should_use.py b/tensorflow/python/util/tf_should_use.py index f8c480ed1ac..0c11b08131c 100644 --- a/tensorflow/python/util/tf_should_use.py +++ b/tensorflow/python/util/tf_should_use.py @@ -49,7 +49,7 @@ class _TFShouldUseHelper(object): if context.executing_eagerly(): # If warn_in_eager, sated == False. Otherwise true. self._sated = not warn_in_eager - elif ops.get_default_graph()._building_function: # pylint: disable=protected-access + elif ops.inside_function(): if error_in_function: self._sated = False ops.add_exit_callback_to_default_func_graph( @@ -182,7 +182,7 @@ def _add_should_use_warning(x, error_in_function=False, warn_in_eager=False): if context.executing_eagerly() and not warn_in_eager: return x - if ops.get_default_graph()._building_function and not error_in_function: # pylint: disable=protected-access + if ops.inside_function() and not error_in_function: # We don't currently log warnings in tf.function calls, so just skip it. return x