Add json translation for tfjs mlir converter.

TFJS ops are registered as TF custom ops, and utilize export_graphdef.cc to build out the GraphDef object that could contain both TF and TFJS dialects.

PiperOrigin-RevId: 311158257
Change-Id: I7313a5a01f12ef742a97fd5e9ff2bbffe8498b0c
This commit is contained in:
A. Unique TensorFlower 2020-05-12 10:48:33 -07:00 committed by TensorFlower Gardener
parent b661070db9
commit 2407170feb
15 changed files with 938 additions and 9 deletions

View File

@ -70,9 +70,9 @@ tool_dirs = config.mlir_tf_tools_dirs + [
]
tool_names = [
'mlir-opt', 'mlir-translate', 'tf-opt', 'tf_tfl_translate',
'flatbuffer_to_string', 'flatbuffer_translate', 'tf-mlir-translate',
'mlir-tflite-runner', 'tfcompile', 'json_to_flatbuffer', 'xla-gpu-opt',
'xla-opt'
'tf_tfjs_translate', 'flatbuffer_to_string', 'flatbuffer_translate',
'tf-mlir-translate', 'mlir-tflite-runner', 'tfcompile',
'json_to_flatbuffer', 'xla-gpu-opt', 'xla-opt'
]
tools = [ToolSubst(s, unresolved='ignore') for s in tool_names]
llvm_config.add_tool_substitutions(tools, tool_dirs)

View File

@ -44,6 +44,7 @@ mlir_tf_tools_dirs = [
'tensorflow/compiler/mlir',
'tensorflow/compiler/mlir/lite',
'tensorflow/compiler/mlir/tensorflow',
'tensorflow/compiler/mlir/tfjs',
'tensorflow/compiler/mlir/xla',
'tensorflow/compiler/aot',
'tensorflow/compiler/xla/service/mlir_gpu',

View File

@ -59,6 +59,18 @@ limitations under the License.
namespace tensorflow {
namespace {
// static TensorFlow op prefix set.
std::set<std::string>* GlobalOpPrefixes() {
static std::set<std::string>* global_op_prefixes = [] {
std::set<std::string>* result = new std::set<std::string>;
result->insert("tf.");
result->insert("_tf.");
result->insert("tf_executor.");
return result;
}();
return global_op_prefixes;
}
// Converts a location to the debug information for the node def.
Status ConvertLocation(mlir::Location inst_loc,
NodeDef::ExperimentalDebugInfo* debug_info) {
@ -268,8 +280,10 @@ StatusOr<llvm::StringRef> GetTensorFlowOpName(llvm::StringRef op_name) {
// - ".sink" or ".Sink": only the NextIteration operation has this suffix. We
// don't need to consider ".source"/".Source" because the nodes with this
// suffix are skipped by the caller and will not be added to the graph.
if (!op_name.consume_front("_tf.") && !op_name.consume_front("tf.") &&
!op_name.consume_front("tf_executor.")) {
auto prefixes = GlobalOpPrefixes();
if (std::none_of(prefixes->begin(), prefixes->end(), [&](std::string prefix) {
return op_name.consume_front(prefix);
})) {
return errors::FailedPrecondition("op node '", op_name.str(),
"' was not a TF op!");
}
@ -506,4 +520,9 @@ bool IsLegacyCallInstruction(mlir::Operation* inst) {
inst->getName().getStringRef().compare("_tf.LegacyCall") == 0;
}
Status AddTensorFlowOpPrefix(std::string prefix) {
GlobalOpPrefixes()->insert(prefix);
return Status::OK();
}
} // namespace tensorflow

View File

@ -34,10 +34,17 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/stream_executor/lib/statusor.h"
namespace mlir {
class ShapedType;
} // namespace mlir
namespace tensorflow {
using stream_executor::port::StatusOr;
// Add custom op prefix for TensorFlow dialects.
Status AddTensorFlowOpPrefix(std::string);
// Maps an MLIR op name in the TensorFlow dialect or the TensorFlow control
// dialect back into a TensorFlow valid op name.
StatusOr<llvm::StringRef> GetTensorFlowOpName(llvm::StringRef);

View File

@ -1,4 +1,5 @@
load("//third_party/mlir:tblgen.bzl", "gentbl")
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
package(
default_visibility = ["//visibility:public"],
@ -131,10 +132,106 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
"//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass",
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Transforms",
],
)
cc_library(
name = "json_translate_lib",
srcs = [
"translate/json_translate.cc",
],
hdrs = [
"translate/json_translate.h",
],
deps = [
":tensorflow_js",
":tensorflow_js_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
"//tensorflow/compiler/mlir/tensorflow:export_utils",
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Translation",
],
alwayslink = 1,
)
cc_library(
name = "tf_to_tfjs_json",
srcs = ["translate/tf_to_tfjs_json.cc"],
hdrs = [
"translate/tf_to_tfjs_json.h",
],
deps = [
":json_translate_lib",
":tfjs_optimize",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:decode_constant_pass",
"//tensorflow/compiler/mlir/tensorflow:error_util",
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib",
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
"//tensorflow/compiler/mlir/tensorflow:translate_cl_options",
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:support",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
],
alwayslink = 1,
)
tf_cc_binary(
name = "json_translate",
deps = [
":json_translate_lib",
"@llvm-project//mlir:MlirTranslateMain",
],
)
filegroup(
name = "tf_tfjs_translate_main",
srcs = [
"translate/tf_tfjs_translate.cc",
],
)
tf_cc_binary(
name = "tf_tfjs_translate",
srcs = [":tf_tfjs_translate_main"],
deps = [
":json_translate_lib",
":tensorflow_js_passes",
":tf_to_tfjs_json",
":tfjs_optimize",
"//tensorflow/compiler/mlir:init_mlir",
"//tensorflow/compiler/mlir/tensorflow:translate_cl_options",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:errors",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
],
)

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/Interfaces/SideEffects.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
namespace mlir {
namespace tfjs {

View File

@ -0,0 +1,23 @@
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
licenses(["notice"])
glob_lit_tests(
data = [
":test_utilities",
],
driver = "@llvm-project//mlir:run_lit.sh",
test_file_exts = [
"pbtxt",
],
)
# Bundle together all of the test utilities that are used by tests.
filegroup(
name = "test_utilities",
testonly = True,
data = [
"//tensorflow/compiler/mlir/tfjs:tf_tfjs_translate",
"@llvm-project//llvm:FileCheck",
],
)

View File

@ -0,0 +1,78 @@
# RUN: tf_tfjs_translate %s -tf-input-arrays=input0,input1 -tf-input-data-types=DT_INT32,DT_INT32 -tf-input-shapes=10:10 -tf-output-arrays=Mul -o - | FileCheck %s --dump-input-on-failure
# Add two tensor<4xi32> inputs and return the result
node {
name: "Add"
op: "Add"
input: "input0"
input: "input1"
attr {
key: "T"
value {
type: DT_INT32
}
}
}
node {
name: "input0"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
}
node {
name: "input1"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
}
node {
name: "Mul"
op: "Mul"
input: "Add"
input: "Add"
attr {
key: "T"
value {
type: DT_INT32
}
}
}
versions {
producer: 27
}
# CHECK: "name": "input0"
# CHECK-NEXT: "op": "Placeholder"
# CHECK: "type": "DT_INT32"
# CHECK: "name": "input1",
# CHECK-NEXT: "op": "Placeholder"
# CHECK: "type": "DT_INT32"
# CHECK: "name": "Add"
# CHECK-NEXT: "op": "AddV2"
# CHECK-NEXT: "input":
# CHECK-NEXT: "input0"
# CHECK-NEXT: "input1"
# CHECK: "type": "DT_INT32"
# CHECK: "name": "Mul1"
# CHECK-NEXT: "op": "Mul"
# CHECK-NEXT: "input":
# CHECK-NEXT: "Add"
# CHECK-NEXT: "Add"
# CHECK: "type": "DT_INT32"
# CHECK: "name": "Mul"
# CHECK-NEXT: "op": "_Retval"
# CHECK-NEXT: "input":
# CHECK-NEXT: "Mul1"
# CHECK: "type": "DT_INT32"
# CHECK: "library"
# CHECK: "versions"
# CHECK: "producer": 27

View File

@ -0,0 +1,175 @@
# RUN: tf_tfjs_translate %s -tf-input-arrays=input0 -tf-input-data-types=DT_FLOAT -tf-input-shapes=10 -tf-output-arrays=Add -tf-custom-opdefs="name: 'Prelu' input_arg: { name: 'x' type: DT_FLOAT } input_arg: { name: 'alpha' type: DT_FLOAT } output_arg: { name: 'c' type: DT_FLOAT }" -o - | FileCheck %s --dump-input-on-failure
# Add two tensor<4xi32> inputs and return the result
node {
name: "input0"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 10
}
}
}
}
experimental_debug_info {
}
}
node {
name: "alpha"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
}
float_val: 0.5
}
}
}
experimental_debug_info {
}
}
node {
name: "Relu"
op: "Relu"
input: "input0"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
experimental_debug_info {
}
}
node {
name: "Neg"
op: "Neg"
input: "input0"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
experimental_debug_info {
}
}
node {
name: "Relu1"
op: "Relu"
input: "Neg"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
experimental_debug_info {
}
}
node {
name: "Mul"
op: "Mul"
input: "alpha"
input: "Relu1"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
experimental_debug_info {
}
}
node {
name: "Add"
op: "Add"
input: "Relu"
input: "Mul"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
experimental_debug_info {
}
}
node {
name: "main"
op: "_Retval"
input: "Add"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "index"
value {
i: 0
}
}
}
library {
}
versions {
producer: 344
}
# CHECK: "node":
# CHECK: "name": "input0",
# CHECK-NEXT: "op": "Placeholder",
# CHECK-NEXT: "attr":
# CHECK: "type": "DT_FLOAT"
# CHECK: "name": "Add.Relu.Neg.Relu1.Mul",
# CHECK-NEXT: "op": "Const",
# CHECK-NEXT: "attr":
# CHECK: "value":
# CHECK: "tensor":
# CHECK: "dtype": "DT_FLOAT",
# CHECK: "tensorShape": {},
# CHECK: "floatVal":
# CHECK: -0.5
# CHECK: "name": "Add.Relu.Neg.Relu1.Mul1",
# CHECK-NEXT: "op": "Prelu",
# CHECK-NEXT: "input":
# CHECK: "input0",
# CHECK: "Add.Relu.Neg.Relu1.Mul"
# CHECK: "attr":
# CHECK: "_output_shapes":
# CHECK: "list":
# CHECK: "shape":
# CHECK: "dim":
# CHECK: "size": "10"
# CHECK: "experimentalDebugInfo": {}
# CHECK: "name": "Add",
# CHECK-NEXT: "op": "_Retval",
# CHECK-NEXT: "input":
# CHECK: "Add.Relu.Neg.Relu1.Mul1"
# CHECK: "attr":
# CHECK: "T":
# CHECK: "type": "DT_FLOAT"
# CHECK: "library": {},
# CHECK: "versions":
# CHECK: "producer": 344

View File

@ -1,4 +1,4 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -20,7 +20,6 @@ limitations under the License.
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Transforms/Passes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tfjs/transforms/passes.h"
@ -47,6 +46,11 @@ void AddTFToTFJSConversionPasses(mlir::OpPassManager* pm) {
// Canonicalize, CSE etc.
pm->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
pm->addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
// raise to executor dialect in order to use GraphDef converter
pm->addNestedPass<mlir::FuncOp>(
mlir::CreateFunctionalToExecutorDialectConversionPass());
pm->addNestedPass<mlir::FuncOp>(mlir::CreateBreakUpIslandsPass());
}
} // namespace tensorflow

View File

@ -0,0 +1,105 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tfjs/translate/json_translate.h"
#include <memory>
#include <string>
#include <utility>
#include "absl/container/flat_hash_set.h"
#include "absl/status/status.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Translation.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/status.h"
using mlir::ModuleOp;
using mlir::TranslateFromMLIRRegistration;
using std::string;
using tensorflow::Status;
using xla::StatusOr;
// Translates the given MLIR module in the TFJS dialect to TFJS JSON
// format. Returns false on success.
//
bool tfjs::MlirToJSONTranslateFunction(ModuleOp module,
std::string* serialized_json) {
string json_output;
// Allow TF to treat TFJS ops as TF ops.
if (!tensorflow::AddTensorFlowOpPrefix("tfjs.").ok()) {
LOG(ERROR) << "Failed to add tfjs op prefix.";
return false;
}
tensorflow::GraphExportConfig confs;
confs.export_shapes = true;
confs.export_library = true;
tensorflow::FunctionLibraryDefinition flib_def(
tensorflow::OpRegistry::Global(), tensorflow::FunctionDefLibrary());
absl::flat_hash_set<tensorflow::Node*> control_ret_nodes;
auto graph = absl::make_unique<tensorflow::Graph>(flib_def);
auto status = tensorflow::ConvertMlirToGraph(module, confs, &graph, &flib_def,
&control_ret_nodes);
if (!status.ok()) {
LOG(ERROR) << "Graph export failed: " << status;
return false;
}
auto graphdef = absl::make_unique<tensorflow::GraphDef>();
graph->ToGraphDef(graphdef.get());
// Replace the _Arg nodes of the main function with Placeholder op.
auto nodes = graphdef->mutable_node();
for (const auto& node : llvm::enumerate(*nodes)) {
if (node.value().op() == "_Arg") {
nodes->Mutable(node.index())->set_op("Placeholder");
}
}
tensorflow::protobuf::util::JsonPrintOptions json_options;
json_options.add_whitespace = true;
auto jsonStatus = tensorflow::protobuf::util::MessageToJsonString(
*graphdef, &json_output, json_options);
if (!jsonStatus.ok()) {
LOG(ERROR) << "Proto2Json failed: " << status;
return false;
}
*serialized_json = std::move(json_output);
return true;
}
static mlir::LogicalResult MlirToJSONFileTranslateFunction(
ModuleOp module, llvm::raw_ostream& output) {
std::string serialized_json;
if (!tfjs::MlirToJSONTranslateFunction(module, &serialized_json))
return mlir::failure();
output << serialized_json;
return mlir::success();
}
static TranslateFromMLIRRegistration MLIRToJSONFileTranslate(
"mlir-to-tfjs-json", MlirToJSONFileTranslateFunction);

View File

@ -0,0 +1,31 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_JSON_TRANSLATE_H_
#define TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_JSON_TRANSLATE_H_
#include <string>
#include "mlir/IR/Module.h" // from @llvm-project
#include "tensorflow/core/lib/core/status.h"
namespace tfjs {
// Translates the given MLIR `module` into a JSON string. Returns true if
// translation fails, otherwise returns false.
bool MlirToJSONTranslateFunction(mlir::ModuleOp module,
std::string* serialized_json);
} // namespace tfjs
#endif // TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_JSON_TRANSLATE_H_

View File

@ -0,0 +1,173 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <iostream>
#include <string>
#include "absl/strings/str_split.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/FileUtilities.h" // from @llvm-project
#include "tensorflow/compiler/mlir/init_mlir.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h"
#include "tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.h"
#include "tensorflow/compiler/mlir/tfjs/transforms/passes.h"
#include "tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/stream_executor/lib/statusor.h"
using llvm::cl::opt;
using mlir::MLIRContext;
using stream_executor::port::StatusOr;
// NOLINTNEXTLINE
opt<std::string> input_file_name(llvm::cl::Positional,
llvm::cl::desc("<input file>"),
llvm::cl::init("-"));
// NOLINTNEXTLINE
opt<bool> import_saved_model_object_graph(
"savedmodel-objectgraph-to-mlir",
llvm::cl::desc("Import a saved model to its MLIR representation"),
llvm::cl::value_desc("dir"));
// NOLINTNEXTLINE
opt<bool> import_saved_model_signature_defs(
"savedmodel-signaturedefs-to-mlir",
llvm::cl::desc("Import a saved model V1 to its MLIR representation"),
llvm::cl::value_desc("dir"));
// NOLINTNEXTLINE
opt<std::string> saved_model_tags(
"tf-savedmodel-tags",
llvm::cl::desc("Tags used to indicate which MetaGraphDef to import, "
"separated by ','"),
llvm::cl::init("serve"));
// NOLINTNEXTLINE
opt<std::string> saved_model_exported_names(
"tf-savedmodel-exported-names",
llvm::cl::desc("Names to export from SavedModel, separated by ','. Empty "
"(the default) means export all."),
llvm::cl::init(""));
// NOLINTNEXTLINE
opt<std::string> output_file_name("o", llvm::cl::desc("<output file>"),
llvm::cl::value_desc("filename"),
llvm::cl::init("-"));
// NOLINTNEXTLINE
opt<bool> input_mlir(
"input-mlir",
llvm::cl::desc("Take input TensorFlow model in textual MLIR instead of "
"GraphDef format"),
llvm::cl::init(false), llvm::cl::Hidden);
// NOLINTNEXTLINE
opt<bool> output_mlir(
"output-mlir",
llvm::cl::desc("Output MLIR rather than JSON for the generated TFJS model"),
llvm::cl::init(false));
// The following approach allows injecting opdefs in addition
// to those that are already part of the global TF registry to be linked in
// prior to importing the graph. The primary goal is for support of custom ops.
// This is not intended to be a general solution for custom ops for the future
// but mainly for supporting older models like mobilenet_ssd. More appropriate
// mechanisms, such as op hints or using functions to represent composable ops
// like https://github.com/tensorflow/community/pull/113 should be encouraged
// going forward.
// NOLINTNEXTLINE
llvm::cl::list<std::string> custom_opdefs(
"tf-custom-opdefs", llvm::cl::desc("List of custom opdefs when importing "
"graphdef"));
// Debugging flag to print function mapping in the JSON.
// NOLINTNEXTLINE
static opt<bool> print_function_result_mapping(
"print-function-result-mapping",
llvm::cl::desc(
"Print the mapping of function result to json output buffer"),
llvm::cl::init(false));
enum TranslationStatus { kTrSuccess, kTrFailure };
static int PrintFunctionResultMapping(const std::string& result) {
std::cout << result << std::endl;
return kTrSuccess;
}
int main(int argc, char** argv) {
tensorflow::InitMlir y(&argc, &argv);
llvm::cl::ParseCommandLineOptions(argc, argv,
"TF GraphDef to TFJS JSON converter\n");
MLIRContext context;
llvm::SourceMgr source_mgr;
mlir::SourceMgrDiagnosticHandler sourceMgrHandler(source_mgr, &context);
StatusOr<mlir::OwningModuleRef> module;
if (import_saved_model_object_graph || import_saved_model_signature_defs) {
if (input_mlir)
module = tensorflow::errors::InvalidArgument(
"Importing saved model should not have input_mlir set");
module = tensorflow::ImportSavedModel(
import_saved_model_object_graph, import_saved_model_signature_defs,
custom_opdefs, input_file_name, saved_model_tags,
saved_model_exported_names, &context);
} else {
module = tensorflow::LoadFromGraphdefOrMlirSource(
input_file_name, input_mlir, custom_opdefs, debug_info_file,
input_arrays, input_dtypes, input_shapes, output_arrays,
/*prune_unused_nodes=*/true, &source_mgr, &context);
}
// If errors occur, the library call in the above already logged the error
// message. So we can just return here.
if (!module.ok()) return kTrFailure;
mlir::PassManager pm(&context);
tensorflow::AddTFToTFJSConversionPasses(&pm);
std::string result;
auto status = tensorflow::ConvertTFOpsToTfjsJSON(module.ValueOrDie().get(),
output_mlir, &result, &pm);
if (!status.ok()) return kTrFailure;
std::string error_msg;
auto output = mlir::openOutputFile(output_file_name, &error_msg);
if (output == nullptr) {
llvm::errs() << error_msg << '\n';
return kTrFailure;
}
output->os() << result;
output->keep();
// Print out debugging info related to function mapping.
if (print_function_result_mapping) return PrintFunctionResultMapping(result);
return kTrSuccess;
}

View File

@ -0,0 +1,152 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.h"
#include <memory>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Parser.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Support/FileUtilities.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/compiler/mlir/tfjs/translate/json_translate.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/op_def_builder.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/stream_executor/lib/statusor.h"
namespace tensorflow {
using mlir::MLIRContext;
using mlir::ModuleOp;
using mlir::OwningModuleRef;
using stream_executor::port::StatusOr;
namespace {
tensorflow::Status RegisterCustomOps(
const std::vector<std::string>& extra_tf_opdefs) {
for (const auto& tf_opdefs_string : extra_tf_opdefs) {
tensorflow::OpDef opdef;
if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs_string,
&opdef)) {
LOG(ERROR) << "OpDef parsing failed for: " << tf_opdefs_string;
return errors::InvalidArgument("fail to parse extra OpDef");
}
// Register extra opdefs.
tensorflow::OpRegistry::Global()->Register(
[opdef](tensorflow::OpRegistrationData* op_reg_data) -> Status {
*op_reg_data = tensorflow::OpRegistrationData(opdef);
return Status::OK();
});
}
return Status::OK();
}
} // namespace
StatusOr<OwningModuleRef> LoadFromGraphdefOrMlirSource(
const std::string& input_filename, bool input_mlir,
const std::vector<std::string>& extra_tf_opdefs,
absl::string_view debug_info_file, absl::string_view input_arrays,
absl::string_view input_dtypes, absl::string_view input_shapes,
absl::string_view output_arrays, bool prune_unused_nodes,
llvm::SourceMgr* source_mgr, MLIRContext* context) {
// Set up the input file.
std::string error_message;
auto file = mlir::openInputFile(input_filename, &error_message);
if (!file) {
llvm::errs() << error_message << "\n";
return errors::InvalidArgument("fail to open input file");
}
if (input_mlir) {
source_mgr->AddNewSourceBuffer(std::move(file), llvm::SMLoc());
return OwningModuleRef(mlir::parseSourceFile(*source_mgr, context));
}
TF_RETURN_IF_ERROR(RegisterCustomOps(extra_tf_opdefs));
return tensorflow::GraphdefToMlirTranslateFunction(
file->getBuffer(), debug_info_file, input_arrays, input_dtypes,
input_shapes, output_arrays, /*control_output_arrays=*/"",
prune_unused_nodes, /*convert_legacy_fed_inputs=*/true,
/*graph_as_function=*/false, /*upgrade_legacy=*/true,
/*enable_shape_inference=*/true, context);
}
Status ConvertTFOpsToTfjsJSON(mlir::ModuleOp module, bool export_to_mlir,
std::string* result,
mlir::PassManager* pass_manager) {
mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext(),
/*propagate=*/true);
if (failed(pass_manager->run(module))) {
return statusHandler.ConsumeStatus();
}
if (export_to_mlir) {
llvm::raw_string_ostream os(*result);
module.print(os);
return Status::OK();
}
return tfjs::MlirToJSONTranslateFunction(module, result)
? Status::OK()
: statusHandler.ConsumeStatus();
}
StatusOr<mlir::OwningModuleRef> ImportSavedModel(
bool import_saved_model, bool import_saved_model_v1,
const std::vector<std::string>& extra_tf_opdefs,
const std::string& input_filename, const std::string& saved_model_tags,
const std::string& saved_model_exported_names, mlir::MLIRContext* context) {
std::unordered_set<std::string> tags = absl::StrSplit(saved_model_tags, ',');
std::vector<std::string> exported_names_in_vector =
absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
absl::Span<std::string> exported_names(exported_names_in_vector);
if (import_saved_model) {
auto module = tensorflow::SavedModelObjectGraphToMlirImport(
input_filename, tags, absl::Span<std::string>(exported_names), context);
if (!module)
return tensorflow::errors::InvalidArgument("fail to open input file");
TF_RETURN_IF_ERROR(RegisterCustomOps(extra_tf_opdefs));
return module;
} else if (import_saved_model_v1) {
auto module = tensorflow::SavedModelSignatureDefsToMlirImport(
input_filename, tags, exported_names, context);
if (!module)
return tensorflow::errors::InvalidArgument("fail to open input file");
TF_RETURN_IF_ERROR(RegisterCustomOps(extra_tf_opdefs));
return module;
} else {
return tensorflow::errors::InvalidArgument(
"Should be either saved model v1 or v2");
}
}
} // namespace tensorflow

View File

@ -0,0 +1,63 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_TF_TO_TFJS_JSON_H_
#define TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_TF_TO_TFJS_JSON_H_
#include <string>
#include <vector>
#include "absl/strings/string_view.h"
#include "llvm/Support/SourceMgr.h"
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "tensorflow/core/platform/status.h"
#include "tensorflow/stream_executor/lib/statusor.h"
namespace tensorflow {
// Load a TF model from a GraphDef definition or a TF control flow dialect MLIR
// source into a MLIR module. If `input_mlir` is true, load from a MLIR source
// file; otherwise, load from a GraphDef.
// Setting prune_unused_nodes to true, would prune unreachable nodes if
// output_arrays is specified.
stream_executor::port::StatusOr<mlir::OwningModuleRef>
LoadFromGraphdefOrMlirSource(
const std::string& input_filename, bool input_mlir,
const std::vector<std::string>& extra_tf_opdefs,
absl::string_view debug_info_file, absl::string_view input_arrays,
absl::string_view input_dtypes, absl::string_view input_shapes,
absl::string_view output_arrays, bool prune_unused_nodes,
llvm::SourceMgr* source_mgr, mlir::MLIRContext* context);
// Load Saved model (either v1 or v2) into MLIR.
stream_executor::port::StatusOr<mlir::OwningModuleRef> ImportSavedModel(
bool import_saved_model, bool import_saved_model_v1,
const std::vector<std::string>& extra_tf_opdefs,
const std::string& input_filename, const std::string& saved_model_tags,
const std::string& saved_model_exported_names, mlir::MLIRContext* context);
// Taking a MLIR module in TF executor dialect and a set of parameters,
// applies a set of passes to convert the module to TFJS dialect and
// serializes the result to JSON string.
// If `export_to_mlir` is true, the result is exported in MLIR text format,
// otherwise exported in JSON.
Status ConvertTFOpsToTfjsJSON(mlir::ModuleOp module, bool export_to_mlir,
std::string* result,
mlir::PassManager* pass_manager);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_TF_TO_TFJS_JSON_H_