[MLIR:TF/XLA] Constant fold VariableShape and TensorListElementShape

PiperOrigin-RevId: 299217733
Change-Id: I8d567592270f98f3348935599656e7e640401015
This commit is contained in:
Yuanzhong Xu 2020-03-05 16:03:04 -08:00 committed by TensorFlower Gardener
parent 4569e70aa2
commit 44506a0e68
3 changed files with 63 additions and 0 deletions

View File

@ -7054,6 +7054,27 @@ lengths: Output tensor containing sizes of the 0th dimension of tensors in the l
TF_DerivedResultTypeAttr element_dtype = TF_DerivedResultTypeAttr<0>;
}
def TF_TensorListElementShapeOp : TF_Op<"TensorListElementShape", [NoSideEffect]> {
let summary = "The shape of the elements of the given list, as a tensor.";
let description = [{
input_handle: the list
element_shape: the shape of elements of the list
}];
let arguments = (ins
TF_VariantTensor:$input_handle
);
let results = (outs
TF_I32OrI64Tensor:$element_shape
);
TF_DerivedResultTypeAttr shape_type = TF_DerivedResultTypeAttr<0>;
let hasFolder = 1;
}
def TF_TensorListFromTensorOp : TF_Op<"TensorListFromTensor", [NoSideEffect]> {
let summary = [{
Creates a TensorList which, when stacked, has the value of `tensor`.
@ -7822,6 +7843,8 @@ shape(t) ==> [2, 2, 3]
let verifier = [{
return Verify(*this);
}];
let hasFolder = 1;
}
def TF_WhereOp : TF_Op<"Where", [NoSideEffect]> {

View File

@ -2907,6 +2907,19 @@ static LogicalResult Verify(TensorListReserveOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// TensorListElementShapeOp
//===----------------------------------------------------------------------===//
OpFoldResult TensorListElementShapeOp::fold(ArrayRef<Attribute> operands) {
int width =
getType().cast<ShapedType>().getElementType().getIntOrFloatBitWidth();
auto variant_type =
getElementTypeOrSelf(getOperand().getType()).cast<TF::VariantType>();
if (variant_type.getSubtypes().empty()) return {};
return ConvertShapeToAttr(variant_type.getSubtypes()[0], width);
}
//===----------------------------------------------------------------------===//
// TensorListStackOp
//===----------------------------------------------------------------------===//
@ -3177,6 +3190,15 @@ static LogicalResult Verify(VariableShapeOp op) {
}
}
OpFoldResult VariableShapeOp::fold(ArrayRef<Attribute> operands) {
int width =
getType().cast<ShapedType>().getElementType().getIntOrFloatBitWidth();
auto resource_type =
getElementTypeOrSelf(getOperand().getType()).cast<TF::ResourceType>();
if (resource_type.getSubtypes().empty()) return {};
return ConvertShapeToAttr(resource_type.getSubtypes()[0], width);
}
//===----------------------------------------------------------------------===//
// WhileOp
//===----------------------------------------------------------------------===//

View File

@ -213,3 +213,21 @@ func @testRemoteDevice() -> tensor<2x2xi32> {
// CHECK-NEXT: return [[cst]] : tensor<2x2xi32>
return %2: tensor<2x2xi32>
}
// Tests ops that variable shapes are correctly evaluated on static types.
// CHECK-LABEL: func @testVariableShape
func @testVariableShape(%arg0: tensor<!tf.resource<tensor<2x4xf32>>>) -> tensor<2xi32> {
%0 = "tf.VariableShape"(%arg0) : (tensor<!tf.resource<tensor<2x4xf32>>>) -> tensor<2xi32>
// CHECK: [[cst:%.*]] = "tf.Const{{.*}} dense<{{\[}}2, 4]> : tensor<2xi32>
// CHECK-NEXT: return [[cst]] : tensor<2xi32>
return %0: tensor<2xi32>
}
// Tests ops that tensor list shapes are correctly evaluated on static types.
// CHECK-LABEL: func @testTensorListElementShape
func @testTensorListElementShape(%arg0: tensor<!tf.variant<tensor<2x4xf32>>>) -> tensor<2xi32> {
%0 = "tf.TensorListElementShape"(%arg0) : (tensor<!tf.variant<tensor<2x4xf32>>>) -> tensor<2xi32>
// CHECK: [[cst:%.*]] = "tf.Const{{.*}} dense<{{\[}}2, 4]> : tensor<2xi32>
// CHECK-NEXT: return [[cst]] : tensor<2xi32>
return %0: tensor<2xi32>
}