Implement a basic API: tf.mlir.experimental.convert_graph_def

Load a .pbptx, convert to MLIR, and (optionally) optimize the module before
returning it as a string.
This is an early experimental API, intended for example to play with some
some colab examples during development.

PiperOrigin-RevId: 268073235
This commit is contained in:
Mehdi Amini 2019-09-09 14:08:10 -07:00 committed by TensorFlower Gardener
parent 3e696f8865
commit cf3c40795d
21 changed files with 284 additions and 19 deletions

View File

@ -0,0 +1,11 @@
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
exports_files(
["mlir.i"],
visibility = [
"//tensorflow/python:__subpackages__",
],
)

View File

@ -0,0 +1,74 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
%include "tensorflow/python/platform/base.i"
%{
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
namespace tensorflow {
namespace swig {
// Simple wrapper to support tf.mlir.experimental.convert_graph_def.
// Load a .pbptx, convert to MLIR, and (optionally) optimize the module before
// returning it as a string.
// This is an early experimental API, ideally we should return a wrapper object
// around a Python binding to the MLIR module.
string ImportGraphDef(const string &proto, TF_Status* status) {
GraphDef graphdef;
auto s = tensorflow::LoadProtoFromBuffer(proto, &graphdef);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
return "// error";
}
GraphDebugInfo debug_info;
NodeSpecs specs;
mlir::MLIRContext context;
auto module = ConvertGraphdefToMlir(graphdef, debug_info, specs, &context);
if (!module.ok()) {
Set_TF_Status_from_Status(status, module.status());
return "// error";
}
return MlirModuleToString(*module.ConsumeValueOrDie());
}
} // namespace swig
} // namespace tensorflow
%}
%ignoreall
%unignore tensorflow;
%unignore tensorflow::swig;
%unignore tensorflow::swig::ImportGraphDef;
// Wrap this function
namespace tensorflow {
namespace swig {
static string ImportGraphDef(const string &graphdef, TF_Status* status);
} // namespace swig
} // namespace tensorflow
%insert("python") %{
def import_graphdef(graphdef):
return str(ImportGraphDef(str(graphdef).encode('utf-8')));
%}
%unignoreall

View File

@ -1774,4 +1774,13 @@ StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir(
add_default_attributes, context);
}
std::string MlirModuleToString(mlir::ModuleOp module) {
std::string txt_module;
{
llvm::raw_string_ostream os{txt_module};
module.print(os);
}
return txt_module;
}
} // namespace tensorflow

View File

@ -16,6 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_IMPORT_MODEL_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_IMPORT_MODEL_H_
#include <string>
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "tensorflow/cc/saved_model/loader.h"
@ -48,6 +50,9 @@ stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir(
const SavedModelBundle& saved_model, const GraphDebugInfo& debug_info,
mlir::MLIRContext* context, bool add_default_attributes = true);
// Serialize a MLIR module to a string.
std::string MlirModuleToString(mlir::ModuleOp m);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_IMPORT_MODEL_H_

View File

@ -37,6 +37,30 @@ inline llvm::StringRef StringViewToRef(absl::string_view view) {
namespace tensorflow {
Status LoadProtoFromBuffer(absl::string_view input,
tensorflow::protobuf::Message* proto) {
tensorflow::protobuf::TextFormat::Parser parser;
// Don't produce errors when attempting to parse text format as it would fail
// when the input is actually a binary file.
NoOpErrorCollector collector;
parser.RecordErrorsTo(&collector);
// Attempt to parse as text.
tensorflow::protobuf::io::ArrayInputStream input_stream(input.data(),
input.size());
if (parser.Parse(&input_stream, proto)) {
return Status::OK();
}
// Else attempt to parse as binary.
proto->Clear();
tensorflow::protobuf::io::ArrayInputStream binary_stream(input.data(),
input.size());
if (proto->ParseFromZeroCopyStream(&binary_stream)) {
return Status::OK();
}
LOG(ERROR) << "Error parsing Protobuf";
return errors::InvalidArgument("Could not parse input proto");
}
Status LoadProtoFromFile(absl::string_view input_filename,
tensorflow::protobuf::Message* proto) {
auto file_or_err =
@ -45,26 +69,10 @@ Status LoadProtoFromFile(absl::string_view input_filename,
return errors::InvalidArgument("Could not open input file");
auto& input_file = *file_or_err;
std::string content(input_file->getBufferStart(),
input_file->getBufferSize());
absl::string_view content(input_file->getBufferStart(),
input_file->getBufferSize());
tensorflow::protobuf::TextFormat::Parser parser;
// Don't produce errors when attempting to parse text format as it would fail
// when the input is actually a binary file.
NoOpErrorCollector collector;
parser.RecordErrorsTo(&collector);
// Attempt to parse as text.
if (parser.ParseFromString(content, proto)) {
return Status::OK();
}
// Else attempt to parse as binary.
proto->Clear();
std::istringstream istream(content);
if (proto->ParseFromIstream(&istream)) {
return Status::OK();
}
LOG(ERROR) << "Error parsing Protobuf: " << input_filename;
return errors::InvalidArgument("Could not parse input file");
return LoadProtoFromBuffer(content, proto);
}
} // namespace tensorflow

View File

@ -22,6 +22,11 @@ limitations under the License.
namespace tensorflow {
// Reads text (.pbtext) or binary (.pb) format of a proto message from the given
// buffer. Returns error status of the file is not found or malformed proto.
Status LoadProtoFromBuffer(absl::string_view input,
tensorflow::protobuf::Message* proto);
// Reads text (.pbtext) or binary (.pb) format of a proto message from the given
// file path. Returns error status of the file is not found or malformed proto.
Status LoadProtoFromFile(absl::string_view input_filename,

View File

@ -5029,6 +5029,7 @@ tf_py_wrap_cc(
"util/scoped_annotation.i",
"util/traceme.i",
"util/transform_graph.i",
"//tensorflow/compiler/mlir/python:mlir.i",
"//tensorflow/lite/toco/python:toco.i",
],
# add win_def_file for pywrap_tensorflow

View File

@ -160,6 +160,9 @@ from tensorflow.python.ops import rnn_cell
from tensorflow.python.compiler.xla import jit
from tensorflow.python.compiler.xla import xla
# MLIR APIs.
from tensorflow.python.compiler.mlir import mlir
# Required due to `rnn` and `rnn_cell` not being imported in `nn` directly
# (due to a circular dependency issue: rnn depends on layers).
nn.dynamic_rnn = rnn.dynamic_rnn

View File

@ -17,6 +17,7 @@ py_library(
deps = if_not_windows([
"//tensorflow/python/compiler/tensorrt:init_py",
]) + [
"//tensorflow/python/compiler/mlir",
"//tensorflow/python/compiler/xla:compiler_py",
],
)

View File

@ -0,0 +1,27 @@
load("//tensorflow:tensorflow.bzl", "py_test")
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
py_library(
name = "mlir",
srcs = ["mlir.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:util",
],
)
py_test(
name = "mlir_test",
srcs = ["mlir_test.py"],
python_version = "PY3",
deps = [
":mlir",
"//tensorflow/python:client_testlib",
"//tensorflow/python:platform",
],
)

View File

@ -0,0 +1,38 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""mlir is an experimental library that provides support APIs for MLIR."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python import pywrap_tensorflow as import_graphdef
from tensorflow.python.util.tf_export import tf_export
@tf_export('mlir.experimental.convert_graph_def')
def convert_graph_def(graph_def):
"""Import a GraphDef and convert it to a textual MLIR module.
Args:
graph_def: An object of type graph_pb2.GraphDef or a textual proto
representation of a valid GraphDef.
Returns:
A textual representation of the MLIR module corresponding to the graphdef.
Raises a RuntimeError on error.
"""
return import_graphdef.import_graphdef(graph_def)

View File

@ -0,0 +1,41 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Tests for python.compiler.mlir."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.compiler.mlir import mlir
from tensorflow.python.framework import errors
from tensorflow.python.platform import test
class MLIRImportTest(test.TestCase):
def test_import_graph_def(self):
"""Tests the basic flow of `tf.mlir.experimental.convert_graph_def`."""
mlir_module = mlir.convert_graph_def('')
# An empty graph should contain at least an empty main function.
self.assertIn('func @main', mlir_module)
def test_invalid_pbtxt(self):
with self.assertRaisesRegexp(errors.InvalidArgumentError,
'Could not parse input proto'):
mlir.convert_graph_def('some invalid proto')
if __name__ == '__main__':
test.main()

View File

@ -55,3 +55,5 @@ limitations under the License.
%include "tensorflow/python/util/traceme.i"
%include "tensorflow/python/util/scoped_annotation.i"
%include "tensorflow/compiler/mlir/python/mlir.i"

View File

@ -39,6 +39,8 @@ TENSORFLOW_API_INIT_FILES = [
"math/__init__.py",
"mixed_precision/__init__.py",
"mixed_precision/experimental/__init__.py",
"mlir/__init__.py",
"mlir/experimental/__init__.py",
"nest/__init__.py",
"nn/__init__.py",
"quantization/__init__.py",

View File

@ -49,6 +49,8 @@ TENSORFLOW_API_INIT_FILES_V1 = [
"manip/__init__.py",
"math/__init__.py",
"metrics/__init__.py",
"mlir/__init__.py",
"mlir/experimental/__init__.py",
"nest/__init__.py",
"nn/__init__.py",
"nn/rnn_cell/__init__.py",

View File

@ -0,0 +1,7 @@
path: "tensorflow.mlir.experimental"
tf_module {
member_method {
name: "convert_graph_def"
argspec: "args=[\'graph_def\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,7 @@
path: "tensorflow.mlir"
tf_module {
member {
name: "experimental"
mtype: "<type \'module\'>"
}
}

View File

@ -508,6 +508,10 @@ tf_module {
name: "metrics"
mtype: "<type \'module\'>"
}
member {
name: "mlir"
mtype: "<type \'module\'>"
}
member {
name: "name_scope"
mtype: "<type \'type\'>"

View File

@ -0,0 +1,7 @@
path: "tensorflow.mlir.experimental"
tf_module {
member_method {
name: "convert_graph_def"
argspec: "args=[\'graph_def\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,7 @@
path: "tensorflow.mlir"
tf_module {
member {
name: "experimental"
mtype: "<type \'module\'>"
}
}

View File

@ -260,6 +260,10 @@ tf_module {
name: "mixed_precision"
mtype: "<type \'module\'>"
}
member {
name: "mlir"
mtype: "<type \'module\'>"
}
member {
name: "name_scope"
mtype: "<type \'type\'>"