Internal change

PiperOrigin-RevId: 299514998
Change-Id: I079530f7982fd945137f253d3286319b224bfd24
This commit is contained in:
A. Unique TensorFlower 2020-03-07 00:02:33 -08:00 committed by TensorFlower Gardener
parent c3d14c434b
commit edd36f52a3
9 changed files with 13 additions and 715 deletions

View File

@ -81,7 +81,6 @@ cc_library(
"//tensorflow/compiler/mlir/xla:xla_dialect_registration",
"//tensorflow/compiler/mlir/xla:xla_legalize_control_flow",
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
"//tensorflow/compiler/mlir/xla:xla_legalize_tf_with_tf2xla",
"//tensorflow/compiler/mlir/xla:xla_legalize_to_linalg",
"//tensorflow/compiler/mlir/xla:xla_legalize_to_standard",
"//tensorflow/compiler/mlir/xla:xla_lower",

View File

@ -132,43 +132,6 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "xla_legalize_tf_with_tf2xla",
srcs = [
"transforms/legalize_tf_with_tf2xla.cc",
],
deps = [
":hlo",
":mlir_hlo_builder",
"//tensorflow/compiler/jit:xla_cpu_device",
"//tensorflow/compiler/jit:xla_cpu_jit",
"//tensorflow/compiler/mlir:op_or_arg_name_mapper",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:convert_type",
"//tensorflow/compiler/mlir/tensorflow:export_tf_dialect_op",
"//tensorflow/compiler/mlir/tensorflow:lower_tf_lib",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_inc_gen",
"//tensorflow/compiler/mlir/tensorflow:translate_utils",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/core:core_cpu_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:session_options",
"//tensorflow/stream_executor:timer",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
],
alwayslink = 1,
)
cc_library(
name = "map_xla_to_scalar_op",
hdrs = ["transforms/map_xla_to_scalar_op.h"],
@ -462,31 +425,6 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "mlir_hlo_builder",
srcs = [
"ir/mlir_hlo_builder.cc",
],
hdrs = [
"ir/mlir_hlo_builder.h",
],
deps = [
":hlo",
":hlo_utils",
":type_to_shape",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:shape_inference",
"//tensorflow/core/platform:types",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/container:flat_hash_map",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
],
)
cc_library(
name = "lhlo",
srcs = [

View File

@ -1,84 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
namespace xla {
static std::string GetMlirOpName(HloOpcode opcode) {
std::string op_name = HloOpcodeString(opcode);
absl::c_replace(op_name, '-', '_');
return mlir::xla_hlo::XlaHloDialect::getDialectNamespace().str() + "." +
op_name;
}
static std::string ToString(mlir::Type ty) {
std::string str;
llvm::raw_string_ostream sstream(str);
ty.print(sstream);
sstream.flush();
return str;
}
MlirHloBuilder::~MlirHloBuilder() = default;
StatusOr<XlaOp> MlirHloBuilder::MakeXlaOp(mlir::Value val) {
mlir::Type ty = val.getType();
auto shape = std::make_unique<Shape>(TypeToShape(ty));
if (shape->element_type() == PrimitiveType::PRIMITIVE_TYPE_INVALID) {
return InvalidArgument("unsupported type: %s", ToString(ty).c_str());
}
int64 handle = reinterpret_cast<int64>(val.getAsOpaquePointer());
handle_to_shape_[handle] = std::move(shape);
return XlaOp(handle, this);
}
XlaOp MlirHloBuilder::UnaryOp(HloOpcode unop, XlaOp operand) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
TF_ASSIGN_OR_RETURN(
Shape shape, ShapeInference::InferUnaryOpShape(unop, *operand_shape));
mlir::Value value = GetValue(operand);
mlir::OperationState state(loc_, GetMlirOpName(unop));
state.addOperands(value);
TF_ASSIGN_OR_RETURN(
mlir::Type ty,
ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
state.addTypes(ty);
mlir::Operation* op = builder_.createOperation(state);
return MakeXlaOp(op->getResult(0));
});
}
StatusOr<const Shape*> MlirHloBuilder::GetShapePtr(XlaOp op) const {
TF_RETURN_IF_ERROR(first_error());
TF_RETURN_IF_ERROR(CheckOpBuilder(op));
auto it = handle_to_shape_.find(op.handle());
if (it == handle_to_shape_.end()) {
return InvalidArgument("No XlaOp with handle %d", op.handle());
}
return it->second.get();
}
} // namespace xla

View File

@ -1,100 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_MLIR_HLO_BUILDER_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_IR_MLIR_HLO_BUILDER_H_
#include <memory>
#include "absl/container/flat_hash_map.h"
#include "llvm/ADT/StringRef.h"
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/Location.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "mlir/IR/Value.h" // TF:llvm-project
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/stream_executor/lib/statusor.h"
namespace xla {
// Provides a way to construct xla_hlo dialect ops in MLIR using XlaBuilder
// interface.
//
// Requires that all XlaOp arguments are either returned by any of the builder
// method or constructed using MakeXlaOp method in this builder.
//
// TODO(hinsu): Support more ops and utility functions to set special attributes
// like OpMetadata and Sharding.
class MlirHloBuilder : public XlaBuilder {
public:
// Constructs builder for the given function. New operations are added to the
// beginning of the function, if it is non empty and has a block.
explicit MlirHloBuilder(mlir::FuncOp func)
: XlaBuilder(func.getName().str()),
builder_(&func.getBody()),
loc_(builder_.getUnknownLoc()) {}
// TODO(hinsu): Add a constructor to build a new MLIR function from scratch
// and override Build methods.
MlirHloBuilder(const MlirHloBuilder&) = delete;
MlirHloBuilder& operator=(const MlirHloBuilder&) = delete;
~MlirHloBuilder() override;
// Wraps the given MLIR value under an XlaOp instance. Note that all HLO
// operations returns exactly one result therefore each op has an XlaOp
// wrapping result of the op.
//
// Returns an error if the HLO dialect doesn't support type of the given
// value.
StatusOr<XlaOp> MakeXlaOp(mlir::Value val);
// Returns value corresponding to the given op.
//
// Requires that the op was created by this builder.
mlir::Value GetValue(XlaOp op) {
void* ptr = reinterpret_cast<void*>(op.handle());
return mlir::Value::getFromOpaquePointer(ptr);
}
// Sets location for newly built ops, until reset.
void SetLocation(mlir::Location loc) { loc_ = loc; }
// Update insertion point so that newly built ops are inserted before the
// given op in order, until reset.
void setInsertionPoint(mlir::Operation* op) {
builder_.setInsertionPoint(op);
}
// Returns the shape of the given op.
StatusOr<const Shape*> GetShapePtr(XlaOp op) const override;
private:
XlaOp UnaryOp(HloOpcode unop, XlaOp operand) override;
mlir::OpBuilder builder_;
mlir::Location loc_;
absl::flat_hash_map<int64, std::unique_ptr<Shape>> handle_to_shape_;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_MLIR_HLO_BUILDER_H_

View File

@ -1,54 +0,0 @@
// RUN: tf-opt -xla-legalize-tf-with-tf2xla=device-type=XLA_CPU %s | FileCheck %s --dump-input-on-failure
// INVALID_DEVICE: tf-opt -xla-legalize-tf-with-tf2xla=device-type=INVALID_DEVICE %s | FileCheck %s --dump-input-on-failure
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} {
// CHECK-LABEL: abs
// expected-error@+1 {{unsupported device}}
func @abs(%arg0: tensor<2xf32>) -> tensor<2xf32> {
// CHECK: %[[RESULT:.*]] = "xla_hlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
%0 = "tf.Abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
// return %[[RESULT]]
return %0 : tensor<2xf32>
}
// CHECK-LABEL: unknown_op
func @unknown_op(%arg0: tensor<2xf32>) -> tensor<2xf32> {
// CHECK: tf.CustomTestOp
// expected-remark@+1 {{constant 20}}
%0 = "tf.CustomTestOp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
// CHECK-LABEL: dynamic_operand
func @dynamic_operand(%arg0: tensor<?xf32>) -> tensor<?xf32> {
// CHECK: tf.Abs
// expected-remark@+1 {{lowering requires static shaped operands}}
%0 = "tf.Abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// CHECK-LABEL: multiple_dialect_ops
func @multiple_dialect_ops(%arg0: tensor<2xf32>) -> tensor<2xf32> {
// CHECK: xla_hlo.neg
%0 = "xla_hlo.neg"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
// CHECK: xla_hlo.abs
%1 = "tf.Abs"(%0) : (tensor<2xf32>) -> tensor<2xf32>
return %1 : tensor<2xf32>
}
// TODO(hinsu): Add a test with variant type once one of the ops supporting
// the type is whitelisted. It should be rejected with unsupported type remark.
// TODO(hinsu): Add a test with uint8 type once one of the ops supporting the
// type is whitelisted. Unsigned types are not yet added to the HLO dialect so
// it should return an error. See b/130356985
// TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is
// available but doesn't support this instance.
}

View File

@ -1,388 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include <memory>
#include <string>
#include <string_view>
#include <utility>
#include <vector>
#include "absl/container/inlined_vector.h"
#include "absl/memory/memory.h"
#include "llvm/ADT/Optional.h"
#include "mlir/IR/Diagnostics.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/Location.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/Types.h" // TF:llvm-project
#include "mlir/IR/Value.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h.inc"
#include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h"
#include "tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h"
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/tf2xla/xla_expression.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/node_properties.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/stream_executor/lib/statusor.h"
#include "tensorflow/stream_executor/stream_executor.h"
namespace mlir {
namespace xla_hlo {
namespace {
template <typename T, size_t N>
using InlinedVector = tensorflow::gtl::InlinedVector<T, N>; // non-absl ok
static bool IsOpWhitelisted(Operation* op) {
// White-listed TensorFlow ops are known to have well behaved tf2xla kernels
// building valid MLIR using MlirHloBuilder.
// TODO(hinsu): Drop explicit whitelist when MLIR based bridge is enabled for
// all tf2xla kernels.
return isa<TF::AbsOp>(op);
}
static llvm::Optional<absl::string_view> GetJitDevice(
const std::string& device_type, const Location& loc) {
if (device_type == "XLA_CPU") return absl::string_view("XLA_CPU_JIT");
if (device_type == "TPU") return absl::string_view("XLA_TPU_JIT");
// TODO(hinsu): Support GPU device along with a test for it.
emitError(loc) << "unsupported device for legalization with tf2xla kernels: "
<< device_type;
return llvm::None;
}
static std::unique_ptr<tensorflow::StaticDeviceMgr> CreateDeviceMgr(
const std::string& device_type, const Location& loc) {
auto jit_device_or = GetJitDevice(device_type, loc);
if (!jit_device_or) return nullptr;
auto* factory = tensorflow::DeviceFactory::GetFactory(device_type);
if (!factory) {
emitError(loc) << "failed to create DeviceFactory for device: "
<< device_type;
return nullptr;
}
std::vector<std::unique_ptr<tensorflow::Device>> devices;
auto status = factory->CreateDevices(
tensorflow::SessionOptions(),
/*name_prefix=*/"/job:localhost/replica:0/task:0", &devices);
if (!status.ok()) {
emitError(loc) << status.ToString();
return nullptr;
}
auto device = absl::make_unique<tensorflow::XlaCompilationDevice>(
tensorflow::SessionOptions(), tensorflow::DeviceType(*jit_device_or));
return absl::make_unique<tensorflow::StaticDeviceMgr>(std::move(device));
}
class FuncLegalizer {
public:
static LogicalResult Legalize(FuncOp func, const std::string& device_type) {
FuncLegalizer legalizer(func, device_type);
if (failed(legalizer.PrepareParams())) return failure();
return legalizer.Legalize();
}
private:
FuncLegalizer(FuncOp func, const std::string& device_type)
: func_(func), device_type_(device_type), hlo_builder_(func) {}
~FuncLegalizer() { context_->Unref(); }
// Prepares OpKernelContext params common to all the ops.
// Emits an error on failure.
LogicalResult PrepareParams();
// Tries to legalize supported TensorFlow ops.
// Emits an error on failure.
LogicalResult Legalize();
// Tries to legalize the specified TensorFlow op, if supported.
//
// Emits an error and returns failure if an error is encountered during
// conversion. Note that success return value doesn't mean successful
// legalization.
LogicalResult LegalizeOp(Operation* op);
FuncOp func_;
std::string device_type_;
::xla::MlirHloBuilder hlo_builder_;
tensorflow::OpOrArgLocNameMapper name_mapper_;
tensorflow::XlaContext* context_; // Ref-counted.
std::unique_ptr<tensorflow::StaticDeviceMgr> device_mgr_;
tensorflow::Device* device_; // Owned by device_mgr_;
std::unique_ptr<tensorflow::ScopedStepContainer> step_container_;
std::unique_ptr<tensorflow::FunctionLibraryDefinition> flib_def_;
std::unique_ptr<tensorflow::ProcessFunctionLibraryRuntime> pflr_;
tensorflow::OpKernelContext::Params params_;
};
LogicalResult FuncLegalizer::PrepareParams() {
// XlaCompiler within the context is only used by the functional ops to
// compile functions. We are not handling those at the moment so XlaCompiler
// is not required.
context_ = new tensorflow::XlaContext(/*compiler=*/nullptr, &hlo_builder_);
context_->Ref();
mlir::Location loc = func_.getLoc();
device_mgr_ = CreateDeviceMgr(device_type_, loc);
if (!device_mgr_) return failure();
// Type of params_.device is DeviceBase* so store it as Device* to access
// derived class method.
device_ = device_mgr_->ListDevices().front();
params_.device = device_;
params_.resource_manager = device_->resource_manager();
// Resources are cleared at the time of device manager destruction so pass
// no-op cleanup function.
auto cleanup = [](const std::string& name) {};
// Use step_id zero as we only have a single context concurrently and
// concurrently running each of the MLIR functions create a new device.
step_container_ = absl::make_unique<tensorflow::ScopedStepContainer>(
/*step_id=*/0, cleanup);
tensorflow::Status status = step_container_->Create(
device_->resource_manager(),
tensorflow::XlaContext::kXlaContextResourceName, context_);
if (!status.ok()) {
emitError(loc) << "failed to create XlaContext resource: "
<< status.ToString();
return failure();
}
params_.step_container = step_container_.get();
tensorflow::StatusOr<int64_t> version_or =
tensorflow::GetTfGraphProducerVersion(
func_.getParentOfType<mlir::ModuleOp>());
if (!version_or.ok()) {
emitError(loc) << version_or.status().ToString();
return failure();
}
flib_def_ = absl::make_unique<tensorflow::FunctionLibraryDefinition>(
tensorflow::OpRegistry::Global(), tensorflow::FunctionDefLibrary());
pflr_ = absl::make_unique<tensorflow::ProcessFunctionLibraryRuntime>(
device_mgr_.get(), tensorflow::Env::Default(), /*config=*/nullptr,
version_or.ValueOrDie(), flib_def_.get(), tensorflow::OptimizerOptions());
params_.function_library = pflr_->GetFLR(device_->name());
return success();
}
LogicalResult FuncLegalizer::Legalize() {
// TensorFlow functions don't use CFGs.
if (func_.getBlocks().size() > 1) {
emitError(func_.getLoc()) << "requires at most one block in a TF function";
return failure();
}
if (func_.getBlocks().empty()) return success();
Block& block = func_.getBlocks().front();
std::vector<Operation*> ops;
ops.reserve(block.getOperations().size());
for (Operation& op : block.getOperations()) {
ops.push_back(&op);
}
for (Operation* op : ops) {
if (failed(LegalizeOp(op))) return failure();
}
return success();
}
LogicalResult FuncLegalizer::LegalizeOp(Operation* op) {
if (!IsOpWhitelisted(op)) return success();
// Only static shaped operands are supported in XLA builders for now.
for (Type ty : op->getOperandTypes()) {
auto ranked_ty = ty.cast<RankedTensorType>();
if (!ranked_ty || !ranked_ty.hasStaticShape()) {
op->emitRemark() << "lowering requires static shaped operands";
return success();
}
}
auto nodedef_or = tensorflow::ConvertTFDialectOpToNodeDef(
op, name_mapper_.GetUniqueName(op), /*ignore_unregistered_attrs=*/true);
if (!nodedef_or.ok()) {
op->emitRemark() << "failed to convert op to NodeDef: "
<< nodedef_or.status().ToString();
return success();
}
std::shared_ptr<const tensorflow::NodeProperties> props;
tensorflow::Status status = tensorflow::NodeProperties::CreateFromNodeDef(
*nodedef_or.ValueOrDie(),
params_.function_library->GetFunctionLibraryDefinition(), &props);
if (!status.ok()) {
op->emitRemark() << "failed to create NodeProperties: "
<< status.ToString();
return success();
}
tensorflow::OpKernel* op_kernel_raw;
status = params_.function_library->CreateKernel(props, &op_kernel_raw);
if (!status.ok()) {
op->emitRemark() << "failed to create tf2xla kernel: " << status.ToString();
return success();
}
// Transfer ownership of the kernel to a local smart pointer.
auto op_kernel = absl::WrapUnique(op_kernel_raw);
// TensorValue in inputs are backed by tensors which in turn depend on
// expressions. So, pre-allocate them to the required size.
InlinedVector<tensorflow::XlaExpression, 4> expressions;
InlinedVector<tensorflow::Tensor, 4> tensors;
InlinedVector<tensorflow::TensorValue, 4> inputs;
expressions.reserve(op->getNumOperands());
tensors.reserve(op->getNumOperands());
inputs.reserve(op->getNumOperands());
// Prepare the list of Tensor inputs for the kernel.
for (Value operand : op->getOperands()) {
// Skip this op if XLA doesn't support this operand type.
auto xla_op_or = hlo_builder_.MakeXlaOp(operand);
if (!xla_op_or.ok()) {
op->emitRemark() << "skipping legalization due to "
<< xla_op_or.status().ToString();
return success();
}
::xla::XlaOp xla_op = xla_op_or.ValueOrDie();
tensorflow::DataType dtype;
status = tensorflow::ConvertToDataType(operand.getType(), &dtype);
if (!status.ok()) {
op->emitRemark() << "skipping legalization due to " << status.ToString();
return success();
}
auto expression = tensorflow::XlaExpression::XlaOp(xla_op, dtype);
expressions.push_back(expression);
if (!tensorflow::DataTypeCanUseMemcpy(dtype)) {
op->emitRemark() << "skipping legalization due to unsupported type "
<< operand.getType();
return success();
}
auto shape_or = expression.GetShape();
if (!shape_or.ok()) {
op->emitRemark() << "failed to get shape for expression. "
<< expression.HumanString();
return success();
}
tensors.emplace_back(
device_->GetAllocator(tensorflow::AllocatorAttributes()), dtype,
shape_or.ValueOrDie());
tensorflow::Tensor& tensor = tensors.back();
tensorflow::XlaOpKernelContext::AssignExpressionToTensor(expression,
&tensor);
inputs.emplace_back(&tensor);
}
params_.inputs = &inputs;
params_.op_kernel = op_kernel.get();
llvm::SmallVector<tensorflow::AllocatorAttributes, 4> output_attr(
op->getNumResults());
params_.output_attr_array = output_attr.data();
hlo_builder_.setInsertionPoint(op);
hlo_builder_.SetLocation(op->getLoc());
// Execute the kernel.
tensorflow::OpKernelContext op_context(&params_, op->getNumResults());
device_->Compute(params_.op_kernel, &op_context);
if (!op_context.status().ok()) {
op->emitRemark() << "compilation to HLO failed: "
<< op_context.status().ToString();
return success();
}
// Replace uses of old results using the corresponding value after the
// lowering.
for (int i = 0, e = op->getNumResults(); i < e; i++) {
tensorflow::Tensor* output = op_context.mutable_output(i);
const tensorflow::XlaExpression* expr =
tensorflow::XlaOpKernelContext::CastExpressionFromTensor(*output);
if (expr->kind() != tensorflow::XlaExpression::Kind::kXlaOp)
return op->emitError(
"expects XlaExpression of kind kXlaOp in compiled output");
auto value = hlo_builder_.GetValue(expr->handle());
op->getResult(i).replaceAllUsesWith(value);
}
op->erase();
return success();
}
class LegalizeTF : public FunctionPass<LegalizeTF> {
public:
LegalizeTF() = default;
LegalizeTF(const LegalizeTF&) {}
void runOnFunction() override {
if (failed(FuncLegalizer::Legalize(getFunction(), device_type_)))
signalPassFailure();
}
private:
// TODO(hinsu): Support finer grained device type assignment instead of a
// global device type for all TensorFlow ops.
Option<std::string> device_type_{
*this, "device-type",
llvm::cl::desc("XLA device type for execution of TensorFlow ops. "
"Supports XLA_CPU and TPU for now.")};
};
static PassRegistration<LegalizeTF> pass(
"xla-legalize-tf-with-tf2xla",
"Legalize from TensorFlow to the HLO dialect using tf2xla kernels");
} // end namespace
} // end namespace xla_hlo
} // end namespace mlir

View File

@ -50,8 +50,7 @@ XlaCompiler* XlaOpKernelContext::compiler() const {
}
// Retrieves an XlaExpression that was allocated by a previous Op.
const XlaExpression* XlaOpKernelContext::CastExpressionFromTensor(
const Tensor& tensor) {
static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) {
const XlaExpression* expression =
reinterpret_cast<const XlaExpression*>(tensor.tensor_data().data());
CHECK(expression->kind() != XlaExpression::Kind::kInvalid)
@ -60,8 +59,8 @@ const XlaExpression* XlaOpKernelContext::CastExpressionFromTensor(
}
// Assigns an XlaExpression to a tensor on an XLA compilation device.
void XlaOpKernelContext::AssignExpressionToTensor(const XlaExpression& value,
Tensor* tensor) {
static void AssignExpressionToTensor(Tensor* tensor,
const XlaExpression& value) {
const XlaExpression* expression =
reinterpret_cast<const XlaExpression*>(tensor->tensor_data().data());
CHECK(expression->kind() == XlaExpression::Kind::kInvalid)
@ -397,8 +396,7 @@ namespace {
Status ReadVariableInputTensor(const Tensor& tensor, DataType type,
const XlaOpKernelContext* ctx,
TensorShape* shape, xla::XlaOp* value) {
const XlaExpression* expression =
XlaOpKernelContext::CastExpressionFromTensor(tensor);
const XlaExpression* expression = CastExpressionFromTensor(tensor);
XlaResource* variable = expression->resource();
TF_RET_CHECK(variable != nullptr);
TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
@ -488,8 +486,7 @@ void XlaOpKernelContext::SetOutputExpression(int index,
TF_ASSIGN_OR_RETURN(TensorShape shape, expression.GetShape());
TF_RETURN_IF_ERROR(context_->allocate_output(index, shape, &output));
}
XlaOpKernelContext::AssignExpressionToTensor(
expression, context_->mutable_output(index));
AssignExpressionToTensor(context_->mutable_output(index), expression);
return Status::OK();
}();
if (!status.ok()) {
@ -539,8 +536,7 @@ namespace {
Status AssignVariableTensor(const Tensor& tensor, DataType type,
const XlaOpKernelContext* ctx, xla::XlaOp handle,
xla::XlaBuilder* builder) {
const XlaExpression* expression =
XlaOpKernelContext::CastExpressionFromTensor(tensor);
const XlaExpression* expression = CastExpressionFromTensor(tensor);
XlaResource* variable = expression->resource();
TF_RET_CHECK(variable != nullptr);
TF_RET_CHECK(variable->kind() == XlaResource::kVariable);

View File

@ -278,13 +278,6 @@ class XlaOpKernelContext {
// separate specialization of the computation for each DataType.
const xla::XlaComputation* GetOrCreateMul(const DataType type);
// Assigns an XlaExpression to a tensor on an XLA compilation device.
static void AssignExpressionToTensor(const XlaExpression& value,
Tensor* tensor);
// Retrieves an XlaExpression that was assigned to the specified tensor.
static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor);
private:
// Returns the tensor of input `name`.
const Tensor& GetInputTensorByName(absl::string_view name);

View File

@ -95,7 +95,6 @@ class XlaOp {
int64 handle() const { return handle_; }
friend class XlaBuilder;
friend class MlirHloBuilder;
// < 0 means "invalid handle".
int64 handle_;
@ -140,7 +139,7 @@ class XlaBuilder {
XlaBuilder(const XlaBuilder&) = delete;
XlaBuilder& operator=(const XlaBuilder&) = delete;
virtual ~XlaBuilder();
~XlaBuilder();
// Returns the computation name.
const string& name() const { return name_; }
@ -278,7 +277,7 @@ class XlaBuilder {
StatusOr<Shape> GetShape(XlaOp op) const;
// Returns the shape of the given op.
virtual StatusOr<const Shape*> GetShapePtr(XlaOp op) const;
StatusOr<const Shape*> GetShapePtr(XlaOp op) const;
// Returns the (inferred) result for the current computation's shape. This
// assumes the root instruction is the last added instruction.
@ -646,7 +645,7 @@ class XlaBuilder {
StatusOr<HloInstructionProto*> LookUpMutableInstructionByHandle(int64 handle);
// Internal helper method that does the building for an arbitrary unary op.
virtual XlaOp UnaryOp(HloOpcode unop, XlaOp operand);
XlaOp UnaryOp(HloOpcode unop, XlaOp operand);
// Internal helper method that does the building for an arbitrary binary op.
// broadcast_dimensions specifies which dimensions to use for broadcasting
@ -1057,17 +1056,16 @@ class XlaBuilder {
friend XlaOp GetDimensionSize(XlaOp operand, int64 dimension);
friend XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension);
protected:
// Returns OK status if the given op was built using this builder. Otherwise,
// returns an error.
Status CheckOpBuilder(XlaOp op) const;
private:
XlaOp ConditionalImpl(
XlaOp branch_index,
absl::Span<const XlaComputation* const> branch_computations,
absl::Span<const XlaOp> branch_operands);
// Returns OK status if the given op was built using this builder. Otherwise,
// returns an error.
Status CheckOpBuilder(XlaOp op) const;
// Here, InstructionType is either const HloInstructionProto* or non-const
// HloInstructionProto*.
template <typename InstructionType>