diff --git a/CODEOWNERS b/CODEOWNERS
index 271e3b5b2ff..3ef02ffd68c 100644
--- a/CODEOWNERS
+++ b/CODEOWNERS
@@ -13,55 +13,4 @@
/tensorflow/tensorboard/ @jart
/tensorflow/tools/docs/ @markdaoust
-# contrib
-
-# NEED OWNER: /tensorflow/contrib/all_reduce
-/tensorflow/contrib/autograph/ @mdanatg @kkimdev
-/tensorflow/contrib/batching/ @alextp @chrisolston
-/tensorflow/contrib/bayesflow/ @ebrevdo @rsepassi @jvdillon
-/tensorflow/contrib/boosted_trees/ @sshrdp @yk5 @nataliaponomareva
-/tensorflow/contrib/checkpoint/ @allenlavoie
-/tensorflow/contrib/contrib/cluster_resolver/ @frankchn
-/tensorflow/contrib/cmake/ @mrry
-/tensorflow/contrib/copy_graph/ @tucker @poxvoculi
-/tensorflow/contrib/crf/ @kentonl
-/tensorflow/contrib/data/ @mrry
-/tensorflow/tensorflow/contrib/distribute @joshl @priyag @sourabhbajaj @frankchn
-/tensorflow/contrib/distributions/ @jvdillon @langmore @rsepassi
-/tensorflow/contrib/eager @jaingaurav @alextp
-/tensorflow/contrib/factorization/ @agarwal-ashish @xavigonzalvo
-/tensorflow/contrib/ffmpeg/ @fredbertsch
-/tensorflow/contrib/framework/ @ebrevdo
-/tensorflow/contrib/graph_editor/ @purpledog
-# NEED OWNER: /tensorflow/contrib/grid_rnn/
-/tensorflow/contrib/hadoop @yongtang
-/tensorflow/contrib/hvx/ @satok16
-/tensorflow/contrib/integrate/ @shoyer
-/tensorflow/contrib/kernel_methods/ @petrosmol
-/tensorflow/contrib/ios_examples/ @petewarden
-/tensorflow/contrib/labeled_tensor/ @shoyer
-/tensorflow/contrib/layers/ @fchollet @martinwicke
-/tensorflow/contrib/learn/ @martinwicke @ispirmustafa @alextp
-/tensorflow/contrib/linear_optimizer/ @petrosmol @andreasst @katsiapis
-/tensorflow/contrib/lookup/ @ysuematsu @andreasst
-/tensorflow/contrib/losses/ @alextp @ispirmustafa
-/tensorflow/contrib/makefile/ @petewarden @satok16 @wolffg
-/tensorflow/contrib/metrics/ @alextp @honkentuber @ispirmustafa
-/tensorflow/contrib/opt/ @strategist333 @alextp
-/tensorflow/contrib/pi_examples/ @maciekcc
-/tensorflow/contrib/quantization/ @petewarden
-/tensorflow/contrib/rnn/ @ebrevdo @scottzhu
-/tensorflow/contrib/saved_model/ @nfiedel @sukritiramesh @allenlavoie
-/tensorflow/contrib/seq2seq/ @ebrevdo @lmthang
-/tensorflow/contrib/session_bundle/ @nfiedel @sukritiramesh
-/tensorflow/contrib/slim/ @sguada @thenbasilmanran
-/tensorflow/contrib/stateless/ @girving @alextp
-/tensorflow/contrib/tensor_forest/ @gilberthendry @thomascolthurst @yupbank
-/tensorflow/contrib/tensorrt/ @aaroey @smit-hinsu @azaks2
-# NEED OWNER: /tensorflow/contrib/testing/
-/tensorflow/contrib/timeseries/ @allenlavoie
-/tensorflow/contrib/tpu/ @frankchn @saeta @jhseu @sourabhbajaj
-/tensorflow/contrib/training/ @joel-shor @ebrevdo
-/tensorflow/contrib/util/ @sherrym
-
/third_party/systemlibs/ @perfinion
diff --git a/README.md b/README.md
index 51ca43e1571..05b1e4de458 100644
--- a/README.md
+++ b/README.md
@@ -110,19 +110,19 @@ Build Type | Status
### Community Supported Builds
-Build Type | Status | Artifacts
-------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
-**Linux AMD ROCm GPU** Nightly | [](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly) | [Nightly](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly/lastSuccessfulBuild/)
-**Linux AMD ROCm GPU** Stable Release | [](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/) | Release [1.15](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/lastSuccessfulBuild/) / [2.x](http://ml-ci.amd.com:21096/job/tensorflow-rocm-v2-release/lastSuccessfulBuild/)
-**Linux s390x** Nightly | [](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | [Nightly](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/)
-**Linux s390x CPU** Stable Release | [](https://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/) | [Release](https://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/)
-**Linux ppc64le CPU** Nightly | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Nightly_Artifact/)
-**Linux ppc64le CPU** Stable Release | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_CPU_Release_Build/)
-**Linux ppc64le GPU** Nightly | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/)
-**Linux ppc64le GPU** Stable Release | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_GPU_Release_Build/)
-**Linux CPU with Intel® MKL-DNN** Nightly | [](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/)
-**Linux CPU with Intel® MKL-DNN**
**Supports Python 2.7, 3.4, 3.5, 3.6 and 3.7** | [](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild) | [1.14.0 PyPI](https://pypi.org/project/intel-tensorflow/)
-**Red Hat® Enterprise Linux® 7.6 CPU & GPU**
Python 2.7, 3.6 | [](https://jenkins-tensorflow.apps.ci.centos.org/job/tensorflow-rhel7-3.6/2/) | [1.13.1 PyPI](https://tensorflow.pypi.thoth-station.ninja/index/)
+Build Type | Status | Artifacts
+----------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
+**Linux AMD ROCm GPU** Nightly | [](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly) | [Nightly](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly/lastSuccessfulBuild/)
+**Linux AMD ROCm GPU** Stable Release | [](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/) | Release [1.15](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/lastSuccessfulBuild/) / [2.x](http://ml-ci.amd.com:21096/job/tensorflow-rocm-v2-release/lastSuccessfulBuild/)
+**Linux s390x** Nightly | [](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | [Nightly](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/)
+**Linux s390x CPU** Stable Release | [](https://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/) | [Release](https://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/)
+**Linux ppc64le CPU** Nightly | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Nightly_Artifact/)
+**Linux ppc64le CPU** Stable Release | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_CPU_Release_Build/)
+**Linux ppc64le GPU** Nightly | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/)
+**Linux ppc64le GPU** Stable Release | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_GPU_Release_Build/)
+**Linux CPU with Intel® MKL-DNN** Nightly | [](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/)
+**Linux CPU with Intel® MKL-DNN** Stable Release |  | Release [1.15](https://pypi.org/project/intel-tensorflow/1.15.0/) / [2.x](https://pypi.org/project/intel-tensorflow/)
+**Red Hat® Enterprise Linux® 7.6 CPU & GPU**
Python 2.7, 3.6 | [](https://jenkins-tensorflow.apps.ci.centos.org/job/tensorflow-rhel7-3.6/2/) | [1.13.1 PyPI](https://tensorflow.pypi.thoth-station.ninja/index/)
## Resources
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 0f299ec13f8..603c2a5c45c 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -195,6 +195,12 @@ config_setting(
visibility = ["//visibility:public"],
)
+config_setting(
+ name = "chromiumos",
+ values = {"crosstool_top": "//external:android/chromiumos"},
+ visibility = ["//visibility:public"],
+)
+
config_setting(
name = "linux_aarch64",
values = {"cpu": "aarch64"},
@@ -453,6 +459,7 @@ package_group(
"//tensorflow_estimator/python/estimator/...",
"//tensorflow_models/official/...",
"//third_party/py/autograph/...",
+ "//third_party/swift/tensorflow/x10/...",
],
)
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 46ade1b2e77..8793e308466 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -233,7 +233,7 @@ tensorflow::Status GetReplacedFromExistingWorkers(
std::vector responses(
existing_workers->size());
for (int i = 0; i < existing_workers->size(); i++) {
- tensorflow::eager::EagerClient* eager_client;
+ tensorflow::core::RefCountPtr eager_client;
statuses[i] =
client_cache->GetClient(existing_workers->at(i), &eager_client);
if (!statuses[i].ok()) {
@@ -282,7 +282,7 @@ tensorflow::Status CreateRemoteContexts(
continue;
}
- tensorflow::eager::EagerClient* eager_client;
+ tensorflow::core::RefCountPtr eager_client;
statuses[i] = remote_eager_workers->GetClient(remote_worker, &eager_client);
if (eager_client == nullptr) {
statuses[i] = tensorflow::errors::Internal(
@@ -340,7 +340,7 @@ tensorflow::Status UpdateRemoteContexts(
continue;
}
- tensorflow::eager::EagerClient* eager_client;
+ tensorflow::core::RefCountPtr eager_client;
statuses[i] = remote_eager_workers->GetClient(remote_worker, &eager_client);
if (eager_client == nullptr) {
statuses[i] = tensorflow::errors::Internal(
@@ -819,7 +819,7 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
}
// TODO(yuefengz): support partially specified `worker_name`.
- tensorflow::eager::EagerClient* eager_client;
+ tensorflow::core::RefCountPtr eager_client;
status->status = remote_eager_workers->GetClient(worker_name, &eager_client);
if (!status->status.ok()) {
return false;
diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/custom_opdef.pbtxt b/tensorflow/compiler/mlir/lite/tests/end2end/custom_opdef.pbtxt
index 7036ef71b58..0fcee7d7e8f 100644
--- a/tensorflow/compiler/mlir/lite/tests/end2end/custom_opdef.pbtxt
+++ b/tensorflow/compiler/mlir/lite/tests/end2end/custom_opdef.pbtxt
@@ -38,6 +38,6 @@ versions {
# CHECK: func @main(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<*xi32>
# CHECK: attributes {tf.entry_function = {inputs = "input0,input1", outputs = "output"}} {
-# CHECK-NEXT: %0 = "tf.BannaPotatoSaladWithColeslaw"(%arg0, %arg1) {T = "tfdtype$DT_INT32", device = "", name = "output"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<*xi32>
+# CHECK-NEXT: %0 = "tf.BannaPotatoSaladWithColeslaw"(%arg0, %arg1) {T = i32, device = "", name = "output"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<*xi32>
# CHECK-NEXT: return %0 : tensor<*xi32>
# CHECK-NEXT: }
diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir
index 27eff39c397..ec618ffa276 100644
--- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir
@@ -1280,3 +1280,13 @@ func @conv2d_backprop_unsupported_data_format(%arg0: tensor<4xi32>, %arg1: tenso
// CHECK-LABEL: conv2d_backprop_unsupported_data_format
// CHECK: tf.Conv2DBackpropInput
}
+
+func @assert_remove(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi1> {
+ %0 = "tf.LessEqual"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
+ "tf.Assert"(%0, %arg1) {summarize = 3} : (tensor<1xi1>, tensor<1xi32>) -> ()
+ return %0 : tensor<1xi1>
+ // CHECK-LABEL: assert_remove
+ // CHECK: tfl.less_equal
+ // CHECK-NOT: Assert
+ // CHECK: return
+}
diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir
index f7913f11f72..1d51adb16f2 100644
--- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir
@@ -622,3 +622,12 @@ func @Relu1_2(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
// CHECK: %[[relu_n1_to_1:[0-9].*]] = "tfl.relu_n1_to_1"
}
+
+// CHECK-LABEL: fuse_relu_to_add
+func @fuse_relu_to_add(%arg0: tensor<2x3xf32>, %arg1: tensor<2x3xf32>) -> tensor<2x3xf32> {
+ %0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
+ %1 = "tfl.relu_n1_to_1"(%0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
+ return %1 : tensor<2x3xf32>
+ // CHECK: %[[RES:.*]] = tfl.add %arg0, %arg1 {fused_activation_function = "RELU_N1_TO_1"}
+ // CHECK: return %[[RES]]
+}
diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc
index bc6ff5e3b47..0512bc98cab 100644
--- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc
@@ -68,6 +68,7 @@ struct LegalizeTF : public FunctionPass {
// TODO(antiagainst): Define this pattern in a table-driven manner once variadic
// operands are properly supported in declarative rewrite rule specification.
+DECL_CONVERT_OP(Assert);
DECL_CONVERT_OP(Concat);
DECL_CONVERT_OP(ConcatV2);
DECL_CONVERT_OP(MatMul);
@@ -86,7 +87,7 @@ PatternMatchResult ConvertTFConcatOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_concat_op = cast(op);
- SmallVector values(tf_concat_op.values());
+ auto values = tf_concat_op.values();
auto output_type = tf_concat_op.output()->getType();
// Extract axis attribute from constant concat_dims tensor
ElementsAttr axis;
@@ -105,7 +106,7 @@ PatternMatchResult ConvertTFConcatV2Op::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_concat_op = cast(op);
- SmallVector values(tf_concat_op.values());
+ auto values = tf_concat_op.values();
auto output_type = tf_concat_op.output()->getType();
// Extract axis attribute from constant axis tensor
ElementsAttr axis;
@@ -374,6 +375,14 @@ PatternMatchResult ConvertTFMatrixDiagV3Op::matchAndRewrite(
return matchFailure();
}
+// TF Lite doesn't support Assert, we just drop the assert from the graph.
+PatternMatchResult ConvertTFAssertOp::matchAndRewrite(
+ Operation* op, PatternRewriter& rewriter) const {
+ op->dropAllReferences();
+ op->erase();
+ return matchSuccess();
+}
+
void LegalizeTF::runOnFunction() {
OwningRewritePatternList patterns;
auto* ctx = &getContext();
@@ -385,7 +394,8 @@ void LegalizeTF::runOnFunction() {
.insert(ctx);
+ ConvertTFStridedSliceOp, ConvertTFUnpackOp, ConvertTFAssertOp>(
+ ctx);
applyPatternsGreedily(func, patterns);
}
diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc
index bf0e7169584..3f50c3ad1c1 100644
--- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc
@@ -484,9 +484,9 @@ struct ConvertTensorListResize : public ConversionPattern {
&rewriter);
// Inserts the two blocks' names into the symbol table held by the module.
- // Using ModuleManager will ensure that the inserted symbol names are
+ // Using SymbolTable will ensure that the inserted symbol names are
// unique.
- ModuleManager manager(resize_op.getParentOfType());
+ SymbolTable manager(resize_op.getParentOfType());
manager.insert(then_branch_op);
manager.insert(else_branch_op);
@@ -754,8 +754,7 @@ struct ConvertWhile : public ConversionPattern {
cloned.removeAttr("T");
UpdateFunctionTypes(cloned);
- SmallVector results(cloned.getResults());
- rewriter.replaceOp(op, results);
+ rewriter.replaceOp(op, cloned.getResults());
return matchSuccess();
}
};
diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc b/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc
index c8b54d26653..173785ba5b0 100644
--- a/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc
@@ -135,15 +135,15 @@ class FoldIfOp : public OpRewritePattern {
static void EraseDeadFuncs(const FuncSet& candiate_funcs, ModuleOp module) {
if (candiate_funcs.empty()) return;
- ModuleManager manager(module);
+ SymbolTable manager(module);
// Identify the functions that are used as symbols in the module and shouldn't
// be erased.
FuncSet in_use_funcs;
- manager.getModule().walk([&](Operation* op) {
+ manager.getOp()->walk([&](Operation* op) {
for (auto attr : op->getAttrs()) {
if (auto symbol = attr.second.dyn_cast()) {
- auto func = manager.lookupSymbol(symbol.getValue());
+ auto func = manager.lookup(symbol.getValue());
in_use_funcs.insert(func);
}
}
diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td
index 905f01d8413..a91f6de1971 100644
--- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td
+++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td
@@ -44,12 +44,13 @@ multiclass FuseActFnIntoConvOpPat {
$multiplier)>;
}
-// TODO(hinsu): Also fuse ops corresponding to RELU_N1_TO_1 and SIGN_BIT fused
+// TODO(hinsu): Also fuse ops corresponding to SIGN_BIT fused
// activation functions.
// Currently we're not fusing tanh, sigmoid, hard_swish and other activations
// those cannot be simply translated into clamping.
foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
- [TFL_Relu6Op, TFL_AF_Relu6]] in
+ [TFL_Relu6Op, TFL_AF_Relu6],
+ [TFL_Relu1Op, TFL_AF_Relu1]] in
defm : FuseActFnIntoConvOpPat;
@@ -291,3 +292,18 @@ def : Pat<(TFL_MaximumOp (TFL_MinimumOp $input,
(ConstantOp $NegOne)),
(TFL_Relu1Op $input),
[(ValueEquals<"-1"> $NegOne), (ValueEquals<"1"> $One)]>;
+
+// Multi-pattern consisting of matching stand-alone op or op followed by relu.
+multiclass FusedBinaryActivationFuncOpPat {
+ foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
+ [TFL_Relu6Op, TFL_AF_Relu6],
+ [TFL_Relu1Op, TFL_AF_Relu1]] in {
+ def : Pat<(actFnPair[0] (BinaryOp $lhs, $rhs, TFL_AF_None)),
+ (BinaryOp $lhs, $rhs, actFnPair[1])>;
+ }
+}
+
+// Instantiated FusedBinary patterns for the from-to pairs of ops.
+foreach BinaryOps = [TFL_AddOp, TFL_DivOp,
+ TFL_MulOp, TFL_SubOp] in
+ defm : FusedBinaryActivationFuncOpPat;
diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD
index 5484988d0f5..5f93210f06e 100644
--- a/tensorflow/compiler/mlir/tensorflow/BUILD
+++ b/tensorflow/compiler/mlir/tensorflow/BUILD
@@ -192,6 +192,37 @@ cc_library(
alwayslink = 1,
)
+gentbl(
+ name = "decompose_resource_ops_inc_gen",
+ tbl_outs = [
+ (
+ "-gen-rewriters",
+ "transforms/generated_decompose_resource_ops.inc",
+ ),
+ ],
+ tblgen = "@local_config_mlir//:mlir-tblgen",
+ td_file = "transforms/decompose_resource_ops.td",
+ td_srcs = [
+ ":tensorflow_ops_td_files",
+ "@local_config_mlir//:StdOpsTdFiles",
+ ],
+)
+
+cc_library(
+ name = "decompose_resource_ops",
+ srcs = [
+ "transforms/decompose_resource_ops.cc",
+ ],
+ hdrs = [
+ "transforms/decompose_resource_ops.h",
+ ],
+ deps = [
+ ":decompose_resource_ops_inc_gen",
+ ":tensorflow",
+ "@local_config_mlir//:IR",
+ ],
+)
+
cc_library(
name = "tensorflow_passes",
srcs = [
@@ -199,6 +230,7 @@ cc_library(
"transforms/bridge_pass.cc",
"transforms/cluster_formation.cc",
"transforms/cluster_outlining.cc",
+ "transforms/decompose_resource_ops_pass.cc",
"transforms/delete_unused_funcs.cc",
"transforms/executor_island_coarsening.cc",
"transforms/fold_switch.cc",
@@ -213,6 +245,7 @@ cc_library(
"transforms/raise_control_flow.cc",
"transforms/replicate_invariant_op_hoisting.cc",
"transforms/replicate_to_island.cc",
+ "transforms/resource_device_inference.cc",
"transforms/resource_op_lifting.cc",
"transforms/shape_inference.cc",
"transforms/shape_inference_pass.cc",
@@ -236,6 +269,8 @@ cc_library(
":bridge_logger",
":convert_tensor",
":convert_type",
+ ":decompose_resource_ops",
+ ":decompose_resource_ops_inc_gen",
":device_util",
":error_util",
":export_tf_dialect_op",
@@ -368,12 +403,14 @@ cc_library(
":convert_tensor",
":convert_type",
":mangling_util",
+ ":tensorflow",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/platform:protobuf",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
@@ -564,7 +601,6 @@ cc_library(
hdrs = ["utils/error_util.h"],
deps = [
"//tensorflow/core:lib",
- "//tensorflow/stream_executor/lib",
"@llvm//:support",
"@local_config_mlir//:IR",
],
@@ -808,7 +844,6 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core/platform:logging",
"//tensorflow/stream_executor/lib",
- "@com_google_absl//absl/types:span",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:Parser",
diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc
index 8d43c9330d0..898393479b0 100644
--- a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc
+++ b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc
@@ -34,6 +34,7 @@ limitations under the License.
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
@@ -99,12 +100,13 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) {
auto forward_input_to_output = [&](Value* operand, Value* result) {
if (!mlir::getElementTypeOrSelf(result->getType()).isa())
return;
+ auto& result_ids = resource_value_to_ids_[result];
auto operand_it = resource_value_to_ids_.find(operand);
assert(operand_it != resource_value_to_ids_.end() &&
"A resource-type output does not have the corresponding "
"resource-type input.");
- resource_value_to_ids_[result].insert(operand_it->getSecond().begin(),
- operand_it->getSecond().end());
+ result_ids.insert(operand_it->getSecond().begin(),
+ operand_it->getSecond().end());
};
// TODO(yuanzx): Consider control-flow ops.
func_op.walk([&](Operation* op) {
@@ -119,6 +121,16 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) {
forward_input_to_output(std::get<0>(operand_and_result),
std::get<1>(operand_and_result));
}
+ } else if (auto replicate = llvm::dyn_cast(op)) {
+ // The nested block for RepliateOp is handled separately in side-effect
+ // analysis. Inside that block, we can still treat its block arguments as
+ // different resources.
+ for (auto arg : replicate.GetBody().getArguments()) {
+ if (mlir::getElementTypeOrSelf(arg->getType())
+ .isa()) {
+ resource_value_to_ids_[arg].insert(next_unique_id++);
+ }
+ }
} else {
for (auto result : op->getResults()) {
if (!mlir::getElementTypeOrSelf(result->getType())
@@ -261,9 +273,36 @@ void SideEffectAnalysis::AddPredecessorsForAccess(int64_t resource_id,
void SideEffectAnalysis::AnalyzeFunction(
FuncOp func_op, const ResourceAliasAnalysis& alias_analysis) {
- // This function populates control_predecessors_ and control_successors_ by
- // walking through func_op's body, and tracking resource accesses in
- // per_resource_access_info_.
+ // AnalyzeRegion() recursively analyzes the function body, and only populates
+ // control_predecessors_.
+ AnalyzeRegion(&func_op.getBody(), alias_analysis);
+ // Populate sorted_control_predecessors_ and sorted_control_successors_ based
+ // on control_predecessors.
+ for (auto& entry : control_predecessors_) {
+ auto op = entry.getFirst();
+ auto& sorted_predecessors = sorted_control_predecessors_[op];
+ for (auto predecessor : entry.getSecond()) {
+ sorted_predecessors.push_back(predecessor);
+ sorted_control_successors_[predecessor].push_back(op);
+ }
+ }
+ control_predecessors_.clear();
+ for (auto& entry : sorted_control_predecessors_) {
+ llvm::sort(entry.getSecond(), [](Operation* a, Operation* b) {
+ return a->isBeforeInBlock(b);
+ });
+ }
+ for (auto& entry : sorted_control_successors_) {
+ llvm::sort(entry.getSecond(), [](Operation* a, Operation* b) {
+ return a->isBeforeInBlock(b);
+ });
+ }
+}
+
+void SideEffectAnalysis::AnalyzeRegion(
+ Region* region, const ResourceAliasAnalysis& alias_analysis) {
+ // This function populates control_predecessors_ by walking through the
+ // region, and tracking resource accesses in per_resource_access_info_.
// Returns whether an access to `resource` can skip control edges from
// prevoius accesses to unknown resources, due to that earlier accesses to
@@ -284,82 +323,93 @@ void SideEffectAnalysis::AnalyzeFunction(
(it->second.tracked_last_unknown_read || no_unknown_read);
};
- func_op.walk([&](Operation* op) {
- // We do not need explicit control edges for declaration ops.
- if (OpIsDeclaration(op, alias_analysis)) return;
-
- auto resource_op_info = GetResourceInfoForOp(op);
- if (!resource_op_info && op->hasNoSideEffect()) return;
-
- llvm::SmallDenseSet resources =
- resource_op_info ? FindAccessedResources(op, alias_analysis)
- : UnknownResourceSet();
- assert(!resources.empty());
- const bool is_unknown = resources.count(kUnknownResourceId) > 0;
- const bool read_only = OpIsReadOnly(op);
- bool indirectly_tracked_unknown_access = false;
- // First add edges from known resources.
- if (is_unknown) {
- for (auto& entry : per_resource_access_info_) {
- if (entry.getFirst() == kUnknownResourceId) continue;
- AddPredecessorsForAccess(entry.getFirst(), op, read_only);
- indirectly_tracked_unknown_access |=
- unknown_access_indirectly_tracked_by_resource(entry.getFirst(),
- read_only);
+ // We explicitly iterates through the regions and blocks, in order to handle
+ // different nested regions separately.
+ for (auto& block : *region) {
+ for (auto& op : block) {
+ if (op.getNumRegions() > 0) {
+ llvm::SmallVector child_analyses;
+ for (auto& child_region : op.getRegions()) {
+ child_analyses.emplace_back();
+ child_analyses.back().AnalyzeRegion(&child_region, alias_analysis);
+ }
+ ConsumeChildAnalyses(std::move(child_analyses));
}
- } else {
- for (int64_t resource : resources) {
- AddPredecessorsForAccess(resource, op, read_only);
- indirectly_tracked_unknown_access |=
- unknown_access_indirectly_tracked_by_resource(resource, read_only);
- // Update access info for known resources.
- TrackAccess(resource, op, read_only);
- }
- }
- // If not indirectly tracked, add edges from the unknown resource.
- if (!indirectly_tracked_unknown_access) {
- AddPredecessorsForAccess(kUnknownResourceId, op, read_only);
- }
- if (is_unknown) {
- // Update access info for unknown resource.
- TrackAccess(kUnknownResourceId, op, read_only);
- }
- });
- // Populate control_successors_ based on control_predecessors_.
- for (auto& entry : control_predecessors_) {
- auto op = entry.getFirst();
- for (auto predecessor : entry.getSecond()) {
- control_successors_[predecessor].insert(op);
+ // We do not need explicit control edges for declaration ops.
+ if (OpIsDeclaration(&op, alias_analysis)) continue;
+
+ auto resource_op_info = GetResourceInfoForOp(&op);
+ if (!resource_op_info && op.hasNoSideEffect()) continue;
+
+ llvm::SmallDenseSet resources =
+ resource_op_info ? FindAccessedResources(&op, alias_analysis)
+ : UnknownResourceSet();
+ assert(!resources.empty());
+ const bool is_unknown = resources.count(kUnknownResourceId) > 0;
+ const bool read_only = OpIsReadOnly(&op);
+ bool indirectly_tracked_unknown_access = false;
+ // First add edges from known resources.
+ if (is_unknown) {
+ for (auto& entry : per_resource_access_info_) {
+ if (entry.getFirst() == kUnknownResourceId) continue;
+ AddPredecessorsForAccess(entry.getFirst(), &op, read_only);
+ indirectly_tracked_unknown_access |=
+ unknown_access_indirectly_tracked_by_resource(entry.getFirst(),
+ read_only);
+ }
+ } else {
+ for (int64_t resource : resources) {
+ AddPredecessorsForAccess(resource, &op, read_only);
+ indirectly_tracked_unknown_access |=
+ unknown_access_indirectly_tracked_by_resource(resource,
+ read_only);
+ // Update access info for known resources.
+ TrackAccess(resource, &op, read_only);
+ }
+ }
+ // If not indirectly tracked, add edges from the unknown resource.
+ if (!indirectly_tracked_unknown_access) {
+ AddPredecessorsForAccess(kUnknownResourceId, &op, read_only);
+ }
+ if (is_unknown) {
+ // Update access info for unknown resource.
+ TrackAccess(kUnknownResourceId, &op, read_only);
+ }
}
}
}
-llvm::SmallVector SideEffectAnalysis::DirectControlPredecessors(
+void SideEffectAnalysis::ConsumeChildAnalyses(
+ llvm::SmallVector&& children) {
+ for (auto& child : children) {
+ for (auto& entry : child.control_predecessors_) {
+ control_predecessors_[entry.getFirst()] = std::move(entry.getSecond());
+ }
+ }
+}
+
+llvm::SmallVector SideEffectAnalysis::DirectControlPredecessors(
Operation* op, llvm::function_ref filter) const {
- llvm::SmallVector result;
- auto it = control_predecessors_.find(op);
- if (it == control_predecessors_.end()) return result;
+ llvm::SmallVector result;
+ auto it = sorted_control_predecessors_.find(op);
+ if (it == sorted_control_predecessors_.end()) return result;
result.reserve(it->getSecond().size());
for (auto predecessor : it->getSecond()) {
if (!filter || filter(predecessor)) result.push_back(predecessor);
}
- llvm::sort(result,
- [](Operation* a, Operation* b) { return a->isBeforeInBlock(b); });
return result;
}
-llvm::SmallVector SideEffectAnalysis::DirectControlSuccessors(
+llvm::SmallVector SideEffectAnalysis::DirectControlSuccessors(
Operation* op, llvm::function_ref filter) const {
- llvm::SmallVector result;
- auto it = control_successors_.find(op);
- if (it == control_successors_.end()) return result;
+ llvm::SmallVector result;
+ auto it = sorted_control_successors_.find(op);
+ if (it == sorted_control_successors_.end()) return result;
result.reserve(it->getSecond().size());
for (auto successor : it->getSecond()) {
if (!filter || filter(successor)) result.push_back(successor);
}
- llvm::sort(result,
- [](Operation* a, Operation* b) { return a->isBeforeInBlock(b); });
return result;
}
diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h
index 5eee28a6ae0..3d65217db27 100644
--- a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h
+++ b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h
@@ -32,6 +32,9 @@ namespace TF {
// An analysis that runs on a function and maps each resource-type value to a
// set of unique int64_t IDs representing the possible resources it could alias.
+//
+// If there are nested regions, each region is handled separately. This means
+// cross-region aliasing cannot be checked by this analysis.
class ResourceAliasAnalysis {
public:
explicit ResourceAliasAnalysis(Operation* op);
@@ -63,8 +66,12 @@ class ResourceAliasAnalysis {
// interfering with all known resource op accesses. It distinguishes accesses
// based on whether they are read-only, and read-only ops do not interfer with
// each other.
+//
+// If there are nested regions, each region is handled separately, and control
+// dependencies are only tracked for ops under the same parent op.
class SideEffectAnalysis {
public:
+ explicit SideEffectAnalysis() = default;
explicit SideEffectAnalysis(Operation* op);
SideEffectAnalysis(SideEffectAnalysis&& other) = default;
~SideEffectAnalysis() = default;
@@ -72,23 +79,32 @@ class SideEffectAnalysis {
// Returns a vector of ops that are direct control predecessors of `op`,
// sorted in program order. If `filter` is provided, only predecessors that
// pass the filter (returning true) will be included.
- llvm::SmallVector DirectControlPredecessors(
+ llvm::SmallVector DirectControlPredecessors(
Operation* op,
llvm::function_ref filter = nullptr) const;
// Returns a vector of ops that are direct control successors of `op`, sorted
// in program order. If `filter` is provided, only successors that pass the
// filter (returning true) will be included.
- llvm::SmallVector DirectControlSuccessors(
+ llvm::SmallVector DirectControlSuccessors(
Operation* op,
llvm::function_ref filter = nullptr) const;
private:
- // Runs the analysis on `func_op` and populates control_predecessors_ and
- // control_successors_.
+ // Runs the analysis on `func_op` and populates sorted_control_predecessors_
+ // and sorted_control_successors_.
void AnalyzeFunction(FuncOp func_op,
const ResourceAliasAnalysis& alias_analysis);
+ // Runs the analysis on `region` and populates control_predecessors_.
+ void AnalyzeRegion(Region* region,
+ const ResourceAliasAnalysis& alias_analysis);
+
+ // Moves the control_predecessors_ fields in `children` analyses to this
+ // current analysis.
+ void ConsumeChildAnalyses(
+ llvm::SmallVector&& children);
+
// Updates control_predecessors_ for `op` that is being visted, on the given
// `resource_id`.
void AddPredecessorsForAccess(int64_t resource_id, Operation* op,
@@ -98,11 +114,14 @@ class SideEffectAnalysis {
void TrackAccess(int64_t resource_id, Operation* op, bool read_only);
// Maps from an op to its control predecessors.
- llvm::SmallDenseMap, 8>
+ llvm::SmallDenseMap, 8>
control_predecessors_;
- // Maps from an op to its control successors.
- llvm::SmallDenseMap, 8>
- control_successors_;
+ // Maps from an op to its control predecessors sorted in program order.
+ llvm::SmallDenseMap, 8>
+ sorted_control_predecessors_;
+ // Maps from an op to its control successors sorted in program order.
+ llvm::SmallDenseMap, 8>
+ sorted_control_successors_;
// Internal per-resource data structure when we build the dependencies.
struct PerResourceAcessInfo {
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc
index 20483691a92..ffba86e78ff 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc
@@ -332,8 +332,7 @@ struct DropEmptyLaunch : public OpRewritePattern {
if (&block.front() != &block.back()) return matchFailure();
// Map launch results to return operands.
- llvm::SmallVector new_rets(block.front().getOperands());
- rewriter.replaceOp(op, new_rets);
+ rewriter.replaceOp(op, block.front().getOperands());
return matchSuccess();
}
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc
index d2174255a05..5a018a39fd7 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc
@@ -408,8 +408,7 @@ ParseResult ParseIslandOp(OpAsmParser &parser, OperationState &result) {
if (!wrapped_op) return failure();
OpBuilder builder(parser.getBuilder().getContext());
builder.setInsertionPointToEnd(&block);
- builder.create(wrapped_op->getLoc(),
- llvm::to_vector<8>(wrapped_op->getResults()));
+ builder.create(wrapped_op->getLoc(), wrapped_op->getResults());
result.location = wrapped_op->getLoc();
} else if (parser.parseRegion(body, llvm::None, llvm::None)) {
return failure();
@@ -1065,8 +1064,7 @@ struct DropEmptyGraph : public OpRewritePattern {
if (&block.front() != &block.back()) return matchFailure();
// Map graph results to fetch operands.
- llvm::SmallVector new_rets(op.GetFetch().fetches());
- rewriter.replaceOp(op, new_rets);
+ rewriter.replaceOp(op, op.GetFetch().fetches());
return matchSuccess();
}
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
index cdc545d5681..5b5c028c89d 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
@@ -94,6 +94,8 @@ Inputs must be of same size and shape.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>;
+
+ let hasFolder = 1;
}
def TF_AddV2Op : TF_Op<"AddV2", [Broadcastable, Commutative, NoSideEffect]>,
@@ -143,6 +145,8 @@ retained with length 1.
);
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
+
+ let verifier = [{ return Verify(*this); }];
}
def TF_AnyOp : TF_Op<"Any", [NoSideEffect]> {
@@ -169,6 +173,8 @@ retained with length 1.
);
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
+
+ let verifier = [{ return Verify(*this); }];
}
def TF_ArgMaxOp : TF_Op<"ArgMax", [NoSideEffect]> {
@@ -2116,6 +2122,28 @@ tf.math.greater_equal(x, y) ==> [True, False, True, True]
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
+def TF_HashTableV2Op : TF_Op<"HashTableV2", []> {
+ let summary = "Creates a non-initialized hash table.";
+
+ let description = [{
+This op creates a hash table, specifying the type of its keys and values.
+Before using the table you will have to initialize it. After initialization the
+table will be immutable.
+ }];
+
+ let arguments = (ins
+ StrAttr:$container,
+ StrAttr:$shared_name,
+ DefaultValuedAttr:$use_node_name_sharing,
+ TypeAttr:$key_dtype,
+ TypeAttr:$value_dtype
+ );
+
+ let results = (outs
+ TF_ResourceTensor:$table_handle
+ );
+}
+
def TF_IdentityNOp : TF_Op<"IdentityN", [NoSideEffect]> {
let summary = [{
Returns a list of tensors with the same shapes and contents as the input
@@ -2473,7 +2501,7 @@ def TF_LogicalAndOp : TF_Op<"LogicalAnd", [Broadcastable, Commutative, NoSideEff
}
def TF_LogicalNotOp : TF_Op<"LogicalNot", [NoSideEffect, SameOperandsAndResultType]> {
- let summary = "Returns the truth value of NOT x element-wise.";
+ let summary = "Returns the truth value of `NOT x` element-wise.";
let description = [{
}];
@@ -4334,6 +4362,37 @@ Resize `images` to `size` using nearest neighbor interpolation.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
+def TF_ResourceApplyAdamOp : TF_Op<"ResourceApplyAdam", []> {
+ let summary = "Update '*var' according to the Adam algorithm.";
+
+ let description = [{
+$$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$
+$$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$
+$$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$
+$$variable := variable - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$
+ }];
+
+ let arguments = (ins
+ TF_ResourceTensor:$var,
+ TF_ResourceTensor:$m,
+ TF_ResourceTensor:$v,
+ TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$beta1_power,
+ TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$beta2_power,
+ TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr,
+ TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$beta1,
+ TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$beta2,
+ TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$epsilon,
+ TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad,
+
+ DefaultValuedAttr:$use_locking,
+ DefaultValuedAttr:$use_nesterov
+ );
+
+ let results = (outs);
+
+ TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<3>;
+}
+
def TF_ResourceApplyGradientDescentOp : TF_Op<"ResourceApplyGradientDescent", []> {
let summary = "Update '*var' by subtracting 'alpha' * 'delta' from it.";
@@ -4353,6 +4412,34 @@ def TF_ResourceApplyGradientDescentOp : TF_Op<"ResourceApplyGradientDescent", []
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>;
}
+def TF_ResourceApplyKerasMomentumOp : TF_Op<"ResourceApplyKerasMomentum", []> {
+ let summary = [{
+Update '*var' according to the momentum scheme.
+ }];
+
+ let description = [{
+Set use_nesterov = True if you want to use Nesterov momentum.
+
+accum = accum * momentum - lr * grad
+var += accum
+ }];
+
+ let arguments = (ins
+ TF_ResourceTensor:$var,
+ TF_ResourceTensor:$accum,
+ TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr,
+ TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad,
+ TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$momentum,
+
+ DefaultValuedAttr:$use_locking,
+ DefaultValuedAttr:$use_nesterov
+ );
+
+ let results = (outs);
+
+ TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<2>;
+}
+
def TF_ReverseSequenceOp : TF_Op<"ReverseSequence", [NoSideEffect]> {
let summary = "Reverses variable length slices.";
@@ -5117,6 +5204,8 @@ def TF_SplitVOp : TF_Op<"SplitV", [NoSideEffect]> {
TF_DerivedOperandTypeAttr Tlen = TF_DerivedOperandTypeAttr<1>;
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedResultSizeAttr num_split = TF_DerivedResultSizeAttr<0>;
+
+ let verifier = [{ return Verify(*this); }];
}
def TF_SqrtOp : TF_Op<"Sqrt", [NoSideEffect, SameOperandsAndResultType]> {
@@ -5491,6 +5580,65 @@ output. For the internal use of the distributed TPU compiler.
TF_DerivedResultTypeListAttr Tresults = TF_DerivedResultTypeListAttr<0>;
}
+def TF_TPUReplicatedInputOp : TF_Op<"TPUReplicatedInput", [NoSideEffect]> {
+ let summary = "Connects N inputs to an N-way replicated TPU computation.";
+
+ let description = [{
+This operation holds a replicated input to a `tpu.replicate()` computation subgraph.
+Each replicated input has the same shape and type alongside the output.
+
+For example:
+```
+%a = "tf.opA"()
+%b = "tf.opB"()
+%replicated_input = "tf.TPUReplicatedInput"(%a, %b)
+%computation = "tf.Computation"(%replicated_input)
+```
+The above computation has a replicated input of two replicas.
+ }];
+
+ let arguments = (ins
+ Variadic:$inputs,
+
+ DefaultValuedAttr:$is_mirrored_variable,
+ DefaultValuedAttr:$index
+ );
+
+ let results = (outs
+ TF_Tensor:$output
+ );
+
+ TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
+ TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>;
+}
+
+def TF_TPUReplicatedOutputOp : TF_Op<"TPUReplicatedOutput", [NoSideEffect]> {
+ let summary = "Connects N outputs from an N-way replicated TPU computation.";
+
+ let description = [{
+This operation holds a replicated output from a `tpu.replicate()` computation subgraph.
+Each replicated output has the same shape and type alongside the input.
+
+For example:
+```
+%computation = "tf.Computation"()
+%replicated_output:2 = "tf.TPUReplicatedOutput"(%computation)
+```
+The above computation has a replicated output of two replicas.
+ }];
+
+ let arguments = (ins
+ TF_Tensor:$input
+ );
+
+ let results = (outs
+ Variadic:$outputs
+ );
+
+ TF_DerivedResultSizeAttr num_replicas = TF_DerivedResultSizeAttr<0>;
+ TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
+}
+
def TF_TanhOp : TF_Op<"Tanh", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Computes hyperbolic tangent of `x` element-wise.";
@@ -5905,6 +6053,8 @@ This is the opposite of `pack`.
TF_DerivedResultSizeAttr num = TF_DerivedResultSizeAttr<0>;
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
+
+ let verifier = [{ return Verify(*this); }];
}
def TF_VariableShapeOp : TF_Op<"VariableShape", []> {
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc
index 1bd9accbb78..9d2f634161c 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc
@@ -301,6 +301,15 @@ void AddOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
results.insert(context);
}
+//===----------------------------------------------------------------------===//
+// AddNOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult AddNOp::fold(ArrayRef operands) {
+ if (operands.size() == 1) return *inputs().begin();
+ return {};
+}
+
//===----------------------------------------------------------------------===//
// AddV2Op
//===----------------------------------------------------------------------===//
@@ -310,6 +319,49 @@ void AddV2Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
results.insert(context);
}
+//===----------------------------------------------------------------------===//
+// AllOp
+//===----------------------------------------------------------------------===//
+
+// Verifies an reduction op's `input` and reduction `dims`.
+static LogicalResult VerifyReductionInputAndDims(Value *input, Value *dims,
+ Location loc) {
+ auto dims_type = dims->getType().dyn_cast();
+ if (!dims_type) return success();
+ if (dims_type.getRank() > 1)
+ return emitError(loc, "dimensions can only be 0D or 1D tensor");
+
+ auto input_type = input->getType().dyn_cast();
+ if (!input_type) return success();
+ int64_t rank = input_type.getRank();
+
+ DenseIntElementsAttr dims_attr;
+ if (!matchPattern(dims, m_Constant(&dims_attr))) return success();
+ for (const auto &dim_pair : llvm::enumerate(dims_attr)) {
+ int64_t cur_dim = dim_pair.value().getSExtValue();
+ if (cur_dim < -rank || cur_dim >= rank)
+ return emitError(loc)
+ << dim_pair.index() << "-th dimension should be in the range of [-"
+ << rank << ", " << rank << ")";
+ }
+
+ return success();
+}
+
+static LogicalResult Verify(AllOp op) {
+ return VerifyReductionInputAndDims(op.input(), op.reduction_indices(),
+ op.getLoc());
+}
+
+//===----------------------------------------------------------------------===//
+// AnyOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult Verify(AnyOp op) {
+ return VerifyReductionInputAndDims(op.input(), op.reduction_indices(),
+ op.getLoc());
+}
+
//===----------------------------------------------------------------------===//
// AssertOp
//===----------------------------------------------------------------------===//
@@ -1542,17 +1594,23 @@ static LogicalResult Verify(SoftmaxCrossEntropyWithLogitsOp op) {
// SplitOp
//===----------------------------------------------------------------------===//
-static LogicalResult Verify(SplitOp op) {
+// Verifies the input and split dimension operands for tf.Split/tf.SplitV.
+// Writes the split dimension's index (adjusted with input rank) via `dim_index`
+// if it's a constant.
+template
+LogicalResult VerifySplitInputAndSplitDim(Op op, Optional *dim_index) {
+ *dim_index = llvm::None;
+
Value *split_dim = op.split_dim();
- auto split_dim_type = split_dim->getType().dyn_cast();
- if (!split_dim_type) return success();
- if (split_dim_type.getRank() != 0)
- return op.emitOpError("split dimension should be an integer scalar tensor");
+ if (auto split_dim_type = split_dim->getType().dyn_cast())
+ if (split_dim_type.getRank() != 0)
+ return op.emitOpError(
+ "split dimension should be an integer scalar tensor");
// We can perform further verification if the input tensor to be split has
// known rank and the split dimension tensor is a constant.
- auto input_type = op.value()->getType().dyn_cast();
+ auto input_type = op.value()->getType().template dyn_cast();
if (!input_type) return success();
int64_t input_rank = input_type.getRank();
@@ -1562,21 +1620,95 @@ static LogicalResult Verify(SplitOp op) {
DenseIntElementsAttr split_dim_attr;
if (!matchPattern(split_dim, m_Constant(&split_dim_attr))) return success();
- int64_t dim_index = (*split_dim_attr.begin()).getSExtValue();
+ int64_t index = (*split_dim_attr.begin()).getSExtValue();
- if (dim_index + input_rank < 0 || dim_index >= input_rank) {
+ if (index + input_rank < 0 || index >= input_rank) {
return op.emitOpError("split dimension must be in range [-")
<< input_rank << ", " << input_rank << ")";
}
- if (dim_index < 0) dim_index += input_rank;
+ if (index < 0) index += input_rank;
+ *dim_index = index;
- int64_t input_dim_size = input_type.getDimSize(dim_index);
- if (input_dim_size < 0) return success();
+ return success();
+}
+
+static LogicalResult Verify(SplitOp op) {
+ Optional dim_index;
+ if (failed(VerifySplitInputAndSplitDim(op, &dim_index))) return failure();
+ if (!dim_index) return success();
+
+ int64_t input_dim_size =
+ op.value()->getType().cast().getDimSize(*dim_index);
+ if (input_dim_size == ShapedType::kDynamicSize) return success();
if (input_dim_size % op.getNumResults() != 0)
return op.emitOpError("dimension #")
- << dim_index << " not divisible by the number of result tensors";
+ << *dim_index << " not divisible by the number of result tensors";
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// SplitVOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult Verify(SplitVOp op) {
+ auto split_sizes_type =
+ op.size_splits()->getType().dyn_cast();
+ if (!split_sizes_type) return success();
+
+ if (split_sizes_type.getRank() != 1 ||
+ split_sizes_type.getDimSize(0) != op.getNumResults())
+ return op.emitOpError("split sizes should be a 1D tensor of ")
+ << op.getNumResults() << " elements";
+
+ Optional dim_index = 0;
+ if (failed(VerifySplitInputAndSplitDim(op, &dim_index))) return failure();
+ if (!dim_index) return success();
+
+ int64_t input_dim_size =
+ op.value()->getType().cast().getDimSize(*dim_index);
+ if (input_dim_size == ShapedType::kDynamicSize) return success();
+
+ // If split sizes come from a constant, they must sum to the dimension size
+ // along split_dim, and we can have no more than one dynamic dimension.
+ DenseIntElementsAttr split_sizes_attr;
+ if (!matchPattern(op.size_splits(), m_Constant(&split_sizes_attr)))
+ return success();
+
+ int64_t total_dim_size = 0; // Total dimension size assigned to splits
+ llvm::Optional dynamic_dim_index;
+
+ SmallVector split_sizes;
+ split_sizes.reserve(
+ split_sizes_attr.getType().cast().getNumElements());
+
+ for (auto dim : llvm::enumerate(split_sizes_attr)) {
+ int64_t dim_val = dim.value().getSExtValue();
+ split_sizes.push_back(dim_val);
+ if (dim_val == ShapedType::kDynamicSize) {
+ // We cannot have more than one dynamic dimension.
+ if (dynamic_dim_index)
+ return op.emitOpError(
+ "cannot have more than one dynamic dimension in split sizes");
+ dynamic_dim_index = dim.index();
+ } else {
+ total_dim_size += dim_val;
+ }
+ }
+
+ if (!dynamic_dim_index && total_dim_size != input_dim_size)
+ return op.emitOpError(
+ "split sizes must sum up to the dimension size along split "
+ "dimension, found ")
+ << total_dim_size << " vs " << input_dim_size;
+
+ if (dynamic_dim_index && total_dim_size > input_dim_size)
+ return op.emitOpError(
+ "split sizes must sum up to be less than or equal to the "
+ "dimension size along split dimension, found ")
+ << total_dim_size << " vs " << input_dim_size;
return success();
}
@@ -1787,6 +1919,30 @@ void TruncateDivOp::getCanonicalizationPatterns(
results.insert(context);
}
+//===----------------------------------------------------------------------===//
+// UnpackOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult Verify(UnpackOp op) {
+ auto value_type = op.value()->getType().dyn_cast();
+ if (!value_type) return success();
+
+ int64_t value_rank = value_type.getRank();
+ int64_t axis = op.axis().getSExtValue();
+ if (axis < -value_rank || axis >= value_rank)
+ return op.emitOpError("axis attribute must be in the range of [-")
+ << value_rank << ", " << value_rank << ')';
+
+ axis = GetDimForAxis(axis, value_rank);
+ int64_t dim_size = value_type.getDimSize(axis);
+ if (ShapedType::isDynamic(dim_size)) return success();
+
+ if (dim_size != op.getNumResults())
+ return op.emitOpError("result count must be equal to ") << dim_size;
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// VariableShapeOp
//===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
index 8d975e909bb..9b6196cda5b 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
@@ -196,6 +196,42 @@ retained with length 1.
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
}
+def TF_LegacyCallOp : TF_Op<"LegacyCall",
+ [CallOpInterface, NoSideEffect]> {
+ let summary =
+ "returns `f(inputs)`, where `f` is a function.";
+
+ let description = [{
+ The LegacyCall operation represents a direct call to a function that is
+ within the same symbol scope as the call and is mapped to a GraphDef node
+ with the function name as the op name. Unlike a PartitionedCall which
+ represents asynchronously executing a function across multiple devices, a
+ LegacyCall represents a function call with the only attribute
+ _diable_call_shape_inference.
+ }];
+
+ let arguments = (ins
+ Variadic:$args,
+
+ FlatSymbolRefAttr:$f,
+ DefaultValuedAttr:$_disable_call_shape_inference
+ );
+
+ let results = (outs
+ Variadic:$output
+ );
+
+ let extraClassDeclaration = [{
+ // Gets the argument operands to the called function.
+ operand_range getArgOperands() { return args(); }
+
+ // Returns the callee of this operation.
+ CallInterfaceCallable getCallableForCallee() {
+ return getAttrOfType("f");
+ }
+ }];
+}
+
def TF_PartitionedCallOp : TF_Op<"PartitionedCall",
[CallOpInterface, NoSideEffect]> {
let summary =
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir b/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir
index 67c3982fe3b..d5a5c16cbff 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir
@@ -18,7 +18,7 @@ func @multiple_return(%arg0: tensor<*xi32>, %arg1: tensor) -> (tensor<*xi32
// CHECK-LABEL: func @multiple_return
// CHECK: %[[GRAPH:.*]]:2 = tf_executor.graph {
// CHECK: %[[ADD1:.*]], %[[ADD1_control:.*]] = tf_executor.island wraps "tf.Add"(%arg0, %arg1)
-// CHECK: %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island(%[[ADD1_control]]) wraps "tf.Add"(%[[ADD1]], %arg1)
+// CHECK: %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ADD1]], %arg1)
// CHECK: tf_executor.fetch %[[ADD1]], %[[ADD2]] :
// CHECK: }
// CHECK: return %[[GRAPH]]#0, %[[GRAPH]]#1
@@ -41,7 +41,12 @@ func @multiple_islands(%arg0: tensor<*xi32>, %arg1: tensor) -> (tensor<*xi3
%res = "tf.Print"(%sub) { message = "sub result" } : (tensor<*xi32>) -> (tensor<*xi32>)
tf_executor.yield
}
- tf_executor.fetch %island1#1, %island2#1, %island3 : tensor<*xi32>, tensor<*xi32>, !tf_executor.control
+ %island4 = tf_executor.island(%island1#2, %island2#2) {
+ %add = "tf.Add"(%island1#1, %island1#1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
+ %res = "tf.Print"(%add) { message = "add result" } : (tensor<*xi32>) -> (tensor<*xi32>)
+ tf_executor.yield
+ }
+ tf_executor.fetch %island1#1, %island2#1, %island3, %island4 : tensor<*xi32>, tensor<*xi32>, !tf_executor.control, !tf_executor.control
}
return %graph#0, %graph#1 : tensor<*xi32>, tensor<*xi32>
}
@@ -49,12 +54,17 @@ func @multiple_islands(%arg0: tensor<*xi32>, %arg1: tensor) -> (tensor<*xi3
// CHECK-LABEL: func @multiple_islands
// CHECK: %[[GRAPH:.*]]:2 = tf_executor.graph {
// CHECK: %[[ADD1:.*]], %[[ADD1_control:.*]] = tf_executor.island wraps "tf.Add"(%arg0, %arg1)
-// CHECK: %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island(%[[ADD1_control]]) wraps "tf.Add"(%[[ADD1]], %arg1)
+// CHECK: %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ADD1]], %arg1)
// CHECK: %[[SUB1:.*]], %[[SUB1_control:.*]] = tf_executor.island(%[[ADD2_control]]) wraps "tf.Sub"(%arg0, %arg1)
-// CHECK: %[[MUL:.*]], %[[MUL_control:.*]] = tf_executor.island(%[[SUB1_control]]) wraps "tf.Mul"(%[[SUB1]], %arg1)
+// CHECK: %[[MUL:.*]], %[[MUL_control:.*]] = tf_executor.island wraps "tf.Mul"(%[[SUB1]], %arg1)
// CHECK: %[[SUB2:.*]], %[[SUB2_control:.*]] = tf_executor.island(%[[ADD2_control]], %[[MUL_control]]) wraps "tf.Sub"(%[[ADD1]], %[[SUB1]])
-// CHECK: %[[PRINT:.*]], %[[PRINT_control:.*]] = tf_executor.island(%[[SUB2_control]]) wraps "tf.Print"(%[[SUB2]]) {message = "sub result"}
-// CHECK: tf_executor.fetch %[[ADD2]], %[[MUL]], %[[PRINT_control]] :
+// CHECK: %[[PRINT1:.*]], %[[PRINT1_control:.*]] = tf_executor.island wraps "tf.Print"(%[[SUB2]]) {message = "sub result"}
+// CHECK: %[[ISLAND1:.*]] = tf_executor.island(%[[ADD2_control]], %[[MUL_control]]) {
+// CHECK: tf_executor.yield
+// CHECK: }
+// CHECK: %[[ADD3:.*]], %[[ADD3_control:.*]] = tf_executor.island(%[[ISLAND1]], %[[ADD2_control]]) wraps "tf.Add"(%[[ADD2]], %[[ADD2]])
+// CHECK: %[[PRINT2:.*]], %[[PRINT2_control:.*]] = tf_executor.island wraps "tf.Print"(%[[ADD3]]) {message = "add result"}
+// CHECK: tf_executor.fetch %[[ADD2]], %[[MUL]], %[[PRINT1_control]], %[[PRINT2_control:.*]] :
// CHECK: }
// CHECK: return %[[GRAPH]]#0, %[[GRAPH]]#1
@@ -74,8 +84,8 @@ func @dangling_print(%arg0: tensor<*xi32>, %arg1: tensor) -> (tensor<*xi32>
// CHECK-LABEL: func @dangling_print
// CHECK: %[[GRAPH:.*]]:2 = tf_executor.graph {
// CHECK: %[[ADD1:.*]], %[[ADD1_control:.*]] = tf_executor.island wraps "tf.Add"(%arg0, %arg1)
-// CHECK: %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island(%[[ADD1_control]]) wraps "tf.Add"(%[[ADD1_control:.*]], %arg1)
-// CHECK: %[[PRINT:.*]], %[[PRINT_control:.*]] = tf_executor.island(%[[ADD2_control]]) wraps "tf.Print"(%[[ADD2_control:.*]]) {message = "add result"}
+// CHECK: %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ADD1_control:.*]], %arg1)
+// CHECK: %[[PRINT:.*]], %[[PRINT_control:.*]] = tf_executor.island wraps "tf.Print"(%[[ADD2_control:.*]]) {message = "add result"}
// CHECK: tf_executor.fetch %[[ADD1]], %[[ADD2]], %[[PRINT_control]] :
// CHECK: }
// CHECK: return %[[GRAPH]]#0, %[[GRAPH]]#1
@@ -103,11 +113,14 @@ func @switch_and_merge(%arg0: tensor<*xi32>, %arg1: tensor) -> (tensor<*xi3
// CHECK-LABEL: func @switch_and_merge(%arg0: tensor<*xi32>, %arg1: tensor) -> (tensor<*xi32>, tensor) {
// CHECK: %[[GRAPH:.*]]:2 = tf_executor.graph {
// CHECK: %[[ADD1:.*]], %[[ADD1_control:.*]] = tf_executor.island wraps "tf.Add"(%arg0, %arg1)
-// CHECK: %[[LESS:.*]], %[[LESS_control:.*]] = tf_executor.island(%[[ADD1_control]]) wraps "tf.Less"(%arg1, %arg1)
-// CHECK: %[[PRINT1:.*]], %[[PRINT1_control:.*]] = tf_executor.island(%[[LESS_control]]) wraps "tf.Print"(%[[ADD1]]) {message = "add result 1"}
-// CHECK: %[[SWITCH_false:.*]], %[[SWITCH_true:.*]], {{.*}} = tf_executor.Switch %[[ADD1]], %[[LESS]], %[[PRINT1_control]]
+// CHECK: %[[LESS:.*]], %[[LESS_control:.*]] = tf_executor.island wraps "tf.Less"(%arg1, %arg1)
+// CHECK: %[[PRINT1:.*]], %[[PRINT1_control:.*]] = tf_executor.island wraps "tf.Print"(%[[ADD1]]) {message = "add result 1"}
+// CHECK: %[[ISLAND1:.*]] = tf_executor.island(%[[LESS_control]], %[[PRINT1_control]]) {
+// CHECK: tf_executor.yield
+// CHECK: }
+// CHECK: %[[SWITCH_false:.*]], %[[SWITCH_true:.*]], {{.*}} = tf_executor.Switch %[[ADD1]], %[[LESS]], %[[ISLAND1]]
// CHECK: %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island wraps "tf.Add"(%[[SWITCH_false]], %arg1)
-// CHECK: %[[PRINT2:.*]], %[[PRINT2_control:.*]] = tf_executor.island(%[[ADD2_control]]) wraps "tf.Print"(%[[ADD2]]) {message = "add result 2"}
+// CHECK: %[[PRINT2:.*]], %[[PRINT2_control:.*]] = tf_executor.island wraps "tf.Print"(%[[ADD2]]) {message = "add result 2"}
// CHECK: %[[MERGE:.*]], %[[MERGE_index:.*]], %{{.*}} = tf_executor.Merge %[[ADD2]], %[[SWITCH_true]], %[[PRINT2_control]]
// CHECK: tf_executor.fetch %[[MERGE]], %[[MERGE_index]]
// CHECK: }
@@ -130,7 +143,7 @@ func @control_flow_plumbing(%arg0: tensor<*xi32>, %arg1: tensor) -> tensor<
// CHECK: %[[GRAPH:.*]] = tf_executor.graph {
// CHECK: %[[PRINT:.*]], %[[PRINT_control:.*]] = tf_executor.island wraps "tf.Print"(%arg0) {message = "Random Print"}
// CHECK: %[[ADD1:.*]], %[[ADD1_control:.*]] = tf_executor.island(%[[PRINT_control]]) wraps "tf.Add"(%arg0, %arg1)
-// CHECK: %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island(%[[ADD1_control]]) wraps "tf.Add"(%[[ADD1]], %arg1)
+// CHECK: %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ADD1]], %arg1)
// CHECK: tf_executor.fetch %[[ADD2]] : tensor<*xi32>
// CHECK: }
// CHECK: return %[[GRAPH]] : tensor<*xi32>
@@ -150,6 +163,77 @@ func @fetching_arg(%arg0: tensor<*xi32>) {
// CHECK-LABEL: func @fetching_arg
// CHECK: tf_executor.graph {
// CHECK: %[[ADD1:.*]], %[[ADD1_control:.*]] = tf_executor.island wraps "tf.Add"(%arg0, %arg0)
-// CHECK: %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island(%[[ADD1_control]]) wraps "tf.Add"(%[[ADD1]], %arg0)
+// CHECK: %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ADD1]], %arg0)
// CHECK: tf_executor.fetch %[[ADD2_control]] : !tf_executor.control
// CHECK: }
+
+func @non_aliasing_reads_writes(
+ %arg0: tensor<*x!tf.resource>>,
+ %arg1: tensor<*x!tf.resource>>,
+ %arg2: tensor<32xf32>) -> (tensor<32xf32>) {
+ %graph = tf_executor.graph {
+ %island:2 = tf_executor.island {
+ %read0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>>) -> tensor<32xf32>
+ "tf.AssignVariableOp"(%arg0, %arg2) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> ()
+ %read1 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource>>) -> tensor<32xf32>
+ %var_handle = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> tensor<*x!tf.resource>>
+ %read2 = "tf.ReadVariableOp"(%var_handle) : (tensor<*x!tf.resource>>) -> tensor<32xf32>
+ "tf.AssignVariableOp"(%arg1, %read0) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> ()
+ "tf.AssignVariableOp"(%arg0, %read2) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> ()
+ %read3 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>>) -> tensor<32xf32>
+ tf_executor.yield %read3 : tensor<32xf32>
+ }
+ tf_executor.fetch %island#0 : tensor<32xf32>
+ }
+ return %graph : tensor<32xf32>
+}
+
+// CHECK-LABEL: func @non_aliasing_reads_writes
+// CHECK: %[[GRAPH:.*]] = tf_executor.graph {
+// CHECK: %[[READ0:.*]], %[[READ0_CONTROL:.*]] = tf_executor.island wraps "tf.ReadVariableOp"(%arg0)
+// CHECK: %[[ASSIGN0_CONTROL:.*]] = tf_executor.island(%[[READ0_CONTROL]]) wraps "tf.AssignVariableOp"(%arg0, %arg2)
+// CHECK: %[[READ1:.*]], %[[READ1_CONTROL:.*]] = tf_executor.island wraps "tf.ReadVariableOp"(%arg1)
+// CHECK: %[[VH0:.*]], %[[VH0_CONTROL:.*]] = tf_executor.island wraps "tf.VarHandleOp"() {container = "c", shared_name = "v0"}
+// CHECK: %[[READ2:.*]], %[[READ2_CONTROL:.*]] = tf_executor.island wraps "tf.ReadVariableOp"(%[[VH0]])
+// CHECK: %[[ASSIGN1_CONTROL:.*]] = tf_executor.island(%[[READ1_CONTROL]]) wraps "tf.AssignVariableOp"(%arg1, %[[READ0:.*]])
+// CHECK: %[[ASSIGN2_CONTROL:.*]] = tf_executor.island(%[[ASSIGN0_CONTROL]]) wraps "tf.AssignVariableOp"(%arg0, %[[READ2]])
+// CHECK: %[[READ3:.*]], %[[READ3_CONTROL:.*]] = tf_executor.island(%[[ASSIGN2_CONTROL]]) wraps "tf.ReadVariableOp"(%arg0)
+// CHECK: %[[ISLAND1:.*]] = tf_executor.island(%[[ASSIGN1_CONTROL]], %[[READ3_CONTROL]]) {
+// CHECK: tf_executor.yield
+// CHECK: }
+// CHECK: tf_executor.fetch %[[READ3]], %[[ISLAND1]] : tensor<32xf32>, !tf_executor.control
+// CHECK: }
+
+func @unknown_side_effecting_op(%arg0: tensor<32xf32>) -> () {
+ tf_executor.graph {
+ %island = tf_executor.island {
+ %vh0 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> tensor<*x!tf.resource>>
+ %vh1 = "tf.VarHandleOp"() {container = "c", shared_name = "v1"} : () -> tensor<*x!tf.resource>>
+ %read0 = "tf.ReadVariableOp"(%vh0) : (tensor<*x!tf.resource>>) -> tensor<32xf32>
+ "tf.AssignVariableOp"(%vh1, %arg0) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> ()
+ "tf._UnknownSideEffectingOp_"() : () -> ()
+ %read1 = "tf.ReadVariableOp"(%vh1) : (tensor<*x!tf.resource>>) -> tensor<32xf32>
+ "tf.AssignVariableOp"(%vh0, %read1) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> ()
+ "tf.AssignVariableOp"(%vh1, %read0) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> ()
+ tf_executor.yield
+ }
+ tf_executor.fetch %island : !tf_executor.control
+ }
+ return
+}
+
+// CHECK-LABEL: func @unknown_side_effecting_op
+// CHECK: tf_executor.graph {
+// CHECK: %[[VH0:.*]], %[[VH0_CONTROL:.*]] = tf_executor.island wraps "tf.VarHandleOp"() {container = "c", shared_name = "v0"}
+// CHECK: %[[VH1:.*]], %[[VH1_CONTROL:.*]] = tf_executor.island wraps "tf.VarHandleOp"() {container = "c", shared_name = "v1"}
+// CHECK: %[[READ0:.*]], %[[READ0_CONTROL:.*]] = tf_executor.island wraps "tf.ReadVariableOp"(%[[VH0]])
+// CHECK: %[[ASSIGN0_CONTROL:.*]] = tf_executor.island wraps "tf.AssignVariableOp"(%[[VH1]], %arg0)
+// CHECK: %[[UNKNOWN_CONTROL:.*]] = tf_executor.island(%[[READ0_CONTROL]], %[[ASSIGN0_CONTROL]]) wraps "tf._UnknownSideEffectingOp_"()
+// CHECK: %[[READ1:.*]], %[[READ1_CONTROL:.*]] = tf_executor.island(%[[UNKNOWN_CONTROL]]) wraps "tf.ReadVariableOp"(%[[VH1]])
+// CHECK: %[[ASSIGN1_CONTROL:.*]] = tf_executor.island(%[[UNKNOWN_CONTROL]]) wraps "tf.AssignVariableOp"(%[[VH0]], %[[READ1]])
+// CHECK: %[[ASSIGN2_CONTROL:.*]] = tf_executor.island(%[[READ1_CONTROL]]) wraps "tf.AssignVariableOp"(%[[VH1]], %[[READ0]])
+// CHECK: %[[ISLAND1:.*]] = tf_executor.island(%[[ASSIGN1_CONTROL]], %[[ASSIGN2_CONTROL]]) {
+// CHECK: tf_executor.yield
+// CHECK: }
+// CHECK: tf_executor.fetch %[[ISLAND1]] : !tf_executor.control
+// CHECK: }
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir
index a2cc33a8201..18c63912a86 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir
@@ -382,3 +382,10 @@ func @nonIdentityTranspose(%arg0: tensor<2x3x4x5x6xf32>) -> tensor<2x3x4x6x5xf32
// CHECK: %1 = "tf.Transpose"(%arg0, %0) : (tensor<2x3x4x5x6xf32>, tensor<5xi32>) -> tensor<2x3x4x6x5xf32>
// CHECK: return %1
}
+
+// CHECK-LABEL: func @addN
+func @addN(%arg0: tensor<*xf32>) -> tensor<*xf32> {
+ // CHECK: return %arg0
+ %0 = "tf.AddN"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/decompose_resource_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/decompose_resource_ops.mlir
new file mode 100644
index 00000000000..67d58b41199
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/tests/decompose_resource_ops.mlir
@@ -0,0 +1,67 @@
+// RUN: tf-opt %s -split-input-file -tf-device-decompose-resource-ops | FileCheck %s
+
+// -----
+
+// Tests that composite tf.AssignAddVariableOp operation is decomposed and
+// hoisted.
+
+// CHECK-LABEL: func @decompose_assign_add_variable_op
+func @decompose_assign_add_variable_op() -> () {
+
+ %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
+
+ // CHECK: %[[ONE:[0-9]*]] = "tf.Const"() {value = dense<1> : tensor}
+ // CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"
+ // CHECK: "tf.AddV2"(%[[RES_READ_VAL]], %[[ONE]])
+ // CHECK: "tf.AssignVariableOp"
+
+ %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor
+ "tf.AssignAddVariableOp"(%0, %1) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>, tensor) -> ()
+
+ return
+}
+
+// -----
+
+// Tests that composite tf.AssignSubVariableOp operation is decomposed using
+// SubOp.
+
+// CHECK-LABEL: func @decompose_assign_sub_variable_op
+func @decompose_assign_sub_variable_op() -> () {
+
+ %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
+
+ // CHECK: %[[ONE:[0-9]*]] = "tf.Const"() {value = dense<1> : tensor}
+ // CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"
+ // CHECK: "tf.Sub"(%[[RES_READ_VAL]], %[[ONE]])
+ // CHECK: "tf.AssignVariableOp"
+
+ %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor
+ "tf.AssignSubVariableOp"(%0, %1) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>, tensor) -> ()
+
+ return
+}
+
+// -----
+
+// Tests that composite tf.ResourceApplyGradientDescent operation is decomposed.
+
+// CHECK-LABEL: func @decompose_resource_apply_gradient_descent
+// CHECK-SAME: (%[[DELTA:.*]]: tensor)
+func @decompose_resource_apply_gradient_descent(%arg0: tensor) -> () {
+
+ %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
+
+ // CHECK: %[[ALPHA:[0-9]*]] = "tf.Const"
+ // CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp"
+ // CHECK: %[[MUL:[0-9]*]] = "tf.Mul"(%[[DELTA]], %[[ALPHA]])
+ // CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]])
+ // CHECK: %[[SUB:[0-9]*]] = "tf.Sub"(%[[RES_READ_VAL]], %[[MUL]])
+ // CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[SUB]])
+
+ %1 = "tf.Const"() {T = f32, value = dense<[0.5]> : tensor<1xf32>} : () -> tensor
+ "tf.ResourceApplyGradientDescent"(%0, %1, %arg0) {use_locking = false} : (tensor<*x!tf.resource>, tensor, tensor) -> ()
+
+ return
+}
+
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-cfg.mlir b/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-cfg.mlir
index 2a0434b69e0..a0390ec8738 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-cfg.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-cfg.mlir
@@ -49,40 +49,33 @@ func @testIf3Result(tensor, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xi8>,
// -----
-func @testIf1Then(tensor<2x?xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
-func @testIf1Else(tensor<*xf32>, tensor<2x?xf32>) -> tensor<*xf32>
+func @testIfThen(%arg0: tensor) -> tensor {
+ return %arg0 : tensor
+}
+func @testIfElse(%arg0: tensor) -> tensor {
+ return %arg0 : tensor
+}
-// CHECK-LABEL: func @testIf1Casts(%arg0: tensor, %arg1: tensor<2x2xf32>, %arg2: tensor<*xf32>)
-func @testIf1Casts(tensor, tensor<2x2xf32>, tensor<*xf32>) -> tensor<2x?xf32> {
-^bb0(%arg0: tensor, %arg1: tensor<2x2xf32>, %arg2: tensor<*xf32>):
-
- %1 = "tf.If"(%arg0, %arg1, %arg2) {
- then_branch = @testIf1Then, else_branch = @testIf1Else, is_stateless = false
- } : (tensor, tensor<2x2xf32>, tensor<*xf32>) -> tensor<2x?xf32>
-
-// CHECK: %0 = extract_element %arg0[] : tensor
-// CHECK: cond_br %0, ^bb1, ^bb2
-// CHECK:^bb1: // pred: ^bb0
-// CHECK: %1 = tensor_cast %arg1 : tensor<2x2xf32> to tensor<2x?xf32>
-// CHECK: %2 = tensor_cast %arg2 : tensor<*xf32> to tensor<2x2xf32>
-// CHECK: %3 = call @testIf1Then(%1, %2) : (tensor<2x?xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
-// CHECK: %4 = tensor_cast %3 : tensor<2x2xf32> to tensor<2x?xf32>
-// CHECK: br ^bb3(%4 : tensor<2x?xf32>)
-
-// CHECK:^bb2: // pred: ^bb0
-// CHECK: %5 = tensor_cast %arg1 : tensor<2x2xf32> to tensor<*xf32>
-// CHECK: %6 = tensor_cast %arg2 : tensor<*xf32> to tensor<2x?xf32>
-// CHECK: %7 = call @testIf1Else(%5, %6) : (tensor<*xf32>, tensor<2x?xf32>) -> tensor<*xf32>
-// CHECK: %8 = tensor_cast %7 : tensor<*xf32> to tensor<2x?xf32>
-// CHECK: br ^bb3(%8 : tensor<2x?xf32>)
-
-// CHECK:^bb3(%9: tensor<2x?xf32>): // 2 preds: ^bb1, ^bb2
-
- %2 = "tf.Add"(%1, %1) : (tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
-// CHECK: %10 = "tf.Add"(%9, %9) : (tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
-
- return %2 : tensor<2x?xf32>
-// CHECK: return %10 : tensor<2x?xf32>
+// CHECK-LABEL: func @testIfCasts(%arg0: tensor, %arg1: tensor>>) -> tensor>>
+func @testIfCasts(%arg0: tensor, %arg1: tensor>>) -> tensor>> {
+ %0 = "tf.If"(%arg0, %arg1) {
+ then_branch = @testIfThen, else_branch = @testIfElse, is_stateless = false
+ } : (tensor, tensor>>) -> tensor>>
+ return %0: tensor>>
+// CHECK: %0 = extract_element %arg0[] : tensor
+// CHECK: cond_br %0, ^bb1, ^bb2
+// CHECK: ^bb1:
+// CHECK: %1 = "tf.Cast"(%arg1) {Truncate = false} : (tensor>>) -> tensor
+// CHECK: %2 = call @testIfThen(%1) : (tensor) -> tensor
+// CHECK: %3 = "tf.Cast"(%2) {Truncate = false} : (tensor) -> tensor>>
+// CHECK: br ^bb3(%3 : tensor>>)
+// CHECK: ^bb2:
+// CHECK: %4 = "tf.Cast"(%arg1) {Truncate = false} : (tensor>>) -> tensor
+// CHECK: %5 = call @testIfElse(%4) : (tensor) -> tensor
+// CHECK: %6 = "tf.Cast"(%5) {Truncate = false} : (tensor) -> tensor>>
+// CHECK: br ^bb3(%6 : tensor>>)
+// CHECK: ^bb3(%7: tensor>>):
+// CHECK: return %7 : tensor>>
}
// -----
@@ -188,31 +181,36 @@ func @testComplexWhile1Result(tensor<*xf32>) -> (tensor<*xf32>) {
// -----
-func @testWhileCond(tensor) -> (tensor)
-func @testWhileBody(tensor<*xf32>) -> (tensor)
+func @testWhileCond(%arg0: tensor) -> (tensor) {
+ %true = "tf.Const"() { value = dense : tensor } : () -> (tensor)
+ return %true : tensor
+}
+func @testWhileBody(%arg0: tensor>>) -> (tensor>>) {
+ %0 = "tf.Cast"(%arg0) : (tensor>>) -> tensor>>
+ return %0 : tensor>>
+}
-// CHECK-LABEL: func @testWhileCasts(%arg0: tensor<1x3xf32>)
-func @testWhileCasts(%arg0: tensor<1x3xf32>) -> (tensor) {
+// CHECK-LABEL: func @testWhileCasts(%arg0: tensor>>) -> tensor>>
+func @testWhileCasts(%arg0: tensor>>) -> (tensor>>) {
%0 = "tf.While"(%arg0) {
cond = @testWhileCond, body = @testWhileBody, is_stateless = false
- } : (tensor<1x3xf32>) -> (tensor)
-
-// CHECK: %0 = tensor_cast %arg0 : tensor<1x3xf32> to tensor
-// CHECK: br ^bb1(%0 : tensor)
-// CHECK: ^bb1(%1: tensor):
-// CHECK: %2 = call @testWhileCond(%1) : (tensor) -> tensor
+ } : (tensor>>) -> (tensor>>)
+ return %0 : tensor>>
+// CHECK: %0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor>>) -> tensor
+// CHECK: br ^bb1(%0 : tensor)
+// CHECK: ^bb1(%1: tensor): // 2 preds: ^bb0, ^bb2
+// CHECK: %2 = call @testWhileCond(%1) : (tensor) -> tensor
// CHECK: %3 = extract_element %2[] : tensor
-// CHECK: %4 = tensor_cast %1 : tensor to tensor<*xf32>
-// CHECK: cond_br %3, ^bb2(%4 : tensor<*xf32>), ^bb3(%4 : tensor<*xf32>)
-// CHECK: ^bb2(%5: tensor<*xf32>):
-// CHECK: %6 = call @testWhileBody(%5) : (tensor<*xf32>) -> tensor
-// CHECK: %7 = tensor_cast %6 : tensor to tensor
-// CHECK: br ^bb1(%7 : tensor)
-// CHECK: ^bb3(%8: tensor<*xf32>):
-// CHECK: %9 = tensor_cast %8 : tensor<*xf32> to tensor
+// CHECK: %4 = "tf.Cast"(%1) {Truncate = false} : (tensor) -> tensor>>
+// CHECK: cond_br %3, ^bb2(%4 : tensor>>), ^bb3(%4 : tensor>>)
+// CHECK: ^bb2(%5: tensor>>): // pred: ^bb1
+// CHECK: %6 = call @testWhileBody(%5) : (tensor>>) -> tensor>>
+// CHECK: %7 = "tf.Cast"(%6) {Truncate = false} : (tensor>>) -> tensor
+// CHECK: br ^bb1(%7 : tensor)
+// CHECK: ^bb3(%8: tensor>>): // pred: ^bb1
+// CHECK: %9 = "tf.Cast"(%8) {Truncate = false} : (tensor>>) -> tensor>>
+// CHECK: return %9 : tensor>>
- return %0 : tensor
-// CHECK: return %9 : tensor
}
// -----
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-custom-operation.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-custom-operation.pbtxt
index 9ce15315832..207d6676f61 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-custom-operation.pbtxt
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-custom-operation.pbtxt
@@ -54,5 +54,5 @@ versions {
# the names are matching between the function definition and the uses / call
# site (a numerical suffix may be appended).
-# CHECK: "tf.foo0"(
+# CHECK: "tf.LegacyCall"(%outputs) {_disable_call_shape_inference = false, f = @foo0}
# CHECK: func @foo0
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-default-attr.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-default-attr.pbtxt
index b26d7e7f2ba..ac248041994 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-default-attr.pbtxt
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-default-attr.pbtxt
@@ -8,7 +8,7 @@
# Verify that we can also pull some attributes that are needed to be able to
# create a Graph in memory, like `T`.
# CHECK: tf.MaxPool
-# CHECK-SAME: T = "tfdtype$DT_FLOAT"
+# CHECK-SAME: T = f32
node {
name: "input"
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-call.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-call.pbtxt
new file mode 100644
index 00000000000..f0a7a574ae3
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-call.pbtxt
@@ -0,0 +1,65 @@
+# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=x -tf-input-data-types=DT_INT32 -tf-input-shapes=10 -tf-output-arrays=func_call -o - | FileCheck %s
+
+node {
+ name: "x"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ int_val: 1
+ }
+ }
+ }
+}
+node {
+ name: "func_call"
+ op: "test_func_name"
+ input: "x"
+ attr {
+ key: "_disable_call_shape_inference"
+ value {
+ b: true
+ }
+ }
+}
+library {
+ function {
+ signature {
+ name: "test_func_name"
+ input_arg {
+ name: "a_0"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "a"
+ type: DT_INT32
+ }
+ }
+ ret {
+ key: "a"
+ value: "a_0"
+ }
+ attr {
+ key: "_disable_call_shape_inference"
+ value {
+ b: true
+ }
+ }
+ }
+}
+
+# CHECK: func @main
+# CHECK: "tf.LegacyCall"(%arg0) {_disable_call_shape_inference = true, f = @test_func_name0}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-name-bug.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-name-bug.pbtxt
index dcdbe67ccb6..563007f4305 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-name-bug.pbtxt
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-name-bug.pbtxt
@@ -121,8 +121,8 @@ versions {
# Verify that functions from the library are properly imported.
# CHECK-LABEL: func @main() {
-# CHECK: "tf.foo110"()
-# CHECK: "tf.foo111"()
+# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @foo110}
+# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @foo111}
# CHECK-LABEL: func @foo110() {
# CHECK-LABEL: func @foo111() {
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-library.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-library.pbtxt
index 17b2655aa5d..b65984227f6 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-library.pbtxt
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-library.pbtxt
@@ -39,10 +39,10 @@ versions {
# Verify that functions from the library are properly imported.
# CHECK-LABEL: func @main() {
-# CHECK: "tf.foo0"()
-# CHECK: "tf.bar0"()
+# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @foo0}
+# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @bar0}
# CHECK-LABEL: func @foo0() {
-# CHECK: "tf.bar0"()
+# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @bar0}
# CHECK-LABEL: func @bar0() {
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/multi-output-feeds.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/multi-output-feeds.pbtxt
new file mode 100644
index 00000000000..b28e2818730
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/multi-output-feeds.pbtxt
@@ -0,0 +1,300 @@
+# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=z:1,z:2 -tf-input-shapes=':' -tf-output-arrays=z:2,z:1,a:0 -o - | FileCheck %s --dump-input=fail
+# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-prune-unused-nodes -tf-input-arrays=z:1,z:2 -tf-input-shapes=':' -tf-output-arrays=z:2,z:1,a:0 -o - | FileCheck --check-prefix=PRUNE %s --dump-input=fail
+# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-prune-unused-nodes -tf-input-arrays=z:1,z:2 -tf-input-shapes=':' -tf-output-arrays=z:0,a:0 -o - | FileCheck --check-prefix=PRESERVE %s --dump-input=fail
+
+# Generated in Python via
+# ```
+# import tensorflow as tf
+#
+# with tf.compat.v1.Graph().as_default() as g:
+# w = tf.constant(2.0)
+# x = tf.constant(3.0)
+# y = tf.constant(4.0)
+# var = tf.Variable(2.0)
+# var_add = var.assign_add(3.0)
+# with g.control_dependencies([var_add]):
+# z0, z1, z2 = tf.identity_n((w, x, y))
+#
+# a = tf.add(z1, z2)
+# ```
+
+node {
+ name: "w"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 2.0
+ }
+ }
+ }
+}
+node {
+ name: "x"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 3.0
+ }
+ }
+ }
+}
+node {
+ name: "y"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 4.0
+ }
+ }
+ }
+}
+node {
+ name: "var/initial_value"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 2.0
+ }
+ }
+ }
+}
+node {
+ name: "var"
+ op: "VariableV2"
+ attr {
+ key: "container"
+ value {
+ s: ""
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ }
+ }
+ }
+ attr {
+ key: "shared_name"
+ value {
+ s: ""
+ }
+ }
+}
+node {
+ name: "var/Assign"
+ op: "Assign"
+ input: "var"
+ input: "var/initial_value"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@var"
+ }
+ }
+ }
+ attr {
+ key: "use_locking"
+ value {
+ b: true
+ }
+ }
+ attr {
+ key: "validate_shape"
+ value {
+ b: true
+ }
+ }
+}
+node {
+ name: "var/read"
+ op: "Identity"
+ input: "var"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@var"
+ }
+ }
+ }
+}
+node {
+ name: "var_add/value"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 3.0
+ }
+ }
+ }
+}
+node {
+ name: "var_add"
+ op: "AssignAdd"
+ input: "var"
+ input: "var_add/value"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@var"
+ }
+ }
+ }
+ attr {
+ key: "use_locking"
+ value {
+ b: false
+ }
+ }
+}
+node {
+ name: "z"
+ op: "IdentityN"
+ input: "w"
+ input: "x"
+ input: "y"
+ input: "^var_add"
+ attr {
+ key: "T"
+ value {
+ list {
+ type: DT_FLOAT
+ type: DT_FLOAT
+ type: DT_FLOAT
+ }
+ }
+ }
+}
+node {
+ name: "a"
+ op: "Add"
+ input: "z:1"
+ input: "z:2"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+versions {
+ producer: 230
+}
+
+# Test non zero index output tensors as feeds. Original ops where their outputs
+# are replaced with feeds are preserved and args and rets are lifted to the
+# function. Rets that happen to coincide with a feed should have its value be
+# of the feed.
+#
+# CHECK: func @main(%[[ARG_0:.*]]: tensor, %[[ARG_1:.*]]: tensor) -> (tensor, tensor, tensor)
+# CHECK: attributes {tf.entry_function = {inputs = "z:1,z:2", outputs = "z:2,z:1,a:0"}}
+# CHECK: %{{.*}}, %[[ASSIGN_ADD_CTRL:.*]] = tf_executor.island wraps "tf.AssignAdd"
+# CHECK: %{{.*}}, %{{.*}} = tf_executor.island(%[[ASSIGN_ADD_CTRL]]) wraps "tf.IdentityN"
+# CHECK: %[[ADD:.*]], %{{.*}} = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
+# CHECK: tf_executor.fetch %[[ARG_1]], %[[ARG_0]], %[[ADD]]
+
+# Test when non zero index output tensors are feeds, remaining ops that are
+# unreachable are pruned if pruning is enabled.
+#
+# PRUNE: func @main(%[[ARG_0:.*]]: tensor, %[[ARG_1:.*]]: tensor) -> (tensor, tensor, tensor)
+# PRUNE: attributes {tf.entry_function = {inputs = "z:1,z:2", outputs = "z:2,z:1,a:0"}}
+# PRUNE-NOT: "tf.Const"
+# PRUNE-NOT: "tf.VariableV2"
+# PRUNE-NOT: "tf.Assign"
+# PRUNE-NOT: "tf.Identity"
+# PRUNE-NOT: "tf.AssignAdd"
+# PRUNE-NOT: "tf.IdentityN"
+# PRUNE: %[[ADD:.*]], %{{.*}} = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
+# PRUNE: tf_executor.fetch %[[ARG_1]], %[[ARG_0]], %[[ADD]]
+
+# Test when non zero index output tensors are feeds, remaining ops that are
+# unreachable are preserved if pruning is not enabled.
+#
+# PRESERVE: func @main(%[[ARG_0:.*]]: tensor, %[[ARG_1:.*]]: tensor) -> (tensor, tensor)
+# PRESERVE: attributes {tf.entry_function = {inputs = "z:1,z:2", outputs = "z:0,a:0"}}
+# PRESERVE: %{{.*}}, %[[ASSIGN_ADD_CTRL:.*]] = tf_executor.island wraps "tf.AssignAdd"
+# PRESERVE: %[[IDENTITY_N:.*]]:3, %{{.*}} = tf_executor.island(%[[ASSIGN_ADD_CTRL]]) wraps "tf.IdentityN"
+# PRESERVE: %[[ADD:.*]], %{{.*}} = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
+# PRESERVE: tf_executor.fetch %[[IDENTITY_N]]#0, %[[ADD]]
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/switch_n.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/switch_n.pbtxt
index d33ac2f3b5b..3dd5ce58ed2 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/switch_n.pbtxt
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/switch_n.pbtxt
@@ -2,11 +2,11 @@
# CHECK: tf_executor.SwitchN
# CHECK-SAME: of 3 : tensor
-# CHECK-SAME: T = "tfdtype$DT_INT32"
+# CHECK-SAME: T = i32
# CHECK-SAME: name = "Case/branch_index/_3"
# CHECK: tf_executor.SwitchN
# CHECK-SAME: of 2 : tensor
-# CHECK-SAME: T = "tfdtype$DT_FLOAT"
+# CHECK-SAME: T = f32
# CHECK-SAME: name = "Case/Case/input_0/_7"
node {
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir
index 120e73f6e94..60ffc924ae5 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir
@@ -250,3 +250,19 @@ func @ZerosLike_variant(%arg0: tensor>>) -> tensor>>) -> tensor>>
return %0 : tensor>>
}
+
+// CHECK-LABEL: func @addN
+func @addN(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<*xf32> {
+ // CHECK: %[[SUM0:.*]] = "tf.AddV2"(%arg0, %arg1)
+ // CHECK: %[[SUM1:.*]] = "tf.AddV2"(%[[SUM0]], %arg2)
+ // return %[[SUM1]]
+ %0 = "tf.AddN"(%arg0, %arg1, %arg2) : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+}
+
+// CHECK-LABEL: func @addN_variant
+func @addN_variant(%arg0: tensor>>, %arg1: tensor>>, %arg2: tensor>>) -> tensor>> {
+ // CHECK: tf.AddN
+ %0 = "tf.AddN"(%arg0, %arg1, %arg2) : (tensor>>, tensor>>, tensor>>) -> tensor>>
+ return %0 : tensor>>
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/tf-legacy-call.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/tf-legacy-call.mlir
new file mode 100644
index 00000000000..6c83b45295e
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/tf-legacy-call.mlir
@@ -0,0 +1,26 @@
+// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s
+
+func @main() {
+ tf_executor.graph {
+ %outputs, %control = tf_executor.island wraps "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Constant", value = dense<0> : tensor} : () -> tensor
+ %outputs_0, %control_1 = tf_executor.island wraps "tf.LegacyCall"(%outputs) {f = @foo0} : (tensor) -> tensor
+ tf_executor.fetch
+ }
+ return
+}
+func @foo0(%arg0: tensor<*xi32>) -> tensor<*xi32> {
+ %0 = tf_executor.graph {
+ tf_executor.fetch %arg0 : tensor<*xi32>
+ }
+ return %0 : tensor<*xi32>
+}
+
+// CHECK: node {
+// CHECK: name: "_tf.LegacyCall"
+// CHECK-NEXT: op: "foo0"
+
+// CHECK: library {
+// CHECK-NEXT: function {
+// CHECK-NEXT: signature {
+// CHECK-NEXT: name: "foo0"
+
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/resource-device-inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/resource-device-inference.mlir
new file mode 100644
index 00000000000..c98e40fed05
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/tests/resource-device-inference.mlir
@@ -0,0 +1,244 @@
+// RUN: tf-opt -split-input-file -verify-diagnostics -tf-resource-device-inference %s | FileCheck %s --dump-input=fail
+
+// Tests that the pass can correctly propagate device attributes inside the same
+// function.
+
+// CHECK-LABEL: func @propagate_in_function
+func @propagate_in_function(
+ %arg0: tensor<*x!tf.resource>> {tf.device = "/TPU:0"},
+ %arg1: tensor<*x!tf.resource>> {tf.device = "/TPU:1"}) {
+ tf_executor.graph {
+ // CHECK: tf_executor.island
+ %island = tf_executor.island {
+ // CHECK-NEXT: "tf.VarHandleOp"
+ %var_handle = "tf.VarHandleOp"() {container = "c", shared_name = "v0", device = "/CPU:0"}
+ : () -> tensor<*x!tf.resource>>
+ // CHECK-NEXT: "tf.Identity"
+ // CHECK-SAME: {device = "/TPU:0"}
+ %id0 = "tf.Identity"(%arg0) : (tensor<*x!tf.resource>>)
+ -> tensor<*x!tf.resource>>
+ // CHECK-NEXT: "tf.Identity"
+ // CHECK-SAME: {device = "/TPU:0"}
+ %id1 = "tf.Identity"(%id0) : (tensor<*x!tf.resource>>)
+ -> tensor<*x!tf.resource>>
+ // CHECK-NEXT: "tf.Identity"
+ // CHECK-SAME: {device = "/CPU:0"}
+ %id2 = "tf.Identity"(%var_handle) : (tensor<*x!tf.resource>>)
+ -> tensor<*x!tf.resource>>
+ %read = "tf.ReadVariableOp"(%id2) : (tensor<*x!tf.resource>>) -> tensor<32xf32>
+ %id3 = "tf.Identity"(%read) : (tensor<32xf32>) -> tensor<32xf32>
+ tf_executor.yield
+ }
+ tf_executor.fetch %island : !tf_executor.control
+ }
+ return
+}
+
+// -----
+
+// Tesets that the pass can propagate through tf.If's branches.
+
+// CHECK-LABEL: func @propagate_if_op
+func @propagate_if_op(
+ %arg0: tensor<*x!tf.resource>> {tf.device = "/TPU:0"},
+ %arg1: tensor) {
+ tf_executor.graph {
+ // CHECK: tf_executor.island
+ %island = tf_executor.island {
+ // CHECK-NEXT: "tf.Identity"
+ // CHECK-SAME: {device = "/TPU:0"}
+ %id0 = "tf.Identity"(%arg0) : (tensor<*x!tf.resource>>)
+ -> tensor<*x!tf.resource>>
+ // CHECK-NEXT: "tf.VarHandleOp"
+ %var_handle = "tf.VarHandleOp"() {container = "c", shared_name = "v0", device = "/TPU:1"}
+ : () -> tensor<*x!tf.resource>>
+ // CHECK-NEXT: "tf.If"
+ "tf.If"(%arg1, %id0, %var_handle) {
+ then_branch = @if_then,
+ else_branch = @if_else,
+ output_shapes = [], is_stateless = false}
+ : (tensor, tensor<*x!tf.resource>>,
+ tensor<*x!tf.resource>>) -> ()
+ tf_executor.yield
+ }
+ tf_executor.fetch %island : !tf_executor.control
+ }
+ return
+}
+
+// CHECK-LABEL: func @if_then
+func @if_then(
+ %arg0: tensor<*x!tf.resource>>,
+ %arg1: tensor<*x!tf.resource>>) {
+ tf_executor.graph {
+ // CHECK: tf_executor.island
+ %island = tf_executor.island {
+ // CHECK-NEXT: "tf.Identity"
+ // CHECK-SAME: {device = "/TPU:0"}
+ %id0 = "tf.Identity"(%arg0) : (tensor<*x!tf.resource>>)
+ -> tensor<*x!tf.resource>>
+ // CHECK-NEXT: "tf.Identity"
+ // CHECK-SAME: {device = "/TPU:1"}
+ %id1 = "tf.Identity"(%arg1) : (tensor<*x!tf.resource>>)
+ -> tensor<*x!tf.resource>>
+ tf_executor.yield
+ }
+ tf_executor.fetch %island : !tf_executor.control
+ }
+ return
+}
+
+// CHECK-LABEL: func @if_else
+func @if_else(
+ %arg0: tensor<*x!tf.resource>>,
+ %arg1: tensor<*x!tf.resource>>) {
+ tf_executor.graph {
+ // CHECK: tf_executor.island
+ %island = tf_executor.island {
+ // CHECK-NEXT: "tf.Identity"
+ // CHECK-SAME: {device = "/TPU:0"}
+ %id0 = "tf.Identity"(%arg0) : (tensor<*x!tf.resource>>)
+ -> tensor<*x!tf.resource>>
+ tf_executor.yield
+ }
+ tf_executor.fetch %island : !tf_executor.control
+ }
+ return
+}
+
+
+// -----
+
+// Tesets that the pass can propagate through tf.While's branches.
+
+// CHECK-LABEL: func @propagate_while_op
+func @propagate_while_op(
+ %arg0: tensor<*x!tf.resource>> {tf.device = "/TPU:0"},
+ %arg1: tensor) {
+ tf_executor.graph {
+ // CHECK: tf_executor.island
+ %island = tf_executor.island {
+ // CHECK-NEXT: "tf.Identity"
+ // CHECK-SAME: {device = "/TPU:0"}
+ %id0 = "tf.Identity"(%arg0) : (tensor<*x!tf.resource>>)
+ -> tensor<*x!tf.resource>>
+ // CHECK-NEXT: "tf.VarHandleOp"
+ %var_handle = "tf.VarHandleOp"() {container = "c", shared_name = "v0", device = "/TPU:1"}
+ : () -> tensor<*x!tf.resource>>
+ // CHECK-NEXT: "tf.While"
+ "tf.While"(%arg1, %id0, %var_handle) {
+ body = @while_body,
+ cond = @while_cond,
+ output_shapes = [], is_stateless = false}
+ : (tensor, tensor<*x!tf.resource>>,
+ tensor<*x!tf.resource>>) ->
+ (tensor, tensor<*x!tf.resource>>,
+ tensor<*x!tf.resource>>)
+ tf_executor.yield
+ }
+ tf_executor.fetch %island : !tf_executor.control
+ }
+ return
+}
+
+// CHECK-LABEL: func @while_body
+func @while_body(
+ %arg0: tensor,
+ %arg1: tensor<*x!tf.resource>>,
+ %arg2: tensor<*x!tf.resource>>) ->
+ (tensor, tensor<*x!tf.resource>>,
+ tensor<*x!tf.resource>>) {
+ %graph:3 = tf_executor.graph {
+ // CHECK: tf_executor.island
+ %island:4 = tf_executor.island {
+ // CHECK-NEXT: "tf.Identity"
+ // CHECK-SAME: {device = "/TPU:0"}
+ %id0 = "tf.Identity"(%arg1) : (tensor<*x!tf.resource>>)
+ -> tensor<*x!tf.resource>>
+ // CHECK-NEXT: "tf.Identity"
+ // CHECK-SAME: {device = "/TPU:1"}
+ %id1 = "tf.Identity"(%arg2) : (tensor<*x!tf.resource>>)
+ -> tensor<*x!tf.resource>>
+ tf_executor.yield %arg0, %id0, %id1
+ : tensor, tensor<*x!tf.resource>>,
+ tensor<*x!tf.resource>>
+ }
+ tf_executor.fetch %island#0, %island#1, %island#2
+ : tensor, tensor<*x!tf.resource>>,
+ tensor<*x!tf.resource>>
+ }
+ return %graph#0, %graph#1, %graph#2
+ : tensor, tensor<*x!tf.resource>>,
+ tensor<*x!tf.resource>>
+}
+
+// CHECK-LABEL: func @while_cond
+func @while_cond(
+ %arg0: tensor,
+ %arg1: tensor<*x!tf.resource>>,
+ %arg2: tensor<*x!tf.resource>>) -> tensor<32xf32> {
+ %graph = tf_executor.graph {
+ // CHECK: tf_executor.island
+ %island:2 = tf_executor.island {
+ // CHECK-NEXT: "tf.Identity"
+ // CHECK-SAME: {device = "/TPU:0"}
+ %id0 = "tf.Identity"(%arg1) : (tensor<*x!tf.resource>>)
+ -> tensor<*x!tf.resource>>
+ %read = "tf.ReadVariableOp"(%id0)
+ : (tensor<*x!tf.resource>>) -> tensor<32xf32>
+ tf_executor.yield %read : tensor<32xf32>
+ }
+ tf_executor.fetch %island#0 : tensor<32xf32>
+ }
+ return %graph : tensor<32xf32>
+}
+
+// -----
+
+// Tesets that the pass reports error on conflicting assignments from multiple
+// callers.
+
+func @error_on_conflict_multiple_callers(
+ %arg0: tensor<*x!tf.resource>> {tf.device = "/TPU:0"},
+ %arg1: tensor) {
+ tf_executor.graph {
+ %island = tf_executor.island {
+ %id0 = "tf.Identity"(%arg0) : (tensor<*x!tf.resource>>)
+ -> tensor<*x!tf.resource>>
+ %var_handle = "tf.VarHandleOp"() {container = "c", shared_name = "v0", device = "/TPU:1"}
+ : () -> tensor<*x!tf.resource>>
+ "tf.If"(%arg1, %id0, %var_handle) {
+ then_branch = @if_then_and_else,
+ else_branch = @if_then_and_else,
+ output_shapes = [], is_stateless = false}
+ : (tensor, tensor<*x!tf.resource>>,
+ tensor<*x!tf.resource>>) -> ()
+ "tf.If"(%arg1, %var_handle, %id0) {
+ // expected-error@above {{Conflicting device assignment for resource}}
+ then_branch = @if_then_and_else,
+ else_branch = @if_then_and_else,
+ output_shapes = [], is_stateless = false}
+ : (tensor, tensor<*x!tf.resource>>,
+ tensor<*x!tf.resource>>) -> ()
+ tf_executor.yield
+ }
+ tf_executor.fetch %island : !tf_executor.control
+ }
+ return
+}
+
+func @if_then_and_else(
+ %arg0: tensor<*x!tf.resource>>,
+ %arg1: tensor<*x!tf.resource>>) {
+ tf_executor.graph {
+ %island = tf_executor.island {
+ %id0 = "tf.Identity"(%arg0) : (tensor<*x!tf.resource>>)
+ -> tensor<*x!tf.resource>>
+ %id1 = "tf.Identity"(%arg1) : (tensor<*x!tf.resource>>)
+ -> tensor<*x!tf.resource>>
+ tf_executor.yield
+ }
+ tf_executor.fetch %island : !tf_executor.control
+ }
+ return
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir b/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir
index 8ff72dbc7fc..e5905e5f681 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir
@@ -8,7 +8,7 @@ func @only_resource_load() -> tensor<*xi32> {
// CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp"
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
- // CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) {dtype = "tfdtype$DT_INT32"}
+ // CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) {dtype = i32}
// CHECK: "tf_device.launch"
// CHECK: %[[COMPUTE_RES:[0-9]*]] = "tf.SomeComputation"(%[[RES_READ_VAL]])
// CHECK: tf_device.return %[[COMPUTE_RES]]
@@ -16,7 +16,7 @@ func @only_resource_load() -> tensor<*xi32> {
// CHECK-SAME: () -> tensor<*xi32>
%1 = "tf_device.launch"() ( {
- %2 = "tf.ReadVariableOp"(%0) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>) -> tensor<*xi32>
+ %2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32>
%3 = "tf.SomeComputation"(%2) : (tensor<*xi32>) -> (tensor<*xi32>)
tf_device.return %3 : tensor<*xi32>
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<*xi32>
@@ -39,11 +39,11 @@ func @only_resource_store() -> tensor<*xi32> {
// CHECK: tf_device.return %[[COMPUTE_RES]], %[[COMPUTE_RES]]
// CHECK: {device = "tpu0", launch_attr = "launch_attr"}
// CHECK-SAME: () -> (tensor<*xi32>, tensor<*xi32>)
- // CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[LAUNCH_RES]]#1) {dtype = "tfdtype$DT_INT32"}
+ // CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[LAUNCH_RES]]#1) {dtype = i32}
%1 = "tf_device.launch"() ( {
%2 = "tf.SomeComputation"() : () -> (tensor<*xi32>)
- "tf.AssignVariableOp"(%0, %2) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>, tensor<*xi32>) -> ()
+ "tf.AssignVariableOp"(%0, %2) {dtype = i32} : (tensor<*x!tf.resource>, tensor<*xi32>) -> ()
tf_device.return %2 : tensor<*xi32>
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<*xi32>
@@ -61,18 +61,18 @@ func @same_resource_load_and_store() -> tensor<*xi32> {
// CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp"
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
- // CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) {dtype = "tfdtype$DT_INT32"}
+ // CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) {dtype = i32}
// CHECK: %[[LAUNCH_RES:[0-9]*]]:2 = "tf_device.launch"
// CHECK: %[[COMPUTE_RES:[0-9]*]] = "tf.SomeComputation"(%[[RES_READ_VAL]])
// CHECK: tf_device.return %[[COMPUTE_RES]], %[[COMPUTE_RES]]
// CHECK: {device = "tpu0", launch_attr = "launch_attr"}
// CHECK-SAME: () -> (tensor<*xi32>, tensor<*xi32>)
- // CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[LAUNCH_RES]]#1) {dtype = "tfdtype$DT_INT32"}
+ // CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[LAUNCH_RES]]#1) {dtype = i32}
%1 = "tf_device.launch"() ( {
- %2 = "tf.ReadVariableOp"(%0) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>) -> tensor<*xi32>
+ %2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32>
%3 = "tf.SomeComputation"(%2) : (tensor<*xi32>) -> (tensor<*xi32>)
- "tf.AssignVariableOp"(%0, %3) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>, tensor<*xi32>) -> ()
+ "tf.AssignVariableOp"(%0, %3) {dtype = i32} : (tensor<*x!tf.resource>, tensor<*xi32>) -> ()
tf_device.return %3 : tensor<*xi32>
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<*xi32>
@@ -82,96 +82,6 @@ func @same_resource_load_and_store() -> tensor<*xi32> {
// -----
-// Tests that composite tf.AssignAddVariableOp operation is decomposed and
-// hoisted.
-
-// CHECK-LABEL: func @decompose_assign_add_variable_op
-func @decompose_assign_add_variable_op() -> tensor<*xi32> {
-
- // CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp"
- %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
-
- // CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) {dtype = "tfdtype$DT_INT32"}
- // CHECK: %[[LAUNCH_RES:[0-9]*]]:2 = "tf_device.launch"
- // CHECK: %[[ONE:[0-9]*]] = "tf.Const"() {value = dense<1> : tensor}
- // CHECK: %[[COMPUTE_RES:[0-9]*]] = "tf.AddV2"(%[[RES_READ_VAL]], %[[ONE]])
- // CHECK: tf_device.return %[[COMPUTE_RES]], %[[COMPUTE_RES]]
- // CHECK: {device = "tpu0", launch_attr = "launch_attr"}
- // CHECK-SAME: () -> (tensor<*xi32>, tensor<*xi32>)
- // CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[LAUNCH_RES]]#1) {dtype = "tfdtype$DT_INT32"}
-
- %1 = "tf_device.launch"() ( {
- %2 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor
- "tf.AssignAddVariableOp"(%0, %2) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>, tensor) -> ()
- %3 = "tf.ReadVariableOp"(%0) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>) -> tensor<*xi32>
- tf_device.return %3 : tensor<*xi32>
- }) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<*xi32>
-
- // CHECK: return %[[LAUNCH_RES]]#0
- return %1 : tensor<*xi32>
-}
-
-// -----
-
-// Tests that composite tf.AssignSubVariableOp operation is decomposed using
-// SubOp.
-
-// CHECK-LABEL: func @decompose_assign_sub_variable_op
-func @decompose_assign_sub_variable_op() -> tensor<*xi32> {
-
- %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
-
- // CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"
- // CHECK: %[[ONE:[0-9]*]] = "tf.Const"() {value = dense<1> : tensor}
- // CHECK: "tf.Sub"(%[[RES_READ_VAL]], %[[ONE]])
- // CHECK: "tf.AssignVariableOp"
-
- %1 = "tf_device.launch"() ( {
- %2 = "tf.Const"() {value = dense<1> : tensor