Merge branch 'master' into fix_minimum_maximum
This commit is contained in:
commit
ec5d3a0603
51
CODEOWNERS
51
CODEOWNERS
@ -13,55 +13,4 @@
|
||||
/tensorflow/tensorboard/ @jart
|
||||
/tensorflow/tools/docs/ @markdaoust
|
||||
|
||||
# contrib
|
||||
|
||||
# NEED OWNER: /tensorflow/contrib/all_reduce
|
||||
/tensorflow/contrib/autograph/ @mdanatg @kkimdev
|
||||
/tensorflow/contrib/batching/ @alextp @chrisolston
|
||||
/tensorflow/contrib/bayesflow/ @ebrevdo @rsepassi @jvdillon
|
||||
/tensorflow/contrib/boosted_trees/ @sshrdp @yk5 @nataliaponomareva
|
||||
/tensorflow/contrib/checkpoint/ @allenlavoie
|
||||
/tensorflow/contrib/contrib/cluster_resolver/ @frankchn
|
||||
/tensorflow/contrib/cmake/ @mrry
|
||||
/tensorflow/contrib/copy_graph/ @tucker @poxvoculi
|
||||
/tensorflow/contrib/crf/ @kentonl
|
||||
/tensorflow/contrib/data/ @mrry
|
||||
/tensorflow/tensorflow/contrib/distribute @joshl @priyag @sourabhbajaj @frankchn
|
||||
/tensorflow/contrib/distributions/ @jvdillon @langmore @rsepassi
|
||||
/tensorflow/contrib/eager @jaingaurav @alextp
|
||||
/tensorflow/contrib/factorization/ @agarwal-ashish @xavigonzalvo
|
||||
/tensorflow/contrib/ffmpeg/ @fredbertsch
|
||||
/tensorflow/contrib/framework/ @ebrevdo
|
||||
/tensorflow/contrib/graph_editor/ @purpledog
|
||||
# NEED OWNER: /tensorflow/contrib/grid_rnn/
|
||||
/tensorflow/contrib/hadoop @yongtang
|
||||
/tensorflow/contrib/hvx/ @satok16
|
||||
/tensorflow/contrib/integrate/ @shoyer
|
||||
/tensorflow/contrib/kernel_methods/ @petrosmol
|
||||
/tensorflow/contrib/ios_examples/ @petewarden
|
||||
/tensorflow/contrib/labeled_tensor/ @shoyer
|
||||
/tensorflow/contrib/layers/ @fchollet @martinwicke
|
||||
/tensorflow/contrib/learn/ @martinwicke @ispirmustafa @alextp
|
||||
/tensorflow/contrib/linear_optimizer/ @petrosmol @andreasst @katsiapis
|
||||
/tensorflow/contrib/lookup/ @ysuematsu @andreasst
|
||||
/tensorflow/contrib/losses/ @alextp @ispirmustafa
|
||||
/tensorflow/contrib/makefile/ @petewarden @satok16 @wolffg
|
||||
/tensorflow/contrib/metrics/ @alextp @honkentuber @ispirmustafa
|
||||
/tensorflow/contrib/opt/ @strategist333 @alextp
|
||||
/tensorflow/contrib/pi_examples/ @maciekcc
|
||||
/tensorflow/contrib/quantization/ @petewarden
|
||||
/tensorflow/contrib/rnn/ @ebrevdo @scottzhu
|
||||
/tensorflow/contrib/saved_model/ @nfiedel @sukritiramesh @allenlavoie
|
||||
/tensorflow/contrib/seq2seq/ @ebrevdo @lmthang
|
||||
/tensorflow/contrib/session_bundle/ @nfiedel @sukritiramesh
|
||||
/tensorflow/contrib/slim/ @sguada @thenbasilmanran
|
||||
/tensorflow/contrib/stateless/ @girving @alextp
|
||||
/tensorflow/contrib/tensor_forest/ @gilberthendry @thomascolthurst @yupbank
|
||||
/tensorflow/contrib/tensorrt/ @aaroey @smit-hinsu @azaks2
|
||||
# NEED OWNER: /tensorflow/contrib/testing/
|
||||
/tensorflow/contrib/timeseries/ @allenlavoie
|
||||
/tensorflow/contrib/tpu/ @frankchn @saeta @jhseu @sourabhbajaj
|
||||
/tensorflow/contrib/training/ @joel-shor @ebrevdo
|
||||
/tensorflow/contrib/util/ @sherrym
|
||||
|
||||
/third_party/systemlibs/ @perfinion
|
||||
|
26
README.md
26
README.md
@ -110,19 +110,19 @@ Build Type | Status
|
||||
|
||||
### Community Supported Builds
|
||||
|
||||
Build Type | Status | Artifacts
|
||||
------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
|
||||
**Linux AMD ROCm GPU** Nightly | [](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly) | [Nightly](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly/lastSuccessfulBuild/)
|
||||
**Linux AMD ROCm GPU** Stable Release | [](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/) | Release [1.15](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/lastSuccessfulBuild/) / [2.x](http://ml-ci.amd.com:21096/job/tensorflow-rocm-v2-release/lastSuccessfulBuild/)
|
||||
**Linux s390x** Nightly | [](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | [Nightly](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/)
|
||||
**Linux s390x CPU** Stable Release | [](https://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/) | [Release](https://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/)
|
||||
**Linux ppc64le CPU** Nightly | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Nightly_Artifact/)
|
||||
**Linux ppc64le CPU** Stable Release | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_CPU_Release_Build/)
|
||||
**Linux ppc64le GPU** Nightly | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/)
|
||||
**Linux ppc64le GPU** Stable Release | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_GPU_Release_Build/)
|
||||
**Linux CPU with Intel® MKL-DNN** Nightly | [](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/)
|
||||
**Linux CPU with Intel® MKL-DNN** <br> **Supports Python 2.7, 3.4, 3.5, 3.6 and 3.7** | [](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild) | [1.14.0 PyPI](https://pypi.org/project/intel-tensorflow/)
|
||||
**Red Hat® Enterprise Linux® 7.6 CPU & GPU** <br> Python 2.7, 3.6 | [](https://jenkins-tensorflow.apps.ci.centos.org/job/tensorflow-rhel7-3.6/2/) | [1.13.1 PyPI](https://tensorflow.pypi.thoth-station.ninja/index/)
|
||||
Build Type | Status | Artifacts
|
||||
----------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
|
||||
**Linux AMD ROCm GPU** Nightly | [](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly) | [Nightly](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly/lastSuccessfulBuild/)
|
||||
**Linux AMD ROCm GPU** Stable Release | [](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/) | Release [1.15](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/lastSuccessfulBuild/) / [2.x](http://ml-ci.amd.com:21096/job/tensorflow-rocm-v2-release/lastSuccessfulBuild/)
|
||||
**Linux s390x** Nightly | [](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | [Nightly](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/)
|
||||
**Linux s390x CPU** Stable Release | [](https://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/) | [Release](https://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/)
|
||||
**Linux ppc64le CPU** Nightly | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Nightly_Artifact/)
|
||||
**Linux ppc64le CPU** Stable Release | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_CPU_Release_Build/)
|
||||
**Linux ppc64le GPU** Nightly | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/)
|
||||
**Linux ppc64le GPU** Stable Release | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_GPU_Release_Build/)
|
||||
**Linux CPU with Intel® MKL-DNN** Nightly | [](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/)
|
||||
**Linux CPU with Intel® MKL-DNN** Stable Release |  | Release [1.15](https://pypi.org/project/intel-tensorflow/1.15.0/) / [2.x](https://pypi.org/project/intel-tensorflow/)
|
||||
**Red Hat® Enterprise Linux® 7.6 CPU & GPU** <br> Python 2.7, 3.6 | [](https://jenkins-tensorflow.apps.ci.centos.org/job/tensorflow-rhel7-3.6/2/) | [1.13.1 PyPI](https://tensorflow.pypi.thoth-station.ninja/index/)
|
||||
|
||||
## Resources
|
||||
|
||||
|
@ -195,6 +195,12 @@ config_setting(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "chromiumos",
|
||||
values = {"crosstool_top": "//external:android/chromiumos"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "linux_aarch64",
|
||||
values = {"cpu": "aarch64"},
|
||||
@ -453,6 +459,7 @@ package_group(
|
||||
"//tensorflow_estimator/python/estimator/...",
|
||||
"//tensorflow_models/official/...",
|
||||
"//third_party/py/autograph/...",
|
||||
"//third_party/swift/tensorflow/x10/...",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -233,7 +233,7 @@ tensorflow::Status GetReplacedFromExistingWorkers(
|
||||
std::vector<tensorflow::eager::KeepAliveResponse> responses(
|
||||
existing_workers->size());
|
||||
for (int i = 0; i < existing_workers->size(); i++) {
|
||||
tensorflow::eager::EagerClient* eager_client;
|
||||
tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
|
||||
statuses[i] =
|
||||
client_cache->GetClient(existing_workers->at(i), &eager_client);
|
||||
if (!statuses[i].ok()) {
|
||||
@ -282,7 +282,7 @@ tensorflow::Status CreateRemoteContexts(
|
||||
continue;
|
||||
}
|
||||
|
||||
tensorflow::eager::EagerClient* eager_client;
|
||||
tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
|
||||
statuses[i] = remote_eager_workers->GetClient(remote_worker, &eager_client);
|
||||
if (eager_client == nullptr) {
|
||||
statuses[i] = tensorflow::errors::Internal(
|
||||
@ -340,7 +340,7 @@ tensorflow::Status UpdateRemoteContexts(
|
||||
continue;
|
||||
}
|
||||
|
||||
tensorflow::eager::EagerClient* eager_client;
|
||||
tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
|
||||
statuses[i] = remote_eager_workers->GetClient(remote_worker, &eager_client);
|
||||
if (eager_client == nullptr) {
|
||||
statuses[i] = tensorflow::errors::Internal(
|
||||
@ -819,7 +819,7 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
|
||||
}
|
||||
|
||||
// TODO(yuefengz): support partially specified `worker_name`.
|
||||
tensorflow::eager::EagerClient* eager_client;
|
||||
tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
|
||||
status->status = remote_eager_workers->GetClient(worker_name, &eager_client);
|
||||
if (!status->status.ok()) {
|
||||
return false;
|
||||
|
@ -38,6 +38,6 @@ versions {
|
||||
|
||||
# CHECK: func @main(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<*xi32>
|
||||
# CHECK: attributes {tf.entry_function = {inputs = "input0,input1", outputs = "output"}} {
|
||||
# CHECK-NEXT: %0 = "tf.BannaPotatoSaladWithColeslaw"(%arg0, %arg1) {T = "tfdtype$DT_INT32", device = "", name = "output"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<*xi32>
|
||||
# CHECK-NEXT: %0 = "tf.BannaPotatoSaladWithColeslaw"(%arg0, %arg1) {T = i32, device = "", name = "output"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<*xi32>
|
||||
# CHECK-NEXT: return %0 : tensor<*xi32>
|
||||
# CHECK-NEXT: }
|
||||
|
@ -1280,3 +1280,13 @@ func @conv2d_backprop_unsupported_data_format(%arg0: tensor<4xi32>, %arg1: tenso
|
||||
// CHECK-LABEL: conv2d_backprop_unsupported_data_format
|
||||
// CHECK: tf.Conv2DBackpropInput
|
||||
}
|
||||
|
||||
func @assert_remove(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi1> {
|
||||
%0 = "tf.LessEqual"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
|
||||
"tf.Assert"(%0, %arg1) {summarize = 3} : (tensor<1xi1>, tensor<1xi32>) -> ()
|
||||
return %0 : tensor<1xi1>
|
||||
// CHECK-LABEL: assert_remove
|
||||
// CHECK: tfl.less_equal
|
||||
// CHECK-NOT: Assert
|
||||
// CHECK: return
|
||||
}
|
||||
|
@ -622,3 +622,12 @@ func @Relu1_2(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
||||
|
||||
// CHECK: %[[relu_n1_to_1:[0-9].*]] = "tfl.relu_n1_to_1"
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fuse_relu_to_add
|
||||
func @fuse_relu_to_add(%arg0: tensor<2x3xf32>, %arg1: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
||||
%0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
%1 = "tfl.relu_n1_to_1"(%0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
return %1 : tensor<2x3xf32>
|
||||
// CHECK: %[[RES:.*]] = tfl.add %arg0, %arg1 {fused_activation_function = "RELU_N1_TO_1"}
|
||||
// CHECK: return %[[RES]]
|
||||
}
|
||||
|
@ -68,6 +68,7 @@ struct LegalizeTF : public FunctionPass<LegalizeTF> {
|
||||
// TODO(antiagainst): Define this pattern in a table-driven manner once variadic
|
||||
// operands are properly supported in declarative rewrite rule specification.
|
||||
|
||||
DECL_CONVERT_OP(Assert);
|
||||
DECL_CONVERT_OP(Concat);
|
||||
DECL_CONVERT_OP(ConcatV2);
|
||||
DECL_CONVERT_OP(MatMul);
|
||||
@ -86,7 +87,7 @@ PatternMatchResult ConvertTFConcatOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_concat_op = cast<TF::ConcatOp>(op);
|
||||
|
||||
SmallVector<Value*, 4> values(tf_concat_op.values());
|
||||
auto values = tf_concat_op.values();
|
||||
auto output_type = tf_concat_op.output()->getType();
|
||||
// Extract axis attribute from constant concat_dims tensor
|
||||
ElementsAttr axis;
|
||||
@ -105,7 +106,7 @@ PatternMatchResult ConvertTFConcatV2Op::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_concat_op = cast<TF::ConcatV2Op>(op);
|
||||
|
||||
SmallVector<Value*, 4> values(tf_concat_op.values());
|
||||
auto values = tf_concat_op.values();
|
||||
auto output_type = tf_concat_op.output()->getType();
|
||||
// Extract axis attribute from constant axis tensor
|
||||
ElementsAttr axis;
|
||||
@ -374,6 +375,14 @@ PatternMatchResult ConvertTFMatrixDiagV3Op::matchAndRewrite(
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
// TF Lite doesn't support Assert, we just drop the assert from the graph.
|
||||
PatternMatchResult ConvertTFAssertOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
op->dropAllReferences();
|
||||
op->erase();
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
void LegalizeTF::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
auto* ctx = &getContext();
|
||||
@ -385,7 +394,8 @@ void LegalizeTF::runOnFunction() {
|
||||
.insert<ConvertTFConcatOp, ConvertTFConcatV2Op, ConvertTFMatMulOp,
|
||||
ConvertTFMatrixDiagV2Op, ConvertTFMatrixDiagV3Op, ConvertTFPackOp,
|
||||
ConvertTFReshapeOp, ConvertTFSplitOp, ConvertTFSplitVOp,
|
||||
ConvertTFStridedSliceOp, ConvertTFUnpackOp>(ctx);
|
||||
ConvertTFStridedSliceOp, ConvertTFUnpackOp, ConvertTFAssertOp>(
|
||||
ctx);
|
||||
applyPatternsGreedily(func, patterns);
|
||||
}
|
||||
|
||||
|
@ -484,9 +484,9 @@ struct ConvertTensorListResize : public ConversionPattern {
|
||||
&rewriter);
|
||||
|
||||
// Inserts the two blocks' names into the symbol table held by the module.
|
||||
// Using ModuleManager will ensure that the inserted symbol names are
|
||||
// Using SymbolTable will ensure that the inserted symbol names are
|
||||
// unique.
|
||||
ModuleManager manager(resize_op.getParentOfType<ModuleOp>());
|
||||
SymbolTable manager(resize_op.getParentOfType<ModuleOp>());
|
||||
manager.insert(then_branch_op);
|
||||
manager.insert(else_branch_op);
|
||||
|
||||
@ -754,8 +754,7 @@ struct ConvertWhile : public ConversionPattern {
|
||||
cloned.removeAttr("T");
|
||||
UpdateFunctionTypes(cloned);
|
||||
|
||||
SmallVector<Value *, 8> results(cloned.getResults());
|
||||
rewriter.replaceOp(op, results);
|
||||
rewriter.replaceOp(op, cloned.getResults());
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
@ -135,15 +135,15 @@ class FoldIfOp : public OpRewritePattern<TF::IfOp> {
|
||||
static void EraseDeadFuncs(const FuncSet& candiate_funcs, ModuleOp module) {
|
||||
if (candiate_funcs.empty()) return;
|
||||
|
||||
ModuleManager manager(module);
|
||||
SymbolTable manager(module);
|
||||
|
||||
// Identify the functions that are used as symbols in the module and shouldn't
|
||||
// be erased.
|
||||
FuncSet in_use_funcs;
|
||||
manager.getModule().walk([&](Operation* op) {
|
||||
manager.getOp()->walk([&](Operation* op) {
|
||||
for (auto attr : op->getAttrs()) {
|
||||
if (auto symbol = attr.second.dyn_cast<FlatSymbolRefAttr>()) {
|
||||
auto func = manager.lookupSymbol<FuncOp>(symbol.getValue());
|
||||
auto func = manager.lookup<FuncOp>(symbol.getValue());
|
||||
in_use_funcs.insert(func);
|
||||
}
|
||||
}
|
||||
|
@ -44,12 +44,13 @@ multiclass FuseActFnIntoConvOpPat<dag ActFnOp, dag ActFnAttr> {
|
||||
$multiplier)>;
|
||||
}
|
||||
|
||||
// TODO(hinsu): Also fuse ops corresponding to RELU_N1_TO_1 and SIGN_BIT fused
|
||||
// TODO(hinsu): Also fuse ops corresponding to SIGN_BIT fused
|
||||
// activation functions.
|
||||
// Currently we're not fusing tanh, sigmoid, hard_swish and other activations
|
||||
// those cannot be simply translated into clamping.
|
||||
foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
|
||||
[TFL_Relu6Op, TFL_AF_Relu6]] in
|
||||
[TFL_Relu6Op, TFL_AF_Relu6],
|
||||
[TFL_Relu1Op, TFL_AF_Relu1]] in
|
||||
defm : FuseActFnIntoConvOpPat<actFnPair[0], actFnPair[1]>;
|
||||
|
||||
|
||||
@ -291,3 +292,18 @@ def : Pat<(TFL_MaximumOp (TFL_MinimumOp $input,
|
||||
(ConstantOp $NegOne)),
|
||||
(TFL_Relu1Op $input),
|
||||
[(ValueEquals<"-1"> $NegOne), (ValueEquals<"1"> $One)]>;
|
||||
|
||||
// Multi-pattern consisting of matching stand-alone op or op followed by relu.
|
||||
multiclass FusedBinaryActivationFuncOpPat<dag BinaryOp> {
|
||||
foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
|
||||
[TFL_Relu6Op, TFL_AF_Relu6],
|
||||
[TFL_Relu1Op, TFL_AF_Relu1]] in {
|
||||
def : Pat<(actFnPair[0] (BinaryOp $lhs, $rhs, TFL_AF_None)),
|
||||
(BinaryOp $lhs, $rhs, actFnPair[1])>;
|
||||
}
|
||||
}
|
||||
|
||||
// Instantiated FusedBinary patterns for the from-to pairs of ops.
|
||||
foreach BinaryOps = [TFL_AddOp, TFL_DivOp,
|
||||
TFL_MulOp, TFL_SubOp] in
|
||||
defm : FusedBinaryActivationFuncOpPat<BinaryOps>;
|
||||
|
@ -192,6 +192,37 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "decompose_resource_ops_inc_gen",
|
||||
tbl_outs = [
|
||||
(
|
||||
"-gen-rewriters",
|
||||
"transforms/generated_decompose_resource_ops.inc",
|
||||
),
|
||||
],
|
||||
tblgen = "@local_config_mlir//:mlir-tblgen",
|
||||
td_file = "transforms/decompose_resource_ops.td",
|
||||
td_srcs = [
|
||||
":tensorflow_ops_td_files",
|
||||
"@local_config_mlir//:StdOpsTdFiles",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "decompose_resource_ops",
|
||||
srcs = [
|
||||
"transforms/decompose_resource_ops.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"transforms/decompose_resource_ops.h",
|
||||
],
|
||||
deps = [
|
||||
":decompose_resource_ops_inc_gen",
|
||||
":tensorflow",
|
||||
"@local_config_mlir//:IR",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensorflow_passes",
|
||||
srcs = [
|
||||
@ -199,6 +230,7 @@ cc_library(
|
||||
"transforms/bridge_pass.cc",
|
||||
"transforms/cluster_formation.cc",
|
||||
"transforms/cluster_outlining.cc",
|
||||
"transforms/decompose_resource_ops_pass.cc",
|
||||
"transforms/delete_unused_funcs.cc",
|
||||
"transforms/executor_island_coarsening.cc",
|
||||
"transforms/fold_switch.cc",
|
||||
@ -213,6 +245,7 @@ cc_library(
|
||||
"transforms/raise_control_flow.cc",
|
||||
"transforms/replicate_invariant_op_hoisting.cc",
|
||||
"transforms/replicate_to_island.cc",
|
||||
"transforms/resource_device_inference.cc",
|
||||
"transforms/resource_op_lifting.cc",
|
||||
"transforms/shape_inference.cc",
|
||||
"transforms/shape_inference_pass.cc",
|
||||
@ -236,6 +269,8 @@ cc_library(
|
||||
":bridge_logger",
|
||||
":convert_tensor",
|
||||
":convert_type",
|
||||
":decompose_resource_ops",
|
||||
":decompose_resource_ops_inc_gen",
|
||||
":device_util",
|
||||
":error_util",
|
||||
":export_tf_dialect_op",
|
||||
@ -368,12 +403,14 @@ cc_library(
|
||||
":convert_tensor",
|
||||
":convert_type",
|
||||
":mangling_util",
|
||||
":tensorflow",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:protobuf",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/memory",
|
||||
@ -564,7 +601,6 @@ cc_library(
|
||||
hdrs = ["utils/error_util.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
],
|
||||
@ -808,7 +844,6 @@ cc_library(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Parser",
|
||||
|
@ -34,6 +34,7 @@ limitations under the License.
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
|
||||
@ -99,12 +100,13 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) {
|
||||
auto forward_input_to_output = [&](Value* operand, Value* result) {
|
||||
if (!mlir::getElementTypeOrSelf(result->getType()).isa<TF::ResourceType>())
|
||||
return;
|
||||
auto& result_ids = resource_value_to_ids_[result];
|
||||
auto operand_it = resource_value_to_ids_.find(operand);
|
||||
assert(operand_it != resource_value_to_ids_.end() &&
|
||||
"A resource-type output does not have the corresponding "
|
||||
"resource-type input.");
|
||||
resource_value_to_ids_[result].insert(operand_it->getSecond().begin(),
|
||||
operand_it->getSecond().end());
|
||||
result_ids.insert(operand_it->getSecond().begin(),
|
||||
operand_it->getSecond().end());
|
||||
};
|
||||
// TODO(yuanzx): Consider control-flow ops.
|
||||
func_op.walk([&](Operation* op) {
|
||||
@ -119,6 +121,16 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) {
|
||||
forward_input_to_output(std::get<0>(operand_and_result),
|
||||
std::get<1>(operand_and_result));
|
||||
}
|
||||
} else if (auto replicate = llvm::dyn_cast<tf_device::ReplicateOp>(op)) {
|
||||
// The nested block for RepliateOp is handled separately in side-effect
|
||||
// analysis. Inside that block, we can still treat its block arguments as
|
||||
// different resources.
|
||||
for (auto arg : replicate.GetBody().getArguments()) {
|
||||
if (mlir::getElementTypeOrSelf(arg->getType())
|
||||
.isa<TF::ResourceType>()) {
|
||||
resource_value_to_ids_[arg].insert(next_unique_id++);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (auto result : op->getResults()) {
|
||||
if (!mlir::getElementTypeOrSelf(result->getType())
|
||||
@ -261,9 +273,36 @@ void SideEffectAnalysis::AddPredecessorsForAccess(int64_t resource_id,
|
||||
|
||||
void SideEffectAnalysis::AnalyzeFunction(
|
||||
FuncOp func_op, const ResourceAliasAnalysis& alias_analysis) {
|
||||
// This function populates control_predecessors_ and control_successors_ by
|
||||
// walking through func_op's body, and tracking resource accesses in
|
||||
// per_resource_access_info_.
|
||||
// AnalyzeRegion() recursively analyzes the function body, and only populates
|
||||
// control_predecessors_.
|
||||
AnalyzeRegion(&func_op.getBody(), alias_analysis);
|
||||
// Populate sorted_control_predecessors_ and sorted_control_successors_ based
|
||||
// on control_predecessors.
|
||||
for (auto& entry : control_predecessors_) {
|
||||
auto op = entry.getFirst();
|
||||
auto& sorted_predecessors = sorted_control_predecessors_[op];
|
||||
for (auto predecessor : entry.getSecond()) {
|
||||
sorted_predecessors.push_back(predecessor);
|
||||
sorted_control_successors_[predecessor].push_back(op);
|
||||
}
|
||||
}
|
||||
control_predecessors_.clear();
|
||||
for (auto& entry : sorted_control_predecessors_) {
|
||||
llvm::sort(entry.getSecond(), [](Operation* a, Operation* b) {
|
||||
return a->isBeforeInBlock(b);
|
||||
});
|
||||
}
|
||||
for (auto& entry : sorted_control_successors_) {
|
||||
llvm::sort(entry.getSecond(), [](Operation* a, Operation* b) {
|
||||
return a->isBeforeInBlock(b);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void SideEffectAnalysis::AnalyzeRegion(
|
||||
Region* region, const ResourceAliasAnalysis& alias_analysis) {
|
||||
// This function populates control_predecessors_ by walking through the
|
||||
// region, and tracking resource accesses in per_resource_access_info_.
|
||||
|
||||
// Returns whether an access to `resource` can skip control edges from
|
||||
// prevoius accesses to unknown resources, due to that earlier accesses to
|
||||
@ -284,82 +323,93 @@ void SideEffectAnalysis::AnalyzeFunction(
|
||||
(it->second.tracked_last_unknown_read || no_unknown_read);
|
||||
};
|
||||
|
||||
func_op.walk([&](Operation* op) {
|
||||
// We do not need explicit control edges for declaration ops.
|
||||
if (OpIsDeclaration(op, alias_analysis)) return;
|
||||
|
||||
auto resource_op_info = GetResourceInfoForOp(op);
|
||||
if (!resource_op_info && op->hasNoSideEffect()) return;
|
||||
|
||||
llvm::SmallDenseSet<int64_t, 8> resources =
|
||||
resource_op_info ? FindAccessedResources(op, alias_analysis)
|
||||
: UnknownResourceSet();
|
||||
assert(!resources.empty());
|
||||
const bool is_unknown = resources.count(kUnknownResourceId) > 0;
|
||||
const bool read_only = OpIsReadOnly(op);
|
||||
bool indirectly_tracked_unknown_access = false;
|
||||
// First add edges from known resources.
|
||||
if (is_unknown) {
|
||||
for (auto& entry : per_resource_access_info_) {
|
||||
if (entry.getFirst() == kUnknownResourceId) continue;
|
||||
AddPredecessorsForAccess(entry.getFirst(), op, read_only);
|
||||
indirectly_tracked_unknown_access |=
|
||||
unknown_access_indirectly_tracked_by_resource(entry.getFirst(),
|
||||
read_only);
|
||||
// We explicitly iterates through the regions and blocks, in order to handle
|
||||
// different nested regions separately.
|
||||
for (auto& block : *region) {
|
||||
for (auto& op : block) {
|
||||
if (op.getNumRegions() > 0) {
|
||||
llvm::SmallVector<SideEffectAnalysis, 4> child_analyses;
|
||||
for (auto& child_region : op.getRegions()) {
|
||||
child_analyses.emplace_back();
|
||||
child_analyses.back().AnalyzeRegion(&child_region, alias_analysis);
|
||||
}
|
||||
ConsumeChildAnalyses(std::move(child_analyses));
|
||||
}
|
||||
} else {
|
||||
for (int64_t resource : resources) {
|
||||
AddPredecessorsForAccess(resource, op, read_only);
|
||||
indirectly_tracked_unknown_access |=
|
||||
unknown_access_indirectly_tracked_by_resource(resource, read_only);
|
||||
// Update access info for known resources.
|
||||
TrackAccess(resource, op, read_only);
|
||||
}
|
||||
}
|
||||
// If not indirectly tracked, add edges from the unknown resource.
|
||||
if (!indirectly_tracked_unknown_access) {
|
||||
AddPredecessorsForAccess(kUnknownResourceId, op, read_only);
|
||||
}
|
||||
if (is_unknown) {
|
||||
// Update access info for unknown resource.
|
||||
TrackAccess(kUnknownResourceId, op, read_only);
|
||||
}
|
||||
});
|
||||
|
||||
// Populate control_successors_ based on control_predecessors_.
|
||||
for (auto& entry : control_predecessors_) {
|
||||
auto op = entry.getFirst();
|
||||
for (auto predecessor : entry.getSecond()) {
|
||||
control_successors_[predecessor].insert(op);
|
||||
// We do not need explicit control edges for declaration ops.
|
||||
if (OpIsDeclaration(&op, alias_analysis)) continue;
|
||||
|
||||
auto resource_op_info = GetResourceInfoForOp(&op);
|
||||
if (!resource_op_info && op.hasNoSideEffect()) continue;
|
||||
|
||||
llvm::SmallDenseSet<int64_t, 8> resources =
|
||||
resource_op_info ? FindAccessedResources(&op, alias_analysis)
|
||||
: UnknownResourceSet();
|
||||
assert(!resources.empty());
|
||||
const bool is_unknown = resources.count(kUnknownResourceId) > 0;
|
||||
const bool read_only = OpIsReadOnly(&op);
|
||||
bool indirectly_tracked_unknown_access = false;
|
||||
// First add edges from known resources.
|
||||
if (is_unknown) {
|
||||
for (auto& entry : per_resource_access_info_) {
|
||||
if (entry.getFirst() == kUnknownResourceId) continue;
|
||||
AddPredecessorsForAccess(entry.getFirst(), &op, read_only);
|
||||
indirectly_tracked_unknown_access |=
|
||||
unknown_access_indirectly_tracked_by_resource(entry.getFirst(),
|
||||
read_only);
|
||||
}
|
||||
} else {
|
||||
for (int64_t resource : resources) {
|
||||
AddPredecessorsForAccess(resource, &op, read_only);
|
||||
indirectly_tracked_unknown_access |=
|
||||
unknown_access_indirectly_tracked_by_resource(resource,
|
||||
read_only);
|
||||
// Update access info for known resources.
|
||||
TrackAccess(resource, &op, read_only);
|
||||
}
|
||||
}
|
||||
// If not indirectly tracked, add edges from the unknown resource.
|
||||
if (!indirectly_tracked_unknown_access) {
|
||||
AddPredecessorsForAccess(kUnknownResourceId, &op, read_only);
|
||||
}
|
||||
if (is_unknown) {
|
||||
// Update access info for unknown resource.
|
||||
TrackAccess(kUnknownResourceId, &op, read_only);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
llvm::SmallVector<Operation*, 8> SideEffectAnalysis::DirectControlPredecessors(
|
||||
void SideEffectAnalysis::ConsumeChildAnalyses(
|
||||
llvm::SmallVector<SideEffectAnalysis, 4>&& children) {
|
||||
for (auto& child : children) {
|
||||
for (auto& entry : child.control_predecessors_) {
|
||||
control_predecessors_[entry.getFirst()] = std::move(entry.getSecond());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
llvm::SmallVector<Operation*, 4> SideEffectAnalysis::DirectControlPredecessors(
|
||||
Operation* op, llvm::function_ref<bool(Operation*)> filter) const {
|
||||
llvm::SmallVector<Operation*, 8> result;
|
||||
auto it = control_predecessors_.find(op);
|
||||
if (it == control_predecessors_.end()) return result;
|
||||
llvm::SmallVector<Operation*, 4> result;
|
||||
auto it = sorted_control_predecessors_.find(op);
|
||||
if (it == sorted_control_predecessors_.end()) return result;
|
||||
result.reserve(it->getSecond().size());
|
||||
for (auto predecessor : it->getSecond()) {
|
||||
if (!filter || filter(predecessor)) result.push_back(predecessor);
|
||||
}
|
||||
llvm::sort(result,
|
||||
[](Operation* a, Operation* b) { return a->isBeforeInBlock(b); });
|
||||
return result;
|
||||
}
|
||||
|
||||
llvm::SmallVector<Operation*, 8> SideEffectAnalysis::DirectControlSuccessors(
|
||||
llvm::SmallVector<Operation*, 4> SideEffectAnalysis::DirectControlSuccessors(
|
||||
Operation* op, llvm::function_ref<bool(Operation*)> filter) const {
|
||||
llvm::SmallVector<Operation*, 8> result;
|
||||
auto it = control_successors_.find(op);
|
||||
if (it == control_successors_.end()) return result;
|
||||
llvm::SmallVector<Operation*, 4> result;
|
||||
auto it = sorted_control_successors_.find(op);
|
||||
if (it == sorted_control_successors_.end()) return result;
|
||||
result.reserve(it->getSecond().size());
|
||||
for (auto successor : it->getSecond()) {
|
||||
if (!filter || filter(successor)) result.push_back(successor);
|
||||
}
|
||||
llvm::sort(result,
|
||||
[](Operation* a, Operation* b) { return a->isBeforeInBlock(b); });
|
||||
return result;
|
||||
}
|
||||
|
||||
|
@ -32,6 +32,9 @@ namespace TF {
|
||||
|
||||
// An analysis that runs on a function and maps each resource-type value to a
|
||||
// set of unique int64_t IDs representing the possible resources it could alias.
|
||||
//
|
||||
// If there are nested regions, each region is handled separately. This means
|
||||
// cross-region aliasing cannot be checked by this analysis.
|
||||
class ResourceAliasAnalysis {
|
||||
public:
|
||||
explicit ResourceAliasAnalysis(Operation* op);
|
||||
@ -63,8 +66,12 @@ class ResourceAliasAnalysis {
|
||||
// interfering with all known resource op accesses. It distinguishes accesses
|
||||
// based on whether they are read-only, and read-only ops do not interfer with
|
||||
// each other.
|
||||
//
|
||||
// If there are nested regions, each region is handled separately, and control
|
||||
// dependencies are only tracked for ops under the same parent op.
|
||||
class SideEffectAnalysis {
|
||||
public:
|
||||
explicit SideEffectAnalysis() = default;
|
||||
explicit SideEffectAnalysis(Operation* op);
|
||||
SideEffectAnalysis(SideEffectAnalysis&& other) = default;
|
||||
~SideEffectAnalysis() = default;
|
||||
@ -72,23 +79,32 @@ class SideEffectAnalysis {
|
||||
// Returns a vector of ops that are direct control predecessors of `op`,
|
||||
// sorted in program order. If `filter` is provided, only predecessors that
|
||||
// pass the filter (returning true) will be included.
|
||||
llvm::SmallVector<Operation*, 8> DirectControlPredecessors(
|
||||
llvm::SmallVector<Operation*, 4> DirectControlPredecessors(
|
||||
Operation* op,
|
||||
llvm::function_ref<bool(Operation*)> filter = nullptr) const;
|
||||
|
||||
// Returns a vector of ops that are direct control successors of `op`, sorted
|
||||
// in program order. If `filter` is provided, only successors that pass the
|
||||
// filter (returning true) will be included.
|
||||
llvm::SmallVector<Operation*, 8> DirectControlSuccessors(
|
||||
llvm::SmallVector<Operation*, 4> DirectControlSuccessors(
|
||||
Operation* op,
|
||||
llvm::function_ref<bool(Operation*)> filter = nullptr) const;
|
||||
|
||||
private:
|
||||
// Runs the analysis on `func_op` and populates control_predecessors_ and
|
||||
// control_successors_.
|
||||
// Runs the analysis on `func_op` and populates sorted_control_predecessors_
|
||||
// and sorted_control_successors_.
|
||||
void AnalyzeFunction(FuncOp func_op,
|
||||
const ResourceAliasAnalysis& alias_analysis);
|
||||
|
||||
// Runs the analysis on `region` and populates control_predecessors_.
|
||||
void AnalyzeRegion(Region* region,
|
||||
const ResourceAliasAnalysis& alias_analysis);
|
||||
|
||||
// Moves the control_predecessors_ fields in `children` analyses to this
|
||||
// current analysis.
|
||||
void ConsumeChildAnalyses(
|
||||
llvm::SmallVector<SideEffectAnalysis, 4>&& children);
|
||||
|
||||
// Updates control_predecessors_ for `op` that is being visted, on the given
|
||||
// `resource_id`.
|
||||
void AddPredecessorsForAccess(int64_t resource_id, Operation* op,
|
||||
@ -98,11 +114,14 @@ class SideEffectAnalysis {
|
||||
void TrackAccess(int64_t resource_id, Operation* op, bool read_only);
|
||||
|
||||
// Maps from an op to its control predecessors.
|
||||
llvm::SmallDenseMap<Operation*, llvm::SmallPtrSet<Operation*, 8>, 8>
|
||||
llvm::SmallDenseMap<Operation*, llvm::SmallPtrSet<Operation*, 4>, 8>
|
||||
control_predecessors_;
|
||||
// Maps from an op to its control successors.
|
||||
llvm::SmallDenseMap<Operation*, llvm::SmallPtrSet<Operation*, 8>, 8>
|
||||
control_successors_;
|
||||
// Maps from an op to its control predecessors sorted in program order.
|
||||
llvm::SmallDenseMap<Operation*, llvm::SmallVector<Operation*, 4>, 8>
|
||||
sorted_control_predecessors_;
|
||||
// Maps from an op to its control successors sorted in program order.
|
||||
llvm::SmallDenseMap<Operation*, llvm::SmallVector<Operation*, 4>, 8>
|
||||
sorted_control_successors_;
|
||||
|
||||
// Internal per-resource data structure when we build the dependencies.
|
||||
struct PerResourceAcessInfo {
|
||||
|
@ -332,8 +332,7 @@ struct DropEmptyLaunch : public OpRewritePattern<LaunchOp> {
|
||||
if (&block.front() != &block.back()) return matchFailure();
|
||||
|
||||
// Map launch results to return operands.
|
||||
llvm::SmallVector<Value*, 8> new_rets(block.front().getOperands());
|
||||
rewriter.replaceOp(op, new_rets);
|
||||
rewriter.replaceOp(op, block.front().getOperands());
|
||||
|
||||
return matchSuccess();
|
||||
}
|
||||
|
@ -408,8 +408,7 @@ ParseResult ParseIslandOp(OpAsmParser &parser, OperationState &result) {
|
||||
if (!wrapped_op) return failure();
|
||||
OpBuilder builder(parser.getBuilder().getContext());
|
||||
builder.setInsertionPointToEnd(&block);
|
||||
builder.create<YieldOp>(wrapped_op->getLoc(),
|
||||
llvm::to_vector<8>(wrapped_op->getResults()));
|
||||
builder.create<YieldOp>(wrapped_op->getLoc(), wrapped_op->getResults());
|
||||
result.location = wrapped_op->getLoc();
|
||||
} else if (parser.parseRegion(body, llvm::None, llvm::None)) {
|
||||
return failure();
|
||||
@ -1065,8 +1064,7 @@ struct DropEmptyGraph : public OpRewritePattern<GraphOp> {
|
||||
if (&block.front() != &block.back()) return matchFailure();
|
||||
|
||||
// Map graph results to fetch operands.
|
||||
llvm::SmallVector<Value *, 8> new_rets(op.GetFetch().fetches());
|
||||
rewriter.replaceOp(op, new_rets);
|
||||
rewriter.replaceOp(op, op.GetFetch().fetches());
|
||||
|
||||
return matchSuccess();
|
||||
}
|
||||
|
@ -94,6 +94,8 @@ Inputs must be of same size and shape.
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>;
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TF_AddV2Op : TF_Op<"AddV2", [Broadcastable, Commutative, NoSideEffect]>,
|
||||
@ -143,6 +145,8 @@ retained with length 1.
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
}
|
||||
|
||||
def TF_AnyOp : TF_Op<"Any", [NoSideEffect]> {
|
||||
@ -169,6 +173,8 @@ retained with length 1.
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
}
|
||||
|
||||
def TF_ArgMaxOp : TF_Op<"ArgMax", [NoSideEffect]> {
|
||||
@ -2116,6 +2122,28 @@ tf.math.greater_equal(x, y) ==> [True, False, True, True]
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_HashTableV2Op : TF_Op<"HashTableV2", []> {
|
||||
let summary = "Creates a non-initialized hash table.";
|
||||
|
||||
let description = [{
|
||||
This op creates a hash table, specifying the type of its keys and values.
|
||||
Before using the table you will have to initialize it. After initialization the
|
||||
table will be immutable.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
StrAttr:$container,
|
||||
StrAttr:$shared_name,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$use_node_name_sharing,
|
||||
TypeAttr:$key_dtype,
|
||||
TypeAttr:$value_dtype
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_ResourceTensor:$table_handle
|
||||
);
|
||||
}
|
||||
|
||||
def TF_IdentityNOp : TF_Op<"IdentityN", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Returns a list of tensors with the same shapes and contents as the input
|
||||
@ -2473,7 +2501,7 @@ def TF_LogicalAndOp : TF_Op<"LogicalAnd", [Broadcastable, Commutative, NoSideEff
|
||||
}
|
||||
|
||||
def TF_LogicalNotOp : TF_Op<"LogicalNot", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary = "Returns the truth value of NOT x element-wise.";
|
||||
let summary = "Returns the truth value of `NOT x` element-wise.";
|
||||
|
||||
let description = [{
|
||||
}];
|
||||
@ -4334,6 +4362,37 @@ Resize `images` to `size` using nearest neighbor interpolation.
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_ResourceApplyAdamOp : TF_Op<"ResourceApplyAdam", []> {
|
||||
let summary = "Update '*var' according to the Adam algorithm.";
|
||||
|
||||
let description = [{
|
||||
$$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$
|
||||
$$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$
|
||||
$$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$
|
||||
$$variable := variable - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_ResourceTensor:$var,
|
||||
TF_ResourceTensor:$m,
|
||||
TF_ResourceTensor:$v,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$beta1_power,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$beta2_power,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$beta1,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$beta2,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$epsilon,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad,
|
||||
|
||||
DefaultValuedAttr<BoolAttr, "false">:$use_locking,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$use_nesterov
|
||||
);
|
||||
|
||||
let results = (outs);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<3>;
|
||||
}
|
||||
|
||||
def TF_ResourceApplyGradientDescentOp : TF_Op<"ResourceApplyGradientDescent", []> {
|
||||
let summary = "Update '*var' by subtracting 'alpha' * 'delta' from it.";
|
||||
|
||||
@ -4353,6 +4412,34 @@ def TF_ResourceApplyGradientDescentOp : TF_Op<"ResourceApplyGradientDescent", []
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>;
|
||||
}
|
||||
|
||||
def TF_ResourceApplyKerasMomentumOp : TF_Op<"ResourceApplyKerasMomentum", []> {
|
||||
let summary = [{
|
||||
Update '*var' according to the momentum scheme.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
Set use_nesterov = True if you want to use Nesterov momentum.
|
||||
|
||||
accum = accum * momentum - lr * grad
|
||||
var += accum
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_ResourceTensor:$var,
|
||||
TF_ResourceTensor:$accum,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$momentum,
|
||||
|
||||
DefaultValuedAttr<BoolAttr, "false">:$use_locking,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$use_nesterov
|
||||
);
|
||||
|
||||
let results = (outs);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<2>;
|
||||
}
|
||||
|
||||
def TF_ReverseSequenceOp : TF_Op<"ReverseSequence", [NoSideEffect]> {
|
||||
let summary = "Reverses variable length slices.";
|
||||
|
||||
@ -5117,6 +5204,8 @@ def TF_SplitVOp : TF_Op<"SplitV", [NoSideEffect]> {
|
||||
TF_DerivedOperandTypeAttr Tlen = TF_DerivedOperandTypeAttr<1>;
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedResultSizeAttr num_split = TF_DerivedResultSizeAttr<0>;
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
}
|
||||
|
||||
def TF_SqrtOp : TF_Op<"Sqrt", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
@ -5491,6 +5580,65 @@ output. For the internal use of the distributed TPU compiler.
|
||||
TF_DerivedResultTypeListAttr Tresults = TF_DerivedResultTypeListAttr<0>;
|
||||
}
|
||||
|
||||
def TF_TPUReplicatedInputOp : TF_Op<"TPUReplicatedInput", [NoSideEffect]> {
|
||||
let summary = "Connects N inputs to an N-way replicated TPU computation.";
|
||||
|
||||
let description = [{
|
||||
This operation holds a replicated input to a `tpu.replicate()` computation subgraph.
|
||||
Each replicated input has the same shape and type alongside the output.
|
||||
|
||||
For example:
|
||||
```
|
||||
%a = "tf.opA"()
|
||||
%b = "tf.opB"()
|
||||
%replicated_input = "tf.TPUReplicatedInput"(%a, %b)
|
||||
%computation = "tf.Computation"(%replicated_input)
|
||||
```
|
||||
The above computation has a replicated input of two replicas.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<TF_Tensor>:$inputs,
|
||||
|
||||
DefaultValuedAttr<BoolAttr, "false">:$is_mirrored_variable,
|
||||
DefaultValuedAttr<I64Attr, "-1">:$index
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_Tensor:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_TPUReplicatedOutputOp : TF_Op<"TPUReplicatedOutput", [NoSideEffect]> {
|
||||
let summary = "Connects N outputs from an N-way replicated TPU computation.";
|
||||
|
||||
let description = [{
|
||||
This operation holds a replicated output from a `tpu.replicate()` computation subgraph.
|
||||
Each replicated output has the same shape and type alongside the input.
|
||||
|
||||
For example:
|
||||
```
|
||||
%computation = "tf.Computation"()
|
||||
%replicated_output:2 = "tf.TPUReplicatedOutput"(%computation)
|
||||
```
|
||||
The above computation has a replicated output of two replicas.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_Tensor:$input
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<TF_Tensor>:$outputs
|
||||
);
|
||||
|
||||
TF_DerivedResultSizeAttr num_replicas = TF_DerivedResultSizeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_TanhOp : TF_Op<"Tanh", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary = "Computes hyperbolic tangent of `x` element-wise.";
|
||||
|
||||
@ -5905,6 +6053,8 @@ This is the opposite of `pack`.
|
||||
|
||||
TF_DerivedResultSizeAttr num = TF_DerivedResultSizeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
}
|
||||
|
||||
def TF_VariableShapeOp : TF_Op<"VariableShape", []> {
|
||||
|
@ -301,6 +301,15 @@ void AddOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
results.insert<AddToAddV2>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AddNOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AddNOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (operands.size() == 1) return *inputs().begin();
|
||||
return {};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AddV2Op
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -310,6 +319,49 @@ void AddV2Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
results.insert<AddV2OfNegLeft, AddV2OfNegRight>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AllOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Verifies an reduction op's `input` and reduction `dims`.
|
||||
static LogicalResult VerifyReductionInputAndDims(Value *input, Value *dims,
|
||||
Location loc) {
|
||||
auto dims_type = dims->getType().dyn_cast<RankedTensorType>();
|
||||
if (!dims_type) return success();
|
||||
if (dims_type.getRank() > 1)
|
||||
return emitError(loc, "dimensions can only be 0D or 1D tensor");
|
||||
|
||||
auto input_type = input->getType().dyn_cast<RankedTensorType>();
|
||||
if (!input_type) return success();
|
||||
int64_t rank = input_type.getRank();
|
||||
|
||||
DenseIntElementsAttr dims_attr;
|
||||
if (!matchPattern(dims, m_Constant(&dims_attr))) return success();
|
||||
for (const auto &dim_pair : llvm::enumerate(dims_attr)) {
|
||||
int64_t cur_dim = dim_pair.value().getSExtValue();
|
||||
if (cur_dim < -rank || cur_dim >= rank)
|
||||
return emitError(loc)
|
||||
<< dim_pair.index() << "-th dimension should be in the range of [-"
|
||||
<< rank << ", " << rank << ")";
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult Verify(AllOp op) {
|
||||
return VerifyReductionInputAndDims(op.input(), op.reduction_indices(),
|
||||
op.getLoc());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AnyOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(AnyOp op) {
|
||||
return VerifyReductionInputAndDims(op.input(), op.reduction_indices(),
|
||||
op.getLoc());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AssertOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1542,17 +1594,23 @@ static LogicalResult Verify(SoftmaxCrossEntropyWithLogitsOp op) {
|
||||
// SplitOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(SplitOp op) {
|
||||
// Verifies the input and split dimension operands for tf.Split/tf.SplitV.
|
||||
// Writes the split dimension's index (adjusted with input rank) via `dim_index`
|
||||
// if it's a constant.
|
||||
template <class Op>
|
||||
LogicalResult VerifySplitInputAndSplitDim(Op op, Optional<int64_t> *dim_index) {
|
||||
*dim_index = llvm::None;
|
||||
|
||||
Value *split_dim = op.split_dim();
|
||||
auto split_dim_type = split_dim->getType().dyn_cast<RankedTensorType>();
|
||||
if (!split_dim_type) return success();
|
||||
if (split_dim_type.getRank() != 0)
|
||||
return op.emitOpError("split dimension should be an integer scalar tensor");
|
||||
if (auto split_dim_type = split_dim->getType().dyn_cast<RankedTensorType>())
|
||||
if (split_dim_type.getRank() != 0)
|
||||
return op.emitOpError(
|
||||
"split dimension should be an integer scalar tensor");
|
||||
|
||||
// We can perform further verification if the input tensor to be split has
|
||||
// known rank and the split dimension tensor is a constant.
|
||||
|
||||
auto input_type = op.value()->getType().dyn_cast<RankedTensorType>();
|
||||
auto input_type = op.value()->getType().template dyn_cast<RankedTensorType>();
|
||||
if (!input_type) return success();
|
||||
|
||||
int64_t input_rank = input_type.getRank();
|
||||
@ -1562,21 +1620,95 @@ static LogicalResult Verify(SplitOp op) {
|
||||
DenseIntElementsAttr split_dim_attr;
|
||||
if (!matchPattern(split_dim, m_Constant(&split_dim_attr))) return success();
|
||||
|
||||
int64_t dim_index = (*split_dim_attr.begin()).getSExtValue();
|
||||
int64_t index = (*split_dim_attr.begin()).getSExtValue();
|
||||
|
||||
if (dim_index + input_rank < 0 || dim_index >= input_rank) {
|
||||
if (index + input_rank < 0 || index >= input_rank) {
|
||||
return op.emitOpError("split dimension must be in range [-")
|
||||
<< input_rank << ", " << input_rank << ")";
|
||||
}
|
||||
|
||||
if (dim_index < 0) dim_index += input_rank;
|
||||
if (index < 0) index += input_rank;
|
||||
*dim_index = index;
|
||||
|
||||
int64_t input_dim_size = input_type.getDimSize(dim_index);
|
||||
if (input_dim_size < 0) return success();
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult Verify(SplitOp op) {
|
||||
Optional<int64_t> dim_index;
|
||||
if (failed(VerifySplitInputAndSplitDim(op, &dim_index))) return failure();
|
||||
if (!dim_index) return success();
|
||||
|
||||
int64_t input_dim_size =
|
||||
op.value()->getType().cast<RankedTensorType>().getDimSize(*dim_index);
|
||||
if (input_dim_size == ShapedType::kDynamicSize) return success();
|
||||
|
||||
if (input_dim_size % op.getNumResults() != 0)
|
||||
return op.emitOpError("dimension #")
|
||||
<< dim_index << " not divisible by the number of result tensors";
|
||||
<< *dim_index << " not divisible by the number of result tensors";
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SplitVOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(SplitVOp op) {
|
||||
auto split_sizes_type =
|
||||
op.size_splits()->getType().dyn_cast<RankedTensorType>();
|
||||
if (!split_sizes_type) return success();
|
||||
|
||||
if (split_sizes_type.getRank() != 1 ||
|
||||
split_sizes_type.getDimSize(0) != op.getNumResults())
|
||||
return op.emitOpError("split sizes should be a 1D tensor of ")
|
||||
<< op.getNumResults() << " elements";
|
||||
|
||||
Optional<int64_t> dim_index = 0;
|
||||
if (failed(VerifySplitInputAndSplitDim(op, &dim_index))) return failure();
|
||||
if (!dim_index) return success();
|
||||
|
||||
int64_t input_dim_size =
|
||||
op.value()->getType().cast<RankedTensorType>().getDimSize(*dim_index);
|
||||
if (input_dim_size == ShapedType::kDynamicSize) return success();
|
||||
|
||||
// If split sizes come from a constant, they must sum to the dimension size
|
||||
// along split_dim, and we can have no more than one dynamic dimension.
|
||||
DenseIntElementsAttr split_sizes_attr;
|
||||
if (!matchPattern(op.size_splits(), m_Constant(&split_sizes_attr)))
|
||||
return success();
|
||||
|
||||
int64_t total_dim_size = 0; // Total dimension size assigned to splits
|
||||
llvm::Optional<int> dynamic_dim_index;
|
||||
|
||||
SmallVector<int64_t, 4> split_sizes;
|
||||
split_sizes.reserve(
|
||||
split_sizes_attr.getType().cast<ShapedType>().getNumElements());
|
||||
|
||||
for (auto dim : llvm::enumerate(split_sizes_attr)) {
|
||||
int64_t dim_val = dim.value().getSExtValue();
|
||||
split_sizes.push_back(dim_val);
|
||||
if (dim_val == ShapedType::kDynamicSize) {
|
||||
// We cannot have more than one dynamic dimension.
|
||||
if (dynamic_dim_index)
|
||||
return op.emitOpError(
|
||||
"cannot have more than one dynamic dimension in split sizes");
|
||||
dynamic_dim_index = dim.index();
|
||||
} else {
|
||||
total_dim_size += dim_val;
|
||||
}
|
||||
}
|
||||
|
||||
if (!dynamic_dim_index && total_dim_size != input_dim_size)
|
||||
return op.emitOpError(
|
||||
"split sizes must sum up to the dimension size along split "
|
||||
"dimension, found ")
|
||||
<< total_dim_size << " vs " << input_dim_size;
|
||||
|
||||
if (dynamic_dim_index && total_dim_size > input_dim_size)
|
||||
return op.emitOpError(
|
||||
"split sizes must sum up to be less than or equal to the "
|
||||
"dimension size along split dimension, found ")
|
||||
<< total_dim_size << " vs " << input_dim_size;
|
||||
|
||||
return success();
|
||||
}
|
||||
@ -1787,6 +1919,30 @@ void TruncateDivOp::getCanonicalizationPatterns(
|
||||
results.insert<TruncateDivWithSqrtDivisor>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// UnpackOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(UnpackOp op) {
|
||||
auto value_type = op.value()->getType().dyn_cast<RankedTensorType>();
|
||||
if (!value_type) return success();
|
||||
|
||||
int64_t value_rank = value_type.getRank();
|
||||
int64_t axis = op.axis().getSExtValue();
|
||||
if (axis < -value_rank || axis >= value_rank)
|
||||
return op.emitOpError("axis attribute must be in the range of [-")
|
||||
<< value_rank << ", " << value_rank << ')';
|
||||
|
||||
axis = GetDimForAxis(axis, value_rank);
|
||||
int64_t dim_size = value_type.getDimSize(axis);
|
||||
if (ShapedType::isDynamic(dim_size)) return success();
|
||||
|
||||
if (dim_size != op.getNumResults())
|
||||
return op.emitOpError("result count must be equal to ") << dim_size;
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// VariableShapeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -196,6 +196,42 @@ retained with length 1.
|
||||
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
|
||||
}
|
||||
|
||||
def TF_LegacyCallOp : TF_Op<"LegacyCall",
|
||||
[CallOpInterface, NoSideEffect]> {
|
||||
let summary =
|
||||
"returns `f(inputs)`, where `f` is a function.";
|
||||
|
||||
let description = [{
|
||||
The LegacyCall operation represents a direct call to a function that is
|
||||
within the same symbol scope as the call and is mapped to a GraphDef node
|
||||
with the function name as the op name. Unlike a PartitionedCall which
|
||||
represents asynchronously executing a function across multiple devices, a
|
||||
LegacyCall represents a function call with the only attribute
|
||||
_diable_call_shape_inference.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<TF_Tensor>:$args,
|
||||
|
||||
FlatSymbolRefAttr:$f,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$_disable_call_shape_inference
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<TF_Tensor>:$output
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// Gets the argument operands to the called function.
|
||||
operand_range getArgOperands() { return args(); }
|
||||
|
||||
// Returns the callee of this operation.
|
||||
CallInterfaceCallable getCallableForCallee() {
|
||||
return getAttrOfType<SymbolRefAttr>("f");
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_PartitionedCallOp : TF_Op<"PartitionedCall",
|
||||
[CallOpInterface, NoSideEffect]> {
|
||||
let summary =
|
||||
|
@ -18,7 +18,7 @@ func @multiple_return(%arg0: tensor<*xi32>, %arg1: tensor<i32>) -> (tensor<*xi32
|
||||
// CHECK-LABEL: func @multiple_return
|
||||
// CHECK: %[[GRAPH:.*]]:2 = tf_executor.graph {
|
||||
// CHECK: %[[ADD1:.*]], %[[ADD1_control:.*]] = tf_executor.island wraps "tf.Add"(%arg0, %arg1)
|
||||
// CHECK: %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island(%[[ADD1_control]]) wraps "tf.Add"(%[[ADD1]], %arg1)
|
||||
// CHECK: %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ADD1]], %arg1)
|
||||
// CHECK: tf_executor.fetch %[[ADD1]], %[[ADD2]] :
|
||||
// CHECK: }
|
||||
// CHECK: return %[[GRAPH]]#0, %[[GRAPH]]#1
|
||||
@ -41,7 +41,12 @@ func @multiple_islands(%arg0: tensor<*xi32>, %arg1: tensor<i32>) -> (tensor<*xi3
|
||||
%res = "tf.Print"(%sub) { message = "sub result" } : (tensor<*xi32>) -> (tensor<*xi32>)
|
||||
tf_executor.yield
|
||||
}
|
||||
tf_executor.fetch %island1#1, %island2#1, %island3 : tensor<*xi32>, tensor<*xi32>, !tf_executor.control
|
||||
%island4 = tf_executor.island(%island1#2, %island2#2) {
|
||||
%add = "tf.Add"(%island1#1, %island1#1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
|
||||
%res = "tf.Print"(%add) { message = "add result" } : (tensor<*xi32>) -> (tensor<*xi32>)
|
||||
tf_executor.yield
|
||||
}
|
||||
tf_executor.fetch %island1#1, %island2#1, %island3, %island4 : tensor<*xi32>, tensor<*xi32>, !tf_executor.control, !tf_executor.control
|
||||
}
|
||||
return %graph#0, %graph#1 : tensor<*xi32>, tensor<*xi32>
|
||||
}
|
||||
@ -49,12 +54,17 @@ func @multiple_islands(%arg0: tensor<*xi32>, %arg1: tensor<i32>) -> (tensor<*xi3
|
||||
// CHECK-LABEL: func @multiple_islands
|
||||
// CHECK: %[[GRAPH:.*]]:2 = tf_executor.graph {
|
||||
// CHECK: %[[ADD1:.*]], %[[ADD1_control:.*]] = tf_executor.island wraps "tf.Add"(%arg0, %arg1)
|
||||
// CHECK: %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island(%[[ADD1_control]]) wraps "tf.Add"(%[[ADD1]], %arg1)
|
||||
// CHECK: %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ADD1]], %arg1)
|
||||
// CHECK: %[[SUB1:.*]], %[[SUB1_control:.*]] = tf_executor.island(%[[ADD2_control]]) wraps "tf.Sub"(%arg0, %arg1)
|
||||
// CHECK: %[[MUL:.*]], %[[MUL_control:.*]] = tf_executor.island(%[[SUB1_control]]) wraps "tf.Mul"(%[[SUB1]], %arg1)
|
||||
// CHECK: %[[MUL:.*]], %[[MUL_control:.*]] = tf_executor.island wraps "tf.Mul"(%[[SUB1]], %arg1)
|
||||
// CHECK: %[[SUB2:.*]], %[[SUB2_control:.*]] = tf_executor.island(%[[ADD2_control]], %[[MUL_control]]) wraps "tf.Sub"(%[[ADD1]], %[[SUB1]])
|
||||
// CHECK: %[[PRINT:.*]], %[[PRINT_control:.*]] = tf_executor.island(%[[SUB2_control]]) wraps "tf.Print"(%[[SUB2]]) {message = "sub result"}
|
||||
// CHECK: tf_executor.fetch %[[ADD2]], %[[MUL]], %[[PRINT_control]] :
|
||||
// CHECK: %[[PRINT1:.*]], %[[PRINT1_control:.*]] = tf_executor.island wraps "tf.Print"(%[[SUB2]]) {message = "sub result"}
|
||||
// CHECK: %[[ISLAND1:.*]] = tf_executor.island(%[[ADD2_control]], %[[MUL_control]]) {
|
||||
// CHECK: tf_executor.yield
|
||||
// CHECK: }
|
||||
// CHECK: %[[ADD3:.*]], %[[ADD3_control:.*]] = tf_executor.island(%[[ISLAND1]], %[[ADD2_control]]) wraps "tf.Add"(%[[ADD2]], %[[ADD2]])
|
||||
// CHECK: %[[PRINT2:.*]], %[[PRINT2_control:.*]] = tf_executor.island wraps "tf.Print"(%[[ADD3]]) {message = "add result"}
|
||||
// CHECK: tf_executor.fetch %[[ADD2]], %[[MUL]], %[[PRINT1_control]], %[[PRINT2_control:.*]] :
|
||||
// CHECK: }
|
||||
// CHECK: return %[[GRAPH]]#0, %[[GRAPH]]#1
|
||||
|
||||
@ -74,8 +84,8 @@ func @dangling_print(%arg0: tensor<*xi32>, %arg1: tensor<i32>) -> (tensor<*xi32>
|
||||
// CHECK-LABEL: func @dangling_print
|
||||
// CHECK: %[[GRAPH:.*]]:2 = tf_executor.graph {
|
||||
// CHECK: %[[ADD1:.*]], %[[ADD1_control:.*]] = tf_executor.island wraps "tf.Add"(%arg0, %arg1)
|
||||
// CHECK: %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island(%[[ADD1_control]]) wraps "tf.Add"(%[[ADD1_control:.*]], %arg1)
|
||||
// CHECK: %[[PRINT:.*]], %[[PRINT_control:.*]] = tf_executor.island(%[[ADD2_control]]) wraps "tf.Print"(%[[ADD2_control:.*]]) {message = "add result"}
|
||||
// CHECK: %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ADD1_control:.*]], %arg1)
|
||||
// CHECK: %[[PRINT:.*]], %[[PRINT_control:.*]] = tf_executor.island wraps "tf.Print"(%[[ADD2_control:.*]]) {message = "add result"}
|
||||
// CHECK: tf_executor.fetch %[[ADD1]], %[[ADD2]], %[[PRINT_control]] :
|
||||
// CHECK: }
|
||||
// CHECK: return %[[GRAPH]]#0, %[[GRAPH]]#1
|
||||
@ -103,11 +113,14 @@ func @switch_and_merge(%arg0: tensor<*xi32>, %arg1: tensor<i32>) -> (tensor<*xi3
|
||||
// CHECK-LABEL: func @switch_and_merge(%arg0: tensor<*xi32>, %arg1: tensor<i32>) -> (tensor<*xi32>, tensor<i32>) {
|
||||
// CHECK: %[[GRAPH:.*]]:2 = tf_executor.graph {
|
||||
// CHECK: %[[ADD1:.*]], %[[ADD1_control:.*]] = tf_executor.island wraps "tf.Add"(%arg0, %arg1)
|
||||
// CHECK: %[[LESS:.*]], %[[LESS_control:.*]] = tf_executor.island(%[[ADD1_control]]) wraps "tf.Less"(%arg1, %arg1)
|
||||
// CHECK: %[[PRINT1:.*]], %[[PRINT1_control:.*]] = tf_executor.island(%[[LESS_control]]) wraps "tf.Print"(%[[ADD1]]) {message = "add result 1"}
|
||||
// CHECK: %[[SWITCH_false:.*]], %[[SWITCH_true:.*]], {{.*}} = tf_executor.Switch %[[ADD1]], %[[LESS]], %[[PRINT1_control]]
|
||||
// CHECK: %[[LESS:.*]], %[[LESS_control:.*]] = tf_executor.island wraps "tf.Less"(%arg1, %arg1)
|
||||
// CHECK: %[[PRINT1:.*]], %[[PRINT1_control:.*]] = tf_executor.island wraps "tf.Print"(%[[ADD1]]) {message = "add result 1"}
|
||||
// CHECK: %[[ISLAND1:.*]] = tf_executor.island(%[[LESS_control]], %[[PRINT1_control]]) {
|
||||
// CHECK: tf_executor.yield
|
||||
// CHECK: }
|
||||
// CHECK: %[[SWITCH_false:.*]], %[[SWITCH_true:.*]], {{.*}} = tf_executor.Switch %[[ADD1]], %[[LESS]], %[[ISLAND1]]
|
||||
// CHECK: %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island wraps "tf.Add"(%[[SWITCH_false]], %arg1)
|
||||
// CHECK: %[[PRINT2:.*]], %[[PRINT2_control:.*]] = tf_executor.island(%[[ADD2_control]]) wraps "tf.Print"(%[[ADD2]]) {message = "add result 2"}
|
||||
// CHECK: %[[PRINT2:.*]], %[[PRINT2_control:.*]] = tf_executor.island wraps "tf.Print"(%[[ADD2]]) {message = "add result 2"}
|
||||
// CHECK: %[[MERGE:.*]], %[[MERGE_index:.*]], %{{.*}} = tf_executor.Merge %[[ADD2]], %[[SWITCH_true]], %[[PRINT2_control]]
|
||||
// CHECK: tf_executor.fetch %[[MERGE]], %[[MERGE_index]]
|
||||
// CHECK: }
|
||||
@ -130,7 +143,7 @@ func @control_flow_plumbing(%arg0: tensor<*xi32>, %arg1: tensor<i32>) -> tensor<
|
||||
// CHECK: %[[GRAPH:.*]] = tf_executor.graph {
|
||||
// CHECK: %[[PRINT:.*]], %[[PRINT_control:.*]] = tf_executor.island wraps "tf.Print"(%arg0) {message = "Random Print"}
|
||||
// CHECK: %[[ADD1:.*]], %[[ADD1_control:.*]] = tf_executor.island(%[[PRINT_control]]) wraps "tf.Add"(%arg0, %arg1)
|
||||
// CHECK: %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island(%[[ADD1_control]]) wraps "tf.Add"(%[[ADD1]], %arg1)
|
||||
// CHECK: %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ADD1]], %arg1)
|
||||
// CHECK: tf_executor.fetch %[[ADD2]] : tensor<*xi32>
|
||||
// CHECK: }
|
||||
// CHECK: return %[[GRAPH]] : tensor<*xi32>
|
||||
@ -150,6 +163,77 @@ func @fetching_arg(%arg0: tensor<*xi32>) {
|
||||
// CHECK-LABEL: func @fetching_arg
|
||||
// CHECK: tf_executor.graph {
|
||||
// CHECK: %[[ADD1:.*]], %[[ADD1_control:.*]] = tf_executor.island wraps "tf.Add"(%arg0, %arg0)
|
||||
// CHECK: %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island(%[[ADD1_control]]) wraps "tf.Add"(%[[ADD1]], %arg0)
|
||||
// CHECK: %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ADD1]], %arg0)
|
||||
// CHECK: tf_executor.fetch %[[ADD2_control]] : !tf_executor.control
|
||||
// CHECK: }
|
||||
|
||||
func @non_aliasing_reads_writes(
|
||||
%arg0: tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
%arg1: tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
%arg2: tensor<32xf32>) -> (tensor<32xf32>) {
|
||||
%graph = tf_executor.graph {
|
||||
%island:2 = tf_executor.island {
|
||||
%read0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||
"tf.AssignVariableOp"(%arg0, %arg2) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
|
||||
%read1 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||
%var_handle = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
%read2 = "tf.ReadVariableOp"(%var_handle) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||
"tf.AssignVariableOp"(%arg1, %read0) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
|
||||
"tf.AssignVariableOp"(%arg0, %read2) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
|
||||
%read3 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||
tf_executor.yield %read3 : tensor<32xf32>
|
||||
}
|
||||
tf_executor.fetch %island#0 : tensor<32xf32>
|
||||
}
|
||||
return %graph : tensor<32xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @non_aliasing_reads_writes
|
||||
// CHECK: %[[GRAPH:.*]] = tf_executor.graph {
|
||||
// CHECK: %[[READ0:.*]], %[[READ0_CONTROL:.*]] = tf_executor.island wraps "tf.ReadVariableOp"(%arg0)
|
||||
// CHECK: %[[ASSIGN0_CONTROL:.*]] = tf_executor.island(%[[READ0_CONTROL]]) wraps "tf.AssignVariableOp"(%arg0, %arg2)
|
||||
// CHECK: %[[READ1:.*]], %[[READ1_CONTROL:.*]] = tf_executor.island wraps "tf.ReadVariableOp"(%arg1)
|
||||
// CHECK: %[[VH0:.*]], %[[VH0_CONTROL:.*]] = tf_executor.island wraps "tf.VarHandleOp"() {container = "c", shared_name = "v0"}
|
||||
// CHECK: %[[READ2:.*]], %[[READ2_CONTROL:.*]] = tf_executor.island wraps "tf.ReadVariableOp"(%[[VH0]])
|
||||
// CHECK: %[[ASSIGN1_CONTROL:.*]] = tf_executor.island(%[[READ1_CONTROL]]) wraps "tf.AssignVariableOp"(%arg1, %[[READ0:.*]])
|
||||
// CHECK: %[[ASSIGN2_CONTROL:.*]] = tf_executor.island(%[[ASSIGN0_CONTROL]]) wraps "tf.AssignVariableOp"(%arg0, %[[READ2]])
|
||||
// CHECK: %[[READ3:.*]], %[[READ3_CONTROL:.*]] = tf_executor.island(%[[ASSIGN2_CONTROL]]) wraps "tf.ReadVariableOp"(%arg0)
|
||||
// CHECK: %[[ISLAND1:.*]] = tf_executor.island(%[[ASSIGN1_CONTROL]], %[[READ3_CONTROL]]) {
|
||||
// CHECK: tf_executor.yield
|
||||
// CHECK: }
|
||||
// CHECK: tf_executor.fetch %[[READ3]], %[[ISLAND1]] : tensor<32xf32>, !tf_executor.control
|
||||
// CHECK: }
|
||||
|
||||
func @unknown_side_effecting_op(%arg0: tensor<32xf32>) -> () {
|
||||
tf_executor.graph {
|
||||
%island = tf_executor.island {
|
||||
%vh0 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
%vh1 = "tf.VarHandleOp"() {container = "c", shared_name = "v1"} : () -> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
%read0 = "tf.ReadVariableOp"(%vh0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||
"tf.AssignVariableOp"(%vh1, %arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
|
||||
"tf._UnknownSideEffectingOp_"() : () -> ()
|
||||
%read1 = "tf.ReadVariableOp"(%vh1) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||
"tf.AssignVariableOp"(%vh0, %read1) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
|
||||
"tf.AssignVariableOp"(%vh1, %read0) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
|
||||
tf_executor.yield
|
||||
}
|
||||
tf_executor.fetch %island : !tf_executor.control
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @unknown_side_effecting_op
|
||||
// CHECK: tf_executor.graph {
|
||||
// CHECK: %[[VH0:.*]], %[[VH0_CONTROL:.*]] = tf_executor.island wraps "tf.VarHandleOp"() {container = "c", shared_name = "v0"}
|
||||
// CHECK: %[[VH1:.*]], %[[VH1_CONTROL:.*]] = tf_executor.island wraps "tf.VarHandleOp"() {container = "c", shared_name = "v1"}
|
||||
// CHECK: %[[READ0:.*]], %[[READ0_CONTROL:.*]] = tf_executor.island wraps "tf.ReadVariableOp"(%[[VH0]])
|
||||
// CHECK: %[[ASSIGN0_CONTROL:.*]] = tf_executor.island wraps "tf.AssignVariableOp"(%[[VH1]], %arg0)
|
||||
// CHECK: %[[UNKNOWN_CONTROL:.*]] = tf_executor.island(%[[READ0_CONTROL]], %[[ASSIGN0_CONTROL]]) wraps "tf._UnknownSideEffectingOp_"()
|
||||
// CHECK: %[[READ1:.*]], %[[READ1_CONTROL:.*]] = tf_executor.island(%[[UNKNOWN_CONTROL]]) wraps "tf.ReadVariableOp"(%[[VH1]])
|
||||
// CHECK: %[[ASSIGN1_CONTROL:.*]] = tf_executor.island(%[[UNKNOWN_CONTROL]]) wraps "tf.AssignVariableOp"(%[[VH0]], %[[READ1]])
|
||||
// CHECK: %[[ASSIGN2_CONTROL:.*]] = tf_executor.island(%[[READ1_CONTROL]]) wraps "tf.AssignVariableOp"(%[[VH1]], %[[READ0]])
|
||||
// CHECK: %[[ISLAND1:.*]] = tf_executor.island(%[[ASSIGN1_CONTROL]], %[[ASSIGN2_CONTROL]]) {
|
||||
// CHECK: tf_executor.yield
|
||||
// CHECK: }
|
||||
// CHECK: tf_executor.fetch %[[ISLAND1]] : !tf_executor.control
|
||||
// CHECK: }
|
||||
|
@ -382,3 +382,10 @@ func @nonIdentityTranspose(%arg0: tensor<2x3x4x5x6xf32>) -> tensor<2x3x4x6x5xf32
|
||||
// CHECK: %1 = "tf.Transpose"(%arg0, %0) : (tensor<2x3x4x5x6xf32>, tensor<5xi32>) -> tensor<2x3x4x6x5xf32>
|
||||
// CHECK: return %1
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @addN
|
||||
func @addN(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: return %arg0
|
||||
%0 = "tf.AddN"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
@ -0,0 +1,67 @@
|
||||
// RUN: tf-opt %s -split-input-file -tf-device-decompose-resource-ops | FileCheck %s
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that composite tf.AssignAddVariableOp operation is decomposed and
|
||||
// hoisted.
|
||||
|
||||
// CHECK-LABEL: func @decompose_assign_add_variable_op
|
||||
func @decompose_assign_add_variable_op() -> () {
|
||||
|
||||
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
|
||||
|
||||
// CHECK: %[[ONE:[0-9]*]] = "tf.Const"() {value = dense<1> : tensor<i32>}
|
||||
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"
|
||||
// CHECK: "tf.AddV2"(%[[RES_READ_VAL]], %[[ONE]])
|
||||
// CHECK: "tf.AssignVariableOp"
|
||||
|
||||
%1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
"tf.AssignAddVariableOp"(%0, %1) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>, tensor<i32>) -> ()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that composite tf.AssignSubVariableOp operation is decomposed using
|
||||
// SubOp.
|
||||
|
||||
// CHECK-LABEL: func @decompose_assign_sub_variable_op
|
||||
func @decompose_assign_sub_variable_op() -> () {
|
||||
|
||||
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
|
||||
|
||||
// CHECK: %[[ONE:[0-9]*]] = "tf.Const"() {value = dense<1> : tensor<i32>}
|
||||
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"
|
||||
// CHECK: "tf.Sub"(%[[RES_READ_VAL]], %[[ONE]])
|
||||
// CHECK: "tf.AssignVariableOp"
|
||||
|
||||
%1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
"tf.AssignSubVariableOp"(%0, %1) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>, tensor<i32>) -> ()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that composite tf.ResourceApplyGradientDescent operation is decomposed.
|
||||
|
||||
// CHECK-LABEL: func @decompose_resource_apply_gradient_descent
|
||||
// CHECK-SAME: (%[[DELTA:.*]]: tensor<f32>)
|
||||
func @decompose_resource_apply_gradient_descent(%arg0: tensor<f32>) -> () {
|
||||
|
||||
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
|
||||
|
||||
// CHECK: %[[ALPHA:[0-9]*]] = "tf.Const"
|
||||
// CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp"
|
||||
// CHECK: %[[MUL:[0-9]*]] = "tf.Mul"(%[[DELTA]], %[[ALPHA]])
|
||||
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]])
|
||||
// CHECK: %[[SUB:[0-9]*]] = "tf.Sub"(%[[RES_READ_VAL]], %[[MUL]])
|
||||
// CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[SUB]])
|
||||
|
||||
%1 = "tf.Const"() {T = f32, value = dense<[0.5]> : tensor<1xf32>} : () -> tensor<f32>
|
||||
"tf.ResourceApplyGradientDescent"(%0, %1, %arg0) {use_locking = false} : (tensor<*x!tf.resource>, tensor<f32>, tensor<f32>) -> ()
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -49,40 +49,33 @@ func @testIf3Result(tensor<i1>, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xi8>,
|
||||
|
||||
// -----
|
||||
|
||||
func @testIf1Then(tensor<2x?xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
func @testIf1Else(tensor<*xf32>, tensor<2x?xf32>) -> tensor<*xf32>
|
||||
func @testIfThen(%arg0: tensor<!tf.variant>) -> tensor<!tf.variant> {
|
||||
return %arg0 : tensor<!tf.variant>
|
||||
}
|
||||
func @testIfElse(%arg0: tensor<!tf.variant>) -> tensor<!tf.variant> {
|
||||
return %arg0 : tensor<!tf.variant>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @testIf1Casts(%arg0: tensor<i1>, %arg1: tensor<2x2xf32>, %arg2: tensor<*xf32>)
|
||||
func @testIf1Casts(tensor<i1>, tensor<2x2xf32>, tensor<*xf32>) -> tensor<2x?xf32> {
|
||||
^bb0(%arg0: tensor<i1>, %arg1: tensor<2x2xf32>, %arg2: tensor<*xf32>):
|
||||
|
||||
%1 = "tf.If"(%arg0, %arg1, %arg2) {
|
||||
then_branch = @testIf1Then, else_branch = @testIf1Else, is_stateless = false
|
||||
} : (tensor<i1>, tensor<2x2xf32>, tensor<*xf32>) -> tensor<2x?xf32>
|
||||
|
||||
// CHECK: %0 = extract_element %arg0[] : tensor<i1>
|
||||
// CHECK: cond_br %0, ^bb1, ^bb2
|
||||
// CHECK:^bb1: // pred: ^bb0
|
||||
// CHECK: %1 = tensor_cast %arg1 : tensor<2x2xf32> to tensor<2x?xf32>
|
||||
// CHECK: %2 = tensor_cast %arg2 : tensor<*xf32> to tensor<2x2xf32>
|
||||
// CHECK: %3 = call @testIf1Then(%1, %2) : (tensor<2x?xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK: %4 = tensor_cast %3 : tensor<2x2xf32> to tensor<2x?xf32>
|
||||
// CHECK: br ^bb3(%4 : tensor<2x?xf32>)
|
||||
|
||||
// CHECK:^bb2: // pred: ^bb0
|
||||
// CHECK: %5 = tensor_cast %arg1 : tensor<2x2xf32> to tensor<*xf32>
|
||||
// CHECK: %6 = tensor_cast %arg2 : tensor<*xf32> to tensor<2x?xf32>
|
||||
// CHECK: %7 = call @testIf1Else(%5, %6) : (tensor<*xf32>, tensor<2x?xf32>) -> tensor<*xf32>
|
||||
// CHECK: %8 = tensor_cast %7 : tensor<*xf32> to tensor<2x?xf32>
|
||||
// CHECK: br ^bb3(%8 : tensor<2x?xf32>)
|
||||
|
||||
// CHECK:^bb3(%9: tensor<2x?xf32>): // 2 preds: ^bb1, ^bb2
|
||||
|
||||
%2 = "tf.Add"(%1, %1) : (tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
|
||||
// CHECK: %10 = "tf.Add"(%9, %9) : (tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
|
||||
|
||||
return %2 : tensor<2x?xf32>
|
||||
// CHECK: return %10 : tensor<2x?xf32>
|
||||
// CHECK-LABEL: func @testIfCasts(%arg0: tensor<i1>, %arg1: tensor<!tf.variant<tensor<f32>>>) -> tensor<!tf.variant<tensor<f32>>>
|
||||
func @testIfCasts(%arg0: tensor<i1>, %arg1: tensor<!tf.variant<tensor<f32>>>) -> tensor<!tf.variant<tensor<f32>>> {
|
||||
%0 = "tf.If"(%arg0, %arg1) {
|
||||
then_branch = @testIfThen, else_branch = @testIfElse, is_stateless = false
|
||||
} : (tensor<i1>, tensor<!tf.variant<tensor<f32>>>) -> tensor<!tf.variant<tensor<f32>>>
|
||||
return %0: tensor<!tf.variant<tensor<f32>>>
|
||||
// CHECK: %0 = extract_element %arg0[] : tensor<i1>
|
||||
// CHECK: cond_br %0, ^bb1, ^bb2
|
||||
// CHECK: ^bb1:
|
||||
// CHECK: %1 = "tf.Cast"(%arg1) {Truncate = false} : (tensor<!tf.variant<tensor<f32>>>) -> tensor<!tf.variant>
|
||||
// CHECK: %2 = call @testIfThen(%1) : (tensor<!tf.variant>) -> tensor<!tf.variant>
|
||||
// CHECK: %3 = "tf.Cast"(%2) {Truncate = false} : (tensor<!tf.variant>) -> tensor<!tf.variant<tensor<f32>>>
|
||||
// CHECK: br ^bb3(%3 : tensor<!tf.variant<tensor<f32>>>)
|
||||
// CHECK: ^bb2:
|
||||
// CHECK: %4 = "tf.Cast"(%arg1) {Truncate = false} : (tensor<!tf.variant<tensor<f32>>>) -> tensor<!tf.variant>
|
||||
// CHECK: %5 = call @testIfElse(%4) : (tensor<!tf.variant>) -> tensor<!tf.variant>
|
||||
// CHECK: %6 = "tf.Cast"(%5) {Truncate = false} : (tensor<!tf.variant>) -> tensor<!tf.variant<tensor<f32>>>
|
||||
// CHECK: br ^bb3(%6 : tensor<!tf.variant<tensor<f32>>>)
|
||||
// CHECK: ^bb3(%7: tensor<!tf.variant<tensor<f32>>>):
|
||||
// CHECK: return %7 : tensor<!tf.variant<tensor<f32>>>
|
||||
}
|
||||
|
||||
// -----
|
||||
@ -188,31 +181,36 @@ func @testComplexWhile1Result(tensor<*xf32>) -> (tensor<*xf32>) {
|
||||
|
||||
// -----
|
||||
|
||||
func @testWhileCond(tensor<?x3xf32>) -> (tensor<i1>)
|
||||
func @testWhileBody(tensor<*xf32>) -> (tensor<?x?xf32>)
|
||||
func @testWhileCond(%arg0: tensor<!tf.variant>) -> (tensor<i1>) {
|
||||
%true = "tf.Const"() { value = dense<true> : tensor<i1> } : () -> (tensor<i1>)
|
||||
return %true : tensor<i1>
|
||||
}
|
||||
func @testWhileBody(%arg0: tensor<!tf.variant<tensor<1x?xf32>>>) -> (tensor<!tf.variant<tensor<?x?xf32>>>) {
|
||||
%0 = "tf.Cast"(%arg0) : (tensor<!tf.variant<tensor<1x?xf32>>>) -> tensor<!tf.variant<tensor<?x?xf32>>>
|
||||
return %0 : tensor<!tf.variant<tensor<?x?xf32>>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @testWhileCasts(%arg0: tensor<1x3xf32>)
|
||||
func @testWhileCasts(%arg0: tensor<1x3xf32>) -> (tensor<?x?xf32>) {
|
||||
// CHECK-LABEL: func @testWhileCasts(%arg0: tensor<!tf.variant<tensor<1x3xf32>>>) -> tensor<!tf.variant<tensor<*xf32>>>
|
||||
func @testWhileCasts(%arg0: tensor<!tf.variant<tensor<1x3xf32>>>) -> (tensor<!tf.variant<tensor<*xf32>>>) {
|
||||
%0 = "tf.While"(%arg0) {
|
||||
cond = @testWhileCond, body = @testWhileBody, is_stateless = false
|
||||
} : (tensor<1x3xf32>) -> (tensor<?x?xf32>)
|
||||
|
||||
// CHECK: %0 = tensor_cast %arg0 : tensor<1x3xf32> to tensor<?x3xf32>
|
||||
// CHECK: br ^bb1(%0 : tensor<?x3xf32>)
|
||||
// CHECK: ^bb1(%1: tensor<?x3xf32>):
|
||||
// CHECK: %2 = call @testWhileCond(%1) : (tensor<?x3xf32>) -> tensor<i1>
|
||||
} : (tensor<!tf.variant<tensor<1x3xf32>>>) -> (tensor<!tf.variant<tensor<*xf32>>>)
|
||||
return %0 : tensor<!tf.variant<tensor<*xf32>>>
|
||||
// CHECK: %0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<!tf.variant<tensor<1x3xf32>>>) -> tensor<!tf.variant>
|
||||
// CHECK: br ^bb1(%0 : tensor<!tf.variant>)
|
||||
// CHECK: ^bb1(%1: tensor<!tf.variant>): // 2 preds: ^bb0, ^bb2
|
||||
// CHECK: %2 = call @testWhileCond(%1) : (tensor<!tf.variant>) -> tensor<i1>
|
||||
// CHECK: %3 = extract_element %2[] : tensor<i1>
|
||||
// CHECK: %4 = tensor_cast %1 : tensor<?x3xf32> to tensor<*xf32>
|
||||
// CHECK: cond_br %3, ^bb2(%4 : tensor<*xf32>), ^bb3(%4 : tensor<*xf32>)
|
||||
// CHECK: ^bb2(%5: tensor<*xf32>):
|
||||
// CHECK: %6 = call @testWhileBody(%5) : (tensor<*xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %7 = tensor_cast %6 : tensor<?x?xf32> to tensor<?x3xf32>
|
||||
// CHECK: br ^bb1(%7 : tensor<?x3xf32>)
|
||||
// CHECK: ^bb3(%8: tensor<*xf32>):
|
||||
// CHECK: %9 = tensor_cast %8 : tensor<*xf32> to tensor<?x?xf32>
|
||||
// CHECK: %4 = "tf.Cast"(%1) {Truncate = false} : (tensor<!tf.variant>) -> tensor<!tf.variant<tensor<1x?xf32>>>
|
||||
// CHECK: cond_br %3, ^bb2(%4 : tensor<!tf.variant<tensor<1x?xf32>>>), ^bb3(%4 : tensor<!tf.variant<tensor<1x?xf32>>>)
|
||||
// CHECK: ^bb2(%5: tensor<!tf.variant<tensor<1x?xf32>>>): // pred: ^bb1
|
||||
// CHECK: %6 = call @testWhileBody(%5) : (tensor<!tf.variant<tensor<1x?xf32>>>) -> tensor<!tf.variant<tensor<?x?xf32>>>
|
||||
// CHECK: %7 = "tf.Cast"(%6) {Truncate = false} : (tensor<!tf.variant<tensor<?x?xf32>>>) -> tensor<!tf.variant>
|
||||
// CHECK: br ^bb1(%7 : tensor<!tf.variant>)
|
||||
// CHECK: ^bb3(%8: tensor<!tf.variant<tensor<1x?xf32>>>): // pred: ^bb1
|
||||
// CHECK: %9 = "tf.Cast"(%8) {Truncate = false} : (tensor<!tf.variant<tensor<1x?xf32>>>) -> tensor<!tf.variant<tensor<*xf32>>>
|
||||
// CHECK: return %9 : tensor<!tf.variant<tensor<*xf32>>>
|
||||
|
||||
return %0 : tensor<?x?xf32>
|
||||
// CHECK: return %9 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
@ -54,5 +54,5 @@ versions {
|
||||
# the names are matching between the function definition and the uses / call
|
||||
# site (a numerical suffix may be appended).
|
||||
|
||||
# CHECK: "tf.foo0"(
|
||||
# CHECK: "tf.LegacyCall"(%outputs) {_disable_call_shape_inference = false, f = @foo0}
|
||||
# CHECK: func @foo0
|
||||
|
@ -8,7 +8,7 @@
|
||||
# Verify that we can also pull some attributes that are needed to be able to
|
||||
# create a Graph in memory, like `T`.
|
||||
# CHECK: tf.MaxPool
|
||||
# CHECK-SAME: T = "tfdtype$DT_FLOAT"
|
||||
# CHECK-SAME: T = f32
|
||||
|
||||
node {
|
||||
name: "input"
|
||||
|
@ -0,0 +1,65 @@
|
||||
# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=x -tf-input-data-types=DT_INT32 -tf-input-shapes=10 -tf-output-arrays=func_call -o - | FileCheck %s
|
||||
|
||||
node {
|
||||
name: "x"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_INT32
|
||||
tensor_shape {
|
||||
dim {
|
||||
size: 1
|
||||
}
|
||||
}
|
||||
int_val: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "func_call"
|
||||
op: "test_func_name"
|
||||
input: "x"
|
||||
attr {
|
||||
key: "_disable_call_shape_inference"
|
||||
value {
|
||||
b: true
|
||||
}
|
||||
}
|
||||
}
|
||||
library {
|
||||
function {
|
||||
signature {
|
||||
name: "test_func_name"
|
||||
input_arg {
|
||||
name: "a_0"
|
||||
type: DT_INT32
|
||||
}
|
||||
output_arg {
|
||||
name: "a"
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
ret {
|
||||
key: "a"
|
||||
value: "a_0"
|
||||
}
|
||||
attr {
|
||||
key: "_disable_call_shape_inference"
|
||||
value {
|
||||
b: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# CHECK: func @main
|
||||
# CHECK: "tf.LegacyCall"(%arg0) {_disable_call_shape_inference = true, f = @test_func_name0}
|
@ -121,8 +121,8 @@ versions {
|
||||
# Verify that functions from the library are properly imported.
|
||||
|
||||
# CHECK-LABEL: func @main() {
|
||||
# CHECK: "tf.foo110"()
|
||||
# CHECK: "tf.foo111"()
|
||||
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @foo110}
|
||||
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @foo111}
|
||||
|
||||
# CHECK-LABEL: func @foo110() {
|
||||
# CHECK-LABEL: func @foo111() {
|
||||
|
@ -39,10 +39,10 @@ versions {
|
||||
# Verify that functions from the library are properly imported.
|
||||
|
||||
# CHECK-LABEL: func @main() {
|
||||
# CHECK: "tf.foo0"()
|
||||
# CHECK: "tf.bar0"()
|
||||
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @foo0}
|
||||
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @bar0}
|
||||
|
||||
# CHECK-LABEL: func @foo0() {
|
||||
# CHECK: "tf.bar0"()
|
||||
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @bar0}
|
||||
|
||||
# CHECK-LABEL: func @bar0() {
|
||||
|
@ -0,0 +1,300 @@
|
||||
# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=z:1,z:2 -tf-input-shapes=':' -tf-output-arrays=z:2,z:1,a:0 -o - | FileCheck %s --dump-input=fail
|
||||
# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-prune-unused-nodes -tf-input-arrays=z:1,z:2 -tf-input-shapes=':' -tf-output-arrays=z:2,z:1,a:0 -o - | FileCheck --check-prefix=PRUNE %s --dump-input=fail
|
||||
# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-prune-unused-nodes -tf-input-arrays=z:1,z:2 -tf-input-shapes=':' -tf-output-arrays=z:0,a:0 -o - | FileCheck --check-prefix=PRESERVE %s --dump-input=fail
|
||||
|
||||
# Generated in Python via
|
||||
# ```
|
||||
# import tensorflow as tf
|
||||
#
|
||||
# with tf.compat.v1.Graph().as_default() as g:
|
||||
# w = tf.constant(2.0)
|
||||
# x = tf.constant(3.0)
|
||||
# y = tf.constant(4.0)
|
||||
# var = tf.Variable(2.0)
|
||||
# var_add = var.assign_add(3.0)
|
||||
# with g.control_dependencies([var_add]):
|
||||
# z0, z1, z2 = tf.identity_n((w, x, y))
|
||||
#
|
||||
# a = tf.add(z1, z2)
|
||||
# ```
|
||||
|
||||
node {
|
||||
name: "w"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape {
|
||||
}
|
||||
float_val: 2.0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "x"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape {
|
||||
}
|
||||
float_val: 3.0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "y"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape {
|
||||
}
|
||||
float_val: 4.0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "var/initial_value"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape {
|
||||
}
|
||||
float_val: 2.0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "var"
|
||||
op: "VariableV2"
|
||||
attr {
|
||||
key: "container"
|
||||
value {
|
||||
s: ""
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shared_name"
|
||||
value {
|
||||
s: ""
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "var/Assign"
|
||||
op: "Assign"
|
||||
input: "var"
|
||||
input: "var/initial_value"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_class"
|
||||
value {
|
||||
list {
|
||||
s: "loc:@var"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "use_locking"
|
||||
value {
|
||||
b: true
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "validate_shape"
|
||||
value {
|
||||
b: true
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "var/read"
|
||||
op: "Identity"
|
||||
input: "var"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_class"
|
||||
value {
|
||||
list {
|
||||
s: "loc:@var"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "var_add/value"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape {
|
||||
}
|
||||
float_val: 3.0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "var_add"
|
||||
op: "AssignAdd"
|
||||
input: "var"
|
||||
input: "var_add/value"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_class"
|
||||
value {
|
||||
list {
|
||||
s: "loc:@var"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "use_locking"
|
||||
value {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "z"
|
||||
op: "IdentityN"
|
||||
input: "w"
|
||||
input: "x"
|
||||
input: "y"
|
||||
input: "^var_add"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_FLOAT
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "a"
|
||||
op: "Add"
|
||||
input: "z:1"
|
||||
input: "z:2"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
versions {
|
||||
producer: 230
|
||||
}
|
||||
|
||||
# Test non zero index output tensors as feeds. Original ops where their outputs
|
||||
# are replaced with feeds are preserved and args and rets are lifted to the
|
||||
# function. Rets that happen to coincide with a feed should have its value be
|
||||
# of the feed.
|
||||
#
|
||||
# CHECK: func @main(%[[ARG_0:.*]]: tensor<f32>, %[[ARG_1:.*]]: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>)
|
||||
# CHECK: attributes {tf.entry_function = {inputs = "z:1,z:2", outputs = "z:2,z:1,a:0"}}
|
||||
# CHECK: %{{.*}}, %[[ASSIGN_ADD_CTRL:.*]] = tf_executor.island wraps "tf.AssignAdd"
|
||||
# CHECK: %{{.*}}, %{{.*}} = tf_executor.island(%[[ASSIGN_ADD_CTRL]]) wraps "tf.IdentityN"
|
||||
# CHECK: %[[ADD:.*]], %{{.*}} = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
|
||||
# CHECK: tf_executor.fetch %[[ARG_1]], %[[ARG_0]], %[[ADD]]
|
||||
|
||||
# Test when non zero index output tensors are feeds, remaining ops that are
|
||||
# unreachable are pruned if pruning is enabled.
|
||||
#
|
||||
# PRUNE: func @main(%[[ARG_0:.*]]: tensor<f32>, %[[ARG_1:.*]]: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>)
|
||||
# PRUNE: attributes {tf.entry_function = {inputs = "z:1,z:2", outputs = "z:2,z:1,a:0"}}
|
||||
# PRUNE-NOT: "tf.Const"
|
||||
# PRUNE-NOT: "tf.VariableV2"
|
||||
# PRUNE-NOT: "tf.Assign"
|
||||
# PRUNE-NOT: "tf.Identity"
|
||||
# PRUNE-NOT: "tf.AssignAdd"
|
||||
# PRUNE-NOT: "tf.IdentityN"
|
||||
# PRUNE: %[[ADD:.*]], %{{.*}} = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
|
||||
# PRUNE: tf_executor.fetch %[[ARG_1]], %[[ARG_0]], %[[ADD]]
|
||||
|
||||
# Test when non zero index output tensors are feeds, remaining ops that are
|
||||
# unreachable are preserved if pruning is not enabled.
|
||||
#
|
||||
# PRESERVE: func @main(%[[ARG_0:.*]]: tensor<f32>, %[[ARG_1:.*]]: tensor<f32>) -> (tensor<f32>, tensor<f32>)
|
||||
# PRESERVE: attributes {tf.entry_function = {inputs = "z:1,z:2", outputs = "z:0,a:0"}}
|
||||
# PRESERVE: %{{.*}}, %[[ASSIGN_ADD_CTRL:.*]] = tf_executor.island wraps "tf.AssignAdd"
|
||||
# PRESERVE: %[[IDENTITY_N:.*]]:3, %{{.*}} = tf_executor.island(%[[ASSIGN_ADD_CTRL]]) wraps "tf.IdentityN"
|
||||
# PRESERVE: %[[ADD:.*]], %{{.*}} = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
|
||||
# PRESERVE: tf_executor.fetch %[[IDENTITY_N]]#0, %[[ADD]]
|
@ -2,11 +2,11 @@
|
||||
|
||||
# CHECK: tf_executor.SwitchN
|
||||
# CHECK-SAME: of 3 : tensor<i32>
|
||||
# CHECK-SAME: T = "tfdtype$DT_INT32"
|
||||
# CHECK-SAME: T = i32
|
||||
# CHECK-SAME: name = "Case/branch_index/_3"
|
||||
# CHECK: tf_executor.SwitchN
|
||||
# CHECK-SAME: of 2 : tensor<f32>
|
||||
# CHECK-SAME: T = "tfdtype$DT_FLOAT"
|
||||
# CHECK-SAME: T = f32
|
||||
# CHECK-SAME: name = "Case/Case/input_0/_7"
|
||||
|
||||
node {
|
||||
|
@ -250,3 +250,19 @@ func @ZerosLike_variant(%arg0: tensor<!tf.variant<tensor<2xi32>>>) -> tensor<!tf
|
||||
%0 = "tf.ZerosLike"(%arg0) : (tensor<!tf.variant<tensor<2xi32>>>) -> tensor<!tf.variant<tensor<2xi32>>>
|
||||
return %0 : tensor<!tf.variant<tensor<2xi32>>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @addN
|
||||
func @addN(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: %[[SUM0:.*]] = "tf.AddV2"(%arg0, %arg1)
|
||||
// CHECK: %[[SUM1:.*]] = "tf.AddV2"(%[[SUM0]], %arg2)
|
||||
// return %[[SUM1]]
|
||||
%0 = "tf.AddN"(%arg0, %arg1, %arg2) : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @addN_variant
|
||||
func @addN_variant(%arg0: tensor<!tf.variant<tensor<2xf32>>>, %arg1: tensor<!tf.variant<tensor<2xf32>>>, %arg2: tensor<!tf.variant<tensor<2xf32>>>) -> tensor<!tf.variant<tensor<2xf32>>> {
|
||||
// CHECK: tf.AddN
|
||||
%0 = "tf.AddN"(%arg0, %arg1, %arg2) : (tensor<!tf.variant<tensor<2xf32>>>, tensor<!tf.variant<tensor<2xf32>>>, tensor<!tf.variant<tensor<2xf32>>>) -> tensor<!tf.variant<tensor<2xf32>>>
|
||||
return %0 : tensor<!tf.variant<tensor<2xf32>>>
|
||||
}
|
||||
|
@ -0,0 +1,26 @@
|
||||
// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s
|
||||
|
||||
func @main() {
|
||||
tf_executor.graph {
|
||||
%outputs, %control = tf_executor.island wraps "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Constant", value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
%outputs_0, %control_1 = tf_executor.island wraps "tf.LegacyCall"(%outputs) {f = @foo0} : (tensor<i32>) -> tensor<i32>
|
||||
tf_executor.fetch
|
||||
}
|
||||
return
|
||||
}
|
||||
func @foo0(%arg0: tensor<*xi32>) -> tensor<*xi32> {
|
||||
%0 = tf_executor.graph {
|
||||
tf_executor.fetch %arg0 : tensor<*xi32>
|
||||
}
|
||||
return %0 : tensor<*xi32>
|
||||
}
|
||||
|
||||
// CHECK: node {
|
||||
// CHECK: name: "_tf.LegacyCall"
|
||||
// CHECK-NEXT: op: "foo0"
|
||||
|
||||
// CHECK: library {
|
||||
// CHECK-NEXT: function {
|
||||
// CHECK-NEXT: signature {
|
||||
// CHECK-NEXT: name: "foo0"
|
||||
|
@ -0,0 +1,244 @@
|
||||
// RUN: tf-opt -split-input-file -verify-diagnostics -tf-resource-device-inference %s | FileCheck %s --dump-input=fail
|
||||
|
||||
// Tests that the pass can correctly propagate device attributes inside the same
|
||||
// function.
|
||||
|
||||
// CHECK-LABEL: func @propagate_in_function
|
||||
func @propagate_in_function(
|
||||
%arg0: tensor<*x!tf.resource<tensor<32xf32>>> {tf.device = "/TPU:0"},
|
||||
%arg1: tensor<*x!tf.resource<tensor<32xf32>>> {tf.device = "/TPU:1"}) {
|
||||
tf_executor.graph {
|
||||
// CHECK: tf_executor.island
|
||||
%island = tf_executor.island {
|
||||
// CHECK-NEXT: "tf.VarHandleOp"
|
||||
%var_handle = "tf.VarHandleOp"() {container = "c", shared_name = "v0", device = "/CPU:0"}
|
||||
: () -> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
// CHECK-NEXT: "tf.Identity"
|
||||
// CHECK-SAME: {device = "/TPU:0"}
|
||||
%id0 = "tf.Identity"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>)
|
||||
-> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
// CHECK-NEXT: "tf.Identity"
|
||||
// CHECK-SAME: {device = "/TPU:0"}
|
||||
%id1 = "tf.Identity"(%id0) : (tensor<*x!tf.resource<tensor<32xf32>>>)
|
||||
-> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
// CHECK-NEXT: "tf.Identity"
|
||||
// CHECK-SAME: {device = "/CPU:0"}
|
||||
%id2 = "tf.Identity"(%var_handle) : (tensor<*x!tf.resource<tensor<32xf32>>>)
|
||||
-> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
%read = "tf.ReadVariableOp"(%id2) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||
%id3 = "tf.Identity"(%read) : (tensor<32xf32>) -> tensor<32xf32>
|
||||
tf_executor.yield
|
||||
}
|
||||
tf_executor.fetch %island : !tf_executor.control
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tesets that the pass can propagate through tf.If's branches.
|
||||
|
||||
// CHECK-LABEL: func @propagate_if_op
|
||||
func @propagate_if_op(
|
||||
%arg0: tensor<*x!tf.resource<tensor<32xf32>>> {tf.device = "/TPU:0"},
|
||||
%arg1: tensor<i1>) {
|
||||
tf_executor.graph {
|
||||
// CHECK: tf_executor.island
|
||||
%island = tf_executor.island {
|
||||
// CHECK-NEXT: "tf.Identity"
|
||||
// CHECK-SAME: {device = "/TPU:0"}
|
||||
%id0 = "tf.Identity"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>)
|
||||
-> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
// CHECK-NEXT: "tf.VarHandleOp"
|
||||
%var_handle = "tf.VarHandleOp"() {container = "c", shared_name = "v0", device = "/TPU:1"}
|
||||
: () -> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
// CHECK-NEXT: "tf.If"
|
||||
"tf.If"(%arg1, %id0, %var_handle) {
|
||||
then_branch = @if_then,
|
||||
else_branch = @if_else,
|
||||
output_shapes = [], is_stateless = false}
|
||||
: (tensor<i1>, tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>) -> ()
|
||||
tf_executor.yield
|
||||
}
|
||||
tf_executor.fetch %island : !tf_executor.control
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @if_then
|
||||
func @if_then(
|
||||
%arg0: tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
%arg1: tensor<*x!tf.resource<tensor<32xf32>>>) {
|
||||
tf_executor.graph {
|
||||
// CHECK: tf_executor.island
|
||||
%island = tf_executor.island {
|
||||
// CHECK-NEXT: "tf.Identity"
|
||||
// CHECK-SAME: {device = "/TPU:0"}
|
||||
%id0 = "tf.Identity"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>)
|
||||
-> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
// CHECK-NEXT: "tf.Identity"
|
||||
// CHECK-SAME: {device = "/TPU:1"}
|
||||
%id1 = "tf.Identity"(%arg1) : (tensor<*x!tf.resource<tensor<32xf32>>>)
|
||||
-> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
tf_executor.yield
|
||||
}
|
||||
tf_executor.fetch %island : !tf_executor.control
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @if_else
|
||||
func @if_else(
|
||||
%arg0: tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
%arg1: tensor<*x!tf.resource<tensor<32xf32>>>) {
|
||||
tf_executor.graph {
|
||||
// CHECK: tf_executor.island
|
||||
%island = tf_executor.island {
|
||||
// CHECK-NEXT: "tf.Identity"
|
||||
// CHECK-SAME: {device = "/TPU:0"}
|
||||
%id0 = "tf.Identity"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>)
|
||||
-> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
tf_executor.yield
|
||||
}
|
||||
tf_executor.fetch %island : !tf_executor.control
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// Tesets that the pass can propagate through tf.While's branches.
|
||||
|
||||
// CHECK-LABEL: func @propagate_while_op
|
||||
func @propagate_while_op(
|
||||
%arg0: tensor<*x!tf.resource<tensor<32xf32>>> {tf.device = "/TPU:0"},
|
||||
%arg1: tensor<i32>) {
|
||||
tf_executor.graph {
|
||||
// CHECK: tf_executor.island
|
||||
%island = tf_executor.island {
|
||||
// CHECK-NEXT: "tf.Identity"
|
||||
// CHECK-SAME: {device = "/TPU:0"}
|
||||
%id0 = "tf.Identity"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>)
|
||||
-> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
// CHECK-NEXT: "tf.VarHandleOp"
|
||||
%var_handle = "tf.VarHandleOp"() {container = "c", shared_name = "v0", device = "/TPU:1"}
|
||||
: () -> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
// CHECK-NEXT: "tf.While"
|
||||
"tf.While"(%arg1, %id0, %var_handle) {
|
||||
body = @while_body,
|
||||
cond = @while_cond,
|
||||
output_shapes = [], is_stateless = false}
|
||||
: (tensor<i32>, tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>) ->
|
||||
(tensor<i32>, tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>)
|
||||
tf_executor.yield
|
||||
}
|
||||
tf_executor.fetch %island : !tf_executor.control
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @while_body
|
||||
func @while_body(
|
||||
%arg0: tensor<i32>,
|
||||
%arg1: tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
%arg2: tensor<*x!tf.resource<tensor<32xf32>>>) ->
|
||||
(tensor<i32>, tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>) {
|
||||
%graph:3 = tf_executor.graph {
|
||||
// CHECK: tf_executor.island
|
||||
%island:4 = tf_executor.island {
|
||||
// CHECK-NEXT: "tf.Identity"
|
||||
// CHECK-SAME: {device = "/TPU:0"}
|
||||
%id0 = "tf.Identity"(%arg1) : (tensor<*x!tf.resource<tensor<32xf32>>>)
|
||||
-> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
// CHECK-NEXT: "tf.Identity"
|
||||
// CHECK-SAME: {device = "/TPU:1"}
|
||||
%id1 = "tf.Identity"(%arg2) : (tensor<*x!tf.resource<tensor<32xf32>>>)
|
||||
-> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
tf_executor.yield %arg0, %id0, %id1
|
||||
: tensor<i32>, tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
}
|
||||
tf_executor.fetch %island#0, %island#1, %island#2
|
||||
: tensor<i32>, tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
}
|
||||
return %graph#0, %graph#1, %graph#2
|
||||
: tensor<i32>, tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @while_cond
|
||||
func @while_cond(
|
||||
%arg0: tensor<i32>,
|
||||
%arg1: tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
%arg2: tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32> {
|
||||
%graph = tf_executor.graph {
|
||||
// CHECK: tf_executor.island
|
||||
%island:2 = tf_executor.island {
|
||||
// CHECK-NEXT: "tf.Identity"
|
||||
// CHECK-SAME: {device = "/TPU:0"}
|
||||
%id0 = "tf.Identity"(%arg1) : (tensor<*x!tf.resource<tensor<32xf32>>>)
|
||||
-> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
%read = "tf.ReadVariableOp"(%id0)
|
||||
: (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||
tf_executor.yield %read : tensor<32xf32>
|
||||
}
|
||||
tf_executor.fetch %island#0 : tensor<32xf32>
|
||||
}
|
||||
return %graph : tensor<32xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tesets that the pass reports error on conflicting assignments from multiple
|
||||
// callers.
|
||||
|
||||
func @error_on_conflict_multiple_callers(
|
||||
%arg0: tensor<*x!tf.resource<tensor<32xf32>>> {tf.device = "/TPU:0"},
|
||||
%arg1: tensor<i1>) {
|
||||
tf_executor.graph {
|
||||
%island = tf_executor.island {
|
||||
%id0 = "tf.Identity"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>)
|
||||
-> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
%var_handle = "tf.VarHandleOp"() {container = "c", shared_name = "v0", device = "/TPU:1"}
|
||||
: () -> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
"tf.If"(%arg1, %id0, %var_handle) {
|
||||
then_branch = @if_then_and_else,
|
||||
else_branch = @if_then_and_else,
|
||||
output_shapes = [], is_stateless = false}
|
||||
: (tensor<i1>, tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>) -> ()
|
||||
"tf.If"(%arg1, %var_handle, %id0) {
|
||||
// expected-error@above {{Conflicting device assignment for resource}}
|
||||
then_branch = @if_then_and_else,
|
||||
else_branch = @if_then_and_else,
|
||||
output_shapes = [], is_stateless = false}
|
||||
: (tensor<i1>, tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>) -> ()
|
||||
tf_executor.yield
|
||||
}
|
||||
tf_executor.fetch %island : !tf_executor.control
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func @if_then_and_else(
|
||||
%arg0: tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
%arg1: tensor<*x!tf.resource<tensor<32xf32>>>) {
|
||||
tf_executor.graph {
|
||||
%island = tf_executor.island {
|
||||
%id0 = "tf.Identity"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>)
|
||||
-> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
%id1 = "tf.Identity"(%arg1) : (tensor<*x!tf.resource<tensor<32xf32>>>)
|
||||
-> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
tf_executor.yield
|
||||
}
|
||||
tf_executor.fetch %island : !tf_executor.control
|
||||
}
|
||||
return
|
||||
}
|
@ -8,7 +8,7 @@ func @only_resource_load() -> tensor<*xi32> {
|
||||
// CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp"
|
||||
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
|
||||
|
||||
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) {dtype = "tfdtype$DT_INT32"}
|
||||
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) {dtype = i32}
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK: %[[COMPUTE_RES:[0-9]*]] = "tf.SomeComputation"(%[[RES_READ_VAL]])
|
||||
// CHECK: tf_device.return %[[COMPUTE_RES]]
|
||||
@ -16,7 +16,7 @@ func @only_resource_load() -> tensor<*xi32> {
|
||||
// CHECK-SAME: () -> tensor<*xi32>
|
||||
|
||||
%1 = "tf_device.launch"() ( {
|
||||
%2 = "tf.ReadVariableOp"(%0) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>) -> tensor<*xi32>
|
||||
%2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32>
|
||||
%3 = "tf.SomeComputation"(%2) : (tensor<*xi32>) -> (tensor<*xi32>)
|
||||
tf_device.return %3 : tensor<*xi32>
|
||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<*xi32>
|
||||
@ -39,11 +39,11 @@ func @only_resource_store() -> tensor<*xi32> {
|
||||
// CHECK: tf_device.return %[[COMPUTE_RES]], %[[COMPUTE_RES]]
|
||||
// CHECK: {device = "tpu0", launch_attr = "launch_attr"}
|
||||
// CHECK-SAME: () -> (tensor<*xi32>, tensor<*xi32>)
|
||||
// CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[LAUNCH_RES]]#1) {dtype = "tfdtype$DT_INT32"}
|
||||
// CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[LAUNCH_RES]]#1) {dtype = i32}
|
||||
|
||||
%1 = "tf_device.launch"() ( {
|
||||
%2 = "tf.SomeComputation"() : () -> (tensor<*xi32>)
|
||||
"tf.AssignVariableOp"(%0, %2) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>, tensor<*xi32>) -> ()
|
||||
"tf.AssignVariableOp"(%0, %2) {dtype = i32} : (tensor<*x!tf.resource>, tensor<*xi32>) -> ()
|
||||
tf_device.return %2 : tensor<*xi32>
|
||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<*xi32>
|
||||
|
||||
@ -61,18 +61,18 @@ func @same_resource_load_and_store() -> tensor<*xi32> {
|
||||
// CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp"
|
||||
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
|
||||
|
||||
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) {dtype = "tfdtype$DT_INT32"}
|
||||
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) {dtype = i32}
|
||||
// CHECK: %[[LAUNCH_RES:[0-9]*]]:2 = "tf_device.launch"
|
||||
// CHECK: %[[COMPUTE_RES:[0-9]*]] = "tf.SomeComputation"(%[[RES_READ_VAL]])
|
||||
// CHECK: tf_device.return %[[COMPUTE_RES]], %[[COMPUTE_RES]]
|
||||
// CHECK: {device = "tpu0", launch_attr = "launch_attr"}
|
||||
// CHECK-SAME: () -> (tensor<*xi32>, tensor<*xi32>)
|
||||
// CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[LAUNCH_RES]]#1) {dtype = "tfdtype$DT_INT32"}
|
||||
// CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[LAUNCH_RES]]#1) {dtype = i32}
|
||||
|
||||
%1 = "tf_device.launch"() ( {
|
||||
%2 = "tf.ReadVariableOp"(%0) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>) -> tensor<*xi32>
|
||||
%2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32>
|
||||
%3 = "tf.SomeComputation"(%2) : (tensor<*xi32>) -> (tensor<*xi32>)
|
||||
"tf.AssignVariableOp"(%0, %3) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>, tensor<*xi32>) -> ()
|
||||
"tf.AssignVariableOp"(%0, %3) {dtype = i32} : (tensor<*x!tf.resource>, tensor<*xi32>) -> ()
|
||||
tf_device.return %3 : tensor<*xi32>
|
||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<*xi32>
|
||||
|
||||
@ -82,96 +82,6 @@ func @same_resource_load_and_store() -> tensor<*xi32> {
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that composite tf.AssignAddVariableOp operation is decomposed and
|
||||
// hoisted.
|
||||
|
||||
// CHECK-LABEL: func @decompose_assign_add_variable_op
|
||||
func @decompose_assign_add_variable_op() -> tensor<*xi32> {
|
||||
|
||||
// CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp"
|
||||
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
|
||||
|
||||
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) {dtype = "tfdtype$DT_INT32"}
|
||||
// CHECK: %[[LAUNCH_RES:[0-9]*]]:2 = "tf_device.launch"
|
||||
// CHECK: %[[ONE:[0-9]*]] = "tf.Const"() {value = dense<1> : tensor<i32>}
|
||||
// CHECK: %[[COMPUTE_RES:[0-9]*]] = "tf.AddV2"(%[[RES_READ_VAL]], %[[ONE]])
|
||||
// CHECK: tf_device.return %[[COMPUTE_RES]], %[[COMPUTE_RES]]
|
||||
// CHECK: {device = "tpu0", launch_attr = "launch_attr"}
|
||||
// CHECK-SAME: () -> (tensor<*xi32>, tensor<*xi32>)
|
||||
// CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[LAUNCH_RES]]#1) {dtype = "tfdtype$DT_INT32"}
|
||||
|
||||
%1 = "tf_device.launch"() ( {
|
||||
%2 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
"tf.AssignAddVariableOp"(%0, %2) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>, tensor<i32>) -> ()
|
||||
%3 = "tf.ReadVariableOp"(%0) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>) -> tensor<*xi32>
|
||||
tf_device.return %3 : tensor<*xi32>
|
||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<*xi32>
|
||||
|
||||
// CHECK: return %[[LAUNCH_RES]]#0
|
||||
return %1 : tensor<*xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that composite tf.AssignSubVariableOp operation is decomposed using
|
||||
// SubOp.
|
||||
|
||||
// CHECK-LABEL: func @decompose_assign_sub_variable_op
|
||||
func @decompose_assign_sub_variable_op() -> tensor<*xi32> {
|
||||
|
||||
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
|
||||
|
||||
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"
|
||||
// CHECK: %[[ONE:[0-9]*]] = "tf.Const"() {value = dense<1> : tensor<i32>}
|
||||
// CHECK: "tf.Sub"(%[[RES_READ_VAL]], %[[ONE]])
|
||||
// CHECK: "tf.AssignVariableOp"
|
||||
|
||||
%1 = "tf_device.launch"() ( {
|
||||
%2 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
"tf.AssignSubVariableOp"(%0, %2) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>, tensor<i32>) -> ()
|
||||
%3 = "tf.ReadVariableOp"(%0) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>) -> tensor<*xi32>
|
||||
tf_device.return %3 : tensor<*xi32>
|
||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<*xi32>
|
||||
|
||||
return %1 : tensor<*xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that composite tf.ResourceApplyGradientDescent operation is decomposed
|
||||
// and hoisted.
|
||||
|
||||
// CHECK-LABEL: func @decompose_resource_apply_gradient_descent
|
||||
func @decompose_resource_apply_gradient_descent() -> tensor<*xf32> {
|
||||
|
||||
// CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp"
|
||||
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
|
||||
|
||||
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) {dtype = "tfdtype$DT_FLOAT"}
|
||||
// CHECK: %[[LAUNCH_RES:[0-9]*]]:2 = "tf_device.launch"
|
||||
// CHECK: %[[ALPHA:[0-9]*]] = "tf.Const"
|
||||
// CHECK: %[[DELTA:[0-9]*]] = "tf.Const"
|
||||
// CHECK: %[[MUL:[0-9]*]] = "tf.Mul"(%[[ALPHA]], %[[DELTA]])
|
||||
// CHECK: %[[SUB:[0-9]*]] = "tf.Sub"(%[[RES_READ_VAL]], %[[MUL]])
|
||||
// CHECK: tf_device.return %[[SUB]], %[[SUB]]
|
||||
// CHECK: {device = "tpu0", launch_attr = "launch_attr"}
|
||||
// CHECK-SAME: () -> (tensor<*xf32>, tensor<*xf32>)
|
||||
// CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[LAUNCH_RES]]#1) {dtype = "tfdtype$DT_FLOAT"}
|
||||
|
||||
%1 = "tf_device.launch"() ( {
|
||||
%2 = "tf.Const"() {T = "tfdtype$DT_FLOAT", value = dense<[1.0]> : tensor<1xf32>} : () -> tensor<f32>
|
||||
%3 = "tf.Const"() {T = "tfdtype$DT_FLOAT", value = dense<[0.5]> : tensor<1xf32>} : () -> tensor<f32>
|
||||
"tf.ResourceApplyGradientDescent"(%0, %2, %3) : (tensor<*x!tf.resource>, tensor<f32>, tensor<f32>) -> ()
|
||||
%4 = "tf.ReadVariableOp"(%0) {dtype = "tfdtype$DT_FLOAT"} : (tensor<*x!tf.resource>) -> tensor<*xf32>
|
||||
tf_device.return %4 : tensor<*xf32>
|
||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<*xf32>
|
||||
|
||||
// CHECK: return %[[LAUNCH_RES]]#0
|
||||
return %1 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that internal resource operations are not hoisted.
|
||||
|
||||
// CHECK-LABEL: func @internal_resource
|
||||
@ -184,13 +94,13 @@ func @internal_resource() -> tensor<*xi32> {
|
||||
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
|
||||
|
||||
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]])
|
||||
%2 = "tf.ReadVariableOp"(%1) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>) -> tensor<*xi32>
|
||||
%2 = "tf.ReadVariableOp"(%1) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32>
|
||||
|
||||
// CHECK: %[[COMPUTE_RES:[0-9]*]] = "tf.SomeComputation"(%[[RES_READ_VAL]])
|
||||
%3 = "tf.SomeComputation"(%2) : (tensor<*xi32>) -> (tensor<*xi32>)
|
||||
|
||||
// CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[COMPUTE_RES]])
|
||||
"tf.AssignVariableOp"(%1, %3) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>, tensor<*xi32>) -> ()
|
||||
"tf.AssignVariableOp"(%1, %3) {dtype = i32} : (tensor<*x!tf.resource>, tensor<*xi32>) -> ()
|
||||
|
||||
// CHECK: tf_device.return %[[COMPUTE_RES]]
|
||||
tf_device.return %3 : tensor<*xi32>
|
||||
|
@ -1,6 +1,17 @@
|
||||
// RUN: tf-opt %s -tf-shape-inference -verify-diagnostics | FileCheck %s -dump-input=fail -color
|
||||
|
||||
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 130 : i32}} {
|
||||
// CHECK-LABEL: func @main(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32>
|
||||
func @main(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<*xi32> {
|
||||
// CHECK: %[[ARG0:.*]] = "tf.Cast"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
|
||||
// CHECK: %[[ARG1:.*]] = "tf.Cast"(%arg1) : (tensor<1xi32>) -> tensor<1xi32>
|
||||
// CHECK: %[[RESULT:.*]] = "tf.AddV2"(%[[ARG0]], %[[ARG1]]) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
// CHECK: return %[[RESULT]] : tensor<1xi32>
|
||||
%0 = "tf.Cast"(%arg0) : (tensor<1xi32>) -> tensor<*xi32>
|
||||
%1 = "tf.Cast"(%arg1) : (tensor<1xi32>) -> tensor<*xi32>
|
||||
%2 = "tf.AddV2"(%0, %1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
|
||||
return %2 : tensor<*xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @simple_chain
|
||||
func @simple_chain(%arg0: tensor<1xf32>) -> tensor<*xf32> {
|
||||
|
@ -6,18 +6,15 @@
|
||||
// CHECK-LABEL: func @non_aliasing_reads_writes
|
||||
func @non_aliasing_reads_writes(
|
||||
// expected-remark@above {{ID: 13}}
|
||||
// expected-remark@above {{Predecessors: {12}}}
|
||||
%arg0: tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
%arg1: tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
%arg2: tensor<32xf32>) -> (tensor<32xf32>) {
|
||||
%graph = tf_executor.graph {
|
||||
// expected-remark@above {{ID: 11}}
|
||||
// expected-remark@above {{Predecessors: {10}}}
|
||||
// expected-remark@above {{Successors: {12}}}
|
||||
// CHECK: tf_executor.island
|
||||
%island:2 = tf_executor.island {
|
||||
// expected-remark@above {{ID: 9}}
|
||||
// expected-remark@above {{Predecessors: {8}}}
|
||||
// expected-remark@above {{Successors: {10}}}
|
||||
%read0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||
// expected-remark@above {{ID: 0}}
|
||||
@ -49,17 +46,14 @@ func @non_aliasing_reads_writes(
|
||||
tf_executor.yield %read3 : tensor<32xf32>
|
||||
// expected-remark@above {{ID: 8}}
|
||||
// expected-remark@above {{Predecessors: {4,5,7}}}
|
||||
// expected-remark@above {{Successors: {9}}}
|
||||
}
|
||||
tf_executor.fetch %island#0 : tensor<32xf32>
|
||||
// expected-remark@above {{ID: 10}}
|
||||
// expected-remark@above {{Predecessors: {9}}}
|
||||
// expected-remark@above {{Successors: {11}}}
|
||||
}
|
||||
return %graph : tensor<32xf32>
|
||||
// expected-remark@above {{ID: 12}}
|
||||
// expected-remark@above {{Predecessors: {11}}}
|
||||
// expected-remark@above {{Successors: {13}}}
|
||||
}
|
||||
|
||||
// -----
|
||||
@ -70,15 +64,12 @@ func @non_aliasing_reads_writes(
|
||||
// CHECK-LABEL: func @aliasing_reads_writes
|
||||
func @aliasing_reads_writes(%arg0: tensor<32xf32>) -> () {
|
||||
// expected-remark@above {{ID: 14}}
|
||||
// expected-remark@above {{Predecessors: {13}}}
|
||||
tf_executor.graph {
|
||||
// expected-remark@above {{ID: 12}}
|
||||
// expected-remark@above {{Predecessors: {11}}}
|
||||
// expected-remark@above {{Successors: {13}}}
|
||||
// CHECK: tf_executor.island
|
||||
%island = tf_executor.island {
|
||||
// expected-remark@above {{ID: 10}}
|
||||
// expected-remark@above {{Predecessors: {9}}}
|
||||
// expected-remark@above {{Successors: {11}}}
|
||||
%vh0 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
// expected-remark@above {{ID: 0}}
|
||||
@ -112,17 +103,14 @@ func @aliasing_reads_writes(%arg0: tensor<32xf32>) -> () {
|
||||
tf_executor.yield
|
||||
// expected-remark@above {{ID: 9}}
|
||||
// expected-remark@above {{Predecessors: {8}}}
|
||||
// expected-remark@above {{Successors: {10}}}
|
||||
}
|
||||
tf_executor.fetch %island : !tf_executor.control
|
||||
// expected-remark@above {{ID: 11}}
|
||||
// expected-remark@above {{Predecessors: {10}}}
|
||||
// expected-remark@above {{Successors: {12}}}
|
||||
}
|
||||
return
|
||||
// expected-remark@above {{ID: 13}}
|
||||
// expected-remark@above {{Predecessors: {12}}}
|
||||
// expected-remark@above {{Successors: {14}}}
|
||||
}
|
||||
|
||||
// -----
|
||||
@ -133,15 +121,12 @@ func @aliasing_reads_writes(%arg0: tensor<32xf32>) -> () {
|
||||
// CHECK-LABEL: func @unknown_side_effecting_op
|
||||
func @unknown_side_effecting_op(%arg0: tensor<32xf32>) -> () {
|
||||
// expected-remark@above {{ID: 13}}
|
||||
// expected-remark@above {{Predecessors: {12}}}
|
||||
tf_executor.graph {
|
||||
// expected-remark@above {{ID: 11}}
|
||||
// expected-remark@above {{Predecessors: {10}}}
|
||||
// expected-remark@above {{Successors: {12}}}
|
||||
// CHECK: tf_executor.island
|
||||
%island = tf_executor.island {
|
||||
// expected-remark@above {{ID: 9}}
|
||||
// expected-remark@above {{Predecessors: {8}}}
|
||||
// expected-remark@above {{Successors: {10}}}
|
||||
%vh0 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
// expected-remark@above {{ID: 0}}
|
||||
@ -172,17 +157,14 @@ func @unknown_side_effecting_op(%arg0: tensor<32xf32>) -> () {
|
||||
tf_executor.yield
|
||||
// expected-remark@above {{ID: 8}}
|
||||
// expected-remark@above {{Predecessors: {6,7}}}
|
||||
// expected-remark@above {{Successors: {9}}}
|
||||
}
|
||||
tf_executor.fetch %island : !tf_executor.control
|
||||
// expected-remark@above {{ID: 10}}
|
||||
// expected-remark@above {{Predecessors: {9}}}
|
||||
// expected-remark@above {{Successors: {11}}}
|
||||
}
|
||||
return
|
||||
// expected-remark@above {{ID: 12}}
|
||||
// expected-remark@above {{Predecessors: {11}}}
|
||||
// expected-remark@above {{Successors: {13}}}
|
||||
}
|
||||
|
||||
// -----
|
||||
@ -193,15 +175,12 @@ func @unknown_side_effecting_op(%arg0: tensor<32xf32>) -> () {
|
||||
// CHECK-LABEL: func @read_only_unknown_resource
|
||||
func @read_only_unknown_resource(%arg0: tensor<32xf32>) -> () {
|
||||
// expected-remark@above {{ID: 10}}
|
||||
// expected-remark@above {{Predecessors: {9}}}
|
||||
tf_executor.graph {
|
||||
// expected-remark@above {{ID: 8}}
|
||||
// expected-remark@above {{Predecessors: {7}}}
|
||||
// expected-remark@above {{Successors: {9}}}
|
||||
// CHECK: tf_executor.island
|
||||
%island = tf_executor.island {
|
||||
// expected-remark@above {{ID: 6}}
|
||||
// expected-remark@above {{Predecessors: {5}}}
|
||||
// expected-remark@above {{Successors: {7}}}
|
||||
%vh0 = "tf._UnknownSideEffectingOp_"() : () -> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
// expected-remark@above {{ID: 0}}
|
||||
@ -223,15 +202,71 @@ func @read_only_unknown_resource(%arg0: tensor<32xf32>) -> () {
|
||||
tf_executor.yield
|
||||
// expected-remark@above {{ID: 5}}
|
||||
// expected-remark@above {{Predecessors: {4}}}
|
||||
// expected-remark@above {{Successors: {6}}}
|
||||
}
|
||||
tf_executor.fetch %island : !tf_executor.control
|
||||
// expected-remark@above {{ID: 7}}
|
||||
// expected-remark@above {{Predecessors: {6}}}
|
||||
// expected-remark@above {{Successors: {8}}}
|
||||
}
|
||||
return
|
||||
// expected-remark@above {{ID: 9}}
|
||||
// expected-remark@above {{Predecessors: {8}}}
|
||||
// expected-remark@above {{Successors: {10}}}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that the pass adds control dependencies in nested regions with
|
||||
// tf_device.replicate
|
||||
|
||||
func @with_replicate(
|
||||
// expected-remark@above {{ID: 12}}
|
||||
%arg0: tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
%arg1: tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
%arg2: tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
%arg3: tensor<*x!tf.resource<tensor<32xf32>>>) {
|
||||
tf_executor.graph {
|
||||
// expected-remark@above {{ID: 10}}
|
||||
// expected-remark@above {{Successors: {11}}}
|
||||
%island = tf_executor.island {
|
||||
// expected-remark@above {{ID: 8}}
|
||||
// expected-remark@above {{Successors: {9}}}
|
||||
%u0:2 = "tf._UnknownSideEffectingOp_"() : () -> (tensor<32xf32>, tensor<32xf32>)
|
||||
// expected-remark@above {{ID: 0}}
|
||||
// expected-remark@above {{Successors: {5}}}
|
||||
tf_device.replicate(
|
||||
// expected-remark@above {{ID: 5}}
|
||||
// expected-remark@above {{Predecessors: {0}}}
|
||||
// expected-remark@above {{Successors: {6}}}
|
||||
[%arg0, %arg1] as %r0: tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
[%arg2, %arg3] as %r1: tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
[%u0#0, %u0#1] as %u : tensor<32xf32>)
|
||||
{n = 2 : i32, devices = ["/CPU:0", "/GPU:1"]} {
|
||||
%read0 = "tf.ReadVariableOp"(%r0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||
// expected-remark@above {{ID: 1}}
|
||||
// expected-remark@above {{Successors: {4}}}
|
||||
"tf.AssignVariableOp"(%r1, %u) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
|
||||
// expected-remark@above {{ID: 2}}
|
||||
// expected-remark@above {{Successors: {3}}}
|
||||
%read1 = "tf.ReadVariableOp"(%r1) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||
// expected-remark@above {{ID: 3}}
|
||||
// expected-remark@above {{Predecessors: {2}}}
|
||||
// expected-remark@above {{Successors: {4}}}
|
||||
tf_device.return
|
||||
// expected-remark@above {{ID: 4}}
|
||||
// expected-remark@above {{Predecessors: {1,3}}}
|
||||
}
|
||||
"tf._UnknownSideEffectingOp_"() : () -> ()
|
||||
// expected-remark@above {{ID: 6}}
|
||||
// expected-remark@above {{Predecessors: {5}}}
|
||||
// expected-remark@above {{Successors: {7}}}
|
||||
tf_executor.yield
|
||||
// expected-remark@above {{ID: 7}}
|
||||
// expected-remark@above {{Predecessors: {6}}}
|
||||
}
|
||||
tf_executor.fetch %island : !tf_executor.control
|
||||
// expected-remark@above {{ID: 9}}
|
||||
// expected-remark@above {{Predecessors: {8}}}
|
||||
}
|
||||
return
|
||||
// expected-remark@above {{ID: 11}}
|
||||
// expected-remark@above {{Predecessors: {10}}}
|
||||
}
|
||||
|
@ -1610,7 +1610,7 @@ func @testSplitUnknownDimInput(%input: tensor<4x?x4xf32>) {
|
||||
|
||||
// -----
|
||||
|
||||
func @testSplitNonConstSplitDim(%input: tensor<4x4xf32>, %split_dim: tensor<1xi32>) {
|
||||
func @testSplitNonScalarSplitDim(%input: tensor<4x4xf32>, %split_dim: tensor<1xi32>) {
|
||||
// expected-error @+1 {{split dimension should be an integer scalar tensor}}
|
||||
%0:2 = "tf.Split"(%split_dim, %input) : (tensor<1xi32>, tensor<4x4xf32>) -> (tensor<*xf32>, tensor<*xf32>)
|
||||
return
|
||||
@ -1674,3 +1674,152 @@ func @testTopKV2WrongKRank(%input: tensor<8xf32>, %k: tensor<5xi32>) {
|
||||
%0:2 = "tf.TopKV2"(%input, %k) : (tensor<8xf32>, tensor<5xi32>) -> (tensor<*xf32>, tensor<*xi32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSplitVScalarInput(%input: tensor<f32>, %split_sizes: tensor<2xi32>, %split_dim: tensor<i32>) {
|
||||
// expected-error @+1 {{cannot split scalar input tensor}}
|
||||
%0:2 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<f32>, tensor<2xi32>, tensor<i32>) -> (tensor<*xf32>, tensor<*xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSplitVNonScalarSplitDim(%input: tensor<4x4xf32>, %split_sizes: tensor<2xi32>, %split_dim: tensor<1xi32>) {
|
||||
// expected-error @+1 {{split dimension should be an integer scalar tensor}}
|
||||
%0:2 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x4xf32>, tensor<2xi32>, tensor<1xi32>) -> (tensor<*xf32>, tensor<*xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSplitVSplitDimOutOfRange(%input: tensor<4x4xf32>, %split_sizes: tensor<2xi32>) {
|
||||
%split_dim = "tf.Const"() {value = dense<100>: tensor<i32>} : () -> (tensor<i32>)
|
||||
// expected-error @+1 {{split dimension must be in range [-2, 2)}}
|
||||
%0:2 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x4xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<*xf32>, tensor<*xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSplitVWrongSplitSizesType(%input: tensor<4x4xf32>, %split_sizes: tensor<2x2xi32>, %split_dim: tensor<i32>) {
|
||||
// expected-error @+1 {{op split sizes should be a 1D tensor of 2 elements}}
|
||||
%0:2 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x4xf32>, tensor<2x2xi32>, tensor<i32>) -> (tensor<*xf32>, tensor<*xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSplitVMultipleDynamicSizes(%input: tensor<4x4xf32>) {
|
||||
%split_dim = "tf.Const"() {value = dense<1>: tensor<i32>} : () -> (tensor<i32>)
|
||||
%split_sizes = "tf.Const"() {value = dense<[-1, -1]>: tensor<2xi32>} : () -> (tensor<2xi32>)
|
||||
// expected-error @+1 {{cannot have more than one dynamic dimension in split sizes}}
|
||||
%0:2 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x4xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<*xf32>, tensor<*xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSplitVSplitSizeOutOfRange(%input: tensor<4x4xf32>) {
|
||||
%split_dim = "tf.Const"() {value = dense<1>: tensor<i32>} : () -> (tensor<i32>)
|
||||
%split_sizes = "tf.Const"() {value = dense<[-1, 100]>: tensor<2xi32>} : () -> (tensor<2xi32>)
|
||||
// expected-error @+1 {{split sizes must sum up to be less than or equal to the dimension size along split dimension, found 100 vs 4}}
|
||||
%0:2 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x4xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<*xf32>, tensor<*xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSplitVSplitSizeOutOfRange(%input: tensor<4x4xf32>) {
|
||||
%split_dim = "tf.Const"() {value = dense<1>: tensor<i32>} : () -> (tensor<i32>)
|
||||
%split_sizes = "tf.Const"() {value = dense<[2, 3]>: tensor<2xi32>} : () -> (tensor<2xi32>)
|
||||
// expected-error @+1 {{split sizes must sum up to the dimension size along split dimension, found 5 vs 4}}
|
||||
%0:2 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x4xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<*xf32>, tensor<*xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSplitV1(%input: tensor<4x4xf32>) {
|
||||
%split_dim = "tf.Const"() {value = dense<1>: tensor<i32>} : () -> (tensor<i32>)
|
||||
%split_sizes = "tf.Const"() {value = dense<[-1, 4]>: tensor<2xi32>} : () -> (tensor<2xi32>)
|
||||
%0:2 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x4xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<*xf32>, tensor<*xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
func @testSplitV2(%input: tensor<4x4xf32>) {
|
||||
%split_dim = "tf.Const"() {value = dense<1>: tensor<i32>} : () -> (tensor<i32>)
|
||||
%split_sizes = "tf.Const"() {value = dense<[3, 1]>: tensor<2xi32>} : () -> (tensor<2xi32>)
|
||||
%0:2 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x4xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<*xf32>, tensor<*xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// tf.All
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
func @testAllDimWrongRank(%input: tensor<4x6xi1>, %dims: tensor<2x2xi32>) {
|
||||
// expected-error @+1 {{dimensions can only be 0D or 1D tensor}}
|
||||
%0 = "tf.All"(%input, %dims) : (tensor<4x6xi1>, tensor<2x2xi32>) -> (tensor<*xi1>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testAllDimOutOfRange(%input: tensor<4x6xi1>) {
|
||||
%dims = "tf.Const"() {value = dense<[-1, 5]> : tensor<2xi32>} : () -> (tensor<2xi32>)
|
||||
// expected-error @+1 {{1-th dimension should be in the range of [-2, 2)}}
|
||||
%0 = "tf.All"(%input, %dims) : (tensor<4x6xi1>, tensor<2xi32>) -> (tensor<*xi1>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// tf.Any
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
func @testAnyDimWrongRank(%input: tensor<4x6xi1>, %dims: tensor<2x2xi32>) {
|
||||
// expected-error @+1 {{dimensions can only be 0D or 1D tensor}}
|
||||
%0 = "tf.Any"(%input, %dims) : (tensor<4x6xi1>, tensor<2x2xi32>) -> (tensor<*xi1>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testAnyDimOutOfRange(%input: tensor<4x6xi1>) {
|
||||
%dims = "tf.Const"() {value = dense<[-1, 5]> : tensor<2xi32>} : () -> (tensor<2xi32>)
|
||||
// expected-error @+1 {{1-th dimension should be in the range of [-2, 2)}}
|
||||
%0 = "tf.Any"(%input, %dims) : (tensor<4x6xi1>, tensor<2xi32>) -> (tensor<*xi1>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// tf.Unpack
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
func @testUnpackAxisOutOfRange(%input: tensor<2x6xf32>) {
|
||||
// expected-error @+1 {{axis attribute must be in the range of [-2, 2)}}
|
||||
%0:2 = "tf.Unpack"(%input) {axis = 5} : (tensor<2x6xf32>) -> (tensor<6xf32>, tensor<6xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testAxisUnknownDim(%input: tensor<?x6xf32>) {
|
||||
// CHECK: tf.Unpack
|
||||
%0:2 = "tf.Unpack"(%input) {axis = 0} : (tensor<?x6xf32>) -> (tensor<6xf32>, tensor<6xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testAxisDim(%input: tensor<2x6xf32>) {
|
||||
// expected-error @+1 {{result count must be equal to 6}}
|
||||
%0:2 = "tf.Unpack"(%input) {axis = -1} : (tensor<2x6xf32>) -> (tensor<6xf32>, tensor<6xf32>)
|
||||
return
|
||||
}
|
||||
|
@ -0,0 +1,41 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
# RUN: (! %p/exported_python_args 2>&1) | FileCheck %s
|
||||
|
||||
# pylint: disable=missing-docstring,line-too-long,dangerous-default-value
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow.compat.v2 as tf
|
||||
from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common
|
||||
|
||||
|
||||
class TestModule(tf.Module):
|
||||
|
||||
@tf.function(input_signature=[tf.TensorSpec([], tf.float32)])
|
||||
def some_function(self, x):
|
||||
return self.callee(x)
|
||||
|
||||
# CHECK: While importing SavedModel function 'callee': in input signature:
|
||||
# CHECK-SAME: Unhandled structured value kind {{.*}} at index path: <value>.1.foo
|
||||
@tf.function
|
||||
def callee(self, x, n={'foo': 42}):
|
||||
return x
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
common.do_test(TestModule)
|
@ -31,9 +31,14 @@ void CreateTPUBridge(OpPassManager &pm) {
|
||||
func_pm.addPass(tf_executor::CreateTFExecutorIslandCoarseningPass());
|
||||
func_pm.addPass(CreateTPUClusterFormationPass());
|
||||
func_pm.addPass(createCanonicalizerPass());
|
||||
// Place DecomposeResourceOpsPass before TFExecutorConstantSinking pass
|
||||
// because DecomposeResourceOpsPass uses pattern rewriter which hoists
|
||||
// changed constants out of tf_device.Launch.
|
||||
func_pm.addPass(TFDevice::CreateDecomposeResourceOpsPass());
|
||||
func_pm.addPass(tf_executor::CreateTFExecutorConstantSinkingPass());
|
||||
func_pm.addPass(TFDevice::CreateResourceOpLiftingPass());
|
||||
|
||||
pm.addPass(TF::CreateResourceDeviceInferencePass());
|
||||
pm.addPass(TFDevice::CreateClusterOutliningPass());
|
||||
pm.addPass(CreateTPURewritePass());
|
||||
pm.addNestedPass<FuncOp>(TFDevice::CreateReplicateInvariantOpHoistingPass());
|
||||
|
@ -44,16 +44,16 @@ struct ClusterOutliningPass : public ModulePass<ClusterOutliningPass> {
|
||||
|
||||
void ReplaceLaunchReturnWithReturn(tf_device::ReturnOp launch_return_op,
|
||||
OpBuilder* builder) {
|
||||
llvm::SmallVector<Value*, 4> operands(launch_return_op.getOperands());
|
||||
builder->create<ReturnOp>(launch_return_op.getLoc(), operands);
|
||||
builder->create<ReturnOp>(launch_return_op.getLoc(),
|
||||
launch_return_op.getOperands());
|
||||
launch_return_op.erase();
|
||||
}
|
||||
|
||||
// Builds a function that outlines region attached to launch_op and inserts
|
||||
// built function into given module.
|
||||
FuncOp BuildFunction(StringRef device, llvm::ArrayRef<Value*> live_ins,
|
||||
tf_device::LaunchOp launch_op,
|
||||
ModuleManager* module_manager, OpBuilder* builder) {
|
||||
tf_device::LaunchOp launch_op, SymbolTable* symbol_table,
|
||||
OpBuilder* builder) {
|
||||
llvm::SmallVector<Type, 4> operand_types;
|
||||
operand_types.reserve(live_ins.size());
|
||||
for (Value* v : live_ins) operand_types.emplace_back(v->getType());
|
||||
@ -92,14 +92,14 @@ FuncOp BuildFunction(StringRef device, llvm::ArrayRef<Value*> live_ins,
|
||||
builder->setInsertionPoint(launch_return_op);
|
||||
ReplaceLaunchReturnWithReturn(launch_return_op, builder);
|
||||
|
||||
module_manager->insert(outlined_func);
|
||||
symbol_table->insert(outlined_func);
|
||||
return outlined_func;
|
||||
}
|
||||
|
||||
// Outlines body of `tf_device.launch` into a function and create a
|
||||
// `tf_device.launch_func` to invoke that function. `tf_device.launch` is
|
||||
// removed afterwards.`
|
||||
void OutlineLaunch(tf_device::LaunchOp launch_op, ModuleManager* module_manager,
|
||||
void OutlineLaunch(tf_device::LaunchOp launch_op, SymbolTable* symbol_table,
|
||||
OpBuilder* builder) {
|
||||
llvm::SetVector<Value*> live_ins;
|
||||
getUsedValuesDefinedAbove(launch_op.body(), launch_op.body(), live_ins);
|
||||
@ -108,7 +108,7 @@ void OutlineLaunch(tf_device::LaunchOp launch_op, ModuleManager* module_manager,
|
||||
launch_op.getAttrOfType<StringAttr>(kDeviceAttr).getValue();
|
||||
|
||||
FuncOp outlined_func = BuildFunction(device, live_ins.getArrayRef(),
|
||||
launch_op, module_manager, builder);
|
||||
launch_op, symbol_table, builder);
|
||||
launch_op.setAttr(builder->getIdentifier(kFuncAttr),
|
||||
builder->getSymbolRefAttr(outlined_func.getName()));
|
||||
|
||||
@ -124,10 +124,10 @@ void OutlineLaunch(tf_device::LaunchOp launch_op, ModuleManager* module_manager,
|
||||
|
||||
void ClusterOutliningPass::runOnModule() {
|
||||
ModuleOp m = getModule();
|
||||
ModuleManager module_manager(m);
|
||||
SymbolTable symbol_table(m);
|
||||
OpBuilder builder(m.getContext());
|
||||
m.walk([&](tf_device::LaunchOp launch) {
|
||||
OutlineLaunch(launch, &module_manager, &builder);
|
||||
OutlineLaunch(launch, &symbol_table, &builder);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,31 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.h"
|
||||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_decompose_resource_ops.inc"
|
||||
|
||||
void PopulateDecomposeResourceOpsPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns) {
|
||||
populateWithGenerated(context, patterns);
|
||||
}
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
@ -0,0 +1,34 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_DECOMPOSE_RESOURCE_OPS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_DECOMPOSE_RESOURCE_OPS_H_
|
||||
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
|
||||
// Populates rewrite patterns that decompose composite resource operations into
|
||||
// primitive ones like ReadVariableOp, AssignVariableOp and other computations
|
||||
// to facilitate transformations like resource op lifting.
|
||||
void PopulateDecomposeResourceOpsPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns);
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_DECOMPOSE_RESOURCE_OPS_H_
|
@ -0,0 +1,63 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Dialect/StandardOps/Ops.td"
|
||||
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
||||
|
||||
def CreateTFReadVariableOp: NativeCodeCall<
|
||||
"$_builder.create<TF::ReadVariableOp>("
|
||||
" $0.getLoc(),"
|
||||
" UnrankedTensorType::get("
|
||||
" $1->getType().cast<TensorType>().getElementType()),"
|
||||
" $2)"
|
||||
>;
|
||||
|
||||
def DecomposeAssignAddVariableOp :
|
||||
Pat<
|
||||
(TF_AssignAddVariableOp:$src_op $resource, $value),
|
||||
(TF_AssignVariableOp
|
||||
$resource,
|
||||
(TF_AddV2Op
|
||||
(CreateTFReadVariableOp $src_op, $value, $resource),
|
||||
$value
|
||||
)
|
||||
)
|
||||
>;
|
||||
|
||||
def DecomposeAssignSubVariableOp :
|
||||
Pat<
|
||||
(TF_AssignSubVariableOp:$src_op $resource, $value),
|
||||
(TF_AssignVariableOp
|
||||
$resource,
|
||||
(TF_SubOp
|
||||
(CreateTFReadVariableOp $src_op, $value, $resource),
|
||||
$value
|
||||
)
|
||||
)
|
||||
>;
|
||||
|
||||
// This decomposition is only correct inside XLA as it ignores use_locking
|
||||
// attribute.
|
||||
def DecomposeResourceApplyGradientDescentOp :
|
||||
Pat<
|
||||
(TF_ResourceApplyGradientDescentOp:$src_op $resource, $alpha, $delta, $_),
|
||||
(TF_AssignVariableOp
|
||||
$resource,
|
||||
(TF_SubOp
|
||||
(CreateTFReadVariableOp $src_op, $alpha, $resource),
|
||||
(TF_MulOp $alpha, $delta)
|
||||
)
|
||||
)
|
||||
>;
|
@ -0,0 +1,59 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFDevice {
|
||||
namespace {
|
||||
|
||||
// A pass that decomposes composite resource operations into primitive ones like
|
||||
// ReadVariableOp, AssignVariableOp and other computations to facilitate
|
||||
// transformations like resource op lifting.
|
||||
//
|
||||
// For example:
|
||||
//
|
||||
// tf.AssignAddVariableOp(%res, %0)
|
||||
//
|
||||
// Becomes
|
||||
//
|
||||
// %res_val = tf.ReadVariableOp(%res)
|
||||
// %1 = tf.AddV2(%res_val, %0)
|
||||
// tf.AssignVariableOp(%res, %1)
|
||||
struct DecomposeResourceOps : public FunctionPass<DecomposeResourceOps> {
|
||||
void runOnFunction() override {
|
||||
// Add lowering patterns to the list.
|
||||
OwningRewritePatternList patterns;
|
||||
mlir::TF::PopulateDecomposeResourceOpsPatterns(&getContext(), &patterns);
|
||||
|
||||
applyPatternsGreedily(getFunction(), patterns);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateDecomposeResourceOpsPass() {
|
||||
return std::make_unique<DecomposeResourceOps>();
|
||||
}
|
||||
|
||||
} // namespace TFDevice
|
||||
} // namespace mlir
|
||||
|
||||
static mlir::PassRegistration<mlir::TFDevice::DecomposeResourceOps> pass(
|
||||
"tf-device-decompose-resource-ops",
|
||||
"Decompose composite resource variable operations into primitive "
|
||||
"Read/AssignVariableOp and raw computation");
|
@ -304,8 +304,7 @@ void InsertDummyIslandForFetch(FetchOp fetch) {
|
||||
/*control=*/ControlType::get(fetch.getContext()),
|
||||
/*controlInputs=*/control_fetches);
|
||||
island.body().push_back(new Block);
|
||||
OpBuilder(&island.GetBody())
|
||||
.create<YieldOp>(fetch.getLoc(), llvm::to_vector<4>(data_fetches));
|
||||
OpBuilder(&island.GetBody()).create<YieldOp>(fetch.getLoc(), data_fetches);
|
||||
const int fetch_control_idx = data_fetches.size();
|
||||
for (int i = 0, e = fetch.getNumOperands(); i < e; i++) {
|
||||
// The fetch could have multiple control operands (all at the end of its
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
// standard TensorFlow dialect to MLIR Control Flow Graph (CFG) form.
|
||||
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
|
||||
@ -79,8 +80,11 @@ static Operation* CallFn(Location loc,
|
||||
for (int i = 0; i < num_operands; ++i) {
|
||||
Value* val = get_arg(i);
|
||||
Type expected = fn_type.getInput(i);
|
||||
if (val->getType() != expected)
|
||||
val = builder->create<TensorCastOp>(loc, val, expected);
|
||||
if (val->getType() != expected) {
|
||||
val =
|
||||
builder->create<TF::CastOp>(loc, expected, val,
|
||||
/*Truncate=*/builder->getBoolAttr(false));
|
||||
}
|
||||
operands.push_back(val);
|
||||
}
|
||||
return builder->create<CallOp>(loc, fn, operands).getOperation();
|
||||
@ -100,8 +104,11 @@ static llvm::SmallVector<Value*, 4> PrepareValsForJump(
|
||||
for (int i = 0; i < num_vals; ++i) {
|
||||
Value* val = get_val(i);
|
||||
Type expected = block->getArgument(i)->getType();
|
||||
if (val->getType() != expected)
|
||||
val = builder->create<TensorCastOp>(loc, val, expected);
|
||||
if (val->getType() != expected) {
|
||||
val =
|
||||
builder->create<TF::CastOp>(loc, expected, val,
|
||||
/*Truncate=*/builder->getBoolAttr(false));
|
||||
}
|
||||
result.push_back(val);
|
||||
}
|
||||
return result;
|
||||
@ -131,8 +138,11 @@ static void ReplaceOpResultWithBlockArgs(Location loc, Operation* op,
|
||||
for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
|
||||
Value* arg = block->getArgument(i);
|
||||
Value* result = op->getResult(i);
|
||||
if (arg->getType() != result->getType())
|
||||
arg = builder->create<TensorCastOp>(loc, arg, result->getType());
|
||||
if (arg->getType() != result->getType()) {
|
||||
arg =
|
||||
builder->create<TF::CastOp>(loc, result->getType(), arg,
|
||||
/*Truncate=*/builder->getBoolAttr(false));
|
||||
}
|
||||
result->replaceAllUsesWith(arg);
|
||||
}
|
||||
}
|
||||
@ -301,26 +311,15 @@ void FunctionalControlFlowToCFG::runOnFunction() {
|
||||
// subsequent blocks.
|
||||
//
|
||||
// TODO: Use PatternRewriter to eliminate these function control flow ops.
|
||||
auto has_variant_operand = [](Operation* op) {
|
||||
auto is_variant = [](Type ty) {
|
||||
return getElementTypeOrSelf(ty).getKind() == TensorFlowTypes::VARIANT;
|
||||
};
|
||||
|
||||
if (llvm::none_of(op->getOperandTypes(), is_variant)) return false;
|
||||
|
||||
op->emitOpError() << "does not yet support operands of type variant "
|
||||
"for conversion to CFG";
|
||||
return true;
|
||||
};
|
||||
|
||||
if (IfOp if_op = llvm::dyn_cast<IfOp>(op)) {
|
||||
if (has_variant_operand(&op) || failed(LowerIfOp(if_op))) {
|
||||
if (failed(LowerIfOp(if_op))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
break;
|
||||
}
|
||||
if (WhileOp while_op = llvm::dyn_cast<WhileOp>(op)) {
|
||||
if (has_variant_operand(&op) || failed(LowerWhileOp(while_op))) {
|
||||
if (failed(LowerWhileOp(while_op))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
break;
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
|
||||
namespace mlir {
|
||||
@ -109,6 +110,39 @@ Type InferExpandDimsType(Type ty, int64_t axis, Builder *builder) {
|
||||
return RankedTensorType::get(shape, ranked_ty.getElementType());
|
||||
}
|
||||
|
||||
// Lowers AddN op to a sequence of AddV2 ops to accumulate operands.
|
||||
//
|
||||
// %result = "tf.AddN"(%0, %1, %2)
|
||||
//
|
||||
// is lowered to:
|
||||
//
|
||||
// %sum_0 = "tf.AddV2"(%0, %1)
|
||||
// %result = "tf.AddV2"(%sum_0, %2)
|
||||
//
|
||||
class LowerAddNOp : public OpRewritePattern<TF::AddNOp> {
|
||||
public:
|
||||
explicit LowerAddNOp(MLIRContext *context)
|
||||
: OpRewritePattern<TF::AddNOp>(context) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(TF::AddNOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// TODO(hinsu): Support variant with TensorList type. tf.AddV2 doesn't
|
||||
// support variant type so variant types require special handling.
|
||||
if (getElementTypeOrSelf(op.getType()).isa<VariantType>())
|
||||
return matchFailure();
|
||||
|
||||
// TODO(hinsu): Improve parallelism by splitting operands in two halves and
|
||||
// accumulating them first.
|
||||
Value *result = *op.inputs().begin();
|
||||
for (Value *operand : llvm::drop_begin(op.inputs(), 1)) {
|
||||
result = rewriter.create<TF::AddV2Op>(op.getLoc(), result, operand);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, result);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
// Lowers Pack op to ConcatV2 op after changing shape of the inputs with
|
||||
// ExpandDims op.
|
||||
//
|
||||
@ -159,6 +193,7 @@ class LowerPackOp : public OpRewritePattern<TF::PackOp> {
|
||||
|
||||
void PopulateLoweringTFPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns) {
|
||||
patterns->insert<LowerAddNOp>(context);
|
||||
patterns->insert<LowerPackOp>(context);
|
||||
populateWithGenerated(context, patterns);
|
||||
}
|
||||
|
@ -63,7 +63,7 @@ void CreateTFStandardPipeline(OpPassManager &pm,
|
||||
if (options.enable_inliner) {
|
||||
pm.addPass(createInlinerPass());
|
||||
}
|
||||
pm.addNestedPass<FuncOp>(CreateTFShapeInferencePass());
|
||||
pm.addPass(CreateTFShapeInferencePass());
|
||||
pm.addNestedPass<FuncOp>(CreateTFOptimizePass());
|
||||
pm.addNestedPass<FuncOp>(createCSEPass());
|
||||
}
|
||||
|
@ -57,6 +57,9 @@ struct StandardPipelineOptions : public PassOptions<StandardPipelineOptions> {
|
||||
// NOLINTNEXTLINE - MLIR contract is pass by mutable reference.
|
||||
void CreateTFStandardPipeline(OpPassManager& pm,
|
||||
const StandardPipelineOptions& options);
|
||||
|
||||
// Propagates device attributes of resources from callers to callees.
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateResourceDeviceInferencePass();
|
||||
} // namespace TF
|
||||
|
||||
namespace TFControlFlow {
|
||||
@ -96,6 +99,11 @@ std::unique_ptr<OpPassBase<FuncOp>> CreateClusterFormationPass();
|
||||
// Creates a pass that outlines regions of tf_device.launch operations.
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateClusterOutliningPass();
|
||||
|
||||
// A pass that decomposes composite resource operations into primitive ones like
|
||||
// ReadVariableOp, AssignVariableOp and other computations to facilitate
|
||||
// transformations like resource op lifting.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateDecomposeResourceOpsPass();
|
||||
|
||||
// Creates a pass that lifts operations on external resource variables from
|
||||
// device computation nested in `tf_device::LaunchOp` out so that resource
|
||||
// variable load operations are all before device computation while resource
|
||||
|
@ -64,8 +64,8 @@ llvm::SmallVector<tf_executor::IslandOp, 8> ExpandReplicateIntoReplicas(
|
||||
|
||||
// Replace replicate terminator with YieldOp.
|
||||
builder->setInsertionPoint(&terminator);
|
||||
builder->create<tf_executor::YieldOp>(
|
||||
terminator.getLoc(), llvm::to_vector<8>(terminator.getOperands()));
|
||||
builder->create<tf_executor::YieldOp>(terminator.getLoc(),
|
||||
terminator.getOperands());
|
||||
terminator.erase();
|
||||
|
||||
builder->setInsertionPoint(island_op);
|
||||
|
@ -0,0 +1,278 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <iterator>
|
||||
#include <memory>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/Optional.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ADT/iterator_range.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Types.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Value.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Visitors.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
|
||||
namespace {
|
||||
constexpr char kDeviceAttr[] = "device";
|
||||
constexpr char kFuncDeviceAttr[] = "tf.device";
|
||||
|
||||
// A pass that propagates device assignment of resources on a module. It
|
||||
// performs in-function propagation, as well as cross-function propagation from
|
||||
// callers to callees.
|
||||
//
|
||||
// This pass changes the module by adding "tf.device" attribute to function
|
||||
// arguments and adding "device" attribute to TF ops.
|
||||
struct ResourceDeviceInference : public ModulePass<ResourceDeviceInference> {
|
||||
void runOnModule() override;
|
||||
};
|
||||
|
||||
// A class that records each resource's device assignment in a function.
|
||||
class PerFunctionResult {
|
||||
public:
|
||||
explicit PerFunctionResult(FuncOp func_op) : alias_analysis_(func_op) {}
|
||||
|
||||
// Returns the recorded device assignment for a resource, if any.
|
||||
llvm::Optional<llvm::StringRef> DeviceForResource(
|
||||
const Value* resource) const {
|
||||
llvm::Optional<llvm::StringRef> result;
|
||||
if (alias_analysis_.IsUnknownResource(resource)) return result;
|
||||
for (int64_t id : alias_analysis_.GetResourceUniqueIds(resource)) {
|
||||
auto it = resource_id_to_device_.find(id);
|
||||
if (it == resource_id_to_device_.end()) continue;
|
||||
if (!result) {
|
||||
result = it->getSecond();
|
||||
continue;
|
||||
}
|
||||
if (result != it->getSecond()) {
|
||||
// Got conflicting assignments, clear the result.
|
||||
result.reset();
|
||||
return result;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Records the device assignment for a resource. If the new assignment
|
||||
// conflicts with an existing one, returns an error.
|
||||
//
|
||||
// If `changed` is provided, assign *changed to true if anything is modified.
|
||||
LogicalResult AddResourceDevice(const Value* resource, llvm::StringRef device,
|
||||
bool* changed = nullptr) {
|
||||
if (alias_analysis_.IsUnknownResource(resource)) return success();
|
||||
for (int64_t id : alias_analysis_.GetResourceUniqueIds(resource)) {
|
||||
auto emplace_res = resource_id_to_device_.try_emplace(id, device);
|
||||
if (emplace_res.second) {
|
||||
if (changed) *changed = true;
|
||||
} else if (emplace_res.first->getSecond() != device) {
|
||||
// Existing assignment does not equal the new assignment.
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
llvm::SmallDenseMap<int64_t, llvm::StringRef, 8> resource_id_to_device_;
|
||||
TF::ResourceAliasAnalysis alias_analysis_;
|
||||
};
|
||||
|
||||
// Tries to record device assignment for a resource.
|
||||
LogicalResult AddResourceDeviceAndEmitError(const Value* resource,
|
||||
llvm::StringRef device,
|
||||
Operation* error_reporting_op,
|
||||
PerFunctionResult* result,
|
||||
bool* changed = nullptr) {
|
||||
auto res = result->AddResourceDevice(resource, device, changed);
|
||||
if (failed(res)) {
|
||||
error_reporting_op->emitError()
|
||||
<< "Conflicting device assignment for resource";
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
// Propagates device assignment inside a function.
|
||||
LogicalResult ComputeResourceDevicesInComputation(FuncOp func_op,
|
||||
PerFunctionResult* result) {
|
||||
OpBuilder builder(func_op);
|
||||
// Function arguments.
|
||||
for (auto arg : func_op.getArguments()) {
|
||||
if (!mlir::getElementTypeOrSelf(arg->getType()).isa<TF::ResourceType>()) {
|
||||
continue;
|
||||
}
|
||||
auto device_attr = func_op.getArgAttrOfType<mlir::StringAttr>(
|
||||
arg->getArgNumber(), kFuncDeviceAttr);
|
||||
if (!device_attr || device_attr.getValue() == "") {
|
||||
// If device_attr does not exist, try to construct it from any recorded
|
||||
// assignment.
|
||||
if (auto device = result->DeviceForResource(arg)) {
|
||||
func_op.setArgAttr(arg->getArgNumber(), kFuncDeviceAttr,
|
||||
builder.getStringAttr(*device));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
// Record the attribute.
|
||||
auto res = AddResourceDeviceAndEmitError(arg, device_attr.getValue(),
|
||||
func_op, result);
|
||||
if (failed(res)) return res;
|
||||
}
|
||||
auto walk_res = func_op.walk([&](Operation* op) {
|
||||
if (auto var_handle = llvm::dyn_cast<TF::VarHandleOp>(op)) {
|
||||
// Record VarHanldeOp's device attribute.
|
||||
auto device_attr =
|
||||
var_handle.getAttrOfType<mlir::StringAttr>(kDeviceAttr);
|
||||
if (!device_attr || device_attr.getValue().empty()) {
|
||||
return WalkResult::advance();
|
||||
}
|
||||
auto res = AddResourceDeviceAndEmitError(
|
||||
var_handle.resource(), device_attr.getValue(), op, result);
|
||||
if (failed(res)) return WalkResult::interrupt();
|
||||
}
|
||||
if (auto identity = llvm::dyn_cast<TF::IdentityOp>(op)) {
|
||||
// Try to construct IdentityOp's attribute from recorded assignment.
|
||||
if (!mlir::getElementTypeOrSelf(identity.output()->getType())
|
||||
.isa<TF::ResourceType>()) {
|
||||
return WalkResult::advance();
|
||||
}
|
||||
if (auto device = result->DeviceForResource(identity.output())) {
|
||||
auto device_attr =
|
||||
identity.getAttrOfType<mlir::StringAttr>(kDeviceAttr);
|
||||
if (!device_attr || device_attr.getValue().empty()) {
|
||||
identity.setAttr(kDeviceAttr, builder.getStringAttr(*device));
|
||||
}
|
||||
}
|
||||
return WalkResult::advance();
|
||||
}
|
||||
// Propagate and record output device assignment for other ops based on
|
||||
// existing recording. E.g., IdentityN.
|
||||
for (auto output : op->getResults()) {
|
||||
if (!mlir::getElementTypeOrSelf(output->getType())
|
||||
.isa<TF::ResourceType>()) {
|
||||
continue;
|
||||
}
|
||||
if (auto device = result->DeviceForResource(output)) {
|
||||
auto res = AddResourceDeviceAndEmitError(output, *device, op, result);
|
||||
if (failed(res)) return WalkResult::interrupt();
|
||||
}
|
||||
}
|
||||
return WalkResult::advance();
|
||||
});
|
||||
return failure(walk_res.wasInterrupted());
|
||||
}
|
||||
|
||||
void ResourceDeviceInference::runOnModule() {
|
||||
auto module = getModule();
|
||||
llvm::SmallDenseMap<Operation*, PerFunctionResult, 4> per_function_results;
|
||||
llvm::SetVector<FuncOp> worklist;
|
||||
module.walk([&](FuncOp func_op) {
|
||||
worklist.insert(func_op);
|
||||
per_function_results.try_emplace(func_op, func_op);
|
||||
});
|
||||
// Helper that propagates an op's recorded operand device assignments to its
|
||||
// called function's arguments.
|
||||
auto propagate_operands_to_callee_arguments =
|
||||
[&](Operation* caller,
|
||||
llvm::iterator_range<OperandIterator> caller_operands,
|
||||
llvm::StringRef called_func_name,
|
||||
const PerFunctionResult& caller_res) {
|
||||
auto callee =
|
||||
llvm::dyn_cast<FuncOp>(module.lookupSymbol(called_func_name));
|
||||
assert(callee);
|
||||
auto& callee_res = per_function_results.find(callee)->getSecond();
|
||||
bool callee_needs_recompute = false;
|
||||
for (auto operand_and_argument :
|
||||
llvm::zip(caller_operands, callee.getArguments())) {
|
||||
if (!mlir::getElementTypeOrSelf(
|
||||
std::get<0>(operand_and_argument)->getType())
|
||||
.isa<TF::ResourceType>()) {
|
||||
continue;
|
||||
}
|
||||
auto device =
|
||||
caller_res.DeviceForResource(std::get<0>(operand_and_argument));
|
||||
if (!device) continue;
|
||||
if (failed(AddResourceDeviceAndEmitError(
|
||||
std::get<1>(operand_and_argument), *device, caller,
|
||||
&callee_res, &callee_needs_recompute))) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
// If the callee recording is modified, make sure that it will be
|
||||
// reprocessed.
|
||||
if (callee_needs_recompute) {
|
||||
worklist.insert(callee);
|
||||
}
|
||||
return success();
|
||||
};
|
||||
while (!worklist.empty()) {
|
||||
auto func_op = worklist.back();
|
||||
worklist.pop_back();
|
||||
auto& func_res = per_function_results.find(func_op)->getSecond();
|
||||
// In-function propagation.
|
||||
if (failed(ComputeResourceDevicesInComputation(func_op, &func_res))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
// Propagation to callees.
|
||||
auto walk_res = func_op.walk([&](Operation* op) {
|
||||
if (auto while_op = llvm::dyn_cast<TF::WhileOp>(op)) {
|
||||
if (failed(propagate_operands_to_callee_arguments(
|
||||
while_op, while_op.getOperands(), while_op.body(), func_res)) ||
|
||||
failed(propagate_operands_to_callee_arguments(
|
||||
while_op, while_op.getOperands(), while_op.cond(), func_res))) {
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
} else if (auto if_op = llvm::dyn_cast<TF::IfOp>(op)) {
|
||||
if (failed(propagate_operands_to_callee_arguments(
|
||||
if_op, if_op.input(), if_op.then_branch(), func_res)) ||
|
||||
failed(propagate_operands_to_callee_arguments(
|
||||
if_op, if_op.input(), if_op.else_branch(), func_res))) {
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
}
|
||||
return WalkResult::advance();
|
||||
});
|
||||
if (walk_res.wasInterrupted()) return signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateResourceDeviceInferencePass() {
|
||||
return std::make_unique<ResourceDeviceInference>();
|
||||
}
|
||||
|
||||
static PassRegistration<ResourceDeviceInference> pass(
|
||||
"tf-resource-device-inference",
|
||||
"Propagates the device attribute on resources from callers to callees.");
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
@ -77,129 +77,6 @@ struct ResourceOpLiftingPass : public FunctionPass<ResourceOpLiftingPass> {
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
// Rewrites composite variable op `tf.AssignAddVariableOp` or
|
||||
// `tf.AssignSubVariableOp` into primitive resource/computation ops.
|
||||
// For example:
|
||||
//
|
||||
// tf.AssignAddVariableOp(%res, %0)
|
||||
//
|
||||
// Becomes
|
||||
//
|
||||
// %res_val = tf.ReadVariableOp(%res)
|
||||
// %1 = tf.AddV2(%res_val, %0)
|
||||
// tf.AssignVariableOp(%res, %1)
|
||||
//
|
||||
template <typename T>
|
||||
LogicalResult RewriteCompositeAssignVariableOp(T src_op, OpBuilder* builder) {
|
||||
// Read mangled dtype, which indicates type of data stored in resource
|
||||
// variable. It can then be used to construct type needed for both
|
||||
// ReadVariableOp and AssignVariableOp.
|
||||
StringAttr mangled_dtype_attr =
|
||||
src_op.template getAttrOfType<StringAttr>(kDTypeAttr);
|
||||
std::string type_string = mangled_dtype_attr.getValue();
|
||||
tensorflow::DataType dtype_proto;
|
||||
auto s =
|
||||
tensorflow::mangling_util::DemangleDataType(type_string, &dtype_proto);
|
||||
if (!s.ok()) return src_op.emitError() << s.error_message();
|
||||
|
||||
Type type;
|
||||
s = tensorflow::ConvertDataType(dtype_proto, *builder, &type);
|
||||
if (!s.ok()) return src_op.emitError() << s.error_message();
|
||||
type = UnrankedTensorType::get(type);
|
||||
|
||||
builder->setInsertionPoint(src_op);
|
||||
|
||||
auto read_variable_op = builder->create<TF::ReadVariableOp>(
|
||||
src_op.getLoc(), type, src_op.resource());
|
||||
read_variable_op.setAttr(builder->getIdentifier(kDTypeAttr),
|
||||
mangled_dtype_attr);
|
||||
|
||||
Value* result;
|
||||
if (std::is_same<T, TF::AssignAddVariableOp>()) {
|
||||
result = builder->create<TF::AddV2Op>(
|
||||
src_op.getLoc(), read_variable_op.value(), src_op.value());
|
||||
} else {
|
||||
result = builder->create<TF::SubOp>(
|
||||
src_op.getLoc(), read_variable_op.value(), src_op.value());
|
||||
}
|
||||
|
||||
auto assign_variable_op = builder->create<TF::AssignVariableOp>(
|
||||
src_op.getLoc(), src_op.resource(), result);
|
||||
assign_variable_op.setAttr(builder->getIdentifier(kDTypeAttr),
|
||||
mangled_dtype_attr);
|
||||
|
||||
src_op.erase();
|
||||
return success();
|
||||
}
|
||||
|
||||
// Rewrites `tf.ResourceApplyGradientDescent` into primitive resource and
|
||||
// computation ops.
|
||||
//
|
||||
// Specifically:
|
||||
//
|
||||
// tf.ResourceApplyGradientDescent(%var, %alpha, %delta)
|
||||
//
|
||||
// Becomes
|
||||
//
|
||||
// %old_var_val = tf.ReadVariableOp(%var)
|
||||
// %gradient_update = tf.Mul(%alpha, %delta)
|
||||
// %new_var_val = tf.Sub(%old_var_val, %gradient_update)
|
||||
// tf.AssignVariableOp(%var, %new_var_val)
|
||||
LogicalResult RewriteResourceApplyGradientDescentOp(
|
||||
TF::ResourceApplyGradientDescentOp op, OpBuilder* builder) {
|
||||
Type type = op.alpha()->getType();
|
||||
auto t = UnrankedTensorType::get(type.cast<TensorType>().getElementType());
|
||||
|
||||
tensorflow::DataType data_type;
|
||||
auto s = tensorflow::ConvertToDataType(type, &data_type);
|
||||
if (!s.ok()) return op.emitError() << s.error_message();
|
||||
|
||||
std::string mangled_data_type =
|
||||
tensorflow::mangling_util::MangleDataType(data_type);
|
||||
auto mangled_dtype_attr = builder->getStringAttr(mangled_data_type);
|
||||
|
||||
builder->setInsertionPoint(op);
|
||||
auto read_variable_op =
|
||||
builder->create<TF::ReadVariableOp>(op.getLoc(), t, op.var());
|
||||
read_variable_op.setAttr(builder->getIdentifier(kDTypeAttr),
|
||||
mangled_dtype_attr);
|
||||
|
||||
auto mul_op =
|
||||
builder->create<TF::MulOp>(op.getLoc(), t, op.alpha(), op.delta());
|
||||
auto sub_op = builder->create<TF::SubOp>(
|
||||
op.getLoc(), t, read_variable_op.value(), mul_op.z());
|
||||
auto assign_variable_op =
|
||||
builder->create<TF::AssignVariableOp>(op.getLoc(), op.var(), sub_op.z());
|
||||
assign_variable_op.setAttr(builder->getIdentifier(kDTypeAttr),
|
||||
mangled_dtype_attr);
|
||||
|
||||
op.erase();
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
// Rewrites an operation that updates value of a resource variable into its
|
||||
// equivalent primitive ones so that following analysis/rewrite can be easier.
|
||||
// If given op is not a composite resource store op or is an unsupported op, no
|
||||
// change is applied.
|
||||
// TODO(ycao): Explore using pattern rewriter after needed operations are
|
||||
// defined.
|
||||
// TODO(ycao): Add support for other composite resource store ops.
|
||||
LogicalResult MaybeRewriteCompositeResourceStore(Operation* op,
|
||||
OpBuilder* builder) {
|
||||
if (auto assign_add_op = dyn_cast<TF::AssignAddVariableOp>(op)) {
|
||||
return RewriteCompositeAssignVariableOp(assign_add_op, builder);
|
||||
} else if (auto assign_sub_op = dyn_cast<TF::AssignSubVariableOp>(op)) {
|
||||
return RewriteCompositeAssignVariableOp(assign_sub_op, builder);
|
||||
} else if (auto resource_apply_gradient_descent_op =
|
||||
dyn_cast<TF::ResourceApplyGradientDescentOp>(op)) {
|
||||
return RewriteResourceApplyGradientDescentOp(
|
||||
resource_apply_gradient_descent_op, builder);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
// Performs store-load forwarding. This effectively removes
|
||||
// 1) Any resource loads after a store to that same resource is done
|
||||
// 2) Any resource stores except the last one.
|
||||
@ -358,10 +235,6 @@ void HoistResourceOpsFromLaunchOp(tf_device::LaunchOp launch_op) {
|
||||
ModuleOp m = launch_op.getParentOfType<ModuleOp>();
|
||||
OpBuilder builder(m);
|
||||
|
||||
// Rewrite composite resource store operations into primitive ones.
|
||||
launch_op.walk(
|
||||
[&](Operation* op) { MaybeRewriteCompositeResourceStore(op, &builder); });
|
||||
|
||||
// Perform store-load forwarding. So that each resource is only loaded with
|
||||
// its initial value and is only stored with its final value.
|
||||
ForwardStoreToLoad(launch_op);
|
||||
|
@ -47,6 +47,46 @@ using ::tensorflow::int64;
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
namespace {
|
||||
Optional<llvm::SmallVector<mlir::Type, 4>> InferShapeForFunctionReturnType(
|
||||
FuncOp func) {
|
||||
// Only infer shape when there is one return op for now.
|
||||
if (!has_single_element(func.getBody()) || func.front().empty()) {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Find the return type.
|
||||
auto return_op = dyn_cast<mlir::ReturnOp>(func.front().back());
|
||||
if (!return_op) {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Manually fold tf.Cast that precedes the return instruction and only differs
|
||||
// in shape refinement level.
|
||||
for (OpOperand& arg_op : return_op.getOperation()->getOpOperands()) {
|
||||
if (auto cast_op = dyn_cast<CastOp>(arg_op.get()->getDefiningOp())) {
|
||||
// Shape inference should not change the element type.
|
||||
if (cast_op.SrcT() != cast_op.DstT()) continue;
|
||||
// We only refine the result shape if the result a dynamic shape, the
|
||||
// input has static shape, and the two shapes are compatible.
|
||||
auto has_static_shape = [](const Value* value) {
|
||||
auto shaped_type = value->getType().dyn_cast<ShapedType>();
|
||||
return shaped_type && shaped_type.hasStaticShape();
|
||||
};
|
||||
Value* input = cast_op.x();
|
||||
Value* result = cast_op.y();
|
||||
if (!has_static_shape(input) || has_static_shape(result) ||
|
||||
failed(verifyCompatibleShape(input->getType(), result->getType())))
|
||||
continue;
|
||||
|
||||
arg_op.set(cast_op.x());
|
||||
if (cast_op.y()->use_empty()) cast_op.erase();
|
||||
}
|
||||
}
|
||||
|
||||
return llvm::to_vector<4>(return_op.getOperandTypes());
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
|
||||
int64_t graph_version) {
|
||||
@ -245,11 +285,10 @@ LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version,
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult InferShapeForFunction(FuncOp op,
|
||||
LogicalResult InferShapeForFunction(FuncOp func,
|
||||
ArrayRef<ArrayRef<int64_t>> arg_shapes,
|
||||
int64_t graph_version) {
|
||||
auto main_func = op;
|
||||
mlir::FunctionType func_type = main_func.getType();
|
||||
mlir::FunctionType func_type = func.getType();
|
||||
bool needs_refinement = false;
|
||||
llvm::SmallVector<mlir::Type, 4> new_arg_types;
|
||||
new_arg_types.reserve(func_type.getNumInputs());
|
||||
@ -276,7 +315,7 @@ LogicalResult InferShapeForFunction(FuncOp op,
|
||||
auto new_arg_type = mlir::RankedTensorType::get(shape, element_type);
|
||||
if (new_arg_type != func_type.getInput(i)) {
|
||||
// If the new type is more detailed, trigger shape inference.
|
||||
main_func.getArgument(i)->setType(new_arg_type);
|
||||
func.getArgument(i)->setType(new_arg_type);
|
||||
needs_refinement = true;
|
||||
}
|
||||
new_arg_types.push_back(new_arg_type);
|
||||
@ -287,39 +326,28 @@ LogicalResult InferShapeForFunction(FuncOp op,
|
||||
}
|
||||
|
||||
mlir::LogicalResult result =
|
||||
mlir::TF::InferShapeUntilFixPoint(&main_func.getBody(), graph_version);
|
||||
mlir::TF::InferShapeUntilFixPoint(&func.getBody(), graph_version);
|
||||
if (failed(result)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Must only have 1 block so that there is only one return op.
|
||||
if (main_func.getBody().getBlocks().size() != 1 ||
|
||||
main_func.front().empty()) {
|
||||
return failure();
|
||||
auto return_types = InferShapeForFunctionReturnType(func);
|
||||
func.setType(mlir::FunctionType::get(new_arg_types,
|
||||
return_types.hasValue()
|
||||
? return_types.getValue()
|
||||
: func.getType().getResults(),
|
||||
func.getContext()));
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult InferShapeForFunctionType(FuncOp func) {
|
||||
if (auto return_types = InferShapeForFunctionReturnType(func)) {
|
||||
func.setType(mlir::FunctionType::get(func.getType().getInputs(),
|
||||
return_types.getValue(),
|
||||
func.getContext()));
|
||||
}
|
||||
|
||||
// Find the return type.
|
||||
auto return_op = dyn_cast<mlir::ReturnOp>(*main_func.front().rbegin());
|
||||
if (!return_op) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Manually fold tf.Cast that precedes the return instruction and only differ
|
||||
// in shape refinement level.
|
||||
for (OpOperand& arg_op : return_op.getOperation()->getOpOperands()) {
|
||||
if (auto cast_op = dyn_cast<CastOp>(arg_op.get()->getDefiningOp())) {
|
||||
if (cast_op.SrcT() != cast_op.DstT()) continue;
|
||||
arg_op.set(cast_op.x());
|
||||
if (cast_op.y()->use_empty()) cast_op.erase();
|
||||
}
|
||||
}
|
||||
|
||||
llvm::SmallVector<mlir::Type, 4> return_types(return_op.getOperandTypes());
|
||||
|
||||
// Update function signature with the results of inference.
|
||||
main_func.setType(
|
||||
mlir::FunctionType::get(new_arg_types, return_types, op.getContext()));
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -41,12 +41,16 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
|
||||
LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version,
|
||||
int64_t max_iteration = 10);
|
||||
|
||||
// Given a list of refined shapes matching the function arguments of op, run
|
||||
// Given a list of refined shapes matching the function arguments of func, runs
|
||||
// shape inference over the function to propagate this updated information.
|
||||
LogicalResult InferShapeForFunction(FuncOp op,
|
||||
LogicalResult InferShapeForFunction(FuncOp func,
|
||||
ArrayRef<ArrayRef<int64_t>> arg_shapes,
|
||||
int64_t graph_version);
|
||||
|
||||
// Refines the return type of the given function by folding tf.Cast that
|
||||
// precedes the return instruction.
|
||||
LogicalResult InferShapeForFunctionType(FuncOp func);
|
||||
|
||||
} // namespace TF
|
||||
|
||||
} // namespace mlir
|
||||
|
@ -65,7 +65,11 @@ struct ShapeInference : public ModulePass<ShapeInference> {
|
||||
return;
|
||||
}
|
||||
for (auto func : module.getOps<FuncOp>()) {
|
||||
TF::InferShapeUntilFixPoint(&func.getBody(), producer.getInt());
|
||||
InferShapeUntilFixPoint(&func.getBody(), producer.getInt());
|
||||
}
|
||||
|
||||
if (auto main_func = module.lookupSymbol<mlir::FuncOp>("main")) {
|
||||
InferShapeForFunctionType(main_func);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -34,10 +34,12 @@ limitations under the License.
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ADT/iterator_range.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Identifier.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Types.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Value.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
@ -57,8 +59,6 @@ constexpr char kTPUReplicateAttr[] = "_tpu_replicate";
|
||||
constexpr char kDeviceAttr[] = "device";
|
||||
constexpr char kNameAttr[] = "name";
|
||||
constexpr char kNumReplicasAttr[] = "num_replicas";
|
||||
constexpr char kTPUReplicatedInputOp[] = "tf.TPUReplicatedInput";
|
||||
constexpr char kTPUReplicatedOutputOp[] = "tf.TPUReplicatedOutput";
|
||||
|
||||
constexpr char kBadTPUReplicateAttrMsg[] =
|
||||
"requires '_tpu_replicate' string attribute";
|
||||
@ -275,9 +275,8 @@ LogicalResult ReplicateCluster(tf_device::LaunchOp launch_op,
|
||||
mlir::visitUsedValuesDefinedAbove(
|
||||
launch_op.body(), launch_op.body(), [&](mlir::OpOperand* operand) {
|
||||
Operation* def = operand->get()->getDefiningOp();
|
||||
if (def && def->getName().getStringRef() == kTPUReplicatedInputOp) {
|
||||
if (def && llvm::isa<TF::TPUReplicatedInputOp>(def))
|
||||
replicated_input_ops.insert(def);
|
||||
}
|
||||
});
|
||||
|
||||
// Check if number of operands of each used TPUReplicatedInput op matches
|
||||
@ -305,10 +304,10 @@ LogicalResult ReplicateCluster(tf_device::LaunchOp launch_op,
|
||||
int idx = result_and_idx.index();
|
||||
for (auto& use : result->getUses()) {
|
||||
Operation* def = use.getOwner();
|
||||
if (!def || def->getName().getStringRef() != kTPUReplicatedOutputOp)
|
||||
if (!def || !llvm::isa<TF::TPUReplicatedOutputOp>(def))
|
||||
return launch_op.emitError()
|
||||
<< "requires output of " << launch_op.getOperationName()
|
||||
<< " to lead to a '" << kTPUReplicatedOutputOp << "' op";
|
||||
<< " to lead to a 'tf.TPUReplicatedOutput' op";
|
||||
|
||||
if (def->getNumResults() != num_replicas)
|
||||
return def->emitOpError() << "requires " << num_replicas << " results";
|
||||
@ -331,9 +330,8 @@ LogicalResult ReplicateCluster(tf_device::LaunchOp launch_op,
|
||||
|
||||
// Create terminator for replicate op and move launch into replicate.
|
||||
builder.setInsertionPointToEnd(&replicate_op.GetBody());
|
||||
auto return_op = builder.create<tf_device::ReturnOp>(
|
||||
replicate_op.getLoc(),
|
||||
llvm::SmallVector<Value*, 8>(launch_op.getResults()));
|
||||
auto return_op = builder.create<tf_device::ReturnOp>(replicate_op.getLoc(),
|
||||
launch_op.getResults());
|
||||
launch_op.getOperation()->moveBefore(return_op);
|
||||
|
||||
return success();
|
||||
@ -427,8 +425,8 @@ void TPUClusterFormation::runOnFunction() {
|
||||
|
||||
// Remove TPUReplicatedInput and TPUReplicatedOutput nodes.
|
||||
auto remove_result = getFunction().walk([&](Operation* op) {
|
||||
auto op_name = op->getName().getStringRef();
|
||||
if (op_name != kTPUReplicatedInputOp && op_name != kTPUReplicatedOutputOp)
|
||||
if (!llvm::isa<TF::TPUReplicatedInputOp>(op) &&
|
||||
!llvm::isa<TF::TPUReplicatedOutputOp>(op))
|
||||
return WalkResult::advance();
|
||||
|
||||
// Forward operand to result. When `num_replicas` attribute is 1, no
|
||||
@ -440,7 +438,8 @@ void TPUClusterFormation::runOnFunction() {
|
||||
// Leftover TPUReplicatedInput/TPUReplicatedOutput that are not of
|
||||
// `num_replicas` to 1.
|
||||
if (!op->use_empty()) {
|
||||
op->emitOpError() << "expects " << op_name << " to have no uses";
|
||||
op->emitOpError() << "expects " << op->getName().getStringRef()
|
||||
<< " to have no uses";
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
|
||||
|
@ -109,13 +109,13 @@ LogicalResult EncapsulateFuncAndSerialize(FuncOp entry_func,
|
||||
return parent_module.emitError(CreateMissingAttributeMsg(kVersionsAttr));
|
||||
|
||||
module_for_func.get().getOperation()->setAttr(kVersionsAttr, versions_attr);
|
||||
ModuleManager module_manager(module_for_func.get());
|
||||
SymbolTable symbol_table(module_for_func.get());
|
||||
|
||||
while (!referenced.empty()) {
|
||||
auto func = referenced.pop_back_val();
|
||||
|
||||
// Skip functions that have already been cloned into new module.
|
||||
if (module_manager.lookupSymbol<FuncOp>(func.getName())) continue;
|
||||
if (symbol_table.lookup<FuncOp>(func.getName())) continue;
|
||||
|
||||
// Find any SymbolRefAttr in func that maps to a FuncOp. We need to clone
|
||||
// all found FuncOps to new_module to make sure new_module is
|
||||
@ -138,7 +138,7 @@ LogicalResult EncapsulateFuncAndSerialize(FuncOp entry_func,
|
||||
// should be no other reference to it.
|
||||
clone.setName("main");
|
||||
}
|
||||
module_manager.insert(clone);
|
||||
symbol_table.insert(clone);
|
||||
}
|
||||
|
||||
// Serialize module and return.
|
||||
|
@ -13,14 +13,19 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/STLExtras.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
|
||||
|
||||
// This pass is used in preparation for Graph export.
|
||||
@ -38,12 +43,11 @@ struct BreakUpIslands : OperationPass<BreakUpIslands, FuncOp> {
|
||||
void runOnOperation() final;
|
||||
|
||||
void BreakUpIsland(tf_executor::IslandOp op,
|
||||
const TF::SideEffectAnalysis& side_effect_analysis,
|
||||
llvm::DenseMap<Operation*, llvm::SmallVector<Value*, 4>>*
|
||||
new_control_edges);
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
void BreakUpIslands::runOnOperation() {
|
||||
auto graph_op_range = getOperation().getBody().front().without_terminator();
|
||||
tf_executor::GraphOp graph_op;
|
||||
@ -61,12 +65,13 @@ void BreakUpIslands::runOnOperation() {
|
||||
// Map from the users of the existing islands to the list of control
|
||||
// edges that need to be added.
|
||||
llvm::DenseMap<Operation*, llvm::SmallVector<Value*, 4>> new_control_edges;
|
||||
auto& side_effect_analysis = getAnalysis<TF::SideEffectAnalysis>();
|
||||
// Iterate in reverse order to avoid invalidating Operation* stored in
|
||||
// new_control_edges.
|
||||
for (auto& item :
|
||||
llvm::make_early_inc_range(llvm::reverse(graph_op.GetBody()))) {
|
||||
if (auto island = dyn_cast<tf_executor::IslandOp>(&item)) {
|
||||
BreakUpIsland(island, &new_control_edges);
|
||||
BreakUpIsland(island, side_effect_analysis, &new_control_edges);
|
||||
}
|
||||
}
|
||||
OpBuilder builder(getOperation());
|
||||
@ -106,21 +111,81 @@ void BreakUpIslands::runOnOperation() {
|
||||
}
|
||||
}
|
||||
|
||||
// Helper that creates an island. If `sub_op` is not nullptr, it will be moved
|
||||
// to the island.
|
||||
tf_executor::IslandOp CreateIsland(ArrayRef<Type> result_types,
|
||||
ArrayRef<Value*> control_inputs,
|
||||
const tf_executor::ControlType& control_type,
|
||||
const Location& loc, Operation* sub_op,
|
||||
tf_executor::IslandOp original_island) {
|
||||
OpBuilder builder(original_island);
|
||||
auto island = builder.create<tf_executor::IslandOp>(
|
||||
loc, result_types, control_type, control_inputs);
|
||||
island.body().push_back(new Block);
|
||||
Block* block = &island.body().back();
|
||||
if (sub_op) {
|
||||
sub_op->replaceAllUsesWith(island.outputs());
|
||||
sub_op->moveBefore(block, block->begin());
|
||||
}
|
||||
OpBuilder island_builder(original_island);
|
||||
island_builder.setInsertionPointToEnd(block);
|
||||
if (sub_op) {
|
||||
island_builder.create<tf_executor::YieldOp>(loc, sub_op->getResults());
|
||||
} else {
|
||||
island_builder.create<tf_executor::YieldOp>(loc, ArrayRef<Value*>{});
|
||||
}
|
||||
return island;
|
||||
}
|
||||
|
||||
// A struct contains the operations in an island that do not have incoming or
|
||||
// outgoing dependencies.
|
||||
struct IslandSourcesAndSinks {
|
||||
// Sub-ops that do not depend on other ops in the island.
|
||||
llvm::SmallPtrSet<Operation*, 4> sources;
|
||||
// Sub-ops that do not have other sub-ops island depending on them (excluding
|
||||
// yield).
|
||||
llvm::SmallPtrSet<Operation*, 4> sinks;
|
||||
};
|
||||
|
||||
// Finds IslandSourcesAndSinks for an unmodified island.
|
||||
IslandSourcesAndSinks FindSourcesAndSinksInIsland(
|
||||
tf_executor::IslandOp island,
|
||||
const TF::SideEffectAnalysis& side_effect_analysis) {
|
||||
IslandSourcesAndSinks result;
|
||||
auto island_body = island.GetBody().without_terminator();
|
||||
for (Operation& sub_op : island_body) {
|
||||
auto predecessors = side_effect_analysis.DirectControlPredecessors(&sub_op);
|
||||
result.sinks.insert(&sub_op);
|
||||
// Remove predecessor from sinks.
|
||||
for (auto predecessor : predecessors) result.sinks.erase(predecessor);
|
||||
bool has_in_island_operands = false;
|
||||
for (auto operand : sub_op.getOperands()) {
|
||||
auto defining_op = operand->getDefiningOp();
|
||||
if (!defining_op || defining_op->getParentOp() != island) continue;
|
||||
// Remove operands from sinks.
|
||||
result.sinks.erase(defining_op);
|
||||
has_in_island_operands = true;
|
||||
}
|
||||
if (predecessors.empty() && !has_in_island_operands) {
|
||||
result.sources.insert(&sub_op);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Converts a single island into multiple islands (one for each op). The islands
|
||||
// are chained together by control flow values.
|
||||
void BreakUpIslands::BreakUpIsland(
|
||||
tf_executor::IslandOp op,
|
||||
const TF::SideEffectAnalysis& side_effect_analysis,
|
||||
llvm::DenseMap<Operation*, llvm::SmallVector<Value*, 4>>*
|
||||
new_control_edges) {
|
||||
auto island_body = op.GetBody().without_terminator();
|
||||
// Skip islands that are already only a single op.
|
||||
// Skip islands that are empty (only yield).
|
||||
if (island_body.empty() || has_single_element(island_body)) return;
|
||||
OpBuilder builder(op);
|
||||
OpBuilder island_builder(op);
|
||||
auto control_type = tf_executor::ControlType::get(&getContext());
|
||||
Value* previous_island = nullptr;
|
||||
auto tmp_control_inputs = llvm::to_vector<4>(op.controlInputs());
|
||||
auto island_control_inputs = llvm::to_vector<4>(op.controlInputs());
|
||||
// Add control dependencies for yields of values defined by other islands to
|
||||
// the island that defines that fetched value.
|
||||
for (auto* fetch : op.GetYield().fetches()) {
|
||||
@ -130,7 +195,7 @@ void BreakUpIslands::BreakUpIsland(
|
||||
// OK, because it is the same island.
|
||||
} else if (auto island_op = llvm::dyn_cast<tf_executor::IslandOp>(
|
||||
fetch->getDefiningOp())) {
|
||||
tmp_control_inputs.push_back(island_op.control());
|
||||
island_control_inputs.push_back(island_op.control());
|
||||
} else {
|
||||
// TODO(parkers): Any defining op that has a control output can be handled
|
||||
// just like an island.
|
||||
@ -138,39 +203,71 @@ void BreakUpIslands::BreakUpIsland(
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
ArrayRef<Value*> previous_control = tmp_control_inputs;
|
||||
// If there are multiple control inputs, create an empty island to group them.
|
||||
if (island_control_inputs.size() > 1) {
|
||||
auto island = CreateIsland({}, island_control_inputs, control_type,
|
||||
op.getLoc(), nullptr, op);
|
||||
island_control_inputs.clear();
|
||||
island_control_inputs.push_back(island.control());
|
||||
}
|
||||
// Find sources and sinks inside the original island.
|
||||
auto sources_and_sinks =
|
||||
FindSourcesAndSinksInIsland(op, side_effect_analysis);
|
||||
// The corresponding control output of the new island created for each sub-op.
|
||||
llvm::SmallDenseMap<Operation*, Value*, 8> new_control_for_sub_ops;
|
||||
// Control outputs of newly created islands that are sinks.
|
||||
llvm::SmallVector<Value*, 8> sink_island_controls;
|
||||
// For each operation in the island, construct a new island to wrap the op,
|
||||
// yield all the results, and replace all the usages with the results of the
|
||||
// new island.
|
||||
for (Operation& sub_op : llvm::make_early_inc_range(island_body)) {
|
||||
auto loc = sub_op.getLoc();
|
||||
auto island = builder.create<tf_executor::IslandOp>(
|
||||
loc, llvm::to_vector<4>(sub_op.getResultTypes()), control_type,
|
||||
previous_control);
|
||||
island.body().push_back(new Block);
|
||||
Block* block = &island.body().back();
|
||||
sub_op.replaceAllUsesWith(island.outputs());
|
||||
block->getOperations().splice(block->begin(), op.GetBody().getOperations(),
|
||||
sub_op);
|
||||
island_builder.setInsertionPointToEnd(block);
|
||||
island_builder.create<tf_executor::YieldOp>(
|
||||
loc, llvm::to_vector<4>(sub_op.getResults()));
|
||||
previous_island = island.control();
|
||||
previous_control = previous_island;
|
||||
for (auto& sub_op : llvm::make_early_inc_range(island_body)) {
|
||||
const auto predecessors =
|
||||
side_effect_analysis.DirectControlPredecessors(&sub_op);
|
||||
// Get the controls from the predecessors.
|
||||
llvm::SmallVector<Value*, 4> predecessors_control;
|
||||
predecessors_control.reserve(predecessors.size());
|
||||
for (auto predecessor : predecessors) {
|
||||
predecessors_control.push_back(new_control_for_sub_ops[predecessor]);
|
||||
}
|
||||
// If sub_op is a source, use island_control_inputs, because that's required
|
||||
// by inter-islands dependencies; otherwise, we do not need to include
|
||||
// island_control_inputs, since they must have been tracked by the (direct
|
||||
// or indirect) control predecessors or operands.
|
||||
ArrayRef<Value*> control = sources_and_sinks.sources.count(&sub_op) > 0
|
||||
? island_control_inputs
|
||||
: predecessors_control;
|
||||
auto island =
|
||||
CreateIsland(llvm::to_vector<4>(sub_op.getResultTypes()), control,
|
||||
control_type, sub_op.getLoc(), &sub_op, op);
|
||||
new_control_for_sub_ops[&sub_op] = island.control();
|
||||
if (sources_and_sinks.sinks.count(&sub_op)) {
|
||||
sink_island_controls.push_back(island.control());
|
||||
}
|
||||
}
|
||||
op.control()->replaceAllUsesWith(previous_island);
|
||||
// All existing outputs need to add a control flow edge to the
|
||||
// previous_island.
|
||||
// Create output controls for the sinks.
|
||||
assert(!sink_island_controls.empty());
|
||||
// If there are multiple output controls, create an empty island to group
|
||||
// them.
|
||||
if (sink_island_controls.size() > 1) {
|
||||
auto island = CreateIsland({}, sink_island_controls, control_type,
|
||||
op.getLoc(), nullptr, op);
|
||||
sink_island_controls.clear();
|
||||
sink_island_controls.push_back(island.control());
|
||||
}
|
||||
assert(sink_island_controls.size() == 1);
|
||||
op.control()->replaceAllUsesWith(sink_island_controls[0]);
|
||||
// All existing outputs need to add a control flow edge from
|
||||
// sink_island_controls[0].
|
||||
for (Value* out : op.outputs()) {
|
||||
for (auto& use : out->getUses()) {
|
||||
Operation* owner = use.getOwner();
|
||||
if (auto island_op =
|
||||
llvm::dyn_cast<tf_executor::IslandOp>(owner->getParentOp())) {
|
||||
(*new_control_edges)[island_op].push_back(previous_island);
|
||||
(*new_control_edges)[island_op].push_back(sink_island_controls[0]);
|
||||
} else if (llvm::isa<tf_executor::FetchOp>(owner) ||
|
||||
llvm::isa<tf_executor::MergeOp>(owner) ||
|
||||
llvm::isa<tf_executor::SwitchOp>(owner)) {
|
||||
(*new_control_edges)[owner].push_back(previous_island);
|
||||
(*new_control_edges)[owner].push_back(sink_island_controls[0]);
|
||||
} else {
|
||||
use.getOwner()->emitError("Adding control dependency not supported");
|
||||
return signalPassFailure();
|
||||
@ -182,6 +279,8 @@ void BreakUpIslands::BreakUpIsland(
|
||||
op.erase();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateBreakUpIslandsPass() {
|
||||
return std::make_unique<BreakUpIslands>();
|
||||
}
|
||||
|
@ -535,6 +535,18 @@ StatusOr<std::unique_ptr<Graph>> Exporter::Convert(
|
||||
arg, index,
|
||||
graph_as_function && !input_names.empty() ? input_names[index] : ""));
|
||||
}
|
||||
|
||||
auto convert_called_function = [&](llvm::StringRef name) {
|
||||
auto func =
|
||||
function.getParentOfType<mlir::ModuleOp>().lookupSymbol<mlir::FuncOp>(
|
||||
name);
|
||||
if (func != nullptr) {
|
||||
TF_RETURN_IF_ERROR(ConvertLibFunction(configs, tf_dialect, func, flib));
|
||||
TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(*flib));
|
||||
}
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
// Adds nodes for operations.
|
||||
for (Operation& inst : block) {
|
||||
auto op_name = GetTensorFlowOpName(inst.getName().getStringRef());
|
||||
@ -544,13 +556,12 @@ StatusOr<std::unique_ptr<Graph>> Exporter::Convert(
|
||||
// definition library
|
||||
// TODO(prakalps): If two functions have cyclic dependence, this will
|
||||
// introduce an infinite loop.
|
||||
auto func =
|
||||
function.getParentOfType<mlir::ModuleOp>().lookupSymbol<mlir::FuncOp>(
|
||||
op_name.ValueOrDie());
|
||||
if (func != nullptr) {
|
||||
TF_RETURN_IF_ERROR(ConvertLibFunction(configs, tf_dialect, func, flib));
|
||||
TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(*flib));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(convert_called_function(op_name.ValueOrDie().str()));
|
||||
}
|
||||
|
||||
if (IsLegacyCallInstruction(&inst)) {
|
||||
TF_RETURN_IF_ERROR(convert_called_function(
|
||||
inst.getAttrOfType<mlir::SymbolRefAttr>("f").getLeafReference()));
|
||||
}
|
||||
|
||||
for (auto type : inst.getResultTypes()) {
|
||||
|
@ -16,8 +16,11 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
|
||||
|
||||
#include <iterator>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
@ -97,6 +100,9 @@ using stream_executor::port::StatusOr;
|
||||
|
||||
namespace {
|
||||
|
||||
const char* disable_call_shape_inference_attribute_name =
|
||||
"_disable_call_shape_inference";
|
||||
|
||||
// This class is used to generate new MLIR function name strings that are both
|
||||
// unique in the TF function library `flib_` and unique among the name strings
|
||||
// generated by the class object during its lifetime.
|
||||
@ -246,11 +252,14 @@ class ImporterBase {
|
||||
llvm::SmallVector<mlir::NamedAttribute, 4>* attributes);
|
||||
|
||||
// Helper to create either a tf_executor operation or a TF operation wrapped
|
||||
// in an island.
|
||||
// in an island. When convert_to_legacy_call is true, converts the operation
|
||||
// representing a call to a library function with a name represented in
|
||||
// node_type_name to LegacyCallOp.
|
||||
mlir::Operation* createOperation(
|
||||
const Node& node, llvm::StringRef op_name,
|
||||
const Node& node, llvm::StringRef node_type_name,
|
||||
const mlir::OperationState& result,
|
||||
const llvm::SmallVectorImpl<mlir::Value*>& control_operands);
|
||||
const llvm::SmallVectorImpl<mlir::Value*>& control_operands,
|
||||
bool convert_to_legacy_call = false);
|
||||
|
||||
// Converts one NodeDef from the input GraphDef into an Operation and
|
||||
// inserts it into the MLIR module using builder_.
|
||||
@ -297,19 +306,24 @@ class ImporterBase {
|
||||
// Gets the location information string for the given node.
|
||||
std::string GetLocationStr(const Node& node, bool includeNodeName = false);
|
||||
|
||||
// Inserts a placeholder node in the graph to replace the input node. Replaces
|
||||
// all the output edges of the input_node with the placeholder node, and
|
||||
// removes the input_node from the graph. The new node has the same name as
|
||||
// the input_node, so Nodespecs do not need any modification.
|
||||
// Inserts a placeholder node in the graph to replace a feed output tensor,
|
||||
// and returns the new placeholder node and a boolean indicating if the
|
||||
// original input node was removed from the graph. Uses of the feed output
|
||||
// tensor are replaced with this placeholder node. If the feed output tensor
|
||||
// is of a single output node, the control dependencies are forwarded to the
|
||||
// the placeholder node, and the original node will be removed.
|
||||
// Note: This modifies the graph, and so any list of ordered nodes needs to be
|
||||
// reconstructed.
|
||||
StatusOr<Node*> ReplaceWithPlaceholderNode(const TensorShapeProto& shape,
|
||||
DataType dtype, Node* input_node);
|
||||
StatusOr<std::pair<Node*, bool>> CreatePlaceholderNodeForFeed(
|
||||
const TensorShapeProto& shape, DataType dtype, Node* node, int index,
|
||||
const std::unordered_map<string, Node*>& node_name_map);
|
||||
|
||||
// Gets the input and output nodes corresponding to the specified input and
|
||||
// output nodes in specs_. If there are no input or output nodes specified,
|
||||
// nodes will be empty
|
||||
Status GetInputOutputNodes(std::unordered_set<const Node*>* nodes);
|
||||
// nodes will be empty.
|
||||
Status GetInputOutputNodes(
|
||||
const std::unordered_map<string, Node*>& node_name_map,
|
||||
std::unordered_set<const Node*>* nodes);
|
||||
|
||||
// The input graph with backedges removed. The removed backedges are stored
|
||||
// in the back_edge_helper.
|
||||
@ -339,6 +353,10 @@ class ImporterBase {
|
||||
NodeValueMap node_values_;
|
||||
std::unique_ptr<ShapeRefiner> shape_refiner_;
|
||||
NameUniquifier* function_name_uniquifier_;
|
||||
|
||||
protected:
|
||||
// Maps feed as TensorId to new Placeholder node name.
|
||||
absl::flat_hash_map<TensorId, absl::string_view> remapped_feeds_;
|
||||
};
|
||||
|
||||
// Returns true if the node with given name has a non primary output that is
|
||||
@ -419,6 +437,49 @@ Status PreprocessGraphDef(const GraphImportConfig* specs, GraphDef* graph_def) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Mapping from node name to feed (index and ArrayInfo). Node name must outlive
|
||||
// this map.
|
||||
using FeedsByNode = absl::flat_hash_map<
|
||||
absl::string_view,
|
||||
absl::flat_hash_map<int, const std::pair<std::string, ArrayInfo>*>>;
|
||||
|
||||
// Creates from a `GraphImportConfig::InputArrays` a mapping from a feeds output
|
||||
// tensor name to index and ArrayInfo. Keys and values are backed by
|
||||
// `GraphImportConfig::InputArrays`.
|
||||
StatusOr<FeedsByNode> GetFeedsByNode(
|
||||
const GraphImportConfig::InputArrays& inputs) {
|
||||
FeedsByNode feeds_by_node;
|
||||
feeds_by_node.reserve(inputs.size());
|
||||
|
||||
for (const auto& input : inputs) {
|
||||
TensorId tensor = ParseTensorName(input.first);
|
||||
if (tensor.index() < 0)
|
||||
return errors::FailedPrecondition(
|
||||
"Feed output tensor must be a data output '", tensor.ToString(), "'");
|
||||
|
||||
auto& node = feeds_by_node[tensor.node()];
|
||||
if (!node.insert({tensor.index(), &input}).second)
|
||||
return errors::FailedPrecondition(
|
||||
"Multiple feeds for the same output tensor '", tensor.ToString(),
|
||||
"'");
|
||||
}
|
||||
|
||||
return feeds_by_node;
|
||||
}
|
||||
|
||||
// Creates a unique name for a node that will be replacing a feed output tensor.
|
||||
std::string GetUniqueNodeName(
|
||||
absl::string_view node_name, int index,
|
||||
const std::unordered_map<string, Node*>& node_name_map) {
|
||||
std::string new_node_name_base = absl::StrCat(node_name, "_", index);
|
||||
int count = 0;
|
||||
std::string new_node_name = new_node_name_base;
|
||||
while (node_name_map.find(new_node_name) != node_name_map.end()) {
|
||||
new_node_name = absl::StrCat(new_node_name_base, "_", count++);
|
||||
}
|
||||
return new_node_name;
|
||||
}
|
||||
|
||||
Status ImporterBase::RemoveBackedges(const Graph& graph) {
|
||||
// TODO(fengliuai): Converting to GraphDef and back is the easiest way to
|
||||
// clone a graph.
|
||||
@ -459,37 +520,54 @@ Status ImporterBase::RemoveBackedges(const Graph& graph) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<Node*> ImporterBase::ReplaceWithPlaceholderNode(
|
||||
const TensorShapeProto& shape, DataType dtype, Node* input_node) {
|
||||
StatusOr<std::pair<Node*, bool>> ImporterBase::CreatePlaceholderNodeForFeed(
|
||||
const TensorShapeProto& shape, DataType dtype, Node* node, int index,
|
||||
const std::unordered_map<string, Node*>& node_name_map) {
|
||||
DCHECK_LT(index, node->num_outputs());
|
||||
const bool update_inplace = node->num_outputs() == 1 && index == 0;
|
||||
std::string new_node_name =
|
||||
update_inplace ? node->name()
|
||||
: GetUniqueNodeName(node->name(), index, node_name_map);
|
||||
|
||||
Node* placeholder_node;
|
||||
NodeBuilder builder(input_node->name(), "Placeholder");
|
||||
NodeBuilder builder(new_node_name, "Placeholder");
|
||||
builder.Attr("shape", shape);
|
||||
builder.Attr("dtype", dtype);
|
||||
TF_RETURN_IF_ERROR(builder.Finalize(graph_.get(), &placeholder_node));
|
||||
|
||||
while (!input_node->out_edges().empty()) {
|
||||
const Edge* oe = *input_node->out_edges().begin();
|
||||
// UpdateEdge cannot be used with control edges.
|
||||
if (oe->src_output() == Graph::kControlSlot) {
|
||||
graph_->AddControlEdge(placeholder_node, oe->dst());
|
||||
graph_->RemoveControlEdge(oe);
|
||||
continue;
|
||||
// Update edges from original feed with Placeholder node.
|
||||
std::vector<const Edge*> data_edges;
|
||||
std::vector<const Edge*> control_edges;
|
||||
for (const tensorflow::Edge* edge : node->out_edges()) {
|
||||
if (edge->src_output() == index) {
|
||||
data_edges.push_back(edge);
|
||||
} else if (update_inplace && edge->IsControlEdge()) {
|
||||
control_edges.push_back(edge);
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
graph_->UpdateEdge(placeholder_node, 0, oe->dst(), oe->dst_input()));
|
||||
}
|
||||
|
||||
graph_->RemoveNode(input_node);
|
||||
for (const auto* edge : data_edges) {
|
||||
TF_RETURN_IF_ERROR(graph_->UpdateEdge(placeholder_node, 0, edge->dst(),
|
||||
edge->dst_input()));
|
||||
}
|
||||
|
||||
return placeholder_node;
|
||||
for (const auto* edge : control_edges) {
|
||||
graph_->AddControlEdge(placeholder_node, edge->dst());
|
||||
graph_->RemoveControlEdge(edge);
|
||||
}
|
||||
|
||||
if (update_inplace) {
|
||||
graph_->RemoveNode(node);
|
||||
}
|
||||
|
||||
return std::pair<Node*, bool>(placeholder_node, update_inplace);
|
||||
}
|
||||
|
||||
Status ImporterBase::GetInputOutputNodes(
|
||||
const std::unordered_map<string, Node*>& node_name_map,
|
||||
std::unordered_set<const Node*>* nodes) {
|
||||
auto node_name_map = graph_->BuildNodeNameIndex();
|
||||
auto add_node = [&](const string& name) {
|
||||
auto it = node_name_map.find(name);
|
||||
auto add_node = [&](absl::string_view name) {
|
||||
auto it = node_name_map.find(std::string(name));
|
||||
if (it == node_name_map.end()) {
|
||||
return errors::FailedPrecondition(
|
||||
absl::StrCat("Graph does not contain node: ", name));
|
||||
@ -498,13 +576,25 @@ Status ImporterBase::GetInputOutputNodes(
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
// Remap feeds and fetches to newly created Placeholder nodes.
|
||||
for (const auto& input : specs_.inputs) {
|
||||
TF_RETURN_IF_ERROR(add_node(input.first));
|
||||
TensorId tensor = ParseTensorName(input.first);
|
||||
auto remapped_it = remapped_feeds_.find(tensor);
|
||||
if (remapped_it != remapped_feeds_.end()) {
|
||||
TF_RETURN_IF_ERROR(add_node(remapped_it->second));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(add_node(tensor.node()));
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& output : specs_.outputs) {
|
||||
auto output_node_name = std::string(ParseTensorName(output).first);
|
||||
TF_RETURN_IF_ERROR(add_node(output_node_name));
|
||||
TensorId tensor = ParseTensorName(output);
|
||||
auto remapped_it = remapped_feeds_.find(tensor);
|
||||
if (remapped_it != remapped_feeds_.end()) {
|
||||
TF_RETURN_IF_ERROR(add_node(remapped_it->second));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(add_node(tensor.node()));
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
@ -520,6 +610,9 @@ Status ImporterBase::AddNodesToShapeRefiner() {
|
||||
shape_refiner_->set_require_shape_inference_fns(false);
|
||||
shape_refiner_->set_function_library_for_shape_inference(&graph_flib_);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto feeds_by_node, GetFeedsByNode(specs_.inputs));
|
||||
auto node_name_map = graph_->BuildNodeNameIndex();
|
||||
|
||||
// First add all nodes to the refiner.
|
||||
for (Node* node : ordered_nodes_) {
|
||||
// We need to use a TensorFlow node to teach the shape refiner that user
|
||||
@ -533,28 +626,49 @@ Status ImporterBase::AddNodesToShapeRefiner() {
|
||||
// it to replace the original input node, so the shape refiner can
|
||||
// successfully propagate the user's input type and shape to the rest of the
|
||||
// graph.
|
||||
auto it = specs_.inputs.find(node->name());
|
||||
if (it != specs_.inputs.end()) {
|
||||
auto node_name = node->op_def().name();
|
||||
if (node_name != "Placeholder" && node_name != "LegacyFedInput" &&
|
||||
node_name != FunctionLibraryDefinition::kArgOp) {
|
||||
// We do not handle the case where the input node has multiple outputs
|
||||
if (node->num_outputs() > 1) {
|
||||
return errors::FailedPrecondition(absl::StrCat(
|
||||
"Input arrays can only have op with single output. Node op:",
|
||||
node_name));
|
||||
bool node_added_to_shape_refiner = false;
|
||||
auto it = feeds_by_node.find(node->name());
|
||||
if (it != feeds_by_node.end()) {
|
||||
auto op_name = node->op_def().name();
|
||||
if (op_name != "Placeholder" && op_name != "LegacyFedInput" &&
|
||||
op_name != FunctionLibraryDefinition::kArgOp) {
|
||||
for (const auto& output_tensor : it->second) {
|
||||
const int index = output_tensor.first;
|
||||
const ArrayInfo& array_info = output_tensor.second->second;
|
||||
|
||||
DataType dtype = array_info.imported_dtype;
|
||||
// Uses the existing output type if it isn't specified by the user.
|
||||
if (dtype == DT_INVALID) {
|
||||
dtype = node->output_type(0);
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto placeholder_node_and_removed,
|
||||
CreatePlaceholderNodeForFeed(array_info.shape, dtype, node, index,
|
||||
node_name_map));
|
||||
|
||||
Node* placeholder_node = placeholder_node_and_removed.first;
|
||||
if (placeholder_node_and_removed.second) {
|
||||
// Original node has been removed from the graph.
|
||||
node = placeholder_node;
|
||||
node_added_to_shape_refiner = true;
|
||||
}
|
||||
remapped_feeds_[{it->first, index}] = placeholder_node->name();
|
||||
node_name_map[placeholder_node->name()] = placeholder_node;
|
||||
// Add the new placeholder node to the shape refiner.
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
shape_refiner_->AddNode(placeholder_node),
|
||||
GetLocationStr(*placeholder_node));
|
||||
}
|
||||
// For single output nodes, replace them with Placeholder node.
|
||||
DataType dtype = it->second.imported_dtype;
|
||||
// Uses the existing output type if it isn't specified by the user.
|
||||
if (dtype == DT_INVALID) {
|
||||
dtype = node->output_type(0);
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
node, ReplaceWithPlaceholderNode(it->second.shape, dtype, node));
|
||||
} else {
|
||||
node->AddAttr("shape", it->second.shape);
|
||||
DataType dtype = it->second.imported_dtype;
|
||||
auto index_it = it->second.find(0);
|
||||
if (index_it == it->second.end()) {
|
||||
return errors::FailedPrecondition(
|
||||
"Missing feed output tensor at index 0 for node '", node->name(),
|
||||
"'");
|
||||
}
|
||||
node->AddAttr("shape", index_it->second->second.shape);
|
||||
DataType dtype = index_it->second->second.imported_dtype;
|
||||
// Uses the existing output type if it isn't specified by the user.
|
||||
if (dtype == DT_INVALID) {
|
||||
dtype = node->output_type(0);
|
||||
@ -562,9 +676,11 @@ Status ImporterBase::AddNodesToShapeRefiner() {
|
||||
node->AddAttr("dtype", dtype);
|
||||
}
|
||||
}
|
||||
// Adds the node to the shape refiner.
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(shape_refiner_->AddNode(node),
|
||||
GetLocationStr(*node));
|
||||
if (!node_added_to_shape_refiner) {
|
||||
// Add the node to the shape refiner if the node hasn't been removed.
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(shape_refiner_->AddNode(node),
|
||||
GetLocationStr(*node));
|
||||
}
|
||||
|
||||
auto set_shape_from_list_attr = [&](const AttrValue* attr) {
|
||||
auto& list = attr->list();
|
||||
@ -625,7 +741,7 @@ Status ImporterBase::AddNodesToShapeRefiner() {
|
||||
// Prune nodes in the graph that are not reachable from the output.
|
||||
if (specs_.prune_unused_nodes) {
|
||||
std::unordered_set<const Node*> prune_start;
|
||||
TF_RETURN_IF_ERROR(GetInputOutputNodes(&prune_start));
|
||||
TF_RETURN_IF_ERROR(GetInputOutputNodes(node_name_map, &prune_start));
|
||||
if (!prune_start.empty()) {
|
||||
if (PruneForReverseReachability(graph_.get(), prune_start)) {
|
||||
VLOG(1) << "Pruned unused nodes in graphdef";
|
||||
@ -829,9 +945,11 @@ StatusOr<mlir::Attribute> ImporterBase::ConvertAttributeValue(
|
||||
return builder_.getFloatAttr(builder_.getF32Type(), value.f());
|
||||
case AttrValue::kB:
|
||||
return builder_.getBoolAttr(value.b());
|
||||
case AttrValue::kType:
|
||||
return builder_.getStringAttr(
|
||||
mangling_util::MangleDataType(value.type()));
|
||||
case AttrValue::kType: {
|
||||
mlir::Type type;
|
||||
TF_RETURN_IF_ERROR(ConvertDataType(value.type(), builder_, &type));
|
||||
return mlir::TypeAttr::get(type);
|
||||
}
|
||||
case AttrValue::kShape:
|
||||
return builder_.getStringAttr(mangling_util::MangleShape(value.shape()));
|
||||
case AttrValue::kTensor:
|
||||
@ -1106,11 +1224,9 @@ Status ImporterBase::ConvertFunctionArgAndRets(
|
||||
builder_.setInsertionPointToEnd(&graph_op.body().front());
|
||||
builder_.create<mlir::tf_executor::FetchOp>(graph_op.getLoc(),
|
||||
inst_to_return);
|
||||
inst_to_return.assign(graph_op.getResults().begin(),
|
||||
graph_op.getResults().end());
|
||||
builder_.setInsertionPointToEnd(bb);
|
||||
builder_.create<mlir::ReturnOp>(mlir::UnknownLoc::get(context_),
|
||||
inst_to_return);
|
||||
graph_op.getResults());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -1210,9 +1326,10 @@ std::string ImporterBase::GetLocationStr(const Node& node,
|
||||
}
|
||||
|
||||
mlir::Operation* ImporterBase::createOperation(
|
||||
const Node& node, llvm::StringRef op_name,
|
||||
const Node& node, llvm::StringRef node_type_name,
|
||||
const mlir::OperationState& result,
|
||||
const llvm::SmallVectorImpl<mlir::Value*>& control_operands) {
|
||||
const llvm::SmallVectorImpl<mlir::Value*>& control_operands,
|
||||
bool convert_to_legacy_call) {
|
||||
// For the tf.executor specific operations (not wrapped in an island), we
|
||||
// have an extra returned value for the control result, and we concatenate
|
||||
// control and non-control operands.
|
||||
@ -1274,11 +1391,31 @@ mlir::Operation* ImporterBase::createOperation(
|
||||
mlir::OpBuilder island_builder(&island.GetBody());
|
||||
|
||||
// Create the operation inside the island now.
|
||||
mlir::Operation* inner_op = island_builder.createOperation(result);
|
||||
mlir::Operation* inner_op;
|
||||
if (convert_to_legacy_call) {
|
||||
bool disable_call_shape_inference = false;
|
||||
for (const auto& name_and_value : node.attrs()) {
|
||||
const auto& attr_name = name_and_value.first;
|
||||
const AttrValue& attr_value = name_and_value.second;
|
||||
if (strcmp(attr_name.c_str(),
|
||||
disable_call_shape_inference_attribute_name) == 0 &&
|
||||
attr_value.value_case() == AttrValue::kB) {
|
||||
disable_call_shape_inference = attr_value.b();
|
||||
}
|
||||
}
|
||||
|
||||
mlir::BoolAttr attribute =
|
||||
builder_.getBoolAttr(disable_call_shape_inference);
|
||||
inner_op = island_builder.create<mlir::TF::LegacyCallOp>(
|
||||
result.location, result.types, result.operands,
|
||||
island_builder.getSymbolRefAttr(node_type_name), attribute);
|
||||
} else {
|
||||
inner_op = island_builder.createOperation(result);
|
||||
}
|
||||
|
||||
// Add the terminator for the island
|
||||
mlir::SmallVector<mlir::Value*, 8> ret_vals(inner_op->getResults());
|
||||
island_builder.create<mlir::tf_executor::YieldOp>(result.location, ret_vals);
|
||||
island_builder.create<mlir::tf_executor::YieldOp>(result.location,
|
||||
inner_op->getResults());
|
||||
return island.getOperation();
|
||||
}
|
||||
|
||||
@ -1293,9 +1430,11 @@ Status ImporterBase::ConvertNode(const Node& node) {
|
||||
// create the MLIR function and insert it to the module if it doesn't exist.
|
||||
std::string node_type_name = node.type_string();
|
||||
const auto* func_def = graph_flib_.Find(node_type_name);
|
||||
bool convert_to_legacy_call = false;
|
||||
if (func_def) {
|
||||
TF_RETURN_IF_ERROR(ConvertLibFunction(node_type_name));
|
||||
node_type_name = (*tf_name_to_mlir_name_)[node_type_name];
|
||||
convert_to_legacy_call = true;
|
||||
}
|
||||
|
||||
auto get_full_op_name = [&](const std::string& op_name) {
|
||||
@ -1380,6 +1519,14 @@ Status ImporterBase::ConvertNode(const Node& node) {
|
||||
for (const auto& name_and_value : node.attrs()) {
|
||||
const auto& attr_name = name_and_value.first;
|
||||
const AttrValue& attr_value = name_and_value.second;
|
||||
// LegacyCall can only represent _diable_call_shape_inference attribute.
|
||||
// If a call has other attributes, can't convert it to LegacyCall.
|
||||
if (convert_to_legacy_call &&
|
||||
(strcmp(attr_name.c_str(),
|
||||
disable_call_shape_inference_attribute_name) ||
|
||||
attr_value.value_case() != AttrValue::kB)) {
|
||||
convert_to_legacy_call = false;
|
||||
}
|
||||
if (attr_value.value_case() == AttrValue::kFunc) {
|
||||
// Attribute iteration order is not defined for protocol buffer Map.
|
||||
// Process function attributes separately in the lexicographical order to
|
||||
@ -1423,9 +1570,8 @@ Status ImporterBase::ConvertNode(const Node& node) {
|
||||
}
|
||||
|
||||
// Register the mapping between the TF node and the newly created operation.
|
||||
node_values_[node.id()] =
|
||||
createOperation(node, op_name, result, control_operands);
|
||||
|
||||
node_values_[node.id()] = createOperation(
|
||||
node, node_type_name, result, control_operands, convert_to_legacy_call);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -1667,36 +1813,52 @@ StatusOr<mlir::FunctionType> GraphDefImporter::InferMainFunctionType(
|
||||
const GraphImportConfig& specs, mlir::MLIRContext* context,
|
||||
absl::InlinedVector<OutputTensor, 4>* arg_nodes,
|
||||
absl::InlinedVector<OutputTensor, 4>* ret_nodes) {
|
||||
// Finds out all the input nodes and output nodes.
|
||||
absl::flat_hash_set<absl::string_view> output_node_names;
|
||||
for (const auto& output_tensor : specs.outputs) {
|
||||
output_node_names.insert(ParseTensorName(output_tensor).node());
|
||||
// Find all the input nodes and output nodes.
|
||||
// Feeds have been remapped to single output nodes (Placeholder), so an exact
|
||||
// name match is sufficient.
|
||||
absl::flat_hash_map<absl::string_view, int> inputs;
|
||||
for (auto input_and_idx : llvm::enumerate(specs.inputs)) {
|
||||
TensorId tensor = ParseTensorName(input_and_idx.value().first);
|
||||
auto remapped_it = remapped_feeds_.find(tensor);
|
||||
if (remapped_it != remapped_feeds_.end()) {
|
||||
inputs.insert({remapped_it->second, input_and_idx.index()});
|
||||
} else {
|
||||
inputs.insert({tensor.node(), input_and_idx.index()});
|
||||
}
|
||||
}
|
||||
if (!specs.inputs.empty() || !specs.outputs.empty()) {
|
||||
arg_nodes->resize(specs.inputs.size());
|
||||
ret_nodes->resize(specs.outputs.size());
|
||||
|
||||
absl::flat_hash_set<absl::string_view> output_node_names;
|
||||
std::vector<TensorId> outputs;
|
||||
output_node_names.reserve(specs.outputs.size());
|
||||
for (const auto& output : specs.outputs) {
|
||||
TensorId tensor = ParseTensorName(output);
|
||||
auto remapped_it = remapped_feeds_.find(tensor);
|
||||
if (remapped_it != remapped_feeds_.end()) {
|
||||
output_node_names.insert(remapped_it->second);
|
||||
outputs.push_back({remapped_it->second, 0});
|
||||
} else {
|
||||
output_node_names.insert(tensor.node());
|
||||
outputs.push_back(tensor);
|
||||
}
|
||||
}
|
||||
|
||||
if (!inputs.empty() || !outputs.empty()) {
|
||||
arg_nodes->resize(inputs.size());
|
||||
ret_nodes->resize(outputs.size());
|
||||
|
||||
for (Node* n : GetOrderedNodes()) {
|
||||
// Handle inputs/arguments.
|
||||
auto input_it = specs.inputs.find(n->name());
|
||||
if (input_it != specs.inputs.end()) {
|
||||
(*arg_nodes)[std::distance(specs.inputs.begin(), input_it)] = {n, 0};
|
||||
auto input_it = inputs.find(n->name());
|
||||
if (input_it != inputs.end()) {
|
||||
(*arg_nodes)[input_it->second] = {n, 0};
|
||||
}
|
||||
|
||||
// Handle outputs/returns.
|
||||
if (output_node_names.contains(n->name())) {
|
||||
for (int i = 0, e = specs.outputs.size(); i != e; ++i) {
|
||||
std::pair<std::string, std::string> name_and_port =
|
||||
absl::StrSplit(specs.outputs[i], ':');
|
||||
auto name = name_and_port.first;
|
||||
if (name != n->name()) continue;
|
||||
int port = 0;
|
||||
if (!name_and_port.second.empty() &&
|
||||
!absl::SimpleAtoi(name_and_port.second, &port)) {
|
||||
return errors::InvalidArgument("Invalid port specification: ",
|
||||
specs.outputs[i]);
|
||||
}
|
||||
(*ret_nodes)[i] = {n, port};
|
||||
for (int i = 0, e = outputs.size(); i != e; ++i) {
|
||||
TensorId tensor = outputs[i];
|
||||
if (n->name() != tensor.node()) continue;
|
||||
(*ret_nodes)[i] = {n, tensor.index()};
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -2118,7 +2280,11 @@ class StructuredValueLinearizer {
|
||||
|
||||
// Returns the list of index paths to each leaf of the StructuredValue,
|
||||
// in a linearized order matching `tf.nest.flatten`.
|
||||
llvm::ArrayRef<mlir::ArrayAttr> GetLeafIndexPaths() const;
|
||||
//
|
||||
// If an error ocurred during the linearization process, an error message with
|
||||
// `error_context` prepended will be included in the returned status.
|
||||
StatusOr<llvm::ArrayRef<mlir::ArrayAttr>> GetLeafIndexPaths(
|
||||
llvm::StringRef error_context) const;
|
||||
|
||||
private:
|
||||
// Main function that recursively traverses the StructuredValue.
|
||||
@ -2130,6 +2296,8 @@ class StructuredValueLinearizer {
|
||||
llvm::SmallVector<mlir::Attribute, 4> current_index_path_;
|
||||
// The list of leaf index paths we have discovered so far.
|
||||
llvm::SmallVector<mlir::ArrayAttr, 4> leaf_index_paths_;
|
||||
// If non-empty, an error message to report.
|
||||
std::string error_message_;
|
||||
};
|
||||
|
||||
StructuredValueLinearizer::StructuredValueLinearizer(
|
||||
@ -2138,9 +2306,19 @@ StructuredValueLinearizer::StructuredValueLinearizer(
|
||||
RecursivelyFindLeaves(value);
|
||||
}
|
||||
|
||||
llvm::ArrayRef<mlir::ArrayAttr> StructuredValueLinearizer::GetLeafIndexPaths()
|
||||
const {
|
||||
return leaf_index_paths_;
|
||||
StatusOr<llvm::ArrayRef<mlir::ArrayAttr>>
|
||||
StructuredValueLinearizer::GetLeafIndexPaths(
|
||||
llvm::StringRef error_context) const {
|
||||
if (error_message_.empty()) {
|
||||
return llvm::makeArrayRef(leaf_index_paths_);
|
||||
}
|
||||
return errors::InvalidArgument(
|
||||
error_context.str(), error_message_,
|
||||
"This likely means that you have @tf.function "
|
||||
"on an exported function instead of "
|
||||
"@tf.function(input_signature=[...]). Consider annotating an "
|
||||
"input_signature or narrowing your set of "
|
||||
"exported names to not include this function.");
|
||||
}
|
||||
|
||||
void StructuredValueLinearizer::RecursivelyFindLeaves(
|
||||
@ -2196,7 +2374,20 @@ void StructuredValueLinearizer::RecursivelyFindLeaves(
|
||||
return;
|
||||
}
|
||||
default: {
|
||||
llvm_unreachable("Unhandled StructuredValue kind!");
|
||||
llvm::raw_string_ostream os(error_message_);
|
||||
// TODO(silvasean): Use an enumerant name string instead of a number.
|
||||
os << "Unhandled structured value kind " << value.kind_case()
|
||||
<< " at index path: <value>";
|
||||
for (auto path_element : current_index_path_) {
|
||||
os << ".";
|
||||
if (auto integer = path_element.dyn_cast<mlir::IntegerAttr>()) {
|
||||
os << integer.getValue();
|
||||
} else {
|
||||
auto str = path_element.cast<mlir::StringAttr>();
|
||||
os << str.getValue();
|
||||
}
|
||||
}
|
||||
os << "\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -2290,6 +2481,9 @@ Status CreateSavedModelIR(
|
||||
if (object_names.GetExportedNames(node_id).empty()) {
|
||||
continue;
|
||||
}
|
||||
std::string error_context =
|
||||
"While importing SavedModel function '" +
|
||||
object_names.GetExportedNames(node_id)[0].str() + "': ";
|
||||
const SavedFunction& function = object.function();
|
||||
auto orig_func = symbol_table.lookup<mlir::FuncOp>(
|
||||
tf_name_to_mlir_name.find(function.concrete_functions(0))->second);
|
||||
@ -2314,8 +2508,7 @@ Status CreateSavedModelIR(
|
||||
/*config=*/builder.getStringAttr(""),
|
||||
/*config_proto=*/builder.getStringAttr(""),
|
||||
/*executor_type=*/builder.getStringAttr(""));
|
||||
body_builder.create<mlir::ReturnOp>(
|
||||
func.getLoc(), llvm::to_vector<4>(call.getResults()));
|
||||
body_builder.create<mlir::ReturnOp>(func.getLoc(), call.getResults());
|
||||
}
|
||||
func.setAttr(
|
||||
"tf_saved_model.exported_names",
|
||||
@ -2338,9 +2531,12 @@ Status CreateSavedModelIR(
|
||||
|
||||
int bound_input_base =
|
||||
func.getNumArguments() - concrete_function.bound_inputs_size();
|
||||
auto input_index_paths = input_linearizer.GetLeafIndexPaths();
|
||||
TF_ASSIGN_OR_RETURN(auto input_index_paths,
|
||||
input_linearizer.GetLeafIndexPaths(
|
||||
error_context + "in input signature: "));
|
||||
if (bound_input_base != input_index_paths.size()) {
|
||||
return errors::InvalidArgument(
|
||||
error_context,
|
||||
"Argument mismatch between concrete function input signature "
|
||||
"vs underlying FunctionDef for concrete function '",
|
||||
function.concrete_functions(0), "' (", input_index_paths.size(),
|
||||
@ -2361,9 +2557,12 @@ Status CreateSavedModelIR(
|
||||
|
||||
StructuredValueLinearizer output_linearizer(
|
||||
concrete_function.output_signature(), builder.getContext());
|
||||
auto output_index_paths = output_linearizer.GetLeafIndexPaths();
|
||||
TF_ASSIGN_OR_RETURN(auto output_index_paths,
|
||||
output_linearizer.GetLeafIndexPaths(
|
||||
error_context + "in output signature: "));
|
||||
if (func.getNumResults() != output_index_paths.size()) {
|
||||
return errors::InvalidArgument(
|
||||
error_context,
|
||||
"Result mismatch between concrete function output signature "
|
||||
"vs underlying FunctionDef for concrete function '",
|
||||
function.concrete_functions(0), "' (", output_index_paths.size(),
|
||||
|
@ -67,8 +67,6 @@ void FunctionalToExecutorDialectConversion::runOnFunction() {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Expect function to end with return\n");
|
||||
return;
|
||||
}
|
||||
llvm::SmallVector<Value*, 4> args =
|
||||
llvm::to_vector<4>(return_op.getOperands());
|
||||
// Build GraphOp.
|
||||
OpBuilder builder(&body, body.begin());
|
||||
auto graph_op = builder.create<tf_executor::GraphOp>(
|
||||
@ -79,10 +77,10 @@ void FunctionalToExecutorDialectConversion::runOnFunction() {
|
||||
loc, getFunction().getType().getResults(),
|
||||
tf_executor::ControlType::get(&getContext()), ArrayRef<Value*>());
|
||||
// Create Fetch.
|
||||
auto to_fetch = llvm::to_vector<4>(island.getResults());
|
||||
ValueRange to_fetch = island.getResults();
|
||||
if (to_fetch.size() != 1) {
|
||||
// Drop control result for fetch.
|
||||
to_fetch.pop_back();
|
||||
to_fetch = to_fetch.drop_back();
|
||||
}
|
||||
builder.create<tf_executor::FetchOp>(loc, to_fetch);
|
||||
// Build Island.
|
||||
@ -91,7 +89,7 @@ void FunctionalToExecutorDialectConversion::runOnFunction() {
|
||||
island.body().front().begin(), body.getOperations(), copy_range.begin(),
|
||||
copy_range.end());
|
||||
builder.setInsertionPointToEnd(&island.body().front());
|
||||
builder.create<tf_executor::YieldOp>(loc, args);
|
||||
builder.create<tf_executor::YieldOp>(loc, return_op.getOperands());
|
||||
for (auto item : llvm::enumerate(graph_op.getResults())) {
|
||||
return_op.setOperand(item.index(), item.value());
|
||||
}
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||
@ -58,7 +58,7 @@ Status ParseMlirModule(llvm::StringRef mlir_module_string,
|
||||
|
||||
// Converts arg_shapes to xla::Shape's and store into xla_input_shapes.
|
||||
Status GetXlaInputShapes(
|
||||
mlir::ModuleOp module, absl::Span<TensorShape> arg_shapes,
|
||||
mlir::ModuleOp module, llvm::ArrayRef<TensorShape> arg_shapes,
|
||||
const xla::CustomShapeRepresentationFn shape_representation_fn,
|
||||
std::vector<xla::Shape>* xla_input_shapes) {
|
||||
xla_input_shapes->clear();
|
||||
@ -150,7 +150,8 @@ void GetInputMappingForMlir(int num_inputs, std::vector<int>* input_mapping) {
|
||||
}
|
||||
|
||||
// Refine MLIR types based on new shape information.
|
||||
Status RefineShapes(absl::Span<TensorShape> arg_shapes, mlir::ModuleOp module) {
|
||||
Status RefineShapes(llvm::ArrayRef<TensorShape> arg_shapes,
|
||||
mlir::ModuleOp module) {
|
||||
auto versions = module.getAttrOfType<::mlir::DictionaryAttr>("tf.versions");
|
||||
if (!versions) {
|
||||
return errors::Internal(
|
||||
@ -234,7 +235,7 @@ Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op,
|
||||
}
|
||||
|
||||
Status CompileSerializedMlirToXlaHlo(
|
||||
llvm::StringRef mlir_module_string, absl::Span<TensorShape> arg_shapes,
|
||||
llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes,
|
||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||
XlaCompiler::CompilationResult* compilation_result) {
|
||||
mlir::MLIRContext mlir_context;
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_COMPILE_MLIR_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_COMPILE_MLIR_UTIL_H_
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
@ -40,7 +40,7 @@ Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op,
|
||||
// Compiles a serialized MLIR module into XLA HLO, generates all accompanying
|
||||
// metadata and stores them in CompilationResult.
|
||||
Status CompileSerializedMlirToXlaHlo(
|
||||
llvm::StringRef mlir_module_string, absl::Span<TensorShape> arg_shapes,
|
||||
llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes,
|
||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||
XlaCompiler::CompilationResult* compilation_result);
|
||||
} // namespace tensorflow
|
||||
|
@ -41,9 +41,9 @@ TEST(CompileSerializedMlirToXlaHloTest, InvalidSerializedMlirModule) {
|
||||
std::vector<TensorShape> arg_shapes;
|
||||
XlaCompiler::CompilationResult compilation_result;
|
||||
|
||||
Status s = CompileSerializedMlirToXlaHlo(
|
||||
invalid_mlir_module, absl::Span<TensorShape>(arg_shapes),
|
||||
TestShapeRepresentation, &compilation_result);
|
||||
Status s = CompileSerializedMlirToXlaHlo(invalid_mlir_module, arg_shapes,
|
||||
TestShapeRepresentation,
|
||||
&compilation_result);
|
||||
EXPECT_EQ(s.code(), tensorflow::errors::Code::INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
@ -61,8 +61,7 @@ TEST(CompileSerializedMlirToXlaHloTest, Success) {
|
||||
XlaCompiler::CompilationResult compilation_result;
|
||||
|
||||
Status s = CompileSerializedMlirToXlaHlo(
|
||||
mlir_module, absl::Span<TensorShape>(arg_shapes), TestShapeRepresentation,
|
||||
&compilation_result);
|
||||
mlir_module, arg_shapes, TestShapeRepresentation, &compilation_result);
|
||||
ASSERT_TRUE(s.ok());
|
||||
|
||||
const xla::HloModuleConfig module_config(
|
||||
@ -134,8 +133,7 @@ TEST(CompileSerializedMlirToXlaHloTest, CompileTimeConstantFoldedSuccess) {
|
||||
XlaCompiler::CompilationResult compilation_result;
|
||||
|
||||
Status s = CompileSerializedMlirToXlaHlo(
|
||||
mlir_module, absl::Span<TensorShape>(arg_shapes), TestShapeRepresentation,
|
||||
&compilation_result);
|
||||
mlir_module, arg_shapes, TestShapeRepresentation, &compilation_result);
|
||||
ASSERT_TRUE(s.ok());
|
||||
|
||||
const xla::HloModuleConfig module_config(
|
||||
@ -174,8 +172,7 @@ TEST(CompileSerializedMlirToXlaHloTest, ShapeInference) {
|
||||
XlaCompiler::CompilationResult compilation_result;
|
||||
|
||||
Status s = CompileSerializedMlirToXlaHlo(
|
||||
mlir_module, absl::Span<TensorShape>(arg_shapes), TestShapeRepresentation,
|
||||
&compilation_result);
|
||||
mlir_module, arg_shapes, TestShapeRepresentation, &compilation_result);
|
||||
TF_ASSERT_OK(s);
|
||||
|
||||
const xla::HloModuleConfig module_config(
|
||||
|
@ -22,13 +22,11 @@ limitations under the License.
|
||||
#include "mlir/IR/Location.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
// Error utilities for MLIR when interacting with code using Status returns.
|
||||
namespace mlir {
|
||||
|
||||
// TensorFlow's Status is used for error reporting back to callers.
|
||||
using stream_executor::port::StatusOr;
|
||||
using tensorflow::Status;
|
||||
|
||||
// Diagnostic handler that collects all the diagnostics reported and can produce
|
||||
|
@ -34,6 +34,7 @@ limitations under the License.
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/DebugStringHelper.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
|
||||
@ -253,21 +254,30 @@ StatusOr<std::unique_ptr<NodeDef>> GetOperationNodeDef(
|
||||
// Note: we do not use NodeBuilder or NodeDefBuilder as that would require
|
||||
// mapping back from the inputs to the input arguments.
|
||||
|
||||
// Some control flow ops in TensorFlow Graph have their respective "Ref" ops
|
||||
// as well. For example there is Enter and RefEnter op. RefEnter forwards
|
||||
// the input ref buffer to output. However both Enter and RefEnter are
|
||||
// mapped to tf_executor::EnterOp during import and then to _tf.Enter op in
|
||||
// control dialect. Check if it is a Ref op to correctly map to the TensorFlow
|
||||
// Graph op.
|
||||
llvm::SmallString<64> op_name;
|
||||
if (IsRefTypeControlOp(inst)) op_name = "Ref";
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto tf_name,
|
||||
GetTensorFlowOpName(inst->getName().getStringRef()));
|
||||
op_name.append(tf_name);
|
||||
if (IsLegacyCallInstruction(inst)) {
|
||||
// The op_name is the name of the function.
|
||||
op_name.append(
|
||||
inst->getAttrOfType<mlir::SymbolRefAttr>("f").getLeafReference());
|
||||
// Remove the attribute from the instruction as it is already converted to
|
||||
// op_name.
|
||||
auto attr_id = mlir::Identifier::get("f", inst->getContext());
|
||||
inst->removeAttr(attr_id);
|
||||
} else {
|
||||
// Some control flow ops in TensorFlow Graph have their respective "Ref" ops
|
||||
// as well. For example there is Enter and RefEnter op. RefEnter forwards
|
||||
// the input ref buffer to output. However both Enter and RefEnter are
|
||||
// mapped to tf_executor::EnterOp during import and then to _tf.Enter op in
|
||||
// control dialect. Check if it is a Ref op to correctly map to the
|
||||
// TensorFlow Graph op.
|
||||
if (IsRefTypeControlOp(inst)) op_name = "Ref";
|
||||
TF_ASSIGN_OR_RETURN(auto tf_name,
|
||||
GetTensorFlowOpName(inst->getName().getStringRef()));
|
||||
op_name.append(tf_name);
|
||||
}
|
||||
|
||||
node_def->set_name(name.str());
|
||||
node_def->set_op(op_name.str());
|
||||
node_def->set_name(name);
|
||||
|
||||
// Add inputs to the NodeDef based on the number of operands. This is required
|
||||
// as later when edges are added to the Node using Graph::AddEdge the
|
||||
@ -454,4 +464,9 @@ Status SetSizeAttribute(absl::string_view name, size_t size,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool IsLegacyCallInstruction(mlir::Operation* inst) {
|
||||
return llvm::dyn_cast<mlir::TF::LegacyCallOp>(inst) ||
|
||||
inst->getName().getStringRef().compare("_tf.LegacyCall") == 0;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -73,5 +73,16 @@ Status SetShapeAttribute(absl::string_view name, mlir::ShapedType shape,
|
||||
// If the attribute already exists with a different value, returns an error.
|
||||
Status SetSizeAttribute(absl::string_view name, size_t size,
|
||||
AttrValueMap* values);
|
||||
|
||||
// Returns true if the given instruction is an mlir::TF::LegacyCallOp or the
|
||||
// result of such an operation transformed by the
|
||||
// ExecutorToControlDialectConversion pass.
|
||||
//
|
||||
// TODO(b/145706023): When the ExecutorToControlDialectConversion pass runs
|
||||
// before the exporter, it mutates an mlir::TF::LegacyCallOp instruction to
|
||||
// an instruction with a different operation name. As such, this routine checks
|
||||
// both forms of a LegacyCall instruction. We only need to check for
|
||||
// mlir::TF::LegacyCallOp when the ticket is resolved.
|
||||
bool IsLegacyCallInstruction(mlir::Operation* inst);
|
||||
} // namespace tensorflow
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_EXPORTER_UTILS_H_
|
||||
|
@ -23,6 +23,8 @@ package_group(
|
||||
],
|
||||
)
|
||||
|
||||
exports_files(["ir/hlo_ops.td"])
|
||||
|
||||
filegroup(
|
||||
name = "hlo_ops_td_files",
|
||||
srcs = [
|
||||
@ -406,6 +408,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/client/lib:matrix",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:Analysis",
|
||||
"@local_config_mlir//:IR",
|
||||
|
@ -182,13 +182,12 @@ tensorflow::Status HloFunctionImporter::ImportInstructions(
|
||||
// Setup the return type (HLO only supports a single return value).
|
||||
TF_ASSIGN_OR_RETURN(auto result,
|
||||
GetMlirValue(computation->root_instruction()));
|
||||
llvm::SmallVector<Value*, 1> return_values({result});
|
||||
|
||||
// Create terminator op depending on the parent op of this region.
|
||||
if (llvm::isa<FuncOp>(block->getParentOp())) {
|
||||
builder.create<mlir::ReturnOp>(loc, makeArrayRef(return_values));
|
||||
builder.create<mlir::ReturnOp>(loc, result);
|
||||
} else {
|
||||
builder.create<mlir::xla_hlo::ReturnOp>(loc, makeArrayRef(return_values));
|
||||
builder.create<mlir::xla_hlo::ReturnOp>(loc, result);
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
@ -266,32 +265,20 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
MakeAndReturn(CompareOp);
|
||||
}
|
||||
case HloOpcode::kGather: {
|
||||
const auto& gather_dimensions = instruction->gather_dimension_numbers();
|
||||
std::vector<int64_t> offset_dims(gather_dimensions.offset_dims().begin(),
|
||||
gather_dimensions.offset_dims().end());
|
||||
auto gather_instruction = static_cast<HloGatherInstruction*>(instruction);
|
||||
attributes.push_back(ConvertGatherDimensionNumbers(
|
||||
gather_instruction->gather_dimension_numbers()));
|
||||
|
||||
std::vector<int64_t> slice_sizes(
|
||||
instruction->gather_slice_sizes().begin(),
|
||||
instruction->gather_slice_sizes().end());
|
||||
gather_instruction->gather_slice_sizes().begin(),
|
||||
gather_instruction->gather_slice_sizes().end());
|
||||
attributes.push_back(
|
||||
builder_->getNamedAttr("slice_sizes", Convert(slice_sizes)));
|
||||
attributes.push_back(builder_->getNamedAttr(
|
||||
"indices_are_sorted",
|
||||
builder_->getBoolAttr(gather_instruction->indices_are_sorted())));
|
||||
|
||||
std::vector<int64_t> collapsed_slice_dims(
|
||||
gather_dimensions.collapsed_slice_dims().begin(),
|
||||
gather_dimensions.collapsed_slice_dims().end());
|
||||
|
||||
std::vector<int64_t> start_index_map(
|
||||
gather_dimensions.start_index_map().begin(),
|
||||
gather_dimensions.start_index_map().end());
|
||||
|
||||
// TODO(b/132057942): Change to explicitly passing an integer instead of
|
||||
// call getI64IntegerAttr here.
|
||||
return func_builder
|
||||
->create<mlir::xla_hlo::GatherOp>(
|
||||
loc, result_type, operands[0], operands[1],
|
||||
func_builder->getI64IntegerAttr(
|
||||
gather_dimensions.index_vector_dim()),
|
||||
Convert(offset_dims), Convert(slice_sizes),
|
||||
Convert(collapsed_slice_dims), Convert(start_index_map))
|
||||
.getOperation();
|
||||
MakeAndReturn(GatherOp);
|
||||
}
|
||||
case HloOpcode::kDynamicUpdateSlice: {
|
||||
return func_builder
|
||||
@ -707,4 +694,19 @@ mlir::NamedAttribute HloFunctionImporter::ConvertConvDimensionNumbers(
|
||||
return builder_->getNamedAttr("dimension_numbers", attr);
|
||||
}
|
||||
|
||||
mlir::NamedAttribute HloFunctionImporter::ConvertGatherDimensionNumbers(
|
||||
const xla::GatherDimensionNumbers& dnums) {
|
||||
std::vector<int64_t> offset_dims(dnums.offset_dims().begin(),
|
||||
dnums.offset_dims().end());
|
||||
std::vector<int64_t> collapsed_slice_dims(
|
||||
dnums.collapsed_slice_dims().begin(), dnums.collapsed_slice_dims().end());
|
||||
std::vector<int64_t> start_index_map(dnums.start_index_map().begin(),
|
||||
dnums.start_index_map().end());
|
||||
auto attr = mlir::xla_hlo::GatherDimensionNumbers::get(
|
||||
Convert(offset_dims), Convert(collapsed_slice_dims),
|
||||
Convert(start_index_map),
|
||||
builder_->getI64IntegerAttr(dnums.index_vector_dim()), context_);
|
||||
return builder_->getNamedAttr("dimension_numbers", attr);
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -117,6 +117,10 @@ class HloFunctionImporter {
|
||||
mlir::NamedAttribute ConvertConvDimensionNumbers(
|
||||
const xla::ConvolutionDimensionNumbers& dnums);
|
||||
|
||||
// Converts the gather dimensions to attributes.
|
||||
mlir::NamedAttribute ConvertGatherDimensionNumbers(
|
||||
const xla::GatherDimensionNumbers& dnums);
|
||||
|
||||
mlir::MLIRContext* context_;
|
||||
mlir::ModuleOp module_;
|
||||
mlir::Builder* builder_;
|
||||
|
@ -606,7 +606,7 @@ static TensorType GetReduceResultType(Type operand_ty,
|
||||
}
|
||||
|
||||
void ReduceOp::build(Builder* builder, OperationState& state,
|
||||
ArrayRef<Value*> operands, ArrayRef<Value*> init_values,
|
||||
ValueRange operands, ValueRange init_values,
|
||||
DenseIntElementsAttr dimensions) {
|
||||
SmallVector<Type, 1> result_ty;
|
||||
result_ty.reserve(operands.size());
|
||||
@ -845,9 +845,8 @@ Type SliceOp::InferOutputTypes(Builder* builder, Value* operand,
|
||||
// SortOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void SortOp::build(Builder* builder, OperationState& state,
|
||||
ArrayRef<Value*> operands, int64_t dimension,
|
||||
bool is_stable) {
|
||||
void SortOp::build(Builder* builder, OperationState& state, ValueRange operands,
|
||||
int64_t dimension, bool is_stable) {
|
||||
state.addOperands(operands);
|
||||
state.addAttribute("dimension", builder->getI64IntegerAttr(dimension));
|
||||
state.addAttribute("is_stable", builder->getBoolAttr(dimension));
|
||||
@ -990,7 +989,7 @@ void GetTupleElementOp::build(Builder* builder, OperationState& result,
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void TupleOp::build(Builder* builder, OperationState& result,
|
||||
ArrayRef<Value*> values) {
|
||||
ValueRange values) {
|
||||
SmallVector<Type, 4> types;
|
||||
types.reserve(values.size());
|
||||
for (auto val : values) {
|
||||
|
@ -405,8 +405,8 @@ def HLO_ReduceOp: HLO_Op<"reduce", [
|
||||
let results = (outs Variadic<HLO_TensorOrTuple>);
|
||||
|
||||
let builders = [OpBuilder<
|
||||
"Builder *, OperationState &state, ArrayRef<Value *> operands, "
|
||||
"ArrayRef<Value *> init_values, DenseIntElementsAttr dimensions"
|
||||
"Builder *, OperationState &state, ValueRange operands, "
|
||||
"ValueRange init_values, DenseIntElementsAttr dimensions"
|
||||
>];
|
||||
|
||||
let hasFolder = 1;
|
||||
@ -445,7 +445,7 @@ def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]>, BASE_HLO_TupleOp {
|
||||
|
||||
let builders = [OpBuilder<
|
||||
"Builder *builder, OperationState &results, "
|
||||
"ArrayRef<Value*> values">];
|
||||
"ValueRange values">];
|
||||
|
||||
// TupleOp has special conversion logic to HLO.
|
||||
let hasCustomHLOConverter = 1;
|
||||
@ -777,21 +777,25 @@ def HLO_FftOp: HLO_Op<"fft", [NoSideEffect]>, BASE_HLO_FftOp {
|
||||
let hasCustomHLOConverter = 1;
|
||||
}
|
||||
|
||||
def GatherDimensionNumbers : StructAttr<"GatherDimensionNumbers", HLO_Dialect,
|
||||
[StructFieldAttr<"offset_dims", I64ElementsAttr>,
|
||||
StructFieldAttr<"collapsed_slice_dims", I64ElementsAttr>,
|
||||
StructFieldAttr<"start_index_map", I64ElementsAttr>,
|
||||
StructFieldAttr<"index_vector_dim", I64Attr>]> {
|
||||
let description = "Structure of dimension information for gather";
|
||||
}
|
||||
|
||||
def HLO_GatherOp: HLO_Op<"gather", [NoSideEffect]>, BASE_HLO_GatherOp {
|
||||
let arguments = (ins
|
||||
HLO_Tensor:$operand,
|
||||
HLO_IntTensor:$start_indices,
|
||||
I64Attr:$index_vector_dim,
|
||||
I64ElementsAttr:$offset_dims,
|
||||
GatherDimensionNumbers:$dimension_numbers,
|
||||
I64ElementsAttr:$slice_sizes,
|
||||
I64ElementsAttr:$collapsed_slice_dims,
|
||||
I64ElementsAttr:$start_index_map
|
||||
DefaultValuedAttr<BoolAttr, "false">:$indices_are_sorted
|
||||
);
|
||||
|
||||
let results = (outs HLO_Tensor);
|
||||
|
||||
// TODO(b/129422361) Attributes are not supported by the codegen. The
|
||||
// optional argument (dimensions) needs to be added as an attribute.
|
||||
let hasCustomHLOConverter = 1;
|
||||
}
|
||||
|
||||
@ -880,7 +884,7 @@ def HLO_SortOp : HLO_Op<"sort", [NoSideEffect]>, BASE_HLO_SortOp {
|
||||
let regions = (region SizedRegion<1>:$comparator);
|
||||
|
||||
let builders = [OpBuilder<
|
||||
"Builder *builder, OperationState &state, ArrayRef<Value *> operands, "
|
||||
"Builder *builder, OperationState &state, ValueRange operands, "
|
||||
"int64_t dimension, bool is_stable"
|
||||
>];
|
||||
|
||||
|
@ -40,7 +40,9 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
using ::stream_executor::port::StatusOr;
|
||||
using ::tensorflow::int16;
|
||||
using ::tensorflow::int32;
|
||||
using ::tensorflow::int64;
|
||||
@ -149,6 +151,7 @@ I64_ELEMENTS_ATTR_TO_VECTOR(permutation);
|
||||
I64_ELEMENTS_ATTR_TO_VECTOR(start_indices);
|
||||
I64_ELEMENTS_ATTR_TO_VECTOR(limit_indices);
|
||||
I64_ELEMENTS_ATTR_TO_VECTOR(strides);
|
||||
I64_ELEMENTS_ATTR_TO_VECTOR(slice_sizes);
|
||||
|
||||
#undef I64_ELEMENTS_ATTR_TO_VECTOR
|
||||
|
||||
@ -267,6 +270,30 @@ static xla::ComparisonDirection Convert_comparison_direction(
|
||||
.ValueOrDie();
|
||||
}
|
||||
|
||||
static xla::GatherDimensionNumbers Convert_gather_dimension_numbers(
|
||||
mlir::xla_hlo::GatherDimensionNumbers input) {
|
||||
xla::GatherDimensionNumbers output;
|
||||
|
||||
auto offset_dims = ConvertDenseIntAttr(input.offset_dims());
|
||||
std::copy(offset_dims.begin(), offset_dims.end(),
|
||||
tensorflow::protobuf::RepeatedFieldBackInserter(
|
||||
output.mutable_offset_dims()));
|
||||
|
||||
auto collapsed_slice_dims = ConvertDenseIntAttr(input.collapsed_slice_dims());
|
||||
std::copy(collapsed_slice_dims.begin(), collapsed_slice_dims.end(),
|
||||
tensorflow::protobuf::RepeatedFieldBackInserter(
|
||||
output.mutable_collapsed_slice_dims()));
|
||||
|
||||
auto start_index_map = ConvertDenseIntAttr(input.start_index_map());
|
||||
std::copy(start_index_map.begin(), start_index_map.end(),
|
||||
tensorflow::protobuf::RepeatedFieldBackInserter(
|
||||
output.mutable_start_index_map()));
|
||||
|
||||
output.set_index_vector_dim(
|
||||
ConvertAPInt(input.index_vector_dim().getValue()));
|
||||
return output;
|
||||
}
|
||||
|
||||
static xla::ScatterDimensionNumbers Convert_scatter_dimension_numbers(
|
||||
mlir::xla_hlo::ScatterDimensionNumbers input) {
|
||||
xla::ScatterDimensionNumbers output;
|
||||
@ -496,7 +523,13 @@ LogicalResult ExportXlaOp(DynamicUpdateSliceOp op, OpLoweringContext ctx) {
|
||||
LogicalResult ExportXlaOp(FftOp op, OpLoweringContext ctx) { return failure(); }
|
||||
|
||||
LogicalResult ExportXlaOp(GatherOp op, OpLoweringContext ctx) {
|
||||
return failure();
|
||||
auto& value_map = *ctx.values;
|
||||
xla::GatherDimensionNumbers dimension_numbers =
|
||||
Convert_gather_dimension_numbers(op.dimension_numbers());
|
||||
value_map[op] = xla::Gather(
|
||||
value_map[op.operand()], value_map[op.start_indices()], dimension_numbers,
|
||||
Convert_slice_sizes(op.slice_sizes()), op.indices_are_sorted());
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ExportXlaOp(IotaOp op, OpLoweringContext ctx) {
|
||||
|
@ -25,18 +25,25 @@ func @fusedBatchNorm_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>,
|
||||
|
||||
// CHECK-LABEL: func @biasAdd_NHWC
|
||||
func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> {
|
||||
// CHECK-NEXT: %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>}
|
||||
// CHECK: "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>}
|
||||
%0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32>
|
||||
return %0 : tensor<1x32x10x32xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @biasAdd_NCHW
|
||||
func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> {
|
||||
// CHECK-NEXT: %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK: "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
%0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NCHW"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32>
|
||||
return %0 : tensor<1x32x10x32xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @biasAdd_dynamic
|
||||
func @biasAdd_dynamic(%arg0: tensor<?x?x?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?x?x?x?xi32> {
|
||||
// CHECK: "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
%0 = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NCHW"} : (tensor<?x?x?x?xi32>, tensor<?xi32>) -> tensor<?x?x?x?xi32>
|
||||
return %0 : tensor<?x?x?x?xi32>
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Binary op legalizations.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -666,11 +673,18 @@ func @preventgradient(%arg0: tensor<1xi32>) -> tensor<1xi32> {
|
||||
|
||||
// CHECK-LABEL: @const
|
||||
func @const() -> tensor<2xi32> {
|
||||
// CHECK-NEXT: xla_hlo.constant dense<0> : tensor<2xi32>
|
||||
// CHECK: xla_hlo.constant dense<0> : tensor<2xi32>
|
||||
%0 = "tf.Const"() {device = "", name = "", dtype = "tfdtype$DT_INT32", value = dense<0> : tensor<2xi32>} : () -> (tensor<2xi32>)
|
||||
return %0: tensor<2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @const_dynamic_output
|
||||
func @const_dynamic_output() -> tensor<*xi32> {
|
||||
// CHECK: xla_hlo.constant {value = dense<0> : tensor<2xi32>} : tensor<*xi32>
|
||||
%0 = "tf.Const"() {value = dense<0> : tensor<2xi32>} : () -> (tensor<*xi32>)
|
||||
return %0: tensor<*xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @opaque_const
|
||||
func @opaque_const() -> tensor<!tf.variant<tensor<2xi32>>> {
|
||||
// CHECK-NOT: xla_hlo.constant
|
||||
@ -838,13 +852,14 @@ func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @relu_grad
|
||||
// CHECK-SAME: (%[[GRADIENTS:.*]]: tensor<4x8xf32>, %[[FEATURES:.*]]: tensor<4x8xf32>)
|
||||
func @relu_grad(%gradients: tensor<4x8xf32>, %features: tensor<4x8xf32>) -> tensor<4x8xf32> {
|
||||
// CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<4x8xf32>
|
||||
// CHECK: %[[PRED:.*]] = "xla_hlo.compare"(%[[FEATURES]], %[[ZERO]]) {comparison_direction = "GT"} : (tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xi1>
|
||||
// CHECK: %[[RESULT:.*]] = "xla_hlo.select"(%[[PRED]], %[[GRADIENTS]], %[[ZERO]]) : (tensor<4x8xi1>, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
|
||||
// CHECK: return %[[RESULT]] : tensor<4x8xf32>
|
||||
%2 = "tf.ReluGrad"(%gradients, %features) : (tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
|
||||
// CHECK-SAME: (%[[GRADIENTS:.*]]: tensor<4x8xf32>, %[[FEATURES:.*]]: tensor<?x?xf32>)
|
||||
func @relu_grad(%gradients: tensor<4x8xf32>, %features: tensor<?x?xf32>) -> tensor<4x8xf32> {
|
||||
// CHECK-DAG: %[[ZERO_SCALAR:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK-DAG: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<4x8xf32>
|
||||
// CHECK-DAG: %[[PRED:.*]] = "xla_hlo.compare"(%[[FEATURES]], %[[ZERO_SCALAR]]) {comparison_direction = "GT"} : (tensor<?x?xf32>, tensor<f32>) -> tensor<*xi1>
|
||||
// CHECK-DAG: %[[RESULT:.*]] = "xla_hlo.select"(%[[PRED]], %[[GRADIENTS]], %[[ZERO]]) : (tensor<*xi1>, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
|
||||
// CHECK-DAG: return %[[RESULT]] : tensor<4x8xf32>
|
||||
%2 = "tf.ReluGrad"(%gradients, %features) : (tensor<4x8xf32>, tensor<?x?xf32>) -> tensor<4x8xf32>
|
||||
return %2 : tensor<4x8xf32>
|
||||
}
|
||||
|
||||
@ -1019,6 +1034,14 @@ func @transpose_2d(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> {
|
||||
return %0 : tensor<3x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @transpose_3d_int32
|
||||
func @transpose_3d_int32(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> {
|
||||
%permutation = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi32>} : () -> (tensor<3xi32>)
|
||||
// CHECK: "xla_hlo.transpose"
|
||||
%0 = "tf.Transpose"(%arg0, %permutation) : (tensor<1x2x3xf32>, tensor<3xi32>) -> tensor<3x2x1xf32>
|
||||
return %0 : tensor<3x2x1xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @transpose_3d
|
||||
func @transpose_3d(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> {
|
||||
%permutation = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi64>} : () -> (tensor<3xi64>)
|
||||
@ -1344,35 +1367,42 @@ func @tanh_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
|
||||
// CHECK-LABEL: reshape
|
||||
func @reshape(%arg0: tensor<2xf32>, %arg1: tensor<2xi32>) -> tensor<1x1xf32> {
|
||||
// CHECK: %0 = "xla_hlo.reshape"(%arg0) : (tensor<2xf32>) -> tensor<1x1xf32>
|
||||
// CHECK: "xla_hlo.reshape"
|
||||
%0 = "tf.Reshape"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xi32>) -> tensor<1x1xf32>
|
||||
return %0 : tensor<1x1xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: reshape_dynamic
|
||||
func @reshape_dynamic(%arg0: tensor<*xf32>, %arg1: tensor<2xi32>) -> tensor<?x?xf32> {
|
||||
// CHECK: %0 = "tf.Reshape"(%arg0, %arg1) : (tensor<*xf32>, tensor<2xi32>) -> tensor<?x?xf32>
|
||||
func @reshape_dynamic(%arg0: tensor<?xf32>, %arg1: tensor<2xi32>) -> tensor<1x1xf32> {
|
||||
// CHECK: "xla_hlo.reshape"
|
||||
%0 = "tf.Reshape"(%arg0, %arg1) : (tensor<?xf32>, tensor<2xi32>) -> tensor<1x1xf32>
|
||||
return %0 : tensor<1x1xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: reshape_unranked
|
||||
func @reshape_unranked(%arg0: tensor<*xf32>, %arg1: tensor<2xi32>) -> tensor<?x?xf32> {
|
||||
// CHECK: "tf.Reshape"
|
||||
%0 = "tf.Reshape"(%arg0, %arg1) : (tensor<*xf32>, tensor<2xi32>) -> tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: squeeze
|
||||
func @squeeze(%arg0: tensor<1x1x10xf32>) -> tensor<1x10xf32> {
|
||||
// CHECK-NEXT: %0 = "xla_hlo.reshape"(%arg0) : (tensor<1x1x10xf32>) -> tensor<1x10xf32>
|
||||
// CHECK: "xla_hlo.reshape"
|
||||
%0 = "tf.Squeeze"(%arg0) : (tensor<1x1x10xf32>) -> tensor<1x10xf32>
|
||||
return %0 : tensor<1x10xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: squeeze_dynamic
|
||||
func @squeeze_dynamic(%arg0: tensor<?x10xf32>) -> tensor<*xf32> {
|
||||
// CHECK-NEXT: %0 = "tf.Squeeze"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
|
||||
// CHECK: "tf.Squeeze"
|
||||
%0 = "tf.Squeeze"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: expand_dims
|
||||
func @expand_dims(%arg0: tensor<2xf32>, %axis: tensor<i32>) -> tensor<1x2xf32> {
|
||||
// CHECK: "xla_hlo.reshape"{{.*}} : (tensor<2xf32>) -> tensor<1x2xf32>
|
||||
// CHECK: "xla_hlo.reshape"
|
||||
%0 = "tf.ExpandDims"(%arg0, %axis) : (tensor<2xf32>, tensor<i32>) -> tensor<1x2xf32>
|
||||
return %0 : tensor<1x2xf32>
|
||||
}
|
||||
@ -1380,7 +1410,8 @@ func @expand_dims(%arg0: tensor<2xf32>, %axis: tensor<i32>) -> tensor<1x2xf32> {
|
||||
// CHECK-LABEL: slice_constant_start
|
||||
func @slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> {
|
||||
// CHECK: %[[START:.*]] = xla_hlo.constant dense<1> : tensor<1xi64>
|
||||
// CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[START]]) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor<1xi64>) -> tensor<2xi32>
|
||||
// CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START]]) : (tensor<1xi64>) -> tensor<1xi64>
|
||||
// CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[START_I64]]) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor<1xi64>) -> tensor<2xi32>
|
||||
// CHECK: return %[[RESULT]] : tensor<2xi32>
|
||||
%starts = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>)
|
||||
%sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi64>} : () -> (tensor<1xi64>)
|
||||
@ -1388,10 +1419,22 @@ func @slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> {
|
||||
return %0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: slice_i32_consts
|
||||
func @slice_i32_consts(%arg0: tensor<4xi32>) -> tensor<2xi32> {
|
||||
// CHECK: %[[START:.*]] = xla_hlo.constant dense<1> : tensor<1xi32>
|
||||
// CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START]]) : (tensor<1xi32>) -> tensor<1xi64>
|
||||
// CHECK: slice_sizes = dense<2> : tensor<1xi64>
|
||||
%starts = "tf.Const"() {value = dense<[1]> : tensor<1xi32>} : () -> (tensor<1xi32>)
|
||||
%sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi32>} : () -> (tensor<1xi32>)
|
||||
%0 = "tf.Slice"(%arg0, %starts, %sizes) : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
|
||||
return %0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: slice_constant_start_negative_one_size
|
||||
func @slice_constant_start_negative_one_size(%arg0: tensor<4xi32>) -> tensor<3xi32> {
|
||||
// CHECK: %[[START:.*]] = xla_hlo.constant dense<1> : tensor<1xi64>
|
||||
// CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[START]]) {slice_sizes = dense<3> : tensor<1xi64>} : (tensor<4xi32>, tensor<1xi64>) -> tensor<3xi32>
|
||||
// CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START]]) : (tensor<1xi64>) -> tensor<1xi64>
|
||||
// CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[START_I64]]) {slice_sizes = dense<3> : tensor<1xi64>} : (tensor<4xi32>, tensor<1xi64>) -> tensor<3xi32>
|
||||
// CHECK: return %[[RESULT]] : tensor<3xi32>
|
||||
%starts = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>)
|
||||
%sizes = "tf.Const"() {value = dense<[-1]> : tensor<1xi64>} : () -> (tensor<1xi64>)
|
||||
@ -1402,7 +1445,8 @@ func @slice_constant_start_negative_one_size(%arg0: tensor<4xi32>) -> tensor<3xi
|
||||
// CHECK-LABEL: slice_constant_start_dynamic_shape
|
||||
func @slice_constant_start_dynamic_shape(%arg0: tensor<?x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> {
|
||||
// CHECK: %[[START:.*]] = xla_hlo.constant dense<[1, 0]> : tensor<2xi64>
|
||||
// CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[START]]) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<?x4xi32>, tensor<2xi64>) -> tensor<1x4xi32>
|
||||
// CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START]]) : (tensor<2xi64>) -> tensor<2xi64>
|
||||
// CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[START_I64]]) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<?x4xi32>, tensor<2xi64>) -> tensor<1x4xi32>
|
||||
// CHECK: return %[[RESULT]] : tensor<1x4xi32>
|
||||
%starts = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>)
|
||||
%sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>)
|
||||
@ -1412,7 +1456,8 @@ func @slice_constant_start_dynamic_shape(%arg0: tensor<?x4xi32>, %arg1: tensor<2
|
||||
|
||||
// CHECK-LABEL: slice_variable_start
|
||||
func @slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> {
|
||||
// CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xi32>
|
||||
// CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%arg1) : (tensor<2xi64>) -> tensor<2xi64>
|
||||
// CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[START_I64]]) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xi32>
|
||||
// CHECK: return %[[RESULT]] : tensor<1x4xi32>
|
||||
%sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>)
|
||||
%0 = "tf.Slice"(%arg0, %arg1, %sizes) : (tensor<3x4xi32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32>
|
||||
@ -1525,6 +1570,16 @@ func @mean(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> {
|
||||
return %0 : tensor<4x1xf16>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @mean_scalar_dim
|
||||
func @mean_scalar_dim(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> {
|
||||
// Verify that tf.Mean op with scalar attributes are lowered successfully.
|
||||
|
||||
// CHECK-NOT: tf.Mean
|
||||
%dimension = "tf.Const"() { value = dense<1> : tensor<i64> } : () -> tensor<i64>
|
||||
%0 = "tf.Mean"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<i64>) -> tensor<4x1xf16>
|
||||
return %0 : tensor<4x1xf16>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @mean_dynamic
|
||||
func @mean_dynamic(%arg0: tensor<4x?xf16>) -> tensor<4x1xf16> {
|
||||
%dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64>
|
||||
@ -1601,6 +1656,66 @@ func @max_dynamic(%arg0: tensor<4x?xf16>) -> tensor<4x1xf16> {
|
||||
return %0 : tensor<4x1xf16>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @all
|
||||
func @all(%input: tensor<4x8xi1>) -> tensor<4xi1> {
|
||||
%dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
// CHECK: %[[INIT:.*]] = xla_hlo.constant dense<true> : tensor<i1>
|
||||
// CHECK: "xla_hlo.reduce"(%{{.*}}, %[[INIT]]) ( {
|
||||
// CHECK: ^{{.*}}(%[[ARGA:.*]]: tensor<i1>, %[[ARGB:.*]]: tensor<i1>):
|
||||
// CHECK: %[[AND:.*]] = xla_hlo.and %[[ARGA]], %[[ARGB]] : tensor<i1>
|
||||
// CHECK: "xla_hlo.return"(%[[AND]]) : (tensor<i1>) -> ()
|
||||
// CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xi1>, tensor<i1>) -> tensor<4xi1>
|
||||
%0 = "tf.All"(%input, %dims) : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4xi1>
|
||||
return %0 : tensor<4xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @all_keep_dim
|
||||
func @all_keep_dim(%input: tensor<4x8xi1>) -> tensor<4x1xi1> {
|
||||
// CHECK: "xla_hlo.reshape"(%{{.*}}) : (tensor<4xi1>) -> tensor<4x1xi1>
|
||||
%dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%0 = "tf.All"(%input, %dims) {keep_dims = true} : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4x1xi1>
|
||||
return %0 : tensor<4x1xi1>
|
||||
}
|
||||
|
||||
// CHECk-LABEL: @all_dynamic
|
||||
func @all_dynamic(%input: tensor<4x?xi1>) -> tensor<4x1xi1> {
|
||||
%dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
// CHECK: %[[ARG:.*]] = "xla_hlo.convert"(%{{.*}}) : (tensor<4x?xi1>) -> tensor<4x?xi1>
|
||||
// CHECK: "xla_hlo.reduce"(%[[ARG]]
|
||||
%0 = "tf.All"(%input, %dims) {keep_dims = true} : (tensor<4x?xi1>, tensor<1xi32>) -> tensor<4x1xi1>
|
||||
return %0 : tensor<4x1xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @any
|
||||
func @any(%input: tensor<4x8xi1>) -> tensor<4xi1> {
|
||||
%dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
// CHECK: %[[INIT:.*]] = xla_hlo.constant dense<false> : tensor<i1>
|
||||
// CHECK: "xla_hlo.reduce"(%{{.*}}, %[[INIT]]) ( {
|
||||
// CHECK: ^{{.*}}(%[[ARGA:.*]]: tensor<i1>, %[[ARGB:.*]]: tensor<i1>):
|
||||
// CHECK: %[[AND:.*]] = xla_hlo.or %[[ARGA]], %[[ARGB]] : tensor<i1>
|
||||
// CHECK: "xla_hlo.return"(%[[AND]]) : (tensor<i1>) -> ()
|
||||
// CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xi1>, tensor<i1>) -> tensor<4xi1>
|
||||
%0 = "tf.Any"(%input, %dims) : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4xi1>
|
||||
return %0 : tensor<4xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @any_keep_dim
|
||||
func @any_keep_dim(%input: tensor<4x8xi1>) -> tensor<4x1xi1> {
|
||||
// CHECK: "xla_hlo.reshape"(%{{.*}}) : (tensor<4xi1>) -> tensor<4x1xi1>
|
||||
%dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%0 = "tf.Any"(%input, %dims) {keep_dims = true} : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4x1xi1>
|
||||
return %0 : tensor<4x1xi1>
|
||||
}
|
||||
|
||||
// CHECk-LABEL: @any_dynamic
|
||||
func @any_dynamic(%input: tensor<4x?xi1>) -> tensor<4x1xi1> {
|
||||
%dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
// CHECK: %[[ARG:.*]] = "xla_hlo.convert"(%{{.*}}) : (tensor<4x?xi1>) -> tensor<4x?xi1>
|
||||
// CHECK: "xla_hlo.reduce"(%[[ARG]]
|
||||
%0 = "tf.Any"(%input, %dims) {keep_dims = true} : (tensor<4x?xi1>, tensor<1xi32>) -> tensor<4x1xi1>
|
||||
return %0 : tensor<4x1xi1>
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Tile op legalizations.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1924,12 +2039,23 @@ func @split_match_and_split_into_two(%input: tensor<4x6xf32>) -> (tensor<2x6xf32
|
||||
return %0#0, %0#1 : tensor<2x6xf32>, tensor<2x6xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @split_match_and_split_into_two_dynamic
|
||||
func @split_match_and_split_into_two_dynamic(%input: tensor<4x?xf32>) -> (tensor<2x?xf32>, tensor<2x?xf32>) {
|
||||
%cst = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: %[[ONE:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[2, -1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x?xf32>) -> tensor<2x?xf32>
|
||||
// CHECK: %[[TWO:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[4, -1]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x?xf32>) -> tensor<2x?xf32>
|
||||
%0:2 = "tf.Split"(%cst, %input) : (tensor<i32>, tensor<4x?xf32>) -> (tensor<2x?xf32>, tensor<2x?xf32>)
|
||||
// CHECK: return %[[ONE]], %[[TWO]]
|
||||
return %0#0, %0#1 : tensor<2x?xf32>, tensor<2x?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @split_match_and_split_into_three
|
||||
// CHECK-SAME: (%[[ARG:.*]]: tensor<4x6xf32>)
|
||||
func @split_match_and_split_into_three(%input: tensor<4x6xf32>) -> (tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>) {
|
||||
%cst = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: %[[ONE:.*]] = "xla_hlo.slice"(%arg0) {limit_indices = dense<[4, 2]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32>
|
||||
// CHECK: %[[TWO:.*]] = "xla_hlo.slice"(%arg0) {limit_indices = dense<4> : tensor<2xi64>, start_indices = dense<[0, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32>
|
||||
// CHECK: %[[THREE:.*]] = "xla_hlo.slice"(%arg0) {limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[0, 4]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32>
|
||||
// CHECK: %[[ONE:.*]] = "xla_hlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 2]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32>
|
||||
// CHECK: %[[TWO:.*]] = "xla_hlo.slice"(%[[ARG]]) {limit_indices = dense<4> : tensor<2xi64>, start_indices = dense<[0, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32>
|
||||
// CHECK: %[[THREE:.*]] = "xla_hlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[0, 4]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32>
|
||||
%0:3 = "tf.Split"(%cst, %input) : (tensor<i32>, tensor<4x6xf32>) -> (tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>)
|
||||
// CHECK: return %[[ONE]], %[[TWO]], %[[THREE]]
|
||||
return %0#0, %0#1, %0#2 : tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>
|
||||
@ -1973,3 +2099,82 @@ func @topk_v2(%input: tensor<16x16xf32>) -> (tensor<16x8xf32>, tensor<16x8xi32>)
|
||||
%0:2 = "tf.TopKV2"(%input, %k): (tensor<16x16xf32>, tensor<i32>) -> (tensor<16x8xf32>, tensor<16x8xi32>)
|
||||
return %0#0, %0#1: tensor<16x8xf32>, tensor<16x8xi32>
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// tf.SplitV legalization
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-LABEL: @splitv_match_and_split_into_three
|
||||
// CHECK-SAME: (%[[ARG:.*]]: tensor<4x6xf32>)
|
||||
func @splitv_match_and_split_into_three(%input: tensor<4x6xf32>) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) {
|
||||
%split_sizes = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||
%split_dim = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: %[[ONE:.*]] = "xla_hlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x1xf32>
|
||||
// CHECK: %[[TWO:.*]] = "xla_hlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32>
|
||||
// CHECK: %[[THREE:.*]] = "xla_hlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[0, 3]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x3xf32>
|
||||
%0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x6xf32>, tensor<3xi32>, tensor<i32>) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>)
|
||||
// CHECK: return %[[ONE]], %[[TWO]], %[[THREE]]
|
||||
return %0#0, %0#1, %0#2 : tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @splitv_match_and_split_into_three_dynamic
|
||||
func @splitv_match_and_split_into_three_dynamic(%input: tensor<?x6xf32>) -> (tensor<?x1xf32>, tensor<?x2xf32>, tensor<?x3xf32>) {
|
||||
%split_sizes = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||
%split_dim = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[-1, 1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<?x6xf32>) -> tensor<?x1xf32>
|
||||
// CHECK: "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[-1, 3]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<?x6xf32>) -> tensor<?x2xf32>
|
||||
// CHECK: "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[-1, 6]> : tensor<2xi64>, start_indices = dense<[0, 3]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<?x6xf32>) -> tensor<?x3xf32>
|
||||
%0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<?x6xf32>, tensor<3xi32>, tensor<i32>) -> (tensor<?x1xf32>, tensor<?x2xf32>, tensor<?x3xf32>)
|
||||
return %0#0, %0#1, %0#2 : tensor<?x1xf32>, tensor<?x2xf32>, tensor<?x3xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @splitv_dynamic_dim_in_split_sizes
|
||||
func @splitv_dynamic_dim_in_split_sizes(%input: tensor<4x6xf32>) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) {
|
||||
%split_sizes = "tf.Const"() {value = dense<[1, -1, 3]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||
%split_dim = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: limit_indices = dense<[4, 1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>
|
||||
// CHECK: limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>
|
||||
// CHECK: limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[0, 3]> : tensor<2xi64>
|
||||
%0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x6xf32>, tensor<3xi32>, tensor<i32>) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>)
|
||||
return %0#0, %0#1, %0#2 : tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// tf.Assert legalization
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-LABEL: @assert
|
||||
func @assert(%arg0: tensor<i1>, %arg1: tensor<*xf32>) {
|
||||
// CHECK-NOT: tf.Assert
|
||||
"tf.Assert"(%arg0, %arg1) {summarize = 1} : (tensor<i1>, tensor<*xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// tf.Unpack legalization
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-LABEL: @unpack
|
||||
func @unpack(%input: tensor<4x3x6xf32>) -> (tensor<4x?xf32>, tensor<4x6xf32>, tensor<4x6xf32>) {
|
||||
// CHECK: %[[SLICE1:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[4, 1, 6]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32>
|
||||
// CHECK: %[[RES1:.*]] = "xla_hlo.reshape"(%[[SLICE1]]) : (tensor<4x1x6xf32>) -> tensor<4x?xf32>
|
||||
// CHECK: %[[SLICE2:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[4, 2, 6]> : tensor<3xi64>, start_indices = dense<[0, 1, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32>
|
||||
// CHECK: %[[RES2:.*]] = "xla_hlo.reshape"(%[[SLICE2]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: %[[SLICE3:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[4, 3, 6]> : tensor<3xi64>, start_indices = dense<[0, 2, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32>
|
||||
// CHECK: %[[RES3:.*]] = "xla_hlo.reshape"(%[[SLICE3]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32>
|
||||
|
||||
%0:3 = "tf.Unpack"(%input) {axis = 1} : (tensor<4x3x6xf32>) -> (tensor<4x?xf32>, tensor<4x6xf32>, tensor<4x6xf32>)
|
||||
// return %[[RES1]], %[[RES2]], %[[RES3]]
|
||||
return %0#0, %0#1, %0#2 : tensor<4x?xf32>, tensor<4x6xf32>, tensor<4x6xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @unpack_dynamic
|
||||
func @unpack_dynamic(%input: tensor<?x?x2xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
|
||||
// CHECK: %[[SLICE1:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[-1, -1, 1]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<?x?x2xf32>) -> tensor<?x?x1xf32>
|
||||
// CHECK: "xla_hlo.reshape"(%[[SLICE1]]) : (tensor<?x?x1xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[SLICE2:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[-1, -1, 2]> : tensor<3xi64>, start_indices = dense<[0, 0, 1]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<?x?x2xf32>) -> tensor<?x?x1xf32>
|
||||
// CHECK: "xla_hlo.reshape"(%[[SLICE2]]) : (tensor<?x?x1xf32>) -> tensor<?x?xf32>
|
||||
|
||||
%0:2 = "tf.Unpack"(%input) {axis = -1} : (tensor<?x?x2xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>)
|
||||
return %0#0, %0#1 : tensor<?x?xf32>, tensor<?x?xf32>
|
||||
}
|
||||
|
@ -317,16 +317,33 @@ func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: HloModule
|
||||
// CHECK-LABEL: HloModule
|
||||
func @main(%arg0: tensor<3x4xi32>, %arg1: tensor<4x5xi32>) -> tensor<3x5xi32> {
|
||||
// Simple einsum is lowered to HLO dot op.
|
||||
// CHECK: dot(s32[3,4] %{{.*}}, s32[4,5] %{{.*}}), lhs_contracting_dims={1}, rhs_contracting_dims={0}
|
||||
// CHECK: dot(s32[3,4] %{{.*}}, s32[4,5] %{{.*}}), lhs_contracting_dims={1}, rhs_contracting_dims={0}
|
||||
%0 = "xla_hlo.einsum"(%arg0, %arg1) {einsum_config = "ab,bc->ac"} : (tensor<3x4xi32>, tensor<4x5xi32>) -> tensor<3x5xi32>
|
||||
return %0 : tensor<3x5xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: HloModule
|
||||
func @main(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> tensor<10x300xf32> {
|
||||
// CHECK: [[ARG0:%.*]] = f32[200,100,300] parameter(0)
|
||||
// CHECK: [[ARG1:%.*]] = s32[10,2] parameter(1)
|
||||
// CHECK: f32[10,300] gather(f32[200,100,300] [[ARG0]], s32[10,2] [[ARG1]])
|
||||
// CHECK-SAME: offset_dims={1}
|
||||
// CHECK-SAME: collapsed_slice_dims={0,1}
|
||||
// CHECK-SAME: start_index_map={0,1}
|
||||
// CHECK-SAME: index_vector_dim=1
|
||||
// CHECK-SAME: slice_sizes={1,1,300}
|
||||
// CHECK-SAME: indices_are_sorted=true
|
||||
%0 = "xla_hlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<[0, 1]> : tensor<2xi64>, index_vector_dim = 1 : i64, offset_dims = dense<1> : tensor<1xi64>, start_index_map = dense<[0, 1]> : tensor<2xi64>}, indices_are_sorted = true, name = "gather", slice_sizes = dense<[1, 1, 300]> : tensor<3xi64>} : (tensor<200x100x300xf32>, tensor<10x2xi32>) -> tensor<10x300xf32>
|
||||
return %0 : tensor<10x300xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: HloModule
|
||||
func @main(%arg: tensor<4x2xf32>) -> tensor<i32> {
|
||||
%0 = "xla_hlo.get_dimension_size"(%arg) {dimension = 1 : i32} : (tensor<4x2xf32>) -> tensor<i32>
|
||||
|
@ -317,6 +317,28 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] {
|
||||
ROOT %floor.2 = f32[16] floor(f32[16] %arg0.1)
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_gather(
|
||||
// CHECK-SAME: [[ARG0:%.+]]: tensor<200x100x300xf32>, [[ARG1:%.+]]: tensor<10x2xi32>) -> tensor<10x300xf32> {
|
||||
%test_gather (arg.0: f32[200,100,300], arg.1: s32[10,2]) -> f32[10,300] {
|
||||
%arg.0 = f32[200,100,300] parameter(0)
|
||||
%arg.1 = s32[10,2] parameter(1)
|
||||
// CHECK: "xla_hlo.gather"([[ARG0]], [[ARG1]])
|
||||
// CHECK-SAME: dimension_numbers
|
||||
// CHECK-SAME: collapsed_slice_dims = dense<[0, 1]> : tensor<2xi64>
|
||||
// CHECK-SAME: index_vector_dim = 1 : i64
|
||||
// CHECK-SAME: offset_dims = dense<1> : tensor<1xi64>
|
||||
// CHECK-SAME: start_index_map = dense<[0, 1]> : tensor<2xi64>
|
||||
// CHECK-SAME: indices_are_sorted = true
|
||||
// CHECK-SAME: slice_sizes = dense<[1, 1, 300]> : tensor<3xi64>
|
||||
ROOT gather = f32[10,300] gather(f32[200,100,300] %arg.0, s32[10,2] %arg.1),
|
||||
collapsed_slice_dims={0,1},
|
||||
index_vector_dim=1,
|
||||
offset_dims={1},
|
||||
start_index_map={0,1},
|
||||
indices_are_sorted=true,
|
||||
slice_sizes={1,1,300}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_get_dimension_size
|
||||
// CHECK-SAME: ([[ARG:%.*]]: tensor<4x2xf32>)
|
||||
%test_get_dimension_size (Arg_0.1: f32[4,2]) -> s32[] {
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include "absl/memory/memory.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/BlockAndValueMapping.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Location.h" // TF:local_config_mlir
|
||||
@ -38,13 +39,19 @@ namespace {
|
||||
|
||||
constexpr StringRef kTempBufferAttr = "temp";
|
||||
|
||||
Value* GetTensorStoreMemRef(Value* value) {
|
||||
Value* GetTensorStoreOrReturnMemRef(Value* value) {
|
||||
for (const auto& user : value->getUsers()) {
|
||||
if (auto tensor_store = dyn_cast<TensorStoreOp>(user)) {
|
||||
if (tensor_store.getOperand(0) == value) {
|
||||
return tensor_store.getOperand(1);
|
||||
}
|
||||
}
|
||||
if (auto return_op = dyn_cast<xla_hlo::ReturnOp>(user)) {
|
||||
if (return_op.getOperand(0) == value) {
|
||||
auto block = return_op.getOperation()->getBlock();
|
||||
return *block->args_rbegin();
|
||||
}
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
@ -88,8 +95,8 @@ Value* InsertAllocAndDealloc(Location loc, Value* result,
|
||||
/// function to store that values held in the tensor.
|
||||
Value* GetBufferForResultValue(Location loc, Value* result,
|
||||
ConversionPatternRewriter* rewriter) {
|
||||
if (auto tensor_store_memref = GetTensorStoreMemRef(result)) {
|
||||
return tensor_store_memref;
|
||||
if (auto existing_memref = GetTensorStoreOrReturnMemRef(result)) {
|
||||
return existing_memref;
|
||||
}
|
||||
return InsertAllocAndDealloc(loc, result, rewriter);
|
||||
}
|
||||
@ -117,7 +124,63 @@ class HloToLhloOpConverter : public ConversionPattern {
|
||||
rewriter.create<LhloOpTy>(op->getLoc(), llvm::None, buffer_args,
|
||||
op->getAttrs());
|
||||
rewriter.replaceOp(op, ArrayRef<Value*>(buffer_args).slice(operands.size()),
|
||||
llvm::to_vector<4>(original_results));
|
||||
original_results);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
struct HloToLHloReduceConverter
|
||||
: public OpConversionPattern<xla_hlo::ReduceOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(
|
||||
xla_hlo::ReduceOp op, ArrayRef<Value*> operands,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
auto loc = op.getLoc();
|
||||
// TODO(b/137624192) Implement variadic reduce.
|
||||
if (op.getNumResults() != 1) return matchFailure();
|
||||
if (op.getParentRegion()->getBlocks().size() != 1) {
|
||||
emitError(loc,
|
||||
"tensor to buffer conversion expects a single block in the "
|
||||
"region containing the operation");
|
||||
}
|
||||
const auto& original_results = op.getResults();
|
||||
SmallVector<Value*, 4> buffer_args(operands.begin(), operands.end());
|
||||
for (auto result : original_results) {
|
||||
buffer_args.push_back(GetBufferForResultValue(loc, result, &rewriter));
|
||||
}
|
||||
auto new_op = rewriter.create<xla_lhlo::ReduceOp>(
|
||||
loc, llvm::None, buffer_args, op.getAttrs());
|
||||
|
||||
// Copy over the operations inside the region.
|
||||
rewriter.inlineRegionBefore(op.body(), new_op.body(), new_op.body().end());
|
||||
|
||||
// Create new block arguments with correct type.
|
||||
auto& entry_block = new_op.body().front();
|
||||
int original_arg_count = entry_block.getNumArguments();
|
||||
for (int i = 0; i < original_arg_count; ++i) {
|
||||
auto old_arg = entry_block.getArgument(i);
|
||||
auto old_type = old_arg->getType().cast<TensorType>();
|
||||
auto new_type =
|
||||
MemRefType::get(old_type.getShape(), old_type.getElementType());
|
||||
auto new_arg = entry_block.addArgument(new_type);
|
||||
rewriter.replaceUsesOfBlockArgument(old_arg, new_arg);
|
||||
}
|
||||
// Add an argument for the result.
|
||||
entry_block.addArgument(
|
||||
entry_block.getArgument(original_arg_count)->getType());
|
||||
// Remove the old arguments.
|
||||
for (int i = original_arg_count - 1; i >= 0; --i) {
|
||||
entry_block.eraseArgument(i);
|
||||
}
|
||||
// Insert terminator at the end.
|
||||
rewriter.setInsertionPointToEnd(&entry_block);
|
||||
rewriter.create<xla_lhlo::TerminatorOp>(loc);
|
||||
|
||||
rewriter.replaceOp(op, ArrayRef<Value*>(buffer_args).slice(operands.size()),
|
||||
original_results);
|
||||
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
@ -130,11 +193,12 @@ class HloToLhloTensorLoadConverter : public ConversionPattern {
|
||||
PatternMatchResult matchAndRewrite(
|
||||
Operation* op, ArrayRef<Value*> operands,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
rewriter.replaceOp(op, operands, llvm::to_vector<4>(op->getResults()));
|
||||
rewriter.replaceOp(op, operands, op->getResults());
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
// TODO(b/137624192): Rewrite into a copy and elide copy if possible.
|
||||
class HloToLhloTensorStoreConverter : public ConversionPattern {
|
||||
public:
|
||||
explicit HloToLhloTensorStoreConverter(MLIRContext* context)
|
||||
@ -148,6 +212,19 @@ class HloToLhloTensorStoreConverter : public ConversionPattern {
|
||||
}
|
||||
};
|
||||
|
||||
// TODO(b/137624192): Rewrite into a copy and elide copy if possible.
|
||||
class HloToLhloReturnConverter : public OpConversionPattern<xla_hlo::ReturnOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(
|
||||
xla_hlo::ReturnOp op, ArrayRef<Value*> operands,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
rewriter.eraseOp(op);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
|
||||
// buffers if necessary.
|
||||
//
|
||||
@ -215,6 +292,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
|
||||
xla_lhlo::BroadcastInDimOp>,
|
||||
HloToLhloOpConverter<xla_hlo::CeilOp, xla_lhlo::CeilOp>,
|
||||
HloToLhloOpConverter<xla_hlo::CompareOp, xla_lhlo::CompareOp>,
|
||||
HloToLhloOpConverter<xla_hlo::ConstOp, xla_lhlo::ConstOp>,
|
||||
HloToLhloOpConverter<xla_hlo::ConvertOp, xla_lhlo::ConvertOp>,
|
||||
HloToLhloOpConverter<xla_hlo::CosOp, xla_lhlo::CosOp>,
|
||||
HloToLhloOpConverter<xla_hlo::DivOp, xla_lhlo::DivOp>,
|
||||
@ -229,6 +307,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
|
||||
HloToLhloOpConverter<xla_hlo::SignOp, xla_lhlo::SignOp>,
|
||||
HloToLhloOpConverter<xla_hlo::SubOp, xla_lhlo::SubOp>,
|
||||
HloToLhloOpConverter<xla_hlo::TanhOp, xla_lhlo::TanhOp>,
|
||||
HloToLHloReduceConverter, HloToLhloReturnConverter,
|
||||
HloToLhloTensorLoadConverter, HloToLhloTensorStoreConverter
|
||||
>(context);
|
||||
// clang-format on
|
||||
|
@ -53,8 +53,7 @@ LogicalResult ReplaceTerminators(Region* region, Block* target_block,
|
||||
auto return_op = dyn_cast<xla_hlo::ReturnOp>(block->getTerminator());
|
||||
if (!return_op) continue;
|
||||
builder->setInsertionPointToEnd(block);
|
||||
builder->create<mlir::BranchOp>(
|
||||
loc, target_block, llvm::to_vector<4>(return_op.getOperands()));
|
||||
builder->create<mlir::BranchOp>(loc, target_block, return_op.getOperands());
|
||||
return_op.erase();
|
||||
}
|
||||
|
||||
@ -196,8 +195,7 @@ LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) {
|
||||
dyn_cast<mlir::xla_hlo::ReturnOp>(new_block->getTerminator());
|
||||
if (!return_op) continue;
|
||||
builder.setInsertionPointToEnd(new_block);
|
||||
builder.create<mlir::BranchOp>(loc, cond_block,
|
||||
llvm::to_vector<4>(return_op.getOperands()));
|
||||
builder.create<mlir::BranchOp>(loc, cond_block, return_op.getOperands());
|
||||
return_op.erase();
|
||||
}
|
||||
|
||||
|
@ -127,8 +127,8 @@ static llvm::Optional<int64_t> GetIntegerHLOAxisFromTFAxis(Value *value,
|
||||
|
||||
/// Returns a `ConvertOp` that casts the elements to a i64 type while retaining
|
||||
/// the shape of the input value.
|
||||
static ConvertOp CastElementsToI64(Location loc, Value *value,
|
||||
PatternRewriter *rewriter) {
|
||||
static ConvertOp CastValueToI64(Location loc, Value *value,
|
||||
PatternRewriter *rewriter) {
|
||||
return rewriter->create<ConvertOp>(loc, value, rewriter->getIntegerType(64));
|
||||
}
|
||||
|
||||
@ -207,7 +207,8 @@ static IntegerAttr getFeatureDimensionAttr(Builder &b, StringAttr format,
|
||||
// Bias op utilities.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Return a 1D DenseIntElementsAttr for the feature dimension of a BiasAdd.
|
||||
// Return a 1D DenseIntElementsAttr for the feature dimension of a BiasAdd.
|
||||
// Requires input to have ranked tensor.
|
||||
static DenseIntElementsAttr getBiasFeatureDimension(Builder &b,
|
||||
StringAttr format,
|
||||
Value *input) {
|
||||
@ -418,7 +419,8 @@ static DenseIntElementsAttr TFSliceSizes2HLOSliceSizes(
|
||||
Builder *builder) {
|
||||
DenseIntElementsAttr constant_start_indices;
|
||||
if (!matchPattern(start_indices, m_Constant(&constant_start_indices))) {
|
||||
return slice_sizes;
|
||||
return xla::ConvertElementsAttr(slice_sizes, builder->getIntegerType(64))
|
||||
.cast<DenseIntElementsAttr>();
|
||||
}
|
||||
|
||||
auto input_ty = input->getType().dyn_cast<RankedTensorType>();
|
||||
@ -687,7 +689,7 @@ class ConvertEinsumOp : public OpRewritePattern<TF::EinsumOp> {
|
||||
rewriter.replaceOpWithNewOp<UnaryEinsumOp>(
|
||||
op, op.getType(), *op.inputs().begin(), equation);
|
||||
} else if (op.N() == 2) {
|
||||
auto inputs = llvm::to_vector<2>(op.inputs());
|
||||
ValueRange inputs = op.inputs();
|
||||
rewriter.replaceOpWithNewOp<EinsumOp>(op, op.getType(), inputs[0],
|
||||
inputs[1], equation);
|
||||
} else {
|
||||
@ -924,7 +926,7 @@ class ConvertSizeOp : public OpRewritePattern<TF::SizeOp> {
|
||||
};
|
||||
|
||||
// Converts the tf.Split op into a series of HLO slice ops when the tensor to be
|
||||
// split has fuly static shape and the dimension to split is a constant.
|
||||
// split has fully static shape and the dimension to split is a constant.
|
||||
//
|
||||
// The main logic of this pattern is to calculate the index start and end range
|
||||
// for each slice. And this happens only on the dimension to be split; for all
|
||||
@ -962,9 +964,9 @@ class ConvertSplitOp : public OpRewritePattern<TF::SplitOp> {
|
||||
|
||||
PatternMatchResult matchAndRewrite(TF::SplitOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// We can only match when the tensor to be split has fully static shape.
|
||||
// We can only split along static dimensions.
|
||||
auto input_type = op.value()->getType().dyn_cast<RankedTensorType>();
|
||||
if (!input_type || !input_type.hasStaticShape()) return matchFailure();
|
||||
if (!input_type) return matchFailure();
|
||||
|
||||
// We can only match when the split dimension is a constant scalar.
|
||||
DenseIntElementsAttr split_dim_attr;
|
||||
@ -978,6 +980,10 @@ class ConvertSplitOp : public OpRewritePattern<TF::SplitOp> {
|
||||
|
||||
// Calculate the dimension size for each slice along the split dimension.
|
||||
int64_t input_dim_size = input_type.getDimSize(dim_index);
|
||||
// If we are splitting along the dynamic dimension then we cannot compute
|
||||
// the static dimension length.
|
||||
if (TensorType::isDynamic(input_dim_size)) return matchFailure();
|
||||
|
||||
int64_t num_splits = op.getNumResults();
|
||||
int64_t slice_size = input_dim_size / num_splits;
|
||||
|
||||
@ -1011,6 +1017,118 @@ class ConvertSplitOp : public OpRewritePattern<TF::SplitOp> {
|
||||
}
|
||||
};
|
||||
|
||||
// Converts the tf.SplitV op into a series of HLO slice ops when the tensor to
|
||||
// be split has fully static shape and the dimension to split and split sizes
|
||||
// are constants.
|
||||
//
|
||||
// This is similar to the conversion for tf.Split op other than that the size of
|
||||
// each chunk on the dimension to split is explicitly given as an op operand
|
||||
// and they are not necessarily the same.
|
||||
//
|
||||
// For example, given the following IR:
|
||||
//
|
||||
// %split_sizes = "tf.Const"() {value = dense<[1, -1, 3]> : tensor<3xi32>}
|
||||
// %split_dim = "tf.Const"() {value = dense<1> : tensor<i32>}
|
||||
// %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) :
|
||||
// (tensor<4x6xf32>, tensor<3xi32>, tensor<i32>) ->
|
||||
// (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>)
|
||||
//
|
||||
// We will generate slices following slices:
|
||||
// %0 = "xla_hlo.slice"(%input) {
|
||||
// limit_indices = dense<[4, 1]> : tensor<2xi64>,
|
||||
// start_indices = dense<0> : tensor<2xi64>,
|
||||
// strides = dense<1> : tensor<2xi64>} :
|
||||
// (tensor<4x6xf32>) -> tensor<4x1xf32>
|
||||
// %1 = "xla_hlo.slice"(%input) {
|
||||
// limit_indices = dense<[4, 3]> : tensor<2xi64>,
|
||||
// start_indices = dense<[0, 1]> : tensor<2xi64>,
|
||||
// strides = dense<1> : tensor<2xi64>} :
|
||||
// (tensor<4x6xf32>) -> tensor<4x2xf32>
|
||||
// %2 = "xla_hlo.slice"(%input) {
|
||||
// limit_indices = dense<[4, 6]> : tensor<2xi64>,
|
||||
// start_indices = dense<[0, 3]> : tensor<2xi64>,
|
||||
// strides = dense<1> : tensor<2xi64>} :
|
||||
// (tensor<4x6xf32>) -> tensor<4x3xf32>
|
||||
class ConvertSplitVOp : public OpRewritePattern<TF::SplitVOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(TF::SplitVOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// We can only split along static dimensions.
|
||||
// TODO(b/145731001): enhance to support dynamic-shaped inputs.
|
||||
auto input_type = op.value()->getType().dyn_cast<RankedTensorType>();
|
||||
if (!input_type) return matchFailure();
|
||||
|
||||
// We can only match when the split dimension is a constant scalar.
|
||||
DenseIntElementsAttr split_dim_attr;
|
||||
if (!matchPattern(op.split_dim(), m_Constant(&split_dim_attr)))
|
||||
return matchFailure();
|
||||
|
||||
// We can only match when the split sizes is a constant int vector.
|
||||
DenseIntElementsAttr split_sizes_attr;
|
||||
if (!matchPattern(op.size_splits(), m_Constant(&split_sizes_attr)))
|
||||
return matchFailure();
|
||||
|
||||
// Get each chunck's size along the dimension to split. It may contain
|
||||
// dynamic sizes and we need to update it if so.
|
||||
SmallVector<int64_t, 4> split_sizes;
|
||||
int64_t total_dim_size = 0; // Total dimension size assigned to splits
|
||||
llvm::Optional<int> dynamic_dim_index;
|
||||
split_sizes.reserve(
|
||||
split_sizes_attr.getType().cast<ShapedType>().getNumElements());
|
||||
for (auto dim : llvm::enumerate(split_sizes_attr)) {
|
||||
int64_t dim_val = dim.value().getSExtValue();
|
||||
split_sizes.push_back(dim_val);
|
||||
if (dim_val == ShapedType::kDynamicSize) {
|
||||
// We cannot have more than one dynamic dimension.
|
||||
assert(!dynamic_dim_index && "invalid split sizes");
|
||||
dynamic_dim_index = dim.index();
|
||||
} else {
|
||||
total_dim_size += dim_val;
|
||||
}
|
||||
}
|
||||
|
||||
// Get the dimension we are splitting at. Offset properly if it's negative.
|
||||
int64_t input_rank = input_type.getRank();
|
||||
int64_t dim_index = (*split_dim_attr.begin()).getSExtValue();
|
||||
if (dim_index < 0) dim_index += input_rank;
|
||||
|
||||
int64_t input_dim_size = input_type.getDimSize(dim_index);
|
||||
if (TensorType::isDynamic(input_dim_size)) return matchFailure();
|
||||
|
||||
assert(((dynamic_dim_index && total_dim_size <= input_dim_size) ||
|
||||
(!dynamic_dim_index && total_dim_size == input_dim_size)) &&
|
||||
"invalid split sizes");
|
||||
|
||||
// Update the dynamic dimension with calculated concrete size.
|
||||
if (dynamic_dim_index)
|
||||
split_sizes[*dynamic_dim_index] = input_dim_size - total_dim_size;
|
||||
|
||||
// Parameters for constructing each slice.
|
||||
SmallVector<int64_t, 4> begin_indices(input_rank, 0);
|
||||
auto end_indices = llvm::to_vector<4>(input_type.getShape());
|
||||
SmallVector<int64_t, 4> strides(input_rank, 1);
|
||||
|
||||
// All HLO slice results used to replace the original tf.Split op.
|
||||
SmallVector<Value *, 4> slices;
|
||||
slices.reserve(op.getNumResults());
|
||||
|
||||
for (int i = 0; i < op.getNumResults(); ++i) {
|
||||
end_indices[dim_index] = begin_indices[dim_index] + split_sizes[i];
|
||||
slices.push_back(rewriter.create<xla_hlo::SliceOp>(
|
||||
op.getLoc(), op.value(), GetI64ElementsAttr(begin_indices, &rewriter),
|
||||
GetI64ElementsAttr(end_indices, &rewriter),
|
||||
GetI64ElementsAttr(strides, &rewriter)));
|
||||
// Prepare the begin indice for the next slice.
|
||||
begin_indices[dim_index] = end_indices[dim_index];
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, slices);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
// Converts StridedSlice op to HLO Slice op along with Reverse op to handle
|
||||
// negative strides and Reshape op to update the output shape. Indices and
|
||||
// strides operands are converted to attributes with non-negative indexing.
|
||||
@ -1182,8 +1300,7 @@ class GenericConvertReductionOp : public OpRewritePattern<OpTy> {
|
||||
ArrayRef<int64_t> input_shape = input_ty.getShape();
|
||||
|
||||
DenseIntElementsAttr dimensions;
|
||||
if (!matchPattern(op.reduction_indices(), m_Constant(&dimensions)) ||
|
||||
dimensions.getType().getRank() != 1)
|
||||
if (!matchPattern(op.reduction_indices(), m_Constant(&dimensions)))
|
||||
return this->matchFailure();
|
||||
|
||||
// Build the final shape from input_shape and dimensions using a bitmap
|
||||
@ -1260,7 +1377,6 @@ class ConvertMeanOp
|
||||
: public GenericConvertReductionOp<ConvertMeanOp, TF::MeanOp, AddOp> {
|
||||
public:
|
||||
using GenericConvertReductionOp::GenericConvertReductionOp;
|
||||
|
||||
static Value *GetInitialValue(Type reduce_element_type, Location loc,
|
||||
PatternRewriter &rewriter) {
|
||||
return GetScalarConstOfType(reduce_element_type, loc, 0, &rewriter);
|
||||
@ -1300,6 +1416,36 @@ class ConvertMaxOp
|
||||
}
|
||||
};
|
||||
|
||||
// Converts All op to HLO Reduce op.
|
||||
//
|
||||
// %init = constant dense<...> : tensor<T>
|
||||
// %max = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.and"]
|
||||
// {dimensions = ...}
|
||||
class ConvertAllOp
|
||||
: public GenericConvertReductionOp<ConvertAllOp, TF::AllOp, AndOp> {
|
||||
public:
|
||||
using GenericConvertReductionOp::GenericConvertReductionOp;
|
||||
static Value *GetInitialValue(Type reduce_element_type, Location loc,
|
||||
PatternRewriter &rewriter) {
|
||||
return GetScalarConstOfType(reduce_element_type, loc, 1, &rewriter);
|
||||
}
|
||||
};
|
||||
|
||||
// Converts Any op to HLO Reduce op.
|
||||
//
|
||||
// %init = constant dense<...> : tensor<T>
|
||||
// %max = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.or"]
|
||||
// {dimensions = ...}
|
||||
class ConvertAnyOp
|
||||
: public GenericConvertReductionOp<ConvertAnyOp, TF::AnyOp, OrOp> {
|
||||
public:
|
||||
using GenericConvertReductionOp::GenericConvertReductionOp;
|
||||
static Value *GetInitialValue(Type reduce_element_type, Location loc,
|
||||
PatternRewriter &rewriter) {
|
||||
return GetScalarConstOfType(reduce_element_type, loc, 0, &rewriter);
|
||||
}
|
||||
};
|
||||
|
||||
// Converts tensorflow ArgMin or ArgMax op to xla_hlo operations that perform
|
||||
// a reduction on the original input and the corresponding index. The reduction
|
||||
// sub-computation selects the max (or min) value and the index for the value.
|
||||
@ -2000,6 +2146,53 @@ class ConvertTopKV2Op : public OpRewritePattern<TF::TopKV2Op> {
|
||||
}
|
||||
};
|
||||
|
||||
// Converts tf.Unpack to a series of XLA HLO slice ops.
|
||||
//
|
||||
// Each slice takes one element along the dimension to unpack and takes the full
|
||||
// range for all other dimenions. Each slice is then reshaped to drop the
|
||||
// dimension to unpack (which is always of size 1).
|
||||
// TODO(antiagainst): consider changing this into a TF internal lowering pass.
|
||||
class ConvertUnpackOp : public OpRewritePattern<TF::UnpackOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(TF::UnpackOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto value_type = op.value()->getType().cast<RankedTensorType>();
|
||||
if (!value_type) return matchFailure();
|
||||
|
||||
int64_t value_rank = value_type.getRank();
|
||||
int64_t axis = op.axis().getSExtValue();
|
||||
if (axis < 0) axis += value_rank;
|
||||
|
||||
// Parameters for constructing each slice.
|
||||
SmallVector<int64_t, 4> begin_indices(value_rank, 0);
|
||||
auto end_indices = llvm::to_vector<4>(value_type.getShape());
|
||||
SmallVector<int64_t, 4> strides(value_rank, 1);
|
||||
|
||||
// All HLO slice+reshape results used to replace the original tf.Unpack op.
|
||||
SmallVector<Value *, 4> results;
|
||||
results.reserve(op.getNumResults());
|
||||
|
||||
for (int i = 0; i < op.getNumResults(); ++i) {
|
||||
begin_indices[axis] = i;
|
||||
end_indices[axis] = i + 1;
|
||||
|
||||
auto slice_op = rewriter.create<xla_hlo::SliceOp>(
|
||||
op.getLoc(), op.value(), GetI64ElementsAttr(begin_indices, &rewriter),
|
||||
GetI64ElementsAttr(end_indices, &rewriter),
|
||||
GetI64ElementsAttr(strides, &rewriter));
|
||||
// Reshape to drop the axis dimension.
|
||||
auto reshape_op = rewriter.create<xla_hlo::ReshapeOp>(
|
||||
op.getLoc(), op.getType(i), slice_op);
|
||||
results.push_back(reshape_op);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, results);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc"
|
||||
|
||||
LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) {
|
||||
@ -2013,16 +2206,16 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) {
|
||||
// level TensorFlow ops. So, we don't have to target all the TensorFlow ops
|
||||
// here for lowering to HLO.
|
||||
TF::PopulateLoweringTFPatterns(context, &patterns);
|
||||
patterns
|
||||
.insert<ConvertArgMaxOp, ConvertBF16FloorDivOp, ConvertConv2D,
|
||||
ConvertEinsumOp, ConvertMaxPoolOp, ConvertRangeOp,
|
||||
ConvertSigmoidOp, ConvertSizeOp, ConvertMaxPoolOp, ConvertRangeOp,
|
||||
ConvertSigmoidOp, ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
|
||||
ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSplitOp,
|
||||
ConvertStridedSliceOp, ConvertTopKV2Op, ConvertMeanOp,
|
||||
ConvertSumOp, ConvertMaxOp, ConvertTileOp, ConvertMaxPoolGradOp,
|
||||
ConvertOneHotOp, ConvertConv2DBackpropInputOp,
|
||||
ConvertConv2DBackpropFilterOp>(op->getContext());
|
||||
patterns.insert<
|
||||
ConvertArgMaxOp, ConvertBF16FloorDivOp, ConvertConv2D, ConvertEinsumOp,
|
||||
ConvertMaxPoolOp, ConvertRangeOp, ConvertSigmoidOp, ConvertSizeOp,
|
||||
ConvertMaxPoolOp, ConvertRangeOp, ConvertSigmoidOp,
|
||||
ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
|
||||
ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSplitOp, ConvertSplitVOp,
|
||||
ConvertStridedSliceOp, ConvertTopKV2Op, ConvertUnpackOp, ConvertMeanOp,
|
||||
ConvertSumOp, ConvertMaxOp, ConvertAllOp, ConvertAnyOp, ConvertTileOp,
|
||||
ConvertMaxPoolGradOp, ConvertOneHotOp, ConvertConv2DBackpropInputOp,
|
||||
ConvertConv2DBackpropFilterOp>(op->getContext());
|
||||
|
||||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<XlaHloDialect>();
|
||||
|
@ -95,8 +95,7 @@ void ImportXlaRegion(mlir::FuncOp func, Region* dest_region, Location loc,
|
||||
detupled_args.push_back(extract);
|
||||
}
|
||||
|
||||
llvm::SmallVector<Value*, 4> result(
|
||||
builder.create<CallOp>(loc, func, detupled_args).getResults());
|
||||
auto result = builder.create<CallOp>(loc, func, detupled_args).getResults();
|
||||
if (!tuple_return) {
|
||||
builder.create<xla_hlo::ReturnOp>(loc, result);
|
||||
} else {
|
||||
|
@ -29,6 +29,9 @@ def FeatureDimension : NativeCodeCall<
|
||||
def FalseBoolAttr : AttrConstraint<CPred<"!$_self.getValue()">>;
|
||||
def TrueBoolAttr : AttrConstraint<CPred<"$_self.getValue()">>;
|
||||
|
||||
def CastValueToI64: NativeCodeCall<
|
||||
"CastValueToI64($0->getLoc(), $1, &$_builder)">;
|
||||
|
||||
def : Pattern<
|
||||
(TF_FusedBatchNormOp:$root $x, $scale, $offset, $mean, $variance, $epsilon,
|
||||
$data_format, FalseBoolAttr:$is_training),
|
||||
@ -43,13 +46,22 @@ def : Pattern<
|
||||
[(HasNoUseOf:$root__1), (HasNoUseOf:$root__2),
|
||||
(HasNoUseOf:$root__3), (HasNoUseOf:$root__4)]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Assert op pattern.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// HLO and XLA doesn't support Assertions.
|
||||
def LowerAssert : Pattern<(TF_AssertOp $condition, $data, $summarize), []>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Bias op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
def BiasAddFeatureDimension : NativeCodeCall<
|
||||
"getBiasFeatureDimension($_builder, $0, $1)">;
|
||||
|
||||
def : Pat<(TF_BiasAddOp AnyStaticShapeTensor:$input, $bias, $data_format),
|
||||
// $input needs to be a ranked tensor to identify index of the feature
|
||||
// dimension depending on the data_format 'NHWC' or 'NCHW'.
|
||||
def : Pat<(TF_BiasAddOp AnyRankedTensor:$input, $bias, $data_format),
|
||||
(HLO_AddOp $input, $bias,
|
||||
(BiasAddFeatureDimension $data_format, $input))>;
|
||||
|
||||
@ -298,7 +310,7 @@ def : Pat<(TF_MatMulOp $a, $b, $transpose_a, $transpose_b),
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def : Pat<(TF_ConstOp:$res ElementsAttr:$value), (HLO_ConstOp $value),
|
||||
[(AnyStaticShapeTensor $res), (HLO_Tensor $res)]>;
|
||||
[(HLO_Tensor $res)]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Relu op patterns.
|
||||
@ -316,11 +328,21 @@ def : Pat<(TF_Relu6Op AnyStaticShapeTensor:$input),
|
||||
(HLO_ConstOp (ConstantSplat<"6"> $input)))>;
|
||||
|
||||
// ReluGrad(gradients, features) = gradients * (features > 0)
|
||||
def : Pat<(TF_ReluGradOp AnyStaticShapeTensor:$gradients, AnyStaticShapeTensor:$features),
|
||||
//
|
||||
// $gradients needs to be of static shape so that on_true and on_false operands
|
||||
// of SelectOp have same shape.
|
||||
//
|
||||
// $features needs to be ranked for computation of the broadcast dimensions for
|
||||
// CompareOp.
|
||||
//
|
||||
// TODO(hinsu): Relax $gradients static shape requirement when there is a way
|
||||
// to create splat tensor of dynamic shape in HLO.
|
||||
def : Pat<(TF_ReluGradOp AnyStaticShapeTensor:$gradients, AnyRankedTensor:$features),
|
||||
(HLO_SelectOp
|
||||
(HLO_CompareOp $features, (HLO_ConstOp:$zero (ConstantSplat<"0"> $features)),
|
||||
(HLO_CompareOp $features,
|
||||
(HLO_ConstOp (GetScalarOfType<0> $features)),
|
||||
(NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_GT),
|
||||
$gradients, $zero)>;
|
||||
$gradients, (HLO_ConstOp (ConstantSplat<"0"> $gradients)))>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Slice op patterns.
|
||||
@ -333,9 +355,9 @@ def TFSliceSizes2HLOSliceSizes : NativeCodeCall<
|
||||
"TFSliceSizes2HLOSliceSizes($0, $1, $2.cast<DenseIntElementsAttr>(),"
|
||||
"&$_builder)">;
|
||||
|
||||
def : Pat<(TF_SliceOp HLO_Tensor:$input, HLO_Tensor:$starting_indices,
|
||||
(TF_ConstOp I64ElementsAttr:$slice_sizes)),
|
||||
(HLO_DynamicSliceOp $input, $starting_indices,
|
||||
def : Pat<(TF_SliceOp:$op HLO_Tensor:$input, HLO_Tensor:$starting_indices,
|
||||
(TF_ConstOp $slice_sizes)),
|
||||
(HLO_DynamicSliceOp $input, (CastValueToI64 $op, $starting_indices),
|
||||
(TFSliceSizes2HLOSliceSizes $input, $starting_indices, $slice_sizes)),
|
||||
[(CanBeTranslatedToDynamicSlice $input, $starting_indices,
|
||||
$slice_sizes)]>;
|
||||
@ -383,19 +405,21 @@ foreach Mapping = [
|
||||
def : Pat<(TF_CastOp HLO_Tensor:$arg, ConstBoolAttrFalse),
|
||||
(HLO_ConvertOp $arg)>;
|
||||
|
||||
def : Pat<(TF_TransposeOp:$res $arg, (TF_ConstOp I64ElementsAttr:$permutation)),
|
||||
(HLO_TransposeOp $arg, (CastIntElementsAttr $permutation))>;
|
||||
def : Pat<(TF_TransposeOp:$res $arg, (TF_ConstOp $permutation)),
|
||||
(HLO_TransposeOp $arg, (CastElementsToI64Elements $permutation))>;
|
||||
|
||||
// Result of the following ops changing tensor shape needs to have static
|
||||
// shape as HLO doesn't yet support dynamic reshaping ops.
|
||||
//
|
||||
// TODO(hinsu): Update once HLO supports dynamic reshaping ops.
|
||||
foreach TfOp = [TF_ExpandDimsOp, TF_ReshapeOp, TF_SqueezeOp, ] in {
|
||||
def : Pat<(TfOp:$res AnyStaticShapeTensor:$arg, $ignored),
|
||||
def : Pat<(TfOp:$res $arg, $ignored),
|
||||
(HLO_ReshapeOp $arg), [(AnyStaticShapeTensor $res)]>;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RngUniform.
|
||||
//===----------------------------------------------------------------------===//
|
||||
def CastElementsToI64: NativeCodeCall<
|
||||
"CastElementsToI64($0->getLoc(), $1, &$_builder)">;
|
||||
|
||||
// TODO(misard,phawkins): handle random number generator seeds/states correctly.
|
||||
def : Pat<(TF_RandomUniformOp:$old $shape, $seed, $seed2),
|
||||
@ -404,5 +428,5 @@ def : Pat<(TF_RandomUniformOp:$old $shape, $seed, $seed2),
|
||||
(NativeCodeCall<"$_builder.getFloatAttr(old.dtype(), 0.0)">)),
|
||||
(HLO_ConstOp
|
||||
(NativeCodeCall<"$_builder.getFloatAttr(old.dtype(), 1.0)">)),
|
||||
(CastElementsToI64 $old, $shape)),
|
||||
(CastValueToI64 $old, $shape)),
|
||||
[(IsShapedTensor $shape)]>;
|
||||
|
@ -54,7 +54,8 @@ struct LhloFuseLinalg : public FunctionPass<LhloFuseLinalg> {
|
||||
auto op = cast<LinalgOp>(generic_op.getOperation());
|
||||
for (const Value* result : op.getOutputs()) {
|
||||
if (!func_args.count(result)) continue;
|
||||
if (linalg::tileLinalgOp(b, op, tile_sizes, &folder)) {
|
||||
if (linalg::tileLinalgOp(b, op, tile_sizes, /*permutation=*/{},
|
||||
&folder)) {
|
||||
generic_op.erase();
|
||||
return;
|
||||
}
|
||||
|
@ -112,7 +112,7 @@ class PointwiseToLinalgConverter : public OpConversionPattern<LhloOp> {
|
||||
rewriter.setInsertionPointToEnd(block);
|
||||
Operation* op = MapLhloOpToStdScalarOp<LhloOp>(
|
||||
llvm::cast<LhloOp>(lhlo_op), bodyResultTypes, bodyArgs, rewriter);
|
||||
rewriter.create<linalg::YieldOp>(loc, llvm::to_vector<1>(op->getResults()));
|
||||
rewriter.create<linalg::YieldOp>(loc, op->getResults());
|
||||
rewriter.eraseOp(lhlo_op);
|
||||
return ConversionPattern::matchSuccess();
|
||||
}
|
||||
|
@ -653,7 +653,13 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
||||
divs = np.arange(-3, 3, .25, dtype=dtype).reshape(1, 24)
|
||||
np_result = np.true_divide(nums, divs)
|
||||
np_result[:, divs[0] == 0] = 0
|
||||
self._testBinary(gen_math_ops.div_no_nan, nums, divs, expected=np_result)
|
||||
self._testBinary(
|
||||
gen_math_ops.div_no_nan,
|
||||
nums,
|
||||
divs,
|
||||
expected=np_result,
|
||||
rtol=7e-15 if dtype == np.float64 else None,
|
||||
atol=3.9e-15 if dtype == np.float64 else None)
|
||||
|
||||
if dtype not in self.complex_types: # floordiv unsupported for complex.
|
||||
self._testBinary(
|
||||
|
@ -164,7 +164,8 @@ class TensorArrayTest(xla_test.XLATestCase):
|
||||
dtype=tf_dtype, tensor_array_name="foo", size=3)
|
||||
|
||||
# Unpack a matrix into vectors.
|
||||
w1 = ta.unstack(convert([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]]))
|
||||
w1 = ta.unstack(
|
||||
convert([[1.0, 1.03125], [2.0, 2.03125], [3.0, 3.03125]]))
|
||||
r0 = w1.read(0)
|
||||
r1 = w1.read(1)
|
||||
r2 = w1.read(2)
|
||||
@ -172,9 +173,9 @@ class TensorArrayTest(xla_test.XLATestCase):
|
||||
|
||||
d0, d1, d2 = self.evaluate(xla.compile(fn))
|
||||
|
||||
self.assertAllEqual(convert([1.0, 1.1]), d0)
|
||||
self.assertAllEqual(convert([2.0, 2.1]), d1)
|
||||
self.assertAllEqual(convert([3.0, 3.1]), d2)
|
||||
self.assertAllEqual(convert([1.0, 1.03125]), d0)
|
||||
self.assertAllEqual(convert([2.0, 2.03125]), d1)
|
||||
self.assertAllEqual(convert([3.0, 3.03125]), d2)
|
||||
|
||||
def fn():
|
||||
# Reset ta because we're going to change the shape, else shape
|
||||
|
@ -307,7 +307,7 @@ void UpdateToEngineNode(const std::vector<EngineInfo>& infos,
|
||||
}
|
||||
}
|
||||
}
|
||||
LOG(FATAL) << "Node " << (**node).name() << " not found in any engine.";
|
||||
LOG(FATAL) << "Node " << node_name << " not found in any engine.";
|
||||
}
|
||||
|
||||
// Function to insert a TRT engine node into the graph.
|
||||
|
@ -654,9 +654,8 @@ class ConverterTest : public ::testing::Test {
|
||||
ConverterTest() { Reset(); }
|
||||
|
||||
void Reset() {
|
||||
builder_.reset(nvinfer1::createInferBuilder(logger_));
|
||||
converter_ =
|
||||
std::move(Converter::Create(builder_.get(), TrtPrecisionMode::FP32,
|
||||
std::move(Converter::Create(TrtPrecisionMode::FP32,
|
||||
/*use_calibration=*/false, &logger_)
|
||||
.ValueOrDie());
|
||||
weight_store_ = &converter_->weight_store_;
|
||||
@ -702,9 +701,6 @@ class ConverterTest : public ::testing::Test {
|
||||
|
||||
private:
|
||||
Logger logger_;
|
||||
// These members are ordered in a way such that the destruction order is:
|
||||
// converter_ -> builder_
|
||||
TrtUniquePtrType<nvinfer1::IBuilder> builder_;
|
||||
|
||||
protected:
|
||||
std::unique_ptr<Converter> converter_;
|
||||
@ -996,9 +992,7 @@ TEST_F(ConverterTest, MaybeApplyQuantizationRanges) {
|
||||
FakeITensor input, infer_1, infer_2, infer_3;
|
||||
FakeITensor not_infer;
|
||||
Logger logger;
|
||||
TrtUniquePtrType<nvinfer1::IBuilder> builder(
|
||||
nvinfer1::createInferBuilder(logger));
|
||||
auto int8_converter = Converter::Create(builder.get(), TrtPrecisionMode::INT8,
|
||||
auto int8_converter = Converter::Create(TrtPrecisionMode::INT8,
|
||||
/*use_calibration=*/true, &logger)
|
||||
.ValueOrDie();
|
||||
int8_converter->ProvideQuantizationRange(&input, -5.0f, 5.0f);
|
||||
@ -1255,12 +1249,8 @@ class OpConverterTest : public ::testing::Test {
|
||||
engine_.reset(nullptr);
|
||||
|
||||
// Re-create them in proper order.
|
||||
builder_.reset(nvinfer1::createInferBuilder(logger_));
|
||||
builder_->setMaxWorkspaceSize(1 << 26);
|
||||
|
||||
// Reset the converter.
|
||||
converter_ =
|
||||
std::move(Converter::Create(builder_.get(), precision_mode_to_test_,
|
||||
std::move(Converter::Create(precision_mode_to_test_,
|
||||
/*use_calibration=*/false, &logger_)
|
||||
.ValueOrDie());
|
||||
|
||||
@ -1294,18 +1284,13 @@ class OpConverterTest : public ::testing::Test {
|
||||
TF_EXPECT_OK(converter_->RenameAndMarkOutputTensors(output_info));
|
||||
|
||||
// Build the TRT engine.
|
||||
if (precision_mode == TrtPrecisionMode::FP16) {
|
||||
builder_->setFp16Mode(true);
|
||||
} else if (precision_mode == TrtPrecisionMode::INT8) {
|
||||
// Setting FP16 mode as well allows TRT to also consider FP16 kernels and
|
||||
// use them in situations where they are faster than INT8 or where INT8 is
|
||||
// not supported for a given layer.
|
||||
builder_->setFp16Mode(true);
|
||||
builder_->setInt8Mode(true);
|
||||
}
|
||||
ASSERT_EQ(nullptr, engine_.get());
|
||||
builder_->setMaxBatchSize(batch_size);
|
||||
TF_ASSERT_OK(converter_->BuildCudaEngine(&engine_));
|
||||
TF_ASSERT_OK(
|
||||
converter_->BuildCudaEngine(&engine_,
|
||||
/*max_batch_size=*/batch_size,
|
||||
/*max_workspace_size_bytes=*/1 << 26,
|
||||
/*allocator=*/nullptr,
|
||||
/*calibrator=*/nullptr));
|
||||
CHECK_NOTNULL(engine_.get());
|
||||
CheckDataTypeMatches(input_data);
|
||||
CheckDataTypeMatches(*output_data);
|
||||
@ -1473,7 +1458,6 @@ class OpConverterTest : public ::testing::Test {
|
||||
|
||||
private:
|
||||
Logger logger_;
|
||||
TrtUniquePtrType<nvinfer1::IBuilder> builder_;
|
||||
TrtUniquePtrType<nvinfer1::ICudaEngine> engine_;
|
||||
cudaStream_t stream_;
|
||||
// Used to create placeholders with shape and data type information. The
|
||||
|
@ -143,6 +143,10 @@ class DataFormatVecPermuteOp : public XlaOpKernel {
|
||||
REGISTER_XLA_OP(
|
||||
Name("DataFormatVecPermute").TypeConstraint("T", {DT_INT32, DT_INT64}),
|
||||
DataFormatVecPermuteOp);
|
||||
REGISTER_XLA_OP(Name("DataFormatVecPermute")
|
||||
.Label("host")
|
||||
.TypeConstraint("T", {DT_INT32, DT_INT64}),
|
||||
DataFormatVecPermuteOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -723,17 +723,6 @@ Status XlaCompiler::CompileFunction(
|
||||
|
||||
std::unique_ptr<Graph> graph = GetGraph(fbody);
|
||||
|
||||
// Clear the "_kernel" attribute if it is set to "host". This is used to
|
||||
// indicate that a computation should happen on the host instead of the
|
||||
// accelerator, but doesn't make sense in XLA.
|
||||
const char* const kKernelAttr = "_kernel";
|
||||
for (Node* n : graph->nodes()) {
|
||||
string value;
|
||||
if (TryGetNodeAttr(n->attrs(), kKernelAttr, &value) && value == "host") {
|
||||
n->ClearAttr(kKernelAttr);
|
||||
}
|
||||
}
|
||||
|
||||
// _Arg and _Retval nodes don't exist in the stored subgraph for the function;
|
||||
// they are added by the function body looked up. Therefore, they don't have
|
||||
// core assignments here.
|
||||
@ -1059,7 +1048,12 @@ Status XlaCompiler::BuildArguments(
|
||||
const XlaCompiler::Argument& arg = args[input_to_args->at(i)];
|
||||
VLOG(2) << " XLA arg " << i
|
||||
<< " shape: " << xla::ShapeUtil::HumanString(arg_shapes[i])
|
||||
<< " name: " << arg.name << " TF arg " << input_to_args->at(i);
|
||||
<< " name: " << arg.name << " TF arg " << input_to_args->at(i)
|
||||
<< " node name: " << arg.node_name
|
||||
<< (arg_shardings.find(i) == arg_shardings.end()
|
||||
? ""
|
||||
: absl::StrCat(" sharding: ",
|
||||
arg_shardings.at(i).DebugString()));
|
||||
XlaExpression& arg_expression = (*arg_expressions)[input_to_args->at(i)];
|
||||
switch (arg.kind) {
|
||||
case XlaCompiler::Argument::kResource: {
|
||||
|
@ -147,6 +147,9 @@ class XlaCompiler {
|
||||
// The name of this argument, used for debugging.
|
||||
string name;
|
||||
|
||||
// The name of TensorFlow _Arg node, used for debugging.
|
||||
string node_name;
|
||||
|
||||
// For a kResource, what kind of resource is it?
|
||||
XlaResource::Kind resource_kind = XlaResource::kInvalid;
|
||||
|
||||
|
@ -61,6 +61,7 @@ XlaOpRegistry::~XlaOpRegistry() = default;
|
||||
/* static */ bool XlaOpRegistry::IsCompatible(const OpRegistration& x,
|
||||
const OpRegistration& y) {
|
||||
if (x.name != y.name) return true;
|
||||
if (x.label != y.label) return true;
|
||||
// The registrations refer to the same Op: ensures they are compatible and
|
||||
// are restricted to different device whitelists.
|
||||
if (x.compilation_only != y.compilation_only) {
|
||||
@ -256,6 +257,7 @@ void XlaOpRegistry::RegisterCompilationKernels() {
|
||||
std::unique_ptr<KernelDef> kdef(new KernelDef);
|
||||
kdef->set_op(op_registration->name);
|
||||
kdef->set_device_type(backend.first);
|
||||
kdef->set_label(op_registration->label);
|
||||
|
||||
// Constrain each type attribute to the intersection of:
|
||||
// a) the types supported by the backend, and
|
||||
@ -539,6 +541,11 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::IsMetadataOp() {
|
||||
return *this;
|
||||
}
|
||||
|
||||
XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Label(std::string label) {
|
||||
registration_->label = label;
|
||||
return *this;
|
||||
}
|
||||
|
||||
std::unique_ptr<XlaOpRegistry::OpRegistration> XlaOpRegistrationBuilder::Build(
|
||||
XlaOpRegistry::Factory factory) {
|
||||
registration_->factory = factory;
|
||||
|
@ -270,6 +270,8 @@ class XlaOpRegistry {
|
||||
// operands and not their values.
|
||||
bool is_metadata_op = false;
|
||||
|
||||
std::string label;
|
||||
|
||||
// Factory used to build OpKernels that perform symbolic execution.
|
||||
Factory factory;
|
||||
};
|
||||
@ -350,6 +352,9 @@ class XlaOpRegistrationBuilder {
|
||||
// operands and not their values.
|
||||
XlaOpRegistrationBuilder& IsMetadataOp();
|
||||
|
||||
// Specifies a particular value for the "_kernel" attr.
|
||||
XlaOpRegistrationBuilder& Label(std::string label);
|
||||
|
||||
std::unique_ptr<XlaOpRegistry::OpRegistration> Build(
|
||||
XlaOpRegistry::Factory factory);
|
||||
|
||||
|
@ -319,6 +319,8 @@ XlaOp Erf(XlaOp x) {
|
||||
});
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Approximation for the inverse error function from
|
||||
// Giles, M., "Approximating the erfinv function".
|
||||
// The approximation has the form:
|
||||
@ -331,7 +333,7 @@ XlaOp Erf(XlaOp x) {
|
||||
// p = sum_{i=1}^n gq[i]*w^i
|
||||
// }
|
||||
// return p*x
|
||||
XlaOp ErfInv(XlaOp x) {
|
||||
XlaOp ErfInv32(XlaOp x) {
|
||||
constexpr int kDegree = 9;
|
||||
constexpr std::array<float, 9> w_less_than_5_constants = {
|
||||
2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f,
|
||||
@ -371,6 +373,101 @@ XlaOp ErfInv(XlaOp x) {
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp ErfInv64(XlaOp x) {
|
||||
constexpr std::array<double, 23> w_less_than_6_25_constants = {
|
||||
-3.6444120640178196996e-21, -1.685059138182016589e-19,
|
||||
1.2858480715256400167e-18, 1.115787767802518096e-17,
|
||||
-1.333171662854620906e-16, 2.0972767875968561637e-17,
|
||||
6.6376381343583238325e-15, -4.0545662729752068639e-14,
|
||||
-8.1519341976054721522e-14, 2.6335093153082322977e-12,
|
||||
-1.2975133253453532498e-11, -5.4154120542946279317e-11,
|
||||
1.051212273321532285e-09, -4.1126339803469836976e-09,
|
||||
-2.9070369957882005086e-08, 4.2347877827932403518e-07,
|
||||
-1.3654692000834678645e-06, -1.3882523362786468719e-05,
|
||||
0.0001867342080340571352, -0.00074070253416626697512,
|
||||
-0.0060336708714301490533, 0.24015818242558961693,
|
||||
1.6536545626831027356};
|
||||
constexpr std::array<double, 19> w_less_than_16_constants = {
|
||||
2.2137376921775787049e-09, 9.0756561938885390979e-08,
|
||||
-2.7517406297064545428e-07, 1.8239629214389227755e-08,
|
||||
1.5027403968909827627e-06, -4.013867526981545969e-06,
|
||||
2.9234449089955446044e-06, 1.2475304481671778723e-05,
|
||||
-4.7318229009055733981e-05, 6.8284851459573175448e-05,
|
||||
2.4031110387097893999e-05, -0.0003550375203628474796,
|
||||
0.00095328937973738049703, -0.0016882755560235047313,
|
||||
0.0024914420961078508066, -0.0037512085075692412107,
|
||||
0.005370914553590063617, 1.0052589676941592334,
|
||||
3.0838856104922207635,
|
||||
};
|
||||
constexpr std::array<double, 17> w_greater_than_16_constants = {
|
||||
-2.7109920616438573243e-11, -2.5556418169965252055e-10,
|
||||
1.5076572693500548083e-09, -3.7894654401267369937e-09,
|
||||
7.6157012080783393804e-09, -1.4960026627149240478e-08,
|
||||
2.9147953450901080826e-08, -6.7711997758452339498e-08,
|
||||
2.2900482228026654717e-07, -9.9298272942317002539e-07,
|
||||
4.5260625972231537039e-06, -1.9681778105531670567e-05,
|
||||
7.5995277030017761139e-05, -0.00021503011930044477347,
|
||||
-0.00013871931833623122026, 1.0103004648645343977,
|
||||
4.8499064014085844221,
|
||||
};
|
||||
// Compute logarithm of (1+arg) using log1p(arg) which is more precise than
|
||||
// log(1+arg) when arg is close to zero. For more details, see
|
||||
// https://en.cppreference.com/w/cpp/numeric/math/log1p
|
||||
auto w = -Log1p(-x * x);
|
||||
|
||||
auto lt_6_25 = Lt(w, ScalarLike(x, 6.25));
|
||||
auto lt_16 = Lt(w, ScalarLike(x, 16));
|
||||
auto coefficient = [&](int i) {
|
||||
auto c = FullLike(x, w_less_than_6_25_constants[i]);
|
||||
if (i < 19) {
|
||||
c = Select(lt_6_25, c, FullLike(x, w_less_than_16_constants[i]));
|
||||
}
|
||||
if (i < 17) {
|
||||
c = Select(lt_16, c, FullLike(x, w_greater_than_16_constants[i]));
|
||||
}
|
||||
return c;
|
||||
};
|
||||
auto sqrt_w = Sqrt(w);
|
||||
w = Select(lt_6_25, w - ScalarLike(x, 3.125),
|
||||
sqrt_w - Select(lt_16, ScalarLike(x, 3.25), ScalarLike(x, 5.0)));
|
||||
auto p = coefficient(0);
|
||||
for (int i = 1; i < 17; ++i) {
|
||||
p = coefficient(i) + p * w;
|
||||
}
|
||||
for (int i = 17; i < 19; ++i) {
|
||||
p = Select(lt_16, coefficient(i) + p * w, p);
|
||||
}
|
||||
for (int i = 19; i < 23; ++i) {
|
||||
p = Select(lt_6_25, coefficient(i) + p * w, p);
|
||||
}
|
||||
// Result modulo edge cases.
|
||||
XlaOp result = p * x;
|
||||
|
||||
// Handle edge cases, namely erfinv(+/-1) = +/-inf. (The above computation is
|
||||
// indeterminate, and can give nan or -/+inf.)
|
||||
auto& b = *x.builder();
|
||||
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, b.GetShape(x));
|
||||
return Select(Eq(Abs(x), ScalarLike(x, 1)),
|
||||
x * MaxValue(&b, shape.element_type()), result);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
XlaOp ErfInv(XlaOp x) {
|
||||
auto& b = *x.builder();
|
||||
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("ErfInv", x));
|
||||
TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(x));
|
||||
if (shape.element_type() == F64) {
|
||||
return ErfInv64(x);
|
||||
}
|
||||
return DoWithUpcastToF32(x, {BF16, F16},
|
||||
[](XlaOp x) { return ErfInv32(x); });
|
||||
});
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Coefficients for the Lanczos approximation of the gamma function. The
|
||||
// coefficients are uniquely determined by the choice of g and n (kLanczosGamma
|
||||
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/client/lib/math.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
@ -116,6 +117,10 @@ class MathTypedTest : public MathTest {
|
||||
//
|
||||
// For good measure, we also check pow with an exponent other than 0.5.
|
||||
void TestSqrtPowInequivalence() {
|
||||
// TODO(b/145798892): test fails on GPU for double values.
|
||||
if (std::is_same<T, double>::value) {
|
||||
return;
|
||||
}
|
||||
SetFastMathDisabled(true);
|
||||
|
||||
// Tests disable constant folding by default, but this test needs it
|
||||
@ -151,11 +156,16 @@ class MathTypedTest : public MathTest {
|
||||
};
|
||||
|
||||
// TODO(b/123355973): Add bfloat16 to TestTypes once it's working.
|
||||
#ifdef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16
|
||||
using TestTypes = ::testing::Types<float>;
|
||||
#else
|
||||
using TestTypes = ::testing::Types<float, Eigen::half>;
|
||||
using TestTypes = ::testing::Types<float
|
||||
#ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16
|
||||
,
|
||||
Eigen::half
|
||||
#endif
|
||||
#ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64
|
||||
,
|
||||
double
|
||||
#endif
|
||||
>;
|
||||
|
||||
TYPED_TEST_CASE(MathTypedTest, TestTypes);
|
||||
|
||||
@ -224,6 +234,28 @@ XLA_TEST_F(MathTest, SqrtF32) {
|
||||
ComputeAndCompareR0<float>(&builder, 0.0f, {zero_data.get()}, error_spec_);
|
||||
}
|
||||
|
||||
#ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64
|
||||
XLA_TEST_F(MathTest, ErfInvF64) {
|
||||
XlaBuilder builder(TestName());
|
||||
auto x = ConstantR1<double>(
|
||||
&builder, {-0.9, -0.8, -0.7, -0.6, -0.5, -0.4, -0.3, -0.2, -0.1, 0.0, 0.1,
|
||||
0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9});
|
||||
ErfInv(x);
|
||||
|
||||
std::vector<double> expected = {-1.163087153676674, -0.9061938024368231,
|
||||
-0.732869077959217, -0.5951160814499948,
|
||||
-0.4769362762044698, -0.37080715859355795,
|
||||
-0.27246271472675443, -0.1791434546212916,
|
||||
-0.08885599049425767, 0.,
|
||||
0.08885599049425777, 0.1791434546212916,
|
||||
0.27246271472675443, 0.37080715859355784,
|
||||
0.4769362762044698, 0.5951160814499948,
|
||||
0.732869077959217, 0.9061938024368231,
|
||||
1.1630871536766736};
|
||||
ComputeAndCompareR1<double>(&builder, expected, {}, ErrorSpec{1e-15});
|
||||
}
|
||||
#endif
|
||||
|
||||
XLA_TEST_F(MathTest, SquareTenValues) {
|
||||
XlaBuilder builder(TestName());
|
||||
auto x = ConstantR1<float>(
|
||||
|
@ -2112,7 +2112,8 @@ XlaOp XlaBuilder::CrossReplicaSum(
|
||||
|
||||
XlaOp XlaBuilder::AllReduce(XlaOp operand, const XlaComputation& computation,
|
||||
absl::Span<const ReplicaGroup> replica_groups,
|
||||
const absl::optional<ChannelHandle>& channel_id) {
|
||||
const absl::optional<ChannelHandle>& channel_id,
|
||||
const absl::optional<Shape>& shape_with_layout) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
||||
@ -2136,9 +2137,31 @@ XlaOp XlaBuilder::AllReduce(XlaOp operand, const XlaComputation& computation,
|
||||
operand_shapes.push_back(operand_shape);
|
||||
operands.push_back(operand);
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(Shape shape,
|
||||
|
||||
TF_ASSIGN_OR_RETURN(Shape inferred_shape,
|
||||
ShapeInference::InferAllReduceShape(operand_shapes));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
if (shape_with_layout) {
|
||||
if (!LayoutUtil::HasLayout(*shape_with_layout)) {
|
||||
return InvalidArgument("shape_with_layout must have the layout set: %s",
|
||||
shape_with_layout->ToString());
|
||||
}
|
||||
if (!ShapeUtil::Compatible(*shape_with_layout, *operand_shape)) {
|
||||
return InvalidArgument(
|
||||
"Provided shape_with_layout must be compatible with the "
|
||||
"operand shape: %s vs %s",
|
||||
shape_with_layout->ToString(), operand_shape->ToString());
|
||||
}
|
||||
instr.set_constrain_layout(true);
|
||||
if (operand_shape->IsTuple() && !inferred_shape.IsTuple()) {
|
||||
// For a single-element tuple, take the tuple element shape.
|
||||
TF_RET_CHECK(shape_with_layout->tuple_shapes_size() == 1);
|
||||
*instr.mutable_shape() = shape_with_layout->tuple_shapes(0).ToProto();
|
||||
} else {
|
||||
*instr.mutable_shape() = shape_with_layout->ToProto();
|
||||
}
|
||||
} else {
|
||||
*instr.mutable_shape() = inferred_shape.ToProto();
|
||||
}
|
||||
|
||||
for (const ReplicaGroup& group : replica_groups) {
|
||||
*instr.add_replica_groups() = group;
|
||||
@ -2153,10 +2176,10 @@ XlaOp XlaBuilder::AllReduce(XlaOp operand, const XlaComputation& computation,
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto all_reduce,
|
||||
AddInstruction(std::move(instr), HloOpcode::kAllReduce, operands));
|
||||
if (operand_shape->IsTuple() && !shape.IsTuple()) {
|
||||
if (operand_shape->IsTuple() && !inferred_shape.IsTuple()) {
|
||||
// For a single-element tuple, wrap the result into a tuple.
|
||||
TF_RET_CHECK(operand_shapes.size() == 1);
|
||||
TF_RET_CHECK(ShapeUtil::Compatible(*operand_shapes[0], shape));
|
||||
TF_RET_CHECK(ShapeUtil::Compatible(*operand_shapes[0], inferred_shape));
|
||||
return Tuple({all_reduce});
|
||||
}
|
||||
return all_reduce;
|
||||
@ -3282,9 +3305,10 @@ XlaOp CrossReplicaSum(const XlaOp operand,
|
||||
|
||||
XlaOp AllReduce(const XlaOp operand, const XlaComputation& computation,
|
||||
absl::Span<const ReplicaGroup> replica_groups,
|
||||
const absl::optional<ChannelHandle>& channel_id) {
|
||||
const absl::optional<ChannelHandle>& channel_id,
|
||||
const absl::optional<Shape>& shape_with_layout) {
|
||||
return operand.builder()->AllReduce(operand, computation, replica_groups,
|
||||
channel_id);
|
||||
channel_id, shape_with_layout);
|
||||
}
|
||||
|
||||
XlaOp AllToAll(const XlaOp operand, int64 split_dimension,
|
||||
|
@ -514,7 +514,8 @@ class XlaBuilder {
|
||||
XlaOp AllReduce(
|
||||
XlaOp operand, const XlaComputation& computation,
|
||||
absl::Span<const ReplicaGroup> replica_groups = {},
|
||||
const absl::optional<ChannelHandle>& channel_id = absl::nullopt);
|
||||
const absl::optional<ChannelHandle>& channel_id = absl::nullopt,
|
||||
const absl::optional<Shape>& shape_with_layout = absl::nullopt);
|
||||
|
||||
XlaOp AllToAll(XlaOp operand, int64 split_dimension, int64 concat_dimension,
|
||||
int64 split_count,
|
||||
@ -922,7 +923,8 @@ class XlaBuilder {
|
||||
absl::Span<const ReplicaGroup> replica_groups);
|
||||
friend XlaOp AllReduce(XlaOp operand, const XlaComputation& computation,
|
||||
absl::Span<const ReplicaGroup> replica_groups,
|
||||
const absl::optional<ChannelHandle>& channel_id);
|
||||
const absl::optional<ChannelHandle>& channel_id,
|
||||
const absl::optional<Shape>& shape_with_layout);
|
||||
friend XlaOp AllToAll(XlaOp operand, int64 split_dimension,
|
||||
int64 concat_dimension, int64 split_count,
|
||||
const std::vector<ReplicaGroup>& replica_groups);
|
||||
@ -1666,10 +1668,14 @@ XlaOp CrossReplicaSum(XlaOp operand,
|
||||
// - `channel_id`: for Allreduce nodes from different modules, if they have the
|
||||
// same channel_id, they will be 'AllReduce'd. If empty, AllReduce will not be
|
||||
// applied cross modules.
|
||||
XlaOp AllReduce(
|
||||
XlaOp operand, const XlaComputation& computation,
|
||||
absl::Span<const ReplicaGroup> replica_groups = {},
|
||||
const absl::optional<ChannelHandle>& channel_id = absl::nullopt);
|
||||
//
|
||||
// - `shape_with_layout`: forces the layout of the AllReduce to the given
|
||||
// layout. This is used to guarantee the same layout for a group of AllReduce
|
||||
// ops compiled separately.
|
||||
XlaOp AllReduce(XlaOp operand, const XlaComputation& computation,
|
||||
absl::Span<const ReplicaGroup> replica_groups = {},
|
||||
const absl::optional<ChannelHandle>& channel_id = absl::nullopt,
|
||||
const absl::optional<Shape>& shape_with_layout = absl::nullopt);
|
||||
|
||||
// Enqueues an operation that do an Alltoall of the operand cross cores.
|
||||
XlaOp AllToAll(XlaOp operand, int64 split_dimension, int64 concat_dimension,
|
||||
|
@ -38,6 +38,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/mem.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace xla {
|
||||
@ -131,18 +132,23 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) {
|
||||
}
|
||||
} else if (shape.IsArray()) {
|
||||
if (allocate_arrays) {
|
||||
// Literals can be used as DMA targets, which can require alignment. We
|
||||
// force a 16-byte minimum alignment.
|
||||
constexpr int kMinimumAlignment = 16;
|
||||
if (LayoutUtil::IsSparseArray(shape)) {
|
||||
// For sparse arrays, the buffer must be of the size of the maximum
|
||||
// number of sparse elements possible.
|
||||
const int64 max_sparse_elements =
|
||||
LayoutUtil::MaxSparseElements(shape.layout());
|
||||
piece->set_buffer(
|
||||
new char[max_sparse_elements *
|
||||
ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type())]);
|
||||
piece->set_buffer(static_cast<char*>(tensorflow::port::AlignedMalloc(
|
||||
max_sparse_elements *
|
||||
ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type()),
|
||||
kMinimumAlignment)));
|
||||
piece->set_sparse_indices(
|
||||
new SparseIndexArray(max_sparse_elements, shape.rank()));
|
||||
} else {
|
||||
piece->set_buffer(new char[piece->size_bytes()]);
|
||||
piece->set_buffer(static_cast<char*>(tensorflow::port::AlignedMalloc(
|
||||
piece->size_bytes(), kMinimumAlignment)));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@ -174,7 +180,7 @@ void Literal::DeallocateBuffers() {
|
||||
root_piece_->ForEachMutableSubpiece(
|
||||
[&](const ShapeIndex& index, Piece* piece) {
|
||||
if (piece->buffer() != nullptr) {
|
||||
delete[] piece->buffer();
|
||||
tensorflow::port::AlignedFree(piece->buffer());
|
||||
delete piece->sparse_indices();
|
||||
}
|
||||
});
|
||||
@ -504,7 +510,7 @@ Status Literal::MoveFrom(Literal&& src_literal,
|
||||
dest_index.push_back(i);
|
||||
}
|
||||
Piece& dest_piece = piece(dest_index);
|
||||
delete[] dest_piece.buffer();
|
||||
tensorflow::port::AlignedFree(dest_piece.buffer());
|
||||
dest_piece.set_buffer(src_piece.buffer());
|
||||
delete dest_piece.sparse_indices();
|
||||
dest_piece.set_sparse_indices(src_piece.sparse_indices());
|
||||
|
@ -26,7 +26,6 @@ py_test(
|
||||
name = "xla_client_test",
|
||||
srcs = ["xla_client_test.py"],
|
||||
main = "xla_client_test.py",
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_oss"], # TODO(phawkins): This test passes, but requires --config=monolithic.
|
||||
deps = [
|
||||
|
@ -31,6 +31,11 @@ tf_proto_library_cc(
|
||||
use_grpc_namespace = True,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "c_api",
|
||||
hdrs = ["c_api.h"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_driver",
|
||||
srcs = [
|
||||
|
30
tensorflow/compiler/xla/python/tpu_driver/c_api.h
Normal file
30
tensorflow/compiler/xla/python/tpu_driver/c_api.h
Normal file
@ -0,0 +1,30 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_TPU_DRIVER_C_API_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_PYTHON_TPU_DRIVER_C_API_H_
|
||||
|
||||
#define TPUDRIVER_CAPI_EXPORT __attribute__((visibility("default")))
|
||||
|
||||
extern "C" {
|
||||
|
||||
TPUDRIVER_CAPI_EXPORT extern void TpuDriver_Initialize();
|
||||
|
||||
TPUDRIVER_CAPI_EXPORT extern void TpuDriver_Open(const char* worker);
|
||||
|
||||
TPUDRIVER_CAPI_EXPORT extern const char* TpuDriver_Version(void);
|
||||
}
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_TPU_DRIVER_C_API_H_
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user