Automated rollback of commit 18b2f26832dabd864b5a79403f1d92b45e081e9d

PiperOrigin-RevId: 258844085
This commit is contained in:
Andrew Selle 2019-07-18 14:21:00 -07:00 committed by TensorFlower Gardener
parent a377701899
commit 4117e2129f
4 changed files with 491 additions and 8 deletions

View File

@ -140,6 +140,22 @@ py_test(
],
)
py_test(
name = "lite_mlir_test",
srcs = ["lite_mlir_test.py"],
python_version = "PY2",
srcs_version = "PY2AND3",
tags = [
"no_oss",
"no_windows",
],
deps = [
":lite",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_test_lib",
],
)
py_library(
name = "util",
srcs = ["util.py"],

View File

@ -0,0 +1,461 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for lite.py functionality related to MLIR-TFLite converter."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.lite.python import lite
from tensorflow.lite.python import lite_constants
from tensorflow.lite.python.interpreter import Interpreter
from tensorflow.python.client import session
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training.tracking import tracking
def mlir_convert_and_check_for_unsupported(test_object, converter):
"""Run the converter but don't fail MLIR was not built.
Args:
test_object: PyTest object.
converter: A TFLiteConverter
Returns:
The converted TF lite model or None if mlir support is not builtinto the
binary.
"""
try:
model = converter.convert()
test_object.assertTrue(model)
return model
except lite.ConverterError as e:
if not e.message.startswith('This flag is not supported by this version'):
raise e
else:
return None
@test_util.run_v1_only('Incompatible with 2.0.')
class FromSessionTest(test_util.TensorFlowTestCase):
def testFloat(self):
in_tensor = array_ops.placeholder(
shape=[1, 16, 16, 3], dtype=dtypes.float32)
out_tensor = in_tensor + in_tensor
sess = session.Session()
# Convert model and ensure model is not None.
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
[out_tensor])
converter.experimental_enable_mlir_converter = True
tflite_model = mlir_convert_and_check_for_unsupported(self, converter)
if tflite_model is None:
return
# Check values from converted model.
interpreter = Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
self.assertEqual(1, len(input_details))
self.assertEqual('Placeholder', input_details[0]['name'])
self.assertEqual(np.float32, input_details[0]['dtype'])
self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
self.assertEqual((0., 0.), input_details[0]['quantization'])
output_details = interpreter.get_output_details()
self.assertEqual(1, len(output_details))
self.assertEqual('add', output_details[0]['name'])
self.assertEqual(np.float32, output_details[0]['dtype'])
self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
self.assertEqual((0., 0.), output_details[0]['quantization'])
def testString(self):
in_tensor = array_ops.placeholder(shape=[4], dtype=dtypes.string)
out_tensor = array_ops.reshape(in_tensor, shape=[2, 2])
sess = session.Session()
# Convert model and ensure model is not None.
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
[out_tensor])
tflite_model = mlir_convert_and_check_for_unsupported(self, converter)
if tflite_model is None:
return
# Check values from converted model.
interpreter = Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
self.assertEqual(1, len(input_details))
self.assertEqual('Placeholder', input_details[0]['name'])
self.assertEqual(np.string_, input_details[0]['dtype'])
self.assertTrue(([4] == input_details[0]['shape']).all())
output_details = interpreter.get_output_details()
self.assertEqual(1, len(output_details))
self.assertEqual('Reshape', output_details[0]['name'])
self.assertEqual(np.string_, output_details[0]['dtype'])
self.assertTrue(([2, 2] == output_details[0]['shape']).all())
def testQuantization(self):
in_tensor_1 = array_ops.placeholder(
shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA')
in_tensor_2 = array_ops.placeholder(
shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB')
out_tensor = array_ops.fake_quant_with_min_max_args(
in_tensor_1 + in_tensor_2, min=0., max=1., name='output')
sess = session.Session()
# Convert model and ensure model is not None.
converter = lite.TFLiteConverter.from_session(sess,
[in_tensor_1, in_tensor_2],
[out_tensor])
converter.experimental_enable_mlir_converter = True
converter.inference_type = lite_constants.QUANTIZED_UINT8
converter.quantized_input_stats = {
'inputA': (0., 1.),
'inputB': (0., 1.)
} # mean, std_dev
tflite_model = mlir_convert_and_check_for_unsupported(self, converter)
if tflite_model is None:
return
# Check values from converted model.
interpreter = Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
self.assertEqual(2, len(input_details))
self.assertEqual('inputA', input_details[0]['name'])
self.assertEqual(np.uint8, input_details[0]['dtype'])
self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
self.assertEqual((1., 0.),
input_details[0]['quantization']) # scale, zero_point
self.assertEqual('inputB', input_details[1]['name'])
self.assertEqual(np.uint8, input_details[1]['dtype'])
self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all())
self.assertEqual((1., 0.),
input_details[1]['quantization']) # scale, zero_point
output_details = interpreter.get_output_details()
self.assertEqual(1, len(output_details))
self.assertEqual('add', output_details[0]['name'])
self.assertEqual(np.uint8, output_details[0]['dtype'])
self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
self.assertGreater(output_details[0]['quantization'][0], 0) # scale
def testScalarValid(self):
# Construct a graph using a scalar (empty shape) input.
in_tensor = array_ops.placeholder(dtype=dtypes.float32, shape=[])
out_tensor = in_tensor + in_tensor
sess = session.Session()
# Test conversion with the scalar input shape.
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
[out_tensor])
tflite_model = mlir_convert_and_check_for_unsupported(self, converter)
if tflite_model is None:
return
# Check values from converted model.
interpreter = Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
self.assertEqual(1, len(input_details))
self.assertEqual('Placeholder', input_details[0]['name'])
self.assertEqual(np.float32, input_details[0]['dtype'])
self.assertEqual(len(input_details[0]['shape']), 0)
output_details = interpreter.get_output_details()
self.assertEqual(1, len(output_details))
self.assertEqual('add', output_details[0]['name'])
self.assertEqual(np.float32, output_details[0]['dtype'])
self.assertEqual(len(output_details[0]['shape']), 0)
# Validate inference using the scalar inputs/outputs.
test_input = np.array(4.0, dtype=np.float32)
expected_output = np.array(8.0, dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], test_input)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
self.assertTrue((expected_output == output_data).all())
def testPostTrainingQuantize(self):
np.random.seed(0)
# We need the tensor to have more than 1024 elements for quantize_weights
# to kick in. Thus, the [33, 33] shape.
in_tensor_1 = array_ops.placeholder(
shape=[33, 33], dtype=dtypes.float32, name='inputA')
in_tensor_2 = constant_op.constant(
np.random.uniform(low=-10., high=10., size=(33, 33)),
shape=[33, 33],
dtype=dtypes.float32,
name='inputB')
out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2, name='output')
sess = session.Session()
# Convert float model.
float_converter = lite.TFLiteConverter.from_session(sess, [in_tensor_1],
[out_tensor])
float_tflite = mlir_convert_and_check_for_unsupported(self, float_converter)
if float_tflite is None:
return
# Convert quantized weights model.
quantized_converter = lite.TFLiteConverter.from_session(
sess, [in_tensor_1], [out_tensor])
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
quantized_tflite = mlir_convert_and_check_for_unsupported(
self, quantized_converter)
if quantized_tflite is None:
return
# Ensure that the quantized weights tflite model is smaller.
self.assertLess(len(quantized_tflite), len(float_tflite))
@test_util.run_in_graph_and_eager_modes
def testFunctions(self):
"""Tests tf.function in 1.X."""
@def_function.function
def plus_placeholder(x, placeholder):
return x + placeholder
with ops.Graph().as_default():
placeholder = array_ops.placeholder(
dtype=dtypes.float32, shape=[1], name='input')
variable_node = variables.Variable(1.0, name='variable_node')
defun_node = plus_placeholder(variable_node, placeholder)
output_node = math_ops.multiply(defun_node, 2.0, name='output_node')
# Initialize variables in the model.
sess = session.Session()
sess.run(variables.variables_initializer([variable_node]))
# Convert model and ensure model is not None.
converter = lite.TFLiteConverter.from_session(sess, [placeholder],
[output_node])
tflite_model = mlir_convert_and_check_for_unsupported(self, converter)
if tflite_model is None:
return
# Check values from converted model.
interpreter = Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
self.assertEqual(1, len(input_details))
self.assertEqual('input', input_details[0]['name'])
self.assertEqual(np.float32, input_details[0]['dtype'])
self.assertTrue(([1] == input_details[0]['shape']).all())
self.assertEqual((0., 0.), input_details[0]['quantization'])
output_details = interpreter.get_output_details()
self.assertEqual(1, len(output_details))
self.assertEqual('output_node', output_details[0]['name'])
self.assertEqual(np.float32, output_details[0]['dtype'])
self.assertTrue(([1] == output_details[0]['shape']).all())
self.assertEqual((0., 0.), output_details[0]['quantization'])
class FromConcreteFunctionTest(test_util.TensorFlowTestCase):
def _evaluateTFLiteModel(self, tflite_model, input_data):
"""Evaluates the model on the `input_data`."""
interpreter = Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
for input_tensor, tensor_data in zip(input_details, input_data):
interpreter.set_tensor(input_tensor['index'], tensor_data.numpy())
interpreter.invoke()
return [
interpreter.get_tensor(details['index']) for details in output_details
]
def _getSimpleVariableModel(self):
root = tracking.AutoTrackable()
root.v1 = variables.Variable(3.)
root.v2 = variables.Variable(2.)
root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
return root
@test_util.run_v2_only
def testFloat(self):
root = self._getSimpleVariableModel()
input_data = constant_op.constant(1., shape=[1])
concrete_func = root.f.get_concrete_function(input_data)
# Convert model.
converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
converter.experimental_enable_mlir_converter = True
tflite_model = mlir_convert_and_check_for_unsupported(self, converter)
if tflite_model is None:
return
# Check values from converted model.
expected_value = root.f(input_data)
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
self.assertEqual(expected_value.numpy(), actual_value)
@test_util.run_v2_only
def testControlFlow(self):
input_data = {
'x': constant_op.constant([1., 2.], shape=[1, 2]),
'b': constant_op.constant(True)
}
weights = variables.Variable([[0.1, 0.2], [0.3, 0.4]], dtype=dtypes.float32)
def true_fn(x):
return math_ops.matmul(x, weights)
def false_fn(x):
return math_ops.add(x, weights)
@def_function.function(input_signature=[
tensor_spec.TensorSpec(shape=[1, 2], dtype=dtypes.float32),
tensor_spec.TensorSpec(shape=(), dtype=dtypes.bool)
])
def model(x, b):
return control_flow_ops.cond(
b, true_fn=lambda: true_fn(x), false_fn=lambda: false_fn(x))
concrete_func = model.get_concrete_function()
# Convert model.
converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
converter.experimental_enable_mlir_converter = True
tflite_model = mlir_convert_and_check_for_unsupported(self, converter)
if tflite_model is None:
return
# Check values from converted model.
expected_value = concrete_func(**input_data)
actual_value = self._evaluateTFLiteModel(
tflite_model, [input_data['x'], input_data['b']])[0]
np.testing.assert_almost_equal(expected_value.numpy(), actual_value)
@test_util.run_v2_only
def testStaticRnn(self):
input_data = constant_op.constant(
np.array(np.random.random_sample((3, 10)), dtype=np.float32))
cell = rnn_cell_impl.LSTMCell(10)
@def_function.function(input_signature=[
tensor_spec.TensorSpec(shape=[3, 10], dtype=dtypes.float32)
])
def model(x):
seq = array_ops.split(x, 3, 0)
return rnn.static_rnn(
cell, seq, dtype=dtypes.float32, sequence_length=[1])
concrete_func = model.get_concrete_function()
# Convert model.
converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
converter.experimental_enable_mlir_converter = True
tflite_model = mlir_convert_and_check_for_unsupported(self, converter)
if tflite_model is None:
return
# Check values from converted model.
expected_value = concrete_func(input_data)[0]
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
for expected, actual in zip(expected_value, actual_value):
np.testing.assert_almost_equal(expected.numpy(), actual)
class TestFlexMode(test_util.TensorFlowTestCase):
@test_util.run_v1_only('Incompatible with 2.0.')
def testSession(self):
in_tensor = array_ops.placeholder(
shape=[1, 16, 16, 3], dtype=dtypes.float32)
out_tensor = in_tensor + in_tensor
sess = session.Session()
# Convert model and ensure model is not None.
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
[out_tensor])
converter.experimental_enable_mlir_converter = True
converter.target_spec.supported_ops = set([lite.OpsSet.SELECT_TF_OPS])
tflite_model = mlir_convert_and_check_for_unsupported(self, converter)
if tflite_model is None:
return
# Ensures the model contains TensorFlow ops.
# TODO(nupurgarg): Check values once there is a Python delegate interface.
interpreter = Interpreter(model_content=tflite_model)
with self.assertRaises(RuntimeError) as error:
interpreter.allocate_tensors()
self.assertIn(
'Regular TensorFlow ops are not supported by this interpreter. Make '
'sure you invoke the Flex delegate before inference.',
str(error.exception))
@test_util.run_v2_only
def testConcreteFunc(self):
input_data = constant_op.constant(1., shape=[1])
root = tracking.AutoTrackable()
root.v1 = variables.Variable(3.)
root.v2 = variables.Variable(2.)
root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
concrete_func = root.f.get_concrete_function(input_data)
# Convert model.
converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
converter.experimental_enable_mlir_converter = True
converter.target_spec.supported_ops = set([lite.OpsSet.SELECT_TF_OPS])
tflite_model = mlir_convert_and_check_for_unsupported(self, converter)
if tflite_model is None:
return
# Ensures the model contains TensorFlow ops.
# TODO(nupurgarg): Check values once there is a Python delegate interface.
interpreter = Interpreter(model_content=tflite_model)
with self.assertRaises(RuntimeError) as error:
interpreter.allocate_tensors()
self.assertIn(
'Regular TensorFlow ops are not supported by this interpreter. Make '
'sure you invoke the Flex delegate before inference.',
str(error.exception))
if __name__ == '__main__':
test.main()

View File

@ -1,5 +1,5 @@
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
load("//tensorflow:tensorflow.bzl", "py_binary", "tf_py_test")
load("//tensorflow:tensorflow.bzl", "if_mlir", "py_binary", "tf_py_test")
package(
default_visibility = [
@ -22,19 +22,23 @@ cc_library(
name = "toco_python_api",
srcs = ["toco_python_api.cc"],
hdrs = ["toco_python_api.h"],
defines = if_mlir(
if_false = [],
if_true = ["TFLITE_BUILD_WITH_MLIR_CONVERTER"],
),
visibility = [
"//tensorflow/python:__subpackages__",
],
deps = [
"//third_party/python_runtime:headers",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/lite/python/interpreter_wrapper:python_utils",
"//tensorflow/lite/toco:model_flags_proto_cc",
"//tensorflow/lite/toco:toco_flags_proto_cc",
"//tensorflow/lite/toco:toco_graphviz_dump_options",
"//tensorflow/lite/toco:toco_port",
"//tensorflow/lite/toco:toco_tooling",
"//tensorflow/core:protos_all_cc",
] + select({
# This is required when running `tflite_convert` from `bazel`.
# It requires to link with TensorFlow Ops to get the op definitions.
@ -42,7 +46,10 @@ cc_library(
"//tensorflow/core:ops",
],
"//conditions:default": [],
}),
}) + if_mlir(
if_false = [],
if_true = ["//tensorflow/compiler/mlir/lite/python:graphdef_to_tfl_flatbuffer"],
),
)
# Compatibility stub. Remove when internal customers moved.

View File

@ -28,11 +28,10 @@ limitations under the License.
#include "tensorflow/lite/toco/toco_tooling.h"
#include "tensorflow/lite/toco/toco_types.h"
#if defined(PLATFORM_GOOGLE)
#if defined(TFLITE_BUILD_WITH_MLIR_CONVERTER)
#include "tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.h"
#else
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
#endif
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
namespace toco {
@ -125,13 +124,13 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
// Convert model.
if (enable_mlir_converter) {
#if defined(PLATFORM_GOOGLE)
#if defined(TFLITE_BUILD_WITH_MLIR_CONVERTER)
status = tensorflow::ConvertGraphDefToTFLiteFlatBuffer(
model_flags, toco_flags, debug_info, graph_def,
&output_file_contents_txt);
#else
// TODO(b/124314620): Remove this condition.
PyErr_SetString(PyExc_Exception,
PyErr_SetString(PyExc_RuntimeError,
"This flag is not supported by this version of the "
"TFLite converter. This functionality is being "
"actively worked on.");