TensorTracer: Ignore calls for tracing tf function graphs.

PiperOrigin-RevId: 351271347
Change-Id: I04b1ebe3900982cc2c2003272dbd90dadbbfdbf7
This commit is contained in:
Mehmet Deveci 2021-01-11 17:36:43 -08:00 committed by TensorFlower Gardener
parent d431a5faa1
commit c999815aac

View File

@ -30,6 +30,8 @@ import six
from tensorflow.core.framework import summary_pb2
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import func_graph
from tensorflow.python.framework import function
from tensorflow.python.framework import graph_io
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
@ -1960,6 +1962,12 @@ class TensorTracer(object):
RuntimeError: If num_replicas_per_host > 8.
RuntimeError: If tensor_fetches is None or empty.
"""
if isinstance(graph, func_graph.FuncGraph) or isinstance(
graph, function._FuncGraph): # pylint: disable=protected-access
logging.warning('Tensor Tracer is not supported for tracing FuncGraphs. '
'Ignoring tracing.')
return tensor_fetches
if graph in TensorTracer._traced_graphs:
logging.warning('Graph is already rewritten with tensor tracer, ignoring '
'multiple calls.')
@ -2010,6 +2018,11 @@ class TensorTracer(object):
Raises:
RuntimeError: If tensor_fetches is None or empty.
"""
if isinstance(graph, func_graph.FuncGraph) or isinstance(
graph, function._FuncGraph): # pylint: disable=protected-access
logging.warning('Tensor Tracer is not supported for tracing FuncGraphs. '
'Ignoring tracing.')
return tensor_fetches
if graph in TensorTracer._traced_graphs:
logging.warning('Graph is already rewritten with tensor tracer, ignoring '