Internal change
PiperOrigin-RevId: 299514998 Change-Id: I079530f7982fd945137f253d3286319b224bfd24
This commit is contained in:
parent
c3d14c434b
commit
edd36f52a3
@ -81,7 +81,6 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/xla:xla_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_control_flow",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf_with_tf2xla",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_to_linalg",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_to_standard",
|
||||
"//tensorflow/compiler/mlir/xla:xla_lower",
|
||||
|
@ -132,43 +132,6 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_legalize_tf_with_tf2xla",
|
||||
srcs = [
|
||||
"transforms/legalize_tf_with_tf2xla.cc",
|
||||
],
|
||||
deps = [
|
||||
":hlo",
|
||||
":mlir_hlo_builder",
|
||||
"//tensorflow/compiler/jit:xla_cpu_device",
|
||||
"//tensorflow/compiler/jit:xla_cpu_jit",
|
||||
"//tensorflow/compiler/mlir:op_or_arg_name_mapper",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:convert_type",
|
||||
"//tensorflow/compiler/mlir/tensorflow:export_tf_dialect_op",
|
||||
"//tensorflow/compiler/mlir/tensorflow:lower_tf_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_inc_gen",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_utils",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/core:core_cpu_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:session_options",
|
||||
"//tensorflow/stream_executor:timer",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "map_xla_to_scalar_op",
|
||||
hdrs = ["transforms/map_xla_to_scalar_op.h"],
|
||||
@ -462,31 +425,6 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mlir_hlo_builder",
|
||||
srcs = [
|
||||
"ir/mlir_hlo_builder.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"ir/mlir_hlo_builder.h",
|
||||
],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_utils",
|
||||
":type_to_shape",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:shape_inference",
|
||||
"//tensorflow/core/platform:types",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "lhlo",
|
||||
srcs = [
|
||||
|
@ -1,84 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h"
|
||||
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
|
||||
#include "tensorflow/compiler/xla/service/shape_inference.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
static std::string GetMlirOpName(HloOpcode opcode) {
|
||||
std::string op_name = HloOpcodeString(opcode);
|
||||
absl::c_replace(op_name, '-', '_');
|
||||
return mlir::xla_hlo::XlaHloDialect::getDialectNamespace().str() + "." +
|
||||
op_name;
|
||||
}
|
||||
|
||||
static std::string ToString(mlir::Type ty) {
|
||||
std::string str;
|
||||
llvm::raw_string_ostream sstream(str);
|
||||
ty.print(sstream);
|
||||
sstream.flush();
|
||||
return str;
|
||||
}
|
||||
|
||||
MlirHloBuilder::~MlirHloBuilder() = default;
|
||||
|
||||
StatusOr<XlaOp> MlirHloBuilder::MakeXlaOp(mlir::Value val) {
|
||||
mlir::Type ty = val.getType();
|
||||
auto shape = std::make_unique<Shape>(TypeToShape(ty));
|
||||
if (shape->element_type() == PrimitiveType::PRIMITIVE_TYPE_INVALID) {
|
||||
return InvalidArgument("unsupported type: %s", ToString(ty).c_str());
|
||||
}
|
||||
|
||||
int64 handle = reinterpret_cast<int64>(val.getAsOpaquePointer());
|
||||
handle_to_shape_[handle] = std::move(shape);
|
||||
return XlaOp(handle, this);
|
||||
}
|
||||
|
||||
XlaOp MlirHloBuilder::UnaryOp(HloOpcode unop, XlaOp operand) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Shape shape, ShapeInference::InferUnaryOpShape(unop, *operand_shape));
|
||||
|
||||
mlir::Value value = GetValue(operand);
|
||||
mlir::OperationState state(loc_, GetMlirOpName(unop));
|
||||
state.addOperands(value);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
mlir::Type ty,
|
||||
ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
|
||||
state.addTypes(ty);
|
||||
mlir::Operation* op = builder_.createOperation(state);
|
||||
return MakeXlaOp(op->getResult(0));
|
||||
});
|
||||
}
|
||||
|
||||
StatusOr<const Shape*> MlirHloBuilder::GetShapePtr(XlaOp op) const {
|
||||
TF_RETURN_IF_ERROR(first_error());
|
||||
TF_RETURN_IF_ERROR(CheckOpBuilder(op));
|
||||
auto it = handle_to_shape_.find(op.handle());
|
||||
if (it == handle_to_shape_.end()) {
|
||||
return InvalidArgument("No XlaOp with handle %d", op.handle());
|
||||
}
|
||||
return it->second.get();
|
||||
}
|
||||
|
||||
} // namespace xla
|
@ -1,100 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_MLIR_HLO_BUILDER_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_XLA_IR_MLIR_HLO_BUILDER_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Provides a way to construct xla_hlo dialect ops in MLIR using XlaBuilder
|
||||
// interface.
|
||||
//
|
||||
// Requires that all XlaOp arguments are either returned by any of the builder
|
||||
// method or constructed using MakeXlaOp method in this builder.
|
||||
//
|
||||
// TODO(hinsu): Support more ops and utility functions to set special attributes
|
||||
// like OpMetadata and Sharding.
|
||||
class MlirHloBuilder : public XlaBuilder {
|
||||
public:
|
||||
// Constructs builder for the given function. New operations are added to the
|
||||
// beginning of the function, if it is non empty and has a block.
|
||||
explicit MlirHloBuilder(mlir::FuncOp func)
|
||||
: XlaBuilder(func.getName().str()),
|
||||
builder_(&func.getBody()),
|
||||
loc_(builder_.getUnknownLoc()) {}
|
||||
|
||||
// TODO(hinsu): Add a constructor to build a new MLIR function from scratch
|
||||
// and override Build methods.
|
||||
|
||||
MlirHloBuilder(const MlirHloBuilder&) = delete;
|
||||
MlirHloBuilder& operator=(const MlirHloBuilder&) = delete;
|
||||
|
||||
~MlirHloBuilder() override;
|
||||
|
||||
// Wraps the given MLIR value under an XlaOp instance. Note that all HLO
|
||||
// operations returns exactly one result therefore each op has an XlaOp
|
||||
// wrapping result of the op.
|
||||
//
|
||||
// Returns an error if the HLO dialect doesn't support type of the given
|
||||
// value.
|
||||
StatusOr<XlaOp> MakeXlaOp(mlir::Value val);
|
||||
|
||||
// Returns value corresponding to the given op.
|
||||
//
|
||||
// Requires that the op was created by this builder.
|
||||
mlir::Value GetValue(XlaOp op) {
|
||||
void* ptr = reinterpret_cast<void*>(op.handle());
|
||||
return mlir::Value::getFromOpaquePointer(ptr);
|
||||
}
|
||||
|
||||
// Sets location for newly built ops, until reset.
|
||||
void SetLocation(mlir::Location loc) { loc_ = loc; }
|
||||
|
||||
// Update insertion point so that newly built ops are inserted before the
|
||||
// given op in order, until reset.
|
||||
void setInsertionPoint(mlir::Operation* op) {
|
||||
builder_.setInsertionPoint(op);
|
||||
}
|
||||
|
||||
// Returns the shape of the given op.
|
||||
StatusOr<const Shape*> GetShapePtr(XlaOp op) const override;
|
||||
|
||||
private:
|
||||
XlaOp UnaryOp(HloOpcode unop, XlaOp operand) override;
|
||||
|
||||
mlir::OpBuilder builder_;
|
||||
mlir::Location loc_;
|
||||
|
||||
absl::flat_hash_map<int64, std::unique_ptr<Shape>> handle_to_shape_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_MLIR_HLO_BUILDER_H_
|
@ -1,54 +0,0 @@
|
||||
// RUN: tf-opt -xla-legalize-tf-with-tf2xla=device-type=XLA_CPU %s | FileCheck %s --dump-input-on-failure
|
||||
|
||||
// INVALID_DEVICE: tf-opt -xla-legalize-tf-with-tf2xla=device-type=INVALID_DEVICE %s | FileCheck %s --dump-input-on-failure
|
||||
|
||||
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} {
|
||||
|
||||
// CHECK-LABEL: abs
|
||||
// expected-error@+1 {{unsupported device}}
|
||||
func @abs(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// CHECK: %[[RESULT:.*]] = "xla_hlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%0 = "tf.Abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
|
||||
// return %[[RESULT]]
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: unknown_op
|
||||
func @unknown_op(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// CHECK: tf.CustomTestOp
|
||||
// expected-remark@+1 {{constant 20}}
|
||||
%0 = "tf.CustomTestOp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: dynamic_operand
|
||||
func @dynamic_operand(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
// CHECK: tf.Abs
|
||||
// expected-remark@+1 {{lowering requires static shaped operands}}
|
||||
%0 = "tf.Abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: multiple_dialect_ops
|
||||
func @multiple_dialect_ops(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// CHECK: xla_hlo.neg
|
||||
%0 = "xla_hlo.neg"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
// CHECK: xla_hlo.abs
|
||||
%1 = "tf.Abs"(%0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
|
||||
return %1 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// TODO(hinsu): Add a test with variant type once one of the ops supporting
|
||||
// the type is whitelisted. It should be rejected with unsupported type remark.
|
||||
|
||||
// TODO(hinsu): Add a test with uint8 type once one of the ops supporting the
|
||||
// type is whitelisted. Unsigned types are not yet added to the HLO dialect so
|
||||
// it should return an error. See b/130356985
|
||||
|
||||
// TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is
|
||||
// available but doesn't support this instance.
|
||||
}
|
@ -1,388 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "llvm/ADT/Optional.h"
|
||||
#include "mlir/IR/Diagnostics.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Types.h" // TF:llvm-project
|
||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h.inc"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h"
|
||||
#include "tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_expression.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/node_properties.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
#include "tensorflow/stream_executor/stream_executor.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
namespace {
|
||||
|
||||
template <typename T, size_t N>
|
||||
using InlinedVector = tensorflow::gtl::InlinedVector<T, N>; // non-absl ok
|
||||
|
||||
static bool IsOpWhitelisted(Operation* op) {
|
||||
// White-listed TensorFlow ops are known to have well behaved tf2xla kernels
|
||||
// building valid MLIR using MlirHloBuilder.
|
||||
// TODO(hinsu): Drop explicit whitelist when MLIR based bridge is enabled for
|
||||
// all tf2xla kernels.
|
||||
return isa<TF::AbsOp>(op);
|
||||
}
|
||||
|
||||
static llvm::Optional<absl::string_view> GetJitDevice(
|
||||
const std::string& device_type, const Location& loc) {
|
||||
if (device_type == "XLA_CPU") return absl::string_view("XLA_CPU_JIT");
|
||||
if (device_type == "TPU") return absl::string_view("XLA_TPU_JIT");
|
||||
// TODO(hinsu): Support GPU device along with a test for it.
|
||||
|
||||
emitError(loc) << "unsupported device for legalization with tf2xla kernels: "
|
||||
<< device_type;
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
static std::unique_ptr<tensorflow::StaticDeviceMgr> CreateDeviceMgr(
|
||||
const std::string& device_type, const Location& loc) {
|
||||
auto jit_device_or = GetJitDevice(device_type, loc);
|
||||
if (!jit_device_or) return nullptr;
|
||||
|
||||
auto* factory = tensorflow::DeviceFactory::GetFactory(device_type);
|
||||
if (!factory) {
|
||||
emitError(loc) << "failed to create DeviceFactory for device: "
|
||||
<< device_type;
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<std::unique_ptr<tensorflow::Device>> devices;
|
||||
auto status = factory->CreateDevices(
|
||||
tensorflow::SessionOptions(),
|
||||
/*name_prefix=*/"/job:localhost/replica:0/task:0", &devices);
|
||||
if (!status.ok()) {
|
||||
emitError(loc) << status.ToString();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto device = absl::make_unique<tensorflow::XlaCompilationDevice>(
|
||||
tensorflow::SessionOptions(), tensorflow::DeviceType(*jit_device_or));
|
||||
return absl::make_unique<tensorflow::StaticDeviceMgr>(std::move(device));
|
||||
}
|
||||
|
||||
class FuncLegalizer {
|
||||
public:
|
||||
static LogicalResult Legalize(FuncOp func, const std::string& device_type) {
|
||||
FuncLegalizer legalizer(func, device_type);
|
||||
if (failed(legalizer.PrepareParams())) return failure();
|
||||
return legalizer.Legalize();
|
||||
}
|
||||
|
||||
private:
|
||||
FuncLegalizer(FuncOp func, const std::string& device_type)
|
||||
: func_(func), device_type_(device_type), hlo_builder_(func) {}
|
||||
|
||||
~FuncLegalizer() { context_->Unref(); }
|
||||
|
||||
// Prepares OpKernelContext params common to all the ops.
|
||||
// Emits an error on failure.
|
||||
LogicalResult PrepareParams();
|
||||
|
||||
// Tries to legalize supported TensorFlow ops.
|
||||
// Emits an error on failure.
|
||||
LogicalResult Legalize();
|
||||
|
||||
// Tries to legalize the specified TensorFlow op, if supported.
|
||||
//
|
||||
// Emits an error and returns failure if an error is encountered during
|
||||
// conversion. Note that success return value doesn't mean successful
|
||||
// legalization.
|
||||
LogicalResult LegalizeOp(Operation* op);
|
||||
|
||||
FuncOp func_;
|
||||
std::string device_type_;
|
||||
|
||||
::xla::MlirHloBuilder hlo_builder_;
|
||||
tensorflow::OpOrArgLocNameMapper name_mapper_;
|
||||
|
||||
tensorflow::XlaContext* context_; // Ref-counted.
|
||||
|
||||
std::unique_ptr<tensorflow::StaticDeviceMgr> device_mgr_;
|
||||
tensorflow::Device* device_; // Owned by device_mgr_;
|
||||
std::unique_ptr<tensorflow::ScopedStepContainer> step_container_;
|
||||
std::unique_ptr<tensorflow::FunctionLibraryDefinition> flib_def_;
|
||||
std::unique_ptr<tensorflow::ProcessFunctionLibraryRuntime> pflr_;
|
||||
tensorflow::OpKernelContext::Params params_;
|
||||
};
|
||||
|
||||
LogicalResult FuncLegalizer::PrepareParams() {
|
||||
// XlaCompiler within the context is only used by the functional ops to
|
||||
// compile functions. We are not handling those at the moment so XlaCompiler
|
||||
// is not required.
|
||||
context_ = new tensorflow::XlaContext(/*compiler=*/nullptr, &hlo_builder_);
|
||||
context_->Ref();
|
||||
|
||||
mlir::Location loc = func_.getLoc();
|
||||
device_mgr_ = CreateDeviceMgr(device_type_, loc);
|
||||
if (!device_mgr_) return failure();
|
||||
|
||||
// Type of params_.device is DeviceBase* so store it as Device* to access
|
||||
// derived class method.
|
||||
device_ = device_mgr_->ListDevices().front();
|
||||
params_.device = device_;
|
||||
params_.resource_manager = device_->resource_manager();
|
||||
|
||||
// Resources are cleared at the time of device manager destruction so pass
|
||||
// no-op cleanup function.
|
||||
auto cleanup = [](const std::string& name) {};
|
||||
// Use step_id zero as we only have a single context concurrently and
|
||||
// concurrently running each of the MLIR functions create a new device.
|
||||
step_container_ = absl::make_unique<tensorflow::ScopedStepContainer>(
|
||||
/*step_id=*/0, cleanup);
|
||||
tensorflow::Status status = step_container_->Create(
|
||||
device_->resource_manager(),
|
||||
tensorflow::XlaContext::kXlaContextResourceName, context_);
|
||||
if (!status.ok()) {
|
||||
emitError(loc) << "failed to create XlaContext resource: "
|
||||
<< status.ToString();
|
||||
return failure();
|
||||
}
|
||||
params_.step_container = step_container_.get();
|
||||
|
||||
tensorflow::StatusOr<int64_t> version_or =
|
||||
tensorflow::GetTfGraphProducerVersion(
|
||||
func_.getParentOfType<mlir::ModuleOp>());
|
||||
if (!version_or.ok()) {
|
||||
emitError(loc) << version_or.status().ToString();
|
||||
return failure();
|
||||
}
|
||||
|
||||
flib_def_ = absl::make_unique<tensorflow::FunctionLibraryDefinition>(
|
||||
tensorflow::OpRegistry::Global(), tensorflow::FunctionDefLibrary());
|
||||
pflr_ = absl::make_unique<tensorflow::ProcessFunctionLibraryRuntime>(
|
||||
device_mgr_.get(), tensorflow::Env::Default(), /*config=*/nullptr,
|
||||
version_or.ValueOrDie(), flib_def_.get(), tensorflow::OptimizerOptions());
|
||||
params_.function_library = pflr_->GetFLR(device_->name());
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult FuncLegalizer::Legalize() {
|
||||
// TensorFlow functions don't use CFGs.
|
||||
if (func_.getBlocks().size() > 1) {
|
||||
emitError(func_.getLoc()) << "requires at most one block in a TF function";
|
||||
return failure();
|
||||
}
|
||||
if (func_.getBlocks().empty()) return success();
|
||||
Block& block = func_.getBlocks().front();
|
||||
|
||||
std::vector<Operation*> ops;
|
||||
ops.reserve(block.getOperations().size());
|
||||
for (Operation& op : block.getOperations()) {
|
||||
ops.push_back(&op);
|
||||
}
|
||||
|
||||
for (Operation* op : ops) {
|
||||
if (failed(LegalizeOp(op))) return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult FuncLegalizer::LegalizeOp(Operation* op) {
|
||||
if (!IsOpWhitelisted(op)) return success();
|
||||
|
||||
// Only static shaped operands are supported in XLA builders for now.
|
||||
for (Type ty : op->getOperandTypes()) {
|
||||
auto ranked_ty = ty.cast<RankedTensorType>();
|
||||
if (!ranked_ty || !ranked_ty.hasStaticShape()) {
|
||||
op->emitRemark() << "lowering requires static shaped operands";
|
||||
return success();
|
||||
}
|
||||
}
|
||||
|
||||
auto nodedef_or = tensorflow::ConvertTFDialectOpToNodeDef(
|
||||
op, name_mapper_.GetUniqueName(op), /*ignore_unregistered_attrs=*/true);
|
||||
if (!nodedef_or.ok()) {
|
||||
op->emitRemark() << "failed to convert op to NodeDef: "
|
||||
<< nodedef_or.status().ToString();
|
||||
return success();
|
||||
}
|
||||
|
||||
std::shared_ptr<const tensorflow::NodeProperties> props;
|
||||
tensorflow::Status status = tensorflow::NodeProperties::CreateFromNodeDef(
|
||||
*nodedef_or.ValueOrDie(),
|
||||
params_.function_library->GetFunctionLibraryDefinition(), &props);
|
||||
if (!status.ok()) {
|
||||
op->emitRemark() << "failed to create NodeProperties: "
|
||||
<< status.ToString();
|
||||
return success();
|
||||
}
|
||||
tensorflow::OpKernel* op_kernel_raw;
|
||||
status = params_.function_library->CreateKernel(props, &op_kernel_raw);
|
||||
if (!status.ok()) {
|
||||
op->emitRemark() << "failed to create tf2xla kernel: " << status.ToString();
|
||||
return success();
|
||||
}
|
||||
// Transfer ownership of the kernel to a local smart pointer.
|
||||
auto op_kernel = absl::WrapUnique(op_kernel_raw);
|
||||
|
||||
// TensorValue in inputs are backed by tensors which in turn depend on
|
||||
// expressions. So, pre-allocate them to the required size.
|
||||
InlinedVector<tensorflow::XlaExpression, 4> expressions;
|
||||
InlinedVector<tensorflow::Tensor, 4> tensors;
|
||||
InlinedVector<tensorflow::TensorValue, 4> inputs;
|
||||
expressions.reserve(op->getNumOperands());
|
||||
tensors.reserve(op->getNumOperands());
|
||||
inputs.reserve(op->getNumOperands());
|
||||
|
||||
// Prepare the list of Tensor inputs for the kernel.
|
||||
for (Value operand : op->getOperands()) {
|
||||
// Skip this op if XLA doesn't support this operand type.
|
||||
auto xla_op_or = hlo_builder_.MakeXlaOp(operand);
|
||||
if (!xla_op_or.ok()) {
|
||||
op->emitRemark() << "skipping legalization due to "
|
||||
<< xla_op_or.status().ToString();
|
||||
return success();
|
||||
}
|
||||
::xla::XlaOp xla_op = xla_op_or.ValueOrDie();
|
||||
|
||||
tensorflow::DataType dtype;
|
||||
status = tensorflow::ConvertToDataType(operand.getType(), &dtype);
|
||||
if (!status.ok()) {
|
||||
op->emitRemark() << "skipping legalization due to " << status.ToString();
|
||||
return success();
|
||||
}
|
||||
|
||||
auto expression = tensorflow::XlaExpression::XlaOp(xla_op, dtype);
|
||||
expressions.push_back(expression);
|
||||
|
||||
if (!tensorflow::DataTypeCanUseMemcpy(dtype)) {
|
||||
op->emitRemark() << "skipping legalization due to unsupported type "
|
||||
<< operand.getType();
|
||||
return success();
|
||||
}
|
||||
|
||||
auto shape_or = expression.GetShape();
|
||||
if (!shape_or.ok()) {
|
||||
op->emitRemark() << "failed to get shape for expression. "
|
||||
<< expression.HumanString();
|
||||
return success();
|
||||
}
|
||||
|
||||
tensors.emplace_back(
|
||||
device_->GetAllocator(tensorflow::AllocatorAttributes()), dtype,
|
||||
shape_or.ValueOrDie());
|
||||
tensorflow::Tensor& tensor = tensors.back();
|
||||
tensorflow::XlaOpKernelContext::AssignExpressionToTensor(expression,
|
||||
&tensor);
|
||||
inputs.emplace_back(&tensor);
|
||||
}
|
||||
|
||||
params_.inputs = &inputs;
|
||||
params_.op_kernel = op_kernel.get();
|
||||
llvm::SmallVector<tensorflow::AllocatorAttributes, 4> output_attr(
|
||||
op->getNumResults());
|
||||
params_.output_attr_array = output_attr.data();
|
||||
|
||||
hlo_builder_.setInsertionPoint(op);
|
||||
hlo_builder_.SetLocation(op->getLoc());
|
||||
|
||||
// Execute the kernel.
|
||||
tensorflow::OpKernelContext op_context(¶ms_, op->getNumResults());
|
||||
device_->Compute(params_.op_kernel, &op_context);
|
||||
if (!op_context.status().ok()) {
|
||||
op->emitRemark() << "compilation to HLO failed: "
|
||||
<< op_context.status().ToString();
|
||||
return success();
|
||||
}
|
||||
|
||||
// Replace uses of old results using the corresponding value after the
|
||||
// lowering.
|
||||
for (int i = 0, e = op->getNumResults(); i < e; i++) {
|
||||
tensorflow::Tensor* output = op_context.mutable_output(i);
|
||||
const tensorflow::XlaExpression* expr =
|
||||
tensorflow::XlaOpKernelContext::CastExpressionFromTensor(*output);
|
||||
if (expr->kind() != tensorflow::XlaExpression::Kind::kXlaOp)
|
||||
return op->emitError(
|
||||
"expects XlaExpression of kind kXlaOp in compiled output");
|
||||
auto value = hlo_builder_.GetValue(expr->handle());
|
||||
op->getResult(i).replaceAllUsesWith(value);
|
||||
}
|
||||
|
||||
op->erase();
|
||||
return success();
|
||||
}
|
||||
|
||||
class LegalizeTF : public FunctionPass<LegalizeTF> {
|
||||
public:
|
||||
LegalizeTF() = default;
|
||||
|
||||
LegalizeTF(const LegalizeTF&) {}
|
||||
|
||||
void runOnFunction() override {
|
||||
if (failed(FuncLegalizer::Legalize(getFunction(), device_type_)))
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
private:
|
||||
// TODO(hinsu): Support finer grained device type assignment instead of a
|
||||
// global device type for all TensorFlow ops.
|
||||
Option<std::string> device_type_{
|
||||
*this, "device-type",
|
||||
llvm::cl::desc("XLA device type for execution of TensorFlow ops. "
|
||||
"Supports XLA_CPU and TPU for now.")};
|
||||
};
|
||||
|
||||
static PassRegistration<LegalizeTF> pass(
|
||||
"xla-legalize-tf-with-tf2xla",
|
||||
"Legalize from TensorFlow to the HLO dialect using tf2xla kernels");
|
||||
|
||||
} // end namespace
|
||||
|
||||
} // end namespace xla_hlo
|
||||
} // end namespace mlir
|
@ -50,8 +50,7 @@ XlaCompiler* XlaOpKernelContext::compiler() const {
|
||||
}
|
||||
|
||||
// Retrieves an XlaExpression that was allocated by a previous Op.
|
||||
const XlaExpression* XlaOpKernelContext::CastExpressionFromTensor(
|
||||
const Tensor& tensor) {
|
||||
static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) {
|
||||
const XlaExpression* expression =
|
||||
reinterpret_cast<const XlaExpression*>(tensor.tensor_data().data());
|
||||
CHECK(expression->kind() != XlaExpression::Kind::kInvalid)
|
||||
@ -60,8 +59,8 @@ const XlaExpression* XlaOpKernelContext::CastExpressionFromTensor(
|
||||
}
|
||||
|
||||
// Assigns an XlaExpression to a tensor on an XLA compilation device.
|
||||
void XlaOpKernelContext::AssignExpressionToTensor(const XlaExpression& value,
|
||||
Tensor* tensor) {
|
||||
static void AssignExpressionToTensor(Tensor* tensor,
|
||||
const XlaExpression& value) {
|
||||
const XlaExpression* expression =
|
||||
reinterpret_cast<const XlaExpression*>(tensor->tensor_data().data());
|
||||
CHECK(expression->kind() == XlaExpression::Kind::kInvalid)
|
||||
@ -397,8 +396,7 @@ namespace {
|
||||
Status ReadVariableInputTensor(const Tensor& tensor, DataType type,
|
||||
const XlaOpKernelContext* ctx,
|
||||
TensorShape* shape, xla::XlaOp* value) {
|
||||
const XlaExpression* expression =
|
||||
XlaOpKernelContext::CastExpressionFromTensor(tensor);
|
||||
const XlaExpression* expression = CastExpressionFromTensor(tensor);
|
||||
XlaResource* variable = expression->resource();
|
||||
TF_RET_CHECK(variable != nullptr);
|
||||
TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
|
||||
@ -488,8 +486,7 @@ void XlaOpKernelContext::SetOutputExpression(int index,
|
||||
TF_ASSIGN_OR_RETURN(TensorShape shape, expression.GetShape());
|
||||
TF_RETURN_IF_ERROR(context_->allocate_output(index, shape, &output));
|
||||
}
|
||||
XlaOpKernelContext::AssignExpressionToTensor(
|
||||
expression, context_->mutable_output(index));
|
||||
AssignExpressionToTensor(context_->mutable_output(index), expression);
|
||||
return Status::OK();
|
||||
}();
|
||||
if (!status.ok()) {
|
||||
@ -539,8 +536,7 @@ namespace {
|
||||
Status AssignVariableTensor(const Tensor& tensor, DataType type,
|
||||
const XlaOpKernelContext* ctx, xla::XlaOp handle,
|
||||
xla::XlaBuilder* builder) {
|
||||
const XlaExpression* expression =
|
||||
XlaOpKernelContext::CastExpressionFromTensor(tensor);
|
||||
const XlaExpression* expression = CastExpressionFromTensor(tensor);
|
||||
XlaResource* variable = expression->resource();
|
||||
TF_RET_CHECK(variable != nullptr);
|
||||
TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
|
||||
|
@ -278,13 +278,6 @@ class XlaOpKernelContext {
|
||||
// separate specialization of the computation for each DataType.
|
||||
const xla::XlaComputation* GetOrCreateMul(const DataType type);
|
||||
|
||||
// Assigns an XlaExpression to a tensor on an XLA compilation device.
|
||||
static void AssignExpressionToTensor(const XlaExpression& value,
|
||||
Tensor* tensor);
|
||||
|
||||
// Retrieves an XlaExpression that was assigned to the specified tensor.
|
||||
static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor);
|
||||
|
||||
private:
|
||||
// Returns the tensor of input `name`.
|
||||
const Tensor& GetInputTensorByName(absl::string_view name);
|
||||
|
@ -95,7 +95,6 @@ class XlaOp {
|
||||
int64 handle() const { return handle_; }
|
||||
|
||||
friend class XlaBuilder;
|
||||
friend class MlirHloBuilder;
|
||||
|
||||
// < 0 means "invalid handle".
|
||||
int64 handle_;
|
||||
@ -140,7 +139,7 @@ class XlaBuilder {
|
||||
XlaBuilder(const XlaBuilder&) = delete;
|
||||
XlaBuilder& operator=(const XlaBuilder&) = delete;
|
||||
|
||||
virtual ~XlaBuilder();
|
||||
~XlaBuilder();
|
||||
|
||||
// Returns the computation name.
|
||||
const string& name() const { return name_; }
|
||||
@ -278,7 +277,7 @@ class XlaBuilder {
|
||||
StatusOr<Shape> GetShape(XlaOp op) const;
|
||||
|
||||
// Returns the shape of the given op.
|
||||
virtual StatusOr<const Shape*> GetShapePtr(XlaOp op) const;
|
||||
StatusOr<const Shape*> GetShapePtr(XlaOp op) const;
|
||||
|
||||
// Returns the (inferred) result for the current computation's shape. This
|
||||
// assumes the root instruction is the last added instruction.
|
||||
@ -646,7 +645,7 @@ class XlaBuilder {
|
||||
StatusOr<HloInstructionProto*> LookUpMutableInstructionByHandle(int64 handle);
|
||||
|
||||
// Internal helper method that does the building for an arbitrary unary op.
|
||||
virtual XlaOp UnaryOp(HloOpcode unop, XlaOp operand);
|
||||
XlaOp UnaryOp(HloOpcode unop, XlaOp operand);
|
||||
|
||||
// Internal helper method that does the building for an arbitrary binary op.
|
||||
// broadcast_dimensions specifies which dimensions to use for broadcasting
|
||||
@ -1057,17 +1056,16 @@ class XlaBuilder {
|
||||
friend XlaOp GetDimensionSize(XlaOp operand, int64 dimension);
|
||||
friend XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension);
|
||||
|
||||
protected:
|
||||
// Returns OK status if the given op was built using this builder. Otherwise,
|
||||
// returns an error.
|
||||
Status CheckOpBuilder(XlaOp op) const;
|
||||
|
||||
private:
|
||||
XlaOp ConditionalImpl(
|
||||
XlaOp branch_index,
|
||||
absl::Span<const XlaComputation* const> branch_computations,
|
||||
absl::Span<const XlaOp> branch_operands);
|
||||
|
||||
// Returns OK status if the given op was built using this builder. Otherwise,
|
||||
// returns an error.
|
||||
Status CheckOpBuilder(XlaOp op) const;
|
||||
|
||||
// Here, InstructionType is either const HloInstructionProto* or non-const
|
||||
// HloInstructionProto*.
|
||||
template <typename InstructionType>
|
||||
|
Loading…
x
Reference in New Issue
Block a user