Merge branch 'master' into fix_minimum_maximum

This commit is contained in:
Elena Zhelezina 2019-12-09 10:05:03 +00:00 committed by GitHub
commit ec5d3a0603
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
682 changed files with 66050 additions and 44576 deletions

View File

@ -13,55 +13,4 @@
/tensorflow/tensorboard/ @jart
/tensorflow/tools/docs/ @markdaoust
# contrib
# NEED OWNER: /tensorflow/contrib/all_reduce
/tensorflow/contrib/autograph/ @mdanatg @kkimdev
/tensorflow/contrib/batching/ @alextp @chrisolston
/tensorflow/contrib/bayesflow/ @ebrevdo @rsepassi @jvdillon
/tensorflow/contrib/boosted_trees/ @sshrdp @yk5 @nataliaponomareva
/tensorflow/contrib/checkpoint/ @allenlavoie
/tensorflow/contrib/contrib/cluster_resolver/ @frankchn
/tensorflow/contrib/cmake/ @mrry
/tensorflow/contrib/copy_graph/ @tucker @poxvoculi
/tensorflow/contrib/crf/ @kentonl
/tensorflow/contrib/data/ @mrry
/tensorflow/tensorflow/contrib/distribute @joshl @priyag @sourabhbajaj @frankchn
/tensorflow/contrib/distributions/ @jvdillon @langmore @rsepassi
/tensorflow/contrib/eager @jaingaurav @alextp
/tensorflow/contrib/factorization/ @agarwal-ashish @xavigonzalvo
/tensorflow/contrib/ffmpeg/ @fredbertsch
/tensorflow/contrib/framework/ @ebrevdo
/tensorflow/contrib/graph_editor/ @purpledog
# NEED OWNER: /tensorflow/contrib/grid_rnn/
/tensorflow/contrib/hadoop @yongtang
/tensorflow/contrib/hvx/ @satok16
/tensorflow/contrib/integrate/ @shoyer
/tensorflow/contrib/kernel_methods/ @petrosmol
/tensorflow/contrib/ios_examples/ @petewarden
/tensorflow/contrib/labeled_tensor/ @shoyer
/tensorflow/contrib/layers/ @fchollet @martinwicke
/tensorflow/contrib/learn/ @martinwicke @ispirmustafa @alextp
/tensorflow/contrib/linear_optimizer/ @petrosmol @andreasst @katsiapis
/tensorflow/contrib/lookup/ @ysuematsu @andreasst
/tensorflow/contrib/losses/ @alextp @ispirmustafa
/tensorflow/contrib/makefile/ @petewarden @satok16 @wolffg
/tensorflow/contrib/metrics/ @alextp @honkentuber @ispirmustafa
/tensorflow/contrib/opt/ @strategist333 @alextp
/tensorflow/contrib/pi_examples/ @maciekcc
/tensorflow/contrib/quantization/ @petewarden
/tensorflow/contrib/rnn/ @ebrevdo @scottzhu
/tensorflow/contrib/saved_model/ @nfiedel @sukritiramesh @allenlavoie
/tensorflow/contrib/seq2seq/ @ebrevdo @lmthang
/tensorflow/contrib/session_bundle/ @nfiedel @sukritiramesh
/tensorflow/contrib/slim/ @sguada @thenbasilmanran
/tensorflow/contrib/stateless/ @girving @alextp
/tensorflow/contrib/tensor_forest/ @gilberthendry @thomascolthurst @yupbank
/tensorflow/contrib/tensorrt/ @aaroey @smit-hinsu @azaks2
# NEED OWNER: /tensorflow/contrib/testing/
/tensorflow/contrib/timeseries/ @allenlavoie
/tensorflow/contrib/tpu/ @frankchn @saeta @jhseu @sourabhbajaj
/tensorflow/contrib/training/ @joel-shor @ebrevdo
/tensorflow/contrib/util/ @sherrym
/third_party/systemlibs/ @perfinion

View File

@ -110,19 +110,19 @@ Build Type | Status
### Community Supported Builds
Build Type | Status | Artifacts
------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
**Linux AMD ROCm GPU** Nightly | [![Build Status](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly/badge/icon)](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly) | [Nightly](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly/lastSuccessfulBuild/)
**Linux AMD ROCm GPU** Stable Release | [![Build Status](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/badge/icon)](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/) | Release [1.15](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/lastSuccessfulBuild/) / [2.x](http://ml-ci.amd.com:21096/job/tensorflow-rocm-v2-release/lastSuccessfulBuild/)
**Linux s390x** Nightly | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | [Nightly](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/)
**Linux s390x CPU** Stable Release | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/badge/icon)](https://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/) | [Release](https://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/)
**Linux ppc64le CPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Nightly_Artifact/)
**Linux ppc64le CPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_CPU_Release_Build/)
**Linux ppc64le GPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/)
**Linux ppc64le GPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_GPU_Release_Build/)
**Linux CPU with Intel® MKL-DNN** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/)
**Linux CPU with Intel® MKL-DNN** <br> **Supports Python 2.7, 3.4, 3.5, 3.6 and 3.7** | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild) | [1.14.0 PyPI](https://pypi.org/project/intel-tensorflow/)
**Red Hat® Enterprise Linux® 7.6 CPU & GPU** <br> Python 2.7, 3.6 | [![Build Status](https://jenkins-tensorflow.apps.ci.centos.org/buildStatus/icon?job=tensorflow-rhel7-3.6&build=2)](https://jenkins-tensorflow.apps.ci.centos.org/job/tensorflow-rhel7-3.6/2/) | [1.13.1 PyPI](https://tensorflow.pypi.thoth-station.ninja/index/)
Build Type | Status | Artifacts
----------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
**Linux AMD ROCm GPU** Nightly | [![Build Status](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly/badge/icon)](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly) | [Nightly](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly/lastSuccessfulBuild/)
**Linux AMD ROCm GPU** Stable Release | [![Build Status](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/badge/icon)](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/) | Release [1.15](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/lastSuccessfulBuild/) / [2.x](http://ml-ci.amd.com:21096/job/tensorflow-rocm-v2-release/lastSuccessfulBuild/)
**Linux s390x** Nightly | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | [Nightly](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/)
**Linux s390x CPU** Stable Release | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/badge/icon)](https://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/) | [Release](https://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/)
**Linux ppc64le CPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Nightly_Artifact/)
**Linux ppc64le CPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_CPU_Release_Build/)
**Linux ppc64le GPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/)
**Linux ppc64le GPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_GPU_Release_Build/)
**Linux CPU with Intel® MKL-DNN** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/)
**Linux CPU with Intel® MKL-DNN** Stable Release | ![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon) | Release [1.15](https://pypi.org/project/intel-tensorflow/1.15.0/) / [2.x](https://pypi.org/project/intel-tensorflow/)
**Red Hat® Enterprise Linux® 7.6 CPU & GPU** <br> Python 2.7, 3.6 | [![Build Status](https://jenkins-tensorflow.apps.ci.centos.org/buildStatus/icon?job=tensorflow-rhel7-3.6&build=2)](https://jenkins-tensorflow.apps.ci.centos.org/job/tensorflow-rhel7-3.6/2/) | [1.13.1 PyPI](https://tensorflow.pypi.thoth-station.ninja/index/)
## Resources

View File

@ -195,6 +195,12 @@ config_setting(
visibility = ["//visibility:public"],
)
config_setting(
name = "chromiumos",
values = {"crosstool_top": "//external:android/chromiumos"},
visibility = ["//visibility:public"],
)
config_setting(
name = "linux_aarch64",
values = {"cpu": "aarch64"},
@ -453,6 +459,7 @@ package_group(
"//tensorflow_estimator/python/estimator/...",
"//tensorflow_models/official/...",
"//third_party/py/autograph/...",
"//third_party/swift/tensorflow/x10/...",
],
)

View File

@ -233,7 +233,7 @@ tensorflow::Status GetReplacedFromExistingWorkers(
std::vector<tensorflow::eager::KeepAliveResponse> responses(
existing_workers->size());
for (int i = 0; i < existing_workers->size(); i++) {
tensorflow::eager::EagerClient* eager_client;
tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
statuses[i] =
client_cache->GetClient(existing_workers->at(i), &eager_client);
if (!statuses[i].ok()) {
@ -282,7 +282,7 @@ tensorflow::Status CreateRemoteContexts(
continue;
}
tensorflow::eager::EagerClient* eager_client;
tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
statuses[i] = remote_eager_workers->GetClient(remote_worker, &eager_client);
if (eager_client == nullptr) {
statuses[i] = tensorflow::errors::Internal(
@ -340,7 +340,7 @@ tensorflow::Status UpdateRemoteContexts(
continue;
}
tensorflow::eager::EagerClient* eager_client;
tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
statuses[i] = remote_eager_workers->GetClient(remote_worker, &eager_client);
if (eager_client == nullptr) {
statuses[i] = tensorflow::errors::Internal(
@ -819,7 +819,7 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
}
// TODO(yuefengz): support partially specified `worker_name`.
tensorflow::eager::EagerClient* eager_client;
tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
status->status = remote_eager_workers->GetClient(worker_name, &eager_client);
if (!status->status.ok()) {
return false;

View File

@ -38,6 +38,6 @@ versions {
# CHECK: func @main(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<*xi32>
# CHECK: attributes {tf.entry_function = {inputs = "input0,input1", outputs = "output"}} {
# CHECK-NEXT: %0 = "tf.BannaPotatoSaladWithColeslaw"(%arg0, %arg1) {T = "tfdtype$DT_INT32", device = "", name = "output"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<*xi32>
# CHECK-NEXT: %0 = "tf.BannaPotatoSaladWithColeslaw"(%arg0, %arg1) {T = i32, device = "", name = "output"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<*xi32>
# CHECK-NEXT: return %0 : tensor<*xi32>
# CHECK-NEXT: }

View File

@ -1280,3 +1280,13 @@ func @conv2d_backprop_unsupported_data_format(%arg0: tensor<4xi32>, %arg1: tenso
// CHECK-LABEL: conv2d_backprop_unsupported_data_format
// CHECK: tf.Conv2DBackpropInput
}
func @assert_remove(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi1> {
%0 = "tf.LessEqual"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
"tf.Assert"(%0, %arg1) {summarize = 3} : (tensor<1xi1>, tensor<1xi32>) -> ()
return %0 : tensor<1xi1>
// CHECK-LABEL: assert_remove
// CHECK: tfl.less_equal
// CHECK-NOT: Assert
// CHECK: return
}

View File

@ -622,3 +622,12 @@ func @Relu1_2(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
// CHECK: %[[relu_n1_to_1:[0-9].*]] = "tfl.relu_n1_to_1"
}
// CHECK-LABEL: fuse_relu_to_add
func @fuse_relu_to_add(%arg0: tensor<2x3xf32>, %arg1: tensor<2x3xf32>) -> tensor<2x3xf32> {
%0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
%1 = "tfl.relu_n1_to_1"(%0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
return %1 : tensor<2x3xf32>
// CHECK: %[[RES:.*]] = tfl.add %arg0, %arg1 {fused_activation_function = "RELU_N1_TO_1"}
// CHECK: return %[[RES]]
}

View File

@ -68,6 +68,7 @@ struct LegalizeTF : public FunctionPass<LegalizeTF> {
// TODO(antiagainst): Define this pattern in a table-driven manner once variadic
// operands are properly supported in declarative rewrite rule specification.
DECL_CONVERT_OP(Assert);
DECL_CONVERT_OP(Concat);
DECL_CONVERT_OP(ConcatV2);
DECL_CONVERT_OP(MatMul);
@ -86,7 +87,7 @@ PatternMatchResult ConvertTFConcatOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_concat_op = cast<TF::ConcatOp>(op);
SmallVector<Value*, 4> values(tf_concat_op.values());
auto values = tf_concat_op.values();
auto output_type = tf_concat_op.output()->getType();
// Extract axis attribute from constant concat_dims tensor
ElementsAttr axis;
@ -105,7 +106,7 @@ PatternMatchResult ConvertTFConcatV2Op::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_concat_op = cast<TF::ConcatV2Op>(op);
SmallVector<Value*, 4> values(tf_concat_op.values());
auto values = tf_concat_op.values();
auto output_type = tf_concat_op.output()->getType();
// Extract axis attribute from constant axis tensor
ElementsAttr axis;
@ -374,6 +375,14 @@ PatternMatchResult ConvertTFMatrixDiagV3Op::matchAndRewrite(
return matchFailure();
}
// TF Lite doesn't support Assert, we just drop the assert from the graph.
PatternMatchResult ConvertTFAssertOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
op->dropAllReferences();
op->erase();
return matchSuccess();
}
void LegalizeTF::runOnFunction() {
OwningRewritePatternList patterns;
auto* ctx = &getContext();
@ -385,7 +394,8 @@ void LegalizeTF::runOnFunction() {
.insert<ConvertTFConcatOp, ConvertTFConcatV2Op, ConvertTFMatMulOp,
ConvertTFMatrixDiagV2Op, ConvertTFMatrixDiagV3Op, ConvertTFPackOp,
ConvertTFReshapeOp, ConvertTFSplitOp, ConvertTFSplitVOp,
ConvertTFStridedSliceOp, ConvertTFUnpackOp>(ctx);
ConvertTFStridedSliceOp, ConvertTFUnpackOp, ConvertTFAssertOp>(
ctx);
applyPatternsGreedily(func, patterns);
}

View File

@ -484,9 +484,9 @@ struct ConvertTensorListResize : public ConversionPattern {
&rewriter);
// Inserts the two blocks' names into the symbol table held by the module.
// Using ModuleManager will ensure that the inserted symbol names are
// Using SymbolTable will ensure that the inserted symbol names are
// unique.
ModuleManager manager(resize_op.getParentOfType<ModuleOp>());
SymbolTable manager(resize_op.getParentOfType<ModuleOp>());
manager.insert(then_branch_op);
manager.insert(else_branch_op);
@ -754,8 +754,7 @@ struct ConvertWhile : public ConversionPattern {
cloned.removeAttr("T");
UpdateFunctionTypes(cloned);
SmallVector<Value *, 8> results(cloned.getResults());
rewriter.replaceOp(op, results);
rewriter.replaceOp(op, cloned.getResults());
return matchSuccess();
}
};

View File

@ -135,15 +135,15 @@ class FoldIfOp : public OpRewritePattern<TF::IfOp> {
static void EraseDeadFuncs(const FuncSet& candiate_funcs, ModuleOp module) {
if (candiate_funcs.empty()) return;
ModuleManager manager(module);
SymbolTable manager(module);
// Identify the functions that are used as symbols in the module and shouldn't
// be erased.
FuncSet in_use_funcs;
manager.getModule().walk([&](Operation* op) {
manager.getOp()->walk([&](Operation* op) {
for (auto attr : op->getAttrs()) {
if (auto symbol = attr.second.dyn_cast<FlatSymbolRefAttr>()) {
auto func = manager.lookupSymbol<FuncOp>(symbol.getValue());
auto func = manager.lookup<FuncOp>(symbol.getValue());
in_use_funcs.insert(func);
}
}

View File

@ -44,12 +44,13 @@ multiclass FuseActFnIntoConvOpPat<dag ActFnOp, dag ActFnAttr> {
$multiplier)>;
}
// TODO(hinsu): Also fuse ops corresponding to RELU_N1_TO_1 and SIGN_BIT fused
// TODO(hinsu): Also fuse ops corresponding to SIGN_BIT fused
// activation functions.
// Currently we're not fusing tanh, sigmoid, hard_swish and other activations
// those cannot be simply translated into clamping.
foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
[TFL_Relu6Op, TFL_AF_Relu6]] in
[TFL_Relu6Op, TFL_AF_Relu6],
[TFL_Relu1Op, TFL_AF_Relu1]] in
defm : FuseActFnIntoConvOpPat<actFnPair[0], actFnPair[1]>;
@ -291,3 +292,18 @@ def : Pat<(TFL_MaximumOp (TFL_MinimumOp $input,
(ConstantOp $NegOne)),
(TFL_Relu1Op $input),
[(ValueEquals<"-1"> $NegOne), (ValueEquals<"1"> $One)]>;
// Multi-pattern consisting of matching stand-alone op or op followed by relu.
multiclass FusedBinaryActivationFuncOpPat<dag BinaryOp> {
foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
[TFL_Relu6Op, TFL_AF_Relu6],
[TFL_Relu1Op, TFL_AF_Relu1]] in {
def : Pat<(actFnPair[0] (BinaryOp $lhs, $rhs, TFL_AF_None)),
(BinaryOp $lhs, $rhs, actFnPair[1])>;
}
}
// Instantiated FusedBinary patterns for the from-to pairs of ops.
foreach BinaryOps = [TFL_AddOp, TFL_DivOp,
TFL_MulOp, TFL_SubOp] in
defm : FusedBinaryActivationFuncOpPat<BinaryOps>;

View File

@ -192,6 +192,37 @@ cc_library(
alwayslink = 1,
)
gentbl(
name = "decompose_resource_ops_inc_gen",
tbl_outs = [
(
"-gen-rewriters",
"transforms/generated_decompose_resource_ops.inc",
),
],
tblgen = "@local_config_mlir//:mlir-tblgen",
td_file = "transforms/decompose_resource_ops.td",
td_srcs = [
":tensorflow_ops_td_files",
"@local_config_mlir//:StdOpsTdFiles",
],
)
cc_library(
name = "decompose_resource_ops",
srcs = [
"transforms/decompose_resource_ops.cc",
],
hdrs = [
"transforms/decompose_resource_ops.h",
],
deps = [
":decompose_resource_ops_inc_gen",
":tensorflow",
"@local_config_mlir//:IR",
],
)
cc_library(
name = "tensorflow_passes",
srcs = [
@ -199,6 +230,7 @@ cc_library(
"transforms/bridge_pass.cc",
"transforms/cluster_formation.cc",
"transforms/cluster_outlining.cc",
"transforms/decompose_resource_ops_pass.cc",
"transforms/delete_unused_funcs.cc",
"transforms/executor_island_coarsening.cc",
"transforms/fold_switch.cc",
@ -213,6 +245,7 @@ cc_library(
"transforms/raise_control_flow.cc",
"transforms/replicate_invariant_op_hoisting.cc",
"transforms/replicate_to_island.cc",
"transforms/resource_device_inference.cc",
"transforms/resource_op_lifting.cc",
"transforms/shape_inference.cc",
"transforms/shape_inference_pass.cc",
@ -236,6 +269,8 @@ cc_library(
":bridge_logger",
":convert_tensor",
":convert_type",
":decompose_resource_ops",
":decompose_resource_ops_inc_gen",
":device_util",
":error_util",
":export_tf_dialect_op",
@ -368,12 +403,14 @@ cc_library(
":convert_tensor",
":convert_type",
":mangling_util",
":tensorflow",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:protobuf",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
@ -564,7 +601,6 @@ cc_library(
hdrs = ["utils/error_util.h"],
deps = [
"//tensorflow/core:lib",
"//tensorflow/stream_executor/lib",
"@llvm//:support",
"@local_config_mlir//:IR",
],
@ -808,7 +844,6 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core/platform:logging",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/types:span",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:Parser",

View File

@ -34,6 +34,7 @@ limitations under the License.
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
@ -99,12 +100,13 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) {
auto forward_input_to_output = [&](Value* operand, Value* result) {
if (!mlir::getElementTypeOrSelf(result->getType()).isa<TF::ResourceType>())
return;
auto& result_ids = resource_value_to_ids_[result];
auto operand_it = resource_value_to_ids_.find(operand);
assert(operand_it != resource_value_to_ids_.end() &&
"A resource-type output does not have the corresponding "
"resource-type input.");
resource_value_to_ids_[result].insert(operand_it->getSecond().begin(),
operand_it->getSecond().end());
result_ids.insert(operand_it->getSecond().begin(),
operand_it->getSecond().end());
};
// TODO(yuanzx): Consider control-flow ops.
func_op.walk([&](Operation* op) {
@ -119,6 +121,16 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) {
forward_input_to_output(std::get<0>(operand_and_result),
std::get<1>(operand_and_result));
}
} else if (auto replicate = llvm::dyn_cast<tf_device::ReplicateOp>(op)) {
// The nested block for RepliateOp is handled separately in side-effect
// analysis. Inside that block, we can still treat its block arguments as
// different resources.
for (auto arg : replicate.GetBody().getArguments()) {
if (mlir::getElementTypeOrSelf(arg->getType())
.isa<TF::ResourceType>()) {
resource_value_to_ids_[arg].insert(next_unique_id++);
}
}
} else {
for (auto result : op->getResults()) {
if (!mlir::getElementTypeOrSelf(result->getType())
@ -261,9 +273,36 @@ void SideEffectAnalysis::AddPredecessorsForAccess(int64_t resource_id,
void SideEffectAnalysis::AnalyzeFunction(
FuncOp func_op, const ResourceAliasAnalysis& alias_analysis) {
// This function populates control_predecessors_ and control_successors_ by
// walking through func_op's body, and tracking resource accesses in
// per_resource_access_info_.
// AnalyzeRegion() recursively analyzes the function body, and only populates
// control_predecessors_.
AnalyzeRegion(&func_op.getBody(), alias_analysis);
// Populate sorted_control_predecessors_ and sorted_control_successors_ based
// on control_predecessors.
for (auto& entry : control_predecessors_) {
auto op = entry.getFirst();
auto& sorted_predecessors = sorted_control_predecessors_[op];
for (auto predecessor : entry.getSecond()) {
sorted_predecessors.push_back(predecessor);
sorted_control_successors_[predecessor].push_back(op);
}
}
control_predecessors_.clear();
for (auto& entry : sorted_control_predecessors_) {
llvm::sort(entry.getSecond(), [](Operation* a, Operation* b) {
return a->isBeforeInBlock(b);
});
}
for (auto& entry : sorted_control_successors_) {
llvm::sort(entry.getSecond(), [](Operation* a, Operation* b) {
return a->isBeforeInBlock(b);
});
}
}
void SideEffectAnalysis::AnalyzeRegion(
Region* region, const ResourceAliasAnalysis& alias_analysis) {
// This function populates control_predecessors_ by walking through the
// region, and tracking resource accesses in per_resource_access_info_.
// Returns whether an access to `resource` can skip control edges from
// prevoius accesses to unknown resources, due to that earlier accesses to
@ -284,82 +323,93 @@ void SideEffectAnalysis::AnalyzeFunction(
(it->second.tracked_last_unknown_read || no_unknown_read);
};
func_op.walk([&](Operation* op) {
// We do not need explicit control edges for declaration ops.
if (OpIsDeclaration(op, alias_analysis)) return;
auto resource_op_info = GetResourceInfoForOp(op);
if (!resource_op_info && op->hasNoSideEffect()) return;
llvm::SmallDenseSet<int64_t, 8> resources =
resource_op_info ? FindAccessedResources(op, alias_analysis)
: UnknownResourceSet();
assert(!resources.empty());
const bool is_unknown = resources.count(kUnknownResourceId) > 0;
const bool read_only = OpIsReadOnly(op);
bool indirectly_tracked_unknown_access = false;
// First add edges from known resources.
if (is_unknown) {
for (auto& entry : per_resource_access_info_) {
if (entry.getFirst() == kUnknownResourceId) continue;
AddPredecessorsForAccess(entry.getFirst(), op, read_only);
indirectly_tracked_unknown_access |=
unknown_access_indirectly_tracked_by_resource(entry.getFirst(),
read_only);
// We explicitly iterates through the regions and blocks, in order to handle
// different nested regions separately.
for (auto& block : *region) {
for (auto& op : block) {
if (op.getNumRegions() > 0) {
llvm::SmallVector<SideEffectAnalysis, 4> child_analyses;
for (auto& child_region : op.getRegions()) {
child_analyses.emplace_back();
child_analyses.back().AnalyzeRegion(&child_region, alias_analysis);
}
ConsumeChildAnalyses(std::move(child_analyses));
}
} else {
for (int64_t resource : resources) {
AddPredecessorsForAccess(resource, op, read_only);
indirectly_tracked_unknown_access |=
unknown_access_indirectly_tracked_by_resource(resource, read_only);
// Update access info for known resources.
TrackAccess(resource, op, read_only);
}
}
// If not indirectly tracked, add edges from the unknown resource.
if (!indirectly_tracked_unknown_access) {
AddPredecessorsForAccess(kUnknownResourceId, op, read_only);
}
if (is_unknown) {
// Update access info for unknown resource.
TrackAccess(kUnknownResourceId, op, read_only);
}
});
// Populate control_successors_ based on control_predecessors_.
for (auto& entry : control_predecessors_) {
auto op = entry.getFirst();
for (auto predecessor : entry.getSecond()) {
control_successors_[predecessor].insert(op);
// We do not need explicit control edges for declaration ops.
if (OpIsDeclaration(&op, alias_analysis)) continue;
auto resource_op_info = GetResourceInfoForOp(&op);
if (!resource_op_info && op.hasNoSideEffect()) continue;
llvm::SmallDenseSet<int64_t, 8> resources =
resource_op_info ? FindAccessedResources(&op, alias_analysis)
: UnknownResourceSet();
assert(!resources.empty());
const bool is_unknown = resources.count(kUnknownResourceId) > 0;
const bool read_only = OpIsReadOnly(&op);
bool indirectly_tracked_unknown_access = false;
// First add edges from known resources.
if (is_unknown) {
for (auto& entry : per_resource_access_info_) {
if (entry.getFirst() == kUnknownResourceId) continue;
AddPredecessorsForAccess(entry.getFirst(), &op, read_only);
indirectly_tracked_unknown_access |=
unknown_access_indirectly_tracked_by_resource(entry.getFirst(),
read_only);
}
} else {
for (int64_t resource : resources) {
AddPredecessorsForAccess(resource, &op, read_only);
indirectly_tracked_unknown_access |=
unknown_access_indirectly_tracked_by_resource(resource,
read_only);
// Update access info for known resources.
TrackAccess(resource, &op, read_only);
}
}
// If not indirectly tracked, add edges from the unknown resource.
if (!indirectly_tracked_unknown_access) {
AddPredecessorsForAccess(kUnknownResourceId, &op, read_only);
}
if (is_unknown) {
// Update access info for unknown resource.
TrackAccess(kUnknownResourceId, &op, read_only);
}
}
}
}
llvm::SmallVector<Operation*, 8> SideEffectAnalysis::DirectControlPredecessors(
void SideEffectAnalysis::ConsumeChildAnalyses(
llvm::SmallVector<SideEffectAnalysis, 4>&& children) {
for (auto& child : children) {
for (auto& entry : child.control_predecessors_) {
control_predecessors_[entry.getFirst()] = std::move(entry.getSecond());
}
}
}
llvm::SmallVector<Operation*, 4> SideEffectAnalysis::DirectControlPredecessors(
Operation* op, llvm::function_ref<bool(Operation*)> filter) const {
llvm::SmallVector<Operation*, 8> result;
auto it = control_predecessors_.find(op);
if (it == control_predecessors_.end()) return result;
llvm::SmallVector<Operation*, 4> result;
auto it = sorted_control_predecessors_.find(op);
if (it == sorted_control_predecessors_.end()) return result;
result.reserve(it->getSecond().size());
for (auto predecessor : it->getSecond()) {
if (!filter || filter(predecessor)) result.push_back(predecessor);
}
llvm::sort(result,
[](Operation* a, Operation* b) { return a->isBeforeInBlock(b); });
return result;
}
llvm::SmallVector<Operation*, 8> SideEffectAnalysis::DirectControlSuccessors(
llvm::SmallVector<Operation*, 4> SideEffectAnalysis::DirectControlSuccessors(
Operation* op, llvm::function_ref<bool(Operation*)> filter) const {
llvm::SmallVector<Operation*, 8> result;
auto it = control_successors_.find(op);
if (it == control_successors_.end()) return result;
llvm::SmallVector<Operation*, 4> result;
auto it = sorted_control_successors_.find(op);
if (it == sorted_control_successors_.end()) return result;
result.reserve(it->getSecond().size());
for (auto successor : it->getSecond()) {
if (!filter || filter(successor)) result.push_back(successor);
}
llvm::sort(result,
[](Operation* a, Operation* b) { return a->isBeforeInBlock(b); });
return result;
}

View File

@ -32,6 +32,9 @@ namespace TF {
// An analysis that runs on a function and maps each resource-type value to a
// set of unique int64_t IDs representing the possible resources it could alias.
//
// If there are nested regions, each region is handled separately. This means
// cross-region aliasing cannot be checked by this analysis.
class ResourceAliasAnalysis {
public:
explicit ResourceAliasAnalysis(Operation* op);
@ -63,8 +66,12 @@ class ResourceAliasAnalysis {
// interfering with all known resource op accesses. It distinguishes accesses
// based on whether they are read-only, and read-only ops do not interfer with
// each other.
//
// If there are nested regions, each region is handled separately, and control
// dependencies are only tracked for ops under the same parent op.
class SideEffectAnalysis {
public:
explicit SideEffectAnalysis() = default;
explicit SideEffectAnalysis(Operation* op);
SideEffectAnalysis(SideEffectAnalysis&& other) = default;
~SideEffectAnalysis() = default;
@ -72,23 +79,32 @@ class SideEffectAnalysis {
// Returns a vector of ops that are direct control predecessors of `op`,
// sorted in program order. If `filter` is provided, only predecessors that
// pass the filter (returning true) will be included.
llvm::SmallVector<Operation*, 8> DirectControlPredecessors(
llvm::SmallVector<Operation*, 4> DirectControlPredecessors(
Operation* op,
llvm::function_ref<bool(Operation*)> filter = nullptr) const;
// Returns a vector of ops that are direct control successors of `op`, sorted
// in program order. If `filter` is provided, only successors that pass the
// filter (returning true) will be included.
llvm::SmallVector<Operation*, 8> DirectControlSuccessors(
llvm::SmallVector<Operation*, 4> DirectControlSuccessors(
Operation* op,
llvm::function_ref<bool(Operation*)> filter = nullptr) const;
private:
// Runs the analysis on `func_op` and populates control_predecessors_ and
// control_successors_.
// Runs the analysis on `func_op` and populates sorted_control_predecessors_
// and sorted_control_successors_.
void AnalyzeFunction(FuncOp func_op,
const ResourceAliasAnalysis& alias_analysis);
// Runs the analysis on `region` and populates control_predecessors_.
void AnalyzeRegion(Region* region,
const ResourceAliasAnalysis& alias_analysis);
// Moves the control_predecessors_ fields in `children` analyses to this
// current analysis.
void ConsumeChildAnalyses(
llvm::SmallVector<SideEffectAnalysis, 4>&& children);
// Updates control_predecessors_ for `op` that is being visted, on the given
// `resource_id`.
void AddPredecessorsForAccess(int64_t resource_id, Operation* op,
@ -98,11 +114,14 @@ class SideEffectAnalysis {
void TrackAccess(int64_t resource_id, Operation* op, bool read_only);
// Maps from an op to its control predecessors.
llvm::SmallDenseMap<Operation*, llvm::SmallPtrSet<Operation*, 8>, 8>
llvm::SmallDenseMap<Operation*, llvm::SmallPtrSet<Operation*, 4>, 8>
control_predecessors_;
// Maps from an op to its control successors.
llvm::SmallDenseMap<Operation*, llvm::SmallPtrSet<Operation*, 8>, 8>
control_successors_;
// Maps from an op to its control predecessors sorted in program order.
llvm::SmallDenseMap<Operation*, llvm::SmallVector<Operation*, 4>, 8>
sorted_control_predecessors_;
// Maps from an op to its control successors sorted in program order.
llvm::SmallDenseMap<Operation*, llvm::SmallVector<Operation*, 4>, 8>
sorted_control_successors_;
// Internal per-resource data structure when we build the dependencies.
struct PerResourceAcessInfo {

View File

@ -332,8 +332,7 @@ struct DropEmptyLaunch : public OpRewritePattern<LaunchOp> {
if (&block.front() != &block.back()) return matchFailure();
// Map launch results to return operands.
llvm::SmallVector<Value*, 8> new_rets(block.front().getOperands());
rewriter.replaceOp(op, new_rets);
rewriter.replaceOp(op, block.front().getOperands());
return matchSuccess();
}

View File

@ -408,8 +408,7 @@ ParseResult ParseIslandOp(OpAsmParser &parser, OperationState &result) {
if (!wrapped_op) return failure();
OpBuilder builder(parser.getBuilder().getContext());
builder.setInsertionPointToEnd(&block);
builder.create<YieldOp>(wrapped_op->getLoc(),
llvm::to_vector<8>(wrapped_op->getResults()));
builder.create<YieldOp>(wrapped_op->getLoc(), wrapped_op->getResults());
result.location = wrapped_op->getLoc();
} else if (parser.parseRegion(body, llvm::None, llvm::None)) {
return failure();
@ -1065,8 +1064,7 @@ struct DropEmptyGraph : public OpRewritePattern<GraphOp> {
if (&block.front() != &block.back()) return matchFailure();
// Map graph results to fetch operands.
llvm::SmallVector<Value *, 8> new_rets(op.GetFetch().fetches());
rewriter.replaceOp(op, new_rets);
rewriter.replaceOp(op, op.GetFetch().fetches());
return matchSuccess();
}

View File

@ -94,6 +94,8 @@ Inputs must be of same size and shape.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>;
let hasFolder = 1;
}
def TF_AddV2Op : TF_Op<"AddV2", [Broadcastable, Commutative, NoSideEffect]>,
@ -143,6 +145,8 @@ retained with length 1.
);
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
let verifier = [{ return Verify(*this); }];
}
def TF_AnyOp : TF_Op<"Any", [NoSideEffect]> {
@ -169,6 +173,8 @@ retained with length 1.
);
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
let verifier = [{ return Verify(*this); }];
}
def TF_ArgMaxOp : TF_Op<"ArgMax", [NoSideEffect]> {
@ -2116,6 +2122,28 @@ tf.math.greater_equal(x, y) ==> [True, False, True, True]
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_HashTableV2Op : TF_Op<"HashTableV2", []> {
let summary = "Creates a non-initialized hash table.";
let description = [{
This op creates a hash table, specifying the type of its keys and values.
Before using the table you will have to initialize it. After initialization the
table will be immutable.
}];
let arguments = (ins
StrAttr:$container,
StrAttr:$shared_name,
DefaultValuedAttr<BoolAttr, "false">:$use_node_name_sharing,
TypeAttr:$key_dtype,
TypeAttr:$value_dtype
);
let results = (outs
TF_ResourceTensor:$table_handle
);
}
def TF_IdentityNOp : TF_Op<"IdentityN", [NoSideEffect]> {
let summary = [{
Returns a list of tensors with the same shapes and contents as the input
@ -2473,7 +2501,7 @@ def TF_LogicalAndOp : TF_Op<"LogicalAnd", [Broadcastable, Commutative, NoSideEff
}
def TF_LogicalNotOp : TF_Op<"LogicalNot", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Returns the truth value of NOT x element-wise.";
let summary = "Returns the truth value of `NOT x` element-wise.";
let description = [{
}];
@ -4334,6 +4362,37 @@ Resize `images` to `size` using nearest neighbor interpolation.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_ResourceApplyAdamOp : TF_Op<"ResourceApplyAdam", []> {
let summary = "Update '*var' according to the Adam algorithm.";
let description = [{
$$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$
$$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$
$$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$
$$variable := variable - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$
}];
let arguments = (ins
TF_ResourceTensor:$var,
TF_ResourceTensor:$m,
TF_ResourceTensor:$v,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$beta1_power,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$beta2_power,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$beta1,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$beta2,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$epsilon,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad,
DefaultValuedAttr<BoolAttr, "false">:$use_locking,
DefaultValuedAttr<BoolAttr, "false">:$use_nesterov
);
let results = (outs);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<3>;
}
def TF_ResourceApplyGradientDescentOp : TF_Op<"ResourceApplyGradientDescent", []> {
let summary = "Update '*var' by subtracting 'alpha' * 'delta' from it.";
@ -4353,6 +4412,34 @@ def TF_ResourceApplyGradientDescentOp : TF_Op<"ResourceApplyGradientDescent", []
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>;
}
def TF_ResourceApplyKerasMomentumOp : TF_Op<"ResourceApplyKerasMomentum", []> {
let summary = [{
Update '*var' according to the momentum scheme.
}];
let description = [{
Set use_nesterov = True if you want to use Nesterov momentum.
accum = accum * momentum - lr * grad
var += accum
}];
let arguments = (ins
TF_ResourceTensor:$var,
TF_ResourceTensor:$accum,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$momentum,
DefaultValuedAttr<BoolAttr, "false">:$use_locking,
DefaultValuedAttr<BoolAttr, "false">:$use_nesterov
);
let results = (outs);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<2>;
}
def TF_ReverseSequenceOp : TF_Op<"ReverseSequence", [NoSideEffect]> {
let summary = "Reverses variable length slices.";
@ -5117,6 +5204,8 @@ def TF_SplitVOp : TF_Op<"SplitV", [NoSideEffect]> {
TF_DerivedOperandTypeAttr Tlen = TF_DerivedOperandTypeAttr<1>;
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedResultSizeAttr num_split = TF_DerivedResultSizeAttr<0>;
let verifier = [{ return Verify(*this); }];
}
def TF_SqrtOp : TF_Op<"Sqrt", [NoSideEffect, SameOperandsAndResultType]> {
@ -5491,6 +5580,65 @@ output. For the internal use of the distributed TPU compiler.
TF_DerivedResultTypeListAttr Tresults = TF_DerivedResultTypeListAttr<0>;
}
def TF_TPUReplicatedInputOp : TF_Op<"TPUReplicatedInput", [NoSideEffect]> {
let summary = "Connects N inputs to an N-way replicated TPU computation.";
let description = [{
This operation holds a replicated input to a `tpu.replicate()` computation subgraph.
Each replicated input has the same shape and type alongside the output.
For example:
```
%a = "tf.opA"()
%b = "tf.opB"()
%replicated_input = "tf.TPUReplicatedInput"(%a, %b)
%computation = "tf.Computation"(%replicated_input)
```
The above computation has a replicated input of two replicas.
}];
let arguments = (ins
Variadic<TF_Tensor>:$inputs,
DefaultValuedAttr<BoolAttr, "false">:$is_mirrored_variable,
DefaultValuedAttr<I64Attr, "-1">:$index
);
let results = (outs
TF_Tensor:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>;
}
def TF_TPUReplicatedOutputOp : TF_Op<"TPUReplicatedOutput", [NoSideEffect]> {
let summary = "Connects N outputs from an N-way replicated TPU computation.";
let description = [{
This operation holds a replicated output from a `tpu.replicate()` computation subgraph.
Each replicated output has the same shape and type alongside the input.
For example:
```
%computation = "tf.Computation"()
%replicated_output:2 = "tf.TPUReplicatedOutput"(%computation)
```
The above computation has a replicated output of two replicas.
}];
let arguments = (ins
TF_Tensor:$input
);
let results = (outs
Variadic<TF_Tensor>:$outputs
);
TF_DerivedResultSizeAttr num_replicas = TF_DerivedResultSizeAttr<0>;
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_TanhOp : TF_Op<"Tanh", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Computes hyperbolic tangent of `x` element-wise.";
@ -5905,6 +6053,8 @@ This is the opposite of `pack`.
TF_DerivedResultSizeAttr num = TF_DerivedResultSizeAttr<0>;
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
let verifier = [{ return Verify(*this); }];
}
def TF_VariableShapeOp : TF_Op<"VariableShape", []> {

View File

@ -301,6 +301,15 @@ void AddOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
results.insert<AddToAddV2>(context);
}
//===----------------------------------------------------------------------===//
// AddNOp
//===----------------------------------------------------------------------===//
OpFoldResult AddNOp::fold(ArrayRef<Attribute> operands) {
if (operands.size() == 1) return *inputs().begin();
return {};
}
//===----------------------------------------------------------------------===//
// AddV2Op
//===----------------------------------------------------------------------===//
@ -310,6 +319,49 @@ void AddV2Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
results.insert<AddV2OfNegLeft, AddV2OfNegRight>(context);
}
//===----------------------------------------------------------------------===//
// AllOp
//===----------------------------------------------------------------------===//
// Verifies an reduction op's `input` and reduction `dims`.
static LogicalResult VerifyReductionInputAndDims(Value *input, Value *dims,
Location loc) {
auto dims_type = dims->getType().dyn_cast<RankedTensorType>();
if (!dims_type) return success();
if (dims_type.getRank() > 1)
return emitError(loc, "dimensions can only be 0D or 1D tensor");
auto input_type = input->getType().dyn_cast<RankedTensorType>();
if (!input_type) return success();
int64_t rank = input_type.getRank();
DenseIntElementsAttr dims_attr;
if (!matchPattern(dims, m_Constant(&dims_attr))) return success();
for (const auto &dim_pair : llvm::enumerate(dims_attr)) {
int64_t cur_dim = dim_pair.value().getSExtValue();
if (cur_dim < -rank || cur_dim >= rank)
return emitError(loc)
<< dim_pair.index() << "-th dimension should be in the range of [-"
<< rank << ", " << rank << ")";
}
return success();
}
static LogicalResult Verify(AllOp op) {
return VerifyReductionInputAndDims(op.input(), op.reduction_indices(),
op.getLoc());
}
//===----------------------------------------------------------------------===//
// AnyOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(AnyOp op) {
return VerifyReductionInputAndDims(op.input(), op.reduction_indices(),
op.getLoc());
}
//===----------------------------------------------------------------------===//
// AssertOp
//===----------------------------------------------------------------------===//
@ -1542,17 +1594,23 @@ static LogicalResult Verify(SoftmaxCrossEntropyWithLogitsOp op) {
// SplitOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(SplitOp op) {
// Verifies the input and split dimension operands for tf.Split/tf.SplitV.
// Writes the split dimension's index (adjusted with input rank) via `dim_index`
// if it's a constant.
template <class Op>
LogicalResult VerifySplitInputAndSplitDim(Op op, Optional<int64_t> *dim_index) {
*dim_index = llvm::None;
Value *split_dim = op.split_dim();
auto split_dim_type = split_dim->getType().dyn_cast<RankedTensorType>();
if (!split_dim_type) return success();
if (split_dim_type.getRank() != 0)
return op.emitOpError("split dimension should be an integer scalar tensor");
if (auto split_dim_type = split_dim->getType().dyn_cast<RankedTensorType>())
if (split_dim_type.getRank() != 0)
return op.emitOpError(
"split dimension should be an integer scalar tensor");
// We can perform further verification if the input tensor to be split has
// known rank and the split dimension tensor is a constant.
auto input_type = op.value()->getType().dyn_cast<RankedTensorType>();
auto input_type = op.value()->getType().template dyn_cast<RankedTensorType>();
if (!input_type) return success();
int64_t input_rank = input_type.getRank();
@ -1562,21 +1620,95 @@ static LogicalResult Verify(SplitOp op) {
DenseIntElementsAttr split_dim_attr;
if (!matchPattern(split_dim, m_Constant(&split_dim_attr))) return success();
int64_t dim_index = (*split_dim_attr.begin()).getSExtValue();
int64_t index = (*split_dim_attr.begin()).getSExtValue();
if (dim_index + input_rank < 0 || dim_index >= input_rank) {
if (index + input_rank < 0 || index >= input_rank) {
return op.emitOpError("split dimension must be in range [-")
<< input_rank << ", " << input_rank << ")";
}
if (dim_index < 0) dim_index += input_rank;
if (index < 0) index += input_rank;
*dim_index = index;
int64_t input_dim_size = input_type.getDimSize(dim_index);
if (input_dim_size < 0) return success();
return success();
}
static LogicalResult Verify(SplitOp op) {
Optional<int64_t> dim_index;
if (failed(VerifySplitInputAndSplitDim(op, &dim_index))) return failure();
if (!dim_index) return success();
int64_t input_dim_size =
op.value()->getType().cast<RankedTensorType>().getDimSize(*dim_index);
if (input_dim_size == ShapedType::kDynamicSize) return success();
if (input_dim_size % op.getNumResults() != 0)
return op.emitOpError("dimension #")
<< dim_index << " not divisible by the number of result tensors";
<< *dim_index << " not divisible by the number of result tensors";
return success();
}
//===----------------------------------------------------------------------===//
// SplitVOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(SplitVOp op) {
auto split_sizes_type =
op.size_splits()->getType().dyn_cast<RankedTensorType>();
if (!split_sizes_type) return success();
if (split_sizes_type.getRank() != 1 ||
split_sizes_type.getDimSize(0) != op.getNumResults())
return op.emitOpError("split sizes should be a 1D tensor of ")
<< op.getNumResults() << " elements";
Optional<int64_t> dim_index = 0;
if (failed(VerifySplitInputAndSplitDim(op, &dim_index))) return failure();
if (!dim_index) return success();
int64_t input_dim_size =
op.value()->getType().cast<RankedTensorType>().getDimSize(*dim_index);
if (input_dim_size == ShapedType::kDynamicSize) return success();
// If split sizes come from a constant, they must sum to the dimension size
// along split_dim, and we can have no more than one dynamic dimension.
DenseIntElementsAttr split_sizes_attr;
if (!matchPattern(op.size_splits(), m_Constant(&split_sizes_attr)))
return success();
int64_t total_dim_size = 0; // Total dimension size assigned to splits
llvm::Optional<int> dynamic_dim_index;
SmallVector<int64_t, 4> split_sizes;
split_sizes.reserve(
split_sizes_attr.getType().cast<ShapedType>().getNumElements());
for (auto dim : llvm::enumerate(split_sizes_attr)) {
int64_t dim_val = dim.value().getSExtValue();
split_sizes.push_back(dim_val);
if (dim_val == ShapedType::kDynamicSize) {
// We cannot have more than one dynamic dimension.
if (dynamic_dim_index)
return op.emitOpError(
"cannot have more than one dynamic dimension in split sizes");
dynamic_dim_index = dim.index();
} else {
total_dim_size += dim_val;
}
}
if (!dynamic_dim_index && total_dim_size != input_dim_size)
return op.emitOpError(
"split sizes must sum up to the dimension size along split "
"dimension, found ")
<< total_dim_size << " vs " << input_dim_size;
if (dynamic_dim_index && total_dim_size > input_dim_size)
return op.emitOpError(
"split sizes must sum up to be less than or equal to the "
"dimension size along split dimension, found ")
<< total_dim_size << " vs " << input_dim_size;
return success();
}
@ -1787,6 +1919,30 @@ void TruncateDivOp::getCanonicalizationPatterns(
results.insert<TruncateDivWithSqrtDivisor>(context);
}
//===----------------------------------------------------------------------===//
// UnpackOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(UnpackOp op) {
auto value_type = op.value()->getType().dyn_cast<RankedTensorType>();
if (!value_type) return success();
int64_t value_rank = value_type.getRank();
int64_t axis = op.axis().getSExtValue();
if (axis < -value_rank || axis >= value_rank)
return op.emitOpError("axis attribute must be in the range of [-")
<< value_rank << ", " << value_rank << ')';
axis = GetDimForAxis(axis, value_rank);
int64_t dim_size = value_type.getDimSize(axis);
if (ShapedType::isDynamic(dim_size)) return success();
if (dim_size != op.getNumResults())
return op.emitOpError("result count must be equal to ") << dim_size;
return success();
}
//===----------------------------------------------------------------------===//
// VariableShapeOp
//===----------------------------------------------------------------------===//

View File

@ -196,6 +196,42 @@ retained with length 1.
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
}
def TF_LegacyCallOp : TF_Op<"LegacyCall",
[CallOpInterface, NoSideEffect]> {
let summary =
"returns `f(inputs)`, where `f` is a function.";
let description = [{
The LegacyCall operation represents a direct call to a function that is
within the same symbol scope as the call and is mapped to a GraphDef node
with the function name as the op name. Unlike a PartitionedCall which
represents asynchronously executing a function across multiple devices, a
LegacyCall represents a function call with the only attribute
_diable_call_shape_inference.
}];
let arguments = (ins
Variadic<TF_Tensor>:$args,
FlatSymbolRefAttr:$f,
DefaultValuedAttr<BoolAttr, "false">:$_disable_call_shape_inference
);
let results = (outs
Variadic<TF_Tensor>:$output
);
let extraClassDeclaration = [{
// Gets the argument operands to the called function.
operand_range getArgOperands() { return args(); }
// Returns the callee of this operation.
CallInterfaceCallable getCallableForCallee() {
return getAttrOfType<SymbolRefAttr>("f");
}
}];
}
def TF_PartitionedCallOp : TF_Op<"PartitionedCall",
[CallOpInterface, NoSideEffect]> {
let summary =

View File

@ -18,7 +18,7 @@ func @multiple_return(%arg0: tensor<*xi32>, %arg1: tensor<i32>) -> (tensor<*xi32
// CHECK-LABEL: func @multiple_return
// CHECK: %[[GRAPH:.*]]:2 = tf_executor.graph {
// CHECK: %[[ADD1:.*]], %[[ADD1_control:.*]] = tf_executor.island wraps "tf.Add"(%arg0, %arg1)
// CHECK: %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island(%[[ADD1_control]]) wraps "tf.Add"(%[[ADD1]], %arg1)
// CHECK: %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ADD1]], %arg1)
// CHECK: tf_executor.fetch %[[ADD1]], %[[ADD2]] :
// CHECK: }
// CHECK: return %[[GRAPH]]#0, %[[GRAPH]]#1
@ -41,7 +41,12 @@ func @multiple_islands(%arg0: tensor<*xi32>, %arg1: tensor<i32>) -> (tensor<*xi3
%res = "tf.Print"(%sub) { message = "sub result" } : (tensor<*xi32>) -> (tensor<*xi32>)
tf_executor.yield
}
tf_executor.fetch %island1#1, %island2#1, %island3 : tensor<*xi32>, tensor<*xi32>, !tf_executor.control
%island4 = tf_executor.island(%island1#2, %island2#2) {
%add = "tf.Add"(%island1#1, %island1#1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
%res = "tf.Print"(%add) { message = "add result" } : (tensor<*xi32>) -> (tensor<*xi32>)
tf_executor.yield
}
tf_executor.fetch %island1#1, %island2#1, %island3, %island4 : tensor<*xi32>, tensor<*xi32>, !tf_executor.control, !tf_executor.control
}
return %graph#0, %graph#1 : tensor<*xi32>, tensor<*xi32>
}
@ -49,12 +54,17 @@ func @multiple_islands(%arg0: tensor<*xi32>, %arg1: tensor<i32>) -> (tensor<*xi3
// CHECK-LABEL: func @multiple_islands
// CHECK: %[[GRAPH:.*]]:2 = tf_executor.graph {
// CHECK: %[[ADD1:.*]], %[[ADD1_control:.*]] = tf_executor.island wraps "tf.Add"(%arg0, %arg1)
// CHECK: %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island(%[[ADD1_control]]) wraps "tf.Add"(%[[ADD1]], %arg1)
// CHECK: %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ADD1]], %arg1)
// CHECK: %[[SUB1:.*]], %[[SUB1_control:.*]] = tf_executor.island(%[[ADD2_control]]) wraps "tf.Sub"(%arg0, %arg1)
// CHECK: %[[MUL:.*]], %[[MUL_control:.*]] = tf_executor.island(%[[SUB1_control]]) wraps "tf.Mul"(%[[SUB1]], %arg1)
// CHECK: %[[MUL:.*]], %[[MUL_control:.*]] = tf_executor.island wraps "tf.Mul"(%[[SUB1]], %arg1)
// CHECK: %[[SUB2:.*]], %[[SUB2_control:.*]] = tf_executor.island(%[[ADD2_control]], %[[MUL_control]]) wraps "tf.Sub"(%[[ADD1]], %[[SUB1]])
// CHECK: %[[PRINT:.*]], %[[PRINT_control:.*]] = tf_executor.island(%[[SUB2_control]]) wraps "tf.Print"(%[[SUB2]]) {message = "sub result"}
// CHECK: tf_executor.fetch %[[ADD2]], %[[MUL]], %[[PRINT_control]] :
// CHECK: %[[PRINT1:.*]], %[[PRINT1_control:.*]] = tf_executor.island wraps "tf.Print"(%[[SUB2]]) {message = "sub result"}
// CHECK: %[[ISLAND1:.*]] = tf_executor.island(%[[ADD2_control]], %[[MUL_control]]) {
// CHECK: tf_executor.yield
// CHECK: }
// CHECK: %[[ADD3:.*]], %[[ADD3_control:.*]] = tf_executor.island(%[[ISLAND1]], %[[ADD2_control]]) wraps "tf.Add"(%[[ADD2]], %[[ADD2]])
// CHECK: %[[PRINT2:.*]], %[[PRINT2_control:.*]] = tf_executor.island wraps "tf.Print"(%[[ADD3]]) {message = "add result"}
// CHECK: tf_executor.fetch %[[ADD2]], %[[MUL]], %[[PRINT1_control]], %[[PRINT2_control:.*]] :
// CHECK: }
// CHECK: return %[[GRAPH]]#0, %[[GRAPH]]#1
@ -74,8 +84,8 @@ func @dangling_print(%arg0: tensor<*xi32>, %arg1: tensor<i32>) -> (tensor<*xi32>
// CHECK-LABEL: func @dangling_print
// CHECK: %[[GRAPH:.*]]:2 = tf_executor.graph {
// CHECK: %[[ADD1:.*]], %[[ADD1_control:.*]] = tf_executor.island wraps "tf.Add"(%arg0, %arg1)
// CHECK: %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island(%[[ADD1_control]]) wraps "tf.Add"(%[[ADD1_control:.*]], %arg1)
// CHECK: %[[PRINT:.*]], %[[PRINT_control:.*]] = tf_executor.island(%[[ADD2_control]]) wraps "tf.Print"(%[[ADD2_control:.*]]) {message = "add result"}
// CHECK: %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ADD1_control:.*]], %arg1)
// CHECK: %[[PRINT:.*]], %[[PRINT_control:.*]] = tf_executor.island wraps "tf.Print"(%[[ADD2_control:.*]]) {message = "add result"}
// CHECK: tf_executor.fetch %[[ADD1]], %[[ADD2]], %[[PRINT_control]] :
// CHECK: }
// CHECK: return %[[GRAPH]]#0, %[[GRAPH]]#1
@ -103,11 +113,14 @@ func @switch_and_merge(%arg0: tensor<*xi32>, %arg1: tensor<i32>) -> (tensor<*xi3
// CHECK-LABEL: func @switch_and_merge(%arg0: tensor<*xi32>, %arg1: tensor<i32>) -> (tensor<*xi32>, tensor<i32>) {
// CHECK: %[[GRAPH:.*]]:2 = tf_executor.graph {
// CHECK: %[[ADD1:.*]], %[[ADD1_control:.*]] = tf_executor.island wraps "tf.Add"(%arg0, %arg1)
// CHECK: %[[LESS:.*]], %[[LESS_control:.*]] = tf_executor.island(%[[ADD1_control]]) wraps "tf.Less"(%arg1, %arg1)
// CHECK: %[[PRINT1:.*]], %[[PRINT1_control:.*]] = tf_executor.island(%[[LESS_control]]) wraps "tf.Print"(%[[ADD1]]) {message = "add result 1"}
// CHECK: %[[SWITCH_false:.*]], %[[SWITCH_true:.*]], {{.*}} = tf_executor.Switch %[[ADD1]], %[[LESS]], %[[PRINT1_control]]
// CHECK: %[[LESS:.*]], %[[LESS_control:.*]] = tf_executor.island wraps "tf.Less"(%arg1, %arg1)
// CHECK: %[[PRINT1:.*]], %[[PRINT1_control:.*]] = tf_executor.island wraps "tf.Print"(%[[ADD1]]) {message = "add result 1"}
// CHECK: %[[ISLAND1:.*]] = tf_executor.island(%[[LESS_control]], %[[PRINT1_control]]) {
// CHECK: tf_executor.yield
// CHECK: }
// CHECK: %[[SWITCH_false:.*]], %[[SWITCH_true:.*]], {{.*}} = tf_executor.Switch %[[ADD1]], %[[LESS]], %[[ISLAND1]]
// CHECK: %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island wraps "tf.Add"(%[[SWITCH_false]], %arg1)
// CHECK: %[[PRINT2:.*]], %[[PRINT2_control:.*]] = tf_executor.island(%[[ADD2_control]]) wraps "tf.Print"(%[[ADD2]]) {message = "add result 2"}
// CHECK: %[[PRINT2:.*]], %[[PRINT2_control:.*]] = tf_executor.island wraps "tf.Print"(%[[ADD2]]) {message = "add result 2"}
// CHECK: %[[MERGE:.*]], %[[MERGE_index:.*]], %{{.*}} = tf_executor.Merge %[[ADD2]], %[[SWITCH_true]], %[[PRINT2_control]]
// CHECK: tf_executor.fetch %[[MERGE]], %[[MERGE_index]]
// CHECK: }
@ -130,7 +143,7 @@ func @control_flow_plumbing(%arg0: tensor<*xi32>, %arg1: tensor<i32>) -> tensor<
// CHECK: %[[GRAPH:.*]] = tf_executor.graph {
// CHECK: %[[PRINT:.*]], %[[PRINT_control:.*]] = tf_executor.island wraps "tf.Print"(%arg0) {message = "Random Print"}
// CHECK: %[[ADD1:.*]], %[[ADD1_control:.*]] = tf_executor.island(%[[PRINT_control]]) wraps "tf.Add"(%arg0, %arg1)
// CHECK: %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island(%[[ADD1_control]]) wraps "tf.Add"(%[[ADD1]], %arg1)
// CHECK: %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ADD1]], %arg1)
// CHECK: tf_executor.fetch %[[ADD2]] : tensor<*xi32>
// CHECK: }
// CHECK: return %[[GRAPH]] : tensor<*xi32>
@ -150,6 +163,77 @@ func @fetching_arg(%arg0: tensor<*xi32>) {
// CHECK-LABEL: func @fetching_arg
// CHECK: tf_executor.graph {
// CHECK: %[[ADD1:.*]], %[[ADD1_control:.*]] = tf_executor.island wraps "tf.Add"(%arg0, %arg0)
// CHECK: %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island(%[[ADD1_control]]) wraps "tf.Add"(%[[ADD1]], %arg0)
// CHECK: %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ADD1]], %arg0)
// CHECK: tf_executor.fetch %[[ADD2_control]] : !tf_executor.control
// CHECK: }
func @non_aliasing_reads_writes(
%arg0: tensor<*x!tf.resource<tensor<32xf32>>>,
%arg1: tensor<*x!tf.resource<tensor<32xf32>>>,
%arg2: tensor<32xf32>) -> (tensor<32xf32>) {
%graph = tf_executor.graph {
%island:2 = tf_executor.island {
%read0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
"tf.AssignVariableOp"(%arg0, %arg2) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
%read1 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
%var_handle = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> tensor<*x!tf.resource<tensor<32xf32>>>
%read2 = "tf.ReadVariableOp"(%var_handle) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
"tf.AssignVariableOp"(%arg1, %read0) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
"tf.AssignVariableOp"(%arg0, %read2) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
%read3 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
tf_executor.yield %read3 : tensor<32xf32>
}
tf_executor.fetch %island#0 : tensor<32xf32>
}
return %graph : tensor<32xf32>
}
// CHECK-LABEL: func @non_aliasing_reads_writes
// CHECK: %[[GRAPH:.*]] = tf_executor.graph {
// CHECK: %[[READ0:.*]], %[[READ0_CONTROL:.*]] = tf_executor.island wraps "tf.ReadVariableOp"(%arg0)
// CHECK: %[[ASSIGN0_CONTROL:.*]] = tf_executor.island(%[[READ0_CONTROL]]) wraps "tf.AssignVariableOp"(%arg0, %arg2)
// CHECK: %[[READ1:.*]], %[[READ1_CONTROL:.*]] = tf_executor.island wraps "tf.ReadVariableOp"(%arg1)
// CHECK: %[[VH0:.*]], %[[VH0_CONTROL:.*]] = tf_executor.island wraps "tf.VarHandleOp"() {container = "c", shared_name = "v0"}
// CHECK: %[[READ2:.*]], %[[READ2_CONTROL:.*]] = tf_executor.island wraps "tf.ReadVariableOp"(%[[VH0]])
// CHECK: %[[ASSIGN1_CONTROL:.*]] = tf_executor.island(%[[READ1_CONTROL]]) wraps "tf.AssignVariableOp"(%arg1, %[[READ0:.*]])
// CHECK: %[[ASSIGN2_CONTROL:.*]] = tf_executor.island(%[[ASSIGN0_CONTROL]]) wraps "tf.AssignVariableOp"(%arg0, %[[READ2]])
// CHECK: %[[READ3:.*]], %[[READ3_CONTROL:.*]] = tf_executor.island(%[[ASSIGN2_CONTROL]]) wraps "tf.ReadVariableOp"(%arg0)
// CHECK: %[[ISLAND1:.*]] = tf_executor.island(%[[ASSIGN1_CONTROL]], %[[READ3_CONTROL]]) {
// CHECK: tf_executor.yield
// CHECK: }
// CHECK: tf_executor.fetch %[[READ3]], %[[ISLAND1]] : tensor<32xf32>, !tf_executor.control
// CHECK: }
func @unknown_side_effecting_op(%arg0: tensor<32xf32>) -> () {
tf_executor.graph {
%island = tf_executor.island {
%vh0 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> tensor<*x!tf.resource<tensor<32xf32>>>
%vh1 = "tf.VarHandleOp"() {container = "c", shared_name = "v1"} : () -> tensor<*x!tf.resource<tensor<32xf32>>>
%read0 = "tf.ReadVariableOp"(%vh0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
"tf.AssignVariableOp"(%vh1, %arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
"tf._UnknownSideEffectingOp_"() : () -> ()
%read1 = "tf.ReadVariableOp"(%vh1) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
"tf.AssignVariableOp"(%vh0, %read1) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
"tf.AssignVariableOp"(%vh1, %read0) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
tf_executor.yield
}
tf_executor.fetch %island : !tf_executor.control
}
return
}
// CHECK-LABEL: func @unknown_side_effecting_op
// CHECK: tf_executor.graph {
// CHECK: %[[VH0:.*]], %[[VH0_CONTROL:.*]] = tf_executor.island wraps "tf.VarHandleOp"() {container = "c", shared_name = "v0"}
// CHECK: %[[VH1:.*]], %[[VH1_CONTROL:.*]] = tf_executor.island wraps "tf.VarHandleOp"() {container = "c", shared_name = "v1"}
// CHECK: %[[READ0:.*]], %[[READ0_CONTROL:.*]] = tf_executor.island wraps "tf.ReadVariableOp"(%[[VH0]])
// CHECK: %[[ASSIGN0_CONTROL:.*]] = tf_executor.island wraps "tf.AssignVariableOp"(%[[VH1]], %arg0)
// CHECK: %[[UNKNOWN_CONTROL:.*]] = tf_executor.island(%[[READ0_CONTROL]], %[[ASSIGN0_CONTROL]]) wraps "tf._UnknownSideEffectingOp_"()
// CHECK: %[[READ1:.*]], %[[READ1_CONTROL:.*]] = tf_executor.island(%[[UNKNOWN_CONTROL]]) wraps "tf.ReadVariableOp"(%[[VH1]])
// CHECK: %[[ASSIGN1_CONTROL:.*]] = tf_executor.island(%[[UNKNOWN_CONTROL]]) wraps "tf.AssignVariableOp"(%[[VH0]], %[[READ1]])
// CHECK: %[[ASSIGN2_CONTROL:.*]] = tf_executor.island(%[[READ1_CONTROL]]) wraps "tf.AssignVariableOp"(%[[VH1]], %[[READ0]])
// CHECK: %[[ISLAND1:.*]] = tf_executor.island(%[[ASSIGN1_CONTROL]], %[[ASSIGN2_CONTROL]]) {
// CHECK: tf_executor.yield
// CHECK: }
// CHECK: tf_executor.fetch %[[ISLAND1]] : !tf_executor.control
// CHECK: }

View File

@ -382,3 +382,10 @@ func @nonIdentityTranspose(%arg0: tensor<2x3x4x5x6xf32>) -> tensor<2x3x4x6x5xf32
// CHECK: %1 = "tf.Transpose"(%arg0, %0) : (tensor<2x3x4x5x6xf32>, tensor<5xi32>) -> tensor<2x3x4x6x5xf32>
// CHECK: return %1
}
// CHECK-LABEL: func @addN
func @addN(%arg0: tensor<*xf32>) -> tensor<*xf32> {
// CHECK: return %arg0
%0 = "tf.AddN"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}

View File

@ -0,0 +1,67 @@
// RUN: tf-opt %s -split-input-file -tf-device-decompose-resource-ops | FileCheck %s
// -----
// Tests that composite tf.AssignAddVariableOp operation is decomposed and
// hoisted.
// CHECK-LABEL: func @decompose_assign_add_variable_op
func @decompose_assign_add_variable_op() -> () {
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
// CHECK: %[[ONE:[0-9]*]] = "tf.Const"() {value = dense<1> : tensor<i32>}
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"
// CHECK: "tf.AddV2"(%[[RES_READ_VAL]], %[[ONE]])
// CHECK: "tf.AssignVariableOp"
%1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
"tf.AssignAddVariableOp"(%0, %1) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>, tensor<i32>) -> ()
return
}
// -----
// Tests that composite tf.AssignSubVariableOp operation is decomposed using
// SubOp.
// CHECK-LABEL: func @decompose_assign_sub_variable_op
func @decompose_assign_sub_variable_op() -> () {
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
// CHECK: %[[ONE:[0-9]*]] = "tf.Const"() {value = dense<1> : tensor<i32>}
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"
// CHECK: "tf.Sub"(%[[RES_READ_VAL]], %[[ONE]])
// CHECK: "tf.AssignVariableOp"
%1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
"tf.AssignSubVariableOp"(%0, %1) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>, tensor<i32>) -> ()
return
}
// -----
// Tests that composite tf.ResourceApplyGradientDescent operation is decomposed.
// CHECK-LABEL: func @decompose_resource_apply_gradient_descent
// CHECK-SAME: (%[[DELTA:.*]]: tensor<f32>)
func @decompose_resource_apply_gradient_descent(%arg0: tensor<f32>) -> () {
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
// CHECK: %[[ALPHA:[0-9]*]] = "tf.Const"
// CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp"
// CHECK: %[[MUL:[0-9]*]] = "tf.Mul"(%[[DELTA]], %[[ALPHA]])
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]])
// CHECK: %[[SUB:[0-9]*]] = "tf.Sub"(%[[RES_READ_VAL]], %[[MUL]])
// CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[SUB]])
%1 = "tf.Const"() {T = f32, value = dense<[0.5]> : tensor<1xf32>} : () -> tensor<f32>
"tf.ResourceApplyGradientDescent"(%0, %1, %arg0) {use_locking = false} : (tensor<*x!tf.resource>, tensor<f32>, tensor<f32>) -> ()
return
}

View File

@ -49,40 +49,33 @@ func @testIf3Result(tensor<i1>, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xi8>,
// -----
func @testIf1Then(tensor<2x?xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
func @testIf1Else(tensor<*xf32>, tensor<2x?xf32>) -> tensor<*xf32>
func @testIfThen(%arg0: tensor<!tf.variant>) -> tensor<!tf.variant> {
return %arg0 : tensor<!tf.variant>
}
func @testIfElse(%arg0: tensor<!tf.variant>) -> tensor<!tf.variant> {
return %arg0 : tensor<!tf.variant>
}
// CHECK-LABEL: func @testIf1Casts(%arg0: tensor<i1>, %arg1: tensor<2x2xf32>, %arg2: tensor<*xf32>)
func @testIf1Casts(tensor<i1>, tensor<2x2xf32>, tensor<*xf32>) -> tensor<2x?xf32> {
^bb0(%arg0: tensor<i1>, %arg1: tensor<2x2xf32>, %arg2: tensor<*xf32>):
%1 = "tf.If"(%arg0, %arg1, %arg2) {
then_branch = @testIf1Then, else_branch = @testIf1Else, is_stateless = false
} : (tensor<i1>, tensor<2x2xf32>, tensor<*xf32>) -> tensor<2x?xf32>
// CHECK: %0 = extract_element %arg0[] : tensor<i1>
// CHECK: cond_br %0, ^bb1, ^bb2
// CHECK:^bb1: // pred: ^bb0
// CHECK: %1 = tensor_cast %arg1 : tensor<2x2xf32> to tensor<2x?xf32>
// CHECK: %2 = tensor_cast %arg2 : tensor<*xf32> to tensor<2x2xf32>
// CHECK: %3 = call @testIf1Then(%1, %2) : (tensor<2x?xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: %4 = tensor_cast %3 : tensor<2x2xf32> to tensor<2x?xf32>
// CHECK: br ^bb3(%4 : tensor<2x?xf32>)
// CHECK:^bb2: // pred: ^bb0
// CHECK: %5 = tensor_cast %arg1 : tensor<2x2xf32> to tensor<*xf32>
// CHECK: %6 = tensor_cast %arg2 : tensor<*xf32> to tensor<2x?xf32>
// CHECK: %7 = call @testIf1Else(%5, %6) : (tensor<*xf32>, tensor<2x?xf32>) -> tensor<*xf32>
// CHECK: %8 = tensor_cast %7 : tensor<*xf32> to tensor<2x?xf32>
// CHECK: br ^bb3(%8 : tensor<2x?xf32>)
// CHECK:^bb3(%9: tensor<2x?xf32>): // 2 preds: ^bb1, ^bb2
%2 = "tf.Add"(%1, %1) : (tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
// CHECK: %10 = "tf.Add"(%9, %9) : (tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
return %2 : tensor<2x?xf32>
// CHECK: return %10 : tensor<2x?xf32>
// CHECK-LABEL: func @testIfCasts(%arg0: tensor<i1>, %arg1: tensor<!tf.variant<tensor<f32>>>) -> tensor<!tf.variant<tensor<f32>>>
func @testIfCasts(%arg0: tensor<i1>, %arg1: tensor<!tf.variant<tensor<f32>>>) -> tensor<!tf.variant<tensor<f32>>> {
%0 = "tf.If"(%arg0, %arg1) {
then_branch = @testIfThen, else_branch = @testIfElse, is_stateless = false
} : (tensor<i1>, tensor<!tf.variant<tensor<f32>>>) -> tensor<!tf.variant<tensor<f32>>>
return %0: tensor<!tf.variant<tensor<f32>>>
// CHECK: %0 = extract_element %arg0[] : tensor<i1>
// CHECK: cond_br %0, ^bb1, ^bb2
// CHECK: ^bb1:
// CHECK: %1 = "tf.Cast"(%arg1) {Truncate = false} : (tensor<!tf.variant<tensor<f32>>>) -> tensor<!tf.variant>
// CHECK: %2 = call @testIfThen(%1) : (tensor<!tf.variant>) -> tensor<!tf.variant>
// CHECK: %3 = "tf.Cast"(%2) {Truncate = false} : (tensor<!tf.variant>) -> tensor<!tf.variant<tensor<f32>>>
// CHECK: br ^bb3(%3 : tensor<!tf.variant<tensor<f32>>>)
// CHECK: ^bb2:
// CHECK: %4 = "tf.Cast"(%arg1) {Truncate = false} : (tensor<!tf.variant<tensor<f32>>>) -> tensor<!tf.variant>
// CHECK: %5 = call @testIfElse(%4) : (tensor<!tf.variant>) -> tensor<!tf.variant>
// CHECK: %6 = "tf.Cast"(%5) {Truncate = false} : (tensor<!tf.variant>) -> tensor<!tf.variant<tensor<f32>>>
// CHECK: br ^bb3(%6 : tensor<!tf.variant<tensor<f32>>>)
// CHECK: ^bb3(%7: tensor<!tf.variant<tensor<f32>>>):
// CHECK: return %7 : tensor<!tf.variant<tensor<f32>>>
}
// -----
@ -188,31 +181,36 @@ func @testComplexWhile1Result(tensor<*xf32>) -> (tensor<*xf32>) {
// -----
func @testWhileCond(tensor<?x3xf32>) -> (tensor<i1>)
func @testWhileBody(tensor<*xf32>) -> (tensor<?x?xf32>)
func @testWhileCond(%arg0: tensor<!tf.variant>) -> (tensor<i1>) {
%true = "tf.Const"() { value = dense<true> : tensor<i1> } : () -> (tensor<i1>)
return %true : tensor<i1>
}
func @testWhileBody(%arg0: tensor<!tf.variant<tensor<1x?xf32>>>) -> (tensor<!tf.variant<tensor<?x?xf32>>>) {
%0 = "tf.Cast"(%arg0) : (tensor<!tf.variant<tensor<1x?xf32>>>) -> tensor<!tf.variant<tensor<?x?xf32>>>
return %0 : tensor<!tf.variant<tensor<?x?xf32>>>
}
// CHECK-LABEL: func @testWhileCasts(%arg0: tensor<1x3xf32>)
func @testWhileCasts(%arg0: tensor<1x3xf32>) -> (tensor<?x?xf32>) {
// CHECK-LABEL: func @testWhileCasts(%arg0: tensor<!tf.variant<tensor<1x3xf32>>>) -> tensor<!tf.variant<tensor<*xf32>>>
func @testWhileCasts(%arg0: tensor<!tf.variant<tensor<1x3xf32>>>) -> (tensor<!tf.variant<tensor<*xf32>>>) {
%0 = "tf.While"(%arg0) {
cond = @testWhileCond, body = @testWhileBody, is_stateless = false
} : (tensor<1x3xf32>) -> (tensor<?x?xf32>)
// CHECK: %0 = tensor_cast %arg0 : tensor<1x3xf32> to tensor<?x3xf32>
// CHECK: br ^bb1(%0 : tensor<?x3xf32>)
// CHECK: ^bb1(%1: tensor<?x3xf32>):
// CHECK: %2 = call @testWhileCond(%1) : (tensor<?x3xf32>) -> tensor<i1>
} : (tensor<!tf.variant<tensor<1x3xf32>>>) -> (tensor<!tf.variant<tensor<*xf32>>>)
return %0 : tensor<!tf.variant<tensor<*xf32>>>
// CHECK: %0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<!tf.variant<tensor<1x3xf32>>>) -> tensor<!tf.variant>
// CHECK: br ^bb1(%0 : tensor<!tf.variant>)
// CHECK: ^bb1(%1: tensor<!tf.variant>): // 2 preds: ^bb0, ^bb2
// CHECK: %2 = call @testWhileCond(%1) : (tensor<!tf.variant>) -> tensor<i1>
// CHECK: %3 = extract_element %2[] : tensor<i1>
// CHECK: %4 = tensor_cast %1 : tensor<?x3xf32> to tensor<*xf32>
// CHECK: cond_br %3, ^bb2(%4 : tensor<*xf32>), ^bb3(%4 : tensor<*xf32>)
// CHECK: ^bb2(%5: tensor<*xf32>):
// CHECK: %6 = call @testWhileBody(%5) : (tensor<*xf32>) -> tensor<?x?xf32>
// CHECK: %7 = tensor_cast %6 : tensor<?x?xf32> to tensor<?x3xf32>
// CHECK: br ^bb1(%7 : tensor<?x3xf32>)
// CHECK: ^bb3(%8: tensor<*xf32>):
// CHECK: %9 = tensor_cast %8 : tensor<*xf32> to tensor<?x?xf32>
// CHECK: %4 = "tf.Cast"(%1) {Truncate = false} : (tensor<!tf.variant>) -> tensor<!tf.variant<tensor<1x?xf32>>>
// CHECK: cond_br %3, ^bb2(%4 : tensor<!tf.variant<tensor<1x?xf32>>>), ^bb3(%4 : tensor<!tf.variant<tensor<1x?xf32>>>)
// CHECK: ^bb2(%5: tensor<!tf.variant<tensor<1x?xf32>>>): // pred: ^bb1
// CHECK: %6 = call @testWhileBody(%5) : (tensor<!tf.variant<tensor<1x?xf32>>>) -> tensor<!tf.variant<tensor<?x?xf32>>>
// CHECK: %7 = "tf.Cast"(%6) {Truncate = false} : (tensor<!tf.variant<tensor<?x?xf32>>>) -> tensor<!tf.variant>
// CHECK: br ^bb1(%7 : tensor<!tf.variant>)
// CHECK: ^bb3(%8: tensor<!tf.variant<tensor<1x?xf32>>>): // pred: ^bb1
// CHECK: %9 = "tf.Cast"(%8) {Truncate = false} : (tensor<!tf.variant<tensor<1x?xf32>>>) -> tensor<!tf.variant<tensor<*xf32>>>
// CHECK: return %9 : tensor<!tf.variant<tensor<*xf32>>>
return %0 : tensor<?x?xf32>
// CHECK: return %9 : tensor<?x?xf32>
}
// -----

View File

@ -54,5 +54,5 @@ versions {
# the names are matching between the function definition and the uses / call
# site (a numerical suffix may be appended).
# CHECK: "tf.foo0"(
# CHECK: "tf.LegacyCall"(%outputs) {_disable_call_shape_inference = false, f = @foo0}
# CHECK: func @foo0

View File

@ -8,7 +8,7 @@
# Verify that we can also pull some attributes that are needed to be able to
# create a Graph in memory, like `T`.
# CHECK: tf.MaxPool
# CHECK-SAME: T = "tfdtype$DT_FLOAT"
# CHECK-SAME: T = f32
node {
name: "input"

View File

@ -0,0 +1,65 @@
# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=x -tf-input-data-types=DT_INT32 -tf-input-shapes=10 -tf-output-arrays=func_call -o - | FileCheck %s
node {
name: "x"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
dim {
size: 1
}
}
int_val: 1
}
}
}
}
node {
name: "func_call"
op: "test_func_name"
input: "x"
attr {
key: "_disable_call_shape_inference"
value {
b: true
}
}
}
library {
function {
signature {
name: "test_func_name"
input_arg {
name: "a_0"
type: DT_INT32
}
output_arg {
name: "a"
type: DT_INT32
}
}
ret {
key: "a"
value: "a_0"
}
attr {
key: "_disable_call_shape_inference"
value {
b: true
}
}
}
}
# CHECK: func @main
# CHECK: "tf.LegacyCall"(%arg0) {_disable_call_shape_inference = true, f = @test_func_name0}

View File

@ -121,8 +121,8 @@ versions {
# Verify that functions from the library are properly imported.
# CHECK-LABEL: func @main() {
# CHECK: "tf.foo110"()
# CHECK: "tf.foo111"()
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @foo110}
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @foo111}
# CHECK-LABEL: func @foo110() {
# CHECK-LABEL: func @foo111() {

View File

@ -39,10 +39,10 @@ versions {
# Verify that functions from the library are properly imported.
# CHECK-LABEL: func @main() {
# CHECK: "tf.foo0"()
# CHECK: "tf.bar0"()
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @foo0}
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @bar0}
# CHECK-LABEL: func @foo0() {
# CHECK: "tf.bar0"()
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @bar0}
# CHECK-LABEL: func @bar0() {

View File

@ -0,0 +1,300 @@
# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=z:1,z:2 -tf-input-shapes=':' -tf-output-arrays=z:2,z:1,a:0 -o - | FileCheck %s --dump-input=fail
# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-prune-unused-nodes -tf-input-arrays=z:1,z:2 -tf-input-shapes=':' -tf-output-arrays=z:2,z:1,a:0 -o - | FileCheck --check-prefix=PRUNE %s --dump-input=fail
# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-prune-unused-nodes -tf-input-arrays=z:1,z:2 -tf-input-shapes=':' -tf-output-arrays=z:0,a:0 -o - | FileCheck --check-prefix=PRESERVE %s --dump-input=fail
# Generated in Python via
# ```
# import tensorflow as tf
#
# with tf.compat.v1.Graph().as_default() as g:
# w = tf.constant(2.0)
# x = tf.constant(3.0)
# y = tf.constant(4.0)
# var = tf.Variable(2.0)
# var_add = var.assign_add(3.0)
# with g.control_dependencies([var_add]):
# z0, z1, z2 = tf.identity_n((w, x, y))
#
# a = tf.add(z1, z2)
# ```
node {
name: "w"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
}
float_val: 2.0
}
}
}
}
node {
name: "x"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
}
float_val: 3.0
}
}
}
}
node {
name: "y"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
}
float_val: 4.0
}
}
}
}
node {
name: "var/initial_value"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
}
float_val: 2.0
}
}
}
}
node {
name: "var"
op: "VariableV2"
attr {
key: "container"
value {
s: ""
}
}
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
}
}
}
attr {
key: "shared_name"
value {
s: ""
}
}
}
node {
name: "var/Assign"
op: "Assign"
input: "var"
input: "var/initial_value"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "_class"
value {
list {
s: "loc:@var"
}
}
}
attr {
key: "use_locking"
value {
b: true
}
}
attr {
key: "validate_shape"
value {
b: true
}
}
}
node {
name: "var/read"
op: "Identity"
input: "var"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "_class"
value {
list {
s: "loc:@var"
}
}
}
}
node {
name: "var_add/value"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
}
float_val: 3.0
}
}
}
}
node {
name: "var_add"
op: "AssignAdd"
input: "var"
input: "var_add/value"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "_class"
value {
list {
s: "loc:@var"
}
}
}
attr {
key: "use_locking"
value {
b: false
}
}
}
node {
name: "z"
op: "IdentityN"
input: "w"
input: "x"
input: "y"
input: "^var_add"
attr {
key: "T"
value {
list {
type: DT_FLOAT
type: DT_FLOAT
type: DT_FLOAT
}
}
}
}
node {
name: "a"
op: "Add"
input: "z:1"
input: "z:2"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
versions {
producer: 230
}
# Test non zero index output tensors as feeds. Original ops where their outputs
# are replaced with feeds are preserved and args and rets are lifted to the
# function. Rets that happen to coincide with a feed should have its value be
# of the feed.
#
# CHECK: func @main(%[[ARG_0:.*]]: tensor<f32>, %[[ARG_1:.*]]: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>)
# CHECK: attributes {tf.entry_function = {inputs = "z:1,z:2", outputs = "z:2,z:1,a:0"}}
# CHECK: %{{.*}}, %[[ASSIGN_ADD_CTRL:.*]] = tf_executor.island wraps "tf.AssignAdd"
# CHECK: %{{.*}}, %{{.*}} = tf_executor.island(%[[ASSIGN_ADD_CTRL]]) wraps "tf.IdentityN"
# CHECK: %[[ADD:.*]], %{{.*}} = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
# CHECK: tf_executor.fetch %[[ARG_1]], %[[ARG_0]], %[[ADD]]
# Test when non zero index output tensors are feeds, remaining ops that are
# unreachable are pruned if pruning is enabled.
#
# PRUNE: func @main(%[[ARG_0:.*]]: tensor<f32>, %[[ARG_1:.*]]: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>)
# PRUNE: attributes {tf.entry_function = {inputs = "z:1,z:2", outputs = "z:2,z:1,a:0"}}
# PRUNE-NOT: "tf.Const"
# PRUNE-NOT: "tf.VariableV2"
# PRUNE-NOT: "tf.Assign"
# PRUNE-NOT: "tf.Identity"
# PRUNE-NOT: "tf.AssignAdd"
# PRUNE-NOT: "tf.IdentityN"
# PRUNE: %[[ADD:.*]], %{{.*}} = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
# PRUNE: tf_executor.fetch %[[ARG_1]], %[[ARG_0]], %[[ADD]]
# Test when non zero index output tensors are feeds, remaining ops that are
# unreachable are preserved if pruning is not enabled.
#
# PRESERVE: func @main(%[[ARG_0:.*]]: tensor<f32>, %[[ARG_1:.*]]: tensor<f32>) -> (tensor<f32>, tensor<f32>)
# PRESERVE: attributes {tf.entry_function = {inputs = "z:1,z:2", outputs = "z:0,a:0"}}
# PRESERVE: %{{.*}}, %[[ASSIGN_ADD_CTRL:.*]] = tf_executor.island wraps "tf.AssignAdd"
# PRESERVE: %[[IDENTITY_N:.*]]:3, %{{.*}} = tf_executor.island(%[[ASSIGN_ADD_CTRL]]) wraps "tf.IdentityN"
# PRESERVE: %[[ADD:.*]], %{{.*}} = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
# PRESERVE: tf_executor.fetch %[[IDENTITY_N]]#0, %[[ADD]]

View File

@ -2,11 +2,11 @@
# CHECK: tf_executor.SwitchN
# CHECK-SAME: of 3 : tensor<i32>
# CHECK-SAME: T = "tfdtype$DT_INT32"
# CHECK-SAME: T = i32
# CHECK-SAME: name = "Case/branch_index/_3"
# CHECK: tf_executor.SwitchN
# CHECK-SAME: of 2 : tensor<f32>
# CHECK-SAME: T = "tfdtype$DT_FLOAT"
# CHECK-SAME: T = f32
# CHECK-SAME: name = "Case/Case/input_0/_7"
node {

View File

@ -250,3 +250,19 @@ func @ZerosLike_variant(%arg0: tensor<!tf.variant<tensor<2xi32>>>) -> tensor<!tf
%0 = "tf.ZerosLike"(%arg0) : (tensor<!tf.variant<tensor<2xi32>>>) -> tensor<!tf.variant<tensor<2xi32>>>
return %0 : tensor<!tf.variant<tensor<2xi32>>>
}
// CHECK-LABEL: func @addN
func @addN(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<*xf32> {
// CHECK: %[[SUM0:.*]] = "tf.AddV2"(%arg0, %arg1)
// CHECK: %[[SUM1:.*]] = "tf.AddV2"(%[[SUM0]], %arg2)
// return %[[SUM1]]
%0 = "tf.AddN"(%arg0, %arg1, %arg2) : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// CHECK-LABEL: func @addN_variant
func @addN_variant(%arg0: tensor<!tf.variant<tensor<2xf32>>>, %arg1: tensor<!tf.variant<tensor<2xf32>>>, %arg2: tensor<!tf.variant<tensor<2xf32>>>) -> tensor<!tf.variant<tensor<2xf32>>> {
// CHECK: tf.AddN
%0 = "tf.AddN"(%arg0, %arg1, %arg2) : (tensor<!tf.variant<tensor<2xf32>>>, tensor<!tf.variant<tensor<2xf32>>>, tensor<!tf.variant<tensor<2xf32>>>) -> tensor<!tf.variant<tensor<2xf32>>>
return %0 : tensor<!tf.variant<tensor<2xf32>>>
}

View File

@ -0,0 +1,26 @@
// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s
func @main() {
tf_executor.graph {
%outputs, %control = tf_executor.island wraps "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Constant", value = dense<0> : tensor<i32>} : () -> tensor<i32>
%outputs_0, %control_1 = tf_executor.island wraps "tf.LegacyCall"(%outputs) {f = @foo0} : (tensor<i32>) -> tensor<i32>
tf_executor.fetch
}
return
}
func @foo0(%arg0: tensor<*xi32>) -> tensor<*xi32> {
%0 = tf_executor.graph {
tf_executor.fetch %arg0 : tensor<*xi32>
}
return %0 : tensor<*xi32>
}
// CHECK: node {
// CHECK: name: "_tf.LegacyCall"
// CHECK-NEXT: op: "foo0"
// CHECK: library {
// CHECK-NEXT: function {
// CHECK-NEXT: signature {
// CHECK-NEXT: name: "foo0"

View File

@ -0,0 +1,244 @@
// RUN: tf-opt -split-input-file -verify-diagnostics -tf-resource-device-inference %s | FileCheck %s --dump-input=fail
// Tests that the pass can correctly propagate device attributes inside the same
// function.
// CHECK-LABEL: func @propagate_in_function
func @propagate_in_function(
%arg0: tensor<*x!tf.resource<tensor<32xf32>>> {tf.device = "/TPU:0"},
%arg1: tensor<*x!tf.resource<tensor<32xf32>>> {tf.device = "/TPU:1"}) {
tf_executor.graph {
// CHECK: tf_executor.island
%island = tf_executor.island {
// CHECK-NEXT: "tf.VarHandleOp"
%var_handle = "tf.VarHandleOp"() {container = "c", shared_name = "v0", device = "/CPU:0"}
: () -> tensor<*x!tf.resource<tensor<32xf32>>>
// CHECK-NEXT: "tf.Identity"
// CHECK-SAME: {device = "/TPU:0"}
%id0 = "tf.Identity"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>)
-> tensor<*x!tf.resource<tensor<32xf32>>>
// CHECK-NEXT: "tf.Identity"
// CHECK-SAME: {device = "/TPU:0"}
%id1 = "tf.Identity"(%id0) : (tensor<*x!tf.resource<tensor<32xf32>>>)
-> tensor<*x!tf.resource<tensor<32xf32>>>
// CHECK-NEXT: "tf.Identity"
// CHECK-SAME: {device = "/CPU:0"}
%id2 = "tf.Identity"(%var_handle) : (tensor<*x!tf.resource<tensor<32xf32>>>)
-> tensor<*x!tf.resource<tensor<32xf32>>>
%read = "tf.ReadVariableOp"(%id2) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
%id3 = "tf.Identity"(%read) : (tensor<32xf32>) -> tensor<32xf32>
tf_executor.yield
}
tf_executor.fetch %island : !tf_executor.control
}
return
}
// -----
// Tesets that the pass can propagate through tf.If's branches.
// CHECK-LABEL: func @propagate_if_op
func @propagate_if_op(
%arg0: tensor<*x!tf.resource<tensor<32xf32>>> {tf.device = "/TPU:0"},
%arg1: tensor<i1>) {
tf_executor.graph {
// CHECK: tf_executor.island
%island = tf_executor.island {
// CHECK-NEXT: "tf.Identity"
// CHECK-SAME: {device = "/TPU:0"}
%id0 = "tf.Identity"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>)
-> tensor<*x!tf.resource<tensor<32xf32>>>
// CHECK-NEXT: "tf.VarHandleOp"
%var_handle = "tf.VarHandleOp"() {container = "c", shared_name = "v0", device = "/TPU:1"}
: () -> tensor<*x!tf.resource<tensor<32xf32>>>
// CHECK-NEXT: "tf.If"
"tf.If"(%arg1, %id0, %var_handle) {
then_branch = @if_then,
else_branch = @if_else,
output_shapes = [], is_stateless = false}
: (tensor<i1>, tensor<*x!tf.resource<tensor<32xf32>>>,
tensor<*x!tf.resource<tensor<32xf32>>>) -> ()
tf_executor.yield
}
tf_executor.fetch %island : !tf_executor.control
}
return
}
// CHECK-LABEL: func @if_then
func @if_then(
%arg0: tensor<*x!tf.resource<tensor<32xf32>>>,
%arg1: tensor<*x!tf.resource<tensor<32xf32>>>) {
tf_executor.graph {
// CHECK: tf_executor.island
%island = tf_executor.island {
// CHECK-NEXT: "tf.Identity"
// CHECK-SAME: {device = "/TPU:0"}
%id0 = "tf.Identity"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>)
-> tensor<*x!tf.resource<tensor<32xf32>>>
// CHECK-NEXT: "tf.Identity"
// CHECK-SAME: {device = "/TPU:1"}
%id1 = "tf.Identity"(%arg1) : (tensor<*x!tf.resource<tensor<32xf32>>>)
-> tensor<*x!tf.resource<tensor<32xf32>>>
tf_executor.yield
}
tf_executor.fetch %island : !tf_executor.control
}
return
}
// CHECK-LABEL: func @if_else
func @if_else(
%arg0: tensor<*x!tf.resource<tensor<32xf32>>>,
%arg1: tensor<*x!tf.resource<tensor<32xf32>>>) {
tf_executor.graph {
// CHECK: tf_executor.island
%island = tf_executor.island {
// CHECK-NEXT: "tf.Identity"
// CHECK-SAME: {device = "/TPU:0"}
%id0 = "tf.Identity"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>)
-> tensor<*x!tf.resource<tensor<32xf32>>>
tf_executor.yield
}
tf_executor.fetch %island : !tf_executor.control
}
return
}
// -----
// Tesets that the pass can propagate through tf.While's branches.
// CHECK-LABEL: func @propagate_while_op
func @propagate_while_op(
%arg0: tensor<*x!tf.resource<tensor<32xf32>>> {tf.device = "/TPU:0"},
%arg1: tensor<i32>) {
tf_executor.graph {
// CHECK: tf_executor.island
%island = tf_executor.island {
// CHECK-NEXT: "tf.Identity"
// CHECK-SAME: {device = "/TPU:0"}
%id0 = "tf.Identity"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>)
-> tensor<*x!tf.resource<tensor<32xf32>>>
// CHECK-NEXT: "tf.VarHandleOp"
%var_handle = "tf.VarHandleOp"() {container = "c", shared_name = "v0", device = "/TPU:1"}
: () -> tensor<*x!tf.resource<tensor<32xf32>>>
// CHECK-NEXT: "tf.While"
"tf.While"(%arg1, %id0, %var_handle) {
body = @while_body,
cond = @while_cond,
output_shapes = [], is_stateless = false}
: (tensor<i32>, tensor<*x!tf.resource<tensor<32xf32>>>,
tensor<*x!tf.resource<tensor<32xf32>>>) ->
(tensor<i32>, tensor<*x!tf.resource<tensor<32xf32>>>,
tensor<*x!tf.resource<tensor<32xf32>>>)
tf_executor.yield
}
tf_executor.fetch %island : !tf_executor.control
}
return
}
// CHECK-LABEL: func @while_body
func @while_body(
%arg0: tensor<i32>,
%arg1: tensor<*x!tf.resource<tensor<32xf32>>>,
%arg2: tensor<*x!tf.resource<tensor<32xf32>>>) ->
(tensor<i32>, tensor<*x!tf.resource<tensor<32xf32>>>,
tensor<*x!tf.resource<tensor<32xf32>>>) {
%graph:3 = tf_executor.graph {
// CHECK: tf_executor.island
%island:4 = tf_executor.island {
// CHECK-NEXT: "tf.Identity"
// CHECK-SAME: {device = "/TPU:0"}
%id0 = "tf.Identity"(%arg1) : (tensor<*x!tf.resource<tensor<32xf32>>>)
-> tensor<*x!tf.resource<tensor<32xf32>>>
// CHECK-NEXT: "tf.Identity"
// CHECK-SAME: {device = "/TPU:1"}
%id1 = "tf.Identity"(%arg2) : (tensor<*x!tf.resource<tensor<32xf32>>>)
-> tensor<*x!tf.resource<tensor<32xf32>>>
tf_executor.yield %arg0, %id0, %id1
: tensor<i32>, tensor<*x!tf.resource<tensor<32xf32>>>,
tensor<*x!tf.resource<tensor<32xf32>>>
}
tf_executor.fetch %island#0, %island#1, %island#2
: tensor<i32>, tensor<*x!tf.resource<tensor<32xf32>>>,
tensor<*x!tf.resource<tensor<32xf32>>>
}
return %graph#0, %graph#1, %graph#2
: tensor<i32>, tensor<*x!tf.resource<tensor<32xf32>>>,
tensor<*x!tf.resource<tensor<32xf32>>>
}
// CHECK-LABEL: func @while_cond
func @while_cond(
%arg0: tensor<i32>,
%arg1: tensor<*x!tf.resource<tensor<32xf32>>>,
%arg2: tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32> {
%graph = tf_executor.graph {
// CHECK: tf_executor.island
%island:2 = tf_executor.island {
// CHECK-NEXT: "tf.Identity"
// CHECK-SAME: {device = "/TPU:0"}
%id0 = "tf.Identity"(%arg1) : (tensor<*x!tf.resource<tensor<32xf32>>>)
-> tensor<*x!tf.resource<tensor<32xf32>>>
%read = "tf.ReadVariableOp"(%id0)
: (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
tf_executor.yield %read : tensor<32xf32>
}
tf_executor.fetch %island#0 : tensor<32xf32>
}
return %graph : tensor<32xf32>
}
// -----
// Tesets that the pass reports error on conflicting assignments from multiple
// callers.
func @error_on_conflict_multiple_callers(
%arg0: tensor<*x!tf.resource<tensor<32xf32>>> {tf.device = "/TPU:0"},
%arg1: tensor<i1>) {
tf_executor.graph {
%island = tf_executor.island {
%id0 = "tf.Identity"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>)
-> tensor<*x!tf.resource<tensor<32xf32>>>
%var_handle = "tf.VarHandleOp"() {container = "c", shared_name = "v0", device = "/TPU:1"}
: () -> tensor<*x!tf.resource<tensor<32xf32>>>
"tf.If"(%arg1, %id0, %var_handle) {
then_branch = @if_then_and_else,
else_branch = @if_then_and_else,
output_shapes = [], is_stateless = false}
: (tensor<i1>, tensor<*x!tf.resource<tensor<32xf32>>>,
tensor<*x!tf.resource<tensor<32xf32>>>) -> ()
"tf.If"(%arg1, %var_handle, %id0) {
// expected-error@above {{Conflicting device assignment for resource}}
then_branch = @if_then_and_else,
else_branch = @if_then_and_else,
output_shapes = [], is_stateless = false}
: (tensor<i1>, tensor<*x!tf.resource<tensor<32xf32>>>,
tensor<*x!tf.resource<tensor<32xf32>>>) -> ()
tf_executor.yield
}
tf_executor.fetch %island : !tf_executor.control
}
return
}
func @if_then_and_else(
%arg0: tensor<*x!tf.resource<tensor<32xf32>>>,
%arg1: tensor<*x!tf.resource<tensor<32xf32>>>) {
tf_executor.graph {
%island = tf_executor.island {
%id0 = "tf.Identity"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>)
-> tensor<*x!tf.resource<tensor<32xf32>>>
%id1 = "tf.Identity"(%arg1) : (tensor<*x!tf.resource<tensor<32xf32>>>)
-> tensor<*x!tf.resource<tensor<32xf32>>>
tf_executor.yield
}
tf_executor.fetch %island : !tf_executor.control
}
return
}

View File

@ -8,7 +8,7 @@ func @only_resource_load() -> tensor<*xi32> {
// CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp"
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) {dtype = "tfdtype$DT_INT32"}
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) {dtype = i32}
// CHECK: "tf_device.launch"
// CHECK: %[[COMPUTE_RES:[0-9]*]] = "tf.SomeComputation"(%[[RES_READ_VAL]])
// CHECK: tf_device.return %[[COMPUTE_RES]]
@ -16,7 +16,7 @@ func @only_resource_load() -> tensor<*xi32> {
// CHECK-SAME: () -> tensor<*xi32>
%1 = "tf_device.launch"() ( {
%2 = "tf.ReadVariableOp"(%0) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>) -> tensor<*xi32>
%2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32>
%3 = "tf.SomeComputation"(%2) : (tensor<*xi32>) -> (tensor<*xi32>)
tf_device.return %3 : tensor<*xi32>
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<*xi32>
@ -39,11 +39,11 @@ func @only_resource_store() -> tensor<*xi32> {
// CHECK: tf_device.return %[[COMPUTE_RES]], %[[COMPUTE_RES]]
// CHECK: {device = "tpu0", launch_attr = "launch_attr"}
// CHECK-SAME: () -> (tensor<*xi32>, tensor<*xi32>)
// CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[LAUNCH_RES]]#1) {dtype = "tfdtype$DT_INT32"}
// CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[LAUNCH_RES]]#1) {dtype = i32}
%1 = "tf_device.launch"() ( {
%2 = "tf.SomeComputation"() : () -> (tensor<*xi32>)
"tf.AssignVariableOp"(%0, %2) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>, tensor<*xi32>) -> ()
"tf.AssignVariableOp"(%0, %2) {dtype = i32} : (tensor<*x!tf.resource>, tensor<*xi32>) -> ()
tf_device.return %2 : tensor<*xi32>
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<*xi32>
@ -61,18 +61,18 @@ func @same_resource_load_and_store() -> tensor<*xi32> {
// CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp"
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) {dtype = "tfdtype$DT_INT32"}
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) {dtype = i32}
// CHECK: %[[LAUNCH_RES:[0-9]*]]:2 = "tf_device.launch"
// CHECK: %[[COMPUTE_RES:[0-9]*]] = "tf.SomeComputation"(%[[RES_READ_VAL]])
// CHECK: tf_device.return %[[COMPUTE_RES]], %[[COMPUTE_RES]]
// CHECK: {device = "tpu0", launch_attr = "launch_attr"}
// CHECK-SAME: () -> (tensor<*xi32>, tensor<*xi32>)
// CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[LAUNCH_RES]]#1) {dtype = "tfdtype$DT_INT32"}
// CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[LAUNCH_RES]]#1) {dtype = i32}
%1 = "tf_device.launch"() ( {
%2 = "tf.ReadVariableOp"(%0) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>) -> tensor<*xi32>
%2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32>
%3 = "tf.SomeComputation"(%2) : (tensor<*xi32>) -> (tensor<*xi32>)
"tf.AssignVariableOp"(%0, %3) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>, tensor<*xi32>) -> ()
"tf.AssignVariableOp"(%0, %3) {dtype = i32} : (tensor<*x!tf.resource>, tensor<*xi32>) -> ()
tf_device.return %3 : tensor<*xi32>
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<*xi32>
@ -82,96 +82,6 @@ func @same_resource_load_and_store() -> tensor<*xi32> {
// -----
// Tests that composite tf.AssignAddVariableOp operation is decomposed and
// hoisted.
// CHECK-LABEL: func @decompose_assign_add_variable_op
func @decompose_assign_add_variable_op() -> tensor<*xi32> {
// CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp"
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) {dtype = "tfdtype$DT_INT32"}
// CHECK: %[[LAUNCH_RES:[0-9]*]]:2 = "tf_device.launch"
// CHECK: %[[ONE:[0-9]*]] = "tf.Const"() {value = dense<1> : tensor<i32>}
// CHECK: %[[COMPUTE_RES:[0-9]*]] = "tf.AddV2"(%[[RES_READ_VAL]], %[[ONE]])
// CHECK: tf_device.return %[[COMPUTE_RES]], %[[COMPUTE_RES]]
// CHECK: {device = "tpu0", launch_attr = "launch_attr"}
// CHECK-SAME: () -> (tensor<*xi32>, tensor<*xi32>)
// CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[LAUNCH_RES]]#1) {dtype = "tfdtype$DT_INT32"}
%1 = "tf_device.launch"() ( {
%2 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
"tf.AssignAddVariableOp"(%0, %2) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>, tensor<i32>) -> ()
%3 = "tf.ReadVariableOp"(%0) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>) -> tensor<*xi32>
tf_device.return %3 : tensor<*xi32>
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<*xi32>
// CHECK: return %[[LAUNCH_RES]]#0
return %1 : tensor<*xi32>
}
// -----
// Tests that composite tf.AssignSubVariableOp operation is decomposed using
// SubOp.
// CHECK-LABEL: func @decompose_assign_sub_variable_op
func @decompose_assign_sub_variable_op() -> tensor<*xi32> {
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"
// CHECK: %[[ONE:[0-9]*]] = "tf.Const"() {value = dense<1> : tensor<i32>}
// CHECK: "tf.Sub"(%[[RES_READ_VAL]], %[[ONE]])
// CHECK: "tf.AssignVariableOp"
%1 = "tf_device.launch"() ( {
%2 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
"tf.AssignSubVariableOp"(%0, %2) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>, tensor<i32>) -> ()
%3 = "tf.ReadVariableOp"(%0) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>) -> tensor<*xi32>
tf_device.return %3 : tensor<*xi32>
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<*xi32>
return %1 : tensor<*xi32>
}
// -----
// Tests that composite tf.ResourceApplyGradientDescent operation is decomposed
// and hoisted.
// CHECK-LABEL: func @decompose_resource_apply_gradient_descent
func @decompose_resource_apply_gradient_descent() -> tensor<*xf32> {
// CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp"
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) {dtype = "tfdtype$DT_FLOAT"}
// CHECK: %[[LAUNCH_RES:[0-9]*]]:2 = "tf_device.launch"
// CHECK: %[[ALPHA:[0-9]*]] = "tf.Const"
// CHECK: %[[DELTA:[0-9]*]] = "tf.Const"
// CHECK: %[[MUL:[0-9]*]] = "tf.Mul"(%[[ALPHA]], %[[DELTA]])
// CHECK: %[[SUB:[0-9]*]] = "tf.Sub"(%[[RES_READ_VAL]], %[[MUL]])
// CHECK: tf_device.return %[[SUB]], %[[SUB]]
// CHECK: {device = "tpu0", launch_attr = "launch_attr"}
// CHECK-SAME: () -> (tensor<*xf32>, tensor<*xf32>)
// CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[LAUNCH_RES]]#1) {dtype = "tfdtype$DT_FLOAT"}
%1 = "tf_device.launch"() ( {
%2 = "tf.Const"() {T = "tfdtype$DT_FLOAT", value = dense<[1.0]> : tensor<1xf32>} : () -> tensor<f32>
%3 = "tf.Const"() {T = "tfdtype$DT_FLOAT", value = dense<[0.5]> : tensor<1xf32>} : () -> tensor<f32>
"tf.ResourceApplyGradientDescent"(%0, %2, %3) : (tensor<*x!tf.resource>, tensor<f32>, tensor<f32>) -> ()
%4 = "tf.ReadVariableOp"(%0) {dtype = "tfdtype$DT_FLOAT"} : (tensor<*x!tf.resource>) -> tensor<*xf32>
tf_device.return %4 : tensor<*xf32>
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<*xf32>
// CHECK: return %[[LAUNCH_RES]]#0
return %1 : tensor<*xf32>
}
// -----
// Tests that internal resource operations are not hoisted.
// CHECK-LABEL: func @internal_resource
@ -184,13 +94,13 @@ func @internal_resource() -> tensor<*xi32> {
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]])
%2 = "tf.ReadVariableOp"(%1) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>) -> tensor<*xi32>
%2 = "tf.ReadVariableOp"(%1) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32>
// CHECK: %[[COMPUTE_RES:[0-9]*]] = "tf.SomeComputation"(%[[RES_READ_VAL]])
%3 = "tf.SomeComputation"(%2) : (tensor<*xi32>) -> (tensor<*xi32>)
// CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[COMPUTE_RES]])
"tf.AssignVariableOp"(%1, %3) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>, tensor<*xi32>) -> ()
"tf.AssignVariableOp"(%1, %3) {dtype = i32} : (tensor<*x!tf.resource>, tensor<*xi32>) -> ()
// CHECK: tf_device.return %[[COMPUTE_RES]]
tf_device.return %3 : tensor<*xi32>

View File

@ -1,6 +1,17 @@
// RUN: tf-opt %s -tf-shape-inference -verify-diagnostics | FileCheck %s -dump-input=fail -color
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 130 : i32}} {
// CHECK-LABEL: func @main(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32>
func @main(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<*xi32> {
// CHECK: %[[ARG0:.*]] = "tf.Cast"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
// CHECK: %[[ARG1:.*]] = "tf.Cast"(%arg1) : (tensor<1xi32>) -> tensor<1xi32>
// CHECK: %[[RESULT:.*]] = "tf.AddV2"(%[[ARG0]], %[[ARG1]]) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: return %[[RESULT]] : tensor<1xi32>
%0 = "tf.Cast"(%arg0) : (tensor<1xi32>) -> tensor<*xi32>
%1 = "tf.Cast"(%arg1) : (tensor<1xi32>) -> tensor<*xi32>
%2 = "tf.AddV2"(%0, %1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
return %2 : tensor<*xi32>
}
// CHECK-LABEL: func @simple_chain
func @simple_chain(%arg0: tensor<1xf32>) -> tensor<*xf32> {

View File

@ -6,18 +6,15 @@
// CHECK-LABEL: func @non_aliasing_reads_writes
func @non_aliasing_reads_writes(
// expected-remark@above {{ID: 13}}
// expected-remark@above {{Predecessors: {12}}}
%arg0: tensor<*x!tf.resource<tensor<32xf32>>>,
%arg1: tensor<*x!tf.resource<tensor<32xf32>>>,
%arg2: tensor<32xf32>) -> (tensor<32xf32>) {
%graph = tf_executor.graph {
// expected-remark@above {{ID: 11}}
// expected-remark@above {{Predecessors: {10}}}
// expected-remark@above {{Successors: {12}}}
// CHECK: tf_executor.island
%island:2 = tf_executor.island {
// expected-remark@above {{ID: 9}}
// expected-remark@above {{Predecessors: {8}}}
// expected-remark@above {{Successors: {10}}}
%read0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
// expected-remark@above {{ID: 0}}
@ -49,17 +46,14 @@ func @non_aliasing_reads_writes(
tf_executor.yield %read3 : tensor<32xf32>
// expected-remark@above {{ID: 8}}
// expected-remark@above {{Predecessors: {4,5,7}}}
// expected-remark@above {{Successors: {9}}}
}
tf_executor.fetch %island#0 : tensor<32xf32>
// expected-remark@above {{ID: 10}}
// expected-remark@above {{Predecessors: {9}}}
// expected-remark@above {{Successors: {11}}}
}
return %graph : tensor<32xf32>
// expected-remark@above {{ID: 12}}
// expected-remark@above {{Predecessors: {11}}}
// expected-remark@above {{Successors: {13}}}
}
// -----
@ -70,15 +64,12 @@ func @non_aliasing_reads_writes(
// CHECK-LABEL: func @aliasing_reads_writes
func @aliasing_reads_writes(%arg0: tensor<32xf32>) -> () {
// expected-remark@above {{ID: 14}}
// expected-remark@above {{Predecessors: {13}}}
tf_executor.graph {
// expected-remark@above {{ID: 12}}
// expected-remark@above {{Predecessors: {11}}}
// expected-remark@above {{Successors: {13}}}
// CHECK: tf_executor.island
%island = tf_executor.island {
// expected-remark@above {{ID: 10}}
// expected-remark@above {{Predecessors: {9}}}
// expected-remark@above {{Successors: {11}}}
%vh0 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> tensor<*x!tf.resource<tensor<32xf32>>>
// expected-remark@above {{ID: 0}}
@ -112,17 +103,14 @@ func @aliasing_reads_writes(%arg0: tensor<32xf32>) -> () {
tf_executor.yield
// expected-remark@above {{ID: 9}}
// expected-remark@above {{Predecessors: {8}}}
// expected-remark@above {{Successors: {10}}}
}
tf_executor.fetch %island : !tf_executor.control
// expected-remark@above {{ID: 11}}
// expected-remark@above {{Predecessors: {10}}}
// expected-remark@above {{Successors: {12}}}
}
return
// expected-remark@above {{ID: 13}}
// expected-remark@above {{Predecessors: {12}}}
// expected-remark@above {{Successors: {14}}}
}
// -----
@ -133,15 +121,12 @@ func @aliasing_reads_writes(%arg0: tensor<32xf32>) -> () {
// CHECK-LABEL: func @unknown_side_effecting_op
func @unknown_side_effecting_op(%arg0: tensor<32xf32>) -> () {
// expected-remark@above {{ID: 13}}
// expected-remark@above {{Predecessors: {12}}}
tf_executor.graph {
// expected-remark@above {{ID: 11}}
// expected-remark@above {{Predecessors: {10}}}
// expected-remark@above {{Successors: {12}}}
// CHECK: tf_executor.island
%island = tf_executor.island {
// expected-remark@above {{ID: 9}}
// expected-remark@above {{Predecessors: {8}}}
// expected-remark@above {{Successors: {10}}}
%vh0 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> tensor<*x!tf.resource<tensor<32xf32>>>
// expected-remark@above {{ID: 0}}
@ -172,17 +157,14 @@ func @unknown_side_effecting_op(%arg0: tensor<32xf32>) -> () {
tf_executor.yield
// expected-remark@above {{ID: 8}}
// expected-remark@above {{Predecessors: {6,7}}}
// expected-remark@above {{Successors: {9}}}
}
tf_executor.fetch %island : !tf_executor.control
// expected-remark@above {{ID: 10}}
// expected-remark@above {{Predecessors: {9}}}
// expected-remark@above {{Successors: {11}}}
}
return
// expected-remark@above {{ID: 12}}
// expected-remark@above {{Predecessors: {11}}}
// expected-remark@above {{Successors: {13}}}
}
// -----
@ -193,15 +175,12 @@ func @unknown_side_effecting_op(%arg0: tensor<32xf32>) -> () {
// CHECK-LABEL: func @read_only_unknown_resource
func @read_only_unknown_resource(%arg0: tensor<32xf32>) -> () {
// expected-remark@above {{ID: 10}}
// expected-remark@above {{Predecessors: {9}}}
tf_executor.graph {
// expected-remark@above {{ID: 8}}
// expected-remark@above {{Predecessors: {7}}}
// expected-remark@above {{Successors: {9}}}
// CHECK: tf_executor.island
%island = tf_executor.island {
// expected-remark@above {{ID: 6}}
// expected-remark@above {{Predecessors: {5}}}
// expected-remark@above {{Successors: {7}}}
%vh0 = "tf._UnknownSideEffectingOp_"() : () -> tensor<*x!tf.resource<tensor<32xf32>>>
// expected-remark@above {{ID: 0}}
@ -223,15 +202,71 @@ func @read_only_unknown_resource(%arg0: tensor<32xf32>) -> () {
tf_executor.yield
// expected-remark@above {{ID: 5}}
// expected-remark@above {{Predecessors: {4}}}
// expected-remark@above {{Successors: {6}}}
}
tf_executor.fetch %island : !tf_executor.control
// expected-remark@above {{ID: 7}}
// expected-remark@above {{Predecessors: {6}}}
// expected-remark@above {{Successors: {8}}}
}
return
// expected-remark@above {{ID: 9}}
// expected-remark@above {{Predecessors: {8}}}
// expected-remark@above {{Successors: {10}}}
}
// -----
// Tests that the pass adds control dependencies in nested regions with
// tf_device.replicate
func @with_replicate(
// expected-remark@above {{ID: 12}}
%arg0: tensor<*x!tf.resource<tensor<32xf32>>>,
%arg1: tensor<*x!tf.resource<tensor<32xf32>>>,
%arg2: tensor<*x!tf.resource<tensor<32xf32>>>,
%arg3: tensor<*x!tf.resource<tensor<32xf32>>>) {
tf_executor.graph {
// expected-remark@above {{ID: 10}}
// expected-remark@above {{Successors: {11}}}
%island = tf_executor.island {
// expected-remark@above {{ID: 8}}
// expected-remark@above {{Successors: {9}}}
%u0:2 = "tf._UnknownSideEffectingOp_"() : () -> (tensor<32xf32>, tensor<32xf32>)
// expected-remark@above {{ID: 0}}
// expected-remark@above {{Successors: {5}}}
tf_device.replicate(
// expected-remark@above {{ID: 5}}
// expected-remark@above {{Predecessors: {0}}}
// expected-remark@above {{Successors: {6}}}
[%arg0, %arg1] as %r0: tensor<*x!tf.resource<tensor<32xf32>>>,
[%arg2, %arg3] as %r1: tensor<*x!tf.resource<tensor<32xf32>>>,
[%u0#0, %u0#1] as %u : tensor<32xf32>)
{n = 2 : i32, devices = ["/CPU:0", "/GPU:1"]} {
%read0 = "tf.ReadVariableOp"(%r0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
// expected-remark@above {{ID: 1}}
// expected-remark@above {{Successors: {4}}}
"tf.AssignVariableOp"(%r1, %u) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
// expected-remark@above {{ID: 2}}
// expected-remark@above {{Successors: {3}}}
%read1 = "tf.ReadVariableOp"(%r1) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
// expected-remark@above {{ID: 3}}
// expected-remark@above {{Predecessors: {2}}}
// expected-remark@above {{Successors: {4}}}
tf_device.return
// expected-remark@above {{ID: 4}}
// expected-remark@above {{Predecessors: {1,3}}}
}
"tf._UnknownSideEffectingOp_"() : () -> ()
// expected-remark@above {{ID: 6}}
// expected-remark@above {{Predecessors: {5}}}
// expected-remark@above {{Successors: {7}}}
tf_executor.yield
// expected-remark@above {{ID: 7}}
// expected-remark@above {{Predecessors: {6}}}
}
tf_executor.fetch %island : !tf_executor.control
// expected-remark@above {{ID: 9}}
// expected-remark@above {{Predecessors: {8}}}
}
return
// expected-remark@above {{ID: 11}}
// expected-remark@above {{Predecessors: {10}}}
}

View File

@ -1610,7 +1610,7 @@ func @testSplitUnknownDimInput(%input: tensor<4x?x4xf32>) {
// -----
func @testSplitNonConstSplitDim(%input: tensor<4x4xf32>, %split_dim: tensor<1xi32>) {
func @testSplitNonScalarSplitDim(%input: tensor<4x4xf32>, %split_dim: tensor<1xi32>) {
// expected-error @+1 {{split dimension should be an integer scalar tensor}}
%0:2 = "tf.Split"(%split_dim, %input) : (tensor<1xi32>, tensor<4x4xf32>) -> (tensor<*xf32>, tensor<*xf32>)
return
@ -1674,3 +1674,152 @@ func @testTopKV2WrongKRank(%input: tensor<8xf32>, %k: tensor<5xi32>) {
%0:2 = "tf.TopKV2"(%input, %k) : (tensor<8xf32>, tensor<5xi32>) -> (tensor<*xf32>, tensor<*xi32>)
return
}
// -----
func @testSplitVScalarInput(%input: tensor<f32>, %split_sizes: tensor<2xi32>, %split_dim: tensor<i32>) {
// expected-error @+1 {{cannot split scalar input tensor}}
%0:2 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<f32>, tensor<2xi32>, tensor<i32>) -> (tensor<*xf32>, tensor<*xf32>)
return
}
// -----
func @testSplitVNonScalarSplitDim(%input: tensor<4x4xf32>, %split_sizes: tensor<2xi32>, %split_dim: tensor<1xi32>) {
// expected-error @+1 {{split dimension should be an integer scalar tensor}}
%0:2 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x4xf32>, tensor<2xi32>, tensor<1xi32>) -> (tensor<*xf32>, tensor<*xf32>)
return
}
// -----
func @testSplitVSplitDimOutOfRange(%input: tensor<4x4xf32>, %split_sizes: tensor<2xi32>) {
%split_dim = "tf.Const"() {value = dense<100>: tensor<i32>} : () -> (tensor<i32>)
// expected-error @+1 {{split dimension must be in range [-2, 2)}}
%0:2 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x4xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<*xf32>, tensor<*xf32>)
return
}
// -----
func @testSplitVWrongSplitSizesType(%input: tensor<4x4xf32>, %split_sizes: tensor<2x2xi32>, %split_dim: tensor<i32>) {
// expected-error @+1 {{op split sizes should be a 1D tensor of 2 elements}}
%0:2 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x4xf32>, tensor<2x2xi32>, tensor<i32>) -> (tensor<*xf32>, tensor<*xf32>)
return
}
// -----
func @testSplitVMultipleDynamicSizes(%input: tensor<4x4xf32>) {
%split_dim = "tf.Const"() {value = dense<1>: tensor<i32>} : () -> (tensor<i32>)
%split_sizes = "tf.Const"() {value = dense<[-1, -1]>: tensor<2xi32>} : () -> (tensor<2xi32>)
// expected-error @+1 {{cannot have more than one dynamic dimension in split sizes}}
%0:2 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x4xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<*xf32>, tensor<*xf32>)
return
}
// -----
func @testSplitVSplitSizeOutOfRange(%input: tensor<4x4xf32>) {
%split_dim = "tf.Const"() {value = dense<1>: tensor<i32>} : () -> (tensor<i32>)
%split_sizes = "tf.Const"() {value = dense<[-1, 100]>: tensor<2xi32>} : () -> (tensor<2xi32>)
// expected-error @+1 {{split sizes must sum up to be less than or equal to the dimension size along split dimension, found 100 vs 4}}
%0:2 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x4xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<*xf32>, tensor<*xf32>)
return
}
// -----
func @testSplitVSplitSizeOutOfRange(%input: tensor<4x4xf32>) {
%split_dim = "tf.Const"() {value = dense<1>: tensor<i32>} : () -> (tensor<i32>)
%split_sizes = "tf.Const"() {value = dense<[2, 3]>: tensor<2xi32>} : () -> (tensor<2xi32>)
// expected-error @+1 {{split sizes must sum up to the dimension size along split dimension, found 5 vs 4}}
%0:2 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x4xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<*xf32>, tensor<*xf32>)
return
}
// -----
func @testSplitV1(%input: tensor<4x4xf32>) {
%split_dim = "tf.Const"() {value = dense<1>: tensor<i32>} : () -> (tensor<i32>)
%split_sizes = "tf.Const"() {value = dense<[-1, 4]>: tensor<2xi32>} : () -> (tensor<2xi32>)
%0:2 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x4xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<*xf32>, tensor<*xf32>)
return
}
func @testSplitV2(%input: tensor<4x4xf32>) {
%split_dim = "tf.Const"() {value = dense<1>: tensor<i32>} : () -> (tensor<i32>)
%split_sizes = "tf.Const"() {value = dense<[3, 1]>: tensor<2xi32>} : () -> (tensor<2xi32>)
%0:2 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x4xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<*xf32>, tensor<*xf32>)
return
}
// -----
//===--------------------------------------------------------------------===//
// tf.All
//===--------------------------------------------------------------------===//
func @testAllDimWrongRank(%input: tensor<4x6xi1>, %dims: tensor<2x2xi32>) {
// expected-error @+1 {{dimensions can only be 0D or 1D tensor}}
%0 = "tf.All"(%input, %dims) : (tensor<4x6xi1>, tensor<2x2xi32>) -> (tensor<*xi1>)
return
}
// -----
func @testAllDimOutOfRange(%input: tensor<4x6xi1>) {
%dims = "tf.Const"() {value = dense<[-1, 5]> : tensor<2xi32>} : () -> (tensor<2xi32>)
// expected-error @+1 {{1-th dimension should be in the range of [-2, 2)}}
%0 = "tf.All"(%input, %dims) : (tensor<4x6xi1>, tensor<2xi32>) -> (tensor<*xi1>)
return
}
// -----
//===--------------------------------------------------------------------===//
// tf.Any
//===--------------------------------------------------------------------===//
func @testAnyDimWrongRank(%input: tensor<4x6xi1>, %dims: tensor<2x2xi32>) {
// expected-error @+1 {{dimensions can only be 0D or 1D tensor}}
%0 = "tf.Any"(%input, %dims) : (tensor<4x6xi1>, tensor<2x2xi32>) -> (tensor<*xi1>)
return
}
// -----
func @testAnyDimOutOfRange(%input: tensor<4x6xi1>) {
%dims = "tf.Const"() {value = dense<[-1, 5]> : tensor<2xi32>} : () -> (tensor<2xi32>)
// expected-error @+1 {{1-th dimension should be in the range of [-2, 2)}}
%0 = "tf.Any"(%input, %dims) : (tensor<4x6xi1>, tensor<2xi32>) -> (tensor<*xi1>)
return
}
// -----
//===--------------------------------------------------------------------===//
// tf.Unpack
//===--------------------------------------------------------------------===//
func @testUnpackAxisOutOfRange(%input: tensor<2x6xf32>) {
// expected-error @+1 {{axis attribute must be in the range of [-2, 2)}}
%0:2 = "tf.Unpack"(%input) {axis = 5} : (tensor<2x6xf32>) -> (tensor<6xf32>, tensor<6xf32>)
return
}
// -----
func @testAxisUnknownDim(%input: tensor<?x6xf32>) {
// CHECK: tf.Unpack
%0:2 = "tf.Unpack"(%input) {axis = 0} : (tensor<?x6xf32>) -> (tensor<6xf32>, tensor<6xf32>)
return
}
// -----
func @testAxisDim(%input: tensor<2x6xf32>) {
// expected-error @+1 {{result count must be equal to 6}}
%0:2 = "tf.Unpack"(%input) {axis = -1} : (tensor<2x6xf32>) -> (tensor<6xf32>, tensor<6xf32>)
return
}

View File

@ -0,0 +1,41 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# RUN: (! %p/exported_python_args 2>&1) | FileCheck %s
# pylint: disable=missing-docstring,line-too-long,dangerous-default-value
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow.compat.v2 as tf
from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common
class TestModule(tf.Module):
@tf.function(input_signature=[tf.TensorSpec([], tf.float32)])
def some_function(self, x):
return self.callee(x)
# CHECK: While importing SavedModel function 'callee': in input signature:
# CHECK-SAME: Unhandled structured value kind {{.*}} at index path: <value>.1.foo
@tf.function
def callee(self, x, n={'foo': 42}):
return x
if __name__ == '__main__':
common.do_test(TestModule)

View File

@ -31,9 +31,14 @@ void CreateTPUBridge(OpPassManager &pm) {
func_pm.addPass(tf_executor::CreateTFExecutorIslandCoarseningPass());
func_pm.addPass(CreateTPUClusterFormationPass());
func_pm.addPass(createCanonicalizerPass());
// Place DecomposeResourceOpsPass before TFExecutorConstantSinking pass
// because DecomposeResourceOpsPass uses pattern rewriter which hoists
// changed constants out of tf_device.Launch.
func_pm.addPass(TFDevice::CreateDecomposeResourceOpsPass());
func_pm.addPass(tf_executor::CreateTFExecutorConstantSinkingPass());
func_pm.addPass(TFDevice::CreateResourceOpLiftingPass());
pm.addPass(TF::CreateResourceDeviceInferencePass());
pm.addPass(TFDevice::CreateClusterOutliningPass());
pm.addPass(CreateTPURewritePass());
pm.addNestedPass<FuncOp>(TFDevice::CreateReplicateInvariantOpHoistingPass());

View File

@ -44,16 +44,16 @@ struct ClusterOutliningPass : public ModulePass<ClusterOutliningPass> {
void ReplaceLaunchReturnWithReturn(tf_device::ReturnOp launch_return_op,
OpBuilder* builder) {
llvm::SmallVector<Value*, 4> operands(launch_return_op.getOperands());
builder->create<ReturnOp>(launch_return_op.getLoc(), operands);
builder->create<ReturnOp>(launch_return_op.getLoc(),
launch_return_op.getOperands());
launch_return_op.erase();
}
// Builds a function that outlines region attached to launch_op and inserts
// built function into given module.
FuncOp BuildFunction(StringRef device, llvm::ArrayRef<Value*> live_ins,
tf_device::LaunchOp launch_op,
ModuleManager* module_manager, OpBuilder* builder) {
tf_device::LaunchOp launch_op, SymbolTable* symbol_table,
OpBuilder* builder) {
llvm::SmallVector<Type, 4> operand_types;
operand_types.reserve(live_ins.size());
for (Value* v : live_ins) operand_types.emplace_back(v->getType());
@ -92,14 +92,14 @@ FuncOp BuildFunction(StringRef device, llvm::ArrayRef<Value*> live_ins,
builder->setInsertionPoint(launch_return_op);
ReplaceLaunchReturnWithReturn(launch_return_op, builder);
module_manager->insert(outlined_func);
symbol_table->insert(outlined_func);
return outlined_func;
}
// Outlines body of `tf_device.launch` into a function and create a
// `tf_device.launch_func` to invoke that function. `tf_device.launch` is
// removed afterwards.`
void OutlineLaunch(tf_device::LaunchOp launch_op, ModuleManager* module_manager,
void OutlineLaunch(tf_device::LaunchOp launch_op, SymbolTable* symbol_table,
OpBuilder* builder) {
llvm::SetVector<Value*> live_ins;
getUsedValuesDefinedAbove(launch_op.body(), launch_op.body(), live_ins);
@ -108,7 +108,7 @@ void OutlineLaunch(tf_device::LaunchOp launch_op, ModuleManager* module_manager,
launch_op.getAttrOfType<StringAttr>(kDeviceAttr).getValue();
FuncOp outlined_func = BuildFunction(device, live_ins.getArrayRef(),
launch_op, module_manager, builder);
launch_op, symbol_table, builder);
launch_op.setAttr(builder->getIdentifier(kFuncAttr),
builder->getSymbolRefAttr(outlined_func.getName()));
@ -124,10 +124,10 @@ void OutlineLaunch(tf_device::LaunchOp launch_op, ModuleManager* module_manager,
void ClusterOutliningPass::runOnModule() {
ModuleOp m = getModule();
ModuleManager module_manager(m);
SymbolTable symbol_table(m);
OpBuilder builder(m.getContext());
m.walk([&](tf_device::LaunchOp launch) {
OutlineLaunch(launch, &module_manager, &builder);
OutlineLaunch(launch, &symbol_table, &builder);
});
}

View File

@ -0,0 +1,31 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
namespace mlir {
namespace TF {
#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_decompose_resource_ops.inc"
void PopulateDecomposeResourceOpsPatterns(MLIRContext *context,
OwningRewritePatternList *patterns) {
populateWithGenerated(context, patterns);
}
} // namespace TF
} // namespace mlir

View File

@ -0,0 +1,34 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_DECOMPOSE_RESOURCE_OPS_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_DECOMPOSE_RESOURCE_OPS_H_
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
namespace mlir {
namespace TF {
// Populates rewrite patterns that decompose composite resource operations into
// primitive ones like ReadVariableOp, AssignVariableOp and other computations
// to facilitate transformations like resource op lifting.
void PopulateDecomposeResourceOpsPatterns(MLIRContext *context,
OwningRewritePatternList *patterns);
} // namespace TF
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_DECOMPOSE_RESOURCE_OPS_H_

View File

@ -0,0 +1,63 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
include "mlir/IR/OpBase.td"
include "mlir/Dialect/StandardOps/Ops.td"
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
def CreateTFReadVariableOp: NativeCodeCall<
"$_builder.create<TF::ReadVariableOp>("
" $0.getLoc(),"
" UnrankedTensorType::get("
" $1->getType().cast<TensorType>().getElementType()),"
" $2)"
>;
def DecomposeAssignAddVariableOp :
Pat<
(TF_AssignAddVariableOp:$src_op $resource, $value),
(TF_AssignVariableOp
$resource,
(TF_AddV2Op
(CreateTFReadVariableOp $src_op, $value, $resource),
$value
)
)
>;
def DecomposeAssignSubVariableOp :
Pat<
(TF_AssignSubVariableOp:$src_op $resource, $value),
(TF_AssignVariableOp
$resource,
(TF_SubOp
(CreateTFReadVariableOp $src_op, $value, $resource),
$value
)
)
>;
// This decomposition is only correct inside XLA as it ignores use_locking
// attribute.
def DecomposeResourceApplyGradientDescentOp :
Pat<
(TF_ResourceApplyGradientDescentOp:$src_op $resource, $alpha, $delta, $_),
(TF_AssignVariableOp
$resource,
(TF_SubOp
(CreateTFReadVariableOp $src_op, $alpha, $resource),
(TF_MulOp $alpha, $delta)
)
)
>;

View File

@ -0,0 +1,59 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.h"
namespace mlir {
namespace TFDevice {
namespace {
// A pass that decomposes composite resource operations into primitive ones like
// ReadVariableOp, AssignVariableOp and other computations to facilitate
// transformations like resource op lifting.
//
// For example:
//
// tf.AssignAddVariableOp(%res, %0)
//
// Becomes
//
// %res_val = tf.ReadVariableOp(%res)
// %1 = tf.AddV2(%res_val, %0)
// tf.AssignVariableOp(%res, %1)
struct DecomposeResourceOps : public FunctionPass<DecomposeResourceOps> {
void runOnFunction() override {
// Add lowering patterns to the list.
OwningRewritePatternList patterns;
mlir::TF::PopulateDecomposeResourceOpsPatterns(&getContext(), &patterns);
applyPatternsGreedily(getFunction(), patterns);
}
};
} // namespace
std::unique_ptr<OpPassBase<FuncOp>> CreateDecomposeResourceOpsPass() {
return std::make_unique<DecomposeResourceOps>();
}
} // namespace TFDevice
} // namespace mlir
static mlir::PassRegistration<mlir::TFDevice::DecomposeResourceOps> pass(
"tf-device-decompose-resource-ops",
"Decompose composite resource variable operations into primitive "
"Read/AssignVariableOp and raw computation");

View File

@ -304,8 +304,7 @@ void InsertDummyIslandForFetch(FetchOp fetch) {
/*control=*/ControlType::get(fetch.getContext()),
/*controlInputs=*/control_fetches);
island.body().push_back(new Block);
OpBuilder(&island.GetBody())
.create<YieldOp>(fetch.getLoc(), llvm::to_vector<4>(data_fetches));
OpBuilder(&island.GetBody()).create<YieldOp>(fetch.getLoc(), data_fetches);
const int fetch_control_idx = data_fetches.size();
for (int i = 0, e = fetch.getNumOperands(); i < e; i++) {
// The fetch could have multiple control operands (all at the end of its

View File

@ -17,6 +17,7 @@ limitations under the License.
// standard TensorFlow dialect to MLIR Control Flow Graph (CFG) form.
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
@ -79,8 +80,11 @@ static Operation* CallFn(Location loc,
for (int i = 0; i < num_operands; ++i) {
Value* val = get_arg(i);
Type expected = fn_type.getInput(i);
if (val->getType() != expected)
val = builder->create<TensorCastOp>(loc, val, expected);
if (val->getType() != expected) {
val =
builder->create<TF::CastOp>(loc, expected, val,
/*Truncate=*/builder->getBoolAttr(false));
}
operands.push_back(val);
}
return builder->create<CallOp>(loc, fn, operands).getOperation();
@ -100,8 +104,11 @@ static llvm::SmallVector<Value*, 4> PrepareValsForJump(
for (int i = 0; i < num_vals; ++i) {
Value* val = get_val(i);
Type expected = block->getArgument(i)->getType();
if (val->getType() != expected)
val = builder->create<TensorCastOp>(loc, val, expected);
if (val->getType() != expected) {
val =
builder->create<TF::CastOp>(loc, expected, val,
/*Truncate=*/builder->getBoolAttr(false));
}
result.push_back(val);
}
return result;
@ -131,8 +138,11 @@ static void ReplaceOpResultWithBlockArgs(Location loc, Operation* op,
for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
Value* arg = block->getArgument(i);
Value* result = op->getResult(i);
if (arg->getType() != result->getType())
arg = builder->create<TensorCastOp>(loc, arg, result->getType());
if (arg->getType() != result->getType()) {
arg =
builder->create<TF::CastOp>(loc, result->getType(), arg,
/*Truncate=*/builder->getBoolAttr(false));
}
result->replaceAllUsesWith(arg);
}
}
@ -301,26 +311,15 @@ void FunctionalControlFlowToCFG::runOnFunction() {
// subsequent blocks.
//
// TODO: Use PatternRewriter to eliminate these function control flow ops.
auto has_variant_operand = [](Operation* op) {
auto is_variant = [](Type ty) {
return getElementTypeOrSelf(ty).getKind() == TensorFlowTypes::VARIANT;
};
if (llvm::none_of(op->getOperandTypes(), is_variant)) return false;
op->emitOpError() << "does not yet support operands of type variant "
"for conversion to CFG";
return true;
};
if (IfOp if_op = llvm::dyn_cast<IfOp>(op)) {
if (has_variant_operand(&op) || failed(LowerIfOp(if_op))) {
if (failed(LowerIfOp(if_op))) {
return signalPassFailure();
}
break;
}
if (WhileOp while_op = llvm::dyn_cast<WhileOp>(op)) {
if (has_variant_operand(&op) || failed(LowerWhileOp(while_op))) {
if (failed(LowerWhileOp(while_op))) {
return signalPassFailure();
}
break;

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/core/util/tensor_format.h"
namespace mlir {
@ -109,6 +110,39 @@ Type InferExpandDimsType(Type ty, int64_t axis, Builder *builder) {
return RankedTensorType::get(shape, ranked_ty.getElementType());
}
// Lowers AddN op to a sequence of AddV2 ops to accumulate operands.
//
// %result = "tf.AddN"(%0, %1, %2)
//
// is lowered to:
//
// %sum_0 = "tf.AddV2"(%0, %1)
// %result = "tf.AddV2"(%sum_0, %2)
//
class LowerAddNOp : public OpRewritePattern<TF::AddNOp> {
public:
explicit LowerAddNOp(MLIRContext *context)
: OpRewritePattern<TF::AddNOp>(context) {}
PatternMatchResult matchAndRewrite(TF::AddNOp op,
PatternRewriter &rewriter) const override {
// TODO(hinsu): Support variant with TensorList type. tf.AddV2 doesn't
// support variant type so variant types require special handling.
if (getElementTypeOrSelf(op.getType()).isa<VariantType>())
return matchFailure();
// TODO(hinsu): Improve parallelism by splitting operands in two halves and
// accumulating them first.
Value *result = *op.inputs().begin();
for (Value *operand : llvm::drop_begin(op.inputs(), 1)) {
result = rewriter.create<TF::AddV2Op>(op.getLoc(), result, operand);
}
rewriter.replaceOp(op, result);
return matchSuccess();
}
};
// Lowers Pack op to ConcatV2 op after changing shape of the inputs with
// ExpandDims op.
//
@ -159,6 +193,7 @@ class LowerPackOp : public OpRewritePattern<TF::PackOp> {
void PopulateLoweringTFPatterns(MLIRContext *context,
OwningRewritePatternList *patterns) {
patterns->insert<LowerAddNOp>(context);
patterns->insert<LowerPackOp>(context);
populateWithGenerated(context, patterns);
}

View File

@ -63,7 +63,7 @@ void CreateTFStandardPipeline(OpPassManager &pm,
if (options.enable_inliner) {
pm.addPass(createInlinerPass());
}
pm.addNestedPass<FuncOp>(CreateTFShapeInferencePass());
pm.addPass(CreateTFShapeInferencePass());
pm.addNestedPass<FuncOp>(CreateTFOptimizePass());
pm.addNestedPass<FuncOp>(createCSEPass());
}

View File

@ -57,6 +57,9 @@ struct StandardPipelineOptions : public PassOptions<StandardPipelineOptions> {
// NOLINTNEXTLINE - MLIR contract is pass by mutable reference.
void CreateTFStandardPipeline(OpPassManager& pm,
const StandardPipelineOptions& options);
// Propagates device attributes of resources from callers to callees.
std::unique_ptr<OpPassBase<ModuleOp>> CreateResourceDeviceInferencePass();
} // namespace TF
namespace TFControlFlow {
@ -96,6 +99,11 @@ std::unique_ptr<OpPassBase<FuncOp>> CreateClusterFormationPass();
// Creates a pass that outlines regions of tf_device.launch operations.
std::unique_ptr<OpPassBase<ModuleOp>> CreateClusterOutliningPass();
// A pass that decomposes composite resource operations into primitive ones like
// ReadVariableOp, AssignVariableOp and other computations to facilitate
// transformations like resource op lifting.
std::unique_ptr<OpPassBase<FuncOp>> CreateDecomposeResourceOpsPass();
// Creates a pass that lifts operations on external resource variables from
// device computation nested in `tf_device::LaunchOp` out so that resource
// variable load operations are all before device computation while resource

View File

@ -64,8 +64,8 @@ llvm::SmallVector<tf_executor::IslandOp, 8> ExpandReplicateIntoReplicas(
// Replace replicate terminator with YieldOp.
builder->setInsertionPoint(&terminator);
builder->create<tf_executor::YieldOp>(
terminator.getLoc(), llvm::to_vector<8>(terminator.getOperands()));
builder->create<tf_executor::YieldOp>(terminator.getLoc(),
terminator.getOperands());
terminator.erase();
builder->setInsertionPoint(island_op);

View File

@ -0,0 +1,278 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <iterator>
#include <memory>
#include <tuple>
#include <utility>
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/Casting.h"
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Function.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/IR/Types.h" // TF:local_config_mlir
#include "mlir/IR/Value.h" // TF:local_config_mlir
#include "mlir/IR/Visitors.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
namespace mlir {
namespace TF {
namespace {
constexpr char kDeviceAttr[] = "device";
constexpr char kFuncDeviceAttr[] = "tf.device";
// A pass that propagates device assignment of resources on a module. It
// performs in-function propagation, as well as cross-function propagation from
// callers to callees.
//
// This pass changes the module by adding "tf.device" attribute to function
// arguments and adding "device" attribute to TF ops.
struct ResourceDeviceInference : public ModulePass<ResourceDeviceInference> {
void runOnModule() override;
};
// A class that records each resource's device assignment in a function.
class PerFunctionResult {
public:
explicit PerFunctionResult(FuncOp func_op) : alias_analysis_(func_op) {}
// Returns the recorded device assignment for a resource, if any.
llvm::Optional<llvm::StringRef> DeviceForResource(
const Value* resource) const {
llvm::Optional<llvm::StringRef> result;
if (alias_analysis_.IsUnknownResource(resource)) return result;
for (int64_t id : alias_analysis_.GetResourceUniqueIds(resource)) {
auto it = resource_id_to_device_.find(id);
if (it == resource_id_to_device_.end()) continue;
if (!result) {
result = it->getSecond();
continue;
}
if (result != it->getSecond()) {
// Got conflicting assignments, clear the result.
result.reset();
return result;
}
}
return result;
}
// Records the device assignment for a resource. If the new assignment
// conflicts with an existing one, returns an error.
//
// If `changed` is provided, assign *changed to true if anything is modified.
LogicalResult AddResourceDevice(const Value* resource, llvm::StringRef device,
bool* changed = nullptr) {
if (alias_analysis_.IsUnknownResource(resource)) return success();
for (int64_t id : alias_analysis_.GetResourceUniqueIds(resource)) {
auto emplace_res = resource_id_to_device_.try_emplace(id, device);
if (emplace_res.second) {
if (changed) *changed = true;
} else if (emplace_res.first->getSecond() != device) {
// Existing assignment does not equal the new assignment.
return failure();
}
}
return success();
}
private:
llvm::SmallDenseMap<int64_t, llvm::StringRef, 8> resource_id_to_device_;
TF::ResourceAliasAnalysis alias_analysis_;
};
// Tries to record device assignment for a resource.
LogicalResult AddResourceDeviceAndEmitError(const Value* resource,
llvm::StringRef device,
Operation* error_reporting_op,
PerFunctionResult* result,
bool* changed = nullptr) {
auto res = result->AddResourceDevice(resource, device, changed);
if (failed(res)) {
error_reporting_op->emitError()
<< "Conflicting device assignment for resource";
}
return res;
}
// Propagates device assignment inside a function.
LogicalResult ComputeResourceDevicesInComputation(FuncOp func_op,
PerFunctionResult* result) {
OpBuilder builder(func_op);
// Function arguments.
for (auto arg : func_op.getArguments()) {
if (!mlir::getElementTypeOrSelf(arg->getType()).isa<TF::ResourceType>()) {
continue;
}
auto device_attr = func_op.getArgAttrOfType<mlir::StringAttr>(
arg->getArgNumber(), kFuncDeviceAttr);
if (!device_attr || device_attr.getValue() == "") {
// If device_attr does not exist, try to construct it from any recorded
// assignment.
if (auto device = result->DeviceForResource(arg)) {
func_op.setArgAttr(arg->getArgNumber(), kFuncDeviceAttr,
builder.getStringAttr(*device));
}
continue;
}
// Record the attribute.
auto res = AddResourceDeviceAndEmitError(arg, device_attr.getValue(),
func_op, result);
if (failed(res)) return res;
}
auto walk_res = func_op.walk([&](Operation* op) {
if (auto var_handle = llvm::dyn_cast<TF::VarHandleOp>(op)) {
// Record VarHanldeOp's device attribute.
auto device_attr =
var_handle.getAttrOfType<mlir::StringAttr>(kDeviceAttr);
if (!device_attr || device_attr.getValue().empty()) {
return WalkResult::advance();
}
auto res = AddResourceDeviceAndEmitError(
var_handle.resource(), device_attr.getValue(), op, result);
if (failed(res)) return WalkResult::interrupt();
}
if (auto identity = llvm::dyn_cast<TF::IdentityOp>(op)) {
// Try to construct IdentityOp's attribute from recorded assignment.
if (!mlir::getElementTypeOrSelf(identity.output()->getType())
.isa<TF::ResourceType>()) {
return WalkResult::advance();
}
if (auto device = result->DeviceForResource(identity.output())) {
auto device_attr =
identity.getAttrOfType<mlir::StringAttr>(kDeviceAttr);
if (!device_attr || device_attr.getValue().empty()) {
identity.setAttr(kDeviceAttr, builder.getStringAttr(*device));
}
}
return WalkResult::advance();
}
// Propagate and record output device assignment for other ops based on
// existing recording. E.g., IdentityN.
for (auto output : op->getResults()) {
if (!mlir::getElementTypeOrSelf(output->getType())
.isa<TF::ResourceType>()) {
continue;
}
if (auto device = result->DeviceForResource(output)) {
auto res = AddResourceDeviceAndEmitError(output, *device, op, result);
if (failed(res)) return WalkResult::interrupt();
}
}
return WalkResult::advance();
});
return failure(walk_res.wasInterrupted());
}
void ResourceDeviceInference::runOnModule() {
auto module = getModule();
llvm::SmallDenseMap<Operation*, PerFunctionResult, 4> per_function_results;
llvm::SetVector<FuncOp> worklist;
module.walk([&](FuncOp func_op) {
worklist.insert(func_op);
per_function_results.try_emplace(func_op, func_op);
});
// Helper that propagates an op's recorded operand device assignments to its
// called function's arguments.
auto propagate_operands_to_callee_arguments =
[&](Operation* caller,
llvm::iterator_range<OperandIterator> caller_operands,
llvm::StringRef called_func_name,
const PerFunctionResult& caller_res) {
auto callee =
llvm::dyn_cast<FuncOp>(module.lookupSymbol(called_func_name));
assert(callee);
auto& callee_res = per_function_results.find(callee)->getSecond();
bool callee_needs_recompute = false;
for (auto operand_and_argument :
llvm::zip(caller_operands, callee.getArguments())) {
if (!mlir::getElementTypeOrSelf(
std::get<0>(operand_and_argument)->getType())
.isa<TF::ResourceType>()) {
continue;
}
auto device =
caller_res.DeviceForResource(std::get<0>(operand_and_argument));
if (!device) continue;
if (failed(AddResourceDeviceAndEmitError(
std::get<1>(operand_and_argument), *device, caller,
&callee_res, &callee_needs_recompute))) {
return failure();
}
}
// If the callee recording is modified, make sure that it will be
// reprocessed.
if (callee_needs_recompute) {
worklist.insert(callee);
}
return success();
};
while (!worklist.empty()) {
auto func_op = worklist.back();
worklist.pop_back();
auto& func_res = per_function_results.find(func_op)->getSecond();
// In-function propagation.
if (failed(ComputeResourceDevicesInComputation(func_op, &func_res))) {
return signalPassFailure();
}
// Propagation to callees.
auto walk_res = func_op.walk([&](Operation* op) {
if (auto while_op = llvm::dyn_cast<TF::WhileOp>(op)) {
if (failed(propagate_operands_to_callee_arguments(
while_op, while_op.getOperands(), while_op.body(), func_res)) ||
failed(propagate_operands_to_callee_arguments(
while_op, while_op.getOperands(), while_op.cond(), func_res))) {
return WalkResult::interrupt();
}
} else if (auto if_op = llvm::dyn_cast<TF::IfOp>(op)) {
if (failed(propagate_operands_to_callee_arguments(
if_op, if_op.input(), if_op.then_branch(), func_res)) ||
failed(propagate_operands_to_callee_arguments(
if_op, if_op.input(), if_op.else_branch(), func_res))) {
return WalkResult::interrupt();
}
}
return WalkResult::advance();
});
if (walk_res.wasInterrupted()) return signalPassFailure();
}
}
} // namespace
std::unique_ptr<OpPassBase<ModuleOp>> CreateResourceDeviceInferencePass() {
return std::make_unique<ResourceDeviceInference>();
}
static PassRegistration<ResourceDeviceInference> pass(
"tf-resource-device-inference",
"Propagates the device attribute on resources from callers to callees.");
} // namespace TF
} // namespace mlir

View File

@ -77,129 +77,6 @@ struct ResourceOpLiftingPass : public FunctionPass<ResourceOpLiftingPass> {
void runOnFunction() override;
};
// Rewrites composite variable op `tf.AssignAddVariableOp` or
// `tf.AssignSubVariableOp` into primitive resource/computation ops.
// For example:
//
// tf.AssignAddVariableOp(%res, %0)
//
// Becomes
//
// %res_val = tf.ReadVariableOp(%res)
// %1 = tf.AddV2(%res_val, %0)
// tf.AssignVariableOp(%res, %1)
//
template <typename T>
LogicalResult RewriteCompositeAssignVariableOp(T src_op, OpBuilder* builder) {
// Read mangled dtype, which indicates type of data stored in resource
// variable. It can then be used to construct type needed for both
// ReadVariableOp and AssignVariableOp.
StringAttr mangled_dtype_attr =
src_op.template getAttrOfType<StringAttr>(kDTypeAttr);
std::string type_string = mangled_dtype_attr.getValue();
tensorflow::DataType dtype_proto;
auto s =
tensorflow::mangling_util::DemangleDataType(type_string, &dtype_proto);
if (!s.ok()) return src_op.emitError() << s.error_message();
Type type;
s = tensorflow::ConvertDataType(dtype_proto, *builder, &type);
if (!s.ok()) return src_op.emitError() << s.error_message();
type = UnrankedTensorType::get(type);
builder->setInsertionPoint(src_op);
auto read_variable_op = builder->create<TF::ReadVariableOp>(
src_op.getLoc(), type, src_op.resource());
read_variable_op.setAttr(builder->getIdentifier(kDTypeAttr),
mangled_dtype_attr);
Value* result;
if (std::is_same<T, TF::AssignAddVariableOp>()) {
result = builder->create<TF::AddV2Op>(
src_op.getLoc(), read_variable_op.value(), src_op.value());
} else {
result = builder->create<TF::SubOp>(
src_op.getLoc(), read_variable_op.value(), src_op.value());
}
auto assign_variable_op = builder->create<TF::AssignVariableOp>(
src_op.getLoc(), src_op.resource(), result);
assign_variable_op.setAttr(builder->getIdentifier(kDTypeAttr),
mangled_dtype_attr);
src_op.erase();
return success();
}
// Rewrites `tf.ResourceApplyGradientDescent` into primitive resource and
// computation ops.
//
// Specifically:
//
// tf.ResourceApplyGradientDescent(%var, %alpha, %delta)
//
// Becomes
//
// %old_var_val = tf.ReadVariableOp(%var)
// %gradient_update = tf.Mul(%alpha, %delta)
// %new_var_val = tf.Sub(%old_var_val, %gradient_update)
// tf.AssignVariableOp(%var, %new_var_val)
LogicalResult RewriteResourceApplyGradientDescentOp(
TF::ResourceApplyGradientDescentOp op, OpBuilder* builder) {
Type type = op.alpha()->getType();
auto t = UnrankedTensorType::get(type.cast<TensorType>().getElementType());
tensorflow::DataType data_type;
auto s = tensorflow::ConvertToDataType(type, &data_type);
if (!s.ok()) return op.emitError() << s.error_message();
std::string mangled_data_type =
tensorflow::mangling_util::MangleDataType(data_type);
auto mangled_dtype_attr = builder->getStringAttr(mangled_data_type);
builder->setInsertionPoint(op);
auto read_variable_op =
builder->create<TF::ReadVariableOp>(op.getLoc(), t, op.var());
read_variable_op.setAttr(builder->getIdentifier(kDTypeAttr),
mangled_dtype_attr);
auto mul_op =
builder->create<TF::MulOp>(op.getLoc(), t, op.alpha(), op.delta());
auto sub_op = builder->create<TF::SubOp>(
op.getLoc(), t, read_variable_op.value(), mul_op.z());
auto assign_variable_op =
builder->create<TF::AssignVariableOp>(op.getLoc(), op.var(), sub_op.z());
assign_variable_op.setAttr(builder->getIdentifier(kDTypeAttr),
mangled_dtype_attr);
op.erase();
return success();
}
// Rewrites an operation that updates value of a resource variable into its
// equivalent primitive ones so that following analysis/rewrite can be easier.
// If given op is not a composite resource store op or is an unsupported op, no
// change is applied.
// TODO(ycao): Explore using pattern rewriter after needed operations are
// defined.
// TODO(ycao): Add support for other composite resource store ops.
LogicalResult MaybeRewriteCompositeResourceStore(Operation* op,
OpBuilder* builder) {
if (auto assign_add_op = dyn_cast<TF::AssignAddVariableOp>(op)) {
return RewriteCompositeAssignVariableOp(assign_add_op, builder);
} else if (auto assign_sub_op = dyn_cast<TF::AssignSubVariableOp>(op)) {
return RewriteCompositeAssignVariableOp(assign_sub_op, builder);
} else if (auto resource_apply_gradient_descent_op =
dyn_cast<TF::ResourceApplyGradientDescentOp>(op)) {
return RewriteResourceApplyGradientDescentOp(
resource_apply_gradient_descent_op, builder);
}
return success();
}
// Performs store-load forwarding. This effectively removes
// 1) Any resource loads after a store to that same resource is done
// 2) Any resource stores except the last one.
@ -358,10 +235,6 @@ void HoistResourceOpsFromLaunchOp(tf_device::LaunchOp launch_op) {
ModuleOp m = launch_op.getParentOfType<ModuleOp>();
OpBuilder builder(m);
// Rewrite composite resource store operations into primitive ones.
launch_op.walk(
[&](Operation* op) { MaybeRewriteCompositeResourceStore(op, &builder); });
// Perform store-load forwarding. So that each resource is only loaded with
// its initial value and is only stored with its final value.
ForwardStoreToLoad(launch_op);

View File

@ -47,6 +47,46 @@ using ::tensorflow::int64;
namespace mlir {
namespace TF {
namespace {
Optional<llvm::SmallVector<mlir::Type, 4>> InferShapeForFunctionReturnType(
FuncOp func) {
// Only infer shape when there is one return op for now.
if (!has_single_element(func.getBody()) || func.front().empty()) {
return None;
}
// Find the return type.
auto return_op = dyn_cast<mlir::ReturnOp>(func.front().back());
if (!return_op) {
return None;
}
// Manually fold tf.Cast that precedes the return instruction and only differs
// in shape refinement level.
for (OpOperand& arg_op : return_op.getOperation()->getOpOperands()) {
if (auto cast_op = dyn_cast<CastOp>(arg_op.get()->getDefiningOp())) {
// Shape inference should not change the element type.
if (cast_op.SrcT() != cast_op.DstT()) continue;
// We only refine the result shape if the result a dynamic shape, the
// input has static shape, and the two shapes are compatible.
auto has_static_shape = [](const Value* value) {
auto shaped_type = value->getType().dyn_cast<ShapedType>();
return shaped_type && shaped_type.hasStaticShape();
};
Value* input = cast_op.x();
Value* result = cast_op.y();
if (!has_static_shape(input) || has_static_shape(result) ||
failed(verifyCompatibleShape(input->getType(), result->getType())))
continue;
arg_op.set(cast_op.x());
if (cast_op.y()->use_empty()) cast_op.erase();
}
}
return llvm::to_vector<4>(return_op.getOperandTypes());
}
} // namespace
bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
int64_t graph_version) {
@ -245,11 +285,10 @@ LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version,
return success();
}
LogicalResult InferShapeForFunction(FuncOp op,
LogicalResult InferShapeForFunction(FuncOp func,
ArrayRef<ArrayRef<int64_t>> arg_shapes,
int64_t graph_version) {
auto main_func = op;
mlir::FunctionType func_type = main_func.getType();
mlir::FunctionType func_type = func.getType();
bool needs_refinement = false;
llvm::SmallVector<mlir::Type, 4> new_arg_types;
new_arg_types.reserve(func_type.getNumInputs());
@ -276,7 +315,7 @@ LogicalResult InferShapeForFunction(FuncOp op,
auto new_arg_type = mlir::RankedTensorType::get(shape, element_type);
if (new_arg_type != func_type.getInput(i)) {
// If the new type is more detailed, trigger shape inference.
main_func.getArgument(i)->setType(new_arg_type);
func.getArgument(i)->setType(new_arg_type);
needs_refinement = true;
}
new_arg_types.push_back(new_arg_type);
@ -287,39 +326,28 @@ LogicalResult InferShapeForFunction(FuncOp op,
}
mlir::LogicalResult result =
mlir::TF::InferShapeUntilFixPoint(&main_func.getBody(), graph_version);
mlir::TF::InferShapeUntilFixPoint(&func.getBody(), graph_version);
if (failed(result)) {
return failure();
}
// Must only have 1 block so that there is only one return op.
if (main_func.getBody().getBlocks().size() != 1 ||
main_func.front().empty()) {
return failure();
auto return_types = InferShapeForFunctionReturnType(func);
func.setType(mlir::FunctionType::get(new_arg_types,
return_types.hasValue()
? return_types.getValue()
: func.getType().getResults(),
func.getContext()));
return success();
}
LogicalResult InferShapeForFunctionType(FuncOp func) {
if (auto return_types = InferShapeForFunctionReturnType(func)) {
func.setType(mlir::FunctionType::get(func.getType().getInputs(),
return_types.getValue(),
func.getContext()));
}
// Find the return type.
auto return_op = dyn_cast<mlir::ReturnOp>(*main_func.front().rbegin());
if (!return_op) {
return failure();
}
// Manually fold tf.Cast that precedes the return instruction and only differ
// in shape refinement level.
for (OpOperand& arg_op : return_op.getOperation()->getOpOperands()) {
if (auto cast_op = dyn_cast<CastOp>(arg_op.get()->getDefiningOp())) {
if (cast_op.SrcT() != cast_op.DstT()) continue;
arg_op.set(cast_op.x());
if (cast_op.y()->use_empty()) cast_op.erase();
}
}
llvm::SmallVector<mlir::Type, 4> return_types(return_op.getOperandTypes());
// Update function signature with the results of inference.
main_func.setType(
mlir::FunctionType::get(new_arg_types, return_types, op.getContext()));
return success();
}

View File

@ -41,12 +41,16 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version,
int64_t max_iteration = 10);
// Given a list of refined shapes matching the function arguments of op, run
// Given a list of refined shapes matching the function arguments of func, runs
// shape inference over the function to propagate this updated information.
LogicalResult InferShapeForFunction(FuncOp op,
LogicalResult InferShapeForFunction(FuncOp func,
ArrayRef<ArrayRef<int64_t>> arg_shapes,
int64_t graph_version);
// Refines the return type of the given function by folding tf.Cast that
// precedes the return instruction.
LogicalResult InferShapeForFunctionType(FuncOp func);
} // namespace TF
} // namespace mlir

View File

@ -65,7 +65,11 @@ struct ShapeInference : public ModulePass<ShapeInference> {
return;
}
for (auto func : module.getOps<FuncOp>()) {
TF::InferShapeUntilFixPoint(&func.getBody(), producer.getInt());
InferShapeUntilFixPoint(&func.getBody(), producer.getInt());
}
if (auto main_func = module.lookupSymbol<mlir::FuncOp>("main")) {
InferShapeForFunctionType(main_func);
}
}
};

View File

@ -34,10 +34,12 @@ limitations under the License.
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/Casting.h"
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Identifier.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/IR/Types.h" // TF:local_config_mlir
#include "mlir/IR/Value.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
@ -57,8 +59,6 @@ constexpr char kTPUReplicateAttr[] = "_tpu_replicate";
constexpr char kDeviceAttr[] = "device";
constexpr char kNameAttr[] = "name";
constexpr char kNumReplicasAttr[] = "num_replicas";
constexpr char kTPUReplicatedInputOp[] = "tf.TPUReplicatedInput";
constexpr char kTPUReplicatedOutputOp[] = "tf.TPUReplicatedOutput";
constexpr char kBadTPUReplicateAttrMsg[] =
"requires '_tpu_replicate' string attribute";
@ -275,9 +275,8 @@ LogicalResult ReplicateCluster(tf_device::LaunchOp launch_op,
mlir::visitUsedValuesDefinedAbove(
launch_op.body(), launch_op.body(), [&](mlir::OpOperand* operand) {
Operation* def = operand->get()->getDefiningOp();
if (def && def->getName().getStringRef() == kTPUReplicatedInputOp) {
if (def && llvm::isa<TF::TPUReplicatedInputOp>(def))
replicated_input_ops.insert(def);
}
});
// Check if number of operands of each used TPUReplicatedInput op matches
@ -305,10 +304,10 @@ LogicalResult ReplicateCluster(tf_device::LaunchOp launch_op,
int idx = result_and_idx.index();
for (auto& use : result->getUses()) {
Operation* def = use.getOwner();
if (!def || def->getName().getStringRef() != kTPUReplicatedOutputOp)
if (!def || !llvm::isa<TF::TPUReplicatedOutputOp>(def))
return launch_op.emitError()
<< "requires output of " << launch_op.getOperationName()
<< " to lead to a '" << kTPUReplicatedOutputOp << "' op";
<< " to lead to a 'tf.TPUReplicatedOutput' op";
if (def->getNumResults() != num_replicas)
return def->emitOpError() << "requires " << num_replicas << " results";
@ -331,9 +330,8 @@ LogicalResult ReplicateCluster(tf_device::LaunchOp launch_op,
// Create terminator for replicate op and move launch into replicate.
builder.setInsertionPointToEnd(&replicate_op.GetBody());
auto return_op = builder.create<tf_device::ReturnOp>(
replicate_op.getLoc(),
llvm::SmallVector<Value*, 8>(launch_op.getResults()));
auto return_op = builder.create<tf_device::ReturnOp>(replicate_op.getLoc(),
launch_op.getResults());
launch_op.getOperation()->moveBefore(return_op);
return success();
@ -427,8 +425,8 @@ void TPUClusterFormation::runOnFunction() {
// Remove TPUReplicatedInput and TPUReplicatedOutput nodes.
auto remove_result = getFunction().walk([&](Operation* op) {
auto op_name = op->getName().getStringRef();
if (op_name != kTPUReplicatedInputOp && op_name != kTPUReplicatedOutputOp)
if (!llvm::isa<TF::TPUReplicatedInputOp>(op) &&
!llvm::isa<TF::TPUReplicatedOutputOp>(op))
return WalkResult::advance();
// Forward operand to result. When `num_replicas` attribute is 1, no
@ -440,7 +438,8 @@ void TPUClusterFormation::runOnFunction() {
// Leftover TPUReplicatedInput/TPUReplicatedOutput that are not of
// `num_replicas` to 1.
if (!op->use_empty()) {
op->emitOpError() << "expects " << op_name << " to have no uses";
op->emitOpError() << "expects " << op->getName().getStringRef()
<< " to have no uses";
return WalkResult::interrupt();
}

View File

@ -109,13 +109,13 @@ LogicalResult EncapsulateFuncAndSerialize(FuncOp entry_func,
return parent_module.emitError(CreateMissingAttributeMsg(kVersionsAttr));
module_for_func.get().getOperation()->setAttr(kVersionsAttr, versions_attr);
ModuleManager module_manager(module_for_func.get());
SymbolTable symbol_table(module_for_func.get());
while (!referenced.empty()) {
auto func = referenced.pop_back_val();
// Skip functions that have already been cloned into new module.
if (module_manager.lookupSymbol<FuncOp>(func.getName())) continue;
if (symbol_table.lookup<FuncOp>(func.getName())) continue;
// Find any SymbolRefAttr in func that maps to a FuncOp. We need to clone
// all found FuncOps to new_module to make sure new_module is
@ -138,7 +138,7 @@ LogicalResult EncapsulateFuncAndSerialize(FuncOp entry_func,
// should be no other reference to it.
clone.setName("main");
}
module_manager.insert(clone);
symbol_table.insert(clone);
}
// Serialize module and return.

View File

@ -13,14 +13,19 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir
#include "mlir/Support/STLExtras.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
// This pass is used in preparation for Graph export.
@ -38,12 +43,11 @@ struct BreakUpIslands : OperationPass<BreakUpIslands, FuncOp> {
void runOnOperation() final;
void BreakUpIsland(tf_executor::IslandOp op,
const TF::SideEffectAnalysis& side_effect_analysis,
llvm::DenseMap<Operation*, llvm::SmallVector<Value*, 4>>*
new_control_edges);
};
} // end anonymous namespace
void BreakUpIslands::runOnOperation() {
auto graph_op_range = getOperation().getBody().front().without_terminator();
tf_executor::GraphOp graph_op;
@ -61,12 +65,13 @@ void BreakUpIslands::runOnOperation() {
// Map from the users of the existing islands to the list of control
// edges that need to be added.
llvm::DenseMap<Operation*, llvm::SmallVector<Value*, 4>> new_control_edges;
auto& side_effect_analysis = getAnalysis<TF::SideEffectAnalysis>();
// Iterate in reverse order to avoid invalidating Operation* stored in
// new_control_edges.
for (auto& item :
llvm::make_early_inc_range(llvm::reverse(graph_op.GetBody()))) {
if (auto island = dyn_cast<tf_executor::IslandOp>(&item)) {
BreakUpIsland(island, &new_control_edges);
BreakUpIsland(island, side_effect_analysis, &new_control_edges);
}
}
OpBuilder builder(getOperation());
@ -106,21 +111,81 @@ void BreakUpIslands::runOnOperation() {
}
}
// Helper that creates an island. If `sub_op` is not nullptr, it will be moved
// to the island.
tf_executor::IslandOp CreateIsland(ArrayRef<Type> result_types,
ArrayRef<Value*> control_inputs,
const tf_executor::ControlType& control_type,
const Location& loc, Operation* sub_op,
tf_executor::IslandOp original_island) {
OpBuilder builder(original_island);
auto island = builder.create<tf_executor::IslandOp>(
loc, result_types, control_type, control_inputs);
island.body().push_back(new Block);
Block* block = &island.body().back();
if (sub_op) {
sub_op->replaceAllUsesWith(island.outputs());
sub_op->moveBefore(block, block->begin());
}
OpBuilder island_builder(original_island);
island_builder.setInsertionPointToEnd(block);
if (sub_op) {
island_builder.create<tf_executor::YieldOp>(loc, sub_op->getResults());
} else {
island_builder.create<tf_executor::YieldOp>(loc, ArrayRef<Value*>{});
}
return island;
}
// A struct contains the operations in an island that do not have incoming or
// outgoing dependencies.
struct IslandSourcesAndSinks {
// Sub-ops that do not depend on other ops in the island.
llvm::SmallPtrSet<Operation*, 4> sources;
// Sub-ops that do not have other sub-ops island depending on them (excluding
// yield).
llvm::SmallPtrSet<Operation*, 4> sinks;
};
// Finds IslandSourcesAndSinks for an unmodified island.
IslandSourcesAndSinks FindSourcesAndSinksInIsland(
tf_executor::IslandOp island,
const TF::SideEffectAnalysis& side_effect_analysis) {
IslandSourcesAndSinks result;
auto island_body = island.GetBody().without_terminator();
for (Operation& sub_op : island_body) {
auto predecessors = side_effect_analysis.DirectControlPredecessors(&sub_op);
result.sinks.insert(&sub_op);
// Remove predecessor from sinks.
for (auto predecessor : predecessors) result.sinks.erase(predecessor);
bool has_in_island_operands = false;
for (auto operand : sub_op.getOperands()) {
auto defining_op = operand->getDefiningOp();
if (!defining_op || defining_op->getParentOp() != island) continue;
// Remove operands from sinks.
result.sinks.erase(defining_op);
has_in_island_operands = true;
}
if (predecessors.empty() && !has_in_island_operands) {
result.sources.insert(&sub_op);
}
}
return result;
}
// Converts a single island into multiple islands (one for each op). The islands
// are chained together by control flow values.
void BreakUpIslands::BreakUpIsland(
tf_executor::IslandOp op,
const TF::SideEffectAnalysis& side_effect_analysis,
llvm::DenseMap<Operation*, llvm::SmallVector<Value*, 4>>*
new_control_edges) {
auto island_body = op.GetBody().without_terminator();
// Skip islands that are already only a single op.
// Skip islands that are empty (only yield).
if (island_body.empty() || has_single_element(island_body)) return;
OpBuilder builder(op);
OpBuilder island_builder(op);
auto control_type = tf_executor::ControlType::get(&getContext());
Value* previous_island = nullptr;
auto tmp_control_inputs = llvm::to_vector<4>(op.controlInputs());
auto island_control_inputs = llvm::to_vector<4>(op.controlInputs());
// Add control dependencies for yields of values defined by other islands to
// the island that defines that fetched value.
for (auto* fetch : op.GetYield().fetches()) {
@ -130,7 +195,7 @@ void BreakUpIslands::BreakUpIsland(
// OK, because it is the same island.
} else if (auto island_op = llvm::dyn_cast<tf_executor::IslandOp>(
fetch->getDefiningOp())) {
tmp_control_inputs.push_back(island_op.control());
island_control_inputs.push_back(island_op.control());
} else {
// TODO(parkers): Any defining op that has a control output can be handled
// just like an island.
@ -138,39 +203,71 @@ void BreakUpIslands::BreakUpIsland(
return signalPassFailure();
}
}
ArrayRef<Value*> previous_control = tmp_control_inputs;
// If there are multiple control inputs, create an empty island to group them.
if (island_control_inputs.size() > 1) {
auto island = CreateIsland({}, island_control_inputs, control_type,
op.getLoc(), nullptr, op);
island_control_inputs.clear();
island_control_inputs.push_back(island.control());
}
// Find sources and sinks inside the original island.
auto sources_and_sinks =
FindSourcesAndSinksInIsland(op, side_effect_analysis);
// The corresponding control output of the new island created for each sub-op.
llvm::SmallDenseMap<Operation*, Value*, 8> new_control_for_sub_ops;
// Control outputs of newly created islands that are sinks.
llvm::SmallVector<Value*, 8> sink_island_controls;
// For each operation in the island, construct a new island to wrap the op,
// yield all the results, and replace all the usages with the results of the
// new island.
for (Operation& sub_op : llvm::make_early_inc_range(island_body)) {
auto loc = sub_op.getLoc();
auto island = builder.create<tf_executor::IslandOp>(
loc, llvm::to_vector<4>(sub_op.getResultTypes()), control_type,
previous_control);
island.body().push_back(new Block);
Block* block = &island.body().back();
sub_op.replaceAllUsesWith(island.outputs());
block->getOperations().splice(block->begin(), op.GetBody().getOperations(),
sub_op);
island_builder.setInsertionPointToEnd(block);
island_builder.create<tf_executor::YieldOp>(
loc, llvm::to_vector<4>(sub_op.getResults()));
previous_island = island.control();
previous_control = previous_island;
for (auto& sub_op : llvm::make_early_inc_range(island_body)) {
const auto predecessors =
side_effect_analysis.DirectControlPredecessors(&sub_op);
// Get the controls from the predecessors.
llvm::SmallVector<Value*, 4> predecessors_control;
predecessors_control.reserve(predecessors.size());
for (auto predecessor : predecessors) {
predecessors_control.push_back(new_control_for_sub_ops[predecessor]);
}
// If sub_op is a source, use island_control_inputs, because that's required
// by inter-islands dependencies; otherwise, we do not need to include
// island_control_inputs, since they must have been tracked by the (direct
// or indirect) control predecessors or operands.
ArrayRef<Value*> control = sources_and_sinks.sources.count(&sub_op) > 0
? island_control_inputs
: predecessors_control;
auto island =
CreateIsland(llvm::to_vector<4>(sub_op.getResultTypes()), control,
control_type, sub_op.getLoc(), &sub_op, op);
new_control_for_sub_ops[&sub_op] = island.control();
if (sources_and_sinks.sinks.count(&sub_op)) {
sink_island_controls.push_back(island.control());
}
}
op.control()->replaceAllUsesWith(previous_island);
// All existing outputs need to add a control flow edge to the
// previous_island.
// Create output controls for the sinks.
assert(!sink_island_controls.empty());
// If there are multiple output controls, create an empty island to group
// them.
if (sink_island_controls.size() > 1) {
auto island = CreateIsland({}, sink_island_controls, control_type,
op.getLoc(), nullptr, op);
sink_island_controls.clear();
sink_island_controls.push_back(island.control());
}
assert(sink_island_controls.size() == 1);
op.control()->replaceAllUsesWith(sink_island_controls[0]);
// All existing outputs need to add a control flow edge from
// sink_island_controls[0].
for (Value* out : op.outputs()) {
for (auto& use : out->getUses()) {
Operation* owner = use.getOwner();
if (auto island_op =
llvm::dyn_cast<tf_executor::IslandOp>(owner->getParentOp())) {
(*new_control_edges)[island_op].push_back(previous_island);
(*new_control_edges)[island_op].push_back(sink_island_controls[0]);
} else if (llvm::isa<tf_executor::FetchOp>(owner) ||
llvm::isa<tf_executor::MergeOp>(owner) ||
llvm::isa<tf_executor::SwitchOp>(owner)) {
(*new_control_edges)[owner].push_back(previous_island);
(*new_control_edges)[owner].push_back(sink_island_controls[0]);
} else {
use.getOwner()->emitError("Adding control dependency not supported");
return signalPassFailure();
@ -182,6 +279,8 @@ void BreakUpIslands::BreakUpIsland(
op.erase();
}
} // namespace
std::unique_ptr<OpPassBase<FuncOp>> CreateBreakUpIslandsPass() {
return std::make_unique<BreakUpIslands>();
}

View File

@ -535,6 +535,18 @@ StatusOr<std::unique_ptr<Graph>> Exporter::Convert(
arg, index,
graph_as_function && !input_names.empty() ? input_names[index] : ""));
}
auto convert_called_function = [&](llvm::StringRef name) {
auto func =
function.getParentOfType<mlir::ModuleOp>().lookupSymbol<mlir::FuncOp>(
name);
if (func != nullptr) {
TF_RETURN_IF_ERROR(ConvertLibFunction(configs, tf_dialect, func, flib));
TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(*flib));
}
return Status::OK();
};
// Adds nodes for operations.
for (Operation& inst : block) {
auto op_name = GetTensorFlowOpName(inst.getName().getStringRef());
@ -544,13 +556,12 @@ StatusOr<std::unique_ptr<Graph>> Exporter::Convert(
// definition library
// TODO(prakalps): If two functions have cyclic dependence, this will
// introduce an infinite loop.
auto func =
function.getParentOfType<mlir::ModuleOp>().lookupSymbol<mlir::FuncOp>(
op_name.ValueOrDie());
if (func != nullptr) {
TF_RETURN_IF_ERROR(ConvertLibFunction(configs, tf_dialect, func, flib));
TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(*flib));
}
TF_RETURN_IF_ERROR(convert_called_function(op_name.ValueOrDie().str()));
}
if (IsLegacyCallInstruction(&inst)) {
TF_RETURN_IF_ERROR(convert_called_function(
inst.getAttrOfType<mlir::SymbolRefAttr>("f").getLeafReference()));
}
for (auto type : inst.getResultTypes()) {

View File

@ -16,8 +16,11 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
#include <iterator>
#include <string>
#include <tuple>
#include <type_traits>
#include <utility>
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
@ -97,6 +100,9 @@ using stream_executor::port::StatusOr;
namespace {
const char* disable_call_shape_inference_attribute_name =
"_disable_call_shape_inference";
// This class is used to generate new MLIR function name strings that are both
// unique in the TF function library `flib_` and unique among the name strings
// generated by the class object during its lifetime.
@ -246,11 +252,14 @@ class ImporterBase {
llvm::SmallVector<mlir::NamedAttribute, 4>* attributes);
// Helper to create either a tf_executor operation or a TF operation wrapped
// in an island.
// in an island. When convert_to_legacy_call is true, converts the operation
// representing a call to a library function with a name represented in
// node_type_name to LegacyCallOp.
mlir::Operation* createOperation(
const Node& node, llvm::StringRef op_name,
const Node& node, llvm::StringRef node_type_name,
const mlir::OperationState& result,
const llvm::SmallVectorImpl<mlir::Value*>& control_operands);
const llvm::SmallVectorImpl<mlir::Value*>& control_operands,
bool convert_to_legacy_call = false);
// Converts one NodeDef from the input GraphDef into an Operation and
// inserts it into the MLIR module using builder_.
@ -297,19 +306,24 @@ class ImporterBase {
// Gets the location information string for the given node.
std::string GetLocationStr(const Node& node, bool includeNodeName = false);
// Inserts a placeholder node in the graph to replace the input node. Replaces
// all the output edges of the input_node with the placeholder node, and
// removes the input_node from the graph. The new node has the same name as
// the input_node, so Nodespecs do not need any modification.
// Inserts a placeholder node in the graph to replace a feed output tensor,
// and returns the new placeholder node and a boolean indicating if the
// original input node was removed from the graph. Uses of the feed output
// tensor are replaced with this placeholder node. If the feed output tensor
// is of a single output node, the control dependencies are forwarded to the
// the placeholder node, and the original node will be removed.
// Note: This modifies the graph, and so any list of ordered nodes needs to be
// reconstructed.
StatusOr<Node*> ReplaceWithPlaceholderNode(const TensorShapeProto& shape,
DataType dtype, Node* input_node);
StatusOr<std::pair<Node*, bool>> CreatePlaceholderNodeForFeed(
const TensorShapeProto& shape, DataType dtype, Node* node, int index,
const std::unordered_map<string, Node*>& node_name_map);
// Gets the input and output nodes corresponding to the specified input and
// output nodes in specs_. If there are no input or output nodes specified,
// nodes will be empty
Status GetInputOutputNodes(std::unordered_set<const Node*>* nodes);
// nodes will be empty.
Status GetInputOutputNodes(
const std::unordered_map<string, Node*>& node_name_map,
std::unordered_set<const Node*>* nodes);
// The input graph with backedges removed. The removed backedges are stored
// in the back_edge_helper.
@ -339,6 +353,10 @@ class ImporterBase {
NodeValueMap node_values_;
std::unique_ptr<ShapeRefiner> shape_refiner_;
NameUniquifier* function_name_uniquifier_;
protected:
// Maps feed as TensorId to new Placeholder node name.
absl::flat_hash_map<TensorId, absl::string_view> remapped_feeds_;
};
// Returns true if the node with given name has a non primary output that is
@ -419,6 +437,49 @@ Status PreprocessGraphDef(const GraphImportConfig* specs, GraphDef* graph_def) {
return Status::OK();
}
// Mapping from node name to feed (index and ArrayInfo). Node name must outlive
// this map.
using FeedsByNode = absl::flat_hash_map<
absl::string_view,
absl::flat_hash_map<int, const std::pair<std::string, ArrayInfo>*>>;
// Creates from a `GraphImportConfig::InputArrays` a mapping from a feeds output
// tensor name to index and ArrayInfo. Keys and values are backed by
// `GraphImportConfig::InputArrays`.
StatusOr<FeedsByNode> GetFeedsByNode(
const GraphImportConfig::InputArrays& inputs) {
FeedsByNode feeds_by_node;
feeds_by_node.reserve(inputs.size());
for (const auto& input : inputs) {
TensorId tensor = ParseTensorName(input.first);
if (tensor.index() < 0)
return errors::FailedPrecondition(
"Feed output tensor must be a data output '", tensor.ToString(), "'");
auto& node = feeds_by_node[tensor.node()];
if (!node.insert({tensor.index(), &input}).second)
return errors::FailedPrecondition(
"Multiple feeds for the same output tensor '", tensor.ToString(),
"'");
}
return feeds_by_node;
}
// Creates a unique name for a node that will be replacing a feed output tensor.
std::string GetUniqueNodeName(
absl::string_view node_name, int index,
const std::unordered_map<string, Node*>& node_name_map) {
std::string new_node_name_base = absl::StrCat(node_name, "_", index);
int count = 0;
std::string new_node_name = new_node_name_base;
while (node_name_map.find(new_node_name) != node_name_map.end()) {
new_node_name = absl::StrCat(new_node_name_base, "_", count++);
}
return new_node_name;
}
Status ImporterBase::RemoveBackedges(const Graph& graph) {
// TODO(fengliuai): Converting to GraphDef and back is the easiest way to
// clone a graph.
@ -459,37 +520,54 @@ Status ImporterBase::RemoveBackedges(const Graph& graph) {
return Status::OK();
}
StatusOr<Node*> ImporterBase::ReplaceWithPlaceholderNode(
const TensorShapeProto& shape, DataType dtype, Node* input_node) {
StatusOr<std::pair<Node*, bool>> ImporterBase::CreatePlaceholderNodeForFeed(
const TensorShapeProto& shape, DataType dtype, Node* node, int index,
const std::unordered_map<string, Node*>& node_name_map) {
DCHECK_LT(index, node->num_outputs());
const bool update_inplace = node->num_outputs() == 1 && index == 0;
std::string new_node_name =
update_inplace ? node->name()
: GetUniqueNodeName(node->name(), index, node_name_map);
Node* placeholder_node;
NodeBuilder builder(input_node->name(), "Placeholder");
NodeBuilder builder(new_node_name, "Placeholder");
builder.Attr("shape", shape);
builder.Attr("dtype", dtype);
TF_RETURN_IF_ERROR(builder.Finalize(graph_.get(), &placeholder_node));
while (!input_node->out_edges().empty()) {
const Edge* oe = *input_node->out_edges().begin();
// UpdateEdge cannot be used with control edges.
if (oe->src_output() == Graph::kControlSlot) {
graph_->AddControlEdge(placeholder_node, oe->dst());
graph_->RemoveControlEdge(oe);
continue;
// Update edges from original feed with Placeholder node.
std::vector<const Edge*> data_edges;
std::vector<const Edge*> control_edges;
for (const tensorflow::Edge* edge : node->out_edges()) {
if (edge->src_output() == index) {
data_edges.push_back(edge);
} else if (update_inplace && edge->IsControlEdge()) {
control_edges.push_back(edge);
}
TF_RETURN_IF_ERROR(
graph_->UpdateEdge(placeholder_node, 0, oe->dst(), oe->dst_input()));
}
graph_->RemoveNode(input_node);
for (const auto* edge : data_edges) {
TF_RETURN_IF_ERROR(graph_->UpdateEdge(placeholder_node, 0, edge->dst(),
edge->dst_input()));
}
return placeholder_node;
for (const auto* edge : control_edges) {
graph_->AddControlEdge(placeholder_node, edge->dst());
graph_->RemoveControlEdge(edge);
}
if (update_inplace) {
graph_->RemoveNode(node);
}
return std::pair<Node*, bool>(placeholder_node, update_inplace);
}
Status ImporterBase::GetInputOutputNodes(
const std::unordered_map<string, Node*>& node_name_map,
std::unordered_set<const Node*>* nodes) {
auto node_name_map = graph_->BuildNodeNameIndex();
auto add_node = [&](const string& name) {
auto it = node_name_map.find(name);
auto add_node = [&](absl::string_view name) {
auto it = node_name_map.find(std::string(name));
if (it == node_name_map.end()) {
return errors::FailedPrecondition(
absl::StrCat("Graph does not contain node: ", name));
@ -498,13 +576,25 @@ Status ImporterBase::GetInputOutputNodes(
return Status::OK();
};
// Remap feeds and fetches to newly created Placeholder nodes.
for (const auto& input : specs_.inputs) {
TF_RETURN_IF_ERROR(add_node(input.first));
TensorId tensor = ParseTensorName(input.first);
auto remapped_it = remapped_feeds_.find(tensor);
if (remapped_it != remapped_feeds_.end()) {
TF_RETURN_IF_ERROR(add_node(remapped_it->second));
} else {
TF_RETURN_IF_ERROR(add_node(tensor.node()));
}
}
for (const auto& output : specs_.outputs) {
auto output_node_name = std::string(ParseTensorName(output).first);
TF_RETURN_IF_ERROR(add_node(output_node_name));
TensorId tensor = ParseTensorName(output);
auto remapped_it = remapped_feeds_.find(tensor);
if (remapped_it != remapped_feeds_.end()) {
TF_RETURN_IF_ERROR(add_node(remapped_it->second));
} else {
TF_RETURN_IF_ERROR(add_node(tensor.node()));
}
}
return Status::OK();
@ -520,6 +610,9 @@ Status ImporterBase::AddNodesToShapeRefiner() {
shape_refiner_->set_require_shape_inference_fns(false);
shape_refiner_->set_function_library_for_shape_inference(&graph_flib_);
TF_ASSIGN_OR_RETURN(auto feeds_by_node, GetFeedsByNode(specs_.inputs));
auto node_name_map = graph_->BuildNodeNameIndex();
// First add all nodes to the refiner.
for (Node* node : ordered_nodes_) {
// We need to use a TensorFlow node to teach the shape refiner that user
@ -533,28 +626,49 @@ Status ImporterBase::AddNodesToShapeRefiner() {
// it to replace the original input node, so the shape refiner can
// successfully propagate the user's input type and shape to the rest of the
// graph.
auto it = specs_.inputs.find(node->name());
if (it != specs_.inputs.end()) {
auto node_name = node->op_def().name();
if (node_name != "Placeholder" && node_name != "LegacyFedInput" &&
node_name != FunctionLibraryDefinition::kArgOp) {
// We do not handle the case where the input node has multiple outputs
if (node->num_outputs() > 1) {
return errors::FailedPrecondition(absl::StrCat(
"Input arrays can only have op with single output. Node op:",
node_name));
bool node_added_to_shape_refiner = false;
auto it = feeds_by_node.find(node->name());
if (it != feeds_by_node.end()) {
auto op_name = node->op_def().name();
if (op_name != "Placeholder" && op_name != "LegacyFedInput" &&
op_name != FunctionLibraryDefinition::kArgOp) {
for (const auto& output_tensor : it->second) {
const int index = output_tensor.first;
const ArrayInfo& array_info = output_tensor.second->second;
DataType dtype = array_info.imported_dtype;
// Uses the existing output type if it isn't specified by the user.
if (dtype == DT_INVALID) {
dtype = node->output_type(0);
}
TF_ASSIGN_OR_RETURN(
auto placeholder_node_and_removed,
CreatePlaceholderNodeForFeed(array_info.shape, dtype, node, index,
node_name_map));
Node* placeholder_node = placeholder_node_and_removed.first;
if (placeholder_node_and_removed.second) {
// Original node has been removed from the graph.
node = placeholder_node;
node_added_to_shape_refiner = true;
}
remapped_feeds_[{it->first, index}] = placeholder_node->name();
node_name_map[placeholder_node->name()] = placeholder_node;
// Add the new placeholder node to the shape refiner.
TF_RETURN_WITH_CONTEXT_IF_ERROR(
shape_refiner_->AddNode(placeholder_node),
GetLocationStr(*placeholder_node));
}
// For single output nodes, replace them with Placeholder node.
DataType dtype = it->second.imported_dtype;
// Uses the existing output type if it isn't specified by the user.
if (dtype == DT_INVALID) {
dtype = node->output_type(0);
}
TF_ASSIGN_OR_RETURN(
node, ReplaceWithPlaceholderNode(it->second.shape, dtype, node));
} else {
node->AddAttr("shape", it->second.shape);
DataType dtype = it->second.imported_dtype;
auto index_it = it->second.find(0);
if (index_it == it->second.end()) {
return errors::FailedPrecondition(
"Missing feed output tensor at index 0 for node '", node->name(),
"'");
}
node->AddAttr("shape", index_it->second->second.shape);
DataType dtype = index_it->second->second.imported_dtype;
// Uses the existing output type if it isn't specified by the user.
if (dtype == DT_INVALID) {
dtype = node->output_type(0);
@ -562,9 +676,11 @@ Status ImporterBase::AddNodesToShapeRefiner() {
node->AddAttr("dtype", dtype);
}
}
// Adds the node to the shape refiner.
TF_RETURN_WITH_CONTEXT_IF_ERROR(shape_refiner_->AddNode(node),
GetLocationStr(*node));
if (!node_added_to_shape_refiner) {
// Add the node to the shape refiner if the node hasn't been removed.
TF_RETURN_WITH_CONTEXT_IF_ERROR(shape_refiner_->AddNode(node),
GetLocationStr(*node));
}
auto set_shape_from_list_attr = [&](const AttrValue* attr) {
auto& list = attr->list();
@ -625,7 +741,7 @@ Status ImporterBase::AddNodesToShapeRefiner() {
// Prune nodes in the graph that are not reachable from the output.
if (specs_.prune_unused_nodes) {
std::unordered_set<const Node*> prune_start;
TF_RETURN_IF_ERROR(GetInputOutputNodes(&prune_start));
TF_RETURN_IF_ERROR(GetInputOutputNodes(node_name_map, &prune_start));
if (!prune_start.empty()) {
if (PruneForReverseReachability(graph_.get(), prune_start)) {
VLOG(1) << "Pruned unused nodes in graphdef";
@ -829,9 +945,11 @@ StatusOr<mlir::Attribute> ImporterBase::ConvertAttributeValue(
return builder_.getFloatAttr(builder_.getF32Type(), value.f());
case AttrValue::kB:
return builder_.getBoolAttr(value.b());
case AttrValue::kType:
return builder_.getStringAttr(
mangling_util::MangleDataType(value.type()));
case AttrValue::kType: {
mlir::Type type;
TF_RETURN_IF_ERROR(ConvertDataType(value.type(), builder_, &type));
return mlir::TypeAttr::get(type);
}
case AttrValue::kShape:
return builder_.getStringAttr(mangling_util::MangleShape(value.shape()));
case AttrValue::kTensor:
@ -1106,11 +1224,9 @@ Status ImporterBase::ConvertFunctionArgAndRets(
builder_.setInsertionPointToEnd(&graph_op.body().front());
builder_.create<mlir::tf_executor::FetchOp>(graph_op.getLoc(),
inst_to_return);
inst_to_return.assign(graph_op.getResults().begin(),
graph_op.getResults().end());
builder_.setInsertionPointToEnd(bb);
builder_.create<mlir::ReturnOp>(mlir::UnknownLoc::get(context_),
inst_to_return);
graph_op.getResults());
return Status::OK();
}
@ -1210,9 +1326,10 @@ std::string ImporterBase::GetLocationStr(const Node& node,
}
mlir::Operation* ImporterBase::createOperation(
const Node& node, llvm::StringRef op_name,
const Node& node, llvm::StringRef node_type_name,
const mlir::OperationState& result,
const llvm::SmallVectorImpl<mlir::Value*>& control_operands) {
const llvm::SmallVectorImpl<mlir::Value*>& control_operands,
bool convert_to_legacy_call) {
// For the tf.executor specific operations (not wrapped in an island), we
// have an extra returned value for the control result, and we concatenate
// control and non-control operands.
@ -1274,11 +1391,31 @@ mlir::Operation* ImporterBase::createOperation(
mlir::OpBuilder island_builder(&island.GetBody());
// Create the operation inside the island now.
mlir::Operation* inner_op = island_builder.createOperation(result);
mlir::Operation* inner_op;
if (convert_to_legacy_call) {
bool disable_call_shape_inference = false;
for (const auto& name_and_value : node.attrs()) {
const auto& attr_name = name_and_value.first;
const AttrValue& attr_value = name_and_value.second;
if (strcmp(attr_name.c_str(),
disable_call_shape_inference_attribute_name) == 0 &&
attr_value.value_case() == AttrValue::kB) {
disable_call_shape_inference = attr_value.b();
}
}
mlir::BoolAttr attribute =
builder_.getBoolAttr(disable_call_shape_inference);
inner_op = island_builder.create<mlir::TF::LegacyCallOp>(
result.location, result.types, result.operands,
island_builder.getSymbolRefAttr(node_type_name), attribute);
} else {
inner_op = island_builder.createOperation(result);
}
// Add the terminator for the island
mlir::SmallVector<mlir::Value*, 8> ret_vals(inner_op->getResults());
island_builder.create<mlir::tf_executor::YieldOp>(result.location, ret_vals);
island_builder.create<mlir::tf_executor::YieldOp>(result.location,
inner_op->getResults());
return island.getOperation();
}
@ -1293,9 +1430,11 @@ Status ImporterBase::ConvertNode(const Node& node) {
// create the MLIR function and insert it to the module if it doesn't exist.
std::string node_type_name = node.type_string();
const auto* func_def = graph_flib_.Find(node_type_name);
bool convert_to_legacy_call = false;
if (func_def) {
TF_RETURN_IF_ERROR(ConvertLibFunction(node_type_name));
node_type_name = (*tf_name_to_mlir_name_)[node_type_name];
convert_to_legacy_call = true;
}
auto get_full_op_name = [&](const std::string& op_name) {
@ -1380,6 +1519,14 @@ Status ImporterBase::ConvertNode(const Node& node) {
for (const auto& name_and_value : node.attrs()) {
const auto& attr_name = name_and_value.first;
const AttrValue& attr_value = name_and_value.second;
// LegacyCall can only represent _diable_call_shape_inference attribute.
// If a call has other attributes, can't convert it to LegacyCall.
if (convert_to_legacy_call &&
(strcmp(attr_name.c_str(),
disable_call_shape_inference_attribute_name) ||
attr_value.value_case() != AttrValue::kB)) {
convert_to_legacy_call = false;
}
if (attr_value.value_case() == AttrValue::kFunc) {
// Attribute iteration order is not defined for protocol buffer Map.
// Process function attributes separately in the lexicographical order to
@ -1423,9 +1570,8 @@ Status ImporterBase::ConvertNode(const Node& node) {
}
// Register the mapping between the TF node and the newly created operation.
node_values_[node.id()] =
createOperation(node, op_name, result, control_operands);
node_values_[node.id()] = createOperation(
node, node_type_name, result, control_operands, convert_to_legacy_call);
return Status::OK();
}
@ -1667,36 +1813,52 @@ StatusOr<mlir::FunctionType> GraphDefImporter::InferMainFunctionType(
const GraphImportConfig& specs, mlir::MLIRContext* context,
absl::InlinedVector<OutputTensor, 4>* arg_nodes,
absl::InlinedVector<OutputTensor, 4>* ret_nodes) {
// Finds out all the input nodes and output nodes.
absl::flat_hash_set<absl::string_view> output_node_names;
for (const auto& output_tensor : specs.outputs) {
output_node_names.insert(ParseTensorName(output_tensor).node());
// Find all the input nodes and output nodes.
// Feeds have been remapped to single output nodes (Placeholder), so an exact
// name match is sufficient.
absl::flat_hash_map<absl::string_view, int> inputs;
for (auto input_and_idx : llvm::enumerate(specs.inputs)) {
TensorId tensor = ParseTensorName(input_and_idx.value().first);
auto remapped_it = remapped_feeds_.find(tensor);
if (remapped_it != remapped_feeds_.end()) {
inputs.insert({remapped_it->second, input_and_idx.index()});
} else {
inputs.insert({tensor.node(), input_and_idx.index()});
}
}
if (!specs.inputs.empty() || !specs.outputs.empty()) {
arg_nodes->resize(specs.inputs.size());
ret_nodes->resize(specs.outputs.size());
absl::flat_hash_set<absl::string_view> output_node_names;
std::vector<TensorId> outputs;
output_node_names.reserve(specs.outputs.size());
for (const auto& output : specs.outputs) {
TensorId tensor = ParseTensorName(output);
auto remapped_it = remapped_feeds_.find(tensor);
if (remapped_it != remapped_feeds_.end()) {
output_node_names.insert(remapped_it->second);
outputs.push_back({remapped_it->second, 0});
} else {
output_node_names.insert(tensor.node());
outputs.push_back(tensor);
}
}
if (!inputs.empty() || !outputs.empty()) {
arg_nodes->resize(inputs.size());
ret_nodes->resize(outputs.size());
for (Node* n : GetOrderedNodes()) {
// Handle inputs/arguments.
auto input_it = specs.inputs.find(n->name());
if (input_it != specs.inputs.end()) {
(*arg_nodes)[std::distance(specs.inputs.begin(), input_it)] = {n, 0};
auto input_it = inputs.find(n->name());
if (input_it != inputs.end()) {
(*arg_nodes)[input_it->second] = {n, 0};
}
// Handle outputs/returns.
if (output_node_names.contains(n->name())) {
for (int i = 0, e = specs.outputs.size(); i != e; ++i) {
std::pair<std::string, std::string> name_and_port =
absl::StrSplit(specs.outputs[i], ':');
auto name = name_and_port.first;
if (name != n->name()) continue;
int port = 0;
if (!name_and_port.second.empty() &&
!absl::SimpleAtoi(name_and_port.second, &port)) {
return errors::InvalidArgument("Invalid port specification: ",
specs.outputs[i]);
}
(*ret_nodes)[i] = {n, port};
for (int i = 0, e = outputs.size(); i != e; ++i) {
TensorId tensor = outputs[i];
if (n->name() != tensor.node()) continue;
(*ret_nodes)[i] = {n, tensor.index()};
}
}
}
@ -2118,7 +2280,11 @@ class StructuredValueLinearizer {
// Returns the list of index paths to each leaf of the StructuredValue,
// in a linearized order matching `tf.nest.flatten`.
llvm::ArrayRef<mlir::ArrayAttr> GetLeafIndexPaths() const;
//
// If an error ocurred during the linearization process, an error message with
// `error_context` prepended will be included in the returned status.
StatusOr<llvm::ArrayRef<mlir::ArrayAttr>> GetLeafIndexPaths(
llvm::StringRef error_context) const;
private:
// Main function that recursively traverses the StructuredValue.
@ -2130,6 +2296,8 @@ class StructuredValueLinearizer {
llvm::SmallVector<mlir::Attribute, 4> current_index_path_;
// The list of leaf index paths we have discovered so far.
llvm::SmallVector<mlir::ArrayAttr, 4> leaf_index_paths_;
// If non-empty, an error message to report.
std::string error_message_;
};
StructuredValueLinearizer::StructuredValueLinearizer(
@ -2138,9 +2306,19 @@ StructuredValueLinearizer::StructuredValueLinearizer(
RecursivelyFindLeaves(value);
}
llvm::ArrayRef<mlir::ArrayAttr> StructuredValueLinearizer::GetLeafIndexPaths()
const {
return leaf_index_paths_;
StatusOr<llvm::ArrayRef<mlir::ArrayAttr>>
StructuredValueLinearizer::GetLeafIndexPaths(
llvm::StringRef error_context) const {
if (error_message_.empty()) {
return llvm::makeArrayRef(leaf_index_paths_);
}
return errors::InvalidArgument(
error_context.str(), error_message_,
"This likely means that you have @tf.function "
"on an exported function instead of "
"@tf.function(input_signature=[...]). Consider annotating an "
"input_signature or narrowing your set of "
"exported names to not include this function.");
}
void StructuredValueLinearizer::RecursivelyFindLeaves(
@ -2196,7 +2374,20 @@ void StructuredValueLinearizer::RecursivelyFindLeaves(
return;
}
default: {
llvm_unreachable("Unhandled StructuredValue kind!");
llvm::raw_string_ostream os(error_message_);
// TODO(silvasean): Use an enumerant name string instead of a number.
os << "Unhandled structured value kind " << value.kind_case()
<< " at index path: <value>";
for (auto path_element : current_index_path_) {
os << ".";
if (auto integer = path_element.dyn_cast<mlir::IntegerAttr>()) {
os << integer.getValue();
} else {
auto str = path_element.cast<mlir::StringAttr>();
os << str.getValue();
}
}
os << "\n";
}
}
}
@ -2290,6 +2481,9 @@ Status CreateSavedModelIR(
if (object_names.GetExportedNames(node_id).empty()) {
continue;
}
std::string error_context =
"While importing SavedModel function '" +
object_names.GetExportedNames(node_id)[0].str() + "': ";
const SavedFunction& function = object.function();
auto orig_func = symbol_table.lookup<mlir::FuncOp>(
tf_name_to_mlir_name.find(function.concrete_functions(0))->second);
@ -2314,8 +2508,7 @@ Status CreateSavedModelIR(
/*config=*/builder.getStringAttr(""),
/*config_proto=*/builder.getStringAttr(""),
/*executor_type=*/builder.getStringAttr(""));
body_builder.create<mlir::ReturnOp>(
func.getLoc(), llvm::to_vector<4>(call.getResults()));
body_builder.create<mlir::ReturnOp>(func.getLoc(), call.getResults());
}
func.setAttr(
"tf_saved_model.exported_names",
@ -2338,9 +2531,12 @@ Status CreateSavedModelIR(
int bound_input_base =
func.getNumArguments() - concrete_function.bound_inputs_size();
auto input_index_paths = input_linearizer.GetLeafIndexPaths();
TF_ASSIGN_OR_RETURN(auto input_index_paths,
input_linearizer.GetLeafIndexPaths(
error_context + "in input signature: "));
if (bound_input_base != input_index_paths.size()) {
return errors::InvalidArgument(
error_context,
"Argument mismatch between concrete function input signature "
"vs underlying FunctionDef for concrete function '",
function.concrete_functions(0), "' (", input_index_paths.size(),
@ -2361,9 +2557,12 @@ Status CreateSavedModelIR(
StructuredValueLinearizer output_linearizer(
concrete_function.output_signature(), builder.getContext());
auto output_index_paths = output_linearizer.GetLeafIndexPaths();
TF_ASSIGN_OR_RETURN(auto output_index_paths,
output_linearizer.GetLeafIndexPaths(
error_context + "in output signature: "));
if (func.getNumResults() != output_index_paths.size()) {
return errors::InvalidArgument(
error_context,
"Result mismatch between concrete function output signature "
"vs underlying FunctionDef for concrete function '",
function.concrete_functions(0), "' (", output_index_paths.size(),

View File

@ -67,8 +67,6 @@ void FunctionalToExecutorDialectConversion::runOnFunction() {
LLVM_DEBUG(llvm::dbgs() << "Expect function to end with return\n");
return;
}
llvm::SmallVector<Value*, 4> args =
llvm::to_vector<4>(return_op.getOperands());
// Build GraphOp.
OpBuilder builder(&body, body.begin());
auto graph_op = builder.create<tf_executor::GraphOp>(
@ -79,10 +77,10 @@ void FunctionalToExecutorDialectConversion::runOnFunction() {
loc, getFunction().getType().getResults(),
tf_executor::ControlType::get(&getContext()), ArrayRef<Value*>());
// Create Fetch.
auto to_fetch = llvm::to_vector<4>(island.getResults());
ValueRange to_fetch = island.getResults();
if (to_fetch.size() != 1) {
// Drop control result for fetch.
to_fetch.pop_back();
to_fetch = to_fetch.drop_back();
}
builder.create<tf_executor::FetchOp>(loc, to_fetch);
// Build Island.
@ -91,7 +89,7 @@ void FunctionalToExecutorDialectConversion::runOnFunction() {
island.body().front().begin(), body.getOperations(), copy_range.begin(),
copy_range.end());
builder.setInsertionPointToEnd(&island.body().front());
builder.create<tf_executor::YieldOp>(loc, args);
builder.create<tf_executor::YieldOp>(loc, return_op.getOperands());
for (auto item : llvm::enumerate(graph_op.getResults())) {
return_op.setOperand(item.index(), item.value());
}

View File

@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
#include "absl/types/span.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringRef.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/IR/Function.h" // TF:local_config_mlir
@ -58,7 +58,7 @@ Status ParseMlirModule(llvm::StringRef mlir_module_string,
// Converts arg_shapes to xla::Shape's and store into xla_input_shapes.
Status GetXlaInputShapes(
mlir::ModuleOp module, absl::Span<TensorShape> arg_shapes,
mlir::ModuleOp module, llvm::ArrayRef<TensorShape> arg_shapes,
const xla::CustomShapeRepresentationFn shape_representation_fn,
std::vector<xla::Shape>* xla_input_shapes) {
xla_input_shapes->clear();
@ -150,7 +150,8 @@ void GetInputMappingForMlir(int num_inputs, std::vector<int>* input_mapping) {
}
// Refine MLIR types based on new shape information.
Status RefineShapes(absl::Span<TensorShape> arg_shapes, mlir::ModuleOp module) {
Status RefineShapes(llvm::ArrayRef<TensorShape> arg_shapes,
mlir::ModuleOp module) {
auto versions = module.getAttrOfType<::mlir::DictionaryAttr>("tf.versions");
if (!versions) {
return errors::Internal(
@ -234,7 +235,7 @@ Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op,
}
Status CompileSerializedMlirToXlaHlo(
llvm::StringRef mlir_module_string, absl::Span<TensorShape> arg_shapes,
llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes,
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
XlaCompiler::CompilationResult* compilation_result) {
mlir::MLIRContext mlir_context;

View File

@ -16,7 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_COMPILE_MLIR_UTIL_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_COMPILE_MLIR_UTIL_H_
#include "absl/types/span.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringRef.h"
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
@ -40,7 +40,7 @@ Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op,
// Compiles a serialized MLIR module into XLA HLO, generates all accompanying
// metadata and stores them in CompilationResult.
Status CompileSerializedMlirToXlaHlo(
llvm::StringRef mlir_module_string, absl::Span<TensorShape> arg_shapes,
llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes,
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
XlaCompiler::CompilationResult* compilation_result);
} // namespace tensorflow

View File

@ -41,9 +41,9 @@ TEST(CompileSerializedMlirToXlaHloTest, InvalidSerializedMlirModule) {
std::vector<TensorShape> arg_shapes;
XlaCompiler::CompilationResult compilation_result;
Status s = CompileSerializedMlirToXlaHlo(
invalid_mlir_module, absl::Span<TensorShape>(arg_shapes),
TestShapeRepresentation, &compilation_result);
Status s = CompileSerializedMlirToXlaHlo(invalid_mlir_module, arg_shapes,
TestShapeRepresentation,
&compilation_result);
EXPECT_EQ(s.code(), tensorflow::errors::Code::INVALID_ARGUMENT);
}
@ -61,8 +61,7 @@ TEST(CompileSerializedMlirToXlaHloTest, Success) {
XlaCompiler::CompilationResult compilation_result;
Status s = CompileSerializedMlirToXlaHlo(
mlir_module, absl::Span<TensorShape>(arg_shapes), TestShapeRepresentation,
&compilation_result);
mlir_module, arg_shapes, TestShapeRepresentation, &compilation_result);
ASSERT_TRUE(s.ok());
const xla::HloModuleConfig module_config(
@ -134,8 +133,7 @@ TEST(CompileSerializedMlirToXlaHloTest, CompileTimeConstantFoldedSuccess) {
XlaCompiler::CompilationResult compilation_result;
Status s = CompileSerializedMlirToXlaHlo(
mlir_module, absl::Span<TensorShape>(arg_shapes), TestShapeRepresentation,
&compilation_result);
mlir_module, arg_shapes, TestShapeRepresentation, &compilation_result);
ASSERT_TRUE(s.ok());
const xla::HloModuleConfig module_config(
@ -174,8 +172,7 @@ TEST(CompileSerializedMlirToXlaHloTest, ShapeInference) {
XlaCompiler::CompilationResult compilation_result;
Status s = CompileSerializedMlirToXlaHlo(
mlir_module, absl::Span<TensorShape>(arg_shapes), TestShapeRepresentation,
&compilation_result);
mlir_module, arg_shapes, TestShapeRepresentation, &compilation_result);
TF_ASSERT_OK(s);
const xla::HloModuleConfig module_config(

View File

@ -22,13 +22,11 @@ limitations under the License.
#include "mlir/IR/Location.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/stream_executor/lib/statusor.h"
// Error utilities for MLIR when interacting with code using Status returns.
namespace mlir {
// TensorFlow's Status is used for error reporting back to callers.
using stream_executor::port::StatusOr;
using tensorflow::Status;
// Diagnostic handler that collects all the diagnostics reported and can produce

View File

@ -34,6 +34,7 @@ limitations under the License.
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
#include "mlir/Support/DebugStringHelper.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
@ -253,21 +254,30 @@ StatusOr<std::unique_ptr<NodeDef>> GetOperationNodeDef(
// Note: we do not use NodeBuilder or NodeDefBuilder as that would require
// mapping back from the inputs to the input arguments.
// Some control flow ops in TensorFlow Graph have their respective "Ref" ops
// as well. For example there is Enter and RefEnter op. RefEnter forwards
// the input ref buffer to output. However both Enter and RefEnter are
// mapped to tf_executor::EnterOp during import and then to _tf.Enter op in
// control dialect. Check if it is a Ref op to correctly map to the TensorFlow
// Graph op.
llvm::SmallString<64> op_name;
if (IsRefTypeControlOp(inst)) op_name = "Ref";
TF_ASSIGN_OR_RETURN(auto tf_name,
GetTensorFlowOpName(inst->getName().getStringRef()));
op_name.append(tf_name);
if (IsLegacyCallInstruction(inst)) {
// The op_name is the name of the function.
op_name.append(
inst->getAttrOfType<mlir::SymbolRefAttr>("f").getLeafReference());
// Remove the attribute from the instruction as it is already converted to
// op_name.
auto attr_id = mlir::Identifier::get("f", inst->getContext());
inst->removeAttr(attr_id);
} else {
// Some control flow ops in TensorFlow Graph have their respective "Ref" ops
// as well. For example there is Enter and RefEnter op. RefEnter forwards
// the input ref buffer to output. However both Enter and RefEnter are
// mapped to tf_executor::EnterOp during import and then to _tf.Enter op in
// control dialect. Check if it is a Ref op to correctly map to the
// TensorFlow Graph op.
if (IsRefTypeControlOp(inst)) op_name = "Ref";
TF_ASSIGN_OR_RETURN(auto tf_name,
GetTensorFlowOpName(inst->getName().getStringRef()));
op_name.append(tf_name);
}
node_def->set_name(name.str());
node_def->set_op(op_name.str());
node_def->set_name(name);
// Add inputs to the NodeDef based on the number of operands. This is required
// as later when edges are added to the Node using Graph::AddEdge the
@ -454,4 +464,9 @@ Status SetSizeAttribute(absl::string_view name, size_t size,
return Status::OK();
}
bool IsLegacyCallInstruction(mlir::Operation* inst) {
return llvm::dyn_cast<mlir::TF::LegacyCallOp>(inst) ||
inst->getName().getStringRef().compare("_tf.LegacyCall") == 0;
}
} // namespace tensorflow

View File

@ -73,5 +73,16 @@ Status SetShapeAttribute(absl::string_view name, mlir::ShapedType shape,
// If the attribute already exists with a different value, returns an error.
Status SetSizeAttribute(absl::string_view name, size_t size,
AttrValueMap* values);
// Returns true if the given instruction is an mlir::TF::LegacyCallOp or the
// result of such an operation transformed by the
// ExecutorToControlDialectConversion pass.
//
// TODO(b/145706023): When the ExecutorToControlDialectConversion pass runs
// before the exporter, it mutates an mlir::TF::LegacyCallOp instruction to
// an instruction with a different operation name. As such, this routine checks
// both forms of a LegacyCall instruction. We only need to check for
// mlir::TF::LegacyCallOp when the ticket is resolved.
bool IsLegacyCallInstruction(mlir::Operation* inst);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_EXPORTER_UTILS_H_

View File

@ -23,6 +23,8 @@ package_group(
],
)
exports_files(["ir/hlo_ops.td"])
filegroup(
name = "hlo_ops_td_files",
srcs = [
@ -406,6 +408,7 @@ cc_library(
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client/lib:matrix",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/stream_executor/lib",
"@llvm//:support",
"@local_config_mlir//:Analysis",
"@local_config_mlir//:IR",

View File

@ -182,13 +182,12 @@ tensorflow::Status HloFunctionImporter::ImportInstructions(
// Setup the return type (HLO only supports a single return value).
TF_ASSIGN_OR_RETURN(auto result,
GetMlirValue(computation->root_instruction()));
llvm::SmallVector<Value*, 1> return_values({result});
// Create terminator op depending on the parent op of this region.
if (llvm::isa<FuncOp>(block->getParentOp())) {
builder.create<mlir::ReturnOp>(loc, makeArrayRef(return_values));
builder.create<mlir::ReturnOp>(loc, result);
} else {
builder.create<mlir::xla_hlo::ReturnOp>(loc, makeArrayRef(return_values));
builder.create<mlir::xla_hlo::ReturnOp>(loc, result);
}
return tensorflow::Status::OK();
}
@ -266,32 +265,20 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
MakeAndReturn(CompareOp);
}
case HloOpcode::kGather: {
const auto& gather_dimensions = instruction->gather_dimension_numbers();
std::vector<int64_t> offset_dims(gather_dimensions.offset_dims().begin(),
gather_dimensions.offset_dims().end());
auto gather_instruction = static_cast<HloGatherInstruction*>(instruction);
attributes.push_back(ConvertGatherDimensionNumbers(
gather_instruction->gather_dimension_numbers()));
std::vector<int64_t> slice_sizes(
instruction->gather_slice_sizes().begin(),
instruction->gather_slice_sizes().end());
gather_instruction->gather_slice_sizes().begin(),
gather_instruction->gather_slice_sizes().end());
attributes.push_back(
builder_->getNamedAttr("slice_sizes", Convert(slice_sizes)));
attributes.push_back(builder_->getNamedAttr(
"indices_are_sorted",
builder_->getBoolAttr(gather_instruction->indices_are_sorted())));
std::vector<int64_t> collapsed_slice_dims(
gather_dimensions.collapsed_slice_dims().begin(),
gather_dimensions.collapsed_slice_dims().end());
std::vector<int64_t> start_index_map(
gather_dimensions.start_index_map().begin(),
gather_dimensions.start_index_map().end());
// TODO(b/132057942): Change to explicitly passing an integer instead of
// call getI64IntegerAttr here.
return func_builder
->create<mlir::xla_hlo::GatherOp>(
loc, result_type, operands[0], operands[1],
func_builder->getI64IntegerAttr(
gather_dimensions.index_vector_dim()),
Convert(offset_dims), Convert(slice_sizes),
Convert(collapsed_slice_dims), Convert(start_index_map))
.getOperation();
MakeAndReturn(GatherOp);
}
case HloOpcode::kDynamicUpdateSlice: {
return func_builder
@ -707,4 +694,19 @@ mlir::NamedAttribute HloFunctionImporter::ConvertConvDimensionNumbers(
return builder_->getNamedAttr("dimension_numbers", attr);
}
mlir::NamedAttribute HloFunctionImporter::ConvertGatherDimensionNumbers(
const xla::GatherDimensionNumbers& dnums) {
std::vector<int64_t> offset_dims(dnums.offset_dims().begin(),
dnums.offset_dims().end());
std::vector<int64_t> collapsed_slice_dims(
dnums.collapsed_slice_dims().begin(), dnums.collapsed_slice_dims().end());
std::vector<int64_t> start_index_map(dnums.start_index_map().begin(),
dnums.start_index_map().end());
auto attr = mlir::xla_hlo::GatherDimensionNumbers::get(
Convert(offset_dims), Convert(collapsed_slice_dims),
Convert(start_index_map),
builder_->getI64IntegerAttr(dnums.index_vector_dim()), context_);
return builder_->getNamedAttr("dimension_numbers", attr);
}
} // namespace xla

View File

@ -117,6 +117,10 @@ class HloFunctionImporter {
mlir::NamedAttribute ConvertConvDimensionNumbers(
const xla::ConvolutionDimensionNumbers& dnums);
// Converts the gather dimensions to attributes.
mlir::NamedAttribute ConvertGatherDimensionNumbers(
const xla::GatherDimensionNumbers& dnums);
mlir::MLIRContext* context_;
mlir::ModuleOp module_;
mlir::Builder* builder_;

View File

@ -606,7 +606,7 @@ static TensorType GetReduceResultType(Type operand_ty,
}
void ReduceOp::build(Builder* builder, OperationState& state,
ArrayRef<Value*> operands, ArrayRef<Value*> init_values,
ValueRange operands, ValueRange init_values,
DenseIntElementsAttr dimensions) {
SmallVector<Type, 1> result_ty;
result_ty.reserve(operands.size());
@ -845,9 +845,8 @@ Type SliceOp::InferOutputTypes(Builder* builder, Value* operand,
// SortOp
//===----------------------------------------------------------------------===//
void SortOp::build(Builder* builder, OperationState& state,
ArrayRef<Value*> operands, int64_t dimension,
bool is_stable) {
void SortOp::build(Builder* builder, OperationState& state, ValueRange operands,
int64_t dimension, bool is_stable) {
state.addOperands(operands);
state.addAttribute("dimension", builder->getI64IntegerAttr(dimension));
state.addAttribute("is_stable", builder->getBoolAttr(dimension));
@ -990,7 +989,7 @@ void GetTupleElementOp::build(Builder* builder, OperationState& result,
//===----------------------------------------------------------------------===//
void TupleOp::build(Builder* builder, OperationState& result,
ArrayRef<Value*> values) {
ValueRange values) {
SmallVector<Type, 4> types;
types.reserve(values.size());
for (auto val : values) {

View File

@ -405,8 +405,8 @@ def HLO_ReduceOp: HLO_Op<"reduce", [
let results = (outs Variadic<HLO_TensorOrTuple>);
let builders = [OpBuilder<
"Builder *, OperationState &state, ArrayRef<Value *> operands, "
"ArrayRef<Value *> init_values, DenseIntElementsAttr dimensions"
"Builder *, OperationState &state, ValueRange operands, "
"ValueRange init_values, DenseIntElementsAttr dimensions"
>];
let hasFolder = 1;
@ -445,7 +445,7 @@ def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]>, BASE_HLO_TupleOp {
let builders = [OpBuilder<
"Builder *builder, OperationState &results, "
"ArrayRef<Value*> values">];
"ValueRange values">];
// TupleOp has special conversion logic to HLO.
let hasCustomHLOConverter = 1;
@ -777,21 +777,25 @@ def HLO_FftOp: HLO_Op<"fft", [NoSideEffect]>, BASE_HLO_FftOp {
let hasCustomHLOConverter = 1;
}
def GatherDimensionNumbers : StructAttr<"GatherDimensionNumbers", HLO_Dialect,
[StructFieldAttr<"offset_dims", I64ElementsAttr>,
StructFieldAttr<"collapsed_slice_dims", I64ElementsAttr>,
StructFieldAttr<"start_index_map", I64ElementsAttr>,
StructFieldAttr<"index_vector_dim", I64Attr>]> {
let description = "Structure of dimension information for gather";
}
def HLO_GatherOp: HLO_Op<"gather", [NoSideEffect]>, BASE_HLO_GatherOp {
let arguments = (ins
HLO_Tensor:$operand,
HLO_IntTensor:$start_indices,
I64Attr:$index_vector_dim,
I64ElementsAttr:$offset_dims,
GatherDimensionNumbers:$dimension_numbers,
I64ElementsAttr:$slice_sizes,
I64ElementsAttr:$collapsed_slice_dims,
I64ElementsAttr:$start_index_map
DefaultValuedAttr<BoolAttr, "false">:$indices_are_sorted
);
let results = (outs HLO_Tensor);
// TODO(b/129422361) Attributes are not supported by the codegen. The
// optional argument (dimensions) needs to be added as an attribute.
let hasCustomHLOConverter = 1;
}
@ -880,7 +884,7 @@ def HLO_SortOp : HLO_Op<"sort", [NoSideEffect]>, BASE_HLO_SortOp {
let regions = (region SizedRegion<1>:$comparator);
let builders = [OpBuilder<
"Builder *builder, OperationState &state, ArrayRef<Value *> operands, "
"Builder *builder, OperationState &state, ValueRange operands, "
"int64_t dimension, bool is_stable"
>];

View File

@ -40,7 +40,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/stream_executor/lib/statusor.h"
using ::stream_executor::port::StatusOr;
using ::tensorflow::int16;
using ::tensorflow::int32;
using ::tensorflow::int64;
@ -149,6 +151,7 @@ I64_ELEMENTS_ATTR_TO_VECTOR(permutation);
I64_ELEMENTS_ATTR_TO_VECTOR(start_indices);
I64_ELEMENTS_ATTR_TO_VECTOR(limit_indices);
I64_ELEMENTS_ATTR_TO_VECTOR(strides);
I64_ELEMENTS_ATTR_TO_VECTOR(slice_sizes);
#undef I64_ELEMENTS_ATTR_TO_VECTOR
@ -267,6 +270,30 @@ static xla::ComparisonDirection Convert_comparison_direction(
.ValueOrDie();
}
static xla::GatherDimensionNumbers Convert_gather_dimension_numbers(
mlir::xla_hlo::GatherDimensionNumbers input) {
xla::GatherDimensionNumbers output;
auto offset_dims = ConvertDenseIntAttr(input.offset_dims());
std::copy(offset_dims.begin(), offset_dims.end(),
tensorflow::protobuf::RepeatedFieldBackInserter(
output.mutable_offset_dims()));
auto collapsed_slice_dims = ConvertDenseIntAttr(input.collapsed_slice_dims());
std::copy(collapsed_slice_dims.begin(), collapsed_slice_dims.end(),
tensorflow::protobuf::RepeatedFieldBackInserter(
output.mutable_collapsed_slice_dims()));
auto start_index_map = ConvertDenseIntAttr(input.start_index_map());
std::copy(start_index_map.begin(), start_index_map.end(),
tensorflow::protobuf::RepeatedFieldBackInserter(
output.mutable_start_index_map()));
output.set_index_vector_dim(
ConvertAPInt(input.index_vector_dim().getValue()));
return output;
}
static xla::ScatterDimensionNumbers Convert_scatter_dimension_numbers(
mlir::xla_hlo::ScatterDimensionNumbers input) {
xla::ScatterDimensionNumbers output;
@ -496,7 +523,13 @@ LogicalResult ExportXlaOp(DynamicUpdateSliceOp op, OpLoweringContext ctx) {
LogicalResult ExportXlaOp(FftOp op, OpLoweringContext ctx) { return failure(); }
LogicalResult ExportXlaOp(GatherOp op, OpLoweringContext ctx) {
return failure();
auto& value_map = *ctx.values;
xla::GatherDimensionNumbers dimension_numbers =
Convert_gather_dimension_numbers(op.dimension_numbers());
value_map[op] = xla::Gather(
value_map[op.operand()], value_map[op.start_indices()], dimension_numbers,
Convert_slice_sizes(op.slice_sizes()), op.indices_are_sorted());
return success();
}
LogicalResult ExportXlaOp(IotaOp op, OpLoweringContext ctx) {

View File

@ -25,18 +25,25 @@ func @fusedBatchNorm_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>,
// CHECK-LABEL: func @biasAdd_NHWC
func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> {
// CHECK-NEXT: %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>}
// CHECK: "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>}
%0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32>
return %0 : tensor<1x32x10x32xi32>
}
// CHECK-LABEL: func @biasAdd_NCHW
func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> {
// CHECK-NEXT: %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK: "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>}
%0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NCHW"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32>
return %0 : tensor<1x32x10x32xi32>
}
// CHECK-LABEL: func @biasAdd_dynamic
func @biasAdd_dynamic(%arg0: tensor<?x?x?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?x?x?x?xi32> {
// CHECK: "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>}
%0 = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NCHW"} : (tensor<?x?x?x?xi32>, tensor<?xi32>) -> tensor<?x?x?x?xi32>
return %0 : tensor<?x?x?x?xi32>
}
//===----------------------------------------------------------------------===//
// Binary op legalizations.
//===----------------------------------------------------------------------===//
@ -666,11 +673,18 @@ func @preventgradient(%arg0: tensor<1xi32>) -> tensor<1xi32> {
// CHECK-LABEL: @const
func @const() -> tensor<2xi32> {
// CHECK-NEXT: xla_hlo.constant dense<0> : tensor<2xi32>
// CHECK: xla_hlo.constant dense<0> : tensor<2xi32>
%0 = "tf.Const"() {device = "", name = "", dtype = "tfdtype$DT_INT32", value = dense<0> : tensor<2xi32>} : () -> (tensor<2xi32>)
return %0: tensor<2xi32>
}
// CHECK-LABEL: @const_dynamic_output
func @const_dynamic_output() -> tensor<*xi32> {
// CHECK: xla_hlo.constant {value = dense<0> : tensor<2xi32>} : tensor<*xi32>
%0 = "tf.Const"() {value = dense<0> : tensor<2xi32>} : () -> (tensor<*xi32>)
return %0: tensor<*xi32>
}
// CHECK-LABEL: @opaque_const
func @opaque_const() -> tensor<!tf.variant<tensor<2xi32>>> {
// CHECK-NOT: xla_hlo.constant
@ -838,13 +852,14 @@ func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> {
}
// CHECK-LABEL: func @relu_grad
// CHECK-SAME: (%[[GRADIENTS:.*]]: tensor<4x8xf32>, %[[FEATURES:.*]]: tensor<4x8xf32>)
func @relu_grad(%gradients: tensor<4x8xf32>, %features: tensor<4x8xf32>) -> tensor<4x8xf32> {
// CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<4x8xf32>
// CHECK: %[[PRED:.*]] = "xla_hlo.compare"(%[[FEATURES]], %[[ZERO]]) {comparison_direction = "GT"} : (tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xi1>
// CHECK: %[[RESULT:.*]] = "xla_hlo.select"(%[[PRED]], %[[GRADIENTS]], %[[ZERO]]) : (tensor<4x8xi1>, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
// CHECK: return %[[RESULT]] : tensor<4x8xf32>
%2 = "tf.ReluGrad"(%gradients, %features) : (tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
// CHECK-SAME: (%[[GRADIENTS:.*]]: tensor<4x8xf32>, %[[FEATURES:.*]]: tensor<?x?xf32>)
func @relu_grad(%gradients: tensor<4x8xf32>, %features: tensor<?x?xf32>) -> tensor<4x8xf32> {
// CHECK-DAG: %[[ZERO_SCALAR:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK-DAG: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<4x8xf32>
// CHECK-DAG: %[[PRED:.*]] = "xla_hlo.compare"(%[[FEATURES]], %[[ZERO_SCALAR]]) {comparison_direction = "GT"} : (tensor<?x?xf32>, tensor<f32>) -> tensor<*xi1>
// CHECK-DAG: %[[RESULT:.*]] = "xla_hlo.select"(%[[PRED]], %[[GRADIENTS]], %[[ZERO]]) : (tensor<*xi1>, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
// CHECK-DAG: return %[[RESULT]] : tensor<4x8xf32>
%2 = "tf.ReluGrad"(%gradients, %features) : (tensor<4x8xf32>, tensor<?x?xf32>) -> tensor<4x8xf32>
return %2 : tensor<4x8xf32>
}
@ -1019,6 +1034,14 @@ func @transpose_2d(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> {
return %0 : tensor<3x2xf32>
}
// CHECK-LABEL: @transpose_3d_int32
func @transpose_3d_int32(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> {
%permutation = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi32>} : () -> (tensor<3xi32>)
// CHECK: "xla_hlo.transpose"
%0 = "tf.Transpose"(%arg0, %permutation) : (tensor<1x2x3xf32>, tensor<3xi32>) -> tensor<3x2x1xf32>
return %0 : tensor<3x2x1xf32>
}
// CHECK-LABEL: @transpose_3d
func @transpose_3d(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> {
%permutation = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi64>} : () -> (tensor<3xi64>)
@ -1344,35 +1367,42 @@ func @tanh_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
// CHECK-LABEL: reshape
func @reshape(%arg0: tensor<2xf32>, %arg1: tensor<2xi32>) -> tensor<1x1xf32> {
// CHECK: %0 = "xla_hlo.reshape"(%arg0) : (tensor<2xf32>) -> tensor<1x1xf32>
// CHECK: "xla_hlo.reshape"
%0 = "tf.Reshape"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xi32>) -> tensor<1x1xf32>
return %0 : tensor<1x1xf32>
}
// CHECK-LABEL: reshape_dynamic
func @reshape_dynamic(%arg0: tensor<*xf32>, %arg1: tensor<2xi32>) -> tensor<?x?xf32> {
// CHECK: %0 = "tf.Reshape"(%arg0, %arg1) : (tensor<*xf32>, tensor<2xi32>) -> tensor<?x?xf32>
func @reshape_dynamic(%arg0: tensor<?xf32>, %arg1: tensor<2xi32>) -> tensor<1x1xf32> {
// CHECK: "xla_hlo.reshape"
%0 = "tf.Reshape"(%arg0, %arg1) : (tensor<?xf32>, tensor<2xi32>) -> tensor<1x1xf32>
return %0 : tensor<1x1xf32>
}
// CHECK-LABEL: reshape_unranked
func @reshape_unranked(%arg0: tensor<*xf32>, %arg1: tensor<2xi32>) -> tensor<?x?xf32> {
// CHECK: "tf.Reshape"
%0 = "tf.Reshape"(%arg0, %arg1) : (tensor<*xf32>, tensor<2xi32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
// CHECK-LABEL: squeeze
func @squeeze(%arg0: tensor<1x1x10xf32>) -> tensor<1x10xf32> {
// CHECK-NEXT: %0 = "xla_hlo.reshape"(%arg0) : (tensor<1x1x10xf32>) -> tensor<1x10xf32>
// CHECK: "xla_hlo.reshape"
%0 = "tf.Squeeze"(%arg0) : (tensor<1x1x10xf32>) -> tensor<1x10xf32>
return %0 : tensor<1x10xf32>
}
// CHECK-LABEL: squeeze_dynamic
func @squeeze_dynamic(%arg0: tensor<?x10xf32>) -> tensor<*xf32> {
// CHECK-NEXT: %0 = "tf.Squeeze"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
// CHECK: "tf.Squeeze"
%0 = "tf.Squeeze"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// CHECK-LABEL: expand_dims
func @expand_dims(%arg0: tensor<2xf32>, %axis: tensor<i32>) -> tensor<1x2xf32> {
// CHECK: "xla_hlo.reshape"{{.*}} : (tensor<2xf32>) -> tensor<1x2xf32>
// CHECK: "xla_hlo.reshape"
%0 = "tf.ExpandDims"(%arg0, %axis) : (tensor<2xf32>, tensor<i32>) -> tensor<1x2xf32>
return %0 : tensor<1x2xf32>
}
@ -1380,7 +1410,8 @@ func @expand_dims(%arg0: tensor<2xf32>, %axis: tensor<i32>) -> tensor<1x2xf32> {
// CHECK-LABEL: slice_constant_start
func @slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> {
// CHECK: %[[START:.*]] = xla_hlo.constant dense<1> : tensor<1xi64>
// CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[START]]) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor<1xi64>) -> tensor<2xi32>
// CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START]]) : (tensor<1xi64>) -> tensor<1xi64>
// CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[START_I64]]) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor<1xi64>) -> tensor<2xi32>
// CHECK: return %[[RESULT]] : tensor<2xi32>
%starts = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>)
%sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi64>} : () -> (tensor<1xi64>)
@ -1388,10 +1419,22 @@ func @slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> {
return %0 : tensor<2xi32>
}
// CHECK-LABEL: slice_i32_consts
func @slice_i32_consts(%arg0: tensor<4xi32>) -> tensor<2xi32> {
// CHECK: %[[START:.*]] = xla_hlo.constant dense<1> : tensor<1xi32>
// CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START]]) : (tensor<1xi32>) -> tensor<1xi64>
// CHECK: slice_sizes = dense<2> : tensor<1xi64>
%starts = "tf.Const"() {value = dense<[1]> : tensor<1xi32>} : () -> (tensor<1xi32>)
%sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi32>} : () -> (tensor<1xi32>)
%0 = "tf.Slice"(%arg0, %starts, %sizes) : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
return %0 : tensor<2xi32>
}
// CHECK-LABEL: slice_constant_start_negative_one_size
func @slice_constant_start_negative_one_size(%arg0: tensor<4xi32>) -> tensor<3xi32> {
// CHECK: %[[START:.*]] = xla_hlo.constant dense<1> : tensor<1xi64>
// CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[START]]) {slice_sizes = dense<3> : tensor<1xi64>} : (tensor<4xi32>, tensor<1xi64>) -> tensor<3xi32>
// CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START]]) : (tensor<1xi64>) -> tensor<1xi64>
// CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[START_I64]]) {slice_sizes = dense<3> : tensor<1xi64>} : (tensor<4xi32>, tensor<1xi64>) -> tensor<3xi32>
// CHECK: return %[[RESULT]] : tensor<3xi32>
%starts = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>)
%sizes = "tf.Const"() {value = dense<[-1]> : tensor<1xi64>} : () -> (tensor<1xi64>)
@ -1402,7 +1445,8 @@ func @slice_constant_start_negative_one_size(%arg0: tensor<4xi32>) -> tensor<3xi
// CHECK-LABEL: slice_constant_start_dynamic_shape
func @slice_constant_start_dynamic_shape(%arg0: tensor<?x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> {
// CHECK: %[[START:.*]] = xla_hlo.constant dense<[1, 0]> : tensor<2xi64>
// CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[START]]) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<?x4xi32>, tensor<2xi64>) -> tensor<1x4xi32>
// CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START]]) : (tensor<2xi64>) -> tensor<2xi64>
// CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[START_I64]]) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<?x4xi32>, tensor<2xi64>) -> tensor<1x4xi32>
// CHECK: return %[[RESULT]] : tensor<1x4xi32>
%starts = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>)
%sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>)
@ -1412,7 +1456,8 @@ func @slice_constant_start_dynamic_shape(%arg0: tensor<?x4xi32>, %arg1: tensor<2
// CHECK-LABEL: slice_variable_start
func @slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> {
// CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xi32>
// CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%arg1) : (tensor<2xi64>) -> tensor<2xi64>
// CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[START_I64]]) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xi32>
// CHECK: return %[[RESULT]] : tensor<1x4xi32>
%sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>)
%0 = "tf.Slice"(%arg0, %arg1, %sizes) : (tensor<3x4xi32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32>
@ -1525,6 +1570,16 @@ func @mean(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> {
return %0 : tensor<4x1xf16>
}
// CHECK-LABEL: func @mean_scalar_dim
func @mean_scalar_dim(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> {
// Verify that tf.Mean op with scalar attributes are lowered successfully.
// CHECK-NOT: tf.Mean
%dimension = "tf.Const"() { value = dense<1> : tensor<i64> } : () -> tensor<i64>
%0 = "tf.Mean"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<i64>) -> tensor<4x1xf16>
return %0 : tensor<4x1xf16>
}
// CHECK-LABEL: func @mean_dynamic
func @mean_dynamic(%arg0: tensor<4x?xf16>) -> tensor<4x1xf16> {
%dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64>
@ -1601,6 +1656,66 @@ func @max_dynamic(%arg0: tensor<4x?xf16>) -> tensor<4x1xf16> {
return %0 : tensor<4x1xf16>
}
// CHECK-LABEL: @all
func @all(%input: tensor<4x8xi1>) -> tensor<4xi1> {
%dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
// CHECK: %[[INIT:.*]] = xla_hlo.constant dense<true> : tensor<i1>
// CHECK: "xla_hlo.reduce"(%{{.*}}, %[[INIT]]) ( {
// CHECK: ^{{.*}}(%[[ARGA:.*]]: tensor<i1>, %[[ARGB:.*]]: tensor<i1>):
// CHECK: %[[AND:.*]] = xla_hlo.and %[[ARGA]], %[[ARGB]] : tensor<i1>
// CHECK: "xla_hlo.return"(%[[AND]]) : (tensor<i1>) -> ()
// CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xi1>, tensor<i1>) -> tensor<4xi1>
%0 = "tf.All"(%input, %dims) : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4xi1>
return %0 : tensor<4xi1>
}
// CHECK-LABEL: @all_keep_dim
func @all_keep_dim(%input: tensor<4x8xi1>) -> tensor<4x1xi1> {
// CHECK: "xla_hlo.reshape"(%{{.*}}) : (tensor<4xi1>) -> tensor<4x1xi1>
%dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
%0 = "tf.All"(%input, %dims) {keep_dims = true} : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4x1xi1>
return %0 : tensor<4x1xi1>
}
// CHECk-LABEL: @all_dynamic
func @all_dynamic(%input: tensor<4x?xi1>) -> tensor<4x1xi1> {
%dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
// CHECK: %[[ARG:.*]] = "xla_hlo.convert"(%{{.*}}) : (tensor<4x?xi1>) -> tensor<4x?xi1>
// CHECK: "xla_hlo.reduce"(%[[ARG]]
%0 = "tf.All"(%input, %dims) {keep_dims = true} : (tensor<4x?xi1>, tensor<1xi32>) -> tensor<4x1xi1>
return %0 : tensor<4x1xi1>
}
// CHECK-LABEL: @any
func @any(%input: tensor<4x8xi1>) -> tensor<4xi1> {
%dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
// CHECK: %[[INIT:.*]] = xla_hlo.constant dense<false> : tensor<i1>
// CHECK: "xla_hlo.reduce"(%{{.*}}, %[[INIT]]) ( {
// CHECK: ^{{.*}}(%[[ARGA:.*]]: tensor<i1>, %[[ARGB:.*]]: tensor<i1>):
// CHECK: %[[AND:.*]] = xla_hlo.or %[[ARGA]], %[[ARGB]] : tensor<i1>
// CHECK: "xla_hlo.return"(%[[AND]]) : (tensor<i1>) -> ()
// CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xi1>, tensor<i1>) -> tensor<4xi1>
%0 = "tf.Any"(%input, %dims) : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4xi1>
return %0 : tensor<4xi1>
}
// CHECK-LABEL: @any_keep_dim
func @any_keep_dim(%input: tensor<4x8xi1>) -> tensor<4x1xi1> {
// CHECK: "xla_hlo.reshape"(%{{.*}}) : (tensor<4xi1>) -> tensor<4x1xi1>
%dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
%0 = "tf.Any"(%input, %dims) {keep_dims = true} : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4x1xi1>
return %0 : tensor<4x1xi1>
}
// CHECk-LABEL: @any_dynamic
func @any_dynamic(%input: tensor<4x?xi1>) -> tensor<4x1xi1> {
%dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
// CHECK: %[[ARG:.*]] = "xla_hlo.convert"(%{{.*}}) : (tensor<4x?xi1>) -> tensor<4x?xi1>
// CHECK: "xla_hlo.reduce"(%[[ARG]]
%0 = "tf.Any"(%input, %dims) {keep_dims = true} : (tensor<4x?xi1>, tensor<1xi32>) -> tensor<4x1xi1>
return %0 : tensor<4x1xi1>
}
//===----------------------------------------------------------------------===//
// Tile op legalizations.
//===----------------------------------------------------------------------===//
@ -1924,12 +2039,23 @@ func @split_match_and_split_into_two(%input: tensor<4x6xf32>) -> (tensor<2x6xf32
return %0#0, %0#1 : tensor<2x6xf32>, tensor<2x6xf32>
}
// CHECK-LABEL: @split_match_and_split_into_two_dynamic
func @split_match_and_split_into_two_dynamic(%input: tensor<4x?xf32>) -> (tensor<2x?xf32>, tensor<2x?xf32>) {
%cst = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
// CHECK: %[[ONE:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[2, -1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x?xf32>) -> tensor<2x?xf32>
// CHECK: %[[TWO:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[4, -1]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x?xf32>) -> tensor<2x?xf32>
%0:2 = "tf.Split"(%cst, %input) : (tensor<i32>, tensor<4x?xf32>) -> (tensor<2x?xf32>, tensor<2x?xf32>)
// CHECK: return %[[ONE]], %[[TWO]]
return %0#0, %0#1 : tensor<2x?xf32>, tensor<2x?xf32>
}
// CHECK-LABEL: @split_match_and_split_into_three
// CHECK-SAME: (%[[ARG:.*]]: tensor<4x6xf32>)
func @split_match_and_split_into_three(%input: tensor<4x6xf32>) -> (tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>) {
%cst = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
// CHECK: %[[ONE:.*]] = "xla_hlo.slice"(%arg0) {limit_indices = dense<[4, 2]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32>
// CHECK: %[[TWO:.*]] = "xla_hlo.slice"(%arg0) {limit_indices = dense<4> : tensor<2xi64>, start_indices = dense<[0, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32>
// CHECK: %[[THREE:.*]] = "xla_hlo.slice"(%arg0) {limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[0, 4]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32>
// CHECK: %[[ONE:.*]] = "xla_hlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 2]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32>
// CHECK: %[[TWO:.*]] = "xla_hlo.slice"(%[[ARG]]) {limit_indices = dense<4> : tensor<2xi64>, start_indices = dense<[0, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32>
// CHECK: %[[THREE:.*]] = "xla_hlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[0, 4]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32>
%0:3 = "tf.Split"(%cst, %input) : (tensor<i32>, tensor<4x6xf32>) -> (tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>)
// CHECK: return %[[ONE]], %[[TWO]], %[[THREE]]
return %0#0, %0#1, %0#2 : tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>
@ -1973,3 +2099,82 @@ func @topk_v2(%input: tensor<16x16xf32>) -> (tensor<16x8xf32>, tensor<16x8xi32>)
%0:2 = "tf.TopKV2"(%input, %k): (tensor<16x16xf32>, tensor<i32>) -> (tensor<16x8xf32>, tensor<16x8xi32>)
return %0#0, %0#1: tensor<16x8xf32>, tensor<16x8xi32>
}
//===----------------------------------------------------------------------===//
// tf.SplitV legalization
//===----------------------------------------------------------------------===//
// CHECK-LABEL: @splitv_match_and_split_into_three
// CHECK-SAME: (%[[ARG:.*]]: tensor<4x6xf32>)
func @splitv_match_and_split_into_three(%input: tensor<4x6xf32>) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) {
%split_sizes = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32>
%split_dim = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
// CHECK: %[[ONE:.*]] = "xla_hlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x1xf32>
// CHECK: %[[TWO:.*]] = "xla_hlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32>
// CHECK: %[[THREE:.*]] = "xla_hlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[0, 3]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x3xf32>
%0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x6xf32>, tensor<3xi32>, tensor<i32>) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>)
// CHECK: return %[[ONE]], %[[TWO]], %[[THREE]]
return %0#0, %0#1, %0#2 : tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>
}
// CHECK-LABEL: @splitv_match_and_split_into_three_dynamic
func @splitv_match_and_split_into_three_dynamic(%input: tensor<?x6xf32>) -> (tensor<?x1xf32>, tensor<?x2xf32>, tensor<?x3xf32>) {
%split_sizes = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32>
%split_dim = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
// CHECK: "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[-1, 1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<?x6xf32>) -> tensor<?x1xf32>
// CHECK: "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[-1, 3]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<?x6xf32>) -> tensor<?x2xf32>
// CHECK: "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[-1, 6]> : tensor<2xi64>, start_indices = dense<[0, 3]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<?x6xf32>) -> tensor<?x3xf32>
%0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<?x6xf32>, tensor<3xi32>, tensor<i32>) -> (tensor<?x1xf32>, tensor<?x2xf32>, tensor<?x3xf32>)
return %0#0, %0#1, %0#2 : tensor<?x1xf32>, tensor<?x2xf32>, tensor<?x3xf32>
}
// CHECK-LABEL: @splitv_dynamic_dim_in_split_sizes
func @splitv_dynamic_dim_in_split_sizes(%input: tensor<4x6xf32>) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) {
%split_sizes = "tf.Const"() {value = dense<[1, -1, 3]> : tensor<3xi32>} : () -> tensor<3xi32>
%split_dim = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
// CHECK: limit_indices = dense<[4, 1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>
// CHECK: limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>
// CHECK: limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[0, 3]> : tensor<2xi64>
%0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x6xf32>, tensor<3xi32>, tensor<i32>) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>)
return %0#0, %0#1, %0#2 : tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>
}
//===----------------------------------------------------------------------===//
// tf.Assert legalization
//===----------------------------------------------------------------------===//
// CHECK-LABEL: @assert
func @assert(%arg0: tensor<i1>, %arg1: tensor<*xf32>) {
// CHECK-NOT: tf.Assert
"tf.Assert"(%arg0, %arg1) {summarize = 1} : (tensor<i1>, tensor<*xf32>) -> ()
return
}
//===----------------------------------------------------------------------===//
// tf.Unpack legalization
//===----------------------------------------------------------------------===//
// CHECK-LABEL: @unpack
func @unpack(%input: tensor<4x3x6xf32>) -> (tensor<4x?xf32>, tensor<4x6xf32>, tensor<4x6xf32>) {
// CHECK: %[[SLICE1:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[4, 1, 6]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32>
// CHECK: %[[RES1:.*]] = "xla_hlo.reshape"(%[[SLICE1]]) : (tensor<4x1x6xf32>) -> tensor<4x?xf32>
// CHECK: %[[SLICE2:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[4, 2, 6]> : tensor<3xi64>, start_indices = dense<[0, 1, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32>
// CHECK: %[[RES2:.*]] = "xla_hlo.reshape"(%[[SLICE2]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32>
// CHECK: %[[SLICE3:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[4, 3, 6]> : tensor<3xi64>, start_indices = dense<[0, 2, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32>
// CHECK: %[[RES3:.*]] = "xla_hlo.reshape"(%[[SLICE3]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32>
%0:3 = "tf.Unpack"(%input) {axis = 1} : (tensor<4x3x6xf32>) -> (tensor<4x?xf32>, tensor<4x6xf32>, tensor<4x6xf32>)
// return %[[RES1]], %[[RES2]], %[[RES3]]
return %0#0, %0#1, %0#2 : tensor<4x?xf32>, tensor<4x6xf32>, tensor<4x6xf32>
}
// CHECK-LABEL: @unpack_dynamic
func @unpack_dynamic(%input: tensor<?x?x2xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
// CHECK: %[[SLICE1:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[-1, -1, 1]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<?x?x2xf32>) -> tensor<?x?x1xf32>
// CHECK: "xla_hlo.reshape"(%[[SLICE1]]) : (tensor<?x?x1xf32>) -> tensor<?x?xf32>
// CHECK: %[[SLICE2:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[-1, -1, 2]> : tensor<3xi64>, start_indices = dense<[0, 0, 1]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<?x?x2xf32>) -> tensor<?x?x1xf32>
// CHECK: "xla_hlo.reshape"(%[[SLICE2]]) : (tensor<?x?x1xf32>) -> tensor<?x?xf32>
%0:2 = "tf.Unpack"(%input) {axis = -1} : (tensor<?x?x2xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>)
return %0#0, %0#1 : tensor<?x?xf32>, tensor<?x?xf32>
}

View File

@ -317,16 +317,33 @@ func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> {
// -----
// CHECK-LABEL: HloModule
// CHECK-LABEL: HloModule
func @main(%arg0: tensor<3x4xi32>, %arg1: tensor<4x5xi32>) -> tensor<3x5xi32> {
// Simple einsum is lowered to HLO dot op.
// CHECK: dot(s32[3,4] %{{.*}}, s32[4,5] %{{.*}}), lhs_contracting_dims={1}, rhs_contracting_dims={0}
// CHECK: dot(s32[3,4] %{{.*}}, s32[4,5] %{{.*}}), lhs_contracting_dims={1}, rhs_contracting_dims={0}
%0 = "xla_hlo.einsum"(%arg0, %arg1) {einsum_config = "ab,bc->ac"} : (tensor<3x4xi32>, tensor<4x5xi32>) -> tensor<3x5xi32>
return %0 : tensor<3x5xi32>
}
// -----
// CHECK-LABEL: HloModule
func @main(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> tensor<10x300xf32> {
// CHECK: [[ARG0:%.*]] = f32[200,100,300] parameter(0)
// CHECK: [[ARG1:%.*]] = s32[10,2] parameter(1)
// CHECK: f32[10,300] gather(f32[200,100,300] [[ARG0]], s32[10,2] [[ARG1]])
// CHECK-SAME: offset_dims={1}
// CHECK-SAME: collapsed_slice_dims={0,1}
// CHECK-SAME: start_index_map={0,1}
// CHECK-SAME: index_vector_dim=1
// CHECK-SAME: slice_sizes={1,1,300}
// CHECK-SAME: indices_are_sorted=true
%0 = "xla_hlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<[0, 1]> : tensor<2xi64>, index_vector_dim = 1 : i64, offset_dims = dense<1> : tensor<1xi64>, start_index_map = dense<[0, 1]> : tensor<2xi64>}, indices_are_sorted = true, name = "gather", slice_sizes = dense<[1, 1, 300]> : tensor<3xi64>} : (tensor<200x100x300xf32>, tensor<10x2xi32>) -> tensor<10x300xf32>
return %0 : tensor<10x300xf32>
}
// -----
// CHECK-LABEL: HloModule
func @main(%arg: tensor<4x2xf32>) -> tensor<i32> {
%0 = "xla_hlo.get_dimension_size"(%arg) {dimension = 1 : i32} : (tensor<4x2xf32>) -> tensor<i32>

View File

@ -317,6 +317,28 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] {
ROOT %floor.2 = f32[16] floor(f32[16] %arg0.1)
}
// CHECK-LABEL: func @test_gather(
// CHECK-SAME: [[ARG0:%.+]]: tensor<200x100x300xf32>, [[ARG1:%.+]]: tensor<10x2xi32>) -> tensor<10x300xf32> {
%test_gather (arg.0: f32[200,100,300], arg.1: s32[10,2]) -> f32[10,300] {
%arg.0 = f32[200,100,300] parameter(0)
%arg.1 = s32[10,2] parameter(1)
// CHECK: "xla_hlo.gather"([[ARG0]], [[ARG1]])
// CHECK-SAME: dimension_numbers
// CHECK-SAME: collapsed_slice_dims = dense<[0, 1]> : tensor<2xi64>
// CHECK-SAME: index_vector_dim = 1 : i64
// CHECK-SAME: offset_dims = dense<1> : tensor<1xi64>
// CHECK-SAME: start_index_map = dense<[0, 1]> : tensor<2xi64>
// CHECK-SAME: indices_are_sorted = true
// CHECK-SAME: slice_sizes = dense<[1, 1, 300]> : tensor<3xi64>
ROOT gather = f32[10,300] gather(f32[200,100,300] %arg.0, s32[10,2] %arg.1),
collapsed_slice_dims={0,1},
index_vector_dim=1,
offset_dims={1},
start_index_map={0,1},
indices_are_sorted=true,
slice_sizes={1,1,300}
}
// CHECK-LABEL: func @test_get_dimension_size
// CHECK-SAME: ([[ARG:%.*]]: tensor<4x2xf32>)
%test_get_dimension_size (Arg_0.1: f32[4,2]) -> s32[] {

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/BlockAndValueMapping.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Function.h" // TF:local_config_mlir
#include "mlir/IR/Location.h" // TF:local_config_mlir
@ -38,13 +39,19 @@ namespace {
constexpr StringRef kTempBufferAttr = "temp";
Value* GetTensorStoreMemRef(Value* value) {
Value* GetTensorStoreOrReturnMemRef(Value* value) {
for (const auto& user : value->getUsers()) {
if (auto tensor_store = dyn_cast<TensorStoreOp>(user)) {
if (tensor_store.getOperand(0) == value) {
return tensor_store.getOperand(1);
}
}
if (auto return_op = dyn_cast<xla_hlo::ReturnOp>(user)) {
if (return_op.getOperand(0) == value) {
auto block = return_op.getOperation()->getBlock();
return *block->args_rbegin();
}
}
}
return nullptr;
}
@ -88,8 +95,8 @@ Value* InsertAllocAndDealloc(Location loc, Value* result,
/// function to store that values held in the tensor.
Value* GetBufferForResultValue(Location loc, Value* result,
ConversionPatternRewriter* rewriter) {
if (auto tensor_store_memref = GetTensorStoreMemRef(result)) {
return tensor_store_memref;
if (auto existing_memref = GetTensorStoreOrReturnMemRef(result)) {
return existing_memref;
}
return InsertAllocAndDealloc(loc, result, rewriter);
}
@ -117,7 +124,63 @@ class HloToLhloOpConverter : public ConversionPattern {
rewriter.create<LhloOpTy>(op->getLoc(), llvm::None, buffer_args,
op->getAttrs());
rewriter.replaceOp(op, ArrayRef<Value*>(buffer_args).slice(operands.size()),
llvm::to_vector<4>(original_results));
original_results);
return matchSuccess();
}
};
struct HloToLHloReduceConverter
: public OpConversionPattern<xla_hlo::ReduceOp> {
public:
using OpConversionPattern::OpConversionPattern;
PatternMatchResult matchAndRewrite(
xla_hlo::ReduceOp op, ArrayRef<Value*> operands,
ConversionPatternRewriter& rewriter) const final {
auto loc = op.getLoc();
// TODO(b/137624192) Implement variadic reduce.
if (op.getNumResults() != 1) return matchFailure();
if (op.getParentRegion()->getBlocks().size() != 1) {
emitError(loc,
"tensor to buffer conversion expects a single block in the "
"region containing the operation");
}
const auto& original_results = op.getResults();
SmallVector<Value*, 4> buffer_args(operands.begin(), operands.end());
for (auto result : original_results) {
buffer_args.push_back(GetBufferForResultValue(loc, result, &rewriter));
}
auto new_op = rewriter.create<xla_lhlo::ReduceOp>(
loc, llvm::None, buffer_args, op.getAttrs());
// Copy over the operations inside the region.
rewriter.inlineRegionBefore(op.body(), new_op.body(), new_op.body().end());
// Create new block arguments with correct type.
auto& entry_block = new_op.body().front();
int original_arg_count = entry_block.getNumArguments();
for (int i = 0; i < original_arg_count; ++i) {
auto old_arg = entry_block.getArgument(i);
auto old_type = old_arg->getType().cast<TensorType>();
auto new_type =
MemRefType::get(old_type.getShape(), old_type.getElementType());
auto new_arg = entry_block.addArgument(new_type);
rewriter.replaceUsesOfBlockArgument(old_arg, new_arg);
}
// Add an argument for the result.
entry_block.addArgument(
entry_block.getArgument(original_arg_count)->getType());
// Remove the old arguments.
for (int i = original_arg_count - 1; i >= 0; --i) {
entry_block.eraseArgument(i);
}
// Insert terminator at the end.
rewriter.setInsertionPointToEnd(&entry_block);
rewriter.create<xla_lhlo::TerminatorOp>(loc);
rewriter.replaceOp(op, ArrayRef<Value*>(buffer_args).slice(operands.size()),
original_results);
return matchSuccess();
}
};
@ -130,11 +193,12 @@ class HloToLhloTensorLoadConverter : public ConversionPattern {
PatternMatchResult matchAndRewrite(
Operation* op, ArrayRef<Value*> operands,
ConversionPatternRewriter& rewriter) const final {
rewriter.replaceOp(op, operands, llvm::to_vector<4>(op->getResults()));
rewriter.replaceOp(op, operands, op->getResults());
return matchSuccess();
}
};
// TODO(b/137624192): Rewrite into a copy and elide copy if possible.
class HloToLhloTensorStoreConverter : public ConversionPattern {
public:
explicit HloToLhloTensorStoreConverter(MLIRContext* context)
@ -148,6 +212,19 @@ class HloToLhloTensorStoreConverter : public ConversionPattern {
}
};
// TODO(b/137624192): Rewrite into a copy and elide copy if possible.
class HloToLhloReturnConverter : public OpConversionPattern<xla_hlo::ReturnOp> {
public:
using OpConversionPattern::OpConversionPattern;
PatternMatchResult matchAndRewrite(
xla_hlo::ReturnOp op, ArrayRef<Value*> operands,
ConversionPatternRewriter& rewriter) const final {
rewriter.eraseOp(op);
return matchSuccess();
}
};
// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
// buffers if necessary.
//
@ -215,6 +292,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
xla_lhlo::BroadcastInDimOp>,
HloToLhloOpConverter<xla_hlo::CeilOp, xla_lhlo::CeilOp>,
HloToLhloOpConverter<xla_hlo::CompareOp, xla_lhlo::CompareOp>,
HloToLhloOpConverter<xla_hlo::ConstOp, xla_lhlo::ConstOp>,
HloToLhloOpConverter<xla_hlo::ConvertOp, xla_lhlo::ConvertOp>,
HloToLhloOpConverter<xla_hlo::CosOp, xla_lhlo::CosOp>,
HloToLhloOpConverter<xla_hlo::DivOp, xla_lhlo::DivOp>,
@ -229,6 +307,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
HloToLhloOpConverter<xla_hlo::SignOp, xla_lhlo::SignOp>,
HloToLhloOpConverter<xla_hlo::SubOp, xla_lhlo::SubOp>,
HloToLhloOpConverter<xla_hlo::TanhOp, xla_lhlo::TanhOp>,
HloToLHloReduceConverter, HloToLhloReturnConverter,
HloToLhloTensorLoadConverter, HloToLhloTensorStoreConverter
>(context);
// clang-format on

View File

@ -53,8 +53,7 @@ LogicalResult ReplaceTerminators(Region* region, Block* target_block,
auto return_op = dyn_cast<xla_hlo::ReturnOp>(block->getTerminator());
if (!return_op) continue;
builder->setInsertionPointToEnd(block);
builder->create<mlir::BranchOp>(
loc, target_block, llvm::to_vector<4>(return_op.getOperands()));
builder->create<mlir::BranchOp>(loc, target_block, return_op.getOperands());
return_op.erase();
}
@ -196,8 +195,7 @@ LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) {
dyn_cast<mlir::xla_hlo::ReturnOp>(new_block->getTerminator());
if (!return_op) continue;
builder.setInsertionPointToEnd(new_block);
builder.create<mlir::BranchOp>(loc, cond_block,
llvm::to_vector<4>(return_op.getOperands()));
builder.create<mlir::BranchOp>(loc, cond_block, return_op.getOperands());
return_op.erase();
}

View File

@ -127,8 +127,8 @@ static llvm::Optional<int64_t> GetIntegerHLOAxisFromTFAxis(Value *value,
/// Returns a `ConvertOp` that casts the elements to a i64 type while retaining
/// the shape of the input value.
static ConvertOp CastElementsToI64(Location loc, Value *value,
PatternRewriter *rewriter) {
static ConvertOp CastValueToI64(Location loc, Value *value,
PatternRewriter *rewriter) {
return rewriter->create<ConvertOp>(loc, value, rewriter->getIntegerType(64));
}
@ -207,7 +207,8 @@ static IntegerAttr getFeatureDimensionAttr(Builder &b, StringAttr format,
// Bias op utilities.
//===----------------------------------------------------------------------===//
/// Return a 1D DenseIntElementsAttr for the feature dimension of a BiasAdd.
// Return a 1D DenseIntElementsAttr for the feature dimension of a BiasAdd.
// Requires input to have ranked tensor.
static DenseIntElementsAttr getBiasFeatureDimension(Builder &b,
StringAttr format,
Value *input) {
@ -418,7 +419,8 @@ static DenseIntElementsAttr TFSliceSizes2HLOSliceSizes(
Builder *builder) {
DenseIntElementsAttr constant_start_indices;
if (!matchPattern(start_indices, m_Constant(&constant_start_indices))) {
return slice_sizes;
return xla::ConvertElementsAttr(slice_sizes, builder->getIntegerType(64))
.cast<DenseIntElementsAttr>();
}
auto input_ty = input->getType().dyn_cast<RankedTensorType>();
@ -687,7 +689,7 @@ class ConvertEinsumOp : public OpRewritePattern<TF::EinsumOp> {
rewriter.replaceOpWithNewOp<UnaryEinsumOp>(
op, op.getType(), *op.inputs().begin(), equation);
} else if (op.N() == 2) {
auto inputs = llvm::to_vector<2>(op.inputs());
ValueRange inputs = op.inputs();
rewriter.replaceOpWithNewOp<EinsumOp>(op, op.getType(), inputs[0],
inputs[1], equation);
} else {
@ -924,7 +926,7 @@ class ConvertSizeOp : public OpRewritePattern<TF::SizeOp> {
};
// Converts the tf.Split op into a series of HLO slice ops when the tensor to be
// split has fuly static shape and the dimension to split is a constant.
// split has fully static shape and the dimension to split is a constant.
//
// The main logic of this pattern is to calculate the index start and end range
// for each slice. And this happens only on the dimension to be split; for all
@ -962,9 +964,9 @@ class ConvertSplitOp : public OpRewritePattern<TF::SplitOp> {
PatternMatchResult matchAndRewrite(TF::SplitOp op,
PatternRewriter &rewriter) const override {
// We can only match when the tensor to be split has fully static shape.
// We can only split along static dimensions.
auto input_type = op.value()->getType().dyn_cast<RankedTensorType>();
if (!input_type || !input_type.hasStaticShape()) return matchFailure();
if (!input_type) return matchFailure();
// We can only match when the split dimension is a constant scalar.
DenseIntElementsAttr split_dim_attr;
@ -978,6 +980,10 @@ class ConvertSplitOp : public OpRewritePattern<TF::SplitOp> {
// Calculate the dimension size for each slice along the split dimension.
int64_t input_dim_size = input_type.getDimSize(dim_index);
// If we are splitting along the dynamic dimension then we cannot compute
// the static dimension length.
if (TensorType::isDynamic(input_dim_size)) return matchFailure();
int64_t num_splits = op.getNumResults();
int64_t slice_size = input_dim_size / num_splits;
@ -1011,6 +1017,118 @@ class ConvertSplitOp : public OpRewritePattern<TF::SplitOp> {
}
};
// Converts the tf.SplitV op into a series of HLO slice ops when the tensor to
// be split has fully static shape and the dimension to split and split sizes
// are constants.
//
// This is similar to the conversion for tf.Split op other than that the size of
// each chunk on the dimension to split is explicitly given as an op operand
// and they are not necessarily the same.
//
// For example, given the following IR:
//
// %split_sizes = "tf.Const"() {value = dense<[1, -1, 3]> : tensor<3xi32>}
// %split_dim = "tf.Const"() {value = dense<1> : tensor<i32>}
// %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) :
// (tensor<4x6xf32>, tensor<3xi32>, tensor<i32>) ->
// (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>)
//
// We will generate slices following slices:
// %0 = "xla_hlo.slice"(%input) {
// limit_indices = dense<[4, 1]> : tensor<2xi64>,
// start_indices = dense<0> : tensor<2xi64>,
// strides = dense<1> : tensor<2xi64>} :
// (tensor<4x6xf32>) -> tensor<4x1xf32>
// %1 = "xla_hlo.slice"(%input) {
// limit_indices = dense<[4, 3]> : tensor<2xi64>,
// start_indices = dense<[0, 1]> : tensor<2xi64>,
// strides = dense<1> : tensor<2xi64>} :
// (tensor<4x6xf32>) -> tensor<4x2xf32>
// %2 = "xla_hlo.slice"(%input) {
// limit_indices = dense<[4, 6]> : tensor<2xi64>,
// start_indices = dense<[0, 3]> : tensor<2xi64>,
// strides = dense<1> : tensor<2xi64>} :
// (tensor<4x6xf32>) -> tensor<4x3xf32>
class ConvertSplitVOp : public OpRewritePattern<TF::SplitVOp> {
public:
using OpRewritePattern::OpRewritePattern;
PatternMatchResult matchAndRewrite(TF::SplitVOp op,
PatternRewriter &rewriter) const override {
// We can only split along static dimensions.
// TODO(b/145731001): enhance to support dynamic-shaped inputs.
auto input_type = op.value()->getType().dyn_cast<RankedTensorType>();
if (!input_type) return matchFailure();
// We can only match when the split dimension is a constant scalar.
DenseIntElementsAttr split_dim_attr;
if (!matchPattern(op.split_dim(), m_Constant(&split_dim_attr)))
return matchFailure();
// We can only match when the split sizes is a constant int vector.
DenseIntElementsAttr split_sizes_attr;
if (!matchPattern(op.size_splits(), m_Constant(&split_sizes_attr)))
return matchFailure();
// Get each chunck's size along the dimension to split. It may contain
// dynamic sizes and we need to update it if so.
SmallVector<int64_t, 4> split_sizes;
int64_t total_dim_size = 0; // Total dimension size assigned to splits
llvm::Optional<int> dynamic_dim_index;
split_sizes.reserve(
split_sizes_attr.getType().cast<ShapedType>().getNumElements());
for (auto dim : llvm::enumerate(split_sizes_attr)) {
int64_t dim_val = dim.value().getSExtValue();
split_sizes.push_back(dim_val);
if (dim_val == ShapedType::kDynamicSize) {
// We cannot have more than one dynamic dimension.
assert(!dynamic_dim_index && "invalid split sizes");
dynamic_dim_index = dim.index();
} else {
total_dim_size += dim_val;
}
}
// Get the dimension we are splitting at. Offset properly if it's negative.
int64_t input_rank = input_type.getRank();
int64_t dim_index = (*split_dim_attr.begin()).getSExtValue();
if (dim_index < 0) dim_index += input_rank;
int64_t input_dim_size = input_type.getDimSize(dim_index);
if (TensorType::isDynamic(input_dim_size)) return matchFailure();
assert(((dynamic_dim_index && total_dim_size <= input_dim_size) ||
(!dynamic_dim_index && total_dim_size == input_dim_size)) &&
"invalid split sizes");
// Update the dynamic dimension with calculated concrete size.
if (dynamic_dim_index)
split_sizes[*dynamic_dim_index] = input_dim_size - total_dim_size;
// Parameters for constructing each slice.
SmallVector<int64_t, 4> begin_indices(input_rank, 0);
auto end_indices = llvm::to_vector<4>(input_type.getShape());
SmallVector<int64_t, 4> strides(input_rank, 1);
// All HLO slice results used to replace the original tf.Split op.
SmallVector<Value *, 4> slices;
slices.reserve(op.getNumResults());
for (int i = 0; i < op.getNumResults(); ++i) {
end_indices[dim_index] = begin_indices[dim_index] + split_sizes[i];
slices.push_back(rewriter.create<xla_hlo::SliceOp>(
op.getLoc(), op.value(), GetI64ElementsAttr(begin_indices, &rewriter),
GetI64ElementsAttr(end_indices, &rewriter),
GetI64ElementsAttr(strides, &rewriter)));
// Prepare the begin indice for the next slice.
begin_indices[dim_index] = end_indices[dim_index];
}
rewriter.replaceOp(op, slices);
return matchSuccess();
}
};
// Converts StridedSlice op to HLO Slice op along with Reverse op to handle
// negative strides and Reshape op to update the output shape. Indices and
// strides operands are converted to attributes with non-negative indexing.
@ -1182,8 +1300,7 @@ class GenericConvertReductionOp : public OpRewritePattern<OpTy> {
ArrayRef<int64_t> input_shape = input_ty.getShape();
DenseIntElementsAttr dimensions;
if (!matchPattern(op.reduction_indices(), m_Constant(&dimensions)) ||
dimensions.getType().getRank() != 1)
if (!matchPattern(op.reduction_indices(), m_Constant(&dimensions)))
return this->matchFailure();
// Build the final shape from input_shape and dimensions using a bitmap
@ -1260,7 +1377,6 @@ class ConvertMeanOp
: public GenericConvertReductionOp<ConvertMeanOp, TF::MeanOp, AddOp> {
public:
using GenericConvertReductionOp::GenericConvertReductionOp;
static Value *GetInitialValue(Type reduce_element_type, Location loc,
PatternRewriter &rewriter) {
return GetScalarConstOfType(reduce_element_type, loc, 0, &rewriter);
@ -1300,6 +1416,36 @@ class ConvertMaxOp
}
};
// Converts All op to HLO Reduce op.
//
// %init = constant dense<...> : tensor<T>
// %max = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.and"]
// {dimensions = ...}
class ConvertAllOp
: public GenericConvertReductionOp<ConvertAllOp, TF::AllOp, AndOp> {
public:
using GenericConvertReductionOp::GenericConvertReductionOp;
static Value *GetInitialValue(Type reduce_element_type, Location loc,
PatternRewriter &rewriter) {
return GetScalarConstOfType(reduce_element_type, loc, 1, &rewriter);
}
};
// Converts Any op to HLO Reduce op.
//
// %init = constant dense<...> : tensor<T>
// %max = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.or"]
// {dimensions = ...}
class ConvertAnyOp
: public GenericConvertReductionOp<ConvertAnyOp, TF::AnyOp, OrOp> {
public:
using GenericConvertReductionOp::GenericConvertReductionOp;
static Value *GetInitialValue(Type reduce_element_type, Location loc,
PatternRewriter &rewriter) {
return GetScalarConstOfType(reduce_element_type, loc, 0, &rewriter);
}
};
// Converts tensorflow ArgMin or ArgMax op to xla_hlo operations that perform
// a reduction on the original input and the corresponding index. The reduction
// sub-computation selects the max (or min) value and the index for the value.
@ -2000,6 +2146,53 @@ class ConvertTopKV2Op : public OpRewritePattern<TF::TopKV2Op> {
}
};
// Converts tf.Unpack to a series of XLA HLO slice ops.
//
// Each slice takes one element along the dimension to unpack and takes the full
// range for all other dimenions. Each slice is then reshaped to drop the
// dimension to unpack (which is always of size 1).
// TODO(antiagainst): consider changing this into a TF internal lowering pass.
class ConvertUnpackOp : public OpRewritePattern<TF::UnpackOp> {
public:
using OpRewritePattern::OpRewritePattern;
PatternMatchResult matchAndRewrite(TF::UnpackOp op,
PatternRewriter &rewriter) const override {
auto value_type = op.value()->getType().cast<RankedTensorType>();
if (!value_type) return matchFailure();
int64_t value_rank = value_type.getRank();
int64_t axis = op.axis().getSExtValue();
if (axis < 0) axis += value_rank;
// Parameters for constructing each slice.
SmallVector<int64_t, 4> begin_indices(value_rank, 0);
auto end_indices = llvm::to_vector<4>(value_type.getShape());
SmallVector<int64_t, 4> strides(value_rank, 1);
// All HLO slice+reshape results used to replace the original tf.Unpack op.
SmallVector<Value *, 4> results;
results.reserve(op.getNumResults());
for (int i = 0; i < op.getNumResults(); ++i) {
begin_indices[axis] = i;
end_indices[axis] = i + 1;
auto slice_op = rewriter.create<xla_hlo::SliceOp>(
op.getLoc(), op.value(), GetI64ElementsAttr(begin_indices, &rewriter),
GetI64ElementsAttr(end_indices, &rewriter),
GetI64ElementsAttr(strides, &rewriter));
// Reshape to drop the axis dimension.
auto reshape_op = rewriter.create<xla_hlo::ReshapeOp>(
op.getLoc(), op.getType(i), slice_op);
results.push_back(reshape_op);
}
rewriter.replaceOp(op, results);
return matchSuccess();
}
};
#include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc"
LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) {
@ -2013,16 +2206,16 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) {
// level TensorFlow ops. So, we don't have to target all the TensorFlow ops
// here for lowering to HLO.
TF::PopulateLoweringTFPatterns(context, &patterns);
patterns
.insert<ConvertArgMaxOp, ConvertBF16FloorDivOp, ConvertConv2D,
ConvertEinsumOp, ConvertMaxPoolOp, ConvertRangeOp,
ConvertSigmoidOp, ConvertSizeOp, ConvertMaxPoolOp, ConvertRangeOp,
ConvertSigmoidOp, ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSplitOp,
ConvertStridedSliceOp, ConvertTopKV2Op, ConvertMeanOp,
ConvertSumOp, ConvertMaxOp, ConvertTileOp, ConvertMaxPoolGradOp,
ConvertOneHotOp, ConvertConv2DBackpropInputOp,
ConvertConv2DBackpropFilterOp>(op->getContext());
patterns.insert<
ConvertArgMaxOp, ConvertBF16FloorDivOp, ConvertConv2D, ConvertEinsumOp,
ConvertMaxPoolOp, ConvertRangeOp, ConvertSigmoidOp, ConvertSizeOp,
ConvertMaxPoolOp, ConvertRangeOp, ConvertSigmoidOp,
ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSplitOp, ConvertSplitVOp,
ConvertStridedSliceOp, ConvertTopKV2Op, ConvertUnpackOp, ConvertMeanOp,
ConvertSumOp, ConvertMaxOp, ConvertAllOp, ConvertAnyOp, ConvertTileOp,
ConvertMaxPoolGradOp, ConvertOneHotOp, ConvertConv2DBackpropInputOp,
ConvertConv2DBackpropFilterOp>(op->getContext());
ConversionTarget target(*context);
target.addLegalDialect<XlaHloDialect>();

View File

@ -95,8 +95,7 @@ void ImportXlaRegion(mlir::FuncOp func, Region* dest_region, Location loc,
detupled_args.push_back(extract);
}
llvm::SmallVector<Value*, 4> result(
builder.create<CallOp>(loc, func, detupled_args).getResults());
auto result = builder.create<CallOp>(loc, func, detupled_args).getResults();
if (!tuple_return) {
builder.create<xla_hlo::ReturnOp>(loc, result);
} else {

View File

@ -29,6 +29,9 @@ def FeatureDimension : NativeCodeCall<
def FalseBoolAttr : AttrConstraint<CPred<"!$_self.getValue()">>;
def TrueBoolAttr : AttrConstraint<CPred<"$_self.getValue()">>;
def CastValueToI64: NativeCodeCall<
"CastValueToI64($0->getLoc(), $1, &$_builder)">;
def : Pattern<
(TF_FusedBatchNormOp:$root $x, $scale, $offset, $mean, $variance, $epsilon,
$data_format, FalseBoolAttr:$is_training),
@ -43,13 +46,22 @@ def : Pattern<
[(HasNoUseOf:$root__1), (HasNoUseOf:$root__2),
(HasNoUseOf:$root__3), (HasNoUseOf:$root__4)]>;
//===----------------------------------------------------------------------===//
// Assert op pattern.
//===----------------------------------------------------------------------===//
// HLO and XLA doesn't support Assertions.
def LowerAssert : Pattern<(TF_AssertOp $condition, $data, $summarize), []>;
//===----------------------------------------------------------------------===//
// Bias op patterns.
//===----------------------------------------------------------------------===//
def BiasAddFeatureDimension : NativeCodeCall<
"getBiasFeatureDimension($_builder, $0, $1)">;
def : Pat<(TF_BiasAddOp AnyStaticShapeTensor:$input, $bias, $data_format),
// $input needs to be a ranked tensor to identify index of the feature
// dimension depending on the data_format 'NHWC' or 'NCHW'.
def : Pat<(TF_BiasAddOp AnyRankedTensor:$input, $bias, $data_format),
(HLO_AddOp $input, $bias,
(BiasAddFeatureDimension $data_format, $input))>;
@ -298,7 +310,7 @@ def : Pat<(TF_MatMulOp $a, $b, $transpose_a, $transpose_b),
//===----------------------------------------------------------------------===//
def : Pat<(TF_ConstOp:$res ElementsAttr:$value), (HLO_ConstOp $value),
[(AnyStaticShapeTensor $res), (HLO_Tensor $res)]>;
[(HLO_Tensor $res)]>;
//===----------------------------------------------------------------------===//
// Relu op patterns.
@ -316,11 +328,21 @@ def : Pat<(TF_Relu6Op AnyStaticShapeTensor:$input),
(HLO_ConstOp (ConstantSplat<"6"> $input)))>;
// ReluGrad(gradients, features) = gradients * (features > 0)
def : Pat<(TF_ReluGradOp AnyStaticShapeTensor:$gradients, AnyStaticShapeTensor:$features),
//
// $gradients needs to be of static shape so that on_true and on_false operands
// of SelectOp have same shape.
//
// $features needs to be ranked for computation of the broadcast dimensions for
// CompareOp.
//
// TODO(hinsu): Relax $gradients static shape requirement when there is a way
// to create splat tensor of dynamic shape in HLO.
def : Pat<(TF_ReluGradOp AnyStaticShapeTensor:$gradients, AnyRankedTensor:$features),
(HLO_SelectOp
(HLO_CompareOp $features, (HLO_ConstOp:$zero (ConstantSplat<"0"> $features)),
(HLO_CompareOp $features,
(HLO_ConstOp (GetScalarOfType<0> $features)),
(NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_GT),
$gradients, $zero)>;
$gradients, (HLO_ConstOp (ConstantSplat<"0"> $gradients)))>;
//===----------------------------------------------------------------------===//
// Slice op patterns.
@ -333,9 +355,9 @@ def TFSliceSizes2HLOSliceSizes : NativeCodeCall<
"TFSliceSizes2HLOSliceSizes($0, $1, $2.cast<DenseIntElementsAttr>(),"
"&$_builder)">;
def : Pat<(TF_SliceOp HLO_Tensor:$input, HLO_Tensor:$starting_indices,
(TF_ConstOp I64ElementsAttr:$slice_sizes)),
(HLO_DynamicSliceOp $input, $starting_indices,
def : Pat<(TF_SliceOp:$op HLO_Tensor:$input, HLO_Tensor:$starting_indices,
(TF_ConstOp $slice_sizes)),
(HLO_DynamicSliceOp $input, (CastValueToI64 $op, $starting_indices),
(TFSliceSizes2HLOSliceSizes $input, $starting_indices, $slice_sizes)),
[(CanBeTranslatedToDynamicSlice $input, $starting_indices,
$slice_sizes)]>;
@ -383,19 +405,21 @@ foreach Mapping = [
def : Pat<(TF_CastOp HLO_Tensor:$arg, ConstBoolAttrFalse),
(HLO_ConvertOp $arg)>;
def : Pat<(TF_TransposeOp:$res $arg, (TF_ConstOp I64ElementsAttr:$permutation)),
(HLO_TransposeOp $arg, (CastIntElementsAttr $permutation))>;
def : Pat<(TF_TransposeOp:$res $arg, (TF_ConstOp $permutation)),
(HLO_TransposeOp $arg, (CastElementsToI64Elements $permutation))>;
// Result of the following ops changing tensor shape needs to have static
// shape as HLO doesn't yet support dynamic reshaping ops.
//
// TODO(hinsu): Update once HLO supports dynamic reshaping ops.
foreach TfOp = [TF_ExpandDimsOp, TF_ReshapeOp, TF_SqueezeOp, ] in {
def : Pat<(TfOp:$res AnyStaticShapeTensor:$arg, $ignored),
def : Pat<(TfOp:$res $arg, $ignored),
(HLO_ReshapeOp $arg), [(AnyStaticShapeTensor $res)]>;
}
//===----------------------------------------------------------------------===//
// RngUniform.
//===----------------------------------------------------------------------===//
def CastElementsToI64: NativeCodeCall<
"CastElementsToI64($0->getLoc(), $1, &$_builder)">;
// TODO(misard,phawkins): handle random number generator seeds/states correctly.
def : Pat<(TF_RandomUniformOp:$old $shape, $seed, $seed2),
@ -404,5 +428,5 @@ def : Pat<(TF_RandomUniformOp:$old $shape, $seed, $seed2),
(NativeCodeCall<"$_builder.getFloatAttr(old.dtype(), 0.0)">)),
(HLO_ConstOp
(NativeCodeCall<"$_builder.getFloatAttr(old.dtype(), 1.0)">)),
(CastElementsToI64 $old, $shape)),
(CastValueToI64 $old, $shape)),
[(IsShapedTensor $shape)]>;

View File

@ -54,7 +54,8 @@ struct LhloFuseLinalg : public FunctionPass<LhloFuseLinalg> {
auto op = cast<LinalgOp>(generic_op.getOperation());
for (const Value* result : op.getOutputs()) {
if (!func_args.count(result)) continue;
if (linalg::tileLinalgOp(b, op, tile_sizes, &folder)) {
if (linalg::tileLinalgOp(b, op, tile_sizes, /*permutation=*/{},
&folder)) {
generic_op.erase();
return;
}

View File

@ -112,7 +112,7 @@ class PointwiseToLinalgConverter : public OpConversionPattern<LhloOp> {
rewriter.setInsertionPointToEnd(block);
Operation* op = MapLhloOpToStdScalarOp<LhloOp>(
llvm::cast<LhloOp>(lhlo_op), bodyResultTypes, bodyArgs, rewriter);
rewriter.create<linalg::YieldOp>(loc, llvm::to_vector<1>(op->getResults()));
rewriter.create<linalg::YieldOp>(loc, op->getResults());
rewriter.eraseOp(lhlo_op);
return ConversionPattern::matchSuccess();
}

View File

@ -653,7 +653,13 @@ class BinaryOpsTest(xla_test.XLATestCase):
divs = np.arange(-3, 3, .25, dtype=dtype).reshape(1, 24)
np_result = np.true_divide(nums, divs)
np_result[:, divs[0] == 0] = 0
self._testBinary(gen_math_ops.div_no_nan, nums, divs, expected=np_result)
self._testBinary(
gen_math_ops.div_no_nan,
nums,
divs,
expected=np_result,
rtol=7e-15 if dtype == np.float64 else None,
atol=3.9e-15 if dtype == np.float64 else None)
if dtype not in self.complex_types: # floordiv unsupported for complex.
self._testBinary(

View File

@ -164,7 +164,8 @@ class TensorArrayTest(xla_test.XLATestCase):
dtype=tf_dtype, tensor_array_name="foo", size=3)
# Unpack a matrix into vectors.
w1 = ta.unstack(convert([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]]))
w1 = ta.unstack(
convert([[1.0, 1.03125], [2.0, 2.03125], [3.0, 3.03125]]))
r0 = w1.read(0)
r1 = w1.read(1)
r2 = w1.read(2)
@ -172,9 +173,9 @@ class TensorArrayTest(xla_test.XLATestCase):
d0, d1, d2 = self.evaluate(xla.compile(fn))
self.assertAllEqual(convert([1.0, 1.1]), d0)
self.assertAllEqual(convert([2.0, 2.1]), d1)
self.assertAllEqual(convert([3.0, 3.1]), d2)
self.assertAllEqual(convert([1.0, 1.03125]), d0)
self.assertAllEqual(convert([2.0, 2.03125]), d1)
self.assertAllEqual(convert([3.0, 3.03125]), d2)
def fn():
# Reset ta because we're going to change the shape, else shape

View File

@ -307,7 +307,7 @@ void UpdateToEngineNode(const std::vector<EngineInfo>& infos,
}
}
}
LOG(FATAL) << "Node " << (**node).name() << " not found in any engine.";
LOG(FATAL) << "Node " << node_name << " not found in any engine.";
}
// Function to insert a TRT engine node into the graph.

View File

@ -654,9 +654,8 @@ class ConverterTest : public ::testing::Test {
ConverterTest() { Reset(); }
void Reset() {
builder_.reset(nvinfer1::createInferBuilder(logger_));
converter_ =
std::move(Converter::Create(builder_.get(), TrtPrecisionMode::FP32,
std::move(Converter::Create(TrtPrecisionMode::FP32,
/*use_calibration=*/false, &logger_)
.ValueOrDie());
weight_store_ = &converter_->weight_store_;
@ -702,9 +701,6 @@ class ConverterTest : public ::testing::Test {
private:
Logger logger_;
// These members are ordered in a way such that the destruction order is:
// converter_ -> builder_
TrtUniquePtrType<nvinfer1::IBuilder> builder_;
protected:
std::unique_ptr<Converter> converter_;
@ -996,9 +992,7 @@ TEST_F(ConverterTest, MaybeApplyQuantizationRanges) {
FakeITensor input, infer_1, infer_2, infer_3;
FakeITensor not_infer;
Logger logger;
TrtUniquePtrType<nvinfer1::IBuilder> builder(
nvinfer1::createInferBuilder(logger));
auto int8_converter = Converter::Create(builder.get(), TrtPrecisionMode::INT8,
auto int8_converter = Converter::Create(TrtPrecisionMode::INT8,
/*use_calibration=*/true, &logger)
.ValueOrDie();
int8_converter->ProvideQuantizationRange(&input, -5.0f, 5.0f);
@ -1255,12 +1249,8 @@ class OpConverterTest : public ::testing::Test {
engine_.reset(nullptr);
// Re-create them in proper order.
builder_.reset(nvinfer1::createInferBuilder(logger_));
builder_->setMaxWorkspaceSize(1 << 26);
// Reset the converter.
converter_ =
std::move(Converter::Create(builder_.get(), precision_mode_to_test_,
std::move(Converter::Create(precision_mode_to_test_,
/*use_calibration=*/false, &logger_)
.ValueOrDie());
@ -1294,18 +1284,13 @@ class OpConverterTest : public ::testing::Test {
TF_EXPECT_OK(converter_->RenameAndMarkOutputTensors(output_info));
// Build the TRT engine.
if (precision_mode == TrtPrecisionMode::FP16) {
builder_->setFp16Mode(true);
} else if (precision_mode == TrtPrecisionMode::INT8) {
// Setting FP16 mode as well allows TRT to also consider FP16 kernels and
// use them in situations where they are faster than INT8 or where INT8 is
// not supported for a given layer.
builder_->setFp16Mode(true);
builder_->setInt8Mode(true);
}
ASSERT_EQ(nullptr, engine_.get());
builder_->setMaxBatchSize(batch_size);
TF_ASSERT_OK(converter_->BuildCudaEngine(&engine_));
TF_ASSERT_OK(
converter_->BuildCudaEngine(&engine_,
/*max_batch_size=*/batch_size,
/*max_workspace_size_bytes=*/1 << 26,
/*allocator=*/nullptr,
/*calibrator=*/nullptr));
CHECK_NOTNULL(engine_.get());
CheckDataTypeMatches(input_data);
CheckDataTypeMatches(*output_data);
@ -1473,7 +1458,6 @@ class OpConverterTest : public ::testing::Test {
private:
Logger logger_;
TrtUniquePtrType<nvinfer1::IBuilder> builder_;
TrtUniquePtrType<nvinfer1::ICudaEngine> engine_;
cudaStream_t stream_;
// Used to create placeholders with shape and data type information. The

View File

@ -143,6 +143,10 @@ class DataFormatVecPermuteOp : public XlaOpKernel {
REGISTER_XLA_OP(
Name("DataFormatVecPermute").TypeConstraint("T", {DT_INT32, DT_INT64}),
DataFormatVecPermuteOp);
REGISTER_XLA_OP(Name("DataFormatVecPermute")
.Label("host")
.TypeConstraint("T", {DT_INT32, DT_INT64}),
DataFormatVecPermuteOp);
} // namespace
} // namespace tensorflow

View File

@ -723,17 +723,6 @@ Status XlaCompiler::CompileFunction(
std::unique_ptr<Graph> graph = GetGraph(fbody);
// Clear the "_kernel" attribute if it is set to "host". This is used to
// indicate that a computation should happen on the host instead of the
// accelerator, but doesn't make sense in XLA.
const char* const kKernelAttr = "_kernel";
for (Node* n : graph->nodes()) {
string value;
if (TryGetNodeAttr(n->attrs(), kKernelAttr, &value) && value == "host") {
n->ClearAttr(kKernelAttr);
}
}
// _Arg and _Retval nodes don't exist in the stored subgraph for the function;
// they are added by the function body looked up. Therefore, they don't have
// core assignments here.
@ -1059,7 +1048,12 @@ Status XlaCompiler::BuildArguments(
const XlaCompiler::Argument& arg = args[input_to_args->at(i)];
VLOG(2) << " XLA arg " << i
<< " shape: " << xla::ShapeUtil::HumanString(arg_shapes[i])
<< " name: " << arg.name << " TF arg " << input_to_args->at(i);
<< " name: " << arg.name << " TF arg " << input_to_args->at(i)
<< " node name: " << arg.node_name
<< (arg_shardings.find(i) == arg_shardings.end()
? ""
: absl::StrCat(" sharding: ",
arg_shardings.at(i).DebugString()));
XlaExpression& arg_expression = (*arg_expressions)[input_to_args->at(i)];
switch (arg.kind) {
case XlaCompiler::Argument::kResource: {

View File

@ -147,6 +147,9 @@ class XlaCompiler {
// The name of this argument, used for debugging.
string name;
// The name of TensorFlow _Arg node, used for debugging.
string node_name;
// For a kResource, what kind of resource is it?
XlaResource::Kind resource_kind = XlaResource::kInvalid;

View File

@ -61,6 +61,7 @@ XlaOpRegistry::~XlaOpRegistry() = default;
/* static */ bool XlaOpRegistry::IsCompatible(const OpRegistration& x,
const OpRegistration& y) {
if (x.name != y.name) return true;
if (x.label != y.label) return true;
// The registrations refer to the same Op: ensures they are compatible and
// are restricted to different device whitelists.
if (x.compilation_only != y.compilation_only) {
@ -256,6 +257,7 @@ void XlaOpRegistry::RegisterCompilationKernels() {
std::unique_ptr<KernelDef> kdef(new KernelDef);
kdef->set_op(op_registration->name);
kdef->set_device_type(backend.first);
kdef->set_label(op_registration->label);
// Constrain each type attribute to the intersection of:
// a) the types supported by the backend, and
@ -539,6 +541,11 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::IsMetadataOp() {
return *this;
}
XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Label(std::string label) {
registration_->label = label;
return *this;
}
std::unique_ptr<XlaOpRegistry::OpRegistration> XlaOpRegistrationBuilder::Build(
XlaOpRegistry::Factory factory) {
registration_->factory = factory;

View File

@ -270,6 +270,8 @@ class XlaOpRegistry {
// operands and not their values.
bool is_metadata_op = false;
std::string label;
// Factory used to build OpKernels that perform symbolic execution.
Factory factory;
};
@ -350,6 +352,9 @@ class XlaOpRegistrationBuilder {
// operands and not their values.
XlaOpRegistrationBuilder& IsMetadataOp();
// Specifies a particular value for the "_kernel" attr.
XlaOpRegistrationBuilder& Label(std::string label);
std::unique_ptr<XlaOpRegistry::OpRegistration> Build(
XlaOpRegistry::Factory factory);

View File

@ -319,6 +319,8 @@ XlaOp Erf(XlaOp x) {
});
}
namespace {
// Approximation for the inverse error function from
// Giles, M., "Approximating the erfinv function".
// The approximation has the form:
@ -331,7 +333,7 @@ XlaOp Erf(XlaOp x) {
// p = sum_{i=1}^n gq[i]*w^i
// }
// return p*x
XlaOp ErfInv(XlaOp x) {
XlaOp ErfInv32(XlaOp x) {
constexpr int kDegree = 9;
constexpr std::array<float, 9> w_less_than_5_constants = {
2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f,
@ -371,6 +373,101 @@ XlaOp ErfInv(XlaOp x) {
});
}
XlaOp ErfInv64(XlaOp x) {
constexpr std::array<double, 23> w_less_than_6_25_constants = {
-3.6444120640178196996e-21, -1.685059138182016589e-19,
1.2858480715256400167e-18, 1.115787767802518096e-17,
-1.333171662854620906e-16, 2.0972767875968561637e-17,
6.6376381343583238325e-15, -4.0545662729752068639e-14,
-8.1519341976054721522e-14, 2.6335093153082322977e-12,
-1.2975133253453532498e-11, -5.4154120542946279317e-11,
1.051212273321532285e-09, -4.1126339803469836976e-09,
-2.9070369957882005086e-08, 4.2347877827932403518e-07,
-1.3654692000834678645e-06, -1.3882523362786468719e-05,
0.0001867342080340571352, -0.00074070253416626697512,
-0.0060336708714301490533, 0.24015818242558961693,
1.6536545626831027356};
constexpr std::array<double, 19> w_less_than_16_constants = {
2.2137376921775787049e-09, 9.0756561938885390979e-08,
-2.7517406297064545428e-07, 1.8239629214389227755e-08,
1.5027403968909827627e-06, -4.013867526981545969e-06,
2.9234449089955446044e-06, 1.2475304481671778723e-05,
-4.7318229009055733981e-05, 6.8284851459573175448e-05,
2.4031110387097893999e-05, -0.0003550375203628474796,
0.00095328937973738049703, -0.0016882755560235047313,
0.0024914420961078508066, -0.0037512085075692412107,
0.005370914553590063617, 1.0052589676941592334,
3.0838856104922207635,
};
constexpr std::array<double, 17> w_greater_than_16_constants = {
-2.7109920616438573243e-11, -2.5556418169965252055e-10,
1.5076572693500548083e-09, -3.7894654401267369937e-09,
7.6157012080783393804e-09, -1.4960026627149240478e-08,
2.9147953450901080826e-08, -6.7711997758452339498e-08,
2.2900482228026654717e-07, -9.9298272942317002539e-07,
4.5260625972231537039e-06, -1.9681778105531670567e-05,
7.5995277030017761139e-05, -0.00021503011930044477347,
-0.00013871931833623122026, 1.0103004648645343977,
4.8499064014085844221,
};
// Compute logarithm of (1+arg) using log1p(arg) which is more precise than
// log(1+arg) when arg is close to zero. For more details, see
// https://en.cppreference.com/w/cpp/numeric/math/log1p
auto w = -Log1p(-x * x);
auto lt_6_25 = Lt(w, ScalarLike(x, 6.25));
auto lt_16 = Lt(w, ScalarLike(x, 16));
auto coefficient = [&](int i) {
auto c = FullLike(x, w_less_than_6_25_constants[i]);
if (i < 19) {
c = Select(lt_6_25, c, FullLike(x, w_less_than_16_constants[i]));
}
if (i < 17) {
c = Select(lt_16, c, FullLike(x, w_greater_than_16_constants[i]));
}
return c;
};
auto sqrt_w = Sqrt(w);
w = Select(lt_6_25, w - ScalarLike(x, 3.125),
sqrt_w - Select(lt_16, ScalarLike(x, 3.25), ScalarLike(x, 5.0)));
auto p = coefficient(0);
for (int i = 1; i < 17; ++i) {
p = coefficient(i) + p * w;
}
for (int i = 17; i < 19; ++i) {
p = Select(lt_16, coefficient(i) + p * w, p);
}
for (int i = 19; i < 23; ++i) {
p = Select(lt_6_25, coefficient(i) + p * w, p);
}
// Result modulo edge cases.
XlaOp result = p * x;
// Handle edge cases, namely erfinv(+/-1) = +/-inf. (The above computation is
// indeterminate, and can give nan or -/+inf.)
auto& b = *x.builder();
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape shape, b.GetShape(x));
return Select(Eq(Abs(x), ScalarLike(x, 1)),
x * MaxValue(&b, shape.element_type()), result);
});
}
} // namespace
XlaOp ErfInv(XlaOp x) {
auto& b = *x.builder();
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("ErfInv", x));
TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(x));
if (shape.element_type() == F64) {
return ErfInv64(x);
}
return DoWithUpcastToF32(x, {BF16, F16},
[](XlaOp x) { return ErfInv32(x); });
});
}
namespace {
// Coefficients for the Lanczos approximation of the gamma function. The
// coefficients are uniquely determined by the choice of g and n (kLanczosGamma

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
@ -116,6 +117,10 @@ class MathTypedTest : public MathTest {
//
// For good measure, we also check pow with an exponent other than 0.5.
void TestSqrtPowInequivalence() {
// TODO(b/145798892): test fails on GPU for double values.
if (std::is_same<T, double>::value) {
return;
}
SetFastMathDisabled(true);
// Tests disable constant folding by default, but this test needs it
@ -151,11 +156,16 @@ class MathTypedTest : public MathTest {
};
// TODO(b/123355973): Add bfloat16 to TestTypes once it's working.
#ifdef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16
using TestTypes = ::testing::Types<float>;
#else
using TestTypes = ::testing::Types<float, Eigen::half>;
using TestTypes = ::testing::Types<float
#ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16
,
Eigen::half
#endif
#ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64
,
double
#endif
>;
TYPED_TEST_CASE(MathTypedTest, TestTypes);
@ -224,6 +234,28 @@ XLA_TEST_F(MathTest, SqrtF32) {
ComputeAndCompareR0<float>(&builder, 0.0f, {zero_data.get()}, error_spec_);
}
#ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64
XLA_TEST_F(MathTest, ErfInvF64) {
XlaBuilder builder(TestName());
auto x = ConstantR1<double>(
&builder, {-0.9, -0.8, -0.7, -0.6, -0.5, -0.4, -0.3, -0.2, -0.1, 0.0, 0.1,
0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9});
ErfInv(x);
std::vector<double> expected = {-1.163087153676674, -0.9061938024368231,
-0.732869077959217, -0.5951160814499948,
-0.4769362762044698, -0.37080715859355795,
-0.27246271472675443, -0.1791434546212916,
-0.08885599049425767, 0.,
0.08885599049425777, 0.1791434546212916,
0.27246271472675443, 0.37080715859355784,
0.4769362762044698, 0.5951160814499948,
0.732869077959217, 0.9061938024368231,
1.1630871536766736};
ComputeAndCompareR1<double>(&builder, expected, {}, ErrorSpec{1e-15});
}
#endif
XLA_TEST_F(MathTest, SquareTenValues) {
XlaBuilder builder(TestName());
auto x = ConstantR1<float>(

View File

@ -2112,7 +2112,8 @@ XlaOp XlaBuilder::CrossReplicaSum(
XlaOp XlaBuilder::AllReduce(XlaOp operand, const XlaComputation& computation,
absl::Span<const ReplicaGroup> replica_groups,
const absl::optional<ChannelHandle>& channel_id) {
const absl::optional<ChannelHandle>& channel_id,
const absl::optional<Shape>& shape_with_layout) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
@ -2136,9 +2137,31 @@ XlaOp XlaBuilder::AllReduce(XlaOp operand, const XlaComputation& computation,
operand_shapes.push_back(operand_shape);
operands.push_back(operand);
}
TF_ASSIGN_OR_RETURN(Shape shape,
TF_ASSIGN_OR_RETURN(Shape inferred_shape,
ShapeInference::InferAllReduceShape(operand_shapes));
*instr.mutable_shape() = shape.ToProto();
if (shape_with_layout) {
if (!LayoutUtil::HasLayout(*shape_with_layout)) {
return InvalidArgument("shape_with_layout must have the layout set: %s",
shape_with_layout->ToString());
}
if (!ShapeUtil::Compatible(*shape_with_layout, *operand_shape)) {
return InvalidArgument(
"Provided shape_with_layout must be compatible with the "
"operand shape: %s vs %s",
shape_with_layout->ToString(), operand_shape->ToString());
}
instr.set_constrain_layout(true);
if (operand_shape->IsTuple() && !inferred_shape.IsTuple()) {
// For a single-element tuple, take the tuple element shape.
TF_RET_CHECK(shape_with_layout->tuple_shapes_size() == 1);
*instr.mutable_shape() = shape_with_layout->tuple_shapes(0).ToProto();
} else {
*instr.mutable_shape() = shape_with_layout->ToProto();
}
} else {
*instr.mutable_shape() = inferred_shape.ToProto();
}
for (const ReplicaGroup& group : replica_groups) {
*instr.add_replica_groups() = group;
@ -2153,10 +2176,10 @@ XlaOp XlaBuilder::AllReduce(XlaOp operand, const XlaComputation& computation,
TF_ASSIGN_OR_RETURN(
auto all_reduce,
AddInstruction(std::move(instr), HloOpcode::kAllReduce, operands));
if (operand_shape->IsTuple() && !shape.IsTuple()) {
if (operand_shape->IsTuple() && !inferred_shape.IsTuple()) {
// For a single-element tuple, wrap the result into a tuple.
TF_RET_CHECK(operand_shapes.size() == 1);
TF_RET_CHECK(ShapeUtil::Compatible(*operand_shapes[0], shape));
TF_RET_CHECK(ShapeUtil::Compatible(*operand_shapes[0], inferred_shape));
return Tuple({all_reduce});
}
return all_reduce;
@ -3282,9 +3305,10 @@ XlaOp CrossReplicaSum(const XlaOp operand,
XlaOp AllReduce(const XlaOp operand, const XlaComputation& computation,
absl::Span<const ReplicaGroup> replica_groups,
const absl::optional<ChannelHandle>& channel_id) {
const absl::optional<ChannelHandle>& channel_id,
const absl::optional<Shape>& shape_with_layout) {
return operand.builder()->AllReduce(operand, computation, replica_groups,
channel_id);
channel_id, shape_with_layout);
}
XlaOp AllToAll(const XlaOp operand, int64 split_dimension,

View File

@ -514,7 +514,8 @@ class XlaBuilder {
XlaOp AllReduce(
XlaOp operand, const XlaComputation& computation,
absl::Span<const ReplicaGroup> replica_groups = {},
const absl::optional<ChannelHandle>& channel_id = absl::nullopt);
const absl::optional<ChannelHandle>& channel_id = absl::nullopt,
const absl::optional<Shape>& shape_with_layout = absl::nullopt);
XlaOp AllToAll(XlaOp operand, int64 split_dimension, int64 concat_dimension,
int64 split_count,
@ -922,7 +923,8 @@ class XlaBuilder {
absl::Span<const ReplicaGroup> replica_groups);
friend XlaOp AllReduce(XlaOp operand, const XlaComputation& computation,
absl::Span<const ReplicaGroup> replica_groups,
const absl::optional<ChannelHandle>& channel_id);
const absl::optional<ChannelHandle>& channel_id,
const absl::optional<Shape>& shape_with_layout);
friend XlaOp AllToAll(XlaOp operand, int64 split_dimension,
int64 concat_dimension, int64 split_count,
const std::vector<ReplicaGroup>& replica_groups);
@ -1666,10 +1668,14 @@ XlaOp CrossReplicaSum(XlaOp operand,
// - `channel_id`: for Allreduce nodes from different modules, if they have the
// same channel_id, they will be 'AllReduce'd. If empty, AllReduce will not be
// applied cross modules.
XlaOp AllReduce(
XlaOp operand, const XlaComputation& computation,
absl::Span<const ReplicaGroup> replica_groups = {},
const absl::optional<ChannelHandle>& channel_id = absl::nullopt);
//
// - `shape_with_layout`: forces the layout of the AllReduce to the given
// layout. This is used to guarantee the same layout for a group of AllReduce
// ops compiled separately.
XlaOp AllReduce(XlaOp operand, const XlaComputation& computation,
absl::Span<const ReplicaGroup> replica_groups = {},
const absl::optional<ChannelHandle>& channel_id = absl::nullopt,
const absl::optional<Shape>& shape_with_layout = absl::nullopt);
// Enqueues an operation that do an Alltoall of the operand cross cores.
XlaOp AllToAll(XlaOp operand, int64 split_dimension, int64 concat_dimension,

View File

@ -38,6 +38,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@ -131,18 +132,23 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) {
}
} else if (shape.IsArray()) {
if (allocate_arrays) {
// Literals can be used as DMA targets, which can require alignment. We
// force a 16-byte minimum alignment.
constexpr int kMinimumAlignment = 16;
if (LayoutUtil::IsSparseArray(shape)) {
// For sparse arrays, the buffer must be of the size of the maximum
// number of sparse elements possible.
const int64 max_sparse_elements =
LayoutUtil::MaxSparseElements(shape.layout());
piece->set_buffer(
new char[max_sparse_elements *
ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type())]);
piece->set_buffer(static_cast<char*>(tensorflow::port::AlignedMalloc(
max_sparse_elements *
ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type()),
kMinimumAlignment)));
piece->set_sparse_indices(
new SparseIndexArray(max_sparse_elements, shape.rank()));
} else {
piece->set_buffer(new char[piece->size_bytes()]);
piece->set_buffer(static_cast<char*>(tensorflow::port::AlignedMalloc(
piece->size_bytes(), kMinimumAlignment)));
}
}
} else {
@ -174,7 +180,7 @@ void Literal::DeallocateBuffers() {
root_piece_->ForEachMutableSubpiece(
[&](const ShapeIndex& index, Piece* piece) {
if (piece->buffer() != nullptr) {
delete[] piece->buffer();
tensorflow::port::AlignedFree(piece->buffer());
delete piece->sparse_indices();
}
});
@ -504,7 +510,7 @@ Status Literal::MoveFrom(Literal&& src_literal,
dest_index.push_back(i);
}
Piece& dest_piece = piece(dest_index);
delete[] dest_piece.buffer();
tensorflow::port::AlignedFree(dest_piece.buffer());
dest_piece.set_buffer(src_piece.buffer());
delete dest_piece.sparse_indices();
dest_piece.set_sparse_indices(src_piece.sparse_indices());

View File

@ -26,7 +26,6 @@ py_test(
name = "xla_client_test",
srcs = ["xla_client_test.py"],
main = "xla_client_test.py",
python_version = "PY3",
srcs_version = "PY2AND3",
tags = ["no_oss"], # TODO(phawkins): This test passes, but requires --config=monolithic.
deps = [

View File

@ -31,6 +31,11 @@ tf_proto_library_cc(
use_grpc_namespace = True,
)
cc_library(
name = "c_api",
hdrs = ["c_api.h"],
)
cc_library(
name = "tpu_driver",
srcs = [

View File

@ -0,0 +1,30 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_TPU_DRIVER_C_API_H_
#define TENSORFLOW_COMPILER_XLA_PYTHON_TPU_DRIVER_C_API_H_
#define TPUDRIVER_CAPI_EXPORT __attribute__((visibility("default")))
extern "C" {
TPUDRIVER_CAPI_EXPORT extern void TpuDriver_Initialize();
TPUDRIVER_CAPI_EXPORT extern void TpuDriver_Open(const char* worker);
TPUDRIVER_CAPI_EXPORT extern const char* TpuDriver_Version(void);
}
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_TPU_DRIVER_C_API_H_

Some files were not shown because too many files have changed in this diff Show More