diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index df3ffd0599c..bf2d8103872 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -142,6 +142,19 @@ cc_library( ], ) +cc_library( + name = "hlo_shape_derivation", + srcs = [], + hdrs = ["transforms/hlo_shape_derivation.h"], + deps = [ + ":hlo", + ":lhlo", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Transforms", + ], +) + cc_library( name = "lhlo_legalize_to_affine", srcs = ["transforms/lhlo_legalize_to_affine.cc"], @@ -218,9 +231,9 @@ cc_library( srcs = ["transforms/hlo_legalize_to_lhlo.cc"], deps = [ ":hlo", + ":hlo_shape_derivation", ":lhlo", "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:StandardOps", diff --git a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir index 4b2d76e586a..be6f0e6a949 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir +++ b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt -hlo-legalize-to-lhlo -lhlo-redundant-copies-removal %s -o - | FileCheck %s --dump-input=always +// RUN: tf-opt -hlo-legalize-to-lhlo -lhlo-redundant-copies-removal -split-input-file %s -o - | FileCheck %s -dump-input-on-failure // CHECK-LABEL: func @attrs func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { @@ -11,6 +11,8 @@ func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { return } +// ----- + // CHECK-LABEL: func @func_op func @func_op(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>) @@ -20,6 +22,8 @@ func @func_op(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: "xla_lhlo.terminator"() : () -> () } +// ----- + // CHECK-LABEL: func @func_op_long func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>) @@ -45,6 +49,8 @@ func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: "xla_lhlo.terminator"() : () -> () } +// ----- + // CHECK-LABEL: func @remove_lhlo_copy_op_created_from_tensor_store func @remove_lhlo_copy_op_created_from_tensor_store(%arg0: tensor, %arg1: tensor, %arg2: memref) { %0 = "xla_hlo.max"(%arg0, %arg1) : (tensor, tensor) -> tensor @@ -58,6 +64,8 @@ func @remove_lhlo_copy_op_created_from_tensor_store(%arg0: tensor, %arg1: t // CHECK-NOT: dealloc %[[ALLOC_OPERAND]] : memref // CHECK: "xla_lhlo.terminator"() : () -> () +// ----- + // CHECK-LABEL: func @fusion func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>, %summand_2: memref<2x2xf32>, %result: memref<2x2xf32>) { @@ -77,6 +85,8 @@ func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>, "xla_lhlo.terminator"() : () -> () } +// ----- + // CHECK-LABEL: func @copy func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> @@ -87,6 +97,8 @@ func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { return } +// ----- + // CHECK-LABEL: func @exp func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> @@ -97,6 +109,8 @@ func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { return } +// ----- + // CHECK-LABEL: func @select func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) { @@ -110,6 +124,8 @@ func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>, return } +// ----- + // CHECK-LABEL: func @compare func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xi1>) { %tensor_lhs = tensor_load %lhs : memref<2x2xf32> @@ -122,6 +138,8 @@ func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2x return } +// ----- + // CHECK-LABEL: func @broadcast func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) { %tensor_operand = tensor_load %operand : memref<5xf32> @@ -133,6 +151,8 @@ func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) { return } +// ----- + // CHECK-LABEL: func @dyn_broadcast func @dyn_broadcast(%operand: memref) { %tensor_operand = tensor_load %operand : memref @@ -157,6 +177,8 @@ func @dyn_broadcast(%operand: memref) { return } +// ----- + // CHECK-LABEL: func @iota func @iota(%result: memref<10xi32>) { %tensor_result = "xla_hlo.iota"() @@ -166,6 +188,8 @@ func @iota(%result: memref<10xi32>) { return } +// ----- + // CHECK-LABEL: func @abs func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> @@ -176,6 +200,8 @@ func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { return } +// ----- + // CHECK-LABEL: func @ceil func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> @@ -186,6 +212,8 @@ func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { return } +// ----- + // CHECK-LABEL: func @convert func @convert(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> @@ -196,6 +224,8 @@ func @convert(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { return } +// ----- + // CHECK-LABEL: func @cos func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> @@ -206,6 +236,8 @@ func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { return } +// ----- + // CHECK-LABEL: func @neg func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> @@ -216,6 +248,8 @@ func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { return } +// ----- + // CHECK-LABEL: func @sign func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> @@ -226,6 +260,8 @@ func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { return } +// ----- + // CHECK-LABEL: func @tanh func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> @@ -236,6 +272,8 @@ func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { return } +// ----- + // CHECK-LABEL: func @remainder func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_lhs = tensor_load %lhs : memref<2x2xf32> @@ -246,3 +284,47 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x tensor_store %tensor_result, %result : memref<2x2xf32> return } + +// ----- + +// Dynamic shape binary element-wise operation. +// CHECK-LABEL: func @add_dyn +func @add_dyn(%lhs: tensor, %rhs: tensor) { + %result = "xla_hlo.add"(%lhs, %rhs) + : (tensor, tensor) -> tensor + // CHECK: %[[DIM0:.*]] = dim %arg0, 0 : memref + // CHECK: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64 + // CHECK: %[[DIM1:.*]] = dim %arg0, 1 : memref + // CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64 + // CHECK: %[[SHAPE:.*]] = "xla_hlo.scalars_to_dimension_tensor"(%[[IC0]], %[[IC1]]) : (i64, i64) -> tensor<2xi64> + // CHECK: %[[C0:.*]] = constant 0 : index + // CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<2xi64> + // CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index + // CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<2xi64> + // CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index + // CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]]) + // CHECK: "xla_lhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref, memref, memref) -> () + return +} + +// ----- + +// Dynamic shape unary element-wise operation. +// CHECK-LABEL: func @tanh_dyn +func @tanh_dyn(%arg0: tensor) { + %result = "xla_hlo.tanh"(%arg0) + : (tensor) -> tensor + // CHECK: %[[DIM0:.*]] = dim %arg0, 0 : memref + // CHECK: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64 + // CHECK: %[[DIM1:.*]] = dim %arg0, 1 : memref + // CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64 + // CHECK: %[[SHAPE:.*]] = "xla_hlo.scalars_to_dimension_tensor"(%[[IC0]], %[[IC1]]) : (i64, i64) -> tensor<2xi64> + // CHECK: %[[C0:.*]] = constant 0 : index + // CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<2xi64> + // CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index + // CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<2xi64> + // CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index + // CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]]) + // CHECK: "xla_lhlo.tanh"(%arg0, %[[RESULT]]) : (memref, memref) -> () + return +} diff --git a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc index 77c361a8ab5..1384abed91c 100644 --- a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc +++ b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc @@ -30,6 +30,7 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" // TF:llvm-project #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h" +#include "tensorflow/compiler/mlir/xla/transforms/hlo_shape_derivation.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" #include "tensorflow/compiler/mlir/xla/transforms/rewriters.h" @@ -127,9 +128,24 @@ class HloToLhloOpConverter : public ConversionPattern { ConversionPatternRewriter& rewriter) const final { const auto& original_results = op->getResults(); SmallVector buffer_args(operands.begin(), operands.end()); - for (auto result : original_results) { - buffer_args.push_back( - InsertAllocAndDealloc(op->getLoc(), result, &rewriter)); + for (auto result : llvm::enumerate(original_results)) { + RankedTensorType resultType = + result.value().getType().dyn_cast(); + if (!resultType) { + return matchFailure(); + } + if (resultType.hasStaticShape()) { + buffer_args.push_back( + InsertAllocAndDealloc(op->getLoc(), result.value(), &rewriter)); + } else { + Value shape_value = ShapeDerivation::impl::deriveShapeFromOp( + op, result.index(), &rewriter); + if (!shape_value) { + return matchFailure(); + } + buffer_args.push_back(InsertDynamicAllocAndDealloc( + op->getLoc(), result.value(), shape_value, &rewriter)); + } } rewriter.create(op->getLoc(), llvm::None, buffer_args, op->getAttrs()); @@ -320,6 +336,7 @@ struct HloLegalizeToLhlo : public ModulePass { target.addIllegalOp(); target.addIllegalOp(); target.addLegalOp(); + target.addLegalOp(); target.addIllegalDialect(); target.addDynamicallyLegalOp([&](FuncOp op) { auto inputs = op.getType().getInputs(); diff --git a/tensorflow/compiler/mlir/xla/transforms/hlo_shape_derivation.h b/tensorflow/compiler/mlir/xla/transforms/hlo_shape_derivation.h new file mode 100644 index 00000000000..7c6d162632f --- /dev/null +++ b/tensorflow/compiler/mlir/xla/transforms/hlo_shape_derivation.h @@ -0,0 +1,130 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_HLO_SHAPE_DERIVATION_H_ +#define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_HLO_SHAPE_DERIVATION_H_ + +#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Location.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" + +namespace mlir { +namespace xla_hlo { + +// This file contains implementations for shape derivation functions that, +// given some operation and a result number, produce IR that computes the +// shape of the given result at runtime based on operands of the provided +// operation. +// These should be generated at some point based on annotations on the HLO +// using the new shape dialect. While this is still in the works, we hardcode +// the expected IR here to unblock progress. +// The implementation is based on templates to allow for using these derivation +// functions in templated code. + +namespace impl { + +struct UnknownShape { + // Default shape derivation function that simply fails with a runtime error. + static Value deriveShapeFromOp(Operation* op, int operand_position, + ConversionPatternRewriter* rewriter) { + op->emitOpError() + << "dynamic result shapes cannot be derived for this operation"; + return {}; + } +}; + +struct SameShapeAsFirstOperand { + // Shape derivation function that computes the shape of the result based on + // the first argument. For a 2-dimensional input tensor, this produces IR of + // the form + // + // %0 = dim %arg0, 0 : memref + // %1 = index_cast %0 : index to i64 + // %2 = dim %arg0, 1 : memref + // %3 = index_cast %2 : index to i64 + // %4 = "xla_hlo.scalars_to_dimension_tensor"(%1, %3) + // : (i64, i64) -> tensor<2xi64> + // + // and returns %4 as the shape value. + static Value deriveShapeFromOp(Operation* op, int result_postion, + ConversionPatternRewriter* rewriter) { + Value operand = op->getOperand(0); + ShapedType operand_type = operand.getType().dyn_cast(); + if (!operand_type) { + op->emitOpError() << "first operand has no shaped type"; + return {}; + } + auto loc = op->getLoc(); + SmallVector shape_values; + shape_values.reserve(operand_type.getRank()); + auto shape_scalar_type = rewriter->getIntegerType(64); + for (auto element : llvm::enumerate(operand_type.getShape())) { + if (element.value() == ShapedType::kDynamicSize) { + Value dim = rewriter->create(loc, operand, element.index()); + shape_values.push_back( + rewriter->create(loc, dim, shape_scalar_type)); + } else { + shape_values.push_back(rewriter->create( + loc, rewriter->getI64IntegerAttr(element.value()))); + } + } + return rewriter->create( + loc, RankedTensorType::get({operand_type.getRank()}, shape_scalar_type), + shape_values); + } +}; + +} // namespace impl + +// Default template to cover HLO operations whose shape derivation is unknown. +template +struct ShapeDerivation { + using impl = impl::UnknownShape; +}; + +// Element-wise operations that have the shape of their first operand. + +#define SAME_SHAPE_AS_FIRST_OPERAND(Op) \ + template <> \ + struct ShapeDerivation { \ + using impl = impl::SameShapeAsFirstOperand; \ + }; + +SAME_SHAPE_AS_FIRST_OPERAND(AbsOp) +SAME_SHAPE_AS_FIRST_OPERAND(AddOp) +SAME_SHAPE_AS_FIRST_OPERAND(AndOp) +SAME_SHAPE_AS_FIRST_OPERAND(CeilOp) +SAME_SHAPE_AS_FIRST_OPERAND(CosOp) +SAME_SHAPE_AS_FIRST_OPERAND(DivOp) +SAME_SHAPE_AS_FIRST_OPERAND(ExpOp) +SAME_SHAPE_AS_FIRST_OPERAND(MaxOp) +SAME_SHAPE_AS_FIRST_OPERAND(MinOp) +SAME_SHAPE_AS_FIRST_OPERAND(MulOp) +SAME_SHAPE_AS_FIRST_OPERAND(NegOp) +SAME_SHAPE_AS_FIRST_OPERAND(RemOp) +SAME_SHAPE_AS_FIRST_OPERAND(SubOp) +SAME_SHAPE_AS_FIRST_OPERAND(TanhOp) + +#undef SAME_SHAPE_AS_FIRST_OPERAND + +} // namespace xla_hlo +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_HLO_SHAPE_DERIVATION_H_