STT-tensorflow/tensorflow/python/compiler/mlir/mlir_test.py
Andy Ly ba167d161e Add python wrapper mlir.experimental.convert_function for importing ConcreteFunctions into TF MLIR.
This takes a ConcreteFunction, collects a FunctionDef for the function and an associated FunctionDefLibrary, and imports the FunctionDef and FunctionDefLibrary via `ConvertFunctionToMlir`.
Control rets/target nodes of the entry function are also now supported in `ConvertFunctionToMlir`.

PiperOrigin-RevId: 331195841
Change-Id: Ib3a7264e90ca303ab7a850bf18c8d5e330063a4f
2020-09-11 12:21:46 -07:00

87 lines
2.9 KiB
Python

# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Tests for python.compiler.mlir."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.compiler.mlir import mlir
from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import logging_ops
from tensorflow.python.platform import test
class MLIRGraphDefImportTest(test.TestCase):
def testImport(self):
"""Tests the basic flow of `tf.mlir.experimental.convert_graph_def`."""
mlir_module = mlir.convert_graph_def('')
# An empty graph should contain at least an empty main function.
self.assertIn('func @main', mlir_module)
def testInvalidPbtxt(self):
with self.assertRaisesRegex(errors.InvalidArgumentError,
'Could not parse input proto'):
mlir.convert_graph_def('some invalid proto')
class MLIRConcreteFunctionImportTest(test.TestCase):
def testImport(self):
@def_function.function
def identity(i):
return i
concrete_function = identity.get_concrete_function(
tensor_spec.TensorSpec(None, dtypes.float32))
mlir_module = mlir.convert_function(concrete_function)
self.assertRegex(mlir_module, r'func @.*identity.*\(')
def testImportWithCall(self):
@def_function.function
def callee(i):
return i
@def_function.function
def caller(i):
return callee(i)
concrete_function = caller.get_concrete_function(
tensor_spec.TensorSpec(None, dtypes.float32))
mlir_module = mlir.convert_function(concrete_function)
self.assertRegex(mlir_module, r'func @.*caller.*\(')
self.assertRegex(mlir_module, r'func @.*callee.*\(')
def testImportWithControlRet(self):
@def_function.function
def logging():
logging_ops.print_v2('some message')
concrete_function = logging.get_concrete_function()
mlir_module = mlir.convert_function(concrete_function, pass_pipeline='')
self.assertRegex(mlir_module, r'tf\.PrintV2')
self.assertRegex(mlir_module, r'tf_executor.fetch.*: !tf_executor.control')
if __name__ == '__main__':
test.main()