STT-tensorflow/tensorflow/compiler/mlir/python/mlir.i
Mehdi Amini cf3c40795d Implement a basic API: tf.mlir.experimental.convert_graph_def
Load a .pbptx, convert to MLIR, and (optionally) optimize the module before
returning it as a string.
This is an early experimental API, intended for example to play with some
some colab examples during development.

PiperOrigin-RevId: 268073235
2019-09-09 15:00:00 -07:00

75 lines
2.2 KiB
OpenEdge ABL

/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
%include "tensorflow/python/platform/base.i"
%{
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
namespace tensorflow {
namespace swig {
// Simple wrapper to support tf.mlir.experimental.convert_graph_def.
// Load a .pbptx, convert to MLIR, and (optionally) optimize the module before
// returning it as a string.
// This is an early experimental API, ideally we should return a wrapper object
// around a Python binding to the MLIR module.
string ImportGraphDef(const string &proto, TF_Status* status) {
GraphDef graphdef;
auto s = tensorflow::LoadProtoFromBuffer(proto, &graphdef);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
return "// error";
}
GraphDebugInfo debug_info;
NodeSpecs specs;
mlir::MLIRContext context;
auto module = ConvertGraphdefToMlir(graphdef, debug_info, specs, &context);
if (!module.ok()) {
Set_TF_Status_from_Status(status, module.status());
return "// error";
}
return MlirModuleToString(*module.ConsumeValueOrDie());
}
} // namespace swig
} // namespace tensorflow
%}
%ignoreall
%unignore tensorflow;
%unignore tensorflow::swig;
%unignore tensorflow::swig::ImportGraphDef;
// Wrap this function
namespace tensorflow {
namespace swig {
static string ImportGraphDef(const string &graphdef, TF_Status* status);
} // namespace swig
} // namespace tensorflow
%insert("python") %{
def import_graphdef(graphdef):
return str(ImportGraphDef(str(graphdef).encode('utf-8')));
%}
%unignoreall