Infer TensorArray element shape from use in TensorArrayGatherV3 op.

PiperOrigin-RevId: 342991093
Change-Id: I4679d9fc0855e923961ecefac1d098cb2dc4208a
This commit is contained in:
Prakalp Srivastava 2020-11-17 18:43:29 -08:00 committed by TensorFlower Gardener
parent 4fe8e6e882
commit fa933c373a
2 changed files with 44 additions and 1 deletions

View File

@ -68,6 +68,35 @@ func @main() -> tensor<i32> {
return %size_out : tensor<i32>
}
// -----
// Test inferring shape from the result type of gather.
// CHECK-LABEL: func @main
func @main() -> tensor<2x3xf32> {
%size = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
// CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
%ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = #tf.shape<*>, dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
%indices = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
%gather = "tf.TensorArrayGatherV3"(%ta#0, %indices, %ta#1) : (tensor<!tf.resource>, tensor<2xi32>, tensor<f32>) -> tensor<2x3xf32>
return %gather : tensor<2x3xf32>
}
// -----
// Test inferring shape from the element_shape attribute of gather.
// CHECK-LABEL: func @main
func @main() -> tensor<*xf32> {
%size = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
// CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
%ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = #tf.shape<*>, dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
%indices = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
%gather = "tf.TensorArrayGatherV3"(%ta#0, %indices, %ta#1) {element_shape = #tf.shape<3>} : (tensor<!tf.resource>, tensor<2xi32>, tensor<f32>) -> tensor<*xf32>
return %gather : tensor<*xf32>
}
// -----
// Test tensor array concat and split.

View File

@ -108,7 +108,7 @@ llvm::Optional<llvm::SmallVector<int64_t, 8>> GetTensorArrayElementShape(
auto element_shape = ta.element_shapeAttr().cast<mlir::TF::ShapeAttr>();
if (element_shape.hasStaticShape()) {
auto shape = element_shape.getShape();
// Convert int64 to int64_.
// Convert int64 to int64_t.
llvm::SmallVector<int64_t, 8> dims(shape.begin(), shape.end());
return dims;
}
@ -141,6 +141,20 @@ llvm::Optional<llvm::SmallVector<int64_t, 8>> GetTensorArrayElementShape(
if (!t || t.getShape().empty()) return llvm::None;
return RankedTensorType::get(t.getShape().drop_front(),
t.getElementType());
} else if (auto gather =
llvm::dyn_cast<TF::TensorArrayGatherV3Op>(user)) {
// Try to infer from result type of gather.
auto t = gather.value().getType().dyn_cast<RankedTensorType>();
if (t && !t.getShape().empty())
return RankedTensorType::get(t.getShape().drop_front(),
t.getElementType());
// Try to infer from `element_shape` attribute of gather.
auto element_shape = gather.element_shapeAttr()
.dyn_cast_or_null<mlir::TF::ShapeAttr>();
if (element_shape && element_shape.hasStaticShape()) {
return RankedTensorType::get(element_shape.getShape(),
gather.dtype());
}
}
return llvm::None;
});