Add the op expansion graph optimization pass to tensorflow
PiperOrigin-RevId: 336770266 Change-Id: Iadc1917c9be25bc9010129af551814ae72160347
This commit is contained in:
parent
e5f30136a1
commit
a05f8bdf0e
@ -18,7 +18,8 @@ package_group(
|
|||||||
includes = ["//third_party/mlir:subpackages"],
|
includes = ["//third_party/mlir:subpackages"],
|
||||||
packages = [
|
packages = [
|
||||||
"//learning/brain/experimental/mlir/tfr/...",
|
"//learning/brain/experimental/mlir/tfr/...",
|
||||||
"//tensorflow/compiler/mlir/...",
|
"//tensorflow/c/...",
|
||||||
|
"//tensorflow/compiler/...",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -179,7 +180,6 @@ cc_library(
|
|||||||
"//tensorflow/compiler/mlir/tfr:passes",
|
"//tensorflow/compiler/mlir/tfr:passes",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:graph",
|
"//tensorflow/core:graph",
|
||||||
"//tensorflow/core:lib",
|
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/common_runtime:optimization_registry",
|
"//tensorflow/core/common_runtime:optimization_registry",
|
||||||
"//tensorflow/stream_executor/lib",
|
"//tensorflow/stream_executor/lib",
|
||||||
@ -201,7 +201,6 @@ tf_cc_test(
|
|||||||
":tfr_decompose_ctx",
|
":tfr_decompose_ctx",
|
||||||
"//tensorflow/compiler/xla:test",
|
"//tensorflow/compiler/xla:test",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
@ -216,7 +215,6 @@ cc_library(
|
|||||||
name = "graph_decompose_pass",
|
name = "graph_decompose_pass",
|
||||||
srcs = ["integration/graph_decompose_pass.cc"],
|
srcs = ["integration/graph_decompose_pass.cc"],
|
||||||
hdrs = ["integration/graph_decompose_pass.h"],
|
hdrs = ["integration/graph_decompose_pass.h"],
|
||||||
data = ["//tensorflow/compiler/mlir/tfr/resources:decomposition_lib"],
|
|
||||||
deps = [
|
deps = [
|
||||||
":tfr_decompose_ctx",
|
":tfr_decompose_ctx",
|
||||||
"//tensorflow/compiler/mlir:mlir_graph_optimization_pass",
|
"//tensorflow/compiler/mlir:mlir_graph_optimization_pass",
|
||||||
@ -228,6 +226,22 @@ cc_library(
|
|||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_py_test(
|
||||||
|
name = "graph_decompose_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["integration/graph_decompose_test.py"],
|
||||||
|
data = ["//tensorflow/compiler/mlir/tfr/resources:decomposition_lib"],
|
||||||
|
python_version = "PY3",
|
||||||
|
tags = [
|
||||||
|
"no_oss",
|
||||||
|
"notap",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/compiler/mlir/tfr/resources:composite_ops",
|
||||||
|
"//tensorflow/python/eager:def_function",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_python_pybind_extension(
|
tf_python_pybind_extension(
|
||||||
name = "tfr_wrapper",
|
name = "tfr_wrapper",
|
||||||
srcs = ["python/tfr_wrapper.cc"],
|
srcs = ["python/tfr_wrapper.cc"],
|
||||||
|
@ -15,7 +15,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.h"
|
#include "tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.h"
|
||||||
|
|
||||||
#include "tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h"
|
#include "tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
|
||||||
#include "tensorflow/core/platform/env.h"
|
#include "tensorflow/core/platform/env.h"
|
||||||
#include "tensorflow/core/platform/path.h"
|
#include "tensorflow/core/platform/path.h"
|
||||||
#include "tensorflow/core/util/env_var.h"
|
#include "tensorflow/core/util/env_var.h"
|
||||||
@ -23,8 +22,21 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
constexpr const char* const kTFRLibEnv = "TF_MLIR_TFR_LIB_DIR";
|
||||||
|
|
||||||
|
bool GraphDecomposePass::IsEnabled(const ConfigProto& config_proto) const {
|
||||||
|
const char* tfr_lib_env_val = getenv(string(kTFRLibEnv).c_str());
|
||||||
|
return tfr_lib_env_val != nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
Status GraphDecomposePass::Run(const ConfigProto& config_proto,
|
Status GraphDecomposePass::Run(const ConfigProto& config_proto,
|
||||||
mlir::ModuleOp module) {
|
mlir::ModuleOp module) {
|
||||||
|
if (!IsEnabled(config_proto)) {
|
||||||
|
VLOG(1) << "Skipping Graph Decomposition Pass, decompositin library was "
|
||||||
|
"not found";
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
VLOG(1) << "Run Graph Decomposition Passes";
|
||||||
TF_ASSIGN_OR_RETURN(ctx_, LoadDecompositionLib(module.getContext()));
|
TF_ASSIGN_OR_RETURN(ctx_, LoadDecompositionLib(module.getContext()));
|
||||||
TF_RETURN_IF_ERROR(ctx_->Decompose(module));
|
TF_RETURN_IF_ERROR(ctx_->Decompose(module));
|
||||||
return ctx_->Destroy();
|
return ctx_->Destroy();
|
||||||
@ -35,8 +47,7 @@ GraphDecomposePass::LoadDecompositionLib(mlir::MLIRContext* mlir_ctx) {
|
|||||||
Env* env = Env::Default();
|
Env* env = Env::Default();
|
||||||
std::string tfr_lib_dir;
|
std::string tfr_lib_dir;
|
||||||
TF_RETURN_IF_ERROR(ReadStringFromEnvVar(
|
TF_RETURN_IF_ERROR(ReadStringFromEnvVar(
|
||||||
"TF_MLIR_TFR_LIB_DIR", "tensorflow/compiler/mlir/tfr/resources",
|
kTFRLibEnv, "tensorflow/compiler/mlir/tfr/resources", &tfr_lib_dir));
|
||||||
&tfr_lib_dir));
|
|
||||||
string composite_mlir_dir = io::JoinPath(env->GetRunfilesDir(), tfr_lib_dir);
|
string composite_mlir_dir = io::JoinPath(env->GetRunfilesDir(), tfr_lib_dir);
|
||||||
std::vector<string> files;
|
std::vector<string> files;
|
||||||
TF_RETURN_IF_ERROR(env->GetChildren(composite_mlir_dir, &files));
|
TF_RETURN_IF_ERROR(env->GetChildren(composite_mlir_dir, &files));
|
||||||
@ -59,7 +70,7 @@ GraphDecomposePass::LoadDecompositionLib(mlir::MLIRContext* mlir_ctx) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
constexpr int kMlirGraphDecomposePassPriority = 1;
|
constexpr int kMlirGraphDecomposePassPriority = -1;
|
||||||
|
|
||||||
static mlir_pass_registration::MlirOptimizationPassRegistration
|
static mlir_pass_registration::MlirOptimizationPassRegistration
|
||||||
register_mlir_graph_decompose_pass(kMlirGraphDecomposePassPriority,
|
register_mlir_graph_decompose_pass(kMlirGraphDecomposePassPriority,
|
||||||
|
@ -16,9 +16,9 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_COMPILER_MLIR_TFR_INTEGRATION_GRAPH_DECOMPOSE_PASS_H_
|
#define TENSORFLOW_COMPILER_MLIR_TFR_INTEGRATION_GRAPH_DECOMPOSE_PASS_H_
|
||||||
|
|
||||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||||
#include "tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h"
|
|
||||||
#include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h"
|
#include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h"
|
||||||
|
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -30,10 +30,9 @@ class GraphDecomposePass : public MlirOptimizationPass {
|
|||||||
public:
|
public:
|
||||||
llvm::StringRef name() const override { return "tfr"; }
|
llvm::StringRef name() const override { return "tfr"; }
|
||||||
|
|
||||||
bool IsEnabled(const ConfigProto& config_proto) const override {
|
// Whether to run this pass. If this is enabled, the GraphDef will be imported
|
||||||
// TODO(fengliuai): make a new flag in config_proto.experimental()
|
// to MLIR even no tf composition file is found.
|
||||||
return true;
|
bool IsEnabled(const ConfigProto& config_proto) const override;
|
||||||
}
|
|
||||||
|
|
||||||
// This should be used as a thin mapper around mlir::ModulePass::runOnModule
|
// This should be used as a thin mapper around mlir::ModulePass::runOnModule
|
||||||
// API integrated with the Tensorflow runtime.
|
// API integrated with the Tensorflow runtime.
|
||||||
|
@ -0,0 +1,90 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Tests for third_party.tensorflow.compiler.mlir.tfr.integrattion.graph_decompose."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from tensorflow.compiler.mlir.tfr.resources import gen_composite_ops
|
||||||
|
from tensorflow.python.eager import def_function
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import load_library
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops import nn_ops
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
_lib_dir = os.path.dirname(gen_composite_ops.__file__)
|
||||||
|
_lib_name = os.path.basename(gen_composite_ops.__file__)[4:].replace(
|
||||||
|
'.py', '.so')
|
||||||
|
load_library.load_op_library(os.path.join(_lib_dir, _lib_name))
|
||||||
|
|
||||||
|
|
||||||
|
class GraphDecomposeTest(test.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
os.environ['TF_MLIR_TFR_LIB_DIR'] = 'tensorflow/compiler/mlir/tfr/resources'
|
||||||
|
super(GraphDecomposeTest, self).setUp()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
del os.environ['TF_MLIR_TFR_LIB_DIR']
|
||||||
|
super(GraphDecomposeTest, self).tearDown()
|
||||||
|
|
||||||
|
def testAddN(self):
|
||||||
|
add = def_function.function(gen_composite_ops.my_add_n)
|
||||||
|
t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
|
||||||
|
t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
|
||||||
|
t3 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
|
||||||
|
sq1 = add([t1])
|
||||||
|
sq2 = add([t1, t2])
|
||||||
|
sq3 = add([t1, t2, t3])
|
||||||
|
self.assertAllEqual(sq1.numpy().reshape(-1), [1, 2, 3, 4])
|
||||||
|
self.assertAllEqual(sq2.numpy().reshape(-1), [2, 4, 6, 8])
|
||||||
|
self.assertAllEqual(sq3.numpy().reshape(-1), [3, 6, 9, 12])
|
||||||
|
|
||||||
|
def testBiasedDense(self):
|
||||||
|
biased_dense = def_function.function(gen_composite_ops.my_biased_dense)
|
||||||
|
t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
|
||||||
|
t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
|
||||||
|
t3 = constant_op.constant([[-10.0, -10.0], [-10.0, -10.0]])
|
||||||
|
sq = biased_dense(t1, t2, t3)
|
||||||
|
self.assertAllEqual(sq.numpy().reshape(-1), [-3, 0, 5, 12])
|
||||||
|
|
||||||
|
def testBiasedDenseRelu(self):
|
||||||
|
biased_dense = def_function.function(gen_composite_ops.my_biased_dense)
|
||||||
|
t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
|
||||||
|
t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
|
||||||
|
t3 = constant_op.constant([[-10.0, -10.0], [-10.0, -10.0]])
|
||||||
|
sq = biased_dense(t1, t2, t3, act='relu')
|
||||||
|
self.assertAllEqual(sq.numpy().reshape(-1), [0, 0, 5, 12])
|
||||||
|
|
||||||
|
def testWithKnownKernel(self):
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def biasd_dense_elu(x, y, z):
|
||||||
|
dot = gen_composite_ops.my_biased_dense(x, y, z)
|
||||||
|
return nn_ops.elu(dot) # with known kernel, should not expand.
|
||||||
|
|
||||||
|
t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
|
||||||
|
t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
|
||||||
|
t3 = constant_op.constant([[-10.0, -10.0], [-10.0, -10.0]])
|
||||||
|
sq = biasd_dense_elu(t1, t2, t3)
|
||||||
|
self.assertAllClose(sq.numpy().reshape(-1), [-0.950213, 0, 5, 12])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
ops.enable_eager_execution()
|
||||||
|
test.main()
|
@ -36,18 +36,9 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
|
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
|
||||||
#include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h"
|
#include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h"
|
||||||
#include "tensorflow/compiler/mlir/tfr/passes/passes.h"
|
#include "tensorflow/compiler/mlir/tfr/passes/passes.h"
|
||||||
#include "tensorflow/core/framework/graph.pb.h"
|
|
||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
|
||||||
#include "tensorflow/core/framework/node_def_util.h"
|
|
||||||
#include "tensorflow/core/framework/versions.pb.h"
|
|
||||||
#include "tensorflow/core/graph/graph.h"
|
#include "tensorflow/core/graph/graph.h"
|
||||||
#include "tensorflow/core/graph/node_builder.h"
|
#include "tensorflow/core/graph/node_builder.h"
|
||||||
#include "tensorflow/core/graph/tensor_id.h"
|
#include "tensorflow/core/graph/tensor_id.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
|
||||||
#include "tensorflow/core/lib/strings/str_util.h"
|
|
||||||
#include "tensorflow/core/platform/errors.h"
|
|
||||||
#include "tensorflow/core/platform/protobuf.h"
|
|
||||||
#include "tensorflow/core/platform/types.h"
|
|
||||||
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
|
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
|
||||||
#include "tensorflow/core/protobuf/struct.pb.h"
|
#include "tensorflow/core/protobuf/struct.pb.h"
|
||||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||||
|
@ -20,13 +20,8 @@ limitations under the License.
|
|||||||
#include "mlir/IR/Module.h" // from @llvm-project
|
#include "mlir/IR/Module.h" // from @llvm-project
|
||||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
|
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
|
||||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
|
||||||
#include "tensorflow/core/framework/function.h"
|
#include "tensorflow/core/framework/function.h"
|
||||||
#include "tensorflow/core/framework/graph.pb.h"
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
|
||||||
#include "tensorflow/core/framework/node_def_builder.h"
|
|
||||||
#include "tensorflow/core/framework/types.pb.h"
|
|
||||||
#include "tensorflow/core/platform/stringpiece.h"
|
|
||||||
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
|
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
|
||||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||||
|
|
||||||
|
@ -29,7 +29,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/node_def_builder.h"
|
#include "tensorflow/core/framework/node_def_builder.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/framework/types.pb.h"
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||||
|
@ -33,14 +33,31 @@ cc_library(
|
|||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_custom_op_library(
|
||||||
|
name = "composite_ops.so",
|
||||||
|
srcs = [
|
||||||
|
"composite_ops.cc",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_gen_op_wrapper_py(
|
tf_gen_op_wrapper_py(
|
||||||
name = "composite_ops",
|
name = "gen_composite_ops",
|
||||||
out = "composite_ops.py",
|
out = "gen_composite_ops.py",
|
||||||
deps = [
|
deps = [
|
||||||
":composite_ops_cc",
|
":composite_ops_cc",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_custom_op_py_library(
|
||||||
|
name = "composite_ops",
|
||||||
|
dso = [":composite_ops.so"],
|
||||||
|
kernels = [":composite_ops_cc"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":gen_composite_ops",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "test_ops_cc",
|
name = "test_ops_cc",
|
||||||
srcs = ["test_ops.cc"],
|
srcs = ["test_ops.cc"],
|
||||||
|
120
tensorflow/compiler/mlir/tfr/tfr.bzl
Normal file
120
tensorflow/compiler/mlir/tfr/tfr.bzl
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
"""BUILD extension for TF composition project."""
|
||||||
|
|
||||||
|
load("//tensorflow:tensorflow.bzl", "py_binary", "tf_gen_op_wrapper_py")
|
||||||
|
load("//tensorflow:tensorflow.google.bzl", "pytype_library")
|
||||||
|
|
||||||
|
def gen_op_libraries(
|
||||||
|
name,
|
||||||
|
src,
|
||||||
|
deps,
|
||||||
|
tags = [],
|
||||||
|
test = False):
|
||||||
|
"""gen_op_libraries() generates all cc and py libraries for composite op source.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: used as the name component of all the generated libraries.
|
||||||
|
src: File contains the composite ops.
|
||||||
|
deps: Libraries the 'src' depends on.
|
||||||
|
tags:
|
||||||
|
test:
|
||||||
|
"""
|
||||||
|
if not src.endswith(".py") or name == src[:-3]:
|
||||||
|
fail("'src' %s conflicts with op Python wrapper. Rename it to be different from 'name'." % src)
|
||||||
|
|
||||||
|
gen_op_lib_exec = src[:-3]
|
||||||
|
py_binary(
|
||||||
|
name = gen_op_lib_exec,
|
||||||
|
srcs = [src],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
python_version = "PY3",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/python:platform",
|
||||||
|
] + deps,
|
||||||
|
)
|
||||||
|
|
||||||
|
register_op = "register_" + name
|
||||||
|
native.genrule(
|
||||||
|
name = register_op,
|
||||||
|
srcs = [],
|
||||||
|
outs = [name + ".inc.cc"],
|
||||||
|
cmd = "$(location %s) --output=$@ --gen_register_op=true" % gen_op_lib_exec,
|
||||||
|
exec_tools = [":" + gen_op_lib_exec],
|
||||||
|
local = 1,
|
||||||
|
tags = tags,
|
||||||
|
)
|
||||||
|
|
||||||
|
native.cc_library(
|
||||||
|
name = name + "_cc",
|
||||||
|
testonly = test,
|
||||||
|
srcs = [":" + register_op],
|
||||||
|
copts = [
|
||||||
|
"-Wno-unused-result",
|
||||||
|
"-Wno-unused-variable",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_gen_op_wrapper_py(
|
||||||
|
name = name,
|
||||||
|
out = name + ".py",
|
||||||
|
deps = [
|
||||||
|
":%s_cc" % name,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
pytype_library(
|
||||||
|
name = name + "_grads",
|
||||||
|
srcs = [
|
||||||
|
src,
|
||||||
|
],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
"//third_party/py/numpy",
|
||||||
|
"//third_party/py/tensorflow",
|
||||||
|
] + deps,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytype_library(
|
||||||
|
name = name + "_lib",
|
||||||
|
srcs = [
|
||||||
|
name + ".py",
|
||||||
|
],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":%s" % name,
|
||||||
|
":%s_cc" % name,
|
||||||
|
":%s_grads" % name,
|
||||||
|
"//third_party/py/numpy",
|
||||||
|
"//third_party/py/tensorflow",
|
||||||
|
] + deps,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Link the register op and rebuild the binary
|
||||||
|
gen_tfr_lib_exec = gen_op_lib_exec + "_registered"
|
||||||
|
py_binary(
|
||||||
|
name = gen_tfr_lib_exec,
|
||||||
|
main = src,
|
||||||
|
srcs = [src],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
python_version = "PY3",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/python:platform",
|
||||||
|
":%s" % name + "_cc",
|
||||||
|
] + deps,
|
||||||
|
)
|
||||||
|
|
||||||
|
op_tfr = "composite_" + name
|
||||||
|
native.genrule(
|
||||||
|
name = op_tfr,
|
||||||
|
srcs = [],
|
||||||
|
outs = [name + ".mlir"],
|
||||||
|
cmd = "$(location %s) --output=$@ --gen_register_op=false" % gen_tfr_lib_exec,
|
||||||
|
exec_tools = [":" + gen_tfr_lib_exec],
|
||||||
|
local = 1,
|
||||||
|
tags = tags,
|
||||||
|
)
|
Loading…
x
Reference in New Issue
Block a user