From a05f8bdf0e964d27cab4dfbe36576b7692f055da Mon Sep 17 00:00:00 2001 From: Feng Liu <fengliuai@google.com> Date: Mon, 12 Oct 2020 16:36:09 -0700 Subject: [PATCH] Add the op expansion graph optimization pass to tensorflow PiperOrigin-RevId: 336770266 Change-Id: Iadc1917c9be25bc9010129af551814ae72160347 --- tensorflow/compiler/mlir/tfr/BUILD | 22 +++- .../tfr/integration/graph_decompose_pass.cc | 19 ++- .../tfr/integration/graph_decompose_pass.h | 11 +- .../tfr/integration/graph_decompose_test.py | 90 +++++++++++++ .../mlir/tfr/integration/tfr_decompose_ctx.cc | 9 -- .../mlir/tfr/integration/tfr_decompose_ctx.h | 5 - .../tfr/integration/tfr_decompose_ctx_test.cc | 1 - tensorflow/compiler/mlir/tfr/resources/BUILD | 21 ++- tensorflow/compiler/mlir/tfr/tfr.bzl | 120 ++++++++++++++++++ 9 files changed, 267 insertions(+), 31 deletions(-) create mode 100644 tensorflow/compiler/mlir/tfr/integration/graph_decompose_test.py create mode 100644 tensorflow/compiler/mlir/tfr/tfr.bzl diff --git a/tensorflow/compiler/mlir/tfr/BUILD b/tensorflow/compiler/mlir/tfr/BUILD index 420648dff54..63ce46b853c 100644 --- a/tensorflow/compiler/mlir/tfr/BUILD +++ b/tensorflow/compiler/mlir/tfr/BUILD @@ -18,7 +18,8 @@ package_group( includes = ["//third_party/mlir:subpackages"], packages = [ "//learning/brain/experimental/mlir/tfr/...", - "//tensorflow/compiler/mlir/...", + "//tensorflow/c/...", + "//tensorflow/compiler/...", ], ) @@ -179,7 +180,6 @@ cc_library( "//tensorflow/compiler/mlir/tfr:passes", "//tensorflow/core:framework", "//tensorflow/core:graph", - "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/common_runtime:optimization_registry", "//tensorflow/stream_executor/lib", @@ -201,7 +201,6 @@ tf_cc_test( ":tfr_decompose_ctx", "//tensorflow/compiler/xla:test", "//tensorflow/core:framework", - "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -216,7 +215,6 @@ cc_library( name = "graph_decompose_pass", srcs = ["integration/graph_decompose_pass.cc"], hdrs = ["integration/graph_decompose_pass.h"], - data = ["//tensorflow/compiler/mlir/tfr/resources:decomposition_lib"], deps = [ ":tfr_decompose_ctx", "//tensorflow/compiler/mlir:mlir_graph_optimization_pass", @@ -228,6 +226,22 @@ cc_library( alwayslink = 1, ) +tf_py_test( + name = "graph_decompose_test", + size = "small", + srcs = ["integration/graph_decompose_test.py"], + data = ["//tensorflow/compiler/mlir/tfr/resources:decomposition_lib"], + python_version = "PY3", + tags = [ + "no_oss", + "notap", + ], + deps = [ + "//tensorflow/compiler/mlir/tfr/resources:composite_ops", + "//tensorflow/python/eager:def_function", + ], +) + tf_python_pybind_extension( name = "tfr_wrapper", srcs = ["python/tfr_wrapper.cc"], diff --git a/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.cc b/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.cc index 9fd7ee03cb9..be0fab13021 100644 --- a/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.cc +++ b/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.h" #include "tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h" -#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/path.h" #include "tensorflow/core/util/env_var.h" @@ -23,8 +22,21 @@ limitations under the License. namespace tensorflow { +constexpr const char* const kTFRLibEnv = "TF_MLIR_TFR_LIB_DIR"; + +bool GraphDecomposePass::IsEnabled(const ConfigProto& config_proto) const { + const char* tfr_lib_env_val = getenv(string(kTFRLibEnv).c_str()); + return tfr_lib_env_val != nullptr; +} + Status GraphDecomposePass::Run(const ConfigProto& config_proto, mlir::ModuleOp module) { + if (!IsEnabled(config_proto)) { + VLOG(1) << "Skipping Graph Decomposition Pass, decompositin library was " + "not found"; + return Status::OK(); + } + VLOG(1) << "Run Graph Decomposition Passes"; TF_ASSIGN_OR_RETURN(ctx_, LoadDecompositionLib(module.getContext())); TF_RETURN_IF_ERROR(ctx_->Decompose(module)); return ctx_->Destroy(); @@ -35,8 +47,7 @@ GraphDecomposePass::LoadDecompositionLib(mlir::MLIRContext* mlir_ctx) { Env* env = Env::Default(); std::string tfr_lib_dir; TF_RETURN_IF_ERROR(ReadStringFromEnvVar( - "TF_MLIR_TFR_LIB_DIR", "tensorflow/compiler/mlir/tfr/resources", - &tfr_lib_dir)); + kTFRLibEnv, "tensorflow/compiler/mlir/tfr/resources", &tfr_lib_dir)); string composite_mlir_dir = io::JoinPath(env->GetRunfilesDir(), tfr_lib_dir); std::vector<string> files; TF_RETURN_IF_ERROR(env->GetChildren(composite_mlir_dir, &files)); @@ -59,7 +70,7 @@ GraphDecomposePass::LoadDecompositionLib(mlir::MLIRContext* mlir_ctx) { } namespace { -constexpr int kMlirGraphDecomposePassPriority = 1; +constexpr int kMlirGraphDecomposePassPriority = -1; static mlir_pass_registration::MlirOptimizationPassRegistration register_mlir_graph_decompose_pass(kMlirGraphDecomposePassPriority, diff --git a/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.h b/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.h index f0963379928..89db74c72ed 100644 --- a/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.h +++ b/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.h @@ -16,9 +16,9 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_TFR_INTEGRATION_GRAPH_DECOMPOSE_PASS_H_ #include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h" #include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h" -#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { @@ -30,10 +30,9 @@ class GraphDecomposePass : public MlirOptimizationPass { public: llvm::StringRef name() const override { return "tfr"; } - bool IsEnabled(const ConfigProto& config_proto) const override { - // TODO(fengliuai): make a new flag in config_proto.experimental() - return true; - } + // Whether to run this pass. If this is enabled, the GraphDef will be imported + // to MLIR even no tf composition file is found. + bool IsEnabled(const ConfigProto& config_proto) const override; // This should be used as a thin mapper around mlir::ModulePass::runOnModule // API integrated with the Tensorflow runtime. diff --git a/tensorflow/compiler/mlir/tfr/integration/graph_decompose_test.py b/tensorflow/compiler/mlir/tfr/integration/graph_decompose_test.py new file mode 100644 index 00000000000..a03de698cb1 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/integration/graph_decompose_test.py @@ -0,0 +1,90 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for third_party.tensorflow.compiler.mlir.tfr.integrattion.graph_decompose.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.compiler.mlir.tfr.resources import gen_composite_ops +from tensorflow.python.eager import def_function +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import load_library +from tensorflow.python.framework import ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import test + +_lib_dir = os.path.dirname(gen_composite_ops.__file__) +_lib_name = os.path.basename(gen_composite_ops.__file__)[4:].replace( + '.py', '.so') +load_library.load_op_library(os.path.join(_lib_dir, _lib_name)) + + +class GraphDecomposeTest(test.TestCase): + + def setUp(self): + os.environ['TF_MLIR_TFR_LIB_DIR'] = 'tensorflow/compiler/mlir/tfr/resources' + super(GraphDecomposeTest, self).setUp() + + def tearDown(self): + del os.environ['TF_MLIR_TFR_LIB_DIR'] + super(GraphDecomposeTest, self).tearDown() + + def testAddN(self): + add = def_function.function(gen_composite_ops.my_add_n) + t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + t3 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + sq1 = add([t1]) + sq2 = add([t1, t2]) + sq3 = add([t1, t2, t3]) + self.assertAllEqual(sq1.numpy().reshape(-1), [1, 2, 3, 4]) + self.assertAllEqual(sq2.numpy().reshape(-1), [2, 4, 6, 8]) + self.assertAllEqual(sq3.numpy().reshape(-1), [3, 6, 9, 12]) + + def testBiasedDense(self): + biased_dense = def_function.function(gen_composite_ops.my_biased_dense) + t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + t3 = constant_op.constant([[-10.0, -10.0], [-10.0, -10.0]]) + sq = biased_dense(t1, t2, t3) + self.assertAllEqual(sq.numpy().reshape(-1), [-3, 0, 5, 12]) + + def testBiasedDenseRelu(self): + biased_dense = def_function.function(gen_composite_ops.my_biased_dense) + t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + t3 = constant_op.constant([[-10.0, -10.0], [-10.0, -10.0]]) + sq = biased_dense(t1, t2, t3, act='relu') + self.assertAllEqual(sq.numpy().reshape(-1), [0, 0, 5, 12]) + + def testWithKnownKernel(self): + + @def_function.function + def biasd_dense_elu(x, y, z): + dot = gen_composite_ops.my_biased_dense(x, y, z) + return nn_ops.elu(dot) # with known kernel, should not expand. + + t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + t3 = constant_op.constant([[-10.0, -10.0], [-10.0, -10.0]]) + sq = biasd_dense_elu(t1, t2, t3) + self.assertAllClose(sq.numpy().reshape(-1), [-0.950213, 0, 5, 12]) + + +if __name__ == '__main__': + ops.enable_eager_execution() + test.main() diff --git a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.cc b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.cc index 7a2962c7b67..4cc7d90f17b 100644 --- a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.cc +++ b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.cc @@ -36,18 +36,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" #include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h" #include "tensorflow/compiler/mlir/tfr/passes/passes.h" -#include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/framework/node_def_util.h" -#include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/graph/tensor_id.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/protobuf.h" -#include "tensorflow/core/platform/types.h" #include "tensorflow/core/protobuf/graph_debug_info.pb.h" #include "tensorflow/core/protobuf/struct.pb.h" #include "tensorflow/stream_executor/lib/statusor.h" diff --git a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h index b51d6158eb2..9ff314ebc92 100644 --- a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h +++ b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h @@ -20,13 +20,8 @@ limitations under the License. #include "mlir/IR/Module.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" -#include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/framework/node_def_builder.h" -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/protobuf/graph_debug_info.pb.h" #include "tensorflow/stream_executor/lib/statusor.h" diff --git a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx_test.cc b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx_test.cc index 8b1b0453cff..5d10936e092 100644 --- a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx_test.cc +++ b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx_test.cc @@ -29,7 +29,6 @@ limitations under the License. #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/stream_executor/lib/statusor.h" diff --git a/tensorflow/compiler/mlir/tfr/resources/BUILD b/tensorflow/compiler/mlir/tfr/resources/BUILD index 0d0705b1da0..62ca65c5b57 100644 --- a/tensorflow/compiler/mlir/tfr/resources/BUILD +++ b/tensorflow/compiler/mlir/tfr/resources/BUILD @@ -33,14 +33,31 @@ cc_library( alwayslink = 1, ) +tf_custom_op_library( + name = "composite_ops.so", + srcs = [ + "composite_ops.cc", + ], +) + tf_gen_op_wrapper_py( - name = "composite_ops", - out = "composite_ops.py", + name = "gen_composite_ops", + out = "gen_composite_ops.py", deps = [ ":composite_ops_cc", ], ) +tf_custom_op_py_library( + name = "composite_ops", + dso = [":composite_ops.so"], + kernels = [":composite_ops_cc"], + visibility = ["//visibility:public"], + deps = [ + ":gen_composite_ops", + ], +) + cc_library( name = "test_ops_cc", srcs = ["test_ops.cc"], diff --git a/tensorflow/compiler/mlir/tfr/tfr.bzl b/tensorflow/compiler/mlir/tfr/tfr.bzl new file mode 100644 index 00000000000..cc1b617f932 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/tfr.bzl @@ -0,0 +1,120 @@ +"""BUILD extension for TF composition project.""" + +load("//tensorflow:tensorflow.bzl", "py_binary", "tf_gen_op_wrapper_py") +load("//tensorflow:tensorflow.google.bzl", "pytype_library") + +def gen_op_libraries( + name, + src, + deps, + tags = [], + test = False): + """gen_op_libraries() generates all cc and py libraries for composite op source. + + Args: + name: used as the name component of all the generated libraries. + src: File contains the composite ops. + deps: Libraries the 'src' depends on. + tags: + test: + """ + if not src.endswith(".py") or name == src[:-3]: + fail("'src' %s conflicts with op Python wrapper. Rename it to be different from 'name'." % src) + + gen_op_lib_exec = src[:-3] + py_binary( + name = gen_op_lib_exec, + srcs = [src], + srcs_version = "PY2AND3", + python_version = "PY3", + deps = [ + "//tensorflow/python:platform", + ] + deps, + ) + + register_op = "register_" + name + native.genrule( + name = register_op, + srcs = [], + outs = [name + ".inc.cc"], + cmd = "$(location %s) --output=$@ --gen_register_op=true" % gen_op_lib_exec, + exec_tools = [":" + gen_op_lib_exec], + local = 1, + tags = tags, + ) + + native.cc_library( + name = name + "_cc", + testonly = test, + srcs = [":" + register_op], + copts = [ + "-Wno-unused-result", + "-Wno-unused-variable", + ], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], + alwayslink = 1, + ) + + tf_gen_op_wrapper_py( + name = name, + out = name + ".py", + deps = [ + ":%s_cc" % name, + ], + ) + + pytype_library( + name = name + "_grads", + srcs = [ + src, + ], + srcs_version = "PY2AND3", + deps = [ + "//third_party/py/numpy", + "//third_party/py/tensorflow", + ] + deps, + ) + + pytype_library( + name = name + "_lib", + srcs = [ + name + ".py", + ], + srcs_version = "PY2AND3", + deps = [ + ":%s" % name, + ":%s_cc" % name, + ":%s_grads" % name, + "//third_party/py/numpy", + "//third_party/py/tensorflow", + ] + deps, + ) + + # Link the register op and rebuild the binary + gen_tfr_lib_exec = gen_op_lib_exec + "_registered" + py_binary( + name = gen_tfr_lib_exec, + main = src, + srcs = [src], + srcs_version = "PY2AND3", + python_version = "PY3", + deps = [ + "//tensorflow/python:platform", + ":%s" % name + "_cc", + ] + deps, + ) + + op_tfr = "composite_" + name + native.genrule( + name = op_tfr, + srcs = [], + outs = [name + ".mlir"], + cmd = "$(location %s) --output=$@ --gen_register_op=false" % gen_tfr_lib_exec, + exec_tools = [":" + gen_tfr_lib_exec], + local = 1, + tags = tags, + )