[MLIR:TF/XLA] Constant fold VariableShape and TensorListElementShape
PiperOrigin-RevId: 299217733 Change-Id: I8d567592270f98f3348935599656e7e640401015
This commit is contained in:
parent
4569e70aa2
commit
44506a0e68
@ -7054,6 +7054,27 @@ lengths: Output tensor containing sizes of the 0th dimension of tensors in the l
|
||||
TF_DerivedResultTypeAttr element_dtype = TF_DerivedResultTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_TensorListElementShapeOp : TF_Op<"TensorListElementShape", [NoSideEffect]> {
|
||||
let summary = "The shape of the elements of the given list, as a tensor.";
|
||||
|
||||
let description = [{
|
||||
input_handle: the list
|
||||
element_shape: the shape of elements of the list
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_VariantTensor:$input_handle
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_I32OrI64Tensor:$element_shape
|
||||
);
|
||||
|
||||
TF_DerivedResultTypeAttr shape_type = TF_DerivedResultTypeAttr<0>;
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TF_TensorListFromTensorOp : TF_Op<"TensorListFromTensor", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Creates a TensorList which, when stacked, has the value of `tensor`.
|
||||
@ -7822,6 +7843,8 @@ shape(t) ==> [2, 2, 3]
|
||||
let verifier = [{
|
||||
return Verify(*this);
|
||||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TF_WhereOp : TF_Op<"Where", [NoSideEffect]> {
|
||||
|
@ -2907,6 +2907,19 @@ static LogicalResult Verify(TensorListReserveOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorListElementShapeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult TensorListElementShapeOp::fold(ArrayRef<Attribute> operands) {
|
||||
int width =
|
||||
getType().cast<ShapedType>().getElementType().getIntOrFloatBitWidth();
|
||||
auto variant_type =
|
||||
getElementTypeOrSelf(getOperand().getType()).cast<TF::VariantType>();
|
||||
if (variant_type.getSubtypes().empty()) return {};
|
||||
return ConvertShapeToAttr(variant_type.getSubtypes()[0], width);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorListStackOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -3177,6 +3190,15 @@ static LogicalResult Verify(VariableShapeOp op) {
|
||||
}
|
||||
}
|
||||
|
||||
OpFoldResult VariableShapeOp::fold(ArrayRef<Attribute> operands) {
|
||||
int width =
|
||||
getType().cast<ShapedType>().getElementType().getIntOrFloatBitWidth();
|
||||
auto resource_type =
|
||||
getElementTypeOrSelf(getOperand().getType()).cast<TF::ResourceType>();
|
||||
if (resource_type.getSubtypes().empty()) return {};
|
||||
return ConvertShapeToAttr(resource_type.getSubtypes()[0], width);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// WhileOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -213,3 +213,21 @@ func @testRemoteDevice() -> tensor<2x2xi32> {
|
||||
// CHECK-NEXT: return [[cst]] : tensor<2x2xi32>
|
||||
return %2: tensor<2x2xi32>
|
||||
}
|
||||
|
||||
// Tests ops that variable shapes are correctly evaluated on static types.
|
||||
// CHECK-LABEL: func @testVariableShape
|
||||
func @testVariableShape(%arg0: tensor<!tf.resource<tensor<2x4xf32>>>) -> tensor<2xi32> {
|
||||
%0 = "tf.VariableShape"(%arg0) : (tensor<!tf.resource<tensor<2x4xf32>>>) -> tensor<2xi32>
|
||||
// CHECK: [[cst:%.*]] = "tf.Const{{.*}} dense<{{\[}}2, 4]> : tensor<2xi32>
|
||||
// CHECK-NEXT: return [[cst]] : tensor<2xi32>
|
||||
return %0: tensor<2xi32>
|
||||
}
|
||||
|
||||
// Tests ops that tensor list shapes are correctly evaluated on static types.
|
||||
// CHECK-LABEL: func @testTensorListElementShape
|
||||
func @testTensorListElementShape(%arg0: tensor<!tf.variant<tensor<2x4xf32>>>) -> tensor<2xi32> {
|
||||
%0 = "tf.TensorListElementShape"(%arg0) : (tensor<!tf.variant<tensor<2x4xf32>>>) -> tensor<2xi32>
|
||||
// CHECK: [[cst:%.*]] = "tf.Const{{.*}} dense<{{\[}}2, 4]> : tensor<2xi32>
|
||||
// CHECK-NEXT: return [[cst]] : tensor<2xi32>
|
||||
return %0: tensor<2xi32>
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user