[tf.data] Add map and batch fusion rewrite in MLIR
PiperOrigin-RevId: 310260964 Change-Id: I6505bcd35f21a3f9ff520f1900038c3c4be15536
This commit is contained in:
parent
469de83a9c
commit
a967cad22b
|
@ -342,6 +342,38 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "tf_data_optimization_inc_gen",
|
||||
tbl_outs = [
|
||||
(
|
||||
"-gen-rewriters",
|
||||
"transforms/generated_tf_data_optimization.inc",
|
||||
),
|
||||
],
|
||||
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||
td_file = "transforms/tf_data_optimization.td",
|
||||
td_srcs = [
|
||||
":tensorflow_ops_td_files",
|
||||
"@llvm-project//mlir:StdOpsTdFiles",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_data_optimization",
|
||||
srcs = [
|
||||
"transforms/tf_data_optimization.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"transforms/tf_data_optimization.h",
|
||||
],
|
||||
deps = [
|
||||
":tensorflow",
|
||||
":tensorflow_types",
|
||||
":tf_data_optimization_inc_gen",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "unroll_batch_matmul_pass",
|
||||
srcs = [
|
||||
|
@ -406,6 +438,7 @@ cc_library(
|
|||
"transforms/tensor_array_ops_decomposition.cc",
|
||||
"transforms/tensor_list_ops_decomposition.cc",
|
||||
"transforms/test_side_effect_analysis.cc",
|
||||
"transforms/tf_data_optimization_pass.cc",
|
||||
"transforms/tf_device_assignment.cc",
|
||||
"transforms/tpu_cluster_formation.cc",
|
||||
"transforms/tpu_dynamic_layout_pass.cc",
|
||||
|
@ -444,6 +477,7 @@ cc_library(
|
|||
":tensorflow",
|
||||
":tensorflow_optimize_inc_gen",
|
||||
":tensorflow_types",
|
||||
":tf_data_optimization",
|
||||
":tpu_rewrite_device_util",
|
||||
":translate_utils",
|
||||
":unroll_batch_matmul_pass",
|
||||
|
|
|
@ -777,4 +777,133 @@ Formats a string template using a list of tensors, pretty-printing tensor summar
|
|||
TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// tf.data ops
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def TF_BatchDatasetV2Op : TF_Op<"BatchDatasetV2", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Creates a dataset that batches `batch_size` elements from `input_dataset`.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_VariantTensor:$input_dataset,
|
||||
I64Tensor:$batch_size,
|
||||
I1Tensor:$drop_remainder,
|
||||
|
||||
DefaultValuedAttr<BoolAttr, "false">:$parallel_copy,
|
||||
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
|
||||
Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_VariantTensor:$handle
|
||||
);
|
||||
}
|
||||
|
||||
def TF_MapDatasetOp : TF_Op<"MapDataset", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Creates a dataset that applies `f` to the outputs of `input_dataset`.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_VariantTensor:$input_dataset,
|
||||
Variadic<TF_Tensor>:$other_arguments,
|
||||
|
||||
SymbolRefAttr:$f,
|
||||
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
|
||||
Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes,
|
||||
DefaultValuedAttr<BoolAttr, "true">:$use_inter_op_parallelism,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$preserve_cardinality
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_VariantTensor:$handle
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeListAttr Targuments = TF_DerivedOperandTypeListAttr<1>;
|
||||
}
|
||||
|
||||
def TF_MapAndBatchDatasetOp : TF_Op<"MapAndBatchDataset", [NoSideEffect]> {
|
||||
let summary = "Creates a dataset that fuses mapping with batching.";
|
||||
|
||||
let description = [{
|
||||
Creates a dataset that applies `f` to the outputs of `input_dataset` and then
|
||||
batches `batch_size` of them.
|
||||
|
||||
Unlike a "MapDataset", which applies `f` sequentially, this dataset invokes up
|
||||
to `batch_size * num_parallel_batches` copies of `f` in parallel.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_VariantTensor:$input_dataset,
|
||||
Variadic<TF_Tensor>:$other_arguments,
|
||||
I64Tensor:$batch_size,
|
||||
I64Tensor:$num_parallel_calls,
|
||||
I1Tensor:$drop_remainder,
|
||||
|
||||
SymbolRefAttr:$f,
|
||||
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
|
||||
Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$preserve_cardinality
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_VariantTensor:$handle
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeListAttr Targuments = TF_DerivedOperandTypeListAttr<1>;
|
||||
}
|
||||
|
||||
def TF_ParallelMapDatasetOp : TF_Op<"ParallelMapDataset", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Creates a dataset that applies `f` to the outputs of `input_dataset`.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
Unlike a "MapDataset", which applies `f` sequentially, this dataset invokes
|
||||
up to `num_parallel_calls` copies of `f` in parallel.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_VariantTensor:$input_dataset,
|
||||
Variadic<TF_Tensor>:$other_arguments,
|
||||
I32Tensor:$num_parallel_calls,
|
||||
|
||||
SymbolRefAttr:$f,
|
||||
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
|
||||
Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes,
|
||||
DefaultValuedAttr<BoolAttr, "true">:$use_inter_op_parallelism,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$sloppy,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$preserve_cardinality
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_VariantTensor:$handle
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeListAttr Targuments = TF_DerivedOperandTypeListAttr<1>;
|
||||
}
|
||||
|
||||
def TF_TensorSliceDatasetOp : TF_Op<"TensorSliceDataset", []> {
|
||||
let summary = [{
|
||||
Creates a dataset that emits each dim-0 slice of `components` once.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<TF_Tensor>:$components,
|
||||
Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_VariantTensor:$handle
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeListAttr Toutput_types = TF_DerivedOperandTypeListAttr<0>;
|
||||
}
|
||||
|
||||
|
||||
#endif // TF_OPS
|
||||
|
|
|
@ -0,0 +1,256 @@
|
|||
# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-output-arrays=BatchDatasetV2 -o - | FileCheck %s --dump-input-on-failure
|
||||
|
||||
# CHECK-LABEL: func @main() -> tensor<*x!tf.variant>
|
||||
# CHECK: %[[tensor_slice:.*]], %[[tensor_slice_control:.*]] = tf_executor.island wraps "tf.TensorSliceDataset"
|
||||
# CHECK: %[[map_dataset:.*]], %[[map_dataset_control:.*]] = tf_executor.island wraps "tf.MapDataset"(%[[tensor_slice]]
|
||||
# CHECK: %[[batch_dataset:.*]], %[[batch_dataset_control:.*]] = tf_executor.island wraps "tf.BatchDatasetV2"(%[[map_dataset]]
|
||||
|
||||
node {
|
||||
name: "tensors/normalize_tensors/component_0"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_INT32
|
||||
tensor_shape {
|
||||
dim {
|
||||
size: 3
|
||||
}
|
||||
}
|
||||
tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "TensorSliceDataset"
|
||||
op: "TensorSliceDataset"
|
||||
input: "tensors/normalize_tensors/component_0"
|
||||
attr {
|
||||
key: "Toutput_types"
|
||||
value {
|
||||
list {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "MapDataset"
|
||||
op: "MapDataset"
|
||||
input: "TensorSliceDataset"
|
||||
attr {
|
||||
key: "Targuments"
|
||||
value {
|
||||
list {
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "f"
|
||||
value {
|
||||
func {
|
||||
name: "__inference_Dataset_map_<lambda>_8"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "output_types"
|
||||
value {
|
||||
list {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "preserve_cardinality"
|
||||
value {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "use_inter_op_parallelism"
|
||||
value {
|
||||
b: true
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "batch_size"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_INT64
|
||||
tensor_shape {
|
||||
}
|
||||
int64_val: 5
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "drop_remainder"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_BOOL
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_BOOL
|
||||
tensor_shape {
|
||||
}
|
||||
bool_val: false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "BatchDatasetV2"
|
||||
op: "BatchDatasetV2"
|
||||
input: "MapDataset"
|
||||
input: "batch_size"
|
||||
input: "drop_remainder"
|
||||
attr {
|
||||
key: "output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
dim {
|
||||
size: -1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "output_types"
|
||||
value {
|
||||
list {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "parallel_copy"
|
||||
value {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
}
|
||||
library {
|
||||
function {
|
||||
signature {
|
||||
name: "__inference_Dataset_map_<lambda>_8"
|
||||
input_arg {
|
||||
name: "args_0"
|
||||
type: DT_INT32
|
||||
}
|
||||
output_arg {
|
||||
name: "identity"
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
node_def {
|
||||
name: "mul/y"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_INT32
|
||||
tensor_shape {
|
||||
}
|
||||
int_val: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node_def {
|
||||
name: "mul"
|
||||
op: "Mul"
|
||||
input: "args_0"
|
||||
input: "mul/y:output:0"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node_def {
|
||||
name: "Identity"
|
||||
op: "Identity"
|
||||
input: "mul:z:0"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
ret {
|
||||
key: "identity"
|
||||
value: "Identity:output:0"
|
||||
}
|
||||
arg_attr {
|
||||
key: 0
|
||||
value {
|
||||
attr {
|
||||
key: "_user_specified_name"
|
||||
value {
|
||||
s: "args_0"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
versions {
|
||||
producer: 134
|
||||
min_consumer: 12
|
||||
}
|
||||
|
|
@ -0,0 +1,29 @@
|
|||
// RUN: tf-opt -tf-standard-pipeline -tf-data-optimization %s -o %t && FileCheck %s --dump-input-on-failure < %t
|
||||
|
||||
module {
|
||||
// CHECK-LABEL: fuse_map_and_batch
|
||||
func @fuse_map_and_batch() -> tensor<!tf.variant> attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "BatchDatasetV2"}} {
|
||||
%0 = "tf.Const"() {value = dense<5> : tensor<i64>} : () -> tensor<i64>
|
||||
%1 = "tf.Const"() {value = dense<false> : tensor<i1>} : () -> tensor<i1>
|
||||
%2 = "tf.Const"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||
// CHECK: %[[NPC:.*]] = "tf.Const"() {value = dense<1> : tensor<i64>}
|
||||
// CHECK: %[[TSLICE:.*]] = "tf.TensorSliceDataset"
|
||||
%3 = "tf.TensorSliceDataset"(%2) {device = "", output_shapes = [#tf.shape<>]} : (tensor<3xi32>) -> tensor<*x!tf.variant>
|
||||
// CHECK: "tf.MapAndBatchDataset"(%[[TSLICE]], %[[BSIZE:.*]], %[[NPC]]
|
||||
// CHECK-SAME: f = @"__inference_Dataset_map_<lambda>_80",
|
||||
%4 = "tf.MapDataset"(%3) {device = "",
|
||||
f = @"__inference_Dataset_map_<lambda>_80",
|
||||
output_shapes = [#tf.shape<>], output_types = [i32],
|
||||
preserve_cardinality = false, sloppy = false,
|
||||
use_inter_op_parallelism = true} : (tensor<*x!tf.variant>) -> tensor<!tf.variant>
|
||||
%5 = "tf.BatchDatasetV2"(%4, %0, %1) {device = "", output_shapes = [#tf.shape<>], output_types = [i32], parallel_copy = false} : (tensor<!tf.variant>, tensor<i64>, tensor<i1>) -> tensor<!tf.variant>
|
||||
return %5 : tensor<!tf.variant>
|
||||
}
|
||||
|
||||
func @"__inference_Dataset_map_<lambda>_80"(%arg0: tensor<*xi32>) -> tensor<*xi32> {
|
||||
%0 = "tf.Const"() {value = dense<2> : tensor<i32>} : () -> tensor<i32>
|
||||
%1 = "tf.Mul"(%arg0, %0) {device = ""} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||
%2 = "tf.Identity"(%1) {device = ""} : (tensor<*xi32>) -> tensor<*xi32>
|
||||
return %2 : tensor<*xi32>
|
||||
}
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
// RUN: tf-opt -tf-standard-pipeline -tf-data-optimization %s -o %t && FileCheck %s --dump-input-on-failure < %t
|
||||
|
||||
module {
|
||||
// CHECK-LABEL: fuse_pmap_and_batch
|
||||
func @fuse_pmap_and_batch() -> tensor<!tf.variant> attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "BatchDatasetV2"}} {
|
||||
%0 = "tf.Const"() {value = dense<5> : tensor<i64>} : () -> tensor<i64>
|
||||
%1 = "tf.Const"() {value = dense<false> : tensor<i1>} : () -> tensor<i1>
|
||||
%2 = "tf.Const"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||
%3 = "tf.Const"() {value = dense<12> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: %[[TSLICE:.*]] = "tf.TensorSliceDataset"
|
||||
%4 = "tf.TensorSliceDataset"(%2) {device = "", output_shapes = [#tf.shape<>]} : (tensor<3xi32>) -> tensor<*x!tf.variant>
|
||||
// CHECK: "tf.MapAndBatchDataset"(%[[TSLICE]],
|
||||
// CHECK-SAME: f = @"__inference_Dataset_map_<lambda>_80",
|
||||
%5 = "tf.ParallelMapDataset"(%4, %3) {device = "",
|
||||
f = @"__inference_Dataset_map_<lambda>_80",
|
||||
output_shapes = [#tf.shape<>], output_types = [i32],
|
||||
preserve_cardinality = false, sloppy = false,
|
||||
use_inter_op_parallelism = true} : (tensor<*x!tf.variant>, tensor<i32>) -> tensor<!tf.variant>
|
||||
%6 = "tf.BatchDatasetV2"(%5, %0, %1) {device = "", output_shapes = [#tf.shape<>], output_types = [i32], parallel_copy = false} : (tensor<!tf.variant>, tensor<i64>, tensor<i1>) -> tensor<!tf.variant>
|
||||
return %6 : tensor<!tf.variant>
|
||||
}
|
||||
|
||||
func @"__inference_Dataset_map_<lambda>_80"(%arg0: tensor<*xi32>) -> tensor<*xi32> {
|
||||
%0 = "tf.Const"() {value = dense<2> : tensor<i32>} : () -> tensor<i32>
|
||||
%1 = "tf.Mul"(%arg0, %0) {device = ""} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||
%2 = "tf.Identity"(%1) {device = ""} : (tensor<*xi32>) -> tensor<*xi32>
|
||||
return %2 : tensor<*xi32>
|
||||
}
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.h"
|
||||
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
|
||||
namespace {
|
||||
|
||||
struct FuseParallelMapAndBatch : public OpRewritePattern<BatchDatasetV2Op> {
|
||||
using OpRewritePattern<BatchDatasetV2Op>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(BatchDatasetV2Op op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto batchInputDataset = op.input_dataset();
|
||||
|
||||
ParallelMapDatasetOp batchInputOp = dyn_cast_or_null<ParallelMapDatasetOp>(
|
||||
batchInputDataset.getDefiningOp());
|
||||
if (!batchInputOp) return failure();
|
||||
|
||||
// The type of the `num_parallel_calls` argument in ParallelMapDataset
|
||||
// and MapAndBatchDataset is different (int32 and int64 respectively)
|
||||
auto num_parallel_calls_op = rewriter.create<CastOp>(
|
||||
op.getLoc(), UnrankedTensorType::get(rewriter.getIntegerType(64)),
|
||||
batchInputOp.num_parallel_calls(), rewriter.getBoolAttr(false));
|
||||
|
||||
auto fused_op = rewriter.create<MapAndBatchDatasetOp>(
|
||||
op.getLoc(), op.getType(), batchInputOp.input_dataset(),
|
||||
batchInputOp.other_arguments(), op.batch_size(),
|
||||
num_parallel_calls_op.y(), op.drop_remainder(), batchInputOp.f(),
|
||||
op.output_types(), op.output_shapes(),
|
||||
batchInputOp.preserve_cardinality());
|
||||
rewriter.replaceOp(op, {fused_op.handle()});
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_tf_data_optimization.inc"
|
||||
} // namespace
|
||||
|
||||
void PopulateTFDataOptimizationPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns) {
|
||||
patterns->insert<FuseParallelMapAndBatch>(context);
|
||||
populateWithGenerated(context, patterns);
|
||||
}
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
|
@ -0,0 +1,32 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_DATA_OPTIMIZATION_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_DATA_OPTIMIZATION_H_
|
||||
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
|
||||
// Populates patterns to perform optimizations specific to tf.data operations.
|
||||
void PopulateTFDataOptimizationPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns);
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_DATA_OPTIMIZATION_H_
|
|
@ -0,0 +1,32 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
||||
|
||||
// TODO(jpienaar): Move this somewhere general.
|
||||
class GetI64ScalarElementsAttr<int value> :
|
||||
NativeCodeCall<"DenseElementsAttr::get<int64_t>(RankedTensorType::get({}, $_builder.getIntegerType(64)), " # value # ")">;
|
||||
|
||||
def FuseMapAndBatch : Pat<
|
||||
(TF_BatchDatasetV2Op
|
||||
(TF_MapDatasetOp $input_dataset, $other_arguments, $f, $output_types,
|
||||
$output_shapes, $use_inter_op_parallelism, $preserve_cardinality),
|
||||
$batch_size, $drop_remainder, $parallel_copy, $batch_output_types,
|
||||
$batch_output_shapes),
|
||||
(TF_MapAndBatchDatasetOp $input_dataset, $other_arguments, $batch_size,
|
||||
(TF_ConstOp (GetI64ScalarElementsAttr<1>)), $drop_remainder, $f,
|
||||
$batch_output_types, $batch_output_shapes, $preserve_cardinality)>;
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
namespace {
|
||||
|
||||
// Perform tf.data optimizations.
|
||||
struct TFDataOptimization
|
||||
: public PassWrapper<TFDataOptimization, FunctionPass> {
|
||||
void runOnFunction() override {
|
||||
OwningRewritePatternList patterns;
|
||||
mlir::TF::PopulateTFDataOptimizationPatterns(&getContext(), &patterns);
|
||||
|
||||
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
||||
|
||||
static mlir::PassRegistration<mlir::TF::TFDataOptimization> pass(
|
||||
"tf-data-optimization", "Performs tf.data optimizations");
|
Loading…
Reference in New Issue