[MLIR:TF/XLA] Support control flow for tensor array decomposition
Gradients are accessed via names, and a gradient can be referenced inside a while loop etc. We first recursively analyze all required gradients in nested functions, and then lift them outside. PiperOrigin-RevId: 303423766 Change-Id: Icae790de5f4f7a1d8cdec37cac303bf51f2c3306
This commit is contained in:
parent
353ab1535d
commit
0c6b402cac
@ -187,3 +187,199 @@ func @main() {
|
||||
%write3 = "tf.TensorArrayWriteV3"(%grad3#0, %index, %value, %grad3#1) : (tensor<!tf.resource>, tensor<i32>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests while loop with access to the tensor array defined outside and its
|
||||
// gradient defined inside. The gradient creation should be moved outside.
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
func @main() -> () {
|
||||
// CHECK: %[[SIZE:.*]] = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
|
||||
%size = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
|
||||
%index = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
|
||||
// CHECK: %[[GVAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
|
||||
%ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
|
||||
// CHECK: "tf.While"(%[[VAR]], %[[SIZE]], %[[GVAR]])
|
||||
%1:2 = "tf.While"(%ta#0, %size) {
|
||||
body = @while_body, cond = @while_cond, device = "", is_stateless = false}
|
||||
: (tensor<!tf.resource>, tensor<i32>) -> (tensor<!tf.resource>, tensor<i32>)
|
||||
// CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VAR]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
|
||||
// CHECK: "tf.Slice"(%[[READ]],
|
||||
%read = "tf.TensorArrayReadV3"(%1#0, %index, %ta#1) : (tensor<!tf.resource>, tensor<i32>, tensor<f32>) -> tensor<3xf32>
|
||||
return
|
||||
}
|
||||
// CHECK: func @while_body(%[[BARG0:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[BARG1:.*]]: tensor<i32>, %[[BARG2:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>)
|
||||
func @while_body(%arg0: tensor<!tf.resource>, %arg1: tensor<i32>) -> (tensor<!tf.resource>, tensor<i32>) {
|
||||
// CHECK: %[[CONST1:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
%const1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: %[[SUB:.*]] = "tf.Sub"(%[[BARG1]], %[[CONST1]])
|
||||
%sub = "tf.Sub"(%arg1, %const1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
%elem = "tf._SomeOp"() : () -> tensor<3xf32>
|
||||
%flow = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
|
||||
// CHECK: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[BARG0]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
|
||||
// CHECK: %[[UPDATE1:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ1]],
|
||||
// CHECK: "tf.AssignVariableOp"(%[[BARG0]], %[[UPDATE1]])
|
||||
%write = "tf.TensorArrayWriteV3"(%arg0, %sub, %elem, %flow) : (tensor<!tf.resource>, tensor<i32>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
|
||||
%grad:2 = "tf.TensorArrayGradV3"(%arg0, %write) {source = "a"} : (tensor<!tf.resource>, tensor<f32>) -> (tensor<!tf.resource>, tensor<f32>)
|
||||
// CHECK: %[[READ2:.*]] = "tf.ReadVariableOp"(%[[BARG2]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
|
||||
// CHECK: %[[UPDATE2:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ2]],
|
||||
// CHECK: "tf.AssignVariableOp"(%[[BARG2]], %[[UPDATE2]])
|
||||
%gwrite = "tf.TensorArrayWriteV3"(%grad#0, %sub, %elem, %grad#1) : (tensor<!tf.resource>, tensor<i32>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: return %[[BARG0]], %[[SUB]], %[[BARG2]]
|
||||
return %arg0, %sub : tensor<!tf.resource>, tensor<i32>
|
||||
}
|
||||
// CHECK: func @while_cond(%[[CARG0:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[CARG1:.*]]: tensor<i32>, %[[CARG2:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>)
|
||||
func @while_cond(%arg0: tensor<!tf.resource>, %arg1: tensor<i32>) -> tensor<i32> {
|
||||
// CHECK-NEXT: return %[[CARG1]]
|
||||
return %arg1 : tensor<i32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests If op with access to the tensor array defined outside and its gradient
|
||||
// defined inside. The gradient creation should be moved outside.
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
func @main() -> () {
|
||||
// CHECK: %[[SIZE:.*]] = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
|
||||
%size = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
|
||||
%index = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
|
||||
%ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
|
||||
// CHECK: %[[COND:.*]] = "tf._SomeOp"() : () -> tensor<i1>
|
||||
%cond = "tf._SomeOp"() : () -> tensor<i1>
|
||||
// CHECK: %[[GVAR1:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
|
||||
// CHECK: %[[GVAR2:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
|
||||
// CHECK: "tf.If"(%[[COND]], %[[VAR]], %[[GVAR1]], %[[GVAR2]])
|
||||
%1 = "tf.If"(%cond, %ta#0) {
|
||||
then_branch = @then_branch, else_branch = @else_branch, device = "", is_stateless = false}
|
||||
: (tensor<i1>, tensor<!tf.resource>) -> tensor<!tf.resource>
|
||||
// CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VAR]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
|
||||
// CHECK: "tf.Slice"(%[[READ]],
|
||||
%read = "tf.TensorArrayReadV3"(%1, %index, %ta#1) : (tensor<!tf.resource>, tensor<i32>, tensor<f32>) -> tensor<3xf32>
|
||||
return
|
||||
}
|
||||
// CHECK: func @then_branch(%[[TARG0:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[TARG1:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[TARG2:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>)
|
||||
func @then_branch(%arg0: tensor<!tf.resource>) -> tensor<!tf.resource> {
|
||||
%const1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
%elem = "tf._SomeOp"() : () -> tensor<3xf32>
|
||||
%flow = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
|
||||
// CHECK: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[TARG1]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
|
||||
// CHECK: %[[UPDATE1:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ1]],
|
||||
// CHECK: "tf.AssignVariableOp"(%[[TARG1]], %[[UPDATE1]])
|
||||
%grad:2 = "tf.TensorArrayGradV3"(%arg0, %flow) {source = "a"} : (tensor<!tf.resource>, tensor<f32>) -> (tensor<!tf.resource>, tensor<f32>)
|
||||
%gwrite = "tf.TensorArrayWriteV3"(%grad#0, %const1, %elem, %grad#1) : (tensor<!tf.resource>, tensor<i32>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: return %[[TARG0]]
|
||||
return %arg0 : tensor<!tf.resource>
|
||||
}
|
||||
// CHECK: func @else_branch(%[[EARG0:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[EARG1:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[EARG2:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>)
|
||||
func @else_branch(%arg0: tensor<!tf.resource>) -> tensor<!tf.resource> {
|
||||
// CHECK: %[[CONST1:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
%const1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
%elem = "tf._SomeOp"() : () -> tensor<3xf32>
|
||||
%flow = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
|
||||
// CHECK: %[[READ2:.*]] = "tf.ReadVariableOp"(%[[EARG2]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
|
||||
// CHECK: %[[UPDATE2:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ2]],
|
||||
// CHECK: "tf.AssignVariableOp"(%[[EARG2]], %[[UPDATE2]])
|
||||
%grad:2 = "tf.TensorArrayGradV3"(%arg0, %flow) {source = "b"} : (tensor<!tf.resource>, tensor<f32>) -> (tensor<!tf.resource>, tensor<f32>)
|
||||
%gwrite = "tf.TensorArrayWriteV3"(%grad#0, %const1, %elem, %grad#1) : (tensor<!tf.resource>, tensor<i32>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: return %[[EARG0]]
|
||||
return %arg0 : tensor<!tf.resource>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests (Stateful)PartitionedCall op with access to the tensor array defined
|
||||
// outside and its gradient defined inside. The gradient creation should be
|
||||
// moved outside.
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
func @main() -> () {
|
||||
// CHECK: %[[SIZE:.*]] = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
|
||||
%size = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
|
||||
%index = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
|
||||
%ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
|
||||
// CHECK: %[[COND:.*]] = "tf._SomeOp"() : () -> tensor<i1>
|
||||
%cond = "tf._SomeOp"() : () -> tensor<i1>
|
||||
// CHECK: %[[GVAR1:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
|
||||
%grad:2 = "tf.TensorArrayGradV3"(%ta#0, %ta#1) {source = "a"} : (tensor<!tf.resource>, tensor<f32>) -> (tensor<!tf.resource>, tensor<f32>)
|
||||
// CHECK: %[[GVAR2:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
|
||||
// CHECK: "tf.StatefulPartitionedCall"(%[[VAR]], %[[GVAR1]], %[[GVAR2]])
|
||||
// CHECK-SAME: f = @callee_tensorarray_decomposed
|
||||
%call = "tf.StatefulPartitionedCall"(%ta#0) {f = @callee, config = "", config_proto = "", executor_type = ""}
|
||||
: (tensor<!tf.resource>) -> tensor<!tf.resource>
|
||||
// CHECK: "tf.PartitionedCall"(%[[VAR]], %[[GVAR1]], %[[GVAR2]])
|
||||
// CHECK-SAME: f = @callee_tensorarray_decomposed
|
||||
%call2 = "tf.PartitionedCall"(%call) {f = @callee, config = "", config_proto = "", executor_type = ""}
|
||||
: (tensor<!tf.resource>) -> tensor<!tf.resource>
|
||||
// CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VAR]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
|
||||
// CHECK: "tf.Slice"(%[[READ]],
|
||||
%read = "tf.TensorArrayReadV3"(%call2, %index, %ta#1) : (tensor<!tf.resource>, tensor<i32>, tensor<f32>) -> tensor<3xf32>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @callee
|
||||
// CHECK-SAME: (%[[OCARG0:.*]]: tensor<!tf.resource>) -> tensor<!tf.resource>
|
||||
func @callee(%arg0: tensor<!tf.resource>) -> tensor<!tf.resource> {
|
||||
%const1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
%elem = "tf._SomeOp"() : () -> tensor<3xf32>
|
||||
%flow = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
|
||||
%grad:2 = "tf.TensorArrayGradV3"(%arg0, %flow) {source = "a"} : (tensor<!tf.resource>, tensor<f32>) -> (tensor<!tf.resource>, tensor<f32>)
|
||||
%gwrite = "tf.TensorArrayWriteV3"(%grad#0, %const1, %elem, %grad#1) : (tensor<!tf.resource>, tensor<i32>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
|
||||
%grad2:2 = "tf.TensorArrayGradV3"(%arg0, %flow) {source = "b"} : (tensor<!tf.resource>, tensor<f32>) -> (tensor<!tf.resource>, tensor<f32>)
|
||||
%gwrite2 = "tf.TensorArrayWriteV3"(%grad2#0, %const1, %elem, %grad2#1) : (tensor<!tf.resource>, tensor<i32>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
|
||||
return %arg0 : tensor<!tf.resource>
|
||||
}
|
||||
// CHECK: func @callee_tensorarray_decomposed(%[[CARG0:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[CARG1:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[CARG2:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>)
|
||||
// CHECK: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[CARG1]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
|
||||
// CHECK: %[[UPDATE1:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ1]],
|
||||
// CHECK: "tf.AssignVariableOp"(%[[CARG1]], %[[UPDATE1]])
|
||||
// CHECK: %[[READ2:.*]] = "tf.ReadVariableOp"(%[[CARG2]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
|
||||
// CHECK: %[[UPDATE2:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ2]],
|
||||
// CHECK: "tf.AssignVariableOp"(%[[CARG2]], %[[UPDATE2]])
|
||||
// CHECK: return %[[CARG0]]
|
||||
|
||||
// -----
|
||||
|
||||
// Test the pass reports failure on unknown size.
|
||||
|
||||
func @main(%arg0: tensor<i32>) -> () {
|
||||
// expected-error @+1 {{unknown max element count}}
|
||||
%ta:2 = "tf.TensorArrayV3"(%arg0) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test the pass reports failure on unknown shape.
|
||||
|
||||
func @main(%arg0: tensor<i32>) -> () {
|
||||
%size = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
|
||||
// expected-error @+1 {{unknown element shape}}
|
||||
%ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$unknown_rank: true", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that the pass reports error on ambiguous tensor array.
|
||||
|
||||
func @main(%arg0: tensor<i1>) -> () {
|
||||
%size = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
|
||||
%ta0:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
|
||||
%ta1:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
|
||||
%if_op = "tf.If"(%arg0, %ta0#0, %ta1#0) {then_branch = @if_then, else_branch = @if_else, is_stateless = false}
|
||||
: (tensor<i1>, tensor<!tf.resource>, tensor<!tf.resource>) -> tensor<!tf.resource>
|
||||
%index = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
// expected-error @+1 {{unknown tensor array}}
|
||||
%read = "tf.TensorArrayReadV3"(%if_op, %index, %ta0#1) : (tensor<!tf.resource>, tensor<i32>, tensor<f32>) -> tensor<3xf32>
|
||||
return
|
||||
}
|
||||
func @if_then(%arg0: tensor<!tf.resource>, %arg1: tensor<!tf.resource>) -> tensor<!tf.resource> {
|
||||
return %arg0 : tensor<!tf.resource>
|
||||
}
|
||||
func @if_else(%arg0: tensor<!tf.resource>, %arg1: tensor<!tf.resource>) -> tensor<!tf.resource> {
|
||||
return %arg1 : tensor<!tf.resource>
|
||||
}
|
||||
|
@ -19,7 +19,8 @@ limitations under the License.
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ADT/StringMap.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
@ -55,6 +56,8 @@ namespace {
|
||||
|
||||
namespace cutil = TF::collection_ops_util;
|
||||
|
||||
using std::string;
|
||||
|
||||
// A pass that converts tensor array operations to tensor operations and
|
||||
// read/assign ops on local variables. A later resource lifting pass can further
|
||||
// remove the local variables.
|
||||
@ -85,7 +88,7 @@ LogicalResult GetSplitElementTypeAndCount(TF::TensorArraySplitV3Op split,
|
||||
return split.emitOpError("unknown or invalid split tensor shape");
|
||||
}
|
||||
int64_t length = buffer_type.getDimSize(0) / *count;
|
||||
for (auto len : lengths_const.value().getValues<APInt>()) {
|
||||
for (const auto& len : lengths_const.value().getValues<APInt>()) {
|
||||
if (length == len.getSExtValue()) continue;
|
||||
return split.emitOpError("different split lengths are not supported");
|
||||
}
|
||||
@ -145,7 +148,7 @@ struct TensorArrayStats {
|
||||
// this is a gradient.
|
||||
bool accumulate_on_write;
|
||||
// Maps from a gradient source string to the local variable to the gradient.
|
||||
llvm::SmallDenseMap<llvm::StringRef, Value> grads;
|
||||
llvm::StringMap<Value> grads;
|
||||
};
|
||||
|
||||
LogicalResult HandleTensorArrayV3Op(
|
||||
@ -224,10 +227,7 @@ LogicalResult HandleTensorArrayWriteV3Op(
|
||||
cutil::GetElement(index_reshape, buffer, builder, write.getLoc(),
|
||||
/*keep_slice_shape=*/true);
|
||||
// Add a size-1 leading dimension to elem.
|
||||
for (auto dim : buffer.getType().cast<RankedTensorType>().getShape())
|
||||
LOG(ERROR) << " buffer : " << dim;
|
||||
auto slice_type = original_elem.getType().cast<RankedTensorType>();
|
||||
for (auto dim : slice_type.getShape()) LOG(ERROR) << " resahpe : " << dim;
|
||||
elem = builder.create<TF::ReshapeOp>(
|
||||
write.getLoc(), ArrayRef<Type>{slice_type},
|
||||
ArrayRef<Value>{elem, cutil::GetR1Const(slice_type.getShape(), builder,
|
||||
@ -339,6 +339,26 @@ LogicalResult HandleTensorArraySizeV3Op(
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult CreateAndInitializeGradVariable(Type local_var_type,
|
||||
Operation* op, Value* var) {
|
||||
OpBuilder builder(op);
|
||||
*var = builder.create<TF::MlirLocalVarOp>(
|
||||
op->getLoc(), ArrayRef<Type>{local_var_type}, ArrayRef<Value>{},
|
||||
ArrayRef<NamedAttribute>{});
|
||||
Value buffer;
|
||||
auto buffer_type = getElementTypeOrSelf(local_var_type)
|
||||
.cast<TF::ResourceType>()
|
||||
.getSubtypes()[0]
|
||||
.cast<RankedTensorType>();
|
||||
if (failed(cutil::CreateInitBufferValue(
|
||||
buffer_type.getShape().drop_front(), buffer_type.getDimSize(0), op,
|
||||
buffer_type.getElementType(), builder, &buffer))) {
|
||||
return failure();
|
||||
}
|
||||
cutil::WriteLocalVariable(*var, buffer, builder, op->getLoc());
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult HandleTensorArrayGradV3Op(
|
||||
TF::TensorArrayGradV3Op grad,
|
||||
llvm::SmallDenseMap<Value, TensorArrayStats>* stats) {
|
||||
@ -347,26 +367,17 @@ LogicalResult HandleTensorArrayGradV3Op(
|
||||
Value grad_var;
|
||||
auto sit = stats->find(local_var);
|
||||
if (sit == stats->end()) return grad.emitOpError("unknown tensor array");
|
||||
auto emplace_res = sit->getSecond().grads.try_emplace(grad.source(), Value());
|
||||
auto emplace_res =
|
||||
sit->getSecond().grads.try_emplace(grad.source().str(), Value());
|
||||
if (!emplace_res.second) {
|
||||
// If the source has been assigned a grad, use it.
|
||||
grad_var = emplace_res.first->getSecond();
|
||||
grad_var = emplace_res.first->second;
|
||||
} else {
|
||||
grad_var = builder.create<TF::MlirLocalVarOp>(
|
||||
grad.getLoc(), ArrayRef<Type>{local_var.getType()}, ArrayRef<Value>{},
|
||||
ArrayRef<NamedAttribute>{});
|
||||
Value buffer;
|
||||
auto buffer_type = getElementTypeOrSelf(local_var.getType())
|
||||
.cast<TF::ResourceType>()
|
||||
.getSubtypes()[0]
|
||||
.cast<RankedTensorType>();
|
||||
if (failed(cutil::CreateInitBufferValue(
|
||||
buffer_type.getShape().drop_front(), buffer_type.getDimSize(0),
|
||||
grad, buffer_type.getElementType(), builder, &buffer))) {
|
||||
if (failed(CreateAndInitializeGradVariable(local_var.getType(), grad,
|
||||
&grad_var))) {
|
||||
return failure();
|
||||
}
|
||||
cutil::WriteLocalVariable(grad_var, buffer, builder, grad.getLoc());
|
||||
emplace_res.first->getSecond() = grad_var;
|
||||
emplace_res.first->second = grad_var;
|
||||
// Write to a grad accumulates with previous writes.
|
||||
(*stats)[grad_var].accumulate_on_write = true;
|
||||
}
|
||||
@ -409,36 +420,454 @@ LogicalResult HandleTensorArrayScatterV3Op(
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult DecomposeTensorArrayOps(Block* block, ModuleOp module) {
|
||||
llvm::SmallDenseMap<Value, TensorArrayStats> stats;
|
||||
// Updates func's type according to its current arguments and return values.
|
||||
void UpdateFuncType(FuncOp func) {
|
||||
llvm::SmallVector<Type, 8> arg_types;
|
||||
for (auto arg : func.getArguments()) arg_types.push_back(arg.getType());
|
||||
func.setType(FunctionType::get(
|
||||
arg_types,
|
||||
llvm::to_vector<8>(func.front().getTerminator()->getOperandTypes()),
|
||||
func.getContext()));
|
||||
}
|
||||
|
||||
// Finds the accessed gradient sources for each tensor array argument.
|
||||
llvm::SmallDenseMap<int64_t, llvm::SmallVector<string, 4>> AccessedGradients(
|
||||
ArrayRef<FuncOp> funcs, ModuleOp module) {
|
||||
llvm::SmallDenseMap<int64_t, llvm::SmallVector<string, 4>> result;
|
||||
llvm::SmallDenseMap<int64_t, llvm::StringSet<>> result_sets;
|
||||
auto insert = [&](Value v, const string& source) {
|
||||
auto arg = v.cast<BlockArgument>();
|
||||
if (!arg) return;
|
||||
auto insert_res = result_sets[arg.getArgNumber()].insert(source);
|
||||
if (!insert_res.second) return;
|
||||
result[arg.getArgNumber()].push_back(source);
|
||||
};
|
||||
for (FuncOp func : funcs) {
|
||||
for (auto& op : func.front().getOperations()) {
|
||||
if (llvm::isa<TF::IdentityOp>(&op) || llvm::isa<TF::IdentityNOp>(&op)) {
|
||||
op.replaceAllUsesWith(op.getOperands());
|
||||
continue;
|
||||
}
|
||||
if (auto grad = llvm::dyn_cast<TF::TensorArrayGradV3Op>(&op)) {
|
||||
insert(grad.handle(), grad.source().str());
|
||||
} else if (auto while_op = llvm::dyn_cast<TF::WhileOp>(&op)) {
|
||||
auto body = module.lookupSymbol<FuncOp>(while_op.body());
|
||||
auto cond = module.lookupSymbol<FuncOp>(while_op.cond());
|
||||
for (const auto& entry : AccessedGradients({body, cond}, module)) {
|
||||
for (const string& source : entry.getSecond()) {
|
||||
insert(while_op.getOperand(entry.getFirst()), source);
|
||||
}
|
||||
}
|
||||
} else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
|
||||
auto then_branch = module.lookupSymbol<FuncOp>(if_op.then_branch());
|
||||
auto else_branch = module.lookupSymbol<FuncOp>(if_op.else_branch());
|
||||
for (const auto& entry :
|
||||
AccessedGradients({then_branch, else_branch}, module)) {
|
||||
for (const string& source : entry.getSecond()) {
|
||||
insert(if_op.getOperand(entry.getFirst() + 1), source);
|
||||
}
|
||||
}
|
||||
} else if (auto pc = llvm::dyn_cast<TF::PartitionedCallOp>(&op)) {
|
||||
if (!pc.f().isa<FlatSymbolRefAttr>()) continue;
|
||||
auto callee = module.lookupSymbol<FuncOp>(pc.f().getRootReference());
|
||||
for (const auto& entry : AccessedGradients({callee}, module)) {
|
||||
for (const string& source : entry.getSecond()) {
|
||||
insert(pc.getOperand(entry.getFirst()), source);
|
||||
}
|
||||
}
|
||||
} else if (auto spc =
|
||||
llvm::dyn_cast<TF::StatefulPartitionedCallOp>(&op)) {
|
||||
auto callee = module.lookupSymbol<FuncOp>(spc.f());
|
||||
for (const auto& entry : AccessedGradients({callee}, module)) {
|
||||
for (const string& source : entry.getSecond()) {
|
||||
insert(spc.getOperand(entry.getFirst()), source);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Contains cached information for decomposed callee functions for (stateful)
|
||||
// partitioned call ops.
|
||||
struct PartitionedCallTensorArrayOpsInfo {
|
||||
bool signature_change;
|
||||
FuncOp decomposed_callee;
|
||||
llvm::SmallVector<std::pair<int64_t, llvm::SmallVector<string, 4>>, 4>
|
||||
arg_grads;
|
||||
llvm::SmallVector<std::pair<int64_t, int64_t>, 4> ret_forward_input;
|
||||
};
|
||||
|
||||
// Updates a called function's input signature by adjusting resource types, and
|
||||
// adding required gradient arguments.
|
||||
void ChangeFunctionInputSignature(
|
||||
FuncOp func,
|
||||
const llvm::SmallDenseMap<int64_t, llvm::SmallVector<string, 4>>& grads,
|
||||
llvm::function_ref<Type(int64_t)> ta_arg_buffer_type,
|
||||
llvm::function_ref<bool(int64_t)> ta_accumulate_on_write,
|
||||
llvm::SmallDenseMap<Value, TensorArrayStats>* stats) {
|
||||
int64_t original_args = func.getNumArguments();
|
||||
for (int64_t argnum = 0; argnum < original_args; ++argnum) {
|
||||
auto arg = func.getArgument(argnum);
|
||||
Type t = ta_arg_buffer_type(argnum);
|
||||
if (!t) continue;
|
||||
arg.setType(t);
|
||||
auto grad_it = grads.find(argnum);
|
||||
if (grad_it == grads.end()) continue;
|
||||
llvm::StringMap<Value> grads_map;
|
||||
for (const string& source : grad_it->getSecond()) {
|
||||
auto g = func.front().addArgument(t);
|
||||
(*stats)[g].accumulate_on_write = true;
|
||||
grads_map[source] = g;
|
||||
}
|
||||
auto& stat = (*stats)[arg];
|
||||
stat.accumulate_on_write = ta_accumulate_on_write(argnum);
|
||||
stat.grads = std::move(grads_map);
|
||||
}
|
||||
UpdateFuncType(func);
|
||||
}
|
||||
|
||||
LogicalResult DecomposeTensorArrayOps(
|
||||
Block*, ModuleOp, llvm::SmallDenseMap<Value, TensorArrayStats>*,
|
||||
llvm::SmallDenseMap<FuncOp, PartitionedCallTensorArrayOpsInfo>*);
|
||||
|
||||
LogicalResult HandleWhileOp(
|
||||
TF::WhileOp while_op, ModuleOp module,
|
||||
llvm::SmallDenseMap<Value, TensorArrayStats>* stats,
|
||||
llvm::SmallDenseMap<FuncOp, PartitionedCallTensorArrayOpsInfo>*
|
||||
decomposed_partitioned_call_callees) {
|
||||
auto body = module.lookupSymbol<FuncOp>(while_op.body());
|
||||
auto cond = module.lookupSymbol<FuncOp>(while_op.cond());
|
||||
auto grads = AccessedGradients({body, cond}, module);
|
||||
auto ta_arg_buffer_type = [&](int64_t index) -> Type {
|
||||
auto it = stats->find(while_op.getOperand(index));
|
||||
if (it == stats->end()) return nullptr;
|
||||
return it->getFirst().getType();
|
||||
};
|
||||
auto ta_accumulate_on_write = [&](int64_t index) {
|
||||
auto it = stats->find(while_op.getOperand(index));
|
||||
if (it == stats->end()) return false;
|
||||
return it->getSecond().accumulate_on_write;
|
||||
};
|
||||
llvm::SmallDenseMap<Value, TensorArrayStats> body_stats;
|
||||
ChangeFunctionInputSignature(body, grads, ta_arg_buffer_type,
|
||||
ta_accumulate_on_write, &body_stats);
|
||||
llvm::SmallDenseMap<Value, TensorArrayStats> cond_stats;
|
||||
ChangeFunctionInputSignature(cond, grads, ta_arg_buffer_type,
|
||||
ta_accumulate_on_write, &cond_stats);
|
||||
if (failed(DecomposeTensorArrayOps(&body.front(), module, &body_stats,
|
||||
decomposed_partitioned_call_callees)) ||
|
||||
failed(DecomposeTensorArrayOps(&cond.front(), module, &cond_stats,
|
||||
decomposed_partitioned_call_callees))) {
|
||||
return failure();
|
||||
}
|
||||
if (body_stats.empty() && cond_stats.empty()) return success();
|
||||
auto old_body_ret = body.front().getTerminator();
|
||||
auto new_retvals = llvm::to_vector<8>(old_body_ret->getOperands());
|
||||
for (int64_t i = 0; i < while_op.getNumResults(); ++i) {
|
||||
if (!ta_arg_buffer_type(i)) continue;
|
||||
auto retval = old_body_ret->getOperand(i);
|
||||
auto arg = retval.dyn_cast<BlockArgument>();
|
||||
if (!arg) {
|
||||
return while_op.emitOpError(
|
||||
"output tensor array does not alias input in a while loop");
|
||||
}
|
||||
for (const string& source : grads[i]) {
|
||||
new_retvals.push_back(body_stats[arg].grads[source]);
|
||||
}
|
||||
}
|
||||
OpBuilder(old_body_ret).create<ReturnOp>(old_body_ret->getLoc(), new_retvals);
|
||||
old_body_ret->erase();
|
||||
UpdateFuncType(body);
|
||||
// Recreate the while op.
|
||||
auto operands = llvm::to_vector<8>(while_op.getOperands());
|
||||
for (int64_t i = 0; i < while_op.getNumOperands(); ++i) {
|
||||
auto grad_it = grads.find(i);
|
||||
auto& stat = (*stats)[operands[i]];
|
||||
if (grad_it == grads.end()) continue;
|
||||
for (const string& source : grad_it->getSecond()) {
|
||||
auto it = stat.grads.find(source);
|
||||
if (it != stat.grads.end()) {
|
||||
operands.push_back(it->second);
|
||||
} else {
|
||||
Value grad_var;
|
||||
if (failed(CreateAndInitializeGradVariable(operands[i].getType(),
|
||||
while_op, &grad_var))) {
|
||||
return failure();
|
||||
}
|
||||
stat.grads[source] = grad_var;
|
||||
operands.push_back(grad_var);
|
||||
}
|
||||
}
|
||||
}
|
||||
OpBuilder builder(while_op);
|
||||
auto new_while =
|
||||
builder.create<TF::WhileOp>(while_op.getLoc(), body.getType().getInputs(),
|
||||
operands, while_op.getAttrs());
|
||||
// Clear the output shapes as it is not needed for XLA lowering.
|
||||
new_while.setAttr("output_shapes", builder.getArrayAttr({}));
|
||||
for (int64_t i = 0; i < while_op.getNumOperands(); ++i) {
|
||||
if (ta_arg_buffer_type(i)) {
|
||||
while_op.getResult(i).replaceAllUsesWith(while_op.getOperand(i));
|
||||
} else {
|
||||
while_op.getResult(i).replaceAllUsesWith(new_while.getResult(i));
|
||||
}
|
||||
}
|
||||
while_op.erase();
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult HandleIfOp(
|
||||
TF::IfOp if_op, ModuleOp module,
|
||||
llvm::SmallDenseMap<Value, TensorArrayStats>* stats,
|
||||
llvm::SmallDenseMap<FuncOp, PartitionedCallTensorArrayOpsInfo>*
|
||||
decomposed_partitioned_call_callees) {
|
||||
auto then_branch = module.lookupSymbol<FuncOp>(if_op.then_branch());
|
||||
auto else_branch = module.lookupSymbol<FuncOp>(if_op.else_branch());
|
||||
auto grads = AccessedGradients({then_branch, else_branch}, module);
|
||||
auto ta_arg_buffer_type = [&](int64_t index) -> Type {
|
||||
auto it = stats->find(if_op.getOperand(index + 1));
|
||||
if (it == stats->end()) return nullptr;
|
||||
return it->getFirst().getType();
|
||||
};
|
||||
auto ta_accumulate_on_write = [&](int64_t index) {
|
||||
auto it = stats->find(if_op.getOperand(index + 1));
|
||||
if (it == stats->end()) return false;
|
||||
return it->getSecond().accumulate_on_write;
|
||||
};
|
||||
llvm::SmallDenseMap<Value, TensorArrayStats> then_stats;
|
||||
ChangeFunctionInputSignature(then_branch, grads, ta_arg_buffer_type,
|
||||
ta_accumulate_on_write, &then_stats);
|
||||
llvm::SmallDenseMap<Value, TensorArrayStats> else_stats;
|
||||
ChangeFunctionInputSignature(else_branch, grads, ta_arg_buffer_type,
|
||||
ta_accumulate_on_write, &else_stats);
|
||||
if (failed(DecomposeTensorArrayOps(&then_branch.front(), module, &then_stats,
|
||||
decomposed_partitioned_call_callees)) ||
|
||||
failed(DecomposeTensorArrayOps(&else_branch.front(), module, &else_stats,
|
||||
decomposed_partitioned_call_callees))) {
|
||||
return failure();
|
||||
}
|
||||
if (then_stats.empty() && else_stats.empty()) return success();
|
||||
// Recreate the if op.
|
||||
auto operands = llvm::to_vector<8>(if_op.getOperands());
|
||||
for (int64_t i = 0; i < if_op.getNumOperands() - 1; ++i) {
|
||||
auto grad_it = grads.find(i);
|
||||
auto& stat = (*stats)[operands[i + 1]];
|
||||
if (grad_it == grads.end()) continue;
|
||||
for (const string& source : grad_it->getSecond()) {
|
||||
auto it = stat.grads.find(source);
|
||||
if (it != stat.grads.end()) {
|
||||
operands.push_back(it->second);
|
||||
} else {
|
||||
Value grad_var;
|
||||
if (failed(CreateAndInitializeGradVariable(operands[i + 1].getType(),
|
||||
if_op, &grad_var))) {
|
||||
return failure();
|
||||
}
|
||||
stat.grads[source] = grad_var;
|
||||
operands.push_back(grad_var);
|
||||
}
|
||||
}
|
||||
}
|
||||
OpBuilder builder(if_op);
|
||||
auto new_if = builder.create<TF::IfOp>(if_op.getLoc(),
|
||||
then_branch.getType().getResults(),
|
||||
operands, if_op.getAttrs());
|
||||
// Clear the output shapes as it is not needed for XLA lowering.
|
||||
new_if.setAttr("output_shapes", builder.getArrayAttr({}));
|
||||
auto ret_forwards_input = [](FuncOp f, int64_t ret_ind) -> int64_t {
|
||||
auto retval = f.front().getTerminator()->getOperand(ret_ind);
|
||||
auto arg = retval.dyn_cast<BlockArgument>();
|
||||
if (!arg) return -1;
|
||||
return arg.getArgNumber();
|
||||
};
|
||||
for (int64_t i = 0; i < if_op.getNumResults(); ++i) {
|
||||
if (!getElementTypeOrSelf(if_op.getResult(i).getType())
|
||||
.isa<TF::ResourceType>()) {
|
||||
if_op.getResult(i).replaceAllUsesWith(new_if.getResult(i));
|
||||
continue;
|
||||
}
|
||||
int64_t then_forward_input = ret_forwards_input(then_branch, i);
|
||||
int64_t else_foward_input = ret_forwards_input(else_branch, i);
|
||||
if (then_forward_input != else_foward_input || then_forward_input < 0) {
|
||||
return if_op.emitOpError(
|
||||
"branches do not forward the same input resource");
|
||||
}
|
||||
if_op.getResult(i).replaceAllUsesWith(
|
||||
if_op.getOperand(then_forward_input + 1));
|
||||
}
|
||||
if_op.erase();
|
||||
return success();
|
||||
}
|
||||
|
||||
template <typename CallOp>
|
||||
LogicalResult HandlePartitionedCallOp(
|
||||
CallOp call, FuncOp callee, ModuleOp module,
|
||||
llvm::SmallDenseMap<Value, TensorArrayStats>* stats,
|
||||
llvm::SmallDenseMap<FuncOp, PartitionedCallTensorArrayOpsInfo>*
|
||||
decomposed_partitioned_call_callees) {
|
||||
auto emplace_res = decomposed_partitioned_call_callees->try_emplace(
|
||||
callee, PartitionedCallTensorArrayOpsInfo());
|
||||
auto& info = emplace_res.first->getSecond();
|
||||
// Recreates the call op with info.
|
||||
auto recreate_caller = [&]() -> LogicalResult {
|
||||
auto new_operands = llvm::to_vector<8>(call.getOperands());
|
||||
for (const auto& entry : info.arg_grads) {
|
||||
auto it = stats->find(call.getOperand(entry.first));
|
||||
if (it == stats->end()) return call.emitOpError("unknown tensor array");
|
||||
for (const string& source : entry.second) {
|
||||
auto grad_it = it->getSecond().grads.find(source);
|
||||
if (grad_it != it->getSecond().grads.end()) {
|
||||
new_operands.push_back(grad_it->second);
|
||||
} else {
|
||||
Value grad_var;
|
||||
if (failed(CreateAndInitializeGradVariable(it->getFirst().getType(),
|
||||
call, &grad_var))) {
|
||||
return failure();
|
||||
}
|
||||
it->getSecond().grads[source] = grad_var;
|
||||
new_operands.push_back(grad_var);
|
||||
}
|
||||
}
|
||||
}
|
||||
OpBuilder builder(call);
|
||||
auto new_call = builder.create<CallOp>(
|
||||
call.getLoc(), info.decomposed_callee.getType().getResults(),
|
||||
new_operands, call.getAttrs());
|
||||
new_call.setAttr(
|
||||
"f", builder.getSymbolRefAttr(
|
||||
const_cast<FuncOp&>(info.decomposed_callee).getName()));
|
||||
for (const auto& entry : info.ret_forward_input) {
|
||||
call.getResult(entry.first)
|
||||
.replaceAllUsesWith(call.getOperand(entry.second));
|
||||
}
|
||||
call.replaceAllUsesWith(new_call);
|
||||
call.erase();
|
||||
return success();
|
||||
};
|
||||
if (!emplace_res.second) {
|
||||
// This callee was handled before.
|
||||
if (!info.signature_change) return success();
|
||||
return recreate_caller();
|
||||
}
|
||||
// Rewrite the callee on a cloned function.
|
||||
info.signature_change = false;
|
||||
auto ta_arg_buffer_type = [&](int64_t index) -> Type {
|
||||
auto it = stats->find(call.getOperand(index));
|
||||
if (it == stats->end()) return nullptr;
|
||||
info.signature_change = true;
|
||||
return it->getFirst().getType();
|
||||
};
|
||||
auto ta_accumulate_on_write = [&](int64_t index) {
|
||||
auto it = stats->find(call.getOperand(index));
|
||||
if (it == stats->end()) return false;
|
||||
return it->getSecond().accumulate_on_write;
|
||||
};
|
||||
auto callee_clone = callee.clone();
|
||||
auto grads = AccessedGradients({callee_clone}, module);
|
||||
for (int64_t i = 0; i < callee_clone.getNumArguments(); ++i) {
|
||||
auto it = grads.find(i);
|
||||
if (it == grads.end()) continue;
|
||||
info.arg_grads.emplace_back(i, it->getSecond());
|
||||
}
|
||||
llvm::SmallDenseMap<Value, TensorArrayStats> callee_stats;
|
||||
ChangeFunctionInputSignature(callee_clone, grads, ta_arg_buffer_type,
|
||||
ta_accumulate_on_write, &callee_stats);
|
||||
if (failed(DecomposeTensorArrayOps(&callee_clone.front(), module,
|
||||
&callee_stats,
|
||||
decomposed_partitioned_call_callees))) {
|
||||
return failure();
|
||||
}
|
||||
for (int64_t i = 0; i < call.getNumResults(); ++i) {
|
||||
auto ret = callee_clone.front().getTerminator()->getOperand(i);
|
||||
if (!getElementTypeOrSelf(ret.getType()).isa<TF::ResourceType>()) continue;
|
||||
auto arg = ret.dyn_cast<BlockArgument>();
|
||||
if (!arg) continue;
|
||||
info.ret_forward_input.emplace_back(i, arg.getArgNumber());
|
||||
}
|
||||
|
||||
if (!info.signature_change) {
|
||||
// Signature is not modified. We do not need to keep two copies.
|
||||
info.signature_change = false;
|
||||
auto name = callee.getName();
|
||||
callee.erase();
|
||||
callee_clone.setName(name);
|
||||
SymbolTable(module).insert(callee_clone);
|
||||
} else {
|
||||
info.decomposed_callee = callee_clone;
|
||||
// Add the clone with a new name.
|
||||
auto name =
|
||||
llvm::formatv("{0}_{1}", callee.getName(), "tensorarray_decomposed")
|
||||
.str();
|
||||
callee_clone.setName(name);
|
||||
SymbolTable(module).insert(callee_clone);
|
||||
}
|
||||
if (info.signature_change) return recreate_caller();
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult DecomposeTensorArrayOps(
|
||||
Block* block, ModuleOp module,
|
||||
llvm::SmallDenseMap<Value, TensorArrayStats>* stats,
|
||||
llvm::SmallDenseMap<FuncOp, PartitionedCallTensorArrayOpsInfo>*
|
||||
decomposed_partitioned_call_callees) {
|
||||
for (auto& op : llvm::make_early_inc_range(block->getOperations())) {
|
||||
if (llvm::isa<TF::IdentityOp>(&op) || llvm::isa<TF::IdentityNOp>(&op)) {
|
||||
op.replaceAllUsesWith(op.getOperands());
|
||||
op.erase();
|
||||
} else if (auto ta = llvm::dyn_cast<TF::TensorArrayV3Op>(&op)) {
|
||||
if (failed(HandleTensorArrayV3Op(ta, module, &stats))) {
|
||||
if (failed(HandleTensorArrayV3Op(ta, module, stats))) {
|
||||
return failure();
|
||||
}
|
||||
} else if (auto read = llvm::dyn_cast<TF::TensorArrayReadV3Op>(&op)) {
|
||||
if (failed(HandleTensorArrayReadV3Op(read, stats))) return failure();
|
||||
if (failed(HandleTensorArrayReadV3Op(read, *stats))) return failure();
|
||||
} else if (auto write = llvm::dyn_cast<TF::TensorArrayWriteV3Op>(&op)) {
|
||||
if (failed(HandleTensorArrayWriteV3Op(write, stats))) return failure();
|
||||
if (failed(HandleTensorArrayWriteV3Op(write, *stats))) return failure();
|
||||
} else if (auto concat = llvm::dyn_cast<TF::TensorArrayConcatV3Op>(&op)) {
|
||||
if (failed(HandleTensorArrayConcatV3Op(concat, stats))) return failure();
|
||||
if (failed(HandleTensorArrayConcatV3Op(concat, *stats))) return failure();
|
||||
} else if (auto split = llvm::dyn_cast<TF::TensorArraySplitV3Op>(&op)) {
|
||||
if (failed(HandleTensorArraySplitV3Op(split, stats))) return failure();
|
||||
if (failed(HandleTensorArraySplitV3Op(split, *stats))) return failure();
|
||||
} else if (auto size = llvm::dyn_cast<TF::TensorArraySizeV3Op>(&op)) {
|
||||
if (failed(HandleTensorArraySizeV3Op(size, stats))) return failure();
|
||||
if (failed(HandleTensorArraySizeV3Op(size, *stats))) return failure();
|
||||
} else if (auto grad = llvm::dyn_cast<TF::TensorArrayGradV3Op>(&op)) {
|
||||
if (failed(HandleTensorArrayGradV3Op(grad, &stats))) return failure();
|
||||
if (failed(HandleTensorArrayGradV3Op(grad, stats))) return failure();
|
||||
} else if (auto gather = llvm::dyn_cast<TF::TensorArrayGatherV3Op>(&op)) {
|
||||
if (failed(HandleTensorArrayGatherV3Op(gather, stats))) return failure();
|
||||
if (failed(HandleTensorArrayGatherV3Op(gather, *stats))) return failure();
|
||||
} else if (auto scatter = llvm::dyn_cast<TF::TensorArrayScatterV3Op>(&op)) {
|
||||
if (failed(HandleTensorArrayScatterV3Op(scatter, stats))) {
|
||||
if (failed(HandleTensorArrayScatterV3Op(scatter, *stats))) {
|
||||
return failure();
|
||||
}
|
||||
} else if (auto close = llvm::dyn_cast<TF::TensorArrayCloseV3Op>(&op)) {
|
||||
close.erase();
|
||||
} else if (auto while_op = llvm::dyn_cast<TF::WhileOp>(&op)) {
|
||||
if (failed(HandleWhileOp(while_op, module, stats,
|
||||
decomposed_partitioned_call_callees))) {
|
||||
return failure();
|
||||
}
|
||||
} else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
|
||||
if (failed(HandleIfOp(if_op, module, stats,
|
||||
decomposed_partitioned_call_callees))) {
|
||||
return failure();
|
||||
}
|
||||
} else if (auto pcall = llvm::dyn_cast<TF::PartitionedCallOp>(&op)) {
|
||||
if (!pcall.f().isa<FlatSymbolRefAttr>()) {
|
||||
return pcall.emitOpError(
|
||||
"TensorArray decomposition does not support call with nested "
|
||||
"references.");
|
||||
}
|
||||
if (failed(HandlePartitionedCallOp(
|
||||
pcall, module.lookupSymbol<FuncOp>(pcall.f().getRootReference()),
|
||||
module, stats, decomposed_partitioned_call_callees))) {
|
||||
return failure();
|
||||
}
|
||||
} else if (auto spcall =
|
||||
llvm::dyn_cast<TF::StatefulPartitionedCallOp>(&op)) {
|
||||
if (failed(HandlePartitionedCallOp(
|
||||
spcall, module.lookupSymbol<FuncOp>(spcall.f()), module, stats,
|
||||
decomposed_partitioned_call_callees))) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
}
|
||||
return success();
|
||||
@ -448,7 +877,11 @@ void TensorArrayOpsDecompositionPass::runOnModule() {
|
||||
auto module = getModule();
|
||||
auto main = module.lookupSymbol<FuncOp>("main");
|
||||
if (!main) return;
|
||||
if (failed(DecomposeTensorArrayOps(&main.front(), module))) {
|
||||
llvm::SmallDenseMap<Value, TensorArrayStats> stats;
|
||||
llvm::SmallDenseMap<FuncOp, PartitionedCallTensorArrayOpsInfo>
|
||||
decomposed_partitioned_call_callees;
|
||||
if (failed(DecomposeTensorArrayOps(&main.front(), module, &stats,
|
||||
&decomposed_partitioned_call_callees))) {
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
@ -261,6 +261,7 @@ Status ConvertMLIRToXlaComputation(
|
||||
tf2xla.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||
tf2xla.addPass(mlir::TF::CreateTensorListOpsDecompositionPass());
|
||||
tf2xla.addPass(mlir::TF::CreateStackOpsDecompositionPass());
|
||||
tf2xla.addPass(mlir::TF::CreateTensorArrayOpsDecompositionPass());
|
||||
tf2xla.addPass(mlir::TFDevice::CreateDecomposeResourceOpsPass());
|
||||
tf2xla.addPass(mlir::TF::CreatePromoteResourcesToArgsPass());
|
||||
// LegalizeTFControlFlow encapsulates arguments for control flow operations
|
||||
|
Loading…
Reference in New Issue
Block a user