Add MlirXlaOpKernel, which is used to implement XlaOpKernels using MLIR legalization.
PiperOrigin-RevId: 360664123 Change-Id: Ic72c880496fe405675a8740559e13db62d195f18
This commit is contained in:
parent
c180f35a45
commit
cfec367771
tensorflow/compiler
@ -213,9 +213,7 @@ static bool ShouldBeMegamorphic(int64 compile_count, int64 execution_count) {
|
||||
execution_count < kMinExecutionsPerCompile * compile_count;
|
||||
}
|
||||
|
||||
// Creates a simple graph using the specified op as the only op apart from the
|
||||
// arg and retval nodes.
|
||||
static xla::StatusOr<std::unique_ptr<Graph>> CreateGraph(
|
||||
xla::StatusOr<std::unique_ptr<Graph>> CreateGraph(
|
||||
const NodeDef& node_def, absl::Span<const XlaCompiler::Argument> args,
|
||||
absl::Span<const DataType> result_types) {
|
||||
// TODO(b/74182462): We implement this by creating a new dummy Graph including
|
||||
|
@ -196,6 +196,12 @@ class XlaCompilationCache : public ResourceBase {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaCompilationCache);
|
||||
};
|
||||
|
||||
// Creates a single-node graph using the specified node_def as the only op apart
|
||||
// from the arg and retval nodes.
|
||||
xla::StatusOr<std::unique_ptr<Graph>> CreateGraph(
|
||||
const NodeDef& node_def, absl::Span<const XlaCompiler::Argument> args,
|
||||
absl::Span<const DataType> result_types);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_
|
||||
|
@ -1124,6 +1124,18 @@ tf_cuda_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mlir_xla_op_kernel",
|
||||
srcs = ["mlir_xla_op_kernel.cc"],
|
||||
hdrs = ["mlir_xla_op_kernel.h"],
|
||||
deps = [
|
||||
":xla_compiler",
|
||||
"//tensorflow/compiler/jit:xla_compilation_cache",
|
||||
"//tensorflow/compiler/mlir:array_container_utils",
|
||||
"//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "resource_util",
|
||||
srcs = ["resource_util.cc"],
|
||||
|
@ -150,6 +150,7 @@ tf_kernel_library(
|
||||
"//tensorflow/compiler/jit:xla_activity_listener",
|
||||
"//tensorflow/compiler/jit:xla_activity_proto_cc",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:mlir_xla_op_kernel",
|
||||
"//tensorflow/compiler/tf2xla:xla_compilation_device",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla:xla_context",
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/kernels/relu_op.h"
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
@ -35,15 +36,7 @@ XlaOp Relu6(XlaOp x) {
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class ReluOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit ReluOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
// Computes the max of the scalar input x and 0.
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
ctx->SetOutput(0, xla::Relu(ctx->Input(0)));
|
||||
}
|
||||
};
|
||||
REGISTER_XLA_OP(Name("Relu"), ReluOp);
|
||||
REGISTER_XLA_OP(Name("Relu"), MlirXlaOpKernel);
|
||||
|
||||
class Relu6Op : public XlaOpKernel {
|
||||
public:
|
||||
|
109
tensorflow/compiler/tf2xla/mlir_xla_op_kernel.cc
Normal file
109
tensorflow/compiler/tf2xla/mlir_xla_op_kernel.cc
Normal file
@ -0,0 +1,109 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h"
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_compilation_cache.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
|
||||
#include "tensorflow/compiler/mlir/utils/array_container_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
Status ContextToXlaArgs(XlaOpKernelContext* ctx,
|
||||
std::vector<XlaCompiler::Argument>& xla_args) {
|
||||
int num_inputs = ctx->num_inputs();
|
||||
xla_args.reserve(num_inputs);
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
// TODO(b/180448676): If the input `XlaExpression` kind is `kConstant`, then
|
||||
// create a constant `XlaArgument`.
|
||||
// TODO(b/180448774): Handle kResource and kTensorList.
|
||||
XlaExpression::Kind ctx_kind_i = ctx->InputExpression(i).kind();
|
||||
if (ctx_kind_i != XlaExpression::Kind::kXlaOp &&
|
||||
ctx_kind_i != XlaExpression::Kind::kConstant)
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
absl::StrCat("Input ", i, " to an MlirXlaOpKernel is invalid: ",
|
||||
ctx->InputExpression(i).HumanString()));
|
||||
XlaCompiler::Argument arg;
|
||||
arg.kind = XlaCompiler::Argument::kParameter;
|
||||
arg.type = ctx->input_type(i);
|
||||
arg.shape = ctx->InputXlaShape(i).ValueOrDie();
|
||||
arg.name = absl::StrCat("_arg", i);
|
||||
xla_args.push_back(arg);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status MlirXlaOpKernel::ConstructXlaOp(XlaOpKernelContext* ctx) {
|
||||
// Create input XlaArguments.
|
||||
std::vector<XlaCompiler::Argument> xla_args;
|
||||
TF_RETURN_IF_ERROR(ContextToXlaArgs(ctx, xla_args));
|
||||
|
||||
// Create input XlaOps.
|
||||
llvm::SmallVector<xla::XlaOp, 4> xla_params(ctx->num_inputs());
|
||||
for (int i = 0, end = ctx->num_inputs(); i < end; ++i) {
|
||||
xla_params[i] = ctx->Input(i);
|
||||
}
|
||||
|
||||
// Create outputs.
|
||||
std::vector<DataType> result_dtypes(ctx->num_outputs());
|
||||
for (int i = 0, end = result_dtypes.size(); i < end; ++i) {
|
||||
result_dtypes[i] = ctx->expected_output_dtype(i);
|
||||
}
|
||||
|
||||
// When there are no data-flow outputs from the node, the node is used as a
|
||||
// control output by the graph to TensorflowDialect importer.
|
||||
std::vector<std::string> control_rets;
|
||||
if (result_dtypes.empty()) {
|
||||
control_rets.push_back(def().name());
|
||||
}
|
||||
|
||||
// Get the context's device.
|
||||
auto device = dynamic_cast<Device*>(ctx->op_kernel_context()->device());
|
||||
if (!device) {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"Expected the XlaOpKernelContext argument's device to have type "
|
||||
"Device.");
|
||||
}
|
||||
|
||||
// Create a graph that wraps the kernel.
|
||||
TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(def(), xla_args, result_dtypes));
|
||||
|
||||
// Compile the graph to HLO.
|
||||
GraphDebugInfo debug_info;
|
||||
std::vector<xla::XlaOp> returns(1);
|
||||
TF_RETURN_IF_ERROR(BuildHloFromGraph(
|
||||
*graph, *ctx->builder(), xla_params, returns,
|
||||
mlir::SpanToArrayRef<XlaCompiler::Argument>(xla_args), control_rets,
|
||||
device->device_type(),
|
||||
*ctx->function_library()->GetFunctionLibraryDefinition(), debug_info,
|
||||
{}));
|
||||
|
||||
// Set context outputs.
|
||||
for (int i = 0, end = returns.size(); i < end; ++i) {
|
||||
ctx->SetOutput(i, returns[i]);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void MlirXlaOpKernel::Compile(XlaOpKernelContext* ctx) {
|
||||
OP_REQUIRES_OK(ctx, ConstructXlaOp(ctx));
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
36
tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h
Normal file
36
tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h
Normal file
@ -0,0 +1,36 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_TF2XLA_MLIR_XLA_OP_KERNEL_H_
|
||||
#define TENSORFLOW_COMPILER_TF2XLA_MLIR_XLA_OP_KERNEL_H_
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// An XlaOpKernel that's implemented by lowering using MLIR TensorFlow to HLO
|
||||
// legalization.
|
||||
class MlirXlaOpKernel : public XlaOpKernel {
|
||||
public:
|
||||
explicit MlirXlaOpKernel(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
|
||||
private:
|
||||
void Compile(XlaOpKernelContext* ctx) override;
|
||||
Status ConstructXlaOp(XlaOpKernelContext* ctx);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_TF2XLA_MLIR_XLA_OP_KERNEL_H_
|
Loading…
Reference in New Issue
Block a user