[MLIR:TF/XLA] Support control flow for tensor array decomposition

Gradients are accessed via names, and a gradient can be referenced inside a while loop etc. We first recursively analyze all required gradients in nested functions, and then lift them outside.

PiperOrigin-RevId: 303423766
Change-Id: Icae790de5f4f7a1d8cdec37cac303bf51f2c3306
This commit is contained in:
Yuanzhong Xu 2020-03-27 15:53:42 -07:00 committed by TensorFlower Gardener
parent 353ab1535d
commit 0c6b402cac
3 changed files with 663 additions and 33 deletions

View File

@ -187,3 +187,199 @@ func @main() {
%write3 = "tf.TensorArrayWriteV3"(%grad3#0, %index, %value, %grad3#1) : (tensor<!tf.resource>, tensor<i32>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
return
}
// -----
// Tests while loop with access to the tensor array defined outside and its
// gradient defined inside. The gradient creation should be moved outside.
// CHECK-LABEL: func @main
func @main() -> () {
// CHECK: %[[SIZE:.*]] = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
%size = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
%index = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
// CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
// CHECK: %[[GVAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
%ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
// CHECK: "tf.While"(%[[VAR]], %[[SIZE]], %[[GVAR]])
%1:2 = "tf.While"(%ta#0, %size) {
body = @while_body, cond = @while_cond, device = "", is_stateless = false}
: (tensor<!tf.resource>, tensor<i32>) -> (tensor<!tf.resource>, tensor<i32>)
// CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VAR]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
// CHECK: "tf.Slice"(%[[READ]],
%read = "tf.TensorArrayReadV3"(%1#0, %index, %ta#1) : (tensor<!tf.resource>, tensor<i32>, tensor<f32>) -> tensor<3xf32>
return
}
// CHECK: func @while_body(%[[BARG0:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[BARG1:.*]]: tensor<i32>, %[[BARG2:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>)
func @while_body(%arg0: tensor<!tf.resource>, %arg1: tensor<i32>) -> (tensor<!tf.resource>, tensor<i32>) {
// CHECK: %[[CONST1:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
%const1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
// CHECK: %[[SUB:.*]] = "tf.Sub"(%[[BARG1]], %[[CONST1]])
%sub = "tf.Sub"(%arg1, %const1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
%elem = "tf._SomeOp"() : () -> tensor<3xf32>
%flow = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
// CHECK: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[BARG0]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
// CHECK: %[[UPDATE1:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ1]],
// CHECK: "tf.AssignVariableOp"(%[[BARG0]], %[[UPDATE1]])
%write = "tf.TensorArrayWriteV3"(%arg0, %sub, %elem, %flow) : (tensor<!tf.resource>, tensor<i32>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
%grad:2 = "tf.TensorArrayGradV3"(%arg0, %write) {source = "a"} : (tensor<!tf.resource>, tensor<f32>) -> (tensor<!tf.resource>, tensor<f32>)
// CHECK: %[[READ2:.*]] = "tf.ReadVariableOp"(%[[BARG2]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
// CHECK: %[[UPDATE2:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ2]],
// CHECK: "tf.AssignVariableOp"(%[[BARG2]], %[[UPDATE2]])
%gwrite = "tf.TensorArrayWriteV3"(%grad#0, %sub, %elem, %grad#1) : (tensor<!tf.resource>, tensor<i32>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
// CHECK: return %[[BARG0]], %[[SUB]], %[[BARG2]]
return %arg0, %sub : tensor<!tf.resource>, tensor<i32>
}
// CHECK: func @while_cond(%[[CARG0:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[CARG1:.*]]: tensor<i32>, %[[CARG2:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>)
func @while_cond(%arg0: tensor<!tf.resource>, %arg1: tensor<i32>) -> tensor<i32> {
// CHECK-NEXT: return %[[CARG1]]
return %arg1 : tensor<i32>
}
// -----
// Tests If op with access to the tensor array defined outside and its gradient
// defined inside. The gradient creation should be moved outside.
// CHECK-LABEL: func @main
func @main() -> () {
// CHECK: %[[SIZE:.*]] = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
%size = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
%index = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
// CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
%ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
// CHECK: %[[COND:.*]] = "tf._SomeOp"() : () -> tensor<i1>
%cond = "tf._SomeOp"() : () -> tensor<i1>
// CHECK: %[[GVAR1:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
// CHECK: %[[GVAR2:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
// CHECK: "tf.If"(%[[COND]], %[[VAR]], %[[GVAR1]], %[[GVAR2]])
%1 = "tf.If"(%cond, %ta#0) {
then_branch = @then_branch, else_branch = @else_branch, device = "", is_stateless = false}
: (tensor<i1>, tensor<!tf.resource>) -> tensor<!tf.resource>
// CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VAR]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
// CHECK: "tf.Slice"(%[[READ]],
%read = "tf.TensorArrayReadV3"(%1, %index, %ta#1) : (tensor<!tf.resource>, tensor<i32>, tensor<f32>) -> tensor<3xf32>
return
}
// CHECK: func @then_branch(%[[TARG0:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[TARG1:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[TARG2:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>)
func @then_branch(%arg0: tensor<!tf.resource>) -> tensor<!tf.resource> {
%const1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
%elem = "tf._SomeOp"() : () -> tensor<3xf32>
%flow = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
// CHECK: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[TARG1]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
// CHECK: %[[UPDATE1:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ1]],
// CHECK: "tf.AssignVariableOp"(%[[TARG1]], %[[UPDATE1]])
%grad:2 = "tf.TensorArrayGradV3"(%arg0, %flow) {source = "a"} : (tensor<!tf.resource>, tensor<f32>) -> (tensor<!tf.resource>, tensor<f32>)
%gwrite = "tf.TensorArrayWriteV3"(%grad#0, %const1, %elem, %grad#1) : (tensor<!tf.resource>, tensor<i32>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
// CHECK: return %[[TARG0]]
return %arg0 : tensor<!tf.resource>
}
// CHECK: func @else_branch(%[[EARG0:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[EARG1:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[EARG2:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>)
func @else_branch(%arg0: tensor<!tf.resource>) -> tensor<!tf.resource> {
// CHECK: %[[CONST1:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
%const1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
%elem = "tf._SomeOp"() : () -> tensor<3xf32>
%flow = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
// CHECK: %[[READ2:.*]] = "tf.ReadVariableOp"(%[[EARG2]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
// CHECK: %[[UPDATE2:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ2]],
// CHECK: "tf.AssignVariableOp"(%[[EARG2]], %[[UPDATE2]])
%grad:2 = "tf.TensorArrayGradV3"(%arg0, %flow) {source = "b"} : (tensor<!tf.resource>, tensor<f32>) -> (tensor<!tf.resource>, tensor<f32>)
%gwrite = "tf.TensorArrayWriteV3"(%grad#0, %const1, %elem, %grad#1) : (tensor<!tf.resource>, tensor<i32>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
// CHECK: return %[[EARG0]]
return %arg0 : tensor<!tf.resource>
}
// -----
// Tests (Stateful)PartitionedCall op with access to the tensor array defined
// outside and its gradient defined inside. The gradient creation should be
// moved outside.
// CHECK-LABEL: func @main
func @main() -> () {
// CHECK: %[[SIZE:.*]] = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
%size = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
%index = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
// CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
%ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
// CHECK: %[[COND:.*]] = "tf._SomeOp"() : () -> tensor<i1>
%cond = "tf._SomeOp"() : () -> tensor<i1>
// CHECK: %[[GVAR1:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
%grad:2 = "tf.TensorArrayGradV3"(%ta#0, %ta#1) {source = "a"} : (tensor<!tf.resource>, tensor<f32>) -> (tensor<!tf.resource>, tensor<f32>)
// CHECK: %[[GVAR2:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
// CHECK: "tf.StatefulPartitionedCall"(%[[VAR]], %[[GVAR1]], %[[GVAR2]])
// CHECK-SAME: f = @callee_tensorarray_decomposed
%call = "tf.StatefulPartitionedCall"(%ta#0) {f = @callee, config = "", config_proto = "", executor_type = ""}
: (tensor<!tf.resource>) -> tensor<!tf.resource>
// CHECK: "tf.PartitionedCall"(%[[VAR]], %[[GVAR1]], %[[GVAR2]])
// CHECK-SAME: f = @callee_tensorarray_decomposed
%call2 = "tf.PartitionedCall"(%call) {f = @callee, config = "", config_proto = "", executor_type = ""}
: (tensor<!tf.resource>) -> tensor<!tf.resource>
// CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VAR]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
// CHECK: "tf.Slice"(%[[READ]],
%read = "tf.TensorArrayReadV3"(%call2, %index, %ta#1) : (tensor<!tf.resource>, tensor<i32>, tensor<f32>) -> tensor<3xf32>
return
}
// CHECK-LABEL: func @callee
// CHECK-SAME: (%[[OCARG0:.*]]: tensor<!tf.resource>) -> tensor<!tf.resource>
func @callee(%arg0: tensor<!tf.resource>) -> tensor<!tf.resource> {
%const1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
%elem = "tf._SomeOp"() : () -> tensor<3xf32>
%flow = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
%grad:2 = "tf.TensorArrayGradV3"(%arg0, %flow) {source = "a"} : (tensor<!tf.resource>, tensor<f32>) -> (tensor<!tf.resource>, tensor<f32>)
%gwrite = "tf.TensorArrayWriteV3"(%grad#0, %const1, %elem, %grad#1) : (tensor<!tf.resource>, tensor<i32>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
%grad2:2 = "tf.TensorArrayGradV3"(%arg0, %flow) {source = "b"} : (tensor<!tf.resource>, tensor<f32>) -> (tensor<!tf.resource>, tensor<f32>)
%gwrite2 = "tf.TensorArrayWriteV3"(%grad2#0, %const1, %elem, %grad2#1) : (tensor<!tf.resource>, tensor<i32>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
return %arg0 : tensor<!tf.resource>
}
// CHECK: func @callee_tensorarray_decomposed(%[[CARG0:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[CARG1:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[CARG2:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>)
// CHECK: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[CARG1]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
// CHECK: %[[UPDATE1:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ1]],
// CHECK: "tf.AssignVariableOp"(%[[CARG1]], %[[UPDATE1]])
// CHECK: %[[READ2:.*]] = "tf.ReadVariableOp"(%[[CARG2]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
// CHECK: %[[UPDATE2:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ2]],
// CHECK: "tf.AssignVariableOp"(%[[CARG2]], %[[UPDATE2]])
// CHECK: return %[[CARG0]]
// -----
// Test the pass reports failure on unknown size.
func @main(%arg0: tensor<i32>) -> () {
// expected-error @+1 {{unknown max element count}}
%ta:2 = "tf.TensorArrayV3"(%arg0) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
return
}
// -----
// Test the pass reports failure on unknown shape.
func @main(%arg0: tensor<i32>) -> () {
%size = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
// expected-error @+1 {{unknown element shape}}
%ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$unknown_rank: true", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
return
}
// -----
// Tests that the pass reports error on ambiguous tensor array.
func @main(%arg0: tensor<i1>) -> () {
%size = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
%ta0:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
%ta1:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
%if_op = "tf.If"(%arg0, %ta0#0, %ta1#0) {then_branch = @if_then, else_branch = @if_else, is_stateless = false}
: (tensor<i1>, tensor<!tf.resource>, tensor<!tf.resource>) -> tensor<!tf.resource>
%index = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
// expected-error @+1 {{unknown tensor array}}
%read = "tf.TensorArrayReadV3"(%if_op, %index, %ta0#1) : (tensor<!tf.resource>, tensor<i32>, tensor<f32>) -> tensor<3xf32>
return
}
func @if_then(%arg0: tensor<!tf.resource>, %arg1: tensor<!tf.resource>) -> tensor<!tf.resource> {
return %arg0 : tensor<!tf.resource>
}
func @if_else(%arg0: tensor<!tf.resource>, %arg1: tensor<!tf.resource>) -> tensor<!tf.resource> {
return %arg1 : tensor<!tf.resource>
}

View File

@ -19,7 +19,8 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
@ -55,6 +56,8 @@ namespace {
namespace cutil = TF::collection_ops_util;
using std::string;
// A pass that converts tensor array operations to tensor operations and
// read/assign ops on local variables. A later resource lifting pass can further
// remove the local variables.
@ -85,7 +88,7 @@ LogicalResult GetSplitElementTypeAndCount(TF::TensorArraySplitV3Op split,
return split.emitOpError("unknown or invalid split tensor shape");
}
int64_t length = buffer_type.getDimSize(0) / *count;
for (auto len : lengths_const.value().getValues<APInt>()) {
for (const auto& len : lengths_const.value().getValues<APInt>()) {
if (length == len.getSExtValue()) continue;
return split.emitOpError("different split lengths are not supported");
}
@ -145,7 +148,7 @@ struct TensorArrayStats {
// this is a gradient.
bool accumulate_on_write;
// Maps from a gradient source string to the local variable to the gradient.
llvm::SmallDenseMap<llvm::StringRef, Value> grads;
llvm::StringMap<Value> grads;
};
LogicalResult HandleTensorArrayV3Op(
@ -224,10 +227,7 @@ LogicalResult HandleTensorArrayWriteV3Op(
cutil::GetElement(index_reshape, buffer, builder, write.getLoc(),
/*keep_slice_shape=*/true);
// Add a size-1 leading dimension to elem.
for (auto dim : buffer.getType().cast<RankedTensorType>().getShape())
LOG(ERROR) << " buffer : " << dim;
auto slice_type = original_elem.getType().cast<RankedTensorType>();
for (auto dim : slice_type.getShape()) LOG(ERROR) << " resahpe : " << dim;
elem = builder.create<TF::ReshapeOp>(
write.getLoc(), ArrayRef<Type>{slice_type},
ArrayRef<Value>{elem, cutil::GetR1Const(slice_type.getShape(), builder,
@ -339,6 +339,26 @@ LogicalResult HandleTensorArraySizeV3Op(
return success();
}
LogicalResult CreateAndInitializeGradVariable(Type local_var_type,
Operation* op, Value* var) {
OpBuilder builder(op);
*var = builder.create<TF::MlirLocalVarOp>(
op->getLoc(), ArrayRef<Type>{local_var_type}, ArrayRef<Value>{},
ArrayRef<NamedAttribute>{});
Value buffer;
auto buffer_type = getElementTypeOrSelf(local_var_type)
.cast<TF::ResourceType>()
.getSubtypes()[0]
.cast<RankedTensorType>();
if (failed(cutil::CreateInitBufferValue(
buffer_type.getShape().drop_front(), buffer_type.getDimSize(0), op,
buffer_type.getElementType(), builder, &buffer))) {
return failure();
}
cutil::WriteLocalVariable(*var, buffer, builder, op->getLoc());
return success();
}
LogicalResult HandleTensorArrayGradV3Op(
TF::TensorArrayGradV3Op grad,
llvm::SmallDenseMap<Value, TensorArrayStats>* stats) {
@ -347,26 +367,17 @@ LogicalResult HandleTensorArrayGradV3Op(
Value grad_var;
auto sit = stats->find(local_var);
if (sit == stats->end()) return grad.emitOpError("unknown tensor array");
auto emplace_res = sit->getSecond().grads.try_emplace(grad.source(), Value());
auto emplace_res =
sit->getSecond().grads.try_emplace(grad.source().str(), Value());
if (!emplace_res.second) {
// If the source has been assigned a grad, use it.
grad_var = emplace_res.first->getSecond();
grad_var = emplace_res.first->second;
} else {
grad_var = builder.create<TF::MlirLocalVarOp>(
grad.getLoc(), ArrayRef<Type>{local_var.getType()}, ArrayRef<Value>{},
ArrayRef<NamedAttribute>{});
Value buffer;
auto buffer_type = getElementTypeOrSelf(local_var.getType())
.cast<TF::ResourceType>()
.getSubtypes()[0]
.cast<RankedTensorType>();
if (failed(cutil::CreateInitBufferValue(
buffer_type.getShape().drop_front(), buffer_type.getDimSize(0),
grad, buffer_type.getElementType(), builder, &buffer))) {
if (failed(CreateAndInitializeGradVariable(local_var.getType(), grad,
&grad_var))) {
return failure();
}
cutil::WriteLocalVariable(grad_var, buffer, builder, grad.getLoc());
emplace_res.first->getSecond() = grad_var;
emplace_res.first->second = grad_var;
// Write to a grad accumulates with previous writes.
(*stats)[grad_var].accumulate_on_write = true;
}
@ -409,36 +420,454 @@ LogicalResult HandleTensorArrayScatterV3Op(
return success();
}
LogicalResult DecomposeTensorArrayOps(Block* block, ModuleOp module) {
llvm::SmallDenseMap<Value, TensorArrayStats> stats;
// Updates func's type according to its current arguments and return values.
void UpdateFuncType(FuncOp func) {
llvm::SmallVector<Type, 8> arg_types;
for (auto arg : func.getArguments()) arg_types.push_back(arg.getType());
func.setType(FunctionType::get(
arg_types,
llvm::to_vector<8>(func.front().getTerminator()->getOperandTypes()),
func.getContext()));
}
// Finds the accessed gradient sources for each tensor array argument.
llvm::SmallDenseMap<int64_t, llvm::SmallVector<string, 4>> AccessedGradients(
ArrayRef<FuncOp> funcs, ModuleOp module) {
llvm::SmallDenseMap<int64_t, llvm::SmallVector<string, 4>> result;
llvm::SmallDenseMap<int64_t, llvm::StringSet<>> result_sets;
auto insert = [&](Value v, const string& source) {
auto arg = v.cast<BlockArgument>();
if (!arg) return;
auto insert_res = result_sets[arg.getArgNumber()].insert(source);
if (!insert_res.second) return;
result[arg.getArgNumber()].push_back(source);
};
for (FuncOp func : funcs) {
for (auto& op : func.front().getOperations()) {
if (llvm::isa<TF::IdentityOp>(&op) || llvm::isa<TF::IdentityNOp>(&op)) {
op.replaceAllUsesWith(op.getOperands());
continue;
}
if (auto grad = llvm::dyn_cast<TF::TensorArrayGradV3Op>(&op)) {
insert(grad.handle(), grad.source().str());
} else if (auto while_op = llvm::dyn_cast<TF::WhileOp>(&op)) {
auto body = module.lookupSymbol<FuncOp>(while_op.body());
auto cond = module.lookupSymbol<FuncOp>(while_op.cond());
for (const auto& entry : AccessedGradients({body, cond}, module)) {
for (const string& source : entry.getSecond()) {
insert(while_op.getOperand(entry.getFirst()), source);
}
}
} else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
auto then_branch = module.lookupSymbol<FuncOp>(if_op.then_branch());
auto else_branch = module.lookupSymbol<FuncOp>(if_op.else_branch());
for (const auto& entry :
AccessedGradients({then_branch, else_branch}, module)) {
for (const string& source : entry.getSecond()) {
insert(if_op.getOperand(entry.getFirst() + 1), source);
}
}
} else if (auto pc = llvm::dyn_cast<TF::PartitionedCallOp>(&op)) {
if (!pc.f().isa<FlatSymbolRefAttr>()) continue;
auto callee = module.lookupSymbol<FuncOp>(pc.f().getRootReference());
for (const auto& entry : AccessedGradients({callee}, module)) {
for (const string& source : entry.getSecond()) {
insert(pc.getOperand(entry.getFirst()), source);
}
}
} else if (auto spc =
llvm::dyn_cast<TF::StatefulPartitionedCallOp>(&op)) {
auto callee = module.lookupSymbol<FuncOp>(spc.f());
for (const auto& entry : AccessedGradients({callee}, module)) {
for (const string& source : entry.getSecond()) {
insert(spc.getOperand(entry.getFirst()), source);
}
}
}
}
}
return result;
}
// Contains cached information for decomposed callee functions for (stateful)
// partitioned call ops.
struct PartitionedCallTensorArrayOpsInfo {
bool signature_change;
FuncOp decomposed_callee;
llvm::SmallVector<std::pair<int64_t, llvm::SmallVector<string, 4>>, 4>
arg_grads;
llvm::SmallVector<std::pair<int64_t, int64_t>, 4> ret_forward_input;
};
// Updates a called function's input signature by adjusting resource types, and
// adding required gradient arguments.
void ChangeFunctionInputSignature(
FuncOp func,
const llvm::SmallDenseMap<int64_t, llvm::SmallVector<string, 4>>& grads,
llvm::function_ref<Type(int64_t)> ta_arg_buffer_type,
llvm::function_ref<bool(int64_t)> ta_accumulate_on_write,
llvm::SmallDenseMap<Value, TensorArrayStats>* stats) {
int64_t original_args = func.getNumArguments();
for (int64_t argnum = 0; argnum < original_args; ++argnum) {
auto arg = func.getArgument(argnum);
Type t = ta_arg_buffer_type(argnum);
if (!t) continue;
arg.setType(t);
auto grad_it = grads.find(argnum);
if (grad_it == grads.end()) continue;
llvm::StringMap<Value> grads_map;
for (const string& source : grad_it->getSecond()) {
auto g = func.front().addArgument(t);
(*stats)[g].accumulate_on_write = true;
grads_map[source] = g;
}
auto& stat = (*stats)[arg];
stat.accumulate_on_write = ta_accumulate_on_write(argnum);
stat.grads = std::move(grads_map);
}
UpdateFuncType(func);
}
LogicalResult DecomposeTensorArrayOps(
Block*, ModuleOp, llvm::SmallDenseMap<Value, TensorArrayStats>*,
llvm::SmallDenseMap<FuncOp, PartitionedCallTensorArrayOpsInfo>*);
LogicalResult HandleWhileOp(
TF::WhileOp while_op, ModuleOp module,
llvm::SmallDenseMap<Value, TensorArrayStats>* stats,
llvm::SmallDenseMap<FuncOp, PartitionedCallTensorArrayOpsInfo>*
decomposed_partitioned_call_callees) {
auto body = module.lookupSymbol<FuncOp>(while_op.body());
auto cond = module.lookupSymbol<FuncOp>(while_op.cond());
auto grads = AccessedGradients({body, cond}, module);
auto ta_arg_buffer_type = [&](int64_t index) -> Type {
auto it = stats->find(while_op.getOperand(index));
if (it == stats->end()) return nullptr;
return it->getFirst().getType();
};
auto ta_accumulate_on_write = [&](int64_t index) {
auto it = stats->find(while_op.getOperand(index));
if (it == stats->end()) return false;
return it->getSecond().accumulate_on_write;
};
llvm::SmallDenseMap<Value, TensorArrayStats> body_stats;
ChangeFunctionInputSignature(body, grads, ta_arg_buffer_type,
ta_accumulate_on_write, &body_stats);
llvm::SmallDenseMap<Value, TensorArrayStats> cond_stats;
ChangeFunctionInputSignature(cond, grads, ta_arg_buffer_type,
ta_accumulate_on_write, &cond_stats);
if (failed(DecomposeTensorArrayOps(&body.front(), module, &body_stats,
decomposed_partitioned_call_callees)) ||
failed(DecomposeTensorArrayOps(&cond.front(), module, &cond_stats,
decomposed_partitioned_call_callees))) {
return failure();
}
if (body_stats.empty() && cond_stats.empty()) return success();
auto old_body_ret = body.front().getTerminator();
auto new_retvals = llvm::to_vector<8>(old_body_ret->getOperands());
for (int64_t i = 0; i < while_op.getNumResults(); ++i) {
if (!ta_arg_buffer_type(i)) continue;
auto retval = old_body_ret->getOperand(i);
auto arg = retval.dyn_cast<BlockArgument>();
if (!arg) {
return while_op.emitOpError(
"output tensor array does not alias input in a while loop");
}
for (const string& source : grads[i]) {
new_retvals.push_back(body_stats[arg].grads[source]);
}
}
OpBuilder(old_body_ret).create<ReturnOp>(old_body_ret->getLoc(), new_retvals);
old_body_ret->erase();
UpdateFuncType(body);
// Recreate the while op.
auto operands = llvm::to_vector<8>(while_op.getOperands());
for (int64_t i = 0; i < while_op.getNumOperands(); ++i) {
auto grad_it = grads.find(i);
auto& stat = (*stats)[operands[i]];
if (grad_it == grads.end()) continue;
for (const string& source : grad_it->getSecond()) {
auto it = stat.grads.find(source);
if (it != stat.grads.end()) {
operands.push_back(it->second);
} else {
Value grad_var;
if (failed(CreateAndInitializeGradVariable(operands[i].getType(),
while_op, &grad_var))) {
return failure();
}
stat.grads[source] = grad_var;
operands.push_back(grad_var);
}
}
}
OpBuilder builder(while_op);
auto new_while =
builder.create<TF::WhileOp>(while_op.getLoc(), body.getType().getInputs(),
operands, while_op.getAttrs());
// Clear the output shapes as it is not needed for XLA lowering.
new_while.setAttr("output_shapes", builder.getArrayAttr({}));
for (int64_t i = 0; i < while_op.getNumOperands(); ++i) {
if (ta_arg_buffer_type(i)) {
while_op.getResult(i).replaceAllUsesWith(while_op.getOperand(i));
} else {
while_op.getResult(i).replaceAllUsesWith(new_while.getResult(i));
}
}
while_op.erase();
return success();
}
LogicalResult HandleIfOp(
TF::IfOp if_op, ModuleOp module,
llvm::SmallDenseMap<Value, TensorArrayStats>* stats,
llvm::SmallDenseMap<FuncOp, PartitionedCallTensorArrayOpsInfo>*
decomposed_partitioned_call_callees) {
auto then_branch = module.lookupSymbol<FuncOp>(if_op.then_branch());
auto else_branch = module.lookupSymbol<FuncOp>(if_op.else_branch());
auto grads = AccessedGradients({then_branch, else_branch}, module);
auto ta_arg_buffer_type = [&](int64_t index) -> Type {
auto it = stats->find(if_op.getOperand(index + 1));
if (it == stats->end()) return nullptr;
return it->getFirst().getType();
};
auto ta_accumulate_on_write = [&](int64_t index) {
auto it = stats->find(if_op.getOperand(index + 1));
if (it == stats->end()) return false;
return it->getSecond().accumulate_on_write;
};
llvm::SmallDenseMap<Value, TensorArrayStats> then_stats;
ChangeFunctionInputSignature(then_branch, grads, ta_arg_buffer_type,
ta_accumulate_on_write, &then_stats);
llvm::SmallDenseMap<Value, TensorArrayStats> else_stats;
ChangeFunctionInputSignature(else_branch, grads, ta_arg_buffer_type,
ta_accumulate_on_write, &else_stats);
if (failed(DecomposeTensorArrayOps(&then_branch.front(), module, &then_stats,
decomposed_partitioned_call_callees)) ||
failed(DecomposeTensorArrayOps(&else_branch.front(), module, &else_stats,
decomposed_partitioned_call_callees))) {
return failure();
}
if (then_stats.empty() && else_stats.empty()) return success();
// Recreate the if op.
auto operands = llvm::to_vector<8>(if_op.getOperands());
for (int64_t i = 0; i < if_op.getNumOperands() - 1; ++i) {
auto grad_it = grads.find(i);
auto& stat = (*stats)[operands[i + 1]];
if (grad_it == grads.end()) continue;
for (const string& source : grad_it->getSecond()) {
auto it = stat.grads.find(source);
if (it != stat.grads.end()) {
operands.push_back(it->second);
} else {
Value grad_var;
if (failed(CreateAndInitializeGradVariable(operands[i + 1].getType(),
if_op, &grad_var))) {
return failure();
}
stat.grads[source] = grad_var;
operands.push_back(grad_var);
}
}
}
OpBuilder builder(if_op);
auto new_if = builder.create<TF::IfOp>(if_op.getLoc(),
then_branch.getType().getResults(),
operands, if_op.getAttrs());
// Clear the output shapes as it is not needed for XLA lowering.
new_if.setAttr("output_shapes", builder.getArrayAttr({}));
auto ret_forwards_input = [](FuncOp f, int64_t ret_ind) -> int64_t {
auto retval = f.front().getTerminator()->getOperand(ret_ind);
auto arg = retval.dyn_cast<BlockArgument>();
if (!arg) return -1;
return arg.getArgNumber();
};
for (int64_t i = 0; i < if_op.getNumResults(); ++i) {
if (!getElementTypeOrSelf(if_op.getResult(i).getType())
.isa<TF::ResourceType>()) {
if_op.getResult(i).replaceAllUsesWith(new_if.getResult(i));
continue;
}
int64_t then_forward_input = ret_forwards_input(then_branch, i);
int64_t else_foward_input = ret_forwards_input(else_branch, i);
if (then_forward_input != else_foward_input || then_forward_input < 0) {
return if_op.emitOpError(
"branches do not forward the same input resource");
}
if_op.getResult(i).replaceAllUsesWith(
if_op.getOperand(then_forward_input + 1));
}
if_op.erase();
return success();
}
template <typename CallOp>
LogicalResult HandlePartitionedCallOp(
CallOp call, FuncOp callee, ModuleOp module,
llvm::SmallDenseMap<Value, TensorArrayStats>* stats,
llvm::SmallDenseMap<FuncOp, PartitionedCallTensorArrayOpsInfo>*
decomposed_partitioned_call_callees) {
auto emplace_res = decomposed_partitioned_call_callees->try_emplace(
callee, PartitionedCallTensorArrayOpsInfo());
auto& info = emplace_res.first->getSecond();
// Recreates the call op with info.
auto recreate_caller = [&]() -> LogicalResult {
auto new_operands = llvm::to_vector<8>(call.getOperands());
for (const auto& entry : info.arg_grads) {
auto it = stats->find(call.getOperand(entry.first));
if (it == stats->end()) return call.emitOpError("unknown tensor array");
for (const string& source : entry.second) {
auto grad_it = it->getSecond().grads.find(source);
if (grad_it != it->getSecond().grads.end()) {
new_operands.push_back(grad_it->second);
} else {
Value grad_var;
if (failed(CreateAndInitializeGradVariable(it->getFirst().getType(),
call, &grad_var))) {
return failure();
}
it->getSecond().grads[source] = grad_var;
new_operands.push_back(grad_var);
}
}
}
OpBuilder builder(call);
auto new_call = builder.create<CallOp>(
call.getLoc(), info.decomposed_callee.getType().getResults(),
new_operands, call.getAttrs());
new_call.setAttr(
"f", builder.getSymbolRefAttr(
const_cast<FuncOp&>(info.decomposed_callee).getName()));
for (const auto& entry : info.ret_forward_input) {
call.getResult(entry.first)
.replaceAllUsesWith(call.getOperand(entry.second));
}
call.replaceAllUsesWith(new_call);
call.erase();
return success();
};
if (!emplace_res.second) {
// This callee was handled before.
if (!info.signature_change) return success();
return recreate_caller();
}
// Rewrite the callee on a cloned function.
info.signature_change = false;
auto ta_arg_buffer_type = [&](int64_t index) -> Type {
auto it = stats->find(call.getOperand(index));
if (it == stats->end()) return nullptr;
info.signature_change = true;
return it->getFirst().getType();
};
auto ta_accumulate_on_write = [&](int64_t index) {
auto it = stats->find(call.getOperand(index));
if (it == stats->end()) return false;
return it->getSecond().accumulate_on_write;
};
auto callee_clone = callee.clone();
auto grads = AccessedGradients({callee_clone}, module);
for (int64_t i = 0; i < callee_clone.getNumArguments(); ++i) {
auto it = grads.find(i);
if (it == grads.end()) continue;
info.arg_grads.emplace_back(i, it->getSecond());
}
llvm::SmallDenseMap<Value, TensorArrayStats> callee_stats;
ChangeFunctionInputSignature(callee_clone, grads, ta_arg_buffer_type,
ta_accumulate_on_write, &callee_stats);
if (failed(DecomposeTensorArrayOps(&callee_clone.front(), module,
&callee_stats,
decomposed_partitioned_call_callees))) {
return failure();
}
for (int64_t i = 0; i < call.getNumResults(); ++i) {
auto ret = callee_clone.front().getTerminator()->getOperand(i);
if (!getElementTypeOrSelf(ret.getType()).isa<TF::ResourceType>()) continue;
auto arg = ret.dyn_cast<BlockArgument>();
if (!arg) continue;
info.ret_forward_input.emplace_back(i, arg.getArgNumber());
}
if (!info.signature_change) {
// Signature is not modified. We do not need to keep two copies.
info.signature_change = false;
auto name = callee.getName();
callee.erase();
callee_clone.setName(name);
SymbolTable(module).insert(callee_clone);
} else {
info.decomposed_callee = callee_clone;
// Add the clone with a new name.
auto name =
llvm::formatv("{0}_{1}", callee.getName(), "tensorarray_decomposed")
.str();
callee_clone.setName(name);
SymbolTable(module).insert(callee_clone);
}
if (info.signature_change) return recreate_caller();
return success();
}
LogicalResult DecomposeTensorArrayOps(
Block* block, ModuleOp module,
llvm::SmallDenseMap<Value, TensorArrayStats>* stats,
llvm::SmallDenseMap<FuncOp, PartitionedCallTensorArrayOpsInfo>*
decomposed_partitioned_call_callees) {
for (auto& op : llvm::make_early_inc_range(block->getOperations())) {
if (llvm::isa<TF::IdentityOp>(&op) || llvm::isa<TF::IdentityNOp>(&op)) {
op.replaceAllUsesWith(op.getOperands());
op.erase();
} else if (auto ta = llvm::dyn_cast<TF::TensorArrayV3Op>(&op)) {
if (failed(HandleTensorArrayV3Op(ta, module, &stats))) {
if (failed(HandleTensorArrayV3Op(ta, module, stats))) {
return failure();
}
} else if (auto read = llvm::dyn_cast<TF::TensorArrayReadV3Op>(&op)) {
if (failed(HandleTensorArrayReadV3Op(read, stats))) return failure();
if (failed(HandleTensorArrayReadV3Op(read, *stats))) return failure();
} else if (auto write = llvm::dyn_cast<TF::TensorArrayWriteV3Op>(&op)) {
if (failed(HandleTensorArrayWriteV3Op(write, stats))) return failure();
if (failed(HandleTensorArrayWriteV3Op(write, *stats))) return failure();
} else if (auto concat = llvm::dyn_cast<TF::TensorArrayConcatV3Op>(&op)) {
if (failed(HandleTensorArrayConcatV3Op(concat, stats))) return failure();
if (failed(HandleTensorArrayConcatV3Op(concat, *stats))) return failure();
} else if (auto split = llvm::dyn_cast<TF::TensorArraySplitV3Op>(&op)) {
if (failed(HandleTensorArraySplitV3Op(split, stats))) return failure();
if (failed(HandleTensorArraySplitV3Op(split, *stats))) return failure();
} else if (auto size = llvm::dyn_cast<TF::TensorArraySizeV3Op>(&op)) {
if (failed(HandleTensorArraySizeV3Op(size, stats))) return failure();
if (failed(HandleTensorArraySizeV3Op(size, *stats))) return failure();
} else if (auto grad = llvm::dyn_cast<TF::TensorArrayGradV3Op>(&op)) {
if (failed(HandleTensorArrayGradV3Op(grad, &stats))) return failure();
if (failed(HandleTensorArrayGradV3Op(grad, stats))) return failure();
} else if (auto gather = llvm::dyn_cast<TF::TensorArrayGatherV3Op>(&op)) {
if (failed(HandleTensorArrayGatherV3Op(gather, stats))) return failure();
if (failed(HandleTensorArrayGatherV3Op(gather, *stats))) return failure();
} else if (auto scatter = llvm::dyn_cast<TF::TensorArrayScatterV3Op>(&op)) {
if (failed(HandleTensorArrayScatterV3Op(scatter, stats))) {
if (failed(HandleTensorArrayScatterV3Op(scatter, *stats))) {
return failure();
}
} else if (auto close = llvm::dyn_cast<TF::TensorArrayCloseV3Op>(&op)) {
close.erase();
} else if (auto while_op = llvm::dyn_cast<TF::WhileOp>(&op)) {
if (failed(HandleWhileOp(while_op, module, stats,
decomposed_partitioned_call_callees))) {
return failure();
}
} else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
if (failed(HandleIfOp(if_op, module, stats,
decomposed_partitioned_call_callees))) {
return failure();
}
} else if (auto pcall = llvm::dyn_cast<TF::PartitionedCallOp>(&op)) {
if (!pcall.f().isa<FlatSymbolRefAttr>()) {
return pcall.emitOpError(
"TensorArray decomposition does not support call with nested "
"references.");
}
if (failed(HandlePartitionedCallOp(
pcall, module.lookupSymbol<FuncOp>(pcall.f().getRootReference()),
module, stats, decomposed_partitioned_call_callees))) {
return failure();
}
} else if (auto spcall =
llvm::dyn_cast<TF::StatefulPartitionedCallOp>(&op)) {
if (failed(HandlePartitionedCallOp(
spcall, module.lookupSymbol<FuncOp>(spcall.f()), module, stats,
decomposed_partitioned_call_callees))) {
return failure();
}
}
}
return success();
@ -448,7 +877,11 @@ void TensorArrayOpsDecompositionPass::runOnModule() {
auto module = getModule();
auto main = module.lookupSymbol<FuncOp>("main");
if (!main) return;
if (failed(DecomposeTensorArrayOps(&main.front(), module))) {
llvm::SmallDenseMap<Value, TensorArrayStats> stats;
llvm::SmallDenseMap<FuncOp, PartitionedCallTensorArrayOpsInfo>
decomposed_partitioned_call_callees;
if (failed(DecomposeTensorArrayOps(&main.front(), module, &stats,
&decomposed_partitioned_call_callees))) {
signalPassFailure();
}
}

View File

@ -261,6 +261,7 @@ Status ConvertMLIRToXlaComputation(
tf2xla.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
tf2xla.addPass(mlir::TF::CreateTensorListOpsDecompositionPass());
tf2xla.addPass(mlir::TF::CreateStackOpsDecompositionPass());
tf2xla.addPass(mlir::TF::CreateTensorArrayOpsDecompositionPass());
tf2xla.addPass(mlir::TFDevice::CreateDecomposeResourceOpsPass());
tf2xla.addPass(mlir::TF::CreatePromoteResourcesToArgsPass());
// LegalizeTFControlFlow encapsulates arguments for control flow operations