minor changes in tensor tracer.

PiperOrigin-RevId: 274900258
Change-Id: If6f5485139e61c1448054741051118467e8f4e79
This commit is contained in:
A. Unique TensorFlower 2019-10-15 15:21:46 -07:00 committed by TensorFlower Gardener
parent 4120ac4755
commit b3422c3b59
4 changed files with 43 additions and 4 deletions
tensorflow/python

View File

@ -212,6 +212,7 @@ py_library(
exclude = [
"**/*test.py",
"**/benchmark.py", # In platform_benchmark.
"**/analytics.py", # In platform_analytics.
],
) + ["platform/build_info.py"],
srcs_version = "PY2AND3",
@ -238,6 +239,12 @@ py_library(
],
)
py_library(
name = "platform_analytics",
srcs = ["platform/analytics.py"],
srcs_version = "PY2AND3",
)
py_library(
name = "platform_test",
srcs = ["platform/googletest.py"],

View File

@ -0,0 +1,28 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Analytics helpers library."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
def track_usage(tool_id, tags):
"""No usage tracking for external library.
Args:
tool_id: A string identifier for tool to be tracked.
tags: list of string tags that will be added to the tracking.
"""
del tool_id, tags # Unused externally.

View File

@ -222,6 +222,7 @@ py_library(
"//tensorflow/python:dtypes",
"//tensorflow/python:framework",
"//tensorflow/python:framework_ops",
"//tensorflow/python:platform_analytics",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tpu_ops_gen",
"//tensorflow/python:training",

View File

@ -43,6 +43,7 @@ from tensorflow.python.ops import nn_impl
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import summary_ops_v2 as summary
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import analytics
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary_iterator
@ -1302,9 +1303,12 @@ class TensorTracer(object):
return math_ops.cast(tensor, dtypes.float32)
return tensor
TensorTracer.check_device_type(self._tt_config.device_type)
TensorTracer.check_trace_mode(self._tt_config.device_type,
self._parameters.trace_mode)
trace_mode = self._parameters.trace_mode
device_type = self._tt_config.device_type
analytics.track_usage('tensor_tracer', [trace_mode, device_type])
TensorTracer.check_device_type(device_type)
TensorTracer.check_trace_mode(device_type, trace_mode)
# Check in_tensor_fetches, and op_fetches and convert them to lists.
processed_t_fetches = self._process_tensor_fetches(tensor_fetches)
op_fetches = self._process_op_fetches(op_fetches)
@ -1468,7 +1472,6 @@ class TensorTracer(object):
RuntimeError: If num_replicas_per_host > 8.
RuntimeError: If tensor_fetches is None or empty.
"""
if graph in TensorTracer._traced_graphs:
logging.warning('Graph is already rewritten with tensor tracer, ignoring '
'multiple calls.')