commit
3bb28df8d4
.bazelrcRELEASE.md
tensorflow
BUILDapi_template.__init__.pyapi_template_v1.__init__.py
c
c_api_experimental.ccc_api_experimental.h
eager
BUILDc_api.ccc_api_distributed_test.ccc_api_test_util.ccc_api_test_util.hc_api_unified_experimental_graph.ccc_api_unified_experimental_test.ccgradients.ccgradients.hgradients_test.ccimmediate_execution_operation.h
experimental
filesystem
gradients
ops
saved_model
cc/saved_model/experimental/tests
compiler
jit
BUILD
kernels
mark_for_compilation_pass_test.ccxla_compile_on_demand_op.ccxla_device_ops.ccxla_launch_util.ccxla_launch_util.hmlir
g3doc
hlo
lite
BUILDflatbuffer_export.cctf_tfl_passes.cc
experimental
estimators
tfl_hardware_interfaces.tdir
python
quantization
tests
end2end
flatbuffer2mlir
fuse-tftext.mlirlegalize-tf-no-runtime-verification.mlirlegalize-tf.mlirlower-static-tensor-list.mlirmlir2flatbuffer
ops.mliroptimize.mlirprepare-composite-functions-tf.mlirprepare-tf.mlirraise-custom-ops.mlir
49
.bazelrc
49
.bazelrc
@ -78,7 +78,16 @@
|
||||
# elinux: General Embedded Linux options shared by all flavors.
|
||||
# elinux_aarch64: Embedded Linux options for aarch64 (ARM64) CPU support.
|
||||
# elinux_armhf: Embedded Linux options for armhf (ARMv7) CPU support.
|
||||
|
||||
#
|
||||
# Release build options (for all operating systems)
|
||||
# release_common: Common options for all builds on all operating systems.
|
||||
# release_windows_common: Common options for all builds on Windows.
|
||||
# release_gpu_common: Common options for GPU builds on Linux and Windows.
|
||||
# release_cpu_linux: Toolchain and CUDA options for Linux CPU builds.
|
||||
# release_cpu_macos: Toolchain and CUDA options for MacOS CPU builds.
|
||||
# release_gpu_linux: Toolchain and CUDA options for Linux GPU builds.
|
||||
# release_cpu_windows: Toolchain and CUDA options for Windows CPU builds.
|
||||
# release_gpu_windows: Toolchain and CUDA options for Windows GPU builds.
|
||||
|
||||
# Allow builds using libc++ as a linker library
|
||||
# This is mostly for OSSFuzz, so we also pass in the flags from environment to clean build file
|
||||
@ -534,3 +543,41 @@ try-import %workspace%/.tf_configure.bazelrc
|
||||
|
||||
# Put user-specific options in .bazelrc.user
|
||||
try-import %workspace%/.bazelrc.user
|
||||
|
||||
# Here are bazelrc configs for release builds
|
||||
build:release_common --config=opt
|
||||
build:release_common --config=v2
|
||||
build:release_common --action_env TF_CONFIGURE_IOS="0"
|
||||
|
||||
build:release_cpu_linux --config=release_common
|
||||
build:release_cpu_linux --config=avx_linux
|
||||
# We use the same toolchain for CPU/GPU packages.
|
||||
# Did not add this to the defaults in case this changes.
|
||||
build:release_cpu_linux --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain
|
||||
|
||||
build:release_cpu_macos --config=release_common
|
||||
build:release_cpu_macos --config=avx_linux
|
||||
|
||||
build:release_gpu_common --config=release_common
|
||||
build:release_gpu_common --config=cuda
|
||||
build:release_gpu_common --config=tensorrt
|
||||
build:release_gpu_common --action_env CUDA_TOOLKIT_PATH="/usr/local/cuda-10.1"
|
||||
build:release_gpu_common --action_env=TF_CUDA_VERSION="10"
|
||||
build:release_gpu_common --action_env=TF_CUDNN_VERSION="7"
|
||||
build:release_gpu_common --action_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_35,sm_37,sm_52,sm_60,sm_61,compute_70"
|
||||
build:release_gpu_common --action_env=TENSORRT_INSTALL_PATH="/usr/local/tensorrt"
|
||||
build:release_gpu_common --action_env=LD_LIBRARY_PATH="/usr/local/tensorrt/lib"
|
||||
build:release_gpu_common --action_env=GCC_HOST_COMPILER_PATH="/usr/bin/gcc-5"
|
||||
|
||||
|
||||
build:release_gpu_linux --config=release_gpu_common
|
||||
build:release_gpu_linux --config=avx_linux
|
||||
build:release_gpu_linux --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain
|
||||
|
||||
build:release_windows_common --config=release_common
|
||||
build:release_windows_common --define=no_tensorflow_py_deps=true
|
||||
build:release_windows_common --announce_rc
|
||||
|
||||
build:release_cpu_windows --config=release_windows_common
|
||||
|
||||
build:release_gpu_windows --config=release_windows_common
|
||||
|
27
RELEASE.md
27
RELEASE.md
@ -11,6 +11,9 @@
|
||||
* C-API functions `TF_StringDecode`, `TF_StringEncode`, and
|
||||
`TF_StringEncodedSize` are no longer relevant and have been removed; see
|
||||
core/platform/ctstring.h for string access/modification in C.
|
||||
* Removed `tf.distribute.Strategy.experimental_run_v2` method, which was deprecated in TF 2.2.
|
||||
* `tensorflow.python`, `tensorflow.core` and `tensorflow.compiler` modules are
|
||||
now hidden. These modules are not part of TensorFlow public API.
|
||||
|
||||
## Known Caveats
|
||||
|
||||
@ -20,6 +23,7 @@
|
||||
|
||||
* <INSERT MAJOR FEATURE HERE, USING MARKDOWN SYNTAX>
|
||||
* <IF RELEASE CONTAINS MULTIPLE FEATURES FROM SAME AREA, GROUP THEM TOGETHER>
|
||||
* A new module named `tf.experimental.numpy` is added, which is a NumPy-compatible API for writing TF programs. This module provides class `ndarray`, which mimics the `ndarray` class in NumPy, and wraps an immutable `tf.Tensor` under the hood. A subset of NumPy functions (e.g. `numpy.add`) are provided. Their inter-operation with TF facilities is seamless in most cases. See tensorflow/python/ops/numpy_ops/README.md for details of what are supported and what are the differences with NumPy.
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
|
||||
@ -27,10 +31,22 @@
|
||||
* <IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE>
|
||||
* <NOTES SHOULD BE GROUPED PER AREA>
|
||||
* TF Core:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* `tf.types.experimental.TensorLike` is a new `Union` type that can be used as
|
||||
type annotation for variables representing a Tensor or a value that can be
|
||||
converted to Tensor by `tf.convert_to_tensor`.
|
||||
* Calling ops with a python constants or numpy values is now consistent with
|
||||
tf.convert_to_tensor behavior. This avoids operations like tf.reshape
|
||||
truncating inputs such as from int64 to int32.
|
||||
* Added `tf.sparse.map_values` to apply a function to the `.value`s of `SparseTensror` arguments.
|
||||
* `tf.data`:
|
||||
* Added optional `exclude_cols` parameter to CsvDataset. This parameter is
|
||||
the complement of `select_cols`; at most one of these should be specified.
|
||||
* We have implemented an optimization which reorders data-discarding
|
||||
transformations such as `take` and `shard` to happen earlier in the
|
||||
dataset when it is safe to do so. The optimization can be disabled via
|
||||
the `experimental_optimization.reorder_data_discarding_ops` dataset
|
||||
option.
|
||||
* `tf.distribute`:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* `tf.keras`:
|
||||
@ -38,7 +54,8 @@
|
||||
* `tf.function`/AutoGraph:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* `tf.lite`:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* Better support for ops with high-dimensional broadcasting inputs by adding
|
||||
`BroadcastTo` ops when necessary.
|
||||
* `tf.random`:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* Math and Linear Algebra:
|
||||
@ -50,9 +67,9 @@
|
||||
* Tracing and Debugging:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* Other:
|
||||
* We have replaced uses of "whitelist" with "allowlist" where possible.
|
||||
Please see https://developers.google.com/style/word-list#blacklist for more
|
||||
context.
|
||||
* We have replaced uses of "whitelist" and "blacklist" with "allowlist"
|
||||
and "denylist" where possible. Please see
|
||||
https://developers.google.com/style/word-list#blacklist for more context.
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
|
||||
## Thanks to our Contributors
|
||||
|
@ -532,16 +532,14 @@ selects.config_setting_group(
|
||||
package_group(
|
||||
name = "internal",
|
||||
packages = [
|
||||
# To pass open source testing in the pip Kokoros.
|
||||
"//bazel_pip/tensorflow/...",
|
||||
"//learning/brain/swift/x10/...",
|
||||
"//perftools/accelerators/xprof/api/...",
|
||||
"//third_party/py/autograph/...",
|
||||
"//third_party/swift/tensorflow/x10/...",
|
||||
"//third_party/swift/tensorflow_apis/...",
|
||||
"//tensorflow/...",
|
||||
"//tensorflow_estimator/python/estimator/...",
|
||||
"//tensorflow_models/official/...",
|
||||
"//third_party/py/autograph/...",
|
||||
"//third_party/swift/tensorflow/x10/...",
|
||||
"//third_party/swift/tensorflow_apis/...",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -158,4 +158,23 @@ if hasattr(_current_module, 'keras'):
|
||||
setattr(_current_module, "initializers", initializers)
|
||||
# pylint: enable=undefined-variable
|
||||
|
||||
# Delete modules that should be hidden from dir().
|
||||
# Don't fail if these modules are not available.
|
||||
# For e.g. this file will be originally placed under tensorflow/_api/v1 which
|
||||
# does not have 'python', 'core' directories. Then, it will be copied
|
||||
# to tensorflow/ which does have these two directories.
|
||||
# pylint: disable=undefined-variable
|
||||
try:
|
||||
del python
|
||||
except NameError:
|
||||
pass
|
||||
try:
|
||||
del core
|
||||
except NameError:
|
||||
pass
|
||||
try:
|
||||
del compiler
|
||||
except NameError:
|
||||
pass
|
||||
|
||||
# __all__ PLACEHOLDER
|
||||
|
@ -156,4 +156,25 @@ if _running_from_pip_package():
|
||||
if _fi.file_exists(_plugin_dir):
|
||||
_ll.load_library(_plugin_dir)
|
||||
|
||||
# Delete modules that should be hidden from dir().
|
||||
# Don't fail if these modules are not available.
|
||||
# For e.g. this file will be originally placed under tensorflow/_api/v1 which
|
||||
# does not have 'python', 'core' directories. Then, it will be copied
|
||||
# to tensorflow/ which does have these two directories.
|
||||
|
||||
# pylint: disable=undefined-variable
|
||||
try:
|
||||
del python
|
||||
except NameError:
|
||||
pass
|
||||
try:
|
||||
del core
|
||||
except NameError:
|
||||
pass
|
||||
try:
|
||||
del compiler
|
||||
except NameError:
|
||||
pass
|
||||
|
||||
|
||||
# __all__ PLACEHOLDER
|
||||
|
@ -29,6 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||
#include "tensorflow/core/framework/collective.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
@ -525,12 +526,12 @@ tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def,
|
||||
|
||||
LOG_AND_RETURN_IF_ERROR(context->StoreCollectiveOpsServer(
|
||||
std::move(new_server), grpc_server->worker_env()->device_mgr,
|
||||
grpc_server->worker_env()->collective_executor_mgr));
|
||||
grpc_server->worker_env()->collective_executor_mgr.get()));
|
||||
} else {
|
||||
LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def));
|
||||
LOG_AND_RETURN_IF_ERROR(context->StoreCollectiveOpsServer(
|
||||
/*new_server=*/nullptr, grpc_server->worker_env()->device_mgr,
|
||||
grpc_server->worker_env()->collective_executor_mgr));
|
||||
grpc_server->worker_env()->collective_executor_mgr.get()));
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
#undef LOG_AND_RETURN_IF_ERROR
|
||||
@ -551,6 +552,14 @@ TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx,
|
||||
status->status = EnableCollectiveOps(server_def, ctx);
|
||||
}
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,
|
||||
TF_Status* status) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
auto collective_executor_handle = context->GetCollectiveExecutorHandle();
|
||||
collective_executor_handle->get()->StartAbort(status->status);
|
||||
}
|
||||
|
||||
TF_ShapeAndTypeList* TF_NewShapeAndTypeList(int num_items) {
|
||||
TF_ShapeAndTypeList* result = new TF_ShapeAndTypeList;
|
||||
result->num_items = num_items;
|
||||
|
@ -230,6 +230,14 @@ TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx,
|
||||
size_t proto_len,
|
||||
TF_Status* status);
|
||||
|
||||
// Aborts all ongoing collectives with the specified status. After abortion,
|
||||
// subsequent collectives will error with this status immediately.
|
||||
//
|
||||
// This is intended to be used when a peer failure is detected. There's yet no
|
||||
// way to reset the collectives other than restarting the program.
|
||||
TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,
|
||||
TF_Status* status);
|
||||
|
||||
// Information about the shape of a Tensor and its type.
|
||||
struct TF_ShapeAndType {
|
||||
// Number of dimensions. -1 indicates unknown rank.
|
||||
|
@ -240,6 +240,8 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/c/experimental/gradients:math_grad",
|
||||
"//tensorflow/c/experimental/ops:array_ops",
|
||||
"//tensorflow/cc/profiler",
|
||||
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
|
||||
"//tensorflow/core:lib",
|
||||
@ -308,6 +310,8 @@ cc_library(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/util:abstract_stack_trace",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
@ -514,7 +518,6 @@ tf_cuda_cc_test(
|
||||
extra_copts = tfe_xla_copts(),
|
||||
tags = [
|
||||
"no_windows",
|
||||
"noasan", # leaks gRPC server instances
|
||||
],
|
||||
deps = [
|
||||
":c_api",
|
||||
@ -581,7 +584,6 @@ tf_cuda_cc_test(
|
||||
extra_copts = tfe_xla_copts(),
|
||||
tags = [
|
||||
"no_windows",
|
||||
"noasan", # leaks gRPC server instances
|
||||
],
|
||||
deps = [
|
||||
":c_api",
|
||||
|
@ -94,7 +94,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
||||
using tensorflow::int64;
|
||||
using tensorflow::string;
|
||||
|
||||
namespace {
|
||||
@ -968,7 +967,7 @@ int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
int64 num_elements = -1;
|
||||
tensorflow::int64 num_elements = -1;
|
||||
status->status = tensorflow::unwrap(h)->NumElements(&num_elements);
|
||||
return num_elements;
|
||||
}
|
||||
@ -980,7 +979,7 @@ int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
|
||||
return -1;
|
||||
}
|
||||
|
||||
int64 dim = -1;
|
||||
tensorflow::int64 dim = -1;
|
||||
status->status = tensorflow::unwrap(h)->Dim(dim_index, &dim);
|
||||
return dim;
|
||||
}
|
||||
|
@ -174,9 +174,9 @@ void TestFunctionWithPackedInput(const bool remote) {
|
||||
const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
|
||||
|
||||
// Create one variable per task.
|
||||
TFE_TensorHandle* h0 = TestVariable(ctx, 1.0, task0_name);
|
||||
TFE_TensorHandle* h1 = TestVariable(ctx, 2.0, task1_name);
|
||||
TFE_TensorHandle* h2 = TestVariable(ctx, 3.0, task2_name);
|
||||
TFE_TensorHandle* h0 = TestVariable(ctx, 1.0, task1_name);
|
||||
TFE_TensorHandle* h1 = TestVariable(ctx, 2.0, task2_name);
|
||||
TFE_TensorHandle* h2 = TestVariable(ctx, 3.0, task0_name);
|
||||
|
||||
// Add a sync point in order to make sure that variables have been initialized
|
||||
// before the function execution starts.
|
||||
@ -185,6 +185,9 @@ void TestFunctionWithPackedInput(const bool remote) {
|
||||
VarIsInitialized(ctx, h2);
|
||||
|
||||
// Pack 3 variable handles into one TFE_TensorHandle.
|
||||
// When remote is false, function device is placed on task0. Handle types are
|
||||
// REMOTE, REMOTE, LOCAL on task0. When remote is true, function device is
|
||||
// placed on task1, Handle types are LOCAL, REMOTE, LOCAL on task1.
|
||||
int num_replicas = 3;
|
||||
std::vector<TFE_TensorHandle*> handles = {h0, h1, h2};
|
||||
TFE_TensorHandle* packed_handle =
|
||||
|
@ -88,6 +88,20 @@ TFE_TensorHandle* TestMatrixTensorHandle(TFE_Context* ctx) {
|
||||
return th;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TestMatrixTensorHandleWithInput(TFE_Context* ctx,
|
||||
float data[], int64_t dims[],
|
||||
int num_dims) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_Tensor* t =
|
||||
TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0], num_dims, status);
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteTensor(t);
|
||||
TF_DeleteStatus(status);
|
||||
return th;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TestMatrixTensorHandle100x100(TFE_Context* ctx) {
|
||||
constexpr int64_t dims[] = {100, 100};
|
||||
constexpr int num_elements = dims[0] * dims[1];
|
||||
|
@ -34,6 +34,12 @@ TFE_TensorHandle* DoubleTestMatrixTensorHandle(TFE_Context* ctx);
|
||||
// Return a tensor handle containing a 2x2 matrix of floats
|
||||
TFE_TensorHandle* TestMatrixTensorHandle(TFE_Context* ctx);
|
||||
|
||||
// Return a tensor handle containing 2D matrix containing given data and
|
||||
// dimensions
|
||||
TFE_TensorHandle* TestMatrixTensorHandleWithInput(TFE_Context* ctx,
|
||||
float data[], int64_t dims[],
|
||||
int num_dims);
|
||||
|
||||
// Return a tensor handle containing a 100x100 matrix of floats
|
||||
TFE_TensorHandle* TestMatrixTensorHandle100x100(TFE_Context* ctx);
|
||||
|
||||
|
@ -33,6 +33,7 @@ limitations under the License.
|
||||
|
||||
using tensorflow::dyn_cast;
|
||||
using tensorflow::string;
|
||||
using tensorflow::gtl::ArraySlice;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tracing {
|
||||
@ -138,20 +139,23 @@ class GraphOperation : public TracingOperation {
|
||||
|
||||
Status SetAttrString(const char* attr_name, const char* data,
|
||||
size_t length) override {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrString has not been implemented yet.");
|
||||
tensorflow::StringPiece s(data, length);
|
||||
op_->node_builder.Attr(attr_name, s);
|
||||
return Status::OK();
|
||||
}
|
||||
Status SetAttrInt(const char* attr_name, int64_t value) override {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrInt has not been implemented yet.");
|
||||
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
|
||||
"64-bit int types should match in size");
|
||||
op_->node_builder.Attr(attr_name, static_cast<tensorflow::int64>(value));
|
||||
return Status::OK();
|
||||
}
|
||||
Status SetAttrFloat(const char* attr_name, float value) override {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrFloat has not been implemented yet.");
|
||||
op_->node_builder.Attr(attr_name, value);
|
||||
return Status::OK();
|
||||
}
|
||||
Status SetAttrBool(const char* attr_name, bool value) override {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrBool has not been implemented yet.");
|
||||
op_->node_builder.Attr(attr_name, value);
|
||||
return Status::OK();
|
||||
}
|
||||
Status SetAttrType(const char* const attr_name, DataType value) override {
|
||||
if (!op_) {
|
||||
@ -164,8 +168,15 @@ class GraphOperation : public TracingOperation {
|
||||
}
|
||||
Status SetAttrShape(const char* attr_name, const int64_t* dims,
|
||||
const int num_dims) override {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrShape has not been implemented yet.");
|
||||
PartialTensorShape shape;
|
||||
if (num_dims >= 0) {
|
||||
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
|
||||
"64-bit int types should match in size");
|
||||
shape = PartialTensorShape(ArraySlice<tensorflow::int64>(
|
||||
reinterpret_cast<const tensorflow::int64*>(dims), num_dims));
|
||||
}
|
||||
op_->node_builder.Attr(attr_name, shape);
|
||||
return Status::OK();
|
||||
}
|
||||
Status SetAttrFunction(const char* attr_name,
|
||||
const AbstractOperation* value) override {
|
||||
@ -174,8 +185,10 @@ class GraphOperation : public TracingOperation {
|
||||
}
|
||||
Status SetAttrFunctionName(const char* attr_name, const char* value,
|
||||
size_t length) override {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrFunctionName has not been implemented yet.");
|
||||
tensorflow::NameAttrList func_name;
|
||||
func_name.set_name(string(value, value + length));
|
||||
op_->node_builder.Attr(attr_name, func_name);
|
||||
return Status::OK();
|
||||
}
|
||||
Status SetAttrTensor(const char* attr_name,
|
||||
AbstractTensorInterface* tensor) override {
|
||||
@ -184,33 +197,71 @@ class GraphOperation : public TracingOperation {
|
||||
}
|
||||
Status SetAttrStringList(const char* attr_name, const void* const* values,
|
||||
const size_t* lengths, int num_values) override {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrStringList has not been implemented yet.");
|
||||
if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) {
|
||||
op_->colocation_constraints.clear();
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
op_->colocation_constraints.emplace(static_cast<const char*>(values[i]),
|
||||
lengths[i]);
|
||||
}
|
||||
} else {
|
||||
std::vector<tensorflow::StringPiece> v;
|
||||
v.reserve(num_values);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
v.emplace_back(static_cast<const char*>(values[i]), lengths[i]);
|
||||
}
|
||||
op_->node_builder.Attr(attr_name, v);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
Status SetAttrFloatList(const char* attr_name, const float* values,
|
||||
int num_values) override {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrFloatList has not been implemented yet.");
|
||||
op_->node_builder.Attr(attr_name,
|
||||
ArraySlice<const float>(values, num_values));
|
||||
return Status::OK();
|
||||
}
|
||||
Status SetAttrIntList(const char* attr_name, const int64_t* values,
|
||||
int num_values) override {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrIntList has not been implemented yet.");
|
||||
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
|
||||
"64-bit int types should match in size");
|
||||
op_->node_builder.Attr(
|
||||
attr_name,
|
||||
ArraySlice<const tensorflow::int64>(
|
||||
reinterpret_cast<const tensorflow::int64*>(values), num_values));
|
||||
return Status::OK();
|
||||
}
|
||||
Status SetAttrTypeList(const char* attr_name, const DataType* values,
|
||||
int num_values) override {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrTypeList has not been implemented yet.");
|
||||
op_->node_builder.Attr(attr_name,
|
||||
ArraySlice<const DataType>(values, num_values));
|
||||
return Status::OK();
|
||||
}
|
||||
Status SetAttrBoolList(const char* attr_name, const unsigned char* values,
|
||||
int num_values) override {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrBoolList has not been implemented yet.");
|
||||
std::unique_ptr<bool[]> b(new bool[num_values]);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
b[i] = values[i];
|
||||
}
|
||||
op_->node_builder.Attr(attr_name,
|
||||
ArraySlice<const bool>(b.get(), num_values));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
|
||||
const int* num_dims, int num_values) override {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrShapeList has not been implemented yet.");
|
||||
std::vector<PartialTensorShape> shapes;
|
||||
shapes.reserve(num_values);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
if (num_dims[i] < 0) {
|
||||
shapes.emplace_back();
|
||||
} else {
|
||||
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
|
||||
"64-bit int types should match in size");
|
||||
shapes.emplace_back(ArraySlice<tensorflow::int64>(
|
||||
reinterpret_cast<const tensorflow::int64*>(dims[i]), num_dims[i]));
|
||||
}
|
||||
}
|
||||
op_->node_builder.Attr(attr_name, shapes);
|
||||
return Status::OK();
|
||||
}
|
||||
Status SetAttrFunctionList(
|
||||
const char* attr_name,
|
||||
|
@ -92,9 +92,255 @@ TEST_P(UnifiedCAPI, TestBasicEager) {
|
||||
TF_DeleteExecutionContext(ctx);
|
||||
}
|
||||
|
||||
// MatMul Test
|
||||
TEST_P(UnifiedCAPI, TestBasicEagerMatMul) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TF_ExecutionContext* ctx = TF_NewEagerExecutionContext(opts, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
/* Want to test simple MatMul example:
|
||||
[[0,0], * [[0,0], = [[0,0],
|
||||
[0,0]] [0,0]] [0,0]]
|
||||
*/
|
||||
|
||||
// Build an abstract input tensor.
|
||||
int64_t dims[] = {2, 2}; // Matrices will be 2 x 2
|
||||
int num_dims = sizeof(dims) / sizeof(dims[0]);
|
||||
|
||||
float vals[] = {0.0f, 0.0f, 0.0f, 0.0f};
|
||||
TFE_Context* eager_ctx = TF_ExecutionContextGetTFEContext(ctx, status.get());
|
||||
TFE_TensorHandle* t =
|
||||
TestMatrixTensorHandleWithInput(eager_ctx, vals, dims, num_dims);
|
||||
|
||||
TF_AbstractTensor* at = TF_CreateAbstractTensorFromEagerTensor(
|
||||
t, status.get()); // get abstract tensor
|
||||
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build an abstract operation.
|
||||
auto* op = TF_NewAbstractOp(ctx);
|
||||
TF_AbstractOpSetOpType(op, "MatMul", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build inputs and outputs.
|
||||
TF_AbstractTensor* inputs[2] = {at, at};
|
||||
TF_OutputList* o = TF_NewOutputList();
|
||||
TF_OutputListSetNumOutputs(o, 1, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Execute.
|
||||
TF_ExecuteOperation(op, 2, inputs, o, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Clean up operation and inputs.
|
||||
TF_DeleteAbstractOp(op);
|
||||
TF_DeleteAbstractTensor(at);
|
||||
|
||||
// Verify the results.
|
||||
ASSERT_EQ(1, TF_OutputListNumOutputs(o));
|
||||
TF_AbstractTensor* result = TF_OutputListGet(o, 0);
|
||||
TFE_TensorHandle* result_t =
|
||||
TF_AbstractTensorGetEagerTensor(result, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_Tensor* result_tensor = TFE_TensorHandleResolve(result_t, status.get());
|
||||
|
||||
// Copy Tensor data into an array.
|
||||
float result_data[4] = {0};
|
||||
memcpy(&result_data[0], TF_TensorData(result_tensor),
|
||||
TF_TensorByteSize(result_tensor));
|
||||
|
||||
int data_len = 4; // length of result_data
|
||||
for (int i = 0; i < data_len; i++) {
|
||||
EXPECT_EQ(result_data[i], 0);
|
||||
}
|
||||
|
||||
TF_DeleteTensor(result_tensor);
|
||||
TF_DeleteAbstractTensor(result);
|
||||
TF_DeleteOutputList(o);
|
||||
TF_DeleteExecutionContext(ctx);
|
||||
}
|
||||
|
||||
// MatMul Test 2
|
||||
TEST_P(UnifiedCAPI, TestBasicEagerMatMul2) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TF_ExecutionContext* ctx = TF_NewEagerExecutionContext(opts, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
/* Want to test simple MatMul example with abstract tensors:
|
||||
[[1,2], * [[5,6], = [[19,22],
|
||||
[3,4]] [7,8]] [43,50]]
|
||||
*/
|
||||
|
||||
// Build 1st Matrix.
|
||||
int64_t dims[] = {2, 2}; // Matrices will be 2 x 2
|
||||
int num_dims = sizeof(dims) / sizeof(dims[0]);
|
||||
|
||||
float vals1[] = {1.0f, 2.0f, 3.0f, 4.0f};
|
||||
TFE_Context* eager_ctx = TF_ExecutionContextGetTFEContext(ctx, status.get());
|
||||
TFE_TensorHandle* t1 =
|
||||
TestMatrixTensorHandleWithInput(eager_ctx, vals1, dims, num_dims);
|
||||
|
||||
TF_AbstractTensor* at1 = TF_CreateAbstractTensorFromEagerTensor(
|
||||
t1, status.get()); // get abstract tensor
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build 2nd Matrix.
|
||||
float vals2[] = {5.0f, 6.0f, 7.0f, 8.0f};
|
||||
TFE_TensorHandle* t2 =
|
||||
TestMatrixTensorHandleWithInput(eager_ctx, vals2, dims, num_dims);
|
||||
|
||||
TF_AbstractTensor* at2 = TF_CreateAbstractTensorFromEagerTensor(
|
||||
t2, status.get()); // get abstract tensor
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build an abstract operation.
|
||||
auto* op = TF_NewAbstractOp(ctx);
|
||||
TF_AbstractOpSetOpType(op, "MatMul", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build inputs and outputs.
|
||||
TF_AbstractTensor* inputs[2] = {at1, at2};
|
||||
TF_OutputList* o = TF_NewOutputList();
|
||||
TF_OutputListSetNumOutputs(o, 1, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Execute.
|
||||
TF_ExecuteOperation(op, 2, inputs, o, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Clean up operation and inputs.
|
||||
TF_DeleteAbstractOp(op);
|
||||
TF_DeleteAbstractTensor(at1);
|
||||
TF_DeleteAbstractTensor(at2);
|
||||
|
||||
// Verify the results.
|
||||
ASSERT_EQ(1, TF_OutputListNumOutputs(o));
|
||||
TF_AbstractTensor* result = TF_OutputListGet(o, 0);
|
||||
TFE_TensorHandle* result_t =
|
||||
TF_AbstractTensorGetEagerTensor(result, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
TF_Tensor* result_tensor = TFE_TensorHandleResolve(result_t, status.get());
|
||||
|
||||
// Copy Tensor data into array.
|
||||
float result_data[4] = {0};
|
||||
memcpy(&result_data[0], TF_TensorData(result_tensor),
|
||||
TF_TensorByteSize(result_tensor));
|
||||
|
||||
// Build expected result & verify.
|
||||
float e_vals[] = {19.0f, 22.0f, 43.0f, 50.0f};
|
||||
|
||||
int data_len = 4; // length of e_vals
|
||||
for (int i = 0; i < data_len; i++) {
|
||||
EXPECT_EQ(result_data[i], e_vals[i]);
|
||||
}
|
||||
|
||||
TF_DeleteTensor(result_tensor);
|
||||
TF_DeleteAbstractTensor(result);
|
||||
TF_DeleteOutputList(o);
|
||||
TF_DeleteExecutionContext(ctx);
|
||||
}
|
||||
|
||||
// MatAdd
|
||||
TEST_P(UnifiedCAPI, TestBasicEagerMatAdd) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TF_ExecutionContext* ctx = TF_NewEagerExecutionContext(opts, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
/* Want to test simple MatAdd example with abstract tensors:
|
||||
[[1,2] , + [[5,6], = [[6,8],
|
||||
[3,4] ] [7,8] ] [10,12]]
|
||||
*/
|
||||
|
||||
// Build 1st Matrix.
|
||||
int64_t dims[] = {2, 2}; // Matrices will be 2 x 2
|
||||
int num_dims = sizeof(dims) / sizeof(dims[0]);
|
||||
|
||||
float vals1[] = {1.0f, 2.0f, 3.0f, 4.0f};
|
||||
TFE_Context* eager_ctx = TF_ExecutionContextGetTFEContext(ctx, status.get());
|
||||
TFE_TensorHandle* t1 =
|
||||
TestMatrixTensorHandleWithInput(eager_ctx, vals1, dims, num_dims);
|
||||
|
||||
TF_AbstractTensor* at1 = TF_CreateAbstractTensorFromEagerTensor(
|
||||
t1, status.get()); // get abstract tensor
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build 2nd Matrix.
|
||||
float vals2[] = {5.0f, 6.0f, 7.0f, 8.0f};
|
||||
TFE_TensorHandle* t2 =
|
||||
TestMatrixTensorHandleWithInput(eager_ctx, vals2, dims, num_dims);
|
||||
|
||||
TF_AbstractTensor* at2 = TF_CreateAbstractTensorFromEagerTensor(
|
||||
t2, status.get()); // get abstract tensor
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build an abstract operation.
|
||||
auto* op = TF_NewAbstractOp(ctx);
|
||||
TF_AbstractOpSetOpType(op, "Add", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build inputs and outputs.
|
||||
TF_AbstractTensor* inputs[2] = {at1, at2};
|
||||
TF_OutputList* o = TF_NewOutputList();
|
||||
TF_OutputListSetNumOutputs(o, 1, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Execute.
|
||||
TF_ExecuteOperation(op, 2, inputs, o, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Clean up operation and inputs.
|
||||
TF_DeleteAbstractOp(op);
|
||||
TF_DeleteAbstractTensor(at1);
|
||||
TF_DeleteAbstractTensor(at2);
|
||||
|
||||
// Verify the results.
|
||||
ASSERT_EQ(1, TF_OutputListNumOutputs(o));
|
||||
TF_AbstractTensor* result = TF_OutputListGet(o, 0);
|
||||
TFE_TensorHandle* result_t =
|
||||
TF_AbstractTensorGetEagerTensor(result, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
TF_Tensor* result_tensor = TFE_TensorHandleResolve(result_t, status.get());
|
||||
|
||||
// Copy Tensor data into array.
|
||||
float result_data[4] = {0};
|
||||
memcpy(&result_data[0], TF_TensorData(result_tensor),
|
||||
TF_TensorByteSize(result_tensor));
|
||||
|
||||
// Build expected result & verify.
|
||||
float e_vals[] = {6.0f, 8.0f, 10.0f, 12.0f};
|
||||
|
||||
int data_len = 4; // length of e_vals
|
||||
for (int i = 0; i < data_len; i++) {
|
||||
EXPECT_EQ(result_data[i], e_vals[i]);
|
||||
}
|
||||
|
||||
TF_DeleteTensor(result_tensor);
|
||||
TF_DeleteAbstractTensor(result);
|
||||
TF_DeleteOutputList(o);
|
||||
TF_DeleteExecutionContext(ctx);
|
||||
}
|
||||
|
||||
TEST_P(UnifiedCAPI, TestBasicGraph) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
|
||||
// Start a new function / execution context.
|
||||
string fn_name = "double";
|
||||
TF_ExecutionContext* graph_ctx =
|
||||
@ -142,6 +388,7 @@ TEST_P(UnifiedCAPI, TestBasicGraph) {
|
||||
|
||||
TF_ExecutionContextRegisterFunction(eager_execution_ctx, func, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build the abstract op to run the function.
|
||||
TF_AbstractOp* fn_op = TF_NewAbstractOp(eager_execution_ctx);
|
||||
TF_AbstractOpSetOpType(fn_op, fn_name.c_str(), status.get());
|
||||
@ -180,6 +427,111 @@ TEST_P(UnifiedCAPI, TestBasicGraph) {
|
||||
TF_DeleteExecutionContext(eager_execution_ctx);
|
||||
}
|
||||
|
||||
// Graph Tracing for MatMul
|
||||
TEST_P(UnifiedCAPI, TestBasicGraphMatMul) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
|
||||
// Start a new function / execution context.
|
||||
string fn_name = "matrix_multiply";
|
||||
TF_ExecutionContext* graph_ctx =
|
||||
TF_CreateFunction(fn_name.c_str(), status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
auto* placeholder_t =
|
||||
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build an abstract operation.
|
||||
auto* matmul_op = TF_NewAbstractOp(graph_ctx);
|
||||
TF_AbstractOpSetOpType(matmul_op, "MatMul", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractOpSetOpName(matmul_op, "my_matmul", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build inputs and outputs.
|
||||
TF_AbstractTensor* inputs[2] = {placeholder_t, placeholder_t};
|
||||
TF_OutputList* mm_outputs = TF_NewOutputList();
|
||||
TF_OutputListSetNumOutputs(mm_outputs, 1, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Execute.
|
||||
TF_ExecuteOperation(matmul_op, 2, inputs, mm_outputs, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Clean up operation and inputs.
|
||||
TF_DeleteAbstractOp(matmul_op);
|
||||
|
||||
TF_AbstractFunction* func =
|
||||
TF_FinalizeFunction(graph_ctx, mm_outputs, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
/* Now that the graph is built, test graph implementation on matmul example:
|
||||
[[1,1] , * [[1,1] , = [[2,2],
|
||||
[1,1]] [1,1]] [2,2]]
|
||||
*/
|
||||
|
||||
// Build eager context.
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TF_ExecutionContext* eager_execution_ctx =
|
||||
TF_NewEagerExecutionContext(opts, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TF_ExecutionContextRegisterFunction(eager_execution_ctx, func, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build the abstract op to run the function.
|
||||
TF_AbstractOp* fn_op = TF_NewAbstractOp(eager_execution_ctx);
|
||||
TF_AbstractOpSetOpType(fn_op, fn_name.c_str(), status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build an abstract input tensor.
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(eager_execution_ctx, status.get());
|
||||
|
||||
float vals[] = {1.0f, 1.0f, 1.0f, 1.0f};
|
||||
int64_t dims[] = {2, 2}; // Matrices will be 2 x 2
|
||||
int num_dims = sizeof(dims) / sizeof(dims[0]);
|
||||
|
||||
TFE_TensorHandle* input_eager =
|
||||
TestMatrixTensorHandleWithInput(eager_ctx, vals, dims, num_dims);
|
||||
TF_AbstractTensor* input_t =
|
||||
TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
TF_OutputListSetNumOutputs(mm_outputs, 1, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_ExecuteOperation(fn_op, 1, &input_t, mm_outputs, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
ASSERT_EQ(1, TF_OutputListNumOutputs(mm_outputs));
|
||||
TF_AbstractTensor* final_result = TF_OutputListGet(mm_outputs, 0);
|
||||
TFE_TensorHandle* final =
|
||||
TF_AbstractTensorGetEagerTensor(final_result, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_Tensor* f_t = TFE_TensorHandleResolve(final, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
float result_data[4] = {0};
|
||||
memcpy(&result_data[0], TF_TensorData(f_t), TF_TensorByteSize(f_t));
|
||||
|
||||
int data_len = 4;
|
||||
for (int i = 0; i < data_len; i++) {
|
||||
ASSERT_EQ(result_data[i], 2.0f);
|
||||
}
|
||||
|
||||
TF_DeleteAbstractTensor(final_result);
|
||||
TF_DeleteOutputList(mm_outputs);
|
||||
TF_DeleteAbstractTensor(placeholder_t);
|
||||
TF_DeleteAbstractOp(fn_op);
|
||||
TF_DeleteAbstractTensor(input_t);
|
||||
TF_DeleteTensor(f_t);
|
||||
TF_DeleteAbstractFunction(func);
|
||||
|
||||
TF_DeleteExecutionContext(eager_execution_ctx);
|
||||
}
|
||||
|
||||
TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
@ -336,6 +688,217 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
|
||||
TF_DeleteAbstractFunction(func);
|
||||
}
|
||||
|
||||
TEST_P(UnifiedCAPI, TestMultiOutputGraphMatMul) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_Status* s = status.get();
|
||||
|
||||
// Start a new function / execution context.
|
||||
string fn_name = "two_adds_and_matmul";
|
||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name.c_str(), s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
// Create a first "Add" computing `arg0 + arg1`.
|
||||
TF_AbstractTensor* add_output1;
|
||||
{
|
||||
// Build an abstract operation, inputs and output.
|
||||
auto* add_op = TF_NewAbstractOp(graph_ctx);
|
||||
TF_AbstractOpSetOpType(add_op, "Add", s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_AbstractOpSetOpName(add_op, "my_add1", s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_AbstractTensor* inputs[2] = {arg0, arg1};
|
||||
TF_OutputList* add_outputs = TF_NewOutputList();
|
||||
TF_OutputListSetNumOutputs(add_outputs, 1, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Trace the operation now (create a node in the graph).
|
||||
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_DeleteAbstractOp(add_op);
|
||||
|
||||
// Extract the resulting tensor.
|
||||
add_output1 = TF_OutputListGet(add_outputs, 0);
|
||||
TF_DeleteOutputList(add_outputs);
|
||||
}
|
||||
|
||||
// Same with a second "Add" computing `arg1 + arg1`.
|
||||
TF_AbstractTensor* add_output2;
|
||||
{
|
||||
// Build an abstract operation, inputs and output.
|
||||
auto* add_op = TF_NewAbstractOp(graph_ctx);
|
||||
TF_AbstractOpSetOpType(add_op, "Add", s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_AbstractOpSetOpName(add_op, "my_add2", s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_AbstractTensor* inputs[2] = {arg1, arg1};
|
||||
TF_OutputList* add_outputs = TF_NewOutputList();
|
||||
TF_OutputListSetNumOutputs(add_outputs, 1, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Trace the operation now (create a node in the graph).
|
||||
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_DeleteAbstractOp(add_op);
|
||||
|
||||
// Extract the resulting tensor.
|
||||
add_output2 = TF_OutputListGet(add_outputs, 0);
|
||||
TF_DeleteOutputList(add_outputs);
|
||||
}
|
||||
|
||||
// 3rd Output will be Matrix Multiplication of add_output1 and add_output2
|
||||
TF_AbstractTensor* mm_output;
|
||||
{
|
||||
// Build an abstract operation, inputs and output.
|
||||
auto* mm_op = TF_NewAbstractOp(graph_ctx);
|
||||
TF_AbstractOpSetOpType(mm_op, "MatMul", s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_AbstractOpSetOpName(mm_op, "mm", s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_AbstractTensor* inputs[2] = {add_output1, add_output2};
|
||||
TF_OutputList* mm_outputs = TF_NewOutputList();
|
||||
TF_OutputListSetNumOutputs(mm_outputs, 1, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Trace the operation now (create a node in the graph).
|
||||
TF_ExecuteOperation(mm_op, 2, inputs, mm_outputs, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_DeleteAbstractOp(mm_op);
|
||||
|
||||
// Extract the resulting tensor.
|
||||
mm_output = TF_OutputListGet(mm_outputs, 0);
|
||||
TF_DeleteOutputList(mm_outputs);
|
||||
}
|
||||
|
||||
// Finalize the function by providing the returned values.
|
||||
TF_AbstractFunction* func;
|
||||
{
|
||||
// We want to return the output of both add operations and MatMul operation,
|
||||
// create a new list and populate it.
|
||||
TF_OutputList* func_outputs = TF_NewOutputList();
|
||||
TF_OutputListPushBack(func_outputs, add_output1, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_OutputListPushBack(func_outputs, add_output2, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_OutputListPushBack(func_outputs, mm_output, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
func = TF_FinalizeFunction(graph_ctx, func_outputs, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_DeleteOutputList(func_outputs);
|
||||
}
|
||||
|
||||
/**
|
||||
* We traced so far this function:
|
||||
*
|
||||
* def two_adds_and_mm(A, B):
|
||||
* my_add1 = A + B
|
||||
* my_add2 = B + B
|
||||
* mm = tf.MatMul(my_add1,my_add2)
|
||||
* return my_add1, my_add2, mm
|
||||
*
|
||||
* Now we will execute this function with an eager context:
|
||||
*
|
||||
* A =[[0, 1],[1, 0]]
|
||||
* B =[[1, 0],[0, 1]]
|
||||
*
|
||||
* output1, output2, output3 = two_adds_and_mm(A, B)
|
||||
*
|
||||
* We expect outputs:
|
||||
*
|
||||
* output1 = [[1, 1],[1, 1]]
|
||||
* output2 = [[2, 0],[0, 2]]
|
||||
* output3 = [[2, 2],[2, 2]]
|
||||
*
|
||||
*/
|
||||
|
||||
// Build eager context.
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TF_ExecutionContext* eager_execution_ctx =
|
||||
TF_NewEagerExecutionContext(opts, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TF_ExecutionContextRegisterFunction(eager_execution_ctx, func, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
// Build the abstract op to run the function.
|
||||
TF_AbstractOp* fn_op = TF_NewAbstractOp(eager_execution_ctx);
|
||||
TF_AbstractOpSetOpType(fn_op, fn_name.c_str(), s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
// Build two abstract input tensors as function arguments.
|
||||
std::vector<TF_AbstractTensor*> func_args;
|
||||
{
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(eager_execution_ctx, s);
|
||||
|
||||
// 1st Arg
|
||||
float vals1[] = {0.0f, 1.0f, 1.0f, 0.0f};
|
||||
int64_t dims[] = {2, 2}; // Matrices will be 2 x 2
|
||||
int num_dims = sizeof(dims) / sizeof(dims[0]);
|
||||
|
||||
TFE_TensorHandle* input_eager =
|
||||
TestMatrixTensorHandleWithInput(eager_ctx, vals1, dims, num_dims);
|
||||
func_args.push_back(TF_CreateAbstractTensorFromEagerTensor(input_eager, s));
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
// 2nd Arg
|
||||
float vals2[] = {1.0f, 0.0f, 0.0f, 1.0f};
|
||||
input_eager =
|
||||
TestMatrixTensorHandleWithInput(eager_ctx, vals2, dims, num_dims);
|
||||
func_args.push_back(TF_CreateAbstractTensorFromEagerTensor(input_eager, s));
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
}
|
||||
|
||||
TF_OutputList* func_outputs = TF_NewOutputList();
|
||||
TF_OutputListSetNumOutputs(func_outputs, 3, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_ExecuteOperation(fn_op, func_args.size(), func_args.data(), func_outputs,
|
||||
s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_DeleteAbstractOp(fn_op);
|
||||
for (TF_AbstractTensor* t : func_args) TF_DeleteAbstractTensor(t);
|
||||
|
||||
ASSERT_EQ(3, TF_OutputListNumOutputs(func_outputs));
|
||||
|
||||
float expected_outputs[3][4] = {{1.0f, 1.0f, 1.0f, 1.0f},
|
||||
{2.0f, 0.0f, 0.0f, 2.0f},
|
||||
{2.0f, 2.0f, 2.0f, 2.0f}};
|
||||
|
||||
float result_data[4];
|
||||
for (int idx = 0; idx < 3; ++idx) {
|
||||
TF_AbstractTensor* result = TF_OutputListGet(func_outputs, idx);
|
||||
TFE_TensorHandle* handle = TF_AbstractTensorGetEagerTensor(result, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_Tensor* f_t = TFE_TensorHandleResolve(handle, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
memcpy(&result_data[0], TF_TensorData(f_t), TF_TensorByteSize(f_t));
|
||||
|
||||
// Verify results for each output
|
||||
for (int j = 0; j < 4; j++) {
|
||||
ASSERT_EQ(result_data[j], expected_outputs[idx][j]);
|
||||
}
|
||||
|
||||
TF_DeleteTensor(f_t);
|
||||
}
|
||||
|
||||
// Free memory associated with add and MatMul outputs
|
||||
for (int idx = 0; idx < 3; ++idx) {
|
||||
TF_AbstractTensor* result = TF_OutputListGet(func_outputs, idx);
|
||||
TF_DeleteAbstractTensor(result);
|
||||
}
|
||||
|
||||
TF_DeleteOutputList(func_outputs);
|
||||
TF_DeleteExecutionContext(eager_execution_ctx);
|
||||
TF_DeleteAbstractFunction(func);
|
||||
}
|
||||
|
||||
TEST_P(UnifiedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
|
@ -175,7 +175,8 @@ Status TapeVSpace::CallBackwardFunction(
|
||||
gtl::ArraySlice<AbstractTensorHandle*> output_gradients,
|
||||
std::vector<AbstractTensorHandle*>* result) const {
|
||||
if (backward_function == nullptr) return Status::OK();
|
||||
return backward_function->Compute(output_gradients, result);
|
||||
Context ctx = {ctx_};
|
||||
return backward_function->Compute(&ctx, output_gradients, result);
|
||||
}
|
||||
|
||||
// Looks up the ID of a Gradient.
|
||||
|
@ -31,7 +31,8 @@ namespace gradients {
|
||||
//
|
||||
// class AddGradientFunction : public GradientFunction {
|
||||
// public:
|
||||
// Status Compute(absl::Span<AbstractTensorHandle* const> grad_inputs,
|
||||
// Status Compute(Context* ctx,
|
||||
// absl::Span<AbstractTensorHandle* const> grad_inputs,
|
||||
// std::vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
// grad_outputs->resize(2);
|
||||
// (*grad_outputs)[0] = grad_inputs[0];
|
||||
@ -50,11 +51,16 @@ namespace gradients {
|
||||
// Status RegisterGradients(GradientRegistry* registry) {
|
||||
// return registry->Register("Add", AddRegisterer);
|
||||
// }
|
||||
struct Context {
|
||||
public:
|
||||
AbstractContext* ctx;
|
||||
};
|
||||
class GradientFunction {
|
||||
public:
|
||||
// TODO(srbs): How we support CompositeTensors e.g. IndexedSlices in
|
||||
// `grad_inputs`.
|
||||
virtual Status Compute(absl::Span<AbstractTensorHandle* const> grad_inputs,
|
||||
virtual Status Compute(Context* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> grad_inputs,
|
||||
std::vector<AbstractTensorHandle*>* grad_outputs) = 0;
|
||||
virtual ~GradientFunction() {}
|
||||
};
|
||||
|
@ -23,6 +23,8 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/eager/gradients_internal.h"
|
||||
#include "tensorflow/c/experimental/gradients/math_grad.h"
|
||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
@ -42,55 +44,10 @@ class CppGradients
|
||||
}
|
||||
};
|
||||
|
||||
// Creates an Identity op.
|
||||
Status Identity(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr identity_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(
|
||||
identity_op->Reset("Identity", /*raw_device_name=*/nullptr));
|
||||
if (isa<tracing::TracingOperation>(identity_op.get())) {
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(identity_op.get())
|
||||
->SetOpName(name));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(identity_op->AddInput(inputs[0]));
|
||||
int num_retvals = 1;
|
||||
TF_RETURN_IF_ERROR(identity_op->Execute(outputs, &num_retvals));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// =================== Register gradients for Add ============================
|
||||
class AddGradientFunction : public GradientFunction {
|
||||
public:
|
||||
explicit AddGradientFunction(AbstractContext* ctx) : ctx_(ctx) {}
|
||||
Status Compute(absl::Span<AbstractTensorHandle* const> grad_inputs,
|
||||
std::vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
grad_outputs->resize(2);
|
||||
std::vector<AbstractTensorHandle*> identity_outputs(1);
|
||||
TF_RETURN_IF_ERROR(Identity(ctx_, {grad_inputs[0]},
|
||||
absl::MakeSpan(identity_outputs), "Id0"));
|
||||
(*grad_outputs)[0] = identity_outputs[0];
|
||||
TF_RETURN_IF_ERROR(Identity(ctx_, {grad_inputs[0]},
|
||||
absl::MakeSpan(identity_outputs), "Id1"));
|
||||
(*grad_outputs)[1] = identity_outputs[0];
|
||||
return Status::OK();
|
||||
}
|
||||
~AddGradientFunction() override {}
|
||||
|
||||
private:
|
||||
AbstractContext* ctx_;
|
||||
};
|
||||
|
||||
GradientFunction* AddRegisterer(const ForwardOperation& op) {
|
||||
return new AddGradientFunction(op.ctx);
|
||||
}
|
||||
|
||||
Status RegisterGradients(GradientRegistry* registry) {
|
||||
return registry->Register("Add", AddRegisterer);
|
||||
}
|
||||
|
||||
// =================== End gradient registrations ============================
|
||||
|
||||
// Computes `inputs[0] + inputs[1]` and records it on the tape.
|
||||
Status Add(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_operation.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
@ -26,6 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/util/abstract_stack_trace.h"
|
||||
|
||||
struct TFE_Op;
|
||||
|
||||
@ -44,6 +46,12 @@ class ImmediateExecutionOperation : public AbstractOperation {
|
||||
// Experimental
|
||||
virtual Status SetUseXla(bool enable) = 0;
|
||||
|
||||
// Set stack trace to be used for potential async error reporting.
|
||||
virtual void SetStackTrace(AbstractStackTrace stack_trace) = 0;
|
||||
|
||||
// Returns the stack trace set by `SetStackTrace` if exists.
|
||||
virtual absl::optional<AbstractStackTrace> GetStackTrace() = 0;
|
||||
|
||||
// For LLVM style RTTI.
|
||||
static bool classof(const AbstractOperation* ptr) {
|
||||
return ptr->getKind() == kEager || ptr->getKind() == kTfrt;
|
||||
|
@ -78,6 +78,11 @@ typedef struct TF_Filesystem {
|
||||
void* plugin_filesystem;
|
||||
} TF_Filesystem;
|
||||
|
||||
typedef struct TF_TransactionToken {
|
||||
void* token;
|
||||
TF_Filesystem* owner;
|
||||
} TF_TransactionToken;
|
||||
|
||||
/// SECTION 2. Function tables for functionality provided by plugins
|
||||
/// ----------------------------------------------------------------------------
|
||||
///
|
||||
@ -679,6 +684,133 @@ typedef struct TF_FilesystemOps {
|
||||
///
|
||||
/// DEFAULT IMPLEMENTATION: No op.
|
||||
void (*flush_caches)(const TF_Filesystem* filesystem);
|
||||
|
||||
/// Starts a new transaction.
|
||||
///
|
||||
/// An opaque transaction token is returned in `token`. Ownership of the token
|
||||
/// is in filesystem. Token will be freed in `end_transaction` call and any
|
||||
/// access to token after that is invalid.
|
||||
///
|
||||
/// In case of error, plugins must set `status` to a value different than
|
||||
/// `TF_OK`, free memory allocated for `token` and return -1.
|
||||
///
|
||||
/// The allocation and freeing of memory must happen via the functions sent to
|
||||
/// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo`
|
||||
/// structure in Section 4).
|
||||
///
|
||||
/// Plugins:
|
||||
/// * Must set `status` to `TF_OK` if transaction successfuly started.
|
||||
/// * Must set `status` to `TF_FAILED_PRECONDITION` if multiple transactions
|
||||
/// are not supported
|
||||
/// * Might use any other error value for `status` to signal other errors.
|
||||
int (*start_transaction)(const TF_Filesystem* filesystem,
|
||||
TF_TransactionToken** token, TF_Status* status);
|
||||
|
||||
/// Ends transaction and free the `token`. Any access to token after
|
||||
/// that will be invalid.
|
||||
///
|
||||
/// In case of error, plugins must set `status` to a value different than
|
||||
/// `TF_OK`, free memory allocated for `token` and return -1.
|
||||
///
|
||||
/// The allocation and freeing of memory must happen via the functions sent to
|
||||
/// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo`
|
||||
/// structure in Section 4).
|
||||
///
|
||||
/// Plugins:
|
||||
/// * Must set `status` to `TF_OK` if transaction successfuly finalized.
|
||||
/// * Must set `status` to `TF_NOT_FOUND` if token is invalid/not found
|
||||
/// * Might use any other error value for `status` to signal other errors.
|
||||
int (*end_transaction)(const TF_Filesystem* filesystem,
|
||||
TF_TransactionToken* token, TF_Status* status);
|
||||
|
||||
/// Adds file/directory in the `path` to transaction in `token`. It is a valid
|
||||
/// operation to add a path that doesn't exist yet to a transaction.
|
||||
///
|
||||
/// In case of error, plugins must set `status` to a value different than
|
||||
/// `TF_OK`, free memory allocated for `token` and return -1.
|
||||
///
|
||||
/// The allocation and freeing of memory must happen via the functions sent to
|
||||
/// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo`
|
||||
/// structure in Section 4).
|
||||
///
|
||||
/// Plugins:
|
||||
/// * Must set `status` to `TF_OK` if path added to transaction successful.
|
||||
/// * Must set `status` to `TF_NOT_FOUND` if `token` is invalid.
|
||||
/// * Must set `status` to `TF_FAILED_PRECONDITION` if file/directory is in
|
||||
/// another transaction and multiple transactions are not supported
|
||||
/// * Might use any other error value for `status` to signal other errors.
|
||||
int (*add_to_transaction)(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_TransactionToken* token, TF_Status* status);
|
||||
|
||||
/// Returns transaction token for file/directory in the `path`. Note that path
|
||||
/// may not exist yet but still might be part of a transaction.
|
||||
///
|
||||
/// Transaction token is returned in `token`. Ownership of the token is in
|
||||
/// filesystem. Token will be freed in `end_transaction` call and any access
|
||||
/// to token after that is invalid.
|
||||
///
|
||||
/// In case of error, plugins must set `status` to a value different than
|
||||
/// `TF_OK`, free memory allocated for `token` and return -1.
|
||||
///
|
||||
/// The allocation and freeing of memory must happen via the functions sent to
|
||||
/// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo`
|
||||
/// structure in Section 4).
|
||||
///
|
||||
/// Plugins:
|
||||
/// * Must set `status` to `TF_OK` if a transaction for path is found
|
||||
/// * Must set `status` to `TF_NOT_FOUND` if `path` is not part of any
|
||||
/// transaction
|
||||
/// * Must set `status` to `TF_FAILED_PRECONDITION` if `path` is
|
||||
/// not in this filesystem.
|
||||
/// * Might use any other error value for `status` to signal other errors.
|
||||
int (*get_transaction_for_path)(const TF_Filesystem* filesystem,
|
||||
const char* path, TF_TransactionToken** token,
|
||||
TF_Status* status);
|
||||
|
||||
/// Returns transaction token for `path` if it is part of a transaction else
|
||||
/// starts a new transaction and adds `path` to that transaction
|
||||
///
|
||||
/// Transaction token is returned in `token`. Ownership of the token is in
|
||||
/// filesystem. Token will be freed in `end_transaction` call and any access
|
||||
/// to token after that is invalid.
|
||||
///
|
||||
/// In case of error, plugins must set `status` to a value different than
|
||||
/// `TF_OK`, free memory allocated for `token` and return -1.
|
||||
///
|
||||
/// The allocation and freeing of memory must happen via the functions sent to
|
||||
/// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo`
|
||||
/// structure in Section 4).
|
||||
///
|
||||
/// Plugins:
|
||||
/// * Must set `status` to `TF_OK` if transaction found or successfuly
|
||||
/// started.
|
||||
/// * Must set `status` to `TF_NOT_FOUND` if `path` doesn't point to this
|
||||
/// filesystem
|
||||
/// * Must set `status` to `TF_FAILED_PRECONDITION` if file/directory is
|
||||
/// not in any transaction and multiple transactions are not supported.
|
||||
/// * Might use any other error value for `status` to signal other errors.
|
||||
int (*get_or_start_transaction_for_path)(const TF_Filesystem* filesystem,
|
||||
const char* path,
|
||||
TF_TransactionToken** token,
|
||||
TF_Status* status);
|
||||
|
||||
/// Decodes transaction token in `token` to human readable format for
|
||||
/// debugging.
|
||||
///
|
||||
/// A new `char*` buffer must be allocated by this method. Core TensorFlow
|
||||
/// manages the lifetime of the buffer after the call. Thus, all callers of
|
||||
/// this method must take ownership of the returned pointer.
|
||||
///
|
||||
/// Plugins must not return `nullptr`. Returning empty strings is allowed.
|
||||
///
|
||||
/// The allocation and freeing of memory must happen via the functions sent to
|
||||
/// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo`
|
||||
/// structure in Section 4).
|
||||
///
|
||||
/// DEFAULT IMPLEMENTATION: Dump token and owner address.
|
||||
char* (*decode_transaction_token)(const TF_Filesystem* filesystem,
|
||||
const TF_TransactionToken* token);
|
||||
|
||||
} TF_FilesystemOps;
|
||||
// LINT.ThenChange(:filesystem_ops_version)
|
||||
|
||||
|
@ -35,7 +35,8 @@ using UniquePtrTo_TF_Status =
|
||||
::std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)>;
|
||||
|
||||
Status ModularFileSystem::NewRandomAccessFile(
|
||||
const std::string& fname, std::unique_ptr<RandomAccessFile>* result) {
|
||||
const std::string& fname,
|
||||
std::unique_ptr<RandomAccessFile>* result /*, TransactionToken* token */) {
|
||||
if (ops_->new_random_access_file == nullptr)
|
||||
return errors::Unimplemented(tensorflow::strings::StrCat(
|
||||
"Filesystem for ", fname, " does not support NewRandomAccessFile()"));
|
||||
@ -54,7 +55,8 @@ Status ModularFileSystem::NewRandomAccessFile(
|
||||
}
|
||||
|
||||
Status ModularFileSystem::NewWritableFile(
|
||||
const std::string& fname, std::unique_ptr<WritableFile>* result) {
|
||||
const std::string& fname,
|
||||
std::unique_ptr<WritableFile>* result /*, TransactionToken* token */) {
|
||||
if (ops_->new_writable_file == nullptr)
|
||||
return errors::Unimplemented(tensorflow::strings::StrCat(
|
||||
"Filesystem for ", fname, " does not support NewWritableFile()"));
|
||||
@ -73,7 +75,8 @@ Status ModularFileSystem::NewWritableFile(
|
||||
}
|
||||
|
||||
Status ModularFileSystem::NewAppendableFile(
|
||||
const std::string& fname, std::unique_ptr<WritableFile>* result) {
|
||||
const std::string& fname,
|
||||
std::unique_ptr<WritableFile>* result /*, TransactionToken* token */) {
|
||||
if (ops_->new_appendable_file == nullptr)
|
||||
return errors::Unimplemented(tensorflow::strings::StrCat(
|
||||
"Filesystem for ", fname, " does not support NewAppendableFile()"));
|
||||
@ -92,7 +95,8 @@ Status ModularFileSystem::NewAppendableFile(
|
||||
}
|
||||
|
||||
Status ModularFileSystem::NewReadOnlyMemoryRegionFromFile(
|
||||
const std::string& fname, std::unique_ptr<ReadOnlyMemoryRegion>* result) {
|
||||
const std::string& fname, std::unique_ptr<ReadOnlyMemoryRegion>*
|
||||
result /*, TransactionToken* token */) {
|
||||
if (ops_->new_read_only_memory_region_from_file == nullptr)
|
||||
return errors::Unimplemented(tensorflow::strings::StrCat(
|
||||
"Filesystem for ", fname,
|
||||
@ -112,7 +116,8 @@ Status ModularFileSystem::NewReadOnlyMemoryRegionFromFile(
|
||||
return StatusFromTF_Status(plugin_status.get());
|
||||
}
|
||||
|
||||
Status ModularFileSystem::FileExists(const std::string& fname) {
|
||||
Status ModularFileSystem::FileExists(
|
||||
const std::string& fname /*, TransactionToken* token */) {
|
||||
if (ops_->path_exists == nullptr)
|
||||
return errors::Unimplemented(tensorflow::strings::StrCat(
|
||||
"Filesystem for ", fname, " does not support FileExists()"));
|
||||
@ -124,8 +129,9 @@ Status ModularFileSystem::FileExists(const std::string& fname) {
|
||||
return StatusFromTF_Status(plugin_status.get());
|
||||
}
|
||||
|
||||
bool ModularFileSystem::FilesExist(const std::vector<std::string>& files,
|
||||
std::vector<Status>* status) {
|
||||
bool ModularFileSystem::FilesExist(
|
||||
const std::vector<std::string>& files,
|
||||
std::vector<Status>* status /*, TransactionToken* token */) {
|
||||
if (ops_->paths_exist == nullptr)
|
||||
return FileSystem::FilesExist(files, status);
|
||||
|
||||
@ -156,8 +162,9 @@ bool ModularFileSystem::FilesExist(const std::vector<std::string>& files,
|
||||
return result;
|
||||
}
|
||||
|
||||
Status ModularFileSystem::GetChildren(const std::string& dir,
|
||||
std::vector<std::string>* result) {
|
||||
Status ModularFileSystem::GetChildren(
|
||||
const std::string& dir,
|
||||
std::vector<std::string>* result /*, TransactionToken* token */) {
|
||||
if (ops_->get_children == nullptr)
|
||||
return errors::Unimplemented(tensorflow::strings::StrCat(
|
||||
"Filesystem for ", dir, " does not support GetChildren()"));
|
||||
@ -181,8 +188,9 @@ Status ModularFileSystem::GetChildren(const std::string& dir,
|
||||
return StatusFromTF_Status(plugin_status.get());
|
||||
}
|
||||
|
||||
Status ModularFileSystem::GetMatchingPaths(const std::string& pattern,
|
||||
std::vector<std::string>* result) {
|
||||
Status ModularFileSystem::GetMatchingPaths(
|
||||
const std::string& pattern,
|
||||
std::vector<std::string>* result /*, TransactionToken* token */) {
|
||||
if (ops_->get_matching_paths == nullptr)
|
||||
return internal::GetMatchingPaths(this, Env::Default(), pattern, result);
|
||||
|
||||
@ -203,7 +211,8 @@ Status ModularFileSystem::GetMatchingPaths(const std::string& pattern,
|
||||
return StatusFromTF_Status(plugin_status.get());
|
||||
}
|
||||
|
||||
Status ModularFileSystem::DeleteFile(const std::string& fname) {
|
||||
Status ModularFileSystem::DeleteFile(
|
||||
const std::string& fname /*, TransactionToken* token */) {
|
||||
if (ops_->delete_file == nullptr)
|
||||
return errors::Unimplemented(tensorflow::strings::StrCat(
|
||||
"Filesystem for ", fname, " does not support DeleteFile()"));
|
||||
@ -215,9 +224,9 @@ Status ModularFileSystem::DeleteFile(const std::string& fname) {
|
||||
return StatusFromTF_Status(plugin_status.get());
|
||||
}
|
||||
|
||||
Status ModularFileSystem::DeleteRecursively(const std::string& dirname,
|
||||
int64* undeleted_files,
|
||||
int64* undeleted_dirs) {
|
||||
Status ModularFileSystem::DeleteRecursively(
|
||||
const std::string& dirname, int64* undeleted_files,
|
||||
int64* undeleted_dirs /*, TransactionToken* token */) {
|
||||
if (undeleted_files == nullptr || undeleted_dirs == nullptr)
|
||||
return errors::FailedPrecondition(
|
||||
"DeleteRecursively must not be called with `undeleted_files` or "
|
||||
@ -238,7 +247,8 @@ Status ModularFileSystem::DeleteRecursively(const std::string& dirname,
|
||||
return StatusFromTF_Status(plugin_status.get());
|
||||
}
|
||||
|
||||
Status ModularFileSystem::DeleteDir(const std::string& dirname) {
|
||||
Status ModularFileSystem::DeleteDir(
|
||||
const std::string& dirname /*, TransactionToken* token */) {
|
||||
if (ops_->delete_dir == nullptr)
|
||||
return errors::Unimplemented(tensorflow::strings::StrCat(
|
||||
"Filesystem for ", dirname, " does not support DeleteDir()"));
|
||||
@ -250,7 +260,8 @@ Status ModularFileSystem::DeleteDir(const std::string& dirname) {
|
||||
return StatusFromTF_Status(plugin_status.get());
|
||||
}
|
||||
|
||||
Status ModularFileSystem::RecursivelyCreateDir(const std::string& dirname) {
|
||||
Status ModularFileSystem::RecursivelyCreateDir(
|
||||
const std::string& dirname /*, TransactionToken* token */) {
|
||||
if (ops_->recursively_create_dir == nullptr)
|
||||
return FileSystem::RecursivelyCreateDir(dirname);
|
||||
|
||||
@ -261,7 +272,8 @@ Status ModularFileSystem::RecursivelyCreateDir(const std::string& dirname) {
|
||||
return StatusFromTF_Status(plugin_status.get());
|
||||
}
|
||||
|
||||
Status ModularFileSystem::CreateDir(const std::string& dirname) {
|
||||
Status ModularFileSystem::CreateDir(
|
||||
const std::string& dirname /*, TransactionToken* token */) {
|
||||
if (ops_->create_dir == nullptr)
|
||||
return errors::Unimplemented(tensorflow::strings::StrCat(
|
||||
"Filesystem for ", dirname, " does not support CreateDir()"));
|
||||
@ -273,7 +285,9 @@ Status ModularFileSystem::CreateDir(const std::string& dirname) {
|
||||
return StatusFromTF_Status(plugin_status.get());
|
||||
}
|
||||
|
||||
Status ModularFileSystem::Stat(const std::string& fname, FileStatistics* stat) {
|
||||
Status ModularFileSystem::Stat(
|
||||
const std::string& fname,
|
||||
FileStatistics* stat /*, TransactionToken* token */) {
|
||||
if (ops_->stat == nullptr)
|
||||
return errors::Unimplemented(tensorflow::strings::StrCat(
|
||||
"Filesystem for ", fname, " does not support Stat()"));
|
||||
@ -296,7 +310,8 @@ Status ModularFileSystem::Stat(const std::string& fname, FileStatistics* stat) {
|
||||
return StatusFromTF_Status(plugin_status.get());
|
||||
}
|
||||
|
||||
Status ModularFileSystem::IsDirectory(const std::string& name) {
|
||||
Status ModularFileSystem::IsDirectory(
|
||||
const std::string& name /*, TransactionToken* token */) {
|
||||
if (ops_->is_directory == nullptr) return FileSystem::IsDirectory(name);
|
||||
|
||||
UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus);
|
||||
@ -306,8 +321,9 @@ Status ModularFileSystem::IsDirectory(const std::string& name) {
|
||||
return StatusFromTF_Status(plugin_status.get());
|
||||
}
|
||||
|
||||
Status ModularFileSystem::GetFileSize(const std::string& fname,
|
||||
uint64* file_size) {
|
||||
Status ModularFileSystem::GetFileSize(
|
||||
const std::string& fname,
|
||||
uint64* file_size /*, TransactionToken* token */) {
|
||||
if (ops_->get_file_size == nullptr) {
|
||||
FileStatistics stat;
|
||||
Status status = Stat(fname, &stat);
|
||||
@ -326,8 +342,9 @@ Status ModularFileSystem::GetFileSize(const std::string& fname,
|
||||
return StatusFromTF_Status(plugin_status.get());
|
||||
}
|
||||
|
||||
Status ModularFileSystem::RenameFile(const std::string& src,
|
||||
const std::string& target) {
|
||||
Status ModularFileSystem::RenameFile(
|
||||
const std::string& src,
|
||||
const std::string& target /*, TransactionToken* token */) {
|
||||
if (ops_->rename_file == nullptr) {
|
||||
Status status = CopyFile(src, target);
|
||||
if (status.ok()) status = DeleteFile(src);
|
||||
@ -342,8 +359,9 @@ Status ModularFileSystem::RenameFile(const std::string& src,
|
||||
return StatusFromTF_Status(plugin_status.get());
|
||||
}
|
||||
|
||||
Status ModularFileSystem::CopyFile(const std::string& src,
|
||||
const std::string& target) {
|
||||
Status ModularFileSystem::CopyFile(
|
||||
const std::string& src,
|
||||
const std::string& target /*, TransactionToken* token */) {
|
||||
if (ops_->copy_file == nullptr) return FileSystem::CopyFile(src, target);
|
||||
|
||||
UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus);
|
||||
@ -354,7 +372,8 @@ Status ModularFileSystem::CopyFile(const std::string& src,
|
||||
return StatusFromTF_Status(plugin_status.get());
|
||||
}
|
||||
|
||||
std::string ModularFileSystem::TranslateName(const std::string& name) const {
|
||||
std::string ModularFileSystem::TranslateName(
|
||||
const std::string& name /*, TransactionToken* token */) const {
|
||||
if (ops_->translate_name == nullptr) return FileSystem::TranslateName(name);
|
||||
|
||||
char* p = ops_->translate_name(filesystem_.get(), name.c_str());
|
||||
@ -366,7 +385,7 @@ std::string ModularFileSystem::TranslateName(const std::string& name) const {
|
||||
return ret;
|
||||
}
|
||||
|
||||
void ModularFileSystem::FlushCaches() {
|
||||
void ModularFileSystem::FlushCaches(/*TransactionToken* token*/) {
|
||||
if (ops_->flush_caches != nullptr) ops_->flush_caches(filesystem_.get());
|
||||
}
|
||||
|
||||
|
@ -61,34 +61,69 @@ class ModularFileSystem final : public FileSystem {
|
||||
|
||||
Status NewRandomAccessFile(
|
||||
const std::string& fname,
|
||||
std::unique_ptr<RandomAccessFile>* result) override;
|
||||
Status NewWritableFile(const std::string& fname,
|
||||
std::unique_ptr<WritableFile>* result) override;
|
||||
Status NewAppendableFile(const std::string& fname,
|
||||
std::unique_ptr<WritableFile>* result) override;
|
||||
std::unique_ptr<RandomAccessFile>*
|
||||
result /*, TransactionToken* token = nullptr */) override;
|
||||
Status NewWritableFile(
|
||||
const std::string& fname,
|
||||
std::unique_ptr<WritableFile>*
|
||||
result /*, TransactionToken* token = nullptr */) override;
|
||||
Status NewAppendableFile(
|
||||
const std::string& fname,
|
||||
std::unique_ptr<WritableFile>*
|
||||
result /*, TransactionToken* token = nullptr */) override;
|
||||
Status NewReadOnlyMemoryRegionFromFile(
|
||||
const std::string& fname,
|
||||
std::unique_ptr<ReadOnlyMemoryRegion>* result) override;
|
||||
Status FileExists(const std::string& fname) override;
|
||||
std::unique_ptr<ReadOnlyMemoryRegion>*
|
||||
result /*, TransactionToken* token = nullptr */) override;
|
||||
Status FileExists(
|
||||
const std::string& fname /*, TransactionToken* token = nullptr */)
|
||||
override;
|
||||
bool FilesExist(const std::vector<std::string>& files,
|
||||
std::vector<Status>* status) override;
|
||||
Status GetChildren(const std::string& dir,
|
||||
std::vector<std::string>* result) override;
|
||||
Status GetMatchingPaths(const std::string& pattern,
|
||||
std::vector<std::string>* results) override;
|
||||
Status DeleteFile(const std::string& fname) override;
|
||||
Status DeleteRecursively(const std::string& dirname, int64* undeleted_files,
|
||||
int64* undeleted_dirs) override;
|
||||
Status DeleteDir(const std::string& dirname) override;
|
||||
Status RecursivelyCreateDir(const std::string& dirname) override;
|
||||
Status CreateDir(const std::string& dirname) override;
|
||||
Status Stat(const std::string& fname, FileStatistics* stat) override;
|
||||
Status IsDirectory(const std::string& fname) override;
|
||||
Status GetFileSize(const std::string& fname, uint64* file_size) override;
|
||||
Status RenameFile(const std::string& src, const std::string& target) override;
|
||||
Status CopyFile(const std::string& src, const std::string& target) override;
|
||||
std::string TranslateName(const std::string& name) const override;
|
||||
void FlushCaches() override;
|
||||
std::vector<Status>*
|
||||
status /*, TransactionToken* token = nullptr */) override;
|
||||
Status GetChildren(
|
||||
const std::string& dir,
|
||||
std::vector<std::string>* result /*, TransactionToken* token = nullptr */)
|
||||
override;
|
||||
Status GetMatchingPaths(
|
||||
const std::string& pattern,
|
||||
std::vector<std::string>*
|
||||
results /*, TransactionToken* token = nullptr */) override;
|
||||
Status DeleteFile(
|
||||
const std::string& fname /*, TransactionToken* token = nullptr */)
|
||||
override;
|
||||
Status DeleteRecursively(
|
||||
const std::string& dirname, int64* undeleted_files,
|
||||
int64* undeleted_dirs /*, TransactionToken* token = nullptr */) override;
|
||||
Status DeleteDir(
|
||||
const std::string& dirname /*, TransactionToken* token = nullptr */)
|
||||
override;
|
||||
Status RecursivelyCreateDir(
|
||||
const std::string& dirname /*, TransactionToken* token = nullptr */)
|
||||
override;
|
||||
Status CreateDir(
|
||||
const std::string& dirname /*, TransactionToken* token = nullptr */)
|
||||
override;
|
||||
Status Stat(
|
||||
const std::string& fname,
|
||||
FileStatistics* stat /*, TransactionToken* token = nullptr */) override;
|
||||
Status IsDirectory(
|
||||
const std::string& fname /*, TransactionToken* token = nullptr */)
|
||||
override;
|
||||
Status GetFileSize(
|
||||
const std::string& fname,
|
||||
uint64* file_size /*, TransactionToken* token = nullptr */) override;
|
||||
Status RenameFile(
|
||||
const std::string& src,
|
||||
const std::string& target /*, TransactionToken* token = nullptr */)
|
||||
override;
|
||||
Status CopyFile(const std::string& src,
|
||||
const std::string&
|
||||
target /*, TransactionToken* token = nullptr */) override;
|
||||
std::string TranslateName(
|
||||
const std::string& name /*, TransactionToken* token = nullptr */)
|
||||
const override;
|
||||
void FlushCaches(/* TransactionToken* token=nullptr */) override;
|
||||
|
||||
private:
|
||||
std::unique_ptr<TF_Filesystem> filesystem_;
|
||||
|
@ -25,7 +25,9 @@ cc_library(
|
||||
"//tensorflow:windows": get_win_copts(),
|
||||
}),
|
||||
deps = [
|
||||
":expiring_lru_cache",
|
||||
":gcs_helper",
|
||||
":ram_file_block_cache",
|
||||
"//tensorflow/c:env",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
||||
@ -44,14 +46,6 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "file_block_cache",
|
||||
hdrs = ["file_block_cache.h"],
|
||||
deps = [
|
||||
"//tensorflow/c:tf_status",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cleanup",
|
||||
hdrs = ["cleanup.h"],
|
||||
@ -63,7 +57,6 @@ cc_library(
|
||||
hdrs = ["ram_file_block_cache.h"],
|
||||
deps = [
|
||||
":cleanup",
|
||||
":file_block_cache",
|
||||
"//tensorflow/c:env",
|
||||
"//tensorflow/c:tf_status",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
|
@ -1,140 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_FILE_BLOCK_CACHE_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_FILE_BLOCK_CACHE_H_
|
||||
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <list>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
namespace tf_gcs_filesystem {
|
||||
|
||||
class FileBlockCache;
|
||||
|
||||
/// FileBlockCacheStatsInterface allows for instrumentation of the block cache.
|
||||
///
|
||||
/// FileBlockCacheStatsInterface and its subclasses must be safe to use from
|
||||
/// multiple threads concurrently.
|
||||
///
|
||||
/// WARNING! This is an experimental interface that may change or go away at any
|
||||
/// time.
|
||||
class FileBlockCacheStatsInterface {
|
||||
public:
|
||||
/// Configure is called to provide instrumentation hooks.
|
||||
///
|
||||
/// Note: Configure can be called multiple times (e.g. if the block cache is
|
||||
/// re-initialized).
|
||||
virtual void Configure(const FileBlockCache* block_cache) = 0;
|
||||
|
||||
/// RecordBlockLoadRequest is called to record the size of a hit block.
|
||||
virtual void RecordCacheHitBlockSize(size_t bytes_transferred) = 0;
|
||||
|
||||
/// RecordBlockLoadRequest is called to record the size of a missed block.
|
||||
virtual void RecordCacheMissBlockSize(size_t bytes_transferred) = 0;
|
||||
|
||||
virtual ~FileBlockCacheStatsInterface() = default;
|
||||
};
|
||||
|
||||
/// \brief A block cache of file contents, keyed by {filename, offset}.
|
||||
///
|
||||
/// This class should be shared by read-only random access files on a remote
|
||||
/// filesystem (e.g. GCS).
|
||||
class FileBlockCache {
|
||||
public:
|
||||
/// The callback executed when a block is not found in the cache, and needs to
|
||||
/// be fetched from the backing filesystem. This callback is provided when the
|
||||
/// cache is constructed. The `status` should be `TF_OK` as long as the
|
||||
/// read from the remote filesystem succeeded (similar to the semantics of the
|
||||
/// read(2) system call).
|
||||
typedef std::function<void(const std::string& filename, size_t offset,
|
||||
size_t buffer_size, char* buffer,
|
||||
size_t* bytes_transferred, TF_Status* status)>
|
||||
BlockFetcher;
|
||||
|
||||
virtual ~FileBlockCache() {}
|
||||
|
||||
/// Read `n` bytes from `filename` starting at `offset` into `buffer`. This
|
||||
/// method will set `status` to:
|
||||
///
|
||||
/// 1) The error from the remote filesystem, if the read from the remote
|
||||
/// filesystem failed.
|
||||
/// 2) `TF_FAILED_PRECONDITION` if the read from the remote filesystem
|
||||
/// succeeded,
|
||||
/// but the read returned a partial block, and the LRU cache contained a
|
||||
/// block at a higher offset (indicating that the partial block should have
|
||||
/// been a full block).
|
||||
/// 3) `TF_OUT_OF_RANGE` if the read from the remote filesystem succeeded, but
|
||||
/// the file contents do not extend past `offset` and thus nothing was
|
||||
/// placed in `out`.
|
||||
/// 4) `TF_OK` otherwise (i.e. the read succeeded, and at least one byte was
|
||||
/// placed
|
||||
/// in `buffer`).
|
||||
///
|
||||
/// Caller is responsible for allocating memory for `buffer`.
|
||||
/// `buffer` will be left unchanged in case of errors.
|
||||
virtual void Read(const std::string& filename, size_t offset, size_t n,
|
||||
char* buffer, size_t* bytes_transferred,
|
||||
TF_Status* status) = 0;
|
||||
|
||||
// Validate the given file signature with the existing file signature in the
|
||||
// cache. Returns true if the signature doesn't change or the file did not
|
||||
// exist before. If the signature changes, update the existing signature with
|
||||
// the new one and remove the file from cache.
|
||||
virtual bool ValidateAndUpdateFileSignature(const std::string& filename,
|
||||
int64_t file_signature) = 0;
|
||||
|
||||
/// Remove all cached blocks for `filename`.
|
||||
virtual void RemoveFile(const std::string& filename) = 0;
|
||||
|
||||
/// Remove all cached data.
|
||||
virtual void Flush() = 0;
|
||||
|
||||
/// Accessors for cache parameters.
|
||||
virtual size_t block_size() const = 0;
|
||||
virtual size_t max_bytes() const = 0;
|
||||
virtual uint64_t max_staleness() const = 0;
|
||||
|
||||
/// The current size (in bytes) of the cache.
|
||||
virtual size_t CacheSize() const = 0;
|
||||
|
||||
// Returns true if the cache is enabled. If false, the BlockFetcher callback
|
||||
// is always executed during Read.
|
||||
virtual bool IsCacheEnabled() const = 0;
|
||||
|
||||
void SetStats(FileBlockCacheStatsInterface* stats) {
|
||||
if (stats == nullptr) {
|
||||
std::cerr
|
||||
<< "Attempted to monitor a NULL stats object. This may prevent the "
|
||||
"corresponding monitoring data from being exported";
|
||||
return;
|
||||
}
|
||||
cache_stats_ = stats;
|
||||
cache_stats_->Configure(this);
|
||||
}
|
||||
|
||||
protected:
|
||||
FileBlockCacheStatsInterface* cache_stats_ = nullptr; // Not owned.
|
||||
};
|
||||
|
||||
} // namespace tf_gcs_filesystem
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_FILE_BLOCK_CACHE_H_
|
@ -28,6 +28,27 @@ limitations under the License.
|
||||
// This filesystem will support `gs://` URI schemes.
|
||||
namespace gcs = google::cloud::storage;
|
||||
|
||||
// The environment variable that overrides the block size for aligned reads from
|
||||
// GCS. Specified in MB (e.g. "16" = 16 x 1024 x 1024 = 16777216 bytes).
|
||||
constexpr char kBlockSize[] = "GCS_READ_CACHE_BLOCK_SIZE_MB";
|
||||
constexpr size_t kDefaultBlockSize = 64 * 1024 * 1024;
|
||||
// The environment variable that overrides the max size of the LRU cache of
|
||||
// blocks read from GCS. Specified in MB.
|
||||
constexpr char kMaxCacheSize[] = "GCS_READ_CACHE_MAX_SIZE_MB";
|
||||
constexpr size_t kDefaultMaxCacheSize = 0;
|
||||
// The environment variable that overrides the maximum staleness of cached file
|
||||
// contents. Once any block of a file reaches this staleness, all cached blocks
|
||||
// will be evicted on the next read.
|
||||
constexpr char kMaxStaleness[] = "GCS_READ_CACHE_MAX_STALENESS";
|
||||
constexpr uint64_t kDefaultMaxStaleness = 0;
|
||||
|
||||
constexpr char kStatCacheMaxAge[] = "GCS_STAT_CACHE_MAX_AGE";
|
||||
constexpr uint64_t kStatCacheDefaultMaxAge = 5;
|
||||
// The environment variable that overrides the maximum number of entries in the
|
||||
// Stat cache.
|
||||
constexpr char kStatCacheMaxEntries[] = "GCS_STAT_CACHE_MAX_ENTRIES";
|
||||
constexpr size_t kStatCacheDefaultMaxEntries = 1024;
|
||||
|
||||
// How to upload new data when Flush() is called multiple times.
|
||||
// By default the entire file is reuploaded.
|
||||
constexpr char kAppendMode[] = "GCS_APPEND_MODE";
|
||||
@ -82,28 +103,16 @@ static void MaybeAppendSlash(std::string* name) {
|
||||
name->push_back('/');
|
||||
}
|
||||
|
||||
// SECTION 1. Implementation for `TF_RandomAccessFile`
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_random_access_file {
|
||||
typedef struct GCSFile {
|
||||
const std::string bucket;
|
||||
const std::string object;
|
||||
gcs::Client* gcs_client; // not owned
|
||||
} GCSFile;
|
||||
|
||||
void Cleanup(TF_RandomAccessFile* file) {
|
||||
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
|
||||
delete gcs_file;
|
||||
}
|
||||
|
||||
// TODO(vnvo2409): Adding cache.
|
||||
// `google-cloud-cpp` is working on a feature that we may want to use.
|
||||
// See https://github.com/googleapis/google-cloud-cpp/issues/4013.
|
||||
int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
|
||||
char* buffer, TF_Status* status) {
|
||||
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
|
||||
auto stream = gcs_file->gcs_client->ReadObject(
|
||||
gcs_file->bucket, gcs_file->object, gcs::ReadRange(offset, offset + n));
|
||||
// A helper function to actually read the data from GCS.
|
||||
static int64_t LoadBufferFromGCS(const std::string& path, size_t offset,
|
||||
size_t buffer_size, char* buffer,
|
||||
tf_gcs_filesystem::GCSFile* gcs_file,
|
||||
TF_Status* status) {
|
||||
std::string bucket, object;
|
||||
ParseGCSPath(path, false, &bucket, &object, status);
|
||||
if (TF_GetCode(status) != TF_OK) return -1;
|
||||
auto stream = gcs_file->gcs_client.ReadObject(
|
||||
bucket, object, gcs::ReadRange(offset, offset + buffer_size));
|
||||
TF_SetStatusFromGCSStatus(stream.status(), status);
|
||||
if ((TF_GetCode(status) != TF_OK) &&
|
||||
(TF_GetCode(status) != TF_OUT_OF_RANGE)) {
|
||||
@ -112,16 +121,119 @@ int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
|
||||
int64_t read;
|
||||
if (!absl::SimpleAtoi(stream.headers().find("content-length")->second,
|
||||
&read)) {
|
||||
TF_SetStatus(status, TF_UNKNOWN, "Could not get content-length header");
|
||||
return -1;
|
||||
}
|
||||
if (read != n) {
|
||||
TF_SetStatus(status, TF_OUT_OF_RANGE, "Read less bytes than requested");
|
||||
// When we read a file with offset that is bigger than the actual file size.
|
||||
// GCS will return an empty header (e.g no `content-length` header). In this
|
||||
// case, we will set read to `0` and continue.
|
||||
if (TF_GetCode(status) == TF_OUT_OF_RANGE) {
|
||||
read = 0;
|
||||
} else {
|
||||
TF_SetStatus(status, TF_UNKNOWN, "Could not get content-length header");
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
// `TF_OUT_OF_RANGE` isn't considered as an error. So we clear it here.
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
stream.read(buffer, read);
|
||||
read = stream.gcount();
|
||||
if (read < buffer_size) {
|
||||
// Check stat cache to see if we encountered an interrupted read.
|
||||
tf_gcs_filesystem::GcsFileStat stat;
|
||||
if (gcs_file->stat_cache->Lookup(path, &stat)) {
|
||||
if (offset + read < stat.base.length) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
absl::StrCat("File contents are inconsistent for file: ",
|
||||
path, " @ ", offset)
|
||||
.c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
return read;
|
||||
}
|
||||
|
||||
// SECTION 1. Implementation for `TF_RandomAccessFile`
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_random_access_file {
|
||||
using ReadFn =
|
||||
std::function<int64_t(const std::string& path, uint64_t offset, size_t n,
|
||||
char* buffer, TF_Status* status)>;
|
||||
typedef struct GCSFile {
|
||||
const std::string path;
|
||||
const bool is_cache_enable;
|
||||
const uint64_t buffer_size;
|
||||
ReadFn read_fn;
|
||||
absl::Mutex buffer_mutex;
|
||||
uint64_t buffer_start ABSL_GUARDED_BY(buffer_mutex);
|
||||
bool buffer_end_is_past_eof ABSL_GUARDED_BY(buffer_mutex);
|
||||
std::string buffer ABSL_GUARDED_BY(buffer_mutex);
|
||||
|
||||
GCSFile(std::string path, bool is_cache_enable, uint64_t buffer_size,
|
||||
ReadFn read_fn)
|
||||
: path(path),
|
||||
is_cache_enable(is_cache_enable),
|
||||
buffer_size(buffer_size),
|
||||
read_fn(std::move(read_fn)),
|
||||
buffer_mutex(),
|
||||
buffer_start(0),
|
||||
buffer_end_is_past_eof(false),
|
||||
buffer() {}
|
||||
} GCSFile;
|
||||
|
||||
void Cleanup(TF_RandomAccessFile* file) {
|
||||
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
|
||||
delete gcs_file;
|
||||
}
|
||||
|
||||
// `google-cloud-cpp` is working on a feature that we may want to use.
|
||||
// See https://github.com/googleapis/google-cloud-cpp/issues/4013.
|
||||
int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
|
||||
char* buffer, TF_Status* status) {
|
||||
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
|
||||
if (gcs_file->is_cache_enable || n > gcs_file->buffer_size) {
|
||||
return gcs_file->read_fn(gcs_file->path, offset, n, buffer, status);
|
||||
} else {
|
||||
absl::MutexLock l(&gcs_file->buffer_mutex);
|
||||
size_t buffer_end = gcs_file->buffer_start + gcs_file->buffer.size();
|
||||
size_t copy_size = 0;
|
||||
if (offset < buffer_end && gcs_file->buffer_start) {
|
||||
copy_size = (std::min)(n, static_cast<size_t>(buffer_end - offset));
|
||||
memcpy(buffer,
|
||||
gcs_file->buffer.data() + (offset - gcs_file->buffer_start),
|
||||
copy_size);
|
||||
}
|
||||
bool consumed_buffer_to_eof =
|
||||
offset + copy_size >= buffer_end && gcs_file->buffer_end_is_past_eof;
|
||||
if (copy_size < n && !consumed_buffer_to_eof) {
|
||||
gcs_file->buffer_start = offset + copy_size;
|
||||
gcs_file->buffer.resize(gcs_file->buffer_size);
|
||||
auto read_fill_buffer = gcs_file->read_fn(
|
||||
gcs_file->path, gcs_file->buffer_start, gcs_file->buffer_size,
|
||||
&(gcs_file->buffer[0]), status);
|
||||
gcs_file->buffer_end_is_past_eof =
|
||||
(TF_GetCode(status) == TF_OUT_OF_RANGE);
|
||||
if (read_fill_buffer >= 0) gcs_file->buffer.resize(read_fill_buffer);
|
||||
if (TF_GetCode(status) != TF_OK &&
|
||||
TF_GetCode(status) != TF_OUT_OF_RANGE) {
|
||||
// Empty the buffer to avoid caching bad reads.
|
||||
gcs_file->buffer.resize(0);
|
||||
return -1;
|
||||
}
|
||||
size_t remaining_copy =
|
||||
(std::min)(n - copy_size, gcs_file->buffer.size());
|
||||
memcpy(buffer + copy_size, gcs_file->buffer.data(), remaining_copy);
|
||||
copy_size += remaining_copy;
|
||||
}
|
||||
if (copy_size < n) {
|
||||
// Forget the end-of-file flag to allow for clients that poll on the
|
||||
// same file.
|
||||
gcs_file->buffer_end_is_past_eof = false;
|
||||
TF_SetStatus(status, TF_OUT_OF_RANGE, "Read less bytes than requested");
|
||||
return copy_size;
|
||||
}
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
return copy_size;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tf_random_access_file
|
||||
|
||||
// SECTION 2. Implementation for `TF_WritableFile`
|
||||
@ -290,11 +402,52 @@ uint64_t Length(const TF_ReadOnlyMemoryRegion* region) {
|
||||
// SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_gcs_filesystem {
|
||||
// TODO(vnvo2409): Add lazy-loading and customizing parameters.
|
||||
// TODO(vnvo2409): Use partial reponse for better performance.
|
||||
// TODO(vnvo2409): We could do some cleanups like `return TF_SetStatus`.
|
||||
// TODO(vnvo2409): Refactor the filesystem implementation when
|
||||
// https://github.com/googleapis/google-cloud-cpp/issues/4482 is done.
|
||||
GCSFile::GCSFile(google::cloud::storage::Client&& gcs_client)
|
||||
: gcs_client(gcs_client), block_cache_lock() {
|
||||
const char* append_mode = std::getenv(kAppendMode);
|
||||
compose = (append_mode != nullptr) && (!strcmp(kAppendMode, append_mode));
|
||||
|
||||
uint64_t value;
|
||||
block_size = kDefaultBlockSize;
|
||||
size_t max_bytes = kDefaultMaxCacheSize;
|
||||
uint64_t max_staleness = kDefaultMaxStaleness;
|
||||
|
||||
// Apply the overrides for the block size (MB), max bytes (MB), and max
|
||||
// staleness (seconds) if provided.
|
||||
if (absl::SimpleAtoi(std::getenv(kBlockSize), &value)) {
|
||||
block_size = value * 1024 * 1024;
|
||||
}
|
||||
if (absl::SimpleAtoi(std::getenv(kMaxCacheSize), &value)) {
|
||||
max_bytes = static_cast<size_t>(value * 1024 * 1024);
|
||||
}
|
||||
if (absl::SimpleAtoi(std::getenv(kMaxStaleness), &value)) {
|
||||
max_staleness = value;
|
||||
}
|
||||
|
||||
file_block_cache = std::make_unique<RamFileBlockCache>(
|
||||
block_size, max_bytes, max_staleness,
|
||||
[this](const std::string& filename, size_t offset, size_t buffer_size,
|
||||
char* buffer, TF_Status* status) {
|
||||
return LoadBufferFromGCS(filename, offset, buffer_size, buffer, this,
|
||||
status);
|
||||
});
|
||||
|
||||
uint64_t stat_cache_max_age = kStatCacheDefaultMaxAge;
|
||||
size_t stat_cache_max_entries = kStatCacheDefaultMaxEntries;
|
||||
if (absl::SimpleAtoi(std::getenv(kStatCacheMaxAge), &value)) {
|
||||
stat_cache_max_age = value;
|
||||
}
|
||||
if (absl::SimpleAtoi(std::getenv(kStatCacheMaxEntries), &value)) {
|
||||
stat_cache_max_entries = static_cast<size_t>(value);
|
||||
}
|
||||
stat_cache = std::make_unique<ExpiringLRUCache<GcsFileStat>>(
|
||||
stat_cache_max_age, stat_cache_max_entries);
|
||||
}
|
||||
|
||||
void Init(TF_Filesystem* filesystem, TF_Status* status) {
|
||||
google::cloud::StatusOr<gcs::Client> client =
|
||||
gcs::Client::CreateDefaultClient();
|
||||
@ -303,12 +456,7 @@ void Init(TF_Filesystem* filesystem, TF_Status* status) {
|
||||
return;
|
||||
}
|
||||
|
||||
const char* append_mode = std::getenv(kAppendMode);
|
||||
bool compose =
|
||||
(append_mode != nullptr) && (!strcmp(kAppendMode, append_mode));
|
||||
|
||||
filesystem->plugin_filesystem =
|
||||
new GCSFile({std::move(client.value()), compose});
|
||||
filesystem->plugin_filesystem = new GCSFile(std::move(client.value()));
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
@ -317,6 +465,19 @@ void Cleanup(TF_Filesystem* filesystem) {
|
||||
delete gcs_file;
|
||||
}
|
||||
|
||||
static void UncachedStatForObject(const std::string& bucket,
|
||||
const std::string& object, GcsFileStat* stat,
|
||||
gcs::Client* gcs_client, TF_Status* status) {
|
||||
auto metadata = gcs_client->GetObjectMetadata(bucket, object);
|
||||
if (!metadata) return TF_SetStatusFromGCSStatus(metadata.status(), status);
|
||||
stat->generation_number = metadata->generation();
|
||||
stat->base.length = metadata->size();
|
||||
stat->base.mtime_nsec =
|
||||
metadata->time_storage_class_updated().time_since_epoch().count();
|
||||
stat->base.is_directory = object.back() == '/';
|
||||
return TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
// TODO(vnvo2409): Implement later
|
||||
void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_RandomAccessFile* file, TF_Status* status) {
|
||||
@ -325,8 +486,46 @@ void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
|
||||
bool is_cache_enabled;
|
||||
{
|
||||
absl::MutexLock l(&gcs_file->block_cache_lock);
|
||||
is_cache_enabled = gcs_file->file_block_cache->IsCacheEnabled();
|
||||
}
|
||||
auto read_fn = [gcs_file, is_cache_enabled, bucket, object](
|
||||
const std::string& path, uint64_t offset, size_t n,
|
||||
char* buffer, TF_Status* status) -> int64_t {
|
||||
int64_t read = 0;
|
||||
if (is_cache_enabled) {
|
||||
absl::ReaderMutexLock l(&gcs_file->block_cache_lock);
|
||||
GcsFileStat stat;
|
||||
gcs_file->stat_cache->LookupOrCompute(
|
||||
path, &stat,
|
||||
[gcs_file, bucket, object](const std::string& path, GcsFileStat* stat,
|
||||
TF_Status* status) {
|
||||
UncachedStatForObject(bucket, object, stat, &gcs_file->gcs_client,
|
||||
status);
|
||||
},
|
||||
status);
|
||||
if (TF_GetCode(status) != TF_OK) return -1;
|
||||
if (!gcs_file->file_block_cache->ValidateAndUpdateFileSignature(
|
||||
path, stat.generation_number)) {
|
||||
std::cout
|
||||
<< "File signature has been changed. Refreshing the cache. Path: "
|
||||
<< path;
|
||||
}
|
||||
read = gcs_file->file_block_cache->Read(path, offset, n, buffer, status);
|
||||
} else {
|
||||
read = LoadBufferFromGCS(path, offset, n, buffer, gcs_file, status);
|
||||
}
|
||||
if (TF_GetCode(status) != TF_OK) return -1;
|
||||
if (read < n)
|
||||
TF_SetStatus(status, TF_OUT_OF_RANGE, "Read less bytes than requested");
|
||||
else
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
return read;
|
||||
};
|
||||
file->plugin_file = new tf_random_access_file::GCSFile(
|
||||
{std::move(bucket), std::move(object), &gcs_file->gcs_client});
|
||||
std::move(path), is_cache_enabled, gcs_file->block_size, read_fn);
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
|
@ -17,6 +17,8 @@
|
||||
|
||||
#include "google/cloud/storage/client.h"
|
||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||
#include "tensorflow/c/experimental/filesystem/plugins/gcs/expiring_lru_cache.h"
|
||||
#include "tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
void ParseGCSPath(const std::string& fname, bool object_empty_ok,
|
||||
@ -45,10 +47,23 @@ uint64_t Length(const TF_ReadOnlyMemoryRegion* region);
|
||||
} // namespace tf_read_only_memory_region
|
||||
|
||||
namespace tf_gcs_filesystem {
|
||||
typedef struct GcsFileStat {
|
||||
TF_FileStatistics base;
|
||||
int64_t generation_number;
|
||||
} GcsFileStat;
|
||||
|
||||
typedef struct GCSFile {
|
||||
google::cloud::storage::Client gcs_client; // owned
|
||||
bool compose;
|
||||
absl::Mutex block_cache_lock;
|
||||
std::shared_ptr<RamFileBlockCache> file_block_cache
|
||||
ABSL_GUARDED_BY(block_cache_lock);
|
||||
uint64_t block_size; // Reads smaller than block_size will trigger a read
|
||||
// of block_size.
|
||||
std::unique_ptr<ExpiringLRUCache<GcsFileStat>> stat_cache;
|
||||
GCSFile(google::cloud::storage::Client&& gcs_client);
|
||||
} GCSFile;
|
||||
|
||||
void Init(TF_Filesystem* filesystem, TF_Status* status);
|
||||
void Cleanup(TF_Filesystem* filesystem);
|
||||
void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
|
||||
|
@ -39,9 +39,6 @@ std::shared_ptr<RamFileBlockCache::Block> RamFileBlockCache::Lookup(
|
||||
auto entry = block_map_.find(key);
|
||||
if (entry != block_map_.end()) {
|
||||
if (BlockNotStale(entry->second)) {
|
||||
if (cache_stats_ != nullptr) {
|
||||
cache_stats_->RecordCacheHitBlockSize(entry->second->data.size());
|
||||
}
|
||||
return entry->second;
|
||||
} else {
|
||||
// Remove the stale block and continue.
|
||||
@ -136,12 +133,9 @@ void RamFileBlockCache::MaybeFetch(const Key& key,
|
||||
block->mu.Unlock(); // Release the lock while making the API call.
|
||||
block->data.clear();
|
||||
block->data.resize(block_size_, 0);
|
||||
size_t bytes_transferred;
|
||||
block_fetcher_(key.first, key.second, block_size_, block->data.data(),
|
||||
&bytes_transferred, status);
|
||||
if (cache_stats_ != nullptr) {
|
||||
cache_stats_->RecordCacheMissBlockSize(bytes_transferred);
|
||||
}
|
||||
int64_t bytes_transferred;
|
||||
bytes_transferred = block_fetcher_(key.first, key.second, block_size_,
|
||||
block->data.data(), status);
|
||||
block->mu.Lock(); // Reacquire the lock immediately afterwards
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
block->data.resize(bytes_transferred, 0);
|
||||
@ -171,18 +165,16 @@ void RamFileBlockCache::MaybeFetch(const Key& key,
|
||||
"Control flow should never reach the end of RamFileBlockCache::Fetch.");
|
||||
}
|
||||
|
||||
void RamFileBlockCache::Read(const std::string& filename, size_t offset,
|
||||
size_t n, char* buffer, size_t* bytes_transferred,
|
||||
TF_Status* status) {
|
||||
*bytes_transferred = 0;
|
||||
int64_t RamFileBlockCache::Read(const std::string& filename, size_t offset,
|
||||
size_t n, char* buffer, TF_Status* status) {
|
||||
if (n == 0) {
|
||||
return TF_SetStatus(status, TF_OK, "");
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
return 0;
|
||||
}
|
||||
if (!IsCacheEnabled() || (n > max_bytes_)) {
|
||||
// The cache is effectively disabled, so we pass the read through to the
|
||||
// fetcher without breaking it up into blocks.
|
||||
return block_fetcher_(filename, offset, n, buffer, bytes_transferred,
|
||||
status);
|
||||
return block_fetcher_(filename, offset, n, buffer, status);
|
||||
}
|
||||
// Calculate the block-aligned start and end of the read.
|
||||
size_t start = block_size_ * (offset / block_size_);
|
||||
@ -202,20 +194,20 @@ void RamFileBlockCache::Read(const std::string& filename, size_t offset,
|
||||
abort();
|
||||
}
|
||||
MaybeFetch(key, block, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
if (TF_GetCode(status) != TF_OK) return -1;
|
||||
UpdateLRU(key, block, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
if (TF_GetCode(status) != TF_OK) return -1;
|
||||
// Copy the relevant portion of the block into the result buffer.
|
||||
const auto& data = block->data;
|
||||
if (offset >= pos + data.size()) {
|
||||
// The requested offset is at or beyond the end of the file. This can
|
||||
// happen if `offset` is not block-aligned, and the read returns the last
|
||||
// block in the file, which does not extend all the way out to `offset`.
|
||||
*bytes_transferred = total_bytes_transferred;
|
||||
std::stringstream os;
|
||||
os << "EOF at offset " << offset << " in file " << filename
|
||||
<< " at position " << pos << " with data size " << data.size();
|
||||
return TF_SetStatus(status, TF_OUT_OF_RANGE, std::move(os).str().c_str());
|
||||
TF_SetStatus(status, TF_OUT_OF_RANGE, std::move(os).str().c_str());
|
||||
return total_bytes_transferred;
|
||||
}
|
||||
auto begin = data.begin();
|
||||
if (offset > pos) {
|
||||
@ -237,8 +229,8 @@ void RamFileBlockCache::Read(const std::string& filename, size_t offset,
|
||||
break;
|
||||
}
|
||||
}
|
||||
*bytes_transferred = total_bytes_transferred;
|
||||
return TF_SetStatus(status, TF_OK, "");
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
return total_bytes_transferred;
|
||||
}
|
||||
|
||||
bool RamFileBlockCache::ValidateAndUpdateFileSignature(
|
||||
|
@ -28,7 +28,6 @@ limitations under the License.
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "absl/synchronization/notification.h"
|
||||
#include "tensorflow/c/env.h"
|
||||
#include "tensorflow/c/experimental/filesystem/plugins/gcs/file_block_cache.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
namespace tf_gcs_filesystem {
|
||||
@ -37,16 +36,17 @@ namespace tf_gcs_filesystem {
|
||||
///
|
||||
/// This class should be shared by read-only random access files on a remote
|
||||
/// filesystem (e.g. GCS).
|
||||
class RamFileBlockCache : public FileBlockCache {
|
||||
class RamFileBlockCache {
|
||||
public:
|
||||
/// The callback executed when a block is not found in the cache, and needs to
|
||||
/// be fetched from the backing filesystem. This callback is provided when the
|
||||
/// cache is constructed. The `status` should be `TF_OK` as long as the
|
||||
/// read from the remote filesystem succeeded (similar to the semantics of the
|
||||
/// read(2) system call).
|
||||
typedef std::function<void(const std::string& filename, size_t offset,
|
||||
size_t buffer_size, char* buffer,
|
||||
size_t* bytes_transferred, TF_Status* status)>
|
||||
/// cache is constructed. It returns total bytes read ( -1 in case of errors
|
||||
/// ). The `status` should be `TF_OK` as long as the read from the remote
|
||||
/// filesystem succeeded (similar to the semantics of the read(2) system
|
||||
/// call).
|
||||
typedef std::function<int64_t(const std::string& filename, size_t offset,
|
||||
size_t buffer_size, char* buffer,
|
||||
TF_Status* status)>
|
||||
BlockFetcher;
|
||||
|
||||
RamFileBlockCache(size_t block_size, size_t max_bytes, uint64_t max_staleness,
|
||||
@ -66,10 +66,10 @@ class RamFileBlockCache : public FileBlockCache {
|
||||
TF_StartThread(&thread_options, "TF_prune_FBC", PruneThread, this));
|
||||
}
|
||||
std::cout << "GCS file block cache is "
|
||||
<< (IsCacheEnabled() ? "enabled" : "disabled");
|
||||
<< (IsCacheEnabled() ? "enabled" : "disabled") << ".\n";
|
||||
}
|
||||
|
||||
~RamFileBlockCache() override {
|
||||
~RamFileBlockCache() {
|
||||
if (pruning_thread_) {
|
||||
stop_pruning_thread_.Notify();
|
||||
// Destroying pruning_thread_ will block until Prune() receives the above
|
||||
@ -78,8 +78,9 @@ class RamFileBlockCache : public FileBlockCache {
|
||||
}
|
||||
}
|
||||
|
||||
/// Read `n` bytes from `filename` starting at `offset` into `buffer`. This
|
||||
/// method will set `status` to:
|
||||
/// Read `n` bytes from `filename` starting at `offset` into `buffer`. It
|
||||
/// returns total bytes read ( -1 in case of errors ). This method will set
|
||||
/// `status` to:
|
||||
///
|
||||
/// 1) The error from the remote filesystem, if the read from the remote
|
||||
/// filesystem failed.
|
||||
@ -97,37 +98,34 @@ class RamFileBlockCache : public FileBlockCache {
|
||||
///
|
||||
/// Caller is responsible for allocating memory for `buffer`.
|
||||
/// `buffer` will be left unchanged in case of errors.
|
||||
void Read(const std::string& filename, size_t offset, size_t n, char* buffer,
|
||||
size_t* bytes_transferred, TF_Status* status) override;
|
||||
int64_t Read(const std::string& filename, size_t offset, size_t n,
|
||||
char* buffer, TF_Status* status);
|
||||
|
||||
// Validate the given file signature with the existing file signature in the
|
||||
// cache. Returns true if the signature doesn't change or the file doesn't
|
||||
// exist before. If the signature changes, update the existing signature with
|
||||
// the new one and remove the file from cache.
|
||||
bool ValidateAndUpdateFileSignature(const std::string& filename,
|
||||
int64_t file_signature) override
|
||||
int64_t file_signature)
|
||||
ABSL_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
/// Remove all cached blocks for `filename`.
|
||||
void RemoveFile(const std::string& filename) override
|
||||
ABSL_LOCKS_EXCLUDED(mu_);
|
||||
void RemoveFile(const std::string& filename) ABSL_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
/// Remove all cached data.
|
||||
void Flush() override ABSL_LOCKS_EXCLUDED(mu_);
|
||||
void Flush() ABSL_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
/// Accessors for cache parameters.
|
||||
size_t block_size() const override { return block_size_; }
|
||||
size_t max_bytes() const override { return max_bytes_; }
|
||||
uint64_t max_staleness() const override { return max_staleness_; }
|
||||
size_t block_size() const { return block_size_; }
|
||||
size_t max_bytes() const { return max_bytes_; }
|
||||
uint64_t max_staleness() const { return max_staleness_; }
|
||||
|
||||
/// The current size (in bytes) of the cache.
|
||||
size_t CacheSize() const override ABSL_LOCKS_EXCLUDED(mu_);
|
||||
size_t CacheSize() const ABSL_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
// Returns true if the cache is enabled. If false, the BlockFetcher callback
|
||||
// is always executed during Read.
|
||||
bool IsCacheEnabled() const override {
|
||||
return block_size_ > 0 && max_bytes_ > 0;
|
||||
}
|
||||
bool IsCacheEnabled() const { return block_size_ > 0 && max_bytes_ > 0; }
|
||||
|
||||
// We can not pass a lambda with capture as a function pointer to
|
||||
// `TF_StartThread`, so we have to wrap `Prune` inside a static function.
|
||||
|
@ -33,20 +33,22 @@ Status ReadCache(tf_gcs_filesystem::RamFileBlockCache* cache,
|
||||
std::vector<char>* out) {
|
||||
out->clear();
|
||||
out->resize(n, 0);
|
||||
size_t bytes_transferred = 0;
|
||||
TF_Status status;
|
||||
cache->Read(filename, offset, n, out->data(), &bytes_transferred, &status);
|
||||
EXPECT_LE(bytes_transferred, n);
|
||||
out->resize(bytes_transferred, n);
|
||||
auto bytes_transferred =
|
||||
cache->Read(filename, offset, n, out->data(), &status);
|
||||
if (bytes_transferred >= 0) {
|
||||
EXPECT_LE(bytes_transferred, n);
|
||||
out->resize(bytes_transferred, n);
|
||||
}
|
||||
return status.status;
|
||||
}
|
||||
|
||||
TEST(RamFileBlockCacheTest, IsCacheEnabled) {
|
||||
auto fetcher = [](const string& filename, size_t offset, size_t n,
|
||||
char* buffer, size_t* bytes_transferred,
|
||||
TF_Status* status) {
|
||||
char* buffer, TF_Status* status) -> int64_t {
|
||||
// Do nothing.
|
||||
return TF_SetStatus(status, TF_OK, "");
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
return 0;
|
||||
};
|
||||
tf_gcs_filesystem::RamFileBlockCache cache1(0, 0, 0, fetcher);
|
||||
tf_gcs_filesystem::RamFileBlockCache cache2(16, 0, 0, fetcher);
|
||||
@ -62,12 +64,11 @@ TEST(RamFileBlockCacheTest, IsCacheEnabled) {
|
||||
TEST(RamFileBlockCacheTest, ValidateAndUpdateFileSignature) {
|
||||
int calls = 0;
|
||||
auto fetcher = [&calls](const string& filename, size_t offset, size_t n,
|
||||
char* buffer, size_t* bytes_transferred,
|
||||
TF_Status* status) {
|
||||
char* buffer, TF_Status* status) -> int64_t {
|
||||
calls++;
|
||||
memset(buffer, 'x', n);
|
||||
*bytes_transferred = n;
|
||||
return TF_SetStatus(status, TF_OK, "");
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
return n;
|
||||
};
|
||||
string filename = "file";
|
||||
tf_gcs_filesystem::RamFileBlockCache cache(16, 32, 0, fetcher);
|
||||
@ -96,15 +97,14 @@ TEST(RamFileBlockCacheTest, PassThrough) {
|
||||
int calls = 0;
|
||||
auto fetcher = [&calls, want_filename, want_offset, want_n](
|
||||
const string& got_filename, size_t got_offset,
|
||||
size_t got_n, char* buffer, size_t* bytes_transferred,
|
||||
TF_Status* status) {
|
||||
size_t got_n, char* buffer, TF_Status* status) -> int64_t {
|
||||
EXPECT_EQ(got_filename, want_filename);
|
||||
EXPECT_EQ(got_offset, want_offset);
|
||||
EXPECT_EQ(got_n, want_n);
|
||||
calls++;
|
||||
memset(buffer, 'x', got_n);
|
||||
*bytes_transferred = got_n;
|
||||
return TF_SetStatus(status, TF_OK, "");
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
return got_n;
|
||||
};
|
||||
// If block_size, max_bytes, or both are zero, or want_n is larger than
|
||||
// max_bytes the cache is a pass-through.
|
||||
@ -133,16 +133,17 @@ TEST(RamFileBlockCacheTest, BlockAlignment) {
|
||||
}
|
||||
// The fetcher just fetches slices of the buffer.
|
||||
auto fetcher = [&buf](const string& filename, size_t offset, size_t n,
|
||||
char* buffer, size_t* bytes_transferred,
|
||||
TF_Status* status) {
|
||||
char* buffer, TF_Status* status) -> int64_t {
|
||||
int64_t bytes_transferred;
|
||||
if (offset < buf.size()) {
|
||||
size_t bytes_to_copy = std::min<size_t>(buf.size() - offset, n);
|
||||
memcpy(buffer, buf.data() + offset, bytes_to_copy);
|
||||
*bytes_transferred = bytes_to_copy;
|
||||
bytes_transferred = bytes_to_copy;
|
||||
} else {
|
||||
*bytes_transferred = 0;
|
||||
bytes_transferred = 0;
|
||||
}
|
||||
return TF_SetStatus(status, TF_OK, "");
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
return bytes_transferred;
|
||||
};
|
||||
for (size_t block_size = 2; block_size <= 4; block_size++) {
|
||||
// Make a cache of N-byte block size (1 block) and verify that reads of
|
||||
@ -181,15 +182,14 @@ TEST(RamFileBlockCacheTest, CacheHits) {
|
||||
std::set<size_t> calls;
|
||||
auto fetcher = [&calls, block_size](const string& filename, size_t offset,
|
||||
size_t n, char* buffer,
|
||||
size_t* bytes_transferred,
|
||||
TF_Status* status) {
|
||||
TF_Status* status) -> int64_t {
|
||||
EXPECT_EQ(n, block_size);
|
||||
EXPECT_EQ(offset % block_size, 0);
|
||||
EXPECT_EQ(calls.find(offset), calls.end()) << "at offset " << offset;
|
||||
calls.insert(offset);
|
||||
memset(buffer, 'x', n);
|
||||
*bytes_transferred = n;
|
||||
return TF_SetStatus(status, TF_OK, "");
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
return n;
|
||||
};
|
||||
const uint32 block_count = 256;
|
||||
tf_gcs_filesystem::RamFileBlockCache cache(
|
||||
@ -215,8 +215,7 @@ TEST(RamFileBlockCacheTest, OutOfRange) {
|
||||
bool second_block = false;
|
||||
auto fetcher = [block_size, file_size, &first_block, &second_block](
|
||||
const string& filename, size_t offset, size_t n,
|
||||
char* buffer, size_t* bytes_transferred,
|
||||
TF_Status* status) {
|
||||
char* buffer, TF_Status* status) -> int64_t {
|
||||
EXPECT_EQ(n, block_size);
|
||||
EXPECT_EQ(offset % block_size, 0);
|
||||
size_t bytes_to_copy = 0;
|
||||
@ -231,8 +230,8 @@ TEST(RamFileBlockCacheTest, OutOfRange) {
|
||||
memset(buffer, 'x', bytes_to_copy);
|
||||
second_block = true;
|
||||
}
|
||||
*bytes_transferred = bytes_to_copy;
|
||||
return TF_SetStatus(status, TF_OK, "");
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
return bytes_to_copy;
|
||||
};
|
||||
tf_gcs_filesystem::RamFileBlockCache cache(block_size, block_size, 0,
|
||||
fetcher);
|
||||
@ -260,14 +259,13 @@ TEST(RamFileBlockCacheTest, Inconsistent) {
|
||||
const size_t block_size = 16;
|
||||
// This fetcher returns OK but only fills in one byte for any offset.
|
||||
auto fetcher = [block_size](const string& filename, size_t offset, size_t n,
|
||||
char* buffer, size_t* bytes_transferred,
|
||||
TF_Status* status) {
|
||||
char* buffer, TF_Status* status) -> int64_t {
|
||||
EXPECT_EQ(n, block_size);
|
||||
EXPECT_EQ(offset % block_size, 0);
|
||||
EXPECT_GE(n, 1);
|
||||
memset(buffer, 'x', 1);
|
||||
*bytes_transferred = 1;
|
||||
return TF_SetStatus(status, TF_OK, "");
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
return 1;
|
||||
};
|
||||
tf_gcs_filesystem::RamFileBlockCache cache(block_size, 2 * block_size, 0,
|
||||
fetcher);
|
||||
@ -286,8 +284,7 @@ TEST(RamFileBlockCacheTest, LRU) {
|
||||
std::list<size_t> calls;
|
||||
auto fetcher = [&calls, block_size](const string& filename, size_t offset,
|
||||
size_t n, char* buffer,
|
||||
size_t* bytes_transferred,
|
||||
TF_Status* status) {
|
||||
TF_Status* status) -> int64_t {
|
||||
EXPECT_EQ(n, block_size);
|
||||
EXPECT_FALSE(calls.empty()) << "at offset = " << offset;
|
||||
if (!calls.empty()) {
|
||||
@ -295,8 +292,8 @@ TEST(RamFileBlockCacheTest, LRU) {
|
||||
calls.pop_front();
|
||||
}
|
||||
memset(buffer, 'x', n);
|
||||
*bytes_transferred = n;
|
||||
return TF_SetStatus(status, TF_OK, "");
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
return n;
|
||||
};
|
||||
const uint32 block_count = 2;
|
||||
tf_gcs_filesystem::RamFileBlockCache cache(
|
||||
@ -335,12 +332,11 @@ TEST(RamFileBlockCacheTest, LRU) {
|
||||
TEST(RamFileBlockCacheTest, MaxStaleness) {
|
||||
int calls = 0;
|
||||
auto fetcher = [&calls](const string& filename, size_t offset, size_t n,
|
||||
char* buffer, size_t* bytes_transferred,
|
||||
TF_Status* status) {
|
||||
char* buffer, TF_Status* status) -> int64_t {
|
||||
calls++;
|
||||
memset(buffer, 'x', n);
|
||||
*bytes_transferred = n;
|
||||
return TF_SetStatus(status, TF_OK, "");
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
return n;
|
||||
};
|
||||
std::vector<char> out;
|
||||
std::unique_ptr<NowSecondsEnv> env(new NowSecondsEnv);
|
||||
@ -380,8 +376,7 @@ TEST(RamFileBlockCacheTest, MaxStaleness) {
|
||||
TEST(RamFileBlockCacheTest, RemoveFile) {
|
||||
int calls = 0;
|
||||
auto fetcher = [&calls](const string& filename, size_t offset, size_t n,
|
||||
char* buffer, size_t* bytes_transferred,
|
||||
TF_Status* status) {
|
||||
char* buffer, TF_Status* status) -> int64_t {
|
||||
calls++;
|
||||
char c = (filename == "a") ? 'a' : (filename == "b") ? 'b' : 'x';
|
||||
if (offset > 0) {
|
||||
@ -389,8 +384,8 @@ TEST(RamFileBlockCacheTest, RemoveFile) {
|
||||
c = toupper(c);
|
||||
}
|
||||
memset(buffer, c, n);
|
||||
*bytes_transferred = n;
|
||||
return TF_SetStatus(status, TF_OK, "");
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
return n;
|
||||
};
|
||||
// This cache has space for 4 blocks; we'll read from two files.
|
||||
const size_t n = 3;
|
||||
@ -443,12 +438,11 @@ TEST(RamFileBlockCacheTest, RemoveFile) {
|
||||
TEST(RamFileBlockCacheTest, Prune) {
|
||||
int calls = 0;
|
||||
auto fetcher = [&calls](const string& filename, size_t offset, size_t n,
|
||||
char* buffer, size_t* bytes_transferred,
|
||||
TF_Status* status) {
|
||||
char* buffer, TF_Status* status) -> int64_t {
|
||||
calls++;
|
||||
memset(buffer, 'x', n);
|
||||
*bytes_transferred = n;
|
||||
return TF_SetStatus(status, TF_OK, "");
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
return n;
|
||||
};
|
||||
std::vector<char> out;
|
||||
// Our fake environment is initialized with the current timestamp.
|
||||
@ -509,17 +503,17 @@ TEST(RamFileBlockCacheTest, ParallelReads) {
|
||||
const int callers = 4;
|
||||
BlockingCounter counter(callers);
|
||||
auto fetcher = [&counter](const string& filename, size_t offset, size_t n,
|
||||
char* buffer, size_t* bytes_transferred,
|
||||
TF_Status* status) {
|
||||
char* buffer, TF_Status* status) -> int64_t {
|
||||
counter.DecrementCount();
|
||||
if (!counter.WaitFor(std::chrono::seconds(10))) {
|
||||
// This avoids having the test time out, which is harder to debug.
|
||||
return TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
||||
"desired concurrency not reached");
|
||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
||||
"desired concurrency not reached");
|
||||
return -1;
|
||||
}
|
||||
memset(buffer, 'x', n);
|
||||
*bytes_transferred = n;
|
||||
return TF_SetStatus(status, TF_OK, "");
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
return n;
|
||||
};
|
||||
const int block_size = 8;
|
||||
tf_gcs_filesystem::RamFileBlockCache cache(
|
||||
@ -548,17 +542,16 @@ TEST(RamFileBlockCacheTest, CoalesceConcurrentReads) {
|
||||
Notification notification;
|
||||
auto fetcher = [&num_requests, ¬ification, block_size](
|
||||
const string& filename, size_t offset, size_t n,
|
||||
char* buffer, size_t* bytes_transferred,
|
||||
TF_Status* status) {
|
||||
char* buffer, TF_Status* status) -> int64_t {
|
||||
EXPECT_EQ(n, block_size);
|
||||
EXPECT_EQ(offset, 0);
|
||||
num_requests++;
|
||||
memset(buffer, 'x', n);
|
||||
*bytes_transferred = n;
|
||||
notification.Notify();
|
||||
// Wait for other thread to issue read.
|
||||
Env::Default()->SleepForMicroseconds(100000); // 0.1 secs
|
||||
return TF_SetStatus(status, TF_OK, "");
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
return n;
|
||||
};
|
||||
tf_gcs_filesystem::RamFileBlockCache cache(block_size, block_size, 0,
|
||||
fetcher);
|
||||
@ -580,12 +573,11 @@ TEST(RamFileBlockCacheTest, CoalesceConcurrentReads) {
|
||||
TEST(RamFileBlockCacheTest, Flush) {
|
||||
int calls = 0;
|
||||
auto fetcher = [&calls](const string& filename, size_t offset, size_t n,
|
||||
char* buffer, size_t* bytes_transferred,
|
||||
TF_Status* status) {
|
||||
char* buffer, TF_Status* status) -> int64_t {
|
||||
calls++;
|
||||
memset(buffer, 'x', n);
|
||||
*bytes_transferred = n;
|
||||
return TF_SetStatus(status, TF_OK, "");
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
return n;
|
||||
};
|
||||
tf_gcs_filesystem::RamFileBlockCache cache(16, 32, 0, fetcher);
|
||||
std::vector<char> out;
|
||||
|
46
tensorflow/c/experimental/filesystem/plugins/s3/BUILD
Normal file
46
tensorflow/c/experimental/filesystem/plugins/s3/BUILD
Normal file
@ -0,0 +1,46 @@
|
||||
# Experimental gcs filesystem plugin.
|
||||
load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object")
|
||||
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
# Filesystem implementation for GCS environments
|
||||
tf_cc_shared_object(
|
||||
name = "s3_filesystem",
|
||||
framework_so = [],
|
||||
linkstatic = False,
|
||||
per_os_targets = 1,
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":s3_filesystem_impl"],
|
||||
)
|
||||
|
||||
# The real implementation of the filesystem.
|
||||
cc_library(
|
||||
name = "s3_filesystem_impl",
|
||||
srcs = ["s3_filesystem.cc"],
|
||||
hdrs = ["s3_filesystem.h"],
|
||||
copts = select({
|
||||
"//conditions:default": [],
|
||||
"//tensorflow:windows": get_win_copts(),
|
||||
}),
|
||||
deps = [
|
||||
":aws_crypto",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
||||
"@aws",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "aws_crypto",
|
||||
srcs = ["aws_crypto.cc"],
|
||||
hdrs = ["aws_crypto.h"],
|
||||
deps = [
|
||||
"@aws",
|
||||
"@boringssl//:crypto",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
133
tensorflow/c/experimental/filesystem/plugins/s3/aws_crypto.cc
Normal file
133
tensorflow/c/experimental/filesystem/plugins/s3/aws_crypto.cc
Normal file
@ -0,0 +1,133 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/filesystem/plugins/s3/aws_crypto.h"
|
||||
|
||||
#include <aws/core/utils/crypto/HashResult.h>
|
||||
#include <aws/s3/S3Client.h>
|
||||
#include <openssl/hmac.h>
|
||||
#include <openssl/rand.h>
|
||||
#include <openssl/sha.h>
|
||||
|
||||
namespace tf_s3_filesystem {
|
||||
|
||||
class AWSSha256HMACOpenSSLImpl : public Aws::Utils::Crypto::HMAC {
|
||||
public:
|
||||
AWSSha256HMACOpenSSLImpl() {}
|
||||
|
||||
virtual ~AWSSha256HMACOpenSSLImpl() = default;
|
||||
|
||||
Aws::Utils::Crypto::HashResult Calculate(
|
||||
const Aws::Utils::ByteBuffer& toSign,
|
||||
const Aws::Utils::ByteBuffer& secret) override {
|
||||
unsigned int length = SHA256_DIGEST_LENGTH;
|
||||
Aws::Utils::ByteBuffer digest(length);
|
||||
memset(digest.GetUnderlyingData(), 0, length);
|
||||
|
||||
HMAC_CTX ctx;
|
||||
HMAC_CTX_init(&ctx);
|
||||
|
||||
HMAC_Init_ex(&ctx, secret.GetUnderlyingData(),
|
||||
static_cast<int>(secret.GetLength()), EVP_sha256(), NULL);
|
||||
HMAC_Update(&ctx, toSign.GetUnderlyingData(), toSign.GetLength());
|
||||
HMAC_Final(&ctx, digest.GetUnderlyingData(), &length);
|
||||
HMAC_CTX_cleanup(&ctx);
|
||||
|
||||
return Aws::Utils::Crypto::HashResult(std::move(digest));
|
||||
}
|
||||
};
|
||||
|
||||
class AWSSha256OpenSSLImpl : public Aws::Utils::Crypto::Hash {
|
||||
public:
|
||||
AWSSha256OpenSSLImpl() {}
|
||||
|
||||
virtual ~AWSSha256OpenSSLImpl() = default;
|
||||
|
||||
Aws::Utils::Crypto::HashResult Calculate(const Aws::String& str) override {
|
||||
SHA256_CTX sha256;
|
||||
SHA256_Init(&sha256);
|
||||
SHA256_Update(&sha256, str.data(), str.size());
|
||||
|
||||
Aws::Utils::ByteBuffer hash(SHA256_DIGEST_LENGTH);
|
||||
SHA256_Final(hash.GetUnderlyingData(), &sha256);
|
||||
|
||||
return Aws::Utils::Crypto::HashResult(std::move(hash));
|
||||
}
|
||||
|
||||
Aws::Utils::Crypto::HashResult Calculate(Aws::IStream& stream) override {
|
||||
SHA256_CTX sha256;
|
||||
SHA256_Init(&sha256);
|
||||
|
||||
auto currentPos = stream.tellg();
|
||||
if (currentPos == std::streampos(std::streamoff(-1))) {
|
||||
currentPos = 0;
|
||||
stream.clear();
|
||||
}
|
||||
|
||||
stream.seekg(0, stream.beg);
|
||||
|
||||
char streamBuffer
|
||||
[Aws::Utils::Crypto::Hash::INTERNAL_HASH_STREAM_BUFFER_SIZE];
|
||||
while (stream.good()) {
|
||||
stream.read(streamBuffer,
|
||||
Aws::Utils::Crypto::Hash::INTERNAL_HASH_STREAM_BUFFER_SIZE);
|
||||
auto bytesRead = stream.gcount();
|
||||
|
||||
if (bytesRead > 0) {
|
||||
SHA256_Update(&sha256, streamBuffer, static_cast<size_t>(bytesRead));
|
||||
}
|
||||
}
|
||||
|
||||
stream.clear();
|
||||
stream.seekg(currentPos, stream.beg);
|
||||
|
||||
Aws::Utils::ByteBuffer hash(SHA256_DIGEST_LENGTH);
|
||||
SHA256_Final(hash.GetUnderlyingData(), &sha256);
|
||||
|
||||
return Aws::Utils::Crypto::HashResult(std::move(hash));
|
||||
}
|
||||
};
|
||||
|
||||
class AWSSecureRandomBytesImpl : public Aws::Utils::Crypto::SecureRandomBytes {
|
||||
public:
|
||||
AWSSecureRandomBytesImpl() {}
|
||||
virtual ~AWSSecureRandomBytesImpl() = default;
|
||||
void GetBytes(unsigned char* buffer, size_t bufferSize) override {
|
||||
assert(buffer);
|
||||
int success = RAND_bytes(buffer, static_cast<int>(bufferSize));
|
||||
if (success != 1) {
|
||||
m_failure = true;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
bool m_failure;
|
||||
};
|
||||
|
||||
std::shared_ptr<Aws::Utils::Crypto::Hash>
|
||||
AWSSHA256Factory::CreateImplementation() const {
|
||||
return Aws::MakeShared<AWSSha256OpenSSLImpl>(AWSCryptoAllocationTag);
|
||||
}
|
||||
|
||||
std::shared_ptr<Aws::Utils::Crypto::HMAC>
|
||||
AWSSHA256HmacFactory::CreateImplementation() const {
|
||||
return Aws::MakeShared<AWSSha256HMACOpenSSLImpl>(AWSCryptoAllocationTag);
|
||||
}
|
||||
|
||||
std::shared_ptr<Aws::Utils::Crypto::SecureRandomBytes>
|
||||
AWSSecureRandomFactory::CreateImplementation() const {
|
||||
return Aws::MakeShared<AWSSecureRandomBytesImpl>(AWSCryptoAllocationTag);
|
||||
}
|
||||
|
||||
} // namespace tf_s3_filesystem
|
47
tensorflow/c/experimental/filesystem/plugins/s3/aws_crypto.h
Normal file
47
tensorflow/c/experimental/filesystem/plugins/s3/aws_crypto.h
Normal file
@ -0,0 +1,47 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_AWS_CRYPTO_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_AWS_CRYPTO_H_
|
||||
|
||||
#include <aws/core/Aws.h>
|
||||
#include <aws/core/utils/crypto/Factories.h>
|
||||
#include <aws/core/utils/crypto/HMAC.h>
|
||||
#include <aws/core/utils/crypto/Hash.h>
|
||||
#include <aws/core/utils/crypto/SecureRandom.h>
|
||||
|
||||
namespace tf_s3_filesystem {
|
||||
constexpr char AWSCryptoAllocationTag[] = "AWSCryptoAllocation";
|
||||
|
||||
class AWSSHA256Factory : public Aws::Utils::Crypto::HashFactory {
|
||||
public:
|
||||
std::shared_ptr<Aws::Utils::Crypto::Hash> CreateImplementation()
|
||||
const override;
|
||||
};
|
||||
|
||||
class AWSSHA256HmacFactory : public Aws::Utils::Crypto::HMACFactory {
|
||||
public:
|
||||
std::shared_ptr<Aws::Utils::Crypto::HMAC> CreateImplementation()
|
||||
const override;
|
||||
};
|
||||
|
||||
class AWSSecureRandomFactory : public Aws::Utils::Crypto::SecureRandomFactory {
|
||||
public:
|
||||
std::shared_ptr<Aws::Utils::Crypto::SecureRandomBytes> CreateImplementation()
|
||||
const override;
|
||||
};
|
||||
|
||||
} // namespace tf_s3_filesystem
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_AWS_CRYPTO_H_
|
981
tensorflow/c/experimental/filesystem/plugins/s3/s3_filesystem.cc
Normal file
981
tensorflow/c/experimental/filesystem/plugins/s3/s3_filesystem.cc
Normal file
@ -0,0 +1,981 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/filesystem/plugins/s3/s3_filesystem.h"
|
||||
|
||||
#include <aws/core/client/AsyncCallerContext.h>
|
||||
#include <aws/core/config/AWSProfileConfigLoader.h>
|
||||
#include <aws/core/utils/FileSystemUtils.h>
|
||||
#include <aws/core/utils/stream/PreallocatedStreamBuf.h>
|
||||
#include <aws/s3/model/AbortMultipartUploadRequest.h>
|
||||
#include <aws/s3/model/CompleteMultipartUploadRequest.h>
|
||||
#include <aws/s3/model/CompletedMultipartUpload.h>
|
||||
#include <aws/s3/model/CompletedPart.h>
|
||||
#include <aws/s3/model/CopyObjectRequest.h>
|
||||
#include <aws/s3/model/CreateMultipartUploadRequest.h>
|
||||
#include <aws/s3/model/GetObjectRequest.h>
|
||||
#include <aws/s3/model/HeadBucketRequest.h>
|
||||
#include <aws/s3/model/HeadObjectRequest.h>
|
||||
#include <aws/s3/model/ListObjectsRequest.h>
|
||||
#include <aws/s3/model/UploadPartCopyRequest.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "absl/strings/ascii.h"
|
||||
#include "absl/strings/numbers.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||
#include "tensorflow/c/experimental/filesystem/plugins/s3/aws_crypto.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
// Implementation of a filesystem for S3 environments.
|
||||
// This filesystem will support `s3://` URI schemes.
|
||||
constexpr char kS3FileSystemAllocationTag[] = "S3FileSystemAllocation";
|
||||
constexpr char kS3ClientAllocationTag[] = "S3ClientAllocation";
|
||||
constexpr int64_t kS3TimeoutMsec = 300000; // 5 min
|
||||
|
||||
constexpr char kExecutorTag[] = "TransferManagerExecutorAllocation";
|
||||
constexpr int kExecutorPoolSize = 25;
|
||||
|
||||
constexpr uint64_t kS3MultiPartUploadChunkSize = 50 * 1024 * 1024; // 50 MB
|
||||
constexpr uint64_t kS3MultiPartDownloadChunkSize = 50 * 1024 * 1024; // 50 MB
|
||||
constexpr size_t kDownloadRetries = 3;
|
||||
constexpr size_t kUploadRetries = 3;
|
||||
|
||||
constexpr size_t kS3ReadAppendableFileBufferSize = 1024 * 1024; // 1 MB
|
||||
|
||||
static void* plugin_memory_allocate(size_t size) { return calloc(1, size); }
|
||||
static void plugin_memory_free(void* ptr) { free(ptr); }
|
||||
|
||||
static inline void TF_SetStatusFromAWSError(
|
||||
const Aws::Client::AWSError<Aws::S3::S3Errors>& error, TF_Status* status) {
|
||||
switch (error.GetResponseCode()) {
|
||||
case Aws::Http::HttpResponseCode::FORBIDDEN:
|
||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
||||
"AWS Credentials have not been set properly. "
|
||||
"Unable to access the specified S3 location");
|
||||
break;
|
||||
case Aws::Http::HttpResponseCode::REQUESTED_RANGE_NOT_SATISFIABLE:
|
||||
TF_SetStatus(status, TF_OUT_OF_RANGE, "Read less bytes than requested");
|
||||
break;
|
||||
case Aws::Http::HttpResponseCode::NOT_FOUND:
|
||||
TF_SetStatus(status, TF_NOT_FOUND, error.GetMessage().c_str());
|
||||
break;
|
||||
default:
|
||||
TF_SetStatus(
|
||||
status, TF_UNKNOWN,
|
||||
(error.GetExceptionName() + ": " + error.GetMessage()).c_str());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
static void ParseS3Path(const Aws::String& fname, bool object_empty_ok,
|
||||
Aws::String* bucket, Aws::String* object,
|
||||
TF_Status* status) {
|
||||
size_t scheme_end = fname.find("://") + 2;
|
||||
if (fname.substr(0, scheme_end + 1) != "s3://") {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
"S3 path doesn't start with 's3://'.");
|
||||
return;
|
||||
}
|
||||
|
||||
size_t bucket_end = fname.find("/", scheme_end + 1);
|
||||
if (bucket_end == std::string::npos) {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
"S3 path doesn't contain a bucket name.");
|
||||
return;
|
||||
}
|
||||
|
||||
*bucket = fname.substr(scheme_end + 1, bucket_end - scheme_end - 1);
|
||||
*object = fname.substr(bucket_end + 1);
|
||||
|
||||
if (object->empty() && !object_empty_ok) {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
"S3 path doesn't contain an object name.");
|
||||
}
|
||||
}
|
||||
|
||||
static Aws::Client::ClientConfiguration& GetDefaultClientConfig() {
|
||||
ABSL_CONST_INIT static absl::Mutex cfg_lock(absl::kConstInit);
|
||||
static bool init(false);
|
||||
static Aws::Client::ClientConfiguration cfg;
|
||||
|
||||
absl::MutexLock l(&cfg_lock);
|
||||
|
||||
if (!init) {
|
||||
const char* endpoint = getenv("S3_ENDPOINT");
|
||||
if (endpoint) cfg.endpointOverride = Aws::String(endpoint);
|
||||
const char* region = getenv("AWS_REGION");
|
||||
// TODO (yongtang): `S3_REGION` should be deprecated after 2.0.
|
||||
if (!region) region = getenv("S3_REGION");
|
||||
if (region) {
|
||||
cfg.region = Aws::String(region);
|
||||
} else {
|
||||
// Load config file (e.g., ~/.aws/config) only if AWS_SDK_LOAD_CONFIG
|
||||
// is set with a truthy value.
|
||||
const char* load_config_env = getenv("AWS_SDK_LOAD_CONFIG");
|
||||
std::string load_config =
|
||||
load_config_env ? absl::AsciiStrToLower(load_config_env) : "";
|
||||
if (load_config == "true" || load_config == "1") {
|
||||
Aws::String config_file;
|
||||
// If AWS_CONFIG_FILE is set then use it, otherwise use ~/.aws/config.
|
||||
const char* config_file_env = getenv("AWS_CONFIG_FILE");
|
||||
if (config_file_env) {
|
||||
config_file = config_file_env;
|
||||
} else {
|
||||
const char* home_env = getenv("HOME");
|
||||
if (home_env) {
|
||||
config_file = home_env;
|
||||
config_file += "/.aws/config";
|
||||
}
|
||||
}
|
||||
Aws::Config::AWSConfigFileProfileConfigLoader loader(config_file);
|
||||
loader.Load();
|
||||
auto profiles = loader.GetProfiles();
|
||||
if (!profiles["default"].GetRegion().empty())
|
||||
cfg.region = profiles["default"].GetRegion();
|
||||
}
|
||||
}
|
||||
const char* use_https = getenv("S3_USE_HTTPS");
|
||||
if (use_https) {
|
||||
if (use_https[0] == '0')
|
||||
cfg.scheme = Aws::Http::Scheme::HTTP;
|
||||
else
|
||||
cfg.scheme = Aws::Http::Scheme::HTTPS;
|
||||
}
|
||||
const char* verify_ssl = getenv("S3_VERIFY_SSL");
|
||||
if (verify_ssl) {
|
||||
if (verify_ssl[0] == '0')
|
||||
cfg.verifySSL = false;
|
||||
else
|
||||
cfg.verifySSL = true;
|
||||
}
|
||||
// if these timeouts are low, you may see an error when
|
||||
// uploading/downloading large files: Unable to connect to endpoint
|
||||
int64_t timeout;
|
||||
cfg.connectTimeoutMs =
|
||||
absl::SimpleAtoi(getenv("S3_CONNECT_TIMEOUT_MSEC"), &timeout)
|
||||
? timeout
|
||||
: kS3TimeoutMsec;
|
||||
cfg.requestTimeoutMs =
|
||||
absl::SimpleAtoi(getenv("S3_REQUEST_TIMEOUT_MSEC"), &timeout)
|
||||
? timeout
|
||||
: kS3TimeoutMsec;
|
||||
const char* ca_file = getenv("S3_CA_FILE");
|
||||
if (ca_file) cfg.caFile = Aws::String(ca_file);
|
||||
const char* ca_path = getenv("S3_CA_PATH");
|
||||
if (ca_path) cfg.caPath = Aws::String(ca_path);
|
||||
init = true;
|
||||
}
|
||||
return cfg;
|
||||
};
|
||||
|
||||
static void GetS3Client(tf_s3_filesystem::S3File* s3_file) {
|
||||
absl::MutexLock l(&s3_file->initialization_lock);
|
||||
|
||||
if (s3_file->s3_client.get() == nullptr) {
|
||||
Aws::SDKOptions options;
|
||||
options.cryptoOptions.sha256Factory_create_fn = []() {
|
||||
return Aws::MakeShared<tf_s3_filesystem::AWSSHA256Factory>(
|
||||
tf_s3_filesystem::AWSCryptoAllocationTag);
|
||||
};
|
||||
options.cryptoOptions.sha256HMACFactory_create_fn = []() {
|
||||
return Aws::MakeShared<tf_s3_filesystem::AWSSHA256HmacFactory>(
|
||||
tf_s3_filesystem::AWSCryptoAllocationTag);
|
||||
};
|
||||
options.cryptoOptions.secureRandomFactory_create_fn = []() {
|
||||
return Aws::MakeShared<tf_s3_filesystem::AWSSecureRandomFactory>(
|
||||
tf_s3_filesystem::AWSCryptoAllocationTag);
|
||||
};
|
||||
Aws::InitAPI(options);
|
||||
|
||||
// The creation of S3Client disables virtual addressing:
|
||||
// S3Client(clientConfiguration, signPayloads, useVirtualAddressing =
|
||||
// true)
|
||||
// The purpose is to address the issue encountered when there is an `.`
|
||||
// in the bucket name. Due to TLS hostname validation or DNS rules,
|
||||
// the bucket may not be resolved. Disabling of virtual addressing
|
||||
// should address the issue. See GitHub issue 16397 for details.
|
||||
s3_file->s3_client = Aws::MakeShared<Aws::S3::S3Client>(
|
||||
kS3ClientAllocationTag, GetDefaultClientConfig(),
|
||||
Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy::Never, false);
|
||||
}
|
||||
}
|
||||
|
||||
static void GetExecutor(tf_s3_filesystem::S3File* s3_file) {
|
||||
absl::MutexLock l(&s3_file->initialization_lock);
|
||||
|
||||
if (s3_file->executor.get() == nullptr) {
|
||||
s3_file->executor =
|
||||
Aws::MakeShared<Aws::Utils::Threading::PooledThreadExecutor>(
|
||||
kExecutorTag, kExecutorPoolSize);
|
||||
}
|
||||
}
|
||||
|
||||
static void GetTransferManager(
|
||||
const Aws::Transfer::TransferDirection& direction,
|
||||
tf_s3_filesystem::S3File* s3_file) {
|
||||
absl::MutexLock l(&s3_file->initialization_lock);
|
||||
|
||||
if (s3_file->transfer_managers[direction].get() == nullptr) {
|
||||
GetS3Client(s3_file);
|
||||
GetExecutor(s3_file);
|
||||
Aws::Transfer::TransferManagerConfiguration config(s3_file->executor.get());
|
||||
config.s3Client = s3_file->s3_client;
|
||||
config.bufferSize = s3_file->multi_part_chunk_sizes[direction];
|
||||
// must be larger than pool size * multi part chunk size
|
||||
config.transferBufferMaxHeapSize =
|
||||
(kExecutorPoolSize + 1) * s3_file->multi_part_chunk_sizes[direction];
|
||||
s3_file->transfer_managers[direction] =
|
||||
Aws::Transfer::TransferManager::Create(config);
|
||||
}
|
||||
}
|
||||
|
||||
static void ShutdownClient(Aws::S3::S3Client* s3_client) {
|
||||
if (s3_client != nullptr) {
|
||||
delete s3_client;
|
||||
Aws::SDKOptions options;
|
||||
Aws::ShutdownAPI(options);
|
||||
}
|
||||
}
|
||||
|
||||
// SECTION 1. Implementation for `TF_RandomAccessFile`
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_random_access_file {
|
||||
typedef struct S3File {
|
||||
Aws::String bucket;
|
||||
Aws::String object;
|
||||
std::shared_ptr<Aws::S3::S3Client> s3_client;
|
||||
std::shared_ptr<Aws::Transfer::TransferManager> transfer_manager;
|
||||
bool use_multi_part_download;
|
||||
} S3File;
|
||||
|
||||
// AWS Streams destroy the buffer (buf) passed, so creating a new
|
||||
// IOStream that retains the buffer so the calling function
|
||||
// can control it's lifecycle
|
||||
class TFS3UnderlyingStream : public Aws::IOStream {
|
||||
public:
|
||||
using Base = Aws::IOStream;
|
||||
TFS3UnderlyingStream(std::streambuf* buf) : Base(buf) {}
|
||||
virtual ~TFS3UnderlyingStream() = default;
|
||||
};
|
||||
|
||||
void Cleanup(TF_RandomAccessFile* file) {
|
||||
auto s3_file = static_cast<S3File*>(file->plugin_file);
|
||||
delete s3_file;
|
||||
}
|
||||
|
||||
static int64_t ReadS3Client(S3File* s3_file, uint64_t offset, size_t n,
|
||||
char* buffer, TF_Status* status) {
|
||||
Aws::S3::Model::GetObjectRequest get_object_request;
|
||||
get_object_request.WithBucket(s3_file->bucket).WithKey(s3_file->bucket);
|
||||
Aws::String bytes =
|
||||
absl::StrCat("bytes=", offset, "-", offset + n - 1).c_str();
|
||||
get_object_request.SetRange(bytes);
|
||||
get_object_request.SetResponseStreamFactory(
|
||||
[]() { return Aws::New<Aws::StringStream>(kS3FileSystemAllocationTag); });
|
||||
|
||||
auto get_object_outcome = s3_file->s3_client->GetObject(get_object_request);
|
||||
if (!get_object_outcome.IsSuccess())
|
||||
TF_SetStatusFromAWSError(get_object_outcome.GetError(), status);
|
||||
else
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
if (TF_GetCode(status) != TF_OK && TF_GetCode(status) != TF_OUT_OF_RANGE)
|
||||
return -1;
|
||||
|
||||
int64_t read = get_object_outcome.GetResult().GetContentLength();
|
||||
if (read < n)
|
||||
TF_SetStatus(status, TF_OUT_OF_RANGE, "Read less bytes than requested");
|
||||
get_object_outcome.GetResult().GetBody().read(buffer, read);
|
||||
return read;
|
||||
}
|
||||
|
||||
static int64_t ReadS3TransferManager(S3File* s3_file, uint64_t offset, size_t n,
|
||||
char* buffer, TF_Status* status) {
|
||||
auto create_download_stream = [&]() {
|
||||
return Aws::New<TFS3UnderlyingStream>(
|
||||
"S3ReadStream",
|
||||
Aws::New<Aws::Utils::Stream::PreallocatedStreamBuf>(
|
||||
"S3ReadStream", reinterpret_cast<unsigned char*>(buffer), n));
|
||||
};
|
||||
auto handle = s3_file->transfer_manager->DownloadFile(
|
||||
s3_file->bucket, s3_file->object, offset, n, create_download_stream);
|
||||
handle->WaitUntilFinished();
|
||||
|
||||
size_t retries = 0;
|
||||
while (handle->GetStatus() == Aws::Transfer::TransferStatus::FAILED &&
|
||||
handle->GetLastError().GetResponseCode() !=
|
||||
Aws::Http::HttpResponseCode::REQUESTED_RANGE_NOT_SATISFIABLE &&
|
||||
retries++ < kDownloadRetries) {
|
||||
// Only failed parts will be downloaded again.
|
||||
s3_file->transfer_manager->RetryDownload(handle);
|
||||
handle->WaitUntilFinished();
|
||||
}
|
||||
|
||||
if (handle->GetStatus() != Aws::Transfer::TransferStatus::COMPLETED)
|
||||
TF_SetStatusFromAWSError(handle->GetLastError(), status);
|
||||
else
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
if (TF_GetCode(status) != TF_OK && TF_GetCode(status) != TF_OUT_OF_RANGE)
|
||||
return -1;
|
||||
int64_t read = handle->GetBytesTransferred();
|
||||
if (read < n)
|
||||
TF_SetStatus(status, TF_OUT_OF_RANGE, "Read less bytes than requested");
|
||||
return read;
|
||||
}
|
||||
|
||||
int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
|
||||
char* buffer, TF_Status* status) {
|
||||
auto s3_file = static_cast<S3File*>(file->plugin_file);
|
||||
if (s3_file->use_multi_part_download)
|
||||
return ReadS3TransferManager(s3_file, offset, n, buffer, status);
|
||||
else
|
||||
return ReadS3Client(s3_file, offset, n, buffer, status);
|
||||
}
|
||||
|
||||
} // namespace tf_random_access_file
|
||||
|
||||
// SECTION 2. Implementation for `TF_WritableFile`
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_writable_file {
|
||||
typedef struct S3File {
|
||||
Aws::String bucket;
|
||||
Aws::String object;
|
||||
std::shared_ptr<Aws::S3::S3Client> s3_client;
|
||||
std::shared_ptr<Aws::Transfer::TransferManager> transfer_manager;
|
||||
bool sync_needed;
|
||||
std::shared_ptr<Aws::Utils::TempFile> outfile;
|
||||
S3File(Aws::String bucket, Aws::String object,
|
||||
std::shared_ptr<Aws::S3::S3Client> s3_client,
|
||||
std::shared_ptr<Aws::Transfer::TransferManager> transfer_manager)
|
||||
: bucket(bucket),
|
||||
object(object),
|
||||
s3_client(s3_client),
|
||||
transfer_manager(transfer_manager),
|
||||
outfile(Aws::MakeShared<Aws::Utils::TempFile>(
|
||||
kS3FileSystemAllocationTag, nullptr, "_s3_filesystem_XXXXXX",
|
||||
std::ios_base::binary | std::ios_base::trunc | std::ios_base::in |
|
||||
std::ios_base::out)) {}
|
||||
} S3File;
|
||||
|
||||
void Cleanup(TF_WritableFile* file) {
|
||||
auto s3_file = static_cast<S3File*>(file->plugin_file);
|
||||
delete s3_file;
|
||||
}
|
||||
|
||||
void Append(const TF_WritableFile* file, const char* buffer, size_t n,
|
||||
TF_Status* status) {
|
||||
auto s3_file = static_cast<S3File*>(file->plugin_file);
|
||||
if (!s3_file->outfile) {
|
||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
||||
"The internal temporary file is not writable.");
|
||||
return;
|
||||
}
|
||||
s3_file->sync_needed = true;
|
||||
s3_file->outfile->write(buffer, n);
|
||||
if (!s3_file->outfile->good())
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"Could not append to the internal temporary file.");
|
||||
else
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
int64_t Tell(const TF_WritableFile* file, TF_Status* status) {
|
||||
auto s3_file = static_cast<S3File*>(file->plugin_file);
|
||||
auto position = static_cast<int64_t>(s3_file->outfile->tellp());
|
||||
if (position == -1)
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"tellp on the internal temporary file failed");
|
||||
else
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
return position;
|
||||
}
|
||||
|
||||
void Sync(const TF_WritableFile* file, TF_Status* status) {
|
||||
auto s3_file = static_cast<S3File*>(file->plugin_file);
|
||||
if (!s3_file->outfile) {
|
||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
||||
"The internal temporary file is not writable.");
|
||||
return;
|
||||
}
|
||||
if (!s3_file->sync_needed) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
return;
|
||||
}
|
||||
auto position = static_cast<int64_t>(s3_file->outfile->tellp());
|
||||
auto handle = s3_file->transfer_manager->UploadFile(
|
||||
s3_file->outfile, s3_file->bucket, s3_file->object,
|
||||
"application/octet-stream", Aws::Map<Aws::String, Aws::String>());
|
||||
handle->WaitUntilFinished();
|
||||
|
||||
size_t retries = 0;
|
||||
while (handle->GetStatus() == Aws::Transfer::TransferStatus::FAILED &&
|
||||
retries++ < kUploadRetries) {
|
||||
// if multipart upload was used, only the failed parts will be re-sent
|
||||
s3_file->transfer_manager->RetryUpload(s3_file->outfile, handle);
|
||||
handle->WaitUntilFinished();
|
||||
}
|
||||
if (handle->GetStatus() != Aws::Transfer::TransferStatus::COMPLETED)
|
||||
return TF_SetStatusFromAWSError(handle->GetLastError(), status);
|
||||
s3_file->outfile->clear();
|
||||
s3_file->outfile->seekp(position);
|
||||
s3_file->sync_needed = false;
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
void Flush(const TF_WritableFile* file, TF_Status* status) {
|
||||
Sync(file, status);
|
||||
}
|
||||
|
||||
void Close(const TF_WritableFile* file, TF_Status* status) {
|
||||
auto s3_file = static_cast<S3File*>(file->plugin_file);
|
||||
if (s3_file->outfile) {
|
||||
Sync(file, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
s3_file->outfile.reset();
|
||||
}
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
} // namespace tf_writable_file
|
||||
|
||||
// SECTION 3. Implementation for `TF_ReadOnlyMemoryRegion`
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_read_only_memory_region {
|
||||
typedef struct S3MemoryRegion {
|
||||
std::unique_ptr<char[]> data;
|
||||
uint64_t length;
|
||||
} S3MemoryRegion;
|
||||
|
||||
void Cleanup(TF_ReadOnlyMemoryRegion* region) {
|
||||
auto r = static_cast<S3MemoryRegion*>(region->plugin_memory_region);
|
||||
delete r;
|
||||
}
|
||||
|
||||
const void* Data(const TF_ReadOnlyMemoryRegion* region) {
|
||||
auto r = static_cast<S3MemoryRegion*>(region->plugin_memory_region);
|
||||
return reinterpret_cast<const void*>(r->data.get());
|
||||
}
|
||||
|
||||
uint64_t Length(const TF_ReadOnlyMemoryRegion* region) {
|
||||
auto r = static_cast<S3MemoryRegion*>(region->plugin_memory_region);
|
||||
return r->length;
|
||||
}
|
||||
|
||||
} // namespace tf_read_only_memory_region
|
||||
|
||||
// SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_s3_filesystem {
|
||||
S3File::S3File()
|
||||
: s3_client(nullptr, ShutdownClient),
|
||||
executor(nullptr),
|
||||
transfer_managers(),
|
||||
multi_part_chunk_sizes(),
|
||||
use_multi_part_download(true),
|
||||
initialization_lock() {
|
||||
uint64_t temp_value;
|
||||
multi_part_chunk_sizes[Aws::Transfer::TransferDirection::UPLOAD] =
|
||||
absl::SimpleAtoi(getenv("S3_MULTI_PART_UPLOAD_CHUNK_SIZE"), &temp_value)
|
||||
? temp_value
|
||||
: kS3MultiPartUploadChunkSize;
|
||||
multi_part_chunk_sizes[Aws::Transfer::TransferDirection::DOWNLOAD] =
|
||||
absl::SimpleAtoi(getenv("S3_MULTI_PART_DOWNLOAD_CHUNK_SIZE"), &temp_value)
|
||||
? temp_value
|
||||
: kS3MultiPartDownloadChunkSize;
|
||||
use_multi_part_download =
|
||||
absl::SimpleAtoi(getenv("S3_DISABLE_MULTI_PART_DOWNLOAD"), &temp_value)
|
||||
? (temp_value != 1)
|
||||
: use_multi_part_download;
|
||||
transfer_managers.emplace(Aws::Transfer::TransferDirection::UPLOAD, nullptr);
|
||||
transfer_managers.emplace(Aws::Transfer::TransferDirection::DOWNLOAD,
|
||||
nullptr);
|
||||
}
|
||||
void Init(TF_Filesystem* filesystem, TF_Status* status) {
|
||||
filesystem->plugin_filesystem = new S3File();
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
void Cleanup(TF_Filesystem* filesystem) {
|
||||
auto s3_file = static_cast<S3File*>(filesystem->plugin_filesystem);
|
||||
delete s3_file;
|
||||
}
|
||||
|
||||
void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_RandomAccessFile* file, TF_Status* status) {
|
||||
Aws::String bucket, object;
|
||||
ParseS3Path(path, false, &bucket, &object, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
auto s3_file = static_cast<S3File*>(filesystem->plugin_filesystem);
|
||||
GetS3Client(s3_file);
|
||||
GetTransferManager(Aws::Transfer::TransferDirection::DOWNLOAD, s3_file);
|
||||
file->plugin_file = new tf_random_access_file::S3File(
|
||||
{bucket, object, s3_file->s3_client,
|
||||
s3_file->transfer_managers[Aws::Transfer::TransferDirection::DOWNLOAD],
|
||||
s3_file->use_multi_part_download});
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_WritableFile* file, TF_Status* status) {
|
||||
Aws::String bucket, object;
|
||||
ParseS3Path(path, false, &bucket, &object, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
auto s3_file = static_cast<S3File*>(filesystem->plugin_filesystem);
|
||||
GetS3Client(s3_file);
|
||||
GetTransferManager(Aws::Transfer::TransferDirection::UPLOAD, s3_file);
|
||||
file->plugin_file = new tf_writable_file::S3File(
|
||||
bucket, object, s3_file->s3_client,
|
||||
s3_file->transfer_managers[Aws::Transfer::TransferDirection::UPLOAD]);
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_WritableFile* file, TF_Status* status) {
|
||||
Aws::String bucket, object;
|
||||
ParseS3Path(path, false, &bucket, &object, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
auto s3_file = static_cast<S3File*>(filesystem->plugin_filesystem);
|
||||
GetS3Client(s3_file);
|
||||
GetTransferManager(Aws::Transfer::TransferDirection::UPLOAD, s3_file);
|
||||
|
||||
// We need to delete `file->plugin_file` in case of errors. We set
|
||||
// `file->plugin_file` to `nullptr` in order to avoid segment fault when
|
||||
// calling deleter of `unique_ptr`.
|
||||
file->plugin_file = nullptr;
|
||||
std::unique_ptr<TF_WritableFile, void (*)(TF_WritableFile*)> writer(
|
||||
file, [](TF_WritableFile* file) {
|
||||
if (file != nullptr && file->plugin_file != nullptr) {
|
||||
tf_writable_file::Cleanup(file);
|
||||
}
|
||||
});
|
||||
writer->plugin_file = new tf_writable_file::S3File(
|
||||
bucket, object, s3_file->s3_client,
|
||||
s3_file->transfer_managers[Aws::Transfer::TransferDirection::UPLOAD]);
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
|
||||
// Wraping inside a `std::unique_ptr` to prevent memory-leaking.
|
||||
std::unique_ptr<TF_RandomAccessFile, void (*)(TF_RandomAccessFile*)> reader(
|
||||
new TF_RandomAccessFile, [](TF_RandomAccessFile* file) {
|
||||
if (file != nullptr) {
|
||||
if (file->plugin_file != nullptr)
|
||||
tf_random_access_file::Cleanup(file);
|
||||
delete file;
|
||||
}
|
||||
});
|
||||
// We set `reader->plugin_file` to `nullptr` in order to avoid segment fault
|
||||
// when calling deleter of `unique_ptr`
|
||||
reader->plugin_file = nullptr;
|
||||
NewRandomAccessFile(filesystem, path, reader.get(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
uint64_t offset = 0;
|
||||
std::string buffer(kS3ReadAppendableFileBufferSize, {});
|
||||
while (true) {
|
||||
auto read = tf_random_access_file::Read(reader.get(), offset,
|
||||
kS3ReadAppendableFileBufferSize,
|
||||
&buffer[0], status);
|
||||
if (TF_GetCode(status) == TF_NOT_FOUND) {
|
||||
break;
|
||||
} else if (TF_GetCode(status) == TF_OK) {
|
||||
offset += read;
|
||||
tf_writable_file::Append(file, buffer.c_str(), read, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
} else if (TF_GetCode(status) == TF_OUT_OF_RANGE) {
|
||||
offset += read;
|
||||
tf_writable_file::Append(file, buffer.c_str(), read, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
break;
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
}
|
||||
writer.release();
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
void Stat(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_FileStatistics* stats, TF_Status* status) {
|
||||
Aws::String bucket, object;
|
||||
ParseS3Path(path, true, &bucket, &object, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
auto s3_file = static_cast<S3File*>(filesystem->plugin_filesystem);
|
||||
GetS3Client(s3_file);
|
||||
|
||||
if (object.empty()) {
|
||||
Aws::S3::Model::HeadBucketRequest head_bucket_request;
|
||||
head_bucket_request.WithBucket(bucket);
|
||||
auto head_bucket_outcome =
|
||||
s3_file->s3_client->HeadBucket(head_bucket_request);
|
||||
if (!head_bucket_outcome.IsSuccess())
|
||||
return TF_SetStatusFromAWSError(head_bucket_outcome.GetError(), status);
|
||||
stats->length = 0;
|
||||
stats->is_directory = 1;
|
||||
stats->mtime_nsec = 0;
|
||||
return TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
bool found = false;
|
||||
Aws::S3::Model::HeadObjectRequest head_object_request;
|
||||
head_object_request.WithBucket(bucket).WithKey(object);
|
||||
head_object_request.SetResponseStreamFactory(
|
||||
[]() { return Aws::New<Aws::StringStream>(kS3FileSystemAllocationTag); });
|
||||
auto head_object_outcome =
|
||||
s3_file->s3_client->HeadObject(head_object_request);
|
||||
if (head_object_outcome.IsSuccess()) {
|
||||
stats->length = head_object_outcome.GetResult().GetContentLength();
|
||||
stats->is_directory = 0;
|
||||
stats->mtime_nsec =
|
||||
head_object_outcome.GetResult().GetLastModified().Millis() * 1e6;
|
||||
found = true;
|
||||
} else {
|
||||
return TF_SetStatusFromAWSError(head_object_outcome.GetError(), status);
|
||||
}
|
||||
|
||||
auto prefix = object;
|
||||
if (prefix.back() != '/') {
|
||||
prefix.push_back('/');
|
||||
}
|
||||
Aws::S3::Model::ListObjectsRequest list_objects_request;
|
||||
list_objects_request.WithBucket(bucket).WithPrefix(prefix).WithMaxKeys(1);
|
||||
list_objects_request.SetResponseStreamFactory(
|
||||
[]() { return Aws::New<Aws::StringStream>(kS3FileSystemAllocationTag); });
|
||||
auto list_objects_outcome =
|
||||
s3_file->s3_client->ListObjects(list_objects_request);
|
||||
if (list_objects_outcome.IsSuccess()) {
|
||||
auto objects = list_objects_outcome.GetResult().GetContents();
|
||||
if (objects.size() > 0) {
|
||||
stats->length = 0;
|
||||
stats->is_directory = 1;
|
||||
stats->mtime_nsec = objects[0].GetLastModified().Millis() * 1e6;
|
||||
found = true;
|
||||
}
|
||||
} else {
|
||||
TF_SetStatusFromAWSError(list_objects_outcome.GetError(), status);
|
||||
if (TF_GetCode(status) == TF_FAILED_PRECONDITION) return;
|
||||
}
|
||||
if (!found)
|
||||
return TF_SetStatus(
|
||||
status, TF_NOT_FOUND,
|
||||
absl::StrCat("Object ", path, " does not exist").c_str());
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
void PathExists(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_Status* status) {
|
||||
TF_FileStatistics stats;
|
||||
Stat(filesystem, path, &stats, status);
|
||||
}
|
||||
|
||||
int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_Status* status) {
|
||||
TF_FileStatistics stats;
|
||||
Stat(filesystem, path, &stats, status);
|
||||
return stats.length;
|
||||
}
|
||||
|
||||
void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem,
|
||||
const char* path,
|
||||
TF_ReadOnlyMemoryRegion* region,
|
||||
TF_Status* status) {
|
||||
Aws::String bucket, object;
|
||||
ParseS3Path(path, false, &bucket, &object, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
auto s3_file = static_cast<S3File*>(filesystem->plugin_filesystem);
|
||||
GetS3Client(s3_file);
|
||||
GetTransferManager(Aws::Transfer::TransferDirection::UPLOAD, s3_file);
|
||||
|
||||
auto size = GetFileSize(filesystem, path, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
if (size == 0)
|
||||
return TF_SetStatus(status, TF_INVALID_ARGUMENT, "File is empty");
|
||||
|
||||
std::unique_ptr<char[]> data(new char[size]);
|
||||
// Wraping inside a `std::unique_ptr` to prevent memory-leaking.
|
||||
std::unique_ptr<TF_RandomAccessFile, void (*)(TF_RandomAccessFile*)> reader(
|
||||
new TF_RandomAccessFile, [](TF_RandomAccessFile* file) {
|
||||
if (file != nullptr) {
|
||||
if (file->plugin_file != nullptr)
|
||||
tf_random_access_file::Cleanup(file);
|
||||
delete file;
|
||||
}
|
||||
});
|
||||
// We set `reader->plugin_file` to `nullptr` in order to avoid segment fault
|
||||
// when calling deleter of `unique_ptr`
|
||||
reader->plugin_file = nullptr;
|
||||
NewRandomAccessFile(filesystem, path, reader.get(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
auto read =
|
||||
tf_random_access_file::Read(reader.get(), 0, size, data.get(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
region->plugin_memory_region = new tf_read_only_memory_region::S3MemoryRegion(
|
||||
{std::move(data), static_cast<uint64_t>(read)});
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
static void SimpleCopyFile(const Aws::String& source,
|
||||
const Aws::String& bucket_dst,
|
||||
const Aws::String& object_dst, S3File* s3_file,
|
||||
TF_Status* status) {
|
||||
Aws::S3::Model::CopyObjectRequest copy_object_request;
|
||||
copy_object_request.WithCopySource(source)
|
||||
.WithBucket(bucket_dst)
|
||||
.WithKey(object_dst);
|
||||
auto copy_object_outcome =
|
||||
s3_file->s3_client->CopyObject(copy_object_request);
|
||||
if (!copy_object_outcome.IsSuccess())
|
||||
TF_SetStatusFromAWSError(copy_object_outcome.GetError(), status);
|
||||
else
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
};
|
||||
|
||||
using EtagOutcome =
|
||||
Aws::Utils::Outcome<Aws::String, Aws::Client::AWSError<Aws::S3::S3Errors>>;
|
||||
typedef struct MultipartCopyAsyncContext
|
||||
: public Aws::Client::AsyncCallerContext {
|
||||
int part_number;
|
||||
int* num_finished_parts;
|
||||
Aws::Vector<EtagOutcome>* etag_outcomes;
|
||||
|
||||
// lock and cv for multi part copy
|
||||
absl::Mutex* multi_part_copy_mutex;
|
||||
absl::CondVar* multi_part_copy_cv;
|
||||
} MultipartCopyAsyncContext;
|
||||
|
||||
static void AbortMultiPartCopy(const Aws::String& bucket_dst,
|
||||
const Aws::String& object_dst,
|
||||
const Aws::String& upload_id, S3File* s3_file,
|
||||
TF_Status* status) {
|
||||
Aws::S3::Model::AbortMultipartUploadRequest request;
|
||||
request.WithBucket(bucket_dst).WithKey(object_dst).WithUploadId(upload_id);
|
||||
auto outcome = s3_file->s3_client->AbortMultipartUpload(request);
|
||||
if (!outcome.IsSuccess())
|
||||
TF_SetStatusFromAWSError(outcome.GetError(), status);
|
||||
else
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
static void MultiPartCopyCallback(
|
||||
const Aws::S3::Model::UploadPartCopyRequest& request,
|
||||
const Aws::S3::Model::UploadPartCopyOutcome& outcome,
|
||||
const std::shared_ptr<const MultipartCopyAsyncContext>& context) {
|
||||
// Access to `etag_outcomes` should be thread-safe because of distinct
|
||||
// `part_number`.
|
||||
auto part_number = context->part_number;
|
||||
auto etag_outcomes = context->etag_outcomes;
|
||||
if (outcome.IsSuccess()) {
|
||||
(*etag_outcomes)[part_number] =
|
||||
outcome.GetResult().GetCopyPartResult().GetETag();
|
||||
} else {
|
||||
(*etag_outcomes)[part_number] = outcome.GetError();
|
||||
}
|
||||
{
|
||||
absl::MutexLock l(context->multi_part_copy_mutex);
|
||||
(*context->num_finished_parts)++;
|
||||
context->multi_part_copy_cv->Signal();
|
||||
}
|
||||
}
|
||||
|
||||
static void MultiPartCopy(const Aws::String& source,
|
||||
const Aws::String& bucket_dst,
|
||||
const Aws::String& object_dst, const size_t num_parts,
|
||||
const uint64_t file_size, S3File* s3_file,
|
||||
TF_Status* status) {
|
||||
Aws::S3::Model::CreateMultipartUploadRequest create_multipart_upload_request;
|
||||
create_multipart_upload_request.WithBucket(bucket_dst).WithKey(object_dst);
|
||||
|
||||
GetS3Client(s3_file);
|
||||
GetTransferManager(Aws::Transfer::TransferDirection::UPLOAD, s3_file);
|
||||
|
||||
auto create_multipart_upload_outcome =
|
||||
s3_file->s3_client->CreateMultipartUpload(
|
||||
create_multipart_upload_request);
|
||||
if (!create_multipart_upload_outcome.IsSuccess())
|
||||
return TF_SetStatusFromAWSError(create_multipart_upload_outcome.GetError(),
|
||||
status);
|
||||
|
||||
auto upload_id = create_multipart_upload_outcome.GetResult().GetUploadId();
|
||||
|
||||
int num_finished_parts = 0;
|
||||
// Keep track of `Outcome` of each upload part.
|
||||
Aws::Vector<EtagOutcome> etag_outcomes(num_parts);
|
||||
// Mutex which protects access of the part_states map.
|
||||
absl::Mutex multi_part_copy_mutex;
|
||||
// Condition variable to be used with above mutex for synchronization.
|
||||
absl::CondVar multi_part_copy_cv;
|
||||
|
||||
auto chunk_size =
|
||||
s3_file->multi_part_chunk_sizes[Aws::Transfer::TransferDirection::UPLOAD];
|
||||
|
||||
size_t retries = 0;
|
||||
while (retries++ < 3) {
|
||||
// Queue up parts.
|
||||
for (auto part_number = 0; part_number < num_parts; ++part_number) {
|
||||
if (etag_outcomes[part_number].IsSuccess()) continue;
|
||||
uint64_t start_pos = part_number * chunk_size;
|
||||
uint64_t end_pos = start_pos + chunk_size - 1;
|
||||
if (end_pos >= file_size) end_pos = file_size - 1;
|
||||
|
||||
Aws::String range =
|
||||
absl::StrCat("bytes=", start_pos, "-", end_pos).c_str();
|
||||
Aws::S3::Model::UploadPartCopyRequest upload_part_copy_request;
|
||||
upload_part_copy_request.WithBucket(bucket_dst)
|
||||
.WithKey(object_dst)
|
||||
.WithCopySource(source)
|
||||
.WithCopySourceRange(range)
|
||||
// S3 API partNumber starts from 1.
|
||||
.WithPartNumber(part_number + 1)
|
||||
.WithUploadId(upload_id);
|
||||
|
||||
auto multi_part_context =
|
||||
Aws::MakeShared<MultipartCopyAsyncContext>("MultiPartCopyContext");
|
||||
multi_part_context->part_number = part_number;
|
||||
multi_part_context->num_finished_parts = &num_finished_parts;
|
||||
multi_part_context->etag_outcomes = &etag_outcomes;
|
||||
multi_part_context->multi_part_copy_mutex = &multi_part_copy_mutex;
|
||||
multi_part_context->multi_part_copy_cv = &multi_part_copy_cv;
|
||||
auto callback =
|
||||
[](const Aws::S3::S3Client* client,
|
||||
const Aws::S3::Model::UploadPartCopyRequest& request,
|
||||
const Aws::S3::Model::UploadPartCopyOutcome& outcome,
|
||||
const std::shared_ptr<const Aws::Client::AsyncCallerContext>&
|
||||
context) {
|
||||
auto multipart_context =
|
||||
std::static_pointer_cast<const MultipartCopyAsyncContext>(
|
||||
context);
|
||||
MultiPartCopyCallback(request, outcome, multipart_context);
|
||||
};
|
||||
|
||||
std::shared_ptr<const Aws::Client::AsyncCallerContext> context =
|
||||
multi_part_context;
|
||||
s3_file->s3_client->UploadPartCopyAsync(upload_part_copy_request,
|
||||
callback, context);
|
||||
}
|
||||
// Wait till they finish.
|
||||
{
|
||||
absl::MutexLock l(&multi_part_copy_mutex);
|
||||
// Wait on the mutex until notify is called then check the finished parts
|
||||
// as there could be false notifications.
|
||||
while (num_finished_parts != num_parts) {
|
||||
multi_part_copy_cv.Wait(&multi_part_copy_mutex);
|
||||
}
|
||||
}
|
||||
// check if there was any error for any part.
|
||||
for (auto part_number = 0; part_number < num_parts; ++part_number) {
|
||||
if (!etag_outcomes[part_number].IsSuccess()) {
|
||||
if (retries >= 3) {
|
||||
AbortMultiPartCopy(bucket_dst, object_dst, upload_id, s3_file,
|
||||
status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
return TF_SetStatusFromAWSError(etag_outcomes[part_number].GetError(),
|
||||
status);
|
||||
} else {
|
||||
// Retry.
|
||||
num_finished_parts--;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Aws::S3::Model::CompletedMultipartUpload completed_multipart_upload;
|
||||
// If there was an error still in any part, it would abort and return in the
|
||||
// above loop. We set the eTag of completed parts to the final
|
||||
// `completed_multipart_upload`. Note these parts have to be added in order.
|
||||
for (int part_number = 0; part_number < num_parts; ++part_number) {
|
||||
Aws::S3::Model::CompletedPart completed_part;
|
||||
completed_part.SetPartNumber(part_number + 1);
|
||||
completed_part.SetETag(etag_outcomes[part_number].GetResult());
|
||||
completed_multipart_upload.AddParts(completed_part);
|
||||
}
|
||||
|
||||
Aws::S3::Model::CompleteMultipartUploadRequest
|
||||
complete_multipart_upload_request;
|
||||
complete_multipart_upload_request.WithBucket(bucket_dst)
|
||||
.WithKey(object_dst)
|
||||
.WithUploadId(upload_id)
|
||||
.WithMultipartUpload(completed_multipart_upload);
|
||||
auto complete_multipart_upload_outcome =
|
||||
s3_file->s3_client->CompleteMultipartUpload(
|
||||
complete_multipart_upload_request);
|
||||
if (!complete_multipart_upload_outcome.IsSuccess())
|
||||
AbortMultiPartCopy(bucket_dst, object_dst, upload_id, s3_file, status);
|
||||
else
|
||||
return TF_SetStatus(status, TF_OK, "");
|
||||
if (TF_GetCode(status) == TF_OK)
|
||||
return TF_SetStatusFromAWSError(
|
||||
complete_multipart_upload_outcome.GetError(), status);
|
||||
};
|
||||
|
||||
void CopyFile(const TF_Filesystem* filesystem, const char* src, const char* dst,
|
||||
TF_Status* status) {
|
||||
auto file_size = GetFileSize(filesystem, src, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
if (file_size == 0)
|
||||
return TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
||||
"Source is a directory or empty file");
|
||||
|
||||
Aws::String bucket_src, object_src;
|
||||
ParseS3Path(src, false, &bucket_src, &object_src, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
Aws::String copy_src = bucket_src + "/" + object_src;
|
||||
|
||||
Aws::String bucket_dst, object_dst;
|
||||
ParseS3Path(dst, false, &bucket_dst, &object_dst, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
auto s3_file = static_cast<S3File*>(filesystem->plugin_filesystem);
|
||||
auto chunk_size =
|
||||
s3_file->multi_part_chunk_sizes[Aws::Transfer::TransferDirection::UPLOAD];
|
||||
size_t num_parts = 1;
|
||||
if (file_size > chunk_size) num_parts = ceil((float)file_size / chunk_size);
|
||||
if (num_parts == 1)
|
||||
SimpleCopyFile(copy_src, bucket_dst, object_dst, s3_file, status);
|
||||
else if (num_parts > 10000)
|
||||
TF_SetStatus(
|
||||
status, TF_UNIMPLEMENTED,
|
||||
absl::StrCat("MultiPartCopy with number of parts more than 10000 is "
|
||||
"not supported. Your object ",
|
||||
src, " required ", num_parts,
|
||||
" as multi_part_copy_part_size is set to ", chunk_size,
|
||||
". You can control this part size using the environment "
|
||||
"variable S3_MULTI_PART_COPY_PART_SIZE to increase it.")
|
||||
.c_str());
|
||||
else
|
||||
MultiPartCopy(copy_src, bucket_dst, object_dst, num_parts, file_size,
|
||||
s3_file, status);
|
||||
}
|
||||
|
||||
// TODO(vnvo2409): Implement later
|
||||
|
||||
} // namespace tf_s3_filesystem
|
||||
|
||||
static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
|
||||
const char* uri) {
|
||||
TF_SetFilesystemVersionMetadata(ops);
|
||||
ops->scheme = strdup(uri);
|
||||
}
|
||||
|
||||
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {
|
||||
info->plugin_memory_allocate = plugin_memory_allocate;
|
||||
info->plugin_memory_free = plugin_memory_free;
|
||||
info->num_schemes = 1;
|
||||
info->ops = static_cast<TF_FilesystemPluginOps*>(
|
||||
plugin_memory_allocate(info->num_schemes * sizeof(info->ops[0])));
|
||||
ProvideFilesystemSupportFor(&info->ops[0], "s3");
|
||||
}
|
@ -0,0 +1,47 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_S3_FILESYSTEM_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_S3_FILESYSTEM_H_
|
||||
|
||||
#include <aws/core/Aws.h>
|
||||
#include <aws/core/utils/StringUtils.h>
|
||||
#include <aws/core/utils/memory/stl/AWSMap.h>
|
||||
#include <aws/core/utils/threading/Executor.h>
|
||||
#include <aws/s3/S3Client.h>
|
||||
#include <aws/transfer/TransferManager.h>
|
||||
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
namespace tf_s3_filesystem {
|
||||
typedef struct S3File {
|
||||
std::shared_ptr<Aws::S3::S3Client> s3_client;
|
||||
std::shared_ptr<Aws::Utils::Threading::PooledThreadExecutor> executor;
|
||||
// We need 2 `TransferManager`, for multipart upload/download.
|
||||
Aws::Map<Aws::Transfer::TransferDirection,
|
||||
std::shared_ptr<Aws::Transfer::TransferManager>>
|
||||
transfer_managers;
|
||||
// Sizes to split objects during multipart upload/download.
|
||||
Aws::Map<Aws::Transfer::TransferDirection, uint64_t> multi_part_chunk_sizes;
|
||||
bool use_multi_part_download;
|
||||
absl::Mutex initialization_lock;
|
||||
S3File();
|
||||
} S3File;
|
||||
void Init(TF_Filesystem* filesystem, TF_Status* status);
|
||||
void Cleanup(TF_Filesystem* filesystem);
|
||||
} // namespace tf_s3_filesystem
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_S3_FILESYSTEM_H_
|
23
tensorflow/c/experimental/gradients/BUILD
Normal file
23
tensorflow/c/experimental/gradients/BUILD
Normal file
@ -0,0 +1,23 @@
|
||||
# Library of gradient functions.
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "math_grad",
|
||||
srcs = ["math_grad.cc"],
|
||||
hdrs = [
|
||||
"math_grad.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/eager:abstract_operation",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:c_api_unified_internal",
|
||||
"//tensorflow/c/eager:gradients",
|
||||
"//tensorflow/c/experimental/ops:array_ops",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
],
|
||||
)
|
53
tensorflow/c/experimental/gradients/math_grad.cc
Normal file
53
tensorflow/c/experimental/gradients/math_grad.cc
Normal file
@ -0,0 +1,53 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/gradients/math_grad.h"
|
||||
|
||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||
|
||||
using tensorflow::ops::Identity;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
namespace {
|
||||
|
||||
class AddGradientFunction : public GradientFunction {
|
||||
public:
|
||||
Status Compute(Context* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> grad_inputs,
|
||||
std::vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
grad_outputs->resize(2);
|
||||
std::vector<AbstractTensorHandle*> identity_outputs(1);
|
||||
// TODO(b/145674566): Handle name unification in tracing code.
|
||||
// TODO(b/161805092): Support broadcasting.
|
||||
TF_RETURN_IF_ERROR(ops::Identity(ctx->ctx, {grad_inputs[0]},
|
||||
absl::MakeSpan(identity_outputs),
|
||||
"Identity0"));
|
||||
(*grad_outputs)[0] = identity_outputs[0];
|
||||
TF_RETURN_IF_ERROR(ops::Identity(ctx->ctx, {grad_inputs[0]},
|
||||
absl::MakeSpan(identity_outputs),
|
||||
"Identity1"));
|
||||
(*grad_outputs)[1] = identity_outputs[0];
|
||||
return Status::OK();
|
||||
}
|
||||
~AddGradientFunction() override {}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
GradientFunction* AddRegisterer(const ForwardOperation& op) {
|
||||
return new AddGradientFunction;
|
||||
}
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
26
tensorflow/c/experimental/gradients/math_grad.h
Normal file
26
tensorflow/c/experimental/gradients/math_grad.h
Normal file
@ -0,0 +1,26 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_
|
||||
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
GradientFunction* AddRegisterer(const ForwardOperation& op);
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_
|
24
tensorflow/c/experimental/ops/BUILD
Normal file
24
tensorflow/c/experimental/ops/BUILD
Normal file
@ -0,0 +1,24 @@
|
||||
# Experimental ops. These will eventually be replaced by machine-generated versions.
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "array_ops",
|
||||
srcs = [
|
||||
"array_ops.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"array_ops.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/eager:abstract_operation",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:c_api_unified_internal",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"//tensorflow/core/platform:errors",
|
||||
],
|
||||
)
|
38
tensorflow/c/experimental/ops/array_ops.cc
Normal file
38
tensorflow/c/experimental/ops/array_ops.cc
Normal file
@ -0,0 +1,38 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace ops {
|
||||
// Creates an Identity op.
|
||||
Status Identity(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr identity_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(
|
||||
identity_op->Reset("Identity", /*raw_device_name=*/nullptr));
|
||||
if (isa<tensorflow::tracing::TracingOperation>(identity_op.get())) {
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(identity_op.get())
|
||||
->SetOpName(name));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(identity_op->AddInput(inputs[0]));
|
||||
int num_retvals = 1;
|
||||
return identity_op->Execute(outputs, &num_retvals);
|
||||
}
|
||||
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
31
tensorflow/c/experimental/ops/array_ops.h
Normal file
31
tensorflow/c/experimental/ops/array_ops.h
Normal file
@ -0,0 +1,31 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_ARRAY_OPS_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_OPS_ARRAY_OPS_H_
|
||||
|
||||
#include "tensorflow/c/eager/abstract_operation.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace ops {
|
||||
Status Identity(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_OPS_ARRAY_OPS_H_
|
@ -113,8 +113,23 @@ cc_library(
|
||||
deps = [
|
||||
":concrete_function",
|
||||
":saved_model_api",
|
||||
":saved_model_utils",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/c/experimental/saved_model/core/ops:restore_ops",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:tensorhandle_convertible",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:variable",
|
||||
"//tensorflow/cc/saved_model:bundle_v2",
|
||||
"//tensorflow/cc/saved_model:constants",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
@ -15,47 +15,360 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/ops/restore_ops.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
|
||||
#include "tensorflow/cc/saved_model/bundle_v2.h"
|
||||
#include "tensorflow/cc/saved_model/constants.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/path.h"
|
||||
#include "tensorflow/core/platform/stringpiece.h"
|
||||
#include "tensorflow/core/platform/tstring.h"
|
||||
#include "tensorflow/core/protobuf/meta_graph.pb.h"
|
||||
#include "tensorflow/core/protobuf/saved_model.pb.h"
|
||||
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
||||
#include "tensorflow/core/protobuf/trackable_object_graph.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Maps from a FunctionDef's name to FunctionDef, for a given FunctionDefLibrary
|
||||
using FunctionDefMap =
|
||||
std::unordered_map<StringPiece, const tensorflow::FunctionDef*,
|
||||
StringPieceHasher>;
|
||||
|
||||
// Maps from a Nodedef's name to its corresponding AttrValues, for a given
|
||||
// Graphdef
|
||||
using NodeAttrMap =
|
||||
std::unordered_map<StringPiece, const AttrValueMap*, StringPieceHasher>;
|
||||
|
||||
// Maps from Node ID to an "Revived Object" implementing
|
||||
// "TensorHandleConvertible"
|
||||
using RevivedObjectMap =
|
||||
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>;
|
||||
|
||||
// Maps from a functiondef's name to the corresponding "TFConcreteFunction"
|
||||
using ConcreteFunctionMap =
|
||||
std::unordered_map<std::string, std::unique_ptr<TFConcreteFunction>>;
|
||||
|
||||
namespace {
|
||||
|
||||
Status ConstantFromSavedConstant(
|
||||
ImmediateExecutionContext* ctx,
|
||||
const tensorflow::SavedConstant& saved_constant,
|
||||
const NodeAttrMap& node_attr_map, std::unique_ptr<Constant>* output) {
|
||||
const std::string& const_op_name = saved_constant.operation();
|
||||
const auto& node_name_and_attrs = node_attr_map.find(const_op_name);
|
||||
if (node_name_and_attrs == node_attr_map.end()) {
|
||||
return errors::FailedPrecondition(
|
||||
"Unable to find Const operation with name'", const_op_name,
|
||||
"' in SavedModel graphdef");
|
||||
}
|
||||
const AttrValueMap* attrs = node_name_and_attrs->second;
|
||||
const auto& attr_name_and_value = attrs->find("value");
|
||||
if (attr_name_and_value == attrs->end()) {
|
||||
return errors::FailedPrecondition("Unable to find Const operation '",
|
||||
const_op_name, "'s value attribute");
|
||||
}
|
||||
const TensorProto& tensor_proto = attr_name_and_value->second.tensor();
|
||||
return internal::TensorProtoToConstant(ctx, tensor_proto, output);
|
||||
}
|
||||
|
||||
// Restores all non-function objects in the SavedModel's object graph.
|
||||
// This function walks through the metagraph's saved object graph, and
|
||||
// constructs revived versions of SavedVariable, SavedConstant, SavedAsset, and
|
||||
// SavedResources. These are returned via the `out` parameter.
|
||||
Status ReviveObjects(
|
||||
const MetaGraphDef& metagraph, ImmediateExecutionContext* context,
|
||||
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>*
|
||||
revived_objects) {
|
||||
// This is needed to restore "Constant" nodes by looking up their
|
||||
// "Value" attribute.
|
||||
NodeAttrMap node_attr_map = internal::NodeToAttrMap(metagraph.graph_def());
|
||||
|
||||
// Iterate through all the saved objects, restoring objects as we go.
|
||||
// We don't recreate functions until all other objects have been created.
|
||||
for (int i = 0; i < metagraph.object_graph_def().nodes_size(); ++i) {
|
||||
const SavedObject& node = metagraph.object_graph_def().nodes(i);
|
||||
if (node.kind_case() == SavedObject::kVariable) {
|
||||
std::unique_ptr<Variable> variable;
|
||||
TF_RETURN_IF_ERROR(
|
||||
internal::LoadSavedVariable(context, node.variable(), &variable));
|
||||
(*revived_objects)[i] = std::move(variable);
|
||||
} else if (node.kind_case() == SavedObject::kConstant) {
|
||||
std::unique_ptr<Constant> constant;
|
||||
TF_RETURN_IF_ERROR(ConstantFromSavedConstant(context, node.constant(),
|
||||
node_attr_map, &constant));
|
||||
(*revived_objects)[i] = std::move(constant);
|
||||
} else if (node.kind_case() == SavedObject::kAsset) {
|
||||
// TODO(bmzhao): Implement Asset C++ class. This should be just recreating
|
||||
// the full path to the asset file:
|
||||
// https://github.com/tensorflow/tensorflow/blob/6a0bdbdb7c48a3491ae1277083ae3dafb4ab4d7a/tensorflow/python/saved_model/load.py#L395-L396
|
||||
// and storing it as a string tensor:
|
||||
// https://github.com/tensorflow/tensorflow/blob/6a0bdbdb7c48a3491ae1277083ae3dafb4ab4d7a/tensorflow/python/training/tracking/tracking.py#L324-L325
|
||||
return errors::Unimplemented("SavedAsset loading is not implemented yet");
|
||||
} else if (node.kind_case() == SavedObject::kResource) {
|
||||
// TODO(bmzhao): Figure out how resource loading works and implement it
|
||||
return errors::Unimplemented(
|
||||
"SavedResource loading is not implemented yet");
|
||||
}
|
||||
}
|
||||
return Status();
|
||||
}
|
||||
|
||||
Status ReviveFunctions(const MetaGraphDef& metagraph,
|
||||
const RevivedObjectMap& revived_objects,
|
||||
ImmediateExecutionContext* context,
|
||||
ConcreteFunctionMap* restored_functions) {
|
||||
const FunctionDefMap function_def_map =
|
||||
internal::FunctionNameToFunctionDefMap(metagraph.graph_def().library());
|
||||
|
||||
// Iterate through all objects, only examining functions.
|
||||
for (const SavedObject& node : metagraph.object_graph_def().nodes()) {
|
||||
if (node.kind_case() == SavedObject::kBareConcreteFunction) {
|
||||
const std::string& function_name =
|
||||
node.bare_concrete_function().concrete_function_name();
|
||||
|
||||
const SavedConcreteFunction& saved_concrete_function =
|
||||
metagraph.object_graph_def().concrete_functions().at(function_name);
|
||||
|
||||
const FunctionDef* function_def = function_def_map.at(function_name);
|
||||
std::unique_ptr<TFConcreteFunction> concrete_function;
|
||||
TF_RETURN_IF_ERROR(internal::LoadTFConcreteFunction(
|
||||
saved_concrete_function, function_def, revived_objects, context,
|
||||
&concrete_function));
|
||||
(*restored_functions)[function_name] = std::move(concrete_function);
|
||||
} else if (node.kind_case() == SavedObject::kFunction) {
|
||||
// We only allow loading functions that have an annotated input signature,
|
||||
// which means there is 1:1 correspondence between tf.function
|
||||
// <=> SavedFunction <=> SavedConcreteFunction <=> FunctionDef. This is
|
||||
// the same restriction that MLIR has:
|
||||
// https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc#L2677-L2707
|
||||
const SavedFunction& saved_function = node.function();
|
||||
if (saved_function.concrete_functions_size() != 1) {
|
||||
return errors::FailedPrecondition(
|
||||
"Only tf.functions annotated with an input signature are supported "
|
||||
"by SavedModelAPI. This means that there should only be a single "
|
||||
"ConcreteFunction per tf.function");
|
||||
}
|
||||
const std::string& function_name = saved_function.concrete_functions(0);
|
||||
const SavedConcreteFunction& saved_concrete_function =
|
||||
metagraph.object_graph_def().concrete_functions().at(function_name);
|
||||
|
||||
const FunctionDef* function_def = function_def_map.at(function_name);
|
||||
|
||||
std::unique_ptr<TFConcreteFunction> concrete_function;
|
||||
TF_RETURN_IF_ERROR(internal::LoadTFConcreteFunction(
|
||||
saved_concrete_function, function_def, revived_objects, context,
|
||||
&concrete_function));
|
||||
(*restored_functions)[function_name] = std::move(concrete_function);
|
||||
}
|
||||
}
|
||||
return Status();
|
||||
}
|
||||
|
||||
const TrackableObjectGraph::TrackableObject::SerializedTensor*
|
||||
FindSerializedTensorInTrackable(
|
||||
const TrackableObjectGraph::TrackableObject& trackable_object,
|
||||
absl::string_view name) {
|
||||
for (const auto& maybe_serialized_tensor : trackable_object.attributes()) {
|
||||
if (maybe_serialized_tensor.name() == name) {
|
||||
return &maybe_serialized_tensor;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// This function reads the Checkpoint embedded in the SavedModel, and calls the
|
||||
// appropriate Restore ops on each of the variables.
|
||||
// Note(bmzhao): Conceptually, objects that contain checkpointable state
|
||||
// implement the "_gather_saveables_for_checkpoint" method
|
||||
// https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/tracking/base.py#L953-L983
|
||||
// which returns a dict of string key -> EITHER:
|
||||
// 1. python callable (taking a checkpoint key) returning SaveableObject OR
|
||||
// 2. variable (partitioned/resource/reference or otherwise)
|
||||
// https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/saving/saveable_object.py#L58.
|
||||
// The string key becomes the "name" attribute of the SerializedTensor proto
|
||||
// in the TrackableObjectGraph,
|
||||
// https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/core/protobuf/trackable_object_graph.proto#L26
|
||||
// And the checkpoint_key is a globally unique string derived from this name:
|
||||
// https://github.com/tensorflow/tensorflow/blob/842df9e6b516e42578a8d23b35d41176b9a6cf1d/tensorflow/python/training/tracking/graph_view.py#L236-L241
|
||||
// SaveableObjects model the information needed to pass to the SaveV2/RestoreV2
|
||||
// ops via their SaveSpec members
|
||||
// https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/saving/saveable_object.py#L21,
|
||||
// which contain the "real" checkpoint keys into the TensorBundle SSTable.
|
||||
// They also contain the logic needed to take the restored tensors from
|
||||
// RestoreV2 and load them back into the "object" they came from via their
|
||||
// overridden "restore" method:
|
||||
// https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/saving/saveable_object.py#L85
|
||||
Status RestoreCheckpoint(SavedModelV2Bundle* bundle,
|
||||
const RevivedObjectMap& revived_objects,
|
||||
const std::string& directory,
|
||||
ImmediateExecutionContext* context) {
|
||||
// TODO(bmzhao): Batch up all the restores into a single restore op per
|
||||
// device, following logic in MultiDeviceSaver.
|
||||
TF_RETURN_IF_ERROR(bundle->VisitObjectsToRestore(
|
||||
[&revived_objects, &directory, context, bundle](
|
||||
int node, const TrackableObjectGraph::TrackableObject& trackable) {
|
||||
if (bundle->saved_object_graph().nodes(node).kind_case() !=
|
||||
SavedObject::kVariable) {
|
||||
// TODO(bmzhao): This requires using the newly added Save/Restore
|
||||
// functions from
|
||||
// https://github.com/tensorflow/tensorflow/commit/df6b21c13c82b5d0981642cfe18f10e60f78ea5c
|
||||
return errors::Unimplemented(
|
||||
"Restoring non-variable objects has not been implemented yet. ");
|
||||
}
|
||||
|
||||
Variable* variable =
|
||||
down_cast<Variable*>(revived_objects.at(node).get());
|
||||
|
||||
// Restore the tensor's value from the checkpoint
|
||||
const TrackableObjectGraph::TrackableObject::SerializedTensor*
|
||||
attribute =
|
||||
FindSerializedTensorInTrackable(trackable, "VARIABLE_VALUE");
|
||||
if (attribute == nullptr) {
|
||||
return errors::FailedPrecondition(
|
||||
"Could not find SerializedTensor with name VARIABLE_VALUE for "
|
||||
"saved variable");
|
||||
}
|
||||
|
||||
const std::string& checkpoint_key = attribute->checkpoint_key();
|
||||
std::string variables_path_prefix =
|
||||
io::JoinPath(directory, kSavedModelVariablesDirectory,
|
||||
kSavedModelVariablesFilename);
|
||||
ImmediateTensorHandlePtr restored_output;
|
||||
TF_RETURN_IF_ERROR(internal::SingleRestore(
|
||||
context, variables_path_prefix, checkpoint_key, variable->dtype(),
|
||||
&restored_output));
|
||||
|
||||
// Assign the restored tensor's value to the variable
|
||||
return variable->Assign(restored_output.get());
|
||||
}));
|
||||
|
||||
return Status();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status TFSavedModelAPI::GetFunction(const std::string& function_path,
|
||||
ConcreteFunction** function) {
|
||||
// TODO(bmzhao): Add support for retrieving a function.
|
||||
return errors::Unimplemented(
|
||||
"Retrieving functions is unimplemented currently");
|
||||
const SavedObject* object =
|
||||
internal::FindNodeAtPath(function_path, bundle_.saved_object_graph());
|
||||
if (object == nullptr) {
|
||||
return errors::NotFound("No saved object found at path ", function_path);
|
||||
}
|
||||
|
||||
if (object->kind_case() == SavedObject::kBareConcreteFunction) {
|
||||
*function =
|
||||
concrete_functions_
|
||||
.at(object->bare_concrete_function().concrete_function_name())
|
||||
.get();
|
||||
} else if (object->kind_case() == SavedObject::kFunction) {
|
||||
*function =
|
||||
concrete_functions_.at(object->function().concrete_functions(0)).get();
|
||||
} else {
|
||||
return errors::InvalidArgument(function_path,
|
||||
" is not a path to a Function.");
|
||||
}
|
||||
|
||||
return Status();
|
||||
}
|
||||
|
||||
Status TFSavedModelAPI::GetSignatureDefFunction(
|
||||
const std::string& signature_def_key, ConcreteFunction** function) {
|
||||
// TODO(bmzhao): Add support for retrieving a signaturedef function.
|
||||
return errors::Unimplemented(
|
||||
"Retrieving functions is unimplemented currently");
|
||||
"Retrieving SignatureDef functions is unimplemented currently");
|
||||
}
|
||||
|
||||
std::vector<ConcreteFunction*> TFSavedModelAPI::ListFunctions() {
|
||||
std::vector<ConcreteFunction*> result;
|
||||
result.reserve(functions_.size());
|
||||
for (ConcreteFunction& function : functions_) {
|
||||
result.push_back(&function);
|
||||
result.reserve(concrete_functions_.size());
|
||||
for (auto& index_and_function : concrete_functions_) {
|
||||
result.push_back(index_and_function.second.get());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
TFSavedModelAPI::TFSavedModelAPI(
|
||||
const std::string& directory, SavedModelV2Bundle bundle,
|
||||
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>
|
||||
revived_objects,
|
||||
std::unordered_map<std::string, std::unique_ptr<TFConcreteFunction>>
|
||||
concrete_functions)
|
||||
: directory_(directory),
|
||||
bundle_(std::move(bundle)),
|
||||
revived_objects_(std::move(revived_objects)),
|
||||
concrete_functions_(std::move(concrete_functions)) {}
|
||||
|
||||
Status TFSavedModelAPI::Load(
|
||||
const std::string& directory,
|
||||
const absl::optional<std::unordered_set<std::string>>& tags,
|
||||
ImmediateExecutionContext* context, std::unique_ptr<TFSavedModelAPI>* out) {
|
||||
// TODO(bmzhao): Add support for loading a TFSavedModelImpl.
|
||||
return errors::Unimplemented(
|
||||
"TFSavedModelAPIImpl loading is unimplemented currently");
|
||||
// TODO(bmzhao): Add support for loading a TF1 SavedModel.
|
||||
if (tags) {
|
||||
return errors::Unimplemented(
|
||||
"Loading saved models with explicit tags will be supported in the "
|
||||
"future");
|
||||
}
|
||||
|
||||
SavedModelV2Bundle bundle;
|
||||
TF_RETURN_IF_ERROR(SavedModelV2Bundle::Load(directory, &bundle));
|
||||
|
||||
// TODO(bmzhao): Mangle loaded function names so that different
|
||||
// models loaded in the same runtime Context don't clobber eachother.
|
||||
// This occurs in python here:
|
||||
// https://github.com/tensorflow/tensorflow/blob/285b5fa15405c5e2c084080f52a1818be8648079/tensorflow/python/saved_model/function_deserialization.py#L438-L454
|
||||
|
||||
RevivedObjectMap revived_objects;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ReviveObjects(bundle.meta_graph_def(), context, &revived_objects));
|
||||
|
||||
// TODO(bmzhao): When we later add support for loading resources, we need to
|
||||
// handle the case where materializing a function's captures requires invoking
|
||||
// other functions. This occurs when retrieving the resource handle for a
|
||||
// TrackableResource:
|
||||
// https://github.com/tensorflow/tensorflow/blob/f19c6efb4a8ba60e2492eedc98ef5375abb39dc7/tensorflow/python/saved_model/load.py#L240
|
||||
// https://github.com/tensorflow/tensorflow/blob/f19c6efb4a8ba60e2492eedc98ef5375abb39dc7/tensorflow/python/training/tracking/tracking.py#L233
|
||||
// This requires restoring functions in a topological sort order by capture
|
||||
// dependencies.
|
||||
ConcreteFunctionMap function_map;
|
||||
TF_RETURN_IF_ERROR(ReviveFunctions(bundle.meta_graph_def(), revived_objects,
|
||||
context, &function_map));
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
RestoreCheckpoint(&bundle, revived_objects, directory, context));
|
||||
|
||||
out->reset(new TFSavedModelAPI(directory, std::move(bundle),
|
||||
std::move(revived_objects),
|
||||
std::move(function_map)));
|
||||
return Status();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -16,14 +16,19 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SAVED_MODEL_IMPL_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SAVED_MODEL_IMPL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
|
||||
#include "tensorflow/cc/saved_model/bundle_v2.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -63,8 +68,19 @@ class TFSavedModelAPI : public SavedModelAPI {
|
||||
~TFSavedModelAPI() override = default;
|
||||
|
||||
private:
|
||||
TFSavedModelAPI() = default;
|
||||
std::vector<ConcreteFunction> functions_;
|
||||
TFSavedModelAPI(
|
||||
const std::string& directory, SavedModelV2Bundle bundle,
|
||||
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>
|
||||
revived_objects,
|
||||
std::unordered_map<std::string, std::unique_ptr<TFConcreteFunction>>
|
||||
concrete_functions);
|
||||
|
||||
std::string directory_;
|
||||
SavedModelV2Bundle bundle_;
|
||||
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>
|
||||
revived_objects_;
|
||||
std::unordered_map<std::string, std::unique_ptr<TFConcreteFunction>>
|
||||
concrete_functions_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -16,10 +16,15 @@ limitations under the License.
|
||||
#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/stringpiece.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
@ -92,12 +97,51 @@ TEST_P(CSavedModelAPITest, LoadsSavedModel) {
|
||||
TF_SavedModel* saved_model =
|
||||
TF_LoadSavedModel(model_dir.c_str(), ctx, status);
|
||||
|
||||
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
|
||||
// That unblocks writing other tests that require a TF_SavedModel*,
|
||||
// like loading a ConcreteFunction. This test at least checks that the
|
||||
// C API builds and can be minimally run.
|
||||
EXPECT_EQ(TF_GetCode(status), TF_UNIMPLEMENTED);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TF_ConcreteFunction* compute_fn =
|
||||
TF_GetSavedModelConcreteFunction(saved_model, "compute", status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
TFE_Op* compute_fn_op = TF_ConcreteFunctionGetCallOp(compute_fn, status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
const TF_TensorHandleList* captures =
|
||||
TF_ConcreteFunctionGetCaptures(compute_fn);
|
||||
|
||||
// TODO(bmzhao): Finish API on FunctionMetadata args, so we know how many
|
||||
// inputs + outputs a function has.
|
||||
std::vector<TFE_TensorHandle*> compute_fn_inputs;
|
||||
TFE_TensorHandle* input_a = TestScalarTensorHandle(ctx, 2.0f);
|
||||
TFE_TensorHandle* input_b = TestScalarTensorHandle(ctx, 1.0f);
|
||||
compute_fn_inputs.reserve(2 + TF_TensorHandleListSize(captures));
|
||||
compute_fn_inputs.push_back(input_a);
|
||||
compute_fn_inputs.push_back(input_b);
|
||||
for (int i = 0; i < TF_TensorHandleListSize(captures); ++i) {
|
||||
compute_fn_inputs.push_back(TF_TensorHandleListGet(captures, i));
|
||||
}
|
||||
TFE_OpAddInputList(compute_fn_op, compute_fn_inputs.data(),
|
||||
compute_fn_inputs.size(), status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* compute_fn_outputs[1] = {nullptr};
|
||||
int num_retvals = 1;
|
||||
|
||||
TFE_Execute(compute_fn_op, &compute_fn_outputs[0], &num_retvals, status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
TF_Tensor* result = TFE_TensorHandleResolve(compute_fn_outputs[0], status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
EXPECT_EQ(TF_NumDims(result), 0);
|
||||
float output_value = *static_cast<float*>(TF_TensorData(result));
|
||||
// (1 + 2) * (2 + 1) / 3 + 5 should be 8
|
||||
EXPECT_FLOAT_EQ(output_value, 8.0);
|
||||
|
||||
TF_DeleteTensor(result);
|
||||
TFE_DeleteTensorHandle(compute_fn_outputs[0]);
|
||||
TFE_DeleteTensorHandle(input_a);
|
||||
TFE_DeleteTensorHandle(input_b);
|
||||
TFE_DeleteOp(compute_fn_op);
|
||||
TF_DeleteSavedModel(saved_model);
|
||||
TF_DeleteStatus(status);
|
||||
TFE_DeleteContext(ctx);
|
||||
|
@ -86,11 +86,7 @@ TEST_P(CPPSavedModelAPITest, LoadsSavedModel) {
|
||||
std::unique_ptr<SavedModelAPI> model =
|
||||
SavedModelAPI::Load(model_dir, *runtime, &status);
|
||||
|
||||
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
|
||||
// That unblocks writing other tests that require a TF_SavedModel*,
|
||||
// like loading a ConcreteFunction. This test at least checks that the
|
||||
// C API builds and can be minimally run.
|
||||
EXPECT_EQ(status.code(), TF_UNIMPLEMENTED) << status.message();
|
||||
EXPECT_EQ(status.code(), TF_OK) << status.message();
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticCPPSavedModelTests,
|
||||
|
@ -648,11 +648,11 @@ cc_library(
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_bounds_check",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/framework:bounds_check",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/base",
|
||||
@ -677,11 +677,11 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_bounds_check",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/framework:bounds_check",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
|
@ -277,7 +277,8 @@ static Status CompileToLocalExecutable(
|
||||
OpKernelContext* ctx, const NameAttrList& function, bool has_ref_vars,
|
||||
const XlaPlatformInfo& platform_info,
|
||||
absl::Span<VariableInfo const> variable_infos,
|
||||
absl::Span<const int> constants, bool lazy, xla::LocalClient** client,
|
||||
absl::Span<const int> constants, bool lazy, bool may_alias_resource_update,
|
||||
xla::LocalClient** client,
|
||||
const XlaCompiler::CompilationResult** compilation_result,
|
||||
xla::LocalExecutable** executable) {
|
||||
// We store information about the JIT-compiled XLA computation
|
||||
@ -332,6 +333,9 @@ static Status CompileToLocalExecutable(
|
||||
// Optimization: where possible, have the computation return a naked array
|
||||
// rather than a one-element tuple.
|
||||
compile_options.always_return_tuple = false;
|
||||
compile_options.alias_resource_update = !has_ref_vars &&
|
||||
!platform_info.is_on_xla_device() &&
|
||||
may_alias_resource_update;
|
||||
|
||||
std::vector<XlaCompiler::Argument> args;
|
||||
TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
@ -350,20 +354,22 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
||||
const XlaCompiler::CompilationResult* compilation_result;
|
||||
xla::LocalExecutable* executable;
|
||||
|
||||
ResourceVarsSnapshot variables_snapshot;
|
||||
std::vector<VariableInfo> variable_infos;
|
||||
{
|
||||
std::vector<VariableInfo> variable_infos;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetVariableInfosFromCtxInputs(ctx, resources_, &variable_infos));
|
||||
OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos)));
|
||||
Status s = CompileToLocalExecutable(
|
||||
ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_,
|
||||
variable_infos, constants_, /*lazy=*/false, &client,
|
||||
&compilation_result, &executable);
|
||||
variable_infos, constants_, /*lazy=*/false,
|
||||
/*may_alias_resource_update=*/true, &client, &compilation_result,
|
||||
&executable);
|
||||
OP_REQUIRES_OK(ctx, s);
|
||||
OP_REQUIRES_OK(ctx,
|
||||
SnapshotResourceVariables(ctx, resources_, variable_infos,
|
||||
&variables_snapshot));
|
||||
}
|
||||
|
||||
std::map<int, const Tensor*> resource_var_ptrs;
|
||||
for (int i = 0; i < resources_.size(); i++) {
|
||||
resource_var_ptrs[resources_[i]] = variable_infos[i].var()->tensor();
|
||||
}
|
||||
|
||||
se::Stream* stream =
|
||||
@ -374,12 +380,19 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
||||
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
|
||||
se::DeviceMemoryAllocator* allocator =
|
||||
GetAllocator(&tf_allocator_adapter, ctx, platform_info_);
|
||||
int device_ordinal = stream ? stream->parent()->device_ordinal()
|
||||
: client->default_device_ordinal();
|
||||
XlaComputationLaunchContext launch_context(
|
||||
client, allocator,
|
||||
client, allocator, device_ordinal,
|
||||
/*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
|
||||
platform_info_.UseMultipleStreams());
|
||||
launch_context.PopulateInputs(ctx, compilation_result, variables_snapshot,
|
||||
/*missing_ctx_input_prefix=*/0);
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias =
|
||||
executable->executable()->module().input_output_alias_config();
|
||||
xla::StatusOr<std::vector<xla::ExecutionInput>> execution_inputs =
|
||||
launch_context.PopulateInputs(ctx, compilation_result, resource_var_ptrs,
|
||||
/*missing_ctx_input_prefix=*/0,
|
||||
input_output_alias);
|
||||
OP_REQUIRES_OK(ctx, execution_inputs.status());
|
||||
|
||||
// Execute the computation.
|
||||
VLOG(2) << "Executing computation.";
|
||||
@ -403,24 +416,24 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
||||
Env* env = Env::Default();
|
||||
auto start_time = env->NowMicros();
|
||||
|
||||
xla::StatusOr<xla::ScopedShapedBuffer> run_result;
|
||||
xla::StatusOr<xla::ExecutionOutput> execution_output;
|
||||
if (!stream || platform_info_.platform_id() == se::host::kHostPlatformId) {
|
||||
run_result = executable->Run(launch_context.arguments(), run_options);
|
||||
execution_output =
|
||||
executable->Run(std::move(*execution_inputs), run_options);
|
||||
} else {
|
||||
run_result = executable->RunAsync(launch_context.arguments(), run_options);
|
||||
execution_output =
|
||||
executable->RunAsync(std::move(*execution_inputs), run_options);
|
||||
}
|
||||
OP_REQUIRES(ctx, run_result.ok(), run_result.status());
|
||||
OP_REQUIRES(ctx, execution_output.ok(), execution_output.status());
|
||||
|
||||
auto elapsed = env->NowMicros() - start_time;
|
||||
VLOG(2) << "Elapsed time: " << elapsed << "us";
|
||||
OP_REQUIRES_OK(
|
||||
ctx, launch_context.PopulateOutputs(
|
||||
ctx, compilation_result, execution_output->ConsumeResult(),
|
||||
/*missing_ctx_input_prefix=*/0, absl::MakeSpan(variable_infos),
|
||||
input_output_alias, resource_var_ptrs));
|
||||
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias =
|
||||
executable->executable()->module().input_output_alias_config();
|
||||
OP_REQUIRES_OK(ctx,
|
||||
launch_context.PopulateOutputs(
|
||||
ctx, compilation_result, run_result.ConsumeValueOrDie(),
|
||||
/*missing_ctx_input_prefix=*/0, input_output_alias,
|
||||
variables_snapshot));
|
||||
VLOG(1) << "Done";
|
||||
}
|
||||
|
||||
@ -516,10 +529,14 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetVariableInfosFromCtxInputs(ctx, resources_, &variable_infos));
|
||||
OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos)));
|
||||
|
||||
// Do not alias resource updates as locking variables in XlaCompile and
|
||||
// unlocking them in XlaRun may lead to deadlocks.
|
||||
Status status = CompileToLocalExecutable(
|
||||
ctx, function_, has_ref_vars_, platform_info_, variable_infos,
|
||||
constants_,
|
||||
/*lazy=*/!must_compile_, &client, &kernel, &executable);
|
||||
/*lazy=*/!must_compile_,
|
||||
/*may_alias_resource_update=*/false, &client, &kernel, &executable);
|
||||
OP_REQUIRES_OK(ctx, SnapshotResourceVariables(ctx, resources_,
|
||||
variable_infos, &variables));
|
||||
if (must_compile_ || status.code() != error::UNIMPLEMENTED) {
|
||||
@ -587,14 +604,22 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
|
||||
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
|
||||
se::DeviceMemoryAllocator* allocator =
|
||||
GetAllocator(&tf_allocator_adapter, ctx, platform_info_);
|
||||
se::Stream* stream =
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
||||
int device_ordinal = stream ? stream->parent()->device_ordinal()
|
||||
: closure.client()->default_device_ordinal();
|
||||
XlaComputationLaunchContext launch_context(
|
||||
closure.client(), allocator,
|
||||
closure.client(), allocator, device_ordinal,
|
||||
/*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
|
||||
/*use_multiple_streams=*/platform_info_.UseMultipleStreams());
|
||||
|
||||
// We're missing the must-be-constant inputs, tell `PopulateInputs`
|
||||
// about this. We don't actually need these inputs because they've
|
||||
// already been baked into the compiled kernel.
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias =
|
||||
closure.executable()->executable()->module().input_output_alias_config();
|
||||
xla::StatusOr<std::vector<xla::ExecutionInput>> execution_inputs;
|
||||
std::map<int, const Tensor*> snapshot_ptrs;
|
||||
{
|
||||
tensorflow::profiler::TraceMe hlo_module_activity(
|
||||
[&] {
|
||||
@ -604,13 +629,17 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
|
||||
},
|
||||
tensorflow::profiler::TraceMeLevel::kInfo);
|
||||
|
||||
launch_context.PopulateInputs(
|
||||
ctx, closure.compilation_result(), closure.resource_var_snapshots(),
|
||||
/*missing_ctx_input_prefix=*/closure.num_constant_args());
|
||||
for (auto& p : closure.resource_var_snapshots()) {
|
||||
snapshot_ptrs.emplace(p.first,
|
||||
p.second.has_value() ? &p.second.value() : nullptr);
|
||||
}
|
||||
execution_inputs = launch_context.PopulateInputs(
|
||||
ctx, closure.compilation_result(), snapshot_ptrs,
|
||||
/*missing_ctx_input_prefix=*/closure.num_constant_args(),
|
||||
input_output_alias);
|
||||
OP_REQUIRES_OK(ctx, execution_inputs.status());
|
||||
}
|
||||
|
||||
se::Stream* stream =
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
||||
xla::ExecutableRunOptions run_options;
|
||||
run_options.set_stream(stream);
|
||||
run_options.set_allocator(allocator);
|
||||
@ -631,21 +660,19 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
|
||||
Env* env = Env::Default();
|
||||
auto start_time = env->NowMicros();
|
||||
|
||||
xla::StatusOr<xla::ScopedShapedBuffer> run_result;
|
||||
xla::StatusOr<xla::ExecutionOutput> execution_output;
|
||||
if (!stream || platform_info_.platform_id() == se::host::kHostPlatformId) {
|
||||
run_result =
|
||||
closure.executable()->Run(launch_context.arguments(), run_options);
|
||||
execution_output =
|
||||
closure.executable()->Run(std::move(*execution_inputs), run_options);
|
||||
} else {
|
||||
run_result =
|
||||
closure.executable()->RunAsync(launch_context.arguments(), run_options);
|
||||
execution_output = closure.executable()->RunAsync(
|
||||
std::move(*execution_inputs), run_options);
|
||||
}
|
||||
OP_REQUIRES(ctx, run_result.ok(), run_result.status());
|
||||
OP_REQUIRES(ctx, execution_output.ok(), execution_output.status());
|
||||
|
||||
auto elapsed = env->NowMicros() - start_time;
|
||||
VLOG(2) << "Elapsed time in computation: " << elapsed << "us";
|
||||
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias =
|
||||
closure.executable()->executable()->module().input_output_alias_config();
|
||||
|
||||
tensorflow::profiler::TraceMe hlo_module_activity(
|
||||
[&] {
|
||||
@ -653,12 +680,16 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
|
||||
},
|
||||
tensorflow::profiler::TraceMeLevel::kInfo);
|
||||
|
||||
xla::StatusOr<std::vector<VariableInfo>> variable_infos = GatherVariableInfo(
|
||||
ctx, *closure.compilation_result(), closure.num_constant_args());
|
||||
OP_REQUIRES_OK(ctx, variable_infos.status());
|
||||
OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(*variable_infos)));
|
||||
OP_REQUIRES_OK(
|
||||
ctx,
|
||||
launch_context.PopulateOutputs(
|
||||
ctx, closure.compilation_result(), run_result.ConsumeValueOrDie(),
|
||||
ctx, closure.compilation_result(), execution_output->ConsumeResult(),
|
||||
/*missing_ctx_input_prefix=*/closure.num_constant_args(),
|
||||
input_output_alias, closure.resource_var_snapshots()));
|
||||
absl::MakeSpan(*variable_infos), input_output_alias, snapshot_ptrs));
|
||||
}
|
||||
|
||||
XlaMergeOp::XlaMergeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
@ -1829,7 +1829,7 @@ TEST(XlaCompilationTest, XLALiteAllowlist) {
|
||||
}
|
||||
EXPECT_TRUE(unknow_op.empty())
|
||||
<< "Someone added support for a new TF opeations inside XLA. They must "
|
||||
"be included in the XLALite allowlist or blacklist:\n"
|
||||
"be included in the XLALite allowlist or denylist:\n"
|
||||
<< absl::StrJoin(unknow_op, "\n");
|
||||
}
|
||||
} // namespace
|
||||
|
@ -50,35 +50,47 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
|
||||
// Builds an XLA allocator for the device.
|
||||
XlaComputationLaunchContext launch_context(
|
||||
client, client->backend().memory_allocator(),
|
||||
client->default_device_ordinal(),
|
||||
/*allocate_xla_tensors=*/true,
|
||||
/*use_multiple_streams=*/metadata.UseMultipleStreams());
|
||||
|
||||
launch_context.PopulateInputs(ctx, result, variable_args,
|
||||
/*missing_ctx_input_prefix=*/0);
|
||||
std::map<int, const Tensor*> snapshot_ptrs;
|
||||
for (auto& p : variable_args) {
|
||||
snapshot_ptrs.emplace(p.first,
|
||||
p.second.has_value() ? &p.second.value() : nullptr);
|
||||
}
|
||||
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias =
|
||||
executable->executable()->module().input_output_alias_config();
|
||||
xla::StatusOr<std::vector<xla::ExecutionInput>> execution_inputs =
|
||||
launch_context.PopulateInputs(ctx, result, snapshot_ptrs,
|
||||
/*missing_ctx_input_prefix=*/0,
|
||||
input_output_alias);
|
||||
TF_RETURN_IF_ERROR(execution_inputs.status());
|
||||
|
||||
se::Stream* stream =
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
||||
TF_RET_CHECK(stream);
|
||||
|
||||
VLOG(2) << "Executing computation: " << name();
|
||||
for (const xla::ShapedBuffer* arg : launch_context.arguments()) {
|
||||
VLOG(2) << name() << ": " << *arg;
|
||||
}
|
||||
xla::ExecutableRunOptions run_options;
|
||||
run_options.set_stream(stream);
|
||||
run_options.set_allocator(client->backend().memory_allocator());
|
||||
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
|
||||
run_options.set_rng_seed(GetXLARandomSeed());
|
||||
|
||||
xla::StatusOr<xla::ScopedShapedBuffer> run_result =
|
||||
executable->Run(launch_context.arguments(), run_options);
|
||||
xla::StatusOr<xla::ExecutionOutput> run_result =
|
||||
executable->Run(execution_inputs.ConsumeValueOrDie(), run_options);
|
||||
TF_RETURN_IF_ERROR(run_result.status());
|
||||
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias =
|
||||
executable->executable()->module().input_output_alias_config();
|
||||
xla::ExecutionOutput execution_output = run_result.ConsumeValueOrDie();
|
||||
xla::StatusOr<std::vector<VariableInfo>> variable_infos =
|
||||
GatherVariableInfo(ctx, *result, 0);
|
||||
TF_RETURN_IF_ERROR(variable_infos.status());
|
||||
TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(*variable_infos)));
|
||||
TF_RETURN_IF_ERROR(launch_context.PopulateOutputs(
|
||||
ctx, result, run_result.ConsumeValueOrDie(),
|
||||
/*missing_ctx_input_prefix=*/0, input_output_alias, variable_args));
|
||||
ctx, result, execution_output.ConsumeResult(),
|
||||
/*missing_ctx_input_prefix=*/0, absl::MakeSpan(*variable_infos),
|
||||
input_output_alias, snapshot_ptrs));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -59,11 +59,13 @@ void XlaAssignVariableOp::Compute(OpKernelContext* context) {
|
||||
return Status::OK();
|
||||
}));
|
||||
mutex_lock ml(*variable->mu());
|
||||
OP_REQUIRES(context, variable->tensor()->dtype() == dtype_,
|
||||
errors::InvalidArgument(
|
||||
"Trying to assign variable with wrong dtype. Expected ",
|
||||
DataTypeString(variable->tensor()->dtype()), " got ",
|
||||
DataTypeString(dtype_)));
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
!variable->is_initialized || variable->tensor()->dtype() == dtype_,
|
||||
errors::InvalidArgument(
|
||||
"Trying to assign variable with wrong dtype. Expected ",
|
||||
DataTypeString(variable->tensor()->dtype()), " got ",
|
||||
DataTypeString(dtype_)));
|
||||
variable->is_initialized = true;
|
||||
*variable->tensor() = value;
|
||||
}
|
||||
|
@ -91,29 +91,19 @@ VariableInfo::~VariableInfo() {
|
||||
Status GetVariableInfosFromCtxInputs(OpKernelContext* ctx,
|
||||
absl::Span<const int> variable_indices,
|
||||
std::vector<VariableInfo>* result) {
|
||||
std::vector<const ResourceHandle*> resource_handles;
|
||||
absl::c_transform(
|
||||
variable_indices, std::back_inserter(resource_handles),
|
||||
[&](int variable_idx) { return &HandleFromInput(ctx, variable_idx); });
|
||||
|
||||
std::vector<core::RefCountPtr<Var>> variables;
|
||||
Status s = LookupResources(ctx, resource_handles, &variables);
|
||||
if (!s.ok()) {
|
||||
errors::AppendToMessage(&s, kPossibleNonVariableResourceHintMessage);
|
||||
return s;
|
||||
}
|
||||
|
||||
result->clear();
|
||||
result->reserve(variable_indices.size());
|
||||
for (int i = 0; i < variable_indices.size(); i++) {
|
||||
// *Release* the variable because we're going to unref it later in
|
||||
// ~VariableInfo.
|
||||
Var* variable = variables[i].release();
|
||||
int input_idx = variable_indices[i];
|
||||
std::string var_name = HandleFromInput(ctx, input_idx).name();
|
||||
result->emplace_back(input_idx, var_name, variable);
|
||||
for (int var_idx : variable_indices) {
|
||||
Var* variable = nullptr;
|
||||
ResourceHandle handle = HandleFromInput(ctx, var_idx);
|
||||
TF_RETURN_IF_ERROR(
|
||||
LookupOrCreateResource<Var>(ctx, handle, &variable, [&](Var** ptr) {
|
||||
// This var is uninitialized for now.
|
||||
*ptr = new Var(DT_INVALID);
|
||||
return Status::OK();
|
||||
}));
|
||||
result->emplace_back(var_idx, handle.name(), variable);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -176,24 +166,43 @@ Status SnapshotResourceVariables(OpKernelContext* ctx,
|
||||
|
||||
XlaComputationLaunchContext::XlaComputationLaunchContext(
|
||||
xla::LocalClient* client, se::DeviceMemoryAllocator* xla_allocator,
|
||||
bool allocate_xla_tensors, bool use_multiple_streams)
|
||||
int device_ordinal, bool allocate_xla_tensors, bool use_multiple_streams)
|
||||
: client_(client),
|
||||
xla_allocator_(xla_allocator),
|
||||
allocate_xla_tensors_(allocate_xla_tensors),
|
||||
use_multiple_streams_(use_multiple_streams) {
|
||||
use_multiple_streams_(use_multiple_streams),
|
||||
device_ordinal_(device_ordinal) {
|
||||
if (use_multiple_streams_) {
|
||||
CHECK(allocate_xla_tensors_) << "To use multiple streams correctly we must "
|
||||
"be allocating XLA tensors!";
|
||||
}
|
||||
}
|
||||
|
||||
void XlaComputationLaunchContext::PopulateInputs(
|
||||
// Fills in `execution_input` with `buffer` for `index`.
|
||||
static void PopulateExecutionInputBuffer(xla::ExecutionInput& execution_input,
|
||||
xla::ShapeIndex index,
|
||||
se::DeviceMemoryBase& buffer,
|
||||
bool donate_buffer, int device_ordinal,
|
||||
se::DeviceMemoryAllocator* allocator) {
|
||||
xla::MaybeOwningDeviceMemory* in_buffer =
|
||||
execution_input.MutableBuffer(index);
|
||||
if (donate_buffer) {
|
||||
*in_buffer = se::OwningDeviceMemory(buffer, device_ordinal, allocator);
|
||||
buffer = se::DeviceMemoryBase();
|
||||
} else {
|
||||
*in_buffer = buffer;
|
||||
}
|
||||
}
|
||||
|
||||
xla::StatusOr<std::vector<xla::ExecutionInput>>
|
||||
XlaComputationLaunchContext::PopulateInputs(
|
||||
OpKernelContext* ctx,
|
||||
const XlaCompiler::CompilationResult* compilation_result,
|
||||
const ResourceVarsSnapshot& variables, int missing_ctx_input_prefix) {
|
||||
// Build ShapedBuffers that point directly to the Tensor buffers.
|
||||
arg_ptrs_ =
|
||||
std::vector<ShapedBuffer*>(compilation_result->xla_input_shapes.size());
|
||||
const std::map<int, const Tensor*>& resource_vars,
|
||||
int missing_ctx_input_prefix,
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias) {
|
||||
std::vector<xla::ExecutionInput> arguments;
|
||||
arguments.reserve(compilation_result->xla_input_shapes.size());
|
||||
|
||||
xla::TransferManager* transfer_manager =
|
||||
client_->backend().transfer_manager();
|
||||
@ -201,10 +210,28 @@ void XlaComputationLaunchContext::PopulateInputs(
|
||||
int arg_num = compilation_result->input_mapping[i];
|
||||
CHECK_GE(arg_num, missing_ctx_input_prefix);
|
||||
const xla::Shape& shape = compilation_result->xla_input_shapes[i];
|
||||
const Tensor* t = variables.count(arg_num)
|
||||
? &(variables.at(arg_num).value())
|
||||
const xla::Shape& device_shape =
|
||||
transfer_manager->HostShapeToDeviceShape(shape);
|
||||
|
||||
bool is_resource_variable = resource_vars.count(arg_num);
|
||||
bool is_updated_resource_variable =
|
||||
is_resource_variable &&
|
||||
absl::c_any_of(compilation_result->resource_updates,
|
||||
[&](const XlaCompiler::ResourceUpdate& update) {
|
||||
return update.input_index == i && update.modified;
|
||||
});
|
||||
|
||||
const Tensor* t = is_resource_variable
|
||||
? resource_vars.at(arg_num)
|
||||
: &(ctx->input(arg_num - missing_ctx_input_prefix));
|
||||
CHECK(t);
|
||||
bool donate_buffer =
|
||||
t->RefCountIsOne() && is_updated_resource_variable &&
|
||||
input_output_alias.ParameterHasAlias(i, xla::ShapeIndex{});
|
||||
VLOG(3) << "Processing input: " << i
|
||||
<< "; is_resource_variable=" << is_resource_variable
|
||||
<< "; is_updated_resource_variable=" << is_updated_resource_variable
|
||||
<< "; donate_buffer=" << donate_buffer;
|
||||
|
||||
if (use_multiple_streams_) {
|
||||
CHECK(ctx->op_device_context() && ctx->op_device_context()->stream())
|
||||
@ -215,23 +242,28 @@ void XlaComputationLaunchContext::PopulateInputs(
|
||||
ctx->op_device_context()->stream());
|
||||
}
|
||||
|
||||
if (xla::Shape::Equal().MinorToMajorOnlyInLayout()(
|
||||
shape, transfer_manager->HostShapeToDeviceShape(shape))) {
|
||||
arguments.emplace_back(device_shape, shape);
|
||||
xla::ExecutionInput& execution_input = arguments.back();
|
||||
if (xla::Shape::Equal().MinorToMajorOnlyInLayout()(shape, device_shape)) {
|
||||
se::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t);
|
||||
arg_buffers_.emplace_back(
|
||||
/*on_host_shape=*/shape, /*on_device_shape=*/shape,
|
||||
client_->platform(), client_->default_device_ordinal());
|
||||
arg_buffers_.back().set_buffer(dmem, /*index=*/{});
|
||||
arg_ptrs_[i] = &arg_buffers_.back();
|
||||
PopulateExecutionInputBuffer(execution_input, xla::ShapeIndex{}, dmem,
|
||||
donate_buffer, device_ordinal_,
|
||||
xla_allocator_);
|
||||
} else {
|
||||
const XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
|
||||
XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
|
||||
CHECK(xla_tensor && xla_tensor->has_shaped_buffer());
|
||||
arg_ptrs_[i] = const_cast<ShapedBuffer*>(&xla_tensor->shaped_buffer());
|
||||
xla_tensor->shaped_buffer().buffers().ForEachMutableElement(
|
||||
[&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
|
||||
PopulateExecutionInputBuffer(execution_input, index, *buffer,
|
||||
donate_buffer, device_ordinal_,
|
||||
xla_allocator_);
|
||||
});
|
||||
}
|
||||
}
|
||||
return std::move(arguments);
|
||||
}
|
||||
|
||||
// Construct the tensor for given type and buffer.
|
||||
// Construct the tensor for the given type and buffer.
|
||||
static Tensor MakeTensor(DataType dtype, const TensorShape& shape,
|
||||
se::DeviceMemoryBase buffer, Allocator* allocator) {
|
||||
size_t expected_size = shape.num_elements() * DataTypeSize(dtype);
|
||||
@ -247,28 +279,26 @@ static Tensor GetOrCreateTensorForOutput(
|
||||
int output_num, OpKernelContext* ctx, int missing_ctx_input_prefix,
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias,
|
||||
absl::Span<const int> input_mapping,
|
||||
const ResourceVarsSnapshot& resource_var_snapshots, DataType output_dtype,
|
||||
const TensorShape& output_shape, se::DeviceMemoryBase output_buffer,
|
||||
Allocator* output_allocator) {
|
||||
const std::map<int, const Tensor*>& resource_vars_snapshots,
|
||||
DataType output_dtype, const TensorShape& output_shape,
|
||||
se::DeviceMemoryBase output_buffer, Allocator* output_allocator) {
|
||||
xla::ShapeIndex output_index = input_output_alias.shape().IsTuple()
|
||||
? xla::ShapeIndex({output_num})
|
||||
: xla::ShapeIndex({});
|
||||
|
||||
CHECK(input_output_alias.shape().IsTuple() || output_num == 0);
|
||||
if (absl::optional<xla::HloInputOutputAliasConfig::Alias> alias =
|
||||
input_output_alias.GetAliasedParameter(output_index)) {
|
||||
VLOG(3) << "Found alias: " << alias->ToString();
|
||||
int tf_param =
|
||||
input_mapping[alias->parameter_number] - missing_ctx_input_prefix;
|
||||
const Tensor* input_tensor = &ctx->input(tf_param);
|
||||
|
||||
// If input tensor is a resource variable, alias to the snapshot we took at
|
||||
// entry time.
|
||||
if (input_tensor->dtype() == DT_RESOURCE) {
|
||||
const absl::optional<Tensor>& v =
|
||||
resource_var_snapshots.at(missing_ctx_input_prefix + tf_param);
|
||||
CHECK(v.has_value());
|
||||
return *v;
|
||||
const Tensor input_tensor =
|
||||
ctx->input(tf_param).dtype() != DT_RESOURCE
|
||||
? ctx->input(tf_param)
|
||||
: *resource_vars_snapshots.at(missing_ctx_input_prefix + tf_param);
|
||||
if (output_buffer.opaque() == input_tensor.data()) {
|
||||
return input_tensor;
|
||||
}
|
||||
return *input_tensor;
|
||||
}
|
||||
return MakeTensor(output_dtype, output_shape, output_buffer,
|
||||
output_allocator);
|
||||
@ -291,12 +321,10 @@ static Status SetOutputForConstant(
|
||||
OpKernelContext* ctx, se::Stream* stream,
|
||||
const XlaCompiler::CompilationResult* compilation_result, int output_num) {
|
||||
CHECK(compilation_result->outputs[output_num].is_constant);
|
||||
// Output is a constant.
|
||||
const Tensor& const_tensor =
|
||||
compilation_result->outputs[output_num].constant_value;
|
||||
Tensor* output_tensor;
|
||||
const size_t total_bytes = const_tensor.TotalBytes();
|
||||
if (stream && total_bytes > 0) {
|
||||
if (stream && const_tensor.TotalBytes() > 0) {
|
||||
// Copy host -> device. (Empty tensors don't have backing buffers.)
|
||||
// Manually allocate memory using an XlaTensorBuffer so we can allocate
|
||||
// as much memory as the device requires (as given by
|
||||
@ -335,52 +363,55 @@ static Status SetOutputForConstant(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Creates a list of updates resource variables.
|
||||
static xla::StatusOr<std::vector<VariableInfo>> GatherVariableInfo(
|
||||
OpKernelContext* ctx,
|
||||
const XlaCompiler::CompilationResult* compilation_result,
|
||||
int missing_ctx_input_prefix) {
|
||||
std::vector<VariableInfo> variable_infos;
|
||||
variable_infos.reserve(compilation_result->resource_updates.size());
|
||||
static xla::StatusOr<Var*> GetOrCreateResourceVar(
|
||||
OpKernelContext* ctx, const ResourceHandle& handle,
|
||||
const XlaCompiler::ResourceUpdate& write) {
|
||||
Var* variable = nullptr;
|
||||
TF_RETURN_IF_ERROR(
|
||||
LookupOrCreateResource<Var>(ctx, handle, &variable, [&write](Var** ptr) {
|
||||
*ptr = new Var(write.type);
|
||||
return Status::OK();
|
||||
}));
|
||||
return variable;
|
||||
}
|
||||
|
||||
for (int i = 0; i < compilation_result->resource_updates.size(); ++i) {
|
||||
xla::StatusOr<std::vector<VariableInfo>> GatherVariableInfo(
|
||||
OpKernelContext* ctx,
|
||||
const XlaCompiler::CompilationResult& compilation_result,
|
||||
int missing_ctx_input_prefix) {
|
||||
std::vector<VariableInfo> out;
|
||||
out.reserve(compilation_result.resource_updates.size());
|
||||
for (int i = 0; i < compilation_result.resource_updates.size(); ++i) {
|
||||
const XlaCompiler::ResourceUpdate& write =
|
||||
compilation_result->resource_updates[i];
|
||||
compilation_result.resource_updates[i];
|
||||
int actual_input_index = write.input_index - missing_ctx_input_prefix;
|
||||
if (actual_input_index < 0 || actual_input_index >= ctx->num_inputs()) {
|
||||
return errors::Internal("Invalid input index for variable write.");
|
||||
}
|
||||
|
||||
// TODO(b/35625933): tensorflow::Var should contain a PersistentTensor,
|
||||
// not a Tensor.
|
||||
Var* variable = nullptr;
|
||||
const ResourceHandle handle = HandleFromInput(ctx, actual_input_index);
|
||||
TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>(ctx, handle, &variable,
|
||||
[&write](Var** ptr) {
|
||||
*ptr = new Var(write.type);
|
||||
return Status::OK();
|
||||
}));
|
||||
variable_infos.emplace_back(actual_input_index, handle.name(), variable);
|
||||
TF_ASSIGN_OR_RETURN(Var * variable,
|
||||
GetOrCreateResourceVar(ctx, handle, write));
|
||||
out.emplace_back(actual_input_index, handle.name(), variable);
|
||||
}
|
||||
return variable_infos;
|
||||
return std::move(out);
|
||||
}
|
||||
|
||||
Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
OpKernelContext* ctx,
|
||||
const XlaCompiler::CompilationResult* compilation_result,
|
||||
ScopedShapedBuffer output, int missing_ctx_input_prefix,
|
||||
absl::Span<VariableInfo> variable_infos,
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias,
|
||||
const ResourceVarsSnapshot& resource_var_snapshots) {
|
||||
const std::map<int, const Tensor*>& resource_vars) {
|
||||
se::Stream* stream =
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
||||
Allocator* allocator = ctx->device()->GetAllocator({});
|
||||
|
||||
// Computation output should always be a tuple.
|
||||
if (VLOG_IS_ON(2)) {
|
||||
VLOG(2) << "Result tuple shape: " << output.on_host_shape().DebugString();
|
||||
VLOG(2) << "Result tuple shape (on device): "
|
||||
<< output.on_device_shape().DebugString();
|
||||
}
|
||||
VLOG(2) << "Result tuple shape: " << output.on_host_shape().DebugString();
|
||||
VLOG(2) << "Result tuple shape (on device): "
|
||||
<< output.on_device_shape().DebugString();
|
||||
CHECK_EQ(ctx->num_outputs(), compilation_result->outputs.size());
|
||||
|
||||
// If the on-host-shape isn't a tuple, create a new single-element tuple
|
||||
@ -438,8 +469,8 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
for (int i = 0; i < ctx->num_outputs(); ++i) {
|
||||
const TensorShape& shape = output_tensor_shapes[i];
|
||||
const DataType& type = compilation_result->outputs[i].type;
|
||||
VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type "
|
||||
<< DataTypeString(type);
|
||||
VLOG(2) << "Populating output for retval " << i << " shape "
|
||||
<< shape.DebugString() << " type " << DataTypeString(type);
|
||||
if (type == DT_VARIANT) {
|
||||
return errors::Unimplemented(
|
||||
"Support for TensorList crossing the XLA/TF boundary "
|
||||
@ -467,30 +498,37 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
se::DeviceMemoryBase buffer = output.buffer({output_num});
|
||||
Tensor output_tensor = GetOrCreateTensorForOutput(
|
||||
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
|
||||
compilation_result->input_mapping, resource_var_snapshots,
|
||||
compilation_result->input_mapping, resource_vars,
|
||||
ctx->expected_output_dtype(i), shape, buffer, allocator);
|
||||
output.set_buffer(se::OwningDeviceMemory(), {output_num});
|
||||
ctx->set_output(i, output_tensor);
|
||||
}
|
||||
output.set_buffer(se::OwningDeviceMemory(), {output_num});
|
||||
++output_num;
|
||||
}
|
||||
|
||||
if (VLOG_IS_ON(3)) {
|
||||
VLOG(3) << ctx->mutable_output(i)->DeviceSafeDebugString();
|
||||
}
|
||||
}
|
||||
|
||||
// Apply variable updates, if any.
|
||||
VLOG(2) << "Applying variable updates";
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::vector<VariableInfo> variable_infos,
|
||||
GatherVariableInfo(ctx, compilation_result, missing_ctx_input_prefix));
|
||||
TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
|
||||
// input_index -> index into variable_infos.
|
||||
absl::flat_hash_map<int, int> variable_info_lookup;
|
||||
for (int i = 0; i < variable_infos.size(); i++) {
|
||||
variable_info_lookup.emplace(variable_infos[i].index(), i);
|
||||
}
|
||||
|
||||
// Apply variable updates, if any.
|
||||
for (int i = 0; i < compilation_result->resource_updates.size(); ++i) {
|
||||
const XlaCompiler::ResourceUpdate& write =
|
||||
compilation_result->resource_updates[i];
|
||||
if (variable_infos[i].var()->tensor()->dtype() != write.type) {
|
||||
int actual_input_index = write.input_index - missing_ctx_input_prefix;
|
||||
CHECK_GE(actual_input_index, 0);
|
||||
CHECK_LT(actual_input_index, ctx->num_inputs());
|
||||
Var* var = variable_infos[variable_info_lookup[actual_input_index]].var();
|
||||
CHECK(var);
|
||||
|
||||
VLOG(2) << "Updating variable #" << i
|
||||
<< " at input index: " << actual_input_index << " with shape "
|
||||
<< write.shape.DebugString() << "; variable tensor has shape: "
|
||||
<< var->tensor()->shape().DebugString();
|
||||
|
||||
if (var->is_initialized && var->tensor()->dtype() != write.type) {
|
||||
return errors::Internal("Mismatched type in variable write");
|
||||
}
|
||||
|
||||
@ -504,14 +542,14 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
}
|
||||
} else {
|
||||
se::DeviceMemoryBase buffer = output.buffer({output_num});
|
||||
output.set_buffer(se::OwningDeviceMemory(), {output_num});
|
||||
output_tensor = GetOrCreateTensorForOutput(
|
||||
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
|
||||
compilation_result->input_mapping, resource_var_snapshots, write.type,
|
||||
compilation_result->input_mapping, resource_vars, write.type,
|
||||
write.shape, buffer, allocator);
|
||||
}
|
||||
*variable_infos[i].var()->tensor() = output_tensor;
|
||||
variable_infos[i].var()->is_initialized |= write.modified;
|
||||
output.set_buffer(se::OwningDeviceMemory(), {output_num});
|
||||
var->is_initialized |= write.modified;
|
||||
*var->tensor() = output_tensor;
|
||||
++output_num;
|
||||
}
|
||||
return Status::OK();
|
||||
@ -562,7 +600,7 @@ Status XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
arg.name = std::string(variable.name());
|
||||
arg.kind = XlaCompiler::Argument::kResource;
|
||||
arg.resource_kind = XlaResource::kVariable;
|
||||
if (variable.var()) {
|
||||
if (variable.var() && variable.var()->is_initialized) {
|
||||
const Tensor* value = variable.var()->tensor();
|
||||
arg.type = value->dtype();
|
||||
arg.shape = value->shape();
|
||||
|
@ -81,6 +81,12 @@ class VariableInfo {
|
||||
bool lock_held_ = false;
|
||||
};
|
||||
|
||||
// Creates a list of updated resource variables.
|
||||
xla::StatusOr<std::vector<VariableInfo>> GatherVariableInfo(
|
||||
OpKernelContext* ctx,
|
||||
const XlaCompiler::CompilationResult& compilation_result,
|
||||
int missing_ctx_input_prefix);
|
||||
|
||||
// Takes a snapshot of the values of resource variable arguments, whose indices
|
||||
// are specified in `variable_indices` argument. We snapshot tensors that back
|
||||
// resource variables since concurrent updates may modify the shape, and it is
|
||||
@ -124,7 +130,7 @@ class XlaComputationLaunchContext {
|
||||
// objects.
|
||||
XlaComputationLaunchContext(xla::LocalClient* client,
|
||||
se::DeviceMemoryAllocator* xla_allocator,
|
||||
bool allocate_xla_tensors,
|
||||
int device_ordinal, bool allocate_xla_tensors,
|
||||
bool use_multiple_streams);
|
||||
|
||||
// Builds a XlaCompiler::Argument vector from the arguments to an XlaLaunch
|
||||
@ -142,10 +148,12 @@ class XlaComputationLaunchContext {
|
||||
// missing and adjusts input indices accordingly. All elements in kernel's
|
||||
// input_mapping must be greater than or equal to `missing_ctx_input_prefix`
|
||||
// (in other words, no inputs actually required by the kernel can be missing).
|
||||
void PopulateInputs(OpKernelContext* ctx,
|
||||
const XlaCompiler::CompilationResult* compilation_result,
|
||||
const ResourceVarsSnapshot& variables,
|
||||
int missing_ctx_input_prefix);
|
||||
xla::StatusOr<std::vector<xla::ExecutionInput>> PopulateInputs(
|
||||
OpKernelContext* ctx,
|
||||
const XlaCompiler::CompilationResult* compilation_result,
|
||||
const std::map<int, const Tensor*>& resource_vars,
|
||||
int missing_ctx_input_prefix,
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias);
|
||||
|
||||
// Given the XLA output in `output`, populate all outputs of `ctx`. Also
|
||||
// writes out the resource variable updates.
|
||||
@ -161,20 +169,16 @@ class XlaComputationLaunchContext {
|
||||
OpKernelContext* ctx,
|
||||
const XlaCompiler::CompilationResult* compilation_result,
|
||||
xla::ScopedShapedBuffer output, int missing_ctx_input_prefix,
|
||||
absl::Span<VariableInfo> variable_infos,
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias,
|
||||
const ResourceVarsSnapshot& resource_var_snapshots);
|
||||
|
||||
// Return the argument list. Only valid after PopulateInputs() has been
|
||||
// called.
|
||||
const std::vector<xla::ShapedBuffer*>& arguments() const { return arg_ptrs_; }
|
||||
const std::map<int, const Tensor*>& resource_vars);
|
||||
|
||||
private:
|
||||
xla::LocalClient* client_;
|
||||
se::DeviceMemoryAllocator* xla_allocator_;
|
||||
bool allocate_xla_tensors_;
|
||||
bool use_multiple_streams_;
|
||||
std::deque<xla::ShapedBuffer> arg_buffers_;
|
||||
std::vector<xla::ShapedBuffer*> arg_ptrs_;
|
||||
int device_ordinal_;
|
||||
};
|
||||
|
||||
// A simple TensorBuffer implementation that allows us to create Tensors that
|
||||
|
@ -74,7 +74,6 @@ We have several choices on how to lower the host-side part from LHLO:
|
||||
* (Pro) easy to implement library calls (cuDNN, cuBLAS, cuFFT, etc), as
|
||||
TFRT ops are interpreted by C++ code.
|
||||
* (Con) host side is under development and not tested.
|
||||
* (Con) the JAX integration isn’t clear from a runtime point of view
|
||||
* Jitted CPU code
|
||||
* (Pro) great lower-ability. Create a few loops and conditions and it's
|
||||
done.
|
||||
@ -84,8 +83,7 @@ We have several choices on how to lower the host-side part from LHLO:
|
||||
dynamic loading, etc).
|
||||
* Existing (interpreting) XLA runtime
|
||||
|
||||
Tentative conclusion: Use jitted CPU code during the transition, and optionally
|
||||
adopt TFRT in the end.
|
||||
Decision: adopt TFRT, but also support jitting CPU code in TFRT.
|
||||
|
||||
## Migrating Device LLVM IR (Task 3)
|
||||
|
||||
@ -114,7 +112,7 @@ end state of each XLA op:
|
||||
* (Cost) Will be throw-away work if we want to ultimately migrate to
|
||||
Standard.
|
||||
* (Benefit) It is easy and mechanical. Can be done in a short period.
|
||||
* (Benefit) It doesn't benefit more compared to a).
|
||||
* (Benefit) It doesn't benefit more compared to (1).
|
||||
1. Refactor old emitters to be like LHLO -> MLIR GPU + Standard + Loops:
|
||||
* (Cost) Lifting existing emitters to Standard introduces some challenges.
|
||||
Pointers and GEPs need to be converted to MemRefs and SubViews. Ensuring
|
||||
@ -134,6 +132,19 @@ end state of each XLA op:
|
||||
* (Benefit) unified stack; community support; portability; more
|
||||
optimization potentials.
|
||||
|
||||
Conclusions:
|
||||
|
||||
* Don't go for (2). (1) or (3) are just better than (2). (2) costs more than
|
||||
(1), since it requires a lot of mechanical refactoring. With (1) we can
|
||||
still achieve the goal of enabling XLA to pick up MLIR emitters. This is by
|
||||
doing LHLO -> LLVM IR -> run legacy device emitters.
|
||||
* ElementalIrEmitter ops go for (4), but not incrementally. There is no way to
|
||||
do it op by op, because all elementally-emitted ops are connected into the
|
||||
same graph. This work can also serve as a unification point of several
|
||||
on-going forces (xla/service/mlir\_gpu, the kernel generator, Linalg).
|
||||
* All other ops go for (1). As a stretch goal, they might be migrated to (3)
|
||||
or (4).
|
||||
|
||||
## Prioritization
|
||||
|
||||
While all three tasks mentioned above are parallelizable, under limited
|
||||
@ -210,26 +221,19 @@ The exact profiling can't be easily done for MLIR-generated ops, since:
|
||||
|
||||
### Step 3: (Task 2) Migrating Thunks
|
||||
|
||||
This step migrates all host ops and library calls. This step will eliminate most
|
||||
of the thunks and produce serializable MLIR instead.
|
||||
|
||||
There are roughly three kinds of thunks:
|
||||
|
||||
As a note, there are roughly three kinds of thunks:
|
||||
* KernelThunk, which launches a kernel.
|
||||
* Control flow thunks, which has host control flow logic (conditional, while,
|
||||
for, sequence) and launch body kernels.
|
||||
* Library thunks: cuDNN, cuBLAS, cuFFT, NCCL, etc.
|
||||
|
||||
The **bottom line** is to:
|
||||
The plan is:
|
||||
* Make Thunks (de)serializable.
|
||||
* Help improve TFRT to a state where it can support these semantics.
|
||||
* As the state improves, migrate individual thunks incrementally.
|
||||
|
||||
* Create a Thunk dialect that provides (de)serialize logic for all existing
|
||||
C++-based Thunks.
|
||||
* Change emitters to emit a graph of Thunk dialect.
|
||||
|
||||
**Optionally**, we can relieve some thunks from C++ implementation. KernelThunk
|
||||
can lower to the GPU LaunchKernelOp. Control flow thunks can leverage the CFG
|
||||
Dialect for loops and conditions, combined with LaunchKernelOp. This optional
|
||||
step requires profiling and stream support.
|
||||
These action items are only partially ordered. The actual execution order /
|
||||
engineering parallelism is to be evaluated as it goes.
|
||||
|
||||
### Step 4: (Task 3) Migrated ElementalIrEmitter
|
||||
|
||||
|
@ -106,6 +106,25 @@ gentbl(
|
||||
td_srcs = [":hlo_ops_td_files"],
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "hlo_ops_pattern_gen",
|
||||
strip_include_prefix = "include",
|
||||
tbl_outs = [
|
||||
(
|
||||
"-gen-rewriters",
|
||||
"include/mlir-hlo/Dialect/mhlo/IR/hlo_patterns.cc.inc",
|
||||
),
|
||||
],
|
||||
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||
td_file = "lib/Dialect/mhlo/IR/hlo_patterns.td",
|
||||
td_srcs = [
|
||||
":hlo_ops_td_files",
|
||||
"@llvm-project//mlir:StdOpsTdFiles",
|
||||
"@llvm-project//mlir:include/mlir/Dialect/Shape/IR/ShapeBase.td",
|
||||
"@llvm-project//mlir:include/mlir/Dialect/Shape/IR/ShapeOps.td",
|
||||
],
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "lhlo_ops_inc_gen",
|
||||
strip_include_prefix = "include",
|
||||
@ -203,6 +222,7 @@ cc_library(
|
||||
],
|
||||
includes = ["include"],
|
||||
deps = [
|
||||
"hlo_ops_pattern_gen",
|
||||
":canonicalize_inc_gen",
|
||||
":chlo_ops_inc_gen",
|
||||
":convert_op_folder",
|
||||
@ -548,6 +568,26 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "legalize_gather_to_torch_index_select",
|
||||
srcs = ["lib/Dialect/mhlo/transforms/legalize_gather_to_torch_index_select.cc"],
|
||||
hdrs = [
|
||||
"include/mlir-hlo/Dialect/mhlo/transforms/passes.h",
|
||||
"include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h",
|
||||
],
|
||||
deps = [
|
||||
":hlo",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "legalize_tanh_to_approximation",
|
||||
srcs = ["lib/Dialect/mhlo/transforms/legalize_tanh_to_approximation.cc"],
|
||||
@ -590,6 +630,7 @@ cc_library(
|
||||
"lib/Dialect/mhlo/transforms/generated_lower_complex.inc",
|
||||
"lib/Dialect/mhlo/transforms/lower_complex.cc",
|
||||
"lib/Dialect/mhlo/transforms/lower_general_dot.cc",
|
||||
"lib/Dialect/mhlo/transforms/optimize_mhlo.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"include/mlir-hlo/Dialect/mhlo/transforms/passes.h",
|
||||
@ -649,7 +690,9 @@ cc_library(
|
||||
deps = [
|
||||
":hlo",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:SCFDialect",
|
||||
"@llvm-project//mlir:Shape",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
)
|
||||
@ -661,6 +704,7 @@ cc_library(
|
||||
"lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc",
|
||||
"lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc",
|
||||
"lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc",
|
||||
"lib/Dialect/mhlo/transforms/optimize_mhlo_pass.cc",
|
||||
"lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc",
|
||||
"lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc",
|
||||
],
|
||||
@ -678,6 +722,7 @@ cc_library(
|
||||
"@llvm-project//mlir:LLVMDialect",
|
||||
"@llvm-project//mlir:LLVMTransforms",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:SCFDialect",
|
||||
"@llvm-project//mlir:Shape",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
@ -695,6 +740,7 @@ cc_library(
|
||||
":hlo_dialect_registration",
|
||||
":hlo_legalize_to_lhlo",
|
||||
":legalize_control_flow",
|
||||
":legalize_gather_to_torch_index_select",
|
||||
":legalize_tanh_to_approximation",
|
||||
":legalize_to_linalg",
|
||||
":legalize_to_standard",
|
||||
|
@ -21,9 +21,9 @@ limitations under the License.
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
|
||||
include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td"
|
||||
include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td"
|
||||
include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
|
||||
include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td"
|
||||
include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td"
|
||||
|
||||
def HLO_Dialect : Dialect {
|
||||
let name = "mhlo";
|
||||
|
@ -38,6 +38,13 @@ void PopulateGeneralDotOpLoweringPatterns(OwningRewritePatternList *patterns,
|
||||
void PopulateComplexLoweringPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns);
|
||||
|
||||
void PopulateOptimizeMHLOPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns);
|
||||
|
||||
// Rewrite patterns for gather to equivalent torch index select legalization.
|
||||
void PopulateGatherToTorchIndexSelectPatterns(
|
||||
mlir::MLIRContext *context, OwningRewritePatternList *patterns);
|
||||
|
||||
void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns,
|
||||
MLIRContext *ctx);
|
||||
|
||||
|
@ -35,6 +35,7 @@ limitations under the License.
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/Support/MathExtras.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
@ -59,6 +60,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h"
|
||||
|
||||
namespace mlir {
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_patterns.cc.inc"
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_structs.cc.inc"
|
||||
namespace mhlo {
|
||||
|
||||
@ -744,7 +746,8 @@ class DynamicBroadcastInDimOpNotActuallyDynamic
|
||||
|
||||
void DynamicBroadcastInDimOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList& results, MLIRContext* context) {
|
||||
results.insert<DynamicBroadcastInDimOpNotActuallyDynamic>(context);
|
||||
results.insert<DynamicBroadcastInDimOpNotActuallyDynamic,
|
||||
DynamicBroadcastToOwnShape>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1465,7 +1468,7 @@ static LogicalResult Verify(PadOp op) {
|
||||
|
||||
static LogicalResult Verify(ReshapeOp op) {
|
||||
// If the operand type is dynamically shaped there is nothing to verify.
|
||||
auto operand_ty = op.operand().getType().cast<RankedTensorType>();
|
||||
auto operand_ty = op.operand().getType().dyn_cast<RankedTensorType>();
|
||||
if (!operand_ty || !operand_ty.hasStaticShape()) return success();
|
||||
|
||||
// If the operand type is statically shaped (not required) the number of
|
||||
|
@ -0,0 +1,29 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Canonicalization patterns for the MHLO dialect.
|
||||
|
||||
include "mlir/Dialect/Shape/IR/ShapeOps.td"
|
||||
include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"
|
||||
|
||||
def EqualBinaryOperands : Constraint<CPred<"$0 == $1">>;
|
||||
|
||||
// Canonicalization patterns.
|
||||
|
||||
def DynamicBroadcastToOwnShape : Pat<
|
||||
(HLO_DynamicBroadcastInDimOp:$op $arg0,
|
||||
(Shape_ToExtentTensorOp (Shape_ShapeOfOp $arg1)), $attr),
|
||||
(replaceWithValue $arg0), [(EqualBinaryOperands $arg0, $arg1)]>;
|
||||
|
@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/OperationSupport.h" // from @llvm-project
|
||||
@ -22,6 +24,7 @@ limitations under the License.
|
||||
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h"
|
||||
|
||||
namespace mlir {
|
||||
@ -74,10 +77,6 @@ struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern<ChloOpTy> {
|
||||
// - Legal combinations of degenerate (1-dim) implicit broadcasting.
|
||||
// The restriction on broadcast_dims derives from the definition of the
|
||||
// `shape.broadcast` op, which only supports prefix-padding.
|
||||
//
|
||||
// It may be possible to expand this pattern to operate on unranked tensors in
|
||||
// the future by emitting more code to dynamically differentiate based on rank.
|
||||
// Whether that is of any practical benefit remains to be seen.
|
||||
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
||||
struct ConvertRankedDynamicBroadcastBinaryOp
|
||||
: public OpRewritePattern<ChloOpTy> {
|
||||
@ -160,6 +159,68 @@ struct ConvertRankedDynamicBroadcastBinaryOp
|
||||
}
|
||||
};
|
||||
|
||||
// Converts a broadcasting binary operation with a scalar operand and an
|
||||
// unranked operand to a ranked broadcasting operation by dynamically reshaping
|
||||
// the unranked operand to a 1D tensor. This will always be safe because
|
||||
// broadcasting from a scalar to another shape always works.
|
||||
template <typename ChloOpTy, typename HloOpTy>
|
||||
struct ConvertUnrankedScalarDynamicBroadcastBinaryOp
|
||||
: public OpRewritePattern<ChloOpTy> {
|
||||
using OpRewritePattern<ChloOpTy>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(ChloOpTy op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
Value lhs = op.lhs();
|
||||
Value rhs = op.rhs();
|
||||
|
||||
auto lhs_ranked_type = lhs.getType().dyn_cast<RankedTensorType>();
|
||||
auto lhs_unranked_type = lhs.getType().dyn_cast<UnrankedTensorType>();
|
||||
|
||||
auto rhs_ranked_type = rhs.getType().dyn_cast<RankedTensorType>();
|
||||
auto rhs_unranked_type = rhs.getType().dyn_cast<UnrankedTensorType>();
|
||||
|
||||
bool lhs_is_scalar = lhs_ranked_type &&
|
||||
lhs_ranked_type.getShape().empty() &&
|
||||
rhs_unranked_type;
|
||||
bool rhs_is_scalar = rhs_ranked_type &&
|
||||
rhs_ranked_type.getShape().empty() &&
|
||||
lhs_unranked_type;
|
||||
|
||||
// Only support the case where exactly one operand is scalar and the other
|
||||
// is unranked. Other patterns in this file will create more efficient
|
||||
// lowerings for cases where both ranks are known or will handle the more
|
||||
// generic case of both inputs being unranked.
|
||||
if (!(lhs_is_scalar ^ rhs_is_scalar)) return failure();
|
||||
|
||||
auto result_type = op.getResult().getType().template dyn_cast<TensorType>();
|
||||
|
||||
// Reshape the non-scalar value into a dynamically sized, rank-1 tensor
|
||||
Value shape =
|
||||
rewriter.create<shape::ShapeOfOp>(loc, lhs_is_scalar ? rhs : lhs);
|
||||
Value num_elements = rewriter.create<shape::NumElementsOp>(loc, shape);
|
||||
Value size = rewriter.create<shape::SizeToIndexOp>(loc, num_elements);
|
||||
Value size_tensor = rewriter.create<TensorFromElementsOp>(loc, size);
|
||||
Value reshaped = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
loc, RankedTensorType::get({-1}, result_type.getElementType()),
|
||||
lhs_is_scalar ? rhs : lhs, size_tensor);
|
||||
|
||||
// Create a new ranked Chlo op that will be further lowered by other
|
||||
// patterns into Mhlo.
|
||||
SmallVector<Value, 2> operands{lhs_is_scalar ? lhs : reshaped,
|
||||
rhs_is_scalar ? rhs : reshaped};
|
||||
Value computed = rewriter.create<ChloOpTy>(
|
||||
loc, SmallVector<Type, 1>{reshaped.getType()}, operands, op.getAttrs());
|
||||
|
||||
// Reshape the result back into an unranked tensor.
|
||||
Value shape_tensor = rewriter.create<shape::ToExtentTensorOp>(
|
||||
loc, RankedTensorType::get({-1}, rewriter.getIndexType()), shape);
|
||||
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, result_type,
|
||||
computed, shape_tensor);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
||||
void PopulateForBinaryOp(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns) {
|
||||
@ -169,6 +230,9 @@ void PopulateForBinaryOp(MLIRContext *context,
|
||||
patterns->insert<
|
||||
ConvertRankedDynamicBroadcastBinaryOp<ChloOpTy, HloOpTy, Adaptor>>(
|
||||
context, 5);
|
||||
patterns->insert<
|
||||
ConvertUnrankedScalarDynamicBroadcastBinaryOp<ChloOpTy, HloOpTy>>(
|
||||
context);
|
||||
}
|
||||
|
||||
template <typename FromOpTy, typename ToOpTy>
|
||||
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
@ -37,6 +38,7 @@ struct TestChloLegalizeToHloPass
|
||||
// The conversion uses helpers from the Standard dialect.
|
||||
conversionTarget.addLegalDialect<mlir::StandardOpsDialect>();
|
||||
conversionTarget.addLegalDialect<mlir::shape::ShapeDialect>();
|
||||
conversionTarget.addLegalDialect<mlir::scf::SCFDialect>();
|
||||
|
||||
PopulateLegalizeChloToHloPatterns(&getContext(), &conversionPatterns);
|
||||
|
||||
|
@ -42,9 +42,6 @@ namespace {
|
||||
|
||||
template <typename T>
|
||||
using BaseOpConversion = BufferAssignmentOpConversionPattern<T>;
|
||||
using StdReturnOpConverter =
|
||||
detail::BufferAssignmentReturnOpConverter<mlir::ReturnOp, mlir::ReturnOp,
|
||||
lmhlo::CopyOp, true>;
|
||||
|
||||
Value InsertDynamicAllocAndDealloc(Location loc, Value result,
|
||||
Value shape_operand,
|
||||
@ -272,27 +269,21 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> {
|
||||
// Copy over the operations inside the region.
|
||||
rewriter.inlineRegionBefore(op.body(), new_op.body(), new_op.body().end());
|
||||
|
||||
// Create new block arguments with correct type.
|
||||
// Convert the region signature to memref and add extra result.
|
||||
auto& entry_block = new_op.body().front();
|
||||
int original_arg_count = entry_block.getNumArguments();
|
||||
for (int i = 0; i < original_arg_count; ++i) {
|
||||
auto old_arg = entry_block.getArgument(i);
|
||||
auto old_type = old_arg.getType().cast<TensorType>();
|
||||
TypeConverter::SignatureConversion sig_conversion(
|
||||
entry_block.getNumArguments() + 1);
|
||||
for (auto arg : entry_block.getArguments()) {
|
||||
auto old_type = arg.getType().cast<TensorType>();
|
||||
auto new_type =
|
||||
MemRefType::get(old_type.getShape(), old_type.getElementType());
|
||||
auto new_arg = entry_block.addArgument(new_type);
|
||||
rewriter.replaceUsesOfBlockArgument(old_arg, new_arg);
|
||||
sig_conversion.addInputs(arg.getArgNumber(), new_type);
|
||||
}
|
||||
// Add an argument for the result.
|
||||
entry_block.addArgument(
|
||||
entry_block.getArgument(original_arg_count).getType());
|
||||
// Remove the old arguments.
|
||||
for (int i = original_arg_count - 1; i >= 0; --i) {
|
||||
entry_block.eraseArgument(i);
|
||||
}
|
||||
// Insert terminator at the end.
|
||||
rewriter.setInsertionPointToEnd(&entry_block);
|
||||
rewriter.create<lmhlo::TerminatorOp>(loc);
|
||||
auto return_op = cast<mhlo::ReturnOp>(entry_block.getTerminator());
|
||||
auto result_type = return_op.results().front().getType().cast<TensorType>();
|
||||
sig_conversion.addInputs({MemRefType::get(result_type.getShape(),
|
||||
result_type.getElementType())});
|
||||
rewriter.applySignatureConversion(&new_op.body(), sig_conversion);
|
||||
|
||||
rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
|
||||
|
||||
@ -300,6 +291,12 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> {
|
||||
}
|
||||
};
|
||||
|
||||
// Legalize mhlo.return to a lmhlo.copy and lmhlo.terminator. This functionality
|
||||
// is provided by mlir buffer assignment, so use the pattern from there.
|
||||
// TODO(DFKI): Move this out of detail.
|
||||
using HloToLhloReturnOpConverter = detail::BufferAssignmentReturnOpConverter<
|
||||
mhlo::ReturnOp, lmhlo::TerminatorOp, lmhlo::CopyOp, false>;
|
||||
|
||||
class HloToLhloTensorLoadOpConverter
|
||||
: public BaseOpConversion<mlir::TensorLoadOp> {
|
||||
public:
|
||||
@ -312,7 +309,6 @@ class HloToLhloTensorLoadOpConverter
|
||||
}
|
||||
};
|
||||
|
||||
// TODO(b/137624192): Rewrite into a copy and elide copy if possible.
|
||||
class HloToLhloTensorStoreOpConverter
|
||||
: public BaseOpConversion<mlir::TensorStoreOp> {
|
||||
public:
|
||||
@ -506,6 +502,7 @@ void populateHLOToLHLOConversionPattern(
|
||||
HloToLhloOpConverter<mhlo::SubOp>,
|
||||
HloToLhloOpConverter<mhlo::TanhOp>,
|
||||
HloToLhloReduceOpConverter,
|
||||
HloToLhloReturnOpConverter,
|
||||
HloToLhloTensorLoadOpConverter,
|
||||
HloToLhloTensorStoreOpConverter
|
||||
>(context, bufferAssignment, converter);
|
||||
|
@ -0,0 +1,152 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
namespace mhlo {
|
||||
namespace {
|
||||
|
||||
struct GatherIsTorchIndexSelect : public OpRewritePattern<GatherOp> {
|
||||
using OpRewritePattern<GatherOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(GatherOp gather,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto start_indices = gather.start_indices();
|
||||
auto start_indices_ty = start_indices.getType().cast<ShapedType>();
|
||||
if (!start_indices_ty.hasRank()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto operand = gather.operand();
|
||||
auto operand_ty = operand.getType().cast<ShapedType>();
|
||||
if (!operand_ty.hasRank()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
int64_t index_vector_dim =
|
||||
std::max<int64_t>(0, start_indices_ty.getRank() - 1);
|
||||
|
||||
// We can use torch_index_select if the last dimension represents the
|
||||
// gather indices.
|
||||
auto dimension_numbers = gather.dimension_numbers();
|
||||
if (dimension_numbers.index_vector_dim().getValue().getSExtValue() !=
|
||||
index_vector_dim) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Index select only works across a single dimension.
|
||||
if (!start_indices_ty.getShape().empty() &&
|
||||
start_indices_ty.getShape().back() != 1) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Only support the default case for start_index_map.
|
||||
if (dimension_numbers.start_index_map().getType().getRank() != 1 ||
|
||||
dimension_numbers.start_index_map()
|
||||
.getValue(0)
|
||||
.cast<IntegerAttr>()
|
||||
.getValue() != 0) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto result_ty = gather.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
if (!result_ty) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Offset dimensions should be the defaults.
|
||||
if (dimension_numbers.offset_dims().getType().getNumElements() !=
|
||||
result_ty.getRank() - index_vector_dim) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
for (auto it : llvm::enumerate(dimension_numbers.offset_dims())) {
|
||||
if ((it.index() + index_vector_dim) != it.value()) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
for (auto it : llvm::enumerate(gather.slice_sizes().getIntValues())) {
|
||||
// First shape value must be 1.
|
||||
if (it.index() == 0) {
|
||||
if (it.value().getSExtValue() != 1) {
|
||||
return failure();
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// The gather needs to index the entire slice for each other dimension.
|
||||
if (it.value().getSExtValue() != operand_ty.getDimSize(it.index())) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
llvm::SmallVector<int64_t, 4> index_select_shape =
|
||||
llvm::to_vector<4>(start_indices_ty.getShape());
|
||||
|
||||
for (auto dim : operand_ty.getShape().drop_front()) {
|
||||
index_select_shape.push_back(dim);
|
||||
}
|
||||
|
||||
if (!dimension_numbers.collapsed_slice_dims().getType().hasRank() ||
|
||||
dimension_numbers.collapsed_slice_dims().getType().getNumElements() !=
|
||||
1 ||
|
||||
dimension_numbers.collapsed_slice_dims().getValue<int64_t>({0}) != 0) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto torch_index_select = rewriter.create<TorchIndexSelectOp>(
|
||||
gather.getLoc(),
|
||||
RankedTensorType::get(index_select_shape, operand_ty.getElementType()),
|
||||
operand, gather.start_indices(), rewriter.getI64IntegerAttr(0),
|
||||
rewriter.getI64IntegerAttr(0));
|
||||
|
||||
rewriter.replaceOpWithNewOp<ReshapeOp>(gather, gather.getType(),
|
||||
torch_index_select);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct LegalizeGatherToTorchIndexSelect
|
||||
: public PassWrapper<LegalizeGatherToTorchIndexSelect, FunctionPass> {
|
||||
/// Perform the lowering of standard dialect operations to approximations.
|
||||
void runOnFunction() override {
|
||||
OwningRewritePatternList patterns;
|
||||
PopulateGatherToTorchIndexSelectPatterns(&getContext(), &patterns);
|
||||
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void PopulateGatherToTorchIndexSelectPatterns(
|
||||
mlir::MLIRContext *context, OwningRewritePatternList *patterns) {
|
||||
patterns->insert<GatherIsTorchIndexSelect>(context);
|
||||
}
|
||||
|
||||
static PassRegistration<LegalizeGatherToTorchIndexSelect> legalize_hlo_pass(
|
||||
"mhlo-legalize-gather-to-torch-index-select",
|
||||
"Legalizes gathers to a torch index select.");
|
||||
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
@ -298,8 +298,8 @@ class DataMovementOpConverter : public OpConversionPattern<OpTy> {
|
||||
auto nloops = resultType.getRank();
|
||||
auto loc = op.getLoc();
|
||||
auto linalgOp = rewriter.create<linalg::GenericOp>(
|
||||
loc, isLHLO ? ArrayRef<Type>{} : resultType, args, /*inputCount=*/1,
|
||||
/*outputCount=*/1, indexing_maps, GetNParallelLoopsAttrs(nloops),
|
||||
loc, isLHLO ? ArrayRef<Type>{} : resultType, args, /*argsIn=*/1,
|
||||
/*argsOut=*/1, indexing_maps, GetNParallelLoopsAttrs(nloops),
|
||||
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
|
||||
nestedBuilder.create<linalg::YieldOp>(loc, *args.begin());
|
||||
});
|
||||
@ -420,7 +420,7 @@ class LhloBroadcastInDimConverter
|
||||
rewriter.create<LoadOp>(loc, operand, llvm::makeArrayRef({zero}));
|
||||
rewriter.create<linalg::GenericOp>(
|
||||
loc, llvm::None, llvm::makeArrayRef(operand_adaptor.output()),
|
||||
/*inputCount=*/0, /*outputCount=*/1,
|
||||
/*argsIn=*/0, /*argsOut=*/1,
|
||||
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
|
||||
GetNParallelLoopsAttrs(nloops),
|
||||
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
|
||||
@ -433,7 +433,7 @@ class LhloBroadcastInDimConverter
|
||||
rewriter.create<linalg::GenericOp>(
|
||||
loc, llvm::None,
|
||||
llvm::makeArrayRef({operand, operand_adaptor.output()}),
|
||||
/*inputCount=*/1, /*outputCount=*/1, indexing_maps,
|
||||
/*argsIn=*/1, /*argsOut=*/1, indexing_maps,
|
||||
GetNParallelLoopsAttrs(nloops),
|
||||
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
|
||||
nestedBuilder.create<linalg::YieldOp>(loc, *args.begin());
|
||||
|
@ -133,8 +133,8 @@ struct ReshapeMemRefCastOpConverter
|
||||
Location loc = op->getLoc();
|
||||
|
||||
auto reshape_op = cast<ReshapeMemRefCastOp>(op);
|
||||
Type dst_type = reshape_op.getResult().getType();
|
||||
auto element_type = dst_type.cast<ShapedType>().getElementType();
|
||||
auto dst_type = reshape_op.getResult().getType().cast<BaseMemRefType>();
|
||||
auto element_type = dst_type.getElementType();
|
||||
|
||||
auto shape = reshape_op.shape();
|
||||
|
||||
@ -162,18 +162,17 @@ struct ReshapeMemRefCastOpConverter
|
||||
desc.setAlignedPtr(rewriter, loc, ptrs_n_offset.aligned_ptr);
|
||||
desc.setOffset(rewriter, loc, ptrs_n_offset.offset);
|
||||
|
||||
auto llvmIndexTy = typeConverter.convertType(rewriter.getIndexType())
|
||||
.cast<LLVM::LLVMType>();
|
||||
auto llvmIndexTyPtr = llvmIndexTy.getPointerTo();
|
||||
auto llvm_index_type = typeConverter.getIndexType();
|
||||
auto llvm_index_ptr_type = llvm_index_type.getPointerTo();
|
||||
Value stride_carried = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, llvmIndexTy,
|
||||
loc, llvm_index_type,
|
||||
rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
|
||||
for (int i = shape_length - 1; i >= 0; --i) {
|
||||
Value pos = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, llvmIndexTy,
|
||||
loc, llvm_index_type,
|
||||
rewriter.getIntegerAttr(rewriter.getIndexType(), i));
|
||||
Value ptr = rewriter.create<LLVM::GEPOp>(
|
||||
loc, llvmIndexTyPtr, shape_desc.alignedPtr(rewriter, loc),
|
||||
loc, llvm_index_ptr_type, shape_desc.alignedPtr(rewriter, loc),
|
||||
ValueRange{pos});
|
||||
Value extracted_size = rewriter.create<LLVM::LoadOp>(loc, ptr);
|
||||
desc.setSize(rewriter, loc, i, extracted_size);
|
||||
@ -188,7 +187,7 @@ struct ReshapeMemRefCastOpConverter
|
||||
rewriter.replaceOp(op, {desc});
|
||||
} else {
|
||||
Value rank = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, llvmIndexTy,
|
||||
loc, llvm_index_type,
|
||||
rewriter.getIntegerAttr(rewriter.getIndexType(), shape_length));
|
||||
Value alloca =
|
||||
typeConverter.promoteOneMemRefDescriptor(loc, desc, rewriter);
|
||||
@ -199,15 +198,127 @@ struct ReshapeMemRefCastOpConverter
|
||||
{rank, void_ptr});
|
||||
rewriter.replaceOp(op, {unranked_desc});
|
||||
}
|
||||
} else {
|
||||
/*
|
||||
* TODO(pifon, herhut):
|
||||
* Compute strides with llvm.loop;
|
||||
* Use UnrankedMemrefDescr::ComputeSize with Alloca;
|
||||
* Set all the fields using getelementptr.
|
||||
*/
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
// The shape is a rank-1 tensor with unknown length.
|
||||
Value result_rank = shape_desc.size(rewriter, loc, 0);
|
||||
// TODO(herhut): Propely handle address spaces.
|
||||
unsigned address_space = 0;
|
||||
auto target_type =
|
||||
typeConverter
|
||||
.convertType(UnrankedMemRefType::get(element_type, address_space))
|
||||
.cast<LLVM::LLVMType>();
|
||||
// Create the unranked memref descriptor that holds the ranked one. The
|
||||
// inner descriptor is allocated on stack.
|
||||
UnrankedMemRefDescriptor target_desc =
|
||||
UnrankedMemRefDescriptor::undef(rewriter, loc, target_type);
|
||||
target_desc.setRank(rewriter, loc, result_rank);
|
||||
SmallVector<Value, 1> sizes;
|
||||
UnrankedMemRefDescriptor::computeSizes(rewriter, loc, typeConverter,
|
||||
{target_desc}, sizes);
|
||||
auto void_ptr_type =
|
||||
LLVM::LLVMType::getInt8PtrTy(typeConverter.getDialect());
|
||||
Value ranked_desc_mem = rewriter.create<LLVM::AllocaOp>(
|
||||
loc, void_ptr_type, sizes.front(), llvm::None);
|
||||
target_desc.setMemRefDescPtr(rewriter, loc, ranked_desc_mem);
|
||||
|
||||
// Fill the fixed parts. For this, we cast to a 0-D memref.
|
||||
auto zero_d_memref_type = MemRefType::get({}, element_type);
|
||||
Value as_zero_d = rewriter.create<LLVM::BitcastOp>(
|
||||
loc,
|
||||
typeConverter.convertType(zero_d_memref_type)
|
||||
.cast<LLVM::LLVMType>()
|
||||
.getPointerTo(address_space),
|
||||
ranked_desc_mem);
|
||||
// Some common constants. Use 32 bit where required by gep struct indexes.
|
||||
auto int32_type = typeConverter.convertType(rewriter.getI32Type());
|
||||
Value zero_index = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, typeConverter.getIndexType(), rewriter.getIndexAttr(0));
|
||||
Value zero = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int32_type, rewriter.getI32IntegerAttr(0));
|
||||
Value one = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int32_type, rewriter.getI32IntegerAttr(1));
|
||||
Value two = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int32_type, rewriter.getI32IntegerAttr(2));
|
||||
// Set base_pointer and aligned pointer.
|
||||
auto element_ptr_ptr_type = typeConverter.convertType(element_type)
|
||||
.cast<LLVM::LLVMType>()
|
||||
.getPointerTo(address_space)
|
||||
.getPointerTo(address_space);
|
||||
auto base_gep = rewriter.create<LLVM::GEPOp>(
|
||||
loc, element_ptr_ptr_type, as_zero_d, ValueRange({zero_index, zero}));
|
||||
rewriter.create<LLVM::StoreOp>(loc, ptrs_n_offset.allocated_ptr, base_gep);
|
||||
auto aligned_gep = rewriter.create<LLVM::GEPOp>(
|
||||
loc, element_ptr_ptr_type, as_zero_d, ValueRange({zero_index, one}));
|
||||
rewriter.create<LLVM::StoreOp>(loc, ptrs_n_offset.aligned_ptr, aligned_gep);
|
||||
// Set offset.
|
||||
auto index_ptr_type =
|
||||
typeConverter.getIndexType().getPointerTo(address_space);
|
||||
auto offset_gep = rewriter.create<LLVM::GEPOp>(
|
||||
loc, index_ptr_type, as_zero_d, ValueRange({zero_index, two}));
|
||||
rewriter.create<LLVM::StoreOp>(loc, ptrs_n_offset.offset, offset_gep);
|
||||
|
||||
// Use the offset pointer as base for further addressing. Copy over the
|
||||
// new shape and compute strides. For this, we need to create a loop from
|
||||
// rank - 1 to 0.
|
||||
Value one_index = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, typeConverter.getIndexType(), rewriter.getIndexAttr(1));
|
||||
auto target_shape_base = rewriter.create<LLVM::GEPOp>(
|
||||
loc, index_ptr_type, offset_gep, ValueRange({one}));
|
||||
auto target_strides_base = rewriter.create<LLVM::GEPOp>(
|
||||
loc, index_ptr_type, target_shape_base, ValueRange({result_rank}));
|
||||
auto shape_ptr = shape_desc.alignedPtr(rewriter, loc);
|
||||
auto result_rank_minus_one =
|
||||
rewriter.create<LLVM::SubOp>(loc, result_rank, one_index);
|
||||
|
||||
Block *init_block = rewriter.getInsertionBlock();
|
||||
Block *cond_block =
|
||||
rewriter.splitBlock(init_block, rewriter.getInsertionPoint());
|
||||
rewriter.setInsertionPointToEnd(init_block);
|
||||
rewriter.create<LLVM::BrOp>(
|
||||
loc, ValueRange({result_rank_minus_one, one_index}), cond_block);
|
||||
rewriter.setInsertionPointToStart(cond_block);
|
||||
auto index_arg = cond_block->addArgument(typeConverter.getIndexType());
|
||||
auto stride_arg = cond_block->addArgument(typeConverter.getIndexType());
|
||||
auto pred = rewriter.create<LLVM::ICmpOp>(
|
||||
loc, LLVM::LLVMType::getInt1Ty(typeConverter.getDialect()),
|
||||
LLVM::ICmpPredicate::sge, index_arg, zero_index);
|
||||
|
||||
Block *body_block =
|
||||
rewriter.splitBlock(cond_block, rewriter.getInsertionPoint());
|
||||
rewriter.setInsertionPointToStart(body_block);
|
||||
|
||||
// Copy size from shape to descriptor.
|
||||
auto size_load_gep = rewriter.create<LLVM::GEPOp>(
|
||||
loc, index_ptr_type, shape_ptr, ValueRange{index_arg});
|
||||
auto extracted_size = rewriter.create<LLVM::LoadOp>(loc, size_load_gep);
|
||||
auto size_store_gep = rewriter.create<LLVM::GEPOp>(
|
||||
loc, index_ptr_type, target_shape_base, ValueRange({index_arg}));
|
||||
rewriter.create<LLVM::StoreOp>(loc, extracted_size, size_store_gep);
|
||||
// Write stride value and compute next one.
|
||||
auto stride_store_gep = rewriter.create<LLVM::GEPOp>(
|
||||
loc, index_ptr_type, target_strides_base, ValueRange({index_arg}));
|
||||
rewriter.create<LLVM::StoreOp>(loc, stride_arg, stride_store_gep);
|
||||
auto next_stride =
|
||||
rewriter.create<LLVM::MulOp>(loc, stride_arg, extracted_size);
|
||||
|
||||
// Decrement loop counter and branch back.
|
||||
auto decrement = rewriter.create<LLVM::SubOp>(loc, index_arg, one_index);
|
||||
rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, next_stride}),
|
||||
cond_block);
|
||||
|
||||
Block *remainder =
|
||||
rewriter.splitBlock(body_block, rewriter.getInsertionPoint());
|
||||
|
||||
// Hook up the cond exit to the remainder.
|
||||
rewriter.setInsertionPointToEnd(cond_block);
|
||||
rewriter.create<LLVM::CondBrOp>(loc, pred, body_block, ValueRange(),
|
||||
remainder, ValueRange());
|
||||
|
||||
// Reset position to beginning of new remainder block.
|
||||
rewriter.setInsertionPointToStart(remainder);
|
||||
rewriter.replaceOp(op, {target_desc});
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -51,7 +51,7 @@ class LowerComplex : public PassWrapper<LowerComplex, FunctionPass> {
|
||||
} // end anonymous namespace
|
||||
|
||||
namespace mlir {
|
||||
namespace hlo {
|
||||
namespace mhlo {
|
||||
namespace {
|
||||
|
||||
#include "tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/generated_lower_complex.inc"
|
||||
@ -62,14 +62,14 @@ void PopulateComplexLoweringPatterns(MLIRContext* context,
|
||||
OwningRewritePatternList* patterns) {
|
||||
populateWithGenerated(context, patterns);
|
||||
}
|
||||
} // end namespace hlo
|
||||
} // end namespace mhlo
|
||||
} // end namespace mlir
|
||||
|
||||
// Lowers the complex operations that can be represented using other operations.
|
||||
void LowerComplex::runOnFunction() {
|
||||
// Add lowering patterns to the list.
|
||||
OwningRewritePatternList patterns;
|
||||
mlir::hlo::PopulateComplexLoweringPatterns(&getContext(), &patterns);
|
||||
mlir::mhlo::PopulateComplexLoweringPatterns(&getContext(), &patterns);
|
||||
|
||||
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
||||
}
|
||||
|
@ -0,0 +1,187 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file provides optional optimization patterns for mhlo, canonocalizing
|
||||
// operations to equivalent but potentially more efficient operations.
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
|
||||
#include "mlir/IR/Types.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h"
|
||||
|
||||
using mlir::OwningRewritePatternList;
|
||||
|
||||
namespace mlir {
|
||||
namespace mhlo {
|
||||
namespace {
|
||||
|
||||
// Returns 1D 64-bit dense elements attribute with the given values.
|
||||
static DenseIntElementsAttr GetI64ElementsAttr(ArrayRef<int64_t> values,
|
||||
Builder* builder) {
|
||||
RankedTensorType ty = RankedTensorType::get(
|
||||
{static_cast<int64_t>(values.size())}, builder->getIntegerType(64));
|
||||
return DenseIntElementsAttr::get(ty, values);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GatherOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class GatherIsSlice : public OpRewritePattern<GatherOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(GatherOp gather,
|
||||
PatternRewriter& rewriter) const override {
|
||||
auto dimension_numbers = gather.dimension_numbers();
|
||||
|
||||
// Inputs need to be ranked to lower.
|
||||
if (!gather.operand().getType().cast<ShapedType>().hasRank() ||
|
||||
!gather.operand().getType().cast<ShapedType>().hasStaticShape() ||
|
||||
!gather.start_indices().getType().cast<ShapedType>().hasRank() ||
|
||||
!gather.start_indices().getType().cast<ShapedType>().hasStaticShape()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (dimension_numbers.index_vector_dim().getValue().getSExtValue() != 0) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// TODO(suderman): Handle start index map != {0}.
|
||||
if (!dimension_numbers.start_index_map() ||
|
||||
dimension_numbers.start_index_map().getType().getRank() != 1 ||
|
||||
dimension_numbers.start_index_map().getType().getDimSize(0) != 1 ||
|
||||
dimension_numbers.start_index_map()
|
||||
.getValue({0})
|
||||
.cast<IntegerAttr>()
|
||||
.getValue() != 0) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto result_ty = gather.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
|
||||
// Requires a ranked output.
|
||||
if (!result_ty) {
|
||||
return failure();
|
||||
}
|
||||
if (dimension_numbers.offset_dims().getType().getNumElements() !=
|
||||
result_ty.getRank()) {
|
||||
return failure();
|
||||
}
|
||||
for (auto it : llvm::enumerate(dimension_numbers.offset_dims())) {
|
||||
if (it.index() != it.value()) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
// Verify the gather slice sizes are correct.
|
||||
if (gather.slice_sizes().getNumElements() !=
|
||||
gather.operand().getType().cast<ShapedType>().getRank()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Validate the slice sizes are correct.
|
||||
if (gather.slice_sizes().getType().cast<ShapedType>().getNumElements() <
|
||||
result_ty.getShape().size() + 1) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
for (auto it : llvm::enumerate(result_ty.getShape())) {
|
||||
if (gather.slice_sizes()
|
||||
.getValue(it.index() + 1)
|
||||
.cast<IntegerAttr>()
|
||||
.getValue() != it.value()) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
auto gather_start_indices = gather.start_indices();
|
||||
auto gather_start_indices_ty =
|
||||
gather_start_indices.getType().cast<ShapedType>();
|
||||
|
||||
llvm::SmallVector<Value, 4> slice_start_indices;
|
||||
|
||||
if (gather_start_indices_ty.getRank() == 0) {
|
||||
slice_start_indices.push_back(gather_start_indices);
|
||||
} else if (gather_start_indices_ty.getRank() == 1) {
|
||||
for (int i = 0; i < gather_start_indices_ty.getDimSize(0); i++) {
|
||||
auto start = GetI64ElementsAttr({i}, &rewriter);
|
||||
auto limit = GetI64ElementsAttr({i + 1}, &rewriter);
|
||||
auto stride = GetI64ElementsAttr({1}, &rewriter);
|
||||
auto indicesSlice = rewriter.create<SliceOp>(
|
||||
gather.getLoc(), gather_start_indices, start, limit, stride);
|
||||
auto reshaped = rewriter.create<ReshapeOp>(
|
||||
gather.getLoc(),
|
||||
RankedTensorType::get(
|
||||
{}, indicesSlice.getType().cast<ShapedType>().getElementType()),
|
||||
indicesSlice);
|
||||
slice_start_indices.push_back(reshaped);
|
||||
}
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto sliceSizes = gather.slice_sizes();
|
||||
auto sliceSizesTy = sliceSizes.getType();
|
||||
if (sliceSizesTy.getRank() != 1) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Start indices have implicit zeros when not specified. This is because
|
||||
// Gather occurs similar to slicing where full slices are inferred. Add any
|
||||
// missing zeros as necessary.
|
||||
auto zero = rewriter.create<ConstOp>(
|
||||
gather.getLoc(), rewriter.getZeroAttr(RankedTensorType::get(
|
||||
{}, gather_start_indices_ty.getElementType())));
|
||||
while (slice_start_indices.size() < sliceSizesTy.getDimSize(0)) {
|
||||
slice_start_indices.push_back(zero);
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 5> sliceShape;
|
||||
for (auto shapeValue : gather.slice_sizes().getIntValues()) {
|
||||
sliceShape.push_back(shapeValue.getSExtValue());
|
||||
}
|
||||
|
||||
auto sliceTy =
|
||||
RankedTensorType::get(sliceShape, result_ty.getElementType());
|
||||
auto slice = rewriter.create<DynamicSliceOp>(
|
||||
gather.getLoc(), sliceTy, gather.operand(), slice_start_indices,
|
||||
gather.slice_sizes());
|
||||
|
||||
rewriter.replaceOpWithNewOp<ReshapeOp>(gather, gather.getType(), slice);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
void PopulateOptimizeMHLOPatterns(MLIRContext* context,
|
||||
OwningRewritePatternList* patterns) {
|
||||
patterns->insert<GatherIsSlice>(context);
|
||||
}
|
||||
} // end namespace mhlo
|
||||
} // end namespace mlir
|
@ -0,0 +1,49 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
|
||||
using mlir::FunctionPass;
|
||||
using mlir::PassRegistration;
|
||||
using mlir::PassWrapper;
|
||||
|
||||
namespace {
|
||||
class OptimizeMhlo : public PassWrapper<OptimizeMhlo, FunctionPass> {
|
||||
public:
|
||||
explicit OptimizeMhlo() : PassWrapper<OptimizeMhlo, FunctionPass>() {}
|
||||
|
||||
/// Performs the lowering to MHLO dialect.
|
||||
void runOnFunction() override;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
// Lowers the complex operations that can be represented using other operations.
|
||||
void OptimizeMhlo::runOnFunction() {
|
||||
// Add lowering patterns to the list.
|
||||
mlir::OwningRewritePatternList patterns;
|
||||
mlir::mhlo::PopulateOptimizeMHLOPatterns(&getContext(), &patterns);
|
||||
|
||||
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
||||
}
|
||||
|
||||
static PassRegistration<OptimizeMhlo> pass("mhlo-test-optimize",
|
||||
"Run optional HLO optimizations.");
|
@ -365,6 +365,16 @@ func @dynamic_broadcast_in_dim_op_not_actually_dynamic(%arg0: tensor<4xf32>, %ar
|
||||
return %0 : tensor<5x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @dynamic_broadcast_in_dim_to_same_shape
|
||||
func @dynamic_broadcast_in_dim_to_same_shape(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
// CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>
|
||||
%0 = shape.shape_of %arg0 : tensor<?xf32>
|
||||
%1 = shape.to_extent_tensor %0 : tensor<1xindex>
|
||||
%2 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %1) { broadcast_dimensions = dense<0> : tensor<1xi64> } : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK: return %[[ARG]] : tensor<?xf32>
|
||||
return %2 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @broadcast_in_dim_constant_fold_0d
|
||||
func @broadcast_in_dim_constant_fold_0d() -> tensor<1x64x224x224xf32> {
|
||||
%cst = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
|
@ -8,7 +8,7 @@
|
||||
func @broadcast_add(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<1xindex> {
|
||||
// CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]]
|
||||
// CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
|
||||
// CHECK-DAG: %[[BCAST_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
|
||||
// CHECK-DAG: %[[BCAST_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]]
|
||||
// CHECK: %[[EXTENTS:.+]] = shape.to_extent_tensor %[[BCAST_S]]
|
||||
// CHECK: return %[[EXTENTS]]
|
||||
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
|
@ -18,7 +18,7 @@ func @dynamicBroadcast(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?
|
||||
// CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
|
||||
// CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]]
|
||||
// CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
|
||||
// CHECK-DAG: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
|
||||
// CHECK-DAG: %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]]
|
||||
// CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
|
||||
// CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
|
||||
@ -39,7 +39,7 @@ func @dynamicBroadcastComplex(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
|
||||
// CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
|
||||
// CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]]
|
||||
// CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
|
||||
// CHECK-NEXT: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
|
||||
// CHECK-NEXT: %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]]
|
||||
// CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
|
||||
// CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
@ -60,7 +60,7 @@ func @dynamicBroadcastCompare(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
|
||||
// CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
|
||||
// CHECK: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]]
|
||||
// CHECK: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
|
||||
// CHECK: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
|
||||
// CHECK: %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]]
|
||||
// CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
|
||||
// CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
@ -237,3 +237,77 @@ func @xorWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4x
|
||||
%0 = chlo.broadcast_xor %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
|
||||
return %0 : tensor<4xi1>
|
||||
}
|
||||
|
||||
// -----
|
||||
func @addScalarUnranked(%arg0: tensor<f32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<f32>, tensor<*xf32>)
|
||||
-> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @addScalarUnranked(
|
||||
// CHECK-SAME: %[[ARG_0:.*]]: tensor<f32>,
|
||||
// CHECK-SAME: %[[ARG_1:.*]]: tensor<*xf32>
|
||||
// CHECK-SAME: ) -> tensor<*xf32> {
|
||||
// First handle the dynamic reshaping of the unranked operand
|
||||
// to a 1D tensor.
|
||||
// CHECK: %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor<*xf32>
|
||||
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_1]]
|
||||
// CHECK: %[[SIZE:.*]] = shape.size_to_index %[[NUM_ELEMENTS]]
|
||||
// CHECK: %[[SIZE_TENSOR:.*]] = tensor_from_elements(%[[SIZE]]) : tensor<1xindex>
|
||||
// CHECK: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_1]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// The assuming region is part of the second stage of lowering
|
||||
// with ranked broadcasting logic.
|
||||
// CHECK: %[[SHAPE_0:.*]] = shape.shape_of %[[ARG_0]] : tensor<f32>
|
||||
// CHECK: %[[SHAPE_RESHAPED:.*]] = shape.shape_of %[[RESHAPED]] : tensor<?xf32>
|
||||
// CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[SHAPE_0]], %[[SHAPE_RESHAPED]]
|
||||
// CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]] -> (tensor<?xf32>) {
|
||||
// CHECK: %[[SCALAR_SHAPE:.*]] = shape.const_shape []
|
||||
// CHECK: %[[BROADCASTED_SHAPE:.*]] = shape.broadcast %[[SCALAR_SHAPE]], %[[SHAPE_RESHAPED]]
|
||||
// CHECK: %[[SHAPE_TENSOR:.*]] = shape.to_extent_tensor %[[BROADCASTED_SHAPE]] : tensor<1xindex>
|
||||
// CHECK: %[[BROADCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG_0]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK: %[[BROADCASTED_RHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[RESHAPED]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK: %[[BROADCASTED_RESULT:.*]] = mhlo.add %[[BROADCASTED_LHS]], %[[BROADCASTED_RHS]] : tensor<?xf32>
|
||||
// CHECK: shape.assuming_yield %[[BROADCASTED_RESULT]] : tensor<?xf32>
|
||||
// CHECK: }
|
||||
// As part of the unranked logic, the result is reshaped back
|
||||
// to an unranked tensor.
|
||||
// CHECK: %[[PROPER_SHAPE_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE_1]] : tensor<?xindex>
|
||||
// CHECK: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[VAL_19:.*]], %[[PROPER_SHAPE_TENSOR]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK: return %[[RESHAPED_RESULT]] : tensor<*xf32>
|
||||
// CHECK: }
|
||||
|
||||
// -----
|
||||
func @addUnrankedScalar(%arg0: tensor<*xf32>, %arg1: tensor<f32>) -> tensor<*xf32> {
|
||||
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<*xf32>, tensor<f32>)
|
||||
-> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @addUnrankedScalar(
|
||||
// CHECK-SAME: %[[ARG_0:.*]]: tensor<*xf32>,
|
||||
// CHECK-SAME: %[[ARG_1:.*]]: tensor<f32>) -> tensor<*xf32> {
|
||||
// First handle the dynamic reshaping of the unranked operand
|
||||
// to a 1D tensor.
|
||||
// CHECK: %[[SHAPE_0:.*]] = shape.shape_of %[[ARG_0]] : tensor<*xf32>
|
||||
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_0]]
|
||||
// CHECK: %[[SIZE:.*]] = shape.size_to_index %[[NUM_ELEMENTS]]
|
||||
// CHECK: %[[SIZE_TENSOR:.*]] = tensor_from_elements(%[[SIZE]]) : tensor<1xindex>
|
||||
// CHECK: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_0]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// The assuming region is part of the second stage of lowering
|
||||
// with ranked broadcasting logic.
|
||||
// CHECK: %[[SHAPE_RESHAPED:.*]] = shape.shape_of %[[RESHAPED]] : tensor<?xf32>
|
||||
// CHECK: %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor<f32>
|
||||
// CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[SHAPE_RESHAPED]], %[[SHAPE_1]]
|
||||
// CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]] -> (tensor<?xf32>) {
|
||||
// CHECK: %[[SHAPE_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE_RESHAPED]] : tensor<1xindex>
|
||||
// CHECK: %[[BROADCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[RESHAPED]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK: %[[BROADCASTED_RHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG_1]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK: %[[BROADCASTED_RESULT:.*]] = mhlo.add %[[BROADCASTED_LHS]], %[[BROADCASTED_RHS]] : tensor<?xf32>
|
||||
// CHECK: shape.assuming_yield %[[BROADCASTED_RESULT]] : tensor<?xf32>
|
||||
// CHECK: }
|
||||
// As part of the unranked logic, the result is reshaped back
|
||||
// to an unranked tensor.
|
||||
// CHECK: %[[PROPER_SHAPE_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE_0]] : tensor<?xindex>
|
||||
// CHECK: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[VAL_19:.*]], %[[PROPER_SHAPE_TENSOR]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK: return %[[RESHAPED_RESULT]] : tensor<*xf32>
|
||||
// CHECK: }
|
||||
|
@ -0,0 +1,41 @@
|
||||
// RUN: mlir-hlo-opt -mhlo-legalize-gather-to-torch-index-select %s -o - | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @gather_to_index_select
|
||||
func @gather_to_index_select(%arg0 : tensor<5x4xf32>, %arg1 : tensor<1x3x1xi32>) -> tensor<1x3x4xf32> {
|
||||
// CHECK: [[TIS:%.+]] = "mhlo.torch_index_select"(%arg0, %arg1) {
|
||||
// CHECK-SAME: batch_dims = 0 : i64,
|
||||
// CHECK-SAME: dim = 0 : i64
|
||||
// CHECK-SAME: } : (tensor<5x4xf32>, tensor<1x3x1xi32>) -> tensor<1x3x1x4xf32>
|
||||
// CHECK: [[RES:%.+]] = "mhlo.reshape"([[TIS]])
|
||||
%0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<0> : tensor<1xi64>, index_vector_dim = 2 : i64, offset_dims = dense<2> : tensor<1xi64>, start_index_map = dense<0> : tensor<1xi64>}, indices_are_sorted = false, slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<5x4xf32>, tensor<1x3x1xi32>) -> tensor<1x3x4xf32>
|
||||
|
||||
// CHECK: return [[RES]]
|
||||
return %0 : tensor<1x3x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @scalar_gather_to_index_select
|
||||
func @scalar_gather_to_index_select(%arg0 : tensor<5x4xf32>, %arg1 : tensor<i32>) -> tensor<1x4xf32> {
|
||||
// CHECK: [[TIS:%.+]] = "mhlo.torch_index_select"(%arg0, %arg1) {
|
||||
// CHECK-SAME: batch_dims = 0 : i64,
|
||||
// CHECK-SAME: dim = 0 : i64
|
||||
// CHECK-SAME: } : (tensor<5x4xf32>, tensor<i32>) -> tensor<4xf32>
|
||||
// CHECK: [[RES:%.+]] = "mhlo.reshape"([[TIS]])
|
||||
%0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<0> : tensor<1xi64>, index_vector_dim = 0 : i64, offset_dims = dense<[0, 1]> : tensor<2xi64>, start_index_map = dense<0> : tensor<1xi64>}, indices_are_sorted = false, slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<5x4xf32>, tensor<i32>) -> tensor<1x4xf32>
|
||||
|
||||
// CHECK: return [[RES]]
|
||||
return %0 : tensor<1x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @gather_no_lowering_subslice
|
||||
func @gather_no_lowering_subslice(%arg0 : tensor<5x4xf32>, %arg1 : tensor<1x3x1xi32>) -> tensor<1x3x3xf32> {
|
||||
// CHECK: "mhlo.gather"
|
||||
%0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<0> : tensor<1xi64>, index_vector_dim = 2 : i64, offset_dims = dense<2> : tensor<1xi64>, start_index_map = dense<0> : tensor<1xi64>}, indices_are_sorted = false, slice_sizes = dense<[1, 3]> : tensor<2xi64>} : (tensor<5x4xf32>, tensor<1x3x1xi32>) -> tensor<1x3x3xf32>
|
||||
return %0 : tensor<1x3x3xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @gather_no_lowering_multidim
|
||||
func @gather_no_lowering_multidim(%arg0 : tensor<5x4xf32>, %arg1 : tensor<1x3x2xi32>) -> tensor<1x3x4xf32> {
|
||||
// CHECK: "mhlo.gather"
|
||||
%0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<0> : tensor<1xi64>, index_vector_dim = 2 : i64, offset_dims = dense<2> : tensor<1xi64>, start_index_map = dense<0> : tensor<1xi64>}, indices_are_sorted = false, slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<5x4xf32>, tensor<1x3x2xi32>) -> tensor<1x3x4xf32>
|
||||
return %0 : tensor<1x3x4xf32>
|
||||
}
|
@ -487,3 +487,26 @@ func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) -> tensor
|
||||
} : (tensor<2x2x3x4xf32>, tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32>
|
||||
return %out : tensor<3x5x5x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// BOTH-LABEL: func @reduce
|
||||
func @reduce(%arg0: tensor<1x8xf32>, %arg1: tensor<f32>) -> tensor<1xf32> {
|
||||
// BOTH: %[[OUT:.*]] = alloc() : memref<1xf32>
|
||||
// BOTH: "lmhlo.reduce"(%{{.+}}, %{{.+}}, %[[OUT]]) ( {
|
||||
// BOTH: ^bb0(%[[ARG1:.*]]: memref<f32>, %[[ARG2:.*]]: memref<f32>,
|
||||
// BOTH-SAME: %[[ARG3:.*]]: memref<f32>):
|
||||
// BOTH: %[[TMP:.*]] = alloc() : memref<f32>
|
||||
// BOTH: "lmhlo.add"(%[[ARG1]], %[[ARG2]], %[[TMP]])
|
||||
// BOTH: "lmhlo.copy"(%[[TMP]], %[[ARG3]])
|
||||
// BOTH: "lmhlo.terminator"() : () -> ()
|
||||
// BOTH: }) {dimensions = dense<1> : tensor<1xi64>}
|
||||
// BOTH-SAME: : (memref<1x8xf32>, memref<f32>, memref<1xf32>) -> ()
|
||||
%0 = "mhlo.reduce"(%arg0, %arg1) ( {
|
||||
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>): // no predecessors
|
||||
%1 = mhlo.add %arg2, %arg3 : tensor<f32>
|
||||
"mhlo.return"(%1) : (tensor<f32>) -> ()
|
||||
}) {dimensions = dense<1> : tensor<1xi64>}
|
||||
: (tensor<1x8xf32>, tensor<f32>) -> tensor<1xf32>
|
||||
return %0 : tensor<1xf32>
|
||||
}
|
||||
|
@ -91,3 +91,25 @@ func @must_be_removed_second(%arg0: memref<2x2xf32>,
|
||||
dealloc %0 : memref<2x2xf32>
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @reduce
|
||||
func @reduce(%arg0: memref<1x8xf32>, %arg1: memref<f32>, %arg2: memref<1xf32>) {
|
||||
%0 = alloc() : memref<1xf32>
|
||||
"lmhlo.reduce"(%arg0, %arg1, %0) ( {
|
||||
// CHECK: ^bb0(%[[ARG0:.*]]: memref<f32>, %[[ARG1:.*]]: memref<f32>,
|
||||
// CHECK-SAME: %[[ARG2:.*]]: memref<f32>)
|
||||
^bb0(%arg3: memref<f32>, %arg4: memref<f32>, %arg5: memref<f32>):
|
||||
%1 = alloc() : memref<f32>
|
||||
// CHECK: "lmhlo.add"(%[[ARG0]], %[[ARG1]], %[[ARG2]])
|
||||
"lmhlo.add"(%arg3, %arg4, %1)
|
||||
: (memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||
// CHECK-NOT; lmhlo.copy
|
||||
"lmhlo.copy"(%1, %arg5) : (memref<f32>, memref<f32>) -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
}) {dimensions = dense<1> : tensor<1xi64>}
|
||||
: (memref<1x8xf32>, memref<f32>, memref<1xf32>) -> ()
|
||||
"lmhlo.copy"(%0, %arg2) : (memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
64
tensorflow/compiler/mlir/hlo/tests/optimize-hlo.mlir
Normal file
64
tensorflow/compiler/mlir/hlo/tests/optimize-hlo.mlir
Normal file
@ -0,0 +1,64 @@
|
||||
// RUN: mlir-hlo-opt %s -pass-pipeline='func(mhlo-test-optimize)' | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @gather_is_slice_no_rank
|
||||
func @gather_is_slice_no_rank(%arg0: tensor<2x1x2xi32>, %arg1: tensor<i64>) -> tensor<1x2xi32> {
|
||||
// CHECK: [[CST:%.+]] = mhlo.constant dense<0> : tensor<i64>
|
||||
// CHECK: [[SLICE:%.+]] = "mhlo.dynamic-slice"(%arg0, %arg1, [[CST]], [[CST]]) {slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>}
|
||||
// CHECK: [[RESHAPE:%.+]] = "mhlo.reshape"([[SLICE]])
|
||||
%res = "mhlo.gather"(%arg0, %arg1) {
|
||||
dimension_numbers = {
|
||||
collapsed_slice_dims = dense<0> : tensor<1xi64>,
|
||||
index_vector_dim = 0 : i64,
|
||||
offset_dims = dense<[0, 1]> : tensor<2xi64>,
|
||||
start_index_map = dense<0> : tensor<1xi64>
|
||||
},
|
||||
slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>
|
||||
} : (tensor<2x1x2xi32>, tensor<i64>) -> tensor<1x2xi32>
|
||||
|
||||
// CHECK: return [[RESHAPE]]
|
||||
return %res : tensor<1x2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @gather_is_slice
|
||||
func @gather_is_slice(%arg0: tensor<2x1x2xi32>, %arg1: tensor<1xi64>) -> tensor<1x2xi32> {
|
||||
// CHECK: [[CST:%.+]] = mhlo.constant dense<0> : tensor<i64>
|
||||
// CHECK: [[RESHAPE:%.+]] = "mhlo.reshape"(%arg1)
|
||||
// CHECK: [[SLICE:%.+]] = "mhlo.dynamic-slice"(%arg0, [[RESHAPE]], [[CST]], [[CST]]) {slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>}
|
||||
// CHECK: [[RES:%.+]] = "mhlo.reshape"([[SLICE]])
|
||||
|
||||
%res = "mhlo.gather"(%arg0, %arg1) {
|
||||
dimension_numbers = {
|
||||
collapsed_slice_dims = dense<0> : tensor<1xi64>,
|
||||
index_vector_dim = 0 : i64,
|
||||
offset_dims = dense<[0, 1]> : tensor<2xi64>,
|
||||
start_index_map = dense<0> : tensor<1xi64>
|
||||
},
|
||||
slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>
|
||||
} : (tensor<2x1x2xi32>, tensor<1xi64>) -> tensor<1x2xi32>
|
||||
|
||||
// CHECK: return [[RES]]
|
||||
return %res : tensor<1x2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @gather_is_slice_multiple_start_indices
|
||||
func @gather_is_slice_multiple_start_indices(%arg0: tensor<2x1x2xi32>, %arg1: tensor<2xi64>) -> tensor<1x2xi32> {
|
||||
// CHECK-DAG: [[CST:%.+]] = mhlo.constant dense<0>
|
||||
// CHECK-DAG: [[SLICE1:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[RESHAPE1:%.+]] = "mhlo.reshape"([[SLICE1]])
|
||||
// CHECK-DAG: [[SLICE2:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[RESHAPE2:%.+]] = "mhlo.reshape"([[SLICE2]])
|
||||
// CHECK-DAG: [[DSLICE:%.+]] = "mhlo.dynamic-slice"(%arg0, [[RESHAPE1]], [[RESHAPE2]], [[CST]]) {slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>}
|
||||
// CHECK-DAG: [[RES:%.+]] = "mhlo.reshape"([[DSLICE]])
|
||||
%res = "mhlo.gather"(%arg0, %arg1) {
|
||||
dimension_numbers = {
|
||||
collapsed_slice_dims = dense<0> : tensor<1xi64>,
|
||||
index_vector_dim = 0 : i64,
|
||||
offset_dims = dense<[0, 1]> : tensor<2xi64>,
|
||||
start_index_map = dense<0> : tensor<1xi64>
|
||||
},
|
||||
slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>
|
||||
} : (tensor<2x1x2xi32>, tensor<2xi64>) -> tensor<1x2xi32>
|
||||
|
||||
// CHECK: return [[RES]]
|
||||
return %res : tensor<1x2xi32>
|
||||
}
|
@ -25,7 +25,6 @@ package_group(
|
||||
filegroup(
|
||||
name = "tensorflow_lite_ops_td_files",
|
||||
srcs = [
|
||||
"experimental/tfl_hardware_interfaces.td",
|
||||
"ir/tfl_op_interfaces.td",
|
||||
"ir/tfl_ops.td",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
|
||||
@ -221,18 +220,14 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
":tensorflow_lite_ops_inc_gen",
|
||||
":validators",
|
||||
"//tensorflow/compiler/mlir/lite/experimental/estimators:cost_estimators",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:DerivedAttributeOpInterface",
|
||||
"@llvm-project//mlir:Dialect",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:LoopLikeInterface",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:SideEffects",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
@ -242,6 +237,28 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "constant_utils",
|
||||
srcs = [
|
||||
"utils/constant_utils.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"utils/constant_utils.h",
|
||||
],
|
||||
copts = ["-std=c++14"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:status",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "lstm_utils",
|
||||
srcs = [
|
||||
@ -273,6 +290,7 @@ cc_library(
|
||||
deps = [
|
||||
":tensorflow_lite",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes",
|
||||
"//tensorflow/core:framework",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
@ -338,6 +356,7 @@ cc_library(
|
||||
"transforms/optimize_functional_ops.cc",
|
||||
"transforms/prepare_composite_functions_tf.cc",
|
||||
"transforms/prepare_tf.cc",
|
||||
"transforms/raise_custom_ops.cc",
|
||||
"transforms/runtime_verify.cc",
|
||||
"transforms/split_merged_operands.cc",
|
||||
"transforms/trim_functions_tf.cc",
|
||||
@ -349,7 +368,7 @@ cc_library(
|
||||
"transforms/passes.h",
|
||||
],
|
||||
deps = [
|
||||
":common",
|
||||
":constant_utils",
|
||||
":lstm_utils",
|
||||
":stateful_ops_utils",
|
||||
":tensorflow_lite",
|
||||
@ -360,6 +379,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:convert_tensor",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||
"//tensorflow/compiler/mlir/tensorflow:unroll_batch_matmul_pass",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
@ -368,7 +388,6 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/kernels:tensor_list",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//llvm:Support",
|
||||
@ -399,7 +418,6 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
@ -433,7 +451,6 @@ cc_library(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
@ -456,7 +473,6 @@ cc_library(
|
||||
"//tensorflow/lite/tools/optimize/sparsity:format_converter",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
@ -480,7 +496,6 @@ gentbl(
|
||||
td_srcs = [
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
|
||||
"experimental/tfl_hardware_interfaces.td",
|
||||
"ir/tfl_op_interfaces.td",
|
||||
],
|
||||
)
|
||||
@ -609,8 +624,6 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:convert_tensor",
|
||||
"//tensorflow/compiler/mlir/tensorflow:export_tf_dialect_op",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
@ -620,7 +633,7 @@ cc_library(
|
||||
"//tensorflow/core/platform:status",
|
||||
"//tensorflow/lite:schema_fbs_version",
|
||||
"//tensorflow/lite:string_util",
|
||||
"//tensorflow/lite/delegates/flex:whitelisted_flex_ops_lib",
|
||||
"//tensorflow/lite/delegates/flex:allowlisted_flex_ops_lib",
|
||||
"//tensorflow/lite/kernels/internal:kernel_utils",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"//tensorflow/lite/tools/versioning",
|
||||
@ -651,7 +664,6 @@ cc_library(
|
||||
":flatbuffer_tflite_operator_lib",
|
||||
":tensorflow_lite",
|
||||
":tensorflow_lite_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
@ -724,7 +736,6 @@ cc_library(
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:MlirTranslateMain",
|
||||
@ -858,10 +869,8 @@ cc_library(
|
||||
"//tensorflow/core:core_cpu_base",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
)
|
||||
|
@ -8,9 +8,6 @@ package(
|
||||
cc_library(
|
||||
name = "cost_estimators",
|
||||
textual_hdrs = [
|
||||
"estimator.h",
|
||||
"cpu_estimators.h",
|
||||
"gpu_estimators.h",
|
||||
"hardware.h",
|
||||
"arithmetic_count_util.h",
|
||||
],
|
||||
|
@ -15,13 +15,17 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_ARITHMETIC_COUNT_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_ARITHMETIC_COUNT_UTIL_H_
|
||||
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
|
||||
// For add/mul/div/sub and other broadcastable ops.
|
||||
class ArithmeticCountUtilHelper {
|
||||
public:
|
||||
static bool GetArithmeticCountForBroadcastableOp(mlir::Operation* op,
|
||||
int64_t* count) {
|
||||
auto output = op->getResult(0);
|
||||
auto output_type = output.getType().dyn_cast_or_null<RankedTensorType>();
|
||||
auto output_type =
|
||||
output.getType().dyn_cast_or_null<mlir::RankedTensorType>();
|
||||
if (!output_type || !output_type.hasStaticShape()) return false;
|
||||
|
||||
*count = output_type.getNumElements();
|
||||
@ -31,7 +35,8 @@ class ArithmeticCountUtilHelper {
|
||||
static bool GetInputTensorTotalSize(mlir::Operation* op, int64_t* count) {
|
||||
int64_t total_count = 0;
|
||||
for (auto input : op->getOperands()) {
|
||||
auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
|
||||
auto input_type =
|
||||
input.getType().dyn_cast_or_null<mlir::RankedTensorType>();
|
||||
if (!input_type || !input_type.hasStaticShape()) {
|
||||
return false;
|
||||
}
|
||||
@ -43,14 +48,16 @@ class ArithmeticCountUtilHelper {
|
||||
|
||||
// For conv2d/depthwise_conv/fully_connected ops.
|
||||
// This algorithm actually comes from TOCO tooling_util.cc
|
||||
static bool GetArithmeticCountForConvAndFullyconnectedOp(Operation* op,
|
||||
static bool GetArithmeticCountForConvAndFullyconnectedOp(mlir::Operation* op,
|
||||
int64_t* count) {
|
||||
auto weight = op->getOperand(1);
|
||||
auto weight_type = weight.getType().dyn_cast_or_null<RankedTensorType>();
|
||||
auto weight_type =
|
||||
weight.getType().dyn_cast_or_null<mlir::RankedTensorType>();
|
||||
if (weight_type == nullptr || !weight_type.hasStaticShape()) return false;
|
||||
|
||||
auto output = op->getResult(0);
|
||||
auto output_type = output.getType().dyn_cast_or_null<RankedTensorType>();
|
||||
auto output_type =
|
||||
output.getType().dyn_cast_or_null<mlir::RankedTensorType>();
|
||||
if (output_type == nullptr || !output_type.hasStaticShape()) return false;
|
||||
|
||||
int64_t cols = 1;
|
||||
@ -63,7 +70,8 @@ class ArithmeticCountUtilHelper {
|
||||
|
||||
auto bias = op->getOperand(2);
|
||||
if (bias) {
|
||||
auto bias_type = bias.getType().dyn_cast_or_null<RankedTensorType>();
|
||||
auto bias_type =
|
||||
bias.getType().dyn_cast_or_null<mlir::RankedTensorType>();
|
||||
if (bias_type && bias_type.hasStaticShape()) {
|
||||
*count += bias_type.getNumElements();
|
||||
}
|
||||
|
@ -1,149 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_CPU_ESTIMATORS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_CPU_ESTIMATORS_H_
|
||||
|
||||
// CPU
|
||||
constexpr float kCPUArithmeticUnitCost = 1.0;
|
||||
|
||||
// This basically assumes pure load/store. This is just fake data.
|
||||
constexpr float kCPUCopyUnitCost = 0.5;
|
||||
constexpr float kCPUDefaultCost = 3.0f;
|
||||
|
||||
// Default values.
|
||||
constexpr float kCPUDefaultFixedValuedCost = 10000.0;
|
||||
|
||||
// tfl.add
|
||||
template <>
|
||||
class TFLiteCostEstimator<AddOp, hardware::CPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
int64_t count;
|
||||
if (ArithmeticCountUtilHelper::GetArithmeticCountForBroadcastableOp(op,
|
||||
&count))
|
||||
return kCPUArithmeticUnitCost * count;
|
||||
return kCPUDefaultFixedValuedCost;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.concatenation
|
||||
template <>
|
||||
class TFLiteCostEstimator<ConcatenationOp, hardware::CPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
int64_t count;
|
||||
if (ArithmeticCountUtilHelper::GetInputTensorTotalSize(op, &count))
|
||||
return kCPUCopyUnitCost * count;
|
||||
return kCPUDefaultFixedValuedCost;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.conv_2d
|
||||
template <>
|
||||
class TFLiteCostEstimator<Conv2DOp, hardware::CPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
int64_t arithmetic_count;
|
||||
if (ArithmeticCountUtilHelper::GetArithmeticCountForConvAndFullyconnectedOp(
|
||||
op, &arithmetic_count)) {
|
||||
return arithmetic_count * kCPUArithmeticUnitCost;
|
||||
}
|
||||
return kCPUDefaultFixedValuedCost;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.depthwise_conv_2d
|
||||
template <>
|
||||
class TFLiteCostEstimator<DepthwiseConv2DOp, hardware::CPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
int64_t arithmetic_count;
|
||||
if (ArithmeticCountUtilHelper::GetArithmeticCountForConvAndFullyconnectedOp(
|
||||
op, &arithmetic_count)) {
|
||||
return arithmetic_count * kCPUArithmeticUnitCost;
|
||||
}
|
||||
return kCPUDefaultFixedValuedCost;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.fully_connected
|
||||
template <>
|
||||
class TFLiteCostEstimator<FullyConnectedOp, hardware::CPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
int64_t arithmetic_count;
|
||||
if (ArithmeticCountUtilHelper::GetArithmeticCountForConvAndFullyconnectedOp(
|
||||
op, &arithmetic_count)) {
|
||||
return arithmetic_count * kCPUArithmeticUnitCost;
|
||||
}
|
||||
return kCPUDefaultFixedValuedCost;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.mul
|
||||
template <>
|
||||
class TFLiteCostEstimator<MulOp, hardware::CPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
int64_t count;
|
||||
if (ArithmeticCountUtilHelper::GetArithmeticCountForBroadcastableOp(op,
|
||||
&count))
|
||||
return kCPUArithmeticUnitCost * count;
|
||||
return kCPUDefaultFixedValuedCost;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.pack
|
||||
template <>
|
||||
class TFLiteCostEstimator<PackOp, hardware::CPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
int64_t count;
|
||||
if (ArithmeticCountUtilHelper::GetInputTensorTotalSize(op, &count))
|
||||
return kCPUCopyUnitCost * count;
|
||||
return kCPUDefaultFixedValuedCost;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.reshape
|
||||
template <>
|
||||
class TFLiteCostEstimator<ReshapeOp, hardware::CPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
int64_t count;
|
||||
if (ArithmeticCountUtilHelper::GetInputTensorTotalSize(op, &count))
|
||||
return kCPUCopyUnitCost * count;
|
||||
return kCPUDefaultFixedValuedCost;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_CPU_ESTIMATORS_H_
|
@ -1,51 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_ESTIMATOR_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_ESTIMATOR_H_
|
||||
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/experimental/estimators/hardware.h"
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h.inc"
|
||||
|
||||
template <typename Op, typename TargetHardware>
|
||||
class TFLiteCostEstimator {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined support for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
// All ops on CPU are supported.
|
||||
// TODO(karimnosseir): Only allow TFL ops in the "TFL_OP" param.
|
||||
template <typename TFL_OP>
|
||||
class TFLiteCostEstimator<TFL_OP, hardware::CPU> {
|
||||
public:
|
||||
// TODO(karimnosseir): Update and use table based method and lookup
|
||||
// cost from a loadable table ?
|
||||
static double GetCost(mlir::Operation* op) { return 0.0; }
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_ESTIMATOR_H_
|
@ -1,543 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATORS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATORS_H_
|
||||
|
||||
// GPU
|
||||
constexpr float kGPUArithmeticUnitCost = 0.2;
|
||||
|
||||
// The copy can be non-consectutive copy. This is just fake data.
|
||||
constexpr float kGPUCopyUnitCost = 0.2;
|
||||
constexpr float kGPUDefaultCost = 1.0f;
|
||||
|
||||
// Default values.
|
||||
constexpr float kGPUDefaultFixedValuedCost = 10000.0;
|
||||
|
||||
// tfl.abs
|
||||
template <>
|
||||
class TFLiteCostEstimator<AbsOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.add
|
||||
template <>
|
||||
class TFLiteCostEstimator<AddOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
int64_t count;
|
||||
if (ArithmeticCountUtilHelper::GetArithmeticCountForBroadcastableOp(op,
|
||||
&count))
|
||||
return kGPUArithmeticUnitCost * count;
|
||||
return kGPUDefaultFixedValuedCost;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.average_pool_2d
|
||||
template <>
|
||||
class TFLiteCostEstimator<AveragePool2DOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.concatenation
|
||||
template <>
|
||||
class TFLiteCostEstimator<ConcatenationOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
int64_t count;
|
||||
if (ArithmeticCountUtilHelper::GetInputTensorTotalSize(op, &count))
|
||||
return kGPUCopyUnitCost * count;
|
||||
return kGPUDefaultFixedValuedCost;
|
||||
}
|
||||
|
||||
// TODO(renjieliu): We probably need to check for dynamic weights.
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.conv_2d
|
||||
template <>
|
||||
class TFLiteCostEstimator<Conv2DOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
int64_t arithmetic_count;
|
||||
if (ArithmeticCountUtilHelper::GetArithmeticCountForConvAndFullyconnectedOp(
|
||||
op, &arithmetic_count)) {
|
||||
return arithmetic_count * kGPUArithmeticUnitCost;
|
||||
}
|
||||
return kGPUDefaultFixedValuedCost;
|
||||
}
|
||||
|
||||
// TODO(renjieliu): We probably need to check for dynamic weights.
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.cos
|
||||
template <>
|
||||
class TFLiteCostEstimator<CosOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.depthwise_conv_2d
|
||||
template <>
|
||||
class TFLiteCostEstimator<DepthwiseConv2DOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
int64_t arithmetic_count;
|
||||
if (ArithmeticCountUtilHelper::GetArithmeticCountForConvAndFullyconnectedOp(
|
||||
op, &arithmetic_count)) {
|
||||
return arithmetic_count * kGPUArithmeticUnitCost;
|
||||
}
|
||||
return kGPUDefaultFixedValuedCost;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.div
|
||||
template <>
|
||||
class TFLiteCostEstimator<DivOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.exp
|
||||
template <>
|
||||
class TFLiteCostEstimator<ExpOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.fully_connected
|
||||
template <>
|
||||
class TFLiteCostEstimator<FullyConnectedOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
int64_t arithmetic_count;
|
||||
if (ArithmeticCountUtilHelper::GetArithmeticCountForConvAndFullyconnectedOp(
|
||||
op, &arithmetic_count)) {
|
||||
return arithmetic_count * kGPUArithmeticUnitCost;
|
||||
}
|
||||
return kGPUDefaultFixedValuedCost;
|
||||
}
|
||||
|
||||
// TODO(renjieliu): we need to check for dynamic weights.
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.hard_swish
|
||||
template <>
|
||||
class TFLiteCostEstimator<HardSwishOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.log
|
||||
template <>
|
||||
class TFLiteCostEstimator<LogOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.logistic
|
||||
template <>
|
||||
class TFLiteCostEstimator<LogisticOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.max_pool_2d
|
||||
template <>
|
||||
class TFLiteCostEstimator<MaxPool2DOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.mirror_pad
|
||||
template <>
|
||||
class TFLiteCostEstimator<MirrorPadOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.maximum
|
||||
template <>
|
||||
class TFLiteCostEstimator<MaximumOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.custom
|
||||
template <>
|
||||
class TFLiteCostEstimator<CustomOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.mean
|
||||
template <>
|
||||
class TFLiteCostEstimator<MeanOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// TODO(renjieiu): check for constraints.
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.minimum
|
||||
template <>
|
||||
class TFLiteCostEstimator<MinimumOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.mul
|
||||
template <>
|
||||
class TFLiteCostEstimator<MulOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
int64_t count;
|
||||
if (ArithmeticCountUtilHelper::GetArithmeticCountForBroadcastableOp(op,
|
||||
&count))
|
||||
return kGPUArithmeticUnitCost * count;
|
||||
return kGPUDefaultFixedValuedCost;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.pad
|
||||
template <>
|
||||
class TFLiteCostEstimator<PadOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.pow
|
||||
template <>
|
||||
class TFLiteCostEstimator<PowOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.prelu
|
||||
template <>
|
||||
class TFLiteCostEstimator<PReluOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.relu
|
||||
template <>
|
||||
class TFLiteCostEstimator<ReluOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.relu6
|
||||
template <>
|
||||
class TFLiteCostEstimator<Relu6Op, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.reshape
|
||||
template <>
|
||||
class TFLiteCostEstimator<ReshapeOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
int64_t count;
|
||||
if (ArithmeticCountUtilHelper::GetInputTensorTotalSize(op, &count))
|
||||
return kGPUCopyUnitCost * count;
|
||||
return kGPUDefaultFixedValuedCost;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.rsqrt
|
||||
template <>
|
||||
class TFLiteCostEstimator<RsqrtOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.sin
|
||||
template <>
|
||||
class TFLiteCostEstimator<SinOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.slice
|
||||
template <>
|
||||
class TFLiteCostEstimator<SliceOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.softmax
|
||||
template <>
|
||||
class TFLiteCostEstimator<SoftmaxOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.space_to_depth
|
||||
template <>
|
||||
class TFLiteCostEstimator<SpaceToDepthOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.sqrt
|
||||
template <>
|
||||
class TFLiteCostEstimator<SqrtOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.square
|
||||
template <>
|
||||
class TFLiteCostEstimator<SquareOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.squared_difference
|
||||
template <>
|
||||
class TFLiteCostEstimator<SquaredDifferenceOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.strided_slice
|
||||
template <>
|
||||
class TFLiteCostEstimator<StridedSliceOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.tanh
|
||||
template <>
|
||||
class TFLiteCostEstimator<TanhOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.transpose
|
||||
template <>
|
||||
class TFLiteCostEstimator<TransposeOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.transpose_conv
|
||||
template <>
|
||||
class TFLiteCostEstimator<TransposeConvOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATORS_H_
|
||||
|
@ -1,76 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// WARNING: This Interface is experimental, DO NOT USE.
|
||||
|
||||
// This is the Target Hardware operation interfacea definition file
|
||||
// for TensorFlow Lite.
|
||||
|
||||
#ifndef TFL_TARGET_HARDWARE_OP_INTERFACES
|
||||
#define TFL_TARGET_HARDWARE_OP_INTERFACES
|
||||
|
||||
def TFL_CpuTargetOp : OpInterface<"CpuOpTargetInterface"> {
|
||||
let description = [{
|
||||
Interface for ops to run on CPU.
|
||||
}];
|
||||
|
||||
let methods = [
|
||||
InterfaceMethod<
|
||||
[{Returns the cost of running this op on CPU.}],
|
||||
// TODO(karimnosseir): Change to return Cost object instead.
|
||||
"double", "GetOpCost", (ins "mlir::Operation*":$op_to_check), [{
|
||||
// TODO(karimnosseir): Consider changing to another way that doesn't
|
||||
// rely on template param name.
|
||||
return TFL::TFLiteCostEstimator<ConcreteOp, TFL::hardware::CPU>::GetCost(op_to_check);
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
[{Returns whether this op can be run on CPU.}],
|
||||
"bool", "IsSupported", (ins "mlir::Operation*":$op_to_check), [{
|
||||
// TODO(karimnosseir): Consider changing to another way that doesn't
|
||||
// rely on template param name.
|
||||
return TFL::TFLiteCostEstimator<ConcreteOp, TFL::hardware::CPU>::IsSupported(op_to_check);
|
||||
}]
|
||||
>,
|
||||
];
|
||||
}
|
||||
|
||||
def TFL_GpuTargetOp : OpInterface<"GpuOpTargetInterface"> {
|
||||
let description = [{
|
||||
Interface for ops to run on GPU.
|
||||
}];
|
||||
|
||||
let methods = [
|
||||
InterfaceMethod<
|
||||
[{Returns the cost of running this op on GPU.}],
|
||||
// TODO(karimnosseir): Change to return Cost object instead.
|
||||
"double", "GetOpCost", (ins "Operation*":$op_to_check), [{
|
||||
// TODO(karimnosseir): Consider changing to another way that doesn't
|
||||
// rely on template param name.
|
||||
return TFL::TFLiteCostEstimator<ConcreteOp, TFL::hardware::GPU>::GetCost(op_to_check);
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
[{Returns whether this op can be run on GPU.}],
|
||||
"bool", "IsSupported", (ins "Operation*":$op_to_check), [{
|
||||
// TODO(karimnosseir): Consider changing to another way that doesn't
|
||||
// rely on template param name.
|
||||
return TFL::TFLiteCostEstimator<ConcreteOp, TFL::hardware::GPU>::IsSupported(op_to_check);
|
||||
}]
|
||||
>,
|
||||
];
|
||||
}
|
||||
|
||||
#endif // TFL_TARGET_HARDWARE_OP_INTERFACES
|
@ -149,6 +149,9 @@ static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
|
||||
if (ftype && ftype.isF32()) {
|
||||
return tflite::TensorType_COMPLEX64;
|
||||
}
|
||||
if (ftype && ftype.isF64()) {
|
||||
return tflite::TensorType_COMPLEX128;
|
||||
}
|
||||
return Status(error::INVALID_ARGUMENT, "Unsupported type");
|
||||
}
|
||||
case mlir::StandardTypes::Integer: {
|
||||
@ -1193,22 +1196,35 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(
|
||||
if (IsConst(&inst)) continue;
|
||||
|
||||
// Fetch operand and result tensor indices.
|
||||
std::vector<int32_t> operands;
|
||||
operands.reserve(inst.getNumOperands());
|
||||
for (auto operand : inst.getOperands()) {
|
||||
if (operand.getType().isa<NoneType>())
|
||||
operands.push_back(kTfLiteOptionalTensor);
|
||||
else
|
||||
operands.push_back(tensor_index_map.lookup(operand));
|
||||
}
|
||||
std::vector<int32_t> results;
|
||||
results.reserve(inst.getNumOperands());
|
||||
for (auto result : inst.getResults()) {
|
||||
results.push_back(tensor_index_map.lookup(result));
|
||||
}
|
||||
Operation* real_inst = &inst;
|
||||
// CustomTfOp is just a wrapper around a TF op, we export the custom Op
|
||||
// not the wrapper, so we fetch the op from the region.
|
||||
if (auto custom_op = dyn_cast<mlir::TFL::CustomTfOp>(inst)) {
|
||||
// If we have custom op with a region, then use the first op in the
|
||||
// region, if it exists, otherwise just use params for custom op.
|
||||
if (!custom_op.body().empty()) {
|
||||
real_inst = &custom_op.body().front().front();
|
||||
} else {
|
||||
module_.emitError(
|
||||
"Invalid CustomTfOp: Custom TF Op have empty region.");
|
||||
}
|
||||
}
|
||||
std::vector<int32_t> operands;
|
||||
operands.reserve(real_inst->getNumOperands());
|
||||
for (auto operand : real_inst->getOperands()) {
|
||||
if (operand.getType().isa<NoneType>())
|
||||
operands.push_back(kTfLiteOptionalTensor);
|
||||
else
|
||||
operands.push_back(tensor_index_map.lookup(operand));
|
||||
}
|
||||
|
||||
if (auto tfl_operator =
|
||||
BuildOperator(&inst, operands, results, intermediates))
|
||||
BuildOperator(real_inst, operands, results, intermediates))
|
||||
operators.push_back(*tfl_operator);
|
||||
else
|
||||
failed_once = true;
|
||||
|
@ -19,7 +19,6 @@ limitations under the License.
|
||||
#define TFL_OP_INTERFACES
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "tensorflow/compiler/mlir/lite/experimental/tfl_hardware_interfaces.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TFL op interface for stateful operands.
|
||||
|
@ -48,14 +48,9 @@ class TensorFlowLiteDialect : public Dialect {
|
||||
Location loc) override;
|
||||
};
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/experimental/estimators/estimator.h"
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops_interface.h.inc"
|
||||
#define GET_OP_CLASSES
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h.inc"
|
||||
// Include all specializes estimators below this line
|
||||
#include "tensorflow/compiler/mlir/lite/experimental/estimators/arithmetic_count_util.h"
|
||||
#include "tensorflow/compiler/mlir/lite/experimental/estimators/cpu_estimators.h"
|
||||
#include "tensorflow/compiler/mlir/lite/experimental/estimators/gpu_estimators.h"
|
||||
|
||||
} // end namespace TFL
|
||||
} // end namespace mlir
|
||||
|
@ -410,10 +410,7 @@ def TFL_ComparisonBinaryBuilder : OpBuilder<
|
||||
|
||||
class TFL_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<TFL_Dialect, mnemonic, !listconcat(traits,
|
||||
[DeclareOpInterfaceMethods<TFL_RuntimeVerification>,
|
||||
// All TFL ops are supported on CPU.
|
||||
DeclareOpInterfaceMethods<TFL_CpuTargetOp>
|
||||
])> {
|
||||
[DeclareOpInterfaceMethods<TFL_RuntimeVerification>])> {
|
||||
// FlatBuffer generation specific information.
|
||||
// -------------------------------------------
|
||||
// When generating the FlatBuffer output some operations have
|
||||
@ -435,8 +432,7 @@ class TFL_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
|
||||
class TFL_ConvOp<string mnemonic, string opSummary, int index> :
|
||||
TFL_Op<mnemonic, [NoSideEffect, AccumulatorUniformScale<2, 0, 1>,
|
||||
AffineQuantizedOpInterface, AffineOpCoefficient<index, 1>,
|
||||
TFL_GpuTargetOp, TFL_SparseOp]> {
|
||||
AffineQuantizedOpInterface, AffineOpCoefficient<index, 1>, TFL_SparseOp]> {
|
||||
let summary = opSummary # " operator";
|
||||
|
||||
let description = [{
|
||||
@ -473,8 +469,7 @@ def TFL_AbsOp : TFL_Op<"abs", [
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultType,
|
||||
NoQuantizableResult,
|
||||
TFL_GpuTargetOp]> {
|
||||
NoQuantizableResult]> {
|
||||
let summary = "Absolute value operator";
|
||||
|
||||
let description = [{
|
||||
@ -495,8 +490,7 @@ def TFL_AddOp : TFL_Op<"add", [
|
||||
CPred<"TFL::VerifyAddOpShapeConstraints(llvm::cast<AddOp>($_op))">>,
|
||||
ResultsBroadcastableShape,
|
||||
NoSideEffect,
|
||||
Commutative,
|
||||
TFL_GpuTargetOp]> {
|
||||
Commutative]> {
|
||||
let summary = "Addition operator";
|
||||
|
||||
let description = [{
|
||||
@ -573,7 +567,6 @@ def TFL_TransposeConvOp: TFL_Op<"transpose_conv", [
|
||||
TFL_TCresVTEtIsSameAsOp<0, 2>>,
|
||||
AccumulatorUniformScale<3, 1, 2>,
|
||||
AffineQuantizedOpInterface, AffineOpCoefficient<0, 2>,
|
||||
TFL_GpuTargetOp,
|
||||
TFL_SparseOp]> {
|
||||
let summary = "Transpose convolution operator";
|
||||
|
||||
@ -612,8 +605,7 @@ def TFL_TransposeConvOp: TFL_Op<"transpose_conv", [
|
||||
def TFL_AveragePool2DOp:
|
||||
TFL_Op<"average_pool_2d",
|
||||
[NoSideEffect,
|
||||
SameOperandsAndResultsScale,
|
||||
TFL_GpuTargetOp]> {
|
||||
SameOperandsAndResultsScale]> {
|
||||
let summary = "Average_pool_2d operator";
|
||||
|
||||
let description = [{
|
||||
@ -713,8 +705,7 @@ def TFL_ConcatenationOp : TFL_Op<"concatenation",
|
||||
NoSideEffect,
|
||||
PredOpTrait<"values and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
SameOperandsAndResultsScale,
|
||||
TFL_GpuTargetOp
|
||||
SameOperandsAndResultsScale
|
||||
]> {
|
||||
let summary = "Concatenation operator";
|
||||
|
||||
@ -861,8 +852,7 @@ def TFL_CosOp: TFL_Op<"cos", [
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultType,
|
||||
NoQuantizableResult,
|
||||
TFL_GpuTargetOp]> {
|
||||
NoQuantizableResult]> {
|
||||
let summary = "Cosine operator";
|
||||
|
||||
let description = [{
|
||||
@ -916,8 +906,7 @@ def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
|
||||
NoSideEffect, AccumulatorUniformScale<2, 0, 1>,
|
||||
AffineQuantizedOpInterface,
|
||||
AffineOpCoefficient<-1, 1>,
|
||||
TFL_SparseOp,
|
||||
TFL_GpuTargetOp]> {
|
||||
TFL_SparseOp]> {
|
||||
let summary = "Fully connected op";
|
||||
|
||||
let arguments = (ins
|
||||
@ -1070,8 +1059,7 @@ def TFL_LessEqualOp : TFL_Op<"less_equal", [
|
||||
ResultsBroadcastableShape,
|
||||
BinaryOpSameElementTypeConstraint,
|
||||
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
|
||||
NoSideEffect,
|
||||
NoQuantizableResult]> {
|
||||
NoSideEffect]> {
|
||||
let summary = "Less_equal operator";
|
||||
|
||||
let description = [{
|
||||
@ -1132,8 +1120,7 @@ convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imag
|
||||
def TFL_GreaterEqualOp : TFL_Op<"greater_equal", [
|
||||
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
|
||||
ResultsBroadcastableShape,
|
||||
NoSideEffect,
|
||||
NoQuantizableResult]> {
|
||||
NoSideEffect]> {
|
||||
let summary = "Greater_equal operator";
|
||||
|
||||
let description = [{
|
||||
@ -1360,8 +1347,7 @@ def TFL_DivOp : TFL_Op<"div", [
|
||||
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 5>,
|
||||
ResultsBroadcastableShape,
|
||||
NoSideEffect,
|
||||
NoQuantizableResult,
|
||||
TFL_GpuTargetOp]> {
|
||||
NoQuantizableResult]> {
|
||||
let summary = "Division operator";
|
||||
|
||||
let description = [{
|
||||
@ -1427,7 +1413,6 @@ def TFL_EmbeddingLookupOp: TFL_Op<"embedding_lookup",
|
||||
|
||||
def TFL_EqualOp: TFL_Op<"equal", [
|
||||
Commutative,
|
||||
NoQuantizableResult,
|
||||
ResultsBroadcastableShape,
|
||||
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
|
||||
PredOpTrait<"Operands have same value type", TCopVTEtIsSameAs<0, 1>>]> {
|
||||
@ -1449,8 +1434,7 @@ def TFL_EqualOp: TFL_Op<"equal", [
|
||||
}
|
||||
|
||||
def TFL_ExpOp: TFL_Op<"exp", [NoSideEffect,
|
||||
SameOperandsAndResultType,
|
||||
TFL_GpuTargetOp]> {
|
||||
SameOperandsAndResultType]> {
|
||||
let summary = "Natural exponentiation operator";
|
||||
|
||||
let description = [{
|
||||
@ -1634,8 +1618,7 @@ def TFL_GreaterOp : TFL_Op<"greater", [
|
||||
ResultsBroadcastableShape,
|
||||
BinaryOpSameElementTypeConstraint,
|
||||
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
|
||||
NoSideEffect,
|
||||
NoQuantizableResult]> {
|
||||
NoSideEffect]> {
|
||||
let summary = "Greater operator";
|
||||
|
||||
let description = [{
|
||||
@ -1659,8 +1642,7 @@ def TFL_HardSwishOp: TFL_Op<"hard_swish", [
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
TFL_GpuTargetOp]> {
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>]> {
|
||||
let summary = "Hardswish activation function.";
|
||||
let description = [{
|
||||
Computes hard-swish activation function
|
||||
@ -1735,8 +1717,7 @@ def TFL_LessOp : TFL_Op<"less", [
|
||||
ResultsBroadcastableShape,
|
||||
BinaryOpSameElementTypeConstraint,
|
||||
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
|
||||
NoSideEffect,
|
||||
NoQuantizableResult]> {
|
||||
NoSideEffect]> {
|
||||
let summary = "Less operator";
|
||||
|
||||
let description = [{
|
||||
@ -1812,8 +1793,7 @@ def TFL_LogisticOp: TFL_Op<"logistic", [
|
||||
PredOpTrait<"x and y must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
SameOperandsAndResultShape,
|
||||
FixedOutputRangeInterface,
|
||||
TFL_GpuTargetOp]> {
|
||||
FixedOutputRangeInterface]> {
|
||||
let summary = "Logistic operator";
|
||||
|
||||
let description = [{
|
||||
@ -1841,8 +1821,7 @@ def TFL_LogOp: TFL_Op<"log", [
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultType,
|
||||
NoQuantizableResult,
|
||||
TFL_GpuTargetOp]> {
|
||||
NoQuantizableResult]> {
|
||||
let summary = "Natural logarithm operator";
|
||||
|
||||
let description = [{
|
||||
@ -1908,8 +1887,7 @@ def TFL_MaxPool2DOp : TFL_Op<"max_pool_2d", [
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
NoSideEffect,
|
||||
MaxPoolOperandAndResultConstraints,
|
||||
SameOperandsAndResultsScale,
|
||||
TFL_GpuTargetOp]> {
|
||||
SameOperandsAndResultsScale]> {
|
||||
let summary = "Max Pool 2D op";
|
||||
|
||||
let description = [{
|
||||
@ -1941,8 +1919,7 @@ def TFL_MaximumOp : TFL_Op<"maximum", [
|
||||
NoSideEffect,
|
||||
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 5>,
|
||||
Commutative,
|
||||
SameOperandsAndResultsScale,
|
||||
TFL_GpuTargetOp]> {
|
||||
SameOperandsAndResultsScale]> {
|
||||
let summary = "Max operator";
|
||||
let description = [{
|
||||
Element-wise max operation.
|
||||
@ -1965,8 +1942,7 @@ def TFL_MaximumOp : TFL_Op<"maximum", [
|
||||
def TFL_MeanOp : TFL_Op<"mean", [
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
NoSideEffect,
|
||||
TFL_GpuTargetOp]> {
|
||||
NoSideEffect]> {
|
||||
let summary = "Mean operator";
|
||||
|
||||
let description = [{
|
||||
@ -2044,8 +2020,7 @@ def TFL_SliceOp : TFL_Op<"slice", [
|
||||
SameOperandsAndResultsScale,
|
||||
TFL_OperandHasRankAtMost<0, 4>,
|
||||
TFL_OperandHasRankAtMost<1, 1>,
|
||||
TFL_OperandHasRankAtMost<2, 1>,
|
||||
TFL_GpuTargetOp]> {
|
||||
TFL_OperandHasRankAtMost<2, 1>]> {
|
||||
let summary = "Return a slice from 'input'.";
|
||||
|
||||
let description = [{
|
||||
@ -2176,8 +2151,7 @@ def TFL_MinimumOp : TFL_Op<"minimum", [
|
||||
NoSideEffect,
|
||||
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 5>,
|
||||
Commutative,
|
||||
SameOperandsAndResultsScale,
|
||||
TFL_GpuTargetOp]> {
|
||||
SameOperandsAndResultsScale]> {
|
||||
let summary = "Min operator";
|
||||
let description = [{
|
||||
Element-wise min operation.
|
||||
@ -2203,8 +2177,7 @@ def TFL_MulOp : TFL_Op<"mul", [
|
||||
Commutative,
|
||||
BinaryOpSameElementTypeConstraint,
|
||||
TFL_RuntimePredOpTrait<"Operands do not have valid shapes",
|
||||
CPred<"TFL::VerifyMulOpShapeConstraints(llvm::cast<MulOp>($_op))">>,
|
||||
TFL_GpuTargetOp]> {
|
||||
CPred<"TFL::VerifyMulOpShapeConstraints(llvm::cast<MulOp>($_op))">>]> {
|
||||
let summary = "Multiplication operator";
|
||||
|
||||
let description = [{
|
||||
@ -2310,8 +2283,7 @@ def TFL_PadOp : TFL_Op<"pad", [
|
||||
TFL_OperandRankEquals1DimOfOperand<0, 1>,
|
||||
PredOpTrait<"the first dim size of the padding argument must be at most 4",
|
||||
Or<[TFL_OperandIsUnrankedPred<1>,
|
||||
TFL_OperandDimIsAtMost<1, 0, 4>]>>,
|
||||
TFL_GpuTargetOp]> {
|
||||
TFL_OperandDimIsAtMost<1, 0, 4>]>>]> {
|
||||
let summary = "Padding operator";
|
||||
|
||||
let description = [{
|
||||
@ -2404,8 +2376,7 @@ def TFL_PowOp : TFL_Op<"pow", [
|
||||
ResultsBroadcastableShape,
|
||||
NoSideEffect,
|
||||
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
|
||||
NoQuantizableResult,
|
||||
TFL_GpuTargetOp]> {
|
||||
NoQuantizableResult]> {
|
||||
let summary = "Power operator";
|
||||
|
||||
let description = [{
|
||||
@ -2428,7 +2399,6 @@ def TFL_PowOp : TFL_Op<"pow", [
|
||||
def TFL_PReluOp : TFL_Op<"prelu", [
|
||||
NoSideEffect,
|
||||
ResultsBroadcastableShape,
|
||||
TFL_GpuTargetOp,
|
||||
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
|
||||
BinaryOpSameElementTypeConstraint,
|
||||
PredOpTrait<"input and output must have the same element type",
|
||||
@ -2470,8 +2440,7 @@ def TFL_ReluOp: TFL_Op<"relu", [
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultsScale,
|
||||
TFL_GpuTargetOp]> {
|
||||
SameOperandsAndResultsScale]> {
|
||||
let summary = "Relu operator";
|
||||
|
||||
let description = [{
|
||||
@ -2500,8 +2469,7 @@ def TFL_Relu6Op: TFL_Op<"relu6", [
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultsScale,
|
||||
TFL_GpuTargetOp]> {
|
||||
SameOperandsAndResultsScale]> {
|
||||
let summary = "Relu6 operator";
|
||||
|
||||
let description = [{
|
||||
@ -2555,7 +2523,7 @@ def TFL_Relu1Op: TFL_Op<"relu_n1_to_1", [
|
||||
}
|
||||
|
||||
def TFL_ReshapeOp: TFL_Op<"reshape", [
|
||||
NoSideEffect, SameOperandsAndResultsScale, TFL_GpuTargetOp]> {
|
||||
NoSideEffect, SameOperandsAndResultsScale]> {
|
||||
let summary = "Reshape operator";
|
||||
|
||||
let description = [{
|
||||
@ -2610,8 +2578,7 @@ slice `i`, with the first `seq_lengths[i]` slices along dimension
|
||||
def TFL_RsqrtOp: TFL_Op<"rsqrt", [NoSideEffect,
|
||||
SameOperandsAndResultType,
|
||||
SameOperandsAndResultShape,
|
||||
NoQuantizableResult,
|
||||
TFL_GpuTargetOp]> {
|
||||
NoQuantizableResult]> {
|
||||
let summary = "Reciprocal of square root operator";
|
||||
|
||||
let description = [{
|
||||
@ -2706,6 +2673,7 @@ def TFL_ReverseV2Op: TFL_Op<"reverse_v2", [
|
||||
// are unranked. Therefore, we skip adding shape constraints here.
|
||||
def TFL_SelectOp : TFL_Op<"select", [
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultsScale,
|
||||
PredOpTrait<"operands have same element type", TCopVTEtIsSameAs<1, 2>>,
|
||||
PredOpTrait<"operands and result have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 1>>]> {
|
||||
@ -2777,8 +2745,7 @@ def TFL_SinOp: TFL_Op<"sin", [
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultType,
|
||||
NoQuantizableResult,
|
||||
TFL_GpuTargetOp]> {
|
||||
NoQuantizableResult]> {
|
||||
let summary = "Sine operator";
|
||||
|
||||
let description = [{
|
||||
@ -2798,8 +2765,7 @@ def TFL_SoftmaxOp : TFL_Op<"softmax", [
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
TFL_OperandHasRankRange<0, 1, 4>,
|
||||
SameOperandsAndResultShape,
|
||||
FixedOutputRangeInterface,
|
||||
TFL_GpuTargetOp]> {
|
||||
FixedOutputRangeInterface]> {
|
||||
let summary = "Softmax operator";
|
||||
|
||||
let description = [{
|
||||
@ -2834,8 +2800,7 @@ def TFL_SqrtOp: TFL_Op<"sqrt", [
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultType,
|
||||
NoQuantizableResult,
|
||||
TFL_GpuTargetOp]> {
|
||||
NoQuantizableResult]> {
|
||||
let summary = "Square root operator";
|
||||
|
||||
let description = [{
|
||||
@ -2853,8 +2818,7 @@ def TFL_SquareOp: TFL_Op<"square", [
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultType,
|
||||
NoQuantizableResult,
|
||||
TFL_GpuTargetOp]> {
|
||||
NoQuantizableResult]> {
|
||||
let summary = "Square operator";
|
||||
|
||||
let description = [{
|
||||
@ -2907,8 +2871,7 @@ def TFL_SquaredDifferenceOp : TFL_Op<"squared_difference", [
|
||||
SameOperandsAndResultElementType,
|
||||
ResultsBroadcastableShape,
|
||||
NoSideEffect,
|
||||
NoQuantizableResult,
|
||||
TFL_GpuTargetOp]> {
|
||||
NoQuantizableResult]> {
|
||||
let summary = "Squared difference operator";
|
||||
|
||||
let description = [{
|
||||
@ -2933,8 +2896,7 @@ def TFL_TanhOp: TFL_Op<"tanh", [
|
||||
SameOperandsAndResultShape,
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
FixedOutputRangeInterface,
|
||||
TFL_GpuTargetOp]> {
|
||||
FixedOutputRangeInterface]> {
|
||||
let summary = "Hyperbolic tangent operator";
|
||||
|
||||
let description = [{
|
||||
@ -3035,8 +2997,7 @@ def TFL_TransposeOp : TFL_Op<"transpose", [
|
||||
TFL_OperandHasRank<1, 1>,
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
SameOperandsAndResultsScale,
|
||||
TFL_GpuTargetOp]> {
|
||||
SameOperandsAndResultsScale]> {
|
||||
let summary = "Transpose operator";
|
||||
|
||||
let description = [{
|
||||
@ -3170,8 +3131,7 @@ def TFL_SpaceToDepthOp: TFL_Op<"space_to_depth", [
|
||||
SameOperandsAndResultsScale,
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
TFL_OperandHasRankAtMost<0, 4>,
|
||||
TFL_GpuTargetOp
|
||||
TFL_OperandHasRankAtMost<0, 4>
|
||||
]> {
|
||||
let summary = "SpaceToDepth operator";
|
||||
|
||||
@ -3383,8 +3343,7 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice", [
|
||||
TFL_OperandHasRankAtMost<0, 5>,
|
||||
TFL_OperandHasRank<1, 1>,
|
||||
TFL_OperandHasRank<2, 1>,
|
||||
TFL_OperandHasRank<3, 1>,
|
||||
TFL_GpuTargetOp
|
||||
TFL_OperandHasRank<3, 1>
|
||||
]> {
|
||||
let summary = "StridedSlice Op";
|
||||
|
||||
@ -3434,7 +3393,7 @@ def TFL_CastOp : TFL_Op<"cast", [
|
||||
}
|
||||
|
||||
def TFL_MirrorPadOp: TFL_Op<"mirror_pad", [
|
||||
NoSideEffect, TFL_OperandHasRank<1, 2>, TFL_GpuTargetOp]> {
|
||||
NoSideEffect, TFL_OperandHasRank<1, 2>]> {
|
||||
let summary = "MirrorPad Operator. Pads a tensor with mirrored values.";
|
||||
|
||||
let description = [{
|
||||
@ -4337,7 +4296,8 @@ def TFL_WhileOp : Op<TFL_Dialect, "while", [
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def TFL_CustomOp : Op<TFL_Dialect, "custom", [NoSideEffect]> {
|
||||
def TFL_CustomOp : Op<TFL_Dialect, "custom", [
|
||||
NoSideEffect, NoQuantizableResult]> {
|
||||
let summary = "Custom op";
|
||||
|
||||
let description = [{
|
||||
@ -4360,4 +4320,80 @@ def TFL_CustomOp : Op<TFL_Dialect, "custom", [NoSideEffect]> {
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
}
|
||||
|
||||
def TFL_CustomTfOp : Op<TFL_Dialect, "custom_tf", [
|
||||
// Currently the custom ops have no side effect
|
||||
// TODO(karimnosseir): Revisit if this needs updating.
|
||||
NoSideEffect,
|
||||
NoQuantizableResult,
|
||||
SingleBlockImplicitTerminator<"YieldOp">]> {
|
||||
let summary = "Wrapper Op for TF custom ops.";
|
||||
|
||||
let description = [{
|
||||
A wrapper op around any Custom TF op. These includes ops defined using
|
||||
custom_opdefs or linked which are not defined in TF dialect.
|
||||
This Op just wraps the custom op inside a region.
|
||||
Note #1, this Op will not include TF Lite custom ops defined using CustomOp.
|
||||
Note #2, this op is just internal representation inside the converter and
|
||||
are not exposed/exported when the model is exported to Flatbuffer.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<TFL_TensorOfOrNone<[AnyType]>>:$input
|
||||
);
|
||||
let results = (outs Variadic<AnyTensor>:$output);
|
||||
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
}
|
||||
|
||||
def TFL_BroadcastToOp : TFL_Op<"broadcast_to", [
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
TFL_OperandHasRankAtMost<0, 8>,
|
||||
TFL_OperandHasRank<1, 1>,
|
||||
PredOpTrait<"output dimension count must be at most 8",
|
||||
Or<[TFL_OperandIsUnrankedPred<1>,
|
||||
TFL_OperandDimIsAtMost<1, 0, 8>]>>,
|
||||
NoSideEffect]> {
|
||||
let summary = "Broadcast an array for a compatible shape.";
|
||||
|
||||
let description = [{
|
||||
Broadcasting is the process of making arrays to have compatible shapes
|
||||
for arithmetic operations. Two shapes are compatible if for each
|
||||
dimension pair they are either equal or one of them is one. When trying
|
||||
to broadcast a Tensor to a shape, it starts with the trailing dimensions,
|
||||
and works its way forward.
|
||||
|
||||
For example,
|
||||
|
||||
>>> x = tf.constant([1, 2, 3])
|
||||
>>> y = tf.broadcast_to(x, [3, 3])
|
||||
>>> print(y)
|
||||
tf.Tensor(
|
||||
[[1 2 3]
|
||||
[1 2 3]
|
||||
[1 2 3]], shape=(3, 3), dtype=int32)
|
||||
|
||||
In the above example, the input Tensor with the shape of `[1, 3]`
|
||||
is broadcasted to output Tensor with shape of `[3, 3]`.
|
||||
|
||||
When doing broadcasted operations such as multiplying a tensor
|
||||
by a scalar, broadcasting (usually) confers some time or space
|
||||
benefit, as the broadcasted tensor is never materialized.
|
||||
|
||||
However, `broadcast_to` does not carry with it any such benefits.
|
||||
The newly-created tensor takes the full memory of the broadcasted
|
||||
shape. (In a graph context, `broadcast_to` might be fused to
|
||||
subsequent operation and then be optimized away, however.)
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[F32, I32, I1, I8, QI8, UI8, QUI8, I16, QI16, I64, Complex<F<32>>]>:$input,
|
||||
TFL_I32OrI64Tensor:$shape
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32, I32, I1, I8, QI8, UI8, QUI8, I16, QI16, I64, Complex<F<32>>]>:$output
|
||||
);
|
||||
}
|
||||
|
||||
#endif // TFL_OPS
|
||||
|
@ -151,10 +151,13 @@ Status ConvertSavedModelToTFLiteFlatBuffer(
|
||||
return errors::Unimplemented("Only support a single exported name.");
|
||||
}
|
||||
|
||||
tensorflow::GraphImportConfig specs;
|
||||
specs.upgrade_legacy = true;
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto module,
|
||||
ImportSavedModel(model_flags.saved_model_dir(),
|
||||
model_flags.saved_model_version(), tags,
|
||||
exported_names, &context));
|
||||
exported_names, specs, &context));
|
||||
|
||||
if (!model_flags.input_arrays().empty() ||
|
||||
!model_flags.output_arrays().empty()) {
|
||||
|
@ -123,6 +123,8 @@ DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
|
||||
return DT_BOOL;
|
||||
case toco::IODataType::COMPLEX64:
|
||||
return DT_COMPLEX64;
|
||||
case toco::IODataType::COMPLEX128:
|
||||
return DT_COMPLEX128;
|
||||
default:
|
||||
return DT_INVALID;
|
||||
}
|
||||
|
@ -81,7 +81,6 @@ cc_library(
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
|
@ -36,11 +36,11 @@ versions {
|
||||
producer: 27
|
||||
}
|
||||
|
||||
# CHECK-LABEL: func @main
|
||||
# CHECK-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<4xi32>, %[[ARG_1:[a-z0-9]+]]: tensor<4xi32>) -> tensor<*xi32>
|
||||
# CHECK-SAME: control_outputs = ""
|
||||
# CHECK-SAME: inputs = "input0,input1"
|
||||
# CHECK-SAME: outputs = "output"
|
||||
# CHECK-NEXT: %[[OP:[a-z0-9]+]] = "tf.BannaPotatoSaladWithColeslaw"(%[[ARG_0]], %[[ARG_1]]) {T = i32, device = ""} : (tensor<4xi32>, tensor<4xi32>) -> tensor<*xi32>
|
||||
# CHECK-NEXT: return %[[OP]] : tensor<*xi32>
|
||||
# CHECK-NEXT: }
|
||||
# CHECK-LABEL: func @main(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<*xi32>
|
||||
# CHECK: attributes {tf.entry_function = {control_outputs = "", inputs = "input0,input1", outputs = "output"}} {
|
||||
# CHECK-NEXT: %[[CUSTOM:.*]] = "tfl.custom_tf"(%arg0, %arg1) ( {
|
||||
# CHECK-NEXT: %[[OUTPUTS:.*]] = "tf.BannaPotatoSaladWithColeslaw"(%arg0, %arg1) {T = i32, device = ""} : (tensor<4xi32>, tensor<4xi32>) -> tensor<*xi32>
|
||||
# CHECK-NEXT: "tfl.yield"(%[[OUTPUTS]]) : (tensor<*xi32>) -> ()
|
||||
# CHECK-NEXT: }) : (tensor<4xi32>, tensor<4xi32>) -> tensor<*xi32>
|
||||
# CHECK-NEXT: return %[[CUSTOM]] : tensor<*xi32>
|
||||
# CHECK-NEXT: }
|
||||
|
@ -15,6 +15,13 @@ func @complex64() -> tensor<4xcomplex<f32>> {
|
||||
return %0 : tensor<4xcomplex<f32>>
|
||||
}
|
||||
|
||||
func @complex128() -> tensor<4xcomplex<f64>> {
|
||||
// CHECK-LABEL: @complex128
|
||||
// CHECK: value = opaque<"tf", "0x746674656E736F722464747970653A2044545F434F4D504C45583132382074656E736F725F7368617065207B2064696D207B2073697A653A2034207D207D2074656E736F725F636F6E74656E743A20225C3030305C3030305C3030305C3030305C3030305C3030305C3336303F5C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303130405C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303230405C3030305C3030305C3030305C3030305C3030305C3030305C3030304022"> : tensor<4xcomplex<f64>>
|
||||
%0 = "tfl.pseudo_const"() { value = opaque<"tf", "0x746674656E736F722464747970653A2044545F434F4D504C45583132382074656E736F725F7368617065207B2064696D207B2073697A653A2034207D207D2074656E736F725F636F6E74656E743A20225C3030305C3030305C3030305C3030305C3030305C3030305C3336303F5C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303130405C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303230405C3030305C3030305C3030305C3030305C3030305C3030305C3030304022"> : tensor<4xcomplex<f64>> } : () -> tensor<4xcomplex<f64>>
|
||||
return %0 : tensor<4xcomplex<f64>>
|
||||
}
|
||||
|
||||
// TODO(b/138847107) this should work but doesn't
|
||||
// func @f16() -> tensor<4xf16> {
|
||||
// %0 = "tfl.pseudo_const"() { value = dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf16> } : () -> tensor<4xf16>
|
||||
|
@ -1,7 +1,7 @@
|
||||
// RUN: tf-opt -tfl-prepare-composite-funcs-tf -tfl-fuse-tftext=true %s -split-input-file | FileCheck %s
|
||||
module {
|
||||
|
||||
func @whitespace_tokenizer_rank1(%arg0: tensor<1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>) attributes {tf._input_shapes = [#tf.shape<1>], tf.api_implements = "tftext:WhitespaceTokenizer", tf.signature.is_stateful} {
|
||||
func @whitespace_tokenizer_rank1(%arg0: tensor<1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>) attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<1>], tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf.signature.is_stateful} {
|
||||
%0 = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64>
|
||||
%1 = "tf.Const"() {value = dense<[]> : tensor<0xi64>} : () -> tensor<0xi64>
|
||||
%2 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
|
||||
@ -1027,11 +1027,11 @@ module {
|
||||
return %1 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK: func @whitespace_tokenizer_rank1(%arg0: tensor<1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>) attributes {tf._input_shapes = [#tf.shape<1>], tf.api_implements = "tftext:WhitespaceTokenizer", tf.signature.is_stateful} {
|
||||
// CHECK: func @whitespace_tokenizer_rank1(%arg0: tensor<1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>) attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf._input_shapes = [#tf.shape<1>], tf.signature.is_stateful} {
|
||||
// CHECK: %0:2 = "tfl.custom"(%arg0) {custom_code = "tftext:WhitespaceTokenizer", custom_option = opaque<"tfl", "0x"> : tensor<0xi8>} : (tensor<1x!tf.string>) -> (tensor<?x!tf.string>, tensor<?xi64>)
|
||||
// CHECK: return %0#0, %0#1 : tensor<?x!tf.string>, tensor<?xi64>
|
||||
|
||||
func @whitespace_tokenizer_rank2(%arg0: tensor<?x1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>, tensor<?xi64>) attributes {tf._input_shapes = [#tf.shape<?x1>], tf.api_implements = "tftext:WhitespaceTokenizer", tf.signature.is_stateful} {
|
||||
func @whitespace_tokenizer_rank2(%arg0: tensor<?x1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>, tensor<?xi64>) attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<?x1>], tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf.signature.is_stateful} {
|
||||
%0 = "tf.Const"() {value = dense<[]> : tensor<0xi64>} : () -> tensor<0xi64>
|
||||
%1 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
|
||||
%2 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
|
||||
@ -2160,11 +2160,12 @@ module {
|
||||
}
|
||||
|
||||
|
||||
// CHECK: func @whitespace_tokenizer_rank2(%arg0: tensor<?x1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>, tensor<?xi64>) attributes {tf._input_shapes = [#tf.shape<?x1>], tf.api_implements = "tftext:WhitespaceTokenizer", tf.signature.is_stateful} {
|
||||
|
||||
// CHECK: func @whitespace_tokenizer_rank2(%arg0: tensor<?x1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>, tensor<?xi64>) attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf._input_shapes = [#tf.shape<?x1>], tf.signature.is_stateful} {
|
||||
// CHECK: %0:3 = "tfl.custom"(%arg0) {custom_code = "tftext:WhitespaceTokenizer", custom_option = opaque<"tfl", "0x"> : tensor<0xi8>} : (tensor<?x1x!tf.string>) -> (tensor<?x!tf.string>, tensor<?xi64>, tensor<?xi64>)
|
||||
// CHECK: return %0#0, %0#1, %0#2 : tensor<?x!tf.string>, tensor<?xi64>, tensor<?xi64>
|
||||
|
||||
func @whitespace_tokenizer_rank0(%arg0: tensor<!tf.string> {tf._user_specified_name = "input"}) -> tensor<?x!tf.string> attributes {tf._input_shapes = [#tf.shape<>], tf.api_implements = "tftext:WhitespaceTokenizer", tf.signature.is_stateful} {
|
||||
func @whitespace_tokenizer_rank0(%arg0: tensor<!tf.string> {tf._user_specified_name = "input"}) -> tensor<?x!tf.string> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>], tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf.signature.is_stateful} {
|
||||
%0 = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64>
|
||||
%1 = "tf.Const"() {value = dense<[]> : tensor<0xi64>} : () -> tensor<0xi64>
|
||||
%2 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
|
||||
@ -3190,7 +3191,7 @@ module {
|
||||
return %1 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK: func @whitespace_tokenizer_rank0(%arg0: tensor<!tf.string> {tf._user_specified_name = "input"}) -> tensor<?x!tf.string> attributes {tf._input_shapes = [#tf.shape<>], tf.api_implements = "tftext:WhitespaceTokenizer", tf.signature.is_stateful} {
|
||||
// CHECK: func @whitespace_tokenizer_rank0(%arg0: tensor<!tf.string> {tf._user_specified_name = "input"}) -> tensor<?x!tf.string> attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf._input_shapes = [#tf.shape<>], tf.signature.is_stateful} {
|
||||
// CHECK: %0 = "tfl.custom"(%arg0) {custom_code = "tftext:WhitespaceTokenizer", custom_option = opaque<"tfl", "0x"> : tensor<0xi8>} : (tensor<!tf.string>) -> tensor<?x!tf.string>
|
||||
// CHECK: return %0 : tensor<?x!tf.string>
|
||||
}
|
||||
|
@ -5,8 +5,6 @@ func @broadcast_to_bf16(%arg0: tensor<3xbf16>, %arg1: tensor<2xi64>) -> tensor<3
|
||||
return %0: tensor<3x3xbf16>
|
||||
|
||||
// CHECK-LABEL: broadcast_to_bf16
|
||||
// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<bf16>
|
||||
// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi64>, tensor<bf16>) -> tensor<3x3xbf16>
|
||||
// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xbf16>, tensor<3x3xbf16>) -> tensor<3x3xbf16>
|
||||
// CHECK: return [[MUL]] : tensor<3x3xbf16>
|
||||
// CHECK: [[BCT:%.*]] = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xbf16>, tensor<2xi64>) -> tensor<3x3xbf16>
|
||||
// CHECK: return [[BCT]] : tensor<3x3xbf16>
|
||||
}
|
||||
|
@ -1487,10 +1487,8 @@ func @broadcast_to_f32(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3
|
||||
return %0: tensor<3x3xf32>
|
||||
|
||||
// CHECK-LABEL: broadcast_to_f32
|
||||
// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<f32>
|
||||
// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi32>, tensor<f32>) -> tensor<3x3xf32>
|
||||
// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
|
||||
// CHECK: return [[MUL]] : tensor<3x3xf32>
|
||||
// CHECK: [[BCT:%.*]] = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32>
|
||||
// CHECK: return [[BCT]] : tensor<3x3xf32>
|
||||
}
|
||||
|
||||
func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3x3xi32> {
|
||||
@ -1498,10 +1496,8 @@ func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3
|
||||
return %0: tensor<3x3xi32>
|
||||
|
||||
// CHECK-LABEL: broadcast_to_i32
|
||||
// CHECK: [[CST:%.*]] = constant dense<1> : tensor<i32>
|
||||
// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi32>, tensor<i32>) -> tensor<3x3xi32>
|
||||
// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32>
|
||||
// CHECK: return [[MUL]] : tensor<3x3xi32>
|
||||
// CHECK: [[BCT:%.*]] = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32>
|
||||
// CHECK: return [[BCT]] : tensor<3x3xi32>
|
||||
}
|
||||
|
||||
func @matmul_batch(%arg0: tensor<10x15xf32>, %arg1: tensor<15x17xf32>) -> tensor<10x17xf32> {
|
||||
|
@ -277,6 +277,45 @@ func @tensorlistWhileCond(%arg0: tensor<i32>, %arg1: tensor<!tf.variant>) -> ten
|
||||
// CHECK: return %[[RESULT]] : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @tensorlistWhileRegion
|
||||
func @tensorlistWhileRegion(%arg0: tensor<2x3xf32>) -> tensor<*xf32> {
|
||||
%cst = constant dense<3> : tensor<1xi32>
|
||||
%cst_0 = constant dense<0> : tensor<i32>
|
||||
%cst_1 = constant dense<-1> : tensor<i32>
|
||||
%0 = "tf.TensorListFromTensor"(%arg0, %cst) : (tensor<2x3xf32>, tensor<1xi32>) -> tensor<!tf.variant<tensor<3xf32>>>
|
||||
// CHECK: "tf.WhileRegion"
|
||||
%1:2 = "tf.WhileRegion"(%cst_0, %0) ({
|
||||
^bb0(%carg0: tensor<i32>, %carg1: tensor<!tf.variant>):
|
||||
%cst_2 = constant dense<2> : tensor<i32>
|
||||
%1 = "tf.Less"(%carg0, %cst_2) : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||
"tf.Yield"(%1) : (tensor<i1>) -> ()
|
||||
|
||||
// verify condition types
|
||||
// CHECK: ^bb0(%[[CARG0:.*]]: tensor<i32>, %[[CARG1:.*]]: tensor<*xf32>):
|
||||
// CHECK: %[[COND:.*]] = "tf.Less"(%[[CARG0]], {{.*}}) : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||
// CHECK: "tf.Yield"(%[[COND]]) : (tensor<i1>) -> ()
|
||||
|
||||
},
|
||||
{
|
||||
^bb0(%barg0: tensor<i32>, %barg1: tensor<!tf.variant>):
|
||||
%1 = "tf.TensorListLength"(%barg1) : (tensor<!tf.variant>) -> tensor<i32>
|
||||
"tf.Yield"(%1, %barg1) : (tensor<i32>, tensor<!tf.variant>) -> ()
|
||||
|
||||
// verify body types
|
||||
// CHECK: ^bb0(%[[BARG0:.*]]: tensor<i32>, %[[BARG1:.*]]: tensor<*xf32>):
|
||||
// CHECK-NOT: tensor<!tf.variant>
|
||||
// CHECK: %[[LEN:.*]] = "tf.Gather"
|
||||
// CHECK-NOT: tensor<!tf.variant>
|
||||
// CHECK: "tf.Yield"(%[[LEN]], %[[BARG1]]) : (tensor<i32>, tensor<*xf32>) -> ()
|
||||
|
||||
}) {is_stateless = false} : (tensor<i32>, tensor<!tf.variant<tensor<3xf32>>>) -> (tensor<i32>, tensor<!tf.variant<tensor<*xf32>>>)
|
||||
// make sure the variant types in input/output have been updated
|
||||
// CHECK: {is_stateless = false} : (tensor<i32>, tensor<2x3xf32>) -> (tensor<i32>, tensor<*xf32>)
|
||||
%2 = "tf.TensorListStack"(%1#1, %cst_1) : (tensor<!tf.variant<tensor<*xf32>>>, tensor<i32>) -> tensor<*xf32>
|
||||
// CHECK: return %0#1 : tensor<*xf32>
|
||||
return %2 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @tensorlistResize(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor<i32>) -> tensor<?x10xf32> {
|
||||
%0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<3x10xf32>, tensor<1xi32>) -> tensor<!tf.variant<tensor<10xf32>>>
|
||||
%1 = "tf.TensorListResize"(%0, %arg2) : (tensor<!tf.variant<tensor<10xf32>>>, tensor<i32>) -> tensor<!tf.variant<tensor<10xf32>>>
|
||||
|
@ -0,0 +1,66 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-select-tf-ops -o - | flatbuffer_to_string - | FileCheck %s
|
||||
|
||||
func @main(tensor<4xcomplex<f64>>, tensor<4xcomplex<f64>>) -> tensor<4xcomplex<f64>> {
|
||||
^bb0(%arg0: tensor<4xcomplex<f64>>, %arg1: tensor<4xcomplex<f64>>):
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: CUSTOM,
|
||||
// CHECK-NEXT: custom_code: "FlexAdd"
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: type: COMPLEX128,
|
||||
// CHECK-NEXT: buffer: 1,
|
||||
// CHECK-NEXT: name: "arg0",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: type: COMPLEX128,
|
||||
// CHECK-NEXT: buffer: 2,
|
||||
// CHECK-NEXT: name: "arg1",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: type: COMPLEX128,
|
||||
// CHECK-NEXT: buffer: 3,
|
||||
// CHECK-NEXT: name: "add",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: inputs: [ 0, 1 ],
|
||||
// CHECK-NEXT: outputs: [ 2 ],
|
||||
// CHECK-NEXT: operators: [ {
|
||||
// CHECK-NEXT: inputs: [ 0, 1 ],
|
||||
// CHECK-NEXT: outputs: [ 2 ],
|
||||
// CHECK-NEXT: custom_options: [ 3, 65, 100, 100, 0, 20, 18, 3, 65, 100, 100, 26, 0, 26, 0, 42, 7, 10, 1, 84, 18, 2, 48, 18, 50, 0, 0, 2, 27, 23, 20, 20, 4, 40, 1 ]
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: name: "main"
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: description: "MLIR Converted.",
|
||||
// CHECK-NEXT: buffers: [ {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: metadata: [ {
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 4
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT:}
|
||||
|
||||
%0 = "tf.Add"(%arg0, %arg1) : (tensor<4xcomplex<f64>>, tensor<4xcomplex<f64>>) -> tensor<4xcomplex<f64>> loc("add")
|
||||
return %0 : tensor<4xcomplex<f64>>
|
||||
}
|
@ -598,6 +598,16 @@ func @testMaxPool2DWrongOperandStorageType(tensor<1x7x7x16x!quant.uniform<i9:f32
|
||||
|
||||
// -----
|
||||
|
||||
func @testTFLiteDetectionPostProcess(%arg0: tensor<1x64x64x32xf32>, %arg1: tensor<1x64x64x32xf32>, %arg2: tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) {
|
||||
%0, %1, %2, %3 = "tfl.custom_tf"(%arg0, %arg1, %arg2) ({
|
||||
%4, %5, %6, %7 = "tf.TFLite_Detection_PostProcess"(%arg0, %arg1, %arg2) {_output_quantized = true, _output_types = [f32, f32, f32, f32], _support_output_type_float_in_quantized_op = true, detections_per_class = 100 : i64, device = "", h_scale = 5.000000e+00 : f32, max_classes_per_detection = 1 : i64, max_detections = 20 : i64, nms_iou_threshold = 6.000000e-01 : f32, nms_score_threshold = 3.000000e-01 : f32, num_classes = 90 : i64, use_regular_nms = false, w_scale = 5.000000e+00 : f32, x_scale = 1.000000e+01 : f32, y_scale = 1.000000e+01 : f32} : (tensor<1x64x64x32xf32>, tensor<1x64x64x32xf32>, tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>)
|
||||
"tfl.yield"(%4, %5, %6, %7) : (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) -> ()
|
||||
}) : (tensor<1x64x64x32xf32>, tensor<1x64x64x32xf32>, tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>)
|
||||
return %0, %1 : tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testMaxPoolingWithArgMax2D(%arg0: tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) {
|
||||
// custom op for "tfl.max_pooling_with_argmax_2d"(%arg0) {filter_h = 2 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>)
|
||||
%0, %1 = "tfl.custom"(%arg0) {custom_option = opaque<"tfl", "0x01000000020000000200000002000000020000000000000000000000000000000000000000000000"> : tensor<40xi8>, custom_code = "MaxPoolingWithArgmax2D"} : (tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>)
|
||||
@ -2300,3 +2310,21 @@ func @main(%arg0: tensor<i32>, %arg1: tensor<1xf32>) -> tensor<i32> {
|
||||
}) : (tensor<i32>, tensor<1xf32>) -> (tensor<i32>)
|
||||
return %0#0 : tensor<i32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testBroadcastToWithI32ShapeTensor
|
||||
func @testBroadcastToWithI32ShapeTensor(tensor<?x?x?x?x?x?xf32>, tensor<8xi32>) -> tensor<?x?x?x?x?x?x?x?xf32> {
|
||||
^bb0(%arg0: tensor<?x?x?x?x?x?xf32>, %arg1: tensor<8xi32>):
|
||||
// CHECK: "tfl.broadcast_to"(%arg0, %arg1)
|
||||
%0 = "tfl.broadcast_to"(%arg0, %arg1): (tensor<?x?x?x?x?x?xf32>, tensor<8xi32>) -> tensor<?x?x?x?x?x?x?x?xf32>
|
||||
return %0 : tensor<?x?x?x?x?x?x?x?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testBroadcastToWithI64ShapeTensor
|
||||
func @testBroadcastToWithI64ShapeTensor(tensor<?x?x?x?x?x?xf32>, tensor<8xi64>) -> tensor<?x?x?x?x?x?x?x?xf32> {
|
||||
^bb0(%arg0: tensor<?x?x?x?x?x?xf32>, %arg1: tensor<8xi64>):
|
||||
// CHECK: "tfl.broadcast_to"(%arg0, %arg1)
|
||||
%0 = "tfl.broadcast_to"(%arg0, %arg1): (tensor<?x?x?x?x?x?xf32>, tensor<8xi64>) -> tensor<?x?x?x?x?x?x?x?xf32>
|
||||
return %0 : tensor<?x?x?x?x?x?x?x?xf32>
|
||||
}
|
||||
|
@ -992,3 +992,13 @@ func @RemoveCast(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: return %arg0
|
||||
}
|
||||
|
||||
func @squaredDifferenceReluRemoveRelu(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
|
||||
%0 = "tfl.squared_difference"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
|
||||
%1 = "tfl.relu"(%0) : (tensor<1xf32>) -> tensor<1xf32>
|
||||
return %1: tensor<1xf32>
|
||||
|
||||
// CHECK-LABEL: squaredDifferenceReluRemoveRelu
|
||||
// CHECK: %[[RESULT:.*]] = tfl.squared_difference %arg0, %arg1 : tensor<1xf32>
|
||||
// CHECK: return %[[RESULT]]
|
||||
}
|
||||
|
||||
|
@ -481,3 +481,15 @@ func @nms_padded_invalid_num_args(%arg0: tensor<100x4xf32>, %arg1: tensor<100xf3
|
||||
// expected-error @+1 {{TFLite does not support batched input for non_max_suppression_padded}}
|
||||
func @nms_padded_with_batches(%arg0: tensor<2x100x4xf32>, %arg1: tensor<2x100xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<i1>, %arg6: tensor<i1>, %arg7: tensor<i1>, %arg8: tensor<i32>) -> (tensor<2x10xi32>, tensor<i32>) attributes {tf._implements = "non_max_suppression_padded_v2", tf._reference = "mlir"}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module {
|
||||
// CHECK-LABEL: func @some_func
|
||||
// CHECK-LABEL: func @func_with_call
|
||||
func @some_func(%arg0: tensor<100xf32>) -> tensor<100xf32> attributes {tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c"}
|
||||
func @func_with_call(%arg0: tensor<100xf32>) -> tensor<100xf32> {
|
||||
%0 = call @some_func(%arg0) : (tensor<100xf32>) -> tensor<100xf32>
|
||||
return %0 : tensor<100xf32>
|
||||
}
|
||||
}
|
||||
|
@ -578,3 +578,70 @@ func @MatrixSetDiagV3Conversion(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) ->
|
||||
// CHECK: %[[RES:.*]] = "tf.MatrixSetDiag"(%arg0, %arg1) : (tensor<3x3xi32>, tensor<3xi32>) -> tensor<3x3xi32>
|
||||
// CHECK: return %[[RES]]
|
||||
}
|
||||
|
||||
func @broadcast_to_f32_low_dim(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> {
|
||||
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32>
|
||||
return %0: tensor<3x3xf32>
|
||||
|
||||
// CHECK-LABEL: broadcast_to_f32_low_dim
|
||||
// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<3x3xf32>
|
||||
// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
|
||||
// CHECK: return [[MUL]] : tensor<3x3xf32>
|
||||
}
|
||||
|
||||
func @broadcast_to_i32_low_dim(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3x3xi32> {
|
||||
%0 = "tf.BroadcastTo"(%input, %shape) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32>
|
||||
return %0: tensor<3x3xi32>
|
||||
|
||||
// CHECK-LABEL: broadcast_to_i32_low_dim
|
||||
// CHECK: [[CST:%.*]] = constant dense<1> : tensor<3x3xi32>
|
||||
// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32>
|
||||
// CHECK: return [[MUL]] : tensor<3x3xi32>
|
||||
}
|
||||
|
||||
func @broadcast_to_low_dim_with_unknown_shape(%arg0: tensor<3xf32>, %arg1: tensor<*xi32>) -> tensor<3x3xf32> {
|
||||
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<*xi32>) -> tensor<3x3xf32>
|
||||
return %0: tensor<3x3xf32>
|
||||
|
||||
// CHECK-LABEL: broadcast_to_low_dim_with_unknown_shape
|
||||
// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<3x3xf32>
|
||||
// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
|
||||
// CHECK: return [[MUL]] : tensor<3x3xf32>
|
||||
}
|
||||
|
||||
func @broadcast_to_i32_low_dim_with_unknown_output(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<*xi32> {
|
||||
%0 = "tf.BroadcastTo"(%input, %shape) : (tensor<3xi32>, tensor<2xi32>) -> tensor<*xi32>
|
||||
return %0: tensor<*xi32>
|
||||
|
||||
// CHECK-LABEL: broadcast_to_i32_low_dim_with_unknown_output
|
||||
// CHECK: [[CST:%.*]] = constant dense<1> : tensor<i32>
|
||||
// CHECK: [[FILL:%.*]] = "tf.Fill"(%arg1, [[CST]]) : (tensor<2xi32>, tensor<i32>) -> tensor<*xi32>
|
||||
// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[FILL]]) : (tensor<3xi32>, tensor<*xi32>) -> tensor<*xi32>
|
||||
// CHECK: return [[MUL]] : tensor<*xi32>
|
||||
}
|
||||
|
||||
func @broadcast_to_high_dim_with_unknown_shape(%arg0: tensor<1x2x3x4x5x6xf32>, %arg1: tensor<*xi32>) -> tensor<7x8x1x2x3x4x5x6xf32> {
|
||||
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xf32>, tensor<*xi32>) -> tensor<7x8x1x2x3x4x5x6xf32>
|
||||
return %0: tensor<7x8x1x2x3x4x5x6xf32>
|
||||
|
||||
// CHECK-LABEL: broadcast_to_high_dim_with_unknown_shape
|
||||
// CHECK: [[BCT:%.*]] = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xf32>, tensor<*xi32>) -> tensor<7x8x1x2x3x4x5x6xf32>
|
||||
// CHECK: return [[BCT]] : tensor<7x8x1x2x3x4x5x6xf32>
|
||||
}
|
||||
|
||||
func @broadcast_to_high_dim_with_unknown_output(%arg0: tensor<1x2x3x4x5x6xf32>, %arg1: tensor<8xi32>) -> tensor<*xf32> {
|
||||
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xf32>, tensor<8xi32>) -> tensor<*xf32>
|
||||
return %0: tensor<*xf32>
|
||||
|
||||
// CHECK-LABEL: broadcast_to_high_dim_with_unknown_output
|
||||
// CHECK: [[BCT:%.*]] = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xf32>, tensor<8xi32>) -> tensor<*xf32>
|
||||
// CHECK: return [[BCT]] : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @broadcast_to_with_unknown_shape_and_output(%arg0: tensor<1x2x3x4x5x6xf32>, %arg1: tensor<*xi32>) -> tensor<*xf32> {
|
||||
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xf32>, tensor<*xi32>) -> tensor<*xf32>
|
||||
return %0: tensor<*xf32>
|
||||
|
||||
// CHECK-LABEL: broadcast_to_with_unknown_shape_and_output
|
||||
// CHECK: "tf.BroadcastTo"(%arg0, %arg1)
|
||||
}
|
||||
|
20
tensorflow/compiler/mlir/lite/tests/raise-custom-ops.mlir
Normal file
20
tensorflow/compiler/mlir/lite/tests/raise-custom-ops.mlir
Normal file
@ -0,0 +1,20 @@
|
||||
// RUN: tf-opt -tfl-raise-custom-ops -canonicalize %s -o - | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: custom_op
|
||||
func @custom_op(%arg0: tensor<4xf32>) -> tensor<4xf32> {
|
||||
%0 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32>
|
||||
%1 = "tfl.mul"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
// will be preserved since it has uses.
|
||||
%2 = "tf.MyCustomOp"(%1, %0) {fused_activation_function = "RELU", int_attr = 2 : i32} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
// will be removed since it doesn't have uses and doesn't have side effect.
|
||||
"tf.MyCustomOp"(%1, %0) {fused_activation_function = "RELU", int_attr = 2 : i32} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %2 : tensor<4xf32>
|
||||
|
||||
// CHECK-NEXT: %[[CST:.*]] = constant dense<1.000000e+00>
|
||||
// CHECK-NEXT: %[[MUL:.*]] = tfl.mul %arg0, %[[CST]] {fused_activation_function = "NONE"} : tensor<4xf32>
|
||||
// CHECK-NEXT: %[[CUSTOM:.*]] = "tfl.custom_tf"(%[[MUL]], %[[CST]]) ( {
|
||||
// CHECK-NEXT: %[[MY_CUSTOM:.*]] = "tf.MyCustomOp"(%[[MUL]], %[[CST]]) {fused_activation_function = "RELU", int_attr = 2 : i32} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK-NEXT: "tfl.yield"(%[[MY_CUSTOM]]) : (tensor<4xf32>) -> ()
|
||||
// CHECK-NEXT: }) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK-NEXT: return %[[CUSTOM]] : tensor<4xf32>
|
||||
}
|
@ -175,6 +175,11 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
// Add a shape inference pass to optimize away the unnecessary casts.
|
||||
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
|
||||
}
|
||||
|
||||
// Inline function calls that left in the graph after folding functional
|
||||
// control flow ops (IfOp, CaseOp).
|
||||
pass_manager->addPass(mlir::createInlinerPass());
|
||||
|
||||
pass_manager->addPass(
|
||||
mlir::TFL::CreateLegalizeTFPass(pass_config.runtime_verification));
|
||||
pass_manager->addPass(mlir::TFL::CreateOptimizePass());
|
||||
@ -182,6 +187,7 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
// so that it can target constants introduced once TensorFlow Identity ops
|
||||
// are removed during legalization.
|
||||
pass_manager->addPass(mlir::TFL::CreateOptimizeFunctionalOpsPass());
|
||||
pass_manager->addPass(mlir::TFL::CreateRaiseCustomOpsPass());
|
||||
pass_manager->addPass(mlir::createSymbolDCEPass());
|
||||
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user