From edd36f52a31fa60f42387c8ee4d0d12ac9fd4eac Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 7 Mar 2020 00:02:33 -0800 Subject: [PATCH] Internal change PiperOrigin-RevId: 299514998 Change-Id: I079530f7982fd945137f253d3286319b224bfd24 --- tensorflow/compiler/mlir/BUILD | 1 - tensorflow/compiler/mlir/xla/BUILD | 62 --- .../compiler/mlir/xla/ir/mlir_hlo_builder.cc | 84 ---- .../compiler/mlir/xla/ir/mlir_hlo_builder.h | 100 ----- .../xla/tests/legalize-tf-with-tf2xla.mlir | 54 --- .../xla/transforms/legalize_tf_with_tf2xla.cc | 388 ------------------ tensorflow/compiler/tf2xla/xla_op_kernel.cc | 16 +- tensorflow/compiler/tf2xla/xla_op_kernel.h | 7 - tensorflow/compiler/xla/client/xla_builder.h | 16 +- 9 files changed, 13 insertions(+), 715 deletions(-) delete mode 100644 tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc delete mode 100644 tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h delete mode 100644 tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir delete mode 100644 tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index ef4ddb619a8..22e665cc8ce 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -81,7 +81,6 @@ cc_library( "//tensorflow/compiler/mlir/xla:xla_dialect_registration", "//tensorflow/compiler/mlir/xla:xla_legalize_control_flow", "//tensorflow/compiler/mlir/xla:xla_legalize_tf", - "//tensorflow/compiler/mlir/xla:xla_legalize_tf_with_tf2xla", "//tensorflow/compiler/mlir/xla:xla_legalize_to_linalg", "//tensorflow/compiler/mlir/xla:xla_legalize_to_standard", "//tensorflow/compiler/mlir/xla:xla_lower", diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index 9e30992a7ac..5a05a0f60f3 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -132,43 +132,6 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "xla_legalize_tf_with_tf2xla", - srcs = [ - "transforms/legalize_tf_with_tf2xla.cc", - ], - deps = [ - ":hlo", - ":mlir_hlo_builder", - "//tensorflow/compiler/jit:xla_cpu_device", - "//tensorflow/compiler/jit:xla_cpu_jit", - "//tensorflow/compiler/mlir:op_or_arg_name_mapper", - "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:convert_type", - "//tensorflow/compiler/mlir/tensorflow:export_tf_dialect_op", - "//tensorflow/compiler/mlir/tensorflow:lower_tf_lib", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_inc_gen", - "//tensorflow/compiler/mlir/tensorflow:translate_utils", - "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/core:core_cpu_lib", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:session_options", - "//tensorflow/stream_executor:timer", - "//tensorflow/stream_executor/lib", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/memory", - "@llvm-project//llvm:support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Support", - ], - alwayslink = 1, -) - cc_library( name = "map_xla_to_scalar_op", hdrs = ["transforms/map_xla_to_scalar_op.h"], @@ -462,31 +425,6 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "mlir_hlo_builder", - srcs = [ - "ir/mlir_hlo_builder.cc", - ], - hdrs = [ - "ir/mlir_hlo_builder.h", - ], - deps = [ - ":hlo", - ":hlo_utils", - ":type_to_shape", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:shape_inference", - "//tensorflow/core/platform:types", - "//tensorflow/stream_executor/lib", - "@com_google_absl//absl/container:flat_hash_map", - "@llvm-project//llvm:support", - "@llvm-project//mlir:IR", - ], -) - cc_library( name = "lhlo", srcs = [ diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc deleted file mode 100644 index 79c2dc71b9b..00000000000 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc +++ /dev/null @@ -1,84 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h" - -#include "llvm/Support/raw_ostream.h" -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "tensorflow/compiler/mlir/xla/hlo_utils.h" -#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" -#include "tensorflow/compiler/mlir/xla/type_to_shape.h" -#include "tensorflow/compiler/xla/service/shape_inference.h" - -namespace xla { - -static std::string GetMlirOpName(HloOpcode opcode) { - std::string op_name = HloOpcodeString(opcode); - absl::c_replace(op_name, '-', '_'); - return mlir::xla_hlo::XlaHloDialect::getDialectNamespace().str() + "." + - op_name; -} - -static std::string ToString(mlir::Type ty) { - std::string str; - llvm::raw_string_ostream sstream(str); - ty.print(sstream); - sstream.flush(); - return str; -} - -MlirHloBuilder::~MlirHloBuilder() = default; - -StatusOr MlirHloBuilder::MakeXlaOp(mlir::Value val) { - mlir::Type ty = val.getType(); - auto shape = std::make_unique(TypeToShape(ty)); - if (shape->element_type() == PrimitiveType::PRIMITIVE_TYPE_INVALID) { - return InvalidArgument("unsupported type: %s", ToString(ty).c_str()); - } - - int64 handle = reinterpret_cast(val.getAsOpaquePointer()); - handle_to_shape_[handle] = std::move(shape); - return XlaOp(handle, this); -} - -XlaOp MlirHloBuilder::UnaryOp(HloOpcode unop, XlaOp operand) { - return ReportErrorOrReturn([&]() -> StatusOr { - TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); - TF_ASSIGN_OR_RETURN( - Shape shape, ShapeInference::InferUnaryOpShape(unop, *operand_shape)); - - mlir::Value value = GetValue(operand); - mlir::OperationState state(loc_, GetMlirOpName(unop)); - state.addOperands(value); - TF_ASSIGN_OR_RETURN( - mlir::Type ty, - ConvertShapeToType(shape, builder_)); - state.addTypes(ty); - mlir::Operation* op = builder_.createOperation(state); - return MakeXlaOp(op->getResult(0)); - }); -} - -StatusOr MlirHloBuilder::GetShapePtr(XlaOp op) const { - TF_RETURN_IF_ERROR(first_error()); - TF_RETURN_IF_ERROR(CheckOpBuilder(op)); - auto it = handle_to_shape_.find(op.handle()); - if (it == handle_to_shape_.end()) { - return InvalidArgument("No XlaOp with handle %d", op.handle()); - } - return it->second.get(); -} - -} // namespace xla diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h deleted file mode 100644 index 0061dc1d68b..00000000000 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h +++ /dev/null @@ -1,100 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_MLIR_HLO_BUILDER_H_ -#define TENSORFLOW_COMPILER_MLIR_XLA_IR_MLIR_HLO_BUILDER_H_ - -#include - -#include "absl/container/flat_hash_map.h" -#include "llvm/ADT/StringRef.h" -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project -#include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/shape.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/stream_executor/lib/statusor.h" - -namespace xla { - -// Provides a way to construct xla_hlo dialect ops in MLIR using XlaBuilder -// interface. -// -// Requires that all XlaOp arguments are either returned by any of the builder -// method or constructed using MakeXlaOp method in this builder. -// -// TODO(hinsu): Support more ops and utility functions to set special attributes -// like OpMetadata and Sharding. -class MlirHloBuilder : public XlaBuilder { - public: - // Constructs builder for the given function. New operations are added to the - // beginning of the function, if it is non empty and has a block. - explicit MlirHloBuilder(mlir::FuncOp func) - : XlaBuilder(func.getName().str()), - builder_(&func.getBody()), - loc_(builder_.getUnknownLoc()) {} - - // TODO(hinsu): Add a constructor to build a new MLIR function from scratch - // and override Build methods. - - MlirHloBuilder(const MlirHloBuilder&) = delete; - MlirHloBuilder& operator=(const MlirHloBuilder&) = delete; - - ~MlirHloBuilder() override; - - // Wraps the given MLIR value under an XlaOp instance. Note that all HLO - // operations returns exactly one result therefore each op has an XlaOp - // wrapping result of the op. - // - // Returns an error if the HLO dialect doesn't support type of the given - // value. - StatusOr MakeXlaOp(mlir::Value val); - - // Returns value corresponding to the given op. - // - // Requires that the op was created by this builder. - mlir::Value GetValue(XlaOp op) { - void* ptr = reinterpret_cast(op.handle()); - return mlir::Value::getFromOpaquePointer(ptr); - } - - // Sets location for newly built ops, until reset. - void SetLocation(mlir::Location loc) { loc_ = loc; } - - // Update insertion point so that newly built ops are inserted before the - // given op in order, until reset. - void setInsertionPoint(mlir::Operation* op) { - builder_.setInsertionPoint(op); - } - - // Returns the shape of the given op. - StatusOr GetShapePtr(XlaOp op) const override; - - private: - XlaOp UnaryOp(HloOpcode unop, XlaOp operand) override; - - mlir::OpBuilder builder_; - mlir::Location loc_; - - absl::flat_hash_map> handle_to_shape_; -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_MLIR_HLO_BUILDER_H_ diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir deleted file mode 100644 index 695279badb9..00000000000 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir +++ /dev/null @@ -1,54 +0,0 @@ -// RUN: tf-opt -xla-legalize-tf-with-tf2xla=device-type=XLA_CPU %s | FileCheck %s --dump-input-on-failure - -// INVALID_DEVICE: tf-opt -xla-legalize-tf-with-tf2xla=device-type=INVALID_DEVICE %s | FileCheck %s --dump-input-on-failure - -module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { - -// CHECK-LABEL: abs -// expected-error@+1 {{unsupported device}} -func @abs(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: %[[RESULT:.*]] = "xla_hlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> - %0 = "tf.Abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> - - // return %[[RESULT]] - return %0 : tensor<2xf32> -} - -// CHECK-LABEL: unknown_op -func @unknown_op(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: tf.CustomTestOp - // expected-remark@+1 {{constant 20}} - %0 = "tf.CustomTestOp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> - - return %0 : tensor<2xf32> -} - -// CHECK-LABEL: dynamic_operand -func @dynamic_operand(%arg0: tensor) -> tensor { - // CHECK: tf.Abs - // expected-remark@+1 {{lowering requires static shaped operands}} - %0 = "tf.Abs"(%arg0) : (tensor) -> tensor - - return %0 : tensor -} - -// CHECK-LABEL: multiple_dialect_ops -func @multiple_dialect_ops(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: xla_hlo.neg - %0 = "xla_hlo.neg"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> - // CHECK: xla_hlo.abs - %1 = "tf.Abs"(%0) : (tensor<2xf32>) -> tensor<2xf32> - - return %1 : tensor<2xf32> -} - -// TODO(hinsu): Add a test with variant type once one of the ops supporting -// the type is whitelisted. It should be rejected with unsupported type remark. - -// TODO(hinsu): Add a test with uint8 type once one of the ops supporting the -// type is whitelisted. Unsigned types are not yet added to the HLO dialect so -// it should return an error. See b/130356985 - -// TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is -// available but doesn't support this instance. -} diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc deleted file mode 100644 index acc668dfb02..00000000000 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc +++ /dev/null @@ -1,388 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include -#include -#include -#include -#include -#include - -#include "absl/container/inlined_vector.h" -#include "absl/memory/memory.h" -#include "llvm/ADT/Optional.h" -#include "mlir/IR/Diagnostics.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/Location.h" // TF:llvm-project -#include "mlir/IR/Module.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/IR/StandardTypes.h" // TF:llvm-project -#include "mlir/IR/Types.h" // TF:llvm-project -#include "mlir/IR/Value.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Support/LogicalResult.h" // TF:llvm-project -#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h.inc" -#include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h" -#include "tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h" -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" -#include "tensorflow/compiler/tf2xla/xla_context.h" -#include "tensorflow/compiler/tf2xla/xla_expression.h" -#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" -#include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/core/common_runtime/device.h" -#include "tensorflow/core/common_runtime/device_factory.h" -#include "tensorflow/core/common_runtime/device_mgr.h" -#include "tensorflow/core/common_runtime/process_function_library_runtime.h" -#include "tensorflow/core/framework/allocator.h" -#include "tensorflow/core/framework/function.h" -#include "tensorflow/core/framework/function.pb.h" -#include "tensorflow/core/framework/node_properties.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/resource_mgr.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/status.h" -#include "tensorflow/core/protobuf/config.pb.h" -#include "tensorflow/core/public/session_options.h" -#include "tensorflow/stream_executor/lib/statusor.h" -#include "tensorflow/stream_executor/stream_executor.h" - -namespace mlir { -namespace xla_hlo { -namespace { - -template -using InlinedVector = tensorflow::gtl::InlinedVector; // non-absl ok - -static bool IsOpWhitelisted(Operation* op) { - // White-listed TensorFlow ops are known to have well behaved tf2xla kernels - // building valid MLIR using MlirHloBuilder. - // TODO(hinsu): Drop explicit whitelist when MLIR based bridge is enabled for - // all tf2xla kernels. - return isa(op); -} - -static llvm::Optional GetJitDevice( - const std::string& device_type, const Location& loc) { - if (device_type == "XLA_CPU") return absl::string_view("XLA_CPU_JIT"); - if (device_type == "TPU") return absl::string_view("XLA_TPU_JIT"); - // TODO(hinsu): Support GPU device along with a test for it. - - emitError(loc) << "unsupported device for legalization with tf2xla kernels: " - << device_type; - return llvm::None; -} - -static std::unique_ptr CreateDeviceMgr( - const std::string& device_type, const Location& loc) { - auto jit_device_or = GetJitDevice(device_type, loc); - if (!jit_device_or) return nullptr; - - auto* factory = tensorflow::DeviceFactory::GetFactory(device_type); - if (!factory) { - emitError(loc) << "failed to create DeviceFactory for device: " - << device_type; - return nullptr; - } - std::vector> devices; - auto status = factory->CreateDevices( - tensorflow::SessionOptions(), - /*name_prefix=*/"/job:localhost/replica:0/task:0", &devices); - if (!status.ok()) { - emitError(loc) << status.ToString(); - return nullptr; - } - - auto device = absl::make_unique( - tensorflow::SessionOptions(), tensorflow::DeviceType(*jit_device_or)); - return absl::make_unique(std::move(device)); -} - -class FuncLegalizer { - public: - static LogicalResult Legalize(FuncOp func, const std::string& device_type) { - FuncLegalizer legalizer(func, device_type); - if (failed(legalizer.PrepareParams())) return failure(); - return legalizer.Legalize(); - } - - private: - FuncLegalizer(FuncOp func, const std::string& device_type) - : func_(func), device_type_(device_type), hlo_builder_(func) {} - - ~FuncLegalizer() { context_->Unref(); } - - // Prepares OpKernelContext params common to all the ops. - // Emits an error on failure. - LogicalResult PrepareParams(); - - // Tries to legalize supported TensorFlow ops. - // Emits an error on failure. - LogicalResult Legalize(); - - // Tries to legalize the specified TensorFlow op, if supported. - // - // Emits an error and returns failure if an error is encountered during - // conversion. Note that success return value doesn't mean successful - // legalization. - LogicalResult LegalizeOp(Operation* op); - - FuncOp func_; - std::string device_type_; - - ::xla::MlirHloBuilder hlo_builder_; - tensorflow::OpOrArgLocNameMapper name_mapper_; - - tensorflow::XlaContext* context_; // Ref-counted. - - std::unique_ptr device_mgr_; - tensorflow::Device* device_; // Owned by device_mgr_; - std::unique_ptr step_container_; - std::unique_ptr flib_def_; - std::unique_ptr pflr_; - tensorflow::OpKernelContext::Params params_; -}; - -LogicalResult FuncLegalizer::PrepareParams() { - // XlaCompiler within the context is only used by the functional ops to - // compile functions. We are not handling those at the moment so XlaCompiler - // is not required. - context_ = new tensorflow::XlaContext(/*compiler=*/nullptr, &hlo_builder_); - context_->Ref(); - - mlir::Location loc = func_.getLoc(); - device_mgr_ = CreateDeviceMgr(device_type_, loc); - if (!device_mgr_) return failure(); - - // Type of params_.device is DeviceBase* so store it as Device* to access - // derived class method. - device_ = device_mgr_->ListDevices().front(); - params_.device = device_; - params_.resource_manager = device_->resource_manager(); - - // Resources are cleared at the time of device manager destruction so pass - // no-op cleanup function. - auto cleanup = [](const std::string& name) {}; - // Use step_id zero as we only have a single context concurrently and - // concurrently running each of the MLIR functions create a new device. - step_container_ = absl::make_unique( - /*step_id=*/0, cleanup); - tensorflow::Status status = step_container_->Create( - device_->resource_manager(), - tensorflow::XlaContext::kXlaContextResourceName, context_); - if (!status.ok()) { - emitError(loc) << "failed to create XlaContext resource: " - << status.ToString(); - return failure(); - } - params_.step_container = step_container_.get(); - - tensorflow::StatusOr version_or = - tensorflow::GetTfGraphProducerVersion( - func_.getParentOfType()); - if (!version_or.ok()) { - emitError(loc) << version_or.status().ToString(); - return failure(); - } - - flib_def_ = absl::make_unique( - tensorflow::OpRegistry::Global(), tensorflow::FunctionDefLibrary()); - pflr_ = absl::make_unique( - device_mgr_.get(), tensorflow::Env::Default(), /*config=*/nullptr, - version_or.ValueOrDie(), flib_def_.get(), tensorflow::OptimizerOptions()); - params_.function_library = pflr_->GetFLR(device_->name()); - return success(); -} - -LogicalResult FuncLegalizer::Legalize() { - // TensorFlow functions don't use CFGs. - if (func_.getBlocks().size() > 1) { - emitError(func_.getLoc()) << "requires at most one block in a TF function"; - return failure(); - } - if (func_.getBlocks().empty()) return success(); - Block& block = func_.getBlocks().front(); - - std::vector ops; - ops.reserve(block.getOperations().size()); - for (Operation& op : block.getOperations()) { - ops.push_back(&op); - } - - for (Operation* op : ops) { - if (failed(LegalizeOp(op))) return failure(); - } - return success(); -} - -LogicalResult FuncLegalizer::LegalizeOp(Operation* op) { - if (!IsOpWhitelisted(op)) return success(); - - // Only static shaped operands are supported in XLA builders for now. - for (Type ty : op->getOperandTypes()) { - auto ranked_ty = ty.cast(); - if (!ranked_ty || !ranked_ty.hasStaticShape()) { - op->emitRemark() << "lowering requires static shaped operands"; - return success(); - } - } - - auto nodedef_or = tensorflow::ConvertTFDialectOpToNodeDef( - op, name_mapper_.GetUniqueName(op), /*ignore_unregistered_attrs=*/true); - if (!nodedef_or.ok()) { - op->emitRemark() << "failed to convert op to NodeDef: " - << nodedef_or.status().ToString(); - return success(); - } - - std::shared_ptr props; - tensorflow::Status status = tensorflow::NodeProperties::CreateFromNodeDef( - *nodedef_or.ValueOrDie(), - params_.function_library->GetFunctionLibraryDefinition(), &props); - if (!status.ok()) { - op->emitRemark() << "failed to create NodeProperties: " - << status.ToString(); - return success(); - } - tensorflow::OpKernel* op_kernel_raw; - status = params_.function_library->CreateKernel(props, &op_kernel_raw); - if (!status.ok()) { - op->emitRemark() << "failed to create tf2xla kernel: " << status.ToString(); - return success(); - } - // Transfer ownership of the kernel to a local smart pointer. - auto op_kernel = absl::WrapUnique(op_kernel_raw); - - // TensorValue in inputs are backed by tensors which in turn depend on - // expressions. So, pre-allocate them to the required size. - InlinedVector expressions; - InlinedVector tensors; - InlinedVector inputs; - expressions.reserve(op->getNumOperands()); - tensors.reserve(op->getNumOperands()); - inputs.reserve(op->getNumOperands()); - - // Prepare the list of Tensor inputs for the kernel. - for (Value operand : op->getOperands()) { - // Skip this op if XLA doesn't support this operand type. - auto xla_op_or = hlo_builder_.MakeXlaOp(operand); - if (!xla_op_or.ok()) { - op->emitRemark() << "skipping legalization due to " - << xla_op_or.status().ToString(); - return success(); - } - ::xla::XlaOp xla_op = xla_op_or.ValueOrDie(); - - tensorflow::DataType dtype; - status = tensorflow::ConvertToDataType(operand.getType(), &dtype); - if (!status.ok()) { - op->emitRemark() << "skipping legalization due to " << status.ToString(); - return success(); - } - - auto expression = tensorflow::XlaExpression::XlaOp(xla_op, dtype); - expressions.push_back(expression); - - if (!tensorflow::DataTypeCanUseMemcpy(dtype)) { - op->emitRemark() << "skipping legalization due to unsupported type " - << operand.getType(); - return success(); - } - - auto shape_or = expression.GetShape(); - if (!shape_or.ok()) { - op->emitRemark() << "failed to get shape for expression. " - << expression.HumanString(); - return success(); - } - - tensors.emplace_back( - device_->GetAllocator(tensorflow::AllocatorAttributes()), dtype, - shape_or.ValueOrDie()); - tensorflow::Tensor& tensor = tensors.back(); - tensorflow::XlaOpKernelContext::AssignExpressionToTensor(expression, - &tensor); - inputs.emplace_back(&tensor); - } - - params_.inputs = &inputs; - params_.op_kernel = op_kernel.get(); - llvm::SmallVector output_attr( - op->getNumResults()); - params_.output_attr_array = output_attr.data(); - - hlo_builder_.setInsertionPoint(op); - hlo_builder_.SetLocation(op->getLoc()); - - // Execute the kernel. - tensorflow::OpKernelContext op_context(¶ms_, op->getNumResults()); - device_->Compute(params_.op_kernel, &op_context); - if (!op_context.status().ok()) { - op->emitRemark() << "compilation to HLO failed: " - << op_context.status().ToString(); - return success(); - } - - // Replace uses of old results using the corresponding value after the - // lowering. - for (int i = 0, e = op->getNumResults(); i < e; i++) { - tensorflow::Tensor* output = op_context.mutable_output(i); - const tensorflow::XlaExpression* expr = - tensorflow::XlaOpKernelContext::CastExpressionFromTensor(*output); - if (expr->kind() != tensorflow::XlaExpression::Kind::kXlaOp) - return op->emitError( - "expects XlaExpression of kind kXlaOp in compiled output"); - auto value = hlo_builder_.GetValue(expr->handle()); - op->getResult(i).replaceAllUsesWith(value); - } - - op->erase(); - return success(); -} - -class LegalizeTF : public FunctionPass { - public: - LegalizeTF() = default; - - LegalizeTF(const LegalizeTF&) {} - - void runOnFunction() override { - if (failed(FuncLegalizer::Legalize(getFunction(), device_type_))) - signalPassFailure(); - } - - private: - // TODO(hinsu): Support finer grained device type assignment instead of a - // global device type for all TensorFlow ops. - Option device_type_{ - *this, "device-type", - llvm::cl::desc("XLA device type for execution of TensorFlow ops. " - "Supports XLA_CPU and TPU for now.")}; -}; - -static PassRegistration pass( - "xla-legalize-tf-with-tf2xla", - "Legalize from TensorFlow to the HLO dialect using tf2xla kernels"); - -} // end namespace - -} // end namespace xla_hlo -} // end namespace mlir diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index a1c45a4bf30..a1941cc5fdf 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -50,8 +50,7 @@ XlaCompiler* XlaOpKernelContext::compiler() const { } // Retrieves an XlaExpression that was allocated by a previous Op. -const XlaExpression* XlaOpKernelContext::CastExpressionFromTensor( - const Tensor& tensor) { +static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) { const XlaExpression* expression = reinterpret_cast(tensor.tensor_data().data()); CHECK(expression->kind() != XlaExpression::Kind::kInvalid) @@ -60,8 +59,8 @@ const XlaExpression* XlaOpKernelContext::CastExpressionFromTensor( } // Assigns an XlaExpression to a tensor on an XLA compilation device. -void XlaOpKernelContext::AssignExpressionToTensor(const XlaExpression& value, - Tensor* tensor) { +static void AssignExpressionToTensor(Tensor* tensor, + const XlaExpression& value) { const XlaExpression* expression = reinterpret_cast(tensor->tensor_data().data()); CHECK(expression->kind() == XlaExpression::Kind::kInvalid) @@ -397,8 +396,7 @@ namespace { Status ReadVariableInputTensor(const Tensor& tensor, DataType type, const XlaOpKernelContext* ctx, TensorShape* shape, xla::XlaOp* value) { - const XlaExpression* expression = - XlaOpKernelContext::CastExpressionFromTensor(tensor); + const XlaExpression* expression = CastExpressionFromTensor(tensor); XlaResource* variable = expression->resource(); TF_RET_CHECK(variable != nullptr); TF_RET_CHECK(variable->kind() == XlaResource::kVariable); @@ -488,8 +486,7 @@ void XlaOpKernelContext::SetOutputExpression(int index, TF_ASSIGN_OR_RETURN(TensorShape shape, expression.GetShape()); TF_RETURN_IF_ERROR(context_->allocate_output(index, shape, &output)); } - XlaOpKernelContext::AssignExpressionToTensor( - expression, context_->mutable_output(index)); + AssignExpressionToTensor(context_->mutable_output(index), expression); return Status::OK(); }(); if (!status.ok()) { @@ -539,8 +536,7 @@ namespace { Status AssignVariableTensor(const Tensor& tensor, DataType type, const XlaOpKernelContext* ctx, xla::XlaOp handle, xla::XlaBuilder* builder) { - const XlaExpression* expression = - XlaOpKernelContext::CastExpressionFromTensor(tensor); + const XlaExpression* expression = CastExpressionFromTensor(tensor); XlaResource* variable = expression->resource(); TF_RET_CHECK(variable != nullptr); TF_RET_CHECK(variable->kind() == XlaResource::kVariable); diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index d72dd3972d3..27b198f8bee 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -278,13 +278,6 @@ class XlaOpKernelContext { // separate specialization of the computation for each DataType. const xla::XlaComputation* GetOrCreateMul(const DataType type); - // Assigns an XlaExpression to a tensor on an XLA compilation device. - static void AssignExpressionToTensor(const XlaExpression& value, - Tensor* tensor); - - // Retrieves an XlaExpression that was assigned to the specified tensor. - static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor); - private: // Returns the tensor of input `name`. const Tensor& GetInputTensorByName(absl::string_view name); diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 73be2c11c5b..dc5c83e0bfb 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -95,7 +95,6 @@ class XlaOp { int64 handle() const { return handle_; } friend class XlaBuilder; - friend class MlirHloBuilder; // < 0 means "invalid handle". int64 handle_; @@ -140,7 +139,7 @@ class XlaBuilder { XlaBuilder(const XlaBuilder&) = delete; XlaBuilder& operator=(const XlaBuilder&) = delete; - virtual ~XlaBuilder(); + ~XlaBuilder(); // Returns the computation name. const string& name() const { return name_; } @@ -278,7 +277,7 @@ class XlaBuilder { StatusOr GetShape(XlaOp op) const; // Returns the shape of the given op. - virtual StatusOr GetShapePtr(XlaOp op) const; + StatusOr GetShapePtr(XlaOp op) const; // Returns the (inferred) result for the current computation's shape. This // assumes the root instruction is the last added instruction. @@ -646,7 +645,7 @@ class XlaBuilder { StatusOr LookUpMutableInstructionByHandle(int64 handle); // Internal helper method that does the building for an arbitrary unary op. - virtual XlaOp UnaryOp(HloOpcode unop, XlaOp operand); + XlaOp UnaryOp(HloOpcode unop, XlaOp operand); // Internal helper method that does the building for an arbitrary binary op. // broadcast_dimensions specifies which dimensions to use for broadcasting @@ -1057,17 +1056,16 @@ class XlaBuilder { friend XlaOp GetDimensionSize(XlaOp operand, int64 dimension); friend XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension); - protected: - // Returns OK status if the given op was built using this builder. Otherwise, - // returns an error. - Status CheckOpBuilder(XlaOp op) const; - private: XlaOp ConditionalImpl( XlaOp branch_index, absl::Span branch_computations, absl::Span branch_operands); + // Returns OK status if the given op was built using this builder. Otherwise, + // returns an error. + Status CheckOpBuilder(XlaOp op) const; + // Here, InstructionType is either const HloInstructionProto* or non-const // HloInstructionProto*. template