From a1825c953ec74f62284d8d9a5071758cd3344c11 Mon Sep 17 00:00:00 2001 From: Hyeonjong Ryu Date: Wed, 3 Mar 2021 12:25:53 -0800 Subject: [PATCH] Internal change on log. PiperOrigin-RevId: 360730979 Change-Id: I3d9feccbef3823e3e2ac454aa653bfd1abf6315c --- tensorflow/lite/python/BUILD | 59 ++++++++++++++++++- tensorflow/lite/python/interpreter.py | 10 ++++ tensorflow/lite/python/interpreter_test.py | 16 +++++ tensorflow/lite/python/metrics_nonportable.py | 47 +++++++++++++++ .../lite/python/metrics_nonportable_test.py | 46 +++++++++++++++ tensorflow/lite/python/metrics_portable.py | 28 +++++++++ .../lite/python/metrics_portable_test.py | 33 +++++++++++ tensorflow/tensorflow.bzl | 3 + 8 files changed, 241 insertions(+), 1 deletion(-) create mode 100644 tensorflow/lite/python/metrics_nonportable.py create mode 100644 tensorflow/lite/python/metrics_nonportable_test.py create mode 100644 tensorflow/lite/python/metrics_portable.py create mode 100644 tensorflow/lite/python/metrics_portable_test.py diff --git a/tensorflow/lite/python/BUILD b/tensorflow/lite/python/BUILD index ae94e5a12e2..28b8541b9b4 100644 --- a/tensorflow/lite/python/BUILD +++ b/tensorflow/lite/python/BUILD @@ -1,5 +1,6 @@ +load("//tensorflow:tensorflow.bzl", "pytype_strict_library") load("@flatbuffers//:build_defs.bzl", "flatbuffer_py_library") -load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable") +load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable", "if_portable") package( default_visibility = ["//tensorflow:internal"], @@ -22,8 +23,10 @@ py_library( srcs_version = "PY3", visibility = ["//visibility:public"], deps = [ + ":metrics", "//tensorflow/lite/python/interpreter_wrapper:_pywrap_tensorflow_interpreter_wrapper", "//tensorflow/python:util", + "//tensorflow/python/util:tf_export", "//third_party/py/numpy", ], ) @@ -459,3 +462,57 @@ py_library( "//tensorflow/python:util", ], ) + +pytype_strict_library( + name = "metrics_nonportable", + srcs = ["metrics_nonportable.py"], + srcs_version = "PY3", + visibility = ["//visibility:private"], + deps = [ + "//tensorflow/python/eager:monitoring", + ], +) + +py_test( + name = "metrics_nonportable_test", + srcs = ["metrics_nonportable_test.py"], + python_version = "PY3", + visibility = ["//visibility:public"], + deps = [ + ":metrics_nonportable", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + ], +) + +pytype_strict_library( + name = "metrics_portable", + srcs = ["metrics_portable.py"], + compatible_with = get_compatible_with_portable(), + srcs_version = "PY3", + visibility = ["//visibility:private"], + deps = [], +) + +py_test( + name = "metrics_portable_test", + srcs = ["metrics_portable_test.py"], + python_version = "PY3", + visibility = ["//visibility:public"], + deps = [ + ":metrics_portable", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + ], +) + +pytype_strict_library( + name = "metrics", + compatible_with = get_compatible_with_portable(), + srcs_version = "PY3", + visibility = ["//tensorflow/lite:__subpackages__"], + deps = if_portable( + if_false = [":metrics_nonportable"], + if_true = [":metrics_portable"], + ), +) diff --git a/tensorflow/lite/python/interpreter.py b/tensorflow/lite/python/interpreter.py index f7ef3b34ba6..5c5898b6d4d 100644 --- a/tensorflow/lite/python/interpreter.py +++ b/tensorflow/lite/python/interpreter.py @@ -40,6 +40,13 @@ else: return lambda x: x +try: + from tensorflow.lite.python import metrics_portable as metrics +except ImportError: + from tensorflow.lite.python import metrics_nonportable as metrics +# pylint: enable=g-import-not-at-top + + class Delegate(object): """Python wrapper class to manage TfLiteDelegate objects. @@ -321,6 +328,9 @@ class Interpreter(object): delegate._get_native_delegate_pointer()) # pylint: disable=protected-access self._signature_defs = self.get_signature_list() + self._metrics = metrics.TFLiteMetrics() + self._metrics.increase_counter_interpreter_creation() + def __del__(self): # Must make sure the interpreter is destroyed before things that # are used by it like the delegates. NOTE this only works on CPython diff --git a/tensorflow/lite/python/interpreter_test.py b/tensorflow/lite/python/interpreter_test.py index 2c43ba2dd4f..5a15e460a9e 100644 --- a/tensorflow/lite/python/interpreter_test.py +++ b/tensorflow/lite/python/interpreter_test.py @@ -22,6 +22,8 @@ import ctypes import io import sys +from unittest import mock + import numpy as np import six @@ -37,6 +39,12 @@ from tensorflow.lite.python.testdata import _pywrap_test_registerer as test_regi from tensorflow.python.framework import test_util from tensorflow.python.platform import resource_loader from tensorflow.python.platform import test +try: + from tensorflow.lite.python import metrics_portable + metrics = metrics_portable +except ImportError: + from tensorflow.lite.python import metrics_nonportable + metrics = metrics_nonportable # pylint: enable=g-import-not-at-top @@ -284,6 +292,14 @@ class InterpreterTest(test_util.TensorFlowTestCase): [0, 2, 3]) self.assertAllEqual(s_params['dim_metadata'][1]['array_indices'], [0, 1, 1]) + @mock.patch.object(metrics.TFLiteMetrics, + 'increase_counter_interpreter_creation') + def testCreationCounter(self, increase_call): + interpreter_wrapper.Interpreter( + model_path=resource_loader.get_path_to_datafile( + 'testdata/permute_float.tflite')) + increase_call.assert_called_once() + class InterpreterTestErrorPropagation(test_util.TensorFlowTestCase): diff --git a/tensorflow/lite/python/metrics_nonportable.py b/tensorflow/lite/python/metrics_nonportable.py new file mode 100644 index 00000000000..145b1c05da8 --- /dev/null +++ b/tensorflow/lite/python/metrics_nonportable.py @@ -0,0 +1,47 @@ +# Lint as: python2, python3 +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python TFLite metrics helper.""" +from typing import Optional, Text + +from tensorflow.python.eager import monitoring + + +class TFLiteMetrics(object): + """TFLite metrics helper for prod (borg) environment. + + Attributes: + md5: A string containing the MD5 hash of the model binary. + model_path: A string containing the path of the model for debugging + purposes. + """ + + _counter_interpreter_creation = monitoring.Counter( + '/tensorflow/lite/interpreter/created', + 'Counter for number of interpreter created in Python.', 'python') + + def __init__(self, + md5: Optional[Text] = None, + model_path: Optional[Text] = None) -> None: + del self # Temporarily removing self until parameter logic is implemented. + if md5 and not model_path or not md5 and model_path: + raise ValueError('Both model metadata(md5, model_path) should be given ' + 'at the same time.') + if md5: + # TODO(b/180400857): Create stub once the service is implemented. + pass + + def increase_counter_interpreter_creation(self): + self._counter_interpreter_creation.get_cell('python').increase_by(1) diff --git a/tensorflow/lite/python/metrics_nonportable_test.py b/tensorflow/lite/python/metrics_nonportable_test.py new file mode 100644 index 00000000000..b58f8e3eef4 --- /dev/null +++ b/tensorflow/lite/python/metrics_nonportable_test.py @@ -0,0 +1,46 @@ +# Lint as: python2, python3 +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TensorFlow Lite Python metrics helper TFLiteMetrics check.""" +from tensorflow.lite.python import metrics_nonportable as metrics +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test + + +class MetricsNonportableTest(test_util.TensorFlowTestCase): + + def test_TFLiteMetrics_creation_no_arg_success(self): + metrics.TFLiteMetrics() + + def test_TFLiteMetrics_creation_arg_success(self): + metrics.TFLiteMetrics('md5', '/path/to/model') + + def test_TFLiteMetrics_creation_fails_with_only_md5(self): + with self.assertRaises(ValueError): + metrics.TFLiteMetrics(md5='md5') + + def test_TFLiteMetrics_creation_fail2_with_only_model_path(self): + with self.assertRaises(ValueError): + metrics.TFLiteMetrics(model_path='/path/to/model') + + def test_interpreter_creation_counter_increase_success(self): + stub = metrics.TFLiteMetrics() + stub.increase_counter_interpreter_creation() + self.assertEqual( + stub._counter_interpreter_creation.get_cell('python').value(), 1) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/lite/python/metrics_portable.py b/tensorflow/lite/python/metrics_portable.py new file mode 100644 index 00000000000..bcca306f15c --- /dev/null +++ b/tensorflow/lite/python/metrics_portable.py @@ -0,0 +1,28 @@ +# Lint as: python2, python3 +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python TFLite metrics helper.""" +from typing import Optional, Text + + +class TFLiteMetrics(object): + + def __init__(self, + md5: Optional[Text] = None, + model_path: Optional[Text] = None) -> None: + pass + + def increase_counter_interpreter_creation(self): + pass diff --git a/tensorflow/lite/python/metrics_portable_test.py b/tensorflow/lite/python/metrics_portable_test.py new file mode 100644 index 00000000000..e88b597f110 --- /dev/null +++ b/tensorflow/lite/python/metrics_portable_test.py @@ -0,0 +1,33 @@ +# Lint as: python2, python3 +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TensorFlow Lite Python metrics helpr TFLiteMetrics check.""" +from tensorflow.lite.python import metrics_portable as metrics +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test + + +class MetricsPortableTest(test_util.TensorFlowTestCase): + + def test_TFLiteMetrics_creation_success(self): + metrics.TFLiteMetrics() + + def test_interpreter_creation_counter_increase_success(self): + stub = metrics.TFLiteMetrics() + stub.increase_counter_interpreter_creation() + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index 3dd9ad42d18..bac172990f1 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -297,6 +297,9 @@ def if_registration_v2(if_true, if_false = []): "//conditions:default": if_false, }) +def if_portable(if_true, if_false = []): + return if_true + # Linux systems may required -lrt linker flag for e.g. clock_gettime # see https://github.com/tensorflow/tensorflow/issues/15129 def lrt_if_needed():