diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tensor_array_ops_decomposition.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tensor_array_ops_decomposition.mlir index 35cfb19a80b..1a13338b0ba 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tensor_array_ops_decomposition.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tensor_array_ops_decomposition.mlir @@ -187,3 +187,199 @@ func @main() { %write3 = "tf.TensorArrayWriteV3"(%grad3#0, %index, %value, %grad3#1) : (tensor, tensor, tensor<3xf32>, tensor) -> tensor return } + +// ----- + +// Tests while loop with access to the tensor array defined outside and its +// gradient defined inside. The gradient creation should be moved outside. + +// CHECK-LABEL: func @main +func @main() -> () { + // CHECK: %[[SIZE:.*]] = "tf.Const"() {value = dense<5> : tensor} : () -> tensor + %size = "tf.Const"() {value = dense<5> : tensor} : () -> tensor + %index = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + // CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> + // CHECK: %[[GVAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> + %ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) + // CHECK: "tf.While"(%[[VAR]], %[[SIZE]], %[[GVAR]]) + %1:2 = "tf.While"(%ta#0, %size) { + body = @while_body, cond = @while_cond, device = "", is_stateless = false} + : (tensor, tensor) -> (tensor, tensor) + // CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VAR]]) : (tensor>>) -> tensor<5x3xf32> + // CHECK: "tf.Slice"(%[[READ]], + %read = "tf.TensorArrayReadV3"(%1#0, %index, %ta#1) : (tensor, tensor, tensor) -> tensor<3xf32> + return +} +// CHECK: func @while_body(%[[BARG0:.*]]: tensor>>, %[[BARG1:.*]]: tensor, %[[BARG2:.*]]: tensor>>) +func @while_body(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + // CHECK: %[[CONST1:.*]] = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %const1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + // CHECK: %[[SUB:.*]] = "tf.Sub"(%[[BARG1]], %[[CONST1]]) + %sub = "tf.Sub"(%arg1, %const1) : (tensor, tensor) -> tensor + %elem = "tf._SomeOp"() : () -> tensor<3xf32> + %flow = "tf.Const"() {value = dense<1.0> : tensor} : () -> tensor + // CHECK: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[BARG0]]) : (tensor>>) -> tensor<5x3xf32> + // CHECK: %[[UPDATE1:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ1]], + // CHECK: "tf.AssignVariableOp"(%[[BARG0]], %[[UPDATE1]]) + %write = "tf.TensorArrayWriteV3"(%arg0, %sub, %elem, %flow) : (tensor, tensor, tensor<3xf32>, tensor) -> tensor + %grad:2 = "tf.TensorArrayGradV3"(%arg0, %write) {source = "a"} : (tensor, tensor) -> (tensor, tensor) + // CHECK: %[[READ2:.*]] = "tf.ReadVariableOp"(%[[BARG2]]) : (tensor>>) -> tensor<5x3xf32> + // CHECK: %[[UPDATE2:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ2]], + // CHECK: "tf.AssignVariableOp"(%[[BARG2]], %[[UPDATE2]]) + %gwrite = "tf.TensorArrayWriteV3"(%grad#0, %sub, %elem, %grad#1) : (tensor, tensor, tensor<3xf32>, tensor) -> tensor + // CHECK: return %[[BARG0]], %[[SUB]], %[[BARG2]] + return %arg0, %sub : tensor, tensor +} +// CHECK: func @while_cond(%[[CARG0:.*]]: tensor>>, %[[CARG1:.*]]: tensor, %[[CARG2:.*]]: tensor>>) +func @while_cond(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK-NEXT: return %[[CARG1]] + return %arg1 : tensor +} + +// ----- + +// Tests If op with access to the tensor array defined outside and its gradient +// defined inside. The gradient creation should be moved outside. + +// CHECK-LABEL: func @main +func @main() -> () { + // CHECK: %[[SIZE:.*]] = "tf.Const"() {value = dense<5> : tensor} : () -> tensor + %size = "tf.Const"() {value = dense<5> : tensor} : () -> tensor + %index = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + // CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> + %ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) + // CHECK: %[[COND:.*]] = "tf._SomeOp"() : () -> tensor + %cond = "tf._SomeOp"() : () -> tensor + // CHECK: %[[GVAR1:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> + // CHECK: %[[GVAR2:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> + // CHECK: "tf.If"(%[[COND]], %[[VAR]], %[[GVAR1]], %[[GVAR2]]) + %1 = "tf.If"(%cond, %ta#0) { + then_branch = @then_branch, else_branch = @else_branch, device = "", is_stateless = false} + : (tensor, tensor) -> tensor + // CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VAR]]) : (tensor>>) -> tensor<5x3xf32> + // CHECK: "tf.Slice"(%[[READ]], + %read = "tf.TensorArrayReadV3"(%1, %index, %ta#1) : (tensor, tensor, tensor) -> tensor<3xf32> + return +} +// CHECK: func @then_branch(%[[TARG0:.*]]: tensor>>, %[[TARG1:.*]]: tensor>>, %[[TARG2:.*]]: tensor>>) +func @then_branch(%arg0: tensor) -> tensor { + %const1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %elem = "tf._SomeOp"() : () -> tensor<3xf32> + %flow = "tf.Const"() {value = dense<1.0> : tensor} : () -> tensor + // CHECK: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[TARG1]]) : (tensor>>) -> tensor<5x3xf32> + // CHECK: %[[UPDATE1:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ1]], + // CHECK: "tf.AssignVariableOp"(%[[TARG1]], %[[UPDATE1]]) + %grad:2 = "tf.TensorArrayGradV3"(%arg0, %flow) {source = "a"} : (tensor, tensor) -> (tensor, tensor) + %gwrite = "tf.TensorArrayWriteV3"(%grad#0, %const1, %elem, %grad#1) : (tensor, tensor, tensor<3xf32>, tensor) -> tensor + // CHECK: return %[[TARG0]] + return %arg0 : tensor +} +// CHECK: func @else_branch(%[[EARG0:.*]]: tensor>>, %[[EARG1:.*]]: tensor>>, %[[EARG2:.*]]: tensor>>) +func @else_branch(%arg0: tensor) -> tensor { + // CHECK: %[[CONST1:.*]] = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %const1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %elem = "tf._SomeOp"() : () -> tensor<3xf32> + %flow = "tf.Const"() {value = dense<1.0> : tensor} : () -> tensor + // CHECK: %[[READ2:.*]] = "tf.ReadVariableOp"(%[[EARG2]]) : (tensor>>) -> tensor<5x3xf32> + // CHECK: %[[UPDATE2:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ2]], + // CHECK: "tf.AssignVariableOp"(%[[EARG2]], %[[UPDATE2]]) + %grad:2 = "tf.TensorArrayGradV3"(%arg0, %flow) {source = "b"} : (tensor, tensor) -> (tensor, tensor) + %gwrite = "tf.TensorArrayWriteV3"(%grad#0, %const1, %elem, %grad#1) : (tensor, tensor, tensor<3xf32>, tensor) -> tensor + // CHECK: return %[[EARG0]] + return %arg0 : tensor +} + +// ----- + +// Tests (Stateful)PartitionedCall op with access to the tensor array defined +// outside and its gradient defined inside. The gradient creation should be +// moved outside. + +// CHECK-LABEL: func @main +func @main() -> () { + // CHECK: %[[SIZE:.*]] = "tf.Const"() {value = dense<5> : tensor} : () -> tensor + %size = "tf.Const"() {value = dense<5> : tensor} : () -> tensor + %index = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + // CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> + %ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) + // CHECK: %[[COND:.*]] = "tf._SomeOp"() : () -> tensor + %cond = "tf._SomeOp"() : () -> tensor + // CHECK: %[[GVAR1:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> + %grad:2 = "tf.TensorArrayGradV3"(%ta#0, %ta#1) {source = "a"} : (tensor, tensor) -> (tensor, tensor) + // CHECK: %[[GVAR2:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> + // CHECK: "tf.StatefulPartitionedCall"(%[[VAR]], %[[GVAR1]], %[[GVAR2]]) + // CHECK-SAME: f = @callee_tensorarray_decomposed + %call = "tf.StatefulPartitionedCall"(%ta#0) {f = @callee, config = "", config_proto = "", executor_type = ""} + : (tensor) -> tensor + // CHECK: "tf.PartitionedCall"(%[[VAR]], %[[GVAR1]], %[[GVAR2]]) + // CHECK-SAME: f = @callee_tensorarray_decomposed + %call2 = "tf.PartitionedCall"(%call) {f = @callee, config = "", config_proto = "", executor_type = ""} + : (tensor) -> tensor + // CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VAR]]) : (tensor>>) -> tensor<5x3xf32> + // CHECK: "tf.Slice"(%[[READ]], + %read = "tf.TensorArrayReadV3"(%call2, %index, %ta#1) : (tensor, tensor, tensor) -> tensor<3xf32> + return +} +// CHECK-LABEL: func @callee +// CHECK-SAME: (%[[OCARG0:.*]]: tensor) -> tensor +func @callee(%arg0: tensor) -> tensor { + %const1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %elem = "tf._SomeOp"() : () -> tensor<3xf32> + %flow = "tf.Const"() {value = dense<1.0> : tensor} : () -> tensor + %grad:2 = "tf.TensorArrayGradV3"(%arg0, %flow) {source = "a"} : (tensor, tensor) -> (tensor, tensor) + %gwrite = "tf.TensorArrayWriteV3"(%grad#0, %const1, %elem, %grad#1) : (tensor, tensor, tensor<3xf32>, tensor) -> tensor + %grad2:2 = "tf.TensorArrayGradV3"(%arg0, %flow) {source = "b"} : (tensor, tensor) -> (tensor, tensor) + %gwrite2 = "tf.TensorArrayWriteV3"(%grad2#0, %const1, %elem, %grad2#1) : (tensor, tensor, tensor<3xf32>, tensor) -> tensor + return %arg0 : tensor +} +// CHECK: func @callee_tensorarray_decomposed(%[[CARG0:.*]]: tensor>>, %[[CARG1:.*]]: tensor>>, %[[CARG2:.*]]: tensor>>) +// CHECK: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[CARG1]]) : (tensor>>) -> tensor<5x3xf32> +// CHECK: %[[UPDATE1:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ1]], +// CHECK: "tf.AssignVariableOp"(%[[CARG1]], %[[UPDATE1]]) +// CHECK: %[[READ2:.*]] = "tf.ReadVariableOp"(%[[CARG2]]) : (tensor>>) -> tensor<5x3xf32> +// CHECK: %[[UPDATE2:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ2]], +// CHECK: "tf.AssignVariableOp"(%[[CARG2]], %[[UPDATE2]]) +// CHECK: return %[[CARG0]] + +// ----- + +// Test the pass reports failure on unknown size. + +func @main(%arg0: tensor) -> () { + // expected-error @+1 {{unknown max element count}} + %ta:2 = "tf.TensorArrayV3"(%arg0) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) + return +} + +// ----- + +// Test the pass reports failure on unknown shape. + +func @main(%arg0: tensor) -> () { + %size = "tf.Const"() {value = dense<5> : tensor} : () -> tensor + // expected-error @+1 {{unknown element shape}} + %ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$unknown_rank: true", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) + return +} + +// ----- + +// Tests that the pass reports error on ambiguous tensor array. + +func @main(%arg0: tensor) -> () { + %size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor + %ta0:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) + %ta1:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) + %if_op = "tf.If"(%arg0, %ta0#0, %ta1#0) {then_branch = @if_then, else_branch = @if_else, is_stateless = false} + : (tensor, tensor, tensor) -> tensor + %index = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + // expected-error @+1 {{unknown tensor array}} + %read = "tf.TensorArrayReadV3"(%if_op, %index, %ta0#1) : (tensor, tensor, tensor) -> tensor<3xf32> + return +} +func @if_then(%arg0: tensor, %arg1: tensor) -> tensor { + return %arg0 : tensor +} +func @if_else(%arg0: tensor, %arg1: tensor) -> tensor { + return %arg1 : tensor +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc index 59dab25c15c..b7efc5aa64b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc @@ -19,7 +19,8 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" -#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringSet.h" #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project @@ -55,6 +56,8 @@ namespace { namespace cutil = TF::collection_ops_util; +using std::string; + // A pass that converts tensor array operations to tensor operations and // read/assign ops on local variables. A later resource lifting pass can further // remove the local variables. @@ -85,7 +88,7 @@ LogicalResult GetSplitElementTypeAndCount(TF::TensorArraySplitV3Op split, return split.emitOpError("unknown or invalid split tensor shape"); } int64_t length = buffer_type.getDimSize(0) / *count; - for (auto len : lengths_const.value().getValues()) { + for (const auto& len : lengths_const.value().getValues()) { if (length == len.getSExtValue()) continue; return split.emitOpError("different split lengths are not supported"); } @@ -145,7 +148,7 @@ struct TensorArrayStats { // this is a gradient. bool accumulate_on_write; // Maps from a gradient source string to the local variable to the gradient. - llvm::SmallDenseMap grads; + llvm::StringMap grads; }; LogicalResult HandleTensorArrayV3Op( @@ -224,10 +227,7 @@ LogicalResult HandleTensorArrayWriteV3Op( cutil::GetElement(index_reshape, buffer, builder, write.getLoc(), /*keep_slice_shape=*/true); // Add a size-1 leading dimension to elem. - for (auto dim : buffer.getType().cast().getShape()) - LOG(ERROR) << " buffer : " << dim; auto slice_type = original_elem.getType().cast(); - for (auto dim : slice_type.getShape()) LOG(ERROR) << " resahpe : " << dim; elem = builder.create( write.getLoc(), ArrayRef{slice_type}, ArrayRef{elem, cutil::GetR1Const(slice_type.getShape(), builder, @@ -339,6 +339,26 @@ LogicalResult HandleTensorArraySizeV3Op( return success(); } +LogicalResult CreateAndInitializeGradVariable(Type local_var_type, + Operation* op, Value* var) { + OpBuilder builder(op); + *var = builder.create( + op->getLoc(), ArrayRef{local_var_type}, ArrayRef{}, + ArrayRef{}); + Value buffer; + auto buffer_type = getElementTypeOrSelf(local_var_type) + .cast() + .getSubtypes()[0] + .cast(); + if (failed(cutil::CreateInitBufferValue( + buffer_type.getShape().drop_front(), buffer_type.getDimSize(0), op, + buffer_type.getElementType(), builder, &buffer))) { + return failure(); + } + cutil::WriteLocalVariable(*var, buffer, builder, op->getLoc()); + return success(); +} + LogicalResult HandleTensorArrayGradV3Op( TF::TensorArrayGradV3Op grad, llvm::SmallDenseMap* stats) { @@ -347,26 +367,17 @@ LogicalResult HandleTensorArrayGradV3Op( Value grad_var; auto sit = stats->find(local_var); if (sit == stats->end()) return grad.emitOpError("unknown tensor array"); - auto emplace_res = sit->getSecond().grads.try_emplace(grad.source(), Value()); + auto emplace_res = + sit->getSecond().grads.try_emplace(grad.source().str(), Value()); if (!emplace_res.second) { // If the source has been assigned a grad, use it. - grad_var = emplace_res.first->getSecond(); + grad_var = emplace_res.first->second; } else { - grad_var = builder.create( - grad.getLoc(), ArrayRef{local_var.getType()}, ArrayRef{}, - ArrayRef{}); - Value buffer; - auto buffer_type = getElementTypeOrSelf(local_var.getType()) - .cast() - .getSubtypes()[0] - .cast(); - if (failed(cutil::CreateInitBufferValue( - buffer_type.getShape().drop_front(), buffer_type.getDimSize(0), - grad, buffer_type.getElementType(), builder, &buffer))) { + if (failed(CreateAndInitializeGradVariable(local_var.getType(), grad, + &grad_var))) { return failure(); } - cutil::WriteLocalVariable(grad_var, buffer, builder, grad.getLoc()); - emplace_res.first->getSecond() = grad_var; + emplace_res.first->second = grad_var; // Write to a grad accumulates with previous writes. (*stats)[grad_var].accumulate_on_write = true; } @@ -409,36 +420,454 @@ LogicalResult HandleTensorArrayScatterV3Op( return success(); } -LogicalResult DecomposeTensorArrayOps(Block* block, ModuleOp module) { - llvm::SmallDenseMap stats; +// Updates func's type according to its current arguments and return values. +void UpdateFuncType(FuncOp func) { + llvm::SmallVector arg_types; + for (auto arg : func.getArguments()) arg_types.push_back(arg.getType()); + func.setType(FunctionType::get( + arg_types, + llvm::to_vector<8>(func.front().getTerminator()->getOperandTypes()), + func.getContext())); +} + +// Finds the accessed gradient sources for each tensor array argument. +llvm::SmallDenseMap> AccessedGradients( + ArrayRef funcs, ModuleOp module) { + llvm::SmallDenseMap> result; + llvm::SmallDenseMap> result_sets; + auto insert = [&](Value v, const string& source) { + auto arg = v.cast(); + if (!arg) return; + auto insert_res = result_sets[arg.getArgNumber()].insert(source); + if (!insert_res.second) return; + result[arg.getArgNumber()].push_back(source); + }; + for (FuncOp func : funcs) { + for (auto& op : func.front().getOperations()) { + if (llvm::isa(&op) || llvm::isa(&op)) { + op.replaceAllUsesWith(op.getOperands()); + continue; + } + if (auto grad = llvm::dyn_cast(&op)) { + insert(grad.handle(), grad.source().str()); + } else if (auto while_op = llvm::dyn_cast(&op)) { + auto body = module.lookupSymbol(while_op.body()); + auto cond = module.lookupSymbol(while_op.cond()); + for (const auto& entry : AccessedGradients({body, cond}, module)) { + for (const string& source : entry.getSecond()) { + insert(while_op.getOperand(entry.getFirst()), source); + } + } + } else if (auto if_op = llvm::dyn_cast(&op)) { + auto then_branch = module.lookupSymbol(if_op.then_branch()); + auto else_branch = module.lookupSymbol(if_op.else_branch()); + for (const auto& entry : + AccessedGradients({then_branch, else_branch}, module)) { + for (const string& source : entry.getSecond()) { + insert(if_op.getOperand(entry.getFirst() + 1), source); + } + } + } else if (auto pc = llvm::dyn_cast(&op)) { + if (!pc.f().isa()) continue; + auto callee = module.lookupSymbol(pc.f().getRootReference()); + for (const auto& entry : AccessedGradients({callee}, module)) { + for (const string& source : entry.getSecond()) { + insert(pc.getOperand(entry.getFirst()), source); + } + } + } else if (auto spc = + llvm::dyn_cast(&op)) { + auto callee = module.lookupSymbol(spc.f()); + for (const auto& entry : AccessedGradients({callee}, module)) { + for (const string& source : entry.getSecond()) { + insert(spc.getOperand(entry.getFirst()), source); + } + } + } + } + } + return result; +} + +// Contains cached information for decomposed callee functions for (stateful) +// partitioned call ops. +struct PartitionedCallTensorArrayOpsInfo { + bool signature_change; + FuncOp decomposed_callee; + llvm::SmallVector>, 4> + arg_grads; + llvm::SmallVector, 4> ret_forward_input; +}; + +// Updates a called function's input signature by adjusting resource types, and +// adding required gradient arguments. +void ChangeFunctionInputSignature( + FuncOp func, + const llvm::SmallDenseMap>& grads, + llvm::function_ref ta_arg_buffer_type, + llvm::function_ref ta_accumulate_on_write, + llvm::SmallDenseMap* stats) { + int64_t original_args = func.getNumArguments(); + for (int64_t argnum = 0; argnum < original_args; ++argnum) { + auto arg = func.getArgument(argnum); + Type t = ta_arg_buffer_type(argnum); + if (!t) continue; + arg.setType(t); + auto grad_it = grads.find(argnum); + if (grad_it == grads.end()) continue; + llvm::StringMap grads_map; + for (const string& source : grad_it->getSecond()) { + auto g = func.front().addArgument(t); + (*stats)[g].accumulate_on_write = true; + grads_map[source] = g; + } + auto& stat = (*stats)[arg]; + stat.accumulate_on_write = ta_accumulate_on_write(argnum); + stat.grads = std::move(grads_map); + } + UpdateFuncType(func); +} + +LogicalResult DecomposeTensorArrayOps( + Block*, ModuleOp, llvm::SmallDenseMap*, + llvm::SmallDenseMap*); + +LogicalResult HandleWhileOp( + TF::WhileOp while_op, ModuleOp module, + llvm::SmallDenseMap* stats, + llvm::SmallDenseMap* + decomposed_partitioned_call_callees) { + auto body = module.lookupSymbol(while_op.body()); + auto cond = module.lookupSymbol(while_op.cond()); + auto grads = AccessedGradients({body, cond}, module); + auto ta_arg_buffer_type = [&](int64_t index) -> Type { + auto it = stats->find(while_op.getOperand(index)); + if (it == stats->end()) return nullptr; + return it->getFirst().getType(); + }; + auto ta_accumulate_on_write = [&](int64_t index) { + auto it = stats->find(while_op.getOperand(index)); + if (it == stats->end()) return false; + return it->getSecond().accumulate_on_write; + }; + llvm::SmallDenseMap body_stats; + ChangeFunctionInputSignature(body, grads, ta_arg_buffer_type, + ta_accumulate_on_write, &body_stats); + llvm::SmallDenseMap cond_stats; + ChangeFunctionInputSignature(cond, grads, ta_arg_buffer_type, + ta_accumulate_on_write, &cond_stats); + if (failed(DecomposeTensorArrayOps(&body.front(), module, &body_stats, + decomposed_partitioned_call_callees)) || + failed(DecomposeTensorArrayOps(&cond.front(), module, &cond_stats, + decomposed_partitioned_call_callees))) { + return failure(); + } + if (body_stats.empty() && cond_stats.empty()) return success(); + auto old_body_ret = body.front().getTerminator(); + auto new_retvals = llvm::to_vector<8>(old_body_ret->getOperands()); + for (int64_t i = 0; i < while_op.getNumResults(); ++i) { + if (!ta_arg_buffer_type(i)) continue; + auto retval = old_body_ret->getOperand(i); + auto arg = retval.dyn_cast(); + if (!arg) { + return while_op.emitOpError( + "output tensor array does not alias input in a while loop"); + } + for (const string& source : grads[i]) { + new_retvals.push_back(body_stats[arg].grads[source]); + } + } + OpBuilder(old_body_ret).create(old_body_ret->getLoc(), new_retvals); + old_body_ret->erase(); + UpdateFuncType(body); + // Recreate the while op. + auto operands = llvm::to_vector<8>(while_op.getOperands()); + for (int64_t i = 0; i < while_op.getNumOperands(); ++i) { + auto grad_it = grads.find(i); + auto& stat = (*stats)[operands[i]]; + if (grad_it == grads.end()) continue; + for (const string& source : grad_it->getSecond()) { + auto it = stat.grads.find(source); + if (it != stat.grads.end()) { + operands.push_back(it->second); + } else { + Value grad_var; + if (failed(CreateAndInitializeGradVariable(operands[i].getType(), + while_op, &grad_var))) { + return failure(); + } + stat.grads[source] = grad_var; + operands.push_back(grad_var); + } + } + } + OpBuilder builder(while_op); + auto new_while = + builder.create(while_op.getLoc(), body.getType().getInputs(), + operands, while_op.getAttrs()); + // Clear the output shapes as it is not needed for XLA lowering. + new_while.setAttr("output_shapes", builder.getArrayAttr({})); + for (int64_t i = 0; i < while_op.getNumOperands(); ++i) { + if (ta_arg_buffer_type(i)) { + while_op.getResult(i).replaceAllUsesWith(while_op.getOperand(i)); + } else { + while_op.getResult(i).replaceAllUsesWith(new_while.getResult(i)); + } + } + while_op.erase(); + return success(); +} + +LogicalResult HandleIfOp( + TF::IfOp if_op, ModuleOp module, + llvm::SmallDenseMap* stats, + llvm::SmallDenseMap* + decomposed_partitioned_call_callees) { + auto then_branch = module.lookupSymbol(if_op.then_branch()); + auto else_branch = module.lookupSymbol(if_op.else_branch()); + auto grads = AccessedGradients({then_branch, else_branch}, module); + auto ta_arg_buffer_type = [&](int64_t index) -> Type { + auto it = stats->find(if_op.getOperand(index + 1)); + if (it == stats->end()) return nullptr; + return it->getFirst().getType(); + }; + auto ta_accumulate_on_write = [&](int64_t index) { + auto it = stats->find(if_op.getOperand(index + 1)); + if (it == stats->end()) return false; + return it->getSecond().accumulate_on_write; + }; + llvm::SmallDenseMap then_stats; + ChangeFunctionInputSignature(then_branch, grads, ta_arg_buffer_type, + ta_accumulate_on_write, &then_stats); + llvm::SmallDenseMap else_stats; + ChangeFunctionInputSignature(else_branch, grads, ta_arg_buffer_type, + ta_accumulate_on_write, &else_stats); + if (failed(DecomposeTensorArrayOps(&then_branch.front(), module, &then_stats, + decomposed_partitioned_call_callees)) || + failed(DecomposeTensorArrayOps(&else_branch.front(), module, &else_stats, + decomposed_partitioned_call_callees))) { + return failure(); + } + if (then_stats.empty() && else_stats.empty()) return success(); + // Recreate the if op. + auto operands = llvm::to_vector<8>(if_op.getOperands()); + for (int64_t i = 0; i < if_op.getNumOperands() - 1; ++i) { + auto grad_it = grads.find(i); + auto& stat = (*stats)[operands[i + 1]]; + if (grad_it == grads.end()) continue; + for (const string& source : grad_it->getSecond()) { + auto it = stat.grads.find(source); + if (it != stat.grads.end()) { + operands.push_back(it->second); + } else { + Value grad_var; + if (failed(CreateAndInitializeGradVariable(operands[i + 1].getType(), + if_op, &grad_var))) { + return failure(); + } + stat.grads[source] = grad_var; + operands.push_back(grad_var); + } + } + } + OpBuilder builder(if_op); + auto new_if = builder.create(if_op.getLoc(), + then_branch.getType().getResults(), + operands, if_op.getAttrs()); + // Clear the output shapes as it is not needed for XLA lowering. + new_if.setAttr("output_shapes", builder.getArrayAttr({})); + auto ret_forwards_input = [](FuncOp f, int64_t ret_ind) -> int64_t { + auto retval = f.front().getTerminator()->getOperand(ret_ind); + auto arg = retval.dyn_cast(); + if (!arg) return -1; + return arg.getArgNumber(); + }; + for (int64_t i = 0; i < if_op.getNumResults(); ++i) { + if (!getElementTypeOrSelf(if_op.getResult(i).getType()) + .isa()) { + if_op.getResult(i).replaceAllUsesWith(new_if.getResult(i)); + continue; + } + int64_t then_forward_input = ret_forwards_input(then_branch, i); + int64_t else_foward_input = ret_forwards_input(else_branch, i); + if (then_forward_input != else_foward_input || then_forward_input < 0) { + return if_op.emitOpError( + "branches do not forward the same input resource"); + } + if_op.getResult(i).replaceAllUsesWith( + if_op.getOperand(then_forward_input + 1)); + } + if_op.erase(); + return success(); +} + +template +LogicalResult HandlePartitionedCallOp( + CallOp call, FuncOp callee, ModuleOp module, + llvm::SmallDenseMap* stats, + llvm::SmallDenseMap* + decomposed_partitioned_call_callees) { + auto emplace_res = decomposed_partitioned_call_callees->try_emplace( + callee, PartitionedCallTensorArrayOpsInfo()); + auto& info = emplace_res.first->getSecond(); + // Recreates the call op with info. + auto recreate_caller = [&]() -> LogicalResult { + auto new_operands = llvm::to_vector<8>(call.getOperands()); + for (const auto& entry : info.arg_grads) { + auto it = stats->find(call.getOperand(entry.first)); + if (it == stats->end()) return call.emitOpError("unknown tensor array"); + for (const string& source : entry.second) { + auto grad_it = it->getSecond().grads.find(source); + if (grad_it != it->getSecond().grads.end()) { + new_operands.push_back(grad_it->second); + } else { + Value grad_var; + if (failed(CreateAndInitializeGradVariable(it->getFirst().getType(), + call, &grad_var))) { + return failure(); + } + it->getSecond().grads[source] = grad_var; + new_operands.push_back(grad_var); + } + } + } + OpBuilder builder(call); + auto new_call = builder.create( + call.getLoc(), info.decomposed_callee.getType().getResults(), + new_operands, call.getAttrs()); + new_call.setAttr( + "f", builder.getSymbolRefAttr( + const_cast(info.decomposed_callee).getName())); + for (const auto& entry : info.ret_forward_input) { + call.getResult(entry.first) + .replaceAllUsesWith(call.getOperand(entry.second)); + } + call.replaceAllUsesWith(new_call); + call.erase(); + return success(); + }; + if (!emplace_res.second) { + // This callee was handled before. + if (!info.signature_change) return success(); + return recreate_caller(); + } + // Rewrite the callee on a cloned function. + info.signature_change = false; + auto ta_arg_buffer_type = [&](int64_t index) -> Type { + auto it = stats->find(call.getOperand(index)); + if (it == stats->end()) return nullptr; + info.signature_change = true; + return it->getFirst().getType(); + }; + auto ta_accumulate_on_write = [&](int64_t index) { + auto it = stats->find(call.getOperand(index)); + if (it == stats->end()) return false; + return it->getSecond().accumulate_on_write; + }; + auto callee_clone = callee.clone(); + auto grads = AccessedGradients({callee_clone}, module); + for (int64_t i = 0; i < callee_clone.getNumArguments(); ++i) { + auto it = grads.find(i); + if (it == grads.end()) continue; + info.arg_grads.emplace_back(i, it->getSecond()); + } + llvm::SmallDenseMap callee_stats; + ChangeFunctionInputSignature(callee_clone, grads, ta_arg_buffer_type, + ta_accumulate_on_write, &callee_stats); + if (failed(DecomposeTensorArrayOps(&callee_clone.front(), module, + &callee_stats, + decomposed_partitioned_call_callees))) { + return failure(); + } + for (int64_t i = 0; i < call.getNumResults(); ++i) { + auto ret = callee_clone.front().getTerminator()->getOperand(i); + if (!getElementTypeOrSelf(ret.getType()).isa()) continue; + auto arg = ret.dyn_cast(); + if (!arg) continue; + info.ret_forward_input.emplace_back(i, arg.getArgNumber()); + } + + if (!info.signature_change) { + // Signature is not modified. We do not need to keep two copies. + info.signature_change = false; + auto name = callee.getName(); + callee.erase(); + callee_clone.setName(name); + SymbolTable(module).insert(callee_clone); + } else { + info.decomposed_callee = callee_clone; + // Add the clone with a new name. + auto name = + llvm::formatv("{0}_{1}", callee.getName(), "tensorarray_decomposed") + .str(); + callee_clone.setName(name); + SymbolTable(module).insert(callee_clone); + } + if (info.signature_change) return recreate_caller(); + return success(); +} + +LogicalResult DecomposeTensorArrayOps( + Block* block, ModuleOp module, + llvm::SmallDenseMap* stats, + llvm::SmallDenseMap* + decomposed_partitioned_call_callees) { for (auto& op : llvm::make_early_inc_range(block->getOperations())) { if (llvm::isa(&op) || llvm::isa(&op)) { op.replaceAllUsesWith(op.getOperands()); op.erase(); } else if (auto ta = llvm::dyn_cast(&op)) { - if (failed(HandleTensorArrayV3Op(ta, module, &stats))) { + if (failed(HandleTensorArrayV3Op(ta, module, stats))) { return failure(); } } else if (auto read = llvm::dyn_cast(&op)) { - if (failed(HandleTensorArrayReadV3Op(read, stats))) return failure(); + if (failed(HandleTensorArrayReadV3Op(read, *stats))) return failure(); } else if (auto write = llvm::dyn_cast(&op)) { - if (failed(HandleTensorArrayWriteV3Op(write, stats))) return failure(); + if (failed(HandleTensorArrayWriteV3Op(write, *stats))) return failure(); } else if (auto concat = llvm::dyn_cast(&op)) { - if (failed(HandleTensorArrayConcatV3Op(concat, stats))) return failure(); + if (failed(HandleTensorArrayConcatV3Op(concat, *stats))) return failure(); } else if (auto split = llvm::dyn_cast(&op)) { - if (failed(HandleTensorArraySplitV3Op(split, stats))) return failure(); + if (failed(HandleTensorArraySplitV3Op(split, *stats))) return failure(); } else if (auto size = llvm::dyn_cast(&op)) { - if (failed(HandleTensorArraySizeV3Op(size, stats))) return failure(); + if (failed(HandleTensorArraySizeV3Op(size, *stats))) return failure(); } else if (auto grad = llvm::dyn_cast(&op)) { - if (failed(HandleTensorArrayGradV3Op(grad, &stats))) return failure(); + if (failed(HandleTensorArrayGradV3Op(grad, stats))) return failure(); } else if (auto gather = llvm::dyn_cast(&op)) { - if (failed(HandleTensorArrayGatherV3Op(gather, stats))) return failure(); + if (failed(HandleTensorArrayGatherV3Op(gather, *stats))) return failure(); } else if (auto scatter = llvm::dyn_cast(&op)) { - if (failed(HandleTensorArrayScatterV3Op(scatter, stats))) { + if (failed(HandleTensorArrayScatterV3Op(scatter, *stats))) { return failure(); } } else if (auto close = llvm::dyn_cast(&op)) { close.erase(); + } else if (auto while_op = llvm::dyn_cast(&op)) { + if (failed(HandleWhileOp(while_op, module, stats, + decomposed_partitioned_call_callees))) { + return failure(); + } + } else if (auto if_op = llvm::dyn_cast(&op)) { + if (failed(HandleIfOp(if_op, module, stats, + decomposed_partitioned_call_callees))) { + return failure(); + } + } else if (auto pcall = llvm::dyn_cast(&op)) { + if (!pcall.f().isa()) { + return pcall.emitOpError( + "TensorArray decomposition does not support call with nested " + "references."); + } + if (failed(HandlePartitionedCallOp( + pcall, module.lookupSymbol(pcall.f().getRootReference()), + module, stats, decomposed_partitioned_call_callees))) { + return failure(); + } + } else if (auto spcall = + llvm::dyn_cast(&op)) { + if (failed(HandlePartitionedCallOp( + spcall, module.lookupSymbol(spcall.f()), module, stats, + decomposed_partitioned_call_callees))) { + return failure(); + } } } return success(); @@ -448,7 +877,11 @@ void TensorArrayOpsDecompositionPass::runOnModule() { auto module = getModule(); auto main = module.lookupSymbol("main"); if (!main) return; - if (failed(DecomposeTensorArrayOps(&main.front(), module))) { + llvm::SmallDenseMap stats; + llvm::SmallDenseMap + decomposed_partitioned_call_callees; + if (failed(DecomposeTensorArrayOps(&main.front(), module, &stats, + &decomposed_partitioned_call_callees))) { signalPassFailure(); } } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc index f2a1cc13b01..3e250ec287b 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc @@ -261,6 +261,7 @@ Status ConvertMLIRToXlaComputation( tf2xla.addNestedPass(mlir::createCanonicalizerPass()); tf2xla.addPass(mlir::TF::CreateTensorListOpsDecompositionPass()); tf2xla.addPass(mlir::TF::CreateStackOpsDecompositionPass()); + tf2xla.addPass(mlir::TF::CreateTensorArrayOpsDecompositionPass()); tf2xla.addPass(mlir::TFDevice::CreateDecomposeResourceOpsPass()); tf2xla.addPass(mlir::TF::CreatePromoteResourcesToArgsPass()); // LegalizeTFControlFlow encapsulates arguments for control flow operations