From a967cad22b06fd24a400c7b3c27d4a573ee9f68f Mon Sep 17 00:00:00 2001 From: Rachel Lim Date: Wed, 6 May 2020 17:27:08 -0700 Subject: [PATCH] [tf.data] Add map and batch fusion rewrite in MLIR PiperOrigin-RevId: 310260964 Change-Id: I6505bcd35f21a3f9ff520f1900038c3c4be15536 --- tensorflow/compiler/mlir/tensorflow/BUILD | 34 +++ .../compiler/mlir/tensorflow/ir/tf_ops.td | 129 +++++++++ .../graphdef2mlir/tf-data-pipeline.pbtxt | 256 ++++++++++++++++++ .../tests/tf_data_fuse_map_and_batch.mlir | 29 ++ .../tests/tf_data_fuse_pmap_and_batch.mlir | 29 ++ .../transforms/tf_data_optimization.cc | 65 +++++ .../transforms/tf_data_optimization.h | 32 +++ .../transforms/tf_data_optimization.td | 32 +++ .../transforms/tf_data_optimization_pass.cc | 40 +++ 9 files changed, 646 insertions(+) create mode 100644 tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/tf-data-pipeline.pbtxt create mode 100644 tensorflow/compiler/mlir/tensorflow/tests/tf_data_fuse_map_and_batch.mlir create mode 100644 tensorflow/compiler/mlir/tensorflow/tests/tf_data_fuse_pmap_and_batch.mlir create mode 100644 tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.cc create mode 100644 tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.h create mode 100644 tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.td create mode 100644 tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization_pass.cc diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 5f1aaa01514..9099f2be2e1 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -342,6 +342,38 @@ cc_library( ], ) +gentbl( + name = "tf_data_optimization_inc_gen", + tbl_outs = [ + ( + "-gen-rewriters", + "transforms/generated_tf_data_optimization.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "transforms/tf_data_optimization.td", + td_srcs = [ + ":tensorflow_ops_td_files", + "@llvm-project//mlir:StdOpsTdFiles", + ], +) + +cc_library( + name = "tf_data_optimization", + srcs = [ + "transforms/tf_data_optimization.cc", + ], + hdrs = [ + "transforms/tf_data_optimization.h", + ], + deps = [ + ":tensorflow", + ":tensorflow_types", + ":tf_data_optimization_inc_gen", + "@llvm-project//mlir:IR", + ], +) + cc_library( name = "unroll_batch_matmul_pass", srcs = [ @@ -406,6 +438,7 @@ cc_library( "transforms/tensor_array_ops_decomposition.cc", "transforms/tensor_list_ops_decomposition.cc", "transforms/test_side_effect_analysis.cc", + "transforms/tf_data_optimization_pass.cc", "transforms/tf_device_assignment.cc", "transforms/tpu_cluster_formation.cc", "transforms/tpu_dynamic_layout_pass.cc", @@ -444,6 +477,7 @@ cc_library( ":tensorflow", ":tensorflow_optimize_inc_gen", ":tensorflow_types", + ":tf_data_optimization", ":tpu_rewrite_device_util", ":translate_utils", ":unroll_batch_matmul_pass", diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td index edaf75b4011..744d1ac5b71 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td @@ -777,4 +777,133 @@ Formats a string template using a list of tensors, pretty-printing tensor summar TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>; } +//===----------------------------------------------------------------------===// +// tf.data ops +//===----------------------------------------------------------------------===// + +def TF_BatchDatasetV2Op : TF_Op<"BatchDatasetV2", [NoSideEffect]> { + let summary = [{ +Creates a dataset that batches `batch_size` elements from `input_dataset`. + }]; + + let description = [{ + }]; + + let arguments = (ins + TF_VariantTensor:$input_dataset, + I64Tensor:$batch_size, + I1Tensor:$drop_remainder, + + DefaultValuedAttr:$parallel_copy, + Confined]>:$output_types, + Confined]>:$output_shapes + ); + + let results = (outs + TF_VariantTensor:$handle + ); +} + +def TF_MapDatasetOp : TF_Op<"MapDataset", [NoSideEffect]> { + let summary = [{ + Creates a dataset that applies `f` to the outputs of `input_dataset`. + }]; + + let arguments = (ins + TF_VariantTensor:$input_dataset, + Variadic:$other_arguments, + + SymbolRefAttr:$f, + Confined]>:$output_types, + Confined]>:$output_shapes, + DefaultValuedAttr:$use_inter_op_parallelism, + DefaultValuedAttr:$preserve_cardinality + ); + + let results = (outs + TF_VariantTensor:$handle + ); + + TF_DerivedOperandTypeListAttr Targuments = TF_DerivedOperandTypeListAttr<1>; +} + +def TF_MapAndBatchDatasetOp : TF_Op<"MapAndBatchDataset", [NoSideEffect]> { + let summary = "Creates a dataset that fuses mapping with batching."; + + let description = [{ +Creates a dataset that applies `f` to the outputs of `input_dataset` and then +batches `batch_size` of them. + +Unlike a "MapDataset", which applies `f` sequentially, this dataset invokes up +to `batch_size * num_parallel_batches` copies of `f` in parallel. + }]; + + let arguments = (ins + TF_VariantTensor:$input_dataset, + Variadic:$other_arguments, + I64Tensor:$batch_size, + I64Tensor:$num_parallel_calls, + I1Tensor:$drop_remainder, + + SymbolRefAttr:$f, + Confined]>:$output_types, + Confined]>:$output_shapes, + DefaultValuedAttr:$preserve_cardinality + ); + + let results = (outs + TF_VariantTensor:$handle + ); + + TF_DerivedOperandTypeListAttr Targuments = TF_DerivedOperandTypeListAttr<1>; +} + +def TF_ParallelMapDatasetOp : TF_Op<"ParallelMapDataset", [NoSideEffect]> { + let summary = [{ + Creates a dataset that applies `f` to the outputs of `input_dataset`. + }]; + + let description = [{ + Unlike a "MapDataset", which applies `f` sequentially, this dataset invokes + up to `num_parallel_calls` copies of `f` in parallel. + }]; + + let arguments = (ins + TF_VariantTensor:$input_dataset, + Variadic:$other_arguments, + I32Tensor:$num_parallel_calls, + + SymbolRefAttr:$f, + Confined]>:$output_types, + Confined]>:$output_shapes, + DefaultValuedAttr:$use_inter_op_parallelism, + DefaultValuedAttr:$sloppy, + DefaultValuedAttr:$preserve_cardinality + ); + + let results = (outs + TF_VariantTensor:$handle + ); + + TF_DerivedOperandTypeListAttr Targuments = TF_DerivedOperandTypeListAttr<1>; +} + +def TF_TensorSliceDatasetOp : TF_Op<"TensorSliceDataset", []> { + let summary = [{ + Creates a dataset that emits each dim-0 slice of `components` once. + }]; + + let arguments = (ins + Variadic:$components, + Confined]>:$output_shapes + ); + + let results = (outs + TF_VariantTensor:$handle + ); + + TF_DerivedOperandTypeListAttr Toutput_types = TF_DerivedOperandTypeListAttr<0>; +} + + #endif // TF_OPS diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/tf-data-pipeline.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/tf-data-pipeline.pbtxt new file mode 100644 index 00000000000..1e640baa507 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/tf-data-pipeline.pbtxt @@ -0,0 +1,256 @@ +# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-output-arrays=BatchDatasetV2 -o - | FileCheck %s --dump-input-on-failure + +# CHECK-LABEL: func @main() -> tensor<*x!tf.variant> +# CHECK: %[[tensor_slice:.*]], %[[tensor_slice_control:.*]] = tf_executor.island wraps "tf.TensorSliceDataset" +# CHECK: %[[map_dataset:.*]], %[[map_dataset_control:.*]] = tf_executor.island wraps "tf.MapDataset"(%[[tensor_slice]] +# CHECK: %[[batch_dataset:.*]], %[[batch_dataset_control:.*]] = tf_executor.island wraps "tf.BatchDatasetV2"(%[[map_dataset]] + +node { + name: "tensors/normalize_tensors/component_0" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 3 + } + } + tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" + } + } + } +} +node { + name: "TensorSliceDataset" + op: "TensorSliceDataset" + input: "tensors/normalize_tensors/component_0" + attr { + key: "Toutput_types" + value { + list { + type: DT_INT32 + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name: "MapDataset" + op: "MapDataset" + input: "TensorSliceDataset" + attr { + key: "Targuments" + value { + list { + } + } + } + attr { + key: "f" + value { + func { + name: "__inference_Dataset_map__8" + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_INT32 + } + } + } + attr { + key: "preserve_cardinality" + value { + b: false + } + } + attr { + key: "use_inter_op_parallelism" + value { + b: true + } + } +} +node { + name: "batch_size" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 5 + } + } + } +} +node { + name: "drop_remainder" + op: "Const" + attr { + key: "dtype" + value { + type: DT_BOOL + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_BOOL + tensor_shape { + } + bool_val: false + } + } + } +} +node { + name: "BatchDatasetV2" + op: "BatchDatasetV2" + input: "MapDataset" + input: "batch_size" + input: "drop_remainder" + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_INT32 + } + } + } + attr { + key: "parallel_copy" + value { + b: false + } + } +} +library { + function { + signature { + name: "__inference_Dataset_map__8" + input_arg { + name: "args_0" + type: DT_INT32 + } + output_arg { + name: "identity" + type: DT_INT32 + } + } + node_def { + name: "mul/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } + } + node_def { + name: "mul" + op: "Mul" + input: "args_0" + input: "mul/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "Identity" + op: "Identity" + input: "mul:z:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + ret { + key: "identity" + value: "Identity:output:0" + } + arg_attr { + key: 0 + value { + attr { + key: "_user_specified_name" + value { + s: "args_0" + } + } + } + } + } +} +versions { + producer: 134 + min_consumer: 12 +} + diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_data_fuse_map_and_batch.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_data_fuse_map_and_batch.mlir new file mode 100644 index 00000000000..39f34caf259 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_data_fuse_map_and_batch.mlir @@ -0,0 +1,29 @@ +// RUN: tf-opt -tf-standard-pipeline -tf-data-optimization %s -o %t && FileCheck %s --dump-input-on-failure < %t + +module { +// CHECK-LABEL: fuse_map_and_batch +func @fuse_map_and_batch() -> tensor attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "BatchDatasetV2"}} { + %0 = "tf.Const"() {value = dense<5> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32> + // CHECK: %[[NPC:.*]] = "tf.Const"() {value = dense<1> : tensor} + // CHECK: %[[TSLICE:.*]] = "tf.TensorSliceDataset" + %3 = "tf.TensorSliceDataset"(%2) {device = "", output_shapes = [#tf.shape<>]} : (tensor<3xi32>) -> tensor<*x!tf.variant> + // CHECK: "tf.MapAndBatchDataset"(%[[TSLICE]], %[[BSIZE:.*]], %[[NPC]] + // CHECK-SAME: f = @"__inference_Dataset_map__80", + %4 = "tf.MapDataset"(%3) {device = "", + f = @"__inference_Dataset_map__80", + output_shapes = [#tf.shape<>], output_types = [i32], + preserve_cardinality = false, sloppy = false, + use_inter_op_parallelism = true} : (tensor<*x!tf.variant>) -> tensor + %5 = "tf.BatchDatasetV2"(%4, %0, %1) {device = "", output_shapes = [#tf.shape<>], output_types = [i32], parallel_copy = false} : (tensor, tensor, tensor) -> tensor + return %5 : tensor +} + +func @"__inference_Dataset_map__80"(%arg0: tensor<*xi32>) -> tensor<*xi32> { + %0 = "tf.Const"() {value = dense<2> : tensor} : () -> tensor + %1 = "tf.Mul"(%arg0, %0) {device = ""} : (tensor<*xi32>, tensor) -> tensor<*xi32> + %2 = "tf.Identity"(%1) {device = ""} : (tensor<*xi32>) -> tensor<*xi32> + return %2 : tensor<*xi32> +} +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_data_fuse_pmap_and_batch.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_data_fuse_pmap_and_batch.mlir new file mode 100644 index 00000000000..70c5c220fe1 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_data_fuse_pmap_and_batch.mlir @@ -0,0 +1,29 @@ +// RUN: tf-opt -tf-standard-pipeline -tf-data-optimization %s -o %t && FileCheck %s --dump-input-on-failure < %t + +module { +// CHECK-LABEL: fuse_pmap_and_batch +func @fuse_pmap_and_batch() -> tensor attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "BatchDatasetV2"}} { + %0 = "tf.Const"() {value = dense<5> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32> + %3 = "tf.Const"() {value = dense<12> : tensor} : () -> tensor + // CHECK: %[[TSLICE:.*]] = "tf.TensorSliceDataset" + %4 = "tf.TensorSliceDataset"(%2) {device = "", output_shapes = [#tf.shape<>]} : (tensor<3xi32>) -> tensor<*x!tf.variant> + // CHECK: "tf.MapAndBatchDataset"(%[[TSLICE]], + // CHECK-SAME: f = @"__inference_Dataset_map__80", + %5 = "tf.ParallelMapDataset"(%4, %3) {device = "", + f = @"__inference_Dataset_map__80", + output_shapes = [#tf.shape<>], output_types = [i32], + preserve_cardinality = false, sloppy = false, + use_inter_op_parallelism = true} : (tensor<*x!tf.variant>, tensor) -> tensor + %6 = "tf.BatchDatasetV2"(%5, %0, %1) {device = "", output_shapes = [#tf.shape<>], output_types = [i32], parallel_copy = false} : (tensor, tensor, tensor) -> tensor + return %6 : tensor +} + +func @"__inference_Dataset_map__80"(%arg0: tensor<*xi32>) -> tensor<*xi32> { + %0 = "tf.Const"() {value = dense<2> : tensor} : () -> tensor + %1 = "tf.Mul"(%arg0, %0) {device = ""} : (tensor<*xi32>, tensor) -> tensor<*xi32> + %2 = "tf.Identity"(%1) {device = ""} : (tensor<*xi32>) -> tensor<*xi32> + return %2 : tensor<*xi32> +} +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.cc new file mode 100644 index 00000000000..786c4b74b34 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.cc @@ -0,0 +1,65 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.h" + +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" + +namespace mlir { +namespace TF { + +namespace { + +struct FuseParallelMapAndBatch : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(BatchDatasetV2Op op, + PatternRewriter &rewriter) const override { + auto batchInputDataset = op.input_dataset(); + + ParallelMapDatasetOp batchInputOp = dyn_cast_or_null( + batchInputDataset.getDefiningOp()); + if (!batchInputOp) return failure(); + + // The type of the `num_parallel_calls` argument in ParallelMapDataset + // and MapAndBatchDataset is different (int32 and int64 respectively) + auto num_parallel_calls_op = rewriter.create( + op.getLoc(), UnrankedTensorType::get(rewriter.getIntegerType(64)), + batchInputOp.num_parallel_calls(), rewriter.getBoolAttr(false)); + + auto fused_op = rewriter.create( + op.getLoc(), op.getType(), batchInputOp.input_dataset(), + batchInputOp.other_arguments(), op.batch_size(), + num_parallel_calls_op.y(), op.drop_remainder(), batchInputOp.f(), + op.output_types(), op.output_shapes(), + batchInputOp.preserve_cardinality()); + rewriter.replaceOp(op, {fused_op.handle()}); + return failure(); + } +}; + +#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_tf_data_optimization.inc" +} // namespace + +void PopulateTFDataOptimizationPatterns(MLIRContext *context, + OwningRewritePatternList *patterns) { + patterns->insert(context); + populateWithGenerated(context, patterns); +} + +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.h b/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.h new file mode 100644 index 00000000000..ffbc06a9515 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.h @@ -0,0 +1,32 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_DATA_OPTIMIZATION_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_DATA_OPTIMIZATION_H_ + +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project + +namespace mlir { +namespace TF { + +// Populates patterns to perform optimizations specific to tf.data operations. +void PopulateTFDataOptimizationPatterns(MLIRContext *context, + OwningRewritePatternList *patterns); + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_DATA_OPTIMIZATION_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.td b/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.td new file mode 100644 index 00000000000..4b4239679b2 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.td @@ -0,0 +1,32 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +include "mlir/IR/OpBase.td" +include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" + +// TODO(jpienaar): Move this somewhere general. +class GetI64ScalarElementsAttr : + NativeCodeCall<"DenseElementsAttr::get(RankedTensorType::get({}, $_builder.getIntegerType(64)), " # value # ")">; + +def FuseMapAndBatch : Pat< + (TF_BatchDatasetV2Op + (TF_MapDatasetOp $input_dataset, $other_arguments, $f, $output_types, + $output_shapes, $use_inter_op_parallelism, $preserve_cardinality), + $batch_size, $drop_remainder, $parallel_copy, $batch_output_types, + $batch_output_shapes), + (TF_MapAndBatchDatasetOp $input_dataset, $other_arguments, $batch_size, + (TF_ConstOp (GetI64ScalarElementsAttr<1>)), $drop_remainder, $f, + $batch_output_types, $batch_output_shapes, $preserve_cardinality)>; + diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization_pass.cc new file mode 100644 index 00000000000..5be69bddb11 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization_pass.cc @@ -0,0 +1,40 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.h" + +namespace mlir { +namespace TF { +namespace { + +// Perform tf.data optimizations. +struct TFDataOptimization + : public PassWrapper { + void runOnFunction() override { + OwningRewritePatternList patterns; + mlir::TF::PopulateTFDataOptimizationPatterns(&getContext(), &patterns); + + applyPatternsAndFoldGreedily(getFunction(), patterns); + } +}; + +} // namespace +} // namespace TF +} // namespace mlir + +static mlir::PassRegistration pass( + "tf-data-optimization", "Performs tf.data optimizations");