diff --git a/CODEOWNERS b/CODEOWNERS index 271e3b5b2ff..3ef02ffd68c 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -13,55 +13,4 @@ /tensorflow/tensorboard/ @jart /tensorflow/tools/docs/ @markdaoust -# contrib - -# NEED OWNER: /tensorflow/contrib/all_reduce -/tensorflow/contrib/autograph/ @mdanatg @kkimdev -/tensorflow/contrib/batching/ @alextp @chrisolston -/tensorflow/contrib/bayesflow/ @ebrevdo @rsepassi @jvdillon -/tensorflow/contrib/boosted_trees/ @sshrdp @yk5 @nataliaponomareva -/tensorflow/contrib/checkpoint/ @allenlavoie -/tensorflow/contrib/contrib/cluster_resolver/ @frankchn -/tensorflow/contrib/cmake/ @mrry -/tensorflow/contrib/copy_graph/ @tucker @poxvoculi -/tensorflow/contrib/crf/ @kentonl -/tensorflow/contrib/data/ @mrry -/tensorflow/tensorflow/contrib/distribute @joshl @priyag @sourabhbajaj @frankchn -/tensorflow/contrib/distributions/ @jvdillon @langmore @rsepassi -/tensorflow/contrib/eager @jaingaurav @alextp -/tensorflow/contrib/factorization/ @agarwal-ashish @xavigonzalvo -/tensorflow/contrib/ffmpeg/ @fredbertsch -/tensorflow/contrib/framework/ @ebrevdo -/tensorflow/contrib/graph_editor/ @purpledog -# NEED OWNER: /tensorflow/contrib/grid_rnn/ -/tensorflow/contrib/hadoop @yongtang -/tensorflow/contrib/hvx/ @satok16 -/tensorflow/contrib/integrate/ @shoyer -/tensorflow/contrib/kernel_methods/ @petrosmol -/tensorflow/contrib/ios_examples/ @petewarden -/tensorflow/contrib/labeled_tensor/ @shoyer -/tensorflow/contrib/layers/ @fchollet @martinwicke -/tensorflow/contrib/learn/ @martinwicke @ispirmustafa @alextp -/tensorflow/contrib/linear_optimizer/ @petrosmol @andreasst @katsiapis -/tensorflow/contrib/lookup/ @ysuematsu @andreasst -/tensorflow/contrib/losses/ @alextp @ispirmustafa -/tensorflow/contrib/makefile/ @petewarden @satok16 @wolffg -/tensorflow/contrib/metrics/ @alextp @honkentuber @ispirmustafa -/tensorflow/contrib/opt/ @strategist333 @alextp -/tensorflow/contrib/pi_examples/ @maciekcc -/tensorflow/contrib/quantization/ @petewarden -/tensorflow/contrib/rnn/ @ebrevdo @scottzhu -/tensorflow/contrib/saved_model/ @nfiedel @sukritiramesh @allenlavoie -/tensorflow/contrib/seq2seq/ @ebrevdo @lmthang -/tensorflow/contrib/session_bundle/ @nfiedel @sukritiramesh -/tensorflow/contrib/slim/ @sguada @thenbasilmanran -/tensorflow/contrib/stateless/ @girving @alextp -/tensorflow/contrib/tensor_forest/ @gilberthendry @thomascolthurst @yupbank -/tensorflow/contrib/tensorrt/ @aaroey @smit-hinsu @azaks2 -# NEED OWNER: /tensorflow/contrib/testing/ -/tensorflow/contrib/timeseries/ @allenlavoie -/tensorflow/contrib/tpu/ @frankchn @saeta @jhseu @sourabhbajaj -/tensorflow/contrib/training/ @joel-shor @ebrevdo -/tensorflow/contrib/util/ @sherrym - /third_party/systemlibs/ @perfinion diff --git a/README.md b/README.md index 51ca43e1571..05b1e4de458 100644 --- a/README.md +++ b/README.md @@ -110,19 +110,19 @@ Build Type | Status ### Community Supported Builds -Build Type | Status | Artifacts -------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------- -**Linux AMD ROCm GPU** Nightly | [![Build Status](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly/badge/icon)](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly) | [Nightly](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly/lastSuccessfulBuild/) -**Linux AMD ROCm GPU** Stable Release | [![Build Status](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/badge/icon)](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/) | Release [1.15](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/lastSuccessfulBuild/) / [2.x](http://ml-ci.amd.com:21096/job/tensorflow-rocm-v2-release/lastSuccessfulBuild/) -**Linux s390x** Nightly | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | [Nightly](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) -**Linux s390x CPU** Stable Release | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/badge/icon)](https://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/) | [Release](https://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/) -**Linux ppc64le CPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Nightly_Artifact/) -**Linux ppc64le CPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_CPU_Release_Build/) -**Linux ppc64le GPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/) -**Linux ppc64le GPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_GPU_Release_Build/) -**Linux CPU with Intel® MKL-DNN** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) -**Linux CPU with Intel® MKL-DNN**
**Supports Python 2.7, 3.4, 3.5, 3.6 and 3.7** | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild) | [1.14.0 PyPI](https://pypi.org/project/intel-tensorflow/) -**Red Hat® Enterprise Linux® 7.6 CPU & GPU**
Python 2.7, 3.6 | [![Build Status](https://jenkins-tensorflow.apps.ci.centos.org/buildStatus/icon?job=tensorflow-rhel7-3.6&build=2)](https://jenkins-tensorflow.apps.ci.centos.org/job/tensorflow-rhel7-3.6/2/) | [1.13.1 PyPI](https://tensorflow.pypi.thoth-station.ninja/index/) +Build Type | Status | Artifacts +----------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------- +**Linux AMD ROCm GPU** Nightly | [![Build Status](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly/badge/icon)](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly) | [Nightly](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly/lastSuccessfulBuild/) +**Linux AMD ROCm GPU** Stable Release | [![Build Status](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/badge/icon)](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/) | Release [1.15](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/lastSuccessfulBuild/) / [2.x](http://ml-ci.amd.com:21096/job/tensorflow-rocm-v2-release/lastSuccessfulBuild/) +**Linux s390x** Nightly | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | [Nightly](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) +**Linux s390x CPU** Stable Release | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/badge/icon)](https://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/) | [Release](https://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/) +**Linux ppc64le CPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Nightly_Artifact/) +**Linux ppc64le CPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_CPU_Release_Build/) +**Linux ppc64le GPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/) +**Linux ppc64le GPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_GPU_Release_Build/) +**Linux CPU with Intel® MKL-DNN** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) +**Linux CPU with Intel® MKL-DNN** Stable Release | ![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon) | Release [1.15](https://pypi.org/project/intel-tensorflow/1.15.0/) / [2.x](https://pypi.org/project/intel-tensorflow/) +**Red Hat® Enterprise Linux® 7.6 CPU & GPU**
Python 2.7, 3.6 | [![Build Status](https://jenkins-tensorflow.apps.ci.centos.org/buildStatus/icon?job=tensorflow-rhel7-3.6&build=2)](https://jenkins-tensorflow.apps.ci.centos.org/job/tensorflow-rhel7-3.6/2/) | [1.13.1 PyPI](https://tensorflow.pypi.thoth-station.ninja/index/) ## Resources diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 0f299ec13f8..603c2a5c45c 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -195,6 +195,12 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "chromiumos", + values = {"crosstool_top": "//external:android/chromiumos"}, + visibility = ["//visibility:public"], +) + config_setting( name = "linux_aarch64", values = {"cpu": "aarch64"}, @@ -453,6 +459,7 @@ package_group( "//tensorflow_estimator/python/estimator/...", "//tensorflow_models/official/...", "//third_party/py/autograph/...", + "//third_party/swift/tensorflow/x10/...", ], ) diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 46ade1b2e77..8793e308466 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -233,7 +233,7 @@ tensorflow::Status GetReplacedFromExistingWorkers( std::vector responses( existing_workers->size()); for (int i = 0; i < existing_workers->size(); i++) { - tensorflow::eager::EagerClient* eager_client; + tensorflow::core::RefCountPtr eager_client; statuses[i] = client_cache->GetClient(existing_workers->at(i), &eager_client); if (!statuses[i].ok()) { @@ -282,7 +282,7 @@ tensorflow::Status CreateRemoteContexts( continue; } - tensorflow::eager::EagerClient* eager_client; + tensorflow::core::RefCountPtr eager_client; statuses[i] = remote_eager_workers->GetClient(remote_worker, &eager_client); if (eager_client == nullptr) { statuses[i] = tensorflow::errors::Internal( @@ -340,7 +340,7 @@ tensorflow::Status UpdateRemoteContexts( continue; } - tensorflow::eager::EagerClient* eager_client; + tensorflow::core::RefCountPtr eager_client; statuses[i] = remote_eager_workers->GetClient(remote_worker, &eager_client); if (eager_client == nullptr) { statuses[i] = tensorflow::errors::Internal( @@ -819,7 +819,7 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx, } // TODO(yuefengz): support partially specified `worker_name`. - tensorflow::eager::EagerClient* eager_client; + tensorflow::core::RefCountPtr eager_client; status->status = remote_eager_workers->GetClient(worker_name, &eager_client); if (!status->status.ok()) { return false; diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/custom_opdef.pbtxt b/tensorflow/compiler/mlir/lite/tests/end2end/custom_opdef.pbtxt index 7036ef71b58..0fcee7d7e8f 100644 --- a/tensorflow/compiler/mlir/lite/tests/end2end/custom_opdef.pbtxt +++ b/tensorflow/compiler/mlir/lite/tests/end2end/custom_opdef.pbtxt @@ -38,6 +38,6 @@ versions { # CHECK: func @main(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<*xi32> # CHECK: attributes {tf.entry_function = {inputs = "input0,input1", outputs = "output"}} { -# CHECK-NEXT: %0 = "tf.BannaPotatoSaladWithColeslaw"(%arg0, %arg1) {T = "tfdtype$DT_INT32", device = "", name = "output"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<*xi32> +# CHECK-NEXT: %0 = "tf.BannaPotatoSaladWithColeslaw"(%arg0, %arg1) {T = i32, device = "", name = "output"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<*xi32> # CHECK-NEXT: return %0 : tensor<*xi32> # CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index 27eff39c397..ec618ffa276 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -1280,3 +1280,13 @@ func @conv2d_backprop_unsupported_data_format(%arg0: tensor<4xi32>, %arg1: tenso // CHECK-LABEL: conv2d_backprop_unsupported_data_format // CHECK: tf.Conv2DBackpropInput } + +func @assert_remove(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi1> { + %0 = "tf.LessEqual"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + "tf.Assert"(%0, %arg1) {summarize = 3} : (tensor<1xi1>, tensor<1xi32>) -> () + return %0 : tensor<1xi1> + // CHECK-LABEL: assert_remove + // CHECK: tfl.less_equal + // CHECK-NOT: Assert + // CHECK: return +} diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index f7913f11f72..1d51adb16f2 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -622,3 +622,12 @@ func @Relu1_2(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { // CHECK: %[[relu_n1_to_1:[0-9].*]] = "tfl.relu_n1_to_1" } + +// CHECK-LABEL: fuse_relu_to_add +func @fuse_relu_to_add(%arg0: tensor<2x3xf32>, %arg1: tensor<2x3xf32>) -> tensor<2x3xf32> { + %0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> + %1 = "tfl.relu_n1_to_1"(%0) : (tensor<2x3xf32>) -> tensor<2x3xf32> + return %1 : tensor<2x3xf32> + // CHECK: %[[RES:.*]] = tfl.add %arg0, %arg1 {fused_activation_function = "RELU_N1_TO_1"} + // CHECK: return %[[RES]] +} diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc index bc6ff5e3b47..0512bc98cab 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc @@ -68,6 +68,7 @@ struct LegalizeTF : public FunctionPass { // TODO(antiagainst): Define this pattern in a table-driven manner once variadic // operands are properly supported in declarative rewrite rule specification. +DECL_CONVERT_OP(Assert); DECL_CONVERT_OP(Concat); DECL_CONVERT_OP(ConcatV2); DECL_CONVERT_OP(MatMul); @@ -86,7 +87,7 @@ PatternMatchResult ConvertTFConcatOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_concat_op = cast(op); - SmallVector values(tf_concat_op.values()); + auto values = tf_concat_op.values(); auto output_type = tf_concat_op.output()->getType(); // Extract axis attribute from constant concat_dims tensor ElementsAttr axis; @@ -105,7 +106,7 @@ PatternMatchResult ConvertTFConcatV2Op::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_concat_op = cast(op); - SmallVector values(tf_concat_op.values()); + auto values = tf_concat_op.values(); auto output_type = tf_concat_op.output()->getType(); // Extract axis attribute from constant axis tensor ElementsAttr axis; @@ -374,6 +375,14 @@ PatternMatchResult ConvertTFMatrixDiagV3Op::matchAndRewrite( return matchFailure(); } +// TF Lite doesn't support Assert, we just drop the assert from the graph. +PatternMatchResult ConvertTFAssertOp::matchAndRewrite( + Operation* op, PatternRewriter& rewriter) const { + op->dropAllReferences(); + op->erase(); + return matchSuccess(); +} + void LegalizeTF::runOnFunction() { OwningRewritePatternList patterns; auto* ctx = &getContext(); @@ -385,7 +394,8 @@ void LegalizeTF::runOnFunction() { .insert(ctx); + ConvertTFStridedSliceOp, ConvertTFUnpackOp, ConvertTFAssertOp>( + ctx); applyPatternsGreedily(func, patterns); } diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc index bf0e7169584..3f50c3ad1c1 100644 --- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc +++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc @@ -484,9 +484,9 @@ struct ConvertTensorListResize : public ConversionPattern { &rewriter); // Inserts the two blocks' names into the symbol table held by the module. - // Using ModuleManager will ensure that the inserted symbol names are + // Using SymbolTable will ensure that the inserted symbol names are // unique. - ModuleManager manager(resize_op.getParentOfType()); + SymbolTable manager(resize_op.getParentOfType()); manager.insert(then_branch_op); manager.insert(else_branch_op); @@ -754,8 +754,7 @@ struct ConvertWhile : public ConversionPattern { cloned.removeAttr("T"); UpdateFunctionTypes(cloned); - SmallVector results(cloned.getResults()); - rewriter.replaceOp(op, results); + rewriter.replaceOp(op, cloned.getResults()); return matchSuccess(); } }; diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc b/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc index c8b54d26653..173785ba5b0 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc @@ -135,15 +135,15 @@ class FoldIfOp : public OpRewritePattern { static void EraseDeadFuncs(const FuncSet& candiate_funcs, ModuleOp module) { if (candiate_funcs.empty()) return; - ModuleManager manager(module); + SymbolTable manager(module); // Identify the functions that are used as symbols in the module and shouldn't // be erased. FuncSet in_use_funcs; - manager.getModule().walk([&](Operation* op) { + manager.getOp()->walk([&](Operation* op) { for (auto attr : op->getAttrs()) { if (auto symbol = attr.second.dyn_cast()) { - auto func = manager.lookupSymbol(symbol.getValue()); + auto func = manager.lookup(symbol.getValue()); in_use_funcs.insert(func); } } diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index 905f01d8413..a91f6de1971 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -44,12 +44,13 @@ multiclass FuseActFnIntoConvOpPat { $multiplier)>; } -// TODO(hinsu): Also fuse ops corresponding to RELU_N1_TO_1 and SIGN_BIT fused +// TODO(hinsu): Also fuse ops corresponding to SIGN_BIT fused // activation functions. // Currently we're not fusing tanh, sigmoid, hard_swish and other activations // those cannot be simply translated into clamping. foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu], - [TFL_Relu6Op, TFL_AF_Relu6]] in + [TFL_Relu6Op, TFL_AF_Relu6], + [TFL_Relu1Op, TFL_AF_Relu1]] in defm : FuseActFnIntoConvOpPat; @@ -291,3 +292,18 @@ def : Pat<(TFL_MaximumOp (TFL_MinimumOp $input, (ConstantOp $NegOne)), (TFL_Relu1Op $input), [(ValueEquals<"-1"> $NegOne), (ValueEquals<"1"> $One)]>; + +// Multi-pattern consisting of matching stand-alone op or op followed by relu. +multiclass FusedBinaryActivationFuncOpPat { + foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu], + [TFL_Relu6Op, TFL_AF_Relu6], + [TFL_Relu1Op, TFL_AF_Relu1]] in { + def : Pat<(actFnPair[0] (BinaryOp $lhs, $rhs, TFL_AF_None)), + (BinaryOp $lhs, $rhs, actFnPair[1])>; + } +} + +// Instantiated FusedBinary patterns for the from-to pairs of ops. +foreach BinaryOps = [TFL_AddOp, TFL_DivOp, + TFL_MulOp, TFL_SubOp] in + defm : FusedBinaryActivationFuncOpPat; diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 5484988d0f5..5f93210f06e 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -192,6 +192,37 @@ cc_library( alwayslink = 1, ) +gentbl( + name = "decompose_resource_ops_inc_gen", + tbl_outs = [ + ( + "-gen-rewriters", + "transforms/generated_decompose_resource_ops.inc", + ), + ], + tblgen = "@local_config_mlir//:mlir-tblgen", + td_file = "transforms/decompose_resource_ops.td", + td_srcs = [ + ":tensorflow_ops_td_files", + "@local_config_mlir//:StdOpsTdFiles", + ], +) + +cc_library( + name = "decompose_resource_ops", + srcs = [ + "transforms/decompose_resource_ops.cc", + ], + hdrs = [ + "transforms/decompose_resource_ops.h", + ], + deps = [ + ":decompose_resource_ops_inc_gen", + ":tensorflow", + "@local_config_mlir//:IR", + ], +) + cc_library( name = "tensorflow_passes", srcs = [ @@ -199,6 +230,7 @@ cc_library( "transforms/bridge_pass.cc", "transforms/cluster_formation.cc", "transforms/cluster_outlining.cc", + "transforms/decompose_resource_ops_pass.cc", "transforms/delete_unused_funcs.cc", "transforms/executor_island_coarsening.cc", "transforms/fold_switch.cc", @@ -213,6 +245,7 @@ cc_library( "transforms/raise_control_flow.cc", "transforms/replicate_invariant_op_hoisting.cc", "transforms/replicate_to_island.cc", + "transforms/resource_device_inference.cc", "transforms/resource_op_lifting.cc", "transforms/shape_inference.cc", "transforms/shape_inference_pass.cc", @@ -236,6 +269,8 @@ cc_library( ":bridge_logger", ":convert_tensor", ":convert_type", + ":decompose_resource_ops", + ":decompose_resource_ops_inc_gen", ":device_util", ":error_util", ":export_tf_dialect_op", @@ -368,12 +403,14 @@ cc_library( ":convert_tensor", ":convert_type", ":mangling_util", + ":tensorflow", "//tensorflow/compiler/xla:status_macros", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:protobuf", "//tensorflow/stream_executor/lib", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", @@ -564,7 +601,6 @@ cc_library( hdrs = ["utils/error_util.h"], deps = [ "//tensorflow/core:lib", - "//tensorflow/stream_executor/lib", "@llvm//:support", "@local_config_mlir//:IR", ], @@ -808,7 +844,6 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core/platform:logging", "//tensorflow/stream_executor/lib", - "@com_google_absl//absl/types:span", "@llvm//:support", "@local_config_mlir//:IR", "@local_config_mlir//:Parser", diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc index 8d43c9330d0..898393479b0 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc +++ b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc @@ -34,6 +34,7 @@ limitations under the License. #include "mlir/IR/StandardTypes.h" // TF:local_config_mlir #include "mlir/Support/LLVM.h" // TF:local_config_mlir #include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/tf2xla/resource_operation_table.h" @@ -99,12 +100,13 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) { auto forward_input_to_output = [&](Value* operand, Value* result) { if (!mlir::getElementTypeOrSelf(result->getType()).isa()) return; + auto& result_ids = resource_value_to_ids_[result]; auto operand_it = resource_value_to_ids_.find(operand); assert(operand_it != resource_value_to_ids_.end() && "A resource-type output does not have the corresponding " "resource-type input."); - resource_value_to_ids_[result].insert(operand_it->getSecond().begin(), - operand_it->getSecond().end()); + result_ids.insert(operand_it->getSecond().begin(), + operand_it->getSecond().end()); }; // TODO(yuanzx): Consider control-flow ops. func_op.walk([&](Operation* op) { @@ -119,6 +121,16 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) { forward_input_to_output(std::get<0>(operand_and_result), std::get<1>(operand_and_result)); } + } else if (auto replicate = llvm::dyn_cast(op)) { + // The nested block for RepliateOp is handled separately in side-effect + // analysis. Inside that block, we can still treat its block arguments as + // different resources. + for (auto arg : replicate.GetBody().getArguments()) { + if (mlir::getElementTypeOrSelf(arg->getType()) + .isa()) { + resource_value_to_ids_[arg].insert(next_unique_id++); + } + } } else { for (auto result : op->getResults()) { if (!mlir::getElementTypeOrSelf(result->getType()) @@ -261,9 +273,36 @@ void SideEffectAnalysis::AddPredecessorsForAccess(int64_t resource_id, void SideEffectAnalysis::AnalyzeFunction( FuncOp func_op, const ResourceAliasAnalysis& alias_analysis) { - // This function populates control_predecessors_ and control_successors_ by - // walking through func_op's body, and tracking resource accesses in - // per_resource_access_info_. + // AnalyzeRegion() recursively analyzes the function body, and only populates + // control_predecessors_. + AnalyzeRegion(&func_op.getBody(), alias_analysis); + // Populate sorted_control_predecessors_ and sorted_control_successors_ based + // on control_predecessors. + for (auto& entry : control_predecessors_) { + auto op = entry.getFirst(); + auto& sorted_predecessors = sorted_control_predecessors_[op]; + for (auto predecessor : entry.getSecond()) { + sorted_predecessors.push_back(predecessor); + sorted_control_successors_[predecessor].push_back(op); + } + } + control_predecessors_.clear(); + for (auto& entry : sorted_control_predecessors_) { + llvm::sort(entry.getSecond(), [](Operation* a, Operation* b) { + return a->isBeforeInBlock(b); + }); + } + for (auto& entry : sorted_control_successors_) { + llvm::sort(entry.getSecond(), [](Operation* a, Operation* b) { + return a->isBeforeInBlock(b); + }); + } +} + +void SideEffectAnalysis::AnalyzeRegion( + Region* region, const ResourceAliasAnalysis& alias_analysis) { + // This function populates control_predecessors_ by walking through the + // region, and tracking resource accesses in per_resource_access_info_. // Returns whether an access to `resource` can skip control edges from // prevoius accesses to unknown resources, due to that earlier accesses to @@ -284,82 +323,93 @@ void SideEffectAnalysis::AnalyzeFunction( (it->second.tracked_last_unknown_read || no_unknown_read); }; - func_op.walk([&](Operation* op) { - // We do not need explicit control edges for declaration ops. - if (OpIsDeclaration(op, alias_analysis)) return; - - auto resource_op_info = GetResourceInfoForOp(op); - if (!resource_op_info && op->hasNoSideEffect()) return; - - llvm::SmallDenseSet resources = - resource_op_info ? FindAccessedResources(op, alias_analysis) - : UnknownResourceSet(); - assert(!resources.empty()); - const bool is_unknown = resources.count(kUnknownResourceId) > 0; - const bool read_only = OpIsReadOnly(op); - bool indirectly_tracked_unknown_access = false; - // First add edges from known resources. - if (is_unknown) { - for (auto& entry : per_resource_access_info_) { - if (entry.getFirst() == kUnknownResourceId) continue; - AddPredecessorsForAccess(entry.getFirst(), op, read_only); - indirectly_tracked_unknown_access |= - unknown_access_indirectly_tracked_by_resource(entry.getFirst(), - read_only); + // We explicitly iterates through the regions and blocks, in order to handle + // different nested regions separately. + for (auto& block : *region) { + for (auto& op : block) { + if (op.getNumRegions() > 0) { + llvm::SmallVector child_analyses; + for (auto& child_region : op.getRegions()) { + child_analyses.emplace_back(); + child_analyses.back().AnalyzeRegion(&child_region, alias_analysis); + } + ConsumeChildAnalyses(std::move(child_analyses)); } - } else { - for (int64_t resource : resources) { - AddPredecessorsForAccess(resource, op, read_only); - indirectly_tracked_unknown_access |= - unknown_access_indirectly_tracked_by_resource(resource, read_only); - // Update access info for known resources. - TrackAccess(resource, op, read_only); - } - } - // If not indirectly tracked, add edges from the unknown resource. - if (!indirectly_tracked_unknown_access) { - AddPredecessorsForAccess(kUnknownResourceId, op, read_only); - } - if (is_unknown) { - // Update access info for unknown resource. - TrackAccess(kUnknownResourceId, op, read_only); - } - }); - // Populate control_successors_ based on control_predecessors_. - for (auto& entry : control_predecessors_) { - auto op = entry.getFirst(); - for (auto predecessor : entry.getSecond()) { - control_successors_[predecessor].insert(op); + // We do not need explicit control edges for declaration ops. + if (OpIsDeclaration(&op, alias_analysis)) continue; + + auto resource_op_info = GetResourceInfoForOp(&op); + if (!resource_op_info && op.hasNoSideEffect()) continue; + + llvm::SmallDenseSet resources = + resource_op_info ? FindAccessedResources(&op, alias_analysis) + : UnknownResourceSet(); + assert(!resources.empty()); + const bool is_unknown = resources.count(kUnknownResourceId) > 0; + const bool read_only = OpIsReadOnly(&op); + bool indirectly_tracked_unknown_access = false; + // First add edges from known resources. + if (is_unknown) { + for (auto& entry : per_resource_access_info_) { + if (entry.getFirst() == kUnknownResourceId) continue; + AddPredecessorsForAccess(entry.getFirst(), &op, read_only); + indirectly_tracked_unknown_access |= + unknown_access_indirectly_tracked_by_resource(entry.getFirst(), + read_only); + } + } else { + for (int64_t resource : resources) { + AddPredecessorsForAccess(resource, &op, read_only); + indirectly_tracked_unknown_access |= + unknown_access_indirectly_tracked_by_resource(resource, + read_only); + // Update access info for known resources. + TrackAccess(resource, &op, read_only); + } + } + // If not indirectly tracked, add edges from the unknown resource. + if (!indirectly_tracked_unknown_access) { + AddPredecessorsForAccess(kUnknownResourceId, &op, read_only); + } + if (is_unknown) { + // Update access info for unknown resource. + TrackAccess(kUnknownResourceId, &op, read_only); + } } } } -llvm::SmallVector SideEffectAnalysis::DirectControlPredecessors( +void SideEffectAnalysis::ConsumeChildAnalyses( + llvm::SmallVector&& children) { + for (auto& child : children) { + for (auto& entry : child.control_predecessors_) { + control_predecessors_[entry.getFirst()] = std::move(entry.getSecond()); + } + } +} + +llvm::SmallVector SideEffectAnalysis::DirectControlPredecessors( Operation* op, llvm::function_ref filter) const { - llvm::SmallVector result; - auto it = control_predecessors_.find(op); - if (it == control_predecessors_.end()) return result; + llvm::SmallVector result; + auto it = sorted_control_predecessors_.find(op); + if (it == sorted_control_predecessors_.end()) return result; result.reserve(it->getSecond().size()); for (auto predecessor : it->getSecond()) { if (!filter || filter(predecessor)) result.push_back(predecessor); } - llvm::sort(result, - [](Operation* a, Operation* b) { return a->isBeforeInBlock(b); }); return result; } -llvm::SmallVector SideEffectAnalysis::DirectControlSuccessors( +llvm::SmallVector SideEffectAnalysis::DirectControlSuccessors( Operation* op, llvm::function_ref filter) const { - llvm::SmallVector result; - auto it = control_successors_.find(op); - if (it == control_successors_.end()) return result; + llvm::SmallVector result; + auto it = sorted_control_successors_.find(op); + if (it == sorted_control_successors_.end()) return result; result.reserve(it->getSecond().size()); for (auto successor : it->getSecond()) { if (!filter || filter(successor)) result.push_back(successor); } - llvm::sort(result, - [](Operation* a, Operation* b) { return a->isBeforeInBlock(b); }); return result; } diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h index 5eee28a6ae0..3d65217db27 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h +++ b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h @@ -32,6 +32,9 @@ namespace TF { // An analysis that runs on a function and maps each resource-type value to a // set of unique int64_t IDs representing the possible resources it could alias. +// +// If there are nested regions, each region is handled separately. This means +// cross-region aliasing cannot be checked by this analysis. class ResourceAliasAnalysis { public: explicit ResourceAliasAnalysis(Operation* op); @@ -63,8 +66,12 @@ class ResourceAliasAnalysis { // interfering with all known resource op accesses. It distinguishes accesses // based on whether they are read-only, and read-only ops do not interfer with // each other. +// +// If there are nested regions, each region is handled separately, and control +// dependencies are only tracked for ops under the same parent op. class SideEffectAnalysis { public: + explicit SideEffectAnalysis() = default; explicit SideEffectAnalysis(Operation* op); SideEffectAnalysis(SideEffectAnalysis&& other) = default; ~SideEffectAnalysis() = default; @@ -72,23 +79,32 @@ class SideEffectAnalysis { // Returns a vector of ops that are direct control predecessors of `op`, // sorted in program order. If `filter` is provided, only predecessors that // pass the filter (returning true) will be included. - llvm::SmallVector DirectControlPredecessors( + llvm::SmallVector DirectControlPredecessors( Operation* op, llvm::function_ref filter = nullptr) const; // Returns a vector of ops that are direct control successors of `op`, sorted // in program order. If `filter` is provided, only successors that pass the // filter (returning true) will be included. - llvm::SmallVector DirectControlSuccessors( + llvm::SmallVector DirectControlSuccessors( Operation* op, llvm::function_ref filter = nullptr) const; private: - // Runs the analysis on `func_op` and populates control_predecessors_ and - // control_successors_. + // Runs the analysis on `func_op` and populates sorted_control_predecessors_ + // and sorted_control_successors_. void AnalyzeFunction(FuncOp func_op, const ResourceAliasAnalysis& alias_analysis); + // Runs the analysis on `region` and populates control_predecessors_. + void AnalyzeRegion(Region* region, + const ResourceAliasAnalysis& alias_analysis); + + // Moves the control_predecessors_ fields in `children` analyses to this + // current analysis. + void ConsumeChildAnalyses( + llvm::SmallVector&& children); + // Updates control_predecessors_ for `op` that is being visted, on the given // `resource_id`. void AddPredecessorsForAccess(int64_t resource_id, Operation* op, @@ -98,11 +114,14 @@ class SideEffectAnalysis { void TrackAccess(int64_t resource_id, Operation* op, bool read_only); // Maps from an op to its control predecessors. - llvm::SmallDenseMap, 8> + llvm::SmallDenseMap, 8> control_predecessors_; - // Maps from an op to its control successors. - llvm::SmallDenseMap, 8> - control_successors_; + // Maps from an op to its control predecessors sorted in program order. + llvm::SmallDenseMap, 8> + sorted_control_predecessors_; + // Maps from an op to its control successors sorted in program order. + llvm::SmallDenseMap, 8> + sorted_control_successors_; // Internal per-resource data structure when we build the dependencies. struct PerResourceAcessInfo { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc index 20483691a92..ffba86e78ff 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc @@ -332,8 +332,7 @@ struct DropEmptyLaunch : public OpRewritePattern { if (&block.front() != &block.back()) return matchFailure(); // Map launch results to return operands. - llvm::SmallVector new_rets(block.front().getOperands()); - rewriter.replaceOp(op, new_rets); + rewriter.replaceOp(op, block.front().getOperands()); return matchSuccess(); } diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc index d2174255a05..5a018a39fd7 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc @@ -408,8 +408,7 @@ ParseResult ParseIslandOp(OpAsmParser &parser, OperationState &result) { if (!wrapped_op) return failure(); OpBuilder builder(parser.getBuilder().getContext()); builder.setInsertionPointToEnd(&block); - builder.create(wrapped_op->getLoc(), - llvm::to_vector<8>(wrapped_op->getResults())); + builder.create(wrapped_op->getLoc(), wrapped_op->getResults()); result.location = wrapped_op->getLoc(); } else if (parser.parseRegion(body, llvm::None, llvm::None)) { return failure(); @@ -1065,8 +1064,7 @@ struct DropEmptyGraph : public OpRewritePattern { if (&block.front() != &block.back()) return matchFailure(); // Map graph results to fetch operands. - llvm::SmallVector new_rets(op.GetFetch().fetches()); - rewriter.replaceOp(op, new_rets); + rewriter.replaceOp(op, op.GetFetch().fetches()); return matchSuccess(); } diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index cdc545d5681..5b5c028c89d 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -94,6 +94,8 @@ Inputs must be of same size and shape. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>; + + let hasFolder = 1; } def TF_AddV2Op : TF_Op<"AddV2", [Broadcastable, Commutative, NoSideEffect]>, @@ -143,6 +145,8 @@ retained with length 1. ); TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>; + + let verifier = [{ return Verify(*this); }]; } def TF_AnyOp : TF_Op<"Any", [NoSideEffect]> { @@ -169,6 +173,8 @@ retained with length 1. ); TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>; + + let verifier = [{ return Verify(*this); }]; } def TF_ArgMaxOp : TF_Op<"ArgMax", [NoSideEffect]> { @@ -2116,6 +2122,28 @@ tf.math.greater_equal(x, y) ==> [True, False, True, True] TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_HashTableV2Op : TF_Op<"HashTableV2", []> { + let summary = "Creates a non-initialized hash table."; + + let description = [{ +This op creates a hash table, specifying the type of its keys and values. +Before using the table you will have to initialize it. After initialization the +table will be immutable. + }]; + + let arguments = (ins + StrAttr:$container, + StrAttr:$shared_name, + DefaultValuedAttr:$use_node_name_sharing, + TypeAttr:$key_dtype, + TypeAttr:$value_dtype + ); + + let results = (outs + TF_ResourceTensor:$table_handle + ); +} + def TF_IdentityNOp : TF_Op<"IdentityN", [NoSideEffect]> { let summary = [{ Returns a list of tensors with the same shapes and contents as the input @@ -2473,7 +2501,7 @@ def TF_LogicalAndOp : TF_Op<"LogicalAnd", [Broadcastable, Commutative, NoSideEff } def TF_LogicalNotOp : TF_Op<"LogicalNot", [NoSideEffect, SameOperandsAndResultType]> { - let summary = "Returns the truth value of NOT x element-wise."; + let summary = "Returns the truth value of `NOT x` element-wise."; let description = [{ }]; @@ -4334,6 +4362,37 @@ Resize `images` to `size` using nearest neighbor interpolation. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_ResourceApplyAdamOp : TF_Op<"ResourceApplyAdam", []> { + let summary = "Update '*var' according to the Adam algorithm."; + + let description = [{ +$$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$ +$$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$ +$$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$ +$$variable := variable - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$ + }]; + + let arguments = (ins + TF_ResourceTensor:$var, + TF_ResourceTensor:$m, + TF_ResourceTensor:$v, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$beta1_power, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$beta2_power, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$beta1, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$beta2, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$epsilon, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad, + + DefaultValuedAttr:$use_locking, + DefaultValuedAttr:$use_nesterov + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<3>; +} + def TF_ResourceApplyGradientDescentOp : TF_Op<"ResourceApplyGradientDescent", []> { let summary = "Update '*var' by subtracting 'alpha' * 'delta' from it."; @@ -4353,6 +4412,34 @@ def TF_ResourceApplyGradientDescentOp : TF_Op<"ResourceApplyGradientDescent", [] TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>; } +def TF_ResourceApplyKerasMomentumOp : TF_Op<"ResourceApplyKerasMomentum", []> { + let summary = [{ +Update '*var' according to the momentum scheme. + }]; + + let description = [{ +Set use_nesterov = True if you want to use Nesterov momentum. + +accum = accum * momentum - lr * grad +var += accum + }]; + + let arguments = (ins + TF_ResourceTensor:$var, + TF_ResourceTensor:$accum, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$momentum, + + DefaultValuedAttr:$use_locking, + DefaultValuedAttr:$use_nesterov + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<2>; +} + def TF_ReverseSequenceOp : TF_Op<"ReverseSequence", [NoSideEffect]> { let summary = "Reverses variable length slices."; @@ -5117,6 +5204,8 @@ def TF_SplitVOp : TF_Op<"SplitV", [NoSideEffect]> { TF_DerivedOperandTypeAttr Tlen = TF_DerivedOperandTypeAttr<1>; TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedResultSizeAttr num_split = TF_DerivedResultSizeAttr<0>; + + let verifier = [{ return Verify(*this); }]; } def TF_SqrtOp : TF_Op<"Sqrt", [NoSideEffect, SameOperandsAndResultType]> { @@ -5491,6 +5580,65 @@ output. For the internal use of the distributed TPU compiler. TF_DerivedResultTypeListAttr Tresults = TF_DerivedResultTypeListAttr<0>; } +def TF_TPUReplicatedInputOp : TF_Op<"TPUReplicatedInput", [NoSideEffect]> { + let summary = "Connects N inputs to an N-way replicated TPU computation."; + + let description = [{ +This operation holds a replicated input to a `tpu.replicate()` computation subgraph. +Each replicated input has the same shape and type alongside the output. + +For example: +``` +%a = "tf.opA"() +%b = "tf.opB"() +%replicated_input = "tf.TPUReplicatedInput"(%a, %b) +%computation = "tf.Computation"(%replicated_input) +``` +The above computation has a replicated input of two replicas. + }]; + + let arguments = (ins + Variadic:$inputs, + + DefaultValuedAttr:$is_mirrored_variable, + DefaultValuedAttr:$index + ); + + let results = (outs + TF_Tensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>; +} + +def TF_TPUReplicatedOutputOp : TF_Op<"TPUReplicatedOutput", [NoSideEffect]> { + let summary = "Connects N outputs from an N-way replicated TPU computation."; + + let description = [{ +This operation holds a replicated output from a `tpu.replicate()` computation subgraph. +Each replicated output has the same shape and type alongside the input. + +For example: +``` +%computation = "tf.Computation"() +%replicated_output:2 = "tf.TPUReplicatedOutput"(%computation) +``` +The above computation has a replicated output of two replicas. + }]; + + let arguments = (ins + TF_Tensor:$input + ); + + let results = (outs + Variadic:$outputs + ); + + TF_DerivedResultSizeAttr num_replicas = TF_DerivedResultSizeAttr<0>; + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_TanhOp : TF_Op<"Tanh", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes hyperbolic tangent of `x` element-wise."; @@ -5905,6 +6053,8 @@ This is the opposite of `pack`. TF_DerivedResultSizeAttr num = TF_DerivedResultSizeAttr<0>; TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + + let verifier = [{ return Verify(*this); }]; } def TF_VariableShapeOp : TF_Op<"VariableShape", []> { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index 1bd9accbb78..9d2f634161c 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -301,6 +301,15 @@ void AddOp::getCanonicalizationPatterns(OwningRewritePatternList &results, results.insert(context); } +//===----------------------------------------------------------------------===// +// AddNOp +//===----------------------------------------------------------------------===// + +OpFoldResult AddNOp::fold(ArrayRef operands) { + if (operands.size() == 1) return *inputs().begin(); + return {}; +} + //===----------------------------------------------------------------------===// // AddV2Op //===----------------------------------------------------------------------===// @@ -310,6 +319,49 @@ void AddV2Op::getCanonicalizationPatterns(OwningRewritePatternList &results, results.insert(context); } +//===----------------------------------------------------------------------===// +// AllOp +//===----------------------------------------------------------------------===// + +// Verifies an reduction op's `input` and reduction `dims`. +static LogicalResult VerifyReductionInputAndDims(Value *input, Value *dims, + Location loc) { + auto dims_type = dims->getType().dyn_cast(); + if (!dims_type) return success(); + if (dims_type.getRank() > 1) + return emitError(loc, "dimensions can only be 0D or 1D tensor"); + + auto input_type = input->getType().dyn_cast(); + if (!input_type) return success(); + int64_t rank = input_type.getRank(); + + DenseIntElementsAttr dims_attr; + if (!matchPattern(dims, m_Constant(&dims_attr))) return success(); + for (const auto &dim_pair : llvm::enumerate(dims_attr)) { + int64_t cur_dim = dim_pair.value().getSExtValue(); + if (cur_dim < -rank || cur_dim >= rank) + return emitError(loc) + << dim_pair.index() << "-th dimension should be in the range of [-" + << rank << ", " << rank << ")"; + } + + return success(); +} + +static LogicalResult Verify(AllOp op) { + return VerifyReductionInputAndDims(op.input(), op.reduction_indices(), + op.getLoc()); +} + +//===----------------------------------------------------------------------===// +// AnyOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(AnyOp op) { + return VerifyReductionInputAndDims(op.input(), op.reduction_indices(), + op.getLoc()); +} + //===----------------------------------------------------------------------===// // AssertOp //===----------------------------------------------------------------------===// @@ -1542,17 +1594,23 @@ static LogicalResult Verify(SoftmaxCrossEntropyWithLogitsOp op) { // SplitOp //===----------------------------------------------------------------------===// -static LogicalResult Verify(SplitOp op) { +// Verifies the input and split dimension operands for tf.Split/tf.SplitV. +// Writes the split dimension's index (adjusted with input rank) via `dim_index` +// if it's a constant. +template +LogicalResult VerifySplitInputAndSplitDim(Op op, Optional *dim_index) { + *dim_index = llvm::None; + Value *split_dim = op.split_dim(); - auto split_dim_type = split_dim->getType().dyn_cast(); - if (!split_dim_type) return success(); - if (split_dim_type.getRank() != 0) - return op.emitOpError("split dimension should be an integer scalar tensor"); + if (auto split_dim_type = split_dim->getType().dyn_cast()) + if (split_dim_type.getRank() != 0) + return op.emitOpError( + "split dimension should be an integer scalar tensor"); // We can perform further verification if the input tensor to be split has // known rank and the split dimension tensor is a constant. - auto input_type = op.value()->getType().dyn_cast(); + auto input_type = op.value()->getType().template dyn_cast(); if (!input_type) return success(); int64_t input_rank = input_type.getRank(); @@ -1562,21 +1620,95 @@ static LogicalResult Verify(SplitOp op) { DenseIntElementsAttr split_dim_attr; if (!matchPattern(split_dim, m_Constant(&split_dim_attr))) return success(); - int64_t dim_index = (*split_dim_attr.begin()).getSExtValue(); + int64_t index = (*split_dim_attr.begin()).getSExtValue(); - if (dim_index + input_rank < 0 || dim_index >= input_rank) { + if (index + input_rank < 0 || index >= input_rank) { return op.emitOpError("split dimension must be in range [-") << input_rank << ", " << input_rank << ")"; } - if (dim_index < 0) dim_index += input_rank; + if (index < 0) index += input_rank; + *dim_index = index; - int64_t input_dim_size = input_type.getDimSize(dim_index); - if (input_dim_size < 0) return success(); + return success(); +} + +static LogicalResult Verify(SplitOp op) { + Optional dim_index; + if (failed(VerifySplitInputAndSplitDim(op, &dim_index))) return failure(); + if (!dim_index) return success(); + + int64_t input_dim_size = + op.value()->getType().cast().getDimSize(*dim_index); + if (input_dim_size == ShapedType::kDynamicSize) return success(); if (input_dim_size % op.getNumResults() != 0) return op.emitOpError("dimension #") - << dim_index << " not divisible by the number of result tensors"; + << *dim_index << " not divisible by the number of result tensors"; + + return success(); +} + +//===----------------------------------------------------------------------===// +// SplitVOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(SplitVOp op) { + auto split_sizes_type = + op.size_splits()->getType().dyn_cast(); + if (!split_sizes_type) return success(); + + if (split_sizes_type.getRank() != 1 || + split_sizes_type.getDimSize(0) != op.getNumResults()) + return op.emitOpError("split sizes should be a 1D tensor of ") + << op.getNumResults() << " elements"; + + Optional dim_index = 0; + if (failed(VerifySplitInputAndSplitDim(op, &dim_index))) return failure(); + if (!dim_index) return success(); + + int64_t input_dim_size = + op.value()->getType().cast().getDimSize(*dim_index); + if (input_dim_size == ShapedType::kDynamicSize) return success(); + + // If split sizes come from a constant, they must sum to the dimension size + // along split_dim, and we can have no more than one dynamic dimension. + DenseIntElementsAttr split_sizes_attr; + if (!matchPattern(op.size_splits(), m_Constant(&split_sizes_attr))) + return success(); + + int64_t total_dim_size = 0; // Total dimension size assigned to splits + llvm::Optional dynamic_dim_index; + + SmallVector split_sizes; + split_sizes.reserve( + split_sizes_attr.getType().cast().getNumElements()); + + for (auto dim : llvm::enumerate(split_sizes_attr)) { + int64_t dim_val = dim.value().getSExtValue(); + split_sizes.push_back(dim_val); + if (dim_val == ShapedType::kDynamicSize) { + // We cannot have more than one dynamic dimension. + if (dynamic_dim_index) + return op.emitOpError( + "cannot have more than one dynamic dimension in split sizes"); + dynamic_dim_index = dim.index(); + } else { + total_dim_size += dim_val; + } + } + + if (!dynamic_dim_index && total_dim_size != input_dim_size) + return op.emitOpError( + "split sizes must sum up to the dimension size along split " + "dimension, found ") + << total_dim_size << " vs " << input_dim_size; + + if (dynamic_dim_index && total_dim_size > input_dim_size) + return op.emitOpError( + "split sizes must sum up to be less than or equal to the " + "dimension size along split dimension, found ") + << total_dim_size << " vs " << input_dim_size; return success(); } @@ -1787,6 +1919,30 @@ void TruncateDivOp::getCanonicalizationPatterns( results.insert(context); } +//===----------------------------------------------------------------------===// +// UnpackOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(UnpackOp op) { + auto value_type = op.value()->getType().dyn_cast(); + if (!value_type) return success(); + + int64_t value_rank = value_type.getRank(); + int64_t axis = op.axis().getSExtValue(); + if (axis < -value_rank || axis >= value_rank) + return op.emitOpError("axis attribute must be in the range of [-") + << value_rank << ", " << value_rank << ')'; + + axis = GetDimForAxis(axis, value_rank); + int64_t dim_size = value_type.getDimSize(axis); + if (ShapedType::isDynamic(dim_size)) return success(); + + if (dim_size != op.getNumResults()) + return op.emitOpError("result count must be equal to ") << dim_size; + + return success(); +} + //===----------------------------------------------------------------------===// // VariableShapeOp //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td index 8d975e909bb..9b6196cda5b 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td @@ -196,6 +196,42 @@ retained with length 1. TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>; } +def TF_LegacyCallOp : TF_Op<"LegacyCall", + [CallOpInterface, NoSideEffect]> { + let summary = + "returns `f(inputs)`, where `f` is a function."; + + let description = [{ + The LegacyCall operation represents a direct call to a function that is + within the same symbol scope as the call and is mapped to a GraphDef node + with the function name as the op name. Unlike a PartitionedCall which + represents asynchronously executing a function across multiple devices, a + LegacyCall represents a function call with the only attribute + _diable_call_shape_inference. + }]; + + let arguments = (ins + Variadic:$args, + + FlatSymbolRefAttr:$f, + DefaultValuedAttr:$_disable_call_shape_inference + ); + + let results = (outs + Variadic:$output + ); + + let extraClassDeclaration = [{ + // Gets the argument operands to the called function. + operand_range getArgOperands() { return args(); } + + // Returns the callee of this operation. + CallInterfaceCallable getCallableForCallee() { + return getAttrOfType("f"); + } + }]; +} + def TF_PartitionedCallOp : TF_Op<"PartitionedCall", [CallOpInterface, NoSideEffect]> { let summary = diff --git a/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir b/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir index 67c3982fe3b..d5a5c16cbff 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir @@ -18,7 +18,7 @@ func @multiple_return(%arg0: tensor<*xi32>, %arg1: tensor) -> (tensor<*xi32 // CHECK-LABEL: func @multiple_return // CHECK: %[[GRAPH:.*]]:2 = tf_executor.graph { // CHECK: %[[ADD1:.*]], %[[ADD1_control:.*]] = tf_executor.island wraps "tf.Add"(%arg0, %arg1) -// CHECK: %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island(%[[ADD1_control]]) wraps "tf.Add"(%[[ADD1]], %arg1) +// CHECK: %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ADD1]], %arg1) // CHECK: tf_executor.fetch %[[ADD1]], %[[ADD2]] : // CHECK: } // CHECK: return %[[GRAPH]]#0, %[[GRAPH]]#1 @@ -41,7 +41,12 @@ func @multiple_islands(%arg0: tensor<*xi32>, %arg1: tensor) -> (tensor<*xi3 %res = "tf.Print"(%sub) { message = "sub result" } : (tensor<*xi32>) -> (tensor<*xi32>) tf_executor.yield } - tf_executor.fetch %island1#1, %island2#1, %island3 : tensor<*xi32>, tensor<*xi32>, !tf_executor.control + %island4 = tf_executor.island(%island1#2, %island2#2) { + %add = "tf.Add"(%island1#1, %island1#1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> + %res = "tf.Print"(%add) { message = "add result" } : (tensor<*xi32>) -> (tensor<*xi32>) + tf_executor.yield + } + tf_executor.fetch %island1#1, %island2#1, %island3, %island4 : tensor<*xi32>, tensor<*xi32>, !tf_executor.control, !tf_executor.control } return %graph#0, %graph#1 : tensor<*xi32>, tensor<*xi32> } @@ -49,12 +54,17 @@ func @multiple_islands(%arg0: tensor<*xi32>, %arg1: tensor) -> (tensor<*xi3 // CHECK-LABEL: func @multiple_islands // CHECK: %[[GRAPH:.*]]:2 = tf_executor.graph { // CHECK: %[[ADD1:.*]], %[[ADD1_control:.*]] = tf_executor.island wraps "tf.Add"(%arg0, %arg1) -// CHECK: %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island(%[[ADD1_control]]) wraps "tf.Add"(%[[ADD1]], %arg1) +// CHECK: %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ADD1]], %arg1) // CHECK: %[[SUB1:.*]], %[[SUB1_control:.*]] = tf_executor.island(%[[ADD2_control]]) wraps "tf.Sub"(%arg0, %arg1) -// CHECK: %[[MUL:.*]], %[[MUL_control:.*]] = tf_executor.island(%[[SUB1_control]]) wraps "tf.Mul"(%[[SUB1]], %arg1) +// CHECK: %[[MUL:.*]], %[[MUL_control:.*]] = tf_executor.island wraps "tf.Mul"(%[[SUB1]], %arg1) // CHECK: %[[SUB2:.*]], %[[SUB2_control:.*]] = tf_executor.island(%[[ADD2_control]], %[[MUL_control]]) wraps "tf.Sub"(%[[ADD1]], %[[SUB1]]) -// CHECK: %[[PRINT:.*]], %[[PRINT_control:.*]] = tf_executor.island(%[[SUB2_control]]) wraps "tf.Print"(%[[SUB2]]) {message = "sub result"} -// CHECK: tf_executor.fetch %[[ADD2]], %[[MUL]], %[[PRINT_control]] : +// CHECK: %[[PRINT1:.*]], %[[PRINT1_control:.*]] = tf_executor.island wraps "tf.Print"(%[[SUB2]]) {message = "sub result"} +// CHECK: %[[ISLAND1:.*]] = tf_executor.island(%[[ADD2_control]], %[[MUL_control]]) { +// CHECK: tf_executor.yield +// CHECK: } +// CHECK: %[[ADD3:.*]], %[[ADD3_control:.*]] = tf_executor.island(%[[ISLAND1]], %[[ADD2_control]]) wraps "tf.Add"(%[[ADD2]], %[[ADD2]]) +// CHECK: %[[PRINT2:.*]], %[[PRINT2_control:.*]] = tf_executor.island wraps "tf.Print"(%[[ADD3]]) {message = "add result"} +// CHECK: tf_executor.fetch %[[ADD2]], %[[MUL]], %[[PRINT1_control]], %[[PRINT2_control:.*]] : // CHECK: } // CHECK: return %[[GRAPH]]#0, %[[GRAPH]]#1 @@ -74,8 +84,8 @@ func @dangling_print(%arg0: tensor<*xi32>, %arg1: tensor) -> (tensor<*xi32> // CHECK-LABEL: func @dangling_print // CHECK: %[[GRAPH:.*]]:2 = tf_executor.graph { // CHECK: %[[ADD1:.*]], %[[ADD1_control:.*]] = tf_executor.island wraps "tf.Add"(%arg0, %arg1) -// CHECK: %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island(%[[ADD1_control]]) wraps "tf.Add"(%[[ADD1_control:.*]], %arg1) -// CHECK: %[[PRINT:.*]], %[[PRINT_control:.*]] = tf_executor.island(%[[ADD2_control]]) wraps "tf.Print"(%[[ADD2_control:.*]]) {message = "add result"} +// CHECK: %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ADD1_control:.*]], %arg1) +// CHECK: %[[PRINT:.*]], %[[PRINT_control:.*]] = tf_executor.island wraps "tf.Print"(%[[ADD2_control:.*]]) {message = "add result"} // CHECK: tf_executor.fetch %[[ADD1]], %[[ADD2]], %[[PRINT_control]] : // CHECK: } // CHECK: return %[[GRAPH]]#0, %[[GRAPH]]#1 @@ -103,11 +113,14 @@ func @switch_and_merge(%arg0: tensor<*xi32>, %arg1: tensor) -> (tensor<*xi3 // CHECK-LABEL: func @switch_and_merge(%arg0: tensor<*xi32>, %arg1: tensor) -> (tensor<*xi32>, tensor) { // CHECK: %[[GRAPH:.*]]:2 = tf_executor.graph { // CHECK: %[[ADD1:.*]], %[[ADD1_control:.*]] = tf_executor.island wraps "tf.Add"(%arg0, %arg1) -// CHECK: %[[LESS:.*]], %[[LESS_control:.*]] = tf_executor.island(%[[ADD1_control]]) wraps "tf.Less"(%arg1, %arg1) -// CHECK: %[[PRINT1:.*]], %[[PRINT1_control:.*]] = tf_executor.island(%[[LESS_control]]) wraps "tf.Print"(%[[ADD1]]) {message = "add result 1"} -// CHECK: %[[SWITCH_false:.*]], %[[SWITCH_true:.*]], {{.*}} = tf_executor.Switch %[[ADD1]], %[[LESS]], %[[PRINT1_control]] +// CHECK: %[[LESS:.*]], %[[LESS_control:.*]] = tf_executor.island wraps "tf.Less"(%arg1, %arg1) +// CHECK: %[[PRINT1:.*]], %[[PRINT1_control:.*]] = tf_executor.island wraps "tf.Print"(%[[ADD1]]) {message = "add result 1"} +// CHECK: %[[ISLAND1:.*]] = tf_executor.island(%[[LESS_control]], %[[PRINT1_control]]) { +// CHECK: tf_executor.yield +// CHECK: } +// CHECK: %[[SWITCH_false:.*]], %[[SWITCH_true:.*]], {{.*}} = tf_executor.Switch %[[ADD1]], %[[LESS]], %[[ISLAND1]] // CHECK: %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island wraps "tf.Add"(%[[SWITCH_false]], %arg1) -// CHECK: %[[PRINT2:.*]], %[[PRINT2_control:.*]] = tf_executor.island(%[[ADD2_control]]) wraps "tf.Print"(%[[ADD2]]) {message = "add result 2"} +// CHECK: %[[PRINT2:.*]], %[[PRINT2_control:.*]] = tf_executor.island wraps "tf.Print"(%[[ADD2]]) {message = "add result 2"} // CHECK: %[[MERGE:.*]], %[[MERGE_index:.*]], %{{.*}} = tf_executor.Merge %[[ADD2]], %[[SWITCH_true]], %[[PRINT2_control]] // CHECK: tf_executor.fetch %[[MERGE]], %[[MERGE_index]] // CHECK: } @@ -130,7 +143,7 @@ func @control_flow_plumbing(%arg0: tensor<*xi32>, %arg1: tensor) -> tensor< // CHECK: %[[GRAPH:.*]] = tf_executor.graph { // CHECK: %[[PRINT:.*]], %[[PRINT_control:.*]] = tf_executor.island wraps "tf.Print"(%arg0) {message = "Random Print"} // CHECK: %[[ADD1:.*]], %[[ADD1_control:.*]] = tf_executor.island(%[[PRINT_control]]) wraps "tf.Add"(%arg0, %arg1) -// CHECK: %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island(%[[ADD1_control]]) wraps "tf.Add"(%[[ADD1]], %arg1) +// CHECK: %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ADD1]], %arg1) // CHECK: tf_executor.fetch %[[ADD2]] : tensor<*xi32> // CHECK: } // CHECK: return %[[GRAPH]] : tensor<*xi32> @@ -150,6 +163,77 @@ func @fetching_arg(%arg0: tensor<*xi32>) { // CHECK-LABEL: func @fetching_arg // CHECK: tf_executor.graph { // CHECK: %[[ADD1:.*]], %[[ADD1_control:.*]] = tf_executor.island wraps "tf.Add"(%arg0, %arg0) -// CHECK: %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island(%[[ADD1_control]]) wraps "tf.Add"(%[[ADD1]], %arg0) +// CHECK: %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ADD1]], %arg0) // CHECK: tf_executor.fetch %[[ADD2_control]] : !tf_executor.control // CHECK: } + +func @non_aliasing_reads_writes( + %arg0: tensor<*x!tf.resource>>, + %arg1: tensor<*x!tf.resource>>, + %arg2: tensor<32xf32>) -> (tensor<32xf32>) { + %graph = tf_executor.graph { + %island:2 = tf_executor.island { + %read0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>>) -> tensor<32xf32> + "tf.AssignVariableOp"(%arg0, %arg2) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> () + %read1 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource>>) -> tensor<32xf32> + %var_handle = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> tensor<*x!tf.resource>> + %read2 = "tf.ReadVariableOp"(%var_handle) : (tensor<*x!tf.resource>>) -> tensor<32xf32> + "tf.AssignVariableOp"(%arg1, %read0) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> () + "tf.AssignVariableOp"(%arg0, %read2) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> () + %read3 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>>) -> tensor<32xf32> + tf_executor.yield %read3 : tensor<32xf32> + } + tf_executor.fetch %island#0 : tensor<32xf32> + } + return %graph : tensor<32xf32> +} + +// CHECK-LABEL: func @non_aliasing_reads_writes +// CHECK: %[[GRAPH:.*]] = tf_executor.graph { +// CHECK: %[[READ0:.*]], %[[READ0_CONTROL:.*]] = tf_executor.island wraps "tf.ReadVariableOp"(%arg0) +// CHECK: %[[ASSIGN0_CONTROL:.*]] = tf_executor.island(%[[READ0_CONTROL]]) wraps "tf.AssignVariableOp"(%arg0, %arg2) +// CHECK: %[[READ1:.*]], %[[READ1_CONTROL:.*]] = tf_executor.island wraps "tf.ReadVariableOp"(%arg1) +// CHECK: %[[VH0:.*]], %[[VH0_CONTROL:.*]] = tf_executor.island wraps "tf.VarHandleOp"() {container = "c", shared_name = "v0"} +// CHECK: %[[READ2:.*]], %[[READ2_CONTROL:.*]] = tf_executor.island wraps "tf.ReadVariableOp"(%[[VH0]]) +// CHECK: %[[ASSIGN1_CONTROL:.*]] = tf_executor.island(%[[READ1_CONTROL]]) wraps "tf.AssignVariableOp"(%arg1, %[[READ0:.*]]) +// CHECK: %[[ASSIGN2_CONTROL:.*]] = tf_executor.island(%[[ASSIGN0_CONTROL]]) wraps "tf.AssignVariableOp"(%arg0, %[[READ2]]) +// CHECK: %[[READ3:.*]], %[[READ3_CONTROL:.*]] = tf_executor.island(%[[ASSIGN2_CONTROL]]) wraps "tf.ReadVariableOp"(%arg0) +// CHECK: %[[ISLAND1:.*]] = tf_executor.island(%[[ASSIGN1_CONTROL]], %[[READ3_CONTROL]]) { +// CHECK: tf_executor.yield +// CHECK: } +// CHECK: tf_executor.fetch %[[READ3]], %[[ISLAND1]] : tensor<32xf32>, !tf_executor.control +// CHECK: } + +func @unknown_side_effecting_op(%arg0: tensor<32xf32>) -> () { + tf_executor.graph { + %island = tf_executor.island { + %vh0 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> tensor<*x!tf.resource>> + %vh1 = "tf.VarHandleOp"() {container = "c", shared_name = "v1"} : () -> tensor<*x!tf.resource>> + %read0 = "tf.ReadVariableOp"(%vh0) : (tensor<*x!tf.resource>>) -> tensor<32xf32> + "tf.AssignVariableOp"(%vh1, %arg0) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> () + "tf._UnknownSideEffectingOp_"() : () -> () + %read1 = "tf.ReadVariableOp"(%vh1) : (tensor<*x!tf.resource>>) -> tensor<32xf32> + "tf.AssignVariableOp"(%vh0, %read1) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> () + "tf.AssignVariableOp"(%vh1, %read0) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> () + tf_executor.yield + } + tf_executor.fetch %island : !tf_executor.control + } + return +} + +// CHECK-LABEL: func @unknown_side_effecting_op +// CHECK: tf_executor.graph { +// CHECK: %[[VH0:.*]], %[[VH0_CONTROL:.*]] = tf_executor.island wraps "tf.VarHandleOp"() {container = "c", shared_name = "v0"} +// CHECK: %[[VH1:.*]], %[[VH1_CONTROL:.*]] = tf_executor.island wraps "tf.VarHandleOp"() {container = "c", shared_name = "v1"} +// CHECK: %[[READ0:.*]], %[[READ0_CONTROL:.*]] = tf_executor.island wraps "tf.ReadVariableOp"(%[[VH0]]) +// CHECK: %[[ASSIGN0_CONTROL:.*]] = tf_executor.island wraps "tf.AssignVariableOp"(%[[VH1]], %arg0) +// CHECK: %[[UNKNOWN_CONTROL:.*]] = tf_executor.island(%[[READ0_CONTROL]], %[[ASSIGN0_CONTROL]]) wraps "tf._UnknownSideEffectingOp_"() +// CHECK: %[[READ1:.*]], %[[READ1_CONTROL:.*]] = tf_executor.island(%[[UNKNOWN_CONTROL]]) wraps "tf.ReadVariableOp"(%[[VH1]]) +// CHECK: %[[ASSIGN1_CONTROL:.*]] = tf_executor.island(%[[UNKNOWN_CONTROL]]) wraps "tf.AssignVariableOp"(%[[VH0]], %[[READ1]]) +// CHECK: %[[ASSIGN2_CONTROL:.*]] = tf_executor.island(%[[READ1_CONTROL]]) wraps "tf.AssignVariableOp"(%[[VH1]], %[[READ0]]) +// CHECK: %[[ISLAND1:.*]] = tf_executor.island(%[[ASSIGN1_CONTROL]], %[[ASSIGN2_CONTROL]]) { +// CHECK: tf_executor.yield +// CHECK: } +// CHECK: tf_executor.fetch %[[ISLAND1]] : !tf_executor.control +// CHECK: } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir index a2cc33a8201..18c63912a86 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir @@ -382,3 +382,10 @@ func @nonIdentityTranspose(%arg0: tensor<2x3x4x5x6xf32>) -> tensor<2x3x4x6x5xf32 // CHECK: %1 = "tf.Transpose"(%arg0, %0) : (tensor<2x3x4x5x6xf32>, tensor<5xi32>) -> tensor<2x3x4x6x5xf32> // CHECK: return %1 } + +// CHECK-LABEL: func @addN +func @addN(%arg0: tensor<*xf32>) -> tensor<*xf32> { + // CHECK: return %arg0 + %0 = "tf.AddN"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/decompose_resource_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/decompose_resource_ops.mlir new file mode 100644 index 00000000000..67d58b41199 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/decompose_resource_ops.mlir @@ -0,0 +1,67 @@ +// RUN: tf-opt %s -split-input-file -tf-device-decompose-resource-ops | FileCheck %s + +// ----- + +// Tests that composite tf.AssignAddVariableOp operation is decomposed and +// hoisted. + +// CHECK-LABEL: func @decompose_assign_add_variable_op +func @decompose_assign_add_variable_op() -> () { + + %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource> + + // CHECK: %[[ONE:[0-9]*]] = "tf.Const"() {value = dense<1> : tensor} + // CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp" + // CHECK: "tf.AddV2"(%[[RES_READ_VAL]], %[[ONE]]) + // CHECK: "tf.AssignVariableOp" + + %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + "tf.AssignAddVariableOp"(%0, %1) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>, tensor) -> () + + return +} + +// ----- + +// Tests that composite tf.AssignSubVariableOp operation is decomposed using +// SubOp. + +// CHECK-LABEL: func @decompose_assign_sub_variable_op +func @decompose_assign_sub_variable_op() -> () { + + %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource> + + // CHECK: %[[ONE:[0-9]*]] = "tf.Const"() {value = dense<1> : tensor} + // CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp" + // CHECK: "tf.Sub"(%[[RES_READ_VAL]], %[[ONE]]) + // CHECK: "tf.AssignVariableOp" + + %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + "tf.AssignSubVariableOp"(%0, %1) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>, tensor) -> () + + return +} + +// ----- + +// Tests that composite tf.ResourceApplyGradientDescent operation is decomposed. + +// CHECK-LABEL: func @decompose_resource_apply_gradient_descent +// CHECK-SAME: (%[[DELTA:.*]]: tensor) +func @decompose_resource_apply_gradient_descent(%arg0: tensor) -> () { + + %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource> + + // CHECK: %[[ALPHA:[0-9]*]] = "tf.Const" + // CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp" + // CHECK: %[[MUL:[0-9]*]] = "tf.Mul"(%[[DELTA]], %[[ALPHA]]) + // CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) + // CHECK: %[[SUB:[0-9]*]] = "tf.Sub"(%[[RES_READ_VAL]], %[[MUL]]) + // CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[SUB]]) + + %1 = "tf.Const"() {T = f32, value = dense<[0.5]> : tensor<1xf32>} : () -> tensor + "tf.ResourceApplyGradientDescent"(%0, %1, %arg0) {use_locking = false} : (tensor<*x!tf.resource>, tensor, tensor) -> () + + return +} + diff --git a/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-cfg.mlir b/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-cfg.mlir index 2a0434b69e0..a0390ec8738 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-cfg.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-cfg.mlir @@ -49,40 +49,33 @@ func @testIf3Result(tensor, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xi8>, // ----- -func @testIf1Then(tensor<2x?xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> -func @testIf1Else(tensor<*xf32>, tensor<2x?xf32>) -> tensor<*xf32> +func @testIfThen(%arg0: tensor) -> tensor { + return %arg0 : tensor +} +func @testIfElse(%arg0: tensor) -> tensor { + return %arg0 : tensor +} -// CHECK-LABEL: func @testIf1Casts(%arg0: tensor, %arg1: tensor<2x2xf32>, %arg2: tensor<*xf32>) -func @testIf1Casts(tensor, tensor<2x2xf32>, tensor<*xf32>) -> tensor<2x?xf32> { -^bb0(%arg0: tensor, %arg1: tensor<2x2xf32>, %arg2: tensor<*xf32>): - - %1 = "tf.If"(%arg0, %arg1, %arg2) { - then_branch = @testIf1Then, else_branch = @testIf1Else, is_stateless = false - } : (tensor, tensor<2x2xf32>, tensor<*xf32>) -> tensor<2x?xf32> - -// CHECK: %0 = extract_element %arg0[] : tensor -// CHECK: cond_br %0, ^bb1, ^bb2 -// CHECK:^bb1: // pred: ^bb0 -// CHECK: %1 = tensor_cast %arg1 : tensor<2x2xf32> to tensor<2x?xf32> -// CHECK: %2 = tensor_cast %arg2 : tensor<*xf32> to tensor<2x2xf32> -// CHECK: %3 = call @testIf1Then(%1, %2) : (tensor<2x?xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> -// CHECK: %4 = tensor_cast %3 : tensor<2x2xf32> to tensor<2x?xf32> -// CHECK: br ^bb3(%4 : tensor<2x?xf32>) - -// CHECK:^bb2: // pred: ^bb0 -// CHECK: %5 = tensor_cast %arg1 : tensor<2x2xf32> to tensor<*xf32> -// CHECK: %6 = tensor_cast %arg2 : tensor<*xf32> to tensor<2x?xf32> -// CHECK: %7 = call @testIf1Else(%5, %6) : (tensor<*xf32>, tensor<2x?xf32>) -> tensor<*xf32> -// CHECK: %8 = tensor_cast %7 : tensor<*xf32> to tensor<2x?xf32> -// CHECK: br ^bb3(%8 : tensor<2x?xf32>) - -// CHECK:^bb3(%9: tensor<2x?xf32>): // 2 preds: ^bb1, ^bb2 - - %2 = "tf.Add"(%1, %1) : (tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xf32> -// CHECK: %10 = "tf.Add"(%9, %9) : (tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xf32> - - return %2 : tensor<2x?xf32> -// CHECK: return %10 : tensor<2x?xf32> +// CHECK-LABEL: func @testIfCasts(%arg0: tensor, %arg1: tensor>>) -> tensor>> +func @testIfCasts(%arg0: tensor, %arg1: tensor>>) -> tensor>> { + %0 = "tf.If"(%arg0, %arg1) { + then_branch = @testIfThen, else_branch = @testIfElse, is_stateless = false + } : (tensor, tensor>>) -> tensor>> + return %0: tensor>> +// CHECK: %0 = extract_element %arg0[] : tensor +// CHECK: cond_br %0, ^bb1, ^bb2 +// CHECK: ^bb1: +// CHECK: %1 = "tf.Cast"(%arg1) {Truncate = false} : (tensor>>) -> tensor +// CHECK: %2 = call @testIfThen(%1) : (tensor) -> tensor +// CHECK: %3 = "tf.Cast"(%2) {Truncate = false} : (tensor) -> tensor>> +// CHECK: br ^bb3(%3 : tensor>>) +// CHECK: ^bb2: +// CHECK: %4 = "tf.Cast"(%arg1) {Truncate = false} : (tensor>>) -> tensor +// CHECK: %5 = call @testIfElse(%4) : (tensor) -> tensor +// CHECK: %6 = "tf.Cast"(%5) {Truncate = false} : (tensor) -> tensor>> +// CHECK: br ^bb3(%6 : tensor>>) +// CHECK: ^bb3(%7: tensor>>): +// CHECK: return %7 : tensor>> } // ----- @@ -188,31 +181,36 @@ func @testComplexWhile1Result(tensor<*xf32>) -> (tensor<*xf32>) { // ----- -func @testWhileCond(tensor) -> (tensor) -func @testWhileBody(tensor<*xf32>) -> (tensor) +func @testWhileCond(%arg0: tensor) -> (tensor) { + %true = "tf.Const"() { value = dense : tensor } : () -> (tensor) + return %true : tensor +} +func @testWhileBody(%arg0: tensor>>) -> (tensor>>) { + %0 = "tf.Cast"(%arg0) : (tensor>>) -> tensor>> + return %0 : tensor>> +} -// CHECK-LABEL: func @testWhileCasts(%arg0: tensor<1x3xf32>) -func @testWhileCasts(%arg0: tensor<1x3xf32>) -> (tensor) { +// CHECK-LABEL: func @testWhileCasts(%arg0: tensor>>) -> tensor>> +func @testWhileCasts(%arg0: tensor>>) -> (tensor>>) { %0 = "tf.While"(%arg0) { cond = @testWhileCond, body = @testWhileBody, is_stateless = false - } : (tensor<1x3xf32>) -> (tensor) - -// CHECK: %0 = tensor_cast %arg0 : tensor<1x3xf32> to tensor -// CHECK: br ^bb1(%0 : tensor) -// CHECK: ^bb1(%1: tensor): -// CHECK: %2 = call @testWhileCond(%1) : (tensor) -> tensor + } : (tensor>>) -> (tensor>>) + return %0 : tensor>> +// CHECK: %0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor>>) -> tensor +// CHECK: br ^bb1(%0 : tensor) +// CHECK: ^bb1(%1: tensor): // 2 preds: ^bb0, ^bb2 +// CHECK: %2 = call @testWhileCond(%1) : (tensor) -> tensor // CHECK: %3 = extract_element %2[] : tensor -// CHECK: %4 = tensor_cast %1 : tensor to tensor<*xf32> -// CHECK: cond_br %3, ^bb2(%4 : tensor<*xf32>), ^bb3(%4 : tensor<*xf32>) -// CHECK: ^bb2(%5: tensor<*xf32>): -// CHECK: %6 = call @testWhileBody(%5) : (tensor<*xf32>) -> tensor -// CHECK: %7 = tensor_cast %6 : tensor to tensor -// CHECK: br ^bb1(%7 : tensor) -// CHECK: ^bb3(%8: tensor<*xf32>): -// CHECK: %9 = tensor_cast %8 : tensor<*xf32> to tensor +// CHECK: %4 = "tf.Cast"(%1) {Truncate = false} : (tensor) -> tensor>> +// CHECK: cond_br %3, ^bb2(%4 : tensor>>), ^bb3(%4 : tensor>>) +// CHECK: ^bb2(%5: tensor>>): // pred: ^bb1 +// CHECK: %6 = call @testWhileBody(%5) : (tensor>>) -> tensor>> +// CHECK: %7 = "tf.Cast"(%6) {Truncate = false} : (tensor>>) -> tensor +// CHECK: br ^bb1(%7 : tensor) +// CHECK: ^bb3(%8: tensor>>): // pred: ^bb1 +// CHECK: %9 = "tf.Cast"(%8) {Truncate = false} : (tensor>>) -> tensor>> +// CHECK: return %9 : tensor>> - return %0 : tensor -// CHECK: return %9 : tensor } // ----- diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-custom-operation.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-custom-operation.pbtxt index 9ce15315832..207d6676f61 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-custom-operation.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-custom-operation.pbtxt @@ -54,5 +54,5 @@ versions { # the names are matching between the function definition and the uses / call # site (a numerical suffix may be appended). -# CHECK: "tf.foo0"( +# CHECK: "tf.LegacyCall"(%outputs) {_disable_call_shape_inference = false, f = @foo0} # CHECK: func @foo0 diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-default-attr.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-default-attr.pbtxt index b26d7e7f2ba..ac248041994 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-default-attr.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-default-attr.pbtxt @@ -8,7 +8,7 @@ # Verify that we can also pull some attributes that are needed to be able to # create a Graph in memory, like `T`. # CHECK: tf.MaxPool -# CHECK-SAME: T = "tfdtype$DT_FLOAT" +# CHECK-SAME: T = f32 node { name: "input" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-call.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-call.pbtxt new file mode 100644 index 00000000000..f0a7a574ae3 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-call.pbtxt @@ -0,0 +1,65 @@ +# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=x -tf-input-data-types=DT_INT32 -tf-input-shapes=10 -tf-output-arrays=func_call -o - | FileCheck %s + +node { + name: "x" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } +} +node { + name: "func_call" + op: "test_func_name" + input: "x" + attr { + key: "_disable_call_shape_inference" + value { + b: true + } + } +} +library { + function { + signature { + name: "test_func_name" + input_arg { + name: "a_0" + type: DT_INT32 + } + output_arg { + name: "a" + type: DT_INT32 + } + } + ret { + key: "a" + value: "a_0" + } + attr { + key: "_disable_call_shape_inference" + value { + b: true + } + } + } +} + +# CHECK: func @main +# CHECK: "tf.LegacyCall"(%arg0) {_disable_call_shape_inference = true, f = @test_func_name0} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-name-bug.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-name-bug.pbtxt index dcdbe67ccb6..563007f4305 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-name-bug.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-name-bug.pbtxt @@ -121,8 +121,8 @@ versions { # Verify that functions from the library are properly imported. # CHECK-LABEL: func @main() { -# CHECK: "tf.foo110"() -# CHECK: "tf.foo111"() +# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @foo110} +# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @foo111} # CHECK-LABEL: func @foo110() { # CHECK-LABEL: func @foo111() { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-library.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-library.pbtxt index 17b2655aa5d..b65984227f6 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-library.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-library.pbtxt @@ -39,10 +39,10 @@ versions { # Verify that functions from the library are properly imported. # CHECK-LABEL: func @main() { -# CHECK: "tf.foo0"() -# CHECK: "tf.bar0"() +# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @foo0} +# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @bar0} # CHECK-LABEL: func @foo0() { -# CHECK: "tf.bar0"() +# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @bar0} # CHECK-LABEL: func @bar0() { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/multi-output-feeds.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/multi-output-feeds.pbtxt new file mode 100644 index 00000000000..b28e2818730 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/multi-output-feeds.pbtxt @@ -0,0 +1,300 @@ +# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=z:1,z:2 -tf-input-shapes=':' -tf-output-arrays=z:2,z:1,a:0 -o - | FileCheck %s --dump-input=fail +# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-prune-unused-nodes -tf-input-arrays=z:1,z:2 -tf-input-shapes=':' -tf-output-arrays=z:2,z:1,a:0 -o - | FileCheck --check-prefix=PRUNE %s --dump-input=fail +# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-prune-unused-nodes -tf-input-arrays=z:1,z:2 -tf-input-shapes=':' -tf-output-arrays=z:0,a:0 -o - | FileCheck --check-prefix=PRESERVE %s --dump-input=fail + +# Generated in Python via +# ``` +# import tensorflow as tf +# +# with tf.compat.v1.Graph().as_default() as g: +# w = tf.constant(2.0) +# x = tf.constant(3.0) +# y = tf.constant(4.0) +# var = tf.Variable(2.0) +# var_add = var.assign_add(3.0) +# with g.control_dependencies([var_add]): +# z0, z1, z2 = tf.identity_n((w, x, y)) +# +# a = tf.add(z1, z2) +# ``` + +node { + name: "w" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 2.0 + } + } + } +} +node { + name: "x" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 3.0 + } + } + } +} +node { + name: "y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 4.0 + } + } + } +} +node { + name: "var/initial_value" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 2.0 + } + } + } +} +node { + name: "var" + op: "VariableV2" + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "var/Assign" + op: "Assign" + input: "var" + input: "var/initial_value" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@var" + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } +} +node { + name: "var/read" + op: "Identity" + input: "var" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@var" + } + } + } +} +node { + name: "var_add/value" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 3.0 + } + } + } +} +node { + name: "var_add" + op: "AssignAdd" + input: "var" + input: "var_add/value" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@var" + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } +} +node { + name: "z" + op: "IdentityN" + input: "w" + input: "x" + input: "y" + input: "^var_add" + attr { + key: "T" + value { + list { + type: DT_FLOAT + type: DT_FLOAT + type: DT_FLOAT + } + } + } +} +node { + name: "a" + op: "Add" + input: "z:1" + input: "z:2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +versions { + producer: 230 +} + +# Test non zero index output tensors as feeds. Original ops where their outputs +# are replaced with feeds are preserved and args and rets are lifted to the +# function. Rets that happen to coincide with a feed should have its value be +# of the feed. +# +# CHECK: func @main(%[[ARG_0:.*]]: tensor, %[[ARG_1:.*]]: tensor) -> (tensor, tensor, tensor) +# CHECK: attributes {tf.entry_function = {inputs = "z:1,z:2", outputs = "z:2,z:1,a:0"}} +# CHECK: %{{.*}}, %[[ASSIGN_ADD_CTRL:.*]] = tf_executor.island wraps "tf.AssignAdd" +# CHECK: %{{.*}}, %{{.*}} = tf_executor.island(%[[ASSIGN_ADD_CTRL]]) wraps "tf.IdentityN" +# CHECK: %[[ADD:.*]], %{{.*}} = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]]) +# CHECK: tf_executor.fetch %[[ARG_1]], %[[ARG_0]], %[[ADD]] + +# Test when non zero index output tensors are feeds, remaining ops that are +# unreachable are pruned if pruning is enabled. +# +# PRUNE: func @main(%[[ARG_0:.*]]: tensor, %[[ARG_1:.*]]: tensor) -> (tensor, tensor, tensor) +# PRUNE: attributes {tf.entry_function = {inputs = "z:1,z:2", outputs = "z:2,z:1,a:0"}} +# PRUNE-NOT: "tf.Const" +# PRUNE-NOT: "tf.VariableV2" +# PRUNE-NOT: "tf.Assign" +# PRUNE-NOT: "tf.Identity" +# PRUNE-NOT: "tf.AssignAdd" +# PRUNE-NOT: "tf.IdentityN" +# PRUNE: %[[ADD:.*]], %{{.*}} = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]]) +# PRUNE: tf_executor.fetch %[[ARG_1]], %[[ARG_0]], %[[ADD]] + +# Test when non zero index output tensors are feeds, remaining ops that are +# unreachable are preserved if pruning is not enabled. +# +# PRESERVE: func @main(%[[ARG_0:.*]]: tensor, %[[ARG_1:.*]]: tensor) -> (tensor, tensor) +# PRESERVE: attributes {tf.entry_function = {inputs = "z:1,z:2", outputs = "z:0,a:0"}} +# PRESERVE: %{{.*}}, %[[ASSIGN_ADD_CTRL:.*]] = tf_executor.island wraps "tf.AssignAdd" +# PRESERVE: %[[IDENTITY_N:.*]]:3, %{{.*}} = tf_executor.island(%[[ASSIGN_ADD_CTRL]]) wraps "tf.IdentityN" +# PRESERVE: %[[ADD:.*]], %{{.*}} = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]]) +# PRESERVE: tf_executor.fetch %[[IDENTITY_N]]#0, %[[ADD]] diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/switch_n.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/switch_n.pbtxt index d33ac2f3b5b..3dd5ce58ed2 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/switch_n.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/switch_n.pbtxt @@ -2,11 +2,11 @@ # CHECK: tf_executor.SwitchN # CHECK-SAME: of 3 : tensor -# CHECK-SAME: T = "tfdtype$DT_INT32" +# CHECK-SAME: T = i32 # CHECK-SAME: name = "Case/branch_index/_3" # CHECK: tf_executor.SwitchN # CHECK-SAME: of 2 : tensor -# CHECK-SAME: T = "tfdtype$DT_FLOAT" +# CHECK-SAME: T = f32 # CHECK-SAME: name = "Case/Case/input_0/_7" node { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir index 120e73f6e94..60ffc924ae5 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir @@ -250,3 +250,19 @@ func @ZerosLike_variant(%arg0: tensor>>) -> tensor>>) -> tensor>> return %0 : tensor>> } + +// CHECK-LABEL: func @addN +func @addN(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<*xf32> { + // CHECK: %[[SUM0:.*]] = "tf.AddV2"(%arg0, %arg1) + // CHECK: %[[SUM1:.*]] = "tf.AddV2"(%[[SUM0]], %arg2) + // return %[[SUM1]] + %0 = "tf.AddN"(%arg0, %arg1, %arg2) : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// CHECK-LABEL: func @addN_variant +func @addN_variant(%arg0: tensor>>, %arg1: tensor>>, %arg2: tensor>>) -> tensor>> { + // CHECK: tf.AddN + %0 = "tf.AddN"(%arg0, %arg1, %arg2) : (tensor>>, tensor>>, tensor>>) -> tensor>> + return %0 : tensor>> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/tf-legacy-call.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/tf-legacy-call.mlir new file mode 100644 index 00000000000..6c83b45295e --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/tf-legacy-call.mlir @@ -0,0 +1,26 @@ +// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s + +func @main() { + tf_executor.graph { + %outputs, %control = tf_executor.island wraps "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Constant", value = dense<0> : tensor} : () -> tensor + %outputs_0, %control_1 = tf_executor.island wraps "tf.LegacyCall"(%outputs) {f = @foo0} : (tensor) -> tensor + tf_executor.fetch + } + return +} +func @foo0(%arg0: tensor<*xi32>) -> tensor<*xi32> { + %0 = tf_executor.graph { + tf_executor.fetch %arg0 : tensor<*xi32> + } + return %0 : tensor<*xi32> +} + +// CHECK: node { +// CHECK: name: "_tf.LegacyCall" +// CHECK-NEXT: op: "foo0" + +// CHECK: library { +// CHECK-NEXT: function { +// CHECK-NEXT: signature { +// CHECK-NEXT: name: "foo0" + diff --git a/tensorflow/compiler/mlir/tensorflow/tests/resource-device-inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/resource-device-inference.mlir new file mode 100644 index 00000000000..c98e40fed05 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/resource-device-inference.mlir @@ -0,0 +1,244 @@ +// RUN: tf-opt -split-input-file -verify-diagnostics -tf-resource-device-inference %s | FileCheck %s --dump-input=fail + +// Tests that the pass can correctly propagate device attributes inside the same +// function. + +// CHECK-LABEL: func @propagate_in_function +func @propagate_in_function( + %arg0: tensor<*x!tf.resource>> {tf.device = "/TPU:0"}, + %arg1: tensor<*x!tf.resource>> {tf.device = "/TPU:1"}) { + tf_executor.graph { + // CHECK: tf_executor.island + %island = tf_executor.island { + // CHECK-NEXT: "tf.VarHandleOp" + %var_handle = "tf.VarHandleOp"() {container = "c", shared_name = "v0", device = "/CPU:0"} + : () -> tensor<*x!tf.resource>> + // CHECK-NEXT: "tf.Identity" + // CHECK-SAME: {device = "/TPU:0"} + %id0 = "tf.Identity"(%arg0) : (tensor<*x!tf.resource>>) + -> tensor<*x!tf.resource>> + // CHECK-NEXT: "tf.Identity" + // CHECK-SAME: {device = "/TPU:0"} + %id1 = "tf.Identity"(%id0) : (tensor<*x!tf.resource>>) + -> tensor<*x!tf.resource>> + // CHECK-NEXT: "tf.Identity" + // CHECK-SAME: {device = "/CPU:0"} + %id2 = "tf.Identity"(%var_handle) : (tensor<*x!tf.resource>>) + -> tensor<*x!tf.resource>> + %read = "tf.ReadVariableOp"(%id2) : (tensor<*x!tf.resource>>) -> tensor<32xf32> + %id3 = "tf.Identity"(%read) : (tensor<32xf32>) -> tensor<32xf32> + tf_executor.yield + } + tf_executor.fetch %island : !tf_executor.control + } + return +} + +// ----- + +// Tesets that the pass can propagate through tf.If's branches. + +// CHECK-LABEL: func @propagate_if_op +func @propagate_if_op( + %arg0: tensor<*x!tf.resource>> {tf.device = "/TPU:0"}, + %arg1: tensor) { + tf_executor.graph { + // CHECK: tf_executor.island + %island = tf_executor.island { + // CHECK-NEXT: "tf.Identity" + // CHECK-SAME: {device = "/TPU:0"} + %id0 = "tf.Identity"(%arg0) : (tensor<*x!tf.resource>>) + -> tensor<*x!tf.resource>> + // CHECK-NEXT: "tf.VarHandleOp" + %var_handle = "tf.VarHandleOp"() {container = "c", shared_name = "v0", device = "/TPU:1"} + : () -> tensor<*x!tf.resource>> + // CHECK-NEXT: "tf.If" + "tf.If"(%arg1, %id0, %var_handle) { + then_branch = @if_then, + else_branch = @if_else, + output_shapes = [], is_stateless = false} + : (tensor, tensor<*x!tf.resource>>, + tensor<*x!tf.resource>>) -> () + tf_executor.yield + } + tf_executor.fetch %island : !tf_executor.control + } + return +} + +// CHECK-LABEL: func @if_then +func @if_then( + %arg0: tensor<*x!tf.resource>>, + %arg1: tensor<*x!tf.resource>>) { + tf_executor.graph { + // CHECK: tf_executor.island + %island = tf_executor.island { + // CHECK-NEXT: "tf.Identity" + // CHECK-SAME: {device = "/TPU:0"} + %id0 = "tf.Identity"(%arg0) : (tensor<*x!tf.resource>>) + -> tensor<*x!tf.resource>> + // CHECK-NEXT: "tf.Identity" + // CHECK-SAME: {device = "/TPU:1"} + %id1 = "tf.Identity"(%arg1) : (tensor<*x!tf.resource>>) + -> tensor<*x!tf.resource>> + tf_executor.yield + } + tf_executor.fetch %island : !tf_executor.control + } + return +} + +// CHECK-LABEL: func @if_else +func @if_else( + %arg0: tensor<*x!tf.resource>>, + %arg1: tensor<*x!tf.resource>>) { + tf_executor.graph { + // CHECK: tf_executor.island + %island = tf_executor.island { + // CHECK-NEXT: "tf.Identity" + // CHECK-SAME: {device = "/TPU:0"} + %id0 = "tf.Identity"(%arg0) : (tensor<*x!tf.resource>>) + -> tensor<*x!tf.resource>> + tf_executor.yield + } + tf_executor.fetch %island : !tf_executor.control + } + return +} + + +// ----- + +// Tesets that the pass can propagate through tf.While's branches. + +// CHECK-LABEL: func @propagate_while_op +func @propagate_while_op( + %arg0: tensor<*x!tf.resource>> {tf.device = "/TPU:0"}, + %arg1: tensor) { + tf_executor.graph { + // CHECK: tf_executor.island + %island = tf_executor.island { + // CHECK-NEXT: "tf.Identity" + // CHECK-SAME: {device = "/TPU:0"} + %id0 = "tf.Identity"(%arg0) : (tensor<*x!tf.resource>>) + -> tensor<*x!tf.resource>> + // CHECK-NEXT: "tf.VarHandleOp" + %var_handle = "tf.VarHandleOp"() {container = "c", shared_name = "v0", device = "/TPU:1"} + : () -> tensor<*x!tf.resource>> + // CHECK-NEXT: "tf.While" + "tf.While"(%arg1, %id0, %var_handle) { + body = @while_body, + cond = @while_cond, + output_shapes = [], is_stateless = false} + : (tensor, tensor<*x!tf.resource>>, + tensor<*x!tf.resource>>) -> + (tensor, tensor<*x!tf.resource>>, + tensor<*x!tf.resource>>) + tf_executor.yield + } + tf_executor.fetch %island : !tf_executor.control + } + return +} + +// CHECK-LABEL: func @while_body +func @while_body( + %arg0: tensor, + %arg1: tensor<*x!tf.resource>>, + %arg2: tensor<*x!tf.resource>>) -> + (tensor, tensor<*x!tf.resource>>, + tensor<*x!tf.resource>>) { + %graph:3 = tf_executor.graph { + // CHECK: tf_executor.island + %island:4 = tf_executor.island { + // CHECK-NEXT: "tf.Identity" + // CHECK-SAME: {device = "/TPU:0"} + %id0 = "tf.Identity"(%arg1) : (tensor<*x!tf.resource>>) + -> tensor<*x!tf.resource>> + // CHECK-NEXT: "tf.Identity" + // CHECK-SAME: {device = "/TPU:1"} + %id1 = "tf.Identity"(%arg2) : (tensor<*x!tf.resource>>) + -> tensor<*x!tf.resource>> + tf_executor.yield %arg0, %id0, %id1 + : tensor, tensor<*x!tf.resource>>, + tensor<*x!tf.resource>> + } + tf_executor.fetch %island#0, %island#1, %island#2 + : tensor, tensor<*x!tf.resource>>, + tensor<*x!tf.resource>> + } + return %graph#0, %graph#1, %graph#2 + : tensor, tensor<*x!tf.resource>>, + tensor<*x!tf.resource>> +} + +// CHECK-LABEL: func @while_cond +func @while_cond( + %arg0: tensor, + %arg1: tensor<*x!tf.resource>>, + %arg2: tensor<*x!tf.resource>>) -> tensor<32xf32> { + %graph = tf_executor.graph { + // CHECK: tf_executor.island + %island:2 = tf_executor.island { + // CHECK-NEXT: "tf.Identity" + // CHECK-SAME: {device = "/TPU:0"} + %id0 = "tf.Identity"(%arg1) : (tensor<*x!tf.resource>>) + -> tensor<*x!tf.resource>> + %read = "tf.ReadVariableOp"(%id0) + : (tensor<*x!tf.resource>>) -> tensor<32xf32> + tf_executor.yield %read : tensor<32xf32> + } + tf_executor.fetch %island#0 : tensor<32xf32> + } + return %graph : tensor<32xf32> +} + +// ----- + +// Tesets that the pass reports error on conflicting assignments from multiple +// callers. + +func @error_on_conflict_multiple_callers( + %arg0: tensor<*x!tf.resource>> {tf.device = "/TPU:0"}, + %arg1: tensor) { + tf_executor.graph { + %island = tf_executor.island { + %id0 = "tf.Identity"(%arg0) : (tensor<*x!tf.resource>>) + -> tensor<*x!tf.resource>> + %var_handle = "tf.VarHandleOp"() {container = "c", shared_name = "v0", device = "/TPU:1"} + : () -> tensor<*x!tf.resource>> + "tf.If"(%arg1, %id0, %var_handle) { + then_branch = @if_then_and_else, + else_branch = @if_then_and_else, + output_shapes = [], is_stateless = false} + : (tensor, tensor<*x!tf.resource>>, + tensor<*x!tf.resource>>) -> () + "tf.If"(%arg1, %var_handle, %id0) { + // expected-error@above {{Conflicting device assignment for resource}} + then_branch = @if_then_and_else, + else_branch = @if_then_and_else, + output_shapes = [], is_stateless = false} + : (tensor, tensor<*x!tf.resource>>, + tensor<*x!tf.resource>>) -> () + tf_executor.yield + } + tf_executor.fetch %island : !tf_executor.control + } + return +} + +func @if_then_and_else( + %arg0: tensor<*x!tf.resource>>, + %arg1: tensor<*x!tf.resource>>) { + tf_executor.graph { + %island = tf_executor.island { + %id0 = "tf.Identity"(%arg0) : (tensor<*x!tf.resource>>) + -> tensor<*x!tf.resource>> + %id1 = "tf.Identity"(%arg1) : (tensor<*x!tf.resource>>) + -> tensor<*x!tf.resource>> + tf_executor.yield + } + tf_executor.fetch %island : !tf_executor.control + } + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir b/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir index 8ff72dbc7fc..e5905e5f681 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir @@ -8,7 +8,7 @@ func @only_resource_load() -> tensor<*xi32> { // CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp" %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource> - // CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) {dtype = "tfdtype$DT_INT32"} + // CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) {dtype = i32} // CHECK: "tf_device.launch" // CHECK: %[[COMPUTE_RES:[0-9]*]] = "tf.SomeComputation"(%[[RES_READ_VAL]]) // CHECK: tf_device.return %[[COMPUTE_RES]] @@ -16,7 +16,7 @@ func @only_resource_load() -> tensor<*xi32> { // CHECK-SAME: () -> tensor<*xi32> %1 = "tf_device.launch"() ( { - %2 = "tf.ReadVariableOp"(%0) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>) -> tensor<*xi32> + %2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32> %3 = "tf.SomeComputation"(%2) : (tensor<*xi32>) -> (tensor<*xi32>) tf_device.return %3 : tensor<*xi32> }) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<*xi32> @@ -39,11 +39,11 @@ func @only_resource_store() -> tensor<*xi32> { // CHECK: tf_device.return %[[COMPUTE_RES]], %[[COMPUTE_RES]] // CHECK: {device = "tpu0", launch_attr = "launch_attr"} // CHECK-SAME: () -> (tensor<*xi32>, tensor<*xi32>) - // CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[LAUNCH_RES]]#1) {dtype = "tfdtype$DT_INT32"} + // CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[LAUNCH_RES]]#1) {dtype = i32} %1 = "tf_device.launch"() ( { %2 = "tf.SomeComputation"() : () -> (tensor<*xi32>) - "tf.AssignVariableOp"(%0, %2) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>, tensor<*xi32>) -> () + "tf.AssignVariableOp"(%0, %2) {dtype = i32} : (tensor<*x!tf.resource>, tensor<*xi32>) -> () tf_device.return %2 : tensor<*xi32> }) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<*xi32> @@ -61,18 +61,18 @@ func @same_resource_load_and_store() -> tensor<*xi32> { // CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp" %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource> - // CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) {dtype = "tfdtype$DT_INT32"} + // CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) {dtype = i32} // CHECK: %[[LAUNCH_RES:[0-9]*]]:2 = "tf_device.launch" // CHECK: %[[COMPUTE_RES:[0-9]*]] = "tf.SomeComputation"(%[[RES_READ_VAL]]) // CHECK: tf_device.return %[[COMPUTE_RES]], %[[COMPUTE_RES]] // CHECK: {device = "tpu0", launch_attr = "launch_attr"} // CHECK-SAME: () -> (tensor<*xi32>, tensor<*xi32>) - // CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[LAUNCH_RES]]#1) {dtype = "tfdtype$DT_INT32"} + // CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[LAUNCH_RES]]#1) {dtype = i32} %1 = "tf_device.launch"() ( { - %2 = "tf.ReadVariableOp"(%0) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>) -> tensor<*xi32> + %2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32> %3 = "tf.SomeComputation"(%2) : (tensor<*xi32>) -> (tensor<*xi32>) - "tf.AssignVariableOp"(%0, %3) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>, tensor<*xi32>) -> () + "tf.AssignVariableOp"(%0, %3) {dtype = i32} : (tensor<*x!tf.resource>, tensor<*xi32>) -> () tf_device.return %3 : tensor<*xi32> }) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<*xi32> @@ -82,96 +82,6 @@ func @same_resource_load_and_store() -> tensor<*xi32> { // ----- -// Tests that composite tf.AssignAddVariableOp operation is decomposed and -// hoisted. - -// CHECK-LABEL: func @decompose_assign_add_variable_op -func @decompose_assign_add_variable_op() -> tensor<*xi32> { - - // CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp" - %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource> - - // CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) {dtype = "tfdtype$DT_INT32"} - // CHECK: %[[LAUNCH_RES:[0-9]*]]:2 = "tf_device.launch" - // CHECK: %[[ONE:[0-9]*]] = "tf.Const"() {value = dense<1> : tensor} - // CHECK: %[[COMPUTE_RES:[0-9]*]] = "tf.AddV2"(%[[RES_READ_VAL]], %[[ONE]]) - // CHECK: tf_device.return %[[COMPUTE_RES]], %[[COMPUTE_RES]] - // CHECK: {device = "tpu0", launch_attr = "launch_attr"} - // CHECK-SAME: () -> (tensor<*xi32>, tensor<*xi32>) - // CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[LAUNCH_RES]]#1) {dtype = "tfdtype$DT_INT32"} - - %1 = "tf_device.launch"() ( { - %2 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - "tf.AssignAddVariableOp"(%0, %2) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>, tensor) -> () - %3 = "tf.ReadVariableOp"(%0) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>) -> tensor<*xi32> - tf_device.return %3 : tensor<*xi32> - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<*xi32> - - // CHECK: return %[[LAUNCH_RES]]#0 - return %1 : tensor<*xi32> -} - -// ----- - -// Tests that composite tf.AssignSubVariableOp operation is decomposed using -// SubOp. - -// CHECK-LABEL: func @decompose_assign_sub_variable_op -func @decompose_assign_sub_variable_op() -> tensor<*xi32> { - - %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource> - - // CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp" - // CHECK: %[[ONE:[0-9]*]] = "tf.Const"() {value = dense<1> : tensor} - // CHECK: "tf.Sub"(%[[RES_READ_VAL]], %[[ONE]]) - // CHECK: "tf.AssignVariableOp" - - %1 = "tf_device.launch"() ( { - %2 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - "tf.AssignSubVariableOp"(%0, %2) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>, tensor) -> () - %3 = "tf.ReadVariableOp"(%0) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>) -> tensor<*xi32> - tf_device.return %3 : tensor<*xi32> - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<*xi32> - - return %1 : tensor<*xi32> -} - -// ----- - -// Tests that composite tf.ResourceApplyGradientDescent operation is decomposed -// and hoisted. - -// CHECK-LABEL: func @decompose_resource_apply_gradient_descent -func @decompose_resource_apply_gradient_descent() -> tensor<*xf32> { - - // CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp" - %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource> - - // CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) {dtype = "tfdtype$DT_FLOAT"} - // CHECK: %[[LAUNCH_RES:[0-9]*]]:2 = "tf_device.launch" - // CHECK: %[[ALPHA:[0-9]*]] = "tf.Const" - // CHECK: %[[DELTA:[0-9]*]] = "tf.Const" - // CHECK: %[[MUL:[0-9]*]] = "tf.Mul"(%[[ALPHA]], %[[DELTA]]) - // CHECK: %[[SUB:[0-9]*]] = "tf.Sub"(%[[RES_READ_VAL]], %[[MUL]]) - // CHECK: tf_device.return %[[SUB]], %[[SUB]] - // CHECK: {device = "tpu0", launch_attr = "launch_attr"} - // CHECK-SAME: () -> (tensor<*xf32>, tensor<*xf32>) - // CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[LAUNCH_RES]]#1) {dtype = "tfdtype$DT_FLOAT"} - - %1 = "tf_device.launch"() ( { - %2 = "tf.Const"() {T = "tfdtype$DT_FLOAT", value = dense<[1.0]> : tensor<1xf32>} : () -> tensor - %3 = "tf.Const"() {T = "tfdtype$DT_FLOAT", value = dense<[0.5]> : tensor<1xf32>} : () -> tensor - "tf.ResourceApplyGradientDescent"(%0, %2, %3) : (tensor<*x!tf.resource>, tensor, tensor) -> () - %4 = "tf.ReadVariableOp"(%0) {dtype = "tfdtype$DT_FLOAT"} : (tensor<*x!tf.resource>) -> tensor<*xf32> - tf_device.return %4 : tensor<*xf32> - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<*xf32> - - // CHECK: return %[[LAUNCH_RES]]#0 - return %1 : tensor<*xf32> -} - -// ----- - // Tests that internal resource operations are not hoisted. // CHECK-LABEL: func @internal_resource @@ -184,13 +94,13 @@ func @internal_resource() -> tensor<*xi32> { %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource> // CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) - %2 = "tf.ReadVariableOp"(%1) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>) -> tensor<*xi32> + %2 = "tf.ReadVariableOp"(%1) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32> // CHECK: %[[COMPUTE_RES:[0-9]*]] = "tf.SomeComputation"(%[[RES_READ_VAL]]) %3 = "tf.SomeComputation"(%2) : (tensor<*xi32>) -> (tensor<*xi32>) // CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[COMPUTE_RES]]) - "tf.AssignVariableOp"(%1, %3) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>, tensor<*xi32>) -> () + "tf.AssignVariableOp"(%1, %3) {dtype = i32} : (tensor<*x!tf.resource>, tensor<*xi32>) -> () // CHECK: tf_device.return %[[COMPUTE_RES]] tf_device.return %3 : tensor<*xi32> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir index acf236f8e1f..5a3c531023c 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir @@ -1,6 +1,17 @@ // RUN: tf-opt %s -tf-shape-inference -verify-diagnostics | FileCheck %s -dump-input=fail -color module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 130 : i32}} { +// CHECK-LABEL: func @main(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> + func @main(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<*xi32> { + // CHECK: %[[ARG0:.*]] = "tf.Cast"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> + // CHECK: %[[ARG1:.*]] = "tf.Cast"(%arg1) : (tensor<1xi32>) -> tensor<1xi32> + // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%[[ARG0]], %[[ARG1]]) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + // CHECK: return %[[RESULT]] : tensor<1xi32> + %0 = "tf.Cast"(%arg0) : (tensor<1xi32>) -> tensor<*xi32> + %1 = "tf.Cast"(%arg1) : (tensor<1xi32>) -> tensor<*xi32> + %2 = "tf.AddV2"(%0, %1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> + return %2 : tensor<*xi32> + } // CHECK-LABEL: func @simple_chain func @simple_chain(%arg0: tensor<1xf32>) -> tensor<*xf32> { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir b/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir index c6eb4663e57..678c2373a1b 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir @@ -6,18 +6,15 @@ // CHECK-LABEL: func @non_aliasing_reads_writes func @non_aliasing_reads_writes( // expected-remark@above {{ID: 13}} -// expected-remark@above {{Predecessors: {12}}} %arg0: tensor<*x!tf.resource>>, %arg1: tensor<*x!tf.resource>>, %arg2: tensor<32xf32>) -> (tensor<32xf32>) { %graph = tf_executor.graph { // expected-remark@above {{ID: 11}} - // expected-remark@above {{Predecessors: {10}}} // expected-remark@above {{Successors: {12}}} // CHECK: tf_executor.island %island:2 = tf_executor.island { // expected-remark@above {{ID: 9}} - // expected-remark@above {{Predecessors: {8}}} // expected-remark@above {{Successors: {10}}} %read0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>>) -> tensor<32xf32> // expected-remark@above {{ID: 0}} @@ -49,17 +46,14 @@ func @non_aliasing_reads_writes( tf_executor.yield %read3 : tensor<32xf32> // expected-remark@above {{ID: 8}} // expected-remark@above {{Predecessors: {4,5,7}}} - // expected-remark@above {{Successors: {9}}} } tf_executor.fetch %island#0 : tensor<32xf32> // expected-remark@above {{ID: 10}} // expected-remark@above {{Predecessors: {9}}} - // expected-remark@above {{Successors: {11}}} } return %graph : tensor<32xf32> // expected-remark@above {{ID: 12}} // expected-remark@above {{Predecessors: {11}}} - // expected-remark@above {{Successors: {13}}} } // ----- @@ -70,15 +64,12 @@ func @non_aliasing_reads_writes( // CHECK-LABEL: func @aliasing_reads_writes func @aliasing_reads_writes(%arg0: tensor<32xf32>) -> () { // expected-remark@above {{ID: 14}} -// expected-remark@above {{Predecessors: {13}}} tf_executor.graph { // expected-remark@above {{ID: 12}} - // expected-remark@above {{Predecessors: {11}}} // expected-remark@above {{Successors: {13}}} // CHECK: tf_executor.island %island = tf_executor.island { // expected-remark@above {{ID: 10}} - // expected-remark@above {{Predecessors: {9}}} // expected-remark@above {{Successors: {11}}} %vh0 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> tensor<*x!tf.resource>> // expected-remark@above {{ID: 0}} @@ -112,17 +103,14 @@ func @aliasing_reads_writes(%arg0: tensor<32xf32>) -> () { tf_executor.yield // expected-remark@above {{ID: 9}} // expected-remark@above {{Predecessors: {8}}} - // expected-remark@above {{Successors: {10}}} } tf_executor.fetch %island : !tf_executor.control // expected-remark@above {{ID: 11}} // expected-remark@above {{Predecessors: {10}}} - // expected-remark@above {{Successors: {12}}} } return // expected-remark@above {{ID: 13}} // expected-remark@above {{Predecessors: {12}}} - // expected-remark@above {{Successors: {14}}} } // ----- @@ -133,15 +121,12 @@ func @aliasing_reads_writes(%arg0: tensor<32xf32>) -> () { // CHECK-LABEL: func @unknown_side_effecting_op func @unknown_side_effecting_op(%arg0: tensor<32xf32>) -> () { // expected-remark@above {{ID: 13}} -// expected-remark@above {{Predecessors: {12}}} tf_executor.graph { // expected-remark@above {{ID: 11}} - // expected-remark@above {{Predecessors: {10}}} // expected-remark@above {{Successors: {12}}} // CHECK: tf_executor.island %island = tf_executor.island { // expected-remark@above {{ID: 9}} - // expected-remark@above {{Predecessors: {8}}} // expected-remark@above {{Successors: {10}}} %vh0 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> tensor<*x!tf.resource>> // expected-remark@above {{ID: 0}} @@ -172,17 +157,14 @@ func @unknown_side_effecting_op(%arg0: tensor<32xf32>) -> () { tf_executor.yield // expected-remark@above {{ID: 8}} // expected-remark@above {{Predecessors: {6,7}}} - // expected-remark@above {{Successors: {9}}} } tf_executor.fetch %island : !tf_executor.control // expected-remark@above {{ID: 10}} // expected-remark@above {{Predecessors: {9}}} - // expected-remark@above {{Successors: {11}}} } return // expected-remark@above {{ID: 12}} // expected-remark@above {{Predecessors: {11}}} - // expected-remark@above {{Successors: {13}}} } // ----- @@ -193,15 +175,12 @@ func @unknown_side_effecting_op(%arg0: tensor<32xf32>) -> () { // CHECK-LABEL: func @read_only_unknown_resource func @read_only_unknown_resource(%arg0: tensor<32xf32>) -> () { // expected-remark@above {{ID: 10}} -// expected-remark@above {{Predecessors: {9}}} tf_executor.graph { // expected-remark@above {{ID: 8}} - // expected-remark@above {{Predecessors: {7}}} // expected-remark@above {{Successors: {9}}} // CHECK: tf_executor.island %island = tf_executor.island { // expected-remark@above {{ID: 6}} - // expected-remark@above {{Predecessors: {5}}} // expected-remark@above {{Successors: {7}}} %vh0 = "tf._UnknownSideEffectingOp_"() : () -> tensor<*x!tf.resource>> // expected-remark@above {{ID: 0}} @@ -223,15 +202,71 @@ func @read_only_unknown_resource(%arg0: tensor<32xf32>) -> () { tf_executor.yield // expected-remark@above {{ID: 5}} // expected-remark@above {{Predecessors: {4}}} - // expected-remark@above {{Successors: {6}}} } tf_executor.fetch %island : !tf_executor.control // expected-remark@above {{ID: 7}} // expected-remark@above {{Predecessors: {6}}} - // expected-remark@above {{Successors: {8}}} } return // expected-remark@above {{ID: 9}} // expected-remark@above {{Predecessors: {8}}} - // expected-remark@above {{Successors: {10}}} +} + +// ----- + +// Tests that the pass adds control dependencies in nested regions with +// tf_device.replicate + +func @with_replicate( + // expected-remark@above {{ID: 12}} + %arg0: tensor<*x!tf.resource>>, + %arg1: tensor<*x!tf.resource>>, + %arg2: tensor<*x!tf.resource>>, + %arg3: tensor<*x!tf.resource>>) { + tf_executor.graph { + // expected-remark@above {{ID: 10}} + // expected-remark@above {{Successors: {11}}} + %island = tf_executor.island { + // expected-remark@above {{ID: 8}} + // expected-remark@above {{Successors: {9}}} + %u0:2 = "tf._UnknownSideEffectingOp_"() : () -> (tensor<32xf32>, tensor<32xf32>) + // expected-remark@above {{ID: 0}} + // expected-remark@above {{Successors: {5}}} + tf_device.replicate( + // expected-remark@above {{ID: 5}} + // expected-remark@above {{Predecessors: {0}}} + // expected-remark@above {{Successors: {6}}} + [%arg0, %arg1] as %r0: tensor<*x!tf.resource>>, + [%arg2, %arg3] as %r1: tensor<*x!tf.resource>>, + [%u0#0, %u0#1] as %u : tensor<32xf32>) + {n = 2 : i32, devices = ["/CPU:0", "/GPU:1"]} { + %read0 = "tf.ReadVariableOp"(%r0) : (tensor<*x!tf.resource>>) -> tensor<32xf32> + // expected-remark@above {{ID: 1}} + // expected-remark@above {{Successors: {4}}} + "tf.AssignVariableOp"(%r1, %u) : (tensor<*x!tf.resource>>, tensor<32xf32>) -> () + // expected-remark@above {{ID: 2}} + // expected-remark@above {{Successors: {3}}} + %read1 = "tf.ReadVariableOp"(%r1) : (tensor<*x!tf.resource>>) -> tensor<32xf32> + // expected-remark@above {{ID: 3}} + // expected-remark@above {{Predecessors: {2}}} + // expected-remark@above {{Successors: {4}}} + tf_device.return + // expected-remark@above {{ID: 4}} + // expected-remark@above {{Predecessors: {1,3}}} + } + "tf._UnknownSideEffectingOp_"() : () -> () + // expected-remark@above {{ID: 6}} + // expected-remark@above {{Predecessors: {5}}} + // expected-remark@above {{Successors: {7}}} + tf_executor.yield + // expected-remark@above {{ID: 7}} + // expected-remark@above {{Predecessors: {6}}} + } + tf_executor.fetch %island : !tf_executor.control + // expected-remark@above {{ID: 9}} + // expected-remark@above {{Predecessors: {8}}} + } + return + // expected-remark@above {{ID: 11}} + // expected-remark@above {{Predecessors: {10}}} } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir index e064c1a53ef..90aa6e73f79 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir @@ -1610,7 +1610,7 @@ func @testSplitUnknownDimInput(%input: tensor<4x?x4xf32>) { // ----- -func @testSplitNonConstSplitDim(%input: tensor<4x4xf32>, %split_dim: tensor<1xi32>) { +func @testSplitNonScalarSplitDim(%input: tensor<4x4xf32>, %split_dim: tensor<1xi32>) { // expected-error @+1 {{split dimension should be an integer scalar tensor}} %0:2 = "tf.Split"(%split_dim, %input) : (tensor<1xi32>, tensor<4x4xf32>) -> (tensor<*xf32>, tensor<*xf32>) return @@ -1674,3 +1674,152 @@ func @testTopKV2WrongKRank(%input: tensor<8xf32>, %k: tensor<5xi32>) { %0:2 = "tf.TopKV2"(%input, %k) : (tensor<8xf32>, tensor<5xi32>) -> (tensor<*xf32>, tensor<*xi32>) return } + +// ----- + +func @testSplitVScalarInput(%input: tensor, %split_sizes: tensor<2xi32>, %split_dim: tensor) { + // expected-error @+1 {{cannot split scalar input tensor}} + %0:2 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor, tensor<2xi32>, tensor) -> (tensor<*xf32>, tensor<*xf32>) + return +} + +// ----- + +func @testSplitVNonScalarSplitDim(%input: tensor<4x4xf32>, %split_sizes: tensor<2xi32>, %split_dim: tensor<1xi32>) { + // expected-error @+1 {{split dimension should be an integer scalar tensor}} + %0:2 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x4xf32>, tensor<2xi32>, tensor<1xi32>) -> (tensor<*xf32>, tensor<*xf32>) + return +} + +// ----- + +func @testSplitVSplitDimOutOfRange(%input: tensor<4x4xf32>, %split_sizes: tensor<2xi32>) { + %split_dim = "tf.Const"() {value = dense<100>: tensor} : () -> (tensor) + // expected-error @+1 {{split dimension must be in range [-2, 2)}} + %0:2 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x4xf32>, tensor<2xi32>, tensor) -> (tensor<*xf32>, tensor<*xf32>) + return +} + +// ----- + +func @testSplitVWrongSplitSizesType(%input: tensor<4x4xf32>, %split_sizes: tensor<2x2xi32>, %split_dim: tensor) { + // expected-error @+1 {{op split sizes should be a 1D tensor of 2 elements}} + %0:2 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x4xf32>, tensor<2x2xi32>, tensor) -> (tensor<*xf32>, tensor<*xf32>) + return +} + +// ----- + +func @testSplitVMultipleDynamicSizes(%input: tensor<4x4xf32>) { + %split_dim = "tf.Const"() {value = dense<1>: tensor} : () -> (tensor) + %split_sizes = "tf.Const"() {value = dense<[-1, -1]>: tensor<2xi32>} : () -> (tensor<2xi32>) + // expected-error @+1 {{cannot have more than one dynamic dimension in split sizes}} + %0:2 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x4xf32>, tensor<2xi32>, tensor) -> (tensor<*xf32>, tensor<*xf32>) + return +} + +// ----- + +func @testSplitVSplitSizeOutOfRange(%input: tensor<4x4xf32>) { + %split_dim = "tf.Const"() {value = dense<1>: tensor} : () -> (tensor) + %split_sizes = "tf.Const"() {value = dense<[-1, 100]>: tensor<2xi32>} : () -> (tensor<2xi32>) + // expected-error @+1 {{split sizes must sum up to be less than or equal to the dimension size along split dimension, found 100 vs 4}} + %0:2 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x4xf32>, tensor<2xi32>, tensor) -> (tensor<*xf32>, tensor<*xf32>) + return +} + +// ----- + +func @testSplitVSplitSizeOutOfRange(%input: tensor<4x4xf32>) { + %split_dim = "tf.Const"() {value = dense<1>: tensor} : () -> (tensor) + %split_sizes = "tf.Const"() {value = dense<[2, 3]>: tensor<2xi32>} : () -> (tensor<2xi32>) + // expected-error @+1 {{split sizes must sum up to the dimension size along split dimension, found 5 vs 4}} + %0:2 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x4xf32>, tensor<2xi32>, tensor) -> (tensor<*xf32>, tensor<*xf32>) + return +} + +// ----- + +func @testSplitV1(%input: tensor<4x4xf32>) { + %split_dim = "tf.Const"() {value = dense<1>: tensor} : () -> (tensor) + %split_sizes = "tf.Const"() {value = dense<[-1, 4]>: tensor<2xi32>} : () -> (tensor<2xi32>) + %0:2 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x4xf32>, tensor<2xi32>, tensor) -> (tensor<*xf32>, tensor<*xf32>) + return +} + +func @testSplitV2(%input: tensor<4x4xf32>) { + %split_dim = "tf.Const"() {value = dense<1>: tensor} : () -> (tensor) + %split_sizes = "tf.Const"() {value = dense<[3, 1]>: tensor<2xi32>} : () -> (tensor<2xi32>) + %0:2 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x4xf32>, tensor<2xi32>, tensor) -> (tensor<*xf32>, tensor<*xf32>) + return +} + +// ----- + +//===--------------------------------------------------------------------===// +// tf.All +//===--------------------------------------------------------------------===// + +func @testAllDimWrongRank(%input: tensor<4x6xi1>, %dims: tensor<2x2xi32>) { + // expected-error @+1 {{dimensions can only be 0D or 1D tensor}} + %0 = "tf.All"(%input, %dims) : (tensor<4x6xi1>, tensor<2x2xi32>) -> (tensor<*xi1>) + return +} + +// ----- + +func @testAllDimOutOfRange(%input: tensor<4x6xi1>) { + %dims = "tf.Const"() {value = dense<[-1, 5]> : tensor<2xi32>} : () -> (tensor<2xi32>) + // expected-error @+1 {{1-th dimension should be in the range of [-2, 2)}} + %0 = "tf.All"(%input, %dims) : (tensor<4x6xi1>, tensor<2xi32>) -> (tensor<*xi1>) + return +} + +// ----- + +//===--------------------------------------------------------------------===// +// tf.Any +//===--------------------------------------------------------------------===// + +func @testAnyDimWrongRank(%input: tensor<4x6xi1>, %dims: tensor<2x2xi32>) { + // expected-error @+1 {{dimensions can only be 0D or 1D tensor}} + %0 = "tf.Any"(%input, %dims) : (tensor<4x6xi1>, tensor<2x2xi32>) -> (tensor<*xi1>) + return +} + +// ----- + +func @testAnyDimOutOfRange(%input: tensor<4x6xi1>) { + %dims = "tf.Const"() {value = dense<[-1, 5]> : tensor<2xi32>} : () -> (tensor<2xi32>) + // expected-error @+1 {{1-th dimension should be in the range of [-2, 2)}} + %0 = "tf.Any"(%input, %dims) : (tensor<4x6xi1>, tensor<2xi32>) -> (tensor<*xi1>) + return +} + +// ----- + +//===--------------------------------------------------------------------===// +// tf.Unpack +//===--------------------------------------------------------------------===// + +func @testUnpackAxisOutOfRange(%input: tensor<2x6xf32>) { + // expected-error @+1 {{axis attribute must be in the range of [-2, 2)}} + %0:2 = "tf.Unpack"(%input) {axis = 5} : (tensor<2x6xf32>) -> (tensor<6xf32>, tensor<6xf32>) + return +} + +// ----- + +func @testAxisUnknownDim(%input: tensor) { + // CHECK: tf.Unpack + %0:2 = "tf.Unpack"(%input) {axis = 0} : (tensor) -> (tensor<6xf32>, tensor<6xf32>) + return +} + +// ----- + +func @testAxisDim(%input: tensor<2x6xf32>) { + // expected-error @+1 {{result count must be equal to 6}} + %0:2 = "tf.Unpack"(%input) {axis = -1} : (tensor<2x6xf32>) -> (tensor<6xf32>, tensor<6xf32>) + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/exported_python_args.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/exported_python_args.py new file mode 100644 index 00000000000..f73aa83a76c --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/exported_python_args.py @@ -0,0 +1,41 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# RUN: (! %p/exported_python_args 2>&1) | FileCheck %s + +# pylint: disable=missing-docstring,line-too-long,dangerous-default-value +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow.compat.v2 as tf +from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common + + +class TestModule(tf.Module): + + @tf.function(input_signature=[tf.TensorSpec([], tf.float32)]) + def some_function(self, x): + return self.callee(x) + + # CHECK: While importing SavedModel function 'callee': in input signature: + # CHECK-SAME: Unhandled structured value kind {{.*}} at index path: .1.foo + @tf.function + def callee(self, x, n={'foo': 42}): + return x + + +if __name__ == '__main__': + common.do_test(TestModule) diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc index a7f45c41f15..c08d17104ea 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc @@ -31,9 +31,14 @@ void CreateTPUBridge(OpPassManager &pm) { func_pm.addPass(tf_executor::CreateTFExecutorIslandCoarseningPass()); func_pm.addPass(CreateTPUClusterFormationPass()); func_pm.addPass(createCanonicalizerPass()); + // Place DecomposeResourceOpsPass before TFExecutorConstantSinking pass + // because DecomposeResourceOpsPass uses pattern rewriter which hoists + // changed constants out of tf_device.Launch. + func_pm.addPass(TFDevice::CreateDecomposeResourceOpsPass()); func_pm.addPass(tf_executor::CreateTFExecutorConstantSinkingPass()); func_pm.addPass(TFDevice::CreateResourceOpLiftingPass()); + pm.addPass(TF::CreateResourceDeviceInferencePass()); pm.addPass(TFDevice::CreateClusterOutliningPass()); pm.addPass(CreateTPURewritePass()); pm.addNestedPass(TFDevice::CreateReplicateInvariantOpHoistingPass()); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc index 7dab06124dc..10337df1a66 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc @@ -44,16 +44,16 @@ struct ClusterOutliningPass : public ModulePass { void ReplaceLaunchReturnWithReturn(tf_device::ReturnOp launch_return_op, OpBuilder* builder) { - llvm::SmallVector operands(launch_return_op.getOperands()); - builder->create(launch_return_op.getLoc(), operands); + builder->create(launch_return_op.getLoc(), + launch_return_op.getOperands()); launch_return_op.erase(); } // Builds a function that outlines region attached to launch_op and inserts // built function into given module. FuncOp BuildFunction(StringRef device, llvm::ArrayRef live_ins, - tf_device::LaunchOp launch_op, - ModuleManager* module_manager, OpBuilder* builder) { + tf_device::LaunchOp launch_op, SymbolTable* symbol_table, + OpBuilder* builder) { llvm::SmallVector operand_types; operand_types.reserve(live_ins.size()); for (Value* v : live_ins) operand_types.emplace_back(v->getType()); @@ -92,14 +92,14 @@ FuncOp BuildFunction(StringRef device, llvm::ArrayRef live_ins, builder->setInsertionPoint(launch_return_op); ReplaceLaunchReturnWithReturn(launch_return_op, builder); - module_manager->insert(outlined_func); + symbol_table->insert(outlined_func); return outlined_func; } // Outlines body of `tf_device.launch` into a function and create a // `tf_device.launch_func` to invoke that function. `tf_device.launch` is // removed afterwards.` -void OutlineLaunch(tf_device::LaunchOp launch_op, ModuleManager* module_manager, +void OutlineLaunch(tf_device::LaunchOp launch_op, SymbolTable* symbol_table, OpBuilder* builder) { llvm::SetVector live_ins; getUsedValuesDefinedAbove(launch_op.body(), launch_op.body(), live_ins); @@ -108,7 +108,7 @@ void OutlineLaunch(tf_device::LaunchOp launch_op, ModuleManager* module_manager, launch_op.getAttrOfType(kDeviceAttr).getValue(); FuncOp outlined_func = BuildFunction(device, live_ins.getArrayRef(), - launch_op, module_manager, builder); + launch_op, symbol_table, builder); launch_op.setAttr(builder->getIdentifier(kFuncAttr), builder->getSymbolRefAttr(outlined_func.getName())); @@ -124,10 +124,10 @@ void OutlineLaunch(tf_device::LaunchOp launch_op, ModuleManager* module_manager, void ClusterOutliningPass::runOnModule() { ModuleOp m = getModule(); - ModuleManager module_manager(m); + SymbolTable symbol_table(m); OpBuilder builder(m.getContext()); m.walk([&](tf_device::LaunchOp launch) { - OutlineLaunch(launch, &module_manager, &builder); + OutlineLaunch(launch, &symbol_table, &builder); }); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.cc b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.cc new file mode 100644 index 00000000000..b70d14fd43b --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.cc @@ -0,0 +1,31 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.h" + +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir { +namespace TF { + +#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_decompose_resource_ops.inc" + +void PopulateDecomposeResourceOpsPatterns(MLIRContext *context, + OwningRewritePatternList *patterns) { + populateWithGenerated(context, patterns); +} + +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.h b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.h new file mode 100644 index 00000000000..813fc649059 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.h @@ -0,0 +1,34 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_DECOMPOSE_RESOURCE_OPS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_DECOMPOSE_RESOURCE_OPS_H_ + +#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir +#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir + +namespace mlir { +namespace TF { + +// Populates rewrite patterns that decompose composite resource operations into +// primitive ones like ReadVariableOp, AssignVariableOp and other computations +// to facilitate transformations like resource op lifting. +void PopulateDecomposeResourceOpsPatterns(MLIRContext *context, + OwningRewritePatternList *patterns); + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_DECOMPOSE_RESOURCE_OPS_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td new file mode 100644 index 00000000000..29c99cdc3d0 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td @@ -0,0 +1,63 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +include "mlir/IR/OpBase.td" +include "mlir/Dialect/StandardOps/Ops.td" +include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" + +def CreateTFReadVariableOp: NativeCodeCall< + "$_builder.create(" + " $0.getLoc()," + " UnrankedTensorType::get(" + " $1->getType().cast().getElementType())," + " $2)" + >; + +def DecomposeAssignAddVariableOp : + Pat< + (TF_AssignAddVariableOp:$src_op $resource, $value), + (TF_AssignVariableOp + $resource, + (TF_AddV2Op + (CreateTFReadVariableOp $src_op, $value, $resource), + $value + ) + ) + >; + +def DecomposeAssignSubVariableOp : + Pat< + (TF_AssignSubVariableOp:$src_op $resource, $value), + (TF_AssignVariableOp + $resource, + (TF_SubOp + (CreateTFReadVariableOp $src_op, $value, $resource), + $value + ) + ) + >; + +// This decomposition is only correct inside XLA as it ignores use_locking +// attribute. +def DecomposeResourceApplyGradientDescentOp : + Pat< + (TF_ResourceApplyGradientDescentOp:$src_op $resource, $alpha, $delta, $_), + (TF_AssignVariableOp + $resource, + (TF_SubOp + (CreateTFReadVariableOp $src_op, $alpha, $resource), + (TF_MulOp $alpha, $delta) + ) + ) + >; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops_pass.cc new file mode 100644 index 00000000000..b7be4ff8742 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops_pass.cc @@ -0,0 +1,59 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir +#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.h" + +namespace mlir { +namespace TFDevice { +namespace { + +// A pass that decomposes composite resource operations into primitive ones like +// ReadVariableOp, AssignVariableOp and other computations to facilitate +// transformations like resource op lifting. +// +// For example: +// +// tf.AssignAddVariableOp(%res, %0) +// +// Becomes +// +// %res_val = tf.ReadVariableOp(%res) +// %1 = tf.AddV2(%res_val, %0) +// tf.AssignVariableOp(%res, %1) +struct DecomposeResourceOps : public FunctionPass { + void runOnFunction() override { + // Add lowering patterns to the list. + OwningRewritePatternList patterns; + mlir::TF::PopulateDecomposeResourceOpsPatterns(&getContext(), &patterns); + + applyPatternsGreedily(getFunction(), patterns); + } +}; + +} // namespace + +std::unique_ptr> CreateDecomposeResourceOpsPass() { + return std::make_unique(); +} + +} // namespace TFDevice +} // namespace mlir + +static mlir::PassRegistration pass( + "tf-device-decompose-resource-ops", + "Decompose composite resource variable operations into primitive " + "Read/AssignVariableOp and raw computation"); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc index c6958d992f1..918e6ac3078 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc @@ -304,8 +304,7 @@ void InsertDummyIslandForFetch(FetchOp fetch) { /*control=*/ControlType::get(fetch.getContext()), /*controlInputs=*/control_fetches); island.body().push_back(new Block); - OpBuilder(&island.GetBody()) - .create(fetch.getLoc(), llvm::to_vector<4>(data_fetches)); + OpBuilder(&island.GetBody()).create(fetch.getLoc(), data_fetches); const int fetch_control_idx = data_fetches.size(); for (int i = 0, e = fetch.getNumOperands(); i < e; i++) { // The fetch could have multiple control operands (all at the end of its diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc index 880b4c4210b..e9b3879c025 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc @@ -17,6 +17,7 @@ limitations under the License. // standard TensorFlow dialect to MLIR Control Flow Graph (CFG) form. #include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:local_config_mlir #include "mlir/IR/Builders.h" // TF:local_config_mlir #include "mlir/IR/Operation.h" // TF:local_config_mlir #include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir @@ -79,8 +80,11 @@ static Operation* CallFn(Location loc, for (int i = 0; i < num_operands; ++i) { Value* val = get_arg(i); Type expected = fn_type.getInput(i); - if (val->getType() != expected) - val = builder->create(loc, val, expected); + if (val->getType() != expected) { + val = + builder->create(loc, expected, val, + /*Truncate=*/builder->getBoolAttr(false)); + } operands.push_back(val); } return builder->create(loc, fn, operands).getOperation(); @@ -100,8 +104,11 @@ static llvm::SmallVector PrepareValsForJump( for (int i = 0; i < num_vals; ++i) { Value* val = get_val(i); Type expected = block->getArgument(i)->getType(); - if (val->getType() != expected) - val = builder->create(loc, val, expected); + if (val->getType() != expected) { + val = + builder->create(loc, expected, val, + /*Truncate=*/builder->getBoolAttr(false)); + } result.push_back(val); } return result; @@ -131,8 +138,11 @@ static void ReplaceOpResultWithBlockArgs(Location loc, Operation* op, for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) { Value* arg = block->getArgument(i); Value* result = op->getResult(i); - if (arg->getType() != result->getType()) - arg = builder->create(loc, arg, result->getType()); + if (arg->getType() != result->getType()) { + arg = + builder->create(loc, result->getType(), arg, + /*Truncate=*/builder->getBoolAttr(false)); + } result->replaceAllUsesWith(arg); } } @@ -301,26 +311,15 @@ void FunctionalControlFlowToCFG::runOnFunction() { // subsequent blocks. // // TODO: Use PatternRewriter to eliminate these function control flow ops. - auto has_variant_operand = [](Operation* op) { - auto is_variant = [](Type ty) { - return getElementTypeOrSelf(ty).getKind() == TensorFlowTypes::VARIANT; - }; - - if (llvm::none_of(op->getOperandTypes(), is_variant)) return false; - - op->emitOpError() << "does not yet support operands of type variant " - "for conversion to CFG"; - return true; - }; if (IfOp if_op = llvm::dyn_cast(op)) { - if (has_variant_operand(&op) || failed(LowerIfOp(if_op))) { + if (failed(LowerIfOp(if_op))) { return signalPassFailure(); } break; } if (WhileOp while_op = llvm::dyn_cast(op)) { - if (has_variant_operand(&op) || failed(LowerWhileOp(while_op))) { + if (failed(LowerWhileOp(while_op))) { return signalPassFailure(); } break; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc index 65c6ac86288..89941c2fab4 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc @@ -24,6 +24,7 @@ limitations under the License. #include "mlir/IR/StandardTypes.h" // TF:local_config_mlir #include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/core/util/tensor_format.h" namespace mlir { @@ -109,6 +110,39 @@ Type InferExpandDimsType(Type ty, int64_t axis, Builder *builder) { return RankedTensorType::get(shape, ranked_ty.getElementType()); } +// Lowers AddN op to a sequence of AddV2 ops to accumulate operands. +// +// %result = "tf.AddN"(%0, %1, %2) +// +// is lowered to: +// +// %sum_0 = "tf.AddV2"(%0, %1) +// %result = "tf.AddV2"(%sum_0, %2) +// +class LowerAddNOp : public OpRewritePattern { + public: + explicit LowerAddNOp(MLIRContext *context) + : OpRewritePattern(context) {} + + PatternMatchResult matchAndRewrite(TF::AddNOp op, + PatternRewriter &rewriter) const override { + // TODO(hinsu): Support variant with TensorList type. tf.AddV2 doesn't + // support variant type so variant types require special handling. + if (getElementTypeOrSelf(op.getType()).isa()) + return matchFailure(); + + // TODO(hinsu): Improve parallelism by splitting operands in two halves and + // accumulating them first. + Value *result = *op.inputs().begin(); + for (Value *operand : llvm::drop_begin(op.inputs(), 1)) { + result = rewriter.create(op.getLoc(), result, operand); + } + + rewriter.replaceOp(op, result); + return matchSuccess(); + } +}; + // Lowers Pack op to ConcatV2 op after changing shape of the inputs with // ExpandDims op. // @@ -159,6 +193,7 @@ class LowerPackOp : public OpRewritePattern { void PopulateLoweringTFPatterns(MLIRContext *context, OwningRewritePatternList *patterns) { + patterns->insert(context); patterns->insert(context); populateWithGenerated(context, patterns); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc index b0420663bde..6e28b19ad80 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc @@ -63,7 +63,7 @@ void CreateTFStandardPipeline(OpPassManager &pm, if (options.enable_inliner) { pm.addPass(createInlinerPass()); } - pm.addNestedPass(CreateTFShapeInferencePass()); + pm.addPass(CreateTFShapeInferencePass()); pm.addNestedPass(CreateTFOptimizePass()); pm.addNestedPass(createCSEPass()); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index 30ee91f4aea..fca1c02bc62 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -57,6 +57,9 @@ struct StandardPipelineOptions : public PassOptions { // NOLINTNEXTLINE - MLIR contract is pass by mutable reference. void CreateTFStandardPipeline(OpPassManager& pm, const StandardPipelineOptions& options); + +// Propagates device attributes of resources from callers to callees. +std::unique_ptr> CreateResourceDeviceInferencePass(); } // namespace TF namespace TFControlFlow { @@ -96,6 +99,11 @@ std::unique_ptr> CreateClusterFormationPass(); // Creates a pass that outlines regions of tf_device.launch operations. std::unique_ptr> CreateClusterOutliningPass(); +// A pass that decomposes composite resource operations into primitive ones like +// ReadVariableOp, AssignVariableOp and other computations to facilitate +// transformations like resource op lifting. +std::unique_ptr> CreateDecomposeResourceOpsPass(); + // Creates a pass that lifts operations on external resource variables from // device computation nested in `tf_device::LaunchOp` out so that resource // variable load operations are all before device computation while resource diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc index 8033773cfaa..9787ac0f0f0 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc @@ -64,8 +64,8 @@ llvm::SmallVector ExpandReplicateIntoReplicas( // Replace replicate terminator with YieldOp. builder->setInsertionPoint(&terminator); - builder->create( - terminator.getLoc(), llvm::to_vector<8>(terminator.getOperands())); + builder->create(terminator.getLoc(), + terminator.getOperands()); terminator.erase(); builder->setInsertionPoint(island_op); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc new file mode 100644 index 00000000000..616c2cb10e8 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc @@ -0,0 +1,278 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/Attributes.h" // TF:local_config_mlir +#include "mlir/IR/Builders.h" // TF:local_config_mlir +#include "mlir/IR/Function.h" // TF:local_config_mlir +#include "mlir/IR/Operation.h" // TF:local_config_mlir +#include "mlir/IR/Types.h" // TF:local_config_mlir +#include "mlir/IR/Value.h" // TF:local_config_mlir +#include "mlir/IR/Visitors.h" // TF:local_config_mlir +#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir +#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" + +namespace mlir { +namespace TF { + +namespace { +constexpr char kDeviceAttr[] = "device"; +constexpr char kFuncDeviceAttr[] = "tf.device"; + +// A pass that propagates device assignment of resources on a module. It +// performs in-function propagation, as well as cross-function propagation from +// callers to callees. +// +// This pass changes the module by adding "tf.device" attribute to function +// arguments and adding "device" attribute to TF ops. +struct ResourceDeviceInference : public ModulePass { + void runOnModule() override; +}; + +// A class that records each resource's device assignment in a function. +class PerFunctionResult { + public: + explicit PerFunctionResult(FuncOp func_op) : alias_analysis_(func_op) {} + + // Returns the recorded device assignment for a resource, if any. + llvm::Optional DeviceForResource( + const Value* resource) const { + llvm::Optional result; + if (alias_analysis_.IsUnknownResource(resource)) return result; + for (int64_t id : alias_analysis_.GetResourceUniqueIds(resource)) { + auto it = resource_id_to_device_.find(id); + if (it == resource_id_to_device_.end()) continue; + if (!result) { + result = it->getSecond(); + continue; + } + if (result != it->getSecond()) { + // Got conflicting assignments, clear the result. + result.reset(); + return result; + } + } + return result; + } + + // Records the device assignment for a resource. If the new assignment + // conflicts with an existing one, returns an error. + // + // If `changed` is provided, assign *changed to true if anything is modified. + LogicalResult AddResourceDevice(const Value* resource, llvm::StringRef device, + bool* changed = nullptr) { + if (alias_analysis_.IsUnknownResource(resource)) return success(); + for (int64_t id : alias_analysis_.GetResourceUniqueIds(resource)) { + auto emplace_res = resource_id_to_device_.try_emplace(id, device); + if (emplace_res.second) { + if (changed) *changed = true; + } else if (emplace_res.first->getSecond() != device) { + // Existing assignment does not equal the new assignment. + return failure(); + } + } + return success(); + } + + private: + llvm::SmallDenseMap resource_id_to_device_; + TF::ResourceAliasAnalysis alias_analysis_; +}; + +// Tries to record device assignment for a resource. +LogicalResult AddResourceDeviceAndEmitError(const Value* resource, + llvm::StringRef device, + Operation* error_reporting_op, + PerFunctionResult* result, + bool* changed = nullptr) { + auto res = result->AddResourceDevice(resource, device, changed); + if (failed(res)) { + error_reporting_op->emitError() + << "Conflicting device assignment for resource"; + } + return res; +} + +// Propagates device assignment inside a function. +LogicalResult ComputeResourceDevicesInComputation(FuncOp func_op, + PerFunctionResult* result) { + OpBuilder builder(func_op); + // Function arguments. + for (auto arg : func_op.getArguments()) { + if (!mlir::getElementTypeOrSelf(arg->getType()).isa()) { + continue; + } + auto device_attr = func_op.getArgAttrOfType( + arg->getArgNumber(), kFuncDeviceAttr); + if (!device_attr || device_attr.getValue() == "") { + // If device_attr does not exist, try to construct it from any recorded + // assignment. + if (auto device = result->DeviceForResource(arg)) { + func_op.setArgAttr(arg->getArgNumber(), kFuncDeviceAttr, + builder.getStringAttr(*device)); + } + continue; + } + // Record the attribute. + auto res = AddResourceDeviceAndEmitError(arg, device_attr.getValue(), + func_op, result); + if (failed(res)) return res; + } + auto walk_res = func_op.walk([&](Operation* op) { + if (auto var_handle = llvm::dyn_cast(op)) { + // Record VarHanldeOp's device attribute. + auto device_attr = + var_handle.getAttrOfType(kDeviceAttr); + if (!device_attr || device_attr.getValue().empty()) { + return WalkResult::advance(); + } + auto res = AddResourceDeviceAndEmitError( + var_handle.resource(), device_attr.getValue(), op, result); + if (failed(res)) return WalkResult::interrupt(); + } + if (auto identity = llvm::dyn_cast(op)) { + // Try to construct IdentityOp's attribute from recorded assignment. + if (!mlir::getElementTypeOrSelf(identity.output()->getType()) + .isa()) { + return WalkResult::advance(); + } + if (auto device = result->DeviceForResource(identity.output())) { + auto device_attr = + identity.getAttrOfType(kDeviceAttr); + if (!device_attr || device_attr.getValue().empty()) { + identity.setAttr(kDeviceAttr, builder.getStringAttr(*device)); + } + } + return WalkResult::advance(); + } + // Propagate and record output device assignment for other ops based on + // existing recording. E.g., IdentityN. + for (auto output : op->getResults()) { + if (!mlir::getElementTypeOrSelf(output->getType()) + .isa()) { + continue; + } + if (auto device = result->DeviceForResource(output)) { + auto res = AddResourceDeviceAndEmitError(output, *device, op, result); + if (failed(res)) return WalkResult::interrupt(); + } + } + return WalkResult::advance(); + }); + return failure(walk_res.wasInterrupted()); +} + +void ResourceDeviceInference::runOnModule() { + auto module = getModule(); + llvm::SmallDenseMap per_function_results; + llvm::SetVector worklist; + module.walk([&](FuncOp func_op) { + worklist.insert(func_op); + per_function_results.try_emplace(func_op, func_op); + }); + // Helper that propagates an op's recorded operand device assignments to its + // called function's arguments. + auto propagate_operands_to_callee_arguments = + [&](Operation* caller, + llvm::iterator_range caller_operands, + llvm::StringRef called_func_name, + const PerFunctionResult& caller_res) { + auto callee = + llvm::dyn_cast(module.lookupSymbol(called_func_name)); + assert(callee); + auto& callee_res = per_function_results.find(callee)->getSecond(); + bool callee_needs_recompute = false; + for (auto operand_and_argument : + llvm::zip(caller_operands, callee.getArguments())) { + if (!mlir::getElementTypeOrSelf( + std::get<0>(operand_and_argument)->getType()) + .isa()) { + continue; + } + auto device = + caller_res.DeviceForResource(std::get<0>(operand_and_argument)); + if (!device) continue; + if (failed(AddResourceDeviceAndEmitError( + std::get<1>(operand_and_argument), *device, caller, + &callee_res, &callee_needs_recompute))) { + return failure(); + } + } + // If the callee recording is modified, make sure that it will be + // reprocessed. + if (callee_needs_recompute) { + worklist.insert(callee); + } + return success(); + }; + while (!worklist.empty()) { + auto func_op = worklist.back(); + worklist.pop_back(); + auto& func_res = per_function_results.find(func_op)->getSecond(); + // In-function propagation. + if (failed(ComputeResourceDevicesInComputation(func_op, &func_res))) { + return signalPassFailure(); + } + // Propagation to callees. + auto walk_res = func_op.walk([&](Operation* op) { + if (auto while_op = llvm::dyn_cast(op)) { + if (failed(propagate_operands_to_callee_arguments( + while_op, while_op.getOperands(), while_op.body(), func_res)) || + failed(propagate_operands_to_callee_arguments( + while_op, while_op.getOperands(), while_op.cond(), func_res))) { + return WalkResult::interrupt(); + } + } else if (auto if_op = llvm::dyn_cast(op)) { + if (failed(propagate_operands_to_callee_arguments( + if_op, if_op.input(), if_op.then_branch(), func_res)) || + failed(propagate_operands_to_callee_arguments( + if_op, if_op.input(), if_op.else_branch(), func_res))) { + return WalkResult::interrupt(); + } + } + return WalkResult::advance(); + }); + if (walk_res.wasInterrupted()) return signalPassFailure(); + } +} + +} // namespace + +std::unique_ptr> CreateResourceDeviceInferencePass() { + return std::make_unique(); +} + +static PassRegistration pass( + "tf-resource-device-inference", + "Propagates the device attribute on resources from callers to callees."); + +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc index 7aa5c19fead..2f32a3a2c28 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc @@ -77,129 +77,6 @@ struct ResourceOpLiftingPass : public FunctionPass { void runOnFunction() override; }; -// Rewrites composite variable op `tf.AssignAddVariableOp` or -// `tf.AssignSubVariableOp` into primitive resource/computation ops. -// For example: -// -// tf.AssignAddVariableOp(%res, %0) -// -// Becomes -// -// %res_val = tf.ReadVariableOp(%res) -// %1 = tf.AddV2(%res_val, %0) -// tf.AssignVariableOp(%res, %1) -// -template -LogicalResult RewriteCompositeAssignVariableOp(T src_op, OpBuilder* builder) { - // Read mangled dtype, which indicates type of data stored in resource - // variable. It can then be used to construct type needed for both - // ReadVariableOp and AssignVariableOp. - StringAttr mangled_dtype_attr = - src_op.template getAttrOfType(kDTypeAttr); - std::string type_string = mangled_dtype_attr.getValue(); - tensorflow::DataType dtype_proto; - auto s = - tensorflow::mangling_util::DemangleDataType(type_string, &dtype_proto); - if (!s.ok()) return src_op.emitError() << s.error_message(); - - Type type; - s = tensorflow::ConvertDataType(dtype_proto, *builder, &type); - if (!s.ok()) return src_op.emitError() << s.error_message(); - type = UnrankedTensorType::get(type); - - builder->setInsertionPoint(src_op); - - auto read_variable_op = builder->create( - src_op.getLoc(), type, src_op.resource()); - read_variable_op.setAttr(builder->getIdentifier(kDTypeAttr), - mangled_dtype_attr); - - Value* result; - if (std::is_same()) { - result = builder->create( - src_op.getLoc(), read_variable_op.value(), src_op.value()); - } else { - result = builder->create( - src_op.getLoc(), read_variable_op.value(), src_op.value()); - } - - auto assign_variable_op = builder->create( - src_op.getLoc(), src_op.resource(), result); - assign_variable_op.setAttr(builder->getIdentifier(kDTypeAttr), - mangled_dtype_attr); - - src_op.erase(); - return success(); -} - -// Rewrites `tf.ResourceApplyGradientDescent` into primitive resource and -// computation ops. -// -// Specifically: -// -// tf.ResourceApplyGradientDescent(%var, %alpha, %delta) -// -// Becomes -// -// %old_var_val = tf.ReadVariableOp(%var) -// %gradient_update = tf.Mul(%alpha, %delta) -// %new_var_val = tf.Sub(%old_var_val, %gradient_update) -// tf.AssignVariableOp(%var, %new_var_val) -LogicalResult RewriteResourceApplyGradientDescentOp( - TF::ResourceApplyGradientDescentOp op, OpBuilder* builder) { - Type type = op.alpha()->getType(); - auto t = UnrankedTensorType::get(type.cast().getElementType()); - - tensorflow::DataType data_type; - auto s = tensorflow::ConvertToDataType(type, &data_type); - if (!s.ok()) return op.emitError() << s.error_message(); - - std::string mangled_data_type = - tensorflow::mangling_util::MangleDataType(data_type); - auto mangled_dtype_attr = builder->getStringAttr(mangled_data_type); - - builder->setInsertionPoint(op); - auto read_variable_op = - builder->create(op.getLoc(), t, op.var()); - read_variable_op.setAttr(builder->getIdentifier(kDTypeAttr), - mangled_dtype_attr); - - auto mul_op = - builder->create(op.getLoc(), t, op.alpha(), op.delta()); - auto sub_op = builder->create( - op.getLoc(), t, read_variable_op.value(), mul_op.z()); - auto assign_variable_op = - builder->create(op.getLoc(), op.var(), sub_op.z()); - assign_variable_op.setAttr(builder->getIdentifier(kDTypeAttr), - mangled_dtype_attr); - - op.erase(); - - return success(); -} - -// Rewrites an operation that updates value of a resource variable into its -// equivalent primitive ones so that following analysis/rewrite can be easier. -// If given op is not a composite resource store op or is an unsupported op, no -// change is applied. -// TODO(ycao): Explore using pattern rewriter after needed operations are -// defined. -// TODO(ycao): Add support for other composite resource store ops. -LogicalResult MaybeRewriteCompositeResourceStore(Operation* op, - OpBuilder* builder) { - if (auto assign_add_op = dyn_cast(op)) { - return RewriteCompositeAssignVariableOp(assign_add_op, builder); - } else if (auto assign_sub_op = dyn_cast(op)) { - return RewriteCompositeAssignVariableOp(assign_sub_op, builder); - } else if (auto resource_apply_gradient_descent_op = - dyn_cast(op)) { - return RewriteResourceApplyGradientDescentOp( - resource_apply_gradient_descent_op, builder); - } - - return success(); -} - // Performs store-load forwarding. This effectively removes // 1) Any resource loads after a store to that same resource is done // 2) Any resource stores except the last one. @@ -358,10 +235,6 @@ void HoistResourceOpsFromLaunchOp(tf_device::LaunchOp launch_op) { ModuleOp m = launch_op.getParentOfType(); OpBuilder builder(m); - // Rewrite composite resource store operations into primitive ones. - launch_op.walk( - [&](Operation* op) { MaybeRewriteCompositeResourceStore(op, &builder); }); - // Perform store-load forwarding. So that each resource is only loaded with // its initial value and is only stored with its final value. ForwardStoreToLoad(launch_op); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index c44c81d1cef..812100ced64 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -47,6 +47,46 @@ using ::tensorflow::int64; namespace mlir { namespace TF { +namespace { +Optional> InferShapeForFunctionReturnType( + FuncOp func) { + // Only infer shape when there is one return op for now. + if (!has_single_element(func.getBody()) || func.front().empty()) { + return None; + } + + // Find the return type. + auto return_op = dyn_cast(func.front().back()); + if (!return_op) { + return None; + } + + // Manually fold tf.Cast that precedes the return instruction and only differs + // in shape refinement level. + for (OpOperand& arg_op : return_op.getOperation()->getOpOperands()) { + if (auto cast_op = dyn_cast(arg_op.get()->getDefiningOp())) { + // Shape inference should not change the element type. + if (cast_op.SrcT() != cast_op.DstT()) continue; + // We only refine the result shape if the result a dynamic shape, the + // input has static shape, and the two shapes are compatible. + auto has_static_shape = [](const Value* value) { + auto shaped_type = value->getType().dyn_cast(); + return shaped_type && shaped_type.hasStaticShape(); + }; + Value* input = cast_op.x(); + Value* result = cast_op.y(); + if (!has_static_shape(input) || has_static_shape(result) || + failed(verifyCompatibleShape(input->getType(), result->getType()))) + continue; + + arg_op.set(cast_op.x()); + if (cast_op.y()->use_empty()) cast_op.erase(); + } + } + + return llvm::to_vector<4>(return_op.getOperandTypes()); +} +} // namespace bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, int64_t graph_version) { @@ -245,11 +285,10 @@ LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version, return success(); } -LogicalResult InferShapeForFunction(FuncOp op, +LogicalResult InferShapeForFunction(FuncOp func, ArrayRef> arg_shapes, int64_t graph_version) { - auto main_func = op; - mlir::FunctionType func_type = main_func.getType(); + mlir::FunctionType func_type = func.getType(); bool needs_refinement = false; llvm::SmallVector new_arg_types; new_arg_types.reserve(func_type.getNumInputs()); @@ -276,7 +315,7 @@ LogicalResult InferShapeForFunction(FuncOp op, auto new_arg_type = mlir::RankedTensorType::get(shape, element_type); if (new_arg_type != func_type.getInput(i)) { // If the new type is more detailed, trigger shape inference. - main_func.getArgument(i)->setType(new_arg_type); + func.getArgument(i)->setType(new_arg_type); needs_refinement = true; } new_arg_types.push_back(new_arg_type); @@ -287,39 +326,28 @@ LogicalResult InferShapeForFunction(FuncOp op, } mlir::LogicalResult result = - mlir::TF::InferShapeUntilFixPoint(&main_func.getBody(), graph_version); + mlir::TF::InferShapeUntilFixPoint(&func.getBody(), graph_version); if (failed(result)) { return failure(); } - // Must only have 1 block so that there is only one return op. - if (main_func.getBody().getBlocks().size() != 1 || - main_func.front().empty()) { - return failure(); + auto return_types = InferShapeForFunctionReturnType(func); + func.setType(mlir::FunctionType::get(new_arg_types, + return_types.hasValue() + ? return_types.getValue() + : func.getType().getResults(), + func.getContext())); + + return success(); +} + +LogicalResult InferShapeForFunctionType(FuncOp func) { + if (auto return_types = InferShapeForFunctionReturnType(func)) { + func.setType(mlir::FunctionType::get(func.getType().getInputs(), + return_types.getValue(), + func.getContext())); } - // Find the return type. - auto return_op = dyn_cast(*main_func.front().rbegin()); - if (!return_op) { - return failure(); - } - - // Manually fold tf.Cast that precedes the return instruction and only differ - // in shape refinement level. - for (OpOperand& arg_op : return_op.getOperation()->getOpOperands()) { - if (auto cast_op = dyn_cast(arg_op.get()->getDefiningOp())) { - if (cast_op.SrcT() != cast_op.DstT()) continue; - arg_op.set(cast_op.x()); - if (cast_op.y()->use_empty()) cast_op.erase(); - } - } - - llvm::SmallVector return_types(return_op.getOperandTypes()); - - // Update function signature with the results of inference. - main_func.setType( - mlir::FunctionType::get(new_arg_types, return_types, op.getContext())); - return success(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h index 1cbd5eb6c29..0529e6414b7 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h @@ -41,12 +41,16 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version, int64_t max_iteration = 10); -// Given a list of refined shapes matching the function arguments of op, run +// Given a list of refined shapes matching the function arguments of func, runs // shape inference over the function to propagate this updated information. -LogicalResult InferShapeForFunction(FuncOp op, +LogicalResult InferShapeForFunction(FuncOp func, ArrayRef> arg_shapes, int64_t graph_version); +// Refines the return type of the given function by folding tf.Cast that +// precedes the return instruction. +LogicalResult InferShapeForFunctionType(FuncOp func); + } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc index 2ef601e914d..637b14346b0 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc @@ -65,7 +65,11 @@ struct ShapeInference : public ModulePass { return; } for (auto func : module.getOps()) { - TF::InferShapeUntilFixPoint(&func.getBody(), producer.getInt()); + InferShapeUntilFixPoint(&func.getBody(), producer.getInt()); + } + + if (auto main_func = module.lookupSymbol("main")) { + InferShapeForFunctionType(main_func); } } }; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc index 6580ad53129..3fb311ff415 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc @@ -34,10 +34,12 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/iterator_range.h" +#include "llvm/Support/Casting.h" #include "mlir/IR/Attributes.h" // TF:local_config_mlir #include "mlir/IR/Builders.h" // TF:local_config_mlir #include "mlir/IR/Identifier.h" // TF:local_config_mlir #include "mlir/IR/MLIRContext.h" // TF:local_config_mlir +#include "mlir/IR/Operation.h" // TF:local_config_mlir #include "mlir/IR/Types.h" // TF:local_config_mlir #include "mlir/IR/Value.h" // TF:local_config_mlir #include "mlir/Pass/Pass.h" // TF:local_config_mlir @@ -57,8 +59,6 @@ constexpr char kTPUReplicateAttr[] = "_tpu_replicate"; constexpr char kDeviceAttr[] = "device"; constexpr char kNameAttr[] = "name"; constexpr char kNumReplicasAttr[] = "num_replicas"; -constexpr char kTPUReplicatedInputOp[] = "tf.TPUReplicatedInput"; -constexpr char kTPUReplicatedOutputOp[] = "tf.TPUReplicatedOutput"; constexpr char kBadTPUReplicateAttrMsg[] = "requires '_tpu_replicate' string attribute"; @@ -275,9 +275,8 @@ LogicalResult ReplicateCluster(tf_device::LaunchOp launch_op, mlir::visitUsedValuesDefinedAbove( launch_op.body(), launch_op.body(), [&](mlir::OpOperand* operand) { Operation* def = operand->get()->getDefiningOp(); - if (def && def->getName().getStringRef() == kTPUReplicatedInputOp) { + if (def && llvm::isa(def)) replicated_input_ops.insert(def); - } }); // Check if number of operands of each used TPUReplicatedInput op matches @@ -305,10 +304,10 @@ LogicalResult ReplicateCluster(tf_device::LaunchOp launch_op, int idx = result_and_idx.index(); for (auto& use : result->getUses()) { Operation* def = use.getOwner(); - if (!def || def->getName().getStringRef() != kTPUReplicatedOutputOp) + if (!def || !llvm::isa(def)) return launch_op.emitError() << "requires output of " << launch_op.getOperationName() - << " to lead to a '" << kTPUReplicatedOutputOp << "' op"; + << " to lead to a 'tf.TPUReplicatedOutput' op"; if (def->getNumResults() != num_replicas) return def->emitOpError() << "requires " << num_replicas << " results"; @@ -331,9 +330,8 @@ LogicalResult ReplicateCluster(tf_device::LaunchOp launch_op, // Create terminator for replicate op and move launch into replicate. builder.setInsertionPointToEnd(&replicate_op.GetBody()); - auto return_op = builder.create( - replicate_op.getLoc(), - llvm::SmallVector(launch_op.getResults())); + auto return_op = builder.create(replicate_op.getLoc(), + launch_op.getResults()); launch_op.getOperation()->moveBefore(return_op); return success(); @@ -427,8 +425,8 @@ void TPUClusterFormation::runOnFunction() { // Remove TPUReplicatedInput and TPUReplicatedOutput nodes. auto remove_result = getFunction().walk([&](Operation* op) { - auto op_name = op->getName().getStringRef(); - if (op_name != kTPUReplicatedInputOp && op_name != kTPUReplicatedOutputOp) + if (!llvm::isa(op) && + !llvm::isa(op)) return WalkResult::advance(); // Forward operand to result. When `num_replicas` attribute is 1, no @@ -440,7 +438,8 @@ void TPUClusterFormation::runOnFunction() { // Leftover TPUReplicatedInput/TPUReplicatedOutput that are not of // `num_replicas` to 1. if (!op->use_empty()) { - op->emitOpError() << "expects " << op_name << " to have no uses"; + op->emitOpError() << "expects " << op->getName().getStringRef() + << " to have no uses"; return WalkResult::interrupt(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc index c5bf918a496..1033670dd1c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc @@ -109,13 +109,13 @@ LogicalResult EncapsulateFuncAndSerialize(FuncOp entry_func, return parent_module.emitError(CreateMissingAttributeMsg(kVersionsAttr)); module_for_func.get().getOperation()->setAttr(kVersionsAttr, versions_attr); - ModuleManager module_manager(module_for_func.get()); + SymbolTable symbol_table(module_for_func.get()); while (!referenced.empty()) { auto func = referenced.pop_back_val(); // Skip functions that have already been cloned into new module. - if (module_manager.lookupSymbol(func.getName())) continue; + if (symbol_table.lookup(func.getName())) continue; // Find any SymbolRefAttr in func that maps to a FuncOp. We need to clone // all found FuncOps to new_module to make sure new_module is @@ -138,7 +138,7 @@ LogicalResult EncapsulateFuncAndSerialize(FuncOp entry_func, // should be no other reference to it. clone.setName("main"); } - module_manager.insert(clone); + symbol_table.insert(clone); } // Serialize module and return. diff --git a/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc b/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc index 22d04b27dd1..764c7915577 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc @@ -13,14 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir #include "mlir/IR/Builders.h" // TF:local_config_mlir #include "mlir/IR/Operation.h" // TF:local_config_mlir #include "mlir/Pass/Pass.h" // TF:local_config_mlir #include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir #include "mlir/Support/STLExtras.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" // This pass is used in preparation for Graph export. @@ -38,12 +43,11 @@ struct BreakUpIslands : OperationPass { void runOnOperation() final; void BreakUpIsland(tf_executor::IslandOp op, + const TF::SideEffectAnalysis& side_effect_analysis, llvm::DenseMap>* new_control_edges); }; -} // end anonymous namespace - void BreakUpIslands::runOnOperation() { auto graph_op_range = getOperation().getBody().front().without_terminator(); tf_executor::GraphOp graph_op; @@ -61,12 +65,13 @@ void BreakUpIslands::runOnOperation() { // Map from the users of the existing islands to the list of control // edges that need to be added. llvm::DenseMap> new_control_edges; + auto& side_effect_analysis = getAnalysis(); // Iterate in reverse order to avoid invalidating Operation* stored in // new_control_edges. for (auto& item : llvm::make_early_inc_range(llvm::reverse(graph_op.GetBody()))) { if (auto island = dyn_cast(&item)) { - BreakUpIsland(island, &new_control_edges); + BreakUpIsland(island, side_effect_analysis, &new_control_edges); } } OpBuilder builder(getOperation()); @@ -106,21 +111,81 @@ void BreakUpIslands::runOnOperation() { } } +// Helper that creates an island. If `sub_op` is not nullptr, it will be moved +// to the island. +tf_executor::IslandOp CreateIsland(ArrayRef result_types, + ArrayRef control_inputs, + const tf_executor::ControlType& control_type, + const Location& loc, Operation* sub_op, + tf_executor::IslandOp original_island) { + OpBuilder builder(original_island); + auto island = builder.create( + loc, result_types, control_type, control_inputs); + island.body().push_back(new Block); + Block* block = &island.body().back(); + if (sub_op) { + sub_op->replaceAllUsesWith(island.outputs()); + sub_op->moveBefore(block, block->begin()); + } + OpBuilder island_builder(original_island); + island_builder.setInsertionPointToEnd(block); + if (sub_op) { + island_builder.create(loc, sub_op->getResults()); + } else { + island_builder.create(loc, ArrayRef{}); + } + return island; +} + +// A struct contains the operations in an island that do not have incoming or +// outgoing dependencies. +struct IslandSourcesAndSinks { + // Sub-ops that do not depend on other ops in the island. + llvm::SmallPtrSet sources; + // Sub-ops that do not have other sub-ops island depending on them (excluding + // yield). + llvm::SmallPtrSet sinks; +}; + +// Finds IslandSourcesAndSinks for an unmodified island. +IslandSourcesAndSinks FindSourcesAndSinksInIsland( + tf_executor::IslandOp island, + const TF::SideEffectAnalysis& side_effect_analysis) { + IslandSourcesAndSinks result; + auto island_body = island.GetBody().without_terminator(); + for (Operation& sub_op : island_body) { + auto predecessors = side_effect_analysis.DirectControlPredecessors(&sub_op); + result.sinks.insert(&sub_op); + // Remove predecessor from sinks. + for (auto predecessor : predecessors) result.sinks.erase(predecessor); + bool has_in_island_operands = false; + for (auto operand : sub_op.getOperands()) { + auto defining_op = operand->getDefiningOp(); + if (!defining_op || defining_op->getParentOp() != island) continue; + // Remove operands from sinks. + result.sinks.erase(defining_op); + has_in_island_operands = true; + } + if (predecessors.empty() && !has_in_island_operands) { + result.sources.insert(&sub_op); + } + } + return result; +} + // Converts a single island into multiple islands (one for each op). The islands // are chained together by control flow values. void BreakUpIslands::BreakUpIsland( tf_executor::IslandOp op, + const TF::SideEffectAnalysis& side_effect_analysis, llvm::DenseMap>* new_control_edges) { auto island_body = op.GetBody().without_terminator(); // Skip islands that are already only a single op. // Skip islands that are empty (only yield). if (island_body.empty() || has_single_element(island_body)) return; - OpBuilder builder(op); - OpBuilder island_builder(op); auto control_type = tf_executor::ControlType::get(&getContext()); - Value* previous_island = nullptr; - auto tmp_control_inputs = llvm::to_vector<4>(op.controlInputs()); + auto island_control_inputs = llvm::to_vector<4>(op.controlInputs()); // Add control dependencies for yields of values defined by other islands to // the island that defines that fetched value. for (auto* fetch : op.GetYield().fetches()) { @@ -130,7 +195,7 @@ void BreakUpIslands::BreakUpIsland( // OK, because it is the same island. } else if (auto island_op = llvm::dyn_cast( fetch->getDefiningOp())) { - tmp_control_inputs.push_back(island_op.control()); + island_control_inputs.push_back(island_op.control()); } else { // TODO(parkers): Any defining op that has a control output can be handled // just like an island. @@ -138,39 +203,71 @@ void BreakUpIslands::BreakUpIsland( return signalPassFailure(); } } - ArrayRef previous_control = tmp_control_inputs; + // If there are multiple control inputs, create an empty island to group them. + if (island_control_inputs.size() > 1) { + auto island = CreateIsland({}, island_control_inputs, control_type, + op.getLoc(), nullptr, op); + island_control_inputs.clear(); + island_control_inputs.push_back(island.control()); + } + // Find sources and sinks inside the original island. + auto sources_and_sinks = + FindSourcesAndSinksInIsland(op, side_effect_analysis); + // The corresponding control output of the new island created for each sub-op. + llvm::SmallDenseMap new_control_for_sub_ops; + // Control outputs of newly created islands that are sinks. + llvm::SmallVector sink_island_controls; // For each operation in the island, construct a new island to wrap the op, // yield all the results, and replace all the usages with the results of the // new island. - for (Operation& sub_op : llvm::make_early_inc_range(island_body)) { - auto loc = sub_op.getLoc(); - auto island = builder.create( - loc, llvm::to_vector<4>(sub_op.getResultTypes()), control_type, - previous_control); - island.body().push_back(new Block); - Block* block = &island.body().back(); - sub_op.replaceAllUsesWith(island.outputs()); - block->getOperations().splice(block->begin(), op.GetBody().getOperations(), - sub_op); - island_builder.setInsertionPointToEnd(block); - island_builder.create( - loc, llvm::to_vector<4>(sub_op.getResults())); - previous_island = island.control(); - previous_control = previous_island; + for (auto& sub_op : llvm::make_early_inc_range(island_body)) { + const auto predecessors = + side_effect_analysis.DirectControlPredecessors(&sub_op); + // Get the controls from the predecessors. + llvm::SmallVector predecessors_control; + predecessors_control.reserve(predecessors.size()); + for (auto predecessor : predecessors) { + predecessors_control.push_back(new_control_for_sub_ops[predecessor]); + } + // If sub_op is a source, use island_control_inputs, because that's required + // by inter-islands dependencies; otherwise, we do not need to include + // island_control_inputs, since they must have been tracked by the (direct + // or indirect) control predecessors or operands. + ArrayRef control = sources_and_sinks.sources.count(&sub_op) > 0 + ? island_control_inputs + : predecessors_control; + auto island = + CreateIsland(llvm::to_vector<4>(sub_op.getResultTypes()), control, + control_type, sub_op.getLoc(), &sub_op, op); + new_control_for_sub_ops[&sub_op] = island.control(); + if (sources_and_sinks.sinks.count(&sub_op)) { + sink_island_controls.push_back(island.control()); + } } - op.control()->replaceAllUsesWith(previous_island); - // All existing outputs need to add a control flow edge to the - // previous_island. + // Create output controls for the sinks. + assert(!sink_island_controls.empty()); + // If there are multiple output controls, create an empty island to group + // them. + if (sink_island_controls.size() > 1) { + auto island = CreateIsland({}, sink_island_controls, control_type, + op.getLoc(), nullptr, op); + sink_island_controls.clear(); + sink_island_controls.push_back(island.control()); + } + assert(sink_island_controls.size() == 1); + op.control()->replaceAllUsesWith(sink_island_controls[0]); + // All existing outputs need to add a control flow edge from + // sink_island_controls[0]. for (Value* out : op.outputs()) { for (auto& use : out->getUses()) { Operation* owner = use.getOwner(); if (auto island_op = llvm::dyn_cast(owner->getParentOp())) { - (*new_control_edges)[island_op].push_back(previous_island); + (*new_control_edges)[island_op].push_back(sink_island_controls[0]); } else if (llvm::isa(owner) || llvm::isa(owner) || llvm::isa(owner)) { - (*new_control_edges)[owner].push_back(previous_island); + (*new_control_edges)[owner].push_back(sink_island_controls[0]); } else { use.getOwner()->emitError("Adding control dependency not supported"); return signalPassFailure(); @@ -182,6 +279,8 @@ void BreakUpIslands::BreakUpIsland( op.erase(); } +} // namespace + std::unique_ptr> CreateBreakUpIslandsPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc index bac3ea22973..58242e62f1c 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc @@ -535,6 +535,18 @@ StatusOr> Exporter::Convert( arg, index, graph_as_function && !input_names.empty() ? input_names[index] : "")); } + + auto convert_called_function = [&](llvm::StringRef name) { + auto func = + function.getParentOfType().lookupSymbol( + name); + if (func != nullptr) { + TF_RETURN_IF_ERROR(ConvertLibFunction(configs, tf_dialect, func, flib)); + TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(*flib)); + } + return Status::OK(); + }; + // Adds nodes for operations. for (Operation& inst : block) { auto op_name = GetTensorFlowOpName(inst.getName().getStringRef()); @@ -544,13 +556,12 @@ StatusOr> Exporter::Convert( // definition library // TODO(prakalps): If two functions have cyclic dependence, this will // introduce an infinite loop. - auto func = - function.getParentOfType().lookupSymbol( - op_name.ValueOrDie()); - if (func != nullptr) { - TF_RETURN_IF_ERROR(ConvertLibFunction(configs, tf_dialect, func, flib)); - TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(*flib)); - } + TF_RETURN_IF_ERROR(convert_called_function(op_name.ValueOrDie().str())); + } + + if (IsLegacyCallInstruction(&inst)) { + TF_RETURN_IF_ERROR(convert_called_function( + inst.getAttrOfType("f").getLeafReference())); } for (auto type : inst.getResultTypes()) { diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index da2e6a67445..7bc7c914f56 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -16,8 +16,11 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include +#include #include #include +#include +#include #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" @@ -97,6 +100,9 @@ using stream_executor::port::StatusOr; namespace { +const char* disable_call_shape_inference_attribute_name = + "_disable_call_shape_inference"; + // This class is used to generate new MLIR function name strings that are both // unique in the TF function library `flib_` and unique among the name strings // generated by the class object during its lifetime. @@ -246,11 +252,14 @@ class ImporterBase { llvm::SmallVector* attributes); // Helper to create either a tf_executor operation or a TF operation wrapped - // in an island. + // in an island. When convert_to_legacy_call is true, converts the operation + // representing a call to a library function with a name represented in + // node_type_name to LegacyCallOp. mlir::Operation* createOperation( - const Node& node, llvm::StringRef op_name, + const Node& node, llvm::StringRef node_type_name, const mlir::OperationState& result, - const llvm::SmallVectorImpl& control_operands); + const llvm::SmallVectorImpl& control_operands, + bool convert_to_legacy_call = false); // Converts one NodeDef from the input GraphDef into an Operation and // inserts it into the MLIR module using builder_. @@ -297,19 +306,24 @@ class ImporterBase { // Gets the location information string for the given node. std::string GetLocationStr(const Node& node, bool includeNodeName = false); - // Inserts a placeholder node in the graph to replace the input node. Replaces - // all the output edges of the input_node with the placeholder node, and - // removes the input_node from the graph. The new node has the same name as - // the input_node, so Nodespecs do not need any modification. + // Inserts a placeholder node in the graph to replace a feed output tensor, + // and returns the new placeholder node and a boolean indicating if the + // original input node was removed from the graph. Uses of the feed output + // tensor are replaced with this placeholder node. If the feed output tensor + // is of a single output node, the control dependencies are forwarded to the + // the placeholder node, and the original node will be removed. // Note: This modifies the graph, and so any list of ordered nodes needs to be // reconstructed. - StatusOr ReplaceWithPlaceholderNode(const TensorShapeProto& shape, - DataType dtype, Node* input_node); + StatusOr> CreatePlaceholderNodeForFeed( + const TensorShapeProto& shape, DataType dtype, Node* node, int index, + const std::unordered_map& node_name_map); // Gets the input and output nodes corresponding to the specified input and // output nodes in specs_. If there are no input or output nodes specified, - // nodes will be empty - Status GetInputOutputNodes(std::unordered_set* nodes); + // nodes will be empty. + Status GetInputOutputNodes( + const std::unordered_map& node_name_map, + std::unordered_set* nodes); // The input graph with backedges removed. The removed backedges are stored // in the back_edge_helper. @@ -339,6 +353,10 @@ class ImporterBase { NodeValueMap node_values_; std::unique_ptr shape_refiner_; NameUniquifier* function_name_uniquifier_; + + protected: + // Maps feed as TensorId to new Placeholder node name. + absl::flat_hash_map remapped_feeds_; }; // Returns true if the node with given name has a non primary output that is @@ -419,6 +437,49 @@ Status PreprocessGraphDef(const GraphImportConfig* specs, GraphDef* graph_def) { return Status::OK(); } +// Mapping from node name to feed (index and ArrayInfo). Node name must outlive +// this map. +using FeedsByNode = absl::flat_hash_map< + absl::string_view, + absl::flat_hash_map*>>; + +// Creates from a `GraphImportConfig::InputArrays` a mapping from a feeds output +// tensor name to index and ArrayInfo. Keys and values are backed by +// `GraphImportConfig::InputArrays`. +StatusOr GetFeedsByNode( + const GraphImportConfig::InputArrays& inputs) { + FeedsByNode feeds_by_node; + feeds_by_node.reserve(inputs.size()); + + for (const auto& input : inputs) { + TensorId tensor = ParseTensorName(input.first); + if (tensor.index() < 0) + return errors::FailedPrecondition( + "Feed output tensor must be a data output '", tensor.ToString(), "'"); + + auto& node = feeds_by_node[tensor.node()]; + if (!node.insert({tensor.index(), &input}).second) + return errors::FailedPrecondition( + "Multiple feeds for the same output tensor '", tensor.ToString(), + "'"); + } + + return feeds_by_node; +} + +// Creates a unique name for a node that will be replacing a feed output tensor. +std::string GetUniqueNodeName( + absl::string_view node_name, int index, + const std::unordered_map& node_name_map) { + std::string new_node_name_base = absl::StrCat(node_name, "_", index); + int count = 0; + std::string new_node_name = new_node_name_base; + while (node_name_map.find(new_node_name) != node_name_map.end()) { + new_node_name = absl::StrCat(new_node_name_base, "_", count++); + } + return new_node_name; +} + Status ImporterBase::RemoveBackedges(const Graph& graph) { // TODO(fengliuai): Converting to GraphDef and back is the easiest way to // clone a graph. @@ -459,37 +520,54 @@ Status ImporterBase::RemoveBackedges(const Graph& graph) { return Status::OK(); } -StatusOr ImporterBase::ReplaceWithPlaceholderNode( - const TensorShapeProto& shape, DataType dtype, Node* input_node) { +StatusOr> ImporterBase::CreatePlaceholderNodeForFeed( + const TensorShapeProto& shape, DataType dtype, Node* node, int index, + const std::unordered_map& node_name_map) { + DCHECK_LT(index, node->num_outputs()); + const bool update_inplace = node->num_outputs() == 1 && index == 0; + std::string new_node_name = + update_inplace ? node->name() + : GetUniqueNodeName(node->name(), index, node_name_map); + Node* placeholder_node; - NodeBuilder builder(input_node->name(), "Placeholder"); + NodeBuilder builder(new_node_name, "Placeholder"); builder.Attr("shape", shape); builder.Attr("dtype", dtype); TF_RETURN_IF_ERROR(builder.Finalize(graph_.get(), &placeholder_node)); - while (!input_node->out_edges().empty()) { - const Edge* oe = *input_node->out_edges().begin(); - // UpdateEdge cannot be used with control edges. - if (oe->src_output() == Graph::kControlSlot) { - graph_->AddControlEdge(placeholder_node, oe->dst()); - graph_->RemoveControlEdge(oe); - continue; + // Update edges from original feed with Placeholder node. + std::vector data_edges; + std::vector control_edges; + for (const tensorflow::Edge* edge : node->out_edges()) { + if (edge->src_output() == index) { + data_edges.push_back(edge); + } else if (update_inplace && edge->IsControlEdge()) { + control_edges.push_back(edge); } - - TF_RETURN_IF_ERROR( - graph_->UpdateEdge(placeholder_node, 0, oe->dst(), oe->dst_input())); } - graph_->RemoveNode(input_node); + for (const auto* edge : data_edges) { + TF_RETURN_IF_ERROR(graph_->UpdateEdge(placeholder_node, 0, edge->dst(), + edge->dst_input())); + } - return placeholder_node; + for (const auto* edge : control_edges) { + graph_->AddControlEdge(placeholder_node, edge->dst()); + graph_->RemoveControlEdge(edge); + } + + if (update_inplace) { + graph_->RemoveNode(node); + } + + return std::pair(placeholder_node, update_inplace); } Status ImporterBase::GetInputOutputNodes( + const std::unordered_map& node_name_map, std::unordered_set* nodes) { - auto node_name_map = graph_->BuildNodeNameIndex(); - auto add_node = [&](const string& name) { - auto it = node_name_map.find(name); + auto add_node = [&](absl::string_view name) { + auto it = node_name_map.find(std::string(name)); if (it == node_name_map.end()) { return errors::FailedPrecondition( absl::StrCat("Graph does not contain node: ", name)); @@ -498,13 +576,25 @@ Status ImporterBase::GetInputOutputNodes( return Status::OK(); }; + // Remap feeds and fetches to newly created Placeholder nodes. for (const auto& input : specs_.inputs) { - TF_RETURN_IF_ERROR(add_node(input.first)); + TensorId tensor = ParseTensorName(input.first); + auto remapped_it = remapped_feeds_.find(tensor); + if (remapped_it != remapped_feeds_.end()) { + TF_RETURN_IF_ERROR(add_node(remapped_it->second)); + } else { + TF_RETURN_IF_ERROR(add_node(tensor.node())); + } } for (const auto& output : specs_.outputs) { - auto output_node_name = std::string(ParseTensorName(output).first); - TF_RETURN_IF_ERROR(add_node(output_node_name)); + TensorId tensor = ParseTensorName(output); + auto remapped_it = remapped_feeds_.find(tensor); + if (remapped_it != remapped_feeds_.end()) { + TF_RETURN_IF_ERROR(add_node(remapped_it->second)); + } else { + TF_RETURN_IF_ERROR(add_node(tensor.node())); + } } return Status::OK(); @@ -520,6 +610,9 @@ Status ImporterBase::AddNodesToShapeRefiner() { shape_refiner_->set_require_shape_inference_fns(false); shape_refiner_->set_function_library_for_shape_inference(&graph_flib_); + TF_ASSIGN_OR_RETURN(auto feeds_by_node, GetFeedsByNode(specs_.inputs)); + auto node_name_map = graph_->BuildNodeNameIndex(); + // First add all nodes to the refiner. for (Node* node : ordered_nodes_) { // We need to use a TensorFlow node to teach the shape refiner that user @@ -533,28 +626,49 @@ Status ImporterBase::AddNodesToShapeRefiner() { // it to replace the original input node, so the shape refiner can // successfully propagate the user's input type and shape to the rest of the // graph. - auto it = specs_.inputs.find(node->name()); - if (it != specs_.inputs.end()) { - auto node_name = node->op_def().name(); - if (node_name != "Placeholder" && node_name != "LegacyFedInput" && - node_name != FunctionLibraryDefinition::kArgOp) { - // We do not handle the case where the input node has multiple outputs - if (node->num_outputs() > 1) { - return errors::FailedPrecondition(absl::StrCat( - "Input arrays can only have op with single output. Node op:", - node_name)); + bool node_added_to_shape_refiner = false; + auto it = feeds_by_node.find(node->name()); + if (it != feeds_by_node.end()) { + auto op_name = node->op_def().name(); + if (op_name != "Placeholder" && op_name != "LegacyFedInput" && + op_name != FunctionLibraryDefinition::kArgOp) { + for (const auto& output_tensor : it->second) { + const int index = output_tensor.first; + const ArrayInfo& array_info = output_tensor.second->second; + + DataType dtype = array_info.imported_dtype; + // Uses the existing output type if it isn't specified by the user. + if (dtype == DT_INVALID) { + dtype = node->output_type(0); + } + + TF_ASSIGN_OR_RETURN( + auto placeholder_node_and_removed, + CreatePlaceholderNodeForFeed(array_info.shape, dtype, node, index, + node_name_map)); + + Node* placeholder_node = placeholder_node_and_removed.first; + if (placeholder_node_and_removed.second) { + // Original node has been removed from the graph. + node = placeholder_node; + node_added_to_shape_refiner = true; + } + remapped_feeds_[{it->first, index}] = placeholder_node->name(); + node_name_map[placeholder_node->name()] = placeholder_node; + // Add the new placeholder node to the shape refiner. + TF_RETURN_WITH_CONTEXT_IF_ERROR( + shape_refiner_->AddNode(placeholder_node), + GetLocationStr(*placeholder_node)); } - // For single output nodes, replace them with Placeholder node. - DataType dtype = it->second.imported_dtype; - // Uses the existing output type if it isn't specified by the user. - if (dtype == DT_INVALID) { - dtype = node->output_type(0); - } - TF_ASSIGN_OR_RETURN( - node, ReplaceWithPlaceholderNode(it->second.shape, dtype, node)); } else { - node->AddAttr("shape", it->second.shape); - DataType dtype = it->second.imported_dtype; + auto index_it = it->second.find(0); + if (index_it == it->second.end()) { + return errors::FailedPrecondition( + "Missing feed output tensor at index 0 for node '", node->name(), + "'"); + } + node->AddAttr("shape", index_it->second->second.shape); + DataType dtype = index_it->second->second.imported_dtype; // Uses the existing output type if it isn't specified by the user. if (dtype == DT_INVALID) { dtype = node->output_type(0); @@ -562,9 +676,11 @@ Status ImporterBase::AddNodesToShapeRefiner() { node->AddAttr("dtype", dtype); } } - // Adds the node to the shape refiner. - TF_RETURN_WITH_CONTEXT_IF_ERROR(shape_refiner_->AddNode(node), - GetLocationStr(*node)); + if (!node_added_to_shape_refiner) { + // Add the node to the shape refiner if the node hasn't been removed. + TF_RETURN_WITH_CONTEXT_IF_ERROR(shape_refiner_->AddNode(node), + GetLocationStr(*node)); + } auto set_shape_from_list_attr = [&](const AttrValue* attr) { auto& list = attr->list(); @@ -625,7 +741,7 @@ Status ImporterBase::AddNodesToShapeRefiner() { // Prune nodes in the graph that are not reachable from the output. if (specs_.prune_unused_nodes) { std::unordered_set prune_start; - TF_RETURN_IF_ERROR(GetInputOutputNodes(&prune_start)); + TF_RETURN_IF_ERROR(GetInputOutputNodes(node_name_map, &prune_start)); if (!prune_start.empty()) { if (PruneForReverseReachability(graph_.get(), prune_start)) { VLOG(1) << "Pruned unused nodes in graphdef"; @@ -829,9 +945,11 @@ StatusOr ImporterBase::ConvertAttributeValue( return builder_.getFloatAttr(builder_.getF32Type(), value.f()); case AttrValue::kB: return builder_.getBoolAttr(value.b()); - case AttrValue::kType: - return builder_.getStringAttr( - mangling_util::MangleDataType(value.type())); + case AttrValue::kType: { + mlir::Type type; + TF_RETURN_IF_ERROR(ConvertDataType(value.type(), builder_, &type)); + return mlir::TypeAttr::get(type); + } case AttrValue::kShape: return builder_.getStringAttr(mangling_util::MangleShape(value.shape())); case AttrValue::kTensor: @@ -1106,11 +1224,9 @@ Status ImporterBase::ConvertFunctionArgAndRets( builder_.setInsertionPointToEnd(&graph_op.body().front()); builder_.create(graph_op.getLoc(), inst_to_return); - inst_to_return.assign(graph_op.getResults().begin(), - graph_op.getResults().end()); builder_.setInsertionPointToEnd(bb); builder_.create(mlir::UnknownLoc::get(context_), - inst_to_return); + graph_op.getResults()); return Status::OK(); } @@ -1210,9 +1326,10 @@ std::string ImporterBase::GetLocationStr(const Node& node, } mlir::Operation* ImporterBase::createOperation( - const Node& node, llvm::StringRef op_name, + const Node& node, llvm::StringRef node_type_name, const mlir::OperationState& result, - const llvm::SmallVectorImpl& control_operands) { + const llvm::SmallVectorImpl& control_operands, + bool convert_to_legacy_call) { // For the tf.executor specific operations (not wrapped in an island), we // have an extra returned value for the control result, and we concatenate // control and non-control operands. @@ -1274,11 +1391,31 @@ mlir::Operation* ImporterBase::createOperation( mlir::OpBuilder island_builder(&island.GetBody()); // Create the operation inside the island now. - mlir::Operation* inner_op = island_builder.createOperation(result); + mlir::Operation* inner_op; + if (convert_to_legacy_call) { + bool disable_call_shape_inference = false; + for (const auto& name_and_value : node.attrs()) { + const auto& attr_name = name_and_value.first; + const AttrValue& attr_value = name_and_value.second; + if (strcmp(attr_name.c_str(), + disable_call_shape_inference_attribute_name) == 0 && + attr_value.value_case() == AttrValue::kB) { + disable_call_shape_inference = attr_value.b(); + } + } + + mlir::BoolAttr attribute = + builder_.getBoolAttr(disable_call_shape_inference); + inner_op = island_builder.create( + result.location, result.types, result.operands, + island_builder.getSymbolRefAttr(node_type_name), attribute); + } else { + inner_op = island_builder.createOperation(result); + } // Add the terminator for the island - mlir::SmallVector ret_vals(inner_op->getResults()); - island_builder.create(result.location, ret_vals); + island_builder.create(result.location, + inner_op->getResults()); return island.getOperation(); } @@ -1293,9 +1430,11 @@ Status ImporterBase::ConvertNode(const Node& node) { // create the MLIR function and insert it to the module if it doesn't exist. std::string node_type_name = node.type_string(); const auto* func_def = graph_flib_.Find(node_type_name); + bool convert_to_legacy_call = false; if (func_def) { TF_RETURN_IF_ERROR(ConvertLibFunction(node_type_name)); node_type_name = (*tf_name_to_mlir_name_)[node_type_name]; + convert_to_legacy_call = true; } auto get_full_op_name = [&](const std::string& op_name) { @@ -1380,6 +1519,14 @@ Status ImporterBase::ConvertNode(const Node& node) { for (const auto& name_and_value : node.attrs()) { const auto& attr_name = name_and_value.first; const AttrValue& attr_value = name_and_value.second; + // LegacyCall can only represent _diable_call_shape_inference attribute. + // If a call has other attributes, can't convert it to LegacyCall. + if (convert_to_legacy_call && + (strcmp(attr_name.c_str(), + disable_call_shape_inference_attribute_name) || + attr_value.value_case() != AttrValue::kB)) { + convert_to_legacy_call = false; + } if (attr_value.value_case() == AttrValue::kFunc) { // Attribute iteration order is not defined for protocol buffer Map. // Process function attributes separately in the lexicographical order to @@ -1423,9 +1570,8 @@ Status ImporterBase::ConvertNode(const Node& node) { } // Register the mapping between the TF node and the newly created operation. - node_values_[node.id()] = - createOperation(node, op_name, result, control_operands); - + node_values_[node.id()] = createOperation( + node, node_type_name, result, control_operands, convert_to_legacy_call); return Status::OK(); } @@ -1667,36 +1813,52 @@ StatusOr GraphDefImporter::InferMainFunctionType( const GraphImportConfig& specs, mlir::MLIRContext* context, absl::InlinedVector* arg_nodes, absl::InlinedVector* ret_nodes) { - // Finds out all the input nodes and output nodes. - absl::flat_hash_set output_node_names; - for (const auto& output_tensor : specs.outputs) { - output_node_names.insert(ParseTensorName(output_tensor).node()); + // Find all the input nodes and output nodes. + // Feeds have been remapped to single output nodes (Placeholder), so an exact + // name match is sufficient. + absl::flat_hash_map inputs; + for (auto input_and_idx : llvm::enumerate(specs.inputs)) { + TensorId tensor = ParseTensorName(input_and_idx.value().first); + auto remapped_it = remapped_feeds_.find(tensor); + if (remapped_it != remapped_feeds_.end()) { + inputs.insert({remapped_it->second, input_and_idx.index()}); + } else { + inputs.insert({tensor.node(), input_and_idx.index()}); + } } - if (!specs.inputs.empty() || !specs.outputs.empty()) { - arg_nodes->resize(specs.inputs.size()); - ret_nodes->resize(specs.outputs.size()); + + absl::flat_hash_set output_node_names; + std::vector outputs; + output_node_names.reserve(specs.outputs.size()); + for (const auto& output : specs.outputs) { + TensorId tensor = ParseTensorName(output); + auto remapped_it = remapped_feeds_.find(tensor); + if (remapped_it != remapped_feeds_.end()) { + output_node_names.insert(remapped_it->second); + outputs.push_back({remapped_it->second, 0}); + } else { + output_node_names.insert(tensor.node()); + outputs.push_back(tensor); + } + } + + if (!inputs.empty() || !outputs.empty()) { + arg_nodes->resize(inputs.size()); + ret_nodes->resize(outputs.size()); for (Node* n : GetOrderedNodes()) { // Handle inputs/arguments. - auto input_it = specs.inputs.find(n->name()); - if (input_it != specs.inputs.end()) { - (*arg_nodes)[std::distance(specs.inputs.begin(), input_it)] = {n, 0}; + auto input_it = inputs.find(n->name()); + if (input_it != inputs.end()) { + (*arg_nodes)[input_it->second] = {n, 0}; } // Handle outputs/returns. if (output_node_names.contains(n->name())) { - for (int i = 0, e = specs.outputs.size(); i != e; ++i) { - std::pair name_and_port = - absl::StrSplit(specs.outputs[i], ':'); - auto name = name_and_port.first; - if (name != n->name()) continue; - int port = 0; - if (!name_and_port.second.empty() && - !absl::SimpleAtoi(name_and_port.second, &port)) { - return errors::InvalidArgument("Invalid port specification: ", - specs.outputs[i]); - } - (*ret_nodes)[i] = {n, port}; + for (int i = 0, e = outputs.size(); i != e; ++i) { + TensorId tensor = outputs[i]; + if (n->name() != tensor.node()) continue; + (*ret_nodes)[i] = {n, tensor.index()}; } } } @@ -2118,7 +2280,11 @@ class StructuredValueLinearizer { // Returns the list of index paths to each leaf of the StructuredValue, // in a linearized order matching `tf.nest.flatten`. - llvm::ArrayRef GetLeafIndexPaths() const; + // + // If an error ocurred during the linearization process, an error message with + // `error_context` prepended will be included in the returned status. + StatusOr> GetLeafIndexPaths( + llvm::StringRef error_context) const; private: // Main function that recursively traverses the StructuredValue. @@ -2130,6 +2296,8 @@ class StructuredValueLinearizer { llvm::SmallVector current_index_path_; // The list of leaf index paths we have discovered so far. llvm::SmallVector leaf_index_paths_; + // If non-empty, an error message to report. + std::string error_message_; }; StructuredValueLinearizer::StructuredValueLinearizer( @@ -2138,9 +2306,19 @@ StructuredValueLinearizer::StructuredValueLinearizer( RecursivelyFindLeaves(value); } -llvm::ArrayRef StructuredValueLinearizer::GetLeafIndexPaths() - const { - return leaf_index_paths_; +StatusOr> +StructuredValueLinearizer::GetLeafIndexPaths( + llvm::StringRef error_context) const { + if (error_message_.empty()) { + return llvm::makeArrayRef(leaf_index_paths_); + } + return errors::InvalidArgument( + error_context.str(), error_message_, + "This likely means that you have @tf.function " + "on an exported function instead of " + "@tf.function(input_signature=[...]). Consider annotating an " + "input_signature or narrowing your set of " + "exported names to not include this function."); } void StructuredValueLinearizer::RecursivelyFindLeaves( @@ -2196,7 +2374,20 @@ void StructuredValueLinearizer::RecursivelyFindLeaves( return; } default: { - llvm_unreachable("Unhandled StructuredValue kind!"); + llvm::raw_string_ostream os(error_message_); + // TODO(silvasean): Use an enumerant name string instead of a number. + os << "Unhandled structured value kind " << value.kind_case() + << " at index path: "; + for (auto path_element : current_index_path_) { + os << "."; + if (auto integer = path_element.dyn_cast()) { + os << integer.getValue(); + } else { + auto str = path_element.cast(); + os << str.getValue(); + } + } + os << "\n"; } } } @@ -2290,6 +2481,9 @@ Status CreateSavedModelIR( if (object_names.GetExportedNames(node_id).empty()) { continue; } + std::string error_context = + "While importing SavedModel function '" + + object_names.GetExportedNames(node_id)[0].str() + "': "; const SavedFunction& function = object.function(); auto orig_func = symbol_table.lookup( tf_name_to_mlir_name.find(function.concrete_functions(0))->second); @@ -2314,8 +2508,7 @@ Status CreateSavedModelIR( /*config=*/builder.getStringAttr(""), /*config_proto=*/builder.getStringAttr(""), /*executor_type=*/builder.getStringAttr("")); - body_builder.create( - func.getLoc(), llvm::to_vector<4>(call.getResults())); + body_builder.create(func.getLoc(), call.getResults()); } func.setAttr( "tf_saved_model.exported_names", @@ -2338,9 +2531,12 @@ Status CreateSavedModelIR( int bound_input_base = func.getNumArguments() - concrete_function.bound_inputs_size(); - auto input_index_paths = input_linearizer.GetLeafIndexPaths(); + TF_ASSIGN_OR_RETURN(auto input_index_paths, + input_linearizer.GetLeafIndexPaths( + error_context + "in input signature: ")); if (bound_input_base != input_index_paths.size()) { return errors::InvalidArgument( + error_context, "Argument mismatch between concrete function input signature " "vs underlying FunctionDef for concrete function '", function.concrete_functions(0), "' (", input_index_paths.size(), @@ -2361,9 +2557,12 @@ Status CreateSavedModelIR( StructuredValueLinearizer output_linearizer( concrete_function.output_signature(), builder.getContext()); - auto output_index_paths = output_linearizer.GetLeafIndexPaths(); + TF_ASSIGN_OR_RETURN(auto output_index_paths, + output_linearizer.GetLeafIndexPaths( + error_context + "in output signature: ")); if (func.getNumResults() != output_index_paths.size()) { return errors::InvalidArgument( + error_context, "Result mismatch between concrete function output signature " "vs underlying FunctionDef for concrete function '", function.concrete_functions(0), "' (", output_index_paths.size(), diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_functional_to_executor.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_functional_to_executor.cc index ff397e4b456..86fbff91db1 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_functional_to_executor.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_functional_to_executor.cc @@ -67,8 +67,6 @@ void FunctionalToExecutorDialectConversion::runOnFunction() { LLVM_DEBUG(llvm::dbgs() << "Expect function to end with return\n"); return; } - llvm::SmallVector args = - llvm::to_vector<4>(return_op.getOperands()); // Build GraphOp. OpBuilder builder(&body, body.begin()); auto graph_op = builder.create( @@ -79,10 +77,10 @@ void FunctionalToExecutorDialectConversion::runOnFunction() { loc, getFunction().getType().getResults(), tf_executor::ControlType::get(&getContext()), ArrayRef()); // Create Fetch. - auto to_fetch = llvm::to_vector<4>(island.getResults()); + ValueRange to_fetch = island.getResults(); if (to_fetch.size() != 1) { // Drop control result for fetch. - to_fetch.pop_back(); + to_fetch = to_fetch.drop_back(); } builder.create(loc, to_fetch); // Build Island. @@ -91,7 +89,7 @@ void FunctionalToExecutorDialectConversion::runOnFunction() { island.body().front().begin(), body.getOperations(), copy_range.begin(), copy_range.end()); builder.setInsertionPointToEnd(&island.body().front()); - builder.create(loc, args); + builder.create(loc, return_op.getOperands()); for (auto item : llvm::enumerate(graph_op.getResults())) { return_op.setOperand(item.index(), item.value()); } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc index a5839cf7645..dc9ec6aa8ea 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h" -#include "absl/types/span.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringRef.h" #include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir #include "mlir/IR/Function.h" // TF:local_config_mlir @@ -58,7 +58,7 @@ Status ParseMlirModule(llvm::StringRef mlir_module_string, // Converts arg_shapes to xla::Shape's and store into xla_input_shapes. Status GetXlaInputShapes( - mlir::ModuleOp module, absl::Span arg_shapes, + mlir::ModuleOp module, llvm::ArrayRef arg_shapes, const xla::CustomShapeRepresentationFn shape_representation_fn, std::vector* xla_input_shapes) { xla_input_shapes->clear(); @@ -150,7 +150,8 @@ void GetInputMappingForMlir(int num_inputs, std::vector* input_mapping) { } // Refine MLIR types based on new shape information. -Status RefineShapes(absl::Span arg_shapes, mlir::ModuleOp module) { +Status RefineShapes(llvm::ArrayRef arg_shapes, + mlir::ModuleOp module) { auto versions = module.getAttrOfType<::mlir::DictionaryAttr>("tf.versions"); if (!versions) { return errors::Internal( @@ -234,7 +235,7 @@ Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op, } Status CompileSerializedMlirToXlaHlo( - llvm::StringRef mlir_module_string, absl::Span arg_shapes, + llvm::StringRef mlir_module_string, llvm::ArrayRef arg_shapes, const XlaCompiler::ShapeRepresentationFn shape_representation_fn, XlaCompiler::CompilationResult* compilation_result) { mlir::MLIRContext mlir_context; diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h index 635c1d67f82..a07927ce432 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_COMPILE_MLIR_UTIL_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_COMPILE_MLIR_UTIL_H_ -#include "absl/types/span.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringRef.h" #include "mlir/IR/Module.h" // TF:local_config_mlir #include "tensorflow/compiler/tf2xla/xla_compiler.h" @@ -40,7 +40,7 @@ Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op, // Compiles a serialized MLIR module into XLA HLO, generates all accompanying // metadata and stores them in CompilationResult. Status CompileSerializedMlirToXlaHlo( - llvm::StringRef mlir_module_string, absl::Span arg_shapes, + llvm::StringRef mlir_module_string, llvm::ArrayRef arg_shapes, const XlaCompiler::ShapeRepresentationFn shape_representation_fn, XlaCompiler::CompilationResult* compilation_result); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc index 3574b336f9a..1668cf615f0 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc @@ -41,9 +41,9 @@ TEST(CompileSerializedMlirToXlaHloTest, InvalidSerializedMlirModule) { std::vector arg_shapes; XlaCompiler::CompilationResult compilation_result; - Status s = CompileSerializedMlirToXlaHlo( - invalid_mlir_module, absl::Span(arg_shapes), - TestShapeRepresentation, &compilation_result); + Status s = CompileSerializedMlirToXlaHlo(invalid_mlir_module, arg_shapes, + TestShapeRepresentation, + &compilation_result); EXPECT_EQ(s.code(), tensorflow::errors::Code::INVALID_ARGUMENT); } @@ -61,8 +61,7 @@ TEST(CompileSerializedMlirToXlaHloTest, Success) { XlaCompiler::CompilationResult compilation_result; Status s = CompileSerializedMlirToXlaHlo( - mlir_module, absl::Span(arg_shapes), TestShapeRepresentation, - &compilation_result); + mlir_module, arg_shapes, TestShapeRepresentation, &compilation_result); ASSERT_TRUE(s.ok()); const xla::HloModuleConfig module_config( @@ -134,8 +133,7 @@ TEST(CompileSerializedMlirToXlaHloTest, CompileTimeConstantFoldedSuccess) { XlaCompiler::CompilationResult compilation_result; Status s = CompileSerializedMlirToXlaHlo( - mlir_module, absl::Span(arg_shapes), TestShapeRepresentation, - &compilation_result); + mlir_module, arg_shapes, TestShapeRepresentation, &compilation_result); ASSERT_TRUE(s.ok()); const xla::HloModuleConfig module_config( @@ -174,8 +172,7 @@ TEST(CompileSerializedMlirToXlaHloTest, ShapeInference) { XlaCompiler::CompilationResult compilation_result; Status s = CompileSerializedMlirToXlaHlo( - mlir_module, absl::Span(arg_shapes), TestShapeRepresentation, - &compilation_result); + mlir_module, arg_shapes, TestShapeRepresentation, &compilation_result); TF_ASSERT_OK(s); const xla::HloModuleConfig module_config( diff --git a/tensorflow/compiler/mlir/tensorflow/utils/error_util.h b/tensorflow/compiler/mlir/tensorflow/utils/error_util.h index 198d04e0486..a60d90cbfb7 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/error_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/error_util.h @@ -22,13 +22,11 @@ limitations under the License. #include "mlir/IR/Location.h" // TF:local_config_mlir #include "mlir/IR/MLIRContext.h" // TF:local_config_mlir #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/stream_executor/lib/statusor.h" // Error utilities for MLIR when interacting with code using Status returns. namespace mlir { // TensorFlow's Status is used for error reporting back to callers. -using stream_executor::port::StatusOr; using tensorflow::Status; // Diagnostic handler that collects all the diagnostics reported and can produce diff --git a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc index 69b309f0632..e35b7130de8 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc @@ -34,6 +34,7 @@ limitations under the License. #include "mlir/IR/StandardTypes.h" // TF:local_config_mlir #include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir #include "mlir/Support/DebugStringHelper.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" @@ -253,21 +254,30 @@ StatusOr> GetOperationNodeDef( // Note: we do not use NodeBuilder or NodeDefBuilder as that would require // mapping back from the inputs to the input arguments. - // Some control flow ops in TensorFlow Graph have their respective "Ref" ops - // as well. For example there is Enter and RefEnter op. RefEnter forwards - // the input ref buffer to output. However both Enter and RefEnter are - // mapped to tf_executor::EnterOp during import and then to _tf.Enter op in - // control dialect. Check if it is a Ref op to correctly map to the TensorFlow - // Graph op. llvm::SmallString<64> op_name; - if (IsRefTypeControlOp(inst)) op_name = "Ref"; - - TF_ASSIGN_OR_RETURN(auto tf_name, - GetTensorFlowOpName(inst->getName().getStringRef())); - op_name.append(tf_name); + if (IsLegacyCallInstruction(inst)) { + // The op_name is the name of the function. + op_name.append( + inst->getAttrOfType("f").getLeafReference()); + // Remove the attribute from the instruction as it is already converted to + // op_name. + auto attr_id = mlir::Identifier::get("f", inst->getContext()); + inst->removeAttr(attr_id); + } else { + // Some control flow ops in TensorFlow Graph have their respective "Ref" ops + // as well. For example there is Enter and RefEnter op. RefEnter forwards + // the input ref buffer to output. However both Enter and RefEnter are + // mapped to tf_executor::EnterOp during import and then to _tf.Enter op in + // control dialect. Check if it is a Ref op to correctly map to the + // TensorFlow Graph op. + if (IsRefTypeControlOp(inst)) op_name = "Ref"; + TF_ASSIGN_OR_RETURN(auto tf_name, + GetTensorFlowOpName(inst->getName().getStringRef())); + op_name.append(tf_name); + } + node_def->set_name(name.str()); node_def->set_op(op_name.str()); - node_def->set_name(name); // Add inputs to the NodeDef based on the number of operands. This is required // as later when edges are added to the Node using Graph::AddEdge the @@ -454,4 +464,9 @@ Status SetSizeAttribute(absl::string_view name, size_t size, return Status::OK(); } +bool IsLegacyCallInstruction(mlir::Operation* inst) { + return llvm::dyn_cast(inst) || + inst->getName().getStringRef().compare("_tf.LegacyCall") == 0; +} + } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h index 8d813b53bd8..df176762c07 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h @@ -73,5 +73,16 @@ Status SetShapeAttribute(absl::string_view name, mlir::ShapedType shape, // If the attribute already exists with a different value, returns an error. Status SetSizeAttribute(absl::string_view name, size_t size, AttrValueMap* values); + +// Returns true if the given instruction is an mlir::TF::LegacyCallOp or the +// result of such an operation transformed by the +// ExecutorToControlDialectConversion pass. +// +// TODO(b/145706023): When the ExecutorToControlDialectConversion pass runs +// before the exporter, it mutates an mlir::TF::LegacyCallOp instruction to +// an instruction with a different operation name. As such, this routine checks +// both forms of a LegacyCall instruction. We only need to check for +// mlir::TF::LegacyCallOp when the ticket is resolved. +bool IsLegacyCallInstruction(mlir::Operation* inst); } // namespace tensorflow #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_EXPORTER_UTILS_H_ diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index ac3475cebc4..bf71bcda776 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -23,6 +23,8 @@ package_group( ], ) +exports_files(["ir/hlo_ops.td"]) + filegroup( name = "hlo_ops_td_files", srcs = [ @@ -406,6 +408,7 @@ cc_library( "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client/lib:matrix", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/stream_executor/lib", "@llvm//:support", "@local_config_mlir//:Analysis", "@local_config_mlir//:IR", diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc index 7c95a13285b..1da4fd04ffb 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc @@ -182,13 +182,12 @@ tensorflow::Status HloFunctionImporter::ImportInstructions( // Setup the return type (HLO only supports a single return value). TF_ASSIGN_OR_RETURN(auto result, GetMlirValue(computation->root_instruction())); - llvm::SmallVector return_values({result}); // Create terminator op depending on the parent op of this region. if (llvm::isa(block->getParentOp())) { - builder.create(loc, makeArrayRef(return_values)); + builder.create(loc, result); } else { - builder.create(loc, makeArrayRef(return_values)); + builder.create(loc, result); } return tensorflow::Status::OK(); } @@ -266,32 +265,20 @@ StatusOr HloFunctionImporter::ImportInstruction( MakeAndReturn(CompareOp); } case HloOpcode::kGather: { - const auto& gather_dimensions = instruction->gather_dimension_numbers(); - std::vector offset_dims(gather_dimensions.offset_dims().begin(), - gather_dimensions.offset_dims().end()); + auto gather_instruction = static_cast(instruction); + attributes.push_back(ConvertGatherDimensionNumbers( + gather_instruction->gather_dimension_numbers())); std::vector slice_sizes( - instruction->gather_slice_sizes().begin(), - instruction->gather_slice_sizes().end()); + gather_instruction->gather_slice_sizes().begin(), + gather_instruction->gather_slice_sizes().end()); + attributes.push_back( + builder_->getNamedAttr("slice_sizes", Convert(slice_sizes))); + attributes.push_back(builder_->getNamedAttr( + "indices_are_sorted", + builder_->getBoolAttr(gather_instruction->indices_are_sorted()))); - std::vector collapsed_slice_dims( - gather_dimensions.collapsed_slice_dims().begin(), - gather_dimensions.collapsed_slice_dims().end()); - - std::vector start_index_map( - gather_dimensions.start_index_map().begin(), - gather_dimensions.start_index_map().end()); - - // TODO(b/132057942): Change to explicitly passing an integer instead of - // call getI64IntegerAttr here. - return func_builder - ->create( - loc, result_type, operands[0], operands[1], - func_builder->getI64IntegerAttr( - gather_dimensions.index_vector_dim()), - Convert(offset_dims), Convert(slice_sizes), - Convert(collapsed_slice_dims), Convert(start_index_map)) - .getOperation(); + MakeAndReturn(GatherOp); } case HloOpcode::kDynamicUpdateSlice: { return func_builder @@ -707,4 +694,19 @@ mlir::NamedAttribute HloFunctionImporter::ConvertConvDimensionNumbers( return builder_->getNamedAttr("dimension_numbers", attr); } +mlir::NamedAttribute HloFunctionImporter::ConvertGatherDimensionNumbers( + const xla::GatherDimensionNumbers& dnums) { + std::vector offset_dims(dnums.offset_dims().begin(), + dnums.offset_dims().end()); + std::vector collapsed_slice_dims( + dnums.collapsed_slice_dims().begin(), dnums.collapsed_slice_dims().end()); + std::vector start_index_map(dnums.start_index_map().begin(), + dnums.start_index_map().end()); + auto attr = mlir::xla_hlo::GatherDimensionNumbers::get( + Convert(offset_dims), Convert(collapsed_slice_dims), + Convert(start_index_map), + builder_->getI64IntegerAttr(dnums.index_vector_dim()), context_); + return builder_->getNamedAttr("dimension_numbers", attr); +} + } // namespace xla diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.h b/tensorflow/compiler/mlir/xla/hlo_function_importer.h index 11a6b1c7dd5..bd36c9b2b54 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.h +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.h @@ -117,6 +117,10 @@ class HloFunctionImporter { mlir::NamedAttribute ConvertConvDimensionNumbers( const xla::ConvolutionDimensionNumbers& dnums); + // Converts the gather dimensions to attributes. + mlir::NamedAttribute ConvertGatherDimensionNumbers( + const xla::GatherDimensionNumbers& dnums); + mlir::MLIRContext* context_; mlir::ModuleOp module_; mlir::Builder* builder_; diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc index b2f02bdf76f..08967372bcb 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc @@ -606,7 +606,7 @@ static TensorType GetReduceResultType(Type operand_ty, } void ReduceOp::build(Builder* builder, OperationState& state, - ArrayRef operands, ArrayRef init_values, + ValueRange operands, ValueRange init_values, DenseIntElementsAttr dimensions) { SmallVector result_ty; result_ty.reserve(operands.size()); @@ -845,9 +845,8 @@ Type SliceOp::InferOutputTypes(Builder* builder, Value* operand, // SortOp //===----------------------------------------------------------------------===// -void SortOp::build(Builder* builder, OperationState& state, - ArrayRef operands, int64_t dimension, - bool is_stable) { +void SortOp::build(Builder* builder, OperationState& state, ValueRange operands, + int64_t dimension, bool is_stable) { state.addOperands(operands); state.addAttribute("dimension", builder->getI64IntegerAttr(dimension)); state.addAttribute("is_stable", builder->getBoolAttr(dimension)); @@ -990,7 +989,7 @@ void GetTupleElementOp::build(Builder* builder, OperationState& result, //===----------------------------------------------------------------------===// void TupleOp::build(Builder* builder, OperationState& result, - ArrayRef values) { + ValueRange values) { SmallVector types; types.reserve(values.size()); for (auto val : values) { diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td index c9b3e7985fc..3c4fd473eb6 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td @@ -405,8 +405,8 @@ def HLO_ReduceOp: HLO_Op<"reduce", [ let results = (outs Variadic); let builders = [OpBuilder< - "Builder *, OperationState &state, ArrayRef operands, " - "ArrayRef init_values, DenseIntElementsAttr dimensions" + "Builder *, OperationState &state, ValueRange operands, " + "ValueRange init_values, DenseIntElementsAttr dimensions" >]; let hasFolder = 1; @@ -445,7 +445,7 @@ def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]>, BASE_HLO_TupleOp { let builders = [OpBuilder< "Builder *builder, OperationState &results, " - "ArrayRef values">]; + "ValueRange values">]; // TupleOp has special conversion logic to HLO. let hasCustomHLOConverter = 1; @@ -777,21 +777,25 @@ def HLO_FftOp: HLO_Op<"fft", [NoSideEffect]>, BASE_HLO_FftOp { let hasCustomHLOConverter = 1; } +def GatherDimensionNumbers : StructAttr<"GatherDimensionNumbers", HLO_Dialect, + [StructFieldAttr<"offset_dims", I64ElementsAttr>, + StructFieldAttr<"collapsed_slice_dims", I64ElementsAttr>, + StructFieldAttr<"start_index_map", I64ElementsAttr>, + StructFieldAttr<"index_vector_dim", I64Attr>]> { + let description = "Structure of dimension information for gather"; +} + def HLO_GatherOp: HLO_Op<"gather", [NoSideEffect]>, BASE_HLO_GatherOp { let arguments = (ins HLO_Tensor:$operand, HLO_IntTensor:$start_indices, - I64Attr:$index_vector_dim, - I64ElementsAttr:$offset_dims, + GatherDimensionNumbers:$dimension_numbers, I64ElementsAttr:$slice_sizes, - I64ElementsAttr:$collapsed_slice_dims, - I64ElementsAttr:$start_index_map + DefaultValuedAttr:$indices_are_sorted ); let results = (outs HLO_Tensor); - // TODO(b/129422361) Attributes are not supported by the codegen. The - // optional argument (dimensions) needs to be added as an attribute. let hasCustomHLOConverter = 1; } @@ -880,7 +884,7 @@ def HLO_SortOp : HLO_Op<"sort", [NoSideEffect]>, BASE_HLO_SortOp { let regions = (region SizedRegion<1>:$comparator); let builders = [OpBuilder< - "Builder *builder, OperationState &state, ArrayRef operands, " + "Builder *builder, OperationState &state, ValueRange operands, " "int64_t dimension, bool is_stable" >]; diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index e9bf3bac44b..26cd512aa85 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -40,7 +40,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/stream_executor/lib/statusor.h" +using ::stream_executor::port::StatusOr; using ::tensorflow::int16; using ::tensorflow::int32; using ::tensorflow::int64; @@ -149,6 +151,7 @@ I64_ELEMENTS_ATTR_TO_VECTOR(permutation); I64_ELEMENTS_ATTR_TO_VECTOR(start_indices); I64_ELEMENTS_ATTR_TO_VECTOR(limit_indices); I64_ELEMENTS_ATTR_TO_VECTOR(strides); +I64_ELEMENTS_ATTR_TO_VECTOR(slice_sizes); #undef I64_ELEMENTS_ATTR_TO_VECTOR @@ -267,6 +270,30 @@ static xla::ComparisonDirection Convert_comparison_direction( .ValueOrDie(); } +static xla::GatherDimensionNumbers Convert_gather_dimension_numbers( + mlir::xla_hlo::GatherDimensionNumbers input) { + xla::GatherDimensionNumbers output; + + auto offset_dims = ConvertDenseIntAttr(input.offset_dims()); + std::copy(offset_dims.begin(), offset_dims.end(), + tensorflow::protobuf::RepeatedFieldBackInserter( + output.mutable_offset_dims())); + + auto collapsed_slice_dims = ConvertDenseIntAttr(input.collapsed_slice_dims()); + std::copy(collapsed_slice_dims.begin(), collapsed_slice_dims.end(), + tensorflow::protobuf::RepeatedFieldBackInserter( + output.mutable_collapsed_slice_dims())); + + auto start_index_map = ConvertDenseIntAttr(input.start_index_map()); + std::copy(start_index_map.begin(), start_index_map.end(), + tensorflow::protobuf::RepeatedFieldBackInserter( + output.mutable_start_index_map())); + + output.set_index_vector_dim( + ConvertAPInt(input.index_vector_dim().getValue())); + return output; +} + static xla::ScatterDimensionNumbers Convert_scatter_dimension_numbers( mlir::xla_hlo::ScatterDimensionNumbers input) { xla::ScatterDimensionNumbers output; @@ -496,7 +523,13 @@ LogicalResult ExportXlaOp(DynamicUpdateSliceOp op, OpLoweringContext ctx) { LogicalResult ExportXlaOp(FftOp op, OpLoweringContext ctx) { return failure(); } LogicalResult ExportXlaOp(GatherOp op, OpLoweringContext ctx) { - return failure(); + auto& value_map = *ctx.values; + xla::GatherDimensionNumbers dimension_numbers = + Convert_gather_dimension_numbers(op.dimension_numbers()); + value_map[op] = xla::Gather( + value_map[op.operand()], value_map[op.start_indices()], dimension_numbers, + Convert_slice_sizes(op.slice_sizes()), op.indices_are_sorted()); + return success(); } LogicalResult ExportXlaOp(IotaOp op, OpLoweringContext ctx) { diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index 8aa9b5ef101..c95b2c86960 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -25,18 +25,25 @@ func @fusedBatchNorm_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, // CHECK-LABEL: func @biasAdd_NHWC func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { - // CHECK-NEXT: %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} + // CHECK: "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> return %0 : tensor<1x32x10x32xi32> } // CHECK-LABEL: func @biasAdd_NCHW func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { - // CHECK-NEXT: %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK: "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NCHW"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> return %0 : tensor<1x32x10x32xi32> } +// CHECK-LABEL: func @biasAdd_dynamic +func @biasAdd_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} + %0 = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NCHW"} : (tensor, tensor) -> tensor + return %0 : tensor +} + //===----------------------------------------------------------------------===// // Binary op legalizations. //===----------------------------------------------------------------------===// @@ -666,11 +673,18 @@ func @preventgradient(%arg0: tensor<1xi32>) -> tensor<1xi32> { // CHECK-LABEL: @const func @const() -> tensor<2xi32> { - // CHECK-NEXT: xla_hlo.constant dense<0> : tensor<2xi32> + // CHECK: xla_hlo.constant dense<0> : tensor<2xi32> %0 = "tf.Const"() {device = "", name = "", dtype = "tfdtype$DT_INT32", value = dense<0> : tensor<2xi32>} : () -> (tensor<2xi32>) return %0: tensor<2xi32> } +// CHECK-LABEL: @const_dynamic_output +func @const_dynamic_output() -> tensor<*xi32> { + // CHECK: xla_hlo.constant {value = dense<0> : tensor<2xi32>} : tensor<*xi32> + %0 = "tf.Const"() {value = dense<0> : tensor<2xi32>} : () -> (tensor<*xi32>) + return %0: tensor<*xi32> +} + // CHECK-LABEL: @opaque_const func @opaque_const() -> tensor>> { // CHECK-NOT: xla_hlo.constant @@ -838,13 +852,14 @@ func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> { } // CHECK-LABEL: func @relu_grad -// CHECK-SAME: (%[[GRADIENTS:.*]]: tensor<4x8xf32>, %[[FEATURES:.*]]: tensor<4x8xf32>) -func @relu_grad(%gradients: tensor<4x8xf32>, %features: tensor<4x8xf32>) -> tensor<4x8xf32> { - // CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<4x8xf32> - // CHECK: %[[PRED:.*]] = "xla_hlo.compare"(%[[FEATURES]], %[[ZERO]]) {comparison_direction = "GT"} : (tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xi1> - // CHECK: %[[RESULT:.*]] = "xla_hlo.select"(%[[PRED]], %[[GRADIENTS]], %[[ZERO]]) : (tensor<4x8xi1>, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32> - // CHECK: return %[[RESULT]] : tensor<4x8xf32> - %2 = "tf.ReluGrad"(%gradients, %features) : (tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32> +// CHECK-SAME: (%[[GRADIENTS:.*]]: tensor<4x8xf32>, %[[FEATURES:.*]]: tensor) +func @relu_grad(%gradients: tensor<4x8xf32>, %features: tensor) -> tensor<4x8xf32> { + // CHECK-DAG: %[[ZERO_SCALAR:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor + // CHECK-DAG: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<4x8xf32> + // CHECK-DAG: %[[PRED:.*]] = "xla_hlo.compare"(%[[FEATURES]], %[[ZERO_SCALAR]]) {comparison_direction = "GT"} : (tensor, tensor) -> tensor<*xi1> + // CHECK-DAG: %[[RESULT:.*]] = "xla_hlo.select"(%[[PRED]], %[[GRADIENTS]], %[[ZERO]]) : (tensor<*xi1>, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32> + // CHECK-DAG: return %[[RESULT]] : tensor<4x8xf32> + %2 = "tf.ReluGrad"(%gradients, %features) : (tensor<4x8xf32>, tensor) -> tensor<4x8xf32> return %2 : tensor<4x8xf32> } @@ -1019,6 +1034,14 @@ func @transpose_2d(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> { return %0 : tensor<3x2xf32> } +// CHECK-LABEL: @transpose_3d_int32 +func @transpose_3d_int32(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { + %permutation = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi32>} : () -> (tensor<3xi32>) + // CHECK: "xla_hlo.transpose" + %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<1x2x3xf32>, tensor<3xi32>) -> tensor<3x2x1xf32> + return %0 : tensor<3x2x1xf32> +} + // CHECK-LABEL: @transpose_3d func @transpose_3d(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { %permutation = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi64>} : () -> (tensor<3xi64>) @@ -1344,35 +1367,42 @@ func @tanh_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { // CHECK-LABEL: reshape func @reshape(%arg0: tensor<2xf32>, %arg1: tensor<2xi32>) -> tensor<1x1xf32> { - // CHECK: %0 = "xla_hlo.reshape"(%arg0) : (tensor<2xf32>) -> tensor<1x1xf32> + // CHECK: "xla_hlo.reshape" %0 = "tf.Reshape"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xi32>) -> tensor<1x1xf32> return %0 : tensor<1x1xf32> } // CHECK-LABEL: reshape_dynamic -func @reshape_dynamic(%arg0: tensor<*xf32>, %arg1: tensor<2xi32>) -> tensor { - // CHECK: %0 = "tf.Reshape"(%arg0, %arg1) : (tensor<*xf32>, tensor<2xi32>) -> tensor +func @reshape_dynamic(%arg0: tensor, %arg1: tensor<2xi32>) -> tensor<1x1xf32> { + // CHECK: "xla_hlo.reshape" + %0 = "tf.Reshape"(%arg0, %arg1) : (tensor, tensor<2xi32>) -> tensor<1x1xf32> + return %0 : tensor<1x1xf32> +} + +// CHECK-LABEL: reshape_unranked +func @reshape_unranked(%arg0: tensor<*xf32>, %arg1: tensor<2xi32>) -> tensor { + // CHECK: "tf.Reshape" %0 = "tf.Reshape"(%arg0, %arg1) : (tensor<*xf32>, tensor<2xi32>) -> tensor return %0 : tensor } // CHECK-LABEL: squeeze func @squeeze(%arg0: tensor<1x1x10xf32>) -> tensor<1x10xf32> { - // CHECK-NEXT: %0 = "xla_hlo.reshape"(%arg0) : (tensor<1x1x10xf32>) -> tensor<1x10xf32> + // CHECK: "xla_hlo.reshape" %0 = "tf.Squeeze"(%arg0) : (tensor<1x1x10xf32>) -> tensor<1x10xf32> return %0 : tensor<1x10xf32> } // CHECK-LABEL: squeeze_dynamic func @squeeze_dynamic(%arg0: tensor) -> tensor<*xf32> { - // CHECK-NEXT: %0 = "tf.Squeeze"(%arg0) : (tensor) -> tensor<*xf32> + // CHECK: "tf.Squeeze" %0 = "tf.Squeeze"(%arg0) : (tensor) -> tensor<*xf32> return %0 : tensor<*xf32> } // CHECK-LABEL: expand_dims func @expand_dims(%arg0: tensor<2xf32>, %axis: tensor) -> tensor<1x2xf32> { - // CHECK: "xla_hlo.reshape"{{.*}} : (tensor<2xf32>) -> tensor<1x2xf32> + // CHECK: "xla_hlo.reshape" %0 = "tf.ExpandDims"(%arg0, %axis) : (tensor<2xf32>, tensor) -> tensor<1x2xf32> return %0 : tensor<1x2xf32> } @@ -1380,7 +1410,8 @@ func @expand_dims(%arg0: tensor<2xf32>, %axis: tensor) -> tensor<1x2xf32> { // CHECK-LABEL: slice_constant_start func @slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> { // CHECK: %[[START:.*]] = xla_hlo.constant dense<1> : tensor<1xi64> - // CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[START]]) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor<1xi64>) -> tensor<2xi32> + // CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START]]) : (tensor<1xi64>) -> tensor<1xi64> + // CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[START_I64]]) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor<1xi64>) -> tensor<2xi32> // CHECK: return %[[RESULT]] : tensor<2xi32> %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>) %sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi64>} : () -> (tensor<1xi64>) @@ -1388,10 +1419,22 @@ func @slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> { return %0 : tensor<2xi32> } +// CHECK-LABEL: slice_i32_consts +func @slice_i32_consts(%arg0: tensor<4xi32>) -> tensor<2xi32> { + // CHECK: %[[START:.*]] = xla_hlo.constant dense<1> : tensor<1xi32> + // CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START]]) : (tensor<1xi32>) -> tensor<1xi64> + // CHECK: slice_sizes = dense<2> : tensor<1xi64> + %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi32>} : () -> (tensor<1xi32>) + %sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi32>} : () -> (tensor<1xi32>) + %0 = "tf.Slice"(%arg0, %starts, %sizes) : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + return %0 : tensor<2xi32> +} + // CHECK-LABEL: slice_constant_start_negative_one_size func @slice_constant_start_negative_one_size(%arg0: tensor<4xi32>) -> tensor<3xi32> { // CHECK: %[[START:.*]] = xla_hlo.constant dense<1> : tensor<1xi64> - // CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[START]]) {slice_sizes = dense<3> : tensor<1xi64>} : (tensor<4xi32>, tensor<1xi64>) -> tensor<3xi32> + // CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START]]) : (tensor<1xi64>) -> tensor<1xi64> + // CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[START_I64]]) {slice_sizes = dense<3> : tensor<1xi64>} : (tensor<4xi32>, tensor<1xi64>) -> tensor<3xi32> // CHECK: return %[[RESULT]] : tensor<3xi32> %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>) %sizes = "tf.Const"() {value = dense<[-1]> : tensor<1xi64>} : () -> (tensor<1xi64>) @@ -1402,7 +1445,8 @@ func @slice_constant_start_negative_one_size(%arg0: tensor<4xi32>) -> tensor<3xi // CHECK-LABEL: slice_constant_start_dynamic_shape func @slice_constant_start_dynamic_shape(%arg0: tensor, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { // CHECK: %[[START:.*]] = xla_hlo.constant dense<[1, 0]> : tensor<2xi64> - // CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[START]]) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor, tensor<2xi64>) -> tensor<1x4xi32> + // CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START]]) : (tensor<2xi64>) -> tensor<2xi64> + // CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[START_I64]]) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor, tensor<2xi64>) -> tensor<1x4xi32> // CHECK: return %[[RESULT]] : tensor<1x4xi32> %starts = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>) %sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>) @@ -1412,7 +1456,8 @@ func @slice_constant_start_dynamic_shape(%arg0: tensor, %arg1: tensor<2 // CHECK-LABEL: slice_variable_start func @slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { - // CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xi32> + // CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%arg1) : (tensor<2xi64>) -> tensor<2xi64> + // CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[START_I64]]) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xi32> // CHECK: return %[[RESULT]] : tensor<1x4xi32> %sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>) %0 = "tf.Slice"(%arg0, %arg1, %sizes) : (tensor<3x4xi32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32> @@ -1525,6 +1570,16 @@ func @mean(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { return %0 : tensor<4x1xf16> } +// CHECK-LABEL: func @mean_scalar_dim +func @mean_scalar_dim(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { + // Verify that tf.Mean op with scalar attributes are lowered successfully. + + // CHECK-NOT: tf.Mean + %dimension = "tf.Const"() { value = dense<1> : tensor } : () -> tensor + %0 = "tf.Mean"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor) -> tensor<4x1xf16> + return %0 : tensor<4x1xf16> +} + // CHECK-LABEL: func @mean_dynamic func @mean_dynamic(%arg0: tensor<4x?xf16>) -> tensor<4x1xf16> { %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> @@ -1601,6 +1656,66 @@ func @max_dynamic(%arg0: tensor<4x?xf16>) -> tensor<4x1xf16> { return %0 : tensor<4x1xf16> } +// CHECK-LABEL: @all +func @all(%input: tensor<4x8xi1>) -> tensor<4xi1> { + %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK: %[[INIT:.*]] = xla_hlo.constant dense : tensor + // CHECK: "xla_hlo.reduce"(%{{.*}}, %[[INIT]]) ( { + // CHECK: ^{{.*}}(%[[ARGA:.*]]: tensor, %[[ARGB:.*]]: tensor): + // CHECK: %[[AND:.*]] = xla_hlo.and %[[ARGA]], %[[ARGB]] : tensor + // CHECK: "xla_hlo.return"(%[[AND]]) : (tensor) -> () + // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xi1>, tensor) -> tensor<4xi1> + %0 = "tf.All"(%input, %dims) : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4xi1> + return %0 : tensor<4xi1> +} + +// CHECK-LABEL: @all_keep_dim +func @all_keep_dim(%input: tensor<4x8xi1>) -> tensor<4x1xi1> { + // CHECK: "xla_hlo.reshape"(%{{.*}}) : (tensor<4xi1>) -> tensor<4x1xi1> + %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %0 = "tf.All"(%input, %dims) {keep_dims = true} : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4x1xi1> + return %0 : tensor<4x1xi1> +} + +// CHECk-LABEL: @all_dynamic +func @all_dynamic(%input: tensor<4x?xi1>) -> tensor<4x1xi1> { + %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK: %[[ARG:.*]] = "xla_hlo.convert"(%{{.*}}) : (tensor<4x?xi1>) -> tensor<4x?xi1> + // CHECK: "xla_hlo.reduce"(%[[ARG]] + %0 = "tf.All"(%input, %dims) {keep_dims = true} : (tensor<4x?xi1>, tensor<1xi32>) -> tensor<4x1xi1> + return %0 : tensor<4x1xi1> +} + +// CHECK-LABEL: @any +func @any(%input: tensor<4x8xi1>) -> tensor<4xi1> { + %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK: %[[INIT:.*]] = xla_hlo.constant dense : tensor + // CHECK: "xla_hlo.reduce"(%{{.*}}, %[[INIT]]) ( { + // CHECK: ^{{.*}}(%[[ARGA:.*]]: tensor, %[[ARGB:.*]]: tensor): + // CHECK: %[[AND:.*]] = xla_hlo.or %[[ARGA]], %[[ARGB]] : tensor + // CHECK: "xla_hlo.return"(%[[AND]]) : (tensor) -> () + // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xi1>, tensor) -> tensor<4xi1> + %0 = "tf.Any"(%input, %dims) : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4xi1> + return %0 : tensor<4xi1> +} + +// CHECK-LABEL: @any_keep_dim +func @any_keep_dim(%input: tensor<4x8xi1>) -> tensor<4x1xi1> { + // CHECK: "xla_hlo.reshape"(%{{.*}}) : (tensor<4xi1>) -> tensor<4x1xi1> + %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %0 = "tf.Any"(%input, %dims) {keep_dims = true} : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4x1xi1> + return %0 : tensor<4x1xi1> +} + +// CHECk-LABEL: @any_dynamic +func @any_dynamic(%input: tensor<4x?xi1>) -> tensor<4x1xi1> { + %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK: %[[ARG:.*]] = "xla_hlo.convert"(%{{.*}}) : (tensor<4x?xi1>) -> tensor<4x?xi1> + // CHECK: "xla_hlo.reduce"(%[[ARG]] + %0 = "tf.Any"(%input, %dims) {keep_dims = true} : (tensor<4x?xi1>, tensor<1xi32>) -> tensor<4x1xi1> + return %0 : tensor<4x1xi1> +} + //===----------------------------------------------------------------------===// // Tile op legalizations. //===----------------------------------------------------------------------===// @@ -1924,12 +2039,23 @@ func @split_match_and_split_into_two(%input: tensor<4x6xf32>) -> (tensor<2x6xf32 return %0#0, %0#1 : tensor<2x6xf32>, tensor<2x6xf32> } +// CHECK-LABEL: @split_match_and_split_into_two_dynamic +func @split_match_and_split_into_two_dynamic(%input: tensor<4x?xf32>) -> (tensor<2x?xf32>, tensor<2x?xf32>) { + %cst = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + // CHECK: %[[ONE:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[2, -1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x?xf32>) -> tensor<2x?xf32> + // CHECK: %[[TWO:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[4, -1]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x?xf32>) -> tensor<2x?xf32> + %0:2 = "tf.Split"(%cst, %input) : (tensor, tensor<4x?xf32>) -> (tensor<2x?xf32>, tensor<2x?xf32>) + // CHECK: return %[[ONE]], %[[TWO]] + return %0#0, %0#1 : tensor<2x?xf32>, tensor<2x?xf32> +} + // CHECK-LABEL: @split_match_and_split_into_three +// CHECK-SAME: (%[[ARG:.*]]: tensor<4x6xf32>) func @split_match_and_split_into_three(%input: tensor<4x6xf32>) -> (tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>) { %cst = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - // CHECK: %[[ONE:.*]] = "xla_hlo.slice"(%arg0) {limit_indices = dense<[4, 2]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32> - // CHECK: %[[TWO:.*]] = "xla_hlo.slice"(%arg0) {limit_indices = dense<4> : tensor<2xi64>, start_indices = dense<[0, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32> - // CHECK: %[[THREE:.*]] = "xla_hlo.slice"(%arg0) {limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[0, 4]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32> + // CHECK: %[[ONE:.*]] = "xla_hlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 2]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32> + // CHECK: %[[TWO:.*]] = "xla_hlo.slice"(%[[ARG]]) {limit_indices = dense<4> : tensor<2xi64>, start_indices = dense<[0, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32> + // CHECK: %[[THREE:.*]] = "xla_hlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[0, 4]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32> %0:3 = "tf.Split"(%cst, %input) : (tensor, tensor<4x6xf32>) -> (tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>) // CHECK: return %[[ONE]], %[[TWO]], %[[THREE]] return %0#0, %0#1, %0#2 : tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32> @@ -1973,3 +2099,82 @@ func @topk_v2(%input: tensor<16x16xf32>) -> (tensor<16x8xf32>, tensor<16x8xi32>) %0:2 = "tf.TopKV2"(%input, %k): (tensor<16x16xf32>, tensor) -> (tensor<16x8xf32>, tensor<16x8xi32>) return %0#0, %0#1: tensor<16x8xf32>, tensor<16x8xi32> } + +//===----------------------------------------------------------------------===// +// tf.SplitV legalization +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @splitv_match_and_split_into_three +// CHECK-SAME: (%[[ARG:.*]]: tensor<4x6xf32>) +func @splitv_match_and_split_into_three(%input: tensor<4x6xf32>) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) { + %split_sizes = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32> + %split_dim = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + // CHECK: %[[ONE:.*]] = "xla_hlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x1xf32> + // CHECK: %[[TWO:.*]] = "xla_hlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32> + // CHECK: %[[THREE:.*]] = "xla_hlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[0, 3]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x3xf32> + %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x6xf32>, tensor<3xi32>, tensor) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) + // CHECK: return %[[ONE]], %[[TWO]], %[[THREE]] + return %0#0, %0#1, %0#2 : tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32> +} + +// CHECK-LABEL: @splitv_match_and_split_into_three_dynamic +func @splitv_match_and_split_into_three_dynamic(%input: tensor) -> (tensor, tensor, tensor) { + %split_sizes = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32> + %split_dim = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + // CHECK: "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[-1, 1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor) -> tensor + // CHECK: "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[-1, 3]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor) -> tensor + // CHECK: "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[-1, 6]> : tensor<2xi64>, start_indices = dense<[0, 3]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor) -> tensor + %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor, tensor<3xi32>, tensor) -> (tensor, tensor, tensor) + return %0#0, %0#1, %0#2 : tensor, tensor, tensor +} + +// CHECK-LABEL: @splitv_dynamic_dim_in_split_sizes +func @splitv_dynamic_dim_in_split_sizes(%input: tensor<4x6xf32>) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) { + %split_sizes = "tf.Const"() {value = dense<[1, -1, 3]> : tensor<3xi32>} : () -> tensor<3xi32> + %split_dim = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + // CHECK: limit_indices = dense<[4, 1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64> + // CHECK: limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64> + // CHECK: limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[0, 3]> : tensor<2xi64> + %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x6xf32>, tensor<3xi32>, tensor) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) + return %0#0, %0#1, %0#2 : tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32> +} + +//===----------------------------------------------------------------------===// +// tf.Assert legalization +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @assert +func @assert(%arg0: tensor, %arg1: tensor<*xf32>) { + // CHECK-NOT: tf.Assert + "tf.Assert"(%arg0, %arg1) {summarize = 1} : (tensor, tensor<*xf32>) -> () + return +} + +//===----------------------------------------------------------------------===// +// tf.Unpack legalization +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @unpack +func @unpack(%input: tensor<4x3x6xf32>) -> (tensor<4x?xf32>, tensor<4x6xf32>, tensor<4x6xf32>) { + // CHECK: %[[SLICE1:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[4, 1, 6]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> + // CHECK: %[[RES1:.*]] = "xla_hlo.reshape"(%[[SLICE1]]) : (tensor<4x1x6xf32>) -> tensor<4x?xf32> + // CHECK: %[[SLICE2:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[4, 2, 6]> : tensor<3xi64>, start_indices = dense<[0, 1, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> + // CHECK: %[[RES2:.*]] = "xla_hlo.reshape"(%[[SLICE2]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[SLICE3:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[4, 3, 6]> : tensor<3xi64>, start_indices = dense<[0, 2, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> + // CHECK: %[[RES3:.*]] = "xla_hlo.reshape"(%[[SLICE3]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32> + + %0:3 = "tf.Unpack"(%input) {axis = 1} : (tensor<4x3x6xf32>) -> (tensor<4x?xf32>, tensor<4x6xf32>, tensor<4x6xf32>) + // return %[[RES1]], %[[RES2]], %[[RES3]] + return %0#0, %0#1, %0#2 : tensor<4x?xf32>, tensor<4x6xf32>, tensor<4x6xf32> +} + +// CHECK-LABEL: @unpack_dynamic +func @unpack_dynamic(%input: tensor) -> (tensor, tensor) { + // CHECK: %[[SLICE1:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[-1, -1, 1]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor) -> tensor + // CHECK: "xla_hlo.reshape"(%[[SLICE1]]) : (tensor) -> tensor + // CHECK: %[[SLICE2:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[-1, -1, 2]> : tensor<3xi64>, start_indices = dense<[0, 0, 1]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor) -> tensor + // CHECK: "xla_hlo.reshape"(%[[SLICE2]]) : (tensor) -> tensor + + %0:2 = "tf.Unpack"(%input) {axis = -1} : (tensor) -> (tensor, tensor) + return %0#0, %0#1 : tensor, tensor +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir index ffcc1cc9df3..85ed317f8c6 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir @@ -317,16 +317,33 @@ func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> { // ----- -// CHECK-LABEL: HloModule +// CHECK-LABEL: HloModule func @main(%arg0: tensor<3x4xi32>, %arg1: tensor<4x5xi32>) -> tensor<3x5xi32> { // Simple einsum is lowered to HLO dot op. - // CHECK: dot(s32[3,4] %{{.*}}, s32[4,5] %{{.*}}), lhs_contracting_dims={1}, rhs_contracting_dims={0} + // CHECK: dot(s32[3,4] %{{.*}}, s32[4,5] %{{.*}}), lhs_contracting_dims={1}, rhs_contracting_dims={0} %0 = "xla_hlo.einsum"(%arg0, %arg1) {einsum_config = "ab,bc->ac"} : (tensor<3x4xi32>, tensor<4x5xi32>) -> tensor<3x5xi32> return %0 : tensor<3x5xi32> } // ----- +// CHECK-LABEL: HloModule +func @main(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> tensor<10x300xf32> { + // CHECK: [[ARG0:%.*]] = f32[200,100,300] parameter(0) + // CHECK: [[ARG1:%.*]] = s32[10,2] parameter(1) + // CHECK: f32[10,300] gather(f32[200,100,300] [[ARG0]], s32[10,2] [[ARG1]]) + // CHECK-SAME: offset_dims={1} + // CHECK-SAME: collapsed_slice_dims={0,1} + // CHECK-SAME: start_index_map={0,1} + // CHECK-SAME: index_vector_dim=1 + // CHECK-SAME: slice_sizes={1,1,300} + // CHECK-SAME: indices_are_sorted=true + %0 = "xla_hlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<[0, 1]> : tensor<2xi64>, index_vector_dim = 1 : i64, offset_dims = dense<1> : tensor<1xi64>, start_index_map = dense<[0, 1]> : tensor<2xi64>}, indices_are_sorted = true, name = "gather", slice_sizes = dense<[1, 1, 300]> : tensor<3xi64>} : (tensor<200x100x300xf32>, tensor<10x2xi32>) -> tensor<10x300xf32> + return %0 : tensor<10x300xf32> +} + +// ----- + // CHECK-LABEL: HloModule func @main(%arg: tensor<4x2xf32>) -> tensor { %0 = "xla_hlo.get_dimension_size"(%arg) {dimension = 1 : i32} : (tensor<4x2xf32>) -> tensor diff --git a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt index 77d74253132..a68e0237b14 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt @@ -317,6 +317,28 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] { ROOT %floor.2 = f32[16] floor(f32[16] %arg0.1) } +// CHECK-LABEL: func @test_gather( +// CHECK-SAME: [[ARG0:%.+]]: tensor<200x100x300xf32>, [[ARG1:%.+]]: tensor<10x2xi32>) -> tensor<10x300xf32> { +%test_gather (arg.0: f32[200,100,300], arg.1: s32[10,2]) -> f32[10,300] { + %arg.0 = f32[200,100,300] parameter(0) + %arg.1 = s32[10,2] parameter(1) + // CHECK: "xla_hlo.gather"([[ARG0]], [[ARG1]]) + // CHECK-SAME: dimension_numbers + // CHECK-SAME: collapsed_slice_dims = dense<[0, 1]> : tensor<2xi64> + // CHECK-SAME: index_vector_dim = 1 : i64 + // CHECK-SAME: offset_dims = dense<1> : tensor<1xi64> + // CHECK-SAME: start_index_map = dense<[0, 1]> : tensor<2xi64> + // CHECK-SAME: indices_are_sorted = true + // CHECK-SAME: slice_sizes = dense<[1, 1, 300]> : tensor<3xi64> + ROOT gather = f32[10,300] gather(f32[200,100,300] %arg.0, s32[10,2] %arg.1), + collapsed_slice_dims={0,1}, + index_vector_dim=1, + offset_dims={1}, + start_index_map={0,1}, + indices_are_sorted=true, + slice_sizes={1,1,300} +} + // CHECK-LABEL: func @test_get_dimension_size // CHECK-SAME: ([[ARG:%.*]]: tensor<4x2xf32>) %test_get_dimension_size (Arg_0.1: f32[4,2]) -> s32[] { diff --git a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc index 58d5b7aa02b..4a74fe4b2ae 100644 --- a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc +++ b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir #include "mlir/IR/Attributes.h" // TF:local_config_mlir +#include "mlir/IR/BlockAndValueMapping.h" // TF:local_config_mlir #include "mlir/IR/Builders.h" // TF:local_config_mlir #include "mlir/IR/Function.h" // TF:local_config_mlir #include "mlir/IR/Location.h" // TF:local_config_mlir @@ -38,13 +39,19 @@ namespace { constexpr StringRef kTempBufferAttr = "temp"; -Value* GetTensorStoreMemRef(Value* value) { +Value* GetTensorStoreOrReturnMemRef(Value* value) { for (const auto& user : value->getUsers()) { if (auto tensor_store = dyn_cast(user)) { if (tensor_store.getOperand(0) == value) { return tensor_store.getOperand(1); } } + if (auto return_op = dyn_cast(user)) { + if (return_op.getOperand(0) == value) { + auto block = return_op.getOperation()->getBlock(); + return *block->args_rbegin(); + } + } } return nullptr; } @@ -88,8 +95,8 @@ Value* InsertAllocAndDealloc(Location loc, Value* result, /// function to store that values held in the tensor. Value* GetBufferForResultValue(Location loc, Value* result, ConversionPatternRewriter* rewriter) { - if (auto tensor_store_memref = GetTensorStoreMemRef(result)) { - return tensor_store_memref; + if (auto existing_memref = GetTensorStoreOrReturnMemRef(result)) { + return existing_memref; } return InsertAllocAndDealloc(loc, result, rewriter); } @@ -117,7 +124,63 @@ class HloToLhloOpConverter : public ConversionPattern { rewriter.create(op->getLoc(), llvm::None, buffer_args, op->getAttrs()); rewriter.replaceOp(op, ArrayRef(buffer_args).slice(operands.size()), - llvm::to_vector<4>(original_results)); + original_results); + return matchSuccess(); + } +}; + +struct HloToLHloReduceConverter + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + PatternMatchResult matchAndRewrite( + xla_hlo::ReduceOp op, ArrayRef operands, + ConversionPatternRewriter& rewriter) const final { + auto loc = op.getLoc(); + // TODO(b/137624192) Implement variadic reduce. + if (op.getNumResults() != 1) return matchFailure(); + if (op.getParentRegion()->getBlocks().size() != 1) { + emitError(loc, + "tensor to buffer conversion expects a single block in the " + "region containing the operation"); + } + const auto& original_results = op.getResults(); + SmallVector buffer_args(operands.begin(), operands.end()); + for (auto result : original_results) { + buffer_args.push_back(GetBufferForResultValue(loc, result, &rewriter)); + } + auto new_op = rewriter.create( + loc, llvm::None, buffer_args, op.getAttrs()); + + // Copy over the operations inside the region. + rewriter.inlineRegionBefore(op.body(), new_op.body(), new_op.body().end()); + + // Create new block arguments with correct type. + auto& entry_block = new_op.body().front(); + int original_arg_count = entry_block.getNumArguments(); + for (int i = 0; i < original_arg_count; ++i) { + auto old_arg = entry_block.getArgument(i); + auto old_type = old_arg->getType().cast(); + auto new_type = + MemRefType::get(old_type.getShape(), old_type.getElementType()); + auto new_arg = entry_block.addArgument(new_type); + rewriter.replaceUsesOfBlockArgument(old_arg, new_arg); + } + // Add an argument for the result. + entry_block.addArgument( + entry_block.getArgument(original_arg_count)->getType()); + // Remove the old arguments. + for (int i = original_arg_count - 1; i >= 0; --i) { + entry_block.eraseArgument(i); + } + // Insert terminator at the end. + rewriter.setInsertionPointToEnd(&entry_block); + rewriter.create(loc); + + rewriter.replaceOp(op, ArrayRef(buffer_args).slice(operands.size()), + original_results); + return matchSuccess(); } }; @@ -130,11 +193,12 @@ class HloToLhloTensorLoadConverter : public ConversionPattern { PatternMatchResult matchAndRewrite( Operation* op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { - rewriter.replaceOp(op, operands, llvm::to_vector<4>(op->getResults())); + rewriter.replaceOp(op, operands, op->getResults()); return matchSuccess(); } }; +// TODO(b/137624192): Rewrite into a copy and elide copy if possible. class HloToLhloTensorStoreConverter : public ConversionPattern { public: explicit HloToLhloTensorStoreConverter(MLIRContext* context) @@ -148,6 +212,19 @@ class HloToLhloTensorStoreConverter : public ConversionPattern { } }; +// TODO(b/137624192): Rewrite into a copy and elide copy if possible. +class HloToLhloReturnConverter : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + PatternMatchResult matchAndRewrite( + xla_hlo::ReturnOp op, ArrayRef operands, + ConversionPatternRewriter& rewriter) const final { + rewriter.eraseOp(op); + return matchSuccess(); + } +}; + // Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary // buffers if necessary. // @@ -215,6 +292,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context, xla_lhlo::BroadcastInDimOp>, HloToLhloOpConverter, HloToLhloOpConverter, + HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, @@ -229,6 +307,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, + HloToLHloReduceConverter, HloToLhloReturnConverter, HloToLhloTensorLoadConverter, HloToLhloTensorStoreConverter >(context); // clang-format on diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_control_flow.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_control_flow.cc index 9be161851d9..8a8afc01bec 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_control_flow.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_control_flow.cc @@ -53,8 +53,7 @@ LogicalResult ReplaceTerminators(Region* region, Block* target_block, auto return_op = dyn_cast(block->getTerminator()); if (!return_op) continue; builder->setInsertionPointToEnd(block); - builder->create( - loc, target_block, llvm::to_vector<4>(return_op.getOperands())); + builder->create(loc, target_block, return_op.getOperands()); return_op.erase(); } @@ -196,8 +195,7 @@ LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) { dyn_cast(new_block->getTerminator()); if (!return_op) continue; builder.setInsertionPointToEnd(new_block); - builder.create(loc, cond_block, - llvm::to_vector<4>(return_op.getOperands())); + builder.create(loc, cond_block, return_op.getOperands()); return_op.erase(); } diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index f0ba67e2fd5..02a9c7e69e0 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -127,8 +127,8 @@ static llvm::Optional GetIntegerHLOAxisFromTFAxis(Value *value, /// Returns a `ConvertOp` that casts the elements to a i64 type while retaining /// the shape of the input value. -static ConvertOp CastElementsToI64(Location loc, Value *value, - PatternRewriter *rewriter) { +static ConvertOp CastValueToI64(Location loc, Value *value, + PatternRewriter *rewriter) { return rewriter->create(loc, value, rewriter->getIntegerType(64)); } @@ -207,7 +207,8 @@ static IntegerAttr getFeatureDimensionAttr(Builder &b, StringAttr format, // Bias op utilities. //===----------------------------------------------------------------------===// -/// Return a 1D DenseIntElementsAttr for the feature dimension of a BiasAdd. +// Return a 1D DenseIntElementsAttr for the feature dimension of a BiasAdd. +// Requires input to have ranked tensor. static DenseIntElementsAttr getBiasFeatureDimension(Builder &b, StringAttr format, Value *input) { @@ -418,7 +419,8 @@ static DenseIntElementsAttr TFSliceSizes2HLOSliceSizes( Builder *builder) { DenseIntElementsAttr constant_start_indices; if (!matchPattern(start_indices, m_Constant(&constant_start_indices))) { - return slice_sizes; + return xla::ConvertElementsAttr(slice_sizes, builder->getIntegerType(64)) + .cast(); } auto input_ty = input->getType().dyn_cast(); @@ -687,7 +689,7 @@ class ConvertEinsumOp : public OpRewritePattern { rewriter.replaceOpWithNewOp( op, op.getType(), *op.inputs().begin(), equation); } else if (op.N() == 2) { - auto inputs = llvm::to_vector<2>(op.inputs()); + ValueRange inputs = op.inputs(); rewriter.replaceOpWithNewOp(op, op.getType(), inputs[0], inputs[1], equation); } else { @@ -924,7 +926,7 @@ class ConvertSizeOp : public OpRewritePattern { }; // Converts the tf.Split op into a series of HLO slice ops when the tensor to be -// split has fuly static shape and the dimension to split is a constant. +// split has fully static shape and the dimension to split is a constant. // // The main logic of this pattern is to calculate the index start and end range // for each slice. And this happens only on the dimension to be split; for all @@ -962,9 +964,9 @@ class ConvertSplitOp : public OpRewritePattern { PatternMatchResult matchAndRewrite(TF::SplitOp op, PatternRewriter &rewriter) const override { - // We can only match when the tensor to be split has fully static shape. + // We can only split along static dimensions. auto input_type = op.value()->getType().dyn_cast(); - if (!input_type || !input_type.hasStaticShape()) return matchFailure(); + if (!input_type) return matchFailure(); // We can only match when the split dimension is a constant scalar. DenseIntElementsAttr split_dim_attr; @@ -978,6 +980,10 @@ class ConvertSplitOp : public OpRewritePattern { // Calculate the dimension size for each slice along the split dimension. int64_t input_dim_size = input_type.getDimSize(dim_index); + // If we are splitting along the dynamic dimension then we cannot compute + // the static dimension length. + if (TensorType::isDynamic(input_dim_size)) return matchFailure(); + int64_t num_splits = op.getNumResults(); int64_t slice_size = input_dim_size / num_splits; @@ -1011,6 +1017,118 @@ class ConvertSplitOp : public OpRewritePattern { } }; +// Converts the tf.SplitV op into a series of HLO slice ops when the tensor to +// be split has fully static shape and the dimension to split and split sizes +// are constants. +// +// This is similar to the conversion for tf.Split op other than that the size of +// each chunk on the dimension to split is explicitly given as an op operand +// and they are not necessarily the same. +// +// For example, given the following IR: +// +// %split_sizes = "tf.Const"() {value = dense<[1, -1, 3]> : tensor<3xi32>} +// %split_dim = "tf.Const"() {value = dense<1> : tensor} +// %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) : +// (tensor<4x6xf32>, tensor<3xi32>, tensor) -> +// (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) +// +// We will generate slices following slices: +// %0 = "xla_hlo.slice"(%input) { +// limit_indices = dense<[4, 1]> : tensor<2xi64>, +// start_indices = dense<0> : tensor<2xi64>, +// strides = dense<1> : tensor<2xi64>} : +// (tensor<4x6xf32>) -> tensor<4x1xf32> +// %1 = "xla_hlo.slice"(%input) { +// limit_indices = dense<[4, 3]> : tensor<2xi64>, +// start_indices = dense<[0, 1]> : tensor<2xi64>, +// strides = dense<1> : tensor<2xi64>} : +// (tensor<4x6xf32>) -> tensor<4x2xf32> +// %2 = "xla_hlo.slice"(%input) { +// limit_indices = dense<[4, 6]> : tensor<2xi64>, +// start_indices = dense<[0, 3]> : tensor<2xi64>, +// strides = dense<1> : tensor<2xi64>} : +// (tensor<4x6xf32>) -> tensor<4x3xf32> +class ConvertSplitVOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(TF::SplitVOp op, + PatternRewriter &rewriter) const override { + // We can only split along static dimensions. + // TODO(b/145731001): enhance to support dynamic-shaped inputs. + auto input_type = op.value()->getType().dyn_cast(); + if (!input_type) return matchFailure(); + + // We can only match when the split dimension is a constant scalar. + DenseIntElementsAttr split_dim_attr; + if (!matchPattern(op.split_dim(), m_Constant(&split_dim_attr))) + return matchFailure(); + + // We can only match when the split sizes is a constant int vector. + DenseIntElementsAttr split_sizes_attr; + if (!matchPattern(op.size_splits(), m_Constant(&split_sizes_attr))) + return matchFailure(); + + // Get each chunck's size along the dimension to split. It may contain + // dynamic sizes and we need to update it if so. + SmallVector split_sizes; + int64_t total_dim_size = 0; // Total dimension size assigned to splits + llvm::Optional dynamic_dim_index; + split_sizes.reserve( + split_sizes_attr.getType().cast().getNumElements()); + for (auto dim : llvm::enumerate(split_sizes_attr)) { + int64_t dim_val = dim.value().getSExtValue(); + split_sizes.push_back(dim_val); + if (dim_val == ShapedType::kDynamicSize) { + // We cannot have more than one dynamic dimension. + assert(!dynamic_dim_index && "invalid split sizes"); + dynamic_dim_index = dim.index(); + } else { + total_dim_size += dim_val; + } + } + + // Get the dimension we are splitting at. Offset properly if it's negative. + int64_t input_rank = input_type.getRank(); + int64_t dim_index = (*split_dim_attr.begin()).getSExtValue(); + if (dim_index < 0) dim_index += input_rank; + + int64_t input_dim_size = input_type.getDimSize(dim_index); + if (TensorType::isDynamic(input_dim_size)) return matchFailure(); + + assert(((dynamic_dim_index && total_dim_size <= input_dim_size) || + (!dynamic_dim_index && total_dim_size == input_dim_size)) && + "invalid split sizes"); + + // Update the dynamic dimension with calculated concrete size. + if (dynamic_dim_index) + split_sizes[*dynamic_dim_index] = input_dim_size - total_dim_size; + + // Parameters for constructing each slice. + SmallVector begin_indices(input_rank, 0); + auto end_indices = llvm::to_vector<4>(input_type.getShape()); + SmallVector strides(input_rank, 1); + + // All HLO slice results used to replace the original tf.Split op. + SmallVector slices; + slices.reserve(op.getNumResults()); + + for (int i = 0; i < op.getNumResults(); ++i) { + end_indices[dim_index] = begin_indices[dim_index] + split_sizes[i]; + slices.push_back(rewriter.create( + op.getLoc(), op.value(), GetI64ElementsAttr(begin_indices, &rewriter), + GetI64ElementsAttr(end_indices, &rewriter), + GetI64ElementsAttr(strides, &rewriter))); + // Prepare the begin indice for the next slice. + begin_indices[dim_index] = end_indices[dim_index]; + } + + rewriter.replaceOp(op, slices); + return matchSuccess(); + } +}; + // Converts StridedSlice op to HLO Slice op along with Reverse op to handle // negative strides and Reshape op to update the output shape. Indices and // strides operands are converted to attributes with non-negative indexing. @@ -1182,8 +1300,7 @@ class GenericConvertReductionOp : public OpRewritePattern { ArrayRef input_shape = input_ty.getShape(); DenseIntElementsAttr dimensions; - if (!matchPattern(op.reduction_indices(), m_Constant(&dimensions)) || - dimensions.getType().getRank() != 1) + if (!matchPattern(op.reduction_indices(), m_Constant(&dimensions))) return this->matchFailure(); // Build the final shape from input_shape and dimensions using a bitmap @@ -1260,7 +1377,6 @@ class ConvertMeanOp : public GenericConvertReductionOp { public: using GenericConvertReductionOp::GenericConvertReductionOp; - static Value *GetInitialValue(Type reduce_element_type, Location loc, PatternRewriter &rewriter) { return GetScalarConstOfType(reduce_element_type, loc, 0, &rewriter); @@ -1300,6 +1416,36 @@ class ConvertMaxOp } }; +// Converts All op to HLO Reduce op. +// +// %init = constant dense<...> : tensor +// %max = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.and"] +// {dimensions = ...} +class ConvertAllOp + : public GenericConvertReductionOp { + public: + using GenericConvertReductionOp::GenericConvertReductionOp; + static Value *GetInitialValue(Type reduce_element_type, Location loc, + PatternRewriter &rewriter) { + return GetScalarConstOfType(reduce_element_type, loc, 1, &rewriter); + } +}; + +// Converts Any op to HLO Reduce op. +// +// %init = constant dense<...> : tensor +// %max = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.or"] +// {dimensions = ...} +class ConvertAnyOp + : public GenericConvertReductionOp { + public: + using GenericConvertReductionOp::GenericConvertReductionOp; + static Value *GetInitialValue(Type reduce_element_type, Location loc, + PatternRewriter &rewriter) { + return GetScalarConstOfType(reduce_element_type, loc, 0, &rewriter); + } +}; + // Converts tensorflow ArgMin or ArgMax op to xla_hlo operations that perform // a reduction on the original input and the corresponding index. The reduction // sub-computation selects the max (or min) value and the index for the value. @@ -2000,6 +2146,53 @@ class ConvertTopKV2Op : public OpRewritePattern { } }; +// Converts tf.Unpack to a series of XLA HLO slice ops. +// +// Each slice takes one element along the dimension to unpack and takes the full +// range for all other dimenions. Each slice is then reshaped to drop the +// dimension to unpack (which is always of size 1). +// TODO(antiagainst): consider changing this into a TF internal lowering pass. +class ConvertUnpackOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(TF::UnpackOp op, + PatternRewriter &rewriter) const override { + auto value_type = op.value()->getType().cast(); + if (!value_type) return matchFailure(); + + int64_t value_rank = value_type.getRank(); + int64_t axis = op.axis().getSExtValue(); + if (axis < 0) axis += value_rank; + + // Parameters for constructing each slice. + SmallVector begin_indices(value_rank, 0); + auto end_indices = llvm::to_vector<4>(value_type.getShape()); + SmallVector strides(value_rank, 1); + + // All HLO slice+reshape results used to replace the original tf.Unpack op. + SmallVector results; + results.reserve(op.getNumResults()); + + for (int i = 0; i < op.getNumResults(); ++i) { + begin_indices[axis] = i; + end_indices[axis] = i + 1; + + auto slice_op = rewriter.create( + op.getLoc(), op.value(), GetI64ElementsAttr(begin_indices, &rewriter), + GetI64ElementsAttr(end_indices, &rewriter), + GetI64ElementsAttr(strides, &rewriter)); + // Reshape to drop the axis dimension. + auto reshape_op = rewriter.create( + op.getLoc(), op.getType(i), slice_op); + results.push_back(reshape_op); + } + + rewriter.replaceOp(op, results); + return matchSuccess(); + } +}; + #include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc" LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) { @@ -2013,16 +2206,16 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) { // level TensorFlow ops. So, we don't have to target all the TensorFlow ops // here for lowering to HLO. TF::PopulateLoweringTFPatterns(context, &patterns); - patterns - .insert, - ConvertSoftmaxOp, ConvertSplitOp, - ConvertStridedSliceOp, ConvertTopKV2Op, ConvertMeanOp, - ConvertSumOp, ConvertMaxOp, ConvertTileOp, ConvertMaxPoolGradOp, - ConvertOneHotOp, ConvertConv2DBackpropInputOp, - ConvertConv2DBackpropFilterOp>(op->getContext()); + patterns.insert< + ConvertArgMaxOp, ConvertBF16FloorDivOp, ConvertConv2D, ConvertEinsumOp, + ConvertMaxPoolOp, ConvertRangeOp, ConvertSigmoidOp, ConvertSizeOp, + ConvertMaxPoolOp, ConvertRangeOp, ConvertSigmoidOp, + ConvertSoftmaxOp, + ConvertSoftmaxOp, ConvertSplitOp, ConvertSplitVOp, + ConvertStridedSliceOp, ConvertTopKV2Op, ConvertUnpackOp, ConvertMeanOp, + ConvertSumOp, ConvertMaxOp, ConvertAllOp, ConvertAnyOp, ConvertTileOp, + ConvertMaxPoolGradOp, ConvertOneHotOp, ConvertConv2DBackpropInputOp, + ConvertConv2DBackpropFilterOp>(op->getContext()); ConversionTarget target(*context); target.addLegalDialect(); diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc index a794e274f59..d2177041ba7 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc @@ -95,8 +95,7 @@ void ImportXlaRegion(mlir::FuncOp func, Region* dest_region, Location loc, detupled_args.push_back(extract); } - llvm::SmallVector result( - builder.create(loc, func, detupled_args).getResults()); + auto result = builder.create(loc, func, detupled_args).getResults(); if (!tuple_return) { builder.create(loc, result); } else { diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td index fb8c6736309..14075134f11 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td @@ -29,6 +29,9 @@ def FeatureDimension : NativeCodeCall< def FalseBoolAttr : AttrConstraint>; def TrueBoolAttr : AttrConstraint>; +def CastValueToI64: NativeCodeCall< + "CastValueToI64($0->getLoc(), $1, &$_builder)">; + def : Pattern< (TF_FusedBatchNormOp:$root $x, $scale, $offset, $mean, $variance, $epsilon, $data_format, FalseBoolAttr:$is_training), @@ -43,13 +46,22 @@ def : Pattern< [(HasNoUseOf:$root__1), (HasNoUseOf:$root__2), (HasNoUseOf:$root__3), (HasNoUseOf:$root__4)]>; +//===----------------------------------------------------------------------===// +// Assert op pattern. +//===----------------------------------------------------------------------===// + +// HLO and XLA doesn't support Assertions. +def LowerAssert : Pattern<(TF_AssertOp $condition, $data, $summarize), []>; + //===----------------------------------------------------------------------===// // Bias op patterns. //===----------------------------------------------------------------------===// def BiasAddFeatureDimension : NativeCodeCall< "getBiasFeatureDimension($_builder, $0, $1)">; -def : Pat<(TF_BiasAddOp AnyStaticShapeTensor:$input, $bias, $data_format), +// $input needs to be a ranked tensor to identify index of the feature +// dimension depending on the data_format 'NHWC' or 'NCHW'. +def : Pat<(TF_BiasAddOp AnyRankedTensor:$input, $bias, $data_format), (HLO_AddOp $input, $bias, (BiasAddFeatureDimension $data_format, $input))>; @@ -298,7 +310,7 @@ def : Pat<(TF_MatMulOp $a, $b, $transpose_a, $transpose_b), //===----------------------------------------------------------------------===// def : Pat<(TF_ConstOp:$res ElementsAttr:$value), (HLO_ConstOp $value), - [(AnyStaticShapeTensor $res), (HLO_Tensor $res)]>; + [(HLO_Tensor $res)]>; //===----------------------------------------------------------------------===// // Relu op patterns. @@ -316,11 +328,21 @@ def : Pat<(TF_Relu6Op AnyStaticShapeTensor:$input), (HLO_ConstOp (ConstantSplat<"6"> $input)))>; // ReluGrad(gradients, features) = gradients * (features > 0) -def : Pat<(TF_ReluGradOp AnyStaticShapeTensor:$gradients, AnyStaticShapeTensor:$features), +// +// $gradients needs to be of static shape so that on_true and on_false operands +// of SelectOp have same shape. +// +// $features needs to be ranked for computation of the broadcast dimensions for +// CompareOp. +// +// TODO(hinsu): Relax $gradients static shape requirement when there is a way +// to create splat tensor of dynamic shape in HLO. +def : Pat<(TF_ReluGradOp AnyStaticShapeTensor:$gradients, AnyRankedTensor:$features), (HLO_SelectOp - (HLO_CompareOp $features, (HLO_ConstOp:$zero (ConstantSplat<"0"> $features)), + (HLO_CompareOp $features, + (HLO_ConstOp (GetScalarOfType<0> $features)), (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_GT), - $gradients, $zero)>; + $gradients, (HLO_ConstOp (ConstantSplat<"0"> $gradients)))>; //===----------------------------------------------------------------------===// // Slice op patterns. @@ -333,9 +355,9 @@ def TFSliceSizes2HLOSliceSizes : NativeCodeCall< "TFSliceSizes2HLOSliceSizes($0, $1, $2.cast()," "&$_builder)">; -def : Pat<(TF_SliceOp HLO_Tensor:$input, HLO_Tensor:$starting_indices, - (TF_ConstOp I64ElementsAttr:$slice_sizes)), - (HLO_DynamicSliceOp $input, $starting_indices, +def : Pat<(TF_SliceOp:$op HLO_Tensor:$input, HLO_Tensor:$starting_indices, + (TF_ConstOp $slice_sizes)), + (HLO_DynamicSliceOp $input, (CastValueToI64 $op, $starting_indices), (TFSliceSizes2HLOSliceSizes $input, $starting_indices, $slice_sizes)), [(CanBeTranslatedToDynamicSlice $input, $starting_indices, $slice_sizes)]>; @@ -383,19 +405,21 @@ foreach Mapping = [ def : Pat<(TF_CastOp HLO_Tensor:$arg, ConstBoolAttrFalse), (HLO_ConvertOp $arg)>; -def : Pat<(TF_TransposeOp:$res $arg, (TF_ConstOp I64ElementsAttr:$permutation)), - (HLO_TransposeOp $arg, (CastIntElementsAttr $permutation))>; +def : Pat<(TF_TransposeOp:$res $arg, (TF_ConstOp $permutation)), + (HLO_TransposeOp $arg, (CastElementsToI64Elements $permutation))>; +// Result of the following ops changing tensor shape needs to have static +// shape as HLO doesn't yet support dynamic reshaping ops. +// +// TODO(hinsu): Update once HLO supports dynamic reshaping ops. foreach TfOp = [TF_ExpandDimsOp, TF_ReshapeOp, TF_SqueezeOp, ] in { - def : Pat<(TfOp:$res AnyStaticShapeTensor:$arg, $ignored), + def : Pat<(TfOp:$res $arg, $ignored), (HLO_ReshapeOp $arg), [(AnyStaticShapeTensor $res)]>; } //===----------------------------------------------------------------------===// // RngUniform. //===----------------------------------------------------------------------===// -def CastElementsToI64: NativeCodeCall< - "CastElementsToI64($0->getLoc(), $1, &$_builder)">; // TODO(misard,phawkins): handle random number generator seeds/states correctly. def : Pat<(TF_RandomUniformOp:$old $shape, $seed, $seed2), @@ -404,5 +428,5 @@ def : Pat<(TF_RandomUniformOp:$old $shape, $seed, $seed2), (NativeCodeCall<"$_builder.getFloatAttr(old.dtype(), 0.0)">)), (HLO_ConstOp (NativeCodeCall<"$_builder.getFloatAttr(old.dtype(), 1.0)">)), - (CastElementsToI64 $old, $shape)), + (CastValueToI64 $old, $shape)), [(IsShapedTensor $shape)]>; diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc index 4dabe0dea42..928bfc20cdb 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc @@ -54,7 +54,8 @@ struct LhloFuseLinalg : public FunctionPass { auto op = cast(generic_op.getOperation()); for (const Value* result : op.getOutputs()) { if (!func_args.count(result)) continue; - if (linalg::tileLinalgOp(b, op, tile_sizes, &folder)) { + if (linalg::tileLinalgOp(b, op, tile_sizes, /*permutation=*/{}, + &folder)) { generic_op.erase(); return; } diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_linalg.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_linalg.cc index 28bacfa87f0..c4787d9bfd9 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_linalg.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_linalg.cc @@ -112,7 +112,7 @@ class PointwiseToLinalgConverter : public OpConversionPattern { rewriter.setInsertionPointToEnd(block); Operation* op = MapLhloOpToStdScalarOp( llvm::cast(lhlo_op), bodyResultTypes, bodyArgs, rewriter); - rewriter.create(loc, llvm::to_vector<1>(op->getResults())); + rewriter.create(loc, op->getResults()); rewriter.eraseOp(lhlo_op); return ConversionPattern::matchSuccess(); } diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index bfd0ce3d072..4d85ca67777 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -653,7 +653,13 @@ class BinaryOpsTest(xla_test.XLATestCase): divs = np.arange(-3, 3, .25, dtype=dtype).reshape(1, 24) np_result = np.true_divide(nums, divs) np_result[:, divs[0] == 0] = 0 - self._testBinary(gen_math_ops.div_no_nan, nums, divs, expected=np_result) + self._testBinary( + gen_math_ops.div_no_nan, + nums, + divs, + expected=np_result, + rtol=7e-15 if dtype == np.float64 else None, + atol=3.9e-15 if dtype == np.float64 else None) if dtype not in self.complex_types: # floordiv unsupported for complex. self._testBinary( diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py index 99847e84c28..1bc88509542 100644 --- a/tensorflow/compiler/tests/tensor_array_ops_test.py +++ b/tensorflow/compiler/tests/tensor_array_ops_test.py @@ -164,7 +164,8 @@ class TensorArrayTest(xla_test.XLATestCase): dtype=tf_dtype, tensor_array_name="foo", size=3) # Unpack a matrix into vectors. - w1 = ta.unstack(convert([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]])) + w1 = ta.unstack( + convert([[1.0, 1.03125], [2.0, 2.03125], [3.0, 3.03125]])) r0 = w1.read(0) r1 = w1.read(1) r2 = w1.read(2) @@ -172,9 +173,9 @@ class TensorArrayTest(xla_test.XLATestCase): d0, d1, d2 = self.evaluate(xla.compile(fn)) - self.assertAllEqual(convert([1.0, 1.1]), d0) - self.assertAllEqual(convert([2.0, 2.1]), d1) - self.assertAllEqual(convert([3.0, 3.1]), d2) + self.assertAllEqual(convert([1.0, 1.03125]), d0) + self.assertAllEqual(convert([2.0, 2.03125]), d1) + self.assertAllEqual(convert([3.0, 3.03125]), d2) def fn(): # Reset ta because we're going to change the shape, else shape diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc index d011be2c5af..20804af5229 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc @@ -307,7 +307,7 @@ void UpdateToEngineNode(const std::vector& infos, } } } - LOG(FATAL) << "Node " << (**node).name() << " not found in any engine."; + LOG(FATAL) << "Node " << node_name << " not found in any engine."; } // Function to insert a TRT engine node into the graph. diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index 6c2b8fdc091..ef03ab91714 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -654,9 +654,8 @@ class ConverterTest : public ::testing::Test { ConverterTest() { Reset(); } void Reset() { - builder_.reset(nvinfer1::createInferBuilder(logger_)); converter_ = - std::move(Converter::Create(builder_.get(), TrtPrecisionMode::FP32, + std::move(Converter::Create(TrtPrecisionMode::FP32, /*use_calibration=*/false, &logger_) .ValueOrDie()); weight_store_ = &converter_->weight_store_; @@ -702,9 +701,6 @@ class ConverterTest : public ::testing::Test { private: Logger logger_; - // These members are ordered in a way such that the destruction order is: - // converter_ -> builder_ - TrtUniquePtrType builder_; protected: std::unique_ptr converter_; @@ -996,9 +992,7 @@ TEST_F(ConverterTest, MaybeApplyQuantizationRanges) { FakeITensor input, infer_1, infer_2, infer_3; FakeITensor not_infer; Logger logger; - TrtUniquePtrType builder( - nvinfer1::createInferBuilder(logger)); - auto int8_converter = Converter::Create(builder.get(), TrtPrecisionMode::INT8, + auto int8_converter = Converter::Create(TrtPrecisionMode::INT8, /*use_calibration=*/true, &logger) .ValueOrDie(); int8_converter->ProvideQuantizationRange(&input, -5.0f, 5.0f); @@ -1255,12 +1249,8 @@ class OpConverterTest : public ::testing::Test { engine_.reset(nullptr); // Re-create them in proper order. - builder_.reset(nvinfer1::createInferBuilder(logger_)); - builder_->setMaxWorkspaceSize(1 << 26); - - // Reset the converter. converter_ = - std::move(Converter::Create(builder_.get(), precision_mode_to_test_, + std::move(Converter::Create(precision_mode_to_test_, /*use_calibration=*/false, &logger_) .ValueOrDie()); @@ -1294,18 +1284,13 @@ class OpConverterTest : public ::testing::Test { TF_EXPECT_OK(converter_->RenameAndMarkOutputTensors(output_info)); // Build the TRT engine. - if (precision_mode == TrtPrecisionMode::FP16) { - builder_->setFp16Mode(true); - } else if (precision_mode == TrtPrecisionMode::INT8) { - // Setting FP16 mode as well allows TRT to also consider FP16 kernels and - // use them in situations where they are faster than INT8 or where INT8 is - // not supported for a given layer. - builder_->setFp16Mode(true); - builder_->setInt8Mode(true); - } ASSERT_EQ(nullptr, engine_.get()); - builder_->setMaxBatchSize(batch_size); - TF_ASSERT_OK(converter_->BuildCudaEngine(&engine_)); + TF_ASSERT_OK( + converter_->BuildCudaEngine(&engine_, + /*max_batch_size=*/batch_size, + /*max_workspace_size_bytes=*/1 << 26, + /*allocator=*/nullptr, + /*calibrator=*/nullptr)); CHECK_NOTNULL(engine_.get()); CheckDataTypeMatches(input_data); CheckDataTypeMatches(*output_data); @@ -1473,7 +1458,6 @@ class OpConverterTest : public ::testing::Test { private: Logger logger_; - TrtUniquePtrType builder_; TrtUniquePtrType engine_; cudaStream_t stream_; // Used to create placeholders with shape and data type information. The diff --git a/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc b/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc index fea2407a5d1..fb89742b139 100644 --- a/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc @@ -143,6 +143,10 @@ class DataFormatVecPermuteOp : public XlaOpKernel { REGISTER_XLA_OP( Name("DataFormatVecPermute").TypeConstraint("T", {DT_INT32, DT_INT64}), DataFormatVecPermuteOp); +REGISTER_XLA_OP(Name("DataFormatVecPermute") + .Label("host") + .TypeConstraint("T", {DT_INT32, DT_INT64}), + DataFormatVecPermuteOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 9d10be1d90a..defd96b570c 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -723,17 +723,6 @@ Status XlaCompiler::CompileFunction( std::unique_ptr graph = GetGraph(fbody); - // Clear the "_kernel" attribute if it is set to "host". This is used to - // indicate that a computation should happen on the host instead of the - // accelerator, but doesn't make sense in XLA. - const char* const kKernelAttr = "_kernel"; - for (Node* n : graph->nodes()) { - string value; - if (TryGetNodeAttr(n->attrs(), kKernelAttr, &value) && value == "host") { - n->ClearAttr(kKernelAttr); - } - } - // _Arg and _Retval nodes don't exist in the stored subgraph for the function; // they are added by the function body looked up. Therefore, they don't have // core assignments here. @@ -1059,7 +1048,12 @@ Status XlaCompiler::BuildArguments( const XlaCompiler::Argument& arg = args[input_to_args->at(i)]; VLOG(2) << " XLA arg " << i << " shape: " << xla::ShapeUtil::HumanString(arg_shapes[i]) - << " name: " << arg.name << " TF arg " << input_to_args->at(i); + << " name: " << arg.name << " TF arg " << input_to_args->at(i) + << " node name: " << arg.node_name + << (arg_shardings.find(i) == arg_shardings.end() + ? "" + : absl::StrCat(" sharding: ", + arg_shardings.at(i).DebugString())); XlaExpression& arg_expression = (*arg_expressions)[input_to_args->at(i)]; switch (arg.kind) { case XlaCompiler::Argument::kResource: { diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index c3e9b3edeca..670da043c1a 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -147,6 +147,9 @@ class XlaCompiler { // The name of this argument, used for debugging. string name; + // The name of TensorFlow _Arg node, used for debugging. + string node_name; + // For a kResource, what kind of resource is it? XlaResource::Kind resource_kind = XlaResource::kInvalid; diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index e70012f761a..a43608bd434 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -61,6 +61,7 @@ XlaOpRegistry::~XlaOpRegistry() = default; /* static */ bool XlaOpRegistry::IsCompatible(const OpRegistration& x, const OpRegistration& y) { if (x.name != y.name) return true; + if (x.label != y.label) return true; // The registrations refer to the same Op: ensures they are compatible and // are restricted to different device whitelists. if (x.compilation_only != y.compilation_only) { @@ -256,6 +257,7 @@ void XlaOpRegistry::RegisterCompilationKernels() { std::unique_ptr kdef(new KernelDef); kdef->set_op(op_registration->name); kdef->set_device_type(backend.first); + kdef->set_label(op_registration->label); // Constrain each type attribute to the intersection of: // a) the types supported by the backend, and @@ -539,6 +541,11 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::IsMetadataOp() { return *this; } +XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Label(std::string label) { + registration_->label = label; + return *this; +} + std::unique_ptr XlaOpRegistrationBuilder::Build( XlaOpRegistry::Factory factory) { registration_->factory = factory; diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index af08790e02e..c6f6ffb2853 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -270,6 +270,8 @@ class XlaOpRegistry { // operands and not their values. bool is_metadata_op = false; + std::string label; + // Factory used to build OpKernels that perform symbolic execution. Factory factory; }; @@ -350,6 +352,9 @@ class XlaOpRegistrationBuilder { // operands and not their values. XlaOpRegistrationBuilder& IsMetadataOp(); + // Specifies a particular value for the "_kernel" attr. + XlaOpRegistrationBuilder& Label(std::string label); + std::unique_ptr Build( XlaOpRegistry::Factory factory); diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc index 989968b5cbc..8c85482c8f8 100644 --- a/tensorflow/compiler/xla/client/lib/math.cc +++ b/tensorflow/compiler/xla/client/lib/math.cc @@ -319,6 +319,8 @@ XlaOp Erf(XlaOp x) { }); } +namespace { + // Approximation for the inverse error function from // Giles, M., "Approximating the erfinv function". // The approximation has the form: @@ -331,7 +333,7 @@ XlaOp Erf(XlaOp x) { // p = sum_{i=1}^n gq[i]*w^i // } // return p*x -XlaOp ErfInv(XlaOp x) { +XlaOp ErfInv32(XlaOp x) { constexpr int kDegree = 9; constexpr std::array w_less_than_5_constants = { 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, @@ -371,6 +373,101 @@ XlaOp ErfInv(XlaOp x) { }); } +XlaOp ErfInv64(XlaOp x) { + constexpr std::array w_less_than_6_25_constants = { + -3.6444120640178196996e-21, -1.685059138182016589e-19, + 1.2858480715256400167e-18, 1.115787767802518096e-17, + -1.333171662854620906e-16, 2.0972767875968561637e-17, + 6.6376381343583238325e-15, -4.0545662729752068639e-14, + -8.1519341976054721522e-14, 2.6335093153082322977e-12, + -1.2975133253453532498e-11, -5.4154120542946279317e-11, + 1.051212273321532285e-09, -4.1126339803469836976e-09, + -2.9070369957882005086e-08, 4.2347877827932403518e-07, + -1.3654692000834678645e-06, -1.3882523362786468719e-05, + 0.0001867342080340571352, -0.00074070253416626697512, + -0.0060336708714301490533, 0.24015818242558961693, + 1.6536545626831027356}; + constexpr std::array w_less_than_16_constants = { + 2.2137376921775787049e-09, 9.0756561938885390979e-08, + -2.7517406297064545428e-07, 1.8239629214389227755e-08, + 1.5027403968909827627e-06, -4.013867526981545969e-06, + 2.9234449089955446044e-06, 1.2475304481671778723e-05, + -4.7318229009055733981e-05, 6.8284851459573175448e-05, + 2.4031110387097893999e-05, -0.0003550375203628474796, + 0.00095328937973738049703, -0.0016882755560235047313, + 0.0024914420961078508066, -0.0037512085075692412107, + 0.005370914553590063617, 1.0052589676941592334, + 3.0838856104922207635, + }; + constexpr std::array w_greater_than_16_constants = { + -2.7109920616438573243e-11, -2.5556418169965252055e-10, + 1.5076572693500548083e-09, -3.7894654401267369937e-09, + 7.6157012080783393804e-09, -1.4960026627149240478e-08, + 2.9147953450901080826e-08, -6.7711997758452339498e-08, + 2.2900482228026654717e-07, -9.9298272942317002539e-07, + 4.5260625972231537039e-06, -1.9681778105531670567e-05, + 7.5995277030017761139e-05, -0.00021503011930044477347, + -0.00013871931833623122026, 1.0103004648645343977, + 4.8499064014085844221, + }; + // Compute logarithm of (1+arg) using log1p(arg) which is more precise than + // log(1+arg) when arg is close to zero. For more details, see + // https://en.cppreference.com/w/cpp/numeric/math/log1p + auto w = -Log1p(-x * x); + + auto lt_6_25 = Lt(w, ScalarLike(x, 6.25)); + auto lt_16 = Lt(w, ScalarLike(x, 16)); + auto coefficient = [&](int i) { + auto c = FullLike(x, w_less_than_6_25_constants[i]); + if (i < 19) { + c = Select(lt_6_25, c, FullLike(x, w_less_than_16_constants[i])); + } + if (i < 17) { + c = Select(lt_16, c, FullLike(x, w_greater_than_16_constants[i])); + } + return c; + }; + auto sqrt_w = Sqrt(w); + w = Select(lt_6_25, w - ScalarLike(x, 3.125), + sqrt_w - Select(lt_16, ScalarLike(x, 3.25), ScalarLike(x, 5.0))); + auto p = coefficient(0); + for (int i = 1; i < 17; ++i) { + p = coefficient(i) + p * w; + } + for (int i = 17; i < 19; ++i) { + p = Select(lt_16, coefficient(i) + p * w, p); + } + for (int i = 19; i < 23; ++i) { + p = Select(lt_6_25, coefficient(i) + p * w, p); + } + // Result modulo edge cases. + XlaOp result = p * x; + + // Handle edge cases, namely erfinv(+/-1) = +/-inf. (The above computation is + // indeterminate, and can give nan or -/+inf.) + auto& b = *x.builder(); + return b.ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, b.GetShape(x)); + return Select(Eq(Abs(x), ScalarLike(x, 1)), + x * MaxValue(&b, shape.element_type()), result); + }); +} + +} // namespace + +XlaOp ErfInv(XlaOp x) { + auto& b = *x.builder(); + return b.ReportErrorOrReturn([&]() -> StatusOr { + TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("ErfInv", x)); + TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(x)); + if (shape.element_type() == F64) { + return ErfInv64(x); + } + return DoWithUpcastToF32(x, {BF16, F16}, + [](XlaOp x) { return ErfInv32(x); }); + }); +} + namespace { // Coefficients for the Lanczos approximation of the gamma function. The // coefficients are uniquely determined by the choice of g and n (kLanczosGamma diff --git a/tensorflow/compiler/xla/client/lib/math_test.cc b/tensorflow/compiler/xla/client/lib/math_test.cc index 6415e9383b5..8d13922e0e3 100644 --- a/tensorflow/compiler/xla/client/lib/math_test.cc +++ b/tensorflow/compiler/xla/client/lib/math_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/client/lib/math.h" + #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" @@ -116,6 +117,10 @@ class MathTypedTest : public MathTest { // // For good measure, we also check pow with an exponent other than 0.5. void TestSqrtPowInequivalence() { + // TODO(b/145798892): test fails on GPU for double values. + if (std::is_same::value) { + return; + } SetFastMathDisabled(true); // Tests disable constant folding by default, but this test needs it @@ -151,11 +156,16 @@ class MathTypedTest : public MathTest { }; // TODO(b/123355973): Add bfloat16 to TestTypes once it's working. -#ifdef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16 -using TestTypes = ::testing::Types; -#else -using TestTypes = ::testing::Types; +using TestTypes = ::testing::Types; TYPED_TEST_CASE(MathTypedTest, TestTypes); @@ -224,6 +234,28 @@ XLA_TEST_F(MathTest, SqrtF32) { ComputeAndCompareR0(&builder, 0.0f, {zero_data.get()}, error_spec_); } +#ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64 +XLA_TEST_F(MathTest, ErfInvF64) { + XlaBuilder builder(TestName()); + auto x = ConstantR1( + &builder, {-0.9, -0.8, -0.7, -0.6, -0.5, -0.4, -0.3, -0.2, -0.1, 0.0, 0.1, + 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9}); + ErfInv(x); + + std::vector expected = {-1.163087153676674, -0.9061938024368231, + -0.732869077959217, -0.5951160814499948, + -0.4769362762044698, -0.37080715859355795, + -0.27246271472675443, -0.1791434546212916, + -0.08885599049425767, 0., + 0.08885599049425777, 0.1791434546212916, + 0.27246271472675443, 0.37080715859355784, + 0.4769362762044698, 0.5951160814499948, + 0.732869077959217, 0.9061938024368231, + 1.1630871536766736}; + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec{1e-15}); +} +#endif + XLA_TEST_F(MathTest, SquareTenValues) { XlaBuilder builder(TestName()); auto x = ConstantR1( diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 8f480c0dec3..290b9c0f647 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -2112,7 +2112,8 @@ XlaOp XlaBuilder::CrossReplicaSum( XlaOp XlaBuilder::AllReduce(XlaOp operand, const XlaComputation& computation, absl::Span replica_groups, - const absl::optional& channel_id) { + const absl::optional& channel_id, + const absl::optional& shape_with_layout) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); @@ -2136,9 +2137,31 @@ XlaOp XlaBuilder::AllReduce(XlaOp operand, const XlaComputation& computation, operand_shapes.push_back(operand_shape); operands.push_back(operand); } - TF_ASSIGN_OR_RETURN(Shape shape, + + TF_ASSIGN_OR_RETURN(Shape inferred_shape, ShapeInference::InferAllReduceShape(operand_shapes)); - *instr.mutable_shape() = shape.ToProto(); + if (shape_with_layout) { + if (!LayoutUtil::HasLayout(*shape_with_layout)) { + return InvalidArgument("shape_with_layout must have the layout set: %s", + shape_with_layout->ToString()); + } + if (!ShapeUtil::Compatible(*shape_with_layout, *operand_shape)) { + return InvalidArgument( + "Provided shape_with_layout must be compatible with the " + "operand shape: %s vs %s", + shape_with_layout->ToString(), operand_shape->ToString()); + } + instr.set_constrain_layout(true); + if (operand_shape->IsTuple() && !inferred_shape.IsTuple()) { + // For a single-element tuple, take the tuple element shape. + TF_RET_CHECK(shape_with_layout->tuple_shapes_size() == 1); + *instr.mutable_shape() = shape_with_layout->tuple_shapes(0).ToProto(); + } else { + *instr.mutable_shape() = shape_with_layout->ToProto(); + } + } else { + *instr.mutable_shape() = inferred_shape.ToProto(); + } for (const ReplicaGroup& group : replica_groups) { *instr.add_replica_groups() = group; @@ -2153,10 +2176,10 @@ XlaOp XlaBuilder::AllReduce(XlaOp operand, const XlaComputation& computation, TF_ASSIGN_OR_RETURN( auto all_reduce, AddInstruction(std::move(instr), HloOpcode::kAllReduce, operands)); - if (operand_shape->IsTuple() && !shape.IsTuple()) { + if (operand_shape->IsTuple() && !inferred_shape.IsTuple()) { // For a single-element tuple, wrap the result into a tuple. TF_RET_CHECK(operand_shapes.size() == 1); - TF_RET_CHECK(ShapeUtil::Compatible(*operand_shapes[0], shape)); + TF_RET_CHECK(ShapeUtil::Compatible(*operand_shapes[0], inferred_shape)); return Tuple({all_reduce}); } return all_reduce; @@ -3282,9 +3305,10 @@ XlaOp CrossReplicaSum(const XlaOp operand, XlaOp AllReduce(const XlaOp operand, const XlaComputation& computation, absl::Span replica_groups, - const absl::optional& channel_id) { + const absl::optional& channel_id, + const absl::optional& shape_with_layout) { return operand.builder()->AllReduce(operand, computation, replica_groups, - channel_id); + channel_id, shape_with_layout); } XlaOp AllToAll(const XlaOp operand, int64 split_dimension, diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 3822e907203..5e93bb2b3ba 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -514,7 +514,8 @@ class XlaBuilder { XlaOp AllReduce( XlaOp operand, const XlaComputation& computation, absl::Span replica_groups = {}, - const absl::optional& channel_id = absl::nullopt); + const absl::optional& channel_id = absl::nullopt, + const absl::optional& shape_with_layout = absl::nullopt); XlaOp AllToAll(XlaOp operand, int64 split_dimension, int64 concat_dimension, int64 split_count, @@ -922,7 +923,8 @@ class XlaBuilder { absl::Span replica_groups); friend XlaOp AllReduce(XlaOp operand, const XlaComputation& computation, absl::Span replica_groups, - const absl::optional& channel_id); + const absl::optional& channel_id, + const absl::optional& shape_with_layout); friend XlaOp AllToAll(XlaOp operand, int64 split_dimension, int64 concat_dimension, int64 split_count, const std::vector& replica_groups); @@ -1666,10 +1668,14 @@ XlaOp CrossReplicaSum(XlaOp operand, // - `channel_id`: for Allreduce nodes from different modules, if they have the // same channel_id, they will be 'AllReduce'd. If empty, AllReduce will not be // applied cross modules. -XlaOp AllReduce( - XlaOp operand, const XlaComputation& computation, - absl::Span replica_groups = {}, - const absl::optional& channel_id = absl::nullopt); +// +// - `shape_with_layout`: forces the layout of the AllReduce to the given +// layout. This is used to guarantee the same layout for a group of AllReduce +// ops compiled separately. +XlaOp AllReduce(XlaOp operand, const XlaComputation& computation, + absl::Span replica_groups = {}, + const absl::optional& channel_id = absl::nullopt, + const absl::optional& shape_with_layout = absl::nullopt); // Enqueues an operation that do an Alltoall of the operand cross cores. XlaOp AllToAll(XlaOp operand, int64 split_dimension, int64 concat_dimension, diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index 3a219673304..bbd640f6064 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -38,6 +38,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -131,18 +132,23 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) { } } else if (shape.IsArray()) { if (allocate_arrays) { + // Literals can be used as DMA targets, which can require alignment. We + // force a 16-byte minimum alignment. + constexpr int kMinimumAlignment = 16; if (LayoutUtil::IsSparseArray(shape)) { // For sparse arrays, the buffer must be of the size of the maximum // number of sparse elements possible. const int64 max_sparse_elements = LayoutUtil::MaxSparseElements(shape.layout()); - piece->set_buffer( - new char[max_sparse_elements * - ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type())]); + piece->set_buffer(static_cast(tensorflow::port::AlignedMalloc( + max_sparse_elements * + ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type()), + kMinimumAlignment))); piece->set_sparse_indices( new SparseIndexArray(max_sparse_elements, shape.rank())); } else { - piece->set_buffer(new char[piece->size_bytes()]); + piece->set_buffer(static_cast(tensorflow::port::AlignedMalloc( + piece->size_bytes(), kMinimumAlignment))); } } } else { @@ -174,7 +180,7 @@ void Literal::DeallocateBuffers() { root_piece_->ForEachMutableSubpiece( [&](const ShapeIndex& index, Piece* piece) { if (piece->buffer() != nullptr) { - delete[] piece->buffer(); + tensorflow::port::AlignedFree(piece->buffer()); delete piece->sparse_indices(); } }); @@ -504,7 +510,7 @@ Status Literal::MoveFrom(Literal&& src_literal, dest_index.push_back(i); } Piece& dest_piece = piece(dest_index); - delete[] dest_piece.buffer(); + tensorflow::port::AlignedFree(dest_piece.buffer()); dest_piece.set_buffer(src_piece.buffer()); delete dest_piece.sparse_indices(); dest_piece.set_sparse_indices(src_piece.sparse_indices()); diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index 409d954748c..cdbe69d617e 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -26,7 +26,6 @@ py_test( name = "xla_client_test", srcs = ["xla_client_test.py"], main = "xla_client_test.py", - python_version = "PY3", srcs_version = "PY2AND3", tags = ["no_oss"], # TODO(phawkins): This test passes, but requires --config=monolithic. deps = [ diff --git a/tensorflow/compiler/xla/python/tpu_driver/BUILD b/tensorflow/compiler/xla/python/tpu_driver/BUILD index 96c6636323b..99a07c31256 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/BUILD +++ b/tensorflow/compiler/xla/python/tpu_driver/BUILD @@ -31,6 +31,11 @@ tf_proto_library_cc( use_grpc_namespace = True, ) +cc_library( + name = "c_api", + hdrs = ["c_api.h"], +) + cc_library( name = "tpu_driver", srcs = [ diff --git a/tensorflow/compiler/xla/python/tpu_driver/c_api.h b/tensorflow/compiler/xla/python/tpu_driver/c_api.h new file mode 100644 index 00000000000..5b892dfdaa3 --- /dev/null +++ b/tensorflow/compiler/xla/python/tpu_driver/c_api.h @@ -0,0 +1,30 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_TPU_DRIVER_C_API_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_TPU_DRIVER_C_API_H_ + +#define TPUDRIVER_CAPI_EXPORT __attribute__((visibility("default"))) + +extern "C" { + +TPUDRIVER_CAPI_EXPORT extern void TpuDriver_Initialize(); + +TPUDRIVER_CAPI_EXPORT extern void TpuDriver_Open(const char* worker); + +TPUDRIVER_CAPI_EXPORT extern const char* TpuDriver_Version(void); +} + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_TPU_DRIVER_C_API_H_ diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/c_api_client.c b/tensorflow/compiler/xla/python/tpu_driver/client/c_api_client.c new file mode 100644 index 00000000000..70ab4af85fd --- /dev/null +++ b/tensorflow/compiler/xla/python/tpu_driver/client/c_api_client.c @@ -0,0 +1,50 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// To compile: gcc -o c_api_client c_api_client.c -ldl +// To run, make sure c_api.so and c_api_client in the same directory, and then +// sudo ./c_api_client + +#include +#include +#include + +int main(int argc, char** argv) { + void* handle; + handle = dlopen("./c_api.so", RTLD_NOW); + if (!handle) { + fprintf(stderr, "Error: %s\n", dlerror()); + exit(EXIT_FAILURE); + } + + const char* (*TpuDriver_Version)(void); + void (*TpuDriver_Initialize)(void); + void (*TpuDriver_Open)(const char* worker); + + fprintf(stdout, "------ Going to Find Out Version ------\n"); + *(void**)(&TpuDriver_Version) = dlsym(handle, "TpuDriver_Version"); + fprintf(stdout, "TPU Driver Version: %s\n", TpuDriver_Version()); + + fprintf(stdout, "------ Going to Initialize ------\n"); + *(void**)(&TpuDriver_Initialize) = dlsym(handle, "TpuDriver_Initialize"); + TpuDriver_Initialize(); + + fprintf(stdout, "------ Going to Open a TPU Driver ------\n"); + *(void**)(&TpuDriver_Open) = dlsym(handle, "TpuDriver_Open"); + TpuDriver_Open("local://"); + + dlclose(handle); + exit(EXIT_SUCCESS); +} diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py index 43c0d1a40c3..a3ad8b117ef 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py @@ -104,6 +104,9 @@ class TpuBackend(xla_client.Backend): options, self.client, compile_options.device_assignment) + def get_default_device_assignment(self, num_replicas): + return self.client.GetDefaultDeviceAssignment(num_replicas) + def serialize(self, executable): return self.client.SerializeExecutable(executable) diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc index e7d1e2ef9d9..60886416a62 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc @@ -32,6 +32,21 @@ PYBIND11_MODULE(tpu_client_extension, m) { .def("devices", &PyTpuClient::devices) .def("local_devices", &PyTpuClient::local_devices) .def("host_id", &PyTpuClient::host_id) + .def("GetDefaultDeviceAssignment", + [](PyTpuClient* client, int num_replicas) + -> StatusOr>> { + TF_ASSIGN_OR_RETURN( + DeviceAssignment device_assignment, + client->GetDefaultDeviceAssignment(num_replicas)); + std::vector> result; + for (int i = 0; i < num_replicas; ++i) { + int device_id = device_assignment(i, 0); + auto iter = client->id_to_device().find(device_id); + CHECK(iter != client->id_to_device().end()) << device_id; + result.push_back(iter->second); + } + return result; + }) .def("TransferToInfeed", [](PyTpuClient* client, const LiteralSlice& literal, int device_ordinal) { @@ -189,6 +204,11 @@ PYBIND11_MODULE(tpu_client_extension, m) { py::call_guard(), py::arg("arguments")) .def("ExecutePerReplica", &PyTpuExecutable::ExecutePerReplica, py::call_guard(), py::arg("arguments")); + + py::class_>(m, "TpuDevice") + .def("__repr__", [](const TpuDevice& device) { + return absl::StrFormat("TpuDevice(id=%i)", device.id()); + }); } // NOLINT(readability/fn_size) } // namespace xla diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index 054c1da9e03..13968154188 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -366,6 +366,21 @@ PYBIND11_MODULE(xla_extension, m) { .def("devices", &PyLocalClient::devices) .def("local_devices", &PyLocalClient::local_devices) .def("host_id", &PyLocalClient::host_id) + .def("GetDefaultDeviceAssignment", + [](PyLocalClient* client, int num_replicas) + -> StatusOr>> { + TF_ASSIGN_OR_RETURN( + DeviceAssignment device_assignment, + client->GetDefaultDeviceAssignment(num_replicas)); + std::vector> result; + for (int i = 0; i < num_replicas; ++i) { + int device_id = device_assignment(i, 0); + auto iter = client->id_to_device().find(device_id); + CHECK(iter != client->id_to_device().end()) << device_id; + result.push_back(iter->second); + } + return result; + }) .def("TransferToInfeed", [](PyLocalClient* client, const LiteralSlice& literal, int device_ordinal) { @@ -624,10 +639,12 @@ PYBIND11_MODULE(xla_extension, m) { py::module ops = m.def_submodule("ops", "XLA operations"); ops.def("AfterAll", &AfterAll); - ops.def("AllReduce", - static_cast, - const absl::optional&)>(&AllReduce)); + ops.def( + "AllReduce", + static_cast, + const absl::optional&, const absl::optional&)>( + &AllReduce)); ops.def("AllToAll", &AllToAll); ops.def("CollectivePermute", &CollectivePermute); ops.def("CreateToken", &CreateToken); diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index c8f66f704d7..a7e35a8a81f 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -91,6 +91,23 @@ class Backend(object): def compile(self, computation, compile_options): """Compiles a computation. Returns an executable.""" + @abc.abstractmethod + def get_default_device_assignment(self, num_replicas): + """Returns the default device assignment that `compile` would use. + + If `compile_options.device_assignment` isn't set, `compile` will pick a + deterministic device assignment based on the number of replicas, possibly + optimizing for device locality. This method returns that assignment, which + is useful for e.g. manually replicating a value before passing it to a + compiled executable. + + Args: + num_replicas: the number of replicas needed. + + Returns: + A list of Devices of length `num_replicas` indexed by replica ID. + """ + class LocalBackend(Backend): """XLA backend implemented using the in-process xla::LocalClient API.""" @@ -143,6 +160,9 @@ class LocalBackend(Backend): options, self.client, compile_options.device_assignment) + def get_default_device_assignment(self, num_replicas): + return self.client.GetDefaultDeviceAssignment(num_replicas) + def serialize(self, executable): return self.client.SerializeExecutable(executable) @@ -1014,7 +1034,7 @@ class ComputationBuilder(object): """ replica_groups_protos = _get_replica_groups_protos(replica_groups) return ops.AllReduce(operand, computation.computation, - replica_groups_protos, None) + replica_groups_protos, None, None) def AllToAll(self, operand, diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index a6300d2dc73..14e6f66741e 100755 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1505,6 +1505,7 @@ cc_library( hdrs = ["hlo_query.h"], deps = [ ":hlo", + ":hlo_casting_utils", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "@com_google_absl//absl/container:flat_hash_set", diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc index eb6692ade5b..ac5edd82bee 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc @@ -239,6 +239,7 @@ TEST_F(BFloat16ConversionFoldingTest, FoldAllReduceTupleOutput) { HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateAllReduce( ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b}, sum, /*replica_groups=*/{}, + /*constrain_layout=*/false, /*channel_id=*/absl::nullopt)); HloInstruction* gte_a = builder.AddInstruction( HloInstruction::CreateGetTupleElement(f32_shape, crs, 0)); diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index f7a5ee691f3..ec93a868022 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -259,6 +259,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllReduce) { HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateAllReduce( ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}, reduction, /*replica_groups=*/{}, + /*constrain_layout=*/false, /*channel_id=*/absl::nullopt)); builder.AddInstruction( HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1)); diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc index d716e62d467..aee1f652abd 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc @@ -211,7 +211,8 @@ TEST_F(BFloat16PropagationTest, DoNotChangeAllReduce) { HloInstruction* all_reduce = builder.AddInstruction(HloInstruction::CreateAllReduce( ShapeUtil::MakeTupleShape({shape, shape}), {a, b}, reduction, - /*replica_groups=*/{}, /*channel_id=*/1)); + /*replica_groups=*/{}, /*constrain_layout=*/false, + /*channel_id=*/1)); HloInstruction* gte0 = builder.AddInstruction( HloInstruction::CreateGetTupleElement(shape, all_reduce, 0)); HloInstruction* gte1 = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index c5ed810c917..37a54f86d3d 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -22,7 +22,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" @@ -254,6 +256,16 @@ using ConstDfsHloVisitorWithDefault = // visiting. class DfsHloRewriteVisitor : public DfsHloVisitorWithDefault { public: + // Runs a visitor on the module and returns whether the module has changed. + StatusOr RunOnModule(HloModule* module) { + bool is_changed = false; + for (const auto& computation : module->computations()) { + TF_RETURN_IF_ERROR(computation->Accept(this)); + is_changed |= changed(); + } + return is_changed; + } + // Default visitor action is to do nothing and return OK. Status DefaultAction(HloInstruction* /*hlo_instruction*/) override { return Status::OK(); diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 9634401fe96..eb8b848fc3f 100755 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -1196,6 +1196,9 @@ cc_library( ":gpu_conv_padding_legalization", ":gpu_conv_rewriter", ":gpu_layout_assignment", + ":reduction_degenerate_dim_remover", + ":reduction_dimension_grouper", + ":reduction_layout_normalizer", ":stream_executor_util", ":target_constants", "//tensorflow/compiler/xla:status_macros", @@ -1664,3 +1667,66 @@ tf_cc_test( "//tensorflow/core:test", ], ) + +cc_library( + name = "reduction_degenerate_dim_remover", + srcs = ["reduction_degenerate_dim_remover.cc"], + hdrs = ["reduction_degenerate_dim_remover.h"], + deps = [ + ":ir_emission_utils", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_casting_utils", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/compiler/xla/service:pattern_matcher", + "//tensorflow/core:lib", + "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + ], +) + +cc_library( + name = "reduction_dimension_grouper", + srcs = ["reduction_dimension_grouper.cc"], + hdrs = ["reduction_dimension_grouper.h"], + deps = [ + ":ir_emission_utils", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_casting_utils", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/compiler/xla/service:pattern_matcher", + "//tensorflow/core:lib", + "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + ], +) + +cc_library( + name = "reduction_layout_normalizer", + srcs = ["reduction_layout_normalizer.cc"], + hdrs = ["reduction_layout_normalizer.h"], + deps = [ + ":ir_emission_utils", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_casting_utils", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/compiler/xla/service:pattern_matcher", + "//tensorflow/core:lib", + "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + ], +) diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc old mode 100755 new mode 100644 index 6404c6d826f..30b204e6fd5 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -161,6 +161,12 @@ Status GpuCompiler::OptimizeHloModule( // where possible. Not every batchnorm op can be implemented as a call to // cudnn, so decompose any remaining batchnorm ops into a soup of HLOs. if (hlo_module->config().debug_options().xla_gpu_use_cudnn_batchnorm()) { + // Since BatchNorm inference is essentially pointwise operations, it is + // always advantageous to use kernel fusion rather than cudnn. + pass.AddPass( + /*rewrite_training_op=*/false, + /*rewrite_inference_op=*/true, + /*rewrite_grad_op=*/false); pass.AddPass(); } pass.AddPass( diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 72f69ca2017..b2067fe916d 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -124,6 +124,24 @@ bool IsCublasGemm(const HloInstruction& hlo) { hlo.custom_call_target() == kGemmCallTarget; } +std::array GetReductionTiling( + const ReductionDimensions& reduction_dimensions) { + if (reduction_dimensions.is_row_reduction) { + int64 tile_z = std::min(reduction_dimensions.dimensions[0], 8LL); + if (reduction_dimensions.dimensions[1] == 1) { + CHECK_EQ(reduction_dimensions.dimensions[0], 1); + return {tile_z, 1, 16}; + } + if (reduction_dimensions.dimensions[2] % (kWarpSize * 64) == 0) { + return {tile_z, 1, 64}; + } + return {tile_z, 1, 8}; + } + + // Column reduction. + return {1, 128, 1}; +} + const char* const kCudnnBatchNormForwardInferenceCallTarget = "__cudnn$batchNormalizationForwardInference"; const char* const kCudnnBatchNormForwardTrainingCallTarget = @@ -201,8 +219,7 @@ bool IsReductionFromOrToContiguousDimensions(const HloInstruction& reduce) { } ReductionDimensions reduction_dimensions = - GetReductionKindAndContiguousComponents(input->shape(), - reduce.dimensions()); + GetReductionKindAndContiguousComponents(reduce); if (reduction_dimensions.is_row_reduction) { // For row reduction, the tile block is 1 x tile_size_x, and we are reducing @@ -218,7 +235,9 @@ bool IsReductionFromOrToContiguousDimensions(const HloInstruction& reduce) { } ReductionDimensions GetReductionKindAndContiguousComponents( - const Shape& input_shape, absl::Span dims_to_reduce) { + const HloInstruction& reduce) { + const Shape& input_shape = reduce.operand(0)->shape(); + absl::Span dims_to_reduce = reduce.dimensions(); DimensionVector dims_to_keep; for (int64 dim = 0; dim < input_shape.rank(); ++dim) { if (!absl::c_linear_search(dims_to_reduce, dim)) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index db3cd228841..2c37a63c05a 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -169,14 +169,17 @@ struct ReductionDimensions { std::array dimensions; }; -// Given the input shape and dimensions to reduce for a reduction, returns -// ReductionDimensions. +// Given the reduction operation, returns ReductionDimensions. // // Prerequisite: the reduction instruction passes the check // IsReductionFromOrToContiguousDimensions, which guarantees either the // dimensions to reduce or the dimensions to keep are consecutive. ReductionDimensions GetReductionKindAndContiguousComponents( - const Shape& input_shape, absl::Span dims_to_reduce); + const HloInstruction& reduce); + +// Get tiling per thread for the given reduction in dimensions [D, H, W]. +std::array GetReductionTiling( + const ReductionDimensions& reduction_dimensions); // Emits call to "vprintf" with given format and arguments. llvm::Value* EmitPrintf(absl::string_view fmt, diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index dbc2c95773a..2f8fd5e01cf 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -1988,10 +1988,11 @@ static int GetNumberOfPartialResults( if (reduction_info.IsRowReduction()) { return 1; } - int64 num_thread = mapping_scheme.GetNumberOfThreadsForDimensionX(); - int64 tile_size = mapping_scheme.GetTileSizeForDimensionX(); - CHECK_EQ(tile_size % num_thread, 0); - return tile_size / num_thread; + int64 num_partial_results = mapping_scheme.DilatedX() ? 1 : 2; + CHECK_EQ(num_partial_results, + (mapping_scheme.GetTileSizeForDimensionX() / + mapping_scheme.GetNumberOfThreadsForDimensionX())); + return num_partial_results; } void IrEmitterUnnested::EmitPrologueForOneReduction( @@ -2876,36 +2877,26 @@ ReductionCodegenInfo IrEmitterUnnested::ComputeReductionCodegenInfo( const HloInstruction* unnested_hlo, const HloInstruction* first_reduce) { const Shape& input_shape = first_reduce->operand(0)->shape(); ReductionDimensions reduction_dimensions = - GetReductionKindAndContiguousComponents(input_shape, - first_reduce->dimensions()); + GetReductionKindAndContiguousComponents(*first_reduce); VLOG(10) << "is_row_reduction " << reduction_dimensions.is_row_reduction << " " << reduction_dimensions.dimensions[0] << " " << reduction_dimensions.dimensions[1] << " " << reduction_dimensions.dimensions[2]; + std::array reduction_tiling = + GetReductionTiling(reduction_dimensions); + int64 tile_size_y = reduction_tiling[1]; + int64 block_size_z = reduction_tiling[0]; + bool dilated_x = + !reduction_dimensions.is_row_reduction && + !IsUnrollingColumnReductionBeneficial(unnested_hlo, input_shape, + reduction_dimensions.dimensions[2]); + int64 tile_size_x = 1; - int64 tile_size_y = 1; - int64 block_size_z = 1; int64 num_threads_x = 1; - bool dilated_x = true; if (reduction_dimensions.is_row_reduction) { num_threads_x = kWarpSize; - if (reduction_dimensions.dimensions[1] == 1) { - // Scalar reduction is handled differently than the other kind of row - // reduction. - CHECK_EQ(reduction_dimensions.dimensions[0], 1); - tile_size_x = kWarpSize * 16; - } else { - if (reduction_dimensions.dimensions[2] % (kWarpSize * 64) == 0) { - tile_size_x = kWarpSize * 64; - } else { - tile_size_x = kWarpSize * 8; - block_size_z = 8; - while (reduction_dimensions.dimensions[0] % block_size_z != 0) { - block_size_z -= 1; - } - } - } + tile_size_x = reduction_tiling[2] * kWarpSize; } else { // Column reduction without transpose doesn't require communication among // threads processing elements in the same tile. The current implementation @@ -2915,20 +2906,17 @@ ReductionCodegenInfo IrEmitterUnnested::ComputeReductionCodegenInfo( // num_threads_x and tile_size_x to allow a bigger hardware thread block. int64 hw_threads_per_block_limit = ThreadsPerBlockLimit(ir_emitter_context_->device_description()); - if (IsUnrollingColumnReductionBeneficial( - unnested_hlo, input_shape, reduction_dimensions.dimensions[2])) { + if (!dilated_x) { // Vectorized loads: two elements per thread. tile_size_x = std::min(2 * hw_threads_per_block_limit, reduction_dimensions.dimensions[2]); num_threads_x = tile_size_x / 2; - dilated_x = false; } else { // One element per thread. tile_size_x = std::min(hw_threads_per_block_limit, reduction_dimensions.dimensions[2]); num_threads_x = tile_size_x; } - tile_size_y = 128; } KernelMappingScheme mapping_scheme( diff --git a/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h b/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h index 345abbd0935..2eede7036cf 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h +++ b/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h @@ -85,15 +85,14 @@ class KernelMappingScheme { dims_in_tiles_{dims_in_elems[0], CeilOfRatio(dims_in_elems[1], tile_size_y), CeilOfRatio(dims_in_elems[2], tile_size_x)}, - dims_in_blocks_{dims_in_tiles_[0] / block_size_z, dims_in_tiles_[1], - dims_in_tiles_[2]}, + dims_in_blocks_{CeilOfRatio(dims_in_tiles_[0], block_size_z), + dims_in_tiles_[1], dims_in_tiles_[2]}, block_size_z_{block_size_z}, num_threads_x_(num_threads_x), num_threads_y_(num_threads_y), dilated_x_(is_dilated_x) { CHECK_EQ(tile_size_y % num_threads_y_, 0); CHECK_EQ(tile_size_x % num_threads_x_, 0); - CHECK_EQ((dims_in_elems[0] % block_size_z), 0); VLOG(10) << "dims_in_elems_ = " << absl::StrJoin(dims_in_elems_, ","); VLOG(10) << "dims_in_tiles_ = " << absl::StrJoin(dims_in_tiles_, ","); VLOG(10) << "dims_in_blocks_ = " << absl::StrJoin(dims_in_blocks_, ","); diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index 489cbd101e2..6635b68899d 100755 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -32,6 +32,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h" #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" +#include "tensorflow/compiler/xla/service/gpu/reduction_degenerate_dim_remover.h" +#include "tensorflow/compiler/xla/service/gpu/reduction_dimension_grouper.h" +#include "tensorflow/compiler/xla/service/gpu/reduction_layout_normalizer.h" #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" #include "tensorflow/compiler/xla/service/gpu/target_constants.h" #include "tensorflow/compiler/xla/service/hlo_constant_folding.h" @@ -154,6 +157,10 @@ Status NVPTXCompiler::OptimizeHloPostLayoutAssignment( /*allow_mixed_precision=*/false, LayoutAssignment::InstructionCanChangeLayout); + pipeline.AddPass(); + pipeline.AddPass(); + pipeline.AddPass(); + // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. AlgebraicSimplifierOptions options; diff --git a/tensorflow/compiler/xla/service/gpu/reduction_degenerate_dim_remover.cc b/tensorflow/compiler/xla/service/gpu/reduction_degenerate_dim_remover.cc new file mode 100644 index 00000000000..2c786b577fc --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/reduction_degenerate_dim_remover.cc @@ -0,0 +1,92 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/reduction_degenerate_dim_remover.h" + +#include + +#include "absl/algorithm/container.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace xla { +namespace gpu { + +class ReductionDegenerateDimRemoverVisitor : public DfsHloRewriteVisitor { + public: + Status HandleReduce(HloInstruction *instr) override { + HloInstruction *reduced_op = instr->mutable_operand(0); + const Shape &input_shape = reduced_op->shape(); + const Shape &reduce_shape = instr->shape(); + + if (!instr->shape().IsArray() || + !ShapeUtil::HasDegenerateDimensions(reduced_op->shape())) { + return Status::OK(); + } + Shape canonical_input_shape = + ShapeUtil::DropDegenerateDimensions(input_shape); + + Shape canonical_reduce_shape = + ShapeUtil::DropDegenerateDimensions(reduce_shape); + + const std::vector &reduced_dimensions = instr->dimensions(); + std::vector updated_reduced_dimensions; + int64 shift = 0; + + for (int dim = 0; dim < input_shape.rank(); dim++) { + if (input_shape.dimensions(dim) == 1) { + shift++; + } else { + if (absl::c_linear_search(reduced_dimensions, dim)) { + updated_reduced_dimensions.push_back(dim - shift); + } + } + } + + HloInstruction *input_reshape = instr->parent()->AddInstruction( + HloInstruction::CreateBitcast(canonical_input_shape, reduced_op)); + + std::unique_ptr new_reduce = HloInstruction::CreateReduce( + canonical_reduce_shape, input_reshape, instr->mutable_operand(1), + updated_reduced_dimensions, instr->to_apply()); + + if (canonical_reduce_shape != reduce_shape) { + HloInstruction *wrapped_reduce = + instr->parent()->AddInstruction(std::move(new_reduce)); + new_reduce = HloInstruction::CreateBitcast(reduce_shape, wrapped_reduce); + } + + return ReplaceWithNewInstruction(instr, std::move(new_reduce)); + } +}; + +StatusOr ReductionDegenerateDimRemover::Run(HloModule *module) { + TF_ASSIGN_OR_RETURN( + bool changed, ReductionDegenerateDimRemoverVisitor().RunOnModule(module)); + return changed; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/reduction_degenerate_dim_remover.h b/tensorflow/compiler/xla/service/gpu/reduction_degenerate_dim_remover.h new file mode 100644 index 00000000000..eeb26da607a --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/reduction_degenerate_dim_remover.h @@ -0,0 +1,53 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_REDUCTION_DEGENERATE_DIM_REMOVER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_REDUCTION_DEGENERATE_DIM_REMOVER_H_ + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { +namespace gpu { + +// Enforces the invariant that reduction input and output have no degenerate +// (size 1) dimension. Since these dimensions are physically meaningless, they +// are removed using bitcasts. +// +// For example, +// +// f[1] out = reduce(f[100, 1, 1] input, dimensions={0, 1}) +// +// becomes: +// +// +// f[100] tmp1 = f[100] bitcast(f[100, 1, 1], input) +// f[] tmp2 = reduce(f[100] tmp1, dimensions={0}) +// f[1] out = f[] bitcast(tmp2) +// +class ReductionDegenerateDimRemover : public HloModulePass { + public: + absl::string_view name() const override { + return "reduction-degenerate-dim-remover"; + } + + StatusOr Run(HloModule* module) override; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_REDUCTION_DEGENERATE_DIM_REMOVER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/reduction_dimension_grouper.cc b/tensorflow/compiler/xla/service/gpu/reduction_dimension_grouper.cc new file mode 100644 index 00000000000..66b458e1ba4 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/reduction_dimension_grouper.cc @@ -0,0 +1,107 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/reduction_dimension_grouper.h" + +#include + +#include "absl/algorithm/container.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace xla { +namespace gpu { + +class ReduceDimensionGroupVisitor : public DfsHloRewriteVisitor { + public: + Status HandleReduce(HloInstruction *reduce) override { + VLOG(4) << "Input: " << reduce->ToString(); + + if (!reduce->shape().IsArray()) { + // TODO(cheshire): Handle variadic reduction. + return Status::OK(); + } + + std::vector new_grouped_dims; + std::vector reduced_dims_grouped; + HloInstruction *operand = reduce->mutable_operand(0); + const Shape &shape = operand->shape(); + CHECK(shape == LayoutUtil::GetWithDefaultLayout(shape)) + << "Default layout should be enforced on reduction operand"; + auto is_reduced = [&](int dim) { + return absl::c_linear_search(reduce->dimensions(), dim); + }; + + bool changed = false; + int64 next_dim_size = 1; + + // Since we have enforced the standard layout, iteration over logical + // dimensions is equivalent to iteration over the major-to-minor order. + for (int logical_dim = 0; logical_dim < shape.rank(); logical_dim++) { + VLOG(5) << "Processing dimension " << logical_dim << " of size " + << shape.dimensions(logical_dim); + if (is_reduced(logical_dim) && logical_dim < shape.rank() - 1 && + is_reduced(logical_dim + 1)) { + VLOG(5) << "This and consecutive dimension are reduced, merging"; + changed = true; + next_dim_size *= shape.dimensions(logical_dim); + continue; + } + + if (is_reduced(logical_dim)) { + new_grouped_dims.push_back(next_dim_size * + shape.dimensions(logical_dim)); + reduced_dims_grouped.push_back(new_grouped_dims.size() - 1); + next_dim_size = 1; + } else { + new_grouped_dims.push_back(shape.dimensions(logical_dim)); + } + } + + if (!changed) { + return Status::OK(); + } + + Shape grouped_shape = + ShapeUtil::MakeShape(shape.element_type(), new_grouped_dims); + HloInstruction *reduce_input_grouped = reduce->parent()->AddInstruction( + HloInstruction::CreateBitcast(grouped_shape, operand)); + + std::unique_ptr new_reduce = HloInstruction::CreateReduce( + reduce->shape(), reduce_input_grouped, reduce->mutable_operand(1), + reduced_dims_grouped, reduce->to_apply()); + VLOG(5) << "Generated new reduction: " << new_reduce->ToString(); + return ReplaceWithNewInstruction(reduce, std::move(new_reduce)); + } +}; + +StatusOr ReductionDimensionGrouper::Run(HloModule *module) { + TF_ASSIGN_OR_RETURN(bool changed, + ReduceDimensionGroupVisitor().RunOnModule(module)); + return changed; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/reduction_dimension_grouper.h b/tensorflow/compiler/xla/service/gpu/reduction_dimension_grouper.h new file mode 100644 index 00000000000..8a78d3fca07 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/reduction_dimension_grouper.h @@ -0,0 +1,54 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_REDUCTION_DIMENSION_GROUPER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_REDUCTION_DIMENSION_GROUPER_H_ + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { +namespace gpu { + +// Groups adjacent (logically and physically) reduced dimensions in reduction +// input. +// +// Precondition: ReductionLayoutNormalizer has been run (physical proximity and +// logical proximity become the same). +// +// For example, +// +// f[] out = reduce(f[10,20,30] input, dimensions={0,1,2}) +// +// becomes: +// +// f[600] tmp = f[600] bitcast(f[10,20,30] input) +// f[] out = reduce(f[600] tmp, dimensions={0}) +// +// TODO(cheshire): handle variadic reduction +class ReductionDimensionGrouper : public HloModulePass { + public: + absl::string_view name() const override { + return "reduction-dimension-grouper"; + } + + StatusOr Run(HloModule* module) override; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_REDUCTION_DIMENSION_GROUPER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/reduction_layout_normalizer.cc b/tensorflow/compiler/xla/service/gpu/reduction_layout_normalizer.cc new file mode 100644 index 00000000000..295ccebd442 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/reduction_layout_normalizer.cc @@ -0,0 +1,129 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/reduction_layout_normalizer.h" + +#include + +#include "absl/algorithm/container.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace xla { +namespace gpu { + +class EnforceMinorToMajorReduceOpVisitor : public DfsHloRewriteVisitor { + Status HandleReduce(HloInstruction *reduce) override { + VLOG(5) << "Input: " << reduce->ToString(); + HloInstruction *operand = reduce->mutable_operand(0); + const Shape &operand_shape = operand->shape(); + const Layout &operand_layout = operand_shape.layout(); + const Shape &reduce_shape = reduce->shape(); + + if (!reduce_shape.IsArray()) { + // TODO(cheshire): Handle variadic reduction. + return Status::OK(); + } + + std::vector new_reduce_dimensions; + std::vector new_operand_shape_data; + std::vector new_reduce_shape_data; + + // The layout order of the reduction output can be different to the + // ordering of kept dimensions in the input operand, thus we need to + // calculate the new layout. + std::vector new_reduce_shape_layout(reduce_shape.rank()); + std::vector reduce_shape_logical_to_physical = + LayoutUtil::MakeLogicalToPhysical(reduce_shape.layout()); + + auto to_reduce_logical_dim = [&](int64 op_logical_dim) { + return op_logical_dim - + absl::c_count_if(reduce->dimensions(), [&](int64 dim) { + CHECK(dim != op_logical_dim); + return dim < op_logical_dim; + }); + }; + + for (int i = 0; i < operand_shape.rank(); i++) { + // Process the dimensions in the major-to-minor order in order to enforce + // the default layout. + int64 major_to_minor_dim_idx = operand_shape.rank() - i - 1; + int64 logical_dim = operand_layout.minor_to_major(major_to_minor_dim_idx); + int64 dim_size = operand_shape.dimensions(logical_dim); + VLOG(5) << "Processing logical dimension " << logical_dim << " of size " + << dim_size; + new_operand_shape_data.push_back(dim_size); + + if (absl::c_linear_search(reduce->dimensions(), logical_dim)) { + new_reduce_dimensions.push_back(i); + } else { + new_reduce_shape_data.push_back(dim_size); + int64 logical_reduce_dim = to_reduce_logical_dim(logical_dim); + int64 physical_reduce_dim = + reduce_shape_logical_to_physical[logical_reduce_dim]; + VLOG(5) << "logical_reduce_dim = " << logical_reduce_dim << ", " + << "physical_reduce_dim = " << physical_reduce_dim; + new_reduce_shape_layout[reduce_shape.rank() - physical_reduce_dim - 1] = + new_reduce_shape_data.size() - 1; + } + } + + Shape new_operand_shape = ShapeUtil::MakeShape(operand_shape.element_type(), + new_operand_shape_data); + if (new_operand_shape == operand_shape) { + return Status::OK(); + } + + Shape new_reduce_shape = ShapeUtil::MakeShapeWithLayout( + reduce_shape.element_type(), new_reduce_shape_data, + new_reduce_shape_layout); + HloInstruction *canonical_reduce_input = reduce->parent()->AddInstruction( + HloInstruction::CreateBitcast(new_operand_shape, operand)); + + VLOG(5) << "Reduction input: " << canonical_reduce_input->ToString(); + std::unique_ptr new_reduce = HloInstruction::CreateReduce( + new_reduce_shape, canonical_reduce_input, reduce->mutable_operand(1), + new_reduce_dimensions, reduce->to_apply()); + VLOG(5) << "Generated new reduction: " << new_reduce->ToString(); + + if (new_reduce_shape != reduce_shape) { + HloInstruction *wrapped_reduce = + reduce->parent()->AddInstruction(std::move(new_reduce)); + new_reduce = HloInstruction::CreateBitcast(reduce_shape, wrapped_reduce); + } + + VLOG(5) << "Generated output: " << new_reduce->ToString(); + return ReplaceWithNewInstruction(reduce, std::move(new_reduce)); + } +}; + +StatusOr ReductionLayoutNormalizer::Run(HloModule *module) { + TF_ASSIGN_OR_RETURN(bool changed, + EnforceMinorToMajorReduceOpVisitor().RunOnModule(module)); + return changed; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/reduction_layout_normalizer.h b/tensorflow/compiler/xla/service/gpu/reduction_layout_normalizer.h new file mode 100644 index 00000000000..d27c847f8ea --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/reduction_layout_normalizer.h @@ -0,0 +1,50 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_REDUCTION_LAYOUT_NORMALIZER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_REDUCTION_LAYOUT_NORMALIZER_H_ + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { +namespace gpu { + +// Enforces default (minor-to-major) layout on all reduction inputs. +// Note that since reduction output can request a custom layout, +// this pass only guarantees standard layout for the input. +// +// For example, +// +// f[20,30]{0,1} out = reduce(f[10,20,30]{2,0,1} input, dimensions={0}) +// +// becomes: +// +// f[20,10,30] tmp = f[20,10,30] bitcast(f[10,20,30]{2,0,1} input) +// f[20,30]{0,1} out = reduce(f[20,10,30]{2,1,0} tmp, dimensions={1}) +class ReductionLayoutNormalizer : public HloModulePass { + public: + absl::string_view name() const override { + return "reduction-layout-normalizer"; + } + + StatusOr Run(HloModule* module) override; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_REDUCTION_LAYOUT_NORMALIZER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD index 11cb5f0cbf7..51a12e1f2fe 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD @@ -81,6 +81,87 @@ tf_cc_test( ], ) +tf_cc_test( + name = "reduction_degenerate_dim_remover_test", + srcs = [ + "reduction_degenerate_dim_remover_test.cc", + ], + tags = tf_cuda_tests_tags(), + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla:debug_options_flags", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/service:gpu_plugin", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service/gpu:gemm_rewriter", + "//tensorflow/compiler/xla/service/gpu:gpu_executable", + "//tensorflow/compiler/xla/tests:filecheck", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:llvm_irgen_test_base", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/memory", + ], +) + +tf_cc_test( + name = "reduction_layout_normalizer_test", + srcs = [ + "reduction_layout_normalizer_test.cc", + ], + tags = tf_cuda_tests_tags(), + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla:debug_options_flags", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/service:gpu_plugin", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service/gpu:gemm_rewriter", + "//tensorflow/compiler/xla/service/gpu:gpu_executable", + "//tensorflow/compiler/xla/tests:filecheck", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:llvm_irgen_test_base", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/memory", + ], +) + +tf_cc_test( + name = "reduction_dimension_grouper_test", + srcs = [ + "reduction_dimension_grouper_test.cc", + ], + tags = tf_cuda_tests_tags(), + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla:debug_options_flags", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/service:gpu_plugin", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service/gpu:gemm_rewriter", + "//tensorflow/compiler/xla/service/gpu:gpu_executable", + "//tensorflow/compiler/xla/tests:filecheck", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:llvm_irgen_test_base", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/memory", + ], +) + tf_cc_test( name = "gpu_copy_test", srcs = ["gpu_copy_test.cc"], diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc index 92bb84065a2..ae10fb161d6 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc @@ -461,7 +461,7 @@ TEST_F(GpuKernelTilingTest, ColumnReductionWithLayoutChangeTiled) { .ValueOrDie(); CompileAndVerifyIr(std::move(hlo_module), R"( -; CHECK-LABEL: define void @reduce +; CHECK-LABEL: define void @ ; CHECK: atomicrmw fadd float ; CHECK: } )", diff --git a/tensorflow/compiler/xla/service/gpu/tests/reduction_degenerate_dim_remover_test.cc b/tensorflow/compiler/xla/service/gpu/tests/reduction_degenerate_dim_remover_test.cc new file mode 100644 index 00000000000..686092706f7 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/reduction_degenerate_dim_remover_test.cc @@ -0,0 +1,72 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/tests/filecheck.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace xla { +namespace gpu { + +namespace { + +class ReductionDegenerateDimRemoverTest : public GpuCodegenTest { + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); + debug_options.add_xla_disable_hlo_passes("reduction-layout-normalizer"); + debug_options.add_xla_disable_hlo_passes("reduction-dimension-grouper"); + return debug_options; + } +}; + +TEST_F(ReductionDegenerateDimRemoverTest, ReductionWithDegenerateDimensions) { + const char* hlo_text = R"( +HloModule ReduceWithDegenerateDimensions + +add { + accum = f32[] parameter(0) + op = f32[] parameter(1) + ROOT out = f32[] add(accum, op) +} + +ENTRY main { + input = f32[1,3,1,4,1,5,1] parameter(0) + zero = f32[] constant(0) + + ROOT out = f32[1,1,1,1] reduce(input, zero), dimensions={1,3,5}, to_apply=add +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHloWithShapes(hlo_text, + R"( +// CHECK: f32[] reduce(f32[3,4,5]{2,1,0} {{.+}}, f32[] {{.+}}), dimensions={0,1,2}, to_apply=%add + )"); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/reduction_dimension_grouper_test.cc b/tensorflow/compiler/xla/service/gpu/tests/reduction_dimension_grouper_test.cc new file mode 100644 index 00000000000..a9e0b9b5c5f --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/reduction_dimension_grouper_test.cc @@ -0,0 +1,70 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/tests/filecheck.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { + +namespace { + +class ReductionDimensionGrouperTest : public GpuCodegenTest { + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); + debug_options.add_xla_disable_hlo_passes("reduction-layout-normalizer"); + debug_options.add_xla_disable_hlo_passes("layout-assignment"); + return debug_options; + } +}; + +TEST_F(ReductionDimensionGrouperTest, ReductionWithGrouping) { + const char* hlo_text = R"( +HloModule ReductionWithGrouping + +add { + accum = f32[] parameter(0) + op = f32[] parameter(1) + ROOT out = f32[] add(accum, op) +} + +ENTRY main { + input = f32[100,10,32,3]{3,2,1,0} parameter(0) + zero = f32[] constant(0) + + ROOT out = f32[100,10]{0,1} reduce(input, zero), dimensions={2,3}, to_apply=add +} + + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHloWithShapes(hlo_text, + R"( +// CHECK: f32[100,10]{0,1} reduce(f32[100,10,96]{2,1,0} {{.+}}, f32[] {{.+}}), dimensions={2}, to_apply=%add + )"); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/reduction_layout_normalizer_test.cc b/tensorflow/compiler/xla/service/gpu/tests/reduction_layout_normalizer_test.cc new file mode 100644 index 00000000000..49b8bbf1d6b --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/reduction_layout_normalizer_test.cc @@ -0,0 +1,69 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/tests/filecheck.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { + +namespace { + +class ReductionLayoutNormalizerTest : public GpuCodegenTest { + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); + debug_options.add_xla_disable_hlo_passes("reduction-dimension-grouper"); + debug_options.add_xla_disable_hlo_passes("layout-assignment"); + return debug_options; + } +}; + +TEST_F(ReductionLayoutNormalizerTest, LayoutCanonicalizerTest) { + const char* hlo_text = R"( +HloModule ReduceWithLayoutChange + +add { + x0 = f32[] parameter(0) + y0 = f32[] parameter(1) + ROOT add0 = f32[] add(x0, y0) +} + +ENTRY main { + arg0 = f32[4,5,5,16,12,12,3,3]{2,3,5,4,0,7,6,1} parameter(0) + constant0 = f32[] constant(0) + ROOT reduce0 = f32[4,5,16,12,12]{4,3,2,1,0} reduce(arg0, constant0), + dimensions={1,6,7}, to_apply=add +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHloWithShapes(hlo_text, + R"( +// CHECK: f32[4,12,12,16,5]{2,1,3,4,0} reduce(f32[5,3,3,4,12,12,16,5]{7,6,5,4,3,2,1,0} {{.+}}, f32[] {{.+}}), dimensions={0,1,2}, to_apply=%add + )"); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 368a3876f8c..bc099371d08 100755 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -400,6 +400,7 @@ StatusOr> HloInstruction::CreateFromProto( /*replica_groups=*/ std::vector(proto.replica_groups().begin(), proto.replica_groups().end()), + /*constrain_layout=*/proto.constrain_layout(), /*channel_id=*/channel_id); break; } @@ -900,10 +901,11 @@ HloInstruction::CreateReducePrecision(const Shape& shape, /* static */ std::unique_ptr HloInstruction::CreateAllReduce( const Shape& shape, absl::Span operands, HloComputation* reduce_computation, - const std::vector& replica_groups, + const std::vector& replica_groups, bool constrain_layout, const absl::optional& channel_id) { return absl::make_unique( - shape, operands, reduce_computation, replica_groups, channel_id); + shape, operands, reduce_computation, replica_groups, constrain_layout, + channel_id); } /* static */ std::unique_ptr HloInstruction::CreateAllToAll( @@ -1341,7 +1343,8 @@ bool HloInstruction::HasSideEffectNoRecurse() const { case HloOpcode::kTrace: return true; case HloOpcode::kAllReduce: - return channel_id().has_value(); + return channel_id().has_value() || + Cast(this)->constrain_layout(); case HloOpcode::kCustomCall: return Cast(this) ->custom_call_has_side_effect(); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 5855911650d..238a96e52a0 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -607,7 +607,7 @@ class HloInstruction { static std::unique_ptr CreateAllReduce( const Shape& shape, absl::Span operands, HloComputation* reduce_computation, - const std::vector& replica_groups, + const std::vector& replica_groups, bool constrain_layout, const absl::optional& channel_id); // An all-to-all op takes N array operands of the same shape and scatters them diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 9448feb7d8a..a150efd8c83 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -553,10 +553,11 @@ bool HloCollectiveInstruction::IdenticalSlowPath( HloAllReduceInstruction::HloAllReduceInstruction( const Shape& shape, absl::Span operands, HloComputation* reduce_computation, - const std::vector& replica_groups, + const std::vector& replica_groups, bool constrain_layout, const absl::optional& channel_id) : HloCollectiveInstruction(HloOpcode::kAllReduce, shape, operands, - replica_groups, channel_id) { + replica_groups, channel_id), + constrain_layout_(constrain_layout) { AppendComputation(reduce_computation); } @@ -569,12 +570,29 @@ bool HloAllReduceInstruction::IsNoop() const { return !channel_id(); } +HloInstructionProto HloAllReduceInstruction::ToProto() const { + HloInstructionProto proto = HloCollectiveInstruction::ToProto(); + proto.set_constrain_layout(constrain_layout_); + return proto; +} + +std::vector HloAllReduceInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + std::vector result = + HloCollectiveInstruction::ExtraAttributesToStringImpl(options); + if (constrain_layout_) { + result.push_back("constrain_layout=true"); + } + return result; +} + bool HloAllReduceInstruction::IdenticalSlowPath( const HloInstruction& other, const std::function& eq_computations) const { const auto& casted_other = static_cast(other); return HloCollectiveInstruction::IdenticalSlowPath(other, eq_computations) && + constrain_layout() == casted_other.constrain_layout() && eq_computations(to_apply(), casted_other.to_apply()); } @@ -583,7 +601,8 @@ HloAllReduceInstruction::CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* /*context*/) const { return absl::make_unique( - shape, new_operands, to_apply(), replica_groups(), channel_id()); + shape, new_operands, to_apply(), replica_groups(), constrain_layout(), + channel_id()); } HloAllToAllInstruction::HloAllToAllInstruction( diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 8950e6218e3..1863c78e7e1 100755 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -336,13 +336,33 @@ class HloAllReduceInstruction : public HloCollectiveInstruction { explicit HloAllReduceInstruction( const Shape& shape, absl::Span operands, HloComputation* reduce_computation, - const std::vector& replica_groups, + const std::vector& replica_groups, bool constrain_layout, const absl::optional& channel_id); // Returns true if the AllReduce does no communication, so it's equivalent // to a mem copy. bool IsNoop() const; + // Returns true if the layout of the AllReduce is enforced by XLA client (as + // the layout set in the shape). The only reason for the client to set the + // layout is to separately compile computations that communicate with + // AllReduce. Since this field is only set `true` by the client, the compiler + // only needs to propagate existing values (e.g., Clone, X64Rewriter) or set + // `false` for all other cases. + // + // When this is `true`, there may be communication endpoints outside the + // current compilation unit, so the compiler considers this AllReduce as + // side-effecting to disable compiler transformations. The compiler is free to + // transform unconstrained AllReduces differently across compilation units. + // It is an error for an HloModule to have a mix of constrained and + // unconstrained AllReduce instructions (checked by HloVerifier). + bool constrain_layout() const { return constrain_layout_; } + + protected: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + HloInstructionProto ToProto() const override; + private: bool IdenticalSlowPath( const HloInstruction& other, @@ -353,6 +373,8 @@ class HloAllReduceInstruction : public HloCollectiveInstruction { std::unique_ptr CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; + + bool constrain_layout_; }; class HloAllToAllInstruction : public HloCollectiveInstruction { diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index ef58b37b469..3ecd0af3480 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -857,11 +857,14 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, optional to_apply; optional> replica_group_ids; optional channel_id; + optional constrain_layout; attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, &to_apply}; attrs["replica_groups"] = {/*required=*/false, AttrTy::kBracedInt64ListList, &tmp_groups}; attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id}; + attrs["constrain_layout"] = {/*required=*/false, AttrTy::kBool, + &constrain_layout}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } @@ -870,7 +873,8 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, replica_groups = CreateReplicaGroups(*tmp_groups); } instruction = builder->AddInstruction(HloInstruction::CreateAllReduce( - shape, operands, *to_apply, replica_groups, channel_id)); + shape, operands, *to_apply, replica_groups, + constrain_layout ? *constrain_layout : false, channel_id)); break; } case HloOpcode::kAllToAll: { diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index a522b1ddbfe..29a6a5e4297 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -1472,6 +1472,24 @@ ENTRY AllReduceWithSubgroups { ROOT all-reduce = f32[128,32]{0,1} all-reduce(input), replica_groups={{0,1},{2,3}}, to_apply=add } +)" +}, +// all-reduce with constrained layout +{ +"AllReduceWithLayout", +R"(HloModule CRS + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY CRS { + input = f32[8]{0} parameter(0) + ROOT crs = f32[8]{0} all-reduce(input), replica_groups={}, constrain_layout=true, to_apply=add +} + )" }, // all-reduce with all-reduce-id diff --git a/tensorflow/compiler/xla/service/hlo_query.cc b/tensorflow/compiler/xla/service/hlo_query.cc index f968a4a9445..defd6abd8f6 100644 --- a/tensorflow/compiler/xla/service/hlo_query.cc +++ b/tensorflow/compiler/xla/service/hlo_query.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_query.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -119,5 +121,17 @@ bool ContainsInstrWithOpcode(const HloComputation* comp, return false; } +bool ContainsLayoutConstrainedAllReduce(const HloModule& module) { + for (auto computation : module.computations()) { + for (auto hlo : computation->instructions()) { + if (hlo->opcode() == HloOpcode::kAllReduce && + DynCast(hlo)->constrain_layout()) { + return true; + } + } + } + return false; +} + } // namespace hlo_query } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_query.h b/tensorflow/compiler/xla/service/hlo_query.h index 215051f8834..0ea36ae83f8 100644 --- a/tensorflow/compiler/xla/service/hlo_query.h +++ b/tensorflow/compiler/xla/service/hlo_query.h @@ -19,6 +19,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" namespace xla { @@ -72,6 +73,10 @@ bool MatchBinaryInstructionOperandOpcode(HloOpcode opcode, HloInstruction** matching_operand, HloInstruction** other_operand); +// Returns whether the module contains all-reduce instructions with constrained +// layout. +bool ContainsLayoutConstrainedAllReduce(const HloModule& module); + } // namespace hlo_query } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 4d460ee30ca..1218f7dfc6f 100755 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -1310,6 +1310,29 @@ Status VerifyAsynchronousCopies(const HloModule& module) { return Status::OK(); } +// Checks that AllReduce instructions in the module are either all layout +// constrained or all unconstrained. +Status VerifyLayoutConstrainedAllReduce(const HloModule& module) { + const HloAllReduceInstruction* reference = nullptr; + for (const HloComputation* computation : module.computations()) { + for (const HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() != HloOpcode::kAllReduce) { + continue; + } + auto all_reduce = DynCast(instruction); + if (!reference) { + reference = all_reduce; + } + if (reference->constrain_layout() != all_reduce->constrain_layout()) { + return FailedPrecondition( + "HloModule has a mix of layout constrained and unconstrained " + "AllReduce instructions."); + } + } + } + return Status::OK(); +} + // Checks various invariants of send and recv instructions. Status VerifySendsAndRecvs(const HloModule& module) { absl::flat_hash_map host_channels; @@ -1697,6 +1720,7 @@ StatusOr HloVerifier::Run(HloModule* module) { })); TF_RETURN_IF_ERROR(module->dynamic_parameter_binding().Verify(*module)); + TF_RETURN_IF_ERROR(VerifyLayoutConstrainedAllReduce(*module)); return false; } diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index df603102157..1b273909991 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -988,5 +988,30 @@ TEST_F(HloVerifierTest, FusionShapeVerifier) { HasSubstr("Fused computation shape")); } +TEST_F(HloVerifierTest, AllReduceVerifier) { + const char* const kModuleStr = R"( + HloModule test + + add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + + ENTRY entry { + input = f32[8,12]{0,1} parameter(0) + crs0 = f32[8,12]{0,1} all-reduce(input), replica_groups={}, to_apply=add + crs1 = f32[8,12]{0,1} all-reduce(input), replica_groups={}, to_apply=add, + constrain_layout=true + ROOT result = (f32[8,12]{0,1}, f32[8,12]{0,1}) tuple(crs0, crs1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(kModuleStr)); + EXPECT_THAT( + verifier().Run(module.get()).status().error_message(), + HasSubstr("mix of layout constrained and unconstrained AllReduce")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 81a42de6816..defaf4cd7ab 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -432,6 +432,12 @@ bool IsLayoutConstrainedCustomCall(HloInstruction* instruction) { return custom_call != nullptr && custom_call->layout_constrained(); } +bool IsLayoutConstrainedAllReduce(HloInstruction* instruction) { + const HloAllReduceInstruction* all_reduce = + DynCast(instruction); + return all_reduce != nullptr && all_reduce->constrain_layout(); +} + } // namespace Status LayoutAssignment::AddMandatoryConstraints( @@ -516,6 +522,9 @@ Status LayoutAssignment::AddMandatoryConstraints( TF_RETURN_IF_ERROR( constraints->SetBufferLayout(new_shape.layout(), *buffer)); } + } else if (IsLayoutConstrainedAllReduce(instruction)) { + TF_RETURN_IF_ERROR( + constraints->SetInstructionLayout(instruction->shape(), instruction)); } else if (instruction->IsCrossModuleAllReduce()) { CHECK(get_channel_constraints(instruction)) << "Multi-module layout assignment requires ChannelLayoutConstraints"; @@ -1765,7 +1774,8 @@ Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) { } // Some instructions carry mandatory layouts in their shape. if (instruction->opcode() != HloOpcode::kInfeed && - !IsLayoutConstrainedCustomCall(instruction)) { + !IsLayoutConstrainedCustomCall(instruction) && + !IsLayoutConstrainedAllReduce(instruction)) { LayoutUtil::ClearLayout(instruction->mutable_shape()); } } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/BUILD b/tensorflow/compiler/xla/service/mlir_gpu/BUILD index e5b6138257b..f7d0aa6b669 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/BUILD +++ b/tensorflow/compiler/xla/service/mlir_gpu/BUILD @@ -139,6 +139,9 @@ cc_library( hdrs = ["kernel_lowering.h"], deps = [ "//tensorflow/compiler/mlir/xla:hlo", + "//tensorflow/compiler/mlir/xla:hlo_legalize_to_lhlo", + "//tensorflow/compiler/mlir/xla:lhlo", + "//tensorflow/compiler/mlir/xla:lhlo_fuse_linalg", "//tensorflow/compiler/mlir/xla:lhlo_legalize_to_affine", "//tensorflow/compiler/mlir/xla:lhlo_legalize_to_linalg", "//tensorflow/compiler/mlir/xla:xla_dialect_registration", @@ -157,10 +160,12 @@ cc_library( "@local_config_mlir//:Linalg", "@local_config_mlir//:LinalgDialectRegistration", "@local_config_mlir//:LoopDialectRegistration", + "@local_config_mlir//:LoopOps", "@local_config_mlir//:LoopsToGPUPass", "@local_config_mlir//:NVVMDialect", "@local_config_mlir//:Pass", "@local_config_mlir//:StandardDialectRegistration", + "@local_config_mlir//:StandardOps", "@local_config_mlir//:Transforms", ], ) diff --git a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc index 7cbbb3ec44e..c749af3a1c3 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc @@ -25,32 +25,234 @@ limitations under the License. #include "mlir/Dialect/GPU/Passes.h" // TF:local_config_mlir #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // TF:local_config_mlir #include "mlir/Dialect/LLVMIR/NVVMDialect.h" // TF:local_config_mlir +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" // TF:local_config_mlir #include "mlir/Dialect/Linalg/Passes.h" // TF:local_config_mlir +#include "mlir/Dialect/LoopOps/LoopOps.h" // TF:local_config_mlir +#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir #include "mlir/IR/Attributes.h" // TF:local_config_mlir +#include "mlir/IR/BlockAndValueMapping.h" // TF:local_config_mlir +#include "mlir/IR/Builders.h" // TF:local_config_mlir #include "mlir/IR/Function.h" // TF:local_config_mlir #include "mlir/IR/Module.h" // TF:local_config_mlir #include "mlir/IR/OperationSupport.h" // TF:local_config_mlir +#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir +#include "mlir/IR/Region.h" // TF:local_config_mlir #include "mlir/Pass/Pass.h" // TF:local_config_mlir #include "mlir/Pass/PassManager.h" // TF:local_config_mlir #include "mlir/Transforms/DialectConversion.h" // TF:local_config_mlir #include "mlir/Transforms/Passes.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" +#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h" #include "tensorflow/compiler/xla/util.h" namespace xla { namespace mlir_gpu { +namespace { + +using ::mlir::xla_lhlo::FusionOp; + +// Following are some small transformations that are required to clean up code +// after lowering from linalg to loops. + +// A simple pass that applies lowering of HLO to LHLO only within Fusion +// operations. This is needed, as FusionOp is not closed from above and hence +// nested pass managers can not be applied. +struct FusionToLhloConverter + : public mlir::FunctionPass { + void runOnFunction() override { + auto& ctx = getContext(); + mlir::OwningRewritePatternList patterns; + mlir::ConversionTarget target(ctx); + target.addLegalDialect<::mlir::xla_lhlo::XlaLhloDialect>(); + ::mlir::xla_hlo::populateHLOToLHLOConversionPattern(&ctx, &patterns); + + getFunction().walk([&](FusionOp op) { + if (failed(applyPartialConversion(op, target, patterns, nullptr))) { + signalPassFailure(); + } + }); + } +}; + +// Replaces a FusionOp by the operations contained in its region. +struct FusionOpRemover : public mlir::FunctionPass { + void runOnFunction() override { + getFunction().walk([&](FusionOp op) { + mlir::OpBuilder builder(op); + // FusionOp has a single region with a single block, so we can just walk + // over it and clone operations to the outside. + mlir::BlockAndValueMapping mapping; + for (auto& nested_op : op.region().front().without_terminator()) { + auto clone = builder.clone(nested_op, mapping); + for (auto pair : + llvm::zip(nested_op.getResults(), clone->getResults())) { + mapping.map(std::get<0>(pair), std::get<1>(pair)); + } + } + op.erase(); + }); + } +}; + +// Rewrite the single-trip loops we get out of linalg into just their bodies. +// TODO(herhut): Make this a general pattern. +struct SingleTripLoopRemoval + : public mlir::FunctionPass { + void runOnFunction() override { + auto getConstantValue = [](mlir::Value* value) -> llvm::Optional { + auto definingOp = value->getDefiningOp(); + if (!definingOp) return llvm::None; + auto constantOp = llvm::dyn_cast(definingOp); + if (!constantOp) return llvm::None; + auto integer = constantOp.getValue().dyn_cast(); + if (!integer) return llvm::None; + return integer.getInt(); + }; + getFunction().walk([&](mlir::loop::ForOp forOp) { + auto lower = getConstantValue(forOp.lowerBound()); + auto upper = getConstantValue(forOp.upperBound()); + auto step = getConstantValue(forOp.step()); + if (!lower || !upper || !step) return; + if ((lower.getValue() < upper.getValue()) && + (lower.getValue() + step.getValue() >= upper.getValue())) { + // This loop has a single trip, so we can move the body in front. + mlir::BlockAndValueMapping mapping; + mlir::OpBuilder b(forOp); + mapping.map(forOp.getInductionVar(), forOp.lowerBound()); + for (auto& nested_op : forOp.getBody()->without_terminator()) { + auto clone = b.clone(nested_op, mapping); + for (auto pair : + llvm::zip(nested_op.getResults(), clone->getResults())) { + mapping.map(std::get<0>(pair), std::get<1>(pair)); + } + } + forOp.erase(); + } + }); + } +}; + +// Simple pass that replaces a load that immediately follows a store to the +// same address with the stored value. This needs generalization. +struct StoreForwardingPass : mlir::FunctionPass { + void runOnFunction() override { + getFunction().walk([&](mlir::LoadOp loadOp) { + auto block = loadOp.getOperation()->getBlock(); + auto iterator = std::find_if(block->rbegin(), block->rend(), + [&loadOp](mlir::Operation& other) { + return &other == loadOp.getOperation(); + }); + if (++iterator == block->rend()) return; + mlir::StoreOp storeOp = llvm::dyn_cast(&*(iterator)); + if (!storeOp) return; + // Check both store to the same value. + if (storeOp.memref() != loadOp.memref()) return; + auto storeIndices = storeOp.getIndices(); + auto loadIndices = loadOp.getIndices(); + if (!std::equal(storeIndices.begin(), storeIndices.end(), + loadIndices.begin(), loadIndices.end())) { + return; + } + loadOp.replaceAllUsesWith(storeOp.getValueToStore()); + loadOp.erase(); + }); + }; +}; + +// Simple pass that removes temporary buffers that are only written to but +// never read from or that are read but the read value is not used. +// Needs an analysis that proves that loads and stores are side-effect free +// (in bounds, no aliasing, etc.). +struct DeadTempBufferRemoval : mlir::FunctionPass { + bool operationConsideredDead(mlir::Operation* op) { + for (auto result : op->getResults()) { + if (!llvm::all_of(result->getUsers(), [&](mlir::Operation* op) { + // Store and Dealloc is OK. + if (llvm::isa(op) || + llvm::isa(op)) { + return true; + } + // Load without uses is also ok. + if (auto loadOp = llvm::dyn_cast(op)) { + return loadOp.use_empty(); + } + // Subview is ok if it is dead itself. + if (llvm::isa(op)) { + return operationConsideredDead(op); + } + return false; + })) { + return false; + } + } + return true; + } + + void recursiveErase(mlir::Operation* op) { + for (auto result : op->getResults()) { + for (auto user : llvm::make_early_inc_range(result->getUsers())) { + recursiveErase(user); + } + } + op->erase(); + } + + void runOnFunction() override { + getFunction().walk([&](mlir::AllocOp allocOp) { + if (!operationConsideredDead(allocOp)) { + return; + } + + // TODO(herhut): There should be a generic helper for this. + recursiveErase(allocOp); + }); + } +}; + +// Neat little helper pass to dump the IR inbetween passes. +struct DumpPass : public mlir::ModulePass { + void runOnModule() override { +#if DEBUG + getModule().dump(); +#endif + } +}; + +} // namespace Status LowerLHLOToGPU(mlir::ModuleOp module) { mlir::PassManager pm(module.getContext()); - // Transform element-wise operations to LinAlg. + // First, lower bodies of fusion operations from hlo to lhlo. + pm.addPass(absl::make_unique()); + // Next, we can strip the outer fusion operation. + pm.addPass(absl::make_unique()); + // Transform lhlo operations to LinAlg. pm.addPass(::mlir::xla_lhlo::createLegalizeToLinalgPass()); - // Go from affine to normal loops. + // Fuse linalg operations. This will yield a single tiled loop nest where + // the inner loops are single trip. + pm.addPass(::mlir::xla_lhlo::createLhloFuseLinalg()); + pm.addPass(absl::make_unique()); + // Go from linalg to normal loops. pm.addPass(::mlir::linalg::createConvertLinalgToLoopsPass()); - // Lower affine to ordinary loops. - pm.addPass(::mlir::createLowerAffinePass()); - // Move constants out of the loop. - pm.addPass(::mlir::createLoopInvariantCodeMotionPass()); + pm.addPass(absl::make_unique()); + // Canonicalize the code to simplify index computations. + pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass()); + pm.addPass(absl::make_unique()); + // The innermost loops will be single-trip. + pm.addPass(absl::make_unique()); + pm.addPass(absl::make_unique()); + // Run CSE to ensure that loads and stores to the same subview get + // recognized as such. + pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass()); + pm.addPass(absl::make_unique()); + // Forward stores to buffers to loads. + pm.addPass(absl::make_unique()); + pm.addPass(absl::make_unique()); + // Remove now unused temporary buffers. + pm.addPass(absl::make_unique()); + pm.addPass(absl::make_unique()); // Coalesce generated loops to have 1d loops. pm.addPass(::mlir::createLoopCoalescingPass()); // Transform the now 1d loops to gpu launches. @@ -65,6 +267,7 @@ Status LowerLHLOToGPU(mlir::ModuleOp module) { if (failed(pm.run(module))) { return InternalError("Lowering to GPU kernels failed."); } + return Status::OK(); } @@ -73,7 +276,7 @@ Status LowerKernelBodiesToNVVM(mlir::ModuleOp module) { ::mlir::PassManager pm(module.getContext(), /*verifyPasses=*/false); // Rewrite kernel functions to LLVM IR. - auto &kernelPm = pm.nest<::mlir::ModuleOp>(); + auto& kernelPm = pm.nest<::mlir::ModuleOp>(); kernelPm.addPass(::mlir::createLowerGpuOpsToNVVMOpsPass()); // Some basic cleanup. kernelPm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass()); diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc index b035a8ddcb5..92f7e5a08ac 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc @@ -522,6 +522,10 @@ StatusOr> MlirCompiler::RunBackend( auto llvmModule = mlir::translateModuleToNVVMIR(*kernel_module); + if (!llvmModule) { + return InternalError("Translation to LLVM failed"); + } + llvmModule->setModuleIdentifier(emission_context.getHloModule()->name()); // TODO(herhut): Why is this needed and does not come from the template? llvmModule->setDataLayout(gpu::nvptx::kDataLayout); diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc b/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc index 3ad958dfe6d..1d37aa1ba75 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc @@ -113,41 +113,20 @@ ENTRY %Add (x: f32[2,2], y: f32[2,2]) -> f32[2,2] { ;CHECK: "gpu.launch_func"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[ARG0]], %[[ARG1]], %[[ARG2]] ;CHECK: } ;CHECK: func @add_kernel(%[[ARG0]]: [[TYPE]], %[[ARG1]]: [[TYPE]], %[[ARG2]]: [[TYPE]] -;CHECK: load %[[ARG0]][[INDEX:.*]] -;CHECK: load %[[ARG1]][[INDEX]] -;CHECK: store %{{.*}}, %[[ARG2]][[INDEX]] +;CHECK-DAG: std.subview %[[ARG0]]{{\[}}[[INDEX:.*]]] +;CHECK-DAG: std.subview %[[ARG1]]{{\[}}[[INDEX]]] +;CHECK-DAG: std.subview %[[ARG2]]{{\[}}[[INDEX]]] +;CHECK: %[[VAL1:.*]] = load %{{.*\[}}[[INDEX:.*]]] +;CHECK: %[[VAL2:.*]] = load %{{.*\[}}[[INDEX]]] +;CHECK: %[[RES:.*]] = addf %[[VAL1]], %[[VAL2]] +;CHECK: store %[[RES]], %{{.*\[}}[[INDEX]]] )", LoweringStage::GPU); } -TEST_F(LhloGenTest, AddInLVVMDialect) { - CompileAndVerifyIr(R"( -HloModule Add - -ENTRY %Add (x: f32[2,2], y: f32[2,2]) -> f32[2,2] { - %x = f32[2,2]{1,0} parameter(0) - %y = f32[2,2]{1,0} parameter(1) - ROOT %add = f32[2,2]{1,0} add(f32[2,2]{1,0} %x, f32[2,2]{1,0} %y) -})", - R"( -;CHECK: func @add_kernel(%[[ARG0:.*]]: [[TYPE:!llvm<.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]] -;CHECK: %[[LD0:.*]] = llvm.load %[[ARG0]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> -;CHECK: %[[LD1:.*]] = llvm.load %[[ARG1]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> -;CHECK: %[[LD2:.*]] = llvm.load %[[ARG2]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> -;CHECK: %[[PTR0:.*]] = llvm.extractvalue %[[LD0]][1] -;CHECK: %[[GEP0:.*]] = llvm.getelementptr %[[PTR0]] -;CHECK: %[[VAL0:.*]] = llvm.load %[[GEP0]] -;CHECK: %[[PTR1:.*]] = llvm.extractvalue %[[LD1]][1] -;CHECK: %[[GEP1:.*]] = llvm.getelementptr %[[PTR1]] -;CHECK: %[[VAL1:.*]] = llvm.load %[[GEP1]] -;CHECK: %[[VAL2:.*]] = llvm.fadd %[[VAL0]], %[[VAL1]] -;CHECK: %[[PTR2:.*]] = llvm.extractvalue %[[LD2]][1] -;CHECK: %[[GEP2:.*]] = llvm.getelementptr %[[PTR2]] -;CHECK: llvm.store %[[VAL2]], %[[GEP2]] - )", - LoweringStage::LLVM); -} - +// This test verifies that the kernel signature is amended correctly. The actual +// body of the generated function does not matter, it is already checked at the +// GPU level above. TEST_F(LhloGenTest, AddAsKernel) { CompileAndVerifyIr(R"( HloModule Add @@ -219,20 +198,6 @@ ENTRY %Add (x: f32[2,2], y: f32[2,2]) -> f32[2,2] { ;CHECK: llvm.store %{{.*}}, %[[GEP2ST0]] ;CHECK: %[[GEP2ST1:.*]] = llvm.getelementptr %[[DESC2]] ;CHECK: llvm.store %{{.*}}, %[[GEP2ST1]] - -;CHECK: %[[VL0:.*]] = llvm.load %[[DESC0]] -;CHECK: %[[VL1:.*]] = llvm.load %[[DESC1]] -;CHECK: %[[VL2:.*]] = llvm.load %[[DESC2]] -;CHECK: %[[EV0:.*]] = llvm.extractvalue %[[VL0]][1] -;CHECK: %[[VGEP0:.*]] = llvm.getelementptr %[[EV0]] -;CHECK: %[[VAL0:.*]] = llvm.load %[[VGEP0]] -;CHECK: %[[EV1:.*]] = llvm.extractvalue %[[VL1]][1] -;CHECK: %[[VGEP1:.*]] = llvm.getelementptr %[[EV1]] -;CHECK: %[[VAL1:.*]] = llvm.load %[[VGEP1]] -;CHECK: %[[VAL2:.*]] = llvm.fadd %[[VAL0]], %[[VAL1]] -;CHECK: %[[EV2:.*]] = llvm.extractvalue %[[VL2]][1] -;CHECK: %[[SGEP:.*]] = llvm.getelementptr %[[EV2]] -;CHECK: llvm.store %[[VAL2]], %[[SGEP]] )", LoweringStage::KERNEL); } @@ -262,6 +227,34 @@ ENTRY %AddMultiply (x: f32[2,2], y: f32[2,2], z: f32[2,2]) -> f32[2,2] { )"); } +TEST_F(LhloGenTest, AddMultiplyGPU) { + CompileAndVerifyIr(R"( +HloModule AddMultiply + +ENTRY %AddMultiply (x: f32[2,2], y: f32[2,2], z: f32[2,2]) -> f32[2,2] { + %x = f32[2,2]{1,0} parameter(0) + %y = f32[2,2]{1,0} parameter(1) + %z = f32[2,2]{1,0} parameter(2) + %add = f32[2,2]{1,0} add(f32[2,2]{1,0} %x, f32[2,2]{1,0} %y) + ROOT %mul = f32[2,2]{1,0} multiply(f32[2,2]{1,0} %add, f32[2,2]{1,0} %z) +})", + R"( +;CHECK: func @fusion_kernel(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]], %[[RESULT:.*]]: [[TYPE]]) +;CHECK-DAG: std.subview %[[ARG0]]{{\[}}[[INDEX:.*]]] +;CHECK-DAG: std.subview %[[ARG1]]{{\[}}[[INDEX]]] +;CHECK-DAG: std.subview %[[ARG2]]{{\[}}[[INDEX]]] +;CHECK-DAG: std.subview %[[RESULT]]{{\[}}[[INDEX]]] +;CHECK: %[[V0:.*]] = load %{{.*\[}}[[CSTIDX:.*]]] +;CHECK: %[[V1:.*]] = load %{{.*\[}}[[CSTIDX:.*]]] +;CHECK: %[[ADD:.*]] = addf %[[V0]], %[[V1]] +;CHECK: %[[V2:.*]] = load %{{.*\[}}[[CSTIDX:.*]]] +;CHECK: %[[MUL:.*]] = mulf %[[ADD]], %[[V2]] +;CHECK: store %[[MUL]], %{{.*\[}}[[CSTIDX:.*]]] +;CHECK-NEXT: return + )", + LoweringStage::GPU); +} + TEST_F(LhloGenTest, FusedReduce) { CompileAndVerifyIr(R"( HloModule FusedReduce @@ -275,12 +268,14 @@ HloModule FusedReduce %fused_computation (param: f32[100,10]) -> f32[10] { %param = f32[100,10] parameter(0) %constant = f32[] constant(0) - ROOT %reduce = f32[10]{0} reduce(f32[100,10]{1,0} %param, f32[] %constant), dimensions={0}, to_apply=%add + ROOT %reduce = f32[10]{0} reduce(f32[100,10]{1,0} %param, f32[] %constant), + dimensions={0}, to_apply=%add } ENTRY %FusedReduce (x: f32[100,10]) -> f32[10] { %x = f32[100,10] parameter(0) - ROOT %fusion = f32[10]{0} fusion(f32[100,10]{1,0} %x), kind=kInput, calls=%fused_computation + ROOT %fusion = f32[10]{0} fusion(f32[100,10]{1,0} %x), kind=kInput, + calls=%fused_computation } )", R"( @@ -316,21 +311,20 @@ ENTRY %Broadcast (x: f32[10]) -> f32[10, 5] { )"); } -// TODO(pifon): Re-enable when Iota can be lowered all the way to GPU. -// TEST_F(LhloGenTest, Iota) { -// CompileAndVerifyIr(R"( -// HloModule Iota -// -// ENTRY %Iota() -> s64[10, 5] { -// ROOT %iota = s64[10, 5]{1,0} iota(), iota_dimension=0 -// })", -// R"( -// ;CHECK: func @iota(%[[OUT:.*]]: [[OUT_T:.*]]) { -// ;CHECK: "xla_lhlo.iota"(%[[OUT]]) -// ;CHECK: {iota_dimension = 0 : i64} : ([[OUT_T]]) -> () -// ;CHECK: } -// )"); -// } +TEST_F(LhloGenTest, Iota) { + CompileAndVerifyIr(R"( + HloModule Iota + + ENTRY %Iota() -> s64[10, 5] { + ROOT %iota = s64[10, 5]{1,0} iota(), iota_dimension=0 +})", + R"( +;CHECK: func @iota(%[[OUT:.*]]: [[OUT_T:.*]]) { +;CHECK: "xla_lhlo.iota"(%[[OUT]]) +;CHECK: {iota_dimension = 0 : i64} : ([[OUT_T]]) -> () +;CHECK: } +)"); +} TEST_F(LhloGenTest, AddReduce) { CompileAndVerifyIr(R"( diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc index 74f2c95102a..07b6fb5bf85 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc @@ -108,6 +108,11 @@ StatusOr MultiOutputFusion::Run(HloModule* module) { changed = true; } } + // Clean up state in case this pass is wrapped in an HloPassPipeline. + candidates_.clear(); + candidates_index_.clear(); + all_fusion_candidates_.clear(); + reachability_.reset(); return changed; } diff --git a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc index 8b95c17d199..c2dc9125479 100644 --- a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc +++ b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc @@ -98,9 +98,9 @@ void LlvmIrGenTestBase::MatchOptimizedHlo(absl::string_view hlo, StatusOr> LlvmIrGenTestBase::GetOptimizedModule( absl::string_view hlo) { - HloModuleConfig config; - TF_ASSIGN_OR_RETURN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo, config)); + TF_ASSIGN_OR_RETURN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo, GetModuleConfigForTest())); return backend().compiler()->RunHloPasses( std::move(module), backend().default_stream_executor(), backend().default_stream_executor()->GetAllocator()); diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index da20d28ea81..8e6e9b46100 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -252,3 +252,114 @@ sh_test( srcs = ["interactive_graphviz_test.sh"], data = [":interactive_graphviz"], ) + +cc_library( + name = "hlo_module_loader", + srcs = ["hlo_module_loader.cc"], + hdrs = ["hlo_module_loader.h"], + deps = [ + "//tensorflow/compiler/xla:debug_options_flags", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/core:lib", + "//tensorflow/core:regexp_internal", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf_headers", + ], +) + +tf_cc_test( + name = "hlo_module_loader_test", + srcs = ["hlo_module_loader_test.cc"], + deps = [ + ":hlo_module_loader", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + "//tensorflow/core:test", + ], +) + +cc_library( + name = "prepare_reference_module", + srcs = ["prepare_reference_module.cc"], + hdrs = ["prepare_reference_module.h"], + deps = [ + "//tensorflow/compiler/xla:debug_options_flags", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_proto_cc", + "//tensorflow/compiler/xla/service:despecializer", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/core/platform:errors", + "//tensorflow/stream_executor:platform", + "//tensorflow/stream_executor/lib", + ], +) + +cc_library( + name = "run_hlo_module_lib", + testonly = True, + srcs = ["run_hlo_module.cc"], + hdrs = ["run_hlo_module.h"], + deps = [ + ":hlo_module_loader", + ":prepare_reference_module", + "//tensorflow/compiler/xla:debug_options_flags", + "//tensorflow/compiler/xla:error_spec", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client/lib:testing", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_runner", + "//tensorflow/compiler/xla/service:hlo_verifier", + "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/core/platform:logging", + "//tensorflow/core/platform:status", + "//tensorflow/core/platform:test", + "//tensorflow/stream_executor:platform", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +tf_cc_binary( + name = "run_hlo_module", + testonly = True, + srcs = ["run_hlo_module_main.cc"], + deps = [ + ":run_hlo_module_lib", + "//tensorflow/compiler/xla:debug_options_flags", + "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/compiler/xla/service:gpu_plugin", + "//tensorflow/compiler/xla/service:interpreter_plugin", + "//tensorflow/core:framework_internal", + "//tensorflow/core/platform:logging", + "//tensorflow/core/platform:platform_port", + "//tensorflow/core/platform:status", + "//tensorflow/core/platform:test", + "@com_google_absl//absl/strings", + ], +) + +# Same as run_hlo_module, but supports the MLIR GPU backend instead of the XLA +# GPU backend. +tf_cc_binary( + name = "run_hlo_module_mlir_gpu", + testonly = True, + srcs = ["run_hlo_module_main.cc"], + deps = [ + ":run_hlo_module_lib", + "//tensorflow/compiler/xla:debug_options_flags", + "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/compiler/xla/service:interpreter_plugin", + "//tensorflow/compiler/xla/service:mlir_gpu_plugin", + "//tensorflow/core:framework_internal", + "//tensorflow/core/platform:logging", + "//tensorflow/core/platform:platform_port", + "//tensorflow/core/platform:status", + "//tensorflow/core/platform:test", + "@com_google_absl//absl/strings", + ], +) diff --git a/tensorflow/compiler/xla/tools/hlo_module_loader.cc b/tensorflow/compiler/xla/tools/hlo_module_loader.cc new file mode 100644 index 00000000000..8eb170b25e5 --- /dev/null +++ b/tensorflow/compiler/xla/tools/hlo_module_loader.cc @@ -0,0 +1,125 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Emits an HLO module in a text form suitable for diffing. + +#include "tensorflow/compiler/xla/tools/hlo_module_loader.h" + +#include +#include +#include + +#include "google/protobuf/text_format.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/regexp.h" + +namespace xla { +namespace { + +Status OverrideConfig(const hlo_module_loader_details::Config& ovr_config, + HloModuleConfig* config) { + config->set_replica_count(ovr_config.num_replicas); + return Status::OK(); +} + +} // namespace + +string StripLogHeaders(const string& hlo_string) { + // I0521 12:04:45.883483 1509 service.cc:186] ... + static RE2* matcher = new RE2( + "[IWEF]\\d{4} " + "\\d{2}:\\d{2}:\\d{2}\\.\\d+\\s+\\d+\\s+[^:]+:\\d+\\]\\s?(.*)"); + absl::string_view matches[4]; + std::vector lines = absl::StrSplit(hlo_string, '\n'); + for (auto& line : lines) { + if (matcher->Match(line, 0, line.size(), RE2::ANCHOR_START, matches, 4)) { + line = string(matches[1]); + } + } + return absl::StrJoin(lines, "\n", [](string* out, const string& line) { + absl::StrAppend(out, line); + }); +} + +StatusOr> LoadModuleFromData( + const string& data, const string& format, + hlo_module_loader_details::Config ovr_config, + const std::function& config_modifier_hook) { + DebugOptions debug_options = GetDebugOptionsFromFlags(); + std::unique_ptr module; + if (format == "hlo" || format == "txt") { + string hlo_string = StripLogHeaders(data); + HloModuleConfig config; + config.set_debug_options(debug_options); + TF_RETURN_IF_ERROR(OverrideConfig(ovr_config, &config)); + if (config_modifier_hook) { + config_modifier_hook(&config); + } + TF_ASSIGN_OR_RETURN(module, + ParseAndReturnUnverifiedModule(hlo_string, config)); + } else { + HloSnapshot proto; + if (format == "pb") { + if (!proto.ParseFromString(data) && + !proto.mutable_hlo()->ParseFromString(data)) { + return InvalidArgument("Failed to parse input as HLO protobuf binary"); + } + } else if (format == "pbtxt") { + if (!proto2::TextFormat::ParseFromString(data, &proto) && + !proto2::TextFormat::ParseFromString(data, proto.mutable_hlo())) { + return InvalidArgument("Failed to parse input as HLO protobuf text"); + } + } else { + return InvalidArgument( + "Invalid format from file extension: '%s'. Expected: hlo, txt, pb, " + "or pbtxt", + format); + } + TF_ASSIGN_OR_RETURN(HloModuleConfig config, + HloModule::CreateModuleConfigFromProto( + proto.hlo().hlo_module(), debug_options)); + TF_RETURN_IF_ERROR(OverrideConfig(ovr_config, &config)); + if (config_modifier_hook) { + config_modifier_hook(&config); + } + TF_ASSIGN_OR_RETURN( + module, HloModule::CreateFromProto(proto.hlo().hlo_module(), config)); + } + return std::move(module); +} + +StatusOr> LoadModuleFromFile( + const string& path, hlo_module_loader_details::Config ovr_config, + string format, + const std::function& config_modifier_hook) { + string data; + if (format.empty()) { + format = string(tensorflow::io::Extension(path)); + } + TF_RETURN_IF_ERROR( + tensorflow::ReadFileToString(tensorflow::Env::Default(), path, &data)); + return LoadModuleFromData(data, format, ovr_config, config_modifier_hook); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/tools/hlo_module_loader.h b/tensorflow/compiler/xla/tools/hlo_module_loader.h new file mode 100644 index 00000000000..8e174cef08f --- /dev/null +++ b/tensorflow/compiler/xla/tools/hlo_module_loader.h @@ -0,0 +1,79 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_HLO_MODULE_LOADER_H_ +#define TENSORFLOW_COMPILER_XLA_TOOLS_HLO_MODULE_LOADER_H_ + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { +namespace hlo_module_loader_details { + +struct Config { + Config() {} + int64 num_replicas = 1; +}; + +} // namespace hlo_module_loader_details + +// Given a string composed by multiple lines, strip the log headers, if present +// at the beginning of each line. +string StripLogHeaders(const string& hlo_string); + +// Loads an HLO module from a string. +// The data can have the followings formats: +// 1) A binary of text proto file, the proto should be in xla.HloProto type. It +// can be a binary proto (format must be "pb"), or a text proto (format must +// be "pbtxt"). +// 2) A hlo text dump, the string should be in HloModule::ToString() format +// (format must be "txt" or "hlo"). The input data can also contain log +// headers, which will be stripped. +// The ovr_config data can be used to override certain fields of the +// HloModuleConfig. +// The HloModuleConfig is passed to config_modifier_hook for custom +// modifications before use. +StatusOr> LoadModuleFromData( + const string& data, const string& format, + hlo_module_loader_details::Config ovr_config = + hlo_module_loader_details::Config(), + const std::function& config_modifier_hook = {}); + +// Loads an HLO module from file. +// The file can be one of the followings: +// 1) A binary of text proto file, the proto should be in xla.HloProto type. It +// can be a binary proto (with .pb extension), or a text proto (with a .pbtxt +// extension). +// 2) A hlo text dump, the string should be in HloModule::ToString() format +// (with a .hlo or .txt extension). A text file can also contain log headers, +// which will be stripped. +// If the format is specified (not empty), it overrides the one guessed from the +// file extension. The ovr_config data can be used to override certain fields of +// the HloModuleConfig. +// The HloModuleConfig is passed to config_modifier_hook for custom +// modifications before use. +StatusOr> LoadModuleFromFile( + const string& path, + hlo_module_loader_details::Config ovr_config = + hlo_module_loader_details::Config(), + string format = "", + const std::function& config_modifier_hook = {}); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_TOOLS_HLO_MODULE_LOADER_H_ diff --git a/tensorflow/compiler/xla/tools/hlo_module_loader_test.cc b/tensorflow/compiler/xla/tools/hlo_module_loader_test.cc new file mode 100644 index 00000000000..e88d03e6b33 --- /dev/null +++ b/tensorflow/compiler/xla/tools/hlo_module_loader_test.cc @@ -0,0 +1,48 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tools/hlo_module_loader.h" + +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +class HloModuleLoaderTest : public HloTestBase {}; + +TEST_F(HloModuleLoaderTest, StripsLogHeaders) { + const string& hlo_string = R"( +I0521 12:04:45.883483 1509 service.cc:186] HloModule test_log_stripping +I0521 12:04:45.883483 1509 service.cc:186] +I0521 12:04:45.883483 1509 service.cc:186] ENTRY entry { +I0521 12:04:45.883483 1509 service.cc:186] p0 = f32[4]{0} parameter(0) +I0521 12:04:45.883483 1509 service.cc:186] p1 = f32[4]{0} parameter(1) +I0521 12:04:45.883483 1509 service.cc:186] add = f32[4]{0} add(p0, p1) +I0521 12:04:45.883483 1509 service.cc:186] ROOT rooty = (f32[4]{0}, f32[4]{0}) tuple(p1, add) +I0521 12:04:45.883483 1509 service.cc:186] } +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + LoadModuleFromData(hlo_string, "txt")); + EXPECT_NE(FindInstruction(hlo_module.get(), "p0"), nullptr); + EXPECT_NE(FindInstruction(hlo_module.get(), "p1"), nullptr); + EXPECT_NE(FindInstruction(hlo_module.get(), "add"), nullptr); + EXPECT_NE(FindInstruction(hlo_module.get(), "rooty"), nullptr); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tools/prepare_reference_module.cc b/tensorflow/compiler/xla/tools/prepare_reference_module.cc new file mode 100644 index 00000000000..65489c2d5db --- /dev/null +++ b/tensorflow/compiler/xla/tools/prepare_reference_module.cc @@ -0,0 +1,61 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tools/prepare_reference_module.h" + +#include +#include + +#include "tensorflow/compiler/xla/debug_options_flags.h" +#include "tensorflow/compiler/xla/service/despecializer.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/xla.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/stream_executor/lib/status.h" +#include "tensorflow/stream_executor/platform.h" + +namespace xla { + +StatusOr> PrepareReferenceModule( + const HloModule& test_module, + const ::stream_executor::Platform::Id& test_platform_id, + const std::function& config_modifier_hook, + const std::function& module_modifier_hook) { + DebugOptions debug_options = GetDebugOptionsFromFlags(); + // The combination of fast math and optimizations leads to unsound code + // transformations (see third_party/tensorflow/compiler/xla/xla.proto for + // details). The test platform should not change this from the default. + debug_options.set_xla_cpu_enable_fast_math(false); + debug_options.set_xla_gpu_enable_fast_min_max(false); + HloModuleConfig reference_config = test_module.config(); + reference_config.set_debug_options(debug_options); + if (config_modifier_hook) { + config_modifier_hook(&reference_config); + } + std::unique_ptr reference_module = + test_module.Clone(reference_config, "reference"); + if (module_modifier_hook) { + TF_RETURN_IF_ERROR(module_modifier_hook(test_module, test_platform_id, + reference_module.get())); + } else { + TF_RETURN_IF_ERROR(Despecializer().Run(reference_module.get()).status()); + } + return std::move(reference_module); +} +}; // namespace xla diff --git a/tensorflow/compiler/xla/tools/prepare_reference_module.h b/tensorflow/compiler/xla/tools/prepare_reference_module.h new file mode 100644 index 00000000000..f98e50fc1e8 --- /dev/null +++ b/tensorflow/compiler/xla/tools/prepare_reference_module.h @@ -0,0 +1,45 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PREPARE_REFERENCE_MODULE_H_ +#define TENSORFLOW_COMPILER_XLA_TOOLS_PREPARE_REFERENCE_MODULE_H_ + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/stream_executor/lib/status.h" +#include "tensorflow/stream_executor/platform.h" + +namespace xla { + +// A helper function that takes a HloModule, derives a HloModuleConfig from it +// which disables fast-math und sets the DebugOptions from flags, then runs the +// deoptimization pipeline (or calls 'module_modifier_hook' if provided). This +// is meant to produce a reference module that is comparable to our custom test +// platforms. +StatusOr> PrepareReferenceModule( + const HloModule& test_module, + const ::stream_executor::Platform::Id& test_platform_id, + const std::function& config_modifier_hook = {}, + const std::function& module_modifier_hook = {}); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PREPARE_REFERENCE_MODULE_H_ diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index 67a2c26201a..095655085e5 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -346,10 +346,10 @@ StatusOr> ParseRecordIoFile(absl::string_view filename, std::vector snapshots; uint64 offset = 0; - string record; + tensorflow::tstring record; while (reader.ReadRecord(&offset, &record).ok()) { HloSnapshot snapshot; - if (snapshot.mutable_hlo()->ParseFromString(record)) { + if (snapshot.mutable_hlo()->ParseFromStringPiece(record)) { snapshots.push_back(std::move(snapshot)); } else { LOG(ERROR) << "Encountered bad proto"; diff --git a/tensorflow/compiler/xla/tools/run_hlo_module.cc b/tensorflow/compiler/xla/tools/run_hlo_module.cc new file mode 100644 index 00000000000..39b545af393 --- /dev/null +++ b/tensorflow/compiler/xla/tools/run_hlo_module.cc @@ -0,0 +1,145 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tools/run_hlo_module.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/client/lib/testing.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" +#include "tensorflow/compiler/xla/error_spec.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_runner.h" +#include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/tools/hlo_module_loader.h" +#include "tensorflow/compiler/xla/tools/prepare_reference_module.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/test.h" + +namespace se = ::stream_executor; + +namespace xla { +namespace { + +Literal ExecuteOnPlatform(std::unique_ptr module, + absl::Span args, + se::Platform* platform, bool run_hlo_passes) { + HloRunner runner(platform); + + TF_QCHECK_OK(VerifyHloModule(module.get(), /*layout_sensitive=*/false, + /*allow_mixed_precision=*/true)) + << " (on " << platform->Name() << ")"; + + std::cerr << "Running HLO module on platform " << platform->Name() << "...\n"; + XLA_VLOG_LINES(1, module->ToString()); + const auto start = std::chrono::high_resolution_clock::now(); + auto result_status = runner.Execute(std::move(module), args, run_hlo_passes); + const auto end = std::chrono::high_resolution_clock::now(); + std::chrono::duration diff = end - start; + std::cerr << "... compiled and ran in " << diff.count() << "s.\n"; + + TF_QCHECK_OK(result_status.status()) + << "Failed to execute on " << platform->Name() << "\n"; + + return result_status.ConsumeValueOrDie(); +} +} // namespace + +::testing::AssertionResult RunAndCompare( + const std::string& hlo_filename, const std::string& test_platform_name, + const std::string& reference_platform_name, std::minstd_rand0* engine, + const RunHloModuleOptions& options, + std::function + reference_module_modifier_hook) { + se::Platform* test_platform = + xla::PlatformUtil::GetPlatform(test_platform_name).ValueOrDie(); + se::Platform* reference_platform = + reference_platform_name.empty() + ? nullptr + : xla::PlatformUtil::GetPlatform(reference_platform_name) + .ValueOrDie(); + auto config_modifier = [](HloModuleConfig* config) { config->set_seed(42); }; + + std::unique_ptr test_module = + LoadModuleFromFile(hlo_filename, hlo_module_loader_details::Config(), + options.input_format, config_modifier) + .ValueOrDie(); + const HloModuleProto test_module_proto = test_module->ToProto(); + + std::vector args = MakeFakeArguments(test_module.get(), engine, + options.use_large_float_range) + .ConsumeValueOrDie(); + + if (options.print_literals) { + for (int i = 0; i < args.size(); ++i) { + std::cout << "\n** Argument " << i << " **\n" + << args[i].ToString() << "\n"; + } + } + + std::unique_ptr reference_module; + if (reference_platform != nullptr) { + // PrepareReferenceModule needs to know the *test* platform, in order to + // properly match the test platform's numerics. + reference_module = + PrepareReferenceModule(*test_module, test_platform->id(), + config_modifier, reference_module_modifier_hook) + .ConsumeValueOrDie(); + } + + Literal test_result = ExecuteOnPlatform( + std::move(test_module), args, test_platform, options.run_test_hlo_passes); + if (options.print_literals) { + std::cout << "\n** Result on test platform " << test_platform->Name() + << " **\n" + << test_result.ToString() << "\n"; + } + + if (reference_module == nullptr) { + std::cerr << "Skipping reference platform\n"; + return ::testing::AssertionSuccess(); + } + + Literal reference_result = + ExecuteOnPlatform(std::move(reference_module), args, reference_platform, + options.run_reference_hlo_passes); + + if (options.print_literals) { + std::cout << "\n** Result on reference platform " + << reference_platform->Name() << " **\n" + << reference_result.ToString() << "\n"; + } + ErrorSpec error_spec(static_cast(options.abs_error_bound), + static_cast(options.rel_error_bound)); + return LiteralTestUtil::Near(/*expected=*/reference_result, + /*actual=*/test_result, + /*error_spec=*/error_spec, + /*detailed_message=*/true); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/tools/run_hlo_module.h b/tensorflow/compiler/xla/tools/run_hlo_module.h new file mode 100644 index 00000000000..932cc22f4dd --- /dev/null +++ b/tensorflow/compiler/xla/tools/run_hlo_module.h @@ -0,0 +1,76 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_RUN_HLO_MODULE_H_ +#define TENSORFLOW_COMPILER_XLA_TOOLS_RUN_HLO_MODULE_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/stream_executor/platform.h" + +namespace xla { + +// Command-line options to this tool. See main() in run_hlo_module_main.cc for +// descriptions of these fields. +struct RunHloModuleOptions { + RunHloModuleOptions() + : platform(""), + reference_platform("default"), + print_literals(false), + run_test_hlo_passes(true), + run_reference_hlo_passes(true), + use_large_float_range(true), + // TODO(b/68721786): These tolerances are set to match the values in the + // isolation test. The goal is to lower these to 0.001. + abs_error_bound(0.1), + rel_error_bound(0.1), + input_format("hlo"), + input_module(""), + iterations(1) {} + std::string platform; + std::string reference_platform; + bool print_literals; + bool run_test_hlo_passes; + bool run_reference_hlo_passes; + bool use_large_float_range; + float abs_error_bound; + float rel_error_bound; + std::string input_format; + std::string input_module; + int iterations; +}; + +// Reads a HloModule from 'hlo_filename', runs it on the platform with the name +// 'test_platform_name', and if 'reference_platform_name' is non-empty, it also +// runs it on the platform with the name 'reference_platform_name' and compares +// the results. 'reference_module_modifier_hook' can be used to transform the +// HloModule before it is run on the reference platform. This may be necessary +// to match the numerics of the test platform. +::testing::AssertionResult RunAndCompare( + const std::string& hlo_filename, const std::string& test_platform_name, + const std::string& reference_platform_name, std::minstd_rand0* engine, + const RunHloModuleOptions& options, + std::function + reference_module_modifier_hook = {}); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_TOOLS_RUN_HLO_MODULE_H_ diff --git a/tensorflow/compiler/xla/tools/run_hlo_module_main.cc b/tensorflow/compiler/xla/tools/run_hlo_module_main.cc new file mode 100644 index 00000000000..7079f413eeb --- /dev/null +++ b/tensorflow/compiler/xla/tools/run_hlo_module_main.cc @@ -0,0 +1,184 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// A tool for reading a HloModule from a HloProto file and execute the module on +// given platform(s). See kUsage for details. + +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" +#include "tensorflow/compiler/xla/tools/run_hlo_module.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace { +const char* const kUsage = R"( +This tool lets you read a HloModule from a file and execute the module on given +platform. + +The file can be one of the followings: +1) a binary or text proto file, the proto should be in xla.HloProto type. +2) a hlo text dump, the string should be in HloModule::ToString() format. + +By default, the module is run on a reference platform such as the interpreter +and the reference result is compared against the test result. + +You can also pass in debug option flags for the HloModule. + +Usage: + + bazel run run_hlo_module -- \ + --input_format=[hlo|pb|pbtxt] \ + --platform=[CPU|CUDA|Interpreter] \ + path/to/hlo_module +)"; +const char kInterpreterPlatformName[] = "Interpreter"; + +// Returns the name of the test platform. +std::string GetTestPlatformName(std::string name) { + QCHECK(!name.empty()) << "Must pass --platform flag."; + return name; +} + +// Returns the name of the reference platform +std::string GetReferencePlatformName(std::string reference_platform) { + if (reference_platform == "default") { + return kInterpreterPlatformName; + } + return reference_platform; +} +} // namespace + +int main(int argc, char** argv) { + xla::RunHloModuleOptions opts; + std::vector flag_list = { + tensorflow::Flag( + "platform", &opts.platform, + "The test platform that the HLO module will be executed on " + "(gpu, cpu, etc)."), + tensorflow::Flag( + "reference_platform", &opts.reference_platform, + "The reference platform that HLO module will be " + "executed on. The result produced on the reference platform will " + "be compared against the result produced on the test platform. A " + "value of 'default' will use the TPU_Interpreter as a reference if " + "the test platform is a TPU, and 'interpreter' otherwise. If the " + "flag value is the empty string, then the module will not be run " + "on a reference platform at all."), + tensorflow::Flag("print_literals", &opts.print_literals, + "Print the input and result literals to stdout."), + tensorflow::Flag( + "run_test_hlo_passes", &opts.run_test_hlo_passes, + "Run HLO pass pipeline for the test platform on the HLO module " + "before running the module on the test platform. This should be " + "set to true if the HLO module is unoptimized and set to false if " + "the HLO module already has been optimized."), + tensorflow::Flag( + "run_reference_hlo_passes", &opts.run_reference_hlo_passes, + "Run HLO pass pipeline for the reference platform on the HLO module " + "before running the module on the reference platform. " + "In general, if the given HLO module was optimized for a platform " + "other " + "than the reference this is necessary because some HLO passes are " + "legalization passes which must be run prior to code generation."), + + tensorflow::Flag( + "use_large_float_range", &opts.use_large_float_range, + "Generate floating point values using a large uniform-log " + "distribtion as opposed to a small uniform distribution."), + tensorflow::Flag( + "abs_error_bound", &opts.abs_error_bound, + "The absolute error bound used when comparing the test and " + "reference results."), + tensorflow::Flag( + "rel_error_bound", &opts.rel_error_bound, + "The relative error bound used when comparing the test and " + "reference results."), + tensorflow::Flag("input_format", &opts.input_format, + "The format of the input file. Valid values:\n" + " hlo : HLO textual format\n" + " pb : xla::HloProto in binary proto format\n" + " pbtxt : xla::HloProto in text proto format"), + tensorflow::Flag( + "input_module", &opts.input_module, + "A path to a file containing the HLO module. Can also pass " + "a this as argv[1], but this flag is more explicit."), + tensorflow::Flag( + "iterations", &opts.iterations, + "The number of times to run the module. Each iteration will be run " + "with different input data.")}; + xla::AppendDebugOptionsFlags(&flag_list); + // The usage string includes the message at the top of the file, the + // DebugOptions flags and the flags defined above. + const std::string kUsageString = absl::StrCat( + kUsage, "\n\n", tensorflow::Flags::Usage(argv[0], flag_list)); + bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); + tensorflow::port::InitMain(kUsageString.c_str(), &argc, &argv); + if (!parse_ok) { + LOG(QFATAL) << kUsageString; + } + + const std::string test_platform_name = GetTestPlatformName(opts.platform); + const std::string reference_platform_name = + GetReferencePlatformName(opts.reference_platform); + + std::string hlo_filename; + if (!opts.input_module.empty()) { + hlo_filename = opts.input_module; + } else { + QCHECK(argc == 2) << "Must specify a single input file"; + hlo_filename = argv[1]; + } + + std::minstd_rand0 engine; + int failure_count = 0; + const int iteration_count = opts.iterations; + for (int i = 1; i <= iteration_count; ++i) { + if (iteration_count != 1) { + std::cerr << "\n=== Iteration " << i << "\n"; + } + ::testing::AssertionResult matched = + xla::RunAndCompare(hlo_filename, test_platform_name, + reference_platform_name, &engine, opts); + + // The AssertionResult is only meaningful when the reference is + // used. Without a reference, the test just verifies that nothing blew up + // when running the module. + if (!reference_platform_name.empty()) { + if (matched) { + // Success. + std::cerr << "\n** Results on " << test_platform_name << " and " + << reference_platform_name << " are close enough. **\n"; + } else { + failure_count++; + std::cerr << matched.message() << "\n"; + } + } + } + + if (!reference_platform_name.empty()) { + std::cerr << failure_count << "/" << iteration_count + << " runs miscompared.\n"; + } + + return failure_count == 0 ? 0 : -1; +} diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 588420eb1b6..29cfe52a196 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -67,6 +67,7 @@ load( "//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_android", + "if_chromiumos", "if_emscripten", "if_ios", "if_mobile", @@ -105,6 +106,7 @@ load( "//tensorflow/core/platform:build_config.bzl", "tf_additional_all_protos", "tf_additional_core_deps", + "tf_additional_env_hdrs", "tf_additional_lib_deps", "tf_additional_monitoring_hdrs", "tf_additional_test_deps", @@ -398,7 +400,7 @@ filegroup( "//tensorflow/core/platform:file_statistics.h", "//tensorflow/core/platform:file_system.h", "//tensorflow/core/platform:path.h", - ], + ] + tf_additional_env_hdrs(), visibility = ["//visibility:private"], ) @@ -464,7 +466,6 @@ cc_library( "//tensorflow/core/lib/core:legacy_lib_proto_parsing_headers", "//tensorflow/core/lib/strings:legacy_lib_proto_parsing_headers", "//tensorflow/core/platform:init_main.h", - "//tensorflow/core/platform:legacy_proto_hdrs", "//tensorflow/core/platform:logging.h", "//tensorflow/core/platform:macros.h", "//tensorflow/core/platform:platform.h", @@ -1569,7 +1570,6 @@ filegroup( "//tensorflow/core/lib/strings:legacy_lib_strings_all_headers", "//tensorflow/core/lib/strings:legacy_lib_strings_all_srcs", "//tensorflow/core/platform/default/build_config:android_srcs", - "//tensorflow/core/platform:legacy_srcs_no_runtime", "//tensorflow/core/profiler:mobile_srcs", "//tensorflow/core/util/ctc:android_srcs", "//tensorflow/core/util/sparse:mobile_srcs_no_runtime_group", @@ -1604,6 +1604,9 @@ filegroup( "common_runtime/eager/*", "common_runtime/gpu_device_factory.*", ], + ) + if_chromiumos( + ["//tensorflow/core/platform:legacy_srcs_no_runtime_google"], + otherwise = ["//tensorflow/core/platform:legacy_srcs_no_runtime"], ), visibility = ["//visibility:private"], ) @@ -2166,8 +2169,6 @@ cc_library( "lib/png/**/*", ], ) + [ - "//tensorflow/core/platform:legacy_monitoring_srcs", - "//tensorflow/core/platform:legacy_platform_lib_srcs", "//tensorflow/core/platform:legacy_lib_internal_srcs", ], hdrs = LIB_INTERNAL_PUBLIC_HEADERS, @@ -2255,6 +2256,7 @@ cc_library( "//tensorflow/core/lib/strings:strcat", "//tensorflow/core/lib/strings:stringprintf", "//tensorflow/core/platform:abi", + "//tensorflow/core/platform:base64", "//tensorflow/core/platform:blocking_counter", "//tensorflow/core/platform:coding", "//tensorflow/core/platform:context", @@ -2270,6 +2272,7 @@ cc_library( "//tensorflow/core/platform:hash", "//tensorflow/core/platform:load_library", "//tensorflow/core/platform:logger", + "//tensorflow/core/platform:monitoring", "//tensorflow/core/platform:mutex", "//tensorflow/core/platform:notification", "//tensorflow/core/platform:net", @@ -2283,6 +2286,8 @@ cc_library( "//tensorflow/core/platform:regexp", "//tensorflow/core/platform:scanner", "//tensorflow/core/platform:setround", + "//tensorflow/core/platform:stacktrace", + "//tensorflow/core/platform:stacktrace_handler", "//tensorflow/core/platform:status", "//tensorflow/core/platform:strcat", "//tensorflow/core/platform:stringpiece", @@ -2695,6 +2700,8 @@ tf_cuda_library( "@com_google_absl//absl/time", "//third_party/eigen3", "//tensorflow/core/framework:attr_value_proto_text", + "//tensorflow/core/framework:bfloat16", + "//tensorflow/core/framework:numeric_types", "//tensorflow/core/kernels:bounds_check", "//tensorflow/core/platform/default/build_config:platformlib", "//tensorflow/core/profiler/lib:traceme", @@ -2772,7 +2779,7 @@ tf_cuda_library( deps = [":framework_lite"], ) -# TODO(josh11b): Is this needed, or can we just use ":protos_all"? +# TODO(josh11b): Is this needed, or can we just use ":protos_all_cc"? cc_library( name = "protos_cc", visibility = ["//visibility:public"], diff --git a/tensorflow/core/api_def/base_api/api_def_DebugNumericSummaryV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_DebugNumericSummaryV2.pbtxt index c9097723057..28f0271c7e8 100644 --- a/tensorflow/core/api_def/base_api/api_def_DebugNumericSummaryV2.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_DebugNumericSummaryV2.pbtxt @@ -15,17 +15,67 @@ Tensor debug mode: the mode in which the input tensor is summarized tensorflow/core/protobuf/debug_event.proto for details. Supported values: - 8 (REDUCE_INF_NAN_THREE_SLOTS): Output a float32 tensor of shape + 2 (CURT_HEALTH): Output a float32/64 tensor of shape [2]. The 1st + element is the tensor_id, if provided, and -1 otherwise. The 2nd + element is a bit which is set to 1 if the input tensor has an + infinity or nan value, or zero otherwise. + + 3 (CONCISE_HEALTH): Ouput a float32/64 tensor of shape [5]. The 1st + element is the tensor_id, if provided, and -1 otherwise. The + remaining four slots are the total number of elements, -infs, + +infs, and nans in the input tensor respectively. + + 4 (FULL_HEALTH): Output a float32/64 tensor of shape [11]. The 1st + element is the tensor_id, if provided, and -1 otherwise. The 2nd + element is the device_id, if provided, and -1 otherwise. The 3rd + element holds the datatype value of the input tensor as according + to the enumerated type in tensorflow/core/framework/types.proto. + The remaining elements hold the total number of elements, -infs, + +infs, nans, negative finite numbers, zeros, and positive finite + numbers in the input tensor respectively. + + 5 (SHAPE): Output a float32/64 tensor of shape [10]. The 1st + element is the tensor_id, if provided, and -1 otherwise. The 2nd + element holds the datatype value of the input tensor as according + to the enumerated type in tensorflow/core/framework/types.proto. + The 3rd element holds the rank of the tensor. The 4th element holds + the number of elements within the tensor. Finally the remaining 6 + elements hold the shape of the tensor. If the rank of the tensor + is lower than 6, the shape is right padded with zeros. If the rank + is greater than 6, the head of the shape is truncated. + + 6 (FULL_NUMERICS): Output a float32/64 tensor of shape [22]. The 1st + element is the tensor_id, if provided, and -1 otherwise. The 2nd + element is the device_id, if provided, and -1 otherwise. The 3rd + element holds the datatype value of the input tensor as according + to the enumerated type in tensorflow/core/framework/types.proto. + The 4th element holds the rank of the tensor. The 5th to 11th + elements hold the shape of the tensor. If the rank of the tensor + is lower than 6, the shape is right padded with zeros. If the rank + is greater than 6, the head of the shape is truncated. The 12th to + 18th elements hold the number of elements, -infs, +infs, nans, + denormal floats, negative finite numbers, zeros, and positive + finite numbers in the input tensor respectively. The final four + elements hold the min value, max value, mean, and variance of the + input tensor. + + 8 (REDUCE_INF_NAN_THREE_SLOTS): Output a float32/64 tensor of shape [3]. The 1st element is -inf if any elements of the input tensor is -inf, or zero otherwise. The 2nd element is +inf if any elements of the input tensor is +inf, or zero otherwise. The 3rd element is - nan if any element of the input tensor is nan, or zero otherwise + nan if any element of the input tensor is nan, or zero otherwise. END } attr { name: "tensor_id" description: <>> strings = tf.constant(['Hello','TensorFlow', '\U0001F642']) +>>> tf.strings.length(strings).numpy() # default counts bytes +array([ 5, 10, 4], dtype=int32) +>>> tf.strings.length(strings, unit="UTF8_CHAR").numpy() +array([ 5, 10, 1], dtype=int32) + END } diff --git a/tensorflow/core/api_def/base_api/api_def_TPUReplicatedInput.pbtxt b/tensorflow/core/api_def/base_api/api_def_TPUReplicatedInput.pbtxt index acd52a735cb..d632da17ad9 100644 --- a/tensorflow/core/api_def/base_api/api_def_TPUReplicatedInput.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_TPUReplicatedInput.pbtxt @@ -2,4 +2,17 @@ op { graph_op_name: "TPUReplicatedInput" visibility: HIDDEN summary: "Connects N inputs to an N-way replicated TPU computation." + description: <