Handle constant inputs in xla-legalize-tf-with-tf2xla pass
Some of the XlaOpKernels require constant inputs during compilation. These are provided as TensorFlow Tensors wrapped within XlaExpression. To use these kernels for compilation, * Figure out the inputs that needs to constant for the op compilation. * Convert constant inputs to XlaExpression of kind Constant. * Skip the op if a required input is not a constant. Also, modified XlaExpression::ResolveConstant to not require xla::Client for constant computation. MLIR canonicalization pass should have converted all values that can be constant folded so we don't need to resolve constants during this pass. Note that we don't have XlaCompiler available in the pass as we don't create it. Whitelisted XlaPad op for testing. PiperOrigin-RevId: 309306327 Change-Id: I91f21c09e9fa08070ade68c5dacd3450903a8570
This commit is contained in:
parent
679da1ca0e
commit
c9572919e6
@ -165,6 +165,7 @@ cc_library(
|
|||||||
":mlir_hlo_builder",
|
":mlir_hlo_builder",
|
||||||
"//tensorflow/compiler/mlir:op_or_arg_name_mapper",
|
"//tensorflow/compiler/mlir:op_or_arg_name_mapper",
|
||||||
"//tensorflow/compiler/mlir/tensorflow",
|
"//tensorflow/compiler/mlir/tensorflow",
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow:convert_tensor",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:convert_type",
|
"//tensorflow/compiler/mlir/tensorflow:convert_type",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:export_tf_dialect_op",
|
"//tensorflow/compiler/mlir/tensorflow:export_tf_dialect_op",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:lower_tf_lib",
|
"//tensorflow/compiler/mlir/tensorflow:lower_tf_lib",
|
||||||
|
@ -50,6 +50,15 @@ func @dynamic_operand(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
|||||||
return %0 : tensor<?xf32>
|
return %0 : tensor<?xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: unsupported_dtype
|
||||||
|
func @unsupported_dtype(%arg0: tensor<2x!tf.variant>) -> tensor<2x!tf.variant> {
|
||||||
|
// CHECK: tf.AddN
|
||||||
|
// expected-remark@+1 {{unsupported type: tensor<2x!tf.variant>}}
|
||||||
|
%0 = "tf.AddN"(%arg0, %arg0) : (tensor<2x!tf.variant>, tensor<2x!tf.variant>) -> tensor<2x!tf.variant>
|
||||||
|
|
||||||
|
return %0 : tensor<2x!tf.variant>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: multiple_dialect_ops
|
// CHECK-LABEL: multiple_dialect_ops
|
||||||
func @multiple_dialect_ops(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
func @multiple_dialect_ops(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||||
// CHECK: xla_hlo.negate
|
// CHECK: xla_hlo.negate
|
||||||
@ -115,12 +124,27 @@ func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
|||||||
return %0: tensor<2xi1>
|
return %0: tensor<2xi1>
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(hinsu): Add a test with variant type once one of the ops supporting
|
// CHECK-LABEL: func @const_inputs
|
||||||
// the type is whitelisted. It should be rejected with unsupported type remark.
|
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x2xf64>, %[[ARG1:.*]]: tensor<f64>,
|
||||||
|
func @const_inputs(%arg0: tensor<2x2xf64>, %arg1: tensor<f64>, %arg2: tensor<2xi32>, %arg3: tensor<2xi32>, %arg4: tensor<2xi32>) -> tensor<6x5xf64> {
|
||||||
|
|
||||||
// TODO(hinsu): Add a test with uint8 type once one of the ops supporting the
|
// CHECK: "xla_hlo.pad"(%[[ARG0]], %[[ARG1]])
|
||||||
// type is whitelisted. Unsigned types are not yet added to the HLO dialect so
|
// CHECK-SAME-DAG: edge_padding_high = dense<[1, 2]> : tensor<2xi64>
|
||||||
// it should return an error. See b/130356985
|
// CHECK-SAME-DAG: edge_padding_low = dense<[2, 1]> : tensor<2xi64>
|
||||||
|
// CHECK-SAME-DAG: interior_padding = dense<[1, 0]> : tensor<2xi64>
|
||||||
|
|
||||||
|
%0 = xla_hlo.constant dense<[2, 1]> : tensor<2xi32>
|
||||||
|
%1 = xla_hlo.constant dense<[1, 2]> : tensor<2xi32>
|
||||||
|
%2 = xla_hlo.constant dense<[1, 0]> : tensor<2xi32>
|
||||||
|
%3 = "tf.XlaPad"(%arg0, %arg1, %0, %1, %2) : (tensor<2x2xf64>, tensor<f64>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<6x5xf64>
|
||||||
|
return %3 : tensor<6x5xf64>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @non_const_inputs(%arg0: tensor<2x2xf64>, %arg1: tensor<f64>, %arg2: tensor<2xi32>, %arg3: tensor<2xi32>, %arg4: tensor<2xi32>) -> tensor<6x5xf64> {
|
||||||
|
// expected-remark@+1 {{lowering requires operand #2 to be a constant}}
|
||||||
|
%0 = "tf.XlaPad"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<2x2xf64>, tensor<f64>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<6x5xf64>
|
||||||
|
return %0 : tensor<6x5xf64>
|
||||||
|
}
|
||||||
|
|
||||||
// TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is
|
// TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is
|
||||||
// available but doesn't support this instance.
|
// available but doesn't support this instance.
|
||||||
|
@ -37,6 +37,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h.inc"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h.inc"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
|
#include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
|
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h"
|
#include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h"
|
||||||
#include "tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h"
|
#include "tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h"
|
||||||
@ -81,6 +82,7 @@ static bool IsOpWhitelisted(Operation* op) {
|
|||||||
// clang-format off
|
// clang-format off
|
||||||
static llvm::SmallDenseSet<mlir::TypeID, 512> ops = {
|
static llvm::SmallDenseSet<mlir::TypeID, 512> ops = {
|
||||||
TypeID::get<TF::AbsOp>(),
|
TypeID::get<TF::AbsOp>(),
|
||||||
|
TypeID::get<TF::AddNOp>(),
|
||||||
TypeID::get<TF::AddV2Op>(),
|
TypeID::get<TF::AddV2Op>(),
|
||||||
TypeID::get<TF::Atan2Op>(),
|
TypeID::get<TF::Atan2Op>(),
|
||||||
TypeID::get<TF::BatchMatMulV2Op>(),
|
TypeID::get<TF::BatchMatMulV2Op>(),
|
||||||
@ -122,7 +124,8 @@ static bool IsOpWhitelisted(Operation* op) {
|
|||||||
TypeID::get<TF::TruncateDivOp>(),
|
TypeID::get<TF::TruncateDivOp>(),
|
||||||
TypeID::get<TF::TruncateModOp>(),
|
TypeID::get<TF::TruncateModOp>(),
|
||||||
TypeID::get<TF::UnpackOp>(),
|
TypeID::get<TF::UnpackOp>(),
|
||||||
TypeID::get<TF::XlaDotOp>()
|
TypeID::get<TF::XlaDotOp>(),
|
||||||
|
TypeID::get<TF::XlaPadOp>()
|
||||||
};
|
};
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
@ -170,6 +173,10 @@ class FuncLegalizer {
|
|||||||
// legalization.
|
// legalization.
|
||||||
LogicalResult LegalizeOp(Operation* op);
|
LogicalResult LegalizeOp(Operation* op);
|
||||||
|
|
||||||
|
// Converts the given operand to expression of kind kConstant or kXlaOp.
|
||||||
|
// Emits a remark and returns expression of kind kInvalid on failure.
|
||||||
|
tensorflow::XlaExpression GetExprForOperand(Value operand, Operation* op);
|
||||||
|
|
||||||
FuncOp func_;
|
FuncOp func_;
|
||||||
std::string device_type_;
|
std::string device_type_;
|
||||||
|
|
||||||
@ -296,6 +303,17 @@ LogicalResult FuncLegalizer::LegalizeOp(Operation* op) {
|
|||||||
// Transfer ownership of the kernel to a local smart pointer.
|
// Transfer ownership of the kernel to a local smart pointer.
|
||||||
auto op_kernel = absl::WrapUnique(op_kernel_raw);
|
auto op_kernel = absl::WrapUnique(op_kernel_raw);
|
||||||
|
|
||||||
|
std::vector<int> required_constants;
|
||||||
|
status = tensorflow::XlaOpRegistry::CompileTimeConstantInputs(
|
||||||
|
*op_kernel, &required_constants);
|
||||||
|
if (!status.ok()) {
|
||||||
|
op->emitRemark() << "failed to compute required constants: "
|
||||||
|
<< status.ToString();
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
llvm::SmallDenseSet<int, 4> required_consts;
|
||||||
|
required_consts.insert(required_constants.begin(), required_constants.end());
|
||||||
|
|
||||||
// TensorValue in inputs are backed by tensors which in turn depend on
|
// TensorValue in inputs are backed by tensors which in turn depend on
|
||||||
// expressions. So, pre-allocate them to the required size.
|
// expressions. So, pre-allocate them to the required size.
|
||||||
InlinedVector<tensorflow::XlaExpression, 4> expressions;
|
InlinedVector<tensorflow::XlaExpression, 4> expressions;
|
||||||
@ -306,45 +324,39 @@ LogicalResult FuncLegalizer::LegalizeOp(Operation* op) {
|
|||||||
inputs.reserve(op->getNumOperands());
|
inputs.reserve(op->getNumOperands());
|
||||||
|
|
||||||
// Prepare the list of Tensor inputs for the kernel.
|
// Prepare the list of Tensor inputs for the kernel.
|
||||||
for (Value operand : op->getOperands()) {
|
for (auto it : llvm::enumerate(op->getOperands())) {
|
||||||
// Skip this op if XLA doesn't support this operand type.
|
Value operand = it.value();
|
||||||
auto xla_op_or = hlo_builder_.MakeXlaOp(operand);
|
size_t idx = it.index();
|
||||||
if (!xla_op_or.ok()) {
|
|
||||||
op->emitRemark() << "skipping legalization due to "
|
tensorflow::XlaExpression expr = GetExprForOperand(operand, op);
|
||||||
<< xla_op_or.status().ToString();
|
tensorflow::XlaExpression::Kind kind = expr.kind();
|
||||||
|
if (kind == tensorflow::XlaExpression::Kind::kInvalid) return success();
|
||||||
|
if (required_consts.count(idx) &&
|
||||||
|
kind != tensorflow::XlaExpression::Kind::kConstant) {
|
||||||
|
op->emitRemark() << "lowering requires operand #" << idx
|
||||||
|
<< " to be a constant";
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
::xla::XlaOp xla_op = xla_op_or.ValueOrDie();
|
expressions.push_back(expr);
|
||||||
|
|
||||||
tensorflow::DataType dtype;
|
if (!tensorflow::DataTypeCanUseMemcpy(expr.dtype())) {
|
||||||
status = tensorflow::ConvertToDataType(operand.getType(), &dtype);
|
|
||||||
if (!status.ok()) {
|
|
||||||
op->emitRemark() << "skipping legalization due to " << status.ToString();
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
auto expression = tensorflow::XlaExpression::XlaOp(xla_op, dtype);
|
|
||||||
expressions.push_back(expression);
|
|
||||||
|
|
||||||
if (!tensorflow::DataTypeCanUseMemcpy(dtype)) {
|
|
||||||
op->emitRemark() << "skipping legalization due to unsupported type "
|
op->emitRemark() << "skipping legalization due to unsupported type "
|
||||||
<< operand.getType();
|
<< operand.getType();
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto shape_or = expression.GetShape();
|
auto shape_or = expr.GetShape();
|
||||||
if (!shape_or.ok()) {
|
if (!shape_or.ok()) {
|
||||||
op->emitRemark() << "failed to get shape for expression. "
|
op->emitRemark() << "failed to get shape for expression. "
|
||||||
<< expression.HumanString();
|
<< expr.HumanString();
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
tensors.emplace_back(
|
tensors.emplace_back(
|
||||||
device_->GetAllocator(tensorflow::AllocatorAttributes()), dtype,
|
device_->GetAllocator(tensorflow::AllocatorAttributes()), expr.dtype(),
|
||||||
shape_or.ValueOrDie());
|
shape_or.ValueOrDie());
|
||||||
tensorflow::Tensor& tensor = tensors.back();
|
tensorflow::Tensor& tensor = tensors.back();
|
||||||
tensorflow::XlaOpKernelContext::AssignExpressionToTensor(expression,
|
tensorflow::XlaOpKernelContext::AssignExpressionToTensor(expr, &tensor);
|
||||||
&tensor);
|
|
||||||
inputs.emplace_back(&tensor);
|
inputs.emplace_back(&tensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -383,6 +395,39 @@ LogicalResult FuncLegalizer::LegalizeOp(Operation* op) {
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tensorflow::XlaExpression FuncLegalizer::GetExprForOperand(Value operand,
|
||||||
|
Operation* op) {
|
||||||
|
ElementsAttr const_attr;
|
||||||
|
auto defining_op = operand.getDefiningOp();
|
||||||
|
if (defining_op && matchPattern(defining_op, m_Constant(&const_attr))) {
|
||||||
|
tensorflow::Tensor tensor;
|
||||||
|
auto status = tensorflow::ConvertToTensor(const_attr, &tensor);
|
||||||
|
if (!status.ok()) {
|
||||||
|
op->emitRemark() << "skipping legalization due to failed const conversion"
|
||||||
|
<< status.ToString();
|
||||||
|
return tensorflow::XlaExpression::Invalid();
|
||||||
|
}
|
||||||
|
return tensorflow::XlaExpression::Constant(tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip this op if XLA doesn't support this operand type.
|
||||||
|
auto xla_op_or = hlo_builder_.MakeXlaOp(operand);
|
||||||
|
if (!xla_op_or.ok()) {
|
||||||
|
op->emitRemark() << "skipping legalization due to "
|
||||||
|
<< xla_op_or.status().ToString();
|
||||||
|
return tensorflow::XlaExpression::Invalid();
|
||||||
|
}
|
||||||
|
::xla::XlaOp xla_op = xla_op_or.ValueOrDie();
|
||||||
|
|
||||||
|
tensorflow::DataType dtype;
|
||||||
|
auto status = tensorflow::ConvertToDataType(operand.getType(), &dtype);
|
||||||
|
if (!status.ok()) {
|
||||||
|
op->emitRemark() << "skipping legalization due to " << status.ToString();
|
||||||
|
return tensorflow::XlaExpression::Invalid();
|
||||||
|
}
|
||||||
|
return tensorflow::XlaExpression::XlaOp(xla_op, dtype);
|
||||||
|
}
|
||||||
|
|
||||||
class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
|
class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
|
||||||
public:
|
public:
|
||||||
LegalizeTF() = default;
|
LegalizeTF() = default;
|
||||||
|
@ -121,6 +121,9 @@ xla::StatusOr<absl::optional<Tensor>> XlaExpression::ResolveConstant(
|
|||||||
handle().builder()->IsConstant(handle()));
|
handle().builder()->IsConstant(handle()));
|
||||||
if (!is_constant) return {absl::nullopt};
|
if (!is_constant) return {absl::nullopt};
|
||||||
|
|
||||||
|
if (!client)
|
||||||
|
return errors::InvalidArgument("client is required to resolve constant");
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(xla::XlaComputation constant_graph,
|
TF_ASSIGN_OR_RETURN(xla::XlaComputation constant_graph,
|
||||||
handle().builder()->BuildConstantSubGraph(
|
handle().builder()->BuildConstantSubGraph(
|
||||||
handle(), dynamic_dimension_is_minus_one));
|
handle(), dynamic_dimension_is_minus_one));
|
||||||
|
@ -175,8 +175,9 @@ Status XlaOpKernelContext::ConstantInputReshaped(
|
|||||||
int index, absl::Span<const int64> new_dims,
|
int index, absl::Span<const int64> new_dims,
|
||||||
xla::Literal* constant_literal) {
|
xla::Literal* constant_literal) {
|
||||||
XlaExpression e = InputExpression(index);
|
XlaExpression e = InputExpression(index);
|
||||||
|
auto* client = compiler() ? compiler()->client() : nullptr;
|
||||||
xla::StatusOr<absl::optional<Tensor>> constant_or_status =
|
xla::StatusOr<absl::optional<Tensor>> constant_or_status =
|
||||||
e.ResolveConstant(compiler()->client(), dynamic_dimension_is_minus_one_);
|
e.ResolveConstant(client, dynamic_dimension_is_minus_one_);
|
||||||
if (!constant_or_status.ok()) {
|
if (!constant_or_status.ok()) {
|
||||||
Status status = constant_or_status.status();
|
Status status = constant_or_status.status();
|
||||||
errors::AppendToMessage(&status, "while evaluating input ", index, " of ",
|
errors::AppendToMessage(&status, "while evaluating input ", index, " of ",
|
||||||
|
Loading…
Reference in New Issue
Block a user