diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 7326192f418..86427932c1e 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -8437,6 +8437,31 @@ output_handle: The list. TF_DerivedOperandTypeAttr element_dtype = TF_DerivedOperandTypeAttr<0>; } +def TF_TensorListGatherOp : TF_Op<"TensorListGather", [NoSideEffect]> { + let summary = "Creates a Tensor by indexing into the TensorList."; + + let description = [{ +Each row in the produced Tensor corresponds to the element in the TensorList +specified by the given index (see `tf.gather`). + +input_handle: The input tensor list. +indices: The indices used to index into the list. +values: The tensor. + }]; + + let arguments = (ins + TF_VariantTensor:$input_handle, + I32Tensor:$indices, + I32Tensor:$element_shape + ); + + let results = (outs + TF_Tensor:$values + ); + + TF_DerivedResultTypeAttr element_dtype = TF_DerivedResultTypeAttr<0>; +} + def TF_TensorListGetItemOp : TF_Op<"TensorListGetItem", [NoSideEffect]> { let summary = ""; diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tensor_list_ops_decomposition.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tensor_list_ops_decomposition.mlir index 9e43cea1003..682da38fc56 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tensor_list_ops_decomposition.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tensor_list_ops_decomposition.mlir @@ -94,7 +94,7 @@ func @main(%arg0: tensor) -> (tensor, tensor<10xf32>, tensor) { func @main(%arg0: tensor, %arg1: tensor<10xf32>) -> tensor { // CHECK-NEXT: "tf.Const"() {value = dense<[]> : tensor<0xi32>} %elem_shape = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> - // CHECK-NEXT: %[[BUFFER:.*]] = "tf.Identity"(%arg1) : (tensor<10xf32>) -> tensor<10xf32> + // CHECK-NEXT: %[[BUFFER:.*]] = "tf.Identity"(%[[ARG1]]) : (tensor<10xf32>) -> tensor<10xf32> // CHECK-NEXT: %[[SIZE:.*]] = "tf.Const"() {value = dense<10> : tensor<1xi32>} : () -> tensor<1xi32> %tl = "tf.TensorListFromTensor"(%arg1, %elem_shape) : (tensor<10xf32>, tensor<0xi32>) -> tensor>> // CHECK-NEXT: %[[SIZE_SHAPE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} @@ -110,6 +110,37 @@ func @main(%arg0: tensor, %arg1: tensor<10xf32>) -> tensor { // ----- +// Test tensor list element shape op. + +// CHECK-LABEL: func @main +func @main(%arg0: tensor<10x8x9xf32>) -> tensor<2xi64> { + %elem_shape = "tf.Const"() {value = dense<[8, 9]> : tensor<2xi32>} : () -> tensor<2xi32> + %tl = "tf.TensorListFromTensor"(%arg0, %elem_shape) : (tensor<10x8x9xf32>, tensor<2xi32>) -> tensor>> + // CHECK: %[[SHAPE:.*]] = "tf.Const"() {value = dense<[8, 9]> : tensor<2xi64>} : () -> tensor<2xi64> + %shape = "tf.TensorListElementShape"(%tl) : (tensor>>) -> tensor<2xi64> + // CHECK-NEXT: return %[[SHAPE]] : tensor<2xi64> + return %shape: tensor<2xi64> +} + +// ----- + +// Test tensor list gather op. + +// CHECK-LABEL: func @main +// CHECK-SAME: (%[[ARG0:.*]]: tensor<10x8x9xf32>, %[[ARG1:.*]]: tensor<3xi32>) -> tensor<3x8x9xf32> +func @main(%arg0: tensor<10x8x9xf32>, %arg1: tensor<3xi32>) -> tensor<3x8x9xf32> { + %elem_shape = "tf.Const"() {value = dense<[8, 9]> : tensor<2xi32>} : () -> tensor<2xi32> + // CHECK: %[[BUFFER:.*]] = "tf.Identity"(%[[ARG0]]) : (tensor<10x8x9xf32>) -> tensor<10x8x9xf32> + %tl = "tf.TensorListFromTensor"(%arg0, %elem_shape) : (tensor<10x8x9xf32>, tensor<2xi32>) -> tensor>> + // CHECK: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + // CHECK: %[[GATHER:.*]] = "tf.GatherV2"(%[[BUFFER]], %[[ARG1]], %[[AXIS]]) : (tensor<10x8x9xf32>, tensor<3xi32>, tensor) -> tensor<3x8x9xf32> + %gather = "tf.TensorListGather"(%tl, %arg1, %elem_shape) : (tensor>>, tensor<3xi32>, tensor<2xi32>) -> tensor<3x8x9xf32> + // CHECK-NEXT: return %[[GATHER]] : tensor<3x8x9xf32> + return %gather: tensor<3x8x9xf32> +} + +// ----- + // Tests while loop. // CHECK-LABEL: func @main diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc b/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc index 3a645fe9bf6..d9715d11922 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc @@ -53,14 +53,16 @@ Value CreateScalarConst(int value, OpBuilder builder, Location loc) { loc, tensorflow::ConvertTensor(scalar_tensor, &builder).ValueOrDie()); } -Value GetR1Const(ArrayRef r1, OpBuilder builder, Location loc) { - tensorflow::Tensor shape_tensor(tensorflow::DT_INT32, - {static_cast(r1.size())}); - for (int i = 0; i < r1.size(); ++i) { - shape_tensor.vec()(i) = r1[i]; - } +Value GetR1Const(ArrayRef r1, OpBuilder builder, Location loc, + int bitwidth) { + llvm::SmallVector values; + int64_t rank = r1.size(); + values.reserve(rank); + for (int i = 0; i < rank; ++i) values.push_back(APInt(bitwidth, r1[i])); + auto result_type = RankedTensorType::get( + {rank}, IntegerType::get(bitwidth, builder.getContext())); return builder.create( - loc, tensorflow::ConvertTensor(shape_tensor, &builder).ValueOrDie()); + loc, DenseElementsAttr::get(result_type, values)); } Value GetIndicesForElement(Value index, Value buffer, OpBuilder builder, diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h b/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h index f0e0eda3b5d..423797279d3 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h @@ -36,8 +36,9 @@ namespace collection_ops_util { // Creates an i32 scalar tf.Const. Value CreateScalarConst(int value, OpBuilder builder, Location loc); -// Creates an i32 vector tf.Const. -Value GetR1Const(ArrayRef r1, OpBuilder builder, Location loc); +// Creates an integer vector tf.Const. +Value GetR1Const(ArrayRef r1, OpBuilder builder, Location loc, + int bitwidth = 32); // Returns the type of the size tensor used to track a data structure's element // count. It is a tensor<1xi32>, and we use R1 instead of a scalar because it is diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc index 277146d5c42..d8ae4fb534a 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc @@ -574,6 +574,37 @@ LogicalResult HandleTensorListLengthOp( return success(); } +LogicalResult HandleTensorListElementShapeOp( + TF::TensorListElementShapeOp elem_shape, + const llvm::SmallDenseMap& buffer_to_size) { + if (buffer_to_size.count(elem_shape.input_handle()) == 0) { + return elem_shape.emitOpError("unknown tensor list"); + } + auto buffer = elem_shape.input_handle(); + auto result = cutil::GetR1Const( + buffer.getType().cast().getShape().drop_front(), + OpBuilder(elem_shape), elem_shape.getLoc(), + elem_shape.shape_type().getIntOrFloatBitWidth()); + elem_shape.element_shape().replaceAllUsesWith(result); + elem_shape.erase(); + return success(); +} + +LogicalResult HandleTensorListGatherOp( + TF::TensorListGatherOp gather, + const llvm::SmallDenseMap& buffer_to_size) { + auto it = buffer_to_size.find(gather.input_handle()); + if (it == buffer_to_size.end()) { + return gather.emitOpError("unknown tensor list"); + } + auto buffer = gather.input_handle(); + auto result = cutil::GatherElements(gather.indices(), buffer, + OpBuilder(gather), gather.getLoc()); + gather.values().replaceAllUsesWith(result); + gather.erase(); + return success(); +} + LogicalResult DecomposeTensorListOpsInternal( Block* block, ModuleOp module, llvm::SmallDenseMap* buffer_to_size, @@ -619,6 +650,15 @@ LogicalResult DecomposeTensorListOpsInternal( } else if (auto stack = llvm::dyn_cast(&op)) { stack.tensor().replaceAllUsesWith(stack.input_handle()); stack.erase(); + } else if (auto elem_shape = + llvm::dyn_cast(&op)) { + if (failed(HandleTensorListElementShapeOp(elem_shape, *buffer_to_size))) { + return failure(); + } + } else if (auto gather = llvm::dyn_cast(&op)) { + if (failed(HandleTensorListGatherOp(gather, *buffer_to_size))) { + return failure(); + } } else if (auto addn = llvm::dyn_cast(&op)) { auto it = buffer_to_size->find(addn.getOperand(0)); if (it != buffer_to_size->end()) {