Infer TensorArray element shape from use in TensorArrayGatherV3
op.
PiperOrigin-RevId: 342991093 Change-Id: I4679d9fc0855e923961ecefac1d098cb2dc4208a
This commit is contained in:
parent
4fe8e6e882
commit
fa933c373a
@ -68,6 +68,35 @@ func @main() -> tensor<i32> {
|
||||
return %size_out : tensor<i32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test inferring shape from the result type of gather.
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
func @main() -> tensor<2x3xf32> {
|
||||
%size = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
|
||||
%ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = #tf.shape<*>, dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
|
||||
%indices = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
%gather = "tf.TensorArrayGatherV3"(%ta#0, %indices, %ta#1) : (tensor<!tf.resource>, tensor<2xi32>, tensor<f32>) -> tensor<2x3xf32>
|
||||
return %gather : tensor<2x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test inferring shape from the element_shape attribute of gather.
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
func @main() -> tensor<*xf32> {
|
||||
%size = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
|
||||
%ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = #tf.shape<*>, dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
|
||||
%indices = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
%gather = "tf.TensorArrayGatherV3"(%ta#0, %indices, %ta#1) {element_shape = #tf.shape<3>} : (tensor<!tf.resource>, tensor<2xi32>, tensor<f32>) -> tensor<*xf32>
|
||||
return %gather : tensor<*xf32>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// Test tensor array concat and split.
|
||||
|
@ -108,7 +108,7 @@ llvm::Optional<llvm::SmallVector<int64_t, 8>> GetTensorArrayElementShape(
|
||||
auto element_shape = ta.element_shapeAttr().cast<mlir::TF::ShapeAttr>();
|
||||
if (element_shape.hasStaticShape()) {
|
||||
auto shape = element_shape.getShape();
|
||||
// Convert int64 to int64_.
|
||||
// Convert int64 to int64_t.
|
||||
llvm::SmallVector<int64_t, 8> dims(shape.begin(), shape.end());
|
||||
return dims;
|
||||
}
|
||||
@ -141,6 +141,20 @@ llvm::Optional<llvm::SmallVector<int64_t, 8>> GetTensorArrayElementShape(
|
||||
if (!t || t.getShape().empty()) return llvm::None;
|
||||
return RankedTensorType::get(t.getShape().drop_front(),
|
||||
t.getElementType());
|
||||
} else if (auto gather =
|
||||
llvm::dyn_cast<TF::TensorArrayGatherV3Op>(user)) {
|
||||
// Try to infer from result type of gather.
|
||||
auto t = gather.value().getType().dyn_cast<RankedTensorType>();
|
||||
if (t && !t.getShape().empty())
|
||||
return RankedTensorType::get(t.getShape().drop_front(),
|
||||
t.getElementType());
|
||||
// Try to infer from `element_shape` attribute of gather.
|
||||
auto element_shape = gather.element_shapeAttr()
|
||||
.dyn_cast_or_null<mlir::TF::ShapeAttr>();
|
||||
if (element_shape && element_shape.hasStaticShape()) {
|
||||
return RankedTensorType::get(element_shape.getShape(),
|
||||
gather.dtype());
|
||||
}
|
||||
}
|
||||
return llvm::None;
|
||||
});
|
||||
|
Loading…
x
Reference in New Issue
Block a user