Add MlirXlaOpKernel, which is used to implement XlaOpKernels using MLIR legalization.
PiperOrigin-RevId: 360664123 Change-Id: Ic72c880496fe405675a8740559e13db62d195f18
This commit is contained in:
parent
c180f35a45
commit
cfec367771
@ -213,9 +213,7 @@ static bool ShouldBeMegamorphic(int64 compile_count, int64 execution_count) {
|
|||||||
execution_count < kMinExecutionsPerCompile * compile_count;
|
execution_count < kMinExecutionsPerCompile * compile_count;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Creates a simple graph using the specified op as the only op apart from the
|
xla::StatusOr<std::unique_ptr<Graph>> CreateGraph(
|
||||||
// arg and retval nodes.
|
|
||||||
static xla::StatusOr<std::unique_ptr<Graph>> CreateGraph(
|
|
||||||
const NodeDef& node_def, absl::Span<const XlaCompiler::Argument> args,
|
const NodeDef& node_def, absl::Span<const XlaCompiler::Argument> args,
|
||||||
absl::Span<const DataType> result_types) {
|
absl::Span<const DataType> result_types) {
|
||||||
// TODO(b/74182462): We implement this by creating a new dummy Graph including
|
// TODO(b/74182462): We implement this by creating a new dummy Graph including
|
||||||
|
@ -196,6 +196,12 @@ class XlaCompilationCache : public ResourceBase {
|
|||||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaCompilationCache);
|
TF_DISALLOW_COPY_AND_ASSIGN(XlaCompilationCache);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Creates a single-node graph using the specified node_def as the only op apart
|
||||||
|
// from the arg and retval nodes.
|
||||||
|
xla::StatusOr<std::unique_ptr<Graph>> CreateGraph(
|
||||||
|
const NodeDef& node_def, absl::Span<const XlaCompiler::Argument> args,
|
||||||
|
absl::Span<const DataType> result_types);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_
|
#endif // TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_
|
||||||
|
@ -1124,6 +1124,18 @@ tf_cuda_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "mlir_xla_op_kernel",
|
||||||
|
srcs = ["mlir_xla_op_kernel.cc"],
|
||||||
|
hdrs = ["mlir_xla_op_kernel.h"],
|
||||||
|
deps = [
|
||||||
|
":xla_compiler",
|
||||||
|
"//tensorflow/compiler/jit:xla_compilation_cache",
|
||||||
|
"//tensorflow/compiler/mlir:array_container_utils",
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "resource_util",
|
name = "resource_util",
|
||||||
srcs = ["resource_util.cc"],
|
srcs = ["resource_util.cc"],
|
||||||
|
@ -150,6 +150,7 @@ tf_kernel_library(
|
|||||||
"//tensorflow/compiler/jit:xla_activity_listener",
|
"//tensorflow/compiler/jit:xla_activity_listener",
|
||||||
"//tensorflow/compiler/jit:xla_activity_proto_cc",
|
"//tensorflow/compiler/jit:xla_activity_proto_cc",
|
||||||
"//tensorflow/compiler/tf2xla:common",
|
"//tensorflow/compiler/tf2xla:common",
|
||||||
|
"//tensorflow/compiler/tf2xla:mlir_xla_op_kernel",
|
||||||
"//tensorflow/compiler/tf2xla:xla_compilation_device",
|
"//tensorflow/compiler/tf2xla:xla_compilation_device",
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
"//tensorflow/compiler/tf2xla:xla_context",
|
"//tensorflow/compiler/tf2xla:xla_context",
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/tf2xla/kernels/relu_op.h"
|
#include "tensorflow/compiler/tf2xla/kernels/relu_op.h"
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
@ -35,15 +36,7 @@ XlaOp Relu6(XlaOp x) {
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
class ReluOp : public XlaOpKernel {
|
REGISTER_XLA_OP(Name("Relu"), MlirXlaOpKernel);
|
||||||
public:
|
|
||||||
explicit ReluOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
|
||||||
// Computes the max of the scalar input x and 0.
|
|
||||||
void Compile(XlaOpKernelContext* ctx) override {
|
|
||||||
ctx->SetOutput(0, xla::Relu(ctx->Input(0)));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
REGISTER_XLA_OP(Name("Relu"), ReluOp);
|
|
||||||
|
|
||||||
class Relu6Op : public XlaOpKernel {
|
class Relu6Op : public XlaOpKernel {
|
||||||
public:
|
public:
|
||||||
|
109
tensorflow/compiler/tf2xla/mlir_xla_op_kernel.cc
Normal file
109
tensorflow/compiler/tf2xla/mlir_xla_op_kernel.cc
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h"
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/jit/xla_compilation_cache.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
|
||||||
|
#include "tensorflow/compiler/mlir/utils/array_container_utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
Status ContextToXlaArgs(XlaOpKernelContext* ctx,
|
||||||
|
std::vector<XlaCompiler::Argument>& xla_args) {
|
||||||
|
int num_inputs = ctx->num_inputs();
|
||||||
|
xla_args.reserve(num_inputs);
|
||||||
|
for (int i = 0; i < num_inputs; ++i) {
|
||||||
|
// TODO(b/180448676): If the input `XlaExpression` kind is `kConstant`, then
|
||||||
|
// create a constant `XlaArgument`.
|
||||||
|
// TODO(b/180448774): Handle kResource and kTensorList.
|
||||||
|
XlaExpression::Kind ctx_kind_i = ctx->InputExpression(i).kind();
|
||||||
|
if (ctx_kind_i != XlaExpression::Kind::kXlaOp &&
|
||||||
|
ctx_kind_i != XlaExpression::Kind::kConstant)
|
||||||
|
return tensorflow::errors::InvalidArgument(
|
||||||
|
absl::StrCat("Input ", i, " to an MlirXlaOpKernel is invalid: ",
|
||||||
|
ctx->InputExpression(i).HumanString()));
|
||||||
|
XlaCompiler::Argument arg;
|
||||||
|
arg.kind = XlaCompiler::Argument::kParameter;
|
||||||
|
arg.type = ctx->input_type(i);
|
||||||
|
arg.shape = ctx->InputXlaShape(i).ValueOrDie();
|
||||||
|
arg.name = absl::StrCat("_arg", i);
|
||||||
|
xla_args.push_back(arg);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
Status MlirXlaOpKernel::ConstructXlaOp(XlaOpKernelContext* ctx) {
|
||||||
|
// Create input XlaArguments.
|
||||||
|
std::vector<XlaCompiler::Argument> xla_args;
|
||||||
|
TF_RETURN_IF_ERROR(ContextToXlaArgs(ctx, xla_args));
|
||||||
|
|
||||||
|
// Create input XlaOps.
|
||||||
|
llvm::SmallVector<xla::XlaOp, 4> xla_params(ctx->num_inputs());
|
||||||
|
for (int i = 0, end = ctx->num_inputs(); i < end; ++i) {
|
||||||
|
xla_params[i] = ctx->Input(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create outputs.
|
||||||
|
std::vector<DataType> result_dtypes(ctx->num_outputs());
|
||||||
|
for (int i = 0, end = result_dtypes.size(); i < end; ++i) {
|
||||||
|
result_dtypes[i] = ctx->expected_output_dtype(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
// When there are no data-flow outputs from the node, the node is used as a
|
||||||
|
// control output by the graph to TensorflowDialect importer.
|
||||||
|
std::vector<std::string> control_rets;
|
||||||
|
if (result_dtypes.empty()) {
|
||||||
|
control_rets.push_back(def().name());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the context's device.
|
||||||
|
auto device = dynamic_cast<Device*>(ctx->op_kernel_context()->device());
|
||||||
|
if (!device) {
|
||||||
|
return tensorflow::errors::InvalidArgument(
|
||||||
|
"Expected the XlaOpKernelContext argument's device to have type "
|
||||||
|
"Device.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a graph that wraps the kernel.
|
||||||
|
TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(def(), xla_args, result_dtypes));
|
||||||
|
|
||||||
|
// Compile the graph to HLO.
|
||||||
|
GraphDebugInfo debug_info;
|
||||||
|
std::vector<xla::XlaOp> returns(1);
|
||||||
|
TF_RETURN_IF_ERROR(BuildHloFromGraph(
|
||||||
|
*graph, *ctx->builder(), xla_params, returns,
|
||||||
|
mlir::SpanToArrayRef<XlaCompiler::Argument>(xla_args), control_rets,
|
||||||
|
device->device_type(),
|
||||||
|
*ctx->function_library()->GetFunctionLibraryDefinition(), debug_info,
|
||||||
|
{}));
|
||||||
|
|
||||||
|
// Set context outputs.
|
||||||
|
for (int i = 0, end = returns.size(); i < end; ++i) {
|
||||||
|
ctx->SetOutput(i, returns[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
void MlirXlaOpKernel::Compile(XlaOpKernelContext* ctx) {
|
||||||
|
OP_REQUIRES_OK(ctx, ConstructXlaOp(ctx));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
36
tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h
Normal file
36
tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_TF2XLA_MLIR_XLA_OP_KERNEL_H_
|
||||||
|
#define TENSORFLOW_COMPILER_TF2XLA_MLIR_XLA_OP_KERNEL_H_
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// An XlaOpKernel that's implemented by lowering using MLIR TensorFlow to HLO
|
||||||
|
// legalization.
|
||||||
|
class MlirXlaOpKernel : public XlaOpKernel {
|
||||||
|
public:
|
||||||
|
explicit MlirXlaOpKernel(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void Compile(XlaOpKernelContext* ctx) override;
|
||||||
|
Status ConstructXlaOp(XlaOpKernelContext* ctx);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_TF2XLA_MLIR_XLA_OP_KERNEL_H_
|
Loading…
Reference in New Issue
Block a user