Merge remote-tracking branch 'upstream/master'
This commit is contained in:
commit
41da235fd0
19
.bazelrc
19
.bazelrc
@ -30,6 +30,7 @@
|
||||
# short_logs: Only log errors during build, skip warnings.
|
||||
# monolithic: Build all TF C++ code into a single shared object.
|
||||
# dynamic_kernels: Try to link all kernels dynamically (experimental).
|
||||
# libc++: Link against libc++ instead of stdlibc++
|
||||
#
|
||||
#
|
||||
# TF version options;
|
||||
@ -38,6 +39,7 @@
|
||||
#
|
||||
# Feature and Third party library support options:
|
||||
# xla: Build TF with XLA
|
||||
# tpu: Build TF with TPU support
|
||||
# using_cuda: CUDA is available to build system.
|
||||
# cuda: Build with full cuda support.
|
||||
# rocm: Build with AMD GPU support (rocm).
|
||||
@ -79,6 +81,14 @@
|
||||
# elinux_armhf: Embedded Linux options for armhf (ARMv7) CPU support.
|
||||
|
||||
|
||||
# Allow builds using libc++ as a linker library
|
||||
# This is mostly for OSSFuzz, so we also pass in the flags from environment to clean build file
|
||||
build:libc++ --action_env=CC
|
||||
build:libc++ --action_env=CXX
|
||||
build:libc++ --action_env=CXXFLAGS=-stdlib=libc++
|
||||
build:libc++ --action_env=PATH
|
||||
build:libc++ --define force_libcpp=enabled
|
||||
build:libc++ --linkopt -fuse-ld=lld
|
||||
|
||||
# Android configs. Bazel needs to have --cpu and --fat_apk_cpu both set to the
|
||||
# target CPU to build transient dependencies correctly. See
|
||||
@ -171,6 +181,9 @@ build:dbg --cxxopt -DTF_LITE_DISABLE_X86_NEON
|
||||
# AWS SDK must be compiled in release mode. see: https://github.com/tensorflow/tensorflow/issues/37498
|
||||
build:dbg --copt -DDEBUG_BUILD
|
||||
|
||||
# Config to build TPU backend
|
||||
build:tpu --define=with_tpu_support=true
|
||||
|
||||
build:tensorrt --action_env TF_NEED_TENSORRT=1
|
||||
|
||||
build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain
|
||||
@ -200,6 +213,8 @@ build:nogcp --define=no_gcp_support=true
|
||||
build:nohdfs --define=no_hdfs_support=true
|
||||
build:nonccl --define=no_nccl_support=true
|
||||
|
||||
build:stackdriver_support --define=stackdriver_support=true
|
||||
|
||||
build --define=use_fast_cpp_protos=true
|
||||
build --define=allow_oversize_protos=true
|
||||
|
||||
@ -441,8 +456,8 @@ build:rbe_linux_py3 --python_path="/usr/bin/python3"
|
||||
build:rbe_linux_py3 --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-manylinux2010-py3_config_python"
|
||||
|
||||
build:rbe_win --config=rbe
|
||||
build:rbe_win --crosstool_top="@org_tensorflow//third_party/toolchains/preconfig/win/bazel_211:toolchain"
|
||||
build:rbe_win --extra_toolchains="@org_tensorflow//third_party/toolchains/preconfig/win/bazel_211:cc-toolchain-x64_windows"
|
||||
build:rbe_win --crosstool_top="@org_tensorflow//third_party/toolchains/preconfig/win/tf_win_08062020:toolchain"
|
||||
build:rbe_win --extra_toolchains="@org_tensorflow//third_party/toolchains/preconfig/win/tf_win_08062020:cc-toolchain-x64_windows"
|
||||
build:rbe_win --host_javabase="@org_tensorflow//third_party/toolchains/preconfig/win:windows_jdk8"
|
||||
build:rbe_win --javabase="@org_tensorflow//third_party/toolchains/preconfig/win:windows_jdk8"
|
||||
build:rbe_win --extra_execution_platforms="@org_tensorflow//third_party/toolchains/preconfig/win:rbe_windows_ltsc2019"
|
||||
|
@ -1 +1 @@
|
||||
3.0.0
|
||||
3.1.0
|
||||
|
2
.github/stale.yml
vendored
2
.github/stale.yml
vendored
@ -23,7 +23,7 @@
|
||||
daysUntilStale: 7
|
||||
# Number of days of inactivity before a stale Issue or Pull Request is closed
|
||||
daysUntilClose: 7
|
||||
# Issues or Pull Requests with these labels will never be considered stale. Set to `[]` to disable
|
||||
# Only issues or pull requests with all of these labels are checked if stale. Defaults to `[]` (disabled)
|
||||
onlyLabels:
|
||||
- stat:awaiting response
|
||||
# Comment to post when marking as stale. Set to `false` to disable
|
||||
|
@ -61,7 +61,6 @@ commands.
|
||||
*Nightly binaries are available for testing using the
|
||||
[tf-nightly](https://pypi.python.org/pypi/tf-nightly) and
|
||||
[tf-nightly-cpu](https://pypi.python.org/pypi/tf-nightly-cpu) packages on PyPi.*
|
||||
|
||||
#### *Try your first TensorFlow program*
|
||||
|
||||
```shell
|
||||
@ -96,6 +95,7 @@ for general questions and discussion, and please direct specific questions to
|
||||
The TensorFlow project strives to abide by generally accepted best practices in
|
||||
open-source software development:
|
||||
|
||||
[](https://bugs.chromium.org/p/oss-fuzz/issues/list?sort=-opened&can=1&q=proj:tensorflow)
|
||||
[](https://bestpractices.coreinfrastructure.org/projects/1486)
|
||||
[](CODE_OF_CONDUCT.md)
|
||||
|
||||
@ -114,6 +114,12 @@ Build Type | Status
|
||||
**Android** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [](https://bintray.com/google/tensorflow/tensorflow/_latestVersion)
|
||||
**Raspberry Pi 0 and 1** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl)
|
||||
**Raspberry Pi 2 and 3** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl)
|
||||
**Libtensorflow MacOS CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-mac-cpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly)
|
||||
**Libtensorflow Linux CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-cpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly)
|
||||
**Libtensorflow Linux GPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-gpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly)
|
||||
**Libtensorflow Windows CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-cpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly)
|
||||
**Libtensorflow Windows GPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-gpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly)
|
||||
|
||||
|
||||
### Community Supported Builds
|
||||
|
||||
|
@ -49,7 +49,7 @@ _TF_BAZELRC_FILENAME = '.tf_configure.bazelrc'
|
||||
_TF_WORKSPACE_ROOT = ''
|
||||
_TF_BAZELRC = ''
|
||||
_TF_CURRENT_BAZEL_VERSION = None
|
||||
_TF_MIN_BAZEL_VERSION = '2.0.0'
|
||||
_TF_MIN_BAZEL_VERSION = '3.1.0'
|
||||
_TF_MAX_BAZEL_VERSION = '3.99.0'
|
||||
|
||||
NCCL_LIB_PATHS = [
|
||||
|
@ -298,6 +298,13 @@ config_setting(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# Experimental features
|
||||
config_setting(
|
||||
name = "stackdriver_support",
|
||||
define_values = {"stackdriver_support": "true"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# Crosses between platforms and file system libraries not supported on those
|
||||
# platforms due to limitations in nested select() statements.
|
||||
config_setting(
|
||||
@ -460,6 +467,13 @@ config_setting(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# This flag enables experimental TPU support
|
||||
config_setting(
|
||||
name = "with_tpu_support",
|
||||
values = {"define": "with_tpu_support=true"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# Specifies via a config setting if this is a mobile build or not, makes
|
||||
# it easier to combine settings later.
|
||||
selects.config_setting_group(
|
||||
@ -531,6 +545,7 @@ package_group(
|
||||
|
||||
# Packages that use composite tensors or dispatch.
|
||||
# TODO(b/154762408) Remove this package group once it's no longer needed.
|
||||
# If this is modified, then copy.bara.sky must also be modified.
|
||||
package_group(name = "composite_tensor_whitelist")
|
||||
|
||||
# Packages that use private types symbols, until they are exported.
|
||||
@ -540,6 +555,11 @@ package_group(
|
||||
packages = ["//learning/deepmind/tensorflow/replicator/..."],
|
||||
)
|
||||
|
||||
# Packages that use StructuredTensors.
|
||||
# TODO(b/159007891) Remove this package once StructuredTensor is exported.
|
||||
# If this is modified, then copy.bara.sky must also be modified.
|
||||
package_group(name = "structured_tensor_whitelist")
|
||||
|
||||
filegroup(
|
||||
name = "intel_binary_blob",
|
||||
data = if_mkl_ml(
|
||||
|
@ -624,7 +624,7 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
|
||||
|
||||
const int num_inputs = input_shapes->num_items;
|
||||
NodeDef node_def;
|
||||
tensorflow::AbstractOperationInterface* op = tensorflow::unwrap(tfe_op);
|
||||
tensorflow::ImmediateExecutionOperation* op = tensorflow::unwrap(tfe_op);
|
||||
node_def.set_name(op->Name());
|
||||
node_def.set_op(op->Name());
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
|
@ -38,9 +38,10 @@ tf_cuda_library(
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
":context_interface",
|
||||
":operation_interface",
|
||||
":tensor_handle_interface",
|
||||
":immediate_execution_context",
|
||||
":immediate_execution_operation",
|
||||
":immediate_execution_tensor_handle",
|
||||
":abstract_tensor_handle",
|
||||
":tfe_context_internal",
|
||||
":tfe_cancellation_manager_internal",
|
||||
":tfe_executor_internal",
|
||||
@ -101,13 +102,17 @@ tf_cuda_library(
|
||||
filegroup(
|
||||
name = "pywrap_required_hdrs",
|
||||
srcs = [
|
||||
"abstract_context.h",
|
||||
"abstract_function.h",
|
||||
"abstract_operation.h",
|
||||
"abstract_tensor_handle.h",
|
||||
"c_api_experimental.h",
|
||||
"c_api_internal.h",
|
||||
"c_api_unified_experimental.h",
|
||||
"context_interface.h",
|
||||
"dlpack.h",
|
||||
"operation_interface.h",
|
||||
"tensor_handle_interface.h",
|
||||
"immediate_execution_context.h",
|
||||
"immediate_execution_operation.h",
|
||||
"immediate_execution_tensor_handle.h",
|
||||
"tfe_cancellation_manager_internal.h",
|
||||
"tfe_executor_internal.h",
|
||||
"tfe_monitoring_internal.h",
|
||||
@ -163,12 +168,22 @@ cc_library(
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensor_handle_interface",
|
||||
hdrs = ["tensor_handle_interface.h"],
|
||||
name = "abstract_tensor_handle",
|
||||
hdrs = ["abstract_tensor_handle.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "immediate_execution_tensor_handle",
|
||||
hdrs = ["immediate_execution_tensor_handle.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":abstract_tensor_handle",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
@ -177,13 +192,13 @@ cc_library(
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "operation_interface",
|
||||
hdrs = ["operation_interface.h"],
|
||||
name = "abstract_operation",
|
||||
hdrs = ["abstract_operation.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":tensor_handle_interface",
|
||||
":abstract_tensor_handle",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
@ -193,16 +208,59 @@ cc_library(
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "context_interface",
|
||||
hdrs = ["context_interface.h"],
|
||||
name = "immediate_execution_operation",
|
||||
hdrs = ["immediate_execution_operation.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":operation_interface",
|
||||
":tensor_handle_interface",
|
||||
":abstract_operation",
|
||||
":abstract_tensor_handle",
|
||||
":immediate_execution_tensor_handle",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "abstract_context",
|
||||
hdrs = ["abstract_context.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":abstract_function",
|
||||
":abstract_operation",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "abstract_function",
|
||||
hdrs = ["abstract_function.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:status",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "immediate_execution_context",
|
||||
hdrs = ["immediate_execution_context.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":abstract_context",
|
||||
":immediate_execution_operation",
|
||||
":immediate_execution_tensor_handle",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/c/experimental/saved_model/core:saved_model_api",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
@ -218,7 +276,7 @@ cc_library(
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":context_interface",
|
||||
":immediate_execution_context",
|
||||
"//tensorflow/c:conversion_macros",
|
||||
],
|
||||
)
|
||||
@ -278,7 +336,7 @@ cc_library(
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":operation_interface",
|
||||
":immediate_execution_operation",
|
||||
"//tensorflow/c:conversion_macros",
|
||||
],
|
||||
)
|
||||
@ -301,7 +359,7 @@ cc_library(
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":tensor_handle_interface",
|
||||
":immediate_execution_tensor_handle",
|
||||
"//tensorflow/c:conversion_macros",
|
||||
],
|
||||
)
|
||||
@ -481,6 +539,9 @@ tf_cuda_library(
|
||||
":tfe_context_internal",
|
||||
":tfe_op_internal",
|
||||
":tfe_tensorhandle_internal",
|
||||
":abstract_operation",
|
||||
":abstract_context",
|
||||
":abstract_tensor_handle",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_api_internal",
|
||||
"//tensorflow/core:core_cpu",
|
||||
|
69
tensorflow/c/eager/abstract_context.h
Normal file
69
tensorflow/c/eager/abstract_context.h
Normal file
@ -0,0 +1,69 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_ABSTRACT_CONTEXT_H_
|
||||
#define TENSORFLOW_C_EAGER_ABSTRACT_CONTEXT_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/abstract_function.h"
|
||||
#include "tensorflow/c/eager/abstract_operation.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Abstract interface to a context.
|
||||
//
|
||||
// This serves as a factory for creating `AbstractOperation`s and for
|
||||
// registering traced functions.
|
||||
// Operations creation within a context can only be executed in that context
|
||||
// (for now at least).
|
||||
// Implementations of the context may contain some state e.g. an execution
|
||||
// environment, a traced representation etc.
|
||||
class AbstractContext {
|
||||
protected:
|
||||
enum AbstractContextKind { kTracing, kImmediateExecution };
|
||||
explicit AbstractContext(AbstractContextKind kind) : kind_(kind) {}
|
||||
virtual ~AbstractContext() {}
|
||||
|
||||
public:
|
||||
AbstractContextKind getKind() const { return kind_; }
|
||||
|
||||
// Release any underlying resources, including the interface object.
|
||||
//
|
||||
// WARNING: The destructor of this class is marked as protected to disallow
|
||||
// clients from directly destroying this object since it may manage it's own
|
||||
// lifetime through ref counting. Thus clients MUST call Release() in order to
|
||||
// destroy an instance of this class.
|
||||
virtual void Release() = 0;
|
||||
|
||||
// Creates an operation builder and ties it to this context.
|
||||
// The returned object can be used for setting operation's attributes,
|
||||
// adding inputs and finally executing (immediately or lazily as in tracing)
|
||||
// it in this context.
|
||||
virtual AbstractOperation* CreateOperation() = 0;
|
||||
|
||||
// Registers a function with this context, after this the function is
|
||||
// available to be called/referenced by its name in this context.
|
||||
virtual Status RegisterFunction(AbstractFunction*) = 0;
|
||||
// Remove a function. 'func' argument is the name of a previously added
|
||||
// FunctionDef. The name is in fdef.signature.name.
|
||||
virtual Status RemoveFunction(const string& func) = 0;
|
||||
|
||||
private:
|
||||
const AbstractContextKind kind_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_ABSTRACT_CONTEXT_H_
|
46
tensorflow/c/eager/abstract_function.h
Normal file
46
tensorflow/c/eager/abstract_function.h
Normal file
@ -0,0 +1,46 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_ABSTRACT_FUNCTION_H_
|
||||
#define TENSORFLOW_C_EAGER_ABSTRACT_FUNCTION_H_
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// A traced function: this hides the complexity of converting the serialized
|
||||
// representation between various supported formats e.g. FunctionDef and Mlir
|
||||
// function.
|
||||
class AbstractFunction {
|
||||
protected:
|
||||
enum AbstractFunctionKind { kGraphFunc, kMlirFunc };
|
||||
explicit AbstractFunction(AbstractFunctionKind kind) : kind_(kind) {}
|
||||
|
||||
public:
|
||||
// Returns which subclass is this instance of.
|
||||
AbstractFunctionKind getKind() const { return kind_; }
|
||||
virtual ~AbstractFunction() = default;
|
||||
|
||||
// Returns the AbstractFunction as a FunctionDef.
|
||||
virtual Status GetFunctionDef(FunctionDef**) = 0;
|
||||
|
||||
private:
|
||||
const AbstractFunctionKind kind_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_ABSTRACT_FUNCTION_H_
|
@ -12,24 +12,29 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
|
||||
#define TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
|
||||
#ifndef TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_
|
||||
#define TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
struct TFE_Op;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Abstract interface to an operation.
|
||||
class AbstractOperationInterface {
|
||||
// This interface allows building and executing an operation in either
|
||||
// tracing or immediate execution mode.
|
||||
class AbstractOperation {
|
||||
protected:
|
||||
enum AbstractOperationKind { kTracing, kImmediateExecution };
|
||||
explicit AbstractOperation(AbstractOperationKind kind) : kind_(kind) {}
|
||||
virtual ~AbstractOperation() {}
|
||||
|
||||
public:
|
||||
AbstractOperationKind getKind() const { return kind_; }
|
||||
|
||||
// Release any underlying resources, including the interface object.
|
||||
//
|
||||
// WARNING: The destructor of this class is marked as protected to disallow
|
||||
@ -38,7 +43,6 @@ class AbstractOperationInterface {
|
||||
// clients MUST call Release() in order to destroy an instance of this class.
|
||||
virtual void Release() = 0;
|
||||
|
||||
virtual void Clear() = 0;
|
||||
virtual Status Reset(const char* op, const char* raw_device_name) = 0;
|
||||
|
||||
virtual const string& Name() const = 0;
|
||||
@ -66,12 +70,10 @@ class AbstractOperationInterface {
|
||||
// existing and given constraints will be performed.
|
||||
virtual Status SetDeviceName(const char* name) = 0;
|
||||
|
||||
virtual Status AddInput(AbstractTensorHandleInterface* input) = 0;
|
||||
virtual Status AddInputList(
|
||||
absl::Span<AbstractTensorHandleInterface*> inputs) = 0;
|
||||
virtual Status Execute(absl::Span<AbstractTensorHandleInterface*> retvals,
|
||||
virtual Status AddInput(AbstractTensorHandle* input) = 0;
|
||||
virtual Status AddInputList(absl::Span<AbstractTensorHandle*> inputs) = 0;
|
||||
virtual Status Execute(absl::Span<AbstractTensorHandle*> retvals,
|
||||
int* num_retvals) = 0;
|
||||
virtual const tensorflow::OpDef* OpDef() const = 0;
|
||||
|
||||
virtual Status SetAttrString(const char* attr_name, const char* data,
|
||||
size_t length) = 0;
|
||||
@ -82,7 +84,7 @@ class AbstractOperationInterface {
|
||||
virtual Status SetAttrShape(const char* attr_name, const int64_t* dims,
|
||||
const int num_dims) = 0;
|
||||
virtual Status SetAttrFunction(const char* attr_name,
|
||||
const AbstractOperationInterface* value) = 0;
|
||||
const AbstractOperation* value) = 0;
|
||||
virtual Status SetAttrFunctionName(const char* attr_name, const char* value,
|
||||
size_t length) = 0;
|
||||
virtual Status SetAttrTensor(const char* attr_name,
|
||||
@ -102,19 +104,12 @@ class AbstractOperationInterface {
|
||||
virtual Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
|
||||
const int* num_dims, int num_values) = 0;
|
||||
virtual Status SetAttrFunctionList(
|
||||
const char* attr_name,
|
||||
absl::Span<const AbstractOperationInterface*> values) = 0;
|
||||
const char* attr_name, absl::Span<const AbstractOperation*> values) = 0;
|
||||
|
||||
virtual Status InputLength(const char* input_name, int* length) = 0;
|
||||
virtual Status OutputLength(const char* output_name, int* length) = 0;
|
||||
|
||||
// Experimental
|
||||
virtual Status SetUseXla(bool enable) = 0;
|
||||
|
||||
protected:
|
||||
virtual ~AbstractOperationInterface() {}
|
||||
private:
|
||||
const AbstractOperationKind kind_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
|
||||
#endif // TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_
|
45
tensorflow/c/eager/abstract_tensor_handle.h
Normal file
45
tensorflow/c/eager/abstract_tensor_handle.h
Normal file
@ -0,0 +1,45 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_ABSTRACT_TENSOR_HANDLE_H_
|
||||
#define TENSORFLOW_C_EAGER_ABSTRACT_TENSOR_HANDLE_H_
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Abstract interface to a Tensor handle in either tracing or immediate
|
||||
// execution mode.
|
||||
class AbstractTensorHandle {
|
||||
protected:
|
||||
enum AbstractTensorHandleKind { kTracing, kImmediateExecution };
|
||||
explicit AbstractTensorHandle(AbstractTensorHandleKind kind) : kind_(kind) {}
|
||||
virtual ~AbstractTensorHandle() {}
|
||||
|
||||
public:
|
||||
AbstractTensorHandleKind getKind() const { return kind_; }
|
||||
|
||||
// Release any underlying resources, including the interface object.
|
||||
//
|
||||
// WARNING: The destructor of this class is marked as protected to disallow
|
||||
// clients from directly destroying this object since it may manage it's own
|
||||
// lifetime through ref counting. Thus this must be allocated on the heap and
|
||||
// clients MUST call Release() in order to destroy an instance of this class.
|
||||
virtual void Release() = 0;
|
||||
|
||||
private:
|
||||
const AbstractTensorHandleKind kind_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_ABSTRACT_TENSOR_HANDLE_H_
|
@ -21,6 +21,8 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
|
||||
// clang-format off
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
// clang-format on
|
||||
@ -31,8 +33,8 @@ limitations under the License.
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/tfe_context_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_op_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
@ -713,8 +715,8 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
||||
status->status = tfrt::ListOpHandlerChains(
|
||||
opts->session_options.options, &op_handler_chains, &device_attributes);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
return tensorflow::wrap(
|
||||
new tfrt::ContextInterface(op_handler_chains, device_attributes));
|
||||
return tensorflow::wrap(new tfrt::ContextInterface(
|
||||
op_handler_chains, device_attributes, opts->async));
|
||||
#else
|
||||
status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
|
||||
return nullptr;
|
||||
@ -1119,7 +1121,7 @@ size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
|
||||
|
||||
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
TF_Status* status) {
|
||||
tensorflow::AbstractOperationInterface* new_op =
|
||||
tensorflow::ImmediateExecutionOperation* new_op =
|
||||
tensorflow::unwrap(ctx)->CreateOperation();
|
||||
status->status = new_op->Reset(op_or_function_name, nullptr);
|
||||
if (!status->status.ok()) {
|
||||
@ -1164,7 +1166,9 @@ void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
|
||||
void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
|
||||
TF_Status* status) {
|
||||
status->status = tensorflow::unwrap(op)->AddInputList(
|
||||
{tensorflow::unwrap(inputs), static_cast<size_t>(num_inputs)});
|
||||
{reinterpret_cast<tensorflow::AbstractTensorHandle**>(
|
||||
tensorflow::unwrap(inputs)),
|
||||
static_cast<size_t>(num_inputs)});
|
||||
}
|
||||
|
||||
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
|
||||
@ -1324,7 +1328,9 @@ void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
|
||||
void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
|
||||
const TFE_Op** value, int num_values) {
|
||||
auto s = tensorflow::unwrap(op)->SetAttrFunctionList(
|
||||
attr_name, {tensorflow::unwrap(value), static_cast<size_t>(num_values)});
|
||||
attr_name, {reinterpret_cast<const tensorflow::AbstractOperation**>(
|
||||
tensorflow::unwrap(value)),
|
||||
static_cast<size_t>(num_values)});
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
@ -1368,7 +1374,10 @@ TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
|
||||
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
|
||||
TF_Status* status) {
|
||||
status->status = tensorflow::unwrap(op)->Execute(
|
||||
absl::MakeSpan(tensorflow::unwrap(retvals), *num_retvals), num_retvals);
|
||||
absl::MakeSpan(reinterpret_cast<tensorflow::AbstractTensorHandle**>(
|
||||
tensorflow::unwrap(retvals)),
|
||||
*num_retvals),
|
||||
num_retvals);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
||||
@ -1397,23 +1406,17 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx,
|
||||
tensorflow::errors::InvalidArgument("Invalid FunctionDef proto");
|
||||
return;
|
||||
}
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
status->status = context->AddFunctionDef(function_def);
|
||||
status->status = tensorflow::unwrap(ctx)->AddFunctionDef(function_def);
|
||||
}
|
||||
|
||||
void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
|
||||
TF_Status* status) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
status->status = context->AddFunctionDef(function->fdef);
|
||||
status->status = tensorflow::unwrap(ctx)->AddFunctionDef(function->fdef);
|
||||
}
|
||||
|
||||
void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,
|
||||
TF_Status* status) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
status->status = context->RemoveFunction(name);
|
||||
status->status = tensorflow::unwrap(ctx)->RemoveFunction(name);
|
||||
}
|
||||
|
||||
unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
|
||||
@ -1479,14 +1482,10 @@ const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op) {
|
||||
}
|
||||
|
||||
void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
|
||||
tensorflow::AttrValueMap m;
|
||||
tensorflow::unwrap(attrs)->FillAttrValueMap(&m);
|
||||
tensorflow::EagerOperation* operation =
|
||||
OperationFromInterface(tensorflow::unwrap(op));
|
||||
tensorflow::AttrBuilder* destination = operation->MutableAttrs();
|
||||
for (const auto& attribute : m) {
|
||||
destination->Set(attribute.first, attribute.second);
|
||||
}
|
||||
destination->CopyAttributes(*tensorflow::unwrap(attrs));
|
||||
}
|
||||
|
||||
void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs, TF_Buffer* buf,
|
||||
|
@ -38,7 +38,7 @@ using tensorflow::string;
|
||||
void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
|
||||
const char* raw_device_name, TF_Status* status) {
|
||||
if (op_to_reset) {
|
||||
tensorflow::AbstractOperationInterface* op =
|
||||
tensorflow::ImmediateExecutionOperation* op =
|
||||
tensorflow::unwrap(op_to_reset);
|
||||
op->Clear();
|
||||
status->status = op->Reset(op_or_function_name, raw_device_name);
|
||||
|
@ -212,6 +212,35 @@ TEST(CAPI, CancellationManager) {
|
||||
TFE_DeleteCancellationManager(c_mgr);
|
||||
}
|
||||
|
||||
TEST(CAPI, ExecutorContextDestructionOrder) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
|
||||
{
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
TFE_Executor* executor = TFE_NewExecutor(/*is_async=*/false);
|
||||
TFE_ContextSetExecutorForThread(ctx, executor);
|
||||
|
||||
TFE_DeleteContext(ctx);
|
||||
TFE_DeleteExecutor(executor);
|
||||
}
|
||||
|
||||
{
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
TFE_Executor* executor = TFE_NewExecutor(/*is_async=*/false);
|
||||
TFE_ContextSetExecutorForThread(ctx, executor);
|
||||
|
||||
TFE_DeleteExecutor(executor);
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
TEST(CAPI, Function_ident_CPU) {
|
||||
// First create a simple identity function.
|
||||
TF_Graph* function_graph = TF_NewGraph();
|
||||
|
@ -12,17 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_
|
||||
#define TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_
|
||||
#ifndef TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_
|
||||
#define TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
|
||||
#include "tensorflow/c/eager/abstract_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/numeric_types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
@ -34,16 +35,9 @@ namespace tensorflow {
|
||||
//
|
||||
// A context is responsible for creating key objects such as Tensors,
|
||||
// TensorHandles & Operations.
|
||||
class AbstractContextInterface {
|
||||
class ImmediateExecutionContext : public AbstractContext {
|
||||
public:
|
||||
// Release any underlying resources, including the interface object.
|
||||
//
|
||||
// WARNING: The destructor of this class is marked as protected to disallow
|
||||
// clients from directly destroying this object since it may manage it's own
|
||||
// lifetime through ref counting. Thus clients MUST call Release() in order to
|
||||
// destroy an instance of this class.
|
||||
virtual void Release() = 0;
|
||||
|
||||
static constexpr AbstractContextKind kKind = kImmediateExecution;
|
||||
// Optimized scalar creation functions
|
||||
virtual AbstractTensorInterface* CreateInt64Scalar(int64 value) = 0;
|
||||
virtual AbstractTensorInterface* CreateUint64Scalar(uint64 value) = 0;
|
||||
@ -74,21 +68,20 @@ class AbstractContextInterface {
|
||||
void* memory_releaser_arg) = 0;
|
||||
|
||||
// Create a handle to wrap and manage a Tensor
|
||||
virtual AbstractTensorHandleInterface* CreateLocalHandle(
|
||||
virtual ImmediateExecutionTensorHandle* CreateLocalHandle(
|
||||
AbstractTensorInterface* t) = 0;
|
||||
// Copy the handle to another device.
|
||||
virtual AbstractTensorHandleInterface* CopyTensorHandleToDevice(
|
||||
AbstractTensorHandleInterface* handle, const char* device_name,
|
||||
virtual ImmediateExecutionTensorHandle* CopyTensorHandleToDevice(
|
||||
ImmediateExecutionTensorHandle* handle, const char* device_name,
|
||||
Status* status) = 0;
|
||||
|
||||
// Create an operation to perform op execution
|
||||
virtual AbstractOperationInterface* CreateOperation() = 0;
|
||||
ImmediateExecutionOperation* CreateOperation() override = 0;
|
||||
|
||||
// Load a SavedModelAPI object from the given directory and tags
|
||||
virtual std::unique_ptr<SavedModelAPI> LoadSavedModelAPI(
|
||||
const std::string& directory,
|
||||
const absl::optional<std::unordered_set<std::string>>& tags,
|
||||
tensorflow::Status* status) = 0;
|
||||
// Returns whether the runtime is backed by TFRT or the legacy TF Eager
|
||||
// Runtime. This is necessary to decouple runtime-dependent
|
||||
// code that is layered on top of the runtime.
|
||||
virtual bool UsesTFRT() = 0;
|
||||
|
||||
// List attributes of available devices
|
||||
virtual void ListDevices(std::vector<DeviceAttributes>* devices) = 0;
|
||||
@ -104,10 +97,16 @@ class AbstractContextInterface {
|
||||
// Block until all pending nodes are finished.
|
||||
virtual Status AsyncWait() = 0;
|
||||
|
||||
// Add a function (serialized FunctionDef protocol buffer) so that it can
|
||||
// be executed as an op. Return error if the function with the same name
|
||||
// already exists.
|
||||
virtual Status AddFunctionDef(const FunctionDef& fdef) = 0;
|
||||
|
||||
protected:
|
||||
virtual ~AbstractContextInterface() {}
|
||||
ImmediateExecutionContext() : AbstractContext(kKind) {}
|
||||
~ImmediateExecutionContext() override {}
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_
|
||||
#endif // TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_
|
53
tensorflow/c/eager/immediate_execution_operation.h
Normal file
53
tensorflow/c/eager/immediate_execution_operation.h
Normal file
@ -0,0 +1,53 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_OPERATION_H_
|
||||
#define TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_OPERATION_H_
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_operation.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
struct TFE_Op;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Abstract interface to an operation.
|
||||
class ImmediateExecutionOperation : public AbstractOperation {
|
||||
public:
|
||||
static constexpr AbstractOperationKind kKind = kImmediateExecution;
|
||||
virtual void Clear() = 0;
|
||||
|
||||
virtual const tensorflow::OpDef* OpDef() const = 0;
|
||||
|
||||
virtual Status InputLength(const char* input_name, int* length) = 0;
|
||||
virtual Status OutputLength(const char* output_name, int* length) = 0;
|
||||
|
||||
// Experimental
|
||||
virtual Status SetUseXla(bool enable) = 0;
|
||||
|
||||
protected:
|
||||
ImmediateExecutionOperation() : AbstractOperation(kKind) {}
|
||||
~ImmediateExecutionOperation() override {}
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_OPERATION_H_
|
@ -12,9 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
|
||||
#define TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
|
||||
#ifndef TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_TENSOR_HANDLE_H_
|
||||
#define TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_TENSOR_HANDLE_H_
|
||||
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
@ -30,15 +31,9 @@ namespace tensorflow {
|
||||
// files. The interface lists the common functionality that must be provided by
|
||||
// any concrete implementation. However, in cases where the true concrete class
|
||||
// is needed a static_cast can be applied.
|
||||
class AbstractTensorHandleInterface {
|
||||
class ImmediateExecutionTensorHandle : public AbstractTensorHandle {
|
||||
public:
|
||||
// Release any underlying resources, including the interface object.
|
||||
//
|
||||
// WARNING: The destructor of this class is marked as protected to disallow
|
||||
// clients from directly destroying this object since it may manage it's own
|
||||
// lifetime through ref counting. Thus this must be allocated on the heap and
|
||||
// clients MUST call Release() in order to destroy an instance of this class.
|
||||
virtual void Release() = 0;
|
||||
static constexpr AbstractTensorHandleKind kKind = kImmediateExecution;
|
||||
|
||||
// Returns tensor dtype.
|
||||
virtual tensorflow::DataType DataType() const = 0;
|
||||
@ -57,12 +52,13 @@ class AbstractTensorHandleInterface {
|
||||
virtual AbstractTensorInterface* Resolve(Status* status) = 0;
|
||||
|
||||
// Return a copy of the handle.
|
||||
virtual AbstractTensorHandleInterface* Copy() = 0;
|
||||
virtual ImmediateExecutionTensorHandle* Copy() = 0;
|
||||
|
||||
protected:
|
||||
virtual ~AbstractTensorHandleInterface() {}
|
||||
ImmediateExecutionTensorHandle() : AbstractTensorHandle(kKind) {}
|
||||
~ImmediateExecutionTensorHandle() override {}
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
|
||||
#endif // TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_TENSOR_HANDLE_H_
|
@ -40,6 +40,9 @@ using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>;
|
||||
using MaybeParallelTensorOwned =
|
||||
absl::variant<std::unique_ptr<ParallelTensor>, TensorHandlePtr>;
|
||||
|
||||
using MaybeParallelTensorUnowned =
|
||||
absl::variant<ParallelTensor*, TFE_TensorHandle*>;
|
||||
|
||||
// A ParallelDevice on its own is not registered with a TFE_Context, and so has
|
||||
// no device name (e.g. for `tf.device`). `NamedParallelDevice` associates a
|
||||
// name with it, which lets us pack its `ParallelTensor`s into TFE_TensorHandles
|
||||
@ -141,9 +144,32 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ExecuteWithSpecialOps(
|
||||
result.emplace(std::move(result_content));
|
||||
return result;
|
||||
}
|
||||
std::vector<ParallelTensor*> parallel_inputs;
|
||||
std::vector<std::unique_ptr<ParallelTensor>> implicitly_broadcast_tensors;
|
||||
parallel_inputs.reserve(inputs.size());
|
||||
implicitly_broadcast_tensors.reserve(inputs.size()); // not tight
|
||||
for (const auto& input : inputs) {
|
||||
if (absl::holds_alternative<TFE_TensorHandle*>(input)) {
|
||||
// Non-parallel tensors are implicitly broadcast, i.e. set as the input
|
||||
// to each parallel operation.
|
||||
//
|
||||
// TODO(allenl): There may be smarter ways to do this copy in some
|
||||
// cases, i.e. with a collective broadcast. We'll need to be careful
|
||||
// about things that are taken as inputs on the host or on their
|
||||
// existing device (for multi-device functions).
|
||||
std::unique_ptr<ParallelTensor> parallel_tensor(
|
||||
parallel_device.CopyToParallelDevice(
|
||||
context, absl::get<TFE_TensorHandle*>(input), status));
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
parallel_inputs.push_back(parallel_tensor.get());
|
||||
implicitly_broadcast_tensors.emplace_back(std::move(parallel_tensor));
|
||||
} else {
|
||||
parallel_inputs.push_back(absl::get<ParallelTensor*>(input));
|
||||
}
|
||||
}
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
||||
maybe_parallel_results(
|
||||
parallel_device.Execute(context, std::move(inputs), operation_name,
|
||||
parallel_device.Execute(context, parallel_inputs, operation_name,
|
||||
attributes, expected_max_outputs, status));
|
||||
if (!maybe_parallel_results.has_value()) return result;
|
||||
std::vector<std::unique_ptr<ParallelTensor>> parallel_results(
|
||||
|
@ -16,6 +16,8 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
|
||||
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace parallel_device {
|
||||
@ -28,21 +30,216 @@ class OpDeleter {
|
||||
|
||||
using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>;
|
||||
|
||||
// Creates a vector of `count` new executors (threads).
|
||||
std::vector<ExecutorPtr> MakeExecutors(size_t count) {
|
||||
std::vector<ExecutorPtr> executors;
|
||||
executors.reserve(count);
|
||||
for (int i = 0; i < count; ++i) {
|
||||
executors.emplace_back(TFE_NewExecutor(true /* is_async */));
|
||||
class StatusDeleter {
|
||||
public:
|
||||
void operator()(TF_Status* to_delete) const { TF_DeleteStatus(to_delete); }
|
||||
};
|
||||
|
||||
using StatusPtr = std::unique_ptr<TF_Status, StatusDeleter>;
|
||||
|
||||
class ExecutorDeleter {
|
||||
public:
|
||||
void operator()(TFE_Executor* to_delete) const {
|
||||
TFE_DeleteExecutor(to_delete);
|
||||
}
|
||||
return executors;
|
||||
}
|
||||
};
|
||||
|
||||
using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
|
||||
|
||||
} // namespace
|
||||
|
||||
// Allows a single op at a time to be launched without blocking.
|
||||
//
|
||||
// DeviceThread itself is thread-safe, in that StartExecute will block if there
|
||||
// is a pending execution. Since StartExecute is equivalent to grabbing a lock,
|
||||
// multiple DeviceThreads should always be accessed in the same order to avoid
|
||||
// deadlocks.
|
||||
class DeviceThread {
|
||||
public:
|
||||
// Starts a background thread waiting for `StartExecute`.
|
||||
explicit DeviceThread(const std::string& device)
|
||||
: status_(TF_NewStatus()),
|
||||
device_(device),
|
||||
// If the context's default exector is set to async, re-using that in
|
||||
// each thread would cause collectives to deadlock. For consistency we
|
||||
// create a new sync executor for every thread.
|
||||
//
|
||||
// TODO(allenl): We should have an async API that works with the
|
||||
// parallel device.
|
||||
executor_(TFE_NewExecutor(/*is_async=*/false)),
|
||||
op_(nullptr),
|
||||
thread_(tensorflow::Env::Default()->StartThread(
|
||||
tensorflow::ThreadOptions(), "parallel_device_execute",
|
||||
std::bind(&DeviceThread::Run, this))) {}
|
||||
~DeviceThread();
|
||||
|
||||
// Requests that the worker thread execute the specified operation. Blocks
|
||||
// until the previously pending operation (a StartExecute without a Join) has
|
||||
// finished, if any.
|
||||
void StartExecute(TFE_Context* context, const char* operation_name,
|
||||
std::vector<TFE_TensorHandle*> inputs,
|
||||
const TFE_OpAttrs* attributes, int expected_max_outputs);
|
||||
// Block until the previous `StartExecute` operation has executed. Forwards
|
||||
// the status from `TFE_Execute` and returns outputs if the status is OK.
|
||||
std::vector<TensorHandlePtr> Join(TF_Status* status);
|
||||
|
||||
private:
|
||||
void Run();
|
||||
|
||||
void Execute(TFE_Context* context, const char* operation_name,
|
||||
std::vector<TFE_TensorHandle*> inputs,
|
||||
const TFE_OpAttrs* attributes, int expected_max_outputs,
|
||||
std::vector<TensorHandlePtr>* outputs, TF_Status* status) const
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(execution_mutex_);
|
||||
|
||||
enum class ExecutionState {
|
||||
kReadyToExecute,
|
||||
kHasResult,
|
||||
kIdle,
|
||||
kShuttingDown,
|
||||
};
|
||||
|
||||
tensorflow::mutex execution_mutex_;
|
||||
ExecutionState execution_state_ TF_GUARDED_BY(execution_mutex_) =
|
||||
ExecutionState::kIdle;
|
||||
// Tells the worker thread that there is new work.
|
||||
tensorflow::condition_variable start_execute_;
|
||||
// The worker thread notifies that work has finished.
|
||||
tensorflow::condition_variable finished_execute_;
|
||||
// Notifies a StartExecute that the previous Join has finished.
|
||||
tensorflow::condition_variable finished_join_;
|
||||
|
||||
// Temporary state between `StartExecute` and `Join`.
|
||||
// Inputs
|
||||
TFE_Context* context_ TF_GUARDED_BY(execution_mutex_);
|
||||
const char* operation_name_ TF_GUARDED_BY(execution_mutex_);
|
||||
std::vector<TFE_TensorHandle*> op_inputs_ TF_GUARDED_BY(execution_mutex_);
|
||||
const TFE_OpAttrs* attributes_ TF_GUARDED_BY(execution_mutex_);
|
||||
int expected_max_outputs_ TF_GUARDED_BY(execution_mutex_);
|
||||
// Outputs
|
||||
std::vector<TensorHandlePtr> op_outputs_ TF_GUARDED_BY(execution_mutex_);
|
||||
StatusPtr status_ TF_GUARDED_BY(execution_mutex_);
|
||||
|
||||
const std::string device_;
|
||||
ExecutorPtr executor_ TF_GUARDED_BY(execution_mutex_);
|
||||
mutable OpPtr op_ TF_GUARDED_BY(execution_mutex_);
|
||||
std::unique_ptr<Thread> thread_;
|
||||
};
|
||||
|
||||
DeviceThread::~DeviceThread() {
|
||||
{
|
||||
tensorflow::mutex_lock l(execution_mutex_);
|
||||
execution_state_ = ExecutionState::kShuttingDown;
|
||||
}
|
||||
start_execute_.notify_one();
|
||||
}
|
||||
|
||||
void DeviceThread::Run() {
|
||||
while (true) {
|
||||
{
|
||||
tensorflow::mutex_lock l(execution_mutex_);
|
||||
while (execution_state_ == ExecutionState::kIdle ||
|
||||
execution_state_ == ExecutionState::kHasResult) {
|
||||
start_execute_.wait(l);
|
||||
}
|
||||
if (execution_state_ == ExecutionState::kShuttingDown) {
|
||||
return;
|
||||
} else if (execution_state_ == ExecutionState::kReadyToExecute) {
|
||||
// op_outputs_ may have been std::moved
|
||||
op_outputs_ = std::vector<TensorHandlePtr>();
|
||||
Execute(context_, operation_name_, std::move(op_inputs_), attributes_,
|
||||
expected_max_outputs_, &op_outputs_, status_.get());
|
||||
execution_state_ = ExecutionState::kHasResult;
|
||||
}
|
||||
}
|
||||
finished_execute_.notify_one();
|
||||
}
|
||||
}
|
||||
|
||||
void DeviceThread::StartExecute(TFE_Context* context,
|
||||
const char* operation_name,
|
||||
std::vector<TFE_TensorHandle*> inputs,
|
||||
const TFE_OpAttrs* attributes,
|
||||
int expected_max_outputs) {
|
||||
{
|
||||
tensorflow::mutex_lock l(execution_mutex_);
|
||||
while (execution_state_ != ExecutionState::kIdle) {
|
||||
// If there's already a pending execution, wait until Join finishes before
|
||||
// starting on the next operation.
|
||||
finished_join_.wait(l);
|
||||
}
|
||||
context_ = context;
|
||||
operation_name_ = operation_name;
|
||||
op_inputs_ = inputs;
|
||||
attributes_ = attributes;
|
||||
expected_max_outputs_ = expected_max_outputs;
|
||||
execution_state_ = ExecutionState::kReadyToExecute;
|
||||
}
|
||||
start_execute_.notify_one();
|
||||
}
|
||||
|
||||
std::vector<TensorHandlePtr> DeviceThread::Join(TF_Status* status) {
|
||||
std::vector<TensorHandlePtr> result;
|
||||
{
|
||||
tensorflow::mutex_lock l(execution_mutex_);
|
||||
while (execution_state_ != ExecutionState::kHasResult) {
|
||||
finished_execute_.wait(l);
|
||||
}
|
||||
if (TF_GetCode(status_.get()) != TF_OK) {
|
||||
TF_SetStatus(status, TF_GetCode(status_.get()),
|
||||
TF_Message(status_.get()));
|
||||
}
|
||||
execution_state_ = ExecutionState::kIdle;
|
||||
result = std::move(op_outputs_);
|
||||
}
|
||||
finished_join_.notify_one();
|
||||
return result;
|
||||
}
|
||||
|
||||
void DeviceThread::Execute(TFE_Context* context, const char* operation_name,
|
||||
std::vector<TFE_TensorHandle*> inputs,
|
||||
const TFE_OpAttrs* attributes,
|
||||
int expected_max_outputs,
|
||||
std::vector<TensorHandlePtr>* outputs,
|
||||
TF_Status* status) const {
|
||||
if (op_ == nullptr) {
|
||||
TFE_ContextSetExecutorForThread(context, executor_.get());
|
||||
op_.reset(TFE_NewOp(context, operation_name, status));
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpSetDevice(op_.get(), device_.c_str(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
} else {
|
||||
TFE_OpReset(op_.get(), operation_name, device_.c_str(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
}
|
||||
TFE_OpAddAttrs(op_.get(), attributes);
|
||||
for (int input_index = 0; input_index < inputs.size(); ++input_index) {
|
||||
TFE_OpAddInput(op_.get(), inputs[input_index], status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
}
|
||||
std::vector<TFE_TensorHandle*> unwrapped_results(expected_max_outputs);
|
||||
int real_num_outputs = expected_max_outputs;
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_Execute(op_.get(), unwrapped_results.data(), &real_num_outputs, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
unwrapped_results.resize(real_num_outputs);
|
||||
outputs->reserve(real_num_outputs);
|
||||
for (TFE_TensorHandle* unwrapped_result : unwrapped_results) {
|
||||
outputs->emplace_back(unwrapped_result);
|
||||
}
|
||||
}
|
||||
|
||||
ParallelDevice::ParallelDevice(const std::vector<std::string>& devices)
|
||||
: underlying_devices_(devices),
|
||||
executors_(MakeExecutors(underlying_devices_.size())) {}
|
||||
: underlying_devices_(devices) {
|
||||
device_threads_.reserve(devices.size());
|
||||
for (int device_index = 0; device_index < devices.size(); ++device_index) {
|
||||
device_threads_.emplace_back(
|
||||
new DeviceThread(devices[device_index].c_str()));
|
||||
}
|
||||
}
|
||||
|
||||
// Necessary for a unique_ptr to a forward-declared type.
|
||||
ParallelDevice::~ParallelDevice() = default;
|
||||
|
||||
std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
|
||||
TFE_Context* context, TFE_TensorHandle* tensor, TF_Status* status) const {
|
||||
@ -100,7 +297,7 @@ std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
|
||||
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
||||
ParallelDevice::Execute(TFE_Context* context,
|
||||
std::vector<MaybeParallelTensorUnowned> inputs,
|
||||
const std::vector<ParallelTensor*>& inputs,
|
||||
const char* operation_name,
|
||||
const TFE_OpAttrs* attributes, int expected_max_outputs,
|
||||
TF_Status* status) const {
|
||||
@ -108,88 +305,34 @@ ParallelDevice::Execute(TFE_Context* context,
|
||||
// Compute per-device per-output tensors
|
||||
std::vector<std::vector<TensorHandlePtr>> per_device_output_tensors;
|
||||
per_device_output_tensors.reserve(underlying_devices_.size());
|
||||
// TODO(allenl): Add a TFE_ExecuteWithExecutor API so we don't have to keep
|
||||
// setting the thread-local executor like this.
|
||||
TFE_Executor* previous_executor(TFE_ContextGetExecutorForThread(context));
|
||||
auto reset_executor =
|
||||
tensorflow::gtl::MakeCleanup([context, previous_executor]() {
|
||||
TFE_ContextSetExecutorForThread(context, previous_executor);
|
||||
TFE_DeleteExecutor(previous_executor);
|
||||
});
|
||||
int first_op_output_count;
|
||||
int first_op_output_count = 0;
|
||||
for (int device_index = 0; device_index < underlying_devices_.size();
|
||||
++device_index) {
|
||||
TFE_Executor* executor = executors_[device_index].get();
|
||||
// Note that the `reset_executor` cleanup sets the thread's executor back to
|
||||
// the value before this function ran.
|
||||
TFE_ContextSetExecutorForThread(context, executor);
|
||||
OpPtr op(TFE_NewOp(context, operation_name, status));
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
TFE_OpSetDevice(op.get(), underlying_devices_[device_index].c_str(),
|
||||
status);
|
||||
TFE_OpAddAttrs(op.get(), attributes);
|
||||
DeviceThread* device_thread = device_threads_[device_index].get();
|
||||
std::vector<TFE_TensorHandle*> device_inputs;
|
||||
device_inputs.reserve(device_inputs.size());
|
||||
for (int input_index = 0; input_index < inputs.size(); ++input_index) {
|
||||
if (absl::holds_alternative<TFE_TensorHandle*>(inputs[input_index])) {
|
||||
// Non-parallel tensors are implicitly broadcast, i.e. set as the input
|
||||
// to each parallel operation.
|
||||
//
|
||||
// TODO(allenl): There may be smarter ways to do this copy in some
|
||||
// cases, i.e. with a collective broadcast. We'll need to be careful
|
||||
// about things that are taken as inputs on the host or on their
|
||||
// existing device (for multi-device functions).
|
||||
TFE_OpAddInput(op.get(),
|
||||
absl::get<TFE_TensorHandle*>(inputs[input_index]),
|
||||
status);
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
} else {
|
||||
// Parallel tensors are divided between operations by device.
|
||||
TFE_OpAddInput(op.get(),
|
||||
absl::get<ParallelTensor*>(inputs[input_index])
|
||||
->tensor(device_index),
|
||||
status);
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
}
|
||||
// Parallel tensors are divided between operations by device.
|
||||
device_inputs.push_back(inputs[input_index]->tensor(device_index));
|
||||
}
|
||||
std::vector<TFE_TensorHandle*> op_outputs(expected_max_outputs);
|
||||
int real_num_outputs = expected_max_outputs;
|
||||
// For nested devices, the inner device sees the async executor we've
|
||||
// set. Inner parallel devices will just overwrite this with their own and
|
||||
// then set it back to ours before returning. This means parallel devices
|
||||
// which consist of several aliased parallel devices would hypothetically
|
||||
// deadlock if the outer parallel device ran one collective with a group
|
||||
// size equal to the total number of aliased physical devices. Currently
|
||||
// physical devices cannot participate in a single collective reduction
|
||||
// multiple times, so this would fail earlier.
|
||||
//
|
||||
// TODO(allenl): Keep a map from outer executor to list of inner executors
|
||||
// rather than a single list of executors so aliased nested parallel devices
|
||||
// don't re-use an executor.
|
||||
TFE_Execute(op.get(), op_outputs.data(), &real_num_outputs, status);
|
||||
device_thread->StartExecute(context, operation_name,
|
||||
std::move(device_inputs), attributes,
|
||||
expected_max_outputs);
|
||||
}
|
||||
for (int device_index = 0; device_index < underlying_devices_.size();
|
||||
++device_index) {
|
||||
DeviceThread* device_thread = device_threads_[device_index].get();
|
||||
per_device_output_tensors.push_back(device_thread->Join(status));
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
if (device_index == 0) {
|
||||
first_op_output_count = real_num_outputs;
|
||||
first_op_output_count = per_device_output_tensors.rbegin()->size();
|
||||
} else {
|
||||
if (real_num_outputs != first_op_output_count) {
|
||||
if (per_device_output_tensors.rbegin()->size() != first_op_output_count) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"Parallel ops produced different numbers of tensors.");
|
||||
return result;
|
||||
}
|
||||
}
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
std::vector<TensorHandlePtr> this_outputs;
|
||||
this_outputs.reserve(real_num_outputs);
|
||||
for (int output_num = 0; output_num < real_num_outputs; ++output_num) {
|
||||
this_outputs.emplace_back(op_outputs[output_num]);
|
||||
}
|
||||
per_device_output_tensors.push_back(std::move(this_outputs));
|
||||
}
|
||||
for (int device_index = 0; device_index < underlying_devices_.size();
|
||||
++device_index) {
|
||||
TFE_Executor* executor = executors_[device_index].get();
|
||||
// TODO(b/157523095): Syncing the executor here shouldn't be
|
||||
// necessary. Currently async+remote is missing cross-executor
|
||||
// coordination.
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
}
|
||||
// For each output of the original operation, pack the per-device
|
||||
// TensorHandles we've computed into a single parallel TensorHandle.
|
||||
|
@ -41,19 +41,8 @@ class TensorHandleDeleter {
|
||||
|
||||
using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
|
||||
|
||||
class ExecutorDeleter {
|
||||
public:
|
||||
void operator()(TFE_Executor* to_delete) const {
|
||||
TFE_DeleteExecutor(to_delete);
|
||||
}
|
||||
};
|
||||
|
||||
using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
|
||||
|
||||
class ParallelTensor;
|
||||
|
||||
using MaybeParallelTensorUnowned =
|
||||
absl::variant<ParallelTensor*, TFE_TensorHandle*>;
|
||||
class DeviceThread;
|
||||
|
||||
// Forwards operations to `devices`, maintaining ParallelTensor with components
|
||||
// placed on each underlying device.
|
||||
@ -61,6 +50,8 @@ class ParallelDevice {
|
||||
public:
|
||||
explicit ParallelDevice(const std::vector<std::string>& devices);
|
||||
|
||||
~ParallelDevice();
|
||||
|
||||
// Helper to copy a tensor handle from another device once for each component
|
||||
// of the ParallelDevice.
|
||||
//
|
||||
@ -79,10 +70,9 @@ class ParallelDevice {
|
||||
|
||||
// Takes a description of a single operation being executed on the
|
||||
// ParallelDevice, and in turn runs one operation per component device with
|
||||
// its corresponding inputs from the input ParallelTensors (or
|
||||
// implicitly-mirrored tensors on other devices). Wraps the resulting
|
||||
// per-device and per-output TFE_TensorHandles into one ParallelTensor per
|
||||
// output of the original operation.
|
||||
// its corresponding inputs from the input ParallelTensors. Wraps the
|
||||
// resulting per-device and per-output TFE_TensorHandles into one
|
||||
// ParallelTensor per output of the original operation.
|
||||
//
|
||||
// Attributes are forwarded to executed operations unmodified.
|
||||
//
|
||||
@ -90,7 +80,7 @@ class ParallelDevice {
|
||||
// TF_OK. Bad statuses are forwarded from underlying `TFE_Execute` calls, or
|
||||
// if sanity checks on dtypes/metadata fail.
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> Execute(
|
||||
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
|
||||
TFE_Context* context, const std::vector<ParallelTensor*>& inputs,
|
||||
const char* operation_name, const TFE_OpAttrs* attributes,
|
||||
int expected_max_outputs, TF_Status* status) const;
|
||||
|
||||
@ -98,9 +88,19 @@ class ParallelDevice {
|
||||
// A sequence of device names, indicating which devices replicated operations
|
||||
// are forwarded to.
|
||||
const std::vector<std::string> underlying_devices_;
|
||||
// A sequence of TFE_Executors, one per device, for executing operations in
|
||||
// A sequence of thread wrappers, one per device, for executing operations in
|
||||
// parallel.
|
||||
const std::vector<ExecutorPtr> executors_;
|
||||
//
|
||||
// Conceptually this is a thread pool with one thread per device. It requires
|
||||
// less synchronization than a thread pool would for this task, since Execute
|
||||
// acquires each thread in order (and so only one Execute will schedule
|
||||
// blocking collective operations at a time), and avoids some dynamic
|
||||
// allocation/scheduling.
|
||||
//
|
||||
// TODO(allenl): Keep a map from outer thread to list of inner threads rather
|
||||
// than a single list of threads so aliased nested parallel devices don't
|
||||
// re-use a thread.
|
||||
std::vector<std::unique_ptr<DeviceThread>> device_threads_;
|
||||
};
|
||||
|
||||
// Contains a tuple of tensors, one on each of the `underlying_devices_` of the
|
||||
|
@ -407,11 +407,12 @@ TensorHandlePtr CollectiveSum(TFE_Context* context, TFE_TensorHandle* input,
|
||||
return TensorHandlePtr(result_handle);
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE, TestCollective) {
|
||||
void TestCollective(bool async) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
TFE_ContextOptionsSetAsync(opts.get(), async);
|
||||
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
|
||||
TF_CreateConfig(
|
||||
/*xla*/ false,
|
||||
@ -454,6 +455,12 @@ TEST(PARALLEL_DEVICE, TestCollective) {
|
||||
ExpectScalarEq<float>(result_components[1].get(), 3.);
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE, TestCollectiveSync) { TestCollective(/*async=*/false); }
|
||||
|
||||
// Note that ops on the parallel device currently don't execute
|
||||
// asynchronously. The test is just that we don't get deadlocks.
|
||||
TEST(PARALLEL_DEVICE, TestCollectiveAsync) { TestCollective(/*async=*/true); }
|
||||
|
||||
void RegisterCollectiveMulFunction(TFE_Context* context,
|
||||
const char* function_name, int group_size,
|
||||
TF_Status* status) {
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/eager/context_interface.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
|
||||
// Wraps a pointer to a context implementation.
|
||||
//
|
||||
@ -28,7 +28,7 @@ typedef struct TFE_Context TFE_Context;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractContextInterface, TFE_Context);
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionContext, TFE_Context);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
|
||||
// Wraps a pointer to an operation implementation.
|
||||
//
|
||||
@ -28,8 +28,8 @@ typedef struct TFE_Op TFE_Op;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractOperationInterface, TFE_Op);
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractOperationInterface*, TFE_Op*);
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionOperation, TFE_Op);
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionOperation*, TFE_Op*);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
|
||||
// Wraps a pointer to a tensor handle implementation.
|
||||
//
|
||||
@ -28,9 +28,9 @@ typedef struct TFE_TensorHandle TFE_TensorHandle;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractTensorHandleInterface,
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionTensorHandle,
|
||||
TFE_TensorHandle);
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractTensorHandleInterface*,
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionTensorHandle*,
|
||||
TFE_TensorHandle*);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/path.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
struct TF_StringStream {
|
||||
@ -146,6 +147,10 @@ TF_StringStream* TF_GetLocalTempDirectories() {
|
||||
return list;
|
||||
}
|
||||
|
||||
char* TF_GetTempFileName(const char* extension) {
|
||||
return strdup(::tensorflow::io::GetTempFilename(extension).c_str());
|
||||
}
|
||||
|
||||
TF_CAPI_EXPORT extern uint64_t TF_NowNanos(void) {
|
||||
return ::tensorflow::Env::Default()->NowNanos();
|
||||
}
|
||||
|
@ -152,6 +152,10 @@ TF_CAPI_EXPORT extern TF_StringStream* TF_GetChildren(const char* filename,
|
||||
// The caller is responsible for freeing the list (see TF_StringStreamDone).
|
||||
TF_CAPI_EXPORT extern TF_StringStream* TF_GetLocalTempDirectories(void);
|
||||
|
||||
// Creates a temporary file name with an extension.
|
||||
// The caller is responsible for freeing the returned pointer.
|
||||
TF_CAPI_EXPORT extern char* TF_GetTempFileName(const char* extension);
|
||||
|
||||
// Returns the number of nanoseconds since the Unix epoch.
|
||||
TF_CAPI_EXPORT extern uint64_t TF_NowNanos(void);
|
||||
|
||||
|
@ -26,5 +26,7 @@ cc_library(
|
||||
deps = [
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
||||
"@com_github_googlecloudplatform_google_cloud_cpp//:storage_client",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
@ -15,15 +15,66 @@ limitations under the License.
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "google/cloud/storage/client.h"
|
||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
// Implementation of a filesystem for GCS environments.
|
||||
// This filesystem will support `gs://` URI schemes.
|
||||
namespace gcs = google::cloud::storage;
|
||||
|
||||
// We can cast `google::cloud::StatusCode` to `TF_Code` because they have the
|
||||
// same integer values. See
|
||||
// https://github.com/googleapis/google-cloud-cpp/blob/6c09cbfa0160bc046e5509b4dd2ab4b872648b4a/google/cloud/status.h#L32-L52
|
||||
static inline void TF_SetStatusFromGCSStatus(
|
||||
const google::cloud::Status& gcs_status, TF_Status* status) {
|
||||
TF_SetStatus(status, static_cast<TF_Code>(gcs_status.code()),
|
||||
gcs_status.message().c_str());
|
||||
}
|
||||
|
||||
static void* plugin_memory_allocate(size_t size) { return calloc(1, size); }
|
||||
static void plugin_memory_free(void* ptr) { free(ptr); }
|
||||
|
||||
static void ParseGCSPath(absl::string_view fname, bool object_empty_ok,
|
||||
char** bucket, char** object, TF_Status* status) {
|
||||
size_t scheme_end = fname.find("://") + 2;
|
||||
if (fname.substr(0, scheme_end + 1) != "gs://") {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
"GCS path doesn't start with 'gs://'.");
|
||||
return;
|
||||
}
|
||||
|
||||
size_t bucket_end = fname.find("/", scheme_end + 1);
|
||||
if (bucket_end == absl::string_view::npos) {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
"GCS path doesn't contain a bucket name.");
|
||||
return;
|
||||
}
|
||||
absl::string_view bucket_view =
|
||||
fname.substr(scheme_end + 1, bucket_end - scheme_end - 1);
|
||||
*bucket =
|
||||
static_cast<char*>(plugin_memory_allocate(bucket_view.length() + 1));
|
||||
memcpy(*bucket, bucket_view.data(), bucket_view.length());
|
||||
(*bucket)[bucket_view.length()] = '\0';
|
||||
|
||||
absl::string_view object_view = fname.substr(bucket_end + 1);
|
||||
if (object_view.empty()) {
|
||||
if (object_empty_ok) {
|
||||
*object = nullptr;
|
||||
return;
|
||||
} else {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
"GCS path doesn't contain an object name.");
|
||||
return;
|
||||
}
|
||||
}
|
||||
*object =
|
||||
static_cast<char*>(plugin_memory_allocate(object_view.length() + 1));
|
||||
// object_view.data() is a null-terminated string_view because fname is.
|
||||
strcpy(*object, object_view.data());
|
||||
}
|
||||
|
||||
// SECTION 1. Implementation for `TF_RandomAccessFile`
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_random_access_file {
|
||||
@ -52,6 +103,20 @@ namespace tf_read_only_memory_region {
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_gcs_filesystem {
|
||||
|
||||
// TODO(vnvo2409): Add lazy-loading and customizing parameters.
|
||||
static void Init(TF_Filesystem* filesystem, TF_Status* status) {
|
||||
google::cloud::StatusOr<gcs::Client> client =
|
||||
gcs::Client::CreateDefaultClient();
|
||||
if (!client) {
|
||||
TF_SetStatusFromGCSStatus(client.status(), status);
|
||||
return;
|
||||
}
|
||||
filesystem->plugin_filesystem = plugin_memory_allocate(sizeof(gcs::Client));
|
||||
auto gcs_client = static_cast<gcs::Client*>(filesystem->plugin_filesystem);
|
||||
(*gcs_client) = client.value();
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
// TODO(vnvo2409): Implement later
|
||||
|
||||
} // namespace tf_gcs_filesystem
|
||||
@ -60,6 +125,10 @@ static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
|
||||
const char* uri) {
|
||||
TF_SetFilesystemVersionMetadata(ops);
|
||||
ops->scheme = strdup(uri);
|
||||
|
||||
ops->filesystem_ops = static_cast<TF_FilesystemOps*>(
|
||||
plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE));
|
||||
ops->filesystem_ops->init = tf_gcs_filesystem::Init;
|
||||
}
|
||||
|
||||
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {
|
||||
@ -69,4 +138,4 @@ void TF_InitPlugin(TF_FilesystemPluginInfo* info) {
|
||||
info->ops = static_cast<TF_FilesystemPluginOps*>(
|
||||
plugin_memory_allocate(info->num_schemes * sizeof(info->ops[0])));
|
||||
ProvideFilesystemSupportFor(&info->ops[0], "gs");
|
||||
}
|
||||
}
|
||||
|
@ -23,8 +23,8 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
":function_metadata",
|
||||
"//tensorflow/c/eager:operation_interface",
|
||||
"//tensorflow/c/eager:tensor_handle_interface",
|
||||
"//tensorflow/c/eager:immediate_execution_operation",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
@ -57,6 +57,7 @@ cc_library(
|
||||
":concrete_function",
|
||||
":saved_model_api",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
@ -15,12 +15,12 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
const std::vector<tensorflow::AbstractTensorHandleInterface*>&
|
||||
const std::vector<tensorflow::ImmediateExecutionTensorHandle*>&
|
||||
ConcreteFunction::GetCaptures() const {
|
||||
return captures_;
|
||||
}
|
||||
|
@ -18,8 +18,8 @@ limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
|
||||
@ -38,15 +38,15 @@ class ConcreteFunction {
|
||||
virtual ~ConcreteFunction() = 0;
|
||||
|
||||
// This method returns the "Call" Op used to execute the function.
|
||||
virtual AbstractOperationInterface* GetCallOp() = 0;
|
||||
virtual ImmediateExecutionOperation* GetCallOp() = 0;
|
||||
|
||||
const std::vector<tensorflow::AbstractTensorHandleInterface*>& GetCaptures()
|
||||
const std::vector<tensorflow::ImmediateExecutionTensorHandle*>& GetCaptures()
|
||||
const;
|
||||
const FunctionMetadata& GetFunctionMetadata() const;
|
||||
|
||||
private:
|
||||
FunctionMetadata metadata_;
|
||||
std::vector<tensorflow::AbstractTensorHandleInterface*> captures_;
|
||||
std::vector<tensorflow::ImmediateExecutionTensorHandle*> captures_;
|
||||
FunctionDef* function_;
|
||||
};
|
||||
|
||||
|
96
tensorflow/c/experimental/saved_model/core/ops/BUILD
Normal file
96
tensorflow/c/experimental/saved_model/core/ops/BUILD
Normal file
@ -0,0 +1,96 @@
|
||||
# This package contains written convenience helpers for Eager Operations
|
||||
# used by SavedModel. Once we autogenerate C++ Eager Op wrappers, we can remove these.
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_cc_test",
|
||||
)
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
# Restricting visibility for now
|
||||
"//tensorflow/c/experimental/saved_model/core:__subpackages__",
|
||||
"//tensorflow/c/experimental/saved_model/internal:__subpackages__",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "owned_eager_op",
|
||||
hdrs = [
|
||||
"owned_eager_op.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/eager:immediate_execution_operation",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "owned_tensor_handle",
|
||||
hdrs = [
|
||||
"owned_tensor_handle.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "owned_eager_context",
|
||||
hdrs = ["owned_eager_context.h"],
|
||||
deps = [
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "owned_tensor",
|
||||
hdrs = ["owned_tensor.h"],
|
||||
deps = [
|
||||
"//tensorflow/c:tensor_interface",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "variable_ops",
|
||||
srcs = [
|
||||
"variable_ops.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"variable_ops.h",
|
||||
],
|
||||
deps = [
|
||||
":owned_eager_op",
|
||||
":owned_tensor_handle",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "variable_ops_test",
|
||||
srcs = [
|
||||
"variable_ops_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":owned_eager_context",
|
||||
":owned_tensor",
|
||||
":owned_tensor_handle",
|
||||
":variable_ops",
|
||||
"//tensorflow/core:all_kernels",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/common_runtime:core_cpu_lib",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"//tensorflow/core/common_runtime/eager:core",
|
||||
],
|
||||
)
|
@ -0,0 +1,54 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_OWNED_EAGER_CONTEXT_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_OWNED_EAGER_CONTEXT_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
struct ImmediateExecutionContextDeleter {
|
||||
void operator()(ImmediateExecutionContext* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct EagerContextDeleter {
|
||||
void operator()(EagerContext* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
|
||||
using AbstractContextPtr =
|
||||
std::unique_ptr<ImmediateExecutionContext,
|
||||
internal::ImmediateExecutionContextDeleter>;
|
||||
|
||||
using EagerContextPtr =
|
||||
std::unique_ptr<EagerContext, internal::EagerContextDeleter>;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_OWNED_EAGER_CONTEXT_H_
|
@ -0,0 +1,42 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OWNED_EAGER_OP_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OWNED_EAGER_OP_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
struct ImmediateExecutionOperationDeleter {
|
||||
void operator()(ImmediateExecutionOperation* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
|
||||
using AbstractOpPtr =
|
||||
std::unique_ptr<ImmediateExecutionOperation,
|
||||
internal::ImmediateExecutionOperationDeleter>;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OWNED_EAGER_OP_H_
|
@ -0,0 +1,42 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_OWNED_TENSOR_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_OWNED_TENSOR_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
struct AbstractTensorInterfaceDeleter {
|
||||
void operator()(AbstractTensorInterface* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
|
||||
using AbstractTensorPtr =
|
||||
std::unique_ptr<AbstractTensorInterface,
|
||||
internal::AbstractTensorInterfaceDeleter>;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_OWNED_TENSOR_H_
|
@ -0,0 +1,54 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OWNED_TENSOR_HANDLE_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OWNED_TENSOR_HANDLE_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
struct TensorHandleDeleter {
|
||||
void operator()(TensorHandle* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct AbstractTensorHandleDeleter {
|
||||
void operator()(ImmediateExecutionTensorHandle* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
|
||||
using TensorHandlePtr =
|
||||
std::unique_ptr<TensorHandle, internal::TensorHandleDeleter>;
|
||||
|
||||
using AbstractTensorHandlePtr =
|
||||
std::unique_ptr<ImmediateExecutionTensorHandle,
|
||||
internal::AbstractTensorHandleDeleter>;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OWNED_TENSOR_HANDLE_H_
|
111
tensorflow/c/experimental/saved_model/core/ops/variable_ops.cc
Normal file
111
tensorflow/c/experimental/saved_model/core/ops/variable_ops.cc
Normal file
@ -0,0 +1,111 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/ops/variable_ops.h"
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/ops/owned_eager_op.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/ops/owned_tensor_handle.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
static const char kNoSharingResourceID[] =
|
||||
"cd2c89b7-88b7-44c8-ad83-06c2a9158347";
|
||||
|
||||
Status CreateUninitializedResourceVariable(ImmediateExecutionContext* ctx,
|
||||
DataType dtype, TensorShape shape,
|
||||
AbstractTensorHandlePtr* handle) {
|
||||
AbstractOpPtr varhandle_op = AbstractOpPtr(ctx->CreateOperation());
|
||||
|
||||
TF_RETURN_IF_ERROR(varhandle_op->Reset("VarHandleOp", nullptr));
|
||||
TF_RETURN_IF_ERROR(varhandle_op->SetAttrType("dtype", dtype));
|
||||
|
||||
// Note that if shape is unknown rank, shape.dim_sizes() will be empty, and
|
||||
// shape.dims() will be -1.
|
||||
gtl::InlinedVector<int64, 4> dim_sizes = shape.dim_sizes();
|
||||
TF_RETURN_IF_ERROR(varhandle_op->SetAttrShape(
|
||||
"shape", reinterpret_cast<const int64_t*>(dim_sizes.data()),
|
||||
shape.dims()));
|
||||
TF_RETURN_IF_ERROR(varhandle_op->SetAttrString("container", "", 0));
|
||||
TF_RETURN_IF_ERROR(varhandle_op->SetAttrString(
|
||||
"shared_name", kNoSharingResourceID, strlen(kNoSharingResourceID)));
|
||||
|
||||
AbstractTensorHandle* var_handle = nullptr;
|
||||
int num_retvals = 1;
|
||||
TF_RETURN_IF_ERROR(varhandle_op->Execute(
|
||||
absl::MakeSpan(&var_handle, num_retvals), &num_retvals));
|
||||
if (var_handle->getKind() != ImmediateExecutionTensorHandle::kKind) {
|
||||
return errors::Internal("Unexpected tensor handle kind.");
|
||||
}
|
||||
handle->reset(reinterpret_cast<ImmediateExecutionTensorHandle*>(var_handle));
|
||||
return Status();
|
||||
}
|
||||
|
||||
Status AssignVariable(ImmediateExecutionContext* ctx,
|
||||
ImmediateExecutionTensorHandle* variable_handle,
|
||||
DataType dtype, ImmediateExecutionTensorHandle* value) {
|
||||
AbstractOpPtr assign_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(assign_op->Reset("AssignVariableOp", nullptr));
|
||||
TF_RETURN_IF_ERROR(assign_op->SetAttrType("dtype", dtype));
|
||||
TF_RETURN_IF_ERROR(assign_op->AddInput(variable_handle));
|
||||
TF_RETURN_IF_ERROR(assign_op->AddInput(value));
|
||||
|
||||
int num_retvals = 0;
|
||||
TF_RETURN_IF_ERROR(assign_op->Execute({}, &num_retvals));
|
||||
return Status();
|
||||
}
|
||||
|
||||
Status ReadVariable(ImmediateExecutionContext* ctx,
|
||||
ImmediateExecutionTensorHandle* variable_handle,
|
||||
DataType dtype, AbstractTensorHandlePtr* output) {
|
||||
AbstractOpPtr read_op = AbstractOpPtr(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(read_op->Reset("ReadVariableOp", nullptr));
|
||||
TF_RETURN_IF_ERROR(read_op->SetAttrType("dtype", dtype));
|
||||
TF_RETURN_IF_ERROR(read_op->AddInput(variable_handle));
|
||||
|
||||
AbstractTensorHandle* value = nullptr;
|
||||
int num_retvals = 1;
|
||||
TF_RETURN_IF_ERROR(
|
||||
read_op->Execute(absl::MakeSpan(&value, num_retvals), &num_retvals));
|
||||
if (value->getKind() != ImmediateExecutionTensorHandle::kKind) {
|
||||
return errors::Internal("Unexpected tensor handle kind.");
|
||||
}
|
||||
output->reset(reinterpret_cast<ImmediateExecutionTensorHandle*>(value));
|
||||
return Status();
|
||||
}
|
||||
|
||||
Status DestroyResource(ImmediateExecutionContext* ctx,
|
||||
ImmediateExecutionTensorHandle* handle) {
|
||||
AbstractOpPtr destroy_op = AbstractOpPtr(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(destroy_op->Reset("DestroyResourceOp", nullptr));
|
||||
TF_RETURN_IF_ERROR(destroy_op->SetAttrBool("ignore_lookup_error", true));
|
||||
TF_RETURN_IF_ERROR(destroy_op->AddInput(handle));
|
||||
|
||||
int num_retvals = 0;
|
||||
TF_RETURN_IF_ERROR(destroy_op->Execute({}, &num_retvals));
|
||||
return Status();
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
@ -0,0 +1,62 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_VARIABLE_OPS_H
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_VARIABLE_OPS_H
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/ops/owned_tensor_handle.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
// Executes a VarHandleOp using `ctx`, and fills `handle` with the DT_RESOURCE
|
||||
// TensorHandle associated with the variable. This is equivalent to creating an
|
||||
// unitialized TF2 tf.Variable.
|
||||
// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/ops/resource_variable_ops.py#L1867-L1872
|
||||
Status CreateUninitializedResourceVariable(ImmediateExecutionContext* ctx,
|
||||
DataType dtype, TensorShape shape,
|
||||
AbstractTensorHandlePtr* handle);
|
||||
|
||||
// Executes an AssignVariableOp using `ctx`, assigning the variable associated
|
||||
// with `variable_handle` with `value`. `dtype` must be the datatype of the
|
||||
// underlying variable for `variable_handle`. Note that it is illegal to assign
|
||||
// a variable to a Tensor with a different dtype than what the variable was
|
||||
// created with.
|
||||
Status AssignVariable(ImmediateExecutionContext* ctx,
|
||||
ImmediateExecutionTensorHandle* variable_handle,
|
||||
DataType dtype, ImmediateExecutionTensorHandle* value);
|
||||
|
||||
// Executes a ReadVariableOp using `ctx`. This reads the underlying variable
|
||||
// value of `variable_handle` and copies the value to `output`. `dtype` must be
|
||||
// the dtype of the variable associated with `variable_handle`.
|
||||
Status ReadVariable(ImmediateExecutionContext* ctx,
|
||||
ImmediateExecutionTensorHandle* variable_handle,
|
||||
DataType dtype, AbstractTensorHandlePtr* output);
|
||||
|
||||
// Executes DestroyResourceOp on `handle`, using `ctx`. This is equivalent to
|
||||
// the cleanup that occurs in a tf.Variable's EagerResourceDeleter:
|
||||
// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/ops/resource_variable_ops.py#L289-L290
|
||||
Status DestroyResource(ImmediateExecutionContext* ctx,
|
||||
ImmediateExecutionTensorHandle* handle);
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_VARIABLE_OPS_H
|
@ -0,0 +1,107 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/ops/variable_ops.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/ops/owned_eager_context.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/ops/owned_tensor.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/ops/owned_tensor_handle.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
AbstractTensorHandlePtr CreateScalarTensorHandle(EagerContext* context,
|
||||
float value) {
|
||||
AbstractTensorPtr tensor(context->CreateFloatScalar(value));
|
||||
AbstractTensorHandlePtr handle(context->CreateLocalHandle(tensor.get()));
|
||||
return handle;
|
||||
}
|
||||
|
||||
class VariableOpsTest : public ::testing::Test {
|
||||
public:
|
||||
VariableOpsTest()
|
||||
: device_mgr_(std::make_unique<StaticDeviceMgr>(DeviceFactory::NewDevice(
|
||||
"CPU", {}, "/job:localhost/replica:0/task:0"))),
|
||||
ctx_(new EagerContext(
|
||||
SessionOptions(),
|
||||
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
||||
tensorflow::ContextMirroringPolicy::MIRRORING_NONE,
|
||||
/* async= */ false,
|
||||
/* lazy_copy_function_remote_inputs= */ false, device_mgr_.get(),
|
||||
/* device_mgr_owned= */ false, /* rendezvous= */ nullptr,
|
||||
/* custom_kernel_creator= */ nullptr,
|
||||
/* cluster_flr= */ nullptr)) {}
|
||||
|
||||
EagerContext* context() { return ctx_.get(); }
|
||||
|
||||
private:
|
||||
std::unique_ptr<StaticDeviceMgr> device_mgr_;
|
||||
EagerContextPtr ctx_;
|
||||
};
|
||||
|
||||
// Sanity check for variable creation
|
||||
TEST_F(VariableOpsTest, CreateVariableSuccessful) {
|
||||
// Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor
|
||||
AbstractTensorHandlePtr handle;
|
||||
TF_EXPECT_OK(internal::CreateUninitializedResourceVariable(
|
||||
context(), DT_FLOAT, {}, &handle));
|
||||
// The created TensorHandle should be a DT_Resource
|
||||
EXPECT_EQ(handle->DataType(), DT_RESOURCE);
|
||||
}
|
||||
|
||||
// Sanity check for variable destruction
|
||||
TEST_F(VariableOpsTest, DestroyVariableSuccessful) {
|
||||
// Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor
|
||||
AbstractTensorHandlePtr handle;
|
||||
TF_EXPECT_OK(internal::CreateUninitializedResourceVariable(
|
||||
context(), DT_FLOAT, {}, &handle));
|
||||
|
||||
// Destroy the variable
|
||||
TF_EXPECT_OK(internal::DestroyResource(context(), handle.get()));
|
||||
}
|
||||
|
||||
// Sanity check for handle assignment and reading
|
||||
TEST_F(VariableOpsTest, AssignVariableAndReadSuccessful) {
|
||||
// Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor
|
||||
AbstractTensorHandlePtr variable;
|
||||
TF_EXPECT_OK(internal::CreateUninitializedResourceVariable(
|
||||
context(), DT_FLOAT, {}, &variable));
|
||||
|
||||
// Create a Scalar float TensorHandle with value 42, and assign it to
|
||||
// the variable.
|
||||
AbstractTensorHandlePtr my_value = CreateScalarTensorHandle(context(), 42.0);
|
||||
TF_EXPECT_OK(internal::AssignVariable(context(), variable.get(), DT_FLOAT,
|
||||
my_value.get()));
|
||||
|
||||
// Read back the value from the variable, and check that it is 42.
|
||||
AbstractTensorHandlePtr read_value_handle;
|
||||
TF_EXPECT_OK(internal::ReadVariable(context(), variable.get(), DT_FLOAT,
|
||||
&read_value_handle));
|
||||
Status status;
|
||||
AbstractTensorPtr read_value(read_value_handle->Resolve(&status));
|
||||
TF_EXPECT_OK(status);
|
||||
EXPECT_FLOAT_EQ(42.0, *static_cast<float*>(read_value->Data()));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -51,7 +52,7 @@ std::vector<ConcreteFunction*> TFSavedModelAPIImpl::ListFunctions() {
|
||||
Status TFSavedModelAPIImpl::Load(
|
||||
const std::string& directory,
|
||||
const absl::optional<std::unordered_set<std::string>>& tags,
|
||||
TFSavedModelAPIImpl* out) {
|
||||
EagerContext* context, std::unique_ptr<TFSavedModelAPIImpl>* out) {
|
||||
// TODO(bmzhao): Add support for loading a TFSavedModelImpl.
|
||||
return errors::Unimplemented(
|
||||
"TFSavedModelAPIImpl loading is unimplemented currently");
|
||||
|
@ -23,14 +23,13 @@ limitations under the License.
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class TFSavedModelAPIImpl : public SavedModelAPI {
|
||||
public:
|
||||
TFSavedModelAPIImpl() = default;
|
||||
|
||||
Status GetFunction(const std::string& function_path,
|
||||
ConcreteFunction** function) override;
|
||||
|
||||
@ -40,13 +39,14 @@ class TFSavedModelAPIImpl : public SavedModelAPI {
|
||||
static Status Load(
|
||||
const std::string& directory,
|
||||
const absl::optional<std::unordered_set<std::string>>& tags,
|
||||
TFSavedModelAPIImpl* out);
|
||||
EagerContext* context, std::unique_ptr<TFSavedModelAPIImpl>* out);
|
||||
|
||||
std::vector<ConcreteFunction*> ListFunctions() override;
|
||||
|
||||
~TFSavedModelAPIImpl() override = default;
|
||||
|
||||
private:
|
||||
TFSavedModelAPIImpl() = default;
|
||||
std::vector<ConcreteFunction> functions_;
|
||||
};
|
||||
|
||||
|
@ -144,7 +144,9 @@ cc_library(
|
||||
"//tensorflow/c:tf_status_internal",
|
||||
"//tensorflow/c/eager:tfe_context_internal",
|
||||
"//tensorflow/c/experimental/saved_model/core:saved_model_api",
|
||||
"//tensorflow/c/experimental/saved_model/core:tf_saved_model_impl",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
@ -176,7 +178,7 @@ cc_library(
|
||||
":tensorhandle_list_type",
|
||||
"//tensorflow/c:c_api_macros",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:tensor_handle_interface",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/c/eager:tfe_tensorhandle_internal",
|
||||
],
|
||||
)
|
||||
@ -188,7 +190,7 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:conversion_macros",
|
||||
"//tensorflow/c/eager:tensor_handle_interface",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -22,11 +22,15 @@ limitations under the License.
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/c/eager/tfe_context_internal.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/tf_saved_model_impl.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_list_type.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_status_internal.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
extern "C" {
|
||||
@ -34,10 +38,21 @@ extern "C" {
|
||||
TF_SavedModel* TF_LoadSavedModel(const char* dirname, TFE_Context* ctx,
|
||||
TF_Status* status) {
|
||||
std::string saved_model_dir(dirname);
|
||||
std::unique_ptr<tensorflow::SavedModelAPI> result;
|
||||
|
||||
if (tensorflow::unwrap(ctx)->UsesTFRT()) {
|
||||
status->status = tensorflow::errors::Unimplemented(
|
||||
"TFRT SavedModel implementation will be added in the future");
|
||||
} else {
|
||||
std::unique_ptr<tensorflow::TFSavedModelAPIImpl> saved_model;
|
||||
status->status = tensorflow::TFSavedModelAPIImpl::Load(
|
||||
dirname, absl::nullopt,
|
||||
tensorflow::down_cast<tensorflow::EagerContext*>(
|
||||
tensorflow::unwrap(ctx)),
|
||||
&saved_model);
|
||||
result = std::move(saved_model);
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::SavedModelAPI> result =
|
||||
tensorflow::unwrap(ctx)->LoadSavedModelAPI(dirname, absl::nullopt,
|
||||
&status->status);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
@ -54,9 +69,20 @@ TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx,
|
||||
tagset.insert(std::string(tags[i]));
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::SavedModelAPI> result =
|
||||
tensorflow::unwrap(ctx)->LoadSavedModelAPI(dirname, std::move(tagset),
|
||||
&status->status);
|
||||
std::unique_ptr<tensorflow::SavedModelAPI> result;
|
||||
if (tensorflow::unwrap(ctx)->UsesTFRT()) {
|
||||
status->status = tensorflow::errors::Unimplemented(
|
||||
"TFRT SavedModel implementation will be added in the future");
|
||||
} else {
|
||||
std::unique_ptr<tensorflow::TFSavedModelAPIImpl> saved_model;
|
||||
status->status = tensorflow::TFSavedModelAPIImpl::Load(
|
||||
dirname, tagset,
|
||||
tensorflow::down_cast<tensorflow::EagerContext*>(
|
||||
tensorflow::unwrap(ctx)),
|
||||
&saved_model);
|
||||
result = std::move(saved_model);
|
||||
}
|
||||
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h"
|
||||
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
|
||||
// Internal structures used by the SavedModel C API. These are likely to
|
||||
// change and should not be depended on.
|
||||
@ -29,7 +29,7 @@ typedef struct TF_TensorHandleList TF_TensorHandleList;
|
||||
namespace tensorflow {
|
||||
|
||||
DEFINE_CONVERSION_FUNCTIONS(
|
||||
std::vector<tensorflow::AbstractTensorHandleInterface*>,
|
||||
std::vector<tensorflow::ImmediateExecutionTensorHandle*>,
|
||||
TF_TensorHandleList)
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -106,6 +106,7 @@ cc_library(
|
||||
hdrs = ["loader.h"],
|
||||
deps = [
|
||||
":constants",
|
||||
":loader_util",
|
||||
":reader",
|
||||
] + if_not_mobile([
|
||||
"//tensorflow/core:core_cpu",
|
||||
@ -132,6 +133,17 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "loader_util",
|
||||
srcs = ["loader_util.cc"],
|
||||
hdrs = ["loader_util.h"],
|
||||
deps = [":constants"] + if_not_mobile([
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
]),
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "bundle_v2_test",
|
||||
srcs = ["bundle_v2_test.cc"],
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include <unordered_set>
|
||||
|
||||
#include "tensorflow/cc/saved_model/constants.h"
|
||||
#include "tensorflow/cc/saved_model/loader_util.h"
|
||||
#include "tensorflow/cc/saved_model/reader.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
@ -29,7 +30,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/protobuf_internal.h"
|
||||
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
|
||||
#include "tensorflow/core/protobuf/saver.pb.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
@ -191,41 +191,6 @@ Status RunInitOp(const RunOptions& run_options, const string& export_dir,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// A SavedModel may store the name of the initialization op to run in the
|
||||
// in the SignatureDef (v2) or a collection (v1). If an init_op collection
|
||||
// exists, then the collection must contain exactly one op.
|
||||
Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def,
|
||||
string* init_op_name) {
|
||||
const auto& sig_def_map = meta_graph_def.signature_def();
|
||||
const auto& init_op_sig_it =
|
||||
meta_graph_def.signature_def().find(kSavedModelInitOpSignatureKey);
|
||||
if (init_op_sig_it != sig_def_map.end()) {
|
||||
*init_op_name = init_op_sig_it->second.outputs()
|
||||
.find(kSavedModelInitOpSignatureKey)
|
||||
->second.name();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const auto& collection_def_map = meta_graph_def.collection_def();
|
||||
string init_op_collection_key;
|
||||
if (collection_def_map.find(kSavedModelMainOpKey) !=
|
||||
collection_def_map.end()) {
|
||||
init_op_collection_key = kSavedModelMainOpKey;
|
||||
} else {
|
||||
init_op_collection_key = kSavedModelLegacyInitOpKey;
|
||||
}
|
||||
|
||||
const auto init_op_it = collection_def_map.find(init_op_collection_key);
|
||||
if (init_op_it != collection_def_map.end()) {
|
||||
if (init_op_it->second.node_list().value_size() != 1) {
|
||||
return errors::FailedPrecondition(
|
||||
strings::StrCat("Expected exactly one main op in : ", export_dir));
|
||||
}
|
||||
*init_op_name = init_op_it->second.node_list().value(0);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RunRestore(const RunOptions& run_options, const string& export_dir,
|
||||
const StringPiece restore_op_name,
|
||||
const StringPiece variable_filename_const_op_name,
|
||||
@ -263,32 +228,6 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir,
|
||||
nullptr /* outputs */, &run_metadata, session);
|
||||
}
|
||||
|
||||
Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def,
|
||||
std::vector<AssetFileDef>* asset_file_defs) {
|
||||
// With SavedModel v2, we write asset file def into metagraph instead of
|
||||
// collection, so read from metagraph first.
|
||||
if (meta_graph_def.asset_file_def_size() > 0) {
|
||||
for (const auto& asset : meta_graph_def.asset_file_def()) {
|
||||
asset_file_defs->push_back(asset);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
// Fall back to read from collection to be backward compatible with v1.
|
||||
const auto& collection_def_map = meta_graph_def.collection_def();
|
||||
const auto assets_it = collection_def_map.find(kSavedModelAssetsKey);
|
||||
if (assets_it == collection_def_map.end()) {
|
||||
return Status::OK();
|
||||
}
|
||||
const auto& any_assets = assets_it->second.any_list().value();
|
||||
for (const auto& any_asset : any_assets) {
|
||||
AssetFileDef asset_file_def;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ParseAny(any_asset, &asset_file_def, "tensorflow.AssetFileDef"));
|
||||
asset_file_defs->push_back(asset_file_def);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ReadSavedModelDebugInfoIfPresent(
|
||||
const string& export_dir,
|
||||
std::unique_ptr<GraphDebugInfo>* debug_info_proto) {
|
||||
@ -322,7 +261,7 @@ Status LoadSavedModelInternal(const SessionOptions& session_options,
|
||||
|
||||
std::vector<AssetFileDef> asset_file_defs;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetAssetFileDefs(bundle->meta_graph_def, &asset_file_defs));
|
||||
internal::GetAssetFileDefs(bundle->meta_graph_def, &asset_file_defs));
|
||||
TF_RETURN_IF_ERROR(
|
||||
RunRestore(run_options, export_dir,
|
||||
bundle->meta_graph_def.saver_def().restore_op_name(),
|
||||
@ -336,7 +275,7 @@ Status LoadSavedModelInternal(const SessionOptions& session_options,
|
||||
const uint64 graph_init_start_microseconds = Env::Default()->NowMicros();
|
||||
string init_op_name;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetInitOp(export_dir, bundle->meta_graph_def, &init_op_name));
|
||||
internal::GetInitOp(export_dir, bundle->meta_graph_def, &init_op_name));
|
||||
TF_RETURN_IF_ERROR(RunInitOp(run_options, export_dir, bundle->meta_graph_def,
|
||||
asset_file_defs, bundle->session.get(),
|
||||
init_op_name));
|
||||
|
90
tensorflow/cc/saved_model/loader_util.cc
Normal file
90
tensorflow/cc/saved_model/loader_util.cc
Normal file
@ -0,0 +1,90 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/saved_model/loader_util.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/cc/saved_model/constants.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/protobuf_internal.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
// A SavedModel may store the name of the initialization op to run in the
|
||||
// in the SignatureDef (v2) or a collection (v1). If an init_op collection
|
||||
// exists, then the collection must contain exactly one op.
|
||||
Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def,
|
||||
string* init_op_name) {
|
||||
const auto& sig_def_map = meta_graph_def.signature_def();
|
||||
const auto& init_op_sig_it =
|
||||
meta_graph_def.signature_def().find(kSavedModelInitOpSignatureKey);
|
||||
if (init_op_sig_it != sig_def_map.end()) {
|
||||
*init_op_name = init_op_sig_it->second.outputs()
|
||||
.find(kSavedModelInitOpSignatureKey)
|
||||
->second.name();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const auto& collection_def_map = meta_graph_def.collection_def();
|
||||
string init_op_collection_key;
|
||||
if (collection_def_map.find(kSavedModelMainOpKey) !=
|
||||
collection_def_map.end()) {
|
||||
init_op_collection_key = kSavedModelMainOpKey;
|
||||
} else {
|
||||
init_op_collection_key = kSavedModelLegacyInitOpKey;
|
||||
}
|
||||
|
||||
const auto init_op_it = collection_def_map.find(init_op_collection_key);
|
||||
if (init_op_it != collection_def_map.end()) {
|
||||
if (init_op_it->second.node_list().value_size() != 1) {
|
||||
return errors::FailedPrecondition(
|
||||
strings::StrCat("Expected exactly one main op in : ", export_dir));
|
||||
}
|
||||
*init_op_name = init_op_it->second.node_list().value(0);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def,
|
||||
std::vector<AssetFileDef>* asset_file_defs) {
|
||||
// With SavedModel v2, we write asset file def into metagraph instead of
|
||||
// collection, so read from metagraph first.
|
||||
if (meta_graph_def.asset_file_def_size() > 0) {
|
||||
for (const auto& asset : meta_graph_def.asset_file_def()) {
|
||||
asset_file_defs->push_back(asset);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
// Fall back to read from collection to be backward compatible with v1.
|
||||
const auto& collection_def_map = meta_graph_def.collection_def();
|
||||
const auto assets_it = collection_def_map.find(kSavedModelAssetsKey);
|
||||
if (assets_it == collection_def_map.end()) {
|
||||
return Status::OK();
|
||||
}
|
||||
const auto& any_assets = assets_it->second.any_list().value();
|
||||
for (const auto& any_asset : any_assets) {
|
||||
AssetFileDef asset_file_def;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ParseAny(any_asset, &asset_file_def, "tensorflow.AssetFileDef"));
|
||||
asset_file_defs->push_back(asset_file_def);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
39
tensorflow/cc/saved_model/loader_util.h
Normal file
39
tensorflow/cc/saved_model/loader_util.h
Normal file
@ -0,0 +1,39 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CC_SAVED_MODEL_LOADER_UTIL_H_
|
||||
#define TENSORFLOW_CC_SAVED_MODEL_LOADER_UTIL_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/protobuf/meta_graph.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
// A SavedModel may store the name of the initialization op to run in the
|
||||
// in the SignatureDef (v2) or a collection (v1). If an init_op collection
|
||||
// exists, then the collection must contain exactly one op.
|
||||
Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def,
|
||||
string* init_op_name);
|
||||
|
||||
Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def,
|
||||
std::vector<AssetFileDef>* asset_file_defs);
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_SAVED_MODEL_LOADER_UTIL_H_
|
@ -67,13 +67,13 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@llvm-project//llvm:arm_target", # fixdeps: keep
|
||||
"@llvm-project//llvm:powerpc_target", # fixdeps: keep
|
||||
"@llvm-project//llvm:target_base",
|
||||
"@llvm-project//llvm:x86_target", # fixdeps: keep
|
||||
"@llvm-project//llvm:ARMCodeGen", # fixdeps: keep
|
||||
"@llvm-project//llvm:PowerPCCodeGen", # fixdeps: keep
|
||||
"@llvm-project//llvm:Target",
|
||||
"@llvm-project//llvm:X86CodeGen", # fixdeps: keep
|
||||
"//tensorflow/core:regexp_internal",
|
||||
] + if_llvm_aarch64_available([
|
||||
"@llvm-project//llvm:aarch64_target", # fixdeps: keep
|
||||
"@llvm-project//llvm:AArch64CodeGen", # fixdeps: keep
|
||||
]),
|
||||
)
|
||||
|
||||
@ -94,8 +94,8 @@ tf_cc_test(
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/platform:resource_loader",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support", # fixdeps: keep
|
||||
"@llvm-project//llvm:x86_target", # fixdeps: keep
|
||||
"@llvm-project//llvm:Support", # fixdeps: keep
|
||||
"@llvm-project//llvm:X86CodeGen", # fixdeps: keep
|
||||
],
|
||||
)
|
||||
|
||||
@ -109,12 +109,12 @@ cc_library(
|
||||
name = "llvm_targets",
|
||||
visibility = ["//tensorflow/python:__pkg__"],
|
||||
deps = [
|
||||
"@llvm-project//llvm:arm_target", # fixdeps: keep
|
||||
"@llvm-project//llvm:powerpc_target", # fixdeps: keep
|
||||
"@llvm-project//llvm:target_base",
|
||||
"@llvm-project//llvm:x86_target", # fixdeps: keep
|
||||
"@llvm-project//llvm:ARMCodeGen", # fixdeps: keep
|
||||
"@llvm-project//llvm:PowerPCCodeGen", # fixdeps: keep
|
||||
"@llvm-project//llvm:Target",
|
||||
"@llvm-project//llvm:X86CodeGen", # fixdeps: keep
|
||||
] + if_llvm_aarch64_available([
|
||||
"@llvm-project//llvm:aarch64_target", # fixdeps: keep
|
||||
"@llvm-project//llvm:AArch64CodeGen", # fixdeps: keep
|
||||
]),
|
||||
)
|
||||
|
||||
@ -286,9 +286,9 @@ cc_library(
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@llvm-project//llvm:ir",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:target_base",
|
||||
"@llvm-project//llvm:Core",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//llvm:Target",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
# RUN: not tfcompile --graph=%s --config=%s.config.pbtxt --mlir_components=Bridge --debug_info=%s.debug.pbtxt 2>&1 | FileCheck %s -dump-input-on-failure
|
||||
# RUN: not tfcompile --graph=%s --config=%s.config.pbtxt --mlir_components=None 2>&1 | FileCheck -check-prefix=OLD %s -dump-input-on-failure
|
||||
# RUN: not tfcompile --graph=%s --config=%s.config.pbtxt --mlir_components=Bridge --debug_info=%s.debug.pbtxt 2>&1 | FileCheck %s
|
||||
# RUN: not tfcompile --graph=%s --config=%s.config.pbtxt --mlir_components=None 2>&1 | FileCheck -check-prefix=OLD %s
|
||||
|
||||
# Checks the error message produced by tfcompile with mlir_component
|
||||
# Checks that source debug information is used in the output error message and
|
||||
|
@ -4,6 +4,7 @@ traces: {
|
||||
value: {
|
||||
file_line_cols: {
|
||||
line: 1
|
||||
col: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -12,9 +13,11 @@ traces: {
|
||||
value: {
|
||||
file_line_cols: {
|
||||
line: 3
|
||||
col: 1
|
||||
}
|
||||
file_line_cols: {
|
||||
line: 4
|
||||
col: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -23,6 +26,7 @@ traces: {
|
||||
value: {
|
||||
file_line_cols: {
|
||||
line: 2
|
||||
col: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -33,6 +33,7 @@ MarkForCompilationPassFlags* mark_for_compilation_flags;
|
||||
XlaDeviceFlags* device_flags;
|
||||
XlaOpsCommonFlags* ops_flags;
|
||||
IntroduceFloatingPointJitterPassFlags* jitter_flags;
|
||||
MlirCommonFlags* mlir_flags;
|
||||
|
||||
std::vector<Flag>* flag_list;
|
||||
absl::once_flag flags_init;
|
||||
@ -166,6 +167,9 @@ void AllocateAndParseFlags() {
|
||||
jitter_flags = new IntroduceFloatingPointJitterPassFlags;
|
||||
jitter_flags->jitter_amount = 1e-5;
|
||||
|
||||
mlir_flags = new MlirCommonFlags;
|
||||
mlir_flags->tf_mlir_enable_mlir_bridge = false;
|
||||
|
||||
auto setter_for_jitter_tensor_names = [](string sequence) {
|
||||
jitter_flags->tensor_names = absl::StrSplit(sequence, ',');
|
||||
return true;
|
||||
@ -211,7 +215,11 @@ void AllocateAndParseFlags() {
|
||||
Flag("tf_introduce_floating_point_jitter_amount",
|
||||
&jitter_flags->jitter_amount,
|
||||
"The amount of jitter to introduce. This amount is added to each "
|
||||
"element in the tensors named in `tensor_names.")});
|
||||
"element in the tensors named in `tensor_names."),
|
||||
|
||||
Flag("tf_mlir_enable_mlir_bridge",
|
||||
&mlir_flags->tf_mlir_enable_mlir_bridge,
|
||||
"Enables experimental MLIR-Based TensorFlow Compiler Bridge.")});
|
||||
|
||||
AppendMarkForCompilationPassFlagsInternal(flag_list);
|
||||
xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list);
|
||||
@ -250,6 +258,11 @@ GetIntroduceFloatingPointJitterPassFlags() {
|
||||
return *jitter_flags;
|
||||
}
|
||||
|
||||
MlirCommonFlags* GetMlirCommonFlags() {
|
||||
absl::call_once(flags_init, &AllocateAndParseFlags);
|
||||
return mlir_flags;
|
||||
}
|
||||
|
||||
void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) {
|
||||
absl::call_once(flags_init, &AllocateAndParseFlags);
|
||||
AppendMarkForCompilationPassFlagsInternal(flag_list);
|
||||
|
@ -133,6 +133,11 @@ struct IntroduceFloatingPointJitterPassFlags {
|
||||
std::vector<string> tensor_names;
|
||||
};
|
||||
|
||||
// Flags for common MLIR configurations.
|
||||
struct MlirCommonFlags {
|
||||
bool tf_mlir_enable_mlir_bridge;
|
||||
};
|
||||
|
||||
// Return a pointer to the DumpGraphFlags struct;
|
||||
// repeated calls return the same pointer.
|
||||
// This should be called only after Flags::Parse() has returned.
|
||||
@ -148,6 +153,8 @@ const XlaOpsCommonFlags& GetXlaOpsCommonFlags();
|
||||
const IntroduceFloatingPointJitterPassFlags&
|
||||
GetIntroduceFloatingPointJitterPassFlags();
|
||||
|
||||
MlirCommonFlags* GetMlirCommonFlags();
|
||||
|
||||
// Appends the flag definitions associated with
|
||||
// MarkForCompilationPassFlags/DumpGraphFlags to `flag_list`.
|
||||
//
|
||||
|
@ -195,53 +195,46 @@ XlaComputationLaunchContext::XlaComputationLaunchContext(
|
||||
}
|
||||
|
||||
void XlaComputationLaunchContext::PopulateInputs(
|
||||
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
|
||||
OpKernelContext* ctx,
|
||||
const XlaCompiler::CompilationResult* compilation_result,
|
||||
const std::map<int, OptionalTensor>& variables,
|
||||
int missing_ctx_input_prefix) {
|
||||
se::Stream* stream =
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
||||
// Build ShapedBuffers that point directly to the Tensor buffers.
|
||||
arg_ptrs_ = std::vector<ShapedBuffer*>(kernel->xla_input_shapes.size());
|
||||
arg_ptrs_ =
|
||||
std::vector<ShapedBuffer*>(compilation_result->xla_input_shapes.size());
|
||||
|
||||
// Pass remaining parameters.
|
||||
const Tensor* t;
|
||||
for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) {
|
||||
int arg_num = kernel->input_mapping[i];
|
||||
DCHECK_GE(arg_num, missing_ctx_input_prefix);
|
||||
const xla::Shape& shape = kernel->xla_input_shapes[i];
|
||||
if (variables.count(arg_num)) {
|
||||
t = &(variables.at(arg_num).value);
|
||||
CHECK(t);
|
||||
} else {
|
||||
t = &(ctx->input(arg_num - missing_ctx_input_prefix));
|
||||
}
|
||||
xla::TransferManager* transfer_manager =
|
||||
client_->backend().transfer_manager();
|
||||
for (int i = 0; i < compilation_result->xla_input_shapes.size(); ++i) {
|
||||
int arg_num = compilation_result->input_mapping[i];
|
||||
CHECK_GE(arg_num, missing_ctx_input_prefix);
|
||||
const xla::Shape& shape = compilation_result->xla_input_shapes[i];
|
||||
const Tensor* t = variables.count(arg_num)
|
||||
? &(variables.at(arg_num).value)
|
||||
: &(ctx->input(arg_num - missing_ctx_input_prefix));
|
||||
CHECK(t);
|
||||
|
||||
if (use_multiple_streams_) {
|
||||
CHECK(stream) << "Must have a stream available when using XLA tensors!";
|
||||
CHECK(ctx->op_device_context() && ctx->op_device_context()->stream())
|
||||
<< "Must have a stream available when using XLA tensors!";
|
||||
XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
|
||||
CHECK(xla_tensor);
|
||||
xla_tensor->WaitForDefinitionEventOnStream(stream);
|
||||
xla_tensor->WaitForDefinitionEventOnStream(
|
||||
ctx->op_device_context()->stream());
|
||||
}
|
||||
|
||||
const xla::Shape on_device_shape =
|
||||
client_->backend().transfer_manager()->HostShapeToDeviceShape(shape);
|
||||
if (on_device_shape.IsTuple()) {
|
||||
const XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
|
||||
CHECK(xla_tensor && xla_tensor->has_shaped_buffer());
|
||||
arg_ptrs_[i] = const_cast<ShapedBuffer*>(&xla_tensor->shaped_buffer());
|
||||
} else {
|
||||
CHECK(xla::Shape::Equal().MinorToMajorOnlyInLayout()(shape,
|
||||
on_device_shape))
|
||||
<< "On-device shape "
|
||||
<< xla::ShapeUtil::HumanStringWithLayout(on_device_shape)
|
||||
<< " not the same as on-host shape "
|
||||
<< xla::ShapeUtil::HumanStringWithLayout(shape);
|
||||
if (xla::Shape::Equal().MinorToMajorOnlyInLayout()(
|
||||
shape, transfer_manager->HostShapeToDeviceShape(shape))) {
|
||||
se::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t);
|
||||
arg_buffers_.emplace_back(
|
||||
/*on_host_shape=*/shape, /*on_device_shape=*/shape,
|
||||
client_->platform(), client_->default_device_ordinal());
|
||||
arg_buffers_.back().set_buffer(dmem, /*index=*/{});
|
||||
arg_ptrs_[i] = &arg_buffers_.back();
|
||||
} else {
|
||||
const XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
|
||||
CHECK(xla_tensor && xla_tensor->has_shaped_buffer());
|
||||
arg_ptrs_[i] = const_cast<ShapedBuffer*>(&xla_tensor->shaped_buffer());
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -370,13 +363,94 @@ static Status SetBufferForResourceVarTensorUnderAllocateXlaTensors(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Sets output `output_num` for `ctx` provided it is known at a compile time.
|
||||
static Status SetOutputForConstant(
|
||||
OpKernelContext* ctx, se::Stream* stream,
|
||||
const XlaCompiler::CompilationResult* compilation_result, int output_num) {
|
||||
CHECK(compilation_result->outputs[output_num].is_constant);
|
||||
// Output is a constant.
|
||||
const Tensor& const_tensor =
|
||||
compilation_result->outputs[output_num].constant_value;
|
||||
Tensor* output_tensor;
|
||||
const size_t total_bytes = const_tensor.TotalBytes();
|
||||
if (stream && total_bytes > 0) {
|
||||
// Copy host -> device. (Empty tensors don't have backing buffers.)
|
||||
// Manually allocate memory using an XlaTensorBuffer so we can allocate
|
||||
// as much memory as the device requires (as given by
|
||||
// GetByteSizeRequirement). This avoids XlaTransferManager having to
|
||||
// reallocate the device buffer later.
|
||||
VLOG(1) << "Constant output tensor on device";
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
ctx->allocate_output(output_num, const_tensor.shape(), &output_tensor));
|
||||
Device* device = dynamic_cast<Device*>(ctx->device());
|
||||
if (device == nullptr) {
|
||||
return errors::Internal("DeviceBase was not a Device.");
|
||||
}
|
||||
ctx->op_device_context()->CopyCPUTensorToDevice(
|
||||
&const_tensor, device, output_tensor,
|
||||
[&](Status status) { TF_CHECK_OK(status); });
|
||||
|
||||
if (device->device_type() == DEVICE_GPU) {
|
||||
// The GPUDeviceContext enqueues the host->device transfer in a
|
||||
// separate stream from the main compute stream. We must ensure the
|
||||
// compute stream is synchronized with the host->device transfer
|
||||
// stream now otherwise we will create a race condition.
|
||||
auto* gpu_device_context =
|
||||
static_cast<GPUDeviceContext*>(ctx->op_device_context());
|
||||
gpu_device_context->stream()->ThenWaitFor(
|
||||
gpu_device_context->host_to_device_stream());
|
||||
}
|
||||
} else {
|
||||
// No copy required.
|
||||
ctx->set_output(output_num, const_tensor);
|
||||
output_tensor = ctx->mutable_output(output_num);
|
||||
}
|
||||
if (XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor)) {
|
||||
xla_tensor->set_host_tensor(const_tensor);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Creates a list of updates resource variables.
|
||||
static xla::StatusOr<std::vector<VariableInfo>> GatherVariableInfo(
|
||||
OpKernelContext* ctx,
|
||||
const XlaCompiler::CompilationResult* compilation_result,
|
||||
int missing_ctx_input_prefix) {
|
||||
std::vector<VariableInfo> variable_infos;
|
||||
variable_infos.reserve(compilation_result->resource_updates.size());
|
||||
|
||||
for (int i = 0; i < compilation_result->resource_updates.size(); ++i) {
|
||||
const XlaCompiler::ResourceUpdate& write =
|
||||
compilation_result->resource_updates[i];
|
||||
int actual_input_index = write.input_index - missing_ctx_input_prefix;
|
||||
if (actual_input_index < 0 || actual_input_index >= ctx->num_inputs()) {
|
||||
return errors::Internal("Invalid input index for variable write.");
|
||||
}
|
||||
|
||||
// TODO(b/35625933): tensorflow::Var should contain a PersistentTensor,
|
||||
// not a Tensor.
|
||||
Var* variable = nullptr;
|
||||
TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>(
|
||||
ctx, HandleFromInput(ctx, actual_input_index), &variable,
|
||||
[&write](Var** ptr) {
|
||||
*ptr = new Var(write.type);
|
||||
return Status::OK();
|
||||
}));
|
||||
variable_infos.emplace_back(actual_input_index, variable);
|
||||
}
|
||||
return variable_infos;
|
||||
}
|
||||
|
||||
Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
|
||||
OpKernelContext* ctx,
|
||||
const XlaCompiler::CompilationResult* compilation_result,
|
||||
ScopedShapedBuffer output, int missing_ctx_input_prefix,
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias,
|
||||
const std::map<int, OptionalTensor>& resource_var_snapshots) {
|
||||
se::Stream* stream =
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
||||
Allocator* allocator = ctx->device()->GetAllocator({});
|
||||
|
||||
// Computation output should always be a tuple.
|
||||
if (VLOG_IS_ON(2)) {
|
||||
@ -384,7 +458,7 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
VLOG(2) << "Result tuple shape (on device): "
|
||||
<< output.on_device_shape().DebugString();
|
||||
}
|
||||
CHECK_EQ(ctx->num_outputs(), kernel->outputs.size());
|
||||
CHECK_EQ(ctx->num_outputs(), compilation_result->outputs.size());
|
||||
|
||||
// If the on-host-shape isn't a tuple, create a new single-element tuple
|
||||
// buffer with a nullptr root index table. This allows the code below to treat
|
||||
@ -413,86 +487,41 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
// Copy XLA results to the OpOutputList.
|
||||
int output_num = 0;
|
||||
for (int i = 0; i < ctx->num_outputs(); ++i) {
|
||||
Allocator* allocator = ctx->device()->GetAllocator({});
|
||||
if (kernel->outputs[i].is_constant) {
|
||||
// Output is a constant.
|
||||
const Tensor& const_tensor = kernel->outputs[i].constant_value;
|
||||
Tensor* output_tensor;
|
||||
const size_t total_bytes = const_tensor.TotalBytes();
|
||||
if (stream && total_bytes > 0) {
|
||||
// Copy host -> device. (Empty tensors don't have backing buffers.)
|
||||
// Manually allocate memory using an XlaTensorBuffer so we can allocate
|
||||
// as much memory as the device requires (as given by
|
||||
// GetByteSizeRequirement). This avoids XlaTransferManager having to
|
||||
// reallocate the device buffer later.
|
||||
VLOG(1) << "Constant output tensor on device";
|
||||
const TensorShape& shape = compilation_result->outputs[i].shape;
|
||||
const DataType& type = compilation_result->outputs[i].type;
|
||||
VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type "
|
||||
<< DataTypeString(type);
|
||||
if (type == DT_VARIANT) {
|
||||
return errors::Unimplemented(
|
||||
"Support for TensorList crossing the XLA/TF boundary "
|
||||
"is not implemented");
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
ctx->allocate_output(i, const_tensor.shape(), &output_tensor));
|
||||
|
||||
Device* device = dynamic_cast<Device*>(ctx->device());
|
||||
if (device == nullptr) {
|
||||
return errors::Internal("DeviceBase was not a Device.");
|
||||
}
|
||||
ctx->op_device_context()->CopyCPUTensorToDevice(
|
||||
&const_tensor, device, output_tensor,
|
||||
[&](Status status) { TF_CHECK_OK(status); });
|
||||
|
||||
if (device->device_type() == DEVICE_GPU) {
|
||||
// The GPUDeviceContext enqueues the host->device transfer in a
|
||||
// separate stream from the main compute stream. We must ensure the
|
||||
// compute stream is synchronized with the host->device transfer
|
||||
// stream now otherwise we will create a race condition.
|
||||
auto* gpu_device_context =
|
||||
static_cast<GPUDeviceContext*>(ctx->op_device_context());
|
||||
gpu_device_context->stream()->ThenWaitFor(
|
||||
gpu_device_context->host_to_device_stream());
|
||||
}
|
||||
} else {
|
||||
// No copy required.
|
||||
ctx->set_output(i, const_tensor);
|
||||
output_tensor = ctx->mutable_output(i);
|
||||
}
|
||||
if (XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor)) {
|
||||
xla_tensor->set_host_tensor(const_tensor);
|
||||
}
|
||||
if (compilation_result->outputs[i].is_constant) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
SetOutputForConstant(ctx, stream, compilation_result, i));
|
||||
} else if (type == DT_RESOURCE) {
|
||||
int input_index =
|
||||
compilation_result->outputs[i].input_index - missing_ctx_input_prefix;
|
||||
TF_RET_CHECK(input_index >= 0 && input_index < ctx->num_inputs())
|
||||
<< "Invalid input for outputs " << i << ": " << input_index;
|
||||
ctx->set_output(i, ctx->input(input_index));
|
||||
} else {
|
||||
const TensorShape& shape = kernel->outputs[i].shape;
|
||||
const DataType& type = kernel->outputs[i].type;
|
||||
VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type "
|
||||
<< DataTypeString(type);
|
||||
if (type == DT_RESOURCE) {
|
||||
int input_index =
|
||||
kernel->outputs[i].input_index - missing_ctx_input_prefix;
|
||||
TF_RET_CHECK(input_index >= 0 && input_index < ctx->num_inputs())
|
||||
<< "Invalid input for outputs " << i << ": " << input_index;
|
||||
ctx->set_output(i, ctx->input(input_index));
|
||||
} else {
|
||||
if (MustAliasOutput(input_output_alias, output_num)) {
|
||||
DCHECK(output.buffer({output_num}).is_null())
|
||||
<< "Expected output buffer to be aliased, but it is not nil.";
|
||||
}
|
||||
if (allocate_xla_tensors_) {
|
||||
TF_RETURN_IF_ERROR(SetBufferForTensorUnderAllocateXlaTensors(
|
||||
input_output_alias, output_num, ctx, i, shape, &output,
|
||||
definition_event, stream, use_multiple_streams_));
|
||||
} else {
|
||||
if (type == DT_VARIANT) {
|
||||
return errors::Unimplemented(
|
||||
"Support for TensorList crossing the XLA/TF boundary "
|
||||
"is not implemented");
|
||||
}
|
||||
if (allocate_xla_tensors_) {
|
||||
TF_RETURN_IF_ERROR(SetBufferForTensorUnderAllocateXlaTensors(
|
||||
input_output_alias, output_num, ctx, i, shape, &output,
|
||||
definition_event, stream, use_multiple_streams_));
|
||||
|
||||
se::DeviceMemoryBase buffer = output.buffer({output_num});
|
||||
Tensor output_tensor = GetOrCreateTensorForOutput(
|
||||
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
|
||||
kernel->input_mapping, resource_var_snapshots,
|
||||
ctx->expected_output_dtype(i), shape, buffer, allocator);
|
||||
output.set_buffer(se::OwningDeviceMemory(), {output_num});
|
||||
ctx->set_output(i, output_tensor);
|
||||
}
|
||||
++output_num;
|
||||
} else {
|
||||
se::DeviceMemoryBase buffer = output.buffer({output_num});
|
||||
Tensor output_tensor = GetOrCreateTensorForOutput(
|
||||
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
|
||||
compilation_result->input_mapping, resource_var_snapshots,
|
||||
ctx->expected_output_dtype(i), shape, buffer, allocator);
|
||||
output.set_buffer(se::OwningDeviceMemory(), {output_num});
|
||||
ctx->set_output(i, output_tensor);
|
||||
}
|
||||
++output_num;
|
||||
}
|
||||
|
||||
if (VLOG_IS_ON(3)) {
|
||||
@ -502,34 +531,14 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
|
||||
// Apply variable updates, if any.
|
||||
VLOG(2) << "Applying variable updates";
|
||||
std::vector<VariableInfo> variable_infos;
|
||||
variable_infos.reserve(kernel->resource_updates.size());
|
||||
|
||||
for (int i = 0; i < kernel->resource_updates.size(); ++i) {
|
||||
const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i];
|
||||
int actual_input_index = write.input_index - missing_ctx_input_prefix;
|
||||
if (actual_input_index < 0 || actual_input_index >= ctx->num_inputs()) {
|
||||
return errors::Internal("Invalid input index for variable write.");
|
||||
}
|
||||
|
||||
// TODO(b/35625933): tensorflow::Var should contain a PersistentTensor,
|
||||
// not a Tensor.
|
||||
Var* variable = nullptr;
|
||||
TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>(
|
||||
ctx, HandleFromInput(ctx, actual_input_index), &variable,
|
||||
[&write](Var** ptr) {
|
||||
*ptr = new Var(write.type);
|
||||
return Status::OK();
|
||||
}));
|
||||
variable_infos.emplace_back(actual_input_index, variable);
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::vector<VariableInfo> variable_infos,
|
||||
GatherVariableInfo(ctx, compilation_result, missing_ctx_input_prefix));
|
||||
TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
|
||||
|
||||
for (int i = 0; i < kernel->resource_updates.size(); ++i) {
|
||||
Allocator* allocator = ctx->device()->GetAllocator({});
|
||||
const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i];
|
||||
|
||||
for (int i = 0; i < compilation_result->resource_updates.size(); ++i) {
|
||||
const XlaCompiler::ResourceUpdate& write =
|
||||
compilation_result->resource_updates[i];
|
||||
if (variable_infos[i].var()->tensor()->dtype() != write.type) {
|
||||
return errors::Internal("Mismatched type in variable write");
|
||||
}
|
||||
@ -543,7 +552,7 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
output.set_buffer(se::OwningDeviceMemory(), {output_num});
|
||||
Tensor output_tensor = GetOrCreateTensorForOutput(
|
||||
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
|
||||
kernel->input_mapping, resource_var_snapshots, write.type,
|
||||
compilation_result->input_mapping, resource_var_snapshots, write.type,
|
||||
write.shape, buffer, allocator);
|
||||
*variable_infos[i].var()->tensor() = output_tensor;
|
||||
variable_infos[i].var()->is_initialized |= write.modified;
|
||||
|
@ -136,7 +136,7 @@ class XlaComputationLaunchContext {
|
||||
// input_mapping must be greater than or equal to `missing_ctx_input_prefix`
|
||||
// (in other words, no inputs actually required by the kernel can be missing).
|
||||
void PopulateInputs(OpKernelContext* ctx,
|
||||
const XlaCompiler::CompilationResult* kernel,
|
||||
const XlaCompiler::CompilationResult* compilation_result,
|
||||
const std::map<int, OptionalTensor>& variables,
|
||||
int missing_ctx_input_prefix);
|
||||
|
||||
@ -148,10 +148,11 @@ class XlaComputationLaunchContext {
|
||||
// See jit/resource_operation_safety_analysis for details.
|
||||
//
|
||||
//
|
||||
// Assumes that the first `missing_ctx_input_prefix` inputs to the kernel are
|
||||
// missing and adjusts input indices accordingly.
|
||||
// Assumes that the first `missing_ctx_input_prefix` inputs to the
|
||||
// compilation_result are missing and adjusts input indices accordingly.
|
||||
Status PopulateOutputs(
|
||||
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
|
||||
OpKernelContext* ctx,
|
||||
const XlaCompiler::CompilationResult* compilation_result,
|
||||
xla::ScopedShapedBuffer output, int missing_ctx_input_prefix,
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias,
|
||||
const std::map<int, OptionalTensor>& resource_var_snapshots);
|
||||
|
@ -27,10 +27,6 @@ namespace tensorflow {
|
||||
return xla_tensor;
|
||||
}
|
||||
|
||||
/*static*/ bool XlaTensor::RefCountIsOne(const Tensor& tensor) {
|
||||
return tensor.RefCountIsOne();
|
||||
}
|
||||
|
||||
/*static*/ se::DeviceMemoryBase XlaTensor::DeviceMemoryFromTensor(
|
||||
const Tensor& tensor) {
|
||||
const XlaTensor* xla_tensor = FromTensor(&tensor);
|
||||
|
@ -39,8 +39,6 @@ class XlaTensor {
|
||||
// fails.
|
||||
static XlaTensor* FromTensor(const Tensor* tensor);
|
||||
|
||||
static bool RefCountIsOne(const Tensor& tensor);
|
||||
|
||||
// Create a DeviceMemoryBase from a Tensor. The Tensor can be an XlaTensor, in
|
||||
// which case the returned value is shaped_buffer()->root_buffer(), or a
|
||||
// normal Tensor in which case the returned value is
|
||||
@ -57,7 +55,7 @@ class XlaTensor {
|
||||
// manage the memory for these tensors a ShapedBuffer may be required.
|
||||
|
||||
// Return true if this XlaTensor contains a ShapedBuffer.
|
||||
bool has_shaped_buffer() const { return shaped_buffer_ != nullptr; }
|
||||
bool has_shaped_buffer() const { return shaped_buffer_.has_value(); }
|
||||
// Return the contained ShapedBuffer.
|
||||
// REQUIRES: has_shaped_buffer()
|
||||
const xla::ShapedBuffer& shaped_buffer() const {
|
||||
@ -70,8 +68,7 @@ class XlaTensor {
|
||||
}
|
||||
// Mutates the XlaTensor to set the ShapedBuffer.
|
||||
void set_shaped_buffer(xla::ScopedShapedBuffer shaped_buffer) {
|
||||
shaped_buffer_ =
|
||||
absl::make_unique<xla::ScopedShapedBuffer>(std::move(shaped_buffer));
|
||||
shaped_buffer_ = std::move(shaped_buffer);
|
||||
}
|
||||
|
||||
// Some tensors on the device may have known values on the host. We use these
|
||||
@ -79,14 +76,12 @@ class XlaTensor {
|
||||
// host value already.
|
||||
|
||||
// Return true if this XlaTensor contains a host tensor.
|
||||
bool has_host_tensor() const { return host_tensor_ != nullptr; }
|
||||
bool has_host_tensor() const { return host_tensor_.has_value(); }
|
||||
// Return the contained host tensor.
|
||||
// REQUIRES: has_host_tensor()
|
||||
const Tensor& host_tensor() const { return *host_tensor_; }
|
||||
// Sets the contained host tensor.
|
||||
void set_host_tensor(const Tensor& tensor) {
|
||||
host_tensor_.reset(new Tensor(tensor));
|
||||
}
|
||||
void set_host_tensor(const Tensor& tensor) { host_tensor_.emplace(tensor); }
|
||||
|
||||
// Adds synchronization events to 'stream' that wait for this tensor to be
|
||||
// defined on 'stream'. Does nothing if the tensor is already defined on that
|
||||
@ -113,9 +108,9 @@ class XlaTensor {
|
||||
|
||||
private:
|
||||
// The optional contained ShapedBuffer.
|
||||
std::unique_ptr<xla::ScopedShapedBuffer> shaped_buffer_;
|
||||
absl::optional<xla::ScopedShapedBuffer> shaped_buffer_;
|
||||
// An optional host tensor value.
|
||||
std::unique_ptr<Tensor> host_tensor_;
|
||||
absl::optional<Tensor> host_tensor_;
|
||||
// An optional event that is triggered when the tensor's content has been
|
||||
// defined. If this event is nullptr, it is assumed that the tensor's content
|
||||
// is always defined.
|
||||
|
@ -30,7 +30,7 @@ cc_library(
|
||||
hdrs = ["op_or_arg_name_mapper.h"],
|
||||
deps = [
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
)
|
||||
@ -42,7 +42,7 @@ cc_library(
|
||||
":init_mlir",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:MlirOptLib",
|
||||
@ -86,7 +86,7 @@ cc_library(
|
||||
hdrs = ["init_mlir.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
],
|
||||
)
|
||||
|
||||
@ -102,7 +102,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Shape",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
@ -155,7 +155,7 @@ tf_cc_binary(
|
||||
"//tensorflow/core:tensorflow",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Support",
|
||||
|
@ -225,7 +225,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:DerivedAttributeOpInterface",
|
||||
"@llvm-project//mlir:Dialect",
|
||||
@ -253,7 +253,7 @@ cc_library(
|
||||
deps = [
|
||||
":tensorflow_lite",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
@ -272,7 +272,7 @@ cc_library(
|
||||
deps = [
|
||||
":tensorflow_lite",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
@ -289,7 +289,7 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
":tensorflow_lite",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
],
|
||||
@ -304,7 +304,7 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
@ -314,6 +314,7 @@ tf_cc_test(
|
||||
cc_library(
|
||||
name = "tensorflow_lite_legalize_tf",
|
||||
srcs = [
|
||||
"transforms/device_index_selector.cc",
|
||||
"transforms/dilated_conv.cc",
|
||||
"transforms/generated_legalize_tf.inc",
|
||||
"transforms/generated_lower_static_tensor_list.inc",
|
||||
@ -357,7 +358,7 @@ cc_library(
|
||||
"//tensorflow/core/platform:logging",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
@ -383,7 +384,7 @@ cc_library(
|
||||
":validators",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
@ -416,7 +417,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/lite/quantization/lite:tfl_to_std",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
@ -441,7 +442,7 @@ cc_library(
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
@ -494,8 +495,8 @@ tf_native_cc_binary(
|
||||
"converter_gen.cc",
|
||||
],
|
||||
deps = [
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:tablegen",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//llvm:TableGen",
|
||||
"@llvm-project//mlir:TableGen",
|
||||
],
|
||||
)
|
||||
@ -541,8 +542,8 @@ cc_library(
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@flatbuffers",
|
||||
"@llvm-project//llvm:analysis",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Analysis",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:TransformUtils",
|
||||
],
|
||||
@ -619,7 +620,7 @@ cc_library(
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@flatbuffers",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
@ -653,7 +654,7 @@ cc_library(
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
@ -713,7 +714,7 @@ cc_library(
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:MlirTranslateMain",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
@ -743,7 +744,7 @@ cc_library(
|
||||
"tf_tfl_translate_cl.h",
|
||||
],
|
||||
deps = [
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -755,7 +756,7 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
],
|
||||
)
|
||||
|
||||
@ -780,7 +781,7 @@ tf_cc_binary(
|
||||
":tf_tfl_translate_cl_options",
|
||||
":tf_to_tfl_flatbuffer",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
# TODO(b/155809683): Link only necessary dialects.
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:IR",
|
||||
@ -805,7 +806,7 @@ tf_cc_binary(
|
||||
":flatbuffer_translate_lib",
|
||||
":flatbuffer_translate_registeration",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
# TODO(b/155809683): Link only necessary dialects.
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:IR",
|
||||
@ -874,7 +875,7 @@ cc_library(
|
||||
"//tensorflow/lite/tools/optimize:quantize_weights",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Parser",
|
||||
@ -894,6 +895,6 @@ cc_library(
|
||||
"//tensorflow/lite/experimental/mlir:__subpackages__",
|
||||
],
|
||||
deps = [
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
],
|
||||
)
|
||||
|
@ -32,7 +32,6 @@ struct PassConfig {
|
||||
lower_tensor_list_ops(false),
|
||||
trim_functions_whitelist({}),
|
||||
quant_specs(std::move(specs)),
|
||||
skip_control_dialect(false),
|
||||
form_clusters(false),
|
||||
unfold_batch_matmul(true),
|
||||
legalize_tf_while(true),
|
||||
@ -49,13 +48,8 @@ struct PassConfig {
|
||||
llvm::ArrayRef<std::string> trim_functions_whitelist;
|
||||
// All information about quantization.
|
||||
QuantizationSpecs quant_specs;
|
||||
// If `skip_control_dialect` is true, TF executor dialect is not converted to
|
||||
// TF control dialect prior to legalization to TF Lite.
|
||||
// TODO(b/142911013): Remove flag once control dialect is removed.
|
||||
bool skip_control_dialect;
|
||||
// If `form_clusters` is true (and `skip_control_dialect` is true), clusters
|
||||
// are formed by grouping consecutive ops of the same device, under a
|
||||
// `tf_device.launch` op.
|
||||
// If `form_clusters` is true , clusters are formed by grouping consecutive
|
||||
// ops of the same device, under a `tf_device.launch` op.
|
||||
bool form_clusters;
|
||||
// if `unfold_batch_matmul` is true, the tf.BatchMatMul is unfolded to a set
|
||||
// of tfl.fully_connected ops.
|
||||
|
@ -446,7 +446,7 @@ static void GenOperandResultVerifier(raw_ostream &os,
|
||||
auto desc =
|
||||
definit->getDef()->getValueAsString("tflRuntimeTypeDescription");
|
||||
|
||||
// Emit a loop to check all the dynamic values in the pack.
|
||||
// Emit a loop to check all operands.
|
||||
os << formatv(" for (Value v : top.getODS{0}{1}s({2})) {{\n",
|
||||
// Capitalize the first letter to match the function name
|
||||
valueKind.substr(0, 1).upper(), valueKind.substr(1),
|
||||
@ -455,14 +455,10 @@ static void GenOperandResultVerifier(raw_ostream &os,
|
||||
os << " (void)v;\n"
|
||||
<< " if (!("
|
||||
<< tgfmt(pred.getCondition(), &fctx.withSelf("v.getType()")) << ")) {\n"
|
||||
<< " if (failure_on_operand_type_mismatch) {\n"
|
||||
<< formatv(
|
||||
" return op->emitOpError(\"{0} #\") << index "
|
||||
" return op->emitOpError(\"{0} #\") << index "
|
||||
"<< \" must be {1}, but got \" << v.getType();\n",
|
||||
valueKind, desc)
|
||||
<< " } else {\n"
|
||||
<< " return ::mlir::LogicalResult::Failure;\n"
|
||||
<< " }\n"
|
||||
<< " }\n" // if
|
||||
<< " ++index;\n"
|
||||
<< " }\n"; // for
|
||||
@ -487,8 +483,7 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
|
||||
|
||||
mlir::tblgen::FmtContext verify_ctx;
|
||||
os << "::mlir::LogicalResult " << op.getCppClassName()
|
||||
<< "::VerifyTflRuntimeConstraints(::mlir::Operation *op, bool "
|
||||
"failure_on_operand_type_mismatch) {\n";
|
||||
<< "::VerifyTflRuntimeConstraints(::mlir::Operation *op) {\n";
|
||||
os << " auto top = cast<" << op.getCppClassName() << ">(op); (void)top;\n";
|
||||
verify_ctx.withOp("top");
|
||||
|
||||
@ -525,11 +520,13 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
|
||||
auto *val = trait.getDef().getValue("tflRuntimePredicate");
|
||||
if (!val) continue;
|
||||
|
||||
auto desc = trait.getDef().getValueAsString("tflRuntimeDescription");
|
||||
|
||||
mlir::tblgen::Pred pred(dyn_cast<llvm::DefInit>(val->getValue()));
|
||||
os << tgfmt(
|
||||
" if (!($0)) {\n "
|
||||
" return ::mlir::LogicalResult::Failure;\n }\n",
|
||||
&verify_ctx, tgfmt(pred.getCondition(), &verify_ctx));
|
||||
" if (!($0))\n"
|
||||
" return top.emitOpError(\"failed to verify that $1\");\n",
|
||||
&verify_ctx, tgfmt(pred.getCondition(), &verify_ctx), desc);
|
||||
}
|
||||
os << " return top.verify();\n}\n";
|
||||
}
|
||||
|
@ -1406,22 +1406,67 @@ BufferOffset<tflite::SparsityParameters> Translator::BuildSparsityParameters(
|
||||
for (int j = 0; j < segments.size(); j++) {
|
||||
vector_segments[j] = segments[j].dyn_cast<mlir::IntegerAttr>().getInt();
|
||||
}
|
||||
auto array_segments =
|
||||
tflite::CreateInt32Vector(builder_,
|
||||
builder_.CreateVector(vector_segments))
|
||||
.Union();
|
||||
tflite::SparseIndexVector segments_type;
|
||||
BufferOffset<void> array_segments;
|
||||
// The segment array is sorted.
|
||||
// TODO(b/147449640): Clean this up with util functions.
|
||||
int max_of_segments = vector_segments[segments.size() - 1];
|
||||
if (max_of_segments <= UINT8_MAX) {
|
||||
segments_type = tflite::SparseIndexVector_Uint8Vector;
|
||||
std::vector<uint8_t> uint8_vector(vector_segments.begin(),
|
||||
vector_segments.end());
|
||||
array_segments = tflite::CreateUint8Vector(
|
||||
builder_, builder_.CreateVector(uint8_vector))
|
||||
.Union();
|
||||
} else if (max_of_segments <= UINT16_MAX) {
|
||||
segments_type = tflite::SparseIndexVector_Uint16Vector;
|
||||
std::vector<uint16_t> uint16_vector(vector_segments.begin(),
|
||||
vector_segments.end());
|
||||
array_segments = tflite::CreateUint16Vector(
|
||||
builder_, builder_.CreateVector(uint16_vector))
|
||||
.Union();
|
||||
} else {
|
||||
segments_type = tflite::SparseIndexVector_Int32Vector;
|
||||
array_segments = tflite::CreateInt32Vector(
|
||||
builder_, builder_.CreateVector(vector_segments))
|
||||
.Union();
|
||||
}
|
||||
|
||||
auto indices = dim_metadata.indices();
|
||||
std::vector<int> vector_indices(indices.size(), 0);
|
||||
int max_of_indices = 0;
|
||||
for (int j = 0; j < indices.size(); j++) {
|
||||
vector_indices[j] = indices[j].dyn_cast<mlir::IntegerAttr>().getInt();
|
||||
if (vector_indices[j] > max_of_indices) {
|
||||
max_of_indices = vector_indices[j];
|
||||
}
|
||||
}
|
||||
auto array_indices = tflite::CreateInt32Vector(
|
||||
builder_, builder_.CreateVector(vector_indices))
|
||||
.Union();
|
||||
tflite::SparseIndexVector indices_type;
|
||||
BufferOffset<void> array_indices;
|
||||
if (max_of_indices <= UINT8_MAX) {
|
||||
indices_type = tflite::SparseIndexVector_Uint8Vector;
|
||||
std::vector<uint8_t> uint8_vector(vector_indices.begin(),
|
||||
vector_indices.end());
|
||||
array_indices = tflite::CreateUint8Vector(
|
||||
builder_, builder_.CreateVector(uint8_vector))
|
||||
.Union();
|
||||
} else if (max_of_indices <= UINT16_MAX) {
|
||||
indices_type = tflite::SparseIndexVector_Uint16Vector;
|
||||
std::vector<uint16_t> uint16_vector(vector_indices.begin(),
|
||||
vector_indices.end());
|
||||
array_indices = tflite::CreateUint16Vector(
|
||||
builder_, builder_.CreateVector(uint16_vector))
|
||||
.Union();
|
||||
} else {
|
||||
indices_type = tflite::SparseIndexVector_Int32Vector;
|
||||
array_indices = tflite::CreateInt32Vector(
|
||||
builder_, builder_.CreateVector(vector_indices))
|
||||
.Union();
|
||||
}
|
||||
|
||||
fb_dim_metadata[i] = tflite::CreateDimensionMetadata(
|
||||
builder_, tflite::DimensionType_SPARSE_CSR, 0,
|
||||
tflite::SparseIndexVector_Int32Vector, array_segments,
|
||||
tflite::SparseIndexVector_Int32Vector, array_indices);
|
||||
builder_, tflite::DimensionType_SPARSE_CSR, 0, segments_type,
|
||||
array_segments, indices_type, array_indices);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -424,6 +424,10 @@ StatusOr<Operation*> BuildExternalConstOp(const tflite::TensorT& tensor,
|
||||
StatusOr<Operation*> BuildConstOp(const tflite::TensorT& tensor,
|
||||
const std::vector<uint8_t>& buffer,
|
||||
OpBuilder builder, Location loc) {
|
||||
if (buffer.empty()) {
|
||||
return errors::InvalidArgument("Constant's buffer may not be empty");
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto type, GetTensorType(tensor, builder,
|
||||
/*shapeless_are_scalars=*/true,
|
||||
/*is_constant=*/true));
|
||||
@ -799,9 +803,17 @@ StatusOr<FuncOp> ConvertSubgraph(
|
||||
}
|
||||
|
||||
for (auto output : func_outputs) {
|
||||
bool is_constant = !is_op_output[output];
|
||||
const bool is_func_input = input_index_set.contains(output);
|
||||
bool is_constant = !is_op_output[output] && !is_func_input;
|
||||
// There are 2 cases tensor is scalar when it doesn't have a shape in
|
||||
// flatbuffer:
|
||||
// 1. `is_constant` = true, means this tensor is created from a constant op.
|
||||
// 2. `is_func_input` = true and `is_entry_point` = true, which means this
|
||||
// tensor is function input and function input type is a scalar tensor.
|
||||
const bool shapeless_is_scalar =
|
||||
is_constant || (is_func_input && is_entry_point);
|
||||
auto type_or_err = GetTensorType(*subgraph.tensors.at(output), builder,
|
||||
/*shapeless_are_scalars=*/is_constant,
|
||||
shapeless_is_scalar,
|
||||
/*is_constant=*/is_constant);
|
||||
if (!type_or_err.ok()) {
|
||||
emitError(func_loc, "error reading return types")
|
||||
@ -856,6 +868,8 @@ StatusOr<FuncOp> ConvertSubgraph(
|
||||
subgraph, &builder, "outputs", func_outputs));
|
||||
}
|
||||
func.setAttr("tf.entry_function", builder.getDictionaryAttr(attributes));
|
||||
} else {
|
||||
func.setVisibility(FuncOp::Visibility::Private);
|
||||
}
|
||||
|
||||
absl::flat_hash_set<const tflite::OperatorT*> pruned_subgraph_ops;
|
||||
|
@ -94,8 +94,7 @@ def TFL_RuntimeVerification : OpInterface<"TflRuntimeVerifyOpInterface"> {
|
||||
let methods = [
|
||||
StaticInterfaceMethod<
|
||||
[{Returns whether the op's operands/results are supported by runtime.}],
|
||||
"LogicalResult", "VerifyTflRuntimeConstraints",
|
||||
(ins "Operation*":$op, "bool":$failure_on_operand_type_mismatch)
|
||||
"LogicalResult", "VerifyTflRuntimeConstraints", (ins "Operation*":$op)
|
||||
>,
|
||||
];
|
||||
}
|
||||
|
@ -46,28 +46,183 @@ namespace mlir {
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_structs.cc.inc"
|
||||
namespace TFL {
|
||||
|
||||
// Returns true when the given two types have the same shape or broadcastable
|
||||
// shape within the given rank. If any given shapes are non-static, this method
|
||||
// returns true.
|
||||
bool IsBinaryOperandsHaveSameShapesOrBroadcastableShape(Type lhs, Type rhs,
|
||||
int max_bcast_rank) {
|
||||
// Ignore shape checking on the non-static shapes for model compatibility.
|
||||
auto lhs_shaped_type = lhs.dyn_cast<ShapedType>();
|
||||
if (!lhs_shaped_type || !lhs_shaped_type.hasStaticShape()) return true;
|
||||
auto rhs_shaped_type = rhs.dyn_cast<ShapedType>();
|
||||
if (!rhs_shaped_type || !rhs_shaped_type.hasStaticShape()) return true;
|
||||
// Returns true when the given operand arguments have the same shape or
|
||||
// broadcastable shape within the given rank. If any given shapes are
|
||||
// non-static and maximum rank is within the given rank, this method returns
|
||||
// true.
|
||||
bool VerifyOperandsHaveSameShapesOrBroadcastableShape(
|
||||
Operation *op, ArrayRef<unsigned> indices, int max_bcast_rank) {
|
||||
if (indices.empty()) return true;
|
||||
|
||||
if (lhs_shaped_type.getShape().equals(rhs_shaped_type.getShape()))
|
||||
return true;
|
||||
// First, it checks there are any inputs that has unknown rank.
|
||||
bool has_unknown_shape_input = false;
|
||||
bool has_same_shape = true;
|
||||
bool reach_first_known_shape = false;
|
||||
int64_t max_rank = -1;
|
||||
|
||||
ArrayRef<int64_t> pivot_shape;
|
||||
SmallVector<int64_t, 4> current_shape;
|
||||
SmallVector<int64_t, 4> result_shape;
|
||||
if (!OpTrait::util::getBroadcastedShape(lhs_shaped_type.getShape(),
|
||||
rhs_shaped_type.getShape(),
|
||||
result_shape)) {
|
||||
return false;
|
||||
|
||||
for (unsigned index : indices) {
|
||||
ShapedType shaped_type =
|
||||
op->getOperand(index).getType().dyn_cast<ShapedType>();
|
||||
if (!shaped_type || !shaped_type.hasRank()) {
|
||||
// Marks that we have an unknown rank input.
|
||||
has_unknown_shape_input = true;
|
||||
continue;
|
||||
}
|
||||
max_rank = std::max(max_rank, shaped_type.getRank());
|
||||
if (!shaped_type.hasStaticShape()) {
|
||||
// Marks that we have an unknown shape input.
|
||||
has_unknown_shape_input = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> shape = shaped_type.getShape();
|
||||
if (!reach_first_known_shape) {
|
||||
pivot_shape = shape;
|
||||
current_shape.assign(shape.begin(), shape.end());
|
||||
reach_first_known_shape = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!pivot_shape.equals(shape)) {
|
||||
has_same_shape = false;
|
||||
}
|
||||
// Checks if all the inputs are broadcastable since they have not all the
|
||||
// same shapes.
|
||||
if (!OpTrait::util::getBroadcastedShape(current_shape, shape,
|
||||
result_shape)) {
|
||||
return false;
|
||||
}
|
||||
current_shape = result_shape;
|
||||
}
|
||||
return lhs_shaped_type.getRank() <= max_bcast_rank &&
|
||||
rhs_shaped_type.getRank() <= max_bcast_rank;
|
||||
|
||||
// It will treat the unknown shape inputs as acceptable inputs for model
|
||||
// compatibility unless there is an known rank that is bigger than the allowed
|
||||
// broadcast maximum rank.
|
||||
if (has_unknown_shape_input) return max_rank <= max_bcast_rank;
|
||||
|
||||
// If all the shape is known and same, CPU kernels are able to handle inputs
|
||||
// regardless of dimension size.
|
||||
return has_same_shape || max_rank <= max_bcast_rank;
|
||||
}
|
||||
|
||||
// Return true when the given element_type is QI8.
|
||||
bool IsQI8Type(Type element_type) {
|
||||
auto quantized_type = element_type.dyn_cast<QuantizedType>();
|
||||
return quantized_type != nullptr &&
|
||||
quantized_type.getStorageTypeIntegralWidth() == 8 &&
|
||||
quantized_type.isSigned();
|
||||
}
|
||||
|
||||
// Return true when the given element_type is QUI8.
|
||||
bool IsQUI8Type(Type element_type) {
|
||||
auto quantized_type = element_type.dyn_cast<QuantizedType>();
|
||||
return quantized_type != nullptr &&
|
||||
quantized_type.getStorageTypeIntegralWidth() == 8 &&
|
||||
!quantized_type.isSigned();
|
||||
}
|
||||
|
||||
// Return true when the given element_type is QI16.
|
||||
bool IsQI16Type(Type element_type) {
|
||||
auto quantized_type = element_type.dyn_cast<QuantizedType>();
|
||||
return quantized_type != nullptr &&
|
||||
quantized_type.getStorageTypeIntegralWidth() == 16 &&
|
||||
quantized_type.isSigned();
|
||||
}
|
||||
|
||||
// Return true when the given element_type is I32.
|
||||
bool IsI32Type(Type element_type) {
|
||||
return element_type.isInteger(32) && !element_type.isUnsignedInteger();
|
||||
}
|
||||
|
||||
// Return true if the given Add operation has the CPU kernel supported shapes.
|
||||
bool VerifyAddOpShapeConstraints(AddOp op) {
|
||||
auto element_type = getElementTypeOrSelf(op.output().getType());
|
||||
|
||||
// Allows F32, QI8, and QUI8 outputs when the operands have valid shapes,
|
||||
// which are broadcastable shapes up to five dimension or have same shapes.
|
||||
if (element_type.isF32() || IsQI8Type(element_type) ||
|
||||
IsQUI8Type(element_type)) {
|
||||
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
|
||||
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
|
||||
/*max_bcast_rank=*/5);
|
||||
}
|
||||
|
||||
// Allows I32 output when the operands have valid shapes, which are
|
||||
// broadcastable shapes up to four dimension or have same shapes.
|
||||
if (IsI32Type(element_type)) {
|
||||
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
|
||||
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
|
||||
/*max_bcast_rank=*/4);
|
||||
}
|
||||
|
||||
// Allows QI16 output when operands have the same shape.
|
||||
if (IsQI16Type(element_type)) {
|
||||
return succeeded(
|
||||
mlir::verifyCompatibleShape(op.lhs().getType(), op.rhs().getType()));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Return true if the given Sub operation has the CPU kernel supported shapes.
|
||||
bool VerifySubOpShapeConstraints(SubOp op) {
|
||||
auto element_type = getElementTypeOrSelf(op.output().getType());
|
||||
|
||||
// Allows F32, QUI8, and QI16 outputs when the operands have valid shapes,
|
||||
// which are broadcastable shapes up to five dimension or have same shapes.
|
||||
if (element_type.isF32() || IsI32Type(element_type) ||
|
||||
IsQUI8Type(element_type) || IsQI16Type(element_type)) {
|
||||
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
|
||||
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
|
||||
/*max_bcast_rank=*/5);
|
||||
}
|
||||
|
||||
// Allows QI8 output when the operands have valid shapes, which are
|
||||
// broadcastable shapes up to four dimension or have same shapes.
|
||||
if (IsQI8Type(element_type)) {
|
||||
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
|
||||
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
|
||||
/*max_bcast_rank=*/4);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Return true if the given Mul operation has the CPU kernel supported shapes.
|
||||
bool VerifyMulOpShapeConstraints(MulOp op) {
|
||||
auto element_type = getElementTypeOrSelf(op.output().getType());
|
||||
|
||||
// Allows QI8 and QUI8 inputs up to five dimension broadcasting unless the
|
||||
// output type is not QI16. If the output type is Q16, allows onlt the same
|
||||
// shape operands.
|
||||
if (IsQI8Type(element_type) || IsQUI8Type(element_type)) {
|
||||
if (IsQI16Type(getElementTypeOrSelf(op.lhs().getType()))) {
|
||||
return succeeded(
|
||||
mlir::verifyCompatibleShape(op.lhs().getType(), op.rhs().getType()));
|
||||
}
|
||||
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
|
||||
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
|
||||
/*max_bcast_rank=*/5);
|
||||
}
|
||||
|
||||
// Allows F32 output when the operands have valid shapes, which are
|
||||
// broadcastable shapes up to five dimension or have same shapes.
|
||||
if (element_type.isF32()) {
|
||||
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
|
||||
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
|
||||
/*max_bcast_rank=*/5);
|
||||
}
|
||||
|
||||
// Allows I32 and QI16 outputs when the operands have valid shapes, which are
|
||||
// broadcastable shapes up to four dimension or have same shapes.
|
||||
if (IsI32Type(element_type) || IsQI16Type(element_type)) {
|
||||
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
|
||||
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
|
||||
/*max_bcast_rank=*/4);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1882,7 +2037,7 @@ static LogicalResult Verify(TransposeConvOp op) {
|
||||
|
||||
auto expected_output_type =
|
||||
RankedTensorType::get(output_shape, output_type.getElementType());
|
||||
if (output_type != expected_output_type) {
|
||||
if (failed(mlir::verifyCompatibleShape(output_type, expected_output_type))) {
|
||||
return op.emitOpError(llvm::formatv("expect output type {0}, got {1}",
|
||||
expected_output_type, output_type));
|
||||
}
|
||||
@ -2004,7 +2159,8 @@ static LogicalResult Verify(TransposeOp op) {
|
||||
}
|
||||
auto expected_output_type =
|
||||
RankedTensorType::get(transposed_shape, input_type.getElementType());
|
||||
if (output_type != expected_output_type) {
|
||||
if (failed(
|
||||
mlir::verifyCompatibleShape(output_type, expected_output_type))) {
|
||||
return op.emitOpError(llvm::formatv("expect output type {0}, got {1}",
|
||||
expected_output_type, output_type));
|
||||
}
|
||||
|
@ -123,14 +123,13 @@ class TFL_RuntimePredOpTrait<string desc, Pred pred> :
|
||||
string tflRuntimeDescription = desc;
|
||||
}
|
||||
|
||||
class TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<
|
||||
int i, int j, int max_bcast_rank> :
|
||||
TFL_RuntimePredOpTrait<"operand #" # i # " and operand #" # j #
|
||||
" have the same shape or broadcastable shapes within the rank " #
|
||||
max_bcast_rank,
|
||||
CPred<"TFL::IsBinaryOperandsHaveSameShapesOrBroadcastableShape("
|
||||
"$_op.getOperand(" # i # ").getType(), $_op.getOperand(" # j #
|
||||
").getType(), " # max_bcast_rank # ")">>;
|
||||
class TFL_OperandsHaveSameShapesOrBroadcastableShape<
|
||||
list<int> indices, int max_bcast_rank> :
|
||||
TFL_RuntimePredOpTrait<"operands do not have the same shape or "
|
||||
"broadcastable shapes within the rank " # max_bcast_rank,
|
||||
CPred<"TFL::VerifyOperandsHaveSameShapesOrBroadcastableShape("
|
||||
"$_op, llvm::ArrayRef<unsigned>({" # StrJoinInt<indices>.result #
|
||||
"}), " # max_bcast_rank # ")">>;
|
||||
|
||||
// These additional types/type constraints here are used to decouple the ops
|
||||
// from runtime support for the ops. Prefer to use these types when defining
|
||||
@ -213,11 +212,20 @@ class TFL_OperandIsRankedAndHasDimPred<int n, int dim> : And<[
|
||||
CPred<"$_op.getOperand(" # n # ").getType().cast<ShapedType>().getRank() > "
|
||||
# dim>]>;
|
||||
|
||||
// Returns true if the n-th operand is ranked and has a dimension length = size
|
||||
// at the rank dim.
|
||||
class TFL_OperandDimEquals<int n, int dim, int size> : And<[
|
||||
TFL_OperandIsRankedAndHasDimPred<n, dim>,
|
||||
CPred<"$_op.getOperand(" # n # ").getType().cast<ShapedType>()"
|
||||
".getShape()[" # dim # " ] == " # size>]>;
|
||||
|
||||
// Returns true if the n-th operand is ranked and has a dimension length <=
|
||||
// size at the rank dim.
|
||||
class TFL_OperandDimIsAtMost<int n, int dim, int size> : And<[
|
||||
TFL_OperandIsRankedAndHasDimPred<n, dim>,
|
||||
CPred<"$_op.getOperand(" # n # ").getType().cast<ShapedType>()"
|
||||
".getShape()[" # dim # " ] <= " # size>]>;
|
||||
|
||||
// Returns true if the n-th operand has unknown rank or at least rank m.
|
||||
class TFL_OperandHasAtleastRank<int n, int m> :
|
||||
PredOpTrait<"operand " # n # " is " # m # "-D",
|
||||
@ -428,7 +436,7 @@ class TFL_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
class TFL_ConvOp<string mnemonic, string opSummary, int index> :
|
||||
TFL_Op<mnemonic, [NoSideEffect, AccumulatorUniformScale<2, 0, 1>,
|
||||
TFL_ChannelDimIndexInterface, AffineOpCoefficient<index, 1>,
|
||||
TFL_GpuTargetOp]> {
|
||||
TFL_GpuTargetOp, TFL_SparseOp]> {
|
||||
let summary = opSummary # " operator";
|
||||
|
||||
let description = [{
|
||||
@ -463,6 +471,7 @@ class TFL_ConvOp<string mnemonic, string opSummary, int index> :
|
||||
//===----------------------------------------------------------------------===//
|
||||
def TFL_AbsOp : TFL_Op<"abs", [
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultType,
|
||||
NoQuantizableResult,
|
||||
TFL_GpuTargetOp]> {
|
||||
@ -482,7 +491,8 @@ an output element, this operation computes \\(y = |x|\\).
|
||||
}
|
||||
|
||||
def TFL_AddOp : TFL_Op<"add", [
|
||||
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 5>,
|
||||
TFL_RuntimePredOpTrait<"Operands do not have valid shapes",
|
||||
CPred<"TFL::VerifyAddOpShapeConstraints(llvm::cast<AddOp>($_op))">>,
|
||||
ResultsBroadcastableShape,
|
||||
NoSideEffect,
|
||||
Commutative,
|
||||
@ -561,7 +571,8 @@ def TFL_TransposeConvOp: TFL_Op<"transpose_conv", [
|
||||
TFL_OperandHasRank<2, 4>,
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 2>>,
|
||||
TFL_GpuTargetOp]> {
|
||||
TFL_GpuTargetOp,
|
||||
TFL_SparseOp]> {
|
||||
let summary = "Transpose convolution operator";
|
||||
|
||||
let description = [{
|
||||
@ -583,6 +594,13 @@ def TFL_TransposeConvOp: TFL_Op<"transpose_conv", [
|
||||
let hasOptions = 1;
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// SparseOpInterface:
|
||||
std::vector<int> GetSparseOperands() { return {1}; }
|
||||
std::vector<std::vector<int>> GetFloatBlockSize() { return {}; }
|
||||
std::vector<std::vector<int>> GetQuantizedBlockSize() { return {}; }
|
||||
}];
|
||||
}
|
||||
|
||||
def TFL_AveragePool2DOp:
|
||||
@ -669,7 +687,10 @@ def TFL_ArgMinOp : TFL_Op<"arg_min", [NoSideEffect]> {
|
||||
}]>;
|
||||
}
|
||||
|
||||
def TFL_CeilOp: TFL_Op<"ceil", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
def TFL_CeilOp: TFL_Op<"ceil", [
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultType]> {
|
||||
let summary = "Ceil operator";
|
||||
|
||||
let description = [{
|
||||
@ -813,11 +834,16 @@ def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution", 0> {
|
||||
let extraClassDeclaration = [{
|
||||
// ChannelDimIndexInterface:
|
||||
int GetChannelDimIndex() { return 0; }
|
||||
// SparseOpInterface:
|
||||
std::vector<int> GetSparseOperands() { return {1}; }
|
||||
std::vector<std::vector<int>> GetFloatBlockSize() { return {}; }
|
||||
std::vector<std::vector<int>> GetQuantizedBlockSize() { return {}; }
|
||||
}];
|
||||
}
|
||||
|
||||
def TFL_CosOp: TFL_Op<"cos", [
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultType,
|
||||
NoQuantizableResult,
|
||||
TFL_GpuTargetOp]> {
|
||||
@ -852,6 +878,10 @@ def TFL_DepthwiseConv2DOp :
|
||||
let extraClassDeclaration = [{
|
||||
// ChannelDimIndexInterface:
|
||||
int GetChannelDimIndex() { return 3; }
|
||||
// SparseOpInterface:
|
||||
std::vector<int> GetSparseOperands() { return {1}; }
|
||||
std::vector<std::vector<int>> GetFloatBlockSize() { return {}; }
|
||||
std::vector<std::vector<int>> GetQuantizedBlockSize() { return {}; }
|
||||
}];
|
||||
}
|
||||
|
||||
@ -1021,7 +1051,7 @@ def TFL_ScatterNdOp : TFL_Op<"scatter_nd", [
|
||||
def TFL_LessEqualOp : TFL_Op<"less_equal", [
|
||||
ResultsBroadcastableShape,
|
||||
BinaryOpSameElementTypeConstraint,
|
||||
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>,
|
||||
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
|
||||
NoSideEffect,
|
||||
NoQuantizableResult]> {
|
||||
let summary = "Less_equal operator";
|
||||
@ -1082,7 +1112,7 @@ convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imag
|
||||
}
|
||||
|
||||
def TFL_GreaterEqualOp : TFL_Op<"greater_equal", [
|
||||
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>,
|
||||
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
|
||||
ResultsBroadcastableShape,
|
||||
NoSideEffect,
|
||||
NoQuantizableResult]> {
|
||||
@ -1150,12 +1180,12 @@ innermost matrices. These will be overwritten by the values in `diagonal`.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QI16, QUI8, TFL_Quint8]>:$input,
|
||||
TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QI16, QUI8, TFL_Quint8]>:$diagonal
|
||||
TFL_TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QI16, QUI8, TFL_Quint8]>:$input,
|
||||
TFL_TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QI16, QUI8, TFL_Quint8]>:$diagonal
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QI16, QUI8, TFL_Quint8]>:$result
|
||||
TFL_TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QI16, QUI8, TFL_Quint8]>:$result
|
||||
);
|
||||
|
||||
let hasOptions = 0;
|
||||
@ -1273,7 +1303,7 @@ larger than 0.
|
||||
}
|
||||
|
||||
def TFL_NotEqualOp : TFL_Op<"not_equal", [
|
||||
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>,
|
||||
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
|
||||
BinaryOpSameElementTypeConstraint,
|
||||
ResultsBroadcastableShape,
|
||||
Commutative,
|
||||
@ -1309,7 +1339,7 @@ def TFL_DivOp : TFL_Op<"div", [
|
||||
// TODO(fengliuai): NoQuantizableResult is only correct for int8
|
||||
// quantization. update to handle Uint8 quantization.
|
||||
BinaryOpSameElementTypeConstraint,
|
||||
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 5>,
|
||||
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 5>,
|
||||
ResultsBroadcastableShape,
|
||||
NoSideEffect,
|
||||
NoQuantizableResult,
|
||||
@ -1338,7 +1368,10 @@ def TFL_DivOp : TFL_Op<"div", [
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TFL_EluOp: TFL_Op<"elu", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
def TFL_EluOp: TFL_Op<"elu", [
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultType]> {
|
||||
let summary = "Exponential Linear Unit operator";
|
||||
let description = [{
|
||||
Computes the exponential linear
|
||||
@ -1374,10 +1407,11 @@ def TFL_EmbeddingLookupOp: TFL_Op<"embedding_lookup",
|
||||
let results = (outs TFL_TensorOf<[F32, I8, UI8]>:$output);
|
||||
}
|
||||
|
||||
def TFL_EqualOp: TFL_Op<"equal", [Commutative, ResultsBroadcastableShape,
|
||||
def TFL_EqualOp: TFL_Op<"equal", [
|
||||
Commutative,
|
||||
NoQuantizableResult,
|
||||
ResultsBroadcastableShape,
|
||||
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>,
|
||||
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
|
||||
PredOpTrait<"Operands have same value type", TCopVTEtIsSameAs<0, 1>>]> {
|
||||
let summary = "Equal operator";
|
||||
|
||||
@ -1516,7 +1550,10 @@ def TFL_FillOp: TFL_Op<"fill", [
|
||||
let hasOptions = 0;
|
||||
}
|
||||
|
||||
def TFL_FloorOp: TFL_Op<"floor", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
def TFL_FloorOp: TFL_Op<"floor", [
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultType]> {
|
||||
let summary = "Floor operator";
|
||||
|
||||
let description = [{
|
||||
@ -1534,7 +1571,7 @@ def TFL_FloorDivOp : TFL_Op<"floor_div", [
|
||||
BinaryOpSameElementTypeConstraint,
|
||||
PredOpTrait<"lhs and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>]> {
|
||||
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>]> {
|
||||
let summary = "Floor div operator";
|
||||
|
||||
let description = [{
|
||||
@ -1559,7 +1596,7 @@ def TFL_FloorModOp : TFL_Op<"floor_mod", [
|
||||
BinaryOpSameElementTypeConstraint,
|
||||
PredOpTrait<"lhs and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>]> {
|
||||
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>]> {
|
||||
let summary = "Division reminder";
|
||||
|
||||
let description = [{
|
||||
@ -1578,7 +1615,7 @@ def TFL_FloorModOp : TFL_Op<"floor_mod", [
|
||||
def TFL_GreaterOp : TFL_Op<"greater", [
|
||||
ResultsBroadcastableShape,
|
||||
BinaryOpSameElementTypeConstraint,
|
||||
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>,
|
||||
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
|
||||
NoSideEffect,
|
||||
NoQuantizableResult]> {
|
||||
let summary = "Greater operator";
|
||||
@ -1670,7 +1707,7 @@ def TFL_LeakyReluOp: TFL_Op<"leaky_relu", [
|
||||
def TFL_LessOp : TFL_Op<"less", [
|
||||
ResultsBroadcastableShape,
|
||||
BinaryOpSameElementTypeConstraint,
|
||||
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>,
|
||||
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
|
||||
NoSideEffect,
|
||||
NoQuantizableResult]> {
|
||||
let summary = "Less operator";
|
||||
@ -1710,7 +1747,10 @@ def TFL_LogicalAndOp : TFL_Op<"logical_and", [NoSideEffect]> {
|
||||
let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }];
|
||||
}
|
||||
|
||||
def TFL_LogicalNotOp : TFL_Op<"logical_not", [NoSideEffect, NoQuantizableResult]> {
|
||||
def TFL_LogicalNotOp : TFL_Op<"logical_not", [
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
NoQuantizableResult]> {
|
||||
let summary = "Logical NOT operator";
|
||||
|
||||
let description = [{
|
||||
@ -1794,6 +1834,7 @@ def TFL_LogisticOp: TFL_Op<"logistic", [
|
||||
|
||||
def TFL_LogOp: TFL_Op<"log", [
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultType,
|
||||
NoQuantizableResult,
|
||||
TFL_GpuTargetOp]> {
|
||||
@ -1884,6 +1925,7 @@ def TFL_MaxPool2DOp : TFL_Op<"max_pool_2d", [
|
||||
def TFL_MaximumOp : TFL_Op<"maximum", [
|
||||
ResultsBroadcastableShape,
|
||||
NoSideEffect,
|
||||
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 5>,
|
||||
Commutative,
|
||||
SameOperandsAndResultsScale,
|
||||
TFL_GpuTargetOp]> {
|
||||
@ -2118,6 +2160,7 @@ def TFL_ReduceProdOp: TFL_Op<"reduce_prod", [
|
||||
def TFL_MinimumOp : TFL_Op<"minimum", [
|
||||
ResultsBroadcastableShape,
|
||||
NoSideEffect,
|
||||
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 5>,
|
||||
Commutative,
|
||||
SameOperandsAndResultsScale,
|
||||
TFL_GpuTargetOp]> {
|
||||
@ -2145,7 +2188,8 @@ def TFL_MulOp : TFL_Op<"mul", [
|
||||
NoSideEffect,
|
||||
Commutative,
|
||||
BinaryOpSameElementTypeConstraint,
|
||||
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 5>,
|
||||
TFL_RuntimePredOpTrait<"Operands do not have valid shapes",
|
||||
CPred<"TFL::VerifyMulOpShapeConstraints(llvm::cast<MulOp>($_op))">>,
|
||||
TFL_GpuTargetOp]> {
|
||||
let summary = "Multiplication operator";
|
||||
|
||||
@ -2171,7 +2215,10 @@ def TFL_MulOp : TFL_Op<"mul", [
|
||||
let hasOptions = 1;
|
||||
}
|
||||
|
||||
def TFL_NegOp: TFL_Op<"neg", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
def TFL_NegOp: TFL_Op<"neg", [
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultType]> {
|
||||
let summary = "Negation operator";
|
||||
|
||||
let description = [{
|
||||
@ -2247,6 +2294,9 @@ def TFL_PadOp : TFL_Op<"pad", [
|
||||
TFL_OperandHasRankAtMost<0, 4>,
|
||||
TFL_OperandHasRank<1, 2>,
|
||||
TFL_OperandRankEquals1DimOfOperand<0, 1>,
|
||||
PredOpTrait<"the first dim size of the padding argument must be at most 4",
|
||||
Or<[TFL_OperandIsUnrankedPred<1>,
|
||||
TFL_OperandDimIsAtMost<1, 0, 4>]>>,
|
||||
TFL_GpuTargetOp]> {
|
||||
let summary = "Padding operator";
|
||||
|
||||
@ -2292,6 +2342,9 @@ def TFL_PadV2Op : TFL_Op<"padv2", [
|
||||
TFL_OperandHasRank<1, 2>,
|
||||
TFL_OperandHasRank<2, 0>,
|
||||
TFL_OperandRankEquals1DimOfOperand<0, 1>,
|
||||
PredOpTrait<"the first dim size of the padding argument must be at most 4",
|
||||
Or<[TFL_OperandIsUnrankedPred<1>,
|
||||
TFL_OperandDimIsAtMost<1, 0, 4>]>>,
|
||||
PredOpTrait<"input and constant value operands must have same element type",
|
||||
TFL_TCopVTEtAreSameAt<0, 2>>]> {
|
||||
let summary = "Padding operator v2";
|
||||
@ -2333,10 +2386,12 @@ def TFL_PadV2Op : TFL_Op<"padv2", [
|
||||
let hasOptions = 1;
|
||||
}
|
||||
|
||||
def TFL_PowOp : TFL_Op<"pow", [ResultsBroadcastableShape,
|
||||
NoSideEffect,
|
||||
NoQuantizableResult,
|
||||
TFL_GpuTargetOp]> {
|
||||
def TFL_PowOp : TFL_Op<"pow", [
|
||||
ResultsBroadcastableShape,
|
||||
NoSideEffect,
|
||||
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
|
||||
NoQuantizableResult,
|
||||
TFL_GpuTargetOp]> {
|
||||
let summary = "Power operator";
|
||||
|
||||
let description = [{
|
||||
@ -2360,7 +2415,7 @@ def TFL_PReluOp : TFL_Op<"prelu", [
|
||||
NoSideEffect,
|
||||
ResultsBroadcastableShape,
|
||||
TFL_GpuTargetOp,
|
||||
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>,
|
||||
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
|
||||
BinaryOpSameElementTypeConstraint,
|
||||
PredOpTrait<"input and output must have the same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>]> {
|
||||
@ -2671,8 +2726,9 @@ def TFL_SelectOp : TFL_Op<"select", [
|
||||
}
|
||||
|
||||
def TFL_SelectV2Op : TFL_Op<"select_v2", [
|
||||
ResultsBroadcastableShape,
|
||||
NoSideEffect,
|
||||
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<1, 2, 4>,
|
||||
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1, 2], 4>,
|
||||
PredOpTrait<"operands have same element type", TCopVTEtIsSameAs<1, 2>>,
|
||||
PredOpTrait<"operands and result have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 1>>]> {
|
||||
@ -2705,6 +2761,7 @@ def TFL_SelectV2Op : TFL_Op<"select_v2", [
|
||||
|
||||
def TFL_SinOp: TFL_Op<"sin", [
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultType,
|
||||
NoQuantizableResult,
|
||||
TFL_GpuTargetOp]> {
|
||||
@ -2752,6 +2809,7 @@ def TFL_SoftmaxOp : TFL_Op<"softmax", [
|
||||
|
||||
def TFL_SqrtOp: TFL_Op<"sqrt", [
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultType,
|
||||
NoQuantizableResult,
|
||||
TFL_GpuTargetOp]> {
|
||||
@ -2770,6 +2828,7 @@ def TFL_SqrtOp: TFL_Op<"sqrt", [
|
||||
|
||||
def TFL_SquareOp: TFL_Op<"square", [
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultType,
|
||||
NoQuantizableResult,
|
||||
TFL_GpuTargetOp]> {
|
||||
@ -2791,7 +2850,8 @@ def TFL_SquareOp: TFL_Op<"square", [
|
||||
def TFL_SubOp : TFL_Op<"sub", [
|
||||
ResultsBroadcastableShape,
|
||||
BinaryOpSameElementTypeConstraint,
|
||||
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 5>,
|
||||
TFL_RuntimePredOpTrait<"Operands do not have valid shapes",
|
||||
CPred<"TFL::VerifySubOpShapeConstraints(llvm::cast<SubOp>($_op))">>,
|
||||
NoSideEffect]> {
|
||||
let summary = "Subtraction operator";
|
||||
|
||||
@ -2820,7 +2880,7 @@ def TFL_SubOp : TFL_Op<"sub", [
|
||||
// TODO(jpienaar): Expand the kernel implementation to support all types besides
|
||||
// I32 and F32.
|
||||
def TFL_SquaredDifferenceOp : TFL_Op<"squared_difference", [
|
||||
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>,
|
||||
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
|
||||
SameOperandsAndResultElementType,
|
||||
ResultsBroadcastableShape,
|
||||
NoSideEffect,
|
||||
@ -3007,6 +3067,8 @@ def TFL_UnpackOp : TFL_Op<"unpack", [
|
||||
def TFL_ZerosLikeOp: TFL_Op<"zeros_like", [
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
SameOperandsAndResultType,
|
||||
SameOperandsAndResultShape,
|
||||
NoSideEffect]> {
|
||||
let summary = "ZerosLike operator";
|
||||
|
||||
@ -3319,7 +3381,9 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice", [
|
||||
}
|
||||
|
||||
def TFL_CastOp : TFL_Op<"cast", [
|
||||
NoSideEffect, SameOperandsAndResultShape, NoQuantizableResult]> {
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
NoQuantizableResult]> {
|
||||
let summary = "Cast operator";
|
||||
|
||||
let description = [{
|
||||
|
@ -27,7 +27,7 @@ cc_library(
|
||||
"//tensorflow/lite/toco:toco_flags_proto_cc",
|
||||
"//tensorflow/lite/toco:types_proto_cc",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Support",
|
||||
@ -56,7 +56,7 @@ cc_library(
|
||||
"//tensorflow/lite/toco:toco_flags_proto_cc",
|
||||
"//tensorflow/lite/toco:types_proto_cc",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Support",
|
||||
@ -85,7 +85,7 @@ cc_library(
|
||||
"//tensorflow/lite/toco:toco_flags_proto_cc",
|
||||
"//tensorflow/lite/toco:types_proto_cc",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Support",
|
||||
|
@ -80,7 +80,7 @@ cc_library(
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
@ -106,7 +106,7 @@ cc_library(
|
||||
deps = [
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
@ -125,7 +125,7 @@ cc_library(
|
||||
deps = [
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
],
|
||||
)
|
||||
|
||||
@ -135,8 +135,8 @@ tf_native_cc_binary(
|
||||
"tools/op_quant_spec_getters_gen.cc",
|
||||
],
|
||||
deps = [
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:tablegen",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//llvm:TableGen",
|
||||
"@llvm-project//mlir:TableGen",
|
||||
],
|
||||
)
|
||||
@ -157,7 +157,7 @@ cc_library(
|
||||
deps = [
|
||||
":numerical_utils",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
@ -172,7 +172,7 @@ cc_library(
|
||||
":device_target",
|
||||
":quantization_lib",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
|
@ -76,7 +76,8 @@ class ImportQuantStatsPass
|
||||
// If the index is out of range, this method returns false. Otherwise it
|
||||
// returns true if the value is a float tensor.
|
||||
bool IsQuantizableResult(Operation *op, int index) {
|
||||
if (index < 0 || index >= op->getNumResults()) return false;
|
||||
if (index < 0 || index >= static_cast<int>(op->getNumResults()))
|
||||
return false;
|
||||
Value res = op->getResult(index);
|
||||
return res.getType().isa<ShapedType>() &&
|
||||
res.getType().cast<ShapedType>().getElementType().isa<FloatType>();
|
||||
@ -158,7 +159,7 @@ void ImportQuantStatsPass::ImportAsStatsOps(OpBuilder b, Operation *op,
|
||||
InsertStatsOpAtResult(b, op->getResult(index), layer_stats, axis_stats,
|
||||
axis);
|
||||
} else {
|
||||
for (int i = 0; i < op->getNumResults(); ++i) {
|
||||
for (int i = 0, e = op->getNumResults(); i < e; ++i) {
|
||||
if (IsQuantizableResult(op, i)) {
|
||||
InsertStatsOpAtResult(b, op->getResult(i), layer_stats, axis_stats,
|
||||
axis);
|
||||
|
@ -36,7 +36,7 @@ cc_library(
|
||||
"//tensorflow/lite/core/api",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
],
|
||||
@ -54,7 +54,7 @@ cc_library(
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
@ -73,7 +73,7 @@ tf_cc_binary(
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
],
|
||||
)
|
||||
|
@ -48,7 +48,7 @@ bool ParseInputNodeQuantSpecs(absl::string_view node_names,
|
||||
std::vector<llvm::Optional<double>> node_mins;
|
||||
if (!min_values.empty()) {
|
||||
std::vector<std::string> node_mins_str = absl::StrSplit(min_values, ',');
|
||||
for (int i = 0; i < node_mins_str.size(); i++) {
|
||||
for (int i = 0, e = node_mins_str.size(); i < e; i++) {
|
||||
double value;
|
||||
if (!absl::SimpleAtod(node_mins_str[i], &value)) {
|
||||
return true;
|
||||
@ -60,7 +60,7 @@ bool ParseInputNodeQuantSpecs(absl::string_view node_names,
|
||||
std::vector<llvm::Optional<double>> node_maxs;
|
||||
if (!max_values.empty()) {
|
||||
std::vector<std::string> node_maxs_str = absl::StrSplit(max_values, ',');
|
||||
for (int i = 0; i < node_maxs_str.size(); i++) {
|
||||
for (int i = 0, e = node_maxs_str.size(); i < e; i++) {
|
||||
double value;
|
||||
if (!absl::SimpleAtod(node_maxs_str[i], &value)) {
|
||||
llvm::errs() << "Unexpected mins: " << node_maxs_str[i] << "\n";
|
||||
|
@ -294,7 +294,7 @@ class QuantizationDriver {
|
||||
return;
|
||||
if (current_op == op) llvm::errs() << "===>>>";
|
||||
llvm::errs() << op->getName() << " : (";
|
||||
for (auto i = 0; i < op->getNumOperands(); ++i) {
|
||||
for (int i = 0, e = op->getNumOperands(); i < e; ++i) {
|
||||
if (auto params = GetOperandQuantState(op, i).params)
|
||||
params.print(llvm::errs());
|
||||
else
|
||||
@ -303,7 +303,7 @@ class QuantizationDriver {
|
||||
llvm::errs() << ",";
|
||||
}
|
||||
llvm::errs() << ") -> (";
|
||||
for (auto i = 0; i < op->getNumResults(); ++i) {
|
||||
for (int i = 0, e = op->getNumResults(); i < e; ++i) {
|
||||
if (auto params = GetResultQuantState(op, i).params)
|
||||
params.print(llvm::errs());
|
||||
else
|
||||
|
@ -55,7 +55,7 @@ static Type GetQuantizedType(Builder builder, Type input_type,
|
||||
} else if (min.size() == max.size()) {
|
||||
auto shape = input_type.dyn_cast<ShapedType>();
|
||||
if (!shape || shape.getRank() <= quant_dim ||
|
||||
min.size() != shape.getDimSize(quant_dim)) {
|
||||
static_cast<int64_t>(min.size()) != shape.getDimSize(quant_dim)) {
|
||||
return {};
|
||||
}
|
||||
// TODO(b/141508873): the quantization dim is set to the last dimension.
|
||||
@ -76,7 +76,8 @@ TypeAttr RescaleQuantizedType(Type input, Attribute factor) {
|
||||
if (auto qtype = ele_type.dyn_cast<quant::UniformQuantizedPerAxisType>()) {
|
||||
ArrayRef<double> scales = qtype.getScales();
|
||||
// Broadcasting hasn't been implemented yet.
|
||||
if (scales.size() != factor_values.getNumElements()) return {};
|
||||
if (static_cast<int64_t>(scales.size()) != factor_values.getNumElements())
|
||||
return {};
|
||||
SmallVector<double, 4> new_scales;
|
||||
new_scales.reserve(scales.size());
|
||||
auto scales_iter = scales.begin();
|
||||
@ -270,7 +271,7 @@ Type GetUniformQuantizedPerAxisTypeForWeight(ElementsAttr attr, int quant_dim,
|
||||
bool narrow_range) {
|
||||
Builder builder(attr.getContext());
|
||||
auto shape = attr.getType().cast<ShapedType>().getShape();
|
||||
if (shape.size() <= quant_dim) return {};
|
||||
if (static_cast<int>(shape.size()) <= quant_dim) return {};
|
||||
// `symmetric` can only be used when it is `signed` and `narrow_range`.
|
||||
if (symmetric && (!is_signed || !narrow_range)) return {};
|
||||
|
||||
@ -335,7 +336,7 @@ quant::QuantizedType GetUniformQuantizedTypeForBias(
|
||||
const std::vector<quant::QuantizedType>& op_types) {
|
||||
if (op_types.empty()) return {};
|
||||
|
||||
int axis_size = 1;
|
||||
size_t axis_size = 1;
|
||||
int32_t quant_dim = -1;
|
||||
Type expressed_type;
|
||||
// Requires all the op types are valid UniformQuantizedTypes or
|
||||
@ -369,7 +370,7 @@ quant::QuantizedType GetUniformQuantizedTypeForBias(
|
||||
scales[index_scale.index()] *= index_scale.value();
|
||||
}
|
||||
} else if (auto type = op_type.dyn_cast<quant::UniformQuantizedType>()) {
|
||||
for (int index = 0; index != axis_size; ++index) {
|
||||
for (int index = 0, e = axis_size; index != e; ++index) {
|
||||
scales[index] *= type.getScale();
|
||||
}
|
||||
}
|
||||
|
@ -27,7 +27,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: tf-opt %s -quant-import-stats --quant-test-stats='entries { name: "op" params { min_max { min: -1 max: 1 } } } entries { name: "op_0:0" params { min_max { min: -2 max: 2 } } } entries { name_regex: "op_*" params { min_max { min: -3 max: 3 } } }' | FileCheck %s --dump-input-on-failure
|
||||
// RUN: tf-opt %s -quant-import-stats --quant-test-stats='entries { name: "op" params { min_max { min: -1 max: 1 } } } entries { name: "op_0:0" params { min_max { min: -2 max: 2 } } } entries { name_regex: "op_*" params { min_max { min: -3 max: 3 } } }' | FileCheck %s
|
||||
|
||||
|
||||
// CHECK-LABEL: import_stats_skip
|
||||
|
@ -32,7 +32,7 @@ cc_library(
|
||||
"//tensorflow/lite/core/api",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
],
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: tf-opt -pass-pipeline='func(canonicalize)' %s | FileCheck %s --dump-input-on-failure
|
||||
// RUN: tf-opt -pass-pipeline='func(canonicalize)' %s | FileCheck %s
|
||||
|
||||
// Checks that tfl.reshape should be removed if its output's only user is
|
||||
// another tfl.reshape
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: tf-opt %s -canonicalize | FileCheck %s --dump-input-on-failure
|
||||
// RUN: tf-opt %s -canonicalize | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @add_float
|
||||
func @add_float() -> (tensor<f32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) {
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: tf-opt %s -tfl-identify-dilated-conv | FileCheck %s --dump-input-on-failure
|
||||
// RUN: tf-opt %s -tfl-identify-dilated-conv | FileCheck %s
|
||||
|
||||
func @testDilatedConv(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
|
@ -1,4 +1,4 @@
|
||||
# RUN: tf_tfl_translate -tf-input-arrays=input0,input1 -tf-input-shapes=4:4 -tf-input-data-types=DT_INT32,DT_INT32 -tf-output-arrays=Add %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
|
||||
# RUN: tf_tfl_translate -tf-input-arrays=input0,input1 -tf-input-shapes=4:4 -tf-input-data-types=DT_INT32,DT_INT32 -tf-output-arrays=Add %s -o - | flatbuffer_to_string - | FileCheck %s
|
||||
|
||||
# Add two tensor<4xi32> inputs and return the result
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
# RUN: tf_tfl_translate -tf-input-arrays=input -tf-input-shapes=1,1,1,256 -tf-input-data-types=DT_FLOAT -tf-inference-type=DT_QINT8 -tf-input-min-values='-33.614346' -tf-input-max-values='21.54917' -tf-output-arrays=output %s -o - --output-mlir 2>&1 | FileCheck --check-prefix=MLIR %s --dump-input-on-failure
|
||||
# RUN: tf_tfl_translate -tf-input-arrays=input -tf-input-shapes=1,1,1,256 -tf-input-data-types=DT_FLOAT -tf-inference-type=DT_QINT8 -tf-input-min-values='-33.614346' -tf-input-max-values='21.54917' -tf-output-arrays=output %s -o - --output-mlir 2>&1 | FileCheck --check-prefix=MLIR %s
|
||||
# RUN: tf_tfl_translate -tf-input-arrays=input -tf-input-shapes=1,1,1,256 -tf-input-data-types=DT_FLOAT -tf-inference-type=DT_QINT8 -tf-input-min-values='-33.614346' -tf-input-max-values='21.54917' -tf-output-arrays=output %s -o - | flatbuffer_to_string - | FileCheck %s
|
||||
|
||||
node {
|
||||
|
@ -1,4 +1,4 @@
|
||||
# RUN: tf_tfl_translate -tf-input-arrays=unranked -tf-input-shapes=1,8,8,2 -tf-input-data-types=DT_INT32 -tf-output-arrays=unranked,static,static_10 %s -o - --output-mlir | FileCheck %s --dump-input-on-failure
|
||||
# RUN: tf_tfl_translate -tf-input-arrays=unranked -tf-input-shapes=1,8,8,2 -tf-input-data-types=DT_INT32 -tf-output-arrays=unranked,static,static_10 %s -o - --output-mlir | FileCheck %s
|
||||
|
||||
node {
|
||||
name: "tf.Const"
|
||||
|
@ -1,5 +1,4 @@
|
||||
# RUN: tf_tfl_translate -tf-input-arrays=a,b -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-input-shapes=4:4 -tf-output-arrays=StatefulIf,StatelessIf %s -o - --output-mlir | FileCheck %s --dump-input-on-failure
|
||||
|
||||
# RUN: tf_tfl_translate -tf-input-arrays=a,b -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-input-shapes=4:4 -tf-output-arrays=StatefulIf,StatelessIf %s -o - --output-mlir | FileCheck %s
|
||||
node {
|
||||
name: "tf.Less"
|
||||
op: "Less"
|
||||
|
@ -54,7 +54,7 @@ tf_native_cc_binary(
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
],
|
||||
)
|
||||
|
||||
@ -70,6 +70,6 @@ tf_native_cc_binary(
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
],
|
||||
)
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s
|
||||
// Ensure basic_lstm roundtrip exactly
|
||||
|
||||
func @main(%arg0: tensor<1x384xf32>, %arg1: tensor<1x96xf32>, %arg2: tensor<384x480xf32>, %arg3: tensor<384xf32>, %arg4: tensor<1x96xf32>) -> tensor<1x96xf32> {
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s
|
||||
// Ensure constants roundtrip exactly
|
||||
|
||||
func @bool() -> tensor<4xi1> {
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s
|
||||
|
||||
func @main(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<1x32x42x128xf32>, %arg2: tensor<4xi32>) -> tensor<1x64x84x32xf32> {
|
||||
%0 = "tfl.custom"(%arg0, %arg1, %arg2) {custom_code = "Convolution2DTransposeBias", custom_option = opaque<"tfl", "0x010000000200000002000000"> : tensor<12xi8>} : (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<4xi32>) -> tensor<1x64x84x32xf32>
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s
|
||||
|
||||
// CHECK: func @main(%arg0: tensor<?x19x19x3xf32>) -> tensor<?x9x9x4xf32>
|
||||
func @main(%arg0: tensor<?x19x19x3xf32>) -> tensor<?x9x9x4xf32> {
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir --use-external-constant - -o - | FileCheck --dump-input-on-failure %s
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir --use-external-constant - -o - | FileCheck %s
|
||||
// Ensure that `tfl.external_const` is imported when the flag `-use-external-constant` is enabled.
|
||||
|
||||
func @main(tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32> {
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s
|
||||
// Confirm function references in if ops are preserved
|
||||
func @main(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
|
||||
// CHECK: %{{.*}} = "tf.If"(%{{.*}}, %{{.*}}, %{{.*}}) {else_branch = @cond_false, is_stateless = false, then_branch = @cond_true} : (tensor<1xi1>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: json_to_flatbuffer %p/test_schema.fbs %s | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --dump-input-on-failure %s
|
||||
// RUN: json_to_flatbuffer %p/test_schema.fbs %s | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck %s
|
||||
|
||||
// CHECK: %cst = constant unit
|
||||
// CHECK: %[[RES0:.*]] = "tfl.conv_2d"(%arg0, %arg1, %cst) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 0 : i32, stride_w = 0 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, none) -> tensor<256x32x32x16xf32>
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s
|
||||
// This only test the exporter and importer are working without min/max quantization parameters.
|
||||
|
||||
func @main(tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32> {
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate -input-arrays=squared_difference --experimental-prune-unreachable-nodes-unconditionally --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate -input-arrays=squared_difference --experimental-prune-unreachable-nodes-unconditionally --tflite-flatbuffer-to-mlir - -o - | FileCheck %s
|
||||
// Tests -input-arrays flag.
|
||||
|
||||
func @main(%arg0: tensor<4xf32>) -> tensor<4xf32> {
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s
|
||||
|
||||
// Tests input and output names from FlatBuffer are added to `tf.entry_function` attribute.
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s
|
||||
// Ensure lstm roundtrip exactly
|
||||
|
||||
func @main(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4xf32>, %arg10: tensor<4xf32>, %arg11: tensor<4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<4xf32>, %arg14: tensor<4xf32>, %arg15: tensor<4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<4xf32>, %arg18: tensor<4xf32>, %arg19: tensor<4xf32>, %arg20: tensor<4xf32>, %arg21: tensor<4xf32>) -> tensor<1x4xf32> {
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s
|
||||
|
||||
// Confirm a wide array of attribute survives the round-trip
|
||||
func @main(tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32> {
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s
|
||||
// Confirm float constants and operators survive a roundtrip
|
||||
|
||||
func @main(tensor<4xf32>) -> tensor<4xf32> {
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s
|
||||
// Test to make sure optional parameters survive a roundtrip
|
||||
|
||||
func @main(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user