Handle constant inputs in xla-legalize-tf-with-tf2xla pass

Some of the XlaOpKernels require constant inputs during compilation. These are provided as TensorFlow Tensors wrapped within XlaExpression.

To use these kernels for compilation,
* Figure out the inputs that needs to constant for the op compilation.
* Convert constant inputs to XlaExpression of kind Constant.
* Skip the op if a required input is not a constant.

Also, modified XlaExpression::ResolveConstant to not require xla::Client for constant computation. MLIR canonicalization pass should have converted all values that can be constant folded so we don't need to resolve constants during this pass. Note that we don't have XlaCompiler available in the pass as we don't create it.

Whitelisted XlaPad op for testing.

PiperOrigin-RevId: 309306327
Change-Id: I91f21c09e9fa08070ade68c5dacd3450903a8570
This commit is contained in:
Smit Hinsu 2020-04-30 15:01:10 -07:00 committed by TensorFlower Gardener
parent 679da1ca0e
commit c9572919e6
5 changed files with 104 additions and 30 deletions

View File

@ -165,6 +165,7 @@ cc_library(
":mlir_hlo_builder",
"//tensorflow/compiler/mlir:op_or_arg_name_mapper",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:convert_tensor",
"//tensorflow/compiler/mlir/tensorflow:convert_type",
"//tensorflow/compiler/mlir/tensorflow:export_tf_dialect_op",
"//tensorflow/compiler/mlir/tensorflow:lower_tf_lib",

View File

@ -50,6 +50,15 @@ func @dynamic_operand(%arg0: tensor<?xf32>) -> tensor<?xf32> {
return %0 : tensor<?xf32>
}
// CHECK-LABEL: unsupported_dtype
func @unsupported_dtype(%arg0: tensor<2x!tf.variant>) -> tensor<2x!tf.variant> {
// CHECK: tf.AddN
// expected-remark@+1 {{unsupported type: tensor<2x!tf.variant>}}
%0 = "tf.AddN"(%arg0, %arg0) : (tensor<2x!tf.variant>, tensor<2x!tf.variant>) -> tensor<2x!tf.variant>
return %0 : tensor<2x!tf.variant>
}
// CHECK-LABEL: multiple_dialect_ops
func @multiple_dialect_ops(%arg0: tensor<2xf32>) -> tensor<2xf32> {
// CHECK: xla_hlo.negate
@ -115,12 +124,27 @@ func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> {
return %0: tensor<2xi1>
}
// TODO(hinsu): Add a test with variant type once one of the ops supporting
// the type is whitelisted. It should be rejected with unsupported type remark.
// CHECK-LABEL: func @const_inputs
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x2xf64>, %[[ARG1:.*]]: tensor<f64>,
func @const_inputs(%arg0: tensor<2x2xf64>, %arg1: tensor<f64>, %arg2: tensor<2xi32>, %arg3: tensor<2xi32>, %arg4: tensor<2xi32>) -> tensor<6x5xf64> {
// TODO(hinsu): Add a test with uint8 type once one of the ops supporting the
// type is whitelisted. Unsigned types are not yet added to the HLO dialect so
// it should return an error. See b/130356985
// CHECK: "xla_hlo.pad"(%[[ARG0]], %[[ARG1]])
// CHECK-SAME-DAG: edge_padding_high = dense<[1, 2]> : tensor<2xi64>
// CHECK-SAME-DAG: edge_padding_low = dense<[2, 1]> : tensor<2xi64>
// CHECK-SAME-DAG: interior_padding = dense<[1, 0]> : tensor<2xi64>
%0 = xla_hlo.constant dense<[2, 1]> : tensor<2xi32>
%1 = xla_hlo.constant dense<[1, 2]> : tensor<2xi32>
%2 = xla_hlo.constant dense<[1, 0]> : tensor<2xi32>
%3 = "tf.XlaPad"(%arg0, %arg1, %0, %1, %2) : (tensor<2x2xf64>, tensor<f64>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<6x5xf64>
return %3 : tensor<6x5xf64>
}
func @non_const_inputs(%arg0: tensor<2x2xf64>, %arg1: tensor<f64>, %arg2: tensor<2xi32>, %arg3: tensor<2xi32>, %arg4: tensor<2xi32>) -> tensor<6x5xf64> {
// expected-remark@+1 {{lowering requires operand #2 to be a constant}}
%0 = "tf.XlaPad"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<2x2xf64>, tensor<f64>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<6x5xf64>
return %0 : tensor<6x5xf64>
}
// TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is
// available but doesn't support this instance.

View File

@ -37,6 +37,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h.inc"
#include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h"
#include "tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h"
@ -81,6 +82,7 @@ static bool IsOpWhitelisted(Operation* op) {
// clang-format off
static llvm::SmallDenseSet<mlir::TypeID, 512> ops = {
TypeID::get<TF::AbsOp>(),
TypeID::get<TF::AddNOp>(),
TypeID::get<TF::AddV2Op>(),
TypeID::get<TF::Atan2Op>(),
TypeID::get<TF::BatchMatMulV2Op>(),
@ -122,7 +124,8 @@ static bool IsOpWhitelisted(Operation* op) {
TypeID::get<TF::TruncateDivOp>(),
TypeID::get<TF::TruncateModOp>(),
TypeID::get<TF::UnpackOp>(),
TypeID::get<TF::XlaDotOp>()
TypeID::get<TF::XlaDotOp>(),
TypeID::get<TF::XlaPadOp>()
};
// clang-format on
@ -170,6 +173,10 @@ class FuncLegalizer {
// legalization.
LogicalResult LegalizeOp(Operation* op);
// Converts the given operand to expression of kind kConstant or kXlaOp.
// Emits a remark and returns expression of kind kInvalid on failure.
tensorflow::XlaExpression GetExprForOperand(Value operand, Operation* op);
FuncOp func_;
std::string device_type_;
@ -296,6 +303,17 @@ LogicalResult FuncLegalizer::LegalizeOp(Operation* op) {
// Transfer ownership of the kernel to a local smart pointer.
auto op_kernel = absl::WrapUnique(op_kernel_raw);
std::vector<int> required_constants;
status = tensorflow::XlaOpRegistry::CompileTimeConstantInputs(
*op_kernel, &required_constants);
if (!status.ok()) {
op->emitRemark() << "failed to compute required constants: "
<< status.ToString();
return success();
}
llvm::SmallDenseSet<int, 4> required_consts;
required_consts.insert(required_constants.begin(), required_constants.end());
// TensorValue in inputs are backed by tensors which in turn depend on
// expressions. So, pre-allocate them to the required size.
InlinedVector<tensorflow::XlaExpression, 4> expressions;
@ -306,45 +324,39 @@ LogicalResult FuncLegalizer::LegalizeOp(Operation* op) {
inputs.reserve(op->getNumOperands());
// Prepare the list of Tensor inputs for the kernel.
for (Value operand : op->getOperands()) {
// Skip this op if XLA doesn't support this operand type.
auto xla_op_or = hlo_builder_.MakeXlaOp(operand);
if (!xla_op_or.ok()) {
op->emitRemark() << "skipping legalization due to "
<< xla_op_or.status().ToString();
for (auto it : llvm::enumerate(op->getOperands())) {
Value operand = it.value();
size_t idx = it.index();
tensorflow::XlaExpression expr = GetExprForOperand(operand, op);
tensorflow::XlaExpression::Kind kind = expr.kind();
if (kind == tensorflow::XlaExpression::Kind::kInvalid) return success();
if (required_consts.count(idx) &&
kind != tensorflow::XlaExpression::Kind::kConstant) {
op->emitRemark() << "lowering requires operand #" << idx
<< " to be a constant";
return success();
}
::xla::XlaOp xla_op = xla_op_or.ValueOrDie();
expressions.push_back(expr);
tensorflow::DataType dtype;
status = tensorflow::ConvertToDataType(operand.getType(), &dtype);
if (!status.ok()) {
op->emitRemark() << "skipping legalization due to " << status.ToString();
return success();
}
auto expression = tensorflow::XlaExpression::XlaOp(xla_op, dtype);
expressions.push_back(expression);
if (!tensorflow::DataTypeCanUseMemcpy(dtype)) {
if (!tensorflow::DataTypeCanUseMemcpy(expr.dtype())) {
op->emitRemark() << "skipping legalization due to unsupported type "
<< operand.getType();
return success();
}
auto shape_or = expression.GetShape();
auto shape_or = expr.GetShape();
if (!shape_or.ok()) {
op->emitRemark() << "failed to get shape for expression. "
<< expression.HumanString();
<< expr.HumanString();
return success();
}
tensors.emplace_back(
device_->GetAllocator(tensorflow::AllocatorAttributes()), dtype,
device_->GetAllocator(tensorflow::AllocatorAttributes()), expr.dtype(),
shape_or.ValueOrDie());
tensorflow::Tensor& tensor = tensors.back();
tensorflow::XlaOpKernelContext::AssignExpressionToTensor(expression,
&tensor);
tensorflow::XlaOpKernelContext::AssignExpressionToTensor(expr, &tensor);
inputs.emplace_back(&tensor);
}
@ -383,6 +395,39 @@ LogicalResult FuncLegalizer::LegalizeOp(Operation* op) {
return success();
}
tensorflow::XlaExpression FuncLegalizer::GetExprForOperand(Value operand,
Operation* op) {
ElementsAttr const_attr;
auto defining_op = operand.getDefiningOp();
if (defining_op && matchPattern(defining_op, m_Constant(&const_attr))) {
tensorflow::Tensor tensor;
auto status = tensorflow::ConvertToTensor(const_attr, &tensor);
if (!status.ok()) {
op->emitRemark() << "skipping legalization due to failed const conversion"
<< status.ToString();
return tensorflow::XlaExpression::Invalid();
}
return tensorflow::XlaExpression::Constant(tensor);
}
// Skip this op if XLA doesn't support this operand type.
auto xla_op_or = hlo_builder_.MakeXlaOp(operand);
if (!xla_op_or.ok()) {
op->emitRemark() << "skipping legalization due to "
<< xla_op_or.status().ToString();
return tensorflow::XlaExpression::Invalid();
}
::xla::XlaOp xla_op = xla_op_or.ValueOrDie();
tensorflow::DataType dtype;
auto status = tensorflow::ConvertToDataType(operand.getType(), &dtype);
if (!status.ok()) {
op->emitRemark() << "skipping legalization due to " << status.ToString();
return tensorflow::XlaExpression::Invalid();
}
return tensorflow::XlaExpression::XlaOp(xla_op, dtype);
}
class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
public:
LegalizeTF() = default;

View File

@ -121,6 +121,9 @@ xla::StatusOr<absl::optional<Tensor>> XlaExpression::ResolveConstant(
handle().builder()->IsConstant(handle()));
if (!is_constant) return {absl::nullopt};
if (!client)
return errors::InvalidArgument("client is required to resolve constant");
TF_ASSIGN_OR_RETURN(xla::XlaComputation constant_graph,
handle().builder()->BuildConstantSubGraph(
handle(), dynamic_dimension_is_minus_one));

View File

@ -175,8 +175,9 @@ Status XlaOpKernelContext::ConstantInputReshaped(
int index, absl::Span<const int64> new_dims,
xla::Literal* constant_literal) {
XlaExpression e = InputExpression(index);
auto* client = compiler() ? compiler()->client() : nullptr;
xla::StatusOr<absl::optional<Tensor>> constant_or_status =
e.ResolveConstant(compiler()->client(), dynamic_dimension_is_minus_one_);
e.ResolveConstant(client, dynamic_dimension_is_minus_one_);
if (!constant_or_status.ok()) {
Status status = constant_or_status.status();
errors::AppendToMessage(&status, "while evaluating input ", index, " of ",