Internal change on log.
PiperOrigin-RevId: 360730979 Change-Id: I3d9feccbef3823e3e2ac454aa653bfd1abf6315c
This commit is contained in:
parent
eee783f657
commit
a1825c953e
@ -1,5 +1,6 @@
|
||||
load("//tensorflow:tensorflow.bzl", "pytype_strict_library")
|
||||
load("@flatbuffers//:build_defs.bzl", "flatbuffer_py_library")
|
||||
load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable")
|
||||
load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable", "if_portable")
|
||||
|
||||
package(
|
||||
default_visibility = ["//tensorflow:internal"],
|
||||
@ -22,8 +23,10 @@ py_library(
|
||||
srcs_version = "PY3",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":metrics",
|
||||
"//tensorflow/lite/python/interpreter_wrapper:_pywrap_tensorflow_interpreter_wrapper",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/util:tf_export",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
@ -459,3 +462,57 @@ py_library(
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "metrics_nonportable",
|
||||
srcs = ["metrics_nonportable.py"],
|
||||
srcs_version = "PY3",
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
"//tensorflow/python/eager:monitoring",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "metrics_nonportable_test",
|
||||
srcs = ["metrics_nonportable_test.py"],
|
||||
python_version = "PY3",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":metrics_nonportable",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "metrics_portable",
|
||||
srcs = ["metrics_portable.py"],
|
||||
compatible_with = get_compatible_with_portable(),
|
||||
srcs_version = "PY3",
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "metrics_portable_test",
|
||||
srcs = ["metrics_portable_test.py"],
|
||||
python_version = "PY3",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":metrics_portable",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "metrics",
|
||||
compatible_with = get_compatible_with_portable(),
|
||||
srcs_version = "PY3",
|
||||
visibility = ["//tensorflow/lite:__subpackages__"],
|
||||
deps = if_portable(
|
||||
if_false = [":metrics_nonportable"],
|
||||
if_true = [":metrics_portable"],
|
||||
),
|
||||
)
|
||||
|
@ -40,6 +40,13 @@ else:
|
||||
return lambda x: x
|
||||
|
||||
|
||||
try:
|
||||
from tensorflow.lite.python import metrics_portable as metrics
|
||||
except ImportError:
|
||||
from tensorflow.lite.python import metrics_nonportable as metrics
|
||||
# pylint: enable=g-import-not-at-top
|
||||
|
||||
|
||||
class Delegate(object):
|
||||
"""Python wrapper class to manage TfLiteDelegate objects.
|
||||
|
||||
@ -321,6 +328,9 @@ class Interpreter(object):
|
||||
delegate._get_native_delegate_pointer()) # pylint: disable=protected-access
|
||||
self._signature_defs = self.get_signature_list()
|
||||
|
||||
self._metrics = metrics.TFLiteMetrics()
|
||||
self._metrics.increase_counter_interpreter_creation()
|
||||
|
||||
def __del__(self):
|
||||
# Must make sure the interpreter is destroyed before things that
|
||||
# are used by it like the delegates. NOTE this only works on CPython
|
||||
|
@ -22,6 +22,8 @@ import ctypes
|
||||
import io
|
||||
import sys
|
||||
|
||||
from unittest import mock
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
@ -37,6 +39,12 @@ from tensorflow.lite.python.testdata import _pywrap_test_registerer as test_regi
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import resource_loader
|
||||
from tensorflow.python.platform import test
|
||||
try:
|
||||
from tensorflow.lite.python import metrics_portable
|
||||
metrics = metrics_portable
|
||||
except ImportError:
|
||||
from tensorflow.lite.python import metrics_nonportable
|
||||
metrics = metrics_nonportable
|
||||
# pylint: enable=g-import-not-at-top
|
||||
|
||||
|
||||
@ -284,6 +292,14 @@ class InterpreterTest(test_util.TensorFlowTestCase):
|
||||
[0, 2, 3])
|
||||
self.assertAllEqual(s_params['dim_metadata'][1]['array_indices'], [0, 1, 1])
|
||||
|
||||
@mock.patch.object(metrics.TFLiteMetrics,
|
||||
'increase_counter_interpreter_creation')
|
||||
def testCreationCounter(self, increase_call):
|
||||
interpreter_wrapper.Interpreter(
|
||||
model_path=resource_loader.get_path_to_datafile(
|
||||
'testdata/permute_float.tflite'))
|
||||
increase_call.assert_called_once()
|
||||
|
||||
|
||||
class InterpreterTestErrorPropagation(test_util.TensorFlowTestCase):
|
||||
|
||||
|
47
tensorflow/lite/python/metrics_nonportable.py
Normal file
47
tensorflow/lite/python/metrics_nonportable.py
Normal file
@ -0,0 +1,47 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Python TFLite metrics helper."""
|
||||
from typing import Optional, Text
|
||||
|
||||
from tensorflow.python.eager import monitoring
|
||||
|
||||
|
||||
class TFLiteMetrics(object):
|
||||
"""TFLite metrics helper for prod (borg) environment.
|
||||
|
||||
Attributes:
|
||||
md5: A string containing the MD5 hash of the model binary.
|
||||
model_path: A string containing the path of the model for debugging
|
||||
purposes.
|
||||
"""
|
||||
|
||||
_counter_interpreter_creation = monitoring.Counter(
|
||||
'/tensorflow/lite/interpreter/created',
|
||||
'Counter for number of interpreter created in Python.', 'python')
|
||||
|
||||
def __init__(self,
|
||||
md5: Optional[Text] = None,
|
||||
model_path: Optional[Text] = None) -> None:
|
||||
del self # Temporarily removing self until parameter logic is implemented.
|
||||
if md5 and not model_path or not md5 and model_path:
|
||||
raise ValueError('Both model metadata(md5, model_path) should be given '
|
||||
'at the same time.')
|
||||
if md5:
|
||||
# TODO(b/180400857): Create stub once the service is implemented.
|
||||
pass
|
||||
|
||||
def increase_counter_interpreter_creation(self):
|
||||
self._counter_interpreter_creation.get_cell('python').increase_by(1)
|
46
tensorflow/lite/python/metrics_nonportable_test.py
Normal file
46
tensorflow/lite/python/metrics_nonportable_test.py
Normal file
@ -0,0 +1,46 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""TensorFlow Lite Python metrics helper TFLiteMetrics check."""
|
||||
from tensorflow.lite.python import metrics_nonportable as metrics
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class MetricsNonportableTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def test_TFLiteMetrics_creation_no_arg_success(self):
|
||||
metrics.TFLiteMetrics()
|
||||
|
||||
def test_TFLiteMetrics_creation_arg_success(self):
|
||||
metrics.TFLiteMetrics('md5', '/path/to/model')
|
||||
|
||||
def test_TFLiteMetrics_creation_fails_with_only_md5(self):
|
||||
with self.assertRaises(ValueError):
|
||||
metrics.TFLiteMetrics(md5='md5')
|
||||
|
||||
def test_TFLiteMetrics_creation_fail2_with_only_model_path(self):
|
||||
with self.assertRaises(ValueError):
|
||||
metrics.TFLiteMetrics(model_path='/path/to/model')
|
||||
|
||||
def test_interpreter_creation_counter_increase_success(self):
|
||||
stub = metrics.TFLiteMetrics()
|
||||
stub.increase_counter_interpreter_creation()
|
||||
self.assertEqual(
|
||||
stub._counter_interpreter_creation.get_cell('python').value(), 1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
28
tensorflow/lite/python/metrics_portable.py
Normal file
28
tensorflow/lite/python/metrics_portable.py
Normal file
@ -0,0 +1,28 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Python TFLite metrics helper."""
|
||||
from typing import Optional, Text
|
||||
|
||||
|
||||
class TFLiteMetrics(object):
|
||||
|
||||
def __init__(self,
|
||||
md5: Optional[Text] = None,
|
||||
model_path: Optional[Text] = None) -> None:
|
||||
pass
|
||||
|
||||
def increase_counter_interpreter_creation(self):
|
||||
pass
|
33
tensorflow/lite/python/metrics_portable_test.py
Normal file
33
tensorflow/lite/python/metrics_portable_test.py
Normal file
@ -0,0 +1,33 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""TensorFlow Lite Python metrics helpr TFLiteMetrics check."""
|
||||
from tensorflow.lite.python import metrics_portable as metrics
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class MetricsPortableTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def test_TFLiteMetrics_creation_success(self):
|
||||
metrics.TFLiteMetrics()
|
||||
|
||||
def test_interpreter_creation_counter_increase_success(self):
|
||||
stub = metrics.TFLiteMetrics()
|
||||
stub.increase_counter_interpreter_creation()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -297,6 +297,9 @@ def if_registration_v2(if_true, if_false = []):
|
||||
"//conditions:default": if_false,
|
||||
})
|
||||
|
||||
def if_portable(if_true, if_false = []):
|
||||
return if_true
|
||||
|
||||
# Linux systems may required -lrt linker flag for e.g. clock_gettime
|
||||
# see https://github.com/tensorflow/tensorflow/issues/15129
|
||||
def lrt_if_needed():
|
||||
|
Loading…
x
Reference in New Issue
Block a user