[tf.data] Add map and batch fusion rewrite in MLIR

PiperOrigin-RevId: 310260964
Change-Id: I6505bcd35f21a3f9ff520f1900038c3c4be15536
This commit is contained in:
Rachel Lim 2020-05-06 17:27:08 -07:00 committed by TensorFlower Gardener
parent 469de83a9c
commit a967cad22b
9 changed files with 646 additions and 0 deletions

View File

@ -342,6 +342,38 @@ cc_library(
], ],
) )
gentbl(
name = "tf_data_optimization_inc_gen",
tbl_outs = [
(
"-gen-rewriters",
"transforms/generated_tf_data_optimization.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "transforms/tf_data_optimization.td",
td_srcs = [
":tensorflow_ops_td_files",
"@llvm-project//mlir:StdOpsTdFiles",
],
)
cc_library(
name = "tf_data_optimization",
srcs = [
"transforms/tf_data_optimization.cc",
],
hdrs = [
"transforms/tf_data_optimization.h",
],
deps = [
":tensorflow",
":tensorflow_types",
":tf_data_optimization_inc_gen",
"@llvm-project//mlir:IR",
],
)
cc_library( cc_library(
name = "unroll_batch_matmul_pass", name = "unroll_batch_matmul_pass",
srcs = [ srcs = [
@ -406,6 +438,7 @@ cc_library(
"transforms/tensor_array_ops_decomposition.cc", "transforms/tensor_array_ops_decomposition.cc",
"transforms/tensor_list_ops_decomposition.cc", "transforms/tensor_list_ops_decomposition.cc",
"transforms/test_side_effect_analysis.cc", "transforms/test_side_effect_analysis.cc",
"transforms/tf_data_optimization_pass.cc",
"transforms/tf_device_assignment.cc", "transforms/tf_device_assignment.cc",
"transforms/tpu_cluster_formation.cc", "transforms/tpu_cluster_formation.cc",
"transforms/tpu_dynamic_layout_pass.cc", "transforms/tpu_dynamic_layout_pass.cc",
@ -444,6 +477,7 @@ cc_library(
":tensorflow", ":tensorflow",
":tensorflow_optimize_inc_gen", ":tensorflow_optimize_inc_gen",
":tensorflow_types", ":tensorflow_types",
":tf_data_optimization",
":tpu_rewrite_device_util", ":tpu_rewrite_device_util",
":translate_utils", ":translate_utils",
":unroll_batch_matmul_pass", ":unroll_batch_matmul_pass",

View File

@ -777,4 +777,133 @@ Formats a string template using a list of tensors, pretty-printing tensor summar
TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>; TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>;
} }
//===----------------------------------------------------------------------===//
// tf.data ops
//===----------------------------------------------------------------------===//
def TF_BatchDatasetV2Op : TF_Op<"BatchDatasetV2", [NoSideEffect]> {
let summary = [{
Creates a dataset that batches `batch_size` elements from `input_dataset`.
}];
let description = [{
}];
let arguments = (ins
TF_VariantTensor:$input_dataset,
I64Tensor:$batch_size,
I1Tensor:$drop_remainder,
DefaultValuedAttr<BoolAttr, "false">:$parallel_copy,
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes
);
let results = (outs
TF_VariantTensor:$handle
);
}
def TF_MapDatasetOp : TF_Op<"MapDataset", [NoSideEffect]> {
let summary = [{
Creates a dataset that applies `f` to the outputs of `input_dataset`.
}];
let arguments = (ins
TF_VariantTensor:$input_dataset,
Variadic<TF_Tensor>:$other_arguments,
SymbolRefAttr:$f,
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes,
DefaultValuedAttr<BoolAttr, "true">:$use_inter_op_parallelism,
DefaultValuedAttr<BoolAttr, "false">:$preserve_cardinality
);
let results = (outs
TF_VariantTensor:$handle
);
TF_DerivedOperandTypeListAttr Targuments = TF_DerivedOperandTypeListAttr<1>;
}
def TF_MapAndBatchDatasetOp : TF_Op<"MapAndBatchDataset", [NoSideEffect]> {
let summary = "Creates a dataset that fuses mapping with batching.";
let description = [{
Creates a dataset that applies `f` to the outputs of `input_dataset` and then
batches `batch_size` of them.
Unlike a "MapDataset", which applies `f` sequentially, this dataset invokes up
to `batch_size * num_parallel_batches` copies of `f` in parallel.
}];
let arguments = (ins
TF_VariantTensor:$input_dataset,
Variadic<TF_Tensor>:$other_arguments,
I64Tensor:$batch_size,
I64Tensor:$num_parallel_calls,
I1Tensor:$drop_remainder,
SymbolRefAttr:$f,
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes,
DefaultValuedAttr<BoolAttr, "false">:$preserve_cardinality
);
let results = (outs
TF_VariantTensor:$handle
);
TF_DerivedOperandTypeListAttr Targuments = TF_DerivedOperandTypeListAttr<1>;
}
def TF_ParallelMapDatasetOp : TF_Op<"ParallelMapDataset", [NoSideEffect]> {
let summary = [{
Creates a dataset that applies `f` to the outputs of `input_dataset`.
}];
let description = [{
Unlike a "MapDataset", which applies `f` sequentially, this dataset invokes
up to `num_parallel_calls` copies of `f` in parallel.
}];
let arguments = (ins
TF_VariantTensor:$input_dataset,
Variadic<TF_Tensor>:$other_arguments,
I32Tensor:$num_parallel_calls,
SymbolRefAttr:$f,
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes,
DefaultValuedAttr<BoolAttr, "true">:$use_inter_op_parallelism,
DefaultValuedAttr<BoolAttr, "false">:$sloppy,
DefaultValuedAttr<BoolAttr, "false">:$preserve_cardinality
);
let results = (outs
TF_VariantTensor:$handle
);
TF_DerivedOperandTypeListAttr Targuments = TF_DerivedOperandTypeListAttr<1>;
}
def TF_TensorSliceDatasetOp : TF_Op<"TensorSliceDataset", []> {
let summary = [{
Creates a dataset that emits each dim-0 slice of `components` once.
}];
let arguments = (ins
Variadic<TF_Tensor>:$components,
Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes
);
let results = (outs
TF_VariantTensor:$handle
);
TF_DerivedOperandTypeListAttr Toutput_types = TF_DerivedOperandTypeListAttr<0>;
}
#endif // TF_OPS #endif // TF_OPS

View File

@ -0,0 +1,256 @@
# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-output-arrays=BatchDatasetV2 -o - | FileCheck %s --dump-input-on-failure
# CHECK-LABEL: func @main() -> tensor<*x!tf.variant>
# CHECK: %[[tensor_slice:.*]], %[[tensor_slice_control:.*]] = tf_executor.island wraps "tf.TensorSliceDataset"
# CHECK: %[[map_dataset:.*]], %[[map_dataset_control:.*]] = tf_executor.island wraps "tf.MapDataset"(%[[tensor_slice]]
# CHECK: %[[batch_dataset:.*]], %[[batch_dataset_control:.*]] = tf_executor.island wraps "tf.BatchDatasetV2"(%[[map_dataset]]
node {
name: "tensors/normalize_tensors/component_0"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
dim {
size: 3
}
}
tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000"
}
}
}
}
node {
name: "TensorSliceDataset"
op: "TensorSliceDataset"
input: "tensors/normalize_tensors/component_0"
attr {
key: "Toutput_types"
value {
list {
type: DT_INT32
}
}
}
attr {
key: "output_shapes"
value {
list {
shape {
}
}
}
}
}
node {
name: "MapDataset"
op: "MapDataset"
input: "TensorSliceDataset"
attr {
key: "Targuments"
value {
list {
}
}
}
attr {
key: "f"
value {
func {
name: "__inference_Dataset_map_<lambda>_8"
}
}
}
attr {
key: "output_shapes"
value {
list {
shape {
}
}
}
}
attr {
key: "output_types"
value {
list {
type: DT_INT32
}
}
}
attr {
key: "preserve_cardinality"
value {
b: false
}
}
attr {
key: "use_inter_op_parallelism"
value {
b: true
}
}
}
node {
name: "batch_size"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT64
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT64
tensor_shape {
}
int64_val: 5
}
}
}
}
node {
name: "drop_remainder"
op: "Const"
attr {
key: "dtype"
value {
type: DT_BOOL
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_BOOL
tensor_shape {
}
bool_val: false
}
}
}
}
node {
name: "BatchDatasetV2"
op: "BatchDatasetV2"
input: "MapDataset"
input: "batch_size"
input: "drop_remainder"
attr {
key: "output_shapes"
value {
list {
shape {
dim {
size: -1
}
}
}
}
}
attr {
key: "output_types"
value {
list {
type: DT_INT32
}
}
}
attr {
key: "parallel_copy"
value {
b: false
}
}
}
library {
function {
signature {
name: "__inference_Dataset_map_<lambda>_8"
input_arg {
name: "args_0"
type: DT_INT32
}
output_arg {
name: "identity"
type: DT_INT32
}
}
node_def {
name: "mul/y"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: 2
}
}
}
}
node_def {
name: "mul"
op: "Mul"
input: "args_0"
input: "mul/y:output:0"
attr {
key: "T"
value {
type: DT_INT32
}
}
}
node_def {
name: "Identity"
op: "Identity"
input: "mul:z:0"
attr {
key: "T"
value {
type: DT_INT32
}
}
}
ret {
key: "identity"
value: "Identity:output:0"
}
arg_attr {
key: 0
value {
attr {
key: "_user_specified_name"
value {
s: "args_0"
}
}
}
}
}
}
versions {
producer: 134
min_consumer: 12
}

View File

@ -0,0 +1,29 @@
// RUN: tf-opt -tf-standard-pipeline -tf-data-optimization %s -o %t && FileCheck %s --dump-input-on-failure < %t
module {
// CHECK-LABEL: fuse_map_and_batch
func @fuse_map_and_batch() -> tensor<!tf.variant> attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "BatchDatasetV2"}} {
%0 = "tf.Const"() {value = dense<5> : tensor<i64>} : () -> tensor<i64>
%1 = "tf.Const"() {value = dense<false> : tensor<i1>} : () -> tensor<i1>
%2 = "tf.Const"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
// CHECK: %[[NPC:.*]] = "tf.Const"() {value = dense<1> : tensor<i64>}
// CHECK: %[[TSLICE:.*]] = "tf.TensorSliceDataset"
%3 = "tf.TensorSliceDataset"(%2) {device = "", output_shapes = [#tf.shape<>]} : (tensor<3xi32>) -> tensor<*x!tf.variant>
// CHECK: "tf.MapAndBatchDataset"(%[[TSLICE]], %[[BSIZE:.*]], %[[NPC]]
// CHECK-SAME: f = @"__inference_Dataset_map_<lambda>_80",
%4 = "tf.MapDataset"(%3) {device = "",
f = @"__inference_Dataset_map_<lambda>_80",
output_shapes = [#tf.shape<>], output_types = [i32],
preserve_cardinality = false, sloppy = false,
use_inter_op_parallelism = true} : (tensor<*x!tf.variant>) -> tensor<!tf.variant>
%5 = "tf.BatchDatasetV2"(%4, %0, %1) {device = "", output_shapes = [#tf.shape<>], output_types = [i32], parallel_copy = false} : (tensor<!tf.variant>, tensor<i64>, tensor<i1>) -> tensor<!tf.variant>
return %5 : tensor<!tf.variant>
}
func @"__inference_Dataset_map_<lambda>_80"(%arg0: tensor<*xi32>) -> tensor<*xi32> {
%0 = "tf.Const"() {value = dense<2> : tensor<i32>} : () -> tensor<i32>
%1 = "tf.Mul"(%arg0, %0) {device = ""} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
%2 = "tf.Identity"(%1) {device = ""} : (tensor<*xi32>) -> tensor<*xi32>
return %2 : tensor<*xi32>
}
}

View File

@ -0,0 +1,29 @@
// RUN: tf-opt -tf-standard-pipeline -tf-data-optimization %s -o %t && FileCheck %s --dump-input-on-failure < %t
module {
// CHECK-LABEL: fuse_pmap_and_batch
func @fuse_pmap_and_batch() -> tensor<!tf.variant> attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "BatchDatasetV2"}} {
%0 = "tf.Const"() {value = dense<5> : tensor<i64>} : () -> tensor<i64>
%1 = "tf.Const"() {value = dense<false> : tensor<i1>} : () -> tensor<i1>
%2 = "tf.Const"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
%3 = "tf.Const"() {value = dense<12> : tensor<i32>} : () -> tensor<i32>
// CHECK: %[[TSLICE:.*]] = "tf.TensorSliceDataset"
%4 = "tf.TensorSliceDataset"(%2) {device = "", output_shapes = [#tf.shape<>]} : (tensor<3xi32>) -> tensor<*x!tf.variant>
// CHECK: "tf.MapAndBatchDataset"(%[[TSLICE]],
// CHECK-SAME: f = @"__inference_Dataset_map_<lambda>_80",
%5 = "tf.ParallelMapDataset"(%4, %3) {device = "",
f = @"__inference_Dataset_map_<lambda>_80",
output_shapes = [#tf.shape<>], output_types = [i32],
preserve_cardinality = false, sloppy = false,
use_inter_op_parallelism = true} : (tensor<*x!tf.variant>, tensor<i32>) -> tensor<!tf.variant>
%6 = "tf.BatchDatasetV2"(%5, %0, %1) {device = "", output_shapes = [#tf.shape<>], output_types = [i32], parallel_copy = false} : (tensor<!tf.variant>, tensor<i64>, tensor<i1>) -> tensor<!tf.variant>
return %6 : tensor<!tf.variant>
}
func @"__inference_Dataset_map_<lambda>_80"(%arg0: tensor<*xi32>) -> tensor<*xi32> {
%0 = "tf.Const"() {value = dense<2> : tensor<i32>} : () -> tensor<i32>
%1 = "tf.Mul"(%arg0, %0) {device = ""} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
%2 = "tf.Identity"(%1) {device = ""} : (tensor<*xi32>) -> tensor<*xi32>
return %2 : tensor<*xi32>
}
}

View File

@ -0,0 +1,65 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.h"
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
namespace mlir {
namespace TF {
namespace {
struct FuseParallelMapAndBatch : public OpRewritePattern<BatchDatasetV2Op> {
using OpRewritePattern<BatchDatasetV2Op>::OpRewritePattern;
LogicalResult matchAndRewrite(BatchDatasetV2Op op,
PatternRewriter &rewriter) const override {
auto batchInputDataset = op.input_dataset();
ParallelMapDatasetOp batchInputOp = dyn_cast_or_null<ParallelMapDatasetOp>(
batchInputDataset.getDefiningOp());
if (!batchInputOp) return failure();
// The type of the `num_parallel_calls` argument in ParallelMapDataset
// and MapAndBatchDataset is different (int32 and int64 respectively)
auto num_parallel_calls_op = rewriter.create<CastOp>(
op.getLoc(), UnrankedTensorType::get(rewriter.getIntegerType(64)),
batchInputOp.num_parallel_calls(), rewriter.getBoolAttr(false));
auto fused_op = rewriter.create<MapAndBatchDatasetOp>(
op.getLoc(), op.getType(), batchInputOp.input_dataset(),
batchInputOp.other_arguments(), op.batch_size(),
num_parallel_calls_op.y(), op.drop_remainder(), batchInputOp.f(),
op.output_types(), op.output_shapes(),
batchInputOp.preserve_cardinality());
rewriter.replaceOp(op, {fused_op.handle()});
return failure();
}
};
#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_tf_data_optimization.inc"
} // namespace
void PopulateTFDataOptimizationPatterns(MLIRContext *context,
OwningRewritePatternList *patterns) {
patterns->insert<FuseParallelMapAndBatch>(context);
populateWithGenerated(context, patterns);
}
} // namespace TF
} // namespace mlir

View File

@ -0,0 +1,32 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_DATA_OPTIMIZATION_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_DATA_OPTIMIZATION_H_
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
namespace mlir {
namespace TF {
// Populates patterns to perform optimizations specific to tf.data operations.
void PopulateTFDataOptimizationPatterns(MLIRContext *context,
OwningRewritePatternList *patterns);
} // namespace TF
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_DATA_OPTIMIZATION_H_

View File

@ -0,0 +1,32 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
include "mlir/IR/OpBase.td"
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
// TODO(jpienaar): Move this somewhere general.
class GetI64ScalarElementsAttr<int value> :
NativeCodeCall<"DenseElementsAttr::get<int64_t>(RankedTensorType::get({}, $_builder.getIntegerType(64)), " # value # ")">;
def FuseMapAndBatch : Pat<
(TF_BatchDatasetV2Op
(TF_MapDatasetOp $input_dataset, $other_arguments, $f, $output_types,
$output_shapes, $use_inter_op_parallelism, $preserve_cardinality),
$batch_size, $drop_remainder, $parallel_copy, $batch_output_types,
$batch_output_shapes),
(TF_MapAndBatchDatasetOp $input_dataset, $other_arguments, $batch_size,
(TF_ConstOp (GetI64ScalarElementsAttr<1>)), $drop_remainder, $f,
$batch_output_types, $batch_output_shapes, $preserve_cardinality)>;

View File

@ -0,0 +1,40 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.h"
namespace mlir {
namespace TF {
namespace {
// Perform tf.data optimizations.
struct TFDataOptimization
: public PassWrapper<TFDataOptimization, FunctionPass> {
void runOnFunction() override {
OwningRewritePatternList patterns;
mlir::TF::PopulateTFDataOptimizationPatterns(&getContext(), &patterns);
applyPatternsAndFoldGreedily(getFunction(), patterns);
}
};
} // namespace
} // namespace TF
} // namespace mlir
static mlir::PassRegistration<mlir::TF::TFDataOptimization> pass(
"tf-data-optimization", "Performs tf.data optimizations");