Add time-major support for keras_lstm fused ops.
TFL only supports time-majored, so it's not time-majored, we will transpose both inputs & outputs. PiperOrigin-RevId: 297051105 Change-Id: Icf576a6451097ed100ddf768263d77290d699380
This commit is contained in:
parent
24b2d0252b
commit
be6324180a
@ -154,7 +154,7 @@ func @layernormalizedlstmcellsimple(%arg0: tensor<1x?xf32>, %arg1: tensor<3x4xf3
|
||||
// -----
|
||||
|
||||
module {
|
||||
func @inference_standard_lstm_7410(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x?x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.signature.is_stateful} {
|
||||
func @inference_standard_lstm_time_major(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x?x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = true} {
|
||||
%0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<?x8x8xf32>, tensor<8x40xf32>) -> tensor<?x8x40xf32>
|
||||
%1 = "tf.Add"(%0, %arg5) : (tensor<?x8x40xf32>, tensor<40xf32>) -> tensor<?x8x40xf32>
|
||||
%2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<?x8x40xf32>, tensor<10x40xf32>) -> tensor<?x8x10xf32>
|
||||
@ -165,7 +165,7 @@ func @inference_standard_lstm_7410(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10x
|
||||
return %5, %4, %5, %5, %6 : tensor<?x10xf32>, tensor<?x?x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK: func @inference_standard_lstm_7410([[VAL_0:%.*]]: tensor<?x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<?x8x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.signature.is_stateful} {
|
||||
// CHECK: func @inference_standard_lstm_time_major([[VAL_0:%.*]]: tensor<?x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<8x?x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = true} {
|
||||
// CHECK: [[VAL_6:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
|
||||
// CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi64>) -> tensor<40x8xf32>
|
||||
// CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
|
||||
@ -181,7 +181,46 @@ func @inference_standard_lstm_7410(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10x
|
||||
// CHECK: [[VAL_18:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_16]], [[VAL_17]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
|
||||
// CHECK: [[VAL_19:%.*]] = constant unit
|
||||
// CHECK: [[VAL_20:%.*]] = "tfl.lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) ( {
|
||||
// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<?x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<?x8x10xf32>
|
||||
// CHECK: return [[VAL_21:%.*]] : tensor<?x8x10xf32>
|
||||
|
||||
// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<?x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<8x?x10xf32>
|
||||
// CHECK: return [[VAL_21:%.*]] : tensor<8x?x10xf32>
|
||||
// CHECK: }
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module {
|
||||
func @inference_standard_lstm_non_time_major(%arg0: tensor<8x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<8x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = false} {
|
||||
%0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<8x8x8xf32>, tensor<8x40xf32>) -> tensor<8x8x40xf32>
|
||||
%1 = "tf.Add"(%0, %arg5) : (tensor<8x8x40xf32>, tensor<40xf32>) -> tensor<8x8x40xf32>
|
||||
%2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<8x8x40xf32>, tensor<10x40xf32>) -> tensor<8x8x10xf32>
|
||||
%3 = "tf.Add"(%2, %arg1) : (tensor<8x8x10xf32>, tensor<?x10xf32>) -> tensor<8x8x10xf32>
|
||||
%4 = "tf.Add"(%2, %arg2) : (tensor<8x8x10xf32>, tensor<?x10xf32>) -> tensor<8x8x10xf32>
|
||||
%5 = "tf.Add"(%arg1, %arg2) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<?x10xf32>
|
||||
%6 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
return %5, %4, %5, %5, %6 : tensor<?x10xf32>, tensor<8x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK: func @inference_standard_lstm_non_time_major([[VAL_0:%.*]]: tensor<8x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<8x8x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = false} {
|
||||
// CHECK: [[VAL_6:%.*]] = constant dense<[1, 0, 2]> : tensor<3xi64>
|
||||
// CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_0]], [[VAL_6]]) : (tensor<8x8x8xf32>, tensor<3xi64>) -> tensor<8x8x8xf32>
|
||||
// CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
|
||||
// CHECK: [[VAL_9:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_8]]) : (tensor<8x40xf32>, tensor<2xi64>) -> tensor<40x8xf32>
|
||||
// CHECK: [[VAL_10:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
|
||||
// CHECK: [[VAL_11:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_10]]) : (tensor<10x40xf32>, tensor<2xi64>) -> tensor<40x10xf32>
|
||||
// CHECK: [[VAL_12:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
// CHECK: [[VAL_13:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: [[VAL_14:%.*]]:4 = "tf.SplitV"([[VAL_9]], [[VAL_12]], [[VAL_13]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>)
|
||||
// CHECK: [[VAL_15:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
// CHECK: [[VAL_16:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: [[VAL_17:%.*]]:4 = "tf.SplitV"([[VAL_11]], [[VAL_15]], [[VAL_16]]) : (tensor<40x10xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>)
|
||||
// CHECK: [[VAL_18:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
// CHECK: [[VAL_19:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: [[VAL_20:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_18]], [[VAL_19]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
|
||||
// CHECK: [[VAL_21:%.*]] = constant unit
|
||||
// CHECK: [[VAL_22:%.*]] = "tfl.lstm"([[VAL_0]], [[VAL_14]]#0, [[VAL_14]]#1, [[VAL_14]]#2, [[VAL_14]]#3, [[VAL_17]]#0, [[VAL_17]]#1, [[VAL_17]]#2, [[VAL_17]]#3, [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_20]]#0, [[VAL_20]]#1, [[VAL_20]]#2, [[VAL_20]]#3, [[VAL_21]], [[VAL_21]], [[VAL_1]], [[VAL_2]], [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_21]]) ( {
|
||||
// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<8x8x10xf32>
|
||||
// CHECK: [[VAL_23:%.*]] = constant dense<[1, 0, 2]> : tensor<3xi64>
|
||||
// CHECK: [[VAL_24:%.*]] = "tf.Transpose"([[VAL_25:%.*]], [[VAL_23]]) : (tensor<8x8x10xf32>, tensor<3xi64>) -> tensor<8x8x10xf32>
|
||||
// CHECK: return [[VAL_24]] : tensor<8x8x10xf32>
|
||||
// CHECK: }
|
||||
}
|
||||
|
||||
@ -70,14 +70,14 @@ Value CreateNoneValue(OpBuilder* builder, mlir::Location location) {
|
||||
builder->getUnitAttr());
|
||||
}
|
||||
|
||||
Value Transpose2D(OpBuilder* builder, Value value_to_transpose,
|
||||
RankedTensorType type, mlir::Location location) {
|
||||
Value Transpose(OpBuilder* builder, Value value_to_transpose,
|
||||
SmallVector<int64_t, 4> perm, RankedTensorType original_type,
|
||||
mlir::Location location) {
|
||||
// Create a constant op for transpose permutation.
|
||||
SmallVector<int64_t, 2> perm = {1, 0};
|
||||
auto perm_op = CreateI64DenseConst(builder, perm, perm, location);
|
||||
|
||||
// Create tensor type for the transpose result.
|
||||
auto transpose_type = type;
|
||||
auto transpose_type = original_type;
|
||||
auto transpose_shape = functional::map(
|
||||
[transpose_type](int64_t dim) { return transpose_type.getDimSize(dim); },
|
||||
perm);
|
||||
@ -88,6 +88,13 @@ Value Transpose2D(OpBuilder* builder, Value value_to_transpose,
|
||||
value_to_transpose, perm_op);
|
||||
}
|
||||
|
||||
Value Transpose2D(OpBuilder* builder, Value value_to_transpose,
|
||||
RankedTensorType type, mlir::Location location) {
|
||||
// Create a constant op for transpose permutation.
|
||||
SmallVector<int64_t, 4> perm = {1, 0};
|
||||
return Transpose(builder, value_to_transpose, perm, type, location);
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> GetRankedTensorShape(Value value) {
|
||||
return value.getType().cast<RankedTensorType>().getShape();
|
||||
}
|
||||
@ -586,15 +593,30 @@ LogicalResult ConvertKerasLSTMLayer(mlir::FuncOp func_op, OpBuilder* builder) {
|
||||
Value recurrent_kernel = func_op.getArgument(4);
|
||||
Value bias = func_op.getArgument(5);
|
||||
|
||||
// Assume it's batch majored.
|
||||
// TFL lstm only supports time-majored inputs, so if it's not time-majored,
|
||||
// we will transpose the inputs and outputs.
|
||||
auto time_major_attr = func_op.getAttrOfType<BoolAttr>("tf.time_major");
|
||||
if (time_major_attr == nullptr) return failure();
|
||||
|
||||
bool time_majored = time_major_attr.getValue();
|
||||
auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
|
||||
if (!input_type) {
|
||||
func_op.emitError() << "Input type is not a ranked tensor type";
|
||||
return failure();
|
||||
}
|
||||
|
||||
int batch = input_type.getDimSize(0);
|
||||
int time = input_type.getDimSize(1);
|
||||
auto final_inputs = input;
|
||||
auto final_input_type = input_type;
|
||||
// We will transpose the inputs.
|
||||
if (!time_majored) {
|
||||
SmallVector<int64_t, 4> perm = {1, 0, 2};
|
||||
final_inputs =
|
||||
Transpose(builder, final_inputs, perm, input_type, func_op.getLoc());
|
||||
final_input_type = final_inputs.getType().dyn_cast<RankedTensorType>();
|
||||
}
|
||||
|
||||
int batch = final_input_type.getDimSize(1);
|
||||
int time = final_input_type.getDimSize(0);
|
||||
|
||||
// Setup correct weights.
|
||||
RankedTensorType weight_type =
|
||||
@ -672,7 +694,13 @@ LogicalResult ConvertKerasLSTMLayer(mlir::FuncOp func_op, OpBuilder* builder) {
|
||||
builder->getF32FloatAttr(10.0), builder->getF32FloatAttr(0.0),
|
||||
builder->getStringAttr("FULL"));
|
||||
|
||||
builder->create<mlir::ReturnOp>(func_op.getLoc(), lstm.getResult());
|
||||
auto final_output = lstm.getResult();
|
||||
if (!time_majored) {
|
||||
SmallVector<int64_t, 4> perm = {1, 0, 2};
|
||||
final_output =
|
||||
Transpose(builder, final_output, perm, result_type, func_op.getLoc());
|
||||
}
|
||||
builder->create<mlir::ReturnOp>(func_op.getLoc(), final_output);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user