Implement a basic API: tf.mlir.experimental.convert_graph_def
Load a .pbptx, convert to MLIR, and (optionally) optimize the module before returning it as a string. This is an early experimental API, intended for example to play with some some colab examples during development. PiperOrigin-RevId: 268073235
This commit is contained in:
parent
3e696f8865
commit
cf3c40795d
11
tensorflow/compiler/mlir/python/BUILD
Normal file
11
tensorflow/compiler/mlir/python/BUILD
Normal file
@ -0,0 +1,11 @@
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
exports_files(
|
||||
["mlir.i"],
|
||||
visibility = [
|
||||
"//tensorflow/python:__subpackages__",
|
||||
],
|
||||
)
|
74
tensorflow/compiler/mlir/python/mlir.i
Normal file
74
tensorflow/compiler/mlir/python/mlir.i
Normal file
@ -0,0 +1,74 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
%include "tensorflow/python/platform/base.i"
|
||||
|
||||
%{
|
||||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace swig {
|
||||
|
||||
// Simple wrapper to support tf.mlir.experimental.convert_graph_def.
|
||||
// Load a .pbptx, convert to MLIR, and (optionally) optimize the module before
|
||||
// returning it as a string.
|
||||
// This is an early experimental API, ideally we should return a wrapper object
|
||||
// around a Python binding to the MLIR module.
|
||||
string ImportGraphDef(const string &proto, TF_Status* status) {
|
||||
GraphDef graphdef;
|
||||
auto s = tensorflow::LoadProtoFromBuffer(proto, &graphdef);
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
return "// error";
|
||||
}
|
||||
GraphDebugInfo debug_info;
|
||||
NodeSpecs specs;
|
||||
mlir::MLIRContext context;
|
||||
auto module = ConvertGraphdefToMlir(graphdef, debug_info, specs, &context);
|
||||
if (!module.ok()) {
|
||||
Set_TF_Status_from_Status(status, module.status());
|
||||
return "// error";
|
||||
}
|
||||
|
||||
return MlirModuleToString(*module.ConsumeValueOrDie());
|
||||
}
|
||||
|
||||
} // namespace swig
|
||||
} // namespace tensorflow
|
||||
|
||||
%}
|
||||
|
||||
%ignoreall
|
||||
|
||||
%unignore tensorflow;
|
||||
%unignore tensorflow::swig;
|
||||
%unignore tensorflow::swig::ImportGraphDef;
|
||||
|
||||
// Wrap this function
|
||||
namespace tensorflow {
|
||||
namespace swig {
|
||||
static string ImportGraphDef(const string &graphdef, TF_Status* status);
|
||||
} // namespace swig
|
||||
} // namespace tensorflow
|
||||
|
||||
%insert("python") %{
|
||||
def import_graphdef(graphdef):
|
||||
return str(ImportGraphDef(str(graphdef).encode('utf-8')));
|
||||
%}
|
||||
|
||||
%unignoreall
|
@ -1774,4 +1774,13 @@ StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir(
|
||||
add_default_attributes, context);
|
||||
}
|
||||
|
||||
std::string MlirModuleToString(mlir::ModuleOp module) {
|
||||
std::string txt_module;
|
||||
{
|
||||
llvm::raw_string_ostream os{txt_module};
|
||||
module.print(os);
|
||||
}
|
||||
return txt_module;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -16,6 +16,8 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_IMPORT_MODEL_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_IMPORT_MODEL_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "tensorflow/cc/saved_model/loader.h"
|
||||
@ -48,6 +50,9 @@ stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir(
|
||||
const SavedModelBundle& saved_model, const GraphDebugInfo& debug_info,
|
||||
mlir::MLIRContext* context, bool add_default_attributes = true);
|
||||
|
||||
// Serialize a MLIR module to a string.
|
||||
std::string MlirModuleToString(mlir::ModuleOp m);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_IMPORT_MODEL_H_
|
||||
|
@ -37,6 +37,30 @@ inline llvm::StringRef StringViewToRef(absl::string_view view) {
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Status LoadProtoFromBuffer(absl::string_view input,
|
||||
tensorflow::protobuf::Message* proto) {
|
||||
tensorflow::protobuf::TextFormat::Parser parser;
|
||||
// Don't produce errors when attempting to parse text format as it would fail
|
||||
// when the input is actually a binary file.
|
||||
NoOpErrorCollector collector;
|
||||
parser.RecordErrorsTo(&collector);
|
||||
// Attempt to parse as text.
|
||||
tensorflow::protobuf::io::ArrayInputStream input_stream(input.data(),
|
||||
input.size());
|
||||
if (parser.Parse(&input_stream, proto)) {
|
||||
return Status::OK();
|
||||
}
|
||||
// Else attempt to parse as binary.
|
||||
proto->Clear();
|
||||
tensorflow::protobuf::io::ArrayInputStream binary_stream(input.data(),
|
||||
input.size());
|
||||
if (proto->ParseFromZeroCopyStream(&binary_stream)) {
|
||||
return Status::OK();
|
||||
}
|
||||
LOG(ERROR) << "Error parsing Protobuf";
|
||||
return errors::InvalidArgument("Could not parse input proto");
|
||||
}
|
||||
|
||||
Status LoadProtoFromFile(absl::string_view input_filename,
|
||||
tensorflow::protobuf::Message* proto) {
|
||||
auto file_or_err =
|
||||
@ -45,26 +69,10 @@ Status LoadProtoFromFile(absl::string_view input_filename,
|
||||
return errors::InvalidArgument("Could not open input file");
|
||||
|
||||
auto& input_file = *file_or_err;
|
||||
std::string content(input_file->getBufferStart(),
|
||||
input_file->getBufferSize());
|
||||
absl::string_view content(input_file->getBufferStart(),
|
||||
input_file->getBufferSize());
|
||||
|
||||
tensorflow::protobuf::TextFormat::Parser parser;
|
||||
// Don't produce errors when attempting to parse text format as it would fail
|
||||
// when the input is actually a binary file.
|
||||
NoOpErrorCollector collector;
|
||||
parser.RecordErrorsTo(&collector);
|
||||
// Attempt to parse as text.
|
||||
if (parser.ParseFromString(content, proto)) {
|
||||
return Status::OK();
|
||||
}
|
||||
// Else attempt to parse as binary.
|
||||
proto->Clear();
|
||||
std::istringstream istream(content);
|
||||
if (proto->ParseFromIstream(&istream)) {
|
||||
return Status::OK();
|
||||
}
|
||||
LOG(ERROR) << "Error parsing Protobuf: " << input_filename;
|
||||
return errors::InvalidArgument("Could not parse input file");
|
||||
return LoadProtoFromBuffer(content, proto);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -22,6 +22,11 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Reads text (.pbtext) or binary (.pb) format of a proto message from the given
|
||||
// buffer. Returns error status of the file is not found or malformed proto.
|
||||
Status LoadProtoFromBuffer(absl::string_view input,
|
||||
tensorflow::protobuf::Message* proto);
|
||||
|
||||
// Reads text (.pbtext) or binary (.pb) format of a proto message from the given
|
||||
// file path. Returns error status of the file is not found or malformed proto.
|
||||
Status LoadProtoFromFile(absl::string_view input_filename,
|
||||
|
@ -5029,6 +5029,7 @@ tf_py_wrap_cc(
|
||||
"util/scoped_annotation.i",
|
||||
"util/traceme.i",
|
||||
"util/transform_graph.i",
|
||||
"//tensorflow/compiler/mlir/python:mlir.i",
|
||||
"//tensorflow/lite/toco/python:toco.i",
|
||||
],
|
||||
# add win_def_file for pywrap_tensorflow
|
||||
|
@ -160,6 +160,9 @@ from tensorflow.python.ops import rnn_cell
|
||||
from tensorflow.python.compiler.xla import jit
|
||||
from tensorflow.python.compiler.xla import xla
|
||||
|
||||
# MLIR APIs.
|
||||
from tensorflow.python.compiler.mlir import mlir
|
||||
|
||||
# Required due to `rnn` and `rnn_cell` not being imported in `nn` directly
|
||||
# (due to a circular dependency issue: rnn depends on layers).
|
||||
nn.dynamic_rnn = rnn.dynamic_rnn
|
||||
|
@ -17,6 +17,7 @@ py_library(
|
||||
deps = if_not_windows([
|
||||
"//tensorflow/python/compiler/tensorrt:init_py",
|
||||
]) + [
|
||||
"//tensorflow/python/compiler/mlir",
|
||||
"//tensorflow/python/compiler/xla:compiler_py",
|
||||
],
|
||||
)
|
||||
|
27
tensorflow/python/compiler/mlir/BUILD
Normal file
27
tensorflow/python/compiler/mlir/BUILD
Normal file
@ -0,0 +1,27 @@
|
||||
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "mlir",
|
||||
srcs = ["mlir.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:pywrap_tensorflow",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "mlir_test",
|
||||
srcs = ["mlir_test.py"],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
":mlir",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:platform",
|
||||
],
|
||||
)
|
38
tensorflow/python/compiler/mlir/mlir.py
Normal file
38
tensorflow/python/compiler/mlir/mlir.py
Normal file
@ -0,0 +1,38 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
"""mlir is an experimental library that provides support APIs for MLIR."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow as import_graphdef
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@tf_export('mlir.experimental.convert_graph_def')
|
||||
def convert_graph_def(graph_def):
|
||||
"""Import a GraphDef and convert it to a textual MLIR module.
|
||||
|
||||
Args:
|
||||
graph_def: An object of type graph_pb2.GraphDef or a textual proto
|
||||
representation of a valid GraphDef.
|
||||
|
||||
Returns:
|
||||
A textual representation of the MLIR module corresponding to the graphdef.
|
||||
Raises a RuntimeError on error.
|
||||
|
||||
"""
|
||||
return import_graphdef.import_graphdef(graph_def)
|
41
tensorflow/python/compiler/mlir/mlir_test.py
Normal file
41
tensorflow/python/compiler/mlir/mlir_test.py
Normal file
@ -0,0 +1,41 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
"""Tests for python.compiler.mlir."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.compiler.mlir import mlir
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class MLIRImportTest(test.TestCase):
|
||||
|
||||
def test_import_graph_def(self):
|
||||
"""Tests the basic flow of `tf.mlir.experimental.convert_graph_def`."""
|
||||
mlir_module = mlir.convert_graph_def('')
|
||||
# An empty graph should contain at least an empty main function.
|
||||
self.assertIn('func @main', mlir_module)
|
||||
|
||||
def test_invalid_pbtxt(self):
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
'Could not parse input proto'):
|
||||
mlir.convert_graph_def('some invalid proto')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -55,3 +55,5 @@ limitations under the License.
|
||||
|
||||
%include "tensorflow/python/util/traceme.i"
|
||||
%include "tensorflow/python/util/scoped_annotation.i"
|
||||
|
||||
%include "tensorflow/compiler/mlir/python/mlir.i"
|
||||
|
@ -39,6 +39,8 @@ TENSORFLOW_API_INIT_FILES = [
|
||||
"math/__init__.py",
|
||||
"mixed_precision/__init__.py",
|
||||
"mixed_precision/experimental/__init__.py",
|
||||
"mlir/__init__.py",
|
||||
"mlir/experimental/__init__.py",
|
||||
"nest/__init__.py",
|
||||
"nn/__init__.py",
|
||||
"quantization/__init__.py",
|
||||
|
@ -49,6 +49,8 @@ TENSORFLOW_API_INIT_FILES_V1 = [
|
||||
"manip/__init__.py",
|
||||
"math/__init__.py",
|
||||
"metrics/__init__.py",
|
||||
"mlir/__init__.py",
|
||||
"mlir/experimental/__init__.py",
|
||||
"nest/__init__.py",
|
||||
"nn/__init__.py",
|
||||
"nn/rnn_cell/__init__.py",
|
||||
|
@ -0,0 +1,7 @@
|
||||
path: "tensorflow.mlir.experimental"
|
||||
tf_module {
|
||||
member_method {
|
||||
name: "convert_graph_def"
|
||||
argspec: "args=[\'graph_def\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
7
tensorflow/tools/api/golden/v1/tensorflow.mlir.pbtxt
Normal file
7
tensorflow/tools/api/golden/v1/tensorflow.mlir.pbtxt
Normal file
@ -0,0 +1,7 @@
|
||||
path: "tensorflow.mlir"
|
||||
tf_module {
|
||||
member {
|
||||
name: "experimental"
|
||||
mtype: "<type \'module\'>"
|
||||
}
|
||||
}
|
@ -508,6 +508,10 @@ tf_module {
|
||||
name: "metrics"
|
||||
mtype: "<type \'module\'>"
|
||||
}
|
||||
member {
|
||||
name: "mlir"
|
||||
mtype: "<type \'module\'>"
|
||||
}
|
||||
member {
|
||||
name: "name_scope"
|
||||
mtype: "<type \'type\'>"
|
||||
|
@ -0,0 +1,7 @@
|
||||
path: "tensorflow.mlir.experimental"
|
||||
tf_module {
|
||||
member_method {
|
||||
name: "convert_graph_def"
|
||||
argspec: "args=[\'graph_def\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
7
tensorflow/tools/api/golden/v2/tensorflow.mlir.pbtxt
Normal file
7
tensorflow/tools/api/golden/v2/tensorflow.mlir.pbtxt
Normal file
@ -0,0 +1,7 @@
|
||||
path: "tensorflow.mlir"
|
||||
tf_module {
|
||||
member {
|
||||
name: "experimental"
|
||||
mtype: "<type \'module\'>"
|
||||
}
|
||||
}
|
@ -260,6 +260,10 @@ tf_module {
|
||||
name: "mixed_precision"
|
||||
mtype: "<type \'module\'>"
|
||||
}
|
||||
member {
|
||||
name: "mlir"
|
||||
mtype: "<type \'module\'>"
|
||||
}
|
||||
member {
|
||||
name: "name_scope"
|
||||
mtype: "<type \'type\'>"
|
||||
|
Loading…
Reference in New Issue
Block a user