[MLIR:TF/XLA] Handle remaining tensor list ops in decomposition

PiperOrigin-RevId: 305066660
Change-Id: Ib11fd1af6350eaf80bf87d30c22efef3b55f2e81
This commit is contained in:
Yuanzhong Xu 2020-04-06 10:49:59 -07:00 committed by TensorFlower Gardener
parent 1e8fffdfb1
commit 111b0c8401
5 changed files with 109 additions and 10 deletions

View File

@ -8437,6 +8437,31 @@ output_handle: The list.
TF_DerivedOperandTypeAttr element_dtype = TF_DerivedOperandTypeAttr<0>;
}
def TF_TensorListGatherOp : TF_Op<"TensorListGather", [NoSideEffect]> {
let summary = "Creates a Tensor by indexing into the TensorList.";
let description = [{
Each row in the produced Tensor corresponds to the element in the TensorList
specified by the given index (see `tf.gather`).
input_handle: The input tensor list.
indices: The indices used to index into the list.
values: The tensor.
}];
let arguments = (ins
TF_VariantTensor:$input_handle,
I32Tensor:$indices,
I32Tensor:$element_shape
);
let results = (outs
TF_Tensor:$values
);
TF_DerivedResultTypeAttr element_dtype = TF_DerivedResultTypeAttr<0>;
}
def TF_TensorListGetItemOp : TF_Op<"TensorListGetItem", [NoSideEffect]> {
let summary = "";

View File

@ -94,7 +94,7 @@ func @main(%arg0: tensor<i32>) -> (tensor<f32>, tensor<10xf32>, tensor<i32>) {
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32>) -> tensor<f32> {
// CHECK-NEXT: "tf.Const"() {value = dense<[]> : tensor<0xi32>}
%elem_shape = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32>
// CHECK-NEXT: %[[BUFFER:.*]] = "tf.Identity"(%arg1) : (tensor<10xf32>) -> tensor<10xf32>
// CHECK-NEXT: %[[BUFFER:.*]] = "tf.Identity"(%[[ARG1]]) : (tensor<10xf32>) -> tensor<10xf32>
// CHECK-NEXT: %[[SIZE:.*]] = "tf.Const"() {value = dense<10> : tensor<1xi32>} : () -> tensor<1xi32>
%tl = "tf.TensorListFromTensor"(%arg1, %elem_shape) : (tensor<10xf32>, tensor<0xi32>) -> tensor<!tf.variant<tensor<f32>>>
// CHECK-NEXT: %[[SIZE_SHAPE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>}
@ -110,6 +110,37 @@ func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32>) -> tensor<f32> {
// -----
// Test tensor list element shape op.
// CHECK-LABEL: func @main
func @main(%arg0: tensor<10x8x9xf32>) -> tensor<2xi64> {
%elem_shape = "tf.Const"() {value = dense<[8, 9]> : tensor<2xi32>} : () -> tensor<2xi32>
%tl = "tf.TensorListFromTensor"(%arg0, %elem_shape) : (tensor<10x8x9xf32>, tensor<2xi32>) -> tensor<!tf.variant<tensor<8x9xf32>>>
// CHECK: %[[SHAPE:.*]] = "tf.Const"() {value = dense<[8, 9]> : tensor<2xi64>} : () -> tensor<2xi64>
%shape = "tf.TensorListElementShape"(%tl) : (tensor<!tf.variant<tensor<8x9xf32>>>) -> tensor<2xi64>
// CHECK-NEXT: return %[[SHAPE]] : tensor<2xi64>
return %shape: tensor<2xi64>
}
// -----
// Test tensor list gather op.
// CHECK-LABEL: func @main
// CHECK-SAME: (%[[ARG0:.*]]: tensor<10x8x9xf32>, %[[ARG1:.*]]: tensor<3xi32>) -> tensor<3x8x9xf32>
func @main(%arg0: tensor<10x8x9xf32>, %arg1: tensor<3xi32>) -> tensor<3x8x9xf32> {
%elem_shape = "tf.Const"() {value = dense<[8, 9]> : tensor<2xi32>} : () -> tensor<2xi32>
// CHECK: %[[BUFFER:.*]] = "tf.Identity"(%[[ARG0]]) : (tensor<10x8x9xf32>) -> tensor<10x8x9xf32>
%tl = "tf.TensorListFromTensor"(%arg0, %elem_shape) : (tensor<10x8x9xf32>, tensor<2xi32>) -> tensor<!tf.variant<tensor<8x9xf32>>>
// CHECK: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
// CHECK: %[[GATHER:.*]] = "tf.GatherV2"(%[[BUFFER]], %[[ARG1]], %[[AXIS]]) : (tensor<10x8x9xf32>, tensor<3xi32>, tensor<i32>) -> tensor<3x8x9xf32>
%gather = "tf.TensorListGather"(%tl, %arg1, %elem_shape) : (tensor<!tf.variant<tensor<8x9xf32>>>, tensor<3xi32>, tensor<2xi32>) -> tensor<3x8x9xf32>
// CHECK-NEXT: return %[[GATHER]] : tensor<3x8x9xf32>
return %gather: tensor<3x8x9xf32>
}
// -----
// Tests while loop.
// CHECK-LABEL: func @main

View File

@ -53,14 +53,16 @@ Value CreateScalarConst(int value, OpBuilder builder, Location loc) {
loc, tensorflow::ConvertTensor(scalar_tensor, &builder).ValueOrDie());
}
Value GetR1Const(ArrayRef<int64_t> r1, OpBuilder builder, Location loc) {
tensorflow::Tensor shape_tensor(tensorflow::DT_INT32,
{static_cast<int64_t>(r1.size())});
for (int i = 0; i < r1.size(); ++i) {
shape_tensor.vec<tensorflow::int32>()(i) = r1[i];
}
Value GetR1Const(ArrayRef<int64_t> r1, OpBuilder builder, Location loc,
int bitwidth) {
llvm::SmallVector<APInt, 4> values;
int64_t rank = r1.size();
values.reserve(rank);
for (int i = 0; i < rank; ++i) values.push_back(APInt(bitwidth, r1[i]));
auto result_type = RankedTensorType::get(
{rank}, IntegerType::get(bitwidth, builder.getContext()));
return builder.create<TF::ConstOp>(
loc, tensorflow::ConvertTensor(shape_tensor, &builder).ValueOrDie());
loc, DenseElementsAttr::get(result_type, values));
}
Value GetIndicesForElement(Value index, Value buffer, OpBuilder builder,

View File

@ -36,8 +36,9 @@ namespace collection_ops_util {
// Creates an i32 scalar tf.Const.
Value CreateScalarConst(int value, OpBuilder builder, Location loc);
// Creates an i32 vector tf.Const.
Value GetR1Const(ArrayRef<int64_t> r1, OpBuilder builder, Location loc);
// Creates an integer vector tf.Const.
Value GetR1Const(ArrayRef<int64_t> r1, OpBuilder builder, Location loc,
int bitwidth = 32);
// Returns the type of the size tensor used to track a data structure's element
// count. It is a tensor<1xi32>, and we use R1 instead of a scalar because it is

View File

@ -574,6 +574,37 @@ LogicalResult HandleTensorListLengthOp(
return success();
}
LogicalResult HandleTensorListElementShapeOp(
TF::TensorListElementShapeOp elem_shape,
const llvm::SmallDenseMap<Value, SizeInfo>& buffer_to_size) {
if (buffer_to_size.count(elem_shape.input_handle()) == 0) {
return elem_shape.emitOpError("unknown tensor list");
}
auto buffer = elem_shape.input_handle();
auto result = cutil::GetR1Const(
buffer.getType().cast<RankedTensorType>().getShape().drop_front(),
OpBuilder(elem_shape), elem_shape.getLoc(),
elem_shape.shape_type().getIntOrFloatBitWidth());
elem_shape.element_shape().replaceAllUsesWith(result);
elem_shape.erase();
return success();
}
LogicalResult HandleTensorListGatherOp(
TF::TensorListGatherOp gather,
const llvm::SmallDenseMap<Value, SizeInfo>& buffer_to_size) {
auto it = buffer_to_size.find(gather.input_handle());
if (it == buffer_to_size.end()) {
return gather.emitOpError("unknown tensor list");
}
auto buffer = gather.input_handle();
auto result = cutil::GatherElements(gather.indices(), buffer,
OpBuilder(gather), gather.getLoc());
gather.values().replaceAllUsesWith(result);
gather.erase();
return success();
}
LogicalResult DecomposeTensorListOpsInternal(
Block* block, ModuleOp module,
llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
@ -619,6 +650,15 @@ LogicalResult DecomposeTensorListOpsInternal(
} else if (auto stack = llvm::dyn_cast<TF::TensorListStackOp>(&op)) {
stack.tensor().replaceAllUsesWith(stack.input_handle());
stack.erase();
} else if (auto elem_shape =
llvm::dyn_cast<TF::TensorListElementShapeOp>(&op)) {
if (failed(HandleTensorListElementShapeOp(elem_shape, *buffer_to_size))) {
return failure();
}
} else if (auto gather = llvm::dyn_cast<TF::TensorListGatherOp>(&op)) {
if (failed(HandleTensorListGatherOp(gather, *buffer_to_size))) {
return failure();
}
} else if (auto addn = llvm::dyn_cast<TF::AddNOp>(&op)) {
auto it = buffer_to_size->find(addn.getOperand(0));
if (it != buffer_to_size->end()) {