From c999815aac22950911e05a864e7d26008b13e8da Mon Sep 17 00:00:00 2001 From: Mehmet Deveci Date: Mon, 11 Jan 2021 17:36:43 -0800 Subject: [PATCH] TensorTracer: Ignore calls for tracing tf function graphs. PiperOrigin-RevId: 351271347 Change-Id: I04b1ebe3900982cc2c2003272dbd90dadbbfdbf7 --- tensorflow/python/tpu/tensor_tracer.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tensorflow/python/tpu/tensor_tracer.py b/tensorflow/python/tpu/tensor_tracer.py index fbb66accfc3..c2430967e57 100644 --- a/tensorflow/python/tpu/tensor_tracer.py +++ b/tensorflow/python/tpu/tensor_tracer.py @@ -30,6 +30,8 @@ import six from tensorflow.core.framework import summary_pb2 from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import func_graph +from tensorflow.python.framework import function from tensorflow.python.framework import graph_io from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util @@ -1960,6 +1962,12 @@ class TensorTracer(object): RuntimeError: If num_replicas_per_host > 8. RuntimeError: If tensor_fetches is None or empty. """ + if isinstance(graph, func_graph.FuncGraph) or isinstance( + graph, function._FuncGraph): # pylint: disable=protected-access + logging.warning('Tensor Tracer is not supported for tracing FuncGraphs. ' + 'Ignoring tracing.') + return tensor_fetches + if graph in TensorTracer._traced_graphs: logging.warning('Graph is already rewritten with tensor tracer, ignoring ' 'multiple calls.') @@ -2010,6 +2018,11 @@ class TensorTracer(object): Raises: RuntimeError: If tensor_fetches is None or empty. """ + if isinstance(graph, func_graph.FuncGraph) or isinstance( + graph, function._FuncGraph): # pylint: disable=protected-access + logging.warning('Tensor Tracer is not supported for tracing FuncGraphs. ' + 'Ignoring tracing.') + return tensor_fetches if graph in TensorTracer._traced_graphs: logging.warning('Graph is already rewritten with tensor tracer, ignoring '