[MLIR:TF/XLA] Handle remaining tensor list ops in decomposition
PiperOrigin-RevId: 305066660 Change-Id: Ib11fd1af6350eaf80bf87d30c22efef3b55f2e81
This commit is contained in:
parent
1e8fffdfb1
commit
111b0c8401
@ -8437,6 +8437,31 @@ output_handle: The list.
|
||||
TF_DerivedOperandTypeAttr element_dtype = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_TensorListGatherOp : TF_Op<"TensorListGather", [NoSideEffect]> {
|
||||
let summary = "Creates a Tensor by indexing into the TensorList.";
|
||||
|
||||
let description = [{
|
||||
Each row in the produced Tensor corresponds to the element in the TensorList
|
||||
specified by the given index (see `tf.gather`).
|
||||
|
||||
input_handle: The input tensor list.
|
||||
indices: The indices used to index into the list.
|
||||
values: The tensor.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_VariantTensor:$input_handle,
|
||||
I32Tensor:$indices,
|
||||
I32Tensor:$element_shape
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_Tensor:$values
|
||||
);
|
||||
|
||||
TF_DerivedResultTypeAttr element_dtype = TF_DerivedResultTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_TensorListGetItemOp : TF_Op<"TensorListGetItem", [NoSideEffect]> {
|
||||
let summary = "";
|
||||
|
||||
|
@ -94,7 +94,7 @@ func @main(%arg0: tensor<i32>) -> (tensor<f32>, tensor<10xf32>, tensor<i32>) {
|
||||
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32>) -> tensor<f32> {
|
||||
// CHECK-NEXT: "tf.Const"() {value = dense<[]> : tensor<0xi32>}
|
||||
%elem_shape = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32>
|
||||
// CHECK-NEXT: %[[BUFFER:.*]] = "tf.Identity"(%arg1) : (tensor<10xf32>) -> tensor<10xf32>
|
||||
// CHECK-NEXT: %[[BUFFER:.*]] = "tf.Identity"(%[[ARG1]]) : (tensor<10xf32>) -> tensor<10xf32>
|
||||
// CHECK-NEXT: %[[SIZE:.*]] = "tf.Const"() {value = dense<10> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%tl = "tf.TensorListFromTensor"(%arg1, %elem_shape) : (tensor<10xf32>, tensor<0xi32>) -> tensor<!tf.variant<tensor<f32>>>
|
||||
// CHECK-NEXT: %[[SIZE_SHAPE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>}
|
||||
@ -110,6 +110,37 @@ func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32>) -> tensor<f32> {
|
||||
|
||||
// -----
|
||||
|
||||
// Test tensor list element shape op.
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
func @main(%arg0: tensor<10x8x9xf32>) -> tensor<2xi64> {
|
||||
%elem_shape = "tf.Const"() {value = dense<[8, 9]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
%tl = "tf.TensorListFromTensor"(%arg0, %elem_shape) : (tensor<10x8x9xf32>, tensor<2xi32>) -> tensor<!tf.variant<tensor<8x9xf32>>>
|
||||
// CHECK: %[[SHAPE:.*]] = "tf.Const"() {value = dense<[8, 9]> : tensor<2xi64>} : () -> tensor<2xi64>
|
||||
%shape = "tf.TensorListElementShape"(%tl) : (tensor<!tf.variant<tensor<8x9xf32>>>) -> tensor<2xi64>
|
||||
// CHECK-NEXT: return %[[SHAPE]] : tensor<2xi64>
|
||||
return %shape: tensor<2xi64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test tensor list gather op.
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<10x8x9xf32>, %[[ARG1:.*]]: tensor<3xi32>) -> tensor<3x8x9xf32>
|
||||
func @main(%arg0: tensor<10x8x9xf32>, %arg1: tensor<3xi32>) -> tensor<3x8x9xf32> {
|
||||
%elem_shape = "tf.Const"() {value = dense<[8, 9]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
// CHECK: %[[BUFFER:.*]] = "tf.Identity"(%[[ARG0]]) : (tensor<10x8x9xf32>) -> tensor<10x8x9xf32>
|
||||
%tl = "tf.TensorListFromTensor"(%arg0, %elem_shape) : (tensor<10x8x9xf32>, tensor<2xi32>) -> tensor<!tf.variant<tensor<8x9xf32>>>
|
||||
// CHECK: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: %[[GATHER:.*]] = "tf.GatherV2"(%[[BUFFER]], %[[ARG1]], %[[AXIS]]) : (tensor<10x8x9xf32>, tensor<3xi32>, tensor<i32>) -> tensor<3x8x9xf32>
|
||||
%gather = "tf.TensorListGather"(%tl, %arg1, %elem_shape) : (tensor<!tf.variant<tensor<8x9xf32>>>, tensor<3xi32>, tensor<2xi32>) -> tensor<3x8x9xf32>
|
||||
// CHECK-NEXT: return %[[GATHER]] : tensor<3x8x9xf32>
|
||||
return %gather: tensor<3x8x9xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests while loop.
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
|
@ -53,14 +53,16 @@ Value CreateScalarConst(int value, OpBuilder builder, Location loc) {
|
||||
loc, tensorflow::ConvertTensor(scalar_tensor, &builder).ValueOrDie());
|
||||
}
|
||||
|
||||
Value GetR1Const(ArrayRef<int64_t> r1, OpBuilder builder, Location loc) {
|
||||
tensorflow::Tensor shape_tensor(tensorflow::DT_INT32,
|
||||
{static_cast<int64_t>(r1.size())});
|
||||
for (int i = 0; i < r1.size(); ++i) {
|
||||
shape_tensor.vec<tensorflow::int32>()(i) = r1[i];
|
||||
}
|
||||
Value GetR1Const(ArrayRef<int64_t> r1, OpBuilder builder, Location loc,
|
||||
int bitwidth) {
|
||||
llvm::SmallVector<APInt, 4> values;
|
||||
int64_t rank = r1.size();
|
||||
values.reserve(rank);
|
||||
for (int i = 0; i < rank; ++i) values.push_back(APInt(bitwidth, r1[i]));
|
||||
auto result_type = RankedTensorType::get(
|
||||
{rank}, IntegerType::get(bitwidth, builder.getContext()));
|
||||
return builder.create<TF::ConstOp>(
|
||||
loc, tensorflow::ConvertTensor(shape_tensor, &builder).ValueOrDie());
|
||||
loc, DenseElementsAttr::get(result_type, values));
|
||||
}
|
||||
|
||||
Value GetIndicesForElement(Value index, Value buffer, OpBuilder builder,
|
||||
|
@ -36,8 +36,9 @@ namespace collection_ops_util {
|
||||
// Creates an i32 scalar tf.Const.
|
||||
Value CreateScalarConst(int value, OpBuilder builder, Location loc);
|
||||
|
||||
// Creates an i32 vector tf.Const.
|
||||
Value GetR1Const(ArrayRef<int64_t> r1, OpBuilder builder, Location loc);
|
||||
// Creates an integer vector tf.Const.
|
||||
Value GetR1Const(ArrayRef<int64_t> r1, OpBuilder builder, Location loc,
|
||||
int bitwidth = 32);
|
||||
|
||||
// Returns the type of the size tensor used to track a data structure's element
|
||||
// count. It is a tensor<1xi32>, and we use R1 instead of a scalar because it is
|
||||
|
@ -574,6 +574,37 @@ LogicalResult HandleTensorListLengthOp(
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult HandleTensorListElementShapeOp(
|
||||
TF::TensorListElementShapeOp elem_shape,
|
||||
const llvm::SmallDenseMap<Value, SizeInfo>& buffer_to_size) {
|
||||
if (buffer_to_size.count(elem_shape.input_handle()) == 0) {
|
||||
return elem_shape.emitOpError("unknown tensor list");
|
||||
}
|
||||
auto buffer = elem_shape.input_handle();
|
||||
auto result = cutil::GetR1Const(
|
||||
buffer.getType().cast<RankedTensorType>().getShape().drop_front(),
|
||||
OpBuilder(elem_shape), elem_shape.getLoc(),
|
||||
elem_shape.shape_type().getIntOrFloatBitWidth());
|
||||
elem_shape.element_shape().replaceAllUsesWith(result);
|
||||
elem_shape.erase();
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult HandleTensorListGatherOp(
|
||||
TF::TensorListGatherOp gather,
|
||||
const llvm::SmallDenseMap<Value, SizeInfo>& buffer_to_size) {
|
||||
auto it = buffer_to_size.find(gather.input_handle());
|
||||
if (it == buffer_to_size.end()) {
|
||||
return gather.emitOpError("unknown tensor list");
|
||||
}
|
||||
auto buffer = gather.input_handle();
|
||||
auto result = cutil::GatherElements(gather.indices(), buffer,
|
||||
OpBuilder(gather), gather.getLoc());
|
||||
gather.values().replaceAllUsesWith(result);
|
||||
gather.erase();
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult DecomposeTensorListOpsInternal(
|
||||
Block* block, ModuleOp module,
|
||||
llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
|
||||
@ -619,6 +650,15 @@ LogicalResult DecomposeTensorListOpsInternal(
|
||||
} else if (auto stack = llvm::dyn_cast<TF::TensorListStackOp>(&op)) {
|
||||
stack.tensor().replaceAllUsesWith(stack.input_handle());
|
||||
stack.erase();
|
||||
} else if (auto elem_shape =
|
||||
llvm::dyn_cast<TF::TensorListElementShapeOp>(&op)) {
|
||||
if (failed(HandleTensorListElementShapeOp(elem_shape, *buffer_to_size))) {
|
||||
return failure();
|
||||
}
|
||||
} else if (auto gather = llvm::dyn_cast<TF::TensorListGatherOp>(&op)) {
|
||||
if (failed(HandleTensorListGatherOp(gather, *buffer_to_size))) {
|
||||
return failure();
|
||||
}
|
||||
} else if (auto addn = llvm::dyn_cast<TF::AddNOp>(&op)) {
|
||||
auto it = buffer_to_size->find(addn.getOperand(0));
|
||||
if (it != buffer_to_size->end()) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user