From 2407170febcdc37fbe90d9f5d8968f2b94ec17dc Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 12 May 2020 10:48:33 -0700 Subject: [PATCH] Add json translation for tfjs mlir converter. TFJS ops are registered as TF custom ops, and utilize export_graphdef.cc to build out the GraphDef object that could contain both TF and TFJS dialects. PiperOrigin-RevId: 311158257 Change-Id: I7313a5a01f12ef742a97fd5e9ff2bbffe8498b0c --- tensorflow/compiler/mlir/runlit.cfg.py | 6 +- tensorflow/compiler/mlir/runlit.site.cfg.py | 1 + .../mlir/tensorflow/utils/export_utils.cc | 23 ++- .../mlir/tensorflow/utils/export_utils.h | 7 + tensorflow/compiler/mlir/tfjs/BUILD | 101 +++++++++- tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h | 1 + tensorflow/compiler/mlir/tfjs/tests/e2e/BUILD | 23 +++ .../compiler/mlir/tfjs/tests/e2e/add.pbtxt | 78 ++++++++ .../compiler/mlir/tfjs/tests/e2e/prelu.pbtxt | 175 ++++++++++++++++++ .../compiler/mlir/tfjs/tf_tfjs_passes.cc | 8 +- .../mlir/tfjs/translate/json_translate.cc | 105 +++++++++++ .../mlir/tfjs/translate/json_translate.h | 31 ++++ .../mlir/tfjs/translate/tf_tfjs_translate.cc | 173 +++++++++++++++++ .../mlir/tfjs/translate/tf_to_tfjs_json.cc | 152 +++++++++++++++ .../mlir/tfjs/translate/tf_to_tfjs_json.h | 63 +++++++ 15 files changed, 938 insertions(+), 9 deletions(-) create mode 100644 tensorflow/compiler/mlir/tfjs/tests/e2e/BUILD create mode 100644 tensorflow/compiler/mlir/tfjs/tests/e2e/add.pbtxt create mode 100644 tensorflow/compiler/mlir/tfjs/tests/e2e/prelu.pbtxt create mode 100644 tensorflow/compiler/mlir/tfjs/translate/json_translate.cc create mode 100644 tensorflow/compiler/mlir/tfjs/translate/json_translate.h create mode 100644 tensorflow/compiler/mlir/tfjs/translate/tf_tfjs_translate.cc create mode 100644 tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.cc create mode 100644 tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.h diff --git a/tensorflow/compiler/mlir/runlit.cfg.py b/tensorflow/compiler/mlir/runlit.cfg.py index 6d3131a781c..f1271d0da24 100644 --- a/tensorflow/compiler/mlir/runlit.cfg.py +++ b/tensorflow/compiler/mlir/runlit.cfg.py @@ -70,9 +70,9 @@ tool_dirs = config.mlir_tf_tools_dirs + [ ] tool_names = [ 'mlir-opt', 'mlir-translate', 'tf-opt', 'tf_tfl_translate', - 'flatbuffer_to_string', 'flatbuffer_translate', 'tf-mlir-translate', - 'mlir-tflite-runner', 'tfcompile', 'json_to_flatbuffer', 'xla-gpu-opt', - 'xla-opt' + 'tf_tfjs_translate', 'flatbuffer_to_string', 'flatbuffer_translate', + 'tf-mlir-translate', 'mlir-tflite-runner', 'tfcompile', + 'json_to_flatbuffer', 'xla-gpu-opt', 'xla-opt' ] tools = [ToolSubst(s, unresolved='ignore') for s in tool_names] llvm_config.add_tool_substitutions(tools, tool_dirs) diff --git a/tensorflow/compiler/mlir/runlit.site.cfg.py b/tensorflow/compiler/mlir/runlit.site.cfg.py index 661e6200df3..3e7596c75d7 100644 --- a/tensorflow/compiler/mlir/runlit.site.cfg.py +++ b/tensorflow/compiler/mlir/runlit.site.cfg.py @@ -44,6 +44,7 @@ mlir_tf_tools_dirs = [ 'tensorflow/compiler/mlir', 'tensorflow/compiler/mlir/lite', 'tensorflow/compiler/mlir/tensorflow', + 'tensorflow/compiler/mlir/tfjs', 'tensorflow/compiler/mlir/xla', 'tensorflow/compiler/aot', 'tensorflow/compiler/xla/service/mlir_gpu', diff --git a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc index cc795259893..4877cbc4a44 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc @@ -59,6 +59,18 @@ limitations under the License. namespace tensorflow { namespace { +// static TensorFlow op prefix set. +std::set* GlobalOpPrefixes() { + static std::set* global_op_prefixes = [] { + std::set* result = new std::set; + result->insert("tf."); + result->insert("_tf."); + result->insert("tf_executor."); + return result; + }(); + return global_op_prefixes; +} + // Converts a location to the debug information for the node def. Status ConvertLocation(mlir::Location inst_loc, NodeDef::ExperimentalDebugInfo* debug_info) { @@ -268,8 +280,10 @@ StatusOr GetTensorFlowOpName(llvm::StringRef op_name) { // - ".sink" or ".Sink": only the NextIteration operation has this suffix. We // don't need to consider ".source"/".Source" because the nodes with this // suffix are skipped by the caller and will not be added to the graph. - if (!op_name.consume_front("_tf.") && !op_name.consume_front("tf.") && - !op_name.consume_front("tf_executor.")) { + auto prefixes = GlobalOpPrefixes(); + if (std::none_of(prefixes->begin(), prefixes->end(), [&](std::string prefix) { + return op_name.consume_front(prefix); + })) { return errors::FailedPrecondition("op node '", op_name.str(), "' was not a TF op!"); } @@ -506,4 +520,9 @@ bool IsLegacyCallInstruction(mlir::Operation* inst) { inst->getName().getStringRef().compare("_tf.LegacyCall") == 0; } +Status AddTensorFlowOpPrefix(std::string prefix) { + GlobalOpPrefixes()->insert(prefix); + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h index 32ed528bd0d..58fe39fa4e8 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h @@ -34,10 +34,17 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/stream_executor/lib/statusor.h" +namespace mlir { +class ShapedType; +} // namespace mlir + namespace tensorflow { using stream_executor::port::StatusOr; +// Add custom op prefix for TensorFlow dialects. +Status AddTensorFlowOpPrefix(std::string); + // Maps an MLIR op name in the TensorFlow dialect or the TensorFlow control // dialect back into a TensorFlow valid op name. StatusOr GetTensorFlowOpName(llvm::StringRef); diff --git a/tensorflow/compiler/mlir/tfjs/BUILD b/tensorflow/compiler/mlir/tfjs/BUILD index 9b731d2c912..806a77e9c38 100644 --- a/tensorflow/compiler/mlir/tfjs/BUILD +++ b/tensorflow/compiler/mlir/tfjs/BUILD @@ -1,4 +1,5 @@ load("//third_party/mlir:tblgen.bzl", "gentbl") +load("//tensorflow:tensorflow.bzl", "tf_cc_binary") package( default_visibility = ["//visibility:public"], @@ -131,10 +132,106 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration", "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass", - "//tensorflow/compiler/mlir/tensorflow:translate_lib", - "@llvm-project//mlir:Analysis", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Transforms", ], ) + +cc_library( + name = "json_translate_lib", + srcs = [ + "translate/json_translate.cc", + ], + hdrs = [ + "translate/json_translate.h", + ], + deps = [ + ":tensorflow_js", + ":tensorflow_js_dialect_registration", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:convert_graphdef", + "//tensorflow/compiler/mlir/tensorflow:export_utils", + "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Translation", + ], + alwayslink = 1, +) + +cc_library( + name = "tf_to_tfjs_json", + srcs = ["translate/tf_to_tfjs_json.cc"], + hdrs = [ + "translate/tf_to_tfjs_json.h", + ], + deps = [ + ":json_translate_lib", + ":tfjs_optimize", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:decode_constant_pass", + "//tensorflow/compiler/mlir/tensorflow:error_util", + "//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib", + "//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes", + "//tensorflow/compiler/mlir/tensorflow:translate_cl_options", + "//tensorflow/compiler/mlir/tensorflow:translate_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:support", + "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + ], + alwayslink = 1, +) + +tf_cc_binary( + name = "json_translate", + deps = [ + ":json_translate_lib", + "@llvm-project//mlir:MlirTranslateMain", + ], +) + +filegroup( + name = "tf_tfjs_translate_main", + srcs = [ + "translate/tf_tfjs_translate.cc", + ], +) + +tf_cc_binary( + name = "tf_tfjs_translate", + srcs = [":tf_tfjs_translate_main"], + deps = [ + ":json_translate_lib", + ":tensorflow_js_passes", + ":tf_to_tfjs_json", + ":tfjs_optimize", + "//tensorflow/compiler/mlir:init_mlir", + "//tensorflow/compiler/mlir/tensorflow:translate_cl_options", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:errors", + "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + ], +) diff --git a/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h b/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h index 318895de79c..545183a052b 100644 --- a/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h +++ b/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h @@ -28,6 +28,7 @@ limitations under the License. #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/Interfaces/SideEffects.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project + namespace mlir { namespace tfjs { diff --git a/tensorflow/compiler/mlir/tfjs/tests/e2e/BUILD b/tensorflow/compiler/mlir/tfjs/tests/e2e/BUILD new file mode 100644 index 00000000000..5c8d37da2f0 --- /dev/null +++ b/tensorflow/compiler/mlir/tfjs/tests/e2e/BUILD @@ -0,0 +1,23 @@ +load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") + +licenses(["notice"]) + +glob_lit_tests( + data = [ + ":test_utilities", + ], + driver = "@llvm-project//mlir:run_lit.sh", + test_file_exts = [ + "pbtxt", + ], +) + +# Bundle together all of the test utilities that are used by tests. +filegroup( + name = "test_utilities", + testonly = True, + data = [ + "//tensorflow/compiler/mlir/tfjs:tf_tfjs_translate", + "@llvm-project//llvm:FileCheck", + ], +) diff --git a/tensorflow/compiler/mlir/tfjs/tests/e2e/add.pbtxt b/tensorflow/compiler/mlir/tfjs/tests/e2e/add.pbtxt new file mode 100644 index 00000000000..f6a324fdc13 --- /dev/null +++ b/tensorflow/compiler/mlir/tfjs/tests/e2e/add.pbtxt @@ -0,0 +1,78 @@ +# RUN: tf_tfjs_translate %s -tf-input-arrays=input0,input1 -tf-input-data-types=DT_INT32,DT_INT32 -tf-input-shapes=10:10 -tf-output-arrays=Mul -o - | FileCheck %s --dump-input-on-failure +# Add two tensor<4xi32> inputs and return the result + +node { + name: "Add" + op: "Add" + input: "input0" + input: "input1" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +node { + name: "input0" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } +} +node { + name: "input1" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } +} +node { + name: "Mul" + op: "Mul" + input: "Add" + input: "Add" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +versions { + producer: 27 +} + +# CHECK: "name": "input0" +# CHECK-NEXT: "op": "Placeholder" +# CHECK: "type": "DT_INT32" +# CHECK: "name": "input1", +# CHECK-NEXT: "op": "Placeholder" +# CHECK: "type": "DT_INT32" +# CHECK: "name": "Add" +# CHECK-NEXT: "op": "AddV2" +# CHECK-NEXT: "input": +# CHECK-NEXT: "input0" +# CHECK-NEXT: "input1" +# CHECK: "type": "DT_INT32" +# CHECK: "name": "Mul1" +# CHECK-NEXT: "op": "Mul" +# CHECK-NEXT: "input": +# CHECK-NEXT: "Add" +# CHECK-NEXT: "Add" +# CHECK: "type": "DT_INT32" +# CHECK: "name": "Mul" +# CHECK-NEXT: "op": "_Retval" +# CHECK-NEXT: "input": +# CHECK-NEXT: "Mul1" +# CHECK: "type": "DT_INT32" +# CHECK: "library" +# CHECK: "versions" +# CHECK: "producer": 27 + diff --git a/tensorflow/compiler/mlir/tfjs/tests/e2e/prelu.pbtxt b/tensorflow/compiler/mlir/tfjs/tests/e2e/prelu.pbtxt new file mode 100644 index 00000000000..810db71f5e0 --- /dev/null +++ b/tensorflow/compiler/mlir/tfjs/tests/e2e/prelu.pbtxt @@ -0,0 +1,175 @@ +# RUN: tf_tfjs_translate %s -tf-input-arrays=input0 -tf-input-data-types=DT_FLOAT -tf-input-shapes=10 -tf-output-arrays=Add -tf-custom-opdefs="name: 'Prelu' input_arg: { name: 'x' type: DT_FLOAT } input_arg: { name: 'alpha' type: DT_FLOAT } output_arg: { name: 'c' type: DT_FLOAT }" -o - | FileCheck %s --dump-input-on-failure +# Add two tensor<4xi32> inputs and return the result + +node { + name: "input0" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 10 + } + } + } + } + experimental_debug_info { + } +} +node { + name: "alpha" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.5 + } + } + } + experimental_debug_info { + } +} +node { + name: "Relu" + op: "Relu" + input: "input0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + experimental_debug_info { + } +} +node { + name: "Neg" + op: "Neg" + input: "input0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + experimental_debug_info { + } +} +node { + name: "Relu1" + op: "Relu" + input: "Neg" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + experimental_debug_info { + } +} +node { + name: "Mul" + op: "Mul" + input: "alpha" + input: "Relu1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + experimental_debug_info { + } +} +node { + name: "Add" + op: "Add" + input: "Relu" + input: "Mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + experimental_debug_info { + } +} +node { + name: "main" + op: "_Retval" + input: "Add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "index" + value { + i: 0 + } + } +} +library { +} +versions { + producer: 344 +} + +# CHECK: "node": +# CHECK: "name": "input0", +# CHECK-NEXT: "op": "Placeholder", +# CHECK-NEXT: "attr": +# CHECK: "type": "DT_FLOAT" +# CHECK: "name": "Add.Relu.Neg.Relu1.Mul", +# CHECK-NEXT: "op": "Const", +# CHECK-NEXT: "attr": +# CHECK: "value": +# CHECK: "tensor": +# CHECK: "dtype": "DT_FLOAT", +# CHECK: "tensorShape": {}, +# CHECK: "floatVal": +# CHECK: -0.5 +# CHECK: "name": "Add.Relu.Neg.Relu1.Mul1", +# CHECK-NEXT: "op": "Prelu", +# CHECK-NEXT: "input": +# CHECK: "input0", +# CHECK: "Add.Relu.Neg.Relu1.Mul" +# CHECK: "attr": +# CHECK: "_output_shapes": +# CHECK: "list": +# CHECK: "shape": +# CHECK: "dim": +# CHECK: "size": "10" +# CHECK: "experimentalDebugInfo": {} +# CHECK: "name": "Add", +# CHECK-NEXT: "op": "_Retval", +# CHECK-NEXT: "input": +# CHECK: "Add.Relu.Neg.Relu1.Mul1" +# CHECK: "attr": +# CHECK: "T": +# CHECK: "type": "DT_FLOAT" +# CHECK: "library": {}, +# CHECK: "versions": +# CHECK: "producer": 344 + diff --git a/tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.cc b/tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.cc index 631bb1ae2af..a445937570e 100644 --- a/tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.cc +++ b/tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,7 +20,6 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tfjs/transforms/passes.h" @@ -47,6 +46,11 @@ void AddTFToTFJSConversionPasses(mlir::OpPassManager* pm) { // Canonicalize, CSE etc. pm->addNestedPass(mlir::createCanonicalizerPass()); pm->addNestedPass(mlir::createCSEPass()); + + // raise to executor dialect in order to use GraphDef converter + pm->addNestedPass( + mlir::CreateFunctionalToExecutorDialectConversionPass()); + pm->addNestedPass(mlir::CreateBreakUpIslandsPass()); } } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfjs/translate/json_translate.cc b/tensorflow/compiler/mlir/tfjs/translate/json_translate.cc new file mode 100644 index 00000000000..7f4b8ffae09 --- /dev/null +++ b/tensorflow/compiler/mlir/tfjs/translate/json_translate.cc @@ -0,0 +1,105 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tfjs/translate/json_translate.h" + +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Translation.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" + +using mlir::ModuleOp; +using mlir::TranslateFromMLIRRegistration; +using std::string; +using tensorflow::Status; +using xla::StatusOr; + +// Translates the given MLIR module in the TFJS dialect to TFJS JSON +// format. Returns false on success. +// +bool tfjs::MlirToJSONTranslateFunction(ModuleOp module, + std::string* serialized_json) { + string json_output; + // Allow TF to treat TFJS ops as TF ops. + if (!tensorflow::AddTensorFlowOpPrefix("tfjs.").ok()) { + LOG(ERROR) << "Failed to add tfjs op prefix."; + return false; + } + tensorflow::GraphExportConfig confs; + confs.export_shapes = true; + confs.export_library = true; + tensorflow::FunctionLibraryDefinition flib_def( + tensorflow::OpRegistry::Global(), tensorflow::FunctionDefLibrary()); + absl::flat_hash_set control_ret_nodes; + auto graph = absl::make_unique(flib_def); + auto status = tensorflow::ConvertMlirToGraph(module, confs, &graph, &flib_def, + &control_ret_nodes); + if (!status.ok()) { + LOG(ERROR) << "Graph export failed: " << status; + return false; + } + auto graphdef = absl::make_unique(); + graph->ToGraphDef(graphdef.get()); + + // Replace the _Arg nodes of the main function with Placeholder op. + auto nodes = graphdef->mutable_node(); + for (const auto& node : llvm::enumerate(*nodes)) { + if (node.value().op() == "_Arg") { + nodes->Mutable(node.index())->set_op("Placeholder"); + } + } + + tensorflow::protobuf::util::JsonPrintOptions json_options; + json_options.add_whitespace = true; + auto jsonStatus = tensorflow::protobuf::util::MessageToJsonString( + *graphdef, &json_output, json_options); + if (!jsonStatus.ok()) { + LOG(ERROR) << "Proto2Json failed: " << status; + return false; + } + *serialized_json = std::move(json_output); + return true; +} + +static mlir::LogicalResult MlirToJSONFileTranslateFunction( + ModuleOp module, llvm::raw_ostream& output) { + std::string serialized_json; + if (!tfjs::MlirToJSONTranslateFunction(module, &serialized_json)) + return mlir::failure(); + + output << serialized_json; + return mlir::success(); +} + +static TranslateFromMLIRRegistration MLIRToJSONFileTranslate( + "mlir-to-tfjs-json", MlirToJSONFileTranslateFunction); diff --git a/tensorflow/compiler/mlir/tfjs/translate/json_translate.h b/tensorflow/compiler/mlir/tfjs/translate/json_translate.h new file mode 100644 index 00000000000..0a931f770ad --- /dev/null +++ b/tensorflow/compiler/mlir/tfjs/translate/json_translate.h @@ -0,0 +1,31 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_JSON_TRANSLATE_H_ +#define TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_JSON_TRANSLATE_H_ + +#include + +#include "mlir/IR/Module.h" // from @llvm-project +#include "tensorflow/core/lib/core/status.h" + +namespace tfjs { + +// Translates the given MLIR `module` into a JSON string. Returns true if +// translation fails, otherwise returns false. +bool MlirToJSONTranslateFunction(mlir::ModuleOp module, + std::string* serialized_json); +} // namespace tfjs + +#endif // TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_JSON_TRANSLATE_H_ diff --git a/tensorflow/compiler/mlir/tfjs/translate/tf_tfjs_translate.cc b/tensorflow/compiler/mlir/tfjs/translate/tf_tfjs_translate.cc new file mode 100644 index 00000000000..e735a3c7b8c --- /dev/null +++ b/tensorflow/compiler/mlir/tfjs/translate/tf_tfjs_translate.cc @@ -0,0 +1,173 @@ + +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "absl/strings/str_split.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/ToolOutputFile.h" +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/FileUtilities.h" // from @llvm-project +#include "tensorflow/compiler/mlir/init_mlir.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h" +#include "tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.h" +#include "tensorflow/compiler/mlir/tfjs/transforms/passes.h" +#include "tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +using llvm::cl::opt; +using mlir::MLIRContext; +using stream_executor::port::StatusOr; + +// NOLINTNEXTLINE +opt input_file_name(llvm::cl::Positional, + llvm::cl::desc(""), + llvm::cl::init("-")); + +// NOLINTNEXTLINE +opt import_saved_model_object_graph( + "savedmodel-objectgraph-to-mlir", + llvm::cl::desc("Import a saved model to its MLIR representation"), + llvm::cl::value_desc("dir")); + +// NOLINTNEXTLINE +opt import_saved_model_signature_defs( + "savedmodel-signaturedefs-to-mlir", + llvm::cl::desc("Import a saved model V1 to its MLIR representation"), + llvm::cl::value_desc("dir")); + +// NOLINTNEXTLINE +opt saved_model_tags( + "tf-savedmodel-tags", + llvm::cl::desc("Tags used to indicate which MetaGraphDef to import, " + "separated by ','"), + llvm::cl::init("serve")); + +// NOLINTNEXTLINE +opt saved_model_exported_names( + "tf-savedmodel-exported-names", + llvm::cl::desc("Names to export from SavedModel, separated by ','. Empty " + "(the default) means export all."), + llvm::cl::init("")); + +// NOLINTNEXTLINE +opt output_file_name("o", llvm::cl::desc(""), + llvm::cl::value_desc("filename"), + llvm::cl::init("-")); +// NOLINTNEXTLINE +opt input_mlir( + "input-mlir", + llvm::cl::desc("Take input TensorFlow model in textual MLIR instead of " + "GraphDef format"), + llvm::cl::init(false), llvm::cl::Hidden); +// NOLINTNEXTLINE +opt output_mlir( + "output-mlir", + llvm::cl::desc("Output MLIR rather than JSON for the generated TFJS model"), + llvm::cl::init(false)); + +// The following approach allows injecting opdefs in addition +// to those that are already part of the global TF registry to be linked in +// prior to importing the graph. The primary goal is for support of custom ops. +// This is not intended to be a general solution for custom ops for the future +// but mainly for supporting older models like mobilenet_ssd. More appropriate +// mechanisms, such as op hints or using functions to represent composable ops +// like https://github.com/tensorflow/community/pull/113 should be encouraged +// going forward. +// NOLINTNEXTLINE +llvm::cl::list custom_opdefs( + "tf-custom-opdefs", llvm::cl::desc("List of custom opdefs when importing " + "graphdef")); + +// Debugging flag to print function mapping in the JSON. +// NOLINTNEXTLINE +static opt print_function_result_mapping( + "print-function-result-mapping", + llvm::cl::desc( + "Print the mapping of function result to json output buffer"), + llvm::cl::init(false)); + +enum TranslationStatus { kTrSuccess, kTrFailure }; + +static int PrintFunctionResultMapping(const std::string& result) { + std::cout << result << std::endl; + return kTrSuccess; +} + +int main(int argc, char** argv) { + tensorflow::InitMlir y(&argc, &argv); + + llvm::cl::ParseCommandLineOptions(argc, argv, + "TF GraphDef to TFJS JSON converter\n"); + + MLIRContext context; + llvm::SourceMgr source_mgr; + mlir::SourceMgrDiagnosticHandler sourceMgrHandler(source_mgr, &context); + + StatusOr module; + + if (import_saved_model_object_graph || import_saved_model_signature_defs) { + if (input_mlir) + module = tensorflow::errors::InvalidArgument( + "Importing saved model should not have input_mlir set"); + module = tensorflow::ImportSavedModel( + import_saved_model_object_graph, import_saved_model_signature_defs, + custom_opdefs, input_file_name, saved_model_tags, + saved_model_exported_names, &context); + } else { + module = tensorflow::LoadFromGraphdefOrMlirSource( + input_file_name, input_mlir, custom_opdefs, debug_info_file, + input_arrays, input_dtypes, input_shapes, output_arrays, + /*prune_unused_nodes=*/true, &source_mgr, &context); + } + + // If errors occur, the library call in the above already logged the error + // message. So we can just return here. + if (!module.ok()) return kTrFailure; + + mlir::PassManager pm(&context); + + tensorflow::AddTFToTFJSConversionPasses(&pm); + + std::string result; + auto status = tensorflow::ConvertTFOpsToTfjsJSON(module.ValueOrDie().get(), + output_mlir, &result, &pm); + if (!status.ok()) return kTrFailure; + + std::string error_msg; + auto output = mlir::openOutputFile(output_file_name, &error_msg); + if (output == nullptr) { + llvm::errs() << error_msg << '\n'; + return kTrFailure; + } + output->os() << result; + output->keep(); + + // Print out debugging info related to function mapping. + if (print_function_result_mapping) return PrintFunctionResultMapping(result); + return kTrSuccess; +} diff --git a/tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.cc b/tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.cc new file mode 100644 index 00000000000..7dc9ea049ba --- /dev/null +++ b/tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.cc @@ -0,0 +1,152 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.h" + +#include +#include +#include +#include +#include + +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Parser.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/FileUtilities.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" +#include "tensorflow/compiler/mlir/tfjs/translate/json_translate.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace tensorflow { + +using mlir::MLIRContext; +using mlir::ModuleOp; +using mlir::OwningModuleRef; +using stream_executor::port::StatusOr; + +namespace { +tensorflow::Status RegisterCustomOps( + const std::vector& extra_tf_opdefs) { + for (const auto& tf_opdefs_string : extra_tf_opdefs) { + tensorflow::OpDef opdef; + if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs_string, + &opdef)) { + LOG(ERROR) << "OpDef parsing failed for: " << tf_opdefs_string; + return errors::InvalidArgument("fail to parse extra OpDef"); + } + // Register extra opdefs. + tensorflow::OpRegistry::Global()->Register( + [opdef](tensorflow::OpRegistrationData* op_reg_data) -> Status { + *op_reg_data = tensorflow::OpRegistrationData(opdef); + return Status::OK(); + }); + } + return Status::OK(); +} +} // namespace + +StatusOr LoadFromGraphdefOrMlirSource( + const std::string& input_filename, bool input_mlir, + const std::vector& extra_tf_opdefs, + absl::string_view debug_info_file, absl::string_view input_arrays, + absl::string_view input_dtypes, absl::string_view input_shapes, + absl::string_view output_arrays, bool prune_unused_nodes, + llvm::SourceMgr* source_mgr, MLIRContext* context) { + // Set up the input file. + std::string error_message; + auto file = mlir::openInputFile(input_filename, &error_message); + if (!file) { + llvm::errs() << error_message << "\n"; + return errors::InvalidArgument("fail to open input file"); + } + + if (input_mlir) { + source_mgr->AddNewSourceBuffer(std::move(file), llvm::SMLoc()); + return OwningModuleRef(mlir::parseSourceFile(*source_mgr, context)); + } + + TF_RETURN_IF_ERROR(RegisterCustomOps(extra_tf_opdefs)); + + return tensorflow::GraphdefToMlirTranslateFunction( + file->getBuffer(), debug_info_file, input_arrays, input_dtypes, + input_shapes, output_arrays, /*control_output_arrays=*/"", + prune_unused_nodes, /*convert_legacy_fed_inputs=*/true, + /*graph_as_function=*/false, /*upgrade_legacy=*/true, + /*enable_shape_inference=*/true, context); +} + +Status ConvertTFOpsToTfjsJSON(mlir::ModuleOp module, bool export_to_mlir, + std::string* result, + mlir::PassManager* pass_manager) { + mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext(), + /*propagate=*/true); + if (failed(pass_manager->run(module))) { + return statusHandler.ConsumeStatus(); + } + + if (export_to_mlir) { + llvm::raw_string_ostream os(*result); + module.print(os); + return Status::OK(); + } + + return tfjs::MlirToJSONTranslateFunction(module, result) + ? Status::OK() + : statusHandler.ConsumeStatus(); +} + +StatusOr ImportSavedModel( + bool import_saved_model, bool import_saved_model_v1, + const std::vector& extra_tf_opdefs, + const std::string& input_filename, const std::string& saved_model_tags, + const std::string& saved_model_exported_names, mlir::MLIRContext* context) { + std::unordered_set tags = absl::StrSplit(saved_model_tags, ','); + std::vector exported_names_in_vector = + absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty()); + absl::Span exported_names(exported_names_in_vector); + if (import_saved_model) { + auto module = tensorflow::SavedModelObjectGraphToMlirImport( + input_filename, tags, absl::Span(exported_names), context); + if (!module) + return tensorflow::errors::InvalidArgument("fail to open input file"); + TF_RETURN_IF_ERROR(RegisterCustomOps(extra_tf_opdefs)); + return module; + } else if (import_saved_model_v1) { + auto module = tensorflow::SavedModelSignatureDefsToMlirImport( + input_filename, tags, exported_names, context); + + if (!module) + return tensorflow::errors::InvalidArgument("fail to open input file"); + TF_RETURN_IF_ERROR(RegisterCustomOps(extra_tf_opdefs)); + return module; + } else { + return tensorflow::errors::InvalidArgument( + "Should be either saved model v1 or v2"); + } +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.h b/tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.h new file mode 100644 index 00000000000..d68f0e7d46e --- /dev/null +++ b/tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.h @@ -0,0 +1,63 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_TF_TO_TFJS_JSON_H_ +#define TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_TF_TO_TFJS_JSON_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "llvm/Support/SourceMgr.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/core/platform/status.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace tensorflow { + +// Load a TF model from a GraphDef definition or a TF control flow dialect MLIR +// source into a MLIR module. If `input_mlir` is true, load from a MLIR source +// file; otherwise, load from a GraphDef. +// Setting prune_unused_nodes to true, would prune unreachable nodes if +// output_arrays is specified. +stream_executor::port::StatusOr +LoadFromGraphdefOrMlirSource( + const std::string& input_filename, bool input_mlir, + const std::vector& extra_tf_opdefs, + absl::string_view debug_info_file, absl::string_view input_arrays, + absl::string_view input_dtypes, absl::string_view input_shapes, + absl::string_view output_arrays, bool prune_unused_nodes, + llvm::SourceMgr* source_mgr, mlir::MLIRContext* context); + +// Load Saved model (either v1 or v2) into MLIR. +stream_executor::port::StatusOr ImportSavedModel( + bool import_saved_model, bool import_saved_model_v1, + const std::vector& extra_tf_opdefs, + const std::string& input_filename, const std::string& saved_model_tags, + const std::string& saved_model_exported_names, mlir::MLIRContext* context); + +// Taking a MLIR module in TF executor dialect and a set of parameters, +// applies a set of passes to convert the module to TFJS dialect and +// serializes the result to JSON string. +// If `export_to_mlir` is true, the result is exported in MLIR text format, +// otherwise exported in JSON. +Status ConvertTFOpsToTfjsJSON(mlir::ModuleOp module, bool export_to_mlir, + std::string* result, + mlir::PassManager* pass_manager); +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_TF_TO_TFJS_JSON_H_