Replace usage of Graph._building_function with a public one
Not sure why people use it in the first place. PiperOrigin-RevId: 293885535 Change-Id: Ibdd3bc050b29b3fe88aebee4e3f6c804fe21813e
This commit is contained in:
parent
5fc54e2153
commit
e05fdf7a7a
@ -105,8 +105,7 @@ class DatasetTestBase(test.TestCase):
|
|||||||
|
|
||||||
# Create an anonymous iterator if we are in eager-mode or are graph inside
|
# Create an anonymous iterator if we are in eager-mode or are graph inside
|
||||||
# of a tf.function.
|
# of a tf.function.
|
||||||
building_function = ops.get_default_graph()._building_function # pylint: disable=protected-access
|
if context.executing_eagerly() or ops.inside_function():
|
||||||
if context.executing_eagerly() or building_function:
|
|
||||||
iterator = iter(dataset)
|
iterator = iter(dataset)
|
||||||
return ta_wrapper(iterator._next_internal) # pylint: disable=protected-access
|
return ta_wrapper(iterator._next_internal) # pylint: disable=protected-access
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -400,8 +400,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
|
|||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If not inside of tf.function and not executing eagerly.
|
RuntimeError: If not inside of tf.function and not executing eagerly.
|
||||||
"""
|
"""
|
||||||
if (context.executing_eagerly()
|
if context.executing_eagerly() or ops.inside_function():
|
||||||
or ops.get_default_graph()._building_function): # pylint: disable=protected-access
|
|
||||||
return iterator_ops.OwnedIterator(self)
|
return iterator_ops.OwnedIterator(self)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("__iter__() is only supported inside of tf.function "
|
raise RuntimeError("__iter__() is only supported inside of tf.function "
|
||||||
@ -3446,8 +3445,7 @@ class CacheDataset(UnaryUnchangedStructureDataset):
|
|||||||
self._input_dataset = input_dataset
|
self._input_dataset = input_dataset
|
||||||
self._filename = ops.convert_to_tensor(
|
self._filename = ops.convert_to_tensor(
|
||||||
filename, dtype=dtypes.string, name="filename")
|
filename, dtype=dtypes.string, name="filename")
|
||||||
if tf2.enabled() and (context.executing_eagerly() or
|
if tf2.enabled() and (context.executing_eagerly() or ops.inside_function()):
|
||||||
ops.get_default_graph()._building_function): # pylint: disable=protected-access
|
|
||||||
self._cache = _MemoryCache()
|
self._cache = _MemoryCache()
|
||||||
variant_tensor = gen_dataset_ops.cache_dataset_v2(
|
variant_tensor = gen_dataset_ops.cache_dataset_v2(
|
||||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||||
@ -3542,8 +3540,7 @@ class ShuffleDataset(UnaryUnchangedStructureDataset):
|
|||||||
self._reshuffle_each_iteration = reshuffle_each_iteration
|
self._reshuffle_each_iteration = reshuffle_each_iteration
|
||||||
|
|
||||||
if tf2.enabled() and self._reshuffle_each_iteration and (
|
if tf2.enabled() and self._reshuffle_each_iteration and (
|
||||||
context.executing_eagerly() or
|
context.executing_eagerly() or ops.inside_function()):
|
||||||
ops.get_default_graph()._building_function): # pylint: disable=protected-access
|
|
||||||
self._seed_generator = _RandomSeedGenerator(self._seed, self._seed2)
|
self._seed_generator = _RandomSeedGenerator(self._seed, self._seed2)
|
||||||
variant_tensor = gen_dataset_ops.shuffle_dataset_v2(
|
variant_tensor = gen_dataset_ops.shuffle_dataset_v2(
|
||||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||||
|
|||||||
@ -490,8 +490,7 @@ class OwnedMultiDeviceIterator(composite_tensor.CompositeTensor):
|
|||||||
RuntimeError: If executed in graph mode or outside of function building
|
RuntimeError: If executed in graph mode or outside of function building
|
||||||
mode.
|
mode.
|
||||||
"""
|
"""
|
||||||
if (not context.executing_eagerly() and
|
if not context.executing_eagerly() and not ops.inside_function():
|
||||||
not ops.get_default_graph()._building_function): # pylint: disable=protected-access
|
|
||||||
raise RuntimeError("OwnedMultiDeviceIterator is only supported inside of "
|
raise RuntimeError("OwnedMultiDeviceIterator is only supported inside of "
|
||||||
"tf.function or when eager execution is enabled.")
|
"tf.function or when eager execution is enabled.")
|
||||||
if devices is None:
|
if devices is None:
|
||||||
|
|||||||
@ -723,7 +723,7 @@ class Tensor(_TensorLike):
|
|||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
g = getattr(self, "graph", None)
|
g = getattr(self, "graph", None)
|
||||||
if (Tensor._USE_EQUALITY and executing_eagerly_outside_functions() and
|
if (Tensor._USE_EQUALITY and executing_eagerly_outside_functions() and
|
||||||
(g is None or g._building_function)): # pylint: disable=protected-access
|
(g is None or g.building_function)):
|
||||||
raise TypeError("Tensor is unhashable if Tensor equality is enabled. "
|
raise TypeError("Tensor is unhashable if Tensor equality is enabled. "
|
||||||
"Instead, use tensor.experimental_ref() as the key.")
|
"Instead, use tensor.experimental_ref() as the key.")
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -1147,7 +1147,7 @@ def run_in_graph_and_eager_modes(func=None,
|
|||||||
def py_func_if_in_function(f):
|
def py_func_if_in_function(f):
|
||||||
|
|
||||||
def decorated(*args, **kwds):
|
def decorated(*args, **kwds):
|
||||||
if not ops.get_default_graph()._building_function:
|
if not ops.inside_function():
|
||||||
return f(*args, **kwds)
|
return f(*args, **kwds)
|
||||||
|
|
||||||
tensor_args = []
|
tensor_args = []
|
||||||
|
|||||||
@ -756,8 +756,7 @@ class TestUtilTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
@test_util.build_as_function_and_v1_graph
|
@test_util.build_as_function_and_v1_graph
|
||||||
def test_modes(inner_self): # pylint: disable=no-self-argument
|
def test_modes(inner_self): # pylint: disable=no-self-argument
|
||||||
is_building_function = ops.get_default_graph().building_function
|
if ops.inside_function():
|
||||||
if is_building_function:
|
|
||||||
self.assertFalse(inner_self.inside_function_tested)
|
self.assertFalse(inner_self.inside_function_tested)
|
||||||
inner_self.inside_function_tested = True
|
inner_self.inside_function_tested = True
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -659,7 +659,7 @@ class _ControlFlowState(object):
|
|||||||
"""
|
"""
|
||||||
if util.IsLoopSwitch(op):
|
if util.IsLoopSwitch(op):
|
||||||
return None
|
return None
|
||||||
if op.graph._building_function: # pylint: disable=protected-access
|
if op.graph.building_function:
|
||||||
# The optimization here is tricky to apply to functions
|
# The optimization here is tricky to apply to functions
|
||||||
return array_ops.zeros_like(op.outputs[index])
|
return array_ops.zeros_like(op.outputs[index])
|
||||||
dead_branch = util.IsSwitch(op)
|
dead_branch = util.IsSwitch(op)
|
||||||
|
|||||||
@ -1487,7 +1487,7 @@ def tensor_equals(self, other):
|
|||||||
return False
|
return False
|
||||||
g = getattr(self, "graph", None)
|
g = getattr(self, "graph", None)
|
||||||
if (ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions() and
|
if (ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions() and
|
||||||
(g is None or g._building_function)): # pylint: disable=protected-access
|
(g is None or g.building_function)):
|
||||||
return gen_math_ops.equal(self, other, incompatible_shape_error=False)
|
return gen_math_ops.equal(self, other, incompatible_shape_error=False)
|
||||||
else:
|
else:
|
||||||
# In legacy graph mode, tensor equality is object equality
|
# In legacy graph mode, tensor equality is object equality
|
||||||
|
|||||||
@ -1839,8 +1839,7 @@ class _UnreadVariable(BaseResourceVariable):
|
|||||||
# Only create a graph_element if we're in session.run-land as only
|
# Only create a graph_element if we're in session.run-land as only
|
||||||
# session.run requires a preexisting tensor to evaluate. Otherwise we can
|
# session.run requires a preexisting tensor to evaluate. Otherwise we can
|
||||||
# avoid accidentally reading the variable.
|
# avoid accidentally reading the variable.
|
||||||
if (context.executing_eagerly()
|
if context.executing_eagerly() or ops.inside_function():
|
||||||
or ops.get_default_graph()._building_function): # pylint: disable=protected-access
|
|
||||||
graph_element = None
|
graph_element = None
|
||||||
else:
|
else:
|
||||||
with ops.control_dependencies([parent_op]):
|
with ops.control_dependencies([parent_op]):
|
||||||
|
|||||||
@ -49,7 +49,7 @@ class _TFShouldUseHelper(object):
|
|||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
# If warn_in_eager, sated == False. Otherwise true.
|
# If warn_in_eager, sated == False. Otherwise true.
|
||||||
self._sated = not warn_in_eager
|
self._sated = not warn_in_eager
|
||||||
elif ops.get_default_graph()._building_function: # pylint: disable=protected-access
|
elif ops.inside_function():
|
||||||
if error_in_function:
|
if error_in_function:
|
||||||
self._sated = False
|
self._sated = False
|
||||||
ops.add_exit_callback_to_default_func_graph(
|
ops.add_exit_callback_to_default_func_graph(
|
||||||
@ -182,7 +182,7 @@ def _add_should_use_warning(x, error_in_function=False, warn_in_eager=False):
|
|||||||
if context.executing_eagerly() and not warn_in_eager:
|
if context.executing_eagerly() and not warn_in_eager:
|
||||||
return x
|
return x
|
||||||
|
|
||||||
if ops.get_default_graph()._building_function and not error_in_function: # pylint: disable=protected-access
|
if ops.inside_function() and not error_in_function:
|
||||||
# We don't currently log warnings in tf.function calls, so just skip it.
|
# We don't currently log warnings in tf.function calls, so just skip it.
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user