Merge branch 'master' into sign-compare-warning-fixes-batch-1-fix2
This commit is contained in:
commit
41707f74dc
84
.bazelrc
84
.bazelrc
@ -39,6 +39,7 @@
|
||||
#
|
||||
# Feature and Third party library support options:
|
||||
# xla: Build TF with XLA
|
||||
# tpu: Build TF with TPU support
|
||||
# using_cuda: CUDA is available to build system.
|
||||
# cuda: Build with full cuda support.
|
||||
# rocm: Build with AMD GPU support (rocm).
|
||||
@ -57,13 +58,12 @@
|
||||
#
|
||||
#
|
||||
# Remote build execution options (only configured to work with TF team projects for now.)
|
||||
# rbe: General RBE options shared by all flavors.
|
||||
# rbe_linux: General RBE options used on all linux builds.
|
||||
# rbe_win: General RBE options used on all windows builds.
|
||||
# rbe: General RBE options shared by all flavors.
|
||||
# rbe_linux: General RBE options used on all linux builds.
|
||||
# rbe_win: General RBE options used on all windows builds.
|
||||
#
|
||||
# rbe_cpu_linux: RBE options to build with only CPU support.
|
||||
# rbe_linux_cuda_nvcc: RBE options to build with GPU support using nvcc.
|
||||
# rbe_gpu_linux: An alias for rbe_linux_cuda_nvcc
|
||||
# rbe_cpu_linux: RBE options to build with only CPU support.
|
||||
# rbe_linux_cuda_nvcc_py*: RBE options to build with GPU support using nvcc.
|
||||
#
|
||||
# rbe_linux_py2: Linux Python 2 RBE config.
|
||||
# rbe_linux_py3: Linux Python 3 RBE config
|
||||
@ -180,6 +180,9 @@ build:dbg --cxxopt -DTF_LITE_DISABLE_X86_NEON
|
||||
# AWS SDK must be compiled in release mode. see: https://github.com/tensorflow/tensorflow/issues/37498
|
||||
build:dbg --copt -DDEBUG_BUILD
|
||||
|
||||
# Config to build TPU backend
|
||||
build:tpu --define=with_tpu_support=true
|
||||
|
||||
build:tensorrt --action_env TF_NEED_TENSORRT=1
|
||||
|
||||
build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain
|
||||
@ -396,33 +399,48 @@ build:rbe_linux_cuda_base --repo_env=REMOTE_GPU_TESTING=1
|
||||
build:rbe_linux_cuda_base --repo_env=TF_NEED_CUDA=1
|
||||
test:rbe_linux_cuda_base --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64"
|
||||
|
||||
build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_base
|
||||
build:rbe_linux_cuda_nvcc --crosstool_top="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
|
||||
build:rbe_linux_cuda_nvcc --extra_toolchains="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
|
||||
build:rbe_linux_cuda_nvcc --extra_execution_platforms="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_nvcc --host_platform="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_nvcc --platforms="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
|
||||
build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
|
||||
build:rbe_linux_cuda_nvcc --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
|
||||
build:rbe_linux_cuda_nvcc --define=using_cuda_nvcc=true
|
||||
test:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_base
|
||||
build:rbe_linux_cuda10.1_nvcc_base --config=rbe_linux_cuda_base
|
||||
build:rbe_linux_cuda10.1_nvcc_base --define=using_cuda_nvcc=true
|
||||
build:rbe_linux_cuda10.1_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
|
||||
build:rbe_linux_cuda10.1_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
|
||||
build:rbe_linux_cuda10.1_nvcc_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda10.1_nvcc_base --host_platform="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda10.1_nvcc_base --platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda10.1_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
|
||||
build:rbe_linux_cuda10.1_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
|
||||
build:rbe_linux_cuda10.1_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
|
||||
build:rbe_linux_cuda10.1_nvcc_py2.7 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python2.7"
|
||||
build:rbe_linux_cuda10.1_nvcc_py3.5 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.5"
|
||||
build:rbe_linux_cuda10.1_nvcc_py3.6 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.6"
|
||||
build:rbe_linux_cuda10.1_nvcc_py3.7 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.7"
|
||||
build:rbe_linux_cuda10.1_nvcc_py3.8 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8"
|
||||
|
||||
build:rbe_linux_cuda_nvcc_base --config=rbe_linux_cuda_base
|
||||
build:rbe_linux_cuda_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
|
||||
build:rbe_linux_cuda_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
|
||||
build:rbe_linux_cuda_nvcc_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_nvcc_base --host_platform="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_nvcc_base --platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
|
||||
build:rbe_linux_cuda_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
|
||||
build:rbe_linux_cuda_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
|
||||
build:rbe_linux_cuda_nvcc_base --define=using_cuda_nvcc=true
|
||||
build:rbe_linux_cuda_nvcc_py27 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python2.7"
|
||||
build:rbe_linux_cuda_nvcc_py35 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.5"
|
||||
build:rbe_linux_cuda_nvcc_py36 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.6"
|
||||
build:rbe_linux_cuda_nvcc_py37 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.7"
|
||||
build:rbe_linux_cuda_nvcc_py38 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8"
|
||||
build:rbe_linux_cuda11.0_nvcc_base --config=rbe_linux_cuda_base
|
||||
build:rbe_linux_cuda11.0_nvcc_base --define=using_cuda_nvcc=true
|
||||
build:rbe_linux_cuda11.0_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain"
|
||||
build:rbe_linux_cuda11.0_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain-linux-x86_64"
|
||||
build:rbe_linux_cuda11.0_nvcc_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_platform//:platform"
|
||||
build:rbe_linux_cuda11.0_nvcc_base --host_platform="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_platform//:platform"
|
||||
build:rbe_linux_cuda11.0_nvcc_base --platforms="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_platform//:platform"
|
||||
build:rbe_linux_cuda11.0_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda"
|
||||
build:rbe_linux_cuda11.0_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_tensorrt"
|
||||
build:rbe_linux_cuda11.0_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_nccl"
|
||||
build:rbe_linux_cuda11.0_nvcc_py2.7 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python2.7"
|
||||
build:rbe_linux_cuda11.0_nvcc_py3.5 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.5"
|
||||
build:rbe_linux_cuda11.0_nvcc_py3.6 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.6"
|
||||
build:rbe_linux_cuda11.0_nvcc_py3.7 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.7"
|
||||
build:rbe_linux_cuda11.0_nvcc_py3.8 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.8"
|
||||
|
||||
# Map default to CUDA 10.1.
|
||||
build:rbe_linux_cuda_nvcc_py27 --config=rbe_linux_cuda10.1_nvcc_py2.7
|
||||
build:rbe_linux_cuda_nvcc_py35 --config=rbe_linux_cuda10.1_nvcc_py3.5
|
||||
build:rbe_linux_cuda_nvcc_py36 --config=rbe_linux_cuda10.1_nvcc_py3.6
|
||||
build:rbe_linux_cuda_nvcc_py37 --config=rbe_linux_cuda10.1_nvcc_py3.7
|
||||
build:rbe_linux_cuda_nvcc_py38 --config=rbe_linux_cuda10.1_nvcc_py3.8
|
||||
|
||||
# Deprecated configs that people might still use.
|
||||
build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_nvcc_py36
|
||||
build:rbe_gpu_linux --config=rbe_linux_cuda_nvcc
|
||||
|
||||
build:rbe_linux_cuda_clang_base --config=rbe_linux_cuda_base
|
||||
build:rbe_linux_cuda_clang_base --crosstool_top="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
|
||||
@ -440,8 +458,6 @@ build:rbe_linux_cuda_clang_py36 --config=rbe_linux_cuda_clang_base --repo_env=TF
|
||||
build:rbe_linux_cuda_clang_py37 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.7"
|
||||
build:rbe_linux_cuda_clang_py38 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8"
|
||||
|
||||
common:rbe_gpu_linux --config=rbe_linux_cuda_nvcc
|
||||
|
||||
build:rbe_linux_py2 --config=rbe_linux
|
||||
build:rbe_linux_py2 --repo_env=PYTHON_BIN_PATH="/usr/bin/python2"
|
||||
build:rbe_linux_py2 --python_path="/usr/bin/python2"
|
||||
|
@ -95,6 +95,7 @@ for general questions and discussion, and please direct specific questions to
|
||||
The TensorFlow project strives to abide by generally accepted best practices in
|
||||
open-source software development:
|
||||
|
||||
[](https://bugs.chromium.org/p/oss-fuzz/issues/list?sort=-opened&can=1&q=proj:tensorflow)
|
||||
[](https://bestpractices.coreinfrastructure.org/projects/1486)
|
||||
[](CODE_OF_CONDUCT.md)
|
||||
|
||||
|
54
RELEASE.md
54
RELEASE.md
@ -1,3 +1,57 @@
|
||||
# Release 2.4.0
|
||||
|
||||
<INSERT SMALL BLURB ABOUT RELEASE FOCUS AREA AND POTENTIAL TOOLCHAIN CHANGES>
|
||||
|
||||
## Breaking Changes
|
||||
|
||||
* <DOCUMENT BREAKING CHANGES HERE>
|
||||
* <THIS SECTION SHOULD CONTAIN API, ABI AND BEHAVIORAL BREAKING CHANGES>
|
||||
|
||||
## Known Caveats
|
||||
|
||||
* <CAVEATS REGARDING THE RELEASE (BUT NOT BREAKING CHANGES). E.G. ADDING A NEW DEPENDENCY, BUMPING A DEPENDENCY NUMBER, LACK OF SUPPORT ON SOME PLATFORM, ETC>
|
||||
|
||||
## Major Features and Improvements
|
||||
|
||||
* <INSERT MAJOR FEATURE HERE, USING MARKDOWN SYNTAX>
|
||||
* <IF RELEASE CONTAINS MULTIPLE FEATURES FROM SAME AREA, GROUP THEM TOGETHER>
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
|
||||
* <SIMILAR TO ABOVE SECTION, BUT FOR OTHER IMPORTANT CHANGES / BUG FIXES>
|
||||
* <IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE>
|
||||
* <NOTES SHOULD BE GROUPED PER AREA>
|
||||
* TF Core:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* `tf.data`:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* `tf.distribute`:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* `tf.keras`:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* `tf.function`/AutoGraph:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* `tf.lite`:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* `tf.random`:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* Math and Linear Algebra:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* TPU Enhancements:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* XLA Support:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* Tracing and Debugging:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* Other:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
|
||||
## Thanks to our Contributors
|
||||
|
||||
This release contains contributions from many people at Google, as well as:
|
||||
|
||||
<INSERT>, <NAME>, <HERE>, <USING>, <GITHUB>, <HANDLE>
|
||||
|
||||
# Release 2.3.0
|
||||
|
||||
## Breaking Changes
|
||||
|
@ -467,6 +467,13 @@ config_setting(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# This flag enables experimental TPU support
|
||||
config_setting(
|
||||
name = "with_tpu_support",
|
||||
values = {"define": "with_tpu_support=true"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# Specifies via a config setting if this is a mobile build or not, makes
|
||||
# it easier to combine settings later.
|
||||
selects.config_setting_group(
|
||||
|
@ -624,7 +624,7 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
|
||||
|
||||
const int num_inputs = input_shapes->num_items;
|
||||
NodeDef node_def;
|
||||
tensorflow::AbstractOperationInterface* op = tensorflow::unwrap(tfe_op);
|
||||
tensorflow::ImmediateExecutionOperation* op = tensorflow::unwrap(tfe_op);
|
||||
node_def.set_name(op->Name());
|
||||
node_def.set_op(op->Name());
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
|
@ -54,7 +54,7 @@ Status ProcessInputs(
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
|
||||
input_tensors->reserve(ninputs);
|
||||
for (int i = 0; i < ninputs; ++i) {
|
||||
Node* node = &inputs[i].oper->node;
|
||||
Node* node = inputs[i].oper ? &inputs[i].oper->node : nullptr;
|
||||
int idx = inputs[i].index;
|
||||
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
@ -90,7 +90,7 @@ Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name,
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
|
||||
output_tensors->reserve(noutputs);
|
||||
for (int i = 0; i < noutputs; ++i) {
|
||||
Node* node = &outputs[i].oper->node;
|
||||
Node* node = outputs[i].oper ? &outputs[i].oper->node : nullptr;
|
||||
int idx = outputs[i].index;
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
fn_body->graph.IsValidOutputTensor(node, idx),
|
||||
|
@ -38,9 +38,10 @@ tf_cuda_library(
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
":context_interface",
|
||||
":operation_interface",
|
||||
":tensor_handle_interface",
|
||||
":immediate_execution_context",
|
||||
":immediate_execution_operation",
|
||||
":immediate_execution_tensor_handle",
|
||||
":abstract_tensor_handle",
|
||||
":tfe_context_internal",
|
||||
":tfe_cancellation_manager_internal",
|
||||
":tfe_executor_internal",
|
||||
@ -101,13 +102,17 @@ tf_cuda_library(
|
||||
filegroup(
|
||||
name = "pywrap_required_hdrs",
|
||||
srcs = [
|
||||
"abstract_context.h",
|
||||
"abstract_function.h",
|
||||
"abstract_operation.h",
|
||||
"abstract_tensor_handle.h",
|
||||
"c_api_experimental.h",
|
||||
"c_api_internal.h",
|
||||
"c_api_unified_experimental.h",
|
||||
"context_interface.h",
|
||||
"dlpack.h",
|
||||
"operation_interface.h",
|
||||
"tensor_handle_interface.h",
|
||||
"immediate_execution_context.h",
|
||||
"immediate_execution_operation.h",
|
||||
"immediate_execution_tensor_handle.h",
|
||||
"tfe_cancellation_manager_internal.h",
|
||||
"tfe_executor_internal.h",
|
||||
"tfe_monitoring_internal.h",
|
||||
@ -163,12 +168,22 @@ cc_library(
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensor_handle_interface",
|
||||
hdrs = ["tensor_handle_interface.h"],
|
||||
name = "abstract_tensor_handle",
|
||||
hdrs = ["abstract_tensor_handle.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "immediate_execution_tensor_handle",
|
||||
hdrs = ["immediate_execution_tensor_handle.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":abstract_tensor_handle",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
@ -177,13 +192,13 @@ cc_library(
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "operation_interface",
|
||||
hdrs = ["operation_interface.h"],
|
||||
name = "abstract_operation",
|
||||
hdrs = ["abstract_operation.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":tensor_handle_interface",
|
||||
":abstract_tensor_handle",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
@ -193,14 +208,58 @@ cc_library(
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "context_interface",
|
||||
hdrs = ["context_interface.h"],
|
||||
name = "immediate_execution_operation",
|
||||
hdrs = ["immediate_execution_operation.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":operation_interface",
|
||||
":tensor_handle_interface",
|
||||
":abstract_operation",
|
||||
":abstract_tensor_handle",
|
||||
":immediate_execution_tensor_handle",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "abstract_context",
|
||||
hdrs = ["abstract_context.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":abstract_function",
|
||||
":abstract_operation",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "abstract_function",
|
||||
hdrs = ["abstract_function.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:status",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "immediate_execution_context",
|
||||
hdrs = ["immediate_execution_context.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":abstract_context",
|
||||
":immediate_execution_operation",
|
||||
":immediate_execution_tensor_handle",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
@ -217,7 +276,7 @@ cc_library(
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":context_interface",
|
||||
":immediate_execution_context",
|
||||
"//tensorflow/c:conversion_macros",
|
||||
],
|
||||
)
|
||||
@ -277,7 +336,7 @@ cc_library(
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":operation_interface",
|
||||
":immediate_execution_operation",
|
||||
"//tensorflow/c:conversion_macros",
|
||||
],
|
||||
)
|
||||
@ -300,7 +359,7 @@ cc_library(
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":tensor_handle_interface",
|
||||
":immediate_execution_tensor_handle",
|
||||
"//tensorflow/c:conversion_macros",
|
||||
],
|
||||
)
|
||||
@ -480,6 +539,9 @@ tf_cuda_library(
|
||||
":tfe_context_internal",
|
||||
":tfe_op_internal",
|
||||
":tfe_tensorhandle_internal",
|
||||
":abstract_operation",
|
||||
":abstract_context",
|
||||
":abstract_tensor_handle",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_api_internal",
|
||||
"//tensorflow/core:core_cpu",
|
||||
|
83
tensorflow/c/eager/abstract_context.h
Normal file
83
tensorflow/c/eager/abstract_context.h
Normal file
@ -0,0 +1,83 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_ABSTRACT_CONTEXT_H_
|
||||
#define TENSORFLOW_C_EAGER_ABSTRACT_CONTEXT_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/abstract_function.h"
|
||||
#include "tensorflow/c/eager/abstract_operation.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Abstract interface to a context.
|
||||
//
|
||||
// This serves as a factory for creating `AbstractOperation`s and for
|
||||
// registering traced functions.
|
||||
// Operations creation within a context can only be executed in that context
|
||||
// (for now at least).
|
||||
// Implementations of the context may contain some state e.g. an execution
|
||||
// environment, a traced representation etc.
|
||||
class AbstractContext {
|
||||
protected:
|
||||
enum AbstractContextKind { kTracing, kImmediateExecution };
|
||||
explicit AbstractContext(AbstractContextKind kind) : kind_(kind) {}
|
||||
virtual ~AbstractContext() {}
|
||||
|
||||
public:
|
||||
AbstractContextKind getKind() const { return kind_; }
|
||||
|
||||
// Release any underlying resources, including the interface object.
|
||||
//
|
||||
// WARNING: The destructor of this class is marked as protected to disallow
|
||||
// clients from directly destroying this object since it may manage it's own
|
||||
// lifetime through ref counting. Thus clients MUST call Release() in order to
|
||||
// destroy an instance of this class.
|
||||
virtual void Release() = 0;
|
||||
|
||||
// Creates an operation builder and ties it to this context.
|
||||
// The returned object can be used for setting operation's attributes,
|
||||
// adding inputs and finally executing (immediately or lazily as in tracing)
|
||||
// it in this context.
|
||||
virtual AbstractOperation* CreateOperation() = 0;
|
||||
|
||||
// Registers a function with this context, after this the function is
|
||||
// available to be called/referenced by its name in this context.
|
||||
virtual Status RegisterFunction(AbstractFunction*) = 0;
|
||||
// Remove a function. 'func' argument is the name of a previously added
|
||||
// FunctionDef. The name is in fdef.signature.name.
|
||||
virtual Status RemoveFunction(const string& func) = 0;
|
||||
|
||||
private:
|
||||
const AbstractContextKind kind_;
|
||||
};
|
||||
|
||||
namespace internal {
|
||||
struct AbstractContextDeleter {
|
||||
void operator()(AbstractContext* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace internal
|
||||
|
||||
using AbstractContextPtr =
|
||||
std::unique_ptr<AbstractContext, internal::AbstractContextDeleter>;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_ABSTRACT_CONTEXT_H_
|
46
tensorflow/c/eager/abstract_function.h
Normal file
46
tensorflow/c/eager/abstract_function.h
Normal file
@ -0,0 +1,46 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_ABSTRACT_FUNCTION_H_
|
||||
#define TENSORFLOW_C_EAGER_ABSTRACT_FUNCTION_H_
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// A traced function: this hides the complexity of converting the serialized
|
||||
// representation between various supported formats e.g. FunctionDef and Mlir
|
||||
// function.
|
||||
class AbstractFunction {
|
||||
protected:
|
||||
enum AbstractFunctionKind { kGraphFunc, kMlirFunc };
|
||||
explicit AbstractFunction(AbstractFunctionKind kind) : kind_(kind) {}
|
||||
|
||||
public:
|
||||
// Returns which subclass is this instance of.
|
||||
AbstractFunctionKind getKind() const { return kind_; }
|
||||
virtual ~AbstractFunction() = default;
|
||||
|
||||
// Returns the AbstractFunction as a FunctionDef.
|
||||
virtual Status GetFunctionDef(FunctionDef**) = 0;
|
||||
|
||||
private:
|
||||
const AbstractFunctionKind kind_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_ABSTRACT_FUNCTION_H_
|
@ -12,24 +12,31 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
|
||||
#define TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
|
||||
#ifndef TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_
|
||||
#define TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
struct TFE_Op;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Abstract interface to an operation.
|
||||
class AbstractOperationInterface {
|
||||
// This interface allows building and executing an operation in either
|
||||
// tracing or immediate execution mode.
|
||||
class AbstractOperation {
|
||||
protected:
|
||||
enum AbstractOperationKind { kTracing, kImmediateExecution };
|
||||
explicit AbstractOperation(AbstractOperationKind kind) : kind_(kind) {}
|
||||
virtual ~AbstractOperation() {}
|
||||
|
||||
public:
|
||||
AbstractOperationKind getKind() const { return kind_; }
|
||||
|
||||
// Release any underlying resources, including the interface object.
|
||||
//
|
||||
// WARNING: The destructor of this class is marked as protected to disallow
|
||||
@ -38,7 +45,6 @@ class AbstractOperationInterface {
|
||||
// clients MUST call Release() in order to destroy an instance of this class.
|
||||
virtual void Release() = 0;
|
||||
|
||||
virtual void Clear() = 0;
|
||||
virtual Status Reset(const char* op, const char* raw_device_name) = 0;
|
||||
|
||||
virtual const string& Name() const = 0;
|
||||
@ -66,12 +72,10 @@ class AbstractOperationInterface {
|
||||
// existing and given constraints will be performed.
|
||||
virtual Status SetDeviceName(const char* name) = 0;
|
||||
|
||||
virtual Status AddInput(AbstractTensorHandleInterface* input) = 0;
|
||||
virtual Status AddInputList(
|
||||
absl::Span<AbstractTensorHandleInterface*> inputs) = 0;
|
||||
virtual Status Execute(absl::Span<AbstractTensorHandleInterface*> retvals,
|
||||
virtual Status AddInput(AbstractTensorHandle* input) = 0;
|
||||
virtual Status AddInputList(absl::Span<AbstractTensorHandle*> inputs) = 0;
|
||||
virtual Status Execute(absl::Span<AbstractTensorHandle*> retvals,
|
||||
int* num_retvals) = 0;
|
||||
virtual const tensorflow::OpDef* OpDef() const = 0;
|
||||
|
||||
virtual Status SetAttrString(const char* attr_name, const char* data,
|
||||
size_t length) = 0;
|
||||
@ -82,7 +86,7 @@ class AbstractOperationInterface {
|
||||
virtual Status SetAttrShape(const char* attr_name, const int64_t* dims,
|
||||
const int num_dims) = 0;
|
||||
virtual Status SetAttrFunction(const char* attr_name,
|
||||
const AbstractOperationInterface* value) = 0;
|
||||
const AbstractOperation* value) = 0;
|
||||
virtual Status SetAttrFunctionName(const char* attr_name, const char* value,
|
||||
size_t length) = 0;
|
||||
virtual Status SetAttrTensor(const char* attr_name,
|
||||
@ -102,19 +106,25 @@ class AbstractOperationInterface {
|
||||
virtual Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
|
||||
const int* num_dims, int num_values) = 0;
|
||||
virtual Status SetAttrFunctionList(
|
||||
const char* attr_name,
|
||||
absl::Span<const AbstractOperationInterface*> values) = 0;
|
||||
const char* attr_name, absl::Span<const AbstractOperation*> values) = 0;
|
||||
|
||||
virtual Status InputLength(const char* input_name, int* length) = 0;
|
||||
virtual Status OutputLength(const char* output_name, int* length) = 0;
|
||||
|
||||
// Experimental
|
||||
virtual Status SetUseXla(bool enable) = 0;
|
||||
|
||||
protected:
|
||||
virtual ~AbstractOperationInterface() {}
|
||||
private:
|
||||
const AbstractOperationKind kind_;
|
||||
};
|
||||
|
||||
namespace internal {
|
||||
struct AbstractOperationDeleter {
|
||||
void operator()(AbstractOperation* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace internal
|
||||
|
||||
using AbstractOpPtr =
|
||||
std::unique_ptr<AbstractOperation, internal::AbstractOperationDeleter>;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
|
||||
#endif // TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_
|
61
tensorflow/c/eager/abstract_tensor_handle.h
Normal file
61
tensorflow/c/eager/abstract_tensor_handle.h
Normal file
@ -0,0 +1,61 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_ABSTRACT_TENSOR_HANDLE_H_
|
||||
#define TENSORFLOW_C_EAGER_ABSTRACT_TENSOR_HANDLE_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Abstract interface to a Tensor handle in either tracing or immediate
|
||||
// execution mode.
|
||||
class AbstractTensorHandle {
|
||||
protected:
|
||||
enum AbstractTensorHandleKind { kTracing, kImmediateExecution };
|
||||
explicit AbstractTensorHandle(AbstractTensorHandleKind kind) : kind_(kind) {}
|
||||
virtual ~AbstractTensorHandle() {}
|
||||
|
||||
public:
|
||||
AbstractTensorHandleKind getKind() const { return kind_; }
|
||||
|
||||
// Release any underlying resources, including the interface object.
|
||||
//
|
||||
// WARNING: The destructor of this class is marked as protected to disallow
|
||||
// clients from directly destroying this object since it may manage it's own
|
||||
// lifetime through ref counting. Thus this must be allocated on the heap and
|
||||
// clients MUST call Release() in order to destroy an instance of this class.
|
||||
virtual void Release() = 0;
|
||||
|
||||
private:
|
||||
const AbstractTensorHandleKind kind_;
|
||||
};
|
||||
|
||||
namespace internal {
|
||||
struct AbstractTensorHandleDeleter {
|
||||
void operator()(AbstractTensorHandle* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace internal
|
||||
|
||||
using AbstractTensorHandlePtr =
|
||||
std::unique_ptr<AbstractTensorHandle,
|
||||
internal::AbstractTensorHandleDeleter>;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_ABSTRACT_TENSOR_HANDLE_H_
|
@ -21,6 +21,8 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
|
||||
// clang-format off
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
// clang-format on
|
||||
@ -31,8 +33,8 @@ limitations under the License.
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/tfe_context_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_op_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
@ -1119,7 +1121,7 @@ size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
|
||||
|
||||
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
TF_Status* status) {
|
||||
tensorflow::AbstractOperationInterface* new_op =
|
||||
tensorflow::ImmediateExecutionOperation* new_op =
|
||||
tensorflow::unwrap(ctx)->CreateOperation();
|
||||
status->status = new_op->Reset(op_or_function_name, nullptr);
|
||||
if (!status->status.ok()) {
|
||||
@ -1164,7 +1166,9 @@ void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
|
||||
void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
|
||||
TF_Status* status) {
|
||||
status->status = tensorflow::unwrap(op)->AddInputList(
|
||||
{tensorflow::unwrap(inputs), static_cast<size_t>(num_inputs)});
|
||||
{reinterpret_cast<tensorflow::AbstractTensorHandle**>(
|
||||
tensorflow::unwrap(inputs)),
|
||||
static_cast<size_t>(num_inputs)});
|
||||
}
|
||||
|
||||
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
|
||||
@ -1324,7 +1328,9 @@ void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
|
||||
void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
|
||||
const TFE_Op** value, int num_values) {
|
||||
auto s = tensorflow::unwrap(op)->SetAttrFunctionList(
|
||||
attr_name, {tensorflow::unwrap(value), static_cast<size_t>(num_values)});
|
||||
attr_name, {reinterpret_cast<const tensorflow::AbstractOperation**>(
|
||||
tensorflow::unwrap(value)),
|
||||
static_cast<size_t>(num_values)});
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
@ -1368,7 +1374,10 @@ TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
|
||||
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
|
||||
TF_Status* status) {
|
||||
status->status = tensorflow::unwrap(op)->Execute(
|
||||
absl::MakeSpan(tensorflow::unwrap(retvals), *num_retvals), num_retvals);
|
||||
absl::MakeSpan(reinterpret_cast<tensorflow::AbstractTensorHandle**>(
|
||||
tensorflow::unwrap(retvals)),
|
||||
*num_retvals),
|
||||
num_retvals);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
||||
|
@ -38,7 +38,7 @@ using tensorflow::string;
|
||||
void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
|
||||
const char* raw_device_name, TF_Status* status) {
|
||||
if (op_to_reset) {
|
||||
tensorflow::AbstractOperationInterface* op =
|
||||
tensorflow::ImmediateExecutionOperation* op =
|
||||
tensorflow::unwrap(op_to_reset);
|
||||
op->Clear();
|
||||
status->status = op->Reset(op_or_function_name, raw_device_name);
|
||||
@ -60,6 +60,12 @@ void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
|
||||
context->SetShouldStoreGraphs(false);
|
||||
}
|
||||
|
||||
uint64_t TFE_GetContextId(TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
return context->GetContextId();
|
||||
}
|
||||
|
||||
void TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell* cell,
|
||||
int64_t value) {
|
||||
cell->cell.IncrementBy(value);
|
||||
|
@ -300,6 +300,14 @@ TF_CAPI_EXPORT extern void TFE_ContextOptionsSetLazyRemoteInputsCopy(
|
||||
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetTfrt(TFE_ContextOptions*,
|
||||
bool use_tfrt);
|
||||
|
||||
// Returns the context_id from the EagerContext which is used by the
|
||||
// EagerService to maintain consistency between client and worker. The
|
||||
// context_id is initialized with a dummy value and is later set when the worker
|
||||
// is initialized (either locally or remotely). The context_id can change during
|
||||
// the process lifetime although this should cause the worker to be
|
||||
// reinitialized (e.g. cleared caches) as well.
|
||||
TF_CAPI_EXPORT extern uint64_t TFE_GetContextId(TFE_Context* ctx);
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Cancellation APIs.
|
||||
|
||||
|
@ -12,15 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_
|
||||
#define TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_
|
||||
#ifndef TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_
|
||||
#define TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/eager/abstract_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/numeric_types.h"
|
||||
@ -34,16 +36,9 @@ namespace tensorflow {
|
||||
//
|
||||
// A context is responsible for creating key objects such as Tensors,
|
||||
// TensorHandles & Operations.
|
||||
class AbstractContextInterface {
|
||||
class ImmediateExecutionContext : public AbstractContext {
|
||||
public:
|
||||
// Release any underlying resources, including the interface object.
|
||||
//
|
||||
// WARNING: The destructor of this class is marked as protected to disallow
|
||||
// clients from directly destroying this object since it may manage it's own
|
||||
// lifetime through ref counting. Thus clients MUST call Release() in order to
|
||||
// destroy an instance of this class.
|
||||
virtual void Release() = 0;
|
||||
|
||||
static constexpr AbstractContextKind kKind = kImmediateExecution;
|
||||
// Optimized scalar creation functions
|
||||
virtual AbstractTensorInterface* CreateInt64Scalar(int64 value) = 0;
|
||||
virtual AbstractTensorInterface* CreateUint64Scalar(uint64 value) = 0;
|
||||
@ -74,15 +69,15 @@ class AbstractContextInterface {
|
||||
void* memory_releaser_arg) = 0;
|
||||
|
||||
// Create a handle to wrap and manage a Tensor
|
||||
virtual AbstractTensorHandleInterface* CreateLocalHandle(
|
||||
virtual ImmediateExecutionTensorHandle* CreateLocalHandle(
|
||||
AbstractTensorInterface* t) = 0;
|
||||
// Copy the handle to another device.
|
||||
virtual AbstractTensorHandleInterface* CopyTensorHandleToDevice(
|
||||
AbstractTensorHandleInterface* handle, const char* device_name,
|
||||
virtual ImmediateExecutionTensorHandle* CopyTensorHandleToDevice(
|
||||
ImmediateExecutionTensorHandle* handle, const char* device_name,
|
||||
Status* status) = 0;
|
||||
|
||||
// Create an operation to perform op execution
|
||||
virtual AbstractOperationInterface* CreateOperation() = 0;
|
||||
ImmediateExecutionOperation* CreateOperation() override = 0;
|
||||
|
||||
// Returns whether the runtime is backed by TFRT or the legacy TF Eager
|
||||
// Runtime. This is necessary to decouple runtime-dependent
|
||||
@ -107,14 +102,26 @@ class AbstractContextInterface {
|
||||
// be executed as an op. Return error if the function with the same name
|
||||
// already exists.
|
||||
virtual Status AddFunctionDef(const FunctionDef& fdef) = 0;
|
||||
// Remove a function. 'func' argument is the name of a previously added
|
||||
// FunctionDef. The name is in fdef.signature.name.
|
||||
virtual Status RemoveFunction(const string& func) = 0;
|
||||
|
||||
protected:
|
||||
virtual ~AbstractContextInterface() {}
|
||||
ImmediateExecutionContext() : AbstractContext(kKind) {}
|
||||
~ImmediateExecutionContext() override {}
|
||||
};
|
||||
|
||||
namespace internal {
|
||||
struct ImmediateExecutionContextDeleter {
|
||||
void operator()(ImmediateExecutionContext* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace internal
|
||||
|
||||
using ImmediateContextPtr =
|
||||
std::unique_ptr<ImmediateExecutionContext,
|
||||
internal::ImmediateExecutionContextDeleter>;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_
|
||||
#endif // TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_
|
69
tensorflow/c/eager/immediate_execution_operation.h
Normal file
69
tensorflow/c/eager/immediate_execution_operation.h
Normal file
@ -0,0 +1,69 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_OPERATION_H_
|
||||
#define TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_OPERATION_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_operation.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
struct TFE_Op;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Abstract interface to an operation.
|
||||
class ImmediateExecutionOperation : public AbstractOperation {
|
||||
public:
|
||||
static constexpr AbstractOperationKind kKind = kImmediateExecution;
|
||||
virtual void Clear() = 0;
|
||||
|
||||
virtual const tensorflow::OpDef* OpDef() const = 0;
|
||||
|
||||
virtual Status InputLength(const char* input_name, int* length) = 0;
|
||||
virtual Status OutputLength(const char* output_name, int* length) = 0;
|
||||
|
||||
// Experimental
|
||||
virtual Status SetUseXla(bool enable) = 0;
|
||||
|
||||
protected:
|
||||
ImmediateExecutionOperation() : AbstractOperation(kKind) {}
|
||||
~ImmediateExecutionOperation() override {}
|
||||
};
|
||||
|
||||
namespace internal {
|
||||
struct ImmediateExecutionOperationDeleter {
|
||||
void operator()(ImmediateExecutionOperation* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace internal
|
||||
|
||||
using ImmediateOpPtr =
|
||||
std::unique_ptr<ImmediateExecutionOperation,
|
||||
internal::ImmediateExecutionOperationDeleter>;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_OPERATION_H_
|
@ -12,9 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
|
||||
#define TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
|
||||
#ifndef TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_TENSOR_HANDLE_H_
|
||||
#define TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_TENSOR_HANDLE_H_
|
||||
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
@ -30,15 +31,9 @@ namespace tensorflow {
|
||||
// files. The interface lists the common functionality that must be provided by
|
||||
// any concrete implementation. However, in cases where the true concrete class
|
||||
// is needed a static_cast can be applied.
|
||||
class AbstractTensorHandleInterface {
|
||||
class ImmediateExecutionTensorHandle : public AbstractTensorHandle {
|
||||
public:
|
||||
// Release any underlying resources, including the interface object.
|
||||
//
|
||||
// WARNING: The destructor of this class is marked as protected to disallow
|
||||
// clients from directly destroying this object since it may manage it's own
|
||||
// lifetime through ref counting. Thus this must be allocated on the heap and
|
||||
// clients MUST call Release() in order to destroy an instance of this class.
|
||||
virtual void Release() = 0;
|
||||
static constexpr AbstractTensorHandleKind kKind = kImmediateExecution;
|
||||
|
||||
// Returns tensor dtype.
|
||||
virtual tensorflow::DataType DataType() const = 0;
|
||||
@ -57,12 +52,27 @@ class AbstractTensorHandleInterface {
|
||||
virtual AbstractTensorInterface* Resolve(Status* status) = 0;
|
||||
|
||||
// Return a copy of the handle.
|
||||
virtual AbstractTensorHandleInterface* Copy() = 0;
|
||||
virtual ImmediateExecutionTensorHandle* Copy() = 0;
|
||||
|
||||
protected:
|
||||
virtual ~AbstractTensorHandleInterface() {}
|
||||
ImmediateExecutionTensorHandle() : AbstractTensorHandle(kKind) {}
|
||||
~ImmediateExecutionTensorHandle() override {}
|
||||
};
|
||||
|
||||
namespace internal {
|
||||
struct ImmediateExecutionTensorHandleDeleter {
|
||||
void operator()(ImmediateExecutionTensorHandle* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace internal
|
||||
|
||||
using ImmediateTensorHandlePtr =
|
||||
std::unique_ptr<ImmediateExecutionTensorHandle,
|
||||
internal::ImmediateExecutionTensorHandleDeleter>;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
|
||||
#endif // TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_TENSOR_HANDLE_H_
|
@ -262,14 +262,14 @@ std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
|
||||
components.reserve(underlying_devices_.size());
|
||||
for (int device_index = 0; device_index < underlying_devices_.size();
|
||||
++device_index) {
|
||||
int64_t* device_id = new int64_t;
|
||||
int32_t* device_id = new int32_t;
|
||||
*device_id = device_index;
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
|
||||
TF_NewTensor(
|
||||
TF_INT64, /*dims=*/nullptr, /*num_dims=*/0, device_id,
|
||||
sizeof(int64_t),
|
||||
TF_INT32, /*dims=*/nullptr, /*num_dims=*/0, device_id,
|
||||
sizeof(int32_t),
|
||||
[](void* data, size_t, void* arg) {
|
||||
delete reinterpret_cast<int64_t*>(data);
|
||||
delete reinterpret_cast<int32_t*>(data);
|
||||
},
|
||||
nullptr),
|
||||
TF_DeleteTensor);
|
||||
@ -283,7 +283,7 @@ std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT64);
|
||||
TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT32);
|
||||
TFE_TensorHandle* device_handle;
|
||||
int num_outputs = 1;
|
||||
TFE_Execute(const_op.get(), &device_handle, &num_outputs, status);
|
||||
|
@ -296,8 +296,8 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
|
||||
TFE_DeleteTensorHandle(result_handle);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
ExpectScalarEq<int64_t>(components[0].get(), 0);
|
||||
ExpectScalarEq<int64_t>(components[1].get(), 1);
|
||||
ExpectScalarEq<int32_t>(components[0].get(), 0);
|
||||
ExpectScalarEq<int32_t>(components[1].get(), 1);
|
||||
std::string first_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[0], first_device);
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/eager/context_interface.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
|
||||
// Wraps a pointer to a context implementation.
|
||||
//
|
||||
@ -28,7 +28,7 @@ typedef struct TFE_Context TFE_Context;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractContextInterface, TFE_Context);
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionContext, TFE_Context);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
|
||||
// Wraps a pointer to an operation implementation.
|
||||
//
|
||||
@ -28,8 +28,8 @@ typedef struct TFE_Op TFE_Op;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractOperationInterface, TFE_Op);
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractOperationInterface*, TFE_Op*);
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionOperation, TFE_Op);
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionOperation*, TFE_Op*);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
|
||||
// Wraps a pointer to a tensor handle implementation.
|
||||
//
|
||||
@ -28,9 +28,9 @@ typedef struct TFE_TensorHandle TFE_TensorHandle;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractTensorHandleInterface,
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionTensorHandle,
|
||||
TFE_TensorHandle);
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractTensorHandleInterface*,
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionTensorHandle*,
|
||||
TFE_TensorHandle*);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/path.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
struct TF_StringStream {
|
||||
@ -146,6 +147,10 @@ TF_StringStream* TF_GetLocalTempDirectories() {
|
||||
return list;
|
||||
}
|
||||
|
||||
char* TF_GetTempFileName(const char* extension) {
|
||||
return strdup(::tensorflow::io::GetTempFilename(extension).c_str());
|
||||
}
|
||||
|
||||
TF_CAPI_EXPORT extern uint64_t TF_NowNanos(void) {
|
||||
return ::tensorflow::Env::Default()->NowNanos();
|
||||
}
|
||||
|
@ -152,6 +152,10 @@ TF_CAPI_EXPORT extern TF_StringStream* TF_GetChildren(const char* filename,
|
||||
// The caller is responsible for freeing the list (see TF_StringStreamDone).
|
||||
TF_CAPI_EXPORT extern TF_StringStream* TF_GetLocalTempDirectories(void);
|
||||
|
||||
// Creates a temporary file name with an extension.
|
||||
// The caller is responsible for freeing the returned pointer.
|
||||
TF_CAPI_EXPORT extern char* TF_GetTempFileName(const char* extension);
|
||||
|
||||
// Returns the number of nanoseconds since the Unix epoch.
|
||||
TF_CAPI_EXPORT extern uint64_t TF_NowNanos(void);
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
# Experimental gcs filesystem plugin.
|
||||
load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object")
|
||||
load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object", "tf_cc_test")
|
||||
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
@ -19,14 +19,46 @@ tf_cc_shared_object(
|
||||
cc_library(
|
||||
name = "gcs_filesystem_impl",
|
||||
srcs = ["gcs_filesystem.cc"],
|
||||
hdrs = ["gcs_filesystem.h"],
|
||||
copts = select({
|
||||
"//conditions:default": [],
|
||||
"//tensorflow:windows": get_win_copts(),
|
||||
}),
|
||||
deps = [
|
||||
":gcs_helper",
|
||||
"//tensorflow/c:env",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
||||
"@com_github_googlecloudplatform_google_cloud_cpp//:storage_client",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gcs_helper",
|
||||
srcs = ["gcs_helper.cc"],
|
||||
hdrs = ["gcs_helper.h"],
|
||||
linkstatic = 1,
|
||||
deps = [
|
||||
"//tensorflow/c:env",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "gcs_filesystem_test",
|
||||
srcs = [
|
||||
"gcs_filesystem.cc",
|
||||
"gcs_filesystem_test.cc",
|
||||
],
|
||||
local_defines = ["TF_GCS_FILESYSTEM_TEST"],
|
||||
tags = [
|
||||
"manual",
|
||||
"notap",
|
||||
],
|
||||
deps = [
|
||||
":gcs_filesystem_impl",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/core/platform:stacktrace_handler",
|
||||
"//tensorflow/core/platform:test",
|
||||
],
|
||||
)
|
||||
|
@ -12,12 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.h"
|
||||
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "google/cloud/storage/client.h"
|
||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||
#include "tensorflow/c/env.h"
|
||||
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
// Implementation of a filesystem for GCS environments.
|
||||
@ -36,8 +38,8 @@ static inline void TF_SetStatusFromGCSStatus(
|
||||
static void* plugin_memory_allocate(size_t size) { return calloc(1, size); }
|
||||
static void plugin_memory_free(void* ptr) { free(ptr); }
|
||||
|
||||
static void ParseGCSPath(absl::string_view fname, bool object_empty_ok,
|
||||
char** bucket, char** object, TF_Status* status) {
|
||||
void ParseGCSPath(absl::string_view fname, bool object_empty_ok, char** bucket,
|
||||
char** object, TF_Status* status) {
|
||||
size_t scheme_end = fname.find("://") + 2;
|
||||
if (fname.substr(0, scheme_end + 1) != "gs://") {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
@ -86,6 +88,20 @@ namespace tf_random_access_file {
|
||||
// SECTION 2. Implementation for `TF_WritableFile`
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_writable_file {
|
||||
typedef struct GCSFile {
|
||||
const char* bucket;
|
||||
const char* object;
|
||||
gcs::Client* gcs_client; // not owned
|
||||
TempFile outfile;
|
||||
bool sync_need;
|
||||
} GCSFile;
|
||||
|
||||
static void Cleanup(TF_WritableFile* file) {
|
||||
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
|
||||
plugin_memory_free(const_cast<char*>(gcs_file->bucket));
|
||||
plugin_memory_free(const_cast<char*>(gcs_file->object));
|
||||
delete gcs_file;
|
||||
}
|
||||
|
||||
// TODO(vnvo2409): Implement later
|
||||
|
||||
@ -104,7 +120,7 @@ namespace tf_read_only_memory_region {
|
||||
namespace tf_gcs_filesystem {
|
||||
|
||||
// TODO(vnvo2409): Add lazy-loading and customizing parameters.
|
||||
static void Init(TF_Filesystem* filesystem, TF_Status* status) {
|
||||
void Init(TF_Filesystem* filesystem, TF_Status* status) {
|
||||
google::cloud::StatusOr<gcs::Client> client =
|
||||
gcs::Client::CreateDefaultClient();
|
||||
if (!client) {
|
||||
@ -117,8 +133,54 @@ static void Init(TF_Filesystem* filesystem, TF_Status* status) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
void Cleanup(TF_Filesystem* filesystem) {
|
||||
plugin_memory_free(filesystem->plugin_filesystem);
|
||||
}
|
||||
|
||||
// TODO(vnvo2409): Implement later
|
||||
|
||||
void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_WritableFile* file, TF_Status* status) {
|
||||
char* bucket;
|
||||
char* object;
|
||||
ParseGCSPath(path, false, &bucket, &object, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
auto gcs_client = static_cast<gcs::Client*>(filesystem->plugin_filesystem);
|
||||
char* temp_file_name = TF_GetTempFileName("");
|
||||
file->plugin_file = new tf_writable_file::GCSFile(
|
||||
{bucket, object, gcs_client,
|
||||
TempFile(temp_file_name, std::ios::binary | std::ios::out), true});
|
||||
// We are responsible for freeing the pointer returned by TF_GetTempFileName
|
||||
free(temp_file_name);
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_WritableFile* file, TF_Status* status) {
|
||||
char* bucket;
|
||||
char* object;
|
||||
ParseGCSPath(path, false, &bucket, &object, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
auto gcs_client = static_cast<gcs::Client*>(filesystem->plugin_filesystem);
|
||||
char* temp_file_name = TF_GetTempFileName("");
|
||||
|
||||
auto gcs_status = gcs_client->DownloadToFile(bucket, object, temp_file_name);
|
||||
TF_SetStatusFromGCSStatus(gcs_status, status);
|
||||
auto status_code = TF_GetCode(status);
|
||||
if (status_code != TF_OK && status_code != TF_NOT_FOUND) {
|
||||
return;
|
||||
}
|
||||
// If this file does not exist on server, we will need to sync it.
|
||||
bool sync_need = (status_code == TF_NOT_FOUND);
|
||||
file->plugin_file = new tf_writable_file::GCSFile(
|
||||
{bucket, object, gcs_client,
|
||||
TempFile(temp_file_name, std::ios::binary | std::ios::app), sync_need});
|
||||
free(temp_file_name);
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
} // namespace tf_gcs_filesystem
|
||||
|
||||
static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
|
||||
@ -126,9 +188,17 @@ static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
|
||||
TF_SetFilesystemVersionMetadata(ops);
|
||||
ops->scheme = strdup(uri);
|
||||
|
||||
ops->writable_file_ops = static_cast<TF_WritableFileOps*>(
|
||||
plugin_memory_allocate(TF_WRITABLE_FILE_OPS_SIZE));
|
||||
ops->writable_file_ops->cleanup = tf_writable_file::Cleanup;
|
||||
|
||||
ops->filesystem_ops = static_cast<TF_FilesystemOps*>(
|
||||
plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE));
|
||||
ops->filesystem_ops->init = tf_gcs_filesystem::Init;
|
||||
ops->filesystem_ops->cleanup = tf_gcs_filesystem::Cleanup;
|
||||
ops->filesystem_ops->new_writable_file = tf_gcs_filesystem::NewWritableFile;
|
||||
ops->filesystem_ops->new_appendable_file =
|
||||
tf_gcs_filesystem::NewAppendableFile;
|
||||
}
|
||||
|
||||
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {
|
||||
|
@ -0,0 +1,35 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_FILESYSTEM_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_FILESYSTEM_H_
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "google/cloud/storage/client.h"
|
||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
void ParseGCSPath(absl::string_view fname, bool object_empty_ok, char** bucket,
|
||||
char** object, TF_Status* status);
|
||||
|
||||
namespace tf_gcs_filesystem {
|
||||
void Init(TF_Filesystem* filesystem, TF_Status* status);
|
||||
void Cleanup(TF_Filesystem* filesystem);
|
||||
void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_WritableFile* file, TF_Status* status);
|
||||
void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_WritableFile* file, TF_Status* status);
|
||||
} // namespace tf_gcs_filesystem
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_FILESYSTEM_H_
|
@ -0,0 +1,57 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.h"
|
||||
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/platform/stacktrace_handler.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
#define ASSERT_TF_OK(x) ASSERT_EQ(TF_OK, TF_GetCode(x))
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class GCSFilesystemTest : public ::testing::Test {
|
||||
public:
|
||||
void SetUp() override {
|
||||
status_ = TF_NewStatus();
|
||||
filesystem_ = new TF_Filesystem;
|
||||
tf_gcs_filesystem::Init(filesystem_, status_);
|
||||
ASSERT_TF_OK(status_) << "Can not initialize filesystem. "
|
||||
<< TF_Message(status_);
|
||||
}
|
||||
void TearDown() override {
|
||||
TF_DeleteStatus(status_);
|
||||
tf_gcs_filesystem::Cleanup(filesystem_);
|
||||
delete filesystem_;
|
||||
}
|
||||
|
||||
protected:
|
||||
TF_Filesystem* filesystem_;
|
||||
TF_Status* status_;
|
||||
};
|
||||
|
||||
// We have to add this test here because there must be at least one test.
|
||||
// This test will be removed in the future.
|
||||
TEST_F(GCSFilesystemTest, TestInit) { ASSERT_TF_OK(status_); }
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
||||
GTEST_API_ int main(int argc, char** argv) {
|
||||
tensorflow::testing::InstallStacktraceHandler();
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
@ -0,0 +1,34 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h"
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
TempFile::TempFile(const char* temp_file_name, std::ios::openmode mode)
|
||||
: std::fstream(temp_file_name, mode), name_(temp_file_name) {}
|
||||
|
||||
TempFile::TempFile(TempFile&& rhs)
|
||||
: std::fstream(std::move(rhs)), name_(std::move(rhs.name_)) {}
|
||||
|
||||
TempFile::~TempFile() {
|
||||
std::fstream::close();
|
||||
std::remove(name_.c_str());
|
||||
}
|
||||
|
||||
const std::string TempFile::getName() const { return name_; }
|
@ -0,0 +1,33 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_HELPER_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_HELPER_H_
|
||||
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
|
||||
class TempFile : public std::fstream {
|
||||
public:
|
||||
// We should specify openmode each time we call TempFile.
|
||||
TempFile(const char* temp_file_name, std::ios::openmode mode);
|
||||
TempFile(TempFile&& rhs);
|
||||
~TempFile() override;
|
||||
const std::string getName() const;
|
||||
|
||||
private:
|
||||
const std::string name_;
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_HELPER_H_
|
@ -3,6 +3,10 @@
|
||||
# Targets in this directory are pure C++ "Classes" underlying the C API types
|
||||
# under tf/c/experimental/saved_model/public/. They are subject to change and
|
||||
# have visibility limited to Tensorflow's implementation only.
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_cc_test",
|
||||
)
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
@ -23,8 +27,8 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
":function_metadata",
|
||||
"//tensorflow/c/eager:operation_interface",
|
||||
"//tensorflow/c/eager:tensor_handle_interface",
|
||||
"//tensorflow/c/eager:immediate_execution_operation",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
@ -47,6 +51,22 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "saved_model_utils",
|
||||
srcs = [
|
||||
"saved_model_utils.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"saved_model_utils.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:tf_tensor_internal",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_saved_model_impl",
|
||||
srcs = [
|
||||
@ -84,3 +104,26 @@ filegroup(
|
||||
],
|
||||
visibility = ["//tensorflow/core:__pkg__"],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "saved_model_utils_test",
|
||||
srcs = [
|
||||
"saved_model_utils_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":saved_model_utils",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/common_runtime:core_cpu_lib",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"//tensorflow/core/common_runtime/eager:core",
|
||||
],
|
||||
)
|
||||
|
@ -15,12 +15,12 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
const std::vector<tensorflow::AbstractTensorHandleInterface*>&
|
||||
const std::vector<tensorflow::ImmediateExecutionTensorHandle*>&
|
||||
ConcreteFunction::GetCaptures() const {
|
||||
return captures_;
|
||||
}
|
||||
|
@ -18,8 +18,8 @@ limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
|
||||
@ -38,15 +38,15 @@ class ConcreteFunction {
|
||||
virtual ~ConcreteFunction() = 0;
|
||||
|
||||
// This method returns the "Call" Op used to execute the function.
|
||||
virtual AbstractOperationInterface* GetCallOp() = 0;
|
||||
virtual ImmediateExecutionOperation* GetCallOp() = 0;
|
||||
|
||||
const std::vector<tensorflow::AbstractTensorHandleInterface*>& GetCaptures()
|
||||
const std::vector<tensorflow::ImmediateExecutionTensorHandle*>& GetCaptures()
|
||||
const;
|
||||
const FunctionMetadata& GetFunctionMetadata() const;
|
||||
|
||||
private:
|
||||
FunctionMetadata metadata_;
|
||||
std::vector<tensorflow::AbstractTensorHandleInterface*> captures_;
|
||||
std::vector<tensorflow::ImmediateExecutionTensorHandle*> captures_;
|
||||
FunctionDef* function_;
|
||||
};
|
||||
|
||||
|
@ -14,44 +14,6 @@ package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "owned_eager_op",
|
||||
hdrs = [
|
||||
"owned_eager_op.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/eager:operation_interface",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "owned_tensor_handle",
|
||||
hdrs = [
|
||||
"owned_tensor_handle.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/eager:tensor_handle_interface",
|
||||
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "owned_eager_context",
|
||||
hdrs = ["owned_eager_context.h"],
|
||||
deps = [
|
||||
"//tensorflow/c/eager:context_interface",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "owned_tensor",
|
||||
hdrs = ["owned_tensor.h"],
|
||||
deps = [
|
||||
"//tensorflow/c:tensor_interface",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "variable_ops",
|
||||
srcs = [
|
||||
@ -61,10 +23,11 @@ cc_library(
|
||||
"variable_ops.h",
|
||||
],
|
||||
deps = [
|
||||
":owned_eager_op",
|
||||
":owned_tensor_handle",
|
||||
"//tensorflow/c/eager:context_interface",
|
||||
"//tensorflow/c/eager:tensor_handle_interface",
|
||||
"//tensorflow/c/eager:abstract_operation",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_operation",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
@ -78,10 +41,11 @@ tf_cc_test(
|
||||
"variable_ops_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":owned_eager_context",
|
||||
":owned_tensor",
|
||||
":owned_tensor_handle",
|
||||
":variable_ops",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
|
@ -1,54 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_OWNED_EAGER_CONTEXT_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_OWNED_EAGER_CONTEXT_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/eager/context_interface.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
struct AbstractContextInterfaceDeleter {
|
||||
void operator()(AbstractContextInterface* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct EagerContextDeleter {
|
||||
void operator()(EagerContext* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
|
||||
using AbstractContextPtr =
|
||||
std::unique_ptr<AbstractContextInterface,
|
||||
internal::AbstractContextInterfaceDeleter>;
|
||||
|
||||
using EagerContextPtr =
|
||||
std::unique_ptr<EagerContext, internal::EagerContextDeleter>;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_OWNED_EAGER_CONTEXT_H_
|
@ -1,54 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OWNED_TENSOR_HANDLE_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OWNED_TENSOR_HANDLE_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
struct TensorHandleDeleter {
|
||||
void operator()(TensorHandle* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct AbstractTensorHandleDeleter {
|
||||
void operator()(AbstractTensorHandleInterface* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
|
||||
using TensorHandlePtr =
|
||||
std::unique_ptr<TensorHandle, internal::TensorHandleDeleter>;
|
||||
|
||||
using AbstractTensorHandlePtr =
|
||||
std::unique_ptr<AbstractTensorHandleInterface,
|
||||
internal::AbstractTensorHandleDeleter>;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OWNED_TENSOR_HANDLE_H_
|
@ -16,9 +16,11 @@ limitations under the License.
|
||||
#include "tensorflow/c/experimental/saved_model/core/ops/variable_ops.h"
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/context_interface.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/ops/owned_eager_op.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/ops/owned_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/abstract_operation.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
@ -32,10 +34,10 @@ namespace internal {
|
||||
static const char kNoSharingResourceID[] =
|
||||
"cd2c89b7-88b7-44c8-ad83-06c2a9158347";
|
||||
|
||||
Status CreateUninitializedResourceVariable(AbstractContextInterface* ctx,
|
||||
Status CreateUninitializedResourceVariable(ImmediateExecutionContext* ctx,
|
||||
DataType dtype, TensorShape shape,
|
||||
AbstractTensorHandlePtr* handle) {
|
||||
AbstractOpPtr varhandle_op = AbstractOpPtr(ctx->CreateOperation());
|
||||
ImmediateTensorHandlePtr* handle) {
|
||||
ImmediateOpPtr varhandle_op(ctx->CreateOperation());
|
||||
|
||||
TF_RETURN_IF_ERROR(varhandle_op->Reset("VarHandleOp", nullptr));
|
||||
TF_RETURN_IF_ERROR(varhandle_op->SetAttrType("dtype", dtype));
|
||||
@ -50,17 +52,57 @@ Status CreateUninitializedResourceVariable(AbstractContextInterface* ctx,
|
||||
TF_RETURN_IF_ERROR(varhandle_op->SetAttrString(
|
||||
"shared_name", kNoSharingResourceID, strlen(kNoSharingResourceID)));
|
||||
|
||||
AbstractTensorHandleInterface* var_handle = nullptr;
|
||||
AbstractTensorHandle* var_handle = nullptr;
|
||||
int num_retvals = 1;
|
||||
TF_RETURN_IF_ERROR(varhandle_op->Execute(
|
||||
absl::MakeSpan(&var_handle, num_retvals), &num_retvals));
|
||||
handle->reset(var_handle);
|
||||
AbstractTensorHandlePtr owned_var_handle(var_handle);
|
||||
if (owned_var_handle->getKind() != ImmediateExecutionTensorHandle::kKind) {
|
||||
return errors::Internal("Unexpected tensor handle kind.");
|
||||
}
|
||||
handle->reset(reinterpret_cast<ImmediateExecutionTensorHandle*>(
|
||||
owned_var_handle.release()));
|
||||
return Status();
|
||||
}
|
||||
|
||||
Status DestroyResource(AbstractContextInterface* ctx,
|
||||
AbstractTensorHandleInterface* handle) {
|
||||
AbstractOpPtr destroy_op = AbstractOpPtr(ctx->CreateOperation());
|
||||
Status AssignVariable(ImmediateExecutionContext* ctx,
|
||||
ImmediateExecutionTensorHandle* variable_handle,
|
||||
DataType dtype, ImmediateExecutionTensorHandle* value) {
|
||||
ImmediateOpPtr assign_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(assign_op->Reset("AssignVariableOp", nullptr));
|
||||
TF_RETURN_IF_ERROR(assign_op->SetAttrType("dtype", dtype));
|
||||
TF_RETURN_IF_ERROR(assign_op->AddInput(variable_handle));
|
||||
TF_RETURN_IF_ERROR(assign_op->AddInput(value));
|
||||
|
||||
int num_retvals = 0;
|
||||
TF_RETURN_IF_ERROR(assign_op->Execute({}, &num_retvals));
|
||||
return Status();
|
||||
}
|
||||
|
||||
Status ReadVariable(ImmediateExecutionContext* ctx,
|
||||
ImmediateExecutionTensorHandle* variable_handle,
|
||||
DataType dtype, ImmediateTensorHandlePtr* output) {
|
||||
ImmediateOpPtr read_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(read_op->Reset("ReadVariableOp", nullptr));
|
||||
TF_RETURN_IF_ERROR(read_op->SetAttrType("dtype", dtype));
|
||||
TF_RETURN_IF_ERROR(read_op->AddInput(variable_handle));
|
||||
|
||||
AbstractTensorHandle* value = nullptr;
|
||||
int num_retvals = 1;
|
||||
TF_RETURN_IF_ERROR(
|
||||
read_op->Execute(absl::MakeSpan(&value, num_retvals), &num_retvals));
|
||||
AbstractTensorHandlePtr owned_value(value);
|
||||
if (owned_value->getKind() != ImmediateExecutionTensorHandle::kKind) {
|
||||
return errors::Internal("Unexpected tensor handle kind.");
|
||||
}
|
||||
output->reset(
|
||||
reinterpret_cast<ImmediateExecutionTensorHandle*>(owned_value.release()));
|
||||
return Status();
|
||||
}
|
||||
|
||||
Status DestroyResource(ImmediateExecutionContext* ctx,
|
||||
ImmediateExecutionTensorHandle* handle) {
|
||||
ImmediateOpPtr destroy_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(destroy_op->Reset("DestroyResourceOp", nullptr));
|
||||
TF_RETURN_IF_ERROR(destroy_op->SetAttrBool("ignore_lookup_error", true));
|
||||
TF_RETURN_IF_ERROR(destroy_op->AddInput(handle));
|
||||
|
@ -16,9 +16,8 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_VARIABLE_OPS_H
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_VARIABLE_OPS_H
|
||||
|
||||
#include "tensorflow/c/eager/context_interface.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/ops/owned_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
@ -30,15 +29,31 @@ namespace internal {
|
||||
// TensorHandle associated with the variable. This is equivalent to creating an
|
||||
// unitialized TF2 tf.Variable.
|
||||
// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/ops/resource_variable_ops.py#L1867-L1872
|
||||
Status CreateUninitializedResourceVariable(AbstractContextInterface* ctx,
|
||||
Status CreateUninitializedResourceVariable(ImmediateExecutionContext* ctx,
|
||||
DataType dtype, TensorShape shape,
|
||||
AbstractTensorHandlePtr* handle);
|
||||
ImmediateTensorHandlePtr* handle);
|
||||
|
||||
// Executes an AssignVariableOp using `ctx`, assigning the variable associated
|
||||
// with `variable_handle` with `value`. `dtype` must be the datatype of the
|
||||
// underlying variable for `variable_handle`. Note that it is illegal to assign
|
||||
// a variable to a Tensor with a different dtype than what the variable was
|
||||
// created with.
|
||||
Status AssignVariable(ImmediateExecutionContext* ctx,
|
||||
ImmediateExecutionTensorHandle* variable_handle,
|
||||
DataType dtype, ImmediateExecutionTensorHandle* value);
|
||||
|
||||
// Executes a ReadVariableOp using `ctx`. This reads the underlying variable
|
||||
// value of `variable_handle` and copies the value to `output`. `dtype` must be
|
||||
// the dtype of the variable associated with `variable_handle`.
|
||||
Status ReadVariable(ImmediateExecutionContext* ctx,
|
||||
ImmediateExecutionTensorHandle* variable_handle,
|
||||
DataType dtype, ImmediateTensorHandlePtr* output);
|
||||
|
||||
// Executes DestroyResourceOp on `handle`, using `ctx`. This is equivalent to
|
||||
// the cleanup that occurs in a tf.Variable's EagerResourceDeleter:
|
||||
// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/ops/resource_variable_ops.py#L289-L290
|
||||
Status DestroyResource(AbstractContextInterface* ctx,
|
||||
AbstractTensorHandleInterface* handle);
|
||||
Status DestroyResource(ImmediateExecutionContext* ctx,
|
||||
ImmediateExecutionTensorHandle* handle);
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
||||
|
@ -17,9 +17,8 @@ limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/ops/owned_eager_context.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/ops/owned_tensor.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/ops/owned_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
@ -30,6 +29,13 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
ImmediateTensorHandlePtr CreateScalarTensorHandle(EagerContext* context,
|
||||
float value) {
|
||||
AbstractTensorPtr tensor(context->CreateFloatScalar(value));
|
||||
ImmediateTensorHandlePtr handle(context->CreateLocalHandle(tensor.get()));
|
||||
return handle;
|
||||
}
|
||||
|
||||
class VariableOpsTest : public ::testing::Test {
|
||||
public:
|
||||
VariableOpsTest()
|
||||
@ -55,7 +61,7 @@ class VariableOpsTest : public ::testing::Test {
|
||||
// Sanity check for variable creation
|
||||
TEST_F(VariableOpsTest, CreateVariableSuccessful) {
|
||||
// Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor
|
||||
AbstractTensorHandlePtr handle;
|
||||
ImmediateTensorHandlePtr handle;
|
||||
TF_EXPECT_OK(internal::CreateUninitializedResourceVariable(
|
||||
context(), DT_FLOAT, {}, &handle));
|
||||
// The created TensorHandle should be a DT_Resource
|
||||
@ -65,7 +71,7 @@ TEST_F(VariableOpsTest, CreateVariableSuccessful) {
|
||||
// Sanity check for variable destruction
|
||||
TEST_F(VariableOpsTest, DestroyVariableSuccessful) {
|
||||
// Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor
|
||||
AbstractTensorHandlePtr handle;
|
||||
ImmediateTensorHandlePtr handle;
|
||||
TF_EXPECT_OK(internal::CreateUninitializedResourceVariable(
|
||||
context(), DT_FLOAT, {}, &handle));
|
||||
|
||||
@ -73,5 +79,28 @@ TEST_F(VariableOpsTest, DestroyVariableSuccessful) {
|
||||
TF_EXPECT_OK(internal::DestroyResource(context(), handle.get()));
|
||||
}
|
||||
|
||||
// Sanity check for handle assignment and reading
|
||||
TEST_F(VariableOpsTest, AssignVariableAndReadSuccessful) {
|
||||
// Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor
|
||||
ImmediateTensorHandlePtr variable;
|
||||
TF_EXPECT_OK(internal::CreateUninitializedResourceVariable(
|
||||
context(), DT_FLOAT, {}, &variable));
|
||||
|
||||
// Create a Scalar float TensorHandle with value 42, and assign it to
|
||||
// the variable.
|
||||
ImmediateTensorHandlePtr my_value = CreateScalarTensorHandle(context(), 42.0);
|
||||
TF_EXPECT_OK(internal::AssignVariable(context(), variable.get(), DT_FLOAT,
|
||||
my_value.get()));
|
||||
|
||||
// Read back the value from the variable, and check that it is 42.
|
||||
ImmediateTensorHandlePtr read_value_handle;
|
||||
TF_EXPECT_OK(internal::ReadVariable(context(), variable.get(), DT_FLOAT,
|
||||
&read_value_handle));
|
||||
Status status;
|
||||
AbstractTensorPtr read_value(read_value_handle->Resolve(&status));
|
||||
TF_EXPECT_OK(status);
|
||||
EXPECT_FLOAT_EQ(42.0, *static_cast<float*>(read_value->Data()));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -0,0 +1,39 @@
|
||||
# This package contains classes corresponding to Revived SavedObjectGraph types
|
||||
# used by SavedModel. See https://cs.opensource.google/tensorflow/tensorflow/+/c575e2ba93c442121d98d3f125d83fed1339924d:tensorflow/core/protobuf/saved_object_graph.proto;l=56-62
|
||||
package(
|
||||
default_visibility = [
|
||||
# Restricting visibility for now
|
||||
"//tensorflow/c/experimental/saved_model/core:__pkg__",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "constant",
|
||||
srcs = [
|
||||
"constant.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"constant.h",
|
||||
],
|
||||
deps = [
|
||||
":tensorhandle_convertible",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensorhandle_convertible",
|
||||
hdrs = [
|
||||
"tensorhandle_convertible.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
],
|
||||
)
|
@ -0,0 +1,46 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Constant::Constant(ImmediateTensorHandlePtr handle)
|
||||
: TensorHandleConvertible(std::move(handle)) {}
|
||||
|
||||
Status Constant::Create(ImmediateExecutionContext* ctx,
|
||||
AbstractTensorInterface* tensor,
|
||||
std::unique_ptr<Constant>* output) {
|
||||
ImmediateExecutionTensorHandle* handle = ctx->CreateLocalHandle(tensor);
|
||||
if (handle == nullptr) {
|
||||
return errors::Internal("Failed to convert tensor to tensorhandle");
|
||||
}
|
||||
output->reset(new Constant(ImmediateTensorHandlePtr(handle)));
|
||||
return Status();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,55 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_CONSTANT_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_CONSTANT_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// This class corresponds to python's tf.constant, which is effectively a
|
||||
// TensorHandle explicitly initialized to some value.
|
||||
// For now this doesn't do much beyond wrap Context's CreateLocalHandle method,
|
||||
// and offer a subclass of TensorHandleConvertible. Note that similar to
|
||||
// the python's eager mode logic, we bypass calling the "Const" op:
|
||||
// https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/framework/constant_op.py#L301
|
||||
class Constant : public TensorHandleConvertible {
|
||||
public:
|
||||
static Status Create(ImmediateExecutionContext* ctx,
|
||||
AbstractTensorInterface* tensor,
|
||||
std::unique_ptr<Constant>* output);
|
||||
|
||||
// RevivedConstant is movable, but not copyable.
|
||||
Constant(Constant&& other) = default;
|
||||
Constant& operator=(Constant&& other) = default;
|
||||
|
||||
~Constant() override = default;
|
||||
|
||||
private:
|
||||
explicit Constant(ImmediateTensorHandlePtr handle);
|
||||
Constant(const Constant&) = delete;
|
||||
Constant& operator=(const Constant&) = delete;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_CONSTANT_H_
|
@ -0,0 +1,49 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSORHANDLE_CONVERTIBLE_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSORHANDLE_CONVERTIBLE_H_
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// A common interface for objects that can be converted to a TensorHandle.
|
||||
// Examples of objects that implement this include Variables, Constants, Assets,
|
||||
// etc. This is used to convert captured objects into a ConcreteFunction's
|
||||
// captured TensorHandles:
|
||||
// https://github.com/tensorflow/tensorflow/blob/676a68963ea4b64fe479b9cede06aa8f5b290ab8/tensorflow/python/saved_model/load.py#L229-L240
|
||||
class TensorHandleConvertible {
|
||||
public:
|
||||
explicit TensorHandleConvertible(ImmediateTensorHandlePtr handle)
|
||||
: handle_(std::move(handle)) {}
|
||||
|
||||
ImmediateExecutionTensorHandle* handle() { return handle_.get(); }
|
||||
|
||||
// TensorHandleConvertible is movable, but not copyable.
|
||||
TensorHandleConvertible(TensorHandleConvertible&& other) = default;
|
||||
TensorHandleConvertible& operator=(TensorHandleConvertible&& other) = default;
|
||||
|
||||
virtual ~TensorHandleConvertible() = default;
|
||||
|
||||
protected:
|
||||
TensorHandleConvertible(const TensorHandleConvertible&) = delete;
|
||||
TensorHandleConvertible& operator=(const TensorHandleConvertible&) = delete;
|
||||
ImmediateTensorHandlePtr handle_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSORHANDLE_CONVERTIBLE_H_
|
@ -13,30 +13,26 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OWNED_EAGER_OP_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OWNED_EAGER_OP_H_
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
|
||||
#include "tensorflow/c/tf_tensor_internal.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
struct AbstractOperationInterfaceDeleter {
|
||||
void operator()(AbstractOperationInterface* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
}
|
||||
Status TensorProtoToConstant(ImmediateExecutionContext* ctx,
|
||||
const TensorProto& proto,
|
||||
std::unique_ptr<Constant>* output) {
|
||||
tensorflow::Tensor tensor;
|
||||
bool parse_result = tensor.FromProto(proto);
|
||||
if (!parse_result) {
|
||||
return errors::Internal("Failed to parse tensor from tensorproto");
|
||||
}
|
||||
};
|
||||
|
||||
TensorInterface tensor_interface(std::move(tensor));
|
||||
return Constant::Create(ctx, &tensor_interface, output);
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
|
||||
using AbstractOpPtr =
|
||||
std::unique_ptr<AbstractOperationInterface,
|
||||
internal::AbstractOperationInterfaceDeleter>;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OWNED_EAGER_OP_H_
|
@ -0,0 +1,39 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_UTILS_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_UTILS_H_
|
||||
|
||||
// Some internal utility functions for the SavedModelAPI, factored out into a
|
||||
// separately unit-testable header.
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
// Load a TensorProto into a tensorflow::Constant. This is similar to the
|
||||
// constant loading logic in python:
|
||||
// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/saved_model/load.py#L437
|
||||
Status TensorProtoToConstant(ImmediateExecutionContext* ctx,
|
||||
const TensorProto& proto,
|
||||
std::unique_ptr<Constant>* output);
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_UTILS_H_
|
@ -0,0 +1,199 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/framework/numeric_types.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// Converts a tensorflow::DatatypeSet to std::vector<DataType>.
|
||||
// This is needed for GTest's ::testing::ValuesIn, since
|
||||
// DataTypeSet doesn't fullfill all the constraints of an STL-like iterable.
|
||||
std::vector<DataType> DataTypeSetToVector(DataTypeSet set) {
|
||||
std::vector<DataType> result;
|
||||
result.reserve(set.size());
|
||||
for (DataType dt : set) {
|
||||
result.push_back(dt);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Returns a vector of shapes intended to be "interesting" test cases.
|
||||
std::vector<std::vector<int64>> InterestingShapes() {
|
||||
std::vector<std::vector<int64>> interesting_shapes;
|
||||
interesting_shapes.push_back({}); // Scalar
|
||||
interesting_shapes.push_back({10}); // 1D Vector
|
||||
interesting_shapes.push_back({3, 3}); // 2D Matrix
|
||||
interesting_shapes.push_back({1, 4, 6, 10}); // Higher Dimension Tensor
|
||||
return interesting_shapes;
|
||||
}
|
||||
|
||||
// Fills a numeric tensor with `value`.
|
||||
void FillNumericTensor(Tensor* tensor, int8 value) {
|
||||
switch (tensor->dtype()) {
|
||||
#define CASE(type) \
|
||||
case DataTypeToEnum<type>::value: { \
|
||||
const auto& flattened = tensor->flat<type>(); \
|
||||
for (int i = 0; i < tensor->NumElements(); ++i) { \
|
||||
flattened(i) = value; \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
TF_CALL_INTEGRAL_TYPES(CASE);
|
||||
TF_CALL_double(CASE);
|
||||
TF_CALL_float(CASE);
|
||||
#undef CASE
|
||||
default:
|
||||
CHECK(false) << "Unsupported data type: "
|
||||
<< DataTypeString(tensor->dtype());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Checks the underlying data is equal for the buffers for two numeric tensors.
|
||||
// Note: The caller must ensure to check that the dtypes and sizes of the
|
||||
// underlying buffers are the same before calling this.
|
||||
void CheckBufferDataIsEqual(DataType dtype, int64 num_elements, void* a,
|
||||
void* b) {
|
||||
switch (dtype) {
|
||||
#define CASE(type) \
|
||||
case DataTypeToEnum<type>::value: { \
|
||||
type* typed_a = static_cast<type*>(a); \
|
||||
type* typed_b = static_cast<type*>(b); \
|
||||
for (int64 i = 0; i < num_elements; ++i) { \
|
||||
if (DataTypeIsFloating(dtype)) { \
|
||||
EXPECT_FLOAT_EQ(typed_a[i], typed_b[i]); \
|
||||
} else { \
|
||||
EXPECT_EQ(typed_a[i], typed_b[i]); \
|
||||
} \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
TF_CALL_INTEGRAL_TYPES(CASE);
|
||||
TF_CALL_double(CASE);
|
||||
TF_CALL_float(CASE);
|
||||
#undef CASE
|
||||
default:
|
||||
CHECK(false) << "Unsupported data type: " << DataTypeString(dtype);
|
||||
}
|
||||
}
|
||||
|
||||
class ConstantTest : public ::testing::TestWithParam<
|
||||
std::tuple<DataType, std::vector<int64>, bool>> {
|
||||
public:
|
||||
ConstantTest()
|
||||
: device_mgr_(std::make_unique<StaticDeviceMgr>(DeviceFactory::NewDevice(
|
||||
"CPU", {}, "/job:localhost/replica:0/task:0"))),
|
||||
ctx_(new EagerContext(
|
||||
SessionOptions(),
|
||||
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
||||
tensorflow::ContextMirroringPolicy::MIRRORING_NONE,
|
||||
/* async= */ false,
|
||||
/* lazy_copy_function_remote_inputs= */ false, device_mgr_.get(),
|
||||
/* device_mgr_owned= */ false, /* rendezvous= */ nullptr,
|
||||
/* custom_kernel_creator= */ nullptr,
|
||||
/* cluster_flr= */ nullptr)) {}
|
||||
|
||||
EagerContext* context() { return ctx_.get(); }
|
||||
|
||||
private:
|
||||
std::unique_ptr<StaticDeviceMgr> device_mgr_;
|
||||
EagerContextPtr ctx_;
|
||||
};
|
||||
|
||||
// Basic sanity check that roundtripping a Tensor->Tensorproto->Constant
|
||||
// preserves values.
|
||||
TEST_P(ConstantTest, CreateConstantSuccessful) {
|
||||
// Get test parameters
|
||||
auto& test_params = GetParam();
|
||||
DataType dtype = std::get<0>(test_params);
|
||||
TensorShape shape(std::get<1>(test_params));
|
||||
bool tensorproto_use_tensor_content = std::get<2>(test_params);
|
||||
|
||||
// Construct a Tensor with the given dtype + shape
|
||||
Tensor expected(dtype, shape);
|
||||
FillNumericTensor(&expected, 42);
|
||||
|
||||
// Serialize it to a Tensorproto
|
||||
TensorProto proto;
|
||||
if (tensorproto_use_tensor_content) {
|
||||
expected.AsProtoTensorContent(&proto);
|
||||
} else {
|
||||
expected.AsProtoField(&proto);
|
||||
}
|
||||
|
||||
// Revival should succeed w/o errors
|
||||
std::unique_ptr<Constant> revived;
|
||||
TF_EXPECT_OK(internal::TensorProtoToConstant(context(), proto, &revived));
|
||||
|
||||
// The revived tensorhandle should have the exact same dtype, shape, +
|
||||
// approx equivalent data to the original.
|
||||
ImmediateExecutionTensorHandle* handle = revived->handle();
|
||||
Status status;
|
||||
AbstractTensorPtr revived_tensor(handle->Resolve(&status));
|
||||
TF_EXPECT_OK(status) << "Failed to convert tensorhandle to tensor";
|
||||
EXPECT_EQ(revived_tensor->Type(), expected.dtype());
|
||||
EXPECT_EQ(revived_tensor->NumElements(), expected.NumElements());
|
||||
EXPECT_EQ(revived_tensor->NumDims(), expected.dims());
|
||||
for (int i = 0; i < expected.dims(); ++i) {
|
||||
EXPECT_EQ(revived_tensor->Dim(i), expected.dim_size(i));
|
||||
}
|
||||
|
||||
CheckBufferDataIsEqual(expected.dtype(), expected.NumElements(),
|
||||
revived_tensor->Data(), expected.data());
|
||||
}
|
||||
|
||||
// Test against combinations of tensors that are
|
||||
// 1. Varying dtypes
|
||||
// 2. Varying shapes
|
||||
// 3. TensorProto serialized using tensor_content vs repeated type
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
ConstantIntegerDtypesTest, ConstantTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(DataTypeSetToVector(kDataTypeIsInteger)),
|
||||
::testing::ValuesIn(InterestingShapes()),
|
||||
::testing::Values(false, true)));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
ConstantFloatingDtypesTest, ConstantTest,
|
||||
::testing::Combine(::testing::Values(DT_FLOAT, DT_DOUBLE),
|
||||
::testing::ValuesIn(InterestingShapes()),
|
||||
::testing::Values(false, true)));
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -178,7 +178,7 @@ cc_library(
|
||||
":tensorhandle_list_type",
|
||||
"//tensorflow/c:c_api_macros",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:tensor_handle_interface",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/c/eager:tfe_tensorhandle_internal",
|
||||
],
|
||||
)
|
||||
@ -190,7 +190,7 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:conversion_macros",
|
||||
"//tensorflow/c/eager:tensor_handle_interface",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h"
|
||||
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
|
||||
// Internal structures used by the SavedModel C API. These are likely to
|
||||
// change and should not be depended on.
|
||||
@ -29,7 +29,7 @@ typedef struct TF_TensorHandleList TF_TensorHandleList;
|
||||
namespace tensorflow {
|
||||
|
||||
DEFINE_CONVERSION_FUNCTIONS(
|
||||
std::vector<tensorflow::AbstractTensorHandleInterface*>,
|
||||
std::vector<tensorflow::ImmediateExecutionTensorHandle*>,
|
||||
TF_TensorHandleList)
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -54,6 +54,20 @@ class AbstractTensorInterface {
|
||||
virtual ~AbstractTensorInterface() {}
|
||||
};
|
||||
|
||||
namespace internal {
|
||||
struct AbstractTensorInterfaceDeleter {
|
||||
void operator()(AbstractTensorInterface* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace internal
|
||||
|
||||
using AbstractTensorPtr =
|
||||
std::unique_ptr<AbstractTensorInterface,
|
||||
internal::AbstractTensorInterfaceDeleter>;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_TENSOR_INTERFACE_H_
|
||||
|
@ -259,9 +259,6 @@ TEST_F(NNGradTest, MaxPoolGradV2Helper) {
|
||||
RunTest(x, x_init_value, y, y_shape);
|
||||
}
|
||||
|
||||
// TODO(rocm):
|
||||
// Re-enable this test once 3D pooling is supported on ROCm platform
|
||||
#ifndef TENSORFLOW_USE_ROCM
|
||||
TEST_F(NNGradTest, MaxPool3DGradHelper) {
|
||||
TensorShape x_shape({1, 3, 3, 3, 1});
|
||||
TensorShape y_shape({1, 1, 1, 1, 1});
|
||||
@ -274,7 +271,6 @@ TEST_F(NNGradTest, MaxPool3DGradHelper) {
|
||||
SetRandomValuesForMaxPooling<float>(&x_init_value);
|
||||
RunTest(x, x_init_value, y, y_shape);
|
||||
}
|
||||
#endif
|
||||
|
||||
TEST_F(NNGradTest, AvgPoolGradHelper) {
|
||||
TensorShape x_shape({1, 2, 2, 1});
|
||||
@ -287,9 +283,6 @@ TEST_F(NNGradTest, AvgPoolGradHelper) {
|
||||
RunTest(x, x_shape, y, y_shape);
|
||||
}
|
||||
|
||||
// TODO(rocm):
|
||||
// Re-enable this test once 3D pooling is supported on ROCm platform
|
||||
#ifndef TENSORFLOW_USE_ROCM
|
||||
TEST_F(NNGradTest, AvgPool3DGradHelper) {
|
||||
TensorShape x_shape({1, 3, 3, 3, 1});
|
||||
TensorShape y_shape({1, 1, 1, 1, 1});
|
||||
@ -300,7 +293,6 @@ TEST_F(NNGradTest, AvgPool3DGradHelper) {
|
||||
auto y = AvgPool3D(scope_, x, ksize, strides, "SAME");
|
||||
RunTest(x, x_shape, y, y_shape);
|
||||
}
|
||||
#endif
|
||||
|
||||
TEST_F(NNGradTest, LRN) {
|
||||
TensorShape x_shape({1, 1, 2, 1});
|
||||
|
@ -69,6 +69,7 @@ cc_library(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@llvm-project//llvm:ARMCodeGen", # fixdeps: keep
|
||||
"@llvm-project//llvm:PowerPCCodeGen", # fixdeps: keep
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//llvm:Target",
|
||||
"@llvm-project//llvm:X86CodeGen", # fixdeps: keep
|
||||
"//tensorflow/core:regexp_internal",
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/base/call_once.h"
|
||||
#include "llvm-c/Target.h"
|
||||
#include "llvm/Support/ManagedStatic.h"
|
||||
#include "tensorflow/compiler/aot/codegen.h"
|
||||
#include "tensorflow/compiler/aot/flags.h"
|
||||
#include "tensorflow/compiler/aot/quantize.h"
|
||||
|
@ -108,8 +108,7 @@ class XlaExecutableClosure {
|
||||
explicit XlaExecutableClosure(
|
||||
xla::LocalClient* client, xla::LocalExecutable* executable,
|
||||
const XlaCompiler::CompilationResult* compilation_result,
|
||||
std::map<int, OptionalTensor> resource_var_snapshots,
|
||||
int num_constant_args)
|
||||
ResourceVarsSnapshot resource_var_snapshots, int num_constant_args)
|
||||
: client_(client),
|
||||
executable_(executable),
|
||||
compilation_result_(compilation_result),
|
||||
@ -124,7 +123,7 @@ class XlaExecutableClosure {
|
||||
const XlaCompiler::CompilationResult* compilation_result() const {
|
||||
return compilation_result_;
|
||||
}
|
||||
const std::map<int, OptionalTensor>& resource_var_snapshots() const {
|
||||
const ResourceVarsSnapshot& resource_var_snapshots() const {
|
||||
return resource_var_snapshots_;
|
||||
}
|
||||
int num_constant_args() const { return num_constant_args_; }
|
||||
@ -133,7 +132,7 @@ class XlaExecutableClosure {
|
||||
xla::LocalClient* client_;
|
||||
xla::LocalExecutable* executable_;
|
||||
const XlaCompiler::CompilationResult* compilation_result_;
|
||||
std::map<int, OptionalTensor> resource_var_snapshots_;
|
||||
ResourceVarsSnapshot resource_var_snapshots_;
|
||||
int num_constant_args_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosure);
|
||||
@ -276,10 +275,10 @@ static Status BuildCompilationCache(OpKernelContext* ctx,
|
||||
|
||||
static Status CompileToLocalExecutable(
|
||||
OpKernelContext* ctx, const NameAttrList& function, bool has_ref_vars,
|
||||
const XlaPlatformInfo& platform_info, absl::Span<const int> resources,
|
||||
const XlaPlatformInfo& platform_info,
|
||||
absl::Span<VariableInfo const> variable_infos,
|
||||
absl::Span<const int> constants, bool lazy, xla::LocalClient** client,
|
||||
std::map<int, OptionalTensor>* variables,
|
||||
const XlaCompiler::CompilationResult** kernel,
|
||||
const XlaCompiler::CompilationResult** compilation_result,
|
||||
xla::LocalExecutable** executable) {
|
||||
// We store information about the JIT-compiled XLA computation
|
||||
// in the ResourceMgr.
|
||||
@ -299,7 +298,6 @@ static Status CompileToLocalExecutable(
|
||||
// this is more obviously correct.)
|
||||
core::ScopedUnref cache_ref(cache);
|
||||
|
||||
TF_RETURN_IF_ERROR(SnapshotResourceVariables(ctx, resources, variables));
|
||||
*client = static_cast<xla::LocalClient*>(cache->client());
|
||||
|
||||
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
|
||||
@ -337,11 +335,11 @@ static Status CompileToLocalExecutable(
|
||||
|
||||
std::vector<XlaCompiler::Argument> args;
|
||||
TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
constant_args, *variables, ctx, &args));
|
||||
constant_args, variable_infos, ctx, &args));
|
||||
return cache->Compile(options, function, args, compile_options,
|
||||
lazy ? XlaCompilationCache::CompileMode::kLazy
|
||||
: XlaCompilationCache::CompileMode::kStrict,
|
||||
kernel, executable);
|
||||
compilation_result, executable);
|
||||
}
|
||||
|
||||
void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
||||
@ -349,16 +347,22 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
||||
<< Canonicalize(function_.name(), AttrSlice(&function_.attr()));
|
||||
|
||||
xla::LocalClient* client;
|
||||
const XlaCompiler::CompilationResult* kernel;
|
||||
const XlaCompiler::CompilationResult* compilation_result;
|
||||
xla::LocalExecutable* executable;
|
||||
std::map<int, OptionalTensor> variables;
|
||||
|
||||
ResourceVarsSnapshot variables;
|
||||
{
|
||||
std::vector<VariableInfo> variable_infos;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetVariableInfosFromCtxInputs(ctx, resources_, &variable_infos));
|
||||
OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos)));
|
||||
Status s = CompileToLocalExecutable(
|
||||
ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_,
|
||||
resources_, constants_, /*lazy=*/false, &client, &variables, &kernel,
|
||||
&executable);
|
||||
variable_infos, constants_, /*lazy=*/false, &client,
|
||||
&compilation_result, &executable);
|
||||
OP_REQUIRES_OK(ctx, s);
|
||||
OP_REQUIRES_OK(ctx, SnapshotResourceVariables(ctx, resources_,
|
||||
variable_infos, &variables));
|
||||
}
|
||||
|
||||
se::Stream* stream =
|
||||
@ -373,7 +377,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
||||
client, allocator,
|
||||
/*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
|
||||
platform_info_.UseMultipleStreams());
|
||||
launch_context.PopulateInputs(ctx, kernel, variables,
|
||||
launch_context.PopulateInputs(ctx, compilation_result, variables,
|
||||
/*missing_ctx_input_prefix=*/0);
|
||||
|
||||
// Execute the computation.
|
||||
@ -413,7 +417,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
||||
executable->executable()->module().input_output_alias_config();
|
||||
OP_REQUIRES_OK(
|
||||
ctx, launch_context.PopulateOutputs(
|
||||
ctx, kernel, run_result.ConsumeValueOrDie(),
|
||||
ctx, compilation_result, run_result.ConsumeValueOrDie(),
|
||||
/*missing_ctx_input_prefix=*/0, input_output_alias, variables));
|
||||
VLOG(1) << "Done";
|
||||
}
|
||||
@ -494,7 +498,7 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
|
||||
xla::LocalClient* client;
|
||||
const XlaCompiler::CompilationResult* kernel;
|
||||
xla::LocalExecutable* executable;
|
||||
std::map<int, OptionalTensor> variables;
|
||||
ResourceVarsSnapshot variables;
|
||||
|
||||
bool cannot_compile_cluster;
|
||||
{
|
||||
@ -506,9 +510,16 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
|
||||
cannot_compile_cluster) {
|
||||
executable = nullptr;
|
||||
} else {
|
||||
std::vector<VariableInfo> variable_infos;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetVariableInfosFromCtxInputs(ctx, resources_, &variable_infos));
|
||||
OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos)));
|
||||
Status status = CompileToLocalExecutable(
|
||||
ctx, function_, has_ref_vars_, platform_info_, resources_, constants_,
|
||||
/*lazy=*/!must_compile_, &client, &variables, &kernel, &executable);
|
||||
ctx, function_, has_ref_vars_, platform_info_, variable_infos,
|
||||
constants_,
|
||||
/*lazy=*/!must_compile_, &client, &kernel, &executable);
|
||||
OP_REQUIRES_OK(ctx, SnapshotResourceVariables(ctx, resources_,
|
||||
variable_infos, &variables));
|
||||
if (must_compile_ || status.code() != error::UNIMPLEMENTED) {
|
||||
OP_REQUIRES_OK(ctx, status);
|
||||
}
|
||||
|
@ -1837,7 +1837,7 @@ absl::flat_hash_map<string, std::vector<string>>* GetWhitelistTable() {
|
||||
"ConcatOffset", "Const", "MirrorPad", "Pack", "Pad", "PadV2", "Reverse",
|
||||
"ReverseV2", "ReverseSequence", "Slice", "Split", "SplitV",
|
||||
"StridedSlice", "StridedSliceGrad", "ResourceStridedSliceAssign",
|
||||
"Tile", "Transpose", "InvertPermutation", "Unpack"}}};
|
||||
"Tile", "Transpose", "InvertPermutation", "Unpack", "DeviceIndex"}}};
|
||||
// clang-format on
|
||||
return result;
|
||||
}
|
||||
|
@ -28,32 +28,23 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
std::map<int, OptionalTensor> GetVariables(OpKernelContext* ctx) {
|
||||
std::map<int, OptionalTensor> variables;
|
||||
for (int64 i = 0; i < ctx->num_inputs(); ++i) {
|
||||
// Returns argument indices corresponding to the resource variable inputs of
|
||||
// kernel context `ctx`.
|
||||
static std::vector<int> GetResourceVariableIndices(OpKernelContext* ctx) {
|
||||
std::vector<int> out;
|
||||
for (int64 i = 0; i < ctx->num_inputs(); i++) {
|
||||
if (ctx->input(i).dtype() == DT_RESOURCE) {
|
||||
core::RefCountPtr<Var> variable;
|
||||
ResourceHandle handle = HandleFromInput(ctx, i);
|
||||
OptionalTensor& optional = variables[i];
|
||||
optional.name = handle.name();
|
||||
if (LookupResource(ctx, handle, &variable).ok()) {
|
||||
tf_shared_lock lock(*variable->mu());
|
||||
optional.present = true;
|
||||
optional.value = *variable->tensor();
|
||||
}
|
||||
out.push_back(i);
|
||||
}
|
||||
}
|
||||
return variables;
|
||||
return out;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
|
||||
const XlaDevice::Metadata& metadata,
|
||||
const XlaCompiler::CompilationResult* result,
|
||||
xla::LocalExecutable* executable) {
|
||||
std::map<int, OptionalTensor> variables = GetVariables(ctx);
|
||||
|
||||
xla::LocalExecutable* executable,
|
||||
const ResourceVarsSnapshot& variable_args) {
|
||||
xla::LocalClient* client = metadata.client();
|
||||
|
||||
// Builds an XLA allocator for the device.
|
||||
@ -62,7 +53,7 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
|
||||
/*allocate_xla_tensors=*/true,
|
||||
/*use_multiple_streams=*/metadata.UseMultipleStreams());
|
||||
|
||||
launch_context.PopulateInputs(ctx, result, variables,
|
||||
launch_context.PopulateInputs(ctx, result, variable_args,
|
||||
/*missing_ctx_input_prefix=*/0);
|
||||
|
||||
se::Stream* stream =
|
||||
@ -87,7 +78,7 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
|
||||
executable->executable()->module().input_output_alias_config();
|
||||
TF_RETURN_IF_ERROR(launch_context.PopulateOutputs(
|
||||
ctx, result, run_result.ConsumeValueOrDie(),
|
||||
/*missing_ctx_input_prefix=*/0, input_output_alias, variables));
|
||||
/*missing_ctx_input_prefix=*/0, input_output_alias, variable_args));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -115,7 +106,7 @@ Status XlaCompileOnDemandOp::ShouldArgumentBeConstant(
|
||||
Status XlaCompileOnDemandOp::Compile(
|
||||
OpKernelContext* ctx, const XlaDevice::Metadata& metadata,
|
||||
const XlaCompiler::CompilationResult** result,
|
||||
xla::LocalExecutable** executable) {
|
||||
ResourceVarsSnapshot* variable_args, xla::LocalExecutable** executable) {
|
||||
std::map<int, Tensor> constant_arguments;
|
||||
for (int64 i = 0; i < ctx->num_inputs(); ++i) {
|
||||
const Tensor& device_tensor = ctx->input(i);
|
||||
@ -190,12 +181,18 @@ Status XlaCompileOnDemandOp::Compile(
|
||||
// rather than a one-element tuple.
|
||||
compile_options.always_return_tuple = false;
|
||||
|
||||
std::map<int, OptionalTensor> variable_args = GetVariables(ctx);
|
||||
|
||||
std::vector<int> variables_indices = GetResourceVariableIndices(ctx);
|
||||
std::vector<XlaCompiler::Argument> args;
|
||||
|
||||
TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
constant_arguments, variable_args, ctx, &args));
|
||||
{
|
||||
std::vector<VariableInfo> variable_infos;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetVariableInfosFromCtxInputs(ctx, variables_indices, &variable_infos));
|
||||
TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
|
||||
TF_RETURN_IF_ERROR(SnapshotResourceVariables(
|
||||
ctx, variables_indices, variable_infos, variable_args));
|
||||
TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
constant_arguments, variable_infos, ctx, &args));
|
||||
}
|
||||
|
||||
return cache->CompileSingleOp(options, args, ctx, compile_options, result,
|
||||
executable);
|
||||
@ -206,8 +203,10 @@ void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) {
|
||||
xla::LocalExecutable* executable;
|
||||
const XlaDevice::Metadata* metadata;
|
||||
OP_REQUIRES_OK(ctx, XlaDevice::GetMetadata(ctx, &metadata));
|
||||
OP_REQUIRES_OK(ctx, Compile(ctx, *metadata, &result, &executable));
|
||||
OP_REQUIRES_OK(ctx, Run(ctx, *metadata, result, executable));
|
||||
ResourceVarsSnapshot variable_args;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
Compile(ctx, *metadata, &result, &variable_args, &executable));
|
||||
OP_REQUIRES_OK(ctx, Run(ctx, *metadata, result, executable, variable_args));
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_JIT_XLA_COMPILE_ON_DEMAND_OP_H_
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_device.h"
|
||||
#include "tensorflow/compiler/jit/xla_launch_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
@ -47,10 +48,12 @@ class XlaCompileOnDemandOp : public OpKernel {
|
||||
bool* result);
|
||||
Status Compile(OpKernelContext* ctx, const XlaDevice::Metadata& metadata,
|
||||
const XlaCompiler::CompilationResult** result,
|
||||
ResourceVarsSnapshot* variable_args,
|
||||
xla::LocalExecutable** executable);
|
||||
Status Run(OpKernelContext* ctx, const XlaDevice::Metadata& metadata,
|
||||
const XlaCompiler::CompilationResult* result,
|
||||
xla::LocalExecutable* executable);
|
||||
xla::LocalExecutable* executable,
|
||||
const ResourceVarsSnapshot& variable_args);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -52,7 +52,8 @@ const char kPossibleNonVariableResourceHintMessage[] =
|
||||
"resource inputs to XLA.";
|
||||
} // anonymous namespace
|
||||
|
||||
VariableInfo::VariableInfo(int index, Var* var) : index_(index), var_(var) {}
|
||||
VariableInfo::VariableInfo(int index, absl::string_view name, Var* var)
|
||||
: index_(index), name_(name), var_(var) {}
|
||||
VariableInfo::VariableInfo(VariableInfo&& other)
|
||||
: index_(other.index_), var_(other.var_), lock_held_(other.lock_held_) {
|
||||
other.index_ = -1;
|
||||
@ -87,16 +88,15 @@ VariableInfo::~VariableInfo() {
|
||||
// Returns a vector of VariableInfo instances for the resource variable inputs
|
||||
// to the kernel with context `ctx`. The input indices for the resource
|
||||
// variable inputs are in `variable_indices`.
|
||||
static Status GetVariableInfosFromCtxInputs(
|
||||
OpKernelContext* ctx, absl::Span<const int> variable_indices,
|
||||
std::vector<VariableInfo>* result) {
|
||||
Status GetVariableInfosFromCtxInputs(OpKernelContext* ctx,
|
||||
absl::Span<const int> variable_indices,
|
||||
std::vector<VariableInfo>* result) {
|
||||
std::vector<const ResourceHandle*> resource_handles;
|
||||
absl::c_transform(
|
||||
variable_indices, std::back_inserter(resource_handles),
|
||||
[&](int variable_idx) { return &HandleFromInput(ctx, variable_idx); });
|
||||
|
||||
std::vector<core::RefCountPtr<Var>> variables;
|
||||
|
||||
Status s = LookupResources(ctx, resource_handles, &variables);
|
||||
if (!s.ok()) {
|
||||
errors::AppendToMessage(&s, kPossibleNonVariableResourceHintMessage);
|
||||
@ -109,7 +109,9 @@ static Status GetVariableInfosFromCtxInputs(
|
||||
// *Release* the variable because we're going to unref it later in
|
||||
// ~VariableInfo.
|
||||
Var* variable = variables[i].release();
|
||||
result->emplace_back(variable_indices[i], variable);
|
||||
int input_idx = variable_indices[i];
|
||||
std::string var_name = HandleFromInput(ctx, input_idx).name();
|
||||
result->emplace_back(input_idx, var_name, variable);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
@ -162,21 +164,12 @@ Status LockVariables(absl::Span<VariableInfo> variables) {
|
||||
|
||||
Status SnapshotResourceVariables(OpKernelContext* ctx,
|
||||
absl::Span<const int> variable_indices,
|
||||
std::map<int, OptionalTensor>* result) {
|
||||
std::vector<VariableInfo> variable_infos;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetVariableInfosFromCtxInputs(ctx, variable_indices, &variable_infos));
|
||||
TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
|
||||
|
||||
absl::Span<VariableInfo const> variable_infos,
|
||||
ResourceVarsSnapshot* result) {
|
||||
for (int i = 0; i < variable_indices.size(); i++) {
|
||||
if (variable_infos[i].var()) {
|
||||
OptionalTensor& tensor = (*result)[variable_indices[i]];
|
||||
tensor.name = HandleFromInput(ctx, variable_indices[i]).name();
|
||||
tensor.present = true;
|
||||
tensor.value = *variable_infos[i].var()->tensor();
|
||||
} else {
|
||||
(*result)[variable_indices[i]] = OptionalTensor();
|
||||
}
|
||||
Var* var = variable_infos[i].var();
|
||||
(*result)[variable_indices[i]] =
|
||||
var ? absl::make_optional(*var->tensor()) : absl::nullopt;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -197,8 +190,7 @@ XlaComputationLaunchContext::XlaComputationLaunchContext(
|
||||
void XlaComputationLaunchContext::PopulateInputs(
|
||||
OpKernelContext* ctx,
|
||||
const XlaCompiler::CompilationResult* compilation_result,
|
||||
const std::map<int, OptionalTensor>& variables,
|
||||
int missing_ctx_input_prefix) {
|
||||
const ResourceVarsSnapshot& variables, int missing_ctx_input_prefix) {
|
||||
// Build ShapedBuffers that point directly to the Tensor buffers.
|
||||
arg_ptrs_ =
|
||||
std::vector<ShapedBuffer*>(compilation_result->xla_input_shapes.size());
|
||||
@ -210,7 +202,7 @@ void XlaComputationLaunchContext::PopulateInputs(
|
||||
CHECK_GE(arg_num, missing_ctx_input_prefix);
|
||||
const xla::Shape& shape = compilation_result->xla_input_shapes[i];
|
||||
const Tensor* t = variables.count(arg_num)
|
||||
? &(variables.at(arg_num).value)
|
||||
? &(variables.at(arg_num).value())
|
||||
: &(ctx->input(arg_num - missing_ctx_input_prefix));
|
||||
CHECK(t);
|
||||
|
||||
@ -262,7 +254,7 @@ static const Tensor* FindAliasedTensorForOutput(
|
||||
int output_num, OpKernelContext* ctx, int missing_ctx_input_prefix,
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias,
|
||||
absl::Span<const int> input_mapping,
|
||||
const std::map<int, OptionalTensor>& resource_var_snapshots) {
|
||||
const ResourceVarsSnapshot& resource_var_snapshots) {
|
||||
if (MustAliasOutput(input_output_alias, output_num)) {
|
||||
int xla_param = input_output_alias.GetAliasedParameter({output_num})
|
||||
.value()
|
||||
@ -274,8 +266,8 @@ static const Tensor* FindAliasedTensorForOutput(
|
||||
// entry time.
|
||||
if (input_tensor->dtype() == DT_RESOURCE) {
|
||||
auto& v = resource_var_snapshots.at(missing_ctx_input_prefix + tf_param);
|
||||
CHECK(v.present);
|
||||
return &v.value;
|
||||
CHECK(v.has_value());
|
||||
return &v.value();
|
||||
}
|
||||
return input_tensor;
|
||||
}
|
||||
@ -298,9 +290,9 @@ static Tensor GetOrCreateTensorForOutput(
|
||||
int output_num, OpKernelContext* ctx, int missing_ctx_input_prefix,
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias,
|
||||
absl::Span<const int> input_mapping,
|
||||
const std::map<int, OptionalTensor>& resource_var_snapshots,
|
||||
DataType output_dtype, const TensorShape& output_shape,
|
||||
se::DeviceMemoryBase output_buffer, Allocator* output_allocator) {
|
||||
const ResourceVarsSnapshot& resource_var_snapshots, DataType output_dtype,
|
||||
const TensorShape& output_shape, se::DeviceMemoryBase output_buffer,
|
||||
Allocator* output_allocator) {
|
||||
if (const Tensor* aliased_tensor = FindAliasedTensorForOutput(
|
||||
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
|
||||
input_mapping, resource_var_snapshots)) {
|
||||
@ -431,13 +423,13 @@ static xla::StatusOr<std::vector<VariableInfo>> GatherVariableInfo(
|
||||
// TODO(b/35625933): tensorflow::Var should contain a PersistentTensor,
|
||||
// not a Tensor.
|
||||
Var* variable = nullptr;
|
||||
TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>(
|
||||
ctx, HandleFromInput(ctx, actual_input_index), &variable,
|
||||
[&write](Var** ptr) {
|
||||
*ptr = new Var(write.type);
|
||||
return Status::OK();
|
||||
}));
|
||||
variable_infos.emplace_back(actual_input_index, variable);
|
||||
const ResourceHandle handle = HandleFromInput(ctx, actual_input_index);
|
||||
TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>(ctx, handle, &variable,
|
||||
[&write](Var** ptr) {
|
||||
*ptr = new Var(write.type);
|
||||
return Status::OK();
|
||||
}));
|
||||
variable_infos.emplace_back(actual_input_index, handle.name(), variable);
|
||||
}
|
||||
return variable_infos;
|
||||
}
|
||||
@ -447,7 +439,7 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
const XlaCompiler::CompilationResult* compilation_result,
|
||||
ScopedShapedBuffer output, int missing_ctx_input_prefix,
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias,
|
||||
const std::map<int, OptionalTensor>& resource_var_snapshots) {
|
||||
const ResourceVarsSnapshot& resource_var_snapshots) {
|
||||
se::Stream* stream =
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
||||
Allocator* allocator = ctx->device()->GetAllocator({});
|
||||
@ -484,10 +476,36 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
stream->ThenRecordEvent(definition_event.get());
|
||||
}
|
||||
|
||||
std::vector<TensorShape> output_tensor_shapes;
|
||||
output_tensor_shapes.reserve(ctx->num_outputs());
|
||||
if (output.on_host_shape().is_dynamic()) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto transfer_manager,
|
||||
xla::TransferManager::GetForPlatform(stream->parent()->platform()));
|
||||
|
||||
xla::Shape output_host_shape = output.on_host_shape();
|
||||
xla::Shape output_device_shape = output.on_device_shape();
|
||||
TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes(
|
||||
stream, &output, &output_host_shape, &output_device_shape));
|
||||
|
||||
output.set_shapes(output_host_shape, output_device_shape);
|
||||
for (int i = 0; i < ctx->num_outputs(); ++i) {
|
||||
const xla::Shape& subshape =
|
||||
xla::ShapeUtil::GetSubshape(output_host_shape, {i});
|
||||
TensorShape shape;
|
||||
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(subshape, &shape));
|
||||
output_tensor_shapes.push_back(shape);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < ctx->num_outputs(); ++i) {
|
||||
output_tensor_shapes.push_back(compilation_result->outputs[i].shape);
|
||||
}
|
||||
}
|
||||
|
||||
// Copy XLA results to the OpOutputList.
|
||||
int output_num = 0;
|
||||
for (int i = 0; i < ctx->num_outputs(); ++i) {
|
||||
const TensorShape& shape = compilation_result->outputs[i].shape;
|
||||
const TensorShape& shape = output_tensor_shapes[i];
|
||||
const DataType& type = compilation_result->outputs[i].type;
|
||||
VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type "
|
||||
<< DataTypeString(type);
|
||||
@ -564,12 +582,21 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
|
||||
Status XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
const std::map<int, Tensor>& constant_args,
|
||||
const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
|
||||
absl::Span<VariableInfo const> variable_args, OpKernelContext* ctx,
|
||||
std::vector<XlaCompiler::Argument>* args) {
|
||||
args->resize(ctx->num_inputs());
|
||||
|
||||
absl::flat_hash_map<int, const VariableInfo*> variable_info_lookup;
|
||||
for (const VariableInfo& info : variable_args) {
|
||||
CHECK(!info.var() || info.lock_held())
|
||||
<< "Need to hold the lock on resource variables "
|
||||
"before calling BuildXlaCompilerArguments";
|
||||
variable_info_lookup.emplace(info.index(), &info);
|
||||
}
|
||||
|
||||
for (int64 input_num = 0; input_num < ctx->num_inputs(); ++input_num) {
|
||||
XlaCompiler::Argument& arg = (*args)[input_num];
|
||||
|
||||
if (constant_args.count(input_num) > 0) {
|
||||
// Handles compile-time constants.
|
||||
const Tensor& input = constant_args.at(input_num);
|
||||
@ -578,7 +605,7 @@ Status XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
arg.type = input.dtype();
|
||||
arg.shape = input.shape();
|
||||
arg.constant_value = input;
|
||||
} else if (variable_args.count(input_num) == 0) {
|
||||
} else if (variable_info_lookup.count(input_num) == 0) {
|
||||
// Handles the non-constant arguments.
|
||||
const Tensor& input = ctx->input(input_num);
|
||||
TF_RET_CHECK(input.dtype() != DT_RESOURCE);
|
||||
@ -594,14 +621,14 @@ Status XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
// Handles resource variables.
|
||||
const Tensor& input = ctx->input(input_num);
|
||||
TF_RET_CHECK(input.dtype() == DT_RESOURCE);
|
||||
const OptionalTensor& variable = variable_args.at(input_num);
|
||||
arg.name = variable.name;
|
||||
const VariableInfo& variable = *variable_info_lookup[input_num];
|
||||
arg.name = std::string(variable.name());
|
||||
arg.kind = XlaCompiler::Argument::kResource;
|
||||
arg.resource_kind = XlaResource::kVariable;
|
||||
if (variable.present) {
|
||||
const Tensor& value = variable.value;
|
||||
arg.type = value.dtype();
|
||||
arg.shape = value.shape();
|
||||
if (variable.var()) {
|
||||
const Tensor* value = variable.var()->tensor();
|
||||
arg.type = value->dtype();
|
||||
arg.shape = value->shape();
|
||||
arg.initialized = true;
|
||||
} else {
|
||||
// The values of uninitialized variables are not passed as inputs, since
|
||||
|
@ -34,36 +34,17 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Struct that represents a possibly-absent Tensor.
|
||||
struct OptionalTensor {
|
||||
string name; // A descriptive name
|
||||
bool present = false; // Is the tensor present?
|
||||
Tensor value; // If present, what is the Tensor's value?
|
||||
};
|
||||
|
||||
// Takes a snapshot of the values of resource variable arguments, whose indices
|
||||
// are specified in `variable_indices` argument. We snapshot tensors that back
|
||||
// resource variables since concurrent updates may modify the shape, and it is
|
||||
// important that the shapes used for compilation match the true shapes of the
|
||||
// buffers.
|
||||
//
|
||||
// We snapshot the entire set of resource variables as one atomic operation.
|
||||
// This models Read->* dependencies between resource variable operations. See
|
||||
// jit/resource_operation_safety_analysis for details.
|
||||
//
|
||||
// Returns a map of TensorFlow argument index to resource variable. If a
|
||||
// resource variable is not initialized, the corresponding OptionalTensor
|
||||
// will have its `present` field set to false.
|
||||
Status SnapshotResourceVariables(OpKernelContext* ctx,
|
||||
absl::Span<const int> variable_indices,
|
||||
std::map<int, OptionalTensor>* result);
|
||||
// Snapshot of resource variables for a TF kernel invocation, mapping from
|
||||
// parameter number to values at execution time. If the resource variable is not
|
||||
// initialized, the value will not be present.
|
||||
using ResourceVarsSnapshot = absl::flat_hash_map<int, absl::optional<Tensor>>;
|
||||
|
||||
// Information about the state of a variable passed as input to the _XlaCompile
|
||||
// and _XlaRun operators. Unlocks the resource variable and decrements its
|
||||
// refcount on destruction.
|
||||
class VariableInfo {
|
||||
public:
|
||||
explicit VariableInfo(int index, Var* var);
|
||||
explicit VariableInfo(int index, absl::string_view name, Var* var);
|
||||
VariableInfo(VariableInfo&& other);
|
||||
|
||||
VariableInfo& operator=(VariableInfo&& other);
|
||||
@ -79,6 +60,9 @@ class VariableInfo {
|
||||
// "empty", i.e. it does not track a resource variable.
|
||||
Var* var() const { return var_; }
|
||||
|
||||
// Returns the variable name.
|
||||
absl::string_view name() const { return name_; }
|
||||
|
||||
// Returns true if the resource variable lock was successfully acquired by
|
||||
// this thread.
|
||||
bool lock_held() const { return lock_held_; }
|
||||
@ -88,6 +72,7 @@ class VariableInfo {
|
||||
|
||||
private:
|
||||
int index_;
|
||||
std::string name_;
|
||||
Var* var_;
|
||||
|
||||
// We can't use a optional<mutex_lock> here because it confuses the compiler's
|
||||
@ -96,6 +81,20 @@ class VariableInfo {
|
||||
bool lock_held_ = false;
|
||||
};
|
||||
|
||||
// Takes a snapshot of the values of resource variable arguments, whose indices
|
||||
// are specified in `variable_indices` argument. We snapshot tensors that back
|
||||
// resource variables since concurrent updates may modify the shape, and it is
|
||||
// important that the shapes used for compilation match the true shapes of the
|
||||
// buffers.
|
||||
//
|
||||
// We snapshot the entire set of resource variables as one atomic operation.
|
||||
// This models Read->* dependencies between resource variable operations. See
|
||||
// jit/resource_operation_safety_analysis for details.
|
||||
Status SnapshotResourceVariables(OpKernelContext* ctx,
|
||||
absl::Span<const int> variable_indices,
|
||||
absl::Span<VariableInfo const> variable_infos,
|
||||
ResourceVarsSnapshot* result);
|
||||
|
||||
// Acquires the mutexes for all the variables in `variables` using a
|
||||
// deadlock-safe protocol (acquire the mutexes in increasing-address order).
|
||||
//
|
||||
@ -104,6 +103,13 @@ class VariableInfo {
|
||||
Status LockVariables(absl::Span<VariableInfo> variables)
|
||||
TF_EXCLUSIVE_LOCK_FUNCTION();
|
||||
|
||||
// Returns a vector of VariableInfo instances for the resource variable inputs
|
||||
// to the kernel with context `ctx`. The input indices for the resource
|
||||
// variable inputs are in `variable_indices`.
|
||||
Status GetVariableInfosFromCtxInputs(OpKernelContext* ctx,
|
||||
absl::Span<const int> variable_indices,
|
||||
std::vector<VariableInfo>* result);
|
||||
|
||||
// Helper class to perform the marshalling of TensorFlow inputs and outputs to
|
||||
// ShapedBuffers suitable for passing to an XLA computation.
|
||||
class XlaComputationLaunchContext {
|
||||
@ -123,9 +129,10 @@ class XlaComputationLaunchContext {
|
||||
|
||||
// Builds a XlaCompiler::Argument vector from the arguments to an XlaLaunch
|
||||
// op.
|
||||
// Precondition: variables in `variable_args` are locked.
|
||||
static Status BuildXlaCompilerArguments(
|
||||
const std::map<int, Tensor>& constant_args,
|
||||
const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
|
||||
absl::Span<VariableInfo const> variable_args, OpKernelContext* ctx,
|
||||
std::vector<XlaCompiler::Argument>* args);
|
||||
|
||||
// Add all inputs within `ctx` as XLA arguments (returned by arguments()).
|
||||
@ -137,7 +144,7 @@ class XlaComputationLaunchContext {
|
||||
// (in other words, no inputs actually required by the kernel can be missing).
|
||||
void PopulateInputs(OpKernelContext* ctx,
|
||||
const XlaCompiler::CompilationResult* compilation_result,
|
||||
const std::map<int, OptionalTensor>& variables,
|
||||
const ResourceVarsSnapshot& variables,
|
||||
int missing_ctx_input_prefix);
|
||||
|
||||
// Given the XLA output in `output`, populate all outputs of `ctx`. Also
|
||||
@ -155,7 +162,7 @@ class XlaComputationLaunchContext {
|
||||
const XlaCompiler::CompilationResult* compilation_result,
|
||||
xla::ScopedShapedBuffer output, int missing_ctx_input_prefix,
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias,
|
||||
const std::map<int, OptionalTensor>& resource_var_snapshots);
|
||||
const ResourceVarsSnapshot& resource_var_snapshots);
|
||||
|
||||
// Return the argument list. Only valid after PopulateInputs() has been
|
||||
// called.
|
||||
|
265
tensorflow/compiler/mlir/g3doc/xla_gpu_codegen.md
Normal file
265
tensorflow/compiler/mlir/g3doc/xla_gpu_codegen.md
Normal file
@ -0,0 +1,265 @@
|
||||
# MLIR CodeGen for XLA
|
||||
|
||||
<!--*
|
||||
# Document freshness: For more information, see go/fresh-source.
|
||||
freshness: { owner: 'timshen' reviewed: '2020-06-16' }
|
||||
*-->
|
||||
|
||||
XLA operates on `HloInstruction` and performs many optimizations on this
|
||||
representation, sharing a lot of these between targeted devices. As some point a
|
||||
linear schedule is computed and the memory buffer is assigned to each value
|
||||
statically. The device specific codegen operates by traversing this sequence and
|
||||
calling "emitters" to generate a representation suitable for the device (for
|
||||
example a single LLVM function per XLA computation on CPU, or a sequence of
|
||||
"thunks" encapsulating GPU operations and possibly generated PTX when targeting
|
||||
GPU).
|
||||
|
||||
As a staging step, we're currently in the process of intercepting the process
|
||||
right after XLA completes the buffer-assignment phase and emit instead an MLIR
|
||||
module in the `lhlo` dialect. From there we perform the codegen using MLIR
|
||||
components (Linalg, affine, and GPU dialect mainly) depending on the device.
|
||||
|
||||
Below is the plan of record to incrementally migrate XLA/GPU by using `lhlo` as
|
||||
the codegen input.
|
||||
|
||||
## Tasks
|
||||
|
||||
| | Host | Device
|
||||
| ------------- | ------------------------ | ------------------------
|
||||
| Input format | HloInstruction* (Task 1) | HloInstruction* (Task 1)
|
||||
| Output format | xla::Thunk (Task 2) | LLVM IR (Task 3)
|
||||
|
||||
* **Task 1** changes both host and device input format from HloInstruction* to
|
||||
LHLO.
|
||||
* **Task 2** changes output format of host from thunks to "some landing pad
|
||||
for host" (see below).
|
||||
* **Task 3** migrates device output from LLVM IR to some form of MLIR. It's
|
||||
optional to this project, and see the section "Migrating Device LLVM IR" for
|
||||
details.
|
||||
|
||||
This project prioritizes having end-to-end runnable models with LHLO-emitters
|
||||
enabled as much as possible. This implies that the following order list of
|
||||
objectives by priority:
|
||||
|
||||
* Make XLA/GPU runnable with LHLO emitters, with existing Thunks and emitters
|
||||
unmodified.
|
||||
* Eliminate the references to HloInstruction\* in LHLO, case by case:
|
||||
* Switch a legacy emitter to an MLIR-based emitter (e.g. Linalg), or
|
||||
* Mechanically translate the existing emitter to take MLIR representation
|
||||
(migrate to Standard with GPU Dialect).
|
||||
|
||||
## Migrating Thunks (Task 2)
|
||||
|
||||
xla::gpu::Thunk is a data structure that:
|
||||
|
||||
* Can be called into from the host (xla::gpu::Thunk::ExecuteOnStream()).
|
||||
* Carries various data in its subclasses.
|
||||
* Interacts with BufferAllocation::Slice and StreamExecutor.
|
||||
* Launches kernels
|
||||
* Calls into all runtime libraries.
|
||||
|
||||
The cost of that includes:
|
||||
|
||||
* Representing op-specific configuration data (e.g. convolution configs).
|
||||
* Migrating op shape and operand shapes.
|
||||
* Representing a tree of thunks (while, condition, etc).
|
||||
|
||||
The migration work is independent from LHLO / emitter migration. Under limited
|
||||
resources, it's prioritized behind LHLO / emitter migration.
|
||||
|
||||
We have several choices on how to lower the host-side part from LHLO:
|
||||
|
||||
* TFRT
|
||||
* (Pro) great CUDA and HIP wrappers for use.
|
||||
* (Pro) easy to implement library calls (cuDNN, cuBLAS, cuFFT, etc), as
|
||||
TFRT ops are interpreted by C++ code.
|
||||
* (Con) host side is under development and not tested.
|
||||
* (Con) the JAX integration isn’t clear from a runtime point of view
|
||||
* Jitted CPU code
|
||||
* (Pro) great lower-ability. Create a few loops and conditions and it's
|
||||
done.
|
||||
* (Con) GPUDialect doesn't yet model chains/streams/asynchronicity/device
|
||||
allocation.
|
||||
* (Con) CUDA / HIP runtime support is minimal (toolkit path, version,
|
||||
dynamic loading, etc).
|
||||
* Existing (interpreting) XLA runtime
|
||||
|
||||
Tentative conclusion: Use jitted CPU code during the transition, and optionally
|
||||
adopt TFRT in the end.
|
||||
|
||||
## Migrating Device LLVM IR (Task 3)
|
||||
|
||||
An elemental emitter generates target op by filling it element by element. Each
|
||||
output element depends on a set of elements from the operands. All elements are
|
||||
described by combining the buffer with dynamic indices. It's sufficient to
|
||||
describe almost all "math" ops, but for performance reasons only a large subset
|
||||
of "math" ops are implemented directly in (Cpu|Gpu)ElementalIrEmitter.
|
||||
|
||||
ElementalIrEmitter is unique in that:
|
||||
|
||||
* A large portion of the code is shared between XLA/GPU and CPU.
|
||||
* It represents a large portion of ops seen in models, including all
|
||||
element-wise ops.
|
||||
* Most fusions solely depend on ElementalIrEmitter.
|
||||
* It's structurally simple, as it describes a data dependency DAG between op
|
||||
elements and operand elements.
|
||||
* It's mostly portable and high-level (e.g. unlike GPU kReduce and GPU kCopy).
|
||||
* Dynamic shape support is easy for at least element-wise ops.
|
||||
|
||||
Now, for all ops, elementally-emitted or not, there are several flavors of the
|
||||
end state of each XLA op:
|
||||
|
||||
1. Device code stays as LLVM IR.
|
||||
1. Refactor the old emitter to be like LHLO -> MLIR LLVM Dialect:
|
||||
* (Cost) Will be throw-away work if we want to ultimately migrate to
|
||||
Standard.
|
||||
* (Benefit) It is easy and mechanical. Can be done in a short period.
|
||||
* (Benefit) It doesn't benefit more compared to a).
|
||||
1. Refactor old emitters to be like LHLO -> MLIR GPU + Standard + Loops:
|
||||
* (Cost) Lifting existing emitters to Standard introduces some challenges.
|
||||
Pointers and GEPs need to be converted to MemRefs and SubViews. Ensuring
|
||||
amdgpu completeness is another one.
|
||||
* (Cost) XLA/GPU heavily relies on LLVM metadata:
|
||||
* `range` for block/thread indices.
|
||||
* `align`, `dereferenceable`, `invariant.load`, `alias.scope`,
|
||||
`noalias` for load/stores.
|
||||
* `llvm.loop.unroll.disable`, `llvm.loop.unroll.full`,
|
||||
`llvm.loop.vectorize.enable` for sequential loops.
|
||||
* (Benefit) Can be long-term. More portable.
|
||||
1. Refactor old emitters to be LHLO -> Linalg, and write new Linalg emitters
|
||||
* (Cost) This is case by case. Compared to previous options, a new
|
||||
implementation that matches XLA's performance needs to go through the
|
||||
benchmark <-> optimize workflow, which can be a significant cost for
|
||||
some ops.
|
||||
* (Benefit) unified stack; community support; portability; more
|
||||
optimization potentials.
|
||||
|
||||
## Prioritization
|
||||
|
||||
While all three tasks mentioned above are parallelizable, under limited
|
||||
resources they have to be serialized. The prioritization focuses on visible
|
||||
results for completion of each task.
|
||||
|
||||
The prioritization is: Task1 (LHLO for legacy emitters) > Task 2 (Thunks) > Task
|
||||
3 (MLIR emitters).
|
||||
|
||||
By the end of Task 1, users of XLA can generate an LHLO (e.g. kernel generator)
|
||||
and execute them. The compilation format will not be serializable MLIR.
|
||||
|
||||
By the end of Task 2, LHLO lowers to proper, serializable MLIR. This enables
|
||||
offline compilation.
|
||||
|
||||
By the end of Task 3, all XLA emitters are MLIR-based in its implementation.
|
||||
|
||||
## Detailed Design
|
||||
|
||||
### Step 1: (Task 1) Complete LHLO and Make Legacy Emitters Take LHLO
|
||||
|
||||
This step makes all existing XLA/GPU emitters interact with MLIR ops. This step
|
||||
is pure refactoring and NFC.
|
||||
|
||||
This step is mostly mechanical, but it's worth noticing the following
|
||||
discrepancies between an unnested HloComputation and LHLO:
|
||||
|
||||
* Each HloInstruction has direct access to its operands (a data-flow DAG). On
|
||||
contrary, each LHLO op only has access to its operand buffers (a bipartite
|
||||
between ops and buffers). LHLO ops have to go through use-def chains to
|
||||
access their operand ops.
|
||||
* Unnested legacy emitters empirically almost never access their operands. The
|
||||
only exception is kReduce.
|
||||
* Unnested legacy emitters access BufferAssignment only for getting slices,
|
||||
not for accessing aux data structures like dataflow\_analysis() or
|
||||
alias\_analysis(). llvm\_ir builds its own alias\_analysis() based on slice
|
||||
information.
|
||||
|
||||
The conclusion is that LHLO should fit right-in without major hassle.
|
||||
|
||||
### Step 2: (Optional) Profiling Support
|
||||
|
||||
**This step is only needed if we start to discard some of the XLA Thunk logic
|
||||
(see the next step).**
|
||||
|
||||
Before actually turning on any MLIR-based emitters, we need profiling for
|
||||
MLIR-based emitters.
|
||||
|
||||
Currently XLA performs its own profiling by calling into StreamExecutor's timer.
|
||||
The timer under the hood inserts two events before and after a kernel launch,
|
||||
and measures the sync time between these two events.
|
||||
|
||||
There are roughly three approaches to support profiling in MLIR:
|
||||
|
||||
* Run a profiler end-to-end
|
||||
* Add a profile op for each op in LHLO, using an injected profiler.
|
||||
|
||||
The "end-to-end" approach is transparent to MLIR, but suffers the same problem
|
||||
that makes XLA not use it in the first place: library calls collected by a
|
||||
profiler (nvprof/...) can't easily relate to HLO ops. For example, cuDNN
|
||||
launches multiple kernels for each HLO, and it's hard to tell which kernels
|
||||
correspond to which HLO.
|
||||
|
||||
The "injected profiler" approach requires:
|
||||
|
||||
* LHLO to take a profiler as a parameter.
|
||||
* inserting profile.start / profile.end before and after each op.
|
||||
* a pass from that lowers profile.{start,end} to a C++ implementation.
|
||||
|
||||
The exact profiling can't be easily done for MLIR-generated ops, since:
|
||||
|
||||
* MLIR doesn't have a timer, nor it depends on TFRT / StreamExecutor.
|
||||
* MLIR doesn't easily call into C functions with complicated parameters.
|
||||
|
||||
### Step 3: (Task 2) Migrating Thunks
|
||||
|
||||
This step migrates all host ops and library calls. This step will eliminate most
|
||||
of the thunks and produce serializable MLIR instead.
|
||||
|
||||
There are roughly three kinds of thunks:
|
||||
|
||||
* KernelThunk, which launches a kernel.
|
||||
* Control flow thunks, which has host control flow logic (conditional, while,
|
||||
for, sequence) and launch body kernels.
|
||||
* Library thunks: cuDNN, cuBLAS, cuFFT, NCCL, etc.
|
||||
|
||||
The **bottom line** is to:
|
||||
|
||||
* Create a Thunk dialect that provides (de)serialize logic for all existing
|
||||
C++-based Thunks.
|
||||
* Change emitters to emit a graph of Thunk dialect.
|
||||
|
||||
**Optionally**, we can relieve some thunks from C++ implementation. KernelThunk
|
||||
can lower to the GPU LaunchKernelOp. Control flow thunks can leverage the CFG
|
||||
Dialect for loops and conditions, combined with LaunchKernelOp. This optional
|
||||
step requires profiling and stream support.
|
||||
|
||||
### Step 4: (Task 3) Migrated ElementalIrEmitter
|
||||
|
||||
Once profiling is ready, we can complete and tune all ElementalIrEmitter-based
|
||||
emitters in MLIR. Then we turn them on by default, assuming that all of these
|
||||
MLIR-based emitters use a single stream.
|
||||
|
||||
Notice that it's beneficial to migrate XLA/CPU's ElementalIrEmitter as well,
|
||||
since they share a large portion of the code.
|
||||
|
||||
With all benchmarking and performance hunting done (TODO: define performance
|
||||
parity), we turn on the new MLIR-based elemental emitter, and delete the legacy
|
||||
ElementalIrEmitter.
|
||||
|
||||
This step also provides easy fusion transitions (nested ops) for the later
|
||||
migration.
|
||||
|
||||
### Step 5: Multi-Stream Support or Drop
|
||||
|
||||
We can't delete
|
||||
[some of the emitters](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/service/gpu/stream_assignment.cc#L140)
|
||||
until we support it in MLIR, or we drop the feature. It's a relatively large
|
||||
amount of work in MLIR and a small amount of gain for XLA. We should investigate
|
||||
current users of multi-stream XLA/GPU users, and try to delete this feature if
|
||||
reasonable.
|
||||
|
||||
### Step 6: (Task 3) Migrated Device Ops
|
||||
|
||||
This step migrates all unnested ops, then we can delete all unnested emitters.
|
||||
|
||||
This calls on a rewrite/refactor for kCopy and kReduce. kReduce is already
|
||||
worked on for plenty, so the actual amount of work that needs to be done remains
|
||||
to be seen.
|
@ -314,7 +314,6 @@ tf_cc_test(
|
||||
cc_library(
|
||||
name = "tensorflow_lite_legalize_tf",
|
||||
srcs = [
|
||||
"transforms/device_index_selector.cc",
|
||||
"transforms/dilated_conv.cc",
|
||||
"transforms/generated_legalize_tf.inc",
|
||||
"transforms/generated_lower_static_tensor_list.inc",
|
||||
|
@ -446,7 +446,7 @@ static void GenOperandResultVerifier(raw_ostream &os,
|
||||
auto desc =
|
||||
definit->getDef()->getValueAsString("tflRuntimeTypeDescription");
|
||||
|
||||
// Emit a loop to check all the dynamic values in the pack.
|
||||
// Emit a loop to check all operands.
|
||||
os << formatv(" for (Value v : top.getODS{0}{1}s({2})) {{\n",
|
||||
// Capitalize the first letter to match the function name
|
||||
valueKind.substr(0, 1).upper(), valueKind.substr(1),
|
||||
@ -455,14 +455,10 @@ static void GenOperandResultVerifier(raw_ostream &os,
|
||||
os << " (void)v;\n"
|
||||
<< " if (!("
|
||||
<< tgfmt(pred.getCondition(), &fctx.withSelf("v.getType()")) << ")) {\n"
|
||||
<< " if (failure_on_operand_type_mismatch) {\n"
|
||||
<< formatv(
|
||||
" return op->emitOpError(\"{0} #\") << index "
|
||||
" return op->emitOpError(\"{0} #\") << index "
|
||||
"<< \" must be {1}, but got \" << v.getType();\n",
|
||||
valueKind, desc)
|
||||
<< " } else {\n"
|
||||
<< " return ::mlir::LogicalResult::Failure;\n"
|
||||
<< " }\n"
|
||||
<< " }\n" // if
|
||||
<< " ++index;\n"
|
||||
<< " }\n"; // for
|
||||
@ -487,8 +483,7 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
|
||||
|
||||
mlir::tblgen::FmtContext verify_ctx;
|
||||
os << "::mlir::LogicalResult " << op.getCppClassName()
|
||||
<< "::VerifyTflRuntimeConstraints(::mlir::Operation *op, bool "
|
||||
"failure_on_operand_type_mismatch) {\n";
|
||||
<< "::VerifyTflRuntimeConstraints(::mlir::Operation *op) {\n";
|
||||
os << " auto top = cast<" << op.getCppClassName() << ">(op); (void)top;\n";
|
||||
verify_ctx.withOp("top");
|
||||
|
||||
@ -529,11 +524,8 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
|
||||
|
||||
mlir::tblgen::Pred pred(dyn_cast<llvm::DefInit>(val->getValue()));
|
||||
os << tgfmt(
|
||||
" if (!($0)) {\n "
|
||||
" if (failure_on_operand_type_mismatch) {\n"
|
||||
" return top.emitOpError(\"failed to verify that $1\");\n"
|
||||
" } else {\n"
|
||||
" return ::mlir::LogicalResult::Failure;\n }\n }\n",
|
||||
" if (!($0))\n"
|
||||
" return top.emitOpError(\"failed to verify that $1\");\n",
|
||||
&verify_ctx, tgfmt(pred.getCondition(), &verify_ctx), desc);
|
||||
}
|
||||
os << " return top.verify();\n}\n";
|
||||
|
@ -240,10 +240,10 @@ static bool IsValidTFLiteMlirModule(ModuleOp module) {
|
||||
}
|
||||
|
||||
for (auto fn : module.getOps<FuncOp>()) {
|
||||
if (fn.getBlocks().size() != 1) {
|
||||
if (!llvm::hasSingleElement(fn)) {
|
||||
return fn.emitError("should have exactly one basic block"), false;
|
||||
}
|
||||
auto& bb = fn.getBlocks().front();
|
||||
auto& bb = fn.front();
|
||||
|
||||
for (auto arg : bb.getArguments()) {
|
||||
if (!HasValidTFLiteType(arg, fn))
|
||||
@ -1089,7 +1089,7 @@ void Translator::InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr) {
|
||||
dict_attr.get("outputs").dyn_cast_or_null<mlir::StringAttr>()) {
|
||||
str.getValue().split(output_names, ',', /*MaxSplit=*/-1,
|
||||
/*KeepEmpty=*/false);
|
||||
auto term = fn.getBlocks().back().getTerminator();
|
||||
auto term = fn.back().getTerminator();
|
||||
if (output_names.size() != term->getNumOperands()) {
|
||||
fn.emitWarning() << "output names (" << output_names.size()
|
||||
<< ") != terminator operands (" << term->getNumOperands()
|
||||
|
@ -94,8 +94,7 @@ def TFL_RuntimeVerification : OpInterface<"TflRuntimeVerifyOpInterface"> {
|
||||
let methods = [
|
||||
StaticInterfaceMethod<
|
||||
[{Returns whether the op's operands/results are supported by runtime.}],
|
||||
"LogicalResult", "VerifyTflRuntimeConstraints",
|
||||
(ins "Operation*":$op, "bool":$failure_on_operand_type_mismatch)
|
||||
"LogicalResult", "VerifyTflRuntimeConstraints", (ins "Operation*":$op)
|
||||
>,
|
||||
];
|
||||
}
|
||||
|
@ -138,6 +138,11 @@ bool IsI32Type(Type element_type) {
|
||||
return element_type.isInteger(32) && !element_type.isUnsignedInteger();
|
||||
}
|
||||
|
||||
// Return true when the given element_type is I64.
|
||||
bool IsI64Type(Type element_type) {
|
||||
return element_type.isInteger(64) && !element_type.isUnsignedInteger();
|
||||
}
|
||||
|
||||
// Return true if the given Add operation has the CPU kernel supported shapes.
|
||||
bool VerifyAddOpShapeConstraints(AddOp op) {
|
||||
auto element_type = getElementTypeOrSelf(op.output().getType());
|
||||
@ -174,7 +179,8 @@ bool VerifySubOpShapeConstraints(SubOp op) {
|
||||
// Allows F32, QUI8, and QI16 outputs when the operands have valid shapes,
|
||||
// which are broadcastable shapes up to five dimension or have same shapes.
|
||||
if (element_type.isF32() || IsI32Type(element_type) ||
|
||||
IsQUI8Type(element_type) || IsQI16Type(element_type)) {
|
||||
IsI64Type(element_type) || IsQUI8Type(element_type) ||
|
||||
IsQI16Type(element_type)) {
|
||||
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
|
||||
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
|
||||
/*max_bcast_rank=*/5);
|
||||
@ -758,6 +764,22 @@ OpFoldResult ConcatenationOp::fold(ArrayRef<Attribute> operands) {
|
||||
return new_concat.getResult();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CustomOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(CustomOp op) {
|
||||
OpaqueElementsAttr opaque_attr =
|
||||
op.custom_option().cast<OpaqueElementsAttr>();
|
||||
if (!opaque_attr.getType().hasStaticShape())
|
||||
return op.emitOpError("custom_option should have a static shape.");
|
||||
if (opaque_attr.getValue().size() !=
|
||||
opaque_attr.getType().cast<ShapedType>().getDimSize(0))
|
||||
return op.emitOpError(
|
||||
"custom_option should have the same length of content with shape.");
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FullyConnectedOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -2169,6 +2191,10 @@ static LogicalResult Verify(TransposeOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// WhileOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult Verify(WhileOp op) {
|
||||
if (op.getNumOperands() != op.getNumResults())
|
||||
return op.emitOpError(llvm::formatv(
|
||||
@ -2178,18 +2204,6 @@ LogicalResult Verify(WhileOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult Verify(CustomOp op) {
|
||||
OpaqueElementsAttr opaque_attr =
|
||||
op.custom_option().cast<OpaqueElementsAttr>();
|
||||
if (!opaque_attr.getType().hasStaticShape())
|
||||
return op.emitOpError("custom_option should have a static shape.");
|
||||
if (opaque_attr.getValue().size() !=
|
||||
opaque_attr.getType().cast<ShapedType>().getDimSize(0))
|
||||
return op.emitOpError(
|
||||
"custom_option should have the same length of content with shape.");
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Canonicalize While op so that results and operands match and external values
|
||||
// are via implicit capture rather than via block args.
|
||||
|
@ -571,6 +571,8 @@ def TFL_TransposeConvOp: TFL_Op<"transpose_conv", [
|
||||
TFL_OperandHasRank<2, 4>,
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 2>>,
|
||||
AccumulatorUniformScale<3, 1, 2>,
|
||||
TFL_ChannelDimIndexInterface, AffineOpCoefficient<0, 2>,
|
||||
TFL_GpuTargetOp,
|
||||
TFL_SparseOp]> {
|
||||
let summary = "Transpose convolution operator";
|
||||
@ -596,6 +598,8 @@ def TFL_TransposeConvOp: TFL_Op<"transpose_conv", [
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// ChannelDimIndexInterface:
|
||||
int GetChannelDimIndex() { return 0; }
|
||||
// SparseOpInterface:
|
||||
std::vector<int> GetSparseOperands() { return {1}; }
|
||||
std::vector<std::vector<int>> GetFloatBlockSize() { return {}; }
|
||||
@ -953,14 +957,14 @@ in the batch dimensions and broadcasting.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[F32]>:$x,
|
||||
TFL_TensorOf<[F32]>:$y,
|
||||
TFL_TensorOf<[F32, QI8]>:$x,
|
||||
TFL_TensorOf<[F32, QI8]>:$y,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$adj_x,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$adj_y
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32]>:$output
|
||||
TFL_TensorOf<[F32, QI8]>:$output
|
||||
);
|
||||
|
||||
let hasOptions = 1;
|
||||
@ -2860,11 +2864,11 @@ def TFL_SubOp : TFL_Op<"sub", [
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$lhs,
|
||||
TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$rhs,
|
||||
ins TFL_TensorOf<[F32, I32, I64, QI8, QUI8, QI16]>:$lhs,
|
||||
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, QI16]>:$rhs,
|
||||
TFL_AFAttr:$fused_activation_function);
|
||||
|
||||
let results = (outs TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$output);
|
||||
let results = (outs TFL_TensorOf<[F32, I32, I64, QI8, QUI8, QI16]>:$output);
|
||||
|
||||
let hasFolder = 1;
|
||||
|
||||
|
@ -26,6 +26,7 @@ filegroup(
|
||||
"//tensorflow/compiler/mlir/lite:flatbuffer_to_string",
|
||||
"//tensorflow/compiler/mlir/lite:tf_tfl_translate",
|
||||
"@llvm-project//llvm:FileCheck",
|
||||
"@llvm-project//llvm:not",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -0,0 +1,257 @@
|
||||
# RUN: not tf_tfl_translate -tf-upgrade-legacy=false -tf-input-arrays=Placeholder,Placeholder_1 -tf-input-shapes=1,2:1 -tf-output-arrays=cond/Merge -tf-enable-shape-inference-on-import=false -mlir-print-debuginfo -output-mlir %s -o - 2>&1 | FileCheck %s
|
||||
|
||||
# CHECK: error: The graph has Control Flow V1 ops. TFLite converter doesn't support Control Flow V1 ops. Consider using Control Flow V2 ops instead.
|
||||
|
||||
node {
|
||||
name: "Const"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape {
|
||||
dim {
|
||||
size: 2
|
||||
}
|
||||
dim {
|
||||
size: 2
|
||||
}
|
||||
}
|
||||
tensor_content: "\315\314\314=\315\314L>\232\231\231>\315\314\314>"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Placeholder"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
dim {
|
||||
size: -1
|
||||
}
|
||||
dim {
|
||||
size: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Placeholder_1"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_BOOL
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "cond/Switch"
|
||||
op: "Switch"
|
||||
input: "Placeholder_1"
|
||||
input: "Placeholder_1"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_BOOL
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "cond/switch_t"
|
||||
op: "Identity"
|
||||
input: "cond/Switch:1"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_BOOL
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "cond/switch_f"
|
||||
op: "Identity"
|
||||
input: "cond/Switch"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_BOOL
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "cond/pred_id"
|
||||
op: "Identity"
|
||||
input: "Placeholder_1"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_BOOL
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "cond/MatMul"
|
||||
op: "MatMul"
|
||||
input: "cond/MatMul/Switch:1"
|
||||
input: "cond/MatMul/Switch_1:1"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "transpose_a"
|
||||
value {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "transpose_b"
|
||||
value {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "cond/MatMul/Switch"
|
||||
op: "Switch"
|
||||
input: "Placeholder"
|
||||
input: "cond/pred_id"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_class"
|
||||
value {
|
||||
list {
|
||||
s: "loc:@Placeholder"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "cond/MatMul/Switch_1"
|
||||
op: "Switch"
|
||||
input: "Const"
|
||||
input: "cond/pred_id"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_class"
|
||||
value {
|
||||
list {
|
||||
s: "loc:@Const"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "cond/Add"
|
||||
op: "Add"
|
||||
input: "cond/Add/Switch"
|
||||
input: "cond/Add/Switch_1"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "cond/Add/Switch"
|
||||
op: "Switch"
|
||||
input: "Placeholder"
|
||||
input: "cond/pred_id"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_class"
|
||||
value {
|
||||
list {
|
||||
s: "loc:@Placeholder"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "cond/Add/Switch_1"
|
||||
op: "Switch"
|
||||
input: "Const"
|
||||
input: "cond/pred_id"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_class"
|
||||
value {
|
||||
list {
|
||||
s: "loc:@Const"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "cond/Merge"
|
||||
op: "Merge"
|
||||
input: "cond/Add"
|
||||
input: "cond/MatMul"
|
||||
attr {
|
||||
key: "N"
|
||||
value {
|
||||
i: 2
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "init"
|
||||
op: "NoOp"
|
||||
}
|
||||
versions {
|
||||
producer: 134
|
||||
}
|
@ -9,6 +9,15 @@ func @add(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
func @sub(%arg0: tensor<1xi64>, %arg1: tensor<1xi64>) -> tensor<1xi64> {
|
||||
%0 = "tf.Sub"(%arg0, %arg1) : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi64>
|
||||
return %0: tensor<1xi64>
|
||||
|
||||
// CHECK-LABEL: sub
|
||||
// CHECK: tfl.sub %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1xi64>
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testAddHighDimsHaveSameShape
|
||||
func @testAddHighDimsHaveSameShape(%arg0: tensor<1x2x3x4x5x6x7x8xi32>, %arg1: tensor<1x2x3x4x5x6x7x8xi32>) -> tensor<1x2x3x4x5x6x7x8xi32> {
|
||||
// CHECK: tfl.add %arg0, %arg1 {fused_activation_function = "NONE"}
|
||||
@ -990,6 +999,13 @@ func @batch_to_space_nd(%arg0: tensor<4x2x2x3xf32>, %arg1: tensor<2xi32>, %arg2:
|
||||
// CHECK: "tfl.batch_to_space_nd"(%arg0, %arg1, %arg2) : (tensor<4x2x2x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<?xf32>
|
||||
}
|
||||
|
||||
func @batch_to_space_nd_unsupported(%arg0: tensor<?x1x1x1x4xf32>, %arg1: tensor<3xi32>, %arg2: tensor<3x2xi32>) -> tensor<?x3x3x3x4xf32> {
|
||||
%0 = "tf.BatchToSpaceND"(%arg0, %arg1, %arg2) : (tensor<?x1x1x1x4xf32>, tensor<3xi32>, tensor<3x2xi32>) -> tensor<?x3x3x3x4xf32>
|
||||
return %0 : tensor<?x3x3x3x4xf32>
|
||||
// CHECK-LABEL: batch_to_space_nd_unsupported
|
||||
// CHECK: "tf.BatchToSpaceND"
|
||||
}
|
||||
|
||||
func @space_to_batch_nd(%arg0: tensor<1x4x4x3xf32>, %arg1: tensor<2xi32>, %arg2: tensor<2x2xi32>) -> tensor<?xf32> {
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %arg1, %arg2) : (tensor<1x4x4x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
|
@ -269,6 +269,14 @@ func @testSub(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32> {
|
||||
return %0#0 : tensor<? x i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testSubInt64
|
||||
func @testSubInt64(tensor<? x i64>, tensor<? x i64>) -> tensor<? x i64> {
|
||||
^bb0(%arg0: tensor<? x i64>, %arg1: tensor<? x i64>):
|
||||
// CHECK: tfl.sub %arg0, %arg1 {fused_activation_function = "RELU6"}
|
||||
%0 = tfl.sub %arg0, %arg1 {fused_activation_function = "RELU6"} : tensor<? x i64>
|
||||
return %0#0 : tensor<? x i64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testMul
|
||||
func @testMul(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32> {
|
||||
^bb0(%arg0: tensor<? x i32>, %arg1: tensor<? x i32>):
|
||||
|
@ -70,6 +70,7 @@ func @prepareAdd(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: prepareConv2DSplat
|
||||
// PerTensor-LABEL: prepareConv2DSplat
|
||||
func @prepareConv2DSplat(%arg0: tensor<1x5x5x3xf32>) -> tensor<1x5x5x3xf32> {
|
||||
%w = constant dense<127.0> : tensor<3x3x3x3xf32>
|
||||
%b = constant dense<0.0> : tensor<3xf32>
|
||||
@ -89,6 +90,7 @@ func @prepareConv2DSplat(%arg0: tensor<1x5x5x3xf32>) -> tensor<1x5x5x3xf32> {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: prepareConv2D
|
||||
// PerTensor-LABEL: prepareConv2D
|
||||
func @prepareConv2D(%arg0: tensor<1x5x5x1xf32>) -> tensor<1x5x5x3xf32> {
|
||||
%w = constant dense<[[[[0.0]]], [[[127.0]]], [[[-127.0]]]]> : tensor<3x1x1x1xf32>
|
||||
%b = constant dense<0.0> : tensor<3xf32>
|
||||
@ -108,6 +110,7 @@ func @prepareConv2D(%arg0: tensor<1x5x5x1xf32>) -> tensor<1x5x5x3xf32> {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: prepareDepthwiseConv2D
|
||||
// PerTensor-LABEL: prepareDepthwiseConv2D
|
||||
func @prepareDepthwiseConv2D(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> {
|
||||
%w = constant dense<127.0> : tensor<32x3x3x3xf32>
|
||||
%b = constant dense<0.0> : tensor<32xf32>
|
||||
@ -127,6 +130,7 @@ func @prepareDepthwiseConv2D(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x112x112
|
||||
}
|
||||
|
||||
// CHECK-LABEL: QuantizeFullyConnected
|
||||
// PerTensor-LABEL: QuantizeFullyConnected
|
||||
func @QuantizeFullyConnected(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> {
|
||||
%w = constant dense<127.0> : tensor<32x12xf32>
|
||||
%b = constant dense<0.0> : tensor<32xf32>
|
||||
@ -143,3 +147,22 @@ func @QuantizeFullyConnected(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x112x112
|
||||
// PerTensor: %[[dq:.*]] = "tfl.dequantize"(%0) : (tensor<32x12x!quant.uniform<i8<-127:127>:f32, 1.000000e+00>>) -> tensor<32x12xf32>
|
||||
// PerTensor: "tfl.fully_connected"(%arg0, %[[dq]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: QuantizeTransposeConv
|
||||
// PerTensor-LABEL: QuantizeTransposeConv
|
||||
func @QuantizeTransposeConv(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<4xi32>) -> tensor<1x32x42x128xf32> {
|
||||
%w = constant dense<127.0> : tensor<1x32x42x128xf32>
|
||||
%b = constant dense<0.0> : tensor<1x32x42x128xf32>
|
||||
%tc = "tfl.transpose_conv"(%arg1, %arg0, %w, %b) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<1x32x42x128xf32>) -> tensor<1x32x42x128xf32>
|
||||
return %tc : tensor<1x32x42x128xf32>
|
||||
|
||||
// CHECK: %[[CST:.*]] = constant dense<1.270000e+02> : tensor<1x32x42x128xf32>
|
||||
// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CST]]) {qtype = tensor<1x32x42x128x!quant.uniform<i8<-127:127>:f32:0, {1.000000e+00}>>, volatile}
|
||||
// CHECK: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]]) : (tensor<1x32x42x128x!quant.uniform<i8<-127:127>:f32:0, {1.000000e+00}>>) -> tensor<1x32x42x128xf32>
|
||||
// CHECK: "tfl.transpose_conv"(%arg1, %arg0, %[[DEQUANTIZE]]
|
||||
|
||||
// PerTensor: %[[CST:.*]] = constant dense<1.270000e+02> : tensor<1x32x42x128xf32>
|
||||
// PerTensor: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CST]]) {qtype = tensor<1x32x42x128x!quant.uniform<i8<-127:127>:f32, 1.000000e+00>>, volatile}
|
||||
// PerTensor: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]]) : (tensor<1x32x42x128x!quant.uniform<i8<-127:127>:f32, 1.000000e+00>>) -> tensor<1x32x42x128xf32>
|
||||
// PerTensor: "tfl.transpose_conv"(%arg1, %arg0, %[[DEQUANTIZE]]
|
||||
}
|
||||
|
@ -528,6 +528,26 @@ func @PadStridedSliceNewAxisMask2(%arg0: tensor<4x64x64x1xf32>) -> tensor<1x4x64
|
||||
return %1 : tensor<1x4x64x64xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @StridedSliceRewriteMasks
|
||||
func @StridedSliceRewriteMasks(%arg0: tensor<8x4x16x2xf32>) -> tensor<8x4x16x1xf32> {
|
||||
%cst = "tf.Const"() {device = "", value = dense<[1, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||
%cst_0 = "tf.Const"() {device = "", value = dense<[1, 0, 0]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||
%cst_1 = "tf.Const"() {device = "", value = dense<1> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||
|
||||
// CHECK: %[[CST:.*]] = constant dense<[1, 0, 0, 1]> : tensor<4xi32>
|
||||
// CHECK: %[[CST0:.*]] = constant dense<[1, 0, 0, 0]> : tensor<4xi32>
|
||||
// CHECK: %[[CST1:.*]] = constant dense<1> : tensor<4xi32>
|
||||
// CHECK: %[[RESULT:.*]] = "tf.StridedSlice"(%arg0, %[[CST]], %[[CST0]], %[[CST1]])
|
||||
// CHECK-SAME: begin_mask = 7 : i64
|
||||
// CHECK-SAME: ellipsis_mask = 0 : i64
|
||||
// CHECK-SAME: end_mask = 14 : i64
|
||||
// CHECK-SAME: new_axis_mask = 0 : i64
|
||||
// CHECK-SAME: shrink_axis_mask = 0 : i64
|
||||
|
||||
%0 = "tf.StridedSlice"(%arg0, %cst, %cst_0, %cst_1) {begin_mask = 1 : i64, device = "", ellipsis_mask = 2 : i64, end_mask = 4 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<8x4x16x2xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x4x16x1xf32>
|
||||
return %0 : tensor<8x4x16x1xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @MatrixSetDiagV2Conversion
|
||||
func @MatrixSetDiagV2Conversion(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) -> tensor<3x3xi32> {
|
||||
%cst = constant dense<0> : tensor<i32>
|
||||
|
@ -39,22 +39,18 @@ namespace tensorflow {
|
||||
void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs,
|
||||
mlir::OpPassManager* pass_manager) {
|
||||
pass_manager->addPass(mlir::TFL::CreatePrepareQuantizePass(quant_specs));
|
||||
pass_manager->addPass(mlir::TFL::CreateQuantizePass());
|
||||
bool emit_quant_adaptor_ops =
|
||||
quant_specs.inference_type != quant_specs.inference_input_type;
|
||||
pass_manager->addPass(
|
||||
mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops));
|
||||
|
||||
if (quant_specs.default_ranges.first.hasValue() ||
|
||||
quant_specs.default_ranges.second.hasValue()) {
|
||||
pass_manager->addPass(mlir::TFL::CreateDefaultQuantParamsPass(
|
||||
quant_specs.default_ranges.first.getValueOr(0.0),
|
||||
quant_specs.default_ranges.second.getValueOr(0.0),
|
||||
quant_specs.IsSignedInferenceType()));
|
||||
pass_manager->addPass(mlir::TFL::CreateQuantizePass());
|
||||
pass_manager->addPass(
|
||||
mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops));
|
||||
}
|
||||
pass_manager->addPass(mlir::TFL::CreateQuantizePass());
|
||||
bool emit_quant_adaptor_ops =
|
||||
quant_specs.inference_type != quant_specs.inference_input_type;
|
||||
pass_manager->addPass(
|
||||
mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops));
|
||||
}
|
||||
|
||||
void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
@ -63,7 +59,7 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
standard_pipeline_options.enable_inliner = false;
|
||||
standard_pipeline_options.form_clusters = pass_config.form_clusters;
|
||||
mlir::TF::CreateTFStandardPipeline(*pass_manager, standard_pipeline_options);
|
||||
pass_manager->addPass(mlir::TFL::CreateDeviceIndexSelectorPass());
|
||||
pass_manager->addPass(mlir::TF::CreateDeviceIndexSelectorPass());
|
||||
|
||||
if (pass_config.shape_inference) {
|
||||
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
|
||||
@ -212,9 +208,6 @@ void CreateTFLStandardPipeline(OpPassManager& pm,
|
||||
|
||||
// Saved model pass to mark global tensors immutable.
|
||||
pm.addPass(mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass());
|
||||
// Used to mark non-exported functions in saved model private.
|
||||
pm.addPass(mlir::tf_saved_model::
|
||||
CreateMarkFunctionVisibilityUsingSavedModelLinkagePass());
|
||||
// Op fusion pass.
|
||||
pm.addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass());
|
||||
|
||||
|
@ -172,7 +172,7 @@ int main(int argc, char **argv) {
|
||||
input_file_name, input_mlir, use_splatted_constant, custom_opdefs,
|
||||
debug_info_file, input_arrays, input_dtypes, input_shapes,
|
||||
output_arrays,
|
||||
/*prune_unused_nodes=*/true, &source_mgr, &context);
|
||||
/*prune_unused_nodes=*/true, upgrade_legacy, &source_mgr, &context);
|
||||
}
|
||||
|
||||
// If errors occur, the library call in the above already logged the error
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/Visitors.h" // from @llvm-project
|
||||
#include "mlir/Parser.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/FileUtilities.h" // from @llvm-project
|
||||
@ -28,6 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
|
||||
@ -39,19 +41,47 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
using mlir::MLIRContext;
|
||||
using mlir::ModuleOp;
|
||||
using mlir::Operation;
|
||||
using mlir::OwningModuleRef;
|
||||
using stream_executor::port::StatusOr;
|
||||
|
||||
bool IsControlFlowV1Op(Operation* op) {
|
||||
return mlir::isa<mlir::tf_executor::SwitchOp>(op) ||
|
||||
mlir::isa<mlir::tf_executor::MergeOp>(op) ||
|
||||
mlir::isa<mlir::tf_executor::EnterOp>(op) ||
|
||||
mlir::isa<mlir::tf_executor::ExitOp>(op) ||
|
||||
mlir::isa<mlir::tf_executor::NextIterationSinkOp>(op) ||
|
||||
mlir::isa<mlir::tf_executor::NextIterationSourceOp>(op);
|
||||
}
|
||||
|
||||
mlir::LogicalResult IsValidGraph(mlir::ModuleOp module) {
|
||||
auto result = module.walk([&](Operation* op) {
|
||||
return IsControlFlowV1Op(op) ? mlir::WalkResult::interrupt()
|
||||
: mlir::WalkResult::advance();
|
||||
});
|
||||
if (result.wasInterrupted()) {
|
||||
module.emitError(
|
||||
"The graph has Control Flow V1 ops. TFLite converter doesn't support "
|
||||
"Control Flow V1 ops. Consider using Control Flow V2 ops instead. See "
|
||||
"https://www.tensorflow.org/api_docs/python/tf/compat/v1/"
|
||||
"enable_control_flow_v2.");
|
||||
return mlir::failure();
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
StatusOr<OwningModuleRef> LoadFromGraphdefOrMlirSource(
|
||||
const std::string& input_filename, bool input_mlir,
|
||||
bool use_splatted_constant, const std::vector<std::string>& extra_tf_opdefs,
|
||||
absl::string_view debug_info_file, absl::string_view input_arrays,
|
||||
absl::string_view input_dtypes, absl::string_view input_shapes,
|
||||
absl::string_view output_arrays, bool prune_unused_nodes,
|
||||
llvm::SourceMgr* source_mgr, MLIRContext* context) {
|
||||
bool enable_upgrade_legacy, llvm::SourceMgr* source_mgr,
|
||||
MLIRContext* context) {
|
||||
// Set up the input file.
|
||||
std::string error_message;
|
||||
auto file = mlir::openInputFile(input_filename, &error_message);
|
||||
@ -86,14 +116,14 @@ StatusOr<OwningModuleRef> LoadFromGraphdefOrMlirSource(
|
||||
file->getBuffer(), debug_info_file, input_arrays, input_dtypes,
|
||||
input_shapes, output_arrays, /*control_output_arrays=*/"",
|
||||
prune_unused_nodes, /*convert_legacy_fed_inputs=*/true,
|
||||
/*graph_as_function=*/false, /*upgrade_legacy=*/true,
|
||||
/*graph_as_function=*/false, enable_upgrade_legacy,
|
||||
/*enable_shape_inference=*/false, context);
|
||||
}
|
||||
return tensorflow::GraphdefToMlirTranslateFunction(
|
||||
file->getBuffer(), debug_info_file, input_arrays, input_dtypes,
|
||||
input_shapes, output_arrays, /*control_output_arrays=*/"",
|
||||
prune_unused_nodes, /*convert_legacy_fed_inputs=*/true,
|
||||
/*graph_as_function=*/false, /*upgrade_legacy=*/true,
|
||||
/*graph_as_function=*/false, enable_upgrade_legacy,
|
||||
/*enable_shape_inference=*/false, context);
|
||||
}
|
||||
|
||||
@ -104,7 +134,8 @@ Status ConvertTFExecutorToTFLOrFlatbuffer(
|
||||
mlir::PassManager* pass_manager) {
|
||||
mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext(),
|
||||
/*propagate=*/true);
|
||||
if (failed(pass_manager->run(module))) {
|
||||
|
||||
if (failed(IsValidGraph(module)) || failed(pass_manager->run(module))) {
|
||||
return statusHandler.ConsumeStatus();
|
||||
}
|
||||
|
||||
|
@ -41,7 +41,8 @@ LoadFromGraphdefOrMlirSource(
|
||||
absl::string_view debug_info_file, absl::string_view input_arrays,
|
||||
absl::string_view input_dtypes, absl::string_view input_shapes,
|
||||
absl::string_view output_arrays, bool prune_unused_nodes,
|
||||
llvm::SourceMgr* source_mgr, mlir::MLIRContext* context);
|
||||
bool enable_upgrade_legacy, llvm::SourceMgr* source_mgr,
|
||||
mlir::MLIRContext* context);
|
||||
|
||||
// Load Saved model (either v1 or v2) into MLIR.
|
||||
stream_executor::port::StatusOr<mlir::OwningModuleRef> ImportSavedModel(
|
||||
|
@ -28,9 +28,11 @@ limitations under the License.
|
||||
#include "llvm/ADT/APInt.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/Threading.h"
|
||||
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Quant/UniformSupport.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Diagnostics.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
@ -767,13 +769,26 @@ void LegalizeTF::runOnFunction() {
|
||||
[](Operation* op) {
|
||||
auto tfl_op = dyn_cast_or_null<TflRuntimeVerifyOpInterface>(op);
|
||||
if (!tfl_op) return false;
|
||||
return succeeded(tfl_op.VerifyTflRuntimeConstraints(
|
||||
tfl_op.getOperation(),
|
||||
/*failure_on_operand_type_mismatch=*/false));
|
||||
return succeeded(tfl_op.VerifyTflRuntimeConstraints(op));
|
||||
}));
|
||||
} else {
|
||||
target.addLegalDialect<TensorFlowLiteDialect>();
|
||||
}
|
||||
|
||||
// Ignore transient errors by registering an no-op handler.
|
||||
// Applying legalization patterns will emit unwanted, transient errors when
|
||||
// the replaced TFLite ops do not meet the sanity checks. In order to ignore
|
||||
// the transient errors, the following lines override a diagnostic handler
|
||||
// with an no-op handler only while this pass runs.
|
||||
uint64_t current_thread_id = llvm::get_threadid();
|
||||
ScopedDiagnosticHandler scoped_diag_handler(
|
||||
context, [¤t_thread_id](Diagnostic&) -> LogicalResult {
|
||||
// Consume only errors that are coming from the same thread in order not
|
||||
// to ignore errors from other passes that are running. Things running
|
||||
// in the pass manager can be multi-threaded.
|
||||
return success(current_thread_id == llvm::get_threadid());
|
||||
});
|
||||
|
||||
// Keep trying to convert.
|
||||
// TODO(karimnosseir): This is similar to what apply greedy patterns does.
|
||||
// Look if there is a function that tries until it converge.
|
||||
|
@ -91,9 +91,6 @@ std::unique_ptr<OperationPass<ModuleOp>> CreateWhileOutlinePass();
|
||||
// Verifies runtime constraints.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateRuntimeVerifyPass();
|
||||
|
||||
// Creates function pass to select device index/fold tf.DeviceIndex.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateDeviceIndexSelectorPass();
|
||||
|
||||
} // namespace TFL
|
||||
|
||||
} // namespace mlir
|
||||
|
@ -52,7 +52,7 @@ class PostQuantizePass : public PassWrapper<PostQuantizePass, FunctionPass> {
|
||||
|
||||
void RemoveQuantizationAdaptorOps(FuncOp func) {
|
||||
mlir::OpBuilder builder(func.getBody());
|
||||
auto& bb = func.getBlocks().front();
|
||||
auto& bb = func.front();
|
||||
auto* terminator = bb.getTerminator();
|
||||
|
||||
int num_args = bb.getNumArguments();
|
||||
|
@ -584,46 +584,50 @@ struct ConvertTFStridedSlice : public RewritePattern {
|
||||
|
||||
const int ellipsis_filled_dim_size = input_size - begin_shape[0] + 1;
|
||||
|
||||
llvm::APInt new_begin_mask = strided_slice_op.begin_mask();
|
||||
llvm::APInt new_end_mask = strided_slice_op.end_mask();
|
||||
int64_t begin_mask = strided_slice_op.begin_mask().getSExtValue();
|
||||
int64_t end_mask = strided_slice_op.end_mask().getSExtValue();
|
||||
int64_t new_begin_mask = 0;
|
||||
int64_t new_end_mask = 0;
|
||||
|
||||
SmallVector<int32_t, 4> padded_begin;
|
||||
SmallVector<int32_t, 4> padded_end;
|
||||
SmallVector<int32_t, 4> padded_stride;
|
||||
|
||||
// Before the ellipsis.
|
||||
uint64_t index = 1;
|
||||
int count = 0;
|
||||
|
||||
while (index < ellipsis_mask) {
|
||||
padded_begin.push_back(begin_dense_elem_attr.getValue<int32_t>(count));
|
||||
padded_end.push_back(end_dense_elem_attr.getValue<int32_t>(count));
|
||||
padded_stride.push_back(stride_dense_elem_attr.getValue<int32_t>(count));
|
||||
index <<= 1;
|
||||
count++;
|
||||
int index = 0;
|
||||
int new_index = 0;
|
||||
while (((ellipsis_mask >> index) & 1) == 0) {
|
||||
padded_begin.push_back(begin_dense_elem_attr.getValue<int32_t>(index));
|
||||
padded_end.push_back(end_dense_elem_attr.getValue<int32_t>(index));
|
||||
padded_stride.push_back(stride_dense_elem_attr.getValue<int32_t>(index));
|
||||
if ((begin_mask >> index) & 1) new_begin_mask |= (1 << new_index);
|
||||
if ((end_mask >> index) & 1) new_end_mask |= (1 << new_index);
|
||||
++index;
|
||||
++new_index;
|
||||
}
|
||||
|
||||
// Ellipsis.
|
||||
for (int i = 0; i < ellipsis_filled_dim_size; ++i) {
|
||||
new_begin_mask |= ellipsis_mask;
|
||||
new_end_mask |= ellipsis_mask;
|
||||
for (; new_index < index + ellipsis_filled_dim_size; ++new_index) {
|
||||
new_begin_mask |= (1 << new_index);
|
||||
new_end_mask |= (1 << new_index);
|
||||
|
||||
// Mimic the begin/end/strides mask behavior.
|
||||
padded_begin.push_back(0);
|
||||
padded_end.push_back(0);
|
||||
padded_stride.push_back(1);
|
||||
|
||||
ellipsis_mask <<= 1;
|
||||
}
|
||||
|
||||
// Account for ellipsis mask.
|
||||
count++;
|
||||
++index;
|
||||
|
||||
// After the ellipsis.
|
||||
for (; count < begin_shape[0]; ++count) {
|
||||
padded_begin.push_back(begin_dense_elem_attr.getValue<int32_t>(count));
|
||||
padded_end.push_back(end_dense_elem_attr.getValue<int32_t>(count));
|
||||
padded_stride.push_back(stride_dense_elem_attr.getValue<int32_t>(count));
|
||||
for (; index < begin_shape[0]; ++index) {
|
||||
padded_begin.push_back(begin_dense_elem_attr.getValue<int32_t>(index));
|
||||
padded_end.push_back(end_dense_elem_attr.getValue<int32_t>(index));
|
||||
padded_stride.push_back(stride_dense_elem_attr.getValue<int32_t>(index));
|
||||
|
||||
if ((begin_mask >> index) & 1) new_begin_mask |= (1 << new_index);
|
||||
if ((end_mask >> index) & 1) new_end_mask |= (1 << new_index);
|
||||
}
|
||||
|
||||
auto attribute_type = rewriter.getIntegerType(64);
|
||||
@ -645,7 +649,7 @@ struct ConvertTFStridedSlice : public RewritePattern {
|
||||
end_op.getResult(), stride_op.getResult(),
|
||||
rewriter.getIntegerAttr(attribute_type, new_begin_mask),
|
||||
rewriter.getIntegerAttr(attribute_type, new_end_mask),
|
||||
rewriter.getI64IntegerAttr(0),
|
||||
/*ellipsis_maks=*/rewriter.getI64IntegerAttr(0),
|
||||
rewriter.getIntegerAttr(attribute_type,
|
||||
strided_slice_op.new_axis_mask()),
|
||||
rewriter.getIntegerAttr(attribute_type,
|
||||
@ -655,10 +659,12 @@ struct ConvertTFStridedSlice : public RewritePattern {
|
||||
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// TODO(renjieliu): Consider expand the transformation for shrink
|
||||
// mask as well.
|
||||
TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
|
||||
|
||||
// TODO(renjieliu): Consider expand the transformation for shrink mask as
|
||||
// well.
|
||||
if (strided_slice_op.shrink_axis_mask().getZExtValue()) return failure();
|
||||
|
||||
// Handle new axis mask.
|
||||
uint64_t new_axis_mask = strided_slice_op.new_axis_mask().getZExtValue();
|
||||
if (new_axis_mask != 0) {
|
||||
|
@ -34,8 +34,7 @@ class RuntimeVerifyPass
|
||||
|
||||
void RuntimeVerifyPass::runOnFunction() {
|
||||
getFunction().walk([&](TflRuntimeVerifyOpInterface op) {
|
||||
if (failed(op.VerifyTflRuntimeConstraints(
|
||||
op.getOperation(), /*failure_on_operand_type_mismatch=*/true)))
|
||||
if (failed(op.VerifyTflRuntimeConstraints(op.getOperation())))
|
||||
signalPassFailure();
|
||||
});
|
||||
}
|
||||
|
@ -57,6 +57,7 @@ gentbl(
|
||||
td_srcs = [
|
||||
":tensorflow_ops_td_files",
|
||||
],
|
||||
test = True,
|
||||
)
|
||||
|
||||
gentbl(
|
||||
@ -88,6 +89,7 @@ gentbl(
|
||||
td_srcs = [
|
||||
":tensorflow_ops_td_files",
|
||||
],
|
||||
test = True,
|
||||
)
|
||||
|
||||
gentbl(
|
||||
@ -112,6 +114,7 @@ gentbl(
|
||||
"@llvm-project//mlir:include/mlir/IR/OpBase.td",
|
||||
"@llvm-project//mlir:include/mlir/Dialect/StandardOps/IR/Ops.td",
|
||||
],
|
||||
test = True,
|
||||
)
|
||||
|
||||
gentbl(
|
||||
@ -137,6 +140,7 @@ gentbl(
|
||||
"@llvm-project//mlir:include/mlir/Dialect/StandardOps/IR/Ops.td",
|
||||
"@llvm-project//mlir:include/mlir/IR/OpBase.td",
|
||||
],
|
||||
test = True,
|
||||
)
|
||||
|
||||
gentbl(
|
||||
@ -161,6 +165,7 @@ gentbl(
|
||||
"@llvm-project//mlir:include/mlir/IR/OpBase.td",
|
||||
"@llvm-project//mlir:include/mlir/Dialect/StandardOps/IR/Ops.td",
|
||||
],
|
||||
test = True,
|
||||
)
|
||||
|
||||
gentbl(
|
||||
@ -475,6 +480,7 @@ cc_library(
|
||||
"transforms/cluster_outlining.cc",
|
||||
"transforms/collection_ops_util.cc",
|
||||
"transforms/decompose_resource_ops_pass.cc",
|
||||
"transforms/device_index_selector.cc",
|
||||
"transforms/einsum.cc",
|
||||
"transforms/executor_island_coarsening.cc",
|
||||
"transforms/executor_tpuv1_inline_tpu_island.cc",
|
||||
@ -491,7 +497,6 @@ cc_library(
|
||||
"transforms/graph_pruning.cc",
|
||||
"transforms/launch_to_device_attribute.cc",
|
||||
"transforms/layout_optimization.cc",
|
||||
"transforms/mark_function_visibility.cc",
|
||||
"transforms/materialize_mlir_passthrough_op.cc",
|
||||
"transforms/optimize.cc",
|
||||
"transforms/optimize_global_tensors.cc",
|
||||
@ -661,7 +666,9 @@ cc_library(
|
||||
":tensorflow_types",
|
||||
":translate_utils",
|
||||
"//tensorflow/cc/saved_model:bundle_v2",
|
||||
"//tensorflow/cc/saved_model:constants",
|
||||
"//tensorflow/cc/saved_model:loader_lite",
|
||||
"//tensorflow/cc/saved_model:loader_util",
|
||||
"//tensorflow/compiler/jit:shape_inference_helpers",
|
||||
"//tensorflow/compiler/mlir:op_or_arg_name_mapper",
|
||||
"//tensorflow/compiler/tf2xla:functionalize_control_flow",
|
||||
@ -673,6 +680,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler/utils:transitive_fanin",
|
||||
"//tensorflow/core/platform:protobuf_internal",
|
||||
"//tensorflow/core/platform:types",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
@ -682,7 +690,6 @@ cc_library(
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
@ -1349,6 +1356,7 @@ cc_library(
|
||||
srcs = ["utils/tpu_rewrite_device_util.cc"],
|
||||
hdrs = ["utils/tpu_rewrite_device_util.h"],
|
||||
deps = [
|
||||
":tensorflow",
|
||||
"//tensorflow/compiler/xla:array4d",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/service:computation_placer",
|
||||
@ -1359,6 +1367,7 @@ cc_library(
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
||||
@ -1367,6 +1376,7 @@ tf_cc_test(
|
||||
size = "small",
|
||||
srcs = ["utils/tpu_rewrite_device_util_test.cc"],
|
||||
deps = [
|
||||
":device_util",
|
||||
":tpu_rewrite_device_util",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:test",
|
||||
|
@ -299,13 +299,13 @@ ParseResult ParseReplicateOp(OpAsmParser* parser, OperationState* state) {
|
||||
parser->parseRegion(body, region_args, region_arg_types))
|
||||
return failure();
|
||||
|
||||
if (body.getBlocks().size() > 1)
|
||||
return parser->emitError(loc) << "expects a single block region";
|
||||
|
||||
// Ensure that the region is well formed: it contains at least a block with
|
||||
// a ReturnOp terminator.
|
||||
ReplicateOp::ensureTerminator(body, parser->getBuilder(), state->location);
|
||||
|
||||
if (!llvm::hasSingleElement(body))
|
||||
return parser->emitError(loc) << "expects a single block region";
|
||||
|
||||
Operation& terminator = body.front().back();
|
||||
if (!isa<ReturnOp>(terminator))
|
||||
return parser->emitError(loc) << "expects a tf_device.return terminator";
|
||||
|
@ -220,13 +220,13 @@ ParseResult ParseGraphOp(OpAsmParser &parser, OperationState &result) {
|
||||
Region &body = *result.addRegion();
|
||||
if (parser.parseRegion(body, llvm::None, llvm::None)) return failure();
|
||||
|
||||
if (body.getBlocks().size() > 1)
|
||||
return parser.emitError(loc) << "expects a single block region";
|
||||
|
||||
// Ensure that the region is well formed: it contains at least a block with
|
||||
// a FetchOp terminator.
|
||||
GraphOp::ensureTerminator(body, parser.getBuilder(), result.location);
|
||||
|
||||
if (!llvm::hasSingleElement(body))
|
||||
return parser.emitError(loc) << "expects a single block region";
|
||||
|
||||
// Get the results type from the terminator type inside the graph.
|
||||
Operation &fetch = body.back().back();
|
||||
if (!isa<FetchOp>(fetch))
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -232,6 +232,7 @@ else_branch: A function that takes 'inputs' and returns a list of
|
||||
|
||||
def TF_YieldOp : TF_Op<"Yield", [Terminator]> {
|
||||
let summary = "Yield operation";
|
||||
|
||||
let description = [{
|
||||
The "yield" operation represents a return operation within the conditional
|
||||
and body of structured control flow (e.g., if and while). The operation
|
||||
@ -497,6 +498,7 @@ Inserts a placeholder for a tensor that will be always fed.
|
||||
|
||||
def TF_PlaceholderWithDefaultOp : TF_Op<"PlaceholderWithDefault", [NoSideEffect]> {
|
||||
let summary = "Placeholder op";
|
||||
|
||||
let description = [{
|
||||
A placeholder op that passes through input when its output is not fed.
|
||||
}];
|
||||
@ -839,9 +841,6 @@ def TF_XlaShardingOp : TF_Op<"XlaSharding", [NoSideEffect]> {
|
||||
An op which shards the input based on the given sharding attribute.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_Tensor:$input,
|
||||
|
||||
@ -858,9 +857,6 @@ An op which shards the input based on the given sharding attribute.
|
||||
def TF_InfeedDequeueTupleOp : TF_Op<"InfeedDequeueTuple", []> {
|
||||
let summary = "Fetches multiple values from infeed as an XLA tuple.";
|
||||
|
||||
let description = [{
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
OptionalAttr<StrAttr>:$_XlaSharding
|
||||
);
|
||||
@ -904,9 +900,6 @@ def TF_BatchDatasetV2Op : TF_Op<"BatchDatasetV2", [NoSideEffect]> {
|
||||
Creates a dataset that batches `batch_size` elements from `input_dataset`.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_VariantTensor:$input_dataset,
|
||||
I64Tensor:$batch_size,
|
||||
@ -1048,4 +1041,46 @@ operation create / operate on a copy of `x`.
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_BesselI0eOp : TF_Op<"BesselI0e", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary = "Computes the Bessel i0e function of `x` element-wise.";
|
||||
|
||||
let description = [{
|
||||
Exponentially scaled modified Bessel function of order 0 defined as
|
||||
`bessel_i0e(x) = exp(-abs(x)) bessel_i0(x)`.
|
||||
|
||||
This function is faster and numerically stabler than `bessel_i0(x)`.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_FpTensor:$x
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_FpTensor:$y
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_BesselI1eOp : TF_Op<"BesselI1e", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary = "Computes the Bessel i1e function of `x` element-wise.";
|
||||
|
||||
let description = [{
|
||||
Exponentially scaled modified Bessel function of order 0 defined as
|
||||
`bessel_i1e(x) = exp(-abs(x)) bessel_i1(x)`.
|
||||
|
||||
This function is faster and numerically stabler than `bessel_i1(x)`.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_FpTensor:$x
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_FpTensor:$y
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
#endif // TF_OPS
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "mlir/IR/Identifier.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/OpImplementation.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/SymbolTable.h" // from @llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
|
||||
@ -76,6 +77,23 @@ static LogicalResult Verify(GlobalTensorOp global_tensor) {
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult Verify(SessionInitializerOp session_initializer) {
|
||||
mlir::SymbolTable symbol_table(
|
||||
session_initializer.getParentOfType<ModuleOp>());
|
||||
|
||||
auto init_func_op =
|
||||
symbol_table.lookup<mlir::FuncOp>(session_initializer.initializer());
|
||||
if (!init_func_op)
|
||||
return session_initializer.emitOpError()
|
||||
<< "the initializer function does not exist";
|
||||
|
||||
if (!init_func_op.getType().getResults().empty())
|
||||
return session_initializer.emitOpError()
|
||||
<< "the initializer function should have no output";
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc.inc"
|
||||
|
||||
@ -212,14 +230,36 @@ static LogicalResult VerifySavedModelModule(
|
||||
}
|
||||
}
|
||||
for (auto func : module.getOps<FuncOp>()) {
|
||||
if (HasAnyTfSavedModelArgAttr(func)) {
|
||||
if (!IsExported(func)) {
|
||||
return func.emitError()
|
||||
<< "can only apply 'tf_saved_model' argument attributes "
|
||||
"to exported functions";
|
||||
}
|
||||
const bool is_exported = IsExported(func);
|
||||
|
||||
if (is_exported && !func.isPublic()) {
|
||||
return func.emitError()
|
||||
<< "exported function @" << func.getName() << " should be public";
|
||||
}
|
||||
|
||||
if (!is_exported && func.isPublic()) {
|
||||
return func.emitError() << "non-exported function @" << func.getName()
|
||||
<< " should be private";
|
||||
}
|
||||
|
||||
if (!is_exported && HasAnyTfSavedModelArgAttr(func)) {
|
||||
return func.emitError() << "can only apply 'tf_saved_model' argument "
|
||||
"attributes to exported functions";
|
||||
}
|
||||
}
|
||||
|
||||
auto session_initializers = module.getOps<SessionInitializerOp>();
|
||||
if (!session_initializers.empty() &&
|
||||
!llvm::hasSingleElement(session_initializers)) {
|
||||
return (*++session_initializers.begin()).emitError()
|
||||
<< "there must be no more than one session_initializer op";
|
||||
}
|
||||
|
||||
auto is_init = [&session_initializers](mlir::FuncOp func) {
|
||||
if (session_initializers.empty()) return false;
|
||||
return (*session_initializers.begin()).initializer() == func.getName();
|
||||
};
|
||||
|
||||
SymbolTable symbol_table(module);
|
||||
auto symbol_uses = SymbolTable::getSymbolUses(&module.getBodyRegion());
|
||||
if (!symbol_uses.hasValue()) {
|
||||
@ -230,6 +270,12 @@ static LogicalResult VerifySavedModelModule(
|
||||
auto func = symbol_table.lookup<FuncOp>(
|
||||
symbol_use.getSymbolRef().cast<FlatSymbolRefAttr>().getValue());
|
||||
if (func && IsExported(func)) {
|
||||
// If it is an init function, then it can be used by the unique
|
||||
// session_initializer op.
|
||||
if (is_init(func) &&
|
||||
llvm::isa<SessionInitializerOp>(symbol_use.getUser()))
|
||||
continue;
|
||||
|
||||
return symbol_use.getUser()
|
||||
->emitError("exported function cannot be internally referenced")
|
||||
.attachNote(func.getLoc())
|
||||
@ -349,5 +395,39 @@ GlobalTensorOp LookupBoundInput(FuncOp func, int arg_index,
|
||||
return symbol_table.lookup<GlobalTensorOp>(attr.getValue());
|
||||
}
|
||||
|
||||
SessionInitializerOp GetSessionInitializerOp(mlir::ModuleOp op) {
|
||||
auto initializers = op.getOps<SessionInitializerOp>();
|
||||
if (initializers.empty()) return {};
|
||||
return *initializers.begin();
|
||||
}
|
||||
|
||||
class OptimizeSessionInitializerPattern
|
||||
: public OpRewritePattern<SessionInitializerOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(SessionInitializerOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
SymbolTable symbol_table(op.getParentOfType<ModuleOp>());
|
||||
auto init_func_op = symbol_table.lookup<mlir::FuncOp>(op.initializer());
|
||||
|
||||
// The init function can only be referenced from the SessionInitializerOp.
|
||||
// And there is at most one SessionInitializerOp in the module. So both ops
|
||||
// have no other uses and can be simply erased.
|
||||
if (init_func_op.front().begin()->isKnownTerminator()) {
|
||||
rewriter.eraseOp(init_func_op);
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
void SessionInitializerOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
results.insert<OptimizeSessionInitializerPattern>(context);
|
||||
}
|
||||
|
||||
} // namespace tf_saved_model
|
||||
} // namespace mlir
|
||||
|
@ -61,6 +61,10 @@ GlobalTensorOp LookupBoundInput(FuncOp func, int arg_index,
|
||||
// should have.
|
||||
Type GetBoundInputArgTypeFor(GlobalTensorOp global_tensor);
|
||||
|
||||
// Returns the session initializer of this module if it exists. Returns null
|
||||
// otherwise.
|
||||
SessionInitializerOp GetSessionInitializerOp(mlir::ModuleOp op);
|
||||
|
||||
} // namespace tf_saved_model
|
||||
} // namespace mlir
|
||||
|
||||
|
@ -128,4 +128,30 @@ def TfSavedModel_GlobalTensorOp : TfSavedModel_Op<"global_tensor"> {
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
}
|
||||
|
||||
def TfSavedModel_SessionInitializerOp: TfSavedModel_Op<"session_initializer"> {
|
||||
let summary = "Initializes TensorFlow session state.";
|
||||
let description = [{
|
||||
The session initializer op marks a function that must be called by an
|
||||
external agent exactly once to initialize TensorFlow session state, and this
|
||||
must happen before any other exported functions are called. There must be no
|
||||
more than one session initializer in a saved model.
|
||||
|
||||
The `initializer` represents the initialization function. The function have
|
||||
no output and this function should be only called once.
|
||||
|
||||
This is used, for example, to initialize hash tables stored in resources and
|
||||
accessed by resource name (rather than as resource handles or bound inputs
|
||||
which is how `global_tensor`s are referenced)
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
FlatSymbolRefAttr:$initializer
|
||||
);
|
||||
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
#endif // SAVED_MODEL_DIALECT
|
||||
|
@ -1,47 +0,0 @@
|
||||
// RUN: tf-opt -tf-saved-model-mark-func-visibility -split-input-file %s | FileCheck --check-prefix=SAVEDMODEL %s
|
||||
// RUN: tf-opt -tf-mark-func-visibility -split-input-file -verify-diagnostics %s | FileCheck %s
|
||||
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
// SAVEDMODEL: func @func_exported_1() attributes {tf_saved_model.exported_names = ["func_exported_1"]}
|
||||
func @func_exported_1() attributes {tf_saved_model.exported_names = ["func_exported_1"]} {
|
||||
"tf.some_call"() {callee = {callee = {callee = @child}}} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// SAVEDMODEL: func @func_exported_2() attributes {tf_saved_model.exported_names = ["func_exported_2"]}
|
||||
func @func_exported_2() attributes {tf_saved_model.exported_names = ["func_exported_2"]} {
|
||||
"tf.some_call"() {callee = {callee = {callee = @child}}} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// SAVEDMODEL: func @func_not_exported() attributes {sym_visibility = "private"}
|
||||
func @func_not_exported() {
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module {
|
||||
// CHECK: func @func_with_entry_spec(%arg0: tensor<1xi32>) -> tensor<1xi32> attributes {tf.entry_function = {inputs = "x", outputs = "y"}}
|
||||
func @func_with_entry_spec(%arg0: tensor<1xi32>) -> tensor<1xi32> attributes {tf.entry_function = {inputs = "x", outputs = "y"}} {
|
||||
return %arg0 : tensor<1xi32>
|
||||
}
|
||||
|
||||
// CHECK: func @func_without_entry_spec(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> attributes {sym_visibility = "private"}
|
||||
func @func_without_entry_spec(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
|
||||
%0 = "tf.AddV2"(%arg0, %arg1) {T = i32, device = ""} : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
|
||||
return %0 : tensor<*xi32>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module {
|
||||
// expected-error @+1 {{can't overwrite the visibility of function private_func_with_entry_spec with private visibility}}
|
||||
func @private_func_with_entry_spec(%arg0: tensor<1xi32>) -> tensor<1xi32> attributes {tf.entry_function = {inputs = "x", outputs = "y"}, sym_visibility = "private"} {
|
||||
return %arg0 : tensor<1xi32>
|
||||
}
|
||||
}
|
@ -433,4 +433,28 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
|
||||
// CHECK: return %[[CAST_RESULT_0]], %[[CAST_RESULT_1]], %[[ADDI]]
|
||||
return %27, %28, %2 : tensor<*xui8>, tensor<*xi8>, tensor<*xi8>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: infer_device_launch
|
||||
func @infer_device_launch(%arg0: tensor<1x8x2xi32>) -> (tensor<*xf32>, tensor<*xf32>) {
|
||||
%0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
|
||||
%1 = "tf_device.launch"() ({
|
||||
%2 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<1x8x2xi32>) -> tensor<1x8x2xf32>
|
||||
tf_device.return %2 : tensor<1x8x2xf32>
|
||||
// CHECK: () -> tensor<1x8x2xf32>
|
||||
}) {device = "/device:CPU:0"} : () -> tensor<*xf32>
|
||||
// CHECK: "tf.Cast"(%{{.*}}) {Truncate = false} : (tensor<1x8x2xf32>) -> tensor<*xf32>
|
||||
// CHECK: (tensor<i32>, tensor<1x8x2xf32>) -> (tensor<1x8x1xf32>, tensor<1x8x1xf32>)
|
||||
%3:2 = "tf.Split"(%0, %1) {device = ""} : (tensor<i32>, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>)
|
||||
%4 = addf %1, %1 : tensor<*xf32>
|
||||
return %3#0, %3#1 : tensor<*xf32>, tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @tensor_cast(%arg0: tensor<1xi32>) -> tensor<1xi32>
|
||||
func @tensor_cast(%arg0: tensor<1xi32>) -> tensor<*xi32> {
|
||||
// CHECK: %[[RESULT:.*]] = tensor_cast
|
||||
// CHECK-SAME: tensor<1xi32> to tensor<1xi32>
|
||||
// CHECK: return %[[RESULT]] : tensor<1xi32>
|
||||
%1 = tensor_cast %arg0 : tensor<1xi32> to tensor<*xi32>
|
||||
return %1 : tensor<*xi32>
|
||||
}
|
||||
}
|
||||
|
@ -4,8 +4,6 @@ load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "lit_test")
|
||||
|
||||
def tf_saved_model_test(name, data, tags = None):
|
||||
"""Create a SavedModel test."""
|
||||
if tags == None:
|
||||
tags = ["no_rocm"]
|
||||
native.py_binary(
|
||||
name = name,
|
||||
testonly = 1,
|
||||
@ -26,5 +24,5 @@ def tf_saved_model_test(name, data, tags = None):
|
||||
name = name + ".py",
|
||||
data = [name] + data,
|
||||
driver = "@llvm-project//mlir:run_lit.sh",
|
||||
tags = tags,
|
||||
tags = tags + ["no_rocm"],
|
||||
)
|
||||
|
@ -46,7 +46,10 @@ def set_tf_options():
|
||||
# This function needs to take a "create_module_fn", as opposed to just the
|
||||
# module itself, because the creation of the module has to be delayed until
|
||||
# after absl and tensorflow have run various initialization steps.
|
||||
def do_test(signature_def_map, show_debug_info=False):
|
||||
def do_test(signature_def_map,
|
||||
init_op=None,
|
||||
canonicalize=False,
|
||||
show_debug_info=False):
|
||||
"""Runs test.
|
||||
|
||||
1. Performs absl and tf "main"-like initialization that must run before almost
|
||||
@ -61,6 +64,9 @@ def do_test(signature_def_map, show_debug_info=False):
|
||||
Args:
|
||||
signature_def_map: A map from string key to signature_def. The key will be
|
||||
used as function name in the resulting MLIR.
|
||||
init_op: The initializer op for the saved model. If set, it will generate a
|
||||
initializer graph in the resulting MLIR.
|
||||
canonicalize: If true, canonicalizer will be run on the resulting MLIR.
|
||||
show_debug_info: If true, shows debug locations in the resulting MLIR.
|
||||
"""
|
||||
|
||||
@ -84,6 +90,7 @@ def do_test(signature_def_map, show_debug_info=False):
|
||||
builder.add_meta_graph_and_variables(
|
||||
sess, [tf.saved_model.tag_constants.SERVING],
|
||||
signature_def_map,
|
||||
main_op=init_op,
|
||||
strip_default_attrs=True)
|
||||
builder.save()
|
||||
|
||||
@ -97,6 +104,9 @@ def do_test(signature_def_map, show_debug_info=False):
|
||||
mlir = pywrap_mlir.experimental_run_pass_pipeline(mlir,
|
||||
'tf-standard-pipeline',
|
||||
show_debug_info)
|
||||
if canonicalize:
|
||||
mlir = pywrap_mlir.experimental_run_pass_pipeline(mlir, 'canonicalize',
|
||||
show_debug_info)
|
||||
print(mlir)
|
||||
|
||||
app.run(app_main)
|
||||
|
@ -0,0 +1,92 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
# RUN: %p/hash_table_v1 | FileCheck %s
|
||||
|
||||
# pylint: disable=missing-docstring,line-too-long
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common_v1
|
||||
|
||||
# Verify that the tf.versions attribute exists. It is difficult to enforce
|
||||
# contents, since the version numbers change over time. The conversion logic
|
||||
# itself is verified in the common graphdef converter, so here just assert
|
||||
# it is being invoked.
|
||||
# CHECK: module
|
||||
# CHECK-SAME: tf.versions
|
||||
# CHECK-SAME: bad_consumers
|
||||
# CHECK-SAME: min_consumer
|
||||
# CHECK-SAME: producer
|
||||
|
||||
# CHECK: "tf_saved_model.session_initializer"() {initializer = [[init:@.*]]} : () -> ()
|
||||
# CHECK: "tf_saved_model.global_tensor"()
|
||||
|
||||
# CHECK: func [[init]]
|
||||
# CHECK-NEXT: [[R5:%.*]] = "tf.Const"()
|
||||
# CHECK-NEXT: [[R6:%.*]] = "tf.Const"()
|
||||
# CHECK-NEXT: [[R7:%.*]] = "tf.HashTableV2"()
|
||||
# CHECK-SAME: shared_name = "[[hash_table:.*]]"
|
||||
# CHECK-NEXT: "tf.LookupTableImportV2"([[R7]], [[R5]], [[R6]])
|
||||
|
||||
# CHECK: func {{@[a-zA-Z_0-9]+}}(
|
||||
# CHECK-SAME: [[ARG0:%.*]]: tensor<i32>
|
||||
# CHECK-SAME: [[ARG1:%.*]]: tensor<!tf.resource
|
||||
# CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["key"]
|
||||
|
||||
# CHECK-NEXT: [[R0:%.*]] = "tf.Const"()
|
||||
# CHECK-NEXT: [[R1:%.*]] = "tf.HashTableV2"()
|
||||
# CHECK-SAME: shared_name = "[[hash_table]]"
|
||||
# CHECK-NEXT: [[R2:%.*]] = "tf.LookupTableFindV2"([[R1]], [[ARG0]], [[R0]])
|
||||
# CHECK-NEXT: [[R3:%.*]] = "tf.ReadVariableOp"([[ARG1]])
|
||||
# CHECK-NEXT: [[R4:%.*]] = "tf.AddV2"([[R2]], [[R3]])
|
||||
# CHECK-NEXT: return [[R4]]
|
||||
|
||||
|
||||
def Test():
|
||||
|
||||
z = tf.compat.v1.get_variable(
|
||||
name='y',
|
||||
shape=(),
|
||||
initializer=tf.random_normal_initializer(),
|
||||
trainable=True)
|
||||
table_initializer = tf.lookup.KeyValueTensorInitializer(
|
||||
keys=[1, 2, 3, 4],
|
||||
values=[5, 6, 7, 8],
|
||||
key_dtype=tf.int32,
|
||||
value_dtype=tf.float32)
|
||||
table = tf.lookup.StaticHashTable(
|
||||
table_initializer, default_value=tf.constant(0.0))
|
||||
|
||||
x = tf.placeholder(tf.int32, shape=(), name='input')
|
||||
y = table.lookup(x)
|
||||
r = tf.add(y, z)
|
||||
|
||||
tensor_info_x = tf.compat.v1.saved_model.utils.build_tensor_info(x)
|
||||
tensor_info_r = tf.compat.v1.saved_model.utils.build_tensor_info(r)
|
||||
|
||||
return {
|
||||
'key': (tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
|
||||
inputs={'x': tensor_info_x},
|
||||
outputs={'r': tensor_info_r},
|
||||
method_name='some_function'))
|
||||
}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
common_v1.set_tf_options()
|
||||
common_v1.do_test(Test(), tf.tables_initializer())
|
@ -0,0 +1,74 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
# RUN: %p/remove_init_variable_v1 | FileCheck %s
|
||||
|
||||
# pylint: disable=missing-docstring,line-too-long
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common_v1
|
||||
|
||||
# Verify that the tf.versions attribute exists. It is difficult to enforce
|
||||
# contents, since the version numbers change over time. The conversion logic
|
||||
# itself is verified in the common graphdef converter, so here just assert
|
||||
# it is being invoked.
|
||||
# CHECK: module
|
||||
# CHECK-SAME: tf.versions
|
||||
# CHECK-SAME: bad_consumers
|
||||
# CHECK-SAME: min_consumer
|
||||
# CHECK-SAME: producer
|
||||
|
||||
# CHECK: "tf_saved_model.global_tensor"() {is_mutable, sym_name = "[[VAR:[a-zA-Z_0-9]+]]", type = tensor<1x3xf32>, value = {{.*}} : tensor<1x3xf32>} : () -> ()
|
||||
# CHECK-NOT: session_initializer
|
||||
|
||||
# CHECK: func {{@[a-zA-Z_0-9]+}}(
|
||||
# CHECK-SAME: [[ARG0:%.*]]: tensor<3x1xf32> {tf_saved_model.index_path = ["x"]},
|
||||
# CHECK-SAME: [[ARG1:%.*]]: tensor<!tf.resource<tensor<1x3xf32>>> {tf_saved_model.bound_input = @[[VAR]]})
|
||||
# CHECK-SAME: -> (tensor<3x3xf32> {tf_saved_model.index_path = ["r"]})
|
||||
# CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["key"]
|
||||
|
||||
# CHECK-NEXT: [[R0:%.*]] = "tf.ReadVariableOp"([[ARG1]]) {{{.*}}} : (tensor<!tf.resource<tensor<1x3xf32>>>) -> tensor<1x3xf32>
|
||||
# CHECK-NEXT: [[R1:%.*]] = "tf.MatMul"([[ARG0]], [[R0]]) {{{.*}}} : (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<3x3xf32>
|
||||
# CHECK-NEXT: return [[R1]] : tensor<3x3xf32>
|
||||
|
||||
|
||||
def Test():
|
||||
|
||||
x = tf.constant([[1.0], [1.0], [1.0]])
|
||||
y = tf.compat.v1.get_variable(
|
||||
name='y',
|
||||
shape=(1, 3),
|
||||
initializer=tf.random_normal_initializer(),
|
||||
trainable=True)
|
||||
r = tf.matmul(x, y)
|
||||
|
||||
tensor_info_x = tf.compat.v1.saved_model.utils.build_tensor_info(x)
|
||||
tensor_info_r = tf.compat.v1.saved_model.utils.build_tensor_info(r)
|
||||
|
||||
return {
|
||||
'key': (tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
|
||||
inputs={'x': tensor_info_x},
|
||||
outputs={'r': tensor_info_r},
|
||||
method_name='some_function'))
|
||||
}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
common_v1.set_tf_options()
|
||||
common_v1.do_test(
|
||||
Test(), tf.initializers.global_variables(), canonicalize=True)
|
@ -1,96 +0,0 @@
|
||||
// RUN: tf-opt -tf-saved-model-mark-func-visibility -symbol-dce -split-input-file %s | FileCheck %s
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// Test case: Unused function should be deleted.
|
||||
|
||||
// CHECK-NOT: func @unused
|
||||
func @unused() {
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// Test case: Root calls child. Child should not be deleted.
|
||||
|
||||
// CHECK: func @root
|
||||
func @root() attributes {tf_saved_model.exported_names = ["root"]} {
|
||||
"tf.some_call"() { callee = @child } : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: func @child
|
||||
func @child() {
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// Test case: Don't crash if attribute that doesn't reference a func.
|
||||
|
||||
"tf.some_opaque_global_variable"() { sym_name = "some_global" } : () -> ()
|
||||
|
||||
func @root2() attributes {tf_saved_model.exported_names = ["root2"]} {
|
||||
"tf.do_something_with_a_global"() { global = @some_global } : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// Test case: Delete recursively dead cycle.
|
||||
|
||||
// CHECK-NOT: func @recursively_dead0
|
||||
func @recursively_dead0() {
|
||||
"tf.some_call"() { callee = @recursively_dead1 } : () -> ()
|
||||
return
|
||||
}
|
||||
// CHECK-NOT: func @recursively_dead1
|
||||
func @recursively_dead1() {
|
||||
"tf.some_call"() { callee = @recursively_dead0 } : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// Test case: Root calls child with a deeply nested symbol reference.
|
||||
// Child should not be deleted.
|
||||
|
||||
// CHECK: func @root
|
||||
func @root() attributes {tf_saved_model.exported_names = ["root"]} {
|
||||
"tf.some_call"() {callee = {callee = {callee = @child}}} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: func @child
|
||||
func @child() {
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test case: If the module doesn't have tf_saved_model semantics, then this
|
||||
// pass shouldn't do anything.
|
||||
module {
|
||||
// CHECK: func @not_dead()
|
||||
func @not_dead() {
|
||||
return
|
||||
}
|
||||
}
|
@ -64,7 +64,7 @@ module attributes {tf_saved_model.semantics} {
|
||||
return
|
||||
}
|
||||
|
||||
func @f_callee(%arg0: tensor<!tf.resource<tensor<f32>>>) {
|
||||
func @f_callee(%arg0: tensor<!tf.resource<tensor<f32>>>) attributes {sym_visibility = "private"} {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -2,6 +2,11 @@
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// CHECK: tf_saved_model.session_initializer
|
||||
"tf_saved_model.session_initializer"() {
|
||||
initializer = @init
|
||||
} : () -> ()
|
||||
|
||||
// Representation for constants: (immutable) global tensor.
|
||||
// CHECK: tf_saved_model.global_tensor
|
||||
"tf_saved_model.global_tensor"() {
|
||||
@ -35,7 +40,18 @@ module attributes {tf_saved_model.semantics} {
|
||||
return %arg0 : tensor<f32>
|
||||
}
|
||||
|
||||
func @f() {
|
||||
func @f() attributes {sym_visibility = "private"} {
|
||||
return
|
||||
}
|
||||
|
||||
// Representation for init functions
|
||||
// CHECK: func @init
|
||||
// CHECK-SAME: exported_names = ["__tf_saved_model_session_initializer"]
|
||||
func @init(
|
||||
%arg1: tensor<!tf.resource<tensor<1x64xf32>>> {tf_saved_model.bound_input = @some_constant}
|
||||
) attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer"]}
|
||||
{
|
||||
"tf.some_call"(%arg1) : (tensor<!tf.resource<tensor<1x64xf32>>>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// expected-error@+1 {{unknown tf_saved_model dialect arg attribute 'tf_saved_model.not_a_real_arg_attr'}}
|
||||
func @f(%arg0: tensor<f32> {tf_saved_model.not_a_real_arg_attr = 1 : i32}) {
|
||||
func @f(%arg0: tensor<f32> {tf_saved_model.not_a_real_arg_attr = 1 : i32}) attributes {sym_visibility = "private"} {
|
||||
return
|
||||
}
|
||||
|
||||
@ -233,7 +233,7 @@ module attributes {tf_saved_model.semantics} {
|
||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<?xf32>, value = dense<1.> : tensor<1xf32> } : () -> ()
|
||||
// expected-error@+1 {{can only apply 'tf_saved_model' argument attributes to exported functions}}
|
||||
func @f(%arg0: tensor<!tf.resource<tensor<?xf32>>> {tf_saved_model.bound_input = @v})
|
||||
-> (tensor<?xf32> {tf_saved_model.index_path = []}) {
|
||||
-> (tensor<?xf32> {tf_saved_model.index_path = []}) attributes {sym_visibility = "private"} {
|
||||
%0 = "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<?xf32>>>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
@ -258,3 +258,97 @@ module attributes {tf_saved_model.semantics} {
|
||||
// expected-error@+1 {{'type' attribute for immutable 'tf_saved_model.global_tensor' should have a static shape}}
|
||||
"tf_saved_model.global_tensor"() { sym_name = "v", type = tensor<?xf32>, value = dense<1.> : tensor<1xf32> } : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// expected-error@+1 {{the initializer function does not exist}}
|
||||
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// expected-error@+1 {{the initializer function should have no output}}
|
||||
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
|
||||
func @init() -> tensor<1xf32> attributes {sym_visibility = "private"} {
|
||||
%0 = "tf.Const"() {value = dense<[1.0]> : tensor<1xf32> } : () -> tensor<1xf32>
|
||||
return %0 : tensor<1xf32>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
|
||||
// expected-error@+1 {{there must be no more than one session_initializer op}}
|
||||
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
|
||||
func @init() -> tensor<1xf32> attributes {sym_visibility = "private"} {
|
||||
%0 = "tf.Const"() {value = dense<[1.0]> : tensor<1xf32> } : () -> tensor<1xf32>
|
||||
return %0 : tensor<1xf32>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// expected-error@+1 {{exported function @f should be public}}
|
||||
func @f(
|
||||
%arg0: tensor<f32> {tf.resource_name = "resource"}
|
||||
) attributes { sym_visibility = "private", tf_saved_model.exported_names = ["foo.some_func"] } {
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// expected-error@+1 {{non-exported function @f should be private}}
|
||||
func @f(
|
||||
%arg0: tensor<f32> {tf.resource_name = "resource"}
|
||||
) {
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// expected-error@+1 {{the initializer function does not exist}}
|
||||
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// expected-error@+1 {{the initializer function should have no output}}
|
||||
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
|
||||
func @init() -> (tensor<1xf32> {tf_saved_model.index_path = ["output"]})
|
||||
attributes { tf_saved_model.exported_names = ["__tf_saved_model_session_initializer"] } {
|
||||
%0 = "tf.Const"() {value = dense<[1.0]> : tensor<1xf32> } : () -> tensor<1xf32>
|
||||
return %0 : tensor<1xf32>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
|
||||
// expected-error@+1 {{there must be no more than one session_initializer op}}
|
||||
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
|
||||
func @init() -> (tensor<1xf32> {tf_saved_model.index_path = ["output"]})
|
||||
attributes { tf_saved_model.exported_names = ["__tf_saved_model_session_initializer"] } {
|
||||
%0 = "tf.Const"() {value = dense<[1.0]> : tensor<1xf32> } : () -> tensor<1xf32>
|
||||
return %0 : tensor<1xf32>
|
||||
}
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
// RUN: tf-opt -tf-saved-model-optimize-global-tensors -split-input-file %s | FileCheck %s
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Freezing.
|
||||
// Immutability.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
@ -142,3 +142,89 @@ module attributes {tf_saved_model.semantics} {
|
||||
// Test running the pass on a module that does not have
|
||||
// tf_saved_model.semantics.
|
||||
module {}
|
||||
|
||||
// -----
|
||||
|
||||
// Test use as an input in unhandled op
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// CHECK: "tf_saved_model.global_tensor"() {
|
||||
// CHECK-SAME: is_mutable
|
||||
// CHECK-SAME: } : () -> ()
|
||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
|
||||
|
||||
func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v})
|
||||
attributes {tf_saved_model.exported_names = ["f"]} {
|
||||
"tf.unhandled_op"(%arg0) : (tensor<!tf.resource<tensor<f32>>>) -> ()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// Test use as a region capture in an unhandled op
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// CHECK: "tf_saved_model.global_tensor"() {
|
||||
// CHECK-SAME: is_mutable
|
||||
// CHECK-SAME: } : () -> ()
|
||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
|
||||
|
||||
func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v})
|
||||
attributes {tf_saved_model.exported_names = ["f"]} {
|
||||
"tf.unhandled"() ({
|
||||
%val = "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
|
||||
"tf.unhandled_terminator"() : () -> ()
|
||||
}) : () -> ()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test use as region capture as well as input in an unhandled op
|
||||
// to the unhandled op.
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// CHECK: "tf_saved_model.global_tensor"() {
|
||||
// CHECK-SAME: is_mutable
|
||||
// CHECK-SAME: } : () -> ()
|
||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
|
||||
|
||||
// CHECK: "tf_saved_model.global_tensor"() {
|
||||
// CHECK-SAME: is_mutable
|
||||
// CHECK-SAME: } : () -> ()
|
||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "u", type = tensor<f32>, value = dense<22.> : tensor<f32> } : () -> ()
|
||||
|
||||
func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}, %arg1: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @u})
|
||||
attributes {tf_saved_model.exported_names = ["f"]} {
|
||||
%0 = "tf.unhandled"(%arg0) ({
|
||||
%val = "tf.ReadVariableOp"(%arg1) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
|
||||
"tf.unhandled_terminator"() : () -> ()
|
||||
}) : (tensor<!tf.resource<tensor<f32>>>) -> (tensor<!tf.resource<tensor<f32>>>)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test multiple global tensors uses as operands for an unhandled op.
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// CHECK: "tf_saved_model.global_tensor"() {
|
||||
// CHECK-SAME: is_mutable
|
||||
// CHECK-SAME: } : () -> ()
|
||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
|
||||
|
||||
// CHECK: "tf_saved_model.global_tensor"() {
|
||||
// CHECK-SAME: is_mutable
|
||||
// CHECK-SAME: } : () -> ()
|
||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "u", type = tensor<f32>, value = dense<22.> : tensor<f32> } : () -> ()
|
||||
|
||||
func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}, %arg1: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @u})
|
||||
attributes {tf_saved_model.exported_names = ["f"]} {
|
||||
"tf.unhandled"(%arg0, %arg1) : (tensor<!tf.resource<tensor<f32>>>, tensor<!tf.resource<tensor<f32>>>) -> ()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user