[tf.data] Add map and batch fusion rewrite in MLIR
PiperOrigin-RevId: 310260964 Change-Id: I6505bcd35f21a3f9ff520f1900038c3c4be15536
This commit is contained in:
parent
469de83a9c
commit
a967cad22b
|
@ -342,6 +342,38 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
gentbl(
|
||||||
|
name = "tf_data_optimization_inc_gen",
|
||||||
|
tbl_outs = [
|
||||||
|
(
|
||||||
|
"-gen-rewriters",
|
||||||
|
"transforms/generated_tf_data_optimization.inc",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||||
|
td_file = "transforms/tf_data_optimization.td",
|
||||||
|
td_srcs = [
|
||||||
|
":tensorflow_ops_td_files",
|
||||||
|
"@llvm-project//mlir:StdOpsTdFiles",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tf_data_optimization",
|
||||||
|
srcs = [
|
||||||
|
"transforms/tf_data_optimization.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"transforms/tf_data_optimization.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":tensorflow",
|
||||||
|
":tensorflow_types",
|
||||||
|
":tf_data_optimization_inc_gen",
|
||||||
|
"@llvm-project//mlir:IR",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "unroll_batch_matmul_pass",
|
name = "unroll_batch_matmul_pass",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
@ -406,6 +438,7 @@ cc_library(
|
||||||
"transforms/tensor_array_ops_decomposition.cc",
|
"transforms/tensor_array_ops_decomposition.cc",
|
||||||
"transforms/tensor_list_ops_decomposition.cc",
|
"transforms/tensor_list_ops_decomposition.cc",
|
||||||
"transforms/test_side_effect_analysis.cc",
|
"transforms/test_side_effect_analysis.cc",
|
||||||
|
"transforms/tf_data_optimization_pass.cc",
|
||||||
"transforms/tf_device_assignment.cc",
|
"transforms/tf_device_assignment.cc",
|
||||||
"transforms/tpu_cluster_formation.cc",
|
"transforms/tpu_cluster_formation.cc",
|
||||||
"transforms/tpu_dynamic_layout_pass.cc",
|
"transforms/tpu_dynamic_layout_pass.cc",
|
||||||
|
@ -444,6 +477,7 @@ cc_library(
|
||||||
":tensorflow",
|
":tensorflow",
|
||||||
":tensorflow_optimize_inc_gen",
|
":tensorflow_optimize_inc_gen",
|
||||||
":tensorflow_types",
|
":tensorflow_types",
|
||||||
|
":tf_data_optimization",
|
||||||
":tpu_rewrite_device_util",
|
":tpu_rewrite_device_util",
|
||||||
":translate_utils",
|
":translate_utils",
|
||||||
":unroll_batch_matmul_pass",
|
":unroll_batch_matmul_pass",
|
||||||
|
|
|
@ -777,4 +777,133 @@ Formats a string template using a list of tensors, pretty-printing tensor summar
|
||||||
TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>;
|
TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// tf.data ops
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def TF_BatchDatasetV2Op : TF_Op<"BatchDatasetV2", [NoSideEffect]> {
|
||||||
|
let summary = [{
|
||||||
|
Creates a dataset that batches `batch_size` elements from `input_dataset`.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
TF_VariantTensor:$input_dataset,
|
||||||
|
I64Tensor:$batch_size,
|
||||||
|
I1Tensor:$drop_remainder,
|
||||||
|
|
||||||
|
DefaultValuedAttr<BoolAttr, "false">:$parallel_copy,
|
||||||
|
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
|
||||||
|
Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
TF_VariantTensor:$handle
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
def TF_MapDatasetOp : TF_Op<"MapDataset", [NoSideEffect]> {
|
||||||
|
let summary = [{
|
||||||
|
Creates a dataset that applies `f` to the outputs of `input_dataset`.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
TF_VariantTensor:$input_dataset,
|
||||||
|
Variadic<TF_Tensor>:$other_arguments,
|
||||||
|
|
||||||
|
SymbolRefAttr:$f,
|
||||||
|
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
|
||||||
|
Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes,
|
||||||
|
DefaultValuedAttr<BoolAttr, "true">:$use_inter_op_parallelism,
|
||||||
|
DefaultValuedAttr<BoolAttr, "false">:$preserve_cardinality
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
TF_VariantTensor:$handle
|
||||||
|
);
|
||||||
|
|
||||||
|
TF_DerivedOperandTypeListAttr Targuments = TF_DerivedOperandTypeListAttr<1>;
|
||||||
|
}
|
||||||
|
|
||||||
|
def TF_MapAndBatchDatasetOp : TF_Op<"MapAndBatchDataset", [NoSideEffect]> {
|
||||||
|
let summary = "Creates a dataset that fuses mapping with batching.";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
Creates a dataset that applies `f` to the outputs of `input_dataset` and then
|
||||||
|
batches `batch_size` of them.
|
||||||
|
|
||||||
|
Unlike a "MapDataset", which applies `f` sequentially, this dataset invokes up
|
||||||
|
to `batch_size * num_parallel_batches` copies of `f` in parallel.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
TF_VariantTensor:$input_dataset,
|
||||||
|
Variadic<TF_Tensor>:$other_arguments,
|
||||||
|
I64Tensor:$batch_size,
|
||||||
|
I64Tensor:$num_parallel_calls,
|
||||||
|
I1Tensor:$drop_remainder,
|
||||||
|
|
||||||
|
SymbolRefAttr:$f,
|
||||||
|
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
|
||||||
|
Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes,
|
||||||
|
DefaultValuedAttr<BoolAttr, "false">:$preserve_cardinality
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
TF_VariantTensor:$handle
|
||||||
|
);
|
||||||
|
|
||||||
|
TF_DerivedOperandTypeListAttr Targuments = TF_DerivedOperandTypeListAttr<1>;
|
||||||
|
}
|
||||||
|
|
||||||
|
def TF_ParallelMapDatasetOp : TF_Op<"ParallelMapDataset", [NoSideEffect]> {
|
||||||
|
let summary = [{
|
||||||
|
Creates a dataset that applies `f` to the outputs of `input_dataset`.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
Unlike a "MapDataset", which applies `f` sequentially, this dataset invokes
|
||||||
|
up to `num_parallel_calls` copies of `f` in parallel.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
TF_VariantTensor:$input_dataset,
|
||||||
|
Variadic<TF_Tensor>:$other_arguments,
|
||||||
|
I32Tensor:$num_parallel_calls,
|
||||||
|
|
||||||
|
SymbolRefAttr:$f,
|
||||||
|
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
|
||||||
|
Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes,
|
||||||
|
DefaultValuedAttr<BoolAttr, "true">:$use_inter_op_parallelism,
|
||||||
|
DefaultValuedAttr<BoolAttr, "false">:$sloppy,
|
||||||
|
DefaultValuedAttr<BoolAttr, "false">:$preserve_cardinality
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
TF_VariantTensor:$handle
|
||||||
|
);
|
||||||
|
|
||||||
|
TF_DerivedOperandTypeListAttr Targuments = TF_DerivedOperandTypeListAttr<1>;
|
||||||
|
}
|
||||||
|
|
||||||
|
def TF_TensorSliceDatasetOp : TF_Op<"TensorSliceDataset", []> {
|
||||||
|
let summary = [{
|
||||||
|
Creates a dataset that emits each dim-0 slice of `components` once.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
Variadic<TF_Tensor>:$components,
|
||||||
|
Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
TF_VariantTensor:$handle
|
||||||
|
);
|
||||||
|
|
||||||
|
TF_DerivedOperandTypeListAttr Toutput_types = TF_DerivedOperandTypeListAttr<0>;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
#endif // TF_OPS
|
#endif // TF_OPS
|
||||||
|
|
|
@ -0,0 +1,256 @@
|
||||||
|
# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-output-arrays=BatchDatasetV2 -o - | FileCheck %s --dump-input-on-failure
|
||||||
|
|
||||||
|
# CHECK-LABEL: func @main() -> tensor<*x!tf.variant>
|
||||||
|
# CHECK: %[[tensor_slice:.*]], %[[tensor_slice_control:.*]] = tf_executor.island wraps "tf.TensorSliceDataset"
|
||||||
|
# CHECK: %[[map_dataset:.*]], %[[map_dataset_control:.*]] = tf_executor.island wraps "tf.MapDataset"(%[[tensor_slice]]
|
||||||
|
# CHECK: %[[batch_dataset:.*]], %[[batch_dataset_control:.*]] = tf_executor.island wraps "tf.BatchDatasetV2"(%[[map_dataset]]
|
||||||
|
|
||||||
|
node {
|
||||||
|
name: "tensors/normalize_tensors/component_0"
|
||||||
|
op: "Const"
|
||||||
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_INT32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "value"
|
||||||
|
value {
|
||||||
|
tensor {
|
||||||
|
dtype: DT_INT32
|
||||||
|
tensor_shape {
|
||||||
|
dim {
|
||||||
|
size: 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "TensorSliceDataset"
|
||||||
|
op: "TensorSliceDataset"
|
||||||
|
input: "tensors/normalize_tensors/component_0"
|
||||||
|
attr {
|
||||||
|
key: "Toutput_types"
|
||||||
|
value {
|
||||||
|
list {
|
||||||
|
type: DT_INT32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "output_shapes"
|
||||||
|
value {
|
||||||
|
list {
|
||||||
|
shape {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "MapDataset"
|
||||||
|
op: "MapDataset"
|
||||||
|
input: "TensorSliceDataset"
|
||||||
|
attr {
|
||||||
|
key: "Targuments"
|
||||||
|
value {
|
||||||
|
list {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "f"
|
||||||
|
value {
|
||||||
|
func {
|
||||||
|
name: "__inference_Dataset_map_<lambda>_8"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "output_shapes"
|
||||||
|
value {
|
||||||
|
list {
|
||||||
|
shape {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "output_types"
|
||||||
|
value {
|
||||||
|
list {
|
||||||
|
type: DT_INT32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "preserve_cardinality"
|
||||||
|
value {
|
||||||
|
b: false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "use_inter_op_parallelism"
|
||||||
|
value {
|
||||||
|
b: true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "batch_size"
|
||||||
|
op: "Const"
|
||||||
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_INT64
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "value"
|
||||||
|
value {
|
||||||
|
tensor {
|
||||||
|
dtype: DT_INT64
|
||||||
|
tensor_shape {
|
||||||
|
}
|
||||||
|
int64_val: 5
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "drop_remainder"
|
||||||
|
op: "Const"
|
||||||
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_BOOL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "value"
|
||||||
|
value {
|
||||||
|
tensor {
|
||||||
|
dtype: DT_BOOL
|
||||||
|
tensor_shape {
|
||||||
|
}
|
||||||
|
bool_val: false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "BatchDatasetV2"
|
||||||
|
op: "BatchDatasetV2"
|
||||||
|
input: "MapDataset"
|
||||||
|
input: "batch_size"
|
||||||
|
input: "drop_remainder"
|
||||||
|
attr {
|
||||||
|
key: "output_shapes"
|
||||||
|
value {
|
||||||
|
list {
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
size: -1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "output_types"
|
||||||
|
value {
|
||||||
|
list {
|
||||||
|
type: DT_INT32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "parallel_copy"
|
||||||
|
value {
|
||||||
|
b: false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
library {
|
||||||
|
function {
|
||||||
|
signature {
|
||||||
|
name: "__inference_Dataset_map_<lambda>_8"
|
||||||
|
input_arg {
|
||||||
|
name: "args_0"
|
||||||
|
type: DT_INT32
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "identity"
|
||||||
|
type: DT_INT32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node_def {
|
||||||
|
name: "mul/y"
|
||||||
|
op: "Const"
|
||||||
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_INT32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "value"
|
||||||
|
value {
|
||||||
|
tensor {
|
||||||
|
dtype: DT_INT32
|
||||||
|
tensor_shape {
|
||||||
|
}
|
||||||
|
int_val: 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node_def {
|
||||||
|
name: "mul"
|
||||||
|
op: "Mul"
|
||||||
|
input: "args_0"
|
||||||
|
input: "mul/y:output:0"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
type: DT_INT32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node_def {
|
||||||
|
name: "Identity"
|
||||||
|
op: "Identity"
|
||||||
|
input: "mul:z:0"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
type: DT_INT32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ret {
|
||||||
|
key: "identity"
|
||||||
|
value: "Identity:output:0"
|
||||||
|
}
|
||||||
|
arg_attr {
|
||||||
|
key: 0
|
||||||
|
value {
|
||||||
|
attr {
|
||||||
|
key: "_user_specified_name"
|
||||||
|
value {
|
||||||
|
s: "args_0"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
versions {
|
||||||
|
producer: 134
|
||||||
|
min_consumer: 12
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,29 @@
|
||||||
|
// RUN: tf-opt -tf-standard-pipeline -tf-data-optimization %s -o %t && FileCheck %s --dump-input-on-failure < %t
|
||||||
|
|
||||||
|
module {
|
||||||
|
// CHECK-LABEL: fuse_map_and_batch
|
||||||
|
func @fuse_map_and_batch() -> tensor<!tf.variant> attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "BatchDatasetV2"}} {
|
||||||
|
%0 = "tf.Const"() {value = dense<5> : tensor<i64>} : () -> tensor<i64>
|
||||||
|
%1 = "tf.Const"() {value = dense<false> : tensor<i1>} : () -> tensor<i1>
|
||||||
|
%2 = "tf.Const"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||||
|
// CHECK: %[[NPC:.*]] = "tf.Const"() {value = dense<1> : tensor<i64>}
|
||||||
|
// CHECK: %[[TSLICE:.*]] = "tf.TensorSliceDataset"
|
||||||
|
%3 = "tf.TensorSliceDataset"(%2) {device = "", output_shapes = [#tf.shape<>]} : (tensor<3xi32>) -> tensor<*x!tf.variant>
|
||||||
|
// CHECK: "tf.MapAndBatchDataset"(%[[TSLICE]], %[[BSIZE:.*]], %[[NPC]]
|
||||||
|
// CHECK-SAME: f = @"__inference_Dataset_map_<lambda>_80",
|
||||||
|
%4 = "tf.MapDataset"(%3) {device = "",
|
||||||
|
f = @"__inference_Dataset_map_<lambda>_80",
|
||||||
|
output_shapes = [#tf.shape<>], output_types = [i32],
|
||||||
|
preserve_cardinality = false, sloppy = false,
|
||||||
|
use_inter_op_parallelism = true} : (tensor<*x!tf.variant>) -> tensor<!tf.variant>
|
||||||
|
%5 = "tf.BatchDatasetV2"(%4, %0, %1) {device = "", output_shapes = [#tf.shape<>], output_types = [i32], parallel_copy = false} : (tensor<!tf.variant>, tensor<i64>, tensor<i1>) -> tensor<!tf.variant>
|
||||||
|
return %5 : tensor<!tf.variant>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @"__inference_Dataset_map_<lambda>_80"(%arg0: tensor<*xi32>) -> tensor<*xi32> {
|
||||||
|
%0 = "tf.Const"() {value = dense<2> : tensor<i32>} : () -> tensor<i32>
|
||||||
|
%1 = "tf.Mul"(%arg0, %0) {device = ""} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||||
|
%2 = "tf.Identity"(%1) {device = ""} : (tensor<*xi32>) -> tensor<*xi32>
|
||||||
|
return %2 : tensor<*xi32>
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,29 @@
|
||||||
|
// RUN: tf-opt -tf-standard-pipeline -tf-data-optimization %s -o %t && FileCheck %s --dump-input-on-failure < %t
|
||||||
|
|
||||||
|
module {
|
||||||
|
// CHECK-LABEL: fuse_pmap_and_batch
|
||||||
|
func @fuse_pmap_and_batch() -> tensor<!tf.variant> attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "BatchDatasetV2"}} {
|
||||||
|
%0 = "tf.Const"() {value = dense<5> : tensor<i64>} : () -> tensor<i64>
|
||||||
|
%1 = "tf.Const"() {value = dense<false> : tensor<i1>} : () -> tensor<i1>
|
||||||
|
%2 = "tf.Const"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||||
|
%3 = "tf.Const"() {value = dense<12> : tensor<i32>} : () -> tensor<i32>
|
||||||
|
// CHECK: %[[TSLICE:.*]] = "tf.TensorSliceDataset"
|
||||||
|
%4 = "tf.TensorSliceDataset"(%2) {device = "", output_shapes = [#tf.shape<>]} : (tensor<3xi32>) -> tensor<*x!tf.variant>
|
||||||
|
// CHECK: "tf.MapAndBatchDataset"(%[[TSLICE]],
|
||||||
|
// CHECK-SAME: f = @"__inference_Dataset_map_<lambda>_80",
|
||||||
|
%5 = "tf.ParallelMapDataset"(%4, %3) {device = "",
|
||||||
|
f = @"__inference_Dataset_map_<lambda>_80",
|
||||||
|
output_shapes = [#tf.shape<>], output_types = [i32],
|
||||||
|
preserve_cardinality = false, sloppy = false,
|
||||||
|
use_inter_op_parallelism = true} : (tensor<*x!tf.variant>, tensor<i32>) -> tensor<!tf.variant>
|
||||||
|
%6 = "tf.BatchDatasetV2"(%5, %0, %1) {device = "", output_shapes = [#tf.shape<>], output_types = [i32], parallel_copy = false} : (tensor<!tf.variant>, tensor<i64>, tensor<i1>) -> tensor<!tf.variant>
|
||||||
|
return %6 : tensor<!tf.variant>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @"__inference_Dataset_map_<lambda>_80"(%arg0: tensor<*xi32>) -> tensor<*xi32> {
|
||||||
|
%0 = "tf.Const"() {value = dense<2> : tensor<i32>} : () -> tensor<i32>
|
||||||
|
%1 = "tf.Mul"(%arg0, %0) {device = ""} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||||
|
%2 = "tf.Identity"(%1) {device = ""} : (tensor<*xi32>) -> tensor<*xi32>
|
||||||
|
return %2 : tensor<*xi32>
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,65 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.h"
|
||||||
|
|
||||||
|
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace TF {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct FuseParallelMapAndBatch : public OpRewritePattern<BatchDatasetV2Op> {
|
||||||
|
using OpRewritePattern<BatchDatasetV2Op>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(BatchDatasetV2Op op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto batchInputDataset = op.input_dataset();
|
||||||
|
|
||||||
|
ParallelMapDatasetOp batchInputOp = dyn_cast_or_null<ParallelMapDatasetOp>(
|
||||||
|
batchInputDataset.getDefiningOp());
|
||||||
|
if (!batchInputOp) return failure();
|
||||||
|
|
||||||
|
// The type of the `num_parallel_calls` argument in ParallelMapDataset
|
||||||
|
// and MapAndBatchDataset is different (int32 and int64 respectively)
|
||||||
|
auto num_parallel_calls_op = rewriter.create<CastOp>(
|
||||||
|
op.getLoc(), UnrankedTensorType::get(rewriter.getIntegerType(64)),
|
||||||
|
batchInputOp.num_parallel_calls(), rewriter.getBoolAttr(false));
|
||||||
|
|
||||||
|
auto fused_op = rewriter.create<MapAndBatchDatasetOp>(
|
||||||
|
op.getLoc(), op.getType(), batchInputOp.input_dataset(),
|
||||||
|
batchInputOp.other_arguments(), op.batch_size(),
|
||||||
|
num_parallel_calls_op.y(), op.drop_remainder(), batchInputOp.f(),
|
||||||
|
op.output_types(), op.output_shapes(),
|
||||||
|
batchInputOp.preserve_cardinality());
|
||||||
|
rewriter.replaceOp(op, {fused_op.handle()});
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_tf_data_optimization.inc"
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void PopulateTFDataOptimizationPatterns(MLIRContext *context,
|
||||||
|
OwningRewritePatternList *patterns) {
|
||||||
|
patterns->insert<FuseParallelMapAndBatch>(context);
|
||||||
|
populateWithGenerated(context, patterns);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace TF
|
||||||
|
} // namespace mlir
|
|
@ -0,0 +1,32 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_DATA_OPTIMIZATION_H_
|
||||||
|
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_DATA_OPTIMIZATION_H_
|
||||||
|
|
||||||
|
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace TF {
|
||||||
|
|
||||||
|
// Populates patterns to perform optimizations specific to tf.data operations.
|
||||||
|
void PopulateTFDataOptimizationPatterns(MLIRContext *context,
|
||||||
|
OwningRewritePatternList *patterns);
|
||||||
|
|
||||||
|
} // namespace TF
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_DATA_OPTIMIZATION_H_
|
|
@ -0,0 +1,32 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
include "mlir/IR/OpBase.td"
|
||||||
|
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
||||||
|
|
||||||
|
// TODO(jpienaar): Move this somewhere general.
|
||||||
|
class GetI64ScalarElementsAttr<int value> :
|
||||||
|
NativeCodeCall<"DenseElementsAttr::get<int64_t>(RankedTensorType::get({}, $_builder.getIntegerType(64)), " # value # ")">;
|
||||||
|
|
||||||
|
def FuseMapAndBatch : Pat<
|
||||||
|
(TF_BatchDatasetV2Op
|
||||||
|
(TF_MapDatasetOp $input_dataset, $other_arguments, $f, $output_types,
|
||||||
|
$output_shapes, $use_inter_op_parallelism, $preserve_cardinality),
|
||||||
|
$batch_size, $drop_remainder, $parallel_copy, $batch_output_types,
|
||||||
|
$batch_output_shapes),
|
||||||
|
(TF_MapAndBatchDatasetOp $input_dataset, $other_arguments, $batch_size,
|
||||||
|
(TF_ConstOp (GetI64ScalarElementsAttr<1>)), $drop_remainder, $f,
|
||||||
|
$batch_output_types, $batch_output_shapes, $preserve_cardinality)>;
|
||||||
|
|
|
@ -0,0 +1,40 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||||
|
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace TF {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Perform tf.data optimizations.
|
||||||
|
struct TFDataOptimization
|
||||||
|
: public PassWrapper<TFDataOptimization, FunctionPass> {
|
||||||
|
void runOnFunction() override {
|
||||||
|
OwningRewritePatternList patterns;
|
||||||
|
mlir::TF::PopulateTFDataOptimizationPatterns(&getContext(), &patterns);
|
||||||
|
|
||||||
|
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace TF
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
static mlir::PassRegistration<mlir::TF::TFDataOptimization> pass(
|
||||||
|
"tf-data-optimization", "Performs tf.data optimizations");
|
Loading…
Reference in New Issue