Add json translation for tfjs mlir converter.
TFJS ops are registered as TF custom ops, and utilize export_graphdef.cc to build out the GraphDef object that could contain both TF and TFJS dialects. PiperOrigin-RevId: 311158257 Change-Id: I7313a5a01f12ef742a97fd5e9ff2bbffe8498b0c
This commit is contained in:
parent
b661070db9
commit
2407170feb
@ -70,9 +70,9 @@ tool_dirs = config.mlir_tf_tools_dirs + [
|
||||
]
|
||||
tool_names = [
|
||||
'mlir-opt', 'mlir-translate', 'tf-opt', 'tf_tfl_translate',
|
||||
'flatbuffer_to_string', 'flatbuffer_translate', 'tf-mlir-translate',
|
||||
'mlir-tflite-runner', 'tfcompile', 'json_to_flatbuffer', 'xla-gpu-opt',
|
||||
'xla-opt'
|
||||
'tf_tfjs_translate', 'flatbuffer_to_string', 'flatbuffer_translate',
|
||||
'tf-mlir-translate', 'mlir-tflite-runner', 'tfcompile',
|
||||
'json_to_flatbuffer', 'xla-gpu-opt', 'xla-opt'
|
||||
]
|
||||
tools = [ToolSubst(s, unresolved='ignore') for s in tool_names]
|
||||
llvm_config.add_tool_substitutions(tools, tool_dirs)
|
||||
|
@ -44,6 +44,7 @@ mlir_tf_tools_dirs = [
|
||||
'tensorflow/compiler/mlir',
|
||||
'tensorflow/compiler/mlir/lite',
|
||||
'tensorflow/compiler/mlir/tensorflow',
|
||||
'tensorflow/compiler/mlir/tfjs',
|
||||
'tensorflow/compiler/mlir/xla',
|
||||
'tensorflow/compiler/aot',
|
||||
'tensorflow/compiler/xla/service/mlir_gpu',
|
||||
|
@ -59,6 +59,18 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
// static TensorFlow op prefix set.
|
||||
std::set<std::string>* GlobalOpPrefixes() {
|
||||
static std::set<std::string>* global_op_prefixes = [] {
|
||||
std::set<std::string>* result = new std::set<std::string>;
|
||||
result->insert("tf.");
|
||||
result->insert("_tf.");
|
||||
result->insert("tf_executor.");
|
||||
return result;
|
||||
}();
|
||||
return global_op_prefixes;
|
||||
}
|
||||
|
||||
// Converts a location to the debug information for the node def.
|
||||
Status ConvertLocation(mlir::Location inst_loc,
|
||||
NodeDef::ExperimentalDebugInfo* debug_info) {
|
||||
@ -268,8 +280,10 @@ StatusOr<llvm::StringRef> GetTensorFlowOpName(llvm::StringRef op_name) {
|
||||
// - ".sink" or ".Sink": only the NextIteration operation has this suffix. We
|
||||
// don't need to consider ".source"/".Source" because the nodes with this
|
||||
// suffix are skipped by the caller and will not be added to the graph.
|
||||
if (!op_name.consume_front("_tf.") && !op_name.consume_front("tf.") &&
|
||||
!op_name.consume_front("tf_executor.")) {
|
||||
auto prefixes = GlobalOpPrefixes();
|
||||
if (std::none_of(prefixes->begin(), prefixes->end(), [&](std::string prefix) {
|
||||
return op_name.consume_front(prefix);
|
||||
})) {
|
||||
return errors::FailedPrecondition("op node '", op_name.str(),
|
||||
"' was not a TF op!");
|
||||
}
|
||||
@ -506,4 +520,9 @@ bool IsLegacyCallInstruction(mlir::Operation* inst) {
|
||||
inst->getName().getStringRef().compare("_tf.LegacyCall") == 0;
|
||||
}
|
||||
|
||||
Status AddTensorFlowOpPrefix(std::string prefix) {
|
||||
GlobalOpPrefixes()->insert(prefix);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -34,10 +34,17 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
namespace mlir {
|
||||
class ShapedType;
|
||||
} // namespace mlir
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using stream_executor::port::StatusOr;
|
||||
|
||||
// Add custom op prefix for TensorFlow dialects.
|
||||
Status AddTensorFlowOpPrefix(std::string);
|
||||
|
||||
// Maps an MLIR op name in the TensorFlow dialect or the TensorFlow control
|
||||
// dialect back into a TensorFlow valid op name.
|
||||
StatusOr<llvm::StringRef> GetTensorFlowOpName(llvm::StringRef);
|
||||
|
@ -1,4 +1,5 @@
|
||||
load("//third_party/mlir:tblgen.bzl", "gentbl")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
@ -131,10 +132,106 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "json_translate_lib",
|
||||
srcs = [
|
||||
"translate/json_translate.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"translate/json_translate.h",
|
||||
],
|
||||
deps = [
|
||||
":tensorflow_js",
|
||||
":tensorflow_js_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
|
||||
"//tensorflow/compiler/mlir/tensorflow:export_utils",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/status",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Translation",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_to_tfjs_json",
|
||||
srcs = ["translate/tf_to_tfjs_json.cc"],
|
||||
hdrs = [
|
||||
"translate/tf_to_tfjs_json.h",
|
||||
],
|
||||
deps = [
|
||||
":json_translate_lib",
|
||||
":tfjs_optimize",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:decode_constant_pass",
|
||||
"//tensorflow/compiler/mlir/tensorflow:error_util",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_cl_options",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Parser",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_cc_binary(
|
||||
name = "json_translate",
|
||||
deps = [
|
||||
":json_translate_lib",
|
||||
"@llvm-project//mlir:MlirTranslateMain",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "tf_tfjs_translate_main",
|
||||
srcs = [
|
||||
"translate/tf_tfjs_translate.cc",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_binary(
|
||||
name = "tf_tfjs_translate",
|
||||
srcs = [":tf_tfjs_translate_main"],
|
||||
deps = [
|
||||
":json_translate_lib",
|
||||
":tensorflow_js_passes",
|
||||
":tf_to_tfjs_json",
|
||||
":tfjs_optimize",
|
||||
"//tensorflow/compiler/mlir:init_mlir",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_cl_options",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/SideEffects.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace tfjs {
|
||||
|
||||
|
23
tensorflow/compiler/mlir/tfjs/tests/e2e/BUILD
Normal file
23
tensorflow/compiler/mlir/tfjs/tests/e2e/BUILD
Normal file
@ -0,0 +1,23 @@
|
||||
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
glob_lit_tests(
|
||||
data = [
|
||||
":test_utilities",
|
||||
],
|
||||
driver = "@llvm-project//mlir:run_lit.sh",
|
||||
test_file_exts = [
|
||||
"pbtxt",
|
||||
],
|
||||
)
|
||||
|
||||
# Bundle together all of the test utilities that are used by tests.
|
||||
filegroup(
|
||||
name = "test_utilities",
|
||||
testonly = True,
|
||||
data = [
|
||||
"//tensorflow/compiler/mlir/tfjs:tf_tfjs_translate",
|
||||
"@llvm-project//llvm:FileCheck",
|
||||
],
|
||||
)
|
78
tensorflow/compiler/mlir/tfjs/tests/e2e/add.pbtxt
Normal file
78
tensorflow/compiler/mlir/tfjs/tests/e2e/add.pbtxt
Normal file
@ -0,0 +1,78 @@
|
||||
# RUN: tf_tfjs_translate %s -tf-input-arrays=input0,input1 -tf-input-data-types=DT_INT32,DT_INT32 -tf-input-shapes=10:10 -tf-output-arrays=Mul -o - | FileCheck %s --dump-input-on-failure
|
||||
# Add two tensor<4xi32> inputs and return the result
|
||||
|
||||
node {
|
||||
name: "Add"
|
||||
op: "Add"
|
||||
input: "input0"
|
||||
input: "input1"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "input0"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "input1"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Mul"
|
||||
op: "Mul"
|
||||
input: "Add"
|
||||
input: "Add"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
versions {
|
||||
producer: 27
|
||||
}
|
||||
|
||||
# CHECK: "name": "input0"
|
||||
# CHECK-NEXT: "op": "Placeholder"
|
||||
# CHECK: "type": "DT_INT32"
|
||||
# CHECK: "name": "input1",
|
||||
# CHECK-NEXT: "op": "Placeholder"
|
||||
# CHECK: "type": "DT_INT32"
|
||||
# CHECK: "name": "Add"
|
||||
# CHECK-NEXT: "op": "AddV2"
|
||||
# CHECK-NEXT: "input":
|
||||
# CHECK-NEXT: "input0"
|
||||
# CHECK-NEXT: "input1"
|
||||
# CHECK: "type": "DT_INT32"
|
||||
# CHECK: "name": "Mul1"
|
||||
# CHECK-NEXT: "op": "Mul"
|
||||
# CHECK-NEXT: "input":
|
||||
# CHECK-NEXT: "Add"
|
||||
# CHECK-NEXT: "Add"
|
||||
# CHECK: "type": "DT_INT32"
|
||||
# CHECK: "name": "Mul"
|
||||
# CHECK-NEXT: "op": "_Retval"
|
||||
# CHECK-NEXT: "input":
|
||||
# CHECK-NEXT: "Mul1"
|
||||
# CHECK: "type": "DT_INT32"
|
||||
# CHECK: "library"
|
||||
# CHECK: "versions"
|
||||
# CHECK: "producer": 27
|
||||
|
175
tensorflow/compiler/mlir/tfjs/tests/e2e/prelu.pbtxt
Normal file
175
tensorflow/compiler/mlir/tfjs/tests/e2e/prelu.pbtxt
Normal file
@ -0,0 +1,175 @@
|
||||
# RUN: tf_tfjs_translate %s -tf-input-arrays=input0 -tf-input-data-types=DT_FLOAT -tf-input-shapes=10 -tf-output-arrays=Add -tf-custom-opdefs="name: 'Prelu' input_arg: { name: 'x' type: DT_FLOAT } input_arg: { name: 'alpha' type: DT_FLOAT } output_arg: { name: 'c' type: DT_FLOAT }" -o - | FileCheck %s --dump-input-on-failure
|
||||
# Add two tensor<4xi32> inputs and return the result
|
||||
|
||||
node {
|
||||
name: "input0"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
dim {
|
||||
size: 10
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "alpha"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape {
|
||||
}
|
||||
float_val: 0.5
|
||||
}
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Relu"
|
||||
op: "Relu"
|
||||
input: "input0"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Neg"
|
||||
op: "Neg"
|
||||
input: "input0"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Relu1"
|
||||
op: "Relu"
|
||||
input: "Neg"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Mul"
|
||||
op: "Mul"
|
||||
input: "alpha"
|
||||
input: "Relu1"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Add"
|
||||
op: "Add"
|
||||
input: "Relu"
|
||||
input: "Mul"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "main"
|
||||
op: "_Retval"
|
||||
input: "Add"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "index"
|
||||
value {
|
||||
i: 0
|
||||
}
|
||||
}
|
||||
}
|
||||
library {
|
||||
}
|
||||
versions {
|
||||
producer: 344
|
||||
}
|
||||
|
||||
# CHECK: "node":
|
||||
# CHECK: "name": "input0",
|
||||
# CHECK-NEXT: "op": "Placeholder",
|
||||
# CHECK-NEXT: "attr":
|
||||
# CHECK: "type": "DT_FLOAT"
|
||||
# CHECK: "name": "Add.Relu.Neg.Relu1.Mul",
|
||||
# CHECK-NEXT: "op": "Const",
|
||||
# CHECK-NEXT: "attr":
|
||||
# CHECK: "value":
|
||||
# CHECK: "tensor":
|
||||
# CHECK: "dtype": "DT_FLOAT",
|
||||
# CHECK: "tensorShape": {},
|
||||
# CHECK: "floatVal":
|
||||
# CHECK: -0.5
|
||||
# CHECK: "name": "Add.Relu.Neg.Relu1.Mul1",
|
||||
# CHECK-NEXT: "op": "Prelu",
|
||||
# CHECK-NEXT: "input":
|
||||
# CHECK: "input0",
|
||||
# CHECK: "Add.Relu.Neg.Relu1.Mul"
|
||||
# CHECK: "attr":
|
||||
# CHECK: "_output_shapes":
|
||||
# CHECK: "list":
|
||||
# CHECK: "shape":
|
||||
# CHECK: "dim":
|
||||
# CHECK: "size": "10"
|
||||
# CHECK: "experimentalDebugInfo": {}
|
||||
# CHECK: "name": "Add",
|
||||
# CHECK-NEXT: "op": "_Retval",
|
||||
# CHECK-NEXT: "input":
|
||||
# CHECK: "Add.Relu.Neg.Relu1.Mul1"
|
||||
# CHECK: "attr":
|
||||
# CHECK: "T":
|
||||
# CHECK: "type": "DT_FLOAT"
|
||||
# CHECK: "library": {},
|
||||
# CHECK: "versions":
|
||||
# CHECK: "producer": 344
|
||||
|
@ -1,4 +1,4 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -20,7 +20,6 @@ limitations under the License.
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "mlir/Transforms/Passes.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tfjs/transforms/passes.h"
|
||||
|
||||
@ -47,6 +46,11 @@ void AddTFToTFJSConversionPasses(mlir::OpPassManager* pm) {
|
||||
// Canonicalize, CSE etc.
|
||||
pm->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||
pm->addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
|
||||
|
||||
// raise to executor dialect in order to use GraphDef converter
|
||||
pm->addNestedPass<mlir::FuncOp>(
|
||||
mlir::CreateFunctionalToExecutorDialectConversionPass());
|
||||
pm->addNestedPass<mlir::FuncOp>(mlir::CreateBreakUpIslandsPass());
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
105
tensorflow/compiler/mlir/tfjs/translate/json_translate.cc
Normal file
105
tensorflow/compiler/mlir/tfjs/translate/json_translate.cc
Normal file
@ -0,0 +1,105 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/tfjs/translate/json_translate.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "mlir/Translation.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
using mlir::ModuleOp;
|
||||
using mlir::TranslateFromMLIRRegistration;
|
||||
using std::string;
|
||||
using tensorflow::Status;
|
||||
using xla::StatusOr;
|
||||
|
||||
// Translates the given MLIR module in the TFJS dialect to TFJS JSON
|
||||
// format. Returns false on success.
|
||||
//
|
||||
bool tfjs::MlirToJSONTranslateFunction(ModuleOp module,
|
||||
std::string* serialized_json) {
|
||||
string json_output;
|
||||
// Allow TF to treat TFJS ops as TF ops.
|
||||
if (!tensorflow::AddTensorFlowOpPrefix("tfjs.").ok()) {
|
||||
LOG(ERROR) << "Failed to add tfjs op prefix.";
|
||||
return false;
|
||||
}
|
||||
tensorflow::GraphExportConfig confs;
|
||||
confs.export_shapes = true;
|
||||
confs.export_library = true;
|
||||
tensorflow::FunctionLibraryDefinition flib_def(
|
||||
tensorflow::OpRegistry::Global(), tensorflow::FunctionDefLibrary());
|
||||
absl::flat_hash_set<tensorflow::Node*> control_ret_nodes;
|
||||
auto graph = absl::make_unique<tensorflow::Graph>(flib_def);
|
||||
auto status = tensorflow::ConvertMlirToGraph(module, confs, &graph, &flib_def,
|
||||
&control_ret_nodes);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Graph export failed: " << status;
|
||||
return false;
|
||||
}
|
||||
auto graphdef = absl::make_unique<tensorflow::GraphDef>();
|
||||
graph->ToGraphDef(graphdef.get());
|
||||
|
||||
// Replace the _Arg nodes of the main function with Placeholder op.
|
||||
auto nodes = graphdef->mutable_node();
|
||||
for (const auto& node : llvm::enumerate(*nodes)) {
|
||||
if (node.value().op() == "_Arg") {
|
||||
nodes->Mutable(node.index())->set_op("Placeholder");
|
||||
}
|
||||
}
|
||||
|
||||
tensorflow::protobuf::util::JsonPrintOptions json_options;
|
||||
json_options.add_whitespace = true;
|
||||
auto jsonStatus = tensorflow::protobuf::util::MessageToJsonString(
|
||||
*graphdef, &json_output, json_options);
|
||||
if (!jsonStatus.ok()) {
|
||||
LOG(ERROR) << "Proto2Json failed: " << status;
|
||||
return false;
|
||||
}
|
||||
*serialized_json = std::move(json_output);
|
||||
return true;
|
||||
}
|
||||
|
||||
static mlir::LogicalResult MlirToJSONFileTranslateFunction(
|
||||
ModuleOp module, llvm::raw_ostream& output) {
|
||||
std::string serialized_json;
|
||||
if (!tfjs::MlirToJSONTranslateFunction(module, &serialized_json))
|
||||
return mlir::failure();
|
||||
|
||||
output << serialized_json;
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
static TranslateFromMLIRRegistration MLIRToJSONFileTranslate(
|
||||
"mlir-to-tfjs-json", MlirToJSONFileTranslateFunction);
|
31
tensorflow/compiler/mlir/tfjs/translate/json_translate.h
Normal file
31
tensorflow/compiler/mlir/tfjs/translate/json_translate.h
Normal file
@ -0,0 +1,31 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_JSON_TRANSLATE_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_JSON_TRANSLATE_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tfjs {
|
||||
|
||||
// Translates the given MLIR `module` into a JSON string. Returns true if
|
||||
// translation fails, otherwise returns false.
|
||||
bool MlirToJSONTranslateFunction(mlir::ModuleOp module,
|
||||
std::string* serialized_json);
|
||||
} // namespace tfjs
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_JSON_TRANSLATE_H_
|
173
tensorflow/compiler/mlir/tfjs/translate/tf_tfjs_translate.cc
Normal file
173
tensorflow/compiler/mlir/tfjs/translate/tf_tfjs_translate.cc
Normal file
@ -0,0 +1,173 @@
|
||||
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/Support/InitLLVM.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
#include "mlir/IR/Diagnostics.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/FileUtilities.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/init_mlir.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h"
|
||||
#include "tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.h"
|
||||
#include "tensorflow/compiler/mlir/tfjs/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
using llvm::cl::opt;
|
||||
using mlir::MLIRContext;
|
||||
using stream_executor::port::StatusOr;
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
opt<std::string> input_file_name(llvm::cl::Positional,
|
||||
llvm::cl::desc("<input file>"),
|
||||
llvm::cl::init("-"));
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
opt<bool> import_saved_model_object_graph(
|
||||
"savedmodel-objectgraph-to-mlir",
|
||||
llvm::cl::desc("Import a saved model to its MLIR representation"),
|
||||
llvm::cl::value_desc("dir"));
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
opt<bool> import_saved_model_signature_defs(
|
||||
"savedmodel-signaturedefs-to-mlir",
|
||||
llvm::cl::desc("Import a saved model V1 to its MLIR representation"),
|
||||
llvm::cl::value_desc("dir"));
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
opt<std::string> saved_model_tags(
|
||||
"tf-savedmodel-tags",
|
||||
llvm::cl::desc("Tags used to indicate which MetaGraphDef to import, "
|
||||
"separated by ','"),
|
||||
llvm::cl::init("serve"));
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
opt<std::string> saved_model_exported_names(
|
||||
"tf-savedmodel-exported-names",
|
||||
llvm::cl::desc("Names to export from SavedModel, separated by ','. Empty "
|
||||
"(the default) means export all."),
|
||||
llvm::cl::init(""));
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
opt<std::string> output_file_name("o", llvm::cl::desc("<output file>"),
|
||||
llvm::cl::value_desc("filename"),
|
||||
llvm::cl::init("-"));
|
||||
// NOLINTNEXTLINE
|
||||
opt<bool> input_mlir(
|
||||
"input-mlir",
|
||||
llvm::cl::desc("Take input TensorFlow model in textual MLIR instead of "
|
||||
"GraphDef format"),
|
||||
llvm::cl::init(false), llvm::cl::Hidden);
|
||||
// NOLINTNEXTLINE
|
||||
opt<bool> output_mlir(
|
||||
"output-mlir",
|
||||
llvm::cl::desc("Output MLIR rather than JSON for the generated TFJS model"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
// The following approach allows injecting opdefs in addition
|
||||
// to those that are already part of the global TF registry to be linked in
|
||||
// prior to importing the graph. The primary goal is for support of custom ops.
|
||||
// This is not intended to be a general solution for custom ops for the future
|
||||
// but mainly for supporting older models like mobilenet_ssd. More appropriate
|
||||
// mechanisms, such as op hints or using functions to represent composable ops
|
||||
// like https://github.com/tensorflow/community/pull/113 should be encouraged
|
||||
// going forward.
|
||||
// NOLINTNEXTLINE
|
||||
llvm::cl::list<std::string> custom_opdefs(
|
||||
"tf-custom-opdefs", llvm::cl::desc("List of custom opdefs when importing "
|
||||
"graphdef"));
|
||||
|
||||
// Debugging flag to print function mapping in the JSON.
|
||||
// NOLINTNEXTLINE
|
||||
static opt<bool> print_function_result_mapping(
|
||||
"print-function-result-mapping",
|
||||
llvm::cl::desc(
|
||||
"Print the mapping of function result to json output buffer"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
enum TranslationStatus { kTrSuccess, kTrFailure };
|
||||
|
||||
static int PrintFunctionResultMapping(const std::string& result) {
|
||||
std::cout << result << std::endl;
|
||||
return kTrSuccess;
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
tensorflow::InitMlir y(&argc, &argv);
|
||||
|
||||
llvm::cl::ParseCommandLineOptions(argc, argv,
|
||||
"TF GraphDef to TFJS JSON converter\n");
|
||||
|
||||
MLIRContext context;
|
||||
llvm::SourceMgr source_mgr;
|
||||
mlir::SourceMgrDiagnosticHandler sourceMgrHandler(source_mgr, &context);
|
||||
|
||||
StatusOr<mlir::OwningModuleRef> module;
|
||||
|
||||
if (import_saved_model_object_graph || import_saved_model_signature_defs) {
|
||||
if (input_mlir)
|
||||
module = tensorflow::errors::InvalidArgument(
|
||||
"Importing saved model should not have input_mlir set");
|
||||
module = tensorflow::ImportSavedModel(
|
||||
import_saved_model_object_graph, import_saved_model_signature_defs,
|
||||
custom_opdefs, input_file_name, saved_model_tags,
|
||||
saved_model_exported_names, &context);
|
||||
} else {
|
||||
module = tensorflow::LoadFromGraphdefOrMlirSource(
|
||||
input_file_name, input_mlir, custom_opdefs, debug_info_file,
|
||||
input_arrays, input_dtypes, input_shapes, output_arrays,
|
||||
/*prune_unused_nodes=*/true, &source_mgr, &context);
|
||||
}
|
||||
|
||||
// If errors occur, the library call in the above already logged the error
|
||||
// message. So we can just return here.
|
||||
if (!module.ok()) return kTrFailure;
|
||||
|
||||
mlir::PassManager pm(&context);
|
||||
|
||||
tensorflow::AddTFToTFJSConversionPasses(&pm);
|
||||
|
||||
std::string result;
|
||||
auto status = tensorflow::ConvertTFOpsToTfjsJSON(module.ValueOrDie().get(),
|
||||
output_mlir, &result, &pm);
|
||||
if (!status.ok()) return kTrFailure;
|
||||
|
||||
std::string error_msg;
|
||||
auto output = mlir::openOutputFile(output_file_name, &error_msg);
|
||||
if (output == nullptr) {
|
||||
llvm::errs() << error_msg << '\n';
|
||||
return kTrFailure;
|
||||
}
|
||||
output->os() << result;
|
||||
output->keep();
|
||||
|
||||
// Print out debugging info related to function mapping.
|
||||
if (print_function_result_mapping) return PrintFunctionResultMapping(result);
|
||||
return kTrSuccess;
|
||||
}
|
152
tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.cc
Normal file
152
tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.cc
Normal file
@ -0,0 +1,152 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Parser.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "mlir/Support/FileUtilities.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
||||
#include "tensorflow/compiler/mlir/tfjs/translate/json_translate.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/framework/op_def_builder.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using mlir::MLIRContext;
|
||||
using mlir::ModuleOp;
|
||||
using mlir::OwningModuleRef;
|
||||
using stream_executor::port::StatusOr;
|
||||
|
||||
namespace {
|
||||
tensorflow::Status RegisterCustomOps(
|
||||
const std::vector<std::string>& extra_tf_opdefs) {
|
||||
for (const auto& tf_opdefs_string : extra_tf_opdefs) {
|
||||
tensorflow::OpDef opdef;
|
||||
if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs_string,
|
||||
&opdef)) {
|
||||
LOG(ERROR) << "OpDef parsing failed for: " << tf_opdefs_string;
|
||||
return errors::InvalidArgument("fail to parse extra OpDef");
|
||||
}
|
||||
// Register extra opdefs.
|
||||
tensorflow::OpRegistry::Global()->Register(
|
||||
[opdef](tensorflow::OpRegistrationData* op_reg_data) -> Status {
|
||||
*op_reg_data = tensorflow::OpRegistrationData(opdef);
|
||||
return Status::OK();
|
||||
});
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
StatusOr<OwningModuleRef> LoadFromGraphdefOrMlirSource(
|
||||
const std::string& input_filename, bool input_mlir,
|
||||
const std::vector<std::string>& extra_tf_opdefs,
|
||||
absl::string_view debug_info_file, absl::string_view input_arrays,
|
||||
absl::string_view input_dtypes, absl::string_view input_shapes,
|
||||
absl::string_view output_arrays, bool prune_unused_nodes,
|
||||
llvm::SourceMgr* source_mgr, MLIRContext* context) {
|
||||
// Set up the input file.
|
||||
std::string error_message;
|
||||
auto file = mlir::openInputFile(input_filename, &error_message);
|
||||
if (!file) {
|
||||
llvm::errs() << error_message << "\n";
|
||||
return errors::InvalidArgument("fail to open input file");
|
||||
}
|
||||
|
||||
if (input_mlir) {
|
||||
source_mgr->AddNewSourceBuffer(std::move(file), llvm::SMLoc());
|
||||
return OwningModuleRef(mlir::parseSourceFile(*source_mgr, context));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(RegisterCustomOps(extra_tf_opdefs));
|
||||
|
||||
return tensorflow::GraphdefToMlirTranslateFunction(
|
||||
file->getBuffer(), debug_info_file, input_arrays, input_dtypes,
|
||||
input_shapes, output_arrays, /*control_output_arrays=*/"",
|
||||
prune_unused_nodes, /*convert_legacy_fed_inputs=*/true,
|
||||
/*graph_as_function=*/false, /*upgrade_legacy=*/true,
|
||||
/*enable_shape_inference=*/true, context);
|
||||
}
|
||||
|
||||
Status ConvertTFOpsToTfjsJSON(mlir::ModuleOp module, bool export_to_mlir,
|
||||
std::string* result,
|
||||
mlir::PassManager* pass_manager) {
|
||||
mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext(),
|
||||
/*propagate=*/true);
|
||||
if (failed(pass_manager->run(module))) {
|
||||
return statusHandler.ConsumeStatus();
|
||||
}
|
||||
|
||||
if (export_to_mlir) {
|
||||
llvm::raw_string_ostream os(*result);
|
||||
module.print(os);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
return tfjs::MlirToJSONTranslateFunction(module, result)
|
||||
? Status::OK()
|
||||
: statusHandler.ConsumeStatus();
|
||||
}
|
||||
|
||||
StatusOr<mlir::OwningModuleRef> ImportSavedModel(
|
||||
bool import_saved_model, bool import_saved_model_v1,
|
||||
const std::vector<std::string>& extra_tf_opdefs,
|
||||
const std::string& input_filename, const std::string& saved_model_tags,
|
||||
const std::string& saved_model_exported_names, mlir::MLIRContext* context) {
|
||||
std::unordered_set<std::string> tags = absl::StrSplit(saved_model_tags, ',');
|
||||
std::vector<std::string> exported_names_in_vector =
|
||||
absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
|
||||
absl::Span<std::string> exported_names(exported_names_in_vector);
|
||||
if (import_saved_model) {
|
||||
auto module = tensorflow::SavedModelObjectGraphToMlirImport(
|
||||
input_filename, tags, absl::Span<std::string>(exported_names), context);
|
||||
if (!module)
|
||||
return tensorflow::errors::InvalidArgument("fail to open input file");
|
||||
TF_RETURN_IF_ERROR(RegisterCustomOps(extra_tf_opdefs));
|
||||
return module;
|
||||
} else if (import_saved_model_v1) {
|
||||
auto module = tensorflow::SavedModelSignatureDefsToMlirImport(
|
||||
input_filename, tags, exported_names, context);
|
||||
|
||||
if (!module)
|
||||
return tensorflow::errors::InvalidArgument("fail to open input file");
|
||||
TF_RETURN_IF_ERROR(RegisterCustomOps(extra_tf_opdefs));
|
||||
return module;
|
||||
} else {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"Should be either saved model v1 or v2");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
63
tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.h
Normal file
63
tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.h
Normal file
@ -0,0 +1,63 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_TF_TO_TFJS_JSON_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_TF_TO_TFJS_JSON_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Load a TF model from a GraphDef definition or a TF control flow dialect MLIR
|
||||
// source into a MLIR module. If `input_mlir` is true, load from a MLIR source
|
||||
// file; otherwise, load from a GraphDef.
|
||||
// Setting prune_unused_nodes to true, would prune unreachable nodes if
|
||||
// output_arrays is specified.
|
||||
stream_executor::port::StatusOr<mlir::OwningModuleRef>
|
||||
LoadFromGraphdefOrMlirSource(
|
||||
const std::string& input_filename, bool input_mlir,
|
||||
const std::vector<std::string>& extra_tf_opdefs,
|
||||
absl::string_view debug_info_file, absl::string_view input_arrays,
|
||||
absl::string_view input_dtypes, absl::string_view input_shapes,
|
||||
absl::string_view output_arrays, bool prune_unused_nodes,
|
||||
llvm::SourceMgr* source_mgr, mlir::MLIRContext* context);
|
||||
|
||||
// Load Saved model (either v1 or v2) into MLIR.
|
||||
stream_executor::port::StatusOr<mlir::OwningModuleRef> ImportSavedModel(
|
||||
bool import_saved_model, bool import_saved_model_v1,
|
||||
const std::vector<std::string>& extra_tf_opdefs,
|
||||
const std::string& input_filename, const std::string& saved_model_tags,
|
||||
const std::string& saved_model_exported_names, mlir::MLIRContext* context);
|
||||
|
||||
// Taking a MLIR module in TF executor dialect and a set of parameters,
|
||||
// applies a set of passes to convert the module to TFJS dialect and
|
||||
// serializes the result to JSON string.
|
||||
// If `export_to_mlir` is true, the result is exported in MLIR text format,
|
||||
// otherwise exported in JSON.
|
||||
Status ConvertTFOpsToTfjsJSON(mlir::ModuleOp module, bool export_to_mlir,
|
||||
std::string* result,
|
||||
mlir::PassManager* pass_manager);
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_TF_TO_TFJS_JSON_H_
|
Loading…
Reference in New Issue
Block a user