From 44506a0e68d7037cabdf9f82466677623fb513ec Mon Sep 17 00:00:00 2001 From: Yuanzhong Xu Date: Thu, 5 Mar 2020 16:03:04 -0800 Subject: [PATCH] [MLIR:TF/XLA] Constant fold VariableShape and TensorListElementShape PiperOrigin-RevId: 299217733 Change-Id: I8d567592270f98f3348935599656e7e640401015 --- .../mlir/tensorflow/ir/tf_generated_ops.td | 23 +++++++++++++++++++ .../compiler/mlir/tensorflow/ir/tf_ops.cc | 22 ++++++++++++++++++ .../mlir/tensorflow/tests/constant-fold.mlir | 18 +++++++++++++++ 3 files changed, 63 insertions(+) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 66cc2ded950..2615128d60e 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -7054,6 +7054,27 @@ lengths: Output tensor containing sizes of the 0th dimension of tensors in the l TF_DerivedResultTypeAttr element_dtype = TF_DerivedResultTypeAttr<0>; } +def TF_TensorListElementShapeOp : TF_Op<"TensorListElementShape", [NoSideEffect]> { + let summary = "The shape of the elements of the given list, as a tensor."; + + let description = [{ +input_handle: the list + element_shape: the shape of elements of the list + }]; + + let arguments = (ins + TF_VariantTensor:$input_handle + ); + + let results = (outs + TF_I32OrI64Tensor:$element_shape + ); + + TF_DerivedResultTypeAttr shape_type = TF_DerivedResultTypeAttr<0>; + + let hasFolder = 1; +} + def TF_TensorListFromTensorOp : TF_Op<"TensorListFromTensor", [NoSideEffect]> { let summary = [{ Creates a TensorList which, when stacked, has the value of `tensor`. @@ -7822,6 +7843,8 @@ shape(t) ==> [2, 2, 3] let verifier = [{ return Verify(*this); }]; + + let hasFolder = 1; } def TF_WhereOp : TF_Op<"Where", [NoSideEffect]> { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index d9c93d16bd4..4d172b9cb04 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -2907,6 +2907,19 @@ static LogicalResult Verify(TensorListReserveOp op) { return success(); } +//===----------------------------------------------------------------------===// +// TensorListElementShapeOp +//===----------------------------------------------------------------------===// + +OpFoldResult TensorListElementShapeOp::fold(ArrayRef operands) { + int width = + getType().cast().getElementType().getIntOrFloatBitWidth(); + auto variant_type = + getElementTypeOrSelf(getOperand().getType()).cast(); + if (variant_type.getSubtypes().empty()) return {}; + return ConvertShapeToAttr(variant_type.getSubtypes()[0], width); +} + //===----------------------------------------------------------------------===// // TensorListStackOp //===----------------------------------------------------------------------===// @@ -3177,6 +3190,15 @@ static LogicalResult Verify(VariableShapeOp op) { } } +OpFoldResult VariableShapeOp::fold(ArrayRef operands) { + int width = + getType().cast().getElementType().getIntOrFloatBitWidth(); + auto resource_type = + getElementTypeOrSelf(getOperand().getType()).cast(); + if (resource_type.getSubtypes().empty()) return {}; + return ConvertShapeToAttr(resource_type.getSubtypes()[0], width); +} + //===----------------------------------------------------------------------===// // WhileOp //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir index d9727e94bb6..411599053e5 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir @@ -213,3 +213,21 @@ func @testRemoteDevice() -> tensor<2x2xi32> { // CHECK-NEXT: return [[cst]] : tensor<2x2xi32> return %2: tensor<2x2xi32> } + +// Tests ops that variable shapes are correctly evaluated on static types. +// CHECK-LABEL: func @testVariableShape +func @testVariableShape(%arg0: tensor>>) -> tensor<2xi32> { + %0 = "tf.VariableShape"(%arg0) : (tensor>>) -> tensor<2xi32> + // CHECK: [[cst:%.*]] = "tf.Const{{.*}} dense<{{\[}}2, 4]> : tensor<2xi32> + // CHECK-NEXT: return [[cst]] : tensor<2xi32> + return %0: tensor<2xi32> +} + +// Tests ops that tensor list shapes are correctly evaluated on static types. +// CHECK-LABEL: func @testTensorListElementShape +func @testTensorListElementShape(%arg0: tensor>>) -> tensor<2xi32> { + %0 = "tf.TensorListElementShape"(%arg0) : (tensor>>) -> tensor<2xi32> + // CHECK: [[cst:%.*]] = "tf.Const{{.*}} dense<{{\[}}2, 4]> : tensor<2xi32> + // CHECK-NEXT: return [[cst]] : tensor<2xi32> + return %0: tensor<2xi32> +}